|
|
| import torch |
| import torch.nn as nn |
| import deepspeed |
| from transformers import Trainer |
| from transformers.trainer_pt_utils import nested_detach |
| from transformers.utils import is_sagemaker_mp_enabled |
| from transformers.trainer import * |
| from transformers.integrations import is_deepspeed_zero3_enabled |
|
|
|
|
| class CPMTrainer(Trainer): |
| def compute_loss(self, model, inputs, return_outputs=False): |
| if "labels" in inputs: |
| labels = inputs.pop("labels") |
| else: |
| labels = None |
| |
| if not self.args.use_lora: |
| outputs = self.model(data = inputs, use_cache=False) |
| else: |
| with self.model._enable_peft_forward_hooks(**inputs): |
| outputs = self.model.base_model(data = inputs, use_cache=False) |
| |
| if labels is not None: |
| |
| loss_fct = nn.CrossEntropyLoss() |
| logits = outputs.logits.view(-1, |
| self.model.config.vocab_size).contiguous() |
| labels = labels.view(-1).long().contiguous() |
| |
| labels = labels.to(logits.device) |
| loss = loss_fct(logits, labels) |
| else: |
| if isinstance(outputs, dict) and "loss" not in outputs: |
| raise ValueError( |
| "The model did not return a loss from the inputs, only the following keys: " |
| f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." |
| ) |
| |
| loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] |
|
|
| return (loss, outputs) if return_outputs else loss |
|
|
| def prediction_step( |
| self, |
| model: nn.Module, |
| inputs: Dict[str, Union[torch.Tensor, Any]], |
| prediction_loss_only: bool, |
| ignore_keys: Optional[List[str]] = None, |
| ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: |
| """ |
| Perform an evaluation step on `model` using `inputs`. |
| |
| Subclass and override to inject custom behavior. |
| |
| Args: |
| model (`nn.Module`): |
| The model to evaluate. |
| inputs (`Dict[str, Union[torch.Tensor, Any]]`): |
| The inputs and targets of the model. |
| |
| The dictionary will be unpacked before being fed to the model. Most models expect the targets under the |
| argument `labels`. Check your model's documentation for all accepted arguments. |
| prediction_loss_only (`bool`): |
| Whether or not to return the loss only. |
| ignore_keys (`List[str]`, *optional*): |
| A list of keys in the output of your model (if it is a dictionary) that should be ignored when |
| gathering predictions. |
| |
| Return: |
| Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, |
| logits and labels (each being optional). |
| """ |
| has_labels = ( |
| False |
| if len(self.label_names) == 0 |
| else all(inputs.get(k) is not None for k in self.label_names) |
| ) |
| |
| |
| |
| return_loss = inputs.get("return_loss", None) |
| if return_loss is None: |
| return_loss = self.can_return_loss |
| loss_without_labels = ( |
| True if len(self.label_names) == 0 and return_loss else False |
| ) |
|
|
| inputs = self._prepare_inputs(inputs) |
| if ignore_keys is None: |
| if hasattr(self.model, "config"): |
| ignore_keys = getattr( |
| self.model.config, "keys_to_ignore_at_inference", [] |
| ) |
| else: |
| ignore_keys = [] |
|
|
| |
| if has_labels or loss_without_labels: |
| labels = nested_detach(tuple(inputs.get(name) |
| for name in self.label_names)) |
| if len(labels) == 1: |
| labels = labels[0] |
| else: |
| labels = None |
|
|
| with torch.no_grad(): |
| if is_sagemaker_mp_enabled(): |
| raw_outputs = smp_forward_only(model, inputs) |
| if has_labels or loss_without_labels: |
| if isinstance(raw_outputs, dict): |
| loss_mb = raw_outputs["loss"] |
| logits_mb = tuple( |
| v |
| for k, v in raw_outputs.items() |
| if k not in ignore_keys + ["loss"] |
| ) |
| else: |
| loss_mb = raw_outputs[0] |
| logits_mb = raw_outputs[1:] |
|
|
| loss = loss_mb.reduce_mean().detach().cpu() |
| logits = smp_nested_concat(logits_mb) |
| else: |
| loss = None |
| if isinstance(raw_outputs, dict): |
| logits_mb = tuple( |
| v for k, v in raw_outputs.items() if k not in ignore_keys |
| ) |
| else: |
| logits_mb = raw_outputs |
| logits = smp_nested_concat(logits_mb) |
| else: |
| if has_labels or loss_without_labels: |
| with self.compute_loss_context_manager(): |
| loss, outputs = self.compute_loss( |
| model, inputs, return_outputs=True |
| ) |
| loss = loss.mean().detach() |
|
|
| if isinstance(outputs, dict): |
| logits = tuple( |
| v |
| for k, v in outputs.items() |
| if k not in ignore_keys + ["loss"] |
| ) |
| else: |
| logits = outputs[1:] |
| else: |
| loss = None |
| with self.compute_loss_context_manager(): |
| outputs = model(**inputs) |
| if isinstance(outputs, dict): |
| logits = tuple( |
| v for k, v in outputs.items() if k not in ignore_keys |
| ) |
| else: |
| logits = outputs |
| |
| if self.args.past_index >= 0: |
| self._past = outputs[self.args.past_index - 1] |
|
|
| if prediction_loss_only: |
| return (loss, None, None) |
|
|
| logits = nested_detach(logits) |
| if len(logits) == 1: |
| logits = logits[0] |
|
|
| return (loss, logits, labels) |
| |
| def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: |
| """ |
| Perform a training step on a batch of inputs. |
| |
| Subclass and override to inject custom behavior. |
| |
| Args: |
| model (`nn.Module`): |
| The model to train. |
| inputs (`Dict[str, Union[torch.Tensor, Any]]`): |
| The inputs and targets of the model. |
| |
| The dictionary will be unpacked before being fed to the model. Most models expect the targets under the |
| argument `labels`. Check your model's documentation for all accepted arguments. |
| |
| Return: |
| `torch.Tensor`: The tensor with training loss on this batch. |
| """ |
| model.train() |
| inputs = self._prepare_inputs(inputs) |
|
|
| if is_sagemaker_mp_enabled(): |
| loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) |
| return loss_mb.reduce_mean().detach().to(self.args.device) |
|
|
| with self.compute_loss_context_manager(): |
| loss = self.compute_loss(model, inputs) |
|
|
| del inputs |
| torch.cuda.empty_cache() |
|
|
| if self.args.n_gpu > 1: |
| loss = loss.mean() |
|
|
| if self.use_apex: |
| with amp.scale_loss(loss, self.optimizer) as scaled_loss: |
| scaled_loss.backward() |
| else: |
| self.accelerator.backward(loss) |
|
|
| return loss.detach() / self.args.gradient_accumulation_steps |
| |
| def _save(self, output_dir: Optional[str] = None, state_dict=None): |
| |
| output_dir = output_dir if output_dir is not None else self.args.output_dir |
| os.makedirs(output_dir, exist_ok=True) |
| logger.info(f"Saving model checkpoint to {output_dir}") |
|
|
| supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel) |
| |
| |
| if not isinstance(self.model, supported_classes): |
| if state_dict is None: |
| state_dict = self.model.state_dict() |
|
|
| if isinstance(unwrap_model(self.model), supported_classes): |
| unwrap_model(self.model).save_pretrained( |
| output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors |
| ) |
| else: |
| logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") |
| if self.args.save_safetensors: |
| safetensors.torch.save_file( |
| state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"} |
| ) |
| else: |
| torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) |
| else: |
| |
| self.model.save_pretrained( |
| output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors |
| ) |
|
|
| if self.tokenizer is not None: |
| self.tokenizer.save_pretrained(output_dir) |
|
|
| |
| torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) |
|
|