""" WavTokenizer model implementation for HuggingFace. This implementation exactly matches the checkpoint structure for direct weight loading. """ import math from typing import Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel from transformers.modeling_outputs import BaseModelOutput from .configuration_wavtokenizer import WavTokenizerConfig # ============================================================================= # Audio Utilities # ============================================================================= def convert_audio(wav, sr, target_sr, target_channels=1): """Convert audio to target sample rate and channels.""" if wav.dim() == 1: wav = wav.unsqueeze(0).unsqueeze(0) elif wav.dim() == 2: wav = wav.unsqueeze(1) if wav.shape[1] > target_channels: wav = wav[:, :target_channels, :] elif wav.shape[1] < target_channels: wav = wav.repeat(1, target_channels, 1) if sr != target_sr: wav = F.interpolate(wav, size=int(wav.shape[-1] * target_sr / sr), mode='linear', align_corners=False) return wav # ============================================================================= # Weight-Normalized Conv1d (matching checkpoint's weight_g/weight_v structure) # ============================================================================= class WNConv1d(nn.Module): """Weight-normalized Conv1d matching checkpoint structure with weight_g/weight_v.""" def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): super().__init__() self.conv = nn.utils.weight_norm( nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) ) def forward(self, x): return self.conv(x) class WNConvTranspose1d(nn.Module): """Weight-normalized ConvTranspose1d.""" def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True): super().__init__() self.convtr = nn.utils.weight_norm( nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, padding, output_padding, groups, bias) ) def forward(self, x): return self.convtr(x) # ============================================================================= # Encoder (EnCodec-style, matching feature_extractor.encodec.encoder.model.*) # ============================================================================= class _ConvWrapper(nn.Module): """Wrapper to match checkpoint structure: conv.conv.weight_g, conv.conv.weight_v, conv.conv.bias""" def __init__(self, in_ch, out_ch, kernel_size, stride=1, padding=0): super().__init__() self.conv = WNConv1d(in_ch, out_ch, kernel_size, stride=stride, padding=padding) def forward(self, x): return self.conv(x) class _ResBlockWrapper(nn.Module): """Wrapper to match checkpoint structure: block.1.conv.conv, block.3.conv.conv, shortcut.conv.conv""" def __init__(self, dim): super().__init__() self.block = nn.Sequential() self.block.add_module('0', nn.ELU()) self.block.add_module('1', _ConvWrapper(dim, dim // 2, 3, padding=1)) self.block.add_module('2', nn.ELU()) self.block.add_module('3', _ConvWrapper(dim // 2, dim, 1)) self.shortcut = _ConvWrapper(dim, dim, 1) def forward(self, x): return self.shortcut(x) + self.block(x) class _LSTMWrapper(nn.Module): """LSTM wrapper matching checkpoint: lstm.weight_ih_l0, etc.""" def __init__(self, dim, num_layers=2): super().__init__() self.lstm = nn.LSTM(dim, dim, num_layers=num_layers, batch_first=True) def forward(self, x): x = x.transpose(1, 2) y, _ = self.lstm(x) y = y + x return y.transpose(1, 2) class EncoderModel(nn.Module): """ Encoder matching checkpoint: feature_extractor.encodec.encoder.model.* Structure based on checkpoint: - model.0: initial conv (1 -> 32) - model.1: residual block (32) - model.2: ELU (not saved) - model.3: downsample conv (32->64, stride=2) - model.4: residual block (64) - model.5: ELU - model.6: downsample conv (64->128, stride=4) - model.7: residual block (128) - model.8: ELU - model.9: downsample conv (128->256, stride=5) - model.10: residual block (256) - model.11: ELU - model.12: downsample conv (256->512, stride=8) - model.13: LSTM - model.14: ELU - model.15: output conv (512->512) """ def __init__(self, channels=1, n_filters=32, dimension=512, ratios=[2, 4, 5, 8]): super().__init__() layers = [] # model.0: Initial conv layers.append(_ConvWrapper(channels, n_filters, 7, padding=3)) # Encoder blocks with downsampling in_ch = n_filters for ratio in ratios: out_ch = in_ch * 2 # Residual block layers.append(_ResBlockWrapper(in_ch)) # ELU (implicit in original, but we need it) layers.append(nn.ELU()) # Downsample conv layers.append(_ConvWrapper(in_ch, out_ch, ratio * 2, stride=ratio, padding=ratio // 2)) in_ch = out_ch # LSTM layers.append(_LSTMWrapper(in_ch)) # ELU layers.append(nn.ELU()) # Output conv layers.append(_ConvWrapper(in_ch, dimension, 7, padding=3)) self.model = nn.Sequential(*layers) def forward(self, x): return self.model(x) # ============================================================================= # Quantizer (matching feature_extractor.encodec.quantizer.vq.layers.0._codebook.*) # ============================================================================= class Codebook(nn.Module): """Codebook matching checkpoint: _codebook.embed, _codebook.inited, _codebook.cluster_size, _codebook.embed_avg""" def __init__(self, num_embeddings, embedding_dim): super().__init__() # These match checkpoint structure exactly self.register_buffer('inited', torch.zeros(1)) self.register_buffer('cluster_size', torch.zeros(num_embeddings)) self.register_buffer('embed', torch.randn(num_embeddings, embedding_dim)) self.register_buffer('embed_avg', torch.randn(num_embeddings, embedding_dim)) def forward(self, x): """ Args: x: (B, T, D) input Returns: quantized: (B, T, D) quantized output indices: (B, T) codebook indices """ # L2 normalize embed = F.normalize(self.embed, dim=-1) x_norm = F.normalize(x, dim=-1) # Find nearest dist = torch.cdist(x_norm, embed) indices = dist.argmin(dim=-1) # Quantize quantized = F.embedding(indices, embed) # Straight-through quantized = x_norm + (quantized - x_norm).detach() return quantized, indices def decode(self, indices): embed = F.normalize(self.embed, dim=-1) return F.embedding(indices, embed) class VQLayer(nn.Module): """VQ layer matching checkpoint: vq.layers.0._codebook.*""" def __init__(self, dim, codebook_size): super().__init__() self._codebook = Codebook(codebook_size, dim) def forward(self, x): # x: (B, D, T) x = x.transpose(1, 2) # (B, T, D) quantized, indices = self._codebook(x) return quantized.transpose(1, 2), indices def decode(self, indices): quantized = self._codebook.decode(indices) return quantized.transpose(1, 2) class VQ(nn.Module): """VQ wrapper matching checkpoint: vq.layers""" def __init__(self, dim, codebook_size, num_quantizers=1): super().__init__() self.layers = nn.ModuleList([ VQLayer(dim, codebook_size) for _ in range(num_quantizers) ]) def forward(self, x): indices_list = [] quantized = torch.zeros_like(x) residual = x for layer in self.layers: q, idx = layer(residual) residual = residual - q quantized = quantized + q indices_list.append(idx) indices = torch.stack(indices_list, dim=1) return quantized, indices def decode(self, indices): quantized = None for i, layer in enumerate(self.layers): q = layer.decode(indices[:, i]) quantized = q if quantized is None else quantized + q return quantized class Quantizer(nn.Module): """Quantizer matching checkpoint: quantizer.vq""" def __init__(self, dim, codebook_size, num_quantizers=1): super().__init__() self.vq = VQ(dim, codebook_size, num_quantizers) def forward(self, x): return self.vq(x) def decode(self, indices): return self.vq.decode(indices) class EnCodecWrapper(nn.Module): """Wrapper matching checkpoint: encodec.encoder, encodec.quantizer""" def __init__(self, channels=1, n_filters=32, dimension=512, ratios=[2, 4, 5, 8], codebook_size=4096, num_quantizers=1): super().__init__() self.encoder = EncoderModel(channels, n_filters, dimension, ratios) self.quantizer = Quantizer(dimension, codebook_size, num_quantizers) # Note: decoder exists in checkpoint but we use Vocos backbone instead def encode(self, x): z = self.encoder(x) z_q, codes = self.quantizer(z) return z_q, codes class FeatureExtractor(nn.Module): """Feature extractor matching checkpoint: feature_extractor.encodec""" def __init__(self, **kwargs): super().__init__() self.encodec = EnCodecWrapper(**kwargs) def encode(self, x): return self.encodec.encode(x) def decode_codes(self, codes): return self.encodec.quantizer.decode(codes) # ============================================================================= # Backbone (Vocos-style with bandwidth-conditioned AdaLayerNorm) # ============================================================================= class AdaLayerNorm(nn.Module): """ Bandwidth-conditioned Adaptive LayerNorm. Checkpoint structure: - norm.scale.weight: [4, 768] (4 bandwidth conditions) - norm.shift.weight: [4, 768] """ def __init__(self, dim, num_bandwidths=4, eps=1e-6): super().__init__() self.eps = eps self.dim = dim # Match checkpoint: scale.weight and shift.weight are [num_bandwidths, dim] self.scale = nn.Embedding(num_bandwidths, dim) self.shift = nn.Embedding(num_bandwidths, dim) # Initialize nn.init.ones_(self.scale.weight) nn.init.zeros_(self.shift.weight) def forward(self, x, bandwidth_id=None): """ Args: x: (B, C, T) input bandwidth_id: (B,) bandwidth index, or None for default (0) """ # Normalize mean = x.mean(dim=1, keepdim=True) var = x.var(dim=1, keepdim=True, unbiased=False) x = (x - mean) / torch.sqrt(var + self.eps) # Get scale/shift based on bandwidth_id if bandwidth_id is None: bandwidth_id = torch.zeros(x.shape[0], dtype=torch.long, device=x.device) scale = self.scale(bandwidth_id) # (B, dim) shift = self.shift(bandwidth_id) # (B, dim) # Apply: (B, dim, 1) for broadcasting x = x * scale.unsqueeze(-1) + shift.unsqueeze(-1) return x class ConvNeXtBlock(nn.Module): """ ConvNeXt block matching checkpoint structure exactly. Checkpoint keys: - dwconv.weight: [768, 1, 7] - dwconv.bias: [768] - norm.scale.weight: [4, 768] - norm.shift.weight: [4, 768] - pwconv1.weight: [2304, 768] - pwconv1.bias: [2304] - pwconv2.weight: [768, 2304] - pwconv2.bias: [768] - gamma: [768] """ def __init__(self, dim, intermediate_dim, kernel_size=7, layer_scale_init=1e-6, num_bandwidths=4): super().__init__() padding = (kernel_size - 1) // 2 self.dwconv = nn.Conv1d(dim, dim, kernel_size, padding=padding, groups=dim) self.norm = AdaLayerNorm(dim, num_bandwidths) self.pwconv1 = nn.Linear(dim, intermediate_dim) self.pwconv2 = nn.Linear(intermediate_dim, dim) self.gamma = nn.Parameter(layer_scale_init * torch.ones(dim)) def forward(self, x, bandwidth_id=None): residual = x x = self.dwconv(x) x = self.norm(x, bandwidth_id) x = x.transpose(1, 2) # (B, T, C) x = self.pwconv1(x) x = F.gelu(x) x = self.pwconv2(x) x = x.transpose(1, 2) # (B, C, T) x = self.gamma.unsqueeze(0).unsqueeze(-1) * x return residual + x class Backbone(nn.Module): """ Vocos backbone matching checkpoint structure. Checkpoint keys: - embed.weight, embed.bias - norm.scale.weight, norm.shift.weight - convnext.0-11.* - final_layer_norm.weight, final_layer_norm.bias """ def __init__(self, input_dim=512, dim=768, intermediate_dim=2304, num_blocks=12, num_bandwidths=4): super().__init__() # Input projection: backbone.embed self.embed = nn.Conv1d(input_dim, dim, kernel_size=3, padding=1) # Input normalization: backbone.norm self.norm = AdaLayerNorm(dim, num_bandwidths) # ConvNeXt blocks: backbone.convnext.0-11 self.convnext = nn.ModuleList([ ConvNeXtBlock(dim, intermediate_dim, num_bandwidths=num_bandwidths) for _ in range(num_blocks) ]) # Final norm: backbone.final_layer_norm self.final_layer_norm = nn.LayerNorm(dim) def forward(self, x, bandwidth_id=None): # Input projection x = self.embed(x) x = self.norm(x, bandwidth_id) # ConvNeXt blocks for block in self.convnext: x = block(x, bandwidth_id) # Final norm x = x.transpose(1, 2) # (B, T, C) x = self.final_layer_norm(x) x = x.transpose(1, 2) # (B, C, T) return x # ============================================================================= # Head (iSTFT) # ============================================================================= class ISTFT(nn.Module): """ISTFT module matching checkpoint: istft.window""" def __init__(self, n_fft=1280): super().__init__() self.n_fft = n_fft self.register_buffer('window', torch.hann_window(n_fft)) class ISTFTHead(nn.Module): """ iSTFT head matching checkpoint structure. Checkpoint keys: - out.weight: [1282, 768] - out.bias: [1282] - istft.window: [1280] """ def __init__(self, dim, n_fft=1280, hop_length=320, padding='center'): super().__init__() self.n_fft = n_fft self.hop_length = hop_length self.padding = padding # Output projection: head.out self.out = nn.Linear(dim, n_fft + 2) # ISTFT window: head.istft.window self.istft = ISTFT(n_fft) def forward(self, x): """ Args: x: (B, C, T) backbone output Returns: audio: (B, 1, samples) """ B, C, T = x.shape x = x.transpose(1, 2) # (B, T, C) x = self.out(x) # (B, T, n_fft + 2) # Split magnitude and phase n_bins = self.n_fft // 2 + 1 # 641 mag = torch.exp(x[:, :, :n_bins]) phase = x[:, :, n_bins:] # Construct complex STFT stft = torch.complex(mag * torch.cos(phase), mag * torch.sin(phase)) stft = stft.transpose(1, 2) # (B, n_bins, T) # Inverse STFT audio = torch.istft( stft, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.n_fft, window=self.istft.window, center=(self.padding == 'center'), return_complex=False, ) return audio.unsqueeze(1) # ============================================================================= # Main WavTokenizer Model # ============================================================================= class WavTokenizer(PreTrainedModel): """ WavTokenizer model for audio tokenization. This implementation exactly matches the checkpoint structure for direct weight loading. """ config_class = WavTokenizerConfig base_model_prefix = "wavtokenizer" def __init__(self, config: WavTokenizerConfig): super().__init__(config) self.config = config # Feature extractor (encoder + quantizer) # Matches: feature_extractor.encodec.* self.feature_extractor = FeatureExtractor( channels=1, n_filters=config.encoder_dim, dimension=config.latent_dim, ratios=config.encoder_rates, codebook_size=config.codebook_size, num_quantizers=config.num_quantizers, ) # Backbone (Vocos-style decoder) # Matches: backbone.* self.backbone = Backbone( input_dim=config.latent_dim, dim=config.backbone_dim, intermediate_dim=config.backbone_intermediate_dim, num_blocks=config.backbone_num_blocks, num_bandwidths=4, ) # Head (iSTFT) # Matches: head.* self.head = ISTFTHead( dim=config.backbone_dim, n_fft=config.n_fft, hop_length=config.hop_length, padding=config.padding, ) self.post_init() def encode(self, audio, bandwidth_id=None): """ Encode audio to quantized features and codes. Args: audio: (B, 1, T) audio waveform bandwidth_id: Optional (B,) bandwidth index Returns: features: (B, D, T') quantized features codes: (B, num_quantizers, T') discrete codes """ return self.feature_extractor.encode(audio) def encode_infer(self, audio, bandwidth_id=None): """ Encode audio for inference. Args: audio: (B, 1, T) audio waveform bandwidth_id: Optional bandwidth index (scalar or tensor) Returns: features: (B, D, T') quantized features codes: (B, T') discrete codes (squeezed for single quantizer) """ features, codes = self.encode(audio, bandwidth_id) if codes.shape[1] == 1: codes = codes.squeeze(1) return features, codes def decode(self, features, bandwidth_id=None): """ Decode features to audio. Args: features: (B, D, T') quantized features bandwidth_id: Optional (B,) bandwidth index Returns: audio: (B, 1, T) reconstructed waveform """ x = self.backbone(features, bandwidth_id) return self.head(x) def codes_to_features(self, codes): """ Convert discrete codes back to continuous features. Args: codes: (B, T) or (B, num_quantizers, T) discrete codes Returns: features: (B, D, T) continuous features """ if codes.dim() == 2: codes = codes.unsqueeze(1) return self.feature_extractor.decode_codes(codes) def forward( self, input_values: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None, bandwidth_id: Optional[torch.Tensor] = None, **kwargs, ): """ HuggingFace-style forward pass. Args: input_values: (B, 1, T) or (B, T) audio waveform input_ids: (B, T) or (B, num_quantizers, T) discrete codes bandwidth_id: Optional (B,) bandwidth index Returns: BaseModelOutput with last_hidden_state (features) and hidden_states (codes, audio) """ if input_values is not None: if input_values.dim() == 2: input_values = input_values.unsqueeze(1) features, codes = self.encode(input_values, bandwidth_id) audio = self.decode(features, bandwidth_id) return BaseModelOutput( last_hidden_state=features, hidden_states=(codes, audio), ) elif input_ids is not None: features = self.codes_to_features(input_ids) audio = self.decode(features, bandwidth_id) return BaseModelOutput( last_hidden_state=features, hidden_states=(input_ids, audio), ) else: raise ValueError("Either input_values or input_ids must be provided")