| from torch import Tensor | |
| import torch.nn as nn | |
| from typing import Callable | |
| from jaxtyping import Bool, Float | |
| from layers import MultiHeadAttention, PositionwiseFeedForward | |
| class ResidualConnection(nn.Module): | |
| """ | |
| Implements the (Pre-LN) Residual Connection module, which wraps a sublayer | |
| (like MultiHeadAttention or FFN) with LayerNormalization and Dropout. | |
| This is the modern "best practice" used in models like GPT-2, which is | |
| more stable than the original Post-LN design in "Attention Is All You Need". | |
| Architecture: x = x + Dropout(Sublayer(LayerNorm(x))) | |
| """ | |
| def __init__(self, d_model: int, dropout: float = 0.1) -> None: | |
| """ | |
| Initializes the Residual Connection. | |
| Args: | |
| d_model (int): The dimension of the model (D). | |
| dropout (float): Dropout probability to apply to the sublayer output. | |
| """ | |
| super().__init__() | |
| self.dropout: nn.Dropout = nn.Dropout(dropout) | |
| self.norm: nn.LayerNorm = nn.LayerNorm(d_model) | |
| def forward( | |
| self, | |
| x: Float[Tensor, "B T D"], | |
| sublayer: Callable[[Float[Tensor, "B T D"]], Float[Tensor, "B T D"]], | |
| ) -> Float[Tensor, "B T D"]: | |
| """ | |
| Forward pass for the Residual Connection. | |
| Args: | |
| x (Tensor): The input tensor from the previous layer. | |
| sublayer (Callable): The sublayer module (e.g., MHA or FFN) | |
| to apply the connection to. | |
| Returns: | |
| Tensor: The output tensor after the residual connection. | |
| """ | |
| x_normed = self.norm(x) | |
| sublayer_output = sublayer(x_normed) | |
| dropout_output = self.dropout(sublayer_output) | |
| return x + dropout_output | |
| class EncoderLayer(nn.Module): | |
| """ | |
| Implements one single Encoder Layer (or "Block") of the Transformer Encoder. | |
| An Encoder Layer consists of two main sublayers: | |
| 1. A Multi-Head Self-Attention mechanism (MHA). | |
| 2. A Position-wise Feed-Forward Network (FFN). | |
| Each sublayer is wrapped by a ResidualConnection (which includes | |
| Pre-LayerNormalization and Dropout). | |
| Architecture: | |
| x -> Residual_1(x, MHA) -> x' | |
| x' -> Residual_2(x', FFN) -> output | |
| """ | |
| def __init__( | |
| self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1 | |
| ) -> None: | |
| """ | |
| Initializes the Encoder Layer. | |
| Args: | |
| d_model (int): The dimension of the model (D). | |
| n_heads (int): The number of attention heads (H). | |
| d_ff (int): The inner dimension of the Feed-Forward Network (D_FF). | |
| dropout (float): The dropout rate for the residual connections. | |
| """ | |
| super().__init__() | |
| self.self_attn: MultiHeadAttention = MultiHeadAttention(d_model, n_heads) | |
| self.feed_forward: PositionwiseFeedForward = PositionwiseFeedForward( | |
| d_model, d_ff | |
| ) | |
| self.residual_1: ResidualConnection = ResidualConnection(d_model, dropout) | |
| self.residual_2: ResidualConnection = ResidualConnection(d_model, dropout) | |
| def forward( | |
| self, x: Float[Tensor, "B T D"], src_mask: Bool[Tensor, "B 1 1 T_k"] | |
| ) -> Float[Tensor, "B T D"]: | |
| """ | |
| Forward pass for the Encoder Layer. | |
| Args: | |
| x (Tensor): Input tensor from the previous layer or embedding. | |
| src_mask (Tensor): The padding mask for the source sentence. | |
| Shape (B, 1, 1, T_k) allows broadcasting | |
| to (B, H, T_q, T_k). | |
| Returns: | |
| Tensor: The output tensor of the Encoder Layer. | |
| """ | |
| x = self.residual_1( | |
| x, | |
| lambda x_normed: self.self_attn( | |
| q=x_normed, k=x_normed, v=x_normed, mask=src_mask | |
| ), | |
| ) | |
| x = self.residual_2(x, self.feed_forward) | |
| return x | |
| class Encoder(nn.Module): | |
| """ | |
| Implements the full Transformer Encoder, which is a stack of N | |
| identical EncoderLayers. | |
| This module takes the input embeddings + positional encodings and | |
| processes them through N layers of self-attention and FFNs. | |
| (Best Practice: Uses Pre-LN, so a final LayerNorm is applied | |
| at the *end* of the stack, before passing to the Decoder). | |
| """ | |
| def __init__( | |
| self, d_model: int, n_heads: int, d_ff: int, n_layers: int, dropout: float = 0.1 | |
| ) -> None: | |
| """ | |
| Initializes the Encoder stack. | |
| Args: | |
| d_model (int): The dimension of the model (D). | |
| n_heads (int): The number of attention heads (H). | |
| d_ff (int): The inner dimension of the Feed-Forward Network (D_FF). | |
| n_layers (int): The number of EncoderLayer blocks to stack (N). | |
| dropout (float): The dropout rate for the residual connections. | |
| """ | |
| super().__init__() | |
| self.layers: nn.ModuleList = nn.ModuleList( | |
| [EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)] | |
| ) | |
| self.norm: nn.LayerNorm = nn.LayerNorm(d_model) | |
| def forward( | |
| self, x: Float[Tensor, "B T D"], src_mask: Bool[Tensor, "B 1 1 T"] | |
| ) -> Float[Tensor, "B T D"]: | |
| """ | |
| Forward pass for the entire Encoder stack. | |
| Args: | |
| x (Tensor): Input tensor (usually token embeddings + pos encodings). | |
| src_mask (Tensor): The padding mask for the source sentence. | |
| Returns: | |
| Tensor: The output of the final Encoder layer (the "context" | |
| or "memory" for the Decoder). | |
| """ | |
| for layer in self.layers: | |
| x = layer(x, src_mask) | |
| x = self.norm(x) | |
| return x | |
| class DecoderLayer(nn.Module): | |
| """ | |
| Implements one single Decoder Layer (or "Block") of the Transformer Decoder. | |
| A Decoder Layer consists of three main sublayers: | |
| 1. A Masked Multi-Head Self-Attention mechanism (MHA). | |
| 2. A Multi-Head Cross-Attention mechanism (MHA). | |
| 3. A Position-wise Feed-Forward Network (FFN). | |
| Each sublayer is wrapped by a ResidualConnection (Pre-LN and Dropout). | |
| Architecture: | |
| x -> Residual_1(x, Masked_MHA) -> x' | |
| x' -> Residual_2(x', Cross_MHA, enc_output) -> x'' | |
| x'' -> Residual_3(x'', FFN) -> output | |
| """ | |
| def __init__( | |
| self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1 | |
| ) -> None: | |
| """ | |
| Initializes the Decoder Layer. | |
| Args: | |
| d_model (int): The dimension of the model (D). | |
| n_heads (int): The number of attention heads (H). | |
| d_ff (int): The inner dimension of the Feed-Forward Network (D_FF). | |
| dropout (float): The dropout rate for the residual connections. | |
| """ | |
| super().__init__() | |
| self.self_attn: MultiHeadAttention = MultiHeadAttention(d_model, n_heads) | |
| self.cross_attn: MultiHeadAttention = MultiHeadAttention(d_model, n_heads) | |
| self.feed_forward: PositionwiseFeedForward = PositionwiseFeedForward( | |
| d_model, d_ff | |
| ) | |
| self.residual_1: ResidualConnection = ResidualConnection(d_model, dropout) | |
| self.residual_2: ResidualConnection = ResidualConnection(d_model, dropout) | |
| self.residual_3: ResidualConnection = ResidualConnection(d_model, dropout) | |
| def forward( | |
| self, | |
| x: Float[Tensor, "B T_tgt D"], | |
| enc_output: Float[Tensor, "B T_src D"], | |
| src_mask: Bool[Tensor, "B 1 1 T_src"], | |
| tgt_mask: Bool[Tensor, "B 1 1 T_tgt"], | |
| ) -> Float[Tensor, "B T_tgt D"]: | |
| """ | |
| Forward pass for the Decoder Layer. | |
| Args: | |
| x (Tensor): Input tensor from the previous decoder layer. | |
| enc_output (Tensor): The output tensor from the Encoder (K, V). | |
| src_mask (Tensor): The padding mask for the source (Encoder) input. | |
| tgt_mask (Tensor): The combined look-ahead and padding mask | |
| for the target (Decoder) input. | |
| Returns: | |
| Tensor: The output tensor of the Decoder Layer. | |
| """ | |
| x = self.residual_1( | |
| x, | |
| lambda x_normed: self.self_attn( | |
| q=x_normed, k=x_normed, v=x_normed, mask=tgt_mask | |
| ), | |
| ) | |
| x = self.residual_2( | |
| x, | |
| lambda x_normed: self.cross_attn( | |
| q=x_normed, k=enc_output, v=enc_output, mask=src_mask | |
| ), | |
| ) | |
| x = self.residual_3(x, self.feed_forward) | |
| return x | |
| class Decoder(nn.Module): | |
| """ | |
| Implements the full Transformer Decoder, which is a stack of N | |
| identical DecoderLayers. | |
| This module takes the target embeddings + positional encodings and | |
| processes them through N layers of masked self-attention, | |
| cross-attention, and FFNs. | |
| (Best Practice: Uses Pre-LN, so a final LayerNorm is applied | |
| at the *end* of the stack, before passing to the final Generator). | |
| """ | |
| def __init__( | |
| self, d_model: int, n_heads: int, d_ff: int, n_layers: int, dropout: float = 0.1 | |
| ) -> None: | |
| """ | |
| Initializes the Decoder stack. | |
| Args: | |
| d_model (int): The dimension of the model (D). | |
| n_heads (int): The number of attention heads (H). | |
| d_ff (int): The inner dimension of the Feed-Forward Network (D_FF). | |
| n_layers (int): The number of DecoderLayer blocks to stack (N). | |
| dropout (float): The dropout rate for the residual connections. | |
| """ | |
| super().__init__() | |
| self.layers: nn.ModuleList = nn.ModuleList( | |
| [DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)] | |
| ) | |
| self.norm: nn.LayerNorm = nn.LayerNorm(d_model) | |
| def forward( | |
| self, | |
| x: Float[Tensor, "B T_tgt D"], | |
| enc_output: Float[Tensor, "B T_src D"], | |
| src_mask: Bool[Tensor, "B 1 1 T_src"], | |
| tgt_mask: Bool[Tensor, "1 1 T_tgt T_tgt"], | |
| ) -> Float[Tensor, "B T_tgt D"]: | |
| """ | |
| Forward pass for the entire Decoder stack. | |
| Args: | |
| x (Tensor): Input tensor for the target (embeddings + pos enc). | |
| enc_output (Tensor): The output from the Encoder (K, V for cross-attn). | |
| src_mask (Tensor): Padding mask for the source (Encoder) sequence. | |
| tgt_mask (Tensor): Combined mask for the target (Decoder) sequence. | |
| Returns: | |
| Tensor: The output of the final Decoder layer, ready for the | |
| final projection (Generator). | |
| """ | |
| for layer in self.layers: | |
| x = layer(x, enc_output, src_mask, tgt_mask) | |
| x = self.norm(x) | |
| return x | |