nhantrungsp commited on
Commit
4a8f5a5
·
verified ·
1 Parent(s): 4196956

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +56 -124
gradio_app.py CHANGED
@@ -1,11 +1,9 @@
1
- import spaces # <--- BẮT BUỘC DÒNG 1
2
  import os
3
  import time
4
  import threading
5
  import pickle
6
  import hashlib
7
- import base64
8
- import io
9
  import tempfile
10
  import numpy as np
11
 
@@ -14,19 +12,32 @@ import torch
14
  import soundfile as sf
15
  from pydub import AudioSegment
16
  import gradio as gr
17
- from fastapi import FastAPI, HTTPException
18
- from pydantic import BaseModel
19
  from vieneu_tts import VieNeuTTS
20
 
21
- # --- KHỞI TẠO ---
22
- app = FastAPI()
23
- print("⏳ Đang khởi động Server...")
24
 
25
- # Biến toàn cục để lưu model (Lazy Load)
26
  tts_model = None
27
  model_lock = threading.Lock()
28
 
29
- # Cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  CACHE_DIR = "./reference_cache"
31
  os.makedirs(CACHE_DIR, exist_ok=True)
32
  reference_cache = {}
@@ -50,30 +61,7 @@ def save_cache_to_disk(cache_key, ref_codes):
50
  with open(cache_path, 'wb') as f: pickle.dump(ref_codes, f)
51
  except Exception: pass
52
 
