| from copy import deepcopy |
| from typing import Sequence |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.utils.checkpoint as cp |
| from torch.nn.modules.batchnorm import _BatchNorm |
|
|
| from ..compat import ( |
| build_norm_layer, FFN, PatchEmbed, trunc_normal_, |
| BaseModule, ModuleList, PatchMerging, ShiftWindowMSA, WindowMSAV2, |
| resize_pos_embed, to_2tuple, BaseBackbone, |
| CheckpointLoader, load_state_dict, MultiheadAttention, get_root_logger, |
| ) |
|
|
| class SwinBlockV2(BaseModule): |
| """Swin Transformer V2 block. Use post normalization. |
| |
| Args: |
| embed_dims (int): Number of input channels. |
| num_heads (int): Number of attention heads. |
| window_size (int): The height and width of the window. Defaults to 7. |
| shift (bool): Shift the attention window or not. Defaults to False. |
| extra_norm (bool): Whether add extra norm at the end of main branch. |
| ffn_ratio (float): The expansion ratio of feedforward network hidden |
| layer channels. Defaults to 4. |
| drop_path (float): The drop path rate after attention and ffn. |
| Defaults to 0. |
| pad_small_map (bool): If True, pad the small feature map to the window |
| size, which is common used in detection and segmentation. If False, |
| avoid shifting window and shrink the window size to the size of |
| feature map, which is common used in classification. |
| Defaults to False. |
| attn_cfgs (dict): The extra config of Shift Window-MSA. |
| Defaults to empty dict. |
| ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict. |
| norm_cfg (dict): The config of norm layers. |
| Defaults to ``dict(type='LN')``. |
| with_cp (bool): Use checkpoint or not. Using checkpoint will save some |
| memory while slowing down the training speed. Defaults to False. |
| pretrained_window_size (int): Window size in pretrained. |
| init_cfg (dict, optional): The extra config for initialization. |
| Defaults to None. |
| """ |
|
|
| def __init__(self, |
| embed_dims, |
| num_heads, |
| window_size=8, |
| shift=False, |
| extra_norm=False, |
| ffn_ratio=4., |
| drop_path=0., |
| pad_small_map=False, |
| attn_cfgs=dict(), |
| ffn_cfgs=dict(), |
| norm_cfg=dict(type='LN'), |
| with_cp=False, |
| pretrained_window_size=0, |
| init_cfg=None): |
|
|
| super(SwinBlockV2, self).__init__(init_cfg) |
| self.with_cp = with_cp |
| self.extra_norm = extra_norm |
|
|
| _attn_cfgs = { |
| 'embed_dims': embed_dims, |
| 'num_heads': num_heads, |
| 'shift_size': window_size // 2 if shift else 0, |
| 'window_size': window_size, |
| 'dropout_layer': dict(type='DropPath', drop_prob=drop_path), |
| 'pad_small_map': pad_small_map, |
| **attn_cfgs |
| } |
| |
| _attn_cfgs.update( |
| window_msa=WindowMSAV2, |
| msa_cfg=dict( |
| pretrained_window_size=to_2tuple(pretrained_window_size))) |
| self.attn = ShiftWindowMSA(**_attn_cfgs) |
| self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] |
|
|
| _ffn_cfgs = { |
| 'embed_dims': embed_dims, |
| 'feedforward_channels': int(embed_dims * ffn_ratio), |
| 'num_fcs': 2, |
| 'ffn_drop': 0, |
| 'dropout_layer': dict(type='DropPath', drop_prob=drop_path), |
| 'act_cfg': dict(type='GELU'), |
| 'add_identity': False, |
| **ffn_cfgs |
| } |
| self.ffn = FFN(**_ffn_cfgs) |
| self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] |
|
|
| |
| if self.extra_norm: |
| self.norm3 = build_norm_layer(norm_cfg, embed_dims)[1] |
|
|
| def forward(self, x, hw_shape): |
|
|
| def _inner_forward(x): |
| |
| identity = x |
| x = self.attn(x, hw_shape) |
| x = self.norm1(x) |
| x = x + identity |
|
|
| identity = x |
| x = self.ffn(x) |
| x = self.norm2(x) |
| x = x + identity |
|
|
| if self.extra_norm: |
| x = self.norm3(x) |
|
|
| return x |
|
|
| if self.with_cp and x.requires_grad: |
| x = cp.checkpoint(_inner_forward, x) |
| else: |
| x = _inner_forward(x) |
|
|
| return x |
|
|
|
|
| class SwinBlockV2Sequence(BaseModule): |
| """Module with successive Swin Transformer blocks and downsample layer. |
| |
| Args: |
| embed_dims (int): Number of input channels. |
| depth (int): Number of successive swin transformer blocks. |
| num_heads (int): Number of attention heads. |
| window_size (int): The height and width of the window. Defaults to 7. |
| downsample (bool): Downsample the output of blocks by patch merging. |
| Defaults to False. |
| downsample_cfg (dict): The extra config of the patch merging layer. |
| Defaults to empty dict. |
| drop_paths (Sequence[float] | float): The drop path rate in each block. |
| Defaults to 0. |
| block_cfgs (Sequence[dict] | dict): The extra config of each block. |
| Defaults to empty dicts. |
| with_cp (bool): Use checkpoint or not. Using checkpoint will save some |
| memory while slowing down the training speed. Defaults to False. |
| pad_small_map (bool): If True, pad the small feature map to the window |
| size, which is common used in detection and segmentation. If False, |
| avoid shifting window and shrink the window size to the size of |
| feature map, which is common used in classification. |
| Defaults to False. |
| extra_norm_every_n_blocks (int): Add extra norm at the end of main |
| branch every n blocks. Defaults to 0, which means no needs for |
| extra norm layer. |
| pretrained_window_size (int): Window size in pretrained. |
| init_cfg (dict, optional): The extra config for initialization. |
| Defaults to None. |
| """ |
|
|
| def __init__(self, |
| embed_dims, |
| depth, |
| num_heads, |
| window_size=8, |
| downsample=False, |
| downsample_cfg=dict(), |
| drop_paths=0., |
| block_cfgs=dict(), |
| with_cp=False, |
| pad_small_map=False, |
| extra_norm_every_n_blocks=0, |
| pretrained_window_size=0, |
| init_cfg=None): |
| super().__init__(init_cfg) |
|
|
| if not isinstance(drop_paths, Sequence): |
| drop_paths = [drop_paths] * depth |
|
|
| if not isinstance(block_cfgs, Sequence): |
| block_cfgs = [deepcopy(block_cfgs) for _ in range(depth)] |
|
|
| if downsample: |
| self.out_channels = 2 * embed_dims |
| _downsample_cfg = { |
| 'in_channels': embed_dims, |
| 'out_channels': self.out_channels, |
| 'norm_cfg': dict(type='LN'), |
| **downsample_cfg |
| } |
| self.downsample = PatchMerging(**_downsample_cfg) |
| else: |
| self.out_channels = embed_dims |
| self.downsample = None |
|
|
| self.blocks = ModuleList() |
| for i in range(depth): |
| extra_norm = True if extra_norm_every_n_blocks and \ |
| (i + 1) % extra_norm_every_n_blocks == 0 else False |
| _block_cfg = { |
| 'embed_dims': self.out_channels, |
| 'num_heads': num_heads, |
| 'window_size': window_size, |
| 'shift': False if i % 2 == 0 else True, |
| 'extra_norm': extra_norm, |
| 'drop_path': drop_paths[i], |
| 'with_cp': with_cp, |
| 'pad_small_map': pad_small_map, |
| 'pretrained_window_size': pretrained_window_size, |
| **block_cfgs[i] |
| } |
| block = SwinBlockV2(**_block_cfg) |
| self.blocks.append(block) |
|
|
| def forward(self, x, in_shape): |
| if self.downsample: |
| x, out_shape = self.downsample(x, in_shape) |
| else: |
| out_shape = in_shape |
|
|
| for block in self.blocks: |
| x = block(x, out_shape) |
|
|
| return x, out_shape |
|
|
|
|
| class SwinTransformerV2(BaseBackbone): |
| """Swin Transformer V2. |
| |
| A PyTorch implement of : `Swin Transformer V2: |
| Scaling Up Capacity and Resolution |
| <https://arxiv.org/abs/2111.09883>`_ |
| |
| Inspiration from |
| https://github.com/microsoft/Swin-Transformer |
| |
| Args: |
| arch (str | dict): Swin Transformer architecture. If use string, choose |
| from 'tiny', 'small', 'base' and 'large'. If use dict, it should |
| have below keys: |
| |
| - **embed_dims** (int): The dimensions of embedding. |
| - **depths** (List[int]): The number of blocks in each stage. |
| - **num_heads** (List[int]): The number of heads in attention |
| modules of each stage. |
| - **extra_norm_every_n_blocks** (int): Add extra norm at the end |
| of main branch every n blocks. |
| |
| Defaults to 'tiny'. |
| img_size (int | tuple): The expected input image shape. Because we |
| support dynamic input shape, just set the argument to the most |
| common input image shape. Defaults to 224. |
| patch_size (int | tuple): The patch size in patch embedding. |
| Defaults to 4. |
| in_channels (int): The num of input channels. Defaults to 3. |
| window_size (int | Sequence): The height and width of the window. |
| Defaults to 7. |
| drop_rate (float): Dropout rate after embedding. Defaults to 0. |
| drop_path_rate (float): Stochastic depth rate. Defaults to 0.1. |
| use_abs_pos_embed (bool): If True, add absolute position embedding to |
| the patch embedding. Defaults to False. |
| interpolate_mode (str): Select the interpolate mode for absolute |
| position embeding vector resize. Defaults to "bicubic". |
| with_cp (bool): Use checkpoint or not. Using checkpoint will save some |
| memory while slowing down the training speed. Defaults to False. |
| frozen_stages (int): Stages to be frozen (stop grad and set eval mode). |
| -1 means not freezing any parameters. Defaults to -1. |
| norm_eval (bool): Whether to set norm layers to eval mode, namely, |
| freeze running stats (mean and var). Note: Effect on Batch Norm |
| and its variants only. Defaults to False. |
| pad_small_map (bool): If True, pad the small feature map to the window |
| size, which is common used in detection and segmentation. If False, |
| avoid shifting window and shrink the window size to the size of |
| feature map, which is common used in classification. |
| Defaults to False. |
| norm_cfg (dict): Config dict for normalization layer for all output |
| features. Defaults to ``dict(type='LN')`` |
| stage_cfgs (Sequence[dict] | dict): Extra config dict for each |
| stage. Defaults to an empty dict. |
| patch_cfg (dict): Extra config dict for patch embedding. |
| Defaults to an empty dict. |
| pretrained_window_sizes (tuple(int)): Pretrained window sizes of |
| each layer. |
| init_cfg (dict, optional): The Config for initialization. |
| Defaults to None. |
| |
| Examples: |
| >>> from lib.models.backbones.swin_v2 import SwinTransformerV2 |
| >>> import torch |
| >>> extra_config = dict( |
| >>> arch='tiny', |
| >>> stage_cfgs=dict(downsample_cfg={'kernel_size': 3, |
| >>> 'padding': 'same'})) |
| >>> self = SwinTransformerV2(**extra_config) |
| >>> inputs = torch.rand(1, 3, 224, 224) |
| >>> output = self.forward(inputs) |
| >>> print(output.shape) |
| (1, 2592, 4) |
| """ |
| arch_zoo = { |
| **dict.fromkeys(['t', 'tiny'], |
| {'embed_dims': 96, |
| 'depths': [2, 2, 6, 2], |
| 'num_heads': [3, 6, 12, 24], |
| 'extra_norm_every_n_blocks': 0}), |
| **dict.fromkeys(['s', 'small'], |
| {'embed_dims': 96, |
| 'depths': [2, 2, 18, 2], |
| 'num_heads': [3, 6, 12, 24], |
| 'extra_norm_every_n_blocks': 0}), |
| **dict.fromkeys(['b', 'base'], |
| {'embed_dims': 128, |
| 'depths': [2, 2, 18, 2], |
| 'num_heads': [4, 8, 16, 32], |
| 'extra_norm_every_n_blocks': 0}), |
| **dict.fromkeys(['l', 'large'], |
| {'embed_dims': 192, |
| 'depths': [2, 2, 18, 2], |
| 'num_heads': [6, 12, 24, 48], |
| 'extra_norm_every_n_blocks': 0}), |
| |
| |
| **dict.fromkeys(['h', 'huge'], |
| {'embed_dims': 352, |
| 'depths': [2, 2, 18, 2], |
| 'num_heads': [8, 16, 32, 64], |
| 'extra_norm_every_n_blocks': 6}), |
| **dict.fromkeys(['g', 'giant'], |
| {'embed_dims': 512, |
| 'depths': [2, 2, 42, 4], |
| 'num_heads': [16, 32, 64, 128], |
| 'extra_norm_every_n_blocks': 6}), |
| } |
|
|
| _version = 1 |
| num_extra_tokens = 0 |
|
|
| def __init__(self, |
| arch='tiny', |
| img_size=256, |
| patch_size=4, |
| in_channels=3, |
| vocabulary_size=128, |
| window_size=8, |
| drop_rate=0., |
| drop_path_rate=0.1, |
| out_indices=(3, ), |
| use_abs_pos_embed=False, |
| interpolate_mode='bicubic', |
| with_cp=False, |
| frozen_stages=-1, |
| norm_eval=False, |
| pad_small_map=False, |
| norm_cfg=dict(type='LN'), |
| stage_cfgs=dict(downsample_cfg=dict(is_post_norm=True)), |
| patch_cfg=dict(), |
| pretrained_window_sizes=[0, 0, 0, 0], |
| init_cfg=None): |
| super(SwinTransformerV2, self).__init__(init_cfg=init_cfg) |
|
|
| if isinstance(arch, str): |
| arch = arch.lower() |
| assert arch in set(self.arch_zoo), \ |
| f'Arch {arch} is not in default archs {set(self.arch_zoo)}' |
| self.arch_settings = self.arch_zoo[arch] |
| else: |
| essential_keys = { |
| 'embed_dims', 'depths', 'num_heads', |
| 'extra_norm_every_n_blocks' |
| } |
| assert isinstance(arch, dict) and set(arch) == essential_keys, \ |
| f'Custom arch needs a dict with keys {essential_keys}' |
| self.arch_settings = arch |
|
|
| self.vocabulary_size = vocabulary_size + 1 |
| self.embed_dims = self.arch_settings['embed_dims'] |
| self.depths = self.arch_settings['depths'] |
| self.num_heads = self.arch_settings['num_heads'] |
| self.extra_norm_every_n_blocks = self.arch_settings[ |
| 'extra_norm_every_n_blocks'] |
| self.num_layers = len(self.depths) |
| self.out_indices = out_indices |
| self.use_abs_pos_embed = use_abs_pos_embed |
| self.interpolate_mode = interpolate_mode |
| self.frozen_stages = frozen_stages |
|
|
| if isinstance(window_size, int): |
| self.window_sizes = [window_size for _ in range(self.num_layers)] |
| elif isinstance(window_size, Sequence): |
| assert len(window_size) == self.num_layers, \ |
| f'Length of window_sizes {len(window_size)} is not equal to '\ |
| f'length of stages {self.num_layers}.' |
| self.window_sizes = window_size |
| else: |
| raise TypeError('window_size should be a Sequence or int.') |
|
|
| _patch_cfg = dict( |
| in_channels=in_channels, |
| input_size=img_size, |
| embed_dims=self.embed_dims, |
| conv_type='Conv2d', |
| kernel_size=patch_size, |
| stride=patch_size, |
| norm_cfg=dict(type='LN'), |
| ) |
| _patch_cfg.update(patch_cfg) |
| self.patch_embed = PatchEmbed(**_patch_cfg) |
| self.patch_resolution = self.patch_embed.init_out_size |
| self.patch_size = patch_size |
|
|
| if self.use_abs_pos_embed: |
| num_patches = self.patch_resolution[0] * self.patch_resolution[1] |
| self.absolute_pos_embed = nn.Parameter( |
| torch.zeros(1, num_patches, self.embed_dims)) |
| self._register_load_state_dict_pre_hook( |
| self._prepare_abs_pos_embed) |
|
|
| self._register_load_state_dict_pre_hook(self._delete_reinit_params) |
|
|
| self.drop_after_pos = nn.Dropout(p=drop_rate) |
| self.norm_eval = norm_eval |
|
|
| |
| total_depth = sum(self.depths) |
| dpr = [ |
| x.item() for x in torch.linspace(0, drop_path_rate, total_depth) |
| ] |
|
|
| self.stages = ModuleList() |
| embed_dims = [self.embed_dims] |
| for i, (depth, num_heads) in enumerate(zip(self.depths, |
| self.num_heads)): |
| if isinstance(stage_cfgs, Sequence): |
| stage_cfg = stage_cfgs[i] |
| else: |
| stage_cfg = deepcopy(stage_cfgs) |
| downsample = True if i > 0 else False |
| _stage_cfg = { |
| 'embed_dims': embed_dims[-1], |
| 'depth': depth, |
| 'num_heads': num_heads, |
| 'window_size': self.window_sizes[i], |
| 'downsample': downsample, |
| 'drop_paths': dpr[:depth], |
| 'with_cp': with_cp, |
| 'pad_small_map': pad_small_map, |
| 'extra_norm_every_n_blocks': self.extra_norm_every_n_blocks, |
| 'pretrained_window_size': pretrained_window_sizes[i], |
| **stage_cfg |
| } |
|
|
| stage = SwinBlockV2Sequence(**_stage_cfg) |
| self.stages.append(stage) |
|
|
| dpr = dpr[depth:] |
| embed_dims.append(stage.out_channels) |
|
|
| for i in out_indices: |
| if norm_cfg is not None: |
| norm_layer = build_norm_layer(norm_cfg, embed_dims[i + 1])[1] |
| else: |
| norm_layer = nn.Identity() |
|
|
| self.add_module(f'norm{i}', norm_layer) |
|
|
| def init_weights(self): |
| if (isinstance(self.init_cfg, dict) |
| and self.init_cfg['type'] == 'Pretrained'): |
| |
| logger = get_root_logger() |
| checkpoint = CheckpointLoader.load_checkpoint( |
| self.init_cfg['checkpoint'], logger=logger, map_location='cpu') |
|
|
| if 'state_dict' in checkpoint: |
| state_dict = checkpoint['state_dict'] |
| else: |
| state_dict = checkpoint |
| |
| |
| |
| |
| load_state_dict(self, state_dict, strict=False, logger=logger) |
| return |
| else: |
| super(SwinTransformerV2, self).init_weights() |
| if self.use_abs_pos_embed: |
| trunc_normal_(self.absolute_pos_embed, std=0.02) |
|
|
| def forward(self, x): |
| x, hw_shape = self.patch_embed(x) |
|
|
| if self.use_abs_pos_embed: |
| x = x + resize_pos_embed( |
| self.absolute_pos_embed, self.patch_resolution, hw_shape, |
| self.interpolate_mode, self.num_extra_tokens) |
| x = self.drop_after_pos(x) |
|
|
| outs = [] |
| for i, stage in enumerate(self.stages): |
| x, hw_shape = stage(x, hw_shape) |
| if i in self.out_indices: |
| norm_layer = getattr(self, f'norm{i}') |
| out = norm_layer(x) |
| out = out.view(-1, *hw_shape, |
| stage.out_channels).permute(0, 3, 1, |
| 2).contiguous() |
| outs.append(out) |
|
|
| return outs |
|
|
| def _freeze_stages(self): |
| if self.frozen_stages >= 0: |
| self.patch_embed.eval() |
| for param in self.patch_embed.parameters(): |
| param.requires_grad = False |
|
|
| for i in range(0, self.frozen_stages + 1): |
| m = self.stages[i] |
| m.eval() |
| for param in m.parameters(): |
| param.requires_grad = False |
| for i in self.out_indices: |
| if i <= self.frozen_stages: |
| for param in getattr(self, f'norm{i}').parameters(): |
| param.requires_grad = False |
|
|
| def train(self, mode=True): |
| super(SwinTransformerV2, self).train(mode) |
| self._freeze_stages() |
| if mode and self.norm_eval: |
| for m in self.modules(): |
| |
| if isinstance(m, _BatchNorm): |
| m.eval() |
|
|
| def _prepare_abs_pos_embed(self, state_dict, prefix, *args, **kwargs): |
| name = prefix + 'absolute_pos_embed' |
| if name not in state_dict.keys(): |
| return |
|
|
| ckpt_pos_embed_shape = state_dict[name].shape |
| if self.absolute_pos_embed.shape != ckpt_pos_embed_shape: |
| logger = get_root_logger() |
| logger.info( |
| 'Resize the absolute_pos_embed shape from ' |
| f'{ckpt_pos_embed_shape} to {self.absolute_pos_embed.shape}.') |
|
|
| ckpt_pos_embed_shape = to_2tuple( |
| int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens))) |
| pos_embed_shape = self.patch_embed.init_out_size |
|
|
| state_dict[name] = resize_pos_embed(state_dict[name], |
| ckpt_pos_embed_shape, |
| pos_embed_shape, |
| self.interpolate_mode, |
| self.num_extra_tokens) |
|
|
| def _delete_reinit_params(self, state_dict, prefix, *args, **kwargs): |
| |
| relative_position_index_keys = [ |
| k for k in state_dict.keys() if 'relative_position_index' in k |
| ] |
| for k in relative_position_index_keys: |
| del state_dict[k] |
|
|
| |
| relative_position_index_keys = [ |
| k for k in state_dict.keys() if 'relative_coords_table' in k |
| ] |
| for k in relative_position_index_keys: |
| del state_dict[k] |
|
|
| class Proj_MHSA(nn.Module): |
|
|
| def __init__( |
| self, |
| embed_dims, |
| proj_dims, |
| num_heads=16, |
| batch_first=True, |
| bias = True |
| ): |
| super().__init__() |
| self.proj_in = nn.Linear(in_features=embed_dims, out_features=proj_dims) |
| self.attn = MultiheadAttention( |
| embed_dims=proj_dims, |
| num_heads=num_heads, |
| batch_first=batch_first, |
| bias=bias |
| ) |
| self.proj_out = nn.Linear(in_features=proj_dims, out_features=embed_dims) |
| def forward(self, x): |
| x = self.proj_in(x) |
| x = self.attn(x, x, x) |
| x = self.proj_out(x) |
| return x |
|
|
| class SwinTransformerV2MSL(SwinTransformerV2): |
|
|
| def __init__(self, **kwargs): |
| if 'use_attn' in kwargs: |
| self.use_attn = kwargs.pop('use_attn') |
| else: |
| self.use_attn = False |
| if 'merge_stage' in kwargs: |
| self.merge_stage = kwargs.pop('merge_stage') |
| else: |
| self.merge_stage = 0 |
| if 'with_cls_pos' in kwargs: |
| self.with_cls_pos = kwargs.pop('with_cls_pos') |
| else: |
| self.with_cls_pos = False |
| super().__init__(**kwargs) |
|
|
| self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) |
| |
| self.vocabulary_token = nn.Parameter(torch.zeros(self.vocabulary_size, self.embed_dims)) |
| self.vocabulary_weight = nn.Parameter(torch.zeros(1, self.patch_size * self.patch_size)) |
| trunc_normal_(self.mask_token, mean=0., std=.02) |
| trunc_normal_(self.vocabulary_token, mean=0., std=.02) |
| |
| if self.use_attn: |
| self.attn1 = Proj_MHSA(embed_dims=352, proj_dims=256, num_heads=16, batch_first=True, bias = True) |
| self.attn2 = Proj_MHSA(embed_dims=704, proj_dims=512, num_heads=16, batch_first=True, bias = True) |
| self.attn3 = Proj_MHSA( embed_dims=1408, proj_dims=1024, num_heads=16, batch_first=True, bias = True) |
| self.attention_blocks = [self.attn1, self.attn2, self.attn3] |
| self.norm_attn = build_norm_layer(dict(type='LN'), 1408)[1] |
|
|
| def create_ann_token(self, anno_img): |
| B, H, W = anno_img.shape |
| ann_token = torch.index_select(self.vocabulary_token, 0, anno_img.reshape(-1)).reshape(B, H, W, -1) |
| assert H % self.patch_size == 0 and W % self.patch_size == 0 |
| nph, npw = H // self.patch_size, W // self.patch_size |
| weight = F.softmax(self.vocabulary_weight, dim=1) * self.patch_size * self.patch_size |
| weight = weight.reshape(1, 1, self.patch_size, 1, self.patch_size).repeat(1, nph, 1, npw, 1).reshape(1, H, W, 1) |
| ann_token = ann_token * weight |
| ann_token = F.avg_pool2d(torch.einsum('BHWC->BCHW', ann_token), self.patch_size, self.patch_size) |
| ann_token = torch.einsum('BCHW->BHWC', ann_token).reshape(B, nph * npw, self.embed_dims) |
| return ann_token |
|
|
| def forward(self, hr_img, anno_img, mask=None): |
| x, hw_shape = self.patch_embed(hr_img) |
| y = self.create_ann_token(anno_img) |
| assert x.shape == y.shape |
| B, L, C = y.shape |
| if mask is not None: |
| mask_tokens = self.mask_token.expand(B, L, -1) |
| w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens) |
| y = y * (1. - w) + mask_tokens * w |
|
|
| if self.merge_stage == 0: |
| x = (x + y) * 0.5 |
| else: |
| x = x.reshape(B, *hw_shape, C) |
| y = y.reshape(B, *hw_shape, C) |
| x = torch.cat((x, y), dim=2) |
| hw_shape = (hw_shape[0], hw_shape[1] * 2) |
| x = x.reshape(B, -1, C) |
| if self.use_abs_pos_embed: |
| x = x + resize_pos_embed( |
| self.absolute_pos_embed, self.patch_resolution, hw_shape, |
| self.interpolate_mode, self.num_extra_tokens) |
| if self.with_cls_pos: |
| hw_shape_half = [hw_shape[0], hw_shape[1] // 2] |
| x = x.reshape(B, *hw_shape, C) |
| x1 = x[:, :, :x.shape[2]//2, :].reshape(B, -1, C) |
| x2 = x[:, :, x.shape[2]//2:, :].reshape(B, -1, C) |
| x1 = x1 + resize_pos_embed( |
| self.absolute_pos_embed, self.patch_resolution, hw_shape_half, |
| self.interpolate_mode, self.num_extra_tokens) |
| x2 = x2 + resize_pos_embed( |
| self.absolute_pos_embed, self.patch_resolution, hw_shape_half, |
| self.interpolate_mode, self.num_extra_tokens) |
| x1 = x1.reshape(B, *hw_shape_half, C) |
| x2 = x2.reshape(B, *hw_shape_half, C) |
| x = torch.cat((x1, x2), dim=2).reshape(B, -1, C) |
| x = self.drop_after_pos(x) |
| outs = [] |
| merge_idx = self.merge_stage - 1 |
| for i, stage in enumerate(self.stages): |
| x, hw_shape = stage(x, hw_shape) |
| if i == merge_idx: |
| x = x.reshape(x.shape[0], *hw_shape, x.shape[-1]) |
| x = (x[:, :, :x.shape[2]//2] + x[:, :, x.shape[2]//2:]) * 0.5 |
| x = x.reshape(x.shape[0], -1, x.shape[-1]) |
| hw_shape = (hw_shape[0], hw_shape[1] // 2) |
| if self.use_attn: |
| if i <= len(self.attention_blocks) - 1: |
| x = x + self.attention_blocks[i](x) |
| if i == len(self.attention_blocks) - 1: |
| x = self.norm_attn(x) |
| if i in self.out_indices: |
| norm_layer = getattr(self, f'norm{i}') |
| out = norm_layer(x) |
| out = out.view(-1, *hw_shape, stage.out_channels).permute(0, 3, 1, 2).contiguous() |
| outs.append(out) |
| return outs |