File size: 5,931 Bytes
85ac35b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
#!/usr/bin/env python
"""
Convert original WavTokenizer checkpoint to HuggingFace format.
Usage:
python convert_wavtokenizer.py \
--config_path configs/wavtokenizer_smalldata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml \
--checkpoint_path checkpoints/wavtokenizer_small_320_24k_4096.ckpt \
--output_dir ./wavtokenizer_hf_converted
This will create a HuggingFace-compatible model directory that can be loaded with:
model = AutoModel.from_pretrained("./wavtokenizer_hf_converted", trust_remote_code=True)
"""
import argparse
import json
import os
import shutil
from pathlib import Path
import torch
import yaml
def convert_wavtokenizer(config_path: str, checkpoint_path: str, output_dir: str):
"""Convert WavTokenizer checkpoint to HuggingFace format."""
print(f"Loading config from: {config_path}")
print(f"Loading checkpoint from: {checkpoint_path}")
# Load YAML config
with open(config_path, 'r') as f:
yaml_cfg = yaml.safe_load(f)
# Extract model parameters
model_args = yaml_cfg.get('model', {}).get('init_args', {})
# Get specific component configs
head_args = model_args.get('head', {}).get('init_args', {})
backbone_args = model_args.get('backbone', {}).get('init_args', {})
quantizer_args = model_args.get('quantizer', {}).get('init_args', {})
feature_extractor_args = model_args.get('feature_extractor', {}).get('init_args', {})
# Create HuggingFace config
hf_config = {
"_name_or_path": "WavTokenizerSmall",
"architectures": ["WavTokenizer"],
"auto_map": {
"AutoConfig": "configuration_wavtokenizer.WavTokenizerConfig",
"AutoModel": "modeling_wavtokenizer.WavTokenizer"
},
"model_type": "wavtokenizer",
# Audio parameters
"sample_rate": feature_extractor_args.get('sample_rate', 24000),
"n_fft": head_args.get('n_fft', 1280),
"hop_length": head_args.get('hop_length', 320),
"n_mels": feature_extractor_args.get('n_mels', 128),
"padding": head_args.get('padding', 'center'),
# Feature dimensions
"feature_dim": backbone_args.get('dim', 512),
"encoder_dim": 64, # Default DAC encoder
"encoder_rates": [8, 5, 4, 2], # Default DAC encoder rates
"latent_dim": backbone_args.get('input_channels', 512),
# Quantizer parameters
"codebook_size": quantizer_args.get('codebook_size', 4096),
"codebook_dim": quantizer_args.get('codebook_dim', 8),
"num_quantizers": quantizer_args.get('num_quantizers', 1),
# Backbone parameters
"backbone_type": "vocos",
"backbone_dim": backbone_args.get('dim', 512),
"backbone_num_blocks": backbone_args.get('num_layers', 8),
"backbone_intermediate_dim": backbone_args.get('intermediate_dim', 1536),
"backbone_kernel_size": 7,
"backbone_layer_scale_init_value": 1e-6,
# Head parameters
"head_type": "istft",
"head_dim": head_args.get('n_fft', 1280) // 2 + 1,
# Attention parameters
"use_attention": True,
"attention_dim": backbone_args.get('dim', 512),
"attention_heads": 8,
"attention_layers": 1,
"torch_dtype": "float32",
"transformers_version": "4.40.0"
}
# Create output directory
os.makedirs(output_dir, exist_ok=True)
# Save config.json
config_out_path = os.path.join(output_dir, "config.json")
with open(config_out_path, 'w') as f:
json.dump(hf_config, f, indent=2)
print(f"Saved config to: {config_out_path}")
# Load checkpoint
print("Loading checkpoint...")
ckpt = torch.load(checkpoint_path, map_location='cpu')
state_dict = ckpt.get('state_dict', ckpt)
# Clean state dict keys
new_state_dict = {}
for k, v in state_dict.items():
# Remove 'model.' prefix if present
if k.startswith('model.'):
k = k[6:]
new_state_dict[k] = v
# Save as pytorch_model.bin
model_out_path = os.path.join(output_dir, "pytorch_model.bin")
torch.save(new_state_dict, model_out_path)
print(f"Saved model weights to: {model_out_path}")
# Copy Python files
script_dir = Path(__file__).parent
# Copy configuration file
config_py = script_dir / "configuration_wavtokenizer.py"
if config_py.exists():
shutil.copy(config_py, output_dir)
print(f"Copied: configuration_wavtokenizer.py")
# Copy modeling file
modeling_py = script_dir / "modeling_wavtokenizer.py"
if modeling_py.exists():
shutil.copy(modeling_py, output_dir)
print(f"Copied: modeling_wavtokenizer.py")
# Copy README
readme = script_dir / "README.md"
if readme.exists():
shutil.copy(readme, output_dir)
print(f"Copied: README.md")
print(f"\nConversion complete! Model saved to: {output_dir}")
print("\nTo load the model:")
print(f' model = AutoModel.from_pretrained("{output_dir}", trust_remote_code=True)')
def main():
parser = argparse.ArgumentParser(description="Convert WavTokenizer checkpoint to HuggingFace format")
parser.add_argument(
"--config_path",
type=str,
required=True,
help="Path to WavTokenizer YAML config file"
)
parser.add_argument(
"--checkpoint_path",
type=str,
required=True,
help="Path to WavTokenizer .ckpt checkpoint file"
)
parser.add_argument(
"--output_dir",
type=str,
default="./wavtokenizer_hf_converted",
help="Output directory for HuggingFace model"
)
args = parser.parse_args()
convert_wavtokenizer(args.config_path, args.checkpoint_path, args.output_dir)
if __name__ == "__main__":
main() |