Halemo's picture
Cache analyze + PDF (fix vibrate-on-resize)
367d7c1 verified
from __future__ import annotations
import os
from io import BytesIO
from pathlib import Path
import streamlit as st
import torch
from PIL import Image
from torchvision import transforms
from src.model import get_gradcam_target_layer, load_model
from src.utils import (
LABELS,
compute_gradcam,
is_likely_mri,
make_pdf_report,
predict_with_probs,
)
ROOT = Path(__file__).resolve().parent
MODEL_PATH = ROOT / "models" / "model_38"
SAMPLE_DIR = ROOT / "sample"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_VERSION = "model_38"
GITHUB_URL = "https://github.com/HalemoGPA/BrainMRI-Tumor-Classifier-Pytorch"
DATASET_URL = "https://www.kaggle.com/datasets/masoudnickparvar/brain-tumor-mri-dataset"
CLASS_DESCRIPTIONS = {
"No Tumor": "Healthy brain tissue. No tumor detected in the scan.",
"Pituitary": "Tumor in the pituitary gland (base of the brain). Often benign.",
"Glioma": "Tumor arising from glial cells. Can be aggressive; needs urgent review.",
"Meningioma": "Tumor of the meninges (brain's protective layers). Usually slow-growing.",
}
st.set_page_config(
page_title="Brain MRI Tumor Classifier",
page_icon="🧠",
layout="wide",
initial_sidebar_state="expanded",
menu_items={
"Get Help": GITHUB_URL,
"Report a bug": f"{GITHUB_URL}/issues",
"About": "Brain MRI Tumor Classifier - PyTorch CNN with Grad-CAM explainability.",
},
)
CSS = """
<style>
:root {
--accent: #7c3aed;
--accent-2: #06b6d4;
--good: #10b981;
--warn: #f59e0b;
--bad: #ef4444;
}
.block-container { padding-top: 1.2rem; max-width: 1280px; }
header[data-testid="stHeader"] { background: transparent; }
.hero {
background: linear-gradient(120deg, rgba(124,58,237,0.18), rgba(6,182,212,0.12));
border: 1px solid rgba(124,58,237,0.35);
border-radius: 16px;
padding: 22px 28px;
margin-bottom: 18px;
}
.hero h1 {
font-size: 2.2rem;
margin: 0 0 4px 0;
background: linear-gradient(90deg, #c4b5fd, #67e8f9);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
background-clip: text;
}
.hero p { margin: 0; color: #b8c0d6; }
.disclaimer-banner {
background: rgba(245, 158, 11, 0.08);
border-left: 3px solid var(--warn);
padding: 8px 14px;
border-radius: 6px;
font-size: 0.86rem;
color: #fcd34d;
margin-bottom: 18px;
}
.ood-banner {
background: rgba(239, 68, 68, 0.10);
border-left: 3px solid var(--bad);
padding: 12px 16px;
border-radius: 8px;
font-size: 0.92rem;
color: #fecaca;
margin: 8px 0 18px 0;
line-height: 1.5;
}
.ood-banner em { color: #fca5a5; font-style: normal; font-weight: 600; }
.section-title {
display: flex; align-items: center; gap: 10px;
font-size: 1.1rem; font-weight: 600; color: #e6e6e6;
margin: 18px 0 10px 0;
}
.section-title .step {
display: inline-flex; align-items: center; justify-content: center;
width: 26px; height: 26px;
background: linear-gradient(135deg, var(--accent), var(--accent-2));
color: white; border-radius: 50%; font-size: 0.82rem; font-weight: 700;
}
.pred-card {
background: rgba(124, 58, 237, 0.08);
border: 1px solid rgba(124, 58, 237, 0.35);
border-radius: 14px;
padding: 18px 22px;
margin-bottom: 14px;
}
.pred-card .label-tag {
display: inline-block;
background: linear-gradient(90deg, var(--accent), var(--accent-2));
color: white; padding: 4px 12px; border-radius: 999px;
font-size: 0.78rem; letter-spacing: 0.04em; text-transform: uppercase;
font-weight: 700;
margin-bottom: 8px;
}
.pred-card h2 { margin: 4px 0 10px 0; font-size: 2rem; }
.pred-card .conf { font-size: 0.92rem; color: #b8c0d6; margin-bottom: 6px; }
.pred-card .conf-val { font-size: 2.4rem; font-weight: 700; line-height: 1; }
.pred-card .desc { color: #c8cfde; font-size: 0.9rem; margin-top: 12px; line-height: 1.5; }
.prob-row {
display: flex; align-items: center; gap: 12px; margin: 6px 0;
font-size: 0.92rem;
}
.prob-row .name { width: 110px; color: #d6dbeb; }
.prob-row .bar {
flex: 1; background: rgba(255,255,255,0.06);
border-radius: 999px; height: 10px; overflow: hidden;
}
.prob-row .bar > div {
height: 100%; border-radius: 999px;
transition: width 0.6s cubic-bezier(0.22, 1, 0.36, 1);
}
.prob-row .pct { width: 56px; text-align: right; color: #b8c0d6; font-variant-numeric: tabular-nums; }
div[data-testid="stImage"] img {
border-radius: 12px;
box-shadow: 0 8px 24px rgba(0,0,0,0.4);
}
.sample-card { text-align: center; }
.sample-card img { border-radius: 10px; }
div[data-testid="stFileUploader"] section {
border: 2px dashed rgba(124, 58, 237, 0.4);
border-radius: 12px;
background: rgba(124, 58, 237, 0.04);
transition: border-color 0.2s, background 0.2s;
}
div[data-testid="stFileUploader"] section:hover {
border-color: var(--accent);
background: rgba(124, 58, 237, 0.08);
}
div.stButton > button[kind="primary"] {
background: linear-gradient(90deg, var(--accent), var(--accent-2));
border: none;
}
div.stDownloadButton > button {
background: linear-gradient(90deg, var(--accent), var(--accent-2));
border: none; color: white; font-weight: 600;
}
footer { visibility: hidden; }
</style>
"""
TRANSFORM = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
@st.cache_resource(show_spinner="Loading model…")
def get_model():
return load_model(str(MODEL_PATH), DEVICE)
@st.cache_data(show_spinner=False)
def list_sample_images() -> list[tuple[str, bytes]]:
out: list[tuple[str, bytes]] = []
for path in sorted(SAMPLE_DIR.iterdir()):
if path.suffix.lower() in {".jpg", ".jpeg", ".png"}:
out.append((path.name, path.read_bytes()))
return out
def preprocess(image: Image.Image) -> torch.Tensor:
return TRANSFORM(image).unsqueeze(0)
def confidence_color(p: float) -> str:
if p >= 0.85:
return "var(--good)"
if p >= 0.6:
return "var(--warn)"
return "var(--bad)"
def confidence_label(p: float) -> str:
if p >= 0.85:
return "High confidence"
if p >= 0.6:
return "Medium confidence"
return "Low confidence"
def render_sidebar() -> None:
with st.sidebar:
st.markdown("## 🧠 Brain MRI Classifier")
st.caption("PyTorch CNN with Grad-CAM explainability.")
st.markdown("---")
st.markdown("#### 🏗️ Model")
st.markdown(
f"`{MODEL_VERSION}` · 4 conv blocks → FC(512) → FC(num_classes)\n\n"
f"**Input:** 224×224 RGB · **Device:** `{DEVICE}`"
)
st.markdown("#### 🏷️ Classes")
for i, label in enumerate(LABELS):
st.markdown(f"`{i}` · {label}")
st.markdown("#### 📚 Resources")
st.markdown(
f"- [Dataset (Kaggle)]({DATASET_URL})\n"
f"- [GitHub repository]({GITHUB_URL})"
)
st.markdown("---")
if st.button("🔄 Start over", use_container_width=True):
for k in ("mri_bytes", "mri_source"):
st.session_state.pop(k, None)
st.rerun()
def render_step_title(num: int, title: str) -> None:
st.markdown(
f'<div class="section-title"><span class="step">{num}</span>{title}</div>',
unsafe_allow_html=True,
)
def render_input_section() -> Image.Image | None:
render_step_title(1, "Choose an MRI image")
col_upload, col_samples = st.columns([1, 1.4], gap="large")
with col_upload:
st.markdown("**Upload your scan**")
uploaded = st.file_uploader(
"Upload a brain MRI",
type=["jpg", "jpeg", "png"],
label_visibility="collapsed",
accept_multiple_files=False,
)
if uploaded is not None:
st.session_state["mri_bytes"] = uploaded.getvalue()
st.session_state["mri_source"] = uploaded.name
with col_samples:
st.markdown("**Or try a sample**")
samples = list_sample_images()
sample_cols = st.columns(len(samples))
for idx, (col, (name, data)) in enumerate(zip(sample_cols, samples), start=1):
with col:
st.image(data, use_container_width=True)
if st.button(
f"▶ Sample {idx}",
key=f"use_{name}",
use_container_width=True,
type="primary",
):
st.session_state["mri_bytes"] = data
st.session_state["mri_source"] = f"Sample {idx}"
data = st.session_state.get("mri_bytes")
if not data:
return None
return Image.open(BytesIO(data)).convert("RGB")
def render_probability_bars(probs, top_idx: int) -> None:
rows = []
for i, label in enumerate(LABELS):
pct = float(probs[i]) * 100
if i == top_idx:
color = "linear-gradient(90deg, var(--accent), var(--accent-2))"
else:
color = "rgba(255,255,255,0.18)"
width = max(pct, 0.5)
rows.append(
f'<div class="prob-row">'
f'<div class="name">{label}</div>'
f'<div class="bar"><div style="width:{width}%; background:{color};"></div></div>'
f'<div class="pct">{pct:.1f}%</div>'
f"</div>"
)
st.markdown("\n".join(rows), unsafe_allow_html=True)
@st.cache_data(show_spinner="Analyzing scan…", max_entries=8)
def _analyze(mri_bytes: bytes):
image = Image.open(BytesIO(mri_bytes)).convert("RGB")
model = get_model()
tensor = preprocess(image).to(DEVICE)
top_idx, probs = predict_with_probs(model, tensor, DEVICE)
target_layer = get_gradcam_target_layer(model)
heatmap = compute_gradcam(model, tensor, target_layer, top_idx, image)
return image, top_idx, probs, heatmap
@st.cache_data(show_spinner=False, max_entries=8)
def _build_pdf(mri_bytes: bytes, predicted_label: str, probabilities_tuple: tuple) -> bytes:
import numpy as np
image, _, _, heatmap = _analyze(mri_bytes)
return make_pdf_report(
original_image=image,
heatmap_image=heatmap,
predicted_label=predicted_label,
probabilities=np.asarray(probabilities_tuple),
model_version=MODEL_VERSION,
)
def render_results(image: Image.Image) -> None:
mri_bytes = st.session_state["mri_bytes"]
image, top_idx, probs, heatmap = _analyze(mri_bytes)
label = LABELS[top_idx]
confidence = float(probs[top_idx])
conf_pct = confidence * 100
render_step_title(2, "Result")
is_mri, diag = is_likely_mri(image)
if not is_mri:
st.markdown(
'<div class="ood-banner">⚠️ <strong>This image does not look like a brain MRI.</strong> '
"The model was trained on grayscale MRI scans only and will still output "
"a confident-looking prediction for any image you give it. "
"Treat the result below as <em>not meaningful</em>.</div>",
unsafe_allow_html=True,
)
col_pred, col_imgs = st.columns([1, 1.2], gap="large")
with col_pred:
st.markdown(
f"""
<div class="pred-card">
<span class="label-tag">Prediction</span>
<h2>{label}</h2>
<div class="conf">{confidence_label(confidence)}</div>
<div class="conf-val" style="color:{confidence_color(confidence)};">{conf_pct:.1f}%</div>
<div class="desc">{CLASS_DESCRIPTIONS[label]}</div>
</div>
""",
unsafe_allow_html=True,
)
st.markdown("**Per-class probability**")
render_probability_bars(probs, top_idx)
with col_imgs:
img_col1, img_col2 = st.columns(2)
with img_col1:
st.image(
image,
caption=f"Uploaded · {st.session_state.get('mri_source', '')}",
use_container_width=True,
)
with img_col2:
st.image(
heatmap,
caption="Grad-CAM · model focus",
use_container_width=True,
)
render_step_title(3, "Download report")
pdf_bytes = _build_pdf(mri_bytes, label, tuple(float(p) for p in probs))
cdl, cdr = st.columns([1, 3])
with cdl:
st.download_button(
"📄 Download PDF report",
data=pdf_bytes,
file_name=f"mri_report_{label.replace(' ', '_').lower()}.pdf",
mime="application/pdf",
use_container_width=True,
)
with cdr:
st.caption(
"The PDF includes the uploaded image, Grad-CAM heatmap, prediction, "
"per-class probabilities, timestamp, and the disclaimer."
)
def main() -> None:
st.markdown(CSS, unsafe_allow_html=True)
render_sidebar()
st.markdown(
"""
<div class="hero">
<h1>🧠 Brain MRI Tumor Classifier</h1>
<p>Upload a brain MRI scan and get a tumor-type prediction explained with Grad-CAM.</p>
</div>
""",
unsafe_allow_html=True,
)
st.markdown(
'<div class="disclaimer-banner">⚠️ <strong>Educational and research use only</strong> '
"- this app is not a medical device. Do not use these predictions for "
"diagnosis or treatment decisions.</div>",
unsafe_allow_html=True,
)
if not MODEL_PATH.exists():
st.error(f"Model file not found at `{MODEL_PATH}`. Cannot run the app.")
st.stop()
image = render_input_section()
if image is None:
st.info("👆 Upload an MRI or click a sample to begin.")
return
render_results(image)
if __name__ == "__main__":
main()