|
|
|
|
|
""" |
|
|
Convert original WavTokenizer checkpoint to HuggingFace format. |
|
|
|
|
|
Usage: |
|
|
python convert_wavtokenizer.py \ |
|
|
--config_path configs/wavtokenizer_smalldata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml \ |
|
|
--checkpoint_path checkpoints/wavtokenizer_small_320_24k_4096.ckpt \ |
|
|
--output_dir ./wavtokenizer_hf_converted |
|
|
|
|
|
This will create a HuggingFace-compatible model directory that can be loaded with: |
|
|
model = AutoModel.from_pretrained("./wavtokenizer_hf_converted", trust_remote_code=True) |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import os |
|
|
import shutil |
|
|
from pathlib import Path |
|
|
|
|
|
import torch |
|
|
import yaml |
|
|
|
|
|
|
|
|
def convert_wavtokenizer(config_path: str, checkpoint_path: str, output_dir: str): |
|
|
"""Convert WavTokenizer checkpoint to HuggingFace format.""" |
|
|
|
|
|
print(f"Loading config from: {config_path}") |
|
|
print(f"Loading checkpoint from: {checkpoint_path}") |
|
|
|
|
|
|
|
|
with open(config_path, 'r') as f: |
|
|
yaml_cfg = yaml.safe_load(f) |
|
|
|
|
|
|
|
|
model_args = yaml_cfg.get('model', {}).get('init_args', {}) |
|
|
|
|
|
|
|
|
head_args = model_args.get('head', {}).get('init_args', {}) |
|
|
backbone_args = model_args.get('backbone', {}).get('init_args', {}) |
|
|
quantizer_args = model_args.get('quantizer', {}).get('init_args', {}) |
|
|
feature_extractor_args = model_args.get('feature_extractor', {}).get('init_args', {}) |
|
|
|
|
|
|
|
|
hf_config = { |
|
|
"_name_or_path": "WavTokenizerSmall", |
|
|
"architectures": ["WavTokenizer"], |
|
|
"auto_map": { |
|
|
"AutoConfig": "configuration_wavtokenizer.WavTokenizerConfig", |
|
|
"AutoModel": "modeling_wavtokenizer.WavTokenizer" |
|
|
}, |
|
|
"model_type": "wavtokenizer", |
|
|
|
|
|
|
|
|
"sample_rate": feature_extractor_args.get('sample_rate', 24000), |
|
|
"n_fft": head_args.get('n_fft', 1280), |
|
|
"hop_length": head_args.get('hop_length', 320), |
|
|
"n_mels": feature_extractor_args.get('n_mels', 128), |
|
|
"padding": head_args.get('padding', 'center'), |
|
|
|
|
|
|
|
|
"feature_dim": backbone_args.get('dim', 512), |
|
|
"encoder_dim": 64, |
|
|
"encoder_rates": [8, 5, 4, 2], |
|
|
"latent_dim": backbone_args.get('input_channels', 512), |
|
|
|
|
|
|
|
|
"codebook_size": quantizer_args.get('codebook_size', 4096), |
|
|
"codebook_dim": quantizer_args.get('codebook_dim', 8), |
|
|
"num_quantizers": quantizer_args.get('num_quantizers', 1), |
|
|
|
|
|
|
|
|
"backbone_type": "vocos", |
|
|
"backbone_dim": backbone_args.get('dim', 512), |
|
|
"backbone_num_blocks": backbone_args.get('num_layers', 8), |
|
|
"backbone_intermediate_dim": backbone_args.get('intermediate_dim', 1536), |
|
|
"backbone_kernel_size": 7, |
|
|
"backbone_layer_scale_init_value": 1e-6, |
|
|
|
|
|
|
|
|
"head_type": "istft", |
|
|
"head_dim": head_args.get('n_fft', 1280) // 2 + 1, |
|
|
|
|
|
|
|
|
"use_attention": True, |
|
|
"attention_dim": backbone_args.get('dim', 512), |
|
|
"attention_heads": 8, |
|
|
"attention_layers": 1, |
|
|
|
|
|
"torch_dtype": "float32", |
|
|
"transformers_version": "4.40.0" |
|
|
} |
|
|
|
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
config_out_path = os.path.join(output_dir, "config.json") |
|
|
with open(config_out_path, 'w') as f: |
|
|
json.dump(hf_config, f, indent=2) |
|
|
print(f"Saved config to: {config_out_path}") |
|
|
|
|
|
|
|
|
print("Loading checkpoint...") |
|
|
ckpt = torch.load(checkpoint_path, map_location='cpu') |
|
|
state_dict = ckpt.get('state_dict', ckpt) |
|
|
|
|
|
|
|
|
new_state_dict = {} |
|
|
for k, v in state_dict.items(): |
|
|
|
|
|
if k.startswith('model.'): |
|
|
k = k[6:] |
|
|
new_state_dict[k] = v |
|
|
|
|
|
|
|
|
model_out_path = os.path.join(output_dir, "pytorch_model.bin") |
|
|
torch.save(new_state_dict, model_out_path) |
|
|
print(f"Saved model weights to: {model_out_path}") |
|
|
|
|
|
|
|
|
script_dir = Path(__file__).parent |
|
|
|
|
|
|
|
|
config_py = script_dir / "configuration_wavtokenizer.py" |
|
|
if config_py.exists(): |
|
|
shutil.copy(config_py, output_dir) |
|
|
print(f"Copied: configuration_wavtokenizer.py") |
|
|
|
|
|
|
|
|
modeling_py = script_dir / "modeling_wavtokenizer.py" |
|
|
if modeling_py.exists(): |
|
|
shutil.copy(modeling_py, output_dir) |
|
|
print(f"Copied: modeling_wavtokenizer.py") |
|
|
|
|
|
|
|
|
readme = script_dir / "README.md" |
|
|
if readme.exists(): |
|
|
shutil.copy(readme, output_dir) |
|
|
print(f"Copied: README.md") |
|
|
|
|
|
print(f"\nConversion complete! Model saved to: {output_dir}") |
|
|
print("\nTo load the model:") |
|
|
print(f' model = AutoModel.from_pretrained("{output_dir}", trust_remote_code=True)') |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Convert WavTokenizer checkpoint to HuggingFace format") |
|
|
parser.add_argument( |
|
|
"--config_path", |
|
|
type=str, |
|
|
required=True, |
|
|
help="Path to WavTokenizer YAML config file" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--checkpoint_path", |
|
|
type=str, |
|
|
required=True, |
|
|
help="Path to WavTokenizer .ckpt checkpoint file" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output_dir", |
|
|
type=str, |
|
|
default="./wavtokenizer_hf_converted", |
|
|
help="Output directory for HuggingFace model" |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
convert_wavtokenizer(args.config_path, args.checkpoint_path, args.output_dir) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |