stephenhoang commited on
Commit
079aad4
·
verified ·
1 Parent(s): fc5f72b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -78
app.py CHANGED
@@ -1,5 +1,3 @@
1
-
2
-
3
  import os
4
  import json
5
  import tempfile
@@ -9,26 +7,22 @@ import gradio as gr
9
  import numpy as np
10
  import soundfile as sf
11
  import torch
 
12
 
13
  from inference import StyleTTS2
14
 
15
  # =========================
16
- # CONFIG: CHINH 2 DUONG DAN NAY
17
  # =========================
18
- DATA_ROOT = "./demo_data"
 
19
  SPEAKER2REFS_PATH = os.path.join(DATA_ROOT, "speaker2refs.json")
20
 
21
- # Repo StyleTTS2-lite-vi (neu app.py nam trong repo thi de "./")
22
- repo_dir = "./"
23
- config_path = os.path.join(repo_dir, "Models", "config.yaml")
24
- # models_path = os.path.join(repo_dir, "Models", "Finetune", "epoch_00000.pth")
25
- from huggingface_hub import hf_hub_download
26
-
27
  CKPT_REPO = "stephenhoang/ttsStyleTTS2-ms152"
28
  models_path = hf_hub_download(repo_id=CKPT_REPO, filename="epoch_00000.pth")
29
  config_path = hf_hub_download(repo_id=CKPT_REPO, filename="config.yaml")
30
 
31
-
32
  device = "cuda" if torch.cuda.is_available() else "cpu"
33
 
34
  # =========================
@@ -44,10 +38,18 @@ SPEAKER_CHOICES = sorted(SPEAKER2REFS.keys())
44
  if not SPEAKER_CHOICES:
45
  raise RuntimeError("speaker2refs.json is empty (no speakers found).")
46
 
47
-
48
- def _abs_audio(p: str) -> str:
49
- # p trong json đang là "demo_data/refs/id_k.wav" => join theo repo root
50
- return p if os.path.isabs(p) else os.path.join(repo_dir, p)
 
 
 
 
 
 
 
 
51
 
52
  # =========================
53
  # LOAD MODEL
@@ -55,8 +57,7 @@ def _abs_audio(p: str) -> str:
55
  model = StyleTTS2(config_path, models_path).eval().to(device)
56
 
57
  # =========================
58
- # STYLE CACHE (giam lag khi gen nhieu lan cung speaker)
59
- # key = (speaker, denoise, avg_style)
60
  # =========================
61
  STYLE_CACHE = {}
62
  STYLE_CACHE_MAX = 64
@@ -72,7 +73,6 @@ def _cache_set(key, val):
72
  STYLE_CACHE.pop(next(iter(STYLE_CACHE)))
73
  STYLE_CACHE[key] = val
74
 
75
-
76
  @torch.inference_mode()
77
  def synth_one_speaker(speaker_name: str, text_prompt: str,
78
  denoise: float, avg_style: bool, stabilize: bool):
@@ -80,66 +80,73 @@ def synth_one_speaker(speaker_name: str, text_prompt: str,
80
  if not speaker_name:
81
  return None, "Bạn chưa chọn speaker."
82
 
83
- spk = SPEAKER2REFS.get(speaker_name)
84
- if not isinstance(spk, dict):
85
- return None, f"Speaker '{speaker_name}' không đúng format trong speaker2refs.json."
 
 
 
 
86
 
87
- ref_rel = spk.get("path")
88
- if not ref_rel:
89
- return None, f"Speaker '{speaker_name}' thiếu field 'path' trong speaker2refs.json."
90
 
91
- ref_path = _abs_audio(ref_rel)
92
  if not os.path.isfile(ref_path):
93
  return None, f"Ref audio not found: {ref_path}"
94
 
95
  if not text_prompt or not text_prompt.strip():
96
  return None, "Bạn chưa nhập text."
97
 
98
- spk_lang = spk.get("lang", "vi")
99
- spk_speed = float(spk.get("speed", 1.0))
100
-
101
- # speakers dict phải dùng key đúng speaker_name (vd "id_73")
102
  speakers = {
103
- speaker_name: {"path": ref_path, "lang": spk_lang, "speed": spk_speed}
104
  }
105
 
106
  cache_key = (speaker_name, float(denoise), bool(avg_style))
107
  styles = _cache_get(cache_key)
108
  if styles is None:
109
- styles = model.get_styles(speakers, denoise, avg_style)
110
  _cache_set(cache_key, styles)
111
 
112
  text_prompt = text_prompt.strip()
113
-
114
- # Nếu user không tự thêm tag speaker, tự thêm [id_k]
115
  if "[id_" not in text_prompt:
116
- text_prompt = f"[{speaker_name}] " + text_prompt
 
 
 
 
 
 
 
 
117
 
118
- # default_speaker cũng phải là speaker đang chọn
119
- r = model.generate(text_prompt, styles, stabilize, 18, f"[{speaker_name}]")
120
- dur = len(r) / 24000
121
- print("GEN_SAMPLES=", len(r), "DUR_SEC=", dur)
122
 
123
- r = np.asarray(r, dtype=np.float32)
124
- m = float(np.max(np.abs(r))) if r.size else 0.0
125
- if m > 1e-9:
126
- r = r / m
127
 
128
  out_f = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
129
  out_path = out_f.name
130
  out_f.close()
131
- sf.write(out_path, r, samplerate=24000)
132
 
133
  status = (
134
  "OK\n"
135
  f"speaker: {speaker_name}\n"
136
- f"ref: {ref_rel}\n"
 
 
137
  f"device: {device}"
138
  )
139
  return out_path, status
140
 
141
  except Exception:
142
  return None, traceback.format_exc()
 
143
  # =========================
144
  # GRADIO UI
145
  # =========================
@@ -160,48 +167,21 @@ with gr.Blocks() as demo:
160
  )
