File size: 7,101 Bytes
cd71d82
 
 
 
 
 
 
 
 
079aad4
41b25d7
cd71d82
 
 
 
079aad4
cd71d82
079aad4
 
cd71d82
 
079aad4
5a4f853
 
 
cff0d4e
5a4f853
cd71d82
 
8bca434
cff0d4e
 
 
 
 
 
 
 
 
 
 
6a6f4a2
cff0d4e
 
 
 
 
 
 
 
 
 
cd71d82
 
 
 
 
 
 
 
 
 
 
 
079aad4
 
 
 
 
 
 
 
 
 
 
 
8c2d6d8
cd71d82
 
 
 
adebebc
 
 
c5bfe73
adebebc
 
 
 
 
 
 
 
 
 
 
cd71d82
 
079aad4
cd71d82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
079aad4
 
 
 
 
 
 
bddffef
079aad4
 
 
cd71d82
 
 
 
 
 
 
 
079aad4
cd71d82
 
 
 
 
079aad4
cd71d82
 
 
 
079aad4
 
 
 
 
 
 
 
 
cd71d82
079aad4
 
 
cd71d82
079aad4
 
 
 
cd71d82
 
 
 
049e266
cd71d82
 
 
 
079aad4
 
049e266
cd71d82
 
 
 
 
 
079aad4
cd71d82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
079aad4
cd71d82
 
 
 
 
079aad4
cd71d82
 
 
 
079aad4
 
cd71d82
0b0e589
079aad4
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import os
import json
import tempfile
import traceback

import gradio as gr
import numpy as np
import soundfile as sf
import torch
from huggingface_hub import hf_hub_download
import yaml

from inference import StyleTTS2

# =========================
# PATHS
# =========================
SPACE_ROOT = os.path.dirname(os.path.abspath(__file__))
DATA_ROOT = os.path.join(SPACE_ROOT, "demo_data")
SPEAKER2REFS_PATH = os.path.join(DATA_ROOT, "speaker2refs.json")

# Model repo (ckpt + config)
CKPT_REPO = "stephenhoang/ttsStyleTTS2-ms152"
models_path = hf_hub_download(repo_id=CKPT_REPO, filename="epoch_00000.pth")
config_path = hf_hub_download(repo_id=CKPT_REPO, filename="config.yaml")
cfg = yaml.safe_load(open(config_path, "r", encoding="utf-8"))

device = "cuda" if torch.cuda.is_available() else "cpu"

# =======================
# ================= DEBUG SYMBOLS =================
try:
    symbols = (
        list(cfg['symbol']['pad']) +
        list(cfg['symbol']['punctuation']) +
        list(cfg['symbol']['letters']) +
        list(cfg['symbol']['letters_ipa']) +
        list(cfg['symbol']['extend'])
    )

    print("\n========== SYMBOL DEBUG ==========")
    print("Total symbols (+pad):", len(symbols))

    for i in range(min(30, len(symbols))):
        print(i, repr(symbols[i]))

    print("==================================\n")

except Exception as e:
    print("❌ SYMBOL DEBUG ERROR:", e)
# =================================================

# LOAD speaker2refs.json
# =========================
if not os.path.isfile(SPEAKER2REFS_PATH):
    raise FileNotFoundError(f"speaker2refs.json not found: {SPEAKER2REFS_PATH}")

with open(SPEAKER2REFS_PATH, "r", encoding="utf-8") as f:
    SPEAKER2REFS = json.load(f)

SPEAKER_CHOICES = sorted(SPEAKER2REFS.keys())
if not SPEAKER_CHOICES:
    raise RuntimeError("speaker2refs.json is empty (no speakers found).")

def _abs_ref_path(p: str) -> str:
    """
    Hỗ trợ cả 2 kiểu:
      - "refs/id_1.wav"
      - "demo_data/refs/id_1.wav"
    """
    p = p.lstrip("./")
    if os.path.isabs(p):
        return p
    if p.startswith("demo_data/"):
        return os.path.join(SPACE_ROOT, p)
    return os.path.join(DATA_ROOT, p)

# =========================
# LOAD MODEL
# =========================
model = StyleTTS2(config_path, models_path).eval().to(device)
# ================= VOCAB DEBUG =================
ckpt = torch.load(models_path, map_location="cpu")

for k, v in ckpt["net"].items():
    if "embedding.weight" in k:
        print("✅ CKPT embedding:", v.shape)

print("✅ Runtime symbols:", len(symbols))

# Nếu có sẵn text_tensor ở scope thì in, còn không thì bỏ dòng này
try:
    print("✅ Text tokens sample:", text_tensor[:30])
except:
    print("⚠️ text_tensor chưa tồn tại ở đây")
# ===============================================

# =========================
# STYLE CACHE
# =========================
STYLE_CACHE = {}
STYLE_CACHE_MAX = 64

def _cache_get(key):
    return STYLE_CACHE.get(key, None)

def _cache_set(key, val):
    if key in STYLE_CACHE:
        STYLE_CACHE[key] = val
        return
    if len(STYLE_CACHE) >= STYLE_CACHE_MAX:
        STYLE_CACHE.pop(next(iter(STYLE_CACHE)))
    STYLE_CACHE[key] = val

@torch.inference_mode()
def synth_one_speaker(speaker_name: str, text_prompt: str,
                      denoise: float, avg_style: bool, stabilize: bool):
    try:
        if not speaker_name:
            return None, "Bạn chưa chọn speaker."

        info = SPEAKER2REFS.get(speaker_name, None)
        if info is None:
            return None, f"Speaker '{speaker_name}' không tồn tại trong speaker2refs.json."

        # info là dict: {"path":..., "lang":..., "speed":..., ...}
        if not isinstance(info, dict) or "path" not in info:
            return None, f"Format speaker2refs.json sai cho speaker '{speaker_name}'. Expect dict có field 'path'."

        ref_path = _abs_ref_path(info["path"])
        lang = info.get("lang", "vi")
        speed = float(info.get("speed", 1.0))

        if not os.path.isfile(ref_path):
            return None, f"Ref audio not found: {ref_path}"

        if not text_prompt or not text_prompt.strip():
            return None, "Bạn chưa nhập text."

        speakers = {
            "id_1": {"path": ref_path, "lang": lang, "speed": speed}
        }

        cache_key = (speaker_name, float(denoise), bool(avg_style))
        styles = _cache_get(cache_key)
        if styles is None:
            styles = model.get_styles(speakers, denoise=denoise, avg_style=avg_style)
            _cache_set(cache_key, styles)

        text_prompt = text_prompt.strip()
        if "[id_" not in text_prompt:
            text_prompt = "[id_1] " + text_prompt

        wav = model.generate(
            text_prompt,
            styles,
            stabilize=stabilize,
            n_merge=18,
            default_speaker="[id_1]"
        )

        wav = np.asarray(wav, dtype=np.float32)
        if wav.size == 0:
            return None, "Model output rỗng (0 samples). Kiểm tra phonemizer/espeak và tokenization."

        # normalize (không làm mất tiếng)
        peak = float(np.max(np.abs(wav)))
        if peak > 1e-6:
            wav = wav / peak

        out_f = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
        out_path = out_f.name
        out_f.close()
        sf.write(out_path, wav, samplerate=16000)

        status = (
            "OK\n"
            f"speaker: {speaker_name}\n"
            f"ref: {ref_path}\n"
            f"lang: {lang}, speed: {speed}\n"
            f"samples: {wav.shape[0]}, sec: {wav.shape[0]/1600016000:.3f}\n"
            f"device: {device}"
        )
        return out_path, status

    except Exception:
        return None, traceback.format_exc()

# =========================
# GRADIO UI
# =========================
with gr.Blocks() as demo:
    gr.HTML("<h2 style='text-align:center;'>TTS</h2>")

    speaker_name = gr.Dropdown(
        choices=SPEAKER_CHOICES,
        label="Speaker Name (closed-set)",
        value=SPEAKER_CHOICES[0],
        interactive=True
    )

    text_prompt = gr.Textbox(
        label="Text Prompt",
        placeholder="Nhập câu tiếng Việt cần đọc...",
        lines=4
    )

    with gr.Row():
        denoise = gr.Slider(0.0, 1.0, step=0.1, value=0.3, label="Denoise Strength")
        avg_style = gr.Checkbox(label="Use Average Styles", value=True)
        stabilize = gr.Checkbox(label="Stabilize Speaking Speed", value=True)

    gen_button = gr.Button("Generate")
    synthesized_audio = gr.Audio(label="Generated Audio", type="filepath")
    status = gr.Textbox(label="Status", lines=6, interactive=False)

    gen_button.click(
        fn=synth_one_speaker,
        inputs=[speaker_name, text_prompt, denoise, avg_style, stabilize],
        outputs=[synthesized_audio, status],
        concurrency_limit=1,
    )

# Gradio: dùng queue() chuẩn, không dùng concurrency_count
demo.queue(max_size=8, default_concurrency_limit=1)  # theo docs :contentReference[oaicite:2]{index=2}
demo.launch()