| | from PIL import Image
|
| | import io
|
| | import torch
|
| | import torch.nn as nn
|
| | import torchvision.transforms as transforms
|
| |
|
| | class EndpointHandler:
|
| | def __init__(self, model_path: str):
|
| |
|
| | from your_model_file import Net
|
| | self.model = Net()
|
| | self.model.load_state_dict(torch.load(model_path, map_location="cpu"))
|
| | self.model.eval()
|
| |
|
| | self.transform = transforms.Compose([
|
| | transforms.Grayscale(),
|
| | transforms.Resize((28,28)),
|
| | transforms.ToTensor(),
|
| | ])
|
| |
|
| | def __call__(self, data: dict) -> dict:
|
| |
|
| | image_bytes = data["inputs"]
|
| | image = Image.open(io.BytesIO(image_bytes)).convert("L")
|
| | tensor = self.transform(image).unsqueeze(0)
|
| | with torch.no_grad():
|
| | logits = self.model(tensor)
|
| | scores = torch.softmax(logits, dim=1)[0].tolist()
|
| |
|
| | return {"scores": scores}
|
| |
|