|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import copy
|
|
|
import math
|
|
|
|
|
|
class MultiHeadAttention(nn.Module):
|
|
|
"""
|
|
|
Multi-Head Attention:
|
|
|
1) Linear projections for Q, K, V
|
|
|
2) Scaled dot-product attention per head
|
|
|
3) Concatenate heads and final linear projection
|
|
|
https://arxiv.org/pdf/1706.03762
|
|
|
"""
|
|
|
def __init__(self, embed_dim:int, key_dim:int, num_heads: int, dropout: float = 0.0):
|
|
|
super().__init__()
|
|
|
|
|
|
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
|
|
self.embed_dim = embed_dim
|
|
|
self.key_dim = key_dim
|
|
|
self.num_heads = num_heads
|
|
|
self.head_dim = embed_dim // num_heads
|
|
|
self.scale = math.sqrt(self.head_dim)
|
|
|
|
|
|
|
|
|
|
|
|
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
|
|
self.k_proj = nn.Linear(key_dim, embed_dim)
|
|
|
self.v_proj = nn.Linear(key_dim, embed_dim)
|
|
|
|
|
|
|
|
|
self.out_proj = nn.Linear(embed_dim, embed_dim)
|
|
|
|
|
|
|
|
|
self.attn_dropout = nn.Dropout(dropout)
|
|
|
self.proj_dropout = nn.Dropout(dropout)
|
|
|
|
|
|
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
|
|
|
"""
|
|
|
Args:
|
|
|
query: [batch, seq_q, embed_dim]
|
|
|
key: [batch, seq_k, embed_dim]
|
|
|
value: [batch, seq_k, embed_dim]
|
|
|
Returns:
|
|
|
out: [batch, seq_q, embed_dim]
|
|
|
"""
|
|
|
B, seq_q, _ = query.size()
|
|
|
_, seq_k, _ = key.size()
|
|
|
|
|
|
|
|
|
q = self.q_proj(query).view(B, seq_q, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
k = self.k_proj(key).view(B, seq_k, self.num_heads, self.head_dim).transpose(1,2)
|
|
|
v = self.v_proj(value).view(B, seq_k, self.num_heads, self.head_dim).transpose(1,2)
|
|
|
|
|
|
|
|
|
scores = (q @ k.transpose(-1, -2)) / self.scale
|
|
|
weights = torch.softmax(scores, dim=-1)
|
|
|
weights = self.attn_dropout(weights)
|
|
|
attn = weights @ v
|
|
|
|
|
|
|
|
|
attn = attn.transpose(1, 2).contiguous().view(B, seq_q, self.embed_dim)
|
|
|
|
|
|
|
|
|
out = self.out_proj(attn)
|
|
|
out = self.proj_dropout(out)
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
class TransformerEncoderLayer(nn.Module):
|
|
|
"""
|
|
|
Transformer Encoder Layer:
|
|
|
1) Multi-head self-attention
|
|
|
2) Feed-forward network
|
|
|
3) Residual connections + LayerNorm
|
|
|
"""
|
|
|
def __init__(
|
|
|
self,
|
|
|
embed_dim: int,
|
|
|
num_heads: int,
|
|
|
mlp_dim: int,
|
|
|
dropout: float = 0.1,
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
self.self_attn = MultiHeadAttention(
|
|
|
embed_dim=embed_dim,
|
|
|
key_dim=embed_dim,
|
|
|
num_heads=num_heads,
|
|
|
dropout=dropout,
|
|
|
)
|
|
|
|
|
|
|
|
|
self.ffn = nn.Sequential(
|
|
|
nn.Linear(embed_dim, mlp_dim),
|
|
|
nn.GELU(),
|
|
|
nn.Dropout(dropout),
|
|
|
nn.Linear(mlp_dim, embed_dim),
|
|
|
)
|
|
|
|
|
|
|
|
|
self.norm1 = nn.LayerNorm(embed_dim)
|
|
|
self.norm2 = nn.LayerNorm(embed_dim)
|
|
|
|
|
|
self.attn_dropout = nn.Dropout(dropout)
|
|
|
self.ff_dropout = nn.Dropout(dropout)
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
"""
|
|
|
Args:
|
|
|
x: Tensor of shape [batch, seq_len, embed_dim]
|
|
|
Returns:
|
|
|
Tensor of same shape
|
|
|
"""
|
|
|
|
|
|
x_norm = self.norm1(x)
|
|
|
attn_out = self.self_attn(x_norm, x_norm, x_norm)
|
|
|
x = x + self.attn_dropout(attn_out)
|
|
|
|
|
|
|
|
|
x_norm_ff = self.norm2(x)
|
|
|
ff = self.ffn(x_norm_ff)
|
|
|
x = x + self.ff_dropout(ff)
|
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
class TransformerEncoder(nn.Module):
|
|
|
|
|
|
def __init__(self, encoder_layer: TransformerEncoderLayer, num_layers: int, embed_dim: int):
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(num_layers)])
|
|
|
self.norm = nn.LayerNorm(embed_dim)
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
for layer in self.layers:
|
|
|
x = layer(x)
|
|
|
|
|
|
|
|
|
x = self.norm(x)
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
class TransformerDecoderLayer(nn.Module):
|
|
|
"""
|
|
|
Transformer Decoder Layer:
|
|
|
1) Self-attention
|
|
|
2) Cross-attention (over encoder features)
|
|
|
3) Feed-forward network
|
|
|
4) Residual connections + LayerNorm
|
|
|
"""
|
|
|
def __init__(
|
|
|
self,
|
|
|
encoder_embed_dim: int,
|
|
|
decoder_embed_dim: int,
|
|
|
num_heads: int,
|
|
|
mlp_dim: int,
|
|
|
dropout: float = 0.1,
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
self.self_attn = MultiHeadAttention(decoder_embed_dim, decoder_embed_dim, num_heads, dropout)
|
|
|
|
|
|
|
|
|
self.cross_attn = MultiHeadAttention(decoder_embed_dim, encoder_embed_dim, num_heads, dropout)
|
|
|
|
|
|
|
|
|
self.ffn = nn.Sequential(
|
|
|
nn.Linear(decoder_embed_dim, mlp_dim),
|
|
|
nn.GELU(),
|
|
|
nn.Dropout(dropout),
|
|
|
nn.Linear(mlp_dim, decoder_embed_dim)
|
|
|
)
|
|
|
|
|
|
|
|
|
self.norm1 = nn.LayerNorm(decoder_embed_dim)
|
|
|
self.norm2 = nn.LayerNorm(decoder_embed_dim)
|
|
|
self.norm3 = nn.LayerNorm(decoder_embed_dim)
|
|
|
self.dropout1 = nn.Dropout(dropout)
|
|
|
self.dropout2 = nn.Dropout(dropout)
|
|
|
self.dropout3 = nn.Dropout(dropout)
|
|
|
|
|
|
def forward(self, x: torch.Tensor, encoder_features: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
x_norm = self.norm1(x)
|
|
|
sa = self.self_attn(x_norm, x_norm, x_norm)
|
|
|
|
|
|
x = x + self.dropout1(sa)
|
|
|
|
|
|
|
|
|
x_norm_ca = self.norm2(x)
|
|
|
ca = self.cross_attn(x_norm_ca, encoder_features, encoder_features)
|
|
|
|
|
|
x = x + self.dropout2(ca)
|
|
|
|
|
|
|
|
|
x_norm_ff = self.norm3(x)
|
|
|
ff = self.ffn(x_norm_ff)
|
|
|
|
|
|
x = x + self.dropout3(ff)
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
class TransformerDecoder(nn.Module):
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
decoder_layer: TransformerDecoderLayer,
|
|
|
num_layers: int,
|
|
|
embed_dim: int,
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
self.layers = nn.ModuleList([copy.deepcopy(decoder_layer) for _ in range(num_layers)])
|
|
|
|
|
|
self.norm = nn.LayerNorm(embed_dim)
|
|
|
|
|
|
def forward(self, x: torch.Tensor, encoder_features: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
for layer in self.layers:
|
|
|
x = layer(x, encoder_features)
|
|
|
|
|
|
x = self.norm(x)
|
|
|
|
|
|
return x
|
|
|
|