gonzalocordova commited on
Commit
14cf308
·
1 Parent(s): 9bdb201

fix: predict_fn return bug

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -19,7 +19,7 @@ transform = transforms.Compose([
19
  ])
20
 
21
 
22
- def predict_fn(image, raw_output=False):
23
  """
24
  This function will predict the class of an image
25
  :param image_path: The path of the image
@@ -37,6 +37,7 @@ def predict_fn(image, raw_output=False):
37
  probabilities = torch.exp(output)
38
  top_p, top_class = probabilities.topk(1, dim=1)
39
 
40
- return top_class.numpy()[0][0]
 
41
 
42
  gr.Interface(predict_fn, gr.inputs.Image(type="pil", label="Input Image"), outputs="label").launch()
 
19
  ])
20
 
21
 
22
+ def predict_fn(image):
23
  """
24
  This function will predict the class of an image
25
  :param image_path: The path of the image
 
37
  probabilities = torch.exp(output)
38
  top_p, top_class = probabilities.topk(1, dim=1)
39
 
40
+ # return dictionary whose keys are labels and values are confidences
41
+ return {str(top_class.item()), str(top_p.item())}
42
 
43
  gr.Interface(predict_fn, gr.inputs.Image(type="pil", label="Input Image"), outputs="label").launch()