|
|
|
|
|
import math
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.utils.checkpoint as cp
|
|
|
from tomesd.merge import bipartite_soft_matching_random2d
|
|
|
|
|
|
from ...utils import PatchEmbed
|
|
|
from ...utils import nchw_to_nlc, nlc_to_nchw
|
|
|
from ...utils import MODELS
|
|
|
from ...utils import Conv2d, build_activation_layer, build_norm_layer, build_dropout
|
|
|
from ..base_module import BaseModule, MultiheadAttention, ModuleList, Sequential
|
|
|
from ..weight_init import (constant_init, normal_init,
|
|
|
trunc_normal_init)
|
|
|
|
|
|
|
|
|
class MixFFN(BaseModule):
|
|
|
"""An implementation of MixFFN of Segformer.
|
|
|
|
|
|
The differences between MixFFN & FFN:
|
|
|
1. Use 1X1 Conv to replace Linear layer.
|
|
|
2. Introduce 3X3 Conv to encode positional information.
|
|
|
Args:
|
|
|
embed_dims (int): The feature dimension. Same as
|
|
|
`MultiheadAttention`. Defaults: 256.
|
|
|
feedforward_channels (int): The hidden dimension of FFNs.
|
|
|
Defaults: 1024.
|
|
|
act_cfg (dict, optional): The activation config for FFNs.
|
|
|
Default: dict(type='ReLU')
|
|
|
ffn_drop (float, optional): Probability of an element to be
|
|
|
zeroed in FFN. Default 0.0.
|
|
|
dropout_layer (obj:`ConfigDict`): The dropout_layer used
|
|
|
when adding the shortcut.
|
|
|
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
|
|
Default: None.
|
|
|
"""
|
|
|
|
|
|
def __init__(self,
|
|
|
embed_dims,
|
|
|
feedforward_channels,
|
|
|
act_cfg=dict(type='GELU'),
|
|
|
ffn_drop=0.,
|
|
|
dropout_layer=None,
|
|
|
init_cfg=None):
|
|
|
super().__init__(init_cfg)
|
|
|
|
|
|
self.embed_dims = embed_dims
|
|
|
self.feedforward_channels = feedforward_channels
|
|
|
self.act_cfg = act_cfg
|
|
|
self.activate = build_activation_layer(act_cfg)
|
|
|
|
|
|
in_channels = embed_dims
|
|
|
fc1 = Conv2d(
|
|
|
in_channels=in_channels,
|
|
|
out_channels=feedforward_channels,
|
|
|
kernel_size=1,
|
|
|
stride=1,
|
|
|
bias=True)
|
|
|
|
|
|
pe_conv = Conv2d(
|
|
|
in_channels=feedforward_channels,
|
|
|
out_channels=feedforward_channels,
|
|
|
kernel_size=3,
|
|
|
stride=1,
|
|
|
padding=(3 - 1) // 2,
|
|
|
bias=True,
|
|
|
groups=feedforward_channels)
|
|
|
fc2 = Conv2d(
|
|
|
in_channels=feedforward_channels,
|
|
|
out_channels=in_channels,
|
|
|
kernel_size=1,
|
|
|
stride=1,
|
|
|
bias=True)
|
|
|
drop = nn.Dropout(ffn_drop)
|
|
|
layers = [fc1, pe_conv, self.activate, drop, fc2, drop]
|
|
|
self.layers = Sequential(*layers)
|
|
|
self.dropout_layer = build_dropout(
|
|
|
dropout_layer) if dropout_layer else torch.nn.Identity()
|
|
|
|
|
|
def forward(self, x, hw_shape, identity=None):
|
|
|
out = nlc_to_nchw(x, hw_shape)
|
|
|
out = self.layers(out)
|
|
|
out = nchw_to_nlc(out)
|
|
|
if identity is None:
|
|
|
identity = x
|
|
|
return identity + self.dropout_layer(out)
|
|
|
|
|
|
|
|
|
class EfficientMultiheadAttention(MultiheadAttention):
|
|
|
"""An implementation of Efficient Multi-head Attention of Segformer.
|
|
|
|
|
|
This module is modified from MultiheadAttention which is a module from
|
|
|
mmcv.cnn.bricks.transformer.
|
|
|
Args:
|
|
|
embed_dims (int): The embedding dimension.
|
|
|
num_heads (int): Parallel attention heads.
|
|
|
attn_drop (float): A Dropout layer on attn_output_weights.
|
|
|
Default: 0.0.
|
|
|
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
|
|
|
Default: 0.0.
|
|
|
dropout_layer (obj:`ConfigDict`): The dropout_layer used
|
|
|
when adding the shortcut. Default: None.
|
|
|
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
|
|
Default: None.
|
|
|
batch_first (bool): Key, Query and Value are shape of
|
|
|
(batch, n, embed_dim)
|
|
|
or (n, batch, embed_dim). Default: False.
|
|
|
qkv_bias (bool): enable bias for qkv if True. Default True.
|
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
|
Default: dict(type='LN').
|
|
|
sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head
|
|
|
Attention of Segformer. Default: 1.
|
|
|
"""
|
|
|
|
|
|
def __init__(self,
|
|
|
embed_dims,
|
|
|
num_heads,
|
|
|
attn_drop=0.,
|
|
|
proj_drop=0.,
|
|
|
dropout_layer=None,
|
|
|
init_cfg=None,
|
|
|
batch_first=True,
|
|
|
qkv_bias=False,
|
|
|
tome_cfg=dict(),
|
|
|
norm_cfg=dict(type='LN'),
|
|
|
sr_ratio=1):
|
|
|
super().__init__(
|
|
|
embed_dims,
|
|
|
num_heads,
|
|
|
attn_drop,
|
|
|
proj_drop,
|
|
|
dropout_layer=dropout_layer,
|
|
|
init_cfg=init_cfg,
|
|
|
batch_first=batch_first,
|
|
|
bias=qkv_bias)
|
|
|
|
|
|
self.q_mode = tome_cfg.get('q_mode')
|
|
|
self.kv_mode = tome_cfg.get('kv_mode')
|
|
|
self.tome_cfg = tome_cfg
|
|
|
|
|
|
self.sr_ratio = sr_ratio
|
|
|
if sr_ratio > 1:
|
|
|
self.sr = Conv2d(
|
|
|
in_channels=embed_dims,
|
|
|
out_channels=embed_dims,
|
|
|
kernel_size=sr_ratio,
|
|
|
stride=sr_ratio)
|
|
|
|
|
|
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
|
|
|
|
|
|
def forward(self, x, hw_shape, identity=None):
|
|
|
x_q = x
|
|
|
|
|
|
if self.sr_ratio > 1:
|
|
|
x_kv = nlc_to_nchw(x, hw_shape)
|
|
|
x_kv = self.sr(x_kv)
|
|
|
x_kv = nchw_to_nlc(x_kv)
|
|
|
x_kv = self.norm(x_kv)
|
|
|
else:
|
|
|
x_kv = x
|
|
|
|
|
|
|
|
|
if self.kv_mode == 'n2d':
|
|
|
kv_hw_shape = (int(hw_shape[0] / self.sr_ratio), int(hw_shape[1] / self.sr_ratio))
|
|
|
x_kv = nlc_to_nchw(x_kv, kv_hw_shape)
|
|
|
x_kv = torch.nn.functional.avg_pool2d(x_kv, kernel_size=self.tome_cfg['kv_s'],
|
|
|
stride=self.tome_cfg['kv_s'],
|
|
|
ceil_mode=True)
|
|
|
x_kv = nchw_to_nlc(x_kv)
|
|
|
|
|
|
|
|
|
if self.kv_mode == 'bsm':
|
|
|
w_kv = int(hw_shape[1] / self.sr_ratio)
|
|
|
h_kv = int(hw_shape[0] / self.sr_ratio)
|
|
|
merge, unmerge = bipartite_soft_matching_random2d(metric=x_kv, w=w_kv, h=h_kv,
|
|
|
r=int(x_kv.size()[1] * self.tome_cfg['kv_r']),
|
|
|
sx=self.tome_cfg['kv_sx'], sy=self.tome_cfg['kv_sy'],
|
|
|
no_rand=True)
|
|
|
x_kv = merge(x_kv)
|
|
|
|
|
|
if identity is None:
|
|
|
identity = x_q
|
|
|
|
|
|
|
|
|
if self.q_mode == 'n1d':
|
|
|
x_q = x_q.transpose(-2, -1)
|
|
|
x_q = torch.nn.functional.avg_pool1d(x_q, kernel_size=self.tome_cfg['q_s'],
|
|
|
stride=self.tome_cfg['q_s'],
|
|
|
ceil_mode=True)
|
|
|
x_q = x_q.transpose(-2, -1)
|
|
|
|
|
|
|
|
|
if self.q_mode == 'n2d':
|
|
|
reduced_hw = (int(torch.ceil(torch.tensor(hw_shape[0] / self.tome_cfg['q_s'][0]))),
|
|
|
int(torch.ceil(torch.tensor(hw_shape[1] / self.tome_cfg['q_s'][1]))))
|
|
|
x_q = nlc_to_nchw(x_q, hw_shape)
|
|
|
x_q = torch.nn.functional.avg_pool2d(x_q, kernel_size=self.tome_cfg['q_s'],
|
|
|
stride=self.tome_cfg['q_s'],
|
|
|
ceil_mode=True)
|
|
|
x_q = nchw_to_nlc(x_q)
|
|
|
|
|
|
|
|
|
if self.q_mode == 'bsm':
|
|
|
merge, unmerge = bipartite_soft_matching_random2d(metric=x_q, w=hw_shape[1], h=hw_shape[0],
|
|
|
r=int(x_q.size()[1] * self.tome_cfg['q_r']),
|
|
|
sx=self.tome_cfg['q_sx'], sy=self.tome_cfg['q_sy'],
|
|
|
no_rand=True)
|
|
|
x_q = merge(x_q)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.batch_first:
|
|
|
x_q = x_q.transpose(0, 1)
|
|
|
x_kv = x_kv.transpose(0, 1)
|
|
|
out = self.attn(query=x_q, key=x_kv, value=x_kv)[0]
|
|
|
if self.batch_first:
|
|
|
out = out.transpose(0, 1)
|
|
|
|
|
|
|
|
|
if self.q_mode == 'bsm':
|
|
|
out = unmerge(out)
|
|
|
|
|
|
|
|
|
if self.q_mode == 'n1d':
|
|
|
out = out.transpose(-2, -1)
|
|
|
out = torch.nn.functional.interpolate(out, size=identity.size()[-2])
|
|
|
out = out.transpose(-2, -1)
|
|
|
|
|
|
|
|
|
if self.q_mode == 'n2d':
|
|
|
out = nlc_to_nchw(out, reduced_hw)
|
|
|
out = torch.nn.functional.interpolate(out, size=hw_shape)
|
|
|
out = nchw_to_nlc(out)
|
|
|
|
|
|
return identity + self.dropout_layer(self.proj_drop(out))
|
|
|
|
|
|
|
|
|
class TransformerEncoderLayer(BaseModule):
|
|
|
"""Implements one encoder layer in Segformer.
|
|
|
|
|
|
Args:
|
|
|
embed_dims (int): The feature dimension.
|
|
|
num_heads (int): Parallel attention heads.
|
|
|
feedforward_channels (int): The hidden dimension for FFNs.
|
|
|
drop_rate (float): Probability of an element to be zeroed.
|
|
|
after the feed forward layer. Default 0.0.
|
|
|
attn_drop_rate (float): The drop out rate for attention layer.
|
|
|
Default 0.0.
|
|
|
drop_path_rate (float): stochastic depth rate. Default 0.0.
|
|
|
qkv_bias (bool): enable bias for qkv if True.
|
|
|
Default: True.
|
|
|
act_cfg (dict): The activation config for FFNs.
|
|
|
Default: dict(type='GELU').
|
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
|
Default: dict(type='LN').
|
|
|
batch_first (bool): Key, Query and Value are shape of
|
|
|
(batch, n, embed_dim)
|
|
|
or (n, batch, embed_dim). Default: False.
|
|
|
init_cfg (dict, optional): Initialization config dict.
|
|
|
Default:None.
|
|
|
sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head
|
|
|
Attention of Segformer. Default: 1.
|
|
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
|
|
some memory while slowing down the training speed. Default: False.
|
|
|
"""
|
|
|
|
|
|
def __init__(self,
|
|
|
embed_dims,
|
|
|
num_heads,
|
|
|
feedforward_channels,
|
|
|
drop_rate=0.,
|
|
|
attn_drop_rate=0.,
|
|
|
drop_path_rate=0.,
|
|
|
qkv_bias=True,
|
|
|
tome_cfg=dict(),
|
|
|
act_cfg=dict(type='GELU'),
|
|
|
norm_cfg=dict(type='LN'),
|
|
|
batch_first=True,
|
|
|
sr_ratio=1,
|
|
|
with_cp=False):
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
|
|
|
|
|
|
self.attn = EfficientMultiheadAttention(
|
|
|
embed_dims=embed_dims,
|
|
|
num_heads=num_heads,
|
|
|
attn_drop=attn_drop_rate,
|
|
|
proj_drop=drop_rate,
|
|
|
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
|
|
batch_first=batch_first,
|
|
|
qkv_bias=qkv_bias,
|
|
|
tome_cfg=tome_cfg,
|
|
|
norm_cfg=norm_cfg,
|
|
|
sr_ratio=sr_ratio)
|
|
|
|
|
|
|
|
|
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
|
|
|
|
|
|
self.ffn = MixFFN(
|
|
|
embed_dims=embed_dims,
|
|
|
feedforward_channels=feedforward_channels,
|
|
|
ffn_drop=drop_rate,
|
|
|
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
|
|
act_cfg=act_cfg)
|
|
|
|
|
|
self.with_cp = with_cp
|
|
|
|
|
|
def forward(self, x, hw_shape):
|
|
|
|
|
|
def _inner_forward(x):
|
|
|
x = self.attn(self.norm1(x), hw_shape, identity=x)
|
|
|
x = self.ffn(self.norm2(x), hw_shape, identity=x)
|
|
|
return x
|
|
|
|
|
|
if self.with_cp and x.requires_grad:
|
|
|
x = cp.checkpoint(_inner_forward, x)
|
|
|
else:
|
|
|
x = _inner_forward(x)
|
|
|
return x
|
|
|
|
|
|
|
|
|
@MODELS.register_module()
|
|
|
class MixVisionTransformer(BaseModule):
|
|
|
"""The backbone of Segformer.
|
|
|
|
|
|
This backbone is the implementation of `SegFormer: Simple and
|
|
|
Efficient Design for Semantic Segmentation with
|
|
|
Transformers <https://arxiv.org/abs/2105.15203>`_.
|
|
|
Args:
|
|
|
in_channels (int): Number of input channels. Default: 3.
|
|
|
embed_dims (int): Embedding dimension. Default: 768.
|
|
|
num_stags (int): The num of stages. Default: 4.
|
|
|
num_layers (Sequence[int]): The layer number of each transformer encode
|
|
|
layer. Default: [3, 4, 6, 3].
|
|
|
num_heads (Sequence[int]): The attention heads of each transformer
|
|
|
encode layer. Default: [1, 2, 4, 8].
|
|
|
patch_sizes (Sequence[int]): The patch_size of each overlapped patch
|
|
|
embedding. Default: [7, 3, 3, 3].
|
|
|
strides (Sequence[int]): The stride of each overlapped patch embedding.
|
|
|
Default: [4, 2, 2, 2].
|
|
|
sr_ratios (Sequence[int]): The spatial reduction rate of each
|
|
|
transformer encode layer. Default: [8, 4, 2, 1].
|
|
|
out_indices (Sequence[int] | int): Output from which stages.
|
|
|
Default: (0, 1, 2, 3).
|
|
|
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
|
|
|
Default: 4.
|
|
|
qkv_bias (bool): Enable bias for qkv if True. Default: True.
|
|
|
drop_rate (float): Probability of an element to be zeroed.
|
|
|
Default 0.0
|
|
|
attn_drop_rate (float): The drop out rate for attention layer.
|
|
|
Default 0.0
|
|
|
drop_path_rate (float): stochastic depth rate. Default 0.0
|
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
|
Default: dict(type='LN')
|
|
|
act_cfg (dict): The activation config for FFNs.
|
|
|
Default: dict(type='GELU').
|
|
|
pretrained (str, optional): model pretrained path. Default: None.
|
|
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
|
|
Default: None.
|
|
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
|
|
some memory while slowing down the training speed. Default: False.
|
|
|
"""
|
|
|
|
|
|
def __init__(self,
|
|
|
in_channels=3,
|
|
|
embed_dims=64,
|
|
|
num_stages=4,
|
|
|
num_layers=[3, 4, 6, 3],
|
|
|
num_heads=[1, 2, 4, 8],
|
|
|
patch_sizes=[7, 3, 3, 3],
|
|
|
strides=[4, 2, 2, 2],
|
|
|
sr_ratios=[8, 4, 2, 1],
|
|
|
out_indices=(0, 1, 2, 3),
|
|
|
mlp_ratio=4,
|
|
|
qkv_bias=True,
|
|
|
drop_rate=0.,
|
|
|
attn_drop_rate=0.,
|
|
|
drop_path_rate=0.,
|
|
|
tome_cfg=[dict(), dict(), dict(), dict()],
|
|
|
act_cfg=dict(type='GELU'),
|
|
|
norm_cfg=dict(type='LN', eps=1e-6),
|
|
|
init_cfg=None,
|
|
|
with_cp=False,
|
|
|
down_sample=False):
|
|
|
super().__init__(init_cfg=init_cfg)
|
|
|
|
|
|
self.embed_dims = embed_dims
|
|
|
self.num_stages = num_stages
|
|
|
self.num_layers = num_layers
|
|
|
self.num_heads = num_heads
|
|
|
self.patch_sizes = patch_sizes
|
|
|
self.strides = strides
|
|
|
self.sr_ratios = sr_ratios
|
|
|
self.with_cp = with_cp
|
|
|
self.down_sample = down_sample
|
|
|
assert num_stages == len(num_layers) == len(num_heads) \
|
|
|
== len(patch_sizes) == len(strides) == len(sr_ratios)
|
|
|
|
|
|
self.out_indices = out_indices
|
|
|
assert max(out_indices) < self.num_stages
|
|
|
|
|
|
|
|
|
dpr = [
|
|
|
x.item()
|
|
|
for x in torch.linspace(0, drop_path_rate, sum(num_layers))
|
|
|
]
|
|
|
|
|
|
cur = 0
|
|
|
self.layers = ModuleList()
|
|
|
for i, num_layer in enumerate(num_layers):
|
|
|
embed_dims_i = embed_dims * num_heads[i]
|
|
|
patch_embed = PatchEmbed(
|
|
|
in_channels=in_channels,
|
|
|
embed_dims=embed_dims_i,
|
|
|
kernel_size=patch_sizes[i],
|
|
|
stride=strides[i],
|
|
|
padding=patch_sizes[i] // 2,
|
|
|
norm_cfg=norm_cfg)
|
|
|
layer = ModuleList([
|
|
|
TransformerEncoderLayer(
|
|
|
embed_dims=embed_dims_i,
|
|
|
num_heads=num_heads[i],
|
|
|
feedforward_channels=mlp_ratio * embed_dims_i,
|
|
|
drop_rate=drop_rate,
|
|
|
attn_drop_rate=attn_drop_rate,
|
|
|
drop_path_rate=dpr[cur + idx],
|
|
|
qkv_bias=qkv_bias,
|
|
|
tome_cfg=tome_cfg[i],
|
|
|
act_cfg=act_cfg,
|
|
|
norm_cfg=norm_cfg,
|
|
|
with_cp=with_cp,
|
|
|
sr_ratio=sr_ratios[i]) for idx in range(num_layer)
|
|
|
])
|
|
|
in_channels = embed_dims_i
|
|
|
|
|
|
norm = build_norm_layer(norm_cfg, embed_dims_i)[1]
|
|
|
self.layers.append(ModuleList([patch_embed, layer, norm]))
|
|
|
cur += num_layer
|
|
|
|
|
|
def init_weights(self):
|
|
|
if self.init_cfg is None:
|
|
|
for m in self.modules():
|
|
|
if isinstance(m, nn.Linear):
|
|
|
trunc_normal_init(m, std=.02, bias=0.)
|
|
|
elif isinstance(m, nn.LayerNorm):
|
|
|
constant_init(m, val=1.0, bias=0.)
|
|
|
elif isinstance(m, nn.Conv2d):
|
|
|
fan_out = m.kernel_size[0] * m.kernel_size[
|
|
|
1] * m.out_channels
|
|
|
fan_out //= m.groups
|
|
|
normal_init(
|
|
|
m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
|
|
|
else:
|
|
|
super().init_weights()
|
|
|
|
|
|
def forward(self, x):
|
|
|
if self.down_sample:
|
|
|
x = torch.nn.functional.interpolate(x, scale_factor=(0.5, 0.5))
|
|
|
outs = []
|
|
|
|
|
|
for i, layer in enumerate(self.layers):
|
|
|
x, hw_shape = layer[0](x)
|
|
|
for block in layer[1]:
|
|
|
x = block(x, hw_shape)
|
|
|
x = layer[2](x)
|
|
|
x = nlc_to_nchw(x, hw_shape)
|
|
|
if i in self.out_indices:
|
|
|
outs.append(x)
|
|
|
|
|
|
return outs
|
|
|
|