File size: 8,084 Bytes
7feac49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# Copyright (c) Alibaba, Inc. and its affiliates.
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
from dataclasses import asdict, dataclass, field
from functools import reduce

import peft
import torch
from packaging import version
from transformers import Trainer

from .lora_layers import *  # noqa
from .utils import SwiftAdapter, SwiftConfig, SwiftOutput, set_adapter

logger = get_logger()


@dataclass
class LoRAConfig(LoraConfig, SwiftConfig):
    """
    The configuration class for the loRA module.

    Args:
        use_qa_lora(bool): Use
            QA-LoRA:[Quantization-Aware Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2309.14717)
            instead of LoRA. QA-LoRA only supports AutoGPTQ quantized models.
            Deprecated, do not use this argument.
        lora_dtype(str): The dtype for all lora modules, supported values are `fp32`, `fp16`, `bf16`.
            Default value is `None`, which means follow the dtype of original module's weight.
        lorap_lr_ratio(float): The lr_ratio argument for [LoRA+](https://arxiv.org/abs/2402.12354)
    """

    use_qa_lora: bool = field(
        default=False, metadata={'help': 'Use [qa-lora](https://github.com/yuhuixu1993/qa-lora) or not'})

    use_merged_linear: bool = field(default=False, metadata={'help': 'Use merged Linear'})

    enable_lora: List[bool] = field(
        default=None, metadata={'help': 'The modules need to be turned on when using the merged linear layer'})

    lora_dtype: Optional[str] = field(
        default=None, metadata={'help': 'The lora dtype, default None means following the original layer\'s dtype'})

    lorap_lr_ratio: float = field(default=2.0**4, metadata={'help': 'The lr ratio of lora_B in lora+'})

    lorap_emb_lr: float = field(default=1e-6, metadata={'help': 'The lr for embedding in lora+'})

    def __post_init__(self):
        super().__post_init__()
        from .mapping import SwiftTuners
        self.swift_type = SwiftTuners.LORA

    def can_be_saved_to_peft(self) -> bool:
        if self.use_qa_lora or self.use_merged_linear:
            logger.warn('QA-LoRA and MergedLinear cannot be saved to peft format')
            return False
        return True

    def to_peft_config(self) -> LoraConfig:
        _dict = asdict(self)
        _dict.pop('use_qa_lora', None)
        _dict.pop('enable_lora', None)
        _dict.pop('lora_dtype', None)
        _dict.pop('use_merged_linear', None)
        _dict['peft_type'] = _dict['swift_type']
        _dict.pop('swift_type', None)
        _dict.pop('lr_ratio', None)
        _dict.pop('model_key_mapping', None)
        return LoraConfig(**_dict)

    def save_pretrained(self, save_directory: str, **kwargs) -> None:
        super(peft.LoraConfig, self).save_pretrained(save_directory, **kwargs)


class LoRA(SwiftAdapter):

    @staticmethod
    def prepare_model(model: nn.Module, config: LoRAConfig, adapter_name: str):
        assert not config.use_qa_lora, 'Do not use qa-lora'
        if config.use_qa_lora:
            auto_gptq_config = get_quantization_config(model, method='gptq')
            if auto_gptq_config:
                config.group_size = getattr(auto_gptq_config, 'group_size', None)
        LoraModel(model, config, adapter_name)

        def state_dict_callback(state_dict, adapter_name, cfg=None, **kwargs):
            return lora_state_dict(state_dict, adapter_name, cfg.bias if cfg else config.bias)

        def mark_trainable_callback(model, cfg=None):
            mark_lora_as_trainable(model, adapter_name, cfg.bias if cfg else config.bias)

        def optimizer_group_callback(model, **defaults):
            if config.lorap_lr_ratio is None:
                return None, None

            def get_module(name):
                parent_idx = 2 if 'lora' in name else 1
                module_names = name.split(sep='.')[:-parent_idx]
                module = reduce(getattr, module_names, model)
                return module

            all_params = set()
            param_groups = {
                'groupA': {},
                'groupB': {},
                'groupB_no_decay': {},
                'embedding': {},
            }

            decay_parameters = Trainer.get_decay_parameter_names(None, model)
            for name, param in model.named_parameters():
                if not param.requires_grad:
                    continue
                module = get_module(name)
                if isinstance(module, Embedding):
                    param_groups['embedding'][name] = param
                elif 'lora_B' in name or param.ndim == 1:
                    if name in decay_parameters:
                        param_groups['groupB'][name] = param
                    else:
                        param_groups['groupB_no_decay'][name] = param
                else:
                    param_groups['groupA'][name] = param
                all_params.add(name)

            lr = defaults['lr']
            weight_decay = defaults.get('weight_decay', 0.0)

            param_groups = [
                {
                    'params': list(param_groups['groupA'].values()),
                    'weight_decay': weight_decay,
                    'lr': lr,
                },
                {
                    'params': list(param_groups['embedding'].values()),
                    'weight_decay': weight_decay,
                    'lr': config.lorap_emb_lr,
                },
                {
                    'params': list(param_groups['groupB'].values()),
                    'weight_decay': weight_decay,
                    'lr': lr * config.lorap_lr_ratio,
                },
                {
                    'params': list(param_groups['groupB_no_decay'].values()),
                    'weight_decay': 0.0,
                    'lr': lr * config.lorap_lr_ratio,
                },
            ]
            return all_params, param_groups

        return SwiftOutput(
            config=config,
            state_dict_callback=state_dict_callback,
            mark_trainable_callback=mark_trainable_callback,
            optimizer_group_callback=optimizer_group_callback)

    @staticmethod
    def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None):
        set_adapter(module, adapter_name, activate, offload)
        for sub_module in module.modules():
            if isinstance(sub_module, (LoraLayer, LoRALayer)):
                sub_module.set_activation(adapter_name, activate)
                if hasattr(sub_module, 'save_memory'):
                    sub_module.save_memory(adapter_name, activate, offload)

    @staticmethod
    def unpatch_lora(model, config: LoRAConfig, adapter_name: str):
        """Unpatch lora modules and merge the weights to original modules.

        LoRA constructs an additional layer with low-rank decomposition matrices of the weights in the network.
        'LoRA: Low-Rank Adaptation of Large Language Models' by Hu et al.(2021)
        See https://arxiv.org/abs/2106.09685

        Args:
            model(`torch.nn.Module`): The model called with `tune` function.
            config(`LoRAConfig`): The `LoRAConfig` to use. Deprecated
            adapter_name(`str`): The adapter name
        """
        if not config.use_merged_linear:
            if version.parse(peft.__version__) < version.parse('0.6.3'):
                logger.info('All adapters will be merged.')
                LoraModel(model, None, '').merge_and_unload()
            else:
                LoraModel(model, None, '').merge_and_unload(adapter_names=[adapter_name])
        else:
            for name, sub_module in model.named_modules():
                if isinstance(sub_module, MergedLinear):
                    sub_module.merge()
                    parent = model.get_submodule('.'.join(name.split('.')[:-1]))
                    target_name = name.split('.')[-1]
                    setattr(parent, target_name, sub_module.base_layer)