File size: 10,597 Bytes
7feac49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 |
# Copyright (c) Alibaba, Inc. and its affiliates.
import re
import types
from dataclasses import dataclass, field
from typing import List, Optional, Union
import torch
from torch import nn
from swift.tuners.utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput
from swift.utils import get_logger
from swift.utils.torch_utils import find_sub_module
from .scetuning_components import probe_output_hook
logger = get_logger()
@dataclass
class SCETuningConfig(SwiftConfig):
"""
The configuration class for the SCEdit module.
'SCEdit: Efficient and Controllable Image Diffusion Generation via Skip Connection Editing' by Jiang et al.(2023)
See https://arxiv.org/abs/2312.11392
Args:
dims(`Union[List[int], int]`): The dimensions of the hidden states
target_modules(`Union[List[str], str]`): The target module to be replaced, can a regex string
hint_modules(`Union[List[str], str]`): The hint module to be replaced, can a regex string
tuner_mode(`str`): Location of tuner operation.
tuner_op(`str`): Tuner operation.
down_ratio(`float`): The dim down ratio of tuner hidden state.
"""
dims: Optional[Union[List[int], int]] = field(
default=None, metadata={'help': 'The dimensions of the hidden states'})
target_modules: Optional[Union[List[str], str]] = field(
default=None,
metadata={'help': 'The target module to be replaced, can be a regex string or name list of full match format'})
hint_modules: Optional[Union[List[str], str]] = field(
default=None,
metadata={'help': 'The hint modules to be replaced, can be a regex string or name list of full match format'})
tuner_mode: str = field(
default='decoder',
metadata={'help': 'Location of tuner operation. The tuner mode choices: encoder, decoder, and identity'})
tuner_op: str = field(default='SCEAdapter', metadata={'help': 'The tuner ops choices: SCEAdapter'})
down_ratio: float = field(default=1.0, metadata={'help': 'The dim down ratio of tuner hidden state'})
def __post_init__(self):
from swift.tuners.mapping import SwiftTuners
self.swift_type = SwiftTuners.SCETUNING
class SCETuning(SwiftAdapter):
@staticmethod
def prepare_model(model: nn.Module, config: SCETuningConfig, adapter_name: str) -> SwiftOutput:
"""Prepare a model with `SCETuningConfig`"""
module_keys = [key for key, _ in model.named_modules()]
# 1. Matching the hint module
hint_module_ins_list = []
if config.hint_modules:
if isinstance(config.hint_modules, list):
for module_key in config.hint_modules:
assert module_key in module_keys
h_module = model.get_submodule(module_key)
logger.info(f'Matching hint module [{module_key}] of type {type(h_module)}')
if isinstance(h_module, (nn.ModuleList, nn.ModuleDict)):
logger.warning(
f'Type of {type(h_module)} may not be supported because of its customized forward')
h_module.register_forward_hook(probe_output_hook, with_kwargs=True)
hint_module_ins_list.append(h_module)
else:
for module_key in module_keys:
if re.fullmatch(config.hint_modules, module_key):
h_module = model.get_submodule(module_key)
logger.info(f'Matching hint module [{module_key}] of type {type(h_module)}')
if isinstance(h_module, (nn.ModuleList, nn.ModuleDict)):
logger.warning(
f'Type of {type(h_module)} may not be supported because of its customized forward')
h_module.register_forward_hook(probe_output_hook, with_kwargs=True)
hint_module_ins_list.append(h_module)
if len(hint_module_ins_list) == 0:
logger.error('Cannot match hint modules')
def _get_module(module):
if isinstance(module, nn.ModuleList):
module = module[-1]
return _get_module(module)
return module
# 2. Matching the target module
target_module_ins_list = []
assert config.target_modules is not None
if isinstance(config.target_modules, list):
for module_key in config.target_modules:
assert module_key in module_keys
t_module = model.get_submodule(module_key)
logger.info(f'Matching target module [{module_key}] of type {type(t_module)}')
target_module_ins_list.append(_get_module(t_module))
else:
for module_key in module_keys:
if re.fullmatch(config.target_modules, module_key):
t_module = model.get_submodule(module_key)
logger.info(f'Matching target module [{module_key}] of type {type(t_module)}')
target_module_ins_list.append(_get_module(t_module))
if len(target_module_ins_list) == 0:
logger.error('Cannot match target modules')
if len(hint_module_ins_list) > 0 and not len(hint_module_ins_list) == len(target_module_ins_list):
logger.info("Target modules' length should be equal with hint modules.")
assert len(hint_module_ins_list) == len(target_module_ins_list)
if isinstance(config.dims, int):
dims = [config.dims for _ in target_module_ins_list]
else:
assert len(config.dims) == len(target_module_ins_list)
dims = config.dims
# refactor forward function
def _forward_encoder_mode(self, *args, **kwargs):
args = getattr(self, f'forward_origin_{adapter_name}')(*args, **kwargs)
args_type = type(args)
if args_type is tuple:
args = args[0]
if hasattr(self, 'hint'):
hint_out = self.hint.probe_output_data
args_main = getattr(self, f'scetuner_{adapter_name}')(args, hint_out)
else:
args_main = getattr(self, f'scetuner_{adapter_name}')(args)
if args_type is tuple:
args_main = (args_main, )
return args_main
def _forward_decoder_mode(self, *args, **kwargs):
args_type = type(args)
if args_type is tuple:
args_sub_tuner = args[0]
args_sub_extra = args[1:]
tuner_module = getattr(self, f'scetuner_{adapter_name}')
args_hidden, args_res = torch.split(args_sub_tuner, args_sub_tuner.shape[1] - tuner_module.dim, 1)
if hasattr(self, 'hint'):
hint_out = self.hint.probe_output_data
args_res_new = tuner_module(args_res, hint_out)
else:
args_res_new = tuner_module(args_res)
args_sub_tuner_new = torch.cat([args_hidden, args_res_new], dim=1)
if args_type is tuple:
args_main = (args_sub_tuner_new, *args_sub_extra)
args_main = getattr(self, f'forward_origin_{adapter_name}')(*args_main, **kwargs)
return args_main
# 3. inject the tuners
for tuner_id, t_module in enumerate(target_module_ins_list):
setattr(t_module, f'forward_origin_{adapter_name}', getattr(t_module, 'forward'))
if config.tuner_mode in ('encoder', 'identity'):
_forward = _forward_encoder_mode
elif config.tuner_mode == 'decoder':
_forward = _forward_decoder_mode
else:
raise Exception(f'Error tuner_mode: {config.tuner_mode}')
setattr(t_module, 'forward', types.MethodType(_forward, t_module))
tuner_op = SCETunerModule(
name=config.tuner_op,
adapter_name=adapter_name,
module_key=str(tuner_id),
dim=dims[tuner_id],
tuner_length=int(dims[tuner_id] * config.down_ratio))
setattr(t_module, f'scetuner_{adapter_name}', tuner_op)
if len(hint_module_ins_list) > 0:
setattr(t_module, 'hint', hint_module_ins_list[tuner_id])
def state_dict_callback(state_dict, adapter_name, **kwargs):
state_dict_new = {key: value for key, value in state_dict.items() if f'scetuner_{adapter_name}' in key}
return state_dict_new
def mark_trainable_callback(model):
return
return SwiftOutput(
config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback)
@staticmethod
def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None):
modules = find_sub_module(module, f'scetuner_{adapter_name}')
for _module in modules:
_module: ActivationMixin
_module: nn.Module
_module.set_activation(adapter_name, activate)
SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, activate, offload)
class SCETunerModule(nn.Module, ActivationMixin):
def __init__(self,
name,
adapter_name,
module_key,
dim,
tuner_length,
tuner_type=None,
tuner_weight=None,
act_layer=nn.GELU,
zero_init_last=True,
use_bias=True):
super(SCETunerModule, self).__init__()
super(nn.Module, self).__init__(module_key)
self.name = name
self.adapter_name = adapter_name
self.dim = dim
if name == 'SCEAdapter':
from .scetuning_components import SCEAdapter
self.tuner_op = SCEAdapter(
dim=dim,
adapter_length=tuner_length,
adapter_type=tuner_type,
adapter_weight=tuner_weight,
act_layer=act_layer)
else:
raise Exception(f'Error tuner op {name}')
self.mark_all_sub_modules_as_plugin()
def forward(self, x, x_shortcut=None, use_shortcut=True, **kwargs):
if not self.is_activated(self.adapter_name):
return x
if self.name == 'SCEAdapter':
self.tuner_op.to(x.device)
out = self.tuner_op(x)
else:
raise Exception(f'Error tuner op {self.name}')
return out
|