import torch from PIL import Image import torchvision.transforms as T from model import get_model import os class EndpointHandler(): def __init__(self, path=""): # Load model and weights self.model = get_model() weights_path = os.path.join(path, "doge_223_sd-04.bin") self.model.load_state_dict(torch.load(weights_path, map_location="cpu")) self.model.eval() # Define your specific transforms self.transform = T.Compose([ T.Resize((224, 224)), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) def __call__(self, data): inputs = data.pop("inputs", data) # Convert bytes to image image = Image.open(inputs).convert("RGB") tensor = self.transform(image).unsqueeze(0) with torch.no_grad(): outputs = self.model(tensor) prediction = torch.argmax(outputs, dim=1).item() return {"label": prediction}