|
|
import os |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from transformers import Trainer |
|
|
from transformers.trainer import ( |
|
|
is_sagemaker_mp_enabled, |
|
|
get_parameter_names, |
|
|
ALL_LAYERNORM_LAYERS, |
|
|
is_peft_available, |
|
|
WEIGHTS_NAME, |
|
|
TRAINING_ARGS_NAME, |
|
|
SAFE_WEIGHTS_NAME, |
|
|
TRAINER_STATE_NAME, |
|
|
PREFIX_CHECKPOINT_DIR, |
|
|
logger, |
|
|
) |
|
|
import safetensors |
|
|
from peft import PeftModel |
|
|
from typing import Optional |
|
|
import numpy as np |
|
|
from transformers.processing_utils import ProcessorMixin |
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
from peft import PeftModel |
|
|
from training.train_utils import get_peft_state_maybe_zero_3, get_peft_state_non_lora_maybe_zero_3 |
|
|
|
|
|
def maybe_zero_3(param, ignore_status=False, name=None): |
|
|
from deepspeed import zero |
|
|
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus |
|
|
|
|
|
if hasattr(param, "ds_id"): |
|
|
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: |
|
|
if not ignore_status: |
|
|
print(name, "no ignore status") |
|
|
with zero.GatheredParameters([param]): |
|
|
param = param.data.detach().cpu().clone() |
|
|
else: |
|
|
param = param.detach().cpu().clone() |
|
|
return param |
|
|
|
|
|
class QwenTrainer(Trainer): |
|
|
|
|
|
def __init__(self, processor, *args, **kwargs): |
|
|
super(QwenTrainer, self).__init__(*args, **kwargs) |
|
|
self.processor = processor |
|
|
|
|
|
def create_optimizer(self): |
|
|
""" |
|
|
Setup the optimizer. |
|
|
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the |
|
|
Trainer's init through `optimizers`, or subclass and override this method in a subclass. |
|
|
""" |
|
|
if is_sagemaker_mp_enabled(): |
|
|
return super().create_optimizer() |
|
|
|
|
|
opt_model = self.model |
|
|
|
|
|
if self.optimizer is None: |
|
|
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) |
|
|
decay_parameters = [name for name in decay_parameters if "bias" not in name] |
|
|
lr_mapper = {} |
|
|
visual_parameters = [] |
|
|
merger_parameters = [] |
|
|
|
|
|
if self.args.vision_lr is not None: |
|
|
lr_mapper["visual"] = self.args.vision_lr |
|
|
visual_parameters = [name for name, _ in opt_model.named_parameters() if "visual" in name and "merger" not in name] |
|
|
if self.args.merger_lr is not None: |
|
|
lr_mapper["merger"] = self.args.merger_lr |
|
|
merger_parameters = [name for name, _ in opt_model.named_parameters() if "merger" in name] |
|
|
|
|
|
if len(lr_mapper) > 0: |
|
|
special_lr_parameters = merger_parameters + visual_parameters |
|
|
|
|
|
optimizer_grouped_parameters = [ |
|
|
{ |
|
|
"params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in special_lr_parameters and p.requires_grad)], |
|
|
"weight_decay": self.args.weight_decay, |
|
|
}, |
|
|
{ |
|
|
"params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in special_lr_parameters and p.requires_grad)], |
|
|
"weight_decay": 0.0, |
|
|
}, |
|
|
] |
|
|
|
|
|
if visual_parameters: |
|
|
optimizer_grouped_parameters.extend( |
|
|
[ |
|
|
{ |
|
|
"params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in visual_parameters and p.requires_grad)], |
|
|
"weight_decay": self.args.weight_decay, |
|
|
"lr": self.args.vision_lr, |
|
|
}, |
|
|
{ |
|
|
"params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in visual_parameters and p.requires_grad)], |
|
|
"weight_decay": 0.0, |
|
|
"lr": self.args.vision_lr, |
|
|
}, |
|
|
] |
|
|
) |
|
|
|
|
|
if merger_parameters: |
|
|
optimizer_grouped_parameters.extend( |
|
|
[ |
|
|
{ |
|
|
"params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in merger_parameters and p.requires_grad)], |
|
|
"weight_decay": self.args.weight_decay, |
|
|
"lr": self.args.merger_lr, |
|
|
}, |
|
|
{ |
|
|
"params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in merger_parameters and p.requires_grad)], |
|
|
"weight_decay": 0.0, |
|
|
"lr": self.args.merger_lr, |
|
|
}, |
|
|
] |
|
|
) |
|
|
else: |
|
|
optimizer_grouped_parameters = [ |
|
|
{ |
|
|
"params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)], |
|
|
"weight_decay": self.args.weight_decay, |
|
|
}, |
|
|
{ |
|
|
"params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)], |
|
|
"weight_decay": 0.0, |
|
|
}, |
|
|
] |
|
|
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) |
|
|
|
|
|
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) |
|
|
if optimizer_cls.__name__ == "Adam8bit": |
|
|
import bitsandbytes |
|
|
|
|
|
manager = bitsandbytes.optim.GlobalOptimManager.get_instance() |
|
|
|
|
|
skipped = 0 |
|
|
for module in opt_model.modules(): |
|
|
if isinstance(module, nn.Embedding): |
|
|
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) |
|
|
logger.info(f"skipped {module}: {skipped/2**20}M params") |
|
|
manager.register_module_override(module, "weight", {"optim_bits": 32}) |
|
|
logger.debug(f"bitsandbytes: will optimize {module} in fp32") |
|
|
logger.info(f"skipped: {skipped/2**20}M params") |
|
|
|
|
|
return self.optimizer |
|
|
|
|
|
def _save_checkpoint(self, model, trial): |
|
|
if self.args.lora_enable: |
|
|
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" |
|
|
|
|
|
if self.hp_search_backend is None and trial is None: |
|
|
self.store_flos() |
|
|
|
|
|
run_dir = self._get_output_dir(trial=trial) |
|
|
output_dir = os.path.join(run_dir, checkpoint_folder) |
|
|
|
|
|
self.save_model(output_dir, _internal_call=True) |
|
|
|
|
|
non_lora_weights = get_peft_state_non_lora_maybe_zero_3(self.model.named_parameters(), require_grad_only=False) |
|
|
torch.save(non_lora_weights, os.path.join(output_dir, "non_lora_state_dict.bin")) |
|
|
|
|
|
if not self.args.save_only_model: |
|
|
|
|
|
self._save_optimizer_and_scheduler(output_dir) |
|
|
|
|
|
self._save_rng_state(output_dir) |
|
|
|
|
|
|
|
|
if self.args.should_save: |
|
|
|
|
|
self.state.stateful_callbacks["TrainerControl"] = self.control.state() |
|
|
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) |
|
|
|
|
|
if self.args.push_to_hub: |
|
|
self._push_from_checkpoint(output_dir) |
|
|
|
|
|
|
|
|
if self.args.should_save: |
|
|
|
|
|
|
|
|
self._rotate_checkpoints(use_mtime=False, output_dir=run_dir) |
|
|
|
|
|
else: |
|
|
super(QwenTrainer, self)._save_checkpoint(model, trial) |
|
|
|
|
|
def _save(self, output_dir: Optional[str] = None, state_dict=None): |
|
|
|
|
|
output_dir = output_dir if output_dir is not None else self.args.output_dir |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
logger.info(f"Saving model checkpoint to {output_dir}") |
|
|
|
|
|
supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel) |
|
|
|
|
|
|
|
|
if not isinstance(self.model, supported_classes): |
|
|
if state_dict is None: |
|
|
state_dict = self.model.state_dict() |
|
|
|
|
|
if isinstance(self.accelerator.unwrap_model(self.model), supported_classes): |
|
|
self.accelerator.unwrap_model(self.model).save_pretrained( |
|
|
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors |
|
|
) |
|
|
else: |
|
|
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") |
|
|
if self.args.save_safetensors: |
|
|
safetensors.torch.save_file( |
|
|
state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"} |
|
|
) |
|
|
else: |
|
|
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) |
|
|
else: |
|
|
self.model.save_pretrained( |
|
|
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors |
|
|
) |
|
|
|
|
|
if self.tokenizer is not None: |
|
|
self.tokenizer.save_pretrained(output_dir) |
|
|
|
|
|
if self.processor is not None: |
|
|
self.processor.save_pretrained(output_dir) |
|
|
|
|
|
|
|
|
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|