File size: 1,727 Bytes
6358641
5190b74
6358641
 
5190b74
6358641
c45eea1
a5e4950
 
 
 
 
 
 
 
 
 
 
 
 
 
5af1f57
a5e4950
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import faiss
import pickle
import logging
from pathlib import Path
from sentence_transformers import SentenceTransformer, CrossEncoder
logger = logging.getLogger(__name__)

class Context:
    def __init__(self, chunk_faiss: Path, chunk_pkl: Path):
        self.chunk_faiss = chunk_faiss
        self.chunk_pkl = chunk_pkl
        self.embed_model   = SentenceTransformer("BAAI/bge-small-en-v1.5")
        self.rerank_model  = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")

        # Load index and chunks lazily on first query
        self.index = None
        self.chunks: list[str] = []

    def load_index_and_chunks(self):
        if self.index is not None:
            return
        self.index = faiss.read_index(str(self.chunk_faiss))
        with open(self.chunk_pkl, "rb") as f:
            self.chunks = pickle.load(f)
        logger.info("Loaded FAISS index and %d chunks", len(self.chunks))

    def retrieve_chunks(self, query, top_k, retrieve_k) -> list[str]:
        self.load_index_and_chunks()

        # Step 1 — dense retrieval
        query_embedding = self.embed_model.encode([query], normalize_embeddings=True)
        _, indices = self.index.search(query_embedding, retrieve_k)
        candidates = [self.chunks[i] for i in indices[0] if i != -1]
        if not candidates:
            return []

        # Step 2 — rerank
        pairs = [[query, c] for c in candidates]
        scores = self.rerank_model.predict(pairs)
        ranked = sorted(zip(scores, candidates), key=lambda x: x[0], reverse=True)
        results = [text for _, text in ranked[:top_k]]
        logger.info("Retrieved %d chunks after reranking (from %d candidates)", len(results), len(candidates))
        return results