Anshul Prasad commited on
Commit
5190b74
·
1 Parent(s): acb9fe6

chunking logic integration.

Browse files
Files changed (1) hide show
  1. api/retrieve_context.py +52 -13
api/retrieve_context.py CHANGED
@@ -1,25 +1,64 @@
1
  import faiss
 
2
  import logging
3
  from pathlib import Path
4
- from sentence_transformers import SentenceTransformer
5
 
6
  from config import TRANSCRIPT_INDEX
7
 
8
  logger = logging.getLogger(__name__)
9
 
10
- model = SentenceTransformer("all-MiniLM-L6-v2")
11
- index = faiss.read_index(TRANSCRIPT_INDEX)
 
12
 
 
 
 
13
 
14
- def retrieve_transcripts(query: str, file_paths: list[Path], transcripts: list[str], top_k: int = 3) -> list[str]:
15
- query_embedding = model.encode([query])
16
- distances, indices = index.search(query_embedding, top_k)
17
 
18
- results = []
19
- for idx in indices[0]:
20
- if idx != -1:
21
- results.append(transcripts[idx])
22
- logger.info(f"Retrieved transcript from: {file_paths[idx]}")
23
 
24
- logger.info("Retrieval process completed.")
25
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import faiss
2
+ import pickle
3
  import logging
4
  from pathlib import Path
5
+ from sentence_transformers import SentenceTransformer, CrossEncoder
6
 
7
  from config import TRANSCRIPT_INDEX
8
 
9
  logger = logging.getLogger(__name__)
10
 
11
+ EMBED_MODEL = "BAAI/bge-small-en-v1.5"
12
+ RERANK_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
13
+ CHUNKS_PKL = "data/chunks.pkl"
14
 
15
+ # Load models once at startup (lightweight, always safe)
16
+ _embed_model = SentenceTransformer(EMBED_MODEL)
17
+ _rerank_model = CrossEncoder(RERANK_MODEL)
18
 
19
+ # Load index and chunks lazily on first query
20
+ _index = None
21
+ _chunks: list[str] = []
22
 
 
 
 
 
 
23
 
24
+ def _load_index_and_chunks():
25
+ global _index, _chunks
26
+ if _index is not None:
27
+ return
28
+ _index = faiss.read_index(TRANSCRIPT_INDEX)
29
+ with open(CHUNKS_PKL, "rb") as f:
30
+ _chunks = pickle.load(f)
31
+ logger.info("Loaded FAISS index and %d chunks", len(_chunks))
32
+
33
+
34
+ def retrieve_transcripts(
35
+ query: str,
36
+ file_paths: list[Path], # kept for API compatibility, unused now
37
+ transcripts: list[str], # kept for API compatibility, unused now
38
+ top_k: int = 3,
39
+ retrieve_k: int = 15,
40
+ ) -> list[str]:
41
+ """
42
+ 1. Embed query and retrieve top retrieve_k chunks from FAISS.
43
+ 2. Rerank with cross-encoder and return top_k best chunks.
44
+ """
45
+ _load_index_and_chunks()
46
+
47
+ # Step 1 — dense retrieval
48
+ query_embedding = _embed_model.encode(
49
+ [query], normalize_embeddings=True
50
+ )
51
+ _, indices = _index.search(query_embedding, retrieve_k)
52
+
53
+ candidates = [_chunks[i] for i in indices[0] if i != -1]
54
+ if not candidates:
55
+ return []
56
+
57
+ # Step 2 — rerank
58
+ pairs = [[query, c] for c in candidates]
59
+ scores = _rerank_model.predict(pairs)
60
+ ranked = sorted(zip(scores, candidates), key=lambda x: x[0], reverse=True)
61
+
62
+ results = [text for _, text in ranked[:top_k]]
63
+ logger.info("Retrieved %d chunks after reranking (from %d candidates)", len(results), len(candidates))
64
+ return results