File size: 1,993 Bytes
f9f6093 5bd6809 f9f6093 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import math
from typing import Optional
from einops import rearrange
import torch
import torch.nn as nn
import torch.nn.functional as F
class Quantizer(nn.Module):
def __init__(self, codebook_size: int, codebook_dim: int, input_dim: int, max_codebook_updates_with_revival: Optional[int] = None) -> None:
super().__init__()
assert math.log2(codebook_size).is_integer()
self.revival_entropy_threshold = int(math.log2(codebook_size)) - 2
self.max_codebook_updates_with_revival = max_codebook_updates_with_revival
self.pre_quant_proj = nn.Linear(input_dim, codebook_dim)
self.post_quant_proj = nn.Linear(codebook_dim, input_dim)
codebook = torch.empty(codebook_size, codebook_dim, requires_grad=False).uniform_(-1.0 / codebook_size, 1.0 / codebook_size)
self.register_buffer('num_codebook_updates', torch.tensor(0))
self.register_buffer('codebook', codebook)
self.register_buffer('codewords_freqs', torch.ones(codebook_size).div(codebook_size))
def forward(self, z: torch.Tensor) -> dict:
z = self.pre_quant_proj(z)
z = F.normalize(z, dim=-1)
b, k = z.size(0), z.size(2)
z = rearrange(z, 'b t k e -> (b t k) e')
cosine_similarity = torch.einsum('n e, c e -> n c', z, self.codebook)
tokens = cosine_similarity.argmax(dim=-1)
q = self.codebook[tokens]
q = z + (q - z).detach()
q = self.post_quant_proj(q)
q = rearrange(q, '(b t k) e -> b t k e', b=b, k=k)
tokens = rearrange(tokens, '(b t k) -> b t k', b=b, k=k)
return {
"q": q,
"tokens": tokens,
}
def compute_codebook_entropy(self) -> float:
probs = self.codewords_freqs[self.codewords_freqs != 0]
return -(torch.log2(probs) * probs).sum().item()
@torch.no_grad()
def embed_tokens(self, tokens: torch.LongTensor) -> torch.FloatTensor:
return self.post_quant_proj(self.codebook[tokens])
|