File size: 5,537 Bytes
6766eda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
#!/usr/bin/env python3
"""
简单的全量微调权重推理测试脚本(无 LoRA):

1. 使用训练时相同的 YAML(拿到 pretrained_path)
2. 从 pretrained_path 加载基础 VoxCPM 模型(config.json / pytorch_model.bin / audiovae.pth)
3. 再从全量微调 checkpoint 目录加载微调后的 generator.pth(保存的是完整 state_dict)
4. 调用模型的 generate 接口合成语音并保存为 wav

用法示例(与你的 finetune 配置保持一致):

    python scripts/test_voxcpm_ft_infer.py \
        --config_path conf/voxcpm/voxcpm_finetune_test.yaml \
        --ckpt_dir /user/liuxin/checkpoints/voxcpm_finetune/step_0001000 \
        --text "你好,我是全量微调后的 VoxCPM。" \
        --output ft_test.wav

如果需要参考音色克隆:

    python scripts/test_voxcpm_ft_infer.py \
        --config_path conf/voxcpm/voxcpm_finetune_test.yaml \
        --ckpt_dir /user/liuxin/checkpoints/voxcpm_finetune/step_0001000 \
        --text "你好,这是带参考音色的合成效果。" \
        --prompt_audio path/to/ref.wav \
        --prompt_text "参考音频对应的文本" \
        --output ft_clone.wav
"""

import argparse
from pathlib import Path

import soundfile as sf
import torch

from voxcpm.model import VoxCPMModel
from voxcpm.training.config import load_yaml_config


def parse_args():
    parser = argparse.ArgumentParser("VoxCPM full-finetune inference test (no LoRA)")
    parser.add_argument(
        "--config_path",
        type=str,
        required=True,
        help="训练时使用的 YAML 配置路径(至少包含 pretrained_path)",
    )
    parser.add_argument(
        "--ckpt_dir",
        type=str,
        required=True,
        help="全量微调 checkpoint 目录(内含 generator.pth,保存完整 state_dict)",
    )
    parser.add_argument(
        "--text",
        type=str,
        required=True,
        help="待合成的目标文本(target_text)",
    )
    parser.add_argument(
        "--prompt_audio",
        type=str,
        default="",
        help="可选:参考音频路径,用于 voice cloning(不填则为直接 TTS)",
    )
    parser.add_argument(
        "--prompt_text",
        type=str,
        default="",
        help="可选:参考音频对应的文本(与 prompt_audio 搭配使用)",
    )
    parser.add_argument(
        "--output",
        type=str,
        default="ft_test.wav",
        help="输出 wav 文件路径",
    )
    parser.add_argument(
        "--cfg_value",
        type=float,
        default=2.0,
        help="推理时的 CFG scale,与训练 / 官方示例保持一致(默认 2.0)",
    )
    parser.add_argument(
        "--inference_timesteps",
        type=int,
        default=10,
        help="扩散推理步数,默认为 10",
    )
    parser.add_argument(
        "--max_len",
        type=int,
        default=600,
        help="生成阶段的最大步数(对应 _generate 的 max_len)",
    )
    return parser.parse_args()


def main():
    args = parse_args()

    # 1. 读取 YAML 配置,拿到 pretrained_path
    cfg = load_yaml_config(args.config_path)
    pretrained_path = cfg["pretrained_path"]

    # 2. 加载基础模型(无 LoRA),推理模式:training=False,使用 config.dtype,并执行 optimize()
    print(f"[FT Inference] 加载基础模型:{pretrained_path}")
    model = VoxCPMModel.from_local(
        pretrained_path,
        optimize=True,
        training=False,
        lora_config=None,
    )

    # 3. 从全量微调 ckpt 目录加载 generator.pth(完整 state_dict)
    ckpt_dir = Path(args.ckpt_dir)
    ckpt_path = ckpt_dir / "generator.pth"
    if not ckpt_path.exists():
        raise FileNotFoundError(f"找不到全量微调 checkpoint: {ckpt_path}")

    print(f"[FT Inference] 从 {ckpt_path} 加载全量微调权重")
    ckpt = torch.load(ckpt_path, map_location="cpu")
    state_dict = ckpt.get("state_dict", ckpt)

    # 为兼容可能存在的额外键,使用 strict=False,并打印 missing / unexpected 数量
    missing, unexpected = model.load_state_dict(state_dict, strict=False)
    print(f"[FT Inference] 加载完成:missing={len(missing)}, unexpected={len(unexpected)}")

    # 4. 调用 generate 进行推理
    prompt_wav_path = args.prompt_audio or ""
    prompt_text = args.prompt_text or ""

    print(f"[FT Inference] 开始合成:text='{args.text}'")
    if prompt_wav_path:
        print(f"[FT Inference] 使用参考音频:{prompt_wav_path}")
        print(f"[FT Inference] 参考文本:{prompt_text}")

    with torch.inference_mode():
        audio = model.generate(
            target_text=args.text,
            prompt_text=prompt_text,
            prompt_wav_path=prompt_wav_path,
            max_len=args.max_len,
            inference_timesteps=args.inference_timesteps,
            cfg_value=args.cfg_value,
        )

    # generate 返回的是一维 / 二维 Tensor,这里做一次收缩并落盘
    if isinstance(audio, torch.Tensor):
        if audio.dim() > 1:
            audio = audio.squeeze(0)
        audio_np = audio.cpu().numpy()
    else:
        raise TypeError(f"model.generate 返回类型异常:{type(audio)}")

    out_path = Path(args.output)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    sf.write(str(out_path), audio_np, model.sample_rate)

    print(f"[FT Inference] 已保存到:{out_path},时长约 {len(audio_np) / model.sample_rate:.2f}s")


if __name__ == "__main__":
    main()