bombshelll commited on
Commit
bfe2b83
Β·
1 Parent(s): 6d6d9b8

Remove BLEU

Browse files
Files changed (2) hide show
  1. app.py +15 -59
  2. style.css +13 -0
app.py CHANGED
@@ -2,10 +2,7 @@ import gradio as gr
2
  from PIL import Image
3
  import torch
4
  from transformers import VisionEncoderDecoderModel, AutoTokenizer, ViTFeatureExtractor, AutoImageProcessor, AutoModelForImageClassification
5
- from nltk.translate.bleu_score import sentence_bleu
6
  import warnings
7
- import nltk
8
- nltk.download('punkt')
9
 
10
  warnings.filterwarnings("ignore", category=UserWarning)
11
 
@@ -15,6 +12,9 @@ caption_model = VisionEncoderDecoderModel.from_pretrained("bombshelll/ViT_BioMed
15
  tokenizer = AutoTokenizer.from_pretrained("bombshelll/ViT_BioMedBert_Captioning_ROCO")
16
  feature_extractor = ViTFeatureExtractor.from_pretrained("bombshelll/ViT_BioMedBert_Captioning_ROCO")
17
 
 
 
 
18
  def load_classifier(model_id):
19
  processor = AutoImageProcessor.from_pretrained(model_id)
20
  model = AutoModelForImageClassification.from_pretrained(model_id).to(device)
@@ -38,23 +38,6 @@ def classify_image(image):
38
  results[name] = label
39
  return results
40
 
41
- def preprocess_caption(text):
42
- text = str(text).lower()
43
- text = text.replace("magnetic resonance imaging", "mri")
44
- text = text.replace("magnetic resonance image", "mri")
45
- text = text.replace("computed tomography", "ct")
46
- text = text.replace("t1-weighted", "t1")
47
- text = text.replace("t1w1", "t1")
48
- text = text.replace("t1w", "t1")
49
- text = text.replace("t1ce", "t1")
50
- text = text.replace("t2-weighted", "t2")
51
- text = text.replace("t2w", "t2")
52
- text = text.replace("t2/flair", "flair")
53
- text = text.replace("tumour", "tumor")
54
- text = text.replace("lesions", "lesion")
55
- text = text.replace("-", " ")
56
- return text.split()
57
-
58
  def generate_captions(image, keywords):
59
  pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)
60
 
@@ -78,69 +61,42 @@ def generate_captions(image, keywords):
78
 
79
  return caption1, caption2
80
 
81
- def run_pipeline(image, actual_caption):
82
  classification = classify_image(image)
83
  keywords = list(classification.values())
84
  caption1, caption2 = generate_captions(image, keywords)
85
 
86
- if actual_caption.strip():
87
- ref = [preprocess_caption(actual_caption)]
88
- hyp1 = preprocess_caption(caption1)
89
- hyp2 = preprocess_caption(caption2)
90
- score1 = sentence_bleu(ref, hyp1, smoothing_function=nltk.translate.bleu_score.SmoothingFunction().method1)
91
- score2 = sentence_bleu(ref, hyp2, smoothing_function=nltk.translate.bleu_score.SmoothingFunction().method1)
92
- bleu1 = f"{score1:.2f}"
93
- bleu2 = f"{score2:.2f}"
94
- else:
95
- bleu1 = "-"
96
- bleu2 = "-"
97
-
98
- result_sections = {
99
- "classification": (
100
- f"Plane: {classification.get('plane')}\n"
101
- f"Modality: {classification.get('modality')}\n"
102
- f"Abnormality: {classification.get('abnormality')}\n"
103
- + (f"Tumor Type: {classification.get('tumor_type')}" if "tumor_type" in classification else "")
104
- ),
105
- "caption1": caption1,
106
- "caption2": caption2,
107
- "bleu1": bleu1,
108
- "bleu2": bleu2
109
- }
110
-
111
- return (
112
- result_sections["classification"],
113
- result_sections["caption1"],
114
- result_sections["bleu1"],
115
- result_sections["caption2"],
116
- result_sections["bleu2"]
117
  )
118
 
119
- with gr.Blocks(theme=gr.themes.Soft(primary_hue="pink"), css="*{font-family:'Poppins', sans-serif;}") as demo:
 
 
120
  gr.Markdown(
121
  """
122
  <link href="https://fonts.googleapis.com/css2?family=Poppins&display=swap" rel="stylesheet">
123
  <h1 style='text-align: center;'>🧠 Brain Hierarchical Classification + Captioning</h1>
124
- <p style='text-align: center;'>Upload an MRI/CT brain image. The system will classify the image (plane, modality, abnormality, tumor type) and generate two captions. Optionally, provide a ground truth caption to get BLEU scores.</p>
125
  """,
126
  elem_id="title"
127
  )
