import json import re import unicodedata from pathlib import Path from functools import lru_cache from typing import Dict, List, Any import faiss from sentence_transformers import SentenceTransformer # ----------------------------- # Paths # ----------------------------- DATA_PATH = Path("data/dataset.json") MODEL_NAME = "sentence-transformers/use-cmlm-multilingual" SAFE_MODEL_NAME = MODEL_NAME.split("/")[-1].replace("-", "_") INDEX_SI_PATH = Path(f"data/index_si_{SAFE_MODEL_NAME}.faiss") INDEX_TA_PATH = Path(f"data/index_ta_{SAFE_MODEL_NAME}.faiss") MAP_SI_PATH = Path(f"data/index_map_si_{SAFE_MODEL_NAME}.json") MAP_TA_PATH = Path(f"data/index_map_ta_{SAFE_MODEL_NAME}.json") # ----------------------------- # Safe Unicode Normalization # ----------------------------- def normalize(text: str) -> str: text = unicodedata.normalize("NFC", str(text)) text = text.replace("\u200d", "").replace("\u200c", "").replace("\ufeff", "") text = re.sub(r"[“”\"'`´]", "", text) text = re.sub(r"\s+", " ", text).strip() text = re.sub(r"[!?.,;:]+$", "", text) return text # ----------------------------- # Load Dataset # ----------------------------- if not DATA_PATH.exists(): raise FileNotFoundError(f"Dataset not found at: {DATA_PATH}") with open(DATA_PATH, "r", encoding="utf-8") as f: DATA = json.load(f) if not isinstance(DATA, list) or len(DATA) == 0: raise ValueError("dataset.json is empty or not a list. Please rebuild your dataset.") # ----------------------------- # Helper to safely get aliases # ----------------------------- def _get_aliases(item: Dict[str, Any], key: str) -> List[str]: val = item.get(key, []) if isinstance(val, list): return [normalize(x) for x in val if normalize(x)] return [] # ----------------------------- # Exact Match Tables # Includes primary questions + aliases # ----------------------------- EXACT_SI: Dict[str, Dict[str, Any]] = {} EXACT_TA: Dict[str, Dict[str, Any]] = {} for d in DATA: q_si = normalize(d.get("question_si", "")) q_ta = normalize(d.get("question_ta", "")) if q_si: EXACT_SI[q_si] = d if q_ta: EXACT_TA[q_ta] = d for a in _get_aliases(d, "aliases_si"): EXACT_SI[a] = d for a in _get_aliases(d, "aliases_ta"): EXACT_TA[a] = d # ----------------------------- # Load FAISS Indexes # ----------------------------- if not INDEX_SI_PATH.exists() or not INDEX_TA_PATH.exists(): raise FileNotFoundError( f"FAISS indexes not found. Expected:\n- {INDEX_SI_PATH}\n- {INDEX_TA_PATH}\n" "Run build_index.py to generate them." ) index_si = faiss.read_index(str(INDEX_SI_PATH)) index_ta = faiss.read_index(str(INDEX_TA_PATH)) # ----------------------------- # Optional index maps # If missing, fall back to 1:1 mapping # ----------------------------- if MAP_SI_PATH.exists(): with open(MAP_SI_PATH, "r", encoding="utf-8") as f: MAP_SI = json.load(f) else: MAP_SI = list(range(len(DATA))) if MAP_TA_PATH.exists(): with open(MAP_TA_PATH, "r", encoding="utf-8") as f: MAP_TA = json.load(f) else: MAP_TA = list(range(len(DATA))) if index_si.ntotal != len(MAP_SI): raise ValueError( f"index_si.ntotal={index_si.ntotal} does not match len(MAP_SI)={len(MAP_SI)}. " "Rebuild indexes using build_index.py." ) if index_ta.ntotal != len(MAP_TA): raise ValueError( f"index_ta.ntotal={index_ta.ntotal} does not match len(MAP_TA)={len(MAP_TA)}. " "Rebuild indexes using build_index.py." ) # ----------------------------- # Embedding Model # ----------------------------- embedder = SentenceTransformer(MODEL_NAME) # ----------------------------- # Semantic Search # ----------------------------- @lru_cache(maxsize=256) def _encode_query(q: str): return embedder.encode([q], normalize_embeddings=True) def search(query: str, lang: str = "si", k: int = 5) -> List[Dict[str, Any]]: lang = (lang or "si").lower().strip() if lang not in {"si", "ta"}: lang = "si" q = normalize(query) if not q: return [] q_emb = _encode_query(q) if lang == "si": scores, idxs = index_si.search(q_emb, k) index_map = MAP_SI else: scores, idxs = index_ta.search(q_emb, k) index_map = MAP_TA results = [] seen_record_ids = set() for rank, (score, idx) in enumerate(zip(scores[0], idxs[0]), start=1): if idx == -1: continue if idx < 0 or idx >= len(index_map): continue mapped_idx = index_map[int(idx)] if mapped_idx < 0 or mapped_idx >= len(DATA): continue item = DATA[int(mapped_idx)] record_id = item.get("id", f"row_{mapped_idx}") # de-duplicate same advisory record if multiple aliases hit if record_id in seen_record_ids: continue seen_record_ids.add(record_id) matched_question = item.get("question_si", "") if lang == "si" else item.get("question_ta", "") results.append({ "rank": len(results) + 1, "score": float(score), "lang": lang, "id": record_id, "matched_question": matched_question, "item": item, }) return results def debug_search(query: str, lang: str = "si", k: int = 5) -> List[Dict[str, Any]]: hits = search(query, lang=lang, k=k) return [ { "rank": h["rank"], "score": round(h["score"], 4), "id": h["id"], "category": h["item"].get("category", ""), "matched_question": h["matched_question"], } for h in hits ]