| import os |
| import sys |
| import json |
| import logging |
| import warnings |
| from copy import deepcopy |
| from typing import Any, Dict, List, Optional, Tuple, Union, Sequence, Mapping |
|
|
| import torch |
| from torch import nn |
| from torch.utils.data import DataLoader, Dataset |
| from tqdm import tqdm |
| from transformers import Seq2SeqTrainer, DataCollator, DataCollatorForSeq2Seq |
| from transformers.deepspeed import is_deepspeed_zero3_enabled |
| from transformers.trainer import TRAINER_STATE_NAME |
|
|
| logger = logging.getLogger(__name__) |
| logger.setLevel(logging.INFO) |
| logging.basicConfig( |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| datefmt="%m/%d/%Y %H:%M:%S", |
| handlers=[logging.StreamHandler(sys.stdout), ], |
| ) |
|
|
|
|
| class TrainerDifferentCollatorMixin: |
| def __init__(self, |
| *args, |
| train_collator: Optional[DataCollator] = None, |
| eval_collator: Optional[DataCollator] = None, |
| test_collator: Optional[DataCollator] = None, |
| **kwargs): |
| if train_collator is None and eval_collator is None and test_collator is None: |
| raise ValueError("use different collator for trainer but get no collator function.") |
| if eval_collator is not None and test_collator is not None and eval_collator != test_collator: |
| warnings.warn('[WARNING!!!] use different collator for eval and test. but maybe do_eval and ' |
| 'do_predict both use trainer.predict (i.e. only test_collator is used.) u should' |
| 'check your code and know exactly what u are doing.') |
| self._train_collator = train_collator |
| self._eval_collator = eval_collator if eval_collator is not None else self._train_collator |
| self._test_collator = test_collator if test_collator is not None else self._eval_collator |
| if "data_collator" in kwargs and kwargs["data_collator"] is not None: |
| warnings.warn("use different collator for trainer but get 'data_collator' argument. It will take no effect and be ignored.") |
| super().__init__(*args, **kwargs) |
|
|
| |
| def get_train_dataloader(self) -> DataLoader: |
| old_collator = self.data_collator |
| self.data_collator = self._train_collator |
| dataloader = super().get_train_dataloader() |
| self.data_collator = old_collator |
| return dataloader |
|
|
| |
| def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: |
| old_collator = self.data_collator |
| self.data_collator = self._eval_collator |
| dataloader = super().get_eval_dataloader(eval_dataset) |
| self.data_collator = old_collator |
| return dataloader |
|
|
| |
| def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: |
| old_collator = self.data_collator |
| self.data_collator = self._test_collator |
| dataloader = super().get_test_dataloader(test_dataset) |
| self.data_collator = old_collator |
| return dataloader |
|
|
|
|
| |
| class TrainerForMMLLM(TrainerDifferentCollatorMixin, Seq2SeqTrainer): |
|
|
| def prediction_step( |
| self, |
| model: nn.Module, |
| inputs: Dict[str, Union[torch.Tensor, Any]], |
| prediction_loss_only: bool, |
| ignore_keys: Optional[List[str]] = None, |
| ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: |
| |
|
|
| |
| if not self.args.predict_with_generate or prediction_loss_only: |
| return super().prediction_step( |
| model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys |
| ) |
|
|
| has_labels = "labels" in inputs |
| inputs = self._prepare_inputs(inputs) |
|
|
| gen_kwargs = self._gen_kwargs.copy() |
| if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: |
| gen_kwargs["max_length"] = self.model.config.max_length |
| gen_kwargs["num_beams"] = ( |
| gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams |
| ) |
| default_synced_gpus = True if is_deepspeed_zero3_enabled() else False |
| gen_kwargs["synced_gpus"] = ( |
| gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus |
| ) |
|
|
| |
| filter_keys = ["labels"] |
| for k in inputs: |
| if not (k in filter_keys): |
| gen_kwargs[k] = inputs[k] |
| self._logging_generate_kwargs(gen_kwargs.keys()) |
| with torch.inference_mode(): |
| with self.compute_loss_context_manager(): |
| generated_tokens = self.model.generate(**gen_kwargs) |
|
|
| |
| if self.model.generation_config._from_model_config: |
| self.model.generation_config._from_model_config = False |
|
|
| |
| generation_inputs = inputs['input_ids'] |
| generated_tokens = generated_tokens[:, generation_inputs.size()[-1]:] |
|
|
| if self.model.generation_config._from_model_config: |
| self.model.generation_config._from_model_config = False |
|
|
| |
| gen_config = self.model.generation_config |
| |
| if generated_tokens.shape[-1] < gen_config.max_length: |
| generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_length) |
| elif gen_config.max_new_tokens is not None and generated_tokens.shape[-1] < gen_config.max_new_tokens + 1: |
| generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_new_tokens + 1) |
|
|
| loss = None |
|
|
| if self.args.prediction_loss_only: |
| return loss, None, None |
|
|
| if has_labels: |
| labels = inputs["labels"] |
| if labels.shape[-1] < gen_config.max_length: |
| labels = self._pad_tensors_to_max_len(labels, gen_config.max_length) |
| elif gen_config.max_new_tokens is not None and labels.shape[-1] < gen_config.max_new_tokens + 1: |
| labels = self._pad_tensors_to_max_len(labels, gen_config.max_new_tokens + 1) |
| else: |
| labels = None |
|
|
| return loss, generated_tokens, labels |
|
|
| def _logging_generate_kwargs(self, keys): |
| if not hasattr(self, '_generate_kwargs'): |
| self._generate_kwargs = None |
| if self._generate_kwargs != keys: |
| self._generate_kwargs = keys |
| logger.warning(f"generate use kwargs: {keys}") |
|
|
| def save_prediction(self, predict_results, file_key_prefix='predict'): |
| if not self.is_world_process_zero(): |
| return |
|
|
| import numpy as np |
| os.makedirs(self.args.output_dir, exist_ok=True) |
| np.save(os.path.join(self.args.output_dir, f"{file_key_prefix}_predictions.npy"), predict_results.predictions) |
| np.save(os.path.join(self.args.output_dir, f"{file_key_prefix}_label_ids.npy"), predict_results.label_ids) |
|
|
| preds, targets = predict_results.predictions, predict_results.label_ids |
| origin_preds, origin_targets = preds, targets |
| preds, targets = deepcopy(preds), deepcopy(targets) |
| logger.warning(f"preds shape: {preds.shape}. targets shape: {targets.shape}") |
|
|
| |
| os.makedirs(self.args.output_dir, exist_ok=True) |
| with open(os.path.join(self.args.output_dir, f'{file_key_prefix}_extra_prediction.jsonl'), 'a', encoding="utf-8") as g: |
| for p, t, pi, ti in tqdm( |
| zip(preds, targets, origin_preds, origin_targets), |
| total=len(preds), desc=f"saving prediction for {file_key_prefix}", |
| ): |
| p[p < 0] = self.tokenizer.pad_token_id |
| t[t < 0] = self.tokenizer.pad_token_id |
| p = self.tokenizer.decode(p, skip_special_tokens=True, clean_up_tokenization_spaces=True) |
| t = self.tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True) |
| obj = dict( |
| pred=p, |
| target=t, |
| |
| |
| ) |
| g.write(json.dumps(obj) + '\n') |
| g.flush() |
|
|
| |
| |
| def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): |
| if self.fsdp is not None: |
| if output_dir is None: |
| output_dir = self.args.output_dir |
| from torch.distributed.fsdp import ( |
| FullyShardedDataParallel as FSDP, |
| FullStateDictConfig, |
| StateDictType, |
| ) |
| save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) |
| with FSDP.state_dict_type(self.model, StateDictType.FULL_STATE_DICT, save_policy): |
| cpu_state_dict = self.model.state_dict() |
| if self.args.should_save: |
| self._save(output_dir, state_dict=cpu_state_dict) |
| |
| if self.args.push_to_hub and not _internal_call: |
| self.push_to_hub(commit_message="Model save") |
| else: |
| super().save_model(output_dir, _internal_call) |
|
|
| def plot_loss(self) -> None: |
| if not self.is_world_process_zero(): |
| return |
|
|
| training_args = self.args |
| FIGURE_NAME = "trainer_state.png" |
| import matplotlib.pyplot as plt |
| data = json.load(open(os.path.join(training_args.output_dir, TRAINER_STATE_NAME), "r")) |
| train_steps, train_losses = [], [] |
| for i in range(len(data["log_history"]) - 1): |
| train_steps.append(data["log_history"][i]["step"]) |
| train_losses.append(data["log_history"][i]["loss"]) |
| plt.figure() |
| plt.plot(train_steps, train_losses) |
| plt.title("training loss of {}".format(training_args.output_dir)) |
| plt.xlabel("step") |
| plt.ylabel("training loss") |
| plt.savefig(os.path.join(training_args.output_dir, FIGURE_NAME), format="png", transparent=True, dpi=300) |
| print("Figure saved: {}".format(os.path.join(training_args.output_dir, FIGURE_NAME))) |
|
|
|
|
| class Seq2SeqDataCollator(DataCollatorForSeq2Seq): |
| def __init__( |
| self, |
| inference_mode: bool = False, |
| **kwargs, |
| ): |
| self.inference_mode = inference_mode |
| self.text_keys = ['input_ids', 'labels', 'attention_mask'] |
| super().__init__(**kwargs) |
|
|
| def __call__(self, features: Sequence[Dict[str, Sequence]], return_tensors=None) -> Dict[str, torch.Tensor]: |
| |
| text_features = [{k: feature[k] for k in self.text_keys if k in feature} for feature in features] |
|
|
| if self.inference_mode: |
| old_padding_side = self.tokenizer.padding_side |
| self.tokenizer.padding_side = 'left' |
| text_features = super().__call__(text_features) |
| self.tokenizer.padding_side = old_padding_side |
| else: |
| old_padding_side = self.tokenizer.padding_side |
| self.tokenizer.padding_side = 'right' |
| text_features = super().__call__(text_features) |
| self.tokenizer.padding_side = old_padding_side |
|
|
| return text_features |
|
|
|
|
| class Seq2Seq2DataCollatorWithImage(Seq2SeqDataCollator): |
| def __init__(self, preprocessor, **kwargs): |
| super().__init__(tokenizer=preprocessor['text'], **kwargs) |
|
|
| |
| def _image_process(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: |
| images = [feature['image'] for feature in features] |
| images = torch.stack(images, dim=0) |
| ret = dict(images=images) |
| return ret |
|
|
| def __call__(self, features: List[Dict[str, Any]], return_tensors=None) -> Dict[str, torch.Tensor]: |
| ret = super().__call__(features, return_tensors) |
| image_outputs = self._image_process(features) |
| ret.update(image_outputs) |
| return ret |
|
|