--- license: mit tags: - medical --- # 🧠 Brain Tumor Classification Using Vision Transformer (ViT) This repository contains a fine-tuned **Vision Transformer (ViT)** model trained on a large collection of MRI scans for brain tumor classification. The model classifies MRI images into one of three categories: - **Glioma** - **Meningioma** - **Tumor (General)** The dataset used includes over **75,000 color-enhanced MRI images**, making this model highly capable for research and educational applications in brain tumor detection. --- ## 📊 Dataset Information - **Original Dataset Name**: Brain Cancer - MRI dataset - **Author**: Rahman, Md Mizanur (2024) - **Hosted on**: [Mendeley Data](https://data.mendeley.com/datasets/mk56jw9rns/1) - **DOI**: [10.17632/mk56jw9rns.1](https://doi.org/10.17632/mk56jw9rns.1) - **Kaggle Rehost (Colorized)**: [Shuvo Kumar Basak on Kaggle](https://www.kaggle.com/datasets/shuvokumarbasakbd/brain-cancer-mri-colorized-dataset) > **Note:** This dataset is publicly available for non-commercial research use. The model does not include the dataset itself. --- ## 🧠 Model Architecture - **Model Type**: Vision Transformer (ViT-B/16) - **Framework**: PyTorch + [timm](https://github.com/huggingface/pytorch-image-models) - **Input Shape**: 224x224 RGB - **Number of Classes**: 3 - **Loss Function**: CrossEntropyLoss - **Optimizer**: AdamW --- ## 🏁 Training Pipeline Summary 1. **Image Preprocessing**: - Resize to 224x224 - Normalization using ImageNet stats - Augmentations: Horizontal/Vertical Flip, ShiftScaleRotate, BrightnessContrast, etc. 2. **DataLoader**: - Stratified Split (Train/Val/Test) - PyTorch `Dataset` and `DataLoader` classes 3. **Model**: - Loaded ViT using `timm.create_model('vit_base_patch16_224', pretrained=True)` - Modified the classifier head to match 3 output classes 4. **Training**: - Trained using mixed precision (`torch.cuda.amp`) - Tracked using `tqdm` 5. **Saving**: - Model saved as `pytorch_model.bin` - Configuration saved as `config.json` --- ## 🔍 Intended Use This model is designed for: - Educational purposes (deep learning and medical imaging) - Research in brain tumor classification using transformers - Demonstrating the power of ViT on colorized medical datasets ⚠️ **Not intended for clinical use** or deployment without regulatory approval and further validation. --- ## 🚀 Inference Example (Python) ```python from timm import create_model import torch from torchvision import transforms from PIL import Image # Load model model = create_model('vit_base_patch16_224', pretrained=False, num_classes=3) model.load_state_dict(torch.load("pytorch_model.bin")) model.eval() # Transform transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.5]*3, std=[0.5]*3), ]) # Inference image = Image.open("example_mri.jpg").convert("RGB") tensor = transform(image).unsqueeze(0) output = model(tensor) pred = torch.argmax(output, dim=1) print("Predicted class:", pred.item())