Alshargi commited on
Commit
6949f58
·
verified ·
1 Parent(s): 50fe70f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +227 -147
app.py CHANGED
@@ -3,11 +3,12 @@ from __future__ import annotations
3
  import os
4
  import re
5
  import time
6
- from typing import Any, Dict, List, Optional, Tuple
7
 
8
  import numpy as np
9
  import pandas as pd
10
  import faiss
 
11
  from flask import Flask, request, jsonify
12
  from flask_cors import CORS
13
  from sentence_transformers import SentenceTransformer
@@ -23,9 +24,15 @@ MODEL_NAME = os.getenv("HADITH_MODEL_NAME", "intfloat/multilingual-e5-base")
23
  DEFAULT_TOP_K = 10
24
  MAX_TOP_K = 50
25
 
 
 
 
 
 
 
26
 
27
  # =========================
28
- # Arabic normalization (remove tashkeel + normalize letters)
29
  # =========================
30
  _AR_DIACRITICS = re.compile(r"""
31
  [\u0610-\u061A]
@@ -35,6 +42,7 @@ _AR_DIACRITICS = re.compile(r"""
35
  """, re.VERBOSE)
36
 
37
  def normalize_ar(text: str) -> str:
 
38
  if text is None:
39
  return ""
40
  text = str(text)
@@ -47,99 +55,160 @@ def normalize_ar(text: str) -> str:
47
  text = re.sub(r"\s+", " ", text).strip()
48
  return text
49
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  # =========================
52
- # Lazy load (load resources on demand)
53
  # =========================
54
- _model: Optional[SentenceTransformer] = None
55
- _index = None
56
- _meta: Optional[pd.DataFrame] = None
57
-
58
- def get_resources() -> Tuple[SentenceTransformer, Any, pd.DataFrame]:
59
- global _model, _index, _meta
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- if _model is not None and _index is not None and _meta is not None:
62
- return _model, _index, _meta
 
63
 
64
- if not os.path.exists(INDEX_PATH):
65
- raise FileNotFoundError(f"FAISS index not found: {INDEX_PATH}")
 
 
66
 
67
- if not os.path.exists(META_PATH):
68
- raise FileNotFoundError(f"Meta parquet not found: {META_PATH}")
69
 
70
- _model = SentenceTransformer(MODEL_NAME)
71
- _index = faiss.read_index(INDEX_PATH)
72
- _meta = pd.read_parquet(META_PATH)
73
 
74
- required_cols = {"hadithID", "collection", "hadith_number", "arabic", "english"}
75
- missing = required_cols - set(_meta.columns)
76
- if missing:
77
- raise ValueError(f"Meta is missing required columns: {missing}")
78
 
79
- if "arabic_clean" not in _meta.columns:
80
- _meta["arabic_clean"] = ""
 
 
 
 
 
 
 
81
 
82
- # Normalize types / fill missing
83
- for col in ["arabic", "english", "arabic_clean", "collection"]:
84
- if col in _meta.columns:
85
- _meta[col] = _meta[col].fillna("").astype(str)
86
 
87
- return _model, _index, _meta
88
 
89
 
90
  # =========================
91
- # Search
92
  # =========================
93
- def semantic_search(query: str, top_k: int = DEFAULT_TOP_K) -> pd.DataFrame:
94
- model, index, meta = get_resources()
 
 
 
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  q = str(query or "").strip()
97
  if not q:
98
  return meta.iloc[0:0].copy()
99
 
100
  top_k = max(1, min(int(top_k), MAX_TOP_K))
101
-
102
  q_norm = normalize_ar(q)
103
- q_emb = model.encode(["query: " + q_norm], normalize_embeddings=True).astype("float32")
104
 
 
105
  scores, idx = index.search(q_emb, top_k)
106
 
107
  res = meta.iloc[idx[0]].copy()
108
- res["score"] = scores[0].astype(float)
109
  res = res.sort_values("score", ascending=False)
110
 
111
- # Filter empty arabic just in case
112
  res["arabic"] = res["arabic"].fillna("").astype(str)
113
  res = res[res["arabic"].str.strip() != ""]
114
 
115
  return res
116
 
117
 
118
- def row_to_json(row: pd.Series, include_text: bool = True) -> Dict[str, Any]:
119
- arabic = str(row.get("arabic", "") or "")
120
- english = str(row.get("english", "") or "")
121
-
122
- arabic_clean = str(row.get("arabic_clean", "") or "").strip()
123
- if not arabic_clean:
124
- arabic_clean = normalize_ar(arabic)
125
-
126
- base: Dict[str, Any] = {
127
- "score": float(row.get("score", 0.0)),
128
- "hadithID": int(row.get("hadithID")),
129
- "collection": str(row.get("collection", "")),
130
- "hadith_number": int(row.get("hadith_number")),
131
- }
132
-
133
- if include_text:
134
- base.update({
135
- "arabic": arabic,
136
- "arabic_clean": arabic_clean,
137
- "english": english,
138
- })
139
-
140
- return base
141
-
142
-
143
  # =========================
144
  # Flask API
145
  # =========================
@@ -151,122 +220,133 @@ CORS(app, resources={r"/*": {"origins": "*"}})
151
  def root():
152
  return jsonify({
153
  "ok": True,
154
- "service": "hadeeth semantic search api",
155
- "endpoints": ["/health", "/search (GET/POST)"]
 
 
 
156
  })
157
 
158
 
159
  @app.get("/health")
160
  def health():
161
- # Don't force-load model/index/meta here if you want it super fast
162
- # But we can still show file presence:
163
- files_ok = os.path.exists(INDEX_PATH) and os.path.exists(META_PATH)
164
- info = {
165
  "ok": True,
166
- "files_ok": files_ok,
167
- "index_path": INDEX_PATH,
168
- "meta_path": META_PATH,
169
  "model": MODEL_NAME,
170
- }
 
 
171
 
172
- # If you want to show counts (this will load resources):
173
- try:
174
- _, index, meta = get_resources()
175
- info["rows"] = int(len(meta))
176
- info["index_ntotal"] = int(getattr(index, "ntotal", -1))
177
- info["loaded"] = True
178
- except Exception as e:
179
- info["loaded"] = False
180
- info["load_error"] = str(e)
181
 
182
- return jsonify(info)
 
 
183
 
 
 
 
 
 
 
 
184
 
185
- @app.post("/search")
186
- def search_post():
187
- """
188
- Body JSON:
189
- {
190
- "q": "��لرزق",
191
- "k": 10,
192
- "include_text": true
193
- }
194
- """
195
- payload = request.get_json(silent=True) or {}
196
 
197
- q = (payload.get("q") or "").strip()
198
- if not q:
199
- return jsonify({"ok": False, "error": "Missing 'q'"}), 400
 
 
200
 
201
- k = payload.get("k", DEFAULT_TOP_K)
202
  try:
203
- k = int(k)
204
  except Exception:
205
- k = DEFAULT_TOP_K
206
- k = max(1, min(k, MAX_TOP_K))
207
 
208
- include_text = payload.get("include_text", True)
209
- include_text = bool(include_text)
 
 
 
 
 
 
 
 
210
 
211
  t0 = time.time()
212
- try:
213
- res_df = semantic_search(q, top_k=k)
214
- except Exception as e:
215
- return jsonify({"ok": False, "error": str(e)}), 500
216
  took_ms = int((time.time() - t0) * 1000)
217
 
218
- results = [row_to_json(r, include_text=include_text) for _, r in res_df.iterrows()]
219
-
220
- return jsonify({
221
- "ok": True,
222
- "query": q,
223
- "query_norm": normalize_ar(q),
224
- "k": k,
225
- "took_ms": took_ms,
226
- "results_count": len(results),
227
- "results": results
228
- })
229
-
230
-
231
- @app.get("/search")
232
- def search_get():
233
- """
234
- GET /search?q=...&k=10&include_text=1
235
- """
236
- q = (request.args.get("q") or "").strip()
237
- if not q:
238
- return jsonify({"ok": False, "error": "Missing 'q'"}), 400
239
 
240
- k_raw = request.args.get("k", str(DEFAULT_TOP_K))
241
- try:
242
- k = int(k_raw)
243
- except Exception:
244
- k = DEFAULT_TOP_K
245
- k = max(1, min(k, MAX_TOP_K))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
- include_text_raw = request.args.get("include_text", "1")
248
- include_text = include_text_raw not in ("0", "false", "False", "")
 
249
 
250
- t0 = time.time()
251
- try:
252
- res_df = semantic_search(q, top_k=k)
253
- except Exception as e:
254
- return jsonify({"ok": False, "error": str(e)}), 500
255
- took_ms = int((time.time() - t0) * 1000)
256
 
257
- results = [row_to_json(r, include_text=include_text) for _, r in res_df.iterrows()]
 
 
 
258
 
259
  return jsonify({
260
  "ok": True,
261
  "query": q,
262
- "query_norm": normalize_ar(q),
263
- "k": k,
 
 
 
264
  "took_ms": took_ms,
265
- "results_count": len(results),
266
- "results": results
267
  })
268
 
269
 
 
 
270
  if __name__ == "__main__":
271
- # Local dev only
272
- app.run(host="0.0.0.0", port=7860, debug=False)
 
3
  import os
4
  import re
5
  import time
6
+ from typing import List, Dict, Any, Tuple
7
 
8
  import numpy as np
9
  import pandas as pd
10
  import faiss
11
+
12
  from flask import Flask, request, jsonify
13
  from flask_cors import CORS
14
  from sentence_transformers import SentenceTransformer
 
24
  DEFAULT_TOP_K = 10
25
  MAX_TOP_K = 50
26
 
27
+ DEFAULT_HL_TOPN = 6 # segments with strong highlight
28
+ MAX_HL_TOPN = 25
29
+
30
+ DEFAULT_SEG_MAXLEN = 220 # segment size
31
+ MAX_SEG_MAXLEN = 420
32
+
33
 
34
  # =========================
35
+ # Arabic normalization
36
  # =========================
37
  _AR_DIACRITICS = re.compile(r"""
38
  [\u0610-\u061A]
 
42
  """, re.VERBOSE)
43
 
44
  def normalize_ar(text: str) -> str:
45
+ """Remove tashkeel + normalize common Arabic letter variants."""
46
  if text is None:
47
  return ""
48
  text = str(text)
 
55
  text = re.sub(r"\s+", " ", text).strip()
56
  return text
57
 
58
+ def escape_html(s: str) -> str:
59
+ if s is None:
60
+ return ""
61
+ return (
62
+ str(s)
63
+ .replace("&", "&")
64
+ .replace("<", "&lt;")
65
+ .replace(">", "&gt;")
66
+ .replace('"', "&quot;")
67
+ .replace("'", "&#39;")
68
+ )
69
+
70
 
71
  # =========================
72
+ # Semantic segment highlighting
73
  # =========================
74
+ def split_ar_segments(text: str, max_len: int = DEFAULT_SEG_MAXLEN) -> List[str]:
75
+ """Split Arabic clean text into short segments for semantic highlighting."""
76
+ if not text:
77
+ return []
78
+ t = re.sub(r"\s+", " ", str(text)).strip()
79
+
80
+ # Split on punctuation (Arabic + Latin)
81
+ parts = re.split(r"(?<=[\.\!\?؟\،\,\;\:])\s+", t)
82
+
83
+ segs: List[str] = []
84
+ buf = ""
85
+ for p in parts:
86
+ p = (p or "").strip()
87
+ if not p:
88
+ continue
89
+ if not buf:
90
+ buf = p
91
+ elif len(buf) + 1 + len(p) <= max_len:
92
+ buf = f"{buf} {p}"
93
+ else:
94
+ segs.append(buf)
95
+ buf = p
96
+ if buf:
97
+ segs.append(buf)
98
+
99
+ # Fallback chunking
100
+ if len(segs) <= 1 and len(t) > max_len:
101
+ segs = [t[i:i+max_len].strip() for i in range(0, len(t), max_len) if t[i:i+max_len].strip()]
102
+
103
+ return segs
104
+
105
+ def semantic_highlight_segments_html(
106
+ model: SentenceTransformer,
107
+ query_norm: str,
108
+ arabic_clean: str,
109
+ top_n: int = DEFAULT_HL_TOPN,
110
+ seg_max_len: int = DEFAULT_SEG_MAXLEN
111
+ ) -> Tuple[str, float, float]:
112
+ """
113
+ Returns HTML with segments colored by semantic similarity to query.
114
+ Also returns min/max similarity.
115
+ """
116
+ segs = split_ar_segments(arabic_clean, max_len=seg_max_len)
117
+ if not segs:
118
+ return escape_html(arabic_clean), 0.0, 0.0
119
 
120
+ # E5 format: "query:" and "passage:"
121
+ q_emb = model.encode(["query: " + query_norm], normalize_embeddings=True).astype("float32")
122
+ seg_emb = model.encode(["passage: " + s for s in segs], normalize_embeddings=True).astype("float32")
123
 
124
+ sims = (seg_emb @ q_emb[0]).astype(np.float32)
125
+ s_min = float(np.min(sims))
126
+ s_max = float(np.max(sims))
127
+ denom = (s_max - s_min) if (s_max - s_min) > 1e-6 else 1.0
128
 
129
+ order = np.argsort(-sims)
130
+ keep = set(order[:min(top_n, len(segs))])
131
 
132
+ html_parts: List[str] = []
133
+ for i, seg in enumerate(segs):
134
+ w = (float(sims[i]) - s_min) / denom # 0..1
135
 
136
+ # Strong highlight for closest segments, softer for others
137
+ alpha = (0.18 + 0.62 * w) if i in keep else (0.06 + 0.20 * w)
138
+ alpha = max(0.05, min(alpha, 0.82))
139
+ border_alpha = max(0.10, min(alpha * 0.8, 0.65))
140
 
141
+ style = (
142
+ f"background: rgba(255, 230, 120, {alpha:.3f});"
143
+ f"border: 1px solid rgba(234, 179, 8, {border_alpha:.3f});"
144
+ "border-radius: 12px;"
145
+ "padding: 3px 8px;"
146
+ "margin: 0 4px 6px 0;"
147
+ "display: inline;"
148
+ )
149
+ html_parts.append(f'<span style="{style}">{escape_html(seg)}</span> ')
150
 
151
+ html = "".join(html_parts).strip()
152
+ if not html:
153
+ html = escape_html(arabic_clean)
 
154
 
155
+ return html, s_min, s_max
156
 
157
 
158
  # =========================
159
+ # Load model + index + meta (once)
160
  # =========================
161
+ if not os.path.exists(INDEX_PATH):
162
+ raise FileNotFoundError(f"FAISS index not found: {INDEX_PATH}")
163
+
164
+ if not os.path.exists(META_PATH):
165
+ raise FileNotFoundError(f"Meta parquet not found: {META_PATH}")
166
 
167
+ print(f"[BOOT] Loading model: {MODEL_NAME}")
168
+ model = SentenceTransformer(MODEL_NAME)
169
+
170
+ print(f"[BOOT] Loading faiss index: {INDEX_PATH}")
171
+ index = faiss.read_index(INDEX_PATH)
172
+
173
+ print(f"[BOOT] Loading meta: {META_PATH}")
174
+ meta = pd.read_parquet(META_PATH)
175
+
176
+ required_cols = {"hadithID", "collection", "hadith_number", "arabic", "english"}
177
+ missing = required_cols - set(meta.columns)
178
+ if missing:
179
+ raise ValueError(f"Meta is missing required columns: {missing}")
180
+
181
+ if "arabic_clean" not in meta.columns:
182
+ meta["arabic_clean"] = ""
183
+
184
+ # normalize types lightly
185
+ meta["arabic"] = meta["arabic"].fillna("").astype(str)
186
+ meta["english"] = meta["english"].fillna("").astype(str)
187
+ meta["collection"] = meta["collection"].fillna("").astype(str)
188
+
189
+
190
+ def semantic_search_df(query: str, top_k: int) -> pd.DataFrame:
191
  q = str(query or "").strip()
192
  if not q:
193
  return meta.iloc[0:0].copy()
194
 
195
  top_k = max(1, min(int(top_k), MAX_TOP_K))
 
196
  q_norm = normalize_ar(q)
 
197
 
198
+ q_emb = model.encode(["query: " + q_norm], normalize_embeddings=True).astype("float32")
199
  scores, idx = index.search(q_emb, top_k)
200
 
201
  res = meta.iloc[idx[0]].copy()
202
+ res["score"] = scores[0]
203
  res = res.sort_values("score", ascending=False)
204
 
205
+ # filter empty arabic rows (avoid empty cards)
206
  res["arabic"] = res["arabic"].fillna("").astype(str)
207
  res = res[res["arabic"].str.strip() != ""]
208
 
209
  return res
210
 
211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  # =========================
213
  # Flask API
214
  # =========================
 
220
  def root():
221
  return jsonify({
222
  "ok": True,
223
+ "service": "hadith semantic search",
224
+ "endpoints": {
225
+ "health": "/health",
226
+ "search": "/search?q=...&k=10&hl_topn=6&seg_maxlen=220"
227
+ }
228
  })
229
 
230
 
231
  @app.get("/health")
232
  def health():
233
+ return jsonify({
 
 
 
234
  "ok": True,
 
 
 
235
  "model": MODEL_NAME,
236
+ "rows": int(len(meta)),
237
+ "index_ntotal": int(getattr(index, "ntotal", -1)),
238
+ })
239
 
 
 
 
 
 
 
 
 
 
240
 
241
+ @app.get("/search")
242
+ def search():
243
+ q = request.args.get("q", "").strip()
244
 
245
+ # topK
246
+ k_raw = request.args.get("k", str(DEFAULT_TOP_K)).strip()
247
+ try:
248
+ k_int = int(k_raw) if k_raw else DEFAULT_TOP_K
249
+ except Exception:
250
+ k_int = DEFAULT_TOP_K
251
+ k_int = min(max(1, k_int), MAX_TOP_K)
252
 
253
+ # highlight knobs
254
+ hl_raw = request.args.get("hl_topn", str(DEFAULT_HL_TOPN)).strip()
255
+ seg_raw = request.args.get("seg_maxlen", str(DEFAULT_SEG_MAXLEN)).strip()
 
 
 
 
 
 
 
 
256
 
257
+ try:
258
+ hl_topn = int(hl_raw) if hl_raw else DEFAULT_HL_TOPN
259
+ except Exception:
260
+ hl_topn = DEFAULT_HL_TOPN
261
+ hl_topn = min(max(1, hl_topn), MAX_HL_TOPN)
262
 
 
263
  try:
264
+ seg_maxlen = int(seg_raw) if seg_raw else DEFAULT_SEG_MAXLEN
265
  except Exception:
266
+ seg_maxlen = DEFAULT_SEG_MAXLEN
267
+ seg_maxlen = min(max(120, seg_maxlen), MAX_SEG_MAXLEN)
268
 
269
+ if not q:
270
+ return jsonify({
271
+ "ok": True,
272
+ "query": "",
273
+ "query_norm": "",
274
+ "k": k_int,
275
+ "rows": int(len(meta)),
276
+ "took_ms": 0,
277
+ "results": [],
278
+ })
279
 
280
  t0 = time.time()
281
+ res_df = semantic_search_df(q, top_k=k_int)
 
 
 
282
  took_ms = int((time.time() - t0) * 1000)
283
 
284
+ q_norm = normalize_ar(q)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
 
286
+ out: List[Dict[str, Any]] = []
287
+ for _, row in res_df.iterrows():
288
+ r = row.to_dict()
289
+
290
+ arabic = str(r.get("arabic", "") or "")
291
+ english = str(r.get("english", "") or "")
292
+
293
+ arabic_clean_val = r.get("arabic_clean", "")
294
+ if arabic_clean_val is None:
295
+ arabic_clean_val = ""
296
+ # handle NaN
297
+ if isinstance(arabic_clean_val, float) and np.isnan(arabic_clean_val):
298
+ arabic_clean_val = ""
299
+ arabic_clean = str(arabic_clean_val).strip()
300
+ if not arabic_clean:
301
+ arabic_clean = normalize_ar(arabic)
302
+
303
+ # ✅ semantic highlight segments (returns HTML spans)
304
+ arabic_clean_html, s_min, s_max = semantic_highlight_segments_html(
305
+ model=model,
306
+ query_norm=q_norm,
307
+ arabic_clean=arabic_clean,
308
+ top_n=hl_topn,
309
+ seg_max_len=seg_maxlen
310
+ )
311
+
312
+ # final fallback never empty
313
+ if not str(arabic_clean_html).strip():
314
+ arabic_clean_html = escape_html(arabic_clean if arabic_clean else arabic)
315
+
316
+ out.append({
317
+ "hadithID": int(r.get("hadithID")),
318
+ "collection": str(r.get("collection", "")),
319
+ "hadith_number": int(r.get("hadith_number")),
320
+ "score": float(r.get("score", 0.0)),
321
 
322
+ "arabic": arabic,
323
+ "arabic_clean": arabic_clean,
324
+ "english": english,
325
 
326
+ # HTML-ready fields
327
+ "arabic_clean_html": arabic_clean_html,
328
+ "arabic_html": escape_html(arabic),
329
+ "english_html": escape_html(english),
 
 
330
 
331
+ # optional stats
332
+ "hl_min": float(s_min),
333
+ "hl_max": float(s_max),
334
+ })
335
 
336
  return jsonify({
337
  "ok": True,
338
  "query": q,
339
+ "query_norm": q_norm,
340
+ "k": k_int,
341
+ "hl_topn": hl_topn,
342
+ "seg_maxlen": seg_maxlen,
343
+ "rows": int(len(meta)),
344
  "took_ms": took_ms,
345
+ "results": out,
 
346
  })
347
 
348
 
349
+ # HuggingFace Docker runs via CMD (gunicorn/uvicorn) عادة
350
+ # لكن هذا مفيد لو شغّلته محلياً:
351
  if __name__ == "__main__":
352
+ app.run(host="0.0.0.0", port=int(os.getenv("PORT", "7860")), debug=True)