#!/usr/bin/env python3 """ LoRA 推理测试脚本。 用法: python scripts/test_voxcpm_lora_infer.py \ --config_path conf/voxcpm/voxcpm_finetune_test.yaml \ --lora_ckpt checkpoints/step_0002000 \ --text "你好,这是 LoRA 微调后的效果。" \ --output lora_test.wav 带参考音频的音色克隆: python scripts/test_voxcpm_lora_infer.py \ --config_path conf/voxcpm/voxcpm_finetune_test.yaml \ --lora_ckpt checkpoints/step_0002000 \ --text "这是带参考音色的合成效果。" \ --prompt_audio path/to/ref.wav \ --prompt_text "参考音频对应的文本" \ --output lora_clone.wav """ import argparse from pathlib import Path import soundfile as sf import torch from voxcpm.model import VoxCPMModel from voxcpm.model.voxcpm import LoRAConfig from voxcpm.training.config import load_yaml_config def parse_args(): parser = argparse.ArgumentParser("VoxCPM LoRA inference test") parser.add_argument( "--config_path", type=str, required=True, help="训练时使用的 YAML 配置路径(包含 pretrained_path 和 lora 配置)", ) parser.add_argument( "--lora_ckpt", type=str, required=True, help="LoRA checkpoint 目录(内含 generator.pth,仅包含 lora_A / lora_B)", ) parser.add_argument( "--text", type=str, required=True, help="待合成的目标文本(target_text)", ) parser.add_argument( "--prompt_audio", type=str, default="", help="可选:参考音频路径,用于 voice cloning(不填则为直接 TTS)", ) parser.add_argument( "--prompt_text", type=str, default="", help="可选:参考音频对应的文本(与 prompt_audio 搭配使用)", ) parser.add_argument( "--output", type=str, default="lora_test.wav", help="输出 wav 文件路径", ) parser.add_argument( "--cfg_value", type=float, default=2.0, help="推理时的 CFG scale,与训练 / 官方示例保持一致(默认 2.0)", ) parser.add_argument( "--inference_timesteps", type=int, default=10, help="扩散推理步数,默认为 10", ) parser.add_argument( "--max_len", type=int, default=600, help="生成阶段的最大步数(对应 _generate 的 max_len)", ) return parser.parse_args() def main(): args = parse_args() # 1. 读取 YAML 配置 cfg = load_yaml_config(args.config_path) pretrained_path = cfg["pretrained_path"] lora_cfg_dict = cfg.get("lora", {}) or {} lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None # 2. 加载基础模型(包含 LoRA 结构,并执行 torch.compile) print(f"[1/3] 加载基础模型:{pretrained_path}") model = VoxCPMModel.from_local( pretrained_path, optimize=True, # 先 compile,load_lora_weights 使用 named_parameters 兼容 training=False, lora_config=lora_cfg, ) # 调试:检查 compile 后 DiT 的参数路径 dit_params = [n for n, _ in model.named_parameters() if 'feat_decoder' in n and 'lora' in n] print(f"[DEBUG] compile 后 DiT LoRA 参数路径 (前3个): {dit_params[:3]}") # 3. 加载 LoRA 权重(在 compile 后也能正常工作) ckpt_dir = Path(args.lora_ckpt) if not ckpt_dir.exists(): raise FileNotFoundError(f"找不到 LoRA checkpoint: {ckpt_dir}") print(f"[2/3] 加载 LoRA 权重:{ckpt_dir}") loaded, skipped = model.load_lora_weights(str(ckpt_dir)) print(f" 已加载 {len(loaded)} 个参数") if skipped: print(f"[WARNING] 跳过 {len(skipped)} 个参数") print(f" 跳过的 key (前5个): {skipped[:5]}") # 4. 合成语音 prompt_wav_path = args.prompt_audio or "" prompt_text = args.prompt_text or "" out_path = Path(args.output) out_path.parent.mkdir(parents=True, exist_ok=True) print(f"\n[3/3] 开始合成测试...") # === 测试 1: 使用 LoRA === print(f"\n [Test 1] 使用 LoRA 合成...") with torch.inference_mode(): audio = model.generate( target_text=args.text, prompt_text=prompt_text, prompt_wav_path=prompt_wav_path, max_len=args.max_len, inference_timesteps=args.inference_timesteps, cfg_value=args.cfg_value, ) audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy() lora_output = out_path.with_stem(out_path.stem + "_with_lora") sf.write(str(lora_output), audio_np, model.sample_rate) print(f" 已保存:{lora_output},时长 {len(audio_np) / model.sample_rate:.2f}s") # === 测试 2: 禁用 LoRA(通过 set_lora_enabled) === print(f"\n [Test 2] 禁用 LoRA (set_lora_enabled=False)...") model.set_lora_enabled(False) with torch.inference_mode(): audio = model.generate( target_text=args.text, prompt_text=prompt_text, prompt_wav_path=prompt_wav_path, max_len=args.max_len, inference_timesteps=args.inference_timesteps, cfg_value=args.cfg_value, ) audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy() disabled_output = out_path.with_stem(out_path.stem + "_lora_disabled") sf.write(str(disabled_output), audio_np, model.sample_rate) print(f" 已保存:{disabled_output},时长 {len(audio_np) / model.sample_rate:.2f}s") # === 测试 3: 重新启用 LoRA === print(f"\n [Test 3] 重新启用 LoRA (set_lora_enabled=True)...") model.set_lora_enabled(True) with torch.inference_mode(): audio = model.generate( target_text=args.text, prompt_text=prompt_text, prompt_wav_path=prompt_wav_path, max_len=args.max_len, inference_timesteps=args.inference_timesteps, cfg_value=args.cfg_value, ) audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy() reenabled_output = out_path.with_stem(out_path.stem + "_lora_reenabled") sf.write(str(reenabled_output), audio_np, model.sample_rate) print(f" 已保存:{reenabled_output},时长 {len(audio_np) / model.sample_rate:.2f}s") # === 测试 4: 卸载 LoRA(reset_lora_weights) === print(f"\n [Test 4] 卸载 LoRA (reset_lora_weights)...") model.reset_lora_weights() with torch.inference_mode(): audio = model.generate( target_text=args.text, prompt_text=prompt_text, prompt_wav_path=prompt_wav_path, max_len=args.max_len, inference_timesteps=args.inference_timesteps, cfg_value=args.cfg_value, ) audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy() reset_output = out_path.with_stem(out_path.stem + "_lora_reset") sf.write(str(reset_output), audio_np, model.sample_rate) print(f" 已保存:{reset_output},时长 {len(audio_np) / model.sample_rate:.2f}s") # === 测试 5: 热加载 LoRA(重新加载权重) === print(f"\n [Test 5] 热加载 LoRA (load_lora_weights)...") loaded, _ = model.load_lora_weights(str(ckpt_dir)) print(f" 重新加载了 {len(loaded)} 个参数") with torch.inference_mode(): audio = model.generate( target_text=args.text, prompt_text=prompt_text, prompt_wav_path=prompt_wav_path, max_len=args.max_len, inference_timesteps=args.inference_timesteps, cfg_value=args.cfg_value, ) audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy() reload_output = out_path.with_stem(out_path.stem + "_lora_reloaded") sf.write(str(reload_output), audio_np, model.sample_rate) print(f" 已保存:{reload_output},时长 {len(audio_np) / model.sample_rate:.2f}s") print(f"\n[完成] 所有测试完成!") print(f" - with_lora: {lora_output}") print(f" - lora_disabled: {disabled_output}") print(f" - lora_reenabled: {reenabled_output}") print(f" - lora_reset: {reset_output}") print(f" - lora_reloaded: {reload_output}") if __name__ == "__main__": main()