File size: 1,031 Bytes
756deb2 59a758f 756deb2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | import torch
from PIL import Image
import torchvision.transforms as T
from model import get_model
import os
class EndpointHandler():
def __init__(self, path=""):
# Load model and weights
self.model = get_model()
weights_path = os.path.join(path, "doge_223_sd-04.bin")
self.model.load_state_dict(torch.load(weights_path, map_location="cpu"))
self.model.eval()
# Define your specific transforms
self.transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def __call__(self, data):
inputs = data.pop("inputs", data)
# Convert bytes to image
image = Image.open(inputs).convert("RGB")
tensor = self.transform(image).unsqueeze(0)
with torch.no_grad():
outputs = self.model(tensor)
prediction = torch.argmax(outputs, dim=1).item()
return {"label": prediction} |