import os import torch import torch.nn as nn from transformers import Trainer from transformers.trainer import ( is_sagemaker_mp_enabled, get_parameter_names, ALL_LAYERNORM_LAYERS, is_peft_available, WEIGHTS_NAME, TRAINING_ARGS_NAME, SAFE_WEIGHTS_NAME, TRAINER_STATE_NAME, PREFIX_CHECKPOINT_DIR, logger, ) import safetensors from peft import PeftModel from typing import Optional import numpy as np from transformers.processing_utils import ProcessorMixin from transformers.modeling_utils import PreTrainedModel from peft import PeftModel from training.train_utils import get_peft_state_maybe_zero_3, get_peft_state_non_lora_maybe_zero_3 def maybe_zero_3(param, ignore_status=False, name=None): from deepspeed import zero from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus if hasattr(param, "ds_id"): if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: if not ignore_status: print(name, "no ignore status") with zero.GatheredParameters([param]): param = param.data.detach().cpu().clone() else: param = param.detach().cpu().clone() return param class QwenTrainer(Trainer): def __init__(self, processor, *args, **kwargs): super(QwenTrainer, self).__init__(*args, **kwargs) self.processor = processor def create_optimizer(self): """ Setup the optimizer. We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the Trainer's init through `optimizers`, or subclass and override this method in a subclass. """ if is_sagemaker_mp_enabled(): return super().create_optimizer() opt_model = self.model if self.optimizer is None: decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) decay_parameters = [name for name in decay_parameters if "bias" not in name] lr_mapper = {} visual_parameters = [] merger_parameters = [] if self.args.vision_lr is not None: lr_mapper["visual"] = self.args.vision_lr visual_parameters = [name for name, _ in opt_model.named_parameters() if "visual" in name and "merger" not in name] if self.args.merger_lr is not None: lr_mapper["merger"] = self.args.merger_lr merger_parameters = [name for name, _ in opt_model.named_parameters() if "merger" in name] if len(lr_mapper) > 0: special_lr_parameters = merger_parameters + visual_parameters optimizer_grouped_parameters = [ { "params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in special_lr_parameters and p.requires_grad)], "weight_decay": self.args.weight_decay, }, { "params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in special_lr_parameters and p.requires_grad)], "weight_decay": 0.0, }, ] if visual_parameters: optimizer_grouped_parameters.extend( [ { "params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in visual_parameters and p.requires_grad)], "weight_decay": self.args.weight_decay, "lr": self.args.vision_lr, }, { "params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in visual_parameters and p.requires_grad)], "weight_decay": 0.0, "lr": self.args.vision_lr, }, ] ) if merger_parameters: optimizer_grouped_parameters.extend( [ { "params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in merger_parameters and p.requires_grad)], "weight_decay": self.args.weight_decay, "lr": self.args.merger_lr, }, { "params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in merger_parameters and p.requires_grad)], "weight_decay": 0.0, "lr": self.args.merger_lr, }, ] ) else: optimizer_grouped_parameters = [ { "params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)], "weight_decay": self.args.weight_decay, }, { "params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)], "weight_decay": 0.0, }, ] optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) if optimizer_cls.__name__ == "Adam8bit": import bitsandbytes manager = bitsandbytes.optim.GlobalOptimManager.get_instance() skipped = 0 for module in opt_model.modules(): if isinstance(module, nn.Embedding): skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) logger.info(f"skipped {module}: {skipped/2**20}M params") manager.register_module_override(module, "weight", {"optim_bits": 32}) logger.debug(f"bitsandbytes: will optimize {module} in fp32") logger.info(f"skipped: {skipped/2**20}M params") return self.optimizer def _save_checkpoint(self, model, trial): if self.args.lora_enable: checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" if self.hp_search_backend is None and trial is None: self.store_flos() run_dir = self._get_output_dir(trial=trial) output_dir = os.path.join(run_dir, checkpoint_folder) self.save_model(output_dir, _internal_call=True) non_lora_weights = get_peft_state_non_lora_maybe_zero_3(self.model.named_parameters(), require_grad_only=False) torch.save(non_lora_weights, os.path.join(output_dir, "non_lora_state_dict.bin")) if not self.args.save_only_model: # Save optimizer and scheduler self._save_optimizer_and_scheduler(output_dir) # Save RNG state self._save_rng_state(output_dir) # Save the Trainer state if self.args.should_save: # Update the `TrainerControl` state to where we are currently self.state.stateful_callbacks["TrainerControl"] = self.control.state() self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) if self.args.push_to_hub: self._push_from_checkpoint(output_dir) # Maybe delete some older checkpoints. if self.args.should_save: # Solely rely on numerical checkpoint id for rotation. # mtime is not reliable especially on some fuse fs in cloud environments. self._rotate_checkpoints(use_mtime=False, output_dir=run_dir) else: super(QwenTrainer, self)._save_checkpoint(model, trial) def _save(self, output_dir: Optional[str] = None, state_dict=None): # If we are executing this function, we are the process zero, so we don't check for that. output_dir = output_dir if output_dir is not None else self.args.output_dir os.makedirs(output_dir, exist_ok=True) logger.info(f"Saving model checkpoint to {output_dir}") supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel) # Save a trained model and configuration using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` if not isinstance(self.model, supported_classes): if state_dict is None: state_dict = self.model.state_dict() if isinstance(self.accelerator.unwrap_model(self.model), supported_classes): self.accelerator.unwrap_model(self.model).save_pretrained( output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors ) else: logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") if self.args.save_safetensors: safetensors.torch.save_file( state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"} ) else: torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) else: self.model.save_pretrained( output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors ) if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir) if self.processor is not None: self.processor.save_pretrained(output_dir) # Good practice: save your training arguments together with the trained model torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) # def training_step(self, model, inputs): # for name, param in model.named_parameters(): # if 'visual' in name and param.requires_grad: # print(f"Training parameter {name}") # # return super().training_step(model, inputs)