Spaces:
Running
Running
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch | |
| # from nets.autoencoders.cvViT import CVViT | |
| from .blocks.complexblock import CVConvNeXtBlock, ComplexDConvBlock, ComplexConv1x1Block | |
| from .blocks.unetblock import UnetBasicBlock, UnetPrUpBlock | |
| from .vit import ViT | |
| class CVEncoder(nn.Module): | |
| def __init__(self, in_channels=2, hidden_dims=None, use_max_pool=True, **kwargs): | |
| super(CVEncoder, self).__init__() | |
| if hidden_dims is None: | |
| hidden_dims = [64, 128, 256, 512, 512, 512, 512] | |
| self.non_constant_depth = self.count_non_constant_hidden_dims(hidden_dims) | |
| self.constant_depth = len(hidden_dims) - self.non_constant_depth | |
| modules = [] | |
| previous_dim = in_channels | |
| # Build Encoder | |
| for i, h_dim in enumerate(hidden_dims): | |
| # For Encoder Part 1, use dilated conv blocks | |
| if i < self.non_constant_depth: | |
| # enc_block = ComplexDConvBlock(previous_dim, h_dim, kernel_size=3, stride=1, dilation=2**(self.non_constant_depth-i)) | |
| enc_block = ComplexDConvBlock(previous_dim, h_dim, kernel_size=3, stride=1, dilation=1) | |
| # For Encoder Part 2, Channel-wise pooling with constant feature maps. | |
| else: | |
| enc_block = ComplexConv1x1Block(h_dim, h_dim * 2, kernel_size=3, dilation=1) | |
| modules.append(enc_block) | |
| previous_dim = h_dim | |
| # Build Encoder | |
| self.complex_encoder = nn.ModuleList(modules) | |
| self.use_max_pool = use_max_pool | |
| def count_non_constant_hidden_dims(self, hidden_dims): | |
| count = 1 | |
| for i in range(1, len(hidden_dims)): | |
| if hidden_dims[i] == hidden_dims[i-1]: | |
| break | |
| count += 1 | |
| return count | |
| def forward(self, x): | |
| laterals = [] | |
| for i, layer in enumerate(self.complex_encoder): | |
| x = layer(x) | |
| laterals.append(x) | |
| if self.use_max_pool: # and i < self.non_constant_depth - 1: # Apply max pooling only to the non-constant part | |
| x = F.max_pool2d(x, 2) | |
| return x, laterals | |
| class ViTUnetEncoder(nn.Module): | |
| def __init__(self, in_channels=2, feature_size=[256, 256], patch_size=16, hidden_size=768, num_layers=4, mlp_ratio=4, num_heads=8, kernel_size=3, stride=1, **kwargs): | |
| super(ViTUnetEncoder, self).__init__() | |
| H, W = feature_size | |
| assert H == W, "Currently only supports square feature maps" | |
| token_size = H // patch_size # e.g., 256 // 16 = 16 tokens per side | |
| self.hidden_size = hidden_size | |
| self.token_size = token_size | |
| self.num_layers = num_layers | |
| self.visual_transformer = ViT( | |
| feature_size=feature_size, | |
| patch_size=patch_size, | |
| in_channels=in_channels, | |
| embed_dim=hidden_size, | |
| mlp_ratio=mlp_ratio, | |
| num_layers=num_layers, | |
| num_heads=num_heads, | |
| ) | |
| # self.visual_transformer = CVViT( | |
| # feature_size=feature_size, | |
| # patch_size=patch_size, | |
| # in_channels=in_channels, | |
| # embed_dim=hidden_size, | |
| # mlp_ratio=mlp_ratio, | |
| # num_layers=num_layers, | |
| # num_heads=num_heads, | |
| # ) | |
| self.complex_proj = nn.Conv2d( | |
| in_channels=in_channels, | |
| out_channels=2, | |
| kernel_size=(3, 3), | |
| stride=(1, 1), | |
| padding=(1, 1) | |
| ) | |
| self.inchannels = in_channels | |
| self.encoder1 = UnetBasicBlock(in_channels=2, out_channels=token_size, kernel_size=3, stride=1, residual=True) | |
| self.encoder2 = UnetPrUpBlock(in_channels=hidden_size, out_channels=token_size * 2, num_layers=2, kernel_size=kernel_size, stride=stride) # x2 -> 32 | |
| self.encoder3 = UnetPrUpBlock(in_channels=hidden_size, out_channels=token_size * 4, num_layers=1, kernel_size=kernel_size, stride=stride) # x4 -> 64 | |
| self.encoder4 = UnetPrUpBlock(in_channels=hidden_size, out_channels=token_size * 8, num_layers=0, kernel_size=kernel_size, stride=stride) # x8 -> 128 | |
| def proj_feat(self, x, hidden_size, token_size): | |
| x = x.view(x.size(0), token_size, token_size, hidden_size) # [B T C] -> [B, token_size, token_size, hidden_size] | |
| x = x.permute(0, 3, 1, 2).contiguous() # B C H W | |
| return x | |
| def forward(self, x_in, skip_connections=False): | |
| x, hidden_states = self.visual_transformer(x_in) # [B, T, C] | |
| residual = None | |
| if skip_connections: | |
| if self.inchannels != 2: | |
| x_in = self.complex_proj(x_in) # Assume input is mag, convert to complex by adding a imaginary part | |
| enc1 = self.encoder1(x_in) | |
| x2 = hidden_states[self.num_layers // 4 * 1 -1] | |
| enc2 = self.encoder2(self.proj_feat(x2, self.hidden_size, self.token_size)) | |
| x3 = hidden_states[self.num_layers // 4 * 2 -1] | |
| enc3 = self.encoder3(self.proj_feat(x3, self.hidden_size, self.token_size)) | |
| x4 = hidden_states[self.num_layers // 4 * 3 -1] | |
| enc4 = self.encoder4(self.proj_feat(x4, self.hidden_size, self.token_size)) | |
| residual = [enc1, enc2, enc3, enc4] | |
| x = self.proj_feat(x, self.hidden_size, self.token_size) # [B, T, C] -> [B, C, H, W] | |
| return x, residual | |
| class CVConvNextEncoder(nn.Module): | |
| def __init__(self, | |
| hidden_dims=512, | |
| intermediate_dim=1356, | |
| num_layers=4, | |
| complex_axis=1, | |
| layer_scale_init_value=None, | |
| **kwargs): | |
| super(CVConvNextEncoder, self).__init__() | |
| layer_scale_init_value = layer_scale_init_value or 1 / num_layers | |
| self.blocks = nn.ModuleList( | |
| [ | |
| CVConvNeXtBlock( | |
| dim=hidden_dims, | |
| intermediate_dim=intermediate_dim, | |
| layer_scale_init_value=layer_scale_init_value, | |
| complex_axis=complex_axis, | |
| ) | |
| for _ in range(num_layers) | |
| ] | |
| ) | |
| self.final_layer_norm = nn.LayerNorm(hidden_dims, eps=1e-6) | |
| self.complex_axis = complex_axis | |
| def forward(self, x): | |
| laterals = [] | |
| for layer in self.blocks: | |
| x = layer(x) | |
| res = x.transpose(1, 2) # [B, C, T] -> [B, T, C] | |
| laterals.append(res[:, 1:]) # Remove CLS token | |
| real, imag = torch.chunk(x, 2, dim=self.complex_axis) # Split real and imaginary parts | |
| real = self.final_layer_norm(real.transpose(1, 2)).transpose(1, 2) # Apply LayerNorm to real part | |
| imag = self.final_layer_norm(imag.transpose(1, 2)).transpose(1, 2) # Apply LayerNorm to imaginary part | |
| x = torch.cat([real, imag], dim=self.complex_axis) # Concatenate real and imaginary parts back together | |
| x = x.transpose(1, 2) # [B, C, T] -> [B, T, C] | |
| return x, laterals | |