Spaces:
Running
on
Zero
Running
on
Zero
Update gradio_app.py
Browse files- gradio_app.py +21 -6
gradio_app.py
CHANGED
|
@@ -15,6 +15,7 @@ from fastapi.responses import FileResponse
|
|
| 15 |
from pydantic import BaseModel
|
| 16 |
import base64
|
| 17 |
import io
|
|
|
|
| 18 |
|
| 19 |
# --- KHỞI TẠO FASTAPI ---
|
| 20 |
app = FastAPI()
|
|
@@ -22,8 +23,9 @@ app = FastAPI()
|
|
| 22 |
print("⏳ Đang khởi động VieNeu-TTS...")
|
| 23 |
|
| 24 |
# --- 1. SETUP MODEL ---
|
|
|
|
| 25 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 26 |
-
print(f"🖥️ Sử dụng thiết
|
| 27 |
|
| 28 |
# Cache
|
| 29 |
CACHE_DIR = "./reference_cache"
|
|
@@ -78,6 +80,9 @@ VOICE_SAMPLES = {
|
|
| 78 |
}
|
| 79 |
|
| 80 |
# --- 3. CORE LOGIC (Dùng chung cho cả API và UI) ---
|
|
|
|
|
|
|
|
|
|
| 81 |
def core_synthesize(text, voice_choice, speed_factor):
|
| 82 |
# Lấy thông tin giọng
|
| 83 |
voice_info = VOICE_SAMPLES.get(voice_choice)
|
|
@@ -99,6 +104,11 @@ def core_synthesize(text, voice_choice, speed_factor):
|
|
| 99 |
else:
|
| 100 |
ref_codes = load_cache_from_disk(cache_key)
|
| 101 |
if ref_codes is None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
ref_codes = tts.encode_reference(ref_audio_path)
|
| 103 |
save_cache_to_disk(cache_key, ref_codes)
|
| 104 |
reference_cache[cache_key] = ref_codes
|
|
@@ -124,6 +134,13 @@ def core_synthesize(text, voice_choice, speed_factor):
|
|
| 124 |
|
| 125 |
return wav
|
| 126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
# --- 4. API ENDPOINTS (Cho Client App kết nối) ---
|
| 128 |
class FastTTSRequest(BaseModel):
|
| 129 |
text: str
|
|
@@ -139,6 +156,7 @@ async def get_voices():
|
|
| 139 |
async def fast_tts(request: FastTTSRequest):
|
| 140 |
try:
|
| 141 |
start = time.time()
|
|
|
|
| 142 |
wav = core_synthesize(request.text, request.voice_choice, request.speed_factor)
|
| 143 |
process_time = time.time() - start
|
| 144 |
|
|
@@ -168,10 +186,7 @@ def ui_synthesize(text, voice, custom_audio, custom_text, mode, speed):
|
|
| 168 |
start = time.time()
|
| 169 |
# Logic riêng cho UI (hỗ trợ custom voice)
|
| 170 |
if mode == "custom_mode":
|
| 171 |
-
|
| 172 |
-
ref_text_raw = custom_text
|
| 173 |
-
ref_codes = tts.encode_reference(ref_audio_path) # Không cache custom
|
| 174 |
-
wav = tts.infer(text, ref_codes, ref_text_raw)
|
| 175 |
# (Bỏ qua speed control cho custom để code gọn)
|
| 176 |
else:
|
| 177 |
wav = core_synthesize(text, voice, speed)
|
|
@@ -218,5 +233,5 @@ app = gr.mount_gradio_app(app, demo, path="/")
|
|
| 218 |
# --- 7. CHẠY SERVER ---
|
| 219 |
if __name__ == "__main__":
|
| 220 |
import uvicorn
|
| 221 |
-
#
|
| 222 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
|
| 15 |
from pydantic import BaseModel
|
| 16 |
import base64
|
| 17 |
import io
|
| 18 |
+
import spaces # <--- THÊM THƯ VIỆN NÀY
|
| 19 |
|
| 20 |
# --- KHỞI TẠO FASTAPI ---
|
| 21 |
app = FastAPI()
|
|
|
|
| 23 |
print("⏳ Đang khởi động VieNeu-TTS...")
|
| 24 |
|
| 25 |
# --- 1. SETUP MODEL ---
|
| 26 |
+
# Trên ZeroGPU, ban đầu có thể nó nhận là CPU, nhưng @spaces.GPU sẽ lo phần chuyển đổi sau
|
| 27 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 28 |
+
print(f"🖥️ Sử dụng thiết bị (Global): {device.upper()}")
|
| 29 |
|
| 30 |
# Cache
|
| 31 |
CACHE_DIR = "./reference_cache"
|
|
|
|
| 80 |
}
|
| 81 |
|
| 82 |
# --- 3. CORE LOGIC (Dùng chung cho cả API và UI) ---
|
| 83 |
+
|
| 84 |
+
# QUAN TRỌNG: Thêm @spaces.GPU vào hàm này để báo cho HF biết đây là hàm cần GPU
|
| 85 |
+
@spaces.GPU
|
| 86 |
def core_synthesize(text, voice_choice, speed_factor):
|
| 87 |
# Lấy thông tin giọng
|
| 88 |
voice_info = VOICE_SAMPLES.get(voice_choice)
|
|
|
|
| 104 |
else:
|
| 105 |
ref_codes = load_cache_from_disk(cache_key)
|
| 106 |
if ref_codes is None:
|
| 107 |
+
# Đảm bảo model đang ở đúng device trước khi encode
|
| 108 |
+
if torch.cuda.is_available():
|
| 109 |
+
# Move model components to GPU if needed inside the decorated function
|
| 110 |
+
# (Usually VieNeuTTS handles this based on init, but we double check)
|
| 111 |
+
pass
|
| 112 |
ref_codes = tts.encode_reference(ref_audio_path)
|
| 113 |
save_cache_to_disk(cache_key, ref_codes)
|
| 114 |
reference_cache[cache_key] = ref_codes
|
|
|
|
| 134 |
|
| 135 |
return wav
|
| 136 |
|
| 137 |
+
# Hàm riêng cho Custom Voice cũng cần GPU
|
| 138 |
+
@spaces.GPU
|
| 139 |
+
def custom_synthesize_logic(text, ref_audio_path, ref_text_raw):
|
| 140 |
+
ref_codes = tts.encode_reference(ref_audio_path)
|
| 141 |
+
wav = tts.infer(text, ref_codes, ref_text_raw)
|
| 142 |
+
return wav
|
| 143 |
+
|
| 144 |
# --- 4. API ENDPOINTS (Cho Client App kết nối) ---
|
| 145 |
class FastTTSRequest(BaseModel):
|
| 146 |
text: str
|
|
|
|
| 156 |
async def fast_tts(request: FastTTSRequest):
|
| 157 |
try:
|
| 158 |
start = time.time()
|
| 159 |
+
# Gọi hàm đã được decorate @spaces.GPU
|
| 160 |
wav = core_synthesize(request.text, request.voice_choice, request.speed_factor)
|
| 161 |
process_time = time.time() - start
|
| 162 |
|
|
|
|
| 186 |
start = time.time()
|
| 187 |
# Logic riêng cho UI (hỗ trợ custom voice)
|
| 188 |
if mode == "custom_mode":
|
| 189 |
+
wav = custom_synthesize_logic(text, custom_audio, custom_text)
|
|
|
|
|
|
|
|
|
|
| 190 |
# (Bỏ qua speed control cho custom để code gọn)
|
| 191 |
else:
|
| 192 |
wav = core_synthesize(text, voice, speed)
|
|
|
|
| 233 |
# --- 7. CHẠY SERVER ---
|
| 234 |
if __name__ == "__main__":
|
| 235 |
import uvicorn
|
| 236 |
+
# 0.0.0.0 Mở port ra ngoài internet
|
| 237 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|