sadjava commited on
Commit
94e4de7
·
1 Parent(s): 9b34fa8

fixed for cpu

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -8,11 +8,12 @@ import gradio as gr
8
  import torch
9
  from torch.nn.functional import softmax
10
  import numpy as np
 
11
 
12
  # %% ../app.ipynb 3
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
- model = torch.load('model.pth').to(device)
16
  model.eval()
17
 
18
  # %% ../app.ipynb 4
@@ -20,11 +21,10 @@ CLASS_LABELS = ['Anger', 'Disgust', 'Fear', 'Happy', 'Neutral', 'Sadness', "Sur
20
 
21
  # %% ../app.ipynb 5
22
  def classify_emotions(im):
23
- im = np.array(im) / 255
24
- if len(im.shape) == 2:
25
- im = im[..., np.newaxis]
26
- if im.shape[-1] == 1:
27
- im = np.concatenate((im, im, im), 2)
28
  im = torch.tensor(im.transpose(2, 0, 1), dtype=torch.float32)
29
  prediction = model.forward(im[np.newaxis, ...].to(device))
30
  return dict(zip(CLASS_LABELS, *softmax(prediction, dim=1).tolist()))
 
8
  import torch
9
  from torch.nn.functional import softmax
10
  import numpy as np
11
+ from PIL import Image
12
 
13
  # %% ../app.ipynb 3
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
+ model = torch.load('model.pth', map_location=torch.device('cpu')).to(device)
17
  model.eval()
18
 
19
  # %% ../app.ipynb 4
 
21
 
22
  # %% ../app.ipynb 5
23
  def classify_emotions(im):
24
+ im = np.array(im)
25
+ im = np.array(Image.fromarray(im).convert('L')) / 255
26
+ im = im[..., np.newaxis]
27
+ im = np.concatenate((im, im, im), 2)
 
28
  im = torch.tensor(im.transpose(2, 0, 1), dtype=torch.float32)
29
  prediction = model.forward(im[np.newaxis, ...].to(device))
30
  return dict(zip(CLASS_LABELS, *softmax(prediction, dim=1).tolist()))