saad003 commited on
Commit
2abe25a
·
verified ·
1 Parent(s): e3ac39c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +188 -270
app.py CHANGED
@@ -1,9 +1,6 @@
1
  # app.py
2
  import io
3
  import os
4
- import random
5
- import re
6
- from typing import Dict
7
 
8
  import faiss
9
  import torch
@@ -15,36 +12,32 @@ from fastapi.middleware.cors import CORSMiddleware
15
  from fastapi.responses import JSONResponse
16
 
17
  from huggingface_hub import hf_hub_download
18
- from transformers import (
19
- CLIPProcessor,
20
- CLIPModel,
21
- AutoTokenizer,
22
- AutoModelForSeq2SeqLM,
23
- )
24
 
25
- # ---------- FastAPI app ----------
26
  app = FastAPI()
27
 
28
  app.add_middleware(
29
  CORSMiddleware,
30
- allow_origins=["*"], # later restrict to your frontend domain
31
  allow_credentials=True,
32
  allow_methods=["*"],
33
  allow_headers=["*"],
34
  )
35
 
36
- # ---------- Config ----------
 
 
37
 
38
- EMBED_REPO_ID = "saad003/Red01" # FAISS + metadata
39
- IMAGE_REPO_ID = "saad003/images04" # images04 with test/valid/train01–train07
40
  BASE_IMAGE_URL = f"https://huggingface.co/datasets/{IMAGE_REPO_ID}/resolve/main"
41
 
42
  HF_TOKEN = os.environ.get("HF_TOKEN")
43
 
44
- device = "cuda" if torch.cuda.is_available() else "cpu"
45
- print("Using device:", device)
46
-
47
- # ---------- Download index + metadata ----------
48
  print("Downloading FAISS index & metadata from Hugging Face...")
49
 
50
  INDEX_PATH = hf_hub_download(
@@ -66,173 +59,147 @@ index = faiss.read_index(INDEX_PATH)
66
 
67
  print("Loading metadata CSV...")
68
  metadata = pd.read_csv(META_PATH)
 
 
69
  assert index.ntotal == len(metadata), "Index size and metadata rows mismatch!"
70
 
71
- # ---------- Load CLIP (retrieval) ----------
72
  print("Loading PubMedCLIP model for retrieval...")
73
  CLIP_MODEL_NAME = "flaviagiammarino/pubmed-clip-vit-base-patch32"
74
 
 
 
 
75
  clip_model = CLIPModel.from_pretrained(CLIP_MODEL_NAME).to(device)
76
  clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME)
77
  clip_model.eval()
78
 
79
- # ---------- Load FLAN-T5 (caption synthesis) ----------
80
- print("Loading FLAN-T5 for diagnosis synthesis from similar captions...")
81
- REFINER_MODEL_ID = "google/flan-t5-base"
82
 
83
- refiner_tokenizer = AutoTokenizer.from_pretrained(REFINER_MODEL_ID)
84
- refiner_model = AutoModelForSeq2SeqLM.from_pretrained(
85
- REFINER_MODEL_ID
86
- ).to(device)
87
- refiner_model.eval()
88
 
89
- print("Backend ready ✅")
90
 
 
91
 
92
- # ---------- Helper: image path mapping ----------
 
 
 
93
 
94
- def id_to_image_url(image_id: str) -> str:
95
- """
96
- Map ROCO image IDs to folders in saad003/images04.
 
97
 
98
- test -> test/
99
- valid -> valid/
100
- train -> train01 ... train07 based on numeric ID
101
- """
102
- image_id = image_id.strip()
103
- base = BASE_IMAGE_URL
104
 
105
- if "_test_" in image_id:
 
 
 
 
 
 
 
 
 
 
 
 
106
  folder = "test"
107
- elif "_valid_" in image_id:
108
  folder = "valid"
109
- elif "_train_" in image_id:
110
- num_str = image_id.split("_")[-1]
 
111
  try:
112
- n = int(num_str)
113
- except ValueError:
114
- n = 0
115
-
116
- if 1 <= n <= 9000:
117
  folder = "train01"
118
- elif 9001 <= n <= 18000:
119
- folder = "train02"
120
- elif 18001 <= n <= 27000:
121
- folder = "train03"
122
- elif 27001 <= n <= 36000:
123
- folder = "train04"
124
- elif 36001 <= n <= 45000:
125
- folder = "train05"
126
- elif 45001 <= n <= 54000:
127
- folder = "train06"
128
  else:
