nhantrungsp commited on
Commit
d46de93
·
verified ·
1 Parent(s): 847b717

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +20 -23
gradio_app.py CHANGED
@@ -1,21 +1,24 @@
1
- import gradio as gr
2
- import soundfile as sf
3
- import tempfile
4
- import torch
5
- from vieneu_tts import VieNeuTTS
6
  import os
7
  import time
8
  import threading
9
  import pickle
10
  import hashlib
 
 
 
11
  import numpy as np
 
 
 
 
12
  from pydub import AudioSegment
 
13
  from fastapi import FastAPI, HTTPException
14
- from fastapi.responses import FileResponse
15
  from pydantic import BaseModel
16
- import base64
17
- import io
18
- import spaces # <--- THÊM THƯ VIỆN NÀY
19
 
20
  # --- KHỞI TẠO FASTAPI ---
21
  app = FastAPI()
@@ -23,7 +26,6 @@ app = FastAPI()
23
  print("⏳ Đang khởi động VieNeu-TTS...")
24
 
25
  # --- 1. SETUP MODEL ---
26
- # Trên ZeroGPU, ban đầu có thể nó nhận là CPU, nhưng @spaces.GPU sẽ lo phần chuyển đổi sau
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
28
  print(f"🖥️ Sử dụng thiết bị (Global): {device.upper()}")
29
 
@@ -54,6 +56,7 @@ def save_cache_to_disk(cache_key, ref_codes):
54
 
55
  # Load Model
56
  try:
 
