import torch import torch.nn as nn from einops import rearrange class Attention(nn.Module): def __init__(self, config) -> None: super().__init__() self.n_heads = config["num_heads"] self.hidden_dim = config["hidden_dim"] self.head_dim = config["head_dim"] self.attn_dim = self.n_heads * self.head_dim # QKV Projection for the input self.qkv_proj = nn.Linear(self.hidden_dim, 3 * self.attn_dim, bias=True) self.output_proj = nn.Sequential(nn.Linear(self.attn_dim, self.hidden_dim)) # DIT Layer initialization nn.init.xavier_uniform_(self.qkv_proj.weight) nn.init.constant_(self.qkv_proj.bias, 0) nn.init.xavier_uniform_(self.output_proj[0].weight) nn.init.constant_(self.output_proj[0].bias, 0) def forward(self, x): # Get Batch Size x Number of Patches B, N = x.shape[:2] # Projecting to 3*att_dim and then splitting to get q, k v(each of att_dim) # qkv -> Batch Size x Number of Patches x (3* Attention Dimension) # q(as well as k and v) -> Batch Size x Number of Patches x Attention Dimension q, k, v = self.qkv_proj(x).split(self.attn_dim, dim=-1) # Now cast from Batch Size x Number of Patches x (Heads * Head Dim) -> Attention Dim # to Batch Size x Heads x Number of Patches x Head Dim q = rearrange( q, "b n (nh hdim) -> b nh n hdim", nh=self.n_heads, hdim=self.head_dim ) k = rearrange( k, "b n (nh hdim) -> b nh n hdim", nh=self.n_heads, hdim=self.head_dim ) v = rearrange( v, "b n (nh hdim) -> b nh n hdim", nh=self.n_heads, hdim=self.head_dim ) # Compute attn score # B x H x N x Head Dimension @ B x H x Head Dimension x N # -> B x H x N x N attn = torch.matmul(q, k.transpose(-2, -1)) * (self.head_dim ** (-0.5)) attn = nn.functional.softmax(attn, dim=-1) out = torch.matmul(attn, v) # Rearrange to B x N x (Heads * Head_dim) out = rearrange( out, "b nh n hdim -> b n (nh hdim)", nh=self.n_heads, hdim=self.head_dim ) out = self.output_proj(out) assert out.shape == x.shape, "Output shape should be equal to Input shape" return out # if __name__ == "__main__": # config = {"num_heads": 2, "hidden_dim": 64, "head_dim": 16} # # # Create an instance of the Attention class # attention_layer = Attention(config) # # # Generate a random input tensor with shape (batch_size, num_patches, hidden_dim) # batch_size = 8 # num_patches = 10 # hidden_dim = config["hidden_dim"] # x = torch.randn(batch_size, num_patches, hidden_dim) # # # Pass the input through the attention layer # output = attention_layer(x)