PrarthanaTS commited on
Commit
c7af039
·
1 Parent(s): 02c7167

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -0
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Fri Aug 11 18:08:06 2023
4
+ @author: prarthana.ts
5
+ """
6
+
7
+ import torch
8
+ import torch.optim as optim
9
+ import lightning.pytorch as pl
10
+ from lightning.pytorch.tuner import Tuner
11
+
12
+ # import pytorch_lightning as pl
13
+ from tqdm import tqdm
14
+
15
+ from torch.optim.lr_scheduler import OneCycleLR
16
+ import matplotlib.pyplot as plt
17
+ import matplotlib.patches as patches
18
+ import albumentations as A
19
+ import cv2
20
+ import torch
21
+ from pytorch_grad_cam.utils.image import show_cam_on_image
22
+ import numpy as np
23
+ from albumentations.pytorch import ToTensorV2
24
+
25
+ from utils_for_app import cells_to_bboxes,non_max_suppression,plot_image,YoloCAM
26
+ from yolov3 import YOLOv3
27
+ from loss import YoloLoss
28
+ from utils import LearningRateFinder
29
+ # Create your config module or import it from the existing config.py file.
30
+ import config
31
+ from main_yolov3_lightening import YOLOv3Lightning
32
+ import torch
33
+ import cv2
34
+ import numpy as np
35
+ import gradio as gr
36
+
37
+ model = YOLOv3Lightning()
38
+ model.load_state_dict(torch.load("yolov3_model_without_75_mosaic.pth", map_location=torch.device('cpu')), strict=False)
39
+ model.setup(stage="test")
40
+
41
+ IMAGE_SIZE = 416
42
+ transforms = A.Compose(
43
+ [
44
+ A.LongestMaxSize(max_size=IMAGE_SIZE),
45
+ A.PadIfNeeded(
46
+ min_height=IMAGE_SIZE, min_width=IMAGE_SIZE, border_mode=cv2.BORDER_CONSTANT
47
+ ),
48
+ A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
49
+ ToTensorV2(),
50
+ ],
51
+ )
52
+
53
+ ANCHORS = [
54
+ [(0.28, 0.22), (0.38, 0.48), (0.9, 0.78)],
55
+ [(0.07, 0.15), (0.15, 0.11), (0.14, 0.29)],
56
+ [(0.02, 0.03), (0.04, 0.07), (0.08, 0.06)],
57
+ ] # Note these have been rescaled to be between [0, 1]
58
+ S = [IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8]
59
+
60
+ scaled_anchors = (
61
+ torch.tensor(config.ANCHORS)
62
+ * torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
63
+ )
64
+
65
+ def process_image_and_plot(image, model, scaled_anchors):
66
+
67
+ transformed_image = transforms(image=image)["image"].unsqueeze(0)
68
+ output = model(transformed_image)
69
+ bboxes = [[] for _ in range(1)]
70
+
71
+ for i in range(3):
72
+ batch_size, A, S, _, _ = output[i].shape
73
+ anchor = scaled_anchors[i]
74
+ boxes_scale_i = cells_to_bboxes(output[i], anchor, S=S, is_preds=True)
75
+ for idx, box in enumerate(boxes_scale_i):
76
+ bboxes[idx] += box
77
+
78
+ nms_boxes = non_max_suppression(
79
+ bboxes[0], iou_threshold=0.5, threshold=0.4, box_format="midpoint",
80
+ )
81
+ fig = plot_image(transformed_image[0].permute(1, 2, 0), nms_boxes)
82
+
83
+ cam = YoloCAM(model=model, target_layers=[model.model.layers[-2]], use_cuda=False)
84
+ grayscale_cam = cam(transformed_image, scaled_anchors)[0, :, :]
85
+ img = cv2.resize(image, (416, 416))
86
+ img = np.float32(img) / 255
87
+ cam_image = show_cam_on_image(img, grayscale_cam, use_rgb=True)
88
+
89
+ return fig,cam_image
90
+
91
+
92
+ examples = [
93
+ ["/content/images/automobile.jpg"],
94
+ ["/content/images/cycle.jpg"],
95
+ ["/content/images/dog-kitten.jpg"],
96
+ ["/content/images/human.jpg"],
97
+ ]
98
+
99
+ def processed_image(image):
100
+ figure,gradcam = process_image_and_plot(image, model, scaled_anchors)
101
+ return figure,gradcam
102
+
103
+ title = "YoloV3 on Pascal VOC Dataset with GradCAM"
104
+ description = "Pytorch Lightening Implemetation of YoloV3 trained from scratch"
105
+ demo = gr.Interface(processed_image,
106
+ inputs=[
107
+ gr.Image(label="Input Image"),
108
+ ],
109
+ outputs=[gr.Plot(),gr.Image(shape=(32, 32), label="Model Prediction")],
110
+ title=title,
111
+ description=description,
112
+ examples=examples,
113
+ )
114
+ demo.launch()