import os import torch import torch.nn as nn from torchvision import transforms from torchvision.models import efficientnet_b3 from PIL import Image import gradio as gr from huggingface_hub import hf_hub_download # ── Config ──────────────────────────────────────────────────────── CKPT_FILE = "model.pt" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") MEAN = [0.485, 0.456, 0.406] STD = [0.229, 0.224, 0.225] CLASS_INFO = { "Glioma": { "color": "#f87171", "glow": "rgba(248,113,113,0.25)", "icon": "🔴", "desc": "Originates in glial cells of the brain or spine. Accounts for ~30% of all brain tumors and ~80% of malignant tumors.", }, "Meningioma": { "color": "#fb923c", "glow": "rgba(251,146,60,0.25)", "icon": "🟠", "desc": "Arises from the meninges surrounding the brain and spinal cord. Usually benign and slow-growing.", }, "Pituitary Tumor": { "color": "#c084fc", "glow": "rgba(192,132,252,0.25)", "icon": "🟣", "desc": "Located in the pituitary gland at the brain's base. Most are benign but can disrupt hormone regulation.", }, "No Tumor": { "color": "#4ade80", "glow": "rgba(74,222,128,0.25)", "icon": "🟢", "desc": "No tumor detected. Brain tissue appears within normal parameters.", }, } # ── Model ───────────────────────────────────────────────────────── class EfficientNetClassifier(nn.Module): def __init__(self, num_classes=4, dropout=0.4): super().__init__() self.backbone = efficientnet_b3(weights=None) in_features = self.backbone.classifier[1].in_features self.backbone.classifier = nn.Sequential( nn.Dropout(p=dropout, inplace=True), nn.Linear(in_features, 512), nn.SiLU(), nn.Dropout(p=dropout / 2), nn.Linear(512, num_classes), ) def forward(self, x): return self.backbone(x) def load_model(): ckpt_path = hf_hub_download(repo_id="S-4-G-4-R/brain-tumor-efficientnet-b3", filename=CKPT_FILE) ckpt = torch.load(ckpt_path, map_location=DEVICE, weights_only=False) n_classes = ckpt.get("num_classes", 4) img_size = ckpt.get("img_size", 300) id_to_label = {int(k): v for k, v in ckpt["id_to_label"].items()} model = EfficientNetClassifier(n_classes).to(DEVICE) model.load_state_dict(ckpt["model"]) model.eval() return model, img_size, id_to_label print("Loading model…") model, IMG_SIZE, id_to_label = load_model() print(f"Model ready on {DEVICE}") transform = transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), transforms.Normalize(MEAN, STD), ]) # ── Inference ───────────────────────────────────────────────────── @torch.no_grad() def predict(image: Image.Image): if image is None: return None, _empty_state() tensor = transform(image.convert("RGB")).unsqueeze(0).to(DEVICE) logits = model(tensor) probs = torch.softmax(logits, dim=-1)[0] results = {id_to_label[i]: round(probs[i].item(), 4) for i in range(len(id_to_label))} top_label = max(results, key=results.get) top_prob = results[top_label] # Normalise key for CLASS_INFO lookup label_key = top_label for k in CLASS_INFO: if k.lower() == top_label.lower(): label_key = k break info = CLASS_INFO.get(label_key, {}) color = info.get("color", "#ffffff") glow = info.get("glow", "rgba(255,255,255,0.1)") icon = info.get("icon", "⚪") desc = info.get("desc", "") # ── Probability bars ────────────────────────────────────────── bars_html = "" for lbl, prob in sorted(results.items(), key=lambda x: x[1], reverse=True): lkey = lbl for k in CLASS_INFO: if k.lower() == lbl.lower(): lkey = k break c = CLASS_INFO.get(lkey, {}).get("color", "#555") is_top = lbl == top_label bars_html += f"""
{CLASS_INFO.get(lkey,{}).get('icon','⚪')} {lbl} {prob*100:.2f}%
""" html = f"""
🔬 Diagnosis
{icon} {top_label}
{desc}
📊 Confidence {top_prob*100:.1f}%
📈 All Classes
{bars_html}
⚠️ For research use only · Not a clinical diagnostic tool
""" return results, html def _empty_state(): return """
🧠
Awaiting MRI scan
Upload or drag-and-drop a brain MRI image on the left to see the classification result here.
""" # ── CSS ─────────────────────────────────────────────────────────── CSS = """ @import url('https://fonts.googleapis.com/css2?family=Space+Grotesk:wght@300;400;500;600;700;800&family=Space+Mono:wght@400;700&display=swap'); *, *::before, *::after { box-sizing: border-box; } :root { --bg: #080c12; --surface: #0d1117; --border: #1f2937; --accent: #6366f1; --muted: #6b7280; --text: #e5e7eb; --font: 'Space Grotesk', sans-serif; --mono: 'Space Mono', monospace; } html, body, .gradio-container { background: var(--bg) !important; font-family: var(--font) !important; color: var(--text) !important; } .gradio-container { max-width: 1100px !important; margin: 0 auto !important; padding: 0 16px !important; } /* ── Header ── */ #hero { padding: 44px 8px 36px; text-align: center; border-bottom: 1px solid var(--border); margin-bottom: 36px; } #hero .pill { display: inline-block; font-family: var(--mono); font-size: 10px; letter-spacing: 0.15em; text-transform: uppercase; padding: 5px 14px; border: 1px solid #2a3a4a; border-radius: 99px; color: #4b6a8a; margin-bottom: 20px; background: #0a131e; } #hero h1 { font-size: clamp(26px, 5vw, 42px); font-weight: 800; letter-spacing: -0.04em; color: #f1f5f9; margin: 0 0 12px; line-height: 1.1; } #hero h1 span { color: #6366f1; } #hero p { font-size: 14px; color: var(--muted); margin: 0; line-height: 1.7; max-width: 520px; margin: 0 auto; } /* ── Two-column wrapper ── */ #main-row { display: grid !important; grid-template-columns: 1fr 1fr !important; gap: 20px !important; align-items: start !important; } @media (max-width: 700px) { #main-row { grid-template-columns: 1fr !important; } } /* ── Left panel ── */ #upload-panel { background: var(--surface) !important; border: 1px solid var(--border) !important; border-radius: 16px !important; padding: 24px !important; } #upload-panel .panel-label { font-size: 11px; letter-spacing: 0.14em; text-transform: uppercase; color: var(--muted); margin-bottom: 16px; font-family: var(--mono); } /* Gradio image component dark styling */ .upload-wrap .svelte-1ipelgc, .upload-wrap [data-testid="image"] { background: #080c12 !important; border: 1.5px dashed #2a3a4a !important; border-radius: 12px !important; min-height: 260px !important; transition: border-color 0.25s; } .upload-wrap [data-testid="image"]:hover { border-color: var(--accent) !important; } /* ── Classify button ── */ #classify-btn { margin-top: 14px !important; width: 100% !important; background: var(--accent) !important; border: none !important; border-radius: 10px !important; color: #fff !important; font-family: var(--font) !important; font-size: 14px !important; font-weight: 700 !important; letter-spacing: 0.06em !important; padding: 13px 0 !important; cursor: pointer !important; transition: opacity 0.2s, transform 0.15s !important; box-shadow: 0 0 24px rgba(99,102,241,0.35) !important; } #classify-btn:hover { opacity: 0.88 !important; transform: translateY(-1px) !important; } #classify-btn:active { transform: translateY(0) !important; } /* ── Upload hint text ── */ #upload-hint { font-size: 12px; color: #374151; text-align: center; margin-top: 10px; line-height: 1.6; } /* ── Stats strip ── */ #stats-strip { display: flex; gap: 12px; margin-top: 16px; } .stat-chip { flex: 1; background: #0a131e; border: 1px solid #1a2535; border-radius: 8px; padding: 10px 12px; text-align: center; } .stat-chip .val { font-size: 16px; font-weight: 800; color: #6366f1; font-family: var(--mono); display: block; letter-spacing: -0.02em; } .stat-chip .lbl { font-size: 10px; color: #374151; text-transform: uppercase; letter-spacing: 0.1em; margin-top: 2px; display: block; } /* ── Right panel / result ── */ .result-panel > label { display: none !important; } #result-col { align-self: stretch; } /* ── Footer ── */ #footer { text-align: center; padding: 28px 16px; border-top: 1px solid var(--border); margin-top: 36px; font-size: 12px; color: #2d3748; line-height: 1.8; } #footer a { color: #4b6a8a; text-decoration: none; } #footer a:hover { color: var(--accent); } /* ── Gradio internal overrides ── */ label span { font-family: var(--font) !important; font-size: 11px !important; font-weight: 600 !important; letter-spacing: 0.1em !important; text-transform: uppercase !important; color: var(--muted) !important; } /* Remove default gradio row gaps */ .gr-row { gap: 0 !important; } """ # ── Gradio UI ───────────────────────────────────────────────────── with gr.Blocks(css=CSS, theme=gr.themes.Base(), title="NeuroScan · Brain Tumor MRI Classifier") as demo: # ── Hero ────────────────────────────────────────────────────── gr.HTML("""
⚡ EfficientNet-B3  ·  98.98% Val Acc  ·  4 Classes

🧠 NeuroScan

AI-powered brain tumor detection from MRI scans.
Classifies Glioma · Meningioma · Pituitary Tumor · No Tumor
in seconds — just upload your scan below.

""") # ── Main two-column layout ───────────────────────────────────── with gr.Row(elem_id="main-row"): # ── Left: Upload panel ──────────────────────────────────── with gr.Column(elem_id="upload-panel", scale=1): gr.HTML('
📤 Upload MRI Scan
') image_input = gr.Image( type="pil", label="", elem_classes=["upload-wrap"], height=280, show_label=False, ) gr.HTML("""
🖼️ Drag & drop or click to browse
Supports JPG · PNG · WEBP  ·  Axial / coronal / sagittal views
""") run_btn = gr.Button("🔍 Classify MRI Scan", elem_id="classify-btn") gr.HTML("""
8.2K Train Images
98.98% Val Accuracy
4 Classes
""") # ── Right: Result panel ─────────────────────────────────── with gr.Column(elem_id="result-col", scale=1): result_html = gr.HTML( value=_empty_state(), label="", elem_classes=["result-panel"], ) # Hidden label output (internal use) label_output = gr.Label(visible=False) # ── Event bindings ───────────────────────────────────────────── run_btn.click(fn=predict, inputs=[image_input], outputs=[label_output, result_html]) image_input.change(fn=predict, inputs=[image_input], outputs=[label_output, result_html]) # ── Footer ──────────────────────────────────────────────────── gr.HTML(""" """) if __name__ == "__main__": demo.launch()