dpss-exp3-TTS / VoxCPM /scripts /test_voxcpm_lora_infer.py
lglg666's picture
Upload folder using huggingface_hub
6766eda verified
#!/usr/bin/env python3
"""
LoRA 推理测试脚本。
用法:
python scripts/test_voxcpm_lora_infer.py \
--config_path conf/voxcpm/voxcpm_finetune_test.yaml \
--lora_ckpt checkpoints/step_0002000 \
--text "你好,这是 LoRA 微调后的效果。" \
--output lora_test.wav
带参考音频的音色克隆:
python scripts/test_voxcpm_lora_infer.py \
--config_path conf/voxcpm/voxcpm_finetune_test.yaml \
--lora_ckpt checkpoints/step_0002000 \
--text "这是带参考音色的合成效果。" \
--prompt_audio path/to/ref.wav \
--prompt_text "参考音频对应的文本" \
--output lora_clone.wav
"""
import argparse
from pathlib import Path
import soundfile as sf
import torch
from voxcpm.model import VoxCPMModel
from voxcpm.model.voxcpm import LoRAConfig
from voxcpm.training.config import load_yaml_config
def parse_args():
parser = argparse.ArgumentParser("VoxCPM LoRA inference test")
parser.add_argument(
"--config_path",
type=str,
required=True,
help="训练时使用的 YAML 配置路径(包含 pretrained_path 和 lora 配置)",
)
parser.add_argument(
"--lora_ckpt",
type=str,
required=True,
help="LoRA checkpoint 目录(内含 generator.pth,仅包含 lora_A / lora_B)",
)
parser.add_argument(
"--text",
type=str,
required=True,
help="待合成的目标文本(target_text)",
)
parser.add_argument(
"--prompt_audio",
type=str,
default="",
help="可选:参考音频路径,用于 voice cloning(不填则为直接 TTS)",
)
parser.add_argument(
"--prompt_text",
type=str,
default="",
help="可选:参考音频对应的文本(与 prompt_audio 搭配使用)",
)
parser.add_argument(
"--output",
type=str,
default="lora_test.wav",
help="输出 wav 文件路径",
)
parser.add_argument(
"--cfg_value",
type=float,
default=2.0,
help="推理时的 CFG scale,与训练 / 官方示例保持一致(默认 2.0)",
)
parser.add_argument(
"--inference_timesteps",
type=int,
default=10,
help="扩散推理步数,默认为 10",
)
parser.add_argument(
"--max_len",
type=int,
default=600,
help="生成阶段的最大步数(对应 _generate 的 max_len)",
)
return parser.parse_args()
def main():
args = parse_args()
# 1. 读取 YAML 配置
cfg = load_yaml_config(args.config_path)
pretrained_path = cfg["pretrained_path"]
lora_cfg_dict = cfg.get("lora", {}) or {}
lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None
# 2. 加载基础模型(包含 LoRA 结构,并执行 torch.compile)
print(f"[1/3] 加载基础模型:{pretrained_path}")
model = VoxCPMModel.from_local(
pretrained_path,
optimize=True, # 先 compile,load_lora_weights 使用 named_parameters 兼容
training=False,
lora_config=lora_cfg,
)
# 调试:检查 compile 后 DiT 的参数路径
dit_params = [n for n, _ in model.named_parameters() if 'feat_decoder' in n and 'lora' in n]
print(f"[DEBUG] compile 后 DiT LoRA 参数路径 (前3个): {dit_params[:3]}")
# 3. 加载 LoRA 权重(在 compile 后也能正常工作)
ckpt_dir = Path(args.lora_ckpt)
if not ckpt_dir.exists():
raise FileNotFoundError(f"找不到 LoRA checkpoint: {ckpt_dir}")
print(f"[2/3] 加载 LoRA 权重:{ckpt_dir}")
loaded, skipped = model.load_lora_weights(str(ckpt_dir))
print(f" 已加载 {len(loaded)} 个参数")
if skipped:
print(f"[WARNING] 跳过 {len(skipped)} 个参数")
print(f" 跳过的 key (前5个): {skipped[:5]}")
# 4. 合成语音
prompt_wav_path = args.prompt_audio or ""
prompt_text = args.prompt_text or ""
out_path = Path(args.output)
out_path.parent.mkdir(parents=True, exist_ok=True)
print(f"\n[3/3] 开始合成测试...")
# === 测试 1: 使用 LoRA ===
print(f"\n [Test 1] 使用 LoRA 合成...")
with torch.inference_mode():
audio = model.generate(
target_text=args.text,
prompt_text=prompt_text,
prompt_wav_path=prompt_wav_path,
max_len=args.max_len,
inference_timesteps=args.inference_timesteps,
cfg_value=args.cfg_value,
)
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
lora_output = out_path.with_stem(out_path.stem + "_with_lora")
sf.write(str(lora_output), audio_np, model.sample_rate)
print(f" 已保存:{lora_output},时长 {len(audio_np) / model.sample_rate:.2f}s")
# === 测试 2: 禁用 LoRA(通过 set_lora_enabled) ===
print(f"\n [Test 2] 禁用 LoRA (set_lora_enabled=False)...")
model.set_lora_enabled(False)
with torch.inference_mode():
audio = model.generate(
target_text=args.text,
prompt_text=prompt_text,
prompt_wav_path=prompt_wav_path,
max_len=args.max_len,
inference_timesteps=args.inference_timesteps,
cfg_value=args.cfg_value,
)
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
disabled_output = out_path.with_stem(out_path.stem + "_lora_disabled")
sf.write(str(disabled_output), audio_np, model.sample_rate)
print(f" 已保存:{disabled_output},时长 {len(audio_np) / model.sample_rate:.2f}s")
# === 测试 3: 重新启用 LoRA ===
print(f"\n [Test 3] 重新启用 LoRA (set_lora_enabled=True)...")
model.set_lora_enabled(True)
with torch.inference_mode():
audio = model.generate(
target_text=args.text,
prompt_text=prompt_text,
prompt_wav_path=prompt_wav_path,
max_len=args.max_len,
inference_timesteps=args.inference_timesteps,
cfg_value=args.cfg_value,
)
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
reenabled_output = out_path.with_stem(out_path.stem + "_lora_reenabled")
sf.write(str(reenabled_output), audio_np, model.sample_rate)
print(f" 已保存:{reenabled_output},时长 {len(audio_np) / model.sample_rate:.2f}s")
# === 测试 4: 卸载 LoRA(reset_lora_weights) ===
print(f"\n [Test 4] 卸载 LoRA (reset_lora_weights)...")
model.reset_lora_weights()
with torch.inference_mode():
audio = model.generate(
target_text=args.text,
prompt_text=prompt_text,
prompt_wav_path=prompt_wav_path,
max_len=args.max_len,
inference_timesteps=args.inference_timesteps,
cfg_value=args.cfg_value,
)
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
reset_output = out_path.with_stem(out_path.stem + "_lora_reset")
sf.write(str(reset_output), audio_np, model.sample_rate)
print(f" 已保存:{reset_output},时长 {len(audio_np) / model.sample_rate:.2f}s")
# === 测试 5: 热加载 LoRA(重新加载权重) ===
print(f"\n [Test 5] 热加载 LoRA (load_lora_weights)...")
loaded, _ = model.load_lora_weights(str(ckpt_dir))
print(f" 重新加载了 {len(loaded)} 个参数")
with torch.inference_mode():
audio = model.generate(
target_text=args.text,
prompt_text=prompt_text,
prompt_wav_path=prompt_wav_path,
max_len=args.max_len,
inference_timesteps=args.inference_timesteps,
cfg_value=args.cfg_value,
)
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
reload_output = out_path.with_stem(out_path.stem + "_lora_reloaded")
sf.write(str(reload_output), audio_np, model.sample_rate)
print(f" 已保存:{reload_output},时长 {len(audio_np) / model.sample_rate:.2f}s")
print(f"\n[完成] 所有测试完成!")
print(f" - with_lora: {lora_output}")
print(f" - lora_disabled: {disabled_output}")
print(f" - lora_reenabled: {reenabled_output}")
print(f" - lora_reset: {reset_output}")
print(f" - lora_reloaded: {reload_output}")
if __name__ == "__main__":
main()