File size: 5,622 Bytes
227593e 24a5179 31a9e9a 24a5179 31a9e9a 227593e bfe2b83 227593e f2ba684 227593e 24a5179 227593e 6d6d9b8 227593e f2ba684 227593e 24a5179 227593e f2ba684 bfe2b83 6d6d9b8 f2ba684 24a5179 bfe2b83 f2ba684 6d6d9b8 f2ba684 24a5179 6453d14 f2ba684 24a5179 f2ba684 6d6d9b8 24a5179 6d6d9b8 24a5179 6d6d9b8 f2ba684 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import gradio as gr
from PIL import Image
import torch
from transformers import VisionEncoderDecoderModel, AutoTokenizer, ViTFeatureExtractor, AutoImageProcessor, AutoModelForImageClassification
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import nltk
import warnings
try:
nltk.data.find("tokenizers/punkt")
except LookupError:
nltk.download("punkt")
warnings.filterwarnings("ignore", category=UserWarning)
device = "cuda" if torch.cuda.is_available() else "cpu"
caption_model = VisionEncoderDecoderModel.from_pretrained("bombshelll/ViT_BioMedBert_Captioning_ROCO").to(device)
tokenizer = AutoTokenizer.from_pretrained("bombshelll/ViT_BioMedBert_Captioning_ROCO")
feature_extractor = ViTFeatureExtractor.from_pretrained("bombshelll/ViT_BioMedBert_Captioning_ROCO")
with open("style.css") as f:
custom_css = f.read()
def load_classifier(model_id):
processor = AutoImageProcessor.from_pretrained(model_id)
model = AutoModelForImageClassification.from_pretrained(model_id).to(device)
return processor, model
classifiers = {
"plane": load_classifier("bombshelll/swin-brain-plane-classification"),
"modality": load_classifier("bombshelll/swin-brain-modality-classification"),
"abnormality": load_classifier("bombshelll/swin-brain-abnormalities-classification"),
"tumor_type": load_classifier("bombshelll/swin-brain-tumor-type-classification")
}
def classify_image(image):
results = {}
for name, (processor, model) in classifiers.items():
inputs = processor(image, return_tensors="pt").to(device)
with torch.no_grad():
logits = model(**inputs).logits
label = model.config.id2label[logits.argmax(-1).item()]
if name != "tumor_type" or results.get("abnormality") == "tumor":
results[name] = label
return results
def preprocess_caption(text):
text = str(text).lower()
for term in ["magnetic resonance imaging", "magnetic resonance image"]:
text = text.replace(term, "mri")
for term in ["computed tomography"]:
text = text.replace(term, "ct")
text = text.replace("t1-weighted", "t1").replace("t1w1", "t1").replace("t1ce", "t1")
text = text.replace("t2-weighted", "t2").replace("t2w", "t2").replace("t2/flair", "flair")
text = text.replace("tumour", "tumor").replace("lesions", "lesion").replace("-", " ")
return text.split()
def generate_captions(image, keywords):
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)
caption_model.eval()
with torch.no_grad():
output_ids = caption_model.generate(pixel_values, max_length=80)
caption1 = tokenizer.decode(output_ids[0], skip_special_tokens=True)
prompt = " ".join(keywords)
prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
with torch.no_grad():
output_ids = caption_model.generate(
pixel_values,
decoder_input_ids=prompt_ids[:, :-1],
max_length=80,
num_beams=4,
no_repeat_ngram_size=3,
length_penalty=2.0
)
caption2 = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return caption1, caption2
def run_pipeline(image, actual_caption):
classification = classify_image(image)
keywords = list(classification.values())
caption1, caption2 = generate_captions(image, keywords)
classification_text = (
f"Plane: {classification.get('plane')}\n"
f"Modality: {classification.get('modality')}\n"
f"Abnormality: {classification.get('abnormality')}\n"
+ (f"Tumor Type: {classification.get('tumor_type')}" if "tumor_type" in classification else "")
)
if actual_caption.strip():
ref = [preprocess_caption(actual_caption)]
hyp1 = preprocess_caption(caption1)
hyp2 = preprocess_caption(caption2)
smooth = SmoothingFunction().method1
bleu1 = f"{sentence_bleu(ref, hyp1, smoothing_function=smooth):.2f}"
bleu2 = f"{sentence_bleu(ref, hyp2, smoothing_function=smooth):.2f}"
else:
bleu1 = "-"
bleu2 = "-"
return classification_text, caption1, caption2, bleu1, bleu2
with gr.Blocks(theme=gr.themes.Soft(primary_hue="pink"), css=custom_css) as demo:
gr.Markdown(
"""
<link href="https://fonts.googleapis.com/css2?family=Poppins&display=swap" rel="stylesheet">
<h1 style='text-align: center;'>π§ Brain Hierarchical Classification + Captioning</h1>
<p style='text-align: center;'>Upload an MRI/CT brain image. The system will classify the image and generate captions. Optionally, provide ground truth to see BLEU scores.</p>
""",
elem_id="title"
)
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="πΌοΈ Upload Brain MRI/CT")
actual_caption = gr.Textbox(label="π¬ Ground Truth Caption (optional)")
btn = gr.Button("π Submit")
with gr.Column():
cls_box = gr.Textbox(label="π Classification Result", lines=4)
cap1_box = gr.Textbox(label="π Caption without Keyword Integration", lines=4)
cap2_box = gr.Textbox(label="π§ Caption with Keyword Integration", lines=4)
bleu1_box = gr.Textbox(label="π BLEU (No Keyword)", lines=1)
bleu2_box = gr.Textbox(label="π BLEU (With Keyword)", lines=1)
btn.click(
fn=run_pipeline,
inputs=[image_input, actual_caption],
outputs=[cls_box, cap1_box, cap2_box, bleu1_box, bleu2_box]
)
demo.launch()
|