chenchangliu commited on
Commit
0fc44d7
·
verified ·
1 Parent(s): 68fb97b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -4
app.py CHANGED
@@ -58,7 +58,6 @@ tfm = transforms.Compose([
58
  ),
59
  ])
60
 
61
-
62
  @torch.no_grad()
63
  def predict(image: Image.Image):
64
  if image is None:
@@ -68,11 +67,20 @@ def predict(image: Image.Image):
68
  x = tfm(image).unsqueeze(0).to(DEVICE)
69
 
70
  log_sp, log_st = model(x)
71
- sp_id = int(log_sp.argmax(dim=1))
72
- st_id = int(log_st.argmax(dim=1))
73
 
74
- return SPECIES[sp_id], STATE[st_id]
 
 
 
 
 
 
 
 
 
 
75
 
 
76
 
77
  demo = gr.Interface(
78
  fn=predict,
 
58
  ),
59
  ])
60
 
 
61
  @torch.no_grad()
62
  def predict(image: Image.Image):
63
  if image is None:
 
67
  x = tfm(image).unsqueeze(0).to(DEVICE)
68
 
69
  log_sp, log_st = model(x)
 
 
70
 
71
+ prob_sp = torch.softmax(log_sp, dim=1)[0] # [num_species]
72
+ prob_st = torch.softmax(log_st, dim=1)[0] # [num_states]
73
+
74
+ sp_id = int(prob_sp.argmax().item())
75
+ st_id = int(prob_st.argmax().item())
76
+
77
+ sp_conf = float(prob_sp[sp_id].item())
78
+ st_conf = float(prob_st[st_id].item())
79
+
80
+ sp_text = f"{SPECIES[sp_id]} (id={sp_id}, conf={sp_conf:.3f})"
81
+ st_text = f"{STATE[st_id]} (id={st_id}, conf={st_conf:.3f})"
82
 
83
+ return sp_text, st_text
84
 
85
  demo = gr.Interface(
86
  fn=predict,