U-Past / modules /encoder.py
lycaoduong's picture
Initial space
e8160b2 verified
import torch.nn as nn
import torch.nn.functional as F
import torch
# from nets.autoencoders.cvViT import CVViT
from .blocks.complexblock import CVConvNeXtBlock, ComplexDConvBlock, ComplexConv1x1Block
from .blocks.unetblock import UnetBasicBlock, UnetPrUpBlock
from .vit import ViT
class CVEncoder(nn.Module):
def __init__(self, in_channels=2, hidden_dims=None, use_max_pool=True, **kwargs):
super(CVEncoder, self).__init__()
if hidden_dims is None:
hidden_dims = [64, 128, 256, 512, 512, 512, 512]
self.non_constant_depth = self.count_non_constant_hidden_dims(hidden_dims)
self.constant_depth = len(hidden_dims) - self.non_constant_depth
modules = []
previous_dim = in_channels
# Build Encoder
for i, h_dim in enumerate(hidden_dims):
# For Encoder Part 1, use dilated conv blocks
if i < self.non_constant_depth:
# enc_block = ComplexDConvBlock(previous_dim, h_dim, kernel_size=3, stride=1, dilation=2**(self.non_constant_depth-i))
enc_block = ComplexDConvBlock(previous_dim, h_dim, kernel_size=3, stride=1, dilation=1)
# For Encoder Part 2, Channel-wise pooling with constant feature maps.
else:
enc_block = ComplexConv1x1Block(h_dim, h_dim * 2, kernel_size=3, dilation=1)
modules.append(enc_block)
previous_dim = h_dim
# Build Encoder
self.complex_encoder = nn.ModuleList(modules)
self.use_max_pool = use_max_pool
def count_non_constant_hidden_dims(self, hidden_dims):
count = 1
for i in range(1, len(hidden_dims)):
if hidden_dims[i] == hidden_dims[i-1]:
break
count += 1
return count
def forward(self, x):
laterals = []
for i, layer in enumerate(self.complex_encoder):
x = layer(x)
laterals.append(x)
if self.use_max_pool: # and i < self.non_constant_depth - 1: # Apply max pooling only to the non-constant part
x = F.max_pool2d(x, 2)
return x, laterals
class ViTUnetEncoder(nn.Module):
def __init__(self, in_channels=2, feature_size=[256, 256], patch_size=16, hidden_size=768, num_layers=4, mlp_ratio=4, num_heads=8, kernel_size=3, stride=1, **kwargs):
super(ViTUnetEncoder, self).__init__()
H, W = feature_size
assert H == W, "Currently only supports square feature maps"
token_size = H // patch_size # e.g., 256 // 16 = 16 tokens per side
self.hidden_size = hidden_size
self.token_size = token_size
self.num_layers = num_layers
self.visual_transformer = ViT(
feature_size=feature_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=hidden_size,
mlp_ratio=mlp_ratio,
num_layers=num_layers,
num_heads=num_heads,
)
# self.visual_transformer = CVViT(
# feature_size=feature_size,
# patch_size=patch_size,
# in_channels=in_channels,
# embed_dim=hidden_size,
# mlp_ratio=mlp_ratio,
# num_layers=num_layers,
# num_heads=num_heads,
# )
self.complex_proj = nn.Conv2d(
in_channels=in_channels,
out_channels=2,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1)
)
self.inchannels = in_channels
self.encoder1 = UnetBasicBlock(in_channels=2, out_channels=token_size, kernel_size=3, stride=1, residual=True)
self.encoder2 = UnetPrUpBlock(in_channels=hidden_size, out_channels=token_size * 2, num_layers=2, kernel_size=kernel_size, stride=stride) # x2 -> 32
self.encoder3 = UnetPrUpBlock(in_channels=hidden_size, out_channels=token_size * 4, num_layers=1, kernel_size=kernel_size, stride=stride) # x4 -> 64
self.encoder4 = UnetPrUpBlock(in_channels=hidden_size, out_channels=token_size * 8, num_layers=0, kernel_size=kernel_size, stride=stride) # x8 -> 128
def proj_feat(self, x, hidden_size, token_size):
x = x.view(x.size(0), token_size, token_size, hidden_size) # [B T C] -> [B, token_size, token_size, hidden_size]
x = x.permute(0, 3, 1, 2).contiguous() # B C H W
return x
def forward(self, x_in, skip_connections=False):
x, hidden_states = self.visual_transformer(x_in) # [B, T, C]
residual = None
if skip_connections:
if self.inchannels != 2:
x_in = self.complex_proj(x_in) # Assume input is mag, convert to complex by adding a imaginary part
enc1 = self.encoder1(x_in)
x2 = hidden_states[self.num_layers // 4 * 1 -1]
enc2 = self.encoder2(self.proj_feat(x2, self.hidden_size, self.token_size))
x3 = hidden_states[self.num_layers // 4 * 2 -1]
enc3 = self.encoder3(self.proj_feat(x3, self.hidden_size, self.token_size))
x4 = hidden_states[self.num_layers // 4 * 3 -1]
enc4 = self.encoder4(self.proj_feat(x4, self.hidden_size, self.token_size))
residual = [enc1, enc2, enc3, enc4]
x = self.proj_feat(x, self.hidden_size, self.token_size) # [B, T, C] -> [B, C, H, W]
return x, residual
class CVConvNextEncoder(nn.Module):
def __init__(self,
hidden_dims=512,
intermediate_dim=1356,
num_layers=4,
complex_axis=1,
layer_scale_init_value=None,
**kwargs):
super(CVConvNextEncoder, self).__init__()
layer_scale_init_value = layer_scale_init_value or 1 / num_layers
self.blocks = nn.ModuleList(
[
CVConvNeXtBlock(
dim=hidden_dims,
intermediate_dim=intermediate_dim,
layer_scale_init_value=layer_scale_init_value,
complex_axis=complex_axis,
)
for _ in range(num_layers)
]
)
self.final_layer_norm = nn.LayerNorm(hidden_dims, eps=1e-6)
self.complex_axis = complex_axis
def forward(self, x):
laterals = []
for layer in self.blocks:
x = layer(x)
res = x.transpose(1, 2) # [B, C, T] -> [B, T, C]
laterals.append(res[:, 1:]) # Remove CLS token
real, imag = torch.chunk(x, 2, dim=self.complex_axis) # Split real and imaginary parts
real = self.final_layer_norm(real.transpose(1, 2)).transpose(1, 2) # Apply LayerNorm to real part
imag = self.final_layer_norm(imag.transpose(1, 2)).transpose(1, 2) # Apply LayerNorm to imaginary part
x = torch.cat([real, imag], dim=self.complex_axis) # Concatenate real and imaginary parts back together
x = x.transpose(1, 2) # [B, C, T] -> [B, T, C]
return x, laterals