Spaces:
Runtime error
Runtime error
Commit
·
f12f9bc
1
Parent(s):
d3ca6f9
Update app.py
Browse files
app.py
CHANGED
|
@@ -66,16 +66,16 @@ inference_model = LitResnet.load_from_checkpoint("cifar10_customresnet_20_epoch.
|
|
| 66 |
|
| 67 |
def inference(input_img, see_misclassified,num_misclassified_imgs,see_gradcam,num_gradcam_imgs,transparency = 0.85, target_layer_number = -1,top_classes=3):
|
| 68 |
|
| 69 |
-
if see_misclassified: # show misclassified images
|
| 70 |
-
|
| 71 |
-
|
| 72 |
|
| 73 |
-
elif num_gradcam_imgs > 0: # show gradcam on example images
|
| 74 |
-
|
| 75 |
-
|
| 76 |
|
| 77 |
-
else: # nothing chosen - misclassified or gradcam
|
| 78 |
-
|
| 79 |
|
| 80 |
# model inference
|
| 81 |
transform = transforms.ToTensor()
|
|
@@ -109,12 +109,14 @@ def inference(input_img, see_misclassified,num_misclassified_imgs,see_gradcam,nu
|
|
| 109 |
|
| 110 |
title = "CIFAR10 trained on ResNet18 Model with GradCAM"
|
| 111 |
description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results"
|
|
|
|
| 112 |
|
| 113 |
demo = gr.Interface(
|
| 114 |
inference,
|
| 115 |
inputs = [gr.Image(shape=(32, 32), label="Input Image"), gr.Checkbox(label="Misclassified"),gr.Slider(0, 10, value = 0, step=1,label="Total Misclassified Images"),gr.Checkbox(label="Gradcam"),gr.Slider(0, 10, value = 0, step=1,label="Total GradCam Images"),gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM"), gr.Slider(-2, -1, value = -1, step=1, label="Which Layer?"), gr.Slider(1, 10, value=3, step=1, label="How many top classes?")],
|
| 116 |
outputs = [gr.Label(), gr.Image(shape=(32, 32), label="Output").style(width=128, height=128)],
|
| 117 |
title = title,
|
| 118 |
-
description = description,
|
|
|
|
| 119 |
|
| 120 |
demo.launch()
|
|
|
|
| 66 |
|
| 67 |
def inference(input_img, see_misclassified,num_misclassified_imgs,see_gradcam,num_gradcam_imgs,transparency = 0.85, target_layer_number = -1,top_classes=3):
|
| 68 |
|
| 69 |
+
# if see_misclassified: # show misclassified images
|
| 70 |
+
# org_img = cv2.imread('/content/drive/MyDrive/AI/ERA_course/session12/example_images/img_eg_0.jpg')
|
| 71 |
+
# input_img = org_img
|
| 72 |
|
| 73 |
+
# elif num_gradcam_imgs > 0: # show gradcam on example images
|
| 74 |
+
# org_img = cv2.imread('/content/drive/MyDrive/AI/ERA_course/session12/example_images/img_eg_0.jpg')
|
| 75 |
+
# input_img = org_img
|
| 76 |
|
| 77 |
+
# else: # nothing chosen - misclassified or gradcam
|
| 78 |
+
# org_img = input_img
|
| 79 |
|
| 80 |
# model inference
|
| 81 |
transform = transforms.ToTensor()
|
|
|
|
| 109 |
|
| 110 |
title = "CIFAR10 trained on ResNet18 Model with GradCAM"
|
| 111 |
description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results"
|
| 112 |
+
examples = [["img_eg_0.jpg", False,0,False,0.5, -1,3], ["img_eg_1.jpg", False,0,False,0.5, -1,3],["img_eg_2.jpg", False,0,False,0.5, -1,3],["img_eg_3.jpg", False,0,False,0.5, -1,3],["img_eg_4.jpg", False,0,False,0.5, -1,3],["img_eg_5.jpg", False,0,False,0.5, -1,3],["img_eg_6.jpg", False,0,False,0.5, -1,3],["img_eg_7.jpg", False,0,False,0.5, -1,3],["img_eg_8.jpg", False,0,False,0.5, -1,3]]
|
| 113 |
|
| 114 |
demo = gr.Interface(
|
| 115 |
inference,
|
| 116 |
inputs = [gr.Image(shape=(32, 32), label="Input Image"), gr.Checkbox(label="Misclassified"),gr.Slider(0, 10, value = 0, step=1,label="Total Misclassified Images"),gr.Checkbox(label="Gradcam"),gr.Slider(0, 10, value = 0, step=1,label="Total GradCam Images"),gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM"), gr.Slider(-2, -1, value = -1, step=1, label="Which Layer?"), gr.Slider(1, 10, value=3, step=1, label="How many top classes?")],
|
| 117 |
outputs = [gr.Label(), gr.Image(shape=(32, 32), label="Output").style(width=128, height=128)],
|
| 118 |
title = title,
|
| 119 |
+
description = description,
|
| 120 |
+
examples = examples)
|
| 121 |
|
| 122 |
demo.launch()
|