File size: 7,230 Bytes
8655c2b 9906796 b85a040 8655c2b 9906796 8655c2b 045d71c bc5a8a0 8655c2b b85a040 8655c2b b85a040 8655c2b b85a040 bc5a8a0 b85a040 bc5a8a0 8655c2b b85a040 bc5a8a0 8655c2b bc5a8a0 045d71c bc5a8a0 8655c2b bc5a8a0 045d71c bc5a8a0 8655c2b 045d71c b85a040 bc5a8a0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | from typing import Dict, List, Any
import io
import torch
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from PIL import Image
import gc
import os
import base64
class EndpointHandler:
def __init__(self, path=""):
"""
Initialize the endpoint handler with the OME detection model.
Args:
path (str): Path to the model weights (can be local or HF Hub ID)
"""
# Set device to CPU to reduce memory usage
self.device = torch.device("cpu")
# Check if we're running in the Hugging Face Endpoints environment
# In HF Endpoints, the model is loaded from the local repository directory
if os.path.isdir(path) and os.path.exists(os.path.join(path, "pytorch_model.bin")):
# Load model from local files
print(f"Loading model from local path: {path}")
self.model = timm.create_model("inception_v4", num_classes=1)
# Load state dict
state_dict_path = os.path.join(path, "pytorch_model.bin")
state_dict = torch.load(state_dict_path, map_location=self.device)
self.model.load_state_dict(state_dict)
else:
# Use the Hugging Face Hub ID
print(f"Loading model from Hugging Face Hub: Thaweewat/inception_512_augv1")
self.model = timm.create_model("hf_hub:Thaweewat/inception_512_augv1", pretrained=True)
self.model.to(self.device)
self.model.eval()
# Get model configuration for preprocessing
self.config = resolve_data_config({}, model=self.model)
# Free up memory
torch.cuda.empty_cache() if torch.cuda.is_available() else None
gc.collect()
def preprocess_image(self, image):
"""
Preprocess the image for model inference.
Args:
image (PIL.Image): Input image
Returns:
torch.Tensor: Preprocessed image tensor
"""
# First, resize and crop to 512x512
width, height = image.size
# Determine the size to crop (take the smaller dimension)
crop_size = min(width, height)
# Calculate crop coordinates to center the crop
left = (width - crop_size) // 2
top = (height - crop_size) // 2
right = left + crop_size
bottom = top + crop_size
# Crop the image to a square
image = image.crop((left, top, right, bottom))
# Resize to 512x512 if not already that size
if crop_size != 512:
image = image.resize((512, 512), Image.LANCZOS)
# Convert to RGB if not already
image = image.convert('RGB')
# Use timm's transform which is configured for the specific model
transform = create_transform(**self.config)
return transform(image)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process the input data and return predictions.
Args:
data (Dict[str, Any]): Input data containing either:
- "inputs": Base64 encoded image or URL
Returns:
List[Dict[str, Any]]: Prediction results in format required by HF Endpoints
[{"label": "OME", "score": float}]
"""
try:
# Get image data from various possible input formats
if "inputs" in data:
inputs = data["inputs"]
# Check if input is a URL
if isinstance(inputs, str) and (inputs.startswith('http://') or inputs.startswith('https://')):
import requests
response = requests.get(inputs)
image = Image.open(io.BytesIO(response.content))
elif isinstance(inputs, str):
# Assume base64 encoded image
try:
image_bytes = base64.b64decode(inputs)
image = Image.open(io.BytesIO(image_bytes))
except Exception as e:
print(f"Error decoding base64: {e}")
# Try to open as file path
try:
image = Image.open(inputs)
except Exception as e2:
print(f"Error opening as file: {e2}")
return [{"label": "OME", "score": 0.0}]
elif isinstance(inputs, bytes):
# Handle binary data directly
image = Image.open(io.BytesIO(inputs))
elif isinstance(inputs, Image.Image):
# Handle PIL Image directly
image = inputs
else:
print(f"Unsupported input type: {type(inputs)}")
return [{"label": "OME", "score": 0.0}]
else:
print("No 'inputs' found in data")
return [{"label": "OME", "score": 0.0}]
# Preprocess image
image_tensor = self.preprocess_image(image)
# Make prediction with memory optimization
with torch.no_grad(): # Disable gradient calculation to save memory
image_tensor = image_tensor.unsqueeze(0).to(self.device)
output = self.model(image_tensor)
# Handle different output formats
if isinstance(output, tuple):
# Some models return multiple outputs
output = output[0]
# Check output shape and get the first element if needed
if output.ndim > 1 and output.shape[1] > 1:
# If output has multiple classes, take the first one
output = output[:, 0]
prediction = torch.sigmoid(output).item()
# Free memory
del image_tensor
torch.cuda.empty_cache() if torch.cuda.is_available() else None
gc.collect()
# Always return "OME" as the label, but with the appropriate score
# Note the reversed logic based on the model's behavior:
# High scores (close to 1.0) indicate a normal ear (no OME) -> low OME score
# Low scores (close to 0.0) indicate presence of OME -> high OME score
# Use the reversed score (1-prediction) as the confidence for OME
# This gives high scores when OME is likely and low scores when OME is unlikely
ome_score = float(1 - prediction)
# Always return "OME" as the label with the appropriate score
return [{"label": "OME", "score": ome_score}]
except Exception as e:
print(f"Error processing image: {str(e)}")
import traceback
traceback.print_exc()
return [{"label": "OME", "score": 0.0}]
|