erlinersi commited on
Commit
e19ae65
·
1 Parent(s): e730503

Upload 3 files

Browse files
Files changed (3) hide show
  1. utils/general.py +748 -0
  2. utils/plots.py +89 -0
  3. utils/torch_utils.py +321 -0
utils/general.py ADDED
@@ -0,0 +1,748 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import logging
3
+ import math
4
+ import random
5
+ import re
6
+ import time
7
+ from pathlib import Path
8
+
9
+ import cv2
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+ import torchvision
14
+ from matplotlib import pyplot as plt
15
+
16
+ id_dict = {'person': 0, 'rider': 1, 'car': 2, 'bus': 3, 'truck': 4,
17
+ 'bike': 5, 'motor': 6, 'tl_green': 7, 'tl_red': 8,
18
+ 'tl_yellow': 9, 'tl_none': 10, 'traffic sign': 11, 'train': 12}
19
+ id_dict_single = {'car': 0, 'bus': 1, 'truck': 2, 'train': 3}
20
+
21
+
22
+ def clean_str(s):
23
+ # Cleans a string by replacing special characters with underscore _
24
+ return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
25
+
26
+
27
+ def set_logging(rank=-1):
28
+ logging.basicConfig(
29
+ format="%(message)s",
30
+ level=logging.INFO if rank in [-1, 0] else logging.WARN)
31
+
32
+
33
+ def convert(size, box):
34
+ dw = 1. / (size[0])
35
+ dh = 1. / (size[1])
36
+ x = (box[0] + box[1]) / 2.0
37
+ y = (box[2] + box[3]) / 2.0
38
+ w = box[1] - box[0]
39
+ h = box[3] - box[2]
40
+ x = x * dw
41
+ w = w * dw
42
+ y = y * dh
43
+ h = h * dh
44
+ return (x, y, w, h)
45
+
46
+
47
+ def xywh2xyxy(x):
48
+ # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
49
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
50
+ y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
51
+ y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
52
+ y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
53
+ y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
54
+ return y
55
+
56
+
57
+ def xyxy2xywh(x):
58
+ # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
59
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
60
+ y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
61
+ y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
62
+ y[:, 2] = x[:, 2] - x[:, 0] # width
63
+ y[:, 3] = x[:, 3] - x[:, 1] # height
64
+ return y
65
+
66
+
67
+ def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, SIoU=False, WIoU=False, eps=1e-7):
68
+ # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
69
+ box2 = box2.T
70
+
71
+ # Get the coordinates of bounding boxes
72
+ if x1y1x2y2: # x1, y1, x2, y2 = box1
73
+ b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
74
+ b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
75
+ else: # transform from xywh to xyxy
76
+ b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
77
+ b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
78
+ b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
79
+ b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
80
+
81
+ # Intersection area
82
+ inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
83
+ (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
84
+
85
+ # Union Area
86
+ w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
87
+ w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
88
+ union = w1 * h1 + w2 * h2 - inter + eps
89
+
90
+ iou = inter / union
91
+ if GIoU or DIoU or CIoU or SIoU:
92
+ cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
93
+ ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
94
+ if SIoU: # SIoU Loss https://arxiv.org/pdf/2205.12740.pdf
95
+ s_cw = (b2_x1 + b2_x2 - b1_x1 - b1_x2) * 0.5
96
+ s_ch = (b2_y1 + b2_y2 - b1_y1 - b1_y2) * 0.5
97
+ sigma = torch.pow(s_cw ** 2 + s_ch ** 2, 0.5)
98
+ sin_alpha_1 = torch.abs(s_cw) / sigma
99
+ sin_alpha_2 = torch.abs(s_ch) / sigma
100
+ threshold = pow(2, 0.5) / 2
101
+ sin_alpha = torch.where(sin_alpha_1 > threshold, sin_alpha_2, sin_alpha_1)
102
+ # angle_cost = 1 - 2 * torch.pow( torch.sin(torch.arcsin(sin_alpha) - np.pi/4), 2)
103
+ angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - np.pi / 2)
104
+ rho_x = (s_cw / cw) ** 2
105
+ rho_y = (s_ch / ch) ** 2
106
+ gamma = angle_cost - 2
107
+ distance_cost = 2 - torch.exp(gamma * rho_x) - torch.exp(gamma * rho_y)
108
+ omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2)
109
+ omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2)
110
+ shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(1 - torch.exp(-1 * omiga_h), 4)
111
+ return iou - 0.5 * (distance_cost + shape_cost)
112
+ if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
113
+ c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
114
+ rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
115
+ (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared
116
+ if DIoU:
117
+ return iou - rho2 / c2 # DIoU
118
+ elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
119
+ v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
120
+ with torch.no_grad():
121
+ alpha = v / (v - iou + (1 + eps))
122
+ return iou - (rho2 / c2 + v * alpha) # CIoU
123
+ else: # GIoU https://arxiv.org/pdf/1902.09630.pdf
124
+ c_area = cw * ch + eps # convex area
125
+ return iou - (c_area - union) / c_area # GIoU
126
+ elif WIoU:
127
+ b1 = torch.stack([b1_x1, b1_y1, b1_x2, b1_y2], dim=-1)
128
+ b2 = torch.stack([b2_x1, b2_y1, b2_x2, b2_y2], dim=-1)
129
+
130
+ self = IoU_Cal(b1, b2)
131
+ loss = getattr(IoU_Cal, 'WIoU')(b1, b2, self=self)
132
+ iou = 1 - self.iou
133
+ return loss, iou
134
+ else:
135
+ return iou # IoU
136
+
137
+
138
+ def box_iou(box1, box2):
139
+ # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
140
+ """
141
+ Return intersection-over-union (Jaccard index) of boxes.
142
+ Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
143
+ Arguments:
144
+ box1 (Tensor[N, 4])
145
+ box2 (Tensor[M, 4])
146
+ Returns:
147
+ iou (Tensor[N, M]): the NxM matrix containing the pairwise
148
+ IoU values for every element in boxes1 and boxes2
149
+ """
150
+
151
+ def box_area(box):
152
+ # box = 4xn
153
+ return (box[2] - box[0]) * (box[3] - box[1])
154
+
155
+ area1 = box_area(box1.T)
156
+ area2 = box_area(box2.T)
157
+
158
+ # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
159
+ inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
160
+ return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
161
+
162
+
163
+ def letterbox(combination, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True):
164
+ """Resize the input image and automatically padding to suitable shape :https://zhuanlan.zhihu.com/p/172121380"""
165
+ # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232
166
+ img, gray, line = combination
167
+ shape = img.shape[:2] # current shape [height, width]
168
+ if isinstance(new_shape, int):
169
+ new_shape = (new_shape, new_shape)
170
+
171
+ # Scale ratio (new / old)
172
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
173
+ if not scaleup: # only scale down, do not scale up (for better test mAP)
174
+ r = min(r, 1.0)
175
+
176
+ # Compute padding
177
+ ratio = r, r # width, height ratios
178
+ new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
179
+ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
180
+ # if auto: # minimum rectangle
181
+ # dw, dh = np.mod(dw, 32), np.mod(dh, 32) # wh padding
182
+ # elif scaleFill: # stretch
183
+ # dw, dh = 0.0, 0.0
184
+ # new_unpad = (new_shape[1], new_shape[0])
185
+ # ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
186
+
187
+ dw /= 2 # divide padding into 2 sides
188
+ dh /= 2
189
+
190
+ if shape[::-1] != new_unpad: # resize
191
+ img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
192
+ gray = cv2.resize(gray, new_unpad, interpolation=cv2.INTER_LINEAR)
193
+ line = cv2.resize(line, new_unpad, interpolation=cv2.INTER_LINEAR)
194
+
195
+ top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
196
+ left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
197
+
198
+ img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
199
+ gray = cv2.copyMakeBorder(gray, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0) # add border
200
+ line = cv2.copyMakeBorder(line, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0) # add border
201
+ # print(img.shape)
202
+
203
+ combination = (img, gray, line)
204
+ return combination, ratio, (dw, dh)
205
+
206
+
207
+ def letterbox_for_img(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True):
208
+ # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232
209
+ shape = img.shape[:2] # current shape [height, width]
210
+ if isinstance(new_shape, int):
211
+ new_shape = (new_shape, new_shape)
212
+
213
+ # Scale ratio (new / old)
214
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
215
+ if not scaleup: # only scale down, do not scale up (for better test mAP)
216
+ r = min(r, 1.0)
217
+
218
+ # Compute padding
219
+ ratio = r, r # width, height ratios
220
+ new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
221
+ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
222
+ if auto: # minimum rectangle
223
+ dw, dh = np.mod(dw, 32), np.mod(dh, 32) # wh padding
224
+ elif scaleFill: # stretch
225
+ dw, dh = 0.0, 0.0
226
+ new_unpad = (new_shape[1], new_shape[0])
227
+ ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
228
+
229
+ dw /= 2 # divide padding into 2 sides
230
+ dh /= 2
231
+ if shape[::-1] != new_unpad: # resize
232
+ img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_AREA)
233
+
234
+ top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
235
+ left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
236
+ img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
237
+ return img, ratio, (dw, dh)
238
+
239
+
240
+ def colorstr(*input):
241
+ # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
242
+ *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
243
+ colors = {'black': '\033[30m', # basic colors
244
+ 'red': '\033[31m',
245
+ 'green': '\033[32m',
246
+ 'yellow': '\033[33m',
247
+ 'blue': '\033[34m',
248
+ 'magenta': '\033[35m',
249
+ 'cyan': '\033[36m',
250
+ 'white': '\033[37m',
251
+ 'bright_black': '\033[90m', # bright colors
252
+ 'bright_red': '\033[91m',
253
+ 'bright_green': '\033[92m',
254
+ 'bright_yellow': '\033[93m',
255
+ 'bright_blue': '\033[94m',
256
+ 'bright_magenta': '\033[95m',
257
+ 'bright_cyan': '\033[96m',
258
+ 'bright_white': '\033[97m',
259
+ 'end': '\033[0m', # misc
260
+ 'bold': '\033[1m',
261
+ 'underline': '\033[4m'}
262
+ return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
263
+
264
+
265
+ def make_divisible(x, divisor):
266
+ # Returns x evenly divisible by divisor
267
+ return math.ceil(x / divisor) * divisor
268
+
269
+
270
+ def check_img_size(img_size, s=32):
271
+ # Verify img_size is a multiple of stride s
272
+ new_size = make_divisible(img_size, int(s)) # ceil gs-multiple
273
+ if new_size != img_size:
274
+ print('WARNING: --img-size %g must be multiple of max stride %g, updating to %g' % (img_size, s, new_size))
275
+ return new_size
276
+
277
+
278
+ def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
279
+ # scales img(bs,3,y,x) by ratio constrained to gs-multiple
280
+ if ratio == 1.0:
281
+ return img
282
+ else:
283
+ h, w = img.shape[2:]
284
+ s = (int(h * ratio), int(w * ratio)) # new size
285
+ img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
286
+ if not same_shape: # pad/crop img
287
+ h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
288
+ return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
289
+
290
+
291
+ def random_perspective(combination, targets=(), degrees=10, translate=.1, scale=.1, shear=10, perspective=0.0,
292
+ border=(0, 0)):
293
+ """combination of img transform"""
294
+ # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10))
295
+ # targets = [cls, xyxy]
296
+ img, gray, line = combination
297
+ height = img.shape[0] + border[0] * 2 # shape(h,w,c)
298
+ width = img.shape[1] + border[1] * 2
299
+
300
+ # Center
301
+ C = np.eye(3)
302
+ C[0, 2] = -img.shape[1] / 2 # x translation (pixels)
303
+ C[1, 2] = -img.shape[0] / 2 # y translation (pixels)
304
+
305
+ # Perspective
306
+ P = np.eye(3)
307
+ P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y)
308
+ P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x)
309
+
310
+ # Rotation and Scale
311
+ R = np.eye(3)
312
+ a = random.uniform(-degrees, degrees)
313
+ # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
314
+ s = random.uniform(1 - scale, 1 + scale)
315
+ # s = 2 ** random.uniform(-scale, scale)
316
+ R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
317
+
318
+ # Shear
319
+ S = np.eye(3)
320
+ S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
321
+ S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
322
+
323
+ # Translation
324
+ T = np.eye(3)
325
+ T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width # x translation (pixels)
326
+ T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height # y translation (pixels)
327
+
328
+ # Combined rotation matrix
329
+ M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
330
+ if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
331
+ if perspective:
332
+ img = cv2.warpPerspective(img, M, dsize=(width, height), borderValue=(114, 114, 114))
333
+ gray = cv2.warpPerspective(gray, M, dsize=(width, height), borderValue=0)
334
+ line = cv2.warpPerspective(line, M, dsize=(width, height), borderValue=0)
335
+ else: # affine
336
+ img = cv2.warpAffine(img, M[:2], dsize=(width, height), borderValue=(114, 114, 114))
337
+ gray = cv2.warpAffine(gray, M[:2], dsize=(width, height), borderValue=0)
338
+ line = cv2.warpAffine(line, M[:2], dsize=(width, height), borderValue=0)
339
+
340
+ # Visualize
341
+ # import matplotlib.pyplot as plt
342
+ # ax = plt.subplots(1, 2, figsize=(12, 6))[1].ravel()
343
+ # ax[0].imshow(img[:, :, ::-1]) # base
344
+ # ax[1].imshow(img2[:, :, ::-1]) # warped
345
+
346
+ # Transform label coordinates
347
+ n = len(targets)
348
+ if n:
349
+ # warp points
350
+ xy = np.ones((n * 4, 3))
351
+ xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
352
+ xy = xy @ M.T # transform
353
+ if perspective:
354
+ xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8) # rescale
355
+ else: # affine
356
+ xy = xy[:, :2].reshape(n, 8)
357
+
358
+ # create new boxes
359
+ x = xy[:, [0, 2, 4, 6]]
360
+ y = xy[:, [1, 3, 5, 7]]
361
+ xy = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
362
+
363
+ # # apply angle-based reduction of bounding boxes
364
+ # radians = a * math.pi / 180
365
+ # reduction = max(abs(math.sin(radians)), abs(math.cos(radians))) ** 0.5
366
+ # x = (xy[:, 2] + xy[:, 0]) / 2
367
+ # y = (xy[:, 3] + xy[:, 1]) / 2
368
+ # w = (xy[:, 2] - xy[:, 0]) * reduction
369
+ # h = (xy[:, 3] - xy[:, 1]) * reduction
370
+ # xy = np.concatenate((x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, n).T
371
+
372
+ # clip boxes
373
+ xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width)
374
+ xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height)
375
+
376
+ # filter candidates
377
+ i = _box_candidates(box1=targets[:, 1:5].T * s, box2=xy.T)
378
+ targets = targets[i]
379
+ targets[:, 1:5] = xy[i]
380
+
381
+ combination = (img, gray, line)
382
+ return combination, targets
383
+
384
+
385
+ def _box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1): # box1(4,n), box2(4,n)
386
+ # Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
387
+ w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
388
+ w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
389
+ ar = np.maximum(w2 / (h2 + 1e-16), h2 / (w2 + 1e-16)) # aspect ratio
390
+ return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + 1e-16) > area_thr) & (ar < ar_thr) # candidates
391
+
392
+
393
+ def mixup(im, labels, seg_label, lane_label, im2, labels2, seg_label2, lane_label2):
394
+ # Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf
395
+ r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
396
+ im = (im * r + im2 * (1 - r)).astype(np.uint8)
397
+ labels = np.concatenate((labels, labels2), 0)
398
+ seg_label |= seg_label2
399
+ lane_label |= lane_label2
400
+ return im, labels, seg_label, lane_label
401
+
402
+
403
+ def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5):
404
+ """change color hue, saturation, value"""
405
+ r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
406
+ hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
407
+ dtype = img.dtype # uint8
408
+
409
+ x = np.arange(0, 256, dtype=np.int16)
410
+ lut_hue = ((x * r[0]) % 180).astype(dtype)
411
+ lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
412
+ lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
413
+
414
+ img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))).astype(dtype)
415
+ cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed
416
+
417
+ # Histogram equalization
418
+ # if random.random() < 0.2:
419
+ # for i in range(3):
420
+ # img[:, :, i] = cv2.equalizeHist(img[:, :, i])
421
+
422
+
423
+ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
424
+ labels=()):
425
+ """Runs Non-Maximum Suppression (NMS) on inference results
426
+
427
+ Returns:
428
+ list of detections, on (n,6) tensor per image [xyxy, conf, cls]
429
+ """
430
+
431
+ nc = prediction.shape[2] - 5 # number of classes
432
+ xc = prediction[..., 4] > conf_thres # candidates
433
+
434
+ # Settings
435
+ min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
436
+ max_det = 300 # maximum number of detections per image
437
+ max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
438
+ time_limit = 10.0 # seconds to quit after
439
+ redundant = True # require redundant detections
440
+ multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
441
+ merge = False # use merge-NMS
442
+
443
+ t = time.time()
444
+ output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
445
+ for xi, x in enumerate(prediction): # image index, image inference
446
+ # Apply constraints
447
+ # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
448
+ x = x[xc[xi]] # confidence
449
+
450
+ # Cat apriori labels if autolabelling
451
+ if labels and len(labels[xi]):
452
+ l = labels[xi]
453
+ v = torch.zeros((len(l), nc + 5), device=x.device)
454
+ v[:, :4] = l[:, 1:5] # box
455
+ v[:, 4] = 1.0 # conf
456
+ v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
457
+ x = torch.cat((x, v), 0)
458
+
459
+ # If none remain process next image
460
+ if not x.shape[0]:
461
+ continue
462
+
463
+ # Compute conf
464
+ if nc == 1:
465
+ x[:, 5:] = x[:, 4:5] # for models with one class, cls_loss is 0 and cls_conf is always 0.5,
466
+ # so there is no need to multiplicate.
467
+ else:
468
+ x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
469
+
470
+ # Box (center x, center y, width, height) to (x1, y1, x2, y2)
471
+ box = xywh2xyxy(x[:, :4])
472
+
473
+ # Detections matrix nx6 (xyxy, conf, cls)
474
+ if multi_label:
475
+ i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
476
+ x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
477
+ else: # best class only
478
+ conf, j = x[:, 5:].max(1, keepdim=True)
479
+ x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
480
+
481
+ # Filter by class
482
+ if classes is not None:
483
+ x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
484
+
485
+ # Apply finite constraint
486
+ # if not torch.isfinite(x).all():
487
+ # x = x[torch.isfinite(x).all(1)]
488
+
489
+ # Check shape
490
+ n = x.shape[0] # number of boxes
491
+ if not n: # no boxes
492
+ continue
493
+ elif n > max_nms: # excess boxes
494
+ x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
495
+
496
+ # Batched NMS
497
+ c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
498
+ boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
499
+ i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
500
+ if i.shape[0] > max_det: # limit detections
501
+ i = i[:max_det]
502
+ if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
503
+ # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
504
+ iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
505
+ weights = iou * scores[None] # box weights
506
+ x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
507
+ if redundant:
508
+ i = i[iou.sum(1) > 1] # require redundancy
509
+
510
+ output[xi] = x[i]
511
+ if (time.time() - t) > time_limit:
512
+ print(f'WARNING: NMS time limit {time_limit}s exceeded')
513
+ break # time limit exceeded
514
+
515
+ return output
516
+
517
+
518
+ def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
519
+ # Rescale coords (xyxy) from img1_shape to img0_shape
520
+ if ratio_pad is None: # calculate from img0_shape
521
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
522
+ pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
523
+ else:
524
+ gain = ratio_pad[0][0]
525
+ pad = ratio_pad[1]
526
+
527
+ coords[:, [0, 2]] -= pad[0] # x padding
528
+ coords[:, [1, 3]] -= pad[1] # y padding
529
+ coords[:, :4] /= gain
530
+ clip_coords(coords, img0_shape)
531
+ return coords
532
+
533
+
534
+ def clip_coords(boxes, img_shape):
535
+ # Clip bounding xyxy bounding boxes to image shape (height, width)
536
+ boxes[:, 0].clamp_(0, img_shape[1]) # x1
537
+ boxes[:, 1].clamp_(0, img_shape[0]) # y1
538
+ boxes[:, 2].clamp_(0, img_shape[1]) # x2
539
+ boxes[:, 3].clamp_(0, img_shape[0]) # y2
540
+
541
+
542
+ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='precision-recall_curve.png', names=[]):
543
+ """ Compute the average precision, given the recall and precision curves.
544
+ Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
545
+ # Arguments
546
+ tp: True positives (nparray, nx1 or nx10).
547
+ conf: Objectness value from 0-1 (nparray).
548
+ pred_cls: Predicted object classes (nparray).
549
+ target_cls: True object classes (nparray).
550
+ plot: Plot precision-recall curve at mAP@0.5
551
+ save_dir: Plot save directory
552
+ # Returns
553
+ The average precision as computed in py-faster-rcnn.
554
+ """
555
+
556
+ # Sort by objectness
557
+ i = np.argsort(-conf)
558
+ tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
559
+
560
+ # Find unique classes
561
+ unique_classes = np.unique(target_cls)
562
+
563
+ # Create Precision-Recall curve and compute AP for each class
564
+ px, py = np.linspace(0, 1, 1000), [] # for plotting
565
+ pr_score = 0.1 # score to evaluate P and R https://github.com/ultralytics/yolov3/issues/898
566
+ s = [unique_classes.shape[0], tp.shape[1]] # number class, number iou thresholds (i.e. 10 for mAP0.5...0.95)
567
+ ap, p, r = np.zeros(s), np.zeros((unique_classes.shape[0], 1000)), np.zeros((unique_classes.shape[0], 1000))
568
+ for ci, c in enumerate(unique_classes):
569
+ i = pred_cls == c
570
+ n_l = (target_cls == c).sum() # number of labels
571
+ n_p = i.sum() # number of predictions
572
+
573
+ if n_p == 0 or n_l == 0:
574
+ continue
575
+ else:
576
+ # Accumulate FPs and TPs
577
+ fpc = (1 - tp[i]).cumsum(0)
578
+ tpc = tp[i].cumsum(0)
579
+
580
+ # Recall
581
+ recall = tpc / (n_l + 1e-16) # recall curve
582
+ r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases
583
+
584
+ # Precision
585
+ precision = tpc / (tpc + fpc) # precision curve
586
+ p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score
587
+ # AP from recall-precision curve
588
+ for j in range(tp.shape[1]):
589
+ ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
590
+ if plot and (j == 0):
591
+ py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5
592
+
593
+ # Compute F1 score (harmonic mean of precision and recall)
594
+ f1 = 2 * p * r / (p + r + 1e-16)
595
+ i = r.mean(0).argmax()
596
+
597
+ if plot:
598
+ plot_pr_curve(px, py, ap, save_dir, names)
599
+
600
+ return p[:, i], r[:, i], ap, f1[:, i], unique_classes.astype('int32')
601
+
602
+
603
+ def compute_ap(recall, precision):
604
+ """ Compute the average precision, given the recall and precision curves.
605
+ Source: https://github.com/rbgirshick/py-faster-rcnn.
606
+ # Arguments
607
+ recall: The recall curve (list).
608
+ precision: The precision curve (list).
609
+ # Returns
610
+ The average precision as computed in py-faster-rcnn.
611
+ """
612
+
613
+ # Append sentinel values to beginning and end
614
+ mrec = np.concatenate(([0.], recall, [recall[-1] + 1E-3]))
615
+ mpre = np.concatenate(([1.], precision, [0.]))
616
+
617
+ # Compute the precision envelope
618
+ mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
619
+
620
+ # Integrate area under curve
621
+ method = 'interp' # methods: 'continuous', 'interp'
622
+ if method == 'interp':
623
+ x = np.linspace(0, 1, 101) # 101-point interp (COCO)
624
+ ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate
625
+
626
+ else: # 'continuous'
627
+ i = np.where(mrec[1:] != mrec[:-1])[0] # points where x axis (recall) changes
628
+ ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve
629
+
630
+ return ap, mpre, mrec
631
+
632
+
633
+ def plot_pr_curve(px, py, ap, save_dir='.', names=()):
634
+ fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
635
+ py = np.stack(py, axis=1)
636
+
637
+ if 0 < len(names) < 21: # show mAP in legend if < 10 classes
638
+ for i, y in enumerate(py.T):
639
+ ax.plot(px, y, linewidth=1, label=f'{names[i]} %.3f' % ap[i, 0]) # plot(recall, precision)
640
+ else:
641
+ ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision)
642
+
643
+ ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean())
644
+ ax.set_xlabel('Recall')
645
+ ax.set_ylabel('Precision')
646
+ ax.set_xlim(0, 1)
647
+ ax.set_ylim(0, 1)
648
+ plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
649
+ fig.savefig(Path(save_dir) / 'precision_recall_curve.png', dpi=250)
650
+
651
+
652
+ def increment_path(path, exist_ok=True, sep=''):
653
+ # Increment path, i.e. runs/exp --> runs/exp{sep}0, runs/exp{sep}1 etc.
654
+ path = Path(path) # os-agnostic
655
+ if (path.exists() and exist_ok) or (not path.exists()):
656
+ return str(path)
657
+ else:
658
+ dirs = glob.glob(f"{path}{sep}*") # similar paths
659
+ matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
660
+ i = [int(m.groups()[0]) for m in matches if m] # indices
661
+ n = max(i) + 1 if i else 2 # increment number
662
+ return f"{path}{sep}{n}" # update path
663
+
664
+
665
+ class IoU_Cal:
666
+ ''' pred, target: x0,y0,x1,y1
667
+ monotonous: {
668
+ None: origin
669
+ True: monotonic FM
670
+ False: non-monotonic FM
671
+ }
672
+ momentum: The momentum of running mean'''
673
+ iou_mean = 1.
674
+ monotonous = False
675
+ momentum = 1 - 0.5 ** (1 / 7000)
676
+
677
+ _is_train = True
678
+
679
+ def __init__(self, pred, target):
680
+ self.pred, self.target = pred, target
681
+ self._fget = {
682
+ # x,y,w,h
683
+ 'pred_xy': lambda: (self.pred[..., :2] + self.pred[..., 2: 4]) / 2,
684
+ 'pred_wh': lambda: self.pred[..., 2: 4] - self.pred[..., :2],
685
+ 'target_xy': lambda: (self.target[..., :2] + self.target[..., 2: 4]) / 2,
686
+ 'target_wh': lambda: self.target[..., 2: 4] - self.target[..., :2],
687
+ # x0,y0,x1,y1
688
+ 'min_coord': lambda: torch.minimum(self.pred[..., :4], self.target[..., :4]),
689
+ 'max_coord': lambda: torch.maximum(self.pred[..., :4], self.target[..., :4]),
690
+ # The overlapping region
691
+ 'wh_inter': lambda: self.min_coord[..., 2: 4] - self.max_coord[..., :2],
692
+ 's_inter': lambda: torch.prod(torch.relu(self.wh_inter), dim=-1),
693
+ # The area covered
694
+ 's_union': lambda: torch.prod(self.pred_wh, dim=-1) +
695
+ torch.prod(self.target_wh, dim=-1) - self.s_inter,
696
+ # The smallest enclosing box
697
+ 'wh_box': lambda: self.max_coord[..., 2: 4] - self.min_coord[..., :2],
698
+ 's_box': lambda: torch.prod(self.wh_box, dim=-1),
699
+ 'l2_box': lambda: torch.square(self.wh_box).sum(dim=-1),
700
+ # The central points' connection of the bounding boxes
701
+ 'd_center': lambda: self.pred_xy - self.target_xy,
702
+ 'l2_center': lambda: torch.square(self.d_center).sum(dim=-1),
703
+ # IoU
704
+ 'iou': lambda: 1 - self.s_inter / self.s_union
705
+ }
706
+ self._update(self)
707
+
708
+ def __setitem__(self, key, value):
709
+ self._fget[key] = value
710
+
711
+ def __getattr__(self, item):
712
+ if callable(self._fget[item]):
713
+ self._fget[item] = self._fget[item]()
714
+ return self._fget[item]
715
+
716
+ @classmethod
717
+ def train(cls):
718
+ cls._is_train = True
719
+
720
+ @classmethod
721
+ def eval(cls):
722
+ cls._is_train = False
723
+
724
+ @classmethod
725
+ def _update(cls, self):
726
+ if cls._is_train: cls.iou_mean = (1 - cls.momentum) * cls.iou_mean + \
727
+ cls.momentum * self.iou.detach().mean().item()
728
+
729
+ def _scaled_loss(self, loss, gamma=1.9, delta=3):
730
+ if isinstance(self.monotonous, bool):
731
+ if self.monotonous:
732
+ loss *= (self.iou.detach() / self.iou_mean).sqrt()
733
+ else:
734
+ beta = self.iou.detach() / self.iou_mean
735
+ alpha = delta * torch.pow(gamma, beta - delta)
736
+ loss *= beta / alpha
737
+ return loss
738
+
739
+ @classmethod
740
+ def IoU(cls, pred, target, self=None):
741
+ self = self if self else cls(pred, target)
742
+ return self.iou
743
+
744
+ @classmethod
745
+ def WIoU(cls, pred, target, self=None):
746
+ self = self if self else cls(pred, target)
747
+ dist = torch.exp(self.l2_center / self.l2_box.detach())
748
+ return self._scaled_loss(dist * self.iou)
utils/plots.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import cv2
4
+ import matplotlib
5
+ import numpy as np
6
+ import torch
7
+
8
+
9
+ def show_seg_result(img, result, index=0, epoch=0, batch=0, save_dir=None, is_ll=False, palette=None, is_demo=False,
10
+ is_gt=False):
11
+ if palette is None:
12
+ palette = np.random.randint(0, 255, size=(3, 3))
13
+ palette[0] = [0, 0, 0]
14
+ palette[1] = [0, 255, 0]
15
+ palette[2] = [255, 0, 0]
16
+ palette = np.array(palette)
17
+ assert palette.shape[0] == 3 # len(classes)
18
+ assert palette.shape[1] == 3
19
+ assert len(palette.shape) == 2
20
+
21
+ if not is_demo:
22
+ color_seg = np.zeros((result.shape[0], result.shape[1], 3), dtype=np.uint8)
23
+ for label, color in enumerate(palette):
24
+ color_seg[result == label, :] = color
25
+ else:
26
+ color_area = np.zeros((result[0].shape[0], result[0].shape[1], 3), dtype=np.uint8)
27
+
28
+ color_area[result[0] == 1] = [0, 255, 0]
29
+ color_area[result[1] == 1] = [255, 0, 0]
30
+ color_seg = color_area
31
+
32
+ # convert to BGR
33
+ color_seg = color_seg[..., ::-1]
34
+ # print(color_seg.shape)
35
+ color_mask = np.mean(color_seg, 2)
36
+ img[color_mask != 0] = img[color_mask != 0] * 0.5 + color_seg[color_mask != 0] * 0.5
37
+ # img = img * 0.5 + color_seg * 0.5
38
+ img = img.astype(np.uint8)
39
+ img = cv2.resize(img, (1280, 720), interpolation=cv2.INTER_LINEAR)
40
+
41
+ if not is_demo:
42
+ if not is_gt:
43
+ if not is_ll:
44
+ cv2.imwrite(save_dir + "/batch_{}_{}_{}_da_segresult.png".format(epoch, batch, index), img)
45
+ else:
46
+ cv2.imwrite(save_dir + "/batch_{}_{}_{}_ll_segresult.png".format(epoch, batch, index), img)
47
+ else:
48
+ if not is_ll:
49
+ cv2.imwrite(save_dir + "/batch_{}_{}_{}_da_seg_gt.png".format(epoch, batch, index), img)
50
+ else:
51
+ cv2.imwrite(save_dir + "/batch_{}_{}_{}_ll_seg_gt.png".format(epoch, batch, index), img)
52
+ return img
53
+
54
+
55
+ def plot_one_box(x, img, color=None, label=None, line_thickness=3):
56
+ # Plots one bounding box on image img
57
+ tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
58
+ color = color or [random.randint(0, 255) for _ in range(3)]
59
+ c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
60
+ cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
61
+ if label:
62
+ tf = max(tl - 1, 1) # font thickness
63
+ t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
64
+ c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
65
+ cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled
66
+ cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
67
+
68
+
69
+ def driving_area_mask(da_seg_out, width, height, pad_w, pad_h, ratio):
70
+ da_predict = da_seg_out[:, :, pad_h:(height - pad_h), pad_w:(width - pad_w)]
71
+ da_seg_mask = torch.nn.functional.interpolate(da_predict, scale_factor=int(1 / ratio), mode='bilinear')
72
+ _, da_seg_mask = torch.max(da_seg_mask, 1)
73
+ da_seg_mask = da_seg_mask.int().squeeze().cpu().numpy()
74
+ return da_seg_mask
75
+
76
+
77
+ def lane_line_mask(ll_seg_out, width, height, pad_w, pad_h, ratio):
78
+ ll_predict = ll_seg_out[:, :, pad_h:(height - pad_h), pad_w:(width - pad_w)]
79
+ ll_seg_mask = torch.nn.functional.interpolate(ll_predict, scale_factor=int(1 / ratio), mode='bilinear')
80
+ _, ll_seg_mask = torch.max(ll_seg_mask, 1)
81
+ ll_seg_mask = ll_seg_mask.int().squeeze().cpu().numpy()
82
+ return ll_seg_mask
83
+
84
+ def color_list():
85
+ # Return first 10 plt colors as (r,g,b) https://stackoverflow.com/questions/51350872/python-from-color-name-to-rgb
86
+ def hex2rgb(h):
87
+ return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
88
+
89
+ return [hex2rgb(h) for h in matplotlib.colors.TABLEAU_COLORS.values()] # or BASE_ (8), CSS4_ (148), XKCD_ (949)
utils/torch_utils.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import logging
3
+ import math
4
+ import os
5
+ import platform
6
+ import subprocess
7
+ import time
8
+ from copy import deepcopy
9
+ from pathlib import Path
10
+
11
+ import thop
12
+ import torch
13
+ import torch.optim as optim
14
+ from prefetch_generator import BackgroundGenerator
15
+ from torch import nn
16
+ from torch.utils.data import DataLoader
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ def create_logger(args, setting_yaml, phase='train', rank=-1):
22
+ # set up logger dir
23
+ dataset = setting_yaml['dataset_dataset']
24
+ model = setting_yaml['model_name']
25
+ cfg_path = os.path.basename(args.log_dir).split('.')[0]
26
+
27
+ if rank in [-1, 0]:
28
+ time_str = time.strftime('%Y-%m-%d-%H-%M')
29
+ log_file = '{}_{}_{}.log'.format(cfg_path, time_str, phase)
30
+ # set up tensorboard_log_dir
31
+ tensorboard_log_dir = Path(args.log_dir) / dataset / model / (cfg_path + '_' + time_str)
32
+ final_output_dir = tensorboard_log_dir
33
+ if not tensorboard_log_dir.exists():
34
+ print('=> creating {}'.format(tensorboard_log_dir))
35
+ tensorboard_log_dir.mkdir(parents=True)
36
+
37
+ final_log_file = tensorboard_log_dir / log_file
38
+ head = '%(asctime)-15s %(message)s'
39
+ logging.basicConfig(filename=str(final_log_file),
40
+ format=head)
41
+ logger = logging.getLogger()
42
+ logger.setLevel(logging.INFO)
43
+ console = logging.StreamHandler()
44
+ logging.getLogger('').addHandler(console)
45
+
46
+ return logger, str(final_output_dir), str(tensorboard_log_dir)
47
+ else:
48
+ return None, None, None
49
+
50
+
51
+ def select_device(logger=None, device='', batch_size=None):
52
+ # device = 'cpu' or '0' or '0,1,2,3'
53
+ s = f'mtpnet 🚀 {git_describe() or date_modified()} torch {torch.__version__} ' # string
54
+ cpu = device.lower() == 'cpu'
55
+ if cpu:
56
+ os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
57
+ elif device: # non-cpu device requested
58
+ os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
59
+ assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability
60
+
61
+ cuda = not cpu and torch.cuda.is_available()
62
+ if cuda:
63
+ n = torch.cuda.device_count()
64
+ if n > 1 and batch_size: # check that batch_size is compatible with device_count
65
+ assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
66
+ space = ' ' * len(s)
67
+ for i, d in enumerate(device.split(',') if device else range(n)):
68
+ p = torch.cuda.get_device_properties(i)
69
+ s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB
70
+ else:
71
+ s += 'CPU\n'
72
+
73
+ if logger:
74
+ logger.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe
75
+ return torch.device('cuda:0' if cuda else 'cpu')
76
+
77
+
78
+ def get_optimizer(setting_yaml, model):
79
+ optimizer = None
80
+ if setting_yaml['train_optimizer'] == 'sgd':
81
+ optimizer = optim.SGD(
82
+ filter(lambda p: p.requires_grad, model.parameters()),
83
+ lr=setting_yaml['train_lr0'],
84
+ momentum=setting_yaml['train_momentum'],
85
+ weight_decay=setting_yaml['train_wd'],
86
+ nesterov=setting_yaml['train_nesterov']
87
+ )
88
+ elif setting_yaml['train_optimizer'] == 'adam':
89
+ optimizer = optim.Adam(
90
+ filter(lambda p: p.requires_grad, model.parameters()),
91
+ # model.parameters(),
92
+ lr=setting_yaml['train_lr0'],
93
+ betas=(setting_yaml['train_momentum'], 0.999),
94
+ eps=1e-5
95
+ )
96
+
97
+ return optimizer
98
+
99
+
100
+ def save_checkpoint(epoch, name, model, optimizer, scheduler, ema, output_dir, is_best, best_fitness):
101
+ model_state = model.module.state_dict() if is_parallel(model) else model.state_dict()
102
+ checkpoint = {
103
+ 'epoch': epoch,
104
+ 'model': name,
105
+ 'state_dict': model_state,
106
+ 'ema': deepcopy(ema.ema).half(),
107
+ 'updates': ema.updates,
108
+ 'optimizer': optimizer.state_dict(),
109
+ 'scheduler': scheduler.state_dict(),
110
+ 'best_fitness': best_fitness
111
+ }
112
+ last = os.path.join(output_dir, 'last.pth')
113
+ last_st = os.path.join(output_dir, 'last_st.pth')
114
+ best = os.path.join(output_dir, 'best.pth')
115
+ best_st = os.path.join(output_dir, 'best_st.pth')
116
+ torch.save(checkpoint, last)
117
+ torch.save(model_state, last_st)
118
+ if is_best and (epoch <= 100):
119
+ torch.save(checkpoint, best)
120
+ torch.save(model_state, best_st)
121
+ if is_best and (epoch > 100):
122
+ torch.save(checkpoint, os.path.join(output_dir, 'best_{:03d}.pth'.format(epoch)))
123
+ torch.save(model_state, os.path.join(output_dir, 'best_st_{:03d}.pth'.format(epoch)))
124
+
125
+
126
+ class ModelEMA:
127
+ """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
128
+ Keep a moving average of everything in the model state_dict (parameters and buffers).
129
+ This is intended to allow functionality like
130
+ https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
131
+ A smoothed version of the weights is necessary for some training schemes to perform well.
132
+ This class is sensitive where it is initialized in the sequence of model init,
133
+ GPU assignment and distributed training wrappers.
134
+ """
135
+
136
+ def __init__(self, model, decay=0.9999, updates=0):
137
+ # Create EMA
138
+ self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
139
+ # if next(model.parameters()).device.type != 'cpu':
140
+ # self.ema.half() # FP16 EMA
141
+ self.updates = updates # number of EMA updates
142
+ self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
143
+ for p in self.ema.parameters():
144
+ p.requires_grad_(False)
145
+
146
+ def update(self, model):
147
+ # Update EMA parameters
148
+ with torch.no_grad():
149
+ self.updates += 1
150
+ d = self.decay(self.updates)
151
+
152
+ msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict
153
+ for k, v in self.ema.state_dict().items():
154
+ if v.dtype.is_floating_point:
155
+ v *= d
156
+ v += (1. - d) * msd[k].detach()
157
+
158
+ def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
159
+ # Update EMA attributes
160
+ copy_attr(self.ema, model, include, exclude)
161
+
162
+
163
+ class DataLoaderX(DataLoader):
164
+ """prefetch dataloader"""
165
+ def __iter__(self):
166
+ return BackgroundGenerator(super().__iter__())
167
+
168
+
169
+ class AverageMeter(object):
170
+ """Computes and stores the average and current value"""
171
+
172
+ def __init__(self):
173
+ self.val = 0
174
+ self.avg = 0
175
+ self.sum = 0
176
+ self.count = 0
177
+
178
+ def reset(self):
179
+ self.val = 0
180
+ self.avg = 0
181
+ self.sum = 0
182
+ self.count = 0
183
+
184
+ def update(self, val, n=1):
185
+ self.val = val
186
+ self.sum += val * n
187
+ self.count += n
188
+ self.avg = self.sum / self.count if self.count != 0 else 0
189
+
190
+
191
+ def git_describe(path=Path(__file__).parent): # path must be a directory
192
+ # return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
193
+ s = f'git -C {path} describe --tags --long --always'
194
+ try:
195
+ return subprocess.check_output(s, shell=True, stderr=subprocess.STDOUT).decode()[:-1]
196
+ except subprocess.CalledProcessError as e:
197
+ return '' # not a git repository
198
+
199
+
200
+ def date_modified(path=__file__):
201
+ # return human-readable file modification date, i.e. '2021-3-26'
202
+ t = datetime.datetime.fromtimestamp(Path(path).stat().st_mtime)
203
+ return f'{t.year}-{t.month}-{t.day}'
204
+
205
+
206
+ def is_parallel(model):
207
+ return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
208
+
209
+
210
+ def copy_attr(a, b, include=(), exclude=()):
211
+ # Copy attributes from b to a, options to only include [...] and to exclude [...]
212
+ for k, v in b.__dict__.items():
213
+ if (len(include) and k not in include) or k.startswith('_') or k in exclude:
214
+ continue
215
+ else:
216
+ setattr(a, k, v)
217
+
218
+
219
+ def time_synchronized():
220
+ # pytorch-accurate time
221
+ if torch.cuda.is_available():
222
+ torch.cuda.synchronize()
223
+ return time.time()
224
+
225
+
226
+ def profile(x, ops, n=100, device=None):
227
+ # profile a pytorch module or list of modules. Example usage:
228
+ # x = torch.randn(16, 3, 640, 640) # input
229
+ # m1 = lambda x: x * torch.sigmoid(x)
230
+ # m2 = nn.SiLU()
231
+ # profile(x, [m1, m2], n=100) # profile speed over 100 iterations
232
+
233
+ device = device or torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
234
+ x = x.to(device)
235
+ x.requires_grad = True
236
+ print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '')
237
+ print(f"\n{'Params':>12s}{'GFLOPS':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}")
238
+ for m in ops if isinstance(ops, list) else [ops]:
239
+ m = m.to(device) if hasattr(m, 'to') else m # device
240
+ m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m # type
241
+ dtf, dtb, t = 0., 0., [0., 0., 0.] # dt forward, backward
242
+ try:
243
+ flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPS
244
+ except:
245
+ flops = 0
246
+
247
+ for _ in range(n):
248
+ t[0] = time_synchronized()
249
+ y = m(x)
250
+ t[1] = time_synchronized()
251
+ try:
252
+ _ = y.sum().backward()
253
+ t[2] = time_synchronized()
254
+ except: # no backward method
255
+ t[2] = float('nan')
256
+ dtf += (t[1] - t[0]) * 1000 / n # ms per op forward
257
+ dtb += (t[2] - t[1]) * 1000 / n # ms per op backward
258
+
259
+ s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list'
260
+ s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list'
261
+ p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters
262
+ print(f'{p:12}{flops:12.4g}{dtf:16.4g}{dtb:16.4g}{str(s_in):>24s}{str(s_out):>24s}')
263
+
264
+
265
+ def initialize_weights(model):
266
+ for m in model.modules():
267
+ t = type(m)
268
+ if t is nn.Conv2d:
269
+ pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
270
+ elif t is nn.BatchNorm2d:
271
+ m.eps = 1e-3
272
+ m.momentum = 0.03
273
+ elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
274
+ m.inplace = True
275
+
276
+
277
+ def fuse_conv_and_bn(conv, bn):
278
+ # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
279
+ fusedconv = nn.Conv2d(conv.in_channels,
280
+ conv.out_channels,
281
+ kernel_size=conv.kernel_size,
282
+ stride=conv.stride,
283
+ padding=conv.padding,
284
+ groups=conv.groups,
285
+ bias=True).requires_grad_(False).to(conv.weight.device)
286
+
287
+ # prepare filters
288
+ w_conv = conv.weight.clone().view(conv.out_channels, -1)
289
+ w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
290
+ fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
291
+
292
+ # prepare spatial bias
293
+ b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
294
+ b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
295
+ fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
296
+
297
+ return fusedconv
298
+
299
+
300
+ def model_info(model, verbose=False, img_size=640):
301
+ # Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
302
+ n_p = sum(x.numel() for x in model.parameters()) # number parameters
303
+ n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
304
+ if verbose:
305
+ print('%5s %40s %9s %12s %20s %10s %10s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma'))
306
+ for i, (name, p) in enumerate(model.named_parameters()):
307
+ name = name.replace('module_list.', '')
308
+ print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
309
+ (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
310
+
311
+ try: # FLOPS
312
+ from thop import profile
313
+ stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32
314
+ img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input
315
+ flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPS
316
+ img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
317
+ fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPS
318
+ except (ImportError, Exception):
319
+ fs = ''
320
+
321
+ logger.info(f"Model Summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")