File size: 5,537 Bytes
6766eda |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
#!/usr/bin/env python3
"""
简单的全量微调权重推理测试脚本(无 LoRA):
1. 使用训练时相同的 YAML(拿到 pretrained_path)
2. 从 pretrained_path 加载基础 VoxCPM 模型(config.json / pytorch_model.bin / audiovae.pth)
3. 再从全量微调 checkpoint 目录加载微调后的 generator.pth(保存的是完整 state_dict)
4. 调用模型的 generate 接口合成语音并保存为 wav
用法示例(与你的 finetune 配置保持一致):
python scripts/test_voxcpm_ft_infer.py \
--config_path conf/voxcpm/voxcpm_finetune_test.yaml \
--ckpt_dir /user/liuxin/checkpoints/voxcpm_finetune/step_0001000 \
--text "你好,我是全量微调后的 VoxCPM。" \
--output ft_test.wav
如果需要参考音色克隆:
python scripts/test_voxcpm_ft_infer.py \
--config_path conf/voxcpm/voxcpm_finetune_test.yaml \
--ckpt_dir /user/liuxin/checkpoints/voxcpm_finetune/step_0001000 \
--text "你好,这是带参考音色的合成效果。" \
--prompt_audio path/to/ref.wav \
--prompt_text "参考音频对应的文本" \
--output ft_clone.wav
"""
import argparse
from pathlib import Path
import soundfile as sf
import torch
from voxcpm.model import VoxCPMModel
from voxcpm.training.config import load_yaml_config
def parse_args():
parser = argparse.ArgumentParser("VoxCPM full-finetune inference test (no LoRA)")
parser.add_argument(
"--config_path",
type=str,
required=True,
help="训练时使用的 YAML 配置路径(至少包含 pretrained_path)",
)
parser.add_argument(
"--ckpt_dir",
type=str,
required=True,
help="全量微调 checkpoint 目录(内含 generator.pth,保存完整 state_dict)",
)
parser.add_argument(
"--text",
type=str,
required=True,
help="待合成的目标文本(target_text)",
)
parser.add_argument(
"--prompt_audio",
type=str,
default="",
help="可选:参考音频路径,用于 voice cloning(不填则为直接 TTS)",
)
parser.add_argument(
"--prompt_text",
type=str,
default="",
help="可选:参考音频对应的文本(与 prompt_audio 搭配使用)",
)
parser.add_argument(
"--output",
type=str,
default="ft_test.wav",
help="输出 wav 文件路径",
)
parser.add_argument(
"--cfg_value",
type=float,
default=2.0,
help="推理时的 CFG scale,与训练 / 官方示例保持一致(默认 2.0)",
)
parser.add_argument(
"--inference_timesteps",
type=int,
default=10,
help="扩散推理步数,默认为 10",
)
parser.add_argument(
"--max_len",
type=int,
default=600,
help="生成阶段的最大步数(对应 _generate 的 max_len)",
)
return parser.parse_args()
def main():
args = parse_args()
# 1. 读取 YAML 配置,拿到 pretrained_path
cfg = load_yaml_config(args.config_path)
pretrained_path = cfg["pretrained_path"]
# 2. 加载基础模型(无 LoRA),推理模式:training=False,使用 config.dtype,并执行 optimize()
print(f"[FT Inference] 加载基础模型:{pretrained_path}")
model = VoxCPMModel.from_local(
pretrained_path,
optimize=True,
training=False,
lora_config=None,
)
# 3. 从全量微调 ckpt 目录加载 generator.pth(完整 state_dict)
ckpt_dir = Path(args.ckpt_dir)
ckpt_path = ckpt_dir / "generator.pth"
if not ckpt_path.exists():
raise FileNotFoundError(f"找不到全量微调 checkpoint: {ckpt_path}")
print(f"[FT Inference] 从 {ckpt_path} 加载全量微调权重")
ckpt = torch.load(ckpt_path, map_location="cpu")
state_dict = ckpt.get("state_dict", ckpt)
# 为兼容可能存在的额外键,使用 strict=False,并打印 missing / unexpected 数量
missing, unexpected = model.load_state_dict(state_dict, strict=False)
print(f"[FT Inference] 加载完成:missing={len(missing)}, unexpected={len(unexpected)}")
# 4. 调用 generate 进行推理
prompt_wav_path = args.prompt_audio or ""
prompt_text = args.prompt_text or ""
print(f"[FT Inference] 开始合成:text='{args.text}'")
if prompt_wav_path:
print(f"[FT Inference] 使用参考音频:{prompt_wav_path}")
print(f"[FT Inference] 参考文本:{prompt_text}")
with torch.inference_mode():
audio = model.generate(
target_text=args.text,
prompt_text=prompt_text,
prompt_wav_path=prompt_wav_path,
max_len=args.max_len,
inference_timesteps=args.inference_timesteps,
cfg_value=args.cfg_value,
)
# generate 返回的是一维 / 二维 Tensor,这里做一次收缩并落盘
if isinstance(audio, torch.Tensor):
if audio.dim() > 1:
audio = audio.squeeze(0)
audio_np = audio.cpu().numpy()
else:
raise TypeError(f"model.generate 返回类型异常:{type(audio)}")
out_path = Path(args.output)
out_path.parent.mkdir(parents=True, exist_ok=True)
sf.write(str(out_path), audio_np, model.sample_rate)
print(f"[FT Inference] 已保存到:{out_path},时长约 {len(audio_np) / model.sample_rate:.2f}s")
if __name__ == "__main__":
main()
|