| import gradio as gr | |
| from PIL import Image | |
| import torch | |
| from transformers import VisionEncoderDecoderModel, AutoTokenizer, ViTFeatureExtractor, AutoImageProcessor, AutoModelForImageClassification | |
| from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction | |
| import nltk | |
| import warnings | |
| try: | |
| nltk.data.find("tokenizers/punkt") | |
| except LookupError: | |
| nltk.download("punkt") | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| caption_model = VisionEncoderDecoderModel.from_pretrained("bombshelll/ViT_BioMedBert_Captioning_ROCO").to(device) | |
| tokenizer = AutoTokenizer.from_pretrained("bombshelll/ViT_BioMedBert_Captioning_ROCO") | |
| feature_extractor = ViTFeatureExtractor.from_pretrained("bombshelll/ViT_BioMedBert_Captioning_ROCO") | |
| with open("style.css") as f: | |
| custom_css = f.read() | |
| def load_classifier(model_id): | |
| processor = AutoImageProcessor.from_pretrained(model_id) | |
| model = AutoModelForImageClassification.from_pretrained(model_id).to(device) | |
| return processor, model | |
| classifiers = { | |
| "plane": load_classifier("bombshelll/swin-brain-plane-classification"), | |
| "modality": load_classifier("bombshelll/swin-brain-modality-classification"), | |
| "abnormality": load_classifier("bombshelll/swin-brain-abnormalities-classification"), | |
| "tumor_type": load_classifier("bombshelll/swin-brain-tumor-type-classification") | |
| } | |
| def classify_image(image): | |
| results = {} | |
| for name, (processor, model) in classifiers.items(): | |
| inputs = processor(image, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| label = model.config.id2label[logits.argmax(-1).item()] | |
| if name != "tumor_type" or results.get("abnormality") == "tumor": | |
| results[name] = label | |
| return results | |
| def preprocess_caption(text): | |
| text = str(text).lower() | |
| for term in ["magnetic resonance imaging", "magnetic resonance image"]: | |
| text = text.replace(term, "mri") | |
| for term in ["computed tomography"]: | |
| text = text.replace(term, "ct") | |
| text = text.replace("t1-weighted", "t1").replace("t1w1", "t1").replace("t1ce", "t1") | |
| text = text.replace("t2-weighted", "t2").replace("t2w", "t2").replace("t2/flair", "flair") | |
| text = text.replace("tumour", "tumor").replace("lesions", "lesion").replace("-", " ") | |
| return text.split() | |
| def generate_captions(image, keywords): | |
| pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device) | |
| caption_model.eval() | |
| with torch.no_grad(): | |
| output_ids = caption_model.generate(pixel_values, max_length=80) | |
| caption1 = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| prompt = " ".join(keywords) | |
| prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) | |
| with torch.no_grad(): | |
| output_ids = caption_model.generate( | |
| pixel_values, | |
| decoder_input_ids=prompt_ids[:, :-1], | |
| max_length=80, | |
| num_beams=4, | |
| no_repeat_ngram_size=3, | |
| length_penalty=2.0 | |
| ) | |
| caption2 = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| return caption1, caption2 | |
| def run_pipeline(image, actual_caption): | |
| classification = classify_image(image) | |
| keywords = list(classification.values()) | |
| caption1, caption2 = generate_captions(image, keywords) | |
| classification_text = ( | |
| f"Plane: {classification.get('plane')}\n" | |
| f"Modality: {classification.get('modality')}\n" | |
| f"Abnormality: {classification.get('abnormality')}\n" | |
| + (f"Tumor Type: {classification.get('tumor_type')}" if "tumor_type" in classification else "") | |
| ) | |
| if actual_caption.strip(): | |
| ref = [preprocess_caption(actual_caption)] | |
| hyp1 = preprocess_caption(caption1) | |
| hyp2 = preprocess_caption(caption2) | |
| smooth = SmoothingFunction().method1 | |
| bleu1 = f"{sentence_bleu(ref, hyp1, smoothing_function=smooth):.2f}" | |
| bleu2 = f"{sentence_bleu(ref, hyp2, smoothing_function=smooth):.2f}" | |
| else: | |
| bleu1 = "-" | |
| bleu2 = "-" | |
| return classification_text, caption1, caption2, bleu1, bleu2 | |
| with gr.Blocks(theme=gr.themes.Soft(primary_hue="pink"), css=custom_css) as demo: | |
| gr.Markdown( | |
| """ | |
| <link href="https://fonts.googleapis.com/css2?family=Poppins&display=swap" rel="stylesheet"> | |
| <h1 style='text-align: center;'>π§ Brain Hierarchical Classification + Captioning</h1> | |
| <p style='text-align: center;'>Upload an MRI/CT brain image. The system will classify the image and generate captions. Optionally, provide ground truth to see BLEU scores.</p> | |
| """, | |
| elem_id="title" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="πΌοΈ Upload Brain MRI/CT") | |
| actual_caption = gr.Textbox(label="π¬ Ground Truth Caption (optional)") | |
| btn = gr.Button("π Submit") | |
| with gr.Column(): | |
| cls_box = gr.Textbox(label="π Classification Result", lines=4) | |
| cap1_box = gr.Textbox(label="π Caption without Keyword Integration", lines=4) | |
| cap2_box = gr.Textbox(label="π§ Caption with Keyword Integration", lines=4) | |
| bleu1_box = gr.Textbox(label="π BLEU (No Keyword)", lines=1) | |
| bleu2_box = gr.Textbox(label="π BLEU (With Keyword)", lines=1) | |
| btn.click( | |
| fn=run_pipeline, | |
| inputs=[image_input, actual_caption], | |
| outputs=[cls_box, cap1_box, cap2_box, bleu1_box, bleu2_box] | |
| ) | |
| demo.launch() | |