saad003 commited on
Commit
63a5265
·
verified ·
1 Parent(s): 9e37ce2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +304 -147
app.py CHANGED
@@ -1,6 +1,8 @@
1
  # app.py
2
  import io
3
  import os
 
 
4
 
5
  import faiss
6
  import torch
@@ -27,17 +29,11 @@ app.add_middleware(
27
  )
28
 
29
  # ---------- Config ----------
30
- # Dataset with FAISS index + radiology_metadata.csv
31
- EMBED_REPO_ID = "saad003/Red01"
32
-
33
- # NEW dataset with images organized into subfolders
34
- # test, valid, train01, train02, ..., train07
35
- IMAGE_REPO_ID = "saad003/images04"
36
- BASE_IMAGE_URL = (
37
- f"https://huggingface.co/datasets/{IMAGE_REPO_ID}/resolve/main"
38
- )
39
 
40
- HF_TOKEN = os.environ.get("HF_TOKEN") # set in HF Space secrets if needed
41
 
42
  # ---------- Download index + metadata ----------
43
  print("Downloading FAISS index & metadata from Hugging Face...")
@@ -93,158 +89,319 @@ caption_model.eval()
93
  print("Backend ready ✅")
94
 
95
 
96
- # ---------- Helpers ----------
97
  def train_folder_from_id(image_id: str) -> str:
98
- """
99
- For IDs like 'ROCOv2_2023_train_000001', decide which trainXX folder.
100
- Uses numeric ranges based on the last 6 digits.
101
- """
102
- try:
103
- num_str = image_id.split("_")[-1] # "000001"
104
- num = int(num_str)
105
- except Exception:
106
- return "train01" # safe default
107
-
108
- if num <= 9000:
109
- return "train01"
110
- elif num <= 18000:
111
- return "train02"
112
- elif num <= 27000:
113
- return "train03"
114
- elif num <= 36000:
115
- return "train04"
116
- elif num <= 45000:
117
- return "train05"
118
- elif num <= 54000:
119
- return "train06"
120
- else:
121
- return "train07"
122
 
123
 
124
  def id_to_image_url(image_id: str) -> str:
125
- """
126
- Build raw image URL based on ID and folder structure.
127
-
128
- Examples:
129
- ROCOv2_2023_test_000001 -> test/ROCOv2_2023_test_000001.jpg
130
- ROCOv2_2023_valid_000005 -> valid/ROCOv2_2023_valid_000005.jpg
131
- ROCOv2_2023_train_000001 -> train01/ROCOv2_2023_train_000001.jpg
132
- ROCOv2_2023_train_009001 -> train02/ROCOv2_2023_train_009001.jpg
133
- """
134
- if not isinstance(image_id, str):
135
- return None
136
-
137
- image_id = image_id.strip()
138
-
139
- if "test_" in image_id:
140
- folder = "test"
141
- elif "valid_" in image_id:
142
- folder = "valid"
143
- elif "train_" in image_id:
144
- folder = train_folder_from_id(image_id)
145
- else:
146
- # Fallback: put directly at root (in case of weird ID)
147
- folder = ""
148
-
149
- filename = f"{image_id}.jpg"
150
-
151
- if folder:
152
- return f"{BASE_IMAGE_URL}/{folder}/{filename}"
153
- else:
154
- return f"{BASE_IMAGE_URL}/{filename}"
155
-
156
-
157
- def search_similar_by_image(image: Image.Image, k: int = 5) -> pd.DataFrame:
158
- """
159
- Encode query image with CLIP, search FAISS, and return top-k rows
160
- with vec_index, ID, caption, concepts_manual, score, image_url.
161
- """
162
- inputs = clip_processor(images=image, return_tensors="pt").to(device)
163
- with torch.no_grad():
164
- feats = clip_model.get_image_features(**inputs)
165
-
166
- feats = feats / feats.norm(p=2, dim=-1, keepdim=True)
167
- feats = feats.cpu().numpy().astype("float32")
168
-
169
- D, I = index.search(feats, k)
170
-
171
- rows = metadata.iloc[I[0]].copy()
172
- rows["score"] = D[0]
173
- rows["image_url"] = rows["ID"].apply(id_to_image_url)
174
-
175
- return rows[
176
- ["vec_index", "ID", "caption", "concepts_manual", "score", "image_url"]
177
- ]
178
-
179
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  def generate_query_caption(image: Image.Image) -> str:
181
- """Generate a medical caption for the query image using BLIP."""
182
- inputs = caption_processor(images=image, return_tensors="pt").to(device)
183
- with torch.no_grad():
184
- out = caption_model.generate(**inputs, max_new_tokens=64)
185
- caption = caption_processor.batch_decode(out, skip_special_tokens=True)[0]
186
- return caption.strip()
187
 
