nhantrungsp commited on
Commit
4196956
·
verified ·
1 Parent(s): d593d54

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +63 -61
gradio_app.py CHANGED
@@ -1,4 +1,4 @@
1
- import spaces # <--- LUÔN ĐỂ ĐẦU TIÊN
2
  import os
3
  import time
4
  import threading
@@ -9,7 +9,7 @@ import io
9
  import tempfile
10
  import numpy as np
11
 
12
- # Import các thư viện khác
13
  import torch
14
  import soundfile as sf
15
  from pydub import AudioSegment
@@ -18,16 +18,13 @@ from fastapi import FastAPI, HTTPException
18
  from pydantic import BaseModel
19
  from vieneu_tts import VieNeuTTS
20
 
21
- # --- KHỞI TẠO FASTAPI ---
22
  app = FastAPI()
 
23
 
24
- print("⏳ Đang khởi động VieNeu-TTS...")
25
-
26
- # --- 1. SETUP MODEL (SỬA LẠI CHO ZEROGPU) ---
27
- # QUAN TRỌNG: Trên ZeroGPU, lúc khởi động PHẢI DÙNG CPU
28
- # GPU chỉ được kích hoạt bên trong hàm @spaces.GPU
29
- device = "cpu"
30
- print(f"🖥️ Thiết bị khởi động (Global): {device.upper()} (Sẽ chuyển sang CUDA khi chạy)")
31
 
32
  # Cache
33
  CACHE_DIR = "./reference_cache"
@@ -53,21 +50,30 @@ def save_cache_to_disk(cache_key, ref_codes):
53
  with open(cache_path, 'wb') as f: pickle.dump(ref_codes, f)
54
  except Exception: pass
55
 
56
- # Load Model vào CPU trước
57
- try:
58
- print("📦 Đang tải model vào RAM (CPU)...")
59
- tts = VieNeuTTS(
60
- backbone_repo="pnnbao-ump/VieNeu-TTS",
61
- backbone_device="cpu", # Bắt buộc là CPU
62
- codec_repo="neuphonic/neucodec",
63
- codec_device="cpu" # Bắt buộc CPU
64
- )
65
- print("✅ Model đã tải xong (Ready on CPU)!")
66
- except Exception as e:
67
- print(f"⚠️ Lỗi tải model: {e}")
68
- tts = None
69
-
70
- # --- 2. DATA ---
 
 
 
 
 
 
 
 
 
