Update app.py
Browse files
app.py
CHANGED
|
@@ -207,11 +207,13 @@ transform = transforms.Compose([
|
|
| 207 |
])
|
| 208 |
|
| 209 |
def predict(img: Image.Image):
|
| 210 |
-
|
| 211 |
with torch.no_grad():
|
| 212 |
-
out = model(
|
| 213 |
-
|
| 214 |
-
|
|
|
|
|
|
|
| 215 |
|
| 216 |
# ------------------------
|
| 217 |
# Gradio interface
|
|
|
|
| 207 |
])
|
| 208 |
|
| 209 |
def predict(img: Image.Image):
|
| 210 |
+
img_t = transform(img).unsqueeze(0).to(device)
|
| 211 |
with torch.no_grad():
|
| 212 |
+
out = model(img_t)
|
| 213 |
+
probs = torch.softmax(out, dim=1)[0]
|
| 214 |
+
top5 = probs.topk(5)
|
| 215 |
+
result = {classes[i]: float(probs[i]) for i in top5.indices}
|
| 216 |
+
return result
|
| 217 |
|
| 218 |
# ------------------------
|
| 219 |
# Gradio interface
|