saad003 commited on
Commit
81e97e9
·
verified ·
1 Parent(s): 63a5265

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +619 -256
app.py CHANGED
@@ -1,8 +1,9 @@
1
  # app.py
2
  import io
3
  import os
4
- import re
5
  import random
 
 
6
 
7
  import faiss
8
  import torch
@@ -14,26 +15,38 @@ from fastapi.middleware.cors import CORSMiddleware
14
  from fastapi.responses import JSONResponse
15
 
16
  from huggingface_hub import hf_hub_download
17
- from transformers import CLIPProcessor, CLIPModel
18
- from transformers import BlipForConditionalGeneration, AutoProcessor
 
 
 
 
19
 
20
  # ---------- FastAPI app ----------
21
  app = FastAPI()
22
 
23
  app.add_middleware(
24
  CORSMiddleware,
25
- allow_origins=["*"], # later you can restrict to your frontend domain
26
  allow_credentials=True,
27
  allow_methods=["*"],
28
  allow_headers=["*"],
29
  )
30
 
31
  # ---------- Config ----------
32
- EMBED_REPO_ID = "saad003/Red01" # FAISS + metadata
33
- IMAGE_REPO_ID = "saad003/images04" # test, valid, train01..train07
 
 
 
 
34
  BASE_IMAGE_URL = f"https://huggingface.co/datasets/{IMAGE_REPO_ID}/resolve/main"
35
 
36
- HF_TOKEN = os.environ.get("HF_TOKEN") # set in HF Space secrets if private
 
 
 
 
37
 
38
  # ---------- Download index + metadata ----------
39
  print("Downloading FAISS index & metadata from Hugging Face...")
@@ -58,350 +71,700 @@ index = faiss.read_index(INDEX_PATH)
58
  print("Loading metadata CSV...")
59
  metadata = pd.read_csv(META_PATH)
60
 
61
- required_cols = {"vec_index", "ID", "caption", "concepts_manual"}
62
- missing = required_cols - set(metadata.columns)
63
- if missing:
64
- raise ValueError(f"radiology_metadata.csv is missing columns: {missing}")
65
-
66
  assert index.ntotal == len(metadata), "Index size and metadata rows mismatch!"
67
 
68
  # ---------- Load CLIP (retrieval) ----------
69
  print("Loading PubMedCLIP model for retrieval...")
70
  CLIP_MODEL_NAME = "flaviagiammarino/pubmed-clip-vit-base-patch32"
71
 
72
- device = "cuda" if torch.cuda.is_available() else "cpu"
73
- print("Using device:", device)
74
-
75
  clip_model = CLIPModel.from_pretrained(CLIP_MODEL_NAME).to(device)
76
  clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME)
77
  clip_model.eval()
78
 
79
- # ---------- Load BLIP (captioning) ----------
80
- print("Loading BLIP radiology captioning model...")
81
- CAPTION_MODEL_ID = "WafaaFraih/blip-roco-radiology-captioning"
 
 
 
82
 
83
- caption_processor = AutoProcessor.from_pretrained(CAPTION_MODEL_ID)
84
- caption_model = BlipForConditionalGeneration.from_pretrained(
85
- CAPTION_MODEL_ID
 
86
  ).to(device)
87
  caption_model.eval()
88
 
89
  print("Backend ready ✅")
90
 
91
 
92
- # ---------- Helpers for dataset path ----------
93
- def train_folder_from_id(image_id: str) -> str:
94
- """
95
- For IDs like 'ROCOv2_2023_train_000001', decide which trainXX folder
96
- based on the last 6 digits.
97
- """
98
- try:
99
- num_str = image_id.split("_")[-1] # "000001"
100
- num = int(num_str)
101
- except Exception:
102
- return "train01" # safe default
103
-
104
- if num <= 9000:
105
- return "train01"
106
- elif num <= 18000:
107
- return "train02"
108
- elif num <= 27000:
109
- return "train03"
110
- elif num <= 36000:
111
- return "train04"
112
- elif num <= 45000:
113
- return "train05"
114
- elif num <= 54000:
115
- return "train06"
116
- else:
117
- return "train07"
118
-
119
 
120
  def id_to_image_url(image_id: str) -> str:
121
  """
122
- Build raw image URL based on ID and folder structure.
123
 
124
- Examples:
125
- ROCOv2_2023_test_000001 -> test/ROCOv2_2023_test_000001.jpg
126
- ROCOv2_2023_valid_000005 -> valid/ROCOv2_2023_valid_000005.jpg
127
- ROCOv2_2023_train_000001 -> train01/ROCOv2_2023_train_000001.jpg
128
  """
129
- if not isinstance(image_id, str):
130
- return None
131
-
132
  image_id = image_id.strip()
 
133
 
134
- if "test_" in image_id:
135
  folder = "test"
136
- elif "valid_" in image_id:
137
  folder = "valid"
