File size: 5,319 Bytes
e8160b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch.nn as nn
import torch.nn.functional as F
import torch
from .blocks.complexmodule import ComplexConv1D
from .utils.pos_embed import get_2d_sincos_pos_embed
from timm.models.vision_transformer import PatchEmbed


class CV1DTokenizer(nn.Module):
    def __init__(self, 
                 input_dims: int = 256, 
                 hidden_dims: int = 512,
                 kernel_size: int = 3,
                 complex_axis: int = 1
                 ) -> None:
        super(CV1DTokenizer, self).__init__()
        padding = (kernel_size - 1) // 2  # To maintain the same time dimension after convolution
        self.embedding = ComplexConv1D(input_dims, hidden_dims, kernel_size, stride=1, padding=padding, complex_axis=complex_axis)
        self.norm = nn.LayerNorm(hidden_dims, eps=1e-6)
        self.complex_axis = complex_axis
    
    def forward(self, x, channel_first=True):
        # inputs: [B, 2, F, T]
        real, imag = torch.chunk(x, 2, self.complex_axis)
        # [B, 2, F, T] -> [B, C, T]
        real = real.squeeze(1)
        imag = imag.squeeze(1)
        x = torch.cat([real, imag], dim=self.complex_axis) # [B, 2*C, T]
        
        x = self.embedding(x)
        
        real, imag = torch.chunk(x, 2, dim=self.complex_axis) # Split real and imaginary parts
        # [B, hidden_dims, time_frame] -> [B, time_frame, hidden_dims] -> LayerNorm -> [B, hidden_dims, time_frame]
        real = self.norm(real.transpose(1, 2)).transpose(1, 2) # Apply LayerNorm to real part
        imag = self.norm(imag.transpose(1, 2)).transpose(1, 2) # Apply LayerNorm to imaginary part
        
        x = torch.cat([real, imag], dim=self.complex_axis) # Concatenate real and imaginary parts back together
        
        if channel_first is False:
            x = x.transpose(1, 2) # [B, C, T] -> [B, T, C]
        
        return x


class CV2DTokenizer(nn.Module):
    def __init__(self, 
                 feature_size: int | tuple = (256, 256),
                 patch_size: int = 16,
                 in_channels: int = 2,
                 embed_dim: int = 512,
                 complex_axis: int = 1
                 ) -> None:
        super(CV2DTokenizer, self).__init__()
        self.real_patch_embed = PatchEmbed(feature_size, patch_size, in_channels // 2, embed_dim // 2)
        self.imag_patch_embed = PatchEmbed(feature_size, patch_size, in_channels // 2, embed_dim // 2)
        
        self.num_patches = self.real_patch_embed.num_patches
        self.complex_axis = complex_axis
        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim), requires_grad=False)  # fixed sin-cos embedding
        self.initialize_weights()
    
    def initialize_weights(self):
        # initialization
        # initialize (and freeze) pos_embed by sin-cos embedding
        pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.num_patches**.5), cls_token=True)
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

        # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
        w_r = self.real_patch_embed.proj.weight.data
        w_i = self.imag_patch_embed.proj.weight.data
        torch.nn.init.xavier_uniform_(w_r.view([w_r.shape[0], -1]))
        torch.nn.init.xavier_uniform_(w_i.view([w_i.shape[0], -1]))

        # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
        torch.nn.init.normal_(self.cls_token, std=.02)

        # initialize nn.Linear and nn.LayerNorm
        self.apply(self._init_weights)
    
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
    
    def forward(self, x, channel_first=True):
        # x: [B, 2, F, T]
        real, imag = torch.chunk(x, 2, self.complex_axis) # Split real and imaginary parts
        
        real_tokens = self.real_patch_embed(real) # [B, num_patches, embed_dim // 2]
        imag_tokens = self.imag_patch_embed(imag) # [B, num_patches, embed_dim // 2]
        
        x = torch.cat([real_tokens, imag_tokens], dim=-1) # Concatenate real and imaginary tokens
        
        x = x + self.pos_embed[:, 1:, :]
        
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1) # [B, num_patches + 1, embed_dim]
        
        if channel_first:
            x = x.transpose(1, 2) # (B, C, T)
        
        return x
    

if __name__ == "__main__":
    batch_size = 4
    feature_size = (256, 256)
    patch_size = 16
    in_channels = 2
    embed_dim = 512
    
    tokenizer = CV2DTokenizer(feature_size, patch_size, in_channels, embed_dim)
    
    dummy_input = torch.randn(batch_size, in_channels, feature_size[0], feature_size[1])
    
    output_tokens = tokenizer(dummy_input)
    
    print("Output tokens shape:", output_tokens.shape)