| import matplotlib
|
| from torch.nn import DataParallel
|
| from torch.nn.parallel import DistributedDataParallel
|
|
|
| matplotlib.use('Agg')
|
| import glob
|
| import itertools
|
| import subprocess
|
| import threading
|
| import traceback
|
|
|
| from pytorch_lightning.callbacks import GradientAccumulationScheduler
|
| from pytorch_lightning.callbacks import ModelCheckpoint
|
|
|
| from functools import wraps
|
| from torch.cuda._utils import _get_device_index
|
| import numpy as np
|
| import torch.optim
|
| import torch.utils.data
|
| import copy
|
| import logging
|
| import os
|
| import re
|
| import sys
|
| import torch
|
| import torch.distributed as dist
|
| import torch.multiprocessing as mp
|
| import tqdm
|
| from torch.optim.optimizer import Optimizer
|
| from packaging import version
|
|
|
|
|
| def get_a_var(obj):
|
| if isinstance(obj, torch.Tensor):
|
| return obj
|
|
|
| if isinstance(obj, list) or isinstance(obj, tuple):
|
| for result in map(get_a_var, obj):
|
| if isinstance(result, torch.Tensor):
|
| return result
|
| if isinstance(obj, dict):
|
| for result in map(get_a_var, obj.items()):
|
| if isinstance(result, torch.Tensor):
|
| return result
|
| return None
|
|
|
|
|
| def data_loader(fn):
|
| """
|
| Decorator to make any fx with this use the lazy property
|
| :param fn:
|
| :return:
|
| """
|
|
|
| wraps(fn)
|
| attr_name = '_lazy_' + fn.__name__
|
|
|
| def _get_data_loader(self):
|
| try:
|
| value = getattr(self, attr_name)
|
| except AttributeError:
|
| try:
|
| value = fn(self)
|
| if (
|
| value is not None and
|
| not isinstance(value, list) and
|
| fn.__name__ in ['test_dataloader', 'val_dataloader']
|
| ):
|
| value = [value]
|
| except AttributeError as e:
|
|
|
| traceback.print_exc()
|
| error = f'{fn.__name__}: An AttributeError was encountered: ' + str(e)
|
| raise RuntimeError(error) from e
|
| setattr(self, attr_name, value)
|
| return value
|
|
|
| return _get_data_loader
|
|
|
|
|
| def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
|
| r"""Applies each `module` in :attr:`modules` in parallel on arguments
|
| contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
|
| on each of :attr:`devices`.
|
|
|
| Args:
|
| modules (Module): modules to be parallelized
|
| inputs (tensor): inputs to the modules
|
| devices (list of int or torch.device): CUDA devices
|
|
|
| :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
|
| :attr:`devices` (if given) should all have same length. Moreover, each
|
| element of :attr:`inputs` can either be a single object as the only argument
|
| to a module, or a collection of positional arguments.
|
| """
|
| assert len(modules) == len(inputs)
|
| if kwargs_tup is not None:
|
| assert len(modules) == len(kwargs_tup)
|
| else:
|
| kwargs_tup = ({},) * len(modules)
|
| if devices is not None:
|
| assert len(modules) == len(devices)
|
| else:
|
| devices = [None] * len(modules)
|
| devices = list(map(lambda x: _get_device_index(x, True), devices))
|
| lock = threading.Lock()
|
| results = {}
|
| grad_enabled = torch.is_grad_enabled()
|
|
|
| def _worker(i, module, input, kwargs, device=None):
|
| torch.set_grad_enabled(grad_enabled)
|
| if device is None:
|
| device = get_a_var(input).get_device()
|
| try:
|
| with torch.cuda.device(device):
|
|
|
| if not isinstance(input, (list, tuple)):
|
| input = (input,)
|
|
|
|
|
|
|
| if module.training:
|
| output = module.training_step(*input, **kwargs)
|
|
|
| elif module.testing:
|
| output = module.test_step(*input, **kwargs)
|
|
|
| else:
|
| output = module.validation_step(*input, **kwargs)
|
|
|
|
|
| with lock:
|
| results[i] = output
|
| except Exception as e:
|
| with lock:
|
| results[i] = e
|
|
|
|
|
|
|
| root_m = modules[0]
|
| for m in modules[1:]:
|
| m.training = root_m.training
|
| m.testing = root_m.testing
|
|
|
| if len(modules) > 1:
|
| threads = [threading.Thread(target=_worker,
|
| args=(i, module, input, kwargs, device))
|
| for i, (module, input, kwargs, device) in
|
| enumerate(zip(modules, inputs, kwargs_tup, devices))]
|
|
|
| for thread in threads:
|
| thread.start()
|
| for thread in threads:
|
| thread.join()
|
| else:
|
| _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
|
|
|
| outputs = []
|
| for i in range(len(inputs)):
|
| output = results[i]
|
| if isinstance(output, Exception):
|
| raise output
|
| outputs.append(output)
|
| return outputs
|
|
|
|
|
| def _find_tensors(obj):
|
| r"""
|
| Recursively find all tensors contained in the specified object.
|
| """
|
| if isinstance(obj, torch.Tensor):
|
| return [obj]
|
| if isinstance(obj, (list, tuple)):
|
| return itertools.chain(*map(_find_tensors, obj))
|
| if isinstance(obj, dict):
|
| return itertools.chain(*map(_find_tensors, obj.values()))
|
| return []
|
|
|
|
|
| class DDP(DistributedDataParallel):
|
| """
|
| Override the forward call in lightning so it goes to training and validation step respectively
|
| """
|
|
|
| def parallel_apply(self, replicas, inputs, kwargs):
|
| return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
|
|
|
| def forward(self, *inputs, **kwargs):
|
| if version.parse(torch.__version__[:6]) < version.parse("1.11"):
|
| self._sync_params()
|
| inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
| assert len(self.device_ids) == 1
|
| if self.module.training:
|
| output = self.module.training_step(*inputs[0], **kwargs[0])
|
| elif self.module.testing:
|
| output = self.module.test_step(*inputs[0], **kwargs[0])
|
| else:
|
| output = self.module.validation_step(*inputs[0], **kwargs[0])
|
| if torch.is_grad_enabled():
|
|
|
|
|
|
|
|
|
|
|
| if self.find_unused_parameters:
|
| self.reducer.prepare_for_backward(list(_find_tensors(output)))
|
| else:
|
| self.reducer.prepare_for_backward([])
|
| else:
|
| from torch.nn.parallel.distributed import \
|
| logging, Join, _DDPSink, _tree_flatten_with_rref, _tree_unflatten_with_rref
|
| with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
|
| if torch.is_grad_enabled() and self.require_backward_grad_sync:
|
| self.logger.set_runtime_stats_and_log()
|
| self.num_iterations += 1
|
| self.reducer.prepare_for_forward()
|
|
|
|
|
|
|
| work = Join.notify_join_context(self)
|
| if work:
|
| self.reducer._set_forward_pass_work_handle(
|
| work, self._divide_by_initial_world_size
|
| )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
|
| logging.info("Reducer buckets have been rebuilt in this iteration.")
|
| self._has_rebuilt_buckets = True
|
|
|
|
|
|
|
| buffer_hook_registered = hasattr(self, 'buffer_hook')
|
| if self._check_sync_bufs_pre_fwd():
|
| self._sync_buffers()
|
|
|
| if self._join_config.enable:
|
|
|
| self._check_global_requires_backward_grad_sync(is_joined_rank=False)
|
|
|
| inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
| if self.module.training:
|
| output = self.module.training_step(*inputs[0], **kwargs[0])
|
| elif self.module.testing:
|
| output = self.module.test_step(*inputs[0], **kwargs[0])
|
| else:
|
| output = self.module.validation_step(*inputs[0], **kwargs[0])
|
|
|
|
|
|
|
| if self._check_sync_bufs_post_fwd():
|
| self._sync_buffers()
|
|
|
| if torch.is_grad_enabled() and self.require_backward_grad_sync:
|
| self.require_forward_param_sync = True
|
|
|
|
|
|
|
|
|
|
|
| if self.find_unused_parameters and not self.static_graph:
|
|
|
| self.reducer.prepare_for_backward(list(_find_tensors(output)))
|
| else:
|
| self.reducer.prepare_for_backward([])
|
| else:
|
| self.require_forward_param_sync = False
|
|
|
|
|
|
|
| if (self.find_unused_parameters and not self.static_graph) or (
|
| self.static_graph and self.num_iterations == 1
|
| ):
|
| state_dict = {
|
| 'static_graph': self.static_graph,
|
| 'num_iterations': self.num_iterations,
|
| }
|
|
|
| output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref(
|
| output
|
| )
|
| output_placeholders = [None for _ in range(len(output_tensor_list))]
|
|
|
|
|
| for i, output in enumerate(output_tensor_list):
|
| if torch.is_tensor(output) and output.grad_fn is None:
|
| output_placeholders[i] = output
|
|
|
|
|
|
|
|
|
|
|
|
|
| passthrough_tensor_list = _DDPSink.apply(
|
| self.reducer,
|
| state_dict,
|
| *output_tensor_list,
|
| )
|
| for i in range(len(output_placeholders)):
|
| if output_placeholders[i] is None:
|
| output_placeholders[i] = passthrough_tensor_list[i]
|
|
|
|
|
| output = _tree_unflatten_with_rref(
|
| output_placeholders, treespec, output_is_rref
|
| )
|
| return output
|
|
|
|
|
| class DP(DataParallel):
|
| """
|
| Override the forward call in lightning so it goes to training and validation step respectively
|
| """
|
|
|
| def forward(self, *inputs, **kwargs):
|
| if not self.device_ids:
|
| return self.module(*inputs, **kwargs)
|
|
|
| for t in itertools.chain(self.module.parameters(), self.module.buffers()):
|
| if t.device != self.src_device_obj:
|
| raise RuntimeError("module must have its parameters and buffers "
|
| "on device {} (device_ids[0]) but found one of "
|
| "them on device: {}".format(self.src_device_obj, t.device))
|
|
|
| inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
| if len(self.device_ids) == 1:
|
|
|
| if self.module.training:
|
| return self.module.training_step(*inputs[0], **kwargs[0])
|
| elif self.module.testing:
|
| return self.module.test_step(*inputs[0], **kwargs[0])
|
| else:
|
| return self.module.validation_step(*inputs[0], **kwargs[0])
|
|
|
| replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
|
| outputs = self.parallel_apply(replicas, inputs, kwargs)
|
| return self.gather(outputs, self.output_device)
|
|
|
| def parallel_apply(self, replicas, inputs, kwargs):
|
| return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
|
|
|
|
|
| class GradientAccumulationScheduler:
|
| def __init__(self, scheduling: dict):
|
| if scheduling == {}:
|
| raise TypeError("Empty dict cannot be interpreted correct")
|
|
|
| for key in scheduling.keys():
|
| if not isinstance(key, int) or not isinstance(scheduling[key], int):
|
| raise TypeError("All epoches and accumulation factor must be integers")
|
|
|
| minimal_epoch = min(scheduling.keys())
|
| if minimal_epoch < 1:
|
| msg = f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct"
|
| raise IndexError(msg)
|
| elif minimal_epoch != 1:
|
| scheduling.update({1: 1})
|
|
|
| self.scheduling = scheduling
|
| self.epochs = sorted(scheduling.keys())
|
|
|
| def on_epoch_begin(self, epoch, trainer):
|
| epoch += 1
|
| for i in reversed(range(len(self.epochs))):
|
| if epoch >= self.epochs[i]:
|
| trainer.accumulate_grad_batches = self.scheduling.get(self.epochs[i])
|
| break
|
|
|
|
|
| class LatestModelCheckpoint(ModelCheckpoint):
|
| def __init__(self, filepath, monitor='val_loss', verbose=0, num_ckpt_keep=5,
|
| save_weights_only=False, mode='auto', period=1, prefix='model', save_best=True):
|
| super(ModelCheckpoint, self).__init__()
|
| self.monitor = monitor
|
| self.verbose = verbose
|
| self.filepath = filepath
|
| os.makedirs(filepath, exist_ok=True)
|
| self.num_ckpt_keep = num_ckpt_keep
|
| self.save_best = save_best
|
| self.save_weights_only = save_weights_only
|
| self.period = period
|
| self.epochs_since_last_check = 0
|
| self.prefix = prefix
|
| self.best_k_models = {}
|
|
|
| self.kth_best_model = ''
|
| self.save_top_k = 1
|
| self.task = None
|
| if mode == 'min':
|
| self.monitor_op = np.less
|
| self.best = np.Inf
|
| self.mode = 'min'
|
| elif mode == 'max':
|
| self.monitor_op = np.greater
|
| self.best = -np.Inf
|
| self.mode = 'max'
|
| else:
|
| if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
|
| self.monitor_op = np.greater
|
| self.best = -np.Inf
|
| self.mode = 'max'
|
| else:
|
| self.monitor_op = np.less
|
| self.best = np.Inf
|
| self.mode = 'min'
|
| if os.path.exists(f'{self.filepath}/best_valid.npy'):
|
| self.best = np.load(f'{self.filepath}/best_valid.npy')[0]
|
|
|
| def get_all_ckpts(self):
|
| return sorted(glob.glob(f'{self.filepath}/{self.prefix}_ckpt_steps_*.ckpt'),
|
| key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0]))
|
|
|
| def on_epoch_end(self, epoch, logs=None):
|
| logs = logs or {}
|
| self.epochs_since_last_check += 1
|
| best_filepath = f'{self.filepath}/{self.prefix}_ckpt_best.pt'
|
| if self.epochs_since_last_check >= self.period:
|
| self.epochs_since_last_check = 0
|
| filepath = f'{self.filepath}/{self.prefix}_ckpt_steps_{self.task.global_step}.ckpt'
|
| if self.verbose > 0:
|
| logging.info(f'Epoch {epoch:05d}@{self.task.global_step}: saving model to {filepath}')
|
| self._save_model(filepath)
|
| for old_ckpt in self.get_all_ckpts()[self.num_ckpt_keep:]:
|
| subprocess.check_call(f'rm -rf "{old_ckpt}"', shell=True)
|
| if self.verbose > 0:
|
| logging.info(f'Delete ckpt: {os.path.basename(old_ckpt)}')
|
| current = logs.get(self.monitor)
|
| if current is not None and self.save_best:
|
| if self.monitor_op(current, self.best):
|
| self.best = current
|
| if self.verbose > 0:
|
| logging.info(
|
| f'Epoch {epoch:05d}@{self.task.global_step}: {self.monitor} reached'
|
| f' {current:0.5f} (best {self.best:0.5f}), saving model to'
|
| f' {best_filepath} as top 1')
|
| self._save_model(best_filepath)
|
| np.save(f'{self.filepath}/best_valid.npy', [self.best])
|
|
|
|
|
| class BaseTrainer:
|
| def __init__(
|
| self,
|
| logger=True,
|
| checkpoint_callback=True,
|
| default_save_path=None,
|
| gradient_clip_val=0,
|
| process_position=0,
|
| gpus=-1,
|
| log_gpu_memory=None,
|
| show_progress_bar=True,
|
| track_grad_norm=-1,
|
| check_val_every_n_epoch=1,
|
| accumulate_grad_batches=1,
|
| max_updates=1000,
|
| min_epochs=1,
|
| val_check_interval=1.0,
|
| log_save_interval=100,
|
| row_log_interval=10,
|
| print_nan_grads=False,
|
| weights_summary='full',
|
| num_sanity_val_steps=5,
|
| resume_from_checkpoint=None,
|
| ):
|
| self.log_gpu_memory = log_gpu_memory
|
| self.gradient_clip_val = gradient_clip_val
|
| self.check_val_every_n_epoch = check_val_every_n_epoch
|
| self.track_grad_norm = track_grad_norm
|
| self.on_gpu = True if (gpus and torch.cuda.is_available()) else False
|
| self.process_position = process_position
|
| self.weights_summary = weights_summary
|
| self.max_updates = max_updates
|
| self.min_epochs = min_epochs
|
| self.num_sanity_val_steps = num_sanity_val_steps
|
| self.print_nan_grads = print_nan_grads
|
| self.resume_from_checkpoint = resume_from_checkpoint
|
| self.default_save_path = default_save_path
|
|
|
|
|
| self.total_batch_idx = 0
|
| self.running_loss = []
|
| self.avg_loss = 0
|
| self.batch_idx = 0
|
| self.tqdm_metrics = {}
|
| self.callback_metrics = {}
|
| self.num_val_batches = 0
|
| self.num_training_batches = 0
|
| self.num_test_batches = 0
|
| self.get_train_dataloader = None
|
| self.get_test_dataloaders = None
|
| self.get_val_dataloaders = None
|
| self.is_iterable_train_dataloader = False
|
|
|
|
|
| self.model = None
|
| self.testing = False
|
| self.disable_validation = False
|
| self.lr_schedulers = []
|
| self.optimizers = None
|
| self.global_step = 0
|
| self.current_epoch = 0
|
| self.total_batches = 0
|
|
|
|
|
| self.checkpoint_callback = checkpoint_callback
|
| self.checkpoint_callback.save_function = self.save_checkpoint
|
| self.weights_save_path = self.checkpoint_callback.filepath
|
|
|
|
|
| self.configure_accumulated_gradients(accumulate_grad_batches)
|
|
|
|
|
| self.data_parallel_device_ids = [
|
| int(x) for x in os.environ.get("CUDA_VISIBLE_DEVICES", "").split(",") if x != '']
|
| if len(self.data_parallel_device_ids) == 0:
|
| self.root_gpu = None
|
| self.on_gpu = False
|
| else:
|
| self.root_gpu = self.data_parallel_device_ids[0]
|
| self.on_gpu = True
|
|
|
|
|
| self.use_ddp = False
|
| self.use_dp = False
|
| self.single_gpu = False
|
| self.distributed_backend = 'ddp' if self.num_gpus > 0 else 'dp'
|
| self.set_distributed_mode(self.distributed_backend)
|
|
|
| self.proc_rank = 0
|
| self.world_size = 1
|
| self.node_rank = 0
|
|
|
|
|
|
|
| self.show_progress_bar = show_progress_bar
|
|
|
|
|
| self.log_save_interval = log_save_interval
|
| self.val_check_interval = val_check_interval
|
| self.logger = logger
|
| self.logger.rank = 0
|
| self.row_log_interval = row_log_interval
|
|
|
| @property
|
| def num_gpus(self):
|
| gpus = self.data_parallel_device_ids
|
| if gpus is None:
|
| return 0
|
| else:
|
| return len(gpus)
|
|
|
| @property
|
| def data_parallel(self):
|
| return self.use_dp or self.use_ddp
|
|
|
| def get_model(self):
|
| is_dp_module = isinstance(self.model, (DDP, DP))
|
| model = self.model.module if is_dp_module else self.model
|
| return model
|
|
|
|
|
|
|
|
|
| def fit(self, model):
|
| if self.use_ddp:
|
| mp.spawn(self.ddp_train, nprocs=self.num_gpus, args=(model,))
|
| else:
|
| model.model = model.build_model()
|
| if not self.testing:
|
| self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
|
| if self.use_dp:
|
| model.cuda(self.root_gpu)
|
| model = DP(model, device_ids=self.data_parallel_device_ids)
|
| elif self.single_gpu:
|
| model.cuda(self.root_gpu)
|
| self.run_pretrain_routine(model)
|
| return 1
|
|
|
| def init_optimizers(self, optimizers):
|
|
|
|
|
| if isinstance(optimizers, Optimizer):
|
| return [optimizers], []
|
|
|
|
|
| elif len(optimizers) == 2 and isinstance(optimizers[0], list):
|
| optimizers, lr_schedulers = optimizers
|
| return optimizers, lr_schedulers
|
|
|
|
|
| elif isinstance(optimizers, list) or isinstance(optimizers, tuple):
|
| return optimizers, []
|
|
|
| def run_pretrain_routine(self, model):
|
| """Sanity check a few things before starting actual training.
|
|
|
| :param model:
|
| """
|
| ref_model = model
|
| if self.data_parallel:
|
| ref_model = model.module
|
|
|
|
|
| ref_model.trainer = self
|
|
|
|
|
| self.copy_trainer_model_properties(ref_model)
|
|
|
|
|
| if self.logger is not None:
|
| ref_model.logger = self.logger
|
| self.logger.save()
|
|
|
| if self.use_ddp:
|
| dist.barrier()
|
|
|
|
|
|
|
|
|
|
|
| self.get_dataloaders(ref_model)
|
|
|
|
|
|
|
| self.model = model
|
|
|
|
|
| self.restore_weights(model)
|
|
|
|
|
| if self.testing:
|
| self.run_evaluation(test=True)
|
| return
|
|
|
|
|
| self.disable_validation = self.num_val_batches == 0
|
|
|
|
|
|
|
| ref_model.on_sanity_check_start()
|
| ref_model.on_train_start()
|
| if not self.disable_validation and self.num_sanity_val_steps > 0:
|
|
|
| pbar = tqdm.tqdm(desc='Validation sanity check',
|
| total=self.num_sanity_val_steps * len(self.get_val_dataloaders()),
|
| leave=False, position=2 * self.process_position,
|
| disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch')
|
| self.main_progress_bar = pbar
|
|
|
| self.val_progress_bar = tqdm.tqdm(disable=True)
|
|
|
| self.evaluate(model, self.get_val_dataloaders(), self.num_sanity_val_steps, self.testing)
|
|
|
|
|
| self.main_progress_bar.close()
|
| self.val_progress_bar.close()
|
|
|
|
|
| pbar = tqdm.tqdm(leave=True, position=2 * self.process_position,
|
| disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch',
|
| file=sys.stdout)
|
| self.main_progress_bar = pbar
|
|
|
|
|
| if self.on_gpu:
|
| torch.cuda.empty_cache()
|
|
|
|
|
| self.train()
|
|
|
| def test(self, model):
|
| self.testing = True
|
| self.fit(model)
|
|
|
| @property
|
| def training_tqdm_dict(self):
|
| tqdm_dict = {
|
| 'step': '{}'.format(self.global_step),
|
| }
|
| tqdm_dict.update(self.tqdm_metrics)
|
| return tqdm_dict
|
|
|
|
|
|
|
|
|
| def restore_weights(self, model):
|
| """
|
| To restore weights we have two cases.
|
| First, attempt to restore hpc weights. If successful, don't restore
|
| other weights.
|
|
|
| Otherwise, try to restore actual weights
|
| :param model:
|
| :return:
|
| """
|
|
|
| if self.on_gpu:
|
| torch.cuda.empty_cache()
|
|
|
| if self.resume_from_checkpoint is not None:
|
| self.restore(self.resume_from_checkpoint, on_gpu=self.on_gpu)
|
| else:
|
|
|
| self.restore_state_if_checkpoint_exists(model)
|
|
|
|
|
| if self.use_ddp:
|
|
|
| dist.barrier()
|
|
|
|
|
| if self.on_gpu:
|
| torch.cuda.empty_cache()
|
|
|
| def restore_state_if_checkpoint_exists(self, model):
|
| did_restore = False
|
|
|
|
|
| no_ckpt_callback = (self.checkpoint_callback is None) or (not self.checkpoint_callback)
|
| if no_ckpt_callback or not os.path.exists(self.checkpoint_callback.filepath):
|
| return did_restore
|
|
|
|
|
| last_steps = -1
|
| last_ckpt_name = None
|
|
|
|
|
| checkpoints = os.listdir(self.checkpoint_callback.filepath)
|
| for name in checkpoints:
|
| if '.ckpt' in name and not name.endswith('part'):
|
| if 'steps_' in name:
|
| steps = name.split('steps_')[1]
|
| steps = int(re.sub('[^0-9]', '', steps))
|
|
|
| if steps > last_steps:
|
| last_steps = steps
|
| last_ckpt_name = name
|
|
|
|
|
| if last_ckpt_name is not None:
|
| last_ckpt_path = os.path.join(self.checkpoint_callback.filepath, last_ckpt_name)
|
| self.restore(last_ckpt_path, self.on_gpu)
|
| logging.info(f'model and trainer restored from checkpoint: {last_ckpt_path}')
|
| did_restore = True
|
|
|
| return did_restore
|
|
|
| def restore(self, checkpoint_path, on_gpu):
|
| checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
|
|
|
|
| model = self.get_model()
|
|
|
|
|
| model.load_state_dict(checkpoint['state_dict'], strict=False)
|
| if on_gpu:
|
| model.cuda(self.root_gpu)
|
|
|
| self.restore_training_state(checkpoint)
|
| model.global_step = self.global_step
|
| del checkpoint
|
|
|
| try:
|
| if dist.is_initialized() and dist.get_rank() > 0:
|
| return
|
| except Exception as e:
|
| print(e)
|
| return
|
|
|
| def restore_training_state(self, checkpoint):
|
| """
|
| Restore trainer state.
|
| Model will get its change to update
|
| :param checkpoint:
|
| :return:
|
| """
|
| if self.checkpoint_callback is not None and self.checkpoint_callback is not False:
|
| self.checkpoint_callback.best = checkpoint['checkpoint_callback_best']
|
|
|
| self.global_step = checkpoint['global_step']
|
| self.current_epoch = checkpoint['epoch']
|
|
|
| if self.testing:
|
| return
|
|
|
|
|
| optimizer_states = checkpoint['optimizer_states']
|
| for optimizer, opt_state in zip(self.optimizers, optimizer_states):
|
| if optimizer is None:
|
| return
|
| optimizer.load_state_dict(opt_state)
|
|
|
|
|
|
|
| if self.root_gpu is not None:
|
| for state in optimizer.state.values():
|
| for k, v in state.items():
|
| if isinstance(v, torch.Tensor):
|
| state[k] = v.cuda(self.root_gpu)
|
|
|
|
|
| lr_schedulers = checkpoint['lr_schedulers']
|
| for scheduler, lrs_state in zip(self.lr_schedulers, lr_schedulers):
|
| scheduler.load_state_dict(lrs_state)
|
|
|
|
|
|
|
|
|
| def _atomic_save(self, checkpoint, filepath):
|
| """Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints.
|
|
|
| This will create a temporary checkpoint with a suffix of ``.part``, then copy it to the final location once
|
| saving is finished.
|
|
|
| Args:
|
| checkpoint (object): The object to save.
|
| Built to be used with the ``dump_checkpoint`` method, but can deal with anything which ``torch.save``
|
| accepts.
|
| filepath (str|pathlib.Path): The path to which the checkpoint will be saved.
|
| This points to the file that the checkpoint will be stored in.
|
| """
|
| tmp_path = str(filepath) + ".part"
|
| torch.save(checkpoint, tmp_path)
|
| os.replace(tmp_path, filepath)
|
|
|
| def save_checkpoint(self, filepath):
|
| checkpoint = self.dump_checkpoint()
|
| self._atomic_save(checkpoint, filepath)
|
|
|
| def dump_checkpoint(self):
|
|
|
| checkpoint = {
|
| 'epoch': self.current_epoch,
|
| 'global_step': self.global_step
|
| }
|
|
|
| if self.checkpoint_callback is not None and self.checkpoint_callback is not False:
|
| checkpoint['checkpoint_callback_best'] = self.checkpoint_callback.best
|
|
|
|
|
| optimizer_states = []
|
| for i, optimizer in enumerate(self.optimizers):
|
| if optimizer is not None:
|
| optimizer_states.append(optimizer.state_dict())
|
|
|
| checkpoint['optimizer_states'] = optimizer_states
|
|
|
|
|
| lr_schedulers = []
|
| for i, scheduler in enumerate(self.lr_schedulers):
|
| lr_schedulers.append(scheduler.state_dict())
|
|
|
| checkpoint['lr_schedulers'] = lr_schedulers
|
|
|
|
|
| model = self.get_model()
|
| checkpoint['state_dict'] = model.state_dict()
|
|
|
| model.on_save_checkpoint(checkpoint)
|
|
|
| return checkpoint
|
|
|
| def copy_trainer_model_properties(self, model):
|
| if isinstance(model, DP):
|
| ref_model = model.module
|
| elif isinstance(model, DDP):
|
| ref_model = model.module
|
| else:
|
| ref_model = model
|
|
|
| for m in [model, ref_model]:
|
| m.trainer = self
|
| m.on_gpu = self.on_gpu
|
| m.use_dp = self.use_dp
|
| m.use_ddp = self.use_ddp
|
| m.testing = self.testing
|
| m.single_gpu = self.single_gpu
|
|
|
| def transfer_batch_to_gpu(self, batch, gpu_id):
|
|
|
| if callable(getattr(batch, 'cuda', None)):
|
| return batch.cuda(gpu_id, non_blocking=True)
|
|
|
| elif callable(getattr(batch, 'to', None)):
|
| return batch.to(torch.device('cuda', gpu_id), non_blocking=True)
|
|
|
|
|
| elif isinstance(batch, list):
|
| for i, x in enumerate(batch):
|
| batch[i] = self.transfer_batch_to_gpu(x, gpu_id)
|
| return batch
|
|
|
|
|
| elif isinstance(batch, tuple):
|
| batch = list(batch)
|
| for i, x in enumerate(batch):
|
| batch[i] = self.transfer_batch_to_gpu(x, gpu_id)
|
| return tuple(batch)
|
|
|
|
|
| elif isinstance(batch, dict):
|
| for k, v in batch.items():
|
| batch[k] = self.transfer_batch_to_gpu(v, gpu_id)
|
|
|
| return batch
|
|
|
|
|
| return batch
|
|
|
| def set_distributed_mode(self, distributed_backend):
|
|
|
| if self.num_gpus == 0:
|
| return
|
|
|
|
|
|
|
|
|
| elif self.num_gpus == 1:
|
| self.single_gpu = True
|
| self.use_dp = False
|
| self.use_ddp = False
|
| self.root_gpu = 0
|
| self.data_parallel_device_ids = [0]
|
| else:
|
| if distributed_backend is not None:
|
| self.use_dp = distributed_backend == 'dp'
|
| self.use_ddp = distributed_backend == 'ddp'
|
| elif distributed_backend is None:
|
| self.use_dp = True
|
| self.use_ddp = False
|
|
|
| logging.info(f'gpu available: {torch.cuda.is_available()}, used: {self.on_gpu}')
|
|
|
| def ddp_train(self, gpu_idx, model):
|
| """
|
| Entry point into a DP thread
|
| :param gpu_idx:
|
| :param model:
|
| :param cluster_obj:
|
| :return:
|
| """
|
|
|
| self.node_rank = 0
|
|
|
|
|
| self.show_progress_bar = self.show_progress_bar and self.node_rank == 0 and gpu_idx == 0
|
|
|
|
|
| if self.use_ddp:
|
| self.proc_rank = self.node_rank * self.num_gpus + gpu_idx
|
| self.world_size = self.num_gpus
|
|
|
|
|
| if self.logger is not None:
|
| self.logger.rank = self.proc_rank
|
|
|
|
|
|
|
|
|
| model.trainer = self
|
| model.init_ddp_connection(self.proc_rank, self.world_size)
|
|
|
|
|
|
|
| model.model = model.build_model()
|
| if not self.testing:
|
| self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
|
|
|
|
|
|
|
| if self.distributed_backend == 'ddp':
|
| torch.cuda.set_device(gpu_idx)
|
| model.cuda(gpu_idx)
|
|
|
|
|
| self.copy_trainer_model_properties(model)
|
|
|
|
|
| self.root_gpu = gpu_idx
|
|
|
| if self.distributed_backend == 'ddp':
|
| device_ids = [gpu_idx]
|
| else:
|
| device_ids = None
|
|
|
|
|
| model = model.configure_ddp(model, device_ids)
|
|
|
|
|
| self.run_pretrain_routine(model)
|
|
|
| def resolve_root_node_address(self, root_node):
|
| if '[' in root_node:
|
| name = root_node.split('[')[0]
|
| number = root_node.split(',')[0]
|
| if '-' in number:
|
| number = number.split('-')[0]
|
|
|
| number = re.sub('[^0-9]', '', number)
|
| root_node = name + number
|
|
|
| return root_node
|
|
|
| def log_metrics(self, metrics, grad_norm_dic, step=None):
|
| """Logs the metric dict passed in.
|
|
|
| :param metrics:
|
| :param grad_norm_dic:
|
| """
|
|
|
| metrics['epoch'] = self.current_epoch
|
|
|
|
|
| metrics.update(grad_norm_dic)
|
|
|
|
|
| scalar_metrics = self.metrics_to_scalars(metrics)
|
|
|
| step = step if step is not None else self.global_step
|
|
|
| if self.proc_rank == 0 and self.logger is not None:
|
| self.logger.log_metrics(scalar_metrics, step=step)
|
| self.logger.save()
|
|
|
| def add_tqdm_metrics(self, metrics):
|
| for k, v in metrics.items():
|
| if type(v) is torch.Tensor:
|
| v = v.item()
|
|
|
| self.tqdm_metrics[k] = v
|
|
|
| def metrics_to_scalars(self, metrics):
|
| new_metrics = {}
|
| for k, v in metrics.items():
|
| if isinstance(v, torch.Tensor):
|
| v = v.item()
|
|
|
| if type(v) is dict:
|
| v = self.metrics_to_scalars(v)
|
|
|
| new_metrics[k] = v
|
|
|
| return new_metrics
|
|
|
| def process_output(self, output, train=False):
|
| """Reduces output according to the training mode.
|
|
|
| Separates loss from logging and tqdm metrics
|
| :param output:
|
| :return:
|
| """
|
|
|
|
|
|
|
|
|
| callback_metrics = {}
|
| for k, v in output.items():
|
| if k not in ['progress_bar', 'log', 'hiddens']:
|
| callback_metrics[k] = v
|
|
|
| if train and self.use_dp:
|
| num_gpus = self.num_gpus
|
| callback_metrics = self.reduce_distributed_output(callback_metrics, num_gpus)
|
|
|
| for k, v in callback_metrics.items():
|
| if isinstance(v, torch.Tensor):
|
| callback_metrics[k] = v.item()
|
|
|
|
|
|
|
|
|
| try:
|
| progress_output = output['progress_bar']
|
|
|
|
|
| if train and self.use_dp:
|
| num_gpus = self.num_gpus
|
| progress_output = self.reduce_distributed_output(progress_output, num_gpus)
|
|
|
| progress_bar_metrics = progress_output
|
| except Exception:
|
| progress_bar_metrics = {}
|
|
|
|
|
|
|
|
|
|
|
| try:
|
| log_output = output['log']
|
|
|
|
|
| if train and self.use_dp:
|
| num_gpus = self.num_gpus
|
| log_output = self.reduce_distributed_output(log_output, num_gpus)
|
|
|
| log_metrics = log_output
|
| except Exception:
|
| log_metrics = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
| loss = None
|
| if train:
|
| try:
|
| loss = output['loss']
|
| except Exception:
|
| if type(output) is torch.Tensor:
|
| loss = output
|
| else:
|
| raise RuntimeError(
|
| 'No `loss` value in the dictionary returned from `model.training_step()`.'
|
| )
|
|
|
|
|
| if self.use_dp:
|
| loss = self.reduce_distributed_output(loss, self.num_gpus)
|
|
|
|
|
|
|
|
|
| hiddens = output.get('hiddens')
|
|
|
|
|
| callback_metrics.update(progress_bar_metrics)
|
| callback_metrics.update(log_metrics)
|
|
|
|
|
| for k, v in callback_metrics.items():
|
| if isinstance(v, torch.Tensor):
|
| callback_metrics[k] = v.item()
|
|
|
| return loss, progress_bar_metrics, log_metrics, callback_metrics, hiddens
|
|
|
| def reduce_distributed_output(self, output, num_gpus):
|
| if num_gpus <= 1:
|
| return output
|
|
|
|
|
|
|
| if type(output) is torch.Tensor:
|
| return output.mean()
|
|
|
| for k, v in output.items():
|
|
|
| if isinstance(output[k], dict):
|
| output[k] = self.reduce_distributed_output(output[k], num_gpus)
|
|
|
|
|
| elif isinstance(output[k], torch.Tensor) and output[k].dim() == 0:
|
| pass
|
|
|
|
|
| elif output[k].size(0) == num_gpus:
|
| reduced = torch.mean(output[k])
|
| output[k] = reduced
|
| return output
|
|
|
| def clip_gradients(self):
|
| if self.gradient_clip_val > 0:
|
| model = self.get_model()
|
| torch.nn.utils.clip_grad_norm_(model.parameters(), self.gradient_clip_val)
|
|
|
| def print_nan_gradients(self):
|
| model = self.get_model()
|
| for param in model.parameters():
|
| if (param.grad is not None) and torch.isnan(param.grad.float()).any():
|
| logging.info(param, param.grad)
|
|
|
| def configure_accumulated_gradients(self, accumulate_grad_batches):
|
| self.accumulate_grad_batches = None
|
|
|
| if isinstance(accumulate_grad_batches, dict):
|
| self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches)
|
| elif isinstance(accumulate_grad_batches, int):
|
| schedule = {1: accumulate_grad_batches}
|
| self.accumulation_scheduler = GradientAccumulationScheduler(schedule)
|
| else:
|
| raise TypeError("Gradient accumulation supports only int and dict types")
|
|
|
| def get_dataloaders(self, model):
|
| if not self.testing:
|
| self.init_train_dataloader(model)
|
| self.init_val_dataloader(model)
|
| else:
|
| self.init_test_dataloader(model)
|
|
|
| if self.use_ddp:
|
| dist.barrier()
|
| if not self.testing:
|
| self.get_train_dataloader()
|
| self.get_val_dataloaders()
|
| else:
|
| self.get_test_dataloaders()
|
|
|
| def init_train_dataloader(self, model):
|
| self.fisrt_epoch = True
|
| self.get_train_dataloader = model.train_dataloader
|
| if isinstance(self.get_train_dataloader(), torch.utils.data.DataLoader):
|
| self.num_training_batches = len(self.get_train_dataloader())
|
| self.num_training_batches = int(self.num_training_batches)
|
| else:
|
| self.num_training_batches = float('inf')
|
| self.is_iterable_train_dataloader = True
|
| if isinstance(self.val_check_interval, int):
|
| self.val_check_batch = self.val_check_interval
|
| else:
|
| self._percent_range_check('val_check_interval')
|
| self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
|
| self.val_check_batch = max(1, self.val_check_batch)
|
|
|
| def init_val_dataloader(self, model):
|
| self.get_val_dataloaders = model.val_dataloader
|
| self.num_val_batches = 0
|
| if self.get_val_dataloaders() is not None:
|
| if isinstance(self.get_val_dataloaders()[0], torch.utils.data.DataLoader):
|
| self.num_val_batches = sum(len(dataloader) for dataloader in self.get_val_dataloaders())
|
| self.num_val_batches = int(self.num_val_batches)
|
| else:
|
| self.num_val_batches = float('inf')
|
|
|
| def init_test_dataloader(self, model):
|
| self.get_test_dataloaders = model.test_dataloader
|
| if self.get_test_dataloaders() is not None:
|
| if isinstance(self.get_test_dataloaders()[0], torch.utils.data.DataLoader):
|
| self.num_test_batches = sum(len(dataloader) for dataloader in self.get_test_dataloaders())
|
| self.num_test_batches = int(self.num_test_batches)
|
| else:
|
| self.num_test_batches = float('inf')
|
|
|
| def evaluate(self, model, dataloaders, max_batches, test=False):
|
| """Run evaluation code.
|
|
|
| :param model: PT model
|
| :param dataloaders: list of PT dataloaders
|
| :param max_batches: Scalar
|
| :param test: boolean
|
| :return:
|
| """
|
|
|
| model.zero_grad()
|
| model.eval()
|
|
|
|
|
| self.copy_trainer_model_properties(model)
|
|
|
|
|
| torch.set_grad_enabled(False)
|
|
|
| if test:
|
| self.get_model().test_start()
|
|
|
| outputs = []
|
|
|
|
|
| for dataloader_idx, dataloader in enumerate(dataloaders):
|
| dl_outputs = []
|
| for batch_idx, batch in enumerate(dataloader):
|
|
|
| if batch is None:
|
| continue
|
|
|
|
|
| if batch_idx >= max_batches:
|
| break
|
|
|
|
|
|
|
|
|
| output = self.evaluation_forward(model,
|
| batch,
|
| batch_idx,
|
| dataloader_idx,
|
| test)
|
|
|
|
|
| dl_outputs.append(output)
|
|
|
|
|
| if test:
|
| self.test_progress_bar.update(1)
|
| else:
|
| self.val_progress_bar.update(1)
|
| outputs.append(dl_outputs)
|
|
|
|
|
| if len(dataloaders) == 1:
|
| outputs = outputs[0]
|
|
|
|
|
| model = self.get_model()
|
| if test:
|
| eval_results_ = model.test_end(outputs)
|
| else:
|
| eval_results_ = model.validation_end(outputs)
|
| eval_results = eval_results_
|
|
|
|
|
| model.train()
|
|
|
|
|
| torch.set_grad_enabled(True)
|
|
|
| return eval_results
|
|
|
| def run_evaluation(self, test=False):
|
|
|
| model = self.get_model()
|
| model.on_pre_performance_check()
|
|
|
|
|
| if test:
|
| dataloaders = self.get_test_dataloaders()
|
| max_batches = self.num_test_batches
|
| else:
|
|
|
| dataloaders = self.get_val_dataloaders()
|
| max_batches = self.num_val_batches
|
|
|
|
|
|
|
| position = 2 * self.process_position + (not test)
|
| desc = 'Testing' if test else 'Validating'
|
| pbar = tqdm.tqdm(desc=desc, total=max_batches, leave=test, position=position,
|
| disable=not self.show_progress_bar, dynamic_ncols=True,
|
| unit='batch', file=sys.stdout)
|
| setattr(self, f'{"test" if test else "val"}_progress_bar', pbar)
|
|
|
|
|
| eval_results = self.evaluate(self.model,
|
| dataloaders,
|
| max_batches,
|
| test)
|
| if eval_results is not None:
|
| _, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(
|
| eval_results)
|
|
|
|
|
| self.add_tqdm_metrics(prog_bar_metrics)
|
|
|
|
|
| self.log_metrics(log_metrics, {})
|
|
|
|
|
| self.callback_metrics.update(callback_metrics)
|
|
|
|
|
| model.on_post_performance_check()
|
|
|
|
|
| tqdm_metrics = self.training_tqdm_dict
|
| if not test:
|
| self.main_progress_bar.set_postfix(**tqdm_metrics)
|
|
|
|
|
| if test:
|
| self.test_progress_bar.close()
|
| else:
|
| self.val_progress_bar.close()
|
|
|
|
|
| if self.proc_rank == 0 and self.checkpoint_callback is not None and not test:
|
| self.checkpoint_callback.on_epoch_end(epoch=self.current_epoch,
|
| logs=self.callback_metrics)
|
|
|
| def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test=False):
|
|
|
| args = [batch, batch_idx]
|
|
|
| if test and len(self.get_test_dataloaders()) > 1:
|
| args.append(dataloader_idx)
|
|
|
| elif not test and len(self.get_val_dataloaders()) > 1:
|
| args.append(dataloader_idx)
|
|
|
|
|
| if self.use_ddp or self.use_dp:
|
| output = model(*args)
|
| return output
|
|
|
|
|
| if self.single_gpu:
|
|
|
| root_gpu = 0
|
| if isinstance(self.data_parallel_device_ids, list):
|
| root_gpu = self.data_parallel_device_ids[0]
|
| batch = self.transfer_batch_to_gpu(batch, root_gpu)
|
| args[0] = batch
|
|
|
|
|
| if test:
|
| output = model.test_step(*args)
|
| else:
|
| output = model.validation_step(*args)
|
|
|
| return output
|
|
|
| def train(self):
|
| model = self.get_model()
|
|
|
| for epoch in range(self.current_epoch, 1000000):
|
|
|
| if self.use_ddp and hasattr(self.get_train_dataloader().sampler, 'set_epoch'):
|
| self.get_train_dataloader().sampler.set_epoch(epoch)
|
|
|
|
|
| model = self.get_model()
|
|
|
|
|
| model.current_epoch = epoch
|
| self.current_epoch = epoch
|
|
|
| total_val_batches = 0
|
| if not self.disable_validation:
|
|
|
| is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
|
| val_checks_per_epoch = self.num_training_batches // self.val_check_batch
|
| val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0
|
| total_val_batches = self.num_val_batches * val_checks_per_epoch
|
|
|
|
|
| self.total_batches = self.num_training_batches + total_val_batches
|
| self.batch_loss_value = 0
|
|
|
| if self.is_iterable_train_dataloader:
|
|
|
| num_iterations = None
|
| else:
|
| num_iterations = self.total_batches
|
|
|
|
|
|
|
| desc = f'Epoch {epoch + 1}' if not self.is_iterable_train_dataloader else ''
|
| self.main_progress_bar.set_description(desc)
|
|
|
|
|
| self.accumulation_scheduler.on_epoch_begin(epoch, self)
|
|
|
|
|
|
|
|
|
| self.run_training_epoch()
|
|
|
|
|
| if self.lr_schedulers is not None:
|
| for lr_scheduler in self.lr_schedulers:
|
| lr_scheduler.step(epoch=self.current_epoch)
|
|
|
| self.main_progress_bar.close()
|
|
|
| model.on_train_end()
|
|
|
| if self.logger is not None:
|
| self.logger.finalize("success")
|
|
|
| def run_training_epoch(self):
|
|
|
| if self.is_function_implemented('on_epoch_start'):
|
| model = self.get_model()
|
| model.on_epoch_start()
|
|
|
|
|
| for batch_idx, batch in enumerate(self.get_train_dataloader()):
|
|
|
| if batch_idx >= self.num_training_batches:
|
| break
|
|
|
| self.batch_idx = batch_idx
|
|
|
| model = self.get_model()
|
| model.global_step = self.global_step
|
|
|
|
|
|
|
|
|
| output = self.run_training_batch(batch, batch_idx)
|
| batch_result, grad_norm_dic, batch_step_metrics = output
|
|
|
|
|
| early_stop_epoch = batch_result == -1
|
|
|
|
|
|
|
|
|
| should_check_val = (
|
| not self.disable_validation and self.global_step % self.val_check_batch == 0 and not self.fisrt_epoch)
|
| self.fisrt_epoch = False
|
|
|
| if should_check_val:
|
| self.run_evaluation(test=self.testing)
|
|
|
|
|
| should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch
|
| if should_save_log:
|
| if self.proc_rank == 0 and self.logger is not None:
|
| self.logger.save()
|
|
|
|
|
| should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch
|
| if should_log_metrics:
|
|
|
| self.log_metrics(batch_step_metrics, grad_norm_dic)
|
|
|
| self.global_step += 1
|
| self.total_batch_idx += 1
|
|
|
|
|
|
|
|
|
| if early_stop_epoch:
|
| break
|
| if self.global_step > self.max_updates:
|
| print("| Training end..")
|
| exit()
|
|
|
|
|
| if self.is_function_implemented('on_epoch_end'):
|
| model = self.get_model()
|
| model.on_epoch_end()
|
|
|
| def run_training_batch(self, batch, batch_idx):
|
|
|
| grad_norm_dic = {}
|
|
|
|
|
| all_callback_metrics = []
|
|
|
|
|
| all_log_metrics = []
|
|
|
| if batch is None:
|
| return 0, grad_norm_dic, {}
|
|
|
|
|
| if self.is_function_implemented('on_batch_start'):
|
| model_ref = self.get_model()
|
| response = model_ref.on_batch_start(batch)
|
|
|
| if response == -1:
|
| return -1, grad_norm_dic, {}
|
|
|
| splits = [batch]
|
| self.hiddens = None
|
| for split_idx, split_batch in enumerate(splits):
|
| self.split_idx = split_idx
|
|
|
|
|
| for opt_idx, optimizer in enumerate(self.optimizers):
|
| if optimizer is None:
|
| continue
|
|
|
|
|
| if len(self.optimizers) > 1:
|
| for param in self.get_model().parameters():
|
| param.requires_grad = False
|
| for group in optimizer.param_groups:
|
| for param in group['params']:
|
| param.requires_grad = True
|
|
|
|
|
| def optimizer_closure():
|
|
|
| output = self.training_forward(
|
| split_batch, batch_idx, opt_idx, self.hiddens)
|
|
|
| closure_loss = output[0]
|
| progress_bar_metrics = output[1]
|
| log_metrics = output[2]
|
| callback_metrics = output[3]
|
| self.hiddens = output[4]
|
| if closure_loss is None:
|
| return None
|
|
|
|
|
|
|
| closure_loss = closure_loss / self.accumulate_grad_batches
|
|
|
|
|
| model_ref = self.get_model()
|
| if closure_loss.requires_grad:
|
| model_ref.backward(closure_loss, optimizer)
|
|
|
|
|
| all_callback_metrics.append(callback_metrics)
|
|
|
|
|
| self.add_tqdm_metrics(progress_bar_metrics)
|
| all_log_metrics.append(log_metrics)
|
|
|
|
|
| if self.is_function_implemented('on_after_backward'):
|
| model_ref = self.get_model()
|
| model_ref.on_after_backward()
|
|
|
| return closure_loss
|
|
|
|
|
| loss = optimizer_closure()
|
| if loss is None:
|
| continue
|
|
|
|
|
| if self.print_nan_grads:
|
| self.print_nan_gradients()
|
|
|
|
|
| self.batch_loss_value += loss.item()
|
|
|
|
|
| if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
|
|
|
|
|
| if batch_idx % self.row_log_interval == 0:
|
| if self.track_grad_norm > 0:
|
| model = self.get_model()
|
| grad_norm_dic = model.grad_norm(
|
| self.track_grad_norm)
|
|
|
|
|
| self.clip_gradients()
|
|
|
|
|
|
|
| model = self.get_model()
|
| model.optimizer_step(self.current_epoch, batch_idx, optimizer, opt_idx)
|
|
|
|
|
| self.running_loss.append(self.batch_loss_value)
|
| self.batch_loss_value = 0
|
| self.avg_loss = np.mean(self.running_loss[-100:])
|
|
|
|
|
| if self.is_function_implemented('on_batch_end'):
|
| model = self.get_model()
|
| model.on_batch_end()
|
|
|
|
|
| self.main_progress_bar.update(1)
|
| self.main_progress_bar.set_postfix(**self.training_tqdm_dict)
|
|
|
|
|
| all_log_metrics = {k: v for d in all_log_metrics for k, v in d.items()}
|
|
|
|
|
| self.callback_metrics.update({k: v for d in all_callback_metrics for k, v in d.items()})
|
|
|
| return 0, grad_norm_dic, all_log_metrics
|
|
|
| def training_forward(self, batch, batch_idx, opt_idx, hiddens):
|
| """
|
| Handle forward for each training case (distributed, single gpu, etc...)
|
| :param batch:
|
| :param batch_idx:
|
| :return:
|
| """
|
|
|
|
|
|
|
|
|
| args = [batch, batch_idx, opt_idx]
|
|
|
|
|
| if self.use_ddp or self.use_dp:
|
| output = self.model(*args)
|
|
|
| elif self.single_gpu:
|
| gpu_id = 0
|
| if isinstance(self.data_parallel_device_ids, list):
|
| gpu_id = self.data_parallel_device_ids[0]
|
| batch = self.transfer_batch_to_gpu(copy.copy(batch), gpu_id)
|
| args[0] = batch
|
| output = self.model.training_step(*args)
|
|
|
| else:
|
| output = self.model.training_step(*args)
|
|
|
|
|
| model_ref = self.get_model()
|
| output_ = model_ref.training_end(output)
|
| if output_ is not None:
|
| output = output_
|
|
|
|
|
| output = self.process_output(output, train=True)
|
|
|
| return output
|
|
|
|
|
|
|
|
|
| def is_function_implemented(self, f_name):
|
| model = self.get_model()
|
| f_op = getattr(model, f_name, None)
|
| return callable(f_op)
|
|
|
| def _percent_range_check(self, name):
|
| value = getattr(self, name)
|
| msg = f"`{name}` must lie in the range [0.0, 1.0], but got {value:.3f}."
|
| if name == "val_check_interval":
|
| msg += " If you want to disable validation set `val_percent_check` to 0.0 instead."
|
|
|
| if not 0. <= value <= 1.:
|
| raise ValueError(msg)
|
|
|