129
- folder = "train07"
130
- else:
131
- folder = ""
132
-
133
- if folder:
134
- return f"{base}/{folder}/{image_id}.jpg"
135
- else:
136
- return f"{base}/{image_id}.jpg"
137
-
138
-
139
- # ---------- Helper: modality detection ----------
140
-
141
- MODALITY_KEYWORDS = {
142
- "CT": [
143
- "ct ",
144
- "ctscan",
145
- "ct scan",
146
- "computed tomography",
147
- "tomography",
148
- "non-contrast ct",
149
- "contrast-enhanced ct",
150
- ],
151
- "MRI": [
152
- "mri ",
153
- "magnetic resonance",
154
- "t1-weighted",
155
- "t2-weighted",
156
- "flair sequence",
157
- "diffusion-weighted",
158
- "dwi",
159
- ],
160
- "X-ray": [
161
- "x-ray",
162
- "x ray",
163
- "radiograph",
164
- "plain film",
165
- "chest film",
166
- "postoperative x",
167
- "post-operative x",
168
- "cxr",
169
- ],
170
- "Ultrasound": [
171
- "ultrasound",
172
- "sonogram",
173
- "sonography",
174
- "usg",
175
- "doppler",
176
- "echocardiogram",
177
- "echocardiography",
178
- ],
179
- "PET/CT": [
180
- "pet-ct",
181
- "pet/ct",
182
- "pet scan",
183
- "positron emission tomography",
184
- ],
185
- "Fluoroscopy": [
186
- "fluoroscopy",
187
- "fluoroscopic",
188
- "angiogram",
189
- "angiography",
190
- "barium swallow",
191
- "barium enema",
192
- ],
193
- }
194
-
195
- def detect_modality(caption: str) -> str:
196
- if not caption:
197
- return "Unknown"
198
- text = caption.lower()
199
-
200
- for modality, keywords in MODALITY_KEYWORDS.items():
201
- for kw in keywords:
202
- if kw in text:
203
- return modality
204
-
205
- if "mra" in text:
206
- return "MRI"
207
- if "cta " in text or "ct angiography" in text:
208
  return "CT"
 
 
 
 
 
 
 
 
 
 
209
  return "Unknown"
210
 
211
 
212
- # ---------- Helper: random scoring ----------
213
-
214
- def generate_random_scores() -> Dict[str, float]:
215
- rng = random.Random()
216
-
217
- modality_score = rng.uniform(85.0, 93.0) # percent
218
- cui_at_k = rng.uniform(0.30, 0.61)
219
- bert = rng.uniform(0.20, 0.40)
220
- medbert = rng.uniform(0.20, 0.35)
221
-
222
- return {
223
- "modality_score": round(modality_score, 1),
224
- "cui_at_k": round(cui_at_k, 3),
225
- "bertscore": round(bert, 3),
226
- "medbertscore": round(medbert, 3),
227
- }
228
-
229
-
230
- # ---------- Helper: FAISS search ----------
231
-
232
  def search_similar_by_image(image: Image.Image, k: int = 5) -> pd.DataFrame:
233
  """
234
- Encode query image with CLIP, search FAISS,
235
- filter out self-match, and return top-k results.
 
 
 
236
  """
237
  inputs = clip_processor(images=image, return_tensors="pt").to(device)
238
  with torch.no_grad():
@@ -241,167 +208,118 @@ def search_similar_by_image(image: Image.Image, k: int = 5) -> pd.DataFrame:
241
  feats = feats / feats.norm(p=2, dim=-1, keepdim=True)
242
  feats = feats.cpu().numpy().astype("float32")
243
 
244
- search_k = min(index.ntotal, k + 5)
245
- D, I = index.search(feats, search_k)
246
-
247
  rows = metadata.iloc[I[0]].copy()
248
  rows["score"] = D[0]
249
 
250
- # drop exact self-match
251
- rows = rows[rows["score"] < 0.999].copy()
252
-
253
- rows["image_url"] = rows["ID"].apply(id_to_image_url)
254
-
255
- rows = rows.sort_values("score", ascending=False).head(k)
256
- if "concepts_manual" not in rows.columns:
257
- rows["concepts_manual"] = ""
258
-
259
- return rows[["ID", "caption", "concepts_manual", "score", "image_url"]]
260
-
261
 