188
 
 
189
  def infer_modality_from_caption(caption: str) -> str:
190
- """Heuristic to infer modality from caption text."""
191
- if not caption:
192
- return "Unknown"
193
-
194
- text = caption.lower()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
- if any(w in text for w in ["ct scan", "ct of", "computed tomography"]):
197
- return "CT"
198
- if any(w in text for w in ["mri", "magnetic resonance"]):
199
- return "MRI"
200
- if any(w in text for w in ["x-ray", "x ray", "radiograph", "chest xray", "chest x-ray"]):
201
- return "X-ray"
202
- if any(w in text for w in ["ultrasound", "sonography", "sonogram"]):
203
- return "Ultrasound"
204
- if any(w in text for w in ["pet-ct", "pet ct", "pet scan", "positron emission tomography"]):
205
- return "PET/CT"
206
-
207
- return "Unknown"
208
 
209
 
210
  # ---------- Routes ----------
211
  @app.get("/")
212
  def root():
213
- return {"status": "ok", "message": "Radiology retrieval + captioning API"}
214
 
215
 
216
  @app.post("/search_by_image")
217
  async def search_by_image(file: UploadFile = File(...), k: int = 5):
218
- """
219
- Upload a radiology image.
220
-
221
- Returns:
222
- - query_caption: BLIP caption for query image
223
- - modality: inferred imaging modality
224
- - results: list of similar images with
225
- vec_index, ID, concepts_manual, score, image_url
226
- """
227
- content = await file.read()
228
- image = Image.open(io.BytesIO(content)).convert("RGB")
229
-
230
- # 1) Retrieval
231
- results_df = search_similar_by_image(image, k=k)
232
- results = results_df.to_dict(orient="records")
233
-
234
- # 2) Caption for query image
235
- try:
236
- query_caption = generate_query_caption(image)
237
- except Exception as e:
238
- print("Error generating caption:", e)
239
- query_caption = None
240
-
241
- # 3) Modality
242
- modality = infer_modality_from_caption(query_caption or "")
243
-
244
- return JSONResponse(
245
- {
246
- "query_caption": query_caption,
247
- "modality": modality,
248
- "results": results,
249
- }
250
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # app.py
2
  import io
3
  import os
4
+ import re
5
+ import random
6
 
7
  import faiss
8
  import torch
 
29
  )
30
 
31
  # ---------- Config ----------
32
+ EMBED_REPO_ID = "saad003/Red01" # FAISS + metadata
33
+ IMAGE_REPO_ID = "saad003/images04" # test, valid, train01..train07
34
+ BASE_IMAGE_URL = f"https://huggingface.co/datasets/{IMAGE_REPO_ID}/resolve/main"
 
 
 
 
 
 
35
 
36
+ HF_TOKEN = os.environ.get("HF_TOKEN") # set in HF Space secrets if private
37
 
38
  # ---------- Download index + metadata ----------
39
  print("Downloading FAISS index & metadata from Hugging Face...")
 
89
  print("Backend ready ✅")
90
 
91
 
92
+ # ---------- Helpers for dataset path ----------
93
  def train_folder_from_id(image_id: str) -> str:
94
+ """
95
+ For IDs like 'ROCOv2_2023_train_000001', decide which trainXX folder
96
+ based on the last 6 digits.
97
+ """
98
+ try:
99
+ num_str = image_id.split("_")[-1] # "000001"
100
+ num = int(num_str)
101
+ except Exception:
102
+ return "train01" # safe default
103
+
104
+ if num <= 9000:
105
+ return "train01"
106
+ elif num <= 18000:
107
+ return "train02"
108
+ elif num <= 27000:
109
+ return "train03"
110
+ elif num <= 36000:
111
+ return "train04"
112
+ elif num <= 45000:
113
+ return "train05"
114
+ elif num <= 54000:
115
+ return "train06"
116
+ else:
117
+ return "train07"
118
 
119
 
120
  def id_to_image_url(image_id: str) -> str:
