Spaces:
Runtime error
Runtime error
| import datetime | |
| import os | |
| import torch | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import scipy.signal | |
| from matplotlib import pyplot as plt | |
| from torch.utils.tensorboard import SummaryWriter | |
| import shutil | |
| import numpy as np | |
| from PIL import Image | |
| from tqdm import tqdm | |
| from .utils import cvtColor, preprocess_input, resize_image | |
| from .utils_bbox import DecodeBox | |
| from .utils_map import get_coco_map, get_map | |
| class LossHistory(): | |
| def __init__(self, log_dir, model, input_shape): | |
| self.log_dir = log_dir | |
| self.losses = [] | |
| self.val_loss = [] | |
| os.makedirs(self.log_dir) | |
| self.writer = SummaryWriter(self.log_dir) | |
| try: | |
| dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1]) | |
| self.writer.add_graph(model, dummy_input) | |
| except: | |
| pass | |
| def append_loss(self, epoch, loss, val_loss): | |
| if not os.path.exists(self.log_dir): | |
| os.makedirs(self.log_dir) | |
| self.losses.append(loss) | |
| self.val_loss.append(val_loss) | |
| with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f: | |
| f.write(str(loss)) | |
| f.write("\n") | |
| with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f: | |
| f.write(str(val_loss)) | |
| f.write("\n") | |
| self.writer.add_scalar('loss', loss, epoch) | |
| self.writer.add_scalar('val_loss', val_loss, epoch) | |
| self.loss_plot() | |
| def loss_plot(self): | |
| iters = range(len(self.losses)) | |
| plt.figure() | |
| plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss') | |
| plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss') | |
| try: | |
| if len(self.losses) < 25: | |
| num = 5 | |
| else: | |
| num = 15 | |
| plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss') | |
| except: | |
| pass | |
| plt.grid(True) | |
| plt.xlabel('Epoch') | |
| plt.ylabel('Loss') | |
| plt.legend(loc="upper right") | |
| plt.savefig(os.path.join(self.log_dir, "epoch_loss.png")) | |
| plt.cla() | |
| plt.close("all") | |
| class EvalCallback(): | |
| def __init__(self, net, input_shape, anchors, anchors_mask, class_names, num_classes, val_lines, log_dir, cuda, \ | |
| map_out_path=".temp_map_out", max_boxes=100, confidence=0.05, nms_iou=0.5, letterbox_image=True, MINOVERLAP=0.5, eval_flag=True, period=1): | |
| super(EvalCallback, self).__init__() | |
| self.net = net | |
| self.input_shape = input_shape | |
| self.anchors = anchors | |
| self.anchors_mask = anchors_mask | |
| self.class_names = class_names | |
| self.num_classes = num_classes | |
| self.val_lines = val_lines | |
| self.log_dir = log_dir | |
| self.cuda = cuda | |
| self.map_out_path = map_out_path | |
| self.max_boxes = max_boxes | |
| self.confidence = confidence | |
| self.nms_iou = nms_iou | |
| self.letterbox_image = letterbox_image | |
| self.MINOVERLAP = MINOVERLAP | |
| self.eval_flag = eval_flag | |
| self.period = period | |
| self.bbox_util = DecodeBox(self.anchors, self.num_classes, (self.input_shape[0], self.input_shape[1]), self.anchors_mask) | |
| self.maps = [0] | |
| self.epoches = [0] | |
| if self.eval_flag: | |
| with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f: | |
| f.write(str(0)) | |
| f.write("\n") | |
| def get_map_txt(self, image_id, image, class_names, map_out_path): | |
| f = open(os.path.join(map_out_path, "detection-results/"+image_id+".txt"), "w", encoding='utf-8') | |
| image_shape = np.array(np.shape(image)[0:2]) | |
| image = cvtColor(image) | |
| image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image) | |
| image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0) | |
| with torch.no_grad(): | |
| images = torch.from_numpy(image_data) | |
| if self.cuda: | |
| images = images.cuda() | |
| outputs = self.net(images) | |
| outputs = self.bbox_util.decode_box(outputs) | |
| results = self.bbox_util.non_max_suppression(torch.cat(outputs, 1), self.num_classes, self.input_shape, | |
| image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou) | |
| if results[0] is None: | |
| return | |
| top_label = np.array(results[0][:, 6], dtype = 'int32') | |
| top_conf = results[0][:, 4] * results[0][:, 5] | |
| top_boxes = results[0][:, :4] | |
| top_100 = np.argsort(top_conf)[::-1][:self.max_boxes] | |
| top_boxes = top_boxes[top_100] | |
| top_conf = top_conf[top_100] | |
| top_label = top_label[top_100] | |
| for i, c in list(enumerate(top_label)): | |
| predicted_class = self.class_names[int(c)] | |
| box = top_boxes[i] | |
| score = str(top_conf[i]) | |
| top, left, bottom, right = box | |
| if predicted_class not in class_names: | |
| continue | |
| f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom)))) | |
| f.close() | |
| return | |
| def on_epoch_end(self, epoch, model_eval): | |
| if epoch % self.period == 0 and self.eval_flag: | |
| self.net = model_eval | |
| if not os.path.exists(self.map_out_path): | |
| os.makedirs(self.map_out_path) | |
| if not os.path.exists(os.path.join(self.map_out_path, "ground-truth")): | |
| os.makedirs(os.path.join(self.map_out_path, "ground-truth")) | |
| if not os.path.exists(os.path.join(self.map_out_path, "detection-results")): | |
| os.makedirs(os.path.join(self.map_out_path, "detection-results")) | |
| print("Get map.") | |
| for annotation_line in tqdm(self.val_lines): | |
| line = annotation_line.split() | |
| image_id = os.path.basename(line[0]).split('.')[0] | |
| image = Image.open(line[0]) | |
| gt_boxes = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]]) | |
| self.get_map_txt(image_id, image, self.class_names, self.map_out_path) | |
| with open(os.path.join(self.map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f: | |
| for box in gt_boxes: | |
| left, top, right, bottom, obj = box | |
| obj_name = self.class_names[obj] | |
| new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom)) | |
| print("Calculate Map.") | |
| try: | |
| temp_map = get_coco_map(class_names = self.class_names, path = self.map_out_path)[1] | |
| except: | |
| temp_map = get_map(self.MINOVERLAP, False, path = self.map_out_path) | |
| self.maps.append(temp_map) | |
| self.epoches.append(epoch) | |
| with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f: | |
| f.write(str(temp_map)) | |
| f.write("\n") | |
| plt.figure() | |
| plt.plot(self.epoches, self.maps, 'red', linewidth = 2, label='train map') | |
| plt.grid(True) | |
| plt.xlabel('Epoch') | |
| plt.ylabel('Map %s'%str(self.MINOVERLAP)) | |
| plt.title('A Map Curve') | |
| plt.legend(loc="upper right") | |
| plt.savefig(os.path.join(self.log_dir, "epoch_map.png")) | |
| plt.cla() | |
| plt.close("all") | |
| print("Get map done.") | |
| shutil.rmtree(self.map_out_path) | |