mmek commited on
Commit
6ecf40e
·
1 Parent(s): 623c32d

add model transforms

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -12,10 +12,11 @@ categories = ("Aculus Olearius", "Healthy", "Peacock Spot")
12
 
13
 
14
  def classify_health(input_img):
15
- input_img = torch.from_numpy(input_img)
16
- image = transforms(input_img).unsqueeze(0)
17
- probs = model(image)
18
- idx = probs.argmax()
 
19
  return dict(zip(categories, map(float, probs)))
20
 
21
 
 
12
 
13
 
14
  def classify_health(input_img):
15
+ input_img = transforms.ToTensor()(input_img)
16
+ with torch.no_grad():
17
+ image = transforms(input_img).unsqueeze(0)
18
+ probs = model(image)
19
+ idx = probs.argmax()
20
  return dict(zip(categories, map(float, probs)))
21
 
22