| | import os |
| | import torch |
| | from transformers import Trainer |
| | from typing import Optional |
| |
|
| |
|
| | def maybe_zero_3(param, ignore_status=False, name=None): |
| | from deepspeed import zero |
| | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus |
| | if hasattr(param, "ds_id"): |
| | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: |
| | if not ignore_status: |
| | print(name, 'no ignore status') |
| | with zero.GatheredParameters([param]): |
| | param = param.data.detach().cpu().clone() |
| | else: |
| | param = param.detach().cpu().clone() |
| | return param |
| |
|
| |
|
| | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): |
| | to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} |
| | to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()} |
| | return to_return |
| |
|
| |
|
| | class ChatUniViTrainer(Trainer): |
| | def _save_checkpoint(self, model, trial, metrics=None): |
| | if 0 and getattr(self.args, 'tune_mm_mlp_adapter', False): |
| | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR |
| | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" |
| |
|
| | run_dir = self._get_output_dir(trial=trial) |
| | output_dir = os.path.join(run_dir, checkpoint_folder) |
| |
|
| | |
| | keys_to_match = ['mm_projector', "ctm", "block"] |
| | if getattr(self.args, "use_im_start_end", False): |
| | keys_to_match.extend(['embed_tokens', 'embed_in']) |
| |
|
| | weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match) |
| |
|
| | if self.args.local_rank == 0 or self.args.local_rank == -1: |
| | self.model.config.save_pretrained(output_dir) |
| | torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) |
| | else: |
| | super(ChatUniViTrainer, self)._save_checkpoint(model, trial, metrics) |
| |
|
| | def _save(self, output_dir: Optional[str] = None, state_dict=None): |
| | if 0 and getattr(self.args, 'tune_mm_mlp_adapter', False): |
| | pass |
| | else: |
| | super(ChatUniViTrainer, self)._save(output_dir, state_dict) |
| |
|