File size: 9,545 Bytes
1b242be | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 |
import torch
import torchaudio
import librosa
import yaml
import numpy as np
import soundfile as sf
import phonemizer
from munch import Munch
import os
import time
# Import các module từ StyleTTS2 repo
from models import *
from utils import *
from text_utils import TextCleaner
from Utils.PLBERT.util import load_plbert
from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
# ================= CẤU HÌNH ĐƯỜNG DẪN (SỬA TẠI ĐÂY) =================
CONFIG_PATH = "/workspace/trainTTS/StyleTTS2_custom/Configs/config_ft.yml" # Đường dẫn file config
MODEL_PATH = "/workspace/trainTTS/StyleTTS2_custom/Models/mix5voice/model_iter_00032000.pth" # Đường dẫn model (đã clean hoặc chưa clean đều được)
REF_AUDIO_PATH = "/workspace/trainTTS/StyleTTS2_custom/test_voice_clone/hue_ban_mai.wav" # File giọng mẫu
OUTPUT_WAV = "/workspace/trainTTS/StyleTTS2_custom/test_voice_clone/hue_ban_mai_cut.wav" # File đầu ra
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ====================================================================
class StyleTTS2Inference:
def __init__(self, config_path, model_path, device=DEVICE):
self.device = device
self.config = yaml.safe_load(open(config_path))
# 1. Khởi tạo các công cụ hỗ trợ
self.phonemizer = phonemizer.backend.EspeakBackend(
language='vi', preserve_punctuation=True, with_stress=True
)
self.text_cleaner = TextCleaner()
# 2. Load các thành phần cốt lõi (Structure)
# Lưu ý: Vẫn cần load cấu trúc ASR/F0 để build_model không lỗi,
# dù sau này không load trọng số vào chúng cũng không sao.
text_aligner = load_ASR_models(self.config['ASR_path'], self.config['ASR_config'])
pitch_extractor = load_F0_models(self.config['F0_path'])
plbert = load_plbert(self.config['PLBERT_dir'])
# 3. Xây dựng kiến trúc model (Vỏ rỗng)
model_params = recursive_munch(self.config['model_params'])
self.model = build_model(model_params, text_aligner, pitch_extractor, plbert)
# 4. Load trọng số (State Dict) - PHẦN QUAN TRỌNG NHẤT
print(f"Loading model from: {model_path}")
params = torch.load(model_path, map_location='cpu')
# Nếu file save có key 'net' thì lấy, không thì lấy trực tiếp (tùy cách save)
if 'net' in params:
params = params['net']
for key in self.model:
# --- CHECK QUAN TRỌNG: Chỉ load nếu có trong file checkpoint ---
if key not in params:
print(f"⚠️ Bỏ qua module '{key}' (không tìm thấy trong checkpoint - OK với model inference)")
continue
# ---------------------------------------------------------------
state_dict = params[key]
new_state_dict = {}
# Xử lý prefix "module." nếu train bằng DataParallel
for k, v in state_dict.items():
if k.startswith("module."):
new_state_dict[k[len("module."):]] = v
else:
new_state_dict[k] = v
self.model[key].load_state_dict(new_state_dict, strict=True)
self.model[key].eval().to(self.device)
print(f"✅ Loaded module: {key}")
# 5. Khởi tạo Sampler cho Diffusion
self.sampler = DiffusionSampler(
self.model.diffusion.diffusion,
sampler=ADPM2Sampler(),
sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0),
clamp=False
)
print("Model initialization complete.\n")
def preprocess_audio(self, audio_path):
"""Chuyển đổi audio reference thành Style Vector"""
wave, sr = librosa.load(audio_path, sr=24000)
audio, _ = librosa.effects.trim(wave, top_db=30)
to_mel = torchaudio.transforms.MelSpectrogram(
n_mels=80, n_fft=2048, win_length=1200, hop_length=300
)
mel = to_mel(torch.from_numpy(audio).float())
mel = (torch.log(1e-5 + mel.unsqueeze(0)) - (-4)) / 4
mel = mel.to(self.device)
with torch.no_grad():
ref_s = self.model.style_encoder(mel.unsqueeze(1))
ref_p = self.model.predictor_encoder(mel.unsqueeze(1))
ref_style = torch.cat([ref_s, ref_p], dim=1)
return ref_style
def preprocess_text(self, text):
"""Phonemize và Tokenize văn bản"""
text = text.strip()
if not text: return None
ps = self.phonemizer.phonemize([text])[0]
tokens = torch.LongTensor(self.text_cleaner(ps)).to(self.device).unsqueeze(0)
# Thêm token start/padding
tokens = torch.cat([torch.LongTensor([0]).to(self.device).unsqueeze(0), tokens], dim=-1)
return tokens
def inference(self, text, ref_style, diffusion_steps=5, alpha=0.3, beta=0.7):
"""Hàm suy luận cốt lõi"""
tokens = self.preprocess_text(text)
if tokens is None: return None
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(self.device)
text_mask = length_to_mask(input_lengths).to(self.device)
with torch.no_grad():
# Text encoding & BERT
t_en = self.model.text_encoder(tokens, input_lengths, text_mask)
bert_dur = self.model.bert(tokens, attention_mask=(~text_mask).int())
d_en = self.model.bert_encoder(bert_dur).transpose(-1, -2)
# Diffusion Sampling (Tạo style vector đa dạng)
s_pred = self.sampler(
noise=torch.randn((1, 256)).unsqueeze(1).to(self.device),
embedding=bert_dur,
features=ref_style,
num_steps=diffusion_steps
).squeeze(1)
# Trộn style dự đoán và style gốc (Ref audio)
# alpha: trọng số giữ lại của style dự đoán (càng cao càng đa dạng nhưng có thể lệch giọng)
# beta: trọng số giữ lại của style gốc (càng cao càng giống giọng mẫu)
s = s_pred[:, 128:] * alpha + ref_style[:, 128:] * beta
ref = s_pred[:, :128] * alpha + ref_style[:, :128] * beta
# Predictor (Dự đoán Duration, F0, N)
d = self.model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
x, _ = self.model.predictor.lstm(d)
duration = torch.sigmoid(self.model.predictor.duration_proj(x)).sum(axis=-1)
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
# Alignment Map Construction
pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
c_frame = 0
for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
c_frame += int(pred_dur[i].data)
# Decoder (Sinh âm thanh)
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(self.device))
F0_pred, N_pred = self.model.predictor.F0Ntrain(en, s)
asr = (t_en @ pred_aln_trg.unsqueeze(0).to(self.device))
out = self.model.decoder(asr, F0_pred, N_pred, ref.squeeze().unsqueeze(0))
return out.squeeze().cpu().numpy()[..., :-50] # Cắt bớt đuôi silence
def generate_long_text(self, text, ref_audio_path):
"""Xử lý văn bản dài bằng cách tách câu"""
print(f"Processing audio ref: {ref_audio_path}")
ref_style = self.preprocess_audio(ref_audio_path)
# Tách câu đơn giản (có thể cải thiện bằng nltk nếu cần)
sentences = text.split('.')
wavs = []
start_time = time.time()
print("Start synthesizing...")
for sent in sentences:
if len(sent.strip()) == 0: continue
# Thêm dấu chấm để ngắt nghỉ tự nhiên hơn nếu phonemizer cần
if not sent.strip().endswith('.'): sent += '.'
wav = self.inference(sent, ref_style)
if wav is not None:
wavs.append(wav)
# Thêm khoảng lặng ngắn giữa các câu (0.1s)
silence = np.zeros(int(24000 * 0.1))
wavs.append(silence)
full_wav = np.concatenate(wavs)
print(f"Done! Total time: {time.time() - start_time:.2f}s")
return full_wav
# ================= MAIN EXECUTION =================
if __name__ == "__main__":
# 1. Khởi tạo
tts = StyleTTS2Inference(CONFIG_PATH, MODEL_PATH)
# 2. Danh sách văn bản cần đọc
list_texts = ["xin chào việt nam, hôm nay trời rất đẹp"]
full_audio = []
# 3. Chạy vòng lặp tạo giọng
for text in list_texts:
audio_segment = tts.generate_long_text(text, REF_AUDIO_PATH)
full_audio.append(audio_segment)
# Thêm khoảng lặng giữa các đoạn văn lớn (0.5s)
full_audio.append(np.zeros(int(24000 * 0.5)))
# 4. Lưu file kết quả
final_wav = np.concatenate(full_audio)
sf.write(OUTPUT_WAV, final_wav, 24000)
print(f"File saved to: {OUTPUT_WAV}") |