U-Past / modules /decoder.py
lycaoduong's picture
Initial space
e8160b2 verified
import torch.nn as nn
import torch
from .blocks.complexblock import CVConvNeXtDBlock, ComplexUpsampleUnet, ComplexUpConstantUNet, ComplexLinearLayer
from .blocks.complexmodule import ComplexConv1D
from .blocks.unetblock import UnetUpBlock
class CVDecoder(nn.Module):
def __init__(self, hidden_dims: list = None,
**kwargs) -> None:
super().__init__()
if hidden_dims is None:
hidden_dims = [64, 128, 256, 512, 512, 512, 512]
self.non_constant_depth = self.count_non_constant_hidden_dims(hidden_dims)
self.constant_depth = len(hidden_dims) - self.non_constant_depth
modules = []
latent_dim = hidden_dims[-1]
pre_h_dim = latent_dim
# Build Decoder in reverse
for i in reversed(range(len(hidden_dims))):
h_dim = hidden_dims[i]
if i >= self.non_constant_depth: # For constant part, use constant upsample blocks
dec_block = ComplexUpConstantUNet(latent_dim, dilation=(1, 1), padding=(1, 1), output_padding=(1, 1))
else:
# pre_h_dim = hidden_dims[i+1]
dec_block = ComplexUpsampleUnet(pre_h_dim, h_dim, dilation=(1, 1), padding=(1, 1), output_padding=(1, 1))
pre_h_dim = h_dim
modules.append(dec_block)
# Adjusting lateral dimension
self.lateral_projection = ComplexLinearLayer(hidden_dims[-1], hidden_dims[-1]//2)
self.complex_decoder = nn.ModuleList(modules)
def count_non_constant_hidden_dims(self, hidden_dims):
count = 1
for i in range(1, len(hidden_dims)):
if hidden_dims[i] == hidden_dims[i-1]:
break
count += 1
return count
def forward(self, x, laterals=None):
# tem_up = []
for i, layer in enumerate(self.complex_decoder):
if laterals is not None:
residual = laterals[-i -1]
if i == self.constant_depth:
residual = self.lateral_projection(residual)
else:
residual = None
x = layer(x, residual)
# tem_up.append(x)
return x
class ViTUnetDecoder(nn.Module):
def __init__(self, feature_size=[256, 256], patch_size=16, hidden_size=768, num_layers=4, kernel_size=3, stride=1, **kwargs):
super(ViTUnetDecoder, self).__init__()
H, W = feature_size
assert H == W, "Currently only supports square feature maps"
token_size = H // patch_size # e.g., 256 // 16 = 16 tokens per side
self.hidden_size = hidden_size
self.token_size = token_size
self.num_layers = num_layers
# Decoder
self.decoder5 = UnetUpBlock(in_channels=hidden_size, out_channels=self.token_size * 8, kernel_size=kernel_size, stride=stride) # x8 -> 128
self.decoder4 = UnetUpBlock(in_channels=self.token_size * 8, out_channels=self.token_size * 4, kernel_size=kernel_size, stride=stride) # x4 -> 64
self.decoder3 = UnetUpBlock(in_channels=self.token_size * 4, out_channels=self.token_size * 2, kernel_size=kernel_size, stride=stride) # x2 -> 32
self.decoder2 = UnetUpBlock(in_channels=self.token_size * 2, out_channels=self.token_size, kernel_size=kernel_size, stride=stride) # x1 -> 16
# def proj_feat(self, x, hidden_size, token_size):
# x = x.view(x.size(0), token_size, token_size, hidden_size)
# x = x.permute(0, 3, 1, 2).contiguous() # B C H W
# return x
def forward(self, x, residuals=None):
dec4 = x
if residuals is not None:
dec3 = self.decoder5(dec4, residuals[-1]) # enc4
dec2 = self.decoder4(dec3, residuals[-2]) # enc3
dec1 = self.decoder3(dec2, residuals[-3]) # enc2
out = self.decoder2(dec1, residuals[-4]) # enc1
else:
dec3 = self.decoder5(dec4)
dec2 = self.decoder4(dec3)
dec1 = self.decoder3(dec2)
out = self.decoder2(dec1)
return out
class CVConvNextDecoder(nn.Module):
def __init__(self,
input_dims=256,
hidden_dims=512,
intermediate_dim=1356,
num_layers=4,
complex_axis=1,
layer_scale_init_value=None,
**kwargs):
super(CVConvNextDecoder, self).__init__()
layer_scale_init_value = layer_scale_init_value or 1 / num_layers
self.blocks = nn.ModuleList(
[
CVConvNeXtDBlock(
dim=hidden_dims,
intermediate_dim=intermediate_dim,
layer_scale_init_value=layer_scale_init_value,
complex_axis=complex_axis,
)
for _ in range(num_layers)
]
)
self.final_layer_norm = nn.LayerNorm(hidden_dims, eps=1e-6)
self.complex_axis = complex_axis
self.enc1 = ComplexConv1D(in_channels=input_dims, out_channels=hidden_dims, kernel_size=3, padding=1, complex_axis=1)
self.num_layers = num_layers
def forward(self, x, x_in=None, laterals=None):
if x_in is not None:
# inputs: [B, 2, F, T]
B, C, F, T = x_in.shape # C = 2
# [B, 2, F, T] -> [B, C, T]
x_in = x_in.reshape(B, C * F, T)
if laterals is not None:
enc1 = self.enc1(x_in)
enc2 = laterals[self.num_layers // 4 * 1 -1]
enc3 = laterals[self.num_layers // 4 * 2 -1]
enc4 = laterals[self.num_layers // 4 * 3 -1]
residuals = [enc1, enc2, enc3, enc4]
for i, layer in enumerate(self.blocks):
if laterals is not None:
residual = residuals[-i-1]
else:
residual = None
x = layer(x, residual)
real, imag = torch.chunk(x, 2, dim=self.complex_axis) # Split real and imaginary parts
real = self.final_layer_norm(real.transpose(1, 2)).transpose(1, 2) # Apply LayerNorm to real part
imag = self.final_layer_norm(imag.transpose(1, 2)).transpose(1, 2) # Apply LayerNorm to imaginary part
x = torch.cat([real, imag], dim=self.complex_axis) # Concatenate real and imaginary parts back together
return x