saad003 commited on
Commit
4f357eb
·
verified ·
1 Parent(s): 05a8813

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -446
app.py CHANGED
@@ -18,8 +18,8 @@ from huggingface_hub import hf_hub_download
18
  from transformers import (
19
  CLIPProcessor,
20
  CLIPModel,
21
- Blip2Processor,
22
- Blip2ForConditionalGeneration,
23
  )
24
 
25
  # ---------- FastAPI app ----------
@@ -38,7 +38,7 @@ app.add_middleware(
38
  # Dataset with FAISS index + radiology_metadata.csv
39
  EMBED_REPO_ID = "saad003/Red01"
40
 
41
- # Dataset with all radiology images (new structure with train01–train07)
42
  IMAGE_REPO_ID = "saad003/images04"
43
  BASE_IMAGE_URL = f"https://huggingface.co/datasets/{IMAGE_REPO_ID}/resolve/main"
44
 
@@ -48,6 +48,9 @@ HF_TOKEN = os.environ.get("HF_TOKEN")
48
  device = "cuda" if torch.cuda.is_available() else "cpu"
49
  print("Using device:", device)
50
 
 
 
 
51
  # ---------- Download index + metadata ----------
52
  print("Downloading FAISS index & metadata from Hugging Face...")
53
 
@@ -81,15 +84,12 @@ clip_model = CLIPModel.from_pretrained(CLIP_MODEL_NAME).to(device)
81
  clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME)
82
  clip_model.eval()
83
 
84
- # ---------- Load BLIP-2 (captioning) ----------
85
- print("Loading BLIP-2 model for medical captioning...")
86
- CAPTION_MODEL_ID = "Salesforce/blip2-opt-2.7b"
87
-
88
- # Use fp16 on GPU, fp32 on CPU
89
- caption_dtype = torch.float16 if device == "cuda" else torch.float32
90
 
91
- caption_processor = Blip2Processor.from_pretrained(CAPTION_MODEL_ID)
92
- caption_model = Blip2ForConditionalGeneration.from_pretrained(
93
  CAPTION_MODEL_ID,
94
  torch_dtype=caption_dtype,
95
  ).to(device)
@@ -116,14 +116,12 @@ def id_to_image_url(image_id: str) -> str:
116
  elif "_valid_" in image_id:
117
  folder = "valid"
118
  elif "_train_" in image_id:
119
- # last part: ROCOv2_2023_train_054005 -> "054005"
120
  num_str = image_id.split("_")[-1]
121
  try:
122
  n = int(num_str)
123
  except ValueError:
124
  n = 0
125
 
126
- # Rough ranges based on your description
127
  if 1 <= n <= 9000:
128
  folder = "train01"
129
  elif 9001 <= n <= 18000:
@@ -214,7 +212,6 @@ def detect_modality(caption: str) -> str:
214
  if kw in text:
215
  return modality
216
 
217
- # Back-up heuristics
218
  if "mra" in text:
219
  return "MRI"
220
  if "cta " in text or "ct angiography" in text:
@@ -243,491 +240,115 @@ def generate_random_scores() -> Dict[str, float]:
243
  }
244
 
245
 
246
- # ---------- Helper: search by image ----------
247
 
248
  def search_similar_by_image(image: Image.Image, k: int = 5) -> pd.DataFrame:
249
  """
250
  Encode query image with CLIP, search FAISS,
251
- filter out self-match (score ~ 1.0), and return top-k results.
252
  """
253
- # Encode image
254
  inputs = clip_processor(images=image, return_tensors="pt").to(device)
255
  with torch.no_grad():
256
  feats = clip_model.get_image_features(**inputs)
257
 
258
- # Normalize (same as you did when building the index)
259
  feats = feats / feats.norm(p=2, dim=-1, keepdim=True)
260
  feats = feats.cpu().numpy().astype("float32")
261
 
262
- # Search a bit more than k so we can drop self-match
263
  search_k = min(index.ntotal, k + 5)
264
  D, I = index.search(feats, search_k)
265
 
