Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,462 Bytes
4a8f5a5 56ad495 d46de93 56ad495 d46de93 4196956 d46de93 56ad495 d46de93 56ad495 4a8f5a5 56ad495 4a8f5a5 4196956 56ad495 4a8f5a5 56ad495 8ad7767 56ad495 8ad7767 56ad495 4a8f5a5 56ad495 4a8f5a5 d593d54 4a8f5a5 4196956 4a8f5a5 d593d54 4196956 847b717 4196956 8ad7767 4a8f5a5 56ad495 8ad7767 4a8f5a5 8ad7767 4196956 d593d54 8ad7767 4a8f5a5 4196956 d593d54 4196956 d593d54 8ad7767 4a8f5a5 8ad7767 d593d54 4a8f5a5 8ad7767 4196956 8ad7767 4196956 8ad7767 56ad495 4a8f5a5 4196956 4a8f5a5 847b717 4a8f5a5 8ad7767 56ad495 8ad7767 4a8f5a5 56ad495 8ad7767 4a8f5a5 8ad7767 4a8f5a5 56ad495 8ad7767 4a8f5a5 56ad495 4a8f5a5 56ad495 4a8f5a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import spaces # <--- BẮT BUỘC DÒNG 1
import os
import time
import threading
import pickle
import hashlib
import tempfile
import numpy as np
# Các thư viện khác
import torch
import soundfile as sf
from pydub import AudioSegment
import gradio as gr
from vieneu_tts import VieNeuTTS
print("⏳ Đang khởi động Server Gradio...")
# --- 1. QUẢN LÝ MODEL (Lazy Loading) ---
tts_model = None
model_lock = threading.Lock()
def get_tts_model():
"""Chỉ tải model khi có người dùng gọi (Tiết kiệm tài nguyên khởi động)"""
global tts_model
with model_lock:
if tts_model is None:
print("📦 Đang khởi tạo model lần đầu (Lazy Load)...")
# ZeroGPU yêu cầu khởi tạo model trên CPU hoặc trong hàm @spaces.GPU
# Ở đây ta khởi tạo trên CPU cho an toàn
tts_model = VieNeuTTS(
backbone_repo="pnnbao-ump/VieNeu-TTS",
backbone_device="cpu",
codec_repo="neuphonic/neucodec",
codec_device="cpu"
)
print("✅ Model tải thành công!")
return tts_model
# --- 2. XỬ LÝ CACHE ---
CACHE_DIR = "./reference_cache"
os.makedirs(CACHE_DIR, exist_ok=True)
reference_cache = {}
reference_cache_lock = threading.Lock()
def get_cache_path(cache_key):
key_hash = hashlib.md5(cache_key.encode()).hexdigest()
return os.path.join(CACHE_DIR, f"{key_hash}.pkl")
def load_cache_from_disk(cache_key):
cache_path = get_cache_path(cache_key)
if os.path.exists(cache_path):
try:
with open(cache_path, 'rb') as f: return pickle.load(f)
except: return None
return None
def save_cache_to_disk(cache_key, ref_codes):
cache_path = get_cache_path(cache_key)
try:
with open(cache_path, 'wb') as f: pickle.dump(ref_codes, f)
except Exception: pass
# --- 3. DỮ LIỆU GIỌNG NÓI ---
VOICE_SAMPLES = {
"Tuyên (nam miền Bắc)": {"audio": "./sample/Tuyên (nam miền Bắc).wav", "text": "./sample/Tuyên (nam miền Bắc).txt"},
"Vĩnh (nam miền Nam)": {"audio": "./sample/Vĩnh (nam miền Nam).wav", "text": "./sample/Vĩnh (nam miền Nam).txt"},
"Bình (nam miền Bắc)": {"audio": "./sample/Bình (nam miền Bắc).wav", "text": "./sample/Bình (nam miền Bắc).txt"},
"Nguyên (nam miền Nam)": {"audio": "./sample/Nguyên (nam miền Nam).wav", "text": "./sample/Nguyên (nam miền Nam).txt"},
"Sơn (nam miền Nam)": {"audio": "./sample/Sơn (nam miền Nam).wav", "text": "./sample/Sơn (nam miền Nam).txt"},
"Đoan (nữ miền Nam)": {"audio": "./sample/Đoan (nữ miền Nam).wav", "text": "./sample/Đoan (nữ miền Nam).txt"},
"Ngọc (nữ miền Bắc)": {"audio": "./sample/Ngọc (nữ miền Bắc).wav", "text": "./sample/Ngọc (nữ miền Bắc).txt"},
"Ly (nữ miền Bắc)": {"audio": "./sample/Ly (nữ miền Bắc).wav", "text": "./sample/Ly (nữ miền Bắc).txt"},
"Dung (nữ miền Nam)": {"audio": "./sample/Dung (nữ miền Nam).wav", "text": "./sample/Dung (nữ miền Nam).txt"},
"Nhỏ Ngọt Ngào": {"audio": "./sample/Nhỏ Ngọt Ngào.wav", "text": "./sample/Nhỏ Ngọt Ngào.txt"},
}
# --- 4. HÀM XỬ LÝ CHÍNH (GPU) ---
@spaces.GPU(duration=120)
def generate_speech(text, voice_choice, speed_factor):
"""
Hàm này sẽ được ZeroGPU cấp phát GPU khi chạy.
Nó cũng đóng vai trò là API endpoint chính.
"""
start_time = time.time()
# 1. Lấy Model (Tải nếu chưa có)
tts = get_tts_model()
# 2. Chuyển Model sang GPU (Chỉ làm trong hàm này)
if torch.cuda.is_available():
try:
if next(tts.backbone.parameters()).device.type != 'cuda':
tts.backbone.to("cuda")
tts.codec.to("cuda")
except: pass
# 3. Lấy thông tin giọng
voice_info = VOICE_SAMPLES.get(voice_choice)
if not voice_info:
# Fallback nếu không tìm thấy giọng
voice_choice = "Tuyên (nam miền Bắc)"
voice_info = VOICE_SAMPLES[voice_choice]
ref_audio_path = voice_info["audio"]
ref_text_path = voice_info["text"]
with open(ref_text_path, "r", encoding="utf-8") as f:
ref_text_raw = f.read()
# 4. Encode Reference (Có Cache)
cache_key = f"preset:{voice_choice}"
with reference_cache_lock:
if cache_key in reference_cache:
ref_codes = reference_cache[cache_key]
if isinstance(ref_codes, torch.Tensor) and torch.cuda.is_available():
ref_codes = ref_codes.to("cuda")
else:
ref_codes = load_cache_from_disk(cache_key)
if ref_codes is None:
# Encode
ref_codes = tts.encode_reference(ref_audio_path)
save_cache_to_disk(cache_key, ref_codes.cpu() if isinstance(ref_codes, torch.Tensor) else ref_codes)
if isinstance(ref_codes, torch.Tensor) and torch.cuda.is_available():
ref_codes = ref_codes.to("cuda")
reference_cache[cache_key] = ref_codes
# 5. Infer (Tạo giọng nói)
wav = tts.infer(text, ref_codes, ref_text_raw)
# 6. Xử lý tốc độ (Speed)
if speed_factor != 1.0:
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
sf.write(tmp.name, wav, 24000)
tmp_path = tmp.name
sound = AudioSegment.from_wav(tmp_path)
new_frame_rate = int(sound.frame_rate * speed_factor)
sound_stretched = sound._spawn(sound.raw_data, overrides={'frame_rate': new_frame_rate})
sound_stretched = sound_stretched.set_frame_rate(24000)
wav = np.array(sound_stretched.get_array_of_samples()).astype(np.float32) / 32768.0
if sound_stretched.channels == 2:
wav = wav.reshape((-1, 2)).mean(axis=1)
os.unlink(tmp_path)
# 7. Lưu file kết quả
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
sf.write(tmp_file.name, wav, 24000)
output_path = tmp_file.name
return output_path, f"✅ Hoàn tất ({time.time() - start_time:.2f}s)"
# --- 5. GIAO DIỆN GRADIO ---
theme = gr.themes.Soft()
css = ".container { max-width: 900px; margin: auto; }"
with gr.Blocks(theme=theme, css=css, title="VieNeu-TTS") as demo:
gr.Markdown("# 🎙️ VieNeu-TTS (ZeroGPU)")
with gr.Row():
with gr.Column():
inp_text = gr.Textbox(label="Văn bản", lines=3, value="Xin chào Việt Nam, đây là thử nghiệm giọng nói.")
inp_voice = gr.Dropdown(list(VOICE_SAMPLES.keys()), value="Tuyên (nam miền Bắc)", label="Chọn giọng")
inp_speed = gr.Slider(0.5, 2.0, value=1.0, label="Tốc độ")
btn = gr.Button("Đọc ngay", variant="primary")
with gr.Column():
out_audio = gr.Audio(label="Kết quả", autoplay=True)
out_status = gr.Textbox(label="Trạng thái")
# Map function vào button
btn.click(generate_speech, [inp_text, inp_voice, inp_speed], [out_audio, out_status])
# --- 6. KHỞI CHẠY ---
if __name__ == "__main__":
# Dùng demo.launch() chuẩn để ZeroGPU nhận diện được
demo.queue(default_concurrency_limit=40).launch(server_name="0.0.0.0", server_port=7860) |