|
|
""" |
|
|
WavTokenizer Model for HuggingFace Transformers |
|
|
|
|
|
This module contains the complete implementation of WavTokenizer, |
|
|
an acoustic discrete codec tokenizer for audio language modeling. |
|
|
All dependencies are included to avoid external imports. |
|
|
|
|
|
The architecture follows the original WavTokenizer implementation: |
|
|
- Encoder: Strided convolutions for audio compression |
|
|
- VQ: Vector quantization with single codebook |
|
|
- Decoder: Vocos-style backbone with ConvNeXt blocks + iSTFT head |
|
|
|
|
|
Reference: https://github.com/jishengpeng/WavTokenizer |
|
|
Paper: "WavTokenizer: an Efficient Acoustic Discrete Codec Tokenizer for Audio Language Modeling" |
|
|
""" |
|
|
|
|
|
import math |
|
|
from typing import Dict, List, Optional, Tuple, Union |
|
|
from dataclasses import dataclass |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch import Tensor |
|
|
from torch.nn.utils import weight_norm, remove_weight_norm |
|
|
|
|
|
from transformers import PreTrainedModel |
|
|
from transformers.tokenization_utils import BatchEncoding |
|
|
|
|
|
from .configuration_wavtokenizer import WavTokenizerConfig |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_audio(wav: Tensor, sr: int, target_sr: int, target_channels: int) -> Tensor: |
|
|
""" |
|
|
Convert audio to target sample rate and number of channels. |
|
|
|
|
|
Args: |
|
|
wav: Input waveform [C, T] or [T] |
|
|
sr: Source sample rate |
|
|
target_sr: Target sample rate |
|
|
target_channels: Target number of channels (1 for mono, 2 for stereo) |
|
|
|
|
|
Returns: |
|
|
Converted waveform [target_channels, T'] |
|
|
""" |
|
|
import torchaudio |
|
|
|
|
|
|
|
|
if wav.dim() == 1: |
|
|
wav = wav.unsqueeze(0) |
|
|
|
|
|
|
|
|
if wav.size(0) > target_channels: |
|
|
wav = wav.mean(dim=0, keepdim=True) |
|
|
elif wav.size(0) < target_channels: |
|
|
wav = wav.expand(target_channels, -1) |
|
|
|
|
|
|
|
|
if sr != target_sr: |
|
|
wav = torchaudio.functional.resample(wav, sr, target_sr) |
|
|
|
|
|
return wav |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def WNConv1d(*args, **kwargs): |
|
|
"""Weight-normalized Conv1d.""" |
|
|
return weight_norm(nn.Conv1d(*args, **kwargs)) |
|
|
|
|
|
|
|
|
def WNConvTranspose1d(*args, **kwargs): |
|
|
"""Weight-normalized ConvTranspose1d.""" |
|
|
return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) |
|
|
|
|
|
|
|
|
class ResidualUnit(nn.Module): |
|
|
"""Residual unit with dilated convolution.""" |
|
|
|
|
|
def __init__(self, dim: int = 16, dilation: int = 1): |
|
|
super().__init__() |
|
|
pad = ((7 - 1) * dilation) // 2 |
|
|
self.block = nn.Sequential( |
|
|
nn.ELU(), |
|
|
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), |
|
|
nn.ELU(), |
|
|
WNConv1d(dim, dim, kernel_size=1), |
|
|
) |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
return x + self.block(x) |
|
|
|
|
|
|
|
|
class EncoderBlock(nn.Module): |
|
|
"""Encoder block with residual units and downsampling.""" |
|
|
|
|
|
def __init__(self, dim: int = 16, stride: int = 1): |
|
|
super().__init__() |
|
|
self.block = nn.Sequential( |
|
|
ResidualUnit(dim // 2, dilation=1), |
|
|
ResidualUnit(dim // 2, dilation=3), |
|
|
ResidualUnit(dim // 2, dilation=9), |
|
|
nn.ELU(), |
|
|
WNConv1d( |
|
|
dim // 2, dim, |
|
|
kernel_size=2 * stride, |
|
|
stride=stride, |
|
|
padding=math.ceil(stride / 2), |
|
|
), |
|
|
) |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
return self.block(x) |
|
|
|
|
|
|
|
|
class Encoder(nn.Module): |
|
|
""" |
|
|
DAC-style encoder that compresses waveform to latent representation. |
|
|
Uses strided convolutions for downsampling. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
d_model: int = 64, |
|
|
strides: List[int] = [8, 5, 4, 2], |
|
|
d_latent: int = 512, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)] |
|
|
|
|
|
|
|
|
for stride in strides: |
|
|
d_model *= 2 |
|
|
self.block.append(EncoderBlock(d_model, stride=stride)) |
|
|
|
|
|
|
|
|
self.block.extend([ |
|
|
nn.ELU(), |
|
|
WNConv1d(d_model, d_latent, kernel_size=3, padding=1), |
|
|
]) |
|
|
|
|
|
self.block = nn.Sequential(*self.block) |
|
|
self.enc_dim = d_model |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
return self.block(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VectorQuantize(nn.Module): |
|
|
""" |
|
|
Improved vector quantization with EMA codebook updates. |
|
|
|
|
|
Uses L2-normalized codes for better stability. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input_dim: int, |
|
|
codebook_size: int, |
|
|
codebook_dim: int, |
|
|
commitment: float = 0.25, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.input_dim = input_dim |
|
|
self.codebook_size = codebook_size |
|
|
self.codebook_dim = codebook_dim |
|
|
self.commitment = commitment |
|
|
|
|
|
|
|
|
requires_projection = input_dim != codebook_dim |
|
|
self.project_in = nn.Linear(input_dim, codebook_dim) if requires_projection else nn.Identity() |
|
|
self.project_out = nn.Linear(codebook_dim, input_dim) if requires_projection else nn.Identity() |
|
|
|
|
|
|
|
|
self.codebook = nn.Embedding(codebook_size, codebook_dim) |
|
|
nn.init.uniform_(self.codebook.weight, -1.0 / codebook_size, 1.0 / codebook_size) |
|
|
|
|
|
def forward(self, z: Tensor) -> Tuple[Tensor, Tensor, Tensor]: |
|
|
""" |
|
|
Forward pass. |
|
|
|
|
|
Args: |
|
|
z: Input [B, D, T] |
|
|
|
|
|
Returns: |
|
|
z_q: Quantized [B, D, T] |
|
|
commitment_loss: Loss scalar |
|
|
indices: Codes [B, T] |
|
|
""" |
|
|
|
|
|
z = z.transpose(1, 2) |
|
|
z_e = self.project_in(z) |
|
|
|
|
|
|
|
|
z_e_norm = F.normalize(z_e, dim=-1) |
|
|
codebook_norm = F.normalize(self.codebook.weight, dim=-1) |
|
|
|
|
|
|
|
|
dist = ( |
|
|
z_e_norm.pow(2).sum(-1, keepdim=True) |
|
|
+ codebook_norm.pow(2).sum(-1) |
|
|
- 2 * torch.einsum('btd,kd->btk', z_e_norm, codebook_norm) |
|
|
) |
|
|
indices = dist.argmin(dim=-1) |
|
|
|
|
|
|
|
|
z_q = F.embedding(indices, codebook_norm) |
|
|
|
|
|
|
|
|
commitment_loss = F.mse_loss(z_e_norm, z_q.detach()) * self.commitment |
|
|
|
|
|
|
|
|
z_q = z_e_norm + (z_q - z_e_norm).detach() |
|
|
|
|
|
|
|
|
z_q = self.project_out(z_q) |
|
|
z_q = z_q.transpose(1, 2) |
|
|
|
|
|
return z_q, commitment_loss, indices |
|
|
|
|
|
def decode(self, indices: Tensor) -> Tensor: |
|
|
"""Decode indices to vectors.""" |
|
|
codebook = F.normalize(self.codebook.weight, dim=-1) |
|
|
z_q = F.embedding(indices, codebook) |
|
|
z_q = self.project_out(z_q) |
|
|
return z_q.transpose(1, 2) |
|
|
|
|
|
|
|
|
class ResidualVectorQuantize(nn.Module): |
|
|
"""Residual VQ with multiple codebooks (typically 1 for WavTokenizer).""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input_dim: int = 512, |
|
|
codebook_size: int = 4096, |
|
|
codebook_dim: int = 8, |
|
|
num_quantizers: int = 1, |
|
|
commitment: float = 0.25, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.num_quantizers = num_quantizers |
|
|
self.quantizers = nn.ModuleList([ |
|
|
VectorQuantize(input_dim, codebook_size, codebook_dim, commitment) |
|
|
for _ in range(num_quantizers) |
|
|
]) |
|
|
|
|
|
def forward( |
|
|
self, z: Tensor, n_quantizers: int = None |
|
|
) -> Tuple[Tensor, Tensor, Tensor]: |
|
|
n_q = n_quantizers or self.num_quantizers |
|
|
|
|
|
residual = z |
|
|
z_q = torch.zeros_like(z) |
|
|
all_indices = [] |
|
|
all_losses = [] |
|
|
|
|
|
for i, quantizer in enumerate(self.quantizers[:n_q]): |
|
|
_z_q, loss, indices = quantizer(residual) |
|
|
residual = residual - _z_q |
|
|
z_q = z_q + _z_q |
|
|
all_indices.append(indices) |
|
|
all_losses.append(loss) |
|
|
|
|
|
codes = torch.stack(all_indices, dim=0) |
|
|
commitment_loss = sum(all_losses) |
|
|
|
|
|
return z_q, commitment_loss, codes |
|
|
|
|
|
def decode(self, codes: Tensor) -> Tensor: |
|
|
"""Decode codes to vectors.""" |
|
|
if codes.dim() == 2: |
|
|
codes = codes.unsqueeze(0) |
|
|
|
|
|
z_q = None |
|
|
for i, quantizer in enumerate(self.quantizers[:codes.size(0)]): |
|
|
_z_q = quantizer.decode(codes[i]) |
|
|
z_q = _z_q if z_q is None else z_q + _z_q |
|
|
|
|
|
return z_q |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConvNeXtBlock(nn.Module): |
|
|
"""ConvNeXt block with depthwise conv + pointwise expansion.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
intermediate_dim: int, |
|
|
kernel_size: int = 7, |
|
|
layer_scale_init_value: float = 1e-6, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
padding = (kernel_size - 1) // 2 |
|
|
self.dwconv = nn.Conv1d(dim, dim, kernel_size, padding=padding, groups=dim) |
|
|
self.norm = nn.LayerNorm(dim) |
|
|
self.pwconv1 = nn.Linear(dim, intermediate_dim) |
|
|
self.act = nn.GELU() |
|
|
self.pwconv2 = nn.Linear(intermediate_dim, dim) |
|
|
|
|
|
self.gamma = nn.Parameter( |
|
|
layer_scale_init_value * torch.ones(dim) |
|
|
) if layer_scale_init_value > 0 else None |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
residual = x |
|
|
x = self.dwconv(x) |
|
|
x = x.transpose(1, 2) |
|
|
x = self.norm(x) |
|
|
x = self.pwconv1(x) |
|
|
x = self.act(x) |
|
|
x = self.pwconv2(x) |
|
|
if self.gamma is not None: |
|
|
x = self.gamma * x |
|
|
x = x.transpose(1, 2) |
|
|
return residual + x |
|
|
|
|
|
|
|
|
class VocosBackbone(nn.Module): |
|
|
"""Vocos backbone with attention and ConvNeXt blocks.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input_dim: int, |
|
|
dim: int, |
|
|
intermediate_dim: int, |
|
|
num_blocks: int, |
|
|
kernel_size: int = 7, |
|
|
layer_scale_init_value: float = 1e-6, |
|
|
use_attention: bool = True, |
|
|
num_heads: int = 8, |
|
|
num_attention_layers: int = 1, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.input_conv = nn.Conv1d(input_dim, dim, kernel_size=7, padding=3) |
|
|
self.norm = nn.LayerNorm(dim) |
|
|
|
|
|
|
|
|
self.use_attention = use_attention |
|
|
if use_attention: |
|
|
self.attention = nn.ModuleList([ |
|
|
nn.MultiheadAttention(dim, num_heads, batch_first=True) |
|
|
for _ in range(num_attention_layers) |
|
|
]) |
|
|
self.attn_norms = nn.ModuleList([ |
|
|
nn.LayerNorm(dim) for _ in range(num_attention_layers) |
|
|
]) |
|
|
|
|
|
|
|
|
self.convnext = nn.ModuleList([ |
|
|
ConvNeXtBlock(dim, intermediate_dim, kernel_size, layer_scale_init_value) |
|
|
for _ in range(num_blocks) |
|
|
]) |
|
|
|
|
|
self.final_norm = nn.LayerNorm(dim) |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
|
|
|
x = self.input_conv(x) |
|
|
x = x.transpose(1, 2) |
|
|
x = self.norm(x) |
|
|
x = x.transpose(1, 2) |
|
|
|
|
|
|
|
|
if self.use_attention: |
|
|
for attn, norm in zip(self.attention, self.attn_norms): |
|
|
x_t = x.transpose(1, 2) |
|
|
residual = x_t |
|
|
x_t = norm(x_t) |
|
|
x_t, _ = attn(x_t, x_t, x_t) |
|
|
x_t = residual + x_t |
|
|
x = x_t.transpose(1, 2) |
|
|
|
|
|
|
|
|
for block in self.convnext: |
|
|
x = block(x) |
|
|
|
|
|
|
|
|
x = x.transpose(1, 2) |
|
|
x = self.final_norm(x) |
|
|
x = x.transpose(1, 2) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class ISTFTHead(nn.Module): |
|
|
"""Inverse STFT head for waveform synthesis.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
n_fft: int, |
|
|
hop_length: int, |
|
|
padding: str = "center", |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.n_fft = n_fft |
|
|
self.hop_length = hop_length |
|
|
self.padding = padding |
|
|
|
|
|
self.out_dim = n_fft // 2 + 1 |
|
|
self.proj = nn.Conv1d(dim, self.out_dim * 2, kernel_size=1) |
|
|
|
|
|
|
|
|
self.register_buffer( |
|
|
"window", |
|
|
torch.hann_window(n_fft), |
|
|
persistent=False |
|
|
) |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
""" |
|
|
Args: |
|
|
x: [B, D, T] |
|
|
Returns: |
|
|
wav: [B, 1, T'] |
|
|
""" |
|
|
x = self.proj(x) |
|
|
|
|
|
|
|
|
mag, phase = x.chunk(2, dim=1) |
|
|
|
|
|
|
|
|
mag = torch.exp(mag) |
|
|
phase = torch.sin(phase) |
|
|
|
|
|
|
|
|
S = torch.complex(mag * torch.cos(phase * math.pi), mag * torch.sin(phase * math.pi)) |
|
|
|
|
|
|
|
|
window = self.window.to(x.device) |
|
|
|
|
|
|
|
|
wav = torch.istft( |
|
|
S, |
|
|
n_fft=self.n_fft, |
|
|
hop_length=self.hop_length, |
|
|
window=window, |
|
|
center=True, |
|
|
normalized=False, |
|
|
onesided=True, |
|
|
return_complex=False, |
|
|
) |
|
|
|
|
|
return wav.unsqueeze(1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MelSpectrogramFeatures(nn.Module): |
|
|
"""Extract mel spectrogram features from audio.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
sample_rate: int = 24000, |
|
|
n_fft: int = 1024, |
|
|
hop_length: int = 256, |
|
|
n_mels: int = 100, |
|
|
f_min: float = 0.0, |
|
|
f_max: float = None, |
|
|
padding: str = "center", |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.sample_rate = sample_rate |
|
|
self.n_fft = n_fft |
|
|
self.hop_length = hop_length |
|
|
self.n_mels = n_mels |
|
|
self.padding = padding |
|
|
|
|
|
|
|
|
import torchaudio |
|
|
mel_fb = torchaudio.functional.melscale_fbanks( |
|
|
n_freqs=n_fft // 2 + 1, |
|
|
f_min=f_min, |
|
|
f_max=f_max or sample_rate // 2, |
|
|
n_mels=n_mels, |
|
|
sample_rate=sample_rate, |
|
|
norm="slaney", |
|
|
mel_scale="slaney", |
|
|
) |
|
|
self.register_buffer("mel_fb", mel_fb, persistent=False) |
|
|
self.register_buffer("window", torch.hann_window(n_fft), persistent=False) |
|
|
|
|
|
def forward(self, wav: Tensor) -> Tensor: |
|
|
""" |
|
|
Args: |
|
|
wav: [B, 1, T] or [B, T] |
|
|
Returns: |
|
|
mel: [B, n_mels, T'] |
|
|
""" |
|
|
if wav.dim() == 3: |
|
|
wav = wav.squeeze(1) |
|
|
|
|
|
|
|
|
stft = torch.stft( |
|
|
wav, |
|
|
n_fft=self.n_fft, |
|
|
hop_length=self.hop_length, |
|
|
window=self.window.to(wav.device), |
|
|
center=True, |
|
|
return_complex=True, |
|
|
) |
|
|
|
|
|
|
|
|
power = stft.abs().pow(2) |
|
|
|
|
|
|
|
|
mel = torch.matmul(self.mel_fb.T.to(power.device), power) |
|
|
|
|
|
|
|
|
mel = torch.log(mel.clamp(min=1e-5)) |
|
|
|
|
|
return mel |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class WavTokenizer(PreTrainedModel): |
|
|
""" |
|
|
WavTokenizer: Efficient acoustic discrete codec tokenizer. |
|
|
|
|
|
Architecture: |
|
|
- Encoder: Strided convolutions for audio compression |
|
|
- VQ: Single-codebook vector quantization (4096 codes) |
|
|
- Decoder: Vocos backbone (ConvNeXt + attention) + iSTFT head |
|
|
|
|
|
Usage: |
|
|
```python |
|
|
model = WavTokenizer.from_pretrained("TuKoResearch/WavTokenizerSmall", trust_remote_code=True) |
|
|
|
|
|
# Encode |
|
|
features, codes = model.encode_infer(wav, bandwidth_id=torch.tensor([0])) |
|
|
|
|
|
# Decode |
|
|
wav_out = model.decode(features, bandwidth_id=torch.tensor([0])) |
|
|
|
|
|
# Or use codes directly |
|
|
features = model.codes_to_features(codes) |
|
|
wav_out = model.decode(features, bandwidth_id=torch.tensor([0])) |
|
|
``` |
|
|
""" |
|
|
|
|
|
config_class = WavTokenizerConfig |
|
|
|
|
|
def __init__(self, config: WavTokenizerConfig): |
|
|
super().__init__(config) |
|
|
|
|
|
self.sample_rate = config.sample_rate |
|
|
self.hop_length = config.hop_length |
|
|
|
|
|
|
|
|
self.encoder = Encoder( |
|
|
d_model=config.encoder_dim, |
|
|
strides=config.encoder_rates, |
|
|
d_latent=config.latent_dim, |
|
|
) |
|
|
|
|
|
|
|
|
self.quantizer = ResidualVectorQuantize( |
|
|
input_dim=config.latent_dim, |
|
|
codebook_size=config.codebook_size, |
|
|
codebook_dim=config.codebook_dim, |
|
|
num_quantizers=config.num_quantizers, |
|
|
) |
|
|
|
|
|
|
|
|
self.feature_proj = nn.Conv1d(config.latent_dim, config.backbone_dim, 1) |
|
|
|
|
|
|
|
|
self.backbone = VocosBackbone( |
|
|
input_dim=config.backbone_dim, |
|
|
dim=config.backbone_dim, |
|
|
intermediate_dim=config.backbone_intermediate_dim, |
|
|
num_blocks=config.backbone_num_blocks, |
|
|
kernel_size=config.backbone_kernel_size, |
|
|
layer_scale_init_value=config.backbone_layer_scale_init_value, |
|
|
use_attention=config.use_attention, |
|
|
num_heads=config.attention_heads, |
|
|
num_attention_layers=config.attention_layers, |
|
|
) |
|
|
|
|
|
|
|
|
self.head = ISTFTHead( |
|
|
dim=config.backbone_dim, |
|
|
n_fft=config.n_fft, |
|
|
hop_length=config.hop_length, |
|
|
padding=config.padding, |
|
|
) |
|
|
|
|
|
|
|
|
self.bandwidth_emb = nn.Embedding(4, config.backbone_dim) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
@property |
|
|
def vocab_size(self) -> int: |
|
|
return self.config.codebook_size |
|
|
|
|
|
@property |
|
|
def frame_rate(self) -> float: |
|
|
return self.config.sample_rate / self.config.hop_length |
|
|
|
|
|
def encode( |
|
|
self, wav: Tensor, bandwidth_id: Tensor = None |
|
|
) -> Tuple[Tensor, Tensor, Tensor]: |
|
|
""" |
|
|
Encode waveform to quantized features. |
|
|
|
|
|
Args: |
|
|
wav: [B, 1, T] or [B, T] |
|
|
bandwidth_id: Optional bandwidth ID |
|
|
|
|
|
Returns: |
|
|
z_q: Quantized features [B, D, T'] |
|
|
commitment_loss: VQ loss |
|
|
codes: Discrete codes [N_q, B, T'] |
|
|
""" |
|
|
if wav.dim() == 2: |
|
|
wav = wav.unsqueeze(1) |
|
|
|
|
|
z = self.encoder(wav) |
|
|
z_q, loss, codes = self.quantizer(z) |
|
|
|
|
|
return z_q, loss, codes |
|
|
|
|
|
@torch.no_grad() |
|
|
def encode_infer( |
|
|
self, wav: Tensor, bandwidth_id: Tensor = None |
|
|
) -> Tuple[Tensor, Tensor]: |
|
|
""" |
|
|
Encode waveform to features and codes (inference). |
|
|
|
|
|
Args: |
|
|
wav: [B, 1, T] or [1, T] or [B, T] |
|
|
bandwidth_id: Optional bandwidth ID |
|
|
|
|
|
Returns: |
|
|
features: [B, D, T'] |
|
|
codes: [B, T'] (squeezed if single quantizer) |
|
|
""" |
|
|
if wav.dim() == 2: |
|
|
if wav.size(0) == 1: |
|
|
wav = wav.unsqueeze(0) |
|
|
else: |
|
|
wav = wav.unsqueeze(1) |
|
|
|
|
|
z = self.encoder(wav) |
|
|
z_q, _, codes = self.quantizer(z) |
|
|
|
|
|
|
|
|
if codes.size(0) == 1: |
|
|
codes = codes.squeeze(0) |
|
|
|
|
|
return z_q, codes |
|
|
|
|
|
def decode( |
|
|
self, features: Tensor, bandwidth_id: Tensor = None |
|
|
) -> Tensor: |
|
|
""" |
|
|
Decode features to waveform. |
|
|
|
|
|
Args: |
|
|
features: [B, D, T'] |
|
|
bandwidth_id: Optional bandwidth ID |
|
|
|
|
|
Returns: |
|
|
wav: [B, 1, T] |
|
|
""" |
|
|
x = self.feature_proj(features) |
|
|
|
|
|
if bandwidth_id is not None: |
|
|
bw_emb = self.bandwidth_emb(bandwidth_id) |
|
|
x = x + bw_emb.unsqueeze(-1) |
|
|
|
|
|
x = self.backbone(x) |
|
|
wav = self.head(x) |
|
|
|
|
|
return wav |
|
|
|
|
|
@torch.no_grad() |
|
|
def codes_to_features(self, codes: Tensor) -> Tensor: |
|
|
""" |
|
|
Convert codes to features. |
|
|
|
|
|
Args: |
|
|
codes: [N_q, B, T'] or [B, T'] |
|
|
|
|
|
Returns: |
|
|
features: [B, D, T'] |
|
|
""" |
|
|
return self.quantizer.decode(codes) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
wav: Tensor = None, |
|
|
codes: Tensor = None, |
|
|
bandwidth_id: Tensor = None, |
|
|
**kwargs |
|
|
) -> Union[BatchEncoding, Tensor]: |
|
|
""" |
|
|
Forward pass. |
|
|
|
|
|
If wav provided: encode to get tokens |
|
|
If codes provided: decode to get wav |
|
|
""" |
|
|
if wav is not None: |
|
|
features, codes = self.encode_infer(wav, bandwidth_id) |
|
|
return BatchEncoding({ |
|
|
"input_values": features, |
|
|
"input_ids": codes, |
|
|
}) |
|
|
elif codes is not None: |
|
|
features = self.codes_to_features(codes) |
|
|
return self.decode(features, bandwidth_id) |
|
|
else: |
|
|
raise ValueError("Provide either 'wav' or 'codes'") |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained0802( |
|
|
cls, |
|
|
config_path: str, |
|
|
checkpoint_path: str, |
|
|
device: str = "cpu", |
|
|
) -> "WavTokenizer": |
|
|
""" |
|
|
Load from original WavTokenizer checkpoint. |
|
|
|
|
|
Args: |
|
|
config_path: Path to YAML config |
|
|
checkpoint_path: Path to .ckpt file |
|
|
device: Device to load to |
|
|
|
|
|
Returns: |
|
|
Loaded model |
|
|
""" |
|
|
import yaml |
|
|
|
|
|
|
|
|
with open(config_path, 'r') as f: |
|
|
yaml_cfg = yaml.safe_load(f) |
|
|
|
|
|
|
|
|
model_args = yaml_cfg.get('model', {}).get('init_args', {}) |
|
|
|
|
|
|
|
|
config = WavTokenizerConfig( |
|
|
sample_rate=24000, |
|
|
n_fft=model_args.get('head', {}).get('init_args', {}).get('n_fft', 1280), |
|
|
hop_length=model_args.get('head', {}).get('init_args', {}).get('hop_length', 320), |
|
|
feature_dim=model_args.get('backbone', {}).get('init_args', {}).get('dim', 512), |
|
|
latent_dim=model_args.get('backbone', {}).get('init_args', {}).get('input_channels', 512), |
|
|
backbone_dim=model_args.get('backbone', {}).get('init_args', {}).get('dim', 512), |
|
|
backbone_intermediate_dim=model_args.get('backbone', {}).get('init_args', {}).get('intermediate_dim', 1536), |
|
|
backbone_num_blocks=model_args.get('backbone', {}).get('init_args', {}).get('num_layers', 8), |
|
|
codebook_size=model_args.get('quantizer', {}).get('init_args', {}).get('codebook_size', 4096), |
|
|
codebook_dim=model_args.get('quantizer', {}).get('init_args', {}).get('codebook_dim', 8), |
|
|
num_quantizers=model_args.get('quantizer', {}).get('init_args', {}).get('num_quantizers', 1), |
|
|
use_attention=True, |
|
|
attention_dim=model_args.get('backbone', {}).get('init_args', {}).get('dim', 512), |
|
|
attention_heads=8, |
|
|
attention_layers=1, |
|
|
) |
|
|
|
|
|
|
|
|
model = cls(config) |
|
|
|
|
|
|
|
|
ckpt = torch.load(checkpoint_path, map_location=device) |
|
|
state_dict = ckpt.get('state_dict', ckpt) |
|
|
|
|
|
|
|
|
new_state_dict = {} |
|
|
for k, v in state_dict.items(): |
|
|
|
|
|
if k.startswith('model.'): |
|
|
k = k[6:] |
|
|
new_state_dict[k] = v |
|
|
|
|
|
|
|
|
missing, unexpected = model.load_state_dict(new_state_dict, strict=False) |
|
|
|
|
|
if missing: |
|
|
print(f"Missing keys: {len(missing)}") |
|
|
if unexpected: |
|
|
print(f"Unexpected keys: {len(unexpected)}") |
|
|
|
|
|
return model.to(device) |
|
|
|