File size: 7,058 Bytes
68e74d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182

'''
Here you can find the encoder and decoder architecture. The encoder takes in the raw waveform and produces a sequence of latent frames, which are then quantized with RVQ and fed to the transformer.
The decoder takes in the (possibly corrupted) latent frames and reconstructs the waveform.
'''

import typing as tp

import numpy as np
import torch.nn as nn
from Utils import Snake, SLSTM, SConv1d, SConvTranspose1d


class ZPEncoderResnetBlock(nn.Module):
    """Residual block with a bottleneck conv stack; used in both encoder and decoder. Yeah, wrong name :) """
    """props to the code implementation to Meta's EnCodec"""
    def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1],
                 norm: str = 'weight_norm', causal: bool = False, compress: int = 2, make_act=None):
        super().__init__()
        assert len(kernel_sizes) == len(dilations)
        if make_act is None:
            make_act = lambda _: nn.ELU()
        pad_mode = 'constant' if causal else 'reflect'
        hidden = dim // compress
        block = []
        for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
            in_chs = dim if i == 0 else hidden
            out_chs = dim if i == len(kernel_sizes) - 1 else hidden
            block += [
                make_act(in_chs),
                SConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation,
                        norm=norm, causal=causal, pad_mode=pad_mode),
            ]
        self.block = nn.Sequential(*block)
        self.shortcut = nn.Identity()

    def forward(self, x):
        return self.shortcut(x) + self.block(x)


class ZPEncoder(nn.Module):
    """
    Causal convolutional encoder: waveform -> latent frames.
    Architecture: stem conv -> (resnet blocks + strided conv) x len(ratios) -> LSTM -> projection.
    hop_length = prod(ratios); default 240 samples = 15ms at 16kHz.
    """
    def __init__(
        self,
        channels: int = 1,
        dimension: int = 128,
        n_filters: int = 32,
        ratios: tp.List[int] = [8, 5, 3, 2],
        norm: str = 'weight_norm',
        kernel_size: int = 7,
        last_kernel_size: int = 7,
        residual_kernel_size: int = 7,
        residual_dilations: tp.List[int] = [1, 3, 9],
        causal: bool = True,
        compress: int = 2,
    ):
        super().__init__()
        self.channels = channels
        self.dimension = dimension
        self.n_filters = n_filters
        self.ratios = list(reversed(ratios))
        self.hop_length = np.prod(self.ratios)
        self.residual_dilations = residual_dilations

        pad_mode = 'constant' if causal else 'reflect'
        mult = 1

        model: tp.List[nn.Module] = [
            SConv1d(channels, mult * n_filters, kernel_size, norm=norm, causal=causal, pad_mode=pad_mode)
        ]

        for ratio in self.ratios:
            # dilated residual stack before each downsampling step (like soundstream) to catch long-term dependencies early on.
            # useful for our use case...
            for dilation in self.residual_dilations:
                model += [
                    ZPEncoderResnetBlock(
                        mult * n_filters,
                        kernel_sizes=[residual_kernel_size, 1],
                        dilations=[dilation, 1],
                        norm=norm,
                        causal=causal,
                        compress=compress,
                    )
                ]
            # strided conv doubles channels while halving time resolution
            model += [
                nn.ELU(),
                SConv1d(mult * n_filters, mult * n_filters * 2, kernel_size=ratio * 2,
                        stride=ratio, norm=norm, causal=causal, pad_mode=pad_mode),
            ]
            mult *= 2
        ''' Only one LSTM layer while EnCodec has 2, this is because (counting up the transformer and the 30ms lookahead)
            with 2 the total latency was too high to be realtime. Tested with RTF...
        '''
        model += [SLSTM(mult * n_filters, num_layers=1)]
        model += [
            nn.ELU(),
            SConv1d(mult * n_filters, dimension, last_kernel_size, norm=norm, causal=causal, pad_mode=pad_mode),
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


class ZPDecoder(nn.Module):
    """
    Causal convolutional decoder: latent frames -> waveform.
    Mirror of ZPEncoder: projection -> LSTM -> (transposed conv + resnet blocks) x len(ratios) -> Tanh output.
    Uses Snake activations in residual blocks (better for audio periodicity), props to the improved RVQGAN paper.
    In the encoder the snake activations didn't improve much, so I left ELU there. For waveform reconstruction they are a game changer.
    """
    def __init__(
        self,
        channels: int = 1,
        dimension: int = 128,
        n_filters: int = 32,
        ratios: tp.List[int] = [8, 5, 3, 2],
        norm: str = 'weight_norm',
        kernel_size: int = 7,
        last_kernel_size: int = 7,
        residual_kernel_size: int = 7,
        residual_dilations: tp.List[int] = [1, 3, 9],
        causal: bool = True,
        compress: int = 2,
    ):
        super().__init__()
        self.channels = channels
        self.dimension = dimension
        self.n_filters = n_filters
        self.ratios = ratios
        self.hop_length = np.prod(self.ratios)
        self.residual_dilations = residual_dilations

        pad_mode = 'constant' if causal else 'reflect'
        mult = 2 ** len(ratios)  # start at maximum channel width

        model: tp.List[nn.Module] = [
            SConv1d(dimension, mult * n_filters, kernel_size, norm=norm, causal=causal, pad_mode=pad_mode)
        ]
        model += [SLSTM(mult * n_filters, num_layers=1)]

        make_act = lambda ch: Snake(ch)

        for ratio in self.ratios:
            # transposed conv halves channels while doubling time resolution
            model += [
                nn.ELU(),
                SConvTranspose1d(mult * n_filters, mult * n_filters // 2, kernel_size=ratio * 2,
                                 stride=ratio, norm=norm, causal=causal),
            ]
            mult //= 2
            for dilation in self.residual_dilations:
                model += [
                    ZPEncoderResnetBlock(
                        mult * n_filters,
                        kernel_sizes=[residual_kernel_size, 1],
                        dilations=[dilation, 1],
                        norm=norm,
                        causal=causal,
                        compress=compress,
                        make_act=make_act,
                    )
                ]

        model += [
            nn.ELU(),
            SConv1d(mult * n_filters, channels, last_kernel_size, norm=norm, causal=causal, pad_mode=pad_mode),
            nn.Tanh(),  # clamp output to [-1, 1]
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)