Update app.py
Browse files
app.py
CHANGED
|
@@ -40,39 +40,78 @@ def resize_image_pil(image, new_width, new_height):
|
|
| 40 |
|
| 41 |
return resized
|
| 42 |
|
| 43 |
-
def inference(input_img, transparency):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
transform = transforms.ToTensor()
|
| 45 |
input_img = transform(input_img)
|
| 46 |
-
|
| 47 |
input_img = input_img.unsqueeze(0)
|
| 48 |
outputs = model(input_img)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
_, prediction = torch.max(outputs, 1)
|
| 50 |
-
target_layers = [model.layer2[
|
| 51 |
-
cam = GradCAM(model=model, target_layers=target_layers
|
| 52 |
-
grayscale_cam = cam(input_tensor=input_img, targets=
|
| 53 |
grayscale_cam = grayscale_cam[0, :]
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
demo = gr.Interface(
|
| 62 |
-
inference,
|
| 63 |
-
inputs
|
| 64 |
gr.Image(width=256, height=256, label="Input Image"),
|
| 65 |
-
gr.Slider(0,
|
| 66 |
-
gr.Slider(-2, -1, value=-2,
|
| 67 |
],
|
| 68 |
outputs = [
|
| 69 |
"text",
|
| 70 |
gr.Image(width=256, height=256, label="Output"),
|
| 71 |
gr.Label(num_top_classes=3)
|
| 72 |
],
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
)
|
| 77 |
|
| 78 |
demo.launch()
|
|
|
|
| 40 |
|
| 41 |
return resized
|
| 42 |
|
| 43 |
+
# def inference(input_img, transparency):
|
| 44 |
+
# transform = transforms.ToTensor()
|
| 45 |
+
# input_img = transform(input_img)
|
| 46 |
+
# input_img = input_img.to(device)
|
| 47 |
+
# input_img = input_img.unsqueeze(0)
|
| 48 |
+
# outputs = model(input_img)
|
| 49 |
+
# _, prediction = torch.max(outputs, 1)
|
| 50 |
+
# target_layers = [model.layer2[-2]]
|
| 51 |
+
# cam = GradCAM(model=model, target_layers=target_layers, use_cuda=True)
|
| 52 |
+
# grayscale_cam = cam(input_tensor=input_img, targets=targets)
|
| 53 |
+
# grayscale_cam = grayscale_cam[0, :]
|
| 54 |
+
# img = input_img.squeeze(0).to('cpu')
|
| 55 |
+
# img = inv_normalize(img)
|
| 56 |
+
# rgb_img = np.transpose(img, (1, 2, 0))
|
| 57 |
+
# rgb_img = rgb_img.numpy()
|
| 58 |
+
# visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency)
|
| 59 |
+
# return classes[prediction[0].item()], visualization
|
| 60 |
+
|
| 61 |
+
def inference(input_img, transparency=0.5, target_layer_number=-1):
|
| 62 |
+
input_img = resize_image_pil(input_img, 32, 32)
|
| 63 |
+
input_img = np.array(input_img)
|
| 64 |
+
org_img= input_img
|
| 65 |
+
|
| 66 |
+
input_img = input_img.reshape((32, 32, 3))
|
| 67 |
+
|
| 68 |
transform = transforms.ToTensor()
|
| 69 |
input_img = transform(input_img)
|
| 70 |
+
|
| 71 |
input_img = input_img.unsqueeze(0)
|
| 72 |
outputs = model(input_img)
|
| 73 |
+
|
| 74 |
+
softmax = torch.nn.Softmax(dim=0)
|
| 75 |
+
o = softmax(outputs.flatten())
|
| 76 |
+
confidences = {classes[i] : float(o[i]) for i in range(10)}
|
| 77 |
_, prediction = torch.max(outputs, 1)
|
| 78 |
+
target_layers = [model.layer2[target_layer_number]]
|
| 79 |
+
cam = GradCAM(model=model, target_layers = target_layers)
|
| 80 |
+
grayscale_cam = cam(input_tensor=input_img, targets=None)
|
| 81 |
grayscale_cam = grayscale_cam[0, :]
|
| 82 |
+
visualization = show_cam_on_image(
|
| 83 |
+
org_img/255,
|
| 84 |
+
grayscale_cam,
|
| 85 |
+
use_rgb=True,
|
| 86 |
+
image_weight=transparency
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
return classes[prediction[0].item()], visualization, confidences
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
|
| 93 |
|
| 94 |
demo = gr.Interface(
|
| 95 |
+
fn=inference,
|
| 96 |
+
inputs=[
|
| 97 |
gr.Image(width=256, height=256, label="Input Image"),
|
| 98 |
+
gr.Slider(0,1, value=0.5, label="Overall opacity value"),
|
| 99 |
+
gr.Slider(-2, -1, value=-2, label="Which model layer to use for GradCAM?")
|
| 100 |
],
|
| 101 |
outputs = [
|
| 102 |
"text",
|
| 103 |
gr.Image(width=256, height=256, label="Output"),
|
| 104 |
gr.Label(num_top_classes=3)
|
| 105 |
],
|
| 106 |
+
|
| 107 |
+
title="CIFAR10 trained on ResNet18 with GradCAM",
|
| 108 |
+
|
| 109 |
+
description = "A simple Gradio interface to infer on ResNet model with GradCAM results shown on top.",
|
| 110 |
+
|
| 111 |
+
examples = [
|
| 112 |
+
["cat.jpg", 0.5, -1],
|
| 113 |
+
["dog.jpg", 0.7, -2]
|
| 114 |
+
]
|
| 115 |
)
|
| 116 |
|
| 117 |
demo.launch()
|