| import glob
|
| import re
|
| import subprocess
|
| from datetime import datetime
|
|
|
| import matplotlib
|
|
|
| matplotlib.use('Agg')
|
|
|
| from utils.hparams import hparams, set_hparams
|
| import random
|
| import sys
|
| import numpy as np
|
| import torch.distributed as dist
|
| from pytorch_lightning.loggers import TensorBoardLogger
|
| from utils.pl_utils import LatestModelCheckpoint, BaseTrainer, data_loader, DDP
|
| from torch import nn
|
| import torch.utils.data
|
| import utils
|
| import logging
|
| import os
|
|
|
| torch.multiprocessing.set_sharing_strategy(os.getenv('TORCH_SHARE_STRATEGY', 'file_system'))
|
|
|
| log_format = '%(asctime)s %(message)s'
|
| logging.basicConfig(stream=sys.stdout, level=logging.INFO,
|
| format=log_format, datefmt='%m/%d %I:%M:%S %p')
|
|
|
|
|
| class BaseDataset(torch.utils.data.Dataset):
|
| def __init__(self, shuffle):
|
| super().__init__()
|
| self.hparams = hparams
|
| self.shuffle = shuffle
|
| self.sort_by_len = hparams['sort_by_len']
|
| self.sizes = None
|
|
|
| @property
|
| def _sizes(self):
|
| return self.sizes
|
|
|
| def __getitem__(self, index):
|
| raise NotImplementedError
|
|
|
| def collater(self, samples):
|
| raise NotImplementedError
|
|
|
| def __len__(self):
|
| return len(self._sizes)
|
|
|
| def num_tokens(self, index):
|
| return self.size(index)
|
|
|
| def size(self, index):
|
| """Return an example's size as a float or tuple. This value is used when
|
| filtering a dataset with ``--max-positions``."""
|
| size = min(self._sizes[index], hparams['max_frames'])
|
| return size
|
|
|
| def ordered_indices(self):
|
| """Return an ordered list of indices. Batches will be constructed based
|
| on this order."""
|
| if self.shuffle:
|
| indices = np.random.permutation(len(self))
|
| if self.sort_by_len:
|
| indices = indices[np.argsort(np.array(self._sizes)[indices], kind='mergesort')]
|
|
|
| else:
|
| indices = np.arange(len(self))
|
| return indices
|
|
|
| @property
|
| def num_workers(self):
|
| return int(os.getenv('NUM_WORKERS', hparams['ds_workers']))
|
|
|
|
|
| class BaseTask(nn.Module):
|
| def __init__(self, *args, **kwargs):
|
|
|
| super(BaseTask, self).__init__(*args, **kwargs)
|
| self.current_epoch = 0
|
| self.global_step = 0
|
| self.loaded_optimizer_states_dict = {}
|
| self.trainer = None
|
| self.logger = None
|
| self.on_gpu = False
|
| self.use_dp = False
|
| self.use_ddp = False
|
| self.example_input_array = None
|
|
|
| self.max_tokens = hparams['max_tokens']
|
| self.max_sentences = hparams['max_sentences']
|
| self.max_eval_tokens = hparams['max_eval_tokens']
|
| if self.max_eval_tokens == -1:
|
| hparams['max_eval_tokens'] = self.max_eval_tokens = self.max_tokens
|
| self.max_eval_sentences = hparams['max_eval_sentences']
|
| if self.max_eval_sentences == -1:
|
| hparams['max_eval_sentences'] = self.max_eval_sentences = self.max_sentences
|
|
|
| self.model = None
|
| self.training_losses_meter = None
|
|
|
|
|
|
|
|
|
| def build_model(self):
|
| raise NotImplementedError
|
|
|
| def load_ckpt(self, ckpt_base_dir, current_model_name=None, model_name='model', force=True, strict=True):
|
|
|
| if current_model_name is None:
|
| current_model_name = model_name
|
| utils.load_ckpt(self.__getattr__(current_model_name), ckpt_base_dir, current_model_name, force, strict)
|
|
|
| def on_epoch_start(self):
|
| self.training_losses_meter = {'total_loss': utils.AvgrageMeter()}
|
|
|
| def _training_step(self, sample, batch_idx, optimizer_idx):
|
| """
|
|
|
| :param sample:
|
| :param batch_idx:
|
| :return: total loss: torch.Tensor, loss_log: dict
|
| """
|
| raise NotImplementedError
|
|
|
| def training_step(self, sample, batch_idx, optimizer_idx=-1):
|
| loss_ret = self._training_step(sample, batch_idx, optimizer_idx)
|
| self.opt_idx = optimizer_idx
|
| if loss_ret is None:
|
| return {'loss': None}
|
| total_loss, log_outputs = loss_ret
|
| log_outputs = utils.tensors_to_scalars(log_outputs)
|
| for k, v in log_outputs.items():
|
| if k not in self.training_losses_meter:
|
| self.training_losses_meter[k] = utils.AvgrageMeter()
|
| if not np.isnan(v):
|
| self.training_losses_meter[k].update(v)
|
| self.training_losses_meter['total_loss'].update(total_loss.item())
|
|
|
| try:
|
| log_outputs['lr'] = self.scheduler.get_lr()
|
| if isinstance(log_outputs['lr'], list):
|
| log_outputs['lr'] = log_outputs['lr'][0]
|
| except:
|
| pass
|
|
|
|
|
| progress_bar_log = log_outputs
|
| tb_log = {f'tr/{k}': v for k, v in log_outputs.items()}
|
| return {
|
| 'loss': total_loss,
|
| 'progress_bar': progress_bar_log,
|
| 'log': tb_log
|
| }
|
|
|
| def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx):
|
| optimizer.step()
|
| optimizer.zero_grad()
|
| if self.scheduler is not None:
|
| self.scheduler.step(self.global_step // hparams['accumulate_grad_batches'])
|
|
|
| def on_epoch_end(self):
|
| loss_outputs = {k: round(v.avg, 4) for k, v in self.training_losses_meter.items()}
|
| print(f"\n==============\n "
|
| f"Epoch {self.current_epoch} ended. Steps: {self.global_step}. {loss_outputs}"
|
| f"\n==============\n")
|
|
|
| def validation_step(self, sample, batch_idx):
|
| """
|
|
|
| :param sample:
|
| :param batch_idx:
|
| :return: output: dict
|
| """
|
| raise NotImplementedError
|
|
|
| def _validation_end(self, outputs):
|
| """
|
|
|
| :param outputs:
|
| :return: loss_output: dict
|
| """
|
| raise NotImplementedError
|
|
|
| def validation_end(self, outputs):
|
| loss_output = self._validation_end(outputs)
|
| print(f"\n==============\n "
|
| f"valid results: {loss_output}"
|
| f"\n==============\n")
|
| return {
|
| 'log': {f'val/{k}': v for k, v in loss_output.items()},
|
| 'val_loss': loss_output['total_loss']
|
| }
|
|
|
| def build_scheduler(self, optimizer):
|
| raise NotImplementedError
|
|
|
| def build_optimizer(self, model):
|
| raise NotImplementedError
|
|
|
| def configure_optimizers(self):
|
| optm = self.build_optimizer(self.model)
|
| self.scheduler = self.build_scheduler(optm)
|
| return [optm]
|
|
|
| def test_start(self):
|
| pass
|
|
|
| def test_step(self, sample, batch_idx):
|
| return self.validation_step(sample, batch_idx)
|
|
|
| def test_end(self, outputs):
|
| return self.validation_end(outputs)
|
|
|
|
|
|
|
|
|
|
|
| @classmethod
|
| def start(cls):
|
| set_hparams()
|
| os.environ['MASTER_PORT'] = str(random.randint(15000, 30000))
|
| random.seed(hparams['seed'])
|
| np.random.seed(hparams['seed'])
|
| task = cls()
|
| work_dir = hparams['work_dir']
|
| trainer = BaseTrainer(checkpoint_callback=LatestModelCheckpoint(
|
| filepath=work_dir,
|
| verbose=True,
|
| monitor='val_loss',
|
| mode='min',
|
| num_ckpt_keep=hparams['num_ckpt_keep'],
|
| save_best=hparams['save_best'],
|
| period=1 if hparams['save_ckpt'] else 100000
|
| ),
|
| logger=TensorBoardLogger(
|
| save_dir=work_dir,
|
| name='lightning_logs',
|
| version='lastest'
|
| ),
|
| gradient_clip_val=hparams['clip_grad_norm'],
|
| val_check_interval=hparams['val_check_interval'],
|
| row_log_interval=hparams['log_interval'],
|
| max_updates=hparams['max_updates'],
|
| num_sanity_val_steps=hparams['num_sanity_val_steps'] if not hparams[
|
| 'validate'] else 10000,
|
| accumulate_grad_batches=hparams['accumulate_grad_batches'])
|
| if not hparams['infer']:
|
| t = datetime.now().strftime('%Y%m%d%H%M%S')
|
| code_dir = f'{work_dir}/codes/{t}'
|
| subprocess.check_call(f'mkdir -p "{code_dir}"', shell=True)
|
| for c in hparams['save_codes']:
|
| subprocess.check_call(f'cp -r "{c}" "{code_dir}/"', shell=True)
|
| print(f"| Copied codes to {code_dir}.")
|
| trainer.checkpoint_callback.task = task
|
| trainer.fit(task)
|
| else:
|
| trainer.test(task)
|
|
|
| def configure_ddp(self, model, device_ids):
|
| model = DDP(
|
| model,
|
| device_ids=device_ids,
|
| find_unused_parameters=True
|
| )
|
| if dist.get_rank() != 0 and not hparams['debug']:
|
| sys.stdout = open(os.devnull, "w")
|
| sys.stderr = open(os.devnull, "w")
|
| random.seed(hparams['seed'])
|
| np.random.seed(hparams['seed'])
|
| return model
|
|
|
| def training_end(self, *args, **kwargs):
|
| return None
|
|
|
| def init_ddp_connection(self, proc_rank, world_size):
|
| set_hparams(print_hparams=False)
|
|
|
| default_port = 12910
|
|
|
| try:
|
| default_port = os.environ['MASTER_PORT']
|
| except Exception:
|
| os.environ['MASTER_PORT'] = str(default_port)
|
|
|
|
|
| root_node = '127.0.0.2'
|
| root_node = self.trainer.resolve_root_node_address(root_node)
|
| os.environ['MASTER_ADDR'] = root_node
|
| dist.init_process_group('nccl', rank=proc_rank, world_size=world_size)
|
|
|
| @data_loader
|
| def train_dataloader(self):
|
| return None
|
|
|
| @data_loader
|
| def test_dataloader(self):
|
| return None
|
|
|
| @data_loader
|
| def val_dataloader(self):
|
| return None
|
|
|
| def on_load_checkpoint(self, checkpoint):
|
| pass
|
|
|
| def on_save_checkpoint(self, checkpoint):
|
| pass
|
|
|
| def on_sanity_check_start(self):
|
| pass
|
|
|
| def on_train_start(self):
|
| pass
|
|
|
| def on_train_end(self):
|
| pass
|
|
|
| def on_batch_start(self, batch):
|
| pass
|
|
|
| def on_batch_end(self):
|
| pass
|
|
|
| def on_pre_performance_check(self):
|
| pass
|
|
|
| def on_post_performance_check(self):
|
| pass
|
|
|
| def on_before_zero_grad(self, optimizer):
|
| pass
|
|
|
| def on_after_backward(self):
|
| pass
|
|
|
| def backward(self, loss, optimizer):
|
| loss.backward()
|
|
|
| def grad_norm(self, norm_type):
|
| results = {}
|
| total_norm = 0
|
| for name, p in self.named_parameters():
|
| if p.requires_grad:
|
| try:
|
| param_norm = p.grad.data.norm(norm_type)
|
| total_norm += param_norm ** norm_type
|
| norm = param_norm ** (1 / norm_type)
|
|
|
| grad = round(norm.data.cpu().numpy().flatten()[0], 3)
|
| results['grad_{}_norm_{}'.format(norm_type, name)] = grad
|
| except Exception:
|
|
|
| pass
|
|
|
| total_norm = total_norm ** (1. / norm_type)
|
| grad = round(total_norm.data.cpu().numpy().flatten()[0], 3)
|
| results['grad_{}_norm_total'.format(norm_type)] = grad
|
| return results
|
|
|