128
  with gr.Row():
129
  with gr.Column():
130
  image_input = gr.Image(type="pil", label="πŸ–ΌοΈ Upload Brain MRI/CT")
131
- actual_caption = gr.Textbox(label="πŸ’¬ Ground Truth Caption (optional)")
132
  btn = gr.Button("πŸš€ Submit")
133
  with gr.Column():
134
  cls_box = gr.Textbox(label="πŸ“‹ Classification Result", lines=4)
135
  cap1_box = gr.Textbox(label="πŸ“ Caption without Keyword Integration", lines=4)
136
- bleu1_box = gr.Textbox(label="πŸ“Š BLEU Score (No Keyword)", lines=1)
137
  cap2_box = gr.Textbox(label="🧠 Caption with Keyword Integration", lines=4)
138
- bleu2_box = gr.Textbox(label="πŸ“ˆ BLEU Score (With Keyword)", lines=1)
139
 
140
  btn.click(
141
  fn=run_pipeline,
142
- inputs=[image_input, actual_caption],
143
- outputs=[cls_box, cap1_box, bleu1_box, cap2_box, bleu2_box]
144
  )
145
 
146
  demo.launch()
 
2
  from PIL import Image
3
  import torch
4
  from transformers import VisionEncoderDecoderModel, AutoTokenizer, ViTFeatureExtractor, AutoImageProcessor, AutoModelForImageClassification
 
5
  import warnings
 
 
6
 
7
  warnings.filterwarnings("ignore", category=UserWarning)
8
 
 
12
  tokenizer = AutoTokenizer.from_pretrained("bombshelll/ViT_BioMedBert_Captioning_ROCO")
13
  feature_extractor = ViTFeatureExtractor.from_pretrained("bombshelll/ViT_BioMedBert_Captioning_ROCO")
14
 
15
+ with open("style.css") as f:
16
+ custom_css = f.read()
17
+
18
  def load_classifier(model_id):
19
  processor = AutoImageProcessor.from_pretrained(model_id)
20
  model = AutoModelForImageClassification.from_pretrained(model_id).to(device)
 
38
  results[name] = label
39
  return results
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def generate_captions(image, keywords):
42
  pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)
43
 
 
61
 
62
  return caption1, caption2
63
 
64
+ def run_pipeline(image):
65
  classification = classify_image(image)
66
  keywords = list(classification.values())
67
  caption1, caption2 = generate_captions(image, keywords)
68
 
69
+ classification_text = (
70
+ f"Plane: {classification.get('plane')}\n"
71
+ f"Modality: {classification.get('modality')}\n"
72
+ f"Abnormality: {classification.get('abnormality')}\n"
73
+ + (f"Tumor Type: {classification.get('tumor_type')}" if "tumor_type" in classification else "")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  )
75
 
76
+ return classification_text, caption1, caption2
77
+
78
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="pink"), css=custom_css) as demo:
79
  gr.Markdown(
80
  """
81
  <link href="https://fonts.googleapis.com/css2?family=Poppins&display=swap" rel="stylesheet">
82
  <h1 style='text-align: center;'>🧠 Brain Hierarchical Classification + Captioning</h1>
83
+ <p style='text-align: center;'>Upload an MRI/CT brain image. The system will classify the image (plane, modality, abnormality, tumor type) and generate two captions.</p>
84
  """,
85
  elem_id="title"
86
  )
87
  with gr.Row():
88
  with gr.Column():
89
  image_input = gr.Image(type="pil", label="πŸ–ΌοΈ Upload Brain MRI/CT")
 
90
  btn = gr.Button("πŸš€ Submit")
91
  with gr.Column():
92
  cls_box = gr.Textbox(label="πŸ“‹ Classification Result", lines=4)
93
  cap1_box = gr.Textbox(label="πŸ“ Caption without Keyword Integration", lines=4)
 
94
  cap2_box = gr.Textbox(label="🧠 Caption with Keyword Integration", lines=4)
 
95
 
96
  btn.click(
97
  fn=run_pipeline,
98
+ inputs=[image_input],
99
+ outputs=[cls_box, cap1_box, cap2_box]
100
  )
101
 
102
  demo.launch()
style.css ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ * {
2
+ font-family: 'Poppins', sans-serif;
3
+ }
4
+
5
+ .gr-column > div {
6
+ max-height: 600px;
7
+ overflow-y: auto;
8
+ }
9
+
10
+ body, html {
11
+ margin: 0;
12
+ padding: 0;
13
+ }