stevfoy commited on
Commit
985c437
·
1 Parent(s): 9222fe1
app.py CHANGED
@@ -30,7 +30,7 @@ import gradio as gr
30
  import os
31
 
32
  model = YOLOv3Lightning()
33
- model.load_state_dict(torch.load("yolov3_608_ckpt_40.pth", map_location=torch.device('cpu')), strict=False)
34
  model.setup(stage="test")
35
 
36
  IMAGE_SIZE = 416
 
30
  import os
31
 
32
  model = YOLOv3Lightning()
33
+ model.load_state_dict(torch.load("yolov3_model.pth", map_location=torch.device('cpu')), strict=False)
34
  model.setup(stage="test")
35
 
36
  IMAGE_SIZE = 416
pytorchyolo/__init__.py ADDED
File without changes
pytorchyolo/models.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+
3
+ import os
4
+ from itertools import chain
5
+ from typing import List, Tuple
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ from pytorchyolo.utils.parse_config import parse_model_config
13
+ from pytorchyolo.utils.utils import weights_init_normal
14
+
15
+
16
+ def create_modules(module_defs: List[dict]) -> Tuple[dict, nn.ModuleList]:
17
+ """
18
+ Constructs module list of layer blocks from module configuration in module_defs
19
+
20
+ :param module_defs: List of dictionaries with module definitions
21
+ :return: Hyperparameters and pytorch module list
22
+ """
23
+ hyperparams = module_defs.pop(0)
24
+ hyperparams.update({
25
+ 'batch': int(hyperparams['batch']),
26
+ 'subdivisions': int(hyperparams['subdivisions']),
27
+ 'width': int(hyperparams['width']),
28
+ 'height': int(hyperparams['height']),
29
+ 'channels': int(hyperparams['channels']),
30
+ 'optimizer': hyperparams.get('optimizer'),
31
+ 'momentum': float(hyperparams['momentum']),
32
+ 'decay': float(hyperparams['decay']),
33
+ 'learning_rate': float(hyperparams['learning_rate']),
34
+ 'burn_in': int(hyperparams['burn_in']),
35
+ 'max_batches': int(hyperparams['max_batches']),
36
+ 'policy': hyperparams['policy'],
37
+ 'lr_steps': list(zip(map(int, hyperparams["steps"].split(",")),
38
+ map(float, hyperparams["scales"].split(","))))
39
+ })
40
+ assert hyperparams["height"] == hyperparams["width"], \
41
+ "Height and width should be equal! Non square images are padded with zeros."
42
+ output_filters = [hyperparams["channels"]]
43
+ module_list = nn.ModuleList()
44
+ for module_i, module_def in enumerate(module_defs):
45
+ modules = nn.Sequential()
46
+
47
+ if module_def["type"] == "convolutional":
48
+ bn = int(module_def["batch_normalize"])
49
+ filters = int(module_def["filters"])
50
+ kernel_size = int(module_def["size"])
51
+ pad = (kernel_size - 1) // 2
52
+ modules.add_module(
53
+ f"conv_{module_i}",
54
+ nn.Conv2d(
55
+ in_channels=output_filters[-1],
56
+ out_channels=filters,
57
+ kernel_size=kernel_size,
58
+ stride=int(module_def["stride"]),
59
+ padding=pad,
60
+ bias=not bn,
61
+ ),
62
+ )
63
+ if bn:
64
+ modules.add_module(f"batch_norm_{module_i}",
65
+ nn.BatchNorm2d(filters, momentum=0.1, eps=1e-5))
66
+ if module_def["activation"] == "leaky":
67
+ modules.add_module(f"leaky_{module_i}", nn.LeakyReLU(0.1))
68
+ elif module_def["activation"] == "mish":
69
+ modules.add_module(f"mish_{module_i}", nn.Mish())
70
+ elif module_def["activation"] == "logistic":
71
+ modules.add_module(f"sigmoid_{module_i}", nn.Sigmoid())
72
+ elif module_def["activation"] == "swish":
73
+ modules.add_module(f"swish_{module_i}", nn.SiLU())
74
+
75
+ elif module_def["type"] == "maxpool":
76
+ kernel_size = int(module_def["size"])
77
+ stride = int(module_def["stride"])
78
+ if kernel_size == 2 and stride == 1:
79
+ modules.add_module(f"_debug_padding_{module_i}", nn.ZeroPad2d((0, 1, 0, 1)))
80
+ maxpool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride,
81
+ padding=int((kernel_size - 1) // 2))
82
+ modules.add_module(f"maxpool_{module_i}", maxpool)
83
+
84
+ elif module_def["type"] == "upsample":
85
+ upsample = Upsample(scale_factor=int(module_def["stride"]), mode="nearest")
86
+ modules.add_module(f"upsample_{module_i}", upsample)
87
+
88
+ elif module_def["type"] == "route":
89
+ layers = [int(x) for x in module_def["layers"].split(",")]
90
+ filters = sum([output_filters[1:][i] for i in layers]) // int(module_def.get("groups", 1))
91
+ modules.add_module(f"route_{module_i}", nn.Sequential())
92
+
93
+ elif module_def["type"] == "shortcut":
94
+ filters = output_filters[1:][int(module_def["from"])]
95
+ modules.add_module(f"shortcut_{module_i}", nn.Sequential())
96
+
97
+ elif module_def["type"] == "yolo":
98
+ anchor_idxs = [int(x) for x in module_def["mask"].split(",")]
99
+ # Extract anchors
100
+ anchors = [int(x) for x in module_def["anchors"].split(",")]
101
+ anchors = [(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)]
102
+ anchors = [anchors[i] for i in anchor_idxs]
103
+ num_classes = int(module_def["classes"])
104
+ new_coords = bool(module_def.get("new_coords", False))
105
+ # Define detection layer
106
+ yolo_layer = YOLOLayer(anchors, num_classes, new_coords)
107
+ modules.add_module(f"yolo_{module_i}", yolo_layer)
108
+ # Register module list and number of output filters
109
+ module_list.append(modules)
110
+ output_filters.append(filters)
111
+
112
+ return hyperparams, module_list
113
+
114
+
115
+ class Upsample(nn.Module):
116
+ """ nn.Upsample is deprecated """
117
+
118
+ def __init__(self, scale_factor, mode: str = "nearest"):
119
+ super(Upsample, self).__init__()
120
+ self.scale_factor = scale_factor
121
+ self.mode = mode
122
+
123
+ def forward(self, x):
124
+ x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode)
125
+ return x
126
+
127
+
128
+ class YOLOLayer(nn.Module):
129
+ """Detection layer"""
130
+
131
+ def __init__(self, anchors: List[Tuple[int, int]], num_classes: int, new_coords: bool):
132
+ """
133
+ Create a YOLO layer
134
+
135
+ :param anchors: List of anchors
136
+ :param num_classes: Number of classes
137
+ :param new_coords: Whether to use the new coordinate format from YOLO V7
138
+ """
139
+ super(YOLOLayer, self).__init__()
140
+ self.num_anchors = len(anchors)
141
+ self.num_classes = num_classes
142
+ self.new_coords = new_coords
143
+ self.mse_loss = nn.MSELoss()
144
+ self.bce_loss = nn.BCELoss()
145
+ self.no = num_classes + 5 # number of outputs per anchor
146
+ self.grid = torch.zeros(1) # TODO
147
+
148
+ anchors = torch.tensor(list(chain(*anchors))).float().view(-1, 2)
149
+ self.register_buffer('anchors', anchors)
150
+ self.register_buffer(
151
+ 'anchor_grid', anchors.clone().view(1, -1, 1, 1, 2))
152
+ self.stride = None
153
+
154
+ def forward(self, x: torch.Tensor, img_size: int) -> torch.Tensor:
155
+ """
156
+ Forward pass of the YOLO layer
157
+
158
+ :param x: Input tensor
159
+ :param img_size: Size of the input image
160
+ """
161
+ stride = img_size // x.size(2)
162
+ self.stride = stride
163
+ bs, _, ny, nx = x.shape # x(bs,255,20,20) to x(bs,3,20,20,85)
164
+ x = x.view(bs, self.num_anchors, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
165
+
166
+ if not self.training: # inference
167
+ if self.grid.shape[2:4] != x.shape[2:4]:
168
+ self.grid = self._make_grid(nx, ny).to(x.device)
169
+
170
+ if self.new_coords:
171
+ x[..., 0:2] = (x[..., 0:2] + self.grid) * stride # xy
172
+ x[..., 2:4] = x[..., 2:4] ** 2 * (4 * self.anchor_grid) # wh
173
+ else:
174
+ x[..., 0:2] = (x[..., 0:2].sigmoid() + self.grid) * stride # xy
175
+ x[..., 2:4] = torch.exp(x[..., 2:4]) * self.anchor_grid # wh
176
+ x[..., 4:] = x[..., 4:].sigmoid() # conf, cls
177
+ x = x.view(bs, -1, self.no)
178
+
179
+ return x
180
+
181
+ @staticmethod
182
+ def _make_grid(nx: int = 20, ny: int = 20) -> torch.Tensor:
183
+ """
184
+ Create a grid of (x, y) coordinates
185
+
186
+ :param nx: Number of x coordinates
187
+ :param ny: Number of y coordinates
188
+ """
189
+ yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)], indexing='ij')
190
+ return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
191
+
192
+
193
+ class Darknet(nn.Module):
194
+ """YOLOv3 object detection model"""
195
+
196
+ def __init__(self, config_path):
197
+ super(Darknet, self).__init__()
198
+ self.module_defs = parse_model_config(config_path)
199
+ self.hyperparams, self.module_list = create_modules(self.module_defs)
200
+ self.yolo_layers = [layer[0]
201
+ for layer in self.module_list if isinstance(layer[0], YOLOLayer)]
202
+ self.seen = 0
203
+ self.header_info = np.array([0, 0, 0, self.seen, 0], dtype=np.int32)
204
+
205
+ def forward(self, x):
206
+ img_size = x.size(2)
207
+ layer_outputs, yolo_outputs = [], []
208
+ for i, (module_def, module) in enumerate(zip(self.module_defs, self.module_list)):
209
+ if module_def["type"] in ["convolutional", "upsample", "maxpool"]:
210
+ x = module(x)
211
+ elif module_def["type"] == "route":
212
+ combined_outputs = torch.cat([layer_outputs[int(layer_i)] for layer_i in module_def["layers"].split(",")], 1)
213
+ group_size = combined_outputs.shape[1] // int(module_def.get("groups", 1))
214
+ group_id = int(module_def.get("group_id", 0))
215
+ x = combined_outputs[:, group_size * group_id : group_size * (group_id + 1)] # Slice groupings used by yolo v4
216
+ elif module_def["type"] == "shortcut":
217
+ layer_i = int(module_def["from"])
218
+ x = layer_outputs[-1] + layer_outputs[layer_i]
219
+ elif module_def["type"] == "yolo":
220
+ x = module[0](x, img_size)
221
+ yolo_outputs.append(x)
222
+ layer_outputs.append(x)
223
+ return yolo_outputs if self.training else torch.cat(yolo_outputs, 1)
224
+
225
+ def load_darknet_weights(self, weights_path):
226
+ """Parses and loads the weights stored in 'weights_path'"""
227
+
228
+ # Open the weights file
229
+ with open(weights_path, "rb") as f:
230
+ # First five are header values
231
+ header = np.fromfile(f, dtype=np.int32, count=5)
232
+ self.header_info = header # Needed to write header when saving weights
233
+ self.seen = header[3] # number of images seen during training
234
+ weights = np.fromfile(f, dtype=np.float32) # The rest are weights
235
+
236
+ # Establish cutoff for loading backbone weights
237
+ cutoff = None
238
+ # If the weights file has a cutoff, we can find out about it by looking at the filename
239
+ # examples: darknet53.conv.74 -> cutoff is 74
240
+ filename = os.path.basename(weights_path)
241
+ if ".conv." in filename:
242
+ try:
243
+ cutoff = int(filename.split(".")[-1]) # use last part of filename
244
+ except ValueError:
245
+ pass
246
+
247
+ ptr = 0
248
+ for i, (module_def, module) in enumerate(zip(self.module_defs, self.module_list)):
249
+ if i == cutoff:
250
+ break
251
+ if module_def["type"] == "convolutional":
252
+ conv_layer = module[0]
253
+ if module_def["batch_normalize"]:
254
+ # Load BN bias, weights, running mean and running variance
255
+ bn_layer = module[1]
256
+ num_b = bn_layer.bias.numel() # Number of biases
257
+ # Bias
258
+ bn_b = torch.from_numpy(
259
+ weights[ptr: ptr + num_b]).view_as(bn_layer.bias)
260
+ bn_layer.bias.data.copy_(bn_b)
261
+ ptr += num_b
262
+ # Weight
263
+ bn_w = torch.from_numpy(
264
+ weights[ptr: ptr + num_b]).view_as(bn_layer.weight)
265
+ bn_layer.weight.data.copy_(bn_w)
266
+ ptr += num_b
267
+ # Running Mean
268
+ bn_rm = torch.from_numpy(
269
+ weights[ptr: ptr + num_b]).view_as(bn_layer.running_mean)
270
+ bn_layer.running_mean.data.copy_(bn_rm)
271
+ ptr += num_b
272
+ # Running Var
273
+ bn_rv = torch.from_numpy(
274
+ weights[ptr: ptr + num_b]).view_as(bn_layer.running_var)
275
+ bn_layer.running_var.data.copy_(bn_rv)
276
+ ptr += num_b
277
+ else:
278
+ # Load conv. bias
279
+ num_b = conv_layer.bias.numel()
280
+ conv_b = torch.from_numpy(
281
+ weights[ptr: ptr + num_b]).view_as(conv_layer.bias)
282
+ conv_layer.bias.data.copy_(conv_b)
283
+ ptr += num_b
284
+ # Load conv. weights
285
+ num_w = conv_layer.weight.numel()
286
+ conv_w = torch.from_numpy(
287
+ weights[ptr: ptr + num_w]).view_as(conv_layer.weight)
288
+ conv_layer.weight.data.copy_(conv_w)
289
+ ptr += num_w
290
+
291
+ def save_darknet_weights(self, path, cutoff=-1):
292
+ """
293
+ @:param path - path of the new weights file
294
+ @:param cutoff - save layers between 0 and cutoff (cutoff = -1 -> all are saved)
295
+ """
296
+ fp = open(path, "wb")
297
+ self.header_info[3] = self.seen
298
+ self.header_info.tofile(fp)
299
+
300
+ # Iterate through layers
301
+ for i, (module_def, module) in enumerate(zip(self.module_defs[:cutoff], self.module_list[:cutoff])):
302
+ if module_def["type"] == "convolutional":
303
+ conv_layer = module[0]
304
+ # If batch norm, load bn first
305
+ if module_def["batch_normalize"]:
306
+ bn_layer = module[1]
307
+ bn_layer.bias.data.cpu().numpy().tofile(fp)
308
+ bn_layer.weight.data.cpu().numpy().tofile(fp)
309
+ bn_layer.running_mean.data.cpu().numpy().tofile(fp)
310
+ bn_layer.running_var.data.cpu().numpy().tofile(fp)
311
+ # Load conv bias
312
+ else:
313
+ conv_layer.bias.data.cpu().numpy().tofile(fp)
314
+ # Load conv weights
315
+ conv_layer.weight.data.cpu().numpy().tofile(fp)
316
+
317
+ fp.close()
318
+
319
+
320
+ def load_model(model_path, weights_path=None):
321
+ """Loads the yolo model from file.
322
+
323
+ :param model_path: Path to model definition file (.cfg)
324
+ :type model_path: str
325
+ :param weights_path: Path to weights or checkpoint file (.weights or .pth)
326
+ :type weights_path: str
327
+ :return: Returns model
328
+ :rtype: Darknet
329
+ """
330
+ device = torch.device("cuda" if torch.cuda.is_available()
331
+ else "cpu") # Select device for inference
332
+ model = Darknet(model_path).to(device)
333
+
334
+ model.apply(weights_init_normal)
335
+
336
+ # If pretrained weights are specified, start from checkpoint or weight file
337
+ if weights_path:
338
+ if weights_path.endswith(".pth"):
339
+ # Load checkpoint weights
340
+ model.load_state_dict(torch.load(weights_path, map_location=device))
341
+ else:
342
+ # Load darknet weights
343
+ model.load_darknet_weights(weights_path)
344
+ return model
pytorchyolo/utils/__init__.py ADDED
File without changes
pytorchyolo/utils/augmentations.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import imgaug.augmenters as iaa
2
+ from torchvision import transforms
3
+ from pytorchyolo.utils.transforms import ToTensor, PadSquare, RelativeLabels, AbsoluteLabels, ImgAug, adjustGrassColor
4
+ import imgaug as ia
5
+
6
+ class DefaultAug(ImgAug):
7
+ def __init__(self, ):
8
+ self.augmentations = iaa.Sequential([
9
+ iaa.Sharpen((0.0, 0.1)),
10
+ iaa.Affine(rotate=(-0, 0), translate_percent=(-0.1, 0.1), scale=(0.8, 1.1)),
11
+ iaa.AddToBrightness((-20, 100)),
12
+ iaa.AddToHue((-10, 10)),
13
+ iaa.Fliplr(0.5),
14
+ ])
15
+
16
+ class greenAug(ImgAug):
17
+ def __init__(self, ):
18
+ self.augmentations = iaa.Sequential([
19
+ iaa.Sharpen((0.0, 0.1)),
20
+ iaa.Affine(rotate=(-10, 10), translate_percent=(-0.1, 0.1), scale=(0.6, 1.2)),
21
+ iaa.ChangeColorspace(from_colorspace="RGB", to_colorspace="HSV"),
22
+ iaa.WithChannels(0, iaa.Add((4))), # Adjust hue
23
+ iaa.WithChannels(1, iaa.LinearContrast((1))), # Adjust saturation
24
+ iaa.WithChannels(1, iaa.Add((5))),
25
+ iaa.WithChannels(2, iaa.LinearContrast((1))), # Adjust value/brightness
26
+ iaa.WithChannels(2, iaa.Add((92))),
27
+ iaa.ChangeColorspace(from_colorspace="HSV", to_colorspace="RGB"),
28
+ ])
29
+
30
+ class StrongAug(ImgAug):
31
+ def __init__(self, ):
32
+ self.augmentations = iaa.Sequential([
33
+ # iaa.Dropout([0.0, 0.01]),
34
+ iaa.Sharpen((0.0, 0.1)),
35
+ iaa.Affine(rotate=(-15, 15), translate_percent=(-0.1, 0.1), scale=(0.8, 1.1)),
36
+ iaa.AddToBrightness((-10, 60)),
37
+ iaa.AddToHue((-5, 10)),
38
+ iaa.Fliplr(0.5),
39
+ ])
40
+
41
+ class greyAug(ImgAug):
42
+ def __init__(self, ):
43
+ self.augmentations = iaa.Sequential([
44
+ iaa.Dropout([0.0, 0.01]),
45
+ iaa.Sharpen((0.0, 0.1)),
46
+ iaa.Affine(rotate=(-45, 45), translate_percent=(-0.1, 0.1), scale=(0.8, 1.1)),
47
+ iaa.AddToBrightness((0, 80)),
48
+ iaa.AddToHue((10, 20)),
49
+ iaa.Fliplr(0.5),
50
+ # iaa.ChangeColorTemperature((1100,10000)),
51
+ iaa.Grayscale(alpha=(0.0, 1.0)),
52
+ ])
53
+
54
+ class newAug(ImgAug):
55
+ def __init__(self, ):
56
+
57
+ self.augmentations = iaa.Sequential([
58
+ iaa.Fliplr(0.5), # Horizontally flip 50% of the images
59
+ iaa.Affine(
60
+ rotate=(-10, 10), # Rotate images between -25 and 25 degrees
61
+ shear=(-8, 8), # Shear images
62
+ scale={"x": (0.8, 1.2), "y": (0.8, 1.2)} # Scale images
63
+ ),
64
+ iaa.GaussianBlur(sigma=(0, 1.0)), # Apply gaussian blur with a sigma between 0 and 1.0
65
+ iaa.Multiply((0.8, 4.2)), # Change brightness (50-150% of original value)
66
+ iaa.LinearContrast((0.8, 1.2)), # Adjust contrast
67
+ iaa.AddToHueAndSaturation((-20, 20)), # Add/Subtract hue and saturation
68
+ ])
69
+
70
+
71
+ class GrassAug(ImgAug):
72
+ def __init__(self, ):
73
+ self.augmentations = iaa.Sequential([
74
+ iaa.Sharpen((0.0, 0.1)),
75
+ iaa.Affine(rotate=(-15, 15), translate_percent=(-0.1, 0.1), scale=(0.8, 1.1)),
76
+ iaa.AddToBrightness((0, 20)),
77
+ iaa.WithColorspace(
78
+ to_colorspace="HSV",
79
+ from_colorspace="RGB",
80
+ children=iaa.Sequential([
81
+ iaa.WithChannels(1, iaa.Add((-5, 5))), # Randomly adjust saturation
82
+ iaa.WithChannels(2,iaa.Add((-20, 90))) # Randomly adjust value/brightness
83
+ ])
84
+ ),
85
+ iaa.Fliplr(0.5),
86
+ ])
87
+
88
+
89
+
90
+
91
+
92
+
93
+
94
+ AUGMENTATION_TRANSFORMS_Version1 = transforms.Compose([
95
+ AbsoluteLabels(),
96
+ StrongAug(),
97
+ PadSquare(),
98
+ RelativeLabels(),
99
+ ToTensor(),
100
+ ])
101
+
102
+ AUGMENTATION_TRANSFORMS = transforms.Compose([
103
+ AbsoluteLabels(),
104
+ StrongAug(),
105
+ PadSquare(),
106
+ RelativeLabels(),
107
+ ToTensor(),
108
+ ])
109
+
110
+ AUGMENTATION_TRANSFORMS_VersionHSV_PAPER = transforms.Compose([
111
+ AbsoluteLabels(),
112
+ GrassAug(),
113
+ PadSquare(),
114
+ RelativeLabels(),
115
+ ToTensor(),
116
+ ])
117
+
118
+ AUGMENTATION_NONE = transforms.Compose([
119
+ AbsoluteLabels(),
120
+ PadSquare(),
121
+ RelativeLabels(),
122
+ ToTensor(),
123
+ ])
pytorchyolo/utils/datasets.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ import torch.nn.functional as F
3
+ import torch
4
+ import glob
5
+ import random
6
+ import os
7
+ import warnings
8
+ import numpy as np
9
+ from PIL import Image
10
+ from PIL import ImageFile
11
+
12
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
13
+
14
+
15
+ def pad_to_square(img, pad_value):
16
+ c, h, w = img.shape
17
+ dim_diff = np.abs(h - w)
18
+ # (upper / left) padding and (lower / right) padding
19
+ pad1, pad2 = dim_diff // 2, dim_diff - dim_diff // 2
20
+ # Determine padding
21
+ pad = (0, 0, pad1, pad2) if h <= w else (pad1, pad2, 0, 0)
22
+ # Add padding
23
+ img = F.pad(img, pad, "constant", value=pad_value)
24
+
25
+ return img, pad
26
+
27
+
28
+ def resize(image, size):
29
+ image = F.interpolate(image.unsqueeze(0), size=size, mode="nearest").squeeze(0)
30
+ return image
31
+
32
+
33
+ class ImageFolder(Dataset):
34
+ def __init__(self, folder_path, transform=None):
35
+ self.files = sorted(glob.glob("%s/*.*" % folder_path))
36
+ self.transform = transform
37
+
38
+ def __getitem__(self, index):
39
+
40
+ img_path = self.files[index % len(self.files)]
41
+ img = np.array(
42
+ Image.open(img_path).convert('RGB'),
43
+ dtype=np.uint8)
44
+
45
+ # Label Placeholder
46
+ boxes = np.zeros((1, 5))
47
+
48
+ # Apply transforms
49
+ if self.transform:
50
+ img, _ = self.transform((img, boxes))
51
+
52
+ return img_path, img
53
+
54
+ def __len__(self):
55
+ return len(self.files)
56
+
57
+
58
+ class ListDataset(Dataset):
59
+ def __init__(self, list_path, img_size=416, multiscale=True, transform=None):
60
+ with open(list_path, "r") as file:
61
+ self.img_files = file.readlines()
62
+
63
+ self.label_files = []
64
+ for path in self.img_files:
65
+ image_dir = os.path.dirname(path)
66
+ label_dir = "labels".join(image_dir.rsplit("images", 1))
67
+ assert label_dir != image_dir, \
68
+ f"Image path must contain a folder named 'images'! \n'{image_dir}'"
69
+ label_file = os.path.join(label_dir, os.path.basename(path))
70
+ label_file = os.path.splitext(label_file)[0] + '.txt'
71
+ self.label_files.append(label_file)
72
+
73
+ self.img_size = img_size
74
+ self.max_objects = 100
75
+ self.multiscale = multiscale
76
+ self.min_size = self.img_size - 3 * 32
77
+ self.max_size = self.img_size + 3 * 32
78
+ self.batch_count = 0
79
+ self.transform = transform
80
+
81
+ def __getitem__(self, index):
82
+
83
+ # ---------
84
+ # Image
85
+ # ---------
86
+ try:
87
+
88
+ img_path = self.img_files[index % len(self.img_files)].rstrip()
89
+
90
+ img = np.array(Image.open(img_path).convert('RGB'), dtype=np.uint8)
91
+ except Exception:
92
+ print(f"Could not read image '{img_path}'.")
93
+ return
94
+
95
+ # ---------
96
+ # Label
97
+ # ---------
98
+ try:
99
+ label_path = self.label_files[index % len(self.img_files)].rstrip()
100
+
101
+ # Ignore warning if file is empty
102
+ with warnings.catch_warnings():
103
+ warnings.simplefilter("ignore")
104
+ boxes = np.loadtxt(label_path).reshape(-1, 5)
105
+ except Exception:
106
+ print(f"Could not read label '{label_path}'.")
107
+ return
108
+
109
+ # -----------
110
+ # Transform
111
+ # -----------
112
+ if self.transform:
113
+ try:
114
+ img, bb_targets = self.transform((img, boxes))
115
+ except Exception:
116
+ print("Could not apply transform.")
117
+ return
118
+
119
+ return img_path, img, bb_targets
120
+
121
+ def collate_fn(self, batch):
122
+ self.batch_count += 1
123
+
124
+ # Drop invalid images
125
+ batch = [data for data in batch if data is not None]
126
+
127
+ paths, imgs, bb_targets = list(zip(*batch))
128
+
129
+ # Selects new image size every tenth batch
130
+ if self.multiscale and self.batch_count % 10 == 0:
131
+ self.img_size = random.choice(
132
+ range(self.min_size, self.max_size + 1, 32))
133
+
134
+ # Resize images to input shape
135
+ imgs = torch.stack([resize(img, self.img_size) for img in imgs])
136
+
137
+ # Add sample index to targets
138
+ for i, boxes in enumerate(bb_targets):
139
+ boxes[:, 0] = i
140
+ bb_targets = torch.cat(bb_targets, 0)
141
+
142
+ return paths, imgs, bb_targets
143
+
144
+ def __len__(self):
145
+ return len(self.img_files)
pytorchyolo/utils/logger.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import datetime
3
+ from torch.utils.tensorboard import SummaryWriter
4
+
5
+
6
+ class Logger(object):
7
+ def __init__(self, log_dir, log_hist=True):
8
+ """Create a summary writer logging to log_dir."""
9
+ if log_hist: # Check a new folder for each log should be dreated
10
+ log_dir = os.path.join(
11
+ log_dir,
12
+ datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S"))
13
+ self.writer = SummaryWriter(log_dir)
14
+
15
+ def scalar_summary(self, tag, value, step):
16
+ """Log a scalar variable."""
17
+ self.writer.add_scalar(tag, value, step)
18
+
19
+ def list_of_scalars_summary(self, tag_value_pairs, step):
20
+ """Log scalar variables."""
21
+ for tag, value in tag_value_pairs:
22
+ self.writer.add_scalar(tag, value, step)
pytorchyolo/utils/loss.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from .utils import to_cpu
7
+
8
+ # This new loss function is based on https://github.com/ultralytics/yolov3/blob/master/utils/loss.py
9
+
10
+
11
+ def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-9):
12
+ # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
13
+ box2 = box2.T
14
+
15
+ # Get the coordinates of bounding boxes
16
+ if x1y1x2y2: # x1, y1, x2, y2 = box1
17
+ b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
18
+ b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
19
+ else: # transform from xywh to xyxy
20
+ b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
21
+ b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
22
+ b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
23
+ b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
24
+
25
+ # Intersection area
26
+ inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
27
+ (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
28
+
29
+ # Union Area
30
+ w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
31
+ w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
32
+ union = w1 * h1 + w2 * h2 - inter + eps
33
+
34
+ iou = inter / union
35
+ if GIoU or DIoU or CIoU:
36
+ # convex (smallest enclosing box) width
37
+ cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1)
38
+ ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
39
+ if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
40
+ c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
41
+ rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
42
+ (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared
43
+ if DIoU:
44
+ return iou - rho2 / c2 # DIoU
45
+ elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
46
+ v = (4 / math.pi ** 2) * \
47
+ torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
48
+ with torch.no_grad():
49
+ alpha = v / ((1 + eps) - iou + v)
50
+ return iou - (rho2 / c2 + v * alpha) # CIoU
51
+ else: # GIoU https://arxiv.org/pdf/1902.09630.pdf
52
+ c_area = cw * ch + eps # convex area
53
+ return iou - (c_area - union) / c_area # GIoU
54
+ else:
55
+ return iou # IoU
56
+
57
+
58
+ def compute_loss(predictions, targets, model):
59
+ # Check which device was used
60
+ device = targets.device
61
+
62
+ # Add placeholder varables for the different losses
63
+ lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device)
64
+
65
+ # Build yolo targets
66
+ tcls, tbox, indices, anchors = build_targets(predictions, targets, model) # targets
67
+
68
+ # Define different loss functions classification
69
+ BCEcls = nn.BCEWithLogitsLoss(
70
+ pos_weight=torch.tensor([1.0], device=device))
71
+ BCEobj = nn.BCEWithLogitsLoss(
72
+ pos_weight=torch.tensor([1.0], device=device))
73
+
74
+ # Calculate losses for each yolo layer
75
+ for layer_index, layer_predictions in enumerate(predictions):
76
+ # Get image ids, anchors, grid index i and j for each target in the current yolo layer
77
+ b, anchor, grid_j, grid_i = indices[layer_index]
78
+ # Build empty object target tensor with the same shape as the object prediction
79
+ tobj = torch.zeros_like(layer_predictions[..., 0], device=device) # target obj
80
+ # Get the number of targets for this layer.
81
+ # Each target is a label box with some scaling and the association of an anchor box.
82
+ # Label boxes may be associated to 0 or multiple anchors. So they are multiple times or not at all in the targets.
83
+ num_targets = b.shape[0]
84
+ # Check if there are targets for this batch
85
+ if num_targets:
86
+ # Load the corresponding values from the predictions for each of the targets
87
+ ps = layer_predictions[b, anchor, grid_j, grid_i]
88
+
89
+ # Regression of the box
90
+ # Apply sigmoid to xy offset predictions in each cell that has a target
91
+ pxy = ps[:, :2].sigmoid()
92
+ # Apply exponent to wh predictions and multiply with the anchor box that matched best with the label for each cell that has a target
93
+ pwh = torch.exp(ps[:, 2:4]) * anchors[layer_index]
94
+ # Build box out of xy and wh
95
+ pbox = torch.cat((pxy, pwh), 1)
96
+ # Calculate CIoU or GIoU for each target with the predicted box for its cell + anchor
97
+ iou = bbox_iou(pbox.T, tbox[layer_index], x1y1x2y2=False, CIoU=True)
98
+ # We want to minimize our loss so we and the best possible IoU is 1 so we take 1 - IoU and reduce it with a mean
99
+ lbox += (1.0 - iou).mean() # iou loss
100
+
101
+ # Classification of the objectness
102
+ # Fill our empty object target tensor with the IoU we just calculated for each target at the targets position
103
+ tobj[b, anchor, grid_j, grid_i] = iou.detach().clamp(0).type(tobj.dtype) # Use cells with iou > 0 as object targets
104
+
105
+ # Classification of the class
106
+ # Check if we need to do a classification (number of classes > 1)
107
+ if ps.size(1) - 5 > 1:
108
+ # Hot one class encoding
109
+ t = torch.zeros_like(ps[:, 5:], device=device) # targets
110
+ t[range(num_targets), tcls[layer_index]] = 1
111
+ # Use the tensor to calculate the BCE loss
112
+ lcls += BCEcls(ps[:, 5:], t) # BCE
113
+
114
+ # Classification of the objectness the sequel
115
+ # Calculate the BCE loss between the on the fly generated target and the network prediction
116
+ lobj += BCEobj(layer_predictions[..., 4], tobj) # obj loss
117
+
118
+ lbox *= 0.05
119
+ lobj *= 1.0
120
+ lcls *= 0.5
121
+
122
+ # Merge losses
123
+ loss = lbox + lobj + lcls
124
+
125
+ return loss, to_cpu(torch.cat((lbox, lobj, lcls, loss)))
126
+
127
+
128
+ def build_targets(p, targets, model):
129
+ # Build targets for compute_loss(), input targets(image,class,x,y,w,h)
130
+ na, nt = 3, targets.shape[0] # number of anchors, targets #TODO
131
+ tcls, tbox, indices, anch = [], [], [], []
132
+ gain = torch.ones(7, device=targets.device) # normalized to gridspace gain
133
+ # Make a tensor that iterates 0-2 for 3 anchors and repeat that as many times as we have target boxes
134
+ ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt)
135
+ # Copy target boxes anchor size times and append an anchor index to each copy the anchor index is also expressed by the new first dimension
136
+ targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2)
137
+
138
+ for i, yolo_layer in enumerate(model.yolo_layers):
139
+ # Scale anchors by the yolo grid cell size so that an anchor with the size of the cell would result in 1
140
+ anchors = yolo_layer.anchors / yolo_layer.stride
141
+ # Add the number of yolo cells in this layer the gain tensor
142
+ # The gain tensor matches the collums of our targets (img id, class, x, y, w, h, anchor id)
143
+ gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain
144
+ # Scale targets by the number of yolo layer cells, they are now in the yolo cell coordinate system
145
+ t = targets * gain
146
+ # Check if we have targets
147
+ if nt:
148
+ # Calculate ration between anchor and target box for both width and height
149
+ r = t[:, :, 4:6] / anchors[:, None]
150
+ # Select the ratios that have the highest divergence in any axis and check if the ratio is less than 4
151
+ j = torch.max(r, 1. / r).max(2)[0] < 4 # compare #TODO
152
+ # Only use targets that have the correct ratios for their anchors
153
+ # That means we only keep ones that have a matching anchor and we loose the anchor dimension
154
+ # The anchor id is still saved in the 7th value of each target
155
+ t = t[j]
156
+ else:
157
+ t = targets[0]
158
+
159
+ # Extract image id in batch and class id
160
+ b, c = t[:, :2].long().T
161
+ # We isolate the target cell associations.
162
+ # x, y, w, h are allready in the cell coordinate system meaning an x = 1.2 would be 1.2 times cellwidth
163
+ gxy = t[:, 2:4]
164
+ gwh = t[:, 4:6] # grid wh
165
+ # Cast to int to get an cell index e.g. 1.2 gets associated to cell 1
166
+ gij = gxy.long()
167
+ # Isolate x and y index dimensions
168
+ gi, gj = gij.T # grid xy indices
169
+
170
+ # Convert anchor indexes to int
171
+ a = t[:, 6].long()
172
+ # Add target tensors for this yolo layer to the output lists
173
+ # Add to index list and limit index range to prevent out of bounds
174
+ indices.append((b, a, gj.clamp_(0, gain[3].long() - 1), gi.clamp_(0, gain[2].long() - 1)))
175
+ # Add to target box list and convert box coordinates from global grid coordinates to local offsets in the grid cell
176
+ tbox.append(torch.cat((gxy - gij, gwh), 1)) # box
177
+ # Add correct anchor for each target to the list
178
+ anch.append(anchors[a])
179
+ # Add class for each target to the list
180
+ tcls.append(c)
181
+
182
+ return tcls, tbox, indices, anch
pytorchyolo/utils/parse_config.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ def parse_model_config(path):
4
+ """Parses the yolo-v3 layer configuration file and returns module definitions"""
5
+ file = open(path, 'r')
6
+ lines = file.read().split('\n')
7
+ lines = [x for x in lines if x and not x.startswith('#')]
8
+ lines = [x.rstrip().lstrip() for x in lines] # get rid of fringe whitespaces
9
+ module_defs = []
10
+ for line in lines:
11
+ if line.startswith('['): # This marks the start of a new block
12
+ module_defs.append({})
13
+ module_defs[-1]['type'] = line[1:-1].rstrip()
14
+ if module_defs[-1]['type'] == 'convolutional':
15
+ module_defs[-1]['batch_normalize'] = 0
16
+ else:
17
+ key, value = line.split("=")
18
+ value = value.strip()
19
+ module_defs[-1][key.rstrip()] = value.strip()
20
+
21
+ return module_defs
22
+
23
+
24
+ def parse_data_config(path):
25
+ """Parses the data configuration file"""
26
+ options = dict()
27
+ options['gpus'] = '0,1,2,3'
28
+ options['num_workers'] = '10'
29
+ with open(path, 'r') as fp:
30
+ lines = fp.readlines()
31
+ for line in lines:
32
+ line = line.strip()
33
+ if line == '' or line.startswith('#'):
34
+ continue
35
+ key, value = line.split('=')
36
+ options[key.strip()] = value.strip()
37
+ return options
pytorchyolo/utils/transforms.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+
5
+ import imgaug.augmenters as iaa
6
+ from imgaug.augmentables.bbs import BoundingBox, BoundingBoxesOnImage
7
+
8
+ from .utils import xywh2xyxy_np
9
+ import torchvision.transforms as transforms
10
+
11
+ import cv2
12
+ from PIL import Image
13
+
14
+ class ImgAug(object):
15
+ def __init__(self, augmentations=[]):
16
+ self.augmentations = augmentations
17
+
18
+ def __call__(self, data):
19
+ # Unpack data
20
+ img, boxes = data
21
+
22
+ # Convert xywh to xyxy
23
+ boxes = np.array(boxes)
24
+ boxes[:, 1:] = xywh2xyxy_np(boxes[:, 1:])
25
+
26
+ # Convert bounding boxes to imgaug
27
+ bounding_boxes = BoundingBoxesOnImage(
28
+ [BoundingBox(*box[1:], label=box[0]) for box in boxes],
29
+ shape=img.shape)
30
+
31
+ # Apply augmentations
32
+ img, bounding_boxes = self.augmentations(
33
+ image=img,
34
+ bounding_boxes=bounding_boxes)
35
+
36
+ # Clip out of image boxes
37
+ bounding_boxes = bounding_boxes.clip_out_of_image()
38
+
39
+ # Convert bounding boxes back to numpy
40
+ boxes = np.zeros((len(bounding_boxes), 5))
41
+ for box_idx, box in enumerate(bounding_boxes):
42
+ # Extract coordinates for unpadded + unscaled image
43
+ x1 = box.x1
44
+ y1 = box.y1
45
+ x2 = box.x2
46
+ y2 = box.y2
47
+
48
+ # Returns (x, y, w, h)
49
+ boxes[box_idx, 0] = box.label
50
+ boxes[box_idx, 1] = ((x1 + x2) / 2)
51
+ boxes[box_idx, 2] = ((y1 + y2) / 2)
52
+ boxes[box_idx, 3] = (x2 - x1)
53
+ boxes[box_idx, 4] = (y2 - y1)
54
+
55
+ return img, boxes
56
+
57
+
58
+ class RelativeLabels(object):
59
+ def __init__(self, ):
60
+ pass
61
+
62
+ def __call__(self, data):
63
+ img, boxes = data
64
+ h, w, _ = img.shape
65
+ boxes[:, [1, 3]] /= w
66
+ boxes[:, [2, 4]] /= h
67
+ return img, boxes
68
+
69
+
70
+ class AbsoluteLabels(object):
71
+ def __init__(self, ):
72
+ pass
73
+
74
+ def __call__(self, data):
75
+ img, boxes = data
76
+ h, w, _ = img.shape
77
+ boxes[:, [1, 3]] *= w
78
+ boxes[:, [2, 4]] *= h
79
+ return img, boxes
80
+
81
+
82
+ class PadSquare(ImgAug):
83
+ def __init__(self, ):
84
+ self.augmentations = iaa.Sequential([
85
+ iaa.PadToAspectRatio(
86
+ 1.0,
87
+ position="center-center").to_deterministic()
88
+ ])
89
+
90
+
91
+ class ToTensor(object):
92
+ def __init__(self, ):
93
+ pass
94
+
95
+ def __call__(self, data):
96
+ img, boxes = data
97
+ # Extract image as PyTorch tensor
98
+ img = transforms.ToTensor()(img)
99
+
100
+ bb_targets = torch.zeros((len(boxes), 6))
101
+ bb_targets[:, 1:] = transforms.ToTensor()(boxes)
102
+
103
+ return img, bb_targets
104
+
105
+
106
+ class Resize(object):
107
+ def __init__(self, size):
108
+ self.size = size
109
+
110
+ def __call__(self, data):
111
+ img, boxes = data
112
+ img = F.interpolate(img.unsqueeze(0), size=self.size, mode="nearest").squeeze(0)
113
+ return img, boxes
114
+
115
+ # Adjust color brightness strategy
116
+ class adjustGrassColor(object):
117
+ def __init__(self, ):
118
+ self.saturation = 1.25
119
+ self.brightness = 1.15
120
+
121
+ def rgb_to_hsv(self, rgb_img):
122
+ # Extract RGB channels
123
+ r, g, b = rgb_img.unbind(0)
124
+
125
+ # Get the max and min values across RGB
126
+ max_val, _ = torch.max(rgb_img, dim=0)
127
+ min_val, _ = torch.min(rgb_img, dim=0)
128
+ diff = max_val - min_val
129
+
130
+ # Calculate HUE
131
+ h = torch.zeros_like(r)
132
+ mask = max_val == min_val
133
+ h[~mask] = 60.0 * ((g[~mask] - b[~mask]) / diff[~mask] % 6)
134
+ mask = max_val == b
135
+ h[mask] = 60.0 * ((r[mask] - g[mask]) / diff[mask] + 4)
136
+ mask = max_val == g
137
+ h[mask] = 60.0 * ((b[mask] - r[mask]) / diff[mask] + 2)
138
+
139
+ # Calculate SATURATION
140
+ s = torch.zeros_like(r)
141
+ mask = max_val != 0
142
+ s[mask] = (diff[mask] / max_val[mask])
143
+
144
+ # Calculate VALUE
145
+ v = max_val
146
+
147
+ return torch.stack([h, s, v])
148
+
149
+ def hsv_to_rgb(self, hsv_img):
150
+ h, s, v = hsv_img.unbind(0)
151
+ c = v * s
152
+ hh = h / 60.0
153
+ x = c * (1 - torch.abs(hh % 2 - 1))
154
+ m = v - c
155
+
156
+ segments = hh.to(torch.int32)
157
+ r = c * (segments == 0) + x * (segments == 1) + m * (segments == 4) + m * (segments == 5)
158
+ g = x * (segments == 0) + c * (segments == 1) + c * (segments == 2) + x * (segments == 3)
159
+ b = m * (segments == 0) + m * (segments == 1) + x * (segments == 2) + c * (segments == 3)
160
+
161
+ return torch.stack([r, g, b])
162
+
163
+ def adjust_grass_color(self, rgb_img):
164
+ hsv_img = self.rgb_to_hsv(rgb_img)
165
+
166
+ # Adjust saturation
167
+ hsv_img[1] = torch.clamp(hsv_img[1] * self.saturation, 0, 1)
168
+
169
+ # Adjust brightness = 1.15
170
+ hsv_img[2] = torch.clamp(hsv_img[2] * self.brightness, 0, 1)
171
+
172
+ return self.hsv_to_rgb(hsv_img)
173
+
174
+
175
+ def __call__(self, data):
176
+ img, boxes = data
177
+
178
+ img = self.adjust_grass_color(img)
179
+
180
+ return img, boxes
181
+
182
+
183
+
184
+ # Normalize the data to Image Net if weight were trained this way, need to explore darknet code
185
+
186
+ class Normalize(object):
187
+ def __init__(self, ):
188
+ pass
189
+
190
+ def __call__(self, data):
191
+ img, boxes = data
192
+
193
+ img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) (img) # Normalize using ImageNet statistics
194
+
195
+ return img, boxes
196
+
197
+
198
+
199
+
200
+ def load_image(path, device):
201
+ image = Image.open(path).convert('RGB')
202
+ transform = transforms.Compose([
203
+ transforms.ToTensor()
204
+ ])
205
+ return transform(image).unsqueeze(0).to(device)
206
+
207
+
208
+
209
+
210
+
211
+
212
+
213
+ def compute_cumulative_histogram(image):
214
+ bins = torch.linspace(0, 1, 256)
215
+ hist = torch.histc(image, bins=256, min=0, max=1)
216
+ cdf = torch.cumsum(hist, dim=1)
217
+ cdf_normalized = cdf / cdf[:, -1:]
218
+ return cdf_normalized
219
+
220
+
221
+ def match_histogram(source, reference):
222
+ reference_cdf = compute_cumulative_histogram(reference)
223
+ source_cdf = compute_cumulative_histogram(source)
224
+
225
+
226
+ matched_image = torch.zeros_like(source)
227
+
228
+ for b in range(source.size(0)):
229
+ for c in range(source.size(1)):
230
+ for i in range(256):
231
+ source_val = (i + 0.5) / 256
232
+ ref_idx = torch.searchsorted(reference_cdf[b, c], source_cdf[b, c, i])
233
+ ref_val = (ref_idx + 0.5) / 256
234
+ mask = (source[b, c] >= source_val - 0.5/256) & (source[b, c] < source_val + 0.5/256)
235
+ matched_image[b, c, mask] = ref_val
236
+
237
+ return matched_image
238
+
239
+
240
+ # Convert back to PIL Image and save
241
+
242
+ class balance_image:
243
+ def __init__(self):
244
+ self.imageName = "C:\\Users\\stevf\\OneDrive\\Documents\\Projects\\PyTorch-YOLOv3\\data\\turfgrass_VOC\\images\\YOLODataset\\images\\20230210_152530.png"
245
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
246
+ self.reference_image = load_image(self.imageName , self.device)
247
+
248
+ def __call__(self, data):
249
+ matched_image, boxes = data
250
+ transform = transforms.ToPILImage()
251
+ matched_image = match_histogram(matched_image, self.reference_image)
252
+ # matched_image = transform(matched_image.squeeze())
253
+ # matched_image_pil.save('matched.jpg')
254
+ return matched_image , boxes
255
+
256
+
257
+
258
+ class WhiteBalanceTransform:
259
+ def __call__(self, data):
260
+ img, boxes = data
261
+
262
+
263
+ # img_np = np.array(img) # Convert PIL Image to numpy array
264
+ # print("TYPE ",img.shape)
265
+ # Convert to BGR format for OpenCV
266
+ img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
267
+
268
+ # Split the channels
269
+ b, g, r = cv2.split(img_bgr)
270
+
271
+ # Compute the mean of each channel
272
+ r_avg = cv2.mean(r)[0]
273
+ g_avg = cv2.mean(g)[0]
274
+ b_avg = cv2.mean(b)[0]
275
+
276
+ # Calculate scaling factors
277
+ k = (r_avg + g_avg + b_avg) / 3
278
+ kr = k / r_avg
279
+ kg = k / g_avg
280
+ kb = k / b_avg
281
+
282
+ # White balance correction
283
+ r = cv2.normalize(r * kr, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
284
+ g = cv2.normalize(g * kg, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
285
+ b = cv2.normalize(b * kb, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
286
+
287
+ # Merge channels and convert back to RGB format
288
+ img_balanced = cv2.merge([b, g, r])
289
+ img_balanced = cv2.cvtColor(img_balanced, cv2.COLOR_BGR2RGB)
290
+
291
+ # Convert numpy array back to PIL Image
292
+ return np.array(Image.fromarray(img_balanced)), boxes
293
+
294
+ class correctImage(ImgAug):
295
+ def __init__(self, ):
296
+ self.augmentations = iaa.Sequential(
297
+ [
298
+ iaa.AddToBrightness((-100, 0)),
299
+ ],
300
+ )
301
+
302
+ class correctImageAspectRatio(ImgAug):
303
+ def __init__(self, ):
304
+ self.augmentations = iaa.Sequential(
305
+ [
306
+ iaa.Resize({"height": 416, "width": "keep-aspect-ratio"}),
307
+ iaa.CropToFixedSize(height=416, width=416)
308
+ ],
309
+ )
310
+
311
+ class crop(ImgAug):
312
+ def __init__(self, ):
313
+ height, width = 416, 555
314
+ target_width = 416
315
+
316
+ crop_left_right = max(0, (width - target_width) // 2)
317
+ crop_top_bottom = max(0, (height - target_width) // 2) # Assuming you also want height to be 416
318
+
319
+ self.augmentations = iaa.Sequential([
320
+ iaa.Crop(px=(crop_top_bottom, crop_left_right, crop_top_bottom, crop_left_right))
321
+ ])
322
+
323
+
324
+ DEFAULT_TRANSFORMS = transforms.Compose([
325
+ AbsoluteLabels(),
326
+ crop(),
327
+ PadSquare(),
328
+ RelativeLabels(),
329
+ ToTensor(),
330
+ ])
pytorchyolo/utils/utils.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+
3
+ import time
4
+ import platform
5
+ import tqdm
6
+ import torch
7
+ import torch.nn as nn
8
+ import torchvision
9
+ import numpy as np
10
+ import subprocess
11
+ import random
12
+ import imgaug as ia
13
+
14
+
15
+ def provide_determinism(seed=42):
16
+ random.seed(seed)
17
+ np.random.seed(seed)
18
+ torch.manual_seed(seed)
19
+ torch.cuda.manual_seed_all(seed)
20
+ ia.seed(seed)
21
+
22
+ torch.backends.cudnn.benchmark = False
23
+ torch.backends.cudnn.deterministic = True
24
+
25
+
26
+ def worker_seed_set(worker_id):
27
+ # See for details of numpy:
28
+ # https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562
29
+ # See for details of random:
30
+ # https://pytorch.org/docs/stable/notes/randomness.html#dataloader
31
+
32
+ # NumPy
33
+ uint64_seed = torch.initial_seed()
34
+ ss = np.random.SeedSequence([uint64_seed])
35
+ np.random.seed(ss.generate_state(4))
36
+
37
+ # random
38
+ worker_seed = torch.initial_seed() % 2**32
39
+ random.seed(worker_seed)
40
+
41
+
42
+ def to_cpu(tensor):
43
+ return tensor.detach().cpu()
44
+
45
+
46
+ def load_classes(path):
47
+ """
48
+ Loads class labels at 'path'
49
+ """
50
+ with open(path, "r") as fp:
51
+ names = fp.read().splitlines()
52
+ return names
53
+
54
+
55
+ def weights_init_normal(m):
56
+ classname = m.__class__.__name__
57
+ if classname.find("Conv") != -1:
58
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
59
+ elif classname.find("BatchNorm2d") != -1:
60
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
61
+ nn.init.constant_(m.bias.data, 0.0)
62
+
63
+
64
+ def rescale_boxes(boxes, current_dim, original_shape):
65
+ """
66
+ Rescales bounding boxes to the original shape
67
+ """
68
+ orig_h, orig_w = original_shape
69
+
70
+ # The amount of padding that was added
71
+ pad_x = max(orig_h - orig_w, 0) * (current_dim / max(original_shape))
72
+ pad_y = max(orig_w - orig_h, 0) * (current_dim / max(original_shape))
73
+
74
+ # Image height and width after padding is removed
75
+ unpad_h = current_dim - pad_y
76
+ unpad_w = current_dim - pad_x
77
+
78
+ # Rescale bounding boxes to dimension of original image
79
+ boxes[:, 0] = ((boxes[:, 0] - pad_x // 2) / unpad_w) * orig_w
80
+ boxes[:, 1] = ((boxes[:, 1] - pad_y // 2) / unpad_h) * orig_h
81
+ boxes[:, 2] = ((boxes[:, 2] - pad_x // 2) / unpad_w) * orig_w
82
+ boxes[:, 3] = ((boxes[:, 3] - pad_y // 2) / unpad_h) * orig_h
83
+ return boxes
84
+
85
+
86
+ def xywh2xyxy(x):
87
+ y = x.new(x.shape)
88
+ y[..., 0] = x[..., 0] - x[..., 2] / 2
89
+ y[..., 1] = x[..., 1] - x[..., 3] / 2
90
+ y[..., 2] = x[..., 0] + x[..., 2] / 2
91
+ y[..., 3] = x[..., 1] + x[..., 3] / 2
92
+ return y
93
+
94
+
95
+ def xywh2xyxy_np(x):
96
+ y = np.zeros_like(x)
97
+ y[..., 0] = x[..., 0] - x[..., 2] / 2
98
+ y[..., 1] = x[..., 1] - x[..., 3] / 2
99
+ y[..., 2] = x[..., 0] + x[..., 2] / 2
100
+ y[..., 3] = x[..., 1] + x[..., 3] / 2
101
+ return y
102
+
103
+
104
+ def ap_per_class(tp, conf, pred_cls, target_cls):
105
+ """ Compute the average precision, given the recall and precision curves.
106
+ Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
107
+ # Arguments
108
+ tp: True positives (list).
109
+ conf: Objectness value from 0-1 (list).
110
+ pred_cls: Predicted object classes (list).
111
+ target_cls: True object classes (list).
112
+ # Returns
113
+ The average precision as computed in py-faster-rcnn.
114
+ """
115
+
116
+ # Sort by objectness
117
+ i = np.argsort(-conf)
118
+ tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
119
+
120
+ # Find unique classes
121
+ unique_classes = np.unique(target_cls)
122
+
123
+ # Create Precision-Recall curve and compute AP for each class
124
+ ap, p, r = [], [], []
125
+ for c in tqdm.tqdm(unique_classes, desc="Computing AP"):
126
+ i = pred_cls == c
127
+ n_gt = (target_cls == c).sum() # Number of ground truth objects
128
+ n_p = i.sum() # Number of predicted objects
129
+
130
+ if n_p == 0 and n_gt == 0:
131
+ continue
132
+ elif n_p == 0 or n_gt == 0:
133
+ ap.append(0)
134
+ r.append(0)
135
+ p.append(0)
136
+ else:
137
+ # Accumulate FPs and TPs
138
+ fpc = (1 - tp[i]).cumsum()
139
+ tpc = (tp[i]).cumsum()
140
+
141
+ # Recall
142
+ recall_curve = tpc / (n_gt + 1e-16)
143
+ r.append(recall_curve[-1])
144
+
145
+ # Precision
146
+ precision_curve = tpc / (tpc + fpc)
147
+ p.append(precision_curve[-1])
148
+
149
+ # AP from recall-precision curve
150
+ ap.append(compute_ap(recall_curve, precision_curve))
151
+
152
+ # Compute F1 score (harmonic mean of precision and recall)
153
+ p, r, ap = np.array(p), np.array(r), np.array(ap)
154
+ f1 = 2 * p * r / (p + r + 1e-16)
155
+
156
+ return p, r, ap, f1, unique_classes.astype("int32")
157
+
158
+
159
+ def compute_ap(recall, precision):
160
+ """ Compute the average precision, given the recall and precision curves.
161
+ Code originally from https://github.com/rbgirshick/py-faster-rcnn.
162
+
163
+ # Arguments
164
+ recall: The recall curve (list).
165
+ precision: The precision curve (list).
166
+ # Returns
167
+ The average precision as computed in py-faster-rcnn.
168
+ """
169
+ # correct AP calculation
170
+ # first append sentinel values at the end
171
+ mrec = np.concatenate(([0.0], recall, [1.0]))
172
+ mpre = np.concatenate(([0.0], precision, [0.0]))
173
+
174
+ # compute the precision envelope
175
+ for i in range(mpre.size - 1, 0, -1):
176
+ mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
177
+
178
+ # to calculate area under PR curve, look for points
179
+ # where X axis (recall) changes value
180
+ i = np.where(mrec[1:] != mrec[:-1])[0]
181
+
182
+ # and sum (\Delta recall) * prec
183
+ ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
184
+ return ap
185
+
186
+
187
+ def get_batch_statistics(outputs, targets, iou_threshold):
188
+ """ Compute true positives, predicted scores and predicted labels per sample """
189
+ batch_metrics = []
190
+ for sample_i in range(len(outputs)):
191
+
192
+ if outputs[sample_i] is None:
193
+ continue
194
+
195
+ output = outputs[sample_i]
196
+ pred_boxes = output[:, :4]
197
+ pred_scores = output[:, 4]
198
+ pred_labels = output[:, -1]
199
+
200
+ true_positives = np.zeros(pred_boxes.shape[0])
201
+
202
+ annotations = targets[targets[:, 0] == sample_i][:, 1:]
203
+ target_labels = annotations[:, 0] if len(annotations) else []
204
+ if len(annotations):
205
+ detected_boxes = []
206
+ target_boxes = annotations[:, 1:]
207
+
208
+ for pred_i, (pred_box, pred_label) in enumerate(zip(pred_boxes, pred_labels)):
209
+
210
+ # If targets are found break
211
+ if len(detected_boxes) == len(annotations):
212
+ break
213
+
214
+ # Ignore if label is not one of the target labels
215
+ if pred_label not in target_labels:
216
+ continue
217
+
218
+ # Filter target_boxes by pred_label so that we only match against boxes of our own label
219
+ filtered_target_position, filtered_targets = zip(*filter(lambda x: target_labels[x[0]] == pred_label, enumerate(target_boxes)))
220
+
221
+ # Find the best matching target for our predicted box
222
+ iou, box_filtered_index = bbox_iou(pred_box.unsqueeze(0), torch.stack(filtered_targets)).max(0)
223
+
224
+ # Remap the index in the list of filtered targets for that label to the index in the list with all targets.
225
+ box_index = filtered_target_position[box_filtered_index]
226
+
227
+ # Check if the iou is above the min treshold and i
228
+ if iou >= iou_threshold and box_index not in detected_boxes:
229
+ true_positives[pred_i] = 1
230
+ detected_boxes += [box_index]
231
+ batch_metrics.append([true_positives, pred_scores, pred_labels])
232
+ return batch_metrics
233
+
234
+
235
+ def bbox_wh_iou(wh1, wh2):
236
+ wh2 = wh2.t()
237
+ w1, h1 = wh1[0], wh1[1]
238
+ w2, h2 = wh2[0], wh2[1]
239
+ inter_area = torch.min(w1, w2) * torch.min(h1, h2)
240
+ union_area = (w1 * h1 + 1e-16) + w2 * h2 - inter_area
241
+ return inter_area / union_area
242
+
243
+
244
+ def bbox_iou(box1, box2, x1y1x2y2=True):
245
+ """
246
+ Returns the IoU of two bounding boxes
247
+ """
248
+ if not x1y1x2y2:
249
+ # Transform from center and width to exact coordinates
250
+ b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2
251
+ b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2
252
+ b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
253
+ b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2
254
+ else:
255
+ # Get the coordinates of bounding boxes
256
+ b1_x1, b1_y1, b1_x2, b1_y2 = \
257
+ box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3]
258
+ b2_x1, b2_y1, b2_x2, b2_y2 = \
259
+ box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3]
260
+
261
+ # get the corrdinates of the intersection rectangle
262
+ inter_rect_x1 = torch.max(b1_x1, b2_x1)
263
+ inter_rect_y1 = torch.max(b1_y1, b2_y1)
264
+ inter_rect_x2 = torch.min(b1_x2, b2_x2)
265
+ inter_rect_y2 = torch.min(b1_y2, b2_y2)
266
+ # Intersection area
267
+ inter_area = torch.clamp(inter_rect_x2 - inter_rect_x1 + 1, min=0) * torch.clamp(
268
+ inter_rect_y2 - inter_rect_y1 + 1, min=0
269
+ )
270
+ # Union Area
271
+ b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1)
272
+ b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1)
273
+
274
+ iou = inter_area / (b1_area + b2_area - inter_area + 1e-16)
275
+
276
+ return iou
277
+
278
+
279
+ def box_iou(box1, box2):
280
+ # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
281
+ """
282
+ Return intersection-over-union (Jaccard index) of boxes.
283
+ Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
284
+ Arguments:
285
+ box1 (Tensor[N, 4])
286
+ box2 (Tensor[M, 4])
287
+ Returns:
288
+ iou (Tensor[N, M]): the NxM matrix containing the pairwise
289
+ IoU values for every element in boxes1 and boxes2
290
+ """
291
+
292
+ def box_area(box):
293
+ # box = 4xn
294
+ return (box[2] - box[0]) * (box[3] - box[1])
295
+
296
+ area1 = box_area(box1.T)
297
+ area2 = box_area(box2.T)
298
+
299
+ # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
300
+ inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) -
301
+ torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
302
+ # iou = inter / (area1 + area2 - inter)
303
+ return inter / (area1[:, None] + area2 - inter)
304
+
305
+
306
+ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None):
307
+ """Performs Non-Maximum Suppression (NMS) on inference results
308
+ Returns:
309
+ detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
310
+ """
311
+
312
+ nc = prediction.shape[2] - 5 # number of classes
313
+
314
+ # Settings
315
+ # (pixels) minimum and maximum box width and height
316
+ max_wh = 4096
317
+ max_det = 300 # maximum number of detections per image
318
+ max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
319
+ time_limit = 1.0 # seconds to quit after
320
+ multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
321
+
322
+ t = time.time()
323
+ output = [torch.zeros((0, 6), device="cpu")] * prediction.shape[0]
324
+
325
+ for xi, x in enumerate(prediction): # image index, image inference
326
+ # Apply constraints
327
+ # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
328
+ x = x[x[..., 4] > conf_thres] # confidence
329
+
330
+ # If none remain process next image
331
+ if not x.shape[0]:
332
+ continue
333
+
334
+ # Compute conf
335
+ x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
336
+
337
+ # Box (center x, center y, width, height) to (x1, y1, x2, y2)
338
+ box = xywh2xyxy(x[:, :4])
339
+
340
+ # Detections matrix nx6 (xyxy, conf, cls)
341
+ if multi_label:
342
+ i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
343
+ x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
344
+ else: # best class only
345
+ conf, j = x[:, 5:].max(1, keepdim=True)
346
+ x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
347
+
348
+ # Filter by class
349
+ if classes is not None:
350
+ x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
351
+
352
+ # Check shape
353
+ n = x.shape[0] # number of boxes
354
+ if not n: # no boxes
355
+ continue
356
+ elif n > max_nms: # excess boxes
357
+ # sort by confidence
358
+ x = x[x[:, 4].argsort(descending=True)[:max_nms]]
359
+
360
+ # Batched NMS
361
+ c = x[:, 5:6] * max_wh # classes
362
+ # boxes (offset by class), scores
363
+ boxes, scores = x[:, :4] + c, x[:, 4]
364
+ i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
365
+ if i.shape[0] > max_det: # limit detections
366
+ i = i[:max_det]
367
+
368
+ output[xi] = to_cpu(x[i])
369
+
370
+ if (time.time() - t) > time_limit:
371
+ print(f'WARNING: NMS time limit {time_limit}s exceeded')
372
+ break # time limit exceeded
373
+
374
+ return output
375
+
376
+
377
+ def print_environment_info():
378
+ """
379
+ Prints infos about the environment and the system.
380
+ This should help when people make issues containg the printout.
381
+ """
382
+
383
+ print("Environment information:")
384
+
385
+ # Print OS information
386
+ print(f"System: {platform.system()} {platform.release()}")
387
+
388
+ # Print poetry package version
389
+ try:
390
+ print(f"Current Version: {subprocess.check_output(['poetry', 'version'], stderr=subprocess.DEVNULL).decode('ascii').strip()}")
391
+ except (subprocess.CalledProcessError, FileNotFoundError):
392
+ print("Not using the poetry package")
393
+
394
+ # Print commit hash if possible
395
+ try:
396
+ print(f"Current Commit Hash: {subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD'], stderr=subprocess.DEVNULL).decode('ascii').strip()}")
397
+ except (subprocess.CalledProcessError, FileNotFoundError):
398
+ print("No git or repo found")