from __future__ import annotations import os from io import BytesIO from pathlib import Path import streamlit as st import torch from PIL import Image from torchvision import transforms from src.model import get_gradcam_target_layer, load_model from src.utils import ( LABELS, compute_gradcam, is_likely_mri, make_pdf_report, predict_with_probs, ) ROOT = Path(__file__).resolve().parent MODEL_PATH = ROOT / "models" / "model_38" SAMPLE_DIR = ROOT / "sample" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MODEL_VERSION = "model_38" GITHUB_URL = "https://github.com/HalemoGPA/BrainMRI-Tumor-Classifier-Pytorch" DATASET_URL = "https://www.kaggle.com/datasets/masoudnickparvar/brain-tumor-mri-dataset" CLASS_DESCRIPTIONS = { "No Tumor": "Healthy brain tissue. No tumor detected in the scan.", "Pituitary": "Tumor in the pituitary gland (base of the brain). Often benign.", "Glioma": "Tumor arising from glial cells. Can be aggressive; needs urgent review.", "Meningioma": "Tumor of the meninges (brain's protective layers). Usually slow-growing.", } st.set_page_config( page_title="Brain MRI Tumor Classifier", page_icon="🧠", layout="wide", initial_sidebar_state="expanded", menu_items={ "Get Help": GITHUB_URL, "Report a bug": f"{GITHUB_URL}/issues", "About": "Brain MRI Tumor Classifier - PyTorch CNN with Grad-CAM explainability.", }, ) CSS = """ """ TRANSFORM = transforms.Compose( [ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) @st.cache_resource(show_spinner="Loading model…") def get_model(): return load_model(str(MODEL_PATH), DEVICE) @st.cache_data(show_spinner=False) def list_sample_images() -> list[tuple[str, bytes]]: out: list[tuple[str, bytes]] = [] for path in sorted(SAMPLE_DIR.iterdir()): if path.suffix.lower() in {".jpg", ".jpeg", ".png"}: out.append((path.name, path.read_bytes())) return out def preprocess(image: Image.Image) -> torch.Tensor: return TRANSFORM(image).unsqueeze(0) def confidence_color(p: float) -> str: if p >= 0.85: return "var(--good)" if p >= 0.6: return "var(--warn)" return "var(--bad)" def confidence_label(p: float) -> str: if p >= 0.85: return "High confidence" if p >= 0.6: return "Medium confidence" return "Low confidence" def render_sidebar() -> None: with st.sidebar: st.markdown("## 🧠 Brain MRI Classifier") st.caption("PyTorch CNN with Grad-CAM explainability.") st.markdown("---") st.markdown("#### 🏗️ Model") st.markdown( f"`{MODEL_VERSION}` · 4 conv blocks → FC(512) → FC(num_classes)\n\n" f"**Input:** 224×224 RGB · **Device:** `{DEVICE}`" ) st.markdown("#### 🏷️ Classes") for i, label in enumerate(LABELS): st.markdown(f"`{i}` · {label}") st.markdown("#### 📚 Resources") st.markdown( f"- [Dataset (Kaggle)]({DATASET_URL})\n" f"- [GitHub repository]({GITHUB_URL})" ) st.markdown("---") if st.button("🔄 Start over", use_container_width=True): for k in ("mri_bytes", "mri_source"): st.session_state.pop(k, None) st.rerun() def render_step_title(num: int, title: str) -> None: st.markdown( f'
{num}{title}
', unsafe_allow_html=True, ) def render_input_section() -> Image.Image | None: render_step_title(1, "Choose an MRI image") col_upload, col_samples = st.columns([1, 1.4], gap="large") with col_upload: st.markdown("**Upload your scan**") uploaded = st.file_uploader( "Upload a brain MRI", type=["jpg", "jpeg", "png"], label_visibility="collapsed", accept_multiple_files=False, ) if uploaded is not None: st.session_state["mri_bytes"] = uploaded.getvalue() st.session_state["mri_source"] = uploaded.name with col_samples: st.markdown("**Or try a sample**") samples = list_sample_images() sample_cols = st.columns(len(samples)) for idx, (col, (name, data)) in enumerate(zip(sample_cols, samples), start=1): with col: st.image(data, use_container_width=True) if st.button( f"▶ Sample {idx}", key=f"use_{name}", use_container_width=True, type="primary", ): st.session_state["mri_bytes"] = data st.session_state["mri_source"] = f"Sample {idx}" data = st.session_state.get("mri_bytes") if not data: return None return Image.open(BytesIO(data)).convert("RGB") def render_probability_bars(probs, top_idx: int) -> None: rows = [] for i, label in enumerate(LABELS): pct = float(probs[i]) * 100 if i == top_idx: color = "linear-gradient(90deg, var(--accent), var(--accent-2))" else: color = "rgba(255,255,255,0.18)" width = max(pct, 0.5) rows.append( f'
' f'
{label}
' f'
' f'
{pct:.1f}%
' f"
" ) st.markdown("\n".join(rows), unsafe_allow_html=True) @st.cache_data(show_spinner="Analyzing scan…", max_entries=8) def _analyze(mri_bytes: bytes): image = Image.open(BytesIO(mri_bytes)).convert("RGB") model = get_model() tensor = preprocess(image).to(DEVICE) top_idx, probs = predict_with_probs(model, tensor, DEVICE) target_layer = get_gradcam_target_layer(model) heatmap = compute_gradcam(model, tensor, target_layer, top_idx, image) return image, top_idx, probs, heatmap @st.cache_data(show_spinner=False, max_entries=8) def _build_pdf(mri_bytes: bytes, predicted_label: str, probabilities_tuple: tuple) -> bytes: import numpy as np image, _, _, heatmap = _analyze(mri_bytes) return make_pdf_report( original_image=image, heatmap_image=heatmap, predicted_label=predicted_label, probabilities=np.asarray(probabilities_tuple), model_version=MODEL_VERSION, ) def render_results(image: Image.Image) -> None: mri_bytes = st.session_state["mri_bytes"] image, top_idx, probs, heatmap = _analyze(mri_bytes) label = LABELS[top_idx] confidence = float(probs[top_idx]) conf_pct = confidence * 100 render_step_title(2, "Result") is_mri, diag = is_likely_mri(image) if not is_mri: st.markdown( '
⚠️ This image does not look like a brain MRI. ' "The model was trained on grayscale MRI scans only and will still output " "a confident-looking prediction for any image you give it. " "Treat the result below as not meaningful.
", unsafe_allow_html=True, ) col_pred, col_imgs = st.columns([1, 1.2], gap="large") with col_pred: st.markdown( f"""
Prediction

{label}

{confidence_label(confidence)}
{conf_pct:.1f}%
{CLASS_DESCRIPTIONS[label]}
""", unsafe_allow_html=True, ) st.markdown("**Per-class probability**") render_probability_bars(probs, top_idx) with col_imgs: img_col1, img_col2 = st.columns(2) with img_col1: st.image( image, caption=f"Uploaded · {st.session_state.get('mri_source', '')}", use_container_width=True, ) with img_col2: st.image( heatmap, caption="Grad-CAM · model focus", use_container_width=True, ) render_step_title(3, "Download report") pdf_bytes = _build_pdf(mri_bytes, label, tuple(float(p) for p in probs)) cdl, cdr = st.columns([1, 3]) with cdl: st.download_button( "📄 Download PDF report", data=pdf_bytes, file_name=f"mri_report_{label.replace(' ', '_').lower()}.pdf", mime="application/pdf", use_container_width=True, ) with cdr: st.caption( "The PDF includes the uploaded image, Grad-CAM heatmap, prediction, " "per-class probabilities, timestamp, and the disclaimer." ) def main() -> None: st.markdown(CSS, unsafe_allow_html=True) render_sidebar() st.markdown( """

🧠 Brain MRI Tumor Classifier

Upload a brain MRI scan and get a tumor-type prediction explained with Grad-CAM.

""", unsafe_allow_html=True, ) st.markdown( '
⚠️ Educational and research use only ' "- this app is not a medical device. Do not use these predictions for " "diagnosis or treatment decisions.
", unsafe_allow_html=True, ) if not MODEL_PATH.exists(): st.error(f"Model file not found at `{MODEL_PATH}`. Cannot run the app.") st.stop() image = render_input_section() if image is None: st.info("👆 Upload an MRI or click a sample to begin.") return render_results(image) if __name__ == "__main__": main()