import torch.nn as nn import torch from .blocks.complexblock import CVConvNeXtDBlock, ComplexUpsampleUnet, ComplexUpConstantUNet, ComplexLinearLayer from .blocks.complexmodule import ComplexConv1D from .blocks.unetblock import UnetUpBlock class CVDecoder(nn.Module): def __init__(self, hidden_dims: list = None, **kwargs) -> None: super().__init__() if hidden_dims is None: hidden_dims = [64, 128, 256, 512, 512, 512, 512] self.non_constant_depth = self.count_non_constant_hidden_dims(hidden_dims) self.constant_depth = len(hidden_dims) - self.non_constant_depth modules = [] latent_dim = hidden_dims[-1] pre_h_dim = latent_dim # Build Decoder in reverse for i in reversed(range(len(hidden_dims))): h_dim = hidden_dims[i] if i >= self.non_constant_depth: # For constant part, use constant upsample blocks dec_block = ComplexUpConstantUNet(latent_dim, dilation=(1, 1), padding=(1, 1), output_padding=(1, 1)) else: # pre_h_dim = hidden_dims[i+1] dec_block = ComplexUpsampleUnet(pre_h_dim, h_dim, dilation=(1, 1), padding=(1, 1), output_padding=(1, 1)) pre_h_dim = h_dim modules.append(dec_block) # Adjusting lateral dimension self.lateral_projection = ComplexLinearLayer(hidden_dims[-1], hidden_dims[-1]//2) self.complex_decoder = nn.ModuleList(modules) def count_non_constant_hidden_dims(self, hidden_dims): count = 1 for i in range(1, len(hidden_dims)): if hidden_dims[i] == hidden_dims[i-1]: break count += 1 return count def forward(self, x, laterals=None): # tem_up = [] for i, layer in enumerate(self.complex_decoder): if laterals is not None: residual = laterals[-i -1] if i == self.constant_depth: residual = self.lateral_projection(residual) else: residual = None x = layer(x, residual) # tem_up.append(x) return x class ViTUnetDecoder(nn.Module): def __init__(self, feature_size=[256, 256], patch_size=16, hidden_size=768, num_layers=4, kernel_size=3, stride=1, **kwargs): super(ViTUnetDecoder, self).__init__() H, W = feature_size assert H == W, "Currently only supports square feature maps" token_size = H // patch_size # e.g., 256 // 16 = 16 tokens per side self.hidden_size = hidden_size self.token_size = token_size self.num_layers = num_layers # Decoder self.decoder5 = UnetUpBlock(in_channels=hidden_size, out_channels=self.token_size * 8, kernel_size=kernel_size, stride=stride) # x8 -> 128 self.decoder4 = UnetUpBlock(in_channels=self.token_size * 8, out_channels=self.token_size * 4, kernel_size=kernel_size, stride=stride) # x4 -> 64 self.decoder3 = UnetUpBlock(in_channels=self.token_size * 4, out_channels=self.token_size * 2, kernel_size=kernel_size, stride=stride) # x2 -> 32 self.decoder2 = UnetUpBlock(in_channels=self.token_size * 2, out_channels=self.token_size, kernel_size=kernel_size, stride=stride) # x1 -> 16 # def proj_feat(self, x, hidden_size, token_size): # x = x.view(x.size(0), token_size, token_size, hidden_size) # x = x.permute(0, 3, 1, 2).contiguous() # B C H W # return x def forward(self, x, residuals=None): dec4 = x if residuals is not None: dec3 = self.decoder5(dec4, residuals[-1]) # enc4 dec2 = self.decoder4(dec3, residuals[-2]) # enc3 dec1 = self.decoder3(dec2, residuals[-3]) # enc2 out = self.decoder2(dec1, residuals[-4]) # enc1 else: dec3 = self.decoder5(dec4) dec2 = self.decoder4(dec3) dec1 = self.decoder3(dec2) out = self.decoder2(dec1) return out class CVConvNextDecoder(nn.Module): def __init__(self, input_dims=256, hidden_dims=512, intermediate_dim=1356, num_layers=4, complex_axis=1, layer_scale_init_value=None, **kwargs): super(CVConvNextDecoder, self).__init__() layer_scale_init_value = layer_scale_init_value or 1 / num_layers self.blocks = nn.ModuleList( [ CVConvNeXtDBlock( dim=hidden_dims, intermediate_dim=intermediate_dim, layer_scale_init_value=layer_scale_init_value, complex_axis=complex_axis, ) for _ in range(num_layers) ] ) self.final_layer_norm = nn.LayerNorm(hidden_dims, eps=1e-6) self.complex_axis = complex_axis self.enc1 = ComplexConv1D(in_channels=input_dims, out_channels=hidden_dims, kernel_size=3, padding=1, complex_axis=1) self.num_layers = num_layers def forward(self, x, x_in=None, laterals=None): if x_in is not None: # inputs: [B, 2, F, T] B, C, F, T = x_in.shape # C = 2 # [B, 2, F, T] -> [B, C, T] x_in = x_in.reshape(B, C * F, T) if laterals is not None: enc1 = self.enc1(x_in) enc2 = laterals[self.num_layers // 4 * 1 -1] enc3 = laterals[self.num_layers // 4 * 2 -1] enc4 = laterals[self.num_layers // 4 * 3 -1] residuals = [enc1, enc2, enc3, enc4] for i, layer in enumerate(self.blocks): if laterals is not None: residual = residuals[-i-1] else: residual = None x = layer(x, residual) real, imag = torch.chunk(x, 2, dim=self.complex_axis) # Split real and imaginary parts real = self.final_layer_norm(real.transpose(1, 2)).transpose(1, 2) # Apply LayerNorm to real part imag = self.final_layer_norm(imag.transpose(1, 2)).transpose(1, 2) # Apply LayerNorm to imaginary part x = torch.cat([real, imag], dim=self.complex_axis) # Concatenate real and imaginary parts back together return x