Alshargi commited on
Commit
50fe70f
·
verified ·
1 Parent(s): d293589

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -55
app.py CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations
3
  import os
4
  import re
5
  import time
6
- from typing import Any, Dict, List, Optional
7
 
8
  import numpy as np
9
  import pandas as pd
@@ -23,12 +23,9 @@ MODEL_NAME = os.getenv("HADITH_MODEL_NAME", "intfloat/multilingual-e5-base")
23
  DEFAULT_TOP_K = 10
24
  MAX_TOP_K = 50
25
 
26
- # If you want a smaller response payload
27
- DEFAULT_INCLUDE_TEXT = True
28
-
29
 
30
  # =========================
31
- # Arabic normalization
32
  # =========================
33
  _AR_DIACRITICS = re.compile(r"""
34
  [\u0610-\u061A]
@@ -38,7 +35,6 @@ _AR_DIACRITICS = re.compile(r"""
38
  """, re.VERBOSE)
39
 
40
  def normalize_ar(text: str) -> str:
41
- """Remove tashkeel + normalize common Arabic letter variants."""
42
  if text is None:
43
  return ""
44
  text = str(text)
@@ -53,33 +49,50 @@ def normalize_ar(text: str) -> str:
53
 
54
 
55
  # =========================
56
- # Load model + index + meta (once)
57
  # =========================
58
- if not os.path.exists(INDEX_PATH):
59
- raise FileNotFoundError(f"FAISS index not found: {INDEX_PATH}")
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- if not os.path.exists(META_PATH):
62
- raise FileNotFoundError(f"Meta parquet not found: {META_PATH}")
 
63
 
64
- model = SentenceTransformer(MODEL_NAME)
65
- index = faiss.read_index(INDEX_PATH)
66
- meta = pd.read_parquet(META_PATH)
 
67
 
68
- required_cols = {"hadithID", "collection", "hadith_number", "arabic", "english"}
69
- missing = required_cols - set(meta.columns)
70
- if missing:
71
- raise ValueError(f"Meta is missing required columns: {missing}")
72
 
73
- if "arabic_clean" not in meta.columns:
74
- meta["arabic_clean"] = ""
 
 
75
 
76
- # Normalize column types to avoid NaN surprises
77
- for col in ["arabic", "english", "arabic_clean", "collection"]:
78
- if col in meta.columns:
79
- meta[col] = meta[col].fillna("").astype(str)
80
 
81
 
 
 
 
82
  def semantic_search(query: str, top_k: int = DEFAULT_TOP_K) -> pd.DataFrame:
 
 
83
  q = str(query or "").strip()
84
  if not q:
85
  return meta.iloc[0:0].copy()
@@ -88,13 +101,14 @@ def semantic_search(query: str, top_k: int = DEFAULT_TOP_K) -> pd.DataFrame:
88
 
89
  q_norm = normalize_ar(q)
90
  q_emb = model.encode(["query: " + q_norm], normalize_embeddings=True).astype("float32")
 
91
  scores, idx = index.search(q_emb, top_k)
92
 
93
  res = meta.iloc[idx[0]].copy()
94
  res["score"] = scores[0].astype(float)
95
  res = res.sort_values("score", ascending=False)
96
 
97
- # Ensure no empty Arabic (avoid useless results)
98
  res["arabic"] = res["arabic"].fillna("").astype(str)
99
  res = res[res["arabic"].str.strip() != ""]
100
 
@@ -103,11 +117,13 @@ def semantic_search(query: str, top_k: int = DEFAULT_TOP_K) -> pd.DataFrame:
103
 
104
  def row_to_json(row: pd.Series, include_text: bool = True) -> Dict[str, Any]:
105
  arabic = str(row.get("arabic", "") or "")
 
 
106
  arabic_clean = str(row.get("arabic_clean", "") or "").strip()
107
  if not arabic_clean:
108
  arabic_clean = normalize_ar(arabic)
109
 
110
- base = {
111
  "score": float(row.get("score", 0.0)),
112
  "hadithID": int(row.get("hadithID")),
113
  "collection": str(row.get("collection", "")),
@@ -118,58 +134,88 @@ def row_to_json(row: pd.Series, include_text: bool = True) -> Dict[str, Any]:
118
  base.update({
119
  "arabic": arabic,
120
  "arabic_clean": arabic_clean,
121
- "english": str(row.get("english", "") or ""),
122
  })
123
 
124
  return base
125
 
126
 
127
  # =========================
128
- # Flask API app
129
  # =========================
130
  app = Flask(__name__)
131
- CORS(app, resources={r"/*": {"origins": "*"}}) # allow calls from other hosts
132
 
133
 
134
- @app.get("/health")
135
- def health():
136
  return jsonify({
137
  "ok": True,
138
- "rows": int(len(meta)),
139
- "index_ntotal": int(getattr(index, "ntotal", -1)),
140
- "model": MODEL_NAME
141
  })
142
 
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  @app.post("/search")
145
- def search():
146
  """
147
- JSON body:
148
  {
149
- "q": "الزرق و سبيل الرزق",
150
  "k": 10,
151
  "include_text": true
152
  }
153
  """
154
  payload = request.get_json(silent=True) or {}
155
- q = (payload.get("q") or "").strip()
156
- k = payload.get("k", DEFAULT_TOP_K)
157
- include_text = payload.get("include_text", DEFAULT_INCLUDE_TEXT)
158
 
159
- # Validate
160
  if not q:
161
  return jsonify({"ok": False, "error": "Missing 'q'"}), 400
 
 
162
  try:
163
  k = int(k)
164
  except Exception:
165
  k = DEFAULT_TOP_K
166
  k = max(1, min(k, MAX_TOP_K))
167
 
 
 
 
168
  t0 = time.time()
169
- res_df = semantic_search(q, top_k=k)
 
 
 
170
  took_ms = int((time.time() - t0) * 1000)
171
 
172
- results = [row_to_json(r, include_text=bool(include_text)) for _, r in res_df.iterrows()]
173
 
174
  return jsonify({
175
  "ok": True,
@@ -186,34 +232,35 @@ def search():
186
  def search_get():
187
  """
188
  GET /search?q=...&k=10&include_text=1
189
- Useful for quick testing in browser.
190
  """
191
  q = (request.args.get("q") or "").strip()
192
- k = request.args.get("k", str(DEFAULT_TOP_K))
193
- include_text = request.args.get("include_text", "1")
194
-
195
  if not q:
196
  return jsonify({"ok": False, "error": "Missing 'q'"}), 400
197
 
 
198
  try:
199
- k_int = int(k)
200
  except Exception:
201
- k_int = DEFAULT_TOP_K
202
- k_int = max(1, min(k_int, MAX_TOP_K))
203
 
204
- include_text_bool = include_text not in ("0", "false", "False", "")
 
205
 
206
  t0 = time.time()
207
- res_df = semantic_search(q, top_k=k_int)
 
 
 
208
  took_ms = int((time.time() - t0) * 1000)
209
 
210
- results = [row_to_json(r, include_text=include_text_bool) for _, r in res_df.iterrows()]
211
 
212
  return jsonify({
213
  "ok": True,
214
  "query": q,
215
  "query_norm": normalize_ar(q),
216
- "k": k_int,
217
  "took_ms": took_ms,
218
  "results_count": len(results),
219
  "results": results
@@ -221,5 +268,5 @@ def search_get():
221
 
222
 
223
  if __name__ == "__main__":
224
- # For local debug only. On HF Spaces, gunicorn/uvicorn handles it.
225
  app.run(host="0.0.0.0", port=7860, debug=False)
 
3
  import os
4
  import re
5
  import time
6
+ from typing import Any, Dict, List, Optional, Tuple
7
 
8
  import numpy as np
9
  import pandas as pd
 
23
  DEFAULT_TOP_K = 10
24
  MAX_TOP_K = 50
25
 
 
 
 
26
 
27
  # =========================
28
+ # Arabic normalization (remove tashkeel + normalize letters)
29
  # =========================
30
  _AR_DIACRITICS = re.compile(r"""
31
  [\u0610-\u061A]
 
35
  """, re.VERBOSE)
36
 
37
  def normalize_ar(text: str) -> str:
 
38
  if text is None:
39
  return ""
40
  text = str(text)
 
49
 
50
 
51
  # =========================
52
+ # Lazy load (load resources on demand)
53
  # =========================
54
+ _model: Optional[SentenceTransformer] = None
55
+ _index = None
56
+ _meta: Optional[pd.DataFrame] = None
57
+
58
+ def get_resources() -> Tuple[SentenceTransformer, Any, pd.DataFrame]:
59
+ global _model, _index, _meta
60
+
61
+ if _model is not None and _index is not None and _meta is not None:
62
+ return _model, _index, _meta
63
+
64
+ if not os.path.exists(INDEX_PATH):
65
+ raise FileNotFoundError(f"FAISS index not found: {INDEX_PATH}")
66
+
67
+ if not os.path.exists(META_PATH):
68
+ raise FileNotFoundError(f"Meta parquet not found: {META_PATH}")
69
 
70
+ _model = SentenceTransformer(MODEL_NAME)
71
+ _index = faiss.read_index(INDEX_PATH)
72
+ _meta = pd.read_parquet(META_PATH)
73
 
74
+ required_cols = {"hadithID", "collection", "hadith_number", "arabic", "english"}
75
+ missing = required_cols - set(_meta.columns)
76
+ if missing:
77
+ raise ValueError(f"Meta is missing required columns: {missing}")
78
 
79
+ if "arabic_clean" not in _meta.columns:
80
+ _meta["arabic_clean"] = ""
 
 
81
 
82
+ # Normalize types / fill missing
83
+ for col in ["arabic", "english", "arabic_clean", "collection"]:
84
+ if col in _meta.columns:
85
+ _meta[col] = _meta[col].fillna("").astype(str)
86
 
87
+ return _model, _index, _meta
 
 
 
88
 
89
 
90
+ # =========================
91
+ # Search
92
+ # =========================
93
  def semantic_search(query: str, top_k: int = DEFAULT_TOP_K) -> pd.DataFrame:
94
+ model, index, meta = get_resources()
95
+
96
  q = str(query or "").strip()
97
  if not q:
98
  return meta.iloc[0:0].copy()
 
101
 
102
  q_norm = normalize_ar(q)
103
  q_emb = model.encode(["query: " + q_norm], normalize_embeddings=True).astype("float32")
104
+
105
  scores, idx = index.search(q_emb, top_k)
106
 
107
  res = meta.iloc[idx[0]].copy()
108
  res["score"] = scores[0].astype(float)
109
  res = res.sort_values("score", ascending=False)
110
 
111
+ # Filter empty arabic just in case
112
  res["arabic"] = res["arabic"].fillna("").astype(str)
113
  res = res[res["arabic"].str.strip() != ""]
114
 
 
117
 
118
  def row_to_json(row: pd.Series, include_text: bool = True) -> Dict[str, Any]:
119
  arabic = str(row.get("arabic", "") or "")
120
+ english = str(row.get("english", "") or "")
121
+
122
  arabic_clean = str(row.get("arabic_clean", "") or "").strip()
123
  if not arabic_clean:
124
  arabic_clean = normalize_ar(arabic)
125
 
126
+ base: Dict[str, Any] = {
127
  "score": float(row.get("score", 0.0)),
128
  "hadithID": int(row.get("hadithID")),
129
  "collection": str(row.get("collection", "")),
 
134
  base.update({
135
  "arabic": arabic,
136
  "arabic_clean": arabic_clean,
137
+ "english": english,
138
  })
139
 
140
  return base
141
 
142
 
143
  # =========================
144
+ # Flask API
145
  # =========================
146
  app = Flask(__name__)
147
+ CORS(app, resources={r"/*": {"origins": "*"}})
148
 
149
 
150
+ @app.get("/")
151
+ def root():
152
  return jsonify({
153
  "ok": True,
154
+ "service": "hadeeth semantic search api",
155
+ "endpoints": ["/health", "/search (GET/POST)"]
 
156
  })
157
 
158
 
159
+ @app.get("/health")
160
+ def health():
161
+ # Don't force-load model/index/meta here if you want it super fast
162
+ # But we can still show file presence:
163
+ files_ok = os.path.exists(INDEX_PATH) and os.path.exists(META_PATH)
164
+ info = {
165
+ "ok": True,
166
+ "files_ok": files_ok,
167
+ "index_path": INDEX_PATH,
168
+ "meta_path": META_PATH,
169
+ "model": MODEL_NAME,
170
+ }
171
+
172
+ # If you want to show counts (this will load resources):
173
+ try:
174
+ _, index, meta = get_resources()
175
+ info["rows"] = int(len(meta))
176
+ info["index_ntotal"] = int(getattr(index, "ntotal", -1))
177
+ info["loaded"] = True
178
+ except Exception as e:
179
+ info["loaded"] = False
180
+ info["load_error"] = str(e)
181
+
182
+ return jsonify(info)
183
+
184
+
185
  @app.post("/search")
186
+ def search_post():
187
  """
188
+ Body JSON:
189
  {
190
+ "q": "الرزق",
191
  "k": 10,
192
  "include_text": true
193
  }
194
  """
195
  payload = request.get_json(silent=True) or {}
 
 
 
196
 
197
+ q = (payload.get("q") or "").strip()
198
  if not q:
199
  return jsonify({"ok": False, "error": "Missing 'q'"}), 400
200
+
201
+ k = payload.get("k", DEFAULT_TOP_K)
202
  try:
203
  k = int(k)
204
  except Exception:
205
  k = DEFAULT_TOP_K
206
  k = max(1, min(k, MAX_TOP_K))
207
 
208
+ include_text = payload.get("include_text", True)
209
+ include_text = bool(include_text)
210
+
211
  t0 = time.time()
212
+ try:
213
+ res_df = semantic_search(q, top_k=k)
214
+ except Exception as e:
215
+ return jsonify({"ok": False, "error": str(e)}), 500
216
  took_ms = int((time.time() - t0) * 1000)
217
 
218
+ results = [row_to_json(r, include_text=include_text) for _, r in res_df.iterrows()]
219
 
220
  return jsonify({
221
  "ok": True,
 
232
  def search_get():
233
  """
234
  GET /search?q=...&k=10&include_text=1
 
235
  """
236
  q = (request.args.get("q") or "").strip()
 
 
 
237
  if not q:
238
  return jsonify({"ok": False, "error": "Missing 'q'"}), 400
239
 
240
+ k_raw = request.args.get("k", str(DEFAULT_TOP_K))
241
  try:
242
+ k = int(k_raw)
243
  except Exception:
244
+ k = DEFAULT_TOP_K
245
+ k = max(1, min(k, MAX_TOP_K))
246
 
247
+ include_text_raw = request.args.get("include_text", "1")
248
+ include_text = include_text_raw not in ("0", "false", "False", "")
249
 
250
  t0 = time.time()
251
+ try:
252
+ res_df = semantic_search(q, top_k=k)
253
+ except Exception as e:
254
+ return jsonify({"ok": False, "error": str(e)}), 500
255
  took_ms = int((time.time() - t0) * 1000)
256
 
257
+ results = [row_to_json(r, include_text=include_text) for _, r in res_df.iterrows()]
258
 
259
  return jsonify({
260
  "ok": True,
261
  "query": q,
262
  "query_norm": normalize_ar(q),
263
+ "k": k,
264
  "took_ms": took_ms,
265
  "results_count": len(results),
266
  "results": results
 
268
 
269
 
270
  if __name__ == "__main__":
271
+ # Local dev only
272
  app.run(host="0.0.0.0", port=7860, debug=False)