# Copyright (c) Alibaba, Inc. and its affiliates. import importlib from dataclasses import dataclass from typing import Any, Dict, List, Tuple, Union import torch from torch import nn from torch.optim import Optimizer from transformers import Trainer, TrainingArguments, get_scheduler from swift.utils import get_logger try: from torch.optim.lr_scheduler import _LRScheduler as LRScheduler except ImportError: from torch.optim.lr_scheduler import LRScheduler logger = get_logger() @dataclass class GaLoreConfig: """ The configuration class for the Galore module. See https://arxiv.org/abs/2403.03507 Args: rank (`int`): The galore rank target_modules (`Union[str, List[str]]`): The target modules to use, if `None`, will use all attn and mlp linears update_proj_gap(`int`): The projection update interval for galore proj_type(`str`) The project type of Galore, valid values are `std`, `reverse_std`, `right`, `left`, `full` galore_scale(float): the scale of gradient optim_per_parameter(bool): Gives one optimizer per parameter """ rank: int = 128 target_modules: Union[str, List[str]] = None update_proj_gap: int = 50 galore_scale: float = 1.0 proj_type: str = 'std' optim_per_parameter: bool = False quantize: bool = False proj_quant: bool = False proj_bits: int = 4 proj_group_size: int = 256 cos_threshold: float = 0.4 gamma_proj: int = 2 queue_size: int = 5 class GaloreOptimizerWrapper(Optimizer): def __init__(self, optimizers: Dict[Any, Optimizer]): self.optimizers = optimizers super().__init__([torch.tensor([1., 2., 3.])], {'lr': 1.}) def zero_grad(self, *args, **kwargs) -> None: for optim in self.optimizers.values(): optim.zero_grad(*args, **kwargs) def step(self, *args, **kwargs) -> None: for optim in self.optimizers.values(): optim.step(*args, **kwargs) class GaloreSchedulerWrapper(LRScheduler): def __init__(self, lr_schedulers: Dict[Any, LRScheduler]): self.lr_schedulers = lr_schedulers def step(self, *args, **kwargs) -> None: for lr_scheduler in self.lr_schedulers.values(): lr_scheduler.step(*args, **kwargs) self._last_lr = lr_scheduler.get_last_lr() def create_optimizer_and_scheduler(model: nn.Module, args: TrainingArguments, config: GaLoreConfig, max_steps, **defaults): galore_params = [] for module_name, module in model.named_modules(): if not isinstance(module, (nn.Linear, nn.Embedding)) or \ not any(target_key in module_name for target_key in config.target_modules): continue if not module.weight.requires_grad: continue logger.info(f'Enable GaLore for weights in module: {module_name}') galore_params.append(module.weight) id_galore_params = [id(p) for p in galore_params] galore_defaults = { 'rank': config.rank, 'update_proj_gap': config.update_proj_gap, 'scale': config.galore_scale, 'proj_type': config.proj_type, **defaults } if config.quantize: galore_defaults['quant'] = config.proj_quant galore_defaults['quant_n_bit'] = config.proj_bits galore_defaults['quant_group_size'] = config.proj_group_size galore_defaults['cos_threshold'] = config.cos_threshold galore_defaults['gamma_proj'] = config.gamma_proj galore_defaults['queue_size'] = config.queue_size optim_cls, optim_kwargs = get_optimizer(args, config) if config.optim_per_parameter and not config.quantize: # q-galore does not support optim_per_parameter optimizer_dict = {} galore_defaults['update_proj_gap'] = galore_defaults['update_proj_gap'] * 2 for p in model.parameters(): if p.requires_grad: if id(p) in id_galore_params: optimizer_dict[p] = optim_cls([{'params': [p], **galore_defaults}], **optim_kwargs) else: optimizer_dict[p] = optim_cls([{'params': [p], **defaults}], **optim_kwargs) # get scheduler dict scheduler_dict = {} for p in model.parameters(): if p.requires_grad: scheduler_dict[p] = get_scheduler( optimizer=optimizer_dict[p], name=args.lr_scheduler_type, num_training_steps=max_steps * 2, num_warmup_steps=args.warmup_steps * 2, scheduler_specific_kwargs=args.lr_scheduler_kwargs, ) return GaloreOptimizerWrapper(optimizer_dict), GaloreSchedulerWrapper(scheduler_dict) else: decay_parameters = Trainer.get_decay_parameter_names(Trainer, model) param_groups = [{ 'params': galore_params, **galore_defaults, }] param_groups.extend([ { 'params': [ p for n, p in model.named_parameters() if (n in decay_parameters and id(p) not in id_galore_params and p.requires_grad) ], 'weight_decay': defaults['weight_decay'], }, { 'params': [ p for n, p in model.named_parameters() if (n not in decay_parameters and id(p) not in id_galore_params and p.requires_grad) ], 'weight_decay': 0.0, }, ]) optim = optim_cls(param_groups, **optim_kwargs) scheduler = get_scheduler( optimizer=optim, name=args.lr_scheduler_type, num_training_steps=max_steps, num_warmup_steps=args.warmup_steps, scheduler_specific_kwargs=args.lr_scheduler_kwargs, ) return optim, scheduler def get_optimizer(args: TrainingArguments, config: GaLoreConfig) -> Tuple[Any, Any]: # parse args.optim_args optim_args = {} if args.optim_args: for mapping in args.optim_args.replace(' ', '').split(','): key, value = mapping.split('=') optim_args[key] = value optimizer_kwargs = {'lr': args.learning_rate} adam_kwargs = { 'betas': (args.adam_beta1, args.adam_beta2), 'eps': args.adam_epsilon, } if args.optim == 'adafactor': from .adafactor import GaLoreAdafactor optimizer_cls = GaLoreAdafactor optimizer_kwargs.update({'scale_parameter': False, 'relative_step': False}) elif args.optim in ('adamw_hf', 'adamw_torch'): if config.quantize: assert importlib.util.find_spec('q_galore_torch') is not None, \ 'Please install q-galore by `pip install q_galore_torch`' logger.info('If you encounter `absmax2` error, please downgrade your bitsandbytes to 0.40.0') from swift.utils import get_dist_setting _, _, world_size, _ = get_dist_setting() if world_size > 1: # from q_galore_torch import QGaLoreAdamW8bit_simulate as GaLoreAdamW from q_galore_torch import QGaLoreAdamW8bit as GaLoreAdamW else: from q_galore_torch import QGaLoreAdamW8bit as GaLoreAdamW else: from .adamw import GaLoreAdamW optimizer_cls = GaLoreAdamW optimizer_kwargs.update(adam_kwargs) elif 'adamw' in args.optim and '8bit' in args.optim: try: from .adamw8bit import GaLoreAdamW8bit optimizer_cls = GaLoreAdamW8bit optimizer_kwargs.update(adam_kwargs) optimizer_kwargs.update({'optim_bits': 8, 'is_paged': 'paged' in args.optim}) except ImportError: raise ValueError('Trainer tried to instantiate bnb optimizer but bnb is not installed!') else: raise ValueError(f'Galore not supported for optimizer type: {args.optim}') return optimizer_cls, optimizer_kwargs