Spaces:
Sleeping
Sleeping
File size: 7,101 Bytes
cd71d82 079aad4 41b25d7 cd71d82 079aad4 cd71d82 079aad4 cd71d82 079aad4 5a4f853 cff0d4e 5a4f853 cd71d82 8bca434 cff0d4e 6a6f4a2 cff0d4e cd71d82 079aad4 8c2d6d8 cd71d82 adebebc c5bfe73 adebebc cd71d82 079aad4 cd71d82 079aad4 bddffef 079aad4 cd71d82 079aad4 cd71d82 079aad4 cd71d82 079aad4 cd71d82 079aad4 cd71d82 079aad4 cd71d82 049e266 cd71d82 079aad4 049e266 cd71d82 079aad4 cd71d82 079aad4 cd71d82 079aad4 cd71d82 079aad4 cd71d82 0b0e589 079aad4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
import os
import json
import tempfile
import traceback
import gradio as gr
import numpy as np
import soundfile as sf
import torch
from huggingface_hub import hf_hub_download
import yaml
from inference import StyleTTS2
# =========================
# PATHS
# =========================
SPACE_ROOT = os.path.dirname(os.path.abspath(__file__))
DATA_ROOT = os.path.join(SPACE_ROOT, "demo_data")
SPEAKER2REFS_PATH = os.path.join(DATA_ROOT, "speaker2refs.json")
# Model repo (ckpt + config)
CKPT_REPO = "stephenhoang/ttsStyleTTS2-ms152"
models_path = hf_hub_download(repo_id=CKPT_REPO, filename="epoch_00000.pth")
config_path = hf_hub_download(repo_id=CKPT_REPO, filename="config.yaml")
cfg = yaml.safe_load(open(config_path, "r", encoding="utf-8"))
device = "cuda" if torch.cuda.is_available() else "cpu"
# =======================
# ================= DEBUG SYMBOLS =================
try:
symbols = (
list(cfg['symbol']['pad']) +
list(cfg['symbol']['punctuation']) +
list(cfg['symbol']['letters']) +
list(cfg['symbol']['letters_ipa']) +
list(cfg['symbol']['extend'])
)
print("\n========== SYMBOL DEBUG ==========")
print("Total symbols (+pad):", len(symbols))
for i in range(min(30, len(symbols))):
print(i, repr(symbols[i]))
print("==================================\n")
except Exception as e:
print("❌ SYMBOL DEBUG ERROR:", e)
# =================================================
# LOAD speaker2refs.json
# =========================
if not os.path.isfile(SPEAKER2REFS_PATH):
raise FileNotFoundError(f"speaker2refs.json not found: {SPEAKER2REFS_PATH}")
with open(SPEAKER2REFS_PATH, "r", encoding="utf-8") as f:
SPEAKER2REFS = json.load(f)
SPEAKER_CHOICES = sorted(SPEAKER2REFS.keys())
if not SPEAKER_CHOICES:
raise RuntimeError("speaker2refs.json is empty (no speakers found).")
def _abs_ref_path(p: str) -> str:
"""
Hỗ trợ cả 2 kiểu:
- "refs/id_1.wav"
- "demo_data/refs/id_1.wav"
"""
p = p.lstrip("./")
if os.path.isabs(p):
return p
if p.startswith("demo_data/"):
return os.path.join(SPACE_ROOT, p)
return os.path.join(DATA_ROOT, p)
# =========================
# LOAD MODEL
# =========================
model = StyleTTS2(config_path, models_path).eval().to(device)
# ================= VOCAB DEBUG =================
ckpt = torch.load(models_path, map_location="cpu")
for k, v in ckpt["net"].items():
if "embedding.weight" in k:
print("✅ CKPT embedding:", v.shape)
print("✅ Runtime symbols:", len(symbols))
# Nếu có sẵn text_tensor ở scope thì in, còn không thì bỏ dòng này
try:
print("✅ Text tokens sample:", text_tensor[:30])
except:
print("⚠️ text_tensor chưa tồn tại ở đây")
# ===============================================
# =========================
# STYLE CACHE
# =========================
STYLE_CACHE = {}
STYLE_CACHE_MAX = 64
def _cache_get(key):
return STYLE_CACHE.get(key, None)
def _cache_set(key, val):
if key in STYLE_CACHE:
STYLE_CACHE[key] = val
return
if len(STYLE_CACHE) >= STYLE_CACHE_MAX:
STYLE_CACHE.pop(next(iter(STYLE_CACHE)))
STYLE_CACHE[key] = val
@torch.inference_mode()
def synth_one_speaker(speaker_name: str, text_prompt: str,
denoise: float, avg_style: bool, stabilize: bool):
try:
if not speaker_name:
return None, "Bạn chưa chọn speaker."
info = SPEAKER2REFS.get(speaker_name, None)
if info is None:
return None, f"Speaker '{speaker_name}' không tồn tại trong speaker2refs.json."
# info là dict: {"path":..., "lang":..., "speed":..., ...}
if not isinstance(info, dict) or "path" not in info:
return None, f"Format speaker2refs.json sai cho speaker '{speaker_name}'. Expect dict có field 'path'."
ref_path = _abs_ref_path(info["path"])
lang = info.get("lang", "vi")
speed = float(info.get("speed", 1.0))
if not os.path.isfile(ref_path):
return None, f"Ref audio not found: {ref_path}"
if not text_prompt or not text_prompt.strip():
return None, "Bạn chưa nhập text."
speakers = {
"id_1": {"path": ref_path, "lang": lang, "speed": speed}
}
cache_key = (speaker_name, float(denoise), bool(avg_style))
styles = _cache_get(cache_key)
if styles is None:
styles = model.get_styles(speakers, denoise=denoise, avg_style=avg_style)
_cache_set(cache_key, styles)
text_prompt = text_prompt.strip()
if "[id_" not in text_prompt:
text_prompt = "[id_1] " + text_prompt
wav = model.generate(
text_prompt,
styles,
stabilize=stabilize,
n_merge=18,
default_speaker="[id_1]"
)
wav = np.asarray(wav, dtype=np.float32)
if wav.size == 0:
return None, "Model output rỗng (0 samples). Kiểm tra phonemizer/espeak và tokenization."
# normalize (không làm mất tiếng)
peak = float(np.max(np.abs(wav)))
if peak > 1e-6:
wav = wav / peak
out_f = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
out_path = out_f.name
out_f.close()
sf.write(out_path, wav, samplerate=16000)
status = (
"OK\n"
f"speaker: {speaker_name}\n"
f"ref: {ref_path}\n"
f"lang: {lang}, speed: {speed}\n"
f"samples: {wav.shape[0]}, sec: {wav.shape[0]/1600016000:.3f}\n"
f"device: {device}"
)
return out_path, status
except Exception:
return None, traceback.format_exc()
# =========================
# GRADIO UI
# =========================
with gr.Blocks() as demo:
gr.HTML("<h2 style='text-align:center;'>TTS</h2>")
speaker_name = gr.Dropdown(
choices=SPEAKER_CHOICES,
label="Speaker Name (closed-set)",
value=SPEAKER_CHOICES[0],
interactive=True
)
text_prompt = gr.Textbox(
label="Text Prompt",
placeholder="Nhập câu tiếng Việt cần đọc...",
lines=4
)
with gr.Row():
denoise = gr.Slider(0.0, 1.0, step=0.1, value=0.3, label="Denoise Strength")
avg_style = gr.Checkbox(label="Use Average Styles", value=True)
stabilize = gr.Checkbox(label="Stabilize Speaking Speed", value=True)
gen_button = gr.Button("Generate")
synthesized_audio = gr.Audio(label="Generated Audio", type="filepath")
status = gr.Textbox(label="Status", lines=6, interactive=False)
gen_button.click(
fn=synth_one_speaker,
inputs=[speaker_name, text_prompt, denoise, avg_style, stabilize],
outputs=[synthesized_audio, status],
concurrency_limit=1,
)
# Gradio: dùng queue() chuẩn, không dùng concurrency_count
demo.queue(max_size=8, default_concurrency_limit=1) # theo docs :contentReference[oaicite:2]{index=2}
demo.launch()
|