File size: 8,084 Bytes
7feac49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
# Copyright (c) Alibaba, Inc. and its affiliates.
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
from dataclasses import asdict, dataclass, field
from functools import reduce
import peft
import torch
from packaging import version
from transformers import Trainer
from .lora_layers import * # noqa
from .utils import SwiftAdapter, SwiftConfig, SwiftOutput, set_adapter
logger = get_logger()
@dataclass
class LoRAConfig(LoraConfig, SwiftConfig):
"""
The configuration class for the loRA module.
Args:
use_qa_lora(bool): Use
QA-LoRA:[Quantization-Aware Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2309.14717)
instead of LoRA. QA-LoRA only supports AutoGPTQ quantized models.
Deprecated, do not use this argument.
lora_dtype(str): The dtype for all lora modules, supported values are `fp32`, `fp16`, `bf16`.
Default value is `None`, which means follow the dtype of original module's weight.
lorap_lr_ratio(float): The lr_ratio argument for [LoRA+](https://arxiv.org/abs/2402.12354)
"""
use_qa_lora: bool = field(
default=False, metadata={'help': 'Use [qa-lora](https://github.com/yuhuixu1993/qa-lora) or not'})
use_merged_linear: bool = field(default=False, metadata={'help': 'Use merged Linear'})
enable_lora: List[bool] = field(
default=None, metadata={'help': 'The modules need to be turned on when using the merged linear layer'})
lora_dtype: Optional[str] = field(
default=None, metadata={'help': 'The lora dtype, default None means following the original layer\'s dtype'})
lorap_lr_ratio: float = field(default=2.0**4, metadata={'help': 'The lr ratio of lora_B in lora+'})
lorap_emb_lr: float = field(default=1e-6, metadata={'help': 'The lr for embedding in lora+'})
def __post_init__(self):
super().__post_init__()
from .mapping import SwiftTuners
self.swift_type = SwiftTuners.LORA
def can_be_saved_to_peft(self) -> bool:
if self.use_qa_lora or self.use_merged_linear:
logger.warn('QA-LoRA and MergedLinear cannot be saved to peft format')
return False
return True
def to_peft_config(self) -> LoraConfig:
_dict = asdict(self)
_dict.pop('use_qa_lora', None)
_dict.pop('enable_lora', None)
_dict.pop('lora_dtype', None)
_dict.pop('use_merged_linear', None)
_dict['peft_type'] = _dict['swift_type']
_dict.pop('swift_type', None)
_dict.pop('lr_ratio', None)
_dict.pop('model_key_mapping', None)
return LoraConfig(**_dict)
def save_pretrained(self, save_directory: str, **kwargs) -> None:
super(peft.LoraConfig, self).save_pretrained(save_directory, **kwargs)
class LoRA(SwiftAdapter):
@staticmethod
def prepare_model(model: nn.Module, config: LoRAConfig, adapter_name: str):
assert not config.use_qa_lora, 'Do not use qa-lora'
if config.use_qa_lora:
auto_gptq_config = get_quantization_config(model, method='gptq')
if auto_gptq_config:
config.group_size = getattr(auto_gptq_config, 'group_size', None)
LoraModel(model, config, adapter_name)
def state_dict_callback(state_dict, adapter_name, cfg=None, **kwargs):
return lora_state_dict(state_dict, adapter_name, cfg.bias if cfg else config.bias)
def mark_trainable_callback(model, cfg=None):
mark_lora_as_trainable(model, adapter_name, cfg.bias if cfg else config.bias)
def optimizer_group_callback(model, **defaults):
if config.lorap_lr_ratio is None:
return None, None
def get_module(name):
parent_idx = 2 if 'lora' in name else 1
module_names = name.split(sep='.')[:-parent_idx]
module = reduce(getattr, module_names, model)
return module
all_params = set()
param_groups = {
'groupA': {},
'groupB': {},
'groupB_no_decay': {},
'embedding': {},
}
decay_parameters = Trainer.get_decay_parameter_names(None, model)
for name, param in model.named_parameters():
if not param.requires_grad:
continue
module = get_module(name)
if isinstance(module, Embedding):
param_groups['embedding'][name] = param
elif 'lora_B' in name or param.ndim == 1:
if name in decay_parameters:
param_groups['groupB'][name] = param
else:
param_groups['groupB_no_decay'][name] = param
else:
param_groups['groupA'][name] = param
all_params.add(name)
lr = defaults['lr']
weight_decay = defaults.get('weight_decay', 0.0)
param_groups = [
{
'params': list(param_groups['groupA'].values()),
'weight_decay': weight_decay,
'lr': lr,
},
{
'params': list(param_groups['embedding'].values()),
'weight_decay': weight_decay,
'lr': config.lorap_emb_lr,
},
{
'params': list(param_groups['groupB'].values()),
'weight_decay': weight_decay,
'lr': lr * config.lorap_lr_ratio,
},
{
'params': list(param_groups['groupB_no_decay'].values()),
'weight_decay': 0.0,
'lr': lr * config.lorap_lr_ratio,
},
]
return all_params, param_groups
return SwiftOutput(
config=config,
state_dict_callback=state_dict_callback,
mark_trainable_callback=mark_trainable_callback,
optimizer_group_callback=optimizer_group_callback)
@staticmethod
def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None):
set_adapter(module, adapter_name, activate, offload)
for sub_module in module.modules():
if isinstance(sub_module, (LoraLayer, LoRALayer)):
sub_module.set_activation(adapter_name, activate)
if hasattr(sub_module, 'save_memory'):
sub_module.save_memory(adapter_name, activate, offload)
@staticmethod
def unpatch_lora(model, config: LoRAConfig, adapter_name: str):
"""Unpatch lora modules and merge the weights to original modules.
LoRA constructs an additional layer with low-rank decomposition matrices of the weights in the network.
'LoRA: Low-Rank Adaptation of Large Language Models' by Hu et al.(2021)
See https://arxiv.org/abs/2106.09685
Args:
model(`torch.nn.Module`): The model called with `tune` function.
config(`LoRAConfig`): The `LoRAConfig` to use. Deprecated
adapter_name(`str`): The adapter name
"""
if not config.use_merged_linear:
if version.parse(peft.__version__) < version.parse('0.6.3'):
logger.info('All adapters will be merged.')
LoraModel(model, None, '').merge_and_unload()
else:
LoraModel(model, None, '').merge_and_unload(adapter_names=[adapter_name])
else:
for name, sub_module in model.named_modules():
if isinstance(sub_module, MergedLinear):
sub_module.merge()
parent = model.get_submodule('.'.join(name.split('.')[:-1]))
target_name = name.split('.')[-1]
setattr(parent, target_name, sub_module.base_layer)
|