|
|
""" |
|
|
RetinaRadar Inference Module for Hugging Face |
|
|
|
|
|
This module provides easy inference for the RetinaRadar model on Hugging Face. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from pathlib import Path |
|
|
from typing import Union, Dict, Any |
|
|
import albumentations as A |
|
|
from albumentations.pytorch import ToTensorV2 |
|
|
|
|
|
|
|
|
class RetinaRadarInference: |
|
|
""" |
|
|
Inference handler for RetinaRadar model on Hugging Face |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_path: str = "retinaradar_model.ckpt", |
|
|
metadata_path: str = "label_metadata.json", |
|
|
device: str = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
): |
|
|
""" |
|
|
Initialize the inference handler |
|
|
|
|
|
Args: |
|
|
model_path: Path to the model checkpoint |
|
|
metadata_path: Path to label metadata JSON |
|
|
device: Device to run inference on ('cuda' or 'cpu') |
|
|
""" |
|
|
self.device = device |
|
|
|
|
|
|
|
|
self.model = torch.load(model_path, map_location=device) |
|
|
self.model.eval() |
|
|
self.model.to(device) |
|
|
|
|
|
|
|
|
import json |
|
|
with open(metadata_path, 'r') as f: |
|
|
self.metadata = json.load(f) |
|
|
|
|
|
|
|
|
IMAGENET_MEAN = [0.485, 0.456, 0.406] |
|
|
IMAGENET_STD = [0.229, 0.224, 0.225] |
|
|
|
|
|
self.transform = A.Compose([ |
|
|
A.Resize(256, 256), |
|
|
A.CenterCrop(224, 224), |
|
|
A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), |
|
|
ToTensorV2(), |
|
|
]) |
|
|
|
|
|
def preprocess(self, image: Union[str, Path, Image.Image, np.ndarray]) -> torch.Tensor: |
|
|
""" |
|
|
Preprocess an image for inference |
|
|
|
|
|
Args: |
|
|
image: Image path, PIL Image, or numpy array |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Preprocessed image tensor |
|
|
""" |
|
|
|
|
|
if isinstance(image, (str, Path)): |
|
|
image = Image.open(image).convert('RGB') |
|
|
|
|
|
|
|
|
if isinstance(image, Image.Image): |
|
|
image = np.array(image) |
|
|
|
|
|
|
|
|
transformed = self.transform(image=image) |
|
|
image_tensor = transformed["image"].unsqueeze(0) |
|
|
|
|
|
return image_tensor.to(self.device) |
|
|
|
|
|
def predict( |
|
|
self, |
|
|
image: Union[str, Path, Image.Image, np.ndarray], |
|
|
threshold: float = 0.5 |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Run inference on an image |
|
|
|
|
|
Args: |
|
|
image: Image to process |
|
|
threshold: Prediction threshold |
|
|
|
|
|
Returns: |
|
|
dict: Predictions with labels and probabilities |
|
|
""" |
|
|
|
|
|
image_tensor = self.preprocess(image) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = self.model(image_tensor) |
|
|
probabilities = torch.sigmoid(logits) |
|
|
|
|
|
|
|
|
predictions = self.decode_predictions( |
|
|
probabilities[0].cpu(), |
|
|
threshold=threshold |
|
|
) |
|
|
|
|
|
return predictions |
|
|
|
|
|
def decode_predictions( |
|
|
self, |
|
|
probabilities: torch.Tensor, |
|
|
threshold: float = 0.5 |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Decode model predictions to human-readable format |
|
|
|
|
|
Args: |
|
|
probabilities: Sigmoid probabilities from model |
|
|
threshold: Threshold for binary predictions |
|
|
|
|
|
Returns: |
|
|
dict: Decoded predictions by feature |
|
|
""" |
|
|
binary_predictions = (probabilities > threshold).float() |
|
|
|
|
|
onehot_feature_names = self.metadata['onehot_feature_names'] |
|
|
feature_names = self.metadata['feature_names'] |
|
|
|
|
|
|
|
|
feature_predictions = {fname: [] for fname in feature_names} |
|
|
|
|
|
for i, onehot_name in enumerate(onehot_feature_names): |
|
|
if '_' in onehot_name: |
|
|
prefix, value = onehot_name.split('_', 1) |
|
|
feature_idx = int(prefix[1:]) |
|
|
|
|
|
if feature_idx < len(feature_names): |
|
|
original_feature_name = feature_names[feature_idx] |
|
|
|
|
|
feature_predictions[original_feature_name].append({ |
|
|
'value': value, |
|
|
'probability': float(probabilities[i]), |
|
|
'prediction': bool(binary_predictions[i]) |
|
|
}) |
|
|
|
|
|
|
|
|
results = {} |
|
|
for feature_name, predictions_list in feature_predictions.items(): |
|
|
if not predictions_list: |
|
|
results[feature_name] = { |
|
|
'probability': 0.0, |
|
|
'prediction': False, |
|
|
'label': None |
|
|
} |
|
|
continue |
|
|
|
|
|
best_pred = max(predictions_list, key=lambda x: x['probability']) |
|
|
|
|
|
results[feature_name] = { |
|
|
'probability': best_pred['probability'], |
|
|
'prediction': best_pred['prediction'], |
|
|
'label': best_pred['value'] if best_pred['prediction'] else None |
|
|
} |
|
|
|
|
|
return results |
|
|
|
|
|
def get_summary(self, predictions: Dict[str, Any]) -> str: |
|
|
""" |
|
|
Get human-readable summary of predictions |
|
|
|
|
|
Args: |
|
|
predictions: Predictions dictionary |
|
|
|
|
|
Returns: |
|
|
str: Formatted summary |
|
|
""" |
|
|
lines = ["Predictions:"] |
|
|
|
|
|
for feature, values in predictions.items(): |
|
|
if isinstance(values, dict) and 'prediction' in values: |
|
|
pred = "✓" if values['prediction'] else "✗" |
|
|
prob = values['probability'] |
|
|
label = values.get('label', 'N/A') |
|
|
lines.append(f" {feature}: {pred} (prob={prob:.3f}, label={label})") |
|
|
|
|
|
return "\n".join(lines) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
inferencer = RetinaRadarInference( |
|
|
model_path="retinaradar_model.ckpt", |
|
|
metadata_path="label_metadata.json", |
|
|
device="cuda" |
|
|
) |
|
|
|
|
|
|
|
|
predictions = inferencer.predict("example_image.png") |
|
|
|
|
|
|
|
|
print(inferencer.get_summary(predictions)) |
|
|
|
|
|
|
|
|
print(f"\nLaterality: {predictions['laterality']['label']}") |
|
|
print(f"Image usable: {predictions['usable']['prediction']}") |
|
|
|