266
  rows = metadata.iloc[I[0]].copy()
267
  rows["score"] = D[0]
268
 
269
- # Remove potential self-match (exact same image → cosine ~ 1.0)
270
  rows = rows[rows["score"] < 0.999].copy()
271
 
272
- # Add image_url
273
  rows["image_url"] = rows["ID"].apply(id_to_image_url)
274
 
275
- # Keep only needed columns and top-k by score
276
  rows = rows.sort_values("score", ascending=False).head(k)
277
-
278
- # If concepts_manual is missing, fill with empty string
279
  if "concepts_manual" not in rows.columns:
280
  rows["concepts_manual"] = ""
281
 
282
  return rows[["ID", "caption", "concepts_manual", "score", "image_url"]]
283
 
284
 
285
- # ---------- Helper: caption with BLIP-2 ----------
286
 
287
  def clean_caption(text: str) -> str:
288
- """Basic cleanup to remove obvious repetition artifacts."""
289
- text = text.strip()
290
-
291
- # Deduplicate immediate repeated phrases separated by commas
292
- parts = [p.strip() for p in text.split(",")]
293
- dedup = []
294
- for p in parts:
295
- if not dedup or p.lower() != dedup[-1].lower():
296
- dedup.append(p)
297
- text = ", ".join(dedup)
298
-
299
- # Remove repeated 'respectively'
300
- text = re.sub(r"(respectively,?\s+)+", "respectively ", text, flags=re.IGNORECASE)
301
-
302
- # Remove exact doubled sentence patterns like "..., and a large ... and a large ..."
303
- text = re.sub(r"\b(\w+(?:\s+\w+){2,})\s+\1\b", r"\1", text, flags=re.IGNORECASE)
304
-
305
- # Normalize whitespace
306
- text = " ".join(text.split())
307
- return text
308
-
309
-
310
- def generate_query_caption(image: Image.Image) -> str:
311
- """
312
- Generate a radiology-focused caption using BLIP-2.
313
- """
314
- prompt = (
315
- "You are an expert radiologist. "
316
- "Describe the key radiology findings in one concise sentence. "
317
- "Avoid repeating phrases."
318
- )
319
-
320
- inputs = caption_processor(
321
- images=image,
322
- text=prompt,
323
- return_tensors="pt",
324
- ).to(device, dtype=caption_dtype)
325
-
326
- with torch.no_grad():
327
- generated_ids = caption_model.generate(
328
- **inputs,
329
- max_new_tokens=64,
330
- num_beams=4,
331
- no_repeat_ngram_size=3,
332
- repetition_penalty=1.1,
333
- )
334
-
335
- caption = caption_processor.batch_decode(
336
- generated_ids, skip_special_tokens=True
337
- )[0]
338
- return clean_caption(caption)
339
-
340
-
341
- # ---------- Routes ----------
342
-
343
- @app.get("/")
344
- def root():
345
- return {"status": "ok", "message": "Radiology retrieval + BLIP-2 captioning API"}
346
-
347
-
348
- @app.post("/search_by_image")
349
- async def search_by_image(file: UploadFile = File(...), k: int = 5):
350
  """
351
- Upload a radiology image.
352
- Returns:
353
- - query_caption: BLIP-2 caption for the query image
354
- - modality: detected imaging modality from caption
355
- - scores: random quality metrics in given ranges
356
- - results: list of similar images with similarity + concepts + image_url
357
  """
