saad003 commited on
Commit
602ea6a
·
verified ·
1 Parent(s): 4f357eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -45
app.py CHANGED
@@ -20,6 +20,8 @@ from transformers import (
20
  CLIPModel,
21
  BlipForConditionalGeneration,
22
  AutoProcessor,
 
 
23
  )
24
 
25
  # ---------- FastAPI app ----------
@@ -35,21 +37,16 @@ app.add_middleware(
35
 
36
  # ---------- Config ----------
37
 
38
- # Dataset with FAISS index + radiology_metadata.csv
39
- EMBED_REPO_ID = "saad003/Red01"
40
-
41
- # Dataset with all radiology images (test, valid, train01–train07)
42
- IMAGE_REPO_ID = "saad003/images04"
43
  BASE_IMAGE_URL = f"https://huggingface.co/datasets/{IMAGE_REPO_ID}/resolve/main"
44
 
45
- # Optional: token if Red01 is private
46
  HF_TOKEN = os.environ.get("HF_TOKEN")
47
 
48
  device = "cuda" if torch.cuda.is_available() else "cpu"
49
  print("Using device:", device)
50
 
51
- # use fp16 on GPU to speed up BLIP, fp32 on CPU
52
- caption_dtype = torch.float16 if device == "cuda" else torch.float32
53
 
54
  # ---------- Download index + metadata ----------
55
  print("Downloading FAISS index & metadata from Hugging Face...")
@@ -73,7 +70,6 @@ index = faiss.read_index(INDEX_PATH)
73
 
74
  print("Loading metadata CSV...")
75
  metadata = pd.read_csv(META_PATH)
76
-
77
  assert index.ntotal == len(metadata), "Index size and metadata rows mismatch!"
78
 
79
  # ---------- Load CLIP (retrieval) ----------
@@ -84,17 +80,27 @@ clip_model = CLIPModel.from_pretrained(CLIP_MODEL_NAME).to(device)
84
  clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME)
85
  clip_model.eval()
86
 
87
- # ---------- Load BLIP (radiology captioning) ----------
88
  print("Loading BLIP ROCO radiology captioning model...")
89
  CAPTION_MODEL_ID = "WafaaFraih/blip-roco-radiology-captioning"
90
 
91
  caption_processor = AutoProcessor.from_pretrained(CAPTION_MODEL_ID)
92
  caption_model = BlipForConditionalGeneration.from_pretrained(
93
  CAPTION_MODEL_ID,
94
- torch_dtype=caption_dtype,
95
  ).to(device)
96
  caption_model.eval()
97
 
 
 
 
 
 
 
 
 
 
 
98
  print("Backend ready ✅")
99
 
100
 
@@ -142,7 +148,6 @@ def id_to_image_url(image_id: str) -> str:
142
  if folder:
143
  return f"{base}/{folder}/{image_id}.jpg"
144
  else:
145
- # fallback – should not happen, but safe
146
  return f"{base}/{image_id}.jpg"
147
 
148
 
