File size: 7,058 Bytes
68e74d1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
'''
Here you can find the encoder and decoder architecture. The encoder takes in the raw waveform and produces a sequence of latent frames, which are then quantized with RVQ and fed to the transformer.
The decoder takes in the (possibly corrupted) latent frames and reconstructs the waveform.
'''
import typing as tp
import numpy as np
import torch.nn as nn
from Utils import Snake, SLSTM, SConv1d, SConvTranspose1d
class ZPEncoderResnetBlock(nn.Module):
"""Residual block with a bottleneck conv stack; used in both encoder and decoder. Yeah, wrong name :) """
"""props to the code implementation to Meta's EnCodec"""
def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1],
norm: str = 'weight_norm', causal: bool = False, compress: int = 2, make_act=None):
super().__init__()
assert len(kernel_sizes) == len(dilations)
if make_act is None:
make_act = lambda _: nn.ELU()
pad_mode = 'constant' if causal else 'reflect'
hidden = dim // compress
block = []
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
in_chs = dim if i == 0 else hidden
out_chs = dim if i == len(kernel_sizes) - 1 else hidden
block += [
make_act(in_chs),
SConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation,
norm=norm, causal=causal, pad_mode=pad_mode),
]
self.block = nn.Sequential(*block)
self.shortcut = nn.Identity()
def forward(self, x):
return self.shortcut(x) + self.block(x)
class ZPEncoder(nn.Module):
"""
Causal convolutional encoder: waveform -> latent frames.
Architecture: stem conv -> (resnet blocks + strided conv) x len(ratios) -> LSTM -> projection.
hop_length = prod(ratios); default 240 samples = 15ms at 16kHz.
"""
def __init__(
self,
channels: int = 1,
dimension: int = 128,
n_filters: int = 32,
ratios: tp.List[int] = [8, 5, 3, 2],
norm: str = 'weight_norm',
kernel_size: int = 7,
last_kernel_size: int = 7,
residual_kernel_size: int = 7,
residual_dilations: tp.List[int] = [1, 3, 9],
causal: bool = True,
compress: int = 2,
):
super().__init__()
self.channels = channels
self.dimension = dimension
self.n_filters = n_filters
self.ratios = list(reversed(ratios))
self.hop_length = np.prod(self.ratios)
self.residual_dilations = residual_dilations
pad_mode = 'constant' if causal else 'reflect'
mult = 1
model: tp.List[nn.Module] = [
SConv1d(channels, mult * n_filters, kernel_size, norm=norm, causal=causal, pad_mode=pad_mode)
]
for ratio in self.ratios:
# dilated residual stack before each downsampling step (like soundstream) to catch long-term dependencies early on.
# useful for our use case...
for dilation in self.residual_dilations:
model += [
ZPEncoderResnetBlock(
mult * n_filters,
kernel_sizes=[residual_kernel_size, 1],
dilations=[dilation, 1],
norm=norm,
causal=causal,
compress=compress,
)
]
# strided conv doubles channels while halving time resolution
model += [
nn.ELU(),
SConv1d(mult * n_filters, mult * n_filters * 2, kernel_size=ratio * 2,
stride=ratio, norm=norm, causal=causal, pad_mode=pad_mode),
]
mult *= 2
''' Only one LSTM layer while EnCodec has 2, this is because (counting up the transformer and the 30ms lookahead)
with 2 the total latency was too high to be realtime. Tested with RTF...
'''
model += [SLSTM(mult * n_filters, num_layers=1)]
model += [
nn.ELU(),
SConv1d(mult * n_filters, dimension, last_kernel_size, norm=norm, causal=causal, pad_mode=pad_mode),
]
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
class ZPDecoder(nn.Module):
"""
Causal convolutional decoder: latent frames -> waveform.
Mirror of ZPEncoder: projection -> LSTM -> (transposed conv + resnet blocks) x len(ratios) -> Tanh output.
Uses Snake activations in residual blocks (better for audio periodicity), props to the improved RVQGAN paper.
In the encoder the snake activations didn't improve much, so I left ELU there. For waveform reconstruction they are a game changer.
"""
def __init__(
self,
channels: int = 1,
dimension: int = 128,
n_filters: int = 32,
ratios: tp.List[int] = [8, 5, 3, 2],
norm: str = 'weight_norm',
kernel_size: int = 7,
last_kernel_size: int = 7,
residual_kernel_size: int = 7,
residual_dilations: tp.List[int] = [1, 3, 9],
causal: bool = True,
compress: int = 2,
):
super().__init__()
self.channels = channels
self.dimension = dimension
self.n_filters = n_filters
self.ratios = ratios
self.hop_length = np.prod(self.ratios)
self.residual_dilations = residual_dilations
pad_mode = 'constant' if causal else 'reflect'
mult = 2 ** len(ratios) # start at maximum channel width
model: tp.List[nn.Module] = [
SConv1d(dimension, mult * n_filters, kernel_size, norm=norm, causal=causal, pad_mode=pad_mode)
]
model += [SLSTM(mult * n_filters, num_layers=1)]
make_act = lambda ch: Snake(ch)
for ratio in self.ratios:
# transposed conv halves channels while doubling time resolution
model += [
nn.ELU(),
SConvTranspose1d(mult * n_filters, mult * n_filters // 2, kernel_size=ratio * 2,
stride=ratio, norm=norm, causal=causal),
]
mult //= 2
for dilation in self.residual_dilations:
model += [
ZPEncoderResnetBlock(
mult * n_filters,
kernel_sizes=[residual_kernel_size, 1],
dilations=[dilation, 1],
norm=norm,
causal=causal,
compress=compress,
make_act=make_act,
)
]
model += [
nn.ELU(),
SConv1d(mult * n_filters, channels, last_kernel_size, norm=norm, causal=causal, pad_mode=pad_mode),
nn.Tanh(), # clamp output to [-1, 1]
]
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
|