import torch.nn as nn from model.attention import Attention # from attention import Attention class TransformerLayer(nn.Module): def __init__(self, config) -> None: super().__init__() self.hidden_dim = config["hidden_dim"] # According to Attention is All you need, defacto ff_dimension self.ff_dim = 4 * self.hidden_dim # Layer Norm for the attention block self.attn_norm = nn.LayerNorm( self.hidden_dim, self.ff_dim, elementwise_affine=False, ) self.attn_block = Attention(config) # Layer norm for attn_block affine is False as these are regressed and learnt during training self.ff_norm = nn.LayerNorm( self.hidden_dim, self.ff_dim, elementwise_affine=False, ) self.mlp_block = nn.Sequential( nn.Linear(self.hidden_dim, self.ff_dim), nn.GELU(approximate="tanh"), nn.Linear(self.ff_dim, self.hidden_dim), ) # Scale Shift Parameter predictions for this layer # 1. Scale and shift parameters for layernorm of attention (2 * hidden_size) # 2. Scale and shift parameters for layernorm of mlp (2 * hidden_size) # 3. Scale for output of attention prior to residual connection (hidden_size) # 4. Scale for output of mlp prior to residual connection (hidden_size) # Total 6 * hidden_size self.adaptive_norm_layer = nn.Sequential( nn.SiLU(), nn.Linear(self.hidden_dim, 6 * self.hidden_dim, bias=True), ) nn.init.xavier_uniform_(self.mlp_block[0].weight) nn.init.constant_(self.mlp_block[0].bias, 0) nn.init.xavier_uniform_(self.mlp_block[-1].weight) nn.init.constant_(self.mlp_block[-1].bias, 0) nn.init.constant_(self.adaptive_norm_layer[-1].weight, 0) nn.init.constant_(self.adaptive_norm_layer[-1].bias, 0) def forward(self, x, condition): scale_shift_params = self.adaptive_norm_layer(condition).chunk(6, dim=1) ( pre_attn_shift, pre_attn_scale, post_attn_scale, pre_mlp_shift, pre_mlp_scale, post_mlp_scale, ) = scale_shift_params out = x attn_norm = self.attn_norm(out) * ( 1 + pre_attn_scale.unsqueeze(1) ) + pre_attn_shift.unsqueeze(1) out = out + self.attn_block(attn_norm) * post_attn_scale.unsqueeze(1) mlp_norm = self.attn_norm(out) * ( 1 + pre_mlp_scale.unsqueeze(1) ) + pre_mlp_shift.unsqueeze(1) out = out + self.mlp_block(mlp_norm) * post_mlp_scale.unsqueeze(1) return out # if __name__ == "__main__": # config = {"hidden_dim": 64, "num_heads": 4, "head_dim": 16} # # # Initialize input tensor x # batch_size = 8 # num_patches = 10 # x = torch.randn(batch_size, num_patches, config["hidden_dim"]) # condition = torch.randn(batch_size, config["hidden_dim"]) # # # Initialize the layer # layer = TransformerLayer(config) # output = layer(x, condition)