File size: 8,839 Bytes
7feac49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
# Copyright (c) Alibaba, Inc. and its affiliates.
from dataclasses import dataclass
from types import MethodType
from typing import List, Literal, Optional

import json
import torch
from torch import nn

from swift.utils import get_logger, patch_getattr
from .utils import SwiftAdapter, SwiftConfig, SwiftOutput

logger = get_logger()


@dataclass
class ReftConfig(SwiftConfig):
    """
    Train a model with Reft.
    Paper: https://arxiv.org/pdf/2404.03592

    Args:
        model_type(`Optional[str]`): The model_type to find down_proj/layers.
        layer_key(`Optional[str]`): Manually specify the layer key, for example `language_model.layers`.
        layers (`Optional[List[int]]`): The layer number to inject.
        r(`int`): The rank of Reft.
        intervention_type (`Literal['NoreftIntervention', 'LoreftIntervention',
                        'ConsreftIntervention', 'LobireftIntervention',
                        'DireftIntervention', 'NodireftIntervention']`): The intervention type,
                        default LoreftIntervention
        args (`Optional[str]`): Other reft_args in json-string format
    """

    model_type: Optional[str] = None
    layer_key: Optional[str] = None
    layers: Optional[List[int]] = None
    r: int = 4
    intervention_type: Literal['NoreftIntervention', 'LoreftIntervention', 'ConsreftIntervention',
                               'LobireftIntervention', 'DireftIntervention',
                               'NodireftIntervention'] = 'LoreftIntervention'
    args: Optional[str] = None

    def __post_init__(self):
        from .mapping import SwiftTuners
        self.swift_type = SwiftTuners.REFT
        if self.args:
            self.args = json.loads(self.args)
        else:
            self.args = {}


