| import math |
| import typing as tp |
| from dataclasses import dataclass, field |
| import typing as tp |
| import torch |
| from torch import nn |
| from einops import rearrange |
| import torch.nn.functional as F |
|
|
| @dataclass |
| class QuantizedResult: |
| x: torch.Tensor |
| codes: torch.Tensor |
| bandwidth: torch.Tensor |
| penalty: tp.Optional[torch.Tensor] = None |
| metrics: dict = field(default_factory=dict) |
|
|
|
|
| |
| |
| |
|
|
| class EuclideanCodebook(nn.Module): |
| def __init__( |
| self, |
| dim, |
| codebook_size, |
| kmeans_init=False, |
| kmeans_iters=10, |
| decay=0.8, |
| epsilon=1e-5, |
| ): |
| super().__init__() |
| self.decay=decay |
| init_fn=uniform_init if not kmeans_init else torch.zeros |
| embed = init_fn(codebook_size, dim) |
|
|
| self.codebook_size = codebook_size |
|
|
| self.kmeans_iters = kmeans_iters |
| self.epsilon = epsilon |
|
|
| self.register_buffer("inited", torch.Tensor([not kmeans_init])) |
| self.register_buffer("cluster_size", torch.zeros(codebook_size)) |
| self.register_buffer("embed", embed) |
| self.register_buffer("embed_avg", embed.clone()) |
|
|
| @torch.jit.ignore |
| def init_embed_(self, data): |
| if self.inited: |
| return |
|
|
| embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) |
| self.embed.data.copy_(embed) |
| self.embed_avg.data.copy_(embed.clone()) |
| self.cluster_size.data.copy_(cluster_size) |
| self.inited.data.copy_(torch.Tensor([True])) |
| |
| |
|
|
|
|
|
|
| def postprocess_emb(self, embed_ind, shape): |
| return embed_ind.view(*shape[:-1]) |
|
|
| def dequantize(self, embed_ind): |
| |
| |
| quantize = F.embedding(embed_ind, self.embed) |
| |
| return quantize |
|
|
| def decode(self, embed_ind): |
| quantize = self.dequantize(embed_ind) |
| return quantize |
|
|
|
|
|
|
| class VectorQuantization(nn.Module): |
| |
| def __init__( |
| self, |
| dim, |
| codebook_size, |
| codebook_dim=None, |
| decay=0.8, |
| epsilon=1e-5, |
| kmeans_init=False, |
| kmeans_iters=10, |
| channels_last=False, |
| ): |
| super().__init__() |
| |
| _codebook_dim = codebook_dim if codebook_dim is not None else dim |
|
|
| requires_projection = _codebook_dim != dim |
| self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()) |
| self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()) |
| self._codebook = EuclideanCodebook(dim=_codebook_dim, |
| codebook_size=codebook_size, |
| kmeans_init=kmeans_init, |
| kmeans_iters=kmeans_iters, |
| decay=decay, |
| epsilon=epsilon) |
| self.codebook_size = codebook_size |
|
|
| self.channels_last = channels_last |
|
|
| @property |
| def codebook(self): |
| return self._codebook.embed |
|
|
| @property |
| def inited(self): |
| return self._codebook.inited |
|
|
| def _postprocess(self, quantize): |
| if not self.channels_last: |
| quantize = rearrange(quantize, "b n d -> b d n") |
| return quantize |
|
|
| def decode(self, embed_ind): |
| quantize = self._codebook.decode(embed_ind) |
| quantize = self.project_out(quantize) |
| quantize = self._postprocess(quantize) |
| return quantize |
|
|
|
|
|
|
|
|
| class ResidualVectorQuantization(nn.Module): |
| """Residual vector quantization implementation. |
| |
| Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf |
| """ |
| def __init__(self, *, num_quantizers, **kwargs): |
| super().__init__() |
| self.layers = nn.ModuleList( |
| [VectorQuantization(**kwargs) for _ in range(num_quantizers)] |
| ) |
|
|
| def decode(self, q_indices: torch.Tensor) -> torch.Tensor: |
| quantized_out = torch.tensor(0.0, device=q_indices.device) |
| for i, indices in enumerate(q_indices): |
| layer = self.layers[i] |
| quantized = layer.decode(indices) |
| quantized_out = quantized_out + quantized |
| return quantized_out |
|
|
|
|
| class ResidualVectorQuantizer(nn.Module): |
| |
| |
| |
|
|
| def __init__( |
| self, |
| dimension = 128, |
| n_q = 4, |
| q_dropout = False, |
| bins = 2048, |
| decay = 0.99, |
| kmeans_init = True, |
| kmeans_iters = 50, |
| threshold_ema_dead_code = 2, |
| orthogonal_reg_weight = 0.0, |
| orthogonal_reg_active_codes_only = False, |
| orthogonal_reg_max_codes = None, |
| ): |
| super().__init__() |
| self.max_n_q = n_q |
| self.n_q = n_q |
| self.q_dropout = q_dropout |
| self.dimension = dimension |
| self.bins = bins |
| self.decay = decay |
| self.kmeans_init = kmeans_init |
| self.kmeans_iters = kmeans_iters |
| self.threshold_ema_dead_code = threshold_ema_dead_code |
| self.orthogonal_reg_weight = orthogonal_reg_weight |
| self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only |
| self.orthogonal_reg_max_codes = orthogonal_reg_max_codes |
| print(f' {kmeans_init=}\n\n\n\n') |
| self.vq = ResidualVectorQuantization( |
| dim=self.dimension, |
| codebook_size=self.bins, |
| num_quantizers=self.n_q, |
| decay=self.decay, |
| kmeans_init=self.kmeans_init, |
| kmeans_iters=self.kmeans_iters, |
| channels_last=False |
| ) |
|
|
| def forward(self, x: torch.Tensor, frame_rate: int): |
| n_q = self.n_q |
| if self.training and self.q_dropout: |
| n_q = int(torch.randint(1, self.n_q + 1, (1,)).item()) |
| bw_per_q = math.log2(self.bins) * frame_rate / 1000 |
| quantized, codes, commit_loss = self.vq(x, n_q=n_q) |
| codes = codes.transpose(0, 1) |
| |
| bw = torch.tensor(n_q * bw_per_q).to(x) |
| return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss)) |
|
|
| def encode(self, x: torch.Tensor) -> torch.Tensor: |
| """Encode a given input tensor with the specified frame rate at the given bandwidth. |
| The RVQ encode method sets the appropriate number of quantizer to use |
| and returns indices for each quantizer. |
| """ |
| n_q = self.n_q |
| codes = self.vq.encode(x, n_q=n_q) |
| codes = codes.transpose(0, 1) |
| |
| return codes |
|
|
| def decode(self, codes: torch.Tensor) -> torch.Tensor: |
| """Decode the given codes to the quantized representation.""" |
| |
| codes = codes.transpose(0, 1) |
| quantized = self.vq.decode(codes) |
| return quantized |
|
|
| @property |
| def total_codebooks(self): |
| return self.max_n_q |
|
|
| @property |
| def num_codebooks(self): |
| return self.n_q |
|
|
| def set_num_codebooks(self, n: int): |
| assert n > 0 and n <= self.max_n_q |
| self.n_q = n |
|
|