File size: 3,146 Bytes
0c120cf
 
31677e7
 
 
 
0c120cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch.nn as nn

from model.attention import Attention

# from attention import Attention



class TransformerLayer(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.hidden_dim = config["hidden_dim"]

        # According to Attention is All you need, defacto ff_dimension
        self.ff_dim = 4 * self.hidden_dim

        # Layer Norm for the attention block
        self.attn_norm = nn.LayerNorm(
            self.hidden_dim,
            self.ff_dim,
            elementwise_affine=False,
        )

        self.attn_block = Attention(config)

        # Layer norm for attn_block affine is False as these are regressed and learnt during training
        self.ff_norm = nn.LayerNorm(
            self.hidden_dim,
            self.ff_dim,
            elementwise_affine=False,
        )

        self.mlp_block = nn.Sequential(
            nn.Linear(self.hidden_dim, self.ff_dim),
            nn.GELU(approximate="tanh"),
            nn.Linear(self.ff_dim, self.hidden_dim),
        )

        # Scale Shift Parameter predictions for this layer
        # 1. Scale and shift parameters for layernorm of attention (2 * hidden_size)
        # 2. Scale and shift parameters for layernorm of mlp (2 * hidden_size)
        # 3. Scale for output of attention prior to residual connection (hidden_size)
        # 4. Scale for output of mlp prior to residual connection (hidden_size)
        # Total 6 * hidden_size

        self.adaptive_norm_layer = nn.Sequential(
            nn.SiLU(),
            nn.Linear(self.hidden_dim, 6 * self.hidden_dim, bias=True),
        )

        nn.init.xavier_uniform_(self.mlp_block[0].weight)
        nn.init.constant_(self.mlp_block[0].bias, 0)
        nn.init.xavier_uniform_(self.mlp_block[-1].weight)
        nn.init.constant_(self.mlp_block[-1].bias, 0)

        nn.init.constant_(self.adaptive_norm_layer[-1].weight, 0)
        nn.init.constant_(self.adaptive_norm_layer[-1].bias, 0)

    def forward(self, x, condition):
        scale_shift_params = self.adaptive_norm_layer(condition).chunk(6, dim=1)
        (
            pre_attn_shift,
            pre_attn_scale,
            post_attn_scale,
            pre_mlp_shift,
            pre_mlp_scale,
            post_mlp_scale,
        ) = scale_shift_params

        out = x
        attn_norm = self.attn_norm(out) * (
            1 + pre_attn_scale.unsqueeze(1)
        ) + pre_attn_shift.unsqueeze(1)

        out = out + self.attn_block(attn_norm) * post_attn_scale.unsqueeze(1)

        mlp_norm = self.attn_norm(out) * (
            1 + pre_mlp_scale.unsqueeze(1)
        ) + pre_mlp_shift.unsqueeze(1)

        out = out + self.mlp_block(mlp_norm) * post_mlp_scale.unsqueeze(1)
        return out


# if __name__ == "__main__":
#    config = {"hidden_dim": 64, "num_heads": 4, "head_dim": 16}
#
#    # Initialize input tensor x
#    batch_size = 8
#    num_patches = 10
#    x = torch.randn(batch_size, num_patches, config["hidden_dim"])
#    condition = torch.randn(batch_size, config["hidden_dim"])
#
#    # Initialize the layer
#    layer = TransformerLayer(config)
#    output = layer(x, condition)