ttsStyleTTS2 / app.py
stephenhoang's picture
Update app.py
c5bfe73 verified
import os
import json
import tempfile
import traceback
import gradio as gr
import numpy as np
import soundfile as sf
import torch
from huggingface_hub import hf_hub_download
import yaml
from inference import StyleTTS2
# =========================
# PATHS
# =========================
SPACE_ROOT = os.path.dirname(os.path.abspath(__file__))
DATA_ROOT = os.path.join(SPACE_ROOT, "demo_data")
SPEAKER2REFS_PATH = os.path.join(DATA_ROOT, "speaker2refs.json")
# Model repo (ckpt + config)
CKPT_REPO = "stephenhoang/ttsStyleTTS2-ms152"
models_path = hf_hub_download(repo_id=CKPT_REPO, filename="epoch_00000.pth")
config_path = hf_hub_download(repo_id=CKPT_REPO, filename="config.yaml")
cfg = yaml.safe_load(open(config_path, "r", encoding="utf-8"))
device = "cuda" if torch.cuda.is_available() else "cpu"
# =======================
# ================= DEBUG SYMBOLS =================
try:
symbols = (
list(cfg['symbol']['pad']) +
list(cfg['symbol']['punctuation']) +
list(cfg['symbol']['letters']) +
list(cfg['symbol']['letters_ipa']) +
list(cfg['symbol']['extend'])
)
print("\n========== SYMBOL DEBUG ==========")
print("Total symbols (+pad):", len(symbols))
for i in range(min(30, len(symbols))):
print(i, repr(symbols[i]))
print("==================================\n")
except Exception as e:
print("❌ SYMBOL DEBUG ERROR:", e)
# =================================================
# LOAD speaker2refs.json
# =========================
if not os.path.isfile(SPEAKER2REFS_PATH):
raise FileNotFoundError(f"speaker2refs.json not found: {SPEAKER2REFS_PATH}")
with open(SPEAKER2REFS_PATH, "r", encoding="utf-8") as f:
SPEAKER2REFS = json.load(f)
SPEAKER_CHOICES = sorted(SPEAKER2REFS.keys())
if not SPEAKER_CHOICES:
raise RuntimeError("speaker2refs.json is empty (no speakers found).")
def _abs_ref_path(p: str) -> str:
"""
Hỗ trợ cả 2 kiểu:
- "refs/id_1.wav"
- "demo_data/refs/id_1.wav"
"""
p = p.lstrip("./")
if os.path.isabs(p):
return p
if p.startswith("demo_data/"):
return os.path.join(SPACE_ROOT, p)
return os.path.join(DATA_ROOT, p)
# =========================
# LOAD MODEL
# =========================
model = StyleTTS2(config_path, models_path).eval().to(device)
# ================= VOCAB DEBUG =================
ckpt = torch.load(models_path, map_location="cpu")
for k, v in ckpt["net"].items():
if "embedding.weight" in k:
print("✅ CKPT embedding:", v.shape)
print("✅ Runtime symbols:", len(symbols))
# Nếu có sẵn text_tensor ở scope thì in, còn không thì bỏ dòng này
try:
print("✅ Text tokens sample:", text_tensor[:30])
except:
print("⚠️ text_tensor chưa tồn tại ở đây")
# ===============================================
# =========================
# STYLE CACHE
# =========================
STYLE_CACHE = {}
STYLE_CACHE_MAX = 64
def _cache_get(key):
return STYLE_CACHE.get(key, None)
def _cache_set(key, val):
if key in STYLE_CACHE:
STYLE_CACHE[key] = val
return
if len(STYLE_CACHE) >= STYLE_CACHE_MAX:
STYLE_CACHE.pop(next(iter(STYLE_CACHE)))
STYLE_CACHE[key] = val
@torch.inference_mode()
def synth_one_speaker(speaker_name: str, text_prompt: str,
denoise: float, avg_style: bool, stabilize: bool):
try:
if not speaker_name:
return None, "Bạn chưa chọn speaker."
info = SPEAKER2REFS.get(speaker_name, None)
if info is None:
return None, f"Speaker '{speaker_name}' không tồn tại trong speaker2refs.json."
# info là dict: {"path":..., "lang":..., "speed":..., ...}
if not isinstance(info, dict) or "path" not in info:
return None, f"Format speaker2refs.json sai cho speaker '{speaker_name}'. Expect dict có field 'path'."
ref_path = _abs_ref_path(info["path"])
lang = info.get("lang", "vi")
speed = float(info.get("speed", 1.0))
if not os.path.isfile(ref_path):
return None, f"Ref audio not found: {ref_path}"
if not text_prompt or not text_prompt.strip():
return None, "Bạn chưa nhập text."
speakers = {
"id_1": {"path": ref_path, "lang": lang, "speed": speed}
}
cache_key = (speaker_name, float(denoise), bool(avg_style))
styles = _cache_get(cache_key)
if styles is None:
styles = model.get_styles(speakers, denoise=denoise, avg_style=avg_style)
_cache_set(cache_key, styles)
text_prompt = text_prompt.strip()
if "[id_" not in text_prompt:
text_prompt = "[id_1] " + text_prompt
wav = model.generate(
text_prompt,
styles,
stabilize=stabilize,
n_merge=18,
default_speaker="[id_1]"
)
wav = np.asarray(wav, dtype=np.float32)
if wav.size == 0:
return None, "Model output rỗng (0 samples). Kiểm tra phonemizer/espeak và tokenization."
# normalize (không làm mất tiếng)
peak = float(np.max(np.abs(wav)))
if peak > 1e-6:
wav = wav / peak
out_f = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
out_path = out_f.name
out_f.close()
sf.write(out_path, wav, samplerate=16000)
status = (
"OK\n"
f"speaker: {speaker_name}\n"
f"ref: {ref_path}\n"
f"lang: {lang}, speed: {speed}\n"
f"samples: {wav.shape[0]}, sec: {wav.shape[0]/1600016000:.3f}\n"
f"device: {device}"
)
return out_path, status
except Exception:
return None, traceback.format_exc()
# =========================
# GRADIO UI
# =========================
with gr.Blocks() as demo:
gr.HTML("<h2 style='text-align:center;'>TTS</h2>")
speaker_name = gr.Dropdown(
choices=SPEAKER_CHOICES,
label="Speaker Name (closed-set)",
value=SPEAKER_CHOICES[0],
interactive=True
)
text_prompt = gr.Textbox(
label="Text Prompt",
placeholder="Nhập câu tiếng Việt cần đọc...",
lines=4
)
with gr.Row():
denoise = gr.Slider(0.0, 1.0, step=0.1, value=0.3, label="Denoise Strength")
avg_style = gr.Checkbox(label="Use Average Styles", value=True)
stabilize = gr.Checkbox(label="Stabilize Speaking Speed", value=True)
gen_button = gr.Button("Generate")
synthesized_audio = gr.Audio(label="Generated Audio", type="filepath")
status = gr.Textbox(label="Status", lines=6, interactive=False)
gen_button.click(
fn=synth_one_speaker,
inputs=[speaker_name, text_prompt, denoise, avg_style, stabilize],
outputs=[synthesized_audio, status],
concurrency_limit=1,
)
# Gradio: dùng queue() chuẩn, không dùng concurrency_count
demo.queue(max_size=8, default_concurrency_limit=1) # theo docs :contentReference[oaicite:2]{index=2}
demo.launch()