File size: 709 Bytes
2e50f2e
857d62b
 
2e50f2e
 
857d62b
 
 
 
2e50f2e
 
857d62b
 
2e50f2e
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
import torchvision.models as models
from torchvision import transforms
from PIL import Image

model = models.resnet50()
model.fc = torch.nn.Linear(model.fc.in_features, 6)
model_path = "resnet50_c3_lr3e-04_bs32_aug_heavy_opt_adam_drop0.5_ls0.1_6class.pth"
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval()

labels = ["angry","fear","happy","sad","surprise","neutral"]

transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor()
])

def predict(image):
    image = transform(image).unsqueeze(0)

    with torch.no_grad():
        output = model(image)
        pred = torch.argmax(output,1).item()

    return {"emotion": labels[pred]}