import math from typing import Optional from einops import rearrange import torch import torch.nn as nn import torch.nn.functional as F class Quantizer(nn.Module): def __init__(self, codebook_size: int, codebook_dim: int, input_dim: int, max_codebook_updates_with_revival: Optional[int] = None) -> None: super().__init__() assert math.log2(codebook_size).is_integer() self.revival_entropy_threshold = int(math.log2(codebook_size)) - 2 self.max_codebook_updates_with_revival = max_codebook_updates_with_revival self.pre_quant_proj = nn.Linear(input_dim, codebook_dim) self.post_quant_proj = nn.Linear(codebook_dim, input_dim) codebook = torch.empty(codebook_size, codebook_dim, requires_grad=False).uniform_(-1.0 / codebook_size, 1.0 / codebook_size) self.register_buffer('num_codebook_updates', torch.tensor(0)) self.register_buffer('codebook', codebook) self.register_buffer('codewords_freqs', torch.ones(codebook_size).div(codebook_size)) def forward(self, z: torch.Tensor) -> dict: z = self.pre_quant_proj(z) z = F.normalize(z, dim=-1) b, k = z.size(0), z.size(2) z = rearrange(z, 'b t k e -> (b t k) e') cosine_similarity = torch.einsum('n e, c e -> n c', z, self.codebook) tokens = cosine_similarity.argmax(dim=-1) q = self.codebook[tokens] q = z + (q - z).detach() q = self.post_quant_proj(q) q = rearrange(q, '(b t k) e -> b t k e', b=b, k=k) tokens = rearrange(tokens, '(b t k) -> b t k', b=b, k=k) return { "q": q, "tokens": tokens, } def compute_codebook_entropy(self) -> float: probs = self.codewords_freqs[self.codewords_freqs != 0] return -(torch.log2(probs) * probs).sum().item() @torch.no_grad() def embed_tokens(self, tokens: torch.LongTensor) -> torch.FloatTensor: return self.post_quant_proj(self.codebook[tokens])