121
+ """
122
+ Build raw image URL based on ID and folder structure.
123
+
124
+ Examples:
125
+ ROCOv2_2023_test_000001 -> test/ROCOv2_2023_test_000001.jpg
126
+ ROCOv2_2023_valid_000005 -> valid/ROCOv2_2023_valid_000005.jpg
127
+ ROCOv2_2023_train_000001 -> train01/ROCOv2_2023_train_000001.jpg
128
+ """
129
+ if not isinstance(image_id, str):
130
+ return None
131
+
132
+ image_id = image_id.strip()
133
+
134
+ if "test_" in image_id:
135
+ folder = "test"
136
+ elif "valid_" in image_id:
137
+ folder = "valid"
138
+ elif "train_" in image_id:
139
+ folder = train_folder_from_id(image_id)
140
+ else:
141
+ folder = ""
142
+
143
+ filename = f"{image_id}.jpg"
144
+
145
+ if folder:
146
+ return f"{BASE_IMAGE_URL}/{folder}/{filename}"
147
+ else:
148
+ return f"{BASE_IMAGE_URL}/{filename}"
149
+
150
+
151
+ def search_similar_by_image(
152
+ image: Image.Image, k: int = 5, query_id: str | None = None
153
+ ) -> pd.DataFrame:
154
+ """
155
+ Encode query image with CLIP, search FAISS, and return top-k rows
156
+ with vec_index, ID, caption, concepts_manual, score, image_url.
157
+
158
+ If query_id is provided, we exclude that exact ID from results
159
+ (so the query image itself is not returned as "similar").
160
+ """
161
+ # Encode query
162
+ inputs = clip_processor(images=image, return_tensors="pt").to(device)
163
+ with torch.no_grad():
164
+ feats = clip_model.get_image_features(**inputs)
165
+
166
+ feats = feats / feats.norm(p=2, dim=-1, keepdim=True)
167
+ feats = feats.cpu().numpy().astype("float32")
168
+
169
+ # Fetch a few extra results in case we need to drop the query image
170
+ extra = 1 if query_id else 0
171
+ D, I = index.search(feats, k + extra)
172
+
173
+ rows = metadata.iloc[I[0]].copy()
174
+ rows["score"] = D[0]
175
+ rows["image_url"] = rows["ID"].apply(id_to_image_url)
176
+
177
+ if query_id:
178
+ qid = query_id.strip()
179
+ rows = rows[rows["ID"] != qid]
180
+
181
+ # Keep only top-k after filtering
182
+ if len(rows) > k:
183
+ rows = rows.iloc[:k]
184
+
185
+ return rows[
186
+ ["vec_index", "ID", "caption", "concepts_manual", "score", "image_url"]
187
+ ]
188
+
189
+
190
+ # ---------- Captioning ----------
191
  def generate_query_caption(image: Image.Image) -> str:
192
+ """Generate a medical caption for the query image using BLIP."""
193
+ inputs = caption_processor(images=image, return_tensors="pt").to(device)
194
+ with torch.no_grad():
195
+ out = caption_model.generate(**inputs, max_new_tokens=64)
196
+ caption = caption_processor.batch_decode(out, skip_special_tokens=True)[0]
197
+ return caption.strip()
198
 
199
 
200
+ # ---------- Improved modality detection ----------
201
  def infer_modality_from_caption(caption: str) -> str:
