Spaces:
Running
Running
| import torch | |
| import torchaudio | |
| import librosa | |
| import yaml | |
| import numpy as np | |
| import soundfile as sf | |
| import phonemizer | |
| from munch import Munch | |
| import os | |
| import time | |
| # Import các module từ StyleTTS2 repo | |
| from models import * | |
| from utils import * | |
| from text_utils import TextCleaner | |
| from Utils.PLBERT.util import load_plbert | |
| from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule | |
| # ================= CẤU HÌNH ĐƯỜNG DẪN (SỬA TẠI ĐÂY) ================= | |
| CONFIG_PATH = "/workspace/trainTTS/StyleTTS2_custom/Configs/config_ft.yml" # Đường dẫn file config | |
| MODEL_PATH = "/workspace/trainTTS/StyleTTS2_custom/Models/mix5voice/model_iter_00032000.pth" # Đường dẫn model (đã clean hoặc chưa clean đều được) | |
| REF_AUDIO_PATH = "/workspace/trainTTS/StyleTTS2_custom/test_voice_clone/hue_ban_mai.wav" # File giọng mẫu | |
| OUTPUT_WAV = "/workspace/trainTTS/StyleTTS2_custom/test_voice_clone/hue_ban_mai_cut.wav" # File đầu ra | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ==================================================================== | |
| class StyleTTS2Inference: | |
| def __init__(self, config_path, model_path, device=DEVICE): | |
| self.device = device | |
| self.config = yaml.safe_load(open(config_path)) | |
| # 1. Khởi tạo các công cụ hỗ trợ | |
| self.phonemizer = phonemizer.backend.EspeakBackend( | |
| language='vi', preserve_punctuation=True, with_stress=True | |
| ) | |
| self.text_cleaner = TextCleaner() | |
| # 2. Load các thành phần cốt lõi (Structure) | |
| # Lưu ý: Vẫn cần load cấu trúc ASR/F0 để build_model không lỗi, | |
| # dù sau này không load trọng số vào chúng cũng không sao. | |
| text_aligner = load_ASR_models(self.config['ASR_path'], self.config['ASR_config']) | |
| pitch_extractor = load_F0_models(self.config['F0_path']) | |
| plbert = load_plbert(self.config['PLBERT_dir']) | |
| # 3. Xây dựng kiến trúc model (Vỏ rỗng) | |
| model_params = recursive_munch(self.config['model_params']) | |
| self.model = build_model(model_params, text_aligner, pitch_extractor, plbert) | |
| # 4. Load trọng số (State Dict) - PHẦN QUAN TRỌNG NHẤT | |
| print(f"Loading model from: {model_path}") | |
| params = torch.load(model_path, map_location='cpu') | |
| # Nếu file save có key 'net' thì lấy, không thì lấy trực tiếp (tùy cách save) | |
| if 'net' in params: | |
| params = params['net'] | |
| for key in self.model: | |
| # --- CHECK QUAN TRỌNG: Chỉ load nếu có trong file checkpoint --- | |
| if key not in params: | |
| print(f"⚠️ Bỏ qua module '{key}' (không tìm thấy trong checkpoint - OK với model inference)") | |
| continue | |
| # --------------------------------------------------------------- | |
| state_dict = params[key] | |
| new_state_dict = {} | |
| # Xử lý prefix "module." nếu train bằng DataParallel | |
| for k, v in state_dict.items(): | |
| if k.startswith("module."): | |
| new_state_dict[k[len("module."):]] = v | |
| else: | |
| new_state_dict[k] = v | |
| self.model[key].load_state_dict(new_state_dict, strict=True) | |
| self.model[key].eval().to(self.device) | |
| print(f"✅ Loaded module: {key}") | |
| # 5. Khởi tạo Sampler cho Diffusion | |
| self.sampler = DiffusionSampler( | |
| self.model.diffusion.diffusion, | |
| sampler=ADPM2Sampler(), | |
| sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), | |
| clamp=False | |
| ) | |
| print("Model initialization complete.\n") | |
| def preprocess_audio(self, audio_path): | |
| """Chuyển đổi audio reference thành Style Vector""" | |
| wave, sr = librosa.load(audio_path, sr=24000) | |
| audio, _ = librosa.effects.trim(wave, top_db=30) | |
| to_mel = torchaudio.transforms.MelSpectrogram( | |
| n_mels=80, n_fft=2048, win_length=1200, hop_length=300 | |
| ) | |
| mel = to_mel(torch.from_numpy(audio).float()) | |
| mel = (torch.log(1e-5 + mel.unsqueeze(0)) - (-4)) / 4 | |
| mel = mel.to(self.device) | |
| with torch.no_grad(): | |
| ref_s = self.model.style_encoder(mel.unsqueeze(1)) | |
| ref_p = self.model.predictor_encoder(mel.unsqueeze(1)) | |
| ref_style = torch.cat([ref_s, ref_p], dim=1) | |
| return ref_style | |
| def preprocess_text(self, text): | |
| """Phonemize và Tokenize văn bản""" | |
| text = text.strip() | |
| if not text: return None | |
| ps = self.phonemizer.phonemize([text])[0] | |
| tokens = torch.LongTensor(self.text_cleaner(ps)).to(self.device).unsqueeze(0) | |
| # Thêm token start/padding | |
| tokens = torch.cat([torch.LongTensor([0]).to(self.device).unsqueeze(0), tokens], dim=-1) | |
| return tokens | |
| def inference(self, text, ref_style, diffusion_steps=5, alpha=0.3, beta=0.7): | |
| """Hàm suy luận cốt lõi""" | |
| tokens = self.preprocess_text(text) | |
| if tokens is None: return None | |
| input_lengths = torch.LongTensor([tokens.shape[-1]]).to(self.device) | |
| text_mask = length_to_mask(input_lengths).to(self.device) | |
| with torch.no_grad(): | |
| # Text encoding & BERT | |
| t_en = self.model.text_encoder(tokens, input_lengths, text_mask) | |
| bert_dur = self.model.bert(tokens, attention_mask=(~text_mask).int()) | |
| d_en = self.model.bert_encoder(bert_dur).transpose(-1, -2) | |
| # Diffusion Sampling (Tạo style vector đa dạng) | |
| s_pred = self.sampler( | |
| noise=torch.randn((1, 256)).unsqueeze(1).to(self.device), | |
| embedding=bert_dur, | |
| features=ref_style, | |
| num_steps=diffusion_steps | |
| ).squeeze(1) | |
| # Trộn style dự đoán và style gốc (Ref audio) | |
| # alpha: trọng số giữ lại của style dự đoán (càng cao càng đa dạng nhưng có thể lệch giọng) | |
| # beta: trọng số giữ lại của style gốc (càng cao càng giống giọng mẫu) | |
| s = s_pred[:, 128:] * alpha + ref_style[:, 128:] * beta | |
| ref = s_pred[:, :128] * alpha + ref_style[:, :128] * beta | |
| # Predictor (Dự đoán Duration, F0, N) | |
| d = self.model.predictor.text_encoder(d_en, s, input_lengths, text_mask) | |
| x, _ = self.model.predictor.lstm(d) | |
| duration = torch.sigmoid(self.model.predictor.duration_proj(x)).sum(axis=-1) | |
| pred_dur = torch.round(duration.squeeze()).clamp(min=1) | |
| # Alignment Map Construction | |
| pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data)) | |
| c_frame = 0 | |
| for i in range(pred_aln_trg.size(0)): | |
| pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1 | |
| c_frame += int(pred_dur[i].data) | |
| # Decoder (Sinh âm thanh) | |
| en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(self.device)) | |
| F0_pred, N_pred = self.model.predictor.F0Ntrain(en, s) | |
| asr = (t_en @ pred_aln_trg.unsqueeze(0).to(self.device)) | |
| out = self.model.decoder(asr, F0_pred, N_pred, ref.squeeze().unsqueeze(0)) | |
| return out.squeeze().cpu().numpy()[..., :-50] # Cắt bớt đuôi silence | |
| def generate_long_text(self, text, ref_audio_path): | |
| """Xử lý văn bản dài bằng cách tách câu""" | |
| print(f"Processing audio ref: {ref_audio_path}") | |
| ref_style = self.preprocess_audio(ref_audio_path) | |
| # Tách câu đơn giản (có thể cải thiện bằng nltk nếu cần) | |
| sentences = text.split('.') | |
| wavs = [] | |
| start_time = time.time() | |
| print("Start synthesizing...") | |
| for sent in sentences: | |
| if len(sent.strip()) == 0: continue | |
| # Thêm dấu chấm để ngắt nghỉ tự nhiên hơn nếu phonemizer cần | |
| if not sent.strip().endswith('.'): sent += '.' | |
| wav = self.inference(sent, ref_style) | |
| if wav is not None: | |
| wavs.append(wav) | |
| # Thêm khoảng lặng ngắn giữa các câu (0.1s) | |
| silence = np.zeros(int(24000 * 0.1)) | |
| wavs.append(silence) | |
| full_wav = np.concatenate(wavs) | |
| print(f"Done! Total time: {time.time() - start_time:.2f}s") | |
| return full_wav | |
| # ================= MAIN EXECUTION ================= | |
| if __name__ == "__main__": | |
| # 1. Khởi tạo | |
| tts = StyleTTS2Inference(CONFIG_PATH, MODEL_PATH) | |
| # 2. Danh sách văn bản cần đọc | |
| list_texts = ["xin chào việt nam, hôm nay trời rất đẹp"] | |
| full_audio = [] | |
| # 3. Chạy vòng lặp tạo giọng | |
| for text in list_texts: | |
| audio_segment = tts.generate_long_text(text, REF_AUDIO_PATH) | |
| full_audio.append(audio_segment) | |
| # Thêm khoảng lặng giữa các đoạn văn lớn (0.5s) | |
| full_audio.append(np.zeros(int(24000 * 0.5))) | |
| # 4. Lưu file kết quả | |
| final_wav = np.concatenate(full_audio) | |
| sf.write(OUTPUT_WAV, final_wav, 24000) | |
| print(f"File saved to: {OUTPUT_WAV}") |