RetinaRadar / hf_inference.py
Hunter Gill
Initial commit with Git LFS
8554c13
"""
RetinaRadar Inference Module for Hugging Face
This module provides easy inference for the RetinaRadar model on Hugging Face.
"""
import torch
import numpy as np
from PIL import Image
from pathlib import Path
from typing import Union, Dict, Any
import albumentations as A
from albumentations.pytorch import ToTensorV2
class RetinaRadarInference:
"""
Inference handler for RetinaRadar model on Hugging Face
"""
def __init__(
self,
model_path: str = "retinaradar_model.ckpt",
metadata_path: str = "label_metadata.json",
device: str = "cuda" if torch.cuda.is_available() else "cpu"
):
"""
Initialize the inference handler
Args:
model_path: Path to the model checkpoint
metadata_path: Path to label metadata JSON
device: Device to run inference on ('cuda' or 'cpu')
"""
self.device = device
# Load model
self.model = torch.load(model_path, map_location=device)
self.model.eval()
self.model.to(device)
# Load metadata
import json
with open(metadata_path, 'r') as f:
self.metadata = json.load(f)
# Setup preprocessing
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
self.transform = A.Compose([
A.Resize(256, 256),
A.CenterCrop(224, 224),
A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
ToTensorV2(),
])
def preprocess(self, image: Union[str, Path, Image.Image, np.ndarray]) -> torch.Tensor:
"""
Preprocess an image for inference
Args:
image: Image path, PIL Image, or numpy array
Returns:
torch.Tensor: Preprocessed image tensor
"""
# Load image if path
if isinstance(image, (str, Path)):
image = Image.open(image).convert('RGB')
# Convert PIL to numpy
if isinstance(image, Image.Image):
image = np.array(image)
# Apply transforms
transformed = self.transform(image=image)
image_tensor = transformed["image"].unsqueeze(0)
return image_tensor.to(self.device)
def predict(
self,
image: Union[str, Path, Image.Image, np.ndarray],
threshold: float = 0.5
) -> Dict[str, Any]:
"""
Run inference on an image
Args:
image: Image to process
threshold: Prediction threshold
Returns:
dict: Predictions with labels and probabilities
"""
# Preprocess
image_tensor = self.preprocess(image)
# Run inference
with torch.no_grad():
logits = self.model(image_tensor)
probabilities = torch.sigmoid(logits)
# Decode predictions
predictions = self.decode_predictions(
probabilities[0].cpu(),
threshold=threshold
)
return predictions
def decode_predictions(
self,
probabilities: torch.Tensor,
threshold: float = 0.5
) -> Dict[str, Any]:
"""
Decode model predictions to human-readable format
Args:
probabilities: Sigmoid probabilities from model
threshold: Threshold for binary predictions
Returns:
dict: Decoded predictions by feature
"""
binary_predictions = (probabilities > threshold).float()
onehot_feature_names = self.metadata['onehot_feature_names']
feature_names = self.metadata['feature_names']
# Organize predictions by original feature
feature_predictions = {fname: [] for fname in feature_names}
for i, onehot_name in enumerate(onehot_feature_names):
if '_' in onehot_name:
prefix, value = onehot_name.split('_', 1)
feature_idx = int(prefix[1:])
if feature_idx < len(feature_names):
original_feature_name = feature_names[feature_idx]
feature_predictions[original_feature_name].append({
'value': value,
'probability': float(probabilities[i]),
'prediction': bool(binary_predictions[i])
})
# Select best prediction for each feature
results = {}
for feature_name, predictions_list in feature_predictions.items():
if not predictions_list:
results[feature_name] = {
'probability': 0.0,
'prediction': False,
'label': None
}
continue
best_pred = max(predictions_list, key=lambda x: x['probability'])
results[feature_name] = {
'probability': best_pred['probability'],
'prediction': best_pred['prediction'],
'label': best_pred['value'] if best_pred['prediction'] else None
}
return results
def get_summary(self, predictions: Dict[str, Any]) -> str:
"""
Get human-readable summary of predictions
Args:
predictions: Predictions dictionary
Returns:
str: Formatted summary
"""
lines = ["Predictions:"]
for feature, values in predictions.items():
if isinstance(values, dict) and 'prediction' in values:
pred = "✓" if values['prediction'] else "✗"
prob = values['probability']
label = values.get('label', 'N/A')
lines.append(f" {feature}: {pred} (prob={prob:.3f}, label={label})")
return "\n".join(lines)
# Example usage
if __name__ == "__main__":
# Initialize
inferencer = RetinaRadarInference(
model_path="retinaradar_model.ckpt",
metadata_path="label_metadata.json",
device="cuda"
)
# Run inference
predictions = inferencer.predict("example_image.png")
# Print results
print(inferencer.get_summary(predictions))
# Access specific predictions
print(f"\nLaterality: {predictions['laterality']['label']}")
print(f"Image usable: {predictions['usable']['prediction']}")