Diffusion-Transformer / model /attention.py
YashNagraj75's picture
Remove print statements
fdc368b
import torch
import torch.nn as nn
from einops import rearrange
class Attention(nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.n_heads = config["num_heads"]
self.hidden_dim = config["hidden_dim"]
self.head_dim = config["head_dim"]
self.attn_dim = self.n_heads * self.head_dim
# QKV Projection for the input
self.qkv_proj = nn.Linear(self.hidden_dim, 3 * self.attn_dim, bias=True)
self.output_proj = nn.Sequential(nn.Linear(self.attn_dim, self.hidden_dim))
# DIT Layer initialization
nn.init.xavier_uniform_(self.qkv_proj.weight)
nn.init.constant_(self.qkv_proj.bias, 0)
nn.init.xavier_uniform_(self.output_proj[0].weight)
nn.init.constant_(self.output_proj[0].bias, 0)
def forward(self, x):
# Get Batch Size x Number of Patches
B, N = x.shape[:2]
# Projecting to 3*att_dim and then splitting to get q, k v(each of att_dim)
# qkv -> Batch Size x Number of Patches x (3* Attention Dimension)
# q(as well as k and v) -> Batch Size x Number of Patches x Attention Dimension
q, k, v = self.qkv_proj(x).split(self.attn_dim, dim=-1)
# Now cast from Batch Size x Number of Patches x (Heads * Head Dim) -> Attention Dim
# to Batch Size x Heads x Number of Patches x Head Dim
q = rearrange(
q, "b n (nh hdim) -> b nh n hdim", nh=self.n_heads, hdim=self.head_dim
)
k = rearrange(
k, "b n (nh hdim) -> b nh n hdim", nh=self.n_heads, hdim=self.head_dim
)
v = rearrange(
v, "b n (nh hdim) -> b nh n hdim", nh=self.n_heads, hdim=self.head_dim
)
# Compute attn score
# B x H x N x Head Dimension @ B x H x Head Dimension x N
# -> B x H x N x N
attn = torch.matmul(q, k.transpose(-2, -1)) * (self.head_dim ** (-0.5))
attn = nn.functional.softmax(attn, dim=-1)
out = torch.matmul(attn, v)
# Rearrange to B x N x (Heads * Head_dim)
out = rearrange(
out, "b nh n hdim -> b n (nh hdim)", nh=self.n_heads, hdim=self.head_dim
)
out = self.output_proj(out)
assert out.shape == x.shape, "Output shape should be equal to Input shape"
return out
# if __name__ == "__main__":
# config = {"num_heads": 2, "hidden_dim": 64, "head_dim": 16}
#
# # Create an instance of the Attention class
# attention_layer = Attention(config)
#
# # Generate a random input tensor with shape (batch_size, num_patches, hidden_dim)
# batch_size = 8
# num_patches = 10
# hidden_dim = config["hidden_dim"]
# x = torch.randn(batch_size, num_patches, hidden_dim)
#
# # Pass the input through the attention layer
# output = attention_layer(x)