Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import time | |
| from typing import List, Dict, Any, Optional | |
| import numpy as np | |
| import faiss | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from sentence_transformers import SentenceTransformer | |
| # ----------------------------- | |
| # Paths | |
| # ----------------------------- | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| JSON_PATH = os.path.join(BASE_DIR, "hadith_corpus25k.json") | |
| ART_DIR = os.path.join(BASE_DIR, "artifacts_hadith_faiss") | |
| INDEX_PATH = os.path.join(ART_DIR, "faiss.index") | |
| # IMPORTANT: np.save adds ".npy" if not present; keep path WITHOUT extension | |
| EMB_PATH = os.path.join(ART_DIR, "embeddings") # will produce embeddings.npy | |
| ID_BY_POS_PATH = os.path.join(ART_DIR, "id_by_pos.json") | |
| POS_BY_ID_PATH = os.path.join(ART_DIR, "pos_by_id.json") | |
| # Settings | |
| MODEL_NAME = os.getenv("MODEL_NAME", "intfloat/multilingual-e5-base") | |
| BATCH_SIZE = int(os.getenv("BATCH_SIZE", "64")) | |
| TOPK_MAX = int(os.getenv("TOPK_MAX", "50")) | |
| # ----------------------------- | |
| # App | |
| # ----------------------------- | |
| app = FastAPI(title="Hadith FAISS API", version="1.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # ูู ุชุจู ุชููููุง ุนูู ุฏูู ูู ู ููุนู ููุท ูู ูู | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ----------------------------- | |
| # Globals (loaded at startup) | |
| # ----------------------------- | |
| _items: List[Dict[str, Any]] = [] | |
| _item_by_id: Dict[int, Dict[str, Any]] = {} | |
| _model: Optional[SentenceTransformer] = None | |
| _index: Optional[faiss.Index] = None | |
| _emb: Optional[np.ndarray] = None | |
| _id_by_pos: List[int] = [] | |
| _pos_by_id: Dict[int, int] = {} | |
| _DIM: int = 0 | |
| _READY: bool = False | |
| # ----------------------------- | |
| # Helpers | |
| # ----------------------------- | |
| def build_text(x: Dict[str, Any]) -> str: | |
| ar = (x.get("arabic_clean") or x.get("arabic") or "").strip() | |
| en = (x.get("english") or "").strip() | |
| if ar and en: | |
| return ar + " [SEP] " + en | |
| return ar or en | |
| def ensure_dirs(): | |
| os.makedirs(ART_DIR, exist_ok=True) | |
| def artifacts_exist() -> bool: | |
| return ( | |
| os.path.exists(INDEX_PATH) | |
| and os.path.exists(EMB_PATH + ".npy") | |
| and os.path.exists(ID_BY_POS_PATH) | |
| and os.path.exists(POS_BY_ID_PATH) | |
| ) | |
| def load_items(): | |
| global _items, _item_by_id | |
| if not os.path.exists(JSON_PATH): | |
| raise RuntimeError(f"Missing dataset file: {JSON_PATH}") | |
| with open(JSON_PATH, "r", encoding="utf-8") as f: | |
| _items = json.load(f) | |
| _item_by_id = {} | |
| for it in _items: | |
| cid = it.get("corpusID") | |
| if cid is not None: | |
| _item_by_id[int(cid)] = it | |
| def get_model() -> SentenceTransformer: | |
| global _model | |
| if _model is None: | |
| _model = SentenceTransformer(MODEL_NAME) | |
| return _model | |
| def save_artifacts( | |
| index: faiss.Index, | |
| emb: np.ndarray, | |
| id_by_pos: List[int], | |
| pos_by_id: Dict[int, int], | |
| ): | |
| ensure_dirs() | |
| faiss.write_index(index, INDEX_PATH) | |
| np.save(EMB_PATH, emb) # will create embeddings.npy | |
| with open(ID_BY_POS_PATH, "w", encoding="utf-8") as f: | |
| json.dump([int(x) for x in id_by_pos], f, ensure_ascii=False) | |
| pos_by_id_str = {str(k): int(v) for k, v in pos_by_id.items()} | |
| with open(POS_BY_ID_PATH, "w", encoding="utf-8") as f: | |
| json.dump(pos_by_id_str, f, ensure_ascii=False) | |
| def load_artifacts(): | |
| global _index, _emb, _id_by_pos, _pos_by_id, _DIM | |
| _index = faiss.read_index(INDEX_PATH) | |
| _emb = np.load(EMB_PATH + ".npy").astype("float32") | |
| with open(ID_BY_POS_PATH, "r", encoding="utf-8") as f: | |
| _id_by_pos = [int(x) for x in json.load(f)] | |
| with open(POS_BY_ID_PATH, "r", encoding="utf-8") as f: | |
| raw = json.load(f) | |
| _pos_by_id = {int(k): int(v) for k, v in raw.items()} | |
| _DIM = int(_emb.shape[1]) | |
| def build_all(): | |
| """ | |
| Build embeddings + FAISS then save. | |
| This should run only if artifacts are missing. | |
| """ | |
| global _index, _emb, _id_by_pos, _pos_by_id, _DIM | |
| t0 = time.time() | |
| model = get_model() | |
| texts = [build_text(x) for x in _items] | |
| passages = ["passage: " + t for t in texts] # E5 passage prefix | |
| emb = model.encode( | |
| passages, | |
| normalize_embeddings=True, | |
| batch_size=BATCH_SIZE, | |
| show_progress_bar=True, | |
| ) | |
| emb = np.asarray(emb, dtype="float32") | |
| dim = int(emb.shape[1]) | |
| index = faiss.IndexFlatIP(dim) # cosine via IP since normalized | |
| index.add(emb) | |
| id_by_pos = [int(x["corpusID"]) for x in _items] | |
| pos_by_id = {cid: i for i, cid in enumerate(id_by_pos)} | |
| save_artifacts(index, emb, id_by_pos, pos_by_id) | |
| _index = index | |
| _emb = emb | |
| _id_by_pos = id_by_pos | |
| _pos_by_id = pos_by_id | |
| _DIM = dim | |
| dt = time.time() - t0 | |
| print(f"[build_all] Built + saved artifacts in {dt:.2f}s. dim={_DIM}, n={len(_id_by_pos)}") | |
| def require_ready(): | |
| if not _READY or _index is None or _emb is None: | |
| raise HTTPException(status_code=503, detail="API is not ready yet. Try again in a moment.") | |
| def pack_item(it: Dict[str, Any]) -> Dict[str, Any]: | |
| return { | |
| "corpusID": it.get("corpusID"), | |
| "book": it.get("book"), | |
| "chapter": it.get("chapter"), | |
| "arabic": it.get("arabic_clean") or it.get("arabic"), | |
| "english": it.get("english"), | |
| "grade": it.get("grade"), | |
| "meta": it.get("meta"), | |
| } | |
| def embed_query(q: str) -> np.ndarray: | |
| model = get_model() | |
| vec = model.encode(["query: " + q], normalize_embeddings=True) # E5 query prefix | |
| return np.asarray(vec, dtype="float32") | |
| # ----------------------------- | |
| # Request Models | |
| # ----------------------------- | |
| class SearchRequest(BaseModel): | |
| query: str | |
| topk: int = 10 | |
| # ----------------------------- | |
| # Startup | |
| # ----------------------------- | |
| def on_startup(): | |
| global _READY | |
| try: | |
| print("[startup] Loading items...") | |
| load_items() | |
| print(f"[startup] Loaded items: {len(_items)}") | |
| if artifacts_exist(): | |
| print("[startup] Artifacts found. Loading...") | |
| load_artifacts() | |
| print(f"[startup] Loaded artifacts: dim={_DIM}, n={len(_id_by_pos)}") | |
| else: | |
| print("[startup] Artifacts NOT found. Building now (first run)...") | |
| build_all() | |
| _READY = True | |
| print("[startup] READY โ ") | |
| except Exception as e: | |
| _READY = False | |
| print("[startup] FAILED โ", str(e)) | |
| # keep app up but not ready | |
| # ----------------------------- | |
| # Routes | |
| # ----------------------------- | |
| def root(): | |
| return {"name": "Hadith FAISS API", "ready": _READY} | |
| def health(): | |
| return { | |
| "ready": _READY, | |
| "items": len(_items), | |
| "dim": _DIM, | |
| "has_artifacts": artifacts_exist(), | |
| "model": MODEL_NAME, | |
| } | |
| def stats(): | |
| require_ready() | |
| return { | |
| "items": len(_items), | |
| "dim": _DIM, | |
| "index_type": type(_index).__name__, | |
| "topk_max": TOPK_MAX, | |
| } | |
| def get_item(corpus_id: int): | |
| require_ready() | |
| it = _item_by_id.get(int(corpus_id)) | |
| if not it: | |
| raise HTTPException(status_code=404, detail="corpusID not found") | |
| return pack_item(it) | |
| def similar(corpus_id: int, topk: int = 10): | |
| require_ready() | |
| cid = int(corpus_id) | |
| if cid not in _pos_by_id: | |
| raise HTTPException(status_code=404, detail="corpusID not found in index") | |
| topk = max(1, min(int(topk), TOPK_MAX)) | |
| pos = _pos_by_id[cid] | |
| q = _emb[pos:pos + 1] # already normalized | |
| scores, idxs = _index.search(q, topk + 1) # +1 to skip itself | |
| scores = scores[0].tolist() | |
| idxs = idxs[0].tolist() | |
| results = [] | |
| for sc, p in zip(scores, idxs): | |
| if p < 0: | |
| continue | |
| hit_id = _id_by_pos[p] | |
| if hit_id == cid: | |
| continue | |
| it = _item_by_id.get(int(hit_id)) | |
| if not it: | |
| continue | |
| results.append({ | |
| "corpusID": int(hit_id), | |
| "score": float(sc), | |
| "item": pack_item(it), | |
| }) | |
| if len(results) >= topk: | |
| break | |
| return {"query_id": cid, "topk": topk, "results": results} | |
| def search(req: SearchRequest): | |
| require_ready() | |
| q = (req.query or "").strip() | |
| if not q: | |
| raise HTTPException(status_code=400, detail="query is empty") | |
| topk = max(1, min(int(req.topk), TOPK_MAX)) | |
| qv = embed_query(q) | |
| scores, idxs = _index.search(qv, topk) | |
| scores = scores[0].tolist() | |
| idxs = idxs[0].tolist() | |
| results = [] | |
| for sc, p in zip(scores, idxs): | |
| if p < 0: | |
| continue | |
| hit_id = _id_by_pos[p] | |
| it = _item_by_id.get(int(hit_id)) | |
| if not it: | |
| continue | |
| results.append({ | |
| "corpusID": int(hit_id), | |
| "score": float(sc), | |
| "item": pack_item(it), | |
| }) | |
| return {"query": q, "topk": topk, "results": results} | |