import os
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import efficientnet_b3
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download
# ── Config ────────────────────────────────────────────────────────
CKPT_FILE = "model.pt"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
CLASS_INFO = {
"Glioma": {
"color": "#f87171",
"glow": "rgba(248,113,113,0.25)",
"icon": "🔴",
"desc": "Originates in glial cells of the brain or spine. Accounts for ~30% of all brain tumors and ~80% of malignant tumors.",
},
"Meningioma": {
"color": "#fb923c",
"glow": "rgba(251,146,60,0.25)",
"icon": "🟠",
"desc": "Arises from the meninges surrounding the brain and spinal cord. Usually benign and slow-growing.",
},
"Pituitary Tumor": {
"color": "#c084fc",
"glow": "rgba(192,132,252,0.25)",
"icon": "🟣",
"desc": "Located in the pituitary gland at the brain's base. Most are benign but can disrupt hormone regulation.",
},
"No Tumor": {
"color": "#4ade80",
"glow": "rgba(74,222,128,0.25)",
"icon": "🟢",
"desc": "No tumor detected. Brain tissue appears within normal parameters.",
},
}
# ── Model ─────────────────────────────────────────────────────────
class EfficientNetClassifier(nn.Module):
def __init__(self, num_classes=4, dropout=0.4):
super().__init__()
self.backbone = efficientnet_b3(weights=None)
in_features = self.backbone.classifier[1].in_features
self.backbone.classifier = nn.Sequential(
nn.Dropout(p=dropout, inplace=True),
nn.Linear(in_features, 512),
nn.SiLU(),
nn.Dropout(p=dropout / 2),
nn.Linear(512, num_classes),
)
def forward(self, x):
return self.backbone(x)
def load_model():
ckpt_path = hf_hub_download(repo_id="S-4-G-4-R/brain-tumor-efficientnet-b3", filename=CKPT_FILE)
ckpt = torch.load(ckpt_path, map_location=DEVICE, weights_only=False)
n_classes = ckpt.get("num_classes", 4)
img_size = ckpt.get("img_size", 300)
id_to_label = {int(k): v for k, v in ckpt["id_to_label"].items()}
model = EfficientNetClassifier(n_classes).to(DEVICE)
model.load_state_dict(ckpt["model"])
model.eval()
return model, img_size, id_to_label
print("Loading model…")
model, IMG_SIZE, id_to_label = load_model()
print(f"Model ready on {DEVICE}")
transform = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(MEAN, STD),
])
# ── Inference ─────────────────────────────────────────────────────
@torch.no_grad()
def predict(image: Image.Image):
if image is None:
return None, _empty_state()
tensor = transform(image.convert("RGB")).unsqueeze(0).to(DEVICE)
logits = model(tensor)
probs = torch.softmax(logits, dim=-1)[0]
results = {id_to_label[i]: round(probs[i].item(), 4) for i in range(len(id_to_label))}
top_label = max(results, key=results.get)
top_prob = results[top_label]
# Normalise key for CLASS_INFO lookup
label_key = top_label
for k in CLASS_INFO:
if k.lower() == top_label.lower():
label_key = k
break
info = CLASS_INFO.get(label_key, {})
color = info.get("color", "#ffffff")
glow = info.get("glow", "rgba(255,255,255,0.1)")
icon = info.get("icon", "⚪")
desc = info.get("desc", "")
# ── Probability bars ──────────────────────────────────────────
bars_html = ""
for lbl, prob in sorted(results.items(), key=lambda x: x[1], reverse=True):
lkey = lbl
for k in CLASS_INFO:
if k.lower() == lbl.lower():
lkey = k
break
c = CLASS_INFO.get(lkey, {}).get("color", "#555")
is_top = lbl == top_label
bars_html += f"""
{CLASS_INFO.get(lkey,{}).get('icon','⚪')} {lbl}
{prob*100:.2f}%
"""
html = f"""
🔬 Diagnosis
{icon} {top_label}
{desc}
📊 Confidence
{top_prob*100:.1f}%
📈 All Classes
{bars_html}
⚠️ For research use only · Not a clinical diagnostic tool
"""
return results, html
def _empty_state():
return """
🧠
Awaiting MRI scan
Upload or drag-and-drop a brain MRI image on the left to see the classification result here.
"""
# ── CSS ───────────────────────────────────────────────────────────
CSS = """
@import url('https://fonts.googleapis.com/css2?family=Space+Grotesk:wght@300;400;500;600;700;800&family=Space+Mono:wght@400;700&display=swap');
*, *::before, *::after { box-sizing: border-box; }
:root {
--bg: #080c12;
--surface: #0d1117;
--border: #1f2937;
--accent: #6366f1;
--muted: #6b7280;
--text: #e5e7eb;
--font: 'Space Grotesk', sans-serif;
--mono: 'Space Mono', monospace;
}
html, body, .gradio-container {
background: var(--bg) !important;
font-family: var(--font) !important;
color: var(--text) !important;
}
.gradio-container {
max-width: 1100px !important;
margin: 0 auto !important;
padding: 0 16px !important;
}
/* ── Header ── */
#hero {
padding: 44px 8px 36px;
text-align: center;
border-bottom: 1px solid var(--border);
margin-bottom: 36px;
}
#hero .pill {
display: inline-block;
font-family: var(--mono);
font-size: 10px;
letter-spacing: 0.15em;
text-transform: uppercase;
padding: 5px 14px;
border: 1px solid #2a3a4a;
border-radius: 99px;
color: #4b6a8a;
margin-bottom: 20px;
background: #0a131e;
}
#hero h1 {
font-size: clamp(26px, 5vw, 42px);
font-weight: 800;
letter-spacing: -0.04em;
color: #f1f5f9;
margin: 0 0 12px;
line-height: 1.1;
}
#hero h1 span { color: #6366f1; }
#hero p {
font-size: 14px;
color: var(--muted);
margin: 0;
line-height: 1.7;
max-width: 520px;
margin: 0 auto;
}
/* ── Two-column wrapper ── */
#main-row {
display: grid !important;
grid-template-columns: 1fr 1fr !important;
gap: 20px !important;
align-items: start !important;
}
@media (max-width: 700px) {
#main-row {
grid-template-columns: 1fr !important;
}
}
/* ── Left panel ── */
#upload-panel {
background: var(--surface) !important;
border: 1px solid var(--border) !important;
border-radius: 16px !important;
padding: 24px !important;
}
#upload-panel .panel-label {
font-size: 11px;
letter-spacing: 0.14em;
text-transform: uppercase;
color: var(--muted);
margin-bottom: 16px;
font-family: var(--mono);
}
/* Gradio image component dark styling */
.upload-wrap .svelte-1ipelgc,
.upload-wrap [data-testid="image"] {
background: #080c12 !important;
border: 1.5px dashed #2a3a4a !important;
border-radius: 12px !important;
min-height: 260px !important;
transition: border-color 0.25s;
}
.upload-wrap [data-testid="image"]:hover {
border-color: var(--accent) !important;
}
/* ── Classify button ── */
#classify-btn {
margin-top: 14px !important;
width: 100% !important;
background: var(--accent) !important;
border: none !important;
border-radius: 10px !important;
color: #fff !important;
font-family: var(--font) !important;
font-size: 14px !important;
font-weight: 700 !important;
letter-spacing: 0.06em !important;
padding: 13px 0 !important;
cursor: pointer !important;
transition: opacity 0.2s, transform 0.15s !important;
box-shadow: 0 0 24px rgba(99,102,241,0.35) !important;
}
#classify-btn:hover {
opacity: 0.88 !important;
transform: translateY(-1px) !important;
}
#classify-btn:active {
transform: translateY(0) !important;
}
/* ── Upload hint text ── */
#upload-hint {
font-size: 12px;
color: #374151;
text-align: center;
margin-top: 10px;
line-height: 1.6;
}
/* ── Stats strip ── */
#stats-strip {
display: flex;
gap: 12px;
margin-top: 16px;
}
.stat-chip {
flex: 1;
background: #0a131e;
border: 1px solid #1a2535;
border-radius: 8px;
padding: 10px 12px;
text-align: center;
}
.stat-chip .val {
font-size: 16px;
font-weight: 800;
color: #6366f1;
font-family: var(--mono);
display: block;
letter-spacing: -0.02em;
}
.stat-chip .lbl {
font-size: 10px;
color: #374151;
text-transform: uppercase;
letter-spacing: 0.1em;
margin-top: 2px;
display: block;
}
/* ── Right panel / result ── */
.result-panel > label { display: none !important; }
#result-col { align-self: stretch; }
/* ── Footer ── */
#footer {
text-align: center;
padding: 28px 16px;
border-top: 1px solid var(--border);
margin-top: 36px;
font-size: 12px;
color: #2d3748;
line-height: 1.8;
}
#footer a { color: #4b6a8a; text-decoration: none; }
#footer a:hover { color: var(--accent); }
/* ── Gradio internal overrides ── */
label span {
font-family: var(--font) !important;
font-size: 11px !important;
font-weight: 600 !important;
letter-spacing: 0.1em !important;
text-transform: uppercase !important;
color: var(--muted) !important;
}
/* Remove default gradio row gaps */
.gr-row { gap: 0 !important; }
"""
# ── Gradio UI ─────────────────────────────────────────────────────
with gr.Blocks(css=CSS, theme=gr.themes.Base(), title="NeuroScan · Brain Tumor MRI Classifier") as demo:
# ── Hero ──────────────────────────────────────────────────────
gr.HTML("""
⚡ EfficientNet-B3 · 98.98% Val Acc · 4 Classes
🧠 NeuroScan
AI-powered brain tumor detection from MRI scans.
Classifies Glioma · Meningioma · Pituitary Tumor · No Tumor
in seconds — just upload your scan below.
""")
# ── Main two-column layout ─────────────────────────────────────
with gr.Row(elem_id="main-row"):
# ── Left: Upload panel ────────────────────────────────────
with gr.Column(elem_id="upload-panel", scale=1):
gr.HTML('📤 Upload MRI Scan
')
image_input = gr.Image(
type="pil",
label="",
elem_classes=["upload-wrap"],
height=280,
show_label=False,
)
gr.HTML("""
🖼️ Drag & drop or click to browse
Supports JPG · PNG · WEBP · Axial / coronal / sagittal views
""")
run_btn = gr.Button("🔍 Classify MRI Scan", elem_id="classify-btn")
gr.HTML("""
8.2K
Train Images
98.98%
Val Accuracy
4
Classes
""")
# ── Right: Result panel ───────────────────────────────────
with gr.Column(elem_id="result-col", scale=1):
result_html = gr.HTML(
value=_empty_state(),
label="",
elem_classes=["result-panel"],
)
# Hidden label output (internal use)
label_output = gr.Label(visible=False)
# ── Event bindings ─────────────────────────────────────────────
run_btn.click(fn=predict, inputs=[image_input], outputs=[label_output, result_html])
image_input.change(fn=predict, inputs=[image_input], outputs=[label_output, result_html])
# ── Footer ────────────────────────────────────────────────────
gr.HTML("""
""")
if __name__ == "__main__":
demo.launch()