Alshargi commited on
Commit
cec2dd5
·
verified ·
1 Parent(s): dfbcb8d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -19
app.py CHANGED
@@ -1,4 +1,3 @@
1
- # app.py
2
  import io
3
  import os
4
  import faiss
@@ -12,16 +11,14 @@ from transformers import BlipProcessor, BlipForConditionalGeneration
12
  import torch
13
 
14
  APP_TITLE = "Image → Hadith Similarity (FAISS)"
15
- INDEX_PATH = "faiss.index" # put your file here (git-lfs likely)
16
- META_PATH = "hadith_meta.parquet" # hadithID, text_ar, text_en, etc.
17
 
18
- # Models
19
  SBERT_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
20
  BLIP_NAME = "Salesforce/blip-image-captioning-base"
21
 
22
  app = FastAPI(title=APP_TITLE)
23
 
24
- # -------- Load once --------
25
  index = None
26
  meta = None
27
  sbert = None
@@ -29,32 +26,49 @@ blip_processor = None
29
  blip_model = None
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
 
 
32
  @app.on_event("startup")
33
  def load_all():
34
  global index, meta, sbert, blip_processor, blip_model
35
 
36
- assert os.path.exists(INDEX_PATH), f"Missing {INDEX_PATH}"
37
- assert os.path.exists(META_PATH), f"Missing {META_PATH}"
 
 
38
 
39
  index = faiss.read_index(INDEX_PATH)
40
  meta = pd.read_parquet(META_PATH)
41
 
 
 
 
 
 
 
42
  sbert = SentenceTransformer(SBERT_NAME)
43
 
44
  blip_processor = BlipProcessor.from_pretrained(BLIP_NAME)
45
  blip_model = BlipForConditionalGeneration.from_pretrained(BLIP_NAME).to(device)
46
  blip_model.eval()
47
 
 
48
  @app.get("/health")
49
  def health():
 
 
50
  return {
51
  "ok": True,
 
 
52
  "index_ntotal": int(index.ntotal),
53
- "model": SBERT_NAME,
 
 
54
  "caption_model": BLIP_NAME,
55
- "dim": 384
56
  }
57
 
 
58
  def caption_image(pil_img: Image.Image) -> str:
59
  inputs = blip_processor(images=pil_img, return_tensors="pt").to(device)
60
  with torch.no_grad():
@@ -62,15 +76,25 @@ def caption_image(pil_img: Image.Image) -> str:
62
  cap = blip_processor.decode(out[0], skip_special_tokens=True)
63
  return cap.strip()
64
 
 
65
  def embed_text(text: str) -> np.ndarray:
 
66
  v = sbert.encode([text], normalize_embeddings=True)
67
  return v.astype("float32")
68
 
 
 
 
 
 
 
 
 
69
  @app.post("/search_image")
70
  async def search_image(
71
  file: UploadFile = File(...),
72
  k: int = Query(10, ge=1, le=50),
73
- format: str = Query("json")
74
  ):
75
  data = await file.read()
76
  pil = Image.open(io.BytesIO(data)).convert("RGB")
