ConvNeXt_Milk10k / inference.py
tech-doc's picture
upload inference.py
84a0314 verified
"""
inference.py β€” ConvNeXt Dual-Modal Skin Lesion Classifier
ISIC 2025 / MILK10k | CC BY-NC 4.0
Classifies skin lesions from paired dermoscopic + clinical images into 11 categories.
Used as a tool called by MedGemma in the Skin AI application.
"""
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from PIL import Image
import torchvision.transforms as transforms
from pathlib import Path
from typing import Union
# ─────────────────────────────────────────────
# Constants
# ─────────────────────────────────────────────
CLASS_NAMES = ['AKIEC', 'BCC', 'BEN_OTH', 'BKL', 'DF',
'INF', 'MAL_OTH', 'MEL', 'NV', 'SCCKA', 'VASC']
CLASS_DESCRIPTIONS = {
'AKIEC': 'Actinic keratosis / intraepithelial carcinoma',
'BCC': 'Basal cell carcinoma',
'BEN_OTH': 'Other benign lesion',
'BKL': 'Benign keratosis',
'DF': 'Dermatofibroma',
'INF': 'Inflammatory / infectious',
'MAL_OTH': 'Other malignant lesion',
'MEL': 'Melanoma',
'NV': 'Melanocytic nevus',
'SCCKA': 'Squamous cell carcinoma / keratoacanthoma',
'VASC': 'Vascular lesion',
}
IMG_SIZE = 384
TRANSFORM = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# ─────────────────────────────────────────────
# Architecture
# ─────────────────────────────────────────────
class DualConvNeXt(nn.Module):
"""
Dual-input ConvNeXt-Base for paired dermoscopic + clinical image classification.
Both encoders share the same architecture but are trained independently.
"""
def __init__(self, num_classes: int = 11, model_name: str = 'convnext_base'):
super().__init__()
self.clinical_encoder = timm.create_model(
model_name, pretrained=False, num_classes=0
)
self.derm_encoder = timm.create_model(
model_name, pretrained=False, num_classes=0
)
feat_dim = self.clinical_encoder.num_features # 1024 for convnext_base
self.classifier = nn.Sequential(
nn.Linear(feat_dim * 2, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, num_classes)
)
def forward(self, clinical: torch.Tensor, derm: torch.Tensor) -> torch.Tensor:
c = self.clinical_encoder(clinical)
d = self.derm_encoder(derm)
return self.classifier(torch.cat([c, d], dim=1))
# ─────────────────────────────────────────────
# Model loading
# ─────────────────────────────────────────────
def load_model(
weights_path: Union[str, Path],
device: torch.device = None
) -> DualConvNeXt:
"""
Load a trained DualConvNeXt model from a checkpoint file.
Args:
weights_path: Path to .pth checkpoint (expects dict with 'model_state_dict')
device: torch.device β€” defaults to CUDA if available
Returns:
Loaded model in eval mode
"""
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DualConvNeXt(num_classes=len(CLASS_NAMES))
checkpoint = torch.load(weights_path, map_location=device)
# Handle both raw state dict and wrapped checkpoints
state = checkpoint.get('model_state_dict', checkpoint)
model.load_state_dict(state)
model.eval().to(device)
return model
def load_ensemble(
weights_dir: Union[str, Path],
device: torch.device = None
) -> list:
"""
Load all fold models from a directory for ensemble inference.
Args:
weights_dir: Directory containing convnext_fold*.pth files
device: torch.device
Returns:
List of loaded models
"""
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
weights_dir = Path(weights_dir)
model_paths = sorted(weights_dir.glob('convnext_fold*.pth'))
if not model_paths:
raise FileNotFoundError(f"No fold checkpoints found in {weights_dir}")
models = [load_model(p, device) for p in model_paths]
print(f"Loaded {len(models)} fold models from {weights_dir}")
return models
# ─────────────────────────────────────────────
# Preprocessing
# ─────────────────────────────────────────────
def preprocess_image(image_path: Union[str, Path]) -> torch.Tensor:
"""Load and preprocess a single image to model input format."""
img = Image.open(image_path).convert('RGB')
return TRANSFORM(img)
# ─────────────────────────────────────────────
# Inference
# ─────────────────────────────────────────────
def predict_single(
model: DualConvNeXt,
clinical_path: Union[str, Path],
derm_path: Union[str, Path],
device: torch.device = None
) -> dict:
"""
Run inference with a single model.
Args:
model: Loaded DualConvNeXt model
clinical_path: Path to clinical close-up image
derm_path: Path to dermoscopic image
device: torch.device
Returns:
dict with prediction, confidence, and per-class probabilities
"""
if device is None:
device = next(model.parameters()).device
clinical = preprocess_image(clinical_path).unsqueeze(0).to(device)
derm = preprocess_image(derm_path).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(clinical, derm)
probs = F.softmax(logits, dim=1).squeeze().cpu().numpy()
pred_idx = int(probs.argmax())
return {
'prediction': CLASS_NAMES[pred_idx],
'description': CLASS_DESCRIPTIONS[CLASS_NAMES[pred_idx]],
'confidence': float(probs[pred_idx]),
'probabilities': {c: float(p) for c, p in zip(CLASS_NAMES, probs)}
}
def predict_ensemble(
models: list,
clinical_path: Union[str, Path],
derm_path: Union[str, Path],
device: torch.device = None
) -> dict:
"""
Run ensemble inference by averaging softmax probabilities across fold models.
Args:
models: List of loaded DualConvNeXt models
clinical_path: Path to clinical close-up image
derm_path: Path to dermoscopic image
device: torch.device
Returns:
dict with ensemble prediction, confidence, per-class probabilities,
and per-model probability breakdown
"""
if device is None:
device = next(models[0].parameters()).device
clinical = preprocess_image(clinical_path).unsqueeze(0).to(device)
derm = preprocess_image(derm_path).unsqueeze(0).to(device)
all_probs = []
with torch.no_grad():
for model in models:
logits = model(clinical, derm)
probs = F.softmax(logits, dim=1).squeeze().cpu().numpy()
all_probs.append(probs)
ensemble_probs = np.mean(all_probs, axis=0)
pred_idx = int(ensemble_probs.argmax())
return {
'prediction': CLASS_NAMES[pred_idx],
'description': CLASS_DESCRIPTIONS[CLASS_NAMES[pred_idx]],
'confidence': float(ensemble_probs[pred_idx]),
'probabilities': {c: float(p) for c, p in zip(CLASS_NAMES, ensemble_probs)},
'n_models': len(models)
}
# ─────────────────────────────────────────────
# Batch inference
# ─────────────────────────────────────────────
def predict_batch(
models: list,
pairs: list,
device: torch.device = None
) -> list:
"""
Run ensemble inference over a batch of image pairs.
Args:
models: List of loaded DualConvNeXt models
pairs: List of (clinical_path, derm_path) tuples
device: torch.device
Returns:
List of result dicts (same format as predict_ensemble)
"""
return [predict_ensemble(models, c, d, device) for c, d in pairs]
# ─────────────────────────────────────────────
# CLI / Quick test
# ─────────────────────────────────────────────
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Skin lesion classifier inference')
parser.add_argument('--clinical', required=True, help='Path to clinical image')
parser.add_argument('--derm', required=True, help='Path to dermoscopic image')
parser.add_argument('--weights', required=True,
help='Path to .pth checkpoint or directory of fold checkpoints')
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
weights_path = Path(args.weights)
if weights_path.is_dir():
models = load_ensemble(weights_path, device)
result = predict_ensemble(models, args.clinical, args.derm, device)
print(f"\nEnsemble ({result['n_models']} models)")
else:
model = load_model(weights_path, device)
result = predict_single(model, args.clinical, args.derm, device)
print(f"Prediction: {result['prediction']} β€” {result['description']}")
print(f"Confidence: {result['confidence']:.1%}")
print("\nAll class probabilities:")
for cls, prob in sorted(result['probabilities'].items(),
key=lambda x: x[1], reverse=True):
bar = 'β–ˆ' * int(prob * 30)
print(f" {cls:8s} {prob:.3f} {bar}")