Spaces:
Runtime error
Runtime error
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: MIT | |
| # | |
| # Permission is hereby granted, free of charge, to any person obtaining a | |
| # copy of this software and associated documentation files (the "Software"), | |
| # to deal in the Software without restriction, including without limitation | |
| # the rights to use, copy, modify, merge, publish, distribute, sublicense, | |
| # and/or sell copies of the Software, and to permit persons to whom the | |
| # Software is furnished to do so, subject to the following conditions: | |
| # | |
| # The above copyright notice and this permission notice shall be included in | |
| # all copies or substantial portions of the Software. | |
| # | |
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL | |
| # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING | |
| # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER | |
| # DEALINGS IN THE SOFTWARE. | |
| # Copyright (c) Kyutai, all rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import math | |
| import typing as tp | |
| import torch | |
| from .base import BaseQuantizer, QuantizedResult | |
| from .core_vq import ResidualVectorQuantization | |
| class ResidualVectorQuantizer(BaseQuantizer): | |
| """Residual Vector Quantizer. | |
| Args: | |
| dimension (int): Dimension of the codebooks. | |
| input_dimension (None or int): dimension of the input, defaults to `dimension` if not provided. | |
| output_dimension (None or int): dimension of the output, defaults to `dimension` if not provided. | |
| n_q (int): Number of vector quantizers used. | |
| q_dropout (bool): Random quantizer drop out at train time. | |
| no_quantization_rate (float): Gives the probability of applying no quantization at all | |
| at train time. The RVQ codebooks will still get the input value to learn the proper codebook. | |
| bins (int): Codebook size. | |
| decay (float): Decay for exponential moving average over the codebooks. | |
| threshold_usage_ratio (float): Defines the threshold for the cluster usage under which a centroid | |
| is replaced. This is expressed as a fraction of the usage a centroid would get under | |
| a uniform distribution, so that it doesn't depend on the batch size etc. | |
| replaced_usage_ratio (float): When replacing a centroid, use this as an initial centroid usage, | |
| to avoid the centroid getting replaced too quickly. | |
| codebook_offset (int): Offset to use for the codebook indices. This is useful when using multiple quantizers | |
| such as in SplitResidualVectorQuantizer. | |
| force_projection (bool): Whether to force input and output projections even when dimension is constant. | |
| generator_seed (int or None): seed used to initialize the RNG used for no quantization. | |
| """ | |
| def __init__( | |
| self, | |
| dimension: int = 128, | |
| input_dimension: tp.Optional[int] = None, | |
| output_dimension: tp.Optional[int] = None, | |
| n_q: int = 8, | |
| q_dropout: bool = False, | |
| q_first_only_proba: float = 0.0, | |
| no_quantization_rate: float = 0.0, | |
| bins: int = 1024, | |
| decay: float = 0.99, | |
| threshold_usage_ratio: float = 0.1, | |
| replaced_usage_ratio: float = 1.0, | |
| codebook_offset: int = 0, | |
| force_projection: bool = False, | |
| generator_seed: tp.Optional[int] = None, | |
| ): | |
| super().__init__() | |
| self.max_n_q = n_q | |
| self.n_q = n_q | |
| self.q_dropout = q_dropout | |
| self.no_quantization_rate = no_quantization_rate | |
| self.q_first_only_proba = q_first_only_proba | |
| self.dimension = dimension | |
| self.input_dimension = input_dimension or dimension | |
| self.output_dimension = output_dimension or dimension | |
| self.bins = bins | |
| self.decay = decay | |
| self.input_proj: torch.nn.Module | |
| self.output_proj: torch.nn.Module | |
| self.generator = None | |
| if generator_seed is not None: | |
| self.generator = torch.Generator( | |
| device="cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| self.generator.manual_seed(generator_seed) | |
| if self.input_dimension == self.dimension and not force_projection: | |
| self.input_proj = torch.nn.Identity() | |
| else: | |
| self.input_proj = torch.nn.Conv1d( | |
| self.input_dimension, self.dimension, 1, bias=False | |
| ) | |
| if self.output_dimension == self.dimension and not force_projection: | |
| self.output_proj = torch.nn.Identity() | |
| else: | |
| self.output_proj = torch.nn.Conv1d( | |
| self.dimension, self.output_dimension, 1, bias=False | |
| ) | |
| self.vq = ResidualVectorQuantization( | |
| dim=self.dimension, | |
| codebook_size=self.bins, | |
| num_quantizers=self.n_q, | |
| decay=self.decay, | |
| threshold_usage_ratio=threshold_usage_ratio, | |
| replaced_usage_ratio=replaced_usage_ratio, | |
| codebook_offset=codebook_offset, | |
| ) | |
| def forward(self, x: torch.Tensor, frame_rate: int): | |
| """ | |
| Args: | |
| x (torch.Tensor): Input tensor of shape [B, C, T] with `C` number of channels. | |
| frame_rate (int): frame rate of the input (e.g `T = frame_rate * duration`), used to compute | |
| the bandwidth. | |
| Returns: | |
| QuantizedResult: Quantized result with the following attributes: | |
| - `x` (torch.Tensor): Quantized tensor of shape [B, C, T]. | |
| - `codes` (torch.Tensor): Quantized codes of shape [B, K, T] with `K` number of codebooks. | |
| - `bw` (torch.Tensor): Bandwidth of the quantized tensor in kbits per second. | |
| - `penalty` (torch.Tensor): Commitment loss. | |
| - `metrics` (dict): RVQ metrics, in particular rate of dead code replacement, and entropy. | |
| """ | |
| n_q = self.n_q | |
| x = self.input_proj(x) | |
| bw_per_q = math.log2(self.bins) * frame_rate / 1000 | |
| quantized, codes, commit_loss, metrics = self.vq(x, n_q=n_q) | |
| B, _, _ = quantized.shape | |
| quantized = self.output_proj(quantized) | |
| codes = codes.transpose(0, 1) | |
| # codes is [B, K, T], with T frames, K nb of codebooks. | |
| bw = torch.tensor(n_q * bw_per_q).to(x) | |
| return QuantizedResult( | |
| quantized, codes, bw, penalty=torch.mean(commit_loss), metrics=metrics | |
| ) | |
| def encode(self, x: torch.Tensor) -> torch.Tensor: | |
| """Encode a given input tensor with the specified frame rate at the given bandwidth. | |
| The RVQ encode method sets the appropriate number of quantizer to use | |
| and returns indices for each quantizer. | |
| """ | |
| n_q = self.n_q | |
| if x.shape[-1] == 0: | |
| return torch.empty((x.shape[0], n_q, 0), device=x.device, dtype=torch.int64) | |
| x = self.input_proj(x) | |
| codes = self.vq.encode(x, n_q=n_q) | |
| codes = codes.transpose(0, 1) | |
| # codes is [B, K, T], with T frames, K nb of codebooks. | |
| return codes | |
| def decode(self, codes: torch.Tensor) -> torch.Tensor: | |
| """Decode the given codes to the quantized representation.""" | |
| # codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T]. | |
| codes = codes.transpose(0, 1) | |
| quantized = self.vq.decode(codes) | |
| quantized = self.output_proj(quantized) | |
| return quantized | |
| def total_codebooks(self): | |
| return self.max_n_q | |
| def num_codebooks(self): | |
| return self.n_q | |
| def set_num_codebooks(self, n: int): | |
| assert n >= 0 and n <= self.max_n_q | |
| self.n_q = n | |
| def cardinality(self) -> int: | |
| return self.bins | |
| class SplitResidualVectorQuantizer(BaseQuantizer): | |
| """Residual Vector Quantizer with separate projections for the first quantizer and the rest. | |
| Args: | |
| n_q (int): Number of residual vector quantizers used. | |
| n_semantic_q (int): Number of residual vector quantizers used for the semantic quantizer. | |
| no_quantization_mode (str): if 'true_skip', when doing no quantization, the input will not go | |
| through the sub quantizers. If `independent`, independent decisions are taken by | |
| the semantic and acoustic quantizers. If `same` (the default), the same decision is taken by both. | |
| **kwargs: Arguments to the constructor of `ResidualVectorQuantizer` that are shared between both. | |
| """ | |
| def __init__( | |
| self, | |
| *, | |
| n_q: int = 8, | |
| no_quantization_rate: float = 0.0, | |
| no_quantization_mode: str = "same", | |
| n_q_semantic: int = 1, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| assert n_q > n_q_semantic, ( | |
| f"Number of quantizers {n_q} must be larger " | |
| f"than the number of semantic quantizers {n_q_semantic}." | |
| ) | |
| self.max_n_q = n_q | |
| self.n_q_semantic = n_q_semantic | |
| self.n_q_acoustic = n_q - n_q_semantic | |
| if no_quantization_mode == "true_skip": | |
| self.no_quantization_rate = no_quantization_rate | |
| # Setting to zero for the underlying RVQ. | |
| no_quantization_rate = 0.0 | |
| else: | |
| self.no_quantization_rate = 0.0 | |
| if no_quantization_mode == "same": | |
| kwargs["generator_seed"] = 1234 | |
| kwargs["no_quantization_rate"] = no_quantization_rate | |
| q_dropout = kwargs.pop("q_dropout", False) | |
| self.rvq_first = ResidualVectorQuantizer( | |
| n_q=n_q_semantic, force_projection=True, q_dropout=False, **kwargs | |
| ) | |
| self.rvq_rest = ResidualVectorQuantizer( | |
| n_q=n_q - n_q_semantic, | |
| codebook_offset=1, | |
| force_projection=True, | |
| q_dropout=q_dropout, | |
| **kwargs, | |
| ) | |
| if no_quantization_mode == "true_skip": | |
| assert self.rvq_first.input_dimension == self.rvq_first.output_dimension | |
| assert self.rvq_rest.input_dimension == self.rvq_rest.output_dimension | |
| def _renorm_and_add( | |
| self, | |
| first_val: torch.Tensor, | |
| rest_val: torch.Tensor, | |
| n_q_semantic: int, | |
| n_q_acoustic: int, | |
| ): | |
| """Renormalizes values from `rvq_first` and `rvq_rest` and adds them. | |
| This allows correcting statistics that are normalized by the number of quantizers. To renormalize, we use the | |
| number of quantizers that are actually used, e.g. taking into account quantizer dropout. | |
| """ | |
| n_q = n_q_semantic + n_q_acoustic | |
| renorm_first_val = first_val * n_q_semantic / n_q | |
| renorm_rest_val = rest_val * n_q_acoustic / n_q | |
| return renorm_first_val + renorm_rest_val | |
| def forward(self, x: torch.Tensor, frame_rate: int): | |
| """ | |
| Args: | |
| x (torch.Tensor): Input tensor of shape [B, C, T] with `C` number of channels. | |
| frame_rate (int): frame rate of the input (e.g `T = frame_rate * duration`), used to compute | |
| the bandwidth. | |
| Returns: | |
| QuantizedResult: Quantized result with the following attributes: | |
| - `x` (torch.Tensor): Quantized tensor of shape [B, C, T]. | |
| - `codes` (torch.Tensor): Quantized codes of shape [B, K, T] with `K` number of codebooks. | |
| - `bw` (torch.Tensor): Bandwidth of the quantized tensor in kbits per second. | |
| - `penalty` (torch.Tensor): Commitment loss. | |
| - `metrics` (dict): RVQ metrics, in particular rate of dead code replacement, and entropy. | |
| """ | |
| semantic_result = self.rvq_first(x, frame_rate) | |
| if self.n_q == self.n_q_semantic: | |
| return semantic_result | |
| acoustic_result = self.rvq_rest(x, frame_rate) | |
| full_quantized_emb = semantic_result.x + acoustic_result.x | |
| full_quantized_codes = torch.cat( | |
| [semantic_result.codes, acoustic_result.codes], dim=1 | |
| ) | |
| # This is the actual number of quantizers used, e.g. taking into account quantizer dropout. | |
| n_q_semantic = semantic_result.codes.shape[1] | |
| n_q_acoustic = acoustic_result.codes.shape[1] | |
| full_quantized_bandwidth = semantic_result.bandwidth + acoustic_result.bandwidth | |
| full_quantized_penalty = self._renorm_and_add( | |
| semantic_result.penalty, acoustic_result.penalty, n_q_semantic, n_q_acoustic | |
| ) | |
| full_quantized_metrics = semantic_result.metrics | |
| for key, value in acoustic_result.metrics.items(): | |
| if key in full_quantized_metrics: | |
| full_quantized_metrics[key] = self._renorm_and_add( | |
| full_quantized_metrics[key], value, n_q_semantic, n_q_acoustic | |
| ) | |
| else: | |
| full_quantized_metrics[key] = value | |
| return QuantizedResult( | |
| full_quantized_emb, | |
| full_quantized_codes, | |
| full_quantized_bandwidth, | |
| penalty=full_quantized_penalty, | |
| metrics=full_quantized_metrics, | |
| ) | |
| def encode(self, x: torch.Tensor) -> torch.Tensor: | |
| """Encode a given input tensor with the specified frame rate at the given bandwidth. | |
| The RVQ encode method sets the appropriate number of quantizer to use | |
| and returns indices for each quantizer. | |
| """ | |
| codes = self.rvq_first.encode(x) | |
| if self.n_q > self.n_q_semantic: | |
| acoustic_codes = self.rvq_rest.encode(x) | |
| codes = torch.cat([codes, acoustic_codes], dim=1) | |
| # codes is [B, K, T], with T frames, K nb of codebooks. | |
| return codes | |
| def decode(self, codes: torch.Tensor) -> torch.Tensor: | |
| """Decode the given codes to the quantized representation.""" | |
| # codes is [B, K, T], with T frames, K nb of codebooks. | |
| quantized = self.rvq_first.decode(codes[:, : self.n_q_semantic]) | |
| if codes.shape[1] > self.n_q_semantic: | |
| quantized += self.rvq_rest.decode(codes[:, self.n_q_semantic :]) | |
| return quantized | |
| def total_codebooks(self): | |
| return self.rvq_first.max_n_q + self.rvq_rest.max_n_q | |
| def num_codebooks(self): | |
| return self.rvq_first.num_codebooks + self.rvq_rest.num_codebooks | |
| def n_q(self): | |
| return self.rvq_first.n_q + self.rvq_rest.n_q | |
| def dimension(self): | |
| return self.rvq_first.dimension | |
| def semantic_quantizer(self) -> ResidualVectorQuantizer: | |
| """This returns the quantizer that models the first level of the hierarchy (typically semantic).""" | |
| return self.rvq_first | |
| def acoustic_quantizer(self) -> ResidualVectorQuantizer: | |
| """This returns the quantizer that models the higher levels of the hierarchy (typically acoustic).""" | |
| return self.rvq_rest | |
| def set_num_codebooks(self, n: int): | |
| assert n >= self.n_q_semantic and n <= self.total_codebooks | |
| self.rvq_rest.set_num_codebooks(n - self.n_q_semantic) | |
| def cardinality(self) -> int: | |
| assert self.rvq_rest.cardinality == self.rvq_first.cardinality | |
| return self.rvq_first.cardinality | |