saad003 commited on
Commit
5a73ed5
·
verified ·
1 Parent(s): 5a6d023

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -102
app.py CHANGED
@@ -18,8 +18,6 @@ from huggingface_hub import hf_hub_download
18
  from transformers import (
19
  CLIPProcessor,
20
  CLIPModel,
21
- BlipForConditionalGeneration,
22
- AutoProcessor,
23
  AutoTokenizer,
24
  AutoModelForSeq2SeqLM,
25
  )
@@ -46,8 +44,6 @@ HF_TOKEN = os.environ.get("HF_TOKEN")
46
  device = "cuda" if torch.cuda.is_available() else "cpu"
47
  print("Using device:", device)
48
 
49
- cap_dtype = torch.float16 if device == "cuda" else torch.float32
50
-
51
  # ---------- Download index + metadata ----------
52
  print("Downloading FAISS index & metadata from Hugging Face...")
53
 
@@ -80,19 +76,8 @@ clip_model = CLIPModel.from_pretrained(CLIP_MODEL_NAME).to(device)
80
  clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME)
81
  clip_model.eval()
82
 
83
- # ---------- Load BLIP (image -> draft caption) ----------
84
- print("Loading BLIP ROCO radiology captioning model...")
85
- CAPTION_MODEL_ID = "WafaaFraih/blip-roco-radiology-captioning"
86
-
87
- caption_processor = AutoProcessor.from_pretrained(CAPTION_MODEL_ID)
88
- caption_model = BlipForConditionalGeneration.from_pretrained(
89
- CAPTION_MODEL_ID,
90
- torch_dtype=cap_dtype,
91
- ).to(device)
92
- caption_model.eval()
93
-
94
- # ---------- Load FLAN-T5 (text refinement using similar captions) ----------
95
- print("Loading FLAN-T5 for caption refinement...")
96
  REFINER_MODEL_ID = "google/flan-t5-base"
97
 
98
  refiner_tokenizer = AutoTokenizer.from_pretrained(REFINER_MODEL_ID)
@@ -274,93 +259,78 @@ def search_similar_by_image(image: Image.Image, k: int = 5) -> pd.DataFrame:
274
  return rows[["ID", "caption", "concepts_manual", "score", "image_url"]]
275
 
276
 
277
- # ---------- Caption cleaning & generation ----------
278
 
279
  def clean_caption(text: str) -> str:
280
  """
281
- Clean captions:
282
  - strip
283
- - split into clauses and remove duplicates
284
- - normalize spacing and punctuation
285
  """
286
  if not text:
287
  return ""
288
 
289
  text = text.strip()
290
 
291
- parts = re.split(r"[,.]", text)
292
- parts = [p.strip() for p in parts if p.strip()]
 
 
 
 
 
 
 
 
 
 
 
293
 
294
- seen = set()
295
- unique_parts = []
296
- for p in parts:
297
- key = p.lower()
298
- if key not in seen:
299
- seen.add(key)
300
- unique_parts.append(p)
301
 
302
- if unique_parts:
303
- cleaned = ", ".join(unique_parts)
304
- else:
305
- cleaned = text
306
 
307
- cleaned = re.sub(
308
- r"(respectively,?\s+)+", "respectively ", cleaned, flags=re.IGNORECASE
309
- )
 
310
 
311
- cleaned = " ".join(cleaned.split())
312
- if cleaned and not cleaned.endswith("."):
313
- cleaned += "."
314
- if cleaned:
315
- cleaned = cleaned[0].upper() + cleaned[1:]
316
- return cleaned
317
 
 
 
 
318
 
319
- def generate_draft_caption(image: Image.Image) -> str:
320
- """
321
- Draft caption directly from image using BLIP.
322
- """
323
- inputs = caption_processor(images=image, return_tensors="pt").to(
324
- device, dtype=cap_dtype
325
- )
326
- with torch.no_grad():
327
- out_ids = caption_model.generate(
328
- **inputs,
329
- max_new_tokens=40,
330
- num_beams=5,
331
- no_repeat_ngram_size=4,
332
- repetition_penalty=1.4,
333
- length_penalty=0.9,
334
- early_stopping=True,
335
- )
336
- raw = caption_processor.batch_decode(out_ids, skip_special_tokens=True)[0]
337
- return clean_caption(raw)
338
 
