import torch import torch.nn as nn import copy import math class MultiHeadAttention(nn.Module): """ Multi-Head Attention: 1) Linear projections for Q, K, V 2) Scaled dot-product attention per head 3) Concatenate heads and final linear projection https://arxiv.org/pdf/1706.03762 """ def __init__(self, embed_dim:int, key_dim:int, num_heads: int, dropout: float = 0.0): super().__init__() assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" self.embed_dim = embed_dim self.key_dim = key_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.scale = math.sqrt(self.head_dim) # square root of dk for scaling # Separate projections for query, key, and value 3 diff transformations # HINT: Linear projections for Q, K, V self.q_proj = nn.Linear(embed_dim, embed_dim) self.k_proj = nn.Linear(key_dim, embed_dim) #To Do self.v_proj = nn.Linear(key_dim, embed_dim) #To Do # Output projection after concatenating heads (embed_dim -> embed_dim) self.out_proj = nn.Linear(embed_dim, embed_dim) #To Do # Dropouts self.attn_dropout = nn.Dropout(dropout) self.proj_dropout = nn.Dropout(dropout) def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: """ Args: query: [batch, seq_q, embed_dim] key: [batch, seq_k, embed_dim] value: [batch, seq_k, embed_dim] Returns: out: [batch, seq_q, embed_dim] """ B, seq_q, _ = query.size() _, seq_k, _ = key.size() # 1) Project inputs and split into heads q = self.q_proj(query).view(B, seq_q, self.num_heads, self.head_dim).transpose(1, 2) # [B, heads, seq_q, head_dim] k = self.k_proj(key).view(B, seq_k, self.num_heads, self.head_dim).transpose(1,2) # [B, heads, seq_k, head_dim] v = self.v_proj(value).view(B, seq_k, self.num_heads, self.head_dim).transpose(1,2) # [B, heads, seq_k, head_dim] # 2) Compute scaled dot-product attention scores = (q @ k.transpose(-1, -2)) / self.scale # TO DO multiply q and k, then scale # [B, heads, seq_q, seq_k] swaps last twp dims? weights = torch.softmax(scores, dim=-1) # TO DO apply softmax to scores weights = self.attn_dropout(weights) attn = weights @ v # TO DO multiply weights and v # [B, heads, seq_q, head_dim] # 3) Concatenate heads attn = attn.transpose(1, 2).contiguous().view(B, seq_q, self.embed_dim) # [B, seq_q, embed_dim] # 4) Final projection out = self.out_proj(attn) # TO DO apply output projection out = self.proj_dropout(out) return out class TransformerEncoderLayer(nn.Module): """ Transformer Encoder Layer: 1) Multi-head self-attention 2) Feed-forward network 3) Residual connections + LayerNorm """ def __init__( self, embed_dim: int, num_heads: int, mlp_dim: int, dropout: float = 0.1, ): super().__init__() # 1) Self-attention self.self_attn = MultiHeadAttention( embed_dim=embed_dim, key_dim=embed_dim, # self-attention uses same dimension for Q, K, V num_heads=num_heads, dropout=dropout, ) # 2) Feed-forward network using nn.Sequential self.ffn = nn.Sequential( nn.Linear(embed_dim, mlp_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(mlp_dim, embed_dim), ) # 3) LayerNorm and Dropouts for residuals self.norm1 = nn.LayerNorm(embed_dim) self.norm2 = nn.LayerNorm(embed_dim) # self.norm3 = nn.LayerNorm(embed_dim) self.attn_dropout = nn.Dropout(dropout) self.ff_dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: Tensor of shape [batch, seq_len, embed_dim] Returns: Tensor of same shape """ # 1) Self-attention block x_norm = self.norm1(x) # Normalize input attn_out = self.self_attn(x_norm, x_norm, x_norm) x = x + self.attn_dropout(attn_out) # Residual connection # 2) Feed-forward block x_norm_ff = self.norm2(x) ff = self.ffn(x_norm_ff) x = x + self.ff_dropout(ff) # Residual connection # x = self.norm3(x) return x class TransformerEncoder(nn.Module): def __init__(self, encoder_layer: TransformerEncoderLayer, num_layers: int, embed_dim: int): super().__init__() # Clone the provided encoder_layer num_layers times self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(num_layers)]) self.norm = nn.LayerNorm(embed_dim) def forward(self, x: torch.Tensor) -> torch.Tensor: # TO DO pass through each layer # HINT: Pass input through each encoder layer to update x for layer in self.layers: x = layer(x) # Apply final normalization x = self.norm(x) return x class TransformerDecoderLayer(nn.Module): """ Transformer Decoder Layer: 1) Self-attention 2) Cross-attention (over encoder features) 3) Feed-forward network 4) Residual connections + LayerNorm """ def __init__( self, encoder_embed_dim: int, decoder_embed_dim: int, num_heads: int, mlp_dim: int, dropout: float = 0.1, ): super().__init__() # 1. Self-attention on decoder input self.self_attn = MultiHeadAttention(decoder_embed_dim, decoder_embed_dim, num_heads, dropout) # 2. Cross-attention over encoder features self.cross_attn = MultiHeadAttention(decoder_embed_dim, encoder_embed_dim, num_heads, dropout) # 3. Feed-forward network self.ffn = nn.Sequential( nn.Linear(decoder_embed_dim, mlp_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(mlp_dim, decoder_embed_dim) ) # 4. LayerNorms and Dropouts self.norm1 = nn.LayerNorm(decoder_embed_dim) self.norm2 = nn.LayerNorm(decoder_embed_dim) self.norm3 = nn.LayerNorm(decoder_embed_dim) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.dropout3 = nn.Dropout(dropout) def forward(self, x: torch.Tensor, encoder_features: torch.Tensor) -> torch.Tensor: # 1) Self-attention block x_norm = self.norm1(x) # Normalize input sa = self.self_attn(x_norm, x_norm, x_norm) # TO DO x = x + self.dropout1(sa) # 2) Cross-attention block x_norm_ca = self.norm2(x) ca = self.cross_attn(x_norm_ca, encoder_features, encoder_features) # TO DO x = x + self.dropout2(ca) # 3) Feed-forward block x_norm_ff = self.norm3(x) ff = self.ffn(x_norm_ff) # TO DO x = x + self.dropout3(ff) return x class TransformerDecoder(nn.Module): def __init__( self, decoder_layer: TransformerDecoderLayer, num_layers: int, embed_dim: int, ): super().__init__() # Clone the provided decoder_layer num_layers times self.layers = nn.ModuleList([copy.deepcopy(decoder_layer) for _ in range(num_layers)]) self.norm = nn.LayerNorm(embed_dim) def forward(self, x: torch.Tensor, encoder_features: torch.Tensor) -> torch.Tensor: # TODO pass through each layer # HINT: Pass input through each encoder layer to update x for layer in self.layers: x = layer(x, encoder_features) x = self.norm(x) return x