53
- # --- HELPER: LOAD MODEL AN TOÀN ---
54
- def get_tts_model():
55
- """Hàm này chỉ tải model khi được gọi lần đầu tiên"""
56
- global tts_model
57
- with model_lock:
58
- if tts_model is None:
59
- print("📦 Đang khởi tạo model lần đầu (Lazy Load)...")
60
- device = "cuda" if torch.cuda.is_available() else "cpu"
61
- print(f" 🖥️ Device: {device}")
62
- try:
63
- # Load model
64
- tts_model = VieNeuTTS(
65
- backbone_repo="pnnbao-ump/VieNeu-TTS",
66
- backbone_device=device,
67
- codec_repo="neuphonic/neucodec",
68
- codec_device=device
69
- )
70
- print(" ✅ Model tải thành công!")
71
- except Exception as e:
72
- print(f" ❌ Lỗi tải model: {e}")
73
- raise e
74
- return tts_model
75
-
76
- # --- DATA ---
77
  VOICE_SAMPLES = {
78
  "Tuyên (nam miền Bắc)": {"audio": "./sample/Tuyên (nam miền Bắc).wav", "text": "./sample/Tuyên (nam miền Bắc).txt"},
79
  "Vĩnh (nam miền Nam)": {"audio": "./sample/Vĩnh (nam miền Nam).wav", "text": "./sample/Vĩnh (nam miền Nam).txt"},
@@ -87,14 +75,20 @@ VOICE_SAMPLES = {
87
  "Nhỏ Ngọt Ngào": {"audio": "./sample/Nhỏ Ngọt Ngào.wav", "text": "./sample/Nhỏ Ngọt Ngào.txt"},
88
  }
89
 
90
- # --- CORE LOGIC (DECORATED WITH @spaces.GPU) ---
91
 
92
- @spaces.GPU(duration=120) # Tăng thời gian timeout lên 120s cho lần đầu load model
93
- def core_synthesize(text, voice_choice, speed_factor):
94
- # 1. Lấy model (Sẽ tải nếu chưa có)
 
 
 
 
 
 
95
  tts = get_tts_model()
96
 
97
- # 2. Đảm bảo model đúng device (GPU)
98
  if torch.cuda.is_available():
99
  try:
100
  if next(tts.backbone.parameters()).device.type != 'cuda':
@@ -105,7 +99,9 @@ def core_synthesize(text, voice_choice, speed_factor):
105
  # 3. Lấy thông tin giọng
106
  voice_info = VOICE_SAMPLES.get(voice_choice)
107
  if not voice_info:
108
- raise ValueError("Giọng không tồn tại")
 
 
109
 
110
  ref_audio_path = voice_info["audio"]
111
  ref_text_path = voice_info["text"]
@@ -113,7 +109,7 @@ def core_synthesize(text, voice_choice, speed_factor):
113
  with open(ref_text_path, "r", encoding="utf-8") as f:
114
  ref_text_raw = f.read()
115
 
116
- # 4. Encode Reference
117
  cache_key = f"preset:{voice_choice}"
118
  with reference_cache_lock:
119
  if cache_key in reference_cache:
@@ -123,18 +119,18 @@ def core_synthesize(text, voice_choice, speed_factor):
123
  else:
124
  ref_codes = load_cache_from_disk(cache_key)
125
  if ref_codes is None:
 
126
  ref_codes = tts.encode_reference(ref_audio_path)
127
- # Cache trên CPU
128
  save_cache_to_disk(cache_key, ref_codes.cpu() if isinstance(ref_codes, torch.Tensor) else ref_codes)
129
 
130
  if isinstance(ref_codes, torch.Tensor) and torch.cuda.is_available():
131
  ref_codes = ref_codes.to("cuda")
132
  reference_cache[cache_key] = ref_codes
133
 
134
- # 5. Infer
135
  wav = tts.infer(text, ref_codes, ref_text_raw)
136
 
137
- # 6. Speed Control (CPU Processing)
138
  if speed_factor != 1.0:
139
  with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
140
  sf.write(tmp.name, wav, 24000)
@@ -149,100 +145,36 @@ def core_synthesize(text, voice_choice, speed_factor):
149
  if sound_stretched.channels == 2:
150
  wav = wav.reshape((-1, 2)).mean(axis=1)
151
  os.unlink(tmp_path)
152
-
153
- return wav
154
 
155
- @spaces.GPU(duration=120)
156
- def custom_synthesize_logic(text, ref_audio_path, ref_text_raw):
157
- tts = get_tts_model()
158
- if torch.cuda.is_available():
159
- try:
160
- if next(tts.backbone.parameters()).device.type != 'cuda':
161
- tts.backbone.to("cuda")
162
- tts.codec.to("cuda")
163
- except: pass
164
 
165
- ref_codes = tts.encode_reference(ref_audio_path)
166
- wav = tts.infer(text, ref_codes, ref_text_raw)
167
- return wav
168
 
169
- # --- API ---
170
- class FastTTSRequest(BaseModel):
171
- text: str
172
- voice_choice: str
173
- speed_factor: float = 1.0
174
- return_base64: bool = False
175
-
176
- @app.get("/voices")
177
- async def get_voices():
178
- return {"voices": list(VOICE_SAMPLES.keys())}
179
-
180
- @app.post("/fast-tts")
181
- async def fast_tts(request: FastTTSRequest):
182
- try:
183
- start = time.time()
184
- # Gọi hàm GPU
185
- wav = core_synthesize(request.text, request.voice_choice, request.speed_factor)
186
- process_time = time.time() - start
187
-
188
- audio_buffer = io.BytesIO()
189
- sf.write(audio_buffer, wav, 24000, format='WAV')
190
- audio_base64 = base64.b64encode(audio_buffer.getvalue()).decode('utf-8')
191
-
192
- return {
193
- "status": "success",
194
- "audio_base64": audio_base64,
195
- "processing_time": process_time
196
- }
197
- except Exception as e:
198
- raise HTTPException(status_code=500, detail=str(e))
199
-
200
- # --- GRADIO UI ---
201
  theme = gr.themes.Soft()
202
  css = ".container { max-width: 900px; margin: auto; }"
203
 
204
- def ui_synthesize(text, voice, custom_audio, custom_text, mode, speed):
205
- try:
206
- start = time.time()
207
- if mode == "custom_mode":
208
- wav = custom_synthesize_logic(text, custom_audio, custom_text)
209
- else:
210
- wav = core_synthesize(text, voice, speed)
211
-
212
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
213
- sf.write(tmp.name, wav, 24000)
214
- path = tmp.name
215
- return path, f"✅ Xong! ({time.time()-start:.2f}s)"
216
- except Exception as e:
217
- return None, f"❌ Lỗi: {e}"
218
-
219
  with gr.Blocks(theme=theme, css=css, title="VieNeu-TTS") as demo:
220
- gr.Markdown("# 🎙️ VieNeu-TTS (API + UI)")
221
 
222
  with gr.Row():
223
  with gr.Column():
224
- inp_text = gr.Textbox(label="Văn bản", lines=3, value="Xin chào Việt Nam")
225
- with gr.Tabs() as tabs:
226
- with gr.TabItem("Giọng mẫu", id="preset_mode"):
227
- inp_voice = gr.Dropdown(list(VOICE_SAMPLES.keys()), value="Tuyên (nam miền Bắc)", label="Chọn giọng")
228
- with gr.TabItem("Custom", id="custom_mode"):
229
- inp_audio = gr.Audio(type="filepath")
230
- inp_ref_text = gr.Textbox(label="Lời thoại mẫu")
231
  inp_speed = gr.Slider(0.5, 2.0, value=1.0, label="Tốc độ")
232
  btn = gr.Button("Đọc ngay", variant="primary")
 
233
  with gr.Column():
234
  out_audio = gr.Audio(label="Kết quả", autoplay=True)
235
  out_status = gr.Textbox(label="Trạng thái")
 
 
 
236
 
237
- mode_state = gr.Textbox(visible=False, value="preset_mode")
238
- tabs.children[0].select(lambda: "preset_mode", None, mode_state)
239
- tabs.children[1].select(lambda: "custom_mode", None, mode_state)
240
- btn.click(ui_synthesize, [inp_text, inp_voice, inp_audio, inp_ref_text, mode_state, inp_speed], [out_audio, out_status])
241
-
242
- # Mount Gradio vào FastAPI
243
- app = gr.mount_gradio_app(app, demo, path="/")
244
-
245
  if __name__ == "__main__":
246
- import uvicorn
247
- # Mở port 7860 để Hugging Face truy cập
248
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ import spaces # <--- BẮT BUỘC DÒNG 1
2
  import os
3
  import time
4
  import threading
5
  import pickle
6
  import hashlib
 
 
7
  import tempfile
8
  import numpy as np
9
 
 
12
  import soundfile as sf
13
  from pydub import AudioSegment
14
  import gradio as gr
 
 
15
  from vieneu_tts import VieNeuTTS
16
 
17
+ print("⏳ Đang khởi động Server Gradio...")
 
 
18
 
19
+ # --- 1. QUẢN MODEL (Lazy Loading) ---
20
  tts_model = None
21
  model_lock = threading.Lock()
22
 
23
+ def get_tts_model():
24
+ """Chỉ tải model khi có người dùng gọi (Tiết kiệm tài nguyên khởi động)"""
25
+ global tts_model
26
+ with model_lock:
27
+ if tts_model is None:
28
+ print("📦 Đang khởi tạo model lần đầu (Lazy Load)...")
29
+ # ZeroGPU yêu cầu khởi tạo model trên CPU hoặc trong hàm @spaces.GPU
30
+ # Ở đây ta khởi tạo trên CPU cho an toàn
31
+ tts_model = VieNeuTTS(
32
+ backbone_repo="pnnbao-ump/VieNeu-TTS",
33
+ backbone_device="cpu",
34
+ codec_repo="neuphonic/neucodec",
35
+ codec_device="cpu"
36
+ )
37
+ print("✅ Model tải thành công!")
38
+ return tts_model
39
+
40
+ # --- 2. XỬ LÝ CACHE ---
41
  CACHE_DIR = "./reference_cache"
42
  os.makedirs(CACHE_DIR, exist_ok=True)
43
  reference_cache = {}
 
61
  with open(cache_path, 'wb') as f: pickle.dump(ref_codes, f)
62
  except Exception: pass
63
 
64
+ # --- 3. DỮ LIỆU GIỌNG NÓI ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  VOICE_SAMPLES = {
66
  "Tuyên (nam miền Bắc)": {"audio": "./sample/Tuyên (nam miền Bắc).wav", "text": "./sample/Tuyên (nam miền Bắc).txt"},
67
  "Vĩnh (nam miền Nam)": {"audio": "./sample/Vĩnh (nam miền Nam).wav", "text": "./sample/Vĩnh (nam miền Nam).txt"},
 
75
  "Nhỏ Ngọt Ngào": {"audio": "./sample/Nhỏ Ngọt Ngào.wav", "text": "./sample/Nhỏ Ngọt Ngào.txt"},
76
  }
