LARRES / utilpack /wast_modules.py
Staty's picture
Upload 50 files
2b21abc verified
import torch, pywt
import torch.nn as nn
from einops import rearrange
from functools import partial
from itertools import accumulate
from timm.layers import DropPath, activations
from timm.models._efficientnet_blocks import SqueezeExcite, InvertedResidual
# version adaptation for PyTorch > 1.7.1
IS_HIGH_VERSION = tuple(map(int, torch.__version__.split('+')[0].split('.'))) > (1, 7, 1)
if IS_HIGH_VERSION:
import torch.fft
class HighFocalFrequencyLoss(nn.Module):
""" Example:
fake = torch.randn(4, 3, 128, 64)
real = torch.randn(4, 3, 128, 64)
hffl = HighFocalFrequencyLoss()
loss = hffl(fake, real)
print(loss)
"""
def __init__(self, loss_weight=0.001, level=1, tau=0.1, alpha=1.0, patch_factor=1, ave_spectrum=False, log_matrix=True, batch_matrix=False):
super(HighFocalFrequencyLoss, self).__init__()
self.loss_weight = loss_weight
self.alpha = alpha
self.patch_factor = patch_factor
self.ave_spectrum = ave_spectrum
self.log_matrix = log_matrix
self.batch_matrix = batch_matrix
self.level = level
self.tau = tau
self.DWT = WaveletTransform2D().cuda()
def tensor2freq(self, x):
# crop image patches
patch_factor = self.patch_factor
_, _, h, w = x.shape
assert h % patch_factor == 0 and w % patch_factor == 0, (
'Patch factor should be divisible by image height and width')
patch_list = []
patch_h = h // patch_factor
patch_w = w // patch_factor
for i in range(patch_factor):
for j in range(patch_factor):
patch_list.append(x[:, :, i * patch_h:(i + 1) * patch_h, j * patch_w:(j + 1) * patch_w])
# stack to patch tensor
y = torch.stack(patch_list, 1)
# perform 2D DFT (real-to-complex, orthonormalization)
if IS_HIGH_VERSION:
freq = torch.fft.fft2(y, norm='ortho')
freq = torch.stack([freq.real, freq.imag], -1)
else:
freq = torch.rfft(y, 2, onesided=False, normalized=True)
return freq
def build_freq_mask(self, shape):
H, W = shape[-2:]
radius = self.tau * max(H, W)
Y, X = torch.meshgrid(torch.arange(H), torch.arange(W))
mask = torch.ones_like(X, dtype=torch.float32).cuda()
centers = [(0, 0), (0, W - 1), (H - 1, 0), (H - 1, W - 1)]
for center in centers:
distance = torch.sqrt((X - center[1]) ** 2 + (Y - center[0]) ** 2)
mask[distance <= radius] = 0
return mask
def loss_formulation(self, recon_freq, real_freq, matrix=None):
# spectrum weight matrix
if matrix is not None:
# if the matrix is predefined
weight_matrix = matrix.detach()
else:
# if the matrix is calculated online: continuous, dynamic, based on current Euclidean distance
matrix_tmp = (recon_freq - real_freq) ** 2
matrix_tmp = torch.sqrt(matrix_tmp[..., 0] + matrix_tmp[..., 1]) ** self.alpha
# whether to adjust the spectrum weight matrix by logarithm
if self.log_matrix:
matrix_tmp = torch.log(matrix_tmp + 1.0)
# whether to calculate the spectrum weight matrix using batch-based statistics
if self.batch_matrix:
matrix_tmp = matrix_tmp / matrix_tmp.max()
else:
matrix_tmp = matrix_tmp / matrix_tmp.max(-1).values.max(-1).values[:, :, :, None, None]
matrix_tmp[torch.isnan(matrix_tmp)] = 0.0
matrix_tmp = torch.clamp(matrix_tmp, min=0.0, max=1.0)
weight_matrix = matrix_tmp.clone().detach()
assert weight_matrix.min().item() >= 0 and weight_matrix.max().item() <= 1, (
'The values of spectrum weight matrix should be in the range [0, 1], '
'but got Min: %.10f Max: %.10f' % (weight_matrix.min().item(), weight_matrix.max().item()))
# frequency distance using (squared) Euclidean distance
tmp = (recon_freq - real_freq) ** 2
freq_distance = tmp[..., 0] + tmp[..., 1]
# dynamic spectrum weighting (Hadamard product)
mask = self.build_freq_mask(weight_matrix.shape)
loss = weight_matrix * freq_distance * mask
return torch.mean(loss)
def frequency_loss(self, pred, target, matrix=None):
"""Forward function to calculate focal frequency loss.
Args:
pred (torch.Tensor): of shape (N, C, H, W). Predicted tensor.
target (torch.Tensor): of shape (N, C, H, W). Target tensor.
matrix (torch.Tensor, optional): Element-wise spectrum weight matrix.
Default: None (If set to None: calculated online, dynamic).
"""
pred_freq = self.tensor2freq(pred)
target_freq = self.tensor2freq(target)
# whether to use minibatch average spectrum
if self.ave_spectrum:
pred_freq = torch.mean(pred_freq, 0, keepdim=True)
target_freq = torch.mean(target_freq, 0, keepdim=True)
return self.loss_formulation(pred_freq, target_freq, matrix)
def forward(self, pred, target, matrix=None, **kwargs):
pred = rearrange(pred, 'b t c h w -> (b t) c h w') if kwargs["reshape"] is True else pred
target = rearrange(target, 'b t c h w -> (b t) c h w') if kwargs["reshape"] is True else target
loss = 0
for level in range(self.level):
pred, _, _, _ = self.DWT(pred)
target, _, _, _ = self.DWT(target)
loss += self.frequency_loss(pred, target, matrix)
return loss * self.loss_weight
class WaveletTransform2D(nn.Module):
"""Compute a two-dimensional wavelet transform.
loss = nn.MSELoss()
data = torch.rand(1, 3, 128, 256)
DWT = WaveletTransform2D()
IDWT = WaveletTransform2D(inverse=True)
LL, LH, HL, HH = DWT(data)
recdata = IDWT([LL, LH, HL, HH])
print(loss(data, recdata))
"""
def __init__(self, inverse=False, wavelet="haar", mode="constant"):
super(WaveletTransform2D, self).__init__()
self.mode = mode
wavelet = pywt.Wavelet(wavelet)
if isinstance(wavelet, tuple):
dec_lo, dec_hi, rec_lo, rec_hi = wavelet
else:
dec_lo, dec_hi, rec_lo, rec_hi = wavelet.filter_bank
self.inverse = inverse
if inverse is False:
dec_lo = torch.tensor(dec_lo).flip(-1).unsqueeze(0)
dec_hi = torch.tensor(dec_hi).flip(-1).unsqueeze(0)
self.build_filters(dec_lo, dec_hi)
else:
rec_lo = torch.tensor(rec_lo).unsqueeze(0)
rec_hi = torch.tensor(rec_hi).unsqueeze(0)
self.build_filters(rec_lo, rec_hi)
def build_filters(self, lo, hi):
# construct 2d filter
self.dim_size = lo.shape[-1]
ll = self.outer(lo, lo)
lh = self.outer(hi, lo)
hl = self.outer(lo, hi)
hh = self.outer(hi, hi)
filters = torch.stack([ll, lh, hl, hh],dim=0)
filters = filters.unsqueeze(1)
self.register_buffer('filters', filters) # [4, 1, height, width]
def outer(self, a: torch.Tensor, b: torch.Tensor):
"""Torch implementation of numpy's outer for 1d vectors."""
a_flat = torch.reshape(a, [-1])
b_flat = torch.reshape(b, [-1])
a_mul = torch.unsqueeze(a_flat, dim=-1)
b_mul = torch.unsqueeze(b_flat, dim=0)
return a_mul * b_mul
def get_pad(self, data_len: int, filter_len: int):
padr = (2 * filter_len - 3) // 2
padl = (2 * filter_len - 3) // 2
# pad to even singal length.
if data_len % 2 != 0:
padr += 1
return padr, padl
def adaptive_pad(self, data):
padb, padt = self.get_pad(data.shape[-2], self.dim_size)
padr, padl = self.get_pad(data.shape[-1], self.dim_size)
data_pad = torch.nn.functional.pad(data, [padl, padr, padt, padb], mode=self.mode)
return data_pad
def forward(self, data):
if self.inverse is False:
b, c, h, w = data.shape
dec_res = []
data = self.adaptive_pad(data)
for filter in self.filters:
dec_res.append(torch.nn.functional.conv2d(data, filter.repeat(c, 1, 1, 1), stride=2, groups=c))
return dec_res
else:
b, c, h, w = data[0].shape
data = torch.stack(data, dim=2).reshape(b, -1, h, w)
rec_res = torch.nn.functional.conv_transpose2d(data, self.filters.repeat(c, 1, 1, 1), stride=2, groups=c)
return rec_res
class WaveletTransform3D(nn.Module):
"""Compute a three-dimensional wavelet transform.
Example:
loss = nn.MSELoss()
data = torch.rand(1, 3, 10, 128, 256)
DWT = WaveletTransform3D()
IDWT = WaveletTransform3D(inverse=True)
LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH = DWT(data)
recdata = IDWT([LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH])
print(loss(data, recdata))
LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH = DWT_3D(data)
recdata = IDWT_3D(LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH)
print(loss(data, recdata))
"""
def __init__(self, inverse=False, wavelet="haar", mode="constant"):
super(WaveletTransform3D, self).__init__()
self.mode = mode
wavelet = pywt.Wavelet(wavelet)
if isinstance(wavelet, tuple):
dec_lo, dec_hi, rec_lo, rec_hi = wavelet
else:
dec_lo, dec_hi, rec_lo, rec_hi = wavelet.filter_bank
self.inverse = inverse
if inverse is False:
dec_lo = torch.tensor(dec_lo).flip(-1).unsqueeze(0)
dec_hi = torch.tensor(dec_hi).flip(-1).unsqueeze(0)
self.build_filters(dec_lo, dec_hi)
else:
rec_lo = torch.tensor(rec_lo).unsqueeze(0)
rec_hi = torch.tensor(rec_hi).unsqueeze(0)
self.build_filters(rec_lo, rec_hi)
def build_filters(self, lo, hi):
# construct 3d filter
self.dim_size = lo.shape[-1]
size = [self.dim_size] * 3
lll = self.outer(lo, self.outer(lo, lo)).reshape(size)
llh = self.outer(lo, self.outer(lo, hi)).reshape(size)
lhl = self.outer(lo, self.outer(hi, lo)).reshape(size)
lhh = self.outer(lo, self.outer(hi, hi)).reshape(size)
hll = self.outer(hi, self.outer(lo, lo)).reshape(size)
hlh = self.outer(hi, self.outer(lo, hi)).reshape(size)
hhl = self.outer(hi, self.outer(hi, lo)).reshape(size)
hhh = self.outer(hi, self.outer(hi, hi)).reshape(size)
filters = torch.stack([lll, llh, lhl, lhh, hll, hlh, hhl, hhh], dim=0)
filters = filters.unsqueeze(1)
self.register_buffer('filters', filters) # [8, 1, length, height, width]
def outer(self, a: torch.Tensor, b: torch.Tensor):
"""Torch implementation of numpy's outer for 1d vectors."""
a_flat = torch.reshape(a, [-1])
b_flat = torch.reshape(b, [-1])
a_mul = torch.unsqueeze(a_flat, dim=-1)
b_mul = torch.unsqueeze(b_flat, dim=0)
return a_mul * b_mul
def get_pad(self, data_len: int, filter_len: int):
padr = (2 * filter_len - 3) // 2
padl = (2 * filter_len - 3) // 2
# pad to even singal length.
if data_len % 2 != 0:
padr += 1
return padr, padl
def adaptive_pad(self, data):
pad_back, pad_front = self.get_pad(data.shape[-3], self.dim_size)
pad_bottom, pad_top = self.get_pad(data.shape[-2], self.dim_size)
pad_right, pad_left = self.get_pad(data.shape[-1], self.dim_size)
data_pad = torch.nn.functional.pad(
data, [pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back], mode=self.mode)
return data_pad
def forward(self, data):
if self.inverse is False:
b, c, t, h, w = data.shape
dec_res = []
data = self.adaptive_pad(data)
for filter in self.filters:
dec_res.append(torch.nn.functional.conv3d(data, filter.repeat(c, 1, 1, 1, 1), stride=2, groups=c))
return dec_res
else:
b, c, t, h, w = data[0].shape
data = torch.stack(data, dim=2).reshape(b, -1, t, h, w)
rec_res = torch.nn.functional.conv_transpose3d(data, self.filters.repeat(c, 1, 1, 1, 1), stride=2, groups=c)
return rec_res
class FrequencyAttention(nn.Module):
def __init__(self, in_dim, out_dim, reduction=32):
super(FrequencyAttention, self).__init__()
self.avgpool_h = nn.AdaptiveAvgPool2d((None, 1))
self.avgpool_w = nn.AdaptiveAvgPool2d((1, None))
hidden_dim = max(8, in_dim // reduction)
self.conv1 = nn.Conv2d(in_dim, hidden_dim, kernel_size=1, stride=1, padding=0)
self.bn1 = nn.BatchNorm2d(hidden_dim)
self.act = activations.HardSwish(inplace=True)
self.conv_h = nn.Conv2d(hidden_dim, out_dim, kernel_size=1, stride=1, padding=0)
self.conv_w = nn.Conv2d(hidden_dim, out_dim, kernel_size=1, stride=1, padding=0)
def forward(self, x):
identity = x
n, c, h, w = x.size()
x_h = self.avgpool_h(x) # b c h 1
x_w = self.avgpool_w(x).permute(0, 1, 3, 2) # b c w 1
y = torch.cat([x_h, x_w], dim=2) # b c (h+w) 1
y = self.conv1(y)
y = self.bn1(y)
y = self.act(y)
x_h, x_w = torch.split(y, [h, w], dim=2)
x_w = x_w.permute(0, 1, 3, 2)
a_h = self.conv_h(x_h).sigmoid()
a_w = self.conv_w(x_w).sigmoid()
out = identity * a_w * a_h
return out
class TF_AwareBlock(nn.Module):
def __init__(self, dim, mlp_ratio=4., drop=0., ls_init_value=1e-2, drop_path=0.1, large_kernel=51, small_kernel=5):
super().__init__()
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm1 = nn.BatchNorm2d(dim)
self.norm2 = nn.BatchNorm2d(dim)
self.lk1 = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=(large_kernel, 5), groups=dim, padding="same"),
nn.BatchNorm2d(dim)
)
self.lk2 = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=(5, large_kernel), groups=dim, padding="same"),
nn.BatchNorm2d(dim)
)
self.sk = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=(small_kernel, small_kernel), groups=dim, padding="same"),
nn.BatchNorm2d(dim)
)
self.low_frequency_attn = FrequencyAttention(in_dim=dim, out_dim=dim, reduction=4)
self.high_frequency_attn = FrequencyAttention(in_dim=dim, out_dim=dim, reduction=4)
self.temporal_mixer = InvertedResidual(in_chs=dim, out_chs=dim, dw_kernel_size=7, exp_ratio=mlp_ratio,
se_layer=partial(SqueezeExcite, rd_ratio=0.25), noskip=True)
self.layer_scale_1 = nn.Parameter(ls_init_value * torch.ones((dim)), requires_grad=True)
self.layer_scale_2 = nn.Parameter(ls_init_value * torch.ones((dim)), requires_grad=True)
@torch.jit.ignore
def no_weight_decay(self):
return {'layer_scale_1', 'layer_scale_2'}
def forward(self, x):
attn = self.norm1(x)
x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * (self.low_frequency_attn(self.lk1(attn) + self.lk2(attn)) + self.high_frequency_attn(self.sk(attn))))
x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.temporal_mixer(self.norm2(x)))
return x
class TF_AwareBlocks(nn.Module):
def __init__(self, dim, num_blocks, drop_path, use_bottleneck=None, use_hid=False, mlp_ratio=4., drop=0., ls_init_value=1e-2, large_kernel=51, small_kernel=5):
super().__init__()
assert len(drop_path) == num_blocks, "drop_path list doesn't match num_blocks"
self.use_hid = use_hid
self.use_bottleneck = use_bottleneck
blocks = []
for i in range(num_blocks):
block = TF_AwareBlock(dim, mlp_ratio, drop, ls_init_value, drop_path[i], large_kernel, small_kernel)
blocks.append(block)
self.blocks = nn.Sequential(*blocks)
self.concat_block = nn.Conv2d(dim * 2, dim, 3, 1, 1) if use_hid==True else None
self.DWT = WaveletTransform3D(inverse=False) if use_bottleneck == "decompose" else None
self.IDWT = WaveletTransform3D(inverse=True) if use_bottleneck == "decompose" else None
def forward(self, x, skip=None): # b, c ,t, h, w
if self.concat_block is not None and self.use_bottleneck is None:
b, c, t, h, w = x.shape
x = rearrange(x, 'b c t h w -> b (c t) h w')
x = self.concat_block(torch.cat([x, skip], dim=1))
x = self.blocks(x)
x = rearrange(x, 'b (c t) h w -> b c t h w', t=t)
return x
elif self.concat_block is None and self.use_bottleneck is None:
b, c, t, h, w = x.shape
x = rearrange(x, 'b c t h w -> b (c t) h w')
x = skip= self.blocks(x)
x = rearrange(x, 'b (c t) h w -> b c t h w', t=t)
return x, skip
elif self.use_bottleneck is not None:
LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH = self.DWT(x) if self.use_bottleneck == "decompose" else [x, None, None, None, None, None, None, None]
b, c, t, h, w = LLL.shape
LLL = rearrange(LLL, 'b c t h w -> b (c t) h w')
LLL = self.blocks(LLL)
LLL = rearrange(LLL, 'b (c t) h w -> b c t h w', t=t)
x = self.IDWT([LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH]) if self.use_bottleneck == "decompose" else LLL
return x
class Wavelet_3D_Embedding(nn.Module):
def __init__(self, in_dim, out_dim, emb_dim=None):
super().__init__()
emb_dim = in_dim if emb_dim==None else emb_dim
self.conv_0 = nn.Sequential(nn.Conv3d(in_dim, in_dim, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1),),
nn.BatchNorm3d(in_dim),
nn.GELU(),)
self.conv_1 = nn.Sequential(nn.Conv3d(in_dim, out_dim, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1),),
nn.BatchNorm3d(out_dim),
nn.GELU(),)
self.conv_emb = nn.Conv3d(emb_dim * 4, out_dim, kernel_size=(3, 3, 3),stride=(1, 1, 1),padding=(1, 1, 1),)
self.DWT = WaveletTransform3D(inverse=False)
def forward(self, x, x_emb=None):
# embedding branch
LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH = self.DWT(x_emb)
lo_temp = torch.cat([LLL, LHL, HLL, HHL], dim=1)
hi_temp = torch.cat([LLH, LHH, HLH, HHH], dim=1)
x_emb = torch.cat([lo_temp, hi_temp], dim=2)
x_emb = self.conv_emb(x_emb)
# downsampling branch
x = self.conv_0(x)
LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH = self.DWT(x)
spatio_lo_coeffs = torch.cat([LLL, LLH], dim=2)
spatio_hi_coeffs = torch.cat([LHL, LHH, HLL, HLH, HHL, HHH], dim=1)
x = self.conv_1(spatio_lo_coeffs)
return (x + x_emb), spatio_hi_coeffs
class Wavelet_3D_Reconstruction(nn.Module):
def __init__(self, in_dim, out_dim, hi_dim):
super().__init__()
self.conv_0 = nn.Sequential(nn.Conv3d(in_dim, out_dim, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1),),
nn.BatchNorm3d(out_dim),
nn.GELU(),)
self.conv_hi = nn.Sequential(nn.Conv3d(int(hi_dim * 6), int(out_dim * 6), kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), groups=6),
nn.BatchNorm3d(out_dim * 6),
nn.GELU(),)
self.IDWT = WaveletTransform3D(inverse=True)
def forward(self, x, skip_hi=None):
LLL, LLH = torch.chunk(self.conv_0(x), chunks=2, dim=2)
LHL, LHH, HLL, HLH, HHL, HHH = torch.chunk(self.conv_hi(skip_hi), chunks=6, dim=1)
x = self.IDWT([LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH])
return x
class WaST_level1(nn.Module):
def __init__(self, in_shape, encoder_dim, block_list=[2, 2, 2], drop_path_rate=0.1, mlp_ratio=4., **kwargs):
super().__init__()
frame, in_dim, H, W = in_shape
self.block_list = block_list
dp_list = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.block_list))]
indexes = list(accumulate(block_list))
dp_list = [dp_list[start:end] for start, end in zip([0] + indexes, indexes)]
self.conv_in = nn.Sequential(
nn.Conv3d(
in_dim,
encoder_dim,
kernel_size=(3, 3, 3),
stride=(1, 1, 1),
padding=(1, 1, 1),
),
nn.BatchNorm3d(encoder_dim),
nn.GELU()
)
self.translator1 = TF_AwareBlocks(dim=encoder_dim * frame, num_blocks=block_list[0], drop_path=dp_list[0], mlp_ratio=mlp_ratio, large_kernel=51, small_kernel=5)
self.wavelet_embed1 = Wavelet_3D_Embedding(in_dim=encoder_dim, out_dim=encoder_dim * 2, emb_dim=in_dim) # wavelet_recon2: hi_dim = in_dim
self.bottleneck_translator = TF_AwareBlocks(dim=encoder_dim * 2 * frame, num_blocks=block_list[1], drop_path=dp_list[1], use_bottleneck=True, mlp_ratio=mlp_ratio, large_kernel=21, small_kernel=5)
self.wavelet_recon1 = Wavelet_3D_Reconstruction(in_dim=encoder_dim * 2, out_dim=encoder_dim, hi_dim=encoder_dim)
self.translator2 = TF_AwareBlocks(dim=encoder_dim * frame, num_blocks=block_list[2], drop_path=dp_list[2], use_hid=True, mlp_ratio=mlp_ratio, large_kernel=51, small_kernel=5)
self.conv_out = nn.Sequential(
nn.BatchNorm3d(encoder_dim),
nn.GELU(),
nn.Conv3d(
encoder_dim,
in_dim,
kernel_size=(3, 3, 3),
stride=(1, 1, 1),
padding=(1, 1, 1))
)
def update_drop_path(self, drop_path_rate):
dp_list = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.block_list))]
indexes = list(accumulate(self.block_list))
dp_lists = [dp_list[start:end] for start, end in zip([0] + indexes, indexes)]
dp_apply_blocks = [self.translator1.blocks, self.bottleneck_translator.blocks, self.translator2.blocks]
for translators, dp_list_translators in zip(dp_apply_blocks, dp_lists):
for translator, dp_list_translator in zip(translators, dp_list_translators):
translator.drop_path.drop_prob = dp_list_translator
def forward(self, x):
x = rearrange(x, 'b t c h w -> b c t h w')
ori_img = x
x = self.conv_in(x)
x, tskip1 = self.translator1(x)
x, skip1 = self.wavelet_embed1(x, x_emb=ori_img)
x = self.bottleneck_translator(x)
x = self.wavelet_recon1(x, skip1)
x = self.translator2(x, tskip1)
x = self.conv_out(x)
x = rearrange(x, 'b c t h w -> b t c h w')
return x
if __name__ == "__main__":
from fvcore.nn import FlopCountAnalysis, flop_count_table
# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "3"
model = WaST_level1(in_shape=(4, 2, 32, 32), encoder_dim=20, block_list=[2, 8, 2]).cuda()
print(model)
dummy_tensor = torch.rand(1, 4, 2, 32, 32).cuda()
output = model(dummy_tensor)
print(f"input shape is {dummy_tensor.shape}, output shape is {output.shape}...")
flops = FlopCountAnalysis(model, dummy_tensor)
print(flop_count_table(flops))