import io import os import faiss import numpy as np import pandas as pd from PIL import Image from fastapi import FastAPI, File, UploadFile, Query from fastapi.responses import JSONResponse, HTMLResponse from sentence_transformers import SentenceTransformer from transformers import BlipProcessor, BlipForConditionalGeneration import torch APP_TITLE = "Image → Hadith Similarity (FAISS)" INDEX_PATH = "hadith_semantic.faiss" META_PATH = "hadith_meta.parquet" SBERT_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" BLIP_NAME = "Salesforce/blip-image-captioning-base" app = FastAPI(title=APP_TITLE) index = None meta = None sbert = None blip_processor = None blip_model = None device = "cuda" if torch.cuda.is_available() else "cpu" @app.on_event("startup") def load_all(): global index, meta, sbert, blip_processor, blip_model if not os.path.exists(INDEX_PATH): raise RuntimeError(f"Missing FAISS index: {INDEX_PATH}") if not os.path.exists(META_PATH): raise RuntimeError(f"Missing meta file: {META_PATH}") index = faiss.read_index(INDEX_PATH) meta = pd.read_parquet(META_PATH) # Basic sanity check if len(meta) != index.ntotal: # Not always fatal, but usually means mismatch between index build order and meta rows. print(f"[WARN] meta rows ({len(meta)}) != index.ntotal ({index.ntotal}). " f"Results will use row positions; ensure they align.") sbert = SentenceTransformer(SBERT_NAME) blip_processor = BlipProcessor.from_pretrained(BLIP_NAME) blip_model = BlipForConditionalGeneration.from_pretrained(BLIP_NAME).to(device) blip_model.eval() @app.get("/health") def health(): # Try infer dim from index when possible dim = getattr(index, "d", None) return { "ok": True, "index_file": INDEX_PATH, "meta_file": META_PATH, "index_ntotal": int(index.ntotal), "meta_rows": int(len(meta)), "dim": int(dim) if dim is not None else None, "text_model": SBERT_NAME, "caption_model": BLIP_NAME, "device": device } def caption_image(pil_img: Image.Image) -> str: inputs = blip_processor(images=pil_img, return_tensors="pt").to(device) with torch.no_grad(): out = blip_model.generate(**inputs, max_new_tokens=30) cap = blip_processor.decode(out[0], skip_special_tokens=True) return cap.strip() def embed_text(text: str) -> np.ndarray: # normalize_embeddings => cosine via inner-product v = sbert.encode([text], normalize_embeddings=True) return v.astype("float32") def pick_col(row, candidates, default=""): for c in candidates: if c in row and pd.notna(row[c]): return row[c] return default @app.post("/search_image") async def search_image( file: UploadFile = File(...), k: int = Query(10, ge=1, le=50), format: str = Query("json"), ): data = await file.read() pil = Image.open(io.BytesIO(data)).convert("RGB") cap = caption_image(pil) qvec = embed_text(cap) scores, idxs = index.search(qvec, k) results = [] for rank, (i, s) in enumerate(zip(idxs[0].tolist(), scores[0].tolist()), start=1): if i < 0 or i >= len(meta): continue row = meta.iloc[i].to_dict() hadith_id = pick_col(row, ["hadithID", "hadith_id", "id", "doc_id"], default=i) text_ar = pick_col(row, ["text_ar", "arabic", "ar", "text"], default="") text_en = pick_col(row, ["text_en", "english", "en"], default="") source = pick_col(row, ["source", "book", "collection"], default="") results.append({ "rank": rank, "score": float(s), "hadithID": int(hadith_id) if str(hadith_id).isdigit() else str(hadith_id), "text_ar": str(text_ar), "text_en": str(text_en), "source": str(source), }) payload = {"caption": cap, "k": k, "results": results} if format == "html": items = "\n".join([ f"
  • #{r['rank']} score={r['score']:.3f} — hadithID={r['hadithID']}
    " f"
    {r['text_ar']}
    " f"
    {r['text_en']}
    " f"
    source: {r['source']}
    " f"
  • " for r in results ]) html = f"""

    Caption

    {cap}

    Top Results

      {items}
    """ return HTMLResponse(html) return JSONResponse(payload)