lglg666 commited on
Commit
6d32e7f
·
verified ·
1 Parent(s): 6766eda

Update VoxCPM/inference_lora.py

Browse files
Files changed (1) hide show
  1. VoxCPM/inference_lora.py +13 -4
VoxCPM/inference_lora.py CHANGED
@@ -7,7 +7,7 @@ from voxcpm.training.config import load_yaml_config
7
  import argparse
8
  import torch
9
  import os
10
-
11
  def main():
12
  parser = argparse.ArgumentParser()
13
  parser.add_argument("--lora_ckpt", type=str, required=True)
@@ -34,7 +34,10 @@ def main():
34
  training=False,
35
  lora_config=lora_cfg,
36
  )
37
-
 
 
 
38
  # 3. 加载 LoRA 权重(在 compile 后也能正常工作)
39
  ckpt_dir = Path(args.lora_ckpt)
40
  if not ckpt_dir.exists():
@@ -49,8 +52,11 @@ def main():
49
  print(f"\n[3/3] 开始推理...")
50
  if args.text:
51
  with torch.inference_mode():
 
 
 
52
  wav = model.generate(
53
- target_text=args.text,
54
  cfg_value=args.cfg_value, # LM guidance on LocDiT, higher for better adherence to the prompt, but maybe worse
55
  inference_timesteps=args.inference_timesteps, # LocDiT inference timesteps, higher for better result, lower for fast speed
56
  retry_badcase=True, # enable retrying mode for some bad cases (unstoppable)
@@ -73,8 +79,11 @@ def main():
73
  texts.append((wav_id, text))
74
  for wav_id, text in texts:
75
  with torch.inference_mode():
 
 
 
76
  wav = model.generate(
77
- target_text=text,
78
  cfg_value=args.cfg_value, # LM guidance on LocDiT, higher for better adherence to the prompt, but maybe worse
79
  inference_timesteps=args.inference_timesteps, # LocDiT inference timesteps, higher for better result, lower for fast speed
80
  retry_badcase=True, # enable retrying mode for some bad cases (unstoppable)
 
7
  import argparse
8
  import torch
9
  import os
10
+ import re
11
  def main():
12
  parser = argparse.ArgumentParser()
13
  parser.add_argument("--lora_ckpt", type=str, required=True)
 
34
  training=False,
35
  lora_config=lora_cfg,
36
  )
37
+
38
+ from src.voxcpm.utils.text_normalize import TextNormalizer
39
+ text_normalizer = TextNormalizer()
40
+
41
  # 3. 加载 LoRA 权重(在 compile 后也能正常工作)
42
  ckpt_dir = Path(args.lora_ckpt)
43
  if not ckpt_dir.exists():
 
52
  print(f"\n[3/3] 开始推理...")
53
  if args.text:
54
  with torch.inference_mode():
55
+ target_text = args.text.replace("\n", " ")
56
+ target_text = re.sub(r'\s+', ' ', target_text)
57
+ target_text = text_normalizer.normalize(target_text)
58
  wav = model.generate(
59
+ target_text=target_text,
60
  cfg_value=args.cfg_value, # LM guidance on LocDiT, higher for better adherence to the prompt, but maybe worse
61
  inference_timesteps=args.inference_timesteps, # LocDiT inference timesteps, higher for better result, lower for fast speed
62
  retry_badcase=True, # enable retrying mode for some bad cases (unstoppable)
 
79
  texts.append((wav_id, text))
80
  for wav_id, text in texts:
81
  with torch.inference_mode():
82
+ target_text = text.replace("\n", " ")
83
+ target_text = re.sub(r'\s+', ' ', target_text)
84
+ target_text = text_normalizer.normalize(target_text)
85
  wav = model.generate(
86
+ target_text=target_text,
87
  cfg_value=args.cfg_value, # LM guidance on LocDiT, higher for better adherence to the prompt, but maybe worse
88
  inference_timesteps=args.inference_timesteps, # LocDiT inference timesteps, higher for better result, lower for fast speed
89
  retry_badcase=True, # enable retrying mode for some bad cases (unstoppable)