tech-doc commited on
Commit
84a0314
Β·
verified Β·
1 Parent(s): 9295690

upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +294 -0
inference.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ inference.py β€” ConvNeXt Dual-Modal Skin Lesion Classifier
3
+ ISIC 2025 / MILK10k | CC BY-NC 4.0
4
+
5
+ Classifies skin lesions from paired dermoscopic + clinical images into 11 categories.
6
+ Used as a tool called by MedGemma in the Skin AI application.
7
+ """
8
+
9
+ import os
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import timm
15
+ from PIL import Image
16
+ import torchvision.transforms as transforms
17
+ from pathlib import Path
18
+ from typing import Union
19
+
20
+ # ─────────────────────────────────────────────
21
+ # Constants
22
+ # ─────────────────────────────────────────────
23
+
24
+ CLASS_NAMES = ['AKIEC', 'BCC', 'BEN_OTH', 'BKL', 'DF',
25
+ 'INF', 'MAL_OTH', 'MEL', 'NV', 'SCCKA', 'VASC']
26
+
27
+ CLASS_DESCRIPTIONS = {
28
+ 'AKIEC': 'Actinic keratosis / intraepithelial carcinoma',
29
+ 'BCC': 'Basal cell carcinoma',
30
+ 'BEN_OTH': 'Other benign lesion',
31
+ 'BKL': 'Benign keratosis',
32
+ 'DF': 'Dermatofibroma',
33
+ 'INF': 'Inflammatory / infectious',
34
+ 'MAL_OTH': 'Other malignant lesion',
35
+ 'MEL': 'Melanoma',
36
+ 'NV': 'Melanocytic nevus',
37
+ 'SCCKA': 'Squamous cell carcinoma / keratoacanthoma',
38
+ 'VASC': 'Vascular lesion',
39
+ }
40
+
41
+ IMG_SIZE = 384
42
+
43
+ TRANSFORM = transforms.Compose([
44
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
45
+ transforms.ToTensor(),
46
+ transforms.Normalize(
47
+ mean=[0.485, 0.456, 0.406],
48
+ std=[0.229, 0.224, 0.225]
49
+ )
50
+ ])
51
+
52
+
53
+ # ─────────────────────────────────────────────
54
+ # Architecture
55
+ # ─────────────────────────────────────────────
56
+
57
+ class DualConvNeXt(nn.Module):
58
+ """
59
+ Dual-input ConvNeXt-Base for paired dermoscopic + clinical image classification.
60
+ Both encoders share the same architecture but are trained independently.
61
+ """
62
+
63
+ def __init__(self, num_classes: int = 11, model_name: str = 'convnext_base'):
64
+ super().__init__()
65
+ self.clinical_encoder = timm.create_model(
66
+ model_name, pretrained=False, num_classes=0
67
+ )
68
+ self.derm_encoder = timm.create_model(
69
+ model_name, pretrained=False, num_classes=0
70
+ )
71
+ feat_dim = self.clinical_encoder.num_features # 1024 for convnext_base
72
+ self.classifier = nn.Sequential(
73
+ nn.Linear(feat_dim * 2, 512),
74
+ nn.ReLU(),
75
+ nn.Dropout(0.3),
76
+ nn.Linear(512, num_classes)
77
+ )
78
+
79
+ def forward(self, clinical: torch.Tensor, derm: torch.Tensor) -> torch.Tensor:
80
+ c = self.clinical_encoder(clinical)
81
+ d = self.derm_encoder(derm)
82
+ return self.classifier(torch.cat([c, d], dim=1))
83
+
84
+
85
+ # ─────────────────────────────────────────────
86
+ # Model loading
87
+ # ─────────────────────────────────────────────
88
+
89
+ def load_model(
90
+ weights_path: Union[str, Path],
91
+ device: torch.device = None
92
+ ) -> DualConvNeXt:
93
+ """
94
+ Load a trained DualConvNeXt model from a checkpoint file.
95
+
96
+ Args:
97
+ weights_path: Path to .pth checkpoint (expects dict with 'model_state_dict')
98
+ device: torch.device β€” defaults to CUDA if available
99
+
100
+ Returns:
101
+ Loaded model in eval mode
102
+ """
103
+ if device is None:
104
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
105
+
106
+ model = DualConvNeXt(num_classes=len(CLASS_NAMES))
107
+ checkpoint = torch.load(weights_path, map_location=device)
108
+
109
+ # Handle both raw state dict and wrapped checkpoints
110
+ state = checkpoint.get('model_state_dict', checkpoint)
111
+ model.load_state_dict(state)
112
+ model.eval().to(device)
113
+ return model
114
+
115
+
116
+ def load_ensemble(
117
+ weights_dir: Union[str, Path],
118
+ device: torch.device = None
119
+ ) -> list:
120
+ """
121
+ Load all fold models from a directory for ensemble inference.
122
+
123
+ Args:
124
+ weights_dir: Directory containing convnext_fold*.pth files
125
+ device: torch.device
126
+
127
+ Returns:
128
+ List of loaded models
129
+ """
130
+ if device is None:
131
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
132
+
133
+ weights_dir = Path(weights_dir)
134
+ model_paths = sorted(weights_dir.glob('convnext_fold*.pth'))
135
+
136
+ if not model_paths:
137
+ raise FileNotFoundError(f"No fold checkpoints found in {weights_dir}")
138
+
139
+ models = [load_model(p, device) for p in model_paths]
140
+ print(f"Loaded {len(models)} fold models from {weights_dir}")
141
+ return models
142
+
143
+
144
+ # ─────────────────────────────────────────────
145
+ # Preprocessing
146
+ # ─────────────────────────────────────────────
147
+
148
+ def preprocess_image(image_path: Union[str, Path]) -> torch.Tensor:
149
+ """Load and preprocess a single image to model input format."""
150
+ img = Image.open(image_path).convert('RGB')
151
+ return TRANSFORM(img)
152
+
153
+
154
+ # ─────────────────────────────────────────────
155
+ # Inference
156
+ # ─────────────────────────────────────────────
157
+
158
+ def predict_single(
159
+ model: DualConvNeXt,
160
+ clinical_path: Union[str, Path],
161
+ derm_path: Union[str, Path],
162
+ device: torch.device = None
163
+ ) -> dict:
164
+ """
165
+ Run inference with a single model.
166
+
167
+ Args:
168
+ model: Loaded DualConvNeXt model
169
+ clinical_path: Path to clinical close-up image
170
+ derm_path: Path to dermoscopic image
171
+ device: torch.device
172
+
173
+ Returns:
174
+ dict with prediction, confidence, and per-class probabilities
175
+ """
176
+ if device is None:
177
+ device = next(model.parameters()).device
178
+
179
+ clinical = preprocess_image(clinical_path).unsqueeze(0).to(device)
180
+ derm = preprocess_image(derm_path).unsqueeze(0).to(device)
181
+
182
+ with torch.no_grad():
183
+ logits = model(clinical, derm)
184
+ probs = F.softmax(logits, dim=1).squeeze().cpu().numpy()
185
+
186
+ pred_idx = int(probs.argmax())
187
+ return {
188
+ 'prediction': CLASS_NAMES[pred_idx],
189
+ 'description': CLASS_DESCRIPTIONS[CLASS_NAMES[pred_idx]],
190
+ 'confidence': float(probs[pred_idx]),
191
+ 'probabilities': {c: float(p) for c, p in zip(CLASS_NAMES, probs)}
192
+ }
193
+
194
+
195
+ def predict_ensemble(
196
+ models: list,
197
+ clinical_path: Union[str, Path],
198
+ derm_path: Union[str, Path],
199
+ device: torch.device = None
200
+ ) -> dict:
201
+ """
202
+ Run ensemble inference by averaging softmax probabilities across fold models.
203
+
204
+ Args:
205
+ models: List of loaded DualConvNeXt models
206
+ clinical_path: Path to clinical close-up image
207
+ derm_path: Path to dermoscopic image
208
+ device: torch.device
209
+
210
+ Returns:
211
+ dict with ensemble prediction, confidence, per-class probabilities,
212
+ and per-model probability breakdown
213
+ """
214
+ if device is None:
215
+ device = next(models[0].parameters()).device
216
+
217
+ clinical = preprocess_image(clinical_path).unsqueeze(0).to(device)
218
+ derm = preprocess_image(derm_path).unsqueeze(0).to(device)
219
+
220
+ all_probs = []
221
+ with torch.no_grad():
222
+ for model in models:
223
+ logits = model(clinical, derm)
224
+ probs = F.softmax(logits, dim=1).squeeze().cpu().numpy()
225
+ all_probs.append(probs)
226
+
227
+ ensemble_probs = np.mean(all_probs, axis=0)
228
+ pred_idx = int(ensemble_probs.argmax())
229
+
230
+ return {
231
+ 'prediction': CLASS_NAMES[pred_idx],
232
+ 'description': CLASS_DESCRIPTIONS[CLASS_NAMES[pred_idx]],
233
+ 'confidence': float(ensemble_probs[pred_idx]),
234
+ 'probabilities': {c: float(p) for c, p in zip(CLASS_NAMES, ensemble_probs)},
235
+ 'n_models': len(models)
236
+ }
237
+
238
+
239
+ # ─────────────────────────────────────────────
240
+ # Batch inference
241
+ # ─────────────────────────────────────────────
242
+
243
+ def predict_batch(
244
+ models: list,
245
+ pairs: list,
246
+ device: torch.device = None
247
+ ) -> list:
248
+ """
249
+ Run ensemble inference over a batch of image pairs.
250
+
251
+ Args:
252
+ models: List of loaded DualConvNeXt models
253
+ pairs: List of (clinical_path, derm_path) tuples
254
+ device: torch.device
255
+
256
+ Returns:
257
+ List of result dicts (same format as predict_ensemble)
258
+ """
259
+ return [predict_ensemble(models, c, d, device) for c, d in pairs]
260
+
261
+
262
+ # ─────────────────────────────────────────────
263
+ # CLI / Quick test
264
+ # ─────────────────────────────────────────────
265
+
266
+ if __name__ == '__main__':
267
+ import argparse
268
+
269
+ parser = argparse.ArgumentParser(description='Skin lesion classifier inference')
270
+ parser.add_argument('--clinical', required=True, help='Path to clinical image')
271
+ parser.add_argument('--derm', required=True, help='Path to dermoscopic image')
272
+ parser.add_argument('--weights', required=True,
273
+ help='Path to .pth checkpoint or directory of fold checkpoints')
274
+ args = parser.parse_args()
275
+
276
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
277
+ print(f"Using device: {device}")
278
+
279
+ weights_path = Path(args.weights)
280
+ if weights_path.is_dir():
281
+ models = load_ensemble(weights_path, device)
282
+ result = predict_ensemble(models, args.clinical, args.derm, device)
283
+ print(f"\nEnsemble ({result['n_models']} models)")
284
+ else:
285
+ model = load_model(weights_path, device)
286
+ result = predict_single(model, args.clinical, args.derm, device)
287
+
288
+ print(f"Prediction: {result['prediction']} β€” {result['description']}")
289
+ print(f"Confidence: {result['confidence']:.1%}")
290
+ print("\nAll class probabilities:")
291
+ for cls, prob in sorted(result['probabilities'].items(),
292
+ key=lambda x: x[1], reverse=True):
293
+ bar = 'β–ˆ' * int(prob * 30)
294
+ print(f" {cls:8s} {prob:.3f} {bar}")