|
|
import math |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from timm.layers import DropPath, trunc_normal_ |
|
|
from timm.models.convnext import ConvNeXtBlock |
|
|
from timm.models.mlp_mixer import MixerBlock |
|
|
from timm.models.swin_transformer import SwinTransformerBlock, window_partition, window_reverse |
|
|
from timm.models.vision_transformer import Block as ViTBlock |
|
|
|
|
|
from .layers import (HorBlock, ChannelAggregationFFN, MultiOrderGatedAggregation, |
|
|
PoolFormerBlock, CBlock, SABlock, MixMlp, VANBlock) |
|
|
|
|
|
|
|
|
class BasicConv2d(nn.Module): |
|
|
|
|
|
def __init__(self, |
|
|
in_channels, |
|
|
out_channels, |
|
|
kernel_size=3, |
|
|
stride=1, |
|
|
padding=0, |
|
|
dilation=1, |
|
|
upsampling=False, |
|
|
act_norm=False, |
|
|
act_inplace=True): |
|
|
super(BasicConv2d, self).__init__() |
|
|
self.act_norm = act_norm |
|
|
if upsampling is True: |
|
|
self.conv = nn.Sequential(*[ |
|
|
nn.Conv2d(in_channels, out_channels*4, kernel_size=kernel_size, |
|
|
stride=1, padding=padding, dilation=dilation), |
|
|
nn.PixelShuffle(2) |
|
|
]) |
|
|
else: |
|
|
self.conv = nn.Conv2d( |
|
|
in_channels, out_channels, kernel_size=kernel_size, |
|
|
stride=stride, padding=padding, dilation=dilation) |
|
|
|
|
|
self.norm = nn.GroupNorm(2, out_channels) |
|
|
self.act = nn.SiLU(inplace=act_inplace) |
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def _init_weights(self, m): |
|
|
if isinstance(m, (nn.Conv2d)): |
|
|
trunc_normal_(m.weight, std=.02) |
|
|
nn.init.constant_(m.bias, 0) |
|
|
|
|
|
def forward(self, x): |
|
|
y = self.conv(x) |
|
|
if self.act_norm: |
|
|
y = self.act(self.norm(y)) |
|
|
return y |
|
|
|
|
|
|
|
|
class ConvSC(nn.Module): |
|
|
|
|
|
def __init__(self, |
|
|
C_in, |
|
|
C_out, |
|
|
kernel_size=3, |
|
|
downsampling=False, |
|
|
upsampling=False, |
|
|
act_norm=True, |
|
|
act_inplace=True): |
|
|
super(ConvSC, self).__init__() |
|
|
|
|
|
stride = 2 if downsampling is True else 1 |
|
|
padding = (kernel_size - stride + 1) // 2 |
|
|
|
|
|
self.conv = BasicConv2d(C_in, C_out, kernel_size=kernel_size, stride=stride, |
|
|
upsampling=upsampling, padding=padding, |
|
|
act_norm=act_norm, act_inplace=act_inplace) |
|
|
|
|
|
def forward(self, x): |
|
|
y = self.conv(x) |
|
|
return y |
|
|
|
|
|
|
|
|
class GroupConv2d(nn.Module): |
|
|
|
|
|
def __init__(self, |
|
|
in_channels, |
|
|
out_channels, |
|
|
kernel_size=3, |
|
|
stride=1, |
|
|
padding=0, |
|
|
groups=1, |
|
|
act_norm=False, |
|
|
act_inplace=True): |
|
|
super(GroupConv2d, self).__init__() |
|
|
self.act_norm=act_norm |
|
|
if in_channels % groups != 0: |
|
|
groups=1 |
|
|
self.conv = nn.Conv2d( |
|
|
in_channels, out_channels, kernel_size=kernel_size, |
|
|
stride=stride, padding=padding, groups=groups) |
|
|
self.norm = nn.GroupNorm(groups,out_channels) |
|
|
self.activate = nn.LeakyReLU(0.2, inplace=act_inplace) |
|
|
|
|
|
def forward(self, x): |
|
|
y = self.conv(x) |
|
|
if self.act_norm: |
|
|
y = self.activate(self.norm(y)) |
|
|
return y |
|
|
|
|
|
|
|
|
class gInception_ST(nn.Module): |
|
|
"""A IncepU block for SimVP""" |
|
|
|
|
|
def __init__(self, C_in, C_hid, C_out, incep_ker = [3,5,7,11], groups = 8): |
|
|
super(gInception_ST, self).__init__() |
|
|
self.conv1 = nn.Conv2d(C_in, C_hid, kernel_size=1, stride=1, padding=0) |
|
|
|
|
|
layers = [] |
|
|
for ker in incep_ker: |
|
|
layers.append(GroupConv2d( |
|
|
C_hid, C_out, kernel_size=ker, stride=1, |
|
|
padding=ker//2, groups=groups, act_norm=True)) |
|
|
self.layers = nn.Sequential(*layers) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.conv1(x) |
|
|
y = 0 |
|
|
for layer in self.layers: |
|
|
y += layer(x) |
|
|
return y |
|
|
|
|
|
|
|
|
class AttentionModule(nn.Module): |
|
|
"""Large Kernel Attention for SimVP""" |
|
|
|
|
|
def __init__(self, dim, kernel_size, dilation=3): |
|
|
super().__init__() |
|
|
d_k = 2 * dilation - 1 |
|
|
d_p = (d_k - 1) // 2 |
|
|
dd_k = kernel_size // dilation + ((kernel_size // dilation) % 2 - 1) |
|
|
dd_p = (dilation * (dd_k - 1) // 2) |
|
|
|
|
|
self.conv0 = nn.Conv2d(dim, dim, d_k, padding=d_p, groups=dim) |
|
|
self.conv_spatial = nn.Conv2d( |
|
|
dim, dim, dd_k, stride=1, padding=dd_p, groups=dim, dilation=dilation) |
|
|
self.conv1 = nn.Conv2d(dim, 2*dim, 1) |
|
|
|
|
|
def forward(self, x): |
|
|
u = x.clone() |
|
|
attn = self.conv0(x) |
|
|
attn = self.conv_spatial(attn) |
|
|
|
|
|
f_g = self.conv1(attn) |
|
|
split_dim = f_g.shape[1] // 2 |
|
|
f_x, g_x = torch.split(f_g, split_dim, dim=1) |
|
|
return torch.sigmoid(g_x) * f_x |
|
|
|
|
|
|
|
|
class SpatialAttention(nn.Module): |
|
|
"""A Spatial Attention block for SimVP""" |
|
|
|
|
|
def __init__(self, d_model, kernel_size=21, attn_shortcut=True): |
|
|
super().__init__() |
|
|
|
|
|
self.proj_1 = nn.Conv2d(d_model, d_model, 1) |
|
|
self.activation = nn.GELU() |
|
|
self.spatial_gating_unit = AttentionModule(d_model, kernel_size) |
|
|
self.proj_2 = nn.Conv2d(d_model, d_model, 1) |
|
|
self.attn_shortcut = attn_shortcut |
|
|
|
|
|
def forward(self, x): |
|
|
if self.attn_shortcut: |
|
|
shortcut = x.clone() |
|
|
x = self.proj_1(x) |
|
|
x = self.activation(x) |
|
|
x = self.spatial_gating_unit(x) |
|
|
x = self.proj_2(x) |
|
|
if self.attn_shortcut: |
|
|
x = x + shortcut |
|
|
return x |
|
|
|
|
|
|
|
|
class GASubBlock(nn.Module): |
|
|
"""A GABlock (gSTA) for SimVP""" |
|
|
|
|
|
def __init__(self, dim, kernel_size=21, mlp_ratio=4., |
|
|
drop=0., drop_path=0.1, init_value=1e-2, act_layer=nn.GELU): |
|
|
super().__init__() |
|
|
self.norm1 = nn.BatchNorm2d(dim) |
|
|
self.attn = SpatialAttention(dim, kernel_size) |
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
|
|
|
self.norm2 = nn.BatchNorm2d(dim) |
|
|
mlp_hidden_dim = int(dim * mlp_ratio) |
|
|
self.mlp = MixMlp( |
|
|
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) |
|
|
|
|
|
self.layer_scale_1 = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True) |
|
|
self.layer_scale_2 = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True) |
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def _init_weights(self, m): |
|
|
if isinstance(m, nn.Linear): |
|
|
trunc_normal_(m.weight, std=.02) |
|
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
|
nn.init.constant_(m.bias, 0) |
|
|
elif isinstance(m, nn.LayerNorm): |
|
|
nn.init.constant_(m.bias, 0) |
|
|
nn.init.constant_(m.weight, 1.0) |
|
|
elif isinstance(m, nn.Conv2d): |
|
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels |
|
|
fan_out //= m.groups |
|
|
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) |
|
|
if m.bias is not None: |
|
|
m.bias.data.zero_() |
|
|
|
|
|
@torch.jit.ignore |
|
|
def no_weight_decay(self): |
|
|
return {'layer_scale_1', 'layer_scale_2'} |
|
|
|
|
|
def forward(self, x): |
|
|
x = x + self.drop_path( |
|
|
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.attn(self.norm1(x))) |
|
|
x = x + self.drop_path( |
|
|
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x))) |
|
|
return x |
|
|
|
|
|
|
|
|
class ConvMixerSubBlock(nn.Module): |
|
|
"""A block of ConvMixer.""" |
|
|
|
|
|
def __init__(self, dim, kernel_size=9, activation=nn.GELU): |
|
|
super().__init__() |
|
|
|
|
|
self.conv_dw = nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same") |
|
|
self.act_1 = activation() |
|
|
self.norm_1 = nn.BatchNorm2d(dim) |
|
|
|
|
|
self.conv_pw = nn.Conv2d(dim, dim, kernel_size=1) |
|
|
self.act_2 = activation() |
|
|
self.norm_2 = nn.BatchNorm2d(dim) |
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def _init_weights(self, m): |
|
|
if isinstance(m, nn.BatchNorm2d): |
|
|
nn.init.constant_(m.bias, 0) |
|
|
nn.init.constant_(m.weight, 1.0) |
|
|
elif isinstance(m, nn.Conv2d): |
|
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels |
|
|
fan_out //= m.groups |
|
|
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) |
|
|
if m.bias is not None: |
|
|
m.bias.data.zero_() |
|
|
|
|
|
@torch.jit.ignore |
|
|
def no_weight_decay(self): |
|
|
return dict() |
|
|
|
|
|
def forward(self, x): |
|
|
x = x + self.norm_1(self.act_1(self.conv_dw(x))) |
|
|
x = self.norm_2(self.act_2(self.conv_pw(x))) |
|
|
return x |
|
|
|
|
|
|
|
|
class ConvNeXtSubBlock(ConvNeXtBlock): |
|
|
"""A block of ConvNeXt.""" |
|
|
|
|
|
def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0.1): |
|
|
super().__init__(dim, mlp_ratio=mlp_ratio, |
|
|
drop_path=drop_path, ls_init_value=1e-6, conv_mlp=True) |
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def _init_weights(self, m): |
|
|
if isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): |
|
|
nn.init.constant_(m.bias, 0) |
|
|
nn.init.constant_(m.weight, 1.0) |
|
|
elif isinstance(m, nn.Conv2d): |
|
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels |
|
|
fan_out //= m.groups |
|
|
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) |
|
|
if m.bias is not None: |
|
|
m.bias.data.zero_() |
|
|
|
|
|
@torch.jit.ignore |
|
|
def no_weight_decay(self): |
|
|
return {'gamma'} |
|
|
|
|
|
def forward(self, x): |
|
|
x = x + self.drop_path( |
|
|
self.gamma.reshape(1, -1, 1, 1) * self.mlp(self.norm(self.conv_dw(x)))) |
|
|
return x |
|
|
|
|
|
|
|
|
class HorNetSubBlock(HorBlock): |
|
|
"""A block of HorNet.""" |
|
|
|
|
|
def __init__(self, dim, mlp_ratio=4., drop_path=0.1, init_value=1e-6): |
|
|
super().__init__(dim, mlp_ratio=mlp_ratio, drop_path=drop_path, init_value=init_value) |
|
|
self.apply(self._init_weights) |
|
|
|
|
|
@torch.jit.ignore |
|
|
def no_weight_decay(self): |
|
|
return {'gamma1', 'gamma2'} |
|
|
|
|
|
def _init_weights(self, m): |
|
|
if isinstance(m, nn.Linear): |
|
|
trunc_normal_(m.weight, std=.02) |
|
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
|
nn.init.constant_(m.bias, 0) |
|
|
elif isinstance(m, nn.LayerNorm): |
|
|
nn.init.constant_(m.bias, 0) |
|
|
nn.init.constant_(m.weight, 1.0) |
|
|
elif isinstance(m, nn.Conv2d): |
|
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels |
|
|
fan_out //= m.groups |
|
|
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) |
|
|
if m.bias is not None: |
|
|
m.bias.data.zero_() |
|
|
|
|
|
|
|
|
class MLPMixerSubBlock(MixerBlock): |
|
|
"""A block of MLP-Mixer.""" |
|
|
|
|
|
def __init__(self, dim, input_resolution=None, mlp_ratio=4., drop=0., drop_path=0.1): |
|
|
seq_len = input_resolution[0] * input_resolution[1] |
|
|
super().__init__(dim, seq_len=seq_len, |
|
|
mlp_ratio=(0.5, mlp_ratio), drop_path=drop_path, drop=drop) |
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def _init_weights(self, m): |
|
|
if isinstance(m, nn.Linear): |
|
|
trunc_normal_(m.weight, std=.02) |
|
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
|
nn.init.constant_(m.bias, 0) |
|
|
elif isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): |
|
|
nn.init.constant_(m.bias, 0) |
|
|
nn.init.constant_(m.weight, 1.0) |
|
|
|
|
|
@torch.jit.ignore |
|
|
def no_weight_decay(self): |
|
|
return dict() |
|
|
|
|
|
def forward(self, x): |
|
|
B, C, H, W = x.shape |
|
|
x = x.flatten(2).transpose(1, 2) |
|
|
x = x + self.drop_path(self.mlp_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2)) |
|
|
x = x + self.drop_path(self.mlp_channels(self.norm2(x))) |
|
|
return x.reshape(B, H, W, C).permute(0, 3, 1, 2) |
|
|
|
|
|
|
|
|
class MogaSubBlock(nn.Module): |
|
|
"""A block of MogaNet.""" |
|
|
|
|
|
def __init__(self, embed_dims, mlp_ratio=4., drop_rate=0., drop_path_rate=0., init_value=1e-5, |
|
|
attn_dw_dilation=[1, 2, 3], attn_channel_split=[1, 3, 4]): |
|
|
super(MogaSubBlock, self).__init__() |
|
|
self.out_channels = embed_dims |
|
|
|
|
|
self.norm1 = nn.BatchNorm2d(embed_dims) |
|
|
self.attn = MultiOrderGatedAggregation( |
|
|
embed_dims, attn_dw_dilation=attn_dw_dilation, attn_channel_split=attn_channel_split) |
|
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() |
|
|
|
|
|
self.norm2 = nn.BatchNorm2d(embed_dims) |
|
|
mlp_hidden_dims = int(embed_dims * mlp_ratio) |
|
|
self.mlp = ChannelAggregationFFN( |
|
|
embed_dims=embed_dims, mlp_hidden_dims=mlp_hidden_dims, ffn_drop=drop_rate) |
|
|
|
|
|
self.layer_scale_1 = nn.Parameter(init_value * torch.ones((1, embed_dims, 1, 1)), requires_grad=True) |
|
|
self.layer_scale_2 = nn.Parameter(init_value * torch.ones((1, embed_dims, 1, 1)), requires_grad=True) |
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def _init_weights(self, m): |
|
|
if isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): |
|
|
nn.init.constant_(m.bias, 0) |
|
|
nn.init.constant_(m.weight, 1.0) |
|
|
elif isinstance(m, nn.Conv2d): |
|
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels |
|
|
fan_out //= m.groups |
|
|
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) |
|
|
if m.bias is not None: |
|
|
m.bias.data.zero_() |
|
|
|
|
|
@torch.jit.ignore |
|
|
def no_weight_decay(self): |
|
|
return {'layer_scale_1', 'layer_scale_2', 'sigma'} |
|
|
|
|
|
def forward(self, x): |
|
|
x = x + self.drop_path(self.layer_scale_1 * self.attn(self.norm1(x))) |
|
|
x = x + self.drop_path(self.layer_scale_2 * self.mlp(self.norm2(x))) |
|
|
return x |
|
|
|
|
|
|
|
|
class PoolFormerSubBlock(PoolFormerBlock): |
|
|
"""A block of PoolFormer.""" |
|
|
|
|
|
def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0.1): |
|
|
super().__init__(dim, pool_size=3, mlp_ratio=mlp_ratio, drop_path=drop_path, |
|
|
drop=drop, init_value=1e-5) |
|
|
self.apply(self._init_weights) |
|
|
|
|
|
@torch.jit.ignore |
|
|
def no_weight_decay(self): |
|
|
return {'layer_scale_1', 'layer_scale_2'} |
|
|
|
|
|
def _init_weights(self, m): |
|
|
if isinstance(m, nn.Linear): |
|
|
trunc_normal_(m.weight, std=.02) |
|
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
|
nn.init.constant_(m.bias, 0) |
|
|
elif isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): |
|
|
nn.init.constant_(m.bias, 0) |
|
|
nn.init.constant_(m.weight, 1.0) |
|
|
|
|
|
|
|
|
class SwinSubBlock(SwinTransformerBlock): |
|
|
"""A block of Swin Transformer.""" |
|
|
|
|
|
def __init__(self, dim, input_resolution=None, layer_i=0, mlp_ratio=4., drop=0., drop_path=0.1): |
|
|
window_size = 7 if input_resolution[0] % 7 == 0 else max(4, input_resolution[0] // 16) |
|
|
window_size = min(8, window_size) |
|
|
shift_size = 0 if (layer_i % 2 == 0) else window_size // 2 |
|
|
super().__init__(dim, input_resolution, num_heads=8, window_size=window_size, |
|
|
shift_size=shift_size, mlp_ratio=mlp_ratio, |
|
|
drop_path=drop_path, attn_drop=drop, proj_drop=drop, qkv_bias=True) |
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def _init_weights(self, m): |
|
|
if isinstance(m, nn.Linear): |
|
|
trunc_normal_(m.weight, std=.02) |
|
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
|
nn.init.constant_(m.bias, 0) |
|
|
elif isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): |
|
|
nn.init.constant_(m.bias, 0) |
|
|
nn.init.constant_(m.weight, 1.0) |
|
|
|
|
|
@torch.jit.ignore |
|
|
def no_weight_decay(self): |
|
|
return {} |
|
|
|
|
|
def forward(self, x): |
|
|
B, C, H, W = x.shape |
|
|
x = x.flatten(2).transpose(1, 2) |
|
|
x = self.norm1(x) |
|
|
x = x.view(B, H, W, C) |
|
|
x = super().forward(x) |
|
|
|
|
|
return x.reshape(B, H, W, C).permute(0, 3, 1, 2) |
|
|
|
|
|
|
|
|
def UniformerSubBlock(embed_dims, mlp_ratio=4., drop=0., drop_path=0., |
|
|
init_value=1e-6, block_type='Conv'): |
|
|
"""Build a block of Uniformer.""" |
|
|
|
|
|
assert block_type in ['Conv', 'MHSA'] |
|
|
if block_type == 'Conv': |
|
|
return CBlock(dim=embed_dims, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path) |
|
|
else: |
|
|
return SABlock(dim=embed_dims, num_heads=8, mlp_ratio=mlp_ratio, qkv_bias=True, |
|
|
drop=drop, drop_path=drop_path, init_value=init_value) |
|
|
|
|
|
|
|
|
class VANSubBlock(VANBlock): |
|
|
"""A block of VAN.""" |
|
|
|
|
|
def __init__(self, dim, mlp_ratio=4., drop=0.,drop_path=0., init_value=1e-2, act_layer=nn.GELU): |
|
|
super().__init__(dim=dim, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path, |
|
|
init_value=init_value, act_layer=act_layer) |
|
|
self.apply(self._init_weights) |
|
|
|
|
|
@torch.jit.ignore |
|
|
def no_weight_decay(self): |
|
|
return {'layer_scale_1', 'layer_scale_2'} |
|
|
|
|
|
def _init_weights(self, m): |
|
|
if isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): |
|
|
nn.init.constant_(m.bias, 0) |
|
|
nn.init.constant_(m.weight, 1.0) |
|
|
elif isinstance(m, nn.Conv2d): |
|
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels |
|
|
fan_out //= m.groups |
|
|
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) |
|
|
if m.bias is not None: |
|
|
m.bias.data.zero_() |
|
|
|
|
|
|
|
|
class ViTSubBlock(ViTBlock): |
|
|
"""A block of Vision Transformer.""" |
|
|
|
|
|
def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0.1): |
|
|
super().__init__(dim=dim, num_heads=8, mlp_ratio=mlp_ratio, qkv_bias=True, |
|
|
attn_drop=drop, proj_drop=0, drop_path=drop_path, act_layer=nn.GELU, norm_layer=nn.LayerNorm) |
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def _init_weights(self, m): |
|
|
if isinstance(m, nn.Linear): |
|
|
trunc_normal_(m.weight, std=.02) |
|
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
|
nn.init.constant_(m.bias, 0) |
|
|
elif isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): |
|
|
nn.init.constant_(m.bias, 0) |
|
|
nn.init.constant_(m.weight, 1.0) |
|
|
|
|
|
@torch.jit.ignore |
|
|
def no_weight_decay(self): |
|
|
return {} |
|
|
|
|
|
def forward(self, x): |
|
|
B, C, H, W = x.shape |
|
|
x = x.flatten(2).transpose(1, 2) |
|
|
x = x + self.drop_path(self.attn(self.norm1(x))) |
|
|
x = x + self.drop_path(self.mlp(self.norm2(x))) |
|
|
return x.reshape(B, H, W, C).permute(0, 3, 1, 2) |
|
|
|
|
|
|
|
|
class TemporalAttention(nn.Module): |
|
|
"""A Temporal Attention block for Temporal Attention Unit""" |
|
|
|
|
|
def __init__(self, d_model, kernel_size=21, attn_shortcut=True): |
|
|
super().__init__() |
|
|
|
|
|
self.proj_1 = nn.Conv2d(d_model, d_model, 1) |
|
|
self.activation = nn.GELU() |
|
|
self.spatial_gating_unit = TemporalAttentionModule(d_model, kernel_size) |
|
|
self.proj_2 = nn.Conv2d(d_model, d_model, 1) |
|
|
self.attn_shortcut = attn_shortcut |
|
|
|
|
|
def forward(self, x): |
|
|
if self.attn_shortcut: |
|
|
shortcut = x.clone() |
|
|
x = self.proj_1(x) |
|
|
x = self.activation(x) |
|
|
x = self.spatial_gating_unit(x) |
|
|
x = self.proj_2(x) |
|
|
if self.attn_shortcut: |
|
|
x = x + shortcut |
|
|
return x |
|
|
|
|
|
|
|
|
class TemporalAttentionModule(nn.Module): |
|
|
"""Large Kernel Attention for SimVP""" |
|
|
|
|
|
def __init__(self, dim, kernel_size, dilation=3, reduction=16): |
|
|
super().__init__() |
|
|
d_k = 2 * dilation - 1 |
|
|
d_p = (d_k - 1) // 2 |
|
|
dd_k = kernel_size // dilation + ((kernel_size // dilation) % 2 - 1) |
|
|
dd_p = (dilation * (dd_k - 1) // 2) |
|
|
|
|
|
self.conv0 = nn.Conv2d(dim, dim, d_k, padding=d_p, groups=dim) |
|
|
self.conv_spatial = nn.Conv2d( |
|
|
dim, dim, dd_k, stride=1, padding=dd_p, groups=dim, dilation=dilation) |
|
|
self.conv1 = nn.Conv2d(dim, dim, 1) |
|
|
|
|
|
self.reduction = max(dim // reduction, 4) |
|
|
self.avg_pool = nn.AdaptiveAvgPool2d(1) |
|
|
self.fc = nn.Sequential( |
|
|
nn.Linear(dim, dim // self.reduction, bias=False), |
|
|
nn.ReLU(True), |
|
|
nn.Linear(dim // self.reduction, dim, bias=False), |
|
|
nn.Sigmoid() |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
u = x.clone() |
|
|
attn = self.conv0(x) |
|
|
attn = self.conv_spatial(attn) |
|
|
f_x = self.conv1(attn) |
|
|
|
|
|
b, c, _, _ = x.size() |
|
|
se_atten = self.avg_pool(x).view(b, c) |
|
|
se_atten = self.fc(se_atten).view(b, c, 1, 1) |
|
|
return se_atten * f_x * u |
|
|
|
|
|
|
|
|
class TAUSubBlock(GASubBlock): |
|
|
"""A TAUBlock (tau) for Temporal Attention Unit""" |
|
|
|
|
|
def __init__(self, dim, kernel_size=21, mlp_ratio=4., |
|
|
drop=0., drop_path=0.1, init_value=1e-2, act_layer=nn.GELU): |
|
|
super().__init__(dim=dim, kernel_size=kernel_size, mlp_ratio=mlp_ratio, |
|
|
drop=drop, drop_path=drop_path, init_value=init_value, act_layer=act_layer) |
|
|
|
|
|
self.attn = TemporalAttention(dim, kernel_size) |
|
|
|