Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -58,7 +58,6 @@ tfm = transforms.Compose([
|
|
| 58 |
),
|
| 59 |
])
|
| 60 |
|
| 61 |
-
|
| 62 |
@torch.no_grad()
|
| 63 |
def predict(image: Image.Image):
|
| 64 |
if image is None:
|
|
@@ -68,11 +67,20 @@ def predict(image: Image.Image):
|
|
| 68 |
x = tfm(image).unsqueeze(0).to(DEVICE)
|
| 69 |
|
| 70 |
log_sp, log_st = model(x)
|
| 71 |
-
sp_id = int(log_sp.argmax(dim=1))
|
| 72 |
-
st_id = int(log_st.argmax(dim=1))
|
| 73 |
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
|
|
|
| 76 |
|
| 77 |
demo = gr.Interface(
|
| 78 |
fn=predict,
|
|
|
|
| 58 |
),
|
| 59 |
])
|
| 60 |
|
|
|
|
| 61 |
@torch.no_grad()
|
| 62 |
def predict(image: Image.Image):
|
| 63 |
if image is None:
|
|
|
|
| 67 |
x = tfm(image).unsqueeze(0).to(DEVICE)
|
| 68 |
|
| 69 |
log_sp, log_st = model(x)
|
|
|
|
|
|
|
| 70 |
|
| 71 |
+
prob_sp = torch.softmax(log_sp, dim=1)[0] # [num_species]
|
| 72 |
+
prob_st = torch.softmax(log_st, dim=1)[0] # [num_states]
|
| 73 |
+
|
| 74 |
+
sp_id = int(prob_sp.argmax().item())
|
| 75 |
+
st_id = int(prob_st.argmax().item())
|
| 76 |
+
|
| 77 |
+
sp_conf = float(prob_sp[sp_id].item())
|
| 78 |
+
st_conf = float(prob_st[st_id].item())
|
| 79 |
+
|
| 80 |
+
sp_text = f"{SPECIES[sp_id]} (id={sp_id}, conf={sp_conf:.3f})"
|
| 81 |
+
st_text = f"{STATE[st_id]} (id={st_id}, conf={st_conf:.3f})"
|
| 82 |
|
| 83 |
+
return sp_text, st_text
|
| 84 |
|
| 85 |
demo = gr.Interface(
|
| 86 |
fn=predict,
|