| | |
| | |
| | |
| | |
| | |
| |
|
| | import math |
| | import typing as tp |
| |
|
| | import torch |
| |
|
| | from .base import BaseQuantizer, QuantizedResult |
| | from .core_vq import ResidualVectorQuantization |
| |
|
| |
|
| | class ResidualVectorQuantizer(BaseQuantizer): |
| | """Residual Vector Quantizer. |
| | |
| | Args: |
| | dimension (int): Dimension of the codebooks. |
| | n_q (int): Number of residual vector quantizers used. |
| | q_dropout (bool): Random quantizer drop out at train time. |
| | bins (int): Codebook size. |
| | decay (float): Decay for exponential moving average over the codebooks. |
| | kmeans_init (bool): Whether to use kmeans to initialize the codebooks. |
| | kmeans_iters (int): Number of iterations used for kmeans initialization. |
| | threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes |
| | that have an exponential moving average cluster size less than the specified threshold with |
| | randomly selected vector from the current batch. |
| | orthogonal_reg_weight (float): Orthogonal regularization weights. |
| | orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes. |
| | orthogonal_reg_max_codes (optional int): Maximum number of codes to consider. |
| | for orthogonal regulariation. |
| | """ |
| | def __init__( |
| | self, |
| | dimension: int = 256, |
| | n_q: int = 8, |
| | q_dropout: bool = False, |
| | bins: int = 1024, |
| | decay: float = 0.99, |
| | kmeans_init: bool = True, |
| | kmeans_iters: int = 10, |
| | threshold_ema_dead_code: int = 2, |
| | orthogonal_reg_weight: float = 0.0, |
| | orthogonal_reg_active_codes_only: bool = False, |
| | orthogonal_reg_max_codes: tp.Optional[int] = None, |
| | ): |
| | super().__init__() |
| | self.max_n_q = n_q |
| | self.n_q = n_q |
| | self.q_dropout = q_dropout |
| | self.dimension = dimension |
| | self.bins = bins |
| | self.decay = decay |
| | self.kmeans_init = kmeans_init |
| | self.kmeans_iters = kmeans_iters |
| | self.threshold_ema_dead_code = threshold_ema_dead_code |
| | self.orthogonal_reg_weight = orthogonal_reg_weight |
| | self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only |
| | self.orthogonal_reg_max_codes = orthogonal_reg_max_codes |
| | self.vq = ResidualVectorQuantization( |
| | dim=self.dimension, |
| | codebook_size=self.bins, |
| | num_quantizers=self.n_q, |
| | decay=self.decay, |
| | kmeans_init=self.kmeans_init, |
| | kmeans_iters=self.kmeans_iters, |
| | threshold_ema_dead_code=self.threshold_ema_dead_code, |
| | orthogonal_reg_weight=self.orthogonal_reg_weight, |
| | orthogonal_reg_active_codes_only=self.orthogonal_reg_active_codes_only, |
| | orthogonal_reg_max_codes=self.orthogonal_reg_max_codes, |
| | channels_last=False |
| | ) |
| |
|
| | def forward(self, x: torch.Tensor, frame_rate: int): |
| | n_q = self.n_q |
| | if self.training and self.q_dropout: |
| | n_q = int(torch.randint(1, self.n_q + 1, (1,)).item()) |
| | bw_per_q = math.log2(self.bins) * frame_rate / 1000 |
| | quantized, codes, commit_loss = self.vq(x, n_q=n_q) |
| | codes = codes.transpose(0, 1) |
| | |
| | bw = torch.tensor(n_q * bw_per_q).to(x) |
| | return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss)) |
| |
|
| | def encode(self, x: torch.Tensor) -> torch.Tensor: |
| | """Encode a given input tensor with the specified frame rate at the given bandwidth. |
| | The RVQ encode method sets the appropriate number of quantizer to use |
| | and returns indices for each quantizer. |
| | """ |
| | n_q = self.n_q |
| | codes = self.vq.encode(x, n_q=n_q) |
| | codes = codes.transpose(0, 1) |
| | |
| | return codes |
| |
|
| | def decode(self, codes: torch.Tensor) -> torch.Tensor: |
| | """Decode the given codes to the quantized representation. |
| | """ |
| | |
| | codes = codes.transpose(0, 1) |
| | quantized = self.vq.decode(codes) |
| | return quantized |
| |
|
| | @property |
| | def total_codebooks(self): |
| | return self.max_n_q |
| |
|
| | @property |
| | def num_codebooks(self): |
| | return self.n_q |
| |
|
| | def set_num_codebooks(self, n: int): |
| | assert n > 0 and n <= self.max_n_q |
| | self.n_q = n |
| |
|