Upload 3 files
Browse files- utils/general.py +748 -0
- utils/plots.py +89 -0
- utils/torch_utils.py +321 -0
utils/general.py
ADDED
|
@@ -0,0 +1,748 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import logging
|
| 3 |
+
import math
|
| 4 |
+
import random
|
| 5 |
+
import re
|
| 6 |
+
import time
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import cv2
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
import torchvision
|
| 14 |
+
from matplotlib import pyplot as plt
|
| 15 |
+
|
| 16 |
+
id_dict = {'person': 0, 'rider': 1, 'car': 2, 'bus': 3, 'truck': 4,
|
| 17 |
+
'bike': 5, 'motor': 6, 'tl_green': 7, 'tl_red': 8,
|
| 18 |
+
'tl_yellow': 9, 'tl_none': 10, 'traffic sign': 11, 'train': 12}
|
| 19 |
+
id_dict_single = {'car': 0, 'bus': 1, 'truck': 2, 'train': 3}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def clean_str(s):
|
| 23 |
+
# Cleans a string by replacing special characters with underscore _
|
| 24 |
+
return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def set_logging(rank=-1):
|
| 28 |
+
logging.basicConfig(
|
| 29 |
+
format="%(message)s",
|
| 30 |
+
level=logging.INFO if rank in [-1, 0] else logging.WARN)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def convert(size, box):
|
| 34 |
+
dw = 1. / (size[0])
|
| 35 |
+
dh = 1. / (size[1])
|
| 36 |
+
x = (box[0] + box[1]) / 2.0
|
| 37 |
+
y = (box[2] + box[3]) / 2.0
|
| 38 |
+
w = box[1] - box[0]
|
| 39 |
+
h = box[3] - box[2]
|
| 40 |
+
x = x * dw
|
| 41 |
+
w = w * dw
|
| 42 |
+
y = y * dh
|
| 43 |
+
h = h * dh
|
| 44 |
+
return (x, y, w, h)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def xywh2xyxy(x):
|
| 48 |
+
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
| 49 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
| 50 |
+
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
|
| 51 |
+
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
|
| 52 |
+
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
|
| 53 |
+
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
|
| 54 |
+
return y
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def xyxy2xywh(x):
|
| 58 |
+
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
|
| 59 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
| 60 |
+
y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
|
| 61 |
+
y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
|
| 62 |
+
y[:, 2] = x[:, 2] - x[:, 0] # width
|
| 63 |
+
y[:, 3] = x[:, 3] - x[:, 1] # height
|
| 64 |
+
return y
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, SIoU=False, WIoU=False, eps=1e-7):
|
| 68 |
+
# Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
|
| 69 |
+
box2 = box2.T
|
| 70 |
+
|
| 71 |
+
# Get the coordinates of bounding boxes
|
| 72 |
+
if x1y1x2y2: # x1, y1, x2, y2 = box1
|
| 73 |
+
b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
|
| 74 |
+
b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
|
| 75 |
+
else: # transform from xywh to xyxy
|
| 76 |
+
b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
|
| 77 |
+
b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
|
| 78 |
+
b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
|
| 79 |
+
b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
|
| 80 |
+
|
| 81 |
+
# Intersection area
|
| 82 |
+
inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
|
| 83 |
+
(torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
|
| 84 |
+
|
| 85 |
+
# Union Area
|
| 86 |
+
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
|
| 87 |
+
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
|
| 88 |
+
union = w1 * h1 + w2 * h2 - inter + eps
|
| 89 |
+
|
| 90 |
+
iou = inter / union
|
| 91 |
+
if GIoU or DIoU or CIoU or SIoU:
|
| 92 |
+
cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
|
| 93 |
+
ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
|
| 94 |
+
if SIoU: # SIoU Loss https://arxiv.org/pdf/2205.12740.pdf
|
| 95 |
+
s_cw = (b2_x1 + b2_x2 - b1_x1 - b1_x2) * 0.5
|
| 96 |
+
s_ch = (b2_y1 + b2_y2 - b1_y1 - b1_y2) * 0.5
|
| 97 |
+
sigma = torch.pow(s_cw ** 2 + s_ch ** 2, 0.5)
|
| 98 |
+
sin_alpha_1 = torch.abs(s_cw) / sigma
|
| 99 |
+
sin_alpha_2 = torch.abs(s_ch) / sigma
|
| 100 |
+
threshold = pow(2, 0.5) / 2
|
| 101 |
+
sin_alpha = torch.where(sin_alpha_1 > threshold, sin_alpha_2, sin_alpha_1)
|
| 102 |
+
# angle_cost = 1 - 2 * torch.pow( torch.sin(torch.arcsin(sin_alpha) - np.pi/4), 2)
|
| 103 |
+
angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - np.pi / 2)
|
| 104 |
+
rho_x = (s_cw / cw) ** 2
|
| 105 |
+
rho_y = (s_ch / ch) ** 2
|
| 106 |
+
gamma = angle_cost - 2
|
| 107 |
+
distance_cost = 2 - torch.exp(gamma * rho_x) - torch.exp(gamma * rho_y)
|
| 108 |
+
omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2)
|
| 109 |
+
omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2)
|
| 110 |
+
shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(1 - torch.exp(-1 * omiga_h), 4)
|
| 111 |
+
return iou - 0.5 * (distance_cost + shape_cost)
|
| 112 |
+
if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
|
| 113 |
+
c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
|
| 114 |
+
rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
|
| 115 |
+
(b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared
|
| 116 |
+
if DIoU:
|
| 117 |
+
return iou - rho2 / c2 # DIoU
|
| 118 |
+
elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
|
| 119 |
+
v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
|
| 120 |
+
with torch.no_grad():
|
| 121 |
+
alpha = v / (v - iou + (1 + eps))
|
| 122 |
+
return iou - (rho2 / c2 + v * alpha) # CIoU
|
| 123 |
+
else: # GIoU https://arxiv.org/pdf/1902.09630.pdf
|
| 124 |
+
c_area = cw * ch + eps # convex area
|
| 125 |
+
return iou - (c_area - union) / c_area # GIoU
|
| 126 |
+
elif WIoU:
|
| 127 |
+
b1 = torch.stack([b1_x1, b1_y1, b1_x2, b1_y2], dim=-1)
|
| 128 |
+
b2 = torch.stack([b2_x1, b2_y1, b2_x2, b2_y2], dim=-1)
|
| 129 |
+
|
| 130 |
+
self = IoU_Cal(b1, b2)
|
| 131 |
+
loss = getattr(IoU_Cal, 'WIoU')(b1, b2, self=self)
|
| 132 |
+
iou = 1 - self.iou
|
| 133 |
+
return loss, iou
|
| 134 |
+
else:
|
| 135 |
+
return iou # IoU
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def box_iou(box1, box2):
|
| 139 |
+
# https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
|
| 140 |
+
"""
|
| 141 |
+
Return intersection-over-union (Jaccard index) of boxes.
|
| 142 |
+
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
|
| 143 |
+
Arguments:
|
| 144 |
+
box1 (Tensor[N, 4])
|
| 145 |
+
box2 (Tensor[M, 4])
|
| 146 |
+
Returns:
|
| 147 |
+
iou (Tensor[N, M]): the NxM matrix containing the pairwise
|
| 148 |
+
IoU values for every element in boxes1 and boxes2
|
| 149 |
+
"""
|
| 150 |
+
|
| 151 |
+
def box_area(box):
|
| 152 |
+
# box = 4xn
|
| 153 |
+
return (box[2] - box[0]) * (box[3] - box[1])
|
| 154 |
+
|
| 155 |
+
area1 = box_area(box1.T)
|
| 156 |
+
area2 = box_area(box2.T)
|
| 157 |
+
|
| 158 |
+
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
|
| 159 |
+
inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
|
| 160 |
+
return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def letterbox(combination, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True):
|
| 164 |
+
"""Resize the input image and automatically padding to suitable shape :https://zhuanlan.zhihu.com/p/172121380"""
|
| 165 |
+
# Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232
|
| 166 |
+
img, gray, line = combination
|
| 167 |
+
shape = img.shape[:2] # current shape [height, width]
|
| 168 |
+
if isinstance(new_shape, int):
|
| 169 |
+
new_shape = (new_shape, new_shape)
|
| 170 |
+
|
| 171 |
+
# Scale ratio (new / old)
|
| 172 |
+
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
| 173 |
+
if not scaleup: # only scale down, do not scale up (for better test mAP)
|
| 174 |
+
r = min(r, 1.0)
|
| 175 |
+
|
| 176 |
+
# Compute padding
|
| 177 |
+
ratio = r, r # width, height ratios
|
| 178 |
+
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
| 179 |
+
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
|
| 180 |
+
# if auto: # minimum rectangle
|
| 181 |
+
# dw, dh = np.mod(dw, 32), np.mod(dh, 32) # wh padding
|
| 182 |
+
# elif scaleFill: # stretch
|
| 183 |
+
# dw, dh = 0.0, 0.0
|
| 184 |
+
# new_unpad = (new_shape[1], new_shape[0])
|
| 185 |
+
# ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
|
| 186 |
+
|
| 187 |
+
dw /= 2 # divide padding into 2 sides
|
| 188 |
+
dh /= 2
|
| 189 |
+
|
| 190 |
+
if shape[::-1] != new_unpad: # resize
|
| 191 |
+
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
|
| 192 |
+
gray = cv2.resize(gray, new_unpad, interpolation=cv2.INTER_LINEAR)
|
| 193 |
+
line = cv2.resize(line, new_unpad, interpolation=cv2.INTER_LINEAR)
|
| 194 |
+
|
| 195 |
+
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
| 196 |
+
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
| 197 |
+
|
| 198 |
+
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
|
| 199 |
+
gray = cv2.copyMakeBorder(gray, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0) # add border
|
| 200 |
+
line = cv2.copyMakeBorder(line, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0) # add border
|
| 201 |
+
# print(img.shape)
|
| 202 |
+
|
| 203 |
+
combination = (img, gray, line)
|
| 204 |
+
return combination, ratio, (dw, dh)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def letterbox_for_img(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True):
|
| 208 |
+
# Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232
|
| 209 |
+
shape = img.shape[:2] # current shape [height, width]
|
| 210 |
+
if isinstance(new_shape, int):
|
| 211 |
+
new_shape = (new_shape, new_shape)
|
| 212 |
+
|
| 213 |
+
# Scale ratio (new / old)
|
| 214 |
+
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
| 215 |
+
if not scaleup: # only scale down, do not scale up (for better test mAP)
|
| 216 |
+
r = min(r, 1.0)
|
| 217 |
+
|
| 218 |
+
# Compute padding
|
| 219 |
+
ratio = r, r # width, height ratios
|
| 220 |
+
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
| 221 |
+
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
|
| 222 |
+
if auto: # minimum rectangle
|
| 223 |
+
dw, dh = np.mod(dw, 32), np.mod(dh, 32) # wh padding
|
| 224 |
+
elif scaleFill: # stretch
|
| 225 |
+
dw, dh = 0.0, 0.0
|
| 226 |
+
new_unpad = (new_shape[1], new_shape[0])
|
| 227 |
+
ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
|
| 228 |
+
|
| 229 |
+
dw /= 2 # divide padding into 2 sides
|
| 230 |
+
dh /= 2
|
| 231 |
+
if shape[::-1] != new_unpad: # resize
|
| 232 |
+
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_AREA)
|
| 233 |
+
|
| 234 |
+
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
| 235 |
+
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
| 236 |
+
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
|
| 237 |
+
return img, ratio, (dw, dh)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def colorstr(*input):
|
| 241 |
+
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
|
| 242 |
+
*args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
|
| 243 |
+
colors = {'black': '\033[30m', # basic colors
|
| 244 |
+
'red': '\033[31m',
|
| 245 |
+
'green': '\033[32m',
|
| 246 |
+
'yellow': '\033[33m',
|
| 247 |
+
'blue': '\033[34m',
|
| 248 |
+
'magenta': '\033[35m',
|
| 249 |
+
'cyan': '\033[36m',
|
| 250 |
+
'white': '\033[37m',
|
| 251 |
+
'bright_black': '\033[90m', # bright colors
|
| 252 |
+
'bright_red': '\033[91m',
|
| 253 |
+
'bright_green': '\033[92m',
|
| 254 |
+
'bright_yellow': '\033[93m',
|
| 255 |
+
'bright_blue': '\033[94m',
|
| 256 |
+
'bright_magenta': '\033[95m',
|
| 257 |
+
'bright_cyan': '\033[96m',
|
| 258 |
+
'bright_white': '\033[97m',
|
| 259 |
+
'end': '\033[0m', # misc
|
| 260 |
+
'bold': '\033[1m',
|
| 261 |
+
'underline': '\033[4m'}
|
| 262 |
+
return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def make_divisible(x, divisor):
|
| 266 |
+
# Returns x evenly divisible by divisor
|
| 267 |
+
return math.ceil(x / divisor) * divisor
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def check_img_size(img_size, s=32):
|
| 271 |
+
# Verify img_size is a multiple of stride s
|
| 272 |
+
new_size = make_divisible(img_size, int(s)) # ceil gs-multiple
|
| 273 |
+
if new_size != img_size:
|
| 274 |
+
print('WARNING: --img-size %g must be multiple of max stride %g, updating to %g' % (img_size, s, new_size))
|
| 275 |
+
return new_size
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
|
| 279 |
+
# scales img(bs,3,y,x) by ratio constrained to gs-multiple
|
| 280 |
+
if ratio == 1.0:
|
| 281 |
+
return img
|
| 282 |
+
else:
|
| 283 |
+
h, w = img.shape[2:]
|
| 284 |
+
s = (int(h * ratio), int(w * ratio)) # new size
|
| 285 |
+
img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
|
| 286 |
+
if not same_shape: # pad/crop img
|
| 287 |
+
h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
|
| 288 |
+
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def random_perspective(combination, targets=(), degrees=10, translate=.1, scale=.1, shear=10, perspective=0.0,
|
| 292 |
+
border=(0, 0)):
|
| 293 |
+
"""combination of img transform"""
|
| 294 |
+
# torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10))
|
| 295 |
+
# targets = [cls, xyxy]
|
| 296 |
+
img, gray, line = combination
|
| 297 |
+
height = img.shape[0] + border[0] * 2 # shape(h,w,c)
|
| 298 |
+
width = img.shape[1] + border[1] * 2
|
| 299 |
+
|
| 300 |
+
# Center
|
| 301 |
+
C = np.eye(3)
|
| 302 |
+
C[0, 2] = -img.shape[1] / 2 # x translation (pixels)
|
| 303 |
+
C[1, 2] = -img.shape[0] / 2 # y translation (pixels)
|
| 304 |
+
|
| 305 |
+
# Perspective
|
| 306 |
+
P = np.eye(3)
|
| 307 |
+
P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y)
|
| 308 |
+
P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x)
|
| 309 |
+
|
| 310 |
+
# Rotation and Scale
|
| 311 |
+
R = np.eye(3)
|
| 312 |
+
a = random.uniform(-degrees, degrees)
|
| 313 |
+
# a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
|
| 314 |
+
s = random.uniform(1 - scale, 1 + scale)
|
| 315 |
+
# s = 2 ** random.uniform(-scale, scale)
|
| 316 |
+
R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
|
| 317 |
+
|
| 318 |
+
# Shear
|
| 319 |
+
S = np.eye(3)
|
| 320 |
+
S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
|
| 321 |
+
S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
|
| 322 |
+
|
| 323 |
+
# Translation
|
| 324 |
+
T = np.eye(3)
|
| 325 |
+
T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width # x translation (pixels)
|
| 326 |
+
T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height # y translation (pixels)
|
| 327 |
+
|
| 328 |
+
# Combined rotation matrix
|
| 329 |
+
M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
|
| 330 |
+
if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
|
| 331 |
+
if perspective:
|
| 332 |
+
img = cv2.warpPerspective(img, M, dsize=(width, height), borderValue=(114, 114, 114))
|
| 333 |
+
gray = cv2.warpPerspective(gray, M, dsize=(width, height), borderValue=0)
|
| 334 |
+
line = cv2.warpPerspective(line, M, dsize=(width, height), borderValue=0)
|
| 335 |
+
else: # affine
|
| 336 |
+
img = cv2.warpAffine(img, M[:2], dsize=(width, height), borderValue=(114, 114, 114))
|
| 337 |
+
gray = cv2.warpAffine(gray, M[:2], dsize=(width, height), borderValue=0)
|
| 338 |
+
line = cv2.warpAffine(line, M[:2], dsize=(width, height), borderValue=0)
|
| 339 |
+
|
| 340 |
+
# Visualize
|
| 341 |
+
# import matplotlib.pyplot as plt
|
| 342 |
+
# ax = plt.subplots(1, 2, figsize=(12, 6))[1].ravel()
|
| 343 |
+
# ax[0].imshow(img[:, :, ::-1]) # base
|
| 344 |
+
# ax[1].imshow(img2[:, :, ::-1]) # warped
|
| 345 |
+
|
| 346 |
+
# Transform label coordinates
|
| 347 |
+
n = len(targets)
|
| 348 |
+
if n:
|
| 349 |
+
# warp points
|
| 350 |
+
xy = np.ones((n * 4, 3))
|
| 351 |
+
xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
|
| 352 |
+
xy = xy @ M.T # transform
|
| 353 |
+
if perspective:
|
| 354 |
+
xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8) # rescale
|
| 355 |
+
else: # affine
|
| 356 |
+
xy = xy[:, :2].reshape(n, 8)
|
| 357 |
+
|
| 358 |
+
# create new boxes
|
| 359 |
+
x = xy[:, [0, 2, 4, 6]]
|
| 360 |
+
y = xy[:, [1, 3, 5, 7]]
|
| 361 |
+
xy = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
|
| 362 |
+
|
| 363 |
+
# # apply angle-based reduction of bounding boxes
|
| 364 |
+
# radians = a * math.pi / 180
|
| 365 |
+
# reduction = max(abs(math.sin(radians)), abs(math.cos(radians))) ** 0.5
|
| 366 |
+
# x = (xy[:, 2] + xy[:, 0]) / 2
|
| 367 |
+
# y = (xy[:, 3] + xy[:, 1]) / 2
|
| 368 |
+
# w = (xy[:, 2] - xy[:, 0]) * reduction
|
| 369 |
+
# h = (xy[:, 3] - xy[:, 1]) * reduction
|
| 370 |
+
# xy = np.concatenate((x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, n).T
|
| 371 |
+
|
| 372 |
+
# clip boxes
|
| 373 |
+
xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width)
|
| 374 |
+
xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height)
|
| 375 |
+
|
| 376 |
+
# filter candidates
|
| 377 |
+
i = _box_candidates(box1=targets[:, 1:5].T * s, box2=xy.T)
|
| 378 |
+
targets = targets[i]
|
| 379 |
+
targets[:, 1:5] = xy[i]
|
| 380 |
+
|
| 381 |
+
combination = (img, gray, line)
|
| 382 |
+
return combination, targets
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def _box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1): # box1(4,n), box2(4,n)
|
| 386 |
+
# Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
|
| 387 |
+
w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
|
| 388 |
+
w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
|
| 389 |
+
ar = np.maximum(w2 / (h2 + 1e-16), h2 / (w2 + 1e-16)) # aspect ratio
|
| 390 |
+
return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + 1e-16) > area_thr) & (ar < ar_thr) # candidates
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def mixup(im, labels, seg_label, lane_label, im2, labels2, seg_label2, lane_label2):
|
| 394 |
+
# Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf
|
| 395 |
+
r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
|
| 396 |
+
im = (im * r + im2 * (1 - r)).astype(np.uint8)
|
| 397 |
+
labels = np.concatenate((labels, labels2), 0)
|
| 398 |
+
seg_label |= seg_label2
|
| 399 |
+
lane_label |= lane_label2
|
| 400 |
+
return im, labels, seg_label, lane_label
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5):
|
| 404 |
+
"""change color hue, saturation, value"""
|
| 405 |
+
r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
|
| 406 |
+
hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
|
| 407 |
+
dtype = img.dtype # uint8
|
| 408 |
+
|
| 409 |
+
x = np.arange(0, 256, dtype=np.int16)
|
| 410 |
+
lut_hue = ((x * r[0]) % 180).astype(dtype)
|
| 411 |
+
lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
|
| 412 |
+
lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
|
| 413 |
+
|
| 414 |
+
img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))).astype(dtype)
|
| 415 |
+
cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed
|
| 416 |
+
|
| 417 |
+
# Histogram equalization
|
| 418 |
+
# if random.random() < 0.2:
|
| 419 |
+
# for i in range(3):
|
| 420 |
+
# img[:, :, i] = cv2.equalizeHist(img[:, :, i])
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
|
| 424 |
+
labels=()):
|
| 425 |
+
"""Runs Non-Maximum Suppression (NMS) on inference results
|
| 426 |
+
|
| 427 |
+
Returns:
|
| 428 |
+
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
|
| 429 |
+
"""
|
| 430 |
+
|
| 431 |
+
nc = prediction.shape[2] - 5 # number of classes
|
| 432 |
+
xc = prediction[..., 4] > conf_thres # candidates
|
| 433 |
+
|
| 434 |
+
# Settings
|
| 435 |
+
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
|
| 436 |
+
max_det = 300 # maximum number of detections per image
|
| 437 |
+
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
|
| 438 |
+
time_limit = 10.0 # seconds to quit after
|
| 439 |
+
redundant = True # require redundant detections
|
| 440 |
+
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
|
| 441 |
+
merge = False # use merge-NMS
|
| 442 |
+
|
| 443 |
+
t = time.time()
|
| 444 |
+
output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
|
| 445 |
+
for xi, x in enumerate(prediction): # image index, image inference
|
| 446 |
+
# Apply constraints
|
| 447 |
+
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
|
| 448 |
+
x = x[xc[xi]] # confidence
|
| 449 |
+
|
| 450 |
+
# Cat apriori labels if autolabelling
|
| 451 |
+
if labels and len(labels[xi]):
|
| 452 |
+
l = labels[xi]
|
| 453 |
+
v = torch.zeros((len(l), nc + 5), device=x.device)
|
| 454 |
+
v[:, :4] = l[:, 1:5] # box
|
| 455 |
+
v[:, 4] = 1.0 # conf
|
| 456 |
+
v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
|
| 457 |
+
x = torch.cat((x, v), 0)
|
| 458 |
+
|
| 459 |
+
# If none remain process next image
|
| 460 |
+
if not x.shape[0]:
|
| 461 |
+
continue
|
| 462 |
+
|
| 463 |
+
# Compute conf
|
| 464 |
+
if nc == 1:
|
| 465 |
+
x[:, 5:] = x[:, 4:5] # for models with one class, cls_loss is 0 and cls_conf is always 0.5,
|
| 466 |
+
# so there is no need to multiplicate.
|
| 467 |
+
else:
|
| 468 |
+
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
|
| 469 |
+
|
| 470 |
+
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
|
| 471 |
+
box = xywh2xyxy(x[:, :4])
|
| 472 |
+
|
| 473 |
+
# Detections matrix nx6 (xyxy, conf, cls)
|
| 474 |
+
if multi_label:
|
| 475 |
+
i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
|
| 476 |
+
x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
|
| 477 |
+
else: # best class only
|
| 478 |
+
conf, j = x[:, 5:].max(1, keepdim=True)
|
| 479 |
+
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
|
| 480 |
+
|
| 481 |
+
# Filter by class
|
| 482 |
+
if classes is not None:
|
| 483 |
+
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
|
| 484 |
+
|
| 485 |
+
# Apply finite constraint
|
| 486 |
+
# if not torch.isfinite(x).all():
|
| 487 |
+
# x = x[torch.isfinite(x).all(1)]
|
| 488 |
+
|
| 489 |
+
# Check shape
|
| 490 |
+
n = x.shape[0] # number of boxes
|
| 491 |
+
if not n: # no boxes
|
| 492 |
+
continue
|
| 493 |
+
elif n > max_nms: # excess boxes
|
| 494 |
+
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
|
| 495 |
+
|
| 496 |
+
# Batched NMS
|
| 497 |
+
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
|
| 498 |
+
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
|
| 499 |
+
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
|
| 500 |
+
if i.shape[0] > max_det: # limit detections
|
| 501 |
+
i = i[:max_det]
|
| 502 |
+
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
|
| 503 |
+
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
|
| 504 |
+
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
|
| 505 |
+
weights = iou * scores[None] # box weights
|
| 506 |
+
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
|
| 507 |
+
if redundant:
|
| 508 |
+
i = i[iou.sum(1) > 1] # require redundancy
|
| 509 |
+
|
| 510 |
+
output[xi] = x[i]
|
| 511 |
+
if (time.time() - t) > time_limit:
|
| 512 |
+
print(f'WARNING: NMS time limit {time_limit}s exceeded')
|
| 513 |
+
break # time limit exceeded
|
| 514 |
+
|
| 515 |
+
return output
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
|
| 519 |
+
# Rescale coords (xyxy) from img1_shape to img0_shape
|
| 520 |
+
if ratio_pad is None: # calculate from img0_shape
|
| 521 |
+
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
| 522 |
+
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
|
| 523 |
+
else:
|
| 524 |
+
gain = ratio_pad[0][0]
|
| 525 |
+
pad = ratio_pad[1]
|
| 526 |
+
|
| 527 |
+
coords[:, [0, 2]] -= pad[0] # x padding
|
| 528 |
+
coords[:, [1, 3]] -= pad[1] # y padding
|
| 529 |
+
coords[:, :4] /= gain
|
| 530 |
+
clip_coords(coords, img0_shape)
|
| 531 |
+
return coords
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
def clip_coords(boxes, img_shape):
|
| 535 |
+
# Clip bounding xyxy bounding boxes to image shape (height, width)
|
| 536 |
+
boxes[:, 0].clamp_(0, img_shape[1]) # x1
|
| 537 |
+
boxes[:, 1].clamp_(0, img_shape[0]) # y1
|
| 538 |
+
boxes[:, 2].clamp_(0, img_shape[1]) # x2
|
| 539 |
+
boxes[:, 3].clamp_(0, img_shape[0]) # y2
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='precision-recall_curve.png', names=[]):
|
| 543 |
+
""" Compute the average precision, given the recall and precision curves.
|
| 544 |
+
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
|
| 545 |
+
# Arguments
|
| 546 |
+
tp: True positives (nparray, nx1 or nx10).
|
| 547 |
+
conf: Objectness value from 0-1 (nparray).
|
| 548 |
+
pred_cls: Predicted object classes (nparray).
|
| 549 |
+
target_cls: True object classes (nparray).
|
| 550 |
+
plot: Plot precision-recall curve at mAP@0.5
|
| 551 |
+
save_dir: Plot save directory
|
| 552 |
+
# Returns
|
| 553 |
+
The average precision as computed in py-faster-rcnn.
|
| 554 |
+
"""
|
| 555 |
+
|
| 556 |
+
# Sort by objectness
|
| 557 |
+
i = np.argsort(-conf)
|
| 558 |
+
tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
|
| 559 |
+
|
| 560 |
+
# Find unique classes
|
| 561 |
+
unique_classes = np.unique(target_cls)
|
| 562 |
+
|
| 563 |
+
# Create Precision-Recall curve and compute AP for each class
|
| 564 |
+
px, py = np.linspace(0, 1, 1000), [] # for plotting
|
| 565 |
+
pr_score = 0.1 # score to evaluate P and R https://github.com/ultralytics/yolov3/issues/898
|
| 566 |
+
s = [unique_classes.shape[0], tp.shape[1]] # number class, number iou thresholds (i.e. 10 for mAP0.5...0.95)
|
| 567 |
+
ap, p, r = np.zeros(s), np.zeros((unique_classes.shape[0], 1000)), np.zeros((unique_classes.shape[0], 1000))
|
| 568 |
+
for ci, c in enumerate(unique_classes):
|
| 569 |
+
i = pred_cls == c
|
| 570 |
+
n_l = (target_cls == c).sum() # number of labels
|
| 571 |
+
n_p = i.sum() # number of predictions
|
| 572 |
+
|
| 573 |
+
if n_p == 0 or n_l == 0:
|
| 574 |
+
continue
|
| 575 |
+
else:
|
| 576 |
+
# Accumulate FPs and TPs
|
| 577 |
+
fpc = (1 - tp[i]).cumsum(0)
|
| 578 |
+
tpc = tp[i].cumsum(0)
|
| 579 |
+
|
| 580 |
+
# Recall
|
| 581 |
+
recall = tpc / (n_l + 1e-16) # recall curve
|
| 582 |
+
r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases
|
| 583 |
+
|
| 584 |
+
# Precision
|
| 585 |
+
precision = tpc / (tpc + fpc) # precision curve
|
| 586 |
+
p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score
|
| 587 |
+
# AP from recall-precision curve
|
| 588 |
+
for j in range(tp.shape[1]):
|
| 589 |
+
ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
|
| 590 |
+
if plot and (j == 0):
|
| 591 |
+
py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5
|
| 592 |
+
|
| 593 |
+
# Compute F1 score (harmonic mean of precision and recall)
|
| 594 |
+
f1 = 2 * p * r / (p + r + 1e-16)
|
| 595 |
+
i = r.mean(0).argmax()
|
| 596 |
+
|
| 597 |
+
if plot:
|
| 598 |
+
plot_pr_curve(px, py, ap, save_dir, names)
|
| 599 |
+
|
| 600 |
+
return p[:, i], r[:, i], ap, f1[:, i], unique_classes.astype('int32')
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
def compute_ap(recall, precision):
|
| 604 |
+
""" Compute the average precision, given the recall and precision curves.
|
| 605 |
+
Source: https://github.com/rbgirshick/py-faster-rcnn.
|
| 606 |
+
# Arguments
|
| 607 |
+
recall: The recall curve (list).
|
| 608 |
+
precision: The precision curve (list).
|
| 609 |
+
# Returns
|
| 610 |
+
The average precision as computed in py-faster-rcnn.
|
| 611 |
+
"""
|
| 612 |
+
|
| 613 |
+
# Append sentinel values to beginning and end
|
| 614 |
+
mrec = np.concatenate(([0.], recall, [recall[-1] + 1E-3]))
|
| 615 |
+
mpre = np.concatenate(([1.], precision, [0.]))
|
| 616 |
+
|
| 617 |
+
# Compute the precision envelope
|
| 618 |
+
mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
|
| 619 |
+
|
| 620 |
+
# Integrate area under curve
|
| 621 |
+
method = 'interp' # methods: 'continuous', 'interp'
|
| 622 |
+
if method == 'interp':
|
| 623 |
+
x = np.linspace(0, 1, 101) # 101-point interp (COCO)
|
| 624 |
+
ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate
|
| 625 |
+
|
| 626 |
+
else: # 'continuous'
|
| 627 |
+
i = np.where(mrec[1:] != mrec[:-1])[0] # points where x axis (recall) changes
|
| 628 |
+
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve
|
| 629 |
+
|
| 630 |
+
return ap, mpre, mrec
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
def plot_pr_curve(px, py, ap, save_dir='.', names=()):
|
| 634 |
+
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
| 635 |
+
py = np.stack(py, axis=1)
|
| 636 |
+
|
| 637 |
+
if 0 < len(names) < 21: # show mAP in legend if < 10 classes
|
| 638 |
+
for i, y in enumerate(py.T):
|
| 639 |
+
ax.plot(px, y, linewidth=1, label=f'{names[i]} %.3f' % ap[i, 0]) # plot(recall, precision)
|
| 640 |
+
else:
|
| 641 |
+
ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision)
|
| 642 |
+
|
| 643 |
+
ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean())
|
| 644 |
+
ax.set_xlabel('Recall')
|
| 645 |
+
ax.set_ylabel('Precision')
|
| 646 |
+
ax.set_xlim(0, 1)
|
| 647 |
+
ax.set_ylim(0, 1)
|
| 648 |
+
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
|
| 649 |
+
fig.savefig(Path(save_dir) / 'precision_recall_curve.png', dpi=250)
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
def increment_path(path, exist_ok=True, sep=''):
|
| 653 |
+
# Increment path, i.e. runs/exp --> runs/exp{sep}0, runs/exp{sep}1 etc.
|
| 654 |
+
path = Path(path) # os-agnostic
|
| 655 |
+
if (path.exists() and exist_ok) or (not path.exists()):
|
| 656 |
+
return str(path)
|
| 657 |
+
else:
|
| 658 |
+
dirs = glob.glob(f"{path}{sep}*") # similar paths
|
| 659 |
+
matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
|
| 660 |
+
i = [int(m.groups()[0]) for m in matches if m] # indices
|
| 661 |
+
n = max(i) + 1 if i else 2 # increment number
|
| 662 |
+
return f"{path}{sep}{n}" # update path
|
| 663 |
+
|
| 664 |
+
|
| 665 |
+
class IoU_Cal:
|
| 666 |
+
''' pred, target: x0,y0,x1,y1
|
| 667 |
+
monotonous: {
|
| 668 |
+
None: origin
|
| 669 |
+
True: monotonic FM
|
| 670 |
+
False: non-monotonic FM
|
| 671 |
+
}
|
| 672 |
+
momentum: The momentum of running mean'''
|
| 673 |
+
iou_mean = 1.
|
| 674 |
+
monotonous = False
|
| 675 |
+
momentum = 1 - 0.5 ** (1 / 7000)
|
| 676 |
+
|
| 677 |
+
_is_train = True
|
| 678 |
+
|
| 679 |
+
def __init__(self, pred, target):
|
| 680 |
+
self.pred, self.target = pred, target
|
| 681 |
+
self._fget = {
|
| 682 |
+
# x,y,w,h
|
| 683 |
+
'pred_xy': lambda: (self.pred[..., :2] + self.pred[..., 2: 4]) / 2,
|
| 684 |
+
'pred_wh': lambda: self.pred[..., 2: 4] - self.pred[..., :2],
|
| 685 |
+
'target_xy': lambda: (self.target[..., :2] + self.target[..., 2: 4]) / 2,
|
| 686 |
+
'target_wh': lambda: self.target[..., 2: 4] - self.target[..., :2],
|
| 687 |
+
# x0,y0,x1,y1
|
| 688 |
+
'min_coord': lambda: torch.minimum(self.pred[..., :4], self.target[..., :4]),
|
| 689 |
+
'max_coord': lambda: torch.maximum(self.pred[..., :4], self.target[..., :4]),
|
| 690 |
+
# The overlapping region
|
| 691 |
+
'wh_inter': lambda: self.min_coord[..., 2: 4] - self.max_coord[..., :2],
|
| 692 |
+
's_inter': lambda: torch.prod(torch.relu(self.wh_inter), dim=-1),
|
| 693 |
+
# The area covered
|
| 694 |
+
's_union': lambda: torch.prod(self.pred_wh, dim=-1) +
|
| 695 |
+
torch.prod(self.target_wh, dim=-1) - self.s_inter,
|
| 696 |
+
# The smallest enclosing box
|
| 697 |
+
'wh_box': lambda: self.max_coord[..., 2: 4] - self.min_coord[..., :2],
|
| 698 |
+
's_box': lambda: torch.prod(self.wh_box, dim=-1),
|
| 699 |
+
'l2_box': lambda: torch.square(self.wh_box).sum(dim=-1),
|
| 700 |
+
# The central points' connection of the bounding boxes
|
| 701 |
+
'd_center': lambda: self.pred_xy - self.target_xy,
|
| 702 |
+
'l2_center': lambda: torch.square(self.d_center).sum(dim=-1),
|
| 703 |
+
# IoU
|
| 704 |
+
'iou': lambda: 1 - self.s_inter / self.s_union
|
| 705 |
+
}
|
| 706 |
+
self._update(self)
|
| 707 |
+
|
| 708 |
+
def __setitem__(self, key, value):
|
| 709 |
+
self._fget[key] = value
|
| 710 |
+
|
| 711 |
+
def __getattr__(self, item):
|
| 712 |
+
if callable(self._fget[item]):
|
| 713 |
+
self._fget[item] = self._fget[item]()
|
| 714 |
+
return self._fget[item]
|
| 715 |
+
|
| 716 |
+
@classmethod
|
| 717 |
+
def train(cls):
|
| 718 |
+
cls._is_train = True
|
| 719 |
+
|
| 720 |
+
@classmethod
|
| 721 |
+
def eval(cls):
|
| 722 |
+
cls._is_train = False
|
| 723 |
+
|
| 724 |
+
@classmethod
|
| 725 |
+
def _update(cls, self):
|
| 726 |
+
if cls._is_train: cls.iou_mean = (1 - cls.momentum) * cls.iou_mean + \
|
| 727 |
+
cls.momentum * self.iou.detach().mean().item()
|
| 728 |
+
|
| 729 |
+
def _scaled_loss(self, loss, gamma=1.9, delta=3):
|
| 730 |
+
if isinstance(self.monotonous, bool):
|
| 731 |
+
if self.monotonous:
|
| 732 |
+
loss *= (self.iou.detach() / self.iou_mean).sqrt()
|
| 733 |
+
else:
|
| 734 |
+
beta = self.iou.detach() / self.iou_mean
|
| 735 |
+
alpha = delta * torch.pow(gamma, beta - delta)
|
| 736 |
+
loss *= beta / alpha
|
| 737 |
+
return loss
|
| 738 |
+
|
| 739 |
+
@classmethod
|
| 740 |
+
def IoU(cls, pred, target, self=None):
|
| 741 |
+
self = self if self else cls(pred, target)
|
| 742 |
+
return self.iou
|
| 743 |
+
|
| 744 |
+
@classmethod
|
| 745 |
+
def WIoU(cls, pred, target, self=None):
|
| 746 |
+
self = self if self else cls(pred, target)
|
| 747 |
+
dist = torch.exp(self.l2_center / self.l2_box.detach())
|
| 748 |
+
return self._scaled_loss(dist * self.iou)
|
utils/plots.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import matplotlib
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def show_seg_result(img, result, index=0, epoch=0, batch=0, save_dir=None, is_ll=False, palette=None, is_demo=False,
|
| 10 |
+
is_gt=False):
|
| 11 |
+
if palette is None:
|
| 12 |
+
palette = np.random.randint(0, 255, size=(3, 3))
|
| 13 |
+
palette[0] = [0, 0, 0]
|
| 14 |
+
palette[1] = [0, 255, 0]
|
| 15 |
+
palette[2] = [255, 0, 0]
|
| 16 |
+
palette = np.array(palette)
|
| 17 |
+
assert palette.shape[0] == 3 # len(classes)
|
| 18 |
+
assert palette.shape[1] == 3
|
| 19 |
+
assert len(palette.shape) == 2
|
| 20 |
+
|
| 21 |
+
if not is_demo:
|
| 22 |
+
color_seg = np.zeros((result.shape[0], result.shape[1], 3), dtype=np.uint8)
|
| 23 |
+
for label, color in enumerate(palette):
|
| 24 |
+
color_seg[result == label, :] = color
|
| 25 |
+
else:
|
| 26 |
+
color_area = np.zeros((result[0].shape[0], result[0].shape[1], 3), dtype=np.uint8)
|
| 27 |
+
|
| 28 |
+
color_area[result[0] == 1] = [0, 255, 0]
|
| 29 |
+
color_area[result[1] == 1] = [255, 0, 0]
|
| 30 |
+
color_seg = color_area
|
| 31 |
+
|
| 32 |
+
# convert to BGR
|
| 33 |
+
color_seg = color_seg[..., ::-1]
|
| 34 |
+
# print(color_seg.shape)
|
| 35 |
+
color_mask = np.mean(color_seg, 2)
|
| 36 |
+
img[color_mask != 0] = img[color_mask != 0] * 0.5 + color_seg[color_mask != 0] * 0.5
|
| 37 |
+
# img = img * 0.5 + color_seg * 0.5
|
| 38 |
+
img = img.astype(np.uint8)
|
| 39 |
+
img = cv2.resize(img, (1280, 720), interpolation=cv2.INTER_LINEAR)
|
| 40 |
+
|
| 41 |
+
if not is_demo:
|
| 42 |
+
if not is_gt:
|
| 43 |
+
if not is_ll:
|
| 44 |
+
cv2.imwrite(save_dir + "/batch_{}_{}_{}_da_segresult.png".format(epoch, batch, index), img)
|
| 45 |
+
else:
|
| 46 |
+
cv2.imwrite(save_dir + "/batch_{}_{}_{}_ll_segresult.png".format(epoch, batch, index), img)
|
| 47 |
+
else:
|
| 48 |
+
if not is_ll:
|
| 49 |
+
cv2.imwrite(save_dir + "/batch_{}_{}_{}_da_seg_gt.png".format(epoch, batch, index), img)
|
| 50 |
+
else:
|
| 51 |
+
cv2.imwrite(save_dir + "/batch_{}_{}_{}_ll_seg_gt.png".format(epoch, batch, index), img)
|
| 52 |
+
return img
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def plot_one_box(x, img, color=None, label=None, line_thickness=3):
|
| 56 |
+
# Plots one bounding box on image img
|
| 57 |
+
tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
|
| 58 |
+
color = color or [random.randint(0, 255) for _ in range(3)]
|
| 59 |
+
c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
|
| 60 |
+
cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
|
| 61 |
+
if label:
|
| 62 |
+
tf = max(tl - 1, 1) # font thickness
|
| 63 |
+
t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
|
| 64 |
+
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
|
| 65 |
+
cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled
|
| 66 |
+
cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def driving_area_mask(da_seg_out, width, height, pad_w, pad_h, ratio):
|
| 70 |
+
da_predict = da_seg_out[:, :, pad_h:(height - pad_h), pad_w:(width - pad_w)]
|
| 71 |
+
da_seg_mask = torch.nn.functional.interpolate(da_predict, scale_factor=int(1 / ratio), mode='bilinear')
|
| 72 |
+
_, da_seg_mask = torch.max(da_seg_mask, 1)
|
| 73 |
+
da_seg_mask = da_seg_mask.int().squeeze().cpu().numpy()
|
| 74 |
+
return da_seg_mask
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def lane_line_mask(ll_seg_out, width, height, pad_w, pad_h, ratio):
|
| 78 |
+
ll_predict = ll_seg_out[:, :, pad_h:(height - pad_h), pad_w:(width - pad_w)]
|
| 79 |
+
ll_seg_mask = torch.nn.functional.interpolate(ll_predict, scale_factor=int(1 / ratio), mode='bilinear')
|
| 80 |
+
_, ll_seg_mask = torch.max(ll_seg_mask, 1)
|
| 81 |
+
ll_seg_mask = ll_seg_mask.int().squeeze().cpu().numpy()
|
| 82 |
+
return ll_seg_mask
|
| 83 |
+
|
| 84 |
+
def color_list():
|
| 85 |
+
# Return first 10 plt colors as (r,g,b) https://stackoverflow.com/questions/51350872/python-from-color-name-to-rgb
|
| 86 |
+
def hex2rgb(h):
|
| 87 |
+
return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
|
| 88 |
+
|
| 89 |
+
return [hex2rgb(h) for h in matplotlib.colors.TABLEAU_COLORS.values()] # or BASE_ (8), CSS4_ (148), XKCD_ (949)
|
utils/torch_utils.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
+
import logging
|
| 3 |
+
import math
|
| 4 |
+
import os
|
| 5 |
+
import platform
|
| 6 |
+
import subprocess
|
| 7 |
+
import time
|
| 8 |
+
from copy import deepcopy
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
import thop
|
| 12 |
+
import torch
|
| 13 |
+
import torch.optim as optim
|
| 14 |
+
from prefetch_generator import BackgroundGenerator
|
| 15 |
+
from torch import nn
|
| 16 |
+
from torch.utils.data import DataLoader
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def create_logger(args, setting_yaml, phase='train', rank=-1):
|
| 22 |
+
# set up logger dir
|
| 23 |
+
dataset = setting_yaml['dataset_dataset']
|
| 24 |
+
model = setting_yaml['model_name']
|
| 25 |
+
cfg_path = os.path.basename(args.log_dir).split('.')[0]
|
| 26 |
+
|
| 27 |
+
if rank in [-1, 0]:
|
| 28 |
+
time_str = time.strftime('%Y-%m-%d-%H-%M')
|
| 29 |
+
log_file = '{}_{}_{}.log'.format(cfg_path, time_str, phase)
|
| 30 |
+
# set up tensorboard_log_dir
|
| 31 |
+
tensorboard_log_dir = Path(args.log_dir) / dataset / model / (cfg_path + '_' + time_str)
|
| 32 |
+
final_output_dir = tensorboard_log_dir
|
| 33 |
+
if not tensorboard_log_dir.exists():
|
| 34 |
+
print('=> creating {}'.format(tensorboard_log_dir))
|
| 35 |
+
tensorboard_log_dir.mkdir(parents=True)
|
| 36 |
+
|
| 37 |
+
final_log_file = tensorboard_log_dir / log_file
|
| 38 |
+
head = '%(asctime)-15s %(message)s'
|
| 39 |
+
logging.basicConfig(filename=str(final_log_file),
|
| 40 |
+
format=head)
|
| 41 |
+
logger = logging.getLogger()
|
| 42 |
+
logger.setLevel(logging.INFO)
|
| 43 |
+
console = logging.StreamHandler()
|
| 44 |
+
logging.getLogger('').addHandler(console)
|
| 45 |
+
|
| 46 |
+
return logger, str(final_output_dir), str(tensorboard_log_dir)
|
| 47 |
+
else:
|
| 48 |
+
return None, None, None
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def select_device(logger=None, device='', batch_size=None):
|
| 52 |
+
# device = 'cpu' or '0' or '0,1,2,3'
|
| 53 |
+
s = f'mtpnet 🚀 {git_describe() or date_modified()} torch {torch.__version__} ' # string
|
| 54 |
+
cpu = device.lower() == 'cpu'
|
| 55 |
+
if cpu:
|
| 56 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
|
| 57 |
+
elif device: # non-cpu device requested
|
| 58 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
|
| 59 |
+
assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability
|
| 60 |
+
|
| 61 |
+
cuda = not cpu and torch.cuda.is_available()
|
| 62 |
+
if cuda:
|
| 63 |
+
n = torch.cuda.device_count()
|
| 64 |
+
if n > 1 and batch_size: # check that batch_size is compatible with device_count
|
| 65 |
+
assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
|
| 66 |
+
space = ' ' * len(s)
|
| 67 |
+
for i, d in enumerate(device.split(',') if device else range(n)):
|
| 68 |
+
p = torch.cuda.get_device_properties(i)
|
| 69 |
+
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB
|
| 70 |
+
else:
|
| 71 |
+
s += 'CPU\n'
|
| 72 |
+
|
| 73 |
+
if logger:
|
| 74 |
+
logger.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe
|
| 75 |
+
return torch.device('cuda:0' if cuda else 'cpu')
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def get_optimizer(setting_yaml, model):
|
| 79 |
+
optimizer = None
|
| 80 |
+
if setting_yaml['train_optimizer'] == 'sgd':
|
| 81 |
+
optimizer = optim.SGD(
|
| 82 |
+
filter(lambda p: p.requires_grad, model.parameters()),
|
| 83 |
+
lr=setting_yaml['train_lr0'],
|
| 84 |
+
momentum=setting_yaml['train_momentum'],
|
| 85 |
+
weight_decay=setting_yaml['train_wd'],
|
| 86 |
+
nesterov=setting_yaml['train_nesterov']
|
| 87 |
+
)
|
| 88 |
+
elif setting_yaml['train_optimizer'] == 'adam':
|
| 89 |
+
optimizer = optim.Adam(
|
| 90 |
+
filter(lambda p: p.requires_grad, model.parameters()),
|
| 91 |
+
# model.parameters(),
|
| 92 |
+
lr=setting_yaml['train_lr0'],
|
| 93 |
+
betas=(setting_yaml['train_momentum'], 0.999),
|
| 94 |
+
eps=1e-5
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
return optimizer
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def save_checkpoint(epoch, name, model, optimizer, scheduler, ema, output_dir, is_best, best_fitness):
|
| 101 |
+
model_state = model.module.state_dict() if is_parallel(model) else model.state_dict()
|
| 102 |
+
checkpoint = {
|
| 103 |
+
'epoch': epoch,
|
| 104 |
+
'model': name,
|
| 105 |
+
'state_dict': model_state,
|
| 106 |
+
'ema': deepcopy(ema.ema).half(),
|
| 107 |
+
'updates': ema.updates,
|
| 108 |
+
'optimizer': optimizer.state_dict(),
|
| 109 |
+
'scheduler': scheduler.state_dict(),
|
| 110 |
+
'best_fitness': best_fitness
|
| 111 |
+
}
|
| 112 |
+
last = os.path.join(output_dir, 'last.pth')
|
| 113 |
+
last_st = os.path.join(output_dir, 'last_st.pth')
|
| 114 |
+
best = os.path.join(output_dir, 'best.pth')
|
| 115 |
+
best_st = os.path.join(output_dir, 'best_st.pth')
|
| 116 |
+
torch.save(checkpoint, last)
|
| 117 |
+
torch.save(model_state, last_st)
|
| 118 |
+
if is_best and (epoch <= 100):
|
| 119 |
+
torch.save(checkpoint, best)
|
| 120 |
+
torch.save(model_state, best_st)
|
| 121 |
+
if is_best and (epoch > 100):
|
| 122 |
+
torch.save(checkpoint, os.path.join(output_dir, 'best_{:03d}.pth'.format(epoch)))
|
| 123 |
+
torch.save(model_state, os.path.join(output_dir, 'best_st_{:03d}.pth'.format(epoch)))
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class ModelEMA:
|
| 127 |
+
""" Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
|
| 128 |
+
Keep a moving average of everything in the model state_dict (parameters and buffers).
|
| 129 |
+
This is intended to allow functionality like
|
| 130 |
+
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
| 131 |
+
A smoothed version of the weights is necessary for some training schemes to perform well.
|
| 132 |
+
This class is sensitive where it is initialized in the sequence of model init,
|
| 133 |
+
GPU assignment and distributed training wrappers.
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
def __init__(self, model, decay=0.9999, updates=0):
|
| 137 |
+
# Create EMA
|
| 138 |
+
self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
|
| 139 |
+
# if next(model.parameters()).device.type != 'cpu':
|
| 140 |
+
# self.ema.half() # FP16 EMA
|
| 141 |
+
self.updates = updates # number of EMA updates
|
| 142 |
+
self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
|
| 143 |
+
for p in self.ema.parameters():
|
| 144 |
+
p.requires_grad_(False)
|
| 145 |
+
|
| 146 |
+
def update(self, model):
|
| 147 |
+
# Update EMA parameters
|
| 148 |
+
with torch.no_grad():
|
| 149 |
+
self.updates += 1
|
| 150 |
+
d = self.decay(self.updates)
|
| 151 |
+
|
| 152 |
+
msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict
|
| 153 |
+
for k, v in self.ema.state_dict().items():
|
| 154 |
+
if v.dtype.is_floating_point:
|
| 155 |
+
v *= d
|
| 156 |
+
v += (1. - d) * msd[k].detach()
|
| 157 |
+
|
| 158 |
+
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
|
| 159 |
+
# Update EMA attributes
|
| 160 |
+
copy_attr(self.ema, model, include, exclude)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class DataLoaderX(DataLoader):
|
| 164 |
+
"""prefetch dataloader"""
|
| 165 |
+
def __iter__(self):
|
| 166 |
+
return BackgroundGenerator(super().__iter__())
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class AverageMeter(object):
|
| 170 |
+
"""Computes and stores the average and current value"""
|
| 171 |
+
|
| 172 |
+
def __init__(self):
|
| 173 |
+
self.val = 0
|
| 174 |
+
self.avg = 0
|
| 175 |
+
self.sum = 0
|
| 176 |
+
self.count = 0
|
| 177 |
+
|
| 178 |
+
def reset(self):
|
| 179 |
+
self.val = 0
|
| 180 |
+
self.avg = 0
|
| 181 |
+
self.sum = 0
|
| 182 |
+
self.count = 0
|
| 183 |
+
|
| 184 |
+
def update(self, val, n=1):
|
| 185 |
+
self.val = val
|
| 186 |
+
self.sum += val * n
|
| 187 |
+
self.count += n
|
| 188 |
+
self.avg = self.sum / self.count if self.count != 0 else 0
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def git_describe(path=Path(__file__).parent): # path must be a directory
|
| 192 |
+
# return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
|
| 193 |
+
s = f'git -C {path} describe --tags --long --always'
|
| 194 |
+
try:
|
| 195 |
+
return subprocess.check_output(s, shell=True, stderr=subprocess.STDOUT).decode()[:-1]
|
| 196 |
+
except subprocess.CalledProcessError as e:
|
| 197 |
+
return '' # not a git repository
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def date_modified(path=__file__):
|
| 201 |
+
# return human-readable file modification date, i.e. '2021-3-26'
|
| 202 |
+
t = datetime.datetime.fromtimestamp(Path(path).stat().st_mtime)
|
| 203 |
+
return f'{t.year}-{t.month}-{t.day}'
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def is_parallel(model):
|
| 207 |
+
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def copy_attr(a, b, include=(), exclude=()):
|
| 211 |
+
# Copy attributes from b to a, options to only include [...] and to exclude [...]
|
| 212 |
+
for k, v in b.__dict__.items():
|
| 213 |
+
if (len(include) and k not in include) or k.startswith('_') or k in exclude:
|
| 214 |
+
continue
|
| 215 |
+
else:
|
| 216 |
+
setattr(a, k, v)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def time_synchronized():
|
| 220 |
+
# pytorch-accurate time
|
| 221 |
+
if torch.cuda.is_available():
|
| 222 |
+
torch.cuda.synchronize()
|
| 223 |
+
return time.time()
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def profile(x, ops, n=100, device=None):
|
| 227 |
+
# profile a pytorch module or list of modules. Example usage:
|
| 228 |
+
# x = torch.randn(16, 3, 640, 640) # input
|
| 229 |
+
# m1 = lambda x: x * torch.sigmoid(x)
|
| 230 |
+
# m2 = nn.SiLU()
|
| 231 |
+
# profile(x, [m1, m2], n=100) # profile speed over 100 iterations
|
| 232 |
+
|
| 233 |
+
device = device or torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
| 234 |
+
x = x.to(device)
|
| 235 |
+
x.requires_grad = True
|
| 236 |
+
print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '')
|
| 237 |
+
print(f"\n{'Params':>12s}{'GFLOPS':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}")
|
| 238 |
+
for m in ops if isinstance(ops, list) else [ops]:
|
| 239 |
+
m = m.to(device) if hasattr(m, 'to') else m # device
|
| 240 |
+
m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m # type
|
| 241 |
+
dtf, dtb, t = 0., 0., [0., 0., 0.] # dt forward, backward
|
| 242 |
+
try:
|
| 243 |
+
flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPS
|
| 244 |
+
except:
|
| 245 |
+
flops = 0
|
| 246 |
+
|
| 247 |
+
for _ in range(n):
|
| 248 |
+
t[0] = time_synchronized()
|
| 249 |
+
y = m(x)
|
| 250 |
+
t[1] = time_synchronized()
|
| 251 |
+
try:
|
| 252 |
+
_ = y.sum().backward()
|
| 253 |
+
t[2] = time_synchronized()
|
| 254 |
+
except: # no backward method
|
| 255 |
+
t[2] = float('nan')
|
| 256 |
+
dtf += (t[1] - t[0]) * 1000 / n # ms per op forward
|
| 257 |
+
dtb += (t[2] - t[1]) * 1000 / n # ms per op backward
|
| 258 |
+
|
| 259 |
+
s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list'
|
| 260 |
+
s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list'
|
| 261 |
+
p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters
|
| 262 |
+
print(f'{p:12}{flops:12.4g}{dtf:16.4g}{dtb:16.4g}{str(s_in):>24s}{str(s_out):>24s}')
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def initialize_weights(model):
|
| 266 |
+
for m in model.modules():
|
| 267 |
+
t = type(m)
|
| 268 |
+
if t is nn.Conv2d:
|
| 269 |
+
pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 270 |
+
elif t is nn.BatchNorm2d:
|
| 271 |
+
m.eps = 1e-3
|
| 272 |
+
m.momentum = 0.03
|
| 273 |
+
elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
|
| 274 |
+
m.inplace = True
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def fuse_conv_and_bn(conv, bn):
|
| 278 |
+
# Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
|
| 279 |
+
fusedconv = nn.Conv2d(conv.in_channels,
|
| 280 |
+
conv.out_channels,
|
| 281 |
+
kernel_size=conv.kernel_size,
|
| 282 |
+
stride=conv.stride,
|
| 283 |
+
padding=conv.padding,
|
| 284 |
+
groups=conv.groups,
|
| 285 |
+
bias=True).requires_grad_(False).to(conv.weight.device)
|
| 286 |
+
|
| 287 |
+
# prepare filters
|
| 288 |
+
w_conv = conv.weight.clone().view(conv.out_channels, -1)
|
| 289 |
+
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
|
| 290 |
+
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
|
| 291 |
+
|
| 292 |
+
# prepare spatial bias
|
| 293 |
+
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
|
| 294 |
+
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
|
| 295 |
+
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
|
| 296 |
+
|
| 297 |
+
return fusedconv
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def model_info(model, verbose=False, img_size=640):
|
| 301 |
+
# Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
|
| 302 |
+
n_p = sum(x.numel() for x in model.parameters()) # number parameters
|
| 303 |
+
n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
|
| 304 |
+
if verbose:
|
| 305 |
+
print('%5s %40s %9s %12s %20s %10s %10s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma'))
|
| 306 |
+
for i, (name, p) in enumerate(model.named_parameters()):
|
| 307 |
+
name = name.replace('module_list.', '')
|
| 308 |
+
print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
|
| 309 |
+
(i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
|
| 310 |
+
|
| 311 |
+
try: # FLOPS
|
| 312 |
+
from thop import profile
|
| 313 |
+
stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32
|
| 314 |
+
img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input
|
| 315 |
+
flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPS
|
| 316 |
+
img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
|
| 317 |
+
fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPS
|
| 318 |
+
except (ImportError, Exception):
|
| 319 |
+
fs = ''
|
| 320 |
+
|
| 321 |
+
logger.info(f"Model Summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")
|