Alshargi commited on
Commit
42a6a19
·
verified ·
1 Parent(s): 6949f58

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -100
app.py CHANGED
@@ -8,7 +8,6 @@ from typing import List, Dict, Any, Tuple
8
  import numpy as np
9
  import pandas as pd
10
  import faiss
11
-
12
  from flask import Flask, request, jsonify
13
  from flask_cors import CORS
14
  from sentence_transformers import SentenceTransformer
@@ -24,10 +23,10 @@ MODEL_NAME = os.getenv("HADITH_MODEL_NAME", "intfloat/multilingual-e5-base")
24
  DEFAULT_TOP_K = 10
25
  MAX_TOP_K = 50
26
 
27
- DEFAULT_HL_TOPN = 6 # segments with strong highlight
28
  MAX_HL_TOPN = 25
29
 
30
- DEFAULT_SEG_MAXLEN = 220 # segment size
31
  MAX_SEG_MAXLEN = 420
32
 
33
 
@@ -42,7 +41,6 @@ _AR_DIACRITICS = re.compile(r"""
42
  """, re.VERBOSE)
43
 
44
  def normalize_ar(text: str) -> str:
45
- """Remove tashkeel + normalize common Arabic letter variants."""
46
  if text is None:
47
  return ""
48
  text = str(text)
@@ -69,15 +67,12 @@ def escape_html(s: str) -> str:
69
 
70
 
71
  # =========================
72
- # Semantic segment highlighting
73
  # =========================
74
- def split_ar_segments(text: str, max_len: int = DEFAULT_SEG_MAXLEN) -> List[str]:
75
- """Split Arabic clean text into short segments for semantic highlighting."""
76
  if not text:
77
  return []
78
  t = re.sub(r"\s+", " ", str(text)).strip()
79
-
80
- # Split on punctuation (Arabic + Latin)
81
  parts = re.split(r"(?<=[\.\!\?؟\،\,\;\:])\s+", t)
82
 
83
  segs: List[str] = []
@@ -96,7 +91,6 @@ def split_ar_segments(text: str, max_len: int = DEFAULT_SEG_MAXLEN) -> List[str]
96
  if buf:
97
  segs.append(buf)
98
 
99
- # Fallback chunking
100
  if len(segs) <= 1 and len(t) > max_len:
101
  segs = [t[i:i+max_len].strip() for i in range(0, len(t), max_len) if t[i:i+max_len].strip()]
102
 
@@ -106,18 +100,18 @@ def semantic_highlight_segments_html(
106
  model: SentenceTransformer,
107
  query_norm: str,
108
  arabic_clean: str,
109
- top_n: int = DEFAULT_HL_TOPN,
110
- seg_max_len: int = DEFAULT_SEG_MAXLEN
111
- ) -> Tuple[str, float, float]:
112
  """
113
- Returns HTML with segments colored by semantic similarity to query.
114
- Also returns min/max similarity.
 
115
  """
116
  segs = split_ar_segments(arabic_clean, max_len=seg_max_len)
117
  if not segs:
118
- return escape_html(arabic_clean), 0.0, 0.0
119
 
120
- # E5 format: "query:" and "passage:"
121
  q_emb = model.encode(["query: " + query_norm], normalize_embeddings=True).astype("float32")
122
  seg_emb = model.encode(["passage: " + s for s in segs], normalize_embeddings=True).astype("float32")
123
 
