bombshelll's picture
Add BLEU
24a5179
import gradio as gr
from PIL import Image
import torch
from transformers import VisionEncoderDecoderModel, AutoTokenizer, ViTFeatureExtractor, AutoImageProcessor, AutoModelForImageClassification
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import nltk
import warnings
try:
nltk.data.find("tokenizers/punkt")
except LookupError:
nltk.download("punkt")
warnings.filterwarnings("ignore", category=UserWarning)
device = "cuda" if torch.cuda.is_available() else "cpu"
caption_model = VisionEncoderDecoderModel.from_pretrained("bombshelll/ViT_BioMedBert_Captioning_ROCO").to(device)
tokenizer = AutoTokenizer.from_pretrained("bombshelll/ViT_BioMedBert_Captioning_ROCO")
feature_extractor = ViTFeatureExtractor.from_pretrained("bombshelll/ViT_BioMedBert_Captioning_ROCO")
with open("style.css") as f:
custom_css = f.read()
def load_classifier(model_id):
processor = AutoImageProcessor.from_pretrained(model_id)
model = AutoModelForImageClassification.from_pretrained(model_id).to(device)
return processor, model
classifiers = {
"plane": load_classifier("bombshelll/swin-brain-plane-classification"),
"modality": load_classifier("bombshelll/swin-brain-modality-classification"),
"abnormality": load_classifier("bombshelll/swin-brain-abnormalities-classification"),
"tumor_type": load_classifier("bombshelll/swin-brain-tumor-type-classification")
}
def classify_image(image):
results = {}
for name, (processor, model) in classifiers.items():
inputs = processor(image, return_tensors="pt").to(device)
with torch.no_grad():
logits = model(**inputs).logits
label = model.config.id2label[logits.argmax(-1).item()]
if name != "tumor_type" or results.get("abnormality") == "tumor":
results[name] = label
return results
def preprocess_caption(text):
text = str(text).lower()
for term in ["magnetic resonance imaging", "magnetic resonance image"]:
text = text.replace(term, "mri")
for term in ["computed tomography"]:
text = text.replace(term, "ct")
text = text.replace("t1-weighted", "t1").replace("t1w1", "t1").replace("t1ce", "t1")
text = text.replace("t2-weighted", "t2").replace("t2w", "t2").replace("t2/flair", "flair")
text = text.replace("tumour", "tumor").replace("lesions", "lesion").replace("-", " ")
return text.split()
def generate_captions(image, keywords):
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)
caption_model.eval()
with torch.no_grad():
output_ids = caption_model.generate(pixel_values, max_length=80)
caption1 = tokenizer.decode(output_ids[0], skip_special_tokens=True)
prompt = " ".join(keywords)
prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
with torch.no_grad():
output_ids = caption_model.generate(
pixel_values,
decoder_input_ids=prompt_ids[:, :-1],
max_length=80,
num_beams=4,
no_repeat_ngram_size=3,
length_penalty=2.0
)
caption2 = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return caption1, caption2
def run_pipeline(image, actual_caption):
classification = classify_image(image)
keywords = list(classification.values())
caption1, caption2 = generate_captions(image, keywords)
classification_text = (
f"Plane: {classification.get('plane')}\n"
f"Modality: {classification.get('modality')}\n"
f"Abnormality: {classification.get('abnormality')}\n"
+ (f"Tumor Type: {classification.get('tumor_type')}" if "tumor_type" in classification else "")
)
if actual_caption.strip():
ref = [preprocess_caption(actual_caption)]
hyp1 = preprocess_caption(caption1)
hyp2 = preprocess_caption(caption2)
smooth = SmoothingFunction().method1
bleu1 = f"{sentence_bleu(ref, hyp1, smoothing_function=smooth):.2f}"
bleu2 = f"{sentence_bleu(ref, hyp2, smoothing_function=smooth):.2f}"
else:
bleu1 = "-"
bleu2 = "-"
return classification_text, caption1, caption2, bleu1, bleu2
with gr.Blocks(theme=gr.themes.Soft(primary_hue="pink"), css=custom_css) as demo:
gr.Markdown(
"""
<link href="https://fonts.googleapis.com/css2?family=Poppins&display=swap" rel="stylesheet">
<h1 style='text-align: center;'>🧠 Brain Hierarchical Classification + Captioning</h1>
<p style='text-align: center;'>Upload an MRI/CT brain image. The system will classify the image and generate captions. Optionally, provide ground truth to see BLEU scores.</p>
""",
elem_id="title"
)
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="πŸ–ΌοΈ Upload Brain MRI/CT")
actual_caption = gr.Textbox(label="πŸ’¬ Ground Truth Caption (optional)")
btn = gr.Button("πŸš€ Submit")
with gr.Column():
cls_box = gr.Textbox(label="πŸ“‹ Classification Result", lines=4)
cap1_box = gr.Textbox(label="πŸ“ Caption without Keyword Integration", lines=4)
cap2_box = gr.Textbox(label="🧠 Caption with Keyword Integration", lines=4)
bleu1_box = gr.Textbox(label="πŸ“Š BLEU (No Keyword)", lines=1)
bleu2_box = gr.Textbox(label="πŸ“ˆ BLEU (With Keyword)", lines=1)
btn.click(
fn=run_pipeline,
inputs=[image_input, actual_caption],
outputs=[cls_box, cap1_box, cap2_box, bleu1_box, bleu2_box]
)
demo.launch()