sanjanatule commited on
Commit
8fcb5ff
·
1 Parent(s): 5fa7512

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -52,6 +52,8 @@ def inference(input_img, see_misclassified=False,num_misclassified_imgs=0,see_gr
52
  softmax = torch.nn.Softmax(dim=0)
53
  o = softmax(outputs.flatten())
54
  confidences = {classes[i]: float(o[i]) for i in range(10)}
 
 
55
  _, prediction = torch.max(outputs, 1)
56
 
57
  # gradcam
@@ -69,8 +71,8 @@ def inference(input_img, see_misclassified=False,num_misclassified_imgs=0,see_gr
69
  visualization = org_img
70
 
71
  # top n classes only
72
- confidences = {k: confidences[k] for k in list(confidences)[:top_classes]}
73
- return confidences, visualization
74
 
75
  title = "CIFAR10 trained on ResNet18 Model with GradCAM"
76
  description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results"
 
52
  softmax = torch.nn.Softmax(dim=0)
53
  o = softmax(outputs.flatten())
54
  confidences = {classes[i]: float(o[i]) for i in range(10)}
55
+ sorted_confidences = dict(sorted(confidences.items(), key=lambda x:x[1], reverse=True))
56
+
57
  _, prediction = torch.max(outputs, 1)
58
 
59
  # gradcam
 
71
  visualization = org_img
72
 
73
  # top n classes only
74
+ sorted_confidences = {k: sorted_confidences[k] for k in list(sorted_confidences)[:top_classes]}
75
+ return sorted_confidences, visualization
76
 
77
  title = "CIFAR10 trained on ResNet18 Model with GradCAM"
78
  description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results"