|
|
|
|
|
""" |
|
|
LoRA 推理测试脚本。 |
|
|
|
|
|
用法: |
|
|
|
|
|
python scripts/test_voxcpm_lora_infer.py \ |
|
|
--config_path conf/voxcpm/voxcpm_finetune_test.yaml \ |
|
|
--lora_ckpt checkpoints/step_0002000 \ |
|
|
--text "你好,这是 LoRA 微调后的效果。" \ |
|
|
--output lora_test.wav |
|
|
|
|
|
带参考音频的音色克隆: |
|
|
|
|
|
python scripts/test_voxcpm_lora_infer.py \ |
|
|
--config_path conf/voxcpm/voxcpm_finetune_test.yaml \ |
|
|
--lora_ckpt checkpoints/step_0002000 \ |
|
|
--text "这是带参考音色的合成效果。" \ |
|
|
--prompt_audio path/to/ref.wav \ |
|
|
--prompt_text "参考音频对应的文本" \ |
|
|
--output lora_clone.wav |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
from pathlib import Path |
|
|
|
|
|
import soundfile as sf |
|
|
import torch |
|
|
|
|
|
from voxcpm.model import VoxCPMModel |
|
|
from voxcpm.model.voxcpm import LoRAConfig |
|
|
from voxcpm.training.config import load_yaml_config |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser("VoxCPM LoRA inference test") |
|
|
parser.add_argument( |
|
|
"--config_path", |
|
|
type=str, |
|
|
required=True, |
|
|
help="训练时使用的 YAML 配置路径(包含 pretrained_path 和 lora 配置)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--lora_ckpt", |
|
|
type=str, |
|
|
required=True, |
|
|
help="LoRA checkpoint 目录(内含 generator.pth,仅包含 lora_A / lora_B)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--text", |
|
|
type=str, |
|
|
required=True, |
|
|
help="待合成的目标文本(target_text)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--prompt_audio", |
|
|
type=str, |
|
|
default="", |
|
|
help="可选:参考音频路径,用于 voice cloning(不填则为直接 TTS)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--prompt_text", |
|
|
type=str, |
|
|
default="", |
|
|
help="可选:参考音频对应的文本(与 prompt_audio 搭配使用)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output", |
|
|
type=str, |
|
|
default="lora_test.wav", |
|
|
help="输出 wav 文件路径", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--cfg_value", |
|
|
type=float, |
|
|
default=2.0, |
|
|
help="推理时的 CFG scale,与训练 / 官方示例保持一致(默认 2.0)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--inference_timesteps", |
|
|
type=int, |
|
|
default=10, |
|
|
help="扩散推理步数,默认为 10", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max_len", |
|
|
type=int, |
|
|
default=600, |
|
|
help="生成阶段的最大步数(对应 _generate 的 max_len)", |
|
|
) |
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
|
|
|
|
|
|
cfg = load_yaml_config(args.config_path) |
|
|
pretrained_path = cfg["pretrained_path"] |
|
|
lora_cfg_dict = cfg.get("lora", {}) or {} |
|
|
lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None |
|
|
|
|
|
|
|
|
print(f"[1/3] 加载基础模型:{pretrained_path}") |
|
|
model = VoxCPMModel.from_local( |
|
|
pretrained_path, |
|
|
optimize=True, |
|
|
training=False, |
|
|
lora_config=lora_cfg, |
|
|
) |
|
|
|
|
|
|
|
|
dit_params = [n for n, _ in model.named_parameters() if 'feat_decoder' in n and 'lora' in n] |
|
|
print(f"[DEBUG] compile 后 DiT LoRA 参数路径 (前3个): {dit_params[:3]}") |
|
|
|
|
|
|
|
|
ckpt_dir = Path(args.lora_ckpt) |
|
|
if not ckpt_dir.exists(): |
|
|
raise FileNotFoundError(f"找不到 LoRA checkpoint: {ckpt_dir}") |
|
|
|
|
|
print(f"[2/3] 加载 LoRA 权重:{ckpt_dir}") |
|
|
loaded, skipped = model.load_lora_weights(str(ckpt_dir)) |
|
|
print(f" 已加载 {len(loaded)} 个参数") |
|
|
if skipped: |
|
|
print(f"[WARNING] 跳过 {len(skipped)} 个参数") |
|
|
print(f" 跳过的 key (前5个): {skipped[:5]}") |
|
|
|
|
|
|
|
|
prompt_wav_path = args.prompt_audio or "" |
|
|
prompt_text = args.prompt_text or "" |
|
|
out_path = Path(args.output) |
|
|
out_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
print(f"\n[3/3] 开始合成测试...") |
|
|
|
|
|
|
|
|
print(f"\n [Test 1] 使用 LoRA 合成...") |
|
|
with torch.inference_mode(): |
|
|
audio = model.generate( |
|
|
target_text=args.text, |
|
|
prompt_text=prompt_text, |
|
|
prompt_wav_path=prompt_wav_path, |
|
|
max_len=args.max_len, |
|
|
inference_timesteps=args.inference_timesteps, |
|
|
cfg_value=args.cfg_value, |
|
|
) |
|
|
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy() |
|
|
lora_output = out_path.with_stem(out_path.stem + "_with_lora") |
|
|
sf.write(str(lora_output), audio_np, model.sample_rate) |
|
|
print(f" 已保存:{lora_output},时长 {len(audio_np) / model.sample_rate:.2f}s") |
|
|
|
|
|
|
|
|
print(f"\n [Test 2] 禁用 LoRA (set_lora_enabled=False)...") |
|
|
model.set_lora_enabled(False) |
|
|
with torch.inference_mode(): |
|
|
audio = model.generate( |
|
|
target_text=args.text, |
|
|
prompt_text=prompt_text, |
|
|
prompt_wav_path=prompt_wav_path, |
|
|
max_len=args.max_len, |
|
|
inference_timesteps=args.inference_timesteps, |
|
|
cfg_value=args.cfg_value, |
|
|
) |
|
|
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy() |
|
|
disabled_output = out_path.with_stem(out_path.stem + "_lora_disabled") |
|
|
sf.write(str(disabled_output), audio_np, model.sample_rate) |
|
|
print(f" 已保存:{disabled_output},时长 {len(audio_np) / model.sample_rate:.2f}s") |
|
|
|
|
|
|
|
|
print(f"\n [Test 3] 重新启用 LoRA (set_lora_enabled=True)...") |
|
|
model.set_lora_enabled(True) |
|
|
with torch.inference_mode(): |
|
|
audio = model.generate( |
|
|
target_text=args.text, |
|
|
prompt_text=prompt_text, |
|
|
prompt_wav_path=prompt_wav_path, |
|
|
max_len=args.max_len, |
|
|
inference_timesteps=args.inference_timesteps, |
|
|
cfg_value=args.cfg_value, |
|
|
) |
|
|
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy() |
|
|
reenabled_output = out_path.with_stem(out_path.stem + "_lora_reenabled") |
|
|
sf.write(str(reenabled_output), audio_np, model.sample_rate) |
|
|
print(f" 已保存:{reenabled_output},时长 {len(audio_np) / model.sample_rate:.2f}s") |
|
|
|
|
|
|
|
|
print(f"\n [Test 4] 卸载 LoRA (reset_lora_weights)...") |
|
|
model.reset_lora_weights() |
|
|
with torch.inference_mode(): |
|
|
audio = model.generate( |
|
|
target_text=args.text, |
|
|
prompt_text=prompt_text, |
|
|
prompt_wav_path=prompt_wav_path, |
|
|
max_len=args.max_len, |
|
|
inference_timesteps=args.inference_timesteps, |
|
|
cfg_value=args.cfg_value, |
|
|
) |
|
|
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy() |
|
|
reset_output = out_path.with_stem(out_path.stem + "_lora_reset") |
|
|
sf.write(str(reset_output), audio_np, model.sample_rate) |
|
|
print(f" 已保存:{reset_output},时长 {len(audio_np) / model.sample_rate:.2f}s") |
|
|
|
|
|
|
|
|
print(f"\n [Test 5] 热加载 LoRA (load_lora_weights)...") |
|
|
loaded, _ = model.load_lora_weights(str(ckpt_dir)) |
|
|
print(f" 重新加载了 {len(loaded)} 个参数") |
|
|
with torch.inference_mode(): |
|
|
audio = model.generate( |
|
|
target_text=args.text, |
|
|
prompt_text=prompt_text, |
|
|
prompt_wav_path=prompt_wav_path, |
|
|
max_len=args.max_len, |
|
|
inference_timesteps=args.inference_timesteps, |
|
|
cfg_value=args.cfg_value, |
|
|
) |
|
|
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy() |
|
|
reload_output = out_path.with_stem(out_path.stem + "_lora_reloaded") |
|
|
sf.write(str(reload_output), audio_np, model.sample_rate) |
|
|
print(f" 已保存:{reload_output},时长 {len(audio_np) / model.sample_rate:.2f}s") |
|
|
|
|
|
print(f"\n[完成] 所有测试完成!") |
|
|
print(f" - with_lora: {lora_output}") |
|
|
print(f" - lora_disabled: {disabled_output}") |
|
|
print(f" - lora_reenabled: {reenabled_output}") |
|
|
print(f" - lora_reset: {reset_output}") |
|
|
print(f" - lora_reloaded: {reload_output}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|
|
|
|