Koushim's picture
Update README.md
8ea8763 verified
---
license: mit
tags:
- medical
---
# 🧠 Brain Tumor Classification Using Vision Transformer (ViT)
This repository contains a fine-tuned **Vision Transformer (ViT)** model trained on a large collection of MRI scans for brain tumor classification. The model classifies MRI images into one of three categories:
- **Glioma**
- **Meningioma**
- **Tumor (General)**
The dataset used includes over **75,000 color-enhanced MRI images**, making this model highly capable for research and educational applications in brain tumor detection.
---
## πŸ“Š Dataset Information
- **Original Dataset Name**: Brain Cancer - MRI dataset
- **Author**: Rahman, Md Mizanur (2024)
- **Hosted on**: [Mendeley Data](https://data.mendeley.com/datasets/mk56jw9rns/1)
- **DOI**: [10.17632/mk56jw9rns.1](https://doi.org/10.17632/mk56jw9rns.1)
- **Kaggle Rehost (Colorized)**: [Shuvo Kumar Basak on Kaggle](https://www.kaggle.com/datasets/shuvokumarbasakbd/brain-cancer-mri-colorized-dataset)
> **Note:** This dataset is publicly available for non-commercial research use. The model does not include the dataset itself.
---
## 🧠 Model Architecture
- **Model Type**: Vision Transformer (ViT-B/16)
- **Framework**: PyTorch + [timm](https://github.com/huggingface/pytorch-image-models)
- **Input Shape**: 224x224 RGB
- **Number of Classes**: 3
- **Loss Function**: CrossEntropyLoss
- **Optimizer**: AdamW
---
## 🏁 Training Pipeline Summary
1. **Image Preprocessing**:
- Resize to 224x224
- Normalization using ImageNet stats
- Augmentations: Horizontal/Vertical Flip, ShiftScaleRotate, BrightnessContrast, etc.
2. **DataLoader**:
- Stratified Split (Train/Val/Test)
- PyTorch `Dataset` and `DataLoader` classes
3. **Model**:
- Loaded ViT using `timm.create_model('vit_base_patch16_224', pretrained=True)`
- Modified the classifier head to match 3 output classes
4. **Training**:
- Trained using mixed precision (`torch.cuda.amp`)
- Tracked using `tqdm`
5. **Saving**:
- Model saved as `pytorch_model.bin`
- Configuration saved as `config.json`
---
## πŸ” Intended Use
This model is designed for:
- Educational purposes (deep learning and medical imaging)
- Research in brain tumor classification using transformers
- Demonstrating the power of ViT on colorized medical datasets
⚠️ **Not intended for clinical use** or deployment without regulatory approval and further validation.
---
## πŸš€ Inference Example (Python)
```python
from timm import create_model
import torch
from torchvision import transforms
from PIL import Image
# Load model
model = create_model('vit_base_patch16_224', pretrained=False, num_classes=3)
model.load_state_dict(torch.load("pytorch_model.bin"))
model.eval()
# Transform
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5]*3, std=[0.5]*3),
])
# Inference
image = Image.open("example_mri.jpg").convert("RGB")
tensor = transform(image).unsqueeze(0)
output = model(tensor)
pred = torch.argmax(output, dim=1)
print("Predicted class:", pred.item())