transcript-rag-summarizer / src /retrieve_context.py
Anshul Prasad
minor fix.
5af1f57
import faiss
import pickle
import logging
from pathlib import Path
from sentence_transformers import SentenceTransformer, CrossEncoder
logger = logging.getLogger(__name__)
class Context:
def __init__(self, chunk_faiss: Path, chunk_pkl: Path):
self.chunk_faiss = chunk_faiss
self.chunk_pkl = chunk_pkl
self.embed_model = SentenceTransformer("BAAI/bge-small-en-v1.5")
self.rerank_model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
# Load index and chunks lazily on first query
self.index = None
self.chunks: list[str] = []
def load_index_and_chunks(self):
if self.index is not None:
return
self.index = faiss.read_index(str(self.chunk_faiss))
with open(self.chunk_pkl, "rb") as f:
self.chunks = pickle.load(f)
logger.info("Loaded FAISS index and %d chunks", len(self.chunks))
def retrieve_chunks(self, query, top_k, retrieve_k) -> list[str]:
self.load_index_and_chunks()
# Step 1 — dense retrieval
query_embedding = self.embed_model.encode([query], normalize_embeddings=True)
_, indices = self.index.search(query_embedding, retrieve_k)
candidates = [self.chunks[i] for i in indices[0] if i != -1]
if not candidates:
return []
# Step 2 — rerank
pairs = [[query, c] for c in candidates]
scores = self.rerank_model.predict(pairs)
ranked = sorted(zip(scores, candidates), key=lambda x: x[0], reverse=True)
results = [text for _, text in ranked[:top_k]]
logger.info("Retrieved %d chunks after reranking (from %d candidates)", len(results), len(candidates))
return results