StyleTTS2_vi / app.py
hieuducle's picture
Upload folder using huggingface_hub
45feb82 verified
import gradio as gr
import spaces
import torch
import torch.nn.functional as F
import torchaudio
import librosa
import os
import yaml
import numpy as np
import soundfile as sf
import nltk
from nltk.tokenize import word_tokenize
from munch import Munch
import phonemizer
from huggingface_hub import hf_hub_download
# --- SETUP MÔI TRƯỜNG ---
# Download NLTK data
nltk.download('punkt', quiet=True)
nltk.download('punkt_tab', quiet=True)
# --- IMPORT MODULE ---
from models import *
from utils import *
from text_utils import TextCleaner
from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
# --- CONFIG ---
device = 'cuda' if torch.cuda.is_available() else 'cpu'
textclenaer = TextCleaner()
to_mel = torchaudio.transforms.MelSpectrogram(
n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
mean, std = -4, 4
print("Đang khởi tạo cấu hình...")
# 1. Load Config
config_path = "./Configs/config_ft.yml"
config = yaml.safe_load(open(config_path))
# Fix đường dẫn tương đối cho các module phụ
config['ASR_config'] = "./Utils/ASR/config.yml"
config['ASR_path'] = "./Utils/ASR/epoch_00080_191_full.pth"
config['F0_path'] = "./Utils/JDC/bst.t7"
config['PLBERT_dir'] = "./Utils/PLBERT/"
# 2. Load Models phụ
print("Load ASR/F0/BERT...")
text_aligner = load_ASR_models(config['ASR_path'], config['ASR_config'])
pitch_extractor = load_F0_models(config['F0_path'])
from Utils.PLBERT.util import load_plbert
plbert = load_plbert(config['PLBERT_dir'])
# 3. Build Model Frame
model_params = recursive_munch(config['model_params'])
model = build_model(model_params, text_aligner, pitch_extractor, plbert)
_ = [model[key].eval() for key in model]
_ = [model[key].to(device) for key in model]
# --- LOAD MODEL TỪ HUGGING FACE ---
print("Đang tải model checkpoint từ Hugging Face Model Hub...")
MODEL_REPO_ID = "hieuducle/model_styletts2_dolly_checkpoint_12000"
MODEL_FILENAME = "workspace/StyleTTS2/Models/Dolly/model_iter_00012000.pth"
try:
CHECKPOINT_PATH = hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME)
print(f"-> Đã tải xong model về: {CHECKPOINT_PATH}")
except Exception as e:
raise RuntimeError(f"Không tải được model! Lỗi: {e}")
# Load weights
params_whole = torch.load(CHECKPOINT_PATH, map_location='cpu')
params = params_whole['net']
for key in model:
if key in params:
try:
model[key].load_state_dict(params[key])
except:
from collections import OrderedDict
state_dict = params[key]
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model[key].load_state_dict(new_state_dict, strict=False)
_ = [model[key].eval() for key in model]
# 4. Init Sampler & Phonemizer
sampler = DiffusionSampler(
model.diffusion.diffusion,
sampler=ADPM2Sampler(),
sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0),
clamp=False
)
global_phonemizer = phonemizer.backend.EspeakBackend(
language='vi',
preserve_punctuation=True,
with_stress=True,
language_switch="remove-flags"
)
# --- HELPER FUNCTIONS ---
def length_to_mask(lengths):
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
mask = torch.gt(mask+1, lengths.unsqueeze(1))
return mask
def preprocess(wave):
wave_tensor = torch.from_numpy(wave).float()
mel_tensor = to_mel(wave_tensor)
mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
return mel_tensor
def compute_style(path):
wave, sr = librosa.load(path, sr=24000)
audio, index = librosa.effects.trim(wave, top_db=30)
if sr != 24000:
audio = librosa.resample(audio, orig_sr=sr, target_sr=24000)
mel_tensor = preprocess(audio).to(device)
with torch.no_grad():
ref_s = model.style_encoder(mel_tensor.unsqueeze(1))
ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1))
return torch.cat([ref_s, ref_p], dim=1)
def LFinference(text, s_prev, ref_s, alpha, beta, t, diffusion_steps, embedding_scale):
text = text.strip()
ps = global_phonemizer.phonemize([text])
ps = word_tokenize(ps[0])
ps = ' '.join(ps)
ps = ps.replace('``', '"').replace("''", '"')
ps = ps.replace('t̪', '\uFFFF').replace('t', 'tʰ').replace('\uFFFF', 't')
tokens = textclenaer(ps)
tokens.insert(0, 0)
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
with torch.no_grad():
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
text_mask = length_to_mask(input_lengths).to(device)
t_en = model.text_encoder(tokens, input_lengths, text_mask)
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
embedding=bert_dur,
embedding_scale=embedding_scale,
features=ref_s,
num_steps=diffusion_steps).squeeze(1)
if s_prev is not None:
s_pred = t * s_prev + (1 - t) * s_pred
s = s_pred[:, 128:]
ref = s_pred[:, :128]
ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
s = beta * s + (1 - beta) * ref_s[:, 128:]
s_pred = torch.cat([ref, s], dim=-1)
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
x, _ = model.predictor.lstm(d)
duration = model.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1)
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
c_frame = 0
for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
c_frame += int(pred_dur[i].data)
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
if model_params.decoder.type == "hifigan":
asr_new = torch.zeros_like(en)
asr_new[:, :, 0] = en[:, :, 0]
asr_new[:, :, 1:] = en[:, :, 0:-1]
en = asr_new
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
if model_params.decoder.type == "hifigan":
asr_new = torch.zeros_like(asr)
asr_new[:, :, 0] = asr[:, :, 0]
asr_new[:, :, 1:] = asr[:, :, 0:-1]
asr = asr_new
out = model.decoder(asr, F0_pred, N_pred, ref.squeeze().unsqueeze(0))
return out.squeeze().cpu().numpy()[..., :-100], s_pred
# --- GRADIO FUNCTION ---
@spaces.GPU(duration=120)
def generate_voice(text, ref_audio, alpha, beta, diffusion_steps):
if not text:
return None
if not ref_audio:
raise gr.Error("Thiếu file giọng mẫu!")
print(f"Gen: {text[:30]}...")
s_ref = compute_style(ref_audio)
sentences = text.split('.')
wavs = []
s_prev = None
for sent in sentences:
if sent.strip() == "":
continue
sent += '.'
wav, s_prev = LFinference(sent, s_prev, s_ref, alpha, beta, 0.7, int(diffusion_steps), 1.5)
wavs.append(wav)
return (24000, np.concatenate(wavs))
# --- GRADIO UI (ĐƠN GIẢN HÓA CHO GRADIO 5.x) ---
with gr.Blocks(title="StyleTTS2-Vi") as demo:
gr.Markdown("# 🎙️ StyleTTS2 Tiếng Việt")
gr.Markdown("Upload file audio mẫu và nhập văn bản để tạo giọng nói")
with gr.Row(equal_height=True):
with gr.Column(scale=1):
text_input = gr.Textbox(
label="Văn bản cần đọc",
placeholder="Nhập văn bản tiếng Việt...",
value="Xin chào việt nam.",
lines=4
)
with gr.Column(scale=1):
# Sử dụng Slider đơn giản (không để trong Accordion)
alpha_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.3,
step=0.1,
label="Alpha (Style)"
)
beta_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.7,
step=0.1,
label="Beta (Pitch)"
)
steps_slider = gr.Slider(
minimum=5,
maximum=50,
value=10,
step=1,
label="Diffusion Steps"
)
with gr.Row(equal_height=True):
with gr.Column(scale=1):
audio_input = gr.Audio(
label="Giọng mẫu (Reference Audio)",
type="filepath"
)
generate_btn = gr.Button("🎵 Tạo giọng nói", variant="primary")
with gr.Column(scale=1):
audio_output = gr.Audio(label="Kết quả", type="filepath")
# Event handler
generate_btn.click(
fn=generate_voice,
inputs=[text_input, audio_input, alpha_slider, beta_slider, steps_slider],
outputs=audio_output
)
# Launch
if __name__ == "__main__":
demo.launch()