Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import soundfile as sf | |
| import tempfile | |
| import torch | |
| from vieneu_tts import VieNeuTTS | |
| import os | |
| import time | |
| import threading | |
| import pickle | |
| import hashlib | |
| import numpy as np | |
| from pydub import AudioSegment | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel | |
| import base64 | |
| import io | |
| # --- KHỞI TẠO FASTAPI --- | |
| app = FastAPI() | |
| print("⏳ Đang khởi động VieNeu-TTS...") | |
| # --- 1. SETUP MODEL --- | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"🖥️ Sử dụng thiết bị: {device.upper()}") | |
| # Cache | |
| CACHE_DIR = "./reference_cache" | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| reference_cache = {} | |
| reference_cache_lock = threading.Lock() | |
| # Hàm Cache Helper | |
| def get_cache_path(cache_key): | |
| key_hash = hashlib.md5(cache_key.encode()).hexdigest() | |
| return os.path.join(CACHE_DIR, f"{key_hash}.pkl") | |
| def load_cache_from_disk(cache_key): | |
| cache_path = get_cache_path(cache_key) | |
| if os.path.exists(cache_path): | |
| try: | |
| with open(cache_path, 'rb') as f: return pickle.load(f) | |
| except: return None | |
| return None | |
| def save_cache_to_disk(cache_key, ref_codes): | |
| cache_path = get_cache_path(cache_key) | |
| try: | |
| with open(cache_path, 'wb') as f: pickle.dump(ref_codes, f) | |
| except Exception: pass | |
| # Load Model | |
| try: | |
| tts = VieNeuTTS( | |
| backbone_repo="pnnbao-ump/VieNeu-TTS", | |
| backbone_device=device, | |
| codec_repo="neuphonic/neucodec", | |
| codec_device=device | |
| ) | |
| print("✅ Model đã tải xong!") | |
| except Exception as e: | |
| print(f"⚠️ Lỗi tải model: {e}") | |
| tts = None | |
| # --- 2. DATA --- | |
| VOICE_SAMPLES = { | |
| "Tuyên (nam miền Bắc)": {"audio": "./sample/Tuyên (nam miền Bắc).wav", "text": "./sample/Tuyên (nam miền Bắc).txt"}, | |
| "Vĩnh (nam miền Nam)": {"audio": "./sample/Vĩnh (nam miền Nam).wav", "text": "./sample/Vĩnh (nam miền Nam).txt"}, | |
| "Bình (nam miền Bắc)": {"audio": "./sample/Bình (nam miền Bắc).wav", "text": "./sample/Bình (nam miền Bắc).txt"}, | |
| "Nguyên (nam miền Nam)": {"audio": "./sample/Nguyên (nam miền Nam).wav", "text": "./sample/Nguyên (nam miền Nam).txt"}, | |
| "Sơn (nam miền Nam)": {"audio": "./sample/Sơn (nam miền Nam).wav", "text": "./sample/Sơn (nam miền Nam).txt"}, | |
| "Đoan (nữ miền Nam)": {"audio": "./sample/Đoan (nữ miền Nam).wav", "text": "./sample/Đoan (nữ miền Nam).txt"}, | |
| "Ngọc (nữ miền Bắc)": {"audio": "./sample/Ngọc (nữ miền Bắc).wav", "text": "./sample/Ngọc (nữ miền Bắc).txt"}, | |
| "Ly (nữ miền Bắc)": {"audio": "./sample/Ly (nữ miền Bắc).wav", "text": "./sample/Ly (nữ miền Bắc).txt"}, | |
| "Dung (nữ miền Nam)": {"audio": "./sample/Dung (nữ miền Nam).wav", "text": "./sample/Dung (nữ miền Nam).txt"}, | |
| "Nhỏ Ngọt Ngào": {"audio": "./sample/Nhỏ Ngọt Ngào.wav", "text": "./sample/Nhỏ Ngọt Ngào.txt"}, | |
| } | |
| # --- 3. CORE LOGIC (Dùng chung cho cả API và UI) --- | |
| def core_synthesize(text, voice_choice, speed_factor): | |
| # Lấy thông tin giọng | |
| voice_info = VOICE_SAMPLES.get(voice_choice) | |
| if not voice_info: | |
| raise ValueError("Giọng không tồn tại") | |
| ref_audio_path = voice_info["audio"] | |
| ref_text_path = voice_info["text"] | |
| # Load reference text | |
| with open(ref_text_path, "r", encoding="utf-8") as f: | |
| ref_text_raw = f.read() | |
| # Encode reference (Cache logic) | |
| cache_key = f"preset:{voice_choice}" | |
| with reference_cache_lock: | |
| if cache_key in reference_cache: | |
| ref_codes = reference_cache[cache_key] | |
| else: | |
| ref_codes = load_cache_from_disk(cache_key) | |
| if ref_codes is None: | |
| ref_codes = tts.encode_reference(ref_audio_path) | |
| save_cache_to_disk(cache_key, ref_codes) | |
| reference_cache[cache_key] = ref_codes | |
| # Infer | |
| wav = tts.infer(text, ref_codes, ref_text_raw) | |
| # Speed | |
| if speed_factor != 1.0: | |
| with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp: | |
| sf.write(tmp.name, wav, 24000) | |
| tmp_path = tmp.name | |
| sound = AudioSegment.from_wav(tmp_path) | |
| new_frame_rate = int(sound.frame_rate * speed_factor) | |
| sound_stretched = sound._spawn(sound.raw_data, overrides={'frame_rate': new_frame_rate}) | |
| sound_stretched = sound_stretched.set_frame_rate(24000) | |
| wav = np.array(sound_stretched.get_array_of_samples()).astype(np.float32) / 32768.0 | |
| if sound_stretched.channels == 2: | |
| wav = wav.reshape((-1, 2)).mean(axis=1) | |
| os.unlink(tmp_path) | |
| return wav | |
| # --- 4. API ENDPOINTS (Cho Client App kết nối) --- | |
| class FastTTSRequest(BaseModel): | |
| text: str | |
| voice_choice: str | |
| speed_factor: float = 1.0 | |
| return_base64: bool = False | |
| async def get_voices(): | |
| return {"voices": list(VOICE_SAMPLES.keys())} | |
| async def fast_tts(request: FastTTSRequest): | |
| try: | |
| start = time.time() | |
| wav = core_synthesize(request.text, request.voice_choice, request.speed_factor) | |
| process_time = time.time() - start | |
| # Convert to Base64 | |
| audio_buffer = io.BytesIO() | |
| sf.write(audio_buffer, wav, 24000, format='WAV') | |
| audio_bytes = audio_buffer.getvalue() | |
| audio_base64 = base64.b64encode(audio_bytes).decode('utf-8') | |
| return { | |
| "status": "success", | |
| "audio_base64": audio_base64, | |
| "processing_time": process_time | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # --- 5. GRADIO UI SETUP --- | |
| # Dùng theme Soft để tránh lỗi | |
| theme = gr.themes.Soft() | |
| # CSS | |
| css = ".container { max-width: 900px; margin: auto; }" | |
| def ui_synthesize(text, voice, custom_audio, custom_text, mode, speed): | |
| try: | |
| start = time.time() | |
| # Logic riêng cho UI (hỗ trợ custom voice) | |
| if mode == "custom_mode": | |
| ref_audio_path = custom_audio | |
| ref_text_raw = custom_text | |
| ref_codes = tts.encode_reference(ref_audio_path) # Không cache custom | |
| wav = tts.infer(text, ref_codes, ref_text_raw) | |
| # (Bỏ qua speed control cho custom để code gọn) | |
| else: | |
| wav = core_synthesize(text, voice, speed) | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: | |
| sf.write(tmp.name, wav, 24000) | |
| path = tmp.name | |
| return path, f"✅ Xong! ({time.time()-start:.2f}s)" | |
| except Exception as e: | |
| return None, f"❌ Lỗi: {e}" | |
| with gr.Blocks(theme=theme, css=css, title="VieNeu-TTS") as demo: | |
| gr.Markdown("# 🎙️ VieNeu-TTS (API + UI)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| inp_text = gr.Textbox(label="Văn bản", lines=3, value="Xin chào Việt Nam") | |
| with gr.Tabs() as tabs: | |
| with gr.TabItem("Giọng mẫu", id="preset_mode"): | |
| inp_voice = gr.Dropdown(list(VOICE_SAMPLES.keys()), value="Tuyên (nam miền Bắc)", label="Chọn giọng") | |
| with gr.TabItem("Custom", id="custom_mode"): | |
| inp_audio = gr.Audio(type="filepath") | |
| inp_ref_text = gr.Textbox(label="Lời thoại mẫu") | |
| inp_speed = gr.Slider(0.5, 2.0, value=1.0, label="Tốc độ") | |
| btn = gr.Button("Đọc ngay", variant="primary") | |
| with gr.Column(): | |
| out_audio = gr.Audio(label="Kết quả", autoplay=True) | |
| out_status = gr.Textbox(label="Trạng thái") | |
| # Ẩn hiện mode | |
| mode_state = gr.Textbox(visible=False, value="preset_mode") | |
| tabs.children[0].select(lambda: "preset_mode", None, mode_state) | |
| tabs.children[1].select(lambda: "custom_mode", None, mode_state) | |
| btn.click(ui_synthesize, [inp_text, inp_voice, inp_audio, inp_ref_text, mode_state, inp_speed], [out_audio, out_status]) | |
| # --- 6. MOUNT GRADIO VÀO FASTAPI --- | |
| # Đây là bước quan trọng nhất để chạy cả 2 cùng lúc | |
| app = gr.mount_gradio_app(app, demo, path="/") | |
| # --- 7. CHẠY SERVER --- | |
| if __name__ == "__main__": | |
| import uvicorn | |
| # Chạy uvicorn thay vì demo.launch() | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |