WavTokenizer / modeling_wavtokenizer.py
klemenk's picture
Create modeling_wavtokenizer.py
1defa8d verified
raw
history blame
25.2 kB
"""
WavTokenizer Model for HuggingFace Transformers
This module contains the complete implementation of WavTokenizer,
an acoustic discrete codec tokenizer for audio language modeling.
All dependencies are included to avoid external imports.
The architecture follows the original WavTokenizer implementation:
- Encoder: Strided convolutions for audio compression
- VQ: Vector quantization with single codebook
- Decoder: Vocos-style backbone with ConvNeXt blocks + iSTFT head
Reference: https://github.com/jishengpeng/WavTokenizer
Paper: "WavTokenizer: an Efficient Acoustic Discrete Codec Tokenizer for Audio Language Modeling"
"""
import math
from typing import Dict, List, Optional, Tuple, Union
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn.utils import weight_norm, remove_weight_norm
from transformers import PreTrainedModel
from transformers.tokenization_utils import BatchEncoding
from .configuration_wavtokenizer import WavTokenizerConfig
# ==============================================================================
# Utility Functions
# ==============================================================================
def convert_audio(wav: Tensor, sr: int, target_sr: int, target_channels: int) -> Tensor:
"""
Convert audio to target sample rate and number of channels.
Args:
wav: Input waveform [C, T] or [T]
sr: Source sample rate
target_sr: Target sample rate
target_channels: Target number of channels (1 for mono, 2 for stereo)
Returns:
Converted waveform [target_channels, T']
"""
import torchaudio
# Ensure 2D
if wav.dim() == 1:
wav = wav.unsqueeze(0)
# Convert channels
if wav.size(0) > target_channels:
wav = wav.mean(dim=0, keepdim=True)
elif wav.size(0) < target_channels:
wav = wav.expand(target_channels, -1)
# Resample if needed
if sr != target_sr:
wav = torchaudio.functional.resample(wav, sr, target_sr)
return wav
# ==============================================================================
# Encoder Components (DAC-style)
# ==============================================================================
def WNConv1d(*args, **kwargs):
"""Weight-normalized Conv1d."""
return weight_norm(nn.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
"""Weight-normalized ConvTranspose1d."""
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
class ResidualUnit(nn.Module):
"""Residual unit with dilated convolution."""
def __init__(self, dim: int = 16, dilation: int = 1):
super().__init__()
pad = ((7 - 1) * dilation) // 2
self.block = nn.Sequential(
nn.ELU(),
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
nn.ELU(),
WNConv1d(dim, dim, kernel_size=1),
)
def forward(self, x: Tensor) -> Tensor:
return x + self.block(x)
class EncoderBlock(nn.Module):
"""Encoder block with residual units and downsampling."""
def __init__(self, dim: int = 16, stride: int = 1):
super().__init__()
self.block = nn.Sequential(
ResidualUnit(dim // 2, dilation=1),
ResidualUnit(dim // 2, dilation=3),
ResidualUnit(dim // 2, dilation=9),
nn.ELU(),
WNConv1d(
dim // 2, dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
),
)
def forward(self, x: Tensor) -> Tensor:
return self.block(x)
class Encoder(nn.Module):
"""
DAC-style encoder that compresses waveform to latent representation.
Uses strided convolutions for downsampling.
"""
def __init__(
self,
d_model: int = 64,
strides: List[int] = [8, 5, 4, 2],
d_latent: int = 512,
):
super().__init__()
# Initial conv
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
# Encoder blocks with increasing channels
for stride in strides:
d_model *= 2
self.block.append(EncoderBlock(d_model, stride=stride))
# Final projection
self.block.extend([
nn.ELU(),
WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
])
self.block = nn.Sequential(*self.block)
self.enc_dim = d_model
def forward(self, x: Tensor) -> Tensor:
return self.block(x)
# ==============================================================================
# Vector Quantization
# ==============================================================================
class VectorQuantize(nn.Module):
"""
Improved vector quantization with EMA codebook updates.
Uses L2-normalized codes for better stability.
"""
def __init__(
self,
input_dim: int,
codebook_size: int,
codebook_dim: int,
commitment: float = 0.25,
):
super().__init__()
self.input_dim = input_dim
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.commitment = commitment
# Projections
requires_projection = input_dim != codebook_dim
self.project_in = nn.Linear(input_dim, codebook_dim) if requires_projection else nn.Identity()
self.project_out = nn.Linear(codebook_dim, input_dim) if requires_projection else nn.Identity()
# Codebook
self.codebook = nn.Embedding(codebook_size, codebook_dim)
nn.init.uniform_(self.codebook.weight, -1.0 / codebook_size, 1.0 / codebook_size)
def forward(self, z: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
"""
Forward pass.
Args:
z: Input [B, D, T]
Returns:
z_q: Quantized [B, D, T]
commitment_loss: Loss scalar
indices: Codes [B, T]
"""
# [B, D, T] -> [B, T, D]
z = z.transpose(1, 2)
z_e = self.project_in(z)
# L2 normalize
z_e_norm = F.normalize(z_e, dim=-1)
codebook_norm = F.normalize(self.codebook.weight, dim=-1)
# Find nearest codes
dist = (
z_e_norm.pow(2).sum(-1, keepdim=True)
+ codebook_norm.pow(2).sum(-1)
- 2 * torch.einsum('btd,kd->btk', z_e_norm, codebook_norm)
)
indices = dist.argmin(dim=-1)
# Look up quantized values
z_q = F.embedding(indices, codebook_norm)
# Commitment loss
commitment_loss = F.mse_loss(z_e_norm, z_q.detach()) * self.commitment
# Straight-through
z_q = z_e_norm + (z_q - z_e_norm).detach()
# Project out and transpose back
z_q = self.project_out(z_q)
z_q = z_q.transpose(1, 2) # [B, D, T]
return z_q, commitment_loss, indices
def decode(self, indices: Tensor) -> Tensor:
"""Decode indices to vectors."""
codebook = F.normalize(self.codebook.weight, dim=-1)
z_q = F.embedding(indices, codebook)
z_q = self.project_out(z_q)
return z_q.transpose(1, 2)
class ResidualVectorQuantize(nn.Module):
"""Residual VQ with multiple codebooks (typically 1 for WavTokenizer)."""
def __init__(
self,
input_dim: int = 512,
codebook_size: int = 4096,
codebook_dim: int = 8,
num_quantizers: int = 1,
commitment: float = 0.25,
):
super().__init__()
self.num_quantizers = num_quantizers
self.quantizers = nn.ModuleList([
VectorQuantize(input_dim, codebook_size, codebook_dim, commitment)
for _ in range(num_quantizers)
])
def forward(
self, z: Tensor, n_quantizers: int = None
) -> Tuple[Tensor, Tensor, Tensor]:
n_q = n_quantizers or self.num_quantizers
residual = z
z_q = torch.zeros_like(z)
all_indices = []
all_losses = []
for i, quantizer in enumerate(self.quantizers[:n_q]):
_z_q, loss, indices = quantizer(residual)
residual = residual - _z_q
z_q = z_q + _z_q
all_indices.append(indices)
all_losses.append(loss)
codes = torch.stack(all_indices, dim=0) # [N_q, B, T]
commitment_loss = sum(all_losses)
return z_q, commitment_loss, codes
def decode(self, codes: Tensor) -> Tensor:
"""Decode codes to vectors."""
if codes.dim() == 2:
codes = codes.unsqueeze(0)
z_q = None
for i, quantizer in enumerate(self.quantizers[:codes.size(0)]):
_z_q = quantizer.decode(codes[i])
z_q = _z_q if z_q is None else z_q + _z_q
return z_q
# ==============================================================================
# Decoder Components (Vocos-style)
# ==============================================================================
class ConvNeXtBlock(nn.Module):
"""ConvNeXt block with depthwise conv + pointwise expansion."""
def __init__(
self,
dim: int,
intermediate_dim: int,
kernel_size: int = 7,
layer_scale_init_value: float = 1e-6,
):
super().__init__()
padding = (kernel_size - 1) // 2
self.dwconv = nn.Conv1d(dim, dim, kernel_size, padding=padding, groups=dim)
self.norm = nn.LayerNorm(dim)
self.pwconv1 = nn.Linear(dim, intermediate_dim)
self.act = nn.GELU()
self.pwconv2 = nn.Linear(intermediate_dim, dim)
self.gamma = nn.Parameter(
layer_scale_init_value * torch.ones(dim)
) if layer_scale_init_value > 0 else None
def forward(self, x: Tensor) -> Tensor:
residual = x
x = self.dwconv(x)
x = x.transpose(1, 2) # [B, T, D]
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.transpose(1, 2) # [B, D, T]
return residual + x
class VocosBackbone(nn.Module):
"""Vocos backbone with attention and ConvNeXt blocks."""
def __init__(
self,
input_dim: int,
dim: int,
intermediate_dim: int,
num_blocks: int,
kernel_size: int = 7,
layer_scale_init_value: float = 1e-6,
use_attention: bool = True,
num_heads: int = 8,
num_attention_layers: int = 1,
):
super().__init__()
# Input projection
self.input_conv = nn.Conv1d(input_dim, dim, kernel_size=7, padding=3)
self.norm = nn.LayerNorm(dim)
# Attention layers
self.use_attention = use_attention
if use_attention:
self.attention = nn.ModuleList([
nn.MultiheadAttention(dim, num_heads, batch_first=True)
for _ in range(num_attention_layers)
])
self.attn_norms = nn.ModuleList([
nn.LayerNorm(dim) for _ in range(num_attention_layers)
])
# ConvNeXt blocks
self.convnext = nn.ModuleList([
ConvNeXtBlock(dim, intermediate_dim, kernel_size, layer_scale_init_value)
for _ in range(num_blocks)
])
self.final_norm = nn.LayerNorm(dim)
def forward(self, x: Tensor) -> Tensor:
# Input projection
x = self.input_conv(x)
x = x.transpose(1, 2) # [B, T, D]
x = self.norm(x)
x = x.transpose(1, 2) # [B, D, T]
# Attention
if self.use_attention:
for attn, norm in zip(self.attention, self.attn_norms):
x_t = x.transpose(1, 2) # [B, T, D]
residual = x_t
x_t = norm(x_t)
x_t, _ = attn(x_t, x_t, x_t)
x_t = residual + x_t
x = x_t.transpose(1, 2) # [B, D, T]
# ConvNeXt blocks
for block in self.convnext:
x = block(x)
# Final norm
x = x.transpose(1, 2)
x = self.final_norm(x)
x = x.transpose(1, 2)
return x
class ISTFTHead(nn.Module):
"""Inverse STFT head for waveform synthesis."""
def __init__(
self,
dim: int,
n_fft: int,
hop_length: int,
padding: str = "center",
):
super().__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.padding = padding
self.out_dim = n_fft // 2 + 1
self.proj = nn.Conv1d(dim, self.out_dim * 2, kernel_size=1)
# Register window buffer
self.register_buffer(
"window",
torch.hann_window(n_fft),
persistent=False
)
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: [B, D, T]
Returns:
wav: [B, 1, T']
"""
x = self.proj(x)
# Split mag/phase
mag, phase = x.chunk(2, dim=1)
# Process
mag = torch.exp(mag)
phase = torch.sin(phase)
# Complex spectrum
S = torch.complex(mag * torch.cos(phase * math.pi), mag * torch.sin(phase * math.pi))
# Ensure window is on same device
window = self.window.to(x.device)
# iSTFT
wav = torch.istft(
S,
n_fft=self.n_fft,
hop_length=self.hop_length,
window=window,
center=True,
normalized=False,
onesided=True,
return_complex=False,
)
return wav.unsqueeze(1)
# ==============================================================================
# Feature Extractor (Mel Spectrogram)
# ==============================================================================
class MelSpectrogramFeatures(nn.Module):
"""Extract mel spectrogram features from audio."""
def __init__(
self,
sample_rate: int = 24000,
n_fft: int = 1024,
hop_length: int = 256,
n_mels: int = 100,
f_min: float = 0.0,
f_max: float = None,
padding: str = "center",
):
super().__init__()
self.sample_rate = sample_rate
self.n_fft = n_fft
self.hop_length = hop_length
self.n_mels = n_mels
self.padding = padding
# Mel filterbank
import torchaudio
mel_fb = torchaudio.functional.melscale_fbanks(
n_freqs=n_fft // 2 + 1,
f_min=f_min,
f_max=f_max or sample_rate // 2,
n_mels=n_mels,
sample_rate=sample_rate,
norm="slaney",
mel_scale="slaney",
)
self.register_buffer("mel_fb", mel_fb, persistent=False)
self.register_buffer("window", torch.hann_window(n_fft), persistent=False)
def forward(self, wav: Tensor) -> Tensor:
"""
Args:
wav: [B, 1, T] or [B, T]
Returns:
mel: [B, n_mels, T']
"""
if wav.dim() == 3:
wav = wav.squeeze(1)
# STFT
stft = torch.stft(
wav,
n_fft=self.n_fft,
hop_length=self.hop_length,
window=self.window.to(wav.device),
center=True,
return_complex=True,
)
# Power spectrum
power = stft.abs().pow(2)
# Mel spectrogram
mel = torch.matmul(self.mel_fb.T.to(power.device), power)
# Log scale
mel = torch.log(mel.clamp(min=1e-5))
return mel
# ==============================================================================
# Main WavTokenizer Model
# ==============================================================================
class WavTokenizer(PreTrainedModel):
"""
WavTokenizer: Efficient acoustic discrete codec tokenizer.
Architecture:
- Encoder: Strided convolutions for audio compression
- VQ: Single-codebook vector quantization (4096 codes)
- Decoder: Vocos backbone (ConvNeXt + attention) + iSTFT head
Usage:
```python
model = WavTokenizer.from_pretrained("TuKoResearch/WavTokenizerSmall", trust_remote_code=True)
# Encode
features, codes = model.encode_infer(wav, bandwidth_id=torch.tensor([0]))
# Decode
wav_out = model.decode(features, bandwidth_id=torch.tensor([0]))
# Or use codes directly
features = model.codes_to_features(codes)
wav_out = model.decode(features, bandwidth_id=torch.tensor([0]))
```
"""
config_class = WavTokenizerConfig
def __init__(self, config: WavTokenizerConfig):
super().__init__(config)
self.sample_rate = config.sample_rate
self.hop_length = config.hop_length
# Encoder
self.encoder = Encoder(
d_model=config.encoder_dim,
strides=config.encoder_rates,
d_latent=config.latent_dim,
)
# Quantizer
self.quantizer = ResidualVectorQuantize(
input_dim=config.latent_dim,
codebook_size=config.codebook_size,
codebook_dim=config.codebook_dim,
num_quantizers=config.num_quantizers,
)
# Feature projection for decoder
self.feature_proj = nn.Conv1d(config.latent_dim, config.backbone_dim, 1)
# Decoder backbone
self.backbone = VocosBackbone(
input_dim=config.backbone_dim,
dim=config.backbone_dim,
intermediate_dim=config.backbone_intermediate_dim,
num_blocks=config.backbone_num_blocks,
kernel_size=config.backbone_kernel_size,
layer_scale_init_value=config.backbone_layer_scale_init_value,
use_attention=config.use_attention,
num_heads=config.attention_heads,
num_attention_layers=config.attention_layers,
)
# iSTFT head
self.head = ISTFTHead(
dim=config.backbone_dim,
n_fft=config.n_fft,
hop_length=config.hop_length,
padding=config.padding,
)
# Bandwidth embedding
self.bandwidth_emb = nn.Embedding(4, config.backbone_dim)
self.post_init()
@property
def vocab_size(self) -> int:
return self.config.codebook_size
@property
def frame_rate(self) -> float:
return self.config.sample_rate / self.config.hop_length
def encode(
self, wav: Tensor, bandwidth_id: Tensor = None
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Encode waveform to quantized features.
Args:
wav: [B, 1, T] or [B, T]
bandwidth_id: Optional bandwidth ID
Returns:
z_q: Quantized features [B, D, T']
commitment_loss: VQ loss
codes: Discrete codes [N_q, B, T']
"""
if wav.dim() == 2:
wav = wav.unsqueeze(1)
z = self.encoder(wav)
z_q, loss, codes = self.quantizer(z)
return z_q, loss, codes
@torch.no_grad()
def encode_infer(
self, wav: Tensor, bandwidth_id: Tensor = None
) -> Tuple[Tensor, Tensor]:
"""
Encode waveform to features and codes (inference).
Args:
wav: [B, 1, T] or [1, T] or [B, T]
bandwidth_id: Optional bandwidth ID
Returns:
features: [B, D, T']
codes: [B, T'] (squeezed if single quantizer)
"""
if wav.dim() == 2:
if wav.size(0) == 1:
wav = wav.unsqueeze(0) # [1, T] -> [1, 1, T]
else:
wav = wav.unsqueeze(1) # [B, T] -> [B, 1, T]
z = self.encoder(wav)
z_q, _, codes = self.quantizer(z)
# Squeeze for single quantizer
if codes.size(0) == 1:
codes = codes.squeeze(0)
return z_q, codes
def decode(
self, features: Tensor, bandwidth_id: Tensor = None
) -> Tensor:
"""
Decode features to waveform.
Args:
features: [B, D, T']
bandwidth_id: Optional bandwidth ID
Returns:
wav: [B, 1, T]
"""
x = self.feature_proj(features)
if bandwidth_id is not None:
bw_emb = self.bandwidth_emb(bandwidth_id)
x = x + bw_emb.unsqueeze(-1)
x = self.backbone(x)
wav = self.head(x)
return wav
@torch.no_grad()
def codes_to_features(self, codes: Tensor) -> Tensor:
"""
Convert codes to features.
Args:
codes: [N_q, B, T'] or [B, T']
Returns:
features: [B, D, T']
"""
return self.quantizer.decode(codes)
def forward(
self,
wav: Tensor = None,
codes: Tensor = None,
bandwidth_id: Tensor = None,
**kwargs
) -> Union[BatchEncoding, Tensor]:
"""
Forward pass.
If wav provided: encode to get tokens
If codes provided: decode to get wav
"""
if wav is not None:
features, codes = self.encode_infer(wav, bandwidth_id)
return BatchEncoding({
"input_values": features,
"input_ids": codes,
})
elif codes is not None:
features = self.codes_to_features(codes)
return self.decode(features, bandwidth_id)
else:
raise ValueError("Provide either 'wav' or 'codes'")
@classmethod
def from_pretrained0802(
cls,
config_path: str,
checkpoint_path: str,
device: str = "cpu",
) -> "WavTokenizer":
"""
Load from original WavTokenizer checkpoint.
Args:
config_path: Path to YAML config
checkpoint_path: Path to .ckpt file
device: Device to load to
Returns:
Loaded model
"""
import yaml
# Load YAML config
with open(config_path, 'r') as f:
yaml_cfg = yaml.safe_load(f)
# Extract config params
model_args = yaml_cfg.get('model', {}).get('init_args', {})
# Create HF config
config = WavTokenizerConfig(
sample_rate=24000,
n_fft=model_args.get('head', {}).get('init_args', {}).get('n_fft', 1280),
hop_length=model_args.get('head', {}).get('init_args', {}).get('hop_length', 320),
feature_dim=model_args.get('backbone', {}).get('init_args', {}).get('dim', 512),
latent_dim=model_args.get('backbone', {}).get('init_args', {}).get('input_channels', 512),
backbone_dim=model_args.get('backbone', {}).get('init_args', {}).get('dim', 512),
backbone_intermediate_dim=model_args.get('backbone', {}).get('init_args', {}).get('intermediate_dim', 1536),
backbone_num_blocks=model_args.get('backbone', {}).get('init_args', {}).get('num_layers', 8),
codebook_size=model_args.get('quantizer', {}).get('init_args', {}).get('codebook_size', 4096),
codebook_dim=model_args.get('quantizer', {}).get('init_args', {}).get('codebook_dim', 8),
num_quantizers=model_args.get('quantizer', {}).get('init_args', {}).get('num_quantizers', 1),
use_attention=True,
attention_dim=model_args.get('backbone', {}).get('init_args', {}).get('dim', 512),
attention_heads=8,
attention_layers=1,
)
# Create model
model = cls(config)
# Load checkpoint
ckpt = torch.load(checkpoint_path, map_location=device)
state_dict = ckpt.get('state_dict', ckpt)
# Clean state dict
new_state_dict = {}
for k, v in state_dict.items():
# Remove 'model.' prefix if present
if k.startswith('model.'):
k = k[6:]
new_state_dict[k] = v
# Load (non-strict to handle mismatches)
missing, unexpected = model.load_state_dict(new_state_dict, strict=False)
if missing:
print(f"Missing keys: {len(missing)}")
if unexpected:
print(f"Unexpected keys: {len(unexpected)}")
return model.to(device)