mmek commited on
Commit
66fa8be
·
1 Parent(s): 76e7437

add model transforms

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -17,7 +17,7 @@ def classify_health(input_img):
17
  image = data_transforms(input_img).unsqueeze(0)
18
  probs = model(image)
19
  idx = probs.argmax(dim=1)
20
- return dict(zip(categories, map(float, probs)))
21
 
22
 
23
  labels = gr.Label()
 
17
  image = data_transforms(input_img).unsqueeze(0)
18
  probs = model(image)
19
  idx = probs.argmax(dim=1)
20
+ return dict(zip(categories, map(float, probs[0])))
21
 
22
 
23
  labels = gr.Label()