coco-guide-api / retrieval.py
Abhiru1's picture
Upload retrieval.py
f8dbb8b verified
import json
import re
import unicodedata
from pathlib import Path
from functools import lru_cache
from typing import Dict, List, Any
import faiss
from sentence_transformers import SentenceTransformer
# -----------------------------
# Paths
# -----------------------------
DATA_PATH = Path("data/dataset.json")
MODEL_NAME = "sentence-transformers/use-cmlm-multilingual"
SAFE_MODEL_NAME = MODEL_NAME.split("/")[-1].replace("-", "_")
INDEX_SI_PATH = Path(f"data/index_si_{SAFE_MODEL_NAME}.faiss")
INDEX_TA_PATH = Path(f"data/index_ta_{SAFE_MODEL_NAME}.faiss")
MAP_SI_PATH = Path(f"data/index_map_si_{SAFE_MODEL_NAME}.json")
MAP_TA_PATH = Path(f"data/index_map_ta_{SAFE_MODEL_NAME}.json")
# -----------------------------
# Safe Unicode Normalization
# -----------------------------
def normalize(text: str) -> str:
text = unicodedata.normalize("NFC", str(text))
text = text.replace("\u200d", "").replace("\u200c", "").replace("\ufeff", "")
text = re.sub(r"[“”\"'`´]", "", text)
text = re.sub(r"\s+", " ", text).strip()
text = re.sub(r"[!?.,;:]+$", "", text)
return text
# -----------------------------
# Load Dataset
# -----------------------------
if not DATA_PATH.exists():
raise FileNotFoundError(f"Dataset not found at: {DATA_PATH}")
with open(DATA_PATH, "r", encoding="utf-8") as f:
DATA = json.load(f)
if not isinstance(DATA, list) or len(DATA) == 0:
raise ValueError("dataset.json is empty or not a list. Please rebuild your dataset.")
# -----------------------------
# Helper to safely get aliases
# -----------------------------
def _get_aliases(item: Dict[str, Any], key: str) -> List[str]:
val = item.get(key, [])
if isinstance(val, list):
return [normalize(x) for x in val if normalize(x)]
return []
# -----------------------------
# Exact Match Tables
# Includes primary questions + aliases
# -----------------------------
EXACT_SI: Dict[str, Dict[str, Any]] = {}
EXACT_TA: Dict[str, Dict[str, Any]] = {}
for d in DATA:
q_si = normalize(d.get("question_si", ""))
q_ta = normalize(d.get("question_ta", ""))
if q_si:
EXACT_SI[q_si] = d
if q_ta:
EXACT_TA[q_ta] = d
for a in _get_aliases(d, "aliases_si"):
EXACT_SI[a] = d
for a in _get_aliases(d, "aliases_ta"):
EXACT_TA[a] = d
# -----------------------------
# Load FAISS Indexes
# -----------------------------
if not INDEX_SI_PATH.exists() or not INDEX_TA_PATH.exists():
raise FileNotFoundError(
f"FAISS indexes not found. Expected:\n- {INDEX_SI_PATH}\n- {INDEX_TA_PATH}\n"
"Run build_index.py to generate them."
)
index_si = faiss.read_index(str(INDEX_SI_PATH))
index_ta = faiss.read_index(str(INDEX_TA_PATH))
# -----------------------------
# Optional index maps
# If missing, fall back to 1:1 mapping
# -----------------------------
if MAP_SI_PATH.exists():
with open(MAP_SI_PATH, "r", encoding="utf-8") as f:
MAP_SI = json.load(f)
else:
MAP_SI = list(range(len(DATA)))
if MAP_TA_PATH.exists():
with open(MAP_TA_PATH, "r", encoding="utf-8") as f:
MAP_TA = json.load(f)
else:
MAP_TA = list(range(len(DATA)))
if index_si.ntotal != len(MAP_SI):
raise ValueError(
f"index_si.ntotal={index_si.ntotal} does not match len(MAP_SI)={len(MAP_SI)}. "
"Rebuild indexes using build_index.py."
)
if index_ta.ntotal != len(MAP_TA):
raise ValueError(
f"index_ta.ntotal={index_ta.ntotal} does not match len(MAP_TA)={len(MAP_TA)}. "
"Rebuild indexes using build_index.py."
)
# -----------------------------
# Embedding Model
# -----------------------------
embedder = SentenceTransformer(MODEL_NAME)
# -----------------------------
# Semantic Search
# -----------------------------
@lru_cache(maxsize=256)
def _encode_query(q: str):
return embedder.encode([q], normalize_embeddings=True)
def search(query: str, lang: str = "si", k: int = 5) -> List[Dict[str, Any]]:
lang = (lang or "si").lower().strip()
if lang not in {"si", "ta"}:
lang = "si"
q = normalize(query)
if not q:
return []
q_emb = _encode_query(q)
if lang == "si":
scores, idxs = index_si.search(q_emb, k)
index_map = MAP_SI
else:
scores, idxs = index_ta.search(q_emb, k)
index_map = MAP_TA
results = []
seen_record_ids = set()
for rank, (score, idx) in enumerate(zip(scores[0], idxs[0]), start=1):
if idx == -1:
continue
if idx < 0 or idx >= len(index_map):
continue
mapped_idx = index_map[int(idx)]
if mapped_idx < 0 or mapped_idx >= len(DATA):
continue
item = DATA[int(mapped_idx)]
record_id = item.get("id", f"row_{mapped_idx}")
# de-duplicate same advisory record if multiple aliases hit
if record_id in seen_record_ids:
continue
seen_record_ids.add(record_id)
matched_question = item.get("question_si", "") if lang == "si" else item.get("question_ta", "")
results.append({
"rank": len(results) + 1,
"score": float(score),
"lang": lang,
"id": record_id,
"matched_question": matched_question,
"item": item,
})
return results
def debug_search(query: str, lang: str = "si", k: int = 5) -> List[Dict[str, Any]]:
hits = search(query, lang=lang, k=k)
return [
{
"rank": h["rank"],
"score": round(h["score"], 4),
"id": h["id"],
"category": h["item"].get("category", ""),
"matched_question": h["matched_question"],
}
for h in hits
]