358
- # Read uploaded file
359
- content = await file.read()
360
- image = Image.open(io.BytesIO(content)).convert("RGB")
361
-
362
- # Retrieval
363
- results_df = search_similar_by_image(image, k=int(k))
364
- results = results_df.to_dict(orient="records")
365
-
366
- # Caption + modality
367
- try:
368
- query_caption = generate_query_caption(image)
369
- except Exception as e:
370
- print("Error generating caption with BLIP-2:", e)
371
- query_caption = None
372
-
373
- modality = detect_modality(query_caption or "")
374
-
375
- # Random scores
376
- scores = generate_random_scores()
377
-
378
- return JSONResponse(
379
- {
380
- "query_caption": query_caption,
381
- "modality": modality,
382
- "scores": scores,
383
- "results": results,
384
- }
385
- )
386
- # app.py
387
- import io
388
- import os
389
- import random
390
- import re
391
- from typing import Dict
392
-
393
- import faiss
394
- import torch
395
- import pandas as pd
396
-
397
- from PIL import Image
398
- from fastapi import FastAPI, File, UploadFile
399
- from fastapi.middleware.cors import CORSMiddleware
400
- from fastapi.responses import JSONResponse
401
-
402
- from huggingface_hub import hf_hub_download
403
- from transformers import (
404
- CLIPProcessor,
405
- CLIPModel,
406
- Blip2Processor,
407
- Blip2ForConditionalGeneration,
408
- )
409
 
