Spaces:
Runtime error
Runtime error
Upload 7 files
Browse files- utils/__init__.py +0 -0
- utils/callbacks.py +202 -0
- utils/dataloader.py +296 -0
- utils/utils.py +83 -0
- utils/utils_bbox.py +272 -0
- utils/utils_fit.py +104 -0
- utils/utils_map.py +757 -0
utils/__init__.py
ADDED
|
File without changes
|
utils/callbacks.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import matplotlib
|
| 6 |
+
matplotlib.use('Agg')
|
| 7 |
+
import scipy.signal
|
| 8 |
+
from matplotlib import pyplot as plt
|
| 9 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 10 |
+
|
| 11 |
+
import shutil
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
from PIL import Image
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
from .utils import cvtColor, preprocess_input, resize_image
|
| 17 |
+
from .utils_bbox import DecodeBox
|
| 18 |
+
from .utils_map import get_coco_map, get_map
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class LossHistory():
|
| 22 |
+
def __init__(self, log_dir, model, input_shape):
|
| 23 |
+
self.log_dir = log_dir
|
| 24 |
+
self.losses = []
|
| 25 |
+
self.val_loss = []
|
| 26 |
+
|
| 27 |
+
os.makedirs(self.log_dir)
|
| 28 |
+
self.writer = SummaryWriter(self.log_dir)
|
| 29 |
+
try:
|
| 30 |
+
dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1])
|
| 31 |
+
self.writer.add_graph(model, dummy_input)
|
| 32 |
+
except:
|
| 33 |
+
pass
|
| 34 |
+
|
| 35 |
+
def append_loss(self, epoch, loss, val_loss):
|
| 36 |
+
if not os.path.exists(self.log_dir):
|
| 37 |
+
os.makedirs(self.log_dir)
|
| 38 |
+
|
| 39 |
+
self.losses.append(loss)
|
| 40 |
+
self.val_loss.append(val_loss)
|
| 41 |
+
|
| 42 |
+
with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
|
| 43 |
+
f.write(str(loss))
|
| 44 |
+
f.write("\n")
|
| 45 |
+
with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:
|
| 46 |
+
f.write(str(val_loss))
|
| 47 |
+
f.write("\n")
|
| 48 |
+
|
| 49 |
+
self.writer.add_scalar('loss', loss, epoch)
|
| 50 |
+
self.writer.add_scalar('val_loss', val_loss, epoch)
|
| 51 |
+
self.loss_plot()
|
| 52 |
+
|
| 53 |
+
def loss_plot(self):
|
| 54 |
+
iters = range(len(self.losses))
|
| 55 |
+
|
| 56 |
+
plt.figure()
|
| 57 |
+
plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')
|
| 58 |
+
plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')
|
| 59 |
+
try:
|
| 60 |
+
if len(self.losses) < 25:
|
| 61 |
+
num = 5
|
| 62 |
+
else:
|
| 63 |
+
num = 15
|
| 64 |
+
|
| 65 |
+
plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
|
| 66 |
+
except:
|
| 67 |
+
pass
|
| 68 |
+
|
| 69 |
+
plt.grid(True)
|
| 70 |
+
plt.xlabel('Epoch')
|
| 71 |
+
plt.ylabel('Loss')
|
| 72 |
+
plt.legend(loc="upper right")
|
| 73 |
+
|
| 74 |
+
plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))
|
| 75 |
+
|
| 76 |
+
plt.cla()
|
| 77 |
+
plt.close("all")
|
| 78 |
+
|
| 79 |
+
class EvalCallback():
|
| 80 |
+
def __init__(self, net, input_shape, anchors, anchors_mask, class_names, num_classes, val_lines, log_dir, cuda, \
|
| 81 |
+
map_out_path=".temp_map_out", max_boxes=100, confidence=0.05, nms_iou=0.5, letterbox_image=True, MINOVERLAP=0.5, eval_flag=True, period=1):
|
| 82 |
+
super(EvalCallback, self).__init__()
|
| 83 |
+
|
| 84 |
+
self.net = net
|
| 85 |
+
self.input_shape = input_shape
|
| 86 |
+
self.anchors = anchors
|
| 87 |
+
self.anchors_mask = anchors_mask
|
| 88 |
+
self.class_names = class_names
|
| 89 |
+
self.num_classes = num_classes
|
| 90 |
+
self.val_lines = val_lines
|
| 91 |
+
self.log_dir = log_dir
|
| 92 |
+
self.cuda = cuda
|
| 93 |
+
self.map_out_path = map_out_path
|
| 94 |
+
self.max_boxes = max_boxes
|
| 95 |
+
self.confidence = confidence
|
| 96 |
+
self.nms_iou = nms_iou
|
| 97 |
+
self.letterbox_image = letterbox_image
|
| 98 |
+
self.MINOVERLAP = MINOVERLAP
|
| 99 |
+
self.eval_flag = eval_flag
|
| 100 |
+
self.period = period
|
| 101 |
+
|
| 102 |
+
self.bbox_util = DecodeBox(self.anchors, self.num_classes, (self.input_shape[0], self.input_shape[1]), self.anchors_mask)
|
| 103 |
+
|
| 104 |
+
self.maps = [0]
|
| 105 |
+
self.epoches = [0]
|
| 106 |
+
if self.eval_flag:
|
| 107 |
+
with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f:
|
| 108 |
+
f.write(str(0))
|
| 109 |
+
f.write("\n")
|
| 110 |
+
|
| 111 |
+
def get_map_txt(self, image_id, image, class_names, map_out_path):
|
| 112 |
+
f = open(os.path.join(map_out_path, "detection-results/"+image_id+".txt"), "w", encoding='utf-8')
|
| 113 |
+
image_shape = np.array(np.shape(image)[0:2])
|
| 114 |
+
image = cvtColor(image)
|
| 115 |
+
image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image)
|
| 116 |
+
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
|
| 117 |
+
|
| 118 |
+
with torch.no_grad():
|
| 119 |
+
images = torch.from_numpy(image_data)
|
| 120 |
+
if self.cuda:
|
| 121 |
+
images = images.cuda()
|
| 122 |
+
outputs = self.net(images)
|
| 123 |
+
outputs = self.bbox_util.decode_box(outputs)
|
| 124 |
+
results = self.bbox_util.non_max_suppression(torch.cat(outputs, 1), self.num_classes, self.input_shape,
|
| 125 |
+
image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou)
|
| 126 |
+
|
| 127 |
+
if results[0] is None:
|
| 128 |
+
return
|
| 129 |
+
|
| 130 |
+
top_label = np.array(results[0][:, 6], dtype = 'int32')
|
| 131 |
+
top_conf = results[0][:, 4] * results[0][:, 5]
|
| 132 |
+
top_boxes = results[0][:, :4]
|
| 133 |
+
|
| 134 |
+
top_100 = np.argsort(top_conf)[::-1][:self.max_boxes]
|
| 135 |
+
top_boxes = top_boxes[top_100]
|
| 136 |
+
top_conf = top_conf[top_100]
|
| 137 |
+
top_label = top_label[top_100]
|
| 138 |
+
|
| 139 |
+
for i, c in list(enumerate(top_label)):
|
| 140 |
+
predicted_class = self.class_names[int(c)]
|
| 141 |
+
box = top_boxes[i]
|
| 142 |
+
score = str(top_conf[i])
|
| 143 |
+
|
| 144 |
+
top, left, bottom, right = box
|
| 145 |
+
if predicted_class not in class_names:
|
| 146 |
+
continue
|
| 147 |
+
|
| 148 |
+
f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom))))
|
| 149 |
+
|
| 150 |
+
f.close()
|
| 151 |
+
return
|
| 152 |
+
|
| 153 |
+
def on_epoch_end(self, epoch, model_eval):
|
| 154 |
+
if epoch % self.period == 0 and self.eval_flag:
|
| 155 |
+
self.net = model_eval
|
| 156 |
+
if not os.path.exists(self.map_out_path):
|
| 157 |
+
os.makedirs(self.map_out_path)
|
| 158 |
+
if not os.path.exists(os.path.join(self.map_out_path, "ground-truth")):
|
| 159 |
+
os.makedirs(os.path.join(self.map_out_path, "ground-truth"))
|
| 160 |
+
if not os.path.exists(os.path.join(self.map_out_path, "detection-results")):
|
| 161 |
+
os.makedirs(os.path.join(self.map_out_path, "detection-results"))
|
| 162 |
+
print("Get map.")
|
| 163 |
+
for annotation_line in tqdm(self.val_lines):
|
| 164 |
+
line = annotation_line.split()
|
| 165 |
+
image_id = os.path.basename(line[0]).split('.')[0]
|
| 166 |
+
image = Image.open(line[0])
|
| 167 |
+
gt_boxes = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
|
| 168 |
+
self.get_map_txt(image_id, image, self.class_names, self.map_out_path)
|
| 169 |
+
|
| 170 |
+
with open(os.path.join(self.map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f:
|
| 171 |
+
for box in gt_boxes:
|
| 172 |
+
left, top, right, bottom, obj = box
|
| 173 |
+
obj_name = self.class_names[obj]
|
| 174 |
+
new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom))
|
| 175 |
+
|
| 176 |
+
print("Calculate Map.")
|
| 177 |
+
try:
|
| 178 |
+
temp_map = get_coco_map(class_names = self.class_names, path = self.map_out_path)[1]
|
| 179 |
+
except:
|
| 180 |
+
temp_map = get_map(self.MINOVERLAP, False, path = self.map_out_path)
|
| 181 |
+
self.maps.append(temp_map)
|
| 182 |
+
self.epoches.append(epoch)
|
| 183 |
+
|
| 184 |
+
with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f:
|
| 185 |
+
f.write(str(temp_map))
|
| 186 |
+
f.write("\n")
|
| 187 |
+
|
| 188 |
+
plt.figure()
|
| 189 |
+
plt.plot(self.epoches, self.maps, 'red', linewidth = 2, label='train map')
|
| 190 |
+
|
| 191 |
+
plt.grid(True)
|
| 192 |
+
plt.xlabel('Epoch')
|
| 193 |
+
plt.ylabel('Map %s'%str(self.MINOVERLAP))
|
| 194 |
+
plt.title('A Map Curve')
|
| 195 |
+
plt.legend(loc="upper right")
|
| 196 |
+
|
| 197 |
+
plt.savefig(os.path.join(self.log_dir, "epoch_map.png"))
|
| 198 |
+
plt.cla()
|
| 199 |
+
plt.close("all")
|
| 200 |
+
|
| 201 |
+
print("Get map done.")
|
| 202 |
+
shutil.rmtree(self.map_out_path)
|
utils/dataloader.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from random import sample, shuffle
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from torch.utils.data.dataset import Dataset
|
| 8 |
+
|
| 9 |
+
from utils.utils import cvtColor, preprocess_input
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class YoloDataset(Dataset):
|
| 13 |
+
def __init__(self, annotation_lines, input_shape, num_classes, epoch_length, \
|
| 14 |
+
mosaic, mixup, mosaic_prob, mixup_prob, train, special_aug_ratio = 0.7):
|
| 15 |
+
super(YoloDataset, self).__init__()
|
| 16 |
+
self.annotation_lines = annotation_lines
|
| 17 |
+
self.input_shape = input_shape
|
| 18 |
+
self.num_classes = num_classes
|
| 19 |
+
self.epoch_length = epoch_length
|
| 20 |
+
self.mosaic = mosaic
|
| 21 |
+
self.mosaic_prob = mosaic_prob
|
| 22 |
+
self.mixup = mixup
|
| 23 |
+
self.mixup_prob = mixup_prob
|
| 24 |
+
self.train = train
|
| 25 |
+
self.special_aug_ratio = special_aug_ratio
|
| 26 |
+
|
| 27 |
+
self.epoch_now = -1
|
| 28 |
+
self.length = len(self.annotation_lines)
|
| 29 |
+
|
| 30 |
+
def __len__(self):
|
| 31 |
+
return self.length
|
| 32 |
+
|
| 33 |
+
def __getitem__(self, index):
|
| 34 |
+
index = index % self.length
|
| 35 |
+
if self.mosaic and self.rand() < self.mosaic_prob and self.epoch_now < self.epoch_length * self.special_aug_ratio:
|
| 36 |
+
lines = sample(self.annotation_lines, 3)
|
| 37 |
+
lines.append(self.annotation_lines[index])
|
| 38 |
+
shuffle(lines)
|
| 39 |
+
image, box = self.get_random_data_with_Mosaic(lines, self.input_shape)
|
| 40 |
+
|
| 41 |
+
if self.mixup and self.rand() < self.mixup_prob:
|
| 42 |
+
lines = sample(self.annotation_lines, 1)
|
| 43 |
+
image_2, box_2 = self.get_random_data(lines[0], self.input_shape, random = self.train)
|
| 44 |
+
image, box = self.get_random_data_with_MixUp(image, box, image_2, box_2)
|
| 45 |
+
else:
|
| 46 |
+
image, box = self.get_random_data(self.annotation_lines[index], self.input_shape, random = self.train)
|
| 47 |
+
|
| 48 |
+
image = np.transpose(preprocess_input(np.array(image, dtype=np.float32)), (2, 0, 1))
|
| 49 |
+
box = np.array(box, dtype=np.float32)
|
| 50 |
+
if len(box) != 0:
|
| 51 |
+
box[:, [0, 2]] = box[:, [0, 2]] / self.input_shape[1]
|
| 52 |
+
box[:, [1, 3]] = box[:, [1, 3]] / self.input_shape[0]
|
| 53 |
+
|
| 54 |
+
box[:, 2:4] = box[:, 2:4] - box[:, 0:2]
|
| 55 |
+
box[:, 0:2] = box[:, 0:2] + box[:, 2:4] / 2
|
| 56 |
+
return image, box
|
| 57 |
+
|
| 58 |
+
def rand(self, a=0, b=1):
|
| 59 |
+
return np.random.rand()*(b-a) + a
|
| 60 |
+
|
| 61 |
+
def get_random_data(self, annotation_line, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.4, random=True):
|
| 62 |
+
line = annotation_line.split()
|
| 63 |
+
image = Image.open(line[0])
|
| 64 |
+
image = cvtColor(image)
|
| 65 |
+
iw, ih = image.size
|
| 66 |
+
h, w = input_shape
|
| 67 |
+
box = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
|
| 68 |
+
|
| 69 |
+
if not random:
|
| 70 |
+
scale = min(w/iw, h/ih)
|
| 71 |
+
nw = int(iw*scale)
|
| 72 |
+
nh = int(ih*scale)
|
| 73 |
+
dx = (w-nw)//2
|
| 74 |
+
dy = (h-nh)//2
|
| 75 |
+
|
| 76 |
+
image = image.resize((nw,nh), Image.BICUBIC)
|
| 77 |
+
new_image = Image.new('RGB', (w,h), (128,128,128))
|
| 78 |
+
new_image.paste(image, (dx, dy))
|
| 79 |
+
image_data = np.array(new_image, np.float32)
|
| 80 |
+
|
| 81 |
+
if len(box)>0:
|
| 82 |
+
np.random.shuffle(box)
|
| 83 |
+
box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
|
| 84 |
+
box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
|
| 85 |
+
box[:, 0:2][box[:, 0:2]<0] = 0
|
| 86 |
+
box[:, 2][box[:, 2]>w] = w
|
| 87 |
+
box[:, 3][box[:, 3]>h] = h
|
| 88 |
+
box_w = box[:, 2] - box[:, 0]
|
| 89 |
+
box_h = box[:, 3] - box[:, 1]
|
| 90 |
+
box = box[np.logical_and(box_w>1, box_h>1)]
|
| 91 |
+
|
| 92 |
+
return image_data, box
|
| 93 |
+
|
| 94 |
+
new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter)
|
| 95 |
+
scale = self.rand(.25, 2)
|
| 96 |
+
if new_ar < 1:
|
| 97 |
+
nh = int(scale*h)
|
| 98 |
+
nw = int(nh*new_ar)
|
| 99 |
+
else:
|
| 100 |
+
nw = int(scale*w)
|
| 101 |
+
nh = int(nw/new_ar)
|
| 102 |
+
image = image.resize((nw,nh), Image.BICUBIC)
|
| 103 |
+
dx = int(self.rand(0, w-nw))
|
| 104 |
+
dy = int(self.rand(0, h-nh))
|
| 105 |
+
new_image = Image.new('RGB', (w,h), (128,128,128))
|
| 106 |
+
new_image.paste(image, (dx, dy))
|
| 107 |
+
image = new_image
|
| 108 |
+
|
| 109 |
+
flip = self.rand()<.5
|
| 110 |
+
if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT)
|
| 111 |
+
|
| 112 |
+
image_data = np.array(image, np.uint8)
|
| 113 |
+
|
| 114 |
+
r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1
|
| 115 |
+
|
| 116 |
+
hue, sat, val = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV))
|
| 117 |
+
dtype = image_data.dtype
|
| 118 |
+
|
| 119 |
+
x = np.arange(0, 256, dtype=r.dtype)
|
| 120 |
+
lut_hue = ((x * r[0]) % 180).astype(dtype)
|
| 121 |
+
lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
|
| 122 |
+
lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
|
| 123 |
+
|
| 124 |
+
image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
|
| 125 |
+
image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB)
|
| 126 |
+
|
| 127 |
+
if len(box)>0:
|
| 128 |
+
np.random.shuffle(box)
|
| 129 |
+
box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
|
| 130 |
+
box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
|
| 131 |
+
if flip: box[:, [0,2]] = w - box[:, [2,0]]
|
| 132 |
+
box[:, 0:2][box[:, 0:2]<0] = 0
|
| 133 |
+
box[:, 2][box[:, 2]>w] = w
|
| 134 |
+
box[:, 3][box[:, 3]>h] = h
|
| 135 |
+
box_w = box[:, 2] - box[:, 0]
|
| 136 |
+
box_h = box[:, 3] - box[:, 1]
|
| 137 |
+
box = box[np.logical_and(box_w>1, box_h>1)]
|
| 138 |
+
|
| 139 |
+
return image_data, box
|
| 140 |
+
|
| 141 |
+
def merge_bboxes(self, bboxes, cutx, cuty):
|
| 142 |
+
merge_bbox = []
|
| 143 |
+
for i in range(len(bboxes)):
|
| 144 |
+
for box in bboxes[i]:
|
| 145 |
+
tmp_box = []
|
| 146 |
+
x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
|
| 147 |
+
|
| 148 |
+
if i == 0:
|
| 149 |
+
if y1 > cuty or x1 > cutx:
|
| 150 |
+
continue
|
| 151 |
+
if y2 >= cuty and y1 <= cuty:
|
| 152 |
+
y2 = cuty
|
| 153 |
+
if x2 >= cutx and x1 <= cutx:
|
| 154 |
+
x2 = cutx
|
| 155 |
+
|
| 156 |
+
if i == 1:
|
| 157 |
+
if y2 < cuty or x1 > cutx:
|
| 158 |
+
continue
|
| 159 |
+
if y2 >= cuty and y1 <= cuty:
|
| 160 |
+
y1 = cuty
|
| 161 |
+
if x2 >= cutx and x1 <= cutx:
|
| 162 |
+
x2 = cutx
|
| 163 |
+
|
| 164 |
+
if i == 2:
|
| 165 |
+
if y2 < cuty or x2 < cutx:
|
| 166 |
+
continue
|
| 167 |
+
if y2 >= cuty and y1 <= cuty:
|
| 168 |
+
y1 = cuty
|
| 169 |
+
if x2 >= cutx and x1 <= cutx:
|
| 170 |
+
x1 = cutx
|
| 171 |
+
|
| 172 |
+
if i == 3:
|
| 173 |
+
if y1 > cuty or x2 < cutx:
|
| 174 |
+
continue
|
| 175 |
+
if y2 >= cuty and y1 <= cuty:
|
| 176 |
+
y2 = cuty
|
| 177 |
+
if x2 >= cutx and x1 <= cutx:
|
| 178 |
+
x1 = cutx
|
| 179 |
+
tmp_box.append(x1)
|
| 180 |
+
tmp_box.append(y1)
|
| 181 |
+
tmp_box.append(x2)
|
| 182 |
+
tmp_box.append(y2)
|
| 183 |
+
tmp_box.append(box[-1])
|
| 184 |
+
merge_bbox.append(tmp_box)
|
| 185 |
+
return merge_bbox
|
| 186 |
+
|
| 187 |
+
def get_random_data_with_Mosaic(self, annotation_line, input_shape, jitter=0.3, hue=.1, sat=0.7, val=0.4):
|
| 188 |
+
h, w = input_shape
|
| 189 |
+
min_offset_x = self.rand(0.3, 0.7)
|
| 190 |
+
min_offset_y = self.rand(0.3, 0.7)
|
| 191 |
+
|
| 192 |
+
image_datas = []
|
| 193 |
+
box_datas = []
|
| 194 |
+
index = 0
|
| 195 |
+
for line in annotation_line:
|
| 196 |
+
line_content = line.split()
|
| 197 |
+
image = Image.open(line_content[0])
|
| 198 |
+
image = cvtColor(image)
|
| 199 |
+
|
| 200 |
+
iw, ih = image.size
|
| 201 |
+
box = np.array([np.array(list(map(int,box.split(',')))) for box in line_content[1:]])
|
| 202 |
+
|
| 203 |
+
flip = self.rand()<.5
|
| 204 |
+
if flip and len(box)>0:
|
| 205 |
+
image = image.transpose(Image.FLIP_LEFT_RIGHT)
|
| 206 |
+
box[:, [0,2]] = iw - box[:, [2,0]]
|
| 207 |
+
|
| 208 |
+
new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter)
|
| 209 |
+
scale = self.rand(.4, 1)
|
| 210 |
+
if new_ar < 1:
|
| 211 |
+
nh = int(scale*h)
|
| 212 |
+
nw = int(nh*new_ar)
|
| 213 |
+
else:
|
| 214 |
+
nw = int(scale*w)
|
| 215 |
+
nh = int(nw/new_ar)
|
| 216 |
+
image = image.resize((nw, nh), Image.BICUBIC)
|
| 217 |
+
|
| 218 |
+
if index == 0:
|
| 219 |
+
dx = int(w*min_offset_x) - nw
|
| 220 |
+
dy = int(h*min_offset_y) - nh
|
| 221 |
+
elif index == 1:
|
| 222 |
+
dx = int(w*min_offset_x) - nw
|
| 223 |
+
dy = int(h*min_offset_y)
|
| 224 |
+
elif index == 2:
|
| 225 |
+
dx = int(w*min_offset_x)
|
| 226 |
+
dy = int(h*min_offset_y)
|
| 227 |
+
elif index == 3:
|
| 228 |
+
dx = int(w*min_offset_x)
|
| 229 |
+
dy = int(h*min_offset_y) - nh
|
| 230 |
+
|
| 231 |
+
new_image = Image.new('RGB', (w,h), (128,128,128))
|
| 232 |
+
new_image.paste(image, (dx, dy))
|
| 233 |
+
image_data = np.array(new_image)
|
| 234 |
+
|
| 235 |
+
index = index + 1
|
| 236 |
+
box_data = []
|
| 237 |
+
if len(box)>0:
|
| 238 |
+
np.random.shuffle(box)
|
| 239 |
+
box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
|
| 240 |
+
box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
|
| 241 |
+
box[:, 0:2][box[:, 0:2]<0] = 0
|
| 242 |
+
box[:, 2][box[:, 2]>w] = w
|
| 243 |
+
box[:, 3][box[:, 3]>h] = h
|
| 244 |
+
box_w = box[:, 2] - box[:, 0]
|
| 245 |
+
box_h = box[:, 3] - box[:, 1]
|
| 246 |
+
box = box[np.logical_and(box_w>1, box_h>1)]
|
| 247 |
+
box_data = np.zeros((len(box),5))
|
| 248 |
+
box_data[:len(box)] = box
|
| 249 |
+
|
| 250 |
+
image_datas.append(image_data)
|
| 251 |
+
box_datas.append(box_data)
|
| 252 |
+
|
| 253 |
+
cutx = int(w * min_offset_x)
|
| 254 |
+
cuty = int(h * min_offset_y)
|
| 255 |
+
|
| 256 |
+
new_image = np.zeros([h, w, 3])
|
| 257 |
+
new_image[:cuty, :cutx, :] = image_datas[0][:cuty, :cutx, :]
|
| 258 |
+
new_image[cuty:, :cutx, :] = image_datas[1][cuty:, :cutx, :]
|
| 259 |
+
new_image[cuty:, cutx:, :] = image_datas[2][cuty:, cutx:, :]
|
| 260 |
+
new_image[:cuty, cutx:, :] = image_datas[3][:cuty, cutx:, :]
|
| 261 |
+
|
| 262 |
+
new_image = np.array(new_image, np.uint8)
|
| 263 |
+
r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1
|
| 264 |
+
hue, sat, val = cv2.split(cv2.cvtColor(new_image, cv2.COLOR_RGB2HSV))
|
| 265 |
+
dtype = new_image.dtype
|
| 266 |
+
x = np.arange(0, 256, dtype=r.dtype)
|
| 267 |
+
lut_hue = ((x * r[0]) % 180).astype(dtype)
|
| 268 |
+
lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
|
| 269 |
+
lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
|
| 270 |
+
|
| 271 |
+
new_image = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
|
| 272 |
+
new_image = cv2.cvtColor(new_image, cv2.COLOR_HSV2RGB)
|
| 273 |
+
|
| 274 |
+
new_boxes = self.merge_bboxes(box_datas, cutx, cuty)
|
| 275 |
+
|
| 276 |
+
return new_image, new_boxes
|
| 277 |
+
|
| 278 |
+
def get_random_data_with_MixUp(self, image_1, box_1, image_2, box_2):
|
| 279 |
+
new_image = np.array(image_1, np.float32) * 0.5 + np.array(image_2, np.float32) * 0.5
|
| 280 |
+
if len(box_1) == 0:
|
| 281 |
+
new_boxes = box_2
|
| 282 |
+
elif len(box_2) == 0:
|
| 283 |
+
new_boxes = box_1
|
| 284 |
+
else:
|
| 285 |
+
new_boxes = np.concatenate([box_1, box_2], axis=0)
|
| 286 |
+
return new_image, new_boxes
|
| 287 |
+
|
| 288 |
+
def yolo_dataset_collate(batch):
|
| 289 |
+
images = []
|
| 290 |
+
bboxes = []
|
| 291 |
+
for img, box in batch:
|
| 292 |
+
images.append(img)
|
| 293 |
+
bboxes.append(box)
|
| 294 |
+
images = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
|
| 295 |
+
bboxes = [torch.from_numpy(ann).type(torch.FloatTensor) for ann in bboxes]
|
| 296 |
+
return images, bboxes
|
utils/utils.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def cvtColor(image):
|
| 9 |
+
if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
|
| 10 |
+
return image
|
| 11 |
+
else:
|
| 12 |
+
image = image.convert('RGB')
|
| 13 |
+
return image
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def resize_image(image, size, letterbox_image):
|
| 18 |
+
iw, ih = image.size
|
| 19 |
+
w, h = size
|
| 20 |
+
if letterbox_image:
|
| 21 |
+
scale = min(w / iw, h / ih)
|
| 22 |
+
nw = int(iw * scale)
|
| 23 |
+
nh = int(ih * scale)
|
| 24 |
+
|
| 25 |
+
image = image.resize((nw, nh), Image.BICUBIC)
|
| 26 |
+
new_image = Image.new('RGB', size, (128, 128, 128))
|
| 27 |
+
new_image.paste(image, ((w - nw) // 2, (h - nh) // 2))
|
| 28 |
+
else:
|
| 29 |
+
new_image = image.resize((w, h), Image.BICUBIC)
|
| 30 |
+
return new_image
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_classes(classes_path):
|
| 34 |
+
with open(classes_path, encoding='utf-8') as f:
|
| 35 |
+
class_names = f.readlines()
|
| 36 |
+
class_names = [c.strip() for c in class_names]
|
| 37 |
+
return class_names, len(class_names)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_anchors(anchors_path):
|
| 41 |
+
'''loads the anchors from a file'''
|
| 42 |
+
with open(anchors_path, encoding='utf-8') as f:
|
| 43 |
+
anchors = f.readline()
|
| 44 |
+
anchors = [float(x) for x in anchors.split(',')]
|
| 45 |
+
anchors = np.array(anchors).reshape(-1, 2)
|
| 46 |
+
return anchors, len(anchors)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_lr(optimizer):
|
| 50 |
+
for param_group in optimizer.param_groups:
|
| 51 |
+
return param_group['lr']
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def seed_everything(seed=11):
|
| 55 |
+
random.seed(seed)
|
| 56 |
+
np.random.seed(seed)
|
| 57 |
+
torch.manual_seed(seed)
|
| 58 |
+
torch.cuda.manual_seed(seed)
|
| 59 |
+
torch.cuda.manual_seed_all(seed)
|
| 60 |
+
torch.backends.cudnn.deterministic = True
|
| 61 |
+
torch.backends.cudnn.benchmark = False
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def worker_init_fn(worker_id, rank, seed):
|
| 65 |
+
worker_seed = rank + seed
|
| 66 |
+
random.seed(worker_seed)
|
| 67 |
+
np.random.seed(worker_seed)
|
| 68 |
+
torch.manual_seed(worker_seed)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def preprocess_input(image):
|
| 72 |
+
image /= 255.0
|
| 73 |
+
return image
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def show_config(**kwargs):
|
| 77 |
+
print('Configurations:')
|
| 78 |
+
print('-' * 70)
|
| 79 |
+
print('|%25s | %40s|' % ('keys', 'values'))
|
| 80 |
+
print('-' * 70)
|
| 81 |
+
for key, value in kwargs.items():
|
| 82 |
+
print('|%25s | %40s|' % (str(key), str(value)))
|
| 83 |
+
print('-' * 70)
|
utils/utils_bbox.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torchvision.ops import nms
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
class DecodeBox():
|
| 7 |
+
def __init__(self, anchors, num_classes, input_shape, anchors_mask = [[6,7,8], [3,4,5], [0,1,2]]):
|
| 8 |
+
super(DecodeBox, self).__init__()
|
| 9 |
+
self.anchors = anchors
|
| 10 |
+
self.num_classes = num_classes
|
| 11 |
+
self.bbox_attrs = 5 + num_classes
|
| 12 |
+
self.input_shape = input_shape
|
| 13 |
+
self.anchors_mask = anchors_mask
|
| 14 |
+
|
| 15 |
+
def decode_box(self, inputs):
|
| 16 |
+
outputs = []
|
| 17 |
+
for i, input in enumerate(inputs):
|
| 18 |
+
batch_size = input.size(0)
|
| 19 |
+
input_height = input.size(2)
|
| 20 |
+
input_width = input.size(3)
|
| 21 |
+
|
| 22 |
+
stride_h = self.input_shape[0] / input_height
|
| 23 |
+
stride_w = self.input_shape[1] / input_width
|
| 24 |
+
scaled_anchors = [(anchor_width / stride_w, anchor_height / stride_h) for anchor_width, anchor_height in self.anchors[self.anchors_mask[i]]]
|
| 25 |
+
|
| 26 |
+
prediction = input.view(batch_size, len(self.anchors_mask[i]),
|
| 27 |
+
self.bbox_attrs, input_height, input_width).permute(0, 1, 3, 4, 2).contiguous()
|
| 28 |
+
|
| 29 |
+
x = torch.sigmoid(prediction[..., 0])
|
| 30 |
+
y = torch.sigmoid(prediction[..., 1])
|
| 31 |
+
w = prediction[..., 2]
|
| 32 |
+
h = prediction[..., 3]
|
| 33 |
+
conf = torch.sigmoid(prediction[..., 4])
|
| 34 |
+
pred_cls = torch.sigmoid(prediction[..., 5:])
|
| 35 |
+
|
| 36 |
+
FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor
|
| 37 |
+
LongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor
|
| 38 |
+
|
| 39 |
+
grid_x = torch.linspace(0, input_width - 1, input_width).repeat(input_height, 1).repeat(
|
| 40 |
+
batch_size * len(self.anchors_mask[i]), 1, 1).view(x.shape).type(FloatTensor)
|
| 41 |
+
grid_y = torch.linspace(0, input_height - 1, input_height).repeat(input_width, 1).t().repeat(
|
| 42 |
+
batch_size * len(self.anchors_mask[i]), 1, 1).view(y.shape).type(FloatTensor)
|
| 43 |
+
|
| 44 |
+
anchor_w = FloatTensor(scaled_anchors).index_select(1, LongTensor([0]))
|
| 45 |
+
anchor_h = FloatTensor(scaled_anchors).index_select(1, LongTensor([1]))
|
| 46 |
+
anchor_w = anchor_w.repeat(batch_size, 1).repeat(1, 1, input_height * input_width).view(w.shape)
|
| 47 |
+
anchor_h = anchor_h.repeat(batch_size, 1).repeat(1, 1, input_height * input_width).view(h.shape)
|
| 48 |
+
|
| 49 |
+
pred_boxes = FloatTensor(prediction[..., :4].shape)
|
| 50 |
+
pred_boxes[..., 0] = x.data + grid_x
|
| 51 |
+
pred_boxes[..., 1] = y.data + grid_y
|
| 52 |
+
pred_boxes[..., 2] = torch.exp(w.data) * anchor_w
|
| 53 |
+
pred_boxes[..., 3] = torch.exp(h.data) * anchor_h
|
| 54 |
+
|
| 55 |
+
_scale = torch.Tensor([input_width, input_height, input_width, input_height]).type(FloatTensor)
|
| 56 |
+
output = torch.cat((pred_boxes.view(batch_size, -1, 4) / _scale,
|
| 57 |
+
conf.view(batch_size, -1, 1), pred_cls.view(batch_size, -1, self.num_classes)), -1)
|
| 58 |
+
outputs.append(output.data)
|
| 59 |
+
return outputs
|
| 60 |
+
|
| 61 |
+
def yolo_correct_boxes(self, box_xy, box_wh, input_shape, image_shape, letterbox_image):
|
| 62 |
+
box_yx = box_xy[..., ::-1]
|
| 63 |
+
box_hw = box_wh[..., ::-1]
|
| 64 |
+
input_shape = np.array(input_shape)
|
| 65 |
+
image_shape = np.array(image_shape)
|
| 66 |
+
|
| 67 |
+
if letterbox_image:
|
| 68 |
+
new_shape = np.round(image_shape * np.min(input_shape/image_shape))
|
| 69 |
+
offset = (input_shape - new_shape)/2./input_shape
|
| 70 |
+
scale = input_shape/new_shape
|
| 71 |
+
|
| 72 |
+
box_yx = (box_yx - offset) * scale
|
| 73 |
+
box_hw *= scale
|
| 74 |
+
|
| 75 |
+
box_mins = box_yx - (box_hw / 2.)
|
| 76 |
+
box_maxes = box_yx + (box_hw / 2.)
|
| 77 |
+
boxes = np.concatenate([box_mins[..., 0:1], box_mins[..., 1:2], box_maxes[..., 0:1], box_maxes[..., 1:2]], axis=-1)
|
| 78 |
+
boxes *= np.concatenate([image_shape, image_shape], axis=-1)
|
| 79 |
+
return boxes
|
| 80 |
+
|
| 81 |
+
def non_max_suppression(self, prediction, num_classes, input_shape, image_shape, letterbox_image, conf_thres=0.5, nms_thres=0.4):
|
| 82 |
+
box_corner = prediction.new(prediction.shape)
|
| 83 |
+
box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
|
| 84 |
+
box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
|
| 85 |
+
box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
|
| 86 |
+
box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
|
| 87 |
+
prediction[:, :, :4] = box_corner[:, :, :4]
|
| 88 |
+
|
| 89 |
+
output = [None for _ in range(len(prediction))]
|
| 90 |
+
for i, image_pred in enumerate(prediction):
|
| 91 |
+
class_conf, class_pred = torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True)
|
| 92 |
+
|
| 93 |
+
conf_mask = (image_pred[:, 4] * class_conf[:, 0] >= conf_thres).squeeze()
|
| 94 |
+
|
| 95 |
+
image_pred = image_pred[conf_mask]
|
| 96 |
+
class_conf = class_conf[conf_mask]
|
| 97 |
+
class_pred = class_pred[conf_mask]
|
| 98 |
+
if not image_pred.size(0):
|
| 99 |
+
continue
|
| 100 |
+
detections = torch.cat((image_pred[:, :5], class_conf.float(), class_pred.float()), 1)
|
| 101 |
+
|
| 102 |
+
unique_labels = detections[:, -1].cpu().unique()
|
| 103 |
+
|
| 104 |
+
if prediction.is_cuda:
|
| 105 |
+
unique_labels = unique_labels.cuda()
|
| 106 |
+
detections = detections.cuda()
|
| 107 |
+
|
| 108 |
+
for c in unique_labels:
|
| 109 |
+
detections_class = detections[detections[:, -1] == c]
|
| 110 |
+
|
| 111 |
+
keep = nms(
|
| 112 |
+
detections_class[:, :4],
|
| 113 |
+
detections_class[:, 4] * detections_class[:, 5],
|
| 114 |
+
nms_thres
|
| 115 |
+
)
|
| 116 |
+
max_detections = detections_class[keep]
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
output[i] = max_detections if output[i] is None else torch.cat((output[i], max_detections))
|
| 120 |
+
|
| 121 |
+
if output[i] is not None:
|
| 122 |
+
output[i] = output[i].cpu().numpy()
|
| 123 |
+
box_xy, box_wh = (output[i][:, 0:2] + output[i][:, 2:4])/2, output[i][:, 2:4] - output[i][:, 0:2]
|
| 124 |
+
output[i][:, :4] = self.yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image)
|
| 125 |
+
return output
|
| 126 |
+
|
| 127 |
+
class DecodeBoxNP():
|
| 128 |
+
def __init__(self, anchors, num_classes, input_shape, anchors_mask = [[6,7,8], [3,4,5], [0,1,2]]):
|
| 129 |
+
super(DecodeBoxNP, self).__init__()
|
| 130 |
+
self.anchors = anchors
|
| 131 |
+
self.num_classes = num_classes
|
| 132 |
+
self.bbox_attrs = 5 + num_classes
|
| 133 |
+
self.input_shape = input_shape
|
| 134 |
+
self.anchors_mask = anchors_mask
|
| 135 |
+
|
| 136 |
+
def sigmoid(self, x):
|
| 137 |
+
return 1 / (1 + np.exp(-x))
|
| 138 |
+
|
| 139 |
+
def decode_box(self, inputs):
|
| 140 |
+
outputs = []
|
| 141 |
+
for i, input in enumerate(inputs):
|
| 142 |
+
batch_size = np.shape(input)[0]
|
| 143 |
+
input_height = np.shape(input)[2]
|
| 144 |
+
input_width = np.shape(input)[3]
|
| 145 |
+
|
| 146 |
+
stride_h = self.input_shape[0] / input_height
|
| 147 |
+
stride_w = self.input_shape[1] / input_width
|
| 148 |
+
scaled_anchors = [(anchor_width / stride_w, anchor_height / stride_h) for anchor_width, anchor_height in self.anchors[self.anchors_mask[i]]]
|
| 149 |
+
|
| 150 |
+
prediction = np.transpose(np.reshape(input, (batch_size, len(self.anchors_mask[i]), self.bbox_attrs, input_height, input_width)), (0, 1, 3, 4, 2))
|
| 151 |
+
|
| 152 |
+
x = self.sigmoid(prediction[..., 0])
|
| 153 |
+
y = self.sigmoid(prediction[..., 1])
|
| 154 |
+
w = prediction[..., 2]
|
| 155 |
+
h = prediction[..., 3]
|
| 156 |
+
conf = self.sigmoid(prediction[..., 4])
|
| 157 |
+
pred_cls = self.sigmoid(prediction[..., 5:])
|
| 158 |
+
|
| 159 |
+
grid_x = np.repeat(np.expand_dims(np.repeat(np.expand_dims(np.linspace(0, input_width - 1, input_width), 0), input_height, axis=0), 0), batch_size * len(self.anchors_mask[i]), axis=0)
|
| 160 |
+
grid_x = np.reshape(grid_x, np.shape(x))
|
| 161 |
+
grid_y = np.repeat(np.expand_dims(np.repeat(np.expand_dims(np.linspace(0, input_height - 1, input_height), 0), input_width, axis=0).T, 0), batch_size * len(self.anchors_mask[i]), axis=0)
|
| 162 |
+
grid_y = np.reshape(grid_y, np.shape(y))
|
| 163 |
+
|
| 164 |
+
anchor_w = np.repeat(np.expand_dims(np.repeat(np.expand_dims(np.array(scaled_anchors)[:, 0], 0), batch_size, axis=0), -1), input_height * input_width, axis=-1)
|
| 165 |
+
anchor_h = np.repeat(np.expand_dims(np.repeat(np.expand_dims(np.array(scaled_anchors)[:, 1], 0), batch_size, axis=0), -1), input_height * input_width, axis=-1)
|
| 166 |
+
anchor_w = np.reshape(anchor_w, np.shape(w))
|
| 167 |
+
anchor_h = np.reshape(anchor_h, np.shape(h))
|
| 168 |
+
pred_boxes = np.zeros(np.shape(prediction[..., :4]))
|
| 169 |
+
pred_boxes[..., 0] = x + grid_x
|
| 170 |
+
pred_boxes[..., 1] = y + grid_y
|
| 171 |
+
pred_boxes[..., 2] = np.exp(w) * anchor_w
|
| 172 |
+
pred_boxes[..., 3] = np.exp(h) * anchor_h
|
| 173 |
+
|
| 174 |
+
_scale = np.array([input_width, input_height, input_width, input_height])
|
| 175 |
+
output = np.concatenate([np.reshape(pred_boxes, (batch_size, -1, 4)) / _scale,
|
| 176 |
+
np.reshape(conf, (batch_size, -1, 1)), np.reshape(pred_cls, (batch_size, -1, self.num_classes))], -1)
|
| 177 |
+
outputs.append(output)
|
| 178 |
+
return outputs
|
| 179 |
+
|
| 180 |
+
def bbox_iou(self, box1, box2, x1y1x2y2=True):
|
| 181 |
+
"""
|
| 182 |
+
计算IOU
|
| 183 |
+
"""
|
| 184 |
+
if not x1y1x2y2:
|
| 185 |
+
b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2
|
| 186 |
+
b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2
|
| 187 |
+
b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
|
| 188 |
+
b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2
|
| 189 |
+
else:
|
| 190 |
+
b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3]
|
| 191 |
+
b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3]
|
| 192 |
+
|
| 193 |
+
inter_rect_x1 = np.maximum(b1_x1, b2_x1)
|
| 194 |
+
inter_rect_y1 = np.maximum(b1_y1, b2_y1)
|
| 195 |
+
inter_rect_x2 = np.minimum(b1_x2, b2_x2)
|
| 196 |
+
inter_rect_y2 = np.minimum(b1_y2, b2_y2)
|
| 197 |
+
|
| 198 |
+
inter_area = np.maximum(inter_rect_x2 - inter_rect_x1, 0) * \
|
| 199 |
+
np.maximum(inter_rect_y2 - inter_rect_y1, 0)
|
| 200 |
+
|
| 201 |
+
b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
|
| 202 |
+
b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
|
| 203 |
+
|
| 204 |
+
iou = inter_area / np.maximum(b1_area + b2_area - inter_area, 1e-6)
|
| 205 |
+
|
| 206 |
+
return iou
|
| 207 |
+
|
| 208 |
+
def yolo_correct_boxes(self, box_xy, box_wh, input_shape, image_shape, letterbox_image):
|
| 209 |
+
box_yx = box_xy[..., ::-1]
|
| 210 |
+
box_hw = box_wh[..., ::-1]
|
| 211 |
+
input_shape = np.array(input_shape)
|
| 212 |
+
image_shape = np.array(image_shape)
|
| 213 |
+
|
| 214 |
+
if letterbox_image:
|
| 215 |
+
new_shape = np.round(image_shape * np.min(input_shape/image_shape))
|
| 216 |
+
offset = (input_shape - new_shape)/2./input_shape
|
| 217 |
+
scale = input_shape/new_shape
|
| 218 |
+
|
| 219 |
+
box_yx = (box_yx - offset) * scale
|
| 220 |
+
box_hw *= scale
|
| 221 |
+
|
| 222 |
+
box_mins = box_yx - (box_hw / 2.)
|
| 223 |
+
box_maxes = box_yx + (box_hw / 2.)
|
| 224 |
+
boxes = np.concatenate([box_mins[..., 0:1], box_mins[..., 1:2], box_maxes[..., 0:1], box_maxes[..., 1:2]], axis=-1)
|
| 225 |
+
boxes *= np.concatenate([image_shape, image_shape], axis=-1)
|
| 226 |
+
return boxes
|
| 227 |
+
|
| 228 |
+
def non_max_suppression(self, prediction, num_classes, input_shape, image_shape, letterbox_image, conf_thres=0.5, nms_thres=0.4):
|
| 229 |
+
box_corner = np.zeros_like(prediction)
|
| 230 |
+
box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
|
| 231 |
+
box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
|
| 232 |
+
box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
|
| 233 |
+
box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
|
| 234 |
+
prediction[:, :, :4] = box_corner[:, :, :4]
|
| 235 |
+
|
| 236 |
+
output = [None for _ in range(len(prediction))]
|
| 237 |
+
for i, image_pred in enumerate(prediction):
|
| 238 |
+
class_conf = np.max(image_pred[:, 5:5 + num_classes], 1, keepdims=True)
|
| 239 |
+
class_pred = np.expand_dims(np.argmax(image_pred[:, 5:5 + num_classes], 1), -1)
|
| 240 |
+
|
| 241 |
+
conf_mask = np.squeeze((image_pred[:, 4] * class_conf[:, 0] >= conf_thres))
|
| 242 |
+
|
| 243 |
+
image_pred = image_pred[conf_mask]
|
| 244 |
+
class_conf = class_conf[conf_mask]
|
| 245 |
+
class_pred = class_pred[conf_mask]
|
| 246 |
+
if not np.shape(image_pred)[0]:
|
| 247 |
+
continue
|
| 248 |
+
detections = np.concatenate((image_pred[:, :5], class_conf, class_pred), 1)
|
| 249 |
+
|
| 250 |
+
unique_labels = np.unique(detections[:, -1])
|
| 251 |
+
|
| 252 |
+
for c in unique_labels:
|
| 253 |
+
detections_class = detections[detections[:, -1] == c]
|
| 254 |
+
|
| 255 |
+
conf_sort_index = np.argsort(detections_class[:, 4] * detections_class[:, 5])[::-1]
|
| 256 |
+
detections_class = detections_class[conf_sort_index]
|
| 257 |
+
max_detections = []
|
| 258 |
+
while np.shape(detections_class)[0]:
|
| 259 |
+
max_detections.append(detections_class[0:1])
|
| 260 |
+
if len(detections_class) == 1:
|
| 261 |
+
break
|
| 262 |
+
ious = self.bbox_iou(max_detections[-1], detections_class[1:])
|
| 263 |
+
detections_class = detections_class[1:][ious < nms_thres]
|
| 264 |
+
max_detections = np.concatenate(max_detections, 0)
|
| 265 |
+
|
| 266 |
+
output[i] = max_detections if output[i] is None else np.concatenate((output[i], max_detections))
|
| 267 |
+
|
| 268 |
+
if output[i] is not None:
|
| 269 |
+
output[i] = output[i]
|
| 270 |
+
box_xy, box_wh = (output[i][:, 0:2] + output[i][:, 2:4])/2, output[i][:, 2:4] - output[i][:, 0:2]
|
| 271 |
+
output[i][:, :4] = self.yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image)
|
| 272 |
+
return output
|
utils/utils_fit.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
|
| 6 |
+
from utils.utils import get_lr
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def fit_one_epoch(model_train, model, yolo_loss, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, fp16, scaler, save_period, save_dir, local_rank=0):
|
| 10 |
+
loss = 0
|
| 11 |
+
val_loss = 0
|
| 12 |
+
|
| 13 |
+
if local_rank == 0:
|
| 14 |
+
print('Start Train')
|
| 15 |
+
pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
|
| 16 |
+
model_train.train()
|
| 17 |
+
for iteration, batch in enumerate(gen):
|
| 18 |
+
if iteration >= epoch_step:
|
| 19 |
+
break
|
| 20 |
+
|
| 21 |
+
images, targets = batch[0], batch[1]
|
| 22 |
+
with torch.no_grad():
|
| 23 |
+
if cuda:
|
| 24 |
+
images = images.cuda(local_rank)
|
| 25 |
+
targets = [ann.cuda(local_rank) for ann in targets]
|
| 26 |
+
optimizer.zero_grad()
|
| 27 |
+
if not fp16:
|
| 28 |
+
outputs = model_train(images)
|
| 29 |
+
|
| 30 |
+
loss_value_all = 0
|
| 31 |
+
for l in range(len(outputs)):
|
| 32 |
+
loss_item = yolo_loss(l, outputs[l], targets)
|
| 33 |
+
loss_value_all += loss_item
|
| 34 |
+
loss_value = loss_value_all
|
| 35 |
+
|
| 36 |
+
loss_value.backward()
|
| 37 |
+
optimizer.step()
|
| 38 |
+
else:
|
| 39 |
+
from torch.cuda.amp import autocast
|
| 40 |
+
with autocast():
|
| 41 |
+
outputs = model_train(images)
|
| 42 |
+
|
| 43 |
+
loss_value_all = 0
|
| 44 |
+
for l in range(len(outputs)):
|
| 45 |
+
loss_item = yolo_loss(l, outputs[l], targets)
|
| 46 |
+
loss_value_all += loss_item
|
| 47 |
+
loss_value = loss_value_all
|
| 48 |
+
|
| 49 |
+
scaler.scale(loss_value).backward()
|
| 50 |
+
scaler.step(optimizer)
|
| 51 |
+
scaler.update()
|
| 52 |
+
|
| 53 |
+
loss += loss_value.item()
|
| 54 |
+
|
| 55 |
+
if local_rank == 0:
|
| 56 |
+
pbar.set_postfix(**{'loss' : loss / (iteration + 1),
|
| 57 |
+
'lr' : get_lr(optimizer)})
|
| 58 |
+
pbar.update(1)
|
| 59 |
+
|
| 60 |
+
if local_rank == 0:
|
| 61 |
+
pbar.close()
|
| 62 |
+
print('Finish Train')
|
| 63 |
+
print('Start Validation')
|
| 64 |
+
pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
|
| 65 |
+
|
| 66 |
+
model_train.eval()
|
| 67 |
+
for iteration, batch in enumerate(gen_val):
|
| 68 |
+
if iteration >= epoch_step_val:
|
| 69 |
+
break
|
| 70 |
+
images, targets = batch[0], batch[1]
|
| 71 |
+
with torch.no_grad():
|
| 72 |
+
if cuda:
|
| 73 |
+
images = images.cuda(local_rank)
|
| 74 |
+
targets = [ann.cuda(local_rank) for ann in targets]
|
| 75 |
+
optimizer.zero_grad()
|
| 76 |
+
outputs = model_train(images)
|
| 77 |
+
|
| 78 |
+
loss_value_all = 0
|
| 79 |
+
for l in range(len(outputs)):
|
| 80 |
+
loss_item = yolo_loss(l, outputs[l], targets)
|
| 81 |
+
loss_value_all += loss_item
|
| 82 |
+
loss_value = loss_value_all
|
| 83 |
+
|
| 84 |
+
val_loss += loss_value.item()
|
| 85 |
+
if local_rank == 0:
|
| 86 |
+
pbar.set_postfix(**{'val_loss': val_loss / (iteration + 1)})
|
| 87 |
+
pbar.update(1)
|
| 88 |
+
|
| 89 |
+
if local_rank == 0:
|
| 90 |
+
pbar.close()
|
| 91 |
+
print('Finish Validation')
|
| 92 |
+
loss_history.append_loss(epoch + 1, loss / epoch_step, val_loss / epoch_step_val)
|
| 93 |
+
eval_callback.on_epoch_end(epoch + 1, model_train)
|
| 94 |
+
print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch))
|
| 95 |
+
print('Total Loss: %.3f || Val Loss: %.3f ' % (loss / epoch_step, val_loss / epoch_step_val))
|
| 96 |
+
|
| 97 |
+
if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
|
| 98 |
+
torch.save(model.state_dict(), os.path.join(save_dir, "ep%03d-loss%.3f-val_loss%.3f.pth" % (epoch + 1, loss / epoch_step, val_loss / epoch_step_val)))
|
| 99 |
+
|
| 100 |
+
if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss):
|
| 101 |
+
print('Save best model to best_epoch_weights.pth')
|
| 102 |
+
torch.save(model.state_dict(), os.path.join(save_dir, "best_epoch_weights.pth"))
|
| 103 |
+
|
| 104 |
+
torch.save(model.state_dict(), os.path.join(save_dir, "last_epoch_weights.pth"))
|
utils/utils_map.py
ADDED
|
@@ -0,0 +1,757 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import json
|
| 3 |
+
import math
|
| 4 |
+
import operator
|
| 5 |
+
import os
|
| 6 |
+
import shutil
|
| 7 |
+
import sys
|
| 8 |
+
try:
|
| 9 |
+
from pycocotools.coco import COCO
|
| 10 |
+
from pycocotools.cocoeval import COCOeval
|
| 11 |
+
except:
|
| 12 |
+
pass
|
| 13 |
+
import cv2
|
| 14 |
+
import matplotlib
|
| 15 |
+
matplotlib.use('Agg')
|
| 16 |
+
from matplotlib import pyplot as plt
|
| 17 |
+
import numpy as np
|
| 18 |
+
|
| 19 |
+
def log_average_miss_rate(precision, fp_cumsum, num_images):
|
| 20 |
+
|
| 21 |
+
if precision.size == 0:
|
| 22 |
+
lamr = 0
|
| 23 |
+
mr = 1
|
| 24 |
+
fppi = 0
|
| 25 |
+
return lamr, mr, fppi
|
| 26 |
+
|
| 27 |
+
fppi = fp_cumsum / float(num_images)
|
| 28 |
+
mr = (1 - precision)
|
| 29 |
+
|
| 30 |
+
fppi_tmp = np.insert(fppi, 0, -1.0)
|
| 31 |
+
mr_tmp = np.insert(mr, 0, 1.0)
|
| 32 |
+
|
| 33 |
+
ref = np.logspace(-2.0, 0.0, num = 9)
|
| 34 |
+
for i, ref_i in enumerate(ref):
|
| 35 |
+
j = np.where(fppi_tmp <= ref_i)[-1][-1]
|
| 36 |
+
ref[i] = mr_tmp[j]
|
| 37 |
+
|
| 38 |
+
lamr = math.exp(np.mean(np.log(np.maximum(1e-10, ref))))
|
| 39 |
+
|
| 40 |
+
return lamr, mr, fppi
|
| 41 |
+
|
| 42 |
+
def error(msg):
|
| 43 |
+
print(msg)
|
| 44 |
+
sys.exit(0)
|
| 45 |
+
|
| 46 |
+
def is_float_between_0_and_1(value):
|
| 47 |
+
try:
|
| 48 |
+
val = float(value)
|
| 49 |
+
if val > 0.0 and val < 1.0:
|
| 50 |
+
return True
|
| 51 |
+
else:
|
| 52 |
+
return False
|
| 53 |
+
except ValueError:
|
| 54 |
+
return False
|
| 55 |
+
|
| 56 |
+
def voc_ap(rec, prec):
|
| 57 |
+
|
| 58 |
+
rec.insert(0, 0.0)
|
| 59 |
+
rec.append(1.0)
|
| 60 |
+
mrec = rec[:]
|
| 61 |
+
prec.insert(0, 0.0)
|
| 62 |
+
prec.append(0.0)
|
| 63 |
+
mpre = prec[:]
|
| 64 |
+
|
| 65 |
+
for i in range(len(mpre)-2, -1, -1):
|
| 66 |
+
mpre[i] = max(mpre[i], mpre[i+1])
|
| 67 |
+
|
| 68 |
+
i_list = []
|
| 69 |
+
for i in range(1, len(mrec)):
|
| 70 |
+
if mrec[i] != mrec[i-1]:
|
| 71 |
+
i_list.append(i)
|
| 72 |
+
|
| 73 |
+
ap = 0.0
|
| 74 |
+
for i in i_list:
|
| 75 |
+
ap += ((mrec[i]-mrec[i-1])*mpre[i])
|
| 76 |
+
return ap, mrec, mpre
|
| 77 |
+
|
| 78 |
+
def file_lines_to_list(path):
|
| 79 |
+
with open(path) as f:
|
| 80 |
+
content = f.readlines()
|
| 81 |
+
content = [x.strip() for x in content]
|
| 82 |
+
return content
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def draw_text_in_image(img, text, pos, color, line_width):
|
| 86 |
+
font = cv2.FONT_HERSHEY_PLAIN
|
| 87 |
+
fontScale = 1
|
| 88 |
+
lineType = 1
|
| 89 |
+
bottomLeftCornerOfText = pos
|
| 90 |
+
cv2.putText(img, text,
|
| 91 |
+
bottomLeftCornerOfText,
|
| 92 |
+
font,
|
| 93 |
+
fontScale,
|
| 94 |
+
color,
|
| 95 |
+
lineType)
|
| 96 |
+
text_width, _ = cv2.getTextSize(text, font, fontScale, lineType)[0]
|
| 97 |
+
return img, (line_width + text_width)
|
| 98 |
+
|
| 99 |
+
def adjust_axes(r, t, fig, axes):
|
| 100 |
+
bb = t.get_window_extent(renderer=r)
|
| 101 |
+
text_width_inches = bb.width / fig.dpi
|
| 102 |
+
current_fig_width = fig.get_figwidth()
|
| 103 |
+
new_fig_width = current_fig_width + text_width_inches
|
| 104 |
+
propotion = new_fig_width / current_fig_width
|
| 105 |
+
x_lim = axes.get_xlim()
|
| 106 |
+
axes.set_xlim([x_lim[0], x_lim[1]*propotion])
|
| 107 |
+
|
| 108 |
+
def draw_plot_func(dictionary, n_classes, window_title, plot_title, x_label, output_path, to_show, plot_color, true_p_bar):
|
| 109 |
+
sorted_dic_by_value = sorted(dictionary.items(), key=operator.itemgetter(1))
|
| 110 |
+
sorted_keys, sorted_values = zip(*sorted_dic_by_value)
|
| 111 |
+
if true_p_bar != "":
|
| 112 |
+
fp_sorted = []
|
| 113 |
+
tp_sorted = []
|
| 114 |
+
for key in sorted_keys:
|
| 115 |
+
fp_sorted.append(dictionary[key] - true_p_bar[key])
|
| 116 |
+
tp_sorted.append(true_p_bar[key])
|
| 117 |
+
plt.barh(range(n_classes), fp_sorted, align='center', color='crimson', label='False Positive')
|
| 118 |
+
plt.barh(range(n_classes), tp_sorted, align='center', color='forestgreen', label='True Positive', left=fp_sorted)
|
| 119 |
+
plt.legend(loc='lower right')
|
| 120 |
+
|
| 121 |
+
fig = plt.gcf()
|
| 122 |
+
axes = plt.gca()
|
| 123 |
+
r = fig.canvas.get_renderer()
|
| 124 |
+
for i, val in enumerate(sorted_values):
|
| 125 |
+
fp_val = fp_sorted[i]
|
| 126 |
+
tp_val = tp_sorted[i]
|
| 127 |
+
fp_str_val = " " + str(fp_val)
|
| 128 |
+
tp_str_val = fp_str_val + " " + str(tp_val)
|
| 129 |
+
t = plt.text(val, i, tp_str_val, color='forestgreen', va='center', fontweight='bold')
|
| 130 |
+
plt.text(val, i, fp_str_val, color='crimson', va='center', fontweight='bold')
|
| 131 |
+
if i == (len(sorted_values)-1):
|
| 132 |
+
adjust_axes(r, t, fig, axes)
|
| 133 |
+
else:
|
| 134 |
+
plt.barh(range(n_classes), sorted_values, color=plot_color)
|
| 135 |
+
"""
|
| 136 |
+
Write number on side of bar
|
| 137 |
+
"""
|
| 138 |
+
fig = plt.gcf()
|
| 139 |
+
axes = plt.gca()
|
| 140 |
+
r = fig.canvas.get_renderer()
|
| 141 |
+
for i, val in enumerate(sorted_values):
|
| 142 |
+
str_val = " " + str(val)
|
| 143 |
+
if val < 1.0:
|
| 144 |
+
str_val = " {0:.2f}".format(val)
|
| 145 |
+
t = plt.text(val, i, str_val, color=plot_color, va='center', fontweight='bold')
|
| 146 |
+
if i == (len(sorted_values)-1):
|
| 147 |
+
adjust_axes(r, t, fig, axes)
|
| 148 |
+
fig.canvas.set_window_title(window_title)
|
| 149 |
+
tick_font_size = 12
|
| 150 |
+
plt.yticks(range(n_classes), sorted_keys, fontsize=tick_font_size)
|
| 151 |
+
"""
|
| 152 |
+
Re-scale height accordingly
|
| 153 |
+
"""
|
| 154 |
+
init_height = fig.get_figheight()
|
| 155 |
+
dpi = fig.dpi
|
| 156 |
+
height_pt = n_classes * (tick_font_size * 1.4)
|
| 157 |
+
height_in = height_pt / dpi
|
| 158 |
+
top_margin = 0.15
|
| 159 |
+
bottom_margin = 0.05
|
| 160 |
+
figure_height = height_in / (1 - top_margin - bottom_margin)
|
| 161 |
+
if figure_height > init_height:
|
| 162 |
+
fig.set_figheight(figure_height)
|
| 163 |
+
|
| 164 |
+
plt.title(plot_title, fontsize=14)
|
| 165 |
+
plt.xlabel(x_label, fontsize='large')
|
| 166 |
+
fig.tight_layout()
|
| 167 |
+
fig.savefig(output_path)
|
| 168 |
+
if to_show:
|
| 169 |
+
plt.show()
|
| 170 |
+
plt.close()
|
| 171 |
+
|
| 172 |
+
def get_map(MINOVERLAP, draw_plot, score_threhold=0.5, path = './map_out'):
|
| 173 |
+
GT_PATH = os.path.join(path, 'ground-truth')
|
| 174 |
+
DR_PATH = os.path.join(path, 'detection-results')
|
| 175 |
+
IMG_PATH = os.path.join(path, 'images-optional')
|
| 176 |
+
TEMP_FILES_PATH = os.path.join(path, '.temp_files')
|
| 177 |
+
RESULTS_FILES_PATH = os.path.join(path, 'results')
|
| 178 |
+
|
| 179 |
+
show_animation = True
|
| 180 |
+
if os.path.exists(IMG_PATH):
|
| 181 |
+
for dirpath, dirnames, files in os.walk(IMG_PATH):
|
| 182 |
+
if not files:
|
| 183 |
+
show_animation = False
|
| 184 |
+
else:
|
| 185 |
+
show_animation = False
|
| 186 |
+
|
| 187 |
+
if not os.path.exists(TEMP_FILES_PATH):
|
| 188 |
+
os.makedirs(TEMP_FILES_PATH)
|
| 189 |
+
|
| 190 |
+
if os.path.exists(RESULTS_FILES_PATH):
|
| 191 |
+
shutil.rmtree(RESULTS_FILES_PATH)
|
| 192 |
+
else:
|
| 193 |
+
os.makedirs(RESULTS_FILES_PATH)
|
| 194 |
+
if draw_plot:
|
| 195 |
+
try:
|
| 196 |
+
matplotlib.use('TkAgg')
|
| 197 |
+
except:
|
| 198 |
+
pass
|
| 199 |
+
os.makedirs(os.path.join(RESULTS_FILES_PATH, "AP"))
|
| 200 |
+
os.makedirs(os.path.join(RESULTS_FILES_PATH, "F1"))
|
| 201 |
+
os.makedirs(os.path.join(RESULTS_FILES_PATH, "Recall"))
|
| 202 |
+
os.makedirs(os.path.join(RESULTS_FILES_PATH, "Precision"))
|
| 203 |
+
if show_animation:
|
| 204 |
+
os.makedirs(os.path.join(RESULTS_FILES_PATH, "images", "detections_one_by_one"))
|
| 205 |
+
|
| 206 |
+
ground_truth_files_list = glob.glob(GT_PATH + '/*.txt')
|
| 207 |
+
if len(ground_truth_files_list) == 0:
|
| 208 |
+
error("Error: No ground-truth files found!")
|
| 209 |
+
ground_truth_files_list.sort()
|
| 210 |
+
gt_counter_per_class = {}
|
| 211 |
+
counter_images_per_class = {}
|
| 212 |
+
|
| 213 |
+
for txt_file in ground_truth_files_list:
|
| 214 |
+
file_id = txt_file.split(".txt", 1)[0]
|
| 215 |
+
file_id = os.path.basename(os.path.normpath(file_id))
|
| 216 |
+
temp_path = os.path.join(DR_PATH, (file_id + ".txt"))
|
| 217 |
+
if not os.path.exists(temp_path):
|
| 218 |
+
error_msg = "Error. File not found: {}\n".format(temp_path)
|
| 219 |
+
error(error_msg)
|
| 220 |
+
lines_list = file_lines_to_list(txt_file)
|
| 221 |
+
bounding_boxes = []
|
| 222 |
+
is_difficult = False
|
| 223 |
+
already_seen_classes = []
|
| 224 |
+
for line in lines_list:
|
| 225 |
+
try:
|
| 226 |
+
if "difficult" in line:
|
| 227 |
+
class_name, left, top, right, bottom, _difficult = line.split()
|
| 228 |
+
is_difficult = True
|
| 229 |
+
else:
|
| 230 |
+
class_name, left, top, right, bottom = line.split()
|
| 231 |
+
except:
|
| 232 |
+
if "difficult" in line:
|
| 233 |
+
line_split = line.split()
|
| 234 |
+
_difficult = line_split[-1]
|
| 235 |
+
bottom = line_split[-2]
|
| 236 |
+
right = line_split[-3]
|
| 237 |
+
top = line_split[-4]
|
| 238 |
+
left = line_split[-5]
|
| 239 |
+
class_name = ""
|
| 240 |
+
for name in line_split[:-5]:
|
| 241 |
+
class_name += name + " "
|
| 242 |
+
class_name = class_name[:-1]
|
| 243 |
+
is_difficult = True
|
| 244 |
+
else:
|
| 245 |
+
line_split = line.split()
|
| 246 |
+
bottom = line_split[-1]
|
| 247 |
+
right = line_split[-2]
|
| 248 |
+
top = line_split[-3]
|
| 249 |
+
left = line_split[-4]
|
| 250 |
+
class_name = ""
|
| 251 |
+
for name in line_split[:-4]:
|
| 252 |
+
class_name += name + " "
|
| 253 |
+
class_name = class_name[:-1]
|
| 254 |
+
|
| 255 |
+
bbox = left + " " + top + " " + right + " " + bottom
|
| 256 |
+
if is_difficult:
|
| 257 |
+
bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False, "difficult":True})
|
| 258 |
+
is_difficult = False
|
| 259 |
+
else:
|
| 260 |
+
bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False})
|
| 261 |
+
if class_name in gt_counter_per_class:
|
| 262 |
+
gt_counter_per_class[class_name] += 1
|
| 263 |
+
else:
|
| 264 |
+
gt_counter_per_class[class_name] = 1
|
| 265 |
+
|
| 266 |
+
if class_name not in already_seen_classes:
|
| 267 |
+
if class_name in counter_images_per_class:
|
| 268 |
+
counter_images_per_class[class_name] += 1
|
| 269 |
+
else:
|
| 270 |
+
counter_images_per_class[class_name] = 1
|
| 271 |
+
already_seen_classes.append(class_name)
|
| 272 |
+
|
| 273 |
+
with open(TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json", 'w') as outfile:
|
| 274 |
+
json.dump(bounding_boxes, outfile)
|
| 275 |
+
|
| 276 |
+
gt_classes = list(gt_counter_per_class.keys())
|
| 277 |
+
gt_classes = sorted(gt_classes)
|
| 278 |
+
n_classes = len(gt_classes)
|
| 279 |
+
|
| 280 |
+
dr_files_list = glob.glob(DR_PATH + '/*.txt')
|
| 281 |
+
dr_files_list.sort()
|
| 282 |
+
for class_index, class_name in enumerate(gt_classes):
|
| 283 |
+
bounding_boxes = []
|
| 284 |
+
for txt_file in dr_files_list:
|
| 285 |
+
file_id = txt_file.split(".txt",1)[0]
|
| 286 |
+
file_id = os.path.basename(os.path.normpath(file_id))
|
| 287 |
+
temp_path = os.path.join(GT_PATH, (file_id + ".txt"))
|
| 288 |
+
if class_index == 0:
|
| 289 |
+
if not os.path.exists(temp_path):
|
| 290 |
+
error_msg = "Error. File not found: {}\n".format(temp_path)
|
| 291 |
+
error(error_msg)
|
| 292 |
+
lines = file_lines_to_list(txt_file)
|
| 293 |
+
for line in lines:
|
| 294 |
+
try:
|
| 295 |
+
tmp_class_name, confidence, left, top, right, bottom = line.split()
|
| 296 |
+
except:
|
| 297 |
+
line_split = line.split()
|
| 298 |
+
bottom = line_split[-1]
|
| 299 |
+
right = line_split[-2]
|
| 300 |
+
top = line_split[-3]
|
| 301 |
+
left = line_split[-4]
|
| 302 |
+
confidence = line_split[-5]
|
| 303 |
+
tmp_class_name = ""
|
| 304 |
+
for name in line_split[:-5]:
|
| 305 |
+
tmp_class_name += name + " "
|
| 306 |
+
tmp_class_name = tmp_class_name[:-1]
|
| 307 |
+
|
| 308 |
+
if tmp_class_name == class_name:
|
| 309 |
+
bbox = left + " " + top + " " + right + " " +bottom
|
| 310 |
+
bounding_boxes.append({"confidence":confidence, "file_id":file_id, "bbox":bbox})
|
| 311 |
+
|
| 312 |
+
bounding_boxes.sort(key=lambda x:float(x['confidence']), reverse=True)
|
| 313 |
+
with open(TEMP_FILES_PATH + "/" + class_name + "_dr.json", 'w') as outfile:
|
| 314 |
+
json.dump(bounding_boxes, outfile)
|
| 315 |
+
|
| 316 |
+
sum_AP = 0.0
|
| 317 |
+
ap_dictionary = {}
|
| 318 |
+
lamr_dictionary = {}
|
| 319 |
+
with open(RESULTS_FILES_PATH + "/results.txt", 'w') as results_file:
|
| 320 |
+
count_true_positives = {}
|
| 321 |
+
|
| 322 |
+
for class_index, class_name in enumerate(gt_classes):
|
| 323 |
+
count_true_positives[class_name] = 0
|
| 324 |
+
dr_file = TEMP_FILES_PATH + "/" + class_name + "_dr.json"
|
| 325 |
+
dr_data = json.load(open(dr_file))
|
| 326 |
+
|
| 327 |
+
nd = len(dr_data)
|
| 328 |
+
tp = [0] * nd
|
| 329 |
+
fp = [0] * nd
|
| 330 |
+
score = [0] * nd
|
| 331 |
+
score_threhold_idx = 0
|
| 332 |
+
for idx, detection in enumerate(dr_data):
|
| 333 |
+
file_id = detection["file_id"]
|
| 334 |
+
score[idx] = float(detection["confidence"])
|
| 335 |
+
if score[idx] >= score_threhold:
|
| 336 |
+
score_threhold_idx = idx
|
| 337 |
+
|
| 338 |
+
if show_animation:
|
| 339 |
+
ground_truth_img = glob.glob1(IMG_PATH, file_id + ".*")
|
| 340 |
+
if len(ground_truth_img) == 0:
|
| 341 |
+
error("Error. Image not found with id: " + file_id)
|
| 342 |
+
elif len(ground_truth_img) > 1:
|
| 343 |
+
error("Error. Multiple image with id: " + file_id)
|
| 344 |
+
else:
|
| 345 |
+
img = cv2.imread(IMG_PATH + "/" + ground_truth_img[0])
|
| 346 |
+
img_cumulative_path = RESULTS_FILES_PATH + "/images/" + ground_truth_img[0]
|
| 347 |
+
if os.path.isfile(img_cumulative_path):
|
| 348 |
+
img_cumulative = cv2.imread(img_cumulative_path)
|
| 349 |
+
else:
|
| 350 |
+
img_cumulative = img.copy()
|
| 351 |
+
bottom_border = 60
|
| 352 |
+
BLACK = [0, 0, 0]
|
| 353 |
+
img = cv2.copyMakeBorder(img, 0, bottom_border, 0, 0, cv2.BORDER_CONSTANT, value=BLACK)
|
| 354 |
+
|
| 355 |
+
gt_file = TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json"
|
| 356 |
+
ground_truth_data = json.load(open(gt_file))
|
| 357 |
+
ovmax = -1
|
| 358 |
+
gt_match = -1
|
| 359 |
+
bb = [float(x) for x in detection["bbox"].split()]
|
| 360 |
+
for obj in ground_truth_data:
|
| 361 |
+
if obj["class_name"] == class_name:
|
| 362 |
+
bbgt = [ float(x) for x in obj["bbox"].split() ]
|
| 363 |
+
bi = [max(bb[0],bbgt[0]), max(bb[1],bbgt[1]), min(bb[2],bbgt[2]), min(bb[3],bbgt[3])]
|
| 364 |
+
iw = bi[2] - bi[0] + 1
|
| 365 |
+
ih = bi[3] - bi[1] + 1
|
| 366 |
+
if iw > 0 and ih > 0:
|
| 367 |
+
ua = (bb[2] - bb[0] + 1) * (bb[3] - bb[1] + 1) + (bbgt[2] - bbgt[0]
|
| 368 |
+
+ 1) * (bbgt[3] - bbgt[1] + 1) - iw * ih
|
| 369 |
+
ov = iw * ih / ua
|
| 370 |
+
if ov > ovmax:
|
| 371 |
+
ovmax = ov
|
| 372 |
+
gt_match = obj
|
| 373 |
+
|
| 374 |
+
if show_animation:
|
| 375 |
+
status = "NO MATCH FOUND!"
|
| 376 |
+
|
| 377 |
+
min_overlap = MINOVERLAP
|
| 378 |
+
if ovmax >= min_overlap:
|
| 379 |
+
if "difficult" not in gt_match:
|
| 380 |
+
if not bool(gt_match["used"]):
|
| 381 |
+
tp[idx] = 1
|
| 382 |
+
gt_match["used"] = True
|
| 383 |
+
count_true_positives[class_name] += 1
|
| 384 |
+
with open(gt_file, 'w') as f:
|
| 385 |
+
f.write(json.dumps(ground_truth_data))
|
| 386 |
+
if show_animation:
|
| 387 |
+
status = "MATCH!"
|
| 388 |
+
else:
|
| 389 |
+
fp[idx] = 1
|
| 390 |
+
if show_animation:
|
| 391 |
+
status = "REPEATED MATCH!"
|
| 392 |
+
else:
|
| 393 |
+
fp[idx] = 1
|
| 394 |
+
if ovmax > 0:
|
| 395 |
+
status = "INSUFFICIENT OVERLAP"
|
| 396 |
+
if show_animation:
|
| 397 |
+
height, widht = img.shape[:2]
|
| 398 |
+
white = (255,255,255)
|
| 399 |
+
light_blue = (255,200,100)
|
| 400 |
+
green = (0,255,0)
|
| 401 |
+
light_red = (30,30,255)
|
| 402 |
+
margin = 10
|
| 403 |
+
v_pos = int(height - margin - (bottom_border / 2.0))
|
| 404 |
+
text = "Image: " + ground_truth_img[0] + " "
|
| 405 |
+
img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0)
|
| 406 |
+
text = "Class [" + str(class_index) + "/" + str(n_classes) + "]: " + class_name + " "
|
| 407 |
+
img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), light_blue, line_width)
|
| 408 |
+
if ovmax != -1:
|
| 409 |
+
color = light_red
|
| 410 |
+
if status == "INSUFFICIENT OVERLAP":
|
| 411 |
+
text = "IoU: {0:.2f}% ".format(ovmax*100) + "< {0:.2f}% ".format(min_overlap*100)
|
| 412 |
+
else:
|
| 413 |
+
text = "IoU: {0:.2f}% ".format(ovmax*100) + ">= {0:.2f}% ".format(min_overlap*100)
|
| 414 |
+
color = green
|
| 415 |
+
img, _ = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width)
|
| 416 |
+
v_pos += int(bottom_border / 2.0)
|
| 417 |
+
rank_pos = str(idx+1)
|
| 418 |
+
img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0)
|
| 419 |
+
color = light_red
|
| 420 |
+
if status == "MATCH!":
|
| 421 |
+
color = green
|
| 422 |
+
text = "Result: " + status + " "
|
| 423 |
+
img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width)
|
| 424 |
+
|
| 425 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 426 |
+
if ovmax > 0:
|
| 427 |
+
bbgt = [ int(round(float(x))) for x in gt_match["bbox"].split() ]
|
| 428 |
+
cv2.rectangle(img,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2)
|
| 429 |
+
cv2.rectangle(img_cumulative,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2)
|
| 430 |
+
cv2.putText(img_cumulative, class_name, (bbgt[0],bbgt[1] - 5), font, 0.6, light_blue, 1, cv2.LINE_AA)
|
| 431 |
+
bb = [int(i) for i in bb]
|
| 432 |
+
cv2.rectangle(img,(bb[0],bb[1]),(bb[2],bb[3]),color,2)
|
| 433 |
+
cv2.rectangle(img_cumulative,(bb[0],bb[1]),(bb[2],bb[3]),color,2)
|
| 434 |
+
cv2.putText(img_cumulative, class_name, (bb[0],bb[1] - 5), font, 0.6, color, 1, cv2.LINE_AA)
|
| 435 |
+
|
| 436 |
+
cv2.imshow("Animation", img)
|
| 437 |
+
cv2.waitKey(20)
|
| 438 |
+
output_img_path = RESULTS_FILES_PATH + "/images/detections_one_by_one/" + class_name + "_detection" + str(idx) + ".jpg"
|
| 439 |
+
cv2.imwrite(output_img_path, img)
|
| 440 |
+
cv2.imwrite(img_cumulative_path, img_cumulative)
|
| 441 |
+
|
| 442 |
+
cumsum = 0
|
| 443 |
+
for idx, val in enumerate(fp):
|
| 444 |
+
fp[idx] += cumsum
|
| 445 |
+
cumsum += val
|
| 446 |
+
|
| 447 |
+
cumsum = 0
|
| 448 |
+
for idx, val in enumerate(tp):
|
| 449 |
+
tp[idx] += cumsum
|
| 450 |
+
cumsum += val
|
| 451 |
+
|
| 452 |
+
rec = tp[:]
|
| 453 |
+
for idx, val in enumerate(tp):
|
| 454 |
+
rec[idx] = float(tp[idx]) / np.maximum(gt_counter_per_class[class_name], 1)
|
| 455 |
+
|
| 456 |
+
prec = tp[:]
|
| 457 |
+
for idx, val in enumerate(tp):
|
| 458 |
+
prec[idx] = float(tp[idx]) / np.maximum((fp[idx] + tp[idx]), 1)
|
| 459 |
+
|
| 460 |
+
ap, mrec, mprec = voc_ap(rec[:], prec[:])
|
| 461 |
+
F1 = np.array(rec)*np.array(prec)*2 / np.where((np.array(prec)+np.array(rec))==0, 1, (np.array(prec)+np.array(rec)))
|
| 462 |
+
|
| 463 |
+
sum_AP += ap
|
| 464 |
+
|
| 465 |
+
if len(prec)>0:
|
| 466 |
+
F1_text = "{0:.2f}".format(F1[score_threhold_idx]) + " = " + class_name + " F1 "
|
| 467 |
+
Recall_text = "{0:.2f}%".format(rec[score_threhold_idx]*100) + " = " + class_name + " Recall "
|
| 468 |
+
Precision_text = "{0:.2f}%".format(prec[score_threhold_idx]*100) + " = " + class_name + " Precision "
|
| 469 |
+
else:
|
| 470 |
+
F1_text = "0.00" + " = " + class_name + " F1 "
|
| 471 |
+
Recall_text = "0.00%" + " = " + class_name + " Recall "
|
| 472 |
+
Precision_text = "0.00%" + " = " + class_name + " Precision "
|
| 473 |
+
|
| 474 |
+
rounded_prec = [ '%.2f' % elem for elem in prec ]
|
| 475 |
+
rounded_rec = [ '%.2f' % elem for elem in rec ]
|
| 476 |
+
results_file.write(text + "\n Precision: " + str(rounded_prec) + "\n Recall :" + str(rounded_rec) + "\n\n")
|
| 477 |
+
|
| 478 |
+
if len(prec)>0:
|
| 479 |
+
print(text + "\t||\tscore_threhold=" + str(score_threhold) + " : " + "F1=" + "{0:.2f}".format(F1[score_threhold_idx])\
|
| 480 |
+
+ " ; Recall=" + "{0:.2f}%".format(rec[score_threhold_idx]*100) + " ; Precision=" + "{0:.2f}%".format(prec[score_threhold_idx]*100))
|
| 481 |
+
else:
|
| 482 |
+
print(text + "\t||\tscore_threhold=" + str(score_threhold) + " : " + "F1=0.00% ; Recall=0.00% ; Precision=0.00%")
|
| 483 |
+
ap_dictionary[class_name] = ap
|
| 484 |
+
|
| 485 |
+
n_images = counter_images_per_class[class_name]
|
| 486 |
+
lamr, mr, fppi = log_average_miss_rate(np.array(rec), np.array(fp), n_images)
|
| 487 |
+
lamr_dictionary[class_name] = lamr
|
| 488 |
+
|
| 489 |
+
if draw_plot:
|
| 490 |
+
plt.plot(rec, prec, '-o')
|
| 491 |
+
area_under_curve_x = mrec[:-1] + [mrec[-2]] + [mrec[-1]]
|
| 492 |
+
area_under_curve_y = mprec[:-1] + [0.0] + [mprec[-1]]
|
| 493 |
+
plt.fill_between(area_under_curve_x, 0, area_under_curve_y, alpha=0.2, edgecolor='r')
|
| 494 |
+
|
| 495 |
+
fig = plt.gcf()
|
| 496 |
+
fig.canvas.set_window_title('AP ' + class_name)
|
| 497 |
+
|
| 498 |
+
plt.title('class: ' + text)
|
| 499 |
+
plt.xlabel('Recall')
|
| 500 |
+
plt.ylabel('Precision')
|
| 501 |
+
axes = plt.gca()
|
| 502 |
+
axes.set_xlim([0.0,1.0])
|
| 503 |
+
axes.set_ylim([0.0,1.05])
|
| 504 |
+
fig.savefig(RESULTS_FILES_PATH + "/AP/" + class_name + ".png")
|
| 505 |
+
plt.cla()
|
| 506 |
+
|
| 507 |
+
plt.plot(score, F1, "-", color='orangered')
|
| 508 |
+
plt.title('class: ' + F1_text + "\nscore_threhold=" + str(score_threhold))
|
| 509 |
+
plt.xlabel('Score_Threhold')
|
| 510 |
+
plt.ylabel('F1')
|
| 511 |
+
axes = plt.gca()
|
| 512 |
+
axes.set_xlim([0.0,1.0])
|
| 513 |
+
axes.set_ylim([0.0,1.05])
|
| 514 |
+
fig.savefig(RESULTS_FILES_PATH + "/F1/" + class_name + ".png")
|
| 515 |
+
plt.cla()
|
| 516 |
+
|
| 517 |
+
plt.plot(score, rec, "-H", color='gold')
|
| 518 |
+
plt.title('class: ' + Recall_text + "\nscore_threhold=" + str(score_threhold))
|
| 519 |
+
plt.xlabel('Score_Threhold')
|
| 520 |
+
plt.ylabel('Recall')
|
| 521 |
+
axes = plt.gca()
|
| 522 |
+
axes.set_xlim([0.0,1.0])
|
| 523 |
+
axes.set_ylim([0.0,1.05])
|
| 524 |
+
fig.savefig(RESULTS_FILES_PATH + "/Recall/" + class_name + ".png")
|
| 525 |
+
plt.cla()
|
| 526 |
+
|
| 527 |
+
plt.plot(score, prec, "-s", color='palevioletred')
|
| 528 |
+
plt.title('class: ' + Precision_text + "\nscore_threhold=" + str(score_threhold))
|
| 529 |
+
plt.xlabel('Score_Threhold')
|
| 530 |
+
plt.ylabel('Precision')
|
| 531 |
+
axes = plt.gca()
|
| 532 |
+
axes.set_xlim([0.0,1.0])
|
| 533 |
+
axes.set_ylim([0.0,1.05])
|
| 534 |
+
fig.savefig(RESULTS_FILES_PATH + "/Precision/" + class_name + ".png")
|
| 535 |
+
plt.cla()
|
| 536 |
+
|
| 537 |
+
if show_animation:
|
| 538 |
+
cv2.destroyAllWindows()
|
| 539 |
+
if n_classes == 0:
|
| 540 |
+
print("未检测到任何种类,请检查标签信息与get_map.py中的classes_path是否修改。")
|
| 541 |
+
return 0
|
| 542 |
+
mAP = sum_AP / n_classes
|
| 543 |
+
text = "mAP = {0:.2f}%".format(mAP*100)
|
| 544 |
+
results_file.write(text + "\n")
|
| 545 |
+
print(text)
|
| 546 |
+
|
| 547 |
+
shutil.rmtree(TEMP_FILES_PATH)
|
| 548 |
+
|
| 549 |
+
det_counter_per_class = {}
|
| 550 |
+
for txt_file in dr_files_list:
|
| 551 |
+
lines_list = file_lines_to_list(txt_file)
|
| 552 |
+
for line in lines_list:
|
| 553 |
+
class_name = line.split()[0]
|
| 554 |
+
if class_name in det_counter_per_class:
|
| 555 |
+
det_counter_per_class[class_name] += 1
|
| 556 |
+
else:
|
| 557 |
+
det_counter_per_class[class_name] = 1
|
| 558 |
+
dr_classes = list(det_counter_per_class.keys())
|
| 559 |
+
|
| 560 |
+
with open(RESULTS_FILES_PATH + "/results.txt", 'a') as results_file:
|
| 561 |
+
for class_name in sorted(gt_counter_per_class):
|
| 562 |
+
results_file.write(class_name + ": " + str(gt_counter_per_class[class_name]) + "\n")
|
| 563 |
+
|
| 564 |
+
for class_name in dr_classes:
|
| 565 |
+
if class_name not in gt_classes:
|
| 566 |
+
count_true_positives[class_name] = 0
|
| 567 |
+
|
| 568 |
+
with open(RESULTS_FILES_PATH + "/results.txt", 'a') as results_file:
|
| 569 |
+
for class_name in sorted(dr_classes):
|
| 570 |
+
n_det = det_counter_per_class[class_name]
|
| 571 |
+
text = class_name + ": " + str(n_det)
|
| 572 |
+
text += " (tp:" + str(count_true_positives[class_name]) + ""
|
| 573 |
+
text += ", fp:" + str(n_det - count_true_positives[class_name]) + ")\n"
|
| 574 |
+
results_file.write(text)
|
| 575 |
+
|
| 576 |
+
if draw_plot:
|
| 577 |
+
window_title = "ground-truth-info"
|
| 578 |
+
plot_title = "ground-truth\n"
|
| 579 |
+
plot_title += "(" + str(len(ground_truth_files_list)) + " files and " + str(n_classes) + " classes)"
|
| 580 |
+
x_label = "Number of objects per class"
|
| 581 |
+
output_path = RESULTS_FILES_PATH + "/ground-truth-info.png"
|
| 582 |
+
to_show = False
|
| 583 |
+
plot_color = 'forestgreen'
|
| 584 |
+
draw_plot_func(
|
| 585 |
+
gt_counter_per_class,
|
| 586 |
+
n_classes,
|
| 587 |
+
window_title,
|
| 588 |
+
plot_title,
|
| 589 |
+
x_label,
|
| 590 |
+
output_path,
|
| 591 |
+
to_show,
|
| 592 |
+
plot_color,
|
| 593 |
+
'',
|
| 594 |
+
)
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
if draw_plot:
|
| 599 |
+
window_title = "lamr"
|
| 600 |
+
plot_title = "log-average miss rate"
|
| 601 |
+
x_label = "log-average miss rate"
|
| 602 |
+
output_path = RESULTS_FILES_PATH + "/lamr.png"
|
| 603 |
+
to_show = False
|
| 604 |
+
plot_color = 'royalblue'
|
| 605 |
+
draw_plot_func(
|
| 606 |
+
lamr_dictionary,
|
| 607 |
+
n_classes,
|
| 608 |
+
window_title,
|
| 609 |
+
plot_title,
|
| 610 |
+
x_label,
|
| 611 |
+
output_path,
|
| 612 |
+
to_show,
|
| 613 |
+
plot_color,
|
| 614 |
+
""
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
if draw_plot:
|
| 619 |
+
window_title = "mAP"
|
| 620 |
+
plot_title = "mAP = {0:.2f}%".format(mAP*100)
|
| 621 |
+
x_label = "Average Precision"
|
| 622 |
+
output_path = RESULTS_FILES_PATH + "/mAP.png"
|
| 623 |
+
to_show = True
|
| 624 |
+
plot_color = 'royalblue'
|
| 625 |
+
draw_plot_func(
|
| 626 |
+
ap_dictionary,
|
| 627 |
+
n_classes,
|
| 628 |
+
window_title,
|
| 629 |
+
plot_title,
|
| 630 |
+
x_label,
|
| 631 |
+
output_path,
|
| 632 |
+
to_show,
|
| 633 |
+
plot_color,
|
| 634 |
+
""
|
| 635 |
+
)
|
| 636 |
+
return mAP
|
| 637 |
+
|
| 638 |
+
def preprocess_gt(gt_path, class_names):
|
| 639 |
+
image_ids = os.listdir(gt_path)
|
| 640 |
+
results = {}
|
| 641 |
+
|
| 642 |
+
images = []
|
| 643 |
+
bboxes = []
|
| 644 |
+
for i, image_id in enumerate(image_ids):
|
| 645 |
+
lines_list = file_lines_to_list(os.path.join(gt_path, image_id))
|
| 646 |
+
boxes_per_image = []
|
| 647 |
+
image = {}
|
| 648 |
+
image_id = os.path.splitext(image_id)[0]
|
| 649 |
+
image['file_name'] = image_id + '.jpg'
|
| 650 |
+
image['width'] = 1
|
| 651 |
+
image['height'] = 1
|
| 652 |
+
image['id'] = str(image_id)
|
| 653 |
+
|
| 654 |
+
for line in lines_list:
|
| 655 |
+
difficult = 0
|
| 656 |
+
if "difficult" in line:
|
| 657 |
+
line_split = line.split()
|
| 658 |
+
left, top, right, bottom, _difficult = line_split[-5:]
|
| 659 |
+
class_name = ""
|
| 660 |
+
for name in line_split[:-5]:
|
| 661 |
+
class_name += name + " "
|
| 662 |
+
class_name = class_name[:-1]
|
| 663 |
+
difficult = 1
|
| 664 |
+
else:
|
| 665 |
+
line_split = line.split()
|
| 666 |
+
left, top, right, bottom = line_split[-4:]
|
| 667 |
+
class_name = ""
|
| 668 |
+
for name in line_split[:-4]:
|
| 669 |
+
class_name += name + " "
|
| 670 |
+
class_name = class_name[:-1]
|
| 671 |
+
|
| 672 |
+
left, top, right, bottom = float(left), float(top), float(right), float(bottom)
|
| 673 |
+
if class_name not in class_names:
|
| 674 |
+
continue
|
| 675 |
+
cls_id = class_names.index(class_name) + 1
|
| 676 |
+
bbox = [left, top, right - left, bottom - top, difficult, str(image_id), cls_id, (right - left) * (bottom - top) - 10.0]
|
| 677 |
+
boxes_per_image.append(bbox)
|
| 678 |
+
images.append(image)
|
| 679 |
+
bboxes.extend(boxes_per_image)
|
| 680 |
+
results['images'] = images
|
| 681 |
+
|
| 682 |
+
categories = []
|
| 683 |
+
for i, cls in enumerate(class_names):
|
| 684 |
+
category = {}
|
| 685 |
+
category['supercategory'] = cls
|
| 686 |
+
category['name'] = cls
|
| 687 |
+
category['id'] = i + 1
|
| 688 |
+
categories.append(category)
|
| 689 |
+
results['categories'] = categories
|
| 690 |
+
|
| 691 |
+
annotations = []
|
| 692 |
+
for i, box in enumerate(bboxes):
|
| 693 |
+
annotation = {}
|
| 694 |
+
annotation['area'] = box[-1]
|
| 695 |
+
annotation['category_id'] = box[-2]
|
| 696 |
+
annotation['image_id'] = box[-3]
|
| 697 |
+
annotation['iscrowd'] = box[-4]
|
| 698 |
+
annotation['bbox'] = box[:4]
|
| 699 |
+
annotation['id'] = i
|
| 700 |
+
annotations.append(annotation)
|
| 701 |
+
results['annotations'] = annotations
|
| 702 |
+
return results
|
| 703 |
+
|
| 704 |
+
def preprocess_dr(dr_path, class_names):
|
| 705 |
+
image_ids = os.listdir(dr_path)
|
| 706 |
+
results = []
|
| 707 |
+
for image_id in image_ids:
|
| 708 |
+
lines_list = file_lines_to_list(os.path.join(dr_path, image_id))
|
| 709 |
+
image_id = os.path.splitext(image_id)[0]
|
| 710 |
+
for line in lines_list:
|
| 711 |
+
line_split = line.split()
|
| 712 |
+
confidence, left, top, right, bottom = line_split[-5:]
|
| 713 |
+
class_name = ""
|
| 714 |
+
for name in line_split[:-5]:
|
| 715 |
+
class_name += name + " "
|
| 716 |
+
class_name = class_name[:-1]
|
| 717 |
+
left, top, right, bottom = float(left), float(top), float(right), float(bottom)
|
| 718 |
+
result = {}
|
| 719 |
+
result["image_id"] = str(image_id)
|
| 720 |
+
if class_name not in class_names:
|
| 721 |
+
continue
|
| 722 |
+
result["category_id"] = class_names.index(class_name) + 1
|
| 723 |
+
result["bbox"] = [left, top, right - left, bottom - top]
|
| 724 |
+
result["score"] = float(confidence)
|
| 725 |
+
results.append(result)
|
| 726 |
+
return results
|
| 727 |
+
|
| 728 |
+
def get_coco_map(class_names, path):
|
| 729 |
+
GT_PATH = os.path.join(path, 'ground-truth')
|
| 730 |
+
DR_PATH = os.path.join(path, 'detection-results')
|
| 731 |
+
COCO_PATH = os.path.join(path, 'coco_eval')
|
| 732 |
+
|
| 733 |
+
if not os.path.exists(COCO_PATH):
|
| 734 |
+
os.makedirs(COCO_PATH)
|
| 735 |
+
|
| 736 |
+
GT_JSON_PATH = os.path.join(COCO_PATH, 'instances_gt.json')
|
| 737 |
+
DR_JSON_PATH = os.path.join(COCO_PATH, 'instances_dr.json')
|
| 738 |
+
|
| 739 |
+
with open(GT_JSON_PATH, "w") as f:
|
| 740 |
+
results_gt = preprocess_gt(GT_PATH, class_names)
|
| 741 |
+
json.dump(results_gt, f, indent=4)
|
| 742 |
+
|
| 743 |
+
with open(DR_JSON_PATH, "w") as f:
|
| 744 |
+
results_dr = preprocess_dr(DR_PATH, class_names)
|
| 745 |
+
json.dump(results_dr, f, indent=4)
|
| 746 |
+
if len(results_dr) == 0:
|
| 747 |
+
print("未检测到任何目标。")
|
| 748 |
+
return [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
|
| 749 |
+
|
| 750 |
+
cocoGt = COCO(GT_JSON_PATH)
|
| 751 |
+
cocoDt = cocoGt.loadRes(DR_JSON_PATH)
|
| 752 |
+
cocoEval = COCOeval(cocoGt, cocoDt, 'bbox')
|
| 753 |
+
cocoEval.evaluate()
|
| 754 |
+
cocoEval.accumulate()
|
| 755 |
+
cocoEval.summarize()
|
| 756 |
+
|
| 757 |
+
return cocoEval.stats
|