| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | from torch import nn, einsum |
| | from einops import rearrange |
| |
|
| |
|
| | class RandomProjectionQuantizer(nn.Module): |
| | """ |
| | Random projection and codebook lookup module |
| | |
| | Some code is borrowed from: |
| | https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/random_projection_quantizer.py |
| | But I did normalization using pre-computed global mean & variance instead of using layer norm. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | input_dim, |
| | codebook_dim, |
| | codebook_size, |
| | seed=142, |
| | ): |
| | super().__init__() |
| |
|
| | |
| | torch.manual_seed(seed) |
| |
|
| | |
| | random_projection = torch.empty(input_dim, codebook_dim) |
| | nn.init.xavier_normal_(random_projection) |
| | self.register_buffer("random_projection", random_projection) |
| |
|
| | |
| | codebook = torch.empty(codebook_size, codebook_dim) |
| | nn.init.normal_(codebook) |
| | self.register_buffer("codebook", codebook) |
| |
|
| | def codebook_lookup(self, x): |
| | |
| | b = x.shape[0] |
| | x = rearrange(x, "b n e -> (b n) e") |
| |
|
| | |
| | normalized_x = nn.functional.normalize(x, dim=1, p=2) |
| | normalized_codebook = nn.functional.normalize(self.codebook, dim=1, p=2) |
| |
|
| | |
| | distances = torch.cdist(normalized_codebook, normalized_x) |
| |
|
| | |
| | nearest_indices = torch.argmin(distances, dim=0) |
| |
|
| | |
| | xq = rearrange(nearest_indices, "(b n) -> b n", b=b) |
| |
|
| | return xq |
| |
|
| | @torch.no_grad() |
| | def forward(self, x): |
| | |
| | self.eval() |
| |
|
| | |
| | x = einsum("b n d, d e -> b n e", x, self.random_projection) |
| |
|
| | |
| | xq = self.codebook_lookup(x) |
| |
|
| | return xq |
| |
|