import os from numpy import pad import torch from huggingface_hub import hf_hub_download import phonemizer import yaml from split_audio.models import load_ASR_models, load_F0_models, build_model from split_audio.utils import mask_from_lens, maximum_path from split_audio.utils import length_to_mask, recursive_munch from split_audio.plbert.plbert import load_plbert from split_audio.text_utils import TextCleaner import librosa import numpy as np import torchaudio import soundfile as sf N_MELS = 80; N_FFT = 2048; WIN = 1200; HOP = 300 MEAN, STD = -4.0, 4.0 PAD = 5000 class AudioSplitter: def __init__(self, language: str, model_name: str = "phoaudio_single_v1", device: str = "cpu"): self.language = language self.model_name = model_name self.backend_phonemizer = phonemizer.backend.EspeakBackend( language=language, preserve_punctuation=True, with_stress=True, ) self.device = device self.textcleaner = TextCleaner() # try to download the model before using it try: hf_hub_download( repo_id="presencesw/tts", filename=self.model_name + ".pth", local_dir="Models", # local_dir_use_symlinks=False, token=os.getenv("HF_TOKEN", None) ) except Exception as e: print(f"Error downloading model: {e}") try: hf_hub_download( repo_id="presencesw/tts", filename=self.model_name + ".yml", local_dir="Models", # local_dir_use_symlinks=False, token=os.getenv("HF_TOKEN", None) ) except Exception as e: print(f"Error downloading model: {e}") try: hf_hub_download( repo_id="presencesw/tts", filename=self.model_name + "_asr.yml", local_dir="Models", # local_dir_use_symlinks=False, token=os.getenv("HF_TOKEN", None) ) except Exception as e: print(f"Error downloading model: {e}") try: hf_hub_download( repo_id="presencesw/tts", filename=self.model_name + "_plbert.yml", local_dir="Models", # local_dir_use_symlinks=False, token=os.getenv("HF_TOKEN", None) ) except Exception as e: print(f"Error downloading model: {e}") self.config = yaml.safe_load(open(os.path.join("Models", self.model_name + ".yml"))) # text_aligner = load_ASR_models(self.config.get("ASR_config"), self.config.get("ASR_path")) text_aligner = load_ASR_models(self.config.get("ASR_path"), self.config.get("ASR_config")) pitch_extractor = load_F0_models(self.config.get("F0_path")) plbert = load_plbert(self.config.get("PLBERT_dir")) model_params = recursive_munch(self.config["model_params"]) self.model = build_model(model_params, text_aligner, pitch_extractor, plbert) _ = [self.model[key].eval() for key in self.model] _ = [self.model[key].to(self.device) for key in self.model] params_whole = torch.load(os.path.join("Models", self.model_name + ".pth"), map_location="cpu") params = params_whole['net'] for key in self.model: if key in params: print('%s loaded' % key) try: self.model[key].load_state_dict(params[key]) except: from collections import OrderedDict state_dict = params[key] new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v # load params self.model[key].load_state_dict(new_state_dict, strict=False) # except: # _load(params[key], model[key]) _ = [self.model[key].eval() for key in self.model] self.n_down = self.model.text_aligner.n_down self.d = 2 ** self.n_down def find_subsequence(self, seq, subseq): n, m = len(seq), len(subseq) if m == 0 or m > n: return None for i in range(n - m + 1): if seq[i:i+m] == subseq: return i return None def to_tokens(self, txt: str): ps = self.backend_phonemizer.phonemize([txt])[0].strip() ps = ps.replace("(en)", "").replace("(vi)", "") return self.textcleaner(ps) def wav_to_mel(self, wave_1d: np.ndarray): # if sr_in != sr_target: # w = torch.from_numpy(wave_1d).float() # w = torchaudio.functional.resample(w, sr_in, sr_target) # wave_1d = w.numpy() wave_pad = np.concatenate( [np.zeros(PAD, dtype=wave_1d.dtype), wave_1d, np.zeros(PAD, dtype=wave_1d.dtype)] ) w = torch.from_numpy(wave_pad).float() to_mel = torchaudio.transforms.MelSpectrogram( n_mels=N_MELS, n_fft=N_FFT, win_length=WIN, hop_length=HOP ) mel = to_mel(w) # [n_mels, T] mel = (torch.log(1e-5 + mel).unsqueeze(0) - MEAN) / STD # [1, 80, T] mel = mel.squeeze(0) # [80, T] # trim để chia hết cho d T = mel.shape[1]; T_trim = T - (T % self.d) if T_trim != T: mel = mel[:, :T_trim] wave_pad = wave_pad[: T_trim * HOP] # đồng bộ thời gian return wave_pad, mel # np.ndarray (đã pad), torch.Tensor [80, T] def cal_attn(self, mel_len, text_len, mel, tokens): mask_mel = length_to_mask(mel_len // (2 ** self.n_down)) text_mask = length_to_mask(text_len) mels_in = mel.unsqueeze(0) # [1, 80, T] ppgs, s2s_pred, s2s_attn = self.model.text_aligner(mels_in, mask_mel, tokens) s2s_attn = s2s_attn.transpose(-1, -2) s2s_attn = s2s_attn[..., 1:] s2s_attn = s2s_attn.transpose(-1, -2) attn_mask = (~mask_mel).unsqueeze(-1).expand(mask_mel.shape[0], mask_mel.shape[1], text_mask.shape[-1]).float().transpose(-1, -2) attn_mask = attn_mask.float() * (~text_mask).unsqueeze(-1).expand(text_mask.shape[0], text_mask.shape[1], mask_mel.shape[-1]).float() attn_mask = (attn_mask < 1) s2s_attn.masked_fill_(attn_mask, 0.0) mask_ST = mask_from_lens(s2s_attn, text_len, mel_len // (2 ** self.n_down)) s2s_attn_mono = maximum_path(s2s_attn, mask_ST) return s2s_attn_mono def convert_sr(self, wav, orig_sr, target_sr): if orig_sr != target_sr: wav = librosa.resample(wav, orig_sr=orig_sr, target_sr=target_sr) return wav def load_audio(self, audio_input, target_sr=24000): if isinstance(audio_input, str): wav, sr = librosa.load(audio_input, sr=None) else: wav = audio_input sr = target_sr wav = self.convert_sr(wav, orig_sr=sr, target_sr=target_sr) return wav, target_sr def split_audio(self, str_raw: str, str_trunc: str, audio_input): ps_trunc = self.to_tokens(str_trunc) ps_raw = self.to_tokens(str_raw) wav_np, mel = self.wav_to_mel(audio_input) T = mel.shape[1] T_trim = T - (T % self.d) if T_trim != T: mel = mel[:, :T_trim] cut_start = self.find_subsequence(ps_raw, ps_trunc) cut_end = cut_start + len(ps_trunc) ps_trunc = torch.LongTensor(ps_trunc).unsqueeze(0) ps_raw = torch.LongTensor(ps_raw).unsqueeze(0) mel_len = torch.tensor([mel.shape[1]], dtype=torch.long) text_len = torch.tensor([ps_raw.shape[1]], dtype=torch.long) s2s_attn_mono = self.cal_attn( mel_len=mel_len, text_len=text_len, mel=mel, tokens=ps_raw ) with torch.no_grad(): token_per_frame_down = torch.argmax(s2s_attn_mono[0], dim=0) token_per_frame_down = token_per_frame_down.cpu().numpy() mask_down = (token_per_frame_down >= cut_start) & (token_per_frame_down < cut_end) idx_down = np.where(mask_down)[0] start_frame_down = idx_down[0] end_frame_down = idx_down[-1] + 1 start_frame_full = int(start_frame_down * self.d) end_frame_full = int(end_frame_down * self.d) start_sample_in_padded = start_frame_full * HOP end_sample_in_padded = end_frame_full * HOP start_sample = max(0, start_sample_in_padded - PAD) end_sample = max(start_sample+1, end_sample_in_padded - PAD) end_sample = min(end_sample, len(wav_np) - PAD) y_cut = wav_np[start_sample + PAD : end_sample + PAD] # margin_frames_full = int(2 * d) # start_sample = max(0, (start_frame_full - margin_frames_full) * HOP - PAD) # end_sample = min(len(wav_np) - 1, (end_frame_full + margin_frames_full) * HOP - pad) # y_cut = wav_np[start_sample + pad : end_sample + pad] return y_cut if __name__ == "__main__": splitter = AudioSplitter(language="vi", model_name="phoaudio_single_v1", device="cpu") # str_raw = "tôi nghĩ đến vóc dáng của tiết vân phong. có lẽ cậu ta cũng đánh thắng tôi. nhưng mà cân nhắc đến chuyện cậu ta đã uống sai, chắc là không khó mà ứng phó. thế là tôi xua xua tay nói," # str_trunc = "tôi nghĩ đến vóc dáng của tiết vân phong" # str_trunc = "nhưng mà cân nhắc đến chuyện cậu ta đã uống sai" str_raw = "mệt mỏi vì lo lắng. họ ngủ một cách lơ mơ với cái ý thức cảnh giác cố hữu. nhà ga nào cũng đầy bọn trộm cắp. lâu lắm mới nghe tiếng chân của một người." str_trunc = "họ ngủ một cách lơ mơ với cái ý thức cảnh giác cố hữu" # audio_input, sr = splitter.load_audio("example_trimmed.wav", sr=None) # splitter.split_audio(str_raw, str_trunc, audio_input) audio_input, sr = splitter.load_audio("Đào_Hiếu.wav", target_sr=24000) y_cut = splitter.split_audio(str_raw, str_trunc, audio_input) # print(f"audio cut: {y_cut}") # librosa.output.write_wav("example_cut.wav", y_cut, sr) # use librosa algorithm to trim the silence y_cut = librosa.effects.trim(y_cut, top_db=15)[0] sf.write("example_cut.wav", y_cut, sr)