410
- # ---------- FastAPI app ----------
411
- app = FastAPI()
412
-
413
- app.add_middleware(
414
- CORSMiddleware,
415
- allow_origins=["*"], # later restrict to your frontend domain
416
- allow_credentials=True,
417
- allow_methods=["*"],
418
- allow_headers=["*"],
419
- )
420
-
421
- # ---------- Config ----------
422
-
423
- # Dataset with FAISS index + radiology_metadata.csv
424
- EMBED_REPO_ID = "saad003/Red01"
425
-
426
- # Dataset with all radiology images (new structure with train01–train07)
427
- IMAGE_REPO_ID = "saad003/images04"
428
- BASE_IMAGE_URL = f"https://huggingface.co/datasets/{IMAGE_REPO_ID}/resolve/main"
429
-
430
- # Optional: token if Red01 is private
431
- HF_TOKEN = os.environ.get("HF_TOKEN")
432
-
433
- device = "cuda" if torch.cuda.is_available() else "cpu"
434
- print("Using device:", device)
435
-
436
- # ---------- Download index + metadata ----------
437
- print("Downloading FAISS index & metadata from Hugging Face...")
438
-
439
- INDEX_PATH = hf_hub_download(
440
- repo_id=EMBED_REPO_ID,
441
- filename="radiology_index.faiss",
442
- repo_type="dataset",
443
- token=HF_TOKEN,
444
- )
445
-
446
- META_PATH = hf_hub_download(
447
- repo_id=EMBED_REPO_ID,
448
- filename="radiology_metadata.csv",
449
- repo_type="dataset",
450
- token=HF_TOKEN,
451
- )
452
-
453
- print("Loading FAISS index...")
454
- index = faiss.read_index(INDEX_PATH)
455
-
456
- print("Loading metadata CSV...")
457
- metadata = pd.read_csv(META_PATH)
458
-
459
- assert index.ntotal == len(metadata), "Index size and metadata rows mismatch!"
460
-
461
- # ---------- Load CLIP (retrieval) ----------
462
- print("Loading PubMedCLIP model for retrieval...")
463
- CLIP_MODEL_NAME = "flaviagiammarino/pubmed-clip-vit-base-patch32"
464
-
465
- clip_model = CLIPModel.from_pretrained(CLIP_MODEL_NAME).to(device)
466
- clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME)
467
- clip_model.eval()
468
-
469
- # ---------- Load BLIP-2 (captioning) ----------
470
- print("Loading BLIP-2 model for medical captioning...")
471
- CAPTION_MODEL_ID = "Salesforce/blip2-opt-2.7b"
472
-
473
- # Use fp16 on GPU, fp32 on CPU
474
- caption_dtype = torch.float16 if device == "cuda" else torch.float32
475
-
476
- caption_processor = Blip2Processor.from_pretrained(CAPTION_MODEL_ID)
477
- caption_model = Blip2ForConditionalGeneration.from_pretrained(
478
- CAPTION_MODEL_ID,
479
- torch_dtype=caption_dtype,
480
- ).to(device)
481
- caption_model.eval()
482
-
483
- print("Backend ready ✅")
484
-
485
-
486
- # ---------- Helper: image path mapping ----------
487
-
488
- def id_to_image_url(image_id: str) -> str:
489
- """
490
- Map ROCO image IDs to folders in saad003/images04.
491
-
492
- test -> test/
493
- valid -> valid/
494
- train -> train01 ... train07 based on numeric ID
495
- """
496
- image_id = image_id.strip()
497
- base = BASE_IMAGE_URL
498
-
499
- if "_test_" in image_id:
500
- folder = "test"
501
- elif "_valid_" in image_id:
502
- folder = "valid"
503
- elif "_train_" in image_id:
504
- # last part: ROCOv2_2023_train_054005 -> "054005"
505
- num_str = image_id.split("_")[-1]
506
- try:
507
- n = int(num_str)
508
- except ValueError:
509
- n = 0
510
-
511
- # Rough ranges based on your description
512
- if 1 <= n <= 9000:
513
- folder = "train01"
514
- elif 9001 <= n <= 18000:
515
- folder = "train02"
516
- elif 18001 <= n <= 27000:
517
- folder = "train03"
518
- elif 27001 <= n <= 36000:
519
- folder = "train04"
520
- elif 36001 <= n <= 45000:
521
- folder = "train05"
522
- elif 45001 <= n <= 54000:
523
- folder = "train06"
524
- else:
525
- folder = "train07"
526
- else:
527
- folder = ""
528
-
529
- if folder:
530
- return f"{base}/{folder}/{image_id}.jpg"
531
- else:
532
- # fallback – should not happen, but safe
533
- return f"{base}/{image_id}.jpg"
534
-
535
-
536
- # ---------- Helper: modality detection ----------
537
-
538
- MODALITY_KEYWORDS = {
539
- "CT": [
540
- "ct ",
541
- "ctscan",
542
- "computed tomography",
543
- "tomography",
544
- "ct scan",
545
- "non-contrast ct",
546
- "contrast-enhanced ct",
547
- ],
548
- "MRI": [
549
- "mri ",
550
- "magnetic resonance",
551
- "t1-weighted",
552
- "t2-weighted",
553
- "flair sequence",
554
- "diffusion-weighted",
555
- "dwi",
556
- ],
557
- "X-ray": [
558
- "x-ray",
559
- "x ray",
560
- "radiograph",
561
- "plain film",
562
- "chest film",
563
- "postoperative x",
564
- "post-operative x",
565
- "cxr",
566
- ],
567
- "Ultrasound": [
568
- "ultrasound",
569
- "sonogram",
570
- "sonography",
571
- "usg",
572
- "doppler",
573
- "echocardiogram",
574
- "echocardiography",
575
- ],
576
- "PET/CT": [
577
- "pet-ct",
578
- "pet/ct",
579
- "pet scan",
580
- "positron emission tomography",
581
- ],
582
- "Fluoroscopy": [
583
- "fluoroscopy",
584
- "fluoroscopic",
585
- "angiogram",
586
- "angiography",
587
- "barium swallow",
588
- "barium enema",
589
- ],
590
- }
591
-
592
- def detect_modality(caption: str) -> str:
593
- if not caption:
594
- return "Unknown"
595
- text = caption.lower()
596
-
597
- for modality, keywords in MODALITY_KEYWORDS.items():
598
- for kw in keywords:
599
- if kw in text:
600
- return modality
601
-
602
- # Back-up heuristics
603
- if "mra" in text:
604
- return "MRI"
605
- if "cta " in text or "ct angiography" in text:
606
- return "CT"
607
- return "Unknown"
608
-
609
-
610
- # ---------- Helper: random scoring ----------
611
-
612
- def generate_random_scores() -> Dict[str, float]:
613
- """
614
- Return random scores in the ranges you specified.
615
- """
616
- rng = random.Random()
617
-
618
- modality_score = rng.uniform(85.0, 93.0) # percent
619
- cui_at_k = rng.uniform(0.30, 0.61)
620
- bert = rng.uniform(0.20, 0.40)
621
- medbert = rng.uniform(0.20, 0.35)
622
-
623
- return {
624
- "modality_score": round(modality_score, 1),
625
- "cui_at_k": round(cui_at_k, 3),
626
- "bertscore": round(bert, 3),
627
- "medbertscore": round(medbert, 3),
628
- }
629
-
630
-
631
- # ---------- Helper: search by image ----------
632
-
633
- def search_similar_by_image(image: Image.Image, k: int = 5) -> pd.DataFrame:
634
- """
635
- Encode query image with CLIP, search FAISS,
636
- filter out self-match (score ~ 1.0), and return top-k results.
637
- """
638
- # Encode image
639
- inputs = clip_processor(images=image, return_tensors="pt").to(device)
640
- with torch.no_grad():
641
- feats = clip_model.get_image_features(**inputs)
642
-
643
- # Normalize (same as you did when building the index)
644
- feats = feats / feats.norm(p=2, dim=-1, keepdim=True)
645
- feats = feats.cpu().numpy().astype("float32")
646
-
647
- # Search a bit more than k so we can drop self-match
648
- search_k = min(index.ntotal, k + 5)
649
- D, I = index.search(feats, search_k)
650
-
651
- rows = metadata.iloc[I[0]].copy()
652
- rows["score"] = D[0]
653
-
654
- # Remove potential self-match (exact same image → cosine ~ 1.0)
655
- rows = rows[rows["score"] < 0.999].copy()
656
-
657
- # Add image_url
658
- rows["image_url"] = rows["ID"].apply(id_to_image_url)
659
-
660
- # Keep only needed columns and top-k by score
661
- rows = rows.sort_values("score", ascending=False).head(k)
662
-
663
- # If concepts_manual is missing, fill with empty string
664
- if "concepts_manual" not in rows.columns:
665
- rows["concepts_manual"] = ""
666
-
667
- return rows[["ID", "caption", "concepts_manual", "score", "image_url"]]
668
-
669
-
670
- # ---------- Helper: caption with BLIP-2 ----------
671
-
672
- def clean_caption(text: str) -> str:
673
- """Basic cleanup to remove obvious repetition artifacts."""
674
  text = text.strip()
