klemenk commited on
Commit
bc6461b
·
verified ·
1 Parent(s): 230e45f

Update modeling_wavtokenizer.py

Browse files
Files changed (1) hide show
  1. modeling_wavtokenizer.py +451 -617
modeling_wavtokenizer.py CHANGED
@@ -1,611 +1,541 @@
1
  """
2
- WavTokenizer Model for HuggingFace Transformers
3
 
4
- This module contains the complete implementation of WavTokenizer,
5
- an acoustic discrete codec tokenizer for audio language modeling.
6
- All dependencies are included to avoid external imports.
7
-
8
- The architecture follows the original WavTokenizer implementation:
9
- - Encoder: Strided convolutions for audio compression
10
- - VQ: Vector quantization with single codebook
11
- - Decoder: Vocos-style backbone with ConvNeXt blocks + iSTFT head
12
-
13
- Reference: https://github.com/jishengpeng/WavTokenizer
14
- Paper: "WavTokenizer: an Efficient Acoustic Discrete Codec Tokenizer for Audio Language Modeling"
15
  """
16
 
17
  import math
18
- from typing import Dict, List, Optional, Tuple, Union
19
- from dataclasses import dataclass
20
 
21
  import torch
22
  import torch.nn as nn
23
  import torch.nn.functional as F
24
- from torch import Tensor
25
- from torch.nn.utils import weight_norm, remove_weight_norm
26
-
27
  from transformers import PreTrainedModel
28
- from transformers.tokenization_utils import BatchEncoding
29
 
30
  from .configuration_wavtokenizer import WavTokenizerConfig
31
 
32
 
33
- # ==============================================================================
34
- # Utility Functions
35
- # ==============================================================================
36
 
37
- def convert_audio(wav: Tensor, sr: int, target_sr: int, target_channels: int) -> Tensor:
38
- """
39
- Convert audio to target sample rate and number of channels.
40
-
41
- Args:
42
- wav: Input waveform [C, T] or [T]
43
- sr: Source sample rate
44
- target_sr: Target sample rate
45
- target_channels: Target number of channels (1 for mono, 2 for stereo)
46
-
47
- Returns:
48
- Converted waveform [target_channels, T']
49
- """
50
- import torchaudio
51
-
52
- # Ensure 2D
53
  if wav.dim() == 1:
54
- wav = wav.unsqueeze(0)
 
 
55
 
56
- # Convert channels
57
- if wav.size(0) > target_channels:
58
- wav = wav.mean(dim=0, keepdim=True)
59
- elif wav.size(0) < target_channels:
60
- wav = wav.expand(target_channels, -1)
61
 
62
- # Resample if needed
63
  if sr != target_sr:
64
- wav = torchaudio.functional.resample(wav, sr, target_sr)
65
 
66
  return wav
67
 
68
 
69
- # ==============================================================================
70
- # Encoder Components (DAC-style)
71
- # ==============================================================================
72
 
73
- def WNConv1d(*args, **kwargs):
74
- """Weight-normalized Conv1d."""
75
- return weight_norm(nn.Conv1d(*args, **kwargs))
 
 
 
 
 
 
 
76
 
77
 
78
- def WNConvTranspose1d(*args, **kwargs):
79
  """Weight-normalized ConvTranspose1d."""
80
- return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
 
 
 
 
 
 
 
81
 
82
 
83
- class ResidualUnit(nn.Module):
84
- """Residual unit with dilated convolution."""
85
-
86
- def __init__(self, dim: int = 16, dilation: int = 1):
 
 
 
87
  super().__init__()
88
- pad = ((7 - 1) * dilation) // 2
89
- self.block = nn.Sequential(
90
- nn.ELU(),
91
- WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
92
- nn.ELU(),
93
- WNConv1d(dim, dim, kernel_size=1),
94
- )
95
 
96
- def forward(self, x: Tensor) -> Tensor:
97
- return x + self.block(x)
98
 
99
 
100
- class EncoderBlock(nn.Module):
101
- """Encoder block with residual units and downsampling."""
 
 
 
 
 
 
 
 
102
 
103
- def __init__(self, dim: int = 16, stride: int = 1):
 
 
 
 
 
 
104
  super().__init__()
