| |
| |
| |
| |
| |
|
|
| from abc import ABC, abstractmethod |
| import typing as tp |
|
|
| from einops import rearrange |
| import torch |
| from torch import nn |
|
|
| from .. import quantization as qt |
|
|
|
|
| class CompressionModel(ABC, nn.Module): |
|
|
| @abstractmethod |
| def forward(self, x: torch.Tensor) -> qt.QuantizedResult: |
| ... |
|
|
| @abstractmethod |
| def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: |
| """See `EncodecModel.encode`""" |
| ... |
|
|
| @abstractmethod |
| def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): |
| """See `EncodecModel.decode`""" |
| ... |
|
|
| @property |
| @abstractmethod |
| def channels(self) -> int: |
| ... |
|
|
| @property |
| @abstractmethod |
| def frame_rate(self) -> int: |
| ... |
|
|
| @property |
| @abstractmethod |
| def sample_rate(self) -> int: |
| ... |
|
|
| @property |
| @abstractmethod |
| def cardinality(self) -> int: |
| ... |
|
|
| @property |
| @abstractmethod |
| def num_codebooks(self) -> int: |
| ... |
|
|
| @property |
| @abstractmethod |
| def total_codebooks(self) -> int: |
| ... |
|
|
| @abstractmethod |
| def set_num_codebooks(self, n: int): |
| """Set the active number of codebooks used by the quantizer. |
| """ |
| ... |
|
|
|
|
| class EncodecModel(CompressionModel): |
| """Encodec model operating on the raw waveform. |
| |
| Args: |
| encoder (nn.Module): Encoder network. |
| decoder (nn.Module): Decoder network. |
| quantizer (qt.BaseQuantizer): Quantizer network. |
| frame_rate (int): Frame rate for the latent representation. |
| sample_rate (int): Audio sample rate. |
| channels (int): Number of audio channels. |
| causal (bool): Whether to use a causal version of the model. |
| renormalize (bool): Whether to renormalize the audio before running the model. |
| """ |
| |
| |
| frame_rate: int = 0 |
| sample_rate: int = 0 |
| channels: int = 0 |
|
|
| def __init__(self, |
| encoder: nn.Module, |
| decoder: nn.Module, |
| quantizer: qt.BaseQuantizer, |
| frame_rate: int, |
| sample_rate: int, |
| channels: int, |
| causal: bool = False, |
| renormalize: bool = False): |
| super().__init__() |
| self.encoder = encoder |
| self.decoder = decoder |
| self.quantizer = quantizer |
| self.frame_rate = frame_rate |
| self.sample_rate = sample_rate |
| self.channels = channels |
| self.renormalize = renormalize |
| self.causal = causal |
| if self.causal: |
| |
| |
| assert not self.renormalize, 'Causal model does not support renormalize' |
|
|
| @property |
| def total_codebooks(self): |
| """Total number of quantizer codebooks available. |
| """ |
| return self.quantizer.total_codebooks |
|
|
| @property |
| def num_codebooks(self): |
| """Active number of codebooks used by the quantizer. |
| """ |
| return self.quantizer.num_codebooks |
|
|
| def set_num_codebooks(self, n: int): |
| """Set the active number of codebooks used by the quantizer. |
| """ |
| self.quantizer.set_num_codebooks(n) |
|
|
| @property |
| def cardinality(self): |
| """Cardinality of each codebook. |
| """ |
| return self.quantizer.bins |
|
|
| def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: |
| scale: tp.Optional[torch.Tensor] |
| if self.renormalize: |
| mono = x.mean(dim=1, keepdim=True) |
| volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt() |
| scale = 1e-8 + volume |
| x = x / scale |
| scale = scale.view(-1, 1) |
| else: |
| scale = None |
| return x, scale |
|
|
| def postprocess(self, |
| x: torch.Tensor, |
| scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor: |
| if scale is not None: |
| assert self.renormalize |
| x = x * scale.view(-1, 1, 1) |
| return x |
|
|
| def forward(self, x: torch.Tensor) -> qt.QuantizedResult: |
| assert x.dim() == 3 |
| length = x.shape[-1] |
| x, scale = self.preprocess(x) |
|
|
| emb = self.encoder(x) |
| q_res = self.quantizer(emb, self.frame_rate) |
| out = self.decoder(q_res.x) |
|
|
| |
| assert out.shape[-1] >= length, (out.shape[-1], length) |
| out = out[..., :length] |
|
|
| q_res.x = self.postprocess(out, scale) |
|
|
| return q_res |
|
|
| def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: |
| """Encode the given input tensor to quantized representation along with scale parameter. |
| |
| Args: |
| x (torch.Tensor): Float tensor of shape [B, C, T] |
| |
| Returns: |
| codes, scale (tp.Tuple[torch.Tensor, torch.Tensor]): Tuple composed of: |
| codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep. |
| scale a float tensor containing the scale for audio renormalizealization. |
| """ |
| assert x.dim() == 3 |
| x, scale = self.preprocess(x) |
| emb = self.encoder(x) |
| codes = self.quantizer.encode(emb) |
| return codes, scale |
|
|
| def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): |
| """Decode the given codes to a reconstructed representation, using the scale to perform |
| audio denormalization if needed. |
| |
| Args: |
| codes (torch.Tensor): Int tensor of shape [B, K, T] |
| scale (tp.Optional[torch.Tensor]): Float tensor containing the scale value. |
| |
| Returns: |
| out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio. |
| """ |
| emb = self.quantizer.decode(codes) |
| out = self.decoder(emb) |
| out = self.postprocess(out, scale) |
| |
| return out |
|
|
|
|
| class FlattenedCompressionModel(CompressionModel): |
| """Wraps a CompressionModel and flatten its codebooks, e.g. |
| instead of returning [B, K, T], return [B, S, T * (K // S)] with |
| S the number of codebooks per step, and `K // S` the number of 'virtual steps' |
| for each real time step. |
| |
| Args: |
| model (CompressionModel): compression model to wrap. |
| codebooks_per_step (int): number of codebooks to keep per step, |
| this must divide the number of codebooks provided by the wrapped model. |
| extend_cardinality (bool): if True, and for instance if codebooks_per_step = 1, |
| if each codebook has a cardinality N, then the first codebook will |
| use the range [0, N - 1], and the second [N, 2 N - 1] etc. |
| On decoding, this can lead to potentially invalid sequences. |
| Any invalid entry will be silently remapped to the proper range |
| with a modulo. |
| """ |
| def __init__(self, model: CompressionModel, codebooks_per_step: int = 1, |
| extend_cardinality: bool = True): |
| super().__init__() |
| self.model = model |
| self.codebooks_per_step = codebooks_per_step |
| self.extend_cardinality = extend_cardinality |
|
|
| @property |
| def total_codebooks(self): |
| return self.model.total_codebooks |
|
|
| @property |
| def num_codebooks(self): |
| """Active number of codebooks used by the quantizer. |
| |
| ..Warning:: this reports the number of codebooks after the flattening |
| of the codebooks! |
| """ |
| assert self.model.num_codebooks % self.codebooks_per_step == 0 |
| return self.codebooks_per_step |
|
|
| def set_num_codebooks(self, n: int): |
| """Set the active number of codebooks used by the quantizer. |
| |
| ..Warning:: this sets the number of codebooks **before** the flattening |
| of the codebooks. |
| """ |
| assert n % self.codebooks_per_step == 0 |
| self.model.set_num_codebooks(n) |
|
|
| @property |
| def num_virtual_steps(self) -> int: |
| """Return the number of virtual steps, e.g. one real step |
| will be split into that many steps. |
| """ |
| return self.model.num_codebooks // self.codebooks_per_step |
|
|
| @property |
| def frame_rate(self) -> int: |
| return self.model.frame_rate * self.num_virtual_steps |
|
|
| @property |
| def sample_rate(self) -> int: |
| return self.model.sample_rate |
|
|
| @property |
| def channels(self) -> int: |
| return self.model.channels |
|
|
| @property |
| def cardinality(self): |
| """Cardinality of each codebook. |
| """ |
| if self.extend_cardinality: |
| return self.model.cardinality * self.num_virtual_steps |
| else: |
| return self.model.cardinality |
|
|
| def forward(self, x: torch.Tensor) -> qt.QuantizedResult: |
| raise NotImplementedError("Not supported, use encode and decode.") |
|
|
| def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: |
| indices, scales = self.model.encode(x) |
| B, K, T = indices.shape |
| indices = rearrange(indices, 'b (k v) t -> b k t v', k=self.codebooks_per_step) |
| if self.extend_cardinality: |
| for virtual_step in range(1, self.num_virtual_steps): |
| indices[..., virtual_step] += self.model.cardinality * virtual_step |
| indices = rearrange(indices, 'b k t v -> b k (t v)') |
| return (indices, scales) |
|
|
| def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): |
| B, K, T = codes.shape |
| assert T % self.num_virtual_steps == 0 |
| codes = rearrange(codes, 'b k (t v) -> b (k v) t', v=self.num_virtual_steps) |
| |
| |
| codes = codes % self.model.cardinality |
| return self.model.decode(codes, scale) |
|
|