AlainDeLong's picture
Create translate app
e27ab6a
import torch
from torch import Tensor
import torch.nn as nn
from jaxtyping import Int, Float
import math
class InputEmbeddings(nn.Module):
"""
Implements the Input Embedding layer.
This module converts a tensor of token IDs into a tensor of
corresponding embedding vectors. It also scales the embeddings
by sqrt(d_model) as mentioned in the paper ("Attention Is All You Need",
Section 3.4).
"""
def __init__(self, d_model: int, vocab_size: int) -> None:
"""
Initializes the InputEmbedding layer.
Args:
d_model (int): The dimension of the embedding vector (D).
vocab_size (int): The size of the vocabulary.
"""
super().__init__()
self.d_model: int = d_model
self.vocab_size: int = vocab_size
self.token_emb: nn.Embedding = nn.Embedding(vocab_size, d_model)
def forward(self, x: Int[Tensor, "B T"]) -> Float[Tensor, "B T D"]:
"""
Forward pass for the InputEmbeddings.
Args:
x (Tensor): Input tensor of token IDs. Shape (B, T). B: batch_size, T: seq_len
Returns:
Tensor: The corresponding embedding vectors, scaled by sqrt(d_model).
Shape (B, T, D).
"""
# (B, T) -> (B, T, D)
embeddings = self.token_emb(x)
return embeddings * math.sqrt(self.d_model)
class PositionalEncoding(nn.Module):
"""
Implements the fixed (sin/cos) Positional Encoding module.
(Ref: "Attention Is All You Need", Section 3.5)
This module generates a tensor of positional encodings that are
added to the input embeddings. It also applies dropout to the
sum of the embeddings and the positional encodings.
"""
def __init__(self, d_model: int, max_seq_len: int, dropout: float = 0.1) -> None:
"""
Initializes the PositionalEncoding module.
Args:
d_model (int): The dimension of the model (D).
max_seq_len (int): The maximum sequence length (T_max) to pre-compute.
dropout (float): Dropout probability.
"""
super().__init__()
self.dropout: nn.Dropout = nn.Dropout(p=dropout)
position: Tensor = torch.arange(max_seq_len).unsqueeze(1).float()
div_term: Tensor = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000) / d_model)
)
# (T_max, D)
pe: Tensor = torch.zeros(max_seq_len, d_model)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
# (T_max D) -> (1, T_max, D)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
def forward(self, x: Float[Tensor, "B T D"]) -> Float[Tensor, "B T D"]:
"""
Adds positional encoding to the input embeddings and applies dropout.
Args:
x (Tensor): Input tensor (token embeddings, already scaled).
Shape (B, T, D).
Returns:
Tensor: Output tensor with positional information and dropout.
Shape (B, T, D).
"""
x = x + self.pe[:, : x.size(1), :]
return self.dropout(x)