#!/usr/bin/env python """ Convert original WavTokenizer checkpoint to HuggingFace format. Usage: python convert_wavtokenizer.py \ --config_path configs/wavtokenizer_smalldata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml \ --checkpoint_path checkpoints/wavtokenizer_small_320_24k_4096.ckpt \ --output_dir ./wavtokenizer_hf_converted This will create a HuggingFace-compatible model directory that can be loaded with: model = AutoModel.from_pretrained("./wavtokenizer_hf_converted", trust_remote_code=True) """ import argparse import json import os import shutil from pathlib import Path import torch import yaml def convert_wavtokenizer(config_path: str, checkpoint_path: str, output_dir: str): """Convert WavTokenizer checkpoint to HuggingFace format.""" print(f"Loading config from: {config_path}") print(f"Loading checkpoint from: {checkpoint_path}") # Load YAML config with open(config_path, 'r') as f: yaml_cfg = yaml.safe_load(f) # Extract model parameters model_args = yaml_cfg.get('model', {}).get('init_args', {}) # Get specific component configs head_args = model_args.get('head', {}).get('init_args', {}) backbone_args = model_args.get('backbone', {}).get('init_args', {}) quantizer_args = model_args.get('quantizer', {}).get('init_args', {}) feature_extractor_args = model_args.get('feature_extractor', {}).get('init_args', {}) # Create HuggingFace config hf_config = { "_name_or_path": "WavTokenizerSmall", "architectures": ["WavTokenizer"], "auto_map": { "AutoConfig": "configuration_wavtokenizer.WavTokenizerConfig", "AutoModel": "modeling_wavtokenizer.WavTokenizer" }, "model_type": "wavtokenizer", # Audio parameters "sample_rate": feature_extractor_args.get('sample_rate', 24000), "n_fft": head_args.get('n_fft', 1280), "hop_length": head_args.get('hop_length', 320), "n_mels": feature_extractor_args.get('n_mels', 128), "padding": head_args.get('padding', 'center'), # Feature dimensions "feature_dim": backbone_args.get('dim', 512), "encoder_dim": 64, # Default DAC encoder "encoder_rates": [8, 5, 4, 2], # Default DAC encoder rates "latent_dim": backbone_args.get('input_channels', 512), # Quantizer parameters "codebook_size": quantizer_args.get('codebook_size', 4096), "codebook_dim": quantizer_args.get('codebook_dim', 8), "num_quantizers": quantizer_args.get('num_quantizers', 1), # Backbone parameters "backbone_type": "vocos", "backbone_dim": backbone_args.get('dim', 512), "backbone_num_blocks": backbone_args.get('num_layers', 8), "backbone_intermediate_dim": backbone_args.get('intermediate_dim', 1536), "backbone_kernel_size": 7, "backbone_layer_scale_init_value": 1e-6, # Head parameters "head_type": "istft", "head_dim": head_args.get('n_fft', 1280) // 2 + 1, # Attention parameters "use_attention": True, "attention_dim": backbone_args.get('dim', 512), "attention_heads": 8, "attention_layers": 1, "torch_dtype": "float32", "transformers_version": "4.40.0" } # Create output directory os.makedirs(output_dir, exist_ok=True) # Save config.json config_out_path = os.path.join(output_dir, "config.json") with open(config_out_path, 'w') as f: json.dump(hf_config, f, indent=2) print(f"Saved config to: {config_out_path}") # Load checkpoint print("Loading checkpoint...") ckpt = torch.load(checkpoint_path, map_location='cpu') state_dict = ckpt.get('state_dict', ckpt) # Clean state dict keys new_state_dict = {} for k, v in state_dict.items(): # Remove 'model.' prefix if present if k.startswith('model.'): k = k[6:] new_state_dict[k] = v # Save as pytorch_model.bin model_out_path = os.path.join(output_dir, "pytorch_model.bin") torch.save(new_state_dict, model_out_path) print(f"Saved model weights to: {model_out_path}") # Copy Python files script_dir = Path(__file__).parent # Copy configuration file config_py = script_dir / "configuration_wavtokenizer.py" if config_py.exists(): shutil.copy(config_py, output_dir) print(f"Copied: configuration_wavtokenizer.py") # Copy modeling file modeling_py = script_dir / "modeling_wavtokenizer.py" if modeling_py.exists(): shutil.copy(modeling_py, output_dir) print(f"Copied: modeling_wavtokenizer.py") # Copy README readme = script_dir / "README.md" if readme.exists(): shutil.copy(readme, output_dir) print(f"Copied: README.md") print(f"\nConversion complete! Model saved to: {output_dir}") print("\nTo load the model:") print(f' model = AutoModel.from_pretrained("{output_dir}", trust_remote_code=True)') def main(): parser = argparse.ArgumentParser(description="Convert WavTokenizer checkpoint to HuggingFace format") parser.add_argument( "--config_path", type=str, required=True, help="Path to WavTokenizer YAML config file" ) parser.add_argument( "--checkpoint_path", type=str, required=True, help="Path to WavTokenizer .ckpt checkpoint file" ) parser.add_argument( "--output_dir", type=str, default="./wavtokenizer_hf_converted", help="Output directory for HuggingFace model" ) args = parser.parse_args() convert_wavtokenizer(args.config_path, args.checkpoint_path, args.output_dir) if __name__ == "__main__": main()