Abhiru1 commited on
Commit
f8dbb8b
·
verified ·
1 Parent(s): 733f19b

Upload retrieval.py

Browse files
Files changed (1) hide show
  1. retrieval.py +204 -0
retrieval.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ import unicodedata
4
+ from pathlib import Path
5
+ from functools import lru_cache
6
+ from typing import Dict, List, Any
7
+
8
+ import faiss
9
+ from sentence_transformers import SentenceTransformer
10
+
11
+
12
+ # -----------------------------
13
+ # Paths
14
+ # -----------------------------
15
+ DATA_PATH = Path("data/dataset.json")
16
+
17
+ MODEL_NAME = "sentence-transformers/use-cmlm-multilingual"
18
+
19
+ SAFE_MODEL_NAME = MODEL_NAME.split("/")[-1].replace("-", "_")
20
+ INDEX_SI_PATH = Path(f"data/index_si_{SAFE_MODEL_NAME}.faiss")
21
+ INDEX_TA_PATH = Path(f"data/index_ta_{SAFE_MODEL_NAME}.faiss")
22
+ MAP_SI_PATH = Path(f"data/index_map_si_{SAFE_MODEL_NAME}.json")
23
+ MAP_TA_PATH = Path(f"data/index_map_ta_{SAFE_MODEL_NAME}.json")
24
+
25
+
26
+
27
+ # -----------------------------
28
+ # Safe Unicode Normalization
29
+ # -----------------------------
30
+ def normalize(text: str) -> str:
31
+ text = unicodedata.normalize("NFC", str(text))
32
+ text = text.replace("\u200d", "").replace("\u200c", "").replace("\ufeff", "")
33
+ text = re.sub(r"[“”\"'`´]", "", text)
34
+ text = re.sub(r"\s+", " ", text).strip()
35
+ text = re.sub(r"[!?.,;:]+$", "", text)
36
+ return text
37
+
38
+
39
+ # -----------------------------
40
+ # Load Dataset
41
+ # -----------------------------
42
+ if not DATA_PATH.exists():
43
+ raise FileNotFoundError(f"Dataset not found at: {DATA_PATH}")
44
+
45
+ with open(DATA_PATH, "r", encoding="utf-8") as f:
46
+ DATA = json.load(f)
47
+
48
+ if not isinstance(DATA, list) or len(DATA) == 0:
49
+ raise ValueError("dataset.json is empty or not a list. Please rebuild your dataset.")
50
+
51
+
52
+ # -----------------------------
53
+ # Helper to safely get aliases
54
+ # -----------------------------
55
+ def _get_aliases(item: Dict[str, Any], key: str) -> List[str]:
56
+ val = item.get(key, [])
57
+ if isinstance(val, list):
58
+ return [normalize(x) for x in val if normalize(x)]
59
+ return []
60
+
61
+
62
+ # -----------------------------
63
+ # Exact Match Tables
64
+ # Includes primary questions + aliases
65
+ # -----------------------------
66
+ EXACT_SI: Dict[str, Dict[str, Any]] = {}
67
+ EXACT_TA: Dict[str, Dict[str, Any]] = {}
68
+
69
+ for d in DATA:
70
+ q_si = normalize(d.get("question_si", ""))
71
+ q_ta = normalize(d.get("question_ta", ""))
72
+
73
+ if q_si:
74
+ EXACT_SI[q_si] = d
75
+ if q_ta:
76
+ EXACT_TA[q_ta] = d
77
+
78
+ for a in _get_aliases(d, "aliases_si"):
79
+ EXACT_SI[a] = d
80
+ for a in _get_aliases(d, "aliases_ta"):
81
+ EXACT_TA[a] = d
82
+
83
+
84
+ # -----------------------------
85
+ # Load FAISS Indexes
86
+ # -----------------------------
87
+ if not INDEX_SI_PATH.exists() or not INDEX_TA_PATH.exists():
88
+ raise FileNotFoundError(
89
+ f"FAISS indexes not found. Expected:\n- {INDEX_SI_PATH}\n- {INDEX_TA_PATH}\n"
90
+ "Run build_index.py to generate them."
91
+ )
92
+
93
+ index_si = faiss.read_index(str(INDEX_SI_PATH))
94
+ index_ta = faiss.read_index(str(INDEX_TA_PATH))
95
+
96
+
97
+ # -----------------------------
98
+ # Optional index maps
99
+ # If missing, fall back to 1:1 mapping
100
+ # -----------------------------
101
+ if MAP_SI_PATH.exists():
102
+ with open(MAP_SI_PATH, "r", encoding="utf-8") as f:
103
+ MAP_SI = json.load(f)
104
+ else:
105
+ MAP_SI = list(range(len(DATA)))
106
+
107
+ if MAP_TA_PATH.exists():
108
+ with open(MAP_TA_PATH, "r", encoding="utf-8") as f:
109
+ MAP_TA = json.load(f)
110
+ else:
111
+ MAP_TA = list(range(len(DATA)))
112
+
113
+ if index_si.ntotal != len(MAP_SI):
114
+ raise ValueError(
115
+ f"index_si.ntotal={index_si.ntotal} does not match len(MAP_SI)={len(MAP_SI)}. "
116
+ "Rebuild indexes using build_index.py."
117
+ )
118
+
119
+ if index_ta.ntotal != len(MAP_TA):
120
+ raise ValueError(
121
+ f"index_ta.ntotal={index_ta.ntotal} does not match len(MAP_TA)={len(MAP_TA)}. "
122
+ "Rebuild indexes using build_index.py."
123
+ )
124
+
125
+
126
+ # -----------------------------
127
+ # Embedding Model
128
+ # -----------------------------
129
+ embedder = SentenceTransformer(MODEL_NAME)
130
+
131
+
132
+ # -----------------------------
133
+ # Semantic Search
134
+ # -----------------------------
135
+ @lru_cache(maxsize=256)
136
+ def _encode_query(q: str):
137
+ return embedder.encode([q], normalize_embeddings=True)
138
+
139
+
140
+ def search(query: str, lang: str = "si", k: int = 5) -> List[Dict[str, Any]]:
141
+ lang = (lang or "si").lower().strip()
142
+ if lang not in {"si", "ta"}:
143
+ lang = "si"
144
+
145
+ q = normalize(query)
146
+ if not q:
147
+ return []
148
+
149
+ q_emb = _encode_query(q)
150
+
151
+ if lang == "si":
152
+ scores, idxs = index_si.search(q_emb, k)
153
+ index_map = MAP_SI
154
+ else:
155
+ scores, idxs = index_ta.search(q_emb, k)
156
+ index_map = MAP_TA
157
+
158
+ results = []
159
+ seen_record_ids = set()
160
+
161
+ for rank, (score, idx) in enumerate(zip(scores[0], idxs[0]), start=1):
162
+ if idx == -1:
163
+ continue
164
+ if idx < 0 or idx >= len(index_map):
165
+ continue
166
+
167
+ mapped_idx = index_map[int(idx)]
168
+ if mapped_idx < 0 or mapped_idx >= len(DATA):
169
+ continue
170
+
171
+ item = DATA[int(mapped_idx)]
172
+ record_id = item.get("id", f"row_{mapped_idx}")
173
+
174
+ # de-duplicate same advisory record if multiple aliases hit
175
+ if record_id in seen_record_ids:
176
+ continue
177
+ seen_record_ids.add(record_id)
178
+
179
+ matched_question = item.get("question_si", "") if lang == "si" else item.get("question_ta", "")
180
+
181
+ results.append({
182
+ "rank": len(results) + 1,
183
+ "score": float(score),
184
+ "lang": lang,
185
+ "id": record_id,
186
+ "matched_question": matched_question,
187
+ "item": item,
188
+ })
189
+
190
+ return results
191
+
192
+
193
+ def debug_search(query: str, lang: str = "si", k: int = 5) -> List[Dict[str, Any]]:
194
+ hits = search(query, lang=lang, k=k)
195
+ return [
196
+ {
197
+ "rank": h["rank"],
198
+ "score": round(h["score"], 4),
199
+ "id": h["id"],
200
+ "category": h["item"].get("category", ""),
201
+ "matched_question": h["matched_question"],
202
+ }
203
+ for h in hits
204
+ ]