sanjanatule commited on
Commit
0fd9718
·
1 Parent(s): 183bdff

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +221 -0
app.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from torchvision import datasets, transforms
3
+ import cv2
4
+ import albumentations as Al
5
+ from albumentations.pytorch import ToTensorV2
6
+ from PIL import Image
7
+ import matplotlib.pyplot as plt
8
+ import matplotlib.patches as patches
9
+ import io
10
+ import numpy as np
11
+ import pandas as pd
12
+ from torch.optim.lr_scheduler import OneCycleLR
13
+ from pytorch_lightning import LightningModule, Trainer, seed_everything
14
+ from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
15
+ from pytorch_lightning.callbacks.progress import TQDMProgressBar
16
+ from pytorch_lightning.loggers import CSVLogger
17
+ from pytorch_lightning.loggers import TensorBoardLogger
18
+ from tqdm import tqdm
19
+ import torch
20
+ import torch.optim as optim
21
+
22
+ # my files
23
+ import utils
24
+ import config
25
+ from model import YOLOv3
26
+ from utils import (
27
+ mean_average_precision,
28
+ cells_to_bboxes,
29
+ get_evaluation_bboxes,
30
+ save_checkpoint,
31
+ load_checkpoint,
32
+ check_class_accuracy,
33
+ plot_couple_examples,
34
+ accuracy_fn,
35
+ get_loaders
36
+ )
37
+ from loss import YoloLoss
38
+
39
+
40
+ # custom functions for yolo
41
+ # loss function for yolov3
42
+ loss_fn = YoloLoss()
43
+
44
+ def model_criterion(out, y,anchors):
45
+ loss = ( loss_fn(out[0], y[0], anchors[0])
46
+ + loss_fn(out[1], y[1], anchors[1])
47
+ + loss_fn(out[2], y[2], anchors[2])
48
+ )
49
+ return loss
50
+
51
+
52
+ # accuracy function for yolov3
53
+ def accuracy_fn(y, out, threshold,correct_class, correct_obj,correct_noobj, tot_class_preds,tot_obj, tot_noobj):
54
+
55
+ for i in range(3):
56
+
57
+ obj = y[i][..., 0] == 1 # in paper this is Iobj_i
58
+ noobj = y[i][..., 0] == 0 # in paper this is Iobj_i
59
+
60
+ correct_class += torch.sum(
61
+ torch.argmax(out[i][..., 5:][obj], dim=-1) == y[i][..., 5][obj]
62
+ )
63
+ tot_class_preds += torch.sum(obj)
64
+
65
+ obj_preds = torch.sigmoid(out[i][..., 0]) > threshold
66
+ correct_obj += torch.sum(obj_preds[obj] == y[i][..., 0][obj])
67
+ tot_obj += torch.sum(obj)
68
+ correct_noobj += torch.sum(obj_preds[noobj] == y[i][..., 0][noobj])
69
+ tot_noobj += torch.sum(noobj)
70
+
71
+ return((correct_class/(tot_class_preds+1e-16))*100,
72
+ (correct_noobj/(tot_noobj+1e-16))*100,
73
+ (correct_obj/(tot_obj+1e-16))*100)
74
+
75
+ # pytorch lightning
76
+ class LitYolo(LightningModule):
77
+ def __init__(self, num_classes=config.NUM_CLASSES, lr=1E-3,weight_decay=config.WEIGHT_DECAY,threshold=config.CONF_THRESHOLD):
78
+ super().__init__()
79
+
80
+ self.save_hyperparameters()
81
+ self.model = YOLOv3(num_classes=self.hparams.num_classes)
82
+ self.criterion = model_criterion
83
+ self.accuracy_fn = accuracy_fn
84
+ self.scaled_anchors = (torch.tensor(config.ANCHORS) * torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2))
85
+ self.tot_class_preds, self.correct_class = 0, 0
86
+ self.tot_noobj, self.correct_noobj = 0, 0
87
+ self.tot_obj, self.correct_obj = 0, 0
88
+
89
+ def forward(self, x):
90
+ out = self.model(x)
91
+ return out
92
+
93
+ def training_step(self, batch, batch_idx):
94
+ x, y = batch
95
+ out = self(x)
96
+ loss = self.criterion(out,y,self.scaled_anchors)
97
+ acc = self.accuracy_fn(y,out,self.hparams.threshold,self.correct_class,
98
+ self.correct_obj,
99
+ self.correct_noobj,
100
+ self.tot_class_preds,
101
+ self.tot_obj,
102
+ self.tot_noobj)
103
+
104
+ self.log('train_loss', loss, prog_bar=True, on_step=False, on_epoch=True)
105
+ self.log_dict({"class_accuracy": acc[0], "no_object_accuracy": acc[1], "object_accuracy":acc[2]},prog_bar=True,on_step=False, on_epoch=True)
106
+ return loss
107
+
108
+
109
+ def evaluate(self, batch, stage=None):
110
+ x, y = batch
111
+ out = self(x)
112
+ loss = self.criterion(out,y,self.scaled_anchors)
113
+ acc = self.accuracy_fn(y,out,self.hparams.threshold,self.correct_class,
114
+ self.correct_obj,
115
+ self.correct_noobj,
116
+ self.tot_class_preds,
117
+ self.tot_obj,
118
+ self.tot_noobj)
119
+
120
+ if stage:
121
+ self.log(f"{stage}_loss", loss, prog_bar=True)
122
+ self.log_dict({"class_accuracy": acc[0], "no_object_accuracy": acc[1], "object_accuracy":acc[2]},prog_bar=True)
123
+
124
+ def test_step(self, batch, batch_idx):
125
+ self.evaluate(batch, "test")
126
+
127
+ def validation_step(self, batch, batch_idx):
128
+ self.evaluate(batch, "val")
129
+
130
+ def configure_optimizers(self):
131
+ optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
132
+ scheduler = OneCycleLR(
133
+ optimizer,
134
+ max_lr= 1E-3,
135
+ pct_start = 5/self.trainer.max_epochs,
136
+ epochs=self.trainer.max_epochs,
137
+ steps_per_epoch=len(train_loader),
138
+ div_factor=100,verbose=True,
139
+ three_phase=False
140
+ )
141
+ return ([optimizer],[scheduler])
142
+ yololit = LitYolo()
143
+ inference_model = yololit.load_from_checkpoint("yolo3_model.ckpt")
144
+
145
+ def yolo3_inference(input_img):
146
+
147
+ anchors = (torch.tensor(config.ANCHORS) * torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2))
148
+ bboxes = [[]]
149
+
150
+ # color of the boxes
151
+ cmap = plt.get_cmap("tab20b")
152
+ class_labels = config.PASCAL_CLASSES
153
+ colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))]
154
+
155
+
156
+ # image transformation
157
+ test_transforms = Al.Compose(
158
+ [
159
+ Al.LongestMaxSize(max_size=416),
160
+ Al.PadIfNeeded(
161
+ min_height=416, min_width=416, border_mode=cv2.BORDER_CONSTANT
162
+ ),
163
+ Al.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
164
+ ToTensorV2(),
165
+ ]
166
+ )
167
+ pr_input_img = test_transforms(image=input_img)
168
+ pr_input_img = pr_input_img['image'].unsqueeze(0)
169
+ test_img_out = inference_model(pr_input_img)
170
+
171
+ # process the outputs
172
+ for i in range(3):
173
+ batch_size, A, S, _, _ = test_img_out[i].shape # 1, anchors = 3, scaling = 13/26/52
174
+ anchor = anchors[i]
175
+ boxes_scale_i = utils.cells_to_bboxes(test_img_out[i], anchor, S=S, is_preds=True)
176
+ for idx, (box) in enumerate(boxes_scale_i):
177
+ bboxes[idx] += box
178
+ # nms
179
+ boxes = utils.non_max_suppression(bboxes[0], iou_threshold=0.6, threshold=0.5, box_format="midpoint",)
180
+
181
+ # create matplotlib plot
182
+ fig, ax = plt.subplots(1)
183
+ # Display the image
184
+ ax.imshow(input_img)
185
+ height, width, _ = input_img.shape
186
+
187
+ # add boxes to the image
188
+ for box in boxes:
189
+ assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height"
190
+ class_pred = box[0]
191
+ box = box[2:]
192
+ upper_left_x = box[0] - box[2] / 2
193
+ upper_left_y = box[1] - box[3] / 2
194
+ rect = patches.Rectangle(
195
+ (upper_left_x * width, upper_left_y * height),
196
+ box[2] * width,
197
+ box[3] * height,
198
+ linewidth=2,
199
+ edgecolor=colors[int(class_pred)],
200
+ facecolor="none",
201
+ )
202
+ # Add the patch to the Axes
203
+ ax.add_patch(rect)
204
+ plt.text(
205
+ upper_left_x * width,
206
+ upper_left_y * height,
207
+ s=class_labels[int(class_pred)],
208
+ color="white",
209
+ verticalalignment="top",
210
+ bbox={"color": colors[int(class_pred)], "pad": 0},
211
+ )
212
+ #plt.show()
213
+ img_buf = io.BytesIO()
214
+ fig.savefig(img_buf, format='png')
215
+ img_buf.seek(0)
216
+ img_arr = np.frombuffer(img_buf.getvalue(), dtype=np.uint8)
217
+ img_buf.close()
218
+ output_img = cv2.imdecode(img_arr, 1)
219
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2RGB)
220
+
221
+ return output_img