''' Guowang Xie from utilsV3.py ''' import torch from torch.utils import data from torch.autograd import Variable, Function import numpy as np import sys, os, math import cv2 import time import re import random from scipy.interpolate import griddata from tpsV2 import createThinPlateSplineShapeTransformer def adjust_position(x_min, y_min, x_max, y_max, new_shape): if (new_shape[0] - (x_max - x_min)) % 2 == 0: f_g_0_0 = (new_shape[0] - (x_max - x_min)) // 2 f_g_0_1 = f_g_0_0 else: f_g_0_0 = (new_shape[0] - (x_max - x_min)) // 2 f_g_0_1 = f_g_0_0 + 1 if (new_shape[1] - (y_max - y_min)) % 2 == 0: f_g_1_0 = (new_shape[1] - (y_max - y_min)) // 2 f_g_1_1 = f_g_1_0 else: f_g_1_0 = (new_shape[1] - (y_max - y_min)) // 2 f_g_1_1 = f_g_1_0 + 1 # return f_g_0_0, f_g_0_1, f_g_1_0, f_g_1_1 return f_g_0_0, f_g_1_0, new_shape[0] - f_g_0_1, new_shape[1] - f_g_1_1 def get_matric_edge(matric): return np.concatenate((matric[:, 0, :], matric[:, -1, :], matric[0, 1:-1, :], matric[-1, 1:-1, :]), axis=0) class SaveFlatImage(object): ''' Post-processing and save result. Function: flatByRegressWithClassiy_multiProcessV2: Selecting a post-processing method flatByfiducial_TPS: Thin Plate Spline, input multi-batch flatByfiducial_interpolation: Interpolation, input one image ''' def __init__(self, path, date, date_time, _re_date, data_path_validate, data_path_test, batch_size, preproccess=False, postprocess='tps_gpu', device=torch.device('cuda:0')): self.path = path self.date = date self.date_time = date_time self._re_date = _re_date self.preproccess = preproccess self.data_path_validate =data_path_validate self.data_path_test = data_path_test self.batch_size = batch_size self.device = device self.col_gap = 0 #4 self.row_gap = self.col_gap# col_gap + 1 if col_gap < 6 else col_gap # fiducial_point_gaps = [1, 2, 3, 4, 5, 6, 10, 12, 15, 20, 30, 60] # POINTS NUM: 61, 31, 21, 16, 13, 11, 7, 6, 5, 4, 3, 2 self.fiducial_point_gaps = [1, 2, 3, 5, 6, 10, 15, 30] # POINTS NUM: 31, 16, 11, 7, 6, 4, 3, 2 self.fiducial_point_num = [31, 16, 11, 7, 6, 4, 3, 2] self.fiducial_num = self.fiducial_point_num[self.col_gap], self.fiducial_point_num[self.row_gap] map_shape = (320, 320) self.postprocess = postprocess if self.postprocess == 'tps': self.tps = createThinPlateSplineShapeTransformer(map_shape, fiducial_num=self.fiducial_num, device=self.device) def location_mark(self, img, location, color=(0, 0, 255)): stepSize = 0 for l in location.astype(np.int64).reshape(-1, 2): cv2.circle(img, (l[0] + math.ceil(stepSize / 2), l[1] + math.ceil(stepSize / 2)), 3, color, -1) return img def flatByfiducial_TPS(self, fiducial_points, segment, im_name, epoch, perturbed_img=None, scheme='validate', is_scaling=False): ''' flat_shap controls the output image resolution ''' # if (scheme == 'test' or scheme == 'eval') and is_scaling: # pass # else: if scheme == 'test' or scheme == 'eval': perturbed_img_path = self.data_path_test + im_name perturbed_img = cv2.imread(perturbed_img_path, flags=cv2.IMREAD_COLOR) perturbed_img = cv2.resize(perturbed_img, (960, 1024)) elif scheme == 'validate' and perturbed_img is None: RGB_name = im_name.replace('gw', 'png') perturbed_img_path = self.data_path_validate + '/png/' + RGB_name perturbed_img = cv2.imread(perturbed_img_path, flags=cv2.IMREAD_COLOR) elif perturbed_img is not None: perturbed_img = perturbed_img.transpose(1, 2, 0) fiducial_points = fiducial_points / [992, 992] perturbed_img_shape = perturbed_img.shape[:2] sshape = fiducial_points[::self.fiducial_point_gaps[self.row_gap], ::self.fiducial_point_gaps[self.col_gap], :] flat_shap = segment * [self.fiducial_point_gaps[self.col_gap], self.fiducial_point_gaps[self.row_gap]] * [self.fiducial_point_num[self.col_gap], self.fiducial_point_num[self.row_gap]] # flat_shap = perturbed_img_shape time_1 = time.time() perturbed_img_ = torch.tensor(perturbed_img.transpose(2,0,1)[None,:]) fiducial_points_ = (torch.tensor(fiducial_points.transpose(1, 0,2).reshape(-1, 2))[None,:]-0.5)*2 rectified = self.tps(perturbed_img_.double().to(self.device), fiducial_points_.to(self.device), list(flat_shap)) time_2 = time.time() time_interval = time_2 - time_1 print('TPS time: '+ str(time_interval)) flat_img = rectified[0].cpu().numpy().transpose(1,2,0) '''save''' flat_img = flat_img.astype(np.uint8) i_path = os.path.join(self.path, self.date + self.date_time + ' @' + self._re_date, str(epoch)) if self._re_date is not None else os.path.join(self.path, self.date + self.date_time, str(epoch)) '''''' perturbed_img_mark = self.location_mark(perturbed_img.copy(), sshape*perturbed_img_shape[::-1], (0, 0, 255)) if scheme == 'test': i_path += '/test' if not os.path.exists(i_path): os.makedirs(i_path) im_name = im_name.replace('gw', 'png') cv2.imwrite(i_path + '/mark_' + im_name, perturbed_img_mark) cv2.imwrite(i_path + '/' + im_name, flat_img) def flatByfiducial_interpolation(self, fiducial_points, segment, im_name, epoch, perturbed_img=None, scheme='validate', is_scaling=False): '''''' if scheme == 'test' or scheme == 'eval': perturbed_img_path = self.data_path_test + im_name perturbed_img = cv2.imread(perturbed_img_path, flags=cv2.IMREAD_COLOR) perturbed_img = cv2.resize(perturbed_img, (960, 1024)) elif scheme == 'validate' and perturbed_img is None: RGB_name = im_name.replace('gw', 'png') perturbed_img_path = self.data_path_validate + '/png/' + RGB_name perturbed_img = cv2.imread(perturbed_img_path, flags=cv2.IMREAD_COLOR) elif perturbed_img is not None: perturbed_img = perturbed_img.transpose(1, 2, 0) fiducial_points = fiducial_points / [992, 992] * [960, 1024] col_gap = 2 #4 row_gap = col_gap# col_gap + 1 if col_gap < 6 else col_gap # fiducial_point_gaps = [1, 2, 3, 4, 5, 6, 10, 12, 15, 20, 30, 60] # POINTS NUM: 61, 31, 21, 16, 13, 11, 7, 6, 5, 4, 3, 2 fiducial_point_gaps = [1, 2, 3, 5, 6, 10, 15, 30] # POINTS NUM: 31, 16, 11, 7, 6, 4, 3, 2 sshape = fiducial_points[::fiducial_point_gaps[row_gap], ::fiducial_point_gaps[col_gap], :] segment_h, segment_w = segment * [fiducial_point_gaps[col_gap], fiducial_point_gaps[row_gap]] fiducial_points_row, fiducial_points_col = sshape.shape[:2] im_x, im_y = np.mgrid[0:(fiducial_points_col - 1):complex(fiducial_points_col), 0:(fiducial_points_row - 1):complex(fiducial_points_row)] tshape = np.stack((im_x, im_y), axis=2) * [segment_w, segment_h] tshape = tshape.reshape(-1, 2) sshape = sshape.reshape(-1, 2) output_shape = (segment_h * (fiducial_points_col - 1), segment_w * (fiducial_points_row - 1)) grid_x, grid_y = np.mgrid[0:output_shape[0] - 1:complex(output_shape[0]), 0:output_shape[1] - 1:complex(output_shape[1])] time_1 = time.time() # grid_z = griddata(tshape, sshape, (grid_y, grid_x), method='cubic').astype('float32') grid_ = griddata(tshape, sshape, (grid_y, grid_x), method='linear').astype('float32') flat_img = cv2.remap(perturbed_img, grid_[:, :, 0], grid_[:, :, 1], cv2.INTER_CUBIC) time_2 = time.time() time_interval = time_2 - time_1 print('Interpolation time: '+ str(time_interval)) '''''' flat_img = flat_img.astype(np.uint8) i_path = os.path.join(self.path, self.date + self.date_time + ' @' + self._re_date, str(epoch)) if self._re_date is not None else os.path.join(self.path, self.date + self.date_time, str(epoch)) '''''' perturbed_img_mark = self.location_mark(perturbed_img.copy(), sshape, (0, 0, 255)) shrink_paddig = 0 # 2 * edge_padding x_start, x_end, y_start, y_end = shrink_paddig, segment_h * (fiducial_points_col - 1) - shrink_paddig, shrink_paddig, segment_w * (fiducial_points_row - 1) - shrink_paddig x_ = (perturbed_img_mark.shape[0]-(x_end-x_start))//2 y_ = (perturbed_img_mark.shape[1]-(y_end-y_start))//2 flat_img_new = np.zeros_like(perturbed_img_mark) flat_img_new[x_:perturbed_img_mark.shape[0] - x_, y_:perturbed_img_mark.shape[1] - y_] = flat_img img_figure = np.concatenate( (perturbed_img_mark, flat_img_new), axis=1) if scheme == 'test': i_path += '/test' if not os.path.exists(i_path): os.makedirs(i_path) im_name = im_name.replace('gw', 'png') cv2.imwrite(i_path + '/' + im_name, img_figure) def flatByRegressWithClassiy_multiProcessV2(self, pred_fiducial_points, pred_segment, im_name, epoch, process_pool=None, perturbed_img=None, scheme='validate', is_scaling=False): for i_val_i in range(pred_fiducial_points.shape[0]): if self.postprocess == 'tps': self.flatByfiducial_TPS(pred_fiducial_points[i_val_i], pred_segment[i_val_i], im_name[i_val_i], epoch, None if perturbed_img is None else perturbed_img[i_val_i], scheme, is_scaling) elif self.postprocess == 'interpolation': self.flatByfiducial_interpolation(pred_fiducial_points[i_val_i], pred_segment[i_val_i], im_name[i_val_i], epoch, None if perturbed_img is None else perturbed_img[i_val_i], scheme, is_scaling) else: print('Error: Other postprocess.') exit() class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1, m=1): self.val = val self.sum += val * m self.count += n self.avg = self.sum / self.count class FlatImg(object): ''' args: self.save_flat_mage:Initialize the post-processing. Select a method in "postprocess_list". ''' def __init__(self, args, path, date, date_time, _re_date, model,\ reslut_file, n_classes, optimizer, \ model_D=None, optimizer_D=None, \ loss_fn=None, loss_fn2=None, data_loader=None, data_loader_hdf5=None, dataPackage_loader = None, \ data_path=None, data_path_validate=None, data_path_test=None, data_preproccess=True): #, valloaderSet, v_loaderSet self.args = args self.path = path self.date = date self.date_time = date_time self._re_date = _re_date # self.valloaderSet = valloaderSet # self.v_loaderSet = v_loaderSet self.model = model self.model_D = model_D self.reslut_file = reslut_file self.n_classes = n_classes self.optimizer = optimizer self.optimizer_D = optimizer_D self.loss_fn = loss_fn self.loss_fn2 = loss_fn2 self.data_loader = data_loader self.data_loader_hdf5 = data_loader_hdf5 self.dataPackage_loader = dataPackage_loader self.data_path = data_path self.data_path_validate = data_path_validate self.data_path_test = data_path_test self.data_preproccess = data_preproccess postprocess_list = ['tps', 'interpolation'] self.save_flat_mage = SaveFlatImage(self.path, self.date, self.date_time, self._re_date, self.data_path_validate, self.data_path_test, self.args.batch_size, self.data_preproccess, postprocess=postprocess_list[0], device=torch.device(self.args.device)) self.validate_loss = AverageMeter() self.validate_loss_regress = AverageMeter() self.validate_loss_segment = AverageMeter() self.lambda_loss = 1 self.lambda_loss_segment = 1 self.lambda_loss_a = 1 self.lambda_loss_b = 1 self.lambda_loss_c = 1 def loadTrainData(self, data_split, is_shuffle=True): train_loader = self.data_loader(self.data_path, split=data_split, img_shrink=self.args.img_shrink, preproccess=self.data_preproccess) trainloader = data.DataLoader(train_loader, batch_size=self.args.batch_size, num_workers=min(self.args.batch_size, 8), drop_last=True, pin_memory=True, shuffle=is_shuffle) return trainloader def loadValidateAndTestData(self, is_shuffle=True, sub_dir='shrink_512/crop/'): v1_loader = self.data_loader(self.data_path_validate, split='validate', img_shrink=self.args.img_shrink, is_return_img_name=True, preproccess=self.data_preproccess) valloader1 = data.DataLoader(v1_loader, batch_size=self.args.batch_size, num_workers=min(self.args.batch_size, 8), pin_memory=True, \ shuffle=is_shuffle) '''val sets''' v_loaderSet = { 'v1_loader': v1_loader, } valloaderSet = { 'valloader1': valloader1, } # sub_dir = 'crop/crop/' t1_loader = self.data_loader(self.data_path_test, split='test', img_shrink=self.args.img_shrink, is_return_img_name=True) testloader1 = data.DataLoader(t1_loader, batch_size=self.args.batch_size, num_workers=self.args.batch_size, pin_memory=True, \ shuffle=False) '''test sets''' t_loaderSet = { 't1_loader': v1_loader, } testloaderSet = { 'testloader1': testloader1, } self.valloaderSet = valloaderSet self.v_loaderSet = v_loaderSet self.testloaderSet = testloaderSet self.t_loaderSet = t_loaderSet # return v_loaderSet, valloaderSet def loadTestData(self, is_shuffle=True): t1_loader = self.data_loader(self.data_path_test, split='test', img_shrink=self.args.img_shrink, is_return_img_name=True) testloader1 = data.DataLoader(t1_loader, batch_size=self.args.batch_size, num_workers=self.args.batch_size, shuffle=False) '''test sets''' testloaderSet = { 'testloader1': testloader1, } self.testloaderSet = testloaderSet def evalData(self, is_shuffle=True, sub_dir='shrink_512/crop/'): eval_loader = self.data_loader(self.data_path_test, split='eval', img_shrink=self.args.img_shrink, is_return_img_name=True) evalloader = data.DataLoader(eval_loader, batch_size=self.args.batch_size, num_workers=self.args.batch_size, pin_memory=True, \ shuffle=False) self.evalloaderSet = evalloader # return v_loaderSet, valloaderSet def saveModel_epoch(self, epoch): epoch += 1 state = {'epoch': epoch, 'model_state': self.model.state_dict(), 'optimizer_state': self.optimizer.state_dict(), # AN ERROR HAS OCCURED } i_path = os.path.join(self.path, self.date + self.date_time + ' @' + self._re_date, str(epoch)) if self._re_date is not None else os.path.join(self.path, self.date + self.date_time, str(epoch)) if not os.path.exists(i_path): os.makedirs(i_path) if self._re_date is None: torch.save(state, i_path + '/' + self.date + self.date_time + "{}".format(self.args.arch) + ".pkl") # "./trained_model/{}_{}_best_model.pkl" else: torch.save(state, i_path + '/' + self._re_date + "@" + self.date + self.date_time + "{}".format( self.args.arch) + ".pkl") def validateOrTestModelV3(self, epoch, trian_t, validate_test='v_l2', is_scaling=False): if validate_test == 'v_l4': loss_segment_list = 0 loss_overall_list = 0 loss_local_list = 0 loss_edge_list = 0 loss_rectangles_list = 0 loss_list = [] begin_test = time.time() with torch.no_grad(): for i_valloader, valloader in enumerate(self.valloaderSet.values()): for i_val, (images, labels, segment, im_name) in enumerate(valloader): try: # save_img_ = random.choices([True, False], weights=[1, 0])[0] save_img_ = random.choices([True, False], weights=[0.05, 0.95])[0] # save_img_ = True images = Variable(images) labels = Variable(labels.cuda(self.args.device)) segment = Variable(segment.cuda(self.args.device)) outputs, outputs_segment = self.model(images) loss_overall, loss_local, loss_edge, loss_rectangles = self.loss_fn(outputs, labels, size_average=True) loss_segment = self.loss_fn2(outputs_segment, segment) loss = self.lambda_loss * (loss_overall + loss_local + loss_edge * self.lambda_loss_a + loss_rectangles * self.lambda_loss_b) + self.lambda_loss_segment * loss_segment # loss = self.lambda_loss * (loss_local + loss_rectangles + loss_edge*self.lambda_loss_a + loss_overall*self.lambda_loss_b) + self.lambda_loss_segment * loss_segment pred_regress = outputs.data.cpu().numpy().transpose(0, 2, 3, 1) # (4, 1280, 1024, 2) pred_segment = outputs_segment.data.round().int().cpu().numpy() # (4, 1280, 1024) ==outputs.data.argmax(dim=0).cpu().numpy() if save_img_: self.save_flat_mage.flatByRegressWithClassiy_multiProcessV2(pred_regress, pred_segment, im_name, epoch + 1, perturbed_img=images.numpy(), scheme='validate', is_scaling=is_scaling) loss_list.append(loss.item()) loss_segment_list += loss_segment.item() loss_overall_list += loss_overall.item() loss_local_list += loss_local.item() # loss_edge_list += loss_edge.item() # loss_rectangles_list += loss_rectangles.item() except: print('* save image validated error :'+im_name[0]) test_time = time.time() - begin_test # if always_save_model: # self.saveModel(epoch, save_path=self.path) list_len = len(loss_list) print('train time : {trian_t:.3f}\t' 'validate time : {test_time:.3f}\t' '[o:{overall_avg:.4f} l:{local_avg:.4f} e:{edge_avg:.4f} r:{rectangles_avg:.4f}\t' '[{loss_regress:.4f} {loss_segment:.4f}]\n'.format( trian_t=trian_t, test_time=test_time, overall_avg=loss_overall_list / list_len, local_avg=loss_local_list / list_len, edge_avg=loss_edge_list / list_len, rectangles_avg=loss_rectangles_list / list_len, loss_regress=(loss_overall_list+loss_local_list+loss_edge_list) / list_len, loss_segment=loss_segment_list / list_len)) print('train time : {trian_t:.3f}\t' 'validate time : {test_time:.3f}\t' '[o:{overall_avg:.4f} l:{local_avg:.4f} e:{edge_avg:.4f} r:{rectangles_avg:.4f}\t' '[{loss_regress:.4f} {loss_segment:.4f}]\n'.format( trian_t=trian_t, test_time=test_time, overall_avg=loss_overall_list / list_len, local_avg=loss_local_list / list_len, edge_avg=loss_edge_list / list_len, rectangles_avg=loss_rectangles_list / list_len, loss_regress=(loss_overall_list+loss_local_list+loss_edge_list) / list_len, loss_segment=loss_segment_list / list_len), file=self.reslut_file) elif validate_test == 't_all': begin_test = time.time() with torch.no_grad(): for i_valloader, valloader in enumerate(self.testloaderSet.values()): for i_val, (images, im_name) in enumerate(valloader): try: # save_img_ = True save_img_ = random.choices([True, False], weights=[1, 0])[0] # save_img_ = random.choices([True, False], weights=[0.2, 0.8])[0] if save_img_: images = Variable(images) outputs, outputs_segment = self.model(images) # outputs, outputs_segment = self.model(images, is_softmax=True) pred_regress = outputs.data.cpu().numpy().transpose(0, 2, 3, 1) pred_segment = outputs_segment.data.round().int().cpu().numpy() # (4, 1280, 1024) ==outputs.data.argmax(dim=0).cpu().numpy() self.save_flat_mage.flatByRegressWithClassiy_multiProcessV2(pred_regress, pred_segment, im_name, epoch + 1, scheme='test', is_scaling=is_scaling) except: print('* save image tested error :' + im_name[0]) test_time = time.time() - begin_test print('test time : {test_time:.3f}'.format( test_time=test_time)) print('test time : {test_time:.3f}'.format( test_time=test_time), file=self.reslut_file) else: begin_test = time.time() with torch.no_grad(): for i_valloader, valloader in enumerate(self.testloaderSet.values()): for i_val, (images, im_name) in enumerate(valloader): try: # save_img_ = True # save_img_ = random.choices([True, False], weights=[1, 0])[0] save_img_ = random.choices([True, False], weights=[0.4, 0.6])[0] if save_img_: images = Variable(images) outputs, outputs_segment = self.model(images) # outputs, outputs_segment = self.model(images, is_softmax=True) pred_regress = outputs.data.cpu().numpy().transpose(0, 2, 3, 1) pred_segment = outputs_segment.data.round().int().cpu().numpy() # (4, 1280, 1024) ==outputs.data.argmax(dim=0).cpu().numpy() self.save_flat_mage.flatByRegressWithClassiy_multiProcessV2(pred_regress, pred_segment, im_name, epoch + 1, scheme='test', is_scaling=is_scaling) except: print('* save image tested error :' + im_name[0]) test_time = time.time() - begin_test print('test time : {test_time:.3f}'.format( test_time=test_time)) print('test time : {test_time:.3f}'.format( test_time=test_time), file=self.reslut_file)