| # Copyright © 2023 Apple Inc. | |
| import math | |
| import mlx.core as mx | |
| from mlx.nn.layers.base import Module | |
| class Embedding(Module): | |
| """Implements a simple lookup table that maps each input integer to a | |
| high-dimensional vector. | |
| Typically used to embed discrete tokens for processing by neural networks. | |
| Args: | |
| num_embeddings (int): How many possible discrete tokens can we embed. | |
| Usually called the vocabulary size. | |
| dims (int): The dimensionality of the embeddings. | |
| """ | |
| def __init__(self, num_embeddings: int, dims: int): | |
| super().__init__() | |
| scale = math.sqrt(1 / dims) | |
| self.weight = mx.random.normal((num_embeddings, dims)) * scale | |
| def _extra_repr(self): | |
| return f"{self.weight.shape[0]}, {self.weight.shape[1]}" | |
| def __call__(self, x): | |
| return self.weight[x] | |