import spaces # <--- QUAN TRỌNG: PHẢI ĐỂ DÒNG ĐẦU TIÊN import os import time import threading import pickle import hashlib import base64 import io import tempfile import numpy as np # Các thư viện khác import sau spaces import torch import soundfile as sf from pydub import AudioSegment import gradio as gr from fastapi import FastAPI, HTTPException from pydantic import BaseModel # Import thư viện nội bộ from vieneu_tts import VieNeuTTS # --- KHỞI TẠO FASTAPI --- app = FastAPI() print("⏳ Đang khởi động VieNeu-TTS...") # --- 1. SETUP MODEL --- device = "cuda" if torch.cuda.is_available() else "cpu" print(f"🖥️ Sử dụng thiết bị (Global): {device.upper()}") # Cache CACHE_DIR = "./reference_cache" os.makedirs(CACHE_DIR, exist_ok=True) reference_cache = {} reference_cache_lock = threading.Lock() # Hàm Cache Helper def get_cache_path(cache_key): key_hash = hashlib.md5(cache_key.encode()).hexdigest() return os.path.join(CACHE_DIR, f"{key_hash}.pkl") def load_cache_from_disk(cache_key): cache_path = get_cache_path(cache_key) if os.path.exists(cache_path): try: with open(cache_path, 'rb') as f: return pickle.load(f) except: return None return None def save_cache_to_disk(cache_key, ref_codes): cache_path = get_cache_path(cache_key) try: with open(cache_path, 'wb') as f: pickle.dump(ref_codes, f) except Exception: pass # Load Model try: print("📦 Đang tải model vào bộ nhớ...") tts = VieNeuTTS( backbone_repo="pnnbao-ump/VieNeu-TTS", backbone_device=device, codec_repo="neuphonic/neucodec", codec_device=device ) print("✅ Model đã tải xong!") except Exception as e: print(f"⚠️ Lỗi tải model: {e}") tts = None # --- 2. DATA --- VOICE_SAMPLES = { "Tuyên (nam miền Bắc)": {"audio": "./sample/Tuyên (nam miền Bắc).wav", "text": "./sample/Tuyên (nam miền Bắc).txt"}, "Vĩnh (nam miền Nam)": {"audio": "./sample/Vĩnh (nam miền Nam).wav", "text": "./sample/Vĩnh (nam miền Nam).txt"}, "Bình (nam miền Bắc)": {"audio": "./sample/Bình (nam miền Bắc).wav", "text": "./sample/Bình (nam miền Bắc).txt"}, "Nguyên (nam miền Nam)": {"audio": "./sample/Nguyên (nam miền Nam).wav", "text": "./sample/Nguyên (nam miền Nam).txt"}, "Sơn (nam miền Nam)": {"audio": "./sample/Sơn (nam miền Nam).wav", "text": "./sample/Sơn (nam miền Nam).txt"}, "Đoan (nữ miền Nam)": {"audio": "./sample/Đoan (nữ miền Nam).wav", "text": "./sample/Đoan (nữ miền Nam).txt"}, "Ngọc (nữ miền Bắc)": {"audio": "./sample/Ngọc (nữ miền Bắc).wav", "text": "./sample/Ngọc (nữ miền Bắc).txt"}, "Ly (nữ miền Bắc)": {"audio": "./sample/Ly (nữ miền Bắc).wav", "text": "./sample/Ly (nữ miền Bắc).txt"}, "Dung (nữ miền Nam)": {"audio": "./sample/Dung (nữ miền Nam).wav", "text": "./sample/Dung (nữ miền Nam).txt"}, "Nhỏ Ngọt Ngào": {"audio": "./sample/Nhỏ Ngọt Ngào.wav", "text": "./sample/Nhỏ Ngọt Ngào.txt"}, } # --- 3. CORE LOGIC (Dùng chung cho cả API và UI) --- # QUAN TRỌNG: Decorator GPU @spaces.GPU def core_synthesize(text, voice_choice, speed_factor): # Lấy thông tin giọng voice_info = VOICE_SAMPLES.get(voice_choice) if not voice_info: raise ValueError("Giọng không tồn tại") ref_audio_path = voice_info["audio"] ref_text_path = voice_info["text"] # Load reference text with open(ref_text_path, "r", encoding="utf-8") as f: ref_text_raw = f.read() # Encode reference (Cache logic) cache_key = f"preset:{voice_choice}" with reference_cache_lock: if cache_key in reference_cache: ref_codes = reference_cache[cache_key] else: ref_codes = load_cache_from_disk(cache_key) if ref_codes is None: # Đảm bảo dọn dẹp bộ nhớ trước khi encode if torch.cuda.is_available(): torch.cuda.empty_cache() ref_codes = tts.encode_reference(ref_audio_path) save_cache_to_disk(cache_key, ref_codes) reference_cache[cache_key] = ref_codes # Infer if torch.cuda.is_available(): torch.cuda.empty_cache() wav = tts.infer(text, ref_codes, ref_text_raw) # Speed if speed_factor != 1.0: with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp: sf.write(tmp.name, wav, 24000) tmp_path = tmp.name sound = AudioSegment.from_wav(tmp_path) new_frame_rate = int(sound.frame_rate * speed_factor) sound_stretched = sound._spawn(sound.raw_data, overrides={'frame_rate': new_frame_rate}) sound_stretched = sound_stretched.set_frame_rate(24000) wav = np.array(sound_stretched.get_array_of_samples()).astype(np.float32) / 32768.0 if sound_stretched.channels == 2: wav = wav.reshape((-1, 2)).mean(axis=1) os.unlink(tmp_path) return wav # Hàm riêng cho Custom Voice cũng cần GPU @spaces.GPU def custom_synthesize_logic(text, ref_audio_path, ref_text_raw): if torch.cuda.is_available(): torch.cuda.empty_cache() ref_codes = tts.encode_reference(ref_audio_path) wav = tts.infer(text, ref_codes, ref_text_raw) return wav # --- 4. API ENDPOINTS (Cho Client App kết nối) --- class FastTTSRequest(BaseModel): text: str voice_choice: str speed_factor: float = 1.0 return_base64: bool = False @app.get("/voices") async def get_voices(): return {"voices": list(VOICE_SAMPLES.keys())} @app.post("/fast-tts") async def fast_tts(request: FastTTSRequest): try: start = time.time() # Gọi hàm đã được decorate @spaces.GPU wav = core_synthesize(request.text, request.voice_choice, request.speed_factor) process_time = time.time() - start # Convert to Base64 audio_buffer = io.BytesIO() sf.write(audio_buffer, wav, 24000, format='WAV') audio_bytes = audio_buffer.getvalue() audio_base64 = base64.b64encode(audio_bytes).decode('utf-8') return { "status": "success", "audio_base64": audio_base64, "processing_time": process_time } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # --- 5. GRADIO UI SETUP --- theme = gr.themes.Soft() css = ".container { max-width: 900px; margin: auto; }" def ui_synthesize(text, voice, custom_audio, custom_text, mode, speed): try: start = time.time() if mode == "custom_mode": wav = custom_synthesize_logic(text, custom_audio, custom_text) else: wav = core_synthesize(text, voice, speed) with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: sf.write(tmp.name, wav, 24000) path = tmp.name return path, f"✅ Xong! ({time.time()-start:.2f}s)" except Exception as e: return None, f"❌ Lỗi: {e}" with gr.Blocks(theme=theme, css=css, title="VieNeu-TTS") as demo: gr.Markdown("# 🎙️ VieNeu-TTS (API + UI)") with gr.Row(): with gr.Column(): inp_text = gr.Textbox(label="Văn bản", lines=3, value="Xin chào Việt Nam") with gr.Tabs() as tabs: with gr.TabItem("Giọng mẫu", id="preset_mode"): inp_voice = gr.Dropdown(list(VOICE_SAMPLES.keys()), value="Tuyên (nam miền Bắc)", label="Chọn giọng") with gr.TabItem("Custom", id="custom_mode"): inp_audio = gr.Audio(type="filepath") inp_ref_text = gr.Textbox(label="Lời thoại mẫu") inp_speed = gr.Slider(0.5, 2.0, value=1.0, label="Tốc độ") btn = gr.Button("Đọc ngay", variant="primary") with gr.Column(): out_audio = gr.Audio(label="Kết quả", autoplay=True) out_status = gr.Textbox(label="Trạng thái") mode_state = gr.Textbox(visible=False, value="preset_mode") tabs.children[0].select(lambda: "preset_mode", None, mode_state) tabs.children[1].select(lambda: "custom_mode", None, mode_state) btn.click(ui_synthesize, [inp_text, inp_voice, inp_audio, inp_ref_text, mode_state, inp_speed], [out_audio, out_status]) # --- 6. MOUNT GRADIO VÀO FASTAPI --- app = gr.mount_gradio_app(app, demo, path="/") # --- 7. CHẠY SERVER --- if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)