PolarisFTL commited on
Commit
5b93456
·
verified ·
1 Parent(s): 0c8e2d1

Upload 7 files

Browse files
utils/__init__.py ADDED
File without changes
utils/callbacks.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import os
3
+
4
+ import torch
5
+ import matplotlib
6
+ matplotlib.use('Agg')
7
+ import scipy.signal
8
+ from matplotlib import pyplot as plt
9
+ from torch.utils.tensorboard import SummaryWriter
10
+
11
+ import shutil
12
+ import numpy as np
13
+
14
+ from PIL import Image
15
+ from tqdm import tqdm
16
+ from .utils import cvtColor, preprocess_input, resize_image
17
+ from .utils_bbox import DecodeBox
18
+ from .utils_map import get_coco_map, get_map
19
+
20
+
21
+ class LossHistory():
22
+ def __init__(self, log_dir, model, input_shape):
23
+ self.log_dir = log_dir
24
+ self.losses = []
25
+ self.val_loss = []
26
+
27
+ os.makedirs(self.log_dir)
28
+ self.writer = SummaryWriter(self.log_dir)
29
+ try:
30
+ dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1])
31
+ self.writer.add_graph(model, dummy_input)
32
+ except:
33
+ pass
34
+
35
+ def append_loss(self, epoch, loss, val_loss):
36
+ if not os.path.exists(self.log_dir):
37
+ os.makedirs(self.log_dir)
38
+
39
+ self.losses.append(loss)
40
+ self.val_loss.append(val_loss)
41
+
42
+ with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
43
+ f.write(str(loss))
44
+ f.write("\n")
45
+ with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:
46
+ f.write(str(val_loss))
47
+ f.write("\n")
48
+
49
+ self.writer.add_scalar('loss', loss, epoch)
50
+ self.writer.add_scalar('val_loss', val_loss, epoch)
51
+ self.loss_plot()
52
+
53
+ def loss_plot(self):
54
+ iters = range(len(self.losses))
55
+
56
+ plt.figure()
57
+ plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')
58
+ plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')
59
+ try:
60
+ if len(self.losses) < 25:
61
+ num = 5
62
+ else:
63
+ num = 15
64
+
65
+ plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
66
+ except:
67
+ pass
68
+
69
+ plt.grid(True)
70
+ plt.xlabel('Epoch')
71
+ plt.ylabel('Loss')
72
+ plt.legend(loc="upper right")
73
+
74
+ plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))
75
+
76
+ plt.cla()
77
+ plt.close("all")
78
+
79
+ class EvalCallback():
80
+ def __init__(self, net, input_shape, anchors, anchors_mask, class_names, num_classes, val_lines, log_dir, cuda, \
81
+ map_out_path=".temp_map_out", max_boxes=100, confidence=0.05, nms_iou=0.5, letterbox_image=True, MINOVERLAP=0.5, eval_flag=True, period=1):
82
+ super(EvalCallback, self).__init__()
83
+
84
+ self.net = net
85
+ self.input_shape = input_shape
86
+ self.anchors = anchors
87
+ self.anchors_mask = anchors_mask
88
+ self.class_names = class_names
89
+ self.num_classes = num_classes
90
+ self.val_lines = val_lines
91
+ self.log_dir = log_dir
92
+ self.cuda = cuda
93
+ self.map_out_path = map_out_path
94
+ self.max_boxes = max_boxes
95
+ self.confidence = confidence
96
+ self.nms_iou = nms_iou
97
+ self.letterbox_image = letterbox_image
98
+ self.MINOVERLAP = MINOVERLAP
99
+ self.eval_flag = eval_flag
100
+ self.period = period
101
+
102
+ self.bbox_util = DecodeBox(self.anchors, self.num_classes, (self.input_shape[0], self.input_shape[1]), self.anchors_mask)
103
+
104
+ self.maps = [0]
105
+ self.epoches = [0]
106
+ if self.eval_flag:
107
+ with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f:
108
+ f.write(str(0))
109
+ f.write("\n")
110
+
111
+ def get_map_txt(self, image_id, image, class_names, map_out_path):
112
+ f = open(os.path.join(map_out_path, "detection-results/"+image_id+".txt"), "w", encoding='utf-8')
113
+ image_shape = np.array(np.shape(image)[0:2])
114
+ image = cvtColor(image)
115
+ image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image)
116
+ image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
117
+
118
+ with torch.no_grad():
119
+ images = torch.from_numpy(image_data)
120
+ if self.cuda:
121
+ images = images.cuda()
122
+ outputs = self.net(images)
123
+ outputs = self.bbox_util.decode_box(outputs)
124
+ results = self.bbox_util.non_max_suppression(torch.cat(outputs, 1), self.num_classes, self.input_shape,
125
+ image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou)
126
+
127
+ if results[0] is None:
128
+ return
129
+
130
+ top_label = np.array(results[0][:, 6], dtype = 'int32')
131
+ top_conf = results[0][:, 4] * results[0][:, 5]
132
+ top_boxes = results[0][:, :4]
133
+
134
+ top_100 = np.argsort(top_conf)[::-1][:self.max_boxes]
135
+ top_boxes = top_boxes[top_100]
136
+ top_conf = top_conf[top_100]
137
+ top_label = top_label[top_100]
138
+
139
+ for i, c in list(enumerate(top_label)):
140
+ predicted_class = self.class_names[int(c)]
141
+ box = top_boxes[i]
142
+ score = str(top_conf[i])
143
+
144
+ top, left, bottom, right = box
145
+ if predicted_class not in class_names:
146
+ continue
147
+
148
+ f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom))))
149
+
150
+ f.close()
151
+ return
152
+
153
+ def on_epoch_end(self, epoch, model_eval):
154
+ if epoch % self.period == 0 and self.eval_flag:
155
+ self.net = model_eval
156
+ if not os.path.exists(self.map_out_path):
157
+ os.makedirs(self.map_out_path)
158
+ if not os.path.exists(os.path.join(self.map_out_path, "ground-truth")):
159
+ os.makedirs(os.path.join(self.map_out_path, "ground-truth"))
160
+ if not os.path.exists(os.path.join(self.map_out_path, "detection-results")):
161
+ os.makedirs(os.path.join(self.map_out_path, "detection-results"))
162
+ print("Get map.")
163
+ for annotation_line in tqdm(self.val_lines):
164
+ line = annotation_line.split()
165
+ image_id = os.path.basename(line[0]).split('.')[0]
166
+ image = Image.open(line[0])
167
+ gt_boxes = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
168
+ self.get_map_txt(image_id, image, self.class_names, self.map_out_path)
169
+
170
+ with open(os.path.join(self.map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f:
171
+ for box in gt_boxes:
172
+ left, top, right, bottom, obj = box
173
+ obj_name = self.class_names[obj]
174
+ new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom))
175
+
176
+ print("Calculate Map.")
177
+ try:
178
+ temp_map = get_coco_map(class_names = self.class_names, path = self.map_out_path)[1]
179
+ except:
180
+ temp_map = get_map(self.MINOVERLAP, False, path = self.map_out_path)
181
+ self.maps.append(temp_map)
182
+ self.epoches.append(epoch)
183
+
184
+ with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f:
185
+ f.write(str(temp_map))
186
+ f.write("\n")
187
+
188
+ plt.figure()
189
+ plt.plot(self.epoches, self.maps, 'red', linewidth = 2, label='train map')
190
+
191
+ plt.grid(True)
192
+ plt.xlabel('Epoch')
193
+ plt.ylabel('Map %s'%str(self.MINOVERLAP))
194
+ plt.title('A Map Curve')
195
+ plt.legend(loc="upper right")
196
+
197
+ plt.savefig(os.path.join(self.log_dir, "epoch_map.png"))
198
+ plt.cla()
199
+ plt.close("all")
200
+
201
+ print("Get map done.")
202
+ shutil.rmtree(self.map_out_path)
utils/dataloader.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from random import sample, shuffle
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from PIL import Image
7
+ from torch.utils.data.dataset import Dataset
8
+
9
+ from utils.utils import cvtColor, preprocess_input
10
+
11
+
12
+ class YoloDataset(Dataset):
13
+ def __init__(self, annotation_lines, input_shape, num_classes, epoch_length, \
14
+ mosaic, mixup, mosaic_prob, mixup_prob, train, special_aug_ratio = 0.7):
15
+ super(YoloDataset, self).__init__()
16
+ self.annotation_lines = annotation_lines
17
+ self.input_shape = input_shape
18
+ self.num_classes = num_classes
19
+ self.epoch_length = epoch_length
20
+ self.mosaic = mosaic
21
+ self.mosaic_prob = mosaic_prob
22
+ self.mixup = mixup
23
+ self.mixup_prob = mixup_prob
24
+ self.train = train
25
+ self.special_aug_ratio = special_aug_ratio
26
+
27
+ self.epoch_now = -1
28
+ self.length = len(self.annotation_lines)
29
+
30
+ def __len__(self):
31
+ return self.length
32
+
33
+ def __getitem__(self, index):
34
+ index = index % self.length
35
+ if self.mosaic and self.rand() < self.mosaic_prob and self.epoch_now < self.epoch_length * self.special_aug_ratio:
36
+ lines = sample(self.annotation_lines, 3)
37
+ lines.append(self.annotation_lines[index])
38
+ shuffle(lines)
39
+ image, box = self.get_random_data_with_Mosaic(lines, self.input_shape)
40
+
41
+ if self.mixup and self.rand() < self.mixup_prob:
42
+ lines = sample(self.annotation_lines, 1)
43
+ image_2, box_2 = self.get_random_data(lines[0], self.input_shape, random = self.train)
44
+ image, box = self.get_random_data_with_MixUp(image, box, image_2, box_2)
45
+ else:
46
+ image, box = self.get_random_data(self.annotation_lines[index], self.input_shape, random = self.train)
47
+
48
+ image = np.transpose(preprocess_input(np.array(image, dtype=np.float32)), (2, 0, 1))
49
+ box = np.array(box, dtype=np.float32)
50
+ if len(box) != 0:
51
+ box[:, [0, 2]] = box[:, [0, 2]] / self.input_shape[1]
52
+ box[:, [1, 3]] = box[:, [1, 3]] / self.input_shape[0]
53
+
54
+ box[:, 2:4] = box[:, 2:4] - box[:, 0:2]
55
+ box[:, 0:2] = box[:, 0:2] + box[:, 2:4] / 2
56
+ return image, box
57
+
58
+ def rand(self, a=0, b=1):
59
+ return np.random.rand()*(b-a) + a
60
+
61
+ def get_random_data(self, annotation_line, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.4, random=True):
62
+ line = annotation_line.split()
63
+ image = Image.open(line[0])
64
+ image = cvtColor(image)
65
+ iw, ih = image.size
66
+ h, w = input_shape
67
+ box = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
68
+
69
+ if not random:
70
+ scale = min(w/iw, h/ih)
71
+ nw = int(iw*scale)
72
+ nh = int(ih*scale)
73
+ dx = (w-nw)//2
74
+ dy = (h-nh)//2
75
+
76
+ image = image.resize((nw,nh), Image.BICUBIC)
77
+ new_image = Image.new('RGB', (w,h), (128,128,128))
78
+ new_image.paste(image, (dx, dy))
79
+ image_data = np.array(new_image, np.float32)
80
+
81
+ if len(box)>0:
82
+ np.random.shuffle(box)
83
+ box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
84
+ box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
85
+ box[:, 0:2][box[:, 0:2]<0] = 0
86
+ box[:, 2][box[:, 2]>w] = w
87
+ box[:, 3][box[:, 3]>h] = h
88
+ box_w = box[:, 2] - box[:, 0]
89
+ box_h = box[:, 3] - box[:, 1]
90
+ box = box[np.logical_and(box_w>1, box_h>1)]
91
+
92
+ return image_data, box
93
+
94
+ new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter)
95
+ scale = self.rand(.25, 2)
96
+ if new_ar < 1:
97
+ nh = int(scale*h)
98
+ nw = int(nh*new_ar)
99
+ else:
100
+ nw = int(scale*w)
101
+ nh = int(nw/new_ar)
102
+ image = image.resize((nw,nh), Image.BICUBIC)
103
+ dx = int(self.rand(0, w-nw))
104
+ dy = int(self.rand(0, h-nh))
105
+ new_image = Image.new('RGB', (w,h), (128,128,128))
106
+ new_image.paste(image, (dx, dy))
107
+ image = new_image
108
+
109
+ flip = self.rand()<.5
110
+ if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT)
111
+
112
+ image_data = np.array(image, np.uint8)
113
+
114
+ r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1
115
+
116
+ hue, sat, val = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV))
117
+ dtype = image_data.dtype
118
+
119
+ x = np.arange(0, 256, dtype=r.dtype)
120
+ lut_hue = ((x * r[0]) % 180).astype(dtype)
121
+ lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
122
+ lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
123
+
124
+ image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
125
+ image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB)
126
+
127
+ if len(box)>0:
128
+ np.random.shuffle(box)
129
+ box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
130
+ box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
131
+ if flip: box[:, [0,2]] = w - box[:, [2,0]]
132
+ box[:, 0:2][box[:, 0:2]<0] = 0
133
+ box[:, 2][box[:, 2]>w] = w
134
+ box[:, 3][box[:, 3]>h] = h
135
+ box_w = box[:, 2] - box[:, 0]
136
+ box_h = box[:, 3] - box[:, 1]
137
+ box = box[np.logical_and(box_w>1, box_h>1)]
138
+
139
+ return image_data, box
140
+
141
+ def merge_bboxes(self, bboxes, cutx, cuty):
142
+ merge_bbox = []
143
+ for i in range(len(bboxes)):
144
+ for box in bboxes[i]:
145
+ tmp_box = []
146
+ x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
147
+
148
+ if i == 0:
149
+ if y1 > cuty or x1 > cutx:
150
+ continue
151
+ if y2 >= cuty and y1 <= cuty:
152
+ y2 = cuty
153
+ if x2 >= cutx and x1 <= cutx:
154
+ x2 = cutx
155
+
156
+ if i == 1:
157
+ if y2 < cuty or x1 > cutx:
158
+ continue
159
+ if y2 >= cuty and y1 <= cuty:
160
+ y1 = cuty
161
+ if x2 >= cutx and x1 <= cutx:
162
+ x2 = cutx
163
+
164
+ if i == 2:
165
+ if y2 < cuty or x2 < cutx:
166
+ continue
167
+ if y2 >= cuty and y1 <= cuty:
168
+ y1 = cuty
169
+ if x2 >= cutx and x1 <= cutx:
170
+ x1 = cutx
171
+
172
+ if i == 3:
173
+ if y1 > cuty or x2 < cutx:
174
+ continue
175
+ if y2 >= cuty and y1 <= cuty:
176
+ y2 = cuty
177
+ if x2 >= cutx and x1 <= cutx:
178
+ x1 = cutx
179
+ tmp_box.append(x1)
180
+ tmp_box.append(y1)
181
+ tmp_box.append(x2)
182
+ tmp_box.append(y2)
183
+ tmp_box.append(box[-1])
184
+ merge_bbox.append(tmp_box)
185
+ return merge_bbox
186
+
187
+ def get_random_data_with_Mosaic(self, annotation_line, input_shape, jitter=0.3, hue=.1, sat=0.7, val=0.4):
188
+ h, w = input_shape
189
+ min_offset_x = self.rand(0.3, 0.7)
190
+ min_offset_y = self.rand(0.3, 0.7)
191
+
192
+ image_datas = []
193
+ box_datas = []
194
+ index = 0
195
+ for line in annotation_line:
196
+ line_content = line.split()
197
+ image = Image.open(line_content[0])
198
+ image = cvtColor(image)
199
+
200
+ iw, ih = image.size
201
+ box = np.array([np.array(list(map(int,box.split(',')))) for box in line_content[1:]])
202
+
203
+ flip = self.rand()<.5
204
+ if flip and len(box)>0:
205
+ image = image.transpose(Image.FLIP_LEFT_RIGHT)
206
+ box[:, [0,2]] = iw - box[:, [2,0]]
207
+
208
+ new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter)
209
+ scale = self.rand(.4, 1)
210
+ if new_ar < 1:
211
+ nh = int(scale*h)
212
+ nw = int(nh*new_ar)
213
+ else:
214
+ nw = int(scale*w)
215
+ nh = int(nw/new_ar)
216
+ image = image.resize((nw, nh), Image.BICUBIC)
217
+
218
+ if index == 0:
219
+ dx = int(w*min_offset_x) - nw
220
+ dy = int(h*min_offset_y) - nh
221
+ elif index == 1:
222
+ dx = int(w*min_offset_x) - nw
223
+ dy = int(h*min_offset_y)
224
+ elif index == 2:
225
+ dx = int(w*min_offset_x)
226
+ dy = int(h*min_offset_y)
227
+ elif index == 3:
228
+ dx = int(w*min_offset_x)
229
+ dy = int(h*min_offset_y) - nh
230
+
231
+ new_image = Image.new('RGB', (w,h), (128,128,128))
232
+ new_image.paste(image, (dx, dy))
233
+ image_data = np.array(new_image)
234
+
235
+ index = index + 1
236
+ box_data = []
237
+ if len(box)>0:
238
+ np.random.shuffle(box)
239
+ box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
240
+ box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
241
+ box[:, 0:2][box[:, 0:2]<0] = 0
242
+ box[:, 2][box[:, 2]>w] = w
243
+ box[:, 3][box[:, 3]>h] = h
244
+ box_w = box[:, 2] - box[:, 0]
245
+ box_h = box[:, 3] - box[:, 1]
246
+ box = box[np.logical_and(box_w>1, box_h>1)]
247
+ box_data = np.zeros((len(box),5))
248
+ box_data[:len(box)] = box
249
+
250
+ image_datas.append(image_data)
251
+ box_datas.append(box_data)
252
+
253
+ cutx = int(w * min_offset_x)
254
+ cuty = int(h * min_offset_y)
255
+
256
+ new_image = np.zeros([h, w, 3])
257
+ new_image[:cuty, :cutx, :] = image_datas[0][:cuty, :cutx, :]
258
+ new_image[cuty:, :cutx, :] = image_datas[1][cuty:, :cutx, :]
259
+ new_image[cuty:, cutx:, :] = image_datas[2][cuty:, cutx:, :]
260
+ new_image[:cuty, cutx:, :] = image_datas[3][:cuty, cutx:, :]
261
+
262
+ new_image = np.array(new_image, np.uint8)
263
+ r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1
264
+ hue, sat, val = cv2.split(cv2.cvtColor(new_image, cv2.COLOR_RGB2HSV))
265
+ dtype = new_image.dtype
266
+ x = np.arange(0, 256, dtype=r.dtype)
267
+ lut_hue = ((x * r[0]) % 180).astype(dtype)
268
+ lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
269
+ lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
270
+
271
+ new_image = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
272
+ new_image = cv2.cvtColor(new_image, cv2.COLOR_HSV2RGB)
273
+
274
+ new_boxes = self.merge_bboxes(box_datas, cutx, cuty)
275
+
276
+ return new_image, new_boxes
277
+
278
+ def get_random_data_with_MixUp(self, image_1, box_1, image_2, box_2):
279
+ new_image = np.array(image_1, np.float32) * 0.5 + np.array(image_2, np.float32) * 0.5
280
+ if len(box_1) == 0:
281
+ new_boxes = box_2
282
+ elif len(box_2) == 0:
283
+ new_boxes = box_1
284
+ else:
285
+ new_boxes = np.concatenate([box_1, box_2], axis=0)
286
+ return new_image, new_boxes
287
+
288
+ def yolo_dataset_collate(batch):
289
+ images = []
290
+ bboxes = []
291
+ for img, box in batch:
292
+ images.append(img)
293
+ bboxes.append(box)
294
+ images = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
295
+ bboxes = [torch.from_numpy(ann).type(torch.FloatTensor) for ann in bboxes]
296
+ return images, bboxes
utils/utils.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+
7
+
8
+ def cvtColor(image):
9
+ if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
10
+ return image
11
+ else:
12
+ image = image.convert('RGB')
13
+ return image
14
+
15
+
16
+
17
+ def resize_image(image, size, letterbox_image):
18
+ iw, ih = image.size
19
+ w, h = size
20
+ if letterbox_image:
21
+ scale = min(w / iw, h / ih)
22
+ nw = int(iw * scale)
23
+ nh = int(ih * scale)
24
+
25
+ image = image.resize((nw, nh), Image.BICUBIC)
26
+ new_image = Image.new('RGB', size, (128, 128, 128))
27
+ new_image.paste(image, ((w - nw) // 2, (h - nh) // 2))
28
+ else:
29
+ new_image = image.resize((w, h), Image.BICUBIC)
30
+ return new_image
31
+
32
+
33
+ def get_classes(classes_path):
34
+ with open(classes_path, encoding='utf-8') as f:
35
+ class_names = f.readlines()
36
+ class_names = [c.strip() for c in class_names]
37
+ return class_names, len(class_names)
38
+
39
+
40
+ def get_anchors(anchors_path):
41
+ '''loads the anchors from a file'''
42
+ with open(anchors_path, encoding='utf-8') as f:
43
+ anchors = f.readline()
44
+ anchors = [float(x) for x in anchors.split(',')]
45
+ anchors = np.array(anchors).reshape(-1, 2)
46
+ return anchors, len(anchors)
47
+
48
+
49
+ def get_lr(optimizer):
50
+ for param_group in optimizer.param_groups:
51
+ return param_group['lr']
52
+
53
+
54
+ def seed_everything(seed=11):
55
+ random.seed(seed)
56
+ np.random.seed(seed)
57
+ torch.manual_seed(seed)
58
+ torch.cuda.manual_seed(seed)
59
+ torch.cuda.manual_seed_all(seed)
60
+ torch.backends.cudnn.deterministic = True
61
+ torch.backends.cudnn.benchmark = False
62
+
63
+
64
+ def worker_init_fn(worker_id, rank, seed):
65
+ worker_seed = rank + seed
66
+ random.seed(worker_seed)
67
+ np.random.seed(worker_seed)
68
+ torch.manual_seed(worker_seed)
69
+
70
+
71
+ def preprocess_input(image):
72
+ image /= 255.0
73
+ return image
74
+
75
+
76
+ def show_config(**kwargs):
77
+ print('Configurations:')
78
+ print('-' * 70)
79
+ print('|%25s | %40s|' % ('keys', 'values'))
80
+ print('-' * 70)
81
+ for key, value in kwargs.items():
82
+ print('|%25s | %40s|' % (str(key), str(value)))
83
+ print('-' * 70)
utils/utils_bbox.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision.ops import nms
4
+ import numpy as np
5
+
6
+ class DecodeBox():
7
+ def __init__(self, anchors, num_classes, input_shape, anchors_mask = [[6,7,8], [3,4,5], [0,1,2]]):
8
+ super(DecodeBox, self).__init__()
9
+ self.anchors = anchors
10
+ self.num_classes = num_classes
11
+ self.bbox_attrs = 5 + num_classes
12
+ self.input_shape = input_shape
13
+ self.anchors_mask = anchors_mask
14
+
15
+ def decode_box(self, inputs):
16
+ outputs = []
17
+ for i, input in enumerate(inputs):
18
+ batch_size = input.size(0)
19
+ input_height = input.size(2)
20
+ input_width = input.size(3)
21
+
22
+ stride_h = self.input_shape[0] / input_height
23
+ stride_w = self.input_shape[1] / input_width
24
+ scaled_anchors = [(anchor_width / stride_w, anchor_height / stride_h) for anchor_width, anchor_height in self.anchors[self.anchors_mask[i]]]
25
+
26
+ prediction = input.view(batch_size, len(self.anchors_mask[i]),
27
+ self.bbox_attrs, input_height, input_width).permute(0, 1, 3, 4, 2).contiguous()
28
+
29
+ x = torch.sigmoid(prediction[..., 0])
30
+ y = torch.sigmoid(prediction[..., 1])
31
+ w = prediction[..., 2]
32
+ h = prediction[..., 3]
33
+ conf = torch.sigmoid(prediction[..., 4])
34
+ pred_cls = torch.sigmoid(prediction[..., 5:])
35
+
36
+ FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor
37
+ LongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor
38
+
39
+ grid_x = torch.linspace(0, input_width - 1, input_width).repeat(input_height, 1).repeat(
40
+ batch_size * len(self.anchors_mask[i]), 1, 1).view(x.shape).type(FloatTensor)
41
+ grid_y = torch.linspace(0, input_height - 1, input_height).repeat(input_width, 1).t().repeat(
42
+ batch_size * len(self.anchors_mask[i]), 1, 1).view(y.shape).type(FloatTensor)
43
+
44
+ anchor_w = FloatTensor(scaled_anchors).index_select(1, LongTensor([0]))
45
+ anchor_h = FloatTensor(scaled_anchors).index_select(1, LongTensor([1]))
46
+ anchor_w = anchor_w.repeat(batch_size, 1).repeat(1, 1, input_height * input_width).view(w.shape)
47
+ anchor_h = anchor_h.repeat(batch_size, 1).repeat(1, 1, input_height * input_width).view(h.shape)
48
+
49
+ pred_boxes = FloatTensor(prediction[..., :4].shape)
50
+ pred_boxes[..., 0] = x.data + grid_x
51
+ pred_boxes[..., 1] = y.data + grid_y
52
+ pred_boxes[..., 2] = torch.exp(w.data) * anchor_w
53
+ pred_boxes[..., 3] = torch.exp(h.data) * anchor_h
54
+
55
+ _scale = torch.Tensor([input_width, input_height, input_width, input_height]).type(FloatTensor)
56
+ output = torch.cat((pred_boxes.view(batch_size, -1, 4) / _scale,
57
+ conf.view(batch_size, -1, 1), pred_cls.view(batch_size, -1, self.num_classes)), -1)
58
+ outputs.append(output.data)
59
+ return outputs
60
+
61
+ def yolo_correct_boxes(self, box_xy, box_wh, input_shape, image_shape, letterbox_image):
62
+ box_yx = box_xy[..., ::-1]
63
+ box_hw = box_wh[..., ::-1]
64
+ input_shape = np.array(input_shape)
65
+ image_shape = np.array(image_shape)
66
+
67
+ if letterbox_image:
68
+ new_shape = np.round(image_shape * np.min(input_shape/image_shape))
69
+ offset = (input_shape - new_shape)/2./input_shape
70
+ scale = input_shape/new_shape
71
+
72
+ box_yx = (box_yx - offset) * scale
73
+ box_hw *= scale
74
+
75
+ box_mins = box_yx - (box_hw / 2.)
76
+ box_maxes = box_yx + (box_hw / 2.)
77
+ boxes = np.concatenate([box_mins[..., 0:1], box_mins[..., 1:2], box_maxes[..., 0:1], box_maxes[..., 1:2]], axis=-1)
78
+ boxes *= np.concatenate([image_shape, image_shape], axis=-1)
79
+ return boxes
80
+
81
+ def non_max_suppression(self, prediction, num_classes, input_shape, image_shape, letterbox_image, conf_thres=0.5, nms_thres=0.4):
82
+ box_corner = prediction.new(prediction.shape)
83
+ box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
84
+ box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
85
+ box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
86
+ box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
87
+ prediction[:, :, :4] = box_corner[:, :, :4]
88
+
89
+ output = [None for _ in range(len(prediction))]
90
+ for i, image_pred in enumerate(prediction):
91
+ class_conf, class_pred = torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True)
92
+
93
+ conf_mask = (image_pred[:, 4] * class_conf[:, 0] >= conf_thres).squeeze()
94
+
95
+ image_pred = image_pred[conf_mask]
96
+ class_conf = class_conf[conf_mask]
97
+ class_pred = class_pred[conf_mask]
98
+ if not image_pred.size(0):
99
+ continue
100
+ detections = torch.cat((image_pred[:, :5], class_conf.float(), class_pred.float()), 1)
101
+
102
+ unique_labels = detections[:, -1].cpu().unique()
103
+
104
+ if prediction.is_cuda:
105
+ unique_labels = unique_labels.cuda()
106
+ detections = detections.cuda()
107
+
108
+ for c in unique_labels:
109
+ detections_class = detections[detections[:, -1] == c]
110
+
111
+ keep = nms(
112
+ detections_class[:, :4],
113
+ detections_class[:, 4] * detections_class[:, 5],
114
+ nms_thres
115
+ )
116
+ max_detections = detections_class[keep]
117
+
118
+
119
+ output[i] = max_detections if output[i] is None else torch.cat((output[i], max_detections))
120
+
121
+ if output[i] is not None:
122
+ output[i] = output[i].cpu().numpy()
123
+ box_xy, box_wh = (output[i][:, 0:2] + output[i][:, 2:4])/2, output[i][:, 2:4] - output[i][:, 0:2]
124
+ output[i][:, :4] = self.yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image)
125
+ return output
126
+
127
+ class DecodeBoxNP():
128
+ def __init__(self, anchors, num_classes, input_shape, anchors_mask = [[6,7,8], [3,4,5], [0,1,2]]):
129
+ super(DecodeBoxNP, self).__init__()
130
+ self.anchors = anchors
131
+ self.num_classes = num_classes
132
+ self.bbox_attrs = 5 + num_classes
133
+ self.input_shape = input_shape
134
+ self.anchors_mask = anchors_mask
135
+
136
+ def sigmoid(self, x):
137
+ return 1 / (1 + np.exp(-x))
138
+
139
+ def decode_box(self, inputs):
140
+ outputs = []
141
+ for i, input in enumerate(inputs):
142
+ batch_size = np.shape(input)[0]
143
+ input_height = np.shape(input)[2]
144
+ input_width = np.shape(input)[3]
145
+
146
+ stride_h = self.input_shape[0] / input_height
147
+ stride_w = self.input_shape[1] / input_width
148
+ scaled_anchors = [(anchor_width / stride_w, anchor_height / stride_h) for anchor_width, anchor_height in self.anchors[self.anchors_mask[i]]]
149
+
150
+ prediction = np.transpose(np.reshape(input, (batch_size, len(self.anchors_mask[i]), self.bbox_attrs, input_height, input_width)), (0, 1, 3, 4, 2))
151
+
152
+ x = self.sigmoid(prediction[..., 0])
153
+ y = self.sigmoid(prediction[..., 1])
154
+ w = prediction[..., 2]
155
+ h = prediction[..., 3]
156
+ conf = self.sigmoid(prediction[..., 4])
157
+ pred_cls = self.sigmoid(prediction[..., 5:])
158
+
159
+ grid_x = np.repeat(np.expand_dims(np.repeat(np.expand_dims(np.linspace(0, input_width - 1, input_width), 0), input_height, axis=0), 0), batch_size * len(self.anchors_mask[i]), axis=0)
160
+ grid_x = np.reshape(grid_x, np.shape(x))
161
+ grid_y = np.repeat(np.expand_dims(np.repeat(np.expand_dims(np.linspace(0, input_height - 1, input_height), 0), input_width, axis=0).T, 0), batch_size * len(self.anchors_mask[i]), axis=0)
162
+ grid_y = np.reshape(grid_y, np.shape(y))
163
+
164
+ anchor_w = np.repeat(np.expand_dims(np.repeat(np.expand_dims(np.array(scaled_anchors)[:, 0], 0), batch_size, axis=0), -1), input_height * input_width, axis=-1)
165
+ anchor_h = np.repeat(np.expand_dims(np.repeat(np.expand_dims(np.array(scaled_anchors)[:, 1], 0), batch_size, axis=0), -1), input_height * input_width, axis=-1)
166
+ anchor_w = np.reshape(anchor_w, np.shape(w))
167
+ anchor_h = np.reshape(anchor_h, np.shape(h))
168
+ pred_boxes = np.zeros(np.shape(prediction[..., :4]))
169
+ pred_boxes[..., 0] = x + grid_x
170
+ pred_boxes[..., 1] = y + grid_y
171
+ pred_boxes[..., 2] = np.exp(w) * anchor_w
172
+ pred_boxes[..., 3] = np.exp(h) * anchor_h
173
+
174
+ _scale = np.array([input_width, input_height, input_width, input_height])
175
+ output = np.concatenate([np.reshape(pred_boxes, (batch_size, -1, 4)) / _scale,
176
+ np.reshape(conf, (batch_size, -1, 1)), np.reshape(pred_cls, (batch_size, -1, self.num_classes))], -1)
177
+ outputs.append(output)
178
+ return outputs
179
+
180
+ def bbox_iou(self, box1, box2, x1y1x2y2=True):
181
+ """
182
+ 计算IOU
183
+ """
184
+ if not x1y1x2y2:
185
+ b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2
186
+ b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2
187
+ b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
188
+ b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2
189
+ else:
190
+ b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3]
191
+ b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3]
192
+
193
+ inter_rect_x1 = np.maximum(b1_x1, b2_x1)
194
+ inter_rect_y1 = np.maximum(b1_y1, b2_y1)
195
+ inter_rect_x2 = np.minimum(b1_x2, b2_x2)
196
+ inter_rect_y2 = np.minimum(b1_y2, b2_y2)
197
+
198
+ inter_area = np.maximum(inter_rect_x2 - inter_rect_x1, 0) * \
199
+ np.maximum(inter_rect_y2 - inter_rect_y1, 0)
200
+
201
+ b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
202
+ b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
203
+
204
+ iou = inter_area / np.maximum(b1_area + b2_area - inter_area, 1e-6)
205
+
206
+ return iou
207
+
208
+ def yolo_correct_boxes(self, box_xy, box_wh, input_shape, image_shape, letterbox_image):
209
+ box_yx = box_xy[..., ::-1]
210
+ box_hw = box_wh[..., ::-1]
211
+ input_shape = np.array(input_shape)
212
+ image_shape = np.array(image_shape)
213
+
214
+ if letterbox_image:
215
+ new_shape = np.round(image_shape * np.min(input_shape/image_shape))
216
+ offset = (input_shape - new_shape)/2./input_shape
217
+ scale = input_shape/new_shape
218
+
219
+ box_yx = (box_yx - offset) * scale
220
+ box_hw *= scale
221
+
222
+ box_mins = box_yx - (box_hw / 2.)
223
+ box_maxes = box_yx + (box_hw / 2.)
224
+ boxes = np.concatenate([box_mins[..., 0:1], box_mins[..., 1:2], box_maxes[..., 0:1], box_maxes[..., 1:2]], axis=-1)
225
+ boxes *= np.concatenate([image_shape, image_shape], axis=-1)
226
+ return boxes
227
+
228
+ def non_max_suppression(self, prediction, num_classes, input_shape, image_shape, letterbox_image, conf_thres=0.5, nms_thres=0.4):
229
+ box_corner = np.zeros_like(prediction)
230
+ box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
231
+ box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
232
+ box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
233
+ box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
234
+ prediction[:, :, :4] = box_corner[:, :, :4]
235
+
236
+ output = [None for _ in range(len(prediction))]
237
+ for i, image_pred in enumerate(prediction):
238
+ class_conf = np.max(image_pred[:, 5:5 + num_classes], 1, keepdims=True)
239
+ class_pred = np.expand_dims(np.argmax(image_pred[:, 5:5 + num_classes], 1), -1)
240
+
241
+ conf_mask = np.squeeze((image_pred[:, 4] * class_conf[:, 0] >= conf_thres))
242
+
243
+ image_pred = image_pred[conf_mask]
244
+ class_conf = class_conf[conf_mask]
245
+ class_pred = class_pred[conf_mask]
246
+ if not np.shape(image_pred)[0]:
247
+ continue
248
+ detections = np.concatenate((image_pred[:, :5], class_conf, class_pred), 1)
249
+
250
+ unique_labels = np.unique(detections[:, -1])
251
+
252
+ for c in unique_labels:
253
+ detections_class = detections[detections[:, -1] == c]
254
+
255
+ conf_sort_index = np.argsort(detections_class[:, 4] * detections_class[:, 5])[::-1]
256
+ detections_class = detections_class[conf_sort_index]
257
+ max_detections = []
258
+ while np.shape(detections_class)[0]:
259
+ max_detections.append(detections_class[0:1])
260
+ if len(detections_class) == 1:
261
+ break
262
+ ious = self.bbox_iou(max_detections[-1], detections_class[1:])
263
+ detections_class = detections_class[1:][ious < nms_thres]
264
+ max_detections = np.concatenate(max_detections, 0)
265
+
266
+ output[i] = max_detections if output[i] is None else np.concatenate((output[i], max_detections))
267
+
268
+ if output[i] is not None:
269
+ output[i] = output[i]
270
+ box_xy, box_wh = (output[i][:, 0:2] + output[i][:, 2:4])/2, output[i][:, 2:4] - output[i][:, 0:2]
271
+ output[i][:, :4] = self.yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image)
272
+ return output
utils/utils_fit.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from tqdm import tqdm
5
+
6
+ from utils.utils import get_lr
7
+
8
+
9
+ def fit_one_epoch(model_train, model, yolo_loss, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, fp16, scaler, save_period, save_dir, local_rank=0):
10
+ loss = 0
11
+ val_loss = 0
12
+
13
+ if local_rank == 0:
14
+ print('Start Train')
15
+ pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
16
+ model_train.train()
17
+ for iteration, batch in enumerate(gen):
18
+ if iteration >= epoch_step:
19
+ break
20
+
21
+ images, targets = batch[0], batch[1]
22
+ with torch.no_grad():
23
+ if cuda:
24
+ images = images.cuda(local_rank)
25
+ targets = [ann.cuda(local_rank) for ann in targets]
26
+ optimizer.zero_grad()
27
+ if not fp16:
28
+ outputs = model_train(images)
29
+
30
+ loss_value_all = 0
31
+ for l in range(len(outputs)):
32
+ loss_item = yolo_loss(l, outputs[l], targets)
33
+ loss_value_all += loss_item
34
+ loss_value = loss_value_all
35
+
36
+ loss_value.backward()
37
+ optimizer.step()
38
+ else:
39
+ from torch.cuda.amp import autocast
40
+ with autocast():
41
+ outputs = model_train(images)
42
+
43
+ loss_value_all = 0
44
+ for l in range(len(outputs)):
45
+ loss_item = yolo_loss(l, outputs[l], targets)
46
+ loss_value_all += loss_item
47
+ loss_value = loss_value_all
48
+
49
+ scaler.scale(loss_value).backward()
50
+ scaler.step(optimizer)
51
+ scaler.update()
52
+
53
+ loss += loss_value.item()
54
+
55
+ if local_rank == 0:
56
+ pbar.set_postfix(**{'loss' : loss / (iteration + 1),
57
+ 'lr' : get_lr(optimizer)})
58
+ pbar.update(1)
59
+
60
+ if local_rank == 0:
61
+ pbar.close()
62
+ print('Finish Train')
63
+ print('Start Validation')
64
+ pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
65
+
66
+ model_train.eval()
67
+ for iteration, batch in enumerate(gen_val):
68
+ if iteration >= epoch_step_val:
69
+ break
70
+ images, targets = batch[0], batch[1]
71
+ with torch.no_grad():
72
+ if cuda:
73
+ images = images.cuda(local_rank)
74
+ targets = [ann.cuda(local_rank) for ann in targets]
75
+ optimizer.zero_grad()
76
+ outputs = model_train(images)
77
+
78
+ loss_value_all = 0
79
+ for l in range(len(outputs)):
80
+ loss_item = yolo_loss(l, outputs[l], targets)
81
+ loss_value_all += loss_item
82
+ loss_value = loss_value_all
83
+
84
+ val_loss += loss_value.item()
85
+ if local_rank == 0:
86
+ pbar.set_postfix(**{'val_loss': val_loss / (iteration + 1)})
87
+ pbar.update(1)
88
+
89
+ if local_rank == 0:
90
+ pbar.close()
91
+ print('Finish Validation')
92
+ loss_history.append_loss(epoch + 1, loss / epoch_step, val_loss / epoch_step_val)
93
+ eval_callback.on_epoch_end(epoch + 1, model_train)
94
+ print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch))
95
+ print('Total Loss: %.3f || Val Loss: %.3f ' % (loss / epoch_step, val_loss / epoch_step_val))
96
+
97
+ if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
98
+ torch.save(model.state_dict(), os.path.join(save_dir, "ep%03d-loss%.3f-val_loss%.3f.pth" % (epoch + 1, loss / epoch_step, val_loss / epoch_step_val)))
99
+
100
+ if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss):
101
+ print('Save best model to best_epoch_weights.pth')
102
+ torch.save(model.state_dict(), os.path.join(save_dir, "best_epoch_weights.pth"))
103
+
104
+ torch.save(model.state_dict(), os.path.join(save_dir, "last_epoch_weights.pth"))
utils/utils_map.py ADDED
@@ -0,0 +1,757 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import json
3
+ import math
4
+ import operator
5
+ import os
6
+ import shutil
7
+ import sys
8
+ try:
9
+ from pycocotools.coco import COCO
10
+ from pycocotools.cocoeval import COCOeval
11
+ except:
12
+ pass
13
+ import cv2
14
+ import matplotlib
15
+ matplotlib.use('Agg')
16
+ from matplotlib import pyplot as plt
17
+ import numpy as np
18
+
19
+ def log_average_miss_rate(precision, fp_cumsum, num_images):
20
+
21
+ if precision.size == 0:
22
+ lamr = 0
23
+ mr = 1
24
+ fppi = 0
25
+ return lamr, mr, fppi
26
+
27
+ fppi = fp_cumsum / float(num_images)
28
+ mr = (1 - precision)
29
+
30
+ fppi_tmp = np.insert(fppi, 0, -1.0)
31
+ mr_tmp = np.insert(mr, 0, 1.0)
32
+
33
+ ref = np.logspace(-2.0, 0.0, num = 9)
34
+ for i, ref_i in enumerate(ref):
35
+ j = np.where(fppi_tmp <= ref_i)[-1][-1]
36
+ ref[i] = mr_tmp[j]
37
+
38
+ lamr = math.exp(np.mean(np.log(np.maximum(1e-10, ref))))
39
+
40
+ return lamr, mr, fppi
41
+
42
+ def error(msg):
43
+ print(msg)
44
+ sys.exit(0)
45
+
46
+ def is_float_between_0_and_1(value):
47
+ try:
48
+ val = float(value)
49
+ if val > 0.0 and val < 1.0:
50
+ return True
51
+ else:
52
+ return False
53
+ except ValueError:
54
+ return False
55
+
56
+ def voc_ap(rec, prec):
57
+
58
+ rec.insert(0, 0.0)
59
+ rec.append(1.0)
60
+ mrec = rec[:]
61
+ prec.insert(0, 0.0)
62
+ prec.append(0.0)
63
+ mpre = prec[:]
64
+
65
+ for i in range(len(mpre)-2, -1, -1):
66
+ mpre[i] = max(mpre[i], mpre[i+1])
67
+
68
+ i_list = []
69
+ for i in range(1, len(mrec)):
70
+ if mrec[i] != mrec[i-1]:
71
+ i_list.append(i)
72
+
73
+ ap = 0.0
74
+ for i in i_list:
75
+ ap += ((mrec[i]-mrec[i-1])*mpre[i])
76
+ return ap, mrec, mpre
77
+
78
+ def file_lines_to_list(path):
79
+ with open(path) as f:
80
+ content = f.readlines()
81
+ content = [x.strip() for x in content]
82
+ return content
83
+
84
+
85
+ def draw_text_in_image(img, text, pos, color, line_width):
86
+ font = cv2.FONT_HERSHEY_PLAIN
87
+ fontScale = 1
88
+ lineType = 1
89
+ bottomLeftCornerOfText = pos
90
+ cv2.putText(img, text,
91
+ bottomLeftCornerOfText,
92
+ font,
93
+ fontScale,
94
+ color,
95
+ lineType)
96
+ text_width, _ = cv2.getTextSize(text, font, fontScale, lineType)[0]
97
+ return img, (line_width + text_width)
98
+
99
+ def adjust_axes(r, t, fig, axes):
100
+ bb = t.get_window_extent(renderer=r)
101
+ text_width_inches = bb.width / fig.dpi
102
+ current_fig_width = fig.get_figwidth()
103
+ new_fig_width = current_fig_width + text_width_inches
104
+ propotion = new_fig_width / current_fig_width
105
+ x_lim = axes.get_xlim()
106
+ axes.set_xlim([x_lim[0], x_lim[1]*propotion])
107
+
108
+ def draw_plot_func(dictionary, n_classes, window_title, plot_title, x_label, output_path, to_show, plot_color, true_p_bar):
109
+ sorted_dic_by_value = sorted(dictionary.items(), key=operator.itemgetter(1))
110
+ sorted_keys, sorted_values = zip(*sorted_dic_by_value)
111
+ if true_p_bar != "":
112
+ fp_sorted = []
113
+ tp_sorted = []
114
+ for key in sorted_keys:
115
+ fp_sorted.append(dictionary[key] - true_p_bar[key])
116
+ tp_sorted.append(true_p_bar[key])
117
+ plt.barh(range(n_classes), fp_sorted, align='center', color='crimson', label='False Positive')
118
+ plt.barh(range(n_classes), tp_sorted, align='center', color='forestgreen', label='True Positive', left=fp_sorted)
119
+ plt.legend(loc='lower right')
120
+
121
+ fig = plt.gcf()
122
+ axes = plt.gca()
123
+ r = fig.canvas.get_renderer()
124
+ for i, val in enumerate(sorted_values):
125
+ fp_val = fp_sorted[i]
126
+ tp_val = tp_sorted[i]
127
+ fp_str_val = " " + str(fp_val)
128
+ tp_str_val = fp_str_val + " " + str(tp_val)
129
+ t = plt.text(val, i, tp_str_val, color='forestgreen', va='center', fontweight='bold')
130
+ plt.text(val, i, fp_str_val, color='crimson', va='center', fontweight='bold')
131
+ if i == (len(sorted_values)-1):
132
+ adjust_axes(r, t, fig, axes)
133
+ else:
134
+ plt.barh(range(n_classes), sorted_values, color=plot_color)
135
+ """
136
+ Write number on side of bar
137
+ """
138
+ fig = plt.gcf()
139
+ axes = plt.gca()
140
+ r = fig.canvas.get_renderer()
141
+ for i, val in enumerate(sorted_values):
142
+ str_val = " " + str(val)
143
+ if val < 1.0:
144
+ str_val = " {0:.2f}".format(val)
145
+ t = plt.text(val, i, str_val, color=plot_color, va='center', fontweight='bold')
146
+ if i == (len(sorted_values)-1):
147
+ adjust_axes(r, t, fig, axes)
148
+ fig.canvas.set_window_title(window_title)
149
+ tick_font_size = 12
150
+ plt.yticks(range(n_classes), sorted_keys, fontsize=tick_font_size)
151
+ """
152
+ Re-scale height accordingly
153
+ """
154
+ init_height = fig.get_figheight()
155
+ dpi = fig.dpi
156
+ height_pt = n_classes * (tick_font_size * 1.4)
157
+ height_in = height_pt / dpi
158
+ top_margin = 0.15
159
+ bottom_margin = 0.05
160
+ figure_height = height_in / (1 - top_margin - bottom_margin)
161
+ if figure_height > init_height:
162
+ fig.set_figheight(figure_height)
163
+
164
+ plt.title(plot_title, fontsize=14)
165
+ plt.xlabel(x_label, fontsize='large')
166
+ fig.tight_layout()
167
+ fig.savefig(output_path)
168
+ if to_show:
169
+ plt.show()
170
+ plt.close()
171
+
172
+ def get_map(MINOVERLAP, draw_plot, score_threhold=0.5, path = './map_out'):
173
+ GT_PATH = os.path.join(path, 'ground-truth')
174
+ DR_PATH = os.path.join(path, 'detection-results')
175
+ IMG_PATH = os.path.join(path, 'images-optional')
176
+ TEMP_FILES_PATH = os.path.join(path, '.temp_files')
177
+ RESULTS_FILES_PATH = os.path.join(path, 'results')
178
+
179
+ show_animation = True
180
+ if os.path.exists(IMG_PATH):
181
+ for dirpath, dirnames, files in os.walk(IMG_PATH):
182
+ if not files:
183
+ show_animation = False
184
+ else:
185
+ show_animation = False
186
+
187
+ if not os.path.exists(TEMP_FILES_PATH):
188
+ os.makedirs(TEMP_FILES_PATH)
189
+
190
+ if os.path.exists(RESULTS_FILES_PATH):
191
+ shutil.rmtree(RESULTS_FILES_PATH)
192
+ else:
193
+ os.makedirs(RESULTS_FILES_PATH)
194
+ if draw_plot:
195
+ try:
196
+ matplotlib.use('TkAgg')
197
+ except:
198
+ pass
199
+ os.makedirs(os.path.join(RESULTS_FILES_PATH, "AP"))
200
+ os.makedirs(os.path.join(RESULTS_FILES_PATH, "F1"))
201
+ os.makedirs(os.path.join(RESULTS_FILES_PATH, "Recall"))
202
+ os.makedirs(os.path.join(RESULTS_FILES_PATH, "Precision"))
203
+ if show_animation:
204
+ os.makedirs(os.path.join(RESULTS_FILES_PATH, "images", "detections_one_by_one"))
205
+
206
+ ground_truth_files_list = glob.glob(GT_PATH + '/*.txt')
207
+ if len(ground_truth_files_list) == 0:
208
+ error("Error: No ground-truth files found!")
209
+ ground_truth_files_list.sort()
210
+ gt_counter_per_class = {}
211
+ counter_images_per_class = {}
212
+
213
+ for txt_file in ground_truth_files_list:
214
+ file_id = txt_file.split(".txt", 1)[0]
215
+ file_id = os.path.basename(os.path.normpath(file_id))
216
+ temp_path = os.path.join(DR_PATH, (file_id + ".txt"))
217
+ if not os.path.exists(temp_path):
218
+ error_msg = "Error. File not found: {}\n".format(temp_path)
219
+ error(error_msg)
220
+ lines_list = file_lines_to_list(txt_file)
221
+ bounding_boxes = []
222
+ is_difficult = False
223
+ already_seen_classes = []
224
+ for line in lines_list:
225
+ try:
226
+ if "difficult" in line:
227
+ class_name, left, top, right, bottom, _difficult = line.split()
228
+ is_difficult = True
229
+ else:
230
+ class_name, left, top, right, bottom = line.split()
231
+ except:
232
+ if "difficult" in line:
233
+ line_split = line.split()
234
+ _difficult = line_split[-1]
235
+ bottom = line_split[-2]
236
+ right = line_split[-3]
237
+ top = line_split[-4]
238
+ left = line_split[-5]
239
+ class_name = ""
240
+ for name in line_split[:-5]:
241
+ class_name += name + " "
242
+ class_name = class_name[:-1]
243
+ is_difficult = True
244
+ else:
245
+ line_split = line.split()
246
+ bottom = line_split[-1]
247
+ right = line_split[-2]
248
+ top = line_split[-3]
249
+ left = line_split[-4]
250
+ class_name = ""
251
+ for name in line_split[:-4]:
252
+ class_name += name + " "
253
+ class_name = class_name[:-1]
254
+
255
+ bbox = left + " " + top + " " + right + " " + bottom
256
+ if is_difficult:
257
+ bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False, "difficult":True})
258
+ is_difficult = False
259
+ else:
260
+ bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False})
261
+ if class_name in gt_counter_per_class:
262
+ gt_counter_per_class[class_name] += 1
263
+ else:
264
+ gt_counter_per_class[class_name] = 1
265
+
266
+ if class_name not in already_seen_classes:
267
+ if class_name in counter_images_per_class:
268
+ counter_images_per_class[class_name] += 1
269
+ else:
270
+ counter_images_per_class[class_name] = 1
271
+ already_seen_classes.append(class_name)
272
+
273
+ with open(TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json", 'w') as outfile:
274
+ json.dump(bounding_boxes, outfile)
275
+
276
+ gt_classes = list(gt_counter_per_class.keys())
277
+ gt_classes = sorted(gt_classes)
278
+ n_classes = len(gt_classes)
279
+
280
+ dr_files_list = glob.glob(DR_PATH + '/*.txt')
281
+ dr_files_list.sort()
282
+ for class_index, class_name in enumerate(gt_classes):
283
+ bounding_boxes = []
284
+ for txt_file in dr_files_list:
285
+ file_id = txt_file.split(".txt",1)[0]
286
+ file_id = os.path.basename(os.path.normpath(file_id))
287
+ temp_path = os.path.join(GT_PATH, (file_id + ".txt"))
288
+ if class_index == 0:
289
+ if not os.path.exists(temp_path):
290
+ error_msg = "Error. File not found: {}\n".format(temp_path)
291
+ error(error_msg)
292
+ lines = file_lines_to_list(txt_file)
293
+ for line in lines:
294
+ try:
295
+ tmp_class_name, confidence, left, top, right, bottom = line.split()
296
+ except:
297
+ line_split = line.split()
298
+ bottom = line_split[-1]
299
+ right = line_split[-2]
300
+ top = line_split[-3]
301
+ left = line_split[-4]
302
+ confidence = line_split[-5]
303
+ tmp_class_name = ""
304
+ for name in line_split[:-5]:
305
+ tmp_class_name += name + " "
306
+ tmp_class_name = tmp_class_name[:-1]
307
+
308
+ if tmp_class_name == class_name:
309
+ bbox = left + " " + top + " " + right + " " +bottom
310
+ bounding_boxes.append({"confidence":confidence, "file_id":file_id, "bbox":bbox})
311
+
312
+ bounding_boxes.sort(key=lambda x:float(x['confidence']), reverse=True)
313
+ with open(TEMP_FILES_PATH + "/" + class_name + "_dr.json", 'w') as outfile:
314
+ json.dump(bounding_boxes, outfile)
315
+
316
+ sum_AP = 0.0
317
+ ap_dictionary = {}
318
+ lamr_dictionary = {}
319
+ with open(RESULTS_FILES_PATH + "/results.txt", 'w') as results_file:
320
+ count_true_positives = {}
321
+
322
+ for class_index, class_name in enumerate(gt_classes):
323
+ count_true_positives[class_name] = 0
324
+ dr_file = TEMP_FILES_PATH + "/" + class_name + "_dr.json"
325
+ dr_data = json.load(open(dr_file))
326
+
327
+ nd = len(dr_data)
328
+ tp = [0] * nd
329
+ fp = [0] * nd
330
+ score = [0] * nd
331
+ score_threhold_idx = 0
332
+ for idx, detection in enumerate(dr_data):
333
+ file_id = detection["file_id"]
334
+ score[idx] = float(detection["confidence"])
335
+ if score[idx] >= score_threhold:
336
+ score_threhold_idx = idx
337
+
338
+ if show_animation:
339
+ ground_truth_img = glob.glob1(IMG_PATH, file_id + ".*")
340
+ if len(ground_truth_img) == 0:
341
+ error("Error. Image not found with id: " + file_id)
342
+ elif len(ground_truth_img) > 1:
343
+ error("Error. Multiple image with id: " + file_id)
344
+ else:
345
+ img = cv2.imread(IMG_PATH + "/" + ground_truth_img[0])
346
+ img_cumulative_path = RESULTS_FILES_PATH + "/images/" + ground_truth_img[0]
347
+ if os.path.isfile(img_cumulative_path):
348
+ img_cumulative = cv2.imread(img_cumulative_path)
349
+ else:
350
+ img_cumulative = img.copy()
351
+ bottom_border = 60
352
+ BLACK = [0, 0, 0]
353
+ img = cv2.copyMakeBorder(img, 0, bottom_border, 0, 0, cv2.BORDER_CONSTANT, value=BLACK)
354
+
355
+ gt_file = TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json"
356
+ ground_truth_data = json.load(open(gt_file))
357
+ ovmax = -1
358
+ gt_match = -1
359
+ bb = [float(x) for x in detection["bbox"].split()]
360
+ for obj in ground_truth_data:
361
+ if obj["class_name"] == class_name:
362
+ bbgt = [ float(x) for x in obj["bbox"].split() ]
363
+ bi = [max(bb[0],bbgt[0]), max(bb[1],bbgt[1]), min(bb[2],bbgt[2]), min(bb[3],bbgt[3])]
364
+ iw = bi[2] - bi[0] + 1
365
+ ih = bi[3] - bi[1] + 1
366
+ if iw > 0 and ih > 0:
367
+ ua = (bb[2] - bb[0] + 1) * (bb[3] - bb[1] + 1) + (bbgt[2] - bbgt[0]
368
+ + 1) * (bbgt[3] - bbgt[1] + 1) - iw * ih
369
+ ov = iw * ih / ua
370
+ if ov > ovmax:
371
+ ovmax = ov
372
+ gt_match = obj
373
+
374
+ if show_animation:
375
+ status = "NO MATCH FOUND!"
376
+
377
+ min_overlap = MINOVERLAP
378
+ if ovmax >= min_overlap:
379
+ if "difficult" not in gt_match:
380
+ if not bool(gt_match["used"]):
381
+ tp[idx] = 1
382
+ gt_match["used"] = True
383
+ count_true_positives[class_name] += 1
384
+ with open(gt_file, 'w') as f:
385
+ f.write(json.dumps(ground_truth_data))
386
+ if show_animation:
387
+ status = "MATCH!"
388
+ else:
389
+ fp[idx] = 1
390
+ if show_animation:
391
+ status = "REPEATED MATCH!"
392
+ else:
393
+ fp[idx] = 1
394
+ if ovmax > 0:
395
+ status = "INSUFFICIENT OVERLAP"
396
+ if show_animation:
397
+ height, widht = img.shape[:2]
398
+ white = (255,255,255)
399
+ light_blue = (255,200,100)
400
+ green = (0,255,0)
401
+ light_red = (30,30,255)
402
+ margin = 10
403
+ v_pos = int(height - margin - (bottom_border / 2.0))
404
+ text = "Image: " + ground_truth_img[0] + " "
405
+ img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0)
406
+ text = "Class [" + str(class_index) + "/" + str(n_classes) + "]: " + class_name + " "
407
+ img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), light_blue, line_width)
408
+ if ovmax != -1:
409
+ color = light_red
410
+ if status == "INSUFFICIENT OVERLAP":
411
+ text = "IoU: {0:.2f}% ".format(ovmax*100) + "< {0:.2f}% ".format(min_overlap*100)
412
+ else:
413
+ text = "IoU: {0:.2f}% ".format(ovmax*100) + ">= {0:.2f}% ".format(min_overlap*100)
414
+ color = green
415
+ img, _ = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width)
416
+ v_pos += int(bottom_border / 2.0)
417
+ rank_pos = str(idx+1)
418
+ img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0)
419
+ color = light_red
420
+ if status == "MATCH!":
421
+ color = green
422
+ text = "Result: " + status + " "
423
+ img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width)
424
+
425
+ font = cv2.FONT_HERSHEY_SIMPLEX
426
+ if ovmax > 0:
427
+ bbgt = [ int(round(float(x))) for x in gt_match["bbox"].split() ]
428
+ cv2.rectangle(img,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2)
429
+ cv2.rectangle(img_cumulative,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2)
430
+ cv2.putText(img_cumulative, class_name, (bbgt[0],bbgt[1] - 5), font, 0.6, light_blue, 1, cv2.LINE_AA)
431
+ bb = [int(i) for i in bb]
432
+ cv2.rectangle(img,(bb[0],bb[1]),(bb[2],bb[3]),color,2)
433
+ cv2.rectangle(img_cumulative,(bb[0],bb[1]),(bb[2],bb[3]),color,2)
434
+ cv2.putText(img_cumulative, class_name, (bb[0],bb[1] - 5), font, 0.6, color, 1, cv2.LINE_AA)
435
+
436
+ cv2.imshow("Animation", img)
437
+ cv2.waitKey(20)
438
+ output_img_path = RESULTS_FILES_PATH + "/images/detections_one_by_one/" + class_name + "_detection" + str(idx) + ".jpg"
439
+ cv2.imwrite(output_img_path, img)
440
+ cv2.imwrite(img_cumulative_path, img_cumulative)
441
+
442
+ cumsum = 0
443
+ for idx, val in enumerate(fp):
444
+ fp[idx] += cumsum
445
+ cumsum += val
446
+
447
+ cumsum = 0
448
+ for idx, val in enumerate(tp):
449
+ tp[idx] += cumsum
450
+ cumsum += val
451
+
452
+ rec = tp[:]
453
+ for idx, val in enumerate(tp):
454
+ rec[idx] = float(tp[idx]) / np.maximum(gt_counter_per_class[class_name], 1)
455
+
456
+ prec = tp[:]
457
+ for idx, val in enumerate(tp):
458
+ prec[idx] = float(tp[idx]) / np.maximum((fp[idx] + tp[idx]), 1)
459
+
460
+ ap, mrec, mprec = voc_ap(rec[:], prec[:])
461
+ F1 = np.array(rec)*np.array(prec)*2 / np.where((np.array(prec)+np.array(rec))==0, 1, (np.array(prec)+np.array(rec)))
462
+
463
+ sum_AP += ap
464
+
465
+ if len(prec)>0:
466
+ F1_text = "{0:.2f}".format(F1[score_threhold_idx]) + " = " + class_name + " F1 "
467
+ Recall_text = "{0:.2f}%".format(rec[score_threhold_idx]*100) + " = " + class_name + " Recall "
468
+ Precision_text = "{0:.2f}%".format(prec[score_threhold_idx]*100) + " = " + class_name + " Precision "
469
+ else:
470
+ F1_text = "0.00" + " = " + class_name + " F1 "
471
+ Recall_text = "0.00%" + " = " + class_name + " Recall "
472
+ Precision_text = "0.00%" + " = " + class_name + " Precision "
473
+
474
+ rounded_prec = [ '%.2f' % elem for elem in prec ]
475
+ rounded_rec = [ '%.2f' % elem for elem in rec ]
476
+ results_file.write(text + "\n Precision: " + str(rounded_prec) + "\n Recall :" + str(rounded_rec) + "\n\n")
477
+
478
+ if len(prec)>0:
479
+ print(text + "\t||\tscore_threhold=" + str(score_threhold) + " : " + "F1=" + "{0:.2f}".format(F1[score_threhold_idx])\
480
+ + " ; Recall=" + "{0:.2f}%".format(rec[score_threhold_idx]*100) + " ; Precision=" + "{0:.2f}%".format(prec[score_threhold_idx]*100))
481
+ else:
482
+ print(text + "\t||\tscore_threhold=" + str(score_threhold) + " : " + "F1=0.00% ; Recall=0.00% ; Precision=0.00%")
483
+ ap_dictionary[class_name] = ap
484
+
485
+ n_images = counter_images_per_class[class_name]
486
+ lamr, mr, fppi = log_average_miss_rate(np.array(rec), np.array(fp), n_images)
487
+ lamr_dictionary[class_name] = lamr
488
+
489
+ if draw_plot:
490
+ plt.plot(rec, prec, '-o')
491
+ area_under_curve_x = mrec[:-1] + [mrec[-2]] + [mrec[-1]]
492
+ area_under_curve_y = mprec[:-1] + [0.0] + [mprec[-1]]
493
+ plt.fill_between(area_under_curve_x, 0, area_under_curve_y, alpha=0.2, edgecolor='r')
494
+
495
+ fig = plt.gcf()
496
+ fig.canvas.set_window_title('AP ' + class_name)
497
+
498
+ plt.title('class: ' + text)
499
+ plt.xlabel('Recall')
500
+ plt.ylabel('Precision')
501
+ axes = plt.gca()
502
+ axes.set_xlim([0.0,1.0])
503
+ axes.set_ylim([0.0,1.05])
504
+ fig.savefig(RESULTS_FILES_PATH + "/AP/" + class_name + ".png")
505
+ plt.cla()
506
+
507
+ plt.plot(score, F1, "-", color='orangered')
508
+ plt.title('class: ' + F1_text + "\nscore_threhold=" + str(score_threhold))
509
+ plt.xlabel('Score_Threhold')
510
+ plt.ylabel('F1')
511
+ axes = plt.gca()
512
+ axes.set_xlim([0.0,1.0])
513
+ axes.set_ylim([0.0,1.05])
514
+ fig.savefig(RESULTS_FILES_PATH + "/F1/" + class_name + ".png")
515
+ plt.cla()
516
+
517
+ plt.plot(score, rec, "-H", color='gold')
518
+ plt.title('class: ' + Recall_text + "\nscore_threhold=" + str(score_threhold))
519
+ plt.xlabel('Score_Threhold')
520
+ plt.ylabel('Recall')
521
+ axes = plt.gca()
522
+ axes.set_xlim([0.0,1.0])
523
+ axes.set_ylim([0.0,1.05])
524
+ fig.savefig(RESULTS_FILES_PATH + "/Recall/" + class_name + ".png")
525
+ plt.cla()
526
+
527
+ plt.plot(score, prec, "-s", color='palevioletred')
528
+ plt.title('class: ' + Precision_text + "\nscore_threhold=" + str(score_threhold))
529
+ plt.xlabel('Score_Threhold')
530
+ plt.ylabel('Precision')
531
+ axes = plt.gca()
532
+ axes.set_xlim([0.0,1.0])
533
+ axes.set_ylim([0.0,1.05])
534
+ fig.savefig(RESULTS_FILES_PATH + "/Precision/" + class_name + ".png")
535
+ plt.cla()
536
+
537
+ if show_animation:
538
+ cv2.destroyAllWindows()
539
+ if n_classes == 0:
540
+ print("未检测到任何种类,请检查标签信息与get_map.py中的classes_path是否修改。")
541
+ return 0
542
+ mAP = sum_AP / n_classes
543
+ text = "mAP = {0:.2f}%".format(mAP*100)
544
+ results_file.write(text + "\n")
545
+ print(text)
546
+
547
+ shutil.rmtree(TEMP_FILES_PATH)
548
+
549
+ det_counter_per_class = {}
550
+ for txt_file in dr_files_list:
551
+ lines_list = file_lines_to_list(txt_file)
552
+ for line in lines_list:
553
+ class_name = line.split()[0]
554
+ if class_name in det_counter_per_class:
555
+ det_counter_per_class[class_name] += 1
556
+ else:
557
+ det_counter_per_class[class_name] = 1
558
+ dr_classes = list(det_counter_per_class.keys())
559
+
560
+ with open(RESULTS_FILES_PATH + "/results.txt", 'a') as results_file:
561
+ for class_name in sorted(gt_counter_per_class):
562
+ results_file.write(class_name + ": " + str(gt_counter_per_class[class_name]) + "\n")
563
+
564
+ for class_name in dr_classes:
565
+ if class_name not in gt_classes:
566
+ count_true_positives[class_name] = 0
567
+
568
+ with open(RESULTS_FILES_PATH + "/results.txt", 'a') as results_file:
569
+ for class_name in sorted(dr_classes):
570
+ n_det = det_counter_per_class[class_name]
571
+ text = class_name + ": " + str(n_det)
572
+ text += " (tp:" + str(count_true_positives[class_name]) + ""
573
+ text += ", fp:" + str(n_det - count_true_positives[class_name]) + ")\n"
574
+ results_file.write(text)
575
+
576
+ if draw_plot:
577
+ window_title = "ground-truth-info"
578
+ plot_title = "ground-truth\n"
579
+ plot_title += "(" + str(len(ground_truth_files_list)) + " files and " + str(n_classes) + " classes)"
580
+ x_label = "Number of objects per class"
581
+ output_path = RESULTS_FILES_PATH + "/ground-truth-info.png"
582
+ to_show = False
583
+ plot_color = 'forestgreen'
584
+ draw_plot_func(
585
+ gt_counter_per_class,
586
+ n_classes,
587
+ window_title,
588
+ plot_title,
589
+ x_label,
590
+ output_path,
591
+ to_show,
592
+ plot_color,
593
+ '',
594
+ )
595
+
596
+
597
+
598
+ if draw_plot:
599
+ window_title = "lamr"
600
+ plot_title = "log-average miss rate"
601
+ x_label = "log-average miss rate"
602
+ output_path = RESULTS_FILES_PATH + "/lamr.png"
603
+ to_show = False
604
+ plot_color = 'royalblue'
605
+ draw_plot_func(
606
+ lamr_dictionary,
607
+ n_classes,
608
+ window_title,
609
+ plot_title,
610
+ x_label,
611
+ output_path,
612
+ to_show,
613
+ plot_color,
614
+ ""
615
+ )
616
+
617
+
618
+ if draw_plot:
619
+ window_title = "mAP"
620
+ plot_title = "mAP = {0:.2f}%".format(mAP*100)
621
+ x_label = "Average Precision"
622
+ output_path = RESULTS_FILES_PATH + "/mAP.png"
623
+ to_show = True
624
+ plot_color = 'royalblue'
625
+ draw_plot_func(
626
+ ap_dictionary,
627
+ n_classes,
628
+ window_title,
629
+ plot_title,
630
+ x_label,
631
+ output_path,
632
+ to_show,
633
+ plot_color,
634
+ ""
635
+ )
636
+ return mAP
637
+
638
+ def preprocess_gt(gt_path, class_names):
639
+ image_ids = os.listdir(gt_path)
640
+ results = {}
641
+
642
+ images = []
643
+ bboxes = []
644
+ for i, image_id in enumerate(image_ids):
645
+ lines_list = file_lines_to_list(os.path.join(gt_path, image_id))
646
+ boxes_per_image = []
647
+ image = {}
648
+ image_id = os.path.splitext(image_id)[0]
649
+ image['file_name'] = image_id + '.jpg'
650
+ image['width'] = 1
651
+ image['height'] = 1
652
+ image['id'] = str(image_id)
653
+
654
+ for line in lines_list:
655
+ difficult = 0
656
+ if "difficult" in line:
657
+ line_split = line.split()
658
+ left, top, right, bottom, _difficult = line_split[-5:]
659
+ class_name = ""
660
+ for name in line_split[:-5]:
661
+ class_name += name + " "
662
+ class_name = class_name[:-1]
663
+ difficult = 1
664
+ else:
665
+ line_split = line.split()
666
+ left, top, right, bottom = line_split[-4:]
667
+ class_name = ""
668
+ for name in line_split[:-4]:
669
+ class_name += name + " "
670
+ class_name = class_name[:-1]
671
+
672
+ left, top, right, bottom = float(left), float(top), float(right), float(bottom)
673
+ if class_name not in class_names:
674
+ continue
675
+ cls_id = class_names.index(class_name) + 1
676
+ bbox = [left, top, right - left, bottom - top, difficult, str(image_id), cls_id, (right - left) * (bottom - top) - 10.0]
677
+ boxes_per_image.append(bbox)
678
+ images.append(image)
679
+ bboxes.extend(boxes_per_image)
680
+ results['images'] = images
681
+
682
+ categories = []
683
+ for i, cls in enumerate(class_names):
684
+ category = {}
685
+ category['supercategory'] = cls
686
+ category['name'] = cls
687
+ category['id'] = i + 1
688
+ categories.append(category)
689
+ results['categories'] = categories
690
+
691
+ annotations = []
692
+ for i, box in enumerate(bboxes):
693
+ annotation = {}
694
+ annotation['area'] = box[-1]
695
+ annotation['category_id'] = box[-2]
696
+ annotation['image_id'] = box[-3]
697
+ annotation['iscrowd'] = box[-4]
698
+ annotation['bbox'] = box[:4]
699
+ annotation['id'] = i
700
+ annotations.append(annotation)
701
+ results['annotations'] = annotations
702
+ return results
703
+
704
+ def preprocess_dr(dr_path, class_names):
705
+ image_ids = os.listdir(dr_path)
706
+ results = []
707
+ for image_id in image_ids:
708
+ lines_list = file_lines_to_list(os.path.join(dr_path, image_id))
709
+ image_id = os.path.splitext(image_id)[0]
710
+ for line in lines_list:
711
+ line_split = line.split()
712
+ confidence, left, top, right, bottom = line_split[-5:]
713
+ class_name = ""
714
+ for name in line_split[:-5]:
715
+ class_name += name + " "
716
+ class_name = class_name[:-1]
717
+ left, top, right, bottom = float(left), float(top), float(right), float(bottom)
718
+ result = {}
719
+ result["image_id"] = str(image_id)
720
+ if class_name not in class_names:
721
+ continue
722
+ result["category_id"] = class_names.index(class_name) + 1
723
+ result["bbox"] = [left, top, right - left, bottom - top]
724
+ result["score"] = float(confidence)
725
+ results.append(result)
726
+ return results
727
+
728
+ def get_coco_map(class_names, path):
729
+ GT_PATH = os.path.join(path, 'ground-truth')
730
+ DR_PATH = os.path.join(path, 'detection-results')
731
+ COCO_PATH = os.path.join(path, 'coco_eval')
732
+
733
+ if not os.path.exists(COCO_PATH):
734
+ os.makedirs(COCO_PATH)
735
+
736
+ GT_JSON_PATH = os.path.join(COCO_PATH, 'instances_gt.json')
737
+ DR_JSON_PATH = os.path.join(COCO_PATH, 'instances_dr.json')
738
+
739
+ with open(GT_JSON_PATH, "w") as f:
740
+ results_gt = preprocess_gt(GT_PATH, class_names)
741
+ json.dump(results_gt, f, indent=4)
742
+
743
+ with open(DR_JSON_PATH, "w") as f:
744
+ results_dr = preprocess_dr(DR_PATH, class_names)
745
+ json.dump(results_dr, f, indent=4)
746
+ if len(results_dr) == 0:
747
+ print("未检测到任何目标。")
748
+ return [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
749
+
750
+ cocoGt = COCO(GT_JSON_PATH)
751
+ cocoDt = cocoGt.loadRes(DR_JSON_PATH)
752
+ cocoEval = COCOeval(cocoGt, cocoDt, 'bbox')
753
+ cocoEval.evaluate()
754
+ cocoEval.accumulate()
755
+ cocoEval.summarize()
756
+
757
+ return cocoEval.stats