138
- elif "train_" in image_id:
139
- folder = train_folder_from_id(image_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  else:
141
  folder = ""
142
 
143
- filename = f"{image_id}.jpg"
144
-
145
  if folder:
146
- return f"{BASE_IMAGE_URL}/{folder}/{filename}"
147
  else:
148
- return f"{BASE_IMAGE_URL}/{filename}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
 
151
- def search_similar_by_image(
152
- image: Image.Image, k: int = 5, query_id: str | None = None
153
- ) -> pd.DataFrame:
 
 
154
  """
155
- Encode query image with CLIP, search FAISS, and return top-k rows
156
- with vec_index, ID, caption, concepts_manual, score, image_url.
 
 
 
 
 
 
 
 
 
 
 
157
 
158
- If query_id is provided, we exclude that exact ID from results
159
- (so the query image itself is not returned as "similar").
 
 
 
 
 
160
  """
161
- # Encode query
162
  inputs = clip_processor(images=image, return_tensors="pt").to(device)
163
  with torch.no_grad():
164
  feats = clip_model.get_image_features(**inputs)
165
 
 
166
  feats = feats / feats.norm(p=2, dim=-1, keepdim=True)
167
  feats = feats.cpu().numpy().astype("float32")
168
 
169
- # Fetch a few extra results in case we need to drop the query image
170
- extra = 1 if query_id else 0
171
- D, I = index.search(feats, k + extra)
172
 
173
  rows = metadata.iloc[I[0]].copy()
174
  rows["score"] = D[0]
 
 
 
 
 
175
  rows["image_url"] = rows["ID"].apply(id_to_image_url)
176
 
177
- if query_id:
178
- qid = query_id.strip()
179
- rows = rows[rows["ID"] != qid]
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
- # Keep only top-k after filtering
182
- if len(rows) > k:
183
- rows = rows.iloc[:k]
 
 
 
 
184
 
185
- return rows[
186
- ["vec_index", "ID", "caption", "concepts_manual", "score", "image_url"]
187
- ]
 
 
 
 
 
 
188
 
189
 
190
- # ---------- Captioning ----------
191
  def generate_query_caption(image: Image.Image) -> str:
192
- """Generate a medical caption for the query image using BLIP."""
193
- inputs = caption_processor(images=image, return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  with torch.no_grad():
195
- out = caption_model.generate(**inputs, max_new_tokens=64)
196
- caption = caption_processor.batch_decode(out, skip_special_tokens=True)[0]
197
- return caption.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
- # ---------- Improved modality detection ----------
201
- def infer_modality_from_caption(caption: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
202
  """
203
- Heuristic modality detector, fairly robust to spelling/spacing.
 
 
 
 
204
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  if not caption:
206
  return "Unknown"
207
-
208
  text = caption.lower()
209
- text = " " + " ".join(text.split()) + " "
210
- normalized = re.sub(r"[^a-z0-9]", "", text)
211
-
212
- def contains_any(substrs, use_normalized=False):
213
- target = normalized if use_normalized else text
214
- return any(s in target for s in substrs)
215
-
216
- # PET / PET-CT
217
- if contains_any(
218
- [
219
- " pet-ct ",
220
- " pet ct ",
221
- " pet/ct ",
222
- " fdg pet ",
223
- " fdg-pet ",
224
- " positron emission tomography ",
225
- ]
226
- ) or contains_any(["petscan", "fdgpet"], use_normalized=True):
227
- return "PET/CT"
228
-
229
- # CT
230
- if contains_any(
231
- [
232
- " ct scan",
233
- " ct of ",
234
- "ct of ",
235
- "contrast-enhanced ct",
236
- "contrast enhanced ct",
237
- "non-contrast ct",
238
- "non contrast ct",
239
- "computed tomography",
240
- "computerized tomography",
241
- "computerised tomography",
242
- ]
243
- ) or contains_any(["ctscan", "cect"], use_normalized=True):
244
- return "CT"
245
 
246
- # MRI
247
- if contains_any(
248
- [
249
- " mri ",
250
- " mr imaging",
251
- " mr scan",
252
- " mr study",
253
- " magnetic resonance",
254
- " mr of ",
255
- ]
256
- ) or contains_any(
257
- [
258
- "t1weighted",
259
- "t2weighted",
260
- "flairsequence",
261
- "diffusionweighted",
262
- "dwi",
263
- "swisequence",
264
- "susceptibilityweighted",
265
- ],
266
- use_normalized=True,
267
- ):
268
  return "MRI"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
 
270
- # X-ray / radiography
271
- if (
272
- contains_any(
273
- [
274
- " x-ray",
275
- " x ray",
276
- " chest xray",
277
- " chest x-ray",
278
- " radiograph",
279
- " radiography",
280
- " plain film",
281
- " plain radiograph",
282
- " chest radiograph",
283
- " erect chest",
284
- " upright chest",
285
- " lateral view",
286
- " ap view ",
287
- " pa view ",
288
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  )
290
- or contains_any(["xray", "cxr"], use_normalized=True)
291
- ):
292
- return "X-ray"
293
-
294
- # Ultrasound
295
- if contains_any(
296
- [
297
- " ultrasound",
298
- " usg ",
299
- " sonography",
300
- " sonogram",
301
- " echography",
302
- " echocardiogram",
303
- " echocardiography",
304
- " doppler ultrasound",
305
- " duplex ultrasound",
306
- " transvaginal ultrasound",
307
- " transabdominal ultrasound",
308
- ]
309
- ) or contains_any(["ultrasoundscan"], use_normalized=True):
310
- return "Ultrasound"
311
-
312
- # Mammography
313
- if contains_any(
314
- [
315
- " mammogram",
316
- " mammography",
317
- " screening mammo",
318
- " diagnostic mammo",
319
- ]
320
- ):
321
- return "Mammography"
322
-
323
- # Angiography / Fluoroscopy
324
- if contains_any(
325
- [
326
- " angiogram",
327
- " angiography",
328
- " digital subtraction angiography",
329
- " dsa ",
330
- " fluoroscopy",
331
- " fluoroscopic",
332
- " catheter angiography",
333
- ]
334
- ):
335
- return "Angiography / Fluoroscopy"
336
-
337
- # Nuclear medicine (non-PET)
338
- if contains_any(
339
- [
340
- " scintigraphy",
341
- " bone scan",
342
- " radionuclide",
343
- " radioisotope",
344
- " sestamibi",
345
- "mibg ",
346
- ]
347
- ):
348
- return "Nuclear medicine"
349
 
350
- return "Unknown"
 
 
 
351
 
352
 
353
  # ---------- Routes ----------
 
354
  @app.get("/")
355
  def root():
356
- return {"status": "ok", "message": "Radiology retrieval + captioning API"}
357
 
358
 
359
  @app.post("/search_by_image")
360
  async def search_by_image(file: UploadFile = File(...), k: int = 5):
361
  """
362
  Upload a radiology image.
363
-
364
  Returns:
365
- - query_caption: BLIP caption ("diagnosis details")
366
- - modality: inferred imaging modality
367
- - modality_score, cui_at_k, bert_score, medbert_score (random metrics)
368
- - results: list of similar images with
369
- ID, concepts_manual, score, image_url
370
  """
 
371
  content = await file.read()
372
  image = Image.open(io.BytesIO(content)).convert("RGB")
373
 
374
- # derive ID from filename (strip extension)
375
- filename = file.filename or ""
376
- query_id = filename.rsplit(".", 1)[0] if "." in filename else filename
377
-
378
- # 1) Retrieval (exclude the query image itself if present)
379
- results_df = search_similar_by_image(image, k=k, query_id=query_id)
380
  results = results_df.to_dict(orient="records")
381
 
382
- # 2) Caption
383
  try:
384
  query_caption = generate_query_caption(image)
385
  except Exception as e:
386
- print("Error generating caption:", e)
387
  query_caption = None
388
 
389
- # 3) Modality + random metrics
390
- modality = infer_modality_from_caption(query_caption or "")
391
 
392
- modality_score = round(random.uniform(0.85, 0.93), 3)
393
- cui_at_k = round(random.uniform(0.30, 0.61), 3)
394
- bert_score = round(random.uniform(0.20, 0.40), 3)
395
- medbert_score = round(random.uniform(0.20, 0.35), 3)
396
 
397
  return JSONResponse(
398
  {
399
  "query_caption": query_caption,
400
  "modality": modality,
401
- "modality_score": modality_score,
402
- "cui_at_k": cui_at_k,
403
- "bert_score": bert_score,
404
- "medbert_score": medbert_score,
405
  "results": results,
406
  }
407
  )
 
1
  # app.py
2
  import io
3
  import os
 
4
  import random
5
+ import re
6
+ from typing import Dict
7
 
8
  import faiss
9
  import torch
 
15
  from fastapi.responses import JSONResponse
16
 
17
  from huggingface_hub import hf_hub_download
18
+ from transformers import (
19
+ CLIPProcessor,
20
+ CLIPModel,
21
+ Blip2Processor,
22
+ Blip2ForConditionalGeneration,
23
+ )
24
 
25
  # ---------- FastAPI app ----------
26
  app = FastAPI()
27
 
28
  app.add_middleware(
29
  CORSMiddleware,
30
+ allow_origins=["*"], # later restrict to your frontend domain
31
  allow_credentials=True,
32
  allow_methods=["*"],
33
  allow_headers=["*"],
34
  )
35
 
36
  # ---------- Config ----------
37
+
38
+ # Dataset with FAISS index + radiology_metadata.csv
39
+ EMBED_REPO_ID = "saad003/Red01"
40
+
41
+ # Dataset with all radiology images (new structure with train01–train07)
42
+ IMAGE_REPO_ID = "saad003/images04"
43
  BASE_IMAGE_URL = f"https://huggingface.co/datasets/{IMAGE_REPO_ID}/resolve/main"
44
 
45
+ # Optional: token if Red01 is private
46
+ HF_TOKEN = os.environ.get("HF_TOKEN")
47
+
48
+ device = "cuda" if torch.cuda.is_available() else "cpu"
49
+ print("Using device:", device)
50
 
51
  # ---------- Download index + metadata ----------
52
  print("Downloading FAISS index & metadata from Hugging Face...")
 
71
  print("Loading metadata CSV...")
72
  metadata = pd.read_csv(META_PATH)
73
 
 
 
 
 
 
74
  assert index.ntotal == len(metadata), "Index size and metadata rows mismatch!"
75
 
76
  # ---------- Load CLIP (retrieval) ----------
77
  print("Loading PubMedCLIP model for retrieval...")
78
  CLIP_MODEL_NAME = "flaviagiammarino/pubmed-clip-vit-base-patch32"
79
 
 
 
 
80
  clip_model = CLIPModel.from_pretrained(CLIP_MODEL_NAME).to(device)
81
  clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME)
82
  clip_model.eval()
83
 
84
+ # ---------- Load BLIP-2 (captioning) ----------
85
+ print("Loading BLIP-2 model for medical captioning...")
86
+ CAPTION_MODEL_ID = "Salesforce/blip2-opt-2.7b"
87
+
88
+ # Use fp16 on GPU, fp32 on CPU
89
+ caption_dtype = torch.float16 if device == "cuda" else torch.float32
90
 
91
+ caption_processor = Blip2Processor.from_pretrained(CAPTION_MODEL_ID)
92
+ caption_model = Blip2ForConditionalGeneration.from_pretrained(
93
+ CAPTION_MODEL_ID,
94
+ torch_dtype=caption_dtype,
95
  ).to(device)
96
  caption_model.eval()
97
 
98
  print("Backend ready ✅")
99
 
100
 
101
+ # ---------- Helper: image path mapping ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  def id_to_image_url(image_id: str) -> str:
104
  """
105
+ Map ROCO image IDs to folders in saad003/images04.
106
 
107
+ test -> test/
108
+ valid -> valid/
109
+ train -> train01 ... train07 based on numeric ID
 
110
  """
 
 
 
111
  image_id = image_id.strip()
112
+ base = BASE_IMAGE_URL
113
 
114
+ if "_test_" in image_id:
115
  folder = "test"
116
+ elif "_valid_" in image_id:
117
  folder = "valid"
118
+ elif "_train_" in image_id:
119
+ # last part: ROCOv2_2023_train_054005 -> "054005"
120
+ num_str = image_id.split("_")[-1]
121
+ try:
122
+ n = int(num_str)
123
+ except ValueError:
124
+ n = 0
125
+
126
+ # Rough ranges based on your description
127
+ if 1 <= n <= 9000:
128
+ folder = "train01"
129
+ elif 9001 <= n <= 18000:
130
+ folder = "train02"
131
+ elif 18001 <= n <= 27000:
132
+ folder = "train03"
133
+ elif 27001 <= n <= 36000:
134
+ folder = "train04"
135
+ elif 36001 <= n <= 45000:
136
+ folder = "train05"
137
+ elif 45001 <= n <= 54000:
138
+ folder = "train06"
139
+ else:
140
+ folder = "train07"
141
  else:
142
  folder = ""
143
 
 
 
144
  if folder:
145
+ return f"{base}/{folder}/{image_id}.jpg"
146
  else:
147
+ # fallback – should not happen, but safe
148
+ return f"{base}/{image_id}.jpg"
149
+
150
+
151
+ # ---------- Helper: modality detection ----------
152
+
153
+ MODALITY_KEYWORDS = {
154
+ "CT": [
155
+ "ct ",
156
+ "ctscan",
157
+ "computed tomography",
158
+ "tomography",
159
+ "ct scan",
160
+ "non-contrast ct",
161
+ "contrast-enhanced ct",
162
+ ],
163
+ "MRI": [
164
+ "mri ",
165
+ "magnetic resonance",
166
+ "t1-weighted",
167
+ "t2-weighted",
168
+ "flair sequence",
169
+ "diffusion-weighted",
170
+ "dwi",
171
+ ],
172
+ "X-ray": [
173
+ "x-ray",
174
+ "x ray",
175
+ "radiograph",
176
+ "plain film",
177
+ "chest film",
178
+ "postoperative x",
179
+ "post-operative x",
180
+ "cxr",
181
+ ],
182
+ "Ultrasound": [
183
+ "ultrasound",
184
+ "sonogram",
185
+ "sonography",
186
+ "usg",
187
+ "doppler",
188
+ "echocardiogram",
189
+ "echocardiography",
190
+ ],
191
+ "PET/CT": [
192
+ "pet-ct",
193
+ "pet/ct",
194
+ "pet scan",
195
+ "positron emission tomography",
196
+ ],
197
+ "Fluoroscopy": [
198
+ "fluoroscopy",
199
+ "fluoroscopic",
200
+ "angiogram",
201
+ "angiography",
202
+ "barium swallow",
203
+ "barium enema",
204
+ ],
205
+ }
206
+
207
+ def detect_modality(caption: str) -> str:
208
+ if not caption:
209
+ return "Unknown"
210
+ text = caption.lower()
211
+
212
+ for modality, keywords in MODALITY_KEYWORDS.items():
213
+ for kw in keywords:
214
+ if kw in text:
215
+ return modality
216
+
217
+ # Back-up heuristics
218
+ if "mra" in text:
219
+ return "MRI"
220
+ if "cta " in text or "ct angiography" in text:
221
+ return "CT"
222
+ return "Unknown"
223
 
224
 
225
+ # ---------- Helper: random scoring ----------
226
+
227
+ def generate_random_scores() -> Dict[str, float]:
228
+ """
229
+ Return random scores in the ranges you specified.
230
  """
231
+ rng = random.Random()
232
+
233
+ modality_score = rng.uniform(85.0, 93.0) # percent
234
+ cui_at_k = rng.uniform(0.30, 0.61)
235
+ bert = rng.uniform(0.20, 0.40)
236
+ medbert = rng.uniform(0.20, 0.35)
237
+
238
+ return {
239
+ "modality_score": round(modality_score, 1),
240
+ "cui_at_k": round(cui_at_k, 3),
241
+ "bertscore": round(bert, 3),
242
+ "medbertscore": round(medbert, 3),
243
+ }
244
 
245
+
246
+ # ---------- Helper: search by image ----------
247
+
248
+ def search_similar_by_image(image: Image.Image, k: int = 5) -> pd.DataFrame:
249
+ """
250
+ Encode query image with CLIP, search FAISS,
251
+ filter out self-match (score ~ 1.0), and return top-k results.
252
  """
253
+ # Encode image
254
  inputs = clip_processor(images=image, return_tensors="pt").to(device)
255
  with torch.no_grad():
256
  feats = clip_model.get_image_features(**inputs)
257
 
258
+ # Normalize (same as you did when building the index)
259
  feats = feats / feats.norm(p=2, dim=-1, keepdim=True)
260
  feats = feats.cpu().numpy().astype("float32")
261
 
262
+ # Search a bit more than k so we can drop self-match
263
+ search_k = min(index.ntotal, k + 5)
264
+ D, I = index.search(feats, search_k)
265
 
266
  rows = metadata.iloc[I[0]].copy()
267
  rows["score"] = D[0]
268
+
269
+ # Remove potential self-match (exact same image → cosine ~ 1.0)
270
+ rows = rows[rows["score"] < 0.999].copy()
271
+
272
+ # Add image_url
273
  rows["image_url"] = rows["ID"].apply(id_to_image_url)
274
 
275
+ # Keep only needed columns and top-k by score
276
+ rows = rows.sort_values("score", ascending=False).head(k)
277
+
278
+ # If concepts_manual is missing, fill with empty string
279
+ if "concepts_manual" not in rows.columns:
280
+ rows["concepts_manual"] = ""
281
+
282
+ return rows[["ID", "caption", "concepts_manual", "score", "image_url"]]
283
+
284
+
285
+ # ---------- Helper: caption with BLIP-2 ----------
286
+
287
+ def clean_caption(text: str) -> str:
288
+ """Basic cleanup to remove obvious repetition artifacts."""
289
+ text = text.strip()
290
 
291
+ # Deduplicate immediate repeated phrases separated by commas
292
+ parts = [p.strip() for p in text.split(",")]
293
+ dedup = []
294
+ for p in parts:
295
+ if not dedup or p.lower() != dedup[-1].lower():
296
+ dedup.append(p)
297
+ text = ", ".join(dedup)
298
 
299
+ # Remove repeated 'respectively'
300
+ text = re.sub(r"(respectively,?\s+)+", "respectively ", text, flags=re.IGNORECASE)
301
+
302
+ # Remove exact doubled sentence patterns like "..., and a large ... and a large ..."
303
+ text = re.sub(r"\b(\w+(?:\s+\w+){2,})\s+\1\b", r"\1", text, flags=re.IGNORECASE)
304
+
305
+ # Normalize whitespace
306
+ text = " ".join(text.split())
307
+ return text
308
 
309
 
 
310
  def generate_query_caption(image: Image.Image) -> str:
311
+ """
312
+ Generate a radiology-focused caption using BLIP-2.
313
+ """
314
+ prompt = (
315
+ "You are an expert radiologist. "
316
+ "Describe the key radiology findings in one concise sentence. "
317
+ "Avoid repeating phrases."
318
+ )
319
+
320
+ inputs = caption_processor(
321
+ images=image,
322
+ text=prompt,
323
+ return_tensors="pt",
324
+ ).to(device, dtype=caption_dtype)
325
+
326
  with torch.no_grad():
327
+ generated_ids = caption_model.generate(
328
+ **inputs,
329
+ max_new_tokens=64,
330
+ num_beams=4,
331
+ no_repeat_ngram_size=3,
332
+ repetition_penalty=1.1,
333
+ )
334
+
335
+ caption = caption_processor.batch_decode(
336
+ generated_ids, skip_special_tokens=True
337
+ )[0]
338
+ return clean_caption(caption)
339
+
340
+
341
+ # ---------- Routes ----------
342
+
343
+ @app.get("/")
344
+ def root():
345
+ return {"status": "ok", "message": "Radiology retrieval + BLIP-2 captioning API"}
346
+
347
+
348
+ @app.post("/search_by_image")
349
+ async def search_by_image(file: UploadFile = File(...), k: int = 5):
350
+ """
351
+ Upload a radiology image.
352
+ Returns:
353
+ - query_caption: BLIP-2 caption for the query image
354
+ - modality: detected imaging modality from caption
355
+ - scores: random quality metrics in given ranges
356
+ - results: list of similar images with similarity + concepts + image_url
357
+ """
358
+ # Read uploaded file
359
+ content = await file.read()
360
+ image = Image.open(io.BytesIO(content)).convert("RGB")
361
+
362
+ # Retrieval
363
+ results_df = search_similar_by_image(image, k=int(k))
364
+ results = results_df.to_dict(orient="records")
365
+
366
+ # Caption + modality
367
+ try:
368
+ query_caption = generate_query_caption(image)
369
+ except Exception as e:
370
+ print("Error generating caption with BLIP-2:", e)
371
+ query_caption = None
372
+
373
+ modality = detect_modality(query_caption or "")
374
+
375
+ # Random scores
376
+ scores = generate_random_scores()
377
+
378
+ return JSONResponse(
379
+ {
380
+ "query_caption": query_caption,
381
+ "modality": modality,
382
+ "scores": scores,
383
+ "results": results,
384
+ }
385
+ )
386
+ # app.py
387
+ import io
388
+ import os
389
+ import random
390
+ import re
391
+ from typing import Dict
392
+
393
+ import faiss
394
+ import torch
395
+ import pandas as pd
396
+
397
+ from PIL import Image
398
+ from fastapi import FastAPI, File, UploadFile
399
+ from fastapi.middleware.cors import CORSMiddleware
400
+ from fastapi.responses import JSONResponse
401
+
402
+ from huggingface_hub import hf_hub_download
403
+ from transformers import (
404
+ CLIPProcessor,
405
+ CLIPModel,
406
+ Blip2Processor,
407
+ Blip2ForConditionalGeneration,
408
+ )
409
+
410
+ # ---------- FastAPI app ----------
411
+ app = FastAPI()
412
+
413
+ app.add_middleware(
414
+ CORSMiddleware,
415
+ allow_origins=["*"], # later restrict to your frontend domain
416
+ allow_credentials=True,
417
+ allow_methods=["*"],
418
+ allow_headers=["*"],
419
+ )
420
+
421
+ # ---------- Config ----------
422
+
423
+ # Dataset with FAISS index + radiology_metadata.csv
424
+ EMBED_REPO_ID = "saad003/Red01"
425
+
426
+ # Dataset with all radiology images (new structure with train01–train07)
427
+ IMAGE_REPO_ID = "saad003/images04"
428
+ BASE_IMAGE_URL = f"https://huggingface.co/datasets/{IMAGE_REPO_ID}/resolve/main"
429
+
430
+ # Optional: token if Red01 is private
431
+ HF_TOKEN = os.environ.get("HF_TOKEN")
432
+
433
+ device = "cuda" if torch.cuda.is_available() else "cpu"
434
+ print("Using device:", device)
435
+
436
+ # ---------- Download index + metadata ----------
437
+ print("Downloading FAISS index & metadata from Hugging Face...")
438
+
439
+ INDEX_PATH = hf_hub_download(
440
+ repo_id=EMBED_REPO_ID,
441
+ filename="radiology_index.faiss",
442
+ repo_type="dataset",
443
+ token=HF_TOKEN,
444
+ )
445
+
446
+ META_PATH = hf_hub_download(
447
+ repo_id=EMBED_REPO_ID,
448
+ filename="radiology_metadata.csv",
449
+ repo_type="dataset",
450
+ token=HF_TOKEN,
451
+ )
452
+
453
+ print("Loading FAISS index...")
454
+ index = faiss.read_index(INDEX_PATH)
455
 
456
+ print("Loading metadata CSV...")
457
+ metadata = pd.read_csv(META_PATH)
458
+
459
+ assert index.ntotal == len(metadata), "Index size and metadata rows mismatch!"
460
+
461
+ # ---------- Load CLIP (retrieval) ----------
462
+ print("Loading PubMedCLIP model for retrieval...")
463
+ CLIP_MODEL_NAME = "flaviagiammarino/pubmed-clip-vit-base-patch32"
464
+
465
+ clip_model = CLIPModel.from_pretrained(CLIP_MODEL_NAME).to(device)
466
+ clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME)
467
+ clip_model.eval()
468
+
469
+ # ---------- Load BLIP-2 (captioning) ----------
470
+ print("Loading BLIP-2 model for medical captioning...")
471
+ CAPTION_MODEL_ID = "Salesforce/blip2-opt-2.7b"
472
+
473
+ # Use fp16 on GPU, fp32 on CPU
474
+ caption_dtype = torch.float16 if device == "cuda" else torch.float32
475
 
476
+ caption_processor = Blip2Processor.from_pretrained(CAPTION_MODEL_ID)
477
+ caption_model = Blip2ForConditionalGeneration.from_pretrained(
478
+ CAPTION_MODEL_ID,
479
+ torch_dtype=caption_dtype,
480
+ ).to(device)
481
+ caption_model.eval()
482
+
483
+ print("Backend ready ✅")
484
+
485
+
486
+ # ---------- Helper: image path mapping ----------
487
+
488
+ def id_to_image_url(image_id: str) -> str:
489
  """
490
+ Map ROCO image IDs to folders in saad003/images04.
491
+
492
+ test -> test/
493
+ valid -> valid/
494
+ train -> train01 ... train07 based on numeric ID
495
  """
496
+ image_id = image_id.strip()
497
+ base = BASE_IMAGE_URL
498
+
499
+ if "_test_" in image_id:
500
+ folder = "test"
501
+ elif "_valid_" in image_id:
502
+ folder = "valid"
503
+ elif "_train_" in image_id:
504
+ # last part: ROCOv2_2023_train_054005 -> "054005"
505
+ num_str = image_id.split("_")[-1]
506
+ try:
507
+ n = int(num_str)
508
+ except ValueError:
509
+ n = 0
510
+
511
+ # Rough ranges based on your description
512
+ if 1 <= n <= 9000:
513
+ folder = "train01"
514
+ elif 9001 <= n <= 18000:
515
+ folder = "train02"
516
+ elif 18001 <= n <= 27000:
517
+ folder = "train03"
518
+ elif 27001 <= n <= 36000:
519
+ folder = "train04"
520
+ elif 36001 <= n <= 45000:
521
+ folder = "train05"
522
+ elif 45001 <= n <= 54000:
523
+ folder = "train06"
524
+ else:
525
+ folder = "train07"
526
+ else:
527
+ folder = ""
528
+
529
+ if folder:
530
+ return f"{base}/{folder}/{image_id}.jpg"
531
+ else:
532
+ # fallback – should not happen, but safe
533
+ return f"{base}/{image_id}.jpg"
534
+
535
+
536
+ # ---------- Helper: modality detection ----------
537
+
538
+ MODALITY_KEYWORDS = {
539
+ "CT": [
540
+ "ct ",
541
+ "ctscan",
542
+ "computed tomography",
543
+ "tomography",
544
+ "ct scan",
545
+ "non-contrast ct",
546
+ "contrast-enhanced ct",
547
+ ],
548
+ "MRI": [
549
+ "mri ",
550
+ "magnetic resonance",
551
+ "t1-weighted",
552
+ "t2-weighted",
553
+ "flair sequence",
554
+ "diffusion-weighted",
555
+ "dwi",
556
+ ],
557
+ "X-ray": [
558
+ "x-ray",
559
+ "x ray",
560
+ "radiograph",
561
+ "plain film",
562
+ "chest film",
563
+ "postoperative x",
564
+ "post-operative x",
565
+ "cxr",
566
+ ],
567
+ "Ultrasound": [
568
+ "ultrasound",
569
+ "sonogram",
570
+ "sonography",
571
+ "usg",
572
+ "doppler",
573
+ "echocardiogram",
574
+ "echocardiography",
575
+ ],
576
+ "PET/CT": [
577
+ "pet-ct",
578
+ "pet/ct",
579
+ "pet scan",
580
+ "positron emission tomography",
581
+ ],
582
+ "Fluoroscopy": [
583
+ "fluoroscopy",
584
+ "fluoroscopic",
585
+ "angiogram",
586
+ "angiography",
587
+ "barium swallow",
588
+ "barium enema",
589
+ ],
590
+ }
591
+
592
+ def detect_modality(caption: str) -> str:
593
  if not caption:
594
  return "Unknown"
 
595
  text = caption.lower()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
596
 
597
+ for modality, keywords in MODALITY_KEYWORDS.items():
598
+ for kw in keywords:
599
+ if kw in text:
600
+ return modality
601
+
602
+ # Back-up heuristics
603
+ if "mra" in text:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
604
  return "MRI"
605
+ if "cta " in text or "ct angiography" in text:
606
+ return "CT"
607
+ return "Unknown"
608
+
609
+
610
+ # ---------- Helper: random scoring ----------
611
+
612
+ def generate_random_scores() -> Dict[str, float]:
613
+ """
614
+ Return random scores in the ranges you specified.
615
+ """
616
+ rng = random.Random()
617
+
618
+ modality_score = rng.uniform(85.0, 93.0) # percent
619
+ cui_at_k = rng.uniform(0.30, 0.61)
620
+ bert = rng.uniform(0.20, 0.40)
621
+ medbert = rng.uniform(0.20, 0.35)
622
+
623
+ return {
624
+ "modality_score": round(modality_score, 1),
625
+ "cui_at_k": round(cui_at_k, 3),
626
+ "bertscore": round(bert, 3),
627
+ "medbertscore": round(medbert, 3),
628
+ }
629
+
630
+
631
+ # ---------- Helper: search by image ----------
632
+
633
+ def search_similar_by_image(image: Image.Image, k: int = 5) -> pd.DataFrame:
634
+ """
635
+ Encode query image with CLIP, search FAISS,
636
+ filter out self-match (score ~ 1.0), and return top-k results.
637
+ """
638
+ # Encode image
639
+ inputs = clip_processor(images=image, return_tensors="pt").to(device)
640
+ with torch.no_grad():
641
+ feats = clip_model.get_image_features(**inputs)
642
+
643
+ # Normalize (same as you did when building the index)
644
+ feats = feats / feats.norm(p=2, dim=-1, keepdim=True)
645
+ feats = feats.cpu().numpy().astype("float32")
646
+
647
+ # Search a bit more than k so we can drop self-match
648
+ search_k = min(index.ntotal, k + 5)
649
+ D, I = index.search(feats, search_k)
650
 
651
+ rows = metadata.iloc[I[0]].copy()
652
+ rows["score"] = D[0]
653
+
654
+ # Remove potential self-match (exact same image → cosine ~ 1.0)
655
+ rows = rows[rows["score"] < 0.999].copy()
656
+
657
+ # Add image_url
658
+ rows["image_url"] = rows["ID"].apply(id_to_image_url)
659
+
660
+ # Keep only needed columns and top-k by score
661
+ rows = rows.sort_values("score", ascending=False).head(k)
662
+
663
+ # If concepts_manual is missing, fill with empty string
664
+ if "concepts_manual" not in rows.columns:
665
+ rows["concepts_manual"] = ""
666
+
667
+ return rows[["ID", "caption", "concepts_manual", "score", "image_url"]]
668
+
669
+
670
+ # ---------- Helper: caption with BLIP-2 ----------
671
+
672
+ def clean_caption(text: str) -> str:
673
+ """Basic cleanup to remove obvious repetition artifacts."""
674
+ text = text.strip()
675
+
676
+ # Deduplicate immediate repeated phrases separated by commas
677
+ parts = [p.strip() for p in text.split(",")]
678
+ dedup = []
679
+ for p in parts:
680
+ if not dedup or p.lower() != dedup[-1].lower():
681
+ dedup.append(p)
682
+ text = ", ".join(dedup)
683
+
684
+ # Remove repeated 'respectively'
685
+ text = re.sub(r"(respectively,?\s+)+", "respectively ", text, flags=re.IGNORECASE)
686
+
687
+ # Remove exact doubled sentence patterns like "..., and a large ... and a large ..."
688
+ text = re.sub(r"\b(\w+(?:\s+\w+){2,})\s+\1\b", r"\1", text, flags=re.IGNORECASE)
689
+
690
+ # Normalize whitespace
691
+ text = " ".join(text.split())
692
+ return text
693
+
694
+
695
+ def generate_query_caption(image: Image.Image) -> str:
696
+ """
697
+ Generate a radiology-focused caption using BLIP-2.
698
+ """
699
+ prompt = (
700
+ "You are an expert radiologist. "
701
+ "Describe the key radiology findings in one concise sentence. "
702
+ "Avoid repeating phrases."
703
+ )
704
+
705
+ inputs = caption_processor(
706
+ images=image,
707
+ text=prompt,
708
+ return_tensors="pt",
709
+ ).to(device, dtype=caption_dtype)
710
+
711
+ with torch.no_grad():
712
+ generated_ids = caption_model.generate(
713
+ **inputs,
714
+ max_new_tokens=64,
715
+ num_beams=4,
716
+ no_repeat_ngram_size=3,
717
+ repetition_penalty=1.1,
718
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
719
 
720
+ caption = caption_processor.batch_decode(
721
+ generated_ids, skip_special_tokens=True
722
+ )[0]
723
+ return clean_caption(caption)
724
 
725
 
726
  # ---------- Routes ----------
727
+
728
  @app.get("/")
729
  def root():
730
+ return {"status": "ok", "message": "Radiology retrieval + BLIP-2 captioning API"}
731
 
732
 
733
  @app.post("/search_by_image")
734
  async def search_by_image(file: UploadFile = File(...), k: int = 5):
735
  """
736
  Upload a radiology image.
 
737
  Returns:
738
+ - query_caption: BLIP-2 caption for the query image
739
+ - modality: detected imaging modality from caption
740
+ - scores: random quality metrics in given ranges
741
+ - results: list of similar images with similarity + concepts + image_url
 
742
  """
743
+ # Read uploaded file
744
  content = await file.read()
745
  image = Image.open(io.BytesIO(content)).convert("RGB")
746
 
747
+ # Retrieval
748
+ results_df = search_similar_by_image(image, k=int(k))
 
 
 
 
749
  results = results_df.to_dict(orient="records")
750
 
751
+ # Caption + modality
752
  try:
753
  query_caption = generate_query_caption(image)
754
  except Exception as e:
755
+ print("Error generating caption with BLIP-2:", e)
756
  query_caption = None
757
 
758
+ modality = detect_modality(query_caption or "")
 
759
 
760
+ # Random scores
761
+ scores = generate_random_scores()
 
 
762
 
763
  return JSONResponse(
764
  {
765
  "query_caption": query_caption,
766
  "modality": modality,
767
+ "scores": scores,
 
 
 
768
  "results": results,
769
  }
770
  )