nhantrungsp commited on
Commit
d593d54
·
verified ·
1 Parent(s): d46de93

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +58 -46
gradio_app.py CHANGED
@@ -1,4 +1,4 @@
1
- import spaces # <--- QUAN TRỌNG: PHẢI ĐỂ DÒNG ĐẦU TIÊN
2
  import os
3
  import time
4
  import threading
@@ -9,15 +9,13 @@ import io
9
  import tempfile
10
  import numpy as np
11
 
12
- # Các thư viện khác import sau spaces
13
  import torch
14
  import soundfile as sf
15
  from pydub import AudioSegment
16
  import gradio as gr
17
  from fastapi import FastAPI, HTTPException
18
  from pydantic import BaseModel
19
-
20
- # Import thư viện nội bộ
21
  from vieneu_tts import VieNeuTTS
22
 
23
  # --- KHỞI TẠO FASTAPI ---
@@ -25,9 +23,11 @@ app = FastAPI()
25
 
26
  print("⏳ Đang khởi động VieNeu-TTS...")
27
 
28
- # --- 1. SETUP MODEL ---
29
- device = "cuda" if torch.cuda.is_available() else "cpu"
30
- print(f"🖥️ Sử dụng thiết bị (Global): {device.upper()}")
 
 
31
 
32
  # Cache
33
  CACHE_DIR = "./reference_cache"
@@ -35,7 +35,6 @@ os.makedirs(CACHE_DIR, exist_ok=True)
35
  reference_cache = {}
36
  reference_cache_lock = threading.Lock()
37
 
38
- # Hàm Cache Helper
39
  def get_cache_path(cache_key):
40
  key_hash = hashlib.md5(cache_key.encode()).hexdigest()
41
  return os.path.join(CACHE_DIR, f"{key_hash}.pkl")
@@ -54,16 +53,16 @@ def save_cache_to_disk(cache_key, ref_codes):
54
  with open(cache_path, 'wb') as f: pickle.dump(ref_codes, f)
55
  except Exception: pass
56
 
57
- # Load Model
58
  try:
59
- print("📦 Đang tải model vào bộ nhớ...")
60
  tts = VieNeuTTS(
61
  backbone_repo="pnnbao-ump/VieNeu-TTS",
62
- backbone_device=device,
63
  codec_repo="neuphonic/neucodec",
64
- codec_device=device
65
  )
66
- print("✅ Model đã tải xong!")
67
  except Exception as e:
68
  print(f"⚠️ Lỗi tải model: {e}")
69
  tts = None
@@ -82,12 +81,31 @@ VOICE_SAMPLES = {
82
  "Nhỏ Ngọt Ngào": {"audio": "./sample/Nhỏ Ngọt Ngào.wav", "text": "./sample/Nhỏ Ngọt Ngào.txt"},
83
  }
84
 
85
- # --- 3. CORE LOGIC (Dùng chung cho cả API và UI) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- # QUAN TRỌNG: Decorator GPU
88
  @spaces.GPU
89
  def core_synthesize(text, voice_choice, speed_factor):
90
- # Lấy thông tin giọng
 
 
 
91
  voice_info = VOICE_SAMPLES.get(voice_choice)
92
  if not voice_info:
93
  raise ValueError("Giọng không tồn tại")
@@ -95,41 +113,41 @@ def core_synthesize(text, voice_choice, speed_factor):
95
  ref_audio_path = voice_info["audio"]
96
  ref_text_path = voice_info["text"]
97
 
98
- # Load reference text
99
  with open(ref_text_path, "r", encoding="utf-8") as f:
100
  ref_text_raw = f.read()
101
 
102
- # Encode reference (Cache logic)
103
  cache_key = f"preset:{voice_choice}"
104
  with reference_cache_lock:
105
  if cache_key in reference_cache:
106
  ref_codes = reference_cache[cache_key]
 
 
 
107
  else:
108
  ref_codes = load_cache_from_disk(cache_key)
109
  if ref_codes is None:
110
- # Đảm bảo dọn dẹp bộ nhớ trước khi encode
111
- if torch.cuda.is_available():
112
- torch.cuda.empty_cache()
113
- ref_codes = tts.encode_reference(ref_audio_path)
114
- save_cache_to_disk(cache_key, ref_codes)
 
 
115
  reference_cache[cache_key] = ref_codes
116
 
117
- # Infer
118
- if torch.cuda.is_available():
119
- torch.cuda.empty_cache()
120
  wav = tts.infer(text, ref_codes, ref_text_raw)
121
-
122
- # Speed
123
  if speed_factor != 1.0:
124
  with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
125
  sf.write(tmp.name, wav, 24000)
126
  tmp_path = tmp.name