262
- # ---------- Helper: caption cleaning & synthesis ----------
 
263
 
264
- def clean_caption(text: str) -> str:
265
- """
266
- Clean generated caption:
267
- - strip
268
- - remove obvious prompt leftovers
269
- - ensure single sentence, nice punctuation
270
- """
271
- if not text:
272
- return ""
273
-
274
- text = text.strip()
275
-
276
- # Drop any leading instruction-like fragments
277
- text = re.sub(
278
- r"^(you are an expert radiologist[:,]?\s*)",
279
- "",
280
- text,
281
- flags=re.IGNORECASE,
282
  )
283
- text = re.sub(
284
- r"(findings? from similar radiology cases[:,]?\s*)",
285
- "",
286
- text,
287
- flags=re.IGNORECASE,
288
- )
289
-
290
- # Replace multiple separators
291
- text = text.replace(" ;", ";")
292
- text = re.sub(r"\s+[,;]\s*", ", ", text)
293
-
294
- # Collapse spaces
295
- text = " ".join(text.split())
296
 
297
- # If there are multiple sentences, keep only the first one
298
- parts = re.split(r"(?<=[.!?])\s+", text)
299
- if parts:
300
- text = parts[0]
301
 
302
- # Ensure period
303
- if text and not text.endswith((".", "!", "?")):
304
- text += "."
305
 
306
- # Capitalize first letter
307
- if text:
308
- text = text[0].upper() + text[1:]
309
-
310
- return text
311
-
312
-
313
- def synthesize_caption_from_similar_captions(captions: list[str]) -> str:
314
  """
315
- Use FLAN-T5 to create a diagnosis sentence from captions of similar images.
 
316
  """
317
- captions = [c.strip() for c in captions if c and isinstance(c, str)]
318
- if not captions:
319
- return ""
320
 
321
- # Use at most 5-6 captions to keep prompt short
322
- caps = captions[:6]
 
 
 
 
 
323
 
324
- numbered = "\n".join(
325
- f"{i+1}) {c}" for i, c in enumerate(caps)
326
- )
 
327
 
328
  prompt = (
329
- "Radiology findings from similar cases:\n"
330
- f"{numbered}\n\n"
331
- "Based on these, write ONE concise radiology impression sentence "
332
- "describing the most likely diagnosis and key findings for the "
333
- "current image. Do not mention numbers or 'similar cases'."
334
  )
335
 
336
- inputs = refiner_tokenizer(
337
- prompt,
 
338
  return_tensors="pt",
339
- truncation=True,
340
- max_length=512,
341
- ).to(device)
342
 
343
  with torch.no_grad():
344
- out_ids = refiner_model.generate(
345
  **inputs,
346
- max_new_tokens=48,
347
- num_beams=4,
348
- length_penalty=0.9,
349
- no_repeat_ngram_size=4,
 
350
  )
351
 
352
- raw = refiner_tokenizer.decode(out_ids[0], skip_special_tokens=True)
353
- return clean_caption(raw)
 
354
 
 
355
 
356
- # ---------- Routes ----------
357
 
 
358
  @app.get("/")
359
  def root():
360
- return {
361
- "status": "ok",
362
- "message": "Radiology retrieval + FLAN-T5 synthesis from similar captions",
363
- }
364
 
365
 
366
  @app.post("/search_by_image")
367
  async def search_by_image(file: UploadFile = File(...), k: int = 5):
368
  """
369
- Upload a radiology image.
370
- Returns:
371
- - query_caption: synthesized diagnosis from captions of similar images
372
- - modality: detected imaging modality
373
- - scores: random quality metrics
374
- - results: similar images (similarity + concepts + image_url)
 
 
375
  """
376
  content = await file.read()
377
  image = Image.open(io.BytesIO(content)).convert("RGB")
378
 
379
- k = int(k)
380
-
381
  # 1) Retrieval
382
  results_df = search_similar_by_image(image, k=k)
383
  results = results_df.to_dict(orient="records")
384
 
385
- # 2) Synthesize caption only from similar image captions
386
- similar_caps_list = results_df["caption"].astype(str).tolist()
387
 
 
388
  try:
389
- final_caption = synthesize_caption_from_similar_captions(
390
- similar_caps_list
391
- )
392
  except Exception as e:
393
- print("Error synthesizing caption:", e)
394
- final_caption = ""
395
 
396
- # 3) Modality & scores
397
- modality = detect_modality(final_caption or "")
398
- scores = generate_random_scores()
399
 
400
  return JSONResponse(
401
  {
402
- "query_caption": final_caption,
403
  "modality": modality,
404
- "scores": scores,
405
  "results": results,
406
  }
407
  )
 
1
  # app.py
2
  import io
3
  import os
 
 
 
4
 
5
  import faiss
6
  import torch
 
12
  from fastapi.responses import JSONResponse
13
 
14
  from huggingface_hub import hf_hub_download
15
+ from transformers import CLIPProcessor, CLIPModel
16
+ from transformers import AutoProcessor, AutoModelForVision2Seq
17
+ from peft import PeftConfig, PeftModel
 
 
 
18
 
19
+ # ---------------- FastAPI app ----------------
20
  app = FastAPI()
21
 
22
  app.add_middleware(
23
  CORSMiddleware,
24
+ allow_origins=["*"],
25
  allow_credentials=True,
26
  allow_methods=["*"],
27
  allow_headers=["*"],
28
  )
29
 
30
+ # ---------------- Config ----------------
31
+ # FAISS index + metadata
32
+ EMBED_REPO_ID = "saad003/Red01"
33
 
34
+ # All radiology images (with test / valid / train01..07 folders)
35
+ IMAGE_REPO_ID = "saad003/images04"
36
  BASE_IMAGE_URL = f"https://huggingface.co/datasets/{IMAGE_REPO_ID}/resolve/main"
37
 
38
  HF_TOKEN = os.environ.get("HF_TOKEN")
39
 
40
+ # ---------------- Download index + metadata ----------------
 
 
 
41
  print("Downloading FAISS index & metadata from Hugging Face...")
42
 
43
  INDEX_PATH = hf_hub_download(
 
59
 
60
  print("Loading metadata CSV...")
61
  metadata = pd.read_csv(META_PATH)
62
+
63
+ # Sanity-check sizes
64
  assert index.ntotal == len(metadata), "Index size and metadata rows mismatch!"
65
 
66
+ # ---------------- CLIP retrieval model ----------------
67
  print("Loading PubMedCLIP model for retrieval...")
68
  CLIP_MODEL_NAME = "flaviagiammarino/pubmed-clip-vit-base-patch32"
69
 
70
+ device = "cuda" if torch.cuda.is_available() else "cpu"
71
+ print("Using device:", device)
72
+
73
  clip_model = CLIPModel.from_pretrained(CLIP_MODEL_NAME).to(device)
74
  clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME)
75
  clip_model.eval()
76
 
77
+ # ---------------- Med-BLIP-2 captioning model ----------------
78
+ # This is a BLIP-2 model fine-tuned on ROCO via QLoRA
79
+ print("Loading Med-BLIP-2 captioning model...")
80
 
81
+ CAPTION_ADAPTER_ID = "NouRed/Med-BLIP-2-QLoRA-ROCO"
82
+ peft_config = PeftConfig.from_pretrained(CAPTION_ADAPTER_ID)
83
+ BASE_CAPTION_MODEL = peft_config.base_model_name_or_path # should be Salesforce/blip2-opt-2.7b
 
 
84
 
85
+ caption_processor = AutoProcessor.from_pretrained(BASE_CAPTION_MODEL)
86
 
87
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
88
 
89
+ base_caption_model = AutoModelForVision2Seq.from_pretrained(
90
+ BASE_CAPTION_MODEL,
91
+ torch_dtype=dtype,
92
+ )
93
 
94
+ caption_model = PeftModel.from_pretrained(
95
+ base_caption_model,
96
+ CAPTION_ADAPTER_ID,
97
+ )
98
 
99
+ caption_model.to(device)
100
+ caption_model.eval()
 
 
 
 
101
 
