S-4-G-4-R's picture
Update app.py
f8dd4c8 verified
import os
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import efficientnet_b3
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download
# ── Config ────────────────────────────────────────────────────────
CKPT_FILE = "model.pt"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
CLASS_INFO = {
"Glioma": {
"color": "#f87171",
"glow": "rgba(248,113,113,0.25)",
"icon": "πŸ”΄",
"desc": "Originates in glial cells of the brain or spine. Accounts for ~30% of all brain tumors and ~80% of malignant tumors.",
},
"Meningioma": {
"color": "#fb923c",
"glow": "rgba(251,146,60,0.25)",
"icon": "🟠",
"desc": "Arises from the meninges surrounding the brain and spinal cord. Usually benign and slow-growing.",
},
"Pituitary Tumor": {
"color": "#c084fc",
"glow": "rgba(192,132,252,0.25)",
"icon": "🟣",
"desc": "Located in the pituitary gland at the brain's base. Most are benign but can disrupt hormone regulation.",
},
"No Tumor": {
"color": "#4ade80",
"glow": "rgba(74,222,128,0.25)",
"icon": "🟒",
"desc": "No tumor detected. Brain tissue appears within normal parameters.",
},
}
# ── Model ─────────────────────────────────────────────────────────
class EfficientNetClassifier(nn.Module):
def __init__(self, num_classes=4, dropout=0.4):
super().__init__()
self.backbone = efficientnet_b3(weights=None)
in_features = self.backbone.classifier[1].in_features
self.backbone.classifier = nn.Sequential(
nn.Dropout(p=dropout, inplace=True),
nn.Linear(in_features, 512),
nn.SiLU(),
nn.Dropout(p=dropout / 2),
nn.Linear(512, num_classes),
)
def forward(self, x):
return self.backbone(x)
def load_model():
ckpt_path = hf_hub_download(repo_id="S-4-G-4-R/brain-tumor-efficientnet-b3", filename=CKPT_FILE)
ckpt = torch.load(ckpt_path, map_location=DEVICE, weights_only=False)
n_classes = ckpt.get("num_classes", 4)
img_size = ckpt.get("img_size", 300)
id_to_label = {int(k): v for k, v in ckpt["id_to_label"].items()}
model = EfficientNetClassifier(n_classes).to(DEVICE)
model.load_state_dict(ckpt["model"])
model.eval()
return model, img_size, id_to_label
print("Loading model…")
model, IMG_SIZE, id_to_label = load_model()
print(f"Model ready on {DEVICE}")
transform = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(MEAN, STD),
])
# ── Inference ─────────────────────────────────────────────────────
@torch.no_grad()
def predict(image: Image.Image):
if image is None:
return None, _empty_state()
tensor = transform(image.convert("RGB")).unsqueeze(0).to(DEVICE)
logits = model(tensor)
probs = torch.softmax(logits, dim=-1)[0]
results = {id_to_label[i]: round(probs[i].item(), 4) for i in range(len(id_to_label))}
top_label = max(results, key=results.get)
top_prob = results[top_label]
# Normalise key for CLASS_INFO lookup
label_key = top_label
for k in CLASS_INFO:
if k.lower() == top_label.lower():
label_key = k
break
info = CLASS_INFO.get(label_key, {})
color = info.get("color", "#ffffff")
glow = info.get("glow", "rgba(255,255,255,0.1)")
icon = info.get("icon", "βšͺ")
desc = info.get("desc", "")
# ── Probability bars ──────────────────────────────────────────
bars_html = ""
for lbl, prob in sorted(results.items(), key=lambda x: x[1], reverse=True):
lkey = lbl
for k in CLASS_INFO:
if k.lower() == lbl.lower():
lkey = k
break
c = CLASS_INFO.get(lkey, {}).get("color", "#555")
is_top = lbl == top_label
bars_html += f"""
<div style="margin-bottom:14px;">
<div style="display:flex;justify-content:space-between;align-items:center;margin-bottom:5px;">
<span style="font-size:13px;color:{'#e5e7eb' if is_top else '#6b7280'};
font-weight:{'600' if is_top else '400'};
font-family:'Space Grotesk',sans-serif;">
{CLASS_INFO.get(lkey,{}).get('icon','βšͺ')} {lbl}
</span>
<span style="font-size:13px;color:{c};font-weight:700;
font-family:'Space Grotesk',sans-serif;">{prob*100:.2f}%</span>
</div>
<div style="background:#1f2937;border-radius:99px;height:5px;overflow:hidden;">
<div style="height:100%;width:{prob*100:.2f}%;background:{c};
border-radius:99px;opacity:{'1' if is_top else '0.45'};
transition:width 0.7s cubic-bezier(0.4,0,0.2,1);"></div>
</div>
</div>"""
html = f"""
<div style="
background:linear-gradient(145deg,#0d1117,#111827);
border:1px solid #1f2937;
border-radius:16px;
padding:28px;
font-family:'Space Grotesk',sans-serif;
height:100%;
box-sizing:border-box;
animation: fadeIn 0.4s ease;
">
<!-- Diagnosis card -->
<div style="
background:linear-gradient(135deg,{glow},{glow.replace('0.25','0.08')});
border:1px solid {color}33;
border-radius:12px;
padding:20px;
margin-bottom:24px;
box-shadow: 0 0 32px {glow};
">
<div style="font-size:11px;letter-spacing:0.14em;color:#6b7280;
text-transform:uppercase;margin-bottom:8px;">
πŸ”¬ Diagnosis
</div>
<div style="font-size:30px;font-weight:700;color:{color};
letter-spacing:-0.03em;margin-bottom:6px;">
{icon} {top_label}
</div>
<div style="font-size:13px;color:#9ca3af;line-height:1.65;">{desc}</div>
</div>
<!-- Confidence meter -->
<div style="margin-bottom:24px;">
<div style="display:flex;justify-content:space-between;align-items:center;margin-bottom:8px;">
<span style="font-size:11px;letter-spacing:0.12em;color:#6b7280;text-transform:uppercase;">
πŸ“Š Confidence
</span>
<span style="font-size:22px;font-weight:800;color:{color};">{top_prob*100:.1f}%</span>
</div>
<div style="background:#1f2937;border-radius:99px;height:8px;overflow:hidden;">
<div style="height:100%;width:{top_prob*100:.1f}%;
background:linear-gradient(90deg,{color}99,{color});
border-radius:99px;
box-shadow:0 0 12px {color}66;
transition:width 0.7s cubic-bezier(0.4,0,0.2,1);"></div>
</div>
</div>
<!-- All probabilities -->
<div>
<div style="font-size:11px;letter-spacing:0.12em;color:#6b7280;
text-transform:uppercase;margin-bottom:14px;">
πŸ“ˆ All Classes
</div>
{bars_html}
</div>
<!-- Disclaimer -->
<div style="margin-top:20px;padding-top:16px;border-top:1px solid #1f2937;
font-size:11px;color:#374151;text-align:center;line-height:1.5;">
⚠️ For research use only · Not a clinical diagnostic tool
</div>
</div>
<style>
@keyframes fadeIn {{ from {{opacity:0;transform:translateY(6px)}} to {{opacity:1;transform:translateY(0)}} }}
</style>
"""
return results, html
def _empty_state():
return """
<div style="
background:linear-gradient(145deg,#0d1117,#111827);
border:1px solid #1f2937;
border-radius:16px;
padding:28px;
display:flex;
flex-direction:column;
align-items:center;
justify-content:center;
gap:16px;
min-height:340px;
box-sizing:border-box;
font-family:'Space Grotesk',sans-serif;
">
<div style="font-size:52px;opacity:0.18;">🧠</div>
<div style="font-size:16px;font-weight:600;color:#374151;letter-spacing:-0.01em;">
Awaiting MRI scan
</div>
<div style="font-size:13px;color:#374151;text-align:center;line-height:1.6;max-width:240px;">
Upload or drag-and-drop a brain MRI image on the left to see the classification result here.
</div>
</div>"""
# ── CSS ───────────────────────────────────────────────────────────
CSS = """
@import url('https://fonts.googleapis.com/css2?family=Space+Grotesk:wght@300;400;500;600;700;800&family=Space+Mono:wght@400;700&display=swap');
*, *::before, *::after { box-sizing: border-box; }
:root {
--bg: #080c12;
--surface: #0d1117;
--border: #1f2937;
--accent: #6366f1;
--muted: #6b7280;
--text: #e5e7eb;
--font: 'Space Grotesk', sans-serif;
--mono: 'Space Mono', monospace;
}
html, body, .gradio-container {
background: var(--bg) !important;
font-family: var(--font) !important;
color: var(--text) !important;
}
.gradio-container {
max-width: 1100px !important;
margin: 0 auto !important;
padding: 0 16px !important;
}
/* ── Header ── */
#hero {
padding: 44px 8px 36px;
text-align: center;
border-bottom: 1px solid var(--border);
margin-bottom: 36px;
}
#hero .pill {
display: inline-block;
font-family: var(--mono);
font-size: 10px;
letter-spacing: 0.15em;
text-transform: uppercase;
padding: 5px 14px;
border: 1px solid #2a3a4a;
border-radius: 99px;
color: #4b6a8a;
margin-bottom: 20px;
background: #0a131e;
}
#hero h1 {
font-size: clamp(26px, 5vw, 42px);
font-weight: 800;
letter-spacing: -0.04em;
color: #f1f5f9;
margin: 0 0 12px;
line-height: 1.1;
}
#hero h1 span { color: #6366f1; }
#hero p {
font-size: 14px;
color: var(--muted);
margin: 0;
line-height: 1.7;
max-width: 520px;
margin: 0 auto;
}
/* ── Two-column wrapper ── */
#main-row {
display: grid !important;
grid-template-columns: 1fr 1fr !important;
gap: 20px !important;
align-items: start !important;
}
@media (max-width: 700px) {
#main-row {
grid-template-columns: 1fr !important;
}
}
/* ── Left panel ── */
#upload-panel {
background: var(--surface) !important;
border: 1px solid var(--border) !important;
border-radius: 16px !important;
padding: 24px !important;
}
#upload-panel .panel-label {
font-size: 11px;
letter-spacing: 0.14em;
text-transform: uppercase;
color: var(--muted);
margin-bottom: 16px;
font-family: var(--mono);
}
/* Gradio image component dark styling */
.upload-wrap .svelte-1ipelgc,
.upload-wrap [data-testid="image"] {
background: #080c12 !important;
border: 1.5px dashed #2a3a4a !important;
border-radius: 12px !important;
min-height: 260px !important;
transition: border-color 0.25s;
}
.upload-wrap [data-testid="image"]:hover {
border-color: var(--accent) !important;
}
/* ── Classify button ── */
#classify-btn {
margin-top: 14px !important;
width: 100% !important;
background: var(--accent) !important;
border: none !important;
border-radius: 10px !important;
color: #fff !important;
font-family: var(--font) !important;
font-size: 14px !important;
font-weight: 700 !important;
letter-spacing: 0.06em !important;
padding: 13px 0 !important;
cursor: pointer !important;
transition: opacity 0.2s, transform 0.15s !important;
box-shadow: 0 0 24px rgba(99,102,241,0.35) !important;
}
#classify-btn:hover {
opacity: 0.88 !important;
transform: translateY(-1px) !important;
}
#classify-btn:active {
transform: translateY(0) !important;
}
/* ── Upload hint text ── */
#upload-hint {
font-size: 12px;
color: #374151;
text-align: center;
margin-top: 10px;
line-height: 1.6;
}
/* ── Stats strip ── */
#stats-strip {
display: flex;
gap: 12px;
margin-top: 16px;
}
.stat-chip {
flex: 1;
background: #0a131e;
border: 1px solid #1a2535;
border-radius: 8px;
padding: 10px 12px;
text-align: center;
}
.stat-chip .val {
font-size: 16px;
font-weight: 800;
color: #6366f1;
font-family: var(--mono);
display: block;
letter-spacing: -0.02em;
}
.stat-chip .lbl {
font-size: 10px;
color: #374151;
text-transform: uppercase;
letter-spacing: 0.1em;
margin-top: 2px;
display: block;
}
/* ── Right panel / result ── */
.result-panel > label { display: none !important; }
#result-col { align-self: stretch; }
/* ── Footer ── */
#footer {
text-align: center;
padding: 28px 16px;
border-top: 1px solid var(--border);
margin-top: 36px;
font-size: 12px;
color: #2d3748;
line-height: 1.8;
}
#footer a { color: #4b6a8a; text-decoration: none; }
#footer a:hover { color: var(--accent); }
/* ── Gradio internal overrides ── */
label span {
font-family: var(--font) !important;
font-size: 11px !important;
font-weight: 600 !important;
letter-spacing: 0.1em !important;
text-transform: uppercase !important;
color: var(--muted) !important;
}
/* Remove default gradio row gaps */
.gr-row { gap: 0 !important; }
"""
# ── Gradio UI ─────────────────────────────────────────────────────
with gr.Blocks(css=CSS, theme=gr.themes.Base(), title="NeuroScan Β· Brain Tumor MRI Classifier") as demo:
# ── Hero ──────────────────────────────────────────────────────
gr.HTML("""
<div id="hero">
<div class="pill">⚑ EfficientNet-B3 &nbsp;·&nbsp; 98.98% Val Acc &nbsp;·&nbsp; 4 Classes</div>
<h1>🧠 Neuro<span>Scan</span></h1>
<p>
AI-powered brain tumor detection from MRI scans.<br>
Classifies <strong style="color:#e5e7eb;">Glioma Β· Meningioma Β· Pituitary Tumor Β· No Tumor</strong><br>
in seconds β€” just upload your scan below.
</p>
</div>
""")
# ── Main two-column layout ─────────────────────────────────────
with gr.Row(elem_id="main-row"):
# ── Left: Upload panel ────────────────────────────────────
with gr.Column(elem_id="upload-panel", scale=1):
gr.HTML('<div class="panel-label">πŸ“€ Upload MRI Scan</div>')
image_input = gr.Image(
type="pil",
label="",
elem_classes=["upload-wrap"],
height=280,
show_label=False,
)
gr.HTML("""
<div id="upload-hint">
πŸ–ΌοΈ Drag & drop or click to browse<br>
Supports <code style="color:#4b6a8a;">JPG Β· PNG Β· WEBP</code> &nbsp;Β·&nbsp; Axial / coronal / sagittal views
</div>
""")
run_btn = gr.Button("πŸ” Classify MRI Scan", elem_id="classify-btn")
gr.HTML("""
<div id="stats-strip">
<div class="stat-chip">
<span class="val">8.2K</span>
<span class="lbl">Train Images</span>
</div>
<div class="stat-chip">
<span class="val">98.98%</span>
<span class="lbl">Val Accuracy</span>
</div>
<div class="stat-chip">
<span class="val">4</span>
<span class="lbl">Classes</span>
</div>
</div>
""")
# ── Right: Result panel ───────────────────────────────────
with gr.Column(elem_id="result-col", scale=1):
result_html = gr.HTML(
value=_empty_state(),
label="",
elem_classes=["result-panel"],
)
# Hidden label output (internal use)
label_output = gr.Label(visible=False)
# ── Event bindings ─────────────────────────────────────────────
run_btn.click(fn=predict, inputs=[image_input], outputs=[label_output, result_html])
image_input.change(fn=predict, inputs=[image_input], outputs=[label_output, result_html])
# ── Footer ────────────────────────────────────────────────────
gr.HTML("""
<div id="footer">
πŸ”¬ <strong style="color:#374151;">NeuroScan</strong> &nbsp;Β·&nbsp;
EfficientNet-B3 fine-tuned on Figshare + Kaggle Brain Tumor datasets &nbsp;Β·&nbsp;
<a href="https://huggingface.co/S-4-G-4-R/brain-tumor-efficientnet-b3" target="_blank">
πŸ€— Model on Hugging Face
</a>
<br>
⚠️ This tool is intended for research and educational purposes only.
It is <strong>not</strong> a substitute for clinical diagnosis.
</div>
""")
if __name__ == "__main__":
demo.launch()