File size: 9,550 Bytes
7feac49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 |
# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import re
import types
from collections import OrderedDict
from dataclasses import dataclass, field
from functools import partial
from itertools import repeat
from typing import Union
import torch
from torch import nn
from swift.utils.logger import get_logger
from swift.utils.torch_utils import find_sub_module
from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput
logger = get_logger()
@dataclass
class SideConfig(SwiftConfig):
"""
The configuration class for the side module.
Side-Tuning only needs to train one side network and
weights the output of pre-trained model and side network.
'Side-Tuning: A Baseline for Network Adaptation via Additive Side Networks'
by Zhang et al.(2019)
See https://arxiv.org/abs/1912.13503
Args:
target_modules: The feedforward module to be replaced, in regex format
"""
dim: int = field(default=None, metadata={'help': 'The dimension of the hidden states'})
target_modules: str = field(
default=None, metadata={'help': 'The target module to be replaced, in full match format'})
side_module_name: str = field(default='fcn4', metadata={'help': 'The name of the additive side networks'})
source_hidden_pos: Union[str, int] = field(
default=0,
metadata={
'help': 'The position of the hidden state input to the target module, can be int (args) or str (kwargs)'
})
target_hidden_pos: Union[str, int] = field(
default=0,
metadata={
'help': 'The position of the hidden state output from the target module, can be int (args) or str (kwargs)'
})
def __post_init__(self):
from .mapping import SwiftTuners
self.swift_type = SwiftTuners.SIDE
class Side(SwiftAdapter):
@staticmethod
def prepare_model(model: nn.Module, config: SideConfig, adapter_name: str) -> SwiftOutput:
"""Prepare a model with `SideConfig`"""
module_keys = [key for key, _ in model.named_modules()]
for module_key in module_keys:
if re.fullmatch(config.target_modules, module_key): # noqa
tgt_module = model.get_submodule(module_key)
logger.info(f'Matching target module [{module_key}] of type {type(tgt_module)}')
if isinstance(tgt_module, (nn.ModuleList, nn.ModuleDict)):
raise Exception(
f'Type of {type(tgt_module)} may not be supported because of its customized forward')
def _forward(self, *args, **kwargs):
args_main = getattr(self, f'forward_origin_{adapter_name}')(*args, **kwargs)
if isinstance(config.source_hidden_pos, int):
x = args[config.source_hidden_pos]
else:
x = kwargs[config.source_hidden_pos]
x_main = args_main[config.target_hidden_pos] \
if isinstance(args_main, (tuple, list, dict)) else args_main
out = getattr(self, f'side_{adapter_name}')(x, x_main)
if isinstance(args_main, (tuple, list, dict)):
args_main[config.target_hidden_pos] = out
else:
args_main = out
return args_main
if isinstance(tgt_module, nn.Sequential) and not hasattr(tgt_module, 'tgt_module_keys'):
tgt_module.tgt_module_keys = copy.deepcopy(list(tgt_module._modules.keys()))
def forward_seq(self, input, *args, **kwargs):
for idx, module in enumerate(self):
if idx >= len(tgt_module.tgt_module_keys):
continue
input = module(input)
return input
setattr(tgt_module, f'forward_origin_{adapter_name}', types.MethodType(forward_seq, tgt_module))
else:
setattr(tgt_module, f'forward_origin_{adapter_name}', tgt_module.forward)
tgt_module.forward = types.MethodType(_forward, tgt_module)
side_module = SideModule(config.dim, adapter_name, module_key, config.side_module_name)
setattr(tgt_module, f'side_{adapter_name}', side_module)
logger.info(f'Side modules(module_key): {module_key}.side_{adapter_name}')
def state_dict_callback(state_dict, adapter_name, **kwargs):
return {key: value for key, value in state_dict.items() if f'side_{adapter_name}' in key}
def mark_trainable_callback(model):
return
return SwiftOutput(
config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback)
@staticmethod
def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None):
modules = find_sub_module(module, f'side_{adapter_name}')
for _module in modules:
_module: ActivationMixin
_module: nn.Module
_module.set_activation(adapter_name, activate)
SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, activate, offload)
class SideModule(nn.Module, ActivationMixin):
"""The implementation of vision side-tuning method.
Side-Tuning only needs to train one side network and
weights the output of pre-trained model and side network.
'Side-Tuning: A Baseline for Network Adaptation via Additive Side Networks'
by Zhang et al.(2019)
See https://arxiv.org/abs/1912.13503
Args:
side_module_name: The name of the additive side networks.
"""
def __init__(self, dim, adapter_name, module_key, side_module_name='fcn4'):
super(SideModule, self).__init__()
super(nn.Module, self).__init__(module_key)
self.adapter_name = adapter_name
side_module_name = side_module_name.lower()
if side_module_name == 'fcn4':
self.side_net = FCN4(out_dims=dim)
elif side_module_name == 'mlp':
self.side_net = Mlp(dim)
elif side_module_name == 'alexnet':
import torchvision
mm = torchvision.models.alexnet(pretrained=True)
self.side_net = nn.Sequential(
OrderedDict([('features', mm.features), ('avgpool', mm.avgpool), ('flatten', nn.Flatten()),
('fc', nn.Linear(9216, dim, bias=False))]))
else:
raise ValueError(f'Unsupported side_module_name: {side_module_name}')
self.alpha = nn.Parameter(torch.tensor(0.0))
self.mark_all_sub_modules_as_plugin()
def forward(self, x, x_main):
if not self.is_activated(self.adapter_name):
return x_main
alpha_squashed = torch.sigmoid(self.alpha)
x_side = self.side_net(x)
x_out = alpha_squashed * x_main + (1 - alpha_squashed) * x_side
return x_out
class FCN4(nn.Module):
"""The implementation of simple FCN4 network for side network.
"""
def __init__(self, out_dims=-1, **kwargs):
super(FCN4, self).__init__(**kwargs)
self.conv1 = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False, dilation=1), nn.GroupNorm(2, 16),
nn.ReLU())
self.conv2 = nn.Sequential(
nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=0, bias=False, dilation=1), nn.GroupNorm(2, 16),
nn.ReLU())
self.conv3 = nn.Sequential(
nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=0, bias=False, dilation=1), nn.GroupNorm(2, 32),
nn.ReLU())
self.conv4 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=0, bias=False, dilation=1), nn.GroupNorm(2, 64),
nn.ReLU())
self.pool = nn.AdaptiveAvgPool2d((1, 1))
if out_dims > 0:
self.fc = nn.Linear(64, out_dims)
else:
self.fc = None
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.pool(x)
x = x.view(x.size(0), -1)
if self.fc is not None:
x = self.fc(x)
return x
class Mlp(nn.Module):
""" MLP as used in Vision Transformer.
"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=None,
bias=True,
drop=0.,
use_conv=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = tuple(repeat(bias, 2))
drop_probs = tuple(repeat(drop, 2))
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.norm(x)
x = self.fc2(x)
x = self.drop2(x)
return x
|