WavTokenizer / convert_wavtokenizer.py
klemenk's picture
Create convert_wavtokenizer.py
85ac35b verified
#!/usr/bin/env python
"""
Convert original WavTokenizer checkpoint to HuggingFace format.
Usage:
python convert_wavtokenizer.py \
--config_path configs/wavtokenizer_smalldata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml \
--checkpoint_path checkpoints/wavtokenizer_small_320_24k_4096.ckpt \
--output_dir ./wavtokenizer_hf_converted
This will create a HuggingFace-compatible model directory that can be loaded with:
model = AutoModel.from_pretrained("./wavtokenizer_hf_converted", trust_remote_code=True)
"""
import argparse
import json
import os
import shutil
from pathlib import Path
import torch
import yaml
def convert_wavtokenizer(config_path: str, checkpoint_path: str, output_dir: str):
"""Convert WavTokenizer checkpoint to HuggingFace format."""
print(f"Loading config from: {config_path}")
print(f"Loading checkpoint from: {checkpoint_path}")
# Load YAML config
with open(config_path, 'r') as f:
yaml_cfg = yaml.safe_load(f)
# Extract model parameters
model_args = yaml_cfg.get('model', {}).get('init_args', {})
# Get specific component configs
head_args = model_args.get('head', {}).get('init_args', {})
backbone_args = model_args.get('backbone', {}).get('init_args', {})
quantizer_args = model_args.get('quantizer', {}).get('init_args', {})
feature_extractor_args = model_args.get('feature_extractor', {}).get('init_args', {})
# Create HuggingFace config
hf_config = {
"_name_or_path": "WavTokenizerSmall",
"architectures": ["WavTokenizer"],
"auto_map": {
"AutoConfig": "configuration_wavtokenizer.WavTokenizerConfig",
"AutoModel": "modeling_wavtokenizer.WavTokenizer"
},
"model_type": "wavtokenizer",
# Audio parameters
"sample_rate": feature_extractor_args.get('sample_rate', 24000),
"n_fft": head_args.get('n_fft', 1280),
"hop_length": head_args.get('hop_length', 320),
"n_mels": feature_extractor_args.get('n_mels', 128),
"padding": head_args.get('padding', 'center'),
# Feature dimensions
"feature_dim": backbone_args.get('dim', 512),
"encoder_dim": 64, # Default DAC encoder
"encoder_rates": [8, 5, 4, 2], # Default DAC encoder rates
"latent_dim": backbone_args.get('input_channels', 512),
# Quantizer parameters
"codebook_size": quantizer_args.get('codebook_size', 4096),
"codebook_dim": quantizer_args.get('codebook_dim', 8),
"num_quantizers": quantizer_args.get('num_quantizers', 1),
# Backbone parameters
"backbone_type": "vocos",
"backbone_dim": backbone_args.get('dim', 512),
"backbone_num_blocks": backbone_args.get('num_layers', 8),
"backbone_intermediate_dim": backbone_args.get('intermediate_dim', 1536),
"backbone_kernel_size": 7,
"backbone_layer_scale_init_value": 1e-6,
# Head parameters
"head_type": "istft",
"head_dim": head_args.get('n_fft', 1280) // 2 + 1,
# Attention parameters
"use_attention": True,
"attention_dim": backbone_args.get('dim', 512),
"attention_heads": 8,
"attention_layers": 1,
"torch_dtype": "float32",
"transformers_version": "4.40.0"
}
# Create output directory
os.makedirs(output_dir, exist_ok=True)
# Save config.json
config_out_path = os.path.join(output_dir, "config.json")
with open(config_out_path, 'w') as f:
json.dump(hf_config, f, indent=2)
print(f"Saved config to: {config_out_path}")
# Load checkpoint
print("Loading checkpoint...")
ckpt = torch.load(checkpoint_path, map_location='cpu')
state_dict = ckpt.get('state_dict', ckpt)
# Clean state dict keys
new_state_dict = {}
for k, v in state_dict.items():
# Remove 'model.' prefix if present
if k.startswith('model.'):
k = k[6:]
new_state_dict[k] = v
# Save as pytorch_model.bin
model_out_path = os.path.join(output_dir, "pytorch_model.bin")
torch.save(new_state_dict, model_out_path)
print(f"Saved model weights to: {model_out_path}")
# Copy Python files
script_dir = Path(__file__).parent
# Copy configuration file
config_py = script_dir / "configuration_wavtokenizer.py"
if config_py.exists():
shutil.copy(config_py, output_dir)
print(f"Copied: configuration_wavtokenizer.py")
# Copy modeling file
modeling_py = script_dir / "modeling_wavtokenizer.py"
if modeling_py.exists():
shutil.copy(modeling_py, output_dir)
print(f"Copied: modeling_wavtokenizer.py")
# Copy README
readme = script_dir / "README.md"
if readme.exists():
shutil.copy(readme, output_dir)
print(f"Copied: README.md")
print(f"\nConversion complete! Model saved to: {output_dir}")
print("\nTo load the model:")
print(f' model = AutoModel.from_pretrained("{output_dir}", trust_remote_code=True)')
def main():
parser = argparse.ArgumentParser(description="Convert WavTokenizer checkpoint to HuggingFace format")
parser.add_argument(
"--config_path",
type=str,
required=True,
help="Path to WavTokenizer YAML config file"
)
parser.add_argument(
"--checkpoint_path",
type=str,
required=True,
help="Path to WavTokenizer .ckpt checkpoint file"
)
parser.add_argument(
"--output_dir",
type=str,
default="./wavtokenizer_hf_converted",
help="Output directory for HuggingFace model"
)
args = parser.parse_args()
convert_wavtokenizer(args.config_path, args.checkpoint_path, args.output_dir)
if __name__ == "__main__":
main()