77
 
78
+ # --- 4. HÀM XỬ CHÍNH (GPU) ---
79
 
80
+ @spaces.GPU(duration=120)
81
+ def generate_speech(text, voice_choice, speed_factor):
82
+ """
83
+ Hàm này sẽ được ZeroGPU cấp phát GPU khi chạy.
84
+ Nó cũng đóng vai trò là API endpoint chính.
85
+ """
86
+ start_time = time.time()
87
+
88
+ # 1. Lấy Model (Tải nếu chưa có)
89
  tts = get_tts_model()
90
 
91
+ # 2. Chuyển Model sang GPU (Chỉ làm trong hàm này)
92
  if torch.cuda.is_available():
93
  try:
94
  if next(tts.backbone.parameters()).device.type != 'cuda':
 
99
  # 3. Lấy thông tin giọng
100
  voice_info = VOICE_SAMPLES.get(voice_choice)
101
  if not voice_info:
102
+ # Fallback nếu không tìm thấy giọng
103
+ voice_choice = "Tuyên (nam miền Bắc)"
104
+ voice_info = VOICE_SAMPLES[voice_choice]
105
 
106
  ref_audio_path = voice_info["audio"]
107
  ref_text_path = voice_info["text"]
 
109
  with open(ref_text_path, "r", encoding="utf-8") as f:
110
  ref_text_raw = f.read()
111
 
112
+ # 4. Encode Reference (Có Cache)
113
  cache_key = f"preset:{voice_choice}"
114
  with reference_cache_lock:
115
  if cache_key in reference_cache:
 
119
  else:
120
  ref_codes = load_cache_from_disk(cache_key)
121
  if ref_codes is None:
122
+ # Encode
123
  ref_codes = tts.encode_reference(ref_audio_path)
 
124
  save_cache_to_disk(cache_key, ref_codes.cpu() if isinstance(ref_codes, torch.Tensor) else ref_codes)
125
 
126
  if isinstance(ref_codes, torch.Tensor) and torch.cuda.is_available():
127
  ref_codes = ref_codes.to("cuda")
128
  reference_cache[cache_key] = ref_codes
129
 
130
+ # 5. Infer (Tạo giọng nói)
131
  wav = tts.infer(text, ref_codes, ref_text_raw)
132
 
133
+ # 6. Xử tốc độ (Speed)
134
  if speed_factor != 1.0:
135
  with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
136
  sf.write(tmp.name, wav, 24000)
 
145
  if sound_stretched.channels == 2:
146
  wav = wav.reshape((-1, 2)).mean(axis=1)
147
  os.unlink(tmp_path)
 
 
148
 
149
+ # 7. Lưu file kết quả
150
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
151
+ sf.write(tmp_file.name, wav, 24000)
152
+ output_path = tmp_file.name
 
 
 
 
 
153
 
154
+ return output_path, f"✅ Hoàn tất ({time.time() - start_time:.2f}s)"
 
 
155
 
156
+ # --- 5. GIAO DIỆN GRADIO ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  theme = gr.themes.Soft()
158
  css = ".container { max-width: 900px; margin: auto; }"
159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  with gr.Blocks(theme=theme, css=css, title="VieNeu-TTS") as demo:
161
+ gr.Markdown("# 🎙️ VieNeu-TTS (ZeroGPU)")
162
 
163
  with gr.Row():
164
  with gr.Column():
165
+ inp_text = gr.Textbox(label="Văn bản", lines=3, value="Xin chào Việt Nam, đây là thử nghiệm giọng nói.")
166
+ inp_voice = gr.Dropdown(list(VOICE_SAMPLES.keys()), value="Tuyên (nam miền Bắc)", label="Chọn giọng")
 
 
 
 
 
167
  inp_speed = gr.Slider(0.5, 2.0, value=1.0, label="Tốc độ")
168
  btn = gr.Button("Đọc ngay", variant="primary")
169
+
170
  with gr.Column():
171
  out_audio = gr.Audio(label="Kết quả", autoplay=True)
172
  out_status = gr.Textbox(label="Trạng thái")
173
+
174
+ # Map function vào button
175
+ btn.click(generate_speech, [inp_text, inp_voice, inp_speed], [out_audio, out_status])
176
 
177
+ # --- 6. KHỞI CHẠY ---
 
 
 
 
 
 
 
178
  if __name__ == "__main__":
179
+ # Dùng demo.launch() chuẩn để ZeroGPU nhận diện được
180
+ demo.queue(default_concurrency_limit=40).launch(server_name="0.0.0.0", server_port=7860)