sanjanatule commited on
Commit
9a9229d
·
1 Parent(s): dba3c56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -59
app.py CHANGED
@@ -15,7 +15,6 @@ import gradio as gr
15
  from PIL import Image
16
  from pytorch_grad_cam import GradCAM
17
  from pytorch_grad_cam.utils.image import show_cam_on_image
18
- from models import custom_resnet
19
  import gradio as gr
20
  from pytorch_lightning import LightningModule, Trainer, seed_everything
21
  from pytorch_lightning.callbacks import LearningRateMonitor
@@ -23,60 +22,9 @@ from pytorch_lightning.callbacks.progress import TQDMProgressBar
23
  from pytorch_lightning.loggers import CSVLogger
24
  from pytorch_lightning.loggers import TensorBoardLogger
25
  from torchmetrics import Accuracy
 
 
26
 
27
- class LitResnet(LightningModule):
28
- def __init__(self, num_classes=10, lr=0.05):
29
- super().__init__()
30
-
31
- self.save_hyperparameters()
32
- self.model = custom_resnet.Net()
33
- self.criterion = nn.CrossEntropyLoss()
34
- self.BATCH_SIZE = 512
35
- self.torchmetrics_accuracy = Accuracy(task="multiclass", num_classes= self.hparams.num_classes)
36
-
37
- def forward(self, x):
38
- out = self.model(x)
39
- return out
40
-
41
- def training_step(self, batch, batch_idx):
42
- x, y = batch
43
- y_pred = self(x)
44
- loss = self.criterion(y_pred, y)
45
- acc = self.torchmetrics_accuracy(y_pred, y)
46
-
47
- self.log('train_loss', loss, prog_bar=True, on_step=False, on_epoch=True)
48
- self.log('train_acc', acc, prog_bar=True, on_step=False, on_epoch=True)
49
- return loss
50
-
51
-
52
- def evaluate(self, batch, stage=None):
53
- x, y = batch
54
- y_test_pred = self(x)
55
- loss = self.criterion(y_test_pred, y)
56
- acc = self.torchmetrics_accuracy(y_test_pred, y)
57
-
58
- if stage:
59
- self.log(f"{stage}_loss", loss, prog_bar=True)
60
- self.log(f"{stage}_acc", acc, prog_bar=True)
61
-
62
- def test_step(self, batch, batch_idx):
63
- self.evaluate(batch, "test")
64
-
65
- def validation_step(self, batch, batch_idx):
66
- self.evaluate(batch, "val")
67
-
68
- def configure_optimizers(self):
69
- optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=1e-4)
70
- scheduler = OneCycleLR(
71
- optimizer,
72
- max_lr= 5.38E-02, #self.hparams.lr,
73
- pct_start = 5/self.trainer.max_epochs,
74
- epochs=self.trainer.max_epochs,
75
- steps_per_epoch=len(train_loader),
76
- div_factor=100,verbose=False,
77
- three_phase=False
78
- )
79
- return ([optimizer],[scheduler])
80
 
81
  inference_model = LitResnet.load_from_checkpoint("cifar10_customresnet_20_epoch.ckpt")
82
 
@@ -112,10 +60,10 @@ def inference(input_img, see_misclassified=False,num_misclassified_imgs=0,see_gr
112
  cam = GradCAM(model=inference_model.model, target_layers=target_layers, use_cuda=False)
113
  grayscale_cam = cam(input_tensor=input_img, targets=None)
114
  grayscale_cam = grayscale_cam[0, :]
115
- img = input_img.squeeze(0)
116
- img = inv_normalize(img)
117
- rgb_img = np.transpose(img, (1, 2, 0))
118
- rgb_img = rgb_img.numpy()
119
  visualization = show_cam_on_image(org_img/255.0, grayscale_cam, use_rgb=True, image_weight=transparency)
120
  else:
121
  visualization = org_img
@@ -126,7 +74,6 @@ def inference(input_img, see_misclassified=False,num_misclassified_imgs=0,see_gr
126
 
127
  title = "CIFAR10 trained on ResNet18 Model with GradCAM"
128
  description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results"
129
- #examples = [["img_eg_0.jpg", False,0,False,0.5, -1,3], ["img_eg_1.jpg", False,0,False,0.5, -1,3],["img_eg_2.jpg", False,0,False,0.5, -1,3],["img_eg_3.jpg", False,0,False,0.5, -1,3],["img_eg_4.jpg", False,0,False,0.5, -1,3],["img_eg_5.jpg", False,0,False,0.5, -1,3],["img_eg_6.jpg", False,0,False,0.5, -1,3],["img_eg_7.jpg", False,0,False,0.5, -1,3],["img_eg_8.jpg", False,0,False,0.5, -1,3]]
130
  examples = [["img_eg_0.jpg"], ["img_eg_1.jpg"],["img_eg_2.jpg"],["img_eg_3.jpg"],["img_eg_4.jpg"],["img_eg_5.jpg"],["img_eg_6.jpg"],["img_eg_7.jpg"],["img_eg_8.jpg"],["img_eg_9.jpg"]]
131
 
132
 
 
15
  from PIL import Image
16
  from pytorch_grad_cam import GradCAM
17
  from pytorch_grad_cam.utils.image import show_cam_on_image
 
18
  import gradio as gr
19
  from pytorch_lightning import LightningModule, Trainer, seed_everything
20
  from pytorch_lightning.callbacks import LearningRateMonitor
 
22
  from pytorch_lightning.loggers import CSVLogger
23
  from pytorch_lightning.loggers import TensorBoardLogger
24
  from torchmetrics import Accuracy
25
+ from models import custom_resnet
26
+ from network import LitResnet
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  inference_model = LitResnet.load_from_checkpoint("cifar10_customresnet_20_epoch.ckpt")
30
 
 
60
  cam = GradCAM(model=inference_model.model, target_layers=target_layers, use_cuda=False)
61
  grayscale_cam = cam(input_tensor=input_img, targets=None)
62
  grayscale_cam = grayscale_cam[0, :]
63
+ # img = input_img.squeeze(0)
64
+ # img = inv_normalize(img)
65
+ # rgb_img = np.transpose(img, (1, 2, 0))
66
+ # rgb_img = rgb_img.numpy()
67
  visualization = show_cam_on_image(org_img/255.0, grayscale_cam, use_rgb=True, image_weight=transparency)
68
  else:
69
  visualization = org_img
 
74
 
75
  title = "CIFAR10 trained on ResNet18 Model with GradCAM"
76
  description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results"
 
77
  examples = [["img_eg_0.jpg"], ["img_eg_1.jpg"],["img_eg_2.jpg"],["img_eg_3.jpg"],["img_eg_4.jpg"],["img_eg_5.jpg"],["img_eg_6.jpg"],["img_eg_7.jpg"],["img_eg_8.jpg"],["img_eg_9.jpg"]]
78
 
79