vit-lung-cancer / README.md
TurkishCodeMan's picture
🫁 Add fine-tuned ViT-Base lung cancer classifier (normal/malignant/benign)
7a22ac6 verified
---
license: apache-2.0
tags:
- image-classification
- vision-transformer
- vit
- lung-cancer
- medical-imaging
- pytorch
- transformers
base_model: google/vit-base-patch16-224
pipeline_tag: image-classification
---
# 🫁 ViT Lung Cancer Classifier
Fine-tuned **Vision Transformer (ViT-Base/16)** for lung cancer CT image classification
into 3 classes: **normal**, **malignant**, and **benign**.
## 📊 Model Details
| Property | Value |
|---|---|
| Base Model | `google/vit-base-patch16-224` |
| Task | Image Classification (3 classes) |
| Input Size | 224 × 224 px |
| Precision | fp16 |
| Training | Full fine-tuning + early stopping |
## 🏷️ Label Mapping
| ID | Label | Description |
|---|---|---|
| 0 | `normal` | Normal lung tissue |
| 1 | `malignant` | Malignant (cancerous) tissue |
| 2 | `benign` | Benign (non-cancerous) tissue |
## 📅 Dataset
The model was trained on a comprehensive lung cancer dataset containing global clinical and risk factor data.
| Property | Details |
|---|---|
| **Total Records** | 1,500 patient records |
| **Features** | 41 variables (Clinical, Demographic, Genetic, Risk Factors) |
| **Period** | 2015 – 2025 |
| **Scope** | 60 countries across 6 WHO Regions |
| **Key Factors** | Smoking status, BMI, Air Pollution, Genetic Mutations, Tumor Stage |
## 🚀 Usage
### Install
```bash
pip install transformers torch pillow
```
### Inference
```python
from transformers import ViTForImageClassification, ViTImageProcessor
from PIL import Image
import torch
model_id = "TurkishCodeMan/vit-lung-cancer"
processor = ViTImageProcessor.from_pretrained(model_id)
model = ViTForImageClassification.from_pretrained(model_id)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval().to(device)
def predict(image_path: str) -> dict:
img = Image.open(image_path).convert("RGB")
inputs = processor(images=img, return_tensors="pt").to(device)
with torch.no_grad():
logits = model(**inputs).logits
pred_id = logits.argmax(-1).item()
probs = torch.softmax(logits.float(), dim=-1)[0]
return {
"prediction": model.config.id2label[pred_id],
"probabilities": {
label: round(probs[i].item(), 4)
for i, label in model.config.id2label.items()
}
}
result = predict("lung_scan.jpg")
print(result)
```
## 🛠️ Training Config
| Parameter | Value |
|---|---|
| **Optimizer** | AdamW |
| **Learning Rate** | 2e-5 |
| **Batch Size** | 16 |
| **Max Epochs** | 30 |
| **Early Stopping** | 5 epochs patience |
| **Mixed Precision**| fp16 |
| **Best Metric** | F1-Macro |