Anshul Prasad commited on
Commit
6358641
·
1 Parent(s): 8b5035b

few refinements

Browse files
Files changed (1) hide show
  1. api/retrieve_context.py +18 -16
api/retrieve_context.py CHANGED
@@ -1,28 +1,30 @@
1
- import numpy as np
2
- import faiss, logging
 
3
  from sentence_transformers import SentenceTransformer
 
4
  from config import TRANSCRIPT_INDEX
5
- logging.basicConfig(format="%(asctime)s %(levelname)s:%(message)s", level=logging.INFO)
6
 
7
- embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
8
- logging.info("Loaded embedding model for query.")
9
 
 
10
  index = faiss.read_index(TRANSCRIPT_INDEX)
11
- logging.info(f"Loaded FAISS index from {TRANSCRIPT_INDEX}.")
12
-
13
- def retrieve_transcripts(query, file_path, transcripts, top_k=3):
14
- logging.info("Starting retrieval process...")
15
 
16
- query_embedding = embedding_model.encode([query], convert_to_tensor=False)
17
- logging.info("Encoded query to embedding.")
18
 
19
- distances, indices = index.search(np.array(query_embedding), top_k)
20
- logging.info(f"Retrieved top {top_k} results from index.")
 
 
 
 
 
 
21
 
22
  results = []
23
  for idx in indices[0]:
24
- results.append(transcripts[idx])
25
- logging.info(f"Retrieved transcript from: {file_path[idx]}")
 
26
 
27
- logging.info("Retrieval process completed.")
28
  return results
 
1
+ import faiss
2
+ import logging
3
+ from pathlib import Path
4
  from sentence_transformers import SentenceTransformer
5
+
6
  from config import TRANSCRIPT_INDEX
 
7
 
8
+ logger = logging.getLogger(__name__)
 
9
 
10
+ model = SentenceTransformer("all-MiniLM-L6-v2")
11
  index = faiss.read_index(TRANSCRIPT_INDEX)
 
 
 
 
12
 
 
 
13
 
14
+ def retrieve_transcripts(
15
+ query: str,
16
+ file_paths: list[Path],
17
+ transcripts: list[str],
18
+ top_k: int = 3,
19
+ ) -> list[str]:
20
+ query_embedding = model.encode([query])
21
+ distances, indices = index.search(query_embedding, top_k)
22
 
23
  results = []
24
  for idx in indices[0]:
25
+ if idx != -1:
26
+ results.append(transcripts[idx])
27
+ logger.info(f"Retrieved transcript from: {file_paths[idx]}")
28
 
29
+ logger.info("Retrieval process completed.")
30
  return results