| | |
| | import math |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from einops import rearrange |
| |
|
| | from swift.utils.logger import get_logger |
| |
|
| | logger = get_logger() |
| |
|
| |
|
| | class ResTuner(nn.Module): |
| |
|
| | def __init__(self, dim=None, layer_num=-1, depth=-1, zero_init_last=False, stage='', tuner_cfg={}, **kwargs): |
| | super().__init__() |
| | self.dim = dim |
| | self.layer_num = layer_num |
| | self.depth = depth |
| | self.stage = stage |
| | self.tuner_cfg = tuner_cfg |
| |
|
| | if (isinstance(tuner_cfg, str) and tuner_cfg == 'res_adapter') or \ |
| | (isinstance(tuner_cfg, dict) and 'res_adapter' in tuner_cfg): |
| | tuner_cfg = tuner_cfg['res_adapter'] if isinstance(tuner_cfg, dict) else tuner_cfg |
| | self.tuner = ResAdapter( |
| | dim=dim, |
| | layer_num=layer_num, |
| | depth=depth, |
| | zero_init_last=zero_init_last, |
| | stage=stage, |
| | tuner_cfg=tuner_cfg, |
| | **kwargs) |
| | elif (isinstance(tuner_cfg, str) and tuner_cfg == 'res_group_adapter') or \ |
| | (isinstance(tuner_cfg, dict) and 'res_group_adapter' in tuner_cfg): |
| | tuner_cfg = tuner_cfg['res_group_adapter'] if isinstance(tuner_cfg, dict) else tuner_cfg |
| | self.tuner = ResGroupAdapter( |
| | dim=dim, |
| | layer_num=layer_num, |
| | depth=depth, |
| | zero_init_last=zero_init_last, |
| | stage=stage, |
| | tuner_cfg=tuner_cfg, |
| | **kwargs) |
| | elif (isinstance(tuner_cfg, str) and tuner_cfg == 'upsample') or \ |
| | (isinstance(tuner_cfg, dict) and 'upsample' in tuner_cfg): |
| | tuner_cfg = tuner_cfg['upsample'] if isinstance(tuner_cfg, dict) else tuner_cfg |
| | if 'upsample_out_channels' in kwargs: |
| | out_channels = kwargs['upsample_out_channels'] |
| | use_conv = True if out_channels else False |
| | else: |
| | out_channels = dim |
| | use_conv = False |
| | self.tuner = Upsample( |
| | channels=dim, use_conv=use_conv, out_channels=out_channels, tuner_cfg=tuner_cfg, **kwargs) |
| | else: |
| | self.tuner = Identity() |
| |
|
| | def forward(self, x, *args, **kwargs): |
| | if self.tuner_cfg == 'zero' or 'zero' in self.tuner_cfg: |
| | x_out = 0.0 |
| | else: |
| | x_out = self.tuner(x, *args, **kwargs) |
| | return x_out |
| |
|
| |
|
| | class ResAdapter(nn.Module): |
| |
|
| | def __init__(self, |
| | dim, |
| | layer_num=-1, |
| | depth=-1, |
| | zero_init_last=False, |
| | stage='', |
| | tuner_cfg=None, |
| | act_layer=nn.GELU, |
| | **kwargs): |
| | super(ResAdapter, self).__init__() |
| | self.dim = dim |
| | self.layer_num = layer_num |
| | self.depth = depth |
| |
|
| | self.adapter_length = tuner_cfg['adapter_length'] if 'adapter_length' in tuner_cfg else 32 |
| | self.adapter_type = tuner_cfg['adapter_type'] if 'adapter_type' in tuner_cfg else None |
| | self.adapter_weight = tuner_cfg['adapter_weight'] if 'adapter_weight' in tuner_cfg else None |
| |
|
| | self.adapter_length = self.adapter_length[self.layer_num] if isinstance(self.adapter_length, |
| | list) else self.adapter_length |
| | assert isinstance(self.adapter_length, int) or (isinstance(self.adapter_length, tuple) |
| | and len(self.adapter_length) == 3) |
| | if isinstance(self.adapter_length, int): |
| | self.ln1 = nn.Linear(dim, self.adapter_length) |
| | else: |
| | self.ln1 = nn.Linear(self.adapter_length[0], self.adapter_length[1]) |
| | self.activate = act_layer() |
| | if isinstance(self.adapter_length, int): |
| | self.ln2 = nn.Linear(self.adapter_length, dim) |
| | else: |
| | self.ln2 = nn.Linear(self.adapter_length[1], self.adapter_length[2]) |
| | dim = self.adapter_length[2] |
| |
|
| | self._xavier_init_weights(self.ln1) |
| | if zero_init_last and layer_num == depth - 1: |
| | self._zero_init_weights(self.ln2) |
| | else: |
| | self._xavier_init_weights(self.ln2) |
| |
|
| | self.scaling = init_weight_type(dim, self.adapter_weight) |
| | self._prepared = False |
| |
|
| | def _zero_init_weights(self, m): |
| | if isinstance(m, nn.Linear): |
| | nn.init.zeros_(m.weight) |
| | nn.init.zeros_(m.bias) |
| |
|
| | def _kaiming_init_weights(self, m): |
| | if isinstance(m, nn.Linear): |
| | nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5)) |
| | nn.init.normal_(m.bias) |
| |
|
| | def _xavier_init_weights(self, m): |
| | if isinstance(m, nn.Linear): |
| | nn.init.xavier_uniform_(m.weight) |
| | nn.init.normal_(m.bias, std=1e-6) |
| |
|
| | def forward(self, x): |
| | if not self._prepared: |
| | self.ln1.to(x.device) |
| | self.activate.to(x.device) |
| | self.ln2.to(x.device) |
| | self._prepared = True |
| |
|
| | x_dtype = x.dtype |
| | x = x.to(self.ln1.weight.dtype) |
| | x_shortcut = x |
| | if len(x_shortcut.size()) == 4: |
| | B, C, N1, N2 = x.size() |
| | x = x.view(x_shortcut.size()[0], x_shortcut.size()[1], -1).permute(0, 2, 1) |
| |
|
| | x_adapter = self.ln2(self.activate(self.ln1(x))) |
| |
|
| | if self.adapter_weight: |
| | x_adapter = apply_data_weight(x_adapter, self.scaling, self.adapter_weight) |
| |
|
| | if len(x_shortcut.size()) == 4: |
| | x_adapter = x_adapter.permute(0, 2, 1).view(x_shortcut.size()[0], |
| | x_adapter.size()[-1], |
| | x_shortcut.size()[2], |
| | x_shortcut.size()[3]) |
| | x_out = x_shortcut + x_adapter |
| | return x_out.to(x_dtype) |
| |
|
| |
|
| | class ResGroupAdapter(nn.Module): |
| |
|
| | def __init__(self, |
| | dim, |
| | layer_num=-1, |
| | depth=-1, |
| | zero_init_last=False, |
| | stage='', |
| | tuner_cfg=None, |
| | act_layer=nn.GELU, |
| | **kwargs): |
| | super(ResGroupAdapter, self).__init__() |
| | self.dim = dim |
| | self.layer_num = layer_num |
| | self.depth = depth |
| |
|
| | self.adapter_type = tuner_cfg['adapter_type'] if 'adapter_type' in tuner_cfg else None |
| | self.adapter_weight = tuner_cfg['adapter_weight'] if 'adapter_weight' in tuner_cfg else None |
| |
|
| | self.adapter_dim = tuner_cfg['dim'] if 'dim' in tuner_cfg else dim |
| | self.adapter_head = tuner_cfg['head'] if 'head' in tuner_cfg else 4 |
| | self.adapter_scale_factor = tuner_cfg['scale_factor'] if 'scale_factor' in tuner_cfg else 2 |
| |
|
| | assert self.adapter_dim % self.adapter_head == 0, 'adapter dim should be divisible by adapter head' |
| | self.dim_mlp = self.adapter_dim // self.adapter_head |
| |
|
| | self.ln1 = nn.Linear(self.dim_mlp, self.dim_mlp * self.adapter_scale_factor) |
| | self.ln2 = nn.Linear(self.dim_mlp * self.adapter_scale_factor, self.dim_mlp) |
| | self.activate = act_layer() |
| |
|
| | self._kaiming_init_weights(self.ln1) |
| | if zero_init_last and layer_num == depth - 1: |
| | self._zero_init_weights(self.ln2) |
| | else: |
| | self._kaiming_init_weights(self.ln2) |
| | self.scaling = init_weight_type(dim, self.adapter_weight) |
| | self._prepared = False |
| |
|
| | def _zero_init_weights(self, m): |
| | if isinstance(m, nn.Linear): |
| | nn.init.zeros_(m.weight) |
| | nn.init.zeros_(m.bias) |
| |
|
| | def _kaiming_init_weights(self, m): |
| | if isinstance(m, nn.Linear): |
| | nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5)) |
| | nn.init.normal_(m.bias) |
| |
|
| | def _xavier_init_weights(self, m): |
| | if isinstance(m, nn.Linear): |
| | nn.init.xavier_uniform_(m.weight) |
| | nn.init.normal_(m.bias, std=1e-6) |
| |
|
| | def forward(self, x): |
| | if not self._prepared: |
| | self.ln1.to(x.device) |
| | self.activate.to(x.device) |
| | self.ln2.to(x.device) |
| | self._prepared = True |
| |
|
| | x_dtype = x.dtype |
| | x = x.to(self.ln1.weight.dtype) |
| | x_shortcut = x |
| |
|
| | batch, inner_dim, height, width = x.shape |
| |
|
| | x_adapter = x.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) |
| |
|
| | x_adapter = rearrange(x_adapter, 'b n (c h) -> (b h) n c', h=self.adapter_head) |
| | x_adapter = self.ln2(self.activate(self.ln1(x_adapter))) |
| | x_adapter = rearrange(x_adapter, '(b h) n c -> b n (c h)', h=self.adapter_head) |
| |
|
| | if self.adapter_weight: |
| | x_adapter = apply_data_weight(x_adapter, self.scaling, self.adapter_weight) |
| |
|
| | x_adapter = x_adapter.reshape(batch, height, width, -1).permute(0, 3, 1, 2).contiguous() |
| | x_out = x_shortcut + x_adapter |
| |
|
| | return x_out.to(x_dtype) |
| |
|
| |
|
| | class Identity(nn.Module): |
| |
|
| | def __init__(self): |
| | super().__init__() |
| |
|
| | def forward(self, inputs, *args, **kwargs): |
| | return inputs |
| |
|
| |
|
| | class Upsample(nn.Module): |
| | """ |
| | An upsampling layer with an optional convolution. |
| | :param channels: channels in the inputs and outputs. |
| | :param use_conv: a bool determining if a convolution is applied. |
| | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then |
| | upsampling occurs in the inner-two dimensions. |
| | """ |
| |
|
| | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, **kwargs): |
| | super().__init__() |
| | self.channels = channels |
| | self.out_channels = out_channels or channels |
| | self.use_conv = use_conv |
| | if use_conv: |
| | self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=padding) |
| | self.init_weights() |
| |
|
| | def init_weights(self): |
| |
|
| | def _init_weights(m): |
| | if isinstance(m, nn.Conv2d): |
| | nn.init.zeros_(m.weight) |
| | nn.init.zeros_(m.bias) |
| |
|
| | self.apply(_init_weights) |
| |
|
| | def forward(self, x, target_size=None, *args, **kwargs): |
| | assert x.shape[1] == self.channels |
| | if target_size is None: |
| | x = F.interpolate(x.float(), scale_factor=2, mode='nearest').type_as(x) |
| | else: |
| | x = F.interpolate(x.float(), target_size, mode='nearest').type_as(x) |
| | if self.use_conv: |
| | x = self.conv(x) |
| | return x |
| |
|
| |
|
| | def init_weight_type(dim, weight_type): |
| | if weight_type is None: |
| | scaling = None |
| | elif weight_type == 'gate': |
| | scaling = nn.Linear(dim, 1) |
| | elif weight_type == 'scale': |
| | scaling = nn.Parameter(torch.Tensor(1)) |
| | scaling.data.fill_(1) |
| | elif weight_type == 'scale_kv': |
| | scaling_k = nn.Parameter(torch.Tensor(1)) |
| | scaling_k.data.fill_(1) |
| | scaling_v = nn.Parameter(torch.Tensor(1)) |
| | scaling_v.data.fill_(1) |
| | scaling = (scaling_k, scaling_v) |
| | elif weight_type == 'scale_channel': |
| | scaling = nn.Parameter(torch.Tensor(dim)) |
| | scaling.data.fill_(1) |
| | elif weight_type == 'scale_kv_channel': |
| | scaling_k = nn.Parameter(torch.Tensor(dim)) |
| | scaling_k.data.fill_(1) |
| | scaling_v = nn.Parameter(torch.Tensor(dim)) |
| | scaling_v.data.fill_(1) |
| | scaling = (scaling_k, scaling_v) |
| | elif weight_type and weight_type.startswith('scalar'): |
| | scaling = float(weight_type.split('_')[-1]) |
| | else: |
| | scaling = None |
| | return scaling |
| |
|
| |
|
| | def apply_data_weight(data, scaling, weight_type): |
| | if weight_type in ['gate']: |
| | scaling = torch.mean(torch.sigmoid(scaling(data)), dim=1).view(-1, 1, 1) |
| | elif weight_type in ['scale', 'scale_channel'] or weight_type.startswith('scalar'): |
| | scaling = scaling |
| | else: |
| | scaling = None |
| | if scaling is not None: |
| | data = data * scaling |
| | return data |
| |
|
| |
|
| | def detach_tensors(feats): |
| | if type(feats) in [list, tuple]: |
| | feats = [detach_tensors(feat) if feat is not None else None for feat in feats] |
| | elif isinstance(feats, dict): |
| | feats = {key: detach_tensors(val) for key, val in feats.items()} |
| | elif isinstance(feats, torch.Tensor): |
| | feats = feats.detach() |
| | else: |
| | feats = feats.detach() |
| | return feats |
| |
|
| |
|
| | def probe_tensors(module, feats, name): |
| | feats = detach_tensors(feats) |
| | setattr(module, name, feats) |
| |
|
| |
|
| | def probe_input_pre_hook(self, args): |
| | input = args[0] |
| | probe_tensors(self, input, 'probe_input_data') |
| | return args |
| |
|
| |
|
| | def probe_output_hook(self, args, result): |
| | output = result |
| | probe_tensors(self, output, 'probe_output_data') |
| | return output |
| |
|