File size: 8,607 Bytes
6766eda |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
#!/usr/bin/env python3
"""
LoRA 推理测试脚本。
用法:
python scripts/test_voxcpm_lora_infer.py \
--config_path conf/voxcpm/voxcpm_finetune_test.yaml \
--lora_ckpt checkpoints/step_0002000 \
--text "你好,这是 LoRA 微调后的效果。" \
--output lora_test.wav
带参考音频的音色克隆:
python scripts/test_voxcpm_lora_infer.py \
--config_path conf/voxcpm/voxcpm_finetune_test.yaml \
--lora_ckpt checkpoints/step_0002000 \
--text "这是带参考音色的合成效果。" \
--prompt_audio path/to/ref.wav \
--prompt_text "参考音频对应的文本" \
--output lora_clone.wav
"""
import argparse
from pathlib import Path
import soundfile as sf
import torch
from voxcpm.model import VoxCPMModel
from voxcpm.model.voxcpm import LoRAConfig
from voxcpm.training.config import load_yaml_config
def parse_args():
parser = argparse.ArgumentParser("VoxCPM LoRA inference test")
parser.add_argument(
"--config_path",
type=str,
required=True,
help="训练时使用的 YAML 配置路径(包含 pretrained_path 和 lora 配置)",
)
parser.add_argument(
"--lora_ckpt",
type=str,
required=True,
help="LoRA checkpoint 目录(内含 generator.pth,仅包含 lora_A / lora_B)",
)
parser.add_argument(
"--text",
type=str,
required=True,
help="待合成的目标文本(target_text)",
)
parser.add_argument(
"--prompt_audio",
type=str,
default="",
help="可选:参考音频路径,用于 voice cloning(不填则为直接 TTS)",
)
parser.add_argument(
"--prompt_text",
type=str,
default="",
help="可选:参考音频对应的文本(与 prompt_audio 搭配使用)",
)
parser.add_argument(
"--output",
type=str,
default="lora_test.wav",
help="输出 wav 文件路径",
)
parser.add_argument(
"--cfg_value",
type=float,
default=2.0,
help="推理时的 CFG scale,与训练 / 官方示例保持一致(默认 2.0)",
)
parser.add_argument(
"--inference_timesteps",
type=int,
default=10,
help="扩散推理步数,默认为 10",
)
parser.add_argument(
"--max_len",
type=int,
default=600,
help="生成阶段的最大步数(对应 _generate 的 max_len)",
)
return parser.parse_args()
def main():
args = parse_args()
# 1. 读取 YAML 配置
cfg = load_yaml_config(args.config_path)
pretrained_path = cfg["pretrained_path"]
lora_cfg_dict = cfg.get("lora", {}) or {}
lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None
# 2. 加载基础模型(包含 LoRA 结构,并执行 torch.compile)
print(f"[1/3] 加载基础模型:{pretrained_path}")
model = VoxCPMModel.from_local(
pretrained_path,
optimize=True, # 先 compile,load_lora_weights 使用 named_parameters 兼容
training=False,
lora_config=lora_cfg,
)
# 调试:检查 compile 后 DiT 的参数路径
dit_params = [n for n, _ in model.named_parameters() if 'feat_decoder' in n and 'lora' in n]
print(f"[DEBUG] compile 后 DiT LoRA 参数路径 (前3个): {dit_params[:3]}")
# 3. 加载 LoRA 权重(在 compile 后也能正常工作)
ckpt_dir = Path(args.lora_ckpt)
if not ckpt_dir.exists():
raise FileNotFoundError(f"找不到 LoRA checkpoint: {ckpt_dir}")
print(f"[2/3] 加载 LoRA 权重:{ckpt_dir}")
loaded, skipped = model.load_lora_weights(str(ckpt_dir))
print(f" 已加载 {len(loaded)} 个参数")
if skipped:
print(f"[WARNING] 跳过 {len(skipped)} 个参数")
print(f" 跳过的 key (前5个): {skipped[:5]}")
# 4. 合成语音
prompt_wav_path = args.prompt_audio or ""
prompt_text = args.prompt_text or ""
out_path = Path(args.output)
out_path.parent.mkdir(parents=True, exist_ok=True)
print(f"\n[3/3] 开始合成测试...")
# === 测试 1: 使用 LoRA ===
print(f"\n [Test 1] 使用 LoRA 合成...")
with torch.inference_mode():
audio = model.generate(
target_text=args.text,
prompt_text=prompt_text,
prompt_wav_path=prompt_wav_path,
max_len=args.max_len,
inference_timesteps=args.inference_timesteps,
cfg_value=args.cfg_value,
)
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
lora_output = out_path.with_stem(out_path.stem + "_with_lora")
sf.write(str(lora_output), audio_np, model.sample_rate)
print(f" 已保存:{lora_output},时长 {len(audio_np) / model.sample_rate:.2f}s")
# === 测试 2: 禁用 LoRA(通过 set_lora_enabled) ===
print(f"\n [Test 2] 禁用 LoRA (set_lora_enabled=False)...")
model.set_lora_enabled(False)
with torch.inference_mode():
audio = model.generate(
target_text=args.text,
prompt_text=prompt_text,
prompt_wav_path=prompt_wav_path,
max_len=args.max_len,
inference_timesteps=args.inference_timesteps,
cfg_value=args.cfg_value,
)
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
disabled_output = out_path.with_stem(out_path.stem + "_lora_disabled")
sf.write(str(disabled_output), audio_np, model.sample_rate)
print(f" 已保存:{disabled_output},时长 {len(audio_np) / model.sample_rate:.2f}s")
# === 测试 3: 重新启用 LoRA ===
print(f"\n [Test 3] 重新启用 LoRA (set_lora_enabled=True)...")
model.set_lora_enabled(True)
with torch.inference_mode():
audio = model.generate(
target_text=args.text,
prompt_text=prompt_text,
prompt_wav_path=prompt_wav_path,
max_len=args.max_len,
inference_timesteps=args.inference_timesteps,
cfg_value=args.cfg_value,
)
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
reenabled_output = out_path.with_stem(out_path.stem + "_lora_reenabled")
sf.write(str(reenabled_output), audio_np, model.sample_rate)
print(f" 已保存:{reenabled_output},时长 {len(audio_np) / model.sample_rate:.2f}s")
# === 测试 4: 卸载 LoRA(reset_lora_weights) ===
print(f"\n [Test 4] 卸载 LoRA (reset_lora_weights)...")
model.reset_lora_weights()
with torch.inference_mode():
audio = model.generate(
target_text=args.text,
prompt_text=prompt_text,
prompt_wav_path=prompt_wav_path,
max_len=args.max_len,
inference_timesteps=args.inference_timesteps,
cfg_value=args.cfg_value,
)
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
reset_output = out_path.with_stem(out_path.stem + "_lora_reset")
sf.write(str(reset_output), audio_np, model.sample_rate)
print(f" 已保存:{reset_output},时长 {len(audio_np) / model.sample_rate:.2f}s")
# === 测试 5: 热加载 LoRA(重新加载权重) ===
print(f"\n [Test 5] 热加载 LoRA (load_lora_weights)...")
loaded, _ = model.load_lora_weights(str(ckpt_dir))
print(f" 重新加载了 {len(loaded)} 个参数")
with torch.inference_mode():
audio = model.generate(
target_text=args.text,
prompt_text=prompt_text,
prompt_wav_path=prompt_wav_path,
max_len=args.max_len,
inference_timesteps=args.inference_timesteps,
cfg_value=args.cfg_value,
)
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
reload_output = out_path.with_stem(out_path.stem + "_lora_reloaded")
sf.write(str(reload_output), audio_np, model.sample_rate)
print(f" 已保存:{reload_output},时长 {len(audio_np) / model.sample_rate:.2f}s")
print(f"\n[完成] 所有测试完成!")
print(f" - with_lora: {lora_output}")
print(f" - lora_disabled: {disabled_output}")
print(f" - lora_reenabled: {reenabled_output}")
print(f" - lora_reset: {reset_output}")
print(f" - lora_reloaded: {reload_output}")
if __name__ == "__main__":
main()
|