File size: 8,839 Bytes
7feac49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 |
# Copyright (c) Alibaba, Inc. and its affiliates.
from dataclasses import dataclass
from types import MethodType
from typing import List, Literal, Optional
import json
import torch
from torch import nn
from swift.utils import get_logger, patch_getattr
from .utils import SwiftAdapter, SwiftConfig, SwiftOutput
logger = get_logger()
@dataclass
class ReftConfig(SwiftConfig):
"""
Train a model with Reft.
Paper: https://arxiv.org/pdf/2404.03592
Args:
model_type(`Optional[str]`): The model_type to find down_proj/layers.
layer_key(`Optional[str]`): Manually specify the layer key, for example `language_model.layers`.
layers (`Optional[List[int]]`): The layer number to inject.
r(`int`): The rank of Reft.
intervention_type (`Literal['NoreftIntervention', 'LoreftIntervention',
'ConsreftIntervention', 'LobireftIntervention',
'DireftIntervention', 'NodireftIntervention']`): The intervention type,
default LoreftIntervention
args (`Optional[str]`): Other reft_args in json-string format
"""
model_type: Optional[str] = None
layer_key: Optional[str] = None
layers: Optional[List[int]] = None
r: int = 4
intervention_type: Literal['NoreftIntervention', 'LoreftIntervention', 'ConsreftIntervention',
'LobireftIntervention', 'DireftIntervention',
'NodireftIntervention'] = 'LoreftIntervention'
args: Optional[str] = None
def __post_init__(self):
from .mapping import SwiftTuners
self.swift_type = SwiftTuners.REFT
if self.args:
self.args = json.loads(self.args)
else:
self.args = {}
class Reft(SwiftAdapter):
@staticmethod
def prepare_model(model: nn.Module, config: ReftConfig, adapter_name: str):
from swift.utils.import_utils import is_pyreft_available
if not is_pyreft_available():
raise ImportError('Please install pyreft before using ReFT: ' '`pip install pyreft`')
import pyreft
from pyreft import ReftModel
from pyreft.interventions import LowRankRotateLayer
from pyreft import (
NoreftIntervention,
LoreftIntervention,
ConsreftIntervention,
LobireftIntervention,
DireftIntervention,
NodireftIntervention,
)
intervention_mapping = {
'NoreftIntervention': NoreftIntervention,
'LoreftIntervention': LoreftIntervention,
'ConsreftIntervention': ConsreftIntervention,
'LobireftIntervention': LobireftIntervention,
'DireftIntervention': DireftIntervention,
'NodireftIntervention': NodireftIntervention,
}
patch_getattr(ReftModel, 'model')
def forward(self, x):
self.to(x.device)
return self.forward_origin(x)
def forward2(self, base, source=None, subspaces=None):
self.to(base.device)
return self.forward_origin(base, source, subspaces)
if not hasattr(LowRankRotateLayer, 'forward_origin'):
LowRankRotateLayer.forward_origin = LowRankRotateLayer.forward
LowRankRotateLayer.forward = forward
NoreftIntervention.forward_origin = NoreftIntervention.forward
NoreftIntervention.forward = forward2
LoreftIntervention.forward_origin = LoreftIntervention.forward
LoreftIntervention.forward = forward2
ConsreftIntervention.forward_origin = ConsreftIntervention.forward
ConsreftIntervention.forward = forward2
LobireftIntervention.forward_origin = LobireftIntervention.forward
LobireftIntervention.forward = forward2
DireftIntervention.forward_origin = DireftIntervention.forward
DireftIntervention.forward = forward2
NodireftIntervention.forward_origin = NodireftIntervention.forward
NodireftIntervention.forward = forward2
module_list_key = config.layer_key
if module_list_key is None:
model_key_mapping = Reft.get_model_key_mapping(config.model_type, config)
module_list_key = model_key_mapping.module_list
logger.info(f'Applying Reft to module: {module_list_key}')
module_list: nn.ModuleList = model.get_submodule(module_list_key)
representations = []
for idx, layer in enumerate(module_list):
if config.layers and idx not in config.layers:
continue
intervention_config = {
'layer':
idx,
'component':
module_list_key + f'[{idx}].output',
'low_rank_dimension':
config.r,
'intervention':
intervention_mapping[config.intervention_type](
embed_dim=model.config.hidden_size, low_rank_dimension=config.r, **config.args)
}
representations.append(intervention_config)
reft_config = pyreft.ReftConfig(representations=representations)
reft_model = pyreft.get_reft_model(model, reft_config, set_device=False)
reft_model.reft_config = reft_model.config
reft_model.config = reft_model.model.config
def _pre_forward_hook(module, args, kwargs):
if 'base' in kwargs:
return args, kwargs
if 'input_ids' not in kwargs:
raise ValueError('Input does not contain `input_ids`, maybe the model does not support ReFT.')
# run intervened forward pass
unit_locations = None
if 'intervention_locations' in kwargs:
if kwargs['intervention_locations'].dim() == 3:
unit_locations = {
'sources->base': (None, kwargs['intervention_locations'].permute(1, 0, 2).tolist())
}
else:
# this is dummy for lora only baseline
unit_locations = {'sources->base': (None, 0)}
kwargs = {
'base': {
'input_ids': kwargs['input_ids'],
'attention_mask': kwargs['attention_mask']
},
'unit_locations': unit_locations,
'labels': kwargs['labels'],
'subspaces': kwargs['subspaces'].permute(1, 0, 2).tolist() if 'subspaces' in kwargs else None
}
return args, kwargs
def _post_forward_hook(module, args, kwargs, outputs):
return outputs[1]
def _generate(self, **kwargs):
# run intervened forward pass
unit_locations = None
if 'intervention_locations' in kwargs:
if kwargs['intervention_locations'].dim() == 3:
unit_locations = {
'sources->base': (None, kwargs['intervention_locations'].permute(1, 0, 2).tolist())
}
else:
# this is dummy for lora only baseline
unit_locations = {'sources->base': (None, 0)}
_kwargs = {
'base': {
'input_ids': kwargs.pop('input_ids'),
'attention_mask': kwargs.pop('attention_mask')
},
'unit_locations': unit_locations,
'subspaces': kwargs.pop('subspaces').permute(1, 0, 2).tolist() if 'subspaces' in kwargs else None
}
_kwargs = {**_kwargs, **kwargs}
return self.generate_origin(**_kwargs)[1]
reft_model.generate_origin = reft_model.generate
reft_model.generate = MethodType(_generate, reft_model)
reft_model.register_forward_pre_hook(_pre_forward_hook, with_kwargs=True)
reft_model.register_forward_hook(_post_forward_hook, with_kwargs=True)
def save_callback(swift_model, model_dir, adapter_name):
reft_model.save_intervention(save_directory=model_dir, include_model=False)
def mark_trainable_callback(model):
return
def load_callback(swift_model, model_dir, adapter_name):
reft_model.load_intervention(model_dir, include_model=False)
return SwiftOutput(
model=reft_model,
config=config,
mark_trainable_callback=mark_trainable_callback,
save_callback=save_callback,
load_callback=load_callback)
@staticmethod
def has_additional_modules():
return True
@staticmethod
def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None):
assert activate, 'ReFT does not support deactivate'
|