""" WavTokenizer Model for HuggingFace Transformers This module contains the complete implementation of WavTokenizer, an acoustic discrete codec tokenizer for audio language modeling. All dependencies are included to avoid external imports. The architecture follows the original WavTokenizer implementation: - Encoder: Strided convolutions for audio compression - VQ: Vector quantization with single codebook - Decoder: Vocos-style backbone with ConvNeXt blocks + iSTFT head Reference: https://github.com/jishengpeng/WavTokenizer Paper: "WavTokenizer: an Efficient Acoustic Discrete Codec Tokenizer for Audio Language Modeling" """ import math from typing import Dict, List, Optional, Tuple, Union from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from torch.nn.utils import weight_norm, remove_weight_norm from transformers import PreTrainedModel from transformers.tokenization_utils import BatchEncoding from .configuration_wavtokenizer import WavTokenizerConfig # ============================================================================== # Utility Functions # ============================================================================== def convert_audio(wav: Tensor, sr: int, target_sr: int, target_channels: int) -> Tensor: """ Convert audio to target sample rate and number of channels. Args: wav: Input waveform [C, T] or [T] sr: Source sample rate target_sr: Target sample rate target_channels: Target number of channels (1 for mono, 2 for stereo) Returns: Converted waveform [target_channels, T'] """ import torchaudio # Ensure 2D if wav.dim() == 1: wav = wav.unsqueeze(0) # Convert channels if wav.size(0) > target_channels: wav = wav.mean(dim=0, keepdim=True) elif wav.size(0) < target_channels: wav = wav.expand(target_channels, -1) # Resample if needed if sr != target_sr: wav = torchaudio.functional.resample(wav, sr, target_sr) return wav # ============================================================================== # Encoder Components (DAC-style) # ============================================================================== def WNConv1d(*args, **kwargs): """Weight-normalized Conv1d.""" return weight_norm(nn.Conv1d(*args, **kwargs)) def WNConvTranspose1d(*args, **kwargs): """Weight-normalized ConvTranspose1d.""" return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) class ResidualUnit(nn.Module): """Residual unit with dilated convolution.""" def __init__(self, dim: int = 16, dilation: int = 1): super().__init__() pad = ((7 - 1) * dilation) // 2 self.block = nn.Sequential( nn.ELU(), WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), nn.ELU(), WNConv1d(dim, dim, kernel_size=1), ) def forward(self, x: Tensor) -> Tensor: return x + self.block(x) class EncoderBlock(nn.Module): """Encoder block with residual units and downsampling.""" def __init__(self, dim: int = 16, stride: int = 1): super().__init__() self.block = nn.Sequential( ResidualUnit(dim // 2, dilation=1), ResidualUnit(dim // 2, dilation=3), ResidualUnit(dim // 2, dilation=9), nn.ELU(), WNConv1d( dim // 2, dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2), ), ) def forward(self, x: Tensor) -> Tensor: return self.block(x) class Encoder(nn.Module): """ DAC-style encoder that compresses waveform to latent representation. Uses strided convolutions for downsampling. """ def __init__( self, d_model: int = 64, strides: List[int] = [8, 5, 4, 2], d_latent: int = 512, ): super().__init__() # Initial conv self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)] # Encoder blocks with increasing channels for stride in strides: d_model *= 2 self.block.append(EncoderBlock(d_model, stride=stride)) # Final projection self.block.extend([ nn.ELU(), WNConv1d(d_model, d_latent, kernel_size=3, padding=1), ]) self.block = nn.Sequential(*self.block) self.enc_dim = d_model def forward(self, x: Tensor) -> Tensor: return self.block(x) # ============================================================================== # Vector Quantization # ============================================================================== class VectorQuantize(nn.Module): """ Improved vector quantization with EMA codebook updates. Uses L2-normalized codes for better stability. """ def __init__( self, input_dim: int, codebook_size: int, codebook_dim: int, commitment: float = 0.25, ): super().__init__() self.input_dim = input_dim self.codebook_size = codebook_size self.codebook_dim = codebook_dim self.commitment = commitment # Projections requires_projection = input_dim != codebook_dim self.project_in = nn.Linear(input_dim, codebook_dim) if requires_projection else nn.Identity() self.project_out = nn.Linear(codebook_dim, input_dim) if requires_projection else nn.Identity() # Codebook self.codebook = nn.Embedding(codebook_size, codebook_dim) nn.init.uniform_(self.codebook.weight, -1.0 / codebook_size, 1.0 / codebook_size) def forward(self, z: Tensor) -> Tuple[Tensor, Tensor, Tensor]: """ Forward pass. Args: z: Input [B, D, T] Returns: z_q: Quantized [B, D, T] commitment_loss: Loss scalar indices: Codes [B, T] """ # [B, D, T] -> [B, T, D] z = z.transpose(1, 2) z_e = self.project_in(z) # L2 normalize z_e_norm = F.normalize(z_e, dim=-1) codebook_norm = F.normalize(self.codebook.weight, dim=-1) # Find nearest codes dist = ( z_e_norm.pow(2).sum(-1, keepdim=True) + codebook_norm.pow(2).sum(-1) - 2 * torch.einsum('btd,kd->btk', z_e_norm, codebook_norm) ) indices = dist.argmin(dim=-1) # Look up quantized values z_q = F.embedding(indices, codebook_norm) # Commitment loss commitment_loss = F.mse_loss(z_e_norm, z_q.detach()) * self.commitment # Straight-through z_q = z_e_norm + (z_q - z_e_norm).detach() # Project out and transpose back z_q = self.project_out(z_q) z_q = z_q.transpose(1, 2) # [B, D, T] return z_q, commitment_loss, indices def decode(self, indices: Tensor) -> Tensor: """Decode indices to vectors.""" codebook = F.normalize(self.codebook.weight, dim=-1) z_q = F.embedding(indices, codebook) z_q = self.project_out(z_q) return z_q.transpose(1, 2) class ResidualVectorQuantize(nn.Module): """Residual VQ with multiple codebooks (typically 1 for WavTokenizer).""" def __init__( self, input_dim: int = 512, codebook_size: int = 4096, codebook_dim: int = 8, num_quantizers: int = 1, commitment: float = 0.25, ): super().__init__() self.num_quantizers = num_quantizers self.quantizers = nn.ModuleList([ VectorQuantize(input_dim, codebook_size, codebook_dim, commitment) for _ in range(num_quantizers) ]) def forward( self, z: Tensor, n_quantizers: int = None ) -> Tuple[Tensor, Tensor, Tensor]: n_q = n_quantizers or self.num_quantizers residual = z z_q = torch.zeros_like(z) all_indices = [] all_losses = [] for i, quantizer in enumerate(self.quantizers[:n_q]): _z_q, loss, indices = quantizer(residual) residual = residual - _z_q z_q = z_q + _z_q all_indices.append(indices) all_losses.append(loss) codes = torch.stack(all_indices, dim=0) # [N_q, B, T] commitment_loss = sum(all_losses) return z_q, commitment_loss, codes def decode(self, codes: Tensor) -> Tensor: """Decode codes to vectors.""" if codes.dim() == 2: codes = codes.unsqueeze(0) z_q = None for i, quantizer in enumerate(self.quantizers[:codes.size(0)]): _z_q = quantizer.decode(codes[i]) z_q = _z_q if z_q is None else z_q + _z_q return z_q # ============================================================================== # Decoder Components (Vocos-style) # ============================================================================== class ConvNeXtBlock(nn.Module): """ConvNeXt block with depthwise conv + pointwise expansion.""" def __init__( self, dim: int, intermediate_dim: int, kernel_size: int = 7, layer_scale_init_value: float = 1e-6, ): super().__init__() padding = (kernel_size - 1) // 2 self.dwconv = nn.Conv1d(dim, dim, kernel_size, padding=padding, groups=dim) self.norm = nn.LayerNorm(dim) self.pwconv1 = nn.Linear(dim, intermediate_dim) self.act = nn.GELU() self.pwconv2 = nn.Linear(intermediate_dim, dim) self.gamma = nn.Parameter( layer_scale_init_value * torch.ones(dim) ) if layer_scale_init_value > 0 else None def forward(self, x: Tensor) -> Tensor: residual = x x = self.dwconv(x) x = x.transpose(1, 2) # [B, T, D] x = self.norm(x) x = self.pwconv1(x) x = self.act(x) x = self.pwconv2(x) if self.gamma is not None: x = self.gamma * x x = x.transpose(1, 2) # [B, D, T] return residual + x class VocosBackbone(nn.Module): """Vocos backbone with attention and ConvNeXt blocks.""" def __init__( self, input_dim: int, dim: int, intermediate_dim: int, num_blocks: int, kernel_size: int = 7, layer_scale_init_value: float = 1e-6, use_attention: bool = True, num_heads: int = 8, num_attention_layers: int = 1, ): super().__init__() # Input projection self.input_conv = nn.Conv1d(input_dim, dim, kernel_size=7, padding=3) self.norm = nn.LayerNorm(dim) # Attention layers self.use_attention = use_attention if use_attention: self.attention = nn.ModuleList([ nn.MultiheadAttention(dim, num_heads, batch_first=True) for _ in range(num_attention_layers) ]) self.attn_norms = nn.ModuleList([ nn.LayerNorm(dim) for _ in range(num_attention_layers) ]) # ConvNeXt blocks self.convnext = nn.ModuleList([ ConvNeXtBlock(dim, intermediate_dim, kernel_size, layer_scale_init_value) for _ in range(num_blocks) ]) self.final_norm = nn.LayerNorm(dim) def forward(self, x: Tensor) -> Tensor: # Input projection x = self.input_conv(x) x = x.transpose(1, 2) # [B, T, D] x = self.norm(x) x = x.transpose(1, 2) # [B, D, T] # Attention if self.use_attention: for attn, norm in zip(self.attention, self.attn_norms): x_t = x.transpose(1, 2) # [B, T, D] residual = x_t x_t = norm(x_t) x_t, _ = attn(x_t, x_t, x_t) x_t = residual + x_t x = x_t.transpose(1, 2) # [B, D, T] # ConvNeXt blocks for block in self.convnext: x = block(x) # Final norm x = x.transpose(1, 2) x = self.final_norm(x) x = x.transpose(1, 2) return x class ISTFTHead(nn.Module): """Inverse STFT head for waveform synthesis.""" def __init__( self, dim: int, n_fft: int, hop_length: int, padding: str = "center", ): super().__init__() self.n_fft = n_fft self.hop_length = hop_length self.padding = padding self.out_dim = n_fft // 2 + 1 self.proj = nn.Conv1d(dim, self.out_dim * 2, kernel_size=1) # Register window buffer self.register_buffer( "window", torch.hann_window(n_fft), persistent=False ) def forward(self, x: Tensor) -> Tensor: """ Args: x: [B, D, T] Returns: wav: [B, 1, T'] """ x = self.proj(x) # Split mag/phase mag, phase = x.chunk(2, dim=1) # Process mag = torch.exp(mag) phase = torch.sin(phase) # Complex spectrum S = torch.complex(mag * torch.cos(phase * math.pi), mag * torch.sin(phase * math.pi)) # Ensure window is on same device window = self.window.to(x.device) # iSTFT wav = torch.istft( S, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True, normalized=False, onesided=True, return_complex=False, ) return wav.unsqueeze(1) # ============================================================================== # Feature Extractor (Mel Spectrogram) # ============================================================================== class MelSpectrogramFeatures(nn.Module): """Extract mel spectrogram features from audio.""" def __init__( self, sample_rate: int = 24000, n_fft: int = 1024, hop_length: int = 256, n_mels: int = 100, f_min: float = 0.0, f_max: float = None, padding: str = "center", ): super().__init__() self.sample_rate = sample_rate self.n_fft = n_fft self.hop_length = hop_length self.n_mels = n_mels self.padding = padding # Mel filterbank import torchaudio mel_fb = torchaudio.functional.melscale_fbanks( n_freqs=n_fft // 2 + 1, f_min=f_min, f_max=f_max or sample_rate // 2, n_mels=n_mels, sample_rate=sample_rate, norm="slaney", mel_scale="slaney", ) self.register_buffer("mel_fb", mel_fb, persistent=False) self.register_buffer("window", torch.hann_window(n_fft), persistent=False) def forward(self, wav: Tensor) -> Tensor: """ Args: wav: [B, 1, T] or [B, T] Returns: mel: [B, n_mels, T'] """ if wav.dim() == 3: wav = wav.squeeze(1) # STFT stft = torch.stft( wav, n_fft=self.n_fft, hop_length=self.hop_length, window=self.window.to(wav.device), center=True, return_complex=True, ) # Power spectrum power = stft.abs().pow(2) # Mel spectrogram mel = torch.matmul(self.mel_fb.T.to(power.device), power) # Log scale mel = torch.log(mel.clamp(min=1e-5)) return mel # ============================================================================== # Main WavTokenizer Model # ============================================================================== class WavTokenizer(PreTrainedModel): """ WavTokenizer: Efficient acoustic discrete codec tokenizer. Architecture: - Encoder: Strided convolutions for audio compression - VQ: Single-codebook vector quantization (4096 codes) - Decoder: Vocos backbone (ConvNeXt + attention) + iSTFT head Usage: ```python model = WavTokenizer.from_pretrained("TuKoResearch/WavTokenizerSmall", trust_remote_code=True) # Encode features, codes = model.encode_infer(wav, bandwidth_id=torch.tensor([0])) # Decode wav_out = model.decode(features, bandwidth_id=torch.tensor([0])) # Or use codes directly features = model.codes_to_features(codes) wav_out = model.decode(features, bandwidth_id=torch.tensor([0])) ``` """ config_class = WavTokenizerConfig def __init__(self, config: WavTokenizerConfig): super().__init__(config) self.sample_rate = config.sample_rate self.hop_length = config.hop_length # Encoder self.encoder = Encoder( d_model=config.encoder_dim, strides=config.encoder_rates, d_latent=config.latent_dim, ) # Quantizer self.quantizer = ResidualVectorQuantize( input_dim=config.latent_dim, codebook_size=config.codebook_size, codebook_dim=config.codebook_dim, num_quantizers=config.num_quantizers, ) # Feature projection for decoder self.feature_proj = nn.Conv1d(config.latent_dim, config.backbone_dim, 1) # Decoder backbone self.backbone = VocosBackbone( input_dim=config.backbone_dim, dim=config.backbone_dim, intermediate_dim=config.backbone_intermediate_dim, num_blocks=config.backbone_num_blocks, kernel_size=config.backbone_kernel_size, layer_scale_init_value=config.backbone_layer_scale_init_value, use_attention=config.use_attention, num_heads=config.attention_heads, num_attention_layers=config.attention_layers, ) # iSTFT head self.head = ISTFTHead( dim=config.backbone_dim, n_fft=config.n_fft, hop_length=config.hop_length, padding=config.padding, ) # Bandwidth embedding self.bandwidth_emb = nn.Embedding(4, config.backbone_dim) self.post_init() @property def vocab_size(self) -> int: return self.config.codebook_size @property def frame_rate(self) -> float: return self.config.sample_rate / self.config.hop_length def encode( self, wav: Tensor, bandwidth_id: Tensor = None ) -> Tuple[Tensor, Tensor, Tensor]: """ Encode waveform to quantized features. Args: wav: [B, 1, T] or [B, T] bandwidth_id: Optional bandwidth ID Returns: z_q: Quantized features [B, D, T'] commitment_loss: VQ loss codes: Discrete codes [N_q, B, T'] """ if wav.dim() == 2: wav = wav.unsqueeze(1) z = self.encoder(wav) z_q, loss, codes = self.quantizer(z) return z_q, loss, codes @torch.no_grad() def encode_infer( self, wav: Tensor, bandwidth_id: Tensor = None ) -> Tuple[Tensor, Tensor]: """ Encode waveform to features and codes (inference). Args: wav: [B, 1, T] or [1, T] or [B, T] bandwidth_id: Optional bandwidth ID Returns: features: [B, D, T'] codes: [B, T'] (squeezed if single quantizer) """ if wav.dim() == 2: if wav.size(0) == 1: wav = wav.unsqueeze(0) # [1, T] -> [1, 1, T] else: wav = wav.unsqueeze(1) # [B, T] -> [B, 1, T] z = self.encoder(wav) z_q, _, codes = self.quantizer(z) # Squeeze for single quantizer if codes.size(0) == 1: codes = codes.squeeze(0) return z_q, codes def decode( self, features: Tensor, bandwidth_id: Tensor = None ) -> Tensor: """ Decode features to waveform. Args: features: [B, D, T'] bandwidth_id: Optional bandwidth ID Returns: wav: [B, 1, T] """ x = self.feature_proj(features) if bandwidth_id is not None: bw_emb = self.bandwidth_emb(bandwidth_id) x = x + bw_emb.unsqueeze(-1) x = self.backbone(x) wav = self.head(x) return wav @torch.no_grad() def codes_to_features(self, codes: Tensor) -> Tensor: """ Convert codes to features. Args: codes: [N_q, B, T'] or [B, T'] Returns: features: [B, D, T'] """ return self.quantizer.decode(codes) def forward( self, wav: Tensor = None, codes: Tensor = None, bandwidth_id: Tensor = None, **kwargs ) -> Union[BatchEncoding, Tensor]: """ Forward pass. If wav provided: encode to get tokens If codes provided: decode to get wav """ if wav is not None: features, codes = self.encode_infer(wav, bandwidth_id) return BatchEncoding({ "input_values": features, "input_ids": codes, }) elif codes is not None: features = self.codes_to_features(codes) return self.decode(features, bandwidth_id) else: raise ValueError("Provide either 'wav' or 'codes'") @classmethod def from_pretrained0802( cls, config_path: str, checkpoint_path: str, device: str = "cpu", ) -> "WavTokenizer": """ Load from original WavTokenizer checkpoint. Args: config_path: Path to YAML config checkpoint_path: Path to .ckpt file device: Device to load to Returns: Loaded model """ import yaml # Load YAML config with open(config_path, 'r') as f: yaml_cfg = yaml.safe_load(f) # Extract config params model_args = yaml_cfg.get('model', {}).get('init_args', {}) # Create HF config config = WavTokenizerConfig( sample_rate=24000, n_fft=model_args.get('head', {}).get('init_args', {}).get('n_fft', 1280), hop_length=model_args.get('head', {}).get('init_args', {}).get('hop_length', 320), feature_dim=model_args.get('backbone', {}).get('init_args', {}).get('dim', 512), latent_dim=model_args.get('backbone', {}).get('init_args', {}).get('input_channels', 512), backbone_dim=model_args.get('backbone', {}).get('init_args', {}).get('dim', 512), backbone_intermediate_dim=model_args.get('backbone', {}).get('init_args', {}).get('intermediate_dim', 1536), backbone_num_blocks=model_args.get('backbone', {}).get('init_args', {}).get('num_layers', 8), codebook_size=model_args.get('quantizer', {}).get('init_args', {}).get('codebook_size', 4096), codebook_dim=model_args.get('quantizer', {}).get('init_args', {}).get('codebook_dim', 8), num_quantizers=model_args.get('quantizer', {}).get('init_args', {}).get('num_quantizers', 1), use_attention=True, attention_dim=model_args.get('backbone', {}).get('init_args', {}).get('dim', 512), attention_heads=8, attention_layers=1, ) # Create model model = cls(config) # Load checkpoint ckpt = torch.load(checkpoint_path, map_location=device) state_dict = ckpt.get('state_dict', ckpt) # Clean state dict new_state_dict = {} for k, v in state_dict.items(): # Remove 'model.' prefix if present if k.startswith('model.'): k = k[6:] new_state_dict[k] = v # Load (non-strict to handle mismatches) missing, unexpected = model.load_state_dict(new_state_dict, strict=False) if missing: print(f"Missing keys: {len(missing)}") if unexpected: print(f"Unexpected keys: {len(unexpected)}") return model.to(device)