from torch import Tensor import torch.nn as nn from typing import Callable from jaxtyping import Bool, Float from layers import MultiHeadAttention, PositionwiseFeedForward class ResidualConnection(nn.Module): """ Implements the (Pre-LN) Residual Connection module, which wraps a sublayer (like MultiHeadAttention or FFN) with LayerNormalization and Dropout. This is the modern "best practice" used in models like GPT-2, which is more stable than the original Post-LN design in "Attention Is All You Need". Architecture: x = x + Dropout(Sublayer(LayerNorm(x))) """ def __init__(self, d_model: int, dropout: float = 0.1) -> None: """ Initializes the Residual Connection. Args: d_model (int): The dimension of the model (D). dropout (float): Dropout probability to apply to the sublayer output. """ super().__init__() self.dropout: nn.Dropout = nn.Dropout(dropout) self.norm: nn.LayerNorm = nn.LayerNorm(d_model) def forward( self, x: Float[Tensor, "B T D"], sublayer: Callable[[Float[Tensor, "B T D"]], Float[Tensor, "B T D"]], ) -> Float[Tensor, "B T D"]: """ Forward pass for the Residual Connection. Args: x (Tensor): The input tensor from the previous layer. sublayer (Callable): The sublayer module (e.g., MHA or FFN) to apply the connection to. Returns: Tensor: The output tensor after the residual connection. """ x_normed = self.norm(x) sublayer_output = sublayer(x_normed) dropout_output = self.dropout(sublayer_output) return x + dropout_output class EncoderLayer(nn.Module): """ Implements one single Encoder Layer (or "Block") of the Transformer Encoder. An Encoder Layer consists of two main sublayers: 1. A Multi-Head Self-Attention mechanism (MHA). 2. A Position-wise Feed-Forward Network (FFN). Each sublayer is wrapped by a ResidualConnection (which includes Pre-LayerNormalization and Dropout). Architecture: x -> Residual_1(x, MHA) -> x' x' -> Residual_2(x', FFN) -> output """ def __init__( self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1 ) -> None: """ Initializes the Encoder Layer. Args: d_model (int): The dimension of the model (D). n_heads (int): The number of attention heads (H). d_ff (int): The inner dimension of the Feed-Forward Network (D_FF). dropout (float): The dropout rate for the residual connections. """ super().__init__() self.self_attn: MultiHeadAttention = MultiHeadAttention(d_model, n_heads) self.feed_forward: PositionwiseFeedForward = PositionwiseFeedForward( d_model, d_ff ) self.residual_1: ResidualConnection = ResidualConnection(d_model, dropout) self.residual_2: ResidualConnection = ResidualConnection(d_model, dropout) def forward( self, x: Float[Tensor, "B T D"], src_mask: Bool[Tensor, "B 1 1 T_k"] ) -> Float[Tensor, "B T D"]: """ Forward pass for the Encoder Layer. Args: x (Tensor): Input tensor from the previous layer or embedding. src_mask (Tensor): The padding mask for the source sentence. Shape (B, 1, 1, T_k) allows broadcasting to (B, H, T_q, T_k). Returns: Tensor: The output tensor of the Encoder Layer. """ x = self.residual_1( x, lambda x_normed: self.self_attn( q=x_normed, k=x_normed, v=x_normed, mask=src_mask ), ) x = self.residual_2(x, self.feed_forward) return x class Encoder(nn.Module): """ Implements the full Transformer Encoder, which is a stack of N identical EncoderLayers. This module takes the input embeddings + positional encodings and processes them through N layers of self-attention and FFNs. (Best Practice: Uses Pre-LN, so a final LayerNorm is applied at the *end* of the stack, before passing to the Decoder). """ def __init__( self, d_model: int, n_heads: int, d_ff: int, n_layers: int, dropout: float = 0.1 ) -> None: """ Initializes the Encoder stack. Args: d_model (int): The dimension of the model (D). n_heads (int): The number of attention heads (H). d_ff (int): The inner dimension of the Feed-Forward Network (D_FF). n_layers (int): The number of EncoderLayer blocks to stack (N). dropout (float): The dropout rate for the residual connections. """ super().__init__() self.layers: nn.ModuleList = nn.ModuleList( [EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)] ) self.norm: nn.LayerNorm = nn.LayerNorm(d_model) def forward( self, x: Float[Tensor, "B T D"], src_mask: Bool[Tensor, "B 1 1 T"] ) -> Float[Tensor, "B T D"]: """ Forward pass for the entire Encoder stack. Args: x (Tensor): Input tensor (usually token embeddings + pos encodings). src_mask (Tensor): The padding mask for the source sentence. Returns: Tensor: The output of the final Encoder layer (the "context" or "memory" for the Decoder). """ for layer in self.layers: x = layer(x, src_mask) x = self.norm(x) return x class DecoderLayer(nn.Module): """ Implements one single Decoder Layer (or "Block") of the Transformer Decoder. A Decoder Layer consists of three main sublayers: 1. A Masked Multi-Head Self-Attention mechanism (MHA). 2. A Multi-Head Cross-Attention mechanism (MHA). 3. A Position-wise Feed-Forward Network (FFN). Each sublayer is wrapped by a ResidualConnection (Pre-LN and Dropout). Architecture: x -> Residual_1(x, Masked_MHA) -> x' x' -> Residual_2(x', Cross_MHA, enc_output) -> x'' x'' -> Residual_3(x'', FFN) -> output """ def __init__( self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1 ) -> None: """ Initializes the Decoder Layer. Args: d_model (int): The dimension of the model (D). n_heads (int): The number of attention heads (H). d_ff (int): The inner dimension of the Feed-Forward Network (D_FF). dropout (float): The dropout rate for the residual connections. """ super().__init__() self.self_attn: MultiHeadAttention = MultiHeadAttention(d_model, n_heads) self.cross_attn: MultiHeadAttention = MultiHeadAttention(d_model, n_heads) self.feed_forward: PositionwiseFeedForward = PositionwiseFeedForward( d_model, d_ff ) self.residual_1: ResidualConnection = ResidualConnection(d_model, dropout) self.residual_2: ResidualConnection = ResidualConnection(d_model, dropout) self.residual_3: ResidualConnection = ResidualConnection(d_model, dropout) def forward( self, x: Float[Tensor, "B T_tgt D"], enc_output: Float[Tensor, "B T_src D"], src_mask: Bool[Tensor, "B 1 1 T_src"], tgt_mask: Bool[Tensor, "B 1 1 T_tgt"], ) -> Float[Tensor, "B T_tgt D"]: """ Forward pass for the Decoder Layer. Args: x (Tensor): Input tensor from the previous decoder layer. enc_output (Tensor): The output tensor from the Encoder (K, V). src_mask (Tensor): The padding mask for the source (Encoder) input. tgt_mask (Tensor): The combined look-ahead and padding mask for the target (Decoder) input. Returns: Tensor: The output tensor of the Decoder Layer. """ x = self.residual_1( x, lambda x_normed: self.self_attn( q=x_normed, k=x_normed, v=x_normed, mask=tgt_mask ), ) x = self.residual_2( x, lambda x_normed: self.cross_attn( q=x_normed, k=enc_output, v=enc_output, mask=src_mask ), ) x = self.residual_3(x, self.feed_forward) return x class Decoder(nn.Module): """ Implements the full Transformer Decoder, which is a stack of N identical DecoderLayers. This module takes the target embeddings + positional encodings and processes them through N layers of masked self-attention, cross-attention, and FFNs. (Best Practice: Uses Pre-LN, so a final LayerNorm is applied at the *end* of the stack, before passing to the final Generator). """ def __init__( self, d_model: int, n_heads: int, d_ff: int, n_layers: int, dropout: float = 0.1 ) -> None: """ Initializes the Decoder stack. Args: d_model (int): The dimension of the model (D). n_heads (int): The number of attention heads (H). d_ff (int): The inner dimension of the Feed-Forward Network (D_FF). n_layers (int): The number of DecoderLayer blocks to stack (N). dropout (float): The dropout rate for the residual connections. """ super().__init__() self.layers: nn.ModuleList = nn.ModuleList( [DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)] ) self.norm: nn.LayerNorm = nn.LayerNorm(d_model) def forward( self, x: Float[Tensor, "B T_tgt D"], enc_output: Float[Tensor, "B T_src D"], src_mask: Bool[Tensor, "B 1 1 T_src"], tgt_mask: Bool[Tensor, "1 1 T_tgt T_tgt"], ) -> Float[Tensor, "B T_tgt D"]: """ Forward pass for the entire Decoder stack. Args: x (Tensor): Input tensor for the target (embeddings + pos enc). enc_output (Tensor): The output from the Encoder (K, V for cross-attn). src_mask (Tensor): Padding mask for the source (Encoder) sequence. tgt_mask (Tensor): Combined mask for the target (Decoder) sequence. Returns: Tensor: The output of the final Decoder layer, ready for the final projection (Generator). """ for layer in self.layers: x = layer(x, enc_output, src_mask, tgt_mask) x = self.norm(x) return x