import os import json import tempfile import traceback import gradio as gr import numpy as np import soundfile as sf import torch from huggingface_hub import hf_hub_download import yaml from inference import StyleTTS2 # ========================= # PATHS # ========================= SPACE_ROOT = os.path.dirname(os.path.abspath(__file__)) DATA_ROOT = os.path.join(SPACE_ROOT, "demo_data") SPEAKER2REFS_PATH = os.path.join(DATA_ROOT, "speaker2refs.json") # Model repo (ckpt + config) CKPT_REPO = "stephenhoang/ttsStyleTTS2-ms152" models_path = hf_hub_download(repo_id=CKPT_REPO, filename="epoch_00000.pth") config_path = hf_hub_download(repo_id=CKPT_REPO, filename="config.yaml") cfg = yaml.safe_load(open(config_path, "r", encoding="utf-8")) device = "cuda" if torch.cuda.is_available() else "cpu" # ======================= # ================= DEBUG SYMBOLS ================= try: symbols = ( list(cfg['symbol']['pad']) + list(cfg['symbol']['punctuation']) + list(cfg['symbol']['letters']) + list(cfg['symbol']['letters_ipa']) + list(cfg['symbol']['extend']) ) print("\n========== SYMBOL DEBUG ==========") print("Total symbols (+pad):", len(symbols)) for i in range(min(30, len(symbols))): print(i, repr(symbols[i])) print("==================================\n") except Exception as e: print("❌ SYMBOL DEBUG ERROR:", e) # ================================================= # LOAD speaker2refs.json # ========================= if not os.path.isfile(SPEAKER2REFS_PATH): raise FileNotFoundError(f"speaker2refs.json not found: {SPEAKER2REFS_PATH}") with open(SPEAKER2REFS_PATH, "r", encoding="utf-8") as f: SPEAKER2REFS = json.load(f) SPEAKER_CHOICES = sorted(SPEAKER2REFS.keys()) if not SPEAKER_CHOICES: raise RuntimeError("speaker2refs.json is empty (no speakers found).") def _abs_ref_path(p: str) -> str: """ Hỗ trợ cả 2 kiểu: - "refs/id_1.wav" - "demo_data/refs/id_1.wav" """ p = p.lstrip("./") if os.path.isabs(p): return p if p.startswith("demo_data/"): return os.path.join(SPACE_ROOT, p) return os.path.join(DATA_ROOT, p) # ========================= # LOAD MODEL # ========================= model = StyleTTS2(config_path, models_path).eval().to(device) # ================= VOCAB DEBUG ================= ckpt = torch.load(models_path, map_location="cpu") for k, v in ckpt["net"].items(): if "embedding.weight" in k: print("✅ CKPT embedding:", v.shape) print("✅ Runtime symbols:", len(symbols)) # Nếu có sẵn text_tensor ở scope thì in, còn không thì bỏ dòng này try: print("✅ Text tokens sample:", text_tensor[:30]) except: print("⚠️ text_tensor chưa tồn tại ở đây") # =============================================== # ========================= # STYLE CACHE # ========================= STYLE_CACHE = {} STYLE_CACHE_MAX = 64 def _cache_get(key): return STYLE_CACHE.get(key, None) def _cache_set(key, val): if key in STYLE_CACHE: STYLE_CACHE[key] = val return if len(STYLE_CACHE) >= STYLE_CACHE_MAX: STYLE_CACHE.pop(next(iter(STYLE_CACHE))) STYLE_CACHE[key] = val @torch.inference_mode() def synth_one_speaker(speaker_name: str, text_prompt: str, denoise: float, avg_style: bool, stabilize: bool): try: if not speaker_name: return None, "Bạn chưa chọn speaker." info = SPEAKER2REFS.get(speaker_name, None) if info is None: return None, f"Speaker '{speaker_name}' không tồn tại trong speaker2refs.json." # info là dict: {"path":..., "lang":..., "speed":..., ...} if not isinstance(info, dict) or "path" not in info: return None, f"Format speaker2refs.json sai cho speaker '{speaker_name}'. Expect dict có field 'path'." ref_path = _abs_ref_path(info["path"]) lang = info.get("lang", "vi") speed = float(info.get("speed", 1.0)) if not os.path.isfile(ref_path): return None, f"Ref audio not found: {ref_path}" if not text_prompt or not text_prompt.strip(): return None, "Bạn chưa nhập text." speakers = { "id_1": {"path": ref_path, "lang": lang, "speed": speed} } cache_key = (speaker_name, float(denoise), bool(avg_style)) styles = _cache_get(cache_key) if styles is None: styles = model.get_styles(speakers, denoise=denoise, avg_style=avg_style) _cache_set(cache_key, styles) text_prompt = text_prompt.strip() if "[id_" not in text_prompt: text_prompt = "[id_1] " + text_prompt wav = model.generate( text_prompt, styles, stabilize=stabilize, n_merge=18, default_speaker="[id_1]" ) wav = np.asarray(wav, dtype=np.float32) if wav.size == 0: return None, "Model output rỗng (0 samples). Kiểm tra phonemizer/espeak và tokenization." # normalize (không làm mất tiếng) peak = float(np.max(np.abs(wav))) if peak > 1e-6: wav = wav / peak out_f = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") out_path = out_f.name out_f.close() sf.write(out_path, wav, samplerate=16000) status = ( "OK\n" f"speaker: {speaker_name}\n" f"ref: {ref_path}\n" f"lang: {lang}, speed: {speed}\n" f"samples: {wav.shape[0]}, sec: {wav.shape[0]/1600016000:.3f}\n" f"device: {device}" ) return out_path, status except Exception: return None, traceback.format_exc() # ========================= # GRADIO UI # ========================= with gr.Blocks() as demo: gr.HTML("