sanjanatule commited on
Commit
183bdff
·
1 Parent(s): 4498318

Upload 4 files

Browse files
Files changed (4) hide show
  1. config.py +184 -0
  2. model.py +176 -0
  3. utils.py +613 -0
  4. yolo3_model.ckpt +3 -0
config.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import albumentations as A
2
+ import cv2
3
+ import torch
4
+
5
+ from albumentations.pytorch import ToTensorV2
6
+ from utils import seed_everything
7
+
8
+ DATASET = 'PASCAL_VOC'
9
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10
+ # seed_everything() # If you want deterministic behavior
11
+ NUM_WORKERS = 2
12
+ BATCH_SIZE = 16
13
+ IMAGE_SIZE = 416
14
+ NUM_CLASSES = 20
15
+ LEARNING_RATE = 1e-5
16
+ WEIGHT_DECAY = 1e-4
17
+ NUM_EPOCHS = 100
18
+ CONF_THRESHOLD = 0.05
19
+ MAP_IOU_THRESH = 0.5
20
+ NMS_IOU_THRESH = 0.45
21
+ S = [IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8]
22
+ PIN_MEMORY = True
23
+ LOAD_MODEL = False
24
+ SAVE_MODEL = True
25
+ CHECKPOINT_FILE = "checkpoint.pth.tar"
26
+ IMG_DIR = "/content/drive/MyDrive/AI/ERA_course/session13_old/PASCAL_VOC/images/"
27
+ LABEL_DIR = "/content/drive/MyDrive/AI/ERA_course/session13_old/PASCAL_VOC/labels/"
28
+
29
+ ANCHORS = [
30
+ [(0.28, 0.22), (0.38, 0.48), (0.9, 0.78)],
31
+ [(0.07, 0.15), (0.15, 0.11), (0.14, 0.29)],
32
+ [(0.02, 0.03), (0.04, 0.07), (0.08, 0.06)],
33
+ ] # Note these have been rescaled to be between [0, 1]
34
+
35
+ means = [0.485, 0.456, 0.406]
36
+
37
+ scale = 1.1
38
+ train_transforms = A.Compose(
39
+ [
40
+ A.LongestMaxSize(max_size=int(IMAGE_SIZE * scale)),
41
+ A.PadIfNeeded(
42
+ min_height=int(IMAGE_SIZE * scale),
43
+ min_width=int(IMAGE_SIZE * scale),
44
+ border_mode=cv2.BORDER_CONSTANT,
45
+ ),
46
+ A.Rotate(limit = 10, interpolation=1, border_mode=4),
47
+ A.RandomCrop(width=IMAGE_SIZE, height=IMAGE_SIZE),
48
+ A.ColorJitter(brightness=0.6, contrast=0.6, saturation=0.6, hue=0.6, p=0.4),
49
+ A.OneOf(
50
+ [
51
+ A.ShiftScaleRotate(
52
+ rotate_limit=20, p=0.5, border_mode=cv2.BORDER_CONSTANT
53
+ ),
54
+ # A.Affine(shear=15, p=0.5, mode="constant"),
55
+ ],
56
+ p=1.0,
57
+ ),
58
+ A.HorizontalFlip(p=0.5),
59
+ A.Blur(p=0.1),
60
+ A.CLAHE(p=0.1),
61
+ A.Posterize(p=0.1),
62
+ A.ToGray(p=0.1),
63
+ A.ChannelShuffle(p=0.05),
64
+ A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
65
+ ToTensorV2(),
66
+ ],
67
+ bbox_params=A.BboxParams(format="yolo", min_visibility=0.4, label_fields=[],),
68
+ )
69
+ test_transforms = A.Compose(
70
+ [
71
+ A.LongestMaxSize(max_size=IMAGE_SIZE),
72
+ A.PadIfNeeded(
73
+ min_height=IMAGE_SIZE, min_width=IMAGE_SIZE, border_mode=cv2.BORDER_CONSTANT
74
+ ),
75
+ A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
76
+ ToTensorV2(),
77
+ ],
78
+ bbox_params=A.BboxParams(format="yolo", min_visibility=0.4, label_fields=[]),
79
+ )
80
+
81
+ PASCAL_CLASSES = [
82
+ "aeroplane",
83
+ "bicycle",
84
+ "bird",
85
+ "boat",
86
+ "bottle",
87
+ "bus",
88
+ "car",
89
+ "cat",
90
+ "chair",
91
+ "cow",
92
+ "diningtable",
93
+ "dog",
94
+ "horse",
95
+ "motorbike",
96
+ "person",
97
+ "pottedplant",
98
+ "sheep",
99
+ "sofa",
100
+ "train",
101
+ "tvmonitor"
102
+ ]
103
+
104
+ COCO_LABELS = ['person',
105
+ 'bicycle',
106
+ 'car',
107
+ 'motorcycle',
108
+ 'airplane',
109
+ 'bus',
110
+ 'train',
111
+ 'truck',
112
+ 'boat',
113
+ 'traffic light',
114
+ 'fire hydrant',
115
+ 'stop sign',
116
+ 'parking meter',
117
+ 'bench',
118
+ 'bird',
119
+ 'cat',
120
+ 'dog',
121
+ 'horse',
122
+ 'sheep',
123
+ 'cow',
124
+ 'elephant',
125
+ 'bear',
126
+ 'zebra',
127
+ 'giraffe',
128
+ 'backpack',
129
+ 'umbrella',
130
+ 'handbag',
131
+ 'tie',
132
+ 'suitcase',
133
+ 'frisbee',
134
+ 'skis',
135
+ 'snowboard',
136
+ 'sports ball',
137
+ 'kite',
138
+ 'baseball bat',
139
+ 'baseball glove',
140
+ 'skateboard',
141
+ 'surfboard',
142
+ 'tennis racket',
143
+ 'bottle',
144
+ 'wine glass',
145
+ 'cup',
146
+ 'fork',
147
+ 'knife',
148
+ 'spoon',
149
+ 'bowl',
150
+ 'banana',
151
+ 'apple',
152
+ 'sandwich',
153
+ 'orange',
154
+ 'broccoli',
155
+ 'carrot',
156
+ 'hot dog',
157
+ 'pizza',
158
+ 'donut',
159
+ 'cake',
160
+ 'chair',
161
+ 'couch',
162
+ 'potted plant',
163
+ 'bed',
164
+ 'dining table',
165
+ 'toilet',
166
+ 'tv',
167
+ 'laptop',
168
+ 'mouse',
169
+ 'remote',
170
+ 'keyboard',
171
+ 'cell phone',
172
+ 'microwave',
173
+ 'oven',
174
+ 'toaster',
175
+ 'sink',
176
+ 'refrigerator',
177
+ 'book',
178
+ 'clock',
179
+ 'vase',
180
+ 'scissors',
181
+ 'teddy bear',
182
+ 'hair drier',
183
+ 'toothbrush'
184
+ ]
model.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of YOLOv3 architecture
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ """
9
+ Information about architecture config:
10
+ Tuple is structured by (filters, kernel_size, stride)
11
+ Every conv is a same convolution.
12
+ List is structured by "B" indicating a residual block followed by the number of repeats
13
+ "S" is for scale prediction block and computing the yolo loss
14
+ "U" is for upsampling the feature map and concatenating with a previous layer
15
+ """
16
+ config = [
17
+ (32, 3, 1),
18
+ (64, 3, 2),
19
+ ["B", 1],
20
+ (128, 3, 2),
21
+ ["B", 2],
22
+ (256, 3, 2),
23
+ ["B", 8],
24
+ (512, 3, 2),
25
+ ["B", 8],
26
+ (1024, 3, 2),
27
+ ["B", 4], # To this point is Darknet-53
28
+ (512, 1, 1),
29
+ (1024, 3, 1),
30
+ "S",
31
+ (256, 1, 1),
32
+ "U",
33
+ (256, 1, 1),
34
+ (512, 3, 1),
35
+ "S",
36
+ (128, 1, 1),
37
+ "U",
38
+ (128, 1, 1),
39
+ (256, 3, 1),
40
+ "S",
41
+ ]
42
+
43
+
44
+ class CNNBlock(nn.Module):
45
+ def __init__(self, in_channels, out_channels, bn_act=True, **kwargs):
46
+ super().__init__()
47
+ self.conv = nn.Conv2d(in_channels, out_channels, bias=not bn_act, **kwargs)
48
+ self.bn = nn.BatchNorm2d(out_channels)
49
+ self.leaky = nn.LeakyReLU(0.1)
50
+ self.use_bn_act = bn_act
51
+
52
+ def forward(self, x):
53
+ if self.use_bn_act:
54
+ return self.leaky(self.bn(self.conv(x)))
55
+ else:
56
+ return self.conv(x)
57
+
58
+
59
+ class ResidualBlock(nn.Module):
60
+ def __init__(self, channels, use_residual=True, num_repeats=1):
61
+ super().__init__()
62
+ self.layers = nn.ModuleList()
63
+ for repeat in range(num_repeats):
64
+ self.layers += [
65
+ nn.Sequential(
66
+ CNNBlock(channels, channels // 2, kernel_size=1),
67
+ CNNBlock(channels // 2, channels, kernel_size=3, padding=1),
68
+ )
69
+ ]
70
+
71
+ self.use_residual = use_residual
72
+ self.num_repeats = num_repeats
73
+
74
+ def forward(self, x):
75
+ for layer in self.layers:
76
+ if self.use_residual:
77
+ x = x + layer(x)
78
+ else:
79
+ x = layer(x)
80
+
81
+ return x
82
+
83
+
84
+ class ScalePrediction(nn.Module):
85
+ def __init__(self, in_channels, num_classes):
86
+ super().__init__()
87
+ self.pred = nn.Sequential(
88
+ CNNBlock(in_channels, 2 * in_channels, kernel_size=3, padding=1),
89
+ CNNBlock(
90
+ 2 * in_channels, (num_classes + 5) * 3, bn_act=False, kernel_size=1
91
+ ),
92
+ )
93
+ self.num_classes = num_classes
94
+
95
+ def forward(self, x):
96
+ return (
97
+ self.pred(x)
98
+ .reshape(x.shape[0], 3, self.num_classes + 5, x.shape[2], x.shape[3])
99
+ .permute(0, 1, 3, 4, 2)
100
+ )
101
+
102
+
103
+ class YOLOv3(nn.Module):
104
+ def __init__(self, in_channels=3, num_classes=80):
105
+ super().__init__()
106
+ self.num_classes = num_classes
107
+ self.in_channels = in_channels
108
+ self.layers = self._create_conv_layers()
109
+
110
+ def forward(self, x):
111
+ outputs = [] # for each scale
112
+ route_connections = []
113
+ for layer in self.layers:
114
+ if isinstance(layer, ScalePrediction):
115
+ outputs.append(layer(x))
116
+ continue
117
+
118
+ x = layer(x)
119
+
120
+ if isinstance(layer, ResidualBlock) and layer.num_repeats == 8:
121
+ route_connections.append(x)
122
+
123
+ elif isinstance(layer, nn.Upsample):
124
+ x = torch.cat([x, route_connections[-1]], dim=1)
125
+ route_connections.pop()
126
+
127
+ return outputs
128
+
129
+ def _create_conv_layers(self):
130
+ layers = nn.ModuleList()
131
+ in_channels = self.in_channels
132
+
133
+ for module in config:
134
+ if isinstance(module, tuple):
135
+ out_channels, kernel_size, stride = module
136
+ layers.append(
137
+ CNNBlock(
138
+ in_channels,
139
+ out_channels,
140
+ kernel_size=kernel_size,
141
+ stride=stride,
142
+ padding=1 if kernel_size == 3 else 0,
143
+ )
144
+ )
145
+ in_channels = out_channels
146
+
147
+ elif isinstance(module, list):
148
+ num_repeats = module[1]
149
+ layers.append(ResidualBlock(in_channels, num_repeats=num_repeats,))
150
+
151
+ elif isinstance(module, str):
152
+ if module == "S":
153
+ layers += [
154
+ ResidualBlock(in_channels, use_residual=False, num_repeats=1),
155
+ CNNBlock(in_channels, in_channels // 2, kernel_size=1),
156
+ ScalePrediction(in_channels // 2, num_classes=self.num_classes),
157
+ ]
158
+ in_channels = in_channels // 2
159
+
160
+ elif module == "U":
161
+ layers.append(nn.Upsample(scale_factor=2),)
162
+ in_channels = in_channels * 3
163
+
164
+ return layers
165
+
166
+
167
+ if __name__ == "__main__":
168
+ num_classes = 20
169
+ IMAGE_SIZE = 416
170
+ model = YOLOv3(num_classes=num_classes)
171
+ x = torch.randn((2, 3, IMAGE_SIZE, IMAGE_SIZE))
172
+ out = model(x)
173
+ assert model(x)[0].shape == (2, 3, IMAGE_SIZE//32, IMAGE_SIZE//32, num_classes + 5)
174
+ assert model(x)[1].shape == (2, 3, IMAGE_SIZE//16, IMAGE_SIZE//16, num_classes + 5)
175
+ assert model(x)[2].shape == (2, 3, IMAGE_SIZE//8, IMAGE_SIZE//8, num_classes + 5)
176
+ print("Success!")
utils.py ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import config
2
+ import matplotlib.pyplot as plt
3
+ import matplotlib.patches as patches
4
+ import numpy as np
5
+ import os
6
+ import random
7
+ import torch
8
+
9
+ from collections import Counter
10
+ from torch.utils.data import DataLoader
11
+ from tqdm import tqdm
12
+
13
+ def accuracy_fn(y,out,threshold):
14
+ """ accuracy after each epoch """
15
+
16
+ tot_class_preds, correct_class = 0, 0
17
+ tot_noobj, correct_noobj = 0, 0
18
+ tot_obj, correct_obj = 0, 0
19
+
20
+ for i in range(3):
21
+ y[i] = y[i].to(config.DEVICE)
22
+ obj = y[i][..., 0] == 1 # in paper this is Iobj_i
23
+ noobj = y[i][..., 0] == 0 # in paper this is Iobj_i
24
+
25
+ correct_class += torch.sum(
26
+ torch.argmax(out[i][..., 5:][obj], dim=-1) == y[i][..., 5][obj]
27
+ )
28
+ tot_class_preds += torch.sum(obj)
29
+
30
+ obj_preds = torch.sigmoid(out[i][..., 0]) > threshold
31
+ correct_obj += torch.sum(obj_preds[obj] == y[i][..., 0][obj])
32
+ tot_obj += torch.sum(obj)
33
+ correct_noobj += torch.sum(obj_preds[noobj] == y[i][..., 0][noobj])
34
+ tot_noobj += torch.sum(noobj)
35
+
36
+ class_accuracy = (correct_class/(tot_class_preds+1e-16))*100
37
+ no_obj_accuracy = (correct_noobj/(tot_noobj+1e-16))*100
38
+ obj_accuracy = (correct_obj/(tot_obj+1e-16))*100
39
+ print(f"Class accuracy is: {class_accuracy}%")
40
+ print(f"No obj accuracy is: {no_obj_accuracy}%")
41
+ print(f"Obj accuracy is: {obj_accuracy}%")
42
+
43
+ return class_accuracy,no_obj_accuracy,obj_accuracy
44
+
45
+ def iou_width_height(boxes1, boxes2):
46
+ """
47
+ Parameters:
48
+ boxes1 (tensor): width and height of the first bounding boxes
49
+ boxes2 (tensor): width and height of the second bounding boxes
50
+ Returns:
51
+ tensor: Intersection over union of the corresponding boxes
52
+ """
53
+ intersection = torch.min(boxes1[..., 0], boxes2[..., 0]) * torch.min(
54
+ boxes1[..., 1], boxes2[..., 1]
55
+ )
56
+ union = (
57
+ boxes1[..., 0] * boxes1[..., 1] + boxes2[..., 0] * boxes2[..., 1] - intersection
58
+ )
59
+ return intersection / union
60
+
61
+
62
+ def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"):
63
+ """
64
+ Video explanation of this function:
65
+ https://youtu.be/XXYG5ZWtjj0
66
+
67
+ This function calculates intersection over union (iou) given pred boxes
68
+ and target boxes.
69
+
70
+ Parameters:
71
+ boxes_preds (tensor): Predictions of Bounding Boxes (BATCH_SIZE, 4)
72
+ boxes_labels (tensor): Correct labels of Bounding Boxes (BATCH_SIZE, 4)
73
+ box_format (str): midpoint/corners, if boxes (x,y,w,h) or (x1,y1,x2,y2)
74
+
75
+ Returns:
76
+ tensor: Intersection over union for all examples
77
+ """
78
+
79
+ if box_format == "midpoint":
80
+ box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
81
+ box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
82
+ box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
83
+ box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2
84
+ box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
85
+ box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
86
+ box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2
87
+ box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2
88
+
89
+ if box_format == "corners":
90
+ box1_x1 = boxes_preds[..., 0:1]
91
+ box1_y1 = boxes_preds[..., 1:2]
92
+ box1_x2 = boxes_preds[..., 2:3]
93
+ box1_y2 = boxes_preds[..., 3:4]
94
+ box2_x1 = boxes_labels[..., 0:1]
95
+ box2_y1 = boxes_labels[..., 1:2]
96
+ box2_x2 = boxes_labels[..., 2:3]
97
+ box2_y2 = boxes_labels[..., 3:4]
98
+
99
+ x1 = torch.max(box1_x1, box2_x1)
100
+ y1 = torch.max(box1_y1, box2_y1)
101
+ x2 = torch.min(box1_x2, box2_x2)
102
+ y2 = torch.min(box1_y2, box2_y2)
103
+
104
+ intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
105
+ box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
106
+ box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))
107
+
108
+ return intersection / (box1_area + box2_area - intersection + 1e-6)
109
+
110
+
111
+ def non_max_suppression(bboxes, iou_threshold, threshold, box_format="corners"):
112
+ """
113
+ Video explanation of this function:
114
+ https://youtu.be/YDkjWEN8jNA
115
+
116
+ Does Non Max Suppression given bboxes
117
+
118
+ Parameters:
119
+ bboxes (list): list of lists containing all bboxes with each bboxes
120
+ specified as [class_pred, prob_score, x1, y1, x2, y2]
121
+ iou_threshold (float): threshold where predicted bboxes is correct
122
+ threshold (float): threshold to remove predicted bboxes (independent of IoU)
123
+ box_format (str): "midpoint" or "corners" used to specify bboxes
124
+
125
+ Returns:
126
+ list: bboxes after performing NMS given a specific IoU threshold
127
+ """
128
+
129
+ assert type(bboxes) == list
130
+
131
+ bboxes = [box for box in bboxes if box[1] > threshold]
132
+ bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)
133
+ bboxes_after_nms = []
134
+
135
+ while bboxes:
136
+ chosen_box = bboxes.pop(0)
137
+
138
+ bboxes = [
139
+ box
140
+ for box in bboxes
141
+ if box[0] != chosen_box[0]
142
+ or intersection_over_union(
143
+ torch.tensor(chosen_box[2:]),
144
+ torch.tensor(box[2:]),
145
+ box_format=box_format,
146
+ )
147
+ < iou_threshold
148
+ ]
149
+
150
+ bboxes_after_nms.append(chosen_box)
151
+
152
+ return bboxes_after_nms
153
+
154
+
155
+ def mean_average_precision(
156
+ pred_boxes, true_boxes, iou_threshold=0.5, box_format="midpoint", num_classes=20
157
+ ):
158
+ """
159
+ Video explanation of this function:
160
+ https://youtu.be/FppOzcDvaDI
161
+
162
+ This function calculates mean average precision (mAP)
163
+
164
+ Parameters:
165
+ pred_boxes (list): list of lists containing all bboxes with each bboxes
166
+ specified as [train_idx, class_prediction, prob_score, x1, y1, x2, y2]
167
+ true_boxes (list): Similar as pred_boxes except all the correct ones
168
+ iou_threshold (float): threshold where predicted bboxes is correct
169
+ box_format (str): "midpoint" or "corners" used to specify bboxes
170
+ num_classes (int): number of classes
171
+
172
+ Returns:
173
+ float: mAP value across all classes given a specific IoU threshold
174
+ """
175
+
176
+ # list storing all AP for respective classes
177
+ average_precisions = []
178
+
179
+ # used for numerical stability later on
180
+ epsilon = 1e-6
181
+
182
+ for c in range(num_classes):
183
+ detections = []
184
+ ground_truths = []
185
+
186
+ # Go through all predictions and targets,
187
+ # and only add the ones that belong to the
188
+ # current class c
189
+ for detection in pred_boxes:
190
+ if detection[1] == c:
191
+ detections.append(detection)
192
+
193
+ for true_box in true_boxes:
194
+ if true_box[1] == c:
195
+ ground_truths.append(true_box)
196
+
197
+ # find the amount of bboxes for each training example
198
+ # Counter here finds how many ground truth bboxes we get
199
+ # for each training example, so let's say img 0 has 3,
200
+ # img 1 has 5 then we will obtain a dictionary with:
201
+ # amount_bboxes = {0:3, 1:5}
202
+ amount_bboxes = Counter([gt[0] for gt in ground_truths])
203
+
204
+ # We then go through each key, val in this dictionary
205
+ # and convert to the following (w.r.t same example):
206
+ # ammount_bboxes = {0:torch.tensor[0,0,0], 1:torch.tensor[0,0,0,0,0]}
207
+ for key, val in amount_bboxes.items():
208
+ amount_bboxes[key] = torch.zeros(val)
209
+
210
+ # sort by box probabilities which is index 2
211
+ detections.sort(key=lambda x: x[2], reverse=True)
212
+ TP = torch.zeros((len(detections)))
213
+ FP = torch.zeros((len(detections)))
214
+ total_true_bboxes = len(ground_truths)
215
+
216
+ # If none exists for this class then we can safely skip
217
+ if total_true_bboxes == 0:
218
+ continue
219
+
220
+ for detection_idx, detection in enumerate(detections):
221
+ # Only take out the ground_truths that have the same
222
+ # training idx as detection
223
+ ground_truth_img = [
224
+ bbox for bbox in ground_truths if bbox[0] == detection[0]
225
+ ]
226
+
227
+ num_gts = len(ground_truth_img)
228
+ best_iou = 0
229
+
230
+ for idx, gt in enumerate(ground_truth_img):
231
+ iou = intersection_over_union(
232
+ torch.tensor(detection[3:]),
233
+ torch.tensor(gt[3:]),
234
+ box_format=box_format,
235
+ )
236
+
237
+ if iou > best_iou:
238
+ best_iou = iou
239
+ best_gt_idx = idx
240
+
241
+ if best_iou > iou_threshold:
242
+ # only detect ground truth detection once
243
+ if amount_bboxes[detection[0]][best_gt_idx] == 0:
244
+ # true positive and add this bounding box to seen
245
+ TP[detection_idx] = 1
246
+ amount_bboxes[detection[0]][best_gt_idx] = 1
247
+ else:
248
+ FP[detection_idx] = 1
249
+
250
+ # if IOU is lower then the detection is a false positive
251
+ else:
252
+ FP[detection_idx] = 1
253
+
254
+ TP_cumsum = torch.cumsum(TP, dim=0)
255
+ FP_cumsum = torch.cumsum(FP, dim=0)
256
+ recalls = TP_cumsum / (total_true_bboxes + epsilon)
257
+ precisions = TP_cumsum / (TP_cumsum + FP_cumsum + epsilon)
258
+ precisions = torch.cat((torch.tensor([1]), precisions))
259
+ recalls = torch.cat((torch.tensor([0]), recalls))
260
+ # torch.trapz for numerical integration
261
+ average_precisions.append(torch.trapz(precisions, recalls))
262
+
263
+ return sum(average_precisions) / len(average_precisions)
264
+
265
+
266
+ def plot_image(image, boxes):
267
+ """Plots predicted bounding boxes on the image"""
268
+ cmap = plt.get_cmap("tab20b")
269
+ class_labels = config.COCO_LABELS if config.DATASET=='COCO' else config.PASCAL_CLASSES
270
+ colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))]
271
+ im = np.array(image)
272
+ height, width, _ = im.shape
273
+
274
+ # Create figure and axes
275
+ fig, ax = plt.subplots(1)
276
+ # Display the image
277
+ ax.imshow(im)
278
+
279
+ # box[0] is x midpoint, box[2] is width
280
+ # box[1] is y midpoint, box[3] is height
281
+
282
+ # Create a Rectangle patch
283
+ for box in boxes:
284
+ assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height"
285
+ class_pred = box[0]
286
+ box = box[2:]
287
+ upper_left_x = box[0] - box[2] / 2
288
+ upper_left_y = box[1] - box[3] / 2
289
+ rect = patches.Rectangle(
290
+ (upper_left_x * width, upper_left_y * height),
291
+ box[2] * width,
292
+ box[3] * height,
293
+ linewidth=2,
294
+ edgecolor=colors[int(class_pred)],
295
+ facecolor="none",
296
+ )
297
+ # Add the patch to the Axes
298
+ ax.add_patch(rect)
299
+ plt.text(
300
+ upper_left_x * width,
301
+ upper_left_y * height,
302
+ s=class_labels[int(class_pred)],
303
+ color="white",
304
+ verticalalignment="top",
305
+ bbox={"color": colors[int(class_pred)], "pad": 0},
306
+ )
307
+
308
+ plt.show()
309
+
310
+
311
+ def get_evaluation_bboxes(
312
+ loader,
313
+ model,
314
+ iou_threshold,
315
+ anchors,
316
+ threshold,
317
+ box_format="midpoint",
318
+ device="cuda",
319
+ ):
320
+ # make sure model is in eval before get bboxes
321
+ model.eval()
322
+ train_idx = 0
323
+ all_pred_boxes = []
324
+ all_true_boxes = []
325
+ for batch_idx, (x, labels) in enumerate(tqdm(loader)):
326
+ x = x.to(device)
327
+
328
+ with torch.no_grad():
329
+ predictions = model(x)
330
+
331
+ batch_size = x.shape[0]
332
+ bboxes = [[] for _ in range(batch_size)]
333
+ for i in range(3):
334
+ S = predictions[i].shape[2]
335
+ anchor = torch.tensor([*anchors[i]]).to(device) * S
336
+ boxes_scale_i = cells_to_bboxes(
337
+ predictions[i], anchor, S=S, is_preds=True
338
+ )
339
+ for idx, (box) in enumerate(boxes_scale_i):
340
+ bboxes[idx] += box
341
+
342
+ # we just want one bbox for each label, not one for each scale
343
+ true_bboxes = cells_to_bboxes(
344
+ labels[2], anchor, S=S, is_preds=False
345
+ )
346
+
347
+ for idx in range(batch_size):
348
+ nms_boxes = non_max_suppression(
349
+ bboxes[idx],
350
+ iou_threshold=iou_threshold,
351
+ threshold=threshold,
352
+ box_format=box_format,
353
+ )
354
+
355
+ for nms_box in nms_boxes:
356
+ all_pred_boxes.append([train_idx] + nms_box)
357
+
358
+ for box in true_bboxes[idx]:
359
+ if box[1] > threshold:
360
+ all_true_boxes.append([train_idx] + box)
361
+
362
+ train_idx += 1
363
+
364
+ model.train()
365
+ return all_pred_boxes, all_true_boxes
366
+
367
+
368
+ def cells_to_bboxes(predictions, anchors, S, is_preds=True):
369
+ """
370
+ Scales the predictions coming from the model to
371
+ be relative to the entire image such that they for example later
372
+ can be plotted or.
373
+ INPUT:
374
+ predictions: tensor of size (N, 3, S, S, num_classes+5)
375
+ anchors: the anchors used for the predictions
376
+ S: the number of cells the image is divided in on the width (and height)
377
+ is_preds: whether the input is predictions or the true bounding boxes
378
+ OUTPUT:
379
+ converted_bboxes: the converted boxes of sizes (N, num_anchors, S, S, 1+5) with class index,
380
+ object score, bounding box coordinates
381
+ """
382
+ BATCH_SIZE = predictions.shape[0]
383
+ num_anchors = len(anchors)
384
+ box_predictions = predictions[..., 1:5]
385
+ if is_preds:
386
+ anchors = anchors.reshape(1, len(anchors), 1, 1, 2)
387
+ box_predictions[..., 0:2] = torch.sigmoid(box_predictions[..., 0:2])
388
+ box_predictions[..., 2:] = torch.exp(box_predictions[..., 2:]) * anchors
389
+ scores = torch.sigmoid(predictions[..., 0:1])
390
+ best_class = torch.argmax(predictions[..., 5:], dim=-1).unsqueeze(-1)
391
+ else:
392
+ scores = predictions[..., 0:1]
393
+ best_class = predictions[..., 5:6]
394
+
395
+ cell_indices = (
396
+ torch.arange(S)
397
+ .repeat(predictions.shape[0], 3, S, 1)
398
+ .unsqueeze(-1)
399
+ .to(predictions.device)
400
+ )
401
+ x = 1 / S * (box_predictions[..., 0:1] + cell_indices)
402
+ y = 1 / S * (box_predictions[..., 1:2] + cell_indices.permute(0, 1, 3, 2, 4))
403
+ w_h = 1 / S * box_predictions[..., 2:4]
404
+ converted_bboxes = torch.cat((best_class, scores, x, y, w_h), dim=-1).reshape(BATCH_SIZE, num_anchors * S * S, 6)
405
+ return converted_bboxes.tolist()
406
+
407
+ def check_class_accuracy(model, loader, threshold):
408
+ model.eval()
409
+ tot_class_preds, correct_class = 0, 0
410
+ tot_noobj, correct_noobj = 0, 0
411
+ tot_obj, correct_obj = 0, 0
412
+
413
+ for idx, (x, y) in enumerate(tqdm(loader)):
414
+ x = x.to(config.DEVICE)
415
+ with torch.no_grad():
416
+ out = model(x)
417
+
418
+ for i in range(3):
419
+ y[i] = y[i].to(config.DEVICE)
420
+ obj = y[i][..., 0] == 1 # in paper this is Iobj_i
421
+ noobj = y[i][..., 0] == 0 # in paper this is Iobj_i
422
+
423
+ correct_class += torch.sum(
424
+ torch.argmax(out[i][..., 5:][obj], dim=-1) == y[i][..., 5][obj]
425
+ )
426
+ tot_class_preds += torch.sum(obj)
427
+
428
+ obj_preds = torch.sigmoid(out[i][..., 0]) > threshold
429
+ correct_obj += torch.sum(obj_preds[obj] == y[i][..., 0][obj])
430
+ tot_obj += torch.sum(obj)
431
+ correct_noobj += torch.sum(obj_preds[noobj] == y[i][..., 0][noobj])
432
+ tot_noobj += torch.sum(noobj)
433
+
434
+ print(f"Class accuracy is: {(correct_class/(tot_class_preds+1e-16))*100:2f}%")
435
+ print(f"No obj accuracy is: {(correct_noobj/(tot_noobj+1e-16))*100:2f}%")
436
+ print(f"Obj accuracy is: {(correct_obj/(tot_obj+1e-16))*100:2f}%")
437
+ model.train()
438
+
439
+
440
+ def get_mean_std(loader):
441
+ # var[X] = E[X**2] - E[X]**2
442
+ channels_sum, channels_sqrd_sum, num_batches = 0, 0, 0
443
+
444
+ for data, _ in tqdm(loader):
445
+ channels_sum += torch.mean(data, dim=[0, 2, 3])
446
+ channels_sqrd_sum += torch.mean(data ** 2, dim=[0, 2, 3])
447
+ num_batches += 1
448
+
449
+ mean = channels_sum / num_batches
450
+ std = (channels_sqrd_sum / num_batches - mean ** 2) ** 0.5
451
+
452
+ return mean, std
453
+
454
+
455
+ def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
456
+ print("=> Saving checkpoint")
457
+ checkpoint = {
458
+ "state_dict": model.state_dict(),
459
+ "optimizer": optimizer.state_dict(),
460
+ }
461
+ torch.save(checkpoint, filename)
462
+
463
+
464
+ def load_checkpoint(checkpoint_file, model, optimizer, lr):
465
+ print("=> Loading checkpoint")
466
+ checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE)
467
+ model.load_state_dict(checkpoint["state_dict"])
468
+ optimizer.load_state_dict(checkpoint["optimizer"])
469
+
470
+ # If we don't do this then it will just have learning rate of old checkpoint
471
+ # and it will lead to many hours of debugging \:
472
+ for param_group in optimizer.param_groups:
473
+ param_group["lr"] = lr
474
+
475
+
476
+ def get_loaders(train_csv_path, test_csv_path):
477
+ from dataset import YOLODataset
478
+
479
+ IMAGE_SIZE = config.IMAGE_SIZE
480
+ train_dataset = YOLODataset(
481
+ train_csv_path,
482
+ transform=config.train_transforms,
483
+ S=[IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8],
484
+ img_dir=config.IMG_DIR,
485
+ label_dir=config.LABEL_DIR,
486
+ anchors=config.ANCHORS,
487
+ )
488
+ test_dataset = YOLODataset(
489
+ test_csv_path,
490
+ transform=config.test_transforms,
491
+ S=[IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8],
492
+ img_dir=config.IMG_DIR,
493
+ label_dir=config.LABEL_DIR,
494
+ anchors=config.ANCHORS,
495
+ )
496
+ train_loader = DataLoader(
497
+ dataset=train_dataset,
498
+ batch_size=config.BATCH_SIZE,
499
+ num_workers=config.NUM_WORKERS,
500
+ pin_memory=config.PIN_MEMORY,
501
+ shuffle=True,
502
+ drop_last=False,
503
+ )
504
+ test_loader = DataLoader(
505
+ dataset=test_dataset,
506
+ batch_size=config.BATCH_SIZE,
507
+ num_workers=config.NUM_WORKERS,
508
+ pin_memory=config.PIN_MEMORY,
509
+ shuffle=False,
510
+ drop_last=False,
511
+ )
512
+
513
+ train_eval_dataset = YOLODataset(
514
+ train_csv_path,
515
+ transform=config.test_transforms,
516
+ S=[IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8],
517
+ img_dir=config.IMG_DIR,
518
+ label_dir=config.LABEL_DIR,
519
+ anchors=config.ANCHORS,
520
+ )
521
+ train_eval_loader = DataLoader(
522
+ dataset=train_eval_dataset,
523
+ batch_size=config.BATCH_SIZE,
524
+ num_workers=config.NUM_WORKERS,
525
+ pin_memory=config.PIN_MEMORY,
526
+ shuffle=False,
527
+ drop_last=False,
528
+ )
529
+
530
+ return train_loader, test_loader, train_eval_loader
531
+
532
+ def plot_couple_examples(model, loader, thresh, iou_thresh, anchors):
533
+ model.eval()
534
+ x, y = next(iter(loader))
535
+ x = x.to("cuda")
536
+ with torch.no_grad():
537
+ out = model(x)
538
+ bboxes = [[] for _ in range(x.shape[0])]
539
+ for i in range(3):
540
+ batch_size, A, S, _, _ = out[i].shape
541
+ anchor = anchors[i]
542
+ boxes_scale_i = cells_to_bboxes(
543
+ out[i], anchor, S=S, is_preds=True
544
+ )
545
+ for idx, (box) in enumerate(boxes_scale_i):
546
+ bboxes[idx] += box
547
+
548
+ model.train()
549
+
550
+ for i in range(batch_size//4):
551
+ nms_boxes = non_max_suppression(
552
+ bboxes[i], iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint",
553
+ )
554
+ plot_image(x[i].permute(1,2,0).detach().cpu(), nms_boxes)
555
+
556
+
557
+
558
+ def seed_everything(seed=42):
559
+ os.environ['PYTHONHASHSEED'] = str(seed)
560
+ random.seed(seed)
561
+ np.random.seed(seed)
562
+ torch.manual_seed(seed)
563
+ torch.cuda.manual_seed(seed)
564
+ torch.cuda.manual_seed_all(seed)
565
+ torch.backends.cudnn.deterministic = True
566
+ torch.backends.cudnn.benchmark = False
567
+
568
+
569
+ def clip_coords(boxes, img_shape):
570
+ # Clip bounding xyxy bounding boxes to image shape (height, width)
571
+ boxes[:, 0].clamp_(0, img_shape[1]) # x1
572
+ boxes[:, 1].clamp_(0, img_shape[0]) # y1
573
+ boxes[:, 2].clamp_(0, img_shape[1]) # x2
574
+ boxes[:, 3].clamp_(0, img_shape[0]) # y2
575
+
576
+ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
577
+ # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
578
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
579
+ y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
580
+ y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y
581
+ y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x
582
+ y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh # bottom right y
583
+ return y
584
+
585
+
586
+ def xyn2xy(x, w=640, h=640, padw=0, padh=0):
587
+ # Convert normalized segments into pixel segments, shape (n,2)
588
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
589
+ y[..., 0] = w * x[..., 0] + padw # top left x
590
+ y[..., 1] = h * x[..., 1] + padh # top left y
591
+ return y
592
+
593
+ def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
594
+ # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
595
+ if clip:
596
+ clip_boxes(x, (h - eps, w - eps)) # warning: inplace clip
597
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
598
+ y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
599
+ y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center
600
+ y[..., 2] = (x[..., 2] - x[..., 0]) / w # width
601
+ y[..., 3] = (x[..., 3] - x[..., 1]) / h # height
602
+ return y
603
+
604
+ def clip_boxes(boxes, shape):
605
+ # Clip boxes (xyxy) to image shape (height, width)
606
+ if isinstance(boxes, torch.Tensor): # faster individually
607
+ boxes[..., 0].clamp_(0, shape[1]) # x1
608
+ boxes[..., 1].clamp_(0, shape[0]) # y1
609
+ boxes[..., 2].clamp_(0, shape[1]) # x2
610
+ boxes[..., 3].clamp_(0, shape[0]) # y2
611
+ else: # np.array (faster grouped)
612
+ boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2
613
+ boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2
yolo3_model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ea2f10d11fdabfa7e77ee96841a0990ee53c5e32f32023e463e38e80413ac1b
3
+ size 740092921