Alshargi commited on
Commit
8662bec
ยท
verified ยท
1 Parent(s): e11a55d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -26
app.py CHANGED
@@ -20,10 +20,17 @@ JSON_PATH = os.path.join(BASE_DIR, "hadith_corpus25k.json")
20
 
21
  ART_DIR = os.path.join(BASE_DIR, "artifacts_hadith_faiss")
22
  INDEX_PATH = os.path.join(ART_DIR, "faiss.index")
23
- EMB_PATH = os.path.join(ART_DIR, "embeddings.npy")
 
 
24
  ID_BY_POS_PATH = os.path.join(ART_DIR, "id_by_pos.json")
25
  POS_BY_ID_PATH = os.path.join(ART_DIR, "pos_by_id.json")
26
 
 
 
 
 
 
27
  # -----------------------------
28
  # App
29
  # -----------------------------
@@ -31,8 +38,8 @@ app = FastAPI(title="Hadith FAISS API", version="1.0")
31
 
32
  app.add_middleware(
33
  CORSMiddleware,
34
- allow_origins=["*"], # ู„ูˆ ุชุจูŠ ุชู‚ูู„ู‡ุง ุนู„ู‰ ุฏูˆู…ูŠู† ู…ูˆู‚ุนูƒ ูู‚ุท ู‚ู„ ู„ูŠ
35
- allow_credentials=True,
36
  allow_methods=["*"],
37
  allow_headers=["*"],
38
  )
@@ -71,7 +78,7 @@ def ensure_dirs():
71
  def artifacts_exist() -> bool:
72
  return (
73
  os.path.exists(INDEX_PATH)
74
- and os.path.exists(EMB_PATH)
75
  and os.path.exists(ID_BY_POS_PATH)
76
  and os.path.exists(POS_BY_ID_PATH)
77
  )
@@ -85,19 +92,21 @@ def load_items():
85
  with open(JSON_PATH, "r", encoding="utf-8") as f:
86
  _items = json.load(f)
87
 
88
- # Build id map
 
 
89
  _item_by_id = {}
90
  for it in _items:
91
  cid = it.get("corpusID")
92
- if cid is not None:
93
- _item_by_id[int(cid)] = it
 
94
 
95
 
96
  def get_model() -> SentenceTransformer:
97
  global _model
98
  if _model is None:
99
- # intfloat/multilingual-e5-base
100
- _model = SentenceTransformer("intfloat/multilingual-e5-base")
101
  return _model
102
 
103
 
@@ -105,12 +114,11 @@ def save_artifacts(index: faiss.Index, emb: np.ndarray, id_by_pos: List[int], po
105
  ensure_dirs()
106
 
107
  faiss.write_index(index, INDEX_PATH)
108
- np.save(EMB_PATH, emb)
109
 
110
  with open(ID_BY_POS_PATH, "w", encoding="utf-8") as f:
111
  json.dump(id_by_pos, f, ensure_ascii=False)
112
 
113
- # keys must be str in json; we convert to str
114
  pos_by_id_str = {str(k): int(v) for k, v in pos_by_id.items()}
115
  with open(POS_BY_ID_PATH, "w", encoding="utf-8") as f:
116
  json.dump(pos_by_id_str, f, ensure_ascii=False)
@@ -120,7 +128,7 @@ def load_artifacts():
120
  global _index, _emb, _id_by_pos, _pos_by_id, _DIM
121
 
122
  _index = faiss.read_index(INDEX_PATH)
123
- _emb = np.load(EMB_PATH).astype("float32")
124
 
125
  with open(ID_BY_POS_PATH, "r", encoding="utf-8") as f:
126
  _id_by_pos = [int(x) for x in json.load(f)]
@@ -142,14 +150,12 @@ def build_all():
142
 
143
  model = get_model()
144
  texts = [build_text(x) for x in _items]
145
-
146
- # E5 recommends prefixes
147
- passages = ["passage: " + t for t in texts]
148
 
149
  emb = model.encode(
150
  passages,
151
  normalize_embeddings=True,
152
- batch_size=64,
153
  show_progress_bar=True,
154
  )
155
  emb = np.asarray(emb, dtype="float32")
@@ -158,7 +164,13 @@ def build_all():
158
  index = faiss.IndexFlatIP(dim) # cosine via IP since normalized
159
  index.add(emb)
160
 
161
- id_by_pos = [int(x["corpusID"]) for x in _items]
 
 
 
 
 
 
162
  pos_by_id = {cid: i for i, cid in enumerate(id_by_pos)}
163
 
164
  save_artifacts(index, emb, id_by_pos, pos_by_id)
@@ -174,12 +186,11 @@ def build_all():
174
 
175
 
176
  def require_ready():
177
- if not _READY or _index is None or _emb is None:
178
  raise HTTPException(status_code=503, detail="API is not ready yet. Try again in a moment.")
179
 
180
 
181
  def pack_item(it: Dict[str, Any]) -> Dict[str, Any]:
182
- # return only what you need (ุฎููŠู)
183
  return {
184
  "corpusID": it.get("corpusID"),
185
  "book": it.get("book"),
@@ -193,8 +204,7 @@ def pack_item(it: Dict[str, Any]) -> Dict[str, Any]:
193
 
194
  def embed_query(q: str) -> np.ndarray:
195
  model = get_model()
196
- # E5 query prefix:
197
- vec = model.encode(["query: " + q], normalize_embeddings=True)
198
  return np.asarray(vec, dtype="float32")
199
 
200
 
@@ -231,7 +241,6 @@ def on_startup():
231
  except Exception as e:
232
  _READY = False
233
  print("[startup] FAILED โŒ", str(e))
234
- # keep app up but not ready
235
 
236
 
237
  # -----------------------------
@@ -254,6 +263,7 @@ def stats():
254
  "items": len(_items),
255
  "dim": _DIM,
256
  "index_type": type(_index).__name__,
 
257
  }
258
 
259
 
@@ -269,14 +279,17 @@ def get_item(corpus_id: int):
269
  @app.get("/similar/{corpus_id}")
270
  def similar(corpus_id: int, topk: int = 10):
271
  require_ready()
 
272
  cid = int(corpus_id)
273
  if cid not in _pos_by_id:
274
  raise HTTPException(status_code=404, detail="corpusID not found in index")
275
 
 
 
276
  pos = _pos_by_id[cid]
277
  q = _emb[pos:pos + 1] # already normalized
278
 
279
- scores, idxs = _index.search(q, int(topk) + 1) # +1 to skip itself
280
  scores = scores[0].tolist()
281
  idxs = idxs[0].tolist()
282
 
@@ -295,20 +308,21 @@ def similar(corpus_id: int, topk: int = 10):
295
  "score": float(sc),
296
  "item": pack_item(it),
297
  })
