Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import tempfile | |
| import traceback | |
| import gradio as gr | |
| import numpy as np | |
| import soundfile as sf | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| import yaml | |
| from inference import StyleTTS2 | |
| # ========================= | |
| # PATHS | |
| # ========================= | |
| SPACE_ROOT = os.path.dirname(os.path.abspath(__file__)) | |
| DATA_ROOT = os.path.join(SPACE_ROOT, "demo_data") | |
| SPEAKER2REFS_PATH = os.path.join(DATA_ROOT, "speaker2refs.json") | |
| # Model repo (ckpt + config) | |
| CKPT_REPO = "stephenhoang/ttsStyleTTS2-ms152" | |
| models_path = hf_hub_download(repo_id=CKPT_REPO, filename="epoch_00000.pth") | |
| config_path = hf_hub_download(repo_id=CKPT_REPO, filename="config.yaml") | |
| cfg = yaml.safe_load(open(config_path, "r", encoding="utf-8")) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ======================= | |
| # ================= DEBUG SYMBOLS ================= | |
| try: | |
| symbols = ( | |
| list(cfg['symbol']['pad']) + | |
| list(cfg['symbol']['punctuation']) + | |
| list(cfg['symbol']['letters']) + | |
| list(cfg['symbol']['letters_ipa']) + | |
| list(cfg['symbol']['extend']) | |
| ) | |
| print("\n========== SYMBOL DEBUG ==========") | |
| print("Total symbols (+pad):", len(symbols)) | |
| for i in range(min(30, len(symbols))): | |
| print(i, repr(symbols[i])) | |
| print("==================================\n") | |
| except Exception as e: | |
| print("❌ SYMBOL DEBUG ERROR:", e) | |
| # ================================================= | |
| # LOAD speaker2refs.json | |
| # ========================= | |
| if not os.path.isfile(SPEAKER2REFS_PATH): | |
| raise FileNotFoundError(f"speaker2refs.json not found: {SPEAKER2REFS_PATH}") | |
| with open(SPEAKER2REFS_PATH, "r", encoding="utf-8") as f: | |
| SPEAKER2REFS = json.load(f) | |
| SPEAKER_CHOICES = sorted(SPEAKER2REFS.keys()) | |
| if not SPEAKER_CHOICES: | |
| raise RuntimeError("speaker2refs.json is empty (no speakers found).") | |
| def _abs_ref_path(p: str) -> str: | |
| """ | |
| Hỗ trợ cả 2 kiểu: | |
| - "refs/id_1.wav" | |
| - "demo_data/refs/id_1.wav" | |
| """ | |
| p = p.lstrip("./") | |
| if os.path.isabs(p): | |
| return p | |
| if p.startswith("demo_data/"): | |
| return os.path.join(SPACE_ROOT, p) | |
| return os.path.join(DATA_ROOT, p) | |
| # ========================= | |
| # LOAD MODEL | |
| # ========================= | |
| model = StyleTTS2(config_path, models_path).eval().to(device) | |
| # ================= VOCAB DEBUG ================= | |
| ckpt = torch.load(models_path, map_location="cpu") | |
| for k, v in ckpt["net"].items(): | |
| if "embedding.weight" in k: | |
| print("✅ CKPT embedding:", v.shape) | |
| print("✅ Runtime symbols:", len(symbols)) | |
| # Nếu có sẵn text_tensor ở scope thì in, còn không thì bỏ dòng này | |
| try: | |
| print("✅ Text tokens sample:", text_tensor[:30]) | |
| except: | |
| print("⚠️ text_tensor chưa tồn tại ở đây") | |
| # =============================================== | |
| # ========================= | |
| # STYLE CACHE | |
| # ========================= | |
| STYLE_CACHE = {} | |
| STYLE_CACHE_MAX = 64 | |
| def _cache_get(key): | |
| return STYLE_CACHE.get(key, None) | |
| def _cache_set(key, val): | |
| if key in STYLE_CACHE: | |
| STYLE_CACHE[key] = val | |
| return | |
| if len(STYLE_CACHE) >= STYLE_CACHE_MAX: | |
| STYLE_CACHE.pop(next(iter(STYLE_CACHE))) | |
| STYLE_CACHE[key] = val | |
| def synth_one_speaker(speaker_name: str, text_prompt: str, | |
| denoise: float, avg_style: bool, stabilize: bool): | |
| try: | |
| if not speaker_name: | |
| return None, "Bạn chưa chọn speaker." | |
| info = SPEAKER2REFS.get(speaker_name, None) | |
| if info is None: | |
| return None, f"Speaker '{speaker_name}' không tồn tại trong speaker2refs.json." | |
| # info là dict: {"path":..., "lang":..., "speed":..., ...} | |
| if not isinstance(info, dict) or "path" not in info: | |
| return None, f"Format speaker2refs.json sai cho speaker '{speaker_name}'. Expect dict có field 'path'." | |
| ref_path = _abs_ref_path(info["path"]) | |
| lang = info.get("lang", "vi") | |
| speed = float(info.get("speed", 1.0)) | |
| if not os.path.isfile(ref_path): | |
| return None, f"Ref audio not found: {ref_path}" | |
| if not text_prompt or not text_prompt.strip(): | |
| return None, "Bạn chưa nhập text." | |
| speakers = { | |
| "id_1": {"path": ref_path, "lang": lang, "speed": speed} | |
| } | |
| cache_key = (speaker_name, float(denoise), bool(avg_style)) | |
| styles = _cache_get(cache_key) | |
| if styles is None: | |
| styles = model.get_styles(speakers, denoise=denoise, avg_style=avg_style) | |
| _cache_set(cache_key, styles) | |
| text_prompt = text_prompt.strip() | |
| if "[id_" not in text_prompt: | |
| text_prompt = "[id_1] " + text_prompt | |
| wav = model.generate( | |
| text_prompt, | |
| styles, | |
| stabilize=stabilize, | |
| n_merge=18, | |
| default_speaker="[id_1]" | |
| ) | |
| wav = np.asarray(wav, dtype=np.float32) | |
| if wav.size == 0: | |
| return None, "Model output rỗng (0 samples). Kiểm tra phonemizer/espeak và tokenization." | |
| # normalize (không làm mất tiếng) | |
| peak = float(np.max(np.abs(wav))) | |
| if peak > 1e-6: | |
| wav = wav / peak | |
| out_f = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") | |
| out_path = out_f.name | |
| out_f.close() | |
| sf.write(out_path, wav, samplerate=16000) | |
| status = ( | |
| "OK\n" | |
| f"speaker: {speaker_name}\n" | |
| f"ref: {ref_path}\n" | |
| f"lang: {lang}, speed: {speed}\n" | |
| f"samples: {wav.shape[0]}, sec: {wav.shape[0]/1600016000:.3f}\n" | |
| f"device: {device}" | |
| ) | |
| return out_path, status | |
| except Exception: | |
| return None, traceback.format_exc() | |
| # ========================= | |
| # GRADIO UI | |
| # ========================= | |
| with gr.Blocks() as demo: | |
| gr.HTML("<h2 style='text-align:center;'>TTS</h2>") | |
| speaker_name = gr.Dropdown( | |
| choices=SPEAKER_CHOICES, | |
| label="Speaker Name (closed-set)", | |
| value=SPEAKER_CHOICES[0], | |
| interactive=True | |
| ) | |
| text_prompt = gr.Textbox( | |
| label="Text Prompt", | |
| placeholder="Nhập câu tiếng Việt cần đọc...", | |
| lines=4 | |
| ) | |
| with gr.Row(): | |
| denoise = gr.Slider(0.0, 1.0, step=0.1, value=0.3, label="Denoise Strength") | |
| avg_style = gr.Checkbox(label="Use Average Styles", value=True) | |
| stabilize = gr.Checkbox(label="Stabilize Speaking Speed", value=True) | |
| gen_button = gr.Button("Generate") | |
| synthesized_audio = gr.Audio(label="Generated Audio", type="filepath") | |
| status = gr.Textbox(label="Status", lines=6, interactive=False) | |
| gen_button.click( | |
| fn=synth_one_speaker, | |
| inputs=[speaker_name, text_prompt, denoise, avg_style, stabilize], | |
| outputs=[synthesized_audio, status], | |
| concurrency_limit=1, | |
| ) | |
| # Gradio: dùng queue() chuẩn, không dùng concurrency_count | |
| demo.queue(max_size=8, default_concurrency_limit=1) # theo docs :contentReference[oaicite:2]{index=2} | |
| demo.launch() | |