Spaces:
Runtime error
Runtime error
Commit
·
8fcb5ff
1
Parent(s):
5fa7512
Update app.py
Browse files
app.py
CHANGED
|
@@ -52,6 +52,8 @@ def inference(input_img, see_misclassified=False,num_misclassified_imgs=0,see_gr
|
|
| 52 |
softmax = torch.nn.Softmax(dim=0)
|
| 53 |
o = softmax(outputs.flatten())
|
| 54 |
confidences = {classes[i]: float(o[i]) for i in range(10)}
|
|
|
|
|
|
|
| 55 |
_, prediction = torch.max(outputs, 1)
|
| 56 |
|
| 57 |
# gradcam
|
|
@@ -69,8 +71,8 @@ def inference(input_img, see_misclassified=False,num_misclassified_imgs=0,see_gr
|
|
| 69 |
visualization = org_img
|
| 70 |
|
| 71 |
# top n classes only
|
| 72 |
-
|
| 73 |
-
return
|
| 74 |
|
| 75 |
title = "CIFAR10 trained on ResNet18 Model with GradCAM"
|
| 76 |
description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results"
|
|
|
|
| 52 |
softmax = torch.nn.Softmax(dim=0)
|
| 53 |
o = softmax(outputs.flatten())
|
| 54 |
confidences = {classes[i]: float(o[i]) for i in range(10)}
|
| 55 |
+
sorted_confidences = dict(sorted(confidences.items(), key=lambda x:x[1], reverse=True))
|
| 56 |
+
|
| 57 |
_, prediction = torch.max(outputs, 1)
|
| 58 |
|
| 59 |
# gradcam
|
|
|
|
| 71 |
visualization = org_img
|
| 72 |
|
| 73 |
# top n classes only
|
| 74 |
+
sorted_confidences = {k: sorted_confidences[k] for k in list(sorted_confidences)[:top_classes]}
|
| 75 |
+
return sorted_confidences, visualization
|
| 76 |
|
| 77 |
title = "CIFAR10 trained on ResNet18 Model with GradCAM"
|
| 78 |
description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results"
|