Alshargi commited on
Commit
397688f
·
verified ·
1 Parent(s): 90c65a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -112
app.py CHANGED
@@ -3,24 +3,27 @@ from __future__ import annotations
3
  import os
4
  import re
5
  import time
6
- import math
7
  from functools import lru_cache
8
- from typing import List, Dict, Any, Tuple, Optional
9
 
10
  import numpy as np
11
  import pandas as pd
12
  import faiss
13
-
14
  from flask import Flask, request, jsonify, Response
15
  from sentence_transformers import SentenceTransformer
16
 
17
 
18
  # =========================
19
- # Config
20
  # =========================
21
- INDEX_PATH = os.getenv("HADITH_INDEX_PATH", "hadith_ar.faiss")
22
  META_PATH = os.getenv("HADITH_META_PATH", "hadith_meta.parquet")
23
- MODEL_NAME = os.getenv("HADITH_MODEL_NAME", "intfloat/multilingual-e5-base")
 
 
 
 
 
24
 
25
  DEFAULT_TOP_K = 10
26
  MAX_TOP_K = 50
@@ -29,21 +32,21 @@ DEFAULT_RERANK_K = 35
29
  MAX_RERANK_K = 120
30
  MIN_RERANK_K = 10
31
 
32
- DEFAULT_HL_TOPN = 6 # for /highlight and html responses
33
  MAX_HL_TOPN = 25
34
 
35
  DEFAULT_SEG_MAXLEN = 220
36
  MAX_SEG_MAXLEN = 420
37
  MIN_SEG_MAXLEN = 120
38
 
39
- # Rerank speed/quality knobs (safe defaults)
40
- RERANK_MAX_SEGS_PER_DOC = int(os.getenv("RERANK_MAX_SEGS_PER_DOC", "10")) # keep it small for speed
41
- RERANK_SEG_MAXLEN = int(os.getenv("RERANK_SEG_MAXLEN", "240")) # segment length during rerank
42
- RERANK_WEIGHT = float(os.getenv("RERANK_WEIGHT", "0.65")) # 0..1 combine rerank with faiss
43
  RERANK_ENABLE = os.getenv("RERANK_ENABLE", "1").strip() != "0"
44
 
45
  # CORS
46
- CORS_ALLOW_ORIGIN = os.getenv("CORS_ALLOW_ORIGIN", "*") # set to your domain if you want strict
47
 
48
 
49
  # =========================
@@ -75,7 +78,6 @@ def ar_tokens(text: str) -> List[str]:
75
  t = normalize_ar(text)
76
  t = _AR_PUNCT.sub(" ", t)
77
  toks = [x.strip() for x in t.split() if x.strip()]
78
- # remove super short tokens
79
  toks = [x for x in toks if len(x) >= 2]
80
  return toks
81
 
@@ -117,19 +119,15 @@ def split_ar_segments(text: str, max_len: int) -> List[str]:
117
  if buf:
118
  segs.append(buf)
119
 
120
- # fallback chunking
121
  if len(segs) <= 1 and len(t) > max_len:
122
  segs = [t[i:i+max_len].strip() for i in range(0, len(t), max_len) if t[i:i+max_len].strip()]
123
  return segs
124
 
125
  def pick_segs_for_rerank(segs: List[str], max_keep: int) -> List[str]:
126
- """Pick up to max_keep segments spread out (for speed)."""
127
  if len(segs) <= max_keep:
128
  return segs
129
- # spread indices evenly
130
  idxs = np.linspace(0, len(segs) - 1, num=max_keep)
131
  idxs = [int(round(x)) for x in idxs]
132
- # unique preserve order
133
  seen = set()
134
  out = []
135
  for i in idxs:
@@ -141,10 +139,11 @@ def pick_segs_for_rerank(segs: List[str], max_keep: int) -> List[str]:
141
 
142
  # =========================
143
  # Embedding helpers (cached)
 
144
  # =========================
145
  @lru_cache(maxsize=2048)
146
  def cached_query_emb(query_norm: str) -> bytes:
147
- emb = model.encode(["query: " + query_norm], normalize_embeddings=True).astype("float32")[0]
148
  return emb.tobytes()
149
 
150
  def get_query_emb(query_norm: str) -> np.ndarray:
@@ -152,10 +151,9 @@ def get_query_emb(query_norm: str) -> np.ndarray:
152
 
153
 
154
  # =========================
155
- # Rerank + evidence HTML (no extra encode)
156
  # =========================
