bombshelll commited on
Commit
f2ba684
ยท
1 Parent(s): 227593e

Add hierarchical classification and captioning app

Browse files
Files changed (1) hide show
  1. app.py +53 -24
app.py CHANGED
@@ -2,10 +2,11 @@ import gradio as gr
2
  from PIL import Image
3
  import torch
4
  from transformers import VisionEncoderDecoderModel, AutoTokenizer, ViTFeatureExtractor, AutoImageProcessor, AutoModelForImageClassification
 
5
 
6
  device = "cuda" if torch.cuda.is_available() else "cpu"
7
 
8
- # Load image captioning model
9
  caption_model = VisionEncoderDecoderModel.from_pretrained("bombshelll/ViT_BioMedBert_Captioning_ROCO").to(device)
10
  tokenizer = AutoTokenizer.from_pretrained("bombshelll/ViT_BioMedBert_Captioning_ROCO")
11
  feature_extractor = ViTFeatureExtractor.from_pretrained("bombshelll/ViT_BioMedBert_Captioning_ROCO")
@@ -13,7 +14,7 @@ feature_extractor = ViTFeatureExtractor.from_pretrained("bombshelll/ViT_BioMedBe
13
  # Load classification models
14
  def load_classifier(model_id):
15
  processor = AutoImageProcessor.from_pretrained(model_id)
16
- model = AutoModelForImageClassification.from_pretrained(model_id)
17
  return processor, model
18
 
19
  classifiers = {
@@ -23,7 +24,7 @@ classifiers = {
23
  "tumor_type": load_classifier("bombshelll/swin-brain-tumor-type-classification")
24
  }
25
 
26
- # Inference functions
27
  def classify_image(image):
28
  results = {}
29
  for name, (processor, model) in classifiers.items():
@@ -35,22 +36,23 @@ def classify_image(image):
35
  results[name] = label
36
  return results
37
 
 
38
  def generate_captions(image, keywords):
39
  pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)
40
 
41
- # Without keywords
42
  caption_model.eval()
43
  with torch.no_grad():
44
  output_ids = caption_model.generate(pixel_values, max_length=80)
45
  caption1 = tokenizer.decode(output_ids[0], skip_special_tokens=True)
46
 
47
- # With keywords
48
  prompt = " ".join(keywords)
49
  prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
50
  with torch.no_grad():
51
  output_ids = caption_model.generate(
52
  pixel_values,
53
- decoder_input_ids=prompt_ids,
54
  max_length=80,
55
  num_beams=4,
56
  no_repeat_ngram_size=3,
@@ -60,24 +62,51 @@ def generate_captions(image, keywords):
60
 
61
  return caption1, caption2
62
 
63
- # Main app logic
64
- def run_pipeline(image):
65
  classification = classify_image(image)
66
  keywords = list(classification.values())
67
  caption1, caption2 = generate_captions(image, keywords)
68
- return classification, caption1, caption2
69
-
70
- # Gradio Interface
71
- interface = gr.Interface(
72
- fn=run_pipeline,
73
- inputs=gr.Image(type="pil"),
74
- outputs=[
75
- gr.JSON(label="Classification Result"),
76
- gr.Textbox(label="Caption without Keywords"),
77
- gr.Textbox(label="Caption with Keywords")
78
- ],
79
- title="๐Ÿง  Brain Hierarchical Classification + Captioning",
80
- description="Upload an MRI/CT brain image. The system will classify (plane, modality, abnormality, tumor) and generate two captions: one plain and one guided by the classification keywords."
81
- )
82
-
83
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from PIL import Image
3
  import torch
4
  from transformers import VisionEncoderDecoderModel, AutoTokenizer, ViTFeatureExtractor, AutoImageProcessor, AutoModelForImageClassification
5
+ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
6
 
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
9
+ # Load captioning model
10
  caption_model = VisionEncoderDecoderModel.from_pretrained("bombshelll/ViT_BioMedBert_Captioning_ROCO").to(device)
11
  tokenizer = AutoTokenizer.from_pretrained("bombshelll/ViT_BioMedBert_Captioning_ROCO")
12
  feature_extractor = ViTFeatureExtractor.from_pretrained("bombshelll/ViT_BioMedBert_Captioning_ROCO")
 
14
  # Load classification models
15
  def load_classifier(model_id):
16
  processor = AutoImageProcessor.from_pretrained(model_id)
17
+ model = AutoModelForImageClassification.from_pretrained(model_id).to(device)
18
  return processor, model
19
 
20
  classifiers = {
 
24
  "tumor_type": load_classifier("bombshelll/swin-brain-tumor-type-classification")
25
  }
26
 
27
+ # Classification function
28
  def classify_image(image):
29
  results = {}
30
  for name, (processor, model) in classifiers.items():
 
36
  results[name] = label
37
  return results
38
 
39
+ # Caption generation
40
  def generate_captions(image, keywords):
41
  pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)
42
 
43
+ # Caption without keywords
44
  caption_model.eval()
45
  with torch.no_grad():
46
  output_ids = caption_model.generate(pixel_values, max_length=80)
47
  caption1 = tokenizer.decode(output_ids[0], skip_special_tokens=True)
48
 
49
+ # Caption with keywords
50
  prompt = " ".join(keywords)
51
  prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
52
  with torch.no_grad():
53
  output_ids = caption_model.generate(
54
  pixel_values,
55
+ decoder_input_ids=prompt_ids[:, :-1],
56
  max_length=80,
57
  num_beams=4,
58
  no_repeat_ngram_size=3,
 
62
 
63
  return caption1, caption2
64
 
65
+ # Main pipeline
66
+ def run_pipeline(image, actual_caption):
67
  classification = classify_image(image)
68
  keywords = list(classification.values())
69
  caption1, caption2 = generate_captions(image, keywords)
70
+
71
+ # Format classification result as string
72
+ classification_str = (
73
+ f"๐Ÿงญ Plane: {classification.get('plane')}\n"
74
+ f"๐Ÿ–ผ๏ธ Modality: {classification.get('modality')}\n"
75
+ f"๐Ÿงฌ Abnormality: {classification.get('abnormality')}\n"
76
+ )
77
+ if "tumor_type" in classification:
78
+ classification_str += f"๐Ÿ”ฌ Tumor Type: {classification.get('tumor_type')}\n"
79
+
80
+ # BLEU Score calculation
81
+ if actual_caption.strip():
82
+ ref = [actual_caption.lower().split()]
83
+ hyp = caption2.lower().split()
84
+ score = sentence_bleu(ref, hyp, smoothing_function=SmoothingFunction().method1)
85
+ bleu = f"๐Ÿ“Š BLEU Score: {score:.2f}"
86
+ else:
87
+ bleu = "๐Ÿ“Š BLEU Score: -"
88
+
89
+ # Output
90
+ result_text = f"{classification_str}\n\nโœ๏ธ Caption without Keywords:\n{caption1}\n\nโœจ Caption with Keywords:\n{caption2}\n\n{bleu}"
91
+
92
+ return result_text
93
+
94
+ # Gradio UI
95
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="pink")) as demo:
96
+ gr.Markdown(
97
+ """
98
+ <h1 style='text-align: center;'>๐Ÿง  Brain Hierarchical Classification + Captioning</h1>
99
+ <p style='text-align: center;'>Upload an MRI/CT brain image. The system will classify the image (plane, modality, abnormality, tumor) and generate two captions, along with a BLEU score if ground truth is given.</p>
100
+ """
101
+ )
102
+ with gr.Row():
103
+ with gr.Column():
104
+ image_input = gr.Image(type="pil", label="๐Ÿ–ผ๏ธ Upload Brain MRI/CT")
105
+ actual_caption = gr.Textbox(label="๐Ÿง  Ground Truth Caption (optional)")
106
+ btn = gr.Button("๐Ÿš€ Submit")
107
+ with gr.Column():
108
+ output_box = gr.Textbox(label="๐Ÿ“ Result", lines=20)
109
+
110
+ btn.click(fn=run_pipeline, inputs=[image_input, actual_caption], outputs=output_box)
111
+
112
+ demo.launch()