@@ -78,30 +102,50 @@ async def search_image(
78
  cap = caption_image(pil)
79
  qvec = embed_text(cap)
80
 
81
- # cosine via normalized + inner product
82
  scores, idxs = index.search(qvec, k)
83
 
84
  results = []
85
  for rank, (i, s) in enumerate(zip(idxs[0].tolist(), scores[0].tolist()), start=1):
86
- row = meta.iloc[i]
 
 
 
 
 
 
 
 
 
87
  results.append({
88
  "rank": rank,
89
  "score": float(s),
90
- "hadithID": int(row.get("hadithID", i)),
91
- "text_ar": row.get("text_ar", ""),
92
- "text_en": row.get("text_en", ""),
93
- "source": row.get("source", "")
94
  })
95
 
96
  payload = {"caption": cap, "k": k, "results": results}
97
 
98
  if format == "html":
99
- # minimal HTML (you can beautify)
100
  items = "\n".join([
101
- f"<li><b>{r['rank']}</b> score={r['score']:.3f} — hadithID={r['hadithID']}<br>{r['text_ar']}</li>"
 
 
 
 
102
  for r in results
103
  ])
104
- html = f"<h3>Caption</h3><p>{cap}</p><h3>Results</h3><ol>{items}</ol>"
 
 
 
 
 
 
 
 
 
105
  return HTMLResponse(html)
106
 
107
  return JSONResponse(payload)
 
 
1
  import io
2
  import os
3
  import faiss
 
11
  import torch
12
 
13
  APP_TITLE = "Image → Hadith Similarity (FAISS)"
14
+ INDEX_PATH = "hadith_semantic.faiss"
15
+ META_PATH = "hadith_meta.parquet"
16
 
 
17
  SBERT_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
18
  BLIP_NAME = "Salesforce/blip-image-captioning-base"
19
 
20
  app = FastAPI(title=APP_TITLE)
21
 
 
22
  index = None
23
  meta = None
24
  sbert = None
 
26
  blip_model = None
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
28
 
29
+
30
  @app.on_event("startup")
31
  def load_all():
32
  global index, meta, sbert, blip_processor, blip_model
33
 
34
+ if not os.path.exists(INDEX_PATH):
35
+ raise RuntimeError(f"Missing FAISS index: {INDEX_PATH}")
36
+ if not os.path.exists(META_PATH):
37
+ raise RuntimeError(f"Missing meta file: {META_PATH}")
38
 
39
  index = faiss.read_index(INDEX_PATH)
40
  meta = pd.read_parquet(META_PATH)
41
 
42
+ # Basic sanity check
43
+ if len(meta) != index.ntotal:
44
+ # Not always fatal, but usually means mismatch between index build order and meta rows.
45
+ print(f"[WARN] meta rows ({len(meta)}) != index.ntotal ({index.ntotal}). "
46
+ f"Results will use row positions; ensure they align.")
47
+
48
  sbert = SentenceTransformer(SBERT_NAME)
49
 
50
  blip_processor = BlipProcessor.from_pretrained(BLIP_NAME)
51
  blip_model = BlipForConditionalGeneration.from_pretrained(BLIP_NAME).to(device)
52
  blip_model.eval()
53
 
54
+
55
  @app.get("/health")
56
  def health():
57
+ # Try infer dim from index when possible
58
+ dim = getattr(index, "d", None)
59
  return {
60
  "ok": True,
61
+ "index_file": INDEX_PATH,
62
+ "meta_file": META_PATH,
63
  "index_ntotal": int(index.ntotal),
64
+ "meta_rows": int(len(meta)),
65
+ "dim": int(dim) if dim is not None else None,
66
+ "text_model": SBERT_NAME,
67
  "caption_model": BLIP_NAME,
68
+ "device": device
69
  }
70
 
71
+
72
  def caption_image(pil_img: Image.Image) -> str:
73
  inputs = blip_processor(images=pil_img, return_tensors="pt").to(device)
74
  with torch.no_grad():
 
76
  cap = blip_processor.decode(out[0], skip_special_tokens=True)
77
  return cap.strip()
78
 
79
+
80
  def embed_text(text: str) -> np.ndarray:
81
+ # normalize_embeddings => cosine via inner-product
82
  v = sbert.encode([text], normalize_embeddings=True)
83
  return v.astype("float32")
84
 
85
+
86
+ def pick_col(row, candidates, default=""):
87
+ for c in candidates:
88
+ if c in row and pd.notna(row[c]):
89
+ return row[c]
90
+ return default
91
+
92
+
93
  @app.post("/search_image")
94
  async def search_image(
95
  file: UploadFile = File(...),
96
  k: int = Query(10, ge=1, le=50),
97
+ format: str = Query("json"),
98
  ):
99
  data = await file.read()
100
  pil = Image.open(io.BytesIO(data)).convert("RGB")
 
102
  cap = caption_image(pil)
103
  qvec = embed_text(cap)
104
 
 
105
  scores, idxs = index.search(qvec, k)
106
 
107
  results = []
108
  for rank, (i, s) in enumerate(zip(idxs[0].tolist(), scores[0].tolist()), start=1):
109
+ if i < 0 or i >= len(meta):
110
+ continue
111
+
112
+ row = meta.iloc[i].to_dict()
113
+
114
+ hadith_id = pick_col(row, ["hadithID", "hadith_id", "id", "doc_id"], default=i)
115
+ text_ar = pick_col(row, ["text_ar", "arabic", "ar", "text"], default="")
116
+ text_en = pick_col(row, ["text_en", "english", "en"], default="")
117
+ source = pick_col(row, ["source", "book", "collection"], default="")
118
+
119
  results.append({
120
  "rank": rank,
121
  "score": float(s),
122
+ "hadithID": int(hadith_id) if str(hadith_id).isdigit() else str(hadith_id),
123
+ "text_ar": str(text_ar),
124
+ "text_en": str(text_en),
125
+ "source": str(source),
126
  })
127
 
128
  payload = {"caption": cap, "k": k, "results": results}
129
 
130
  if format == "html":
 
131
  items = "\n".join([
132
+ f"<li><b>#{r['rank']}</b> score={r['score']:.3f} — hadithID={r['hadithID']}<br>"
133
+ f"<div style='font-family: system-ui; direction: rtl; font-size: 18px'>{r['text_ar']}</div>"
134
+ f"<div style='color:#666; margin-top:6px'>{r['text_en']}</div>"
135
+ f"<div style='color:#999; margin-top:6px'>source: {r['source']}</div>"
136
+ f"</li>"
137
  for r in results
138
  ])
139
+ html = f"""
140
+ <html>
141
+ <body style="margin:18px; font-family: system-ui">
142
+ <h3>Caption</h3>
143
+ <p>{cap}</p>
144
+ <h3>Top Results</h3>
145
+ <ol>{items}</ol>
146
+ </body>
147
+ </html>
148
+ """
149
  return HTMLResponse(html)
150
 
151
  return JSONResponse(payload)