@@ -130,11 +124,14 @@ def semantic_highlight_segments_html(
130
  keep = set(order[:min(top_n, len(segs))])
131
 
132
  html_parts: List[str] = []
 
 
133
  for i, seg in enumerate(segs):
134
  w = (float(sims[i]) - s_min) / denom # 0..1
 
135
 
136
- # Strong highlight for closest segments, softer for others
137
- alpha = (0.18 + 0.62 * w) if i in keep else (0.06 + 0.20 * w)
138
  alpha = max(0.05, min(alpha, 0.82))
139
  border_alpha = max(0.10, min(alpha * 0.8, 0.65))
140
 
@@ -147,12 +144,13 @@ def semantic_highlight_segments_html(
147
  "display: inline;"
148
  )
149
  html_parts.append(f'<span style="{style}">{escape_html(seg)}</span> ')
 
150
 
151
  html = "".join(html_parts).strip()
152
  if not html:
153
  html = escape_html(arabic_clean)
154
 
155
- return html, s_min, s_max
156
 
157
 
158
  # =========================
@@ -164,14 +162,9 @@ if not os.path.exists(INDEX_PATH):
164
  if not os.path.exists(META_PATH):
165
  raise FileNotFoundError(f"Meta parquet not found: {META_PATH}")
166
 
167
- print(f"[BOOT] Loading model: {MODEL_NAME}")
168
  model = SentenceTransformer(MODEL_NAME)
169
-
170
- print(f"[BOOT] Loading faiss index: {INDEX_PATH}")
171
  index = faiss.read_index(INDEX_PATH)
172
-
173
- print(f"[BOOT] Loading meta: {META_PATH}")
174
- meta = pd.read_parquet(META_PATH)
175
 
176
  required_cols = {"hadithID", "collection", "hadith_number", "arabic", "english"}
177
  missing = required_cols - set(meta.columns)
@@ -181,52 +174,33 @@ if missing:
181
  if "arabic_clean" not in meta.columns:
182
  meta["arabic_clean"] = ""
183
 
184
- # normalize types lightly
185
- meta["arabic"] = meta["arabic"].fillna("").astype(str)
186
- meta["english"] = meta["english"].fillna("").astype(str)
187
- meta["collection"] = meta["collection"].fillna("").astype(str)
188
-
189
 
190
- def semantic_search_df(query: str, top_k: int) -> pd.DataFrame:
191
  q = str(query or "").strip()
192
  if not q:
193
  return meta.iloc[0:0].copy()
194
 
195
  top_k = max(1, min(int(top_k), MAX_TOP_K))
196
- q_norm = normalize_ar(q)
197
 
 
198
  q_emb = model.encode(["query: " + q_norm], normalize_embeddings=True).astype("float32")
 
199
  scores, idx = index.search(q_emb, top_k)
200
 
201
  res = meta.iloc[idx[0]].copy()
202
  res["score"] = scores[0]
203
  res = res.sort_values("score", ascending=False)
204
 
205
- # filter empty arabic rows (avoid empty cards)
206
  res["arabic"] = res["arabic"].fillna("").astype(str)
207
  res = res[res["arabic"].str.strip() != ""]
208
-
209
  return res
210
 
211
 
212
  # =========================
213
- # Flask API
214
  # =========================
215
  app = Flask(__name__)
216
- CORS(app, resources={r"/*": {"origins": "*"}})
217
-
218
-
219
- @app.get("/")
220
- def root():
221
- return jsonify({
222
- "ok": True,
223
- "service": "hadith semantic search",
224
- "endpoints": {
225
- "health": "/health",
226
- "search": "/search?q=...&k=10&hl_topn=6&seg_maxlen=220"
227
- }
228
- })
229
-
230
 
231
  @app.get("/health")
232
  def health():
@@ -234,34 +208,29 @@ def health():
234
  "ok": True,
235
  "model": MODEL_NAME,
236
  "rows": int(len(meta)),
237
- "index_ntotal": int(getattr(index, "ntotal", -1)),
238
  })
239
 
240
-
241
  @app.get("/search")
242
  def search():
243
  q = request.args.get("q", "").strip()
244
 
245
- # topK
246
- k_raw = request.args.get("k", str(DEFAULT_TOP_K)).strip()
247
  try:
248
- k_int = int(k_raw) if k_raw else DEFAULT_TOP_K
249
  except Exception:
250
- k_int = DEFAULT_TOP_K
251
- k_int = min(max(1, k_int), MAX_TOP_K)
252
-
253
- # highlight knobs
254
- hl_raw = request.args.get("hl_topn", str(DEFAULT_HL_TOPN)).strip()
255
- seg_raw = request.args.get("seg_maxlen", str(DEFAULT_SEG_MAXLEN)).strip()
256
 
 
257
  try:
258
- hl_topn = int(hl_raw) if hl_raw else DEFAULT_HL_TOPN
259
  except Exception:
260
  hl_topn = DEFAULT_HL_TOPN
261
  hl_topn = min(max(1, hl_topn), MAX_HL_TOPN)
262
 
263
  try:
264
- seg_maxlen = int(seg_raw) if seg_raw else DEFAULT_SEG_MAXLEN
265
  except Exception:
266
  seg_maxlen = DEFAULT_SEG_MAXLEN
267
  seg_maxlen = min(max(120, seg_maxlen), MAX_SEG_MAXLEN)
@@ -271,14 +240,15 @@ def search():
271
  "ok": True,
272
  "query": "",
273
  "query_norm": "",
274
- "k": k_int,
275
- "rows": int(len(meta)),
 
276
  "took_ms": 0,
277
- "results": [],
278
  })
279
 
280
  t0 = time.time()
281
- res_df = semantic_search_df(q, top_k=k_int)
282
  took_ms = int((time.time() - t0) * 1000)
283
 
284
  q_norm = normalize_ar(q)
@@ -287,66 +257,48 @@ def search():
287
  for _, row in res_df.iterrows():
288
  r = row.to_dict()
289
 
290
- arabic = str(r.get("arabic", "") or "")
291
- english = str(r.get("english", "") or "")
292
 
293
  arabic_clean_val = r.get("arabic_clean", "")
294
- if arabic_clean_val is None:
295
- arabic_clean_val = ""
296
- # handle NaN
297
- if isinstance(arabic_clean_val, float) and np.isnan(arabic_clean_val):
298
  arabic_clean_val = ""
299
  arabic_clean = str(arabic_clean_val).strip()
300
  if not arabic_clean:
301
- arabic_clean = normalize_ar(arabic)
302
 
303
- # ✅ semantic highlight segments (returns HTML spans)
304
- arabic_clean_html, s_min, s_max = semantic_highlight_segments_html(
305
  model=model,
306
  query_norm=q_norm,
307
  arabic_clean=arabic_clean,
308
  top_n=hl_topn,
309
- seg_max_len=seg_maxlen
310
  )
311
 
312
- # final fallback never empty
313
- if not str(arabic_clean_html).strip():
314
- arabic_clean_html = escape_html(arabic_clean if arabic_clean else arabic)
315
-
316
  out.append({
 
317
  "hadithID": int(r.get("hadithID")),
318
  "collection": str(r.get("collection", "")),
319
  "hadith_number": int(r.get("hadith_number")),
320
- "score": float(r.get("score", 0.0)),
321
-
322
- "arabic": arabic,
323
  "arabic_clean": arabic_clean,
324
- "english": english,
325
-
326
- # HTML-ready fields
327
- "arabic_clean_html": arabic_clean_html,
328
- "arabic_html": escape_html(arabic),
329
- "english_html": escape_html(english),
330
-
331
- # optional stats
332
- "hl_min": float(s_min),
333
- "hl_max": float(s_max),
334
  })
335
 
336
  return jsonify({
337
  "ok": True,
338
  "query": q,
339
  "query_norm": q_norm,
340
- "k": k_int,
341
  "hl_topn": hl_topn,
342
  "seg_maxlen": seg_maxlen,
343
- "rows": int(len(meta)),
344
  "took_ms": took_ms,
345
- "results": out,
346
  })
347
 
348
 
349
- # HuggingFace Docker runs via CMD (gunicorn/uvicorn) عادة
350
- # لكن هذا مفيد لو شغّلته محلياً:
351
  if __name__ == "__main__":
352
- app.run(host="0.0.0.0", port=int(os.getenv("PORT", "7860")), debug=True)
 
8
  import numpy as np
9
  import pandas as pd
10
  import faiss
 
11
  from flask import Flask, request, jsonify
12
  from flask_cors import CORS
13
  from sentence_transformers import SentenceTransformer
 
23
  DEFAULT_TOP_K = 10
24
  MAX_TOP_K = 50
25
 
26
+ DEFAULT_HL_TOPN = 6
27
  MAX_HL_TOPN = 25
28
 
29
+ DEFAULT_SEG_MAXLEN = 220
30
  MAX_SEG_MAXLEN = 420
31
 
