import torch import torch.nn as nn from timm.models.swin_transformer import SwinTransformerBlock, window_reverse, PatchEmbed, PatchMerging, window_partition from timm.layers import to_2tuple class SwinLSTMCell(nn.Module): def __init__(self, dim, input_resolution, num_heads, window_size, depth, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, flag=None): """ Args: flag: 0 UpSample 1 DownSample 2 STconvert """ super(SwinLSTMCell, self).__init__() self.STBs = nn.ModuleList( STB(i, dim=dim, input_resolution=input_resolution, depth=depth, num_heads=num_heads, window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path, norm_layer=norm_layer, flag=flag) for i in range(depth)) def forward(self, xt, hidden_states): """ Args: xt: input for t period hidden_states: [hx, cx] hidden_states for t-1 period """ if hidden_states is None: B, L, C = xt.shape hx = torch.zeros(B, L, C).to(xt.device) cx = torch.zeros(B, L, C).to(xt.device) else: hx, cx = hidden_states outputs = [] for index, layer in enumerate(self.STBs): if index == 0: x = layer(xt, hx) outputs.append(x) else: if index % 2 == 0: x = layer(outputs[-1], xt) outputs.append(x) if index % 2 == 1: x = layer(outputs[-1], None) outputs.append(x) o_t = outputs[-1] Ft = torch.sigmoid(o_t) cell = torch.tanh(o_t) Ct = Ft * (cx + cell) Ht = Ft * torch.tanh(Ct) return Ht, (Ht, Ct) class STB(SwinTransformerBlock): def __init__(self, index, dim, input_resolution, depth, num_heads, window_size, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, flag=None): if flag == 0: drop_path = drop_path[depth - index - 1] elif flag == 1: drop_path = drop_path[index] elif flag == 2: drop_path = drop_path super(STB, self).__init__(dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, shift_size=0 if (index % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop, drop_path=drop_path, norm_layer=norm_layer) self.red = nn.Linear(2 * dim, dim) def forward(self, x, hx=None): H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" shortcut = x x = self.norm1(x) if hx is not None: hx = self.norm1(hx) x = torch.cat((x, hx), -1) x = self.red(x) x = x.view(B, H, W, C) # cyclic shift if self.shift_size > 0: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) else: shifted_x = x # partition windows x_windows = window_partition(shifted_x, self.window_size) # num_win*B, window_size, window_size, C x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # num_win*B, window_size*window_size, C # W-MSA/SW-MSA attn_windows = self.attn(x_windows, mask=self.attn_mask) # num_win*B, window_size*window_size, C # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C # reverse cyclic shift if self.shift_size > 0: x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x x = x.view(B, H * W, C) # FFN x = shortcut + self.drop_path(x) x = x + self.drop_path(self.mlp(self.norm2(x))) return x class PatchInflated(nn.Module): r""" Tensor to Patch Inflating Args: in_chans (int): Number of input image channels. embed_dim (int): Number of linear projection output channels. input_resolution (tuple[int]): Input resulotion. """ def __init__(self, in_chans, embed_dim, input_resolution, stride=2, padding=1, output_padding=1): super(PatchInflated, self).__init__() stride = to_2tuple(stride) padding = to_2tuple(padding) output_padding = to_2tuple(output_padding) self.input_resolution = input_resolution self.Conv = nn.ConvTranspose2d(in_channels=embed_dim, out_channels=in_chans, kernel_size=(3, 3), stride=stride, padding=padding, output_padding=output_padding) def forward(self, x): H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." x = x.view(B, H, W, C) x = x.permute(0, 3, 1, 2) x = self.Conv(x) return x class PatchExpanding(nn.Module): r""" Patch Expanding Layer. Args: input_resolution (tuple[int]): Resolution of input feature. dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): super(PatchExpanding, self).__init__() self.input_resolution = input_resolution self.dim = dim self.expand = nn.Linear(dim, 2 * dim, bias=False) if dim_scale == 2 else nn.Identity() self.norm = norm_layer(dim // dim_scale) def forward(self, x): H, W = self.input_resolution x = self.expand(x) B, L, C = x.shape assert L == H * W, "input feature has wrong size" x = x.view(B, H, W, C) x = x.reshape(B, H, W, 2, 2, C // 4) x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, H * 2, W * 2, C // 4) x = x.view(B, -1, C // 4) x = self.norm(x) return x class UpSample(nn.Module): def __init__(self, img_size, patch_size, in_chans, embed_dim, depths_upsample, num_heads, window_size, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, flag=0): super(UpSample, self).__init__() self.img_size = img_size self.num_layers = len(depths_upsample) self.embed_dim = embed_dim self.mlp_ratio = mlp_ratio self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=nn.LayerNorm) patches_resolution = self.patch_embed.grid_size self.Unembed = PatchInflated(in_chans=in_chans, embed_dim=embed_dim, input_resolution=patches_resolution) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths_upsample))] self.layers = nn.ModuleList() self.upsample = nn.ModuleList() for i_layer in range(self.num_layers): resolution1 = (patches_resolution[0] // (2 ** (self.num_layers - i_layer))) resolution2 = (patches_resolution[1] // (2 ** (self.num_layers - i_layer))) dimension = int(embed_dim * 2 ** (self.num_layers - i_layer)) upsample = PatchExpanding(input_resolution=(resolution1, resolution2), dim=dimension) layer = SwinLSTMCell(dim=dimension, input_resolution=(resolution1, resolution2), depth=depths_upsample[(self.num_layers - 1 - i_layer)], num_heads=num_heads[(self.num_layers - 1 - i_layer)], window_size=window_size, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths_upsample[:(self.num_layers - 1 - i_layer)]): sum(depths_upsample[:(self.num_layers - 1 - i_layer) + 1])], norm_layer=norm_layer, flag=flag) self.layers.append(layer) self.upsample.append(upsample) def forward(self, x, y): hidden_states_up = [] for index, layer in enumerate(self.layers): x, hidden_state = layer(x, y[index]) x = self.upsample[index](x) hidden_states_up.append(hidden_state) x = torch.sigmoid(self.Unembed(x)) return hidden_states_up, x class DownSample(nn.Module): def __init__(self, img_size, patch_size, in_chans, embed_dim, depths_downsample, num_heads, window_size, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, flag=1): super(DownSample, self).__init__() self.num_layers = len(depths_downsample) self.embed_dim = embed_dim self.mlp_ratio = mlp_ratio self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=nn.LayerNorm) patches_resolution = self.patch_embed.grid_size dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths_downsample))] self.layers = nn.ModuleList() self.downsample = nn.ModuleList() for i_layer in range(self.num_layers): downsample = PatchMerging(input_resolution=(patches_resolution[0] // (2 ** i_layer), patches_resolution[1] // (2 ** i_layer)), dim=int(embed_dim * 2 ** i_layer)) layer = SwinLSTMCell(dim=int(embed_dim * 2 ** i_layer), input_resolution=(patches_resolution[0] // (2 ** i_layer), patches_resolution[1] // (2 ** i_layer)), depth=depths_downsample[i_layer], num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths_downsample[:i_layer]):sum(depths_downsample[:i_layer + 1])], norm_layer=norm_layer, flag=flag) self.layers.append(layer) self.downsample.append(downsample) def forward(self, x, y): x = self.patch_embed(x) hidden_states_down = [] for index, layer in enumerate(self.layers): x, hidden_state = layer(x, y[index]) x = self.downsample[index](x) hidden_states_down.append(hidden_state) return hidden_states_down, x class STconvert(nn.Module): def __init__(self, img_size, patch_size, in_chans, embed_dim, depths, num_heads, window_size, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, flag=2): super(STconvert, self).__init__() self.embed_dim = embed_dim self.mlp_ratio = mlp_ratio self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer) patches_resolution = self.patch_embed.grid_size self.patch_inflated = PatchInflated(in_chans=in_chans, embed_dim=embed_dim, input_resolution=patches_resolution) self.layer = SwinLSTMCell(dim=embed_dim, input_resolution=(patches_resolution[0], patches_resolution[1]), depth=depths, num_heads=num_heads, window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate, norm_layer=norm_layer, flag=flag) def forward(self, x, h=None): x = self.patch_embed(x) x, hidden_state = self.layer(x, h) x = torch.sigmoid(self.patch_inflated(x)) return x, hidden_state