--- title: Brain Tumor MRI Classifier emoji: ๐Ÿง  colorFrom: indigo colorTo: purple sdk: gradio sdk_version: 5.29.0 app_file: app.py pinned: false license: mit tags: - medical-imaging - brain-tumor - efficientnet - image-classification - pytorch - mri --- # Brain Tumor MRI Classifier โ€” EfficientNet-B3 A fine-tuned **EfficientNet-B3** model for 4-class brain tumor classification from MRI scans, achieving **98.98% validation accuracy** and **0.9896 macro F1**. ## Classes | Class | Description | |---|---| | Glioma | Tumor originating in glial cells of the brain or spine | | Meningioma | Tumor arising from the meninges surrounding the brain | | Pituitary Tumor | Tumor in the pituitary gland at the base of the brain | | No Tumor | No tumor detected in the MRI scan | ## Model - **Architecture**: EfficientNet-B3 (ImageNet pretrained) with custom classification head - **Head**: `Dropout โ†’ Linear(1536, 512) โ†’ SiLU โ†’ Dropout โ†’ Linear(512, 4)` - **Input size**: 300 ร— 300 - **Training**: Two-phase โ€” backbone frozen for 5 epochs (head LR 1e-3), then full fine-tune with differential LR (backbone 1e-4, head 1e-3) - **Schedule**: Cosine decay with 3-epoch linear warmup - **Loss**: Class-weighted cross-entropy ## Weights The model weights (`model.pt`) are hosted in this repository and downloaded automatically on first run via `huggingface_hub`. To download manually: ```python from huggingface_hub import hf_hub_download ckpt_path = hf_hub_download(repo_id="your-hf-username/brain-tumor-efficientnet-b3", filename="model.pt") ``` ## Dataset Trained on a merged dataset from two sources: - **Figshare Brain Tumor Dataset** โ€” glioma, meningioma, pituitary MRI scans - **Kaggle Brain Tumor MRI Dataset** โ€” 4-class dataset with glioma, meningioma, pituitary, no tumor | Split | Images | |---|---| | Train | 8,211 | | Validation | 2,053 | ## Results | Metric | Score | |---|---| | Accuracy | 0.9898 | | Macro F1 | 0.9896 | | Weighted F1 | 0.9898 | Per-class F1: Glioma 0.9915 ยท Meningioma 0.9832 ยท No Tumor 0.9903 ยท Pituitary 0.9935 ## Usage ```python import torch import torch.nn as nn from torchvision import transforms from torchvision.models import efficientnet_b3 from huggingface_hub import hf_hub_download from PIL import Image class EfficientNetClassifier(nn.Module): def __init__(self, num_classes=4, dropout=0.4): super().__init__() self.backbone = efficientnet_b3(weights=None) in_features = self.backbone.classifier[1].in_features self.backbone.classifier = nn.Sequential( nn.Dropout(p=dropout, inplace=True), nn.Linear(in_features, 512), nn.SiLU(), nn.Dropout(p=dropout / 2), nn.Linear(512, num_classes), ) def forward(self, x): return self.backbone(x) # Load ckpt_path = hf_hub_download("your-hf-username/brain-tumor-efficientnet-b3", "model.pt") ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) id_to_label = {int(k): v for k, v in ckpt["id_to_label"].items()} model = EfficientNetClassifier() model.load_state_dict(ckpt["model"]) model.eval() # Infer transform = transforms.Compose([ transforms.Resize((300, 300)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) img = Image.open("mri_scan.jpg").convert("RGB") with torch.no_grad(): probs = torch.softmax(model(transform(img).unsqueeze(0)), dim=-1)[0] pred = id_to_label[probs.argmax().item()] print(pred) ``` ## Disclaimer This model is intended for **research purposes only** and is not a certified medical diagnostic tool. Do not use for clinical decision-making.