Spaces:
Build error
Build error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """EnCodec model implementation.""" | |
| import math | |
| from pathlib import Path | |
| import typing as tp | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| from . import quantization as qt | |
| from . import modules as m | |
| from .utils import _check_checksum, _linear_overlap_add, _get_checkpoint_url | |
| ROOT_URL = 'https://dl.fbaipublicfiles.com/encodec/v0/' | |
| EncodedFrame = tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]] | |
| class LMModel(nn.Module): | |
| """Language Model to estimate probabilities of each codebook entry. | |
| We predict all codebooks in parallel for a given time step. | |
| Args: | |
| n_q (int): number of codebooks. | |
| card (int): codebook cardinality. | |
| dim (int): transformer dimension. | |
| **kwargs: passed to `encoder.modules.transformer.StreamingTransformerEncoder`. | |
| """ | |
| def __init__(self, n_q: int = 32, card: int = 1024, dim: int = 200, **kwargs): | |
| super().__init__() | |
| self.card = card | |
| self.n_q = n_q | |
| self.dim = dim | |
| self.transformer = m.StreamingTransformerEncoder(dim=dim, **kwargs) | |
| self.emb = nn.ModuleList([nn.Embedding(card + 1, dim) for _ in range(n_q)]) | |
| self.linears = nn.ModuleList([nn.Linear(dim, card) for _ in range(n_q)]) | |
| def forward(self, indices: torch.Tensor, | |
| states: tp.Optional[tp.List[torch.Tensor]] = None, offset: int = 0): | |
| """ | |
| Args: | |
| indices (torch.Tensor): indices from the previous time step. Indices | |
| should be 1 + actual index in the codebook. The value 0 is reserved for | |
| when the index is missing (i.e. first time step). Shape should be | |
| `[B, n_q, T]`. | |
| states: state for the streaming decoding. | |
| offset: offset of the current time step. | |
| Returns a 3-tuple `(probabilities, new_states, new_offset)` with probabilities | |
| with a shape `[B, card, n_q, T]`. | |
| """ | |
| B, K, T = indices.shape | |
| input_ = sum([self.emb[k](indices[:, k]) for k in range(K)]) | |
| out, states, offset = self.transformer(input_, states, offset) | |
| logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1).permute(0, 3, 1, 2) | |
| return torch.softmax(logits, dim=1), states, offset | |
| class EncodecModel(nn.Module): | |
| """EnCodec model operating on the raw waveform. | |
| Args: | |
| target_bandwidths (list of float): Target bandwidths. | |
| encoder (nn.Module): Encoder network. | |
| decoder (nn.Module): Decoder network. | |
| sample_rate (int): Audio sample rate. | |
| channels (int): Number of audio channels. | |
| normalize (bool): Whether to apply audio normalization. | |
| segment (float or None): segment duration in sec. when doing overlap-add. | |
| overlap (float): overlap between segment, given as a fraction of the segment duration. | |
| name (str): name of the model, used as metadata when compressing audio. | |
| """ | |
| def __init__(self, | |
| encoder: m.SEANetEncoder, | |
| decoder: m.SEANetDecoder, | |
| quantizer: qt.ResidualVectorQuantizer, | |
| target_bandwidths: tp.List[float], | |
| sample_rate: int, | |
| channels: int, | |
| normalize: bool = False, | |
| segment: tp.Optional[float] = None, | |
| overlap: float = 0.01, | |
| name: str = 'unset'): | |
| super().__init__() | |
| self.bandwidth: tp.Optional[float] = None | |
| self.target_bandwidths = target_bandwidths | |
| self.encoder = encoder | |
| self.quantizer = quantizer | |
| self.decoder = decoder | |
| self.sample_rate = sample_rate | |
| self.channels = channels | |
| self.normalize = normalize | |
| self.segment = segment | |
| self.overlap = overlap | |
| self.frame_rate = math.ceil(self.sample_rate / np.prod(self.encoder.ratios)) | |
| self.name = name | |
| self.bits_per_codebook = int(math.log2(self.quantizer.bins)) | |
| assert 2 ** self.bits_per_codebook == self.quantizer.bins, \ | |
| "quantizer bins must be a power of 2." | |
| def segment_length(self) -> tp.Optional[int]: | |
| if self.segment is None: | |
| return None | |
| return int(self.segment * self.sample_rate) | |
| def segment_stride(self) -> tp.Optional[int]: | |
| segment_length = self.segment_length | |
| if segment_length is None: | |
| return None | |
| return max(1, int((1 - self.overlap) * segment_length)) | |
| def encode(self, x: torch.Tensor) -> tp.List[EncodedFrame]: | |
| """Given a tensor `x`, returns a list of frames containing | |
| the discrete encoded codes for `x`, along with rescaling factors | |
| for each segment, when `self.normalize` is True. | |
| Each frames is a tuple `(codebook, scale)`, with `codebook` of | |
| shape `[B, K, T]`, with `K` the number of codebooks. | |
| """ | |
| assert x.dim() == 3 | |
| _, channels, length = x.shape | |
| assert channels > 0 and channels <= 2 | |
| segment_length = self.segment_length | |
| if segment_length is None: | |
| segment_length = length | |
| stride = length | |
| else: | |
| stride = self.segment_stride # type: ignore | |
| assert stride is not None | |
| encoded_frames: tp.List[EncodedFrame] = [] | |
| for offset in range(0, length, stride): | |
| frame = x[:, :, offset: offset + segment_length] | |
| encoded_frames.append(self._encode_frame(frame)) | |
| return encoded_frames | |
| def _encode_frame(self, x: torch.Tensor) -> EncodedFrame: | |
| length = x.shape[-1] | |
| duration = length / self.sample_rate | |
| assert self.segment is None or duration <= 1e-5 + self.segment | |
| if self.normalize: | |
| mono = x.mean(dim=1, keepdim=True) | |
| volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt() | |
| scale = 1e-8 + volume | |
| x = x / scale | |
| scale = scale.view(-1, 1) | |
| else: | |
| scale = None | |
| emb = self.encoder(x) | |
| codes = self.quantizer.encode(emb, self.frame_rate, self.bandwidth) | |
| codes = codes.transpose(0, 1) | |
| # codes is [B, K, T], with T frames, K nb of codebooks. | |
| return codes, scale | |
| def decode(self, encoded_frames: tp.List[EncodedFrame]) -> torch.Tensor: | |
| """Decode the given frames into a waveform. | |
| Note that the output might be a bit bigger than the input. In that case, | |
| any extra steps at the end can be trimmed. | |
| """ | |
| segment_length = self.segment_length | |
| if segment_length is None: | |
| assert len(encoded_frames) == 1 | |
| return self._decode_frame(encoded_frames[0]) | |
| frames = [self._decode_frame(frame) for frame in encoded_frames] | |
| return _linear_overlap_add(frames, self.segment_stride or 1) | |
| def _decode_frame(self, encoded_frame: EncodedFrame) -> torch.Tensor: | |
| codes, scale = encoded_frame | |
| codes = codes.transpose(0, 1) | |
| emb = self.quantizer.decode(codes) | |
| out = self.decoder(emb) | |
| if scale is not None: | |
| out = out * scale.view(-1, 1, 1) | |
| return out | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| frames = self.encode(x) | |
| return self.decode(frames)[:, :, :x.shape[-1]] | |
| def set_target_bandwidth(self, bandwidth: float): | |
| if bandwidth not in self.target_bandwidths: | |
| raise ValueError(f"This model doesn't support the bandwidth {bandwidth}. " | |
| f"Select one of {self.target_bandwidths}.") | |
| self.bandwidth = bandwidth | |
| def get_lm_model(self) -> LMModel: | |
| """Return the associated LM model to improve the compression rate. | |
| """ | |
| device = next(self.parameters()).device | |
| lm = LMModel(self.quantizer.n_q, self.quantizer.bins, num_layers=5, dim=200, | |
| past_context=int(3.5 * self.frame_rate)).to(device) | |
| checkpoints = { | |
| 'encodec_24khz': 'encodec_lm_24khz-1608e3c0.th', | |
| 'encodec_48khz': 'encodec_lm_48khz-7add9fc3.th', | |
| } | |
| try: | |
| checkpoint_name = checkpoints[self.name] | |
| except KeyError: | |
| raise RuntimeError("No LM pre-trained for the current Encodec model.") | |
| url = _get_checkpoint_url(ROOT_URL, checkpoint_name) | |
| state = torch.hub.load_state_dict_from_url( | |
| url, map_location='cpu', check_hash=True) # type: ignore | |
| lm.load_state_dict(state) | |
| lm.eval() | |
| return lm | |
| def _get_model(target_bandwidths: tp.List[float], | |
| sample_rate: int = 24_000, | |
| channels: int = 1, | |
| causal: bool = True, | |
| model_norm: str = 'weight_norm', | |
| audio_normalize: bool = False, | |
| segment: tp.Optional[float] = None, | |
| name: str = 'unset'): | |
| encoder = m.SEANetEncoder(channels=channels, norm=model_norm, causal=causal) | |
| decoder = m.SEANetDecoder(channels=channels, norm=model_norm, causal=causal) | |
| n_q = int(1000 * target_bandwidths[-1] // (math.ceil(sample_rate / encoder.hop_length) * 10)) | |
| quantizer = qt.ResidualVectorQuantizer( | |
| dimension=encoder.dimension, | |
| n_q=n_q, | |
| bins=1024, | |
| ) | |
| model = EncodecModel( | |
| encoder, | |
| decoder, | |
| quantizer, | |
| target_bandwidths, | |
| sample_rate, | |
| channels, | |
| normalize=audio_normalize, | |
| segment=segment, | |
| name=name, | |
| ) | |
| return model | |
| def _get_pretrained(checkpoint_name: str, repository: tp.Optional[Path] = None): | |
| if repository is not None: | |
| if not repository.is_dir(): | |
| raise ValueError(f"{repository} must exist and be a directory.") | |
| file = repository / checkpoint_name | |
| checksum = file.stem.split('-')[1] | |
| _check_checksum(file, checksum) | |
| return torch.load(file) | |
| else: | |
| url = _get_checkpoint_url(ROOT_URL, checkpoint_name) | |
| return torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True) # type:ignore | |
| def encodec_model_24khz(pretrained: bool = True, repository: tp.Optional[Path] = None): | |
| """Return the pretrained causal 24khz model. | |
| """ | |
| if repository: | |
| assert pretrained | |
| target_bandwidths = [1.5, 3., 6, 12., 24.] | |
| checkpoint_name = 'encodec_24khz-d7cc33bc.th' | |
| sample_rate = 24_000 | |
| channels = 1 | |
| model = EncodecModel._get_model( | |
| target_bandwidths, sample_rate, channels, | |
| causal=True, model_norm='weight_norm', audio_normalize=False, | |
| name='encodec_24khz' if pretrained else 'unset') | |
| if pretrained: | |
| state_dict = EncodecModel._get_pretrained(checkpoint_name, repository) | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| return model | |
| def encodec_model_48khz(pretrained: bool = True, repository: tp.Optional[Path] = None): | |
| """Return the pretrained 48khz model. | |
| """ | |
| if repository: | |
| assert pretrained | |
| target_bandwidths = [3., 6., 12., 24.] | |
| checkpoint_name = 'encodec_48khz-7e698e3e.th' | |
| sample_rate = 48_000 | |
| channels = 2 | |
| model = EncodecModel._get_model( | |
| target_bandwidths, sample_rate, channels, | |
| causal=False, model_norm='time_group_norm', audio_normalize=True, | |
| segment=1., name='encodec_48khz' if pretrained else 'unset') | |
| if pretrained: | |
| state_dict = EncodecModel._get_pretrained(checkpoint_name, repository) | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| return model | |
| def test(): | |
| from itertools import product | |
| import torchaudio | |
| bandwidths = [3, 6, 12, 24] | |
| models = { | |
| 'encodec_24khz': EncodecModel.encodec_model_24khz, | |
| 'encodec_48khz': EncodecModel.encodec_model_48khz | |
| } | |
| for model_name, bw in product(models.keys(), bandwidths): | |
| model = models[model_name]() | |
| model.set_target_bandwidth(bw) | |
| audio_suffix = model_name.split('_')[1][:3] | |
| wav, sr = torchaudio.load(f"test_{audio_suffix}.wav") | |
| wav = wav[:, :model.sample_rate * 2] | |
| wav_in = wav.unsqueeze(0) | |
| wav_dec = model(wav_in)[0] | |
| assert wav.shape == wav_dec.shape, (wav.shape, wav_dec.shape) | |
| if __name__ == '__main__': | |
| test() | |