litwell's picture
Upload models/src/training/trainer.py with huggingface_hub
f8e18e6 verified
import os
import torch
import torch.nn as nn
from transformers import Trainer
from transformers.trainer import (
is_sagemaker_mp_enabled,
get_parameter_names,
ALL_LAYERNORM_LAYERS,
is_peft_available,
WEIGHTS_NAME,
TRAINING_ARGS_NAME,
SAFE_WEIGHTS_NAME,
TRAINER_STATE_NAME,
PREFIX_CHECKPOINT_DIR,
logger,
)
import safetensors
from peft import PeftModel
from typing import Optional
import numpy as np
from transformers.processing_utils import ProcessorMixin
from transformers.modeling_utils import PreTrainedModel
from peft import PeftModel
from training.train_utils import get_peft_state_maybe_zero_3, get_peft_state_non_lora_maybe_zero_3
def maybe_zero_3(param, ignore_status=False, name=None):
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
if hasattr(param, "ds_id"):
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
if not ignore_status:
print(name, "no ignore status")
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
class QwenTrainer(Trainer):
def __init__(self, processor, *args, **kwargs):
super(QwenTrainer, self).__init__(*args, **kwargs)
self.processor = processor
def create_optimizer(self):
"""
Setup the optimizer.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
if is_sagemaker_mp_enabled():
return super().create_optimizer()
opt_model = self.model
if self.optimizer is None:
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
lr_mapper = {}
visual_parameters = []
merger_parameters = []
if self.args.vision_lr is not None:
lr_mapper["visual"] = self.args.vision_lr
visual_parameters = [name for name, _ in opt_model.named_parameters() if "visual" in name and "merger" not in name]
if self.args.merger_lr is not None:
lr_mapper["merger"] = self.args.merger_lr
merger_parameters = [name for name, _ in opt_model.named_parameters() if "merger" in name]
if len(lr_mapper) > 0:
special_lr_parameters = merger_parameters + visual_parameters
optimizer_grouped_parameters = [
{
"params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in special_lr_parameters and p.requires_grad)],
"weight_decay": self.args.weight_decay,
},
{
"params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in special_lr_parameters and p.requires_grad)],
"weight_decay": 0.0,
},
]
if visual_parameters:
optimizer_grouped_parameters.extend(
[
{
"params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in visual_parameters and p.requires_grad)],
"weight_decay": self.args.weight_decay,
"lr": self.args.vision_lr,
},
{
"params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in visual_parameters and p.requires_grad)],
"weight_decay": 0.0,
"lr": self.args.vision_lr,
},
]
)
if merger_parameters:
optimizer_grouped_parameters.extend(
[
{
"params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in merger_parameters and p.requires_grad)],
"weight_decay": self.args.weight_decay,
"lr": self.args.merger_lr,
},
{
"params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in merger_parameters and p.requires_grad)],
"weight_decay": 0.0,
"lr": self.args.merger_lr,
},
]
)
else:
optimizer_grouped_parameters = [
{
"params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)],
"weight_decay": self.args.weight_decay,
},
{
"params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)],
"weight_decay": 0.0,
},
]
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
logger.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
logger.info(f"skipped: {skipped/2**20}M params")
return self.optimizer
def _save_checkpoint(self, model, trial):
if self.args.lora_enable:
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
if self.hp_search_backend is None and trial is None:
self.store_flos()
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
self.save_model(output_dir, _internal_call=True)
non_lora_weights = get_peft_state_non_lora_maybe_zero_3(self.model.named_parameters(), require_grad_only=False)
torch.save(non_lora_weights, os.path.join(output_dir, "non_lora_state_dict.bin"))
if not self.args.save_only_model:
# Save optimizer and scheduler
self._save_optimizer_and_scheduler(output_dir)
# Save RNG state
self._save_rng_state(output_dir)
# Save the Trainer state
if self.args.should_save:
# Update the `TrainerControl` state to where we are currently
self.state.stateful_callbacks["TrainerControl"] = self.control.state()
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
if self.args.push_to_hub:
self._push_from_checkpoint(output_dir)
# Maybe delete some older checkpoints.
if self.args.should_save:
# Solely rely on numerical checkpoint id for rotation.
# mtime is not reliable especially on some fuse fs in cloud environments.
self._rotate_checkpoints(use_mtime=False, output_dir=run_dir)
else:
super(QwenTrainer, self)._save_checkpoint(model, trial)
def _save(self, output_dir: Optional[str] = None, state_dict=None):
# If we are executing this function, we are the process zero, so we don't check for that.
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}")
supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, supported_classes):
if state_dict is None:
state_dict = self.model.state_dict()
if isinstance(self.accelerator.unwrap_model(self.model), supported_classes):
self.accelerator.unwrap_model(self.model).save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
if self.args.save_safetensors:
safetensors.torch.save_file(
state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"}
)
else:
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
if self.processor is not None:
self.processor.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
# def training_step(self, model, inputs):
# for name, param in model.named_parameters():
# if 'visual' in name and param.requires_grad:
# print(f"Training parameter {name}")
#
# return super().training_step(model, inputs)