Spaces:
Build error
Build error
Upload 54 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- encoder/__init__.py +12 -0
- encoder/__pycache__/__init__.cpython-310.pyc +0 -0
- encoder/__pycache__/__init__.cpython-38.pyc +0 -0
- encoder/__pycache__/__init__.cpython-39.pyc +0 -0
- encoder/__pycache__/distrib.cpython-310.pyc +0 -0
- encoder/__pycache__/distrib.cpython-38.pyc +0 -0
- encoder/__pycache__/distrib.cpython-39.pyc +0 -0
- encoder/__pycache__/model.cpython-310.pyc +0 -0
- encoder/__pycache__/model.cpython-38.pyc +0 -0
- encoder/__pycache__/model.cpython-39.pyc +0 -0
- encoder/__pycache__/utils.cpython-310.pyc +0 -0
- encoder/__pycache__/utils.cpython-38.pyc +0 -0
- encoder/__pycache__/utils.cpython-39.pyc +0 -0
- encoder/distrib.py +124 -0
- encoder/model.py +324 -0
- encoder/modules/__init__.py +22 -0
- encoder/modules/__pycache__/__init__.cpython-310.pyc +0 -0
- encoder/modules/__pycache__/__init__.cpython-38.pyc +0 -0
- encoder/modules/__pycache__/__init__.cpython-39.pyc +0 -0
- encoder/modules/__pycache__/conv.cpython-310.pyc +0 -0
- encoder/modules/__pycache__/conv.cpython-38.pyc +0 -0
- encoder/modules/__pycache__/conv.cpython-39.pyc +0 -0
- encoder/modules/__pycache__/lstm.cpython-310.pyc +0 -0
- encoder/modules/__pycache__/lstm.cpython-38.pyc +0 -0
- encoder/modules/__pycache__/lstm.cpython-39.pyc +0 -0
- encoder/modules/__pycache__/norm.cpython-310.pyc +0 -0
- encoder/modules/__pycache__/norm.cpython-38.pyc +0 -0
- encoder/modules/__pycache__/norm.cpython-39.pyc +0 -0
- encoder/modules/__pycache__/seanet.cpython-310.pyc +0 -0
- encoder/modules/__pycache__/seanet.cpython-38.pyc +0 -0
- encoder/modules/__pycache__/seanet.cpython-39.pyc +0 -0
- encoder/modules/__pycache__/transformer.cpython-310.pyc +0 -0
- encoder/modules/__pycache__/transformer.cpython-38.pyc +0 -0
- encoder/modules/__pycache__/transformer.cpython-39.pyc +0 -0
- encoder/modules/conv.py +253 -0
- encoder/modules/lstm.py +39 -0
- encoder/modules/norm.py +28 -0
- encoder/modules/seanet.py +253 -0
- encoder/modules/transformer.py +119 -0
- encoder/msstftd.py +147 -0
- encoder/quantization/__init__.py +8 -0
- encoder/quantization/__pycache__/__init__.cpython-310.pyc +0 -0
- encoder/quantization/__pycache__/__init__.cpython-38.pyc +0 -0
- encoder/quantization/__pycache__/__init__.cpython-39.pyc +0 -0
- encoder/quantization/__pycache__/core_vq.cpython-310.pyc +0 -0
- encoder/quantization/__pycache__/core_vq.cpython-38.pyc +0 -0
- encoder/quantization/__pycache__/core_vq.cpython-39.pyc +0 -0
- encoder/quantization/__pycache__/vq.cpython-310.pyc +0 -0
- encoder/quantization/__pycache__/vq.cpython-38.pyc +0 -0
- encoder/quantization/__pycache__/vq.cpython-39.pyc +0 -0
encoder/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# flake8: noqa
|
| 7 |
+
|
| 8 |
+
"""EnCodec neural audio codec."""
|
| 9 |
+
|
| 10 |
+
__version__ = "0.1.2a3"
|
| 11 |
+
|
| 12 |
+
from .model import EncodecModel
|
encoder/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (268 Bytes). View file
|
|
|
encoder/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (278 Bytes). View file
|
|
|
encoder/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (266 Bytes). View file
|
|
|
encoder/__pycache__/distrib.cpython-310.pyc
ADDED
|
Binary file (3.74 kB). View file
|
|
|
encoder/__pycache__/distrib.cpython-38.pyc
ADDED
|
Binary file (3.79 kB). View file
|
|
|
encoder/__pycache__/distrib.cpython-39.pyc
ADDED
|
Binary file (3.76 kB). View file
|
|
|
encoder/__pycache__/model.cpython-310.pyc
ADDED
|
Binary file (11.7 kB). View file
|
|
|
encoder/__pycache__/model.cpython-38.pyc
ADDED
|
Binary file (11.7 kB). View file
|
|
|
encoder/__pycache__/model.cpython-39.pyc
ADDED
|
Binary file (11.6 kB). View file
|
|
|
encoder/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (2.68 kB). View file
|
|
|
encoder/__pycache__/utils.cpython-38.pyc
ADDED
|
Binary file (2.65 kB). View file
|
|
|
encoder/__pycache__/utils.cpython-39.pyc
ADDED
|
Binary file (2.66 kB). View file
|
|
|
encoder/distrib.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Torch distributed utilities."""
|
| 8 |
+
|
| 9 |
+
import typing as tp
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def rank():
|
| 15 |
+
if torch.distributed.is_initialized():
|
| 16 |
+
return torch.distributed.get_rank()
|
| 17 |
+
else:
|
| 18 |
+
return 0
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def world_size():
|
| 22 |
+
if torch.distributed.is_initialized():
|
| 23 |
+
return torch.distributed.get_world_size()
|
| 24 |
+
else:
|
| 25 |
+
return 1
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def is_distributed():
|
| 29 |
+
return world_size() > 1
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
|
| 33 |
+
if is_distributed():
|
| 34 |
+
return torch.distributed.all_reduce(tensor, op)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _is_complex_or_float(tensor):
|
| 38 |
+
return torch.is_floating_point(tensor) or torch.is_complex(tensor)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _check_number_of_params(params: tp.List[torch.Tensor]):
|
| 42 |
+
# utility function to check that the number of params in all workers is the same,
|
| 43 |
+
# and thus avoid a deadlock with distributed all reduce.
|
| 44 |
+
if not is_distributed() or not params:
|
| 45 |
+
return
|
| 46 |
+
tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
|
| 47 |
+
all_reduce(tensor)
|
| 48 |
+
if tensor.item() != len(params) * world_size():
|
| 49 |
+
# If not all the workers have the same number, for at least one of them,
|
| 50 |
+
# this inequality will be verified.
|
| 51 |
+
raise RuntimeError(f"Mismatch in number of params: ours is {len(params)}, "
|
| 52 |
+
"at least one worker has a different one.")
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0):
|
| 56 |
+
"""Broadcast the tensors from the given parameters to all workers.
|
| 57 |
+
This can be used to ensure that all workers have the same model to start with.
|
| 58 |
+
"""
|
| 59 |
+
if not is_distributed():
|
| 60 |
+
return
|
| 61 |
+
tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
|
| 62 |
+
_check_number_of_params(tensors)
|
| 63 |
+
handles = []
|
| 64 |
+
for tensor in tensors:
|
| 65 |
+
handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True)
|
| 66 |
+
handles.append(handle)
|
| 67 |
+
for handle in handles:
|
| 68 |
+
handle.wait()
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def sync_buffer(buffers, average=True):
|
| 72 |
+
"""
|
| 73 |
+
Sync grad for buffers. If average is False, broadcast instead of averaging.
|
| 74 |
+
"""
|
| 75 |
+
if not is_distributed():
|
| 76 |
+
return
|
| 77 |
+
handles = []
|
| 78 |
+
for buffer in buffers:
|
| 79 |
+
if torch.is_floating_point(buffer.data):
|
| 80 |
+
if average:
|
| 81 |
+
handle = torch.distributed.all_reduce(
|
| 82 |
+
buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
|
| 83 |
+
else:
|
| 84 |
+
handle = torch.distributed.broadcast(
|
| 85 |
+
buffer.data, src=0, async_op=True)
|
| 86 |
+
handles.append((buffer, handle))
|
| 87 |
+
for buffer, handle in handles:
|
| 88 |
+
handle.wait()
|
| 89 |
+
if average:
|
| 90 |
+
buffer.data /= world_size
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def sync_grad(params):
|
| 94 |
+
"""
|
| 95 |
+
Simpler alternative to DistributedDataParallel, that doesn't rely
|
| 96 |
+
on any black magic. For simple models it can also be as fast.
|
| 97 |
+
Just call this on your model parameters after the call to backward!
|
| 98 |
+
"""
|
| 99 |
+
if not is_distributed():
|
| 100 |
+
return
|
| 101 |
+
handles = []
|
| 102 |
+
for p in params:
|
| 103 |
+
if p.grad is not None:
|
| 104 |
+
handle = torch.distributed.all_reduce(
|
| 105 |
+
p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
|
| 106 |
+
handles.append((p, handle))
|
| 107 |
+
for p, handle in handles:
|
| 108 |
+
handle.wait()
|
| 109 |
+
p.grad.data /= world_size()
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def average_metrics(metrics: tp.Dict[str, float], count=1.):
|
| 113 |
+
"""Average a dictionary of metrics across all workers, using the optional
|
| 114 |
+
`count` as unnormalized weight.
|
| 115 |
+
"""
|
| 116 |
+
if not is_distributed():
|
| 117 |
+
return metrics
|
| 118 |
+
keys, values = zip(*metrics.items())
|
| 119 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 120 |
+
tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
|
| 121 |
+
tensor *= count
|
| 122 |
+
all_reduce(tensor)
|
| 123 |
+
averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
|
| 124 |
+
return dict(zip(keys, averaged))
|
encoder/model.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""EnCodec model implementation."""
|
| 8 |
+
|
| 9 |
+
import math
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import typing as tp
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
from torch import nn
|
| 16 |
+
|
| 17 |
+
from . import quantization as qt
|
| 18 |
+
from . import modules as m
|
| 19 |
+
from .utils import _check_checksum, _linear_overlap_add, _get_checkpoint_url
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
ROOT_URL = 'https://dl.fbaipublicfiles.com/encodec/v0/'
|
| 23 |
+
|
| 24 |
+
EncodedFrame = tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class LMModel(nn.Module):
|
| 28 |
+
"""Language Model to estimate probabilities of each codebook entry.
|
| 29 |
+
We predict all codebooks in parallel for a given time step.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
n_q (int): number of codebooks.
|
| 33 |
+
card (int): codebook cardinality.
|
| 34 |
+
dim (int): transformer dimension.
|
| 35 |
+
**kwargs: passed to `encoder.modules.transformer.StreamingTransformerEncoder`.
|
| 36 |
+
"""
|
| 37 |
+
def __init__(self, n_q: int = 32, card: int = 1024, dim: int = 200, **kwargs):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.card = card
|
| 40 |
+
self.n_q = n_q
|
| 41 |
+
self.dim = dim
|
| 42 |
+
self.transformer = m.StreamingTransformerEncoder(dim=dim, **kwargs)
|
| 43 |
+
self.emb = nn.ModuleList([nn.Embedding(card + 1, dim) for _ in range(n_q)])
|
| 44 |
+
self.linears = nn.ModuleList([nn.Linear(dim, card) for _ in range(n_q)])
|
| 45 |
+
|
| 46 |
+
def forward(self, indices: torch.Tensor,
|
| 47 |
+
states: tp.Optional[tp.List[torch.Tensor]] = None, offset: int = 0):
|
| 48 |
+
"""
|
| 49 |
+
Args:
|
| 50 |
+
indices (torch.Tensor): indices from the previous time step. Indices
|
| 51 |
+
should be 1 + actual index in the codebook. The value 0 is reserved for
|
| 52 |
+
when the index is missing (i.e. first time step). Shape should be
|
| 53 |
+
`[B, n_q, T]`.
|
| 54 |
+
states: state for the streaming decoding.
|
| 55 |
+
offset: offset of the current time step.
|
| 56 |
+
|
| 57 |
+
Returns a 3-tuple `(probabilities, new_states, new_offset)` with probabilities
|
| 58 |
+
with a shape `[B, card, n_q, T]`.
|
| 59 |
+
|
| 60 |
+
"""
|
| 61 |
+
B, K, T = indices.shape
|
| 62 |
+
input_ = sum([self.emb[k](indices[:, k]) for k in range(K)])
|
| 63 |
+
out, states, offset = self.transformer(input_, states, offset)
|
| 64 |
+
logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1).permute(0, 3, 1, 2)
|
| 65 |
+
return torch.softmax(logits, dim=1), states, offset
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class EncodecModel(nn.Module):
|
| 69 |
+
"""EnCodec model operating on the raw waveform.
|
| 70 |
+
Args:
|
| 71 |
+
target_bandwidths (list of float): Target bandwidths.
|
| 72 |
+
encoder (nn.Module): Encoder network.
|
| 73 |
+
decoder (nn.Module): Decoder network.
|
| 74 |
+
sample_rate (int): Audio sample rate.
|
| 75 |
+
channels (int): Number of audio channels.
|
| 76 |
+
normalize (bool): Whether to apply audio normalization.
|
| 77 |
+
segment (float or None): segment duration in sec. when doing overlap-add.
|
| 78 |
+
overlap (float): overlap between segment, given as a fraction of the segment duration.
|
| 79 |
+
name (str): name of the model, used as metadata when compressing audio.
|
| 80 |
+
"""
|
| 81 |
+
def __init__(self,
|
| 82 |
+
encoder: m.SEANetEncoder,
|
| 83 |
+
decoder: m.SEANetDecoder,
|
| 84 |
+
quantizer: qt.ResidualVectorQuantizer,
|
| 85 |
+
target_bandwidths: tp.List[float],
|
| 86 |
+
sample_rate: int,
|
| 87 |
+
channels: int,
|
| 88 |
+
normalize: bool = False,
|
| 89 |
+
segment: tp.Optional[float] = None,
|
| 90 |
+
overlap: float = 0.01,
|
| 91 |
+
name: str = 'unset'):
|
| 92 |
+
super().__init__()
|
| 93 |
+
self.bandwidth: tp.Optional[float] = None
|
| 94 |
+
self.target_bandwidths = target_bandwidths
|
| 95 |
+
self.encoder = encoder
|
| 96 |
+
self.quantizer = quantizer
|
| 97 |
+
self.decoder = decoder
|
| 98 |
+
self.sample_rate = sample_rate
|
| 99 |
+
self.channels = channels
|
| 100 |
+
self.normalize = normalize
|
| 101 |
+
self.segment = segment
|
| 102 |
+
self.overlap = overlap
|
| 103 |
+
self.frame_rate = math.ceil(self.sample_rate / np.prod(self.encoder.ratios))
|
| 104 |
+
self.name = name
|
| 105 |
+
self.bits_per_codebook = int(math.log2(self.quantizer.bins))
|
| 106 |
+
assert 2 ** self.bits_per_codebook == self.quantizer.bins, \
|
| 107 |
+
"quantizer bins must be a power of 2."
|
| 108 |
+
|
| 109 |
+
@property
|
| 110 |
+
def segment_length(self) -> tp.Optional[int]:
|
| 111 |
+
if self.segment is None:
|
| 112 |
+
return None
|
| 113 |
+
return int(self.segment * self.sample_rate)
|
| 114 |
+
|
| 115 |
+
@property
|
| 116 |
+
def segment_stride(self) -> tp.Optional[int]:
|
| 117 |
+
segment_length = self.segment_length
|
| 118 |
+
if segment_length is None:
|
| 119 |
+
return None
|
| 120 |
+
return max(1, int((1 - self.overlap) * segment_length))
|
| 121 |
+
|
| 122 |
+
def encode(self, x: torch.Tensor) -> tp.List[EncodedFrame]:
|
| 123 |
+
"""Given a tensor `x`, returns a list of frames containing
|
| 124 |
+
the discrete encoded codes for `x`, along with rescaling factors
|
| 125 |
+
for each segment, when `self.normalize` is True.
|
| 126 |
+
|
| 127 |
+
Each frames is a tuple `(codebook, scale)`, with `codebook` of
|
| 128 |
+
shape `[B, K, T]`, with `K` the number of codebooks.
|
| 129 |
+
"""
|
| 130 |
+
assert x.dim() == 3
|
| 131 |
+
_, channels, length = x.shape
|
| 132 |
+
assert channels > 0 and channels <= 2
|
| 133 |
+
segment_length = self.segment_length
|
| 134 |
+
if segment_length is None:
|
| 135 |
+
segment_length = length
|
| 136 |
+
stride = length
|
| 137 |
+
else:
|
| 138 |
+
stride = self.segment_stride # type: ignore
|
| 139 |
+
assert stride is not None
|
| 140 |
+
|
| 141 |
+
encoded_frames: tp.List[EncodedFrame] = []
|
| 142 |
+
for offset in range(0, length, stride):
|
| 143 |
+
frame = x[:, :, offset: offset + segment_length]
|
| 144 |
+
encoded_frames.append(self._encode_frame(frame))
|
| 145 |
+
return encoded_frames
|
| 146 |
+
|
| 147 |
+
def _encode_frame(self, x: torch.Tensor) -> EncodedFrame:
|
| 148 |
+
length = x.shape[-1]
|
| 149 |
+
duration = length / self.sample_rate
|
| 150 |
+
assert self.segment is None or duration <= 1e-5 + self.segment
|
| 151 |
+
|
| 152 |
+
if self.normalize:
|
| 153 |
+
mono = x.mean(dim=1, keepdim=True)
|
| 154 |
+
volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
|
| 155 |
+
scale = 1e-8 + volume
|
| 156 |
+
x = x / scale
|
| 157 |
+
scale = scale.view(-1, 1)
|
| 158 |
+
else:
|
| 159 |
+
scale = None
|
| 160 |
+
|
| 161 |
+
emb = self.encoder(x)
|
| 162 |
+
codes = self.quantizer.encode(emb, self.frame_rate, self.bandwidth)
|
| 163 |
+
codes = codes.transpose(0, 1)
|
| 164 |
+
# codes is [B, K, T], with T frames, K nb of codebooks.
|
| 165 |
+
return codes, scale
|
| 166 |
+
|
| 167 |
+
def decode(self, encoded_frames: tp.List[EncodedFrame]) -> torch.Tensor:
|
| 168 |
+
"""Decode the given frames into a waveform.
|
| 169 |
+
Note that the output might be a bit bigger than the input. In that case,
|
| 170 |
+
any extra steps at the end can be trimmed.
|
| 171 |
+
"""
|
| 172 |
+
segment_length = self.segment_length
|
| 173 |
+
if segment_length is None:
|
| 174 |
+
assert len(encoded_frames) == 1
|
| 175 |
+
return self._decode_frame(encoded_frames[0])
|
| 176 |
+
|
| 177 |
+
frames = [self._decode_frame(frame) for frame in encoded_frames]
|
| 178 |
+
return _linear_overlap_add(frames, self.segment_stride or 1)
|
| 179 |
+
|
| 180 |
+
def _decode_frame(self, encoded_frame: EncodedFrame) -> torch.Tensor:
|
| 181 |
+
codes, scale = encoded_frame
|
| 182 |
+
codes = codes.transpose(0, 1)
|
| 183 |
+
emb = self.quantizer.decode(codes)
|
| 184 |
+
out = self.decoder(emb)
|
| 185 |
+
if scale is not None:
|
| 186 |
+
out = out * scale.view(-1, 1, 1)
|
| 187 |
+
return out
|
| 188 |
+
|
| 189 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 190 |
+
frames = self.encode(x)
|
| 191 |
+
return self.decode(frames)[:, :, :x.shape[-1]]
|
| 192 |
+
|
| 193 |
+
def set_target_bandwidth(self, bandwidth: float):
|
| 194 |
+
if bandwidth not in self.target_bandwidths:
|
| 195 |
+
raise ValueError(f"This model doesn't support the bandwidth {bandwidth}. "
|
| 196 |
+
f"Select one of {self.target_bandwidths}.")
|
| 197 |
+
self.bandwidth = bandwidth
|
| 198 |
+
|
| 199 |
+
def get_lm_model(self) -> LMModel:
|
| 200 |
+
"""Return the associated LM model to improve the compression rate.
|
| 201 |
+
"""
|
| 202 |
+
device = next(self.parameters()).device
|
| 203 |
+
lm = LMModel(self.quantizer.n_q, self.quantizer.bins, num_layers=5, dim=200,
|
| 204 |
+
past_context=int(3.5 * self.frame_rate)).to(device)
|
| 205 |
+
checkpoints = {
|
| 206 |
+
'encodec_24khz': 'encodec_lm_24khz-1608e3c0.th',
|
| 207 |
+
'encodec_48khz': 'encodec_lm_48khz-7add9fc3.th',
|
| 208 |
+
}
|
| 209 |
+
try:
|
| 210 |
+
checkpoint_name = checkpoints[self.name]
|
| 211 |
+
except KeyError:
|
| 212 |
+
raise RuntimeError("No LM pre-trained for the current Encodec model.")
|
| 213 |
+
url = _get_checkpoint_url(ROOT_URL, checkpoint_name)
|
| 214 |
+
state = torch.hub.load_state_dict_from_url(
|
| 215 |
+
url, map_location='cpu', check_hash=True) # type: ignore
|
| 216 |
+
lm.load_state_dict(state)
|
| 217 |
+
lm.eval()
|
| 218 |
+
return lm
|
| 219 |
+
|
| 220 |
+
@staticmethod
|
| 221 |
+
def _get_model(target_bandwidths: tp.List[float],
|
| 222 |
+
sample_rate: int = 24_000,
|
| 223 |
+
channels: int = 1,
|
| 224 |
+
causal: bool = True,
|
| 225 |
+
model_norm: str = 'weight_norm',
|
| 226 |
+
audio_normalize: bool = False,
|
| 227 |
+
segment: tp.Optional[float] = None,
|
| 228 |
+
name: str = 'unset'):
|
| 229 |
+
encoder = m.SEANetEncoder(channels=channels, norm=model_norm, causal=causal)
|
| 230 |
+
decoder = m.SEANetDecoder(channels=channels, norm=model_norm, causal=causal)
|
| 231 |
+
n_q = int(1000 * target_bandwidths[-1] // (math.ceil(sample_rate / encoder.hop_length) * 10))
|
| 232 |
+
quantizer = qt.ResidualVectorQuantizer(
|
| 233 |
+
dimension=encoder.dimension,
|
| 234 |
+
n_q=n_q,
|
| 235 |
+
bins=1024,
|
| 236 |
+
)
|
| 237 |
+
model = EncodecModel(
|
| 238 |
+
encoder,
|
| 239 |
+
decoder,
|
| 240 |
+
quantizer,
|
| 241 |
+
target_bandwidths,
|
| 242 |
+
sample_rate,
|
| 243 |
+
channels,
|
| 244 |
+
normalize=audio_normalize,
|
| 245 |
+
segment=segment,
|
| 246 |
+
name=name,
|
| 247 |
+
)
|
| 248 |
+
return model
|
| 249 |
+
|
| 250 |
+
@staticmethod
|
| 251 |
+
def _get_pretrained(checkpoint_name: str, repository: tp.Optional[Path] = None):
|
| 252 |
+
if repository is not None:
|
| 253 |
+
if not repository.is_dir():
|
| 254 |
+
raise ValueError(f"{repository} must exist and be a directory.")
|
| 255 |
+
file = repository / checkpoint_name
|
| 256 |
+
checksum = file.stem.split('-')[1]
|
| 257 |
+
_check_checksum(file, checksum)
|
| 258 |
+
return torch.load(file)
|
| 259 |
+
else:
|
| 260 |
+
url = _get_checkpoint_url(ROOT_URL, checkpoint_name)
|
| 261 |
+
return torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True) # type:ignore
|
| 262 |
+
|
| 263 |
+
@staticmethod
|
| 264 |
+
def encodec_model_24khz(pretrained: bool = True, repository: tp.Optional[Path] = None):
|
| 265 |
+
"""Return the pretrained causal 24khz model.
|
| 266 |
+
"""
|
| 267 |
+
if repository:
|
| 268 |
+
assert pretrained
|
| 269 |
+
target_bandwidths = [1.5, 3., 6, 12., 24.]
|
| 270 |
+
checkpoint_name = 'encodec_24khz-d7cc33bc.th'
|
| 271 |
+
sample_rate = 24_000
|
| 272 |
+
channels = 1
|
| 273 |
+
model = EncodecModel._get_model(
|
| 274 |
+
target_bandwidths, sample_rate, channels,
|
| 275 |
+
causal=True, model_norm='weight_norm', audio_normalize=False,
|
| 276 |
+
name='encodec_24khz' if pretrained else 'unset')
|
| 277 |
+
if pretrained:
|
| 278 |
+
state_dict = EncodecModel._get_pretrained(checkpoint_name, repository)
|
| 279 |
+
model.load_state_dict(state_dict)
|
| 280 |
+
model.eval()
|
| 281 |
+
return model
|
| 282 |
+
|
| 283 |
+
@staticmethod
|
| 284 |
+
def encodec_model_48khz(pretrained: bool = True, repository: tp.Optional[Path] = None):
|
| 285 |
+
"""Return the pretrained 48khz model.
|
| 286 |
+
"""
|
| 287 |
+
if repository:
|
| 288 |
+
assert pretrained
|
| 289 |
+
target_bandwidths = [3., 6., 12., 24.]
|
| 290 |
+
checkpoint_name = 'encodec_48khz-7e698e3e.th'
|
| 291 |
+
sample_rate = 48_000
|
| 292 |
+
channels = 2
|
| 293 |
+
model = EncodecModel._get_model(
|
| 294 |
+
target_bandwidths, sample_rate, channels,
|
| 295 |
+
causal=False, model_norm='time_group_norm', audio_normalize=True,
|
| 296 |
+
segment=1., name='encodec_48khz' if pretrained else 'unset')
|
| 297 |
+
if pretrained:
|
| 298 |
+
state_dict = EncodecModel._get_pretrained(checkpoint_name, repository)
|
| 299 |
+
model.load_state_dict(state_dict)
|
| 300 |
+
model.eval()
|
| 301 |
+
return model
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def test():
|
| 305 |
+
from itertools import product
|
| 306 |
+
import torchaudio
|
| 307 |
+
bandwidths = [3, 6, 12, 24]
|
| 308 |
+
models = {
|
| 309 |
+
'encodec_24khz': EncodecModel.encodec_model_24khz,
|
| 310 |
+
'encodec_48khz': EncodecModel.encodec_model_48khz
|
| 311 |
+
}
|
| 312 |
+
for model_name, bw in product(models.keys(), bandwidths):
|
| 313 |
+
model = models[model_name]()
|
| 314 |
+
model.set_target_bandwidth(bw)
|
| 315 |
+
audio_suffix = model_name.split('_')[1][:3]
|
| 316 |
+
wav, sr = torchaudio.load(f"test_{audio_suffix}.wav")
|
| 317 |
+
wav = wav[:, :model.sample_rate * 2]
|
| 318 |
+
wav_in = wav.unsqueeze(0)
|
| 319 |
+
wav_dec = model(wav_in)[0]
|
| 320 |
+
assert wav.shape == wav_dec.shape, (wav.shape, wav_dec.shape)
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
if __name__ == '__main__':
|
| 324 |
+
test()
|
encoder/modules/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Torch modules."""
|
| 8 |
+
|
| 9 |
+
# flake8: noqa
|
| 10 |
+
from .conv import (
|
| 11 |
+
pad1d,
|
| 12 |
+
unpad1d,
|
| 13 |
+
NormConv1d,
|
| 14 |
+
NormConvTranspose1d,
|
| 15 |
+
NormConv2d,
|
| 16 |
+
NormConvTranspose2d,
|
| 17 |
+
SConv1d,
|
| 18 |
+
SConvTranspose1d,
|
| 19 |
+
)
|
| 20 |
+
from .lstm import SLSTM
|
| 21 |
+
from .seanet import SEANetEncoder, SEANetDecoder
|
| 22 |
+
from .transformer import StreamingTransformerEncoder
|
encoder/modules/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (557 Bytes). View file
|
|
|
encoder/modules/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (567 Bytes). View file
|
|
|
encoder/modules/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (555 Bytes). View file
|
|
|
encoder/modules/__pycache__/conv.cpython-310.pyc
ADDED
|
Binary file (9.2 kB). View file
|
|
|
encoder/modules/__pycache__/conv.cpython-38.pyc
ADDED
|
Binary file (9.48 kB). View file
|
|
|
encoder/modules/__pycache__/conv.cpython-39.pyc
ADDED
|
Binary file (9.43 kB). View file
|
|
|
encoder/modules/__pycache__/lstm.cpython-310.pyc
ADDED
|
Binary file (1.05 kB). View file
|
|
|
encoder/modules/__pycache__/lstm.cpython-38.pyc
ADDED
|
Binary file (1.06 kB). View file
|
|
|
encoder/modules/__pycache__/lstm.cpython-39.pyc
ADDED
|
Binary file (1.05 kB). View file
|
|
|
encoder/modules/__pycache__/norm.cpython-310.pyc
ADDED
|
Binary file (1.15 kB). View file
|
|
|
encoder/modules/__pycache__/norm.cpython-38.pyc
ADDED
|
Binary file (1.15 kB). View file
|
|
|
encoder/modules/__pycache__/norm.cpython-39.pyc
ADDED
|
Binary file (1.14 kB). View file
|
|
|
encoder/modules/__pycache__/seanet.cpython-310.pyc
ADDED
|
Binary file (9.72 kB). View file
|
|
|
encoder/modules/__pycache__/seanet.cpython-38.pyc
ADDED
|
Binary file (9.65 kB). View file
|
|
|
encoder/modules/__pycache__/seanet.cpython-39.pyc
ADDED
|
Binary file (9.48 kB). View file
|
|
|
encoder/modules/__pycache__/transformer.cpython-310.pyc
ADDED
|
Binary file (4.54 kB). View file
|
|
|
encoder/modules/__pycache__/transformer.cpython-38.pyc
ADDED
|
Binary file (4.53 kB). View file
|
|
|
encoder/modules/__pycache__/transformer.cpython-39.pyc
ADDED
|
Binary file (4.47 kB). View file
|
|
|
encoder/modules/conv.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Convolutional layers wrappers and utilities."""
|
| 8 |
+
|
| 9 |
+
import math
|
| 10 |
+
import typing as tp
|
| 11 |
+
import warnings
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from torch import nn
|
| 15 |
+
from torch.nn import functional as F
|
| 16 |
+
from torch.nn.utils import spectral_norm, weight_norm
|
| 17 |
+
|
| 18 |
+
from .norm import ConvLayerNorm
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
|
| 22 |
+
'time_layer_norm', 'layer_norm', 'time_group_norm'])
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
|
| 26 |
+
assert norm in CONV_NORMALIZATIONS
|
| 27 |
+
if norm == 'weight_norm':
|
| 28 |
+
return weight_norm(module)
|
| 29 |
+
elif norm == 'spectral_norm':
|
| 30 |
+
return spectral_norm(module)
|
| 31 |
+
else:
|
| 32 |
+
# We already check was in CONV_NORMALIZATION, so any other choice
|
| 33 |
+
# doesn't need reparametrization.
|
| 34 |
+
return module
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
|
| 38 |
+
"""Return the proper normalization module. If causal is True, this will ensure the returned
|
| 39 |
+
module is causal, or return an error if the normalization doesn't support causal evaluation.
|
| 40 |
+
"""
|
| 41 |
+
assert norm in CONV_NORMALIZATIONS
|
| 42 |
+
if norm == 'layer_norm':
|
| 43 |
+
assert isinstance(module, nn.modules.conv._ConvNd)
|
| 44 |
+
return ConvLayerNorm(module.out_channels, **norm_kwargs)
|
| 45 |
+
elif norm == 'time_group_norm':
|
| 46 |
+
if causal:
|
| 47 |
+
raise ValueError("GroupNorm doesn't support causal evaluation.")
|
| 48 |
+
assert isinstance(module, nn.modules.conv._ConvNd)
|
| 49 |
+
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
|
| 50 |
+
else:
|
| 51 |
+
return nn.Identity()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
|
| 55 |
+
padding_total: int = 0) -> int:
|
| 56 |
+
"""See `pad_for_conv1d`.
|
| 57 |
+
"""
|
| 58 |
+
length = x.shape[-1]
|
| 59 |
+
n_frames = (length - kernel_size + padding_total) / stride + 1
|
| 60 |
+
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
| 61 |
+
return ideal_length - length
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
|
| 65 |
+
"""Pad for a convolution to make sure that the last window is full.
|
| 66 |
+
Extra padding is added at the end. This is required to ensure that we can rebuild
|
| 67 |
+
an output of the same length, as otherwise, even with padding, some time steps
|
| 68 |
+
might get removed.
|
| 69 |
+
For instance, with total padding = 4, kernel size = 4, stride = 2:
|
| 70 |
+
0 0 1 2 3 4 5 0 0 # (0s are padding)
|
| 71 |
+
1 2 3 # (output frames of a convolution, last 0 is never used)
|
| 72 |
+
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
|
| 73 |
+
1 2 3 4 # once you removed padding, we are missing one time step !
|
| 74 |
+
"""
|
| 75 |
+
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
| 76 |
+
return F.pad(x, (0, extra_padding))
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
|
| 80 |
+
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
| 81 |
+
If this is the case, we insert extra 0 padding to the right before the reflection happen.
|
| 82 |
+
"""
|
| 83 |
+
length = x.shape[-1]
|
| 84 |
+
padding_left, padding_right = paddings
|
| 85 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
| 86 |
+
if mode == 'reflect':
|
| 87 |
+
max_pad = max(padding_left, padding_right)
|
| 88 |
+
extra_pad = 0
|
| 89 |
+
if length <= max_pad:
|
| 90 |
+
extra_pad = max_pad - length + 1
|
| 91 |
+
x = F.pad(x, (0, extra_pad))
|
| 92 |
+
padded = F.pad(x, paddings, mode, value)
|
| 93 |
+
end = padded.shape[-1] - extra_pad
|
| 94 |
+
return padded[..., :end]
|
| 95 |
+
else:
|
| 96 |
+
return F.pad(x, paddings, mode, value)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
|
| 100 |
+
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
|
| 101 |
+
padding_left, padding_right = paddings
|
| 102 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
| 103 |
+
assert (padding_left + padding_right) <= x.shape[-1]
|
| 104 |
+
end = x.shape[-1] - padding_right
|
| 105 |
+
return x[..., padding_left: end]
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class NormConv1d(nn.Module):
|
| 109 |
+
"""Wrapper around Conv1d and normalization applied to this conv
|
| 110 |
+
to provide a uniform interface across normalization approaches.
|
| 111 |
+
"""
|
| 112 |
+
def __init__(self, *args, causal: bool = False, norm: str = 'none',
|
| 113 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
| 114 |
+
super().__init__()
|
| 115 |
+
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
|
| 116 |
+
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
|
| 117 |
+
self.norm_type = norm
|
| 118 |
+
|
| 119 |
+
def forward(self, x):
|
| 120 |
+
x = self.conv(x)
|
| 121 |
+
x = self.norm(x)
|
| 122 |
+
return x
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class NormConv2d(nn.Module):
|
| 126 |
+
"""Wrapper around Conv2d and normalization applied to this conv
|
| 127 |
+
to provide a uniform interface across normalization approaches.
|
| 128 |
+
"""
|
| 129 |
+
def __init__(self, *args, norm: str = 'none',
|
| 130 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
| 131 |
+
super().__init__()
|
| 132 |
+
self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
|
| 133 |
+
self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
|
| 134 |
+
self.norm_type = norm
|
| 135 |
+
|
| 136 |
+
def forward(self, x):
|
| 137 |
+
x = self.conv(x)
|
| 138 |
+
x = self.norm(x)
|
| 139 |
+
return x
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class NormConvTranspose1d(nn.Module):
|
| 143 |
+
"""Wrapper around ConvTranspose1d and normalization applied to this conv
|
| 144 |
+
to provide a uniform interface across normalization approaches.
|
| 145 |
+
"""
|
| 146 |
+
def __init__(self, *args, causal: bool = False, norm: str = 'none',
|
| 147 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
| 148 |
+
super().__init__()
|
| 149 |
+
self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
|
| 150 |
+
self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
|
| 151 |
+
self.norm_type = norm
|
| 152 |
+
|
| 153 |
+
def forward(self, x):
|
| 154 |
+
x = self.convtr(x)
|
| 155 |
+
x = self.norm(x)
|
| 156 |
+
return x
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class NormConvTranspose2d(nn.Module):
|
| 160 |
+
"""Wrapper around ConvTranspose2d and normalization applied to this conv
|
| 161 |
+
to provide a uniform interface across normalization approaches.
|
| 162 |
+
"""
|
| 163 |
+
def __init__(self, *args, norm: str = 'none',
|
| 164 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
| 165 |
+
super().__init__()
|
| 166 |
+
self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
|
| 167 |
+
self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
|
| 168 |
+
|
| 169 |
+
def forward(self, x):
|
| 170 |
+
x = self.convtr(x)
|
| 171 |
+
x = self.norm(x)
|
| 172 |
+
return x
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class SConv1d(nn.Module):
|
| 176 |
+
"""Conv1d with some builtin handling of asymmetric or causal padding
|
| 177 |
+
and normalization.
|
| 178 |
+
"""
|
| 179 |
+
def __init__(self, in_channels: int, out_channels: int,
|
| 180 |
+
kernel_size: int, stride: int = 1, dilation: int = 1,
|
| 181 |
+
groups: int = 1, bias: bool = True, causal: bool = False,
|
| 182 |
+
norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
|
| 183 |
+
pad_mode: str = 'reflect'):
|
| 184 |
+
super().__init__()
|
| 185 |
+
# warn user on unusual setup between dilation and stride
|
| 186 |
+
if stride > 1 and dilation > 1:
|
| 187 |
+
warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1'
|
| 188 |
+
f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).')
|
| 189 |
+
self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
|
| 190 |
+
dilation=dilation, groups=groups, bias=bias, causal=causal,
|
| 191 |
+
norm=norm, norm_kwargs=norm_kwargs)
|
| 192 |
+
self.causal = causal
|
| 193 |
+
self.pad_mode = pad_mode
|
| 194 |
+
|
| 195 |
+
def forward(self, x):
|
| 196 |
+
B, C, T = x.shape
|
| 197 |
+
kernel_size = self.conv.conv.kernel_size[0]
|
| 198 |
+
stride = self.conv.conv.stride[0]
|
| 199 |
+
dilation = self.conv.conv.dilation[0]
|
| 200 |
+
kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
|
| 201 |
+
padding_total = kernel_size - stride
|
| 202 |
+
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
| 203 |
+
if self.causal:
|
| 204 |
+
# Left padding for causal
|
| 205 |
+
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
|
| 206 |
+
else:
|
| 207 |
+
# Asymmetric padding required for odd strides
|
| 208 |
+
padding_right = padding_total // 2
|
| 209 |
+
padding_left = padding_total - padding_right
|
| 210 |
+
x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
|
| 211 |
+
return self.conv(x)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class SConvTranspose1d(nn.Module):
|
| 215 |
+
"""ConvTranspose1d with some builtin handling of asymmetric or causal padding
|
| 216 |
+
and normalization.
|
| 217 |
+
"""
|
| 218 |
+
def __init__(self, in_channels: int, out_channels: int,
|
| 219 |
+
kernel_size: int, stride: int = 1, causal: bool = False,
|
| 220 |
+
norm: str = 'none', trim_right_ratio: float = 1.,
|
| 221 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}):
|
| 222 |
+
super().__init__()
|
| 223 |
+
self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
|
| 224 |
+
causal=causal, norm=norm, norm_kwargs=norm_kwargs)
|
| 225 |
+
self.causal = causal
|
| 226 |
+
self.trim_right_ratio = trim_right_ratio
|
| 227 |
+
assert self.causal or self.trim_right_ratio == 1., \
|
| 228 |
+
"`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
|
| 229 |
+
assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
|
| 230 |
+
|
| 231 |
+
def forward(self, x):
|
| 232 |
+
kernel_size = self.convtr.convtr.kernel_size[0]
|
| 233 |
+
stride = self.convtr.convtr.stride[0]
|
| 234 |
+
padding_total = kernel_size - stride
|
| 235 |
+
|
| 236 |
+
y = self.convtr(x)
|
| 237 |
+
|
| 238 |
+
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
|
| 239 |
+
# removed at the very end, when keeping only the right length for the output,
|
| 240 |
+
# as removing it here would require also passing the length at the matching layer
|
| 241 |
+
# in the encoder.
|
| 242 |
+
if self.causal:
|
| 243 |
+
# Trim the padding on the right according to the specified ratio
|
| 244 |
+
# if trim_right_ratio = 1.0, trim everything from right
|
| 245 |
+
padding_right = math.ceil(padding_total * self.trim_right_ratio)
|
| 246 |
+
padding_left = padding_total - padding_right
|
| 247 |
+
y = unpad1d(y, (padding_left, padding_right))
|
| 248 |
+
else:
|
| 249 |
+
# Asymmetric padding required for odd strides
|
| 250 |
+
padding_right = padding_total // 2
|
| 251 |
+
padding_left = padding_total - padding_right
|
| 252 |
+
y = unpad1d(y, (padding_left, padding_right))
|
| 253 |
+
return y
|
encoder/modules/lstm.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""LSTM layers module."""
|
| 8 |
+
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class SLSTM(nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
LSTM without worrying about the hidden state, nor the layout of the data.
|
| 15 |
+
Expects input as convolutional layout.
|
| 16 |
+
"""
|
| 17 |
+
def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.skip = skip
|
| 20 |
+
self.lstm = nn.LSTM(dimension, dimension, num_layers)
|
| 21 |
+
|
| 22 |
+
# def forward(self, x):
|
| 23 |
+
# x = x.permute(2, 0, 1)
|
| 24 |
+
# y, _ = self.lstm(x)
|
| 25 |
+
# if self.skip:
|
| 26 |
+
# y = y + x
|
| 27 |
+
# y = y.permute(1, 2, 0)
|
| 28 |
+
# return y
|
| 29 |
+
|
| 30 |
+
# 修改transpose顺序
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
# # 插入reshape
|
| 33 |
+
# x = x.reshape(x.shape)
|
| 34 |
+
x1 = x.permute(2, 0, 1)
|
| 35 |
+
y, _ = self.lstm(x1)
|
| 36 |
+
y = y.permute(1, 2, 0)
|
| 37 |
+
if self.skip:
|
| 38 |
+
y = y + x
|
| 39 |
+
return y
|
encoder/modules/norm.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Normalization modules."""
|
| 8 |
+
|
| 9 |
+
import typing as tp
|
| 10 |
+
|
| 11 |
+
import einops
|
| 12 |
+
import torch
|
| 13 |
+
from torch import nn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ConvLayerNorm(nn.LayerNorm):
|
| 17 |
+
"""
|
| 18 |
+
Convolution-friendly LayerNorm that moves channels to last dimensions
|
| 19 |
+
before running the normalization and moves them back to original position right after.
|
| 20 |
+
"""
|
| 21 |
+
def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
|
| 22 |
+
super().__init__(normalized_shape, **kwargs)
|
| 23 |
+
|
| 24 |
+
def forward(self, x):
|
| 25 |
+
x = einops.rearrange(x, 'b ... t -> b t ...')
|
| 26 |
+
x = super().forward(x)
|
| 27 |
+
x = einops.rearrange(x, 'b t ... -> b ... t')
|
| 28 |
+
return
|
encoder/modules/seanet.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Encodec SEANet-based encoder and decoder implementation."""
|
| 8 |
+
|
| 9 |
+
import typing as tp
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
|
| 14 |
+
from . import (
|
| 15 |
+
SConv1d,
|
| 16 |
+
SConvTranspose1d,
|
| 17 |
+
SLSTM
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class SEANetResnetBlock(nn.Module):
|
| 22 |
+
"""Residual block from SEANet model.
|
| 23 |
+
Args:
|
| 24 |
+
dim (int): Dimension of the input/output
|
| 25 |
+
kernel_sizes (list): List of kernel sizes for the convolutions.
|
| 26 |
+
dilations (list): List of dilations for the convolutions.
|
| 27 |
+
activation (str): Activation function.
|
| 28 |
+
activation_params (dict): Parameters to provide to the activation function
|
| 29 |
+
norm (str): Normalization method.
|
| 30 |
+
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
| 31 |
+
causal (bool): Whether to use fully causal convolution.
|
| 32 |
+
pad_mode (str): Padding mode for the convolutions.
|
| 33 |
+
compress (int): Reduced dimensionality in residual branches (from Demucs v3)
|
| 34 |
+
true_skip (bool): Whether to use true skip connection or a simple convolution as the skip connection.
|
| 35 |
+
"""
|
| 36 |
+
def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1],
|
| 37 |
+
activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
|
| 38 |
+
norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False,
|
| 39 |
+
pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True):
|
| 40 |
+
super().__init__()
|
| 41 |
+
assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations'
|
| 42 |
+
act = getattr(nn, activation)
|
| 43 |
+
hidden = dim // compress
|
| 44 |
+
block = []
|
| 45 |
+
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
|
| 46 |
+
in_chs = dim if i == 0 else hidden
|
| 47 |
+
out_chs = dim if i == len(kernel_sizes) - 1 else hidden
|
| 48 |
+
block += [
|
| 49 |
+
act(**activation_params),
|
| 50 |
+
SConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation,
|
| 51 |
+
norm=norm, norm_kwargs=norm_params,
|
| 52 |
+
causal=causal, pad_mode=pad_mode),
|
| 53 |
+
]
|
| 54 |
+
self.block = nn.Sequential(*block)
|
| 55 |
+
self.shortcut: nn.Module
|
| 56 |
+
if true_skip:
|
| 57 |
+
self.shortcut = nn.Identity()
|
| 58 |
+
else:
|
| 59 |
+
self.shortcut = SConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params,
|
| 60 |
+
causal=causal, pad_mode=pad_mode)
|
| 61 |
+
|
| 62 |
+
def forward(self, x):
|
| 63 |
+
return self.shortcut(x) + self.block(x)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class SEANetEncoder(nn.Module):
|
| 67 |
+
"""SEANet encoder.
|
| 68 |
+
Args:
|
| 69 |
+
channels (int): Audio channels.
|
| 70 |
+
dimension (int): Intermediate representation dimension.
|
| 71 |
+
n_filters (int): Base width for the model.
|
| 72 |
+
n_residual_layers (int): nb of residual layers.
|
| 73 |
+
ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
|
| 74 |
+
upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
|
| 75 |
+
that must match the decoder order
|
| 76 |
+
activation (str): Activation function.
|
| 77 |
+
activation_params (dict): Parameters to provide to the activation function
|
| 78 |
+
norm (str): Normalization method.
|
| 79 |
+
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
| 80 |
+
kernel_size (int): Kernel size for the initial convolution.
|
| 81 |
+
last_kernel_size (int): Kernel size for the initial convolution.
|
| 82 |
+
residual_kernel_size (int): Kernel size for the residual layers.
|
| 83 |
+
dilation_base (int): How much to increase the dilation with each layer.
|
| 84 |
+
causal (bool): Whether to use fully causal convolution.
|
| 85 |
+
pad_mode (str): Padding mode for the convolutions.
|
| 86 |
+
true_skip (bool): Whether to use true skip connection or a simple
|
| 87 |
+
(streamable) convolution as the skip connection in the residual network blocks.
|
| 88 |
+
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
| 89 |
+
lstm (int): Number of LSTM layers at the end of the encoder.
|
| 90 |
+
"""
|
| 91 |
+
def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 1,
|
| 92 |
+
ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
|
| 93 |
+
norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
|
| 94 |
+
last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
|
| 95 |
+
pad_mode: str = 'reflect', true_skip: bool = False, compress: int = 2, lstm: int = 2):
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.channels = channels
|
| 98 |
+
self.dimension = dimension
|
| 99 |
+
self.n_filters = n_filters
|
| 100 |
+
self.ratios = list(reversed(ratios))
|
| 101 |
+
del ratios
|
| 102 |
+
self.n_residual_layers = n_residual_layers
|
| 103 |
+
self.hop_length = np.prod(self.ratios)
|
| 104 |
+
|
| 105 |
+
act = getattr(nn, activation)
|
| 106 |
+
mult = 1
|
| 107 |
+
model: tp.List[nn.Module] = [
|
| 108 |
+
SConv1d(channels, mult * n_filters, kernel_size, norm=norm, norm_kwargs=norm_params,
|
| 109 |
+
causal=causal, pad_mode=pad_mode)
|
| 110 |
+
]
|
| 111 |
+
# Downsample to raw audio scale
|
| 112 |
+
for i, ratio in enumerate(self.ratios):
|
| 113 |
+
# Add residual layers
|
| 114 |
+
for j in range(n_residual_layers):
|
| 115 |
+
model += [
|
| 116 |
+
SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1],
|
| 117 |
+
dilations=[dilation_base ** j, 1],
|
| 118 |
+
norm=norm, norm_params=norm_params,
|
| 119 |
+
activation=activation, activation_params=activation_params,
|
| 120 |
+
causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
|
| 121 |
+
|
| 122 |
+
# Add downsampling layers
|
| 123 |
+
model += [
|
| 124 |
+
act(**activation_params),
|
| 125 |
+
SConv1d(mult * n_filters, mult * n_filters * 2,
|
| 126 |
+
kernel_size=ratio * 2, stride=ratio,
|
| 127 |
+
norm=norm, norm_kwargs=norm_params,
|
| 128 |
+
causal=causal, pad_mode=pad_mode),
|
| 129 |
+
]
|
| 130 |
+
mult *= 2
|
| 131 |
+
|
| 132 |
+
if lstm:
|
| 133 |
+
model += [SLSTM(mult * n_filters, num_layers=lstm)]
|
| 134 |
+
|
| 135 |
+
model += [
|
| 136 |
+
act(**activation_params),
|
| 137 |
+
SConv1d(mult * n_filters, dimension, last_kernel_size, norm=norm, norm_kwargs=norm_params,
|
| 138 |
+
causal=causal, pad_mode=pad_mode)
|
| 139 |
+
]
|
| 140 |
+
|
| 141 |
+
self.model = nn.Sequential(*model)
|
| 142 |
+
|
| 143 |
+
def forward(self, x):
|
| 144 |
+
return self.model(x)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class SEANetDecoder(nn.Module):
|
| 148 |
+
"""SEANet decoder.
|
| 149 |
+
Args:
|
| 150 |
+
channels (int): Audio channels.
|
| 151 |
+
dimension (int): Intermediate representation dimension.
|
| 152 |
+
n_filters (int): Base width for the model.
|
| 153 |
+
n_residual_layers (int): nb of residual layers.
|
| 154 |
+
ratios (Sequence[int]): kernel size and stride ratios
|
| 155 |
+
activation (str): Activation function.
|
| 156 |
+
activation_params (dict): Parameters to provide to the activation function
|
| 157 |
+
final_activation (str): Final activation function after all convolutions.
|
| 158 |
+
final_activation_params (dict): Parameters to provide to the activation function
|
| 159 |
+
norm (str): Normalization method.
|
| 160 |
+
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
| 161 |
+
kernel_size (int): Kernel size for the initial convolution.
|
| 162 |
+
last_kernel_size (int): Kernel size for the initial convolution.
|
| 163 |
+
residual_kernel_size (int): Kernel size for the residual layers.
|
| 164 |
+
dilation_base (int): How much to increase the dilation with each layer.
|
| 165 |
+
causal (bool): Whether to use fully causal convolution.
|
| 166 |
+
pad_mode (str): Padding mode for the convolutions.
|
| 167 |
+
true_skip (bool): Whether to use true skip connection or a simple
|
| 168 |
+
(streamable) convolution as the skip connection in the residual network blocks.
|
| 169 |
+
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
| 170 |
+
lstm (int): Number of LSTM layers at the end of the encoder.
|
| 171 |
+
trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
|
| 172 |
+
If equal to 1.0, it means that all the trimming is done at the right.
|
| 173 |
+
"""
|
| 174 |
+
def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 1,
|
| 175 |
+
ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
|
| 176 |
+
final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None,
|
| 177 |
+
norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
|
| 178 |
+
last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
|
| 179 |
+
pad_mode: str = 'reflect', true_skip: bool = False, compress: int = 2, lstm: int = 2,
|
| 180 |
+
trim_right_ratio: float = 1.0):
|
| 181 |
+
super().__init__()
|
| 182 |
+
self.dimension = dimension
|
| 183 |
+
self.channels = channels
|
| 184 |
+
self.n_filters = n_filters
|
| 185 |
+
self.ratios = ratios
|
| 186 |
+
del ratios
|
| 187 |
+
self.n_residual_layers = n_residual_layers
|
| 188 |
+
self.hop_length = np.prod(self.ratios)
|
| 189 |
+
|
| 190 |
+
act = getattr(nn, activation)
|
| 191 |
+
mult = int(2 ** len(self.ratios))
|
| 192 |
+
model: tp.List[nn.Module] = [
|
| 193 |
+
SConv1d(dimension, mult * n_filters, kernel_size, norm=norm, norm_kwargs=norm_params,
|
| 194 |
+
causal=causal, pad_mode=pad_mode)
|
| 195 |
+
]
|
| 196 |
+
|
| 197 |
+
if lstm:
|
| 198 |
+
model += [SLSTM(mult * n_filters, num_layers=lstm)]
|
| 199 |
+
|
| 200 |
+
# Upsample to raw audio scale
|
| 201 |
+
for i, ratio in enumerate(self.ratios):
|
| 202 |
+
# Add upsampling layers
|
| 203 |
+
model += [
|
| 204 |
+
act(**activation_params),
|
| 205 |
+
SConvTranspose1d(mult * n_filters, mult * n_filters // 2,
|
| 206 |
+
kernel_size=ratio * 2, stride=ratio,
|
| 207 |
+
norm=norm, norm_kwargs=norm_params,
|
| 208 |
+
causal=causal, trim_right_ratio=trim_right_ratio),
|
| 209 |
+
]
|
| 210 |
+
# Add residual layers
|
| 211 |
+
for j in range(n_residual_layers):
|
| 212 |
+
model += [
|
| 213 |
+
SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1],
|
| 214 |
+
dilations=[dilation_base ** j, 1],
|
| 215 |
+
activation=activation, activation_params=activation_params,
|
| 216 |
+
norm=norm, norm_params=norm_params, causal=causal,
|
| 217 |
+
pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
|
| 218 |
+
|
| 219 |
+
mult //= 2
|
| 220 |
+
|
| 221 |
+
# Add final layers
|
| 222 |
+
model += [
|
| 223 |
+
act(**activation_params),
|
| 224 |
+
SConv1d(n_filters, channels, last_kernel_size, norm=norm, norm_kwargs=norm_params,
|
| 225 |
+
causal=causal, pad_mode=pad_mode)
|
| 226 |
+
]
|
| 227 |
+
# Add optional final activation to decoder (eg. tanh)
|
| 228 |
+
if final_activation is not None:
|
| 229 |
+
final_act = getattr(nn, final_activation)
|
| 230 |
+
final_activation_params = final_activation_params or {}
|
| 231 |
+
model += [
|
| 232 |
+
final_act(**final_activation_params)
|
| 233 |
+
]
|
| 234 |
+
self.model = nn.Sequential(*model)
|
| 235 |
+
|
| 236 |
+
def forward(self, z):
|
| 237 |
+
y = self.model(z)
|
| 238 |
+
return y
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def test():
|
| 242 |
+
import torch
|
| 243 |
+
encoder = SEANetEncoder()
|
| 244 |
+
decoder = SEANetDecoder()
|
| 245 |
+
x = torch.randn(1, 1, 24000)
|
| 246 |
+
z = encoder(x)
|
| 247 |
+
assert list(z.shape) == [1, 128, 75], z.shape
|
| 248 |
+
y = decoder(z)
|
| 249 |
+
assert y.shape == x.shape, (x.shape, y.shape)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
if __name__ == '__main__':
|
| 253 |
+
test()
|
encoder/modules/transformer.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""A streamable transformer."""
|
| 8 |
+
|
| 9 |
+
import typing as tp
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000):
|
| 17 |
+
"""Create time embedding for the given positions, target dimension `dim`.
|
| 18 |
+
"""
|
| 19 |
+
# We aim for BTC format
|
| 20 |
+
assert dim % 2 == 0
|
| 21 |
+
half_dim = dim // 2
|
| 22 |
+
adim = torch.arange(half_dim, device=positions.device).view(1, 1, -1)
|
| 23 |
+
phase = positions / (max_period ** (adim / (half_dim - 1)))
|
| 24 |
+
return torch.cat([
|
| 25 |
+
torch.cos(phase),
|
| 26 |
+
torch.sin(phase),
|
| 27 |
+
], dim=-1)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class StreamingTransformerEncoderLayer(nn.TransformerEncoderLayer):
|
| 31 |
+
def forward(self, x: torch.Tensor, x_past: torch.Tensor, past_context: int): # type: ignore
|
| 32 |
+
if self.norm_first:
|
| 33 |
+
sa_input = self.norm1(x)
|
| 34 |
+
x = x + self._sa_block(sa_input, x_past, past_context)
|
| 35 |
+
x = x + self._ff_block(self.norm2(x))
|
| 36 |
+
else:
|
| 37 |
+
sa_input = x
|
| 38 |
+
x = self.norm1(x + self._sa_block(sa_input, x_past, past_context))
|
| 39 |
+
x = self.norm2(x + self._ff_block(x))
|
| 40 |
+
|
| 41 |
+
return x, sa_input
|
| 42 |
+
|
| 43 |
+
# self-attention block
|
| 44 |
+
def _sa_block(self, x: torch.Tensor, x_past: torch.Tensor, past_context: int): # type: ignore
|
| 45 |
+
_, T, _ = x.shape
|
| 46 |
+
_, H, _ = x_past.shape
|
| 47 |
+
|
| 48 |
+
queries = x
|
| 49 |
+
keys = torch.cat([x_past, x], dim=1)
|
| 50 |
+
values = keys
|
| 51 |
+
|
| 52 |
+
queries_pos = torch.arange(H, T + H, device=x.device).view(-1, 1)
|
| 53 |
+
keys_pos = torch.arange(T + H, device=x.device).view(1, -1)
|
| 54 |
+
delta = queries_pos - keys_pos
|
| 55 |
+
valid_access = (delta >= 0) & (delta <= past_context)
|
| 56 |
+
x = self.self_attn(queries, keys, values,
|
| 57 |
+
attn_mask=~valid_access,
|
| 58 |
+
need_weights=False)[0]
|
| 59 |
+
return self.dropout1(x)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class StreamingTransformerEncoder(nn.Module):
|
| 63 |
+
"""TransformerEncoder with streaming support.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
dim (int): dimension of the data.
|
| 67 |
+
hidden_scale (int): intermediate dimension of FF module is this times the dimension.
|
| 68 |
+
num_heads (int): number of heads.
|
| 69 |
+
num_layers (int): number of layers.
|
| 70 |
+
max_period (float): maxium period of cosines in the positional embedding.
|
| 71 |
+
past_context (int or None): receptive field for the causal mask, infinite if None.
|
| 72 |
+
gelu (bool): if true uses GeLUs, otherwise use ReLUs.
|
| 73 |
+
norm_in (bool): normalize the input.
|
| 74 |
+
dropout (float): dropout probability.
|
| 75 |
+
**kwargs: See `nn.TransformerEncoderLayer`.
|
| 76 |
+
"""
|
| 77 |
+
def __init__(self, dim, hidden_scale: float = 4., num_heads: int = 8, num_layers: int = 5,
|
| 78 |
+
max_period: float = 10000, past_context: int = 1000, gelu: bool = True,
|
| 79 |
+
norm_in: bool = True, dropout: float = 0., **kwargs):
|
| 80 |
+
super().__init__()
|
| 81 |
+
assert dim % num_heads == 0
|
| 82 |
+
hidden_dim = int(dim * hidden_scale)
|
| 83 |
+
|
| 84 |
+
self.max_period = max_period
|
| 85 |
+
self.past_context = past_context
|
| 86 |
+
activation: tp.Any = F.gelu if gelu else F.relu
|
| 87 |
+
|
| 88 |
+
self.norm_in: nn.Module
|
| 89 |
+
if norm_in:
|
| 90 |
+
self.norm_in = nn.LayerNorm(dim)
|
| 91 |
+
else:
|
| 92 |
+
self.norm_in = nn.Identity()
|
| 93 |
+
|
| 94 |
+
self.layers = nn.ModuleList()
|
| 95 |
+
for idx in range(num_layers):
|
| 96 |
+
self.layers.append(
|
| 97 |
+
StreamingTransformerEncoderLayer(
|
| 98 |
+
dim, num_heads, hidden_dim,
|
| 99 |
+
activation=activation, batch_first=True, dropout=dropout, **kwargs))
|
| 100 |
+
|
| 101 |
+
def forward(self, x: torch.Tensor,
|
| 102 |
+
states: tp.Optional[tp.List[torch.Tensor]] = None,
|
| 103 |
+
offset: tp.Union[int, torch.Tensor] = 0):
|
| 104 |
+
B, T, C = x.shape
|
| 105 |
+
if states is None:
|
| 106 |
+
states = [torch.zeros_like(x[:, :1]) for _ in range(1 + len(self.layers))]
|
| 107 |
+
|
| 108 |
+
positions = torch.arange(T, device=x.device).view(1, -1, 1) + offset
|
| 109 |
+
pos_emb = create_sin_embedding(positions, C, max_period=self.max_period)
|
| 110 |
+
|
| 111 |
+
new_state: tp.List[torch.Tensor] = []
|
| 112 |
+
x = self.norm_in(x)
|
| 113 |
+
x = x + pos_emb
|
| 114 |
+
|
| 115 |
+
for layer_state, layer in zip(states, self.layers):
|
| 116 |
+
x, new_layer_state = layer(x, layer_state, self.past_context)
|
| 117 |
+
new_layer_state = torch.cat([layer_state, new_layer_state], dim=1)
|
| 118 |
+
new_state.append(new_layer_state[:, -self.past_context:, :])
|
| 119 |
+
return x, new_state, offset + T
|
encoder/msstftd.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""MS-STFT discriminator, provided here for reference."""
|
| 8 |
+
|
| 9 |
+
import typing as tp
|
| 10 |
+
|
| 11 |
+
import torchaudio
|
| 12 |
+
import torch
|
| 13 |
+
from torch import nn
|
| 14 |
+
from einops import rearrange
|
| 15 |
+
|
| 16 |
+
from .modules import NormConv2d
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
FeatureMapType = tp.List[torch.Tensor]
|
| 20 |
+
LogitsType = torch.Tensor
|
| 21 |
+
DiscriminatorOutput = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)):
|
| 25 |
+
return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class DiscriminatorSTFT(nn.Module):
|
| 29 |
+
"""STFT sub-discriminator.
|
| 30 |
+
Args:
|
| 31 |
+
filters (int): Number of filters in convolutions
|
| 32 |
+
in_channels (int): Number of input channels. Default: 1
|
| 33 |
+
out_channels (int): Number of output channels. Default: 1
|
| 34 |
+
n_fft (int): Size of FFT for each scale. Default: 1024
|
| 35 |
+
hop_length (int): Length of hop between STFT windows for each scale. Default: 256
|
| 36 |
+
kernel_size (tuple of int): Inner Conv2d kernel sizes. Default: ``(3, 9)``
|
| 37 |
+
stride (tuple of int): Inner Conv2d strides. Default: ``(1, 2)``
|
| 38 |
+
dilations (list of int): Inner Conv2d dilation on the time dimension. Default: ``[1, 2, 4]``
|
| 39 |
+
win_length (int): Window size for each scale. Default: 1024
|
| 40 |
+
normalized (bool): Whether to normalize by magnitude after stft. Default: True
|
| 41 |
+
norm (str): Normalization method. Default: `'weight_norm'`
|
| 42 |
+
activation (str): Activation function. Default: `'LeakyReLU'`
|
| 43 |
+
activation_params (dict): Parameters to provide to the activation function.
|
| 44 |
+
growth (int): Growth factor for the filters. Default: 1
|
| 45 |
+
"""
|
| 46 |
+
def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1,
|
| 47 |
+
n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, max_filters: int = 1024,
|
| 48 |
+
filters_scale: int = 1, kernel_size: tp.Tuple[int, int] = (3, 9), dilations: tp.List = [1, 2, 4],
|
| 49 |
+
stride: tp.Tuple[int, int] = (1, 2), normalized: bool = True, norm: str = 'weight_norm',
|
| 50 |
+
activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}):
|
| 51 |
+
super().__init__()
|
| 52 |
+
assert len(kernel_size) == 2
|
| 53 |
+
assert len(stride) == 2
|
| 54 |
+
self.filters = filters
|
| 55 |
+
self.in_channels = in_channels
|
| 56 |
+
self.out_channels = out_channels
|
| 57 |
+
self.n_fft = n_fft
|
| 58 |
+
self.hop_length = hop_length
|
| 59 |
+
self.win_length = win_length
|
| 60 |
+
self.normalized = normalized
|
| 61 |
+
self.activation = getattr(torch.nn, activation)(**activation_params)
|
| 62 |
+
self.spec_transform = torchaudio.transforms.Spectrogram(
|
| 63 |
+
n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window_fn=torch.hann_window,
|
| 64 |
+
normalized=self.normalized, center=False, pad_mode=None, power=None)
|
| 65 |
+
spec_channels = 2 * self.in_channels
|
| 66 |
+
self.convs = nn.ModuleList()
|
| 67 |
+
self.convs.append(
|
| 68 |
+
NormConv2d(spec_channels, self.filters, kernel_size=kernel_size, padding=get_2d_padding(kernel_size))
|
| 69 |
+
)
|
| 70 |
+
in_chs = min(filters_scale * self.filters, max_filters)
|
| 71 |
+
for i, dilation in enumerate(dilations):
|
| 72 |
+
out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters)
|
| 73 |
+
self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride,
|
| 74 |
+
dilation=(dilation, 1), padding=get_2d_padding(kernel_size, (dilation, 1)),
|
| 75 |
+
norm=norm))
|
| 76 |
+
in_chs = out_chs
|
| 77 |
+
out_chs = min((filters_scale ** (len(dilations) + 1)) * self.filters, max_filters)
|
| 78 |
+
self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_size[0], kernel_size[0]),
|
| 79 |
+
padding=get_2d_padding((kernel_size[0], kernel_size[0])),
|
| 80 |
+
norm=norm))
|
| 81 |
+
self.conv_post = NormConv2d(out_chs, self.out_channels,
|
| 82 |
+
kernel_size=(kernel_size[0], kernel_size[0]),
|
| 83 |
+
padding=get_2d_padding((kernel_size[0], kernel_size[0])),
|
| 84 |
+
norm=norm)
|
| 85 |
+
|
| 86 |
+
def forward(self, x: torch.Tensor):
|
| 87 |
+
fmap = []
|
| 88 |
+
z = self.spec_transform(x) # [B, 2, Freq, Frames, 2]
|
| 89 |
+
z = torch.cat([z.real, z.imag], dim=1)
|
| 90 |
+
z = rearrange(z, 'b c w t -> b c t w')
|
| 91 |
+
for i, layer in enumerate(self.convs):
|
| 92 |
+
z = layer(z)
|
| 93 |
+
z = self.activation(z)
|
| 94 |
+
fmap.append(z)
|
| 95 |
+
z = self.conv_post(z)
|
| 96 |
+
return z, fmap
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class MultiScaleSTFTDiscriminator(nn.Module):
|
| 100 |
+
"""Multi-Scale STFT (MS-STFT) discriminator.
|
| 101 |
+
Args:
|
| 102 |
+
filters (int): Number of filters in convolutions
|
| 103 |
+
in_channels (int): Number of input channels. Default: 1
|
| 104 |
+
out_channels (int): Number of output channels. Default: 1
|
| 105 |
+
n_ffts (Sequence[int]): Size of FFT for each scale
|
| 106 |
+
hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale
|
| 107 |
+
win_lengths (Sequence[int]): Window size for each scale
|
| 108 |
+
**kwargs: additional args for STFTDiscriminator
|
| 109 |
+
"""
|
| 110 |
+
def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1,
|
| 111 |
+
n_ffts: tp.List[int] = [1024, 2048, 512], hop_lengths: tp.List[int] = [256, 512, 128],
|
| 112 |
+
win_lengths: tp.List[int] = [1024, 2048, 512], **kwargs):
|
| 113 |
+
super().__init__()
|
| 114 |
+
assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
|
| 115 |
+
self.discriminators = nn.ModuleList([
|
| 116 |
+
DiscriminatorSTFT(filters, in_channels=in_channels, out_channels=out_channels,
|
| 117 |
+
n_fft=n_ffts[i], win_length=win_lengths[i], hop_length=hop_lengths[i], **kwargs)
|
| 118 |
+
for i in range(len(n_ffts))
|
| 119 |
+
])
|
| 120 |
+
self.num_discriminators = len(self.discriminators)
|
| 121 |
+
|
| 122 |
+
def forward(self, x: torch.Tensor) -> DiscriminatorOutput:
|
| 123 |
+
logits = []
|
| 124 |
+
fmaps = []
|
| 125 |
+
for disc in self.discriminators:
|
| 126 |
+
logit, fmap = disc(x)
|
| 127 |
+
logits.append(logit)
|
| 128 |
+
fmaps.append(fmap)
|
| 129 |
+
return logits, fmaps
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def test():
|
| 133 |
+
disc = MultiScaleSTFTDiscriminator(filters=32)
|
| 134 |
+
y = torch.randn(1, 1, 24000)
|
| 135 |
+
y_hat = torch.randn(1, 1, 24000)
|
| 136 |
+
|
| 137 |
+
y_disc_r, fmap_r = disc(y)
|
| 138 |
+
y_disc_gen, fmap_gen = disc(y_hat)
|
| 139 |
+
assert len(y_disc_r) == len(y_disc_gen) == len(fmap_r) == len(fmap_gen) == disc.num_discriminators
|
| 140 |
+
|
| 141 |
+
assert all([len(fm) == 5 for fm in fmap_r + fmap_gen])
|
| 142 |
+
assert all([list(f.shape)[:2] == [1, 32] for fm in fmap_r + fmap_gen for f in fm])
|
| 143 |
+
assert all([len(logits.shape) == 4 for logits in y_disc_r + y_disc_gen])
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
if __name__ == '__main__':
|
| 147 |
+
test()
|
encoder/quantization/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# flake8: noqa
|
| 8 |
+
from .vq import QuantizedResult, ResidualVectorQuantizer
|
encoder/quantization/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (243 Bytes). View file
|
|
|
encoder/quantization/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (253 Bytes). View file
|
|
|
encoder/quantization/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (241 Bytes). View file
|
|
|
encoder/quantization/__pycache__/core_vq.cpython-310.pyc
ADDED
|
Binary file (12.3 kB). View file
|
|
|
encoder/quantization/__pycache__/core_vq.cpython-38.pyc
ADDED
|
Binary file (13.4 kB). View file
|
|
|
encoder/quantization/__pycache__/core_vq.cpython-39.pyc
ADDED
|
Binary file (12.6 kB). View file
|
|
|
encoder/quantization/__pycache__/vq.cpython-310.pyc
ADDED
|
Binary file (5.14 kB). View file
|
|
|
encoder/quantization/__pycache__/vq.cpython-38.pyc
ADDED
|
Binary file (4.8 kB). View file
|
|
|
encoder/quantization/__pycache__/vq.cpython-39.pyc
ADDED
|
Binary file (5.12 kB). View file
|
|
|