File size: 2,190 Bytes
96b6606 3f365e6 d759a77 3f365e6 d759a77 96b6606 d759a77 3f365e6 96b6606 3f365e6 d759a77 3f365e6 d759a77 3f365e6 d759a77 3f365e6 96b6606 3f365e6 96b6606 d759a77 96b6606 3f365e6 96b6606 3f365e6 96b6606 3f365e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import random
import gradio as gr
from PIL import Image
from datasets import load_dataset
from app.model import predict, gradcam, CLASS_NAMES
# Load HF dataset once at startup
dataset = load_dataset("AIOmarRehan/Brain_Tumor_MRI_Dataset", split="train")
# Convert any image to a usable PIL
def to_pil(example):
if isinstance(example, Image.Image):
return example
return Image.fromarray(example)
def get_random_image():
sample = random.choice(dataset)
return to_pil(sample["image"])
# Prediction and Grad-CAM logic
def predict_fn(img):
label, confidence, probs = predict(img)
probs_sorted = {k: float(v) for k, v in sorted(probs.items(), key=lambda x: x[1], reverse=True)}
return label, {
"Predicted label": label,
"Confidence": round(confidence, 3),
"Class probabilities": probs_sorted
}
def gradcam_fn(img, interpolant):
heatmap = gradcam(img, interpolant=float(interpolant))
return Image.fromarray(heatmap)
# GRADIO UI
with gr.Blocks(title="Brain Tumor MRI Classifier (InceptionV3 + Grad-CAM)") as demo:
gr.Markdown("# Brain Tumor MRI Classifier (InceptionV3 + Grad-CAM)")
gr.Markdown("Upload an MRI image OR use a random sample from the dataset.")
with gr.Row():
with gr.Column():
input_img = gr.Image(type="pil", label="Upload MRI Image")
random_btn = gr.Button("Use Random Dataset Image")
interpolant_slider = gr.Slider(0, 1, value=0.5, label="Grad-CAM Intensity (interpolant)")
submit_btn = gr.Button("Run Prediction + Grad-CAM")
with gr.Column():
output_label = gr.Textbox(label="Predicted Label Only")
output_json = gr.JSON(label="Prediction Results")
output_cam = gr.Image(label="Grad-CAM Overlay")
# Load random image
random_btn.click(
fn=lambda: get_random_image(),
inputs=[],
outputs=[input_img]
)
# Prediction + GradCAM
submit_btn.click(
fn=lambda img, interp: (*predict_fn(img), gradcam_fn(img, interp)),
inputs=[input_img, interpolant_slider],
outputs=[output_label, output_json, output_cam]
)
demo.launch() |