ConvNeXt_Milk10k / README.md
tech-doc's picture
Add research model card with cautious clinical use disclaimers
b23b408 verified
---
license: cc-by-nc-4.0
tags:
- skin-lesion
- dermoscopy
- classification
- convnext
- medical-imaging
- research
datasets:
- ISIC/MILK10k
metrics:
- f1
- auc
language:
- en
pipeline_tag: image-classification
---
# ConvNeXt Dual-Modal Skin Lesion Classifier (ISIC 2025 / MILK10k)
> **Research prototype — not validated for clinical use.**
> This model is released for reproducibility and research purposes only. It must not be used to guide clinical decisions, patient triage, or any diagnostic process. See [Limitations](#limitations) and [Out of Scope](#out-of-scope-uses).
---
## Model Description
A dual-input ConvNeXt-Base architecture trained end-to-end on the [MILK10k dataset](https://doi.org/10.34970/648456) (ISIC 2025 Challenge). The model processes a dermoscopic image and a clinical close-up photograph of the same lesion simultaneously, fusing feature representations before classification. It was developed as a research component submitted to the MedGemma Impact Challenge.
| Property | Value |
|---|---|
| Architecture | Dual ConvNeXt-Base, shared-weight encoders, late fusion |
| Input | Paired dermoscopic + clinical images (384×384 px each) |
| Output | Softmax probabilities over 11 ISIC diagnostic classes |
| Training | 5-fold stratified cross-validation, macro F1 optimisation |
| Ensemble | 5 models (one per fold), predictions averaged at inference |
---
## Intended Use
This model is released strictly for **non-commercial research and educational purposes**, as part of the SkinAI application submitted to the MedGemma Impact Challenge. It is provided to support reproducibility of the challenge submission and to enable further research into multi-modal skin lesion classification.
**Intended users:** Researchers and developers working on dermatology AI, machine learning in medical imaging, or related computational fields.
---
## Out-of-Scope Uses
The following uses are explicitly out of scope and are **not supported**:
- **Clinical diagnosis or decision support** — the model has not been validated for clinical deployment and must not influence patient care in any setting.
- **Patient triage or screening** — performance has only been evaluated on held-out folds of the MILK10k training distribution; generalisability to other populations, imaging devices, or clinical workflows is unknown.
- **Autonomous or semi-autonomous medical decision making** — any application in which model outputs could directly or indirectly affect patient management.
- **Deployment without independent clinical validation** — any production use would require prospective validation by qualified clinicians under appropriate regulatory oversight.
The performance metrics reported below reflect internal cross-validation on a single dataset and are **not sufficient evidence of clinical utility**.
---
## Diagnostic Classes
| Class | Description |
|---|---|
| AKIEC | Actinic keratosis / intraepithelial carcinoma |
| BCC | Basal cell carcinoma |
| BEN_OTH | Other benign lesion |
| BKL | Benign keratosis |
| DF | Dermatofibroma |
| INF | Inflammatory / infectious |
| MAL_OTH | Other malignant lesion |
| MEL | Melanoma |
| NV | Melanocytic nevus |
| SCCKA | Squamous cell carcinoma / keratoacanthoma |
| VASC | Vascular lesion |
---
## Performance
> **Important caveat:** All metrics below are from held-out validation folds of the MILK10k training dataset using 5-fold stratified cross-validation. They represent performance under distribution-matched conditions and should not be interpreted as estimates of real-world clinical performance. External validation has not been performed.
### Aggregate Metrics
| Metric | Value |
|---|---|
| Balanced Multiclass Accuracy | 0.665 |
| Macro F1 (ConvNeXt alone) | 0.555 |
| Macro F1 (MedSigLIP + ConvNeXt ensemble) | 0.591 |
| ISIC 2025 Leaderboard Score (Dice) | 0.538 |
### Per-Class Metrics (Validation, Single ConvNeXt Fold)
| Class | AUC | AUC (Sens>80%) | Avg Precision | Sensitivity | Specificity | Dice | PPV | NPV |
|---|---|---|---|---|---|---|---|---|
| AKIEC | 0.933 | 0.873 | 0.704 | 0.732 | 0.924 | 0.675 | 0.627 | 0.952 |
| BCC | 0.975 | 0.960 | 0.838 | 0.951 | 0.919 | 0.758 | 0.630 | 0.992 |
| BEN_OTH | 0.978 | 0.953 | 0.505 | 0.429 | 0.998 | 0.545 | 0.750 | 0.992 |
| BKL | 0.881 | 0.713 | 0.746 | 0.750 | 0.865 | 0.664 | 0.595 | 0.929 |
| DF | 0.986 | 0.983 | 0.536 | 0.833 | 0.992 | 0.667 | 0.556 | 0.998 |
| INF | 0.841 | 0.722 | 0.164 | 0.364 | 0.985 | 0.364 | 0.364 | 0.985 |
| MAL_OTH | 0.820 | 0.717 | 0.518 | 0.400 | 0.993 | 0.571 | 1.000 | 0.987 |
| MEL | 0.957 | 0.935 | 0.820 | 0.821 | 0.950 | 0.688 | 0.593 | 0.984 |
| NV | 0.960 | 0.948 | 0.845 | 0.865 | 0.963 | 0.796 | 0.738 | 0.983 |
| SCCKA | 0.949 | 0.911 | 0.857 | 0.863 | 0.903 | 0.798 | 0.743 | 0.953 |
| VASC | 0.993 | 0.991 | 0.614 | 0.800 | 0.994 | 0.667 | 0.571 | 0.998 |
| **Mean** | **0.934** | **0.883** | **0.650** | **0.710** | **0.954** | **0.654** | **0.651** | **0.978** |
> Rare classes (INF: ~11 lesions, MAL_OTH: ~15 lesions, VASC: ~15 lesions) are severely underrepresented in MILK10k. Sensitivity figures for these classes should be interpreted with caution given the small sample sizes involved.
---
## Usage
This code is provided for research reproducibility. Users are responsible for ensuring any application complies with applicable laws and ethical guidelines.
```python
import torch
import torch.nn.functional as F
import timm
import torch.nn as nn
from PIL import Image
import torchvision.transforms as transforms
from huggingface_hub import hf_hub_download
# --- Model Definition ---
class DualConvNeXt(nn.Module):
def __init__(self, num_classes=11, model_name='convnext_base'):
super().__init__()
self.clinical_encoder = timm.create_model(
model_name, pretrained=False, num_classes=0
)
self.derm_encoder = timm.create_model(
model_name, pretrained=False, num_classes=0
)
feat_dim = self.clinical_encoder.num_features
self.classifier = nn.Sequential(
nn.Linear(feat_dim * 2, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, num_classes)
)
def forward(self, clinical, derm):
c = self.clinical_encoder(clinical)
d = self.derm_encoder(derm)
return self.classifier(torch.cat([c, d], dim=1))
# --- Load Model ---
CLASS_NAMES = ['AKIEC', 'BCC', 'BEN_OTH', 'BKL', 'DF',
'INF', 'MAL_OTH', 'MEL', 'NV', 'SCCKA', 'VASC']
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DualConvNeXt(num_classes=11)
weights_path = hf_hub_download(
repo_id="tech-doc/ConvNeXt_Milk10k",
filename="convnext_fold0_best.pth"
)
checkpoint = torch.load(weights_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval().to(device)
# --- Preprocessing ---
transform = transforms.Compose([
transforms.Resize((384, 384)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# --- Inference ---
def predict(clinical_image_path: str, derm_image_path: str) -> dict:
"""
Research inference only. Output must not be used for clinical decisions.
Args:
clinical_image_path: Path to clinical close-up photograph
derm_image_path: Path to dermoscopic image
Returns:
dict with 'prediction', 'confidence', and 'probabilities'
"""
clinical = transform(Image.open(clinical_image_path).convert('RGB')).unsqueeze(0).to(device)
derm = transform(Image.open(derm_image_path).convert('RGB')).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(clinical, derm)
probs = F.softmax(logits, dim=1).squeeze().cpu().numpy()
return {
'prediction': CLASS_NAMES[probs.argmax()],
'confidence': float(probs.max()),
'probabilities': {c: float(p) for c, p in zip(CLASS_NAMES, probs)}
}
# Example
result = predict('clinical.jpg', 'dermoscopy.jpg')
print(f"Prediction: {result['prediction']} ({result['confidence']:.1%})")
```
---
## Training Details
| Parameter | Value |
|---|---|
| Base model | `convnext_base` (ImageNet-22k pretrained via `timm`) |
| Image size | 384×384 px |
| Batch size | 32 |
| Optimiser | AdamW, lr=1e-4 |
| Scheduler | Cosine annealing with warm restarts |
| Loss | Cross-entropy with class weights + focal loss |
| Augmentation | Random flips, rotations, colour jitter, RandAugment |
| Folds | 5-fold stratified CV (seed 42) |
| Hardware | NVIDIA A100 (Google Colab) |
| Training time | ~4–6 hours per fold |
---
## Limitations
- **Single-dataset evaluation:** Trained and evaluated exclusively on MILK10k (~5,240 lesions). No external validation has been performed. Reported metrics should not be generalised beyond this distribution.
- **Severe class imbalance:** Rare classes (INF: ~11 lesions, MAL_OTH: ~15 lesions, VASC: ~15 lesions) are underrepresented. Performance on these classes is highly uncertain and may not be reproducible on different samples.
- **Paired-image requirement:** The model requires simultaneous dermoscopic and clinical photographs of the same lesion. Single-image inference is architecturally unsupported and was not evaluated.
- **Skin tone representation:** The MILK10k dataset composition with respect to Fitzpatrick phototype has not been fully characterised. Performance across darker skin tones (Fitzpatrick IV–VI) has not been validated.
- **Paediatric populations:** The model was not evaluated on paediatric patients.
- **Device variability:** Performance may degrade with imaging devices, magnifications, or lighting conditions not represented in the training data.
- **No prospective validation:** All reported metrics are from retrospective cross-validation. Prospective clinical validation would be required before any consideration of real-world use.
---
## Citation
If you use this model or the MILK10k dataset in your research, please cite:
```bibtex
@dataset{milk10k2025,
author = {MILK study team},
title = {MILK10k},
year = {2025},
publisher = {ISIC Archive},
doi = {10.34970/648456}
}
```
---
## License
**CC BY-NC 4.0** — This model was trained on MILK10k data (CC-BY-NC licensed). Non-commercial research use only. Any commercial application is prohibited without explicit permission.