bombshelll commited on
Commit
6d6d9b8
Β·
1 Parent(s): 6453d14

Refine Gradio UI

Browse files
Files changed (1) hide show
  1. app.py +32 -26
app.py CHANGED
@@ -11,12 +11,10 @@ warnings.filterwarnings("ignore", category=UserWarning)
11
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
- # Load captioning model
15
  caption_model = VisionEncoderDecoderModel.from_pretrained("bombshelll/ViT_BioMedBert_Captioning_ROCO").to(device)
16
  tokenizer = AutoTokenizer.from_pretrained("bombshelll/ViT_BioMedBert_Captioning_ROCO")
17
  feature_extractor = ViTFeatureExtractor.from_pretrained("bombshelll/ViT_BioMedBert_Captioning_ROCO")
18
 
19
- # Load classification models
20
  def load_classifier(model_id):
21
  processor = AutoImageProcessor.from_pretrained(model_id)
22
  model = AutoModelForImageClassification.from_pretrained(model_id).to(device)
@@ -29,7 +27,6 @@ classifiers = {
29
  "tumor_type": load_classifier("bombshelll/swin-brain-tumor-type-classification")
30
  }
31
 
32
- # Classification function
33
  def classify_image(image):
34
  results = {}
35
  for name, (processor, model) in classifiers.items():
@@ -41,7 +38,6 @@ def classify_image(image):
41
  results[name] = label
42
  return results
43
 
44
- # Preprocessing caption
45
  def preprocess_caption(text):
46
  text = str(text).lower()
47
  text = text.replace("magnetic resonance imaging", "mri")
@@ -59,17 +55,14 @@ def preprocess_caption(text):
59
  text = text.replace("-", " ")
60
  return text.split()
61
 
62
- # Caption generation
63
  def generate_captions(image, keywords):
64
  pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)
65
-
66
- # Caption without keywords
67
  caption_model.eval()
68
  with torch.no_grad():
69
  output_ids = caption_model.generate(pixel_values, max_length=80)
70
  caption1 = tokenizer.decode(output_ids[0], skip_special_tokens=True)
71
 
72
- # Caption with keywords
73
  prompt = " ".join(keywords)
74
  prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
75
  with torch.no_grad():
@@ -85,22 +78,23 @@ def generate_captions(image, keywords):
85
 
86
  return caption1, caption2
87
 
88
- # Main pipeline
89
  def run_pipeline(image, actual_caption):
90
  classification = classify_image(image)
91
  keywords = list(classification.values())
92
  caption1, caption2 = generate_captions(image, keywords)
93
 
94
- # BLEU Score
95
  if actual_caption.strip():
96
  ref = [preprocess_caption(actual_caption)]
97
- hyp = preprocess_caption(caption2)
98
- score = sentence_bleu(ref, hyp, smoothing_function=nltk.translate.bleu_score.SmoothingFunction().method1)
99
- bleu = f"{score:.2f}"
 
 
 
100
  else:
101
- bleu = "-"
 
102
 
103
- # Format outputs
104
  result_sections = {
105
  "classification": (
106
  f"Plane: {classification.get('plane')}\n"
@@ -110,31 +104,43 @@ def run_pipeline(image, actual_caption):
110
  ),
111
  "caption1": caption1,
112
  "caption2": caption2,
113
- "bleu": bleu
 
114
  }
115
 
116
- return result_sections["classification"], result_sections["caption1"], result_sections["caption2"], result_sections["bleu"]
 
 
 
 
 
 
117
 
118
- # Gradio UI
119
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="pink"), css="*{font-family:'Poppins', sans-serif;}") as demo:
120
  gr.Markdown(
121
  """
 
122
  <h1 style='text-align: center;'>🧠 Brain Hierarchical Classification + Captioning</h1>
123
- <p style='text-align: center;'>Upload an MRI/CT brain image. The system will classify the image (plane, modality, abnormality, tumor type) and generate two captions. Optionally, provide a ground truth caption to get BLEU score.</p>
124
  """,
125
  elem_id="title"
126
  )
127
  with gr.Row():
128
  with gr.Column():
129
  image_input = gr.Image(type="pil", label="πŸ–ΌοΈ Upload Brain MRI/CT")
130
- actual_caption = gr.Textbox(label="🧠 Ground Truth Caption (optional)")
131
  btn = gr.Button("πŸš€ Submit")
132
  with gr.Column():
133
- cls_box = gr.Textbox(label="🧾 Classification Result", lines=4)
134
- cap1_box = gr.Textbox(label="✏️ Caption without Keyword Integration", lines=4)
135
- cap2_box = gr.Textbox(label="✨ Caption with Keyword Integration", lines=4)
136
- bleu_box = gr.Textbox(label="πŸ“Š BLEU Score", lines=1)
137
-
138
- btn.click(fn=run_pipeline, inputs=[image_input, actual_caption], outputs=[cls_box, cap1_box, cap2_box, bleu_box])
 
 
 
 
 
