File size: 10,457 Bytes
9295690
 
 
b23b408
 
 
9295690
b23b408
 
 
 
9295690
b23b408
 
9295690
 
 
 
 
 
 
b23b408
 
 
 
9295690
 
 
b23b408
9295690
b23b408
 
 
 
 
 
 
 
 
9295690
 
 
b23b408
 
 
9295690
b23b408
 
 
 
 
 
 
 
 
 
9295690
b23b408
 
 
9295690
 
 
 
b23b408
9295690
 
 
 
 
 
 
 
 
 
 
 
b23b408
 
9295690
 
b23b408
9295690
 
 
 
b23b408
 
9295690
 
b23b408
9295690
b23b408
9295690
 
b23b408
9295690
 
 
 
 
 
 
 
 
 
 
 
 
b23b408
 
 
9295690
 
 
b23b408
 
9295690
 
 
 
 
 
 
 
 
b23b408
9295690
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b23b408
9295690
 
 
 
 
 
b23b408
9295690
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b23b408
 
9295690
 
b23b408
 
 
 
 
 
 
 
 
 
 
 
 
 
9295690
 
 
b23b408
 
 
 
 
 
 
 
 
9295690
 
 
b23b408
9295690
 
 
 
 
 
 
 
 
 
 
b23b408
 
9295690
 
b23b408
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
---
license: cc-by-nc-4.0
tags:
  - skin-lesion
  - dermoscopy
  - classification
  - convnext
  - medical-imaging
  - research
datasets:
  - ISIC/MILK10k
metrics:
  - f1
  - auc
language:
  - en
pipeline_tag: image-classification
---

# ConvNeXt Dual-Modal Skin Lesion Classifier (ISIC 2025 / MILK10k)

