vit-multilabel-14 / README.md
alex17cmbs's picture
Upload vit_base_patch16_224 multi-label model (AUC: 0.7595)
bf0e12b verified
metadata
tags:
  - pytorch
  - image-classification
  - medical-imaging
  - chest-x-ray
  - multi-label-classification
  - chexnet
  - vit-base-patch16-224
license: mit
datasets:
  - alkzar90/NIH-Chest-X-ray-dataset
language:
  - en
metrics:
  - accuracy
  - f1
  - roc_auc

🫁 ViT-Base-Patch16-224 - Multi-Label Chest X-ray Classification (14 Pathologies)

Ce modèle a été entraîné pour la classification multi-label de 14 pathologies thoraciques à partir de radiographies X-ray du dataset ChestX-ray14.

📋 Description

  • Architecture: ViT-Base-Patch16-224
  • Tâche: Classification multi-label (14 pathologies)
  • Dataset: NIH Chest X-ray (ChestX-ray14)
  • Framework: PyTorch
  • Image Size: 224×224

📊 Performance Globale

Métrique Valeur
AUC-ROC (macro) 0.7595
AUC-ROC (micro) 0.8191
F1 (macro) 0.0916
mAP 0.2041

Comparaison avec l'article CheXNet

Modèle AUC Macro Δ
CheXNet (article) 0.8414 -
Notre modèle 0.7595 -0.0819

📈 Performance par Pathologie

Pathologie AUROC F1 Support
Atelectasis 0.7253 0.0457 3279
Cardiomegaly 0.8585 0.2737 1069
Effusion 0.7918 0.3613 4658
Infiltration 0.6662 0.2098 6112
Mass 0.7563 0.1563 1748
Nodule 0.6822 0.0322 1623
Pneumonia 0.6705 0.0000 555
Pneumothorax 0.7889 0.0729 2665
Consolidation 0.7221 0.0129 1815
Edema 0.8229 0.0796 925
Emphysema 0.7516 0.0159 1093
Fibrosis 0.7701 0.0000 435
Pleural_Thickening 0.7429 0.0000 1143
Hernia 0.8836 0.0227 86

🏷️ Les 14 Pathologies

ID Pathologie
0 Atelectasis
1 Cardiomegaly
2 Effusion
3 Infiltration
4 Mass
5 Nodule
6 Pneumonia
7 Pneumothorax
8 Consolidation
9 Edema
10 Emphysema
11 Fibrosis
12 Pleural_Thickening
13 Hernia

⚙️ Configuration d'entraînement

{
  "data_variant": "full",
  "batch_size": 16,
  "image_size": 224,
  "num_classes": 14,
  "learning_rate": 0.0001,
  "num_epochs": 50,
  "scheduler": "ReduceLROnPlateau (factor=0.5, patience=5)",
  "optimizer": "AdamW (weight_decay=0.01)",
  "loss": "BCEWithLogitsLoss (non pond\u00e9r\u00e9e)"
}

🚀 Utilisation

import torch
from torchvision import transforms
from PIL import Image

# Charger le modèle
# Pour ViT
import timm

model = timm.create_model(
    'vit_base_patch16_224',
    pretrained=False,
    num_classes=14
)
model.load_state_dict(torch.load('pytorch_model.bin', map_location='cpu'))
model.eval()

# Préprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Pathologies
PATHOLOGIES = [
    "Atelectasis", "Cardiomegaly", "Effusion", "Infiltration",
    "Mass", "Nodule", "Pneumonia", "Pneumothorax", "Consolidation",
    "Edema", "Emphysema", "Fibrosis", "Pleural_Thickening", "Hernia"
]

# Prédiction
image = Image.open('chest_xray.png').convert('RGB')
input_tensor = transform(image).unsqueeze(0)

with torch.no_grad():
    logits = model(input_tensor)
    probs = torch.sigmoid(logits)

# Afficher les probabilités
for name, prob in zip(PATHOLOGIES, probs[0]):
    print(f"{name}: {prob:.4f}")

📚 Citation

@inproceedings{Wang_2017,
    title = {ChestX-Ray8: Hospital-Scale Chest X-Ray Database and Benchmarks},
    author = {Wang, Xiaosong and Peng, Yifan and Lu, Le and Lu, Zhiyong and Bagheri, Mohammadhadi and Summers, Ronald M},
    booktitle = {IEEE CVPR},
    year = {2017}
}

@article{rajpurkar2017chexnet,
    title={CheXNet: Radiologist-Level Pneumonia Detection on Chest X-Rays with Deep Learning},
    author={Rajpurkar, Pranav and others},
    journal={arXiv preprint arXiv:1711.05225},
    year={2017}
}

📄 License

MIT License