""" Hello, welcome on board, """ from __future__ import print_function import argparse import os import time, platform import cv2 import numpy as np os.environ['CUDA_LAUNCH_BLOCKING']="0" import torch import torch.optim as optim from torch.utils.data import DataLoader from thop import profile from TEED.dataset import DATASET_NAMES, BipedDataset, TestDataset, dataset_info from TEED.loss2 import * from TEED.ted import TED # TEED architecture from TEED.utils.img_processing import (image_normalization, save_image_batch_to_disk, visualize_result, count_parameters) is_testing =True # set False to train with TEED model IS_LINUX = True if platform.system()=="Linux" else False def train_one_epoch(epoch, dataloader, model, criterions, optimizer, device, log_interval_vis, tb_writer, args=None): imgs_res_folder = os.path.join(args.output_dir, 'current_res') os.makedirs(imgs_res_folder,exist_ok=True) show_log = args.show_log if isinstance(criterions, list): criterion1, criterion2 = criterions else: criterion1 = criterions # Put model in training mode model.train() l_weight0 = [1.1,0.7,1.1,1.3] # for bdcn loss2-B4 l_weight = [[0.05, 2.], [0.05, 2.], [0.01, 1.], [0.01, 3.]] # for cats loss [0.01, 4.] loss_avg =[] for batch_id, sample_batched in enumerate(dataloader): images = sample_batched['images'].to(device) # BxCxHxW labels = sample_batched['labels'].to(device) # BxHxW preds_list = model(images) loss1 = sum([criterion2(preds, labels,l_w) for preds, l_w in zip(preds_list[:-1],l_weight0)]) # bdcn_loss2 [1,2,3] TEED loss2 = criterion1(preds_list[-1], labels, l_weight[-1], device) # cats_loss [dfuse] TEED tLoss = loss2+loss1 # TEED optimizer.zero_grad() tLoss.backward() optimizer.step() loss_avg.append(tLoss.item()) if epoch==0 and (batch_id==100 and tb_writer is not None): tmp_loss = np.array(loss_avg).mean() tb_writer.add_scalar('loss', tmp_loss,epoch) if batch_id % (show_log) == 0: print(time.ctime(), 'Epoch: {0} Sample {1}/{2} Loss: {3}' .format(epoch, batch_id, len(dataloader), format(tLoss.item(),'.4f'))) if batch_id % log_interval_vis == 0: res_data = [] img = images.cpu().numpy() res_data.append(img[2]) ed_gt = labels.cpu().numpy() res_data.append(ed_gt[2]) # tmp_pred = tmp_preds[2,...] for i in range(len(preds_list)): tmp = preds_list[i] tmp = tmp[2] # print(tmp.shape) tmp = torch.sigmoid(tmp).unsqueeze(dim=0) tmp = tmp.cpu().detach().numpy() res_data.append(tmp) vis_imgs = visualize_result(res_data, arg=args) del tmp, res_data vis_imgs = cv2.resize(vis_imgs, (int(vis_imgs.shape[1]*0.8), int(vis_imgs.shape[0]*0.8))) img_test = 'Epoch: {0} Iter: {1}/{2} Loss: {3}' \ .format(epoch, batch_id, len(dataloader), round(tLoss.item(),4)) BLACK = (0, 0, 255) font = cv2.FONT_HERSHEY_SIMPLEX font_size = 0.9 font_color = BLACK font_thickness = 2 x, y = 30, 30 vis_imgs = cv2.putText(vis_imgs, img_test, (x, y), font, font_size, font_color, font_thickness, cv2.LINE_AA) # tmp_vis_name = str(batch_id)+'-results.png' # cv2.imwrite(os.path.join(imgs_res_folder, tmp_vis_name), vis_imgs) cv2.imwrite(os.path.join(imgs_res_folder, 'results.png'), vis_imgs) loss_avg = np.array(loss_avg).mean() return loss_avg def validate_one_epoch(epoch, dataloader, model, device, output_dir, arg=None,test_resize=False): # XXX This is not really validation, but testing # Put model in eval mode model.eval() with torch.no_grad(): for _, sample_batched in enumerate(dataloader): images = sample_batched['images'].to(device) # labels = sample_batched['labels'].to(device) file_names = sample_batched['file_names'] image_shape = sample_batched['image_shape'] preds = model(images,single_test=test_resize) # print('pred shape', preds[0].shape) save_image_batch_to_disk(preds[-1], output_dir, file_names,img_shape=image_shape, arg=arg) def test(checkpoint_path, dataloader, model, device, output_dir, args,resize_input=False): if not os.path.isfile(checkpoint_path): raise FileNotFoundError( f"Checkpoint filte note found: {checkpoint_path}") print(f"Restoring weights from: {checkpoint_path}") model.load_state_dict(torch.load(checkpoint_path, map_location=device)) model.eval() # just for the new dataset # os.makedirs(os.path.join(output_dir,"healthy"), exist_ok=True) # os.makedirs(os.path.join(output_dir,"infected"), exist_ok=True) with torch.no_grad(): total_duration = [] for batch_id, sample_batched in enumerate(dataloader): images = sample_batched['images'].to(device) # if not args.test_data == "CLASSIC": labels = sample_batched['labels'].to(device) file_names = sample_batched['file_names'] image_shape = sample_batched['image_shape'] print(f"{file_names}: {images.shape}") end = time.perf_counter() if device.type == 'cuda': torch.cuda.synchronize() preds = model(images, single_test=resize_input) if device.type == 'cuda': torch.cuda.synchronize() tmp_duration = time.perf_counter() - end total_duration.append(tmp_duration) save_image_batch_to_disk(preds, output_dir, # output_dir file_names, image_shape, arg=args) torch.cuda.empty_cache() total_duration = np.sum(np.array(total_duration)) print("******** Testing finished in", args.test_data, "dataset. *****") print("FPS: %f.4" % (len(dataloader)/total_duration)) # print("Time spend in the Dataset: %f.4" % total_duration.sum(), "seconds") def testPich(checkpoint_path, dataloader, model, device, output_dir, args, resize_input=False): # a test model plus the interganged channels if not os.path.isfile(checkpoint_path): raise FileNotFoundError( f"Checkpoint filte note found: {checkpoint_path}") print(f"Restoring weights from: {checkpoint_path}") model.load_state_dict(torch.load(checkpoint_path, map_location=device)) model.eval() with torch.no_grad(): total_duration = [] for batch_id, sample_batched in enumerate(dataloader): images = sample_batched['images'].to(device) if not args.test_data == "CLASSIC": labels = sample_batched['labels'].to(device) file_names = sample_batched['file_names'] image_shape = sample_batched['image_shape'] print(f"input tensor shape: {images.shape}") start_time = time.time() images2 = images[:, [1, 0, 2], :, :] #GBR # images2 = images[:, [2, 1, 0], :, :] # RGB preds = model(images,single_test=resize_input) preds2 = model(images2,single_test=resize_input) tmp_duration = time.time() - start_time total_duration.append(tmp_duration) save_image_batch_to_disk([preds,preds2], output_dir, file_names, image_shape, arg=args, is_inchannel=True) torch.cuda.empty_cache() total_duration = np.array(total_duration) print("******** Testing finished in", args.test_data, "dataset. *****") print("Average time per image: %f.4" % total_duration.mean(), "seconds") print("Time spend in the Dataset: %f.4" % total_duration.sum(), "seconds") def parse_args(is_testing=True, pl_opt_dir='output/teed'): """Parse command line arguments.""" parser = argparse.ArgumentParser(description='TEED model') parser.add_argument('--choose_test_data', type=int, default=-1, # UDED=15 help='Choose a dataset for testing: 0 - 15') # 新增的 epoch 参数 parser.add_argument('--epoch', type=int, required=True, help='Epoch number') # ----------- test -------0-- TEST_DATA = DATASET_NAMES[parser.parse_args().choose_test_data] # max 8 test_inf = dataset_info(TEST_DATA, is_linux=IS_LINUX) # Training settings TRAIN_DATA = DATASET_NAMES[0] # BIPED=0, BRIND=6, MDBD=10, BIPBRI=13 train_inf = dataset_info(TRAIN_DATA, is_linux=IS_LINUX) train_dir = train_inf['data_dir'] # Data parameters parser.add_argument('--input_dir', type=str, default=train_dir, help='the path to the directory with the input data.') parser.add_argument('--input_val_dir', type=str, default=test_inf['data_dir'], help='the path to the directory with the input data for validation.') parser.add_argument('--output_dir', type=str, default='checkpoints', help='the path to output the results.') parser.add_argument('--train_data', type=str, choices=DATASET_NAMES, default=TRAIN_DATA, help='Name of the dataset.')# TRAIN_DATA,BIPED-B3 parser.add_argument('--test_data', type=str, choices=DATASET_NAMES, default=TEST_DATA, help='Name of the dataset.') parser.add_argument('--test_list', type=str, default=test_inf['test_list'], help='Dataset sample indices list.') parser.add_argument('--train_list', type=str, default=train_inf['train_list'], help='Dataset sample indices list.') parser.add_argument('--is_testing',type=bool, default=is_testing, help='Script in testing mode.') parser.add_argument('--predict_all', type=bool, default=False, help='True: Generate all TEED outputs in all_edges ') parser.add_argument('--up_scale', type=bool, default=False, # for Upsale test set in 30% help='True: up scale x1.5 test image') # Just for test parser.add_argument('--resume', type=bool, default=False, help='use previous trained data') # Just for test parser.add_argument('--checkpoint_data', type=str, default='5/5_model.pth',# 37 for biped 60 MDBD help='Checkpoint path.') parser.add_argument('--test_img_width', type=int, default=test_inf['img_width'], help='Image width for testing.') parser.add_argument('--test_img_height', type=int, default=test_inf['img_height'], help='Image height for testing.') parser.add_argument('--res_dir', type=str, default='result', help='Result directory') parser.add_argument('--use_gpu',type=int, default=0, help='use GPU') parser.add_argument('--log_interval_vis', type=int, default=200,# 100 help='Interval to visualize predictions. 200') parser.add_argument('--show_log', type=int, default=20, help='display logs') parser.add_argument('--epochs', type=int, default=8, metavar='N', help='Number of training epochs (default: 25).') parser.add_argument('--lr', default=8e-4, type=float, help='Initial learning rate. =1e-3') # 1e-3 parser.add_argument('--lrs', default=[8e-5], type=float, help='LR for epochs') # [7e-5] parser.add_argument('--wd', type=float, default=2e-4, metavar='WD', help='weight decay (Good 5e-4/1e-4 )') # good 12e-5 parser.add_argument('--adjust_lr', default=[4], type=int, help='Learning rate step size.') # [4] [6,9,19] parser.add_argument('--version_notes', default='TEED BIPED+BRIND-trainingdataLoader BRIND light AF -USNet--noBN xav init normal bdcnLoss2+cats2loss +DoubleFusion-3AF, AF sum', type=str, help='version notes') parser.add_argument('--batch_size', type=int, default=8, metavar='B', help='the mini-batch size (default: 8)') parser.add_argument('--workers', default=8, type=int, help='The number of workers for the dataloaders.') parser.add_argument('--tensorboard',type=bool, default=True, help='Use Tensorboard for logging.'), parser.add_argument('--img_width', type=int, default=300, help='Image width for training.') # BIPED 352/300 BRIND 256 MDBD 480 parser.add_argument('--img_height', type=int, default=300, help='Image height for training.') # BIPED 352/300 BSDS 352/320 parser.add_argument('--channel_swap', default=[2, 1, 0], type=int) parser.add_argument('--resume_chpt', default='result/resume/', type=str, help='resume training') parser.add_argument('--pl_opt_dir', default=pl_opt_dir, type=str, help='pl output directory') parser.add_argument('--crop_img', default=True, type=bool, help='If true crop training images, else resize images to match image width and height.') parser.add_argument('--mean_test', default=test_inf['mean'], type=float) parser.add_argument('--mean_train', default=train_inf['mean'], type=float) # [103.939,116.779,123.68,137.86] [104.00699, 116.66877, 122.67892] args = parser.parse_args() return args, train_inf def main(args, train_inf): # Tensorboard summary writer # torch.autograd.set_detect_anomaly(True) tb_writer = None training_dir = os.path.join(args.output_dir,args.train_data) os.makedirs(training_dir,exist_ok=True) checkpoint_path = './TEED/checkpoints/BIPED/5/5_model.pth' checkpoint_path = os.path.join('./TEED/checkpoints', 'BIPED', str(args.epochs), '5_model.pth') if args.tensorboard and not args.is_testing: # from tensorboardX import SummaryWriter # previous torch version from torch.utils.tensorboard import SummaryWriter # for torch 1.4 or greather tb_writer = SummaryWriter(log_dir=training_dir) # saving training settings training_notes =[args.version_notes+ ' RL= ' + str(args.lr) + ' WD= ' + str(args.wd) + ' image size = ' + str(args.img_width) + ' adjust LR=' + str(args.adjust_lr) +' LRs= ' + str(args.lrs)+' Loss Function= BDCNloss2 + CAST-loss2.py ' + str(time.asctime())+' trained on '+args.train_data] info_txt = open(os.path.join(training_dir, 'training_settings.txt'), 'w') info_txt.write(str(training_notes)) info_txt.close() print("Training details> ",training_notes) # Get computing device device = torch.device('cpu' if torch.cuda.device_count() == 0 else 'cuda') # torch.cuda.set_device(args.use_gpu) # set a desired gpu print(f"Number of GPU's available: {torch.cuda.device_count()}") print(f"Pytorch version: {torch.__version__}") # print(f'GPU: {torch.cuda.get_device_name()}') print(f'Trainimage mean: {args.mean_train}') print(f'Test image mean: {args.mean_test}') # Instantiate model and move it to the computing device model = TED().to(device) # model = nn.DataParallel(model) ini_epoch =0 if not args.is_testing: if args.resume: checkpoint_path2= os.path.join(args.output_dir, 'BIPED-54-B4',args.checkpoint_data) ini_epoch=8 model.load_state_dict(torch.load(checkpoint_path2, map_location=device)) # Training dataset loading... dataset_train = BipedDataset(args.input_dir, img_width=args.img_width, img_height=args.img_height, train_mode='train', arg=args ) dataloader_train = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=args.workers) # Test dataset loading... dataset_val = TestDataset(args.input_val_dir, test_data=args.test_data, img_width=args.test_img_width, img_height=args.test_img_height, test_list=args.test_list, arg=args ) dataloader_val = DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=args.workers) # Testing if_resize_img = False if args.test_data in ['BIPED', 'CID', 'MDBD'] else True if args.is_testing: # output_dir = os.path.join(args.res_dir, args.train_data+"2"+ args.test_data) output_dir = args.pl_opt_dir print(f"output_dir: {output_dir}") test(checkpoint_path, dataloader_val, model, device, output_dir, args,if_resize_img) # Count parameters: num_param = count_parameters(model) print('-------------------------------------------------------') print('TED parameters:') print(num_param) print('-------------------------------------------------------') return criterion1 = cats_loss #bdcn_loss2 criterion2 = bdcn_loss2#cats_loss#f1_accuracy2 criterion = [criterion1,criterion2] optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) # Count parameters: num_param = count_parameters(model) print('-------------------------------------------------------') print('TEED parameters:') print(num_param) print('-------------------------------------------------------') # Main training loop seed=1021 adjust_lr = args.adjust_lr k=0 set_lr = args.lrs#[25e-4, 5e-6] for epoch in range(ini_epoch,args.epochs): if epoch%5==0: # before 7 seed = seed+1000 np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) print("------ Random seed applied-------------") # adjust learning rate if adjust_lr is not None: if epoch in adjust_lr: lr2 = set_lr[k] for param_group in optimizer.param_groups: param_group['lr'] = lr2 k+=1 # Create output directories output_dir_epoch = os.path.join(args.output_dir,args.train_data, str(epoch)) img_test_dir = os.path.join(output_dir_epoch, args.test_data + '_res') os.makedirs(output_dir_epoch,exist_ok=True) os.makedirs(img_test_dir,exist_ok=True) print("**************** Validating the training from the scratch **********") # validate_one_epoch(epoch, # dataloader_val, # model, # device, # img_test_dir, # arg=args,test_resize=if_resize_img) avg_loss =train_one_epoch(epoch,dataloader_train, model, criterion, optimizer, device, args.log_interval_vis, tb_writer=tb_writer, args=args) validate_one_epoch(epoch, dataloader_val, model, device, img_test_dir, arg=args, test_resize=if_resize_img) # Save model after end of every epoch torch.save(model.module.state_dict() if hasattr(model, "module") else model.state_dict(), os.path.join(output_dir_epoch, '{0}_model.pth'.format(epoch))) if tb_writer is not None: tb_writer.add_scalar('loss', avg_loss, epoch+1) print('Last learning rate> ', optimizer.param_groups[0]['lr']) num_param = count_parameters(model) print('-------------------------------------------------------') print('TEED parameters:') print(num_param) print('-------------------------------------------------------') if __name__ == '__main__': # os.system(" ".join(command)) is_testing =True # True to use TEED for testing args, train_info = parse_args(is_testing=is_testing) main(args, train_info)