File size: 3,146 Bytes
0c120cf 31677e7 0c120cf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 | import torch.nn as nn
from model.attention import Attention
# from attention import Attention
class TransformerLayer(nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.hidden_dim = config["hidden_dim"]
# According to Attention is All you need, defacto ff_dimension
self.ff_dim = 4 * self.hidden_dim
# Layer Norm for the attention block
self.attn_norm = nn.LayerNorm(
self.hidden_dim,
self.ff_dim,
elementwise_affine=False,
)
self.attn_block = Attention(config)
# Layer norm for attn_block affine is False as these are regressed and learnt during training
self.ff_norm = nn.LayerNorm(
self.hidden_dim,
self.ff_dim,
elementwise_affine=False,
)
self.mlp_block = nn.Sequential(
nn.Linear(self.hidden_dim, self.ff_dim),
nn.GELU(approximate="tanh"),
nn.Linear(self.ff_dim, self.hidden_dim),
)
# Scale Shift Parameter predictions for this layer
# 1. Scale and shift parameters for layernorm of attention (2 * hidden_size)
# 2. Scale and shift parameters for layernorm of mlp (2 * hidden_size)
# 3. Scale for output of attention prior to residual connection (hidden_size)
# 4. Scale for output of mlp prior to residual connection (hidden_size)
# Total 6 * hidden_size
self.adaptive_norm_layer = nn.Sequential(
nn.SiLU(),
nn.Linear(self.hidden_dim, 6 * self.hidden_dim, bias=True),
)
nn.init.xavier_uniform_(self.mlp_block[0].weight)
nn.init.constant_(self.mlp_block[0].bias, 0)
nn.init.xavier_uniform_(self.mlp_block[-1].weight)
nn.init.constant_(self.mlp_block[-1].bias, 0)
nn.init.constant_(self.adaptive_norm_layer[-1].weight, 0)
nn.init.constant_(self.adaptive_norm_layer[-1].bias, 0)
def forward(self, x, condition):
scale_shift_params = self.adaptive_norm_layer(condition).chunk(6, dim=1)
(
pre_attn_shift,
pre_attn_scale,
post_attn_scale,
pre_mlp_shift,
pre_mlp_scale,
post_mlp_scale,
) = scale_shift_params
out = x
attn_norm = self.attn_norm(out) * (
1 + pre_attn_scale.unsqueeze(1)
) + pre_attn_shift.unsqueeze(1)
out = out + self.attn_block(attn_norm) * post_attn_scale.unsqueeze(1)
mlp_norm = self.attn_norm(out) * (
1 + pre_mlp_scale.unsqueeze(1)
) + pre_mlp_shift.unsqueeze(1)
out = out + self.mlp_block(mlp_norm) * post_mlp_scale.unsqueeze(1)
return out
# if __name__ == "__main__":
# config = {"hidden_dim": 64, "num_heads": 4, "head_dim": 16}
#
# # Initialize input tensor x
# batch_size = 8
# num_patches = 10
# x = torch.randn(batch_size, num_patches, config["hidden_dim"])
# condition = torch.randn(batch_size, config["hidden_dim"])
#
# # Initialize the layer
# layer = TransformerLayer(config)
# output = layer(x, condition)
|