nxhong commited on
Commit
b166fd7
·
verified ·
1 Parent(s): 591dff6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -148
app.py CHANGED
@@ -5,38 +5,37 @@ import re
5
  import time
6
  import uuid
7
  from io import StringIO
 
8
 
9
  import gradio as gr
10
  import spaces
11
  import torch
12
  import torchaudio
 
 
 
13
  from huggingface_hub import HfApi, hf_hub_download, snapshot_download
14
  from TTS.tts.configs.xtts_config import XttsConfig
15
  from TTS.tts.models.xtts import Xtts
16
  from vinorm import TTSnorm
17
 
18
- # download for mecab
 
19
  os.system("python -m unidic download")
20
 
21
  HF_TOKEN = os.environ.get("HF_TOKEN")
22
  api = HfApi(token=HF_TOKEN)
23
 
24
- # This will trigger downloading model
25
- print("Downloading if not downloaded viXTTS")
26
  checkpoint_dir = "model/"
27
  repo_id = "capleaf/viXTTS"
28
  use_deepspeed = False
29
 
30
  os.makedirs(checkpoint_dir, exist_ok=True)
31
-
32
  required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"]
33
  files_in_dir = os.listdir(checkpoint_dir)
34
  if not all(file in files_in_dir for file in required_files):
35
- snapshot_download(
36
- repo_id=repo_id,
37
- repo_type="model",
38
- local_dir=checkpoint_dir,
39
- )
40
  hf_hub_download(
41
  repo_id="coqui/XTTS-v2",
42
  filename="speakers_xtts.pth",
@@ -47,17 +46,16 @@ xtts_config = os.path.join(checkpoint_dir, "config.json")
47
  config = XttsConfig()
48
  config.load_json(xtts_config)
49
  MODEL = Xtts.init_from_config(config)
50
- MODEL.load_checkpoint(
51
- config, checkpoint_dir=checkpoint_dir, use_deepspeed=use_deepspeed
52
- )
53
  if torch.cuda.is_available():
54
  MODEL.cuda()
55
 
56
  supported_languages = config.languages
57
- if not "vi" in supported_languages:
58
  supported_languages.append("vi")
59
 
60
 
 
61
  def normalize_vietnamese_text(text):
62
  text = (
63
  TTSnorm(text, unknown=False, lower=False, rule=True)
@@ -75,13 +73,11 @@ def normalize_vietnamese_text(text):
75
 
76
 
77
  def calculate_keep_len(text, lang):
78
- """Simple hack for short sentences"""
79
  if lang in ["ja", "zh-cn"]:
80
  return -1
81
-
82
  word_count = len(text.split())
83
  num_punct = text.count(".") + text.count("!") + text.count("?") + text.count(",")
84
-
85
  if word_count < 5:
86
  return 15000 * word_count + 2000 * num_punct
87
  elif word_count < 10:
@@ -89,63 +85,33 @@ def calculate_keep_len(text, lang):
89
  return -1
90
 
91
 
 
92
  @spaces.GPU
93
- def predict(
94
- prompt,
95
- language,
96
- audio_file_pth,
97
- normalize_text=True,
98
- ):
99
  if language not in supported_languages:
100
  metrics_text = gr.Warning(
101
- f"Language you put {language} in is not in is not in our Supported Languages, please choose from dropdown"
102
  )
103
-
104
  return (None, metrics_text)
105
 
106
  speaker_wav = audio_file_pth
107
-
108
  if len(prompt) < 2:
109
- metrics_text = gr.Warning("Please give a longer prompt text")
110
- return (None, metrics_text)
111
-
112
- # if len(prompt) > 250:
113
- # metrics_text = gr.Warning(
114
- # str(len(prompt))
115
- # + " characters.\n"
116
- # + "Your prompt is too long, please keep it under 250 characters\n"
117
- # + "Văn bản quá dài, vui lòng giữ dưới 250 ký tự."
118
- # )
119
- # return (None, metrics_text)
120
 
121
  try:
122
  metrics_text = ""
123
- t_latent = time.time()
124
-
125
- try:
126
- (
127
- gpt_cond_latent,
128
- speaker_embedding,
129
- ) = MODEL.get_conditioning_latents(
130
- audio_path=speaker_wav,
131
- gpt_cond_len=30,
132
- gpt_cond_chunk_len=4,
133
- max_ref_length=60,
134
- )
135
-
136
- except Exception as e:
137
- print("Speaker encoding error", str(e))
138
- metrics_text = gr.Warning(
139
- "It appears something wrong with reference, did you unmute your microphone?"
140
- )
141
- return (None, metrics_text)
142
-
143
- prompt = re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)", r"\1 \2\2", prompt)
144
 
145
  if normalize_text and language == "vi":
146
  prompt = normalize_vietnamese_text(prompt)
147
 
148
- print("I: Generating new audio...")
149
  t0 = time.time()
150
  out = MODEL.inference(
151
  prompt,
@@ -156,134 +122,71 @@ def predict(
156
  temperature=0.75,
157
  enable_text_splitting=True,
158
  )
 
159
  inference_time = time.time() - t0
160
- print(f"I: Time to generate audio: {round(inference_time*1000)} milliseconds")
161
- metrics_text += (
162
- f"Time to generate audio: {round(inference_time*1000)} milliseconds\n"
163
- )
164
  real_time_factor = (time.time() - t0) / out["wav"].shape[-1] * 24000
165
- print(f"Real-time factor (RTF): {real_time_factor}")
166
  metrics_text += f"Real-time factor (RTF): {real_time_factor:.2f}\n"
167
 
168
- # Temporary hack for short sentences
169
  keep_len = calculate_keep_len(prompt, language)
170
  out["wav"] = out["wav"][:keep_len]
171
-
172
  torchaudio.save("output.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
173
 
174
- except RuntimeError as e:
175
- if "device-side assert" in str(e):
176
- # cannot do anything on cuda device side error, need to restart
177
- print(
178
- f"Exit due to: Unrecoverable exception caused by language:{language} prompt:{prompt}",
179
- flush=True,
180
- )
181
- gr.Warning("Unhandled Exception encounter, please retry in a minute")
182
- print("Cuda device-assert Runtime encountered need restart")
183
-
184
- error_time = datetime.datetime.now().strftime("%d-%m-%Y-%H:%M:%S")
185
- error_data = [
186
- error_time,
187
- prompt,
188
- language,
189
- audio_file_pth,
190
- ]
191
- error_data = [str(e) if type(e) != str else e for e in error_data]
192
- print(error_data)
193
- print(speaker_wav)
194
- write_io = StringIO()
195
- csv.writer(write_io).writerows([error_data])
196
- csv_upload = write_io.getvalue().encode()
197
-
198
- filename = error_time + "_" + str(uuid.uuid4()) + ".csv"
199
- print("Writing error csv")
200
- error_api = HfApi()
201
- error_api.upload_file(
202
- path_or_fileobj=csv_upload,
203
- path_in_repo=filename,
204
- repo_id="coqui/xtts-flagged-dataset",
205
- repo_type="dataset",
206
- )
207
-
208
- # speaker_wav
209
- print("Writing error reference audio")
210
- speaker_filename = error_time + "_reference_" + str(uuid.uuid4()) + ".wav"
211
- error_api = HfApi()
212
- error_api.upload_file(
213
- path_or_fileobj=speaker_wav,
214
- path_in_repo=speaker_filename,
215
- repo_id="coqui/xtts-flagged-dataset",
216
- repo_type="dataset",
217
- )
218
 
219
- # HF Space specific.. This error is unrecoverable need to restart space
220
- space = api.get_space_runtime(repo_id=repo_id)
221
- if space.stage != "BUILDING":
222
- api.restart_space(repo_id=repo_id)
223
- else:
224
- print("TRIED TO RESTART but space is building")
225
-
226
- else:
227
- if "Failed to decode" in str(e):
228
- print("Speaker encoding error", str(e))
229
- metrics_text = gr.Warning(
230
- metrics_text="It appears something wrong with reference, did you unmute your microphone?"
231
- )
232
- else:
233
- print("RuntimeError: non device-side assert error:", str(e))
234
- metrics_text = gr.Warning(
235
- "Something unexpected happened please retry again."
236
- )
237
- return (None, metrics_text)
238
  return ("output.wav", metrics_text)
239
 
240
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  with gr.Blocks(analytics_enabled=False) as demo:
242
  gr.Markdown("# 🇻🇳 Text to Speech Vietnamese (capleaf/viXTTS)")
243
- gr.Markdown("Nhập văn bản tiếng Việt và chọn giọng mẫu để tạo âm thanh 🎙️")
244
 
245
  with gr.Row():
246
  with gr.Column(scale=1):
247
  input_text_gr = gr.Textbox(
248
  label="Nhập văn bản",
249
- placeholder="Nhập câu tiếng Việt để chuyển thành giọng nói...",
250
- lines=3,
251
- interactive=True,
252
- value="Xin chào, tôi là mô hình chuyển đổi văn bản thành giọng nói tiếng Việt."
253
  )
254
-
255
  language_gr = gr.Dropdown(
256
  label="Ngôn ngữ",
257
  choices=["vi", "en", "zh-cn", "ja", "ko"],
258
  value="vi",
259
- interactive=True
260
- )
261
-
262
- normalize_text = gr.Checkbox(
263
- label="Chuẩn hóa văn bản tiếng Việt",
264
- value=True
265
  )
266
-
267
  ref_gr = gr.Audio(
268
  label="Giọng mẫu (Reference Audio)",
269
  type="filepath",
270
- value="model/samples/nu-luu-loat.wav"
271
  )
272
-
273
  tts_button = gr.Button("▶️ Đọc văn bản", variant="primary")
274
 
275
  with gr.Column(scale=1):
276
  audio_gr = gr.Audio(label="Kết quả giọng nói", autoplay=True)
277
  out_text_gr = gr.Textbox(label="Thông tin chi tiết", interactive=False)
278
 
279
- # Nút sinh âm thanh
280
  tts_button.click(
281
  predict,
282
  inputs=[input_text_gr, language_gr, ref_gr, normalize_text],
283
- outputs=[audio_gr, out_text_gr]
284
  )
285
 
286
- # Khi chạy Space sẽ tự test 1 câu luôn
287
  demo.load(
288
  predict,
289
  inputs=[
@@ -295,5 +198,12 @@ with gr.Blocks(analytics_enabled=False) as demo:
295
  outputs=[audio_gr, out_text_gr],
296
  )
297
 
298
- demo.queue()
299
- demo.launch(debug=True, show_api=True, share=True)
 
 
 
 
 
 
 
 
5
  import time
6
  import uuid
7
  from io import StringIO
8
+ import threading
9
 
10
  import gradio as gr
11
  import spaces
12
  import torch
13
  import torchaudio
14
+ from fastapi import FastAPI
15
+ from fastapi.responses import FileResponse
16
+ import uvicorn
17
  from huggingface_hub import HfApi, hf_hub_download, snapshot_download
18
  from TTS.tts.configs.xtts_config import XttsConfig
19
  from TTS.tts.models.xtts import Xtts
20
  from vinorm import TTSnorm
21
 
22
+
23
+ # ========== SETUP MODEL ==========
24
  os.system("python -m unidic download")
25
 
26
  HF_TOKEN = os.environ.get("HF_TOKEN")
27
  api = HfApi(token=HF_TOKEN)
28
 
29
+ print("🔽 Downloading model if not exist...")
 
30
  checkpoint_dir = "model/"
31
  repo_id = "capleaf/viXTTS"
32
  use_deepspeed = False
33
 
34
  os.makedirs(checkpoint_dir, exist_ok=True)
 
35
  required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"]
36
  files_in_dir = os.listdir(checkpoint_dir)
37
  if not all(file in files_in_dir for file in required_files):
38
+ snapshot_download(repo_id=repo_id, repo_type="model", local_dir=checkpoint_dir)
 
 
 
 
39
  hf_hub_download(
40
  repo_id="coqui/XTTS-v2",
41
  filename="speakers_xtts.pth",
 
46
  config = XttsConfig()
47
  config.load_json(xtts_config)
48
  MODEL = Xtts.init_from_config(config)
49
+ MODEL.load_checkpoint(config, checkpoint_dir=checkpoint_dir, use_deepspeed=use_deepspeed)
 
 
50
  if torch.cuda.is_available():
51
  MODEL.cuda()
52
 
53
  supported_languages = config.languages
54
+ if "vi" not in supported_languages:
55
  supported_languages.append("vi")
56
 
57
 
58
+ # ========== UTILS ==========
59
  def normalize_vietnamese_text(text):
60
  text = (
61
  TTSnorm(text, unknown=False, lower=False, rule=True)
 
73
 
74
 
75
  def calculate_keep_len(text, lang):
76
+ """Hack giữ độ dài âm thanh cho câu ngắn"""
77
  if lang in ["ja", "zh-cn"]:
78
  return -1
 
79
  word_count = len(text.split())
80
  num_punct = text.count(".") + text.count("!") + text.count("?") + text.count(",")
 
81
  if word_count < 5:
82
  return 15000 * word_count + 2000 * num_punct
83
  elif word_count < 10:
 
85
  return -1
86
 
87
 
88
+ # ========== PREDICT ==========
89
  @spaces.GPU
90
+ def predict(prompt, language, audio_file_pth, normalize_text=True):
 
 
 
 
 
91
  if language not in supported_languages:
92
  metrics_text = gr.Warning(
93
+ f"Language '{language}' không được hỗ trợ. Vui lòng chọn lại."
94
  )
 
95
  return (None, metrics_text)
96
 
97
  speaker_wav = audio_file_pth
 
98
  if len(prompt) < 2:
99
+ return (None, gr.Warning("Vui lòng nhập câu dài hơn."))
 
 
 
 
 
 
 
 
 
 
100
 
101
  try:
102
  metrics_text = ""
103
+ print("🎙️ Encoding speaker...")
104
+ (gpt_cond_latent, speaker_embedding) = MODEL.get_conditioning_latents(
105
+ audio_path=speaker_wav,
106
+ gpt_cond_len=30,
107
+ gpt_cond_chunk_len=4,
108
+ max_ref_length=60,
109
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  if normalize_text and language == "vi":
112
  prompt = normalize_vietnamese_text(prompt)
113
 
114
+ print("⚙️ Generating speech...")
115
  t0 = time.time()
116
  out = MODEL.inference(
117
  prompt,
 
122
  temperature=0.75,
123
  enable_text_splitting=True,
124
  )
125
+
126
  inference_time = time.time() - t0
127
+ metrics_text += f"Thời gian sinh âm thanh: {round(inference_time*1000)} ms\n"
 
 
 
128
  real_time_factor = (time.time() - t0) / out["wav"].shape[-1] * 24000
 
129
  metrics_text += f"Real-time factor (RTF): {real_time_factor:.2f}\n"
130
 
 
131
  keep_len = calculate_keep_len(prompt, language)
132
  out["wav"] = out["wav"][:keep_len]
 
133
  torchaudio.save("output.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
134
 
135
+ except Exception as e:
136
+ print(" Error:", str(e))
137
+ return (None, gr.Warning(f"Lỗi khi tạo giọng nói: {str(e)}"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  return ("output.wav", metrics_text)
140
 
141
 
142
+ # ========== FASTAPI ENDPOINT ==========
143
+ api_app = FastAPI()
144
+
145
+
146
+ @api_app.post("/api/speak")
147
+ def speak_api(text: str = "Xin chào, tôi là mô hình viXTTS.", language: str = "vi"):
148
+ """
149
+ API endpoint để sinh giọng nói từ văn bản và ngôn ngữ.
150
+ """
151
+ ref_audio = "model/samples/nu-luu-loat.wav"
152
+ audio_path, _ = predict(text, language, ref_audio, True)
153
+ return FileResponse(audio_path, media_type="audio/wav")
154
+
155
+
156
+ # ========== GRADIO UI ==========
157
  with gr.Blocks(analytics_enabled=False) as demo:
158
  gr.Markdown("# 🇻🇳 Text to Speech Vietnamese (capleaf/viXTTS)")
159
+ gr.Markdown("Nhập văn bản và chọn giọng mẫu để tạo âm thanh 🎙️")
160
 
161
  with gr.Row():
162
  with gr.Column(scale=1):
163
  input_text_gr = gr.Textbox(
164
  label="Nhập văn bản",
165
+ value="Xin chào, tôi hình chuyển văn bản thành giọng nói tiếng Việt.",
 
 
 
166
  )
 
167
  language_gr = gr.Dropdown(
168
  label="Ngôn ngữ",
169
  choices=["vi", "en", "zh-cn", "ja", "ko"],
170
  value="vi",
 
 
 
 
 
 
171
  )
172
+ normalize_text = gr.Checkbox(label="Chuẩn hóa văn bản tiếng Việt", value=True)
173
  ref_gr = gr.Audio(
174
  label="Giọng mẫu (Reference Audio)",
175
  type="filepath",
176
+ value="model/samples/nu-luu-loat.wav",
177
  )
 
178
  tts_button = gr.Button("▶️ Đọc văn bản", variant="primary")
179
 
180
  with gr.Column(scale=1):
181
  audio_gr = gr.Audio(label="Kết quả giọng nói", autoplay=True)
182
  out_text_gr = gr.Textbox(label="Thông tin chi tiết", interactive=False)
183
 
 
184
  tts_button.click(
185
  predict,
186
  inputs=[input_text_gr, language_gr, ref_gr, normalize_text],
187
+ outputs=[audio_gr, out_text_gr],
188
  )
189
 
 
190
  demo.load(
191
  predict,
192
  inputs=[
 
198
  outputs=[audio_gr, out_text_gr],
199
  )
200
 
201
+
202
+ # ========== RUN BOTH FASTAPI + GRADIO ==========
203
+ if __name__ == "__main__":
204
+ def run_api():
205
+ uvicorn.run(api_app, host="0.0.0.0", port=8000)
206
+
207
+ threading.Thread(target=run_api, daemon=True).start()
208
+ demo.queue()
209
+ demo.launch(server_name="0.0.0.0", server_port=7860)