Spaces:
Runtime error
Runtime error
| """ | |
| Hello, welcome on board, | |
| """ | |
| from __future__ import print_function | |
| import argparse | |
| import os | |
| import time, platform | |
| import cv2 | |
| import numpy as np | |
| os.environ['CUDA_LAUNCH_BLOCKING']="0" | |
| import torch | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader | |
| from thop import profile | |
| from TEED.dataset import DATASET_NAMES, BipedDataset, TestDataset, dataset_info | |
| from TEED.loss2 import * | |
| from TEED.ted import TED # TEED architecture | |
| from TEED.utils.img_processing import (image_normalization, save_image_batch_to_disk, | |
| visualize_result, count_parameters) | |
| is_testing =True # set False to train with TEED model | |
| IS_LINUX = True if platform.system()=="Linux" else False | |
| def train_one_epoch(epoch, dataloader, model, criterions, optimizer, device, | |
| log_interval_vis, tb_writer, args=None): | |
| imgs_res_folder = os.path.join(args.output_dir, 'current_res') | |
| os.makedirs(imgs_res_folder,exist_ok=True) | |
| show_log = args.show_log | |
| if isinstance(criterions, list): | |
| criterion1, criterion2 = criterions | |
| else: | |
| criterion1 = criterions | |
| # Put model in training mode | |
| model.train() | |
| l_weight0 = [1.1,0.7,1.1,1.3] # for bdcn loss2-B4 | |
| l_weight = [[0.05, 2.], [0.05, 2.], [0.01, 1.], | |
| [0.01, 3.]] # for cats loss [0.01, 4.] | |
| loss_avg =[] | |
| for batch_id, sample_batched in enumerate(dataloader): | |
| images = sample_batched['images'].to(device) # BxCxHxW | |
| labels = sample_batched['labels'].to(device) # BxHxW | |
| preds_list = model(images) | |
| loss1 = sum([criterion2(preds, labels,l_w) for preds, l_w in zip(preds_list[:-1],l_weight0)]) # bdcn_loss2 [1,2,3] TEED | |
| loss2 = criterion1(preds_list[-1], labels, l_weight[-1], device) # cats_loss [dfuse] TEED | |
| tLoss = loss2+loss1 # TEED | |
| optimizer.zero_grad() | |
| tLoss.backward() | |
| optimizer.step() | |
| loss_avg.append(tLoss.item()) | |
| if epoch==0 and (batch_id==100 and tb_writer is not None): | |
| tmp_loss = np.array(loss_avg).mean() | |
| tb_writer.add_scalar('loss', tmp_loss,epoch) | |
| if batch_id % (show_log) == 0: | |
| print(time.ctime(), 'Epoch: {0} Sample {1}/{2} Loss: {3}' | |
| .format(epoch, batch_id, len(dataloader), format(tLoss.item(),'.4f'))) | |
| if batch_id % log_interval_vis == 0: | |
| res_data = [] | |
| img = images.cpu().numpy() | |
| res_data.append(img[2]) | |
| ed_gt = labels.cpu().numpy() | |
| res_data.append(ed_gt[2]) | |
| # tmp_pred = tmp_preds[2,...] | |
| for i in range(len(preds_list)): | |
| tmp = preds_list[i] | |
| tmp = tmp[2] | |
| # print(tmp.shape) | |
| tmp = torch.sigmoid(tmp).unsqueeze(dim=0) | |
| tmp = tmp.cpu().detach().numpy() | |
| res_data.append(tmp) | |
| vis_imgs = visualize_result(res_data, arg=args) | |
| del tmp, res_data | |
| vis_imgs = cv2.resize(vis_imgs, | |
| (int(vis_imgs.shape[1]*0.8), int(vis_imgs.shape[0]*0.8))) | |
| img_test = 'Epoch: {0} Iter: {1}/{2} Loss: {3}' \ | |
| .format(epoch, batch_id, len(dataloader), round(tLoss.item(),4)) | |
| BLACK = (0, 0, 255) | |
| font = cv2.FONT_HERSHEY_SIMPLEX | |
| font_size = 0.9 | |
| font_color = BLACK | |
| font_thickness = 2 | |
| x, y = 30, 30 | |
| vis_imgs = cv2.putText(vis_imgs, | |
| img_test, | |
| (x, y), | |
| font, font_size, font_color, font_thickness, cv2.LINE_AA) | |
| # tmp_vis_name = str(batch_id)+'-results.png' | |
| # cv2.imwrite(os.path.join(imgs_res_folder, tmp_vis_name), vis_imgs) | |
| cv2.imwrite(os.path.join(imgs_res_folder, 'results.png'), vis_imgs) | |
| loss_avg = np.array(loss_avg).mean() | |
| return loss_avg | |
| def validate_one_epoch(epoch, dataloader, model, device, output_dir, arg=None,test_resize=False): | |
| # XXX This is not really validation, but testing | |
| # Put model in eval mode | |
| model.eval() | |
| with torch.no_grad(): | |
| for _, sample_batched in enumerate(dataloader): | |
| images = sample_batched['images'].to(device) | |
| # labels = sample_batched['labels'].to(device) | |
| file_names = sample_batched['file_names'] | |
| image_shape = sample_batched['image_shape'] | |
| preds = model(images,single_test=test_resize) | |
| # print('pred shape', preds[0].shape) | |
| save_image_batch_to_disk(preds[-1], | |
| output_dir, | |
| file_names,img_shape=image_shape, | |
| arg=arg) | |
| def test(checkpoint_path, dataloader, model, device, output_dir, args,resize_input=False): | |
| if not os.path.isfile(checkpoint_path): | |
| raise FileNotFoundError( | |
| f"Checkpoint filte note found: {checkpoint_path}") | |
| print(f"Restoring weights from: {checkpoint_path}") | |
| model.load_state_dict(torch.load(checkpoint_path, | |
| map_location=device)) | |
| model.eval() | |
| # just for the new dataset | |
| # os.makedirs(os.path.join(output_dir,"healthy"), exist_ok=True) | |
| # os.makedirs(os.path.join(output_dir,"infected"), exist_ok=True) | |
| with torch.no_grad(): | |
| total_duration = [] | |
| for batch_id, sample_batched in enumerate(dataloader): | |
| images = sample_batched['images'].to(device) | |
| # if not args.test_data == "CLASSIC": | |
| labels = sample_batched['labels'].to(device) | |
| file_names = sample_batched['file_names'] | |
| image_shape = sample_batched['image_shape'] | |
| print(f"{file_names}: {images.shape}") | |
| end = time.perf_counter() | |
| if device.type == 'cuda': | |
| torch.cuda.synchronize() | |
| preds = model(images, single_test=resize_input) | |
| if device.type == 'cuda': | |
| torch.cuda.synchronize() | |
| tmp_duration = time.perf_counter() - end | |
| total_duration.append(tmp_duration) | |
| save_image_batch_to_disk(preds, | |
| output_dir, # output_dir | |
| file_names, | |
| image_shape, | |
| arg=args) | |
| torch.cuda.empty_cache() | |
| total_duration = np.sum(np.array(total_duration)) | |
| print("******** Testing finished in", args.test_data, "dataset. *****") | |
| print("FPS: %f.4" % (len(dataloader)/total_duration)) | |
| # print("Time spend in the Dataset: %f.4" % total_duration.sum(), "seconds") | |
| def testPich(checkpoint_path, dataloader, model, device, output_dir, args, resize_input=False): | |
| # a test model plus the interganged channels | |
| if not os.path.isfile(checkpoint_path): | |
| raise FileNotFoundError( | |
| f"Checkpoint filte note found: {checkpoint_path}") | |
| print(f"Restoring weights from: {checkpoint_path}") | |
| model.load_state_dict(torch.load(checkpoint_path, | |
| map_location=device)) | |
| model.eval() | |
| with torch.no_grad(): | |
| total_duration = [] | |
| for batch_id, sample_batched in enumerate(dataloader): | |
| images = sample_batched['images'].to(device) | |
| if not args.test_data == "CLASSIC": | |
| labels = sample_batched['labels'].to(device) | |
| file_names = sample_batched['file_names'] | |
| image_shape = sample_batched['image_shape'] | |
| print(f"input tensor shape: {images.shape}") | |
| start_time = time.time() | |
| images2 = images[:, [1, 0, 2], :, :] #GBR | |
| # images2 = images[:, [2, 1, 0], :, :] # RGB | |
| preds = model(images,single_test=resize_input) | |
| preds2 = model(images2,single_test=resize_input) | |
| tmp_duration = time.time() - start_time | |
| total_duration.append(tmp_duration) | |
| save_image_batch_to_disk([preds,preds2], | |
| output_dir, | |
| file_names, | |
| image_shape, | |
| arg=args, is_inchannel=True) | |
| torch.cuda.empty_cache() | |
| total_duration = np.array(total_duration) | |
| print("******** Testing finished in", args.test_data, "dataset. *****") | |
| print("Average time per image: %f.4" % total_duration.mean(), "seconds") | |
| print("Time spend in the Dataset: %f.4" % total_duration.sum(), "seconds") | |
| def parse_args(is_testing=True, pl_opt_dir='output/teed'): | |
| """Parse command line arguments.""" | |
| parser = argparse.ArgumentParser(description='TEED model') | |
| parser.add_argument('--choose_test_data', | |
| type=int, | |
| default=-1, # UDED=15 | |
| help='Choose a dataset for testing: 0 - 15') | |
| # 新增的 epoch 参数 | |
| parser.add_argument('--epoch', type=int, required=True, help='Epoch number') | |
| # ----------- test -------0-- | |
| TEST_DATA = DATASET_NAMES[parser.parse_args().choose_test_data] # max 8 | |
| test_inf = dataset_info(TEST_DATA, is_linux=IS_LINUX) | |
| # Training settings | |
| TRAIN_DATA = DATASET_NAMES[0] # BIPED=0, BRIND=6, MDBD=10, BIPBRI=13 | |
| train_inf = dataset_info(TRAIN_DATA, is_linux=IS_LINUX) | |
| train_dir = train_inf['data_dir'] | |
| # Data parameters | |
| parser.add_argument('--input_dir', | |
| type=str, | |
| default=train_dir, | |
| help='the path to the directory with the input data.') | |
| parser.add_argument('--input_val_dir', | |
| type=str, | |
| default=test_inf['data_dir'], | |
| help='the path to the directory with the input data for validation.') | |
| parser.add_argument('--output_dir', | |
| type=str, | |
| default='checkpoints', | |
| help='the path to output the results.') | |
| parser.add_argument('--train_data', | |
| type=str, | |
| choices=DATASET_NAMES, | |
| default=TRAIN_DATA, | |
| help='Name of the dataset.')# TRAIN_DATA,BIPED-B3 | |
| parser.add_argument('--test_data', | |
| type=str, | |
| choices=DATASET_NAMES, | |
| default=TEST_DATA, | |
| help='Name of the dataset.') | |
| parser.add_argument('--test_list', | |
| type=str, | |
| default=test_inf['test_list'], | |
| help='Dataset sample indices list.') | |
| parser.add_argument('--train_list', | |
| type=str, | |
| default=train_inf['train_list'], | |
| help='Dataset sample indices list.') | |
| parser.add_argument('--is_testing',type=bool, | |
| default=is_testing, | |
| help='Script in testing mode.') | |
| parser.add_argument('--predict_all', | |
| type=bool, | |
| default=False, | |
| help='True: Generate all TEED outputs in all_edges ') | |
| parser.add_argument('--up_scale', | |
| type=bool, | |
| default=False, # for Upsale test set in 30% | |
| help='True: up scale x1.5 test image') # Just for test | |
| parser.add_argument('--resume', | |
| type=bool, | |
| default=False, | |
| help='use previous trained data') # Just for test | |
| parser.add_argument('--checkpoint_data', | |
| type=str, | |
| default='5/5_model.pth',# 37 for biped 60 MDBD | |
| help='Checkpoint path.') | |
| parser.add_argument('--test_img_width', | |
| type=int, | |
| default=test_inf['img_width'], | |
| help='Image width for testing.') | |
| parser.add_argument('--test_img_height', | |
| type=int, | |
| default=test_inf['img_height'], | |
| help='Image height for testing.') | |
| parser.add_argument('--res_dir', | |
| type=str, | |
| default='result', | |
| help='Result directory') | |
| parser.add_argument('--use_gpu',type=int, | |
| default=0, help='use GPU') | |
| parser.add_argument('--log_interval_vis', | |
| type=int, | |
| default=200,# 100 | |
| help='Interval to visualize predictions. 200') | |
| parser.add_argument('--show_log', type=int, default=20, help='display logs') | |
| parser.add_argument('--epochs', | |
| type=int, | |
| default=8, | |
| metavar='N', | |
| help='Number of training epochs (default: 25).') | |
| parser.add_argument('--lr', default=8e-4, type=float, | |
| help='Initial learning rate. =1e-3') # 1e-3 | |
| parser.add_argument('--lrs', default=[8e-5], type=float, | |
| help='LR for epochs') # [7e-5] | |
| parser.add_argument('--wd', type=float, default=2e-4, metavar='WD', | |
| help='weight decay (Good 5e-4/1e-4 )') # good 12e-5 | |
| parser.add_argument('--adjust_lr', default=[4], type=int, | |
| help='Learning rate step size.') # [4] [6,9,19] | |
| parser.add_argument('--version_notes', | |
| default='TEED BIPED+BRIND-trainingdataLoader BRIND light AF -USNet--noBN xav init normal bdcnLoss2+cats2loss +DoubleFusion-3AF, AF sum', | |
| type=str, | |
| help='version notes') | |
| parser.add_argument('--batch_size', | |
| type=int, | |
| default=8, | |
| metavar='B', | |
| help='the mini-batch size (default: 8)') | |
| parser.add_argument('--workers', | |
| default=8, | |
| type=int, | |
| help='The number of workers for the dataloaders.') | |
| parser.add_argument('--tensorboard',type=bool, | |
| default=True, | |
| help='Use Tensorboard for logging.'), | |
| parser.add_argument('--img_width', | |
| type=int, | |
| default=300, | |
| help='Image width for training.') # BIPED 352/300 BRIND 256 MDBD 480 | |
| parser.add_argument('--img_height', | |
| type=int, | |
| default=300, | |
| help='Image height for training.') # BIPED 352/300 BSDS 352/320 | |
| parser.add_argument('--channel_swap', | |
| default=[2, 1, 0], | |
| type=int) | |
| parser.add_argument('--resume_chpt', | |
| default='result/resume/', | |
| type=str, | |
| help='resume training') | |
| parser.add_argument('--pl_opt_dir', | |
| default=pl_opt_dir, | |
| type=str, | |
| help='pl output directory') | |
| parser.add_argument('--crop_img', | |
| default=True, | |
| type=bool, | |
| help='If true crop training images, else resize images to match image width and height.') | |
| parser.add_argument('--mean_test', | |
| default=test_inf['mean'], | |
| type=float) | |
| parser.add_argument('--mean_train', | |
| default=train_inf['mean'], | |
| type=float) # [103.939,116.779,123.68,137.86] [104.00699, 116.66877, 122.67892] | |
| args = parser.parse_args() | |
| return args, train_inf | |
| def main(args, train_inf): | |
| # Tensorboard summary writer | |
| # torch.autograd.set_detect_anomaly(True) | |
| tb_writer = None | |
| training_dir = os.path.join(args.output_dir,args.train_data) | |
| os.makedirs(training_dir,exist_ok=True) | |
| checkpoint_path = './TEED/checkpoints/BIPED/5/5_model.pth' | |
| checkpoint_path = os.path.join('./TEED/checkpoints', 'BIPED', str(args.epochs), '5_model.pth') | |
| if args.tensorboard and not args.is_testing: | |
| # from tensorboardX import SummaryWriter # previous torch version | |
| from torch.utils.tensorboard import SummaryWriter # for torch 1.4 or greather | |
| tb_writer = SummaryWriter(log_dir=training_dir) | |
| # saving training settings | |
| training_notes =[args.version_notes+ ' RL= ' + str(args.lr) + ' WD= ' | |
| + str(args.wd) + ' image size = ' + str(args.img_width) | |
| + ' adjust LR=' + str(args.adjust_lr) +' LRs= ' | |
| + str(args.lrs)+' Loss Function= BDCNloss2 + CAST-loss2.py ' | |
| + str(time.asctime())+' trained on '+args.train_data] | |
| info_txt = open(os.path.join(training_dir, 'training_settings.txt'), 'w') | |
| info_txt.write(str(training_notes)) | |
| info_txt.close() | |
| print("Training details> ",training_notes) | |
| # Get computing device | |
| device = torch.device('cpu' if torch.cuda.device_count() == 0 | |
| else 'cuda') | |
| # torch.cuda.set_device(args.use_gpu) # set a desired gpu | |
| print(f"Number of GPU's available: {torch.cuda.device_count()}") | |
| print(f"Pytorch version: {torch.__version__}") | |
| # print(f'GPU: {torch.cuda.get_device_name()}') | |
| print(f'Trainimage mean: {args.mean_train}') | |
| print(f'Test image mean: {args.mean_test}') | |
| # Instantiate model and move it to the computing device | |
| model = TED().to(device) | |
| # model = nn.DataParallel(model) | |
| ini_epoch =0 | |
| if not args.is_testing: | |
| if args.resume: | |
| checkpoint_path2= os.path.join(args.output_dir, 'BIPED-54-B4',args.checkpoint_data) | |
| ini_epoch=8 | |
| model.load_state_dict(torch.load(checkpoint_path2, | |
| map_location=device)) | |
| # Training dataset loading... | |
| dataset_train = BipedDataset(args.input_dir, | |
| img_width=args.img_width, | |
| img_height=args.img_height, | |
| train_mode='train', | |
| arg=args | |
| ) | |
| dataloader_train = DataLoader(dataset_train, | |
| batch_size=args.batch_size, | |
| shuffle=True, | |
| num_workers=args.workers) | |
| # Test dataset loading... | |
| dataset_val = TestDataset(args.input_val_dir, | |
| test_data=args.test_data, | |
| img_width=args.test_img_width, | |
| img_height=args.test_img_height, | |
| test_list=args.test_list, arg=args | |
| ) | |
| dataloader_val = DataLoader(dataset_val, | |
| batch_size=1, | |
| shuffle=False, | |
| num_workers=args.workers) | |
| # Testing | |
| if_resize_img = False if args.test_data in ['BIPED', 'CID', 'MDBD'] else True | |
| if args.is_testing: | |
| # output_dir = os.path.join(args.res_dir, args.train_data+"2"+ args.test_data) | |
| output_dir = args.pl_opt_dir | |
| print(f"output_dir: {output_dir}") | |
| test(checkpoint_path, dataloader_val, model, device, | |
| output_dir, args,if_resize_img) | |
| # Count parameters: | |
| num_param = count_parameters(model) | |
| print('-------------------------------------------------------') | |
| print('TED parameters:') | |
| print(num_param) | |
| print('-------------------------------------------------------') | |
| return | |
| criterion1 = cats_loss #bdcn_loss2 | |
| criterion2 = bdcn_loss2#cats_loss#f1_accuracy2 | |
| criterion = [criterion1,criterion2] | |
| optimizer = optim.Adam(model.parameters(), | |
| lr=args.lr, | |
| weight_decay=args.wd) | |
| # Count parameters: | |
| num_param = count_parameters(model) | |
| print('-------------------------------------------------------') | |
| print('TEED parameters:') | |
| print(num_param) | |
| print('-------------------------------------------------------') | |
| # Main training loop | |
| seed=1021 | |
| adjust_lr = args.adjust_lr | |
| k=0 | |
| set_lr = args.lrs#[25e-4, 5e-6] | |
| for epoch in range(ini_epoch,args.epochs): | |
| if epoch%5==0: # before 7 | |
| seed = seed+1000 | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| print("------ Random seed applied-------------") | |
| # adjust learning rate | |
| if adjust_lr is not None: | |
| if epoch in adjust_lr: | |
| lr2 = set_lr[k] | |
| for param_group in optimizer.param_groups: | |
| param_group['lr'] = lr2 | |
| k+=1 | |
| # Create output directories | |
| output_dir_epoch = os.path.join(args.output_dir,args.train_data, str(epoch)) | |
| img_test_dir = os.path.join(output_dir_epoch, args.test_data + '_res') | |
| os.makedirs(output_dir_epoch,exist_ok=True) | |
| os.makedirs(img_test_dir,exist_ok=True) | |
| print("**************** Validating the training from the scratch **********") | |
| # validate_one_epoch(epoch, | |
| # dataloader_val, | |
| # model, | |
| # device, | |
| # img_test_dir, | |
| # arg=args,test_resize=if_resize_img) | |
| avg_loss =train_one_epoch(epoch,dataloader_train, | |
| model, criterion, | |
| optimizer, | |
| device, | |
| args.log_interval_vis, | |
| tb_writer=tb_writer, | |
| args=args) | |
| validate_one_epoch(epoch, | |
| dataloader_val, | |
| model, | |
| device, | |
| img_test_dir, | |
| arg=args, test_resize=if_resize_img) | |
| # Save model after end of every epoch | |
| torch.save(model.module.state_dict() if hasattr(model, "module") else model.state_dict(), | |
| os.path.join(output_dir_epoch, '{0}_model.pth'.format(epoch))) | |
| if tb_writer is not None: | |
| tb_writer.add_scalar('loss', | |
| avg_loss, | |
| epoch+1) | |
| print('Last learning rate> ', optimizer.param_groups[0]['lr']) | |
| num_param = count_parameters(model) | |
| print('-------------------------------------------------------') | |
| print('TEED parameters:') | |
| print(num_param) | |
| print('-------------------------------------------------------') | |
| if __name__ == '__main__': | |
| # os.system(" ".join(command)) | |
| is_testing =True # True to use TEED for testing | |
| args, train_info = parse_args(is_testing=is_testing) | |
| main(args, train_info) | |