LenDigLearn commited on
Commit
c5b3420
·
verified ·
1 Parent(s): 8219d4b

Upload NeucodecDecoder.py

Browse files
Files changed (1) hide show
  1. NeucodecDecoder.py +565 -0
NeucodecDecoder.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from einops import rearrange
5
+ from torchtune.modules import RotaryPositionalEmbeddings
6
+ from vector_quantize_pytorch import ResidualFSQ
7
+
8
+ from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
9
+
10
+ # the following implementations were taken from the NeuCodec repository and slightly changed
11
+ # sources https://github.com/neuphonic/neucodec/blob/main/neucodec/model.py, https://github.com/neuphonic/neucodec/blob/main/neucodec/codec_decoder_vocos.py and https://github.com/neuphonic/neucodec/blob/main/neucodec/bs_roformer5.py
12
+
13
+ class RMSNorm(torch.nn.Module):
14
+ def __init__(self, dim: int, eps: float = 1e-6):
15
+ r"""https://github.com/meta-llama/llama/blob/main/llama/model.py"""
16
+ super().__init__()
17
+ self.eps = eps
18
+ self.weight = nn.Parameter(torch.ones(dim))
19
+
20
+ def forward(self, x):
21
+ norm_x = torch.mean(x**2, dim=-1, keepdim=True)
22
+ output = x * torch.rsqrt(norm_x + self.eps) * self.weight
23
+ return output
24
+
25
+
26
+ class MLP(nn.Module):
27
+ def __init__(self, dim: int) -> None:
28
+ super().__init__()
29
+
30
+ self.fc1 = nn.Linear(dim, 4 * dim, bias=False)
31
+ self.silu = nn.SiLU()
32
+ self.fc2 = nn.Linear(4 * dim, dim, bias=False)
33
+
34
+ def forward(self, x):
35
+ x = self.fc1(x)
36
+ x = self.silu(x)
37
+ x = self.fc2(x)
38
+ return x
39
+
40
+
41
+ class Attention(nn.Module):
42
+ def __init__(
43
+ self, dim: int, n_heads: int, rotary_embed: RotaryPositionalEmbeddings
44
+ ):
45
+ super().__init__()
46
+
47
+ assert dim % n_heads == 0
48
+
49
+ self.n_heads = n_heads
50
+ self.dim = dim
51
+ self.rotary_embed = rotary_embed
52
+
53
+ self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention")
54
+ assert self.flash, "Must have flash attention."
55
+
56
+ self.c_attn = nn.Linear(dim, 3 * dim, bias=False)
57
+ self.c_proj = nn.Linear(dim, dim, bias=False)
58
+
59
+ def forward(self, x):
60
+ r"""
61
+ Args:
62
+ x: (b, t, h*d)
63
+
64
+ Constants:
65
+ b: batch_size
66
+ t: time steps
67
+ r: 3
68
+ h: heads_num
69
+ d: heads_dim
70
+ """
71
+ B, T, C = x.size()
72
+
73
+ q, k, v = rearrange(
74
+ self.c_attn(x), "b t (r h d) -> r b h t d", r=3, h=self.n_heads
75
+ )
76
+ # q, k, v: (b, h, t, d)
77
+
78
+ q = self.rotary_embed(q)
79
+ k = self.rotary_embed(k)
80
+
81
+ if self.flash:
82
+ y = torch.nn.functional.scaled_dot_product_attention(
83
+ q, k, v, attn_mask=None, dropout_p=0, is_causal=False
84
+ )
85
+
86
+ y = rearrange(y, "b h t d -> b t (h d)")
87
+
88
+ y = self.c_proj(y)
89
+ # shape: (b, t, h*d)
90
+
91
+ return y
92
+
93
+
94
+ class TransformerBlock(nn.Module):
95
+ def __init__(
96
+ self, dim: int, n_heads: int, rotary_embed: RotaryPositionalEmbeddings
97
+ ):
98
+ super().__init__()
99
+ self.dim = dim
100
+ self.n_heads = n_heads
101
+
102
+ self.att_norm = RMSNorm(dim)
103
+ self.ffn_norm = RMSNorm(dim)
104
+ self.att = Attention(dim=dim, n_heads=n_heads, rotary_embed=rotary_embed)
105
+ self.mlp = MLP(dim=dim)
106
+
107
+ def forward(
108
+ self,
109
+ x: torch.Tensor,
110
+ ):
111
+ x = x + self.att(self.att_norm(x))
112
+ x = x + self.mlp(self.ffn_norm(x))
113
+ return x
114
+
115
+ class ISTFT(nn.Module):
116
+ """
117
+ Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
118
+ windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
119
+ See issue: https://github.com/pytorch/pytorch/issues/62323
120
+ Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
121
+ The NOLA constraint is met as we trim padded samples anyway.
122
+
123
+ Args:
124
+ n_fft (int): Size of Fourier transform.
125
+ hop_length (int): The distance between neighboring sliding window frames.
126
+ win_length (int): The size of window frame and STFT filter.
127
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
128
+ """
129
+
130
+ def __init__(
131
+ self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"
132
+ ):
133
+ super().__init__()
134
+ if padding not in ["center", "same"]:
135
+ raise ValueError("Padding must be 'center' or 'same'.")
136
+ self.padding = padding
137
+ self.n_fft = n_fft
138
+ self.hop_length = hop_length
139
+ self.win_length = win_length
140
+ window = torch.hann_window(win_length)
141
+ self.register_buffer("window", window, persistent=False) # changed persistent to False for safetensors compatibility
142
+
143
+ def forward(self, spec: torch.Tensor) -> torch.Tensor:
144
+ """
145
+ Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
146
+
147
+ Args:
148
+ spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
149
+ N is the number of frequency bins, and T is the number of time frames.
150
+
151
+ Returns:
152
+ Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
153
+ """
154
+ if self.padding == "center":
155
+ # Fallback to pytorch native implementation
156
+ return torch.istft(
157
+ spec,
158
+ self.n_fft,
159
+ self.hop_length,
160
+ self.win_length,
161
+ self.window,
162
+ center=True,
163
+ )
164
+ elif self.padding == "same":
165
+ pad = (self.win_length - self.hop_length) // 2
166
+ else:
167
+ raise ValueError("Padding must be 'center' or 'same'.")
168
+
169
+ assert spec.dim() == 3, "Expected a 3D tensor as input"
170
+ B, N, T = spec.shape
171
+
172
+ # Inverse FFT
173
+ ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
174
+ ifft = ifft * self.window[None, :, None]
175
+
176
+ # Overlap and Add
177
+ output_size = (T - 1) * self.hop_length + self.win_length
178
+ y = torch.nn.functional.fold(
179
+ ifft,
180
+ output_size=(1, output_size),
181
+ kernel_size=(1, self.win_length),
182
+ stride=(1, self.hop_length),
183
+ )[:, 0, 0, pad:-pad]
184
+
185
+ # Window envelope
186
+ window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
187
+ window_envelope = torch.nn.functional.fold(
188
+ window_sq,
189
+ output_size=(1, output_size),
190
+ kernel_size=(1, self.win_length),
191
+ stride=(1, self.hop_length),
192
+ ).squeeze()[pad:-pad]
193
+
194
+ # Normalize
195
+ assert (window_envelope > 1e-11).all()
196
+ y = y / window_envelope
197
+
198
+ return y
199
+
200
+
201
+ class FourierHead(nn.Module):
202
+ """Base class for inverse fourier modules."""
203
+
204
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
205
+ """
206
+ Args:
207
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
208
+ L is the sequence length, and H denotes the model dimension.
209
+
210
+ Returns:
211
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
212
+ """
213
+ raise NotImplementedError("Subclasses must implement the forward method.")
214
+
215
+
216
+ class ISTFTHead(FourierHead):
217
+ """
218
+ ISTFT Head module for predicting STFT complex coefficients.
219
+
220
+ Args:
221
+ dim (int): Hidden dimension of the model.
222
+ n_fft (int): Size of Fourier transform.
223
+ hop_length (int): The distance between neighboring sliding window frames, which should align with
224
+ the resolution of the input features.
225
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
226
+ """
227
+
228
+ def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
229
+ super().__init__()
230
+ out_dim = n_fft + 2
231
+ self.out = torch.nn.Linear(dim, out_dim)
232
+ self.istft = ISTFT(
233
+ n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
234
+ )
235
+
236
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
237
+ """
238
+ Forward pass of the ISTFTHead module.
239
+
240
+ Args:
241
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
242
+ L is the sequence length, and H denotes the model dimension.
243
+
244
+ Returns:
245
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
246
+ """
247
+ x_pred = self.out(x)
248
+ # x_pred = x
249
+ x_pred = x_pred.transpose(1, 2)
250
+ mag, p = x_pred.chunk(2, dim=1)
251
+ mag = torch.exp(mag)
252
+ mag = torch.clip(
253
+ mag, max=1e2
254
+ ) # safeguard to prevent excessively large magnitudes
255
+ # wrapping happens here. These two lines produce real and imaginary value
256
+ x = torch.cos(p)
257
+ y = torch.sin(p)
258
+ # recalculating phase here does not produce anything new
259
+ # only costs time
260
+ # phase = torch.atan2(y, x)
261
+ # S = mag * torch.exp(phase * 1j)
262
+ # better directly produce the complex value
263
+ S = mag * (x + 1j * y)
264
+ audio = self.istft(S)
265
+ return audio.unsqueeze(1), x_pred
266
+
267
+
268
+ def nonlinearity(x):
269
+ # swish
270
+ return x * torch.sigmoid(x)
271
+
272
+
273
+ def Normalize(in_channels, num_groups=32):
274
+ return torch.nn.GroupNorm(
275
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
276
+ )
277
+
278
+
279
+ class ResnetBlock(nn.Module):
280
+ def __init__(
281
+ self,
282
+ *,
283
+ in_channels,
284
+ out_channels=None,
285
+ conv_shortcut=False,
286
+ dropout,
287
+ temb_channels=512,
288
+ ):
289
+ super().__init__()
290
+ self.in_channels = in_channels
291
+ out_channels = in_channels if out_channels is None else out_channels
292
+ self.out_channels = out_channels
293
+ self.use_conv_shortcut = conv_shortcut
294
+
295
+ self.norm1 = Normalize(in_channels)
296
+ self.conv1 = torch.nn.Conv1d(
297
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
298
+ )
299
+ if temb_channels > 0:
300
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
301
+ self.norm2 = Normalize(out_channels)
302
+ self.dropout = torch.nn.Dropout(dropout)
303
+ self.conv2 = torch.nn.Conv1d(
304
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
305
+ )
306
+ if self.in_channels != self.out_channels:
307
+ if self.use_conv_shortcut:
308
+ self.conv_shortcut = torch.nn.Conv1d(
309
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
310
+ )
311
+ else:
312
+ self.nin_shortcut = torch.nn.Conv1d(
313
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
314
+ )
315
+
316
+ def forward(self, x, temb=None):
317
+ h = x
318
+ h = self.norm1(h)
319
+ h = nonlinearity(h)
320
+ h = self.conv1(h)
321
+
322
+ if temb is not None:
323
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
324
+
325
+ h = self.norm2(h)
326
+ h = nonlinearity(h)
327
+ h = self.dropout(h)
328
+ h = self.conv2(h)
329
+
330
+ if self.in_channels != self.out_channels:
331
+ if self.use_conv_shortcut:
332
+ x = self.conv_shortcut(x)
333
+ else:
334
+ x = self.nin_shortcut(x)
335
+
336
+ return x + h
337
+
338
+
339
+ class Backbone(nn.Module):
340
+ """Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
341
+
342
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
343
+ """
344
+ Args:
345
+ x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
346
+ C denotes output features, and L is the sequence length.
347
+
348
+ Returns:
349
+ Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
350
+ and H denotes the model dimension.
351
+ """
352
+ raise NotImplementedError("Subclasses must implement the forward method.")
353
+
354
+
355
+ class VocosBackbone(Backbone):
356
+ """
357
+ Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
358
+
359
+ Args:
360
+ input_channels (int): Number of input features channels.
361
+ dim (int): Hidden dimension of the model.
362
+ intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
363
+ num_layers (int): Number of ConvNeXtBlock layers.
364
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
365
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
366
+ None means non-conditional model. Defaults to None.
367
+ """
368
+
369
+ def __init__(self, hidden_dim=1024, depth=12, heads=16, pos_meb_dim=64):
370
+ super().__init__()
371
+
372
+ self.embed = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=7, padding=3)
373
+
374
+ self.temb_ch = 0
375
+ block_in = hidden_dim
376
+ dropout = 0.1
377
+
378
+ prior_net: List[nn.Module] = [
379
+ ResnetBlock(
380
+ in_channels=block_in,
381
+ out_channels=block_in,
382
+ temb_channels=self.temb_ch,
383
+ dropout=dropout,
384
+ ),
385
+ ResnetBlock(
386
+ in_channels=block_in,
387
+ out_channels=block_in,
388
+ temb_channels=self.temb_ch,
389
+ dropout=dropout,
390
+ ),
391
+ ]
392
+ self.prior_net = nn.Sequential(*prior_net)
393
+
394
+ depth = depth
395
+ time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim)
396
+
397
+ transformer_blocks = [
398
+ TransformerBlock(
399
+ dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed
400
+ )
401
+ for _ in range(depth)
402
+ ]
403
+
404
+ self.transformers = nn.Sequential(*transformer_blocks)
405
+ self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6)
406
+ post_net: List[nn.Module] = [
407
+ ResnetBlock(
408
+ in_channels=block_in,
409
+ out_channels=block_in,
410
+ temb_channels=self.temb_ch,
411
+ dropout=dropout,
412
+ ),
413
+ ResnetBlock(
414
+ in_channels=block_in,
415
+ out_channels=block_in,
416
+ temb_channels=self.temb_ch,
417
+ dropout=dropout,
418
+ ),
419
+ ]
420
+ self.post_net = nn.Sequential(*post_net)
421
+
422
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
423
+ x = x.transpose(1, 2)
424
+ x = self.embed(x)
425
+ x = self.prior_net(x)
426
+ x = x.transpose(1, 2)
427
+ x = self.transformers(x)
428
+ x = x.transpose(1, 2)
429
+ x = self.post_net(x)
430
+ x = x.transpose(1, 2)
431
+ x = self.final_layer_norm(x)
432
+ return x
433
+
434
+
435
+ def init_weights(m):
436
+ if isinstance(m, nn.Conv1d):
437
+ nn.init.trunc_normal_(m.weight, std=0.02)
438
+ nn.init.constant_(m.bias, 0)
439
+
440
+ class CodecDecoderVocos(nn.Module):
441
+ def __init__(
442
+ self,
443
+ hidden_dim=1024,
444
+ depth=12,
445
+ heads=16,
446
+ pos_meb_dim=64,
447
+ hop_length=320,
448
+ vq_num_quantizers=1,
449
+ vq_dim=2048, # 1024 2048
450
+ vq_commit_weight=0.25,
451
+ vq_weight_init=False,
452
+ vq_full_commit_loss=False,
453
+ codebook_size=16384,
454
+ codebook_dim=16,
455
+ ):
456
+ super().__init__()
457
+ self.hop_length = hop_length
458
+
459
+ self.quantizer = ResidualFSQ(
460
+ dim=vq_dim, levels=[4, 4, 4, 4, 4, 4, 4, 4], num_quantizers=1
461
+ )
462
+
463
+ self.backbone = VocosBackbone(
464
+ hidden_dim=hidden_dim, depth=depth, heads=heads, pos_meb_dim=pos_meb_dim
465
+ )
466
+
467
+ self.head = ISTFTHead(
468
+ dim=hidden_dim,
469
+ n_fft=self.hop_length * 4,
470
+ hop_length=self.hop_length,
471
+ padding="same",
472
+ )
473
+
474
+ self.reset_parameters()
475
+
476
+ def forward(self, x, vq=True):
477
+ if vq is True:
478
+ # x, q, commit_loss = self.quantizer(x)
479
+ x = x.permute(0, 2, 1)
480
+ x, q = self.quantizer(x)
481
+ x = x.permute(0, 2, 1)
482
+ q = q.permute(0, 2, 1)
483
+ return x, q, None
484
+ x = self.backbone(x)
485
+ x, _ = self.head(x)
486
+
487
+ return x, _
488
+
489
+ def vq2emb(self, vq):
490
+ self.quantizer = self.quantizer.eval()
491
+ x = self.quantizer.vq2emb(vq)
492
+ return x
493
+
494
+ def get_emb(self):
495
+ self.quantizer = self.quantizer.eval()
496
+ embs = self.quantizer.get_emb()
497
+ return embs
498
+
499
+ def inference_vq(self, vq):
500
+ x = vq[None, :, :]
501
+ x = self.model(x)
502
+ return x
503
+
504
+ def inference_0(self, x):
505
+ x, q, loss, perp = self.quantizer(x)
506
+ x = self.model(x)
507
+ return x, None
508
+
509
+ def inference(self, x):
510
+ x = self.model(x)
511
+ return x, None
512
+
513
+ def remove_weight_norm(self):
514
+ """Remove weight normalization module from all of the layers."""
515
+
516
+ def _remove_weight_norm(m):
517
+ try:
518
+ torch.nn.utils.remove_weight_norm(m)
519
+ except ValueError: # this module didn't have weight norm
520
+ return
521
+
522
+ self.apply(_remove_weight_norm)
523
+
524
+ def apply_weight_norm(self):
525
+ """Apply weight normalization module from all of the layers."""
526
+
527
+ def _apply_weight_norm(m):
528
+ if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
529
+ torch.nn.utils.weight_norm(m)
530
+
531
+ self.apply(_apply_weight_norm)
532
+
533
+ def reset_parameters(self):
534
+ self.apply(init_weights)
535
+
536
+ class NeuCodecDecoder(
537
+ nn.Module,
538
+ PyTorchModelHubMixin
539
+ ):
540
+
541
+ def __init__(self, sample_rate: int, hop_length: int):
542
+ super().__init__()
543
+ self.sample_rate = sample_rate
544
+ self.hop_length = hop_length
545
+ self.generator = CodecDecoderVocos(hop_length=hop_length)
546
+ self.fc_post_a = nn.Linear(2048, 1024)
547
+
548
+ @property
549
+ def device(self):
550
+ return next(self.parameters()).device
551
+
552
+ def decode_code(self, fsq_codes: torch.Tensor) -> torch.Tensor:
553
+ """
554
+ Args:
555
+ fsq_codes: torch.Tensor [B, 1, F], 50hz FSQ codes
556
+
557
+ Returns:
558
+ recon: torch.Tensor [B, 1, T], reconstructed 24kHz audio
559
+ """
560
+
561
+ fsq_post_emb = self.generator.quantizer.get_output_from_indices(fsq_codes.transpose(1, 2))
562
+ fsq_post_emb = fsq_post_emb.transpose(1, 2)
563
+ fsq_post_emb = self.fc_post_a(fsq_post_emb.transpose(1, 2)).transpose(1, 2)
564
+ recon = self.generator(fsq_post_emb.transpose(1, 2), vq=False)[0]
565
+ return recon