675
 
676
- # Deduplicate immediate repeated phrases separated by commas
677
- parts = [p.strip() for p in text.split(",")]
678
- dedup = []
 
 
 
679
  for p in parts:
680
- if not dedup or p.lower() != dedup[-1].lower():
681
- dedup.append(p)
682
- text = ", ".join(dedup)
 
683
 
684
- # Remove repeated 'respectively'
685
- text = re.sub(r"(respectively,?\s+)+", "respectively ", text, flags=re.IGNORECASE)
 
 
686
 
687
- # Remove exact doubled sentence patterns like "..., and a large ... and a large ..."
688
- text = re.sub(r"\b(\w+(?:\s+\w+){2,})\s+\1\b", r"\1", text, flags=re.IGNORECASE)
 
 
689
 
690
- # Normalize whitespace
691
- text = " ".join(text.split())
692
- return text
 
 
693
 
694
 
695
  def generate_query_caption(image: Image.Image) -> str:
696
  """
697
- Generate a radiology-focused caption using BLIP-2.
 
698
  """
699
- prompt = (
700
- "You are an expert radiologist. "
701
- "Describe the key radiology findings in one concise sentence. "
702
- "Avoid repeating phrases."
703
  )
704
 
705
- inputs = caption_processor(
706
- images=image,
707
- text=prompt,
708
- return_tensors="pt",
709
- ).to(device, dtype=caption_dtype)
710
-
711
  with torch.no_grad():
712
- generated_ids = caption_model.generate(
713
  **inputs,
714
- max_new_tokens=64,
715
- num_beams=4,
716
- no_repeat_ngram_size=3,
717
- repetition_penalty=1.1,
 
 
718
  )
719
 
720
- caption = caption_processor.batch_decode(
721
- generated_ids, skip_special_tokens=True
722
  )[0]
723
- return clean_caption(caption)
724
 
725
 
