Spaces:
Running
Running
| import json | |
| import re | |
| import unicodedata | |
| from pathlib import Path | |
| from functools import lru_cache | |
| from typing import Dict, List, Any | |
| import faiss | |
| from sentence_transformers import SentenceTransformer | |
| # ----------------------------- | |
| # Paths | |
| # ----------------------------- | |
| DATA_PATH = Path("data/dataset.json") | |
| MODEL_NAME = "sentence-transformers/use-cmlm-multilingual" | |
| SAFE_MODEL_NAME = MODEL_NAME.split("/")[-1].replace("-", "_") | |
| INDEX_SI_PATH = Path(f"data/index_si_{SAFE_MODEL_NAME}.faiss") | |
| INDEX_TA_PATH = Path(f"data/index_ta_{SAFE_MODEL_NAME}.faiss") | |
| MAP_SI_PATH = Path(f"data/index_map_si_{SAFE_MODEL_NAME}.json") | |
| MAP_TA_PATH = Path(f"data/index_map_ta_{SAFE_MODEL_NAME}.json") | |
| # ----------------------------- | |
| # Safe Unicode Normalization | |
| # ----------------------------- | |
| def normalize(text: str) -> str: | |
| text = unicodedata.normalize("NFC", str(text)) | |
| text = text.replace("\u200d", "").replace("\u200c", "").replace("\ufeff", "") | |
| text = re.sub(r"[“”\"'`´]", "", text) | |
| text = re.sub(r"\s+", " ", text).strip() | |
| text = re.sub(r"[!?.,;:]+$", "", text) | |
| return text | |
| # ----------------------------- | |
| # Load Dataset | |
| # ----------------------------- | |
| if not DATA_PATH.exists(): | |
| raise FileNotFoundError(f"Dataset not found at: {DATA_PATH}") | |
| with open(DATA_PATH, "r", encoding="utf-8") as f: | |
| DATA = json.load(f) | |
| if not isinstance(DATA, list) or len(DATA) == 0: | |
| raise ValueError("dataset.json is empty or not a list. Please rebuild your dataset.") | |
| # ----------------------------- | |
| # Helper to safely get aliases | |
| # ----------------------------- | |
| def _get_aliases(item: Dict[str, Any], key: str) -> List[str]: | |
| val = item.get(key, []) | |
| if isinstance(val, list): | |
| return [normalize(x) for x in val if normalize(x)] | |
| return [] | |
| # ----------------------------- | |
| # Exact Match Tables | |
| # Includes primary questions + aliases | |
| # ----------------------------- | |
| EXACT_SI: Dict[str, Dict[str, Any]] = {} | |
| EXACT_TA: Dict[str, Dict[str, Any]] = {} | |
| for d in DATA: | |
| q_si = normalize(d.get("question_si", "")) | |
| q_ta = normalize(d.get("question_ta", "")) | |
| if q_si: | |
| EXACT_SI[q_si] = d | |
| if q_ta: | |
| EXACT_TA[q_ta] = d | |
| for a in _get_aliases(d, "aliases_si"): | |
| EXACT_SI[a] = d | |
| for a in _get_aliases(d, "aliases_ta"): | |
| EXACT_TA[a] = d | |
| # ----------------------------- | |
| # Load FAISS Indexes | |
| # ----------------------------- | |
| if not INDEX_SI_PATH.exists() or not INDEX_TA_PATH.exists(): | |
| raise FileNotFoundError( | |
| f"FAISS indexes not found. Expected:\n- {INDEX_SI_PATH}\n- {INDEX_TA_PATH}\n" | |
| "Run build_index.py to generate them." | |
| ) | |
| index_si = faiss.read_index(str(INDEX_SI_PATH)) | |
| index_ta = faiss.read_index(str(INDEX_TA_PATH)) | |
| # ----------------------------- | |
| # Optional index maps | |
| # If missing, fall back to 1:1 mapping | |
| # ----------------------------- | |
| if MAP_SI_PATH.exists(): | |
| with open(MAP_SI_PATH, "r", encoding="utf-8") as f: | |
| MAP_SI = json.load(f) | |
| else: | |
| MAP_SI = list(range(len(DATA))) | |
| if MAP_TA_PATH.exists(): | |
| with open(MAP_TA_PATH, "r", encoding="utf-8") as f: | |
| MAP_TA = json.load(f) | |
| else: | |
| MAP_TA = list(range(len(DATA))) | |
| if index_si.ntotal != len(MAP_SI): | |
| raise ValueError( | |
| f"index_si.ntotal={index_si.ntotal} does not match len(MAP_SI)={len(MAP_SI)}. " | |
| "Rebuild indexes using build_index.py." | |
| ) | |
| if index_ta.ntotal != len(MAP_TA): | |
| raise ValueError( | |
| f"index_ta.ntotal={index_ta.ntotal} does not match len(MAP_TA)={len(MAP_TA)}. " | |
| "Rebuild indexes using build_index.py." | |
| ) | |
| # ----------------------------- | |
| # Embedding Model | |
| # ----------------------------- | |
| embedder = SentenceTransformer(MODEL_NAME) | |
| # ----------------------------- | |
| # Semantic Search | |
| # ----------------------------- | |
| def _encode_query(q: str): | |
| return embedder.encode([q], normalize_embeddings=True) | |
| def search(query: str, lang: str = "si", k: int = 5) -> List[Dict[str, Any]]: | |
| lang = (lang or "si").lower().strip() | |
| if lang not in {"si", "ta"}: | |
| lang = "si" | |
| q = normalize(query) | |
| if not q: | |
| return [] | |
| q_emb = _encode_query(q) | |
| if lang == "si": | |
| scores, idxs = index_si.search(q_emb, k) | |
| index_map = MAP_SI | |
| else: | |
| scores, idxs = index_ta.search(q_emb, k) | |
| index_map = MAP_TA | |
| results = [] | |
| seen_record_ids = set() | |
| for rank, (score, idx) in enumerate(zip(scores[0], idxs[0]), start=1): | |
| if idx == -1: | |
| continue | |
| if idx < 0 or idx >= len(index_map): | |
| continue | |
| mapped_idx = index_map[int(idx)] | |
| if mapped_idx < 0 or mapped_idx >= len(DATA): | |
| continue | |
| item = DATA[int(mapped_idx)] | |
| record_id = item.get("id", f"row_{mapped_idx}") | |
| # de-duplicate same advisory record if multiple aliases hit | |
| if record_id in seen_record_ids: | |
| continue | |
| seen_record_ids.add(record_id) | |
| matched_question = item.get("question_si", "") if lang == "si" else item.get("question_ta", "") | |
| results.append({ | |
| "rank": len(results) + 1, | |
| "score": float(score), | |
| "lang": lang, | |
| "id": record_id, | |
| "matched_question": matched_question, | |
| "item": item, | |
| }) | |
| return results | |
| def debug_search(query: str, lang: str = "si", k: int = 5) -> List[Dict[str, Any]]: | |
| hits = search(query, lang=lang, k=k) | |
| return [ | |
| { | |
| "rank": h["rank"], | |
| "score": round(h["score"], 4), | |
| "id": h["id"], | |
| "category": h["item"].get("category", ""), | |
| "matched_question": h["matched_question"], | |
| } | |
| for h in hits | |
| ] |