| | |
| | import torch.nn as nn |
| | from mmcv.cnn import build_norm_layer |
| |
|
| | from mmseg.registry import MODELS |
| |
|
| |
|
| | @MODELS.register_module() |
| | class Feature2Pyramid(nn.Module): |
| | """Feature2Pyramid. |
| | |
| | A neck structure connect ViT backbone and decoder_heads. |
| | |
| | Args: |
| | embed_dims (int): Embedding dimension. |
| | rescales (list[float]): Different sampling multiples were |
| | used to obtain pyramid features. Default: [4, 2, 1, 0.5]. |
| | norm_cfg (dict): Config dict for normalization layer. |
| | Default: dict(type='SyncBN', requires_grad=True). |
| | """ |
| |
|
| | def __init__(self, |
| | embed_dim, |
| | rescales=[4, 2, 1, 0.5], |
| | norm_cfg=dict(type='SyncBN', requires_grad=True)): |
| | super().__init__() |
| | self.rescales = rescales |
| | self.upsample_4x = None |
| | for k in self.rescales: |
| | if k == 4: |
| | self.upsample_4x = nn.Sequential( |
| | nn.ConvTranspose2d( |
| | embed_dim, embed_dim, kernel_size=2, stride=2), |
| | build_norm_layer(norm_cfg, embed_dim)[1], |
| | nn.GELU(), |
| | nn.ConvTranspose2d( |
| | embed_dim, embed_dim, kernel_size=2, stride=2), |
| | ) |
| | elif k == 2: |
| | self.upsample_2x = nn.Sequential( |
| | nn.ConvTranspose2d( |
| | embed_dim, embed_dim, kernel_size=2, stride=2)) |
| | elif k == 1: |
| | self.identity = nn.Identity() |
| | elif k == 0.5: |
| | self.downsample_2x = nn.MaxPool2d(kernel_size=2, stride=2) |
| | elif k == 0.25: |
| | self.downsample_4x = nn.MaxPool2d(kernel_size=4, stride=4) |
| | else: |
| | raise KeyError(f'invalid {k} for feature2pyramid') |
| |
|
| | def forward(self, inputs): |
| | assert len(inputs) == len(self.rescales) |
| | outputs = [] |
| | if self.upsample_4x is not None: |
| | ops = [ |
| | self.upsample_4x, self.upsample_2x, self.identity, |
| | self.downsample_2x |
| | ] |
| | else: |
| | ops = [ |
| | self.upsample_2x, self.identity, self.downsample_2x, |
| | self.downsample_4x |
| | ] |
| | for i in range(len(inputs)): |
| | outputs.append(ops[i](inputs[i])) |
| | return tuple(outputs) |
| |
|