| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Any, Optional, Union |
|
|
| import torch |
| from torch import nn |
| from torch.utils.data import DistributedSampler, RandomSampler |
|
|
| from transformers import PreTrainedModel, Trainer, logging |
| from transformers.models.fsmt.configuration_fsmt import FSMTConfig |
| from transformers.optimization import ( |
| Adafactor, |
| get_constant_schedule, |
| get_constant_schedule_with_warmup, |
| get_cosine_schedule_with_warmup, |
| get_cosine_with_hard_restarts_schedule_with_warmup, |
| get_linear_schedule_with_warmup, |
| get_polynomial_decay_schedule_with_warmup, |
| ) |
| from transformers.trainer_pt_utils import get_tpu_sampler |
| from transformers.training_args import ParallelMode |
| from transformers.utils import is_torch_xla_available |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
| arg_to_scheduler = { |
| "linear": get_linear_schedule_with_warmup, |
| "cosine": get_cosine_schedule_with_warmup, |
| "cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup, |
| "polynomial": get_polynomial_decay_schedule_with_warmup, |
| "constant": get_constant_schedule, |
| "constant_w_warmup": get_constant_schedule_with_warmup, |
| } |
|
|
|
|
| class Seq2SeqTrainer(Trainer): |
| def __init__(self, config=None, data_args=None, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
|
|
| if config is None: |
| assert isinstance(self.model, PreTrainedModel), ( |
| "If no `config` is passed the model to be trained has to be of type `PreTrainedModel`, but is" |
| f" {self.model.__class__}" |
| ) |
| self.config = self.model.config |
| else: |
| self.config = config |
|
|
| self.data_args = data_args |
| self.vocab_size = self.config.tgt_vocab_size if isinstance(self.config, FSMTConfig) else self.config.vocab_size |
|
|
| if self.args.label_smoothing != 0 or (self.data_args is not None and self.data_args.ignore_pad_token_for_loss): |
| assert self.config.pad_token_id is not None, ( |
| "Make sure that `config.pad_token_id` is correctly defined when ignoring `pad_token` for loss" |
| " calculation or doing label smoothing." |
| ) |
|
|
| if self.config.pad_token_id is None and self.config.eos_token_id is not None: |
| logger.warning( |
| f"The `config.pad_token_id` is `None`. Using `config.eos_token_id` = {self.config.eos_token_id} for" |
| " padding.." |
| ) |
|
|
| if self.args.label_smoothing == 0: |
| self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id) |
| else: |
| |
| from utils import label_smoothed_nll_loss |
|
|
| self.loss_fn = label_smoothed_nll_loss |
|
|
| def create_optimizer_and_scheduler(self, num_training_steps: int): |
| """ |
| Setup the optimizer and the learning rate scheduler. |
| |
| We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the |
| Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass. |
| """ |
| if self.optimizer is None: |
| no_decay = ["bias", "LayerNorm.weight"] |
| optimizer_grouped_parameters = [ |
| { |
| "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], |
| "weight_decay": self.args.weight_decay, |
| }, |
| { |
| "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], |
| "weight_decay": 0.0, |
| }, |
| ] |
| if self.args.adafactor: |
| optimizer_cls = Adafactor |
| optimizer_kwargs = {"scale_parameter": False, "relative_step": False} |
| else: |
| optimizer_cls = torch.optim.AdamW |
| optimizer_kwargs = { |
| "betas": (self.args.adam_beta1, self.args.adam_beta2), |
| "eps": self.args.adam_epsilon, |
| } |
| optimizer_kwargs["lr"] = self.args.learning_rate |
| self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) |
|
|
| if self.lr_scheduler is None: |
| self.lr_scheduler = self._get_lr_scheduler(num_training_steps) |
| else: |
| logger.warning("scheduler is passed to `Seq2SeqTrainer`, `--lr_scheduler` arg is ignored.") |
|
|
| def _get_lr_scheduler(self, num_training_steps): |
| schedule_func = arg_to_scheduler[self.args.lr_scheduler] |
| if self.args.lr_scheduler == "constant": |
| scheduler = schedule_func(self.optimizer) |
| elif self.args.lr_scheduler == "constant_w_warmup": |
| scheduler = schedule_func(self.optimizer, num_warmup_steps=self.args.warmup_steps) |
| else: |
| scheduler = schedule_func( |
| self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps |
| ) |
| return scheduler |
|
|
| def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: |
| if isinstance(self.train_dataset, torch.utils.data.IterableDataset): |
| return None |
| elif is_torch_xla_available(): |
| return get_tpu_sampler(self.train_dataset) |
| else: |
| if self.args.sortish_sampler: |
| self.train_dataset.make_sortish_sampler( |
| self.args.per_device_train_batch_size, |
| distributed=(self.args.parallel_mode == ParallelMode.DISTRIBUTED), |
| ) |
|
|
| return ( |
| RandomSampler(self.train_dataset) |
| if self.args.local_rank == -1 |
| else DistributedSampler(self.train_dataset) |
| ) |
|
|
| def _compute_loss(self, model, inputs, labels): |
| if self.args.label_smoothing == 0: |
| if self.data_args is not None and self.data_args.ignore_pad_token_for_loss: |
| |
| logits = model(**inputs, use_cache=False)[0] |
| loss = self.loss_fn(logits.view(-1, logits.shape[-1]), labels.view(-1)) |
| else: |
| |
| loss, logits = model(**inputs, labels=labels, use_cache=False)[:2] |
| else: |
| |
| logits = model(**inputs, use_cache=False)[0] |
| lprobs = torch.nn.functional.log_softmax(logits, dim=-1) |
| loss, _ = self.loss_fn(lprobs, labels, self.args.label_smoothing, ignore_index=self.config.pad_token_id) |
| return loss, logits |
|
|
| def compute_loss(self, model, inputs): |
| labels = inputs.pop("labels") |
| loss, _ = self._compute_loss(model, inputs, labels) |
| return loss |
|
|
| def prediction_step( |
| self, |
| model: nn.Module, |
| inputs: dict[str, Union[torch.Tensor, Any]], |
| prediction_loss_only: bool, |
| ignore_keys: Optional[list[str]] = None, |
| ) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: |
| """ |
| Perform an evaluation step on :obj:`model` using obj:`inputs`. |
| |
| Subclass and override to inject custom behavior. |
| |
| Args: |
| model (:obj:`nn.Module`): |
| The model to evaluate. |
| inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`): |
| The inputs and targets of the model. |
| |
| The dictionary will be unpacked before being fed to the model. Most models expect the targets under the |
| argument :obj:`labels`. Check your model's documentation for all accepted arguments. |
| prediction_loss_only (:obj:`bool`): |
| Whether or not to return the loss only. |
| |
| Return: |
| Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: |
| A tuple with the loss, logits and labels (each being optional). |
| """ |
| inputs = self._prepare_inputs(inputs) |
|
|
| gen_kwargs = { |
| "max_length": self.data_args.val_max_target_length |
| if self.data_args is not None |
| else self.config.max_length, |
| "num_beams": self.data_args.eval_beams if self.data_args is not None else self.config.num_beams, |
| } |
|
|
| if self.args.predict_with_generate and not self.args.prediction_loss_only: |
| generated_tokens = self.model.generate( |
| inputs["input_ids"], |
| attention_mask=inputs["attention_mask"], |
| **gen_kwargs, |
| ) |
| |
| if generated_tokens.shape[-1] < gen_kwargs["max_length"]: |
| generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) |
|
|
| labels = inputs.pop("labels") |
| with torch.no_grad(): |
| |
| loss, logits = self._compute_loss(model, inputs, labels) |
|
|
| loss = loss.mean().detach() |
| if self.args.prediction_loss_only: |
| return (loss, None, None) |
|
|
| logits = generated_tokens if self.args.predict_with_generate else logits |
|
|
| if labels.shape[-1] < gen_kwargs["max_length"]: |
| labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) |
|
|
| return (loss, logits, labels) |
|
|
| def _pad_tensors_to_max_len(self, tensor, max_length): |
| |
| pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else self.config.eos_token_id |
|
|
| if pad_token_id is None: |
| raise ValueError( |
| "Make sure that either `config.pad_token_id` or `config.eos_token_id` is defined if tensor has to be" |
| f" padded to `max_length`={max_length}" |
| ) |
|
|
| padded_tensor = pad_token_id * torch.ones( |
| (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device |
| ) |
| padded_tensor[:, : tensor.shape[-1]] = tensor |
| return padded_tensor |
|
|