U-Past / modules /embedding.py
lycaoduong's picture
Initial space
e8160b2 verified
import torch.nn as nn
import torch.nn.functional as F
import torch
from .blocks.complexmodule import ComplexConv1D
from .utils.pos_embed import get_2d_sincos_pos_embed
from timm.models.vision_transformer import PatchEmbed
class CV1DTokenizer(nn.Module):
def __init__(self,
input_dims: int = 256,
hidden_dims: int = 512,
kernel_size: int = 3,
complex_axis: int = 1
) -> None:
super(CV1DTokenizer, self).__init__()
padding = (kernel_size - 1) // 2 # To maintain the same time dimension after convolution
self.embedding = ComplexConv1D(input_dims, hidden_dims, kernel_size, stride=1, padding=padding, complex_axis=complex_axis)
self.norm = nn.LayerNorm(hidden_dims, eps=1e-6)
self.complex_axis = complex_axis
def forward(self, x, channel_first=True):
# inputs: [B, 2, F, T]
real, imag = torch.chunk(x, 2, self.complex_axis)
# [B, 2, F, T] -> [B, C, T]
real = real.squeeze(1)
imag = imag.squeeze(1)
x = torch.cat([real, imag], dim=self.complex_axis) # [B, 2*C, T]
x = self.embedding(x)
real, imag = torch.chunk(x, 2, dim=self.complex_axis) # Split real and imaginary parts
# [B, hidden_dims, time_frame] -> [B, time_frame, hidden_dims] -> LayerNorm -> [B, hidden_dims, time_frame]
real = self.norm(real.transpose(1, 2)).transpose(1, 2) # Apply LayerNorm to real part
imag = self.norm(imag.transpose(1, 2)).transpose(1, 2) # Apply LayerNorm to imaginary part
x = torch.cat([real, imag], dim=self.complex_axis) # Concatenate real and imaginary parts back together
if channel_first is False:
x = x.transpose(1, 2) # [B, C, T] -> [B, T, C]
return x
class CV2DTokenizer(nn.Module):
def __init__(self,
feature_size: int | tuple = (256, 256),
patch_size: int = 16,
in_channels: int = 2,
embed_dim: int = 512,
complex_axis: int = 1
) -> None:
super(CV2DTokenizer, self).__init__()
self.real_patch_embed = PatchEmbed(feature_size, patch_size, in_channels // 2, embed_dim // 2)
self.imag_patch_embed = PatchEmbed(feature_size, patch_size, in_channels // 2, embed_dim // 2)
self.num_patches = self.real_patch_embed.num_patches
self.complex_axis = complex_axis
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
self.initialize_weights()
def initialize_weights(self):
# initialization
# initialize (and freeze) pos_embed by sin-cos embedding
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.num_patches**.5), cls_token=True)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
w_r = self.real_patch_embed.proj.weight.data
w_i = self.imag_patch_embed.proj.weight.data
torch.nn.init.xavier_uniform_(w_r.view([w_r.shape[0], -1]))
torch.nn.init.xavier_uniform_(w_i.view([w_i.shape[0], -1]))
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
torch.nn.init.normal_(self.cls_token, std=.02)
# initialize nn.Linear and nn.LayerNorm
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x, channel_first=True):
# x: [B, 2, F, T]
real, imag = torch.chunk(x, 2, self.complex_axis) # Split real and imaginary parts
real_tokens = self.real_patch_embed(real) # [B, num_patches, embed_dim // 2]
imag_tokens = self.imag_patch_embed(imag) # [B, num_patches, embed_dim // 2]
x = torch.cat([real_tokens, imag_tokens], dim=-1) # Concatenate real and imaginary tokens
x = x + self.pos_embed[:, 1:, :]
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1) # [B, num_patches + 1, embed_dim]
if channel_first:
x = x.transpose(1, 2) # (B, C, T)
return x
if __name__ == "__main__":
batch_size = 4
feature_size = (256, 256)
patch_size = 16
in_channels = 2
embed_dim = 512
tokenizer = CV2DTokenizer(feature_size, patch_size, in_channels, embed_dim)
dummy_input = torch.randn(batch_size, in_channels, feature_size[0], feature_size[1])
output_tokens = tokenizer(dummy_input)
print("Output tokens shape:", output_tokens.shape)