got_retreivers / context_retreiver.py
hash-map's picture
Upload 5 files
dff5c6e verified
# context_retriever.py
import os, re, json, pickle, logging, numpy as np, faiss
from tqdm.notebook import tqdm
from sentence_transformers import SentenceTransformer
from langchain_community.retrievers import BM25Retriever
from langchain.docstore.document import Document
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
WORK = "context"
JSONL = f"{WORK}/rag_documents.jsonl"
FAISS_INDEX = f"{WORK}/faiss_ivf.index"
BM25_PICKLE = f"{WORK}/bm25_retriever.pkl"
logger.info("Loading all RAG documents...")
with open(JSONL, encoding='utf-8') as f:
ALL_DOCS = [json.loads(line) for line in f]
LINE_TO_TEXT = {i: doc["text"] for i, doc in enumerate(ALL_DOCS)}
LINE_TO_META = {i: doc["metadata"] for i, doc in enumerate(ALL_DOCS)}
class HybridRetriever:
def __init__(self):
# FAISS CPU
self.faiss_index = faiss.read_index(FAISS_INDEX)
logger.info(f"FAISS loaded ({self.faiss_index.ntotal:,} vectors)")
# SentenceTransformer (GPU if available)
self.model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2",
device="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu")
# BM25
if os.path.exists(BM25_PICKLE):
self.bm25 = pickle.load(open(BM25_PICKLE, "rb"))
logger.info("BM25 loaded")
else:
logger.info("Building BM25...")
docs = [Document(page_content=re.sub(r"^Filename:.*\nFullPath:.*\n\n", "",
doc["text"], flags=re.M),
metadata=doc["metadata"]) for doc in ALL_DOCS]
self.bm25 = BM25Retriever.from_documents(docs)
self.bm25.k = 30
pickle.dump(self.bm25, open(BM25_PICKLE, "wb"))
logger.info("BM25 built and saved")
def batch_retrieve(self, queries, top_k=3, faiss_k=10, bm25_k=3):
qvecs = self.model.encode(queries, show_progress_bar=False, normalize_embeddings=True).astype("float32")
D, I = self.faiss_index.search(qvecs, faiss_k)
batch_results = []
for qi, (scores, indices) in enumerate(zip(D, I)):
results = []
seen = set()
for score, idx in zip(scores, indices):
if idx == -1 or idx in seen: continue
results.append({"score": float(score), "text": LINE_TO_TEXT[idx],
"metadata": LINE_TO_META[idx], "source": "FAISS"})
seen.add(idx)
if len(results) >= top_k: break
# BM25
bm25_docs = self.bm25.invoke(queries[qi])
for doc in bm25_docs[:bm25_k]:
ln = doc.metadata.get("line_no")
if ln in seen: continue
results.append({"score": 0.0, "text": LINE_TO_TEXT.get(ln, ""),
"metadata": LINE_TO_META.get(ln, doc.metadata), "source": "BM25"})
seen.add(ln)
if len(results) >= top_k: break
batch_results.append(results)
return batch_results
# Singleton retriever
retriever = HybridRetriever()