298
- if len(results) >= int(topk):
299
  break
300
 
301
- return {"query_id": cid, "topk": int(topk), "results": results}
302
 
303
 
304
  @app.post("/search")
305
  def search(req: SearchRequest):
306
  require_ready()
 
307
  q = (req.query or "").strip()
308
  if not q:
309
  raise HTTPException(status_code=400, detail="query is empty")
310
 
311
- topk = max(1, min(int(req.topk), 50))
312
 
313
  qv = embed_query(q)
314
  scores, idxs = _index.search(qv, topk)
 
20
 
21
  ART_DIR = os.path.join(BASE_DIR, "artifacts_hadith_faiss")
22
  INDEX_PATH = os.path.join(ART_DIR, "faiss.index")
23
+
24
+ # IMPORTANT: np.save adds ".npy" if not present; keep path WITHOUT extension
25
+ EMB_PATH = os.path.join(ART_DIR, "embeddings") # will produce embeddings.npy
26
  ID_BY_POS_PATH = os.path.join(ART_DIR, "id_by_pos.json")
27
  POS_BY_ID_PATH = os.path.join(ART_DIR, "pos_by_id.json")
28
 
29
+ # Settings
30
+ MODEL_NAME = os.getenv("MODEL_NAME", "intfloat/multilingual-e5-base")
31
+ BATCH_SIZE = int(os.getenv("BATCH_SIZE", "64"))
32
+ TOPK_MAX = int(os.getenv("TOPK_MAX", "50"))
33
+
34
  # -----------------------------
35
  # App
36
  # -----------------------------
 
38
 
39
  app.add_middleware(
40
  CORSMiddleware,
41
+ allow_origins=["*"], # ู„ุงุญู‚ุงู‹: ุงุณุชุจุฏู„ู‡ุง ุจุฏูˆู…ูŠู† ู…ูˆู‚ุนูƒ ู„ู„ุฃู…ุงู†
42
+ allow_credentials=False,
43
  allow_methods=["*"],
44
  allow_headers=["*"],
45
  )
 
78
  def artifacts_exist() -> bool:
79
  return (
80
  os.path.exists(INDEX_PATH)
81
+ and os.path.exists(EMB_PATH + ".npy")
82
  and os.path.exists(ID_BY_POS_PATH)
83
  and os.path.exists(POS_BY_ID_PATH)
84
  )
 
92
  with open(JSON_PATH, "r", encoding="utf-8") as f:
93
  _items = json.load(f)
94
 
95
+ if not isinstance(_items, list):
96
+ raise RuntimeError("Dataset JSON root must be a list")
97
+
98
  _item_by_id = {}
99
  for it in _items:
100
  cid = it.get("corpusID")
