Tuminha commited on
Commit
e0ee929
·
verified ·
1 Parent(s): 348a971

Upload src/retrieve.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/retrieve.py +67 -0
src/retrieve.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Top-k semantic retrieval against FAISS index.
3
+ """
4
+ from typing import List, Dict, Callable
5
+ import numpy as np
6
+ import faiss
7
+
8
+
9
+ def retrieve(query: str, index, embed_fn: Callable, metadata_df, chunks_lookup: dict = None, k: int = 5) -> List[Dict]:
10
+ """
11
+ Return top-k results with text and metadata.
12
+
13
+ Args:
14
+ query: Query string
15
+ index: FAISS index
16
+ embed_fn: Function that takes a string and returns a normalized embedding (numpy array)
17
+ metadata_df: DataFrame with metadata (chunk_id, book, para_idx_start, para_idx_end, char_count)
18
+ chunks_lookup: Optional dict mapping chunk_id to chunk dict with 'text' field
19
+ k: Number of results to return
20
+
21
+ Returns:
22
+ List of dicts: {score, text, meta:{...}, chunk_id} length == k.
23
+ """
24
+ # Embed the query using the provided function
25
+ query_embedding = embed_fn(query)
26
+
27
+ # Ensure query embedding is the right shape and type
28
+ if len(query_embedding.shape) == 1:
29
+ query_embedding = query_embedding.reshape(1, -1)
30
+ if query_embedding.dtype != np.float32:
31
+ query_embedding = query_embedding.astype(np.float32)
32
+
33
+ # Search FAISS index
34
+ scores, indices = index.search(query_embedding, k)
35
+
36
+ # Map indices to metadata and return results
37
+ results = []
38
+ for score, idx in zip(scores[0], indices[0]):
39
+ if idx < 0 or idx >= len(metadata_df):
40
+ continue # Skip invalid indices
41
+
42
+ row = metadata_df.iloc[idx]
43
+ chunk_id = row['chunk_id']
44
+
45
+ # Get text from chunks_lookup if available, otherwise use placeholder
46
+ text = ""
47
+ if chunks_lookup and chunk_id in chunks_lookup:
48
+ text = chunks_lookup[chunk_id].get('text', '')
49
+ elif 'text' in row:
50
+ text = row['text']
51
+ else:
52
+ text = f"[Chunk {chunk_id} - text not available]"
53
+
54
+ results.append({
55
+ 'score': float(score),
56
+ 'text': text,
57
+ 'chunk_id': chunk_id,
58
+ 'meta': {
59
+ 'book': row['book'],
60
+ 'para_idx_start': int(row['para_idx_start']),
61
+ 'para_idx_end': int(row['para_idx_end']),
62
+ 'char_count': int(row['char_count'])
63
+ }
64
+ })
65
+
66
+ return results
67
+