Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """ | |
| Gradio app.py - StyleTTS2-vi with precomputed style embeddings (.pth) | |
| - UI gọn gàng với accordion thu gọn | |
| - Style Mixer: 4 slot cố định (Kore, Puck, Algenib, Leda), chỉ chỉnh weight; auto-normalize | |
| - Reference samples trong accordion | |
| """ | |
| import os, re, glob, time, yaml, torch, librosa, numpy as np, gradio as gr | |
| from munch import Munch | |
| from soe_vinorm import SoeNormalizer | |
| # ============================================================== | |
| # Cấu hình cơ bản | |
| # ============================================================== | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| SR_OUT = 24000 | |
| ALPHA, BETA, DIFFUSION_STEPS, EMBEDDING_SCALE = 0.0, 0.0, 5, 1.0 | |
| REF_DIR = "ref_voice" # thư mục chứa audio mẫu (.wav) | |
| # ============================================================== | |
| # Import module StyleTTS2 | |
| # ============================================================== | |
| from models import * | |
| from utils import * | |
| from models import build_model | |
| from text_utils import TextCleaner | |
| from Utils_extend_v1.PLBERT.util import load_plbert | |
| from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule | |
| textcleaner = TextCleaner() | |
| # ============================================================== | |
| # Load model và checkpoint | |
| # ============================================================== | |
| from huggingface_hub import hf_hub_download | |
| hf_hub_download( | |
| repo_id="ltphuongunited/styletts2_vi", | |
| filename="gemini_2nd_00045.pth", | |
| local_dir="Models/gemini_vi", | |
| local_dir_use_symlinks=False, | |
| ) | |
| CHECKPOINT_PTH = "Models/gemini_vi/gemini_2nd_00045.pth" | |
| CONFIG_PATH = "Models/gemini_vi/config_gemini_vi_en.yml" | |
| config = yaml.safe_load(open(CONFIG_PATH)) | |
| ASR_config = config.get("ASR_config", False) | |
| ASR_path = config.get("ASR_path", False) | |
| F0_path = config.get("F0_path", False) | |
| PLBERT_dir = config.get("PLBERT_dir", False) | |
| text_aligner = load_ASR_models(ASR_path, ASR_config) | |
| pitch_extractor = load_F0_models(F0_path) | |
| plbert = load_plbert(PLBERT_dir) | |
| model_params = recursive_munch(config["model_params"]) | |
| model = build_model(model_params, text_aligner, pitch_extractor, plbert) | |
| _ = [model[k].to(DEVICE) for k in model] | |
| _ = [model[k].eval() for k in model] | |
| ckpt = torch.load(CHECKPOINT_PTH, map_location="cpu")["net"] | |
| for key in model: | |
| if key in ckpt: | |
| try: | |
| model[key].load_state_dict(ckpt[key]) | |
| except Exception: | |
| from collections import OrderedDict | |
| new_state = OrderedDict() | |
| for k, v in ckpt[key].items(): | |
| new_state[k[7:]] = v | |
| model[key].load_state_dict(new_state, strict=False) | |
| sampler = DiffusionSampler( | |
| model.diffusion.diffusion, | |
| sampler=ADPM2Sampler(), | |
| sigma_schedule=KarrasSchedule(sigma_min=1e-4, sigma_max=3.0, rho=9.0), | |
| clamp=False, | |
| ) | |
| # ============================================================== | |
| # Phonemizer | |
| # ============================================================== | |
| import phonemizer | |
| vi_phonemizer = phonemizer.backend.EspeakBackend( | |
| language="vi", preserve_punctuation=True, with_stress=True | |
| ) | |
| def phonemize_text(text: str) -> str: | |
| ps = vi_phonemizer.phonemize([text])[0] | |
| return ps.replace("(en)", "").replace("(vi)", "").strip() | |
| def length_to_mask(lengths: torch.LongTensor) -> torch.Tensor: | |
| mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) | |
| mask = torch.gt(mask + 1, lengths.unsqueeze(1)) | |
| return mask | |
| # ============================================================== | |
| # Load style embeddings đã tính sẵn | |
| # ============================================================== | |
| STYLE_PTH = "Models/styles_speaker_parallel.pth" | |
| print(f"Loading precomputed styles: {STYLE_PTH}") | |
| styles_dict = torch.load(STYLE_PTH, map_location=DEVICE) | |
| # fallback speaker nếu mixer rỗng | |
| SPEAKER_ORDER_PREF = ["Kore", "Puck", "Algenib", "Leda"] | |
| DEFAULT_SPK = next((s for s in SPEAKER_ORDER_PREF if s in styles_dict), list(styles_dict.keys())[0]) | |
| def get_style_by_length(speaker: str, phoneme_len: int): | |
| spk_tensor = styles_dict[speaker] # [510, 1, 256] hoặc [510, 256] | |
| idx = min(max(phoneme_len, 1), spk_tensor.shape[0]) - 1 | |
| feat = spk_tensor[idx] | |
| # ép về [1,256] | |
| if feat.ndim == 3: # [1,1,256] | |
| feat = feat.squeeze(0) | |
| if feat.ndim == 2: # [1,256] | |
| feat = feat.squeeze(0) | |
| return feat.unsqueeze(0).to(DEVICE) # [1,256] | |
| # ============================================================== | |
| # Style mixing utils | |
| # ============================================================== | |
| def parse_mix_spec(spec: str) -> dict: | |
| """Parse 'Kore:0.75,Puck:0.25' -> {'Kore':0.75,'Puck':0.25} (lọc lỗi, gộp trùng).""" | |
| mix = {} | |
| if not spec or not isinstance(spec, str): | |
| return mix | |
| for part in spec.split(","): | |
| if ":" not in part: | |
| continue | |
| k, v = part.split(":", 1) | |
| k = (k or "").strip() | |
| if not k: | |
| continue | |
| try: | |
| w = float((v or "").strip()) | |
| except Exception: | |
| continue | |
| if not np.isfinite(w) or w <= 0: | |
| continue | |
| mix[k] = mix.get(k, 0.0) + w | |
| return mix | |
| def get_style_mixed_by_length(mix_dict: dict, phoneme_len: int): | |
| """Trộn style của nhiều speaker theo trọng số. Trả về [1,256] trên DEVICE.""" | |
| if not mix_dict: | |
| return get_style_by_length(DEFAULT_SPK, phoneme_len) | |
| total = sum(max(0.0, float(w)) for w in mix_dict.values()) | |
| if total <= 0: | |
| return get_style_by_length(DEFAULT_SPK, phoneme_len) | |
| mix_feat = None | |
| for spk, w in mix_dict.items(): | |
| if spk not in styles_dict: | |
| print(f"[WARN] Speaker '{spk}' không có trong styles_dict, bỏ qua.") | |
| continue | |
| feat_i = get_style_by_length(spk, phoneme_len) # [1,256] | |
| wi = float(w) / total | |
| mix_feat = feat_i * wi if mix_feat is None else mix_feat + feat_i * wi | |
| if mix_feat is None: | |
| return get_style_by_length(DEFAULT_SPK, phoneme_len) | |
| return mix_feat # [1,256] | |
| # ============================================================== | |
| # Audio postprocess (librosa): trim + denoise + remove internal silence | |
| # ============================================================== | |
| def _simple_spectral_denoise(y, sr, n_fft=1024, hop=256, prop_decrease=0.8): | |
| if y.size == 0: | |
| return y | |
| D = librosa.stft(y, n_fft=n_fft, hop_length=hop, win_length=n_fft) | |
| S = np.abs(D) | |
| noise = np.median(S, axis=1, keepdims=True) | |
| S_clean = S - prop_decrease * noise | |
| S_clean = np.maximum(S_clean, 0.0) | |
| gain = S_clean / (S + 1e-8) | |
| D_denoised = D * gain | |
| y_out = librosa.istft(D_denoised, hop_length=hop, win_length=n_fft, length=len(y)) | |
| return y_out | |
| def _concat_with_crossfade(segments, crossfade_samples=0): | |
| if not segments: | |
| return np.array([], dtype=np.float32) | |
| out = segments[0].astype(np.float32, copy=True) | |
| for seg in segments[1:]: | |
| seg = seg.astype(np.float32, copy=False) | |
| if crossfade_samples > 0 and out.size > 0 and seg.size > 0: | |
| cf = min(crossfade_samples, out.size, seg.size) | |
| fade_out = np.linspace(1.0, 0.0, cf, dtype=np.float32) | |
| fade_in = 1.0 - fade_out | |
| tail = out[-cf:] * fade_out + seg[:cf] * fade_in | |
| out = np.concatenate([out[:-cf], tail, seg[cf:]], axis=0) | |
| else: | |
| out = np.concatenate([out, seg], axis=0) | |
| return out | |
| def _reduce_internal_silence(y, sr, top_db=30, min_keep_ms=40, crossfade_ms=8): | |
| if y.size == 0: | |
| return y | |
| intervals = librosa.effects.split(y, top_db=top_db) | |
| if intervals.size == 0: | |
| return y | |
| min_keep = int(sr * (min_keep_ms / 1000.0)) | |
| segs = [] | |
| for s, e in intervals: | |
| if e - s >= min_keep: | |
| segs.append(y[s:e]) | |
| if not segs: | |
| return y | |
| crossfade = int(sr * (crossfade_ms / 1000.0)) | |
| y_out = _concat_with_crossfade(segs, crossfade_samples=crossfade) | |
| return y_out | |
| def postprocess_audio(y, sr, | |
| trim_top_db=30, | |
| denoise=True, | |
| denoise_n_fft=1024, | |
| denoise_hop=256, | |
| denoise_strength=0.8, | |
| remove_internal_silence=True, | |
| split_top_db=30, | |
| min_keep_ms=40, | |
| crossfade_ms=8): | |
| if y.size == 0: | |
| return y.astype(np.float32) | |
| y_trim, _ = librosa.effects.trim(y, top_db=trim_top_db) | |
| if denoise: | |
| y_trim = _simple_spectral_denoise( | |
| y_trim, sr, n_fft=denoise_n_fft, hop=denoise_hop, prop_decrease=denoise_strength | |
| ) | |
| if remove_internal_silence: | |
| y_trim = _reduce_internal_silence( | |
| y_trim, sr, top_db=split_top_db, min_keep_ms=min_keep_ms, crossfade_ms=crossfade_ms | |
| ) | |
| y_trim = np.nan_to_num(y_trim, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32) | |
| m = np.max(np.abs(y_trim)) + 1e-8 | |
| if m > 1.0: | |
| y_trim = y_trim / m | |
| return y_trim | |
| # ============================================================== | |
| # Inference core | |
| # ============================================================== | |
| def inference_one(text, ref_feat, alpha=ALPHA, beta=BETA, | |
| diffusion_steps=DIFFUSION_STEPS, embedding_scale=EMBEDDING_SCALE): | |
| ps = phonemize_text(text) | |
| tokens = textcleaner(ps) | |
| tokens.insert(0, 0) | |
| tokens = torch.LongTensor(tokens).unsqueeze(0).to(DEVICE) | |
| input_lengths = torch.LongTensor([tokens.shape[-1]]).to(DEVICE) | |
| text_mask = length_to_mask(input_lengths).to(DEVICE) | |
| with torch.no_grad(): | |
| t_en = model.text_encoder(tokens, input_lengths, text_mask) | |
| bert_d = model.bert(tokens, attention_mask=(~text_mask).int()) | |
| d_en = model.bert_encoder(bert_d).transpose(-1, -2) | |
| if alpha == 0 and beta == 0: | |
| s_pred = ref_feat.clone() # [1,256] | |
| else: | |
| s_pred = sampler( | |
| noise=torch.randn((1, 256)).unsqueeze(1).to(DEVICE), | |
| embedding=bert_d, | |
| embedding_scale=embedding_scale, | |
| features=ref_feat, # [1,256] | |
| num_steps=diffusion_steps, | |
| ).squeeze(1) # [1,256] | |
| s, ref = s_pred[:, 128:], s_pred[:, :128] | |
| ref = alpha * ref + (1 - alpha) * ref_feat[:, :128] | |
| s = beta * s + (1 - beta) * ref_feat[:, 128:] | |
| # --- Metrics (cosine) --- | |
| def cosine_sim(a, b): | |
| return torch.nn.functional.cosine_similarity(a, b, dim=1).mean().item() | |
| simi_timbre = cosine_sim(s_pred[:, :128], ref_feat[:, :128]) | |
| simi_prosody = cosine_sim(s_pred[:, 128:], ref_feat[:, 128:]) | |
| # --- Duration / Alignment --- | |
| d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask) | |
| x, _ = model.predictor.lstm(d) | |
| duration = torch.sigmoid(model.predictor.duration_proj(x)).sum(axis=-1) | |
| pred_dur = torch.round(duration.squeeze()).clamp(min=1) | |
| T = int(pred_dur.sum().item()) | |
| pred_aln = torch.zeros(input_lengths.item(), T, device=DEVICE) | |
| c = 0 | |
| for i in range(input_lengths.item()): | |
| span = int(pred_dur[i].item()) | |
| pred_aln[i, c:c+span] = 1.0 | |
| c += span | |
| en = (d.transpose(-1, -2) @ pred_aln.unsqueeze(0)) | |
| if model_params.decoder.type == "hifigan": | |
| en = torch.cat([en[:, :, :1], en[:, :, :-1]], dim=2) | |
| F0_pred, N_pred = model.predictor.F0Ntrain(en, s) | |
| asr = (t_en @ pred_aln.unsqueeze(0)) | |
| if model_params.decoder.type == "hifigan": | |
| asr = torch.cat([asr[:, :, :1], asr[:, :, :-1]], dim=2) | |
| out = model.decoder(asr, F0_pred, N_pred, ref.squeeze().unsqueeze(0)) | |
| wav = out.squeeze().detach().cpu().numpy() | |
| if wav.shape[-1] > 50: | |
| wav = wav[:-50] | |
| # # Hậu xử lý: trim + denoise + bỏ silence nội bộ | |
| wav = postprocess_audio( | |
| wav, SR_OUT, | |
| trim_top_db=30, | |
| denoise=True, | |
| denoise_n_fft=1024, denoise_hop=256, denoise_strength=0.8, | |
| remove_internal_silence=False, | |
| split_top_db=30, min_keep_ms=40, crossfade_ms=8 | |
| ) | |
| return wav, ps, simi_timbre, simi_prosody | |
| # ============================================================== | |
| # Ref-audio mapping (quét ./ref_voice để tìm file mẫu theo speaker) | |
| # ============================================================== | |
| def _norm(s: str) -> str: | |
| import unicodedata | |
| s = unicodedata.normalize("NFKD", s) | |
| s = "".join([c for c in s if not unicodedata.combining(c)]) | |
| s = s.lower() | |
| s = re.sub(r"[^a-z0-9_\-\.]+", "", s) | |
| return s | |
| def build_ref_map(ref_dir: str) -> dict: | |
| paths = glob.glob(os.path.join(ref_dir, "**", "*.wav"), recursive=True) | |
| by_name = {} | |
| for p in paths: | |
| fname = os.path.basename(p) | |
| by_name[_norm(fname)] = p | |
| spk_map = {} | |
| speakers = list(styles_dict.keys()) if isinstance(styles_dict, dict) else ["Kore","Algenib","Puck","Leda"] | |
| for spk in speakers: | |
| spk_n = _norm(spk) | |
| hit = None | |
| for k, p in by_name.items(): | |
| if f"_{spk_n}_" in k: | |
| hit = p | |
| break | |
| if not hit: | |
| for k, p in by_name.items(): | |
| if spk_n in k: | |
| hit = p | |
| break | |
| if hit: | |
| spk_map[spk] = hit | |
| return spk_map | |
| REF_MAP = build_ref_map(REF_DIR) | |
| def get_ref_path_for_speaker(spk: str): | |
| return REF_MAP.get(spk) | |
| # ============================================================== | |
| # Wrapper cho Gradio (nhận speaker_mix_spec là string ẩn) | |
| # ============================================================== | |
| def run_inference(text, alpha, beta, speaker_mix_spec): | |
| normalizer = SoeNormalizer() | |
| text = normalizer.normalize(text).replace(" ,", ",").replace(" .", ".") | |
| ps = phonemize_text(text) | |
| phoneme_len = len(ps.replace(" ", "")) | |
| mix_dict = parse_mix_spec(speaker_mix_spec) | |
| if len(mix_dict) > 0: | |
| ref_feat = get_style_mixed_by_length(mix_dict, phoneme_len) | |
| ref_idx = min(phoneme_len, 510) | |
| total = sum(mix_dict.values()) | |
| mix_info = {k: round(float(v / total), 3) for k, v in mix_dict.items()} | |
| chosen_speakers = list(mix_dict.keys()) | |
| else: | |
| ref_feat = get_style_by_length(DEFAULT_SPK, phoneme_len) | |
| ref_idx = min(phoneme_len, 510) | |
| mix_info = {DEFAULT_SPK: 1.0} | |
| chosen_speakers = [DEFAULT_SPK] | |
| t0 = time.time() | |
| wav, ps_out, simi_timbre, simi_prosody = inference_one( | |
| text, ref_feat, alpha=float(alpha), beta=float(beta) | |
| ) | |
| gen_time = time.time() - t0 | |
| rtf = gen_time / max(1e-6, len(wav) / SR_OUT) | |
| info = { | |
| "Text after soe_vinorms:": text, | |
| "Speakers": chosen_speakers, | |
| "Mix weights (normalized)": mix_info, | |
| "Phonemes": ps_out, | |
| "Phoneme length": phoneme_len, | |
| "Ref index": ref_idx, | |
| "simi_timbre": round(float(simi_timbre), 4), | |
| "simi_prosody": round(float(simi_prosody), 4), | |
| "alpha": float(alpha), | |
| "beta": float(beta), | |
| "RTF": round(float(rtf), 3), | |
| "Device": DEVICE, | |
| } | |
| return (SR_OUT, wav.astype(np.float32)), info | |
| # ============================================================== | |
| # UI helper: build mix-spec CỐ ĐỊNH theo 4 speaker | |
| # ============================================================== | |
| def _build_mix_spec_ui_fixed(normalize, w1, w2, w3, w4, order): | |
| pairs = [(order[0], float(w1 or 0.0)), | |
| (order[1], float(w2 or 0.0)), | |
| (order[2], float(w3 or 0.0)), | |
| (order[3], float(w4 or 0.0))] | |
| pairs = [(s, w) for s, w in pairs if w > 0] | |
| if not pairs: | |
| return "", {}, "**Sum:** 0.000" | |
| total = sum(w for _, w in pairs) | |
| if normalize and total > 0: | |
| pairs = [(s, w/total) for s, w in pairs] | |
| acc = {} | |
| for s, w in pairs: | |
| acc[s] = acc.get(s, 0.0) + w | |
| mix_spec = ",".join([f"{s}:{w:.4f}" for s, w in acc.items()]) | |
| mix_view = {"weights": {s: round(w, 3) for s, w in acc.items()}, "normalized": bool(normalize)} | |
| sum_md = f"**Sum:** {round(sum(acc.values()), 3)}" | |
| return mix_spec, mix_view, sum_md | |
| # ============================================================== | |
| # Gradio UI - Compact & Clean Version | |
| # ============================================================== | |
| with gr.Blocks(title="StyleTTS2-vi Demo", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🎙️ StyleTTS2-vi Demo") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| text_inp = gr.Textbox( | |
| label="📝 Text Input", | |
| lines=3, | |
| placeholder="Nhập văn bản cần đọc...", | |
| value="Trăng treo lơ lửng trên đỉnh núi chơ vơ, ánh sáng bàng bạc phủ lên bãi đá ngổn ngang. Con dế thổn thức trong khe cỏ, tiếng gió hun hút lùa qua hốc núi trập trùng. Dưới thung lũng, đàn trâu gặm cỏ ung dung, hơi sương vẩn đục, lảng bảng giữa đồng khuya tĩnh mịch." | |
| ) | |
| # Danh sách speaker có trong styles_dict | |
| spk_choices = list(styles_dict.keys()) if isinstance(styles_dict, dict) else ["Kore","Algenib","Puck","Leda"] | |
| # Thứ tự CỐ ĐỊNH cho mixer | |
| fixed_order = [s for s in ["Kore", "Puck", "Algenib", "Leda"] if s in spk_choices] | |
| if len(fixed_order) < 4: | |
| for s in spk_choices: | |
| if s not in fixed_order: | |
| fixed_order.append(s) | |
| if len(fixed_order) == 4: | |
| break | |
| # === Reference samples - Compact grid === | |
| with gr.Accordion("🎵 Reference Samples", open=True): | |
| gr.Markdown("*Click to preview voice samples*") | |
| for i in range(0, 4, 2): | |
| with gr.Row(): | |
| for j in range(2): | |
| idx = i + j | |
| if idx < len(fixed_order): | |
| spk = fixed_order[idx] | |
| with gr.Column(min_width=200): | |
| gr.Audio( | |
| value=get_ref_path_for_speaker(spk), | |
| label=spk, | |
| type="filepath", | |
| interactive=False, | |
| show_download_button=False | |
| ) | |
| # ---- Style Mixer - More compact ---- | |
| with gr.Accordion("🎨 Style Mixer", open=True): | |
| normalize_ck = gr.Checkbox(value=True, label="Auto-normalize", container=False) | |
| # Grid 2x2 cho 4 sliders | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown(f"**{fixed_order[0]}**") | |
| w1 = gr.Slider(0.0, 1.0, value=0.0, step=0.05, show_label=False, container=False) | |
| with gr.Column(scale=1): | |
| gr.Markdown(f"**{fixed_order[1]}**") | |
| w2 = gr.Slider(0.0, 1.0, value=0.0, step=0.05, show_label=False, container=False) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown(f"**{fixed_order[2]}**") | |
| w3 = gr.Slider(0.0, 1.0, value=0.0, step=0.05, show_label=False, container=False) | |
| with gr.Column(scale=1): | |
| gr.Markdown(f"**{fixed_order[3]}**") | |
| w4 = gr.Slider(0.0, 1.0, value=0.0, step=0.05, show_label=False, container=False) | |
| with gr.Row(): | |
| mix_sum_md = gr.Markdown("**Sum:** 0.000") | |
| mix_view_json = gr.JSON(label="Current Mix", visible=False) | |
| mix_spec_state = gr.State("") | |
| order_state = gr.State(fixed_order) | |
| # Advanced settings - Collapsed by default | |
| with gr.Accordion("⚙️ Advanced Settings", open=False): | |
| with gr.Row(): | |
| alpha_n = gr.Number(value=ALPHA, label="Alpha (timbre)", precision=3, minimum=0, maximum=1) | |
| beta_n = gr.Number(value=BETA, label="Beta (prosody)", precision=3, minimum=0, maximum=1) | |
| btn = gr.Button("🔊 Generate Speech", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| out_audio = gr.Audio(label="🎧 Output Audio", type="numpy") | |
| with gr.Accordion("📊 Generation Metrics", open=False): | |
| metrics = gr.JSON(label="Details") | |
| # Event handlers | |
| def _ui_build_wrapper_fixed(normalize, w1, w2, w3, w4, order): | |
| spec, view, summ = _build_mix_spec_ui_fixed(normalize, w1, w2, w3, w4, order) | |
| return spec, view, summ | |
| for comp in [normalize_ck, w1, w2, w3, w4]: | |
| comp.change( | |
| _ui_build_wrapper_fixed, | |
| inputs=[normalize_ck, w1, w2, w3, w4, order_state], | |
| outputs=[mix_spec_state, mix_view_json, mix_sum_md] | |
| ) | |
| btn.click( | |
| run_inference, | |
| inputs=[text_inp, alpha_n, beta_n, mix_spec_state], | |
| outputs=[out_audio, metrics] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |