| import os |
| import numpy as np |
| from tqdm import tqdm |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.distributed as dist |
| import torch.multiprocessing as mp |
| import torch.optim as optim |
| import torch.utils.data.distributed |
|
|
| import sys |
| sys.path.append('../../') |
| import utils.utils as utils |
| import utils.visualize as vis_utils |
|
|
| |
| logger = utils.setup_custom_logger('root') |
| import logging |
| logging.getLogger('PIL').setLevel(logging.INFO) |
|
|
| |
| |
| import projects.dsine.config as config |
| from projects.baseline_normal.dataloader import * |
| from projects.dsine.losses import ComputeLoss |
| |
|
|
| BEST_KEY = 'mean' |
|
|
|
|
| def train(model, args, device): |
| if device is None: |
| device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') |
|
|
| should_write = ((not args.distributed) or args.rank == 0) |
| if should_write: |
| logger.info('Number of model parameters: %s' % sum([p.data.nelement() for p in model.parameters()])) |
|
|
| |
| train_loader = TrainLoader(args, epoch=0).data |
| val_loader = ValLoader(args).data |
|
|
| |
| loss_fn = ComputeLoss(args) |
|
|
| |
| if not args.diff_lr: |
| logger.info("Using same LR") |
| params = model.parameters() |
| else: |
| logger.info("Using diff LR") |
| m = model.module if args.multigpu else model |
| |
| |
| params = [{"params": m.get_1x_lr_params(), "lr": args.lr / 10}, |
| {"params": m.get_10x_lr_params(), "lr": args.lr}] |
| |
| optimizer = optim.AdamW(params, weight_decay=args.weight_decay, lr=args.lr) |
|
|
| |
| scheduler = optim.lr_scheduler.OneCycleLR(optimizer=optimizer, |
| max_lr=args.lr, |
| epochs=args.num_epochs, |
| steps_per_epoch=len(train_loader) // args.accumulate_grad_batches, |
| div_factor=25.0, |
| final_div_factor=10000.0) |
|
|
| |
| torch.backends.cudnn.enabled = True |
| torch.backends.cudnn.benchmark = True |
| scaler = torch.cuda.amp.GradScaler() |
|
|
| |
| best_acc = 1e4 |
|
|
| |
| total_iter = 0 |
| model.train() |
| for epoch in range(args.num_epochs): |
| train_loader = TrainLoader(args, epoch=epoch).data |
|
|
| if args.rank == 0: |
| data_loader = tqdm(train_loader, desc=f"Epoch: {epoch + 1}/{args.num_epochs}. Loop: Train") |
| else: |
| data_loader = train_loader |
|
|
| for batch_idx, data_dict in enumerate(data_loader): |
| total_iter += args.batch_size_orig |
|
|
| |
| |
| img = data_dict['img'].to(device) |
| gt_norm = data_dict['normal'].to(device) |
| gt_norm_mask = data_dict['normal_mask'].to(device) |
| intrins = data_dict['intrins'].to(device) |
|
|
| pred_list = model(img, intrins=intrins, mode='train') |
| loss = loss_fn(pred_list, gt_norm, gt_norm_mask) |
| |
|
|
| |
| loss = loss / args.accumulate_grad_batches |
| scaler.scale(loss).backward() |
|
|
| if ((batch_idx + 1) % args.accumulate_grad_batches == 0): |
| scaler.unscale_(optimizer) |
| nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) |
| scaler.step(optimizer) |
| scaler.update() |
| optimizer.zero_grad() |
| scheduler.step() |
|
|
| |
| if should_write: |
| loss_ = float(loss.data.cpu().numpy()) |
| args.writer.add_scalar('loss', loss_, global_step=total_iter) |
| data_loader.set_description(f"Epoch: {epoch + 1}/{args.num_epochs}. Loop: Train. Loss: {'%.5f' % loss_}") |
| data_loader.refresh() |
|
|
| |
| |
| if should_write and ((total_iter % args.visualize_every) < args.batch_size_orig): |
| for pred_idx, norm_out in enumerate(pred_list): |
| vis_ = vis_utils.visualize_normal_tb(args, img, norm_out, gt_norm, gt_norm_mask) |
| args.writer.add_image('train vis (iter: %s)' % pred_idx, vis_, global_step=total_iter, dataformats='HWC') |
| |
|
|
| |
| if should_write and ((total_iter % args.validate_every) < args.batch_size_orig): |
| if args.save_all_models: |
| utils.save_model(model, os.path.join(args.output_dir, 'models', 'iter_%010d.pt' % total_iter), total_iter) |
| else: |
| model.eval() |
| metrics = validate(model, args, val_loader, device, total_iter) |
| if metrics[BEST_KEY] <= best_acc: |
| utils.save_model(model, os.path.join(args.output_dir, 'models', 'best.pt'), total_iter) |
| best_acc = metrics[BEST_KEY] |
| print('best acc: %s' % best_acc) |
| model.train() |
|
|
| |
| if should_write: |
| if args.save_all_models: |
| utils.save_model(model, os.path.join(args.output_dir, 'models', 'iter_%010d.pt' % total_iter), total_iter) |
| else: |
| model.eval() |
| metrics = validate(model, args, val_loader, device, total_iter) |
| if metrics[BEST_KEY] <= best_acc: |
| utils.save_model(model, os.path.join(args.output_dir, 'models', 'best.pt'), total_iter) |
| best_acc = metrics[BEST_KEY] |
| print('best acc: %s' % best_acc) |
| if epoch+1 == args.num_epochs: |
| utils.save_model(model, os.path.join(args.output_dir, 'models', 'last.pt'), total_iter) |
| model.train() |
|
|
| del train_loader |
| return model |
|
|
| |
| def validate(model, args, val_loader, device, total_iter): |
| with torch.no_grad(): |
| total_metrics = utils.RunningAverageDict() |
| for data_dict in tqdm(val_loader, desc="Loop: Validation"): |
|
|
| |
| |
| img = data_dict['img'].to(device) |
| gt_norm = data_dict['normal'].to(device) |
| gt_norm_mask = data_dict['normal_mask'].to(device) |
| intrins = data_dict['intrins'].to(device) |
|
|
| norm_out = model(img, intrins=intrins, mode='test')[-1] |
| pred_norm = norm_out[:, :3, :, :] |
| |
|
|
| pred_error = utils.compute_normal_error(pred_norm, gt_norm) |
| metrics_, num_pixels = utils.compute_normal_metrics2(pred_error[gt_norm_mask]) |
| total_metrics.update(metrics_, num_pixels) |
|
|
| metrics = total_metrics.get_value() |
| metrics['rmse'] = np.sqrt(metrics['mse']) |
|
|
| |
| for k, v in metrics.items(): |
| args.writer.add_scalar(k, v, global_step=total_iter) |
| return metrics |
|
|
|
|
| |
| def main_worker(gpu, ngpus_per_node, args): |
| args.gpu = gpu |
|
|
| |
| |
| if args.NNET_architecture == 'v00': |
| from models.dsine.v00 import DSINE_v00 as DSINE |
| elif args.NNET_architecture == 'v01': |
| from models.dsine.v01 import DSINE_v01 as DSINE |
| elif args.NNET_architecture == 'v02': |
| from models.dsine.v02 import DSINE_v02 as DSINE |
| elif args.NNET_architecture == 'v02_kappa': |
| from models.dsine.v02_kappa import DSINE_v02_kappa as DSINE |
| else: |
| raise Exception('invalid model architecture') |
| model = DSINE(args) |
| |
|
|
| |
| if args.gpu is not None: |
| torch.cuda.set_device(args.gpu) |
| model = model.cuda(args.gpu) |
|
|
| args.multigpu = False |
| if args.distributed: |
| |
| args.multigpu = True |
| args.rank = args.rank * ngpus_per_node + gpu |
| dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, |
| world_size=args.world_size, rank=args.rank) |
| args.batch_size = int(args.batch_size / ngpus_per_node) |
| args.workers = int((args.num_workers + ngpus_per_node - 1) / ngpus_per_node) |
|
|
| logger.info('GPU: %s / RANK: %s / Batch size: %s / Num workers: %s' % |
| (args.gpu, args.rank, args.batch_size, args.workers)) |
|
|
| torch.cuda.set_device(args.gpu) |
| model = nn.SyncBatchNorm.convert_sync_batchnorm(model) |
| model = model.cuda(args.gpu) |
| model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], output_device=args.gpu, |
| find_unused_parameters=True) |
|
|
| elif args.gpu is None: |
| |
| args.multigpu = True |
| model = model.cuda() |
| model = torch.nn.DataParallel(model) |
|
|
| train(model, args, device=args.gpu) |
|
|
|
|
| if __name__ == '__main__': |
| args = config.get_args() |
|
|
| |
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| if args.gpus != '-1': |
| os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(list(args.gpus)) |
|
|
| args.world_size = 1 |
| args.rank = 0 |
| nodes = ["127.0.0.1"] |
|
|
| if args.distributed: |
| mp.set_start_method('forkserver') |
| port = np.random.randint(15000, 15025) |
| args.dist_url = 'tcp://{}:{}'.format(nodes[0], port) |
| args.dist_backend = 'nccl' |
| args.gpu = None |
|
|
| ngpus_per_node = torch.cuda.device_count() |
| args.num_workers = args.workers |
| args.ngpus_per_node = ngpus_per_node |
|
|
| args.batch_size_orig = args.batch_size |
|
|
| if args.distributed: |
| args.world_size = ngpus_per_node * args.world_size |
| mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) |
| else: |
| if ngpus_per_node == 1: |
| args.gpu = 0 |
| main_worker(args.gpu, ngpus_per_node, args) |
|
|