|
|
|
|
|
""" |
|
|
简单的全量微调权重推理测试脚本(无 LoRA): |
|
|
|
|
|
1. 使用训练时相同的 YAML(拿到 pretrained_path) |
|
|
2. 从 pretrained_path 加载基础 VoxCPM 模型(config.json / pytorch_model.bin / audiovae.pth) |
|
|
3. 再从全量微调 checkpoint 目录加载微调后的 generator.pth(保存的是完整 state_dict) |
|
|
4. 调用模型的 generate 接口合成语音并保存为 wav |
|
|
|
|
|
用法示例(与你的 finetune 配置保持一致): |
|
|
|
|
|
python scripts/test_voxcpm_ft_infer.py \ |
|
|
--config_path conf/voxcpm/voxcpm_finetune_test.yaml \ |
|
|
--ckpt_dir /user/liuxin/checkpoints/voxcpm_finetune/step_0001000 \ |
|
|
--text "你好,我是全量微调后的 VoxCPM。" \ |
|
|
--output ft_test.wav |
|
|
|
|
|
如果需要参考音色克隆: |
|
|
|
|
|
python scripts/test_voxcpm_ft_infer.py \ |
|
|
--config_path conf/voxcpm/voxcpm_finetune_test.yaml \ |
|
|
--ckpt_dir /user/liuxin/checkpoints/voxcpm_finetune/step_0001000 \ |
|
|
--text "你好,这是带参考音色的合成效果。" \ |
|
|
--prompt_audio path/to/ref.wav \ |
|
|
--prompt_text "参考音频对应的文本" \ |
|
|
--output ft_clone.wav |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
from pathlib import Path |
|
|
|
|
|
import soundfile as sf |
|
|
import torch |
|
|
|
|
|
from voxcpm.model import VoxCPMModel |
|
|
from voxcpm.training.config import load_yaml_config |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser("VoxCPM full-finetune inference test (no LoRA)") |
|
|
parser.add_argument( |
|
|
"--config_path", |
|
|
type=str, |
|
|
required=True, |
|
|
help="训练时使用的 YAML 配置路径(至少包含 pretrained_path)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--ckpt_dir", |
|
|
type=str, |
|
|
required=True, |
|
|
help="全量微调 checkpoint 目录(内含 generator.pth,保存完整 state_dict)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--text", |
|
|
type=str, |
|
|
required=True, |
|
|
help="待合成的目标文本(target_text)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--prompt_audio", |
|
|
type=str, |
|
|
default="", |
|
|
help="可选:参考音频路径,用于 voice cloning(不填则为直接 TTS)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--prompt_text", |
|
|
type=str, |
|
|
default="", |
|
|
help="可选:参考音频对应的文本(与 prompt_audio 搭配使用)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output", |
|
|
type=str, |
|
|
default="ft_test.wav", |
|
|
help="输出 wav 文件路径", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--cfg_value", |
|
|
type=float, |
|
|
default=2.0, |
|
|
help="推理时的 CFG scale,与训练 / 官方示例保持一致(默认 2.0)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--inference_timesteps", |
|
|
type=int, |
|
|
default=10, |
|
|
help="扩散推理步数,默认为 10", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max_len", |
|
|
type=int, |
|
|
default=600, |
|
|
help="生成阶段的最大步数(对应 _generate 的 max_len)", |
|
|
) |
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
|
|
|
|
|
|
cfg = load_yaml_config(args.config_path) |
|
|
pretrained_path = cfg["pretrained_path"] |
|
|
|
|
|
|
|
|
print(f"[FT Inference] 加载基础模型:{pretrained_path}") |
|
|
model = VoxCPMModel.from_local( |
|
|
pretrained_path, |
|
|
optimize=True, |
|
|
training=False, |
|
|
lora_config=None, |
|
|
) |
|
|
|
|
|
|
|
|
ckpt_dir = Path(args.ckpt_dir) |
|
|
ckpt_path = ckpt_dir / "generator.pth" |
|
|
if not ckpt_path.exists(): |
|
|
raise FileNotFoundError(f"找不到全量微调 checkpoint: {ckpt_path}") |
|
|
|
|
|
print(f"[FT Inference] 从 {ckpt_path} 加载全量微调权重") |
|
|
ckpt = torch.load(ckpt_path, map_location="cpu") |
|
|
state_dict = ckpt.get("state_dict", ckpt) |
|
|
|
|
|
|
|
|
missing, unexpected = model.load_state_dict(state_dict, strict=False) |
|
|
print(f"[FT Inference] 加载完成:missing={len(missing)}, unexpected={len(unexpected)}") |
|
|
|
|
|
|
|
|
prompt_wav_path = args.prompt_audio or "" |
|
|
prompt_text = args.prompt_text or "" |
|
|
|
|
|
print(f"[FT Inference] 开始合成:text='{args.text}'") |
|
|
if prompt_wav_path: |
|
|
print(f"[FT Inference] 使用参考音频:{prompt_wav_path}") |
|
|
print(f"[FT Inference] 参考文本:{prompt_text}") |
|
|
|
|
|
with torch.inference_mode(): |
|
|
audio = model.generate( |
|
|
target_text=args.text, |
|
|
prompt_text=prompt_text, |
|
|
prompt_wav_path=prompt_wav_path, |
|
|
max_len=args.max_len, |
|
|
inference_timesteps=args.inference_timesteps, |
|
|
cfg_value=args.cfg_value, |
|
|
) |
|
|
|
|
|
|
|
|
if isinstance(audio, torch.Tensor): |
|
|
if audio.dim() > 1: |
|
|
audio = audio.squeeze(0) |
|
|
audio_np = audio.cpu().numpy() |
|
|
else: |
|
|
raise TypeError(f"model.generate 返回类型异常:{type(audio)}") |
|
|
|
|
|
out_path = Path(args.output) |
|
|
out_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
sf.write(str(out_path), audio_np, model.sample_rate) |
|
|
|
|
|
print(f"[FT Inference] 已保存到:{out_path},时长约 {len(audio_np) / model.sample_rate:.2f}s") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|
|
|
|