viserion999 commited on
Commit
2e50f2e
·
verified ·
1 Parent(s): 1e0192e

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +25 -0
inference.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as transforms
3
+ from PIL import Image
4
+ import pickle
5
+
6
+ # load model
7
+ model = pickle.load(open("resnet50_c3_lr3e-04_bs32_aug_heavy_opt_adam_dr.pkl", "rb"))
8
+ model.eval()
9
+
10
+ transform = transforms.Compose([
11
+ transforms.Resize((224,224)),
12
+ transforms.ToTensor()
13
+ ])
14
+
15
+ labels = ["angry","happy","sad","fear","surprise","neutral"]
16
+
17
+
18
+ def predict(image):
19
+ image = transform(image).unsqueeze(0)
20
+
21
+ with torch.no_grad():
22
+ output = model(image)
23
+ pred = torch.argmax(output,1).item()
24
+
25
+ return {"emotion": labels[pred]}