hadeethapi / app.py
Alshargi's picture
Update app.py
397688f verified
from __future__ import annotations
import os
import re
import time
from functools import lru_cache
from typing import List, Dict, Any, Tuple
import numpy as np
import pandas as pd
import faiss
from flask import Flask, request, jsonify, Response
from sentence_transformers import SentenceTransformer
# =========================
# Config (HF Space defaults)
# =========================
INDEX_PATH = os.getenv("HADITH_INDEX_PATH", "hadith_semantic.faiss")
META_PATH = os.getenv("HADITH_META_PATH", "hadith_meta.parquet")
# Small/fast multilingual model (good on free CPU)
MODEL_NAME = os.getenv(
"HADITH_MODEL_NAME",
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
)
DEFAULT_TOP_K = 10
MAX_TOP_K = 50
DEFAULT_RERANK_K = 35
MAX_RERANK_K = 120
MIN_RERANK_K = 10
DEFAULT_HL_TOPN = 6
MAX_HL_TOPN = 25
DEFAULT_SEG_MAXLEN = 220
MAX_SEG_MAXLEN = 420
MIN_SEG_MAXLEN = 120
# Rerank knobs (keep small for HF free CPU)
RERANK_MAX_SEGS_PER_DOC = int(os.getenv("RERANK_MAX_SEGS_PER_DOC", "8"))
RERANK_SEG_MAXLEN = int(os.getenv("RERANK_SEG_MAXLEN", "240"))
RERANK_WEIGHT = float(os.getenv("RERANK_WEIGHT", "0.65"))
RERANK_ENABLE = os.getenv("RERANK_ENABLE", "1").strip() != "0"
# CORS
CORS_ALLOW_ORIGIN = os.getenv("CORS_ALLOW_ORIGIN", "*")
# =========================
# Arabic normalization
# =========================
_AR_DIACRITICS = re.compile(r"""
[\u0610-\u061A]
| [\u064B-\u065F]
| [\u0670]
| [\u06D6-\u06ED]
""", re.VERBOSE)
_AR_PUNCT = re.compile(r"[^\w\u0600-\u06FF]+", re.UNICODE)
def normalize_ar(text: str) -> str:
if text is None:
return ""
text = str(text)
text = _AR_DIACRITICS.sub("", text)
text = text.replace("ـ", "")
text = re.sub(r"[إأآٱ]", "ا", text)
text = text.replace("ى", "ي")
text = text.replace("ؤ", "و")
text = text.replace("ئ", "ي")
text = re.sub(r"\s+", " ", text).strip()
return text
def ar_tokens(text: str) -> List[str]:
t = normalize_ar(text)
t = _AR_PUNCT.sub(" ", t)
toks = [x.strip() for x in t.split() if x.strip()]
toks = [x for x in toks if len(x) >= 2]
return toks
def escape_html(s: str) -> str:
if s is None:
return ""
return (
str(s)
.replace("&", "&")
.replace("<", "&lt;")
.replace(">", "&gt;")
.replace('"', "&quot;")
.replace("'", "&#39;")
)
# =========================
# Segmenting
# =========================
def split_ar_segments(text: str, max_len: int) -> List[str]:
if not text:
return []
t = re.sub(r"\s+", " ", str(text)).strip()
parts = re.split(r"(?<=[\.\!\?؟\،\,\;\:])\s+", t)
segs: List[str] = []
buf = ""
for p in parts:
p = (p or "").strip()
if not p:
continue
if not buf:
buf = p
elif len(buf) + 1 + len(p) <= max_len:
buf = f"{buf} {p}"
else:
segs.append(buf)
buf = p
if buf:
segs.append(buf)
if len(segs) <= 1 and len(t) > max_len:
segs = [t[i:i+max_len].strip() for i in range(0, len(t), max_len) if t[i:i+max_len].strip()]
return segs
def pick_segs_for_rerank(segs: List[str], max_keep: int) -> List[str]:
if len(segs) <= max_keep:
return segs
idxs = np.linspace(0, len(segs) - 1, num=max_keep)
idxs = [int(round(x)) for x in idxs]
seen = set()
out = []
for i in idxs:
if i not in seen:
seen.add(i)
out.append(segs[i])
return out[:max_keep]
# =========================
# Embedding helpers (cached)
# IMPORTANT: This model does NOT use "query:" / "passage:" prefixes.
# =========================
@lru_cache(maxsize=2048)
def cached_query_emb(query_norm: str) -> bytes:
emb = model.encode([query_norm], normalize_embeddings=True).astype("float32")[0]
return emb.tobytes()
def get_query_emb(query_norm: str) -> np.ndarray:
return np.frombuffer(cached_query_emb(query_norm), dtype=np.float32)
# =========================
# Evidence HTML
# =========================
def build_heatmap_html(segs: List[str], sims: np.ndarray, top_n: int = 6) -> str:
if not segs or sims.size == 0:
return ""
n = len(segs)
top_n = max(1, min(top_n, n))
s_min = float(np.min(sims))
s_max = float(np.max(sims))
denom = (s_max - s_min) if (s_max - s_min) > 1e-6 else 1.0
order = np.argsort(-sims)
keep = set(order[:top_n])
blocks = []
for i in range(n):
w = (float(sims[i]) - s_min) / denom
alpha = (0.20 + 0.60 * w) if i in keep else (0.08 + 0.18 * w)
alpha = max(0.06, min(alpha, 0.85))
blocks.append(
f'<span title="{escape_html(segs[i])}" '
f'style="display:inline-block;width:10px;height:10px;margin:0 3px 0 0;'
f'border-radius:4px;background:rgba(37,99,235,{alpha:.3f});border:1px solid rgba(37,99,235,0.20);"></span>'
)
return (
'<div style="margin:10px 0 0;direction:ltr;text-align:left;">'
'<div style="font-size:12px;color:#475569;margin-bottom:6px;">Evidence heatmap</div>'
+ "".join(blocks) +
'</div>'
)
def best_seg_html(segs: List[str], sims: np.ndarray) -> str:
if not segs or sims.size == 0:
return ""
i = int(np.argmax(sims))
return (
'<span style="background:rgba(255,230,120,0.55);'
'border:1px solid rgba(234,179,8,0.35);border-radius:12px;padding:3px 8px;display:inline;">'
f'{escape_html(segs[i])}</span>'
)
def lexical_ratio(query_norm: str, doc_norm: str, max_terms: int = 10) -> Tuple[float, str]:
q_toks = ar_tokens(query_norm)
d_toks = set(ar_tokens(doc_norm))
if not q_toks:
return 0.0, ""
hit = [t for t in q_toks if t in d_toks]
ratio = len(hit) / max(1, len(set(q_toks)))
terms = " ".join(hit[:max_terms])
return float(ratio), terms
def confidence_label(score: float) -> Tuple[str, str]:
if score >= 0.78:
return "HIGH", "bHigh"
if score >= 0.62:
return "MED", "bMed"
return "LOW", "bLow"
# =========================
# Rerank
# =========================
def rerank_rows(
query_norm: str,
df: pd.DataFrame,
k_final: int,
) -> Tuple[pd.DataFrame, Dict[int, Dict[str, Any]]]:
evidence: Dict[int, Dict[str, Any]] = {}
if (not RERANK_ENABLE) or df.empty:
for _, row in df.iterrows():
hid = int(row["hadithID"]) if pd.notna(row.get("hadithID")) else -1
evidence[hid] = {"mode": "disabled"}
return df.head(k_final), evidence
cand_rows = df.copy()
per_doc_segs: List[List[str]] = []
doc_hids: List[int] = []
for _, row in cand_rows.iterrows():
hid = int(row["hadithID"]) if pd.notna(row.get("hadithID")) else -1
doc_hids.append(hid)
ar = str(row.get("arabic_clean", "") or "").strip()
if not ar:
ar = normalize_ar(str(row.get("arabic", "") or ""))
segs = split_ar_segments(ar, max_len=RERANK_SEG_MAXLEN)
segs = pick_segs_for_rerank(segs, max_keep=RERANK_MAX_SEGS_PER_DOC)
if not segs and ar:
segs = [ar[:RERANK_SEG_MAXLEN]]
per_doc_segs.append(segs)
all_segs: List[str] = []
offsets: List[Tuple[int, int]] = []
cur = 0
for segs in per_doc_segs:
start = cur
all_segs.extend(segs)
cur += len(segs)
offsets.append((start, cur))
if not all_segs:
for hid in doc_hids:
evidence[hid] = {"mode": "empty"}
return cand_rows.head(k_final), evidence
q_emb = get_query_emb(query_norm)
seg_emb = model.encode(all_segs, normalize_embeddings=True).astype("float32")
sims_all = (seg_emb @ q_emb).astype(np.float32)
rr_scores: List[float] = []
for hid, (start, end), segs in zip(doc_hids, offsets, per_doc_segs):
if start == end:
rr = -1.0
sims = np.array([], dtype=np.float32)
else:
sims = sims_all[start:end]
rr = float(np.max(sims))
rr_scores.append(rr)
hm = build_heatmap_html(segs, sims, top_n=min(6, len(segs))) if sims.size else ""
best = best_seg_html(segs, sims) if sims.size else ""
evidence[hid] = {
"mode": "rerank",
"rerank_score": rr,
"heatmap_html": hm,
"best_seg_html": best,
}
cand_rows["rerank_score"] = rr_scores
faiss_scores = cand_rows["score"].astype(float).to_numpy()
rr = cand_rows["rerank_score"].astype(float).to_numpy()
w = float(max(0.0, min(1.0, RERANK_WEIGHT)))
blended = (1.0 - w) * faiss_scores + w * rr
cand_rows["final_score"] = blended
cand_rows = cand_rows.sort_values("final_score", ascending=False).head(k_final)
return cand_rows, evidence
# =========================
# Full highlight for ONE hadith
# =========================
def full_highlight_html(
query_norm: str,
arabic_clean_text: str,
hl_topn: int,
seg_maxlen: int,
) -> Dict[str, str]:
segs = split_ar_segments(arabic_clean_text, max_len=seg_maxlen)
if not segs:
return {
"arabic_clean_html": escape_html(arabic_clean_text),
"heatmap_html": "",
"best_seg_html": "",
}
q_emb = get_query_emb(query_norm)
seg_emb = model.encode(segs, normalize_embeddings=True).astype("float32")
sims = (seg_emb @ q_emb).astype(np.float32)
s_min = float(np.min(sims))
s_max = float(np.max(sims))
denom = (s_max - s_min) if (s_max - s_min) > 1e-6 else 1.0
order = np.argsort(-sims)
keep = set(order[:max(0, min(hl_topn, len(segs)))])
parts: List[str] = []
for i, seg in enumerate(segs):
w = (float(sims[i]) - s_min) / denom
alpha = (0.18 + 0.62 * w) if i in keep else (0.06 + 0.20 * w)
alpha = max(0.05, min(alpha, 0.82))
border_alpha = max(0.10, min(alpha * 0.8, 0.65))
style = (
f"background: rgba(255, 230, 120, {alpha:.3f});"
f"border: 1px solid rgba(234, 179, 8, {border_alpha:.3f});"
"border-radius: 12px;"
"padding: 3px 8px;"
"margin: 0 4px 6px 0;"
"display: inline;"
)
parts.append(f'<span style="{style}">{escape_html(seg)}</span> ')
return {
"arabic_clean_html": "".join(parts).strip() or escape_html(arabic_clean_text),
"heatmap_html": build_heatmap_html(segs, sims, top_n=min(6, len(segs))),
"best_seg_html": best_seg_html(segs, sims),
}
# =========================
# Load model + index + meta (once)
# =========================
if not os.path.exists(INDEX_PATH):
raise FileNotFoundError(f"FAISS index not found: {INDEX_PATH}")
if not os.path.exists(META_PATH):
raise FileNotFoundError(f"Meta parquet not found: {META_PATH}")
model = SentenceTransformer(MODEL_NAME)
index = faiss.read_index(INDEX_PATH)
meta = pd.read_parquet(META_PATH)
# Accept corpusID or hadithID, normalize to hadithID
id_col = "hadithID" if "hadithID" in meta.columns else ("corpusID" if "corpusID" in meta.columns else None)
if id_col is None:
raise ValueError("Meta must contain 'hadithID' or 'corpusID'")
if id_col != "hadithID":
meta = meta.rename(columns={id_col: "hadithID"})
required_cols = {"hadithID", "collection", "hadith_number", "arabic", "english"}
missing = required_cols - set(meta.columns)
if missing:
raise ValueError(f"Meta is missing required columns: {missing}")
if "arabic_clean" not in meta.columns:
meta["arabic_clean"] = ""
meta["arabic"] = meta["arabic"].fillna("").astype(str)
meta["english"] = meta["english"].fillna("").astype(str)
meta["arabic_clean"] = meta["arabic_clean"].fillna("").astype(str)
# =========================
# FAISS Search
# =========================
def semantic_search_df(query: str, top_k: int) -> pd.DataFrame:
q = str(query or "").strip()
if not q:
return meta.iloc[0:0].copy()
top_k = max(1, min(int(top_k), MAX_TOP_K))
q_norm = normalize_ar(q) # Arabic normalize, safe for English too
q_emb = get_query_emb(q_norm).reshape(1, -1)
scores, idx = index.search(q_emb, top_k)
res = meta.iloc[idx[0]].copy()
res["score"] = scores[0]
res = res.sort_values("score", ascending=False)
res = res[res["arabic"].str.strip() != ""]
return res
# =========================
# Flask app
# =========================
app = Flask(__name__)
def add_cors(resp):
resp.headers["Access-Control-Allow-Origin"] = CORS_ALLOW_ORIGIN
resp.headers["Access-Control-Allow-Methods"] = "GET, OPTIONS"
resp.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization"
resp.headers["Access-Control-Max-Age"] = "86400"
return resp
@app.after_request
def _after(resp):
return add_cors(resp)
@app.route("/search", methods=["OPTIONS"])
@app.route("/highlight", methods=["OPTIONS"])
@app.route("/", methods=["OPTIONS"])
def options():
return add_cors(Response("", status=204))
@app.get("/")
def health():
return jsonify({
"ok": True,
"model": MODEL_NAME,
"index_ntotal": int(getattr(index, "ntotal", -1)),
"rows": int(len(meta)),
"rerank": {
"enabled": bool(RERANK_ENABLE),
"weight": RERANK_WEIGHT,
"max_segs_per_doc": RERANK_MAX_SEGS_PER_DOC,
"seg_maxlen": RERANK_SEG_MAXLEN,
},
"endpoints": {
"search": "/search?q=...&k=10&rerank_k=35&format=json",
"search_html": "/search?q=...&k=10&rerank_k=35&format=html",
"highlight": "/highlight?q=...&hadithID=123&format=html&hl_topn=6&seg_maxlen=220",
}
})
@app.get("/search")
def search():
q = request.args.get("q", "").strip()
# final top-k
k_raw = request.args.get("k", str(DEFAULT_TOP_K)).strip()
try:
k = int(k_raw) if k_raw else DEFAULT_TOP_K
except Exception:
k = DEFAULT_TOP_K
k = max(1, min(k, MAX_TOP_K))
# rerank pool size
rk_raw = request.args.get("rerank_k", str(DEFAULT_RERANK_K)).strip()
try:
rerank_k = int(rk_raw) if rk_raw else DEFAULT_RERANK_K
except Exception:
rerank_k = DEFAULT_RERANK_K
rerank_k = max(MIN_RERANK_K, min(rerank_k, MAX_RERANK_K))
rerank_k = max(rerank_k, k)
# highlight controls
hl_raw = request.args.get("hl_topn", str(DEFAULT_HL_TOPN)).strip()
seg_raw = request.args.get("seg_maxlen", str(DEFAULT_SEG_MAXLEN)).strip()
try:
hl_topn = int(hl_raw) if hl_raw else DEFAULT_HL_TOPN
except Exception:
hl_topn = DEFAULT_HL_TOPN
try:
seg_maxlen = int(seg_raw) if seg_raw else DEFAULT_SEG_MAXLEN
except Exception:
seg_maxlen = DEFAULT_SEG_MAXLEN
hl_topn = max(0, min(hl_topn, MAX_HL_TOPN))
seg_maxlen = max(MIN_SEG_MAXLEN, min(seg_maxlen, MAX_SEG_MAXLEN))
fmt = (request.args.get("format", "json") or "json").lower()
want_html = (fmt == "html")
if not q:
return jsonify({
"ok": True,
"query": "",
"query_norm": "",
"k": k,
"rerank_k": rerank_k,
"n": 0,
"rows": int(len(meta)),
"took_ms": 0,
"format": "html" if want_html else "json",
"hl_topn": hl_topn,
"seg_maxlen": seg_maxlen,
"results": [],
})
t0 = time.time()
# 1) retrieve pool
df_pool = semantic_search_df(q, top_k=rerank_k)
q_norm = normalize_ar(q)
# 2) rerank -> final
df_final, ev = rerank_rows(query_norm=q_norm, df=df_pool, k_final=k)
took_ms = int((time.time() - t0) * 1000)
results: List[Dict[str, Any]] = []
for _, row in df_final.iterrows():
hid = int(row.get("hadithID")) if pd.notna(row.get("hadithID")) else None
arabic = str(row.get("arabic", "") or "")
english = str(row.get("english", "") or "")
ar_clean = str(row.get("arabic_clean", "") or "").strip()
if not ar_clean:
ar_clean = normalize_ar(arabic)
lex_r, lex_terms = lexical_ratio(q_norm, ar_clean)
faiss_score = float(row.get("score")) if pd.notna(row.get("score")) else 0.0
rerank_score = float(row.get("rerank_score")) if pd.notna(row.get("rerank_score")) else faiss_score
final_score = float(row.get("final_score")) if pd.notna(row.get("final_score")) else faiss_score
conf_label, conf_class = confidence_label(final_score)
e = ev.get(hid or -1, {})
heatmap_html = e.get("heatmap_html", "") if isinstance(e, dict) else ""
best_html = e.get("best_seg_html", "") if isinstance(e, dict) else ""
r = {
"hadithID": hid,
"collection": str(row.get("collection", "") or ""),
"hadith_number": int(row.get("hadith_number")) if pd.notna(row.get("hadith_number")) else None,
"score": final_score,
"faiss_score": faiss_score,
"rerank_score": rerank_score,
"conf_label": conf_label,
"conf_class": conf_class,
"lex_ratio": float(lex_r),
"lex_terms": lex_terms,
"arabic": arabic,
"arabic_clean": ar_clean,
"english": english,
"heatmap_html": heatmap_html,
"best_seg_html": best_html,
}
if want_html and hl_topn > 0:
extras = full_highlight_html(
query_norm=q_norm,
arabic_clean_text=ar_clean,
hl_topn=hl_topn,
seg_maxlen=seg_maxlen,
)
r["arabic_clean_html"] = extras["arabic_clean_html"]
r["heatmap_html"] = extras["heatmap_html"] or r["heatmap_html"]
r["best_seg_html"] = extras["best_seg_html"] or r["best_seg_html"]
results.append(r)
return jsonify({
"ok": True,
"query": q,
"query_norm": q_norm,
"k": k,
"rerank_k": rerank_k,
"n": len(results),
"rows": int(len(meta)),
"took_ms": took_ms,
"format": "html" if want_html else "json",
"hl_topn": hl_topn,
"seg_maxlen": seg_maxlen,
"results": results,
})
@app.get("/highlight")
def highlight():
q = request.args.get("q", "").strip()
hid_raw = request.args.get("hadithID", "").strip()
hl_raw = request.args.get("hl_topn", str(DEFAULT_HL_TOPN)).strip()
seg_raw = request.args.get("seg_maxlen", str(DEFAULT_SEG_MAXLEN)).strip()
try:
hl_topn = int(hl_raw) if hl_raw else DEFAULT_HL_TOPN
except Exception:
hl_topn = DEFAULT_HL_TOPN
try:
seg_maxlen = int(seg_raw) if seg_raw else DEFAULT_SEG_MAXLEN
except Exception:
seg_maxlen = DEFAULT_SEG_MAXLEN
hl_topn = max(0, min(hl_topn, MAX_HL_TOPN))
seg_maxlen = max(MIN_SEG_MAXLEN, min(seg_maxlen, MAX_SEG_MAXLEN))
fmt = (request.args.get("format", "html") or "html").lower()
want_html = (fmt == "html")
if not q or not hid_raw:
return jsonify({"ok": False, "error": "q and hadithID are required"}), 400
try:
hid = int(hid_raw)
except Exception:
return jsonify({"ok": False, "error": "hadithID must be int"}), 400
row_df = meta[meta["hadithID"] == hid]
if row_df.empty:
return jsonify({"ok": False, "error": "hadithID not found"}), 404
row = row_df.iloc[0]
q_norm = normalize_ar(q)
arabic = str(row.get("arabic", "") or "")
english = str(row.get("english", "") or "")
ar_clean = str(row.get("arabic_clean", "") or "").strip()
if not ar_clean:
ar_clean = normalize_ar(arabic)
extras = full_highlight_html(
query_norm=q_norm,
arabic_clean_text=ar_clean,
hl_topn=hl_topn if want_html else 0,
seg_maxlen=seg_maxlen,
)
lex_r, lex_terms = lexical_ratio(q_norm, ar_clean)
return jsonify({
"ok": True,
"query": q,
"query_norm": q_norm,
"hadithID": hid,
"format": "html" if want_html else "json",
"hl_topn": hl_topn,
"seg_maxlen": seg_maxlen,
"lex_ratio": float(lex_r),
"lex_terms": lex_terms,
"arabic": arabic,
"arabic_clean": ar_clean,
"english": english,
"arabic_clean_html": extras.get("arabic_clean_html", "") if want_html else "",
"heatmap_html": extras.get("heatmap_html", ""),
"best_seg_html": extras.get("best_seg_html", ""),
})
if __name__ == "__main__":
# Hugging Face Spaces uses PORT=7860
port = int(os.getenv("PORT", "7860"))
app.run(host="0.0.0.0", port=port, debug=False)