File size: 9,545 Bytes
1b242be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218


import torch
import torchaudio
import librosa
import yaml
import numpy as np
import soundfile as sf
import phonemizer
from munch import Munch
import os
import time

# Import các module từ StyleTTS2 repo
from models import *
from utils import *
from text_utils import TextCleaner
from Utils.PLBERT.util import load_plbert
from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule

# ================= CẤU HÌNH ĐƯỜNG DẪN (SỬA TẠI ĐÂY) =================
CONFIG_PATH = "/workspace/trainTTS/StyleTTS2_custom/Configs/config_ft.yml"             # Đường dẫn file config
MODEL_PATH = "/workspace/trainTTS/StyleTTS2_custom/Models/mix5voice/model_iter_00032000.pth" # Đường dẫn model (đã clean hoặc chưa clean đều được)
REF_AUDIO_PATH = "/workspace/trainTTS/StyleTTS2_custom/test_voice_clone/hue_ban_mai.wav"                     # File giọng mẫu
OUTPUT_WAV = "/workspace/trainTTS/StyleTTS2_custom/test_voice_clone/hue_ban_mai_cut.wav"                  # File đầu ra

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ====================================================================

class StyleTTS2Inference:
    def __init__(self, config_path, model_path, device=DEVICE):
        self.device = device
        self.config = yaml.safe_load(open(config_path))
        
        # 1. Khởi tạo các công cụ hỗ trợ
        self.phonemizer = phonemizer.backend.EspeakBackend(
            language='vi', preserve_punctuation=True, with_stress=True
        )
        self.text_cleaner = TextCleaner()
        
        # 2. Load các thành phần cốt lõi (Structure)
        # Lưu ý: Vẫn cần load cấu trúc ASR/F0 để build_model không lỗi, 
        # dù sau này không load trọng số vào chúng cũng không sao.
        text_aligner = load_ASR_models(self.config['ASR_path'], self.config['ASR_config'])
        pitch_extractor = load_F0_models(self.config['F0_path'])
        plbert = load_plbert(self.config['PLBERT_dir'])
        
        # 3. Xây dựng kiến trúc model (Vỏ rỗng)
        model_params = recursive_munch(self.config['model_params'])
        self.model = build_model(model_params, text_aligner, pitch_extractor, plbert)
        
        # 4. Load trọng số (State Dict) - PHẦN QUAN TRỌNG NHẤT
        print(f"Loading model from: {model_path}")
        params = torch.load(model_path, map_location='cpu')
        
        # Nếu file save có key 'net' thì lấy, không thì lấy trực tiếp (tùy cách save)
        if 'net' in params:
            params = params['net']
            
        for key in self.model:
            # --- CHECK QUAN TRỌNG: Chỉ load nếu có trong file checkpoint ---
            if key not in params:
                print(f"⚠️ Bỏ qua module '{key}' (không tìm thấy trong checkpoint - OK với model inference)")
                continue
            # ---------------------------------------------------------------
            
            state_dict = params[key]
            new_state_dict = {}
            # Xử lý prefix "module." nếu train bằng DataParallel
            for k, v in state_dict.items():
                if k.startswith("module."):
                    new_state_dict[k[len("module."):]] = v
                else:
                    new_state_dict[k] = v
            
            self.model[key].load_state_dict(new_state_dict, strict=True)
            self.model[key].eval().to(self.device)
            print(f"✅ Loaded module: {key}")

        # 5. Khởi tạo Sampler cho Diffusion
        self.sampler = DiffusionSampler(
            self.model.diffusion.diffusion,
            sampler=ADPM2Sampler(),
            sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0),
            clamp=False
        )
        print("Model initialization complete.\n")

    def preprocess_audio(self, audio_path):
        """Chuyển đổi audio reference thành Style Vector"""
        wave, sr = librosa.load(audio_path, sr=24000)
        audio, _ = librosa.effects.trim(wave, top_db=30)
        
        to_mel = torchaudio.transforms.MelSpectrogram(
            n_mels=80, n_fft=2048, win_length=1200, hop_length=300
        )
        mel = to_mel(torch.from_numpy(audio).float())
        mel = (torch.log(1e-5 + mel.unsqueeze(0)) - (-4)) / 4
        mel = mel.to(self.device)
        
        with torch.no_grad():
            ref_s = self.model.style_encoder(mel.unsqueeze(1))
            ref_p = self.model.predictor_encoder(mel.unsqueeze(1))
            ref_style = torch.cat([ref_s, ref_p], dim=1)
            
        return ref_style

    def preprocess_text(self, text):
        """Phonemize và Tokenize văn bản"""
        text = text.strip()
        if not text: return None
        
        ps = self.phonemizer.phonemize([text])[0]
        tokens = torch.LongTensor(self.text_cleaner(ps)).to(self.device).unsqueeze(0)
        # Thêm token start/padding
        tokens = torch.cat([torch.LongTensor([0]).to(self.device).unsqueeze(0), tokens], dim=-1)
        return tokens

    def inference(self, text, ref_style, diffusion_steps=5, alpha=0.3, beta=0.7):
        """Hàm suy luận cốt lõi"""
        tokens = self.preprocess_text(text)
        if tokens is None: return None
        
        input_lengths = torch.LongTensor([tokens.shape[-1]]).to(self.device)
        text_mask = length_to_mask(input_lengths).to(self.device)

        with torch.no_grad():
            # Text encoding & BERT
            t_en = self.model.text_encoder(tokens, input_lengths, text_mask)
            bert_dur = self.model.bert(tokens, attention_mask=(~text_mask).int())
            d_en = self.model.bert_encoder(bert_dur).transpose(-1, -2)

            # Diffusion Sampling (Tạo style vector đa dạng)
            s_pred = self.sampler(
                noise=torch.randn((1, 256)).unsqueeze(1).to(self.device),
                embedding=bert_dur,
                features=ref_style,
                num_steps=diffusion_steps
            ).squeeze(1)

            # Trộn style dự đoán và style gốc (Ref audio)
            # alpha: trọng số giữ lại của style dự đoán (càng cao càng đa dạng nhưng có thể lệch giọng)
            # beta: trọng số giữ lại của style gốc (càng cao càng giống giọng mẫu)
            s = s_pred[:, 128:] * alpha + ref_style[:, 128:] * beta
            ref = s_pred[:, :128] * alpha + ref_style[:, :128] * beta

            # Predictor (Dự đoán Duration, F0, N)
            d = self.model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
            x, _ = self.model.predictor.lstm(d)
            duration = torch.sigmoid(self.model.predictor.duration_proj(x)).sum(axis=-1)
            pred_dur = torch.round(duration.squeeze()).clamp(min=1)

            # Alignment Map Construction
            pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
            c_frame = 0
            for i in range(pred_aln_trg.size(0)):
                pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
                c_frame += int(pred_dur[i].data)

            # Decoder (Sinh âm thanh)
            en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(self.device))
            F0_pred, N_pred = self.model.predictor.F0Ntrain(en, s)
            asr = (t_en @ pred_aln_trg.unsqueeze(0).to(self.device))
            
            out = self.model.decoder(asr, F0_pred, N_pred, ref.squeeze().unsqueeze(0))

        return out.squeeze().cpu().numpy()[..., :-50] # Cắt bớt đuôi silence

    def generate_long_text(self, text, ref_audio_path):
        """Xử lý văn bản dài bằng cách tách câu"""
        print(f"Processing audio ref: {ref_audio_path}")
        ref_style = self.preprocess_audio(ref_audio_path)
        
        # Tách câu đơn giản (có thể cải thiện bằng nltk nếu cần)
        sentences = text.split('.')
        wavs = []
        
        start_time = time.time()
        print("Start synthesizing...")
        
        for sent in sentences:
            if len(sent.strip()) == 0: continue
            
            # Thêm dấu chấm để ngắt nghỉ tự nhiên hơn nếu phonemizer cần
            if not sent.strip().endswith('.'): sent += '.'
            
            wav = self.inference(sent, ref_style)
            if wav is not None:
                wavs.append(wav)
                # Thêm khoảng lặng ngắn giữa các câu (0.1s)
                silence = np.zeros(int(24000 * 0.1))
                wavs.append(silence)
        
        full_wav = np.concatenate(wavs)
        print(f"Done! Total time: {time.time() - start_time:.2f}s")
        return full_wav

# ================= MAIN EXECUTION =================
if __name__ == "__main__":
    # 1. Khởi tạo
    tts = StyleTTS2Inference(CONFIG_PATH, MODEL_PATH)

    # 2. Danh sách văn bản cần đọc
    list_texts = ["xin chào việt nam, hôm nay trời rất đẹp"]

    full_audio = []
    
    # 3. Chạy vòng lặp tạo giọng
    for text in list_texts:
        audio_segment = tts.generate_long_text(text, REF_AUDIO_PATH)
        full_audio.append(audio_segment)
        # Thêm khoảng lặng giữa các đoạn văn lớn (0.5s)
        full_audio.append(np.zeros(int(24000 * 0.5)))

    # 4. Lưu file kết quả
    final_wav = np.concatenate(full_audio)
    sf.write(OUTPUT_WAV, final_wav, 24000)
    print(f"File saved to: {OUTPUT_WAV}")