139
 
140
  demo.launch()
 
11
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
 
14
  caption_model = VisionEncoderDecoderModel.from_pretrained("bombshelll/ViT_BioMedBert_Captioning_ROCO").to(device)
15
  tokenizer = AutoTokenizer.from_pretrained("bombshelll/ViT_BioMedBert_Captioning_ROCO")
16
  feature_extractor = ViTFeatureExtractor.from_pretrained("bombshelll/ViT_BioMedBert_Captioning_ROCO")
17
 
 
18
  def load_classifier(model_id):
19
  processor = AutoImageProcessor.from_pretrained(model_id)
20
  model = AutoModelForImageClassification.from_pretrained(model_id).to(device)
 
27
  "tumor_type": load_classifier("bombshelll/swin-brain-tumor-type-classification")
28
  }
29
 
 
30
  def classify_image(image):
31
  results = {}
32
  for name, (processor, model) in classifiers.items():
 
38
  results[name] = label
39
  return results
40
 
 
41
  def preprocess_caption(text):
42
  text = str(text).lower()
43
  text = text.replace("magnetic resonance imaging", "mri")
 
55
  text = text.replace("-", " ")
56
  return text.split()
57
 
 
58
  def generate_captions(image, keywords):
59
  pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)
60
+
 
61
  caption_model.eval()
62
  with torch.no_grad():
63
  output_ids = caption_model.generate(pixel_values, max_length=80)
64
  caption1 = tokenizer.decode(output_ids[0], skip_special_tokens=True)
65
 
 
66
  prompt = " ".join(keywords)
67
  prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
68
  with torch.no_grad():
 
78
 
79
  return caption1, caption2
80
 
 
81
  def run_pipeline(image, actual_caption):
82
  classification = classify_image(image)
83
  keywords = list(classification.values())
84
  caption1, caption2 = generate_captions(image, keywords)
85
 
 
86
  if actual_caption.strip():
87
  ref = [preprocess_caption(actual_caption)]
88
+ hyp1 = preprocess_caption(caption1)
89
+ hyp2 = preprocess_caption(caption2)
90
+ score1 = sentence_bleu(ref, hyp1, smoothing_function=nltk.translate.bleu_score.SmoothingFunction().method1)
91
+ score2 = sentence_bleu(ref, hyp2, smoothing_function=nltk.translate.bleu_score.SmoothingFunction().method1)
92
+ bleu1 = f"{score1:.2f}"
93
+ bleu2 = f"{score2:.2f}"
94
  else:
95
+ bleu1 = "-"
96
+ bleu2 = "-"
97
 
 
98
  result_sections = {
99
  "classification": (
100
  f"Plane: {classification.get('plane')}\n"
 
104
  ),
105
  "caption1": caption1,
106
  "caption2": caption2,
107
+ "bleu1": bleu1,
108
+ "bleu2": bleu2
109
  }
110
 
111
+ return (
112
+ result_sections["classification"],
113
+ result_sections["caption1"],
114
+ result_sections["bleu1"],
115
+ result_sections["caption2"],
116
+ result_sections["bleu2"]
117
+ )
118
 
 
119
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="pink"), css="*{font-family:'Poppins', sans-serif;}") as demo:
120
  gr.Markdown(
121
  """
122
+ <link href="https://fonts.googleapis.com/css2?family=Poppins&display=swap" rel="stylesheet">
123
  <h1 style='text-align: center;'>🧠 Brain Hierarchical Classification + Captioning</h1>
124
+ <p style='text-align: center;'>Upload an MRI/CT brain image. The system will classify the image (plane, modality, abnormality, tumor type) and generate two captions. Optionally, provide a ground truth caption to get BLEU scores.</p>
125
  """,
126
  elem_id="title"
127
  )
128
  with gr.Row():
129
  with gr.Column():
130
  image_input = gr.Image(type="pil", label="πŸ–ΌοΈ Upload Brain MRI/CT")
131
+ actual_caption = gr.Textbox(label="πŸ’¬ Ground Truth Caption (optional)")
132
  btn = gr.Button("πŸš€ Submit")
133
  with gr.Column():
134
+ cls_box = gr.Textbox(label="πŸ“‹ Classification Result", lines=4)
135
+ cap1_box = gr.Textbox(label="πŸ“ Caption without Keyword Integration", lines=4)
136
+ bleu1_box = gr.Textbox(label="πŸ“Š BLEU Score (No Keyword)", lines=1)
137
+ cap2_box = gr.Textbox(label="🧠 Caption with Keyword Integration", lines=4)
138
+ bleu2_box = gr.Textbox(label="πŸ“ˆ BLEU Score (With Keyword)", lines=1)
139
+
140
+ btn.click(
141
+ fn=run_pipeline,
142
+ inputs=[image_input, actual_caption],
143
+ outputs=[cls_box, cap1_box, bleu1_box, cap2_box, bleu2_box]
144
+ )
145
 
146
  demo.launch()