| import torch | |
| from torch import nn | |
| class TokenEmbedding(nn.Module): | |
| """Maps token ids to continuous vectors x_t = W_embed[token].""" | |
| def __init__(self, vocab_size: int, d_model: int): | |
| super().__init__() | |
| self.embedding = nn.Embedding(vocab_size, d_model) | |
| nn.init.normal_(self.embedding.weight, mean=0.0, std=0.02) | |
| def weight(self) -> torch.Tensor: | |
| return self.embedding.weight | |
| def forward(self, token_ids: torch.Tensor) -> torch.Tensor: | |
| return self.embedding(token_ids) | |