import os import json import time from typing import List, Dict, Any, Optional import numpy as np import faiss from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from sentence_transformers import SentenceTransformer # ----------------------------- # Paths # ----------------------------- BASE_DIR = os.path.dirname(os.path.abspath(__file__)) JSON_PATH = os.path.join(BASE_DIR, "hadith_corpus25k.json") ART_DIR = os.path.join(BASE_DIR, "artifacts_hadith_faiss") INDEX_PATH = os.path.join(ART_DIR, "faiss.index") # IMPORTANT: np.save adds ".npy" if not present; keep path WITHOUT extension EMB_PATH = os.path.join(ART_DIR, "embeddings") # will produce embeddings.npy ID_BY_POS_PATH = os.path.join(ART_DIR, "id_by_pos.json") POS_BY_ID_PATH = os.path.join(ART_DIR, "pos_by_id.json") # Settings MODEL_NAME = os.getenv("MODEL_NAME", "intfloat/multilingual-e5-base") BATCH_SIZE = int(os.getenv("BATCH_SIZE", "64")) TOPK_MAX = int(os.getenv("TOPK_MAX", "50")) # ----------------------------- # App # ----------------------------- app = FastAPI(title="Hadith FAISS API", version="1.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], # لو تبي تقفلها على دومين موقعك فقط قل لي allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ----------------------------- # Globals (loaded at startup) # ----------------------------- _items: List[Dict[str, Any]] = [] _item_by_id: Dict[int, Dict[str, Any]] = {} _model: Optional[SentenceTransformer] = None _index: Optional[faiss.Index] = None _emb: Optional[np.ndarray] = None _id_by_pos: List[int] = [] _pos_by_id: Dict[int, int] = {} _DIM: int = 0 _READY: bool = False # ----------------------------- # Helpers # ----------------------------- def build_text(x: Dict[str, Any]) -> str: ar = (x.get("arabic_clean") or x.get("arabic") or "").strip() en = (x.get("english") or "").strip() if ar and en: return ar + " [SEP] " + en return ar or en def ensure_dirs(): os.makedirs(ART_DIR, exist_ok=True) def artifacts_exist() -> bool: return ( os.path.exists(INDEX_PATH) and os.path.exists(EMB_PATH + ".npy") and os.path.exists(ID_BY_POS_PATH) and os.path.exists(POS_BY_ID_PATH) ) def load_items(): global _items, _item_by_id if not os.path.exists(JSON_PATH): raise RuntimeError(f"Missing dataset file: {JSON_PATH}") with open(JSON_PATH, "r", encoding="utf-8") as f: _items = json.load(f) _item_by_id = {} for it in _items: cid = it.get("corpusID") if cid is not None: _item_by_id[int(cid)] = it def get_model() -> SentenceTransformer: global _model if _model is None: _model = SentenceTransformer(MODEL_NAME) return _model def save_artifacts( index: faiss.Index, emb: np.ndarray, id_by_pos: List[int], pos_by_id: Dict[int, int], ): ensure_dirs() faiss.write_index(index, INDEX_PATH) np.save(EMB_PATH, emb) # will create embeddings.npy with open(ID_BY_POS_PATH, "w", encoding="utf-8") as f: json.dump([int(x) for x in id_by_pos], f, ensure_ascii=False) pos_by_id_str = {str(k): int(v) for k, v in pos_by_id.items()} with open(POS_BY_ID_PATH, "w", encoding="utf-8") as f: json.dump(pos_by_id_str, f, ensure_ascii=False) def load_artifacts(): global _index, _emb, _id_by_pos, _pos_by_id, _DIM _index = faiss.read_index(INDEX_PATH) _emb = np.load(EMB_PATH + ".npy").astype("float32") with open(ID_BY_POS_PATH, "r", encoding="utf-8") as f: _id_by_pos = [int(x) for x in json.load(f)] with open(POS_BY_ID_PATH, "r", encoding="utf-8") as f: raw = json.load(f) _pos_by_id = {int(k): int(v) for k, v in raw.items()} _DIM = int(_emb.shape[1]) def build_all(): """ Build embeddings + FAISS then save. This should run only if artifacts are missing. """ global _index, _emb, _id_by_pos, _pos_by_id, _DIM t0 = time.time() model = get_model() texts = [build_text(x) for x in _items] passages = ["passage: " + t for t in texts] # E5 passage prefix emb = model.encode( passages, normalize_embeddings=True, batch_size=BATCH_SIZE, show_progress_bar=True, ) emb = np.asarray(emb, dtype="float32") dim = int(emb.shape[1]) index = faiss.IndexFlatIP(dim) # cosine via IP since normalized index.add(emb) id_by_pos = [int(x["corpusID"]) for x in _items] pos_by_id = {cid: i for i, cid in enumerate(id_by_pos)} save_artifacts(index, emb, id_by_pos, pos_by_id) _index = index _emb = emb _id_by_pos = id_by_pos _pos_by_id = pos_by_id _DIM = dim dt = time.time() - t0 print(f"[build_all] Built + saved artifacts in {dt:.2f}s. dim={_DIM}, n={len(_id_by_pos)}") def require_ready(): if not _READY or _index is None or _emb is None: raise HTTPException(status_code=503, detail="API is not ready yet. Try again in a moment.") def pack_item(it: Dict[str, Any]) -> Dict[str, Any]: return { "corpusID": it.get("corpusID"), "book": it.get("book"), "chapter": it.get("chapter"), "arabic": it.get("arabic_clean") or it.get("arabic"), "english": it.get("english"), "grade": it.get("grade"), "meta": it.get("meta"), } def embed_query(q: str) -> np.ndarray: model = get_model() vec = model.encode(["query: " + q], normalize_embeddings=True) # E5 query prefix return np.asarray(vec, dtype="float32") # ----------------------------- # Request Models # ----------------------------- class SearchRequest(BaseModel): query: str topk: int = 10 # ----------------------------- # Startup # ----------------------------- @app.on_event("startup") def on_startup(): global _READY try: print("[startup] Loading items...") load_items() print(f"[startup] Loaded items: {len(_items)}") if artifacts_exist(): print("[startup] Artifacts found. Loading...") load_artifacts() print(f"[startup] Loaded artifacts: dim={_DIM}, n={len(_id_by_pos)}") else: print("[startup] Artifacts NOT found. Building now (first run)...") build_all() _READY = True print("[startup] READY ✅") except Exception as e: _READY = False print("[startup] FAILED ❌", str(e)) # keep app up but not ready # ----------------------------- # Routes # ----------------------------- @app.get("/") def root(): return {"name": "Hadith FAISS API", "ready": _READY} @app.get("/health") def health(): return { "ready": _READY, "items": len(_items), "dim": _DIM, "has_artifacts": artifacts_exist(), "model": MODEL_NAME, } @app.get("/stats") def stats(): require_ready() return { "items": len(_items), "dim": _DIM, "index_type": type(_index).__name__, "topk_max": TOPK_MAX, } @app.get("/item/{corpus_id}") def get_item(corpus_id: int): require_ready() it = _item_by_id.get(int(corpus_id)) if not it: raise HTTPException(status_code=404, detail="corpusID not found") return pack_item(it) @app.get("/similar/{corpus_id}") def similar(corpus_id: int, topk: int = 10): require_ready() cid = int(corpus_id) if cid not in _pos_by_id: raise HTTPException(status_code=404, detail="corpusID not found in index") topk = max(1, min(int(topk), TOPK_MAX)) pos = _pos_by_id[cid] q = _emb[pos:pos + 1] # already normalized scores, idxs = _index.search(q, topk + 1) # +1 to skip itself scores = scores[0].tolist() idxs = idxs[0].tolist() results = [] for sc, p in zip(scores, idxs): if p < 0: continue hit_id = _id_by_pos[p] if hit_id == cid: continue it = _item_by_id.get(int(hit_id)) if not it: continue results.append({ "corpusID": int(hit_id), "score": float(sc), "item": pack_item(it), }) if len(results) >= topk: break return {"query_id": cid, "topk": topk, "results": results} @app.post("/search") def search(req: SearchRequest): require_ready() q = (req.query or "").strip() if not q: raise HTTPException(status_code=400, detail="query is empty") topk = max(1, min(int(req.topk), TOPK_MAX)) qv = embed_query(q) scores, idxs = _index.search(qv, topk) scores = scores[0].tolist() idxs = idxs[0].tolist() results = [] for sc, p in zip(scores, idxs): if p < 0: continue hit_id = _id_by_pos[p] it = _item_by_id.get(int(hit_id)) if not it: continue results.append({ "corpusID": int(hit_id), "score": float(sc), "item": pack_item(it), }) return {"query": q, "topk": topk, "results": results}