|
|
import random |
|
|
import gradio as gr |
|
|
from PIL import Image |
|
|
from datasets import load_dataset |
|
|
from app.model import predict, gradcam, CLASS_NAMES |
|
|
|
|
|
|
|
|
dataset = load_dataset("AIOmarRehan/Brain_Tumor_MRI_Dataset", split="train") |
|
|
|
|
|
|
|
|
def to_pil(example): |
|
|
if isinstance(example, Image.Image): |
|
|
return example |
|
|
return Image.fromarray(example) |
|
|
|
|
|
def get_random_image(): |
|
|
sample = random.choice(dataset) |
|
|
return to_pil(sample["image"]) |
|
|
|
|
|
|
|
|
|
|
|
def predict_fn(img): |
|
|
label, confidence, probs = predict(img) |
|
|
probs_sorted = {k: float(v) for k, v in sorted(probs.items(), key=lambda x: x[1], reverse=True)} |
|
|
return label, { |
|
|
"Predicted label": label, |
|
|
"Confidence": round(confidence, 3), |
|
|
"Class probabilities": probs_sorted |
|
|
} |
|
|
|
|
|
def gradcam_fn(img, interpolant): |
|
|
heatmap = gradcam(img, interpolant=float(interpolant)) |
|
|
return Image.fromarray(heatmap) |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Brain Tumor MRI Classifier (InceptionV3 + Grad-CAM)") as demo: |
|
|
gr.Markdown("# Brain Tumor MRI Classifier (InceptionV3 + Grad-CAM)") |
|
|
gr.Markdown("Upload an MRI image OR use a random sample from the dataset.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_img = gr.Image(type="pil", label="Upload MRI Image") |
|
|
random_btn = gr.Button("Use Random Dataset Image") |
|
|
interpolant_slider = gr.Slider(0, 1, value=0.5, label="Grad-CAM Intensity (interpolant)") |
|
|
submit_btn = gr.Button("Run Prediction + Grad-CAM") |
|
|
|
|
|
with gr.Column(): |
|
|
output_label = gr.Textbox(label="Predicted Label Only") |
|
|
output_json = gr.JSON(label="Prediction Results") |
|
|
output_cam = gr.Image(label="Grad-CAM Overlay") |
|
|
|
|
|
|
|
|
random_btn.click( |
|
|
fn=lambda: get_random_image(), |
|
|
inputs=[], |
|
|
outputs=[input_img] |
|
|
) |
|
|
|
|
|
|
|
|
submit_btn.click( |
|
|
fn=lambda img, interp: (*predict_fn(img), gradcam_fn(img, interp)), |
|
|
inputs=[input_img, interpolant_slider], |
|
|
outputs=[output_label, output_json, output_cam] |
|
|
) |
|
|
|
|
|
demo.launch() |