102
+ print("Backend ready ✅")
103
+
104
+
105
+ # ---------------- Helper: build image URL ----------------
106
+ def id_to_image_url(image_id: str, split: str) -> str:
107
+ """
108
+ Map ROCO ID + split to the correct folder in saad003/images04.
109
+ Folders:
110
+ - test/...
111
+ - valid/...
112
+ - train01/ .. train07/ (train images split by numeric range)
113
+ """
114
+ if split == "test":
115
  folder = "test"
116
+ elif split == "valid":
117
  folder = "valid"
118
+ else:
119
+ # train split, we route to train01..train07 based on ID number
120
+ # Example ID: ROCOv2_2023_train_036004 -> num = 36004
121
  try:
122
+ num_str = image_id.split("_")[-1]
123
+ num = int(num_str)
124
+ except Exception:
125
+ # fallback, just put in train01
 
126
  folder = "train01"
 
 
 
 
 
 
 
 
 
 
127
  else:
128
+ # Roughly 9k images per shard, based on how you uploaded them
129
+ if num <= 9000:
130
+ folder = "train01"
131
+ elif num <= 18000:
132
+ folder = "train02"
133
+ elif num <= 27000:
134
+ folder = "train03"
135
+ elif num <= 36000:
136
+ folder = "train04"
137
+ elif num <= 45000:
138
+ folder = "train05"
139
+ elif num <= 54000:
140
+ folder = "train06"
141
+ else:
142
+ folder = "train07"
143
+
144
+ return f"{BASE_IMAGE_URL}/{folder}/{image_id}.jpg"
145
+
146
+
147
+ # ---------------- Helper: modality detection ----------------
148
+ def infer_modality_from_text(text: str) -> str:
149
+ """
150
+ Simple keyword-based modality detection from the generated caption.
151
+ Tries to be generous with synonyms.
152
+ """
153
+ t = text.lower()
154
+
155
+ ct_keywords = [
156
+ "ct scan", "computed tomography", "ct of the", "ct angiography",
157
+ "cta", "contrast-enhanced ct", "non-contrast ct", "non contrast ct",
158
+ ]
159
+ mri_keywords = [
160
+ "mri", "mr imaging", "magnetic resonance",
161
+ "t1-weighted", "t2-weighted", "flair sequence", "diffusion-weighted imaging",
162
+ ]
163
+ xray_keywords = [
164
+ "x-ray", "x ray", "radiograph", "plain film",
165
+ "chest film", "chest xray", "chest x-ray", "anteroposterior", "posteroanterior",
166
+ ]
167
+ ultrasound_keywords = [
168
+ "ultrasound", "sonography", "sonogram", "echogenic", "doppler",
169
+ ]
170
+ nuclear_keywords = [
171
+ "pet-ct", "pet ct", "pet/ct", "spect", "nuclear medicine", "scintigraphy",
172
+ ]
173
+ mammo_keywords = [
174
+ "mammogram", "mammography", "craniocaudal", "mediolateral oblique",
175
+ ]
176
+
177
+ def has_any(keys):
178
+ return any(k in t for k in keys)
179
+
180
+ if has_any(ct_keywords):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  return "CT"
182
+ if has_any(mri_keywords):
183
+ return "MRI"
184
+ if has_any(xray_keywords):
185
+ return "X-ray"
186
+ if has_any(ultrasound_keywords):
187
+ return "Ultrasound"
188
+ if has_any(nuclear_keywords):
189
+ return "Nuclear medicine / PET"
190
+ if has_any(mammo_keywords):
191
+ return "Mammography"
192
  return "Unknown"
193
 
194
 
195
+ # ---------------- Helper: FAISS retrieval ----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  def search_similar_by_image(image: Image.Image, k: int = 5) -> pd.DataFrame:
197
  """
198
+ Encode query image with PubMedCLIP, search FAISS, return DataFrame with:
199
+ ID, split, caption, concepts_manual, score, image_url
200
+
201
+ Also removes the *exact* self-match (score very close to 1.0)
202
+ so the query image is not shown again in the similar-images list.
203
  """
204
  inputs = clip_processor(images=image, return_tensors="pt").to(device)
205
  with torch.no_grad():
 
208
  feats = feats / feats.norm(p=2, dim=-1, keepdim=True)
209
  feats = feats.cpu().numpy().astype("float32")
210
 
211
+ D, I = index.search(feats, k + 1) # search a bit more so we can drop the self-match
 
 
212
  rows = metadata.iloc[I[0]].copy()
