viserion999 commited on
Commit
857d62b
·
verified ·
1 Parent(s): 2e50f2e

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +8 -7
inference.py CHANGED
@@ -1,20 +1,21 @@
1
  import torch
2
- import torchvision.transforms as transforms
 
3
  from PIL import Image
4
- import pickle
5
 
6
- # load model
7
- model = pickle.load(open("resnet50_c3_lr3e-04_bs32_aug_heavy_opt_adam_dr.pkl", "rb"))
 
 
8
  model.eval()
9
 
 
 
10
  transform = transforms.Compose([
11
  transforms.Resize((224,224)),
12
  transforms.ToTensor()
13
  ])
14
 
15
- labels = ["angry","happy","sad","fear","surprise","neutral"]
16
-
17
-
18
  def predict(image):
19
  image = transform(image).unsqueeze(0)
20
 
 
1
  import torch
2
+ import torchvision.models as models
3
+ from torchvision import transforms
4
  from PIL import Image
 
5
 
6
+ model = models.resnet50()
7
+ model.fc = torch.nn.Linear(model.fc.in_features, 6)
8
+ model_path = "resnet50_c3_lr3e-04_bs32_aug_heavy_opt_adam_drop0.5_ls0.1_6class.pth"
9
+ model.load_state_dict(torch.load(model_path, map_location="cpu"))
10
  model.eval()
11
 
12
+ labels = ["angry","fear","happy","sad","surprise","neutral"]
13
+
14
  transform = transforms.Compose([
15
  transforms.Resize((224,224)),
16
  transforms.ToTensor()
17
  ])
18
 
 
 
 
19
  def predict(image):
20
  image = transform(image).unsqueeze(0)
21