""" inference.py — ConvNeXt Dual-Modal Skin Lesion Classifier ISIC 2025 / MILK10k | CC BY-NC 4.0 Classifies skin lesions from paired dermoscopic + clinical images into 11 categories. Used as a tool called by MedGemma in the Skin AI application. """ import os import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import timm from PIL import Image import torchvision.transforms as transforms from pathlib import Path from typing import Union # ───────────────────────────────────────────── # Constants # ───────────────────────────────────────────── CLASS_NAMES = ['AKIEC', 'BCC', 'BEN_OTH', 'BKL', 'DF', 'INF', 'MAL_OTH', 'MEL', 'NV', 'SCCKA', 'VASC'] CLASS_DESCRIPTIONS = { 'AKIEC': 'Actinic keratosis / intraepithelial carcinoma', 'BCC': 'Basal cell carcinoma', 'BEN_OTH': 'Other benign lesion', 'BKL': 'Benign keratosis', 'DF': 'Dermatofibroma', 'INF': 'Inflammatory / infectious', 'MAL_OTH': 'Other malignant lesion', 'MEL': 'Melanoma', 'NV': 'Melanocytic nevus', 'SCCKA': 'Squamous cell carcinoma / keratoacanthoma', 'VASC': 'Vascular lesion', } IMG_SIZE = 384 TRANSFORM = transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) # ───────────────────────────────────────────── # Architecture # ───────────────────────────────────────────── class DualConvNeXt(nn.Module): """ Dual-input ConvNeXt-Base for paired dermoscopic + clinical image classification. Both encoders share the same architecture but are trained independently. """ def __init__(self, num_classes: int = 11, model_name: str = 'convnext_base'): super().__init__() self.clinical_encoder = timm.create_model( model_name, pretrained=False, num_classes=0 ) self.derm_encoder = timm.create_model( model_name, pretrained=False, num_classes=0 ) feat_dim = self.clinical_encoder.num_features # 1024 for convnext_base self.classifier = nn.Sequential( nn.Linear(feat_dim * 2, 512), nn.ReLU(), nn.Dropout(0.3), nn.Linear(512, num_classes) ) def forward(self, clinical: torch.Tensor, derm: torch.Tensor) -> torch.Tensor: c = self.clinical_encoder(clinical) d = self.derm_encoder(derm) return self.classifier(torch.cat([c, d], dim=1)) # ───────────────────────────────────────────── # Model loading # ───────────────────────────────────────────── def load_model( weights_path: Union[str, Path], device: torch.device = None ) -> DualConvNeXt: """ Load a trained DualConvNeXt model from a checkpoint file. Args: weights_path: Path to .pth checkpoint (expects dict with 'model_state_dict') device: torch.device — defaults to CUDA if available Returns: Loaded model in eval mode """ if device is None: device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = DualConvNeXt(num_classes=len(CLASS_NAMES)) checkpoint = torch.load(weights_path, map_location=device) # Handle both raw state dict and wrapped checkpoints state = checkpoint.get('model_state_dict', checkpoint) model.load_state_dict(state) model.eval().to(device) return model def load_ensemble( weights_dir: Union[str, Path], device: torch.device = None ) -> list: """ Load all fold models from a directory for ensemble inference. Args: weights_dir: Directory containing convnext_fold*.pth files device: torch.device Returns: List of loaded models """ if device is None: device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') weights_dir = Path(weights_dir) model_paths = sorted(weights_dir.glob('convnext_fold*.pth')) if not model_paths: raise FileNotFoundError(f"No fold checkpoints found in {weights_dir}") models = [load_model(p, device) for p in model_paths] print(f"Loaded {len(models)} fold models from {weights_dir}") return models # ───────────────────────────────────────────── # Preprocessing # ───────────────────────────────────────────── def preprocess_image(image_path: Union[str, Path]) -> torch.Tensor: """Load and preprocess a single image to model input format.""" img = Image.open(image_path).convert('RGB') return TRANSFORM(img) # ───────────────────────────────────────────── # Inference # ───────────────────────────────────────────── def predict_single( model: DualConvNeXt, clinical_path: Union[str, Path], derm_path: Union[str, Path], device: torch.device = None ) -> dict: """ Run inference with a single model. Args: model: Loaded DualConvNeXt model clinical_path: Path to clinical close-up image derm_path: Path to dermoscopic image device: torch.device Returns: dict with prediction, confidence, and per-class probabilities """ if device is None: device = next(model.parameters()).device clinical = preprocess_image(clinical_path).unsqueeze(0).to(device) derm = preprocess_image(derm_path).unsqueeze(0).to(device) with torch.no_grad(): logits = model(clinical, derm) probs = F.softmax(logits, dim=1).squeeze().cpu().numpy() pred_idx = int(probs.argmax()) return { 'prediction': CLASS_NAMES[pred_idx], 'description': CLASS_DESCRIPTIONS[CLASS_NAMES[pred_idx]], 'confidence': float(probs[pred_idx]), 'probabilities': {c: float(p) for c, p in zip(CLASS_NAMES, probs)} } def predict_ensemble( models: list, clinical_path: Union[str, Path], derm_path: Union[str, Path], device: torch.device = None ) -> dict: """ Run ensemble inference by averaging softmax probabilities across fold models. Args: models: List of loaded DualConvNeXt models clinical_path: Path to clinical close-up image derm_path: Path to dermoscopic image device: torch.device Returns: dict with ensemble prediction, confidence, per-class probabilities, and per-model probability breakdown """ if device is None: device = next(models[0].parameters()).device clinical = preprocess_image(clinical_path).unsqueeze(0).to(device) derm = preprocess_image(derm_path).unsqueeze(0).to(device) all_probs = [] with torch.no_grad(): for model in models: logits = model(clinical, derm) probs = F.softmax(logits, dim=1).squeeze().cpu().numpy() all_probs.append(probs) ensemble_probs = np.mean(all_probs, axis=0) pred_idx = int(ensemble_probs.argmax()) return { 'prediction': CLASS_NAMES[pred_idx], 'description': CLASS_DESCRIPTIONS[CLASS_NAMES[pred_idx]], 'confidence': float(ensemble_probs[pred_idx]), 'probabilities': {c: float(p) for c, p in zip(CLASS_NAMES, ensemble_probs)}, 'n_models': len(models) } # ───────────────────────────────────────────── # Batch inference # ───────────────────────────────────────────── def predict_batch( models: list, pairs: list, device: torch.device = None ) -> list: """ Run ensemble inference over a batch of image pairs. Args: models: List of loaded DualConvNeXt models pairs: List of (clinical_path, derm_path) tuples device: torch.device Returns: List of result dicts (same format as predict_ensemble) """ return [predict_ensemble(models, c, d, device) for c, d in pairs] # ───────────────────────────────────────────── # CLI / Quick test # ───────────────────────────────────────────── if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='Skin lesion classifier inference') parser.add_argument('--clinical', required=True, help='Path to clinical image') parser.add_argument('--derm', required=True, help='Path to dermoscopic image') parser.add_argument('--weights', required=True, help='Path to .pth checkpoint or directory of fold checkpoints') args = parser.parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") weights_path = Path(args.weights) if weights_path.is_dir(): models = load_ensemble(weights_path, device) result = predict_ensemble(models, args.clinical, args.derm, device) print(f"\nEnsemble ({result['n_models']} models)") else: model = load_model(weights_path, device) result = predict_single(model, args.clinical, args.derm, device) print(f"Prediction: {result['prediction']} — {result['description']}") print(f"Confidence: {result['confidence']:.1%}") print("\nAll class probabilities:") for cls, prob in sorted(result['probabilities'].items(), key=lambda x: x[1], reverse=True): bar = '█' * int(prob * 30) print(f" {cls:8s} {prob:.3f} {bar}")