--- license: apache-2.0 tags: - image-classification - vision-transformer - vit - lung-cancer - medical-imaging - pytorch - transformers base_model: google/vit-base-patch16-224 pipeline_tag: image-classification --- # 🫁 ViT Lung Cancer Classifier Fine-tuned **Vision Transformer (ViT-Base/16)** for lung cancer CT image classification into 3 classes: **normal**, **malignant**, and **benign**. ## 📊 Model Details | Property | Value | |---|---| | Base Model | `google/vit-base-patch16-224` | | Task | Image Classification (3 classes) | | Input Size | 224 × 224 px | | Precision | fp16 | | Training | Full fine-tuning + early stopping | ## 🏷️ Label Mapping | ID | Label | Description | |---|---|---| | 0 | `normal` | Normal lung tissue | | 1 | `malignant` | Malignant (cancerous) tissue | | 2 | `benign` | Benign (non-cancerous) tissue | ## 📅 Dataset The model was trained on a comprehensive lung cancer dataset containing global clinical and risk factor data. | Property | Details | |---|---| | **Total Records** | 1,500 patient records | | **Features** | 41 variables (Clinical, Demographic, Genetic, Risk Factors) | | **Period** | 2015 – 2025 | | **Scope** | 60 countries across 6 WHO Regions | | **Key Factors** | Smoking status, BMI, Air Pollution, Genetic Mutations, Tumor Stage | ## 🚀 Usage ### Install ```bash pip install transformers torch pillow ``` ### Inference ```python from transformers import ViTForImageClassification, ViTImageProcessor from PIL import Image import torch model_id = "TurkishCodeMan/vit-lung-cancer" processor = ViTImageProcessor.from_pretrained(model_id) model = ViTForImageClassification.from_pretrained(model_id) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.eval().to(device) def predict(image_path: str) -> dict: img = Image.open(image_path).convert("RGB") inputs = processor(images=img, return_tensors="pt").to(device) with torch.no_grad(): logits = model(**inputs).logits pred_id = logits.argmax(-1).item() probs = torch.softmax(logits.float(), dim=-1)[0] return { "prediction": model.config.id2label[pred_id], "probabilities": { label: round(probs[i].item(), 4) for i, label in model.config.id2label.items() } } result = predict("lung_scan.jpg") print(result) ``` ## 🛠️ Training Config | Parameter | Value | |---|---| | **Optimizer** | AdamW | | **Learning Rate** | 2e-5 | | **Batch Size** | 16 | | **Max Epochs** | 30 | | **Early Stopping** | 5 epochs patience | | **Mixed Precision**| fp16 | | **Best Metric** | F1-Macro |