| | """ |
| | inference.py β ConvNeXt Dual-Modal Skin Lesion Classifier |
| | ISIC 2025 / MILK10k | CC BY-NC 4.0 |
| | |
| | Classifies skin lesions from paired dermoscopic + clinical images into 11 categories. |
| | Used as a tool called by MedGemma in the Skin AI application. |
| | """ |
| |
|
| | import os |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import timm |
| | from PIL import Image |
| | import torchvision.transforms as transforms |
| | from pathlib import Path |
| | from typing import Union |
| |
|
| | |
| | |
| | |
| |
|
| | CLASS_NAMES = ['AKIEC', 'BCC', 'BEN_OTH', 'BKL', 'DF', |
| | 'INF', 'MAL_OTH', 'MEL', 'NV', 'SCCKA', 'VASC'] |
| |
|
| | CLASS_DESCRIPTIONS = { |
| | 'AKIEC': 'Actinic keratosis / intraepithelial carcinoma', |
| | 'BCC': 'Basal cell carcinoma', |
| | 'BEN_OTH': 'Other benign lesion', |
| | 'BKL': 'Benign keratosis', |
| | 'DF': 'Dermatofibroma', |
| | 'INF': 'Inflammatory / infectious', |
| | 'MAL_OTH': 'Other malignant lesion', |
| | 'MEL': 'Melanoma', |
| | 'NV': 'Melanocytic nevus', |
| | 'SCCKA': 'Squamous cell carcinoma / keratoacanthoma', |
| | 'VASC': 'Vascular lesion', |
| | } |
| |
|
| | IMG_SIZE = 384 |
| |
|
| | TRANSFORM = transforms.Compose([ |
| | transforms.Resize((IMG_SIZE, IMG_SIZE)), |
| | transforms.ToTensor(), |
| | transforms.Normalize( |
| | mean=[0.485, 0.456, 0.406], |
| | std=[0.229, 0.224, 0.225] |
| | ) |
| | ]) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class DualConvNeXt(nn.Module): |
| | """ |
| | Dual-input ConvNeXt-Base for paired dermoscopic + clinical image classification. |
| | Both encoders share the same architecture but are trained independently. |
| | """ |
| |
|
| | def __init__(self, num_classes: int = 11, model_name: str = 'convnext_base'): |
| | super().__init__() |
| | self.clinical_encoder = timm.create_model( |
| | model_name, pretrained=False, num_classes=0 |
| | ) |
| | self.derm_encoder = timm.create_model( |
| | model_name, pretrained=False, num_classes=0 |
| | ) |
| | feat_dim = self.clinical_encoder.num_features |
| | self.classifier = nn.Sequential( |
| | nn.Linear(feat_dim * 2, 512), |
| | nn.ReLU(), |
| | nn.Dropout(0.3), |
| | nn.Linear(512, num_classes) |
| | ) |
| |
|
| | def forward(self, clinical: torch.Tensor, derm: torch.Tensor) -> torch.Tensor: |
| | c = self.clinical_encoder(clinical) |
| | d = self.derm_encoder(derm) |
| | return self.classifier(torch.cat([c, d], dim=1)) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def load_model( |
| | weights_path: Union[str, Path], |
| | device: torch.device = None |
| | ) -> DualConvNeXt: |
| | """ |
| | Load a trained DualConvNeXt model from a checkpoint file. |
| | |
| | Args: |
| | weights_path: Path to .pth checkpoint (expects dict with 'model_state_dict') |
| | device: torch.device β defaults to CUDA if available |
| | |
| | Returns: |
| | Loaded model in eval mode |
| | """ |
| | if device is None: |
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
|
| | model = DualConvNeXt(num_classes=len(CLASS_NAMES)) |
| | checkpoint = torch.load(weights_path, map_location=device) |
| |
|
| | |
| | state = checkpoint.get('model_state_dict', checkpoint) |
| | model.load_state_dict(state) |
| | model.eval().to(device) |
| | return model |
| |
|
| |
|
| | def load_ensemble( |
| | weights_dir: Union[str, Path], |
| | device: torch.device = None |
| | ) -> list: |
| | """ |
| | Load all fold models from a directory for ensemble inference. |
| | |
| | Args: |
| | weights_dir: Directory containing convnext_fold*.pth files |
| | device: torch.device |
| | |
| | Returns: |
| | List of loaded models |
| | """ |
| | if device is None: |
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
|
| | weights_dir = Path(weights_dir) |
| | model_paths = sorted(weights_dir.glob('convnext_fold*.pth')) |
| |
|
| | if not model_paths: |
| | raise FileNotFoundError(f"No fold checkpoints found in {weights_dir}") |
| |
|
| | models = [load_model(p, device) for p in model_paths] |
| | print(f"Loaded {len(models)} fold models from {weights_dir}") |
| | return models |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def preprocess_image(image_path: Union[str, Path]) -> torch.Tensor: |
| | """Load and preprocess a single image to model input format.""" |
| | img = Image.open(image_path).convert('RGB') |
| | return TRANSFORM(img) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def predict_single( |
| | model: DualConvNeXt, |
| | clinical_path: Union[str, Path], |
| | derm_path: Union[str, Path], |
| | device: torch.device = None |
| | ) -> dict: |
| | """ |
| | Run inference with a single model. |
| | |
| | Args: |
| | model: Loaded DualConvNeXt model |
| | clinical_path: Path to clinical close-up image |
| | derm_path: Path to dermoscopic image |
| | device: torch.device |
| | |
| | Returns: |
| | dict with prediction, confidence, and per-class probabilities |
| | """ |
| | if device is None: |
| | device = next(model.parameters()).device |
| |
|
| | clinical = preprocess_image(clinical_path).unsqueeze(0).to(device) |
| | derm = preprocess_image(derm_path).unsqueeze(0).to(device) |
| |
|
| | with torch.no_grad(): |
| | logits = model(clinical, derm) |
| | probs = F.softmax(logits, dim=1).squeeze().cpu().numpy() |
| |
|
| | pred_idx = int(probs.argmax()) |
| | return { |
| | 'prediction': CLASS_NAMES[pred_idx], |
| | 'description': CLASS_DESCRIPTIONS[CLASS_NAMES[pred_idx]], |
| | 'confidence': float(probs[pred_idx]), |
| | 'probabilities': {c: float(p) for c, p in zip(CLASS_NAMES, probs)} |
| | } |
| |
|
| |
|
| | def predict_ensemble( |
| | models: list, |
| | clinical_path: Union[str, Path], |
| | derm_path: Union[str, Path], |
| | device: torch.device = None |
| | ) -> dict: |
| | """ |
| | Run ensemble inference by averaging softmax probabilities across fold models. |
| | |
| | Args: |
| | models: List of loaded DualConvNeXt models |
| | clinical_path: Path to clinical close-up image |
| | derm_path: Path to dermoscopic image |
| | device: torch.device |
| | |
| | Returns: |
| | dict with ensemble prediction, confidence, per-class probabilities, |
| | and per-model probability breakdown |
| | """ |
| | if device is None: |
| | device = next(models[0].parameters()).device |
| |
|
| | clinical = preprocess_image(clinical_path).unsqueeze(0).to(device) |
| | derm = preprocess_image(derm_path).unsqueeze(0).to(device) |
| |
|
| | all_probs = [] |
| | with torch.no_grad(): |
| | for model in models: |
| | logits = model(clinical, derm) |
| | probs = F.softmax(logits, dim=1).squeeze().cpu().numpy() |
| | all_probs.append(probs) |
| |
|
| | ensemble_probs = np.mean(all_probs, axis=0) |
| | pred_idx = int(ensemble_probs.argmax()) |
| |
|
| | return { |
| | 'prediction': CLASS_NAMES[pred_idx], |
| | 'description': CLASS_DESCRIPTIONS[CLASS_NAMES[pred_idx]], |
| | 'confidence': float(ensemble_probs[pred_idx]), |
| | 'probabilities': {c: float(p) for c, p in zip(CLASS_NAMES, ensemble_probs)}, |
| | 'n_models': len(models) |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def predict_batch( |
| | models: list, |
| | pairs: list, |
| | device: torch.device = None |
| | ) -> list: |
| | """ |
| | Run ensemble inference over a batch of image pairs. |
| | |
| | Args: |
| | models: List of loaded DualConvNeXt models |
| | pairs: List of (clinical_path, derm_path) tuples |
| | device: torch.device |
| | |
| | Returns: |
| | List of result dicts (same format as predict_ensemble) |
| | """ |
| | return [predict_ensemble(models, c, d, device) for c, d in pairs] |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | if __name__ == '__main__': |
| | import argparse |
| |
|
| | parser = argparse.ArgumentParser(description='Skin lesion classifier inference') |
| | parser.add_argument('--clinical', required=True, help='Path to clinical image') |
| | parser.add_argument('--derm', required=True, help='Path to dermoscopic image') |
| | parser.add_argument('--weights', required=True, |
| | help='Path to .pth checkpoint or directory of fold checkpoints') |
| | args = parser.parse_args() |
| |
|
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| | print(f"Using device: {device}") |
| |
|
| | weights_path = Path(args.weights) |
| | if weights_path.is_dir(): |
| | models = load_ensemble(weights_path, device) |
| | result = predict_ensemble(models, args.clinical, args.derm, device) |
| | print(f"\nEnsemble ({result['n_models']} models)") |
| | else: |
| | model = load_model(weights_path, device) |
| | result = predict_single(model, args.clinical, args.derm, device) |
| |
|
| | print(f"Prediction: {result['prediction']} β {result['description']}") |
| | print(f"Confidence: {result['confidence']:.1%}") |
| | print("\nAll class probabilities:") |
| | for cls, prob in sorted(result['probabilities'].items(), |
| | key=lambda x: x[1], reverse=True): |
| | bar = 'β' * int(prob * 30) |
| | print(f" {cls:8s} {prob:.3f} {bar}") |
| |
|