| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import warnings |
| from typing import Callable, Dict, List, Optional, Tuple, Union |
|
|
| import torch |
| from datasets import Dataset |
| from torch.utils.data import DataLoader |
| from transformers import ( |
| DataCollator, |
| DataCollatorForLanguageModeling, |
| DataCollatorForSeq2Seq, |
| PreTrainedModel, |
| PreTrainedTokenizerBase, |
| Trainer, |
| TrainingArguments, |
| ) |
| from transformers.trainer_utils import EvalLoopOutput |
|
|
| from ..core import PPODecorators |
| from ..import_utils import is_peft_available |
|
|
|
|
| if is_peft_available(): |
| from peft import PeftModel |
|
|
|
|
| class IterativeSFTTrainer(Trainer): |
| """ |
| The IterativeSFTTrainer can be used to finetune models with methods that requires some steps between optimization. |
| |
| Attributes: |
| **model** (`PreTrainedModel`) -- Model to be optimized, either an 'AutoModelForCausalLM' or an 'AutoModelForSeq2SeqLM'. |
| Check the documentation of `PreTrainedModel` for more details. |
| **args** (`transformers.TrainingArguments`): -- The arguments to use for training. |
| **tokenizer** (`PreTrainedTokenizerBase`) -- Tokenizer to be used for encoding the |
| data. Check the documentation of `transformers.PreTrainedTokenizer` and |
| `transformers.PreTrainedTokenizerFast` for more details. |
| **optimizers** (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): -- The optimizer and scheduler to use for training. |
| **data_collator** (Union[DataCollatorForLanguageModeling, DataCollatorForSeq2Seq], *optional*) -- Data collator to be used for training and |
| passed along the dataloader. |
| **eval_dataset** (`datasets.Dataset`): The dataset to use for evaluation. |
| **max_length** (`int`, defaults to `None`): -- The maximum length of the input. |
| **truncation_mode** (`str`, defaults to `keep_end`): -- The truncation mode to use, either `keep_end` or `keep_start`. |
| **preprocess_logits_for_metrics** (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): -- The function to use to preprocess the logits before computing the metrics. |
| **compute_metrics** (`Callable[[EvalPrediction], Dict]`, *optional*): -- The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to metric values. |
| **optimize_device_cache ** (`bool`, *optional*, defaults to `False`) -- Optimize CUDA cache for slightly more memory-efficient training. |
| """ |
|
|
| def __init__( |
| self, |
| model: PreTrainedModel = None, |
| args: TrainingArguments = None, |
| tokenizer: PreTrainedTokenizerBase = None, |
| optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( |
| None, |
| None, |
| ), |
| data_collator: Optional[DataCollator] = None, |
| eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, |
| max_length: Optional[int] = None, |
| truncation_mode: Optional[str] = "keep_end", |
| preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, |
| compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None, |
| optimize_device_cache: Optional[bool] = False, |
| ): |
| |
| if not isinstance(tokenizer, (PreTrainedTokenizerBase)): |
| raise ValueError( |
| f"tokenizer must be a PreTrainedTokenizerBase like a PreTrainedTokenizer or a PreTrainedTokenizerFast, got {type(tokenizer)}" |
| ) |
| if not isinstance(model, PreTrainedModel): |
| raise ValueError(f"model must be a PreTrainedModel, got {type(model)}") |
| if not model.can_generate(): |
| warnings.warn( |
| f"The current model class {type(model)} is not compatible with `.generate()`" |
| "Please make sure that this is intended." |
| ) |
| if optimizers[1] is None and args.max_steps == -1: |
| raise ValueError( |
| "When no scheduler is provided, you need to set the total number of training steps to perform `max_steps`" |
| ) |
|
|
| self.is_encoder_decoder = getattr(model.config, "is_encoder_decoder", False) |
| self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) |
|
|
| self.tokenizer = tokenizer |
|
|
| if data_collator is None: |
| if self.is_encoder_decoder: |
| warnings.warn( |
| "No data collator is provided. Using 'DataCollatorForSeq2Seq' with" |
| "'labels_pad_token_id' set to '-100' and 'pad_to_multiple_of' set to 8." |
| ) |
| self.data_collator = DataCollatorForSeq2Seq(tokenizer, label_pad_token_id=-100, pad_to_multiple_of=8) |
| else: |
| warnings.warn("No data collator is provided. Using 'DataCollatorForLanguageModeling'") |
| self.data_collator = DataCollatorForLanguageModeling(self.tokenizer, mlm=False) |
| else: |
| self.data_collator = data_collator |
|
|
| self.max_length = max_length |
| self.truncation_mode = truncation_mode |
| self.optimize_device_cache = optimize_device_cache |
|
|
| super().__init__( |
| model=model, |
| args=args, |
| data_collator=self.data_collator, |
| eval_dataset=eval_dataset, |
| tokenizer=tokenizer, |
| compute_metrics=compute_metrics, |
| optimizers=optimizers, |
| preprocess_logits_for_metrics=preprocess_logits_for_metrics, |
| ) |
|
|
| self.create_optimizer_and_scheduler(self.args.max_steps) |
|
|
| |
| self.model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( |
| self.model, self.optimizer, self.lr_scheduler |
| ) |
|
|
| self.tokenizer.truncation_side = "left" if self.truncation_mode == "keep_end" else "right" |
|
|
| if not hasattr(self, "accelerator"): |
| raise AttributeError( |
| "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." |
| ) |
|
|
| PPODecorators.optimize_device_cache = self.optimize_device_cache |
|
|
| def prepare_model_inputs(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor): |
| if attention_mask is None: |
| attention_mask = [torch.ones_like(ids) for ids in input_ids] |
|
|
| if self.is_encoder_decoder: |
| input_data = self.data_collator( |
| [ |
| {"input_ids": ids, "attention_mask": att, "labels": lab} |
| for ids, att, lab in zip(input_ids, attention_mask, labels) |
| ] |
| ).to(self.model.device) |
|
|
| input_data.pop("decoder_input_ids", None) |
|
|
| input_data["labels"][input_data["labels"] == self.tokenizer.pad_token_id] = -100 |
|
|
| else: |
| input_data = self.data_collator( |
| [{"input_ids": ids, "attention_mask": att} for ids, att in zip(input_ids, attention_mask)] |
| ).to(self.model.device) |
|
|
| |
| if self.max_length is not None: |
| if self.truncation_mode == "keep_start": |
| input_data = {k: v[: self.max_length] for k, v in input_data.items()} |
| elif self.truncation_mode == "keep_end": |
| input_data = {k: v[-self.max_length :] for k, v in input_data.items()} |
| else: |
| raise ValueError(f"Unknown truncation mode: {self.truncation_mode}") |
|
|
| return input_data |
|
|
| @staticmethod |
| def _step_safety_checker( |
| input_ids: List[torch.LongTensor], |
| attention_mask: List[torch.LongTensor], |
| labels: List[torch.LongTensor], |
| texts: List[str], |
| texts_labels: List[str], |
| ): |
| """ |
| Check if the input data is valid for training. |
| |
| Args: |
| input_ids (List[`torch.LongTensor`]): |
| List of tensors containing the input_ids |
| attention_mask (List[`torch.LongTensor`]): |
| List of tensors containing the attention_mask |
| labels (List[`torch.FloatTensor`]): |
| List of tensors containing the labels |
| texts (List[`str`]): |
| List of string containing the text input. |
| texts_labels (List[`str`]): |
| List of string containing the text labels. |
| Returns: |
| `tuple`: The input data. |
| """ |
| if texts is None: |
| if attention_mask is None: |
| for name, tensor_list in zip(["input_ids", "labels"], [input_ids, labels]): |
| if not isinstance(tensor_list, list): |
| raise ValueError(f"{name} must be a list of tensors - got {type(tensor_list)}") |
| if not isinstance(tensor_list[0], torch.Tensor): |
| raise ValueError(f"Elements in {name} must be tensors - got {type(tensor_list[0])}") |
| else: |
| for name, tensor_list in zip( |
| ["input_ids", "attention_mask", "labels"], [input_ids, attention_mask, labels] |
| ): |
| if not isinstance(tensor_list, list): |
| raise ValueError(f"{name} must be a list of tensors - got {type(tensor_list)}") |
| if not isinstance(tensor_list[0], torch.Tensor): |
| raise ValueError(f"Elements in {name} must be tensors - got {type(tensor_list[0])}") |
| else: |
| if not isinstance(texts, list): |
| raise ValueError(f"'text' must be a list of strings - got {type(texts)}") |
| if not isinstance(texts[0], str): |
| raise ValueError(f"Elements in 'text' must be strings - got {type(texts[0])}") |
| if texts_labels is not None: |
| if not isinstance(texts_labels, list): |
| raise ValueError(f"'text_labels' must be a list of strings - got {type(texts_labels)}") |
| if not isinstance(texts_labels[0], str): |
| raise ValueError(f"Elements in 'text_labels' must be strings - got {type(texts_labels[0])}") |
|
|
| return input_ids, attention_mask, labels, texts, texts_labels |
|
|
| @PPODecorators.empty_device_cache() |
| def step( |
| self, |
| input_ids: Optional[List[torch.LongTensor]] = None, |
| attention_mask: Optional[List[torch.LongTensor]] = None, |
| labels: Optional[List[torch.LongTensor]] = None, |
| texts: Optional[List[str]] = None, |
| texts_labels: Optional[List[str]] = None, |
| ): |
| """ |
| Run an optimisation step given a list of input_ids, attention_mask, and labels or a list of text and text_labels. |
| Args: |
| input_ids (List[`torch.LongTensor`]): |
| List of tensors containing the input_ids (if not provided, text will be used) |
| attention_mask (List[`torch.LongTensor`], , *optional*): |
| List of tensors containing the attention_mask |
| labels (List[`torch.FloatTensor`], *optional*): |
| List of tensors containing the labels (if set to None, will default to input_ids) |
| texts (List[`str`], *optional*): |
| List of strings containing the text input (if not provided, input_ids will directly be used) |
| texts_labels (List[`str`], *optional*): |
| List of strings containing the text labels (if set to None, will default to text) |
| Returns: |
| `dict[str, Any]`: A summary of the training statistics |
| """ |
| self.model.train() |
|
|
| if self.state.global_step == 0: |
| self.tr_loss = torch.tensor(0.0).to(self.args.device) |
| self._globalstep_last_logged = self.state.global_step |
|
|
| if input_ids is None and texts is None: |
| raise ValueError("Step should include `input_ids` or `texts` as keyword arguments.") |
| elif input_ids is not None and texts is not None: |
| warnings.warn( |
| "Both 'input_ids' and 'texts' are provided. 'input_ids' will be overwritten using inputs provided by the 'texts' keyword argument." |
| ) |
|
|
| if labels is None and texts_labels is None and self.is_encoder_decoder: |
| raise ValueError( |
| "No 'labels' or 'text_labels' are provided. When using an encoder-decoder architecture, 'labels' or 'text_labels' must be passed." |
| ) |
|
|
| input_ids, attention_mask, labels, texts, texts_labels = self._step_safety_checker( |
| input_ids, attention_mask, labels, texts, texts_labels |
| ) |
|
|
| if texts is not None: |
| model_inputs = self.tokenizer( |
| texts, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt" |
| ) |
|
|
| input_ids, attention_mask = model_inputs["input_ids"], model_inputs["attention_mask"] |
|
|
| if texts_labels is not None: |
| labels = self.tokenizer( |
| texts, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt" |
| )["input_ids"] |
|
|
| if labels is None: |
| warnings.warn("No labels are provided. Setting labels to input_ids") |
| labels = input_ids |
|
|
| model_inputs = self.prepare_model_inputs(input_ids, attention_mask, labels) |
|
|
| model_inputs_names = list(model_inputs.keys()) |
|
|
| batch_dict = {} |
| batch_dict.update(model_inputs) |
|
|
| def collator(data): |
| return_dict = dict() |
| for key in data[0]: |
| if key in ["input_ids", "attention_mask", "labels"]: |
| return_dict[key] = torch.stack([d[key] for d in data]).to(self.model.device) |
| return return_dict |
|
|
| batch_data = Dataset.from_dict(batch_dict) |
| batch_data.set_format("torch") |
|
|
| step_dataloader = DataLoader( |
| batch_data, |
| batch_size=self.args.per_device_train_batch_size, |
| shuffle=True, |
| collate_fn=collator, |
| ) |
|
|
| for _, batch in enumerate(step_dataloader): |
| with self.accelerator.accumulate(self.model): |
| model_inputs = {k: batch[k] for k in model_inputs_names} |
| loss = self.compute_loss(self.model, model_inputs) |
|
|
| if self.args.n_gpu > 1: |
| loss = loss.mean() |
|
|
| tr_loss_step = loss.detach() |
|
|
| self.accelerator.backward(loss) |
|
|
| if self.accelerator.sync_gradients and self.args.max_grad_norm is not None: |
| self.accelerator.clip_grad_norm_( |
| self.model.parameters(), |
| self.args.max_grad_norm, |
| ) |
|
|
| self.optimizer.step() |
| self.optimizer.zero_grad() |
| if self.lr_scheduler is not None: |
| self.lr_scheduler.step() |
|
|
| self.state.global_step += 1 |
|
|
| |
| self.tr_loss += tr_loss_step |
|
|
| self._maybe_log_save_evaluate() |
|
|
| def _maybe_log_save_evaluate(self): |
| |
| if self.args.eval_steps is not None: |
| if self.state.global_step % self.args.eval_steps == 0 and self.state.global_step != 0: |
| self.evaluate(self.eval_dataset) |
|
|
| |
| if self.args.logging_steps is not None: |
| if self.state.global_step % self.args.logging_steps == 0 and self.state.global_step != 0: |
| logs: Dict[str, float] = {} |
|
|
| tr_loss_scalar = self._nested_gather(self.tr_loss).mean().item() |
|
|
| |
| self.tr_loss -= self.tr_loss |
|
|
| logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) |
| logs["learning_rate"] = self._get_learning_rate() |
|
|
| self._globalstep_last_logged = self.state.global_step |
|
|
| self.log(logs) |
|
|