| import torch |
| import torchvision.models as models |
| from torchvision import transforms |
| from PIL import Image |
|
|
| model = models.resnet50() |
| model.fc = torch.nn.Linear(model.fc.in_features, 6) |
| model_path = "resnet50_c3_lr3e-04_bs32_aug_heavy_opt_adam_drop0.5_ls0.1_6class.pth" |
| model.load_state_dict(torch.load(model_path, map_location="cpu")) |
| model.eval() |
|
|
| labels = ["angry","fear","happy","sad","surprise","neutral"] |
|
|
| transform = transforms.Compose([ |
| transforms.Resize((224,224)), |
| transforms.ToTensor() |
| ]) |
|
|
| def predict(image): |
| image = transform(image).unsqueeze(0) |
|
|
| with torch.no_grad(): |
| output = model(image) |
| pred = torch.argmax(output,1).item() |
|
|
| return {"emotion": labels[pred]} |