class Reft(SwiftAdapter):

    @staticmethod
    def prepare_model(model: nn.Module, config: ReftConfig, adapter_name: str):
        from swift.utils.import_utils import is_pyreft_available
        if not is_pyreft_available():
            raise ImportError('Please install pyreft before using ReFT: ' '`pip install pyreft`')

        import pyreft
        from pyreft import ReftModel
        from pyreft.interventions import LowRankRotateLayer
        from pyreft import (
            NoreftIntervention,
            LoreftIntervention,
            ConsreftIntervention,
            LobireftIntervention,
            DireftIntervention,
            NodireftIntervention,
        )

        intervention_mapping = {
            'NoreftIntervention': NoreftIntervention,
            'LoreftIntervention': LoreftIntervention,
            'ConsreftIntervention': ConsreftIntervention,
            'LobireftIntervention': LobireftIntervention,
            'DireftIntervention': DireftIntervention,
            'NodireftIntervention': NodireftIntervention,
        }

        patch_getattr(ReftModel, 'model')

        def forward(self, x):
            self.to(x.device)
            return self.forward_origin(x)

        def forward2(self, base, source=None, subspaces=None):
            self.to(base.device)
            return self.forward_origin(base, source, subspaces)

        if not hasattr(LowRankRotateLayer, 'forward_origin'):
            LowRankRotateLayer.forward_origin = LowRankRotateLayer.forward
            LowRankRotateLayer.forward = forward
            NoreftIntervention.forward_origin = NoreftIntervention.forward
            NoreftIntervention.forward = forward2
            LoreftIntervention.forward_origin = LoreftIntervention.forward
            LoreftIntervention.forward = forward2
            ConsreftIntervention.forward_origin = ConsreftIntervention.forward
            ConsreftIntervention.forward = forward2
            LobireftIntervention.forward_origin = LobireftIntervention.forward
            LobireftIntervention.forward = forward2
            DireftIntervention.forward_origin = DireftIntervention.forward
            DireftIntervention.forward = forward2
            NodireftIntervention.forward_origin = NodireftIntervention.forward
            NodireftIntervention.forward = forward2

        module_list_key = config.layer_key
        if module_list_key is None:
            model_key_mapping = Reft.get_model_key_mapping(config.model_type, config)
            module_list_key = model_key_mapping.module_list
        logger.info(f'Applying Reft to module: {module_list_key}')
        module_list: nn.ModuleList = model.get_submodule(module_list_key)
        representations = []
        for idx, layer in enumerate(module_list):
            if config.layers and idx not in config.layers:
                continue
            intervention_config = {
                'layer':
                idx,
                'component':
                module_list_key + f'[{idx}].output',
                'low_rank_dimension':
                config.r,
                'intervention':
                intervention_mapping[config.intervention_type](
                    embed_dim=model.config.hidden_size, low_rank_dimension=config.r, **config.args)
            }
            representations.append(intervention_config)

        reft_config = pyreft.ReftConfig(representations=representations)
        reft_model = pyreft.get_reft_model(model, reft_config, set_device=False)
        reft_model.reft_config = reft_model.config
        reft_model.config = reft_model.model.config

        def _pre_forward_hook(module, args, kwargs):
            if 'base' in kwargs:
                return args, kwargs

            if 'input_ids' not in kwargs:
                raise ValueError('Input does not contain `input_ids`, maybe the model does not support ReFT.')
            # run intervened forward pass
            unit_locations = None
            if 'intervention_locations' in kwargs:
                if kwargs['intervention_locations'].dim() == 3:
                    unit_locations = {
                        'sources->base': (None, kwargs['intervention_locations'].permute(1, 0, 2).tolist())
                    }
                else:
                    # this is dummy for lora only baseline
                    unit_locations = {'sources->base': (None, 0)}
            kwargs = {
                'base': {
                    'input_ids': kwargs['input_ids'],
                    'attention_mask': kwargs['attention_mask']
                },
                'unit_locations': unit_locations,
                'labels': kwargs['labels'],
                'subspaces': kwargs['subspaces'].permute(1, 0, 2).tolist() if 'subspaces' in kwargs else None
            }
            return args, kwargs

        def _post_forward_hook(module, args, kwargs, outputs):
            return outputs[1]

        def _generate(self, **kwargs):
            # run intervened forward pass
            unit_locations = None
            if 'intervention_locations' in kwargs:
                if kwargs['intervention_locations'].dim() == 3:
                    unit_locations = {
                        'sources->base': (None, kwargs['intervention_locations'].permute(1, 0, 2).tolist())
                    }
                else:
                    # this is dummy for lora only baseline
                    unit_locations = {'sources->base': (None, 0)}

            _kwargs = {
                'base': {
                    'input_ids': kwargs.pop('input_ids'),
                    'attention_mask': kwargs.pop('attention_mask')
                },
                'unit_locations': unit_locations,
                'subspaces': kwargs.pop('subspaces').permute(1, 0, 2).tolist() if 'subspaces' in kwargs else None
            }
            _kwargs = {**_kwargs, **kwargs}
            return self.generate_origin(**_kwargs)[1]

        reft_model.generate_origin = reft_model.generate
        reft_model.generate = MethodType(_generate, reft_model)
        reft_model.register_forward_pre_hook(_pre_forward_hook, with_kwargs=True)
        reft_model.register_forward_hook(_post_forward_hook, with_kwargs=True)

        def save_callback(swift_model, model_dir, adapter_name):
            reft_model.save_intervention(save_directory=model_dir, include_model=False)

        def mark_trainable_callback(model):
            return

        def load_callback(swift_model, model_dir, adapter_name):
            reft_model.load_intervention(model_dir, include_model=False)

        return SwiftOutput(
            model=reft_model,
            config=config,
            mark_trainable_callback=mark_trainable_callback,
            save_callback=save_callback,
            load_callback=load_callback)

    @staticmethod
    def has_additional_modules():
        return True

    @staticmethod
    def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None):
        assert activate, 'ReFT does not support deactivate'