| """ |
| This is the loadable seq2seq trainer library that is |
| in charge of training details, loss compute, and statistics. |
| See train.py for a use case of this library. |
| |
| Note: To make this a general library, we implement *only* |
| mechanism things here(i.e. what to do), and leave the strategy |
| things to users(i.e. how to do it). Also see train.py(one of the |
| users of this library) for the strategy things we do. |
| """ |
|
|
| import time |
| import sys |
| import torch |
| import traceback |
| import onmt.utils |
| from onmt.utils.loss import LossCompute |
| from onmt.utils.logging import logger |
| from onmt.utils.scoring_utils import ScoringPreparator |
| from onmt.scorers import get_scorers_cls, build_scorers |
|
|
|
|
| def build_trainer(opt, device_id, model, vocabs, optim, model_saver=None): |
| """ |
| Simplify `Trainer` creation based on user `opt`s* |
| |
| Args: |
| opt (:obj:`Namespace`): user options (usually from argument parsing) |
| model (:obj:`onmt.models.NMTModel`): the model to train |
| fields (dict): dict of fields |
| optim (:obj:`onmt.utils.Optimizer`): optimizer used during training |
| data_type (str): string describing the type of data |
| e.g. "text" |
| model_saver(:obj:`onmt.models.ModelSaverBase`): the utility object |
| used to save the model |
| """ |
|
|
| train_loss = LossCompute.from_opts(opt, model, vocabs["tgt"]) |
| valid_loss = LossCompute.from_opts(opt, model, vocabs["tgt"], train=False) |
|
|
| scoring_preparator = ScoringPreparator(vocabs=vocabs, opt=opt) |
| validset_transforms = opt.data.get("valid", {}).get("transforms", None) |
| if validset_transforms: |
| scoring_preparator.warm_up(validset_transforms) |
| scorers_cls = get_scorers_cls(opt.valid_metrics) |
| valid_scorers = build_scorers(opt, scorers_cls) |
|
|
| trunc_size = opt.truncated_decoder |
| norm_method = opt.normalization |
| accum_count = opt.accum_count |
| accum_steps = opt.accum_steps |
| n_gpu = opt.world_size |
| parallel_mode = opt.parallel_mode |
| average_decay = opt.average_decay |
| average_every = opt.average_every |
| dropout = opt.dropout |
| attention_dropout = opt.attention_dropout |
| dropout_steps = opt.dropout_steps |
| zero_out_prompt_loss = opt.zero_out_prompt_loss |
| if device_id >= 0: |
| gpu_rank = opt.gpu_ranks[device_id] |
| else: |
| gpu_rank = -1 |
| n_gpu = 0 |
|
|
| earlystopper = ( |
| onmt.utils.EarlyStopping( |
| opt.early_stopping, scorers=onmt.utils.scorers_from_opts(opt) |
| ) |
| if opt.early_stopping > 0 |
| else None |
| ) |
|
|
| report_manager = onmt.utils.build_report_manager(opt, gpu_rank) |
| trainer = Trainer( |
| model, |
| train_loss, |
| valid_loss, |
| scoring_preparator, |
| valid_scorers, |
| optim, |
| trunc_size, |
| norm_method, |
| accum_count, |
| accum_steps, |
| n_gpu, |
| gpu_rank, |
| parallel_mode, |
| report_manager, |
| with_align=True if opt.lambda_align > 0 else False, |
| model_saver=model_saver, |
| average_decay=average_decay, |
| average_every=average_every, |
| model_dtype=opt.model_dtype, |
| earlystopper=earlystopper, |
| dropout=dropout, |
| attention_dropout=attention_dropout, |
| dropout_steps=dropout_steps, |
| zero_out_prompt_loss=zero_out_prompt_loss, |
| ) |
| return trainer |
|
|
|
|
| class Trainer(object): |
| """Class that controls the training process. |
| |
| Args: |
| model(:py:class:`onmt.models.model.NMTModel`): model to train |
| train_loss(:obj:`onmt.utils.loss.LossComputeBase`): |
| training loss computation |
| valid_loss(:obj:`onmt.utils.loss.LossComputeBase`): |
| training loss computation |
| scoring_preparator(:obj:`onmt.translate.utils.ScoringPreparator`): |
| preparator for the calculation of metrics via the |
| _eval_handler method |
| valid_scorers (dict): keeps in memory the current values |
| of the validation metrics |
| optim(:obj:`onmt.utils.optimizers.Optimizer`): |
| the optimizer responsible for update |
| trunc_size(int): length of truncated back propagation |
| through time |
| accum_count(list): accumulate gradients this many times. |
| accum_steps(list): steps for accum gradients changes. |
| n_gpu (int): number of gpu. |
| gpu_rank (int): ordinal rank of the gpu in the list. |
| report_manager(:obj:`onmt.utils.ReportMgrBase`): |
| the object that creates reports, or None |
| with_align (bool): whether to jointly lear alignment |
| (Transformer) |
| model_saver(:obj:`onmt.models.ModelSaverBase`): the saver is |
| used to save a checkpoint. |
| Thus nothing will be saved if this parameter is None. |
| average_decay (float): cf opt.average_decay |
| average_every (int): average model every x steps. |
| model_dtype (str): fp32 or fp16. |
| earlystopper (:obj:`onmt.utils.EarlyStopping`): add early |
| stopping mecanism |
| dropout (float): dropout value in RNN or FF layers. |
| attention_dropout (float): dropaout in attention layers. |
| dropout_steps (list): dropout values scheduling in steps. |
| zero_out_prompt_loss (bool): whether to zero-out the prompt loss |
| (mostly for LLM finetuning).""" |
|
|
| def __init__( |
| self, |
| model, |
| train_loss, |
| valid_loss, |
| scoring_preparator, |
| valid_scorers, |
| optim, |
| trunc_size=0, |
| norm_method="sents", |
| accum_count=[1], |
| accum_steps=[0], |
| n_gpu=1, |
| gpu_rank=1, |
| parallel_mode="data_parallel", |
| report_manager=None, |
| with_align=False, |
| model_saver=None, |
| average_decay=0, |
| average_every=1, |
| model_dtype="fp32", |
| earlystopper=None, |
| dropout=[0.3], |
| attention_dropout=[0.1], |
| dropout_steps=[0], |
| zero_out_prompt_loss=False, |
| ): |
| |
|
|
| self.model = model |
| self.train_loss = train_loss |
| self.valid_loss = valid_loss |
|
|
| self.scoring_preparator = scoring_preparator |
| self.valid_scorers = valid_scorers |
| self.optim = optim |
| self.trunc_size = trunc_size |
| self.norm_method = norm_method |
| self.accum_count_l = accum_count |
| self.accum_count = accum_count[0] |
| self.accum_steps = accum_steps |
| self.n_gpu = n_gpu |
| self.gpu_rank = gpu_rank |
| self.parallel_mode = parallel_mode |
| self.report_manager = report_manager |
| self.with_align = with_align |
| self.model_saver = model_saver |
| self.average_decay = average_decay |
| self.moving_average = None |
| self.average_every = average_every |
| self.model_dtype = model_dtype |
| self.earlystopper = earlystopper |
| self.dropout = dropout |
| self.attention_dropout = attention_dropout |
| self.dropout_steps = dropout_steps |
| self.zero_out_prompt_loss = zero_out_prompt_loss |
|
|
| for i in range(len(self.accum_count_l)): |
| assert self.accum_count_l[i] > 0 |
|
|
| |
| self.model.train() |
|
|
| def _eval_handler(self, scorer, preds, texts_ref): |
| """Trigger metrics calculations |
| |
| Args: |
| scorer (:obj:``onmt.scorer.Scorer``): scorer. |
| preds, texts_ref: outputs of the scorer's `translate` method. |
| |
| Returns: |
| The metric calculated by the scorer.""" |
|
|
| return scorer.compute_score(preds, texts_ref) |
|
|
| def _accum_count(self, step): |
| for i in range(len(self.accum_steps)): |
| if step > self.accum_steps[i]: |
| _accum = self.accum_count_l[i] |
| return _accum |
|
|
| def _maybe_update_dropout(self, step): |
| for i in range(len(self.dropout_steps)): |
| if step > 1 and step == self.dropout_steps[i] + 1: |
| self.model.update_dropout(self.dropout[i], self.attention_dropout[i]) |
| logger.info( |
| "Updated dropout/attn dropout to %f %f at step %d" |
| % (self.dropout[i], self.attention_dropout[i], step) |
| ) |
|
|
| def _accum_batches(self, iterator): |
| batches = [] |
| normalization = 0 |
| self.accum_count = self._accum_count(self.optim.training_step) |
| for batch in iterator: |
| batches.append(batch) |
| if self.norm_method == "tokens": |
| num_tokens = ( |
| batch["tgt"][:, 1:, 0].ne(self.train_loss.padding_idx).sum() |
| ) |
| normalization += num_tokens.item() |
| normalization -= len(batch["tgt"]) |
| else: |
| normalization += len(batch["tgt"]) |
| if len(batches) == self.accum_count: |
| yield batches, normalization |
| self.accum_count = self._accum_count(self.optim.training_step) |
| batches = [] |
| normalization = 0 |
| if batches: |
| yield batches, normalization |
|
|
| def _update_average(self, step): |
| if self.moving_average is None: |
| copy_params = [ |
| params.detach().float() for params in self.model.parameters() |
| ] |
| self.moving_average = copy_params |
| else: |
| average_decay = max(self.average_decay, 1 - (step + 1) / (step + 10)) |
| for (i, avg), cpt in zip( |
| enumerate(self.moving_average), self.model.parameters() |
| ): |
| self.moving_average[i] = ( |
| 1 - average_decay |
| ) * avg + cpt.detach().float() * average_decay |
|
|
| def train( |
| self, |
| train_iter, |
| train_steps, |
| save_checkpoint_steps=5000, |
| valid_iter=None, |
| valid_steps=10000, |
| ): |
| """The main training loop by iterating over ``train_iter`` and possibly |
| running validation on ``valid_iter``. |
| |
| Args: |
| train_iter: An iterator that returns the next training batch. |
| train_steps: Run training for this many iterations. |
| save_checkpoint_steps: Save a checkpoint every this many |
| iterations. |
| valid_iter: A generator that returns the next validation batch. |
| valid_steps: Run evaluation every this many iterations. |
| |
| Returns: |
| :obj:``nmt.Statistics``: training loss statistics""" |
|
|
| if valid_iter is None: |
| logger.info("Start training loop without validation...") |
| valid_stats = None |
| else: |
| logger.info( |
| "Start training loop and validate every %d steps...", valid_steps |
| ) |
| logger.info("Scoring with: {}".format(self.scoring_preparator.transform)) |
|
|
| total_stats = onmt.utils.Statistics() |
| report_stats = onmt.utils.Statistics() |
| self._start_report_manager(start_time=total_stats.start_time) |
| |
| torch.cuda.empty_cache() |
|
|
| for i, (batches, normalization) in enumerate(self._accum_batches(train_iter)): |
| step = self.optim.training_step |
| |
| self._maybe_update_dropout(step) |
|
|
| if self.n_gpu > 1 and self.parallel_mode == "data_parallel": |
| normalization = sum( |
| onmt.utils.distributed.all_gather_list(normalization) |
| ) |
|
|
| self._gradient_accumulation( |
| batches, normalization, total_stats, report_stats |
| ) |
|
|
| if self.average_decay > 0 and i % self.average_every == 0: |
| self._update_average(step) |
|
|
| report_stats = self._maybe_report_training( |
| step, train_steps, self.optim.learning_rate(), report_stats |
| ) |
|
|
| if valid_iter is not None and step % valid_steps == 0: |
| valid_stats = self.validate( |
| valid_iter, moving_average=self.moving_average |
| ) |
|
|
| if step % valid_steps == 0 and self.gpu_rank <= 0: |
| self._report_step( |
| self.optim.learning_rate(), |
| step, |
| valid_stats=valid_stats, |
| train_stats=total_stats, |
| ) |
|
|
| |
| if self.earlystopper is not None: |
| self.earlystopper(valid_stats, step) |
| |
| if self.earlystopper.has_stopped(): |
| logger.info("earlystopper has_stopped!") |
| break |
|
|
| if self.model_saver is not None and ( |
| save_checkpoint_steps != 0 and step % save_checkpoint_steps == 0 |
| ): |
| self.model_saver.save(step, moving_average=self.moving_average) |
|
|
| if train_steps > 0 and step >= train_steps: |
| break |
|
|
| if self.model_saver is not None: |
| self.model_saver.save(step, moving_average=self.moving_average) |
| return total_stats |
|
|
| def validate(self, valid_iter, moving_average=None): |
| """Validate model. |
| |
| Args: |
| valid_iter: validate data iterator |
| |
| Returns: |
| :obj:``nmt.Statistics``: validation loss statistics""" |
|
|
| valid_model = self.model |
| if moving_average: |
| |
| |
| model_params_data = [] |
| for avg, param in zip(self.moving_average, valid_model.parameters()): |
| model_params_data.append(param.data) |
| param.data = ( |
| avg.data.half() if param.dtype == torch.float16 else avg.data |
| ) |
|
|
| |
| valid_model.eval() |
|
|
| |
| |
| with torch.no_grad(): |
| stats = onmt.utils.Statistics() |
| start = time.time() |
| for batch in valid_iter: |
| src = batch["src"] |
| src_len = batch["srclen"] |
| tgt = batch["tgt"] |
|
|
| with torch.cuda.amp.autocast(enabled=self.optim.amp): |
| |
| model_out, attns = valid_model( |
| src, tgt, src_len, with_align=self.with_align |
| ) |
|
|
| |
| _, batch_stats = self.valid_loss(batch, model_out, attns) |
|
|
| stats.update(batch_stats) |
| logger.info( |
| """valid stats calculation |
| took: {} s.""".format( |
| time.time() - start |
| ) |
| ) |
|
|
| |
| if len(self.valid_scorers) > 0: |
| computed_metrics = {} |
| start = time.time() |
| preds, texts_ref = self.scoring_preparator.translate( |
| model=self.model, |
| gpu_rank=self.gpu_rank, |
| step=self.optim.training_step, |
| ) |
| logger.info( |
| """The translation of the valid dataset for dynamic scoring |
| took : {} s.""".format( |
| time.time() - start |
| ) |
| ) |
| for i, metric in enumerate(self.valid_scorers): |
| logger.info("UPDATING VALIDATION {}".format(metric)) |
| self.valid_scorers[metric]["value"] = self._eval_handler( |
| scorer=self.valid_scorers[metric]["scorer"], |
| preds=preds, |
| texts_ref=texts_ref, |
| ) |
| computed_metrics[metric] = self.valid_scorers[metric]["value"] |
| logger.info( |
| "validation {}: {}".format( |
| metric, self.valid_scorers[metric]["value"] |
| ) |
| ) |
| |
| metric_stats = onmt.utils.Statistics( |
| 0, 0, 0, 0, 0, computed_metrics |
| ) |
|
|
| |
| stats.update(metric_stats) |
|
|
| if moving_average: |
| for param_data, param in zip(model_params_data, self.model.parameters()): |
| param.data = param_data |
|
|
| |
| valid_model.train() |
|
|
| return stats |
|
|
| def _gradient_accumulation( |
| self, true_batches, normalization, total_stats, report_stats |
| ): |
| """Function that iterates over big batches = ``true_batches`` |
| |
| Perform a backward on the loss of each sub_batch and |
| finally update the params at the end of the big batch.""" |
|
|
| if self.accum_count > 1: |
| self.optim.zero_grad(set_to_none=True) |
|
|
| for k, batch in enumerate(true_batches): |
| target_size = batch["tgt"].size(1) |
| |
| if self.trunc_size: |
| trunc_size = self.trunc_size |
| else: |
| trunc_size = target_size |
|
|
| src = batch["src"] |
| src_len = batch["srclen"] |
| if src_len is not None: |
| report_stats.n_src_words += src_len.sum().item() |
| total_stats.n_src_words += src_len.sum().item() |
|
|
| tgt_outer = batch["tgt"] |
|
|
| bptt = False |
| for j in range(0, target_size - 1, trunc_size): |
| |
|
|
| tgt = tgt_outer[:, j : j + trunc_size, :] |
|
|
| |
| if self.accum_count == 1: |
| self.optim.zero_grad(set_to_none=True) |
| try: |
| with torch.cuda.amp.autocast(enabled=self.optim.amp): |
| model_out, attns = self.model( |
| src, tgt, src_len, bptt=bptt, with_align=self.with_align |
| ) |
| bptt = True |
|
|
| |
| if self.zero_out_prompt_loss: |
| |
| batch = self.train_loss.ignore_prompt(batch) |
| loss, batch_stats = self.train_loss( |
| batch, |
| model_out, |
| attns, |
| trunc_start=j, |
| trunc_size=trunc_size, |
| ) |
| if loss is not None: |
| loss /= normalization |
| self.optim.backward(loss) |
|
|
| total_stats.update(batch_stats) |
| report_stats.update(batch_stats) |
|
|
| except Exception as exc: |
| trace_content = traceback.format_exc() |
| if "CUDA out of memory" in trace_content: |
| logger.info( |
| "Step %d, cuda OOM - batch removed", |
| self.optim.training_step, |
| ) |
| torch.cuda.empty_cache() |
| if self.n_gpu > 1 and self.parallel_mode == "tensor_parallel": |
| torch.distributed.destroy_process_group() |
| sys.exit() |
| else: |
| traceback.print_exc() |
| raise exc |
|
|
| |
| if self.model.decoder.state != {}: |
| self.model.decoder.detach_state() |
|
|
| |
| |
| if self.n_gpu > 1 and self.parallel_mode == "data_parallel": |
| grads = [ |
| p.grad.data |
| for p in self.model.parameters() |
| if p.requires_grad and p.grad is not None |
| ] |
| onmt.utils.distributed.all_reduce_and_rescale_tensors( |
| grads, float(self.n_gpu) |
| ) |
|
|
| self.optim.step() |
|
|
| def _start_report_manager(self, start_time=None): |
| """Simple function to start report manager (if any)""" |
|
|
| if self.report_manager is not None: |
| if start_time is None: |
| self.report_manager.start() |
| else: |
| self.report_manager.start_time = start_time |
|
|
| def _maybe_report_training(self, step, num_steps, learning_rate, report_stats): |
| """Simple function to report training stats (if report_manager is set) |
| see ``onmt.utils.ReportManagerBase.report_training`` for doc""" |
|
|
| if self.report_manager is not None: |
| return self.report_manager.report_training( |
| step, |
| num_steps, |
| learning_rate, |
| None |
| if self.earlystopper is None |
| else self.earlystopper.current_tolerance, |
| report_stats, |
| multigpu=self.n_gpu > 1 and self.parallel_mode == "data_parallel", |
| ) |
|
|
| def _report_step(self, learning_rate, step, valid_stats=None, train_stats=None): |
| """Simple function to report stats (if report_manager is set) |
| see ``onmt.utils.ReportManagerBase.report_step`` for doc""" |
|
|
| if self.report_manager is not None: |
| return self.report_manager.report_step( |
| learning_rate, |
| None |
| if self.earlystopper is None |
| else self.earlystopper.current_tolerance, |
| step, |
| valid_stats=valid_stats, |
| train_stats=train_stats, |
| ) |
|
|