127
-
128
  sound = AudioSegment.from_wav(tmp_path)
129
  new_frame_rate = int(sound.frame_rate * speed_factor)
130
  sound_stretched = sound._spawn(sound.raw_data, overrides={'frame_rate': new_frame_rate})
131
  sound_stretched = sound_stretched.set_frame_rate(24000)
132
-
133
  wav = np.array(sound_stretched.get_array_of_samples()).astype(np.float32) / 32768.0
134
  if sound_stretched.channels == 2:
135
  wav = wav.reshape((-1, 2)).mean(axis=1)
@@ -137,16 +155,17 @@ def core_synthesize(text, voice_choice, speed_factor):
137
 
138
  return wav
139
 
140
- # Hàm riêng cho Custom Voice cũng cần GPU
141
  @spaces.GPU
142
  def custom_synthesize_logic(text, ref_audio_path, ref_text_raw):
143
- if torch.cuda.is_available():
144
- torch.cuda.empty_cache()
 
 
145
  ref_codes = tts.encode_reference(ref_audio_path)
146
  wav = tts.infer(text, ref_codes, ref_text_raw)
147
  return wav
148
 
149
- # --- 4. API ENDPOINTS (Cho Client App kết nối) ---
150
  class FastTTSRequest(BaseModel):
151
  text: str
152
  voice_choice: str
@@ -161,15 +180,14 @@ async def get_voices():
161
  async def fast_tts(request: FastTTSRequest):
162
  try:
163
  start = time.time()
164
- # Gọi hàm đã được decorate @spaces.GPU
165
  wav = core_synthesize(request.text, request.voice_choice, request.speed_factor)
166
  process_time = time.time() - start
167
 
168
- # Convert to Base64
169
  audio_buffer = io.BytesIO()
170
  sf.write(audio_buffer, wav, 24000, format='WAV')
171
- audio_bytes = audio_buffer.getvalue()
172
- audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
173
 
174
  return {
175
  "status": "success",
@@ -179,7 +197,7 @@ async def fast_tts(request: FastTTSRequest):
179
  except Exception as e:
180
  raise HTTPException(status_code=500, detail=str(e))
181
 
182
- # --- 5. GRADIO UI SETUP ---
183
  theme = gr.themes.Soft()
184
  css = ".container { max-width: 900px; margin: auto; }"
185
 
@@ -190,7 +208,7 @@ def ui_synthesize(text, voice, custom_audio, custom_text, mode, speed):
190
  wav = custom_synthesize_logic(text, custom_audio, custom_text)
191
  else:
192
  wav = core_synthesize(text, voice, speed)
193
-
194
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
195
  sf.write(tmp.name, wav, 24000)
196
  path = tmp.name
@@ -204,17 +222,14 @@ with gr.Blocks(theme=theme, css=css, title="VieNeu-TTS") as demo:
204
  with gr.Row():
205
  with gr.Column():
206
  inp_text = gr.Textbox(label="Văn bản", lines=3, value="Xin chào Việt Nam")
207
-
208
  with gr.Tabs() as tabs:
209
  with gr.TabItem("Giọng mẫu", id="preset_mode"):
210
  inp_voice = gr.Dropdown(list(VOICE_SAMPLES.keys()), value="Tuyên (nam miền Bắc)", label="Chọn giọng")
211
  with gr.TabItem("Custom", id="custom_mode"):
212
  inp_audio = gr.Audio(type="filepath")
213
  inp_ref_text = gr.Textbox(label="Lời thoại mẫu")
214
-
215
  inp_speed = gr.Slider(0.5, 2.0, value=1.0, label="Tốc độ")
216
  btn = gr.Button("Đọc ngay", variant="primary")
217
-
218
  with gr.Column():
219
  out_audio = gr.Audio(label="Kết quả", autoplay=True)
220
  out_status = gr.Textbox(label="Trạng thái")
@@ -222,13 +237,10 @@ with gr.Blocks(theme=theme, css=css, title="VieNeu-TTS") as demo:
222
  mode_state = gr.Textbox(visible=False, value="preset_mode")
223
  tabs.children[0].select(lambda: "preset_mode", None, mode_state)
224
  tabs.children[1].select(lambda: "custom_mode", None, mode_state)
225
-
226
  btn.click(ui_synthesize, [inp_text, inp_voice, inp_audio, inp_ref_text, mode_state, inp_speed], [out_audio, out_status])
227
 
228
- # --- 6. MOUNT GRADIO VÀO FASTAPI ---
229
  app = gr.mount_gradio_app(app, demo, path="/")
230
 
231
- # --- 7. CHẠY SERVER ---
232
  if __name__ == "__main__":
233
  import uvicorn
234
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ import spaces # <--- LUÔN ĐỂ ĐẦU TIÊN
2
  import os
3
  import time
4
  import threading
 
9
  import tempfile
10
  import numpy as np
11
 
12
+ # Import các thư viện khác
13
  import torch
14
  import soundfile as sf
15
  from pydub import AudioSegment
16
  import gradio as gr
17
  from fastapi import FastAPI, HTTPException
18
  from pydantic import BaseModel
 
 
19
  from vieneu_tts import VieNeuTTS
20
 
21
  # --- KHỞI TẠO FASTAPI ---
 
23
 
24
  print("⏳ Đang khởi động VieNeu-TTS...")
25
 
26
+ # --- 1. SETUP MODEL (SỬA LẠI CHO ZEROGPU) ---
27
+ # QUAN TRỌNG: Trên ZeroGPU, lúc khởi động PHẢI DÙNG CPU
28
+ # GPU chỉ được kích hoạt bên trong hàm @spaces.GPU
29
+ device = "cpu"
30
+ print(f"🖥️ Thiết bị khởi động (Global): {device.upper()} (Sẽ chuyển sang CUDA khi chạy)")
31
 
32
  # Cache
33
  CACHE_DIR = "./reference_cache"
 
35
  reference_cache = {}
36
  reference_cache_lock = threading.Lock()
37
 
 
38
  def get_cache_path(cache_key):
39
  key_hash = hashlib.md5(cache_key.encode()).hexdigest()
40
  return os.path.join(CACHE_DIR, f"{key_hash}.pkl")
 
53
  with open(cache_path, 'wb') as f: pickle.dump(ref_codes, f)
54
  except Exception: pass
55
 
56
+ # Load Model vào CPU trước
57
  try:
58
+ print("📦 Đang tải model vào RAM (CPU)...")
59
  tts = VieNeuTTS(
60
  backbone_repo="pnnbao-ump/VieNeu-TTS",
61
+ backbone_device="cpu", # Bắt buộc là CPU
62
  codec_repo="neuphonic/neucodec",
63
+ codec_device="cpu" # Bắt buộc là CPU
64
  )
65
+ print("✅ Model đã tải xong (Ready on CPU)!")
66
  except Exception as e:
67
  print(f"⚠️ Lỗi tải model: {e}")
68
  tts = None
 
81
  "Nhỏ Ngọt Ngào": {"audio": "./sample/Nhỏ Ngọt Ngào.wav", "text": "./sample/Nhỏ Ngọt Ngào.txt"},
82
  }
