File size: 2,650 Bytes
31c52d1 7a22ac6 31c52d1 7a22ac6 31c52d1 7a22ac6 31c52d1 7a22ac6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 | ---
license: apache-2.0
tags:
- image-classification
- vision-transformer
- vit
- lung-cancer
- medical-imaging
- pytorch
- transformers
base_model: google/vit-base-patch16-224
pipeline_tag: image-classification
---
# 🫁 ViT Lung Cancer Classifier
Fine-tuned **Vision Transformer (ViT-Base/16)** for lung cancer CT image classification
into 3 classes: **normal**, **malignant**, and **benign**.
## 📊 Model Details
| Property | Value |
|---|---|
| Base Model | `google/vit-base-patch16-224` |
| Task | Image Classification (3 classes) |
| Input Size | 224 × 224 px |
| Precision | fp16 |
| Training | Full fine-tuning + early stopping |
## 🏷️ Label Mapping
| ID | Label | Description |
|---|---|---|
| 0 | `normal` | Normal lung tissue |
| 1 | `malignant` | Malignant (cancerous) tissue |
| 2 | `benign` | Benign (non-cancerous) tissue |
## 📅 Dataset
The model was trained on a comprehensive lung cancer dataset containing global clinical and risk factor data.
| Property | Details |
|---|---|
| **Total Records** | 1,500 patient records |
| **Features** | 41 variables (Clinical, Demographic, Genetic, Risk Factors) |
| **Period** | 2015 – 2025 |
| **Scope** | 60 countries across 6 WHO Regions |
| **Key Factors** | Smoking status, BMI, Air Pollution, Genetic Mutations, Tumor Stage |
## 🚀 Usage
### Install
```bash
pip install transformers torch pillow
```
### Inference
```python
from transformers import ViTForImageClassification, ViTImageProcessor
from PIL import Image
import torch
model_id = "TurkishCodeMan/vit-lung-cancer"
processor = ViTImageProcessor.from_pretrained(model_id)
model = ViTForImageClassification.from_pretrained(model_id)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval().to(device)
def predict(image_path: str) -> dict:
img = Image.open(image_path).convert("RGB")
inputs = processor(images=img, return_tensors="pt").to(device)
with torch.no_grad():
logits = model(**inputs).logits
pred_id = logits.argmax(-1).item()
probs = torch.softmax(logits.float(), dim=-1)[0]
return {
"prediction": model.config.id2label[pred_id],
"probabilities": {
label: round(probs[i].item(), 4)
for i, label in model.config.id2label.items()
}
}
result = predict("lung_scan.jpg")
print(result)
```
## 🛠️ Training Config
| Parameter | Value |
|---|---|
| **Optimizer** | AdamW |
| **Learning Rate** | 2e-5 |
| **Batch Size** | 16 |
| **Max Epochs** | 30 |
| **Early Stopping** | 5 epochs patience |
| **Mixed Precision**| fp16 |
| **Best Metric** | F1-Macro |
|