| --- |
| license: apache-2.0 |
| tags: |
| - image-classification |
| - vision-transformer |
| - vit |
| - lung-cancer |
| - medical-imaging |
| - pytorch |
| - transformers |
| base_model: google/vit-base-patch16-224 |
| pipeline_tag: image-classification |
| --- |
| |
| # 🫁 ViT Lung Cancer Classifier |
|
|
| Fine-tuned **Vision Transformer (ViT-Base/16)** for lung cancer CT image classification |
| into 3 classes: **normal**, **malignant**, and **benign**. |
|
|
| ## 📊 Model Details |
|
|
| | Property | Value | |
| |---|---| |
| | Base Model | `google/vit-base-patch16-224` | |
| | Task | Image Classification (3 classes) | |
| | Input Size | 224 × 224 px | |
| | Precision | fp16 | |
| | Training | Full fine-tuning + early stopping | |
|
|
| ## 🏷️ Label Mapping |
|
|
| | ID | Label | Description | |
| |---|---|---| |
| | 0 | `normal` | Normal lung tissue | |
| | 1 | `malignant` | Malignant (cancerous) tissue | |
| | 2 | `benign` | Benign (non-cancerous) tissue | |
|
|
| ## 📅 Dataset |
|
|
| The model was trained on a comprehensive lung cancer dataset containing global clinical and risk factor data. |
|
|
| | Property | Details | |
| |---|---| |
| | **Total Records** | 1,500 patient records | |
| | **Features** | 41 variables (Clinical, Demographic, Genetic, Risk Factors) | |
| | **Period** | 2015 – 2025 | |
| | **Scope** | 60 countries across 6 WHO Regions | |
| | **Key Factors** | Smoking status, BMI, Air Pollution, Genetic Mutations, Tumor Stage | |
|
|
| ## 🚀 Usage |
|
|
| ### Install |
|
|
| ```bash |
| pip install transformers torch pillow |
| ``` |
|
|
| ### Inference |
|
|
| ```python |
| from transformers import ViTForImageClassification, ViTImageProcessor |
| from PIL import Image |
| import torch |
| |
| model_id = "TurkishCodeMan/vit-lung-cancer" |
| |
| processor = ViTImageProcessor.from_pretrained(model_id) |
| model = ViTForImageClassification.from_pretrained(model_id) |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model.eval().to(device) |
| |
| def predict(image_path: str) -> dict: |
| img = Image.open(image_path).convert("RGB") |
| inputs = processor(images=img, return_tensors="pt").to(device) |
| |
| with torch.no_grad(): |
| logits = model(**inputs).logits |
| |
| pred_id = logits.argmax(-1).item() |
| probs = torch.softmax(logits.float(), dim=-1)[0] |
| |
| return { |
| "prediction": model.config.id2label[pred_id], |
| "probabilities": { |
| label: round(probs[i].item(), 4) |
| for i, label in model.config.id2label.items() |
| } |
| } |
| |
| result = predict("lung_scan.jpg") |
| print(result) |
| ``` |
|
|
| ## 🛠️ Training Config |
|
|
| | Parameter | Value | |
| |---|---| |
| | **Optimizer** | AdamW | |
| | **Learning Rate** | 2e-5 | |
| | **Batch Size** | 16 | |
| | **Max Epochs** | 30 | |
| | **Early Stopping** | 5 epochs patience | |
| | **Mixed Precision**| fp16 | |
| | **Best Metric** | F1-Macro | |
|
|