339
 
340
- def refine_caption_with_similar_cases(
341
- draft_caption: str,
342
- similar_captions: str,
343
- ) -> str:
344
  """
345
- Use FLAN-T5 to rewrite a final diagnosis sentence based on:
346
- - draft caption from BLIP (current image)
347
- - captions from similar images
348
  """
349
- if not draft_caption:
350
- draft_caption = "No draft description available."
351
- if not similar_captions:
352
- # nothing to refine with; just return draft
353
- return draft_caption
 
 
 
 
 
354
 
355
  prompt = (
356
- "You are an expert radiologist.\n\n"
357
- "Draft findings from the current image:\n"
358
- f"{draft_caption}\n\n"
359
- "Findings from similar radiology cases:\n"
360
- f"{similar_captions}\n\n"
361
- "Based on all of this, write ONE concise radiology impression "
362
- "sentence describing the most probable diagnosis and key findings "
363
- "for the current image. Do not mention 'similar cases' or 'draft'."
364
  )
365
 
366
  inputs = refiner_tokenizer(
@@ -373,13 +343,14 @@ def refine_caption_with_similar_cases(
373
  with torch.no_grad():
374
  out_ids = refiner_model.generate(
375
  **inputs,
376
- max_new_tokens=64,
377
  num_beams=4,
378
  length_penalty=0.9,
379
  no_repeat_ngram_size=4,
380
  )
381
- refined = refiner_tokenizer.decode(out_ids[0], skip_special_tokens=True)
382
- return clean_caption(refined)
 
383
 
384
 
385
  # ---------- Routes ----------
@@ -388,7 +359,7 @@ def refine_caption_with_similar_cases(
388
  def root():
389
  return {
390
  "status": "ok",
391
- "message": "Radiology retrieval + BLIP + FLAN-T5 refinement API",
392
  }
393
 
394
 
@@ -397,7 +368,7 @@ async def search_by_image(file: UploadFile = File(...), k: int = 5):
397
  """
398
  Upload a radiology image.
399
  Returns:
400
- - query_caption: refined caption using draft + similar cases
401
  - modality: detected imaging modality
402
  - scores: random quality metrics
403
  - results: similar images (similarity + concepts + image_url)
@@ -411,27 +382,18 @@ async def search_by_image(file: UploadFile = File(...), k: int = 5):
411
  results_df = search_similar_by_image(image, k=k)
412
  results = results_df.to_dict(orient="records")
413
 
414
- # similar captions context (take up to 5)
415
  similar_caps_list = results_df["caption"].astype(str).tolist()
416
- similar_caps_short = "; ".join(similar_caps_list[:5])
417
-
418
- # 2) Draft caption from BLIP
419
- try:
420
- draft_caption = generate_draft_caption(image)
421
- except Exception as e:
422
- print("Error generating draft caption:", e)
423
- draft_caption = ""
424
 
425
- # 3) Refine caption with similar case captions
426
  try:
427
- final_caption = refine_caption_with_similar_cases(
428
- draft_caption, similar_caps_short
429
  )
430
  except Exception as e:
431
- print("Error refining caption:", e)
432
- final_caption = draft_caption or None
433
 
434
- # 4) Modality & scores
435
  modality = detect_modality(final_caption or "")
436
  scores = generate_random_scores()
437
 
 
18
  from transformers import (
19
  CLIPProcessor,
20
  CLIPModel,
 
 
21
  AutoTokenizer,
22
  AutoModelForSeq2SeqLM,
23
  )
 
44
  device = "cuda" if torch.cuda.is_available() else "cpu"
45
  print("Using device:", device)
46
 
 
 
47
  # ---------- Download index + metadata ----------
48
  print("Downloading FAISS index & metadata from Hugging Face...")
49
 
 
76
  clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME)
77
  clip_model.eval()
78
 
79
+ # ---------- Load FLAN-T5 (caption synthesis) ----------
80
+ print("Loading FLAN-T5 for diagnosis synthesis from similar captions...")
 
 
 
 
 
 
 
 
 
 
 
81
  REFINER_MODEL_ID = "google/flan-t5-base"
82
 
83
  refiner_tokenizer = AutoTokenizer.from_pretrained(REFINER_MODEL_ID)
 
259
  return rows[["ID", "caption", "concepts_manual", "score", "image_url"]]
260
 
261
 
262
+ # ---------- Helper: caption cleaning & synthesis ----------
263
 
264
  def clean_caption(text: str) -> str:
265
  """
266
+ Clean generated caption:
267
  - strip
268
+ - remove obvious prompt leftovers
269
+ - ensure single sentence, nice punctuation
270
  """
271
  if not text:
272
  return ""
273
 
274
  text = text.strip()
275
 
276
+ # Drop any leading instruction-like fragments
277
+ text = re.sub(
278
+ r"^(you are an expert radiologist[:,]?\s*)",
279
+ "",
280
+ text,
281
+ flags=re.IGNORECASE,
282
+ )
283
+ text = re.sub(
284
+ r"(findings? from similar radiology cases[:,]?\s*)",
285
+ "",
286
+ text,
287
+ flags=re.IGNORECASE,
288
+ )
289
 
290
+ # Replace multiple separators
291
+ text = text.replace(" ;", ";")
292
+ text = re.sub(r"\s+[,;]\s*", ", ", text)
 
 
 
 
293
 
294
+ # Collapse spaces
295
+ text = " ".join(text.split())
 
 
296
 
297
+ # If there are multiple sentences, keep only the first one
298
+ parts = re.split(r"(?<=[.!?])\s+", text)
299
+ if parts:
300
+ text = parts[0]
301
 
302
+ # Ensure period
303
+ if text and not text.endswith((".", "!", "?")):
304
+ text += "."
 
 
 
305
 
306
+ # Capitalize first letter
307
+ if text:
308
+ text = text[0].upper() + text[1:]
309
 
310
+ return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
 
312
 
313
+ def synthesize_caption_from_similar_captions(captions: list[str]) -> str:
 
 
 
314
  """
315
+ Use FLAN-T5 to create a diagnosis sentence from captions of similar images.
 
 
316
  """
317
+ captions = [c.strip() for c in captions if c and isinstance(c, str)]
318
+ if not captions:
319
+ return ""
320
+
321
+ # Use at most 5-6 captions to keep prompt short
322
+ caps = captions[:6]
323
+
324
+ numbered = "\n".join(
325
+ f"{i+1}) {c}" for i, c in enumerate(caps)
326
+ )
327
 
328
  prompt = (
329
+ "Radiology findings from similar cases:\n"
330
+ f"{numbered}\n\n"
331
+ "Based on these, write ONE concise radiology impression sentence "
332
+ "describing the most likely diagnosis and key findings for the "
333
+ "current image. Do not mention numbers or 'similar cases'."
 
 
 
334
  )
335
 
336
  inputs = refiner_tokenizer(
 
343
  with torch.no_grad():
344
  out_ids = refiner_model.generate(
345
  **inputs,
346
+ max_new_tokens=48,
347
  num_beams=4,
348
  length_penalty=0.9,
349
  no_repeat_ngram_size=4,
350
  )
351
+
352
+ raw = refiner_tokenizer.decode(out_ids[0], skip_special_tokens=True)
353
+ return clean_caption(raw)
354
 
355
 
356
  # ---------- Routes ----------
 
359
  def root():
360
  return {
361
  "status": "ok",
362
+ "message": "Radiology retrieval + FLAN-T5 synthesis from similar captions",
363
  }
364
 
365
 
 
368
  """
369
  Upload a radiology image.
370
  Returns:
371
+ - query_caption: synthesized diagnosis from captions of similar images
372
  - modality: detected imaging modality
373
  - scores: random quality metrics
374
  - results: similar images (similarity + concepts + image_url)
 
382
  results_df = search_similar_by_image(image, k=k)
383
  results = results_df.to_dict(orient="records")
384
 
385
+ # 2) Synthesize caption only from similar image captions
386
  similar_caps_list = results_df["caption"].astype(str).tolist()
 
 
 
 
 
 
 
 
387
 
 
388
  try:
389
+ final_caption = synthesize_caption_from_similar_captions(
390
+ similar_caps_list
391
  )
392
  except Exception as e:
393
+ print("Error synthesizing caption:", e)
394
+ final_caption = ""
395
 
396
+ # 3) Modality & scores
397
  modality = detect_modality(final_caption or "")
398
  scores = generate_random_scores()
399