StyleTTS2_vi / optimize_model.py
hieuducle's picture
Upload folder using huggingface_hub
84f3a60 verified
import torch
import torchaudio
import librosa
import yaml
import numpy as np
import soundfile as sf
import phonemizer
from munch import Munch
import os
import time
# Import các module từ StyleTTS2 repo
from models import *
from utils import *
from text_utils import TextCleaner
from Utils.PLBERT.util import load_plbert
from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
# ================= CẤU HÌNH ĐƯỜNG DẪN (SỬA TẠI ĐÂY) =================
CONFIG_PATH = "/workspace/trainTTS/StyleTTS2_custom/Configs/config_ft.yml" # Đường dẫn file config
MODEL_PATH = "/workspace/trainTTS/StyleTTS2_custom/Models/mix5voice/model_iter_00032000.pth" # Đường dẫn model (đã clean hoặc chưa clean đều được)
REF_AUDIO_PATH = "/workspace/trainTTS/StyleTTS2_custom/test_voice_clone/hue_ban_mai.wav" # File giọng mẫu
OUTPUT_WAV = "/workspace/trainTTS/StyleTTS2_custom/test_voice_clone/hue_ban_mai_cut.wav" # File đầu ra
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ====================================================================
class StyleTTS2Inference:
def __init__(self, config_path, model_path, device=DEVICE):
self.device = device
self.config = yaml.safe_load(open(config_path))
# 1. Khởi tạo các công cụ hỗ trợ
self.phonemizer = phonemizer.backend.EspeakBackend(
language='vi', preserve_punctuation=True, with_stress=True
)
self.text_cleaner = TextCleaner()
# 2. Load các thành phần cốt lõi (Structure)
# Lưu ý: Vẫn cần load cấu trúc ASR/F0 để build_model không lỗi,
# dù sau này không load trọng số vào chúng cũng không sao.
text_aligner = load_ASR_models(self.config['ASR_path'], self.config['ASR_config'])
pitch_extractor = load_F0_models(self.config['F0_path'])
plbert = load_plbert(self.config['PLBERT_dir'])
# 3. Xây dựng kiến trúc model (Vỏ rỗng)
model_params = recursive_munch(self.config['model_params'])
self.model = build_model(model_params, text_aligner, pitch_extractor, plbert)
# 4. Load trọng số (State Dict) - PHẦN QUAN TRỌNG NHẤT
print(f"Loading model from: {model_path}")
params = torch.load(model_path, map_location='cpu')
# Nếu file save có key 'net' thì lấy, không thì lấy trực tiếp (tùy cách save)
if 'net' in params:
params = params['net']
for key in self.model:
# --- CHECK QUAN TRỌNG: Chỉ load nếu có trong file checkpoint ---
if key not in params:
print(f"⚠️ Bỏ qua module '{key}' (không tìm thấy trong checkpoint - OK với model inference)")
continue
# ---------------------------------------------------------------
state_dict = params[key]
new_state_dict = {}
# Xử lý prefix "module." nếu train bằng DataParallel
for k, v in state_dict.items():
if k.startswith("module."):
new_state_dict[k[len("module."):]] = v
else:
new_state_dict[k] = v
self.model[key].load_state_dict(new_state_dict, strict=True)
self.model[key].eval().to(self.device)
print(f"✅ Loaded module: {key}")
# 5. Khởi tạo Sampler cho Diffusion
self.sampler = DiffusionSampler(
self.model.diffusion.diffusion,
sampler=ADPM2Sampler(),
sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0),
clamp=False
)
print("Model initialization complete.\n")
def preprocess_audio(self, audio_path):
"""Chuyển đổi audio reference thành Style Vector"""
wave, sr = librosa.load(audio_path, sr=24000)
audio, _ = librosa.effects.trim(wave, top_db=30)
to_mel = torchaudio.transforms.MelSpectrogram(
n_mels=80, n_fft=2048, win_length=1200, hop_length=300
)
mel = to_mel(torch.from_numpy(audio).float())
mel = (torch.log(1e-5 + mel.unsqueeze(0)) - (-4)) / 4
mel = mel.to(self.device)
with torch.no_grad():
ref_s = self.model.style_encoder(mel.unsqueeze(1))
ref_p = self.model.predictor_encoder(mel.unsqueeze(1))
ref_style = torch.cat([ref_s, ref_p], dim=1)
return ref_style
def preprocess_text(self, text):
"""Phonemize và Tokenize văn bản"""
text = text.strip()
if not text: return None
ps = self.phonemizer.phonemize([text])[0]
tokens = torch.LongTensor(self.text_cleaner(ps)).to(self.device).unsqueeze(0)
# Thêm token start/padding
tokens = torch.cat([torch.LongTensor([0]).to(self.device).unsqueeze(0), tokens], dim=-1)
return tokens
def inference(self, text, ref_style, diffusion_steps=5, alpha=0.3, beta=0.7):
"""Hàm suy luận cốt lõi"""
tokens = self.preprocess_text(text)
if tokens is None: return None
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(self.device)
text_mask = length_to_mask(input_lengths).to(self.device)
with torch.no_grad():
# Text encoding & BERT
t_en = self.model.text_encoder(tokens, input_lengths, text_mask)
bert_dur = self.model.bert(tokens, attention_mask=(~text_mask).int())
d_en = self.model.bert_encoder(bert_dur).transpose(-1, -2)
# Diffusion Sampling (Tạo style vector đa dạng)
s_pred = self.sampler(
noise=torch.randn((1, 256)).unsqueeze(1).to(self.device),
embedding=bert_dur,
features=ref_style,
num_steps=diffusion_steps
).squeeze(1)
# Trộn style dự đoán và style gốc (Ref audio)
# alpha: trọng số giữ lại của style dự đoán (càng cao càng đa dạng nhưng có thể lệch giọng)
# beta: trọng số giữ lại của style gốc (càng cao càng giống giọng mẫu)
s = s_pred[:, 128:] * alpha + ref_style[:, 128:] * beta
ref = s_pred[:, :128] * alpha + ref_style[:, :128] * beta
# Predictor (Dự đoán Duration, F0, N)
d = self.model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
x, _ = self.model.predictor.lstm(d)
duration = torch.sigmoid(self.model.predictor.duration_proj(x)).sum(axis=-1)
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
# Alignment Map Construction
pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
c_frame = 0
for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
c_frame += int(pred_dur[i].data)
# Decoder (Sinh âm thanh)
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(self.device))
F0_pred, N_pred = self.model.predictor.F0Ntrain(en, s)
asr = (t_en @ pred_aln_trg.unsqueeze(0).to(self.device))
out = self.model.decoder(asr, F0_pred, N_pred, ref.squeeze().unsqueeze(0))
return out.squeeze().cpu().numpy()[..., :-50] # Cắt bớt đuôi silence
def generate_long_text(self, text, ref_audio_path):
"""Xử lý văn bản dài bằng cách tách câu"""
print(f"Processing audio ref: {ref_audio_path}")
ref_style = self.preprocess_audio(ref_audio_path)
# Tách câu đơn giản (có thể cải thiện bằng nltk nếu cần)
sentences = text.split('.')
wavs = []
start_time = time.time()
print("Start synthesizing...")
for sent in sentences:
if len(sent.strip()) == 0: continue
# Thêm dấu chấm để ngắt nghỉ tự nhiên hơn nếu phonemizer cần
if not sent.strip().endswith('.'): sent += '.'
wav = self.inference(sent, ref_style)
if wav is not None:
wavs.append(wav)
# Thêm khoảng lặng ngắn giữa các câu (0.1s)
silence = np.zeros(int(24000 * 0.1))
wavs.append(silence)
full_wav = np.concatenate(wavs)
print(f"Done! Total time: {time.time() - start_time:.2f}s")
return full_wav
# ================= MAIN EXECUTION =================
if __name__ == "__main__":
# 1. Khởi tạo
tts = StyleTTS2Inference(CONFIG_PATH, MODEL_PATH)
# 2. Danh sách văn bản cần đọc
list_texts = ["xin chào việt nam, hôm nay trời rất đẹp"]
full_audio = []
# 3. Chạy vòng lặp tạo giọng
for text in list_texts:
audio_segment = tts.generate_long_text(text, REF_AUDIO_PATH)
full_audio.append(audio_segment)
# Thêm khoảng lặng giữa các đoạn văn lớn (0.5s)
full_audio.append(np.zeros(int(24000 * 0.5)))
# 4. Lưu file kết quả
final_wav = np.concatenate(full_audio)
sf.write(OUTPUT_WAV, final_wav, 24000)
print(f"File saved to: {OUTPUT_WAV}")