import gradio as gr from PIL import Image import torch from transformers import VisionEncoderDecoderModel, AutoTokenizer, ViTFeatureExtractor, AutoImageProcessor, AutoModelForImageClassification from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction import nltk import warnings try: nltk.data.find("tokenizers/punkt") except LookupError: nltk.download("punkt") warnings.filterwarnings("ignore", category=UserWarning) device = "cuda" if torch.cuda.is_available() else "cpu" caption_model = VisionEncoderDecoderModel.from_pretrained("bombshelll/ViT_BioMedBert_Captioning_ROCO").to(device) tokenizer = AutoTokenizer.from_pretrained("bombshelll/ViT_BioMedBert_Captioning_ROCO") feature_extractor = ViTFeatureExtractor.from_pretrained("bombshelll/ViT_BioMedBert_Captioning_ROCO") with open("style.css") as f: custom_css = f.read() def load_classifier(model_id): processor = AutoImageProcessor.from_pretrained(model_id) model = AutoModelForImageClassification.from_pretrained(model_id).to(device) return processor, model classifiers = { "plane": load_classifier("bombshelll/swin-brain-plane-classification"), "modality": load_classifier("bombshelll/swin-brain-modality-classification"), "abnormality": load_classifier("bombshelll/swin-brain-abnormalities-classification"), "tumor_type": load_classifier("bombshelll/swin-brain-tumor-type-classification") } def classify_image(image): results = {} for name, (processor, model) in classifiers.items(): inputs = processor(image, return_tensors="pt").to(device) with torch.no_grad(): logits = model(**inputs).logits label = model.config.id2label[logits.argmax(-1).item()] if name != "tumor_type" or results.get("abnormality") == "tumor": results[name] = label return results def preprocess_caption(text): text = str(text).lower() for term in ["magnetic resonance imaging", "magnetic resonance image"]: text = text.replace(term, "mri") for term in ["computed tomography"]: text = text.replace(term, "ct") text = text.replace("t1-weighted", "t1").replace("t1w1", "t1").replace("t1ce", "t1") text = text.replace("t2-weighted", "t2").replace("t2w", "t2").replace("t2/flair", "flair") text = text.replace("tumour", "tumor").replace("lesions", "lesion").replace("-", " ") return text.split() def generate_captions(image, keywords): pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device) caption_model.eval() with torch.no_grad(): output_ids = caption_model.generate(pixel_values, max_length=80) caption1 = tokenizer.decode(output_ids[0], skip_special_tokens=True) prompt = " ".join(keywords) prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) with torch.no_grad(): output_ids = caption_model.generate( pixel_values, decoder_input_ids=prompt_ids[:, :-1], max_length=80, num_beams=4, no_repeat_ngram_size=3, length_penalty=2.0 ) caption2 = tokenizer.decode(output_ids[0], skip_special_tokens=True) return caption1, caption2 def run_pipeline(image, actual_caption): classification = classify_image(image) keywords = list(classification.values()) caption1, caption2 = generate_captions(image, keywords) classification_text = ( f"Plane: {classification.get('plane')}\n" f"Modality: {classification.get('modality')}\n" f"Abnormality: {classification.get('abnormality')}\n" + (f"Tumor Type: {classification.get('tumor_type')}" if "tumor_type" in classification else "") ) if actual_caption.strip(): ref = [preprocess_caption(actual_caption)] hyp1 = preprocess_caption(caption1) hyp2 = preprocess_caption(caption2) smooth = SmoothingFunction().method1 bleu1 = f"{sentence_bleu(ref, hyp1, smoothing_function=smooth):.2f}" bleu2 = f"{sentence_bleu(ref, hyp2, smoothing_function=smooth):.2f}" else: bleu1 = "-" bleu2 = "-" return classification_text, caption1, caption2, bleu1, bleu2 with gr.Blocks(theme=gr.themes.Soft(primary_hue="pink"), css=custom_css) as demo: gr.Markdown( """
Upload an MRI/CT brain image. The system will classify the image and generate captions. Optionally, provide ground truth to see BLEU scores.
""", elem_id="title" ) with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="🖼️ Upload Brain MRI/CT") actual_caption = gr.Textbox(label="💬 Ground Truth Caption (optional)") btn = gr.Button("🚀 Submit") with gr.Column(): cls_box = gr.Textbox(label="📋 Classification Result", lines=4) cap1_box = gr.Textbox(label="📝 Caption without Keyword Integration", lines=4) cap2_box = gr.Textbox(label="🧠 Caption with Keyword Integration", lines=4) bleu1_box = gr.Textbox(label="📊 BLEU (No Keyword)", lines=1) bleu2_box = gr.Textbox(label="📈 BLEU (With Keyword)", lines=1) btn.click( fn=run_pipeline, inputs=[image_input, actual_caption], outputs=[cls_box, cap1_box, cap2_box, bleu1_box, bleu2_box] ) demo.launch()