sanjanatule commited on
Commit
fc4b125
·
1 Parent(s): ef2a693

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -0
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, torchvision
2
+ from torchvision import transforms
3
+ import numpy as np
4
+ import gradio as gr
5
+ from PIL import Image
6
+ from pytorch_grad_cam import GradCAM
7
+ from pytorch_grad_cam.utils.image import show_cam_on_image
8
+ from resnet import ResNet18
9
+ import gradio as gr
10
+
11
+ class LitResnet(LightningModule):
12
+ def __init__(self, num_classes=10, lr=0.05):
13
+ super().__init__()
14
+
15
+ self.save_hyperparameters()
16
+ self.model = custom_resnet.Net()
17
+ self.criterion = nn.CrossEntropyLoss()
18
+ self.BATCH_SIZE = 512
19
+ self.torchmetrics_accuracy = Accuracy(task="multiclass", num_classes= self.hparams.num_classes)
20
+
21
+ def forward(self, x):
22
+ out = self.model(x)
23
+ return out
24
+
25
+ def training_step(self, batch, batch_idx):
26
+ x, y = batch
27
+ y_pred = self(x)
28
+ loss = self.criterion(y_pred, y)
29
+ acc = self.torchmetrics_accuracy(y_pred, y)
30
+
31
+ self.log('train_loss', loss, prog_bar=True, on_step=False, on_epoch=True)
32
+ self.log('train_acc', acc, prog_bar=True, on_step=False, on_epoch=True)
33
+ return loss
34
+
35
+
36
+ def evaluate(self, batch, stage=None):
37
+ x, y = batch
38
+ y_test_pred = self(x)
39
+ loss = self.criterion(y_test_pred, y)
40
+ acc = self.torchmetrics_accuracy(y_test_pred, y)
41
+
42
+ if stage:
43
+ self.log(f"{stage}_loss", loss, prog_bar=True)
44
+ self.log(f"{stage}_acc", acc, prog_bar=True)
45
+
46
+ def test_step(self, batch, batch_idx):
47
+ self.evaluate(batch, "test")
48
+
49
+ def validation_step(self, batch, batch_idx):
50
+ self.evaluate(batch, "val")
51
+
52
+ def configure_optimizers(self):
53
+ optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=1e-4)
54
+ scheduler = OneCycleLR(
55
+ optimizer,
56
+ max_lr= 5.38E-02, #self.hparams.lr,
57
+ pct_start = 5/self.trainer.max_epochs,
58
+ epochs=self.trainer.max_epochs,
59
+ steps_per_epoch=len(train_loader),
60
+ div_factor=100,verbose=False,
61
+ three_phase=False
62
+ )
63
+ return ([optimizer],[scheduler])
64
+
65
+ inference_model = LitResnet.load_from_checkpoint("cifar10_customresnet_20_epoch.ckpt")
66
+
67
+ def inference(input_img, see_misclassified,num_misclassified_imgs,see_gradcam,num_gradcam_imgs,transparency = 0.85, target_layer_number = -1,top_classes=3):
68
+
69
+ if see_misclassified: # show misclassified images
70
+ org_img = cv2.imread('/content/drive/MyDrive/AI/ERA_course/session12/example_images/img_eg_0.jpg')
71
+ input_img = org_img
72
+
73
+ elif num_gradcam_imgs > 0: # show gradcam on example images
74
+ org_img = cv2.imread('/content/drive/MyDrive/AI/ERA_course/session12/example_images/img_eg_0.jpg')
75
+ input_img = org_img
76
+
77
+ else: # nothing chosen - misclassified or gradcam
78
+ org_img = input_img
79
+
80
+ # model inference
81
+ transform = transforms.ToTensor()
82
+ input_img = transform(input_img)
83
+ input_img = input_img.unsqueeze(0)
84
+ outputs = inference_model.model(input_img)
85
+ softmax = torch.nn.Softmax(dim=0)
86
+ o = softmax(outputs.flatten())
87
+ confidences = {classes[i]: float(o[i]) for i in range(10)}
88
+ _, prediction = torch.max(outputs, 1)
89
+
90
+ # gradcam
91
+ if see_gradcam:
92
+ target_layers = [inference_model.model.layer2[target_layer_number]]
93
+ cam = GradCAM(model=inference_model.model, target_layers=target_layers, use_cuda=False)
94
+ grayscale_cam = cam(input_tensor=input_img, targets=None)
95
+ grayscale_cam = grayscale_cam[0, :]
96
+ img = input_img.squeeze(0)
97
+ img = inv_normalize(img)
98
+ rgb_img = np.transpose(img, (1, 2, 0))
99
+ rgb_img = rgb_img.numpy()
100
+ visualization = show_cam_on_image(org_img/255.0, grayscale_cam, use_rgb=True, image_weight=transparency)
101
+ plt.imshow(visualization)
102
+ else:
103
+ plt.imshow(org_img)
104
+ visualization = org_img
105
+
106
+ # top n classes only
107
+ confidences = {k: confidences[k] for k in list(confidences)[:top_classes]}
108
+ return confidences, visualization
109
+
110
+ title = "CIFAR10 trained on ResNet18 Model with GradCAM"
111
+ description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results"
112
+
113
+ demo = gr.Interface(
114
+ inference,
115
+ inputs = [gr.Image(shape=(32, 32), label="Input Image"), gr.Checkbox(label="Misclassified"),gr.Slider(0, 10, value = 0, step=1,label="Total Misclassified Images"),gr.Checkbox(label="Gradcam"),gr.Slider(0, 10, value = 0, step=1,label="Total GradCam Images"),gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM"), gr.Slider(-2, -1, value = -1, step=1, label="Which Layer?"), gr.Slider(1, 10, value=3, step=1, label="How many top classes?")],
116
+ outputs = [gr.Label(), gr.Image(shape=(32, 32), label="Output").style(width=128, height=128)],
117
+ title = title,
118
+ description = description,)
119
+
120
+ demo.launch()