| |
|
| | import os, re, json, pickle, logging, numpy as np, faiss
|
| | from tqdm.notebook import tqdm
|
| | from sentence_transformers import SentenceTransformer
|
| | from langchain_community.retrievers import BM25Retriever
|
| | from langchain.docstore.document import Document
|
| |
|
| | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| | logger = logging.getLogger(__name__)
|
| |
|
| | WORK = "context"
|
| | JSONL = f"{WORK}/rag_documents.jsonl"
|
| | FAISS_INDEX = f"{WORK}/faiss_ivf.index"
|
| | BM25_PICKLE = f"{WORK}/bm25_retriever.pkl"
|
| |
|
| | logger.info("Loading all RAG documents...")
|
| | with open(JSONL, encoding='utf-8') as f:
|
| | ALL_DOCS = [json.loads(line) for line in f]
|
| |
|
| | LINE_TO_TEXT = {i: doc["text"] for i, doc in enumerate(ALL_DOCS)}
|
| | LINE_TO_META = {i: doc["metadata"] for i, doc in enumerate(ALL_DOCS)}
|
| |
|
| | class HybridRetriever:
|
| | def __init__(self):
|
| |
|
| | self.faiss_index = faiss.read_index(FAISS_INDEX)
|
| | logger.info(f"FAISS loaded ({self.faiss_index.ntotal:,} vectors)")
|
| |
|
| |
|
| | self.model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2",
|
| | device="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu")
|
| |
|
| |
|
| | if os.path.exists(BM25_PICKLE):
|
| | self.bm25 = pickle.load(open(BM25_PICKLE, "rb"))
|
| | logger.info("BM25 loaded")
|
| | else:
|
| | logger.info("Building BM25...")
|
| | docs = [Document(page_content=re.sub(r"^Filename:.*\nFullPath:.*\n\n", "",
|
| | doc["text"], flags=re.M),
|
| | metadata=doc["metadata"]) for doc in ALL_DOCS]
|
| | self.bm25 = BM25Retriever.from_documents(docs)
|
| | self.bm25.k = 30
|
| | pickle.dump(self.bm25, open(BM25_PICKLE, "wb"))
|
| | logger.info("BM25 built and saved")
|
| |
|
| | def batch_retrieve(self, queries, top_k=3, faiss_k=10, bm25_k=3):
|
| | qvecs = self.model.encode(queries, show_progress_bar=False, normalize_embeddings=True).astype("float32")
|
| | D, I = self.faiss_index.search(qvecs, faiss_k)
|
| |
|
| | batch_results = []
|
| | for qi, (scores, indices) in enumerate(zip(D, I)):
|
| | results = []
|
| | seen = set()
|
| | for score, idx in zip(scores, indices):
|
| | if idx == -1 or idx in seen: continue
|
| | results.append({"score": float(score), "text": LINE_TO_TEXT[idx],
|
| | "metadata": LINE_TO_META[idx], "source": "FAISS"})
|
| | seen.add(idx)
|
| | if len(results) >= top_k: break
|
| |
|
| |
|
| | bm25_docs = self.bm25.invoke(queries[qi])
|
| | for doc in bm25_docs[:bm25_k]:
|
| | ln = doc.metadata.get("line_no")
|
| | if ln in seen: continue
|
| | results.append({"score": 0.0, "text": LINE_TO_TEXT.get(ln, ""),
|
| | "metadata": LINE_TO_META.get(ln, doc.metadata), "source": "BM25"})
|
| | seen.add(ln)
|
| | if len(results) >= top_k: break
|
| | batch_results.append(results)
|
| | return batch_results
|
| |
|
| |
|
| | retriever = HybridRetriever()
|
| |
|