from __future__ import annotations import os import re import time from functools import lru_cache from typing import List, Dict, Any, Tuple import numpy as np import pandas as pd import faiss from flask import Flask, request, jsonify, Response from sentence_transformers import SentenceTransformer # ========================= # Config (HF Space defaults) # ========================= INDEX_PATH = os.getenv("HADITH_INDEX_PATH", "hadith_semantic.faiss") META_PATH = os.getenv("HADITH_META_PATH", "hadith_meta.parquet") # Small/fast multilingual model (good on free CPU) MODEL_NAME = os.getenv( "HADITH_MODEL_NAME", "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" ) DEFAULT_TOP_K = 10 MAX_TOP_K = 50 DEFAULT_RERANK_K = 35 MAX_RERANK_K = 120 MIN_RERANK_K = 10 DEFAULT_HL_TOPN = 6 MAX_HL_TOPN = 25 DEFAULT_SEG_MAXLEN = 220 MAX_SEG_MAXLEN = 420 MIN_SEG_MAXLEN = 120 # Rerank knobs (keep small for HF free CPU) RERANK_MAX_SEGS_PER_DOC = int(os.getenv("RERANK_MAX_SEGS_PER_DOC", "8")) RERANK_SEG_MAXLEN = int(os.getenv("RERANK_SEG_MAXLEN", "240")) RERANK_WEIGHT = float(os.getenv("RERANK_WEIGHT", "0.65")) RERANK_ENABLE = os.getenv("RERANK_ENABLE", "1").strip() != "0" # CORS CORS_ALLOW_ORIGIN = os.getenv("CORS_ALLOW_ORIGIN", "*") # ========================= # Arabic normalization # ========================= _AR_DIACRITICS = re.compile(r""" [\u0610-\u061A] | [\u064B-\u065F] | [\u0670] | [\u06D6-\u06ED] """, re.VERBOSE) _AR_PUNCT = re.compile(r"[^\w\u0600-\u06FF]+", re.UNICODE) def normalize_ar(text: str) -> str: if text is None: return "" text = str(text) text = _AR_DIACRITICS.sub("", text) text = text.replace("ـ", "") text = re.sub(r"[إأآٱ]", "ا", text) text = text.replace("ى", "ي") text = text.replace("ؤ", "و") text = text.replace("ئ", "ي") text = re.sub(r"\s+", " ", text).strip() return text def ar_tokens(text: str) -> List[str]: t = normalize_ar(text) t = _AR_PUNCT.sub(" ", t) toks = [x.strip() for x in t.split() if x.strip()] toks = [x for x in toks if len(x) >= 2] return toks def escape_html(s: str) -> str: if s is None: return "" return ( str(s) .replace("&", "&") .replace("<", "<") .replace(">", ">") .replace('"', """) .replace("'", "'") ) # ========================= # Segmenting # ========================= def split_ar_segments(text: str, max_len: int) -> List[str]: if not text: return [] t = re.sub(r"\s+", " ", str(text)).strip() parts = re.split(r"(?<=[\.\!\?؟\،\,\;\:])\s+", t) segs: List[str] = [] buf = "" for p in parts: p = (p or "").strip() if not p: continue if not buf: buf = p elif len(buf) + 1 + len(p) <= max_len: buf = f"{buf} {p}" else: segs.append(buf) buf = p if buf: segs.append(buf) if len(segs) <= 1 and len(t) > max_len: segs = [t[i:i+max_len].strip() for i in range(0, len(t), max_len) if t[i:i+max_len].strip()] return segs def pick_segs_for_rerank(segs: List[str], max_keep: int) -> List[str]: if len(segs) <= max_keep: return segs idxs = np.linspace(0, len(segs) - 1, num=max_keep) idxs = [int(round(x)) for x in idxs] seen = set() out = [] for i in idxs: if i not in seen: seen.add(i) out.append(segs[i]) return out[:max_keep] # ========================= # Embedding helpers (cached) # IMPORTANT: This model does NOT use "query:" / "passage:" prefixes. # ========================= @lru_cache(maxsize=2048) def cached_query_emb(query_norm: str) -> bytes: emb = model.encode([query_norm], normalize_embeddings=True).astype("float32")[0] return emb.tobytes() def get_query_emb(query_norm: str) -> np.ndarray: return np.frombuffer(cached_query_emb(query_norm), dtype=np.float32) # ========================= # Evidence HTML # ========================= def build_heatmap_html(segs: List[str], sims: np.ndarray, top_n: int = 6) -> str: if not segs or sims.size == 0: return "" n = len(segs) top_n = max(1, min(top_n, n)) s_min = float(np.min(sims)) s_max = float(np.max(sims)) denom = (s_max - s_min) if (s_max - s_min) > 1e-6 else 1.0 order = np.argsort(-sims) keep = set(order[:top_n]) blocks = [] for i in range(n): w = (float(sims[i]) - s_min) / denom alpha = (0.20 + 0.60 * w) if i in keep else (0.08 + 0.18 * w) alpha = max(0.06, min(alpha, 0.85)) blocks.append( f'' ) return ( '
' '
Evidence heatmap
' + "".join(blocks) + '
' ) def best_seg_html(segs: List[str], sims: np.ndarray) -> str: if not segs or sims.size == 0: return "" i = int(np.argmax(sims)) return ( '' f'{escape_html(segs[i])}' ) def lexical_ratio(query_norm: str, doc_norm: str, max_terms: int = 10) -> Tuple[float, str]: q_toks = ar_tokens(query_norm) d_toks = set(ar_tokens(doc_norm)) if not q_toks: return 0.0, "" hit = [t for t in q_toks if t in d_toks] ratio = len(hit) / max(1, len(set(q_toks))) terms = " ".join(hit[:max_terms]) return float(ratio), terms def confidence_label(score: float) -> Tuple[str, str]: if score >= 0.78: return "HIGH", "bHigh" if score >= 0.62: return "MED", "bMed" return "LOW", "bLow" # ========================= # Rerank # ========================= def rerank_rows( query_norm: str, df: pd.DataFrame, k_final: int, ) -> Tuple[pd.DataFrame, Dict[int, Dict[str, Any]]]: evidence: Dict[int, Dict[str, Any]] = {} if (not RERANK_ENABLE) or df.empty: for _, row in df.iterrows(): hid = int(row["hadithID"]) if pd.notna(row.get("hadithID")) else -1 evidence[hid] = {"mode": "disabled"} return df.head(k_final), evidence cand_rows = df.copy() per_doc_segs: List[List[str]] = [] doc_hids: List[int] = [] for _, row in cand_rows.iterrows(): hid = int(row["hadithID"]) if pd.notna(row.get("hadithID")) else -1 doc_hids.append(hid) ar = str(row.get("arabic_clean", "") or "").strip() if not ar: ar = normalize_ar(str(row.get("arabic", "") or "")) segs = split_ar_segments(ar, max_len=RERANK_SEG_MAXLEN) segs = pick_segs_for_rerank(segs, max_keep=RERANK_MAX_SEGS_PER_DOC) if not segs and ar: segs = [ar[:RERANK_SEG_MAXLEN]] per_doc_segs.append(segs) all_segs: List[str] = [] offsets: List[Tuple[int, int]] = [] cur = 0 for segs in per_doc_segs: start = cur all_segs.extend(segs) cur += len(segs) offsets.append((start, cur)) if not all_segs: for hid in doc_hids: evidence[hid] = {"mode": "empty"} return cand_rows.head(k_final), evidence q_emb = get_query_emb(query_norm) seg_emb = model.encode(all_segs, normalize_embeddings=True).astype("float32") sims_all = (seg_emb @ q_emb).astype(np.float32) rr_scores: List[float] = [] for hid, (start, end), segs in zip(doc_hids, offsets, per_doc_segs): if start == end: rr = -1.0 sims = np.array([], dtype=np.float32) else: sims = sims_all[start:end] rr = float(np.max(sims)) rr_scores.append(rr) hm = build_heatmap_html(segs, sims, top_n=min(6, len(segs))) if sims.size else "" best = best_seg_html(segs, sims) if sims.size else "" evidence[hid] = { "mode": "rerank", "rerank_score": rr, "heatmap_html": hm, "best_seg_html": best, } cand_rows["rerank_score"] = rr_scores faiss_scores = cand_rows["score"].astype(float).to_numpy() rr = cand_rows["rerank_score"].astype(float).to_numpy() w = float(max(0.0, min(1.0, RERANK_WEIGHT))) blended = (1.0 - w) * faiss_scores + w * rr cand_rows["final_score"] = blended cand_rows = cand_rows.sort_values("final_score", ascending=False).head(k_final) return cand_rows, evidence # ========================= # Full highlight for ONE hadith # ========================= def full_highlight_html( query_norm: str, arabic_clean_text: str, hl_topn: int, seg_maxlen: int, ) -> Dict[str, str]: segs = split_ar_segments(arabic_clean_text, max_len=seg_maxlen) if not segs: return { "arabic_clean_html": escape_html(arabic_clean_text), "heatmap_html": "", "best_seg_html": "", } q_emb = get_query_emb(query_norm) seg_emb = model.encode(segs, normalize_embeddings=True).astype("float32") sims = (seg_emb @ q_emb).astype(np.float32) s_min = float(np.min(sims)) s_max = float(np.max(sims)) denom = (s_max - s_min) if (s_max - s_min) > 1e-6 else 1.0 order = np.argsort(-sims) keep = set(order[:max(0, min(hl_topn, len(segs)))]) parts: List[str] = [] for i, seg in enumerate(segs): w = (float(sims[i]) - s_min) / denom alpha = (0.18 + 0.62 * w) if i in keep else (0.06 + 0.20 * w) alpha = max(0.05, min(alpha, 0.82)) border_alpha = max(0.10, min(alpha * 0.8, 0.65)) style = ( f"background: rgba(255, 230, 120, {alpha:.3f});" f"border: 1px solid rgba(234, 179, 8, {border_alpha:.3f});" "border-radius: 12px;" "padding: 3px 8px;" "margin: 0 4px 6px 0;" "display: inline;" ) parts.append(f'{escape_html(seg)} ') return { "arabic_clean_html": "".join(parts).strip() or escape_html(arabic_clean_text), "heatmap_html": build_heatmap_html(segs, sims, top_n=min(6, len(segs))), "best_seg_html": best_seg_html(segs, sims), } # ========================= # Load model + index + meta (once) # ========================= if not os.path.exists(INDEX_PATH): raise FileNotFoundError(f"FAISS index not found: {INDEX_PATH}") if not os.path.exists(META_PATH): raise FileNotFoundError(f"Meta parquet not found: {META_PATH}") model = SentenceTransformer(MODEL_NAME) index = faiss.read_index(INDEX_PATH) meta = pd.read_parquet(META_PATH) # Accept corpusID or hadithID, normalize to hadithID id_col = "hadithID" if "hadithID" in meta.columns else ("corpusID" if "corpusID" in meta.columns else None) if id_col is None: raise ValueError("Meta must contain 'hadithID' or 'corpusID'") if id_col != "hadithID": meta = meta.rename(columns={id_col: "hadithID"}) required_cols = {"hadithID", "collection", "hadith_number", "arabic", "english"} missing = required_cols - set(meta.columns) if missing: raise ValueError(f"Meta is missing required columns: {missing}") if "arabic_clean" not in meta.columns: meta["arabic_clean"] = "" meta["arabic"] = meta["arabic"].fillna("").astype(str) meta["english"] = meta["english"].fillna("").astype(str) meta["arabic_clean"] = meta["arabic_clean"].fillna("").astype(str) # ========================= # FAISS Search # ========================= def semantic_search_df(query: str, top_k: int) -> pd.DataFrame: q = str(query or "").strip() if not q: return meta.iloc[0:0].copy() top_k = max(1, min(int(top_k), MAX_TOP_K)) q_norm = normalize_ar(q) # Arabic normalize, safe for English too q_emb = get_query_emb(q_norm).reshape(1, -1) scores, idx = index.search(q_emb, top_k) res = meta.iloc[idx[0]].copy() res["score"] = scores[0] res = res.sort_values("score", ascending=False) res = res[res["arabic"].str.strip() != ""] return res # ========================= # Flask app # ========================= app = Flask(__name__) def add_cors(resp): resp.headers["Access-Control-Allow-Origin"] = CORS_ALLOW_ORIGIN resp.headers["Access-Control-Allow-Methods"] = "GET, OPTIONS" resp.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization" resp.headers["Access-Control-Max-Age"] = "86400" return resp @app.after_request def _after(resp): return add_cors(resp) @app.route("/search", methods=["OPTIONS"]) @app.route("/highlight", methods=["OPTIONS"]) @app.route("/", methods=["OPTIONS"]) def options(): return add_cors(Response("", status=204)) @app.get("/") def health(): return jsonify({ "ok": True, "model": MODEL_NAME, "index_ntotal": int(getattr(index, "ntotal", -1)), "rows": int(len(meta)), "rerank": { "enabled": bool(RERANK_ENABLE), "weight": RERANK_WEIGHT, "max_segs_per_doc": RERANK_MAX_SEGS_PER_DOC, "seg_maxlen": RERANK_SEG_MAXLEN, }, "endpoints": { "search": "/search?q=...&k=10&rerank_k=35&format=json", "search_html": "/search?q=...&k=10&rerank_k=35&format=html", "highlight": "/highlight?q=...&hadithID=123&format=html&hl_topn=6&seg_maxlen=220", } }) @app.get("/search") def search(): q = request.args.get("q", "").strip() # final top-k k_raw = request.args.get("k", str(DEFAULT_TOP_K)).strip() try: k = int(k_raw) if k_raw else DEFAULT_TOP_K except Exception: k = DEFAULT_TOP_K k = max(1, min(k, MAX_TOP_K)) # rerank pool size rk_raw = request.args.get("rerank_k", str(DEFAULT_RERANK_K)).strip() try: rerank_k = int(rk_raw) if rk_raw else DEFAULT_RERANK_K except Exception: rerank_k = DEFAULT_RERANK_K rerank_k = max(MIN_RERANK_K, min(rerank_k, MAX_RERANK_K)) rerank_k = max(rerank_k, k) # highlight controls hl_raw = request.args.get("hl_topn", str(DEFAULT_HL_TOPN)).strip() seg_raw = request.args.get("seg_maxlen", str(DEFAULT_SEG_MAXLEN)).strip() try: hl_topn = int(hl_raw) if hl_raw else DEFAULT_HL_TOPN except Exception: hl_topn = DEFAULT_HL_TOPN try: seg_maxlen = int(seg_raw) if seg_raw else DEFAULT_SEG_MAXLEN except Exception: seg_maxlen = DEFAULT_SEG_MAXLEN hl_topn = max(0, min(hl_topn, MAX_HL_TOPN)) seg_maxlen = max(MIN_SEG_MAXLEN, min(seg_maxlen, MAX_SEG_MAXLEN)) fmt = (request.args.get("format", "json") or "json").lower() want_html = (fmt == "html") if not q: return jsonify({ "ok": True, "query": "", "query_norm": "", "k": k, "rerank_k": rerank_k, "n": 0, "rows": int(len(meta)), "took_ms": 0, "format": "html" if want_html else "json", "hl_topn": hl_topn, "seg_maxlen": seg_maxlen, "results": [], }) t0 = time.time() # 1) retrieve pool df_pool = semantic_search_df(q, top_k=rerank_k) q_norm = normalize_ar(q) # 2) rerank -> final df_final, ev = rerank_rows(query_norm=q_norm, df=df_pool, k_final=k) took_ms = int((time.time() - t0) * 1000) results: List[Dict[str, Any]] = [] for _, row in df_final.iterrows(): hid = int(row.get("hadithID")) if pd.notna(row.get("hadithID")) else None arabic = str(row.get("arabic", "") or "") english = str(row.get("english", "") or "") ar_clean = str(row.get("arabic_clean", "") or "").strip() if not ar_clean: ar_clean = normalize_ar(arabic) lex_r, lex_terms = lexical_ratio(q_norm, ar_clean) faiss_score = float(row.get("score")) if pd.notna(row.get("score")) else 0.0 rerank_score = float(row.get("rerank_score")) if pd.notna(row.get("rerank_score")) else faiss_score final_score = float(row.get("final_score")) if pd.notna(row.get("final_score")) else faiss_score conf_label, conf_class = confidence_label(final_score) e = ev.get(hid or -1, {}) heatmap_html = e.get("heatmap_html", "") if isinstance(e, dict) else "" best_html = e.get("best_seg_html", "") if isinstance(e, dict) else "" r = { "hadithID": hid, "collection": str(row.get("collection", "") or ""), "hadith_number": int(row.get("hadith_number")) if pd.notna(row.get("hadith_number")) else None, "score": final_score, "faiss_score": faiss_score, "rerank_score": rerank_score, "conf_label": conf_label, "conf_class": conf_class, "lex_ratio": float(lex_r), "lex_terms": lex_terms, "arabic": arabic, "arabic_clean": ar_clean, "english": english, "heatmap_html": heatmap_html, "best_seg_html": best_html, } if want_html and hl_topn > 0: extras = full_highlight_html( query_norm=q_norm, arabic_clean_text=ar_clean, hl_topn=hl_topn, seg_maxlen=seg_maxlen, ) r["arabic_clean_html"] = extras["arabic_clean_html"] r["heatmap_html"] = extras["heatmap_html"] or r["heatmap_html"] r["best_seg_html"] = extras["best_seg_html"] or r["best_seg_html"] results.append(r) return jsonify({ "ok": True, "query": q, "query_norm": q_norm, "k": k, "rerank_k": rerank_k, "n": len(results), "rows": int(len(meta)), "took_ms": took_ms, "format": "html" if want_html else "json", "hl_topn": hl_topn, "seg_maxlen": seg_maxlen, "results": results, }) @app.get("/highlight") def highlight(): q = request.args.get("q", "").strip() hid_raw = request.args.get("hadithID", "").strip() hl_raw = request.args.get("hl_topn", str(DEFAULT_HL_TOPN)).strip() seg_raw = request.args.get("seg_maxlen", str(DEFAULT_SEG_MAXLEN)).strip() try: hl_topn = int(hl_raw) if hl_raw else DEFAULT_HL_TOPN except Exception: hl_topn = DEFAULT_HL_TOPN try: seg_maxlen = int(seg_raw) if seg_raw else DEFAULT_SEG_MAXLEN except Exception: seg_maxlen = DEFAULT_SEG_MAXLEN hl_topn = max(0, min(hl_topn, MAX_HL_TOPN)) seg_maxlen = max(MIN_SEG_MAXLEN, min(seg_maxlen, MAX_SEG_MAXLEN)) fmt = (request.args.get("format", "html") or "html").lower() want_html = (fmt == "html") if not q or not hid_raw: return jsonify({"ok": False, "error": "q and hadithID are required"}), 400 try: hid = int(hid_raw) except Exception: return jsonify({"ok": False, "error": "hadithID must be int"}), 400 row_df = meta[meta["hadithID"] == hid] if row_df.empty: return jsonify({"ok": False, "error": "hadithID not found"}), 404 row = row_df.iloc[0] q_norm = normalize_ar(q) arabic = str(row.get("arabic", "") or "") english = str(row.get("english", "") or "") ar_clean = str(row.get("arabic_clean", "") or "").strip() if not ar_clean: ar_clean = normalize_ar(arabic) extras = full_highlight_html( query_norm=q_norm, arabic_clean_text=ar_clean, hl_topn=hl_topn if want_html else 0, seg_maxlen=seg_maxlen, ) lex_r, lex_terms = lexical_ratio(q_norm, ar_clean) return jsonify({ "ok": True, "query": q, "query_norm": q_norm, "hadithID": hid, "format": "html" if want_html else "json", "hl_topn": hl_topn, "seg_maxlen": seg_maxlen, "lex_ratio": float(lex_r), "lex_terms": lex_terms, "arabic": arabic, "arabic_clean": ar_clean, "english": english, "arabic_clean_html": extras.get("arabic_clean_html", "") if want_html else "", "heatmap_html": extras.get("heatmap_html", ""), "best_seg_html": extras.get("best_seg_html", ""), }) if __name__ == "__main__": # Hugging Face Spaces uses PORT=7860 port = int(os.getenv("PORT", "7860")) app.run(host="0.0.0.0", port=port, debug=False)