Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -3,24 +3,27 @@ from __future__ import annotations
|
|
| 3 |
import os
|
| 4 |
import re
|
| 5 |
import time
|
| 6 |
-
import math
|
| 7 |
from functools import lru_cache
|
| 8 |
-
from typing import List, Dict, Any, Tuple
|
| 9 |
|
| 10 |
import numpy as np
|
| 11 |
import pandas as pd
|
| 12 |
import faiss
|
| 13 |
-
|
| 14 |
from flask import Flask, request, jsonify, Response
|
| 15 |
from sentence_transformers import SentenceTransformer
|
| 16 |
|
| 17 |
|
| 18 |
# =========================
|
| 19 |
-
# Config
|
| 20 |
# =========================
|
| 21 |
-
INDEX_PATH = os.getenv("HADITH_INDEX_PATH", "
|
| 22 |
META_PATH = os.getenv("HADITH_META_PATH", "hadith_meta.parquet")
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
DEFAULT_TOP_K = 10
|
| 26 |
MAX_TOP_K = 50
|
|
@@ -29,21 +32,21 @@ DEFAULT_RERANK_K = 35
|
|
| 29 |
MAX_RERANK_K = 120
|
| 30 |
MIN_RERANK_K = 10
|
| 31 |
|
| 32 |
-
DEFAULT_HL_TOPN = 6
|
| 33 |
MAX_HL_TOPN = 25
|
| 34 |
|
| 35 |
DEFAULT_SEG_MAXLEN = 220
|
| 36 |
MAX_SEG_MAXLEN = 420
|
| 37 |
MIN_SEG_MAXLEN = 120
|
| 38 |
|
| 39 |
-
# Rerank
|
| 40 |
-
RERANK_MAX_SEGS_PER_DOC = int(os.getenv("RERANK_MAX_SEGS_PER_DOC", "
|
| 41 |
-
RERANK_SEG_MAXLEN = int(os.getenv("RERANK_SEG_MAXLEN", "240"))
|
| 42 |
-
RERANK_WEIGHT = float(os.getenv("RERANK_WEIGHT", "0.65"))
|
| 43 |
RERANK_ENABLE = os.getenv("RERANK_ENABLE", "1").strip() != "0"
|
| 44 |
|
| 45 |
# CORS
|
| 46 |
-
CORS_ALLOW_ORIGIN = os.getenv("CORS_ALLOW_ORIGIN", "*")
|
| 47 |
|
| 48 |
|
| 49 |
# =========================
|
|
@@ -75,7 +78,6 @@ def ar_tokens(text: str) -> List[str]:
|
|
| 75 |
t = normalize_ar(text)
|
| 76 |
t = _AR_PUNCT.sub(" ", t)
|
| 77 |
toks = [x.strip() for x in t.split() if x.strip()]
|
| 78 |
-
# remove super short tokens
|
| 79 |
toks = [x for x in toks if len(x) >= 2]
|
| 80 |
return toks
|
| 81 |
|
|
@@ -117,19 +119,15 @@ def split_ar_segments(text: str, max_len: int) -> List[str]:
|
|
| 117 |
if buf:
|
| 118 |
segs.append(buf)
|
| 119 |
|
| 120 |
-
# fallback chunking
|
| 121 |
if len(segs) <= 1 and len(t) > max_len:
|
| 122 |
segs = [t[i:i+max_len].strip() for i in range(0, len(t), max_len) if t[i:i+max_len].strip()]
|
| 123 |
return segs
|
| 124 |
|
| 125 |
def pick_segs_for_rerank(segs: List[str], max_keep: int) -> List[str]:
|
| 126 |
-
"""Pick up to max_keep segments spread out (for speed)."""
|
| 127 |
if len(segs) <= max_keep:
|
| 128 |
return segs
|
| 129 |
-
# spread indices evenly
|
| 130 |
idxs = np.linspace(0, len(segs) - 1, num=max_keep)
|
| 131 |
idxs = [int(round(x)) for x in idxs]
|
| 132 |
-
# unique preserve order
|
| 133 |
seen = set()
|
| 134 |
out = []
|
| 135 |
for i in idxs:
|
|
@@ -141,10 +139,11 @@ def pick_segs_for_rerank(segs: List[str], max_keep: int) -> List[str]:
|
|
| 141 |
|
| 142 |
# =========================
|
| 143 |
# Embedding helpers (cached)
|
|
|
|
| 144 |
# =========================
|
| 145 |
@lru_cache(maxsize=2048)
|
| 146 |
def cached_query_emb(query_norm: str) -> bytes:
|
| 147 |
-
emb = model.encode([
|
| 148 |
return emb.tobytes()
|
| 149 |
|
| 150 |
def get_query_emb(query_norm: str) -> np.ndarray:
|
|
@@ -152,10 +151,9 @@ def get_query_emb(query_norm: str) -> np.ndarray:
|
|
| 152 |
|
| 153 |
|
| 154 |
# =========================
|
| 155 |
-
#
|
| 156 |
# =========================
|
| 157 |
def build_heatmap_html(segs: List[str], sims: np.ndarray, top_n: int = 6) -> str:
|
| 158 |
-
"""Small bar-like heatmap using segment similarity (already computed)."""
|
| 159 |
if not segs or sims.size == 0:
|
| 160 |
return ""
|
| 161 |
|
|
@@ -166,14 +164,12 @@ def build_heatmap_html(segs: List[str], sims: np.ndarray, top_n: int = 6) -> str
|
|
| 166 |
s_max = float(np.max(sims))
|
| 167 |
denom = (s_max - s_min) if (s_max - s_min) > 1e-6 else 1.0
|
| 168 |
|
| 169 |
-
# choose top indices
|
| 170 |
order = np.argsort(-sims)
|
| 171 |
keep = set(order[:top_n])
|
| 172 |
|
| 173 |
blocks = []
|
| 174 |
for i in range(n):
|
| 175 |
-
w = (float(sims[i]) - s_min) / denom
|
| 176 |
-
# stronger for top segments
|
| 177 |
alpha = (0.20 + 0.60 * w) if i in keep else (0.08 + 0.18 * w)
|
| 178 |
alpha = max(0.06, min(alpha, 0.85))
|
| 179 |
blocks.append(
|
|
@@ -193,7 +189,11 @@ def best_seg_html(segs: List[str], sims: np.ndarray) -> str:
|
|
| 193 |
if not segs or sims.size == 0:
|
| 194 |
return ""
|
| 195 |
i = int(np.argmax(sims))
|
| 196 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
|
| 198 |
def lexical_ratio(query_norm: str, doc_norm: str, max_terms: int = 10) -> Tuple[float, str]:
|
| 199 |
q_toks = ar_tokens(query_norm)
|
|
@@ -206,38 +206,29 @@ def lexical_ratio(query_norm: str, doc_norm: str, max_terms: int = 10) -> Tuple[
|
|
| 206 |
return float(ratio), terms
|
| 207 |
|
| 208 |
def confidence_label(score: float) -> Tuple[str, str]:
|
| 209 |
-
"""
|
| 210 |
-
Simple score->label mapping.
|
| 211 |
-
Assumes cosine-like range ~[0..1] after normalization & blending.
|
| 212 |
-
"""
|
| 213 |
if score >= 0.78:
|
| 214 |
return "HIGH", "bHigh"
|
| 215 |
if score >= 0.62:
|
| 216 |
return "MED", "bMed"
|
| 217 |
return "LOW", "bLow"
|
| 218 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
def rerank_rows(
|
| 220 |
query_norm: str,
|
| 221 |
df: pd.DataFrame,
|
| 222 |
k_final: int,
|
| 223 |
) -> Tuple[pd.DataFrame, Dict[int, Dict[str, Any]]]:
|
| 224 |
-
"""
|
| 225 |
-
Rerank using segment max similarity:
|
| 226 |
-
- Split each doc to segments (short)
|
| 227 |
-
- Pick a limited set of segments (speed)
|
| 228 |
-
- One encode call for all segments
|
| 229 |
-
Returns reranked df and per-hadith evidence dict (sims/segs + prebuilt html).
|
| 230 |
-
"""
|
| 231 |
evidence: Dict[int, Dict[str, Any]] = {}
|
| 232 |
|
| 233 |
if (not RERANK_ENABLE) or df.empty:
|
| 234 |
-
# still fill basic fields
|
| 235 |
for _, row in df.iterrows():
|
| 236 |
hid = int(row["hadithID"]) if pd.notna(row.get("hadithID")) else -1
|
| 237 |
evidence[hid] = {"mode": "disabled"}
|
| 238 |
return df.head(k_final), evidence
|
| 239 |
|
| 240 |
-
# Collect segments for each candidate
|
| 241 |
cand_rows = df.copy()
|
| 242 |
|
| 243 |
per_doc_segs: List[List[str]] = []
|
|
@@ -247,21 +238,16 @@ def rerank_rows(
|
|
| 247 |
hid = int(row["hadithID"]) if pd.notna(row.get("hadithID")) else -1
|
| 248 |
doc_hids.append(hid)
|
| 249 |
|
| 250 |
-
ar = str(row.get("
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
ar_clean = ""
|
| 254 |
-
ar_clean = str(ar_clean).strip()
|
| 255 |
-
if not ar_clean:
|
| 256 |
-
ar_clean = normalize_ar(ar)
|
| 257 |
|
| 258 |
-
segs = split_ar_segments(
|
| 259 |
segs = pick_segs_for_rerank(segs, max_keep=RERANK_MAX_SEGS_PER_DOC)
|
| 260 |
-
if not segs:
|
| 261 |
-
segs = [
|
| 262 |
per_doc_segs.append(segs)
|
| 263 |
|
| 264 |
-
# Flatten
|
| 265 |
all_segs: List[str] = []
|
| 266 |
offsets: List[Tuple[int, int]] = []
|
| 267 |
cur = 0
|
|
@@ -272,21 +258,14 @@ def rerank_rows(
|
|
| 272 |
offsets.append((start, cur))
|
| 273 |
|
| 274 |
if not all_segs:
|
| 275 |
-
# fallback: no rerank
|
| 276 |
for hid in doc_hids:
|
| 277 |
evidence[hid] = {"mode": "empty"}
|
| 278 |
return cand_rows.head(k_final), evidence
|
| 279 |
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
["passage: " + s for s in all_segs],
|
| 284 |
-
normalize_embeddings=True
|
| 285 |
-
).astype("float32") # (N, d)
|
| 286 |
-
|
| 287 |
-
sims_all = (seg_emb @ q_emb).astype(np.float32) # (N,)
|
| 288 |
|
| 289 |
-
# Compute per-doc rerank score = max(sim)
|
| 290 |
rr_scores: List[float] = []
|
| 291 |
for hid, (start, end), segs in zip(doc_hids, offsets, per_doc_segs):
|
| 292 |
if start == end:
|
|
@@ -297,7 +276,6 @@ def rerank_rows(
|
|
| 297 |
rr = float(np.max(sims))
|
| 298 |
rr_scores.append(rr)
|
| 299 |
|
| 300 |
-
# Build evidence HTML now (no extra encode)
|
| 301 |
hm = build_heatmap_html(segs, sims, top_n=min(6, len(segs))) if sims.size else ""
|
| 302 |
best = best_seg_html(segs, sims) if sims.size else ""
|
| 303 |
evidence[hid] = {
|
|
@@ -305,14 +283,10 @@ def rerank_rows(
|
|
| 305 |
"rerank_score": rr,
|
| 306 |
"heatmap_html": hm,
|
| 307 |
"best_seg_html": best,
|
| 308 |
-
"rerank_segs": segs, # keep for debugging (can omit if you want)
|
| 309 |
-
"rerank_sims": None, # don't ship full sims to client
|
| 310 |
}
|
| 311 |
|
| 312 |
cand_rows["rerank_score"] = rr_scores
|
| 313 |
|
| 314 |
-
# Blend: score_final = (1-w)*faiss + w*rerank
|
| 315 |
-
# Both are cosine-ish in [0,1] in your setup (normalize embeddings + IP index)
|
| 316 |
faiss_scores = cand_rows["score"].astype(float).to_numpy()
|
| 317 |
rr = cand_rows["rerank_score"].astype(float).to_numpy()
|
| 318 |
|
|
@@ -325,7 +299,7 @@ def rerank_rows(
|
|
| 325 |
|
| 326 |
|
| 327 |
# =========================
|
| 328 |
-
# Full highlight for ONE hadith
|
| 329 |
# =========================
|
| 330 |
def full_highlight_html(
|
| 331 |
query_norm: str,
|
|
@@ -342,11 +316,7 @@ def full_highlight_html(
|
|
| 342 |
}
|
| 343 |
|
| 344 |
q_emb = get_query_emb(query_norm)
|
| 345 |
-
seg_emb = model.encode(
|
| 346 |
-
["passage: " + s for s in segs],
|
| 347 |
-
normalize_embeddings=True
|
| 348 |
-
).astype("float32")
|
| 349 |
-
|
| 350 |
sims = (seg_emb @ q_emb).astype(np.float32)
|
| 351 |
|
| 352 |
s_min = float(np.min(sims))
|
|
@@ -392,6 +362,13 @@ model = SentenceTransformer(MODEL_NAME)
|
|
| 392 |
index = faiss.read_index(INDEX_PATH)
|
| 393 |
meta = pd.read_parquet(META_PATH)
|
| 394 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
required_cols = {"hadithID", "collection", "hadith_number", "arabic", "english"}
|
| 396 |
missing = required_cols - set(meta.columns)
|
| 397 |
if missing:
|
|
@@ -400,6 +377,10 @@ if missing:
|
|
| 400 |
if "arabic_clean" not in meta.columns:
|
| 401 |
meta["arabic_clean"] = ""
|
| 402 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
|
| 404 |
# =========================
|
| 405 |
# FAISS Search
|
|
@@ -410,7 +391,7 @@ def semantic_search_df(query: str, top_k: int) -> pd.DataFrame:
|
|
| 410 |
return meta.iloc[0:0].copy()
|
| 411 |
|
| 412 |
top_k = max(1, min(int(top_k), MAX_TOP_K))
|
| 413 |
-
q_norm = normalize_ar(q)
|
| 414 |
|
| 415 |
q_emb = get_query_emb(q_norm).reshape(1, -1)
|
| 416 |
scores, idx = index.search(q_emb, top_k)
|
|
@@ -418,9 +399,6 @@ def semantic_search_df(query: str, top_k: int) -> pd.DataFrame:
|
|
| 418 |
res = meta.iloc[idx[0]].copy()
|
| 419 |
res["score"] = scores[0]
|
| 420 |
res = res.sort_values("score", ascending=False)
|
| 421 |
-
|
| 422 |
-
# ensure arabic
|
| 423 |
-
res["arabic"] = res["arabic"].fillna("").astype(str)
|
| 424 |
res = res[res["arabic"].str.strip() != ""]
|
| 425 |
return res
|
| 426 |
|
|
@@ -473,7 +451,7 @@ def health():
|
|
| 473 |
def search():
|
| 474 |
q = request.args.get("q", "").strip()
|
| 475 |
|
| 476 |
-
#
|
| 477 |
k_raw = request.args.get("k", str(DEFAULT_TOP_K)).strip()
|
| 478 |
try:
|
| 479 |
k = int(k_raw) if k_raw else DEFAULT_TOP_K
|
|
@@ -481,7 +459,7 @@ def search():
|
|
| 481 |
k = DEFAULT_TOP_K
|
| 482 |
k = max(1, min(k, MAX_TOP_K))
|
| 483 |
|
| 484 |
-
# rerank pool
|
| 485 |
rk_raw = request.args.get("rerank_k", str(DEFAULT_RERANK_K)).strip()
|
| 486 |
try:
|
| 487 |
rerank_k = int(rk_raw) if rk_raw else DEFAULT_RERANK_K
|
|
@@ -490,7 +468,7 @@ def search():
|
|
| 490 |
rerank_k = max(MIN_RERANK_K, min(rerank_k, MAX_RERANK_K))
|
| 491 |
rerank_k = max(rerank_k, k)
|
| 492 |
|
| 493 |
-
#
|
| 494 |
hl_raw = request.args.get("hl_topn", str(DEFAULT_HL_TOPN)).strip()
|
| 495 |
seg_raw = request.args.get("seg_maxlen", str(DEFAULT_SEG_MAXLEN)).strip()
|
| 496 |
try:
|
|
@@ -526,33 +504,27 @@ def search():
|
|
| 526 |
|
| 527 |
t0 = time.time()
|
| 528 |
|
| 529 |
-
# 1)
|
| 530 |
df_pool = semantic_search_df(q, top_k=rerank_k)
|
| 531 |
q_norm = normalize_ar(q)
|
| 532 |
|
| 533 |
-
# 2) rerank
|
| 534 |
df_final, ev = rerank_rows(query_norm=q_norm, df=df_pool, k_final=k)
|
| 535 |
|
| 536 |
took_ms = int((time.time() - t0) * 1000)
|
| 537 |
|
| 538 |
-
# Build results
|
| 539 |
results: List[Dict[str, Any]] = []
|
| 540 |
for _, row in df_final.iterrows():
|
| 541 |
hid = int(row.get("hadithID")) if pd.notna(row.get("hadithID")) else None
|
| 542 |
arabic = str(row.get("arabic", "") or "")
|
| 543 |
english = str(row.get("english", "") or "")
|
| 544 |
|
| 545 |
-
ar_clean = row.get("arabic_clean", "")
|
| 546 |
-
if ar_clean is None or (isinstance(ar_clean, float) and np.isnan(ar_clean)):
|
| 547 |
-
ar_clean = ""
|
| 548 |
-
ar_clean = str(ar_clean).strip()
|
| 549 |
if not ar_clean:
|
| 550 |
ar_clean = normalize_ar(arabic)
|
| 551 |
|
| 552 |
-
# lexical
|
| 553 |
lex_r, lex_terms = lexical_ratio(q_norm, ar_clean)
|
| 554 |
|
| 555 |
-
# scores
|
| 556 |
faiss_score = float(row.get("score")) if pd.notna(row.get("score")) else 0.0
|
| 557 |
rerank_score = float(row.get("rerank_score")) if pd.notna(row.get("rerank_score")) else faiss_score
|
| 558 |
final_score = float(row.get("final_score")) if pd.notna(row.get("final_score")) else faiss_score
|
|
@@ -567,31 +539,20 @@ def search():
|
|
| 567 |
"hadithID": hid,
|
| 568 |
"collection": str(row.get("collection", "") or ""),
|
| 569 |
"hadith_number": int(row.get("hadith_number")) if pd.notna(row.get("hadith_number")) else None,
|
| 570 |
-
|
| 571 |
-
# unified score the UI should use
|
| 572 |
"score": final_score,
|
| 573 |
-
|
| 574 |
-
# diagnostics
|
| 575 |
"faiss_score": faiss_score,
|
| 576 |
"rerank_score": rerank_score,
|
| 577 |
-
|
| 578 |
"conf_label": conf_label,
|
| 579 |
"conf_class": conf_class,
|
| 580 |
-
|
| 581 |
"lex_ratio": float(lex_r),
|
| 582 |
"lex_terms": lex_terms,
|
| 583 |
-
|
| 584 |
"arabic": arabic,
|
| 585 |
"arabic_clean": ar_clean,
|
| 586 |
"english": english,
|
| 587 |
-
|
| 588 |
-
# Provide evidence html even in json (cheap: already computed in rerank)
|
| 589 |
"heatmap_html": heatmap_html,
|
| 590 |
"best_seg_html": best_html,
|
| 591 |
}
|
| 592 |
|
| 593 |
-
# If the caller asked for html AND did not disable highlight, also compute full highlight for each result.
|
| 594 |
-
# This is heavier. Recommended: keep hl_topn=0 for fast mode and use /highlight on click.
|
| 595 |
if want_html and hl_topn > 0:
|
| 596 |
extras = full_highlight_html(
|
| 597 |
query_norm=q_norm,
|
|
@@ -600,7 +561,6 @@ def search():
|
|
| 600 |
seg_maxlen=seg_maxlen,
|
| 601 |
)
|
| 602 |
r["arabic_clean_html"] = extras["arabic_clean_html"]
|
| 603 |
-
# You can overwrite with full-doc ones (optional):
|
| 604 |
r["heatmap_html"] = extras["heatmap_html"] or r["heatmap_html"]
|
| 605 |
r["best_seg_html"] = extras["best_seg_html"] or r["best_seg_html"]
|
| 606 |
|
|
@@ -624,10 +584,6 @@ def search():
|
|
| 624 |
|
| 625 |
@app.get("/highlight")
|
| 626 |
def highlight():
|
| 627 |
-
"""
|
| 628 |
-
Highlight a single hadith on-demand (for fast UI).
|
| 629 |
-
GET /highlight?q=...&hadithID=123&format=html&hl_topn=6&seg_maxlen=220
|
| 630 |
-
"""
|
| 631 |
q = request.args.get("q", "").strip()
|
| 632 |
hid_raw = request.args.get("hadithID", "").strip()
|
| 633 |
|
|
@@ -666,14 +622,10 @@ def highlight():
|
|
| 666 |
arabic = str(row.get("arabic", "") or "")
|
| 667 |
english = str(row.get("english", "") or "")
|
| 668 |
|
| 669 |
-
ar_clean = row.get("arabic_clean", "")
|
| 670 |
-
if ar_clean is None or (isinstance(ar_clean, float) and np.isnan(ar_clean)):
|
| 671 |
-
ar_clean = ""
|
| 672 |
-
ar_clean = str(ar_clean).strip()
|
| 673 |
if not ar_clean:
|
| 674 |
ar_clean = normalize_ar(arabic)
|
| 675 |
|
| 676 |
-
# Always produce evidence + highlight here (one doc only)
|
| 677 |
extras = full_highlight_html(
|
| 678 |
query_norm=q_norm,
|
| 679 |
arabic_clean_text=ar_clean,
|
|
@@ -681,7 +633,6 @@ def highlight():
|
|
| 681 |
seg_maxlen=seg_maxlen,
|
| 682 |
)
|
| 683 |
|
| 684 |
-
# lexical
|
| 685 |
lex_r, lex_terms = lexical_ratio(q_norm, ar_clean)
|
| 686 |
|
| 687 |
return jsonify({
|
|
@@ -692,14 +643,11 @@ def highlight():
|
|
| 692 |
"format": "html" if want_html else "json",
|
| 693 |
"hl_topn": hl_topn,
|
| 694 |
"seg_maxlen": seg_maxlen,
|
| 695 |
-
|
| 696 |
"lex_ratio": float(lex_r),
|
| 697 |
"lex_terms": lex_terms,
|
| 698 |
-
|
| 699 |
"arabic": arabic,
|
| 700 |
"arabic_clean": ar_clean,
|
| 701 |
"english": english,
|
| 702 |
-
|
| 703 |
"arabic_clean_html": extras.get("arabic_clean_html", "") if want_html else "",
|
| 704 |
"heatmap_html": extras.get("heatmap_html", ""),
|
| 705 |
"best_seg_html": extras.get("best_seg_html", ""),
|
|
@@ -707,5 +655,6 @@ def highlight():
|
|
| 707 |
|
| 708 |
|
| 709 |
if __name__ == "__main__":
|
| 710 |
-
#
|
| 711 |
-
|
|
|
|
|
|
| 3 |
import os
|
| 4 |
import re
|
| 5 |
import time
|
|
|
|
| 6 |
from functools import lru_cache
|
| 7 |
+
from typing import List, Dict, Any, Tuple
|
| 8 |
|
| 9 |
import numpy as np
|
| 10 |
import pandas as pd
|
| 11 |
import faiss
|
|
|
|
| 12 |
from flask import Flask, request, jsonify, Response
|
| 13 |
from sentence_transformers import SentenceTransformer
|
| 14 |
|
| 15 |
|
| 16 |
# =========================
|
| 17 |
+
# Config (HF Space defaults)
|
| 18 |
# =========================
|
| 19 |
+
INDEX_PATH = os.getenv("HADITH_INDEX_PATH", "hadith_semantic.faiss")
|
| 20 |
META_PATH = os.getenv("HADITH_META_PATH", "hadith_meta.parquet")
|
| 21 |
+
|
| 22 |
+
# Small/fast multilingual model (good on free CPU)
|
| 23 |
+
MODEL_NAME = os.getenv(
|
| 24 |
+
"HADITH_MODEL_NAME",
|
| 25 |
+
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
| 26 |
+
)
|
| 27 |
|
| 28 |
DEFAULT_TOP_K = 10
|
| 29 |
MAX_TOP_K = 50
|
|
|
|
| 32 |
MAX_RERANK_K = 120
|
| 33 |
MIN_RERANK_K = 10
|
| 34 |
|
| 35 |
+
DEFAULT_HL_TOPN = 6
|
| 36 |
MAX_HL_TOPN = 25
|
| 37 |
|
| 38 |
DEFAULT_SEG_MAXLEN = 220
|
| 39 |
MAX_SEG_MAXLEN = 420
|
| 40 |
MIN_SEG_MAXLEN = 120
|
| 41 |
|
| 42 |
+
# Rerank knobs (keep small for HF free CPU)
|
| 43 |
+
RERANK_MAX_SEGS_PER_DOC = int(os.getenv("RERANK_MAX_SEGS_PER_DOC", "8"))
|
| 44 |
+
RERANK_SEG_MAXLEN = int(os.getenv("RERANK_SEG_MAXLEN", "240"))
|
| 45 |
+
RERANK_WEIGHT = float(os.getenv("RERANK_WEIGHT", "0.65"))
|
| 46 |
RERANK_ENABLE = os.getenv("RERANK_ENABLE", "1").strip() != "0"
|
| 47 |
|
| 48 |
# CORS
|
| 49 |
+
CORS_ALLOW_ORIGIN = os.getenv("CORS_ALLOW_ORIGIN", "*")
|
| 50 |
|
| 51 |
|
| 52 |
# =========================
|
|
|
|
| 78 |
t = normalize_ar(text)
|
| 79 |
t = _AR_PUNCT.sub(" ", t)
|
| 80 |
toks = [x.strip() for x in t.split() if x.strip()]
|
|
|
|
| 81 |
toks = [x for x in toks if len(x) >= 2]
|
| 82 |
return toks
|
| 83 |
|
|
|
|
| 119 |
if buf:
|
| 120 |
segs.append(buf)
|
| 121 |
|
|
|
|
| 122 |
if len(segs) <= 1 and len(t) > max_len:
|
| 123 |
segs = [t[i:i+max_len].strip() for i in range(0, len(t), max_len) if t[i:i+max_len].strip()]
|
| 124 |
return segs
|
| 125 |
|
| 126 |
def pick_segs_for_rerank(segs: List[str], max_keep: int) -> List[str]:
|
|
|
|
| 127 |
if len(segs) <= max_keep:
|
| 128 |
return segs
|
|
|
|
| 129 |
idxs = np.linspace(0, len(segs) - 1, num=max_keep)
|
| 130 |
idxs = [int(round(x)) for x in idxs]
|
|
|
|
| 131 |
seen = set()
|
| 132 |
out = []
|
| 133 |
for i in idxs:
|
|
|
|
| 139 |
|
| 140 |
# =========================
|
| 141 |
# Embedding helpers (cached)
|
| 142 |
+
# IMPORTANT: This model does NOT use "query:" / "passage:" prefixes.
|
| 143 |
# =========================
|
| 144 |
@lru_cache(maxsize=2048)
|
| 145 |
def cached_query_emb(query_norm: str) -> bytes:
|
| 146 |
+
emb = model.encode([query_norm], normalize_embeddings=True).astype("float32")[0]
|
| 147 |
return emb.tobytes()
|
| 148 |
|
| 149 |
def get_query_emb(query_norm: str) -> np.ndarray:
|
|
|
|
| 151 |
|
| 152 |
|
| 153 |
# =========================
|
| 154 |
+
# Evidence HTML
|
| 155 |
# =========================
|
| 156 |
def build_heatmap_html(segs: List[str], sims: np.ndarray, top_n: int = 6) -> str:
|
|
|
|
| 157 |
if not segs or sims.size == 0:
|
| 158 |
return ""
|
| 159 |
|
|
|
|
| 164 |
s_max = float(np.max(sims))
|
| 165 |
denom = (s_max - s_min) if (s_max - s_min) > 1e-6 else 1.0
|
| 166 |
|
|
|
|
| 167 |
order = np.argsort(-sims)
|
| 168 |
keep = set(order[:top_n])
|
| 169 |
|
| 170 |
blocks = []
|
| 171 |
for i in range(n):
|
| 172 |
+
w = (float(sims[i]) - s_min) / denom
|
|
|
|
| 173 |
alpha = (0.20 + 0.60 * w) if i in keep else (0.08 + 0.18 * w)
|
| 174 |
alpha = max(0.06, min(alpha, 0.85))
|
| 175 |
blocks.append(
|
|
|
|
| 189 |
if not segs or sims.size == 0:
|
| 190 |
return ""
|
| 191 |
i = int(np.argmax(sims))
|
| 192 |
+
return (
|
| 193 |
+
'<span style="background:rgba(255,230,120,0.55);'
|
| 194 |
+
'border:1px solid rgba(234,179,8,0.35);border-radius:12px;padding:3px 8px;display:inline;">'
|
| 195 |
+
f'{escape_html(segs[i])}</span>'
|
| 196 |
+
)
|
| 197 |
|
| 198 |
def lexical_ratio(query_norm: str, doc_norm: str, max_terms: int = 10) -> Tuple[float, str]:
|
| 199 |
q_toks = ar_tokens(query_norm)
|
|
|
|
| 206 |
return float(ratio), terms
|
| 207 |
|
| 208 |
def confidence_label(score: float) -> Tuple[str, str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
if score >= 0.78:
|
| 210 |
return "HIGH", "bHigh"
|
| 211 |
if score >= 0.62:
|
| 212 |
return "MED", "bMed"
|
| 213 |
return "LOW", "bLow"
|
| 214 |
|
| 215 |
+
|
| 216 |
+
# =========================
|
| 217 |
+
# Rerank
|
| 218 |
+
# =========================
|
| 219 |
def rerank_rows(
|
| 220 |
query_norm: str,
|
| 221 |
df: pd.DataFrame,
|
| 222 |
k_final: int,
|
| 223 |
) -> Tuple[pd.DataFrame, Dict[int, Dict[str, Any]]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
evidence: Dict[int, Dict[str, Any]] = {}
|
| 225 |
|
| 226 |
if (not RERANK_ENABLE) or df.empty:
|
|
|
|
| 227 |
for _, row in df.iterrows():
|
| 228 |
hid = int(row["hadithID"]) if pd.notna(row.get("hadithID")) else -1
|
| 229 |
evidence[hid] = {"mode": "disabled"}
|
| 230 |
return df.head(k_final), evidence
|
| 231 |
|
|
|
|
| 232 |
cand_rows = df.copy()
|
| 233 |
|
| 234 |
per_doc_segs: List[List[str]] = []
|
|
|
|
| 238 |
hid = int(row["hadithID"]) if pd.notna(row.get("hadithID")) else -1
|
| 239 |
doc_hids.append(hid)
|
| 240 |
|
| 241 |
+
ar = str(row.get("arabic_clean", "") or "").strip()
|
| 242 |
+
if not ar:
|
| 243 |
+
ar = normalize_ar(str(row.get("arabic", "") or ""))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
+
segs = split_ar_segments(ar, max_len=RERANK_SEG_MAXLEN)
|
| 246 |
segs = pick_segs_for_rerank(segs, max_keep=RERANK_MAX_SEGS_PER_DOC)
|
| 247 |
+
if not segs and ar:
|
| 248 |
+
segs = [ar[:RERANK_SEG_MAXLEN]]
|
| 249 |
per_doc_segs.append(segs)
|
| 250 |
|
|
|
|
| 251 |
all_segs: List[str] = []
|
| 252 |
offsets: List[Tuple[int, int]] = []
|
| 253 |
cur = 0
|
|
|
|
| 258 |
offsets.append((start, cur))
|
| 259 |
|
| 260 |
if not all_segs:
|
|
|
|
| 261 |
for hid in doc_hids:
|
| 262 |
evidence[hid] = {"mode": "empty"}
|
| 263 |
return cand_rows.head(k_final), evidence
|
| 264 |
|
| 265 |
+
q_emb = get_query_emb(query_norm)
|
| 266 |
+
seg_emb = model.encode(all_segs, normalize_embeddings=True).astype("float32")
|
| 267 |
+
sims_all = (seg_emb @ q_emb).astype(np.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
|
|
|
|
| 269 |
rr_scores: List[float] = []
|
| 270 |
for hid, (start, end), segs in zip(doc_hids, offsets, per_doc_segs):
|
| 271 |
if start == end:
|
|
|
|
| 276 |
rr = float(np.max(sims))
|
| 277 |
rr_scores.append(rr)
|
| 278 |
|
|
|
|
| 279 |
hm = build_heatmap_html(segs, sims, top_n=min(6, len(segs))) if sims.size else ""
|
| 280 |
best = best_seg_html(segs, sims) if sims.size else ""
|
| 281 |
evidence[hid] = {
|
|
|
|
| 283 |
"rerank_score": rr,
|
| 284 |
"heatmap_html": hm,
|
| 285 |
"best_seg_html": best,
|
|
|
|
|
|
|
| 286 |
}
|
| 287 |
|
| 288 |
cand_rows["rerank_score"] = rr_scores
|
| 289 |
|
|
|
|
|
|
|
| 290 |
faiss_scores = cand_rows["score"].astype(float).to_numpy()
|
| 291 |
rr = cand_rows["rerank_score"].astype(float).to_numpy()
|
| 292 |
|
|
|
|
| 299 |
|
| 300 |
|
| 301 |
# =========================
|
| 302 |
+
# Full highlight for ONE hadith
|
| 303 |
# =========================
|
| 304 |
def full_highlight_html(
|
| 305 |
query_norm: str,
|
|
|
|
| 316 |
}
|
| 317 |
|
| 318 |
q_emb = get_query_emb(query_norm)
|
| 319 |
+
seg_emb = model.encode(segs, normalize_embeddings=True).astype("float32")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
sims = (seg_emb @ q_emb).astype(np.float32)
|
| 321 |
|
| 322 |
s_min = float(np.min(sims))
|
|
|
|
| 362 |
index = faiss.read_index(INDEX_PATH)
|
| 363 |
meta = pd.read_parquet(META_PATH)
|
| 364 |
|
| 365 |
+
# Accept corpusID or hadithID, normalize to hadithID
|
| 366 |
+
id_col = "hadithID" if "hadithID" in meta.columns else ("corpusID" if "corpusID" in meta.columns else None)
|
| 367 |
+
if id_col is None:
|
| 368 |
+
raise ValueError("Meta must contain 'hadithID' or 'corpusID'")
|
| 369 |
+
if id_col != "hadithID":
|
| 370 |
+
meta = meta.rename(columns={id_col: "hadithID"})
|
| 371 |
+
|
| 372 |
required_cols = {"hadithID", "collection", "hadith_number", "arabic", "english"}
|
| 373 |
missing = required_cols - set(meta.columns)
|
| 374 |
if missing:
|
|
|
|
| 377 |
if "arabic_clean" not in meta.columns:
|
| 378 |
meta["arabic_clean"] = ""
|
| 379 |
|
| 380 |
+
meta["arabic"] = meta["arabic"].fillna("").astype(str)
|
| 381 |
+
meta["english"] = meta["english"].fillna("").astype(str)
|
| 382 |
+
meta["arabic_clean"] = meta["arabic_clean"].fillna("").astype(str)
|
| 383 |
+
|
| 384 |
|
| 385 |
# =========================
|
| 386 |
# FAISS Search
|
|
|
|
| 391 |
return meta.iloc[0:0].copy()
|
| 392 |
|
| 393 |
top_k = max(1, min(int(top_k), MAX_TOP_K))
|
| 394 |
+
q_norm = normalize_ar(q) # Arabic normalize, safe for English too
|
| 395 |
|
| 396 |
q_emb = get_query_emb(q_norm).reshape(1, -1)
|
| 397 |
scores, idx = index.search(q_emb, top_k)
|
|
|
|
| 399 |
res = meta.iloc[idx[0]].copy()
|
| 400 |
res["score"] = scores[0]
|
| 401 |
res = res.sort_values("score", ascending=False)
|
|
|
|
|
|
|
|
|
|
| 402 |
res = res[res["arabic"].str.strip() != ""]
|
| 403 |
return res
|
| 404 |
|
|
|
|
| 451 |
def search():
|
| 452 |
q = request.args.get("q", "").strip()
|
| 453 |
|
| 454 |
+
# final top-k
|
| 455 |
k_raw = request.args.get("k", str(DEFAULT_TOP_K)).strip()
|
| 456 |
try:
|
| 457 |
k = int(k_raw) if k_raw else DEFAULT_TOP_K
|
|
|
|
| 459 |
k = DEFAULT_TOP_K
|
| 460 |
k = max(1, min(k, MAX_TOP_K))
|
| 461 |
|
| 462 |
+
# rerank pool size
|
| 463 |
rk_raw = request.args.get("rerank_k", str(DEFAULT_RERANK_K)).strip()
|
| 464 |
try:
|
| 465 |
rerank_k = int(rk_raw) if rk_raw else DEFAULT_RERANK_K
|
|
|
|
| 468 |
rerank_k = max(MIN_RERANK_K, min(rerank_k, MAX_RERANK_K))
|
| 469 |
rerank_k = max(rerank_k, k)
|
| 470 |
|
| 471 |
+
# highlight controls
|
| 472 |
hl_raw = request.args.get("hl_topn", str(DEFAULT_HL_TOPN)).strip()
|
| 473 |
seg_raw = request.args.get("seg_maxlen", str(DEFAULT_SEG_MAXLEN)).strip()
|
| 474 |
try:
|
|
|
|
| 504 |
|
| 505 |
t0 = time.time()
|
| 506 |
|
| 507 |
+
# 1) retrieve pool
|
| 508 |
df_pool = semantic_search_df(q, top_k=rerank_k)
|
| 509 |
q_norm = normalize_ar(q)
|
| 510 |
|
| 511 |
+
# 2) rerank -> final
|
| 512 |
df_final, ev = rerank_rows(query_norm=q_norm, df=df_pool, k_final=k)
|
| 513 |
|
| 514 |
took_ms = int((time.time() - t0) * 1000)
|
| 515 |
|
|
|
|
| 516 |
results: List[Dict[str, Any]] = []
|
| 517 |
for _, row in df_final.iterrows():
|
| 518 |
hid = int(row.get("hadithID")) if pd.notna(row.get("hadithID")) else None
|
| 519 |
arabic = str(row.get("arabic", "") or "")
|
| 520 |
english = str(row.get("english", "") or "")
|
| 521 |
|
| 522 |
+
ar_clean = str(row.get("arabic_clean", "") or "").strip()
|
|
|
|
|
|
|
|
|
|
| 523 |
if not ar_clean:
|
| 524 |
ar_clean = normalize_ar(arabic)
|
| 525 |
|
|
|
|
| 526 |
lex_r, lex_terms = lexical_ratio(q_norm, ar_clean)
|
| 527 |
|
|
|
|
| 528 |
faiss_score = float(row.get("score")) if pd.notna(row.get("score")) else 0.0
|
| 529 |
rerank_score = float(row.get("rerank_score")) if pd.notna(row.get("rerank_score")) else faiss_score
|
| 530 |
final_score = float(row.get("final_score")) if pd.notna(row.get("final_score")) else faiss_score
|
|
|
|
| 539 |
"hadithID": hid,
|
| 540 |
"collection": str(row.get("collection", "") or ""),
|
| 541 |
"hadith_number": int(row.get("hadith_number")) if pd.notna(row.get("hadith_number")) else None,
|
|
|
|
|
|
|
| 542 |
"score": final_score,
|
|
|
|
|
|
|
| 543 |
"faiss_score": faiss_score,
|
| 544 |
"rerank_score": rerank_score,
|
|
|
|
| 545 |
"conf_label": conf_label,
|
| 546 |
"conf_class": conf_class,
|
|
|
|
| 547 |
"lex_ratio": float(lex_r),
|
| 548 |
"lex_terms": lex_terms,
|
|
|
|
| 549 |
"arabic": arabic,
|
| 550 |
"arabic_clean": ar_clean,
|
| 551 |
"english": english,
|
|
|
|
|
|
|
| 552 |
"heatmap_html": heatmap_html,
|
| 553 |
"best_seg_html": best_html,
|
| 554 |
}
|
| 555 |
|
|
|
|
|
|
|
| 556 |
if want_html and hl_topn > 0:
|
| 557 |
extras = full_highlight_html(
|
| 558 |
query_norm=q_norm,
|
|
|
|
| 561 |
seg_maxlen=seg_maxlen,
|
| 562 |
)
|
| 563 |
r["arabic_clean_html"] = extras["arabic_clean_html"]
|
|
|
|
| 564 |
r["heatmap_html"] = extras["heatmap_html"] or r["heatmap_html"]
|
| 565 |
r["best_seg_html"] = extras["best_seg_html"] or r["best_seg_html"]
|
| 566 |
|
|
|
|
| 584 |
|
| 585 |
@app.get("/highlight")
|
| 586 |
def highlight():
|
|
|
|
|
|
|
|
|
|
|
|
|
| 587 |
q = request.args.get("q", "").strip()
|
| 588 |
hid_raw = request.args.get("hadithID", "").strip()
|
| 589 |
|
|
|
|
| 622 |
arabic = str(row.get("arabic", "") or "")
|
| 623 |
english = str(row.get("english", "") or "")
|
| 624 |
|
| 625 |
+
ar_clean = str(row.get("arabic_clean", "") or "").strip()
|
|
|
|
|
|
|
|
|
|
| 626 |
if not ar_clean:
|
| 627 |
ar_clean = normalize_ar(arabic)
|
| 628 |
|
|
|
|
| 629 |
extras = full_highlight_html(
|
| 630 |
query_norm=q_norm,
|
| 631 |
arabic_clean_text=ar_clean,
|
|
|
|
| 633 |
seg_maxlen=seg_maxlen,
|
| 634 |
)
|
| 635 |
|
|
|
|
| 636 |
lex_r, lex_terms = lexical_ratio(q_norm, ar_clean)
|
| 637 |
|
| 638 |
return jsonify({
|
|
|
|
| 643 |
"format": "html" if want_html else "json",
|
| 644 |
"hl_topn": hl_topn,
|
| 645 |
"seg_maxlen": seg_maxlen,
|
|
|
|
| 646 |
"lex_ratio": float(lex_r),
|
| 647 |
"lex_terms": lex_terms,
|
|
|
|
| 648 |
"arabic": arabic,
|
| 649 |
"arabic_clean": ar_clean,
|
| 650 |
"english": english,
|
|
|
|
| 651 |
"arabic_clean_html": extras.get("arabic_clean_html", "") if want_html else "",
|
| 652 |
"heatmap_html": extras.get("heatmap_html", ""),
|
| 653 |
"best_seg_html": extras.get("best_seg_html", ""),
|
|
|
|
| 655 |
|
| 656 |
|
| 657 |
if __name__ == "__main__":
|
| 658 |
+
# Hugging Face Spaces uses PORT=7860
|
| 659 |
+
port = int(os.getenv("PORT", "7860"))
|
| 660 |
+
app.run(host="0.0.0.0", port=port, debug=False)
|