presscare / handler.py
RealNattawattHongthong's picture
Add RexNet pressure ulcer classification model with custom handler for inference endpoints
268f99f
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import io
import base64
from typing import Dict, List, Any
import timm
class EndpointHandler:
def __init__(self, path=""):
"""
Initialize handler with model path
"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Class names
self.class_names = ['Invalid', 'SDTI', 'Stage_I', 'Stage_II', 'Stage_III', 'Stage_IV', 'Unstageable']
# Load RexNet model
self.model = timm.create_model('rexnet_150', pretrained=False, num_classes=7)
# Load state dict
model_path = f"{path}/pytorch_model.bin" if path else "pytorch_model.bin"
state_dict = torch.load(model_path, map_location=self.device)
self.model.load_state_dict(state_dict)
self.model.to(self.device)
self.model.eval()
# Define preprocessing
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process inference request
"""
# Get inputs
inputs = data.pop("inputs", data)
# Handle different input formats
if isinstance(inputs, dict) and "image" in inputs:
image_data = inputs["image"]
elif isinstance(inputs, str):
image_data = inputs
else:
raise ValueError("Invalid input format. Expected {'image': base64_string} or base64_string")
# Decode base64 image
try:
image_bytes = base64.b64decode(image_data)
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
except Exception as e:
raise ValueError(f"Failed to decode image: {str(e)}")
# Preprocess image
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
# Run inference
with torch.no_grad():
outputs = self.model(image_tensor)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
# Get top 3 predictions
top3_prob, top3_indices = torch.topk(probabilities, 3)
# Prepare response
predictions = []
for i in range(3):
predictions.append({
"label": self.class_names[top3_indices[0][i].item()],
"score": float(top3_prob[0][i].item())
})
# Get all probabilities
all_probs = {}
for i, class_name in enumerate(self.class_names):
all_probs[class_name] = float(probabilities[0][i].item())
return [{
"predictions": predictions,
"probabilities": all_probs,
"predicted_class": predictions[0]["label"],
"confidence": predictions[0]["score"]
}]