StaticFace commited on
Commit
3d73fc7
·
verified ·
1 Parent(s): a09d229

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -31
app.py CHANGED
@@ -27,13 +27,10 @@ _OriginalInferenceSession = ort.InferenceSession
27
 
28
  def _PatchedInferenceSession(*args, **kwargs):
29
  so = kwargs.get("sess_options", ort.SessionOptions())
30
-
31
  so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
32
  so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
33
-
34
  so.intra_op_num_threads = CPU_THREADS
35
  so.inter_op_num_threads = 1
36
-
37
  kwargs["sess_options"] = so
38
  return _OriginalInferenceSession(*args, **kwargs)
39
 
@@ -43,25 +40,25 @@ from pocket_tts_onnx import PocketTTSOnnx
43
 
44
  tts_cache = {}
45
 
46
- def get_tts(precision: str, temperature: float, lsd_steps: int):
47
- key = (precision, float(temperature), int(lsd_steps))
48
  if key not in tts_cache:
49
  tts_cache[key] = PocketTTSOnnx(
50
- precision=precision,
51
  temperature=float(temperature),
52
  lsd_steps=int(lsd_steps),
53
  device="cpu",
54
  )
55
  return tts_cache[key]
56
 
57
- def synthesize(ref_audio_path, text, precision, temperature, lsd_steps):
58
  text = (text or "").strip()
59
  if not ref_audio_path:
60
  raise gr.Error("Upload a reference audio file.")
61
  if not text:
62
  raise gr.Error("Enter some text.")
63
 
64
- tts = get_tts(precision, temperature, int(lsd_steps))
65
  audio = tts.generate(text=text, voice=ref_audio_path)
66
 
67
  sr = getattr(tts, "SAMPLE_RATE", 24000)
@@ -71,36 +68,22 @@ def synthesize(ref_audio_path, text, precision, temperature, lsd_steps):
71
 
72
  out_path = os.path.join(tempfile.gettempdir(), "pocket_tts_out.wav")
73
  sf.write(out_path, audio_np, sr)
74
-
75
- info = (
76
- f"CPU_THREADS = {CPU_THREADS}\n"
77
- f"precision = {precision}\n"
78
- f"temperature = {tts.temperature}\n"
79
- f"lsd_steps = {tts.lsd_steps}\n"
80
- f"sample_rate = {sr}"
81
- )
82
- return out_path, info
83
 
84
  with gr.Blocks() as demo:
85
- gr.Markdown("# Pocket TTS ONNX\nReference audio + text → output audio")
86
- info_box = gr.Textbox(label="Runtime Info", value=f"CPU_THREADS = {CPU_THREADS}", lines=6)
87
-
88
  with gr.Row():
89
- ref_audio = gr.Audio(label="Reference Audio", type="filepath")
90
- text = gr.Textbox(label="Text", lines=6, value="Hello, this is a test of voice cloning.")
91
-
92
  with gr.Row():
93
- precision = gr.Dropdown(["int8", "fp32"], value="int8", label="Precision")
94
- temperature = gr.Slider(0.1, 1.2, value=0.7, step=0.05, label="Temperature")
95
- lsd_steps = gr.Slider(1, 20, value=10, step=1, label="LSD Steps")
96
-
97
- generate = gr.Button("Generate", variant="primary")
98
- out_audio = gr.Audio(label="Output", type="filepath")
99
 
100
  generate.click(
101
  fn=synthesize,
102
- inputs=[ref_audio, text, precision, temperature, lsd_steps],
103
- outputs=[out_audio, info_box],
104
  api_name="generate",
105
  )
106
 
 
27
 
28
  def _PatchedInferenceSession(*args, **kwargs):
29
  so = kwargs.get("sess_options", ort.SessionOptions())
 
30
  so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
31
  so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
 
32
  so.intra_op_num_threads = CPU_THREADS
33
  so.inter_op_num_threads = 1
 
34
  kwargs["sess_options"] = so
35
  return _OriginalInferenceSession(*args, **kwargs)
36
 
 
40
 
41
  tts_cache = {}
42
 
43
+ def get_tts(temperature: float, lsd_steps: int):
44
+ key = (float(temperature), int(lsd_steps))
45
  if key not in tts_cache:
46
  tts_cache[key] = PocketTTSOnnx(
47
+ precision="int8",
48
  temperature=float(temperature),
49
  lsd_steps=int(lsd_steps),
50
  device="cpu",
51
  )
52
  return tts_cache[key]
53
 
54
+ def synthesize(ref_audio_path, text, temperature, lsd_steps):
55
  text = (text or "").strip()
56
  if not ref_audio_path:
57
  raise gr.Error("Upload a reference audio file.")
58
  if not text:
59
  raise gr.Error("Enter some text.")
60
 
61
+ tts = get_tts(temperature, int(lsd_steps))
62
  audio = tts.generate(text=text, voice=ref_audio_path)
63
 
64
  sr = getattr(tts, "SAMPLE_RATE", 24000)
 
68
 
69
  out_path = os.path.join(tempfile.gettempdir(), "pocket_tts_out.wav")
70
  sf.write(out_path, audio_np, sr)
71
+ return out_path
 
 
 
 
 
 
 
 
72
 
73
  with gr.Blocks() as demo:
 
 
 
74
  with gr.Row():
75
+ ref_audio = gr.Audio(type="filepath")
76
+ text = gr.Textbox(lines=6, value="Hello, this is a test.")
 
77
  with gr.Row():
78
+ temperature = gr.Slider(0.1, 1.2, value=0.7, step=0.05)
79
+ lsd_steps = gr.Slider(1, 20, value=10, step=1)
80
+ generate = gr.Button("Generate")
81
+ out_audio = gr.Audio(type="filepath")
 
 
82
 
83
  generate.click(
84
  fn=synthesize,
85
+ inputs=[ref_audio, text, temperature, lsd_steps],
86
+ outputs=[out_audio],
87
  api_name="generate",
88
  )
89