|
|
--- |
|
|
license: mit |
|
|
tags: |
|
|
- medical |
|
|
--- |
|
|
# π§ Brain Tumor Classification Using Vision Transformer (ViT) |
|
|
|
|
|
This repository contains a fine-tuned **Vision Transformer (ViT)** model trained on a large collection of MRI scans for brain tumor classification. The model classifies MRI images into one of three categories: |
|
|
|
|
|
- **Glioma** |
|
|
- **Meningioma** |
|
|
- **Tumor (General)** |
|
|
|
|
|
The dataset used includes over **75,000 color-enhanced MRI images**, making this model highly capable for research and educational applications in brain tumor detection. |
|
|
|
|
|
--- |
|
|
|
|
|
## π Dataset Information |
|
|
|
|
|
- **Original Dataset Name**: Brain Cancer - MRI dataset |
|
|
- **Author**: Rahman, Md Mizanur (2024) |
|
|
- **Hosted on**: [Mendeley Data](https://data.mendeley.com/datasets/mk56jw9rns/1) |
|
|
- **DOI**: [10.17632/mk56jw9rns.1](https://doi.org/10.17632/mk56jw9rns.1) |
|
|
- **Kaggle Rehost (Colorized)**: [Shuvo Kumar Basak on Kaggle](https://www.kaggle.com/datasets/shuvokumarbasakbd/brain-cancer-mri-colorized-dataset) |
|
|
|
|
|
> **Note:** This dataset is publicly available for non-commercial research use. The model does not include the dataset itself. |
|
|
|
|
|
--- |
|
|
|
|
|
## π§ Model Architecture |
|
|
|
|
|
- **Model Type**: Vision Transformer (ViT-B/16) |
|
|
- **Framework**: PyTorch + [timm](https://github.com/huggingface/pytorch-image-models) |
|
|
- **Input Shape**: 224x224 RGB |
|
|
- **Number of Classes**: 3 |
|
|
- **Loss Function**: CrossEntropyLoss |
|
|
- **Optimizer**: AdamW |
|
|
|
|
|
--- |
|
|
|
|
|
## π Training Pipeline Summary |
|
|
|
|
|
1. **Image Preprocessing**: |
|
|
- Resize to 224x224 |
|
|
- Normalization using ImageNet stats |
|
|
- Augmentations: Horizontal/Vertical Flip, ShiftScaleRotate, BrightnessContrast, etc. |
|
|
|
|
|
2. **DataLoader**: |
|
|
- Stratified Split (Train/Val/Test) |
|
|
- PyTorch `Dataset` and `DataLoader` classes |
|
|
|
|
|
3. **Model**: |
|
|
- Loaded ViT using `timm.create_model('vit_base_patch16_224', pretrained=True)` |
|
|
- Modified the classifier head to match 3 output classes |
|
|
|
|
|
4. **Training**: |
|
|
- Trained using mixed precision (`torch.cuda.amp`) |
|
|
- Tracked using `tqdm` |
|
|
|
|
|
5. **Saving**: |
|
|
- Model saved as `pytorch_model.bin` |
|
|
- Configuration saved as `config.json` |
|
|
|
|
|
--- |
|
|
|
|
|
## π Intended Use |
|
|
|
|
|
This model is designed for: |
|
|
|
|
|
- Educational purposes (deep learning and medical imaging) |
|
|
- Research in brain tumor classification using transformers |
|
|
- Demonstrating the power of ViT on colorized medical datasets |
|
|
|
|
|
β οΈ **Not intended for clinical use** or deployment without regulatory approval and further validation. |
|
|
|
|
|
--- |
|
|
|
|
|
## π Inference Example (Python) |
|
|
|
|
|
```python |
|
|
from timm import create_model |
|
|
import torch |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
|
|
|
# Load model |
|
|
model = create_model('vit_base_patch16_224', pretrained=False, num_classes=3) |
|
|
model.load_state_dict(torch.load("pytorch_model.bin")) |
|
|
model.eval() |
|
|
|
|
|
# Transform |
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.5]*3, std=[0.5]*3), |
|
|
]) |
|
|
|
|
|
# Inference |
|
|
image = Image.open("example_mri.jpg").convert("RGB") |
|
|
tensor = transform(image).unsqueeze(0) |
|
|
output = model(tensor) |
|
|
pred = torch.argmax(output, dim=1) |
|
|
print("Predicted class:", pred.item()) |