> **Research prototype — not validated for clinical use.**
> This model is released for reproducibility and research purposes only. It must not be used to guide clinical decisions, patient triage, or any diagnostic process. See [Limitations](#limitations) and [Out of Scope](#out-of-scope-uses).

---

## Model Description

A dual-input ConvNeXt-Base architecture trained end-to-end on the [MILK10k dataset](https://doi.org/10.34970/648456) (ISIC 2025 Challenge). The model processes a dermoscopic image and a clinical close-up photograph of the same lesion simultaneously, fusing feature representations before classification. It was developed as a research component submitted to the MedGemma Impact Challenge.

| Property | Value |
|---|---|
| Architecture | Dual ConvNeXt-Base, shared-weight encoders, late fusion |
| Input | Paired dermoscopic + clinical images (384×384 px each) |
| Output | Softmax probabilities over 11 ISIC diagnostic classes |
| Training | 5-fold stratified cross-validation, macro F1 optimisation |
| Ensemble | 5 models (one per fold), predictions averaged at inference |

---

## Intended Use

This model is released strictly for **non-commercial research and educational purposes**, as part of the SkinAI application submitted to the MedGemma Impact Challenge. It is provided to support reproducibility of the challenge submission and to enable further research into multi-modal skin lesion classification.

**Intended users:** Researchers and developers working on dermatology AI, machine learning in medical imaging, or related computational fields.

---

## Out-of-Scope Uses

The following uses are explicitly out of scope and are **not supported**:

- **Clinical diagnosis or decision support** — the model has not been validated for clinical deployment and must not influence patient care in any setting.
- **Patient triage or screening** — performance has only been evaluated on held-out folds of the MILK10k training distribution; generalisability to other populations, imaging devices, or clinical workflows is unknown.
- **Autonomous or semi-autonomous medical decision making** — any application in which model outputs could directly or indirectly affect patient management.
- **Deployment without independent clinical validation** — any production use would require prospective validation by qualified clinicians under appropriate regulatory oversight.

The performance metrics reported below reflect internal cross-validation on a single dataset and are **not sufficient evidence of clinical utility**.

---

## Diagnostic Classes

| Class | Description |
|---|---|
| AKIEC | Actinic keratosis / intraepithelial carcinoma |
| BCC | Basal cell carcinoma |
| BEN_OTH | Other benign lesion |
| BKL | Benign keratosis |
| DF | Dermatofibroma |
| INF | Inflammatory / infectious |
| MAL_OTH | Other malignant lesion |
| MEL | Melanoma |
| NV | Melanocytic nevus |
| SCCKA | Squamous cell carcinoma / keratoacanthoma |
| VASC | Vascular lesion |

---

## Performance

> **Important caveat:** All metrics below are from held-out validation folds of the MILK10k training dataset using 5-fold stratified cross-validation. They represent performance under distribution-matched conditions and should not be interpreted as estimates of real-world clinical performance. External validation has not been performed.

### Aggregate Metrics

| Metric | Value |
|---|---|
| Balanced Multiclass Accuracy | 0.665 |
| Macro F1 (ConvNeXt alone) | 0.555 |
| Macro F1 (MedSigLIP + ConvNeXt ensemble) | 0.591 |
| ISIC 2025 Leaderboard Score (Dice) | 0.538 |

### Per-Class Metrics (Validation, Single ConvNeXt Fold)

| Class | AUC | AUC (Sens>80%) | Avg Precision | Sensitivity | Specificity | Dice | PPV | NPV |
|---|---|---|---|---|---|---|---|---|
| AKIEC | 0.933 | 0.873 | 0.704 | 0.732 | 0.924 | 0.675 | 0.627 | 0.952 |
| BCC | 0.975 | 0.960 | 0.838 | 0.951 | 0.919 | 0.758 | 0.630 | 0.992 |
| BEN_OTH | 0.978 | 0.953 | 0.505 | 0.429 | 0.998 | 0.545 | 0.750 | 0.992 |
| BKL | 0.881 | 0.713 | 0.746 | 0.750 | 0.865 | 0.664 | 0.595 | 0.929 |
| DF | 0.986 | 0.983 | 0.536 | 0.833 | 0.992 | 0.667 | 0.556 | 0.998 |
| INF | 0.841 | 0.722 | 0.164 | 0.364 | 0.985 | 0.364 | 0.364 | 0.985 |
| MAL_OTH | 0.820 | 0.717 | 0.518 | 0.400 | 0.993 | 0.571 | 1.000 | 0.987 |
| MEL | 0.957 | 0.935 | 0.820 | 0.821 | 0.950 | 0.688 | 0.593 | 0.984 |
| NV | 0.960 | 0.948 | 0.845 | 0.865 | 0.963 | 0.796 | 0.738 | 0.983 |
| SCCKA | 0.949 | 0.911 | 0.857 | 0.863 | 0.903 | 0.798 | 0.743 | 0.953 |
| VASC | 0.993 | 0.991 | 0.614 | 0.800 | 0.994 | 0.667 | 0.571 | 0.998 |
| **Mean** | **0.934** | **0.883** | **0.650** | **0.710** | **0.954** | **0.654** | **0.651** | **0.978** |

> Rare classes (INF: ~11 lesions, MAL_OTH: ~15 lesions, VASC: ~15 lesions) are severely underrepresented in MILK10k. Sensitivity figures for these classes should be interpreted with caution given the small sample sizes involved.

---

## Usage

This code is provided for research reproducibility. Users are responsible for ensuring any application complies with applicable laws and ethical guidelines.

```python
import torch
import torch.nn.functional as F
import timm
import torch.nn as nn
from PIL import Image
import torchvision.transforms as transforms
from huggingface_hub import hf_hub_download


# --- Model Definition ---

class DualConvNeXt(nn.Module):
    def __init__(self, num_classes=11, model_name='convnext_base'):
        super().__init__()
        self.clinical_encoder = timm.create_model(
            model_name, pretrained=False, num_classes=0
        )
        self.derm_encoder = timm.create_model(
            model_name, pretrained=False, num_classes=0
        )
        feat_dim = self.clinical_encoder.num_features
        self.classifier = nn.Sequential(
            nn.Linear(feat_dim * 2, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, clinical, derm):
        c = self.clinical_encoder(clinical)
        d = self.derm_encoder(derm)
        return self.classifier(torch.cat([c, d], dim=1))


# --- Load Model ---

CLASS_NAMES = ['AKIEC', 'BCC', 'BEN_OTH', 'BKL', 'DF',
               'INF', 'MAL_OTH', 'MEL', 'NV', 'SCCKA', 'VASC']

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = DualConvNeXt(num_classes=11)

weights_path = hf_hub_download(
    repo_id="tech-doc/ConvNeXt_Milk10k",
    filename="convnext_fold0_best.pth"
)
checkpoint = torch.load(weights_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval().to(device)


# --- Preprocessing ---

transform = transforms.Compose([
    transforms.Resize((384, 384)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])


# --- Inference ---

def predict(clinical_image_path: str, derm_image_path: str) -> dict:
    """
    Research inference only. Output must not be used for clinical decisions.

    Args:
        clinical_image_path: Path to clinical close-up photograph
        derm_image_path: Path to dermoscopic image

    Returns:
        dict with 'prediction', 'confidence', and 'probabilities'
    """
    clinical = transform(Image.open(clinical_image_path).convert('RGB')).unsqueeze(0).to(device)
    derm = transform(Image.open(derm_image_path).convert('RGB')).unsqueeze(0).to(device)

    with torch.no_grad():
        logits = model(clinical, derm)
        probs = F.softmax(logits, dim=1).squeeze().cpu().numpy()

    return {
        'prediction': CLASS_NAMES[probs.argmax()],
        'confidence': float(probs.max()),
        'probabilities': {c: float(p) for c, p in zip(CLASS_NAMES, probs)}
    }


# Example
result = predict('clinical.jpg', 'dermoscopy.jpg')
print(f"Prediction: {result['prediction']} ({result['confidence']:.1%})")
```

---

## Training Details

| Parameter | Value |
|---|---|
| Base model | `convnext_base` (ImageNet-22k pretrained via `timm`) |
| Image size | 384×384 px |
| Batch size | 32 |
| Optimiser | AdamW, lr=1e-4 |
| Scheduler | Cosine annealing with warm restarts |
| Loss | Cross-entropy with class weights + focal loss |
| Augmentation | Random flips, rotations, colour jitter, RandAugment |
| Folds | 5-fold stratified CV (seed 42) |
| Hardware | NVIDIA A100 (Google Colab) |
| Training time | ~4–6 hours per fold |

---

## Limitations

- **Single-dataset evaluation:** Trained and evaluated exclusively on MILK10k (~5,240 lesions). No external validation has been performed. Reported metrics should not be generalised beyond this distribution.
- **Severe class imbalance:** Rare classes (INF: ~11 lesions, MAL_OTH: ~15 lesions, VASC: ~15 lesions) are underrepresented. Performance on these classes is highly uncertain and may not be reproducible on different samples.
- **Paired-image requirement:** The model requires simultaneous dermoscopic and clinical photographs of the same lesion. Single-image inference is architecturally unsupported and was not evaluated.
- **Skin tone representation:** The MILK10k dataset composition with respect to Fitzpatrick phototype has not been fully characterised. Performance across darker skin tones (Fitzpatrick IV–VI) has not been validated.
- **Paediatric populations:** The model was not evaluated on paediatric patients.
- **Device variability:** Performance may degrade with imaging devices, magnifications, or lighting conditions not represented in the training data.
- **No prospective validation:** All reported metrics are from retrospective cross-validation. Prospective clinical validation would be required before any consideration of real-world use.

---

## Citation

If you use this model or the MILK10k dataset in your research, please cite:

```bibtex
@dataset{milk10k2025,
  author    = {MILK study team},
  title     = {MILK10k},
  year      = {2025},
  publisher = {ISIC Archive},
  doi       = {10.34970/648456}
}
```

---

## License

**CC BY-NC 4.0** — This model was trained on MILK10k data (CC-BY-NC licensed). Non-commercial research use only. Any commercial application is prohibited without explicit permission.