import torch.nn as nn import torch.nn.functional as F import torch from .blocks.complexmodule import ComplexConv1D from .utils.pos_embed import get_2d_sincos_pos_embed from timm.models.vision_transformer import PatchEmbed class CV1DTokenizer(nn.Module): def __init__(self, input_dims: int = 256, hidden_dims: int = 512, kernel_size: int = 3, complex_axis: int = 1 ) -> None: super(CV1DTokenizer, self).__init__() padding = (kernel_size - 1) // 2 # To maintain the same time dimension after convolution self.embedding = ComplexConv1D(input_dims, hidden_dims, kernel_size, stride=1, padding=padding, complex_axis=complex_axis) self.norm = nn.LayerNorm(hidden_dims, eps=1e-6) self.complex_axis = complex_axis def forward(self, x, channel_first=True): # inputs: [B, 2, F, T] real, imag = torch.chunk(x, 2, self.complex_axis) # [B, 2, F, T] -> [B, C, T] real = real.squeeze(1) imag = imag.squeeze(1) x = torch.cat([real, imag], dim=self.complex_axis) # [B, 2*C, T] x = self.embedding(x) real, imag = torch.chunk(x, 2, dim=self.complex_axis) # Split real and imaginary parts # [B, hidden_dims, time_frame] -> [B, time_frame, hidden_dims] -> LayerNorm -> [B, hidden_dims, time_frame] real = self.norm(real.transpose(1, 2)).transpose(1, 2) # Apply LayerNorm to real part imag = self.norm(imag.transpose(1, 2)).transpose(1, 2) # Apply LayerNorm to imaginary part x = torch.cat([real, imag], dim=self.complex_axis) # Concatenate real and imaginary parts back together if channel_first is False: x = x.transpose(1, 2) # [B, C, T] -> [B, T, C] return x class CV2DTokenizer(nn.Module): def __init__(self, feature_size: int | tuple = (256, 256), patch_size: int = 16, in_channels: int = 2, embed_dim: int = 512, complex_axis: int = 1 ) -> None: super(CV2DTokenizer, self).__init__() self.real_patch_embed = PatchEmbed(feature_size, patch_size, in_channels // 2, embed_dim // 2) self.imag_patch_embed = PatchEmbed(feature_size, patch_size, in_channels // 2, embed_dim // 2) self.num_patches = self.real_patch_embed.num_patches self.complex_axis = complex_axis self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding self.initialize_weights() def initialize_weights(self): # initialization # initialize (and freeze) pos_embed by sin-cos embedding pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.num_patches**.5), cls_token=True) self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) # initialize patch_embed like nn.Linear (instead of nn.Conv2d) w_r = self.real_patch_embed.proj.weight.data w_i = self.imag_patch_embed.proj.weight.data torch.nn.init.xavier_uniform_(w_r.view([w_r.shape[0], -1])) torch.nn.init.xavier_uniform_(w_i.view([w_i.shape[0], -1])) # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) torch.nn.init.normal_(self.cls_token, std=.02) # initialize nn.Linear and nn.LayerNorm self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): # we use xavier_uniform following official JAX ViT: torch.nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward(self, x, channel_first=True): # x: [B, 2, F, T] real, imag = torch.chunk(x, 2, self.complex_axis) # Split real and imaginary parts real_tokens = self.real_patch_embed(real) # [B, num_patches, embed_dim // 2] imag_tokens = self.imag_patch_embed(imag) # [B, num_patches, embed_dim // 2] x = torch.cat([real_tokens, imag_tokens], dim=-1) # Concatenate real and imaginary tokens x = x + self.pos_embed[:, 1:, :] cls_token = self.cls_token + self.pos_embed[:, :1, :] cls_tokens = cls_token.expand(x.shape[0], -1, -1) x = torch.cat((cls_tokens, x), dim=1) # [B, num_patches + 1, embed_dim] if channel_first: x = x.transpose(1, 2) # (B, C, T) return x if __name__ == "__main__": batch_size = 4 feature_size = (256, 256) patch_size = 16 in_channels = 2 embed_dim = 512 tokenizer = CV2DTokenizer(feature_size, patch_size, in_channels, embed_dim) dummy_input = torch.randn(batch_size, in_channels, feature_size[0], feature_size[1]) output_tokens = tokenizer(dummy_input) print("Output tokens shape:", output_tokens.shape)