sanjanatule commited on
Commit
6bc5e46
·
1 Parent(s): 6c8bc0d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -84
app.py CHANGED
@@ -1,11 +1,10 @@
 
1
  import gradio as gr
2
  from torchvision import datasets, transforms
3
- import cv2
4
  import albumentations as Al
5
  from albumentations.pytorch import ToTensorV2
6
  from PIL import Image
7
  import matplotlib.pyplot as plt
8
- import matplotlib.patches as patches
9
  import io
10
  import numpy as np
11
  import pandas as pd
@@ -13,8 +12,7 @@ from torch.optim.lr_scheduler import OneCycleLR
13
  from pytorch_lightning import LightningModule, Trainer, seed_everything
14
  from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
15
  from pytorch_lightning.callbacks.progress import TQDMProgressBar
16
- from pytorch_lightning.loggers import CSVLogger
17
- from pytorch_lightning.loggers import TensorBoardLogger
18
  from tqdm import tqdm
19
  import torch
20
  import torch.optim as optim
@@ -139,83 +137,73 @@ class LitYolo(LightningModule):
139
  three_phase=False
140
  )
141
  return ([optimizer],[scheduler])
142
- yololit = LitYolo()
143
- inference_model = yololit.load_from_checkpoint("yolo3_model.ckpt")
144
-
145
- def yolo3_inference(input_img):
146
-
147
- anchors = (torch.tensor(config.ANCHORS) * torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2))
148
- bboxes = [[]]
149
-
150
- # color of the boxes
151
- cmap = plt.get_cmap("tab20b")
152
- class_labels = config.PASCAL_CLASSES
153
- colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))]
154
-
155
-
156
- # image transformation
157
- test_transforms = Al.Compose(
158
- [
159
- Al.LongestMaxSize(max_size=416),
160
- Al.PadIfNeeded(
161
- min_height=416, min_width=416, border_mode=cv2.BORDER_CONSTANT
162
- ),
163
- Al.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
164
- ToTensorV2(),
165
- ]
166
- )
167
- pr_input_img = test_transforms(image=input_img)
168
- pr_input_img = pr_input_img['image'].unsqueeze(0)
169
- test_img_out = inference_model(pr_input_img)
170
-
171
- # process the outputs
172
- for i in range(3):
173
- batch_size, A, S, _, _ = test_img_out[i].shape # 1, anchors = 3, scaling = 13/26/52
174
- anchor = anchors[i]
175
- boxes_scale_i = utils.cells_to_bboxes(test_img_out[i], anchor, S=S, is_preds=True)
176
- for idx, (box) in enumerate(boxes_scale_i):
177
- bboxes[idx] += box
178
- # nms
179
- boxes = utils.non_max_suppression(bboxes[0], iou_threshold=0.6, threshold=0.5, box_format="midpoint",)
180
-
181
- # create matplotlib plot
182
- fig, ax = plt.subplots(1)
183
- # Display the image
184
- ax.imshow(input_img)
185
- height, width, _ = input_img.shape
186
-
187
- # add boxes to the image
188
- for box in boxes:
189
- assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height"
190
- class_pred = box[0]
191
- box = box[2:]
192
- upper_left_x = box[0] - box[2] / 2
193
- upper_left_y = box[1] - box[3] / 2
194
- rect = patches.Rectangle(
195
- (upper_left_x * width, upper_left_y * height),
196
- box[2] * width,
197
- box[3] * height,
198
- linewidth=2,
199
- edgecolor=colors[int(class_pred)],
200
- facecolor="none",
201
- )
202
- # Add the patch to the Axes
203
- ax.add_patch(rect)
204
- plt.text(
205
- upper_left_x * width,
206
- upper_left_y * height,
207
- s=class_labels[int(class_pred)],
208
- color="white",
209
- verticalalignment="top",
210
- bbox={"color": colors[int(class_pred)], "pad": 0},
211
- )
212
- #plt.show()
213
- img_buf = io.BytesIO()
214
- fig.savefig(img_buf, format='png')
215
- img_buf.seek(0)
216
- img_arr = np.frombuffer(img_buf.getvalue(), dtype=np.uint8)
217
- img_buf.close()
218
- output_img = cv2.imdecode(img_arr, 1)
219
- output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2RGB)
220
-
221
- return output_img
 
