Aumkeshchy2003 commited on
Commit
cafcdec
·
verified ·
1 Parent(s): 961131f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -207,11 +207,13 @@ transform = transforms.Compose([
207
  ])
208
 
209
  def predict(img: Image.Image):
210
- img = transform(img).unsqueeze(0).to(device)
211
  with torch.no_grad():
212
- out = model(img)
213
- pred = out.argmax(1).item()
214
- return classes[pred]
 
 
215
 
216
  # ------------------------
217
  # Gradio interface
 
207
  ])
208
 
209
  def predict(img: Image.Image):
210
+ img_t = transform(img).unsqueeze(0).to(device)
211
  with torch.no_grad():
212
+ out = model(img_t)
213
+ probs = torch.softmax(out, dim=1)[0]
214
+ top5 = probs.topk(5)
215
+ result = {classes[i]: float(probs[i]) for i in top5.indices}
216
+ return result
217
 
218
  # ------------------------
219
  # Gradio interface