Spaces:
Sleeping
Sleeping
| import io | |
| import os | |
| import faiss | |
| import numpy as np | |
| import pandas as pd | |
| from PIL import Image | |
| from fastapi import FastAPI, File, UploadFile, Query | |
| from fastapi.responses import JSONResponse, HTMLResponse | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import BlipProcessor, BlipForConditionalGeneration | |
| import torch | |
| APP_TITLE = "Image → Hadith Similarity (FAISS)" | |
| INDEX_PATH = "hadith_semantic.faiss" | |
| META_PATH = "hadith_meta.parquet" | |
| SBERT_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" | |
| BLIP_NAME = "Salesforce/blip-image-captioning-base" | |
| app = FastAPI(title=APP_TITLE) | |
| index = None | |
| meta = None | |
| sbert = None | |
| blip_processor = None | |
| blip_model = None | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def load_all(): | |
| global index, meta, sbert, blip_processor, blip_model | |
| if not os.path.exists(INDEX_PATH): | |
| raise RuntimeError(f"Missing FAISS index: {INDEX_PATH}") | |
| if not os.path.exists(META_PATH): | |
| raise RuntimeError(f"Missing meta file: {META_PATH}") | |
| index = faiss.read_index(INDEX_PATH) | |
| meta = pd.read_parquet(META_PATH) | |
| # Basic sanity check | |
| if len(meta) != index.ntotal: | |
| # Not always fatal, but usually means mismatch between index build order and meta rows. | |
| print(f"[WARN] meta rows ({len(meta)}) != index.ntotal ({index.ntotal}). " | |
| f"Results will use row positions; ensure they align.") | |
| sbert = SentenceTransformer(SBERT_NAME) | |
| blip_processor = BlipProcessor.from_pretrained(BLIP_NAME) | |
| blip_model = BlipForConditionalGeneration.from_pretrained(BLIP_NAME).to(device) | |
| blip_model.eval() | |
| def health(): | |
| # Try infer dim from index when possible | |
| dim = getattr(index, "d", None) | |
| return { | |
| "ok": True, | |
| "index_file": INDEX_PATH, | |
| "meta_file": META_PATH, | |
| "index_ntotal": int(index.ntotal), | |
| "meta_rows": int(len(meta)), | |
| "dim": int(dim) if dim is not None else None, | |
| "text_model": SBERT_NAME, | |
| "caption_model": BLIP_NAME, | |
| "device": device | |
| } | |
| def caption_image(pil_img: Image.Image) -> str: | |
| inputs = blip_processor(images=pil_img, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| out = blip_model.generate(**inputs, max_new_tokens=30) | |
| cap = blip_processor.decode(out[0], skip_special_tokens=True) | |
| return cap.strip() | |
| def embed_text(text: str) -> np.ndarray: | |
| # normalize_embeddings => cosine via inner-product | |
| v = sbert.encode([text], normalize_embeddings=True) | |
| return v.astype("float32") | |
| def pick_col(row, candidates, default=""): | |
| for c in candidates: | |
| if c in row and pd.notna(row[c]): | |
| return row[c] | |
| return default | |
| async def search_image( | |
| file: UploadFile = File(...), | |
| k: int = Query(10, ge=1, le=50), | |
| format: str = Query("json"), | |
| ): | |
| data = await file.read() | |
| pil = Image.open(io.BytesIO(data)).convert("RGB") | |
| cap = caption_image(pil) | |
| qvec = embed_text(cap) | |
| scores, idxs = index.search(qvec, k) | |
| results = [] | |
| for rank, (i, s) in enumerate(zip(idxs[0].tolist(), scores[0].tolist()), start=1): | |
| if i < 0 or i >= len(meta): | |
| continue | |
| row = meta.iloc[i].to_dict() | |
| hadith_id = pick_col(row, ["hadithID", "hadith_id", "id", "doc_id"], default=i) | |
| text_ar = pick_col(row, ["text_ar", "arabic", "ar", "text"], default="") | |
| text_en = pick_col(row, ["text_en", "english", "en"], default="") | |
| source = pick_col(row, ["source", "book", "collection"], default="") | |
| results.append({ | |
| "rank": rank, | |
| "score": float(s), | |
| "hadithID": int(hadith_id) if str(hadith_id).isdigit() else str(hadith_id), | |
| "text_ar": str(text_ar), | |
| "text_en": str(text_en), | |
| "source": str(source), | |
| }) | |
| payload = {"caption": cap, "k": k, "results": results} | |
| if format == "html": | |
| items = "\n".join([ | |
| f"<li><b>#{r['rank']}</b> score={r['score']:.3f} — hadithID={r['hadithID']}<br>" | |
| f"<div style='font-family: system-ui; direction: rtl; font-size: 18px'>{r['text_ar']}</div>" | |
| f"<div style='color:#666; margin-top:6px'>{r['text_en']}</div>" | |
| f"<div style='color:#999; margin-top:6px'>source: {r['source']}</div>" | |
| f"</li>" | |
| for r in results | |
| ]) | |
| html = f""" | |
| <html> | |
| <body style="margin:18px; font-family: system-ui"> | |
| <h3>Caption</h3> | |
| <p>{cap}</p> | |
| <h3>Top Results</h3> | |
| <ol>{items}</ol> | |
| </body> | |
| </html> | |
| """ | |
| return HTMLResponse(html) | |
| return JSONResponse(payload) | |