161
 
162
  with gr.Row():
163
- denoise = gr.Slider(0.0, 1.0, step=0.1, value=0.6, label="Denoise Strength")
164
  avg_style = gr.Checkbox(label="Use Average Styles", value=True)
165
  stabilize = gr.Checkbox(label="Stabilize Speaking Speed", value=True)
166
 
167
  gen_button = gr.Button("Generate")
168
  synthesized_audio = gr.Audio(label="Generated Audio", type="filepath")
169
- status = gr.Textbox(label="Status", lines=4, interactive=False)
170
 
171
  gen_button.click(
172
  fn=synth_one_speaker,
173
  inputs=[speaker_name, text_prompt, denoise, avg_style, stabilize],
174
- outputs=[synthesized_audio, status]
 
175
  )
176
- import os
177
- import time
178
-
179
- PORT = int(os.environ.get("PORT", "7860"))
180
-
181
- if __name__ == "__main__":
182
- # queue() không truyền kwargs để khỏi lệch version
183
- try:
184
- demo.queue()
185
- except Exception:
186
- pass
187
-
188
- # launch() với fallback theo version
189
- try:
190
- demo.launch(
191
- server_name="0.0.0.0",
192
- server_port=PORT,
193
- show_error=True,
194
- ssr_mode=False,
195
- prevent_thread_lock=False, # nếu hỗ trợ thì sẽ block luôn
196
- )
197
- except TypeError:
198
- # gradio cũ không có ssr_mode / prevent_thread_lock
199
- demo.launch(
200
- server_name="0.0.0.0",
201
- server_port=PORT,
202
- show_error=True,
203
- )
204
 
205
- # nếu launch() không block (một số build), giữ process sống
206
- while True:
207
- time.sleep(3600)
 
 
 
1
  import os
2
  import json
3
  import tempfile
 
7
  import numpy as np
8
  import soundfile as sf
9
  import torch
10
+ from huggingface_hub import hf_hub_download
11
 
12
  from inference import StyleTTS2
13
 
14
  # =========================
15
+ # PATHS
16
  # =========================
17
+ SPACE_ROOT = os.path.dirname(os.path.abspath(__file__))
18
+ DATA_ROOT = os.path.join(SPACE_ROOT, "demo_data")
19
  SPEAKER2REFS_PATH = os.path.join(DATA_ROOT, "speaker2refs.json")
20
 
21
+ # Model repo (ckpt + config)
 
 
 
 
 
22
  CKPT_REPO = "stephenhoang/ttsStyleTTS2-ms152"
23
  models_path = hf_hub_download(repo_id=CKPT_REPO, filename="epoch_00000.pth")
24
  config_path = hf_hub_download(repo_id=CKPT_REPO, filename="config.yaml")
25
 
 
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
 
28
  # =========================
 
38
  if not SPEAKER_CHOICES:
39
  raise RuntimeError("speaker2refs.json is empty (no speakers found).")
40
 
41
+ def _abs_ref_path(p: str) -> str:
42
+ """
43
+ Hỗ trợ cả 2 kiểu:
44
+ - "refs/id_1.wav"
45
+ - "demo_data/refs/id_1.wav"
46
+ """
47
+ p = p.lstrip("./")
48
+ if os.path.isabs(p):
49
+ return p
50
+ if p.startswith("demo_data/"):
51
+ return os.path.join(SPACE_ROOT, p)
52
+ return os.path.join(DATA_ROOT, p)
53
 
