timofeiiz commited on
Commit
57dbf98
·
verified ·
1 Parent(s): 1de3069

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. README.md +1 -3
  2. config.json +6 -0
  3. model.py +386 -0
  4. model.safetensors +3 -0
README.md CHANGED
@@ -1,3 +1 @@
1
- ---
2
- license: mit
3
- ---
 
1
+ Soundstream implementation. Sample rate $16000$.
 
 
config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "channels": 32,
3
+ "codebook_size": 1024,
4
+ "latent_dim": 512,
5
+ "num_quantizers": 8
6
+ }
model.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+ from huggingface_hub import PyTorchModelHubMixin
5
+
6
+
7
+ class CausalConv1d(nn.Module):
8
+
9
+ def __init__(
10
+ self,
11
+ in_channels: int,
12
+ out_channels: int,
13
+ kernel_size: int,
14
+ stride: int = 1,
15
+ dilation: int = 1,
16
+ ):
17
+ super().__init__()
18
+ self.left_padding = dilation * (kernel_size - 1)
19
+ self.conv = nn.Conv1d(
20
+ in_channels=in_channels,
21
+ out_channels=out_channels,
22
+ kernel_size=kernel_size,
23
+ stride=stride,
24
+ dilation=dilation,
25
+ )
26
+
27
+ def forward(self, x):
28
+ x = F.pad(x, (self.left_padding, 0))
29
+ return self.conv(x)
30
+
31
+
32
+ class CausalConvTranspose1d(nn.Module):
33
+
34
+ def __init__(
35
+ self,
36
+ in_channels: int,
37
+ out_channels: int,
38
+ kernel_size: int,
39
+ stride: int,
40
+ ):
41
+ super().__init__()
42
+ self.stride = stride
43
+ self.conv_transpose = nn.ConvTranspose1d(
44
+ in_channels=in_channels,
45
+ out_channels=out_channels,
46
+ kernel_size=kernel_size,
47
+ stride=stride,
48
+ )
49
+
50
+ def forward(self, x):
51
+ target_length = x.shape[-1] * self.stride
52
+ x = self.conv_transpose(x)
53
+ return x[..., :target_length]
54
+
55
+
56
+ class ResidualUnit(nn.Module):
57
+
58
+ def __init__(self, channels: int, dilation: int):
59
+ super().__init__()
60
+ self.block = nn.Sequential(
61
+ nn.ELU(),
62
+ CausalConv1d(kernel_size=7, in_channels=channels, out_channels=channels, dilation=dilation),
63
+ nn.ELU(),
64
+ nn.Conv1d(kernel_size=1, in_channels=channels, out_channels=channels)
65
+ )
66
+
67
+ def forward(self, x):
68
+ return x + self.block(x)
69
+
70
+
71
+ class EncoderBlock(nn.Module):
72
+
73
+ def __init__(self, channels: int, s: int):
74
+ super().__init__()
75
+ self.block = nn.Sequential(
76
+ ResidualUnit(channels=channels // 2, dilation=1),
77
+ ResidualUnit(channels=channels // 2, dilation=3),
78
+ ResidualUnit(channels=channels // 2, dilation=9),
79
+ CausalConv1d(kernel_size=2 * s, in_channels=channels // 2, out_channels=channels, stride=s)
80
+ )
81
+
82
+ def forward(self, x):
83
+ return self.block(x)
84
+
85
+
86
+ class DecoderBlock(nn.Module):
87
+
88
+ def __init__(self, channels: int, s: int):
89
+ super().__init__()
90
+ self.block = nn.Sequential(
91
+ CausalConvTranspose1d(kernel_size=2 * s, in_channels=channels, out_channels=channels // 2, stride=s),
92
+ ResidualUnit(channels=channels // 2, dilation=1),
93
+ ResidualUnit(channels=channels // 2, dilation=3),
94
+ ResidualUnit(channels=channels // 2, dilation=9)
95
+ )
96
+
97
+ def forward(self, x):
98
+ return self.block(x)
99
+
100
+
101
+ class Encoder(nn.Module):
102
+
103
+ def __init__(self, channels: int = 16, dim: int = 512):
104
+ super().__init__()
105
+ # NB: attribute name "encoder" matches training checkpoint keys
106
+ self.encoder = nn.Sequential(
107
+ CausalConv1d(kernel_size=7, in_channels=1, out_channels=channels),
108
+ EncoderBlock(channels=2 * channels, s=2),
109
+ EncoderBlock(channels=4 * channels, s=4),
110
+ EncoderBlock(channels=8 * channels, s=5),
111
+ EncoderBlock(channels=16 * channels, s=5),
112
+ CausalConv1d(kernel_size=3, in_channels=16 * channels, out_channels=dim)
113
+ )
114
+
115
+ def forward(self, audio):
116
+ return self.encoder(audio)
117
+
118
+
119
+ class Decoder(nn.Module):
120
+
121
+ def __init__(self, channels: int = 16, dim: int = 512):
122
+ super().__init__()
123
+ # NB: attribute name "decoder" matches training checkpoint keys
124
+ self.decoder = nn.Sequential(
125
+ CausalConv1d(kernel_size=7, in_channels=dim, out_channels=16 * channels),
126
+ DecoderBlock(channels=16 * channels, s=5),
127
+ DecoderBlock(channels=8 * channels, s=5),
128
+ DecoderBlock(channels=4 * channels, s=4),
129
+ DecoderBlock(channels=2 * channels, s=2),
130
+ CausalConv1d(kernel_size=7, in_channels=channels, out_channels=1)
131
+ )
132
+
133
+ def forward(self, quantized):
134
+ return self.decoder(quantized)
135
+
136
+
137
+ @torch.no_grad()
138
+ def _k_means(vectors, num_clusters, num_iters):
139
+ n = vectors.size(0)
140
+ device = vectors.device
141
+
142
+ if n >= num_clusters:
143
+ init_indices = torch.randperm(n, device=device)[:num_clusters]
144
+ else:
145
+ init_indices = torch.randint(0, n, (num_clusters,), device=device)
146
+
147
+ centroids = vectors[init_indices].clone()
148
+
149
+ for _ in range(num_iters):
150
+ dists = (
151
+ vectors.pow(2).sum(1, keepdim=True)
152
+ - 2 * vectors @ centroids.t()
153
+ + centroids.pow(2).sum(1)
154
+ )
155
+ assignments = dists.argmin(1)
156
+
157
+ counts = torch.bincount(assignments, minlength=num_clusters).to(vectors.dtype)
158
+ sums = torch.zeros_like(centroids)
159
+ sums.index_add_(0, assignments, vectors)
160
+
161
+ non_empty = counts > 0
162
+ if non_empty.any():
163
+ centroids[non_empty] = sums[non_empty] / counts[non_empty].unsqueeze(1)
164
+
165
+ empty = ~non_empty
166
+ if empty.any():
167
+ centroids[empty] = vectors[torch.randint(0, n, (int(empty.sum()),), device=device)]
168
+
169
+ dists = (
170
+ vectors.pow(2).sum(1, keepdim=True)
171
+ - 2 * vectors @ centroids.t()
172
+ + centroids.pow(2).sum(1)
173
+ )
174
+ counts = torch.bincount(dists.argmin(1), minlength=num_clusters).to(vectors.dtype)
175
+ return centroids, counts
176
+
177
+
178
+ class VectorQuantizer(nn.Module):
179
+
180
+ def __init__(
181
+ self,
182
+ codebook_size: int,
183
+ latent_dim: int,
184
+ decay: float = 0.99,
185
+ dead_code_threshold: float = 2.0,
186
+ kmeans_iters: int = 50,
187
+ ):
188
+ super().__init__()
189
+ self.codebook_size = codebook_size
190
+ self.latent_dim = latent_dim
191
+ self.decay = decay
192
+ self.dead_code_threshold = dead_code_threshold
193
+ self.kmeans_iters = kmeans_iters
194
+ self.eps = 1e-8
195
+
196
+ self.register_buffer("initialized", torch.tensor(False, dtype=torch.bool))
197
+ self.register_buffer("embedding", torch.randn(codebook_size, latent_dim))
198
+ self.register_buffer("ema_n", torch.zeros(codebook_size))
199
+ self.register_buffer("ema_s", torch.zeros(codebook_size, latent_dim))
200
+
201
+ def forward(self, latent):
202
+ B, D, T = latent.shape
203
+ flat = latent.transpose(1, 2).reshape(-1, D)
204
+
205
+ if self.training and not self.initialized:
206
+ self._init_codebook(flat)
207
+ self.initialized.fill_(True)
208
+
209
+ idx, quantized = self._nearest(flat)
210
+
211
+ if self.training:
212
+ self._update_ema(flat, idx)
213
+ self._replace_dead_codes(flat)
214
+
215
+ commit_loss = F.mse_loss(flat, quantized.detach())
216
+ quantized_ste = flat + (quantized - flat).detach()
217
+
218
+ return {
219
+ "quantized": quantized_ste.reshape(B, T, D).transpose(1, 2).contiguous(),
220
+ "indices": idx.reshape(B, T),
221
+ "commitment_loss": commit_loss,
222
+ }
223
+
224
+ def _nearest(self, flat):
225
+ dists = (
226
+ flat.pow(2).sum(1, keepdim=True)
227
+ - 2 * flat @ self.embedding.t()
228
+ + self.embedding.pow(2).sum(1)
229
+ )
230
+ idx = dists.argmin(1)
231
+ return idx, self.embedding[idx]
232
+
233
+ @torch.no_grad()
234
+ def _init_codebook(self, flat):
235
+ centroids, counts = _k_means(flat, self.codebook_size, self.kmeans_iters)
236
+ counts = counts.clamp_min(1.0)
237
+ w = counts / counts.mean()
238
+ self.embedding.copy_(centroids)
239
+ self.ema_n.copy_(w)
240
+ self.ema_s.copy_(centroids * w.unsqueeze(1))
241
+
242
+ @torch.no_grad()
243
+ def _update_ema(self, flat, indices):
244
+ bins = torch.bincount(indices, minlength=self.codebook_size).to(flat.dtype)
245
+ sums = torch.zeros_like(self.ema_s)
246
+ sums.index_add_(0, indices, flat)
247
+ self.ema_n.mul_(self.decay).add_(bins, alpha=1 - self.decay)
248
+ self.ema_s.mul_(self.decay).add_(sums, alpha=1 - self.decay)
249
+ self.embedding.copy_(self.ema_s / (self.ema_n.unsqueeze(1) + self.eps))
250
+
251
+ @torch.no_grad()
252
+ def _replace_dead_codes(self, flat):
253
+ dead = self.ema_n < self.dead_code_threshold
254
+ if not dead.any():
255
+ return
256
+ n = int(dead.sum())
257
+ picks = flat[torch.randint(0, flat.size(0), (n,), device=flat.device)]
258
+ self.embedding[dead] = picks
259
+ self.ema_s[dead] = picks
260
+ self.ema_n[dead] = self.dead_code_threshold
261
+
262
+ @torch.no_grad()
263
+ def quantize(self, latent):
264
+ B, D, T = latent.shape
265
+ flat = latent.transpose(1, 2).reshape(-1, D)
266
+ idx, quantized = self._nearest(flat)
267
+ return {
268
+ "quantized": quantized.reshape(B, T, D).transpose(1, 2).contiguous(),
269
+ "indices": idx.reshape(B, T),
270
+ }
271
+
272
+ def decode_indices(self, indices):
273
+ return self.embedding[indices].transpose(1, 2).contiguous()
274
+
275
+
276
+ class ResidualVectorQuantizer(nn.Module):
277
+
278
+ def __init__(
279
+ self,
280
+ latent_dim: int,
281
+ num_quantizers: int = 8,
282
+ codebook_size: int = 1024,
283
+ ):
284
+ super().__init__()
285
+ self.num_quantizers = num_quantizers
286
+ self.quantizers = nn.ModuleList(
287
+ VectorQuantizer(codebook_size, latent_dim) for _ in range(num_quantizers)
288
+ )
289
+
290
+ def forward(self, latent):
291
+ residual = latent
292
+ quantized = torch.zeros_like(latent)
293
+ total_commit = latent.new_zeros(())
294
+ all_indices = []
295
+
296
+ for vq in self.quantizers:
297
+ out = vq(residual)
298
+ quantized = quantized + out["quantized"]
299
+ residual = residual - out["quantized"].detach()
300
+ total_commit = total_commit + out["commitment_loss"]
301
+ all_indices.append(out["indices"])
302
+
303
+ return {
304
+ "quantized": quantized,
305
+ "indices": torch.stack(all_indices, dim=1),
306
+ "commitment_loss": total_commit,
307
+ }
308
+
309
+ @torch.no_grad()
310
+ def encode(self, latent):
311
+ residual = latent
312
+ all_indices = []
313
+ for vq in self.quantizers:
314
+ out = vq.quantize(residual)
315
+ all_indices.append(out["indices"])
316
+ residual = residual - out["quantized"]
317
+ return torch.stack(all_indices, dim=1)
318
+
319
+ @torch.no_grad()
320
+ def decode(self, indices):
321
+ quantized = None
322
+ for i, vq in enumerate(self.quantizers):
323
+ stage = vq.decode_indices(indices[:, i])
324
+ quantized = stage if quantized is None else quantized + stage
325
+ return quantized
326
+
327
+
328
+ class SoundStreamCodec(
329
+ nn.Module,
330
+ PyTorchModelHubMixin,
331
+ library_name="soundstream-impl",
332
+ license="mit",
333
+ ):
334
+ def __init__(
335
+ self,
336
+ channels: int = 32,
337
+ latent_dim: int = 512,
338
+ codebook_size: int = 1024,
339
+ num_quantizers: int = 8,
340
+ ):
341
+ super().__init__()
342
+ self.strides = (2, 4, 5, 5)
343
+ self.downsampling_factor = 1
344
+ for s in self.strides:
345
+ self.downsampling_factor *= s
346
+
347
+ self.encoder = Encoder(channels=channels, dim=latent_dim)
348
+ self.quantizer = ResidualVectorQuantizer(
349
+ latent_dim=latent_dim,
350
+ codebook_size=codebook_size,
351
+ num_quantizers=num_quantizers,
352
+ )
353
+ self.decoder = Decoder(channels=channels, dim=latent_dim)
354
+
355
+ def forward(self, audio, **kwargs):
356
+ original_length = audio.size(-1)
357
+ audio = self._pad_to_stride(audio)
358
+
359
+ latent = self.encoder(audio)
360
+ q_out = self.quantizer(latent)
361
+ reconstructed = self.decoder(q_out["quantized"])
362
+ reconstructed = reconstructed[..., :original_length]
363
+
364
+ return {
365
+ "reconstructed_audio": reconstructed,
366
+ "latent": latent,
367
+ **q_out,
368
+ }
369
+
370
+ @torch.no_grad()
371
+ def encode(self, audio):
372
+ audio = self._pad_to_stride(audio)
373
+ return self.quantizer.encode(self.encoder(audio))
374
+
375
+ @torch.no_grad()
376
+ def decode(self, indices, original_length=None):
377
+ out = self.decoder(self.quantizer.decode(indices))
378
+ if original_length is not None:
379
+ out = out[..., :original_length]
380
+ return out
381
+
382
+ def _pad_to_stride(self, audio):
383
+ remainder = audio.size(-1) % self.downsampling_factor
384
+ if remainder == 0:
385
+ return audio
386
+ return F.pad(audio, (0, self.downsampling_factor - remainder), mode="replicate")
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eaf8b7aa910c717a05558fd1bf214823564331504c29a9cf2cf50fc3bf29e452
3
+ size 74533004