@@ -152,9 +157,9 @@ MODALITY_KEYWORDS = {
152
  "CT": [
153
  "ct ",
154
  "ctscan",
 
155
  "computed tomography",
156
  "tomography",
157
- "ct scan",
158
  "non-contrast ct",
159
  "contrast-enhanced ct",
160
  ],
@@ -222,9 +227,6 @@ def detect_modality(caption: str) -> str:
222
  # ---------- Helper: random scoring ----------
223
 
224
  def generate_random_scores() -> Dict[str, float]:
225
- """
226
- Return random scores in the ranges you specified.
227
- """
228
  rng = random.Random()
229
 
230
  modality_score = rng.uniform(85.0, 93.0) # percent
@@ -272,11 +274,11 @@ def search_similar_by_image(image: Image.Image, k: int = 5) -> pd.DataFrame:
272
  return rows[["ID", "caption", "concepts_manual", "score", "image_url"]]
273
 
274
 
275
- # ---------- Helper: caption cleaning & generation ----------
276
 
277
  def clean_caption(text: str) -> str:
278
  """
279
- Clean BLIP captions:
280
  - strip
281
  - split into clauses and remove duplicates
282
  - normalize spacing and punctuation
@@ -286,7 +288,6 @@ def clean_caption(text: str) -> str:
286
 
287
  text = text.strip()
288
 
289
- # break into clauses
290
  parts = re.split(r"[,.]", text)
291
  parts = [p.strip() for p in parts if p.strip()]
292
 
@@ -298,12 +299,11 @@ def clean_caption(text: str) -> str:
298
  seen.add(key)
299
  unique_parts.append(p)
300
 
301
- if not unique_parts:
302
- cleaned = text
303
- else:
304
  cleaned = ", ".join(unique_parts)
 
 
305
 
306
- # remove repeated 'respectively'
307
  cleaned = re.sub(
308
  r"(respectively,?\s+)+", "respectively ", cleaned, flags=re.IGNORECASE
309
  )
@@ -311,19 +311,18 @@ def clean_caption(text: str) -> str:
311
  cleaned = " ".join(cleaned.split())
312
  if cleaned and not cleaned.endswith("."):
313
  cleaned += "."
314
- cleaned = cleaned[0].upper() + cleaned[1:] if cleaned else cleaned
 
315
  return cleaned
316
 
317
 
318
- def generate_query_caption(image: Image.Image) -> str:
319
  """
320
- Generate a radiology caption using BLIP (ROCO).
321
- Tuned decoding to reduce repetition and keep it concise.
322
  """
323
  inputs = caption_processor(images=image, return_tensors="pt").to(
324
- device, dtype=caption_dtype
325
  )
326
-
327
  with torch.no_grad():
328
  out_ids = caption_model.generate(
329
  **inputs,
@@ -334,11 +333,53 @@ def generate_query_caption(image: Image.Image) -> str:
334
  length_penalty=0.9,
335
  early_stopping=True,
336
  )
 
 
 
337
 
338
- raw_caption = caption_processor.batch_decode(
339
- out_ids, skip_special_tokens=True
340
- )[0]
341
- return clean_caption(raw_caption)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
 
343
 
344
  # ---------- Routes ----------
@@ -347,7 +388,7 @@ def generate_query_caption(image: Image.Image) -> str:
347
  def root():
348
  return {
349
  "status": "ok",
350
- "message": "Radiology retrieval + BLIP radiology captioning API",
351
  }
352
 
353
 
@@ -356,33 +397,47 @@ async def search_by_image(file: UploadFile = File(...), k: int = 5):
356
  """
357
  Upload a radiology image.
358
  Returns:
359
- - query_caption: BLIP caption for the query image
360
- - modality: detected imaging modality from caption
361
  - scores: random quality metrics
362
  - results: similar images (similarity + concepts + image_url)
363
  """
364
  content = await file.read()
365
  image = Image.open(io.BytesIO(content)).convert("RGB")
366
 
367
- # Retrieval
368
- results_df = search_similar_by_image(image, k=int(k))
 
 
369
  results = results_df.to_dict(orient="records")
370
 
371
- # Caption + modality
 
 
 
 
372
  try:
373
- query_caption = generate_query_caption(image)
374
  except Exception as e:
375
- print("Error generating caption:", e)
376
- query_caption = None
377
 
378
- modality = detect_modality(query_caption or "")
 
 
 
 
 
 
 
379
 
380
- # Random scores
 
381
  scores = generate_random_scores()
382
 
383
  return JSONResponse(
384
  {
385
- "query_caption": query_caption,
386
  "modality": modality,
387
  "scores": scores,
388
  "results": results,
 
20
  CLIPModel,
21
  BlipForConditionalGeneration,
22
  AutoProcessor,
23
+ AutoTokenizer,
24
+ AutoModelForSeq2SeqLM,
25
  )
26
 
27
  # ---------- FastAPI app ----------
 
37
 
38
  # ---------- Config ----------
39
 
40
+ EMBED_REPO_ID = "saad003/Red01" # FAISS + metadata
41
+ IMAGE_REPO_ID = "saad003/images04" # images04 with test/valid/train01–train07
 
 
 
42
  BASE_IMAGE_URL = f"https://huggingface.co/datasets/{IMAGE_REPO_ID}/resolve/main"
43
 
 
44
  HF_TOKEN = os.environ.get("HF_TOKEN")
45
 
46
  device = "cuda" if torch.cuda.is_available() else "cpu"
47
  print("Using device:", device)
48
 
49
+ cap_dtype = torch.float16 if device == "cuda" else torch.float32
 
50
 
51
  # ---------- Download index + metadata ----------
52
  print("Downloading FAISS index & metadata from Hugging Face...")
 
70
 
71
  print("Loading metadata CSV...")
72
  metadata = pd.read_csv(META_PATH)
 
73
  assert index.ntotal == len(metadata), "Index size and metadata rows mismatch!"
74
 
75
  # ---------- Load CLIP (retrieval) ----------
 
80
  clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME)
81
  clip_model.eval()
82
 
83
+ # ---------- Load BLIP (image -> draft caption) ----------
84
  print("Loading BLIP ROCO radiology captioning model...")
85
  CAPTION_MODEL_ID = "WafaaFraih/blip-roco-radiology-captioning"
86
 
87
  caption_processor = AutoProcessor.from_pretrained(CAPTION_MODEL_ID)
88
  caption_model = BlipForConditionalGeneration.from_pretrained(
89
  CAPTION_MODEL_ID,
90
+ torch_dtype=cap_dtype,
91
  ).to(device)
92
  caption_model.eval()
93
 
94
+ # ---------- Load FLAN-T5 (text refinement using similar captions) ----------
95
+ print("Loading FLAN-T5 for caption refinement...")
96
+ REFINER_MODEL_ID = "google/flan-t5-base"
97
+
98
+ refiner_tokenizer = AutoTokenizer.from_pretrained(REFINER_MODEL_ID)
99
+ refiner_model = AutoModelForSeq2SeqLM.from_pretrained(
100
+ REFINER_MODEL_ID
101
+ ).to(device)
102
+ refiner_model.eval()
103
+
104
  print("Backend ready ✅")
105
 
106
 
 
148
  if folder:
149
  return f"{base}/{folder}/{image_id}.jpg"
150
  else:
 
151
  return f"{base}/{image_id}.jpg"
152
 
153
 
 
157
  "CT": [
158
  "ct ",
159
  "ctscan",
160
+ "ct scan",
161
  "computed tomography",
162
  "tomography",
 
163
  "non-contrast ct",
164
  "contrast-enhanced ct",
165
  ],
 
227
  # ---------- Helper: random scoring ----------
228
 
229
  def generate_random_scores() -> Dict[str, float]:
 
 
 
230
  rng = random.Random()
231
 
232
  modality_score = rng.uniform(85.0, 93.0) # percent
 
274
  return rows[["ID", "caption", "concepts_manual", "score", "image_url"]]
275
 
276
 
277
+ # ---------- Caption cleaning & generation ----------
278
 
279
  def clean_caption(text: str) -> str:
280
  """
281
+ Clean captions:
282
  - strip
283
  - split into clauses and remove duplicates
284
  - normalize spacing and punctuation
 
288
 
289
  text = text.strip()
290
 
 
291
  parts = re.split(r"[,.]", text)
292
  parts = [p.strip() for p in parts if p.strip()]
293
 
 
299
  seen.add(key)
300
  unique_parts.append(p)
301
 
302
+ if unique_parts:
 
 
303
  cleaned = ", ".join(unique_parts)
304
+ else:
305
+ cleaned = text
306
 
 
307
  cleaned = re.sub(
308
  r"(respectively,?\s+)+", "respectively ", cleaned, flags=re.IGNORECASE
309
  )
 
311
  cleaned = " ".join(cleaned.split())
312
  if cleaned and not cleaned.endswith("."):
313
  cleaned += "."
314
+ if cleaned:
315
+ cleaned = cleaned[0].upper() + cleaned[1:]
316
  return cleaned
317
 
318
 
319
+ def generate_draft_caption(image: Image.Image) -> str:
320
  """
321
+ Draft caption directly from image using BLIP.
 
322
  """
323
  inputs = caption_processor(images=image, return_tensors="pt").to(
324
+ device, dtype=cap_dtype
325
  )
 
326
  with torch.no_grad():
327
  out_ids = caption_model.generate(
328
  **inputs,
 
333
  length_penalty=0.9,
334
  early_stopping=True,
335
  )
336
+ raw = caption_processor.batch_decode(out_ids, skip_special_tokens=True)[0]
337
+ return clean_caption(raw)
338
+
339
 
340
+ def refine_caption_with_similar_cases(
341
+ draft_caption: str,
342
+ similar_captions: str,
343
+ ) -> str:
344
+ """
345
+ Use FLAN-T5 to rewrite a final diagnosis sentence based on:
346
+ - draft caption from BLIP (current image)
347
+ - captions from similar images
348
+ """
349
+ if not draft_caption:
350
+ draft_caption = "No draft description available."
351
+ if not similar_captions:
352
+ # nothing to refine with; just return draft
353
+ return draft_caption
354
+
355
+ prompt = (
356
+ "You are an expert radiologist.\n\n"
357
+ "Draft findings from the current image:\n"
358
+ f"{draft_caption}\n\n"
359
+ "Findings from similar radiology cases:\n"
360
+ f"{similar_captions}\n\n"
361
+ "Based on all of this, write ONE concise radiology impression "
362
+ "sentence describing the most probable diagnosis and key findings "
363
+ "for the current image. Do not mention 'similar cases' or 'draft'."
364
+ )
365
+
366
+ inputs = refiner_tokenizer(
367
+ prompt,
368
+ return_tensors="pt",
369
+ truncation=True,
370
+ max_length=512,
371
+ ).to(device)
372
+
373
+ with torch.no_grad():
374
+ out_ids = refiner_model.generate(
375
+ **inputs,
376
+ max_new_tokens=64,
377
+ num_beams=4,
378
+ length_penalty=0.9,
379
+ no_repeat_ngram_size=4,
380
+ )
381
+ refined = refiner_tokenizer.decode(out_ids[0], skip_special_tokens=True)
382
+ return clean_caption(refined)
383
 
384
 
385
  # ---------- Routes ----------
 
388
  def root():
389
  return {
390
  "status": "ok",
391
+ "message": "Radiology retrieval + BLIP + FLAN-T5 refinement API",
392
  }
393
 
394
 
 
397
  """
398
  Upload a radiology image.
399
  Returns:
400
+ - query_caption: refined caption using draft + similar cases
401
+ - modality: detected imaging modality
402
  - scores: random quality metrics
403
  - results: similar images (similarity + concepts + image_url)
404
  """
405
  content = await file.read()
406
  image = Image.open(io.BytesIO(content)).convert("RGB")
407
 
408
+ k = int(k)
409
+
410
+ # 1) Retrieval
411
+ results_df = search_similar_by_image(image, k=k)
412
  results = results_df.to_dict(orient="records")
413
 
414
+ # similar captions context (take up to 5)
415
+ similar_caps_list = results_df["caption"].astype(str).tolist()
416
+ similar_caps_short = "; ".join(similar_caps_list[:5])
417
+
418
+ # 2) Draft caption from BLIP
419
  try:
420
+ draft_caption = generate_draft_caption(image)
421
  except Exception as e:
422
+ print("Error generating draft caption:", e)
423
+ draft_caption = ""
424
 
425
+ # 3) Refine caption with similar case captions
426
+ try:
427
+ final_caption = refine_caption_with_similar_cases(
428
+ draft_caption, similar_caps_short
429
+ )
430
+ except Exception as e:
431
+ print("Error refining caption:", e)
432
+ final_caption = draft_caption or None
433
 
434
+ # 4) Modality & scores
435
+ modality = detect_modality(final_caption or "")
436
  scores = generate_random_scores()
437
 
438
  return JSONResponse(
439
  {
440
+ "query_caption": final_caption,
441
  "modality": modality,
442
  "scores": scores,
443
  "results": results,