File size: 10,654 Bytes
3301011
 
 
 
 
 
 
 
cfeb5a2
 
 
3301011
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
769160c
3301011
 
 
 
 
 
 
 
 
 
769160c
3301011
 
 
 
 
 
 
 
 
 
769160c
3301011
 
 
 
 
 
 
 
 
 
769160c
3301011
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import os
from numpy import pad
import torch
from huggingface_hub import hf_hub_download
import phonemizer
import yaml
from split_audio.models import load_ASR_models, load_F0_models, build_model
from split_audio.utils import mask_from_lens, maximum_path
from split_audio.utils import length_to_mask, recursive_munch
from split_audio.plbert.plbert import load_plbert
from split_audio.text_utils import TextCleaner
import librosa
import numpy as np
import torchaudio
import soundfile as sf

N_MELS = 80; N_FFT = 2048; WIN = 1200; HOP = 300
MEAN, STD = -4.0, 4.0
PAD = 5000

class AudioSplitter:
    def __init__(self, language: str, model_name: str = "phoaudio_single_v1", device: str = "cpu"):
        self.language = language
        self.model_name = model_name
        self.backend_phonemizer = phonemizer.backend.EspeakBackend(
            language=language,
            preserve_punctuation=True,
            with_stress=True,
        )
        self.device = device
        self.textcleaner = TextCleaner()
        # try to download the model before using it
        try:
            hf_hub_download(
                repo_id="presencesw/tts", 
                filename=self.model_name + ".pth", 
                local_dir="Models", 
                # local_dir_use_symlinks=False,
                token=os.getenv("HF_TOKEN", None)
            )
        except Exception as e:
            print(f"Error downloading model: {e}")

        try:
            hf_hub_download(
                repo_id="presencesw/tts", 
                filename=self.model_name + ".yml", 
                local_dir="Models", 
                # local_dir_use_symlinks=False,
                token=os.getenv("HF_TOKEN", None)
            )
        except Exception as e:
            print(f"Error downloading model: {e}")

        try:
            hf_hub_download(
                repo_id="presencesw/tts", 
                filename=self.model_name + "_asr.yml", 
                local_dir="Models", 
                # local_dir_use_symlinks=False,
                token=os.getenv("HF_TOKEN", None)
            )
        except Exception as e:
            print(f"Error downloading model: {e}")

        try:
            hf_hub_download(
                repo_id="presencesw/tts", 
                filename=self.model_name + "_plbert.yml", 
                local_dir="Models", 
                # local_dir_use_symlinks=False,
                token=os.getenv("HF_TOKEN", None)
            )
        except Exception as e:
            print(f"Error downloading model: {e}")

        self.config = yaml.safe_load(open(os.path.join("Models", self.model_name + ".yml")))

        # text_aligner = load_ASR_models(self.config.get("ASR_config"), self.config.get("ASR_path"))
        text_aligner = load_ASR_models(self.config.get("ASR_path"), self.config.get("ASR_config"))
        pitch_extractor = load_F0_models(self.config.get("F0_path"))
        plbert = load_plbert(self.config.get("PLBERT_dir"))
        model_params = recursive_munch(self.config["model_params"])
        self.model = build_model(model_params, text_aligner, pitch_extractor, plbert)
        _ = [self.model[key].eval() for key in self.model]
        _ = [self.model[key].to(self.device) for key in self.model]

        params_whole = torch.load(os.path.join("Models", self.model_name + ".pth"), map_location="cpu")
        params = params_whole['net']

        for key in self.model:
            if key in params:
                print('%s loaded' % key)
                try:
                    self.model[key].load_state_dict(params[key])
                except:
                    from collections import OrderedDict
                    state_dict = params[key]
                    new_state_dict = OrderedDict()
                    for k, v in state_dict.items():
                        name = k[7:] # remove `module.`
                        new_state_dict[name] = v
                    # load params
                    self.model[key].load_state_dict(new_state_dict, strict=False)
        #             except:
        #                 _load(params[key], model[key])
        _ = [self.model[key].eval() for key in self.model]
        self.n_down = self.model.text_aligner.n_down
        self.d = 2 ** self.n_down

    def find_subsequence(self, seq, subseq):
        n, m = len(seq), len(subseq)
        if m == 0 or m > n:
            return None
        for i in range(n - m + 1):
            if seq[i:i+m] == subseq:
                return i
        return None

    def to_tokens(self, txt: str):
        ps = self.backend_phonemizer.phonemize([txt])[0].strip()
        ps = ps.replace("(en)", "").replace("(vi)", "")
        return self.textcleaner(ps)
    
    def wav_to_mel(self, wave_1d: np.ndarray):
        # if sr_in != sr_target:
        #     w = torch.from_numpy(wave_1d).float()
        #     w = torchaudio.functional.resample(w, sr_in, sr_target)
        #     wave_1d = w.numpy()
        
        wave_pad = np.concatenate(
            [np.zeros(PAD, dtype=wave_1d.dtype), wave_1d, np.zeros(PAD, dtype=wave_1d.dtype)]
        )
        w = torch.from_numpy(wave_pad).float()
        to_mel = torchaudio.transforms.MelSpectrogram(
            n_mels=N_MELS, n_fft=N_FFT, win_length=WIN, hop_length=HOP
        )
        mel = to_mel(w)                   # [n_mels, T]
        mel = (torch.log(1e-5 + mel).unsqueeze(0) - MEAN) / STD  # [1, 80, T]
        mel = mel.squeeze(0)              # [80, T]
        # trim để chia hết cho d
        T = mel.shape[1]; T_trim = T - (T % self.d)
        if T_trim != T:
            mel = mel[:, :T_trim]
            wave_pad = wave_pad[: T_trim * HOP]  # đồng bộ thời gian
        return wave_pad, mel  # np.ndarray (đã pad), torch.Tensor [80, T]
    
    def cal_attn(self, mel_len, text_len, mel, tokens):
        mask_mel = length_to_mask(mel_len // (2 ** self.n_down))
        text_mask = length_to_mask(text_len)
        mels_in = mel.unsqueeze(0) # [1, 80, T]
        ppgs, s2s_pred, s2s_attn = self.model.text_aligner(mels_in, mask_mel, tokens)
        s2s_attn = s2s_attn.transpose(-1, -2)
        s2s_attn = s2s_attn[..., 1:]
        s2s_attn = s2s_attn.transpose(-1, -2)
        attn_mask = (~mask_mel).unsqueeze(-1).expand(mask_mel.shape[0], mask_mel.shape[1], text_mask.shape[-1]).float().transpose(-1, -2)
        attn_mask = attn_mask.float() * (~text_mask).unsqueeze(-1).expand(text_mask.shape[0], text_mask.shape[1], mask_mel.shape[-1]).float()
        attn_mask = (attn_mask < 1)
        s2s_attn.masked_fill_(attn_mask, 0.0)
        mask_ST = mask_from_lens(s2s_attn, text_len, mel_len // (2 ** self.n_down))
        s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
        return s2s_attn_mono
    
    def convert_sr(self, wav, orig_sr, target_sr):
        if orig_sr != target_sr:
            wav = librosa.resample(wav, orig_sr=orig_sr, target_sr=target_sr)
        return wav
    
    def load_audio(self, audio_input, target_sr=24000):
        if isinstance(audio_input, str):
            wav, sr = librosa.load(audio_input, sr=None)
        else:
            wav = audio_input
            sr = target_sr
        wav = self.convert_sr(wav, orig_sr=sr, target_sr=target_sr)
        return wav, target_sr

    def split_audio(self, str_raw: str, str_trunc: str, audio_input):
        ps_trunc = self.to_tokens(str_trunc)
        ps_raw = self.to_tokens(str_raw)
        

        wav_np, mel = self.wav_to_mel(audio_input)
        T = mel.shape[1]
        T_trim = T - (T % self.d)
        if T_trim != T:
            mel = mel[:, :T_trim]

        


        cut_start = self.find_subsequence(ps_raw, ps_trunc)
        cut_end = cut_start + len(ps_trunc)

        ps_trunc = torch.LongTensor(ps_trunc).unsqueeze(0)
        ps_raw = torch.LongTensor(ps_raw).unsqueeze(0)

        mel_len = torch.tensor([mel.shape[1]], dtype=torch.long)
        text_len = torch.tensor([ps_raw.shape[1]], dtype=torch.long)

        s2s_attn_mono = self.cal_attn(
            mel_len=mel_len, 
            text_len=text_len,
            mel=mel, 
            tokens=ps_raw
        )

        with torch.no_grad():
            token_per_frame_down = torch.argmax(s2s_attn_mono[0], dim=0)

        token_per_frame_down = token_per_frame_down.cpu().numpy()
        mask_down = (token_per_frame_down >= cut_start) & (token_per_frame_down < cut_end)
        idx_down = np.where(mask_down)[0]

        start_frame_down = idx_down[0]
        end_frame_down   = idx_down[-1] + 1

        start_frame_full = int(start_frame_down * self.d)
        end_frame_full   = int(end_frame_down * self.d)

        start_sample_in_padded = start_frame_full * HOP
        end_sample_in_padded   = end_frame_full * HOP

        start_sample = max(0, start_sample_in_padded - PAD)
        end_sample   = max(start_sample+1, end_sample_in_padded - PAD)
        end_sample   = min(end_sample, len(wav_np) - PAD)
        y_cut = wav_np[start_sample + PAD : end_sample + PAD]
        # margin_frames_full = int(2 * d)
        # start_sample = max(0, (start_frame_full - margin_frames_full) * HOP - PAD)
        # end_sample   = min(len(wav_np) - 1, (end_frame_full + margin_frames_full) * HOP - pad)
        # y_cut = wav_np[start_sample + pad : end_sample + pad]
        return y_cut

if __name__ == "__main__":
    splitter = AudioSplitter(language="vi", model_name="phoaudio_single_v1", device="cpu")
    # str_raw = "tôi nghĩ đến vóc dáng của tiết vân phong. có lẽ cậu ta cũng đánh thắng tôi. nhưng mà cân nhắc đến chuyện cậu ta đã uống sai, chắc là không khó mà ứng phó. thế là tôi xua xua tay nói,"
    # str_trunc = "tôi nghĩ đến vóc dáng của tiết vân phong"
    # str_trunc = "nhưng mà cân nhắc đến chuyện cậu ta đã uống sai"
    str_raw = "mệt mỏi vì lo lắng. họ ngủ một cách lơ mơ với cái ý thức cảnh giác cố hữu. nhà ga nào cũng đầy bọn trộm cắp. lâu lắm mới nghe tiếng chân của một người."
    str_trunc = "họ ngủ một cách lơ mơ với cái ý thức cảnh giác cố hữu"
    # audio_input, sr = splitter.load_audio("example_trimmed.wav", sr=None)
    # splitter.split_audio(str_raw, str_trunc, audio_input)
    audio_input, sr = splitter.load_audio("Đào_Hiếu.wav", target_sr=24000)
    y_cut = splitter.split_audio(str_raw, str_trunc, audio_input)
    # print(f"audio cut: {y_cut}")
    # librosa.output.write_wav("example_cut.wav", y_cut, sr)
    # use librosa algorithm to trim the silence

    y_cut = librosa.effects.trim(y_cut, top_db=15)[0]
    sf.write("example_cut.wav", y_cut, sr)