|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
import io |
|
|
import base64 |
|
|
from typing import Dict, List, Any |
|
|
import timm |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path=""): |
|
|
""" |
|
|
Initialize handler with model path |
|
|
""" |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
self.class_names = ['Invalid', 'SDTI', 'Stage_I', 'Stage_II', 'Stage_III', 'Stage_IV', 'Unstageable'] |
|
|
|
|
|
|
|
|
self.model = timm.create_model('rexnet_150', pretrained=False, num_classes=7) |
|
|
|
|
|
|
|
|
model_path = f"{path}/pytorch_model.bin" if path else "pytorch_model.bin" |
|
|
state_dict = torch.load(model_path, map_location=self.device) |
|
|
self.model.load_state_dict(state_dict) |
|
|
self.model.to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
self.transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Process inference request |
|
|
""" |
|
|
|
|
|
inputs = data.pop("inputs", data) |
|
|
|
|
|
|
|
|
if isinstance(inputs, dict) and "image" in inputs: |
|
|
image_data = inputs["image"] |
|
|
elif isinstance(inputs, str): |
|
|
image_data = inputs |
|
|
else: |
|
|
raise ValueError("Invalid input format. Expected {'image': base64_string} or base64_string") |
|
|
|
|
|
|
|
|
try: |
|
|
image_bytes = base64.b64decode(image_data) |
|
|
image = Image.open(io.BytesIO(image_bytes)).convert('RGB') |
|
|
except Exception as e: |
|
|
raise ValueError(f"Failed to decode image: {str(e)}") |
|
|
|
|
|
|
|
|
image_tensor = self.transform(image).unsqueeze(0).to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(image_tensor) |
|
|
probabilities = torch.nn.functional.softmax(outputs, dim=1) |
|
|
|
|
|
|
|
|
top3_prob, top3_indices = torch.topk(probabilities, 3) |
|
|
|
|
|
|
|
|
predictions = [] |
|
|
for i in range(3): |
|
|
predictions.append({ |
|
|
"label": self.class_names[top3_indices[0][i].item()], |
|
|
"score": float(top3_prob[0][i].item()) |
|
|
}) |
|
|
|
|
|
|
|
|
all_probs = {} |
|
|
for i, class_name in enumerate(self.class_names): |
|
|
all_probs[class_name] = float(probabilities[0][i].item()) |
|
|
|
|
|
return [{ |
|
|
"predictions": predictions, |
|
|
"probabilities": all_probs, |
|
|
"predicted_class": predictions[0]["label"], |
|
|
"confidence": predictions[0]["score"] |
|
|
}] |