Lucabr01 commited on
Commit
11c5699
·
verified ·
1 Parent(s): b9fd30d

Upload zpcodec/Utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. zpcodec/Utils.py +358 -0
zpcodec/Utils.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the MIT License found in
5
+ # META_LICENSE.txt in the root directory of this source tree.
6
+
7
+ """Convolutional layers wrappers and utilities + LSTM layers module"""
8
+
9
+ import math
10
+ import typing as tp
11
+ import warnings
12
+
13
+ import torch
14
+ from torch import nn
15
+ from torch.nn import functional as F
16
+ from torch.nn.utils import spectral_norm, weight_norm
17
+ import einops
18
+
19
+
20
+ class ConvLayerNorm(nn.LayerNorm):
21
+ """
22
+ Convolution-friendly LayerNorm that moves channels to last dimensions
23
+ before running the normalization and moves them back to original position right after.
24
+ """
25
+ def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
26
+ super().__init__(normalized_shape, **kwargs)
27
+
28
+ def forward(self, x):
29
+ x = einops.rearrange(x, 'b ... t -> b t ...')
30
+ x = super().forward(x)
31
+ x = einops.rearrange(x, 'b t ... -> b ... t')
32
+ return x
33
+
34
+
35
+ CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
36
+ 'time_layer_norm', 'layer_norm', 'time_group_norm'])
37
+
38
+
39
+ def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
40
+ assert norm in CONV_NORMALIZATIONS
41
+ if norm == 'weight_norm':
42
+ return weight_norm(module)
43
+ elif norm == 'spectral_norm':
44
+ return spectral_norm(module)
45
+ else:
46
+ # We already check was in CONV_NORMALIZATION, so any other choice
47
+ # doesn't need reparametrization.
48
+ return module
49
+
50
+
51
+ def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
52
+ """Return the proper normalization module. If causal is True, this will ensure the returned
53
+ module is causal, or return an error if the normalization doesn't support causal evaluation.
54
+ """
55
+ assert norm in CONV_NORMALIZATIONS
56
+ if norm == 'layer_norm':
57
+ assert isinstance(module, nn.modules.conv._ConvNd)
58
+ return ConvLayerNorm(module.out_channels, **norm_kwargs)
59
+ elif norm == 'time_group_norm':
60
+ if causal:
61
+ raise ValueError("GroupNorm doesn't support causal evaluation.")
62
+ assert isinstance(module, nn.modules.conv._ConvNd)
63
+ return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
64
+ else:
65
+ return nn.Identity()
66
+
67
+
68
+ def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
69
+ padding_total: int = 0) -> int:
70
+ """See `pad_for_conv1d`.
71
+ """
72
+ length = x.shape[-1]
73
+ n_frames = (length - kernel_size + padding_total) / stride + 1
74
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
75
+ return ideal_length - length
76
+
77
+
78
+ def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
79
+ """Pad for a convolution to make sure that the last window is full.
80
+ Extra padding is added at the end. This is required to ensure that we can rebuild
81
+ an output of the same length, as otherwise, even with padding, some time steps
82
+ might get removed.
83
+ For instance, with total padding = 4, kernel size = 4, stride = 2:
84
+ 0 0 1 2 3 4 5 0 0 # (0s are padding)
85
+ 1 2 3 # (output frames of a convolution, last 0 is never used)
86
+ 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
87
+ 1 2 3 4 # once you removed padding, we are missing one time step !
88
+ """
89
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
90
+ return F.pad(x, (0, extra_padding))
91
+
92
+
93
+ def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
94
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
95
+ If this is the case, we insert extra 0 padding to the right before the reflection happen.
96
+ """
97
+ length = x.shape[-1]
98
+ padding_left, padding_right = paddings
99
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
100
+ if mode == 'reflect':
101
+ max_pad = max(padding_left, padding_right)
102
+ extra_pad = 0
103
+ if length <= max_pad:
104
+ extra_pad = max_pad - length + 1
105
+ x = F.pad(x, (0, extra_pad))
106
+ padded = F.pad(x, paddings, mode, value)
107
+ end = padded.shape[-1] - extra_pad
108
+ return padded[..., :end]
109
+ else:
110
+ return F.pad(x, paddings, mode, value)
111
+
112
+
113
+ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
114
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
115
+ padding_left, padding_right = paddings
116
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
117
+ assert (padding_left + padding_right) <= x.shape[-1]
118
+ end = x.shape[-1] - padding_right
119
+ return x[..., padding_left: end]
120
+
121
+
122
+ class NormConv1d(nn.Module):
123
+ """Wrapper around Conv1d and normalization applied to this conv
124
+ to provide a uniform interface across normalization approaches.
125
+ """
126
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
127
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
128
+ super().__init__()
129
+ self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
130
+ self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
131
+ self.norm_type = norm
132
+
133
+ def forward(self, x):
134
+ x = self.conv(x)
135
+ x = self.norm(x)
136
+ return x
137
+
138
+
139
+ class NormConv2d(nn.Module):
140
+ """Wrapper around Conv2d and normalization applied to this conv
141
+ to provide a uniform interface across normalization approaches.
142
+ """
143
+ def __init__(self, *args, norm: str = 'none',
144
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
145
+ super().__init__()
146
+ self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
147
+ self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
148
+ self.norm_type = norm
149
+
150
+ def forward(self, x):
151
+ x = self.conv(x)
152
+ x = self.norm(x)
153
+ return x
154
+
155
+
156
+ class NormConvTranspose1d(nn.Module):
157
+ """Wrapper around ConvTranspose1d and normalization applied to this conv
158
+ to provide a uniform interface across normalization approaches.
159
+ """
160
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
161
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
162
+ super().__init__()
163
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
164
+ self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
165
+ self.norm_type = norm
166
+
167
+ def forward(self, x):
168
+ x = self.convtr(x)
169
+ x = self.norm(x)
170
+ return x
171
+
172
+
173
+ class NormConvTranspose2d(nn.Module):
174
+ """Wrapper around ConvTranspose2d and normalization applied to this conv
175
+ to provide a uniform interface across normalization approaches.
176
+ """
177
+ def __init__(self, *args, norm: str = 'none',
178
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
179
+ super().__init__()
180
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
181
+ self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
182
+
183
+ def forward(self, x):
184
+ x = self.convtr(x)
185
+ x = self.norm(x)
186
+ return x
187
+
188
+
189
+ class SConv1d(nn.Module):
190
+ """Conv1d with some builtin handling of asymmetric or causal padding
191
+ and normalization.
192
+ """
193
+ def __init__(self, in_channels: int, out_channels: int,
194
+ kernel_size: int, stride: int = 1, dilation: int = 1,
195
+ groups: int = 1, bias: bool = True, causal: bool = False,
196
+ norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
197
+ pad_mode: str = 'reflect'):
198
+ super().__init__()
199
+ # warn user on unusual setup between dilation and stride
200
+ if stride > 1 and dilation > 1:
201
+ warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1'
202
+ f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).')
203
+ self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
204
+ dilation=dilation, groups=groups, bias=bias, causal=causal,
205
+ norm=norm, norm_kwargs=norm_kwargs)
206
+ self.causal = causal
207
+ self.pad_mode = pad_mode
208
+
209
+ def forward(self, x):
210
+ B, C, T = x.shape
211
+ kernel_size = self.conv.conv.kernel_size[0]
212
+ stride = self.conv.conv.stride[0]
213
+ dilation = self.conv.conv.dilation[0]
214
+ kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
215
+ padding_total = kernel_size - stride
216
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
217
+ if self.causal:
218
+ # Left padding for causal
219
+ x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
220
+ else:
221
+ # Asymmetric padding required for odd strides
222
+ padding_right = padding_total // 2
223
+ padding_left = padding_total - padding_right
224
+ x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
225
+ return self.conv(x)
226
+
227
+
228
+ class SConvTranspose1d(nn.Module):
229
+ """ConvTranspose1d with some builtin handling of asymmetric or causal padding
230
+ and normalization.
231
+ """
232
+ def __init__(self, in_channels: int, out_channels: int,
233
+ kernel_size: int, stride: int = 1, causal: bool = False,
234
+ norm: str = 'none', trim_right_ratio: float = 1.,
235
+ norm_kwargs: tp.Dict[str, tp.Any] = {}):
236
+ super().__init__()
237
+ self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
238
+ causal=causal, norm=norm, norm_kwargs=norm_kwargs)
239
+ self.causal = causal
240
+ self.trim_right_ratio = trim_right_ratio
241
+ assert self.causal or self.trim_right_ratio == 1., \
242
+ "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
243
+ assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
244
+
245
+ def forward(self, x):
246
+ kernel_size = self.convtr.convtr.kernel_size[0]
247
+ stride = self.convtr.convtr.stride[0]
248
+ padding_total = kernel_size - stride
249
+
250
+ y = self.convtr(x)
251
+
252
+ # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
253
+ # removed at the very end, when keeping only the right length for the output,
254
+ # as removing it here would require also passing the length at the matching layer
255
+ # in the encoder.
256
+ if self.causal:
257
+ # Trim the padding on the right according to the specified ratio
258
+ # if trim_right_ratio = 1.0, trim everything from right
259
+ padding_right = math.ceil(padding_total * self.trim_right_ratio)
260
+ padding_left = padding_total - padding_right
261
+ y = unpad1d(y, (padding_left, padding_right))
262
+ else:
263
+ # Asymmetric padding required for odd strides
264
+ padding_right = padding_total // 2
265
+ padding_left = padding_total - padding_right
266
+ y = unpad1d(y, (padding_left, padding_right))
267
+ return y
268
+
269
+ def get_2d_padding(
270
+ kernel_size: tp.Tuple[int, int],
271
+ dilation: tp.Tuple[int, int] = (1, 1),
272
+ ) -> tp.Tuple[int, int]:
273
+ return (
274
+ ((kernel_size[0] - 1) * dilation[0]) // 2,
275
+ ((kernel_size[1] - 1) * dilation[1]) // 2,
276
+ )
277
+
278
+
279
+ class Snake(nn.Module):
280
+ def __init__(self, channels: int, alpha_init: float = 1.0):
281
+ super().__init__()
282
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1) * alpha_init)
283
+
284
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
285
+ return x + torch.sin(self.alpha * x).pow(2) / (self.alpha.abs() + 1e-9)
286
+
287
+
288
+ class SLSTM(nn.Module):
289
+ """
290
+ LSTM without worrying about the hidden state, nor the layout of the data.
291
+ Expects input as convolutional layout.
292
+ """
293
+ def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
294
+ super().__init__()
295
+ self.skip = skip
296
+ self.lstm = nn.LSTM(dimension, dimension, num_layers)
297
+
298
+ def forward(self, x):
299
+ x = x.permute(2, 0, 1)
300
+ y, _ = self.lstm(x)
301
+ if self.skip:
302
+ y = y + x
303
+ y = y.permute(1, 2, 0)
304
+ return y
305
+
306
+
307
+ class NormConv2d(nn.Module):
308
+ """Small self-contained Conv2d wrapper.
309
+
310
+ The original EnCodec code imports NormConv2d from its internal modules.
311
+ This local version avoids making the discriminator depend on the full
312
+ EnCodec codebase.
313
+ """
314
+
315
+ def __init__(
316
+ self,
317
+ in_channels: int,
318
+ out_channels: int,
319
+ kernel_size: tp.Tuple[int, int],
320
+ stride: tp.Tuple[int, int] = (1, 1),
321
+ dilation: tp.Tuple[int, int] = (1, 1),
322
+ padding: tp.Tuple[int, int] = (0, 0),
323
+ norm: str = "weight_norm",
324
+ bias: bool = True,
325
+ ) -> None:
326
+ super().__init__()
327
+ conv = nn.Conv2d(
328
+ in_channels=in_channels,
329
+ out_channels=out_channels,
330
+ kernel_size=kernel_size,
331
+ stride=stride,
332
+ dilation=dilation,
333
+ padding=padding,
334
+ bias=bias,
335
+ )
336
+
337
+ if norm == "weight_norm":
338
+ conv = nn.utils.weight_norm(conv)
339
+ elif norm == "spectral_norm":
340
+ conv = nn.utils.spectral_norm(conv)
341
+ elif norm in {"none", None}:
342
+ pass
343
+ else:
344
+ raise ValueError(f"Unsupported norm={norm!r}. Use 'weight_norm', 'spectral_norm', or 'none'.")
345
+
346
+ self.conv = conv
347
+
348
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
349
+ return self.conv(x)
350
+
351
+ def get_2d_padding(
352
+ kernel_size: tp.Tuple[int, int],
353
+ dilation: tp.Tuple[int, int] = (1, 1),
354
+ ) -> tp.Tuple[int, int]:
355
+ """Same-padding approximation for Conv2d on [B, C, time, freq]."""
356
+ pad_time = ((kernel_size[0] - 1) * dilation[0]) // 2
357
+ pad_freq = ((kernel_size[1] - 1) * dilation[1]) // 2
358
+ return pad_time, pad_freq