| |
| |
| |
| |
| |
| """Compression models or wrapper around existing models. |
| Also defines the main interface that a model must follow to be usable as an audio tokenizer. |
| """ |
|
|
| from abc import ABC, abstractmethod |
| import logging |
| import math |
| from pathlib import Path |
| import typing as tp |
|
|
| from einops import rearrange |
| import numpy as np |
| import torch |
| from torch import nn |
| from transformers import EncodecModel as HFEncodecModel |
|
|
| from .. import quantization as qt |
|
|
|
|
| logger = logging.getLogger() |
|
|
|
|
| class CompressionModel(ABC, nn.Module): |
| """Base API for all compression model that aim at being used as audio tokenizers |
| with a language model. |
| """ |
|
|
| @abstractmethod |
| def forward(self, x: torch.Tensor) -> qt.QuantizedResult: |
| ... |
|
|
| @abstractmethod |
| def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: |
| """See `EncodecModel.encode`.""" |
| ... |
|
|
| @abstractmethod |
| def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): |
| """See `EncodecModel.decode`.""" |
| ... |
|
|
| @abstractmethod |
| def decode_latent(self, codes: torch.Tensor): |
| """Decode from the discrete codes to continuous latent space.""" |
| ... |
|
|
| @property |
| @abstractmethod |
| def channels(self) -> int: |
| ... |
|
|
| @property |
| @abstractmethod |
| def frame_rate(self) -> float: |
| ... |
|
|
| @property |
| @abstractmethod |
| def sample_rate(self) -> int: |
| ... |
|
|
| @property |
| @abstractmethod |
| def cardinality(self) -> int: |
| ... |
|
|
| @property |
| @abstractmethod |
| def num_codebooks(self) -> int: |
| ... |
|
|
| @property |
| @abstractmethod |
| def total_codebooks(self) -> int: |
| ... |
|
|
| @abstractmethod |
| def set_num_codebooks(self, n: int): |
| """Set the active number of codebooks used by the quantizer.""" |
| ... |
|
|
| @staticmethod |
| def get_pretrained( |
| name: str, device: tp.Union[torch.device, str] = 'cpu' |
| ) -> 'CompressionModel': |
| """Instantiate a CompressionModel from a given pretrained model. |
| |
| Args: |
| name (Path or str): name of the pretrained model. See after. |
| device (torch.device or str): Device on which the model is loaded. |
| |
| Pretrained models: |
| - dac_44khz (https://github.com/descriptinc/descript-audio-codec) |
| - dac_24khz (same) |
| - facebook/encodec_24khz (https://huggingface.co/facebook/encodec_24khz) |
| - facebook/encodec_32khz (https://huggingface.co/facebook/encodec_32khz) |
| - your own model on HugginFace. Export instructions to come... |
| """ |
|
|
| from . import builders, loaders |
| model: CompressionModel |
| if name in ['dac_44khz', 'dac_24khz']: |
| model_type = name.split('_')[1] |
| logger.info("Getting pretrained compression model from DAC %s", model_type) |
| model = DAC(model_type) |
| elif name in ['debug_compression_model']: |
| logger.info("Getting pretrained compression model for debug") |
| model = builders.get_debug_compression_model() |
| elif Path(name).exists(): |
| |
| |
| model = loaders.load_compression_model(name, device=device) |
| else: |
| logger.info("Getting pretrained compression model from HF %s", name) |
| hf_model = HFEncodecModel.from_pretrained(name) |
| model = HFEncodecCompressionModel(hf_model).to(device) |
| return model.to(device).eval() |
|
|
|
|
| class EncodecModel(CompressionModel): |
| """Encodec model operating on the raw waveform. |
| |
| Args: |
| encoder (nn.Module): Encoder network. |
| decoder (nn.Module): Decoder network. |
| quantizer (qt.BaseQuantizer): Quantizer network. |
| frame_rate (int): Frame rate for the latent representation. |
| sample_rate (int): Audio sample rate. |
| channels (int): Number of audio channels. |
| causal (bool): Whether to use a causal version of the model. |
| renormalize (bool): Whether to renormalize the audio before running the model. |
| """ |
| |
| |
| frame_rate: float = 0 |
| sample_rate: int = 0 |
| channels: int = 0 |
|
|
| def __init__(self, |
| encoder: nn.Module, |
| decoder: nn.Module, |
| quantizer: qt.BaseQuantizer, |
| frame_rate: int, |
| sample_rate: int, |
| channels: int, |
| causal: bool = False, |
| renormalize: bool = False): |
| super().__init__() |
| self.encoder = encoder |
| self.decoder = decoder |
| self.quantizer = quantizer |
| self.frame_rate = frame_rate |
| self.sample_rate = sample_rate |
| self.channels = channels |
| self.renormalize = renormalize |
| self.causal = causal |
| if self.causal: |
| |
| |
| assert not self.renormalize, 'Causal model does not support renormalize' |
|
|
| @property |
| def total_codebooks(self): |
| """Total number of quantizer codebooks available.""" |
| return self.quantizer.total_codebooks |
|
|
| @property |
| def num_codebooks(self): |
| """Active number of codebooks used by the quantizer.""" |
| return self.quantizer.num_codebooks |
|
|
| def set_num_codebooks(self, n: int): |
| """Set the active number of codebooks used by the quantizer.""" |
| self.quantizer.set_num_codebooks(n) |
|
|
| @property |
| def cardinality(self): |
| """Cardinality of each codebook.""" |
| return self.quantizer.bins |
|
|
| def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: |
| scale: tp.Optional[torch.Tensor] |
| if self.renormalize: |
| mono = x.mean(dim=1, keepdim=True) |
| volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt() |
| scale = 1e-8 + volume |
| x = x / scale |
| scale = scale.view(-1, 1) |
| else: |
| scale = None |
| return x, scale |
|
|
| def postprocess(self, |
| x: torch.Tensor, |
| scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor: |
| if scale is not None: |
| assert self.renormalize |
| x = x * scale.view(-1, 1, 1) |
| return x |
|
|
| def forward(self, x: torch.Tensor) -> qt.QuantizedResult: |
| assert x.dim() == 3 |
| length = x.shape[-1] |
| x, scale = self.preprocess(x) |
|
|
| emb = self.encoder(x) |
| q_res = self.quantizer(emb, self.frame_rate) |
| out = self.decoder(q_res.x) |
|
|
| |
| assert out.shape[-1] >= length, (out.shape[-1], length) |
| out = out[..., :length] |
|
|
| q_res.x = self.postprocess(out, scale) |
|
|
| return q_res |
|
|
| def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: |
| """Encode the given input tensor to quantized representation along with scale parameter. |
| |
| Args: |
| x (torch.Tensor): Float tensor of shape [B, C, T] |
| |
| Returns: |
| codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of: |
| codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep. |
| scale a float tensor containing the scale for audio renormalizealization. |
| """ |
| assert x.dim() == 3 |
| x, scale = self.preprocess(x) |
| emb = self.encoder(x) |
| codes = self.quantizer.encode(emb) |
| return codes, scale |
|
|
| def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): |
| """Decode the given codes to a reconstructed representation, using the scale to perform |
| audio denormalization if needed. |
| |
| Args: |
| codes (torch.Tensor): Int tensor of shape [B, K, T] |
| scale (torch.Tensor, optional): Float tensor containing the scale value. |
| |
| Returns: |
| out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio. |
| """ |
| emb = self.decode_latent(codes) |
| out = self.decoder(emb) |
| out = self.postprocess(out, scale) |
| |
| return out |
|
|
| def decode_latent(self, codes: torch.Tensor): |
| """Decode from the discrete codes to continuous latent space.""" |
| return self.quantizer.decode(codes) |
|
|
|
|
| class DAC(CompressionModel): |
| def __init__(self, model_type: str = "44khz"): |
| super().__init__() |
| try: |
| import dac.utils |
| except ImportError: |
| raise RuntimeError("Could not import dac, make sure it is installed, " |
| "please run `pip install descript-audio-codec`") |
| self.model = dac.utils.load_model(model_type=model_type) |
| self.n_quantizers = self.total_codebooks |
| self.model.eval() |
|
|
| def forward(self, x: torch.Tensor) -> qt.QuantizedResult: |
| |
| raise NotImplementedError("Forward and training with DAC not supported.") |
|
|
| def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: |
| codes = self.model.encode(x, self.n_quantizers)[1] |
| return codes[:, :self.n_quantizers], None |
|
|
| def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): |
| assert scale is None |
| z_q = self.decode_latent(codes) |
| return self.model.decode(z_q) |
|
|
| def decode_latent(self, codes: torch.Tensor): |
| """Decode from the discrete codes to continuous latent space.""" |
| return self.model.quantizer.from_codes(codes)[0] |
|
|
| @property |
| def channels(self) -> int: |
| return 1 |
|
|
| @property |
| def frame_rate(self) -> float: |
| return self.model.sample_rate / self.model.hop_length |
|
|
| @property |
| def sample_rate(self) -> int: |
| return self.model.sample_rate |
|
|
| @property |
| def cardinality(self) -> int: |
| return self.model.codebook_size |
|
|
| @property |
| def num_codebooks(self) -> int: |
| return self.n_quantizers |
|
|
| @property |
| def total_codebooks(self) -> int: |
| return self.model.n_codebooks |
|
|
| def set_num_codebooks(self, n: int): |
| """Set the active number of codebooks used by the quantizer. |
| """ |
| assert n >= 1 |
| assert n <= self.total_codebooks |
| self.n_quantizers = n |
|
|
|
|
| class HFEncodecCompressionModel(CompressionModel): |
| """Wrapper around HuggingFace Encodec. |
| """ |
| def __init__(self, model: HFEncodecModel): |
| super().__init__() |
| self.model = model |
| bws = self.model.config.target_bandwidths |
| num_codebooks = [ |
| bw * 1000 / (self.frame_rate * math.log2(self.cardinality)) |
| for bw in bws |
| ] |
| deltas = [nc - int(nc) for nc in num_codebooks] |
| |
| assert all(deltas) <= 1e-3, deltas |
| self.possible_num_codebooks = [int(nc) for nc in num_codebooks] |
| self.set_num_codebooks(max(self.possible_num_codebooks)) |
|
|
| def forward(self, x: torch.Tensor) -> qt.QuantizedResult: |
| |
| raise NotImplementedError("Forward and training with HF EncodecModel not supported.") |
|
|
| def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: |
| bandwidth_index = self.possible_num_codebooks.index(self.num_codebooks) |
| bandwidth = self.model.config.target_bandwidths[bandwidth_index] |
| res = self.model.encode(x, None, bandwidth) |
| assert len(res[0]) == 1 |
| assert len(res[1]) == 1 |
| return res[0][0], res[1][0] |
|
|
| def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): |
| if scale is None: |
| scales = [None] |
| else: |
| scales = scale |
| res = self.model.decode(codes[None], scales) |
| return res[0] |
|
|
| def decode_latent(self, codes: torch.Tensor): |
| """Decode from the discrete codes to continuous latent space.""" |
| return self.model.quantizer.decode(codes.transpose(0, 1)) |
|
|
| @property |
| def channels(self) -> int: |
| return self.model.config.audio_channels |
|
|
| @property |
| def frame_rate(self) -> float: |
| hop_length = int(np.prod(self.model.config.upsampling_ratios)) |
| return self.sample_rate / hop_length |
|
|
| @property |
| def sample_rate(self) -> int: |
| return self.model.config.sampling_rate |
|
|
| @property |
| def cardinality(self) -> int: |
| return self.model.config.codebook_size |
|
|
| @property |
| def num_codebooks(self) -> int: |
| return self._num_codebooks |
|
|
| @property |
| def total_codebooks(self) -> int: |
| return max(self.possible_num_codebooks) |
|
|
| def set_num_codebooks(self, n: int): |
| """Set the active number of codebooks used by the quantizer. |
| """ |
| if n not in self.possible_num_codebooks: |
| raise ValueError(f"Allowed values for num codebooks: {self.possible_num_codebooks}") |
| self._num_codebooks = n |
|
|
|
|
| class InterleaveStereoCompressionModel(CompressionModel): |
| """Wraps a CompressionModel to support stereo inputs. The wrapped model |
| will be applied independently to the left and right channels, and both codebooks |
| will be interleaved. If the wrapped model returns a representation `[B, K ,T]` per |
| channel, then the output will be `[B, K * 2, T]` or `[B, K, T * 2]` depending on |
| `per_timestep`. |
| |
| Args: |
| model (CompressionModel): Compression model to wrap. |
| per_timestep (bool): Whether to interleave on the timestep dimension |
| or on the codebooks dimension. |
| """ |
| def __init__(self, model: CompressionModel, per_timestep: bool = False): |
| super().__init__() |
| self.model = model |
| self.per_timestep = per_timestep |
| assert self.model.channels == 1, "Wrapped model is expected to be for monophonic audio" |
|
|
| @property |
| def total_codebooks(self): |
| return self.model.total_codebooks |
|
|
| @property |
| def num_codebooks(self): |
| """Active number of codebooks used by the quantizer. |
| |
| ..Warning:: this reports the number of codebooks after the interleaving |
| of the codebooks! |
| """ |
| return self.model.num_codebooks if self.per_timestep else self.model.num_codebooks * 2 |
|
|
| def set_num_codebooks(self, n: int): |
| """Set the active number of codebooks used by the quantizer. |
| |
| ..Warning:: this sets the number of codebooks before the interleaving! |
| """ |
| self.model.set_num_codebooks(n) |
|
|
| @property |
| def num_virtual_steps(self) -> float: |
| """Return the number of virtual steps, e.g. one real step |
| will be split into that many steps. |
| """ |
| return 2 if self.per_timestep else 1 |
|
|
| @property |
| def frame_rate(self) -> float: |
| return self.model.frame_rate * self.num_virtual_steps |
|
|
| @property |
| def sample_rate(self) -> int: |
| return self.model.sample_rate |
|
|
| @property |
| def channels(self) -> int: |
| return 2 |
|
|
| @property |
| def cardinality(self): |
| """Cardinality of each codebook. |
| """ |
| return self.model.cardinality |
|
|
| def forward(self, x: torch.Tensor) -> qt.QuantizedResult: |
| raise NotImplementedError("Not supported, use encode and decode.") |
|
|
| def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: |
| B, C, T = x.shape |
| assert C == self.channels, f"Expecting stereo audio but audio num channels is {C}" |
|
|
| indices_c0, scales_c0 = self.model.encode(x[:, 0, ...].unsqueeze(1)) |
| indices_c1, scales_c1 = self.model.encode(x[:, 1, ...].unsqueeze(1)) |
| indices = torch.stack([indices_c0, indices_c1], dim=0) |
| scales: tp.Optional[torch.Tensor] = None |
| if scales_c0 is not None and scales_c1 is not None: |
| scales = torch.stack([scales_c0, scales_c1], dim=1) |
|
|
| if self.per_timestep: |
| indices = rearrange(indices, 'c b k t -> b k (t c)', c=2) |
| else: |
| indices = rearrange(indices, 'c b k t -> b (k c) t', c=2) |
|
|
| return (indices, scales) |
|
|
| def get_left_right_codes(self, codes: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]: |
| if self.per_timestep: |
| codes = rearrange(codes, 'b k (t c) -> c b k t', c=2) |
| else: |
| codes = rearrange(codes, 'b (k c) t -> c b k t', c=2) |
| return codes[0], codes[1] |
|
|
| def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): |
| B, K, T = codes.shape |
| assert T % self.num_virtual_steps == 0, "Provided codes' number of timesteps does not match" |
| assert K == self.num_codebooks, "Provided codes' number of codebooks does not match" |
|
|
| scale_c0, scale_c1 = None, None |
| if scale is not None: |
| assert scale.size(0) == B and scale.size(1) == 2, f"Scale has unexpected shape: {scale.shape}" |
| scale_c0 = scale[0, ...] |
| scale_c1 = scale[1, ...] |
|
|
| codes_c0, codes_c1 = self.get_left_right_codes(codes) |
| audio_c0 = self.model.decode(codes_c0, scale_c0) |
| audio_c1 = self.model.decode(codes_c1, scale_c1) |
| return torch.cat([audio_c0, audio_c1], dim=1) |
|
|
| def decode_latent(self, codes: torch.Tensor): |
| """Decode from the discrete codes to continuous latent space.""" |
| raise NotImplementedError("Not supported by interleaved stereo wrapped models.") |
|
|