LARRES / utilpack /swinlstm_modules.py
Staty's picture
Upload 50 files
2b21abc verified
import torch
import torch.nn as nn
from timm.models.swin_transformer import SwinTransformerBlock, window_reverse, PatchEmbed, PatchMerging, window_partition
from timm.layers import to_2tuple
class SwinLSTMCell(nn.Module):
def __init__(self, dim, input_resolution, num_heads, window_size, depth,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, flag=None):
"""
Args:
flag: 0 UpSample 1 DownSample 2 STconvert
"""
super(SwinLSTMCell, self).__init__()
self.STBs = nn.ModuleList(
STB(i, dim=dim, input_resolution=input_resolution, depth=depth,
num_heads=num_heads, window_size=window_size, mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop,
drop_path=drop_path, norm_layer=norm_layer, flag=flag)
for i in range(depth))
def forward(self, xt, hidden_states):
"""
Args:
xt: input for t period
hidden_states: [hx, cx] hidden_states for t-1 period
"""
if hidden_states is None:
B, L, C = xt.shape
hx = torch.zeros(B, L, C).to(xt.device)
cx = torch.zeros(B, L, C).to(xt.device)
else:
hx, cx = hidden_states
outputs = []
for index, layer in enumerate(self.STBs):
if index == 0:
x = layer(xt, hx)
outputs.append(x)
else:
if index % 2 == 0:
x = layer(outputs[-1], xt)
outputs.append(x)
if index % 2 == 1:
x = layer(outputs[-1], None)
outputs.append(x)
o_t = outputs[-1]
Ft = torch.sigmoid(o_t)
cell = torch.tanh(o_t)
Ct = Ft * (cx + cell)
Ht = Ft * torch.tanh(Ct)
return Ht, (Ht, Ct)
class STB(SwinTransformerBlock):
def __init__(self, index, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, flag=None):
if flag == 0:
drop_path = drop_path[depth - index - 1]
elif flag == 1:
drop_path = drop_path[index]
elif flag == 2:
drop_path = drop_path
super(STB, self).__init__(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (index % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path,
norm_layer=norm_layer)
self.red = nn.Linear(2 * dim, dim)
def forward(self, x, hx=None):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
if hx is not None:
hx = self.norm1(hx)
x = torch.cat((x, hx), -1)
x = self.red(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # num_win*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # num_win*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=self.attn_mask) # num_win*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchInflated(nn.Module):
r""" Tensor to Patch Inflating
Args:
in_chans (int): Number of input image channels.
embed_dim (int): Number of linear projection output channels.
input_resolution (tuple[int]): Input resulotion.
"""
def __init__(self, in_chans, embed_dim, input_resolution, stride=2, padding=1, output_padding=1):
super(PatchInflated, self).__init__()
stride = to_2tuple(stride)
padding = to_2tuple(padding)
output_padding = to_2tuple(output_padding)
self.input_resolution = input_resolution
self.Conv = nn.ConvTranspose2d(in_channels=embed_dim, out_channels=in_chans, kernel_size=(3, 3),
stride=stride, padding=padding, output_padding=output_padding)
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x = x.permute(0, 3, 1, 2)
x = self.Conv(x)
return x
class PatchExpanding(nn.Module):
r""" Patch Expanding Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
super(PatchExpanding, self).__init__()
self.input_resolution = input_resolution
self.dim = dim
self.expand = nn.Linear(dim, 2 * dim, bias=False) if dim_scale == 2 else nn.Identity()
self.norm = norm_layer(dim // dim_scale)
def forward(self, x):
H, W = self.input_resolution
x = self.expand(x)
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
x = x.reshape(B, H, W, 2, 2, C // 4)
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, H * 2, W * 2, C // 4)
x = x.view(B, -1, C // 4)
x = self.norm(x)
return x
class UpSample(nn.Module):
def __init__(self, img_size, patch_size, in_chans, embed_dim, depths_upsample, num_heads, window_size, mlp_ratio=4.,
qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, flag=0):
super(UpSample, self).__init__()
self.img_size = img_size
self.num_layers = len(depths_upsample)
self.embed_dim = embed_dim
self.mlp_ratio = mlp_ratio
self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=nn.LayerNorm)
patches_resolution = self.patch_embed.grid_size
self.Unembed = PatchInflated(in_chans=in_chans, embed_dim=embed_dim, input_resolution=patches_resolution)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths_upsample))]
self.layers = nn.ModuleList()
self.upsample = nn.ModuleList()
for i_layer in range(self.num_layers):
resolution1 = (patches_resolution[0] // (2 ** (self.num_layers - i_layer)))
resolution2 = (patches_resolution[1] // (2 ** (self.num_layers - i_layer)))
dimension = int(embed_dim * 2 ** (self.num_layers - i_layer))
upsample = PatchExpanding(input_resolution=(resolution1, resolution2), dim=dimension)
layer = SwinLSTMCell(dim=dimension, input_resolution=(resolution1, resolution2),
depth=depths_upsample[(self.num_layers - 1 - i_layer)],
num_heads=num_heads[(self.num_layers - 1 - i_layer)],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths_upsample[:(self.num_layers - 1 - i_layer)]):
sum(depths_upsample[:(self.num_layers - 1 - i_layer) + 1])],
norm_layer=norm_layer, flag=flag)
self.layers.append(layer)
self.upsample.append(upsample)
def forward(self, x, y):
hidden_states_up = []
for index, layer in enumerate(self.layers):
x, hidden_state = layer(x, y[index])
x = self.upsample[index](x)
hidden_states_up.append(hidden_state)
x = torch.sigmoid(self.Unembed(x))
return hidden_states_up, x
class DownSample(nn.Module):
def __init__(self, img_size, patch_size, in_chans, embed_dim, depths_downsample, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, flag=1):
super(DownSample, self).__init__()
self.num_layers = len(depths_downsample)
self.embed_dim = embed_dim
self.mlp_ratio = mlp_ratio
self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=nn.LayerNorm)
patches_resolution = self.patch_embed.grid_size
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths_downsample))]
self.layers = nn.ModuleList()
self.downsample = nn.ModuleList()
for i_layer in range(self.num_layers):
downsample = PatchMerging(input_resolution=(patches_resolution[0] // (2 ** i_layer),
patches_resolution[1] // (2 ** i_layer)),
dim=int(embed_dim * 2 ** i_layer))
layer = SwinLSTMCell(dim=int(embed_dim * 2 ** i_layer),
input_resolution=(patches_resolution[0] // (2 ** i_layer),
patches_resolution[1] // (2 ** i_layer)),
depth=depths_downsample[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths_downsample[:i_layer]):sum(depths_downsample[:i_layer + 1])],
norm_layer=norm_layer, flag=flag)
self.layers.append(layer)
self.downsample.append(downsample)
def forward(self, x, y):
x = self.patch_embed(x)
hidden_states_down = []
for index, layer in enumerate(self.layers):
x, hidden_state = layer(x, y[index])
x = self.downsample[index](x)
hidden_states_down.append(hidden_state)
return hidden_states_down, x
class STconvert(nn.Module):
def __init__(self, img_size, patch_size, in_chans, embed_dim, depths, num_heads,
window_size, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0.,
attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, flag=2):
super(STconvert, self).__init__()
self.embed_dim = embed_dim
self.mlp_ratio = mlp_ratio
self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size,
in_chans=in_chans, embed_dim=embed_dim,
norm_layer=norm_layer)
patches_resolution = self.patch_embed.grid_size
self.patch_inflated = PatchInflated(in_chans=in_chans, embed_dim=embed_dim,
input_resolution=patches_resolution)
self.layer = SwinLSTMCell(dim=embed_dim,
input_resolution=(patches_resolution[0], patches_resolution[1]),
depth=depths, num_heads=num_heads,
window_size=window_size, mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=drop_path_rate, norm_layer=norm_layer,
flag=flag)
def forward(self, x, h=None):
x = self.patch_embed(x)
x, hidden_state = self.layer(x, h)
x = torch.sigmoid(self.patch_inflated(x))
return x, hidden_state