from dataclasses import dataclass, field from functools import reduce from typing import Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn from torch.utils.data import Dataset from peft.tuners import lora from transformers import Trainer, Seq2SeqTrainingArguments from transformers.data.data_collator import DataCollator from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS from transformers.trainer import (EvalPrediction, PreTrainedModel, PreTrainedTokenizerBase, TrainerCallback) from transformers.trainer_pt_utils import get_parameter_names from transformers.utils import is_sagemaker_mp_enabled, logging from logTrainer import LogTrainer logger = logging.get_logger(__name__) @dataclass class LoraPlusTrainingArguments(Seq2SeqTrainingArguments): loraplus_lr_ratio: Optional[float] = field( default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."} ) loraplus_lr_embedding: Optional[float] = field( default=1e-6, metadata={"help": "loraplus learning rate for lora embedding layers."}, ) def get_module(name, opt_model): """ Retrieve a module from a model using its parameter name. Args: name (str): Full name of the parameter, typically including module path. opt_model (torch.nn.Module): The model from which to retrieve the module. Returns: Module corresponding to the given name. """ parent_idx = 2 if "lora" in name else 1 module_names = name.split(sep=".")[:-parent_idx] module = reduce(getattr, module_names, opt_model) return module def create_loraplus_optimizer( opt_model, optimizer_cls, optimizer_kwargs, loraplus_lr_ratio, loraplus_lr_embedding=None, ): """ Creates an optimizer for the given model, applying LoRA-specific learning rate adjustments to different parameter groups. Args: opt_model (torch.nn.Module): The model for which the optimizer is being created. optimizer_cls (class): The class of the optimizer to be used (e.g., torch.optim.Adam). optimizer_kwargs (dict): A dictionary of keyword arguments for the optimizer's initialization. loraplus_lr_ratio (float): The learning rate ratio to be applied to LoRA parameters. loraplus_lr_embedding (float, optional): A specific learning rate for embedding parameters, with a default value if not provided. Returns: An instance of the specified optimizer class configured with the model's parameters organized into groups with custom learning rates. """ assert loraplus_lr_ratio is not None, "loraplus_lr_ratio must be provided." if loraplus_lr_embedding is None: loraplus_lr_embedding = 1e-6 decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) decay_parameters = [name for name in decay_parameters if "bias" not in name] param_groups = { "groupA": {}, "groupB": {}, "groupB_no_decay": {}, "embedding": {}, } for name, param in opt_model.named_parameters(): if not param.requires_grad: continue module = get_module(name, opt_model) if isinstance(module, lora.Embedding): param_groups["embedding"][name] = param elif "lora_B" in name or param.ndim == 1: if name in decay_parameters: param_groups["groupB"][name] = param else: param_groups["groupB_no_decay"][name] = param else: param_groups["groupA"][name] = param assigned_param_groups = "" for group in param_groups: assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n" logger.debug(assigned_param_groups) lr = optimizer_kwargs["lr"] weight_decay = optimizer_kwargs.get("weight_decay", 0.0) optimizer_grouped_parameters = [ { "params": list(param_groups["groupA"].values()), "weight_decay": weight_decay, "lr": lr, }, { "params": list(param_groups["embedding"].values()), "weight_decay": weight_decay, "lr": loraplus_lr_embedding, }, { "params": list(param_groups["groupB"].values()), "weight_decay": weight_decay, "lr": lr * loraplus_lr_ratio, }, { "params": list(param_groups["groupB_no_decay"].values()), "weight_decay": 0.0, "lr": lr * loraplus_lr_ratio, }, ] optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) if optimizer_cls.__name__ == "Adam8bit": import bitsandbytes manager = bitsandbytes.optim.GlobalOptimManager.get_instance() skipped = 0 for module in opt_model.modules(): if isinstance(module, nn.Embedding): skipped += sum( {p.data_ptr(): p.numel() for p in module.parameters()}.values() ) logger.info(f"skipped {module}: {skipped/2**20}M params") manager.register_module_override(module, "weight", {"optim_bits": 32}) logger.debug(f"bitsandbytes: will optimize {module} in fp32") logger.info(f"skipped: {skipped/2**20}M params") return optimizer class LoraPlusTrainer(LogTrainer): def __init__( self, model: Union[PreTrainedModel, nn.Module] = None, args: LoraPlusTrainingArguments = None, data_collator: Optional[DataCollator] = None, train_dataset: Optional[Dataset] = None, eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, tokenizer: Optional[PreTrainedTokenizerBase] = None, model_init: Optional[Callable[[], PreTrainedModel]] = None, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, callbacks: Optional[List[TrainerCallback]] = None, optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( None, None, ), preprocess_logits_for_metrics: Optional[ Callable[[torch.Tensor, torch.Tensor], torch.Tensor] ] = None, ): assert isinstance( args, LoraPlusTrainingArguments ), "args must be of type LoraPlusTrainingArguments" super().__init__( model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks, optimizers, preprocess_logits_for_metrics, ) def create_optimizer(self): """ Overrides the method to create an optimizer with LoRA+ specific adjustments. """ if self.args.loraplus_lr_ratio is None: return super().create_optimizer() opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model if self.optimizer is None: optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( self.args ) loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None) self.optimizer = create_loraplus_optimizer( opt_model, optimizer_cls, optimizer_kwargs, loraplus_lr_ratio, loraplus_lr_embedding, ) return self.optimizer