Empath_EmotionClassifier / transformer.py
prekshyam's picture
Emotion Classifier, MAE/ViT architecture uploaded
9d79189 verified
import torch
import torch.nn as nn
import copy
import math
class MultiHeadAttention(nn.Module):
"""
Multi-Head Attention:
1) Linear projections for Q, K, V
2) Scaled dot-product attention per head
3) Concatenate heads and final linear projection
https://arxiv.org/pdf/1706.03762
"""
def __init__(self, embed_dim:int, key_dim:int, num_heads: int, dropout: float = 0.0):
super().__init__()
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
self.embed_dim = embed_dim
self.key_dim = key_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.scale = math.sqrt(self.head_dim) # square root of dk for scaling
# Separate projections for query, key, and value 3 diff transformations
# HINT: Linear projections for Q, K, V
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(key_dim, embed_dim) #To Do
self.v_proj = nn.Linear(key_dim, embed_dim) #To Do
# Output projection after concatenating heads (embed_dim -> embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim) #To Do
# Dropouts
self.attn_dropout = nn.Dropout(dropout)
self.proj_dropout = nn.Dropout(dropout)
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
"""
Args:
query: [batch, seq_q, embed_dim]
key: [batch, seq_k, embed_dim]
value: [batch, seq_k, embed_dim]
Returns:
out: [batch, seq_q, embed_dim]
"""
B, seq_q, _ = query.size()
_, seq_k, _ = key.size()
# 1) Project inputs and split into heads
q = self.q_proj(query).view(B, seq_q, self.num_heads, self.head_dim).transpose(1, 2) # [B, heads, seq_q, head_dim]
k = self.k_proj(key).view(B, seq_k, self.num_heads, self.head_dim).transpose(1,2) # [B, heads, seq_k, head_dim]
v = self.v_proj(value).view(B, seq_k, self.num_heads, self.head_dim).transpose(1,2) # [B, heads, seq_k, head_dim]
# 2) Compute scaled dot-product attention
scores = (q @ k.transpose(-1, -2)) / self.scale # TO DO multiply q and k, then scale # [B, heads, seq_q, seq_k] swaps last twp dims?
weights = torch.softmax(scores, dim=-1) # TO DO apply softmax to scores
weights = self.attn_dropout(weights)
attn = weights @ v # TO DO multiply weights and v # [B, heads, seq_q, head_dim]
# 3) Concatenate heads
attn = attn.transpose(1, 2).contiguous().view(B, seq_q, self.embed_dim) # [B, seq_q, embed_dim]
# 4) Final projection
out = self.out_proj(attn) # TO DO apply output projection
out = self.proj_dropout(out)
return out
class TransformerEncoderLayer(nn.Module):
"""
Transformer Encoder Layer:
1) Multi-head self-attention
2) Feed-forward network
3) Residual connections + LayerNorm
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
mlp_dim: int,
dropout: float = 0.1,
):
super().__init__()
# 1) Self-attention
self.self_attn = MultiHeadAttention(
embed_dim=embed_dim,
key_dim=embed_dim, # self-attention uses same dimension for Q, K, V
num_heads=num_heads,
dropout=dropout,
)
# 2) Feed-forward network using nn.Sequential
self.ffn = nn.Sequential(
nn.Linear(embed_dim, mlp_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_dim, embed_dim),
)
# 3) LayerNorm and Dropouts for residuals
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
# self.norm3 = nn.LayerNorm(embed_dim)
self.attn_dropout = nn.Dropout(dropout)
self.ff_dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Tensor of shape [batch, seq_len, embed_dim]
Returns:
Tensor of same shape
"""
# 1) Self-attention block
x_norm = self.norm1(x) # Normalize input
attn_out = self.self_attn(x_norm, x_norm, x_norm)
x = x + self.attn_dropout(attn_out) # Residual connection
# 2) Feed-forward block
x_norm_ff = self.norm2(x)
ff = self.ffn(x_norm_ff)
x = x + self.ff_dropout(ff) # Residual connection
# x = self.norm3(x)
return x
class TransformerEncoder(nn.Module):
def __init__(self, encoder_layer: TransformerEncoderLayer, num_layers: int, embed_dim: int):
super().__init__()
# Clone the provided encoder_layer num_layers times
self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(num_layers)])
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# TO DO pass through each layer
# HINT: Pass input through each encoder layer to update x
for layer in self.layers:
x = layer(x)
# Apply final normalization
x = self.norm(x)
return x
class TransformerDecoderLayer(nn.Module):
"""
Transformer Decoder Layer:
1) Self-attention
2) Cross-attention (over encoder features)
3) Feed-forward network
4) Residual connections + LayerNorm
"""
def __init__(
self,
encoder_embed_dim: int,
decoder_embed_dim: int,
num_heads: int,
mlp_dim: int,
dropout: float = 0.1,
):
super().__init__()
# 1. Self-attention on decoder input
self.self_attn = MultiHeadAttention(decoder_embed_dim, decoder_embed_dim, num_heads, dropout)
# 2. Cross-attention over encoder features
self.cross_attn = MultiHeadAttention(decoder_embed_dim, encoder_embed_dim, num_heads, dropout)
# 3. Feed-forward network
self.ffn = nn.Sequential(
nn.Linear(decoder_embed_dim, mlp_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_dim, decoder_embed_dim)
)
# 4. LayerNorms and Dropouts
self.norm1 = nn.LayerNorm(decoder_embed_dim)
self.norm2 = nn.LayerNorm(decoder_embed_dim)
self.norm3 = nn.LayerNorm(decoder_embed_dim)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, encoder_features: torch.Tensor) -> torch.Tensor:
# 1) Self-attention block
x_norm = self.norm1(x) # Normalize input
sa = self.self_attn(x_norm, x_norm, x_norm)
# TO DO
x = x + self.dropout1(sa)
# 2) Cross-attention block
x_norm_ca = self.norm2(x)
ca = self.cross_attn(x_norm_ca, encoder_features, encoder_features)
# TO DO
x = x + self.dropout2(ca)
# 3) Feed-forward block
x_norm_ff = self.norm3(x)
ff = self.ffn(x_norm_ff)
# TO DO
x = x + self.dropout3(ff)
return x
class TransformerDecoder(nn.Module):
def __init__(
self,
decoder_layer: TransformerDecoderLayer,
num_layers: int,
embed_dim: int,
):
super().__init__()
# Clone the provided decoder_layer num_layers times
self.layers = nn.ModuleList([copy.deepcopy(decoder_layer) for _ in range(num_layers)])
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x: torch.Tensor, encoder_features: torch.Tensor) -> torch.Tensor:
# TODO pass through each layer
# HINT: Pass input through each encoder layer to update x
for layer in self.layers:
x = layer(x, encoder_features)
x = self.norm(x)
return x