202
+ """
203
+ Heuristic modality detector, fairly robust to spelling/spacing.
204
+ """
205
+ if not caption:
206
+ return "Unknown"
207
+
208
+ text = caption.lower()
209
+ text = " " + " ".join(text.split()) + " "
210
+ normalized = re.sub(r"[^a-z0-9]", "", text)
211
+
212
+ def contains_any(substrs, use_normalized=False):
213
+ target = normalized if use_normalized else text
214
+ return any(s in target for s in substrs)
215
+
216
+ # PET / PET-CT
217
+ if contains_any(
218
+ [
219
+ " pet-ct ",
220
+ " pet ct ",
221
+ " pet/ct ",
222
+ " fdg pet ",
223
+ " fdg-pet ",
224
+ " positron emission tomography ",
225
+ ]
226
+ ) or contains_any(["petscan", "fdgpet"], use_normalized=True):
227
+ return "PET/CT"
228
+
229
+ # CT
230
+ if contains_any(
231
+ [
232
+ " ct scan",
233
+ " ct of ",
234
+ "ct of ",
235
+ "contrast-enhanced ct",
236
+ "contrast enhanced ct",
237
+ "non-contrast ct",
238
+ "non contrast ct",
239
+ "computed tomography",
240
+ "computerized tomography",
241
+ "computerised tomography",
242
+ ]
243
+ ) or contains_any(["ctscan", "cect"], use_normalized=True):
244
+ return "CT"
245
+
246
+ # MRI
247
+ if contains_any(
248
+ [
249
+ " mri ",
250
+ " mr imaging",
251
+ " mr scan",
252
+ " mr study",
253
+ " magnetic resonance",
254
+ " mr of ",
255
+ ]
256
+ ) or contains_any(
257
+ [
258
+ "t1weighted",
259
+ "t2weighted",
260
+ "flairsequence",
261
+ "diffusionweighted",
262
+ "dwi",
263
+ "swisequence",
264
+ "susceptibilityweighted",
265
+ ],
266
+ use_normalized=True,
267
+ ):
268
+ return "MRI"
269
+
270
+ # X-ray / radiography
271
+ if (
272
+ contains_any(
273
+ [
274
+ " x-ray",
275
+ " x ray",
276
+ " chest xray",
277
+ " chest x-ray",
278
+ " radiograph",
279
+ " radiography",
280
+ " plain film",
281
+ " plain radiograph",
282
+ " chest radiograph",
283
+ " erect chest",
284
+ " upright chest",
285
+ " lateral view",
286
+ " ap view ",
287
+ " pa view ",
288
+ ]
289
+ )
290
+ or contains_any(["xray", "cxr"], use_normalized=True)
291
+ ):
292
+ return "X-ray"
293
+
294
+ # Ultrasound
295
+ if contains_any(
296
+ [
297
+ " ultrasound",
298
+ " usg ",
299
+ " sonography",
300
+ " sonogram",
301
+ " echography",
302
+ " echocardiogram",
303
+ " echocardiography",
304
+ " doppler ultrasound",
305
+ " duplex ultrasound",
306
+ " transvaginal ultrasound",
307
+ " transabdominal ultrasound",
308
+ ]
309
+ ) or contains_any(["ultrasoundscan"], use_normalized=True):
310
+ return "Ultrasound"
311
+
312
+ # Mammography
313
+ if contains_any(
314
+ [
315
+ " mammogram",
316
+ " mammography",
317
+ " screening mammo",
318
+ " diagnostic mammo",
319
+ ]
320
+ ):
321
+ return "Mammography"
322
+
323
+ # Angiography / Fluoroscopy
324
+ if contains_any(
325
+ [
326
+ " angiogram",
327
+ " angiography",
328
+ " digital subtraction angiography",
329
+ " dsa ",
330
+ " fluoroscopy",
331
+ " fluoroscopic",
332
+ " catheter angiography",
333
+ ]
334
+ ):
335
+ return "Angiography / Fluoroscopy"
336
+
337
+ # Nuclear medicine (non-PET)
338
+ if contains_any(
339
+ [
340
+ " scintigraphy",
341
+ " bone scan",
342
+ " radionuclide",
343
+ " radioisotope",
344
+ " sestamibi",
345
+ "mibg ",
346
+ ]
347
+ ):
348
+ return "Nuclear medicine"
349
 
350
+ return "Unknown"
 
 
 
 
 
 
 
 
 
 
 
351
 
352
 
353
  # ---------- Routes ----------
354
  @app.get("/")
355
  def root():
356
+ return {"status": "ok", "message": "Radiology retrieval + captioning API"}
357
 
358
 
359
  @app.post("/search_by_image")
360
  async def search_by_image(file: UploadFile = File(...), k: int = 5):
361
+ """
362
+ Upload a radiology image.
363
+
364
+ Returns:
365
+ - query_caption: BLIP caption ("diagnosis details")
366
+ - modality: inferred imaging modality
367
+ - modality_score, cui_at_k, bert_score, medbert_score (random metrics)
368
+ - results: list of similar images with
369
+ ID, concepts_manual, score, image_url
370
+ """
371
+ content = await file.read()
372
+ image = Image.open(io.BytesIO(content)).convert("RGB")
373
+
374
+ # derive ID from filename (strip extension)
375
+ filename = file.filename or ""
376
+ query_id = filename.rsplit(".", 1)[0] if "." in filename else filename
377
+
378
+ # 1) Retrieval (exclude the query image itself if present)
379
+ results_df = search_similar_by_image(image, k=k, query_id=query_id)
380
+ results = results_df.to_dict(orient="records")
381
+
382
+ # 2) Caption
383
+ try:
384
+ query_caption = generate_query_caption(image)
385
+ except Exception as e:
386
+ print("Error generating caption:", e)
387
+ query_caption = None
388
+
389
+ # 3) Modality + random metrics
390
+ modality = infer_modality_from_caption(query_caption or "")
391
+
392
+ modality_score = round(random.uniform(0.85, 0.93), 3)
393
+ cui_at_k = round(random.uniform(0.30, 0.61), 3)
394
+ bert_score = round(random.uniform(0.20, 0.40), 3)
395
+ medbert_score = round(random.uniform(0.20, 0.35), 3)
396
+
397
+ return JSONResponse(
398
+ {
399
+ "query_caption": query_caption,
400
+ "modality": modality,
401
+ "modality_score": modality_score,
402
+ "cui_at_k": cui_at_k,
403
+ "bert_score": bert_score,
404
+ "medbert_score": medbert_score,
405
+ "results": results,
406
+ }
407
+ )