213
  rows["score"] = D[0]
214
 
215
+ # Drop suspected identical match (usually score == 1.0)
216
+ rows = rows[rows["score"] < 0.9999]
 
 
 
 
 
 
 
 
 
217
 
218
+ # Limit to requested top-k after filtering
219
+ rows = rows.head(k)
220
 
221
+ # Add image URLs
222
+ rows["image_url"] = rows.apply(
223
+ lambda r: id_to_image_url(str(r["ID"]), str(r["split"])), axis=1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
+ # Keep only what we actually need
227
+ return rows[["ID", "split", "caption", "concepts_manual", "score", "image_url"]]
 
 
228
 
 
 
 
229
 
230
+ # ---------------- Helper: BLIP-2 caption using similar captions ----------------
231
+ def generate_query_caption(image: Image.Image, similar_captions=None) -> str:
 
 
 
 
 
 
232
  """
233
+ Use Med-BLIP-2 to generate a diagnosis-style caption.
234
+ We condition the text prompt on captions from top-k similar images.
235
  """
236
+ similar_captions = similar_captions or []
 
 
237
 
238
+ # Take at most 3 similar captions and truncate each a bit so the prompt doesn't explode
239
+ cleaned_similar = []
240
+ for cap in similar_captions[:3]:
241
+ cap = str(cap).strip()
242
+ if len(cap) > 260:
243
+ cap = cap[:260] + "..."
244
+ cleaned_similar.append(cap)
245
 
246
+ similar_block = ""
247
+ if cleaned_similar:
248
+ joined = " || ".join(cleaned_similar)
249
+ similar_block = f" Findings from similar radiology cases: {joined}"
250
 
251
  prompt = (
252
+ "You are an expert radiologist. Based only on the image and the findings below, "
253
+ "write a concise diagnostic summary in 2–3 short sentences. "
254
+ "Use precise medical terminology and avoid repeating words or phrases."
255
+ + similar_block
 
256
  )
257
 
258
+ inputs = caption_processor(
259
+ images=image,
260
+ text=prompt,
261
  return_tensors="pt",
262
+ ).to(device, dtype)
 
 
263
 
264
  with torch.no_grad():
265
+ generated_ids = caption_model.generate(
266
  **inputs,
267
+ max_new_tokens=96,
268
+ num_beams=3,
269
+ do_sample=False,
270
+ repetition_penalty=1.25,
271
+ no_repeat_ngram_size=3,
272
  )
273
 
274
+ caption = caption_processor.batch_decode(
275
+ generated_ids, skip_special_tokens=True
276
+ )[0]
277
 
278
+ return caption.strip()
279
 
 
280
 
281
+ # ---------------- Routes ----------------
282
  @app.get("/")
283
  def root():
284
+ return {"status": "ok", "message": "Radiology retrieval + Med-BLIP-2 captioning API"}
 
 
 
285
 
286
 
287
  @app.post("/search_by_image")
288
  async def search_by_image(file: UploadFile = File(...), k: int = 5):
289
  """
290
+ Request:
291
+ - file: uploaded radiology image
292
+ - k: number of similar images
293
+
294
+ Response:
295
+ - query_caption: Med-BLIP-2 diagnosis summary for the query
296
+ - modality: inferred imaging modality
297
+ - results: list of similar images with their captions, concepts, score, image_url
298
  """
299
  content = await file.read()
300
  image = Image.open(io.BytesIO(content)).convert("RGB")
301
 
 
 
302
  # 1) Retrieval
303
  results_df = search_similar_by_image(image, k=k)
304
  results = results_df.to_dict(orient="records")
305
 
306
+ # 2) Use captions of similar images as extra context
307
+ similar_caps_for_prompt = results_df["caption"].tolist()
308
 
309
+ # 3) Captioning for the query image
310
  try:
311
+ query_caption = generate_query_caption(image, similar_caps_for_prompt)
 
 
312
  except Exception as e:
313
+ print("Error generating caption:", e)
314
+ query_caption = ""
315
 
316
+ # 4) Modality inference from the generated caption
317
+ modality = infer_modality_from_text(query_caption)
 
318
 
319
  return JSONResponse(
320
  {
321
+ "query_caption": query_caption,
322
  "modality": modality,
 
323
  "results": results,
324
  }
325
  )