Spaces:
Running
Running
| import torch.nn as nn | |
| import torch | |
| from .blocks.complexblock import CVConvNeXtDBlock, ComplexUpsampleUnet, ComplexUpConstantUNet, ComplexLinearLayer | |
| from .blocks.complexmodule import ComplexConv1D | |
| from .blocks.unetblock import UnetUpBlock | |
| class CVDecoder(nn.Module): | |
| def __init__(self, hidden_dims: list = None, | |
| **kwargs) -> None: | |
| super().__init__() | |
| if hidden_dims is None: | |
| hidden_dims = [64, 128, 256, 512, 512, 512, 512] | |
| self.non_constant_depth = self.count_non_constant_hidden_dims(hidden_dims) | |
| self.constant_depth = len(hidden_dims) - self.non_constant_depth | |
| modules = [] | |
| latent_dim = hidden_dims[-1] | |
| pre_h_dim = latent_dim | |
| # Build Decoder in reverse | |
| for i in reversed(range(len(hidden_dims))): | |
| h_dim = hidden_dims[i] | |
| if i >= self.non_constant_depth: # For constant part, use constant upsample blocks | |
| dec_block = ComplexUpConstantUNet(latent_dim, dilation=(1, 1), padding=(1, 1), output_padding=(1, 1)) | |
| else: | |
| # pre_h_dim = hidden_dims[i+1] | |
| dec_block = ComplexUpsampleUnet(pre_h_dim, h_dim, dilation=(1, 1), padding=(1, 1), output_padding=(1, 1)) | |
| pre_h_dim = h_dim | |
| modules.append(dec_block) | |
| # Adjusting lateral dimension | |
| self.lateral_projection = ComplexLinearLayer(hidden_dims[-1], hidden_dims[-1]//2) | |
| self.complex_decoder = nn.ModuleList(modules) | |
| def count_non_constant_hidden_dims(self, hidden_dims): | |
| count = 1 | |
| for i in range(1, len(hidden_dims)): | |
| if hidden_dims[i] == hidden_dims[i-1]: | |
| break | |
| count += 1 | |
| return count | |
| def forward(self, x, laterals=None): | |
| # tem_up = [] | |
| for i, layer in enumerate(self.complex_decoder): | |
| if laterals is not None: | |
| residual = laterals[-i -1] | |
| if i == self.constant_depth: | |
| residual = self.lateral_projection(residual) | |
| else: | |
| residual = None | |
| x = layer(x, residual) | |
| # tem_up.append(x) | |
| return x | |
| class ViTUnetDecoder(nn.Module): | |
| def __init__(self, feature_size=[256, 256], patch_size=16, hidden_size=768, num_layers=4, kernel_size=3, stride=1, **kwargs): | |
| super(ViTUnetDecoder, self).__init__() | |
| H, W = feature_size | |
| assert H == W, "Currently only supports square feature maps" | |
| token_size = H // patch_size # e.g., 256 // 16 = 16 tokens per side | |
| self.hidden_size = hidden_size | |
| self.token_size = token_size | |
| self.num_layers = num_layers | |
| # Decoder | |
| self.decoder5 = UnetUpBlock(in_channels=hidden_size, out_channels=self.token_size * 8, kernel_size=kernel_size, stride=stride) # x8 -> 128 | |
| self.decoder4 = UnetUpBlock(in_channels=self.token_size * 8, out_channels=self.token_size * 4, kernel_size=kernel_size, stride=stride) # x4 -> 64 | |
| self.decoder3 = UnetUpBlock(in_channels=self.token_size * 4, out_channels=self.token_size * 2, kernel_size=kernel_size, stride=stride) # x2 -> 32 | |
| self.decoder2 = UnetUpBlock(in_channels=self.token_size * 2, out_channels=self.token_size, kernel_size=kernel_size, stride=stride) # x1 -> 16 | |
| # def proj_feat(self, x, hidden_size, token_size): | |
| # x = x.view(x.size(0), token_size, token_size, hidden_size) | |
| # x = x.permute(0, 3, 1, 2).contiguous() # B C H W | |
| # return x | |
| def forward(self, x, residuals=None): | |
| dec4 = x | |
| if residuals is not None: | |
| dec3 = self.decoder5(dec4, residuals[-1]) # enc4 | |
| dec2 = self.decoder4(dec3, residuals[-2]) # enc3 | |
| dec1 = self.decoder3(dec2, residuals[-3]) # enc2 | |
| out = self.decoder2(dec1, residuals[-4]) # enc1 | |
| else: | |
| dec3 = self.decoder5(dec4) | |
| dec2 = self.decoder4(dec3) | |
| dec1 = self.decoder3(dec2) | |
| out = self.decoder2(dec1) | |
| return out | |
| class CVConvNextDecoder(nn.Module): | |
| def __init__(self, | |
| input_dims=256, | |
| hidden_dims=512, | |
| intermediate_dim=1356, | |
| num_layers=4, | |
| complex_axis=1, | |
| layer_scale_init_value=None, | |
| **kwargs): | |
| super(CVConvNextDecoder, self).__init__() | |
| layer_scale_init_value = layer_scale_init_value or 1 / num_layers | |
| self.blocks = nn.ModuleList( | |
| [ | |
| CVConvNeXtDBlock( | |
| dim=hidden_dims, | |
| intermediate_dim=intermediate_dim, | |
| layer_scale_init_value=layer_scale_init_value, | |
| complex_axis=complex_axis, | |
| ) | |
| for _ in range(num_layers) | |
| ] | |
| ) | |
| self.final_layer_norm = nn.LayerNorm(hidden_dims, eps=1e-6) | |
| self.complex_axis = complex_axis | |
| self.enc1 = ComplexConv1D(in_channels=input_dims, out_channels=hidden_dims, kernel_size=3, padding=1, complex_axis=1) | |
| self.num_layers = num_layers | |
| def forward(self, x, x_in=None, laterals=None): | |
| if x_in is not None: | |
| # inputs: [B, 2, F, T] | |
| B, C, F, T = x_in.shape # C = 2 | |
| # [B, 2, F, T] -> [B, C, T] | |
| x_in = x_in.reshape(B, C * F, T) | |
| if laterals is not None: | |
| enc1 = self.enc1(x_in) | |
| enc2 = laterals[self.num_layers // 4 * 1 -1] | |
| enc3 = laterals[self.num_layers // 4 * 2 -1] | |
| enc4 = laterals[self.num_layers // 4 * 3 -1] | |
| residuals = [enc1, enc2, enc3, enc4] | |
| for i, layer in enumerate(self.blocks): | |
| if laterals is not None: | |
| residual = residuals[-i-1] | |
| else: | |
| residual = None | |
| x = layer(x, residual) | |
| real, imag = torch.chunk(x, 2, dim=self.complex_axis) # Split real and imaginary parts | |
| real = self.final_layer_norm(real.transpose(1, 2)).transpose(1, 2) # Apply LayerNorm to real part | |
| imag = self.final_layer_norm(imag.transpose(1, 2)).transpose(1, 2) # Apply LayerNorm to imaginary part | |
| x = torch.cat([real, imag], dim=self.complex_axis) # Concatenate real and imaginary parts back together | |
| return x | |