tech-doc commited on
Commit
9295690
·
verified ·
1 Parent(s): e9264d5

Update readme.md

Browse files
Files changed (1) hide show
  1. README.md +225 -3
README.md CHANGED
@@ -1,3 +1,225 @@
1
- ---
2
- license: cc-by-nc-sa-4.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc-by-nc-4.0
3
+ tags:
4
+ - medical-imaging
5
+ - dermatology
6
+ - skin-lesion-classification
7
+ - convnext
8
+ - isic
9
+ - multi-modal
10
+ base_model: facebook/convnext-base-224-22k-1k
11
+ metrics:
12
+ - balanced-accuracy
13
+ - macro-f1
14
+ - auc-roc
15
+ language:
16
+ - en
17
+ pipeline_tag: image-classification
18
+ ---
19
+
20
+ # ConvNeXt Dual-Modal Skin Lesion Classifier (ISIC 2025 / MILK10k)
21
+
22
+ This model classifies skin lesions into 11 diagnostic categories using paired dermoscopic and clinical photographs. It forms part of the **Skin AI** application, where it is called as a tool by MedGemma to provide structured skin lesion classification.
23
+
24
+ ## Model Description
25
+
26
+ A dual-input ConvNeXt-Base architecture trained end-to-end on the MILK10k dataset (ISIC 2025 Challenge). The model processes both a dermoscopic image and a clinical close-up photograph of the same lesion simultaneously, fusing representations before classification.
27
+
28
+ - **Architecture:** Dual ConvNeXt-Base with shared weights, late fusion
29
+ - **Input:** Paired dermoscopic + clinical images (384×384px)
30
+ - **Output:** Softmax probabilities over 11 ISIC diagnostic classes
31
+ - **Training:** 5-fold cross-validation, macro F1 optimisation
32
+ - **Ensemble:** 5 models (one per fold), predictions averaged
33
+
34
+ ## Intended Use
35
+
36
+ This model is intended for **research use only** as a component of the Skin AI application submitted to the MedGemma Impact Challenge. It is not validated for clinical use and must not be used to guide diagnosis or patient management.
37
+
38
+ **Intended users:** Researchers and developers building medical AI applications.
39
+
40
+ **Out of scope:** Direct clinical decision support, patient triage, or any deployment without further validation by qualified clinicians.
41
+
42
+ ## Diagnostic Classes
43
+
44
+ | Class | Description |
45
+ |-------|-------------|
46
+ | AKIEC | Actinic keratosis / intraepithelial carcinoma |
47
+ | BCC | Basal cell carcinoma |
48
+ | BEN_OTH | Other benign lesion |
49
+ | BKL | Benign keratosis |
50
+ | DF | Dermatofibroma |
51
+ | INF | Inflammatory / infectious |
52
+ | MAL_OTH | Other malignant lesion |
53
+ | MEL | Melanoma |
54
+ | NV | Melanocytic nevus |
55
+ | SCCKA | Squamous cell carcinoma / keratoacanthoma |
56
+ | VASC | Vascular lesion |
57
+
58
+ ## Performance
59
+
60
+ Evaluated on held-out validation folds from MILK10k training data (5-fold cross-validation, stratified by lesion diagnosis).
61
+
62
+ ### Aggregate Metrics
63
+
64
+ | Metric | Value |
65
+ |--------|-------|
66
+ | **Balanced Multiclass Accuracy** | **0.665** |
67
+ | Macro F1 (ConvNeXt alone) | 0.555 |
68
+ | Macro F1 (MedSigLIP + ConvNeXt ensemble) | 0.591 |
69
+ | ISIC 2025 Leaderboard (Dice) | 0.538 |
70
+
71
+ ### Per-Class Metrics (Validation, Single ConvNeXt)
72
+
73
+ | Class | AUC | AUC (Sens>80%) | Avg Precision | Sensitivity | Specificity | Dice | PPV | NPV |
74
+ |-------|-----|----------------|---------------|-------------|-------------|------|-----|-----|
75
+ | AKIEC | 0.933 | 0.873 | 0.704 | 0.732 | 0.924 | 0.675 | 0.627 | 0.952 |
76
+ | BCC | 0.975 | 0.960 | 0.838 | 0.951 | 0.919 | 0.758 | 0.630 | 0.992 |
77
+ | BEN_OTH | 0.978 | 0.953 | 0.505 | 0.429 | 0.998 | 0.545 | 0.750 | 0.992 |
78
+ | BKL | 0.881 | 0.713 | 0.746 | 0.750 | 0.865 | 0.664 | 0.595 | 0.929 |
79
+ | DF | 0.986 | 0.983 | 0.536 | 0.833 | 0.992 | 0.667 | 0.556 | 0.998 |
80
+ | INF | 0.841 | 0.722 | 0.164 | 0.364 | 0.985 | 0.364 | 0.364 | 0.985 |
81
+ | MAL_OTH | 0.820 | 0.717 | 0.518 | 0.400 | 0.993 | 0.571 | 1.000 | 0.987 |
82
+ | MEL | 0.957 | 0.935 | 0.820 | 0.821 | 0.950 | 0.688 | 0.593 | 0.984 |
83
+ | NV | 0.960 | 0.948 | 0.845 | 0.865 | 0.963 | 0.796 | 0.738 | 0.983 |
84
+ | SCCKA | 0.949 | 0.911 | 0.857 | 0.863 | 0.903 | 0.798 | 0.743 | 0.953 |
85
+ | VASC | 0.993 | 0.991 | 0.614 | 0.800 | 0.994 | 0.667 | 0.571 | 0.998 |
86
+ | **Mean** | **0.934** | **0.883** | **0.650** | **0.710** | **0.954** | **0.654** | **0.651** | **0.978** |
87
+
88
+ > **Note:** Rare classes (INF, MAL_OTH, BEN_OTH) show lower sensitivity due to class imbalance in the MILK10k dataset.
89
+
90
+ ## Usage
91
+
92
+ ```python
93
+ import torch
94
+ import torch.nn.functional as F
95
+ import timm
96
+ import torch.nn as nn
97
+ from PIL import Image
98
+ import torchvision.transforms as transforms
99
+ from huggingface_hub import hf_hub_download
100
+
101
+ # --- Model Definition ---
102
+
103
+ class DualConvNeXt(nn.Module):
104
+ def __init__(self, num_classes=11, model_name='convnext_base'):
105
+ super().__init__()
106
+ self.clinical_encoder = timm.create_model(
107
+ model_name, pretrained=False, num_classes=0
108
+ )
109
+ self.derm_encoder = timm.create_model(
110
+ model_name, pretrained=False, num_classes=0
111
+ )
112
+ feat_dim = self.clinical_encoder.num_features
113
+ self.classifier = nn.Sequential(
114
+ nn.Linear(feat_dim * 2, 512),
115
+ nn.ReLU(),
116
+ nn.Dropout(0.3),
117
+ nn.Linear(512, num_classes)
118
+ )
119
+
120
+ def forward(self, clinical, derm):
121
+ c = self.clinical_encoder(clinical)
122
+ d = self.derm_encoder(derm)
123
+ return self.classifier(torch.cat([c, d], dim=1))
124
+
125
+
126
+ # --- Load Model ---
127
+
128
+ CLASS_NAMES = ['AKIEC', 'BCC', 'BEN_OTH', 'BKL', 'DF',
129
+ 'INF', 'MAL_OTH', 'MEL', 'NV', 'SCCKA', 'VASC']
130
+
131
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
132
+
133
+ model = DualConvNeXt(num_classes=11)
134
+
135
+ # Load weights (update repo_id to your HF repo)
136
+ weights_path = hf_hub_download(
137
+ repo_id="tech-doc/ConvNeXt_Milk10k",
138
+ filename="convnext_fold0_best.pth"
139
+ )
140
+ checkpoint = torch.load(weights_path, map_location=device)
141
+ model.load_state_dict(checkpoint['model_state_dict'])
142
+ model.eval().to(device)
143
+
144
+
145
+ # --- Preprocessing ---
146
+
147
+ transform = transforms.Compose([
148
+ transforms.Resize((384, 384)),
149
+ transforms.ToTensor(),
150
+ transforms.Normalize(
151
+ mean=[0.485, 0.456, 0.406],
152
+ std=[0.229, 0.224, 0.225]
153
+ )
154
+ ])
155
+
156
+
157
+ # --- Inference ---
158
+
159
+ def predict(clinical_image_path: str, derm_image_path: str) -> dict:
160
+ """
161
+ Classify a skin lesion from paired images.
162
+
163
+ Args:
164
+ clinical_image_path: Path to clinical close-up photograph
165
+ derm_image_path: Path to dermoscopic image
166
+
167
+ Returns:
168
+ dict with 'prediction' (class name) and 'probabilities' (dict)
169
+ """
170
+ clinical = transform(Image.open(clinical_image_path).convert('RGB')).unsqueeze(0).to(device)
171
+ derm = transform(Image.open(derm_image_path).convert('RGB')).unsqueeze(0).to(device)
172
+
173
+ with torch.no_grad():
174
+ logits = model(clinical, derm)
175
+ probs = F.softmax(logits, dim=1).squeeze().cpu().numpy()
176
+
177
+ return {
178
+ 'prediction': CLASS_NAMES[probs.argmax()],
179
+ 'confidence': float(probs.max()),
180
+ 'probabilities': {c: float(p) for c, p in zip(CLASS_NAMES, probs)}
181
+ }
182
+
183
+
184
+ # Example
185
+ result = predict('clinical.jpg', 'dermoscopy.jpg')
186
+ print(f"Prediction: {result['prediction']} ({result['confidence']:.1%})")
187
+ ```
188
+
189
+ ## Training Details
190
+
191
+ - **Base model:** `convnext_base` (ImageNet-22k pretrained via timm)
192
+ - **Image size:** 384×384
193
+ - **Batch size:** 32
194
+ - **Optimiser:** AdamW, lr=1e-4
195
+ - **Scheduler:** Cosine annealing with warm restarts
196
+ - **Loss:** Cross-entropy with class weights + focal loss
197
+ - **Augmentation:** Random flips, rotations, colour jitter, RandAugment
198
+ - **Folds:** 5-fold stratified CV (seed 42)
199
+ - **GPU:** NVIDIA A100 (Google Colab)
200
+ - **Training time:** ~4–6 hours per fold
201
+
202
+ ## Limitations
203
+
204
+ - Trained exclusively on MILK10k (5,240 lesions). Performance on external datasets has not been validated.
205
+ - Rare classes (INF: 11 lesions, MAL_OTH: 15 lesions, VASC: 15 lesions) are underrepresented — sensitivity for these classes is lower.
206
+ - Model requires paired clinical + dermoscopic images; single-image inference is not supported.
207
+ - Not evaluated on paediatric patients or non-Fitzpatrick I–III skin tones at scale.
208
+
209
+ ## Citation
210
+
211
+ If you use this model, please cite the MILK10k dataset:
212
+
213
+ ```bibtex
214
+ @dataset{milk10k2025,
215
+ author = {MILK study team},
216
+ title = {MILK10k},
217
+ year = {2025},
218
+ publisher = {ISIC Archive},
219
+ doi = {10.34970/648456}
220
+ }
221
+ ```
222
+
223
+ ## License
224
+
225
+ **CC BY-NC 4.0** — This model was trained on MILK10k data (CC-BY-NC licensed). Non-commercial research use only.