71
  VOICE_SAMPLES = {
72
  "Tuyên (nam miền Bắc)": {"audio": "./sample/Tuyên (nam miền Bắc).wav", "text": "./sample/Tuyên (nam miền Bắc).txt"},
73
  "Vĩnh (nam miền Nam)": {"audio": "./sample/Vĩnh (nam miền Nam).wav", "text": "./sample/Vĩnh (nam miền Nam).txt"},
@@ -81,31 +87,22 @@ VOICE_SAMPLES = {
81
  "Nhỏ Ngọt Ngào": {"audio": "./sample/Nhỏ Ngọt Ngào.wav", "text": "./sample/Nhỏ Ngọt Ngào.txt"},
82
  }
83
 
84
- # --- 3. CORE LOGIC (ZeroGPU Optimization) ---
85
 
86
- def move_model_to_cuda():
87
- """Hàm helper để đẩy model sang GPU khi cần"""
 
 
 
 
88
  if torch.cuda.is_available():
89
- # Kiểm tra xem model đã ở trên GPU chưa để tránh move thừa
90
- # VieNeuTTS lưu model trong self.backbone và self.codec
91
  try:
92
- # Move backbone
93
  if next(tts.backbone.parameters()).device.type != 'cuda':
94
- print(" 🚀 Moving model to GPU...")
95
  tts.backbone.to("cuda")
96
-
97
- # Move codec
98
- if next(tts.codec.parameters()).device.type != 'cuda':
99
  tts.codec.to("cuda")
100
- except Exception as e:
101
- print(f"⚠️ Lỗi khi move model sang GPU: {e}")
102
 
103
- @spaces.GPU
104
- def core_synthesize(text, voice_choice, speed_factor):
105
- # 1. Đẩy model sang GPU (Chỉ làm việc này bên trong hàm @spaces.GPU)
106
- move_model_to_cuda()
107
-
108
- # 2. Lấy thông tin giọng
109
  voice_info = VOICE_SAMPLES.get(voice_choice)
110
  if not voice_info:
111
  raise ValueError("Giọng không tồn tại")
@@ -116,38 +113,38 @@ def core_synthesize(text, voice_choice, speed_factor):
116
  with open(ref_text_path, "r", encoding="utf-8") as f:
117
  ref_text_raw = f.read()
118
 
119
- # 3. Encode Reference
120
  cache_key = f"preset:{voice_choice}"
121
  with reference_cache_lock:
122
  if cache_key in reference_cache:
123
  ref_codes = reference_cache[cache_key]
124
- # Đảm bảo ref_codes cũng trên GPU
125
- if isinstance(ref_codes, torch.Tensor):
126
  ref_codes = ref_codes.to("cuda")
127
  else:
128
  ref_codes = load_cache_from_disk(cache_key)
129
  if ref_codes is None:
130
- ref_codes = tts.encode_reference(ref_audio_path) # Lúc này model đã ở GPU nên encode sẽ nhanh
131
- # Move về CPU để cache
132
  save_cache_to_disk(cache_key, ref_codes.cpu() if isinstance(ref_codes, torch.Tensor) else ref_codes)
133
 
134
- # Đẩy lại lên GPU để dùng
135
- if isinstance(ref_codes, torch.Tensor):
136
  ref_codes = ref_codes.to("cuda")
137
  reference_cache[cache_key] = ref_codes
138
 
139
- # 4. Infer
140
  wav = tts.infer(text, ref_codes, ref_text_raw)
141
 
142
- # 5. Speed Control (CPU)
143
  if speed_factor != 1.0:
144
  with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
145
  sf.write(tmp.name, wav, 24000)
146
  tmp_path = tmp.name
 
147
  sound = AudioSegment.from_wav(tmp_path)
148
  new_frame_rate = int(sound.frame_rate * speed_factor)
149
  sound_stretched = sound._spawn(sound.raw_data, overrides={'frame_rate': new_frame_rate})
150
  sound_stretched = sound_stretched.set_frame_rate(24000)
 
151
  wav = np.array(sound_stretched.get_array_of_samples()).astype(np.float32) / 32768.0
152
  if sound_stretched.channels == 2:
153
  wav = wav.reshape((-1, 2)).mean(axis=1)
@@ -155,17 +152,21 @@ def core_synthesize(text, voice_choice, speed_factor):
155
 
156
  return wav
157
 
158
- @spaces.GPU
159
  def custom_synthesize_logic(text, ref_audio_path, ref_text_raw):
160
- # 1. Đẩy model sang GPU
161
- move_model_to_cuda()
162
-
163
- # 2. Xử
 
 
 
 
164
  ref_codes = tts.encode_reference(ref_audio_path)
165
  wav = tts.infer(text, ref_codes, ref_text_raw)
166
  return wav
167
 
168
- # --- 4. API ---
169
  class FastTTSRequest(BaseModel):
170
  text: str
171
  voice_choice: str
@@ -180,11 +181,10 @@ async def get_voices():
180
  async def fast_tts(request: FastTTSRequest):
181
  try:
182
  start = time.time()
183
- # Gọi hàm đã decorate
184
  wav = core_synthesize(request.text, request.voice_choice, request.speed_factor)
185
  process_time = time.time() - start
186
 
187
- # Base64
188
  audio_buffer = io.BytesIO()
189
  sf.write(audio_buffer, wav, 24000, format='WAV')
190
  audio_base64 = base64.b64encode(audio_buffer.getvalue()).decode('utf-8')
@@ -197,7 +197,7 @@ async def fast_tts(request: FastTTSRequest):
197
  except Exception as e:
198
  raise HTTPException(status_code=500, detail=str(e))
199
 
200
- # --- 5. UI ---
201
  theme = gr.themes.Soft()
202
  css = ".container { max-width: 900px; margin: auto; }"
203
 
@@ -239,8 +239,10 @@ with gr.Blocks(theme=theme, css=css, title="VieNeu-TTS") as demo:
239
  tabs.children[1].select(lambda: "custom_mode", None, mode_state)
240
  btn.click(ui_synthesize, [inp_text, inp_voice, inp_audio, inp_ref_text, mode_state, inp_speed], [out_audio, out_status])
241
 
 
242
  app = gr.mount_gradio_app(app, demo, path="/")
243
 
244
  if __name__ == "__main__":
245
  import uvicorn
 
246
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ import spaces # <--- BẮT BUỘC DÒNG 1
2
  import os
3
  import time
4
  import threading
 
9
  import tempfile
10
  import numpy as np
11
 
12
+ # Các thư viện khác
13
  import torch
14
  import soundfile as sf
15
  from pydub import AudioSegment
 
18
  from pydantic import BaseModel
19
  from vieneu_tts import VieNeuTTS
20
 
21
+ # --- KHỞI TẠO ---
22
  app = FastAPI()
23
+ print("⏳ Đang khởi động Server...")
24
 
25
+ # Biến toàn cục để lưu model (Lazy Load)
26
+ tts_model = None
27
+ model_lock = threading.Lock()
 
 
 
 
28
 
29
  # Cache
30
  CACHE_DIR = "./reference_cache"
 
50
  with open(cache_path, 'wb') as f: pickle.dump(ref_codes, f)
51
  except Exception: pass
52
 
53
+ # --- HELPER: LOAD MODEL AN TOÀN ---
54
+ def get_tts_model():
55
+ """Hàm này chỉ tải model khi được gọi lần đầu tiên"""
56
+ global tts_model
57
+ with model_lock:
58
+ if tts_model is None:
59
+ print("📦 Đang khởi tạo model lần đầu (Lazy Load)...")
60
+ device = "cuda" if torch.cuda.is_available() else "cpu"
61
+ print(f" 🖥️ Device: {device}")
62
+ try:
63
+ # Load model
64
+ tts_model = VieNeuTTS(
65
+ backbone_repo="pnnbao-ump/VieNeu-TTS",
66
+ backbone_device=device,
67
+ codec_repo="neuphonic/neucodec",
68
+ codec_device=device
69
+ )
70
+ print(" ✅ Model tải thành công!")
71
+ except Exception as e:
72
+ print(f" ❌ Lỗi tải model: {e}")
73
+ raise e
74
+ return tts_model
75
+
76
+ # --- DATA ---
77
  VOICE_SAMPLES = {
78
  "Tuyên (nam miền Bắc)": {"audio": "./sample/Tuyên (nam miền Bắc).wav", "text": "./sample/Tuyên (nam miền Bắc).txt"},
79
  "Vĩnh (nam miền Nam)": {"audio": "./sample/Vĩnh (nam miền Nam).wav", "text": "./sample/Vĩnh (nam miền Nam).txt"},
 
87
  "Nhỏ Ngọt Ngào": {"audio": "./sample/Nhỏ Ngọt Ngào.wav", "text": "./sample/Nhỏ Ngọt Ngào.txt"},
88
  }
89
 
90
+ # --- CORE LOGIC (DECORATED WITH @spaces.GPU) ---
91
 
92
+ @spaces.GPU(duration=120) # Tăng thời gian timeout lên 120s cho lần đầu load model
93
+ def core_synthesize(text, voice_choice, speed_factor):
94
+ # 1. Lấy model (Sẽ tải nếu chưa có)
95
+ tts = get_tts_model()
96
+
97
+ # 2. Đảm bảo model ở đúng device (GPU)
98
  if torch.cuda.is_available():
 
 
99
  try:
 
100
  if next(tts.backbone.parameters()).device.type != 'cuda':
 
101
  tts.backbone.to("cuda")
 
 
 
102
  tts.codec.to("cuda")
103
+ except: pass
 
104
 
105
+ # 3. Lấy thông tin giọng
 
 
 
 
 
106
  voice_info = VOICE_SAMPLES.get(voice_choice)
107
  if not voice_info:
108
  raise ValueError("Giọng không tồn tại")
 
113
  with open(ref_text_path, "r", encoding="utf-8") as f:
114
  ref_text_raw = f.read()
115
 
116
+ # 4. Encode Reference
117
  cache_key = f"preset:{voice_choice}"
118
  with reference_cache_lock:
119
  if cache_key in reference_cache:
120
  ref_codes = reference_cache[cache_key]
121
+ if isinstance(ref_codes, torch.Tensor) and torch.cuda.is_available():
 
122
  ref_codes = ref_codes.to("cuda")
123
  else:
124
  ref_codes = load_cache_from_disk(cache_key)
125
  if ref_codes is None:
126
+ ref_codes = tts.encode_reference(ref_audio_path)
127
+ # Cache trên CPU
128
  save_cache_to_disk(cache_key, ref_codes.cpu() if isinstance(ref_codes, torch.Tensor) else ref_codes)
129
 
130
+ if isinstance(ref_codes, torch.Tensor) and torch.cuda.is_available():
 
131
  ref_codes = ref_codes.to("cuda")
132
  reference_cache[cache_key] = ref_codes
133
 
134
+ # 5. Infer
135
  wav = tts.infer(text, ref_codes, ref_text_raw)
136
 
137
+ # 6. Speed Control (CPU Processing)
138
  if speed_factor != 1.0:
139
  with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
140
  sf.write(tmp.name, wav, 24000)
141
  tmp_path = tmp.name
142
+
143
  sound = AudioSegment.from_wav(tmp_path)
144
  new_frame_rate = int(sound.frame_rate * speed_factor)
145
  sound_stretched = sound._spawn(sound.raw_data, overrides={'frame_rate': new_frame_rate})
146
  sound_stretched = sound_stretched.set_frame_rate(24000)
147
+
148
  wav = np.array(sound_stretched.get_array_of_samples()).astype(np.float32) / 32768.0
149
  if sound_stretched.channels == 2:
150
  wav = wav.reshape((-1, 2)).mean(axis=1)
 
152
 
153
  return wav
154
 
155
+ @spaces.GPU(duration=120)
156
  def custom_synthesize_logic(text, ref_audio_path, ref_text_raw):
157
+ tts = get_tts_model()
158
+ if torch.cuda.is_available():
159
+ try:
160
+ if next(tts.backbone.parameters()).device.type != 'cuda':
161
+ tts.backbone.to("cuda")
162
+ tts.codec.to("cuda")
163
+ except: pass
164
+
165
  ref_codes = tts.encode_reference(ref_audio_path)
166
  wav = tts.infer(text, ref_codes, ref_text_raw)
167
  return wav
168
 
169
+ # --- API ---
170
  class FastTTSRequest(BaseModel):
171
  text: str
172
  voice_choice: str
 
181
  async def fast_tts(request: FastTTSRequest):
182
  try:
183
  start = time.time()
184
+ # Gọi hàm GPU
185
  wav = core_synthesize(request.text, request.voice_choice, request.speed_factor)
186
  process_time = time.time() - start
187
 
 
188
  audio_buffer = io.BytesIO()
189
  sf.write(audio_buffer, wav, 24000, format='WAV')
190
  audio_base64 = base64.b64encode(audio_buffer.getvalue()).decode('utf-8')
 
197
  except Exception as e:
198
  raise HTTPException(status_code=500, detail=str(e))
199
 
200
+ # --- GRADIO UI ---
201
  theme = gr.themes.Soft()
202
  css = ".container { max-width: 900px; margin: auto; }"
203
 
 
239
  tabs.children[1].select(lambda: "custom_mode", None, mode_state)
240
  btn.click(ui_synthesize, [inp_text, inp_voice, inp_audio, inp_ref_text, mode_state, inp_speed], [out_audio, out_status])
241
 
242
+ # Mount Gradio vào FastAPI
243
  app = gr.mount_gradio_app(app, demo, path="/")
244
 
245
  if __name__ == "__main__":
246
  import uvicorn
247
+ # Mở port 7860 để Hugging Face truy cập
248
  uvicorn.run(app, host="0.0.0.0", port=7860)