54
  # =========================
55
  # LOAD MODEL
 
57
  model = StyleTTS2(config_path, models_path).eval().to(device)
58
 
59
  # =========================
60
+ # STYLE CACHE
 
61
  # =========================
62
  STYLE_CACHE = {}
63
  STYLE_CACHE_MAX = 64
 
73
  STYLE_CACHE.pop(next(iter(STYLE_CACHE)))
74
  STYLE_CACHE[key] = val
75
 
 
76
  @torch.inference_mode()
77
  def synth_one_speaker(speaker_name: str, text_prompt: str,
78
  denoise: float, avg_style: bool, stabilize: bool):
 
80
  if not speaker_name:
81
  return None, "Bạn chưa chọn speaker."
82
 
83
+ info = SPEAKER2REFS.get(speaker_name, None)
84
+ if info is None:
85
+ return None, f"Speaker '{speaker_name}' không tồn tại trong speaker2refs.json."
86
+
87
+ # info là dict: {"path":..., "lang":..., "speed":..., ...}
88
+ if not isinstance(info, dict) or "path" not in info:
89
+ return None, f"Format speaker2refs.json sai cho speaker '{speaker_name}'. Expect dict có field 'path'."
90
 
91
+ ref_path = _abs_ref_path(info["path"])
92
+ lang = info.get("lang", "vi")
93
+ speed = float(info.get("speed", 1.0))
94
 
 
95
  if not os.path.isfile(ref_path):
96
  return None, f"Ref audio not found: {ref_path}"
97
 
98
  if not text_prompt or not text_prompt.strip():
99
  return None, "Bạn chưa nhập text."
100
 
 
 
 
 
101
  speakers = {
102
+ "id_1": {"path": ref_path, "lang": lang, "speed": speed}
103
  }
104
 
105
  cache_key = (speaker_name, float(denoise), bool(avg_style))
106
  styles = _cache_get(cache_key)
107
  if styles is None:
108
+ styles = model.get_styles(speakers, denoise=denoise, avg_style=avg_style)
109
  _cache_set(cache_key, styles)
110
 
111
  text_prompt = text_prompt.strip()
 
 
112
  if "[id_" not in text_prompt:
113
+ text_prompt = "[id_1] " + text_prompt
114
+
115
+ wav = model.generate(
116
+ text_prompt,
117
+ styles,
118
+ stabilize=stabilize,
119
+ n_merge=18,
120
+ default_speaker="[id_1]"
121
+ )
122
 
123
+ wav = np.asarray(wav, dtype=np.float32)
124
+ if wav.size == 0:
125
+ return None, "Model output rỗng (0 samples). Kiểm tra phonemizer/espeak và tokenization."
 
126
 
127
+ # normalize (không làm mất tiếng)
128
+ peak = float(np.max(np.abs(wav)))
129
+ if peak > 1e-6:
130
+ wav = wav / peak
131
 
132
  out_f = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
133
  out_path = out_f.name
134
  out_f.close()
135
+ sf.write(out_path, wav, samplerate=24000)
136
 
137
  status = (
138
  "OK\n"
139
  f"speaker: {speaker_name}\n"
140
+ f"ref: {ref_path}\n"
141
+ f"lang: {lang}, speed: {speed}\n"
142
+ f"samples: {wav.shape[0]}, sec: {wav.shape[0]/24000:.3f}\n"
143
  f"device: {device}"
144
  )
145
  return out_path, status
146
 
147
  except Exception:
148
  return None, traceback.format_exc()
149
+
150
  # =========================
151
  # GRADIO UI
152
  # =========================
 
167
  )
168
 
169
  with gr.Row():
170
+ denoise = gr.Slider(0.0, 1.0, step=0.1, value=0.3, label="Denoise Strength")
171
  avg_style = gr.Checkbox(label="Use Average Styles", value=True)
172
  stabilize = gr.Checkbox(label="Stabilize Speaking Speed", value=True)
173
 
174
  gen_button = gr.Button("Generate")
175
  synthesized_audio = gr.Audio(label="Generated Audio", type="filepath")
176
+ status = gr.Textbox(label="Status", lines=6, interactive=False)
177
 
178
  gen_button.click(
179
  fn=synth_one_speaker,
180
  inputs=[speaker_name, text_prompt, denoise, avg_style, stabilize],
181
+ outputs=[synthesized_audio, status],
182
+ concurrency_limit=1,
183
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
+ # Gradio: dùng queue() chuẩn, không dùng concurrency_count
186
+ demo.queue(max_size=8, default_concurrency_limit=1) # theo docs :contentReference[oaicite:2]{index=2}
187
+ demo.launch()