AlainDeLong's picture
Create translate app
e27ab6a
from torch import Tensor
import torch.nn as nn
from typing import Callable
from jaxtyping import Bool, Float
from layers import MultiHeadAttention, PositionwiseFeedForward
class ResidualConnection(nn.Module):
"""
Implements the (Pre-LN) Residual Connection module, which wraps a sublayer
(like MultiHeadAttention or FFN) with LayerNormalization and Dropout.
This is the modern "best practice" used in models like GPT-2, which is
more stable than the original Post-LN design in "Attention Is All You Need".
Architecture: x = x + Dropout(Sublayer(LayerNorm(x)))
"""
def __init__(self, d_model: int, dropout: float = 0.1) -> None:
"""
Initializes the Residual Connection.
Args:
d_model (int): The dimension of the model (D).
dropout (float): Dropout probability to apply to the sublayer output.
"""
super().__init__()
self.dropout: nn.Dropout = nn.Dropout(dropout)
self.norm: nn.LayerNorm = nn.LayerNorm(d_model)
def forward(
self,
x: Float[Tensor, "B T D"],
sublayer: Callable[[Float[Tensor, "B T D"]], Float[Tensor, "B T D"]],
) -> Float[Tensor, "B T D"]:
"""
Forward pass for the Residual Connection.
Args:
x (Tensor): The input tensor from the previous layer.
sublayer (Callable): The sublayer module (e.g., MHA or FFN)
to apply the connection to.
Returns:
Tensor: The output tensor after the residual connection.
"""
x_normed = self.norm(x)
sublayer_output = sublayer(x_normed)
dropout_output = self.dropout(sublayer_output)
return x + dropout_output
class EncoderLayer(nn.Module):
"""
Implements one single Encoder Layer (or "Block") of the Transformer Encoder.
An Encoder Layer consists of two main sublayers:
1. A Multi-Head Self-Attention mechanism (MHA).
2. A Position-wise Feed-Forward Network (FFN).
Each sublayer is wrapped by a ResidualConnection (which includes
Pre-LayerNormalization and Dropout).
Architecture:
x -> Residual_1(x, MHA) -> x'
x' -> Residual_2(x', FFN) -> output
"""
def __init__(
self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1
) -> None:
"""
Initializes the Encoder Layer.
Args:
d_model (int): The dimension of the model (D).
n_heads (int): The number of attention heads (H).
d_ff (int): The inner dimension of the Feed-Forward Network (D_FF).
dropout (float): The dropout rate for the residual connections.
"""
super().__init__()
self.self_attn: MultiHeadAttention = MultiHeadAttention(d_model, n_heads)
self.feed_forward: PositionwiseFeedForward = PositionwiseFeedForward(
d_model, d_ff
)
self.residual_1: ResidualConnection = ResidualConnection(d_model, dropout)
self.residual_2: ResidualConnection = ResidualConnection(d_model, dropout)
def forward(
self, x: Float[Tensor, "B T D"], src_mask: Bool[Tensor, "B 1 1 T_k"]
) -> Float[Tensor, "B T D"]:
"""
Forward pass for the Encoder Layer.
Args:
x (Tensor): Input tensor from the previous layer or embedding.
src_mask (Tensor): The padding mask for the source sentence.
Shape (B, 1, 1, T_k) allows broadcasting
to (B, H, T_q, T_k).
Returns:
Tensor: The output tensor of the Encoder Layer.
"""
x = self.residual_1(
x,
lambda x_normed: self.self_attn(
q=x_normed, k=x_normed, v=x_normed, mask=src_mask
),
)
x = self.residual_2(x, self.feed_forward)
return x
class Encoder(nn.Module):
"""
Implements the full Transformer Encoder, which is a stack of N
identical EncoderLayers.
This module takes the input embeddings + positional encodings and
processes them through N layers of self-attention and FFNs.
(Best Practice: Uses Pre-LN, so a final LayerNorm is applied
at the *end* of the stack, before passing to the Decoder).
"""
def __init__(
self, d_model: int, n_heads: int, d_ff: int, n_layers: int, dropout: float = 0.1
) -> None:
"""
Initializes the Encoder stack.
Args:
d_model (int): The dimension of the model (D).
n_heads (int): The number of attention heads (H).
d_ff (int): The inner dimension of the Feed-Forward Network (D_FF).
n_layers (int): The number of EncoderLayer blocks to stack (N).
dropout (float): The dropout rate for the residual connections.
"""
super().__init__()
self.layers: nn.ModuleList = nn.ModuleList(
[EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)]
)
self.norm: nn.LayerNorm = nn.LayerNorm(d_model)
def forward(
self, x: Float[Tensor, "B T D"], src_mask: Bool[Tensor, "B 1 1 T"]
) -> Float[Tensor, "B T D"]:
"""
Forward pass for the entire Encoder stack.
Args:
x (Tensor): Input tensor (usually token embeddings + pos encodings).
src_mask (Tensor): The padding mask for the source sentence.
Returns:
Tensor: The output of the final Encoder layer (the "context"
or "memory" for the Decoder).
"""
for layer in self.layers:
x = layer(x, src_mask)
x = self.norm(x)
return x
class DecoderLayer(nn.Module):
"""
Implements one single Decoder Layer (or "Block") of the Transformer Decoder.
A Decoder Layer consists of three main sublayers:
1. A Masked Multi-Head Self-Attention mechanism (MHA).
2. A Multi-Head Cross-Attention mechanism (MHA).
3. A Position-wise Feed-Forward Network (FFN).
Each sublayer is wrapped by a ResidualConnection (Pre-LN and Dropout).
Architecture:
x -> Residual_1(x, Masked_MHA) -> x'
x' -> Residual_2(x', Cross_MHA, enc_output) -> x''
x'' -> Residual_3(x'', FFN) -> output
"""
def __init__(
self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1
) -> None:
"""
Initializes the Decoder Layer.
Args:
d_model (int): The dimension of the model (D).
n_heads (int): The number of attention heads (H).
d_ff (int): The inner dimension of the Feed-Forward Network (D_FF).
dropout (float): The dropout rate for the residual connections.
"""
super().__init__()
self.self_attn: MultiHeadAttention = MultiHeadAttention(d_model, n_heads)
self.cross_attn: MultiHeadAttention = MultiHeadAttention(d_model, n_heads)
self.feed_forward: PositionwiseFeedForward = PositionwiseFeedForward(
d_model, d_ff
)
self.residual_1: ResidualConnection = ResidualConnection(d_model, dropout)
self.residual_2: ResidualConnection = ResidualConnection(d_model, dropout)
self.residual_3: ResidualConnection = ResidualConnection(d_model, dropout)
def forward(
self,
x: Float[Tensor, "B T_tgt D"],
enc_output: Float[Tensor, "B T_src D"],
src_mask: Bool[Tensor, "B 1 1 T_src"],
tgt_mask: Bool[Tensor, "B 1 1 T_tgt"],
) -> Float[Tensor, "B T_tgt D"]:
"""
Forward pass for the Decoder Layer.
Args:
x (Tensor): Input tensor from the previous decoder layer.
enc_output (Tensor): The output tensor from the Encoder (K, V).
src_mask (Tensor): The padding mask for the source (Encoder) input.
tgt_mask (Tensor): The combined look-ahead and padding mask
for the target (Decoder) input.
Returns:
Tensor: The output tensor of the Decoder Layer.
"""
x = self.residual_1(
x,
lambda x_normed: self.self_attn(
q=x_normed, k=x_normed, v=x_normed, mask=tgt_mask
),
)
x = self.residual_2(
x,
lambda x_normed: self.cross_attn(
q=x_normed, k=enc_output, v=enc_output, mask=src_mask
),
)
x = self.residual_3(x, self.feed_forward)
return x
class Decoder(nn.Module):
"""
Implements the full Transformer Decoder, which is a stack of N
identical DecoderLayers.
This module takes the target embeddings + positional encodings and
processes them through N layers of masked self-attention,
cross-attention, and FFNs.
(Best Practice: Uses Pre-LN, so a final LayerNorm is applied
at the *end* of the stack, before passing to the final Generator).
"""
def __init__(
self, d_model: int, n_heads: int, d_ff: int, n_layers: int, dropout: float = 0.1
) -> None:
"""
Initializes the Decoder stack.
Args:
d_model (int): The dimension of the model (D).
n_heads (int): The number of attention heads (H).
d_ff (int): The inner dimension of the Feed-Forward Network (D_FF).
n_layers (int): The number of DecoderLayer blocks to stack (N).
dropout (float): The dropout rate for the residual connections.
"""
super().__init__()
self.layers: nn.ModuleList = nn.ModuleList(
[DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)]
)
self.norm: nn.LayerNorm = nn.LayerNorm(d_model)
def forward(
self,
x: Float[Tensor, "B T_tgt D"],
enc_output: Float[Tensor, "B T_src D"],
src_mask: Bool[Tensor, "B 1 1 T_src"],
tgt_mask: Bool[Tensor, "1 1 T_tgt T_tgt"],
) -> Float[Tensor, "B T_tgt D"]:
"""
Forward pass for the entire Decoder stack.
Args:
x (Tensor): Input tensor for the target (embeddings + pos enc).
enc_output (Tensor): The output from the Encoder (K, V for cross-attn).
src_mask (Tensor): Padding mask for the source (Encoder) sequence.
tgt_mask (Tensor): Combined mask for the target (Decoder) sequence.
Returns:
Tensor: The output of the final Decoder layer, ready for the
final projection (Generator).
"""
for layer in self.layers:
x = layer(x, enc_output, src_mask, tgt_mask)
x = self.norm(x)
return x