| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import json |
| import os |
| from types import MethodType |
| from typing import TYPE_CHECKING, Any, Optional, Union |
|
|
| import numpy as np |
| import torch |
| from transformers import Seq2SeqTrainer |
| from typing_extensions import override |
|
|
| from ...extras import logging |
| from ...extras.constants import IGNORE_INDEX |
| from ..callbacks import SaveProcessorCallback |
| from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status |
| from ..trainer_utils import create_custom_optimizer, create_custom_scheduler |
|
|
|
|
| if TYPE_CHECKING: |
| from torch.utils.data import Dataset |
| from transformers import ProcessorMixin |
| from transformers.trainer import PredictionOutput |
|
|
| from ...hparams import FinetuningArguments, ModelArguments, TrainingArguments |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class CustomSeq2SeqTrainer(Seq2SeqTrainer): |
| r"""Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE.""" |
|
|
| def __init__( |
| self, |
| finetuning_args: "FinetuningArguments", |
| processor: Optional["ProcessorMixin"], |
| model_args: Optional["ModelArguments"] = None, |
| gen_kwargs: Optional[dict[str, Any]] = None, |
| **kwargs, |
| ) -> None: |
| kwargs["processing_class"] = kwargs.pop("tokenizer") |
| |
| training_args: TrainingArguments = kwargs.get("args") |
| if training_args.fp8: |
| configure_fp8_environment(training_args) |
| if getattr(training_args, "fp8_backend", "auto") == "te": |
| patch_accelerator_for_fp8() |
|
|
| super().__init__(**kwargs) |
| if processor is not None: |
| |
| |
| self.model_accepts_loss_kwargs = False |
|
|
| self.finetuning_args = finetuning_args |
| if gen_kwargs is not None: |
| |
| self._gen_kwargs = gen_kwargs |
|
|
| if processor is not None: |
| self.add_callback(SaveProcessorCallback(processor)) |
|
|
| if finetuning_args.use_badam: |
| from badam import BAdamCallback, clip_grad_norm_old_version |
|
|
| self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) |
| self.add_callback(BAdamCallback) |
|
|
| if finetuning_args.use_dft_loss: |
| from ..trainer_utils import dft_loss_func |
|
|
| self.compute_loss_func = dft_loss_func |
|
|
| elif finetuning_args.use_eaft_loss: |
| from ..trainer_utils import eaft_loss_func |
|
|
| self.compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func( |
| outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha |
| ) |
|
|
| if training_args.fp8 and hasattr(self, "accelerator"): |
| verify_fp8_status(self.accelerator, training_args) |
|
|
| @override |
| def create_optimizer(self) -> "torch.optim.Optimizer": |
| if self.optimizer is None: |
| self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) |
| return super().create_optimizer() |
|
|
| @override |
| def create_scheduler( |
| self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None |
| ) -> "torch.optim.lr_scheduler.LRScheduler": |
| create_custom_scheduler(self.args, num_training_steps, optimizer) |
| return super().create_scheduler(num_training_steps, optimizer) |
|
|
| @override |
| def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Sampler"]: |
| if self.finetuning_args.disable_shuffling: |
| return torch.utils.data.SequentialSampler(self.train_dataset) |
|
|
| return super()._get_train_sampler(*args, **kwargs) |
|
|
| @override |
| def compute_loss(self, model, inputs, *args, **kwargs): |
| return super().compute_loss(model, inputs, *args, **kwargs) |
|
|
| @override |
| def prediction_step( |
| self, |
| model: "torch.nn.Module", |
| inputs: dict[str, Union["torch.Tensor", Any]], |
| prediction_loss_only: bool, |
| ignore_keys: Optional[list[str]] = None, |
| **gen_kwargs, |
| ) -> tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]: |
| r"""Remove the prompt part in the generated tokens. |
| |
| Subclass and override to inject custom behavior. |
| """ |
| if self.args.predict_with_generate: |
| labels = inputs.pop("labels", None) |
| else: |
| labels = inputs.get("labels") |
|
|
| loss, generated_tokens, _ = super().prediction_step( |
| model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys, **gen_kwargs |
| ) |
| if generated_tokens is not None and self.args.predict_with_generate: |
| generated_tokens[:, : inputs["input_ids"].size(-1)] = self.processing_class.pad_token_id |
| generated_tokens = generated_tokens.contiguous() |
|
|
| return loss, generated_tokens, labels |
|
|
| def save_predictions( |
| self, dataset: "Dataset", predict_results: "PredictionOutput", skip_special_tokens: bool = True |
| ) -> None: |
| r"""Save model predictions to `output_dir`. |
| |
| A custom behavior that not contained in Seq2SeqTrainer. |
| """ |
| if not self.is_world_process_zero(): |
| return |
|
|
| output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") |
| logger.info_rank0(f"Saving prediction results to {output_prediction_file}") |
|
|
| labels = np.where( |
| predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.processing_class.pad_token_id |
| ) |
| preds = np.where( |
| predict_results.predictions != IGNORE_INDEX, |
| predict_results.predictions, |
| self.processing_class.pad_token_id, |
| ) |
|
|
| for i in range(len(preds)): |
| pad_len = np.nonzero(preds[i] != self.processing_class.pad_token_id)[0] |
| if len(pad_len): |
| preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1) |
|
|
| decoded_inputs = self.processing_class.batch_decode(dataset["input_ids"], skip_special_tokens=False) |
| decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=skip_special_tokens) |
| decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=skip_special_tokens) |
|
|
| with open(output_prediction_file, "w", encoding="utf-8") as f: |
| for text, pred, label in zip(decoded_inputs, decoded_preds, decoded_labels): |
| f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n") |
|
|