styletts2 / app.py
PhuongLT
update UI
5632ded
# -*- coding: utf-8 -*-
"""
Gradio app.py - StyleTTS2-vi with precomputed style embeddings (.pth)
- UI gọn gàng với accordion thu gọn
- Style Mixer: 4 slot cố định (Kore, Puck, Algenib, Leda), chỉ chỉnh weight; auto-normalize
- Reference samples trong accordion
"""
import os, re, glob, time, yaml, torch, librosa, numpy as np, gradio as gr
from munch import Munch
from soe_vinorm import SoeNormalizer
# ==============================================================
# Cấu hình cơ bản
# ==============================================================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SR_OUT = 24000
ALPHA, BETA, DIFFUSION_STEPS, EMBEDDING_SCALE = 0.0, 0.0, 5, 1.0
REF_DIR = "ref_voice" # thư mục chứa audio mẫu (.wav)
# ==============================================================
# Import module StyleTTS2
# ==============================================================
from models import *
from utils import *
from models import build_model
from text_utils import TextCleaner
from Utils_extend_v1.PLBERT.util import load_plbert
from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
textcleaner = TextCleaner()
# ==============================================================
# Load model và checkpoint
# ==============================================================
from huggingface_hub import hf_hub_download
hf_hub_download(
repo_id="ltphuongunited/styletts2_vi",
filename="gemini_2nd_00045.pth",
local_dir="Models/gemini_vi",
local_dir_use_symlinks=False,
)
CHECKPOINT_PTH = "Models/gemini_vi/gemini_2nd_00045.pth"
CONFIG_PATH = "Models/gemini_vi/config_gemini_vi_en.yml"
config = yaml.safe_load(open(CONFIG_PATH))
ASR_config = config.get("ASR_config", False)
ASR_path = config.get("ASR_path", False)
F0_path = config.get("F0_path", False)
PLBERT_dir = config.get("PLBERT_dir", False)
text_aligner = load_ASR_models(ASR_path, ASR_config)
pitch_extractor = load_F0_models(F0_path)
plbert = load_plbert(PLBERT_dir)
model_params = recursive_munch(config["model_params"])
model = build_model(model_params, text_aligner, pitch_extractor, plbert)
_ = [model[k].to(DEVICE) for k in model]
_ = [model[k].eval() for k in model]
ckpt = torch.load(CHECKPOINT_PTH, map_location="cpu")["net"]
for key in model:
if key in ckpt:
try:
model[key].load_state_dict(ckpt[key])
except Exception:
from collections import OrderedDict
new_state = OrderedDict()
for k, v in ckpt[key].items():
new_state[k[7:]] = v
model[key].load_state_dict(new_state, strict=False)
sampler = DiffusionSampler(
model.diffusion.diffusion,
sampler=ADPM2Sampler(),
sigma_schedule=KarrasSchedule(sigma_min=1e-4, sigma_max=3.0, rho=9.0),
clamp=False,
)
# ==============================================================
# Phonemizer
# ==============================================================
import phonemizer
vi_phonemizer = phonemizer.backend.EspeakBackend(
language="vi", preserve_punctuation=True, with_stress=True
)
def phonemize_text(text: str) -> str:
ps = vi_phonemizer.phonemize([text])[0]
return ps.replace("(en)", "").replace("(vi)", "").strip()
def length_to_mask(lengths: torch.LongTensor) -> torch.Tensor:
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
mask = torch.gt(mask + 1, lengths.unsqueeze(1))
return mask
# ==============================================================
# Load style embeddings đã tính sẵn
# ==============================================================
STYLE_PTH = "Models/styles_speaker_parallel.pth"
print(f"Loading precomputed styles: {STYLE_PTH}")
styles_dict = torch.load(STYLE_PTH, map_location=DEVICE)
# fallback speaker nếu mixer rỗng
SPEAKER_ORDER_PREF = ["Kore", "Puck", "Algenib", "Leda"]
DEFAULT_SPK = next((s for s in SPEAKER_ORDER_PREF if s in styles_dict), list(styles_dict.keys())[0])
def get_style_by_length(speaker: str, phoneme_len: int):
spk_tensor = styles_dict[speaker] # [510, 1, 256] hoặc [510, 256]
idx = min(max(phoneme_len, 1), spk_tensor.shape[0]) - 1
feat = spk_tensor[idx]
# ép về [1,256]
if feat.ndim == 3: # [1,1,256]
feat = feat.squeeze(0)
if feat.ndim == 2: # [1,256]
feat = feat.squeeze(0)
return feat.unsqueeze(0).to(DEVICE) # [1,256]
# ==============================================================
# Style mixing utils
# ==============================================================
def parse_mix_spec(spec: str) -> dict:
"""Parse 'Kore:0.75,Puck:0.25' -> {'Kore':0.75,'Puck':0.25} (lọc lỗi, gộp trùng)."""
mix = {}
if not spec or not isinstance(spec, str):
return mix
for part in spec.split(","):
if ":" not in part:
continue
k, v = part.split(":", 1)
k = (k or "").strip()
if not k:
continue
try:
w = float((v or "").strip())
except Exception:
continue
if not np.isfinite(w) or w <= 0:
continue
mix[k] = mix.get(k, 0.0) + w
return mix
def get_style_mixed_by_length(mix_dict: dict, phoneme_len: int):
"""Trộn style của nhiều speaker theo trọng số. Trả về [1,256] trên DEVICE."""
if not mix_dict:
return get_style_by_length(DEFAULT_SPK, phoneme_len)
total = sum(max(0.0, float(w)) for w in mix_dict.values())
if total <= 0:
return get_style_by_length(DEFAULT_SPK, phoneme_len)
mix_feat = None
for spk, w in mix_dict.items():
if spk not in styles_dict:
print(f"[WARN] Speaker '{spk}' không có trong styles_dict, bỏ qua.")
continue
feat_i = get_style_by_length(spk, phoneme_len) # [1,256]
wi = float(w) / total
mix_feat = feat_i * wi if mix_feat is None else mix_feat + feat_i * wi
if mix_feat is None:
return get_style_by_length(DEFAULT_SPK, phoneme_len)
return mix_feat # [1,256]
# ==============================================================
# Audio postprocess (librosa): trim + denoise + remove internal silence
# ==============================================================
def _simple_spectral_denoise(y, sr, n_fft=1024, hop=256, prop_decrease=0.8):
if y.size == 0:
return y
D = librosa.stft(y, n_fft=n_fft, hop_length=hop, win_length=n_fft)
S = np.abs(D)
noise = np.median(S, axis=1, keepdims=True)
S_clean = S - prop_decrease * noise
S_clean = np.maximum(S_clean, 0.0)
gain = S_clean / (S + 1e-8)
D_denoised = D * gain
y_out = librosa.istft(D_denoised, hop_length=hop, win_length=n_fft, length=len(y))
return y_out
def _concat_with_crossfade(segments, crossfade_samples=0):
if not segments:
return np.array([], dtype=np.float32)
out = segments[0].astype(np.float32, copy=True)
for seg in segments[1:]:
seg = seg.astype(np.float32, copy=False)
if crossfade_samples > 0 and out.size > 0 and seg.size > 0:
cf = min(crossfade_samples, out.size, seg.size)
fade_out = np.linspace(1.0, 0.0, cf, dtype=np.float32)
fade_in = 1.0 - fade_out
tail = out[-cf:] * fade_out + seg[:cf] * fade_in
out = np.concatenate([out[:-cf], tail, seg[cf:]], axis=0)
else:
out = np.concatenate([out, seg], axis=0)
return out
def _reduce_internal_silence(y, sr, top_db=30, min_keep_ms=40, crossfade_ms=8):
if y.size == 0:
return y
intervals = librosa.effects.split(y, top_db=top_db)
if intervals.size == 0:
return y
min_keep = int(sr * (min_keep_ms / 1000.0))
segs = []
for s, e in intervals:
if e - s >= min_keep:
segs.append(y[s:e])
if not segs:
return y
crossfade = int(sr * (crossfade_ms / 1000.0))
y_out = _concat_with_crossfade(segs, crossfade_samples=crossfade)
return y_out
def postprocess_audio(y, sr,
trim_top_db=30,
denoise=True,
denoise_n_fft=1024,
denoise_hop=256,
denoise_strength=0.8,
remove_internal_silence=True,
split_top_db=30,
min_keep_ms=40,
crossfade_ms=8):
if y.size == 0:
return y.astype(np.float32)
y_trim, _ = librosa.effects.trim(y, top_db=trim_top_db)
if denoise:
y_trim = _simple_spectral_denoise(
y_trim, sr, n_fft=denoise_n_fft, hop=denoise_hop, prop_decrease=denoise_strength
)
if remove_internal_silence:
y_trim = _reduce_internal_silence(
y_trim, sr, top_db=split_top_db, min_keep_ms=min_keep_ms, crossfade_ms=crossfade_ms
)
y_trim = np.nan_to_num(y_trim, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
m = np.max(np.abs(y_trim)) + 1e-8
if m > 1.0:
y_trim = y_trim / m
return y_trim
# ==============================================================
# Inference core
# ==============================================================
def inference_one(text, ref_feat, alpha=ALPHA, beta=BETA,
diffusion_steps=DIFFUSION_STEPS, embedding_scale=EMBEDDING_SCALE):
ps = phonemize_text(text)
tokens = textcleaner(ps)
tokens.insert(0, 0)
tokens = torch.LongTensor(tokens).unsqueeze(0).to(DEVICE)
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(DEVICE)
text_mask = length_to_mask(input_lengths).to(DEVICE)
with torch.no_grad():
t_en = model.text_encoder(tokens, input_lengths, text_mask)
bert_d = model.bert(tokens, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_d).transpose(-1, -2)
if alpha == 0 and beta == 0:
s_pred = ref_feat.clone() # [1,256]
else:
s_pred = sampler(
noise=torch.randn((1, 256)).unsqueeze(1).to(DEVICE),
embedding=bert_d,
embedding_scale=embedding_scale,
features=ref_feat, # [1,256]
num_steps=diffusion_steps,
).squeeze(1) # [1,256]
s, ref = s_pred[:, 128:], s_pred[:, :128]
ref = alpha * ref + (1 - alpha) * ref_feat[:, :128]
s = beta * s + (1 - beta) * ref_feat[:, 128:]
# --- Metrics (cosine) ---
def cosine_sim(a, b):
return torch.nn.functional.cosine_similarity(a, b, dim=1).mean().item()
simi_timbre = cosine_sim(s_pred[:, :128], ref_feat[:, :128])
simi_prosody = cosine_sim(s_pred[:, 128:], ref_feat[:, 128:])
# --- Duration / Alignment ---
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
x, _ = model.predictor.lstm(d)
duration = torch.sigmoid(model.predictor.duration_proj(x)).sum(axis=-1)
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
T = int(pred_dur.sum().item())
pred_aln = torch.zeros(input_lengths.item(), T, device=DEVICE)
c = 0
for i in range(input_lengths.item()):
span = int(pred_dur[i].item())
pred_aln[i, c:c+span] = 1.0
c += span
en = (d.transpose(-1, -2) @ pred_aln.unsqueeze(0))
if model_params.decoder.type == "hifigan":
en = torch.cat([en[:, :, :1], en[:, :, :-1]], dim=2)
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
asr = (t_en @ pred_aln.unsqueeze(0))
if model_params.decoder.type == "hifigan":
asr = torch.cat([asr[:, :, :1], asr[:, :, :-1]], dim=2)
out = model.decoder(asr, F0_pred, N_pred, ref.squeeze().unsqueeze(0))
wav = out.squeeze().detach().cpu().numpy()
if wav.shape[-1] > 50:
wav = wav[:-50]
# # Hậu xử lý: trim + denoise + bỏ silence nội bộ
wav = postprocess_audio(
wav, SR_OUT,
trim_top_db=30,
denoise=True,
denoise_n_fft=1024, denoise_hop=256, denoise_strength=0.8,
remove_internal_silence=False,
split_top_db=30, min_keep_ms=40, crossfade_ms=8
)
return wav, ps, simi_timbre, simi_prosody
# ==============================================================
# Ref-audio mapping (quét ./ref_voice để tìm file mẫu theo speaker)
# ==============================================================
def _norm(s: str) -> str:
import unicodedata
s = unicodedata.normalize("NFKD", s)
s = "".join([c for c in s if not unicodedata.combining(c)])
s = s.lower()
s = re.sub(r"[^a-z0-9_\-\.]+", "", s)
return s
def build_ref_map(ref_dir: str) -> dict:
paths = glob.glob(os.path.join(ref_dir, "**", "*.wav"), recursive=True)
by_name = {}
for p in paths:
fname = os.path.basename(p)
by_name[_norm(fname)] = p
spk_map = {}
speakers = list(styles_dict.keys()) if isinstance(styles_dict, dict) else ["Kore","Algenib","Puck","Leda"]
for spk in speakers:
spk_n = _norm(spk)
hit = None
for k, p in by_name.items():
if f"_{spk_n}_" in k:
hit = p
break
if not hit:
for k, p in by_name.items():
if spk_n in k:
hit = p
break
if hit:
spk_map[spk] = hit
return spk_map
REF_MAP = build_ref_map(REF_DIR)
def get_ref_path_for_speaker(spk: str):
return REF_MAP.get(spk)
# ==============================================================
# Wrapper cho Gradio (nhận speaker_mix_spec là string ẩn)
# ==============================================================
def run_inference(text, alpha, beta, speaker_mix_spec):
normalizer = SoeNormalizer()
text = normalizer.normalize(text).replace(" ,", ",").replace(" .", ".")
ps = phonemize_text(text)
phoneme_len = len(ps.replace(" ", ""))
mix_dict = parse_mix_spec(speaker_mix_spec)
if len(mix_dict) > 0:
ref_feat = get_style_mixed_by_length(mix_dict, phoneme_len)
ref_idx = min(phoneme_len, 510)
total = sum(mix_dict.values())
mix_info = {k: round(float(v / total), 3) for k, v in mix_dict.items()}
chosen_speakers = list(mix_dict.keys())
else:
ref_feat = get_style_by_length(DEFAULT_SPK, phoneme_len)
ref_idx = min(phoneme_len, 510)
mix_info = {DEFAULT_SPK: 1.0}
chosen_speakers = [DEFAULT_SPK]
t0 = time.time()
wav, ps_out, simi_timbre, simi_prosody = inference_one(
text, ref_feat, alpha=float(alpha), beta=float(beta)
)
gen_time = time.time() - t0
rtf = gen_time / max(1e-6, len(wav) / SR_OUT)
info = {
"Text after soe_vinorms:": text,
"Speakers": chosen_speakers,
"Mix weights (normalized)": mix_info,
"Phonemes": ps_out,
"Phoneme length": phoneme_len,
"Ref index": ref_idx,
"simi_timbre": round(float(simi_timbre), 4),
"simi_prosody": round(float(simi_prosody), 4),
"alpha": float(alpha),
"beta": float(beta),
"RTF": round(float(rtf), 3),
"Device": DEVICE,
}
return (SR_OUT, wav.astype(np.float32)), info
# ==============================================================
# UI helper: build mix-spec CỐ ĐỊNH theo 4 speaker
# ==============================================================
def _build_mix_spec_ui_fixed(normalize, w1, w2, w3, w4, order):
pairs = [(order[0], float(w1 or 0.0)),
(order[1], float(w2 or 0.0)),
(order[2], float(w3 or 0.0)),
(order[3], float(w4 or 0.0))]
pairs = [(s, w) for s, w in pairs if w > 0]
if not pairs:
return "", {}, "**Sum:** 0.000"
total = sum(w for _, w in pairs)
if normalize and total > 0:
pairs = [(s, w/total) for s, w in pairs]
acc = {}
for s, w in pairs:
acc[s] = acc.get(s, 0.0) + w
mix_spec = ",".join([f"{s}:{w:.4f}" for s, w in acc.items()])
mix_view = {"weights": {s: round(w, 3) for s, w in acc.items()}, "normalized": bool(normalize)}
sum_md = f"**Sum:** {round(sum(acc.values()), 3)}"
return mix_spec, mix_view, sum_md
# ==============================================================
# Gradio UI - Compact & Clean Version
# ==============================================================
with gr.Blocks(title="StyleTTS2-vi Demo", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎙️ StyleTTS2-vi Demo")
with gr.Row():
with gr.Column(scale=1):
text_inp = gr.Textbox(
label="📝 Text Input",
lines=3,
placeholder="Nhập văn bản cần đọc...",
value="Trăng treo lơ lửng trên đỉnh núi chơ vơ, ánh sáng bàng bạc phủ lên bãi đá ngổn ngang. Con dế thổn thức trong khe cỏ, tiếng gió hun hút lùa qua hốc núi trập trùng. Dưới thung lũng, đàn trâu gặm cỏ ung dung, hơi sương vẩn đục, lảng bảng giữa đồng khuya tĩnh mịch."
)
# Danh sách speaker có trong styles_dict
spk_choices = list(styles_dict.keys()) if isinstance(styles_dict, dict) else ["Kore","Algenib","Puck","Leda"]
# Thứ tự CỐ ĐỊNH cho mixer
fixed_order = [s for s in ["Kore", "Puck", "Algenib", "Leda"] if s in spk_choices]
if len(fixed_order) < 4:
for s in spk_choices:
if s not in fixed_order:
fixed_order.append(s)
if len(fixed_order) == 4:
break
# === Reference samples - Compact grid ===
with gr.Accordion("🎵 Reference Samples", open=True):
gr.Markdown("*Click to preview voice samples*")
for i in range(0, 4, 2):
with gr.Row():
for j in range(2):
idx = i + j
if idx < len(fixed_order):
spk = fixed_order[idx]
with gr.Column(min_width=200):
gr.Audio(
value=get_ref_path_for_speaker(spk),
label=spk,
type="filepath",
interactive=False,
show_download_button=False
)
# ---- Style Mixer - More compact ----
with gr.Accordion("🎨 Style Mixer", open=True):
normalize_ck = gr.Checkbox(value=True, label="Auto-normalize", container=False)
# Grid 2x2 cho 4 sliders
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(f"**{fixed_order[0]}**")
w1 = gr.Slider(0.0, 1.0, value=0.0, step=0.05, show_label=False, container=False)
with gr.Column(scale=1):
gr.Markdown(f"**{fixed_order[1]}**")
w2 = gr.Slider(0.0, 1.0, value=0.0, step=0.05, show_label=False, container=False)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(f"**{fixed_order[2]}**")
w3 = gr.Slider(0.0, 1.0, value=0.0, step=0.05, show_label=False, container=False)
with gr.Column(scale=1):
gr.Markdown(f"**{fixed_order[3]}**")
w4 = gr.Slider(0.0, 1.0, value=0.0, step=0.05, show_label=False, container=False)
with gr.Row():
mix_sum_md = gr.Markdown("**Sum:** 0.000")
mix_view_json = gr.JSON(label="Current Mix", visible=False)
mix_spec_state = gr.State("")
order_state = gr.State(fixed_order)
# Advanced settings - Collapsed by default
with gr.Accordion("⚙️ Advanced Settings", open=False):
with gr.Row():
alpha_n = gr.Number(value=ALPHA, label="Alpha (timbre)", precision=3, minimum=0, maximum=1)
beta_n = gr.Number(value=BETA, label="Beta (prosody)", precision=3, minimum=0, maximum=1)
btn = gr.Button("🔊 Generate Speech", variant="primary", size="lg")
with gr.Column(scale=1):
out_audio = gr.Audio(label="🎧 Output Audio", type="numpy")
with gr.Accordion("📊 Generation Metrics", open=False):
metrics = gr.JSON(label="Details")
# Event handlers
def _ui_build_wrapper_fixed(normalize, w1, w2, w3, w4, order):
spec, view, summ = _build_mix_spec_ui_fixed(normalize, w1, w2, w3, w4, order)
return spec, view, summ
for comp in [normalize_ck, w1, w2, w3, w4]:
comp.change(
_ui_build_wrapper_fixed,
inputs=[normalize_ck, w1, w2, w3, w4, order_state],
outputs=[mix_spec_state, mix_view_json, mix_sum_md]
)
btn.click(
run_inference,
inputs=[text_inp, alpha_n, beta_n, mix_spec_state],
outputs=[out_audio, metrics]
)
if __name__ == "__main__":
demo.launch()