101
+ if cid is None:
102
+ continue
103
+ _item_by_id[int(cid)] = it
104
 
105
 
106
  def get_model() -> SentenceTransformer:
107
  global _model
108
  if _model is None:
109
+ _model = SentenceTransformer(MODEL_NAME)
 
110
  return _model
111
 
112
 
 
114
  ensure_dirs()
115
 
116
  faiss.write_index(index, INDEX_PATH)
117
+ np.save(EMB_PATH, emb) # creates EMB_PATH + ".npy"
118
 
119
  with open(ID_BY_POS_PATH, "w", encoding="utf-8") as f:
120
  json.dump(id_by_pos, f, ensure_ascii=False)
121
 
 
122
  pos_by_id_str = {str(k): int(v) for k, v in pos_by_id.items()}
123
  with open(POS_BY_ID_PATH, "w", encoding="utf-8") as f:
124
  json.dump(pos_by_id_str, f, ensure_ascii=False)
 
128
  global _index, _emb, _id_by_pos, _pos_by_id, _DIM
129
 
130
  _index = faiss.read_index(INDEX_PATH)
131
+ _emb = np.load(EMB_PATH + ".npy").astype("float32", copy=False)
132
 
133
  with open(ID_BY_POS_PATH, "r", encoding="utf-8") as f:
134
  _id_by_pos = [int(x) for x in json.load(f)]
 
150
 
151
  model = get_model()
152
  texts = [build_text(x) for x in _items]
153
+ passages = ["passage: " + t for t in texts] # E5 passage prefix
 
 
154
 
155
  emb = model.encode(
156
  passages,
157
  normalize_embeddings=True,
158
+ batch_size=BATCH_SIZE,
159
  show_progress_bar=True,
160
  )
161
  emb = np.asarray(emb, dtype="float32")
 
164
  index = faiss.IndexFlatIP(dim) # cosine via IP since normalized
165
  index.add(emb)
166
 
167
+ # Build ID mappings
168
+ id_by_pos = []
169
+ for x in _items:
170
+ if "corpusID" not in x:
171
+ raise RuntimeError("Each item must have corpusID")
172
+ id_by_pos.append(int(x["corpusID"]))
173
+
174
  pos_by_id = {cid: i for i, cid in enumerate(id_by_pos)}
175
 
176
  save_artifacts(index, emb, id_by_pos, pos_by_id)
 
186
 
187
 
188
  def require_ready():
189
+ if (not _READY) or (_index is None) or (_emb is None):
190
  raise HTTPException(status_code=503, detail="API is not ready yet. Try again in a moment.")
191
 
192
 
193
  def pack_item(it: Dict[str, Any]) -> Dict[str, Any]:
 
194
  return {
195
  "corpusID": it.get("corpusID"),
196
  "book": it.get("book"),
 
204
 
205
  def embed_query(q: str) -> np.ndarray:
206
  model = get_model()
207
+ vec = model.encode(["query: " + q], normalize_embeddings=True) # E5 query prefix
 
208
  return np.asarray(vec, dtype="float32")
209
 
210
 
 
241
  except Exception as e:
242
  _READY = False
243
  print("[startup] FAILED โŒ", str(e))
 
244
 
245
 
246
  # -----------------------------
 
263
  "items": len(_items),
264
  "dim": _DIM,
265
  "index_type": type(_index).__name__,
266
+ "model": MODEL_NAME,
267
  }
268
 
269
 
 
279
  @app.get("/similar/{corpus_id}")
280
  def similar(corpus_id: int, topk: int = 10):
281
  require_ready()
282
+
283
  cid = int(corpus_id)
284
  if cid not in _pos_by_id:
285
  raise HTTPException(status_code=404, detail="corpusID not found in index")
286
 
287
+ topk = max(1, min(int(topk), TOPK_MAX))
288
+
289
  pos = _pos_by_id[cid]
290
  q = _emb[pos:pos + 1] # already normalized
291
 
292
+ scores, idxs = _index.search(q, topk + 1) # +1 to skip itself
293
  scores = scores[0].tolist()
294
  idxs = idxs[0].tolist()
295
 
 
308
  "score": float(sc),
309
  "item": pack_item(it),
310
  })
311
+ if len(results) >= topk:
312
  break
313
 
314
+ return {"query_id": cid, "topk": topk, "results": results}
315
 
316
 
317
  @app.post("/search")
318
  def search(req: SearchRequest):
319
  require_ready()
320
+
321
  q = (req.query or "").strip()
322
  if not q:
323
  raise HTTPException(status_code=400, detail="query is empty")
324
 
325
+ topk = max(1, min(int(req.topk), TOPK_MAX))
326
 
327
  qv = embed_query(q)
328
  scores, idxs = _index.search(qv, topk)