1
+
2
  import gradio as gr
3
  from torchvision import datasets, transforms
 
4
  import albumentations as Al
5
  from albumentations.pytorch import ToTensorV2
6
  from PIL import Image
7
  import matplotlib.pyplot as plt
 
8
  import io
9
  import numpy as np
10
  import pandas as pd
 
12
  from pytorch_lightning import LightningModule, Trainer, seed_everything
13
  from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
14
  from pytorch_lightning.callbacks.progress import TQDMProgressBar
15
+ from pytorch_lightning.loggers import CSVLogger,TensorBoardLogger
 
16
  from tqdm import tqdm
17
  import torch
18
  import torch.optim as optim
 
137
  three_phase=False
138
  )
139
  return ([optimizer],[scheduler])
140
+
141
+
142
+
143
+
144
+ # gradio
145
+ with gr.Blocks() as demo:
146
+
147
+ # colors for the bboxes
148
+ cmap = plt.get_cmap("tab20b")
149
+ class_labels = config.PASCAL_CLASSES
150
+ colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))]
151
+ colors_hex = {class_labels[i]:matplotlib.colors.rgb2hex(colors[i]) for i in range(0,len(class_labels))}
152
+
153
+ # app GUI
154
+ with gr.Row():
155
+ img_input = gr.Image()
156
+ img_output = gr.AnnotatedImage().style(color_map = colors_hex)
157
+ section_btn = gr.Button("Identify Objects")
158
+
159
+ def yolo3_inference(input_img): # function for yolo inference
160
+
161
+ yololit = LitYolo()
162
+ inference_model = yololit.load_from_checkpoint("yolo3_model.ckpt")
163
+ anchors = (torch.tensor(config.ANCHORS) * torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2))
164
+ bboxes = [[]]
165
+ sections = [] # to return image and annotations
166
+
167
+ # image transformation
168
+ test_transforms = Al.Compose(
169
+ [
170
+ Al.LongestMaxSize(max_size=416),
171
+ Al.PadIfNeeded(
172
+ min_height=416, min_width=416, border_mode=cv2.BORDER_CONSTANT
173
+ ),
174
+ Al.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
175
+ ToTensorV2(),
176
+ ]
177
+ )
178
+ pr_input_img = test_transforms(image=input_img)
179
+ pr_input_img = pr_input_img['image'].unsqueeze(0)
180
+ # infer the image
181
+ inference_model.eval()
182
+ test_img_out = inference_model(pr_input_img)
183
+
184
+ # process the outputs to create bounding boxes
185
+ for i in range(3):
186
+ batch_size, A, S, _, _ = test_img_out[i].shape # 1, anchors = 3, scaling = 13/26/52
187
+ anchor = anchors[i]
188
+ boxes_scale_i = utils_org.cells_to_bboxes(test_img_out[i], anchor, S=S, is_preds=True)
189
+ for idx, (box) in enumerate(boxes_scale_i):
190
+ bboxes[idx] += box
191
+ # nms
192
+ nms_boxes = utils_org.non_max_suppression(bboxes[0], iou_threshold=0.6, threshold=0.5, box_format="midpoint",)
193
+
194
+ # use gradio image annotations
195
+ height, width = 416, 416
196
+ for box in nms_boxes:
197
+ class_pred = box[0]
198
+ box = box[2:]
199
+ upper_left_x = int((box[0] - box[2] / 2) * width)
200
+ upper_left_y = max(int((box[1] - box[3] / 2) * height),0) # less than 0, box collapses
201
+ lower_right_x = int(upper_left_x + (box[2] * width))
202
+ lower_right_y = int(upper_left_y + (box[3] * height))
203
+ sections.append(((upper_left_x,upper_left_y,lower_right_x,lower_right_y), class_labels[int(class_pred)]))
204
+ return (np.array(pr_input_img.squeeze(0).permute(1,2,0)),sections)
205
+
206
+ section_btn.click(yolo3_inference, img_input, img_output)
207
+
208
+ if __name__ == "__main__":
209
+ demo.launch()