| import faiss |
| import pickle |
| import logging |
| from pathlib import Path |
| from sentence_transformers import SentenceTransformer, CrossEncoder |
| logger = logging.getLogger(__name__) |
|
|
| class Context: |
| def __init__(self, chunk_faiss: Path, chunk_pkl: Path): |
| self.chunk_faiss = chunk_faiss |
| self.chunk_pkl = chunk_pkl |
| self.embed_model = SentenceTransformer("BAAI/bge-small-en-v1.5") |
| self.rerank_model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") |
|
|
| |
| self.index = None |
| self.chunks: list[str] = [] |
|
|
| def load_index_and_chunks(self): |
| if self.index is not None: |
| return |
| self.index = faiss.read_index(str(self.chunk_faiss)) |
| with open(self.chunk_pkl, "rb") as f: |
| self.chunks = pickle.load(f) |
| logger.info("Loaded FAISS index and %d chunks", len(self.chunks)) |
|
|
| def retrieve_chunks(self, query, top_k, retrieve_k) -> list[str]: |
| self.load_index_and_chunks() |
|
|
| |
| query_embedding = self.embed_model.encode([query], normalize_embeddings=True) |
| _, indices = self.index.search(query_embedding, retrieve_k) |
| candidates = [self.chunks[i] for i in indices[0] if i != -1] |
| if not candidates: |
| return [] |
|
|
| |
| pairs = [[query, c] for c in candidates] |
| scores = self.rerank_model.predict(pairs) |
| ranked = sorted(zip(scores, candidates), key=lambda x: x[0], reverse=True) |
| results = [text for _, text in ranked[:top_k]] |
| logger.info("Retrieved %d chunks after reranking (from %d candidates)", len(results), len(candidates)) |
| return results |