File size: 3,100 Bytes
2208b4a
 
 
 
 
8ea8763
2208b4a
8ea8763
2208b4a
8ea8763
 
 
2208b4a
8ea8763
2208b4a
8ea8763
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2208b4a
 
8ea8763
2208b4a
8ea8763
 
2208b4a
8ea8763
 
2208b4a
 
 
8ea8763
2208b4a
 
 
 
 
8ea8763
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
---
license: mit
tags:
- medical
---
# ๐Ÿง  Brain Tumor Classification Using Vision Transformer (ViT)

This repository contains a fine-tuned **Vision Transformer (ViT)** model trained on a large collection of MRI scans for brain tumor classification. The model classifies MRI images into one of three categories:

- **Glioma**
- **Meningioma**
- **Tumor (General)**

The dataset used includes over **75,000 color-enhanced MRI images**, making this model highly capable for research and educational applications in brain tumor detection.

---

## ๐Ÿ“Š Dataset Information

- **Original Dataset Name**: Brain Cancer - MRI dataset
- **Author**: Rahman, Md Mizanur (2024)
- **Hosted on**: [Mendeley Data](https://data.mendeley.com/datasets/mk56jw9rns/1)  
- **DOI**: [10.17632/mk56jw9rns.1](https://doi.org/10.17632/mk56jw9rns.1)  
- **Kaggle Rehost (Colorized)**: [Shuvo Kumar Basak on Kaggle](https://www.kaggle.com/datasets/shuvokumarbasakbd/brain-cancer-mri-colorized-dataset)

> **Note:** This dataset is publicly available for non-commercial research use. The model does not include the dataset itself.

---

## ๐Ÿง  Model Architecture

- **Model Type**: Vision Transformer (ViT-B/16)
- **Framework**: PyTorch + [timm](https://github.com/huggingface/pytorch-image-models)
- **Input Shape**: 224x224 RGB
- **Number of Classes**: 3
- **Loss Function**: CrossEntropyLoss
- **Optimizer**: AdamW

---

## ๐Ÿ Training Pipeline Summary

1. **Image Preprocessing**:
   - Resize to 224x224
   - Normalization using ImageNet stats
   - Augmentations: Horizontal/Vertical Flip, ShiftScaleRotate, BrightnessContrast, etc.

2. **DataLoader**:
   - Stratified Split (Train/Val/Test)
   - PyTorch `Dataset` and `DataLoader` classes

3. **Model**:
   - Loaded ViT using `timm.create_model('vit_base_patch16_224', pretrained=True)`
   - Modified the classifier head to match 3 output classes

4. **Training**:
   - Trained using mixed precision (`torch.cuda.amp`)
   - Tracked using `tqdm` 

5. **Saving**:
   - Model saved as `pytorch_model.bin`
   - Configuration saved as `config.json`

---

## ๐Ÿ” Intended Use

This model is designed for:

- Educational purposes (deep learning and medical imaging)
- Research in brain tumor classification using transformers
- Demonstrating the power of ViT on colorized medical datasets

โš ๏ธ **Not intended for clinical use** or deployment without regulatory approval and further validation.

---

## ๐Ÿš€ Inference Example (Python)

```python
from timm import create_model
import torch
from torchvision import transforms
from PIL import Image

# Load model
model = create_model('vit_base_patch16_224', pretrained=False, num_classes=3)
model.load_state_dict(torch.load("pytorch_model.bin"))
model.eval()

# Transform
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3),
])

# Inference
image = Image.open("example_mri.jpg").convert("RGB")
tensor = transform(image).unsqueeze(0)
output = model(tensor)
pred = torch.argmax(output, dim=1)
print("Predicted class:", pred.item())