Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import math | |
| import typing as tp | |
| import torch | |
| from .modules import SEANetDecoder | |
| from .modules import SEANetEncoder | |
| from .quantization import ResidualVectorQuantizer | |
| ################################################################################ | |
| # Encodec neural audio codec | |
| ################################################################################ | |
| class Encodec(torch.nn.Module): | |
| """ | |
| Encodec neural audio codec proposed in "High Fidelity Neural Audio | |
| Compression" (https://arxiv.org/abs/2210.13438) by Défossez et al. | |
| """ | |
| def __init__( | |
| self, | |
| sample_rate: int, | |
| channels: int, | |
| causal: bool, | |
| model_norm: str, | |
| target_bandwidths: tp.Sequence[float], | |
| audio_normalize: bool, | |
| ratios: tp.List[int] = (8, 5, 4, 2), | |
| codebook_size: int = 1024, | |
| n_filters: int = 32, | |
| true_skip: bool = False, | |
| encoder_kwargs: tp.Dict = None, | |
| decoder_kwargs: tp.Dict = None, | |
| ): | |
| """ | |
| Parameters | |
| ---------- | |
| sample_rate : int | |
| Audio sample rate in Hz. | |
| channels : int | |
| Number of audio channels expected at input. | |
| causal : bool | |
| Whether to use a causal convolution layers in encoder/decoder. | |
| model_norm : str | |
| Type of normalization to use in encoder/decoder. | |
| target_bandwidths : tp.Sequence[float] | |
| List of target bandwidths in kb/s. | |
| audio_normalize : bool | |
| Whether to normalize encoded and decoded audio segments using | |
| simple scaling factors | |
| ratios : tp.List[int], optional | |
| List of downsampling ratios used in encoder/decoder, by default (8, 5, 4, 2) | |
| codebook_size : int, optional | |
| Size of residual vector quantizer codebooks, by default 1024 | |
| n_filters : int, optional | |
| Number of filters used in encoder/decoder, by default 32 | |
| true_skip : bool, optional | |
| Whether to use true skip connections in encoder/decoder rather than | |
| convolutional skip connections, by default False | |
| """ | |
| super().__init__() | |
| encoder_kwargs = encoder_kwargs or {} | |
| decoder_kwargs = decoder_kwargs or {} | |
| self.encoder = SEANetEncoder( | |
| channels=channels, | |
| causal=causal, | |
| norm=model_norm, | |
| ratios=ratios, | |
| n_filters=n_filters, | |
| true_skip=true_skip, | |
| **encoder_kwargs, | |
| ) | |
| self.decoder = SEANetDecoder( | |
| channels=channels, | |
| causal=causal, | |
| norm=model_norm, | |
| ratios=ratios, | |
| n_filters=n_filters, | |
| true_skip=true_skip, | |
| **decoder_kwargs, | |
| ) | |
| n_q = int( | |
| 1000 | |
| * target_bandwidths[-1] | |
| // (math.ceil(sample_rate / self.encoder.hop_length) * 10) | |
| ) | |
| self.n_q = n_q # Maximum number of quantizers | |
| self.quantizer = ResidualVectorQuantizer( | |
| dimension=self.encoder.dimension, | |
| n_q=n_q, | |
| bins=codebook_size, | |
| ) | |
| self.sample_rate = sample_rate | |
| self.normalize = audio_normalize | |
| self.channels = channels | |
| self.frame_rate = math.ceil(self.sample_rate / math.prod(self.encoder.ratios)) | |
| self.target_bandwidths = target_bandwidths | |
| self.bits_per_codebook = int(math.log2(self.quantizer.bins)) | |
| assert ( | |
| 2**self.bits_per_codebook == self.quantizer.bins | |
| ), "quantizer bins must be a power of 2." | |
| self.bandwidth = self.target_bandwidths[-1] | |
| def set_target_bandwidth(self, bandwidth: float): | |
| """ | |
| Set the target bandwidth for the codec by adjusting the | |
| number of residual vector quantizers used | |
| """ | |
| if bandwidth not in self.target_bandwidths: | |
| raise ValueError( | |
| f"This model doesn't support the bandwidth {bandwidth}. " | |
| f"Select one of {self.target_bandwidths}." | |
| ) | |
| self.bandwidth = bandwidth | |
| def encode(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Map a given an audio waveform `x` to discrete residual latent codes. | |
| Parameters | |
| ---------- | |
| x : torch.Tensor | |
| Audio waveform of shape `(n_batch, n_channels, n_samples)`. | |
| Returns | |
| ------- | |
| codes : torch.Tensor | |
| Tensor of shape `(n_batch, n_codebooks, n_frames)`. | |
| """ | |
| assert x.dim() == 3 | |
| _, channels, length = x.shape | |
| assert 0 < channels <= 2 | |
| z = self.encoder(x) | |
| codes, z_O, z_o = self.quantizer.encode(z, self.frame_rate, self.bandwidth) | |
| codes = codes.transpose(0, 1) | |
| return codes, z_O, z_o, z | |
| def decode(self, codes: torch.Tensor): | |
| """ | |
| Decode quantized latents to obtain waveform audio. | |
| Parameters | |
| ---------- | |
| codes : torch.Tensor | |
| Tensor of shape `(n_batch, n_codebooks, n_frames)`. | |
| Returns | |
| ------- | |
| out : torch.Tensor | |
| Tensor of shape `(n_batch, n_channels, n_samples)`. | |
| """ | |
| codes = codes.transpose(0, 1) | |
| emb = self.quantizer.decode(codes) | |
| out = self.decoder(emb) | |
| return out | |