Spaces:
Running
Running
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch | |
| from .blocks.complexmodule import ComplexConv1D | |
| from .utils.pos_embed import get_2d_sincos_pos_embed | |
| from timm.models.vision_transformer import PatchEmbed | |
| class CV1DTokenizer(nn.Module): | |
| def __init__(self, | |
| input_dims: int = 256, | |
| hidden_dims: int = 512, | |
| kernel_size: int = 3, | |
| complex_axis: int = 1 | |
| ) -> None: | |
| super(CV1DTokenizer, self).__init__() | |
| padding = (kernel_size - 1) // 2 # To maintain the same time dimension after convolution | |
| self.embedding = ComplexConv1D(input_dims, hidden_dims, kernel_size, stride=1, padding=padding, complex_axis=complex_axis) | |
| self.norm = nn.LayerNorm(hidden_dims, eps=1e-6) | |
| self.complex_axis = complex_axis | |
| def forward(self, x, channel_first=True): | |
| # inputs: [B, 2, F, T] | |
| real, imag = torch.chunk(x, 2, self.complex_axis) | |
| # [B, 2, F, T] -> [B, C, T] | |
| real = real.squeeze(1) | |
| imag = imag.squeeze(1) | |
| x = torch.cat([real, imag], dim=self.complex_axis) # [B, 2*C, T] | |
| x = self.embedding(x) | |
| real, imag = torch.chunk(x, 2, dim=self.complex_axis) # Split real and imaginary parts | |
| # [B, hidden_dims, time_frame] -> [B, time_frame, hidden_dims] -> LayerNorm -> [B, hidden_dims, time_frame] | |
| real = self.norm(real.transpose(1, 2)).transpose(1, 2) # Apply LayerNorm to real part | |
| imag = self.norm(imag.transpose(1, 2)).transpose(1, 2) # Apply LayerNorm to imaginary part | |
| x = torch.cat([real, imag], dim=self.complex_axis) # Concatenate real and imaginary parts back together | |
| if channel_first is False: | |
| x = x.transpose(1, 2) # [B, C, T] -> [B, T, C] | |
| return x | |
| class CV2DTokenizer(nn.Module): | |
| def __init__(self, | |
| feature_size: int | tuple = (256, 256), | |
| patch_size: int = 16, | |
| in_channels: int = 2, | |
| embed_dim: int = 512, | |
| complex_axis: int = 1 | |
| ) -> None: | |
| super(CV2DTokenizer, self).__init__() | |
| self.real_patch_embed = PatchEmbed(feature_size, patch_size, in_channels // 2, embed_dim // 2) | |
| self.imag_patch_embed = PatchEmbed(feature_size, patch_size, in_channels // 2, embed_dim // 2) | |
| self.num_patches = self.real_patch_embed.num_patches | |
| self.complex_axis = complex_axis | |
| self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |
| self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding | |
| self.initialize_weights() | |
| def initialize_weights(self): | |
| # initialization | |
| # initialize (and freeze) pos_embed by sin-cos embedding | |
| pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.num_patches**.5), cls_token=True) | |
| self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) | |
| # initialize patch_embed like nn.Linear (instead of nn.Conv2d) | |
| w_r = self.real_patch_embed.proj.weight.data | |
| w_i = self.imag_patch_embed.proj.weight.data | |
| torch.nn.init.xavier_uniform_(w_r.view([w_r.shape[0], -1])) | |
| torch.nn.init.xavier_uniform_(w_i.view([w_i.shape[0], -1])) | |
| # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) | |
| torch.nn.init.normal_(self.cls_token, std=.02) | |
| # initialize nn.Linear and nn.LayerNorm | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| # we use xavier_uniform following official JAX ViT: | |
| torch.nn.init.xavier_uniform_(m.weight) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.bias, 0) | |
| nn.init.constant_(m.weight, 1.0) | |
| def forward(self, x, channel_first=True): | |
| # x: [B, 2, F, T] | |
| real, imag = torch.chunk(x, 2, self.complex_axis) # Split real and imaginary parts | |
| real_tokens = self.real_patch_embed(real) # [B, num_patches, embed_dim // 2] | |
| imag_tokens = self.imag_patch_embed(imag) # [B, num_patches, embed_dim // 2] | |
| x = torch.cat([real_tokens, imag_tokens], dim=-1) # Concatenate real and imaginary tokens | |
| x = x + self.pos_embed[:, 1:, :] | |
| cls_token = self.cls_token + self.pos_embed[:, :1, :] | |
| cls_tokens = cls_token.expand(x.shape[0], -1, -1) | |
| x = torch.cat((cls_tokens, x), dim=1) # [B, num_patches + 1, embed_dim] | |
| if channel_first: | |
| x = x.transpose(1, 2) # (B, C, T) | |
| return x | |
| if __name__ == "__main__": | |
| batch_size = 4 | |
| feature_size = (256, 256) | |
| patch_size = 16 | |
| in_channels = 2 | |
| embed_dim = 512 | |
| tokenizer = CV2DTokenizer(feature_size, patch_size, in_channels, embed_dim) | |
| dummy_input = torch.randn(batch_size, in_channels, feature_size[0], feature_size[1]) | |
| output_tokens = tokenizer(dummy_input) | |
| print("Output tokens shape:", output_tokens.shape) |