Spaces:
Runtime error
Runtime error
Commit
·
9a9229d
1
Parent(s):
dba3c56
Update app.py
Browse files
app.py
CHANGED
|
@@ -15,7 +15,6 @@ import gradio as gr
|
|
| 15 |
from PIL import Image
|
| 16 |
from pytorch_grad_cam import GradCAM
|
| 17 |
from pytorch_grad_cam.utils.image import show_cam_on_image
|
| 18 |
-
from models import custom_resnet
|
| 19 |
import gradio as gr
|
| 20 |
from pytorch_lightning import LightningModule, Trainer, seed_everything
|
| 21 |
from pytorch_lightning.callbacks import LearningRateMonitor
|
|
@@ -23,60 +22,9 @@ from pytorch_lightning.callbacks.progress import TQDMProgressBar
|
|
| 23 |
from pytorch_lightning.loggers import CSVLogger
|
| 24 |
from pytorch_lightning.loggers import TensorBoardLogger
|
| 25 |
from torchmetrics import Accuracy
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
class LitResnet(LightningModule):
|
| 28 |
-
def __init__(self, num_classes=10, lr=0.05):
|
| 29 |
-
super().__init__()
|
| 30 |
-
|
| 31 |
-
self.save_hyperparameters()
|
| 32 |
-
self.model = custom_resnet.Net()
|
| 33 |
-
self.criterion = nn.CrossEntropyLoss()
|
| 34 |
-
self.BATCH_SIZE = 512
|
| 35 |
-
self.torchmetrics_accuracy = Accuracy(task="multiclass", num_classes= self.hparams.num_classes)
|
| 36 |
-
|
| 37 |
-
def forward(self, x):
|
| 38 |
-
out = self.model(x)
|
| 39 |
-
return out
|
| 40 |
-
|
| 41 |
-
def training_step(self, batch, batch_idx):
|
| 42 |
-
x, y = batch
|
| 43 |
-
y_pred = self(x)
|
| 44 |
-
loss = self.criterion(y_pred, y)
|
| 45 |
-
acc = self.torchmetrics_accuracy(y_pred, y)
|
| 46 |
-
|
| 47 |
-
self.log('train_loss', loss, prog_bar=True, on_step=False, on_epoch=True)
|
| 48 |
-
self.log('train_acc', acc, prog_bar=True, on_step=False, on_epoch=True)
|
| 49 |
-
return loss
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
def evaluate(self, batch, stage=None):
|
| 53 |
-
x, y = batch
|
| 54 |
-
y_test_pred = self(x)
|
| 55 |
-
loss = self.criterion(y_test_pred, y)
|
| 56 |
-
acc = self.torchmetrics_accuracy(y_test_pred, y)
|
| 57 |
-
|
| 58 |
-
if stage:
|
| 59 |
-
self.log(f"{stage}_loss", loss, prog_bar=True)
|
| 60 |
-
self.log(f"{stage}_acc", acc, prog_bar=True)
|
| 61 |
-
|
| 62 |
-
def test_step(self, batch, batch_idx):
|
| 63 |
-
self.evaluate(batch, "test")
|
| 64 |
-
|
| 65 |
-
def validation_step(self, batch, batch_idx):
|
| 66 |
-
self.evaluate(batch, "val")
|
| 67 |
-
|
| 68 |
-
def configure_optimizers(self):
|
| 69 |
-
optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=1e-4)
|
| 70 |
-
scheduler = OneCycleLR(
|
| 71 |
-
optimizer,
|
| 72 |
-
max_lr= 5.38E-02, #self.hparams.lr,
|
| 73 |
-
pct_start = 5/self.trainer.max_epochs,
|
| 74 |
-
epochs=self.trainer.max_epochs,
|
| 75 |
-
steps_per_epoch=len(train_loader),
|
| 76 |
-
div_factor=100,verbose=False,
|
| 77 |
-
three_phase=False
|
| 78 |
-
)
|
| 79 |
-
return ([optimizer],[scheduler])
|
| 80 |
|
| 81 |
inference_model = LitResnet.load_from_checkpoint("cifar10_customresnet_20_epoch.ckpt")
|
| 82 |
|
|
@@ -112,10 +60,10 @@ def inference(input_img, see_misclassified=False,num_misclassified_imgs=0,see_gr
|
|
| 112 |
cam = GradCAM(model=inference_model.model, target_layers=target_layers, use_cuda=False)
|
| 113 |
grayscale_cam = cam(input_tensor=input_img, targets=None)
|
| 114 |
grayscale_cam = grayscale_cam[0, :]
|
| 115 |
-
img = input_img.squeeze(0)
|
| 116 |
-
img = inv_normalize(img)
|
| 117 |
-
rgb_img = np.transpose(img, (1, 2, 0))
|
| 118 |
-
rgb_img = rgb_img.numpy()
|
| 119 |
visualization = show_cam_on_image(org_img/255.0, grayscale_cam, use_rgb=True, image_weight=transparency)
|
| 120 |
else:
|
| 121 |
visualization = org_img
|
|
@@ -126,7 +74,6 @@ def inference(input_img, see_misclassified=False,num_misclassified_imgs=0,see_gr
|
|
| 126 |
|
| 127 |
title = "CIFAR10 trained on ResNet18 Model with GradCAM"
|
| 128 |
description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results"
|
| 129 |
-
#examples = [["img_eg_0.jpg", False,0,False,0.5, -1,3], ["img_eg_1.jpg", False,0,False,0.5, -1,3],["img_eg_2.jpg", False,0,False,0.5, -1,3],["img_eg_3.jpg", False,0,False,0.5, -1,3],["img_eg_4.jpg", False,0,False,0.5, -1,3],["img_eg_5.jpg", False,0,False,0.5, -1,3],["img_eg_6.jpg", False,0,False,0.5, -1,3],["img_eg_7.jpg", False,0,False,0.5, -1,3],["img_eg_8.jpg", False,0,False,0.5, -1,3]]
|
| 130 |
examples = [["img_eg_0.jpg"], ["img_eg_1.jpg"],["img_eg_2.jpg"],["img_eg_3.jpg"],["img_eg_4.jpg"],["img_eg_5.jpg"],["img_eg_6.jpg"],["img_eg_7.jpg"],["img_eg_8.jpg"],["img_eg_9.jpg"]]
|
| 131 |
|
| 132 |
|
|
|
|
| 15 |
from PIL import Image
|
| 16 |
from pytorch_grad_cam import GradCAM
|
| 17 |
from pytorch_grad_cam.utils.image import show_cam_on_image
|
|
|
|
| 18 |
import gradio as gr
|
| 19 |
from pytorch_lightning import LightningModule, Trainer, seed_everything
|
| 20 |
from pytorch_lightning.callbacks import LearningRateMonitor
|
|
|
|
| 22 |
from pytorch_lightning.loggers import CSVLogger
|
| 23 |
from pytorch_lightning.loggers import TensorBoardLogger
|
| 24 |
from torchmetrics import Accuracy
|
| 25 |
+
from models import custom_resnet
|
| 26 |
+
from network import LitResnet
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
inference_model = LitResnet.load_from_checkpoint("cifar10_customresnet_20_epoch.ckpt")
|
| 30 |
|
|
|
|
| 60 |
cam = GradCAM(model=inference_model.model, target_layers=target_layers, use_cuda=False)
|
| 61 |
grayscale_cam = cam(input_tensor=input_img, targets=None)
|
| 62 |
grayscale_cam = grayscale_cam[0, :]
|
| 63 |
+
# img = input_img.squeeze(0)
|
| 64 |
+
# img = inv_normalize(img)
|
| 65 |
+
# rgb_img = np.transpose(img, (1, 2, 0))
|
| 66 |
+
# rgb_img = rgb_img.numpy()
|
| 67 |
visualization = show_cam_on_image(org_img/255.0, grayscale_cam, use_rgb=True, image_weight=transparency)
|
| 68 |
else:
|
| 69 |
visualization = org_img
|
|
|
|
| 74 |
|
| 75 |
title = "CIFAR10 trained on ResNet18 Model with GradCAM"
|
| 76 |
description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results"
|
|
|
|
| 77 |
examples = [["img_eg_0.jpg"], ["img_eg_1.jpg"],["img_eg_2.jpg"],["img_eg_3.jpg"],["img_eg_4.jpg"],["img_eg_5.jpg"],["img_eg_6.jpg"],["img_eg_7.jpg"],["img_eg_8.jpg"],["img_eg_9.jpg"]]
|
| 78 |
|
| 79 |
|