# Copyright (c) Alibaba, Inc. and its affiliates. from dataclasses import dataclass from types import MethodType from typing import List, Literal, Optional import json import torch from torch import nn from swift.utils import get_logger, patch_getattr from .utils import SwiftAdapter, SwiftConfig, SwiftOutput logger = get_logger() @dataclass class ReftConfig(SwiftConfig): """ Train a model with Reft. Paper: https://arxiv.org/pdf/2404.03592 Args: model_type(`Optional[str]`): The model_type to find down_proj/layers. layer_key(`Optional[str]`): Manually specify the layer key, for example `language_model.layers`. layers (`Optional[List[int]]`): The layer number to inject. r(`int`): The rank of Reft. intervention_type (`Literal['NoreftIntervention', 'LoreftIntervention', 'ConsreftIntervention', 'LobireftIntervention', 'DireftIntervention', 'NodireftIntervention']`): The intervention type, default LoreftIntervention args (`Optional[str]`): Other reft_args in json-string format """ model_type: Optional[str] = None layer_key: Optional[str] = None layers: Optional[List[int]] = None r: int = 4 intervention_type: Literal['NoreftIntervention', 'LoreftIntervention', 'ConsreftIntervention', 'LobireftIntervention', 'DireftIntervention', 'NodireftIntervention'] = 'LoreftIntervention' args: Optional[str] = None def __post_init__(self): from .mapping import SwiftTuners self.swift_type = SwiftTuners.REFT if self.args: self.args = json.loads(self.args) else: self.args = {} class Reft(SwiftAdapter): @staticmethod def prepare_model(model: nn.Module, config: ReftConfig, adapter_name: str): from swift.utils.import_utils import is_pyreft_available if not is_pyreft_available(): raise ImportError('Please install pyreft before using ReFT: ' '`pip install pyreft`') import pyreft from pyreft import ReftModel from pyreft.interventions import LowRankRotateLayer from pyreft import ( NoreftIntervention, LoreftIntervention, ConsreftIntervention, LobireftIntervention, DireftIntervention, NodireftIntervention, ) intervention_mapping = { 'NoreftIntervention': NoreftIntervention, 'LoreftIntervention': LoreftIntervention, 'ConsreftIntervention': ConsreftIntervention, 'LobireftIntervention': LobireftIntervention, 'DireftIntervention': DireftIntervention, 'NodireftIntervention': NodireftIntervention, } patch_getattr(ReftModel, 'model') def forward(self, x): self.to(x.device) return self.forward_origin(x) def forward2(self, base, source=None, subspaces=None): self.to(base.device) return self.forward_origin(base, source, subspaces) if not hasattr(LowRankRotateLayer, 'forward_origin'): LowRankRotateLayer.forward_origin = LowRankRotateLayer.forward LowRankRotateLayer.forward = forward NoreftIntervention.forward_origin = NoreftIntervention.forward NoreftIntervention.forward = forward2 LoreftIntervention.forward_origin = LoreftIntervention.forward LoreftIntervention.forward = forward2 ConsreftIntervention.forward_origin = ConsreftIntervention.forward ConsreftIntervention.forward = forward2 LobireftIntervention.forward_origin = LobireftIntervention.forward LobireftIntervention.forward = forward2 DireftIntervention.forward_origin = DireftIntervention.forward DireftIntervention.forward = forward2 NodireftIntervention.forward_origin = NodireftIntervention.forward NodireftIntervention.forward = forward2 module_list_key = config.layer_key if module_list_key is None: model_key_mapping = Reft.get_model_key_mapping(config.model_type, config) module_list_key = model_key_mapping.module_list logger.info(f'Applying Reft to module: {module_list_key}') module_list: nn.ModuleList = model.get_submodule(module_list_key) representations = [] for idx, layer in enumerate(module_list): if config.layers and idx not in config.layers: continue intervention_config = { 'layer': idx, 'component': module_list_key + f'[{idx}].output', 'low_rank_dimension': config.r, 'intervention': intervention_mapping[config.intervention_type]( embed_dim=model.config.hidden_size, low_rank_dimension=config.r, **config.args) } representations.append(intervention_config) reft_config = pyreft.ReftConfig(representations=representations) reft_model = pyreft.get_reft_model(model, reft_config, set_device=False) reft_model.reft_config = reft_model.config reft_model.config = reft_model.model.config def _pre_forward_hook(module, args, kwargs): if 'base' in kwargs: return args, kwargs if 'input_ids' not in kwargs: raise ValueError('Input does not contain `input_ids`, maybe the model does not support ReFT.') # run intervened forward pass unit_locations = None if 'intervention_locations' in kwargs: if kwargs['intervention_locations'].dim() == 3: unit_locations = { 'sources->base': (None, kwargs['intervention_locations'].permute(1, 0, 2).tolist()) } else: # this is dummy for lora only baseline unit_locations = {'sources->base': (None, 0)} kwargs = { 'base': { 'input_ids': kwargs['input_ids'], 'attention_mask': kwargs['attention_mask'] }, 'unit_locations': unit_locations, 'labels': kwargs['labels'], 'subspaces': kwargs['subspaces'].permute(1, 0, 2).tolist() if 'subspaces' in kwargs else None } return args, kwargs def _post_forward_hook(module, args, kwargs, outputs): return outputs[1] def _generate(self, **kwargs): # run intervened forward pass unit_locations = None if 'intervention_locations' in kwargs: if kwargs['intervention_locations'].dim() == 3: unit_locations = { 'sources->base': (None, kwargs['intervention_locations'].permute(1, 0, 2).tolist()) } else: # this is dummy for lora only baseline unit_locations = {'sources->base': (None, 0)} _kwargs = { 'base': { 'input_ids': kwargs.pop('input_ids'), 'attention_mask': kwargs.pop('attention_mask') }, 'unit_locations': unit_locations, 'subspaces': kwargs.pop('subspaces').permute(1, 0, 2).tolist() if 'subspaces' in kwargs else None } _kwargs = {**_kwargs, **kwargs} return self.generate_origin(**_kwargs)[1] reft_model.generate_origin = reft_model.generate reft_model.generate = MethodType(_generate, reft_model) reft_model.register_forward_pre_hook(_pre_forward_hook, with_kwargs=True) reft_model.register_forward_hook(_post_forward_hook, with_kwargs=True) def save_callback(swift_model, model_dir, adapter_name): reft_model.save_intervention(save_directory=model_dir, include_model=False) def mark_trainable_callback(model): return def load_callback(swift_model, model_dir, adapter_name): reft_model.load_intervention(model_dir, include_model=False) return SwiftOutput( model=reft_model, config=config, mark_trainable_callback=mark_trainable_callback, save_callback=save_callback, load_callback=load_callback) @staticmethod def has_additional_modules(): return True @staticmethod def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None): assert activate, 'ReFT does not support deactivate'