726
  # ---------- Routes ----------
727
 
728
  @app.get("/")
729
  def root():
730
- return {"status": "ok", "message": "Radiology retrieval + BLIP-2 captioning API"}
 
 
 
731
 
732
 
733
  @app.post("/search_by_image")
@@ -735,12 +356,11 @@ async def search_by_image(file: UploadFile = File(...), k: int = 5):
735
  """
736
  Upload a radiology image.
737
  Returns:
738
- - query_caption: BLIP-2 caption for the query image
739
  - modality: detected imaging modality from caption
740
- - scores: random quality metrics in given ranges
741
- - results: list of similar images with similarity + concepts + image_url
742
  """
743
- # Read uploaded file
744
  content = await file.read()
745
  image = Image.open(io.BytesIO(content)).convert("RGB")
746
 
@@ -752,7 +372,7 @@ async def search_by_image(file: UploadFile = File(...), k: int = 5):
752
  try:
753
  query_caption = generate_query_caption(image)
754
  except Exception as e:
755
- print("Error generating caption with BLIP-2:", e)
756
  query_caption = None
757
 
758
  modality = detect_modality(query_caption or "")
 
18
  from transformers import (
19
  CLIPProcessor,
20
  CLIPModel,
21
+ BlipForConditionalGeneration,
22
+ AutoProcessor,
23
  )
24
 
25
  # ---------- FastAPI app ----------
 
38
  # Dataset with FAISS index + radiology_metadata.csv
39
  EMBED_REPO_ID = "saad003/Red01"
40
 
41
+ # Dataset with all radiology images (test, valid, train01–train07)
42
  IMAGE_REPO_ID = "saad003/images04"
43
  BASE_IMAGE_URL = f"https://huggingface.co/datasets/{IMAGE_REPO_ID}/resolve/main"
44
 
 
48
  device = "cuda" if torch.cuda.is_available() else "cpu"
49
  print("Using device:", device)
50
 
51
+ # use fp16 on GPU to speed up BLIP, fp32 on CPU
52
+ caption_dtype = torch.float16 if device == "cuda" else torch.float32
53
+
54
  # ---------- Download index + metadata ----------
55
  print("Downloading FAISS index & metadata from Hugging Face...")
56
 
 
84
  clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME)
85
  clip_model.eval()
86
 
87
+ # ---------- Load BLIP (radiology captioning) ----------
88
+ print("Loading BLIP ROCO radiology captioning model...")
89
+ CAPTION_MODEL_ID = "WafaaFraih/blip-roco-radiology-captioning"
 
 
 
90
 
91
+ caption_processor = AutoProcessor.from_pretrained(CAPTION_MODEL_ID)
92
+ caption_model = BlipForConditionalGeneration.from_pretrained(
93
  CAPTION_MODEL_ID,
94
  torch_dtype=caption_dtype,
95
  ).to(device)
 
116
  elif "_valid_" in image_id:
117
  folder = "valid"
118
  elif "_train_" in image_id:
 
119
  num_str = image_id.split("_")[-1]
120
  try:
121
  n = int(num_str)
122
  except ValueError:
123
  n = 0
124
 
 
125
  if 1 <= n <= 9000:
126
  folder = "train01"
127
  elif 9001 <= n <= 18000:
 
212
  if kw in text:
213
  return modality
214
 
 
215
  if "mra" in text:
216
  return "MRI"
217
  if "cta " in text or "ct angiography" in text:
 
240
  }
241
 
242
 
243
+ # ---------- Helper: FAISS search ----------
244
 
245
  def search_similar_by_image(image: Image.Image, k: int = 5) -> pd.DataFrame:
246
  """
247
  Encode query image with CLIP, search FAISS,
248
+ filter out self-match, and return top-k results.
249
  """
 
250
  inputs = clip_processor(images=image, return_tensors="pt").to(device)
251
  with torch.no_grad():
252
  feats = clip_model.get_image_features(**inputs)
253
 
 
254
  feats = feats / feats.norm(p=2, dim=-1, keepdim=True)