32
 
 
41
  """, re.VERBOSE)
42
 
43
  def normalize_ar(text: str) -> str:
 
44
  if text is None:
45
  return ""
46
  text = str(text)
 
67
 
68
 
69
  # =========================
70
+ # Segmenting + semantic highlight
71
  # =========================
72
+ def split_ar_segments(text: str, max_len: int) -> List[str]:
 
73
  if not text:
74
  return []
75
  t = re.sub(r"\s+", " ", str(text)).strip()
 
 
76
  parts = re.split(r"(?<=[\.\!\?؟\،\,\;\:])\s+", t)
77
 
78
  segs: List[str] = []
 
91
  if buf:
92
  segs.append(buf)
93
 
 
94
  if len(segs) <= 1 and len(t) > max_len:
95
  segs = [t[i:i+max_len].strip() for i in range(0, len(t), max_len) if t[i:i+max_len].strip()]
96
 
 
100
  model: SentenceTransformer,
101
  query_norm: str,
102
  arabic_clean: str,
103
+ top_n: int,
104
+ seg_max_len: int
105
+ ) -> Tuple[str, List[Dict[str, Any]]]:
106
  """
107
+ Returns:
108
+ - HTML string with highlighted segments
109
+ - segments_debug: list of {seg, sim, strong}
110
  """
111
  segs = split_ar_segments(arabic_clean, max_len=seg_max_len)
112
  if not segs:
113
+ return escape_html(arabic_clean), []
114
 
 
115
  q_emb = model.encode(["query: " + query_norm], normalize_embeddings=True).astype("float32")
116
  seg_emb = model.encode(["passage: " + s for s in segs], normalize_embeddings=True).astype("float32")
117
 
 
124
  keep = set(order[:min(top_n, len(segs))])
125
 
126
  html_parts: List[str] = []
127
+ dbg: List[Dict[str, Any]] = []
128
+
129
  for i, seg in enumerate(segs):
130
  w = (float(sims[i]) - s_min) / denom # 0..1
131
+ strong = i in keep
132
 
133
+ # Strong highlight for top segments, softer for others
134
+ alpha = (0.18 + 0.62 * w) if strong else (0.06 + 0.20 * w)
135
  alpha = max(0.05, min(alpha, 0.82))
136
  border_alpha = max(0.10, min(alpha * 0.8, 0.65))
137
 
 
144
  "display: inline;"
145
  )
146
  html_parts.append(f'<span style="{style}">{escape_html(seg)}</span> ')
147
+ dbg.append({"seg": seg, "sim": float(sims[i]), "strong": bool(strong)})
148
 
149
  html = "".join(html_parts).strip()
150
  if not html:
151
  html = escape_html(arabic_clean)
152
 
153
+ return html, dbg
154
 
155
 
156
  # =========================
 
162
  if not os.path.exists(META_PATH):
163
  raise FileNotFoundError(f"Meta parquet not found: {META_PATH}")
164
 
 
165
  model = SentenceTransformer(MODEL_NAME)
 
 
166
  index = faiss.read_index(INDEX_PATH)
167
+ meta = pd.read_parquet(META_PATH)
 
 
168
 
169
  required_cols = {"hadithID", "collection", "hadith_number", "arabic", "english"}
170
  missing = required_cols - set(meta.columns)
 
174
  if "arabic_clean" not in meta.columns:
175
  meta["arabic_clean"] = ""
176
 
 
 
 
 
 
177
 
178
+ def semantic_search(query: str, top_k: int) -> pd.DataFrame:
179
  q = str(query or "").strip()
180
  if not q:
181
  return meta.iloc[0:0].copy()
182
 
183
  top_k = max(1, min(int(top_k), MAX_TOP_K))
 
184
 
185
+ q_norm = normalize_ar(q)
186
  q_emb = model.encode(["query: " + q_norm], normalize_embeddings=True).astype("float32")
187
+
188
  scores, idx = index.search(q_emb, top_k)
189
 
190
  res = meta.iloc[idx[0]].copy()
191
  res["score"] = scores[0]
192
  res = res.sort_values("score", ascending=False)
193
 
 
194
  res["arabic"] = res["arabic"].fillna("").astype(str)
195
  res = res[res["arabic"].str.strip() != ""]
 
196
  return res
197
 
198
 
199
  # =========================
200
+ # Flask app (JSON API)
201
  # =========================
202
  app = Flask(__name__)
203
+ CORS(app) # مهم عشان تقدر تناديه من أي هوست (HTML خارجي)
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
  @app.get("/health")
206
  def health():
 
208
  "ok": True,
209
  "model": MODEL_NAME,
210
  "rows": int(len(meta)),
211
+ "index_ntotal": int(index.ntotal),
212
  })
213
 
 
214
  @app.get("/search")
215
  def search():
216
  q = request.args.get("q", "").strip()
217
 
218
+ # k
 
219
  try:
220
+ k = int(request.args.get("k", str(DEFAULT_TOP_K)))
221
  except Exception:
222
+ k = DEFAULT_TOP_K
223
+ k = min(max(1, k), MAX_TOP_K)
 
 
 
 
224
 
225
+ # highlight controls
226
  try:
227
+ hl_topn = int(request.args.get("hl_topn", str(DEFAULT_HL_TOPN)))
228
  except Exception:
229
  hl_topn = DEFAULT_HL_TOPN
230
  hl_topn = min(max(1, hl_topn), MAX_HL_TOPN)
231
 
232
  try:
233
+ seg_maxlen = int(request.args.get("seg_maxlen", str(DEFAULT_SEG_MAXLEN)))
234
  except Exception:
235
  seg_maxlen = DEFAULT_SEG_MAXLEN
236
  seg_maxlen = min(max(120, seg_maxlen), MAX_SEG_MAXLEN)
 
240
  "ok": True,
241
  "query": "",
242
  "query_norm": "",
243
+ "k": k,
244
+ "hl_topn": hl_topn,
245
+ "seg_maxlen": seg_maxlen,
246
  "took_ms": 0,
247
+ "results": []
248
  })
249
 
250
  t0 = time.time()
251
+ res_df = semantic_search(q, top_k=k)
252
  took_ms = int((time.time() - t0) * 1000)
253
 
254
  q_norm = normalize_ar(q)
 
257
  for _, row in res_df.iterrows():
258
  r = row.to_dict()
259
 
260
+ arabic_text = str(r.get("arabic", "") or "")
261
+ english_text = str(r.get("english", "") or "")
262
 
263
  arabic_clean_val = r.get("arabic_clean", "")
264
+ if arabic_clean_val is None or (isinstance(arabic_clean_val, float) and np.isnan(arabic_clean_val)):
 
 
 
265
  arabic_clean_val = ""
266
  arabic_clean = str(arabic_clean_val).strip()
267
  if not arabic_clean:
268
+ arabic_clean = normalize_ar(arabic_text)
269
 
270
+ # ✅ هنا الهايلايت الدلالي مثل كودك
271
+ arabic_highlight_html, _dbg = semantic_highlight_segments_html(
272
  model=model,
273
  query_norm=q_norm,
274
  arabic_clean=arabic_clean,
275
  top_n=hl_topn,
276
+ seg_max_len=seg_maxlen,
277
  )
278
 
 
 
 
 
279
  out.append({
280
+ "score": float(r.get("score", 0.0)),
281
  "hadithID": int(r.get("hadithID")),
282
  "collection": str(r.get("collection", "")),
283
  "hadith_number": int(r.get("hadith_number")),
284
+ "arabic": arabic_text,
 
 
285
  "arabic_clean": arabic_clean,
286
+ "arabic_highlight_html": arabic_highlight_html, # ✅ أهم شيء
287
+ "english": english_text,
 
 
 
 
 
 
 
 
288
  })
289
 
290
  return jsonify({
291
  "ok": True,
292
  "query": q,
293
  "query_norm": q_norm,
294
+ "k": k,
295
  "hl_topn": hl_topn,
296
  "seg_maxlen": seg_maxlen,
 
297
  "took_ms": took_ms,
298
+ "results": out
299
  })
300
 
301
 
302
+ # HF Spaces runs with gunicorn; locally:
 
303
  if __name__ == "__main__":
304
+ app.run(host="0.0.0.0", port=7860, debug=True)