diff --git a/encoder/__init__.py b/encoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ff8fd2ada59e0e15d4df2854052edf150e5238e3 --- /dev/null +++ b/encoder/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# flake8: noqa + +"""EnCodec neural audio codec.""" + +__version__ = "0.1.2a3" + +from .model import EncodecModel diff --git a/encoder/__pycache__/__init__.cpython-310.pyc b/encoder/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..035253de2e717e30123530ed10c99c97629cb484 Binary files /dev/null and b/encoder/__pycache__/__init__.cpython-310.pyc differ diff --git a/encoder/__pycache__/__init__.cpython-38.pyc b/encoder/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd6ecb47329ca9546bdc428b6b9c87f264fa6f54 Binary files /dev/null and b/encoder/__pycache__/__init__.cpython-38.pyc differ diff --git a/encoder/__pycache__/__init__.cpython-39.pyc b/encoder/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2db0a65b299d7fdc858f0f319152bcf74a36acd Binary files /dev/null and b/encoder/__pycache__/__init__.cpython-39.pyc differ diff --git a/encoder/__pycache__/distrib.cpython-310.pyc b/encoder/__pycache__/distrib.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3504e5fb7ca550e0130c161cf8e01f08d057c86b Binary files /dev/null and b/encoder/__pycache__/distrib.cpython-310.pyc differ diff --git a/encoder/__pycache__/distrib.cpython-38.pyc b/encoder/__pycache__/distrib.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3be36517e04d7ec598880f5d298355578fbe931 Binary files /dev/null and b/encoder/__pycache__/distrib.cpython-38.pyc differ diff --git a/encoder/__pycache__/distrib.cpython-39.pyc b/encoder/__pycache__/distrib.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f71510ae5e2859b2f3ac606fa33fb7bb7b54f5a9 Binary files /dev/null and b/encoder/__pycache__/distrib.cpython-39.pyc differ diff --git a/encoder/__pycache__/model.cpython-310.pyc b/encoder/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ce9e7e9c4b66723ff1299ea97a46f7b5975b2fd Binary files /dev/null and b/encoder/__pycache__/model.cpython-310.pyc differ diff --git a/encoder/__pycache__/model.cpython-38.pyc b/encoder/__pycache__/model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6de05a0e70c1adc7825da87ea66a2903aaf4c807 Binary files /dev/null and b/encoder/__pycache__/model.cpython-38.pyc differ diff --git a/encoder/__pycache__/model.cpython-39.pyc b/encoder/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e9ee1a56466b250f6f2749f8b47dd794973b34c Binary files /dev/null and b/encoder/__pycache__/model.cpython-39.pyc differ diff --git a/encoder/__pycache__/utils.cpython-310.pyc b/encoder/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37246ea4ae4bc5dcf03a0e1b85d81c00f469b122 Binary files /dev/null and b/encoder/__pycache__/utils.cpython-310.pyc differ diff --git a/encoder/__pycache__/utils.cpython-38.pyc b/encoder/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23b71d79116643a9bf8a8486baafa583b86ce7bf Binary files /dev/null and b/encoder/__pycache__/utils.cpython-38.pyc differ diff --git a/encoder/__pycache__/utils.cpython-39.pyc b/encoder/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..485d0c8a0890e212a797650489be93ae55f009c5 Binary files /dev/null and b/encoder/__pycache__/utils.cpython-39.pyc differ diff --git a/encoder/distrib.py b/encoder/distrib.py new file mode 100644 index 0000000000000000000000000000000000000000..b1662d8085cf2878c4cd058537d0f097de91d158 --- /dev/null +++ b/encoder/distrib.py @@ -0,0 +1,124 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Torch distributed utilities.""" + +import typing as tp + +import torch + + +def rank(): + if torch.distributed.is_initialized(): + return torch.distributed.get_rank() + else: + return 0 + + +def world_size(): + if torch.distributed.is_initialized(): + return torch.distributed.get_world_size() + else: + return 1 + + +def is_distributed(): + return world_size() > 1 + + +def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM): + if is_distributed(): + return torch.distributed.all_reduce(tensor, op) + + +def _is_complex_or_float(tensor): + return torch.is_floating_point(tensor) or torch.is_complex(tensor) + + +def _check_number_of_params(params: tp.List[torch.Tensor]): + # utility function to check that the number of params in all workers is the same, + # and thus avoid a deadlock with distributed all reduce. + if not is_distributed() or not params: + return + tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long) + all_reduce(tensor) + if tensor.item() != len(params) * world_size(): + # If not all the workers have the same number, for at least one of them, + # this inequality will be verified. + raise RuntimeError(f"Mismatch in number of params: ours is {len(params)}, " + "at least one worker has a different one.") + + +def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0): + """Broadcast the tensors from the given parameters to all workers. + This can be used to ensure that all workers have the same model to start with. + """ + if not is_distributed(): + return + tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)] + _check_number_of_params(tensors) + handles = [] + for tensor in tensors: + handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True) + handles.append(handle) + for handle in handles: + handle.wait() + + +def sync_buffer(buffers, average=True): + """ + Sync grad for buffers. If average is False, broadcast instead of averaging. + """ + if not is_distributed(): + return + handles = [] + for buffer in buffers: + if torch.is_floating_point(buffer.data): + if average: + handle = torch.distributed.all_reduce( + buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True) + else: + handle = torch.distributed.broadcast( + buffer.data, src=0, async_op=True) + handles.append((buffer, handle)) + for buffer, handle in handles: + handle.wait() + if average: + buffer.data /= world_size + + +def sync_grad(params): + """ + Simpler alternative to DistributedDataParallel, that doesn't rely + on any black magic. For simple models it can also be as fast. + Just call this on your model parameters after the call to backward! + """ + if not is_distributed(): + return + handles = [] + for p in params: + if p.grad is not None: + handle = torch.distributed.all_reduce( + p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True) + handles.append((p, handle)) + for p, handle in handles: + handle.wait() + p.grad.data /= world_size() + + +def average_metrics(metrics: tp.Dict[str, float], count=1.): + """Average a dictionary of metrics across all workers, using the optional + `count` as unnormalized weight. + """ + if not is_distributed(): + return metrics + keys, values = zip(*metrics.items()) + device = 'cuda' if torch.cuda.is_available() else 'cpu' + tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32) + tensor *= count + all_reduce(tensor) + averaged = (tensor[:-1] / tensor[-1]).cpu().tolist() + return dict(zip(keys, averaged)) diff --git a/encoder/model.py b/encoder/model.py new file mode 100644 index 0000000000000000000000000000000000000000..33be28de408112b0f54f062df43ac13953e170ea --- /dev/null +++ b/encoder/model.py @@ -0,0 +1,324 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""EnCodec model implementation.""" + +import math +from pathlib import Path +import typing as tp + +import numpy as np +import torch +from torch import nn + +from . import quantization as qt +from . import modules as m +from .utils import _check_checksum, _linear_overlap_add, _get_checkpoint_url + + +ROOT_URL = 'https://dl.fbaipublicfiles.com/encodec/v0/' + +EncodedFrame = tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]] + + +class LMModel(nn.Module): + """Language Model to estimate probabilities of each codebook entry. + We predict all codebooks in parallel for a given time step. + + Args: + n_q (int): number of codebooks. + card (int): codebook cardinality. + dim (int): transformer dimension. + **kwargs: passed to `encoder.modules.transformer.StreamingTransformerEncoder`. + """ + def __init__(self, n_q: int = 32, card: int = 1024, dim: int = 200, **kwargs): + super().__init__() + self.card = card + self.n_q = n_q + self.dim = dim + self.transformer = m.StreamingTransformerEncoder(dim=dim, **kwargs) + self.emb = nn.ModuleList([nn.Embedding(card + 1, dim) for _ in range(n_q)]) + self.linears = nn.ModuleList([nn.Linear(dim, card) for _ in range(n_q)]) + + def forward(self, indices: torch.Tensor, + states: tp.Optional[tp.List[torch.Tensor]] = None, offset: int = 0): + """ + Args: + indices (torch.Tensor): indices from the previous time step. Indices + should be 1 + actual index in the codebook. The value 0 is reserved for + when the index is missing (i.e. first time step). Shape should be + `[B, n_q, T]`. + states: state for the streaming decoding. + offset: offset of the current time step. + + Returns a 3-tuple `(probabilities, new_states, new_offset)` with probabilities + with a shape `[B, card, n_q, T]`. + + """ + B, K, T = indices.shape + input_ = sum([self.emb[k](indices[:, k]) for k in range(K)]) + out, states, offset = self.transformer(input_, states, offset) + logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1).permute(0, 3, 1, 2) + return torch.softmax(logits, dim=1), states, offset + + +class EncodecModel(nn.Module): + """EnCodec model operating on the raw waveform. + Args: + target_bandwidths (list of float): Target bandwidths. + encoder (nn.Module): Encoder network. + decoder (nn.Module): Decoder network. + sample_rate (int): Audio sample rate. + channels (int): Number of audio channels. + normalize (bool): Whether to apply audio normalization. + segment (float or None): segment duration in sec. when doing overlap-add. + overlap (float): overlap between segment, given as a fraction of the segment duration. + name (str): name of the model, used as metadata when compressing audio. + """ + def __init__(self, + encoder: m.SEANetEncoder, + decoder: m.SEANetDecoder, + quantizer: qt.ResidualVectorQuantizer, + target_bandwidths: tp.List[float], + sample_rate: int, + channels: int, + normalize: bool = False, + segment: tp.Optional[float] = None, + overlap: float = 0.01, + name: str = 'unset'): + super().__init__() + self.bandwidth: tp.Optional[float] = None + self.target_bandwidths = target_bandwidths + self.encoder = encoder + self.quantizer = quantizer + self.decoder = decoder + self.sample_rate = sample_rate + self.channels = channels + self.normalize = normalize + self.segment = segment + self.overlap = overlap + self.frame_rate = math.ceil(self.sample_rate / np.prod(self.encoder.ratios)) + self.name = name + self.bits_per_codebook = int(math.log2(self.quantizer.bins)) + assert 2 ** self.bits_per_codebook == self.quantizer.bins, \ + "quantizer bins must be a power of 2." + + @property + def segment_length(self) -> tp.Optional[int]: + if self.segment is None: + return None + return int(self.segment * self.sample_rate) + + @property + def segment_stride(self) -> tp.Optional[int]: + segment_length = self.segment_length + if segment_length is None: + return None + return max(1, int((1 - self.overlap) * segment_length)) + + def encode(self, x: torch.Tensor) -> tp.List[EncodedFrame]: + """Given a tensor `x`, returns a list of frames containing + the discrete encoded codes for `x`, along with rescaling factors + for each segment, when `self.normalize` is True. + + Each frames is a tuple `(codebook, scale)`, with `codebook` of + shape `[B, K, T]`, with `K` the number of codebooks. + """ + assert x.dim() == 3 + _, channels, length = x.shape + assert channels > 0 and channels <= 2 + segment_length = self.segment_length + if segment_length is None: + segment_length = length + stride = length + else: + stride = self.segment_stride # type: ignore + assert stride is not None + + encoded_frames: tp.List[EncodedFrame] = [] + for offset in range(0, length, stride): + frame = x[:, :, offset: offset + segment_length] + encoded_frames.append(self._encode_frame(frame)) + return encoded_frames + + def _encode_frame(self, x: torch.Tensor) -> EncodedFrame: + length = x.shape[-1] + duration = length / self.sample_rate + assert self.segment is None or duration <= 1e-5 + self.segment + + if self.normalize: + mono = x.mean(dim=1, keepdim=True) + volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt() + scale = 1e-8 + volume + x = x / scale + scale = scale.view(-1, 1) + else: + scale = None + + emb = self.encoder(x) + codes = self.quantizer.encode(emb, self.frame_rate, self.bandwidth) + codes = codes.transpose(0, 1) + # codes is [B, K, T], with T frames, K nb of codebooks. + return codes, scale + + def decode(self, encoded_frames: tp.List[EncodedFrame]) -> torch.Tensor: + """Decode the given frames into a waveform. + Note that the output might be a bit bigger than the input. In that case, + any extra steps at the end can be trimmed. + """ + segment_length = self.segment_length + if segment_length is None: + assert len(encoded_frames) == 1 + return self._decode_frame(encoded_frames[0]) + + frames = [self._decode_frame(frame) for frame in encoded_frames] + return _linear_overlap_add(frames, self.segment_stride or 1) + + def _decode_frame(self, encoded_frame: EncodedFrame) -> torch.Tensor: + codes, scale = encoded_frame + codes = codes.transpose(0, 1) + emb = self.quantizer.decode(codes) + out = self.decoder(emb) + if scale is not None: + out = out * scale.view(-1, 1, 1) + return out + + def forward(self, x: torch.Tensor) -> torch.Tensor: + frames = self.encode(x) + return self.decode(frames)[:, :, :x.shape[-1]] + + def set_target_bandwidth(self, bandwidth: float): + if bandwidth not in self.target_bandwidths: + raise ValueError(f"This model doesn't support the bandwidth {bandwidth}. " + f"Select one of {self.target_bandwidths}.") + self.bandwidth = bandwidth + + def get_lm_model(self) -> LMModel: + """Return the associated LM model to improve the compression rate. + """ + device = next(self.parameters()).device + lm = LMModel(self.quantizer.n_q, self.quantizer.bins, num_layers=5, dim=200, + past_context=int(3.5 * self.frame_rate)).to(device) + checkpoints = { + 'encodec_24khz': 'encodec_lm_24khz-1608e3c0.th', + 'encodec_48khz': 'encodec_lm_48khz-7add9fc3.th', + } + try: + checkpoint_name = checkpoints[self.name] + except KeyError: + raise RuntimeError("No LM pre-trained for the current Encodec model.") + url = _get_checkpoint_url(ROOT_URL, checkpoint_name) + state = torch.hub.load_state_dict_from_url( + url, map_location='cpu', check_hash=True) # type: ignore + lm.load_state_dict(state) + lm.eval() + return lm + + @staticmethod + def _get_model(target_bandwidths: tp.List[float], + sample_rate: int = 24_000, + channels: int = 1, + causal: bool = True, + model_norm: str = 'weight_norm', + audio_normalize: bool = False, + segment: tp.Optional[float] = None, + name: str = 'unset'): + encoder = m.SEANetEncoder(channels=channels, norm=model_norm, causal=causal) + decoder = m.SEANetDecoder(channels=channels, norm=model_norm, causal=causal) + n_q = int(1000 * target_bandwidths[-1] // (math.ceil(sample_rate / encoder.hop_length) * 10)) + quantizer = qt.ResidualVectorQuantizer( + dimension=encoder.dimension, + n_q=n_q, + bins=1024, + ) + model = EncodecModel( + encoder, + decoder, + quantizer, + target_bandwidths, + sample_rate, + channels, + normalize=audio_normalize, + segment=segment, + name=name, + ) + return model + + @staticmethod + def _get_pretrained(checkpoint_name: str, repository: tp.Optional[Path] = None): + if repository is not None: + if not repository.is_dir(): + raise ValueError(f"{repository} must exist and be a directory.") + file = repository / checkpoint_name + checksum = file.stem.split('-')[1] + _check_checksum(file, checksum) + return torch.load(file) + else: + url = _get_checkpoint_url(ROOT_URL, checkpoint_name) + return torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True) # type:ignore + + @staticmethod + def encodec_model_24khz(pretrained: bool = True, repository: tp.Optional[Path] = None): + """Return the pretrained causal 24khz model. + """ + if repository: + assert pretrained + target_bandwidths = [1.5, 3., 6, 12., 24.] + checkpoint_name = 'encodec_24khz-d7cc33bc.th' + sample_rate = 24_000 + channels = 1 + model = EncodecModel._get_model( + target_bandwidths, sample_rate, channels, + causal=True, model_norm='weight_norm', audio_normalize=False, + name='encodec_24khz' if pretrained else 'unset') + if pretrained: + state_dict = EncodecModel._get_pretrained(checkpoint_name, repository) + model.load_state_dict(state_dict) + model.eval() + return model + + @staticmethod + def encodec_model_48khz(pretrained: bool = True, repository: tp.Optional[Path] = None): + """Return the pretrained 48khz model. + """ + if repository: + assert pretrained + target_bandwidths = [3., 6., 12., 24.] + checkpoint_name = 'encodec_48khz-7e698e3e.th' + sample_rate = 48_000 + channels = 2 + model = EncodecModel._get_model( + target_bandwidths, sample_rate, channels, + causal=False, model_norm='time_group_norm', audio_normalize=True, + segment=1., name='encodec_48khz' if pretrained else 'unset') + if pretrained: + state_dict = EncodecModel._get_pretrained(checkpoint_name, repository) + model.load_state_dict(state_dict) + model.eval() + return model + + +def test(): + from itertools import product + import torchaudio + bandwidths = [3, 6, 12, 24] + models = { + 'encodec_24khz': EncodecModel.encodec_model_24khz, + 'encodec_48khz': EncodecModel.encodec_model_48khz + } + for model_name, bw in product(models.keys(), bandwidths): + model = models[model_name]() + model.set_target_bandwidth(bw) + audio_suffix = model_name.split('_')[1][:3] + wav, sr = torchaudio.load(f"test_{audio_suffix}.wav") + wav = wav[:, :model.sample_rate * 2] + wav_in = wav.unsqueeze(0) + wav_dec = model(wav_in)[0] + assert wav.shape == wav_dec.shape, (wav.shape, wav_dec.shape) + + +if __name__ == '__main__': + test() diff --git a/encoder/modules/__init__.py b/encoder/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b8e2f987aafa3abf9b882fe15ca5a3b6e150ea32 --- /dev/null +++ b/encoder/modules/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Torch modules.""" + +# flake8: noqa +from .conv import ( + pad1d, + unpad1d, + NormConv1d, + NormConvTranspose1d, + NormConv2d, + NormConvTranspose2d, + SConv1d, + SConvTranspose1d, +) +from .lstm import SLSTM +from .seanet import SEANetEncoder, SEANetDecoder +from .transformer import StreamingTransformerEncoder diff --git a/encoder/modules/__pycache__/__init__.cpython-310.pyc b/encoder/modules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a33ed84b65ef27dbaf04a3a19723da47b25366a6 Binary files /dev/null and b/encoder/modules/__pycache__/__init__.cpython-310.pyc differ diff --git a/encoder/modules/__pycache__/__init__.cpython-38.pyc b/encoder/modules/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84093798a1e1bae6df7359d95a8c87efdaec2c58 Binary files /dev/null and b/encoder/modules/__pycache__/__init__.cpython-38.pyc differ diff --git a/encoder/modules/__pycache__/__init__.cpython-39.pyc b/encoder/modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f49bea6fb17ed7100a6d599c2bb163bbf52aba4a Binary files /dev/null and b/encoder/modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/encoder/modules/__pycache__/conv.cpython-310.pyc b/encoder/modules/__pycache__/conv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92428edd43a60388547badf58540565c97487f3c Binary files /dev/null and b/encoder/modules/__pycache__/conv.cpython-310.pyc differ diff --git a/encoder/modules/__pycache__/conv.cpython-38.pyc b/encoder/modules/__pycache__/conv.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..751a4cf37fc0cddc3701554f0e01689dac5a0dcd Binary files /dev/null and b/encoder/modules/__pycache__/conv.cpython-38.pyc differ diff --git a/encoder/modules/__pycache__/conv.cpython-39.pyc b/encoder/modules/__pycache__/conv.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..116f106b2690b47545d5588549042c3404174e97 Binary files /dev/null and b/encoder/modules/__pycache__/conv.cpython-39.pyc differ diff --git a/encoder/modules/__pycache__/lstm.cpython-310.pyc b/encoder/modules/__pycache__/lstm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c13de51af1d7e0a07d1899f7ec586fd6936729a Binary files /dev/null and b/encoder/modules/__pycache__/lstm.cpython-310.pyc differ diff --git a/encoder/modules/__pycache__/lstm.cpython-38.pyc b/encoder/modules/__pycache__/lstm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b10457a8dae73156058fd0ab125178a4de29a63 Binary files /dev/null and b/encoder/modules/__pycache__/lstm.cpython-38.pyc differ diff --git a/encoder/modules/__pycache__/lstm.cpython-39.pyc b/encoder/modules/__pycache__/lstm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40aa16b546e6fc87f4b1988272128b9ac707dbe3 Binary files /dev/null and b/encoder/modules/__pycache__/lstm.cpython-39.pyc differ diff --git a/encoder/modules/__pycache__/norm.cpython-310.pyc b/encoder/modules/__pycache__/norm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f6ae9ce6bbf5b6f4216cbffd38c45327a061218 Binary files /dev/null and b/encoder/modules/__pycache__/norm.cpython-310.pyc differ diff --git a/encoder/modules/__pycache__/norm.cpython-38.pyc b/encoder/modules/__pycache__/norm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9a4f2d7c1d15d51c3fc6113c51ac843a66b5aa5 Binary files /dev/null and b/encoder/modules/__pycache__/norm.cpython-38.pyc differ diff --git a/encoder/modules/__pycache__/norm.cpython-39.pyc b/encoder/modules/__pycache__/norm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c46d04cc05d423a6889d977c4be0c10176b8ce8 Binary files /dev/null and b/encoder/modules/__pycache__/norm.cpython-39.pyc differ diff --git a/encoder/modules/__pycache__/seanet.cpython-310.pyc b/encoder/modules/__pycache__/seanet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ee20a061f1d9314be4ef13c60d2a6469e160bc9 Binary files /dev/null and b/encoder/modules/__pycache__/seanet.cpython-310.pyc differ diff --git a/encoder/modules/__pycache__/seanet.cpython-38.pyc b/encoder/modules/__pycache__/seanet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ba48c22478ec0e4b88eea0f0fe7267500f5a1ae Binary files /dev/null and b/encoder/modules/__pycache__/seanet.cpython-38.pyc differ diff --git a/encoder/modules/__pycache__/seanet.cpython-39.pyc b/encoder/modules/__pycache__/seanet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1791278f57318ad8d0145526442a82bbb52c11f9 Binary files /dev/null and b/encoder/modules/__pycache__/seanet.cpython-39.pyc differ diff --git a/encoder/modules/__pycache__/transformer.cpython-310.pyc b/encoder/modules/__pycache__/transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65d98c08c2e3a16e672cd25914c3e104359d1b1d Binary files /dev/null and b/encoder/modules/__pycache__/transformer.cpython-310.pyc differ diff --git a/encoder/modules/__pycache__/transformer.cpython-38.pyc b/encoder/modules/__pycache__/transformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c92d179a3168fc9075c1decab210229e28bdadf9 Binary files /dev/null and b/encoder/modules/__pycache__/transformer.cpython-38.pyc differ diff --git a/encoder/modules/__pycache__/transformer.cpython-39.pyc b/encoder/modules/__pycache__/transformer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d88cd6267a12089a158c8a9d94b8c4aab05dec9 Binary files /dev/null and b/encoder/modules/__pycache__/transformer.cpython-39.pyc differ diff --git a/encoder/modules/conv.py b/encoder/modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..e83ae84d20ad2082c6e83bb7fc73bb22ac58cf13 --- /dev/null +++ b/encoder/modules/conv.py @@ -0,0 +1,253 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Convolutional layers wrappers and utilities.""" + +import math +import typing as tp +import warnings + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.utils import spectral_norm, weight_norm + +from .norm import ConvLayerNorm + + +CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm', + 'time_layer_norm', 'layer_norm', 'time_group_norm']) + + +def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module: + assert norm in CONV_NORMALIZATIONS + if norm == 'weight_norm': + return weight_norm(module) + elif norm == 'spectral_norm': + return spectral_norm(module) + else: + # We already check was in CONV_NORMALIZATION, so any other choice + # doesn't need reparametrization. + return module + + +def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module: + """Return the proper normalization module. If causal is True, this will ensure the returned + module is causal, or return an error if the normalization doesn't support causal evaluation. + """ + assert norm in CONV_NORMALIZATIONS + if norm == 'layer_norm': + assert isinstance(module, nn.modules.conv._ConvNd) + return ConvLayerNorm(module.out_channels, **norm_kwargs) + elif norm == 'time_group_norm': + if causal: + raise ValueError("GroupNorm doesn't support causal evaluation.") + assert isinstance(module, nn.modules.conv._ConvNd) + return nn.GroupNorm(1, module.out_channels, **norm_kwargs) + else: + return nn.Identity() + + +def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, + padding_total: int = 0) -> int: + """See `pad_for_conv1d`. + """ + length = x.shape[-1] + n_frames = (length - kernel_size + padding_total) / stride + 1 + ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) + return ideal_length - length + + +def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0): + """Pad for a convolution to make sure that the last window is full. + Extra padding is added at the end. This is required to ensure that we can rebuild + an output of the same length, as otherwise, even with padding, some time steps + might get removed. + For instance, with total padding = 4, kernel size = 4, stride = 2: + 0 0 1 2 3 4 5 0 0 # (0s are padding) + 1 2 3 # (output frames of a convolution, last 0 is never used) + 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding) + 1 2 3 4 # once you removed padding, we are missing one time step ! + """ + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + return F.pad(x, (0, extra_padding)) + + +def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.): + """Tiny wrapper around F.pad, just to allow for reflect padding on small input. + If this is the case, we insert extra 0 padding to the right before the reflection happen. + """ + length = x.shape[-1] + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + if mode == 'reflect': + max_pad = max(padding_left, padding_right) + extra_pad = 0 + if length <= max_pad: + extra_pad = max_pad - length + 1 + x = F.pad(x, (0, extra_pad)) + padded = F.pad(x, paddings, mode, value) + end = padded.shape[-1] - extra_pad + return padded[..., :end] + else: + return F.pad(x, paddings, mode, value) + + +def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): + """Remove padding from x, handling properly zero padding. Only for 1d!""" + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + assert (padding_left + padding_right) <= x.shape[-1] + end = x.shape[-1] - padding_right + return x[..., padding_left: end] + + +class NormConv1d(nn.Module): + """Wrapper around Conv1d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + def __init__(self, *args, causal: bool = False, norm: str = 'none', + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm) + self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + return x + + +class NormConv2d(nn.Module): + """Wrapper around Conv2d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + def __init__(self, *args, norm: str = 'none', + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm) + self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + return x + + +class NormConvTranspose1d(nn.Module): + """Wrapper around ConvTranspose1d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + def __init__(self, *args, causal: bool = False, norm: str = 'none', + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm) + self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.convtr(x) + x = self.norm(x) + return x + + +class NormConvTranspose2d(nn.Module): + """Wrapper around ConvTranspose2d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + def __init__(self, *args, norm: str = 'none', + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm) + self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs) + + def forward(self, x): + x = self.convtr(x) + x = self.norm(x) + return x + + +class SConv1d(nn.Module): + """Conv1d with some builtin handling of asymmetric or causal padding + and normalization. + """ + def __init__(self, in_channels: int, out_channels: int, + kernel_size: int, stride: int = 1, dilation: int = 1, + groups: int = 1, bias: bool = True, causal: bool = False, + norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, + pad_mode: str = 'reflect'): + super().__init__() + # warn user on unusual setup between dilation and stride + if stride > 1 and dilation > 1: + warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1' + f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).') + self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride, + dilation=dilation, groups=groups, bias=bias, causal=causal, + norm=norm, norm_kwargs=norm_kwargs) + self.causal = causal + self.pad_mode = pad_mode + + def forward(self, x): + B, C, T = x.shape + kernel_size = self.conv.conv.kernel_size[0] + stride = self.conv.conv.stride[0] + dilation = self.conv.conv.dilation[0] + kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations + padding_total = kernel_size - stride + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + if self.causal: + # Left padding for causal + x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode) + return self.conv(x) + + +class SConvTranspose1d(nn.Module): + """ConvTranspose1d with some builtin handling of asymmetric or causal padding + and normalization. + """ + def __init__(self, in_channels: int, out_channels: int, + kernel_size: int, stride: int = 1, causal: bool = False, + norm: str = 'none', trim_right_ratio: float = 1., + norm_kwargs: tp.Dict[str, tp.Any] = {}): + super().__init__() + self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride, + causal=causal, norm=norm, norm_kwargs=norm_kwargs) + self.causal = causal + self.trim_right_ratio = trim_right_ratio + assert self.causal or self.trim_right_ratio == 1., \ + "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" + assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1. + + def forward(self, x): + kernel_size = self.convtr.convtr.kernel_size[0] + stride = self.convtr.convtr.stride[0] + padding_total = kernel_size - stride + + y = self.convtr(x) + + # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be + # removed at the very end, when keeping only the right length for the output, + # as removing it here would require also passing the length at the matching layer + # in the encoder. + if self.causal: + # Trim the padding on the right according to the specified ratio + # if trim_right_ratio = 1.0, trim everything from right + padding_right = math.ceil(padding_total * self.trim_right_ratio) + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + return y diff --git a/encoder/modules/lstm.py b/encoder/modules/lstm.py new file mode 100644 index 0000000000000000000000000000000000000000..49908198953deed173bed6eed5199eb74b99e5f8 --- /dev/null +++ b/encoder/modules/lstm.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""LSTM layers module.""" + +from torch import nn + + +class SLSTM(nn.Module): + """ + LSTM without worrying about the hidden state, nor the layout of the data. + Expects input as convolutional layout. + """ + def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True): + super().__init__() + self.skip = skip + self.lstm = nn.LSTM(dimension, dimension, num_layers) + + # def forward(self, x): + # x = x.permute(2, 0, 1) + # y, _ = self.lstm(x) + # if self.skip: + # y = y + x + # y = y.permute(1, 2, 0) + # return y + + # 修改transpose顺序 + def forward(self, x): + # # 插入reshape + # x = x.reshape(x.shape) + x1 = x.permute(2, 0, 1) + y, _ = self.lstm(x1) + y = y.permute(1, 2, 0) + if self.skip: + y = y + x + return y diff --git a/encoder/modules/norm.py b/encoder/modules/norm.py new file mode 100644 index 0000000000000000000000000000000000000000..19970e0a21ea1c10461cb56d776619dd5f64ff36 --- /dev/null +++ b/encoder/modules/norm.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Normalization modules.""" + +import typing as tp + +import einops +import torch +from torch import nn + + +class ConvLayerNorm(nn.LayerNorm): + """ + Convolution-friendly LayerNorm that moves channels to last dimensions + before running the normalization and moves them back to original position right after. + """ + def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs): + super().__init__(normalized_shape, **kwargs) + + def forward(self, x): + x = einops.rearrange(x, 'b ... t -> b t ...') + x = super().forward(x) + x = einops.rearrange(x, 'b t ... -> b ... t') + return diff --git a/encoder/modules/seanet.py b/encoder/modules/seanet.py new file mode 100644 index 0000000000000000000000000000000000000000..ea1c02d508cbffce0613a637d4c7943d936b09db --- /dev/null +++ b/encoder/modules/seanet.py @@ -0,0 +1,253 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Encodec SEANet-based encoder and decoder implementation.""" + +import typing as tp + +import numpy as np +import torch.nn as nn + +from . import ( + SConv1d, + SConvTranspose1d, + SLSTM +) + + +class SEANetResnetBlock(nn.Module): + """Residual block from SEANet model. + Args: + dim (int): Dimension of the input/output + kernel_sizes (list): List of kernel sizes for the convolutions. + dilations (list): List of dilations for the convolutions. + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function + norm (str): Normalization method. + norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. + causal (bool): Whether to use fully causal convolution. + pad_mode (str): Padding mode for the convolutions. + compress (int): Reduced dimensionality in residual branches (from Demucs v3) + true_skip (bool): Whether to use true skip connection or a simple convolution as the skip connection. + """ + def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1], + activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, + norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False, + pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True): + super().__init__() + assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations' + act = getattr(nn, activation) + hidden = dim // compress + block = [] + for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)): + in_chs = dim if i == 0 else hidden + out_chs = dim if i == len(kernel_sizes) - 1 else hidden + block += [ + act(**activation_params), + SConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation, + norm=norm, norm_kwargs=norm_params, + causal=causal, pad_mode=pad_mode), + ] + self.block = nn.Sequential(*block) + self.shortcut: nn.Module + if true_skip: + self.shortcut = nn.Identity() + else: + self.shortcut = SConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params, + causal=causal, pad_mode=pad_mode) + + def forward(self, x): + return self.shortcut(x) + self.block(x) + + +class SEANetEncoder(nn.Module): + """SEANet encoder. + Args: + channels (int): Audio channels. + dimension (int): Intermediate representation dimension. + n_filters (int): Base width for the model. + n_residual_layers (int): nb of residual layers. + ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of + upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here + that must match the decoder order + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function + norm (str): Normalization method. + norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. + kernel_size (int): Kernel size for the initial convolution. + last_kernel_size (int): Kernel size for the initial convolution. + residual_kernel_size (int): Kernel size for the residual layers. + dilation_base (int): How much to increase the dilation with each layer. + causal (bool): Whether to use fully causal convolution. + pad_mode (str): Padding mode for the convolutions. + true_skip (bool): Whether to use true skip connection or a simple + (streamable) convolution as the skip connection in the residual network blocks. + compress (int): Reduced dimensionality in residual branches (from Demucs v3). + lstm (int): Number of LSTM layers at the end of the encoder. + """ + def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 1, + ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, + norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7, + last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False, + pad_mode: str = 'reflect', true_skip: bool = False, compress: int = 2, lstm: int = 2): + super().__init__() + self.channels = channels + self.dimension = dimension + self.n_filters = n_filters + self.ratios = list(reversed(ratios)) + del ratios + self.n_residual_layers = n_residual_layers + self.hop_length = np.prod(self.ratios) + + act = getattr(nn, activation) + mult = 1 + model: tp.List[nn.Module] = [ + SConv1d(channels, mult * n_filters, kernel_size, norm=norm, norm_kwargs=norm_params, + causal=causal, pad_mode=pad_mode) + ] + # Downsample to raw audio scale + for i, ratio in enumerate(self.ratios): + # Add residual layers + for j in range(n_residual_layers): + model += [ + SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1], + dilations=[dilation_base ** j, 1], + norm=norm, norm_params=norm_params, + activation=activation, activation_params=activation_params, + causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)] + + # Add downsampling layers + model += [ + act(**activation_params), + SConv1d(mult * n_filters, mult * n_filters * 2, + kernel_size=ratio * 2, stride=ratio, + norm=norm, norm_kwargs=norm_params, + causal=causal, pad_mode=pad_mode), + ] + mult *= 2 + + if lstm: + model += [SLSTM(mult * n_filters, num_layers=lstm)] + + model += [ + act(**activation_params), + SConv1d(mult * n_filters, dimension, last_kernel_size, norm=norm, norm_kwargs=norm_params, + causal=causal, pad_mode=pad_mode) + ] + + self.model = nn.Sequential(*model) + + def forward(self, x): + return self.model(x) + + +class SEANetDecoder(nn.Module): + """SEANet decoder. + Args: + channels (int): Audio channels. + dimension (int): Intermediate representation dimension. + n_filters (int): Base width for the model. + n_residual_layers (int): nb of residual layers. + ratios (Sequence[int]): kernel size and stride ratios + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function + final_activation (str): Final activation function after all convolutions. + final_activation_params (dict): Parameters to provide to the activation function + norm (str): Normalization method. + norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. + kernel_size (int): Kernel size for the initial convolution. + last_kernel_size (int): Kernel size for the initial convolution. + residual_kernel_size (int): Kernel size for the residual layers. + dilation_base (int): How much to increase the dilation with each layer. + causal (bool): Whether to use fully causal convolution. + pad_mode (str): Padding mode for the convolutions. + true_skip (bool): Whether to use true skip connection or a simple + (streamable) convolution as the skip connection in the residual network blocks. + compress (int): Reduced dimensionality in residual branches (from Demucs v3). + lstm (int): Number of LSTM layers at the end of the encoder. + trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup. + If equal to 1.0, it means that all the trimming is done at the right. + """ + def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 1, + ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, + final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None, + norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7, + last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False, + pad_mode: str = 'reflect', true_skip: bool = False, compress: int = 2, lstm: int = 2, + trim_right_ratio: float = 1.0): + super().__init__() + self.dimension = dimension + self.channels = channels + self.n_filters = n_filters + self.ratios = ratios + del ratios + self.n_residual_layers = n_residual_layers + self.hop_length = np.prod(self.ratios) + + act = getattr(nn, activation) + mult = int(2 ** len(self.ratios)) + model: tp.List[nn.Module] = [ + SConv1d(dimension, mult * n_filters, kernel_size, norm=norm, norm_kwargs=norm_params, + causal=causal, pad_mode=pad_mode) + ] + + if lstm: + model += [SLSTM(mult * n_filters, num_layers=lstm)] + + # Upsample to raw audio scale + for i, ratio in enumerate(self.ratios): + # Add upsampling layers + model += [ + act(**activation_params), + SConvTranspose1d(mult * n_filters, mult * n_filters // 2, + kernel_size=ratio * 2, stride=ratio, + norm=norm, norm_kwargs=norm_params, + causal=causal, trim_right_ratio=trim_right_ratio), + ] + # Add residual layers + for j in range(n_residual_layers): + model += [ + SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1], + dilations=[dilation_base ** j, 1], + activation=activation, activation_params=activation_params, + norm=norm, norm_params=norm_params, causal=causal, + pad_mode=pad_mode, compress=compress, true_skip=true_skip)] + + mult //= 2 + + # Add final layers + model += [ + act(**activation_params), + SConv1d(n_filters, channels, last_kernel_size, norm=norm, norm_kwargs=norm_params, + causal=causal, pad_mode=pad_mode) + ] + # Add optional final activation to decoder (eg. tanh) + if final_activation is not None: + final_act = getattr(nn, final_activation) + final_activation_params = final_activation_params or {} + model += [ + final_act(**final_activation_params) + ] + self.model = nn.Sequential(*model) + + def forward(self, z): + y = self.model(z) + return y + + +def test(): + import torch + encoder = SEANetEncoder() + decoder = SEANetDecoder() + x = torch.randn(1, 1, 24000) + z = encoder(x) + assert list(z.shape) == [1, 128, 75], z.shape + y = decoder(z) + assert y.shape == x.shape, (x.shape, y.shape) + + +if __name__ == '__main__': + test() diff --git a/encoder/modules/transformer.py b/encoder/modules/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..44b47918f84aa47021c0d6f5bd58364641088541 --- /dev/null +++ b/encoder/modules/transformer.py @@ -0,0 +1,119 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""A streamable transformer.""" + +import typing as tp + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000): + """Create time embedding for the given positions, target dimension `dim`. + """ + # We aim for BTC format + assert dim % 2 == 0 + half_dim = dim // 2 + adim = torch.arange(half_dim, device=positions.device).view(1, 1, -1) + phase = positions / (max_period ** (adim / (half_dim - 1))) + return torch.cat([ + torch.cos(phase), + torch.sin(phase), + ], dim=-1) + + +class StreamingTransformerEncoderLayer(nn.TransformerEncoderLayer): + def forward(self, x: torch.Tensor, x_past: torch.Tensor, past_context: int): # type: ignore + if self.norm_first: + sa_input = self.norm1(x) + x = x + self._sa_block(sa_input, x_past, past_context) + x = x + self._ff_block(self.norm2(x)) + else: + sa_input = x + x = self.norm1(x + self._sa_block(sa_input, x_past, past_context)) + x = self.norm2(x + self._ff_block(x)) + + return x, sa_input + + # self-attention block + def _sa_block(self, x: torch.Tensor, x_past: torch.Tensor, past_context: int): # type: ignore + _, T, _ = x.shape + _, H, _ = x_past.shape + + queries = x + keys = torch.cat([x_past, x], dim=1) + values = keys + + queries_pos = torch.arange(H, T + H, device=x.device).view(-1, 1) + keys_pos = torch.arange(T + H, device=x.device).view(1, -1) + delta = queries_pos - keys_pos + valid_access = (delta >= 0) & (delta <= past_context) + x = self.self_attn(queries, keys, values, + attn_mask=~valid_access, + need_weights=False)[0] + return self.dropout1(x) + + +class StreamingTransformerEncoder(nn.Module): + """TransformerEncoder with streaming support. + + Args: + dim (int): dimension of the data. + hidden_scale (int): intermediate dimension of FF module is this times the dimension. + num_heads (int): number of heads. + num_layers (int): number of layers. + max_period (float): maxium period of cosines in the positional embedding. + past_context (int or None): receptive field for the causal mask, infinite if None. + gelu (bool): if true uses GeLUs, otherwise use ReLUs. + norm_in (bool): normalize the input. + dropout (float): dropout probability. + **kwargs: See `nn.TransformerEncoderLayer`. + """ + def __init__(self, dim, hidden_scale: float = 4., num_heads: int = 8, num_layers: int = 5, + max_period: float = 10000, past_context: int = 1000, gelu: bool = True, + norm_in: bool = True, dropout: float = 0., **kwargs): + super().__init__() + assert dim % num_heads == 0 + hidden_dim = int(dim * hidden_scale) + + self.max_period = max_period + self.past_context = past_context + activation: tp.Any = F.gelu if gelu else F.relu + + self.norm_in: nn.Module + if norm_in: + self.norm_in = nn.LayerNorm(dim) + else: + self.norm_in = nn.Identity() + + self.layers = nn.ModuleList() + for idx in range(num_layers): + self.layers.append( + StreamingTransformerEncoderLayer( + dim, num_heads, hidden_dim, + activation=activation, batch_first=True, dropout=dropout, **kwargs)) + + def forward(self, x: torch.Tensor, + states: tp.Optional[tp.List[torch.Tensor]] = None, + offset: tp.Union[int, torch.Tensor] = 0): + B, T, C = x.shape + if states is None: + states = [torch.zeros_like(x[:, :1]) for _ in range(1 + len(self.layers))] + + positions = torch.arange(T, device=x.device).view(1, -1, 1) + offset + pos_emb = create_sin_embedding(positions, C, max_period=self.max_period) + + new_state: tp.List[torch.Tensor] = [] + x = self.norm_in(x) + x = x + pos_emb + + for layer_state, layer in zip(states, self.layers): + x, new_layer_state = layer(x, layer_state, self.past_context) + new_layer_state = torch.cat([layer_state, new_layer_state], dim=1) + new_state.append(new_layer_state[:, -self.past_context:, :]) + return x, new_state, offset + T diff --git a/encoder/msstftd.py b/encoder/msstftd.py new file mode 100644 index 0000000000000000000000000000000000000000..a1d3242a57e1e20e99bc2fa86e363cc5ec92cbf7 --- /dev/null +++ b/encoder/msstftd.py @@ -0,0 +1,147 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""MS-STFT discriminator, provided here for reference.""" + +import typing as tp + +import torchaudio +import torch +from torch import nn +from einops import rearrange + +from .modules import NormConv2d + + +FeatureMapType = tp.List[torch.Tensor] +LogitsType = torch.Tensor +DiscriminatorOutput = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]] + + +def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)): + return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2) + + +class DiscriminatorSTFT(nn.Module): + """STFT sub-discriminator. + Args: + filters (int): Number of filters in convolutions + in_channels (int): Number of input channels. Default: 1 + out_channels (int): Number of output channels. Default: 1 + n_fft (int): Size of FFT for each scale. Default: 1024 + hop_length (int): Length of hop between STFT windows for each scale. Default: 256 + kernel_size (tuple of int): Inner Conv2d kernel sizes. Default: ``(3, 9)`` + stride (tuple of int): Inner Conv2d strides. Default: ``(1, 2)`` + dilations (list of int): Inner Conv2d dilation on the time dimension. Default: ``[1, 2, 4]`` + win_length (int): Window size for each scale. Default: 1024 + normalized (bool): Whether to normalize by magnitude after stft. Default: True + norm (str): Normalization method. Default: `'weight_norm'` + activation (str): Activation function. Default: `'LeakyReLU'` + activation_params (dict): Parameters to provide to the activation function. + growth (int): Growth factor for the filters. Default: 1 + """ + def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, + n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, max_filters: int = 1024, + filters_scale: int = 1, kernel_size: tp.Tuple[int, int] = (3, 9), dilations: tp.List = [1, 2, 4], + stride: tp.Tuple[int, int] = (1, 2), normalized: bool = True, norm: str = 'weight_norm', + activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}): + super().__init__() + assert len(kernel_size) == 2 + assert len(stride) == 2 + self.filters = filters + self.in_channels = in_channels + self.out_channels = out_channels + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.normalized = normalized + self.activation = getattr(torch.nn, activation)(**activation_params) + self.spec_transform = torchaudio.transforms.Spectrogram( + n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window_fn=torch.hann_window, + normalized=self.normalized, center=False, pad_mode=None, power=None) + spec_channels = 2 * self.in_channels + self.convs = nn.ModuleList() + self.convs.append( + NormConv2d(spec_channels, self.filters, kernel_size=kernel_size, padding=get_2d_padding(kernel_size)) + ) + in_chs = min(filters_scale * self.filters, max_filters) + for i, dilation in enumerate(dilations): + out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters) + self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride, + dilation=(dilation, 1), padding=get_2d_padding(kernel_size, (dilation, 1)), + norm=norm)) + in_chs = out_chs + out_chs = min((filters_scale ** (len(dilations) + 1)) * self.filters, max_filters) + self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_size[0], kernel_size[0]), + padding=get_2d_padding((kernel_size[0], kernel_size[0])), + norm=norm)) + self.conv_post = NormConv2d(out_chs, self.out_channels, + kernel_size=(kernel_size[0], kernel_size[0]), + padding=get_2d_padding((kernel_size[0], kernel_size[0])), + norm=norm) + + def forward(self, x: torch.Tensor): + fmap = [] + z = self.spec_transform(x) # [B, 2, Freq, Frames, 2] + z = torch.cat([z.real, z.imag], dim=1) + z = rearrange(z, 'b c w t -> b c t w') + for i, layer in enumerate(self.convs): + z = layer(z) + z = self.activation(z) + fmap.append(z) + z = self.conv_post(z) + return z, fmap + + +class MultiScaleSTFTDiscriminator(nn.Module): + """Multi-Scale STFT (MS-STFT) discriminator. + Args: + filters (int): Number of filters in convolutions + in_channels (int): Number of input channels. Default: 1 + out_channels (int): Number of output channels. Default: 1 + n_ffts (Sequence[int]): Size of FFT for each scale + hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale + win_lengths (Sequence[int]): Window size for each scale + **kwargs: additional args for STFTDiscriminator + """ + def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, + n_ffts: tp.List[int] = [1024, 2048, 512], hop_lengths: tp.List[int] = [256, 512, 128], + win_lengths: tp.List[int] = [1024, 2048, 512], **kwargs): + super().__init__() + assert len(n_ffts) == len(hop_lengths) == len(win_lengths) + self.discriminators = nn.ModuleList([ + DiscriminatorSTFT(filters, in_channels=in_channels, out_channels=out_channels, + n_fft=n_ffts[i], win_length=win_lengths[i], hop_length=hop_lengths[i], **kwargs) + for i in range(len(n_ffts)) + ]) + self.num_discriminators = len(self.discriminators) + + def forward(self, x: torch.Tensor) -> DiscriminatorOutput: + logits = [] + fmaps = [] + for disc in self.discriminators: + logit, fmap = disc(x) + logits.append(logit) + fmaps.append(fmap) + return logits, fmaps + + +def test(): + disc = MultiScaleSTFTDiscriminator(filters=32) + y = torch.randn(1, 1, 24000) + y_hat = torch.randn(1, 1, 24000) + + y_disc_r, fmap_r = disc(y) + y_disc_gen, fmap_gen = disc(y_hat) + assert len(y_disc_r) == len(y_disc_gen) == len(fmap_r) == len(fmap_gen) == disc.num_discriminators + + assert all([len(fm) == 5 for fm in fmap_r + fmap_gen]) + assert all([list(f.shape)[:2] == [1, 32] for fm in fmap_r + fmap_gen for f in fm]) + assert all([len(logits.shape) == 4 for logits in y_disc_r + y_disc_gen]) + + +if __name__ == '__main__': + test() diff --git a/encoder/quantization/__init__.py b/encoder/quantization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bfabe52b8cb6f260cdda6137b34df2f4736bd02f --- /dev/null +++ b/encoder/quantization/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# flake8: noqa +from .vq import QuantizedResult, ResidualVectorQuantizer diff --git a/encoder/quantization/__pycache__/__init__.cpython-310.pyc b/encoder/quantization/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4cc4486a7736e1b3df42a634a80941593723f7fa Binary files /dev/null and b/encoder/quantization/__pycache__/__init__.cpython-310.pyc differ diff --git a/encoder/quantization/__pycache__/__init__.cpython-38.pyc b/encoder/quantization/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a12af2ffad71d3693f11d93e858a980883881bf0 Binary files /dev/null and b/encoder/quantization/__pycache__/__init__.cpython-38.pyc differ diff --git a/encoder/quantization/__pycache__/__init__.cpython-39.pyc b/encoder/quantization/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b4092569e04ba3bd7cfcd2a58e484366830b8c0 Binary files /dev/null and b/encoder/quantization/__pycache__/__init__.cpython-39.pyc differ diff --git a/encoder/quantization/__pycache__/core_vq.cpython-310.pyc b/encoder/quantization/__pycache__/core_vq.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..697d4470912d91c4deb64ba3b1516e7a0b913541 Binary files /dev/null and b/encoder/quantization/__pycache__/core_vq.cpython-310.pyc differ diff --git a/encoder/quantization/__pycache__/core_vq.cpython-38.pyc b/encoder/quantization/__pycache__/core_vq.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4b50fbc6b025adf921725dc899676e7863fed84 Binary files /dev/null and b/encoder/quantization/__pycache__/core_vq.cpython-38.pyc differ diff --git a/encoder/quantization/__pycache__/core_vq.cpython-39.pyc b/encoder/quantization/__pycache__/core_vq.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78af66c127804bb3deedbab6e5a4b1551af8048a Binary files /dev/null and b/encoder/quantization/__pycache__/core_vq.cpython-39.pyc differ diff --git a/encoder/quantization/__pycache__/vq.cpython-310.pyc b/encoder/quantization/__pycache__/vq.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fee3f97870bdbcee04c49b262fbb80fcfab618cb Binary files /dev/null and b/encoder/quantization/__pycache__/vq.cpython-310.pyc differ diff --git a/encoder/quantization/__pycache__/vq.cpython-38.pyc b/encoder/quantization/__pycache__/vq.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c139b8c1483e0deb5b58265798374a483195afcf Binary files /dev/null and b/encoder/quantization/__pycache__/vq.cpython-38.pyc differ diff --git a/encoder/quantization/__pycache__/vq.cpython-39.pyc b/encoder/quantization/__pycache__/vq.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d81919274e357c70fd25751dfe3a69156385b1f Binary files /dev/null and b/encoder/quantization/__pycache__/vq.cpython-39.pyc differ diff --git a/encoder/quantization/ac.py b/encoder/quantization/ac.py new file mode 100644 index 0000000000000000000000000000000000000000..f0f3e5dcd385cd273a145effa3f53ce7ccfdc74c --- /dev/null +++ b/encoder/quantization/ac.py @@ -0,0 +1,292 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Arithmetic coder.""" + +import io +import math +import random +import typing as tp +import torch + +from ..binary import BitPacker, BitUnpacker + + +def build_stable_quantized_cdf(pdf: torch.Tensor, total_range_bits: int, + roundoff: float = 1e-8, min_range: int = 2, + check: bool = True) -> torch.Tensor: + """Turn the given PDF into a quantized CDF that splits + [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional + to the PDF. + + Args: + pdf (torch.Tensor): probability distribution, shape should be `[N]`. + total_range_bits (int): see `ArithmeticCoder`, the typical range we expect + during the coding process is `[0, 2 ** total_range_bits - 1]`. + roundoff (float): will round the pdf up to that level to remove difference coming + from e.g. evaluating the Language Model on different architectures. + min_range (int): minimum range width. Should always be at least 2 for numerical + stability. Use this to avoid pathological behavior is a value + that is expected to be rare actually happens in real life. + check (bool): if True, checks that nothing bad happened, can be deactivated for speed. + """ + pdf = pdf.detach() + if roundoff: + pdf = (pdf / roundoff).floor() * roundoff + # interpolate with uniform distribution to achieve desired minimum probability. + total_range = 2 ** total_range_bits + cardinality = len(pdf) + alpha = min_range * cardinality / total_range + assert alpha <= 1, "you must reduce min_range" + ranges = (((1 - alpha) * total_range) * pdf).floor().long() + ranges += min_range + quantized_cdf = torch.cumsum(ranges, dim=-1) + if min_range < 2: + raise ValueError("min_range must be at least 2.") + if check: + assert quantized_cdf[-1] <= 2 ** total_range_bits, quantized_cdf[-1] + if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range: + raise ValueError("You must increase your total_range_bits.") + return quantized_cdf + + +class ArithmeticCoder: + """ArithmeticCoder, + Let us take a distribution `p` over `N` symbols, and assume we have a stream + of random variables `s_t` sampled from `p`. Let us assume that we have a budget + of `B` bits that we can afford to write on device. There are `2**B` possible numbers, + corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single + sequence `(s_t)` by doing the following: + + 1) Initialize the current range to` [0 ** 2 B - 1]`. + 2) For each time step t, split the current range into contiguous chunks, + one for each possible outcome, with size roughly proportional to `p`. + For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks + would be `{[0, 2], [3, 3]}`. + 3) Select the chunk corresponding to `s_t`, and replace the current range with this. + 4) When done encoding all the values, just select any value remaining in the range. + + You will notice that this procedure can fail: for instance if at any point in time + the range is smaller than `N`, then we can no longer assign a non-empty chunk to each + possible outcome. Intuitively, the more likely a value is, the less the range width + will reduce, and the longer we can go on encoding values. This makes sense: for any efficient + coding scheme, likely outcomes would take less bits, and more of them can be coded + with a fixed budget. + + In practice, we do not know `B` ahead of time, but we have a way to inject new bits + when the current range decreases below a given limit (given by `total_range_bits`), without + having to redo all the computations. If we encode mostly likely values, we will seldom + need to inject new bits, but a single rare value can deplete our stock of entropy! + + In this explanation, we assumed that the distribution `p` was constant. In fact, the present + code works for any sequence `(p_t)` possibly different for each timestep. + We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller + the KL between the true distribution and `p_t`, the most efficient the coding will be. + + Args: + fo (IO[bytes]): file-like object to which the bytes will be written to. + total_range_bits (int): the range `M` described above is `2 ** total_range_bits. + Any time the current range width fall under this limit, new bits will + be injected to rescale the initial range. + """ + + def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24): + assert total_range_bits <= 30 + self.total_range_bits = total_range_bits + self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time. + self.low: int = 0 + self.high: int = 0 + self.max_bit: int = -1 + self._dbg: tp.List[tp.Any] = [] + self._dbg2: tp.List[tp.Any] = [] + + @property + def delta(self) -> int: + """Return the current range width.""" + return self.high - self.low + 1 + + def _flush_common_prefix(self): + # If self.low and self.high start with the sames bits, + # those won't change anymore as we always just increase the range + # by powers of 2, and we can flush them out to the bit stream. + assert self.high >= self.low, (self.low, self.high) + assert self.high < 2 ** (self.max_bit + 1) + while self.max_bit >= 0: + b1 = self.low >> self.max_bit + b2 = self.high >> self.max_bit + if b1 == b2: + self.low -= (b1 << self.max_bit) + self.high -= (b1 << self.max_bit) + assert self.high >= self.low, (self.high, self.low, self.max_bit) + assert self.low >= 0 + self.max_bit -= 1 + self.packer.push(b1) + else: + break + + def push(self, symbol: int, quantized_cdf: torch.Tensor): + """Push the given symbol on the stream, flushing out bits + if possible. + + Args: + symbol (int): symbol to encode with the AC. + quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf` + to build this from your pdf estimate. + """ + while self.delta < 2 ** self.total_range_bits: + self.low *= 2 + self.high = self.high * 2 + 1 + self.max_bit += 1 + + range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item() + range_high = quantized_cdf[symbol].item() - 1 + effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits)))) + effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits)))) + assert self.low <= self.high + self.high = self.low + effective_high + self.low = self.low + effective_low + assert self.low <= self.high, (effective_low, effective_high, range_low, range_high) + self._dbg.append((self.low, self.high)) + self._dbg2.append((self.low, self.high)) + outs = self._flush_common_prefix() + assert self.low <= self.high + assert self.max_bit >= -1 + assert self.max_bit <= 61, self.max_bit + return outs + + def flush(self): + """Flush the remaining information to the stream. + """ + while self.max_bit >= 0: + b1 = (self.low >> self.max_bit) & 1 + self.packer.push(b1) + self.max_bit -= 1 + self.packer.flush() + + +class ArithmeticDecoder: + """ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation. + + Note that this must be called with **exactly** the same parameters and sequence + of quantized cdf as the arithmetic encoder or the wrong values will be decoded. + + If the AC encoder current range is [L, H], with `L` and `H` having the some common + prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream. + For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside + `[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained + for a specific sequence of symbols and a binary-search allows us to decode those symbols. + At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols, + and we will need to read new bits from the stream and repeat the process. + + """ + def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24): + self.total_range_bits = total_range_bits + self.low: int = 0 + self.high: int = 0 + self.current: int = 0 + self.max_bit: int = -1 + self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time. + # Following is for debugging + self._dbg: tp.List[tp.Any] = [] + self._dbg2: tp.List[tp.Any] = [] + self._last: tp.Any = None + + @property + def delta(self) -> int: + return self.high - self.low + 1 + + def _flush_common_prefix(self): + # Given the current range [L, H], if both have a common prefix, + # we know we can remove it from our representation to avoid handling large numbers. + while self.max_bit >= 0: + b1 = self.low >> self.max_bit + b2 = self.high >> self.max_bit + if b1 == b2: + self.low -= (b1 << self.max_bit) + self.high -= (b1 << self.max_bit) + self.current -= (b1 << self.max_bit) + assert self.high >= self.low + assert self.low >= 0 + self.max_bit -= 1 + else: + break + + def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]: + """Pull a symbol, reading as many bits from the stream as required. + This returns `None` when the stream has been exhausted. + + Args: + quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf` + to build this from your pdf estimate. This must be **exatly** + the same cdf as the one used at encoding time. + """ + while self.delta < 2 ** self.total_range_bits: + bit = self.unpacker.pull() + if bit is None: + return None + self.low *= 2 + self.high = self.high * 2 + 1 + self.current = self.current * 2 + bit + self.max_bit += 1 + + def bin_search(low_idx: int, high_idx: int): + # Binary search is not just for coding interviews :) + if high_idx < low_idx: + raise RuntimeError("Binary search failed") + mid = (low_idx + high_idx) // 2 + range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0 + range_high = quantized_cdf[mid].item() - 1 + effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits)))) + effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits)))) + low = effective_low + self.low + high = effective_high + self.low + if self.current >= low: + if self.current <= high: + return (mid, low, high, self.current) + else: + return bin_search(mid + 1, high_idx) + else: + return bin_search(low_idx, mid - 1) + + self._last = (self.low, self.high, self.current, self.max_bit) + sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1) + self._dbg.append((self.low, self.high, self.current)) + self._flush_common_prefix() + self._dbg2.append((self.low, self.high, self.current)) + + return sym + + +def test(): + torch.manual_seed(1234) + random.seed(1234) + for _ in range(4): + pdfs = [] + cardinality = random.randrange(4000) + steps = random.randrange(100, 500) + fo = io.BytesIO() + encoder = ArithmeticCoder(fo) + symbols = [] + for step in range(steps): + pdf = torch.softmax(torch.randn(cardinality), dim=0) + pdfs.append(pdf) + q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) + symbol = torch.multinomial(pdf, 1).item() + symbols.append(symbol) + encoder.push(symbol, q_cdf) + encoder.flush() + + fo.seek(0) + decoder = ArithmeticDecoder(fo) + for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)): + q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) + decoded_symbol = decoder.pull(q_cdf) + assert decoded_symbol == symbol, idx + assert decoder.pull(torch.zeros(1)) is None + + +if __name__ == "__main__": + test() diff --git a/encoder/quantization/core_vq.py b/encoder/quantization/core_vq.py new file mode 100644 index 0000000000000000000000000000000000000000..774781c2947622e6c0c7a55c6eded26a2813b7c7 --- /dev/null +++ b/encoder/quantization/core_vq.py @@ -0,0 +1,421 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# This implementation is inspired from +# https://github.com/lucidrains/vector-quantize-pytorch +# which is released under MIT License. Hereafter, the original license: +# MIT License +# +# Copyright (c) 2020 Phil Wang +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Core vector quantization implementation.""" + +import typing as tp +import warnings + +from einops import rearrange, repeat +import torch +from torch import nn +import torch.nn.functional as F + +from .. import distrib + + +def default(val: tp.Any, d: tp.Any) -> tp.Any: + return val if val is not None else d + + +def ema_inplace(moving_avg, new, decay: float): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): + return (x + epsilon) / (x.sum() + n_categories * epsilon) + + +def uniform_init(*shape: int): + t = torch.empty(shape) + nn.init.kaiming_uniform_(t) + return t + + +def sample_vectors(samples, num: int): + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device=device)[:num] + else: + indices = torch.randint(0, num_samples, (num,), device=device) + + return samples[indices] + + +def kmeans(samples, num_clusters: int, num_iters: int = 10): + dim, dtype = samples.shape[-1], samples.dtype + + means = sample_vectors(samples, num_clusters) + + for _ in range(num_iters): + diffs = rearrange(samples, "n d -> n () d") - rearrange( + means, "c d -> () c d" + ) + dists = -(diffs ** 2).sum(dim=-1) + + buckets = dists.max(dim=-1).indices + bins = torch.bincount(buckets, minlength=num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) + new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) + new_means = new_means / bins_min_clamped[..., None] + + means = torch.where(zero_mask[..., None], means, new_means) + + return means, bins + + +class EuclideanCodebook(nn.Module): + """Codebook with Euclidean distance. + Args: + dim (int): Dimension. + codebook_size (int): Codebook size. + kmeans_init (bool): Whether to use k-means to initialize the codebooks. + If set to true, run the k-means algorithm on the first training batch and use + the learned centroids as initialization. + kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + def __init__( + self, + dim: int, + codebook_size: int, + kmeans_init: int = False, + kmeans_iters: int = 10, + decay: float = 0.99, + epsilon: float = 1e-5, + threshold_ema_dead_code: int = 2, + ): + super().__init__() + self.decay = decay + init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros + embed = init_fn(codebook_size, dim) + + self.codebook_size = codebook_size + + self.kmeans_iters = kmeans_iters + self.epsilon = epsilon + self.threshold_ema_dead_code = threshold_ema_dead_code + + self.register_buffer("inited", torch.Tensor([not kmeans_init])) + self.register_buffer("cluster_size", torch.zeros(codebook_size)) + self.register_buffer("embed", embed) + self.register_buffer("embed_avg", embed.clone()) + + @torch.jit.ignore + def init_embed_(self, data): + if self.inited: + return + + embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) #data不变 + self.embed.data.copy_(embed) + self.embed_avg.data.copy_(embed.clone()) + self.cluster_size.data.copy_(cluster_size) + self.inited.data.copy_(torch.Tensor([True])) + # Make sure all buffers across workers are in sync after initialization + distrib.broadcast_tensors(self.buffers()) + + def replace_(self, samples, mask): + modified_codebook = torch.where( + mask[..., None], sample_vectors(samples, self.codebook_size), self.embed + ) + self.embed.data.copy_(modified_codebook) + + def expire_codes_(self, batch_samples): + if self.threshold_ema_dead_code == 0: + return + + expired_codes = self.cluster_size < self.threshold_ema_dead_code + if not torch.any(expired_codes): + return + + batch_samples = rearrange(batch_samples, "... d -> (...) d") + self.replace_(batch_samples, mask=expired_codes) + distrib.broadcast_tensors(self.buffers()) + + def preprocess(self, x): + x = rearrange(x, "... d -> (...) d") + return x + + def quantize(self, x): + embed = self.embed.t() + dist = -( + x.pow(2).sum(1, keepdim=True) + - 2 * x @ embed + + embed.pow(2).sum(0, keepdim=True) + ) + embed_ind = dist.max(dim=-1).indices + return embed_ind + + def postprocess_emb(self, embed_ind, shape): + return embed_ind.view(*shape[:-1]) + + def dequantize(self, embed_ind): + quantize = F.embedding(embed_ind, self.embed) + return quantize + + def encode(self, x): + shape = x.shape + # pre-process + x = self.preprocess(x) + # quantize + embed_ind = self.quantize(x) + # post-process + embed_ind = self.postprocess_emb(embed_ind, shape) + return embed_ind + + def decode(self, embed_ind): + quantize = self.dequantize(embed_ind) + return quantize + + def forward(self, x): + shape, dtype = x.shape, x.dtype + x = self.preprocess(x) + + self.init_embed_(x) + + embed_ind = self.quantize(x) + embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) + embed_ind = self.postprocess_emb(embed_ind, shape) + quantize = self.dequantize(embed_ind) + + if self.training: + # We do the expiry of code at that point as buffers are in sync + # and all the workers will take the same decision. + self.expire_codes_(x) + ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) + embed_sum = x.t() @ embed_onehot + ema_inplace(self.embed_avg, embed_sum.t(), self.decay) + cluster_size = ( + laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) + * self.cluster_size.sum() + ) + embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) + self.embed.data.copy_(embed_normalized) + + return quantize, embed_ind + + +class VectorQuantization(nn.Module): + """Vector quantization implementation. + Currently supports only euclidean distance. + Args: + dim (int): Dimension + codebook_size (int): Codebook size + codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + commitment_weight (float): Weight for commitment loss. + """ + def __init__( + self, + dim: int, + codebook_size: int, + codebook_dim: tp.Optional[int] = None, + decay: float = 0.99, + epsilon: float = 1e-5, + kmeans_init: bool = True, + kmeans_iters: int = 50, + threshold_ema_dead_code: int = 2, + commitment_weight: float = 1., + ): + super().__init__() + _codebook_dim: int = default(codebook_dim, dim) + + requires_projection = _codebook_dim != dim + self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()) + self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()) + + self.epsilon = epsilon + self.commitment_weight = commitment_weight + + self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size, + kmeans_init=kmeans_init, kmeans_iters=kmeans_iters, + decay=decay, epsilon=epsilon, + threshold_ema_dead_code=threshold_ema_dead_code) + self.codebook_size = codebook_size + + @property + def codebook(self): + return self._codebook.embed + + def encode(self, x): + x = rearrange(x, "b d n -> b n d") + x = self.project_in(x) + embed_in = self._codebook.encode(x) + return embed_in + + def decode(self, embed_ind): + quantize = self._codebook.decode(embed_ind) + quantize = self.project_out(quantize) + quantize = rearrange(quantize, "b n d -> b d n") + return quantize + + def forward(self, x): + + # breakpoint() + device = x.device + x = rearrange(x, "b d n -> b n d") + x = self.project_in(x) + quantize, embed_ind = self._codebook(x) + if self.training: + quantize = x + (quantize - x).detach() + loss = torch.tensor([0.0], device=device, requires_grad=self.training) + + if self.training: + # warnings.warn('When using RVQ in training model, first check ' + # 'https://github.com/facebookresearch/encodec/issues/25 . ' + # 'The bug wasn\'t fixed here for reproducibility.') + if self.commitment_weight > 0: + commit_loss = F.mse_loss(quantize.detach(), x) + loss = loss + commit_loss * self.commitment_weight + + quantize = self.project_out(quantize) + quantize = rearrange(quantize, "b n d -> b d n") + return quantize, embed_ind, loss + + +class ResidualVectorQuantization(nn.Module): + """Residual vector quantization implementation. + Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf + """ + def __init__(self, *, num_quantizers, **kwargs): + super().__init__() + self.layers = nn.ModuleList( + [VectorQuantization(**kwargs) for _ in range(num_quantizers)] + ) + + def forward(self, x, n_q: tp.Optional[int] = None): + quantized_out = 0.0 + residual = x + + all_losses = [] + all_indices = [] + + n_q = n_q or len(self.layers) + for layer in self.layers[:n_q]: + quantized, indices, loss = layer(residual) + residual = residual - quantized.detach() + quantized_out = quantized_out + quantized + all_indices.append(indices) + all_losses.append(loss) + + out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) + return quantized_out, out_indices, out_losses + + def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor: + residual = x + all_indices = [] + n_q = n_q or len(self.layers) + for layer in self.layers[:n_q]: + indices = layer.encode(residual) + all_indices.append(indices) + quantized = layer.decode(indices) + residual = residual - quantized.detach() + out_indices = torch.stack(all_indices) + return out_indices + + def decode(self, q_indices: torch.Tensor) -> torch.Tensor: + quantized_out = torch.tensor(0.0, device=q_indices.device) + for i, indices in enumerate(q_indices): + layer = self.layers[i] + quantized = layer.decode(indices) + quantized_out = quantized_out + quantized + return quantized_out + + +class LanguageVectorQuantization(nn.Module): + """Residual vector quantization implementation. + Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf + """ + def __init__(self, *, num_quantizers, **kwargs): + super().__init__() + self.layers = nn.ModuleList( + [VectorQuantization(**kwargs) for _ in range(num_quantizers)] + ) + # print("core_vq.py:self.layers",self.layers) + + def forward(self, x, n_q: tp.Optional[int] = None): + # breakpoint() x[b,t,c] #[64,75,128] + quantized_out = 0.0 + residual = x + + + all_losses = [] + all_indices = [] + + # breakpoint() + + n_q = n_q or len(self.layers) + + for layer in self.layers[:n_q]: + quantized_out, indices, loss = layer(residual) #得到该层的表征,该层的indices,该层的loss [64,75] + # residual = residual - quantized.detach() + # quantized_out = quantized_out + quantized + all_indices.append(indices) + all_losses.append(loss) + # breakpoint() + # breakpoint() + + out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) + return quantized_out, out_indices, out_losses + + def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor: + residual = x + all_indices = [] + n_q = n_q or len(self.layers) + for layer in self.layers[:n_q]: + indices = layer.encode(residual) + all_indices.append(indices) + quantized = layer.decode(indices) + residual = residual - quantized.detach() + out_indices = torch.stack(all_indices) + return out_indices + + def decode(self, q_indices: torch.Tensor) -> torch.Tensor: + quantized_out = torch.tensor(0.0, device=q_indices.device) + for i, indices in enumerate(q_indices): + layer = self.layers[i] + quantized = layer.decode(indices) + quantized_out = quantized_out + quantized + return quantized_out \ No newline at end of file diff --git a/encoder/quantization/vq.py b/encoder/quantization/vq.py new file mode 100644 index 0000000000000000000000000000000000000000..f8e316b4bf912c2a743cd27fe038a17e85bceb13 --- /dev/null +++ b/encoder/quantization/vq.py @@ -0,0 +1,172 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Residual vector quantizer implementation.""" + +from dataclasses import dataclass, field +import math +import typing as tp + +import torch +from torch import nn + +from .core_vq import ResidualVectorQuantization,LanguageVectorQuantization + + +@dataclass +class QuantizedResult: + quantized: torch.Tensor + codes: torch.Tensor + bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item. + penalty: tp.Optional[torch.Tensor] = None + metrics: dict = field(default_factory=dict) + + +class ResidualVectorQuantizer(nn.Module): + """Residual Vector Quantizer. + Args: + dimension (int): Dimension of the codebooks. + n_q (int): Number of residual vector quantizers used. + bins (int): Codebook size. + decay (float): Decay for exponential moving average over the codebooks. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + def __init__( + self, + dimension: int = 256, + n_q: int = 8, + bins: int = 1024, + decay: float = 0.99, + kmeans_init: bool = True, + kmeans_iters: int = 50, + threshold_ema_dead_code: int = 2, + ): + super().__init__() + self.n_q = n_q + self.dimension = dimension + self.bins = bins + self.decay = decay + self.kmeans_init = kmeans_init + self.kmeans_iters = kmeans_iters + self.threshold_ema_dead_code = threshold_ema_dead_code + + # print(self.bins) + + # breakpoint() + + self.vq = LanguageVectorQuantization( + dim=self.dimension, + codebook_size=self.bins, + num_quantizers=self.n_q, + decay=self.decay, + kmeans_init=self.kmeans_init, + kmeans_iters=self.kmeans_iters, + threshold_ema_dead_code=self.threshold_ema_dead_code, + ) + # self.vq = ResidualVectorQuantization( + # dim=self.dimension, + # codebook_size=self.bins, + # num_quantizers=self.n_q, + # decay=self.decay, + # kmeans_init=self.kmeans_init, + # kmeans_iters=self.kmeans_iters, + # threshold_ema_dead_code=self.threshold_ema_dead_code, + # ) + + + def forward(self, x: torch.Tensor, frame_rate: int, bandwidth: tp.Optional[float] = None) -> QuantizedResult: + """Residual vector quantization on the given input tensor. + Args: + x (torch.Tensor): Input tensor. + frame_rate (int): Sample rate of the input tensor. + bandwidth (float): Target bandwidth. + Returns: + QuantizedResult: + The quantized (or approximately quantized) representation with + the associated bandwidth and any penalty term for the loss. + """ + # breakpoint() + + + bw_per_q = self.get_bandwidth_per_quantizer(frame_rate) + n_q = self.get_num_quantizers_for_bandwidth(frame_rate, bandwidth) + # assert n_q==4 + # breakpoint() + # nq_choice=[3,4,8] + nq_choice=[4,6,8] + if self.training: + # choice = int(torch.randint(0, 3, (1,)).item()) + choice = int(torch.randint(0, 3, (1,)).item()) + # breakpoint() + n_q=nq_choice[choice] + # breakpoint() + # n_q=8 + quantized, codes, commit_loss = self.vq(x, n_q=n_q) + bw = torch.tensor(n_q * bw_per_q).to(x) + return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss)) + + def infer(self, x: torch.Tensor, frame_rate: int, bandwidth: tp.Optional[float] = None) -> QuantizedResult: + """Residual vector quantization on the given input tensor. + Args: + x (torch.Tensor): Input tensor. + frame_rate (int): Sample rate of the input tensor. + bandwidth (float): Target bandwidth. + Returns: + QuantizedResult: + The quantized (or approximately quantized) representation with + the associated bandwidth and any penalty term for the loss. + """ + bw_per_q = self.get_bandwidth_per_quantizer(frame_rate) + # n_q = self.get_num_quantizers_for_bandwidth(frame_rate, bandwidth) + # # assert n_q==4 + # # breakpoint() + # # nq_choice=[3,4,8] + # nq_choice=[3,4,5,6,7,8] + # if self.training: + # # choice = int(torch.randint(0, 3, (1,)).item()) + # choice = int(torch.randint(0, 6, (1,)).item()) + # # breakpoint() + # n_q=nq_choice[choice] + n_q=1 + quantized, codes, commit_loss = self.vq(x, n_q=n_q) + bw = torch.tensor(n_q * bw_per_q).to(x) + return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss)) + + def get_num_quantizers_for_bandwidth(self, frame_rate: int, bandwidth: tp.Optional[float] = None) -> int: + """Return n_q based on specified target bandwidth. + """ + bw_per_q = self.get_bandwidth_per_quantizer(frame_rate) + n_q = self.n_q + if bandwidth and bandwidth > 0.: + # bandwidth is represented as a thousandth of what it is, e.g. 6kbps bandwidth is represented as + # bandwidth == 6.0 + n_q = int(max(1, math.floor(bandwidth * 1000 / bw_per_q))) + return n_q + + def get_bandwidth_per_quantizer(self, frame_rate: int): + """Return bandwidth per quantizer for a given input frame rate. + Each quantizer encodes a frame with lg(bins) bits. + """ + return math.log2(self.bins) * frame_rate + + def encode(self, x: torch.Tensor, frame_rate: int, bandwidth: tp.Optional[float] = None) -> torch.Tensor: + """Encode a given input tensor with the specified frame rate at the given bandwidth. + The RVQ encode method sets the appropriate number of quantizers to use + and returns indices for each quantizer. + """ + n_q = self.get_num_quantizers_for_bandwidth(frame_rate, bandwidth) + codes = self.vq.encode(x, n_q=n_q) + return codes + + def decode(self, codes: torch.Tensor) -> torch.Tensor: + """Decode the given codes to the quantized representation. + """ + quantized = self.vq.decode(codes) + return quantized diff --git a/encoder/utils.py b/encoder/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e3f0f9e9bcb37f2267b2f8adefabfc3672453dc5 --- /dev/null +++ b/encoder/utils.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Various utilities.""" + +from hashlib import sha256 +from pathlib import Path +import typing as tp + +import torch +import torchaudio + + +def _linear_overlap_add(frames: tp.List[torch.Tensor], stride: int): + # Generic overlap add, with linear fade-in/fade-out, supporting complex scenario + # e.g., more than 2 frames per position. + # The core idea is to use a weight function that is a triangle, + # with a maximum value at the middle of the segment. + # We use this weighting when summing the frames, and divide by the sum of weights + # for each positions at the end. Thus: + # - if a frame is the only one to cover a position, the weighting is a no-op. + # - if 2 frames cover a position: + # ... ... + # / \/ \ + # / /\ \ + # S T , i.e. S offset of second frame starts, T end of first frame. + # Then the weight function for each one is: (t - S), (T - t), with `t` a given offset. + # After the final normalization, the weight of the second frame at position `t` is + # (t - S) / (t - S + (T - t)) = (t - S) / (T - S), which is exactly what we want. + # + # - if more than 2 frames overlap at a given point, we hope that by induction + # something sensible happens. + assert len(frames) + device = frames[0].device + dtype = frames[0].dtype + shape = frames[0].shape[:-1] + total_size = stride * (len(frames) - 1) + frames[-1].shape[-1] + + frame_length = frames[0].shape[-1] + t = torch.linspace(0, 1, frame_length + 2, device=device, dtype=dtype)[1: -1] + weight = 0.5 - (t - 0.5).abs() + + sum_weight = torch.zeros(total_size, device=device, dtype=dtype) + out = torch.zeros(*shape, total_size, device=device, dtype=dtype) + offset: int = 0 + + for frame in frames: + frame_length = frame.shape[-1] + out[..., offset:offset + frame_length] += weight[:frame_length] * frame + sum_weight[offset:offset + frame_length] += weight[:frame_length] + offset += stride + assert sum_weight.min() > 0 + return out / sum_weight + + +def _get_checkpoint_url(root_url: str, checkpoint: str): + if not root_url.endswith('/'): + root_url += '/' + return root_url + checkpoint + + +def _check_checksum(path: Path, checksum: str): + sha = sha256() + with open(path, 'rb') as file: + while True: + buf = file.read(2**20) + if not buf: + break + sha.update(buf) + actual_checksum = sha.hexdigest()[:len(checksum)] + if actual_checksum != checksum: + raise RuntimeError(f'Invalid checksum for file {path}, ' + f'expected {checksum} but got {actual_checksum}') + + +def convert_audio(wav: torch.Tensor, sr: int, target_sr: int, target_channels: int): + assert wav.dim() >= 2, "Audio tensor must have at least 2 dimensions" + assert wav.shape[-2] in [1, 2], "Audio must be mono or stereo." + *shape, channels, length = wav.shape + if target_channels == 1: + wav = wav.mean(-2, keepdim=True) + elif target_channels == 2: + wav = wav.expand(*shape, target_channels, length) + elif channels == 1: + wav = wav.expand(target_channels, -1) + else: + raise RuntimeError(f"Impossible to convert from {channels} to {target_channels}") + wav = torchaudio.transforms.Resample(sr, target_sr)(wav) + return wav + + +def save_audio(wav: torch.Tensor, path: tp.Union[Path, str], + sample_rate: int, rescale: bool = False): + limit = 0.99 + mx = wav.abs().max() + if rescale: + wav = wav * min(limit / mx, 1) + else: + wav = wav.clamp(-limit, limit) + torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)