FSub / gradio_app.py
nhantrungsp's picture
Update gradio_app.py
d46de93 verified
raw
history blame
8.8 kB
import spaces # <--- QUAN TRỌNG: PHẢI ĐỂ DÒNG ĐẦU TIÊN
import os
import time
import threading
import pickle
import hashlib
import base64
import io
import tempfile
import numpy as np
# Các thư viện khác import sau spaces
import torch
import soundfile as sf
from pydub import AudioSegment
import gradio as gr
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
# Import thư viện nội bộ
from vieneu_tts import VieNeuTTS
# --- KHỞI TẠO FASTAPI ---
app = FastAPI()
print("⏳ Đang khởi động VieNeu-TTS...")
# --- 1. SETUP MODEL ---
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🖥️ Sử dụng thiết bị (Global): {device.upper()}")
# Cache
CACHE_DIR = "./reference_cache"
os.makedirs(CACHE_DIR, exist_ok=True)
reference_cache = {}
reference_cache_lock = threading.Lock()
# Hàm Cache Helper
def get_cache_path(cache_key):
key_hash = hashlib.md5(cache_key.encode()).hexdigest()
return os.path.join(CACHE_DIR, f"{key_hash}.pkl")
def load_cache_from_disk(cache_key):
cache_path = get_cache_path(cache_key)
if os.path.exists(cache_path):
try:
with open(cache_path, 'rb') as f: return pickle.load(f)
except: return None
return None
def save_cache_to_disk(cache_key, ref_codes):
cache_path = get_cache_path(cache_key)
try:
with open(cache_path, 'wb') as f: pickle.dump(ref_codes, f)
except Exception: pass
# Load Model
try:
print("📦 Đang tải model vào bộ nhớ...")
tts = VieNeuTTS(
backbone_repo="pnnbao-ump/VieNeu-TTS",
backbone_device=device,
codec_repo="neuphonic/neucodec",
codec_device=device
)
print("✅ Model đã tải xong!")
except Exception as e:
print(f"⚠️ Lỗi tải model: {e}")
tts = None
# --- 2. DATA ---
VOICE_SAMPLES = {
"Tuyên (nam miền Bắc)": {"audio": "./sample/Tuyên (nam miền Bắc).wav", "text": "./sample/Tuyên (nam miền Bắc).txt"},
"Vĩnh (nam miền Nam)": {"audio": "./sample/Vĩnh (nam miền Nam).wav", "text": "./sample/Vĩnh (nam miền Nam).txt"},
"Bình (nam miền Bắc)": {"audio": "./sample/Bình (nam miền Bắc).wav", "text": "./sample/Bình (nam miền Bắc).txt"},
"Nguyên (nam miền Nam)": {"audio": "./sample/Nguyên (nam miền Nam).wav", "text": "./sample/Nguyên (nam miền Nam).txt"},
"Sơn (nam miền Nam)": {"audio": "./sample/Sơn (nam miền Nam).wav", "text": "./sample/Sơn (nam miền Nam).txt"},
"Đoan (nữ miền Nam)": {"audio": "./sample/Đoan (nữ miền Nam).wav", "text": "./sample/Đoan (nữ miền Nam).txt"},
"Ngọc (nữ miền Bắc)": {"audio": "./sample/Ngọc (nữ miền Bắc).wav", "text": "./sample/Ngọc (nữ miền Bắc).txt"},
"Ly (nữ miền Bắc)": {"audio": "./sample/Ly (nữ miền Bắc).wav", "text": "./sample/Ly (nữ miền Bắc).txt"},
"Dung (nữ miền Nam)": {"audio": "./sample/Dung (nữ miền Nam).wav", "text": "./sample/Dung (nữ miền Nam).txt"},
"Nhỏ Ngọt Ngào": {"audio": "./sample/Nhỏ Ngọt Ngào.wav", "text": "./sample/Nhỏ Ngọt Ngào.txt"},
}
# --- 3. CORE LOGIC (Dùng chung cho cả API và UI) ---
# QUAN TRỌNG: Decorator GPU
@spaces.GPU
def core_synthesize(text, voice_choice, speed_factor):
# Lấy thông tin giọng
voice_info = VOICE_SAMPLES.get(voice_choice)
if not voice_info:
raise ValueError("Giọng không tồn tại")
ref_audio_path = voice_info["audio"]
ref_text_path = voice_info["text"]
# Load reference text
with open(ref_text_path, "r", encoding="utf-8") as f:
ref_text_raw = f.read()
# Encode reference (Cache logic)
cache_key = f"preset:{voice_choice}"
with reference_cache_lock:
if cache_key in reference_cache:
ref_codes = reference_cache[cache_key]
else:
ref_codes = load_cache_from_disk(cache_key)
if ref_codes is None:
# Đảm bảo dọn dẹp bộ nhớ trước khi encode
if torch.cuda.is_available():
torch.cuda.empty_cache()
ref_codes = tts.encode_reference(ref_audio_path)
save_cache_to_disk(cache_key, ref_codes)
reference_cache[cache_key] = ref_codes
# Infer
if torch.cuda.is_available():
torch.cuda.empty_cache()
wav = tts.infer(text, ref_codes, ref_text_raw)
# Speed
if speed_factor != 1.0:
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
sf.write(tmp.name, wav, 24000)
tmp_path = tmp.name
sound = AudioSegment.from_wav(tmp_path)
new_frame_rate = int(sound.frame_rate * speed_factor)
sound_stretched = sound._spawn(sound.raw_data, overrides={'frame_rate': new_frame_rate})
sound_stretched = sound_stretched.set_frame_rate(24000)
wav = np.array(sound_stretched.get_array_of_samples()).astype(np.float32) / 32768.0
if sound_stretched.channels == 2:
wav = wav.reshape((-1, 2)).mean(axis=1)
os.unlink(tmp_path)
return wav
# Hàm riêng cho Custom Voice cũng cần GPU
@spaces.GPU
def custom_synthesize_logic(text, ref_audio_path, ref_text_raw):
if torch.cuda.is_available():
torch.cuda.empty_cache()
ref_codes = tts.encode_reference(ref_audio_path)
wav = tts.infer(text, ref_codes, ref_text_raw)
return wav
# --- 4. API ENDPOINTS (Cho Client App kết nối) ---
class FastTTSRequest(BaseModel):
text: str
voice_choice: str
speed_factor: float = 1.0
return_base64: bool = False
@app.get("/voices")
async def get_voices():
return {"voices": list(VOICE_SAMPLES.keys())}
@app.post("/fast-tts")
async def fast_tts(request: FastTTSRequest):
try:
start = time.time()
# Gọi hàm đã được decorate @spaces.GPU
wav = core_synthesize(request.text, request.voice_choice, request.speed_factor)
process_time = time.time() - start
# Convert to Base64
audio_buffer = io.BytesIO()
sf.write(audio_buffer, wav, 24000, format='WAV')
audio_bytes = audio_buffer.getvalue()
audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
return {
"status": "success",
"audio_base64": audio_base64,
"processing_time": process_time
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# --- 5. GRADIO UI SETUP ---
theme = gr.themes.Soft()
css = ".container { max-width: 900px; margin: auto; }"
def ui_synthesize(text, voice, custom_audio, custom_text, mode, speed):
try:
start = time.time()
if mode == "custom_mode":
wav = custom_synthesize_logic(text, custom_audio, custom_text)
else:
wav = core_synthesize(text, voice, speed)
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
sf.write(tmp.name, wav, 24000)
path = tmp.name
return path, f"✅ Xong! ({time.time()-start:.2f}s)"
except Exception as e:
return None, f"❌ Lỗi: {e}"
with gr.Blocks(theme=theme, css=css, title="VieNeu-TTS") as demo:
gr.Markdown("# 🎙️ VieNeu-TTS (API + UI)")
with gr.Row():
with gr.Column():
inp_text = gr.Textbox(label="Văn bản", lines=3, value="Xin chào Việt Nam")
with gr.Tabs() as tabs:
with gr.TabItem("Giọng mẫu", id="preset_mode"):
inp_voice = gr.Dropdown(list(VOICE_SAMPLES.keys()), value="Tuyên (nam miền Bắc)", label="Chọn giọng")
with gr.TabItem("Custom", id="custom_mode"):
inp_audio = gr.Audio(type="filepath")
inp_ref_text = gr.Textbox(label="Lời thoại mẫu")
inp_speed = gr.Slider(0.5, 2.0, value=1.0, label="Tốc độ")
btn = gr.Button("Đọc ngay", variant="primary")
with gr.Column():
out_audio = gr.Audio(label="Kết quả", autoplay=True)
out_status = gr.Textbox(label="Trạng thái")
mode_state = gr.Textbox(visible=False, value="preset_mode")
tabs.children[0].select(lambda: "preset_mode", None, mode_state)
tabs.children[1].select(lambda: "custom_mode", None, mode_state)
btn.click(ui_synthesize, [inp_text, inp_voice, inp_audio, inp_ref_text, mode_state, inp_speed], [out_audio, out_status])
# --- 6. MOUNT GRADIO VÀO FASTAPI ---
app = gr.mount_gradio_app(app, demo, path="/")
# --- 7. CHẠY SERVER ---
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)