import torch.nn as nn import torch.nn.functional as F import torch # from nets.autoencoders.cvViT import CVViT from .blocks.complexblock import CVConvNeXtBlock, ComplexDConvBlock, ComplexConv1x1Block from .blocks.unetblock import UnetBasicBlock, UnetPrUpBlock from .vit import ViT class CVEncoder(nn.Module): def __init__(self, in_channels=2, hidden_dims=None, use_max_pool=True, **kwargs): super(CVEncoder, self).__init__() if hidden_dims is None: hidden_dims = [64, 128, 256, 512, 512, 512, 512] self.non_constant_depth = self.count_non_constant_hidden_dims(hidden_dims) self.constant_depth = len(hidden_dims) - self.non_constant_depth modules = [] previous_dim = in_channels # Build Encoder for i, h_dim in enumerate(hidden_dims): # For Encoder Part 1, use dilated conv blocks if i < self.non_constant_depth: # enc_block = ComplexDConvBlock(previous_dim, h_dim, kernel_size=3, stride=1, dilation=2**(self.non_constant_depth-i)) enc_block = ComplexDConvBlock(previous_dim, h_dim, kernel_size=3, stride=1, dilation=1) # For Encoder Part 2, Channel-wise pooling with constant feature maps. else: enc_block = ComplexConv1x1Block(h_dim, h_dim * 2, kernel_size=3, dilation=1) modules.append(enc_block) previous_dim = h_dim # Build Encoder self.complex_encoder = nn.ModuleList(modules) self.use_max_pool = use_max_pool def count_non_constant_hidden_dims(self, hidden_dims): count = 1 for i in range(1, len(hidden_dims)): if hidden_dims[i] == hidden_dims[i-1]: break count += 1 return count def forward(self, x): laterals = [] for i, layer in enumerate(self.complex_encoder): x = layer(x) laterals.append(x) if self.use_max_pool: # and i < self.non_constant_depth - 1: # Apply max pooling only to the non-constant part x = F.max_pool2d(x, 2) return x, laterals class ViTUnetEncoder(nn.Module): def __init__(self, in_channels=2, feature_size=[256, 256], patch_size=16, hidden_size=768, num_layers=4, mlp_ratio=4, num_heads=8, kernel_size=3, stride=1, **kwargs): super(ViTUnetEncoder, self).__init__() H, W = feature_size assert H == W, "Currently only supports square feature maps" token_size = H // patch_size # e.g., 256 // 16 = 16 tokens per side self.hidden_size = hidden_size self.token_size = token_size self.num_layers = num_layers self.visual_transformer = ViT( feature_size=feature_size, patch_size=patch_size, in_channels=in_channels, embed_dim=hidden_size, mlp_ratio=mlp_ratio, num_layers=num_layers, num_heads=num_heads, ) # self.visual_transformer = CVViT( # feature_size=feature_size, # patch_size=patch_size, # in_channels=in_channels, # embed_dim=hidden_size, # mlp_ratio=mlp_ratio, # num_layers=num_layers, # num_heads=num_heads, # ) self.complex_proj = nn.Conv2d( in_channels=in_channels, out_channels=2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) ) self.inchannels = in_channels self.encoder1 = UnetBasicBlock(in_channels=2, out_channels=token_size, kernel_size=3, stride=1, residual=True) self.encoder2 = UnetPrUpBlock(in_channels=hidden_size, out_channels=token_size * 2, num_layers=2, kernel_size=kernel_size, stride=stride) # x2 -> 32 self.encoder3 = UnetPrUpBlock(in_channels=hidden_size, out_channels=token_size * 4, num_layers=1, kernel_size=kernel_size, stride=stride) # x4 -> 64 self.encoder4 = UnetPrUpBlock(in_channels=hidden_size, out_channels=token_size * 8, num_layers=0, kernel_size=kernel_size, stride=stride) # x8 -> 128 def proj_feat(self, x, hidden_size, token_size): x = x.view(x.size(0), token_size, token_size, hidden_size) # [B T C] -> [B, token_size, token_size, hidden_size] x = x.permute(0, 3, 1, 2).contiguous() # B C H W return x def forward(self, x_in, skip_connections=False): x, hidden_states = self.visual_transformer(x_in) # [B, T, C] residual = None if skip_connections: if self.inchannels != 2: x_in = self.complex_proj(x_in) # Assume input is mag, convert to complex by adding a imaginary part enc1 = self.encoder1(x_in) x2 = hidden_states[self.num_layers // 4 * 1 -1] enc2 = self.encoder2(self.proj_feat(x2, self.hidden_size, self.token_size)) x3 = hidden_states[self.num_layers // 4 * 2 -1] enc3 = self.encoder3(self.proj_feat(x3, self.hidden_size, self.token_size)) x4 = hidden_states[self.num_layers // 4 * 3 -1] enc4 = self.encoder4(self.proj_feat(x4, self.hidden_size, self.token_size)) residual = [enc1, enc2, enc3, enc4] x = self.proj_feat(x, self.hidden_size, self.token_size) # [B, T, C] -> [B, C, H, W] return x, residual class CVConvNextEncoder(nn.Module): def __init__(self, hidden_dims=512, intermediate_dim=1356, num_layers=4, complex_axis=1, layer_scale_init_value=None, **kwargs): super(CVConvNextEncoder, self).__init__() layer_scale_init_value = layer_scale_init_value or 1 / num_layers self.blocks = nn.ModuleList( [ CVConvNeXtBlock( dim=hidden_dims, intermediate_dim=intermediate_dim, layer_scale_init_value=layer_scale_init_value, complex_axis=complex_axis, ) for _ in range(num_layers) ] ) self.final_layer_norm = nn.LayerNorm(hidden_dims, eps=1e-6) self.complex_axis = complex_axis def forward(self, x): laterals = [] for layer in self.blocks: x = layer(x) res = x.transpose(1, 2) # [B, C, T] -> [B, T, C] laterals.append(res[:, 1:]) # Remove CLS token real, imag = torch.chunk(x, 2, dim=self.complex_axis) # Split real and imaginary parts real = self.final_layer_norm(real.transpose(1, 2)).transpose(1, 2) # Apply LayerNorm to real part imag = self.final_layer_norm(imag.transpose(1, 2)).transpose(1, 2) # Apply LayerNorm to imaginary part x = torch.cat([real, imag], dim=self.complex_axis) # Concatenate real and imaginary parts back together x = x.transpose(1, 2) # [B, C, T] -> [B, T, C] return x, laterals