57
  tts = VieNeuTTS(
58
  backbone_repo="pnnbao-ump/VieNeu-TTS",
59
  backbone_device=device,
@@ -81,7 +84,7 @@ VOICE_SAMPLES = {
81
 
82
  # --- 3. CORE LOGIC (Dùng chung cho cả API và UI) ---
83
 
84
- # QUAN TRỌNG: Thêm @spaces.GPU vào hàm này để báo cho HF biết đây là hàm cần GPU
85
  @spaces.GPU
86
  def core_synthesize(text, voice_choice, speed_factor):
87
  # Lấy thông tin giọng
@@ -104,16 +107,16 @@ def core_synthesize(text, voice_choice, speed_factor):
104
  else:
105
  ref_codes = load_cache_from_disk(cache_key)
106
  if ref_codes is None:
107
- # Đảm bảo model đang đúng device trước khi encode
108
  if torch.cuda.is_available():
109
- # Move model components to GPU if needed inside the decorated function
110
- # (Usually VieNeuTTS handles this based on init, but we double check)
111
- pass
112
  ref_codes = tts.encode_reference(ref_audio_path)
113
  save_cache_to_disk(cache_key, ref_codes)
114
  reference_cache[cache_key] = ref_codes
115
 
116
  # Infer
 
 
117
  wav = tts.infer(text, ref_codes, ref_text_raw)
118
 
119
  # Speed
@@ -137,6 +140,8 @@ def core_synthesize(text, voice_choice, speed_factor):
137
  # Hàm riêng cho Custom Voice cũng cần GPU
138
  @spaces.GPU
139
  def custom_synthesize_logic(text, ref_audio_path, ref_text_raw):
 
 
140
  ref_codes = tts.encode_reference(ref_audio_path)
141
  wav = tts.infer(text, ref_codes, ref_text_raw)
142
  return wav
@@ -175,19 +180,14 @@ async def fast_tts(request: FastTTSRequest):
175
  raise HTTPException(status_code=500, detail=str(e))
176
 
177
  # --- 5. GRADIO UI SETUP ---
178
- # Dùng theme Soft để tránh lỗi
179
  theme = gr.themes.Soft()
180
-
181
- # CSS
182
  css = ".container { max-width: 900px; margin: auto; }"
183
 
184
  def ui_synthesize(text, voice, custom_audio, custom_text, mode, speed):
185
  try:
186
  start = time.time()
187
- # Logic riêng cho UI (hỗ trợ custom voice)
188
  if mode == "custom_mode":
189
  wav = custom_synthesize_logic(text, custom_audio, custom_text)
190
- # (Bỏ qua speed control cho custom để code gọn)
191
  else:
192
  wav = core_synthesize(text, voice, speed)
193
 
@@ -219,7 +219,6 @@ with gr.Blocks(theme=theme, css=css, title="VieNeu-TTS") as demo:
219
  out_audio = gr.Audio(label="Kết quả", autoplay=True)
220
  out_status = gr.Textbox(label="Trạng thái")
221
 
222
- # Ẩn hiện mode
223
  mode_state = gr.Textbox(visible=False, value="preset_mode")
224
  tabs.children[0].select(lambda: "preset_mode", None, mode_state)
225
  tabs.children[1].select(lambda: "custom_mode", None, mode_state)
@@ -227,11 +226,9 @@ with gr.Blocks(theme=theme, css=css, title="VieNeu-TTS") as demo:
227
  btn.click(ui_synthesize, [inp_text, inp_voice, inp_audio, inp_ref_text, mode_state, inp_speed], [out_audio, out_status])
228
 
229
  # --- 6. MOUNT GRADIO VÀO FASTAPI ---
230
- # Đây là bước quan trọng nhất để chạy cả 2 cùng lúc
231
  app = gr.mount_gradio_app(app, demo, path="/")
232
 
233
  # --- 7. CHẠY SERVER ---
234
  if __name__ == "__main__":
235
  import uvicorn
236
- # 0.0.0.0 Mở port ra ngoài internet
237
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ import spaces # <--- QUAN TRỌNG: PHẢI ĐỂ DÒNG ĐẦU TIÊN
 
 
 
 
2
  import os
3
  import time
4
  import threading
5
  import pickle
6
  import hashlib
7
+ import base64
8
+ import io
9
+ import tempfile
10
  import numpy as np
11
+
12
+ # Các thư viện khác import sau spaces
13
+ import torch
14
+ import soundfile as sf
15
  from pydub import AudioSegment
16
+ import gradio as gr
17
  from fastapi import FastAPI, HTTPException
 
18
  from pydantic import BaseModel
19
+
20
+ # Import thư viện nội bộ
21
+ from vieneu_tts import VieNeuTTS
22
 
23
  # --- KHỞI TẠO FASTAPI ---
24
  app = FastAPI()
 
26
  print("⏳ Đang khởi động VieNeu-TTS...")
27
 
28
  # --- 1. SETUP MODEL ---
 
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
30
  print(f"🖥️ Sử dụng thiết bị (Global): {device.upper()}")
31
 
 
56
 
57
  # Load Model
58
  try:
59
+ print("📦 Đang tải model vào bộ nhớ...")
60
  tts = VieNeuTTS(
61
  backbone_repo="pnnbao-ump/VieNeu-TTS",
62
  backbone_device=device,
 
84
 
85
  # --- 3. CORE LOGIC (Dùng chung cho cả API và UI) ---
86
 
87
+ # QUAN TRỌNG: Decorator GPU
88
  @spaces.GPU
89
  def core_synthesize(text, voice_choice, speed_factor):
90
  # Lấy thông tin giọng
 
107
  else:
108
  ref_codes = load_cache_from_disk(cache_key)
109
  if ref_codes is None:
110
+ # Đảm bảo dọn dẹp bộ nhớ trước khi encode
111
  if torch.cuda.is_available():
112
+ torch.cuda.empty_cache()
 
 
113
  ref_codes = tts.encode_reference(ref_audio_path)
114
  save_cache_to_disk(cache_key, ref_codes)
115
  reference_cache[cache_key] = ref_codes
116
 
117
  # Infer
118
+ if torch.cuda.is_available():
119
+ torch.cuda.empty_cache()
120
  wav = tts.infer(text, ref_codes, ref_text_raw)
121
 
122
  # Speed
 
140
  # Hàm riêng cho Custom Voice cũng cần GPU
141
  @spaces.GPU
142
  def custom_synthesize_logic(text, ref_audio_path, ref_text_raw):
143
+ if torch.cuda.is_available():
144
+ torch.cuda.empty_cache()
145
  ref_codes = tts.encode_reference(ref_audio_path)
146
  wav = tts.infer(text, ref_codes, ref_text_raw)
147
  return wav
 
180
  raise HTTPException(status_code=500, detail=str(e))
181
 
182
  # --- 5. GRADIO UI SETUP ---
 
183
  theme = gr.themes.Soft()
 
 
184
  css = ".container { max-width: 900px; margin: auto; }"
185
 
186
  def ui_synthesize(text, voice, custom_audio, custom_text, mode, speed):
187
  try:
188
  start = time.time()
 
189
  if mode == "custom_mode":
190
  wav = custom_synthesize_logic(text, custom_audio, custom_text)
 
191
  else:
192
  wav = core_synthesize(text, voice, speed)
193
 
 
219
  out_audio = gr.Audio(label="Kết quả", autoplay=True)
220
  out_status = gr.Textbox(label="Trạng thái")
221
 
 
222
  mode_state = gr.Textbox(visible=False, value="preset_mode")
223
  tabs.children[0].select(lambda: "preset_mode", None, mode_state)
224
  tabs.children[1].select(lambda: "custom_mode", None, mode_state)
 
226
  btn.click(ui_synthesize, [inp_text, inp_voice, inp_audio, inp_ref_text, mode_state, inp_speed], [out_audio, out_status])
227
 
228
  # --- 6. MOUNT GRADIO VÀO FASTAPI ---
 
229
  app = gr.mount_gradio_app(app, demo, path="/")
230
 
231
  # --- 7. CHẠY SERVER ---
232
  if __name__ == "__main__":
233
  import uvicorn
 
234
  uvicorn.run(app, host="0.0.0.0", port=7860)