| |
| import importlib |
| from dataclasses import dataclass |
| from typing import Any, Dict, List, Tuple, Union |
|
|
| import torch |
| from torch import nn |
| from torch.optim import Optimizer |
| from transformers import Trainer, TrainingArguments, get_scheduler |
|
|
| from swift.utils import get_logger |
|
|
| try: |
| from torch.optim.lr_scheduler import _LRScheduler as LRScheduler |
| except ImportError: |
| from torch.optim.lr_scheduler import LRScheduler |
|
|
| logger = get_logger() |
|
|
|
|
| @dataclass |
| class GaLoreConfig: |
| """ |
| The configuration class for the Galore module. |
| |
| |
| See https://arxiv.org/abs/2403.03507 |
| |
| Args: |
| rank (`int`): The galore rank |
| target_modules (`Union[str, List[str]]`): The target modules to use, if `None`, |
| will use all attn and mlp linears |
| update_proj_gap(`int`): The projection update interval for galore |
| proj_type(`str`) The project type of Galore, valid values are `std`, |
| `reverse_std`, `right`, `left`, `full` |
| galore_scale(float): the scale of gradient |
| optim_per_parameter(bool): Gives one optimizer per parameter |
| """ |
| rank: int = 128 |
| target_modules: Union[str, List[str]] = None |
| update_proj_gap: int = 50 |
| galore_scale: float = 1.0 |
| proj_type: str = 'std' |
| optim_per_parameter: bool = False |
| quantize: bool = False |
| proj_quant: bool = False |
| proj_bits: int = 4 |
| proj_group_size: int = 256 |
| cos_threshold: float = 0.4 |
| gamma_proj: int = 2 |
| queue_size: int = 5 |
|
|
|
|
| class GaloreOptimizerWrapper(Optimizer): |
|
|
| def __init__(self, optimizers: Dict[Any, Optimizer]): |
| self.optimizers = optimizers |
| super().__init__([torch.tensor([1., 2., 3.])], {'lr': 1.}) |
|
|
| def zero_grad(self, *args, **kwargs) -> None: |
| for optim in self.optimizers.values(): |
| optim.zero_grad(*args, **kwargs) |
|
|
| def step(self, *args, **kwargs) -> None: |
| for optim in self.optimizers.values(): |
| optim.step(*args, **kwargs) |
|
|
|
|
| class GaloreSchedulerWrapper(LRScheduler): |
|
|
| def __init__(self, lr_schedulers: Dict[Any, LRScheduler]): |
| self.lr_schedulers = lr_schedulers |
|
|
| def step(self, *args, **kwargs) -> None: |
| for lr_scheduler in self.lr_schedulers.values(): |
| lr_scheduler.step(*args, **kwargs) |
| self._last_lr = lr_scheduler.get_last_lr() |
|
|
|
|
| def create_optimizer_and_scheduler(model: nn.Module, args: TrainingArguments, config: GaLoreConfig, max_steps, |
| **defaults): |
| galore_params = [] |
| for module_name, module in model.named_modules(): |
| if not isinstance(module, (nn.Linear, nn.Embedding)) or \ |
| not any(target_key in module_name for target_key in config.target_modules): |
| continue |
|
|
| if not module.weight.requires_grad: |
| continue |
|
|
| logger.info(f'Enable GaLore for weights in module: {module_name}') |
| galore_params.append(module.weight) |
|
|
| id_galore_params = [id(p) for p in galore_params] |
| galore_defaults = { |
| 'rank': config.rank, |
| 'update_proj_gap': config.update_proj_gap, |
| 'scale': config.galore_scale, |
| 'proj_type': config.proj_type, |
| **defaults |
| } |
| if config.quantize: |
| galore_defaults['quant'] = config.proj_quant |
| galore_defaults['quant_n_bit'] = config.proj_bits |
| galore_defaults['quant_group_size'] = config.proj_group_size |
| galore_defaults['cos_threshold'] = config.cos_threshold |
| galore_defaults['gamma_proj'] = config.gamma_proj |
| galore_defaults['queue_size'] = config.queue_size |
| optim_cls, optim_kwargs = get_optimizer(args, config) |
|
|
| if config.optim_per_parameter and not config.quantize: |
| |
| optimizer_dict = {} |
| galore_defaults['update_proj_gap'] = galore_defaults['update_proj_gap'] * 2 |
| for p in model.parameters(): |
| if p.requires_grad: |
| if id(p) in id_galore_params: |
| optimizer_dict[p] = optim_cls([{'params': [p], **galore_defaults}], **optim_kwargs) |
| else: |
| optimizer_dict[p] = optim_cls([{'params': [p], **defaults}], **optim_kwargs) |
|
|
| |
| scheduler_dict = {} |
| for p in model.parameters(): |
| if p.requires_grad: |
| scheduler_dict[p] = get_scheduler( |
| optimizer=optimizer_dict[p], |
| name=args.lr_scheduler_type, |
| num_training_steps=max_steps * 2, |
| num_warmup_steps=args.warmup_steps * 2, |
| scheduler_specific_kwargs=args.lr_scheduler_kwargs, |
| ) |
|
|
| return GaloreOptimizerWrapper(optimizer_dict), GaloreSchedulerWrapper(scheduler_dict) |
| else: |
| decay_parameters = Trainer.get_decay_parameter_names(Trainer, model) |
| param_groups = [{ |
| 'params': galore_params, |
| **galore_defaults, |
| }] |
| param_groups.extend([ |
| { |
| 'params': [ |
| p for n, p in model.named_parameters() |
| if (n in decay_parameters and id(p) not in id_galore_params and p.requires_grad) |
| ], |
| 'weight_decay': |
| defaults['weight_decay'], |
| }, |
| { |
| 'params': [ |
| p for n, p in model.named_parameters() |
| if (n not in decay_parameters and id(p) not in id_galore_params and p.requires_grad) |
| ], |
| 'weight_decay': |
| 0.0, |
| }, |
| ]) |
| optim = optim_cls(param_groups, **optim_kwargs) |
| scheduler = get_scheduler( |
| optimizer=optim, |
| name=args.lr_scheduler_type, |
| num_training_steps=max_steps, |
| num_warmup_steps=args.warmup_steps, |
| scheduler_specific_kwargs=args.lr_scheduler_kwargs, |
| ) |
| return optim, scheduler |
|
|
|
|
| def get_optimizer(args: TrainingArguments, config: GaLoreConfig) -> Tuple[Any, Any]: |
| |
| optim_args = {} |
| if args.optim_args: |
| for mapping in args.optim_args.replace(' ', '').split(','): |
| key, value = mapping.split('=') |
| optim_args[key] = value |
|
|
| optimizer_kwargs = {'lr': args.learning_rate} |
|
|
| adam_kwargs = { |
| 'betas': (args.adam_beta1, args.adam_beta2), |
| 'eps': args.adam_epsilon, |
| } |
| if args.optim == 'adafactor': |
| from .adafactor import GaLoreAdafactor |
| optimizer_cls = GaLoreAdafactor |
| optimizer_kwargs.update({'scale_parameter': False, 'relative_step': False}) |
| elif args.optim in ('adamw_hf', 'adamw_torch'): |
| if config.quantize: |
| assert importlib.util.find_spec('q_galore_torch') is not None, \ |
| 'Please install q-galore by `pip install q_galore_torch`' |
| logger.info('If you encounter `absmax2` error, please downgrade your bitsandbytes to 0.40.0') |
| from swift.utils import get_dist_setting |
| _, _, world_size, _ = get_dist_setting() |
| if world_size > 1: |
| |
| from q_galore_torch import QGaLoreAdamW8bit as GaLoreAdamW |
| else: |
| from q_galore_torch import QGaLoreAdamW8bit as GaLoreAdamW |
| else: |
| from .adamw import GaLoreAdamW |
| optimizer_cls = GaLoreAdamW |
| optimizer_kwargs.update(adam_kwargs) |
| elif 'adamw' in args.optim and '8bit' in args.optim: |
| try: |
| from .adamw8bit import GaLoreAdamW8bit |
| optimizer_cls = GaLoreAdamW8bit |
| optimizer_kwargs.update(adam_kwargs) |
| optimizer_kwargs.update({'optim_bits': 8, 'is_paged': 'paged' in args.optim}) |
| except ImportError: |
| raise ValueError('Trainer tried to instantiate bnb optimizer but bnb is not installed!') |
| else: |
| raise ValueError(f'Galore not supported for optimizer type: {args.optim}') |
| return optimizer_cls, optimizer_kwargs |
|
|