# context_retriever.py import os, re, json, pickle, logging, numpy as np, faiss from tqdm.notebook import tqdm from sentence_transformers import SentenceTransformer from langchain_community.retrievers import BM25Retriever from langchain.docstore.document import Document logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) WORK = "context" JSONL = f"{WORK}/rag_documents.jsonl" FAISS_INDEX = f"{WORK}/faiss_ivf.index" BM25_PICKLE = f"{WORK}/bm25_retriever.pkl" logger.info("Loading all RAG documents...") with open(JSONL, encoding='utf-8') as f: ALL_DOCS = [json.loads(line) for line in f] LINE_TO_TEXT = {i: doc["text"] for i, doc in enumerate(ALL_DOCS)} LINE_TO_META = {i: doc["metadata"] for i, doc in enumerate(ALL_DOCS)} class HybridRetriever: def __init__(self): # FAISS CPU self.faiss_index = faiss.read_index(FAISS_INDEX) logger.info(f"FAISS loaded ({self.faiss_index.ntotal:,} vectors)") # SentenceTransformer (GPU if available) self.model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu") # BM25 if os.path.exists(BM25_PICKLE): self.bm25 = pickle.load(open(BM25_PICKLE, "rb")) logger.info("BM25 loaded") else: logger.info("Building BM25...") docs = [Document(page_content=re.sub(r"^Filename:.*\nFullPath:.*\n\n", "", doc["text"], flags=re.M), metadata=doc["metadata"]) for doc in ALL_DOCS] self.bm25 = BM25Retriever.from_documents(docs) self.bm25.k = 30 pickle.dump(self.bm25, open(BM25_PICKLE, "wb")) logger.info("BM25 built and saved") def batch_retrieve(self, queries, top_k=3, faiss_k=10, bm25_k=3): qvecs = self.model.encode(queries, show_progress_bar=False, normalize_embeddings=True).astype("float32") D, I = self.faiss_index.search(qvecs, faiss_k) batch_results = [] for qi, (scores, indices) in enumerate(zip(D, I)): results = [] seen = set() for score, idx in zip(scores, indices): if idx == -1 or idx in seen: continue results.append({"score": float(score), "text": LINE_TO_TEXT[idx], "metadata": LINE_TO_META[idx], "source": "FAISS"}) seen.add(idx) if len(results) >= top_k: break # BM25 bm25_docs = self.bm25.invoke(queries[qi]) for doc in bm25_docs[:bm25_k]: ln = doc.metadata.get("line_no") if ln in seen: continue results.append({"score": 0.0, "text": LINE_TO_TEXT.get(ln, ""), "metadata": LINE_TO_META.get(ln, doc.metadata), "source": "BM25"}) seen.add(ln) if len(results) >= top_k: break batch_results.append(results) return batch_results # Singleton retriever retriever = HybridRetriever()