import torch import torch.nn as nn from torchvision import transforms from PIL import Image import io import base64 from typing import Dict, List, Any import timm class EndpointHandler: def __init__(self, path=""): """ Initialize handler with model path """ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Class names self.class_names = ['Invalid', 'SDTI', 'Stage_I', 'Stage_II', 'Stage_III', 'Stage_IV', 'Unstageable'] # Load RexNet model self.model = timm.create_model('rexnet_150', pretrained=False, num_classes=7) # Load state dict model_path = f"{path}/pytorch_model.bin" if path else "pytorch_model.bin" state_dict = torch.load(model_path, map_location=self.device) self.model.load_state_dict(state_dict) self.model.to(self.device) self.model.eval() # Define preprocessing self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Process inference request """ # Get inputs inputs = data.pop("inputs", data) # Handle different input formats if isinstance(inputs, dict) and "image" in inputs: image_data = inputs["image"] elif isinstance(inputs, str): image_data = inputs else: raise ValueError("Invalid input format. Expected {'image': base64_string} or base64_string") # Decode base64 image try: image_bytes = base64.b64decode(image_data) image = Image.open(io.BytesIO(image_bytes)).convert('RGB') except Exception as e: raise ValueError(f"Failed to decode image: {str(e)}") # Preprocess image image_tensor = self.transform(image).unsqueeze(0).to(self.device) # Run inference with torch.no_grad(): outputs = self.model(image_tensor) probabilities = torch.nn.functional.softmax(outputs, dim=1) # Get top 3 predictions top3_prob, top3_indices = torch.topk(probabilities, 3) # Prepare response predictions = [] for i in range(3): predictions.append({ "label": self.class_names[top3_indices[0][i].item()], "score": float(top3_prob[0][i].item()) }) # Get all probabilities all_probs = {} for i, class_name in enumerate(self.class_names): all_probs[class_name] = float(probabilities[0][i].item()) return [{ "predictions": predictions, "probabilities": all_probs, "predicted_class": predictions[0]["label"], "confidence": predictions[0]["score"] }]