File size: 7,230 Bytes
8655c2b
 
 
 
 
 
 
 
9906796
b85a040
8655c2b
 
 
 
 
 
 
 
 
 
 
 
9906796
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8655c2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
045d71c
bc5a8a0
8655c2b
 
 
 
 
 
b85a040
8655c2b
 
 
 
b85a040
8655c2b
b85a040
 
 
 
 
 
 
 
 
 
bc5a8a0
b85a040
 
 
 
 
 
 
 
bc5a8a0
8655c2b
b85a040
bc5a8a0
8655c2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc5a8a0
045d71c
bc5a8a0
 
8655c2b
bc5a8a0
 
 
045d71c
bc5a8a0
 
8655c2b
 
045d71c
b85a040
 
bc5a8a0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
from typing import Dict, List, Any
import io
import torch
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from PIL import Image
import gc
import os
import base64

class EndpointHandler:
    def __init__(self, path=""):
        """
        Initialize the endpoint handler with the OME detection model.
        
        Args:
            path (str): Path to the model weights (can be local or HF Hub ID)
        """
        # Set device to CPU to reduce memory usage
        self.device = torch.device("cpu")
        
        # Check if we're running in the Hugging Face Endpoints environment
        # In HF Endpoints, the model is loaded from the local repository directory
        if os.path.isdir(path) and os.path.exists(os.path.join(path, "pytorch_model.bin")):
            # Load model from local files
            print(f"Loading model from local path: {path}")
            self.model = timm.create_model("inception_v4", num_classes=1)
            
            # Load state dict
            state_dict_path = os.path.join(path, "pytorch_model.bin")
            state_dict = torch.load(state_dict_path, map_location=self.device)
            self.model.load_state_dict(state_dict)
        else:
            # Use the Hugging Face Hub ID
            print(f"Loading model from Hugging Face Hub: Thaweewat/inception_512_augv1")
            self.model = timm.create_model("hf_hub:Thaweewat/inception_512_augv1", pretrained=True)
        
        self.model.to(self.device)
        self.model.eval()
        
        # Get model configuration for preprocessing
        self.config = resolve_data_config({}, model=self.model)
        
        # Free up memory
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        gc.collect()
    
    def preprocess_image(self, image):
        """
        Preprocess the image for model inference.
        
        Args:
            image (PIL.Image): Input image
            
        Returns:
            torch.Tensor: Preprocessed image tensor
        """
        # First, resize and crop to 512x512
        width, height = image.size
        
        # Determine the size to crop (take the smaller dimension)
        crop_size = min(width, height)
        
        # Calculate crop coordinates to center the crop
        left = (width - crop_size) // 2
        top = (height - crop_size) // 2
        right = left + crop_size
        bottom = top + crop_size
        
        # Crop the image to a square
        image = image.crop((left, top, right, bottom))
        
        # Resize to 512x512 if not already that size
        if crop_size != 512:
            image = image.resize((512, 512), Image.LANCZOS)
        
        # Convert to RGB if not already
        image = image.convert('RGB')
        
        # Use timm's transform which is configured for the specific model
        transform = create_transform(**self.config)
        
        return transform(image)
    
    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Process the input data and return predictions.
        
        Args:
            data (Dict[str, Any]): Input data containing either:
                - "inputs": Base64 encoded image or URL
                
        Returns:
            List[Dict[str, Any]]: Prediction results in format required by HF Endpoints
                                 [{"label": "OME", "score": float}]
        """
        try:
            # Get image data from various possible input formats
            if "inputs" in data:
                inputs = data["inputs"]
                
                # Check if input is a URL
                if isinstance(inputs, str) and (inputs.startswith('http://') or inputs.startswith('https://')):
                    import requests
                    response = requests.get(inputs)
                    image = Image.open(io.BytesIO(response.content))
                elif isinstance(inputs, str):
                    # Assume base64 encoded image
                    try:
                        image_bytes = base64.b64decode(inputs)
                        image = Image.open(io.BytesIO(image_bytes))
                    except Exception as e:
                        print(f"Error decoding base64: {e}")
                        # Try to open as file path
                        try:
                            image = Image.open(inputs)
                        except Exception as e2:
                            print(f"Error opening as file: {e2}")
                            return [{"label": "OME", "score": 0.0}]
                elif isinstance(inputs, bytes):
                    # Handle binary data directly
                    image = Image.open(io.BytesIO(inputs))
                elif isinstance(inputs, Image.Image):
                    # Handle PIL Image directly
                    image = inputs
                else:
                    print(f"Unsupported input type: {type(inputs)}")
                    return [{"label": "OME", "score": 0.0}]
            else:
                print("No 'inputs' found in data")
                return [{"label": "OME", "score": 0.0}]
            
            # Preprocess image
            image_tensor = self.preprocess_image(image)
            
            # Make prediction with memory optimization
            with torch.no_grad():  # Disable gradient calculation to save memory
                image_tensor = image_tensor.unsqueeze(0).to(self.device)
                output = self.model(image_tensor)
                
                # Handle different output formats
                if isinstance(output, tuple):
                    # Some models return multiple outputs
                    output = output[0]
                    
                # Check output shape and get the first element if needed
                if output.ndim > 1 and output.shape[1] > 1:
                    # If output has multiple classes, take the first one
                    output = output[:, 0]
                    
                prediction = torch.sigmoid(output).item()
            
            # Free memory
            del image_tensor
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
            gc.collect()
            
            # Always return "OME" as the label, but with the appropriate score
            # Note the reversed logic based on the model's behavior:
            # High scores (close to 1.0) indicate a normal ear (no OME) -> low OME score
            # Low scores (close to 0.0) indicate presence of OME -> high OME score
            
            # Use the reversed score (1-prediction) as the confidence for OME
            # This gives high scores when OME is likely and low scores when OME is unlikely
            ome_score = float(1 - prediction)
            
            # Always return "OME" as the label with the appropriate score
            return [{"label": "OME", "score": ome_score}]
            
        except Exception as e:
            print(f"Error processing image: {str(e)}")
            import traceback
            traceback.print_exc()
            return [{"label": "OME", "score": 0.0}]