157
  def build_heatmap_html(segs: List[str], sims: np.ndarray, top_n: int = 6) -> str:
158
- """Small bar-like heatmap using segment similarity (already computed)."""
159
  if not segs or sims.size == 0:
160
  return ""
161
 
@@ -166,14 +164,12 @@ def build_heatmap_html(segs: List[str], sims: np.ndarray, top_n: int = 6) -> str
166
  s_max = float(np.max(sims))
167
  denom = (s_max - s_min) if (s_max - s_min) > 1e-6 else 1.0
168
 
169
- # choose top indices
170
  order = np.argsort(-sims)
171
  keep = set(order[:top_n])
172
 
173
  blocks = []
174
  for i in range(n):
175
- w = (float(sims[i]) - s_min) / denom # 0..1
176
- # stronger for top segments
177
  alpha = (0.20 + 0.60 * w) if i in keep else (0.08 + 0.18 * w)
178
  alpha = max(0.06, min(alpha, 0.85))
179
  blocks.append(
@@ -193,7 +189,11 @@ def best_seg_html(segs: List[str], sims: np.ndarray) -> str:
193
  if not segs or sims.size == 0:
194
  return ""
195
  i = int(np.argmax(sims))
196
- return f'<span style="background:rgba(255,230,120,0.55);border:1px solid rgba(234,179,8,0.35);border-radius:12px;padding:3px 8px;display:inline;">{escape_html(segs[i])}</span>'
 
 
 
 
197
 
198
  def lexical_ratio(query_norm: str, doc_norm: str, max_terms: int = 10) -> Tuple[float, str]:
199
  q_toks = ar_tokens(query_norm)
@@ -206,38 +206,29 @@ def lexical_ratio(query_norm: str, doc_norm: str, max_terms: int = 10) -> Tuple[
206
  return float(ratio), terms
207
 
208
  def confidence_label(score: float) -> Tuple[str, str]:
209
- """
210
- Simple score->label mapping.
211
- Assumes cosine-like range ~[0..1] after normalization & blending.
212
- """
213
  if score >= 0.78:
214
  return "HIGH", "bHigh"
215
  if score >= 0.62:
216
  return "MED", "bMed"
217
  return "LOW", "bLow"
218
 
 
 
 
 
219
  def rerank_rows(
220
  query_norm: str,
221
  df: pd.DataFrame,
222
  k_final: int,
223
  ) -> Tuple[pd.DataFrame, Dict[int, Dict[str, Any]]]:
224
- """
225
- Rerank using segment max similarity:
226
- - Split each doc to segments (short)
227
- - Pick a limited set of segments (speed)
228
- - One encode call for all segments
229
- Returns reranked df and per-hadith evidence dict (sims/segs + prebuilt html).
230
- """
231
  evidence: Dict[int, Dict[str, Any]] = {}
232
 
233
  if (not RERANK_ENABLE) or df.empty:
234
- # still fill basic fields
235
  for _, row in df.iterrows():
236
  hid = int(row["hadithID"]) if pd.notna(row.get("hadithID")) else -1
237
  evidence[hid] = {"mode": "disabled"}
238
  return df.head(k_final), evidence
239
 
240
- # Collect segments for each candidate
241
  cand_rows = df.copy()
242
 
243
  per_doc_segs: List[List[str]] = []
@@ -247,21 +238,16 @@ def rerank_rows(
247
  hid = int(row["hadithID"]) if pd.notna(row.get("hadithID")) else -1
248
  doc_hids.append(hid)
249
 
250
- ar = str(row.get("arabic", "") or "")
251
- ar_clean = row.get("arabic_clean", "")
252
- if ar_clean is None or (isinstance(ar_clean, float) and np.isnan(ar_clean)):
253
- ar_clean = ""
254
- ar_clean = str(ar_clean).strip()
255
- if not ar_clean:
256
- ar_clean = normalize_ar(ar)
257
 
258
- segs = split_ar_segments(ar_clean, max_len=RERANK_SEG_MAXLEN)
259
  segs = pick_segs_for_rerank(segs, max_keep=RERANK_MAX_SEGS_PER_DOC)
260
- if not segs:
261
- segs = [ar_clean[:RERANK_SEG_MAXLEN]] if ar_clean else []
262
  per_doc_segs.append(segs)
263
 
264
- # Flatten
265
  all_segs: List[str] = []
266
  offsets: List[Tuple[int, int]] = []
267
  cur = 0
@@ -272,21 +258,14 @@ def rerank_rows(
272
  offsets.append((start, cur))
273
 
274
  if not all_segs:
275
- # fallback: no rerank
276
  for hid in doc_hids:
277
  evidence[hid] = {"mode": "empty"}
278
  return cand_rows.head(k_final), evidence
279
 
280
- # Encode query once + all segments once
281
- q_emb = get_query_emb(query_norm) # (d,)
282
- seg_emb = model.encode(
283
- ["passage: " + s for s in all_segs],
284
- normalize_embeddings=True
285
- ).astype("float32") # (N, d)
286
-
287
- sims_all = (seg_emb @ q_emb).astype(np.float32) # (N,)
288
 
289
- # Compute per-doc rerank score = max(sim)
290
  rr_scores: List[float] = []
291
  for hid, (start, end), segs in zip(doc_hids, offsets, per_doc_segs):
292
  if start == end:
@@ -297,7 +276,6 @@ def rerank_rows(
297
  rr = float(np.max(sims))
298
  rr_scores.append(rr)
299
 
300
- # Build evidence HTML now (no extra encode)
301
  hm = build_heatmap_html(segs, sims, top_n=min(6, len(segs))) if sims.size else ""
302
  best = best_seg_html(segs, sims) if sims.size else ""
303
  evidence[hid] = {
@@ -305,14 +283,10 @@ def rerank_rows(
305
  "rerank_score": rr,
306
  "heatmap_html": hm,
307
  "best_seg_html": best,
308
- "rerank_segs": segs, # keep for debugging (can omit if you want)
309
- "rerank_sims": None, # don't ship full sims to client
310
  }
311
 
312
  cand_rows["rerank_score"] = rr_scores
313
 
314
- # Blend: score_final = (1-w)*faiss + w*rerank
315
- # Both are cosine-ish in [0,1] in your setup (normalize embeddings + IP index)
316
  faiss_scores = cand_rows["score"].astype(float).to_numpy()
317
  rr = cand_rows["rerank_score"].astype(float).to_numpy()
318
 
@@ -325,7 +299,7 @@ def rerank_rows(
325
 
326
 
327
  # =========================
328
- # Full highlight for ONE hadith (on click)
329
  # =========================
330
  def full_highlight_html(
331
  query_norm: str,
@@ -342,11 +316,7 @@ def full_highlight_html(
342
  }
343
 
344
  q_emb = get_query_emb(query_norm)
345
- seg_emb = model.encode(
346
- ["passage: " + s for s in segs],
347
- normalize_embeddings=True
348
- ).astype("float32")
349
-
350
  sims = (seg_emb @ q_emb).astype(np.float32)
351
 
352
  s_min = float(np.min(sims))
@@ -392,6 +362,13 @@ model = SentenceTransformer(MODEL_NAME)
392
  index = faiss.read_index(INDEX_PATH)
393
  meta = pd.read_parquet(META_PATH)
394
 
 
 
 
 
 
 
 
395
  required_cols = {"hadithID", "collection", "hadith_number", "arabic", "english"}
396
  missing = required_cols - set(meta.columns)
397
  if missing:
@@ -400,6 +377,10 @@ if missing:
400
  if "arabic_clean" not in meta.columns:
401
  meta["arabic_clean"] = ""
402
 
 
 
 
 
403
 
404
  # =========================
405
  # FAISS Search
@@ -410,7 +391,7 @@ def semantic_search_df(query: str, top_k: int) -> pd.DataFrame:
410
  return meta.iloc[0:0].copy()
411
 
412
  top_k = max(1, min(int(top_k), MAX_TOP_K))
413
- q_norm = normalize_ar(q)
414
 
415
  q_emb = get_query_emb(q_norm).reshape(1, -1)
416
  scores, idx = index.search(q_emb, top_k)
@@ -418,9 +399,6 @@ def semantic_search_df(query: str, top_k: int) -> pd.DataFrame:
418
  res = meta.iloc[idx[0]].copy()
419
  res["score"] = scores[0]
420
  res = res.sort_values("score", ascending=False)
421
-
422
- # ensure arabic
423
- res["arabic"] = res["arabic"].fillna("").astype(str)
424
  res = res[res["arabic"].str.strip() != ""]
425
  return res
426
 
@@ -473,7 +451,7 @@ def health():
473
  def search():
474
  q = request.args.get("q", "").strip()
475
 
476
- # TopK final
477
  k_raw = request.args.get("k", str(DEFAULT_TOP_K)).strip()
478
  try:
479
  k = int(k_raw) if k_raw else DEFAULT_TOP_K
@@ -481,7 +459,7 @@ def search():
481
  k = DEFAULT_TOP_K
482
  k = max(1, min(k, MAX_TOP_K))
483
 
484
- # rerank pool
485
  rk_raw = request.args.get("rerank_k", str(DEFAULT_RERANK_K)).strip()
486
  try:
487
  rerank_k = int(rk_raw) if rk_raw else DEFAULT_RERANK_K
@@ -490,7 +468,7 @@ def search():
490
  rerank_k = max(MIN_RERANK_K, min(rerank_k, MAX_RERANK_K))
491
  rerank_k = max(rerank_k, k)
492
 
493
- # Highlight controls (only used for format=html; for fast mode you can still send hl_topn=0)
494
  hl_raw = request.args.get("hl_topn", str(DEFAULT_HL_TOPN)).strip()
495
  seg_raw = request.args.get("seg_maxlen", str(DEFAULT_SEG_MAXLEN)).strip()
496
  try:
@@ -526,33 +504,27 @@ def search():
526
 
527
  t0 = time.time()
528
 
529
- # 1) FAISS retrieve pool (rerank_k)
530
  df_pool = semantic_search_df(q, top_k=rerank_k)
531
  q_norm = normalize_ar(q)
532
 
533
- # 2) rerank to final k + evidence (no extra encode)
534
  df_final, ev = rerank_rows(query_norm=q_norm, df=df_pool, k_final=k)
535
 
536
  took_ms = int((time.time() - t0) * 1000)
537
 
538
- # Build results
539
  results: List[Dict[str, Any]] = []
540
  for _, row in df_final.iterrows():
541
  hid = int(row.get("hadithID")) if pd.notna(row.get("hadithID")) else None
542
  arabic = str(row.get("arabic", "") or "")
543
  english = str(row.get("english", "") or "")
544
 
545
- ar_clean = row.get("arabic_clean", "")
546
- if ar_clean is None or (isinstance(ar_clean, float) and np.isnan(ar_clean)):
547
- ar_clean = ""
548
- ar_clean = str(ar_clean).strip()
549
  if not ar_clean:
550
  ar_clean = normalize_ar(arabic)
551
 
552
- # lexical
553
  lex_r, lex_terms = lexical_ratio(q_norm, ar_clean)
554
 
555
- # scores
556
  faiss_score = float(row.get("score")) if pd.notna(row.get("score")) else 0.0
557
  rerank_score = float(row.get("rerank_score")) if pd.notna(row.get("rerank_score")) else faiss_score
558
  final_score = float(row.get("final_score")) if pd.notna(row.get("final_score")) else faiss_score
@@ -567,31 +539,20 @@ def search():
567
  "hadithID": hid,
568
  "collection": str(row.get("collection", "") or ""),
569
  "hadith_number": int(row.get("hadith_number")) if pd.notna(row.get("hadith_number")) else None,
570
-
571
- # unified score the UI should use
572
  "score": final_score,
573
-
574
- # diagnostics
575
  "faiss_score": faiss_score,
576
  "rerank_score": rerank_score,
577
-
578
  "conf_label": conf_label,
579
  "conf_class": conf_class,
580
-
581
  "lex_ratio": float(lex_r),
582
  "lex_terms": lex_terms,
583
-
584
  "arabic": arabic,
585
  "arabic_clean": ar_clean,
586
  "english": english,
587
-
588
- # Provide evidence html even in json (cheap: already computed in rerank)
589
  "heatmap_html": heatmap_html,
590
  "best_seg_html": best_html,
591
  }
592
 
593
- # If the caller asked for html AND did not disable highlight, also compute full highlight for each result.
594
- # This is heavier. Recommended: keep hl_topn=0 for fast mode and use /highlight on click.
595
  if want_html and hl_topn > 0:
596
  extras = full_highlight_html(
597
  query_norm=q_norm,
@@ -600,7 +561,6 @@ def search():
600
  seg_maxlen=seg_maxlen,
601
  )
602
  r["arabic_clean_html"] = extras["arabic_clean_html"]
603
- # You can overwrite with full-doc ones (optional):
604
  r["heatmap_html"] = extras["heatmap_html"] or r["heatmap_html"]
605
  r["best_seg_html"] = extras["best_seg_html"] or r["best_seg_html"]
606
 
@@ -624,10 +584,6 @@ def search():
624
 
625
  @app.get("/highlight")
626
  def highlight():
627
- """
628
- Highlight a single hadith on-demand (for fast UI).
629
- GET /highlight?q=...&hadithID=123&format=html&hl_topn=6&seg_maxlen=220
630
- """
631
  q = request.args.get("q", "").strip()
632
  hid_raw = request.args.get("hadithID", "").strip()
633
 
@@ -666,14 +622,10 @@ def highlight():
666
  arabic = str(row.get("arabic", "") or "")
667
  english = str(row.get("english", "") or "")
668
 
669
- ar_clean = row.get("arabic_clean", "")
670
- if ar_clean is None or (isinstance(ar_clean, float) and np.isnan(ar_clean)):
671
- ar_clean = ""
672
- ar_clean = str(ar_clean).strip()
673
  if not ar_clean:
674
  ar_clean = normalize_ar(arabic)
675
 
676
- # Always produce evidence + highlight here (one doc only)
677
  extras = full_highlight_html(
678
  query_norm=q_norm,
679
  arabic_clean_text=ar_clean,
@@ -681,7 +633,6 @@ def highlight():
681
  seg_maxlen=seg_maxlen,
682
  )
683
 
684
- # lexical
685
  lex_r, lex_terms = lexical_ratio(q_norm, ar_clean)
686
 
687
  return jsonify({
@@ -692,14 +643,11 @@ def highlight():
692
  "format": "html" if want_html else "json",
693
  "hl_topn": hl_topn,
694
  "seg_maxlen": seg_maxlen,
695
-
696
  "lex_ratio": float(lex_r),
697
  "lex_terms": lex_terms,
698
-
699
  "arabic": arabic,
700
  "arabic_clean": ar_clean,
701
  "english": english,
702
-
703
  "arabic_clean_html": extras.get("arabic_clean_html", "") if want_html else "",
704
  "heatmap_html": extras.get("heatmap_html", ""),
705
  "best_seg_html": extras.get("best_seg_html", ""),
@@ -707,5 +655,6 @@ def highlight():
707
 
708
 
709
  if __name__ == "__main__":
710
- # local run only
711
- app.run(host="127.0.0.1", port=5000, debug=True)
 
 
3
  import os
4
  import re
5
  import time
 
6
  from functools import lru_cache
7
+ from typing import List, Dict, Any, Tuple
8
 
9
  import numpy as np
10
  import pandas as pd
11
  import faiss
 
12
  from flask import Flask, request, jsonify, Response
13
  from sentence_transformers import SentenceTransformer
14
 
15
 
16
  # =========================
17
+ # Config (HF Space defaults)
18
  # =========================
19
+ INDEX_PATH = os.getenv("HADITH_INDEX_PATH", "hadith_semantic.faiss")
20
  META_PATH = os.getenv("HADITH_META_PATH", "hadith_meta.parquet")
21
+
22
+ # Small/fast multilingual model (good on free CPU)
23
+ MODEL_NAME = os.getenv(
24
+ "HADITH_MODEL_NAME",
25
+ "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
26
+ )
27
 
28
  DEFAULT_TOP_K = 10
29
  MAX_TOP_K = 50
 
32
  MAX_RERANK_K = 120
33
  MIN_RERANK_K = 10
34
 
35
+ DEFAULT_HL_TOPN = 6
36
  MAX_HL_TOPN = 25
37
 
38
  DEFAULT_SEG_MAXLEN = 220
39
  MAX_SEG_MAXLEN = 420
40
  MIN_SEG_MAXLEN = 120
41
 
42
+ # Rerank knobs (keep small for HF free CPU)
43
+ RERANK_MAX_SEGS_PER_DOC = int(os.getenv("RERANK_MAX_SEGS_PER_DOC", "8"))
44
+ RERANK_SEG_MAXLEN = int(os.getenv("RERANK_SEG_MAXLEN", "240"))
45
+ RERANK_WEIGHT = float(os.getenv("RERANK_WEIGHT", "0.65"))
46
  RERANK_ENABLE = os.getenv("RERANK_ENABLE", "1").strip() != "0"
47
 
48
  # CORS
49
+ CORS_ALLOW_ORIGIN = os.getenv("CORS_ALLOW_ORIGIN", "*")
50
 
51
 
52
  # =========================
 
78
  t = normalize_ar(text)
79
  t = _AR_PUNCT.sub(" ", t)
80
  toks = [x.strip() for x in t.split() if x.strip()]
 
81
  toks = [x for x in toks if len(x) >= 2]
82
  return toks
83
 
 
119
  if buf:
120
  segs.append(buf)
121
 
 
122
  if len(segs) <= 1 and len(t) > max_len:
123
  segs = [t[i:i+max_len].strip() for i in range(0, len(t), max_len) if t[i:i+max_len].strip()]
124
  return segs
125
 
126
  def pick_segs_for_rerank(segs: List[str], max_keep: int) -> List[str]:
 
127
  if len(segs) <= max_keep:
128
  return segs
 
129
  idxs = np.linspace(0, len(segs) - 1, num=max_keep)
130
  idxs = [int(round(x)) for x in idxs]
 
131
  seen = set()
132
  out = []
133
  for i in idxs:
 
139
 
140
  # =========================
141
  # Embedding helpers (cached)
142
+ # IMPORTANT: This model does NOT use "query:" / "passage:" prefixes.
143
  # =========================
144
  @lru_cache(maxsize=2048)
145
  def cached_query_emb(query_norm: str) -> bytes:
146
+ emb = model.encode([query_norm], normalize_embeddings=True).astype("float32")[0]
147
  return emb.tobytes()
148
 
149
  def get_query_emb(query_norm: str) -> np.ndarray:
 
151
 
152
 
153
  # =========================
154
+ # Evidence HTML
155
  # =========================
156
  def build_heatmap_html(segs: List[str], sims: np.ndarray, top_n: int = 6) -> str:
 
157
  if not segs or sims.size == 0:
158
  return ""
159
 
 
164
  s_max = float(np.max(sims))
165
  denom = (s_max - s_min) if (s_max - s_min) > 1e-6 else 1.0
166
 
 
167
  order = np.argsort(-sims)
168
  keep = set(order[:top_n])
169
 
170
  blocks = []
171
  for i in range(n):
172
+ w = (float(sims[i]) - s_min) / denom
 
173
  alpha = (0.20 + 0.60 * w) if i in keep else (0.08 + 0.18 * w)
174
  alpha = max(0.06, min(alpha, 0.85))
175
  blocks.append(
 
189
  if not segs or sims.size == 0:
190
  return ""
191
  i = int(np.argmax(sims))
192
+ return (
193
+ '<span style="background:rgba(255,230,120,0.55);'
194
+ 'border:1px solid rgba(234,179,8,0.35);border-radius:12px;padding:3px 8px;display:inline;">'
195
+ f'{escape_html(segs[i])}</span>'
196
+ )
197
 
198
  def lexical_ratio(query_norm: str, doc_norm: str, max_terms: int = 10) -> Tuple[float, str]:
199
  q_toks = ar_tokens(query_norm)
 
206
  return float(ratio), terms
207
 
208
  def confidence_label(score: float) -> Tuple[str, str]:
 
 
 
 
209
  if score >= 0.78:
210
  return "HIGH", "bHigh"
211
  if score >= 0.62:
212
  return "MED", "bMed"
213
  return "LOW", "bLow"
214
 
215
+
216
+ # =========================
217
+ # Rerank
218
+ # =========================
219
  def rerank_rows(
220
  query_norm: str,
221
  df: pd.DataFrame,
222
  k_final: int,
223
  ) -> Tuple[pd.DataFrame, Dict[int, Dict[str, Any]]]:
 
 
 
 
 
 
 
224
  evidence: Dict[int, Dict[str, Any]] = {}
225
 
226
  if (not RERANK_ENABLE) or df.empty:
 
227
  for _, row in df.iterrows():
228
  hid = int(row["hadithID"]) if pd.notna(row.get("hadithID")) else -1
229
  evidence[hid] = {"mode": "disabled"}
230
  return df.head(k_final), evidence
231
 
 
232
  cand_rows = df.copy()
233
 
234
  per_doc_segs: List[List[str]] = []
 
238
  hid = int(row["hadithID"]) if pd.notna(row.get("hadithID")) else -1
239
  doc_hids.append(hid)
240
 
241
+ ar = str(row.get("arabic_clean", "") or "").strip()
242
+ if not ar:
243
+ ar = normalize_ar(str(row.get("arabic", "") or ""))
 
 
 
 
244
 
245
+ segs = split_ar_segments(ar, max_len=RERANK_SEG_MAXLEN)
246
  segs = pick_segs_for_rerank(segs, max_keep=RERANK_MAX_SEGS_PER_DOC)
247
+ if not segs and ar:
248
+ segs = [ar[:RERANK_SEG_MAXLEN]]
249
  per_doc_segs.append(segs)
250
 
 
251
  all_segs: List[str] = []
252
  offsets: List[Tuple[int, int]] = []
253
  cur = 0
 
258
  offsets.append((start, cur))
259
 
260
  if not all_segs:
 
261
  for hid in doc_hids:
262
  evidence[hid] = {"mode": "empty"}
263
  return cand_rows.head(k_final), evidence
264
 
265
+ q_emb = get_query_emb(query_norm)
266
+ seg_emb = model.encode(all_segs, normalize_embeddings=True).astype("float32")
267
+ sims_all = (seg_emb @ q_emb).astype(np.float32)
 
 
 
 
 
268
 
 
269
  rr_scores: List[float] = []
270
  for hid, (start, end), segs in zip(doc_hids, offsets, per_doc_segs):
271
  if start == end:
 
276
  rr = float(np.max(sims))
277
  rr_scores.append(rr)
278
 
 
279
  hm = build_heatmap_html(segs, sims, top_n=min(6, len(segs))) if sims.size else ""
280
  best = best_seg_html(segs, sims) if sims.size else ""
281
  evidence[hid] = {
 
283
  "rerank_score": rr,
284
  "heatmap_html": hm,
285
  "best_seg_html": best,
 
 
286
  }
287
 
288
  cand_rows["rerank_score"] = rr_scores
289
 
 
 
290
  faiss_scores = cand_rows["score"].astype(float).to_numpy()
291
  rr = cand_rows["rerank_score"].astype(float).to_numpy()
292
 
 
299
 
300
 
301
  # =========================
302
+ # Full highlight for ONE hadith
303
  # =========================
304
  def full_highlight_html(
305
  query_norm: str,
 
316
  }
317
 
318
  q_emb = get_query_emb(query_norm)
319
+ seg_emb = model.encode(segs, normalize_embeddings=True).astype("float32")
 
 
 
 
320
  sims = (seg_emb @ q_emb).astype(np.float32)
321
 
322
  s_min = float(np.min(sims))
 
362
  index = faiss.read_index(INDEX_PATH)
363
  meta = pd.read_parquet(META_PATH)
364
 
365
+ # Accept corpusID or hadithID, normalize to hadithID
366
+ id_col = "hadithID" if "hadithID" in meta.columns else ("corpusID" if "corpusID" in meta.columns else None)
367
+ if id_col is None:
368
+ raise ValueError("Meta must contain 'hadithID' or 'corpusID'")
369
+ if id_col != "hadithID":
370
+ meta = meta.rename(columns={id_col: "hadithID"})
371
+
372
  required_cols = {"hadithID", "collection", "hadith_number", "arabic", "english"}
373
  missing = required_cols - set(meta.columns)
374
  if missing:
 
377
  if "arabic_clean" not in meta.columns:
378
  meta["arabic_clean"] = ""
379
 
380
+ meta["arabic"] = meta["arabic"].fillna("").astype(str)
381
+ meta["english"] = meta["english"].fillna("").astype(str)
382
+ meta["arabic_clean"] = meta["arabic_clean"].fillna("").astype(str)
383
+
384
 
385
  # =========================
386
  # FAISS Search
 
391
  return meta.iloc[0:0].copy()
392
 
393
  top_k = max(1, min(int(top_k), MAX_TOP_K))
394
+ q_norm = normalize_ar(q) # Arabic normalize, safe for English too
395
 
396
  q_emb = get_query_emb(q_norm).reshape(1, -1)
397
  scores, idx = index.search(q_emb, top_k)
 
399
  res = meta.iloc[idx[0]].copy()
400
  res["score"] = scores[0]
401
  res = res.sort_values("score", ascending=False)
 
 
 
402
  res = res[res["arabic"].str.strip() != ""]
403
  return res
404
 
 
451
  def search():
452
  q = request.args.get("q", "").strip()
453
 
454
+ # final top-k
455
  k_raw = request.args.get("k", str(DEFAULT_TOP_K)).strip()
456
  try:
457
  k = int(k_raw) if k_raw else DEFAULT_TOP_K
 
459
  k = DEFAULT_TOP_K
460
  k = max(1, min(k, MAX_TOP_K))
461
 
462
+ # rerank pool size
463
  rk_raw = request.args.get("rerank_k", str(DEFAULT_RERANK_K)).strip()
464
  try:
465
  rerank_k = int(rk_raw) if rk_raw else DEFAULT_RERANK_K
 
468
  rerank_k = max(MIN_RERANK_K, min(rerank_k, MAX_RERANK_K))
469
  rerank_k = max(rerank_k, k)
470
 
471
+ # highlight controls
472
  hl_raw = request.args.get("hl_topn", str(DEFAULT_HL_TOPN)).strip()
473
  seg_raw = request.args.get("seg_maxlen", str(DEFAULT_SEG_MAXLEN)).strip()
474
  try:
 
504
 
505
  t0 = time.time()
506
 
507
+ # 1) retrieve pool
508
  df_pool = semantic_search_df(q, top_k=rerank_k)
509
  q_norm = normalize_ar(q)
510
 
511
+ # 2) rerank -> final
512
  df_final, ev = rerank_rows(query_norm=q_norm, df=df_pool, k_final=k)
513
 
514
  took_ms = int((time.time() - t0) * 1000)
515
 
 
516
  results: List[Dict[str, Any]] = []
517
  for _, row in df_final.iterrows():
518
  hid = int(row.get("hadithID")) if pd.notna(row.get("hadithID")) else None
519
  arabic = str(row.get("arabic", "") or "")
520
  english = str(row.get("english", "") or "")
521
 
522
+ ar_clean = str(row.get("arabic_clean", "") or "").strip()
 
 
 
523
  if not ar_clean:
524
  ar_clean = normalize_ar(arabic)
525
 
 
526
  lex_r, lex_terms = lexical_ratio(q_norm, ar_clean)
527
 
 
528
  faiss_score = float(row.get("score")) if pd.notna(row.get("score")) else 0.0
529
  rerank_score = float(row.get("rerank_score")) if pd.notna(row.get("rerank_score")) else faiss_score
530
  final_score = float(row.get("final_score")) if pd.notna(row.get("final_score")) else faiss_score
 
539
  "hadithID": hid,
540
  "collection": str(row.get("collection", "") or ""),
541
  "hadith_number": int(row.get("hadith_number")) if pd.notna(row.get("hadith_number")) else None,
 
 
542
  "score": final_score,
 
 
543
  "faiss_score": faiss_score,
544
  "rerank_score": rerank_score,
 
545
  "conf_label": conf_label,
546
  "conf_class": conf_class,
 
547
  "lex_ratio": float(lex_r),
548
  "lex_terms": lex_terms,
 
549
  "arabic": arabic,
550
  "arabic_clean": ar_clean,
551
  "english": english,
 
 
552
  "heatmap_html": heatmap_html,
553
  "best_seg_html": best_html,
554
  }
555
 
 
 
556
  if want_html and hl_topn > 0:
557
  extras = full_highlight_html(
558
  query_norm=q_norm,
 
561
  seg_maxlen=seg_maxlen,
562
  )
563
  r["arabic_clean_html"] = extras["arabic_clean_html"]
 
564
  r["heatmap_html"] = extras["heatmap_html"] or r["heatmap_html"]
565
  r["best_seg_html"] = extras["best_seg_html"] or r["best_seg_html"]
566
 
 
584
 
585
  @app.get("/highlight")
586
  def highlight():
 
 
 
 
587
  q = request.args.get("q", "").strip()
588
  hid_raw = request.args.get("hadithID", "").strip()
589
 
 
622
  arabic = str(row.get("arabic", "") or "")
623
  english = str(row.get("english", "") or "")
624
 
625
+ ar_clean = str(row.get("arabic_clean", "") or "").strip()
 
 
 
626
  if not ar_clean:
627
  ar_clean = normalize_ar(arabic)
628
 
 
629
  extras = full_highlight_html(
630
  query_norm=q_norm,
631
  arabic_clean_text=ar_clean,
 
633
  seg_maxlen=seg_maxlen,
634
  )
635
 
 
636
  lex_r, lex_terms = lexical_ratio(q_norm, ar_clean)
637
 
638
  return jsonify({
 
643
  "format": "html" if want_html else "json",
644
  "hl_topn": hl_topn,
645
  "seg_maxlen": seg_maxlen,
 
646
  "lex_ratio": float(lex_r),
647
  "lex_terms": lex_terms,
 
648
  "arabic": arabic,
649
  "arabic_clean": ar_clean,
650
  "english": english,
 
651
  "arabic_clean_html": extras.get("arabic_clean_html", "") if want_html else "",
652
  "heatmap_html": extras.get("heatmap_html", ""),
653
  "best_seg_html": extras.get("best_seg_html", ""),
 
655
 
656
 
657
  if __name__ == "__main__":
658
+ # Hugging Face Spaces uses PORT=7860
659
+ port = int(os.getenv("PORT", "7860"))
660
+ app.run(host="0.0.0.0", port=port, debug=False)