|
|
import torch |
|
|
from torch import Tensor |
|
|
import torch.nn as nn |
|
|
from jaxtyping import Int, Float |
|
|
import math |
|
|
|
|
|
|
|
|
class InputEmbeddings(nn.Module): |
|
|
""" |
|
|
Implements the Input Embedding layer. |
|
|
|
|
|
This module converts a tensor of token IDs into a tensor of |
|
|
corresponding embedding vectors. It also scales the embeddings |
|
|
by sqrt(d_model) as mentioned in the paper ("Attention Is All You Need", |
|
|
Section 3.4). |
|
|
""" |
|
|
|
|
|
def __init__(self, d_model: int, vocab_size: int) -> None: |
|
|
""" |
|
|
Initializes the InputEmbedding layer. |
|
|
|
|
|
Args: |
|
|
d_model (int): The dimension of the embedding vector (D). |
|
|
vocab_size (int): The size of the vocabulary. |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
self.d_model: int = d_model |
|
|
self.vocab_size: int = vocab_size |
|
|
|
|
|
self.token_emb: nn.Embedding = nn.Embedding(vocab_size, d_model) |
|
|
|
|
|
def forward(self, x: Int[Tensor, "B T"]) -> Float[Tensor, "B T D"]: |
|
|
""" |
|
|
Forward pass for the InputEmbeddings. |
|
|
|
|
|
Args: |
|
|
x (Tensor): Input tensor of token IDs. Shape (B, T). B: batch_size, T: seq_len |
|
|
|
|
|
Returns: |
|
|
Tensor: The corresponding embedding vectors, scaled by sqrt(d_model). |
|
|
Shape (B, T, D). |
|
|
""" |
|
|
|
|
|
embeddings = self.token_emb(x) |
|
|
|
|
|
return embeddings * math.sqrt(self.d_model) |
|
|
|
|
|
|
|
|
class PositionalEncoding(nn.Module): |
|
|
""" |
|
|
Implements the fixed (sin/cos) Positional Encoding module. |
|
|
(Ref: "Attention Is All You Need", Section 3.5) |
|
|
|
|
|
This module generates a tensor of positional encodings that are |
|
|
added to the input embeddings. It also applies dropout to the |
|
|
sum of the embeddings and the positional encodings. |
|
|
""" |
|
|
|
|
|
def __init__(self, d_model: int, max_seq_len: int, dropout: float = 0.1) -> None: |
|
|
""" |
|
|
Initializes the PositionalEncoding module. |
|
|
|
|
|
Args: |
|
|
d_model (int): The dimension of the model (D). |
|
|
max_seq_len (int): The maximum sequence length (T_max) to pre-compute. |
|
|
dropout (float): Dropout probability. |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
self.dropout: nn.Dropout = nn.Dropout(p=dropout) |
|
|
|
|
|
position: Tensor = torch.arange(max_seq_len).unsqueeze(1).float() |
|
|
|
|
|
div_term: Tensor = torch.exp( |
|
|
torch.arange(0, d_model, 2).float() * (-math.log(10000) / d_model) |
|
|
) |
|
|
|
|
|
|
|
|
pe: Tensor = torch.zeros(max_seq_len, d_model) |
|
|
|
|
|
pe[:, 0::2] = torch.sin(position * div_term) |
|
|
|
|
|
pe[:, 1::2] = torch.cos(position * div_term) |
|
|
|
|
|
|
|
|
pe = pe.unsqueeze(0) |
|
|
|
|
|
self.register_buffer("pe", pe) |
|
|
|
|
|
def forward(self, x: Float[Tensor, "B T D"]) -> Float[Tensor, "B T D"]: |
|
|
""" |
|
|
Adds positional encoding to the input embeddings and applies dropout. |
|
|
|
|
|
Args: |
|
|
x (Tensor): Input tensor (token embeddings, already scaled). |
|
|
Shape (B, T, D). |
|
|
|
|
|
Returns: |
|
|
Tensor: Output tensor with positional information and dropout. |
|
|
Shape (B, T, D). |
|
|
""" |
|
|
x = x + self.pe[:, : x.size(1), :] |
|
|
|
|
|
return self.dropout(x) |
|
|
|