File size: 6,609 Bytes
8554c13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
"""
RetinaRadar Inference Module for Hugging Face

This module provides easy inference for the RetinaRadar model on Hugging Face.
"""

import torch
import numpy as np
from PIL import Image
from pathlib import Path
from typing import Union, Dict, Any
import albumentations as A
from albumentations.pytorch import ToTensorV2


class RetinaRadarInference:
    """
    Inference handler for RetinaRadar model on Hugging Face
    """
    
    def __init__(
        self,
        model_path: str = "retinaradar_model.ckpt",
        metadata_path: str = "label_metadata.json",
        device: str = "cuda" if torch.cuda.is_available() else "cpu"
    ):
        """
        Initialize the inference handler
        
        Args:
            model_path: Path to the model checkpoint
            metadata_path: Path to label metadata JSON
            device: Device to run inference on ('cuda' or 'cpu')
        """
        self.device = device
        
        # Load model
        self.model = torch.load(model_path, map_location=device)
        self.model.eval()
        self.model.to(device)
        
        # Load metadata
        import json
        with open(metadata_path, 'r') as f:
            self.metadata = json.load(f)
        
        # Setup preprocessing
        IMAGENET_MEAN = [0.485, 0.456, 0.406]
        IMAGENET_STD = [0.229, 0.224, 0.225]
        
        self.transform = A.Compose([
            A.Resize(256, 256),
            A.CenterCrop(224, 224),
            A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
            ToTensorV2(),
        ])
    
    def preprocess(self, image: Union[str, Path, Image.Image, np.ndarray]) -> torch.Tensor:
        """
        Preprocess an image for inference
        
        Args:
            image: Image path, PIL Image, or numpy array
            
        Returns:
            torch.Tensor: Preprocessed image tensor
        """
        # Load image if path
        if isinstance(image, (str, Path)):
            image = Image.open(image).convert('RGB')
        
        # Convert PIL to numpy
        if isinstance(image, Image.Image):
            image = np.array(image)
        
        # Apply transforms
        transformed = self.transform(image=image)
        image_tensor = transformed["image"].unsqueeze(0)
        
        return image_tensor.to(self.device)
    
    def predict(
        self,
        image: Union[str, Path, Image.Image, np.ndarray],
        threshold: float = 0.5
    ) -> Dict[str, Any]:
        """
        Run inference on an image
        
        Args:
            image: Image to process
            threshold: Prediction threshold
            
        Returns:
            dict: Predictions with labels and probabilities
        """
        # Preprocess
        image_tensor = self.preprocess(image)
        
        # Run inference
        with torch.no_grad():
            logits = self.model(image_tensor)
            probabilities = torch.sigmoid(logits)
        
        # Decode predictions
        predictions = self.decode_predictions(
            probabilities[0].cpu(),
            threshold=threshold
        )
        
        return predictions
    
    def decode_predictions(
        self,
        probabilities: torch.Tensor,
        threshold: float = 0.5
    ) -> Dict[str, Any]:
        """
        Decode model predictions to human-readable format
        
        Args:
            probabilities: Sigmoid probabilities from model
            threshold: Threshold for binary predictions
            
        Returns:
            dict: Decoded predictions by feature
        """
        binary_predictions = (probabilities > threshold).float()
        
        onehot_feature_names = self.metadata['onehot_feature_names']
        feature_names = self.metadata['feature_names']
        
        # Organize predictions by original feature
        feature_predictions = {fname: [] for fname in feature_names}
        
        for i, onehot_name in enumerate(onehot_feature_names):
            if '_' in onehot_name:
                prefix, value = onehot_name.split('_', 1)
                feature_idx = int(prefix[1:])
                
                if feature_idx < len(feature_names):
                    original_feature_name = feature_names[feature_idx]
                    
                    feature_predictions[original_feature_name].append({
                        'value': value,
                        'probability': float(probabilities[i]),
                        'prediction': bool(binary_predictions[i])
                    })
        
        # Select best prediction for each feature
        results = {}
        for feature_name, predictions_list in feature_predictions.items():
            if not predictions_list:
                results[feature_name] = {
                    'probability': 0.0,
                    'prediction': False,
                    'label': None
                }
                continue
            
            best_pred = max(predictions_list, key=lambda x: x['probability'])
            
            results[feature_name] = {
                'probability': best_pred['probability'],
                'prediction': best_pred['prediction'],
                'label': best_pred['value'] if best_pred['prediction'] else None
            }
        
        return results
    
    def get_summary(self, predictions: Dict[str, Any]) -> str:
        """
        Get human-readable summary of predictions
        
        Args:
            predictions: Predictions dictionary
            
        Returns:
            str: Formatted summary
        """
        lines = ["Predictions:"]
        
        for feature, values in predictions.items():
            if isinstance(values, dict) and 'prediction' in values:
                pred = "✓" if values['prediction'] else "✗"
                prob = values['probability']
                label = values.get('label', 'N/A')
                lines.append(f"  {feature}: {pred} (prob={prob:.3f}, label={label})")
        
        return "\n".join(lines)


# Example usage
if __name__ == "__main__":
    # Initialize
    inferencer = RetinaRadarInference(
        model_path="retinaradar_model.ckpt",
        metadata_path="label_metadata.json",
        device="cuda"
    )
    
    # Run inference
    predictions = inferencer.predict("example_image.png")
    
    # Print results
    print(inferencer.get_summary(predictions))
    
    # Access specific predictions
    print(f"\nLaterality: {predictions['laterality']['label']}")
    print(f"Image usable: {predictions['usable']['prediction']}")