105
- self.block = nn.Sequential(
106
- ResidualUnit(dim // 2, dilation=1),
107
- ResidualUnit(dim // 2, dilation=3),
108
- ResidualUnit(dim // 2, dilation=9),
109
- nn.ELU(),
110
- WNConv1d(
111
- dim // 2, dim,
112
- kernel_size=2 * stride,
113
- stride=stride,
114
- padding=math.ceil(stride / 2),
115
- ),
116
- )
117
 
118
- def forward(self, x: Tensor) -> Tensor:
119
- return self.block(x)
 
 
 
120
 
121
 
122
- class Encoder(nn.Module):
123
  """
124
- DAC-style encoder that compresses waveform to latent representation.
125
- Uses strided convolutions for downsampling.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  """
127
-
128
- def __init__(
129
- self,
130
- d_model: int = 64,
131
- strides: List[int] = [8, 5, 4, 2],
132
- d_latent: int = 512,
133
- ):
134
  super().__init__()
135
 
136
- # Initial conv
137
- self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
138
 
139
- # Encoder blocks with increasing channels
140
- for stride in strides:
141
- d_model *= 2
142
- self.block.append(EncoderBlock(d_model, stride=stride))
143
 
144
- # Final projection
145
- self.block.extend([
146
- nn.ELU(),
147
- WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
148
- ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
- self.block = nn.Sequential(*self.block)
151
- self.enc_dim = d_model
152
 
153
- def forward(self, x: Tensor) -> Tensor:
154
- return self.block(x)
155
 
156
 
157
- # ==============================================================================
158
- # Vector Quantization
159
- # ==============================================================================
160
 
161
- class VectorQuantize(nn.Module):
162
- """
163
- Improved vector quantization with EMA codebook updates.
164
-
165
- Uses L2-normalized codes for better stability.
166
- """
167
-
168
- def __init__(
169
- self,
170
- input_dim: int,
171
- codebook_size: int,
172
- codebook_dim: int,
173
- commitment: float = 0.25,
174
- ):
175
  super().__init__()
176
-
177
- self.input_dim = input_dim
178
- self.codebook_size = codebook_size
179
- self.codebook_dim = codebook_dim
180
- self.commitment = commitment
181
-
182
- # Projections
183
- requires_projection = input_dim != codebook_dim
184
- self.project_in = nn.Linear(input_dim, codebook_dim) if requires_projection else nn.Identity()
185
- self.project_out = nn.Linear(codebook_dim, input_dim) if requires_projection else nn.Identity()
186
-
187
- # Codebook
188
- self.codebook = nn.Embedding(codebook_size, codebook_dim)
189
- nn.init.uniform_(self.codebook.weight, -1.0 / codebook_size, 1.0 / codebook_size)
190
 
191
- def forward(self, z: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
192
  """
193
- Forward pass.
194
-
195
  Args:
196
- z: Input [B, D, T]
197
-
198
  Returns:
199
- z_q: Quantized [B, D, T]
200
- commitment_loss: Loss scalar
201
- indices: Codes [B, T]
202
  """
203
- # [B, D, T] -> [B, T, D]
204
- z = z.transpose(1, 2)
205
- z_e = self.project_in(z)
206
-
207
  # L2 normalize
208
- z_e_norm = F.normalize(z_e, dim=-1)
209
- codebook_norm = F.normalize(self.codebook.weight, dim=-1)
210
-
211
- # Find nearest codes
212
- dist = (
213
- z_e_norm.pow(2).sum(-1, keepdim=True)
214
- + codebook_norm.pow(2).sum(-1)
215
- - 2 * torch.einsum('btd,kd->btk', z_e_norm, codebook_norm)
216
- )
217
- indices = dist.argmin(dim=-1)
218
 
219
- # Look up quantized values
220
- z_q = F.embedding(indices, codebook_norm)
 
221
 
222
- # Commitment loss
223
- commitment_loss = F.mse_loss(z_e_norm, z_q.detach()) * self.commitment
224
 
225
  # Straight-through
226
- z_q = z_e_norm + (z_q - z_e_norm).detach()
227
 
228
- # Project out and transpose back
229
- z_q = self.project_out(z_q)
230
- z_q = z_q.transpose(1, 2) # [B, D, T]
231
-
232
- return z_q, commitment_loss, indices
233
 
234
- def decode(self, indices: Tensor) -> Tensor:
235
- """Decode indices to vectors."""
236
- codebook = F.normalize(self.codebook.weight, dim=-1)
237
- z_q = F.embedding(indices, codebook)
238
- z_q = self.project_out(z_q)
239
- return z_q.transpose(1, 2)
240
 
241
 
242
- class ResidualVectorQuantize(nn.Module):
243
- """Residual VQ with multiple codebooks (typically 1 for WavTokenizer)."""
 
 
 
244
 
245
- def __init__(
246
- self,
247
- input_dim: int = 512,
248
- codebook_size: int = 4096,
249
- codebook_dim: int = 8,
250
- num_quantizers: int = 1,
251
- commitment: float = 0.25,
252
- ):
 
 
 
 
 
 
253
  super().__init__()
254
-
255
- self.num_quantizers = num_quantizers
256
- self.quantizers = nn.ModuleList([
257
- VectorQuantize(input_dim, codebook_size, codebook_dim, commitment)
258
- for _ in range(num_quantizers)
259
  ])
260
 
261
- def forward(
262
- self, z: Tensor, n_quantizers: int = None
263
- ) -> Tuple[Tensor, Tensor, Tensor]:
264
- n_q = n_quantizers or self.num_quantizers
265
-
266
- residual = z
267
- z_q = torch.zeros_like(z)
268
- all_indices = []
269
- all_losses = []
270
 
271
- for i, quantizer in enumerate(self.quantizers[:n_q]):
272
- _z_q, loss, indices = quantizer(residual)
273
- residual = residual - _z_q
274
- z_q = z_q + _z_q
275
- all_indices.append(indices)
276
- all_losses.append(loss)
277
 
278
- codes = torch.stack(all_indices, dim=0) # [N_q, B, T]
279
- commitment_loss = sum(all_losses)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
- return z_q, commitment_loss, codes
 
 
282
 
283
- def decode(self, codes: Tensor) -> Tensor:
284
- """Decode codes to vectors."""
285
- if codes.dim() == 2:
286
- codes = codes.unsqueeze(0)
 
 
 
 
 
 
287
 
288
- z_q = None
289
- for i, quantizer in enumerate(self.quantizers[:codes.size(0)]):
290
- _z_q = quantizer.decode(codes[i])
291
- z_q = _z_q if z_q is None else z_q + _z_q
292
 
293
- return z_q
294
-
 
 
 
 
 
295
 
296
- # ==============================================================================
297
- # Decoder Components (Vocos-style)
298
- # ==============================================================================
299
 
300
  class ConvNeXtBlock(nn.Module):
301
- """ConvNeXt block with depthwise conv + pointwise expansion."""
302
-
303
- def __init__(
304
- self,
305
- dim: int,
306
- intermediate_dim: int,
307
- kernel_size: int = 7,
308
- layer_scale_init_value: float = 1e-6,
309
- ):
 
 
 
 
 
 
310
  super().__init__()
311
-
312
  padding = (kernel_size - 1) // 2
 
313
  self.dwconv = nn.Conv1d(dim, dim, kernel_size, padding=padding, groups=dim)
314
- self.norm = nn.LayerNorm(dim)
315
  self.pwconv1 = nn.Linear(dim, intermediate_dim)
316
- self.act = nn.GELU()
317
  self.pwconv2 = nn.Linear(intermediate_dim, dim)
318
-
319
- self.gamma = nn.Parameter(
320
- layer_scale_init_value * torch.ones(dim)
321
- ) if layer_scale_init_value > 0 else None
322
 
323
- def forward(self, x: Tensor) -> Tensor:
324
  residual = x
325
  x = self.dwconv(x)
326
- x = x.transpose(1, 2) # [B, T, D]
327
- x = self.norm(x)
328
  x = self.pwconv1(x)
329
- x = self.act(x)
330
  x = self.pwconv2(x)
331
- if self.gamma is not None:
332
- x = self.gamma * x
333
- x = x.transpose(1, 2) # [B, D, T]
334
  return residual + x
335
 
336
 
337
- class VocosBackbone(nn.Module):
338
- """Vocos backbone with attention and ConvNeXt blocks."""
 
339
 
340
- def __init__(
341
- self,
342
- input_dim: int,
343
- dim: int,
344
- intermediate_dim: int,
345
- num_blocks: int,
346
- kernel_size: int = 7,
347
- layer_scale_init_value: float = 1e-6,
348
- use_attention: bool = True,
349
- num_heads: int = 8,
350
- num_attention_layers: int = 1,
351
- ):
352
  super().__init__()
353
 
354
- # Input projection
355
- self.input_conv = nn.Conv1d(input_dim, dim, kernel_size=7, padding=3)
356
- self.norm = nn.LayerNorm(dim)
357
-
358
- # Attention layers
359
- self.use_attention = use_attention
360
- if use_attention:
361
- self.attention = nn.ModuleList([
362
- nn.MultiheadAttention(dim, num_heads, batch_first=True)
363
- for _ in range(num_attention_layers)
364
- ])
365
- self.attn_norms = nn.ModuleList([
366
- nn.LayerNorm(dim) for _ in range(num_attention_layers)
367
- ])
368
 
369
- # ConvNeXt blocks
 
 
 
370
  self.convnext = nn.ModuleList([
371
- ConvNeXtBlock(dim, intermediate_dim, kernel_size, layer_scale_init_value)
372
  for _ in range(num_blocks)
373
  ])
374
 
375
- self.final_norm = nn.LayerNorm(dim)
 
376
 
377
- def forward(self, x: Tensor) -> Tensor:
378
  # Input projection
379
- x = self.input_conv(x)
380
- x = x.transpose(1, 2) # [B, T, D]
381
- x = self.norm(x)
382
- x = x.transpose(1, 2) # [B, D, T]
383
-
384
- # Attention
385
- if self.use_attention:
386
- for attn, norm in zip(self.attention, self.attn_norms):
387
- x_t = x.transpose(1, 2) # [B, T, D]
388
- residual = x_t
389
- x_t = norm(x_t)
390
- x_t, _ = attn(x_t, x_t, x_t)
391
- x_t = residual + x_t
392
- x = x_t.transpose(1, 2) # [B, D, T]
393
 
394
  # ConvNeXt blocks
395
  for block in self.convnext:
396
- x = block(x)
397
 
398
  # Final norm
399
- x = x.transpose(1, 2)
400
- x = self.final_norm(x)
401
- x = x.transpose(1, 2)
402
 
403
  return x
404
 
405
 
406
- class ISTFTHead(nn.Module):
407
- """Inverse STFT head for waveform synthesis."""
408
-
409
- def __init__(
410
- self,
411
- dim: int,
412
- n_fft: int,
413
- hop_length: int,
414
- padding: str = "center",
415
- ):
416
  super().__init__()
417
-
418
  self.n_fft = n_fft
419
- self.hop_length = hop_length
420
- self.padding = padding
421
-
422
- self.out_dim = n_fft // 2 + 1
423
- self.proj = nn.Conv1d(dim, self.out_dim * 2, kernel_size=1)
424
-
425
- # Register window buffer
426
- self.register_buffer(
427
- "window",
428
- torch.hann_window(n_fft),
429
- persistent=False
430
- )
431
-
432
- def forward(self, x: Tensor) -> Tensor:
433
- """
434
- Args:
435
- x: [B, D, T]
436
- Returns:
437
- wav: [B, 1, T']
438
- """
439
- x = self.proj(x)
440
-
441
- # Split mag/phase
442
- mag, phase = x.chunk(2, dim=1)
443
-
444
- # Process
445
- mag = torch.exp(mag)
446
- phase = torch.sin(phase)
447
-
448
- # Complex spectrum
449
- S = torch.complex(mag * torch.cos(phase * math.pi), mag * torch.sin(phase * math.pi))
450
-
451
- # Ensure window is on same device
452
- window = self.window.to(x.device)
453
-
454
- # iSTFT
455
- wav = torch.istft(
456
- S,
457
- n_fft=self.n_fft,
458
- hop_length=self.hop_length,
459
- window=window,
460
- center=True,
461
- normalized=False,
462
- onesided=True,
463
- return_complex=False,
464
- )
465
-
466
- return wav.unsqueeze(1)
467
 
468
 
469
- # ==============================================================================
470
- # Feature Extractor (Mel Spectrogram)
471
- # ==============================================================================
472
-
473
- class MelSpectrogramFeatures(nn.Module):
474
- """Extract mel spectrogram features from audio."""
475
 
476
- def __init__(
477
- self,
478
- sample_rate: int = 24000,
479
- n_fft: int = 1024,
480
- hop_length: int = 256,
481
- n_mels: int = 100,
482
- f_min: float = 0.0,
483
- f_max: float = None,
484
- padding: str = "center",
485
- ):
486
  super().__init__()
487
-
488
- self.sample_rate = sample_rate
489
  self.n_fft = n_fft
490
  self.hop_length = hop_length
491
- self.n_mels = n_mels
492
  self.padding = padding
493
 
494
- # Mel filterbank
495
- import torchaudio
496
- mel_fb = torchaudio.functional.melscale_fbanks(
497
- n_freqs=n_fft // 2 + 1,
498
- f_min=f_min,
499
- f_max=f_max or sample_rate // 2,
500
- n_mels=n_mels,
501
- sample_rate=sample_rate,
502
- norm="slaney",
503
- mel_scale="slaney",
504
- )
505
- self.register_buffer("mel_fb", mel_fb, persistent=False)
506
- self.register_buffer("window", torch.hann_window(n_fft), persistent=False)
507
 
508
- def forward(self, wav: Tensor) -> Tensor:
509
  """
510
  Args:
511
- wav: [B, 1, T] or [B, T]
512
  Returns:
513
- mel: [B, n_mels, T']
514
  """
515
- if wav.dim() == 3:
516
- wav = wav.squeeze(1)
517
-
518
- # STFT
519
- stft = torch.stft(
520
- wav,
 
 
 
 
 
 
 
 
 
 
521
  n_fft=self.n_fft,
522
  hop_length=self.hop_length,
523
- window=self.window.to(wav.device),
524
- center=True,
525
- return_complex=True,
 
526
  )
527
 
528
- # Power spectrum
529
- power = stft.abs().pow(2)
530
-
531
- # Mel spectrogram
532
- mel = torch.matmul(self.mel_fb.T.to(power.device), power)
533
-
534
- # Log scale
535
- mel = torch.log(mel.clamp(min=1e-5))
536
-
537
- return mel
538
 
539
 
540
- # ==============================================================================
541
  # Main WavTokenizer Model
542
- # ==============================================================================
543
 
544
  class WavTokenizer(PreTrainedModel):
545
  """
546
- WavTokenizer: Efficient acoustic discrete codec tokenizer.
547
 
548
- Architecture:
549
- - Encoder: Strided convolutions for audio compression
550
- - VQ: Single-codebook vector quantization (4096 codes)
551
- - Decoder: Vocos backbone (ConvNeXt + attention) + iSTFT head
552
-
553
- Usage:
554
- ```python
555
- model = WavTokenizer.from_pretrained("TuKoResearch/WavTokenizerSmall", trust_remote_code=True)
556
-
557
- # Encode
558
- features, codes = model.encode_infer(wav, bandwidth_id=torch.tensor([0]))
559
-
560
- # Decode
561
- wav_out = model.decode(features, bandwidth_id=torch.tensor([0]))
562
-
563
- # Or use codes directly
564
- features = model.codes_to_features(codes)
565
- wav_out = model.decode(features, bandwidth_id=torch.tensor([0]))
566
- ```
567
  """
568
 
569
  config_class = WavTokenizerConfig
 
570
 
571
  def __init__(self, config: WavTokenizerConfig):
572
  super().__init__(config)
573
-
574
- self.sample_rate = config.sample_rate
575
- self.hop_length = config.hop_length
576
-
577
- # Encoder
578
- self.encoder = Encoder(
579
- d_model=config.encoder_dim,
580
- strides=config.encoder_rates,
581
- d_latent=config.latent_dim,
582
- )
583
-
584
- # Quantizer
585
- self.quantizer = ResidualVectorQuantize(
586
- input_dim=config.latent_dim,
587
  codebook_size=config.codebook_size,
588
- codebook_dim=config.codebook_dim,
589
  num_quantizers=config.num_quantizers,
590
  )
591
 
592
- # Feature projection for decoder
593
- self.feature_proj = nn.Conv1d(config.latent_dim, config.backbone_dim, 1)
594
-
595
- # Decoder backbone
596
- self.backbone = VocosBackbone(
597
- input_dim=config.backbone_dim,
598
  dim=config.backbone_dim,
599
  intermediate_dim=config.backbone_intermediate_dim,
600
  num_blocks=config.backbone_num_blocks,
601
- kernel_size=config.backbone_kernel_size,
602
- layer_scale_init_value=config.backbone_layer_scale_init_value,
603
- use_attention=config.use_attention,
604
- num_heads=config.attention_heads,
605
- num_attention_layers=config.attention_layers,
606
  )
607
 
608
- # iSTFT head
 
609
  self.head = ISTFTHead(
610
  dim=config.backbone_dim,
611
  n_fft=config.n_fft,
@@ -613,201 +543,105 @@ class WavTokenizer(PreTrainedModel):
613
  padding=config.padding,
614
  )
615
 
616
- # Bandwidth embedding
617
- self.bandwidth_emb = nn.Embedding(4, config.backbone_dim)
618
-
619
  self.post_init()
620
 
621
- @property
622
- def vocab_size(self) -> int:
623
- return self.config.codebook_size
624
-
625
- @property
626
- def frame_rate(self) -> float:
627
- return self.config.sample_rate / self.config.hop_length
628
-
629
- def encode(
630
- self, wav: Tensor, bandwidth_id: Tensor = None
631
- ) -> Tuple[Tensor, Tensor, Tensor]:
632
  """
633
- Encode waveform to quantized features.
634
 
635
  Args:
636
- wav: [B, 1, T] or [B, T]
637
- bandwidth_id: Optional bandwidth ID
638
 
639
  Returns:
640
- z_q: Quantized features [B, D, T']
641
- commitment_loss: VQ loss
642
- codes: Discrete codes [N_q, B, T']
643
  """
644
- if wav.dim() == 2:
645
- wav = wav.unsqueeze(1)
646
-
647
- z = self.encoder(wav)
648
- z_q, loss, codes = self.quantizer(z)
649
-
650
- return z_q, loss, codes
651
 
652
- @torch.no_grad()
653
- def encode_infer(
654
- self, wav: Tensor, bandwidth_id: Tensor = None
655
- ) -> Tuple[Tensor, Tensor]:
656
  """
657
- Encode waveform to features and codes (inference).
658
 
659
  Args:
660
- wav: [B, 1, T] or [1, T] or [B, T]
661
- bandwidth_id: Optional bandwidth ID
662
 
663
  Returns:
664
- features: [B, D, T']
665
- codes: [B, T'] (squeezed if single quantizer)
666
  """
667
- if wav.dim() == 2:
668
- if wav.size(0) == 1:
669
- wav = wav.unsqueeze(0) # [1, T] -> [1, 1, T]
670
- else:
671
- wav = wav.unsqueeze(1) # [B, T] -> [B, 1, T]
672
-
673
- z = self.encoder(wav)
674
- z_q, _, codes = self.quantizer(z)
675
-
676
- # Squeeze for single quantizer
677
- if codes.size(0) == 1:
678
- codes = codes.squeeze(0)
679
-
680
- return z_q, codes
681
 
682
- def decode(
683
- self, features: Tensor, bandwidth_id: Tensor = None
684
- ) -> Tensor:
685
  """
686
- Decode features to waveform.
687
 
688
  Args:
689
- features: [B, D, T']
690
- bandwidth_id: Optional bandwidth ID
691
 
692
  Returns:
693
- wav: [B, 1, T]
694
  """
695
- x = self.feature_proj(features)
696
-
697
- if bandwidth_id is not None:
698
- bw_emb = self.bandwidth_emb(bandwidth_id)
699
- x = x + bw_emb.unsqueeze(-1)
700
-
701
- x = self.backbone(x)
702
- wav = self.head(x)
703
-
704
- return wav
705
 
706
- @torch.no_grad()
707
- def codes_to_features(self, codes: Tensor) -> Tensor:
708
  """
709
- Convert codes to features.
710
 
711
  Args:
712
- codes: [N_q, B, T'] or [B, T']
713
 
714
  Returns:
715
- features: [B, D, T']
716
  """
717
- return self.quantizer.decode(codes)
 
 
718
 
719
  def forward(
720
  self,
721
- wav: Tensor = None,
722
- codes: Tensor = None,
723
- bandwidth_id: Tensor = None,
724
- **kwargs
725
- ) -> Union[BatchEncoding, Tensor]:
726
- """
727
- Forward pass.
728
-
729
- If wav provided: encode to get tokens
730
- If codes provided: decode to get wav
731
- """
732
- if wav is not None:
733
- features, codes = self.encode_infer(wav, bandwidth_id)
734
- return BatchEncoding({
735
- "input_values": features,
736
- "input_ids": codes,
737
- })
738
- elif codes is not None:
739
- features = self.codes_to_features(codes)
740
- return self.decode(features, bandwidth_id)
741
- else:
742
- raise ValueError("Provide either 'wav' or 'codes'")
743
-
744
- @classmethod
745
- def from_pretrained0802(
746
- cls,
747
- config_path: str,
748
- checkpoint_path: str,
749
- device: str = "cpu",
750
- ) -> "WavTokenizer":
751
  """
752
- Load from original WavTokenizer checkpoint.
753
 
754
  Args:
755
- config_path: Path to YAML config
756
- checkpoint_path: Path to .ckpt file
757
- device: Device to load to
758
 
759
  Returns:
760
- Loaded model
761
  """
762
- import yaml
763
-
764
- # Load YAML config
765
- with open(config_path, 'r') as f:
766
- yaml_cfg = yaml.safe_load(f)
767
-
768
- # Extract config params
769
- model_args = yaml_cfg.get('model', {}).get('init_args', {})
770
-
771
- # Create HF config
772
- config = WavTokenizerConfig(
773
- sample_rate=24000,
774
- n_fft=model_args.get('head', {}).get('init_args', {}).get('n_fft', 1280),
775
- hop_length=model_args.get('head', {}).get('init_args', {}).get('hop_length', 320),
776
- feature_dim=model_args.get('backbone', {}).get('init_args', {}).get('dim', 512),
777
- latent_dim=model_args.get('backbone', {}).get('init_args', {}).get('input_channels', 512),
778
- backbone_dim=model_args.get('backbone', {}).get('init_args', {}).get('dim', 512),
779
- backbone_intermediate_dim=model_args.get('backbone', {}).get('init_args', {}).get('intermediate_dim', 1536),
780
- backbone_num_blocks=model_args.get('backbone', {}).get('init_args', {}).get('num_layers', 8),
781
- codebook_size=model_args.get('quantizer', {}).get('init_args', {}).get('codebook_size', 4096),
782
- codebook_dim=model_args.get('quantizer', {}).get('init_args', {}).get('codebook_dim', 8),
783
- num_quantizers=model_args.get('quantizer', {}).get('init_args', {}).get('num_quantizers', 1),
784
- use_attention=True,
785
- attention_dim=model_args.get('backbone', {}).get('init_args', {}).get('dim', 512),
786
- attention_heads=8,
787
- attention_layers=1,
788
- )
789
-
790
- # Create model
791
- model = cls(config)
792
-
793
- # Load checkpoint
794
- ckpt = torch.load(checkpoint_path, map_location=device)
795
- state_dict = ckpt.get('state_dict', ckpt)
796
-
797
- # Clean state dict
798
- new_state_dict = {}
799
- for k, v in state_dict.items():
800
- # Remove 'model.' prefix if present
801
- if k.startswith('model.'):
802
- k = k[6:]
803
- new_state_dict[k] = v
804
-
805
- # Load (non-strict to handle mismatches)
806
- missing, unexpected = model.load_state_dict(new_state_dict, strict=False)
807
-
808
- if missing:
809
- print(f"Missing keys: {len(missing)}")
810
- if unexpected:
811
- print(f"Unexpected keys: {len(unexpected)}")
812
 
813
- return model.to(device)
 
 
1
  """
2
+ WavTokenizer model implementation for HuggingFace.
3
 
4
+ This implementation exactly matches the checkpoint structure for direct weight loading.
 
 
 
 
 
 
 
 
 
 
5
  """
6
 
7
  import math
8
+ from typing import Optional, Tuple, Union
 
9
 
10
  import torch
11
  import torch.nn as nn
12
  import torch.nn.functional as F
 
 
 
13
  from transformers import PreTrainedModel
14
+ from transformers.modeling_outputs import BaseModelOutput
15
 
16
  from .configuration_wavtokenizer import WavTokenizerConfig
17
 
18
 
19
+ # =============================================================================
20
+ # Audio Utilities
21
+ # =============================================================================
22
 
23
+ def convert_audio(wav, sr, target_sr, target_channels=1):
24
+ """Convert audio to target sample rate and channels."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  if wav.dim() == 1:
26
+ wav = wav.unsqueeze(0).unsqueeze(0)
27
+ elif wav.dim() == 2:
28
+ wav = wav.unsqueeze(1)
29
 
30
+ if wav.shape[1] > target_channels:
31
+ wav = wav[:, :target_channels, :]
32
+ elif wav.shape[1] < target_channels:
33
+ wav = wav.repeat(1, target_channels, 1)
 
34
 
 
35
  if sr != target_sr:
36
+ wav = F.interpolate(wav, size=int(wav.shape[-1] * target_sr / sr), mode='linear', align_corners=False)
37
 
38
  return wav
39
 
40
 
41
+ # =============================================================================
42
+ # Weight-Normalized Conv1d (matching checkpoint's weight_g/weight_v structure)
43
+ # =============================================================================
44
 
45
+ class WNConv1d(nn.Module):
46
+ """Weight-normalized Conv1d matching checkpoint structure with weight_g/weight_v."""
47
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
48
+ super().__init__()
49
+ self.conv = nn.utils.weight_norm(
50
+ nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
51
+ )
52
+
53
+ def forward(self, x):
54
+ return self.conv(x)
55
 
56
 
57
+ class WNConvTranspose1d(nn.Module):
58
  """Weight-normalized ConvTranspose1d."""
59
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True):
60
+ super().__init__()
61
+ self.convtr = nn.utils.weight_norm(
62
+ nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, padding, output_padding, groups, bias)
63
+ )
64
+
65
+ def forward(self, x):
66
+ return self.convtr(x)
67
 
68
 
69
+ # =============================================================================
70
+ # Encoder (EnCodec-style, matching feature_extractor.encodec.encoder.model.*)
71
+ # =============================================================================
72
+
73
+ class _ConvWrapper(nn.Module):
74
+ """Wrapper to match checkpoint structure: conv.conv.weight_g, conv.conv.weight_v, conv.conv.bias"""
75
+ def __init__(self, in_ch, out_ch, kernel_size, stride=1, padding=0):
76
  super().__init__()
77
+ self.conv = WNConv1d(in_ch, out_ch, kernel_size, stride=stride, padding=padding)
 
 
 
 
 
 
78
 
79
+ def forward(self, x):
80
+ return self.conv(x)
81
 
82
 
83
+ class _ResBlockWrapper(nn.Module):
84
+ """Wrapper to match checkpoint structure: block.1.conv.conv, block.3.conv.conv, shortcut.conv.conv"""
85
+ def __init__(self, dim):
86
+ super().__init__()
87
+ self.block = nn.Sequential()
88
+ self.block.add_module('0', nn.ELU())
89
+ self.block.add_module('1', _ConvWrapper(dim, dim // 2, 3, padding=1))
90
+ self.block.add_module('2', nn.ELU())
91
+ self.block.add_module('3', _ConvWrapper(dim // 2, dim, 1))
92
+ self.shortcut = _ConvWrapper(dim, dim, 1)
93
 
94
+ def forward(self, x):
95
+ return self.shortcut(x) + self.block(x)
96
+
97
+
98
+ class _LSTMWrapper(nn.Module):
99
+ """LSTM wrapper matching checkpoint: lstm.weight_ih_l0, etc."""
100
+ def __init__(self, dim, num_layers=2):
101
  super().__init__()
102
+ self.lstm = nn.LSTM(dim, dim, num_layers=num_layers, batch_first=True)
 
 
 
 
 
 
 
 
 
 
 
103
 
104
+ def forward(self, x):
105
+ x = x.transpose(1, 2)
106
+ y, _ = self.lstm(x)
107
+ y = y + x
108
+ return y.transpose(1, 2)
109
 
110
 
111
+ class EncoderModel(nn.Module):
112
  """
113
+ Encoder matching checkpoint: feature_extractor.encodec.encoder.model.*
114
+
115
+ Structure based on checkpoint:
116
+ - model.0: initial conv (1 -> 32)
117
+ - model.1: residual block (32)
118
+ - model.2: ELU (not saved)
119
+ - model.3: downsample conv (32->64, stride=2)
120
+ - model.4: residual block (64)
121
+ - model.5: ELU
122
+ - model.6: downsample conv (64->128, stride=4)
123
+ - model.7: residual block (128)
124
+ - model.8: ELU
125
+ - model.9: downsample conv (128->256, stride=5)
126
+ - model.10: residual block (256)
127
+ - model.11: ELU
128
+ - model.12: downsample conv (256->512, stride=8)
129
+ - model.13: LSTM
130
+ - model.14: ELU
131
+ - model.15: output conv (512->512)
132
  """
133
+ def __init__(self, channels=1, n_filters=32, dimension=512, ratios=[2, 4, 5, 8]):
 
 
 
 
 
 
134
  super().__init__()
135
 
136
+ layers = []
 
137
 
138
+ # model.0: Initial conv
139
+ layers.append(_ConvWrapper(channels, n_filters, 7, padding=3))
 
 
140
 
141
+ # Encoder blocks with downsampling
142
+ in_ch = n_filters
143
+ for ratio in ratios:
144
+ out_ch = in_ch * 2
145
+ # Residual block
146
+ layers.append(_ResBlockWrapper(in_ch))
147
+ # ELU (implicit in original, but we need it)
148
+ layers.append(nn.ELU())
149
+ # Downsample conv
150
+ layers.append(_ConvWrapper(in_ch, out_ch, ratio * 2, stride=ratio, padding=ratio // 2))
151
+ in_ch = out_ch
152
+
153
+ # LSTM
154
+ layers.append(_LSTMWrapper(in_ch))
155
+
156
+ # ELU
157
+ layers.append(nn.ELU())
158
+
159
+ # Output conv
160
+ layers.append(_ConvWrapper(in_ch, dimension, 7, padding=3))
161
 
162
+ self.model = nn.Sequential(*layers)
 
163
 
164
+ def forward(self, x):
165
+ return self.model(x)
166
 
167
 
168
+ # =============================================================================
169
+ # Quantizer (matching feature_extractor.encodec.quantizer.vq.layers.0._codebook.*)
170
+ # =============================================================================
171
 
172
+ class Codebook(nn.Module):
173
+ """Codebook matching checkpoint: _codebook.embed, _codebook.inited, _codebook.cluster_size, _codebook.embed_avg"""
174
+ def __init__(self, num_embeddings, embedding_dim):
 
 
 
 
 
 
 
 
 
 
 
175
  super().__init__()
176
+ # These match checkpoint structure exactly
177
+ self.register_buffer('inited', torch.zeros(1))
178
+ self.register_buffer('cluster_size', torch.zeros(num_embeddings))
179
+ self.register_buffer('embed', torch.randn(num_embeddings, embedding_dim))
180
+ self.register_buffer('embed_avg', torch.randn(num_embeddings, embedding_dim))
 
 
 
 
 
 
 
 
 
181
 
182
+ def forward(self, x):
183
  """
 
 
184
  Args:
185
+ x: (B, T, D) input
 
186
  Returns:
187
+ quantized: (B, T, D) quantized output
188
+ indices: (B, T) codebook indices
 
189
  """
 
 
 
 
190
  # L2 normalize
191
+ embed = F.normalize(self.embed, dim=-1)
192
+ x_norm = F.normalize(x, dim=-1)
 
 
 
 
 
 
 
 
193
 
194
+ # Find nearest
195
+ dist = torch.cdist(x_norm, embed)
196
+ indices = dist.argmin(dim=-1)
197
 
198
+ # Quantize
199
+ quantized = F.embedding(indices, embed)
200
 
201
  # Straight-through
202
+ quantized = x_norm + (quantized - x_norm).detach()
203
 
204
+ return quantized, indices
 
 
 
 
205
 
206
+ def decode(self, indices):
207
+ embed = F.normalize(self.embed, dim=-1)
208
+ return F.embedding(indices, embed)
 
 
 
209
 
210
 
211
+ class VQLayer(nn.Module):
212
+ """VQ layer matching checkpoint: vq.layers.0._codebook.*"""
213
+ def __init__(self, dim, codebook_size):
214
+ super().__init__()
215
+ self._codebook = Codebook(codebook_size, dim)
216
 
217
+ def forward(self, x):
218
+ # x: (B, D, T)
219
+ x = x.transpose(1, 2) # (B, T, D)
220
+ quantized, indices = self._codebook(x)
221
+ return quantized.transpose(1, 2), indices
222
+
223
+ def decode(self, indices):
224
+ quantized = self._codebook.decode(indices)
225
+ return quantized.transpose(1, 2)
226
+
227
+
228
+ class VQ(nn.Module):
229
+ """VQ wrapper matching checkpoint: vq.layers"""
230
+ def __init__(self, dim, codebook_size, num_quantizers=1):
231
  super().__init__()
232
+ self.layers = nn.ModuleList([
233
+ VQLayer(dim, codebook_size) for _ in range(num_quantizers)
 
 
 
234
  ])
235
 
236
+ def forward(self, x):
237
+ indices_list = []
238
+ quantized = torch.zeros_like(x)
239
+ residual = x
 
 
 
 
 
240
 
241
+ for layer in self.layers:
242
+ q, idx = layer(residual)
243
+ residual = residual - q
244
+ quantized = quantized + q
245
+ indices_list.append(idx)
 
246
 
247
+ indices = torch.stack(indices_list, dim=1)
248
+ return quantized, indices
249
+
250
+ def decode(self, indices):
251
+ quantized = None
252
+ for i, layer in enumerate(self.layers):
253
+ q = layer.decode(indices[:, i])
254
+ quantized = q if quantized is None else quantized + q
255
+ return quantized
256
+
257
+
258
+ class Quantizer(nn.Module):
259
+ """Quantizer matching checkpoint: quantizer.vq"""
260
+ def __init__(self, dim, codebook_size, num_quantizers=1):
261
+ super().__init__()
262
+ self.vq = VQ(dim, codebook_size, num_quantizers)
263
+
264
+ def forward(self, x):
265
+ return self.vq(x)
266
+
267
+ def decode(self, indices):
268
+ return self.vq.decode(indices)
269
+
270
+
271
+ class EnCodecWrapper(nn.Module):
272
+ """Wrapper matching checkpoint: encodec.encoder, encodec.quantizer"""
273
+ def __init__(self, channels=1, n_filters=32, dimension=512, ratios=[2, 4, 5, 8],
274
+ codebook_size=4096, num_quantizers=1):
275
+ super().__init__()
276
+ self.encoder = EncoderModel(channels, n_filters, dimension, ratios)
277
+ self.quantizer = Quantizer(dimension, codebook_size, num_quantizers)
278
+ # Note: decoder exists in checkpoint but we use Vocos backbone instead
279
+
280
+ def encode(self, x):
281
+ z = self.encoder(x)
282
+ z_q, codes = self.quantizer(z)
283
+ return z_q, codes
284
+
285
+
286
+ class FeatureExtractor(nn.Module):
287
+ """Feature extractor matching checkpoint: feature_extractor.encodec"""
288
+ def __init__(self, **kwargs):
289
+ super().__init__()
290
+ self.encodec = EnCodecWrapper(**kwargs)
291
+
292
+ def encode(self, x):
293
+ return self.encodec.encode(x)
294
+
295
+ def decode_codes(self, codes):
296
+ return self.encodec.quantizer.decode(codes)
297
+
298
+
299
+ # =============================================================================
300
+ # Backbone (Vocos-style with bandwidth-conditioned AdaLayerNorm)
301
+ # =============================================================================
302
+
303
+ class AdaLayerNorm(nn.Module):
304
+ """
305
+ Bandwidth-conditioned Adaptive LayerNorm.
306
+
307
+ Checkpoint structure:
308
+ - norm.scale.weight: [4, 768] (4 bandwidth conditions)
309
+ - norm.shift.weight: [4, 768]
310
+ """
311
+ def __init__(self, dim, num_bandwidths=4, eps=1e-6):
312
+ super().__init__()
313
+ self.eps = eps
314
+ self.dim = dim
315
+ # Match checkpoint: scale.weight and shift.weight are [num_bandwidths, dim]
316
+ self.scale = nn.Embedding(num_bandwidths, dim)
317
+ self.shift = nn.Embedding(num_bandwidths, dim)
318
 
319
+ # Initialize
320
+ nn.init.ones_(self.scale.weight)
321
+ nn.init.zeros_(self.shift.weight)
322
 
323
+ def forward(self, x, bandwidth_id=None):
324
+ """
325
+ Args:
326
+ x: (B, C, T) input
327
+ bandwidth_id: (B,) bandwidth index, or None for default (0)
328
+ """
329
+ # Normalize
330
+ mean = x.mean(dim=1, keepdim=True)
331
+ var = x.var(dim=1, keepdim=True, unbiased=False)
332
+ x = (x - mean) / torch.sqrt(var + self.eps)
333
 
334
+ # Get scale/shift based on bandwidth_id
335
+ if bandwidth_id is None:
336
+ bandwidth_id = torch.zeros(x.shape[0], dtype=torch.long, device=x.device)
 
337
 
338
+ scale = self.scale(bandwidth_id) # (B, dim)
339
+ shift = self.shift(bandwidth_id) # (B, dim)
340
+
341
+ # Apply: (B, dim, 1) for broadcasting
342
+ x = x * scale.unsqueeze(-1) + shift.unsqueeze(-1)
343
+
344
+ return x
345
 
 
 
 
346
 
347
  class ConvNeXtBlock(nn.Module):
348
+ """
349
+ ConvNeXt block matching checkpoint structure exactly.
350
+
351
+ Checkpoint keys:
352
+ - dwconv.weight: [768, 1, 7]
353
+ - dwconv.bias: [768]
354
+ - norm.scale.weight: [4, 768]
355
+ - norm.shift.weight: [4, 768]
356
+ - pwconv1.weight: [2304, 768]
357
+ - pwconv1.bias: [2304]
358
+ - pwconv2.weight: [768, 2304]
359
+ - pwconv2.bias: [768]
360
+ - gamma: [768]
361
+ """
362
+ def __init__(self, dim, intermediate_dim, kernel_size=7, layer_scale_init=1e-6, num_bandwidths=4):
363
  super().__init__()
 
364
  padding = (kernel_size - 1) // 2
365
+
366
  self.dwconv = nn.Conv1d(dim, dim, kernel_size, padding=padding, groups=dim)
367
+ self.norm = AdaLayerNorm(dim, num_bandwidths)
368
  self.pwconv1 = nn.Linear(dim, intermediate_dim)
 
369
  self.pwconv2 = nn.Linear(intermediate_dim, dim)
370
+ self.gamma = nn.Parameter(layer_scale_init * torch.ones(dim))
 
 
 
371
 
372
+ def forward(self, x, bandwidth_id=None):
373
  residual = x
374
  x = self.dwconv(x)
375
+ x = self.norm(x, bandwidth_id)
376
+ x = x.transpose(1, 2) # (B, T, C)
377
  x = self.pwconv1(x)
378
+ x = F.gelu(x)
379
  x = self.pwconv2(x)
380
+ x = x.transpose(1, 2) # (B, C, T)
381
+ x = self.gamma.unsqueeze(0).unsqueeze(-1) * x
 
382
  return residual + x
383
 
384
 
385
+ class Backbone(nn.Module):
386
+ """
387
+ Vocos backbone matching checkpoint structure.
388
 
389
+ Checkpoint keys:
390
+ - embed.weight, embed.bias
391
+ - norm.scale.weight, norm.shift.weight
392
+ - convnext.0-11.*
393
+ - final_layer_norm.weight, final_layer_norm.bias
394
+ """
395
+ def __init__(self, input_dim=512, dim=768, intermediate_dim=2304, num_blocks=12,
396
+ num_bandwidths=4):
 
 
 
 
397
  super().__init__()
398
 
399
+ # Input projection: backbone.embed
400
+ self.embed = nn.Conv1d(input_dim, dim, kernel_size=3, padding=1)
 
 
 
 
 
 
 
 
 
 
 
 
401
 
402
+ # Input normalization: backbone.norm
403
+ self.norm = AdaLayerNorm(dim, num_bandwidths)
404
+
405
+ # ConvNeXt blocks: backbone.convnext.0-11
406
  self.convnext = nn.ModuleList([
407
+ ConvNeXtBlock(dim, intermediate_dim, num_bandwidths=num_bandwidths)
408
  for _ in range(num_blocks)
409
  ])
410
 
411
+ # Final norm: backbone.final_layer_norm
412
+ self.final_layer_norm = nn.LayerNorm(dim)
413
 
414
+ def forward(self, x, bandwidth_id=None):
415
  # Input projection
416
+ x = self.embed(x)
417
+ x = self.norm(x, bandwidth_id)
 
 
 
 
 
 
 
 
 
 
 
 
418
 
419
  # ConvNeXt blocks
420
  for block in self.convnext:
421
+ x = block(x, bandwidth_id)
422
 
423
  # Final norm
424
+ x = x.transpose(1, 2) # (B, T, C)
425
+ x = self.final_layer_norm(x)
426
+ x = x.transpose(1, 2) # (B, C, T)
427
 
428
  return x
429
 
430
 
431
+ # =============================================================================
432
+ # Head (iSTFT)
433
+ # =============================================================================
434
+
435
+ class ISTFT(nn.Module):
436
+ """ISTFT module matching checkpoint: istft.window"""
437
+ def __init__(self, n_fft=1280):
 
 
 
438
  super().__init__()
 
439
  self.n_fft = n_fft
440
+ self.register_buffer('window', torch.hann_window(n_fft))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
 
442
 
443
+ class ISTFTHead(nn.Module):
444
+ """
445
+ iSTFT head matching checkpoint structure.
 
 
 
446
 
447
+ Checkpoint keys:
448
+ - out.weight: [1282, 768]
449
+ - out.bias: [1282]
450
+ - istft.window: [1280]
451
+ """
452
+ def __init__(self, dim, n_fft=1280, hop_length=320, padding='center'):
 
 
 
 
453
  super().__init__()
 
 
454
  self.n_fft = n_fft
455
  self.hop_length = hop_length
 
456
  self.padding = padding
457
 
458
+ # Output projection: head.out
459
+ self.out = nn.Linear(dim, n_fft + 2)
460
+
461
+ # ISTFT window: head.istft.window
462
+ self.istft = ISTFT(n_fft)
 
 
 
 
 
 
 
 
463
 
464
+ def forward(self, x):
465
  """
466
  Args:
467
+ x: (B, C, T) backbone output
468
  Returns:
469
+ audio: (B, 1, samples)
470
  """
471
+ B, C, T = x.shape
472
+ x = x.transpose(1, 2) # (B, T, C)
473
+ x = self.out(x) # (B, T, n_fft + 2)
474
+
475
+ # Split magnitude and phase
476
+ n_bins = self.n_fft // 2 + 1 # 641
477
+ mag = torch.exp(x[:, :, :n_bins])
478
+ phase = x[:, :, n_bins:]
479
+
480
+ # Construct complex STFT
481
+ stft = torch.complex(mag * torch.cos(phase), mag * torch.sin(phase))
482
+ stft = stft.transpose(1, 2) # (B, n_bins, T)
483
+
484
+ # Inverse STFT
485
+ audio = torch.istft(
486
+ stft,
487
  n_fft=self.n_fft,
488
  hop_length=self.hop_length,
489
+ win_length=self.n_fft,
490
+ window=self.istft.window,
491
+ center=(self.padding == 'center'),
492
+ return_complex=False,
493
  )
494
 
495
+ return audio.unsqueeze(1)
 
 
 
 
 
 
 
 
 
496
 
497
 
498
+ # =============================================================================
499
  # Main WavTokenizer Model
500
+ # =============================================================================
501
 
502
  class WavTokenizer(PreTrainedModel):
503
  """
504
+ WavTokenizer model for audio tokenization.
505
 
506
+ This implementation exactly matches the checkpoint structure for direct weight loading.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
507
  """
508
 
509
  config_class = WavTokenizerConfig
510
+ base_model_prefix = "wavtokenizer"
511
 
512
  def __init__(self, config: WavTokenizerConfig):
513
  super().__init__(config)
514
+ self.config = config
515
+
516
+ # Feature extractor (encoder + quantizer)
517
+ # Matches: feature_extractor.encodec.*
518
+ self.feature_extractor = FeatureExtractor(
519
+ channels=1,
520
+ n_filters=config.encoder_dim,
521
+ dimension=config.latent_dim,
522
+ ratios=config.encoder_rates,
 
 
 
 
 
523
  codebook_size=config.codebook_size,
 
524
  num_quantizers=config.num_quantizers,
525
  )
526
 
527
+ # Backbone (Vocos-style decoder)
528
+ # Matches: backbone.*
529
+ self.backbone = Backbone(
530
+ input_dim=config.latent_dim,
 
 
531
  dim=config.backbone_dim,
532
  intermediate_dim=config.backbone_intermediate_dim,
533
  num_blocks=config.backbone_num_blocks,
534
+ num_bandwidths=4,
 
 
 
 
535
  )
536
 
537
+ # Head (iSTFT)
538
+ # Matches: head.*
539
  self.head = ISTFTHead(
540
  dim=config.backbone_dim,
541
  n_fft=config.n_fft,
 
543
  padding=config.padding,
544
  )
545
 
 
 
 
546
  self.post_init()
547
 
548
+ def encode(self, audio, bandwidth_id=None):
 
 
 
 
 
 
 
 
 
 
549
  """
550
+ Encode audio to quantized features and codes.
551
 
552
  Args:
553
+ audio: (B, 1, T) audio waveform
554
+ bandwidth_id: Optional (B,) bandwidth index
555
 
556
  Returns:
557
+ features: (B, D, T') quantized features
558
+ codes: (B, num_quantizers, T') discrete codes
 
559
  """
560
+ return self.feature_extractor.encode(audio)
 
 
 
 
 
 
561
 
562
+ def encode_infer(self, audio, bandwidth_id=None):
 
 
 
563
  """
564
+ Encode audio for inference.
565
 
566
  Args:
567
+ audio: (B, 1, T) audio waveform
568
+ bandwidth_id: Optional bandwidth index (scalar or tensor)
569
 
570
  Returns:
571
+ features: (B, D, T') quantized features
572
+ codes: (B, T') discrete codes (squeezed for single quantizer)
573
  """
574
+ features, codes = self.encode(audio, bandwidth_id)
575
+ if codes.shape[1] == 1:
576
+ codes = codes.squeeze(1)
577
+ return features, codes
 
 
 
 
 
 
 
 
 
 
578
 
579
+ def decode(self, features, bandwidth_id=None):
 
 
580
  """
581
+ Decode features to audio.
582
 
583
  Args:
584
+ features: (B, D, T') quantized features
585
+ bandwidth_id: Optional (B,) bandwidth index
586
 
587
  Returns:
588
+ audio: (B, 1, T) reconstructed waveform
589
  """
590
+ x = self.backbone(features, bandwidth_id)
591
+ return self.head(x)
 
 
 
 
 
 
 
 
592
 
593
+ def codes_to_features(self, codes):
 
594
  """
595
+ Convert discrete codes back to continuous features.
596
 
597
  Args:
598
+ codes: (B, T) or (B, num_quantizers, T) discrete codes
599
 
600
  Returns:
601
+ features: (B, D, T) continuous features
602
  """
603
+ if codes.dim() == 2:
604
+ codes = codes.unsqueeze(1)
605
+ return self.feature_extractor.decode_codes(codes)
606
 
607
  def forward(
608
  self,
609
+ input_values: Optional[torch.Tensor] = None,
610
+ input_ids: Optional[torch.Tensor] = None,
611
+ bandwidth_id: Optional[torch.Tensor] = None,
612
+ **kwargs,
613
+ ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
614
  """
615
+ HuggingFace-style forward pass.
616
 
617
  Args:
618
+ input_values: (B, 1, T) or (B, T) audio waveform
619
+ input_ids: (B, T) or (B, num_quantizers, T) discrete codes
620
+ bandwidth_id: Optional (B,) bandwidth index
621
 
622
  Returns:
623
+ BaseModelOutput with last_hidden_state (features) and hidden_states (codes, audio)
624
  """
625
+ if input_values is not None:
626
+ if input_values.dim() == 2:
627
+ input_values = input_values.unsqueeze(1)
628
+
629
+ features, codes = self.encode(input_values, bandwidth_id)
630
+ audio = self.decode(features, bandwidth_id)
631
+
632
+ return BaseModelOutput(
633
+ last_hidden_state=features,
634
+ hidden_states=(codes, audio),
635
+ )
636
+
637
+ elif input_ids is not None:
638
+ features = self.codes_to_features(input_ids)
639
+ audio = self.decode(features, bandwidth_id)
640
+
641
+ return BaseModelOutput(
642
+ last_hidden_state=features,
643
+ hidden_states=(input_ids, audio),
644
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
645
 
646
+ else:
647
+ raise ValueError("Either input_values or input_ids must be provided")