import torch import torchvision.models as models from torchvision import transforms from PIL import Image model = models.resnet50() model.fc = torch.nn.Linear(model.fc.in_features, 6) model_path = "resnet50_c3_lr3e-04_bs32_aug_heavy_opt_adam_drop0.5_ls0.1_6class.pth" model.load_state_dict(torch.load(model_path, map_location="cpu")) model.eval() labels = ["angry","fear","happy","sad","surprise","neutral"] transform = transforms.Compose([ transforms.Resize((224,224)), transforms.ToTensor() ]) def predict(image): image = transform(image).unsqueeze(0) with torch.no_grad(): output = model(image) pred = torch.argmax(output,1).item() return {"emotion": labels[pred]}