Update VoxCPM/inference_lora.py
Browse files- VoxCPM/inference_lora.py +13 -4
VoxCPM/inference_lora.py
CHANGED
|
@@ -7,7 +7,7 @@ from voxcpm.training.config import load_yaml_config
|
|
| 7 |
import argparse
|
| 8 |
import torch
|
| 9 |
import os
|
| 10 |
-
|
| 11 |
def main():
|
| 12 |
parser = argparse.ArgumentParser()
|
| 13 |
parser.add_argument("--lora_ckpt", type=str, required=True)
|
|
@@ -34,7 +34,10 @@ def main():
|
|
| 34 |
training=False,
|
| 35 |
lora_config=lora_cfg,
|
| 36 |
)
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
| 38 |
# 3. 加载 LoRA 权重(在 compile 后也能正常工作)
|
| 39 |
ckpt_dir = Path(args.lora_ckpt)
|
| 40 |
if not ckpt_dir.exists():
|
|
@@ -49,8 +52,11 @@ def main():
|
|
| 49 |
print(f"\n[3/3] 开始推理...")
|
| 50 |
if args.text:
|
| 51 |
with torch.inference_mode():
|
|
|
|
|
|
|
|
|
|
| 52 |
wav = model.generate(
|
| 53 |
-
target_text=
|
| 54 |
cfg_value=args.cfg_value, # LM guidance on LocDiT, higher for better adherence to the prompt, but maybe worse
|
| 55 |
inference_timesteps=args.inference_timesteps, # LocDiT inference timesteps, higher for better result, lower for fast speed
|
| 56 |
retry_badcase=True, # enable retrying mode for some bad cases (unstoppable)
|
|
@@ -73,8 +79,11 @@ def main():
|
|
| 73 |
texts.append((wav_id, text))
|
| 74 |
for wav_id, text in texts:
|
| 75 |
with torch.inference_mode():
|
|
|
|
|
|
|
|
|
|
| 76 |
wav = model.generate(
|
| 77 |
-
target_text=
|
| 78 |
cfg_value=args.cfg_value, # LM guidance on LocDiT, higher for better adherence to the prompt, but maybe worse
|
| 79 |
inference_timesteps=args.inference_timesteps, # LocDiT inference timesteps, higher for better result, lower for fast speed
|
| 80 |
retry_badcase=True, # enable retrying mode for some bad cases (unstoppable)
|
|
|
|
| 7 |
import argparse
|
| 8 |
import torch
|
| 9 |
import os
|
| 10 |
+
import re
|
| 11 |
def main():
|
| 12 |
parser = argparse.ArgumentParser()
|
| 13 |
parser.add_argument("--lora_ckpt", type=str, required=True)
|
|
|
|
| 34 |
training=False,
|
| 35 |
lora_config=lora_cfg,
|
| 36 |
)
|
| 37 |
+
|
| 38 |
+
from src.voxcpm.utils.text_normalize import TextNormalizer
|
| 39 |
+
text_normalizer = TextNormalizer()
|
| 40 |
+
|
| 41 |
# 3. 加载 LoRA 权重(在 compile 后也能正常工作)
|
| 42 |
ckpt_dir = Path(args.lora_ckpt)
|
| 43 |
if not ckpt_dir.exists():
|
|
|
|
| 52 |
print(f"\n[3/3] 开始推理...")
|
| 53 |
if args.text:
|
| 54 |
with torch.inference_mode():
|
| 55 |
+
target_text = args.text.replace("\n", " ")
|
| 56 |
+
target_text = re.sub(r'\s+', ' ', target_text)
|
| 57 |
+
target_text = text_normalizer.normalize(target_text)
|
| 58 |
wav = model.generate(
|
| 59 |
+
target_text=target_text,
|
| 60 |
cfg_value=args.cfg_value, # LM guidance on LocDiT, higher for better adherence to the prompt, but maybe worse
|
| 61 |
inference_timesteps=args.inference_timesteps, # LocDiT inference timesteps, higher for better result, lower for fast speed
|
| 62 |
retry_badcase=True, # enable retrying mode for some bad cases (unstoppable)
|
|
|
|
| 79 |
texts.append((wav_id, text))
|
| 80 |
for wav_id, text in texts:
|
| 81 |
with torch.inference_mode():
|
| 82 |
+
target_text = text.replace("\n", " ")
|
| 83 |
+
target_text = re.sub(r'\s+', ' ', target_text)
|
| 84 |
+
target_text = text_normalizer.normalize(target_text)
|
| 85 |
wav = model.generate(
|
| 86 |
+
target_text=target_text,
|
| 87 |
cfg_value=args.cfg_value, # LM guidance on LocDiT, higher for better adherence to the prompt, but maybe worse
|
| 88 |
inference_timesteps=args.inference_timesteps, # LocDiT inference timesteps, higher for better result, lower for fast speed
|
| 89 |
retry_badcase=True, # enable retrying mode for some bad cases (unstoppable)
|