FireSymptom / rag_query.py
EphAsad's picture
Update rag_query.py
a3b2ab5 verified
import json
import faiss
import numpy as np
import os
# ---------- LOAD SETTINGS ----------
def load_settings():
with open("config/settings.json", "r", encoding="utf-8") as f:
return json.load(f)
SETTINGS = load_settings()
EMBEDDING_MODEL = SETTINGS["embedding_model"]
FAISS_INDEX_PATH = SETTINGS["faiss_index_path"]
METADATA_PATH = SETTINGS["metadata_path"]
TOP_K = SETTINGS["top_k"]
# ---------- GLOBAL CACHES ----------
_model = None
_index = None
_metadata = None
# ---------- LOAD RESOURCES (LAZY) ----------
def load_resources():
global _model, _index, _metadata
# Load embedding model lazily
if _model is None:
from sentence_transformers import SentenceTransformer
_model = SentenceTransformer(EMBEDDING_MODEL)
# Load FAISS index if present
if _index is None:
if not os.path.exists(FAISS_INDEX_PATH):
raise FileNotFoundError("FAISS index not found. Build the index first.")
_index = faiss.read_index(FAISS_INDEX_PATH)
# Load metadata if present
if _metadata is None:
if not os.path.exists(METADATA_PATH):
raise FileNotFoundError("Metadata file not found. Build the index first.")
with open(METADATA_PATH, "r", encoding="utf-8") as f:
_metadata = json.load(f)
# ---------- RETRIEVAL ----------
def retrieve(query: str):
load_resources()
query_embedding = _model.encode([query]).astype("float32")
distances, indices = _index.search(query_embedding, TOP_K)
results = []
for idx in indices[0]:
if idx < len(_metadata):
results.append(_metadata[idx])
return results