""" RetinaRadar Inference Module for Hugging Face This module provides easy inference for the RetinaRadar model on Hugging Face. """ import torch import numpy as np from PIL import Image from pathlib import Path from typing import Union, Dict, Any import albumentations as A from albumentations.pytorch import ToTensorV2 class RetinaRadarInference: """ Inference handler for RetinaRadar model on Hugging Face """ def __init__( self, model_path: str = "retinaradar_model.ckpt", metadata_path: str = "label_metadata.json", device: str = "cuda" if torch.cuda.is_available() else "cpu" ): """ Initialize the inference handler Args: model_path: Path to the model checkpoint metadata_path: Path to label metadata JSON device: Device to run inference on ('cuda' or 'cpu') """ self.device = device # Load model self.model = torch.load(model_path, map_location=device) self.model.eval() self.model.to(device) # Load metadata import json with open(metadata_path, 'r') as f: self.metadata = json.load(f) # Setup preprocessing IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] self.transform = A.Compose([ A.Resize(256, 256), A.CenterCrop(224, 224), A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), ToTensorV2(), ]) def preprocess(self, image: Union[str, Path, Image.Image, np.ndarray]) -> torch.Tensor: """ Preprocess an image for inference Args: image: Image path, PIL Image, or numpy array Returns: torch.Tensor: Preprocessed image tensor """ # Load image if path if isinstance(image, (str, Path)): image = Image.open(image).convert('RGB') # Convert PIL to numpy if isinstance(image, Image.Image): image = np.array(image) # Apply transforms transformed = self.transform(image=image) image_tensor = transformed["image"].unsqueeze(0) return image_tensor.to(self.device) def predict( self, image: Union[str, Path, Image.Image, np.ndarray], threshold: float = 0.5 ) -> Dict[str, Any]: """ Run inference on an image Args: image: Image to process threshold: Prediction threshold Returns: dict: Predictions with labels and probabilities """ # Preprocess image_tensor = self.preprocess(image) # Run inference with torch.no_grad(): logits = self.model(image_tensor) probabilities = torch.sigmoid(logits) # Decode predictions predictions = self.decode_predictions( probabilities[0].cpu(), threshold=threshold ) return predictions def decode_predictions( self, probabilities: torch.Tensor, threshold: float = 0.5 ) -> Dict[str, Any]: """ Decode model predictions to human-readable format Args: probabilities: Sigmoid probabilities from model threshold: Threshold for binary predictions Returns: dict: Decoded predictions by feature """ binary_predictions = (probabilities > threshold).float() onehot_feature_names = self.metadata['onehot_feature_names'] feature_names = self.metadata['feature_names'] # Organize predictions by original feature feature_predictions = {fname: [] for fname in feature_names} for i, onehot_name in enumerate(onehot_feature_names): if '_' in onehot_name: prefix, value = onehot_name.split('_', 1) feature_idx = int(prefix[1:]) if feature_idx < len(feature_names): original_feature_name = feature_names[feature_idx] feature_predictions[original_feature_name].append({ 'value': value, 'probability': float(probabilities[i]), 'prediction': bool(binary_predictions[i]) }) # Select best prediction for each feature results = {} for feature_name, predictions_list in feature_predictions.items(): if not predictions_list: results[feature_name] = { 'probability': 0.0, 'prediction': False, 'label': None } continue best_pred = max(predictions_list, key=lambda x: x['probability']) results[feature_name] = { 'probability': best_pred['probability'], 'prediction': best_pred['prediction'], 'label': best_pred['value'] if best_pred['prediction'] else None } return results def get_summary(self, predictions: Dict[str, Any]) -> str: """ Get human-readable summary of predictions Args: predictions: Predictions dictionary Returns: str: Formatted summary """ lines = ["Predictions:"] for feature, values in predictions.items(): if isinstance(values, dict) and 'prediction' in values: pred = "✓" if values['prediction'] else "✗" prob = values['probability'] label = values.get('label', 'N/A') lines.append(f" {feature}: {pred} (prob={prob:.3f}, label={label})") return "\n".join(lines) # Example usage if __name__ == "__main__": # Initialize inferencer = RetinaRadarInference( model_path="retinaradar_model.ckpt", metadata_path="label_metadata.json", device="cuda" ) # Run inference predictions = inferencer.predict("example_image.png") # Print results print(inferencer.get_summary(predictions)) # Access specific predictions print(f"\nLaterality: {predictions['laterality']['label']}") print(f"Image usable: {predictions['usable']['prediction']}")