255
  feats = feats.cpu().numpy().astype("float32")
256
 
 
257
  search_k = min(index.ntotal, k + 5)
258
  D, I = index.search(feats, search_k)
259
 
260
  rows = metadata.iloc[I[0]].copy()
261
  rows["score"] = D[0]
262
 
263
+ # drop exact self-match
264
  rows = rows[rows["score"] < 0.999].copy()
265
 
 
266
  rows["image_url"] = rows["ID"].apply(id_to_image_url)
267
 
 
268
  rows = rows.sort_values("score", ascending=False).head(k)
 
 
269
  if "concepts_manual" not in rows.columns:
270
  rows["concepts_manual"] = ""
271
 
272
  return rows[["ID", "caption", "concepts_manual", "score", "image_url"]]
273
 
274
 
275
+ # ---------- Helper: caption cleaning & generation ----------
276
 
277
  def clean_caption(text: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  """
279
+ Clean BLIP captions:
280
+ - strip
281
+ - split into clauses and remove duplicates
282
+ - normalize spacing and punctuation
 
 
283
  """
284
+ if not text:
285
+ return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  text = text.strip()
288
 
289
+ # break into clauses
290
+ parts = re.split(r"[,.]", text)
291
+ parts = [p.strip() for p in parts if p.strip()]
292
+
293
+ seen = set()
294
+ unique_parts = []
295
  for p in parts:
296
+ key = p.lower()
297
+ if key not in seen:
298
+ seen.add(key)
299
+ unique_parts.append(p)
300
 
301
+ if not unique_parts:
302
+ cleaned = text
303
+ else:
304
+ cleaned = ", ".join(unique_parts)
305
 
306
+ # remove repeated 'respectively'
307
+ cleaned = re.sub(
308
+ r"(respectively,?\s+)+", "respectively ", cleaned, flags=re.IGNORECASE
309
+ )
310
 
311
+ cleaned = " ".join(cleaned.split())
312
+ if cleaned and not cleaned.endswith("."):
313
+ cleaned += "."
314
+ cleaned = cleaned[0].upper() + cleaned[1:] if cleaned else cleaned
315
+ return cleaned
316
 
317
 
318
  def generate_query_caption(image: Image.Image) -> str:
319
  """
320
+ Generate a radiology caption using BLIP (ROCO).
321
+ Tuned decoding to reduce repetition and keep it concise.
322
  """
323
+ inputs = caption_processor(images=image, return_tensors="pt").to(
324
+ device, dtype=caption_dtype
 
 
325
  )
326
 
 
 
 
 
 
 
327
  with torch.no_grad():
328
+ out_ids = caption_model.generate(
329
  **inputs,
330
+ max_new_tokens=40,
331
+ num_beams=5,
332
+ no_repeat_ngram_size=4,
333
+ repetition_penalty=1.4,
334
+ length_penalty=0.9,
335
+ early_stopping=True,
336
  )
337
 
338
+ raw_caption = caption_processor.batch_decode(
339
+ out_ids, skip_special_tokens=True
340
  )[0]
341
+ return clean_caption(raw_caption)
342
 
343
 
344
  # ---------- Routes ----------
345
 
346
  @app.get("/")
347
  def root():
348
+ return {
349
+ "status": "ok",
350
+ "message": "Radiology retrieval + BLIP radiology captioning API",
351
+ }
352
 
353
 
354
  @app.post("/search_by_image")
 
356
  """
357
  Upload a radiology image.
358
  Returns:
359
+ - query_caption: BLIP caption for the query image
360
  - modality: detected imaging modality from caption
361
+ - scores: random quality metrics
362
+ - results: similar images (similarity + concepts + image_url)
363
  """
 
364
  content = await file.read()
365
  image = Image.open(io.BytesIO(content)).convert("RGB")
366
 
 
372
  try:
373
  query_caption = generate_query_caption(image)
374
  except Exception as e:
375
+ print("Error generating caption:", e)
376
  query_caption = None
377
 
378
  modality = detect_modality(query_caption or "")