from PIL import Image import torch import torchvision.transforms as transforms import torchvision.models as models # Load model model = models.resnet18(pretrained=False, num_classes=3) model.load_state_dict(torch.load("pytorch_model.bin", map_location=torch.device("cpu"))) model.eval() # Preprocessing function transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Labels labels = ["A", "B", "C", "D", "E", "F", "G"] # Required function def predict(image: Image.Image): img_tensor = transform(image).unsqueeze(0) with torch.no_grad(): outputs = model(img_tensor) probs = torch.nn.functional.softmax(outputs[0], dim=0) return {labels[i]: float(probs[i]) for i in range(len(labels))}