dpss-exp3-TTS / VoxCPM /scripts /test_voxcpm_ft_infer.py
lglg666's picture
Upload folder using huggingface_hub
6766eda verified
#!/usr/bin/env python3
"""
简单的全量微调权重推理测试脚本(无 LoRA):
1. 使用训练时相同的 YAML(拿到 pretrained_path)
2. 从 pretrained_path 加载基础 VoxCPM 模型(config.json / pytorch_model.bin / audiovae.pth)
3. 再从全量微调 checkpoint 目录加载微调后的 generator.pth(保存的是完整 state_dict)
4. 调用模型的 generate 接口合成语音并保存为 wav
用法示例(与你的 finetune 配置保持一致):
python scripts/test_voxcpm_ft_infer.py \
--config_path conf/voxcpm/voxcpm_finetune_test.yaml \
--ckpt_dir /user/liuxin/checkpoints/voxcpm_finetune/step_0001000 \
--text "你好,我是全量微调后的 VoxCPM。" \
--output ft_test.wav
如果需要参考音色克隆:
python scripts/test_voxcpm_ft_infer.py \
--config_path conf/voxcpm/voxcpm_finetune_test.yaml \
--ckpt_dir /user/liuxin/checkpoints/voxcpm_finetune/step_0001000 \
--text "你好,这是带参考音色的合成效果。" \
--prompt_audio path/to/ref.wav \
--prompt_text "参考音频对应的文本" \
--output ft_clone.wav
"""
import argparse
from pathlib import Path
import soundfile as sf
import torch
from voxcpm.model import VoxCPMModel
from voxcpm.training.config import load_yaml_config
def parse_args():
parser = argparse.ArgumentParser("VoxCPM full-finetune inference test (no LoRA)")
parser.add_argument(
"--config_path",
type=str,
required=True,
help="训练时使用的 YAML 配置路径(至少包含 pretrained_path)",
)
parser.add_argument(
"--ckpt_dir",
type=str,
required=True,
help="全量微调 checkpoint 目录(内含 generator.pth,保存完整 state_dict)",
)
parser.add_argument(
"--text",
type=str,
required=True,
help="待合成的目标文本(target_text)",
)
parser.add_argument(
"--prompt_audio",
type=str,
default="",
help="可选:参考音频路径,用于 voice cloning(不填则为直接 TTS)",
)
parser.add_argument(
"--prompt_text",
type=str,
default="",
help="可选:参考音频对应的文本(与 prompt_audio 搭配使用)",
)
parser.add_argument(
"--output",
type=str,
default="ft_test.wav",
help="输出 wav 文件路径",
)
parser.add_argument(
"--cfg_value",
type=float,
default=2.0,
help="推理时的 CFG scale,与训练 / 官方示例保持一致(默认 2.0)",
)
parser.add_argument(
"--inference_timesteps",
type=int,
default=10,
help="扩散推理步数,默认为 10",
)
parser.add_argument(
"--max_len",
type=int,
default=600,
help="生成阶段的最大步数(对应 _generate 的 max_len)",
)
return parser.parse_args()
def main():
args = parse_args()
# 1. 读取 YAML 配置,拿到 pretrained_path
cfg = load_yaml_config(args.config_path)
pretrained_path = cfg["pretrained_path"]
# 2. 加载基础模型(无 LoRA),推理模式:training=False,使用 config.dtype,并执行 optimize()
print(f"[FT Inference] 加载基础模型:{pretrained_path}")
model = VoxCPMModel.from_local(
pretrained_path,
optimize=True,
training=False,
lora_config=None,
)
# 3. 从全量微调 ckpt 目录加载 generator.pth(完整 state_dict)
ckpt_dir = Path(args.ckpt_dir)
ckpt_path = ckpt_dir / "generator.pth"
if not ckpt_path.exists():
raise FileNotFoundError(f"找不到全量微调 checkpoint: {ckpt_path}")
print(f"[FT Inference] 从 {ckpt_path} 加载全量微调权重")
ckpt = torch.load(ckpt_path, map_location="cpu")
state_dict = ckpt.get("state_dict", ckpt)
# 为兼容可能存在的额外键,使用 strict=False,并打印 missing / unexpected 数量
missing, unexpected = model.load_state_dict(state_dict, strict=False)
print(f"[FT Inference] 加载完成:missing={len(missing)}, unexpected={len(unexpected)}")
# 4. 调用 generate 进行推理
prompt_wav_path = args.prompt_audio or ""
prompt_text = args.prompt_text or ""
print(f"[FT Inference] 开始合成:text='{args.text}'")
if prompt_wav_path:
print(f"[FT Inference] 使用参考音频:{prompt_wav_path}")
print(f"[FT Inference] 参考文本:{prompt_text}")
with torch.inference_mode():
audio = model.generate(
target_text=args.text,
prompt_text=prompt_text,
prompt_wav_path=prompt_wav_path,
max_len=args.max_len,
inference_timesteps=args.inference_timesteps,
cfg_value=args.cfg_value,
)
# generate 返回的是一维 / 二维 Tensor,这里做一次收缩并落盘
if isinstance(audio, torch.Tensor):
if audio.dim() > 1:
audio = audio.squeeze(0)
audio_np = audio.cpu().numpy()
else:
raise TypeError(f"model.generate 返回类型异常:{type(audio)}")
out_path = Path(args.output)
out_path.parent.mkdir(parents=True, exist_ok=True)
sf.write(str(out_path), audio_np, model.sample_rate)
print(f"[FT Inference] 已保存到:{out_path},时长约 {len(audio_np) / model.sample_rate:.2f}s")
if __name__ == "__main__":
main()