nhantrungsp commited on
Commit
847b717
·
verified ·
1 Parent(s): 962fbc7

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +21 -6
gradio_app.py CHANGED
@@ -15,6 +15,7 @@ from fastapi.responses import FileResponse
15
  from pydantic import BaseModel
16
  import base64
17
  import io
 
18
 
19
  # --- KHỞI TẠO FASTAPI ---
20
  app = FastAPI()
@@ -22,8 +23,9 @@ app = FastAPI()
22
  print("⏳ Đang khởi động VieNeu-TTS...")
23
 
24
  # --- 1. SETUP MODEL ---
 
25
  device = "cuda" if torch.cuda.is_available() else "cpu"
26
- print(f"🖥️ Sử dụng thiết bị: {device.upper()}")
27
 
28
  # Cache
29
  CACHE_DIR = "./reference_cache"
@@ -78,6 +80,9 @@ VOICE_SAMPLES = {
78
  }
79
 
80
  # --- 3. CORE LOGIC (Dùng chung cho cả API và UI) ---
 
 
 
81
  def core_synthesize(text, voice_choice, speed_factor):
82
  # Lấy thông tin giọng
83
  voice_info = VOICE_SAMPLES.get(voice_choice)
@@ -99,6 +104,11 @@ def core_synthesize(text, voice_choice, speed_factor):
99
  else:
100
  ref_codes = load_cache_from_disk(cache_key)
101
  if ref_codes is None:
 
 
 
 
 
102
  ref_codes = tts.encode_reference(ref_audio_path)
103
  save_cache_to_disk(cache_key, ref_codes)
104
  reference_cache[cache_key] = ref_codes
@@ -124,6 +134,13 @@ def core_synthesize(text, voice_choice, speed_factor):
124
 
125
  return wav
126
 
 
 
 
 
 
 
 
127
  # --- 4. API ENDPOINTS (Cho Client App kết nối) ---
128
  class FastTTSRequest(BaseModel):
129
  text: str
@@ -139,6 +156,7 @@ async def get_voices():
139
  async def fast_tts(request: FastTTSRequest):
140
  try:
141
  start = time.time()
 
142
  wav = core_synthesize(request.text, request.voice_choice, request.speed_factor)
143
  process_time = time.time() - start
144
 
@@ -168,10 +186,7 @@ def ui_synthesize(text, voice, custom_audio, custom_text, mode, speed):
168
  start = time.time()
169
  # Logic riêng cho UI (hỗ trợ custom voice)
170
  if mode == "custom_mode":
171
- ref_audio_path = custom_audio
172
- ref_text_raw = custom_text
173
- ref_codes = tts.encode_reference(ref_audio_path) # Không cache custom
174
- wav = tts.infer(text, ref_codes, ref_text_raw)
175
  # (Bỏ qua speed control cho custom để code gọn)
176
  else:
177
  wav = core_synthesize(text, voice, speed)
@@ -218,5 +233,5 @@ app = gr.mount_gradio_app(app, demo, path="/")
218
  # --- 7. CHẠY SERVER ---
219
  if __name__ == "__main__":
220
  import uvicorn
221
- # Chạy uvicorn thay demo.launch()
222
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
15
  from pydantic import BaseModel
16
  import base64
17
  import io
18
+ import spaces # <--- THÊM THƯ VIỆN NÀY
19
 
20
  # --- KHỞI TẠO FASTAPI ---
21
  app = FastAPI()
 
23
  print("⏳ Đang khởi động VieNeu-TTS...")
24
 
25
  # --- 1. SETUP MODEL ---
26
+ # Trên ZeroGPU, ban đầu có thể nó nhận là CPU, nhưng @spaces.GPU sẽ lo phần chuyển đổi sau
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
28
+ print(f"🖥️ Sử dụng thiết bị (Global): {device.upper()}")
29
 
30
  # Cache
31
  CACHE_DIR = "./reference_cache"
 
80
  }
81
 
82
  # --- 3. CORE LOGIC (Dùng chung cho cả API và UI) ---
83
+
84
+ # QUAN TRỌNG: Thêm @spaces.GPU vào hàm này để báo cho HF biết đây là hàm cần GPU
85
+ @spaces.GPU
86
  def core_synthesize(text, voice_choice, speed_factor):
87
  # Lấy thông tin giọng
88
  voice_info = VOICE_SAMPLES.get(voice_choice)
 
104
  else:
105
  ref_codes = load_cache_from_disk(cache_key)
106
  if ref_codes is None:
107
+ # Đảm bảo model đang ở đúng device trước khi encode
108
+ if torch.cuda.is_available():
109
+ # Move model components to GPU if needed inside the decorated function
110
+ # (Usually VieNeuTTS handles this based on init, but we double check)
111
+ pass
112
  ref_codes = tts.encode_reference(ref_audio_path)
113
  save_cache_to_disk(cache_key, ref_codes)
114
  reference_cache[cache_key] = ref_codes
 
134
 
135
  return wav
136
 
137
+ # Hàm riêng cho Custom Voice cũng cần GPU
138
+ @spaces.GPU
139
+ def custom_synthesize_logic(text, ref_audio_path, ref_text_raw):
140
+ ref_codes = tts.encode_reference(ref_audio_path)
141
+ wav = tts.infer(text, ref_codes, ref_text_raw)
142
+ return wav
143
+
144
  # --- 4. API ENDPOINTS (Cho Client App kết nối) ---
145
  class FastTTSRequest(BaseModel):
146
  text: str
 
156
  async def fast_tts(request: FastTTSRequest):
157
  try:
158
  start = time.time()
159
+ # Gọi hàm đã được decorate @spaces.GPU
160
  wav = core_synthesize(request.text, request.voice_choice, request.speed_factor)
161
  process_time = time.time() - start
162
 
 
186
  start = time.time()
187
  # Logic riêng cho UI (hỗ trợ custom voice)
188
  if mode == "custom_mode":
189
+ wav = custom_synthesize_logic(text, custom_audio, custom_text)
 
 
 
190
  # (Bỏ qua speed control cho custom để code gọn)
191
  else:
192
  wav = core_synthesize(text, voice, speed)
 
233
  # --- 7. CHẠY SERVER ---
234
  if __name__ == "__main__":
235
  import uvicorn
236
+ # 0.0.0.0 Mở port ra ngoài internet
237
  uvicorn.run(app, host="0.0.0.0", port=7860)