WavTokenizer / modeling_wavtokenizer.py
klemenk's picture
Update modeling_wavtokenizer.py
0ef180e verified
"""
WavTokenizer model implementation for HuggingFace.
This implementation exactly matches the checkpoint structure for direct weight loading.
"""
import math
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutput
from .configuration_wavtokenizer import WavTokenizerConfig
# =============================================================================
# Audio Utilities
# =============================================================================
def convert_audio(wav, sr, target_sr, target_channels=1):
"""Convert audio to target sample rate and channels."""
if wav.dim() == 1:
wav = wav.unsqueeze(0).unsqueeze(0)
elif wav.dim() == 2:
wav = wav.unsqueeze(1)
if wav.shape[1] > target_channels:
wav = wav[:, :target_channels, :]
elif wav.shape[1] < target_channels:
wav = wav.repeat(1, target_channels, 1)
if sr != target_sr:
wav = F.interpolate(wav, size=int(wav.shape[-1] * target_sr / sr), mode='linear', align_corners=False)
return wav
# =============================================================================
# Weight-Normalized Conv1d (using parametrizations API to match checkpoint)
# =============================================================================
class WNConv1d(nn.Module):
"""Weight-normalized Conv1d using parametrizations API to match checkpoint structure."""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
super().__init__()
conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
# Use parametrizations API (PyTorch 2.0+) to match checkpoint naming
self.conv = nn.utils.parametrizations.weight_norm(conv)
def forward(self, x):
return self.conv(x)
class WNConvTranspose1d(nn.Module):
"""Weight-normalized ConvTranspose1d using parametrizations API."""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True):
super().__init__()
convtr = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, padding, output_padding, groups, bias)
self.convtr = nn.utils.parametrizations.weight_norm(convtr)
def forward(self, x):
return self.convtr(x)
# =============================================================================
# Encoder (EnCodec-style, matching feature_extractor.encodec.encoder.model.*)
# =============================================================================
class _ConvWrapper(nn.Module):
"""Wrapper to match checkpoint structure: conv.conv.weight_g, conv.conv.weight_v, conv.conv.bias"""
def __init__(self, in_ch, out_ch, kernel_size, stride=1, padding=0):
super().__init__()
self.conv = WNConv1d(in_ch, out_ch, kernel_size, stride=stride, padding=padding)
def forward(self, x):
return self.conv(x)
class _ResBlockWrapper(nn.Module):
"""Wrapper to match checkpoint structure: block.1.conv.conv, block.3.conv.conv, shortcut.conv.conv"""
def __init__(self, dim):
super().__init__()
self.block = nn.Sequential()
self.block.add_module('0', nn.ELU())
self.block.add_module('1', _ConvWrapper(dim, dim // 2, 3, padding=1))
self.block.add_module('2', nn.ELU())
self.block.add_module('3', _ConvWrapper(dim // 2, dim, 1))
self.shortcut = _ConvWrapper(dim, dim, 1)
def forward(self, x):
return self.shortcut(x) + self.block(x)
class _LSTMWrapper(nn.Module):
"""LSTM wrapper matching checkpoint: lstm.weight_ih_l0, etc."""
def __init__(self, dim, num_layers=2):
super().__init__()
self.lstm = nn.LSTM(dim, dim, num_layers=num_layers, batch_first=True)
def forward(self, x):
x = x.transpose(1, 2)
y, _ = self.lstm(x)
y = y + x
return y.transpose(1, 2)
class EncoderModel(nn.Module):
"""
Encoder matching checkpoint: feature_extractor.encodec.encoder.model.*
Structure based on checkpoint:
- model.0: initial conv (1 -> 32)
- model.1: residual block (32)
- model.2: ELU (not saved)
- model.3: downsample conv (32->64, stride=2)
- model.4: residual block (64)
- model.5: ELU
- model.6: downsample conv (64->128, stride=4)
- model.7: residual block (128)
- model.8: ELU
- model.9: downsample conv (128->256, stride=5)
- model.10: residual block (256)
- model.11: ELU
- model.12: downsample conv (256->512, stride=8)
- model.13: LSTM
- model.14: ELU
- model.15: output conv (512->512)
"""
def __init__(self, channels=1, n_filters=32, dimension=512, ratios=[2, 4, 5, 8]):
super().__init__()
layers = []
# model.0: Initial conv
layers.append(_ConvWrapper(channels, n_filters, 7, padding=3))
# Encoder blocks with downsampling
in_ch = n_filters
for ratio in ratios:
out_ch = in_ch * 2
# Residual block
layers.append(_ResBlockWrapper(in_ch))
# ELU (implicit in original, but we need it)
layers.append(nn.ELU())
# Downsample conv
layers.append(_ConvWrapper(in_ch, out_ch, ratio * 2, stride=ratio, padding=ratio // 2))
in_ch = out_ch
# LSTM
layers.append(_LSTMWrapper(in_ch))
# ELU
layers.append(nn.ELU())
# Output conv
layers.append(_ConvWrapper(in_ch, dimension, 7, padding=3))
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
# =============================================================================
# Quantizer (matching feature_extractor.encodec.quantizer.vq.layers.0._codebook.*)
# =============================================================================
class Codebook(nn.Module):
"""Codebook matching checkpoint: _codebook.embed, _codebook.inited, _codebook.cluster_size, _codebook.embed_avg"""
def __init__(self, num_embeddings, embedding_dim):
super().__init__()
# These match checkpoint structure exactly
self.register_buffer('inited', torch.zeros(1))
self.register_buffer('cluster_size', torch.zeros(num_embeddings))
self.register_buffer('embed', torch.randn(num_embeddings, embedding_dim))
self.register_buffer('embed_avg', torch.randn(num_embeddings, embedding_dim))
def forward(self, x):
"""
Args:
x: (B, T, D) input
Returns:
quantized: (B, T, D) quantized output
indices: (B, T) codebook indices
"""
# L2 normalize
embed = F.normalize(self.embed, dim=-1)
x_norm = F.normalize(x, dim=-1)
# Find nearest
dist = torch.cdist(x_norm, embed)
indices = dist.argmin(dim=-1)
# Quantize
quantized = F.embedding(indices, embed)
# Straight-through
quantized = x_norm + (quantized - x_norm).detach()
return quantized, indices
def decode(self, indices):
embed = F.normalize(self.embed, dim=-1)
return F.embedding(indices, embed)
class VQLayer(nn.Module):
"""VQ layer matching checkpoint: vq.layers.0._codebook.*"""
def __init__(self, dim, codebook_size):
super().__init__()
self._codebook = Codebook(codebook_size, dim)
def forward(self, x):
# x: (B, D, T)
x = x.transpose(1, 2) # (B, T, D)
quantized, indices = self._codebook(x)
return quantized.transpose(1, 2), indices
def decode(self, indices):
quantized = self._codebook.decode(indices)
return quantized.transpose(1, 2)
class VQ(nn.Module):
"""VQ wrapper matching checkpoint: vq.layers"""
def __init__(self, dim, codebook_size, num_quantizers=1):
super().__init__()
self.layers = nn.ModuleList([
VQLayer(dim, codebook_size) for _ in range(num_quantizers)
])
def forward(self, x):
indices_list = []
quantized = torch.zeros_like(x)
residual = x
for layer in self.layers:
q, idx = layer(residual)
residual = residual - q
quantized = quantized + q
indices_list.append(idx)
indices = torch.stack(indices_list, dim=1)
return quantized, indices
def decode(self, indices):
quantized = None
for i, layer in enumerate(self.layers):
q = layer.decode(indices[:, i])
quantized = q if quantized is None else quantized + q
return quantized
class Quantizer(nn.Module):
"""Quantizer matching checkpoint: quantizer.vq"""
def __init__(self, dim, codebook_size, num_quantizers=1):
super().__init__()
self.vq = VQ(dim, codebook_size, num_quantizers)
def forward(self, x):
return self.vq(x)
def decode(self, indices):
return self.vq.decode(indices)
class EnCodecWrapper(nn.Module):
"""Wrapper matching checkpoint: encodec.encoder, encodec.quantizer"""
def __init__(self, channels=1, n_filters=32, dimension=512, ratios=[2, 4, 5, 8],
codebook_size=4096, num_quantizers=1):
super().__init__()
self.encoder = EncoderModel(channels, n_filters, dimension, ratios)
self.quantizer = Quantizer(dimension, codebook_size, num_quantizers)
# Note: decoder exists in checkpoint but we use Vocos backbone instead
def encode(self, x):
z = self.encoder(x)
z_q, codes = self.quantizer(z)
return z_q, codes
class FeatureExtractor(nn.Module):
"""Feature extractor matching checkpoint: feature_extractor.encodec"""
def __init__(self, **kwargs):
super().__init__()
self.encodec = EnCodecWrapper(**kwargs)
def encode(self, x):
return self.encodec.encode(x)
def decode_codes(self, codes):
return self.encodec.quantizer.decode(codes)
# =============================================================================
# Backbone (Vocos-style with bandwidth-conditioned AdaLayerNorm)
# =============================================================================
class AdaLayerNorm(nn.Module):
"""
Bandwidth-conditioned Adaptive LayerNorm.
Checkpoint structure:
- norm.scale.weight: [4, 768] (4 bandwidth conditions)
- norm.shift.weight: [4, 768]
"""
def __init__(self, dim, num_bandwidths=4, eps=1e-6):
super().__init__()
self.eps = eps
self.dim = dim
# Match checkpoint: scale.weight and shift.weight are [num_bandwidths, dim]
self.scale = nn.Embedding(num_bandwidths, dim)
self.shift = nn.Embedding(num_bandwidths, dim)
# Initialize
nn.init.ones_(self.scale.weight)
nn.init.zeros_(self.shift.weight)
def forward(self, x, bandwidth_id=None):
"""
Args:
x: (B, C, T) input
bandwidth_id: (B,) bandwidth index, or None for default (0)
"""
# Normalize
mean = x.mean(dim=1, keepdim=True)
var = x.var(dim=1, keepdim=True, unbiased=False)
x = (x - mean) / torch.sqrt(var + self.eps)
# Get scale/shift based on bandwidth_id
if bandwidth_id is None:
bandwidth_id = torch.zeros(x.shape[0], dtype=torch.long, device=x.device)
scale = self.scale(bandwidth_id) # (B, dim)
shift = self.shift(bandwidth_id) # (B, dim)
# Apply: (B, dim, 1) for broadcasting
x = x * scale.unsqueeze(-1) + shift.unsqueeze(-1)
return x
class ConvNeXtBlock(nn.Module):
"""
ConvNeXt block matching checkpoint structure exactly.
Checkpoint keys:
- dwconv.weight: [768, 1, 7]
- dwconv.bias: [768]
- norm.scale.weight: [4, 768]
- norm.shift.weight: [4, 768]
- pwconv1.weight: [2304, 768]
- pwconv1.bias: [2304]
- pwconv2.weight: [768, 2304]
- pwconv2.bias: [768]
- gamma: [768]
"""
def __init__(self, dim, intermediate_dim, kernel_size=7, layer_scale_init=1e-6, num_bandwidths=4):
super().__init__()
padding = (kernel_size - 1) // 2
self.dwconv = nn.Conv1d(dim, dim, kernel_size, padding=padding, groups=dim)
self.norm = AdaLayerNorm(dim, num_bandwidths)
self.pwconv1 = nn.Linear(dim, intermediate_dim)
self.pwconv2 = nn.Linear(intermediate_dim, dim)
self.gamma = nn.Parameter(layer_scale_init * torch.ones(dim))
def forward(self, x, bandwidth_id=None):
residual = x
x = self.dwconv(x)
x = self.norm(x, bandwidth_id)
x = x.transpose(1, 2) # (B, T, C)
x = self.pwconv1(x)
x = F.gelu(x)
x = self.pwconv2(x)
x = x.transpose(1, 2) # (B, C, T)
x = self.gamma.unsqueeze(0).unsqueeze(-1) * x
return residual + x
class Backbone(nn.Module):
"""
Vocos backbone matching checkpoint structure.
Checkpoint keys:
- embed.weight, embed.bias
- norm.scale.weight, norm.shift.weight
- convnext.0-11.*
- final_layer_norm.weight, final_layer_norm.bias
"""
def __init__(self, input_dim=512, dim=768, intermediate_dim=2304, num_blocks=12,
num_bandwidths=4):
super().__init__()
# Input projection: backbone.embed (kernel_size=7 to match checkpoint)
self.embed = nn.Conv1d(input_dim, dim, kernel_size=7, padding=3)
# Input normalization: backbone.norm
self.norm = AdaLayerNorm(dim, num_bandwidths)
# ConvNeXt blocks: backbone.convnext.0-11
self.convnext = nn.ModuleList([
ConvNeXtBlock(dim, intermediate_dim, num_bandwidths=num_bandwidths)
for _ in range(num_blocks)
])
# Final norm: backbone.final_layer_norm
self.final_layer_norm = nn.LayerNorm(dim)
def forward(self, x, bandwidth_id=None):
# Input projection
x = self.embed(x)
x = self.norm(x, bandwidth_id)
# ConvNeXt blocks
for block in self.convnext:
x = block(x, bandwidth_id)
# Final norm
x = x.transpose(1, 2) # (B, T, C)
x = self.final_layer_norm(x)
x = x.transpose(1, 2) # (B, C, T)
return x
# =============================================================================
# Head (iSTFT)
# =============================================================================
class ISTFT(nn.Module):
"""ISTFT module matching checkpoint: istft.window"""
def __init__(self, n_fft=1280):
super().__init__()
self.n_fft = n_fft
self.register_buffer('window', torch.hann_window(n_fft))
class ISTFTHead(nn.Module):
"""
iSTFT head matching checkpoint structure.
Checkpoint keys:
- out.weight: [1282, 768]
- out.bias: [1282]
- istft.window: [1280]
"""
def __init__(self, dim, n_fft=1280, hop_length=320, padding='center'):
super().__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.padding = padding
# Output projection: head.out
self.out = nn.Linear(dim, n_fft + 2)
# ISTFT window: head.istft.window
self.istft = ISTFT(n_fft)
def forward(self, x):
"""
Args:
x: (B, C, T) backbone output
Returns:
audio: (B, 1, samples)
"""
B, C, T = x.shape
x = x.transpose(1, 2) # (B, T, C)
x = self.out(x) # (B, T, n_fft + 2)
# Split magnitude and phase
n_bins = self.n_fft // 2 + 1 # 641
mag = torch.exp(x[:, :, :n_bins])
phase = x[:, :, n_bins:]
# Construct complex STFT
stft = torch.complex(mag * torch.cos(phase), mag * torch.sin(phase))
stft = stft.transpose(1, 2) # (B, n_bins, T)
# Inverse STFT
audio = torch.istft(
stft,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.n_fft,
window=self.istft.window,
center=(self.padding == 'center'),
return_complex=False,
)
return audio.unsqueeze(1)
# =============================================================================
# Main WavTokenizer Model
# =============================================================================
class WavTokenizer(PreTrainedModel):
"""
WavTokenizer model for audio tokenization.
This implementation exactly matches the checkpoint structure for direct weight loading.
"""
config_class = WavTokenizerConfig
base_model_prefix = "wavtokenizer"
def __init__(self, config: WavTokenizerConfig):
super().__init__(config)
self.config = config
# Feature extractor (encoder + quantizer)
# Matches: feature_extractor.encodec.*
self.feature_extractor = FeatureExtractor(
channels=1,
n_filters=config.encoder_dim,
dimension=config.latent_dim,
ratios=config.encoder_rates,
codebook_size=config.codebook_size,
num_quantizers=config.num_quantizers,
)
# Backbone (Vocos-style decoder)
# Matches: backbone.*
self.backbone = Backbone(
input_dim=config.latent_dim,
dim=config.backbone_dim,
intermediate_dim=config.backbone_intermediate_dim,
num_blocks=config.backbone_num_blocks,
num_bandwidths=4,
)
# Head (iSTFT)
# Matches: head.*
self.head = ISTFTHead(
dim=config.backbone_dim,
n_fft=config.n_fft,
hop_length=config.hop_length,
padding=config.padding,
)
self.post_init()
def encode(self, audio, bandwidth_id=None):
"""
Encode audio to quantized features and codes.
Args:
audio: (B, 1, T) audio waveform
bandwidth_id: Optional (B,) bandwidth index
Returns:
features: (B, D, T') quantized features
codes: (B, num_quantizers, T') discrete codes
"""
return self.feature_extractor.encode(audio)
def encode_infer(self, audio, bandwidth_id=None):
"""
Encode audio for inference.
Args:
audio: (B, 1, T) audio waveform
bandwidth_id: Optional bandwidth index (scalar or tensor)
Returns:
features: (B, D, T') quantized features
codes: (B, T') discrete codes (squeezed for single quantizer)
"""
features, codes = self.encode(audio, bandwidth_id)
if codes.shape[1] == 1:
codes = codes.squeeze(1)
return features, codes
def decode(self, features, bandwidth_id=None):
"""
Decode features to audio.
Args:
features: (B, D, T') quantized features
bandwidth_id: Optional (B,) bandwidth index
Returns:
audio: (B, 1, T) reconstructed waveform
"""
x = self.backbone(features, bandwidth_id)
return self.head(x)
def codes_to_features(self, codes):
"""
Convert discrete codes back to continuous features.
Args:
codes: (B, T) or (B, num_quantizers, T) discrete codes
Returns:
features: (B, D, T) continuous features
"""
if codes.dim() == 2:
codes = codes.unsqueeze(1)
return self.feature_extractor.decode_codes(codes)
def forward(
self,
input_values: Optional[torch.Tensor] = None,
input_ids: Optional[torch.Tensor] = None,
bandwidth_id: Optional[torch.Tensor] = None,
**kwargs,
):
"""
HuggingFace-style forward pass.
Args:
input_values: (B, 1, T) or (B, T) audio waveform
input_ids: (B, T) or (B, num_quantizers, T) discrete codes
bandwidth_id: Optional (B,) bandwidth index
Returns:
BaseModelOutput with last_hidden_state (features) and hidden_states (codes, audio)
"""
if input_values is not None:
if input_values.dim() == 2:
input_values = input_values.unsqueeze(1)
features, codes = self.encode(input_values, bandwidth_id)
audio = self.decode(features, bandwidth_id)
return BaseModelOutput(
last_hidden_state=features,
hidden_states=(codes, audio),
)
elif input_ids is not None:
features = self.codes_to_features(input_ids)
audio = self.decode(features, bandwidth_id)
return BaseModelOutput(
last_hidden_state=features,
hidden_states=(input_ids, audio),
)
else:
raise ValueError("Either input_values or input_ids must be provided")