83
 
84
+ # --- 3. CORE LOGIC (ZeroGPU Optimization) ---
85
+
86
+ def move_model_to_cuda():
87
+ """Hàm helper để đẩy model sang GPU khi cần"""
88
+ if torch.cuda.is_available():
89
+ # Kiểm tra xem model đã ở trên GPU chưa để tránh move thừa
90
+ # VieNeuTTS lưu model trong self.backbone và self.codec
91
+ try:
92
+ # Move backbone
93
+ if next(tts.backbone.parameters()).device.type != 'cuda':
94
+ print(" 🚀 Moving model to GPU...")
95
+ tts.backbone.to("cuda")
96
+
97
+ # Move codec
98
+ if next(tts.codec.parameters()).device.type != 'cuda':
99
+ tts.codec.to("cuda")
100
+ except Exception as e:
101
+ print(f"⚠️ Lỗi khi move model sang GPU: {e}")
102
 
 
103
  @spaces.GPU
104
  def core_synthesize(text, voice_choice, speed_factor):
105
+ # 1. Đẩy model sang GPU (Chỉ làm việc này bên trong hàm @spaces.GPU)
106
+ move_model_to_cuda()
107
+
108
+ # 2. Lấy thông tin giọng
109
  voice_info = VOICE_SAMPLES.get(voice_choice)
110
  if not voice_info:
111
  raise ValueError("Giọng không tồn tại")
 
113
  ref_audio_path = voice_info["audio"]
114
  ref_text_path = voice_info["text"]
115
 
 
116
  with open(ref_text_path, "r", encoding="utf-8") as f:
117
  ref_text_raw = f.read()
118
 
119
+ # 3. Encode Reference
120
  cache_key = f"preset:{voice_choice}"
121
  with reference_cache_lock:
122
  if cache_key in reference_cache:
123
  ref_codes = reference_cache[cache_key]
124
+ # Đảm bảo ref_codes cũng ở trên GPU
125
+ if isinstance(ref_codes, torch.Tensor):
126
+ ref_codes = ref_codes.to("cuda")
127
  else:
128
  ref_codes = load_cache_from_disk(cache_key)
129
  if ref_codes is None:
130
+ ref_codes = tts.encode_reference(ref_audio_path) # Lúc này model đã GPU nên encode sẽ nhanh
131
+ # Move về CPU để cache
132
+ save_cache_to_disk(cache_key, ref_codes.cpu() if isinstance(ref_codes, torch.Tensor) else ref_codes)
133
+
134
+ # Đẩy lại lên GPU để dùng
135
+ if isinstance(ref_codes, torch.Tensor):
136
+ ref_codes = ref_codes.to("cuda")
137
  reference_cache[cache_key] = ref_codes
138
 
139
+ # 4. Infer
 
 
140
  wav = tts.infer(text, ref_codes, ref_text_raw)
141
+
142
+ # 5. Speed Control (CPU)
143
  if speed_factor != 1.0:
144
  with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
145
  sf.write(tmp.name, wav, 24000)
146
  tmp_path = tmp.name
 
147
  sound = AudioSegment.from_wav(tmp_path)
148
  new_frame_rate = int(sound.frame_rate * speed_factor)
149
  sound_stretched = sound._spawn(sound.raw_data, overrides={'frame_rate': new_frame_rate})
150
  sound_stretched = sound_stretched.set_frame_rate(24000)
 
151
  wav = np.array(sound_stretched.get_array_of_samples()).astype(np.float32) / 32768.0
152
  if sound_stretched.channels == 2:
153
  wav = wav.reshape((-1, 2)).mean(axis=1)
 
155
 
156
  return wav
157
 
 
158
  @spaces.GPU
159
  def custom_synthesize_logic(text, ref_audio_path, ref_text_raw):
160
+ # 1. Đẩy model sang GPU
161
+ move_model_to_cuda()
162
+
163
+ # 2. Xử lý
164
  ref_codes = tts.encode_reference(ref_audio_path)
165
  wav = tts.infer(text, ref_codes, ref_text_raw)
166
  return wav
167
 
168
+ # --- 4. API ---
169
  class FastTTSRequest(BaseModel):
170
  text: str
171
  voice_choice: str
 
180
  async def fast_tts(request: FastTTSRequest):
181
  try:
182
  start = time.time()
183
+ # Gọi hàm đã decorate
184
  wav = core_synthesize(request.text, request.voice_choice, request.speed_factor)
185
  process_time = time.time() - start
186
 
187
+ # Base64
188
  audio_buffer = io.BytesIO()
189
  sf.write(audio_buffer, wav, 24000, format='WAV')
190
+ audio_base64 = base64.b64encode(audio_buffer.getvalue()).decode('utf-8')
 
191
 
192
  return {
193
  "status": "success",
 
197
  except Exception as e:
198
  raise HTTPException(status_code=500, detail=str(e))
199
 
200
+ # --- 5. UI ---
201
  theme = gr.themes.Soft()
202
  css = ".container { max-width: 900px; margin: auto; }"
203
 
 
208
  wav = custom_synthesize_logic(text, custom_audio, custom_text)
209
  else:
210
  wav = core_synthesize(text, voice, speed)
211
+
212
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
213
  sf.write(tmp.name, wav, 24000)
214
  path = tmp.name
 
222
  with gr.Row():
223
  with gr.Column():
224
  inp_text = gr.Textbox(label="Văn bản", lines=3, value="Xin chào Việt Nam")
 
225
  with gr.Tabs() as tabs:
226
  with gr.TabItem("Giọng mẫu", id="preset_mode"):
227
  inp_voice = gr.Dropdown(list(VOICE_SAMPLES.keys()), value="Tuyên (nam miền Bắc)", label="Chọn giọng")
228
  with gr.TabItem("Custom", id="custom_mode"):
229
  inp_audio = gr.Audio(type="filepath")
230
  inp_ref_text = gr.Textbox(label="Lời thoại mẫu")
 
231
  inp_speed = gr.Slider(0.5, 2.0, value=1.0, label="Tốc độ")
232
  btn = gr.Button("Đọc ngay", variant="primary")
 
233
  with gr.Column():
234
  out_audio = gr.Audio(label="Kết quả", autoplay=True)
235
  out_status = gr.Textbox(label="Trạng thái")
 
237
  mode_state = gr.Textbox(visible=False, value="preset_mode")
238
  tabs.children[0].select(lambda: "preset_mode", None, mode_state)
239
  tabs.children[1].select(lambda: "custom_mode", None, mode_state)
 
240
  btn.click(ui_synthesize, [inp_text, inp_voice, inp_audio, inp_ref_text, mode_state, inp_speed], [out_audio, out_status])
241
 
 
242
  app = gr.mount_gradio_app(app, demo, path="/")
243
 
 
244
  if __name__ == "__main__":
245
  import uvicorn
246
  uvicorn.run(app, host="0.0.0.0", port=7860)