Spaces:
Build error
Build error
| import math | |
| import parse | |
| import logging | |
| from utils import util | |
| from torch.utils.data.distributed import DistributedSampler | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| from data import create_dataset, create_dataloader | |
| from models.utils.loss import * | |
| import yaml | |
| from abc import abstractmethod, ABCMeta | |
| from models.utils.flow_losses import AdversarialLoss | |
| class Trainer(metaclass=ABCMeta): | |
| def __init__(self, opt, rank): | |
| self.opt = opt | |
| self.rank = rank | |
| # make directory and set logger | |
| if rank <= 0: | |
| self.mkdir() | |
| self.logger, self.tb_logger = self.setLogger() | |
| self.setSeed() | |
| self.dataInfo, self.valInfo, self.trainSet, self.trainSize, self.totalIterations, self.totalEpochs, self.trainLoader, self.trainSampler = self.prepareDataset() | |
| self.model, self.dist, self.optimizer, self.dist_optim, self.scheduler, self.dist_scheduler = self.init_model() | |
| self.flow_model = self.init_flow_model() | |
| self.model = self.model.to(self.opt['device']) | |
| self.dist = self.dist.to(self.opt['device']) | |
| if opt['path'].get('gen_state', None): | |
| self.startEpoch, self.currentStep = self.resume_training() | |
| else: | |
| self.startEpoch, self.currentStep = 0, 0 | |
| if opt['distributed']: | |
| self.model = DDP( | |
| self.model, | |
| device_ids=[self.opt['local_rank']], | |
| output_device=self.opt['local_rank'], | |
| find_unused_parameters=True | |
| ) | |
| self.dist = DDP( | |
| self.dist, | |
| device_ids=[self.opt['local_rank']], | |
| output_device=self.opt['local_rank'], | |
| find_unused_parameters=True | |
| ) | |
| if self.rank <= 0: | |
| self.logger.info('Start training from epoch: {}, iter: {}'.format( | |
| self.startEpoch, self.currentStep)) | |
| self.best_psnr = 0 | |
| self.valid_best_psnr = 0 | |
| self.maskedLoss = nn.L1Loss() | |
| self.validLoss = nn.L1Loss() | |
| self.adversarial_loss = AdversarialLoss(type='hinge') | |
| self.adversarial_loss = self.adversarial_loss.to(self.opt['device']) | |
| self.countDown = 0 | |
| # metrics recorder | |
| self.total_loss = 0 | |
| self.total_psnr = 0 | |
| self.total_ssim = 0 | |
| self.total_l1 = 0 | |
| self.total_l2 = 0 | |
| def get_lr(self): | |
| lr = [] | |
| for param_group in self.optimizer.param_groups: | |
| lr += [param_group['lr']] | |
| for param_group in self.dist_optim.param_groups: | |
| lr += [param_group['lr']] | |
| return lr | |
| def adjust_learning_rate(self, optimizer, target_lr): | |
| for param_group in optimizer.param_groups: | |
| param_group['lr'] = target_lr | |
| for param_group in self.dist_optim.param_groups: | |
| param_group['lr'] = target_lr | |
| def mkdir(self): | |
| new_name = util.mkdir_and_rename(self.opt['path']['OUTPUT_ROOT']) | |
| if new_name: | |
| self.opt['path']['TRAINING_STATE'] = os.path.join(new_name, 'training_state') | |
| self.opt['path']['LOG'] = os.path.join(new_name, 'log') | |
| self.opt['path']['VAL_IMAGES'] = os.path.join(new_name, 'val_images') | |
| if not os.path.exists(self.opt['path']['TRAINING_STATE']): | |
| os.makedirs(self.opt['path']['TRAINING_STATE']) | |
| if not os.path.exists(self.opt['path']['LOG']): | |
| os.makedirs(self.opt['path']['LOG']) | |
| if not os.path.exists(self.opt['path']['VAL_IMAGES']): | |
| os.makedirs(self.opt['path']['VAL_IMAGES']) | |
| # save config file for output | |
| with open(os.path.join(self.opt['path']['LOG'], 'config.yaml'), 'w') as f: | |
| yaml.dump(self.opt, f) | |
| def setLogger(self): | |
| util.setup_logger('base', self.opt['path']['LOG'], 'train_' + self.opt['name'], level=logging.INFO, | |
| screen=True, tofile=True) | |
| logger = logging.getLogger('base') | |
| logger.info(parse.toString(self.opt)) | |
| logger.info('OUTPUT DIR IS: {}'.format(self.opt['path']['OUTPUT_ROOT'])) | |
| if self.opt['use_tb_logger']: | |
| version = float(torch.__version__[0:3]) | |
| if version >= 1.1: | |
| from torch.utils.tensorboard import SummaryWriter | |
| else: | |
| logger.info('You are using PyTorch {}, Tensorboard will use [tensorboardX)'.format(version)) | |
| from tensorboardX import SummaryWriter | |
| tb_logger = SummaryWriter(os.path.join(self.opt['path']['OUTPUT_ROOT'], 'log')) | |
| else: | |
| tb_logger = None | |
| return logger, tb_logger | |
| def setSeed(self): | |
| seed = self.opt['train']['manual_seed'] | |
| if self.rank <= 0: | |
| self.logger.info('Random seed: {}'.format(seed)) | |
| util.set_random_seed(seed) | |
| torch.backends.cudnn.benchmark = True | |
| if seed == 0: | |
| torch.backends.cudnn.deterministic = True | |
| def prepareDataset(self): | |
| dataInfo = self.opt['datasets']['dataInfo'] | |
| valInfo = self.opt['datasets']['valInfo'] | |
| valInfo['norm'] = self.opt['norm'] | |
| if self.rank <= 0: | |
| self.logger.debug('Val info is: {}'.format(valInfo)) | |
| train_set, train_size, total_iterations, total_epochs = 0, 0, 0, 0 | |
| train_loader, train_sampler = None, None | |
| for phase, dataset in self.opt['datasets'].items(): | |
| dataset['norm'] = self.opt['norm'] | |
| dataset['dataMode'] = self.opt['dataMode'] | |
| dataset['num_frames'] = self.opt['num_frames'] | |
| dataset['sample'] = self.opt['sample'] | |
| dataset['flow2rgb'] = self.opt['flow2rgb'] | |
| dataset['flow_direction'] = self.opt['flow_direction'] | |
| dataset['max_val'] = self.opt['max_val'] | |
| dataset['input_resolution'] = self.opt['input_resolution'] | |
| if phase.lower() == 'train': | |
| train_set = create_dataset(dataset, dataInfo, phase, self.opt['datasetName_train']) | |
| train_size = math.ceil( | |
| len(train_set) / (dataset['batch_size'] * self.opt['world_size'])) | |
| total_iterations = self.opt['train']['MAX_ITERS'] | |
| total_epochs = int(math.ceil(total_iterations / train_size)) | |
| if self.opt['distributed']: | |
| train_sampler = DistributedSampler( | |
| train_set, | |
| num_replicas=self.opt['world_size'], | |
| rank=self.opt['global_rank']) | |
| else: | |
| train_sampler = None | |
| train_loader = create_dataloader(phase, train_set, dataset, self.opt, train_sampler) | |
| if self.rank <= 0: | |
| self.logger.info('Number of training batches: {}, iters: {}'.format(len(train_set), | |
| total_iterations)) | |
| self.logger.info('Total epoch needed: {} for iters {}'.format(total_epochs, total_iterations)) | |
| assert train_set != 0 and train_size != 0, "Train size cannot be zero" | |
| assert train_loader is not None, "Cannot find train set, val set can be None" | |
| return dataInfo, valInfo, train_set, train_size, total_iterations, total_epochs, train_loader, train_sampler | |
| def init_model(self): | |
| pass | |
| def init_flow_model(self): | |
| pass | |
| def resume_training(self): | |
| pass | |
| def train(self): | |
| for epoch in range(self.startEpoch, self.totalEpochs + 1): | |
| if self.opt['distributed']: | |
| self.trainSampler.set_epoch(epoch) | |
| self._trainEpoch(epoch) | |
| if self.currentStep > self.totalIterations: | |
| break | |
| if self.opt['use_valid'] and (epoch + 1) % self.opt['train']['val_freq'] == 0: | |
| self._validate(epoch) | |
| self.scheduler.step(epoch) | |
| self.dist_scheduler.step(epoch) | |
| def _trainEpoch(self, epoch): | |
| pass | |
| def _printLog(self, logs, epoch, loss): | |
| pass | |
| def save_checkpoint(self, epoch, is_best, metric, number): | |
| pass | |
| def _validate(self, epoch): | |
| pass | |