| from module.models_onnx import SynthesizerTrn, symbols
|
| from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
|
| import torch
|
| import torchaudio
|
| from torch import nn
|
| from feature_extractor import cnhubert
|
| cnhubert_base_path = "pretrained_models/chinese-hubert-base"
|
| cnhubert.cnhubert_base_path=cnhubert_base_path
|
| ssl_model = cnhubert.get_model()
|
| from text import cleaned_text_to_sequence
|
| import soundfile
|
| from tools.my_utils import load_audio
|
| import os
|
| import json
|
|
|
| def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
| hann_window = torch.hann_window(win_size).to(
|
| dtype=y.dtype, device=y.device
|
| )
|
| y = torch.nn.functional.pad(
|
| y.unsqueeze(1),
|
| (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
| mode="reflect",
|
| )
|
| y = y.squeeze(1)
|
| spec = torch.stft(
|
| y,
|
| n_fft,
|
| hop_length=hop_size,
|
| win_length=win_size,
|
| window=hann_window,
|
| center=center,
|
| pad_mode="reflect",
|
| normalized=False,
|
| onesided=True,
|
| return_complex=False,
|
| )
|
| spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
| return spec
|
|
|
|
|
| class DictToAttrRecursive(dict):
|
| def __init__(self, input_dict):
|
| super().__init__(input_dict)
|
| for key, value in input_dict.items():
|
| if isinstance(value, dict):
|
| value = DictToAttrRecursive(value)
|
| self[key] = value
|
| setattr(self, key, value)
|
|
|
| def __getattr__(self, item):
|
| try:
|
| return self[item]
|
| except KeyError:
|
| raise AttributeError(f"Attribute {item} not found")
|
|
|
| def __setattr__(self, key, value):
|
| if isinstance(value, dict):
|
| value = DictToAttrRecursive(value)
|
| super(DictToAttrRecursive, self).__setitem__(key, value)
|
| super().__setattr__(key, value)
|
|
|
| def __delattr__(self, item):
|
| try:
|
| del self[item]
|
| except KeyError:
|
| raise AttributeError(f"Attribute {item} not found")
|
|
|
|
|
| class T2SEncoder(nn.Module):
|
| def __init__(self, t2s, vits):
|
| super().__init__()
|
| self.encoder = t2s.onnx_encoder
|
| self.vits = vits
|
|
|
| def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
|
| codes = self.vits.extract_latent(ssl_content)
|
| prompt_semantic = codes[0, 0]
|
| bert = torch.cat([ref_bert.transpose(0, 1), text_bert.transpose(0, 1)], 1)
|
| all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
|
| bert = bert.unsqueeze(0)
|
| prompt = prompt_semantic.unsqueeze(0)
|
| return self.encoder(all_phoneme_ids, bert), prompt
|
|
|
|
|
| class T2SModel(nn.Module):
|
| def __init__(self, t2s_path, vits_model):
|
| super().__init__()
|
| dict_s1 = torch.load(t2s_path, map_location="cpu")
|
| self.config = dict_s1["config"]
|
| self.t2s_model = Text2SemanticLightningModule(self.config, "ojbk", is_train=False)
|
| self.t2s_model.load_state_dict(dict_s1["weight"])
|
| self.t2s_model.eval()
|
| self.vits_model = vits_model.vq_model
|
| self.hz = 50
|
| self.max_sec = self.config["data"]["max_sec"]
|
| self.t2s_model.model.top_k = torch.LongTensor([self.config["inference"]["top_k"]])
|
| self.t2s_model.model.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
|
| self.t2s_model = self.t2s_model.model
|
| self.t2s_model.init_onnx()
|
| self.onnx_encoder = T2SEncoder(self.t2s_model, self.vits_model)
|
| self.first_stage_decoder = self.t2s_model.first_stage_decoder
|
| self.stage_decoder = self.t2s_model.stage_decoder
|
|
|
|
|
| def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
|
| early_stop_num = self.t2s_model.early_stop_num
|
|
|
|
|
| x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
|
|
| prefix_len = prompts.shape[1]
|
|
|
|
|
| y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
|
|
| stop = False
|
| for idx in range(1, 1500):
|
|
|
| enco = self.stage_decoder(y, k, v, y_emb, x_example)
|
| y, k, v, y_emb, logits, samples = enco
|
| if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
| stop = True
|
| if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS:
|
| stop = True
|
| if stop:
|
| break
|
| y[0, -1] = 0
|
|
|
| return y[:, -idx:].unsqueeze(0)
|
|
|
| def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, dynamo=False):
|
|
|
| if dynamo:
|
| export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
|
| onnx_encoder_export_output = torch.onnx.dynamo_export(
|
| self.onnx_encoder,
|
| (ref_seq, text_seq, ref_bert, text_bert, ssl_content),
|
| export_options=export_options
|
| )
|
| onnx_encoder_export_output.save(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx")
|
| return
|
|
|
| torch.onnx.export(
|
| self.onnx_encoder,
|
| (ref_seq, text_seq, ref_bert, text_bert, ssl_content),
|
| f"onnx/{project_name}/{project_name}_t2s_encoder.onnx",
|
| input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"],
|
| output_names=["x", "prompts"],
|
| dynamic_axes={
|
| "ref_seq": {1 : "ref_length"},
|
| "text_seq": {1 : "text_length"},
|
| "ref_bert": {0 : "ref_length"},
|
| "text_bert": {0 : "text_length"},
|
| "ssl_content": {2 : "ssl_length"},
|
| },
|
| opset_version=16
|
| )
|
| x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
|
|
| torch.onnx.export(
|
| self.first_stage_decoder,
|
| (x, prompts),
|
| f"onnx/{project_name}/{project_name}_t2s_fsdec.onnx",
|
| input_names=["x", "prompts"],
|
| output_names=["y", "k", "v", "y_emb", "x_example"],
|
| dynamic_axes={
|
| "x": {1 : "x_length"},
|
| "prompts": {1 : "prompts_length"},
|
| },
|
| verbose=False,
|
| opset_version=16
|
| )
|
| y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
|
|
| torch.onnx.export(
|
| self.stage_decoder,
|
| (y, k, v, y_emb, x_example),
|
| f"onnx/{project_name}/{project_name}_t2s_sdec.onnx",
|
| input_names=["iy", "ik", "iv", "iy_emb", "ix_example"],
|
| output_names=["y", "k", "v", "y_emb", "logits", "samples"],
|
| dynamic_axes={
|
| "iy": {1 : "iy_length"},
|
| "ik": {1 : "ik_length"},
|
| "iv": {1 : "iv_length"},
|
| "iy_emb": {1 : "iy_emb_length"},
|
| "ix_example": {1 : "ix_example_length"},
|
| },
|
| verbose=False,
|
| opset_version=16
|
| )
|
|
|
|
|
| class VitsModel(nn.Module):
|
| def __init__(self, vits_path):
|
| super().__init__()
|
| dict_s2 = torch.load(vits_path,map_location="cpu")
|
| self.hps = dict_s2["config"]
|
| self.hps = DictToAttrRecursive(self.hps)
|
| self.hps.model.semantic_frame_rate = "25hz"
|
| self.vq_model = SynthesizerTrn(
|
| self.hps.data.filter_length // 2 + 1,
|
| self.hps.train.segment_size // self.hps.data.hop_length,
|
| n_speakers=self.hps.data.n_speakers,
|
| **self.hps.model
|
| )
|
| self.vq_model.eval()
|
| self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
|
|
| def forward(self, text_seq, pred_semantic, ref_audio):
|
| refer = spectrogram_torch(
|
| ref_audio,
|
| self.hps.data.filter_length,
|
| self.hps.data.sampling_rate,
|
| self.hps.data.hop_length,
|
| self.hps.data.win_length,
|
| center=False
|
| )
|
| return self.vq_model(pred_semantic, text_seq, refer)[0, 0]
|
|
|
|
|
| class GptSoVits(nn.Module):
|
| def __init__(self, vits, t2s):
|
| super().__init__()
|
| self.vits = vits
|
| self.t2s = t2s
|
|
|
| def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, debug=False):
|
| pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
| audio = self.vits(text_seq, pred_semantic, ref_audio)
|
| if debug:
|
| import onnxruntime
|
| sess = onnxruntime.InferenceSession("onnx/koharu/koharu_vits.onnx", providers=["CPU"])
|
| audio1 = sess.run(None, {
|
| "text_seq" : text_seq.detach().cpu().numpy(),
|
| "pred_semantic" : pred_semantic.detach().cpu().numpy(),
|
| "ref_audio" : ref_audio.detach().cpu().numpy()
|
| })
|
| return audio, audio1
|
| return audio
|
|
|
| def export(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, project_name):
|
| self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name)
|
| pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
| torch.onnx.export(
|
| self.vits,
|
| (text_seq, pred_semantic, ref_audio),
|
| f"onnx/{project_name}/{project_name}_vits.onnx",
|
| input_names=["text_seq", "pred_semantic", "ref_audio"],
|
| output_names=["audio"],
|
| dynamic_axes={
|
| "text_seq": {1 : "text_length"},
|
| "pred_semantic": {2 : "pred_length"},
|
| "ref_audio": {1 : "audio_length"},
|
| },
|
| opset_version=17,
|
| verbose=False
|
| )
|
|
|
|
|
| class SSLModel(nn.Module):
|
| def __init__(self):
|
| super().__init__()
|
| self.ssl = ssl_model
|
|
|
| def forward(self, ref_audio_16k):
|
| return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
|
|
|
|
|
| def export(vits_path, gpt_path, project_name):
|
| vits = VitsModel(vits_path)
|
| gpt = T2SModel(gpt_path, vits)
|
| gpt_sovits = GptSoVits(vits, gpt)
|
| ssl = SSLModel()
|
| ref_seq = torch.LongTensor([cleaned_text_to_sequence(["n", "i2", "h", "ao3", ",", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])])
|
| text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])])
|
| ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
|
| text_bert = torch.randn((text_seq.shape[1], 1024)).float()
|
| ref_audio = torch.randn((1, 48000 * 5)).float()
|
|
|
| ref_audio_16k = torchaudio.functional.resample(ref_audio,48000,16000).float()
|
| ref_audio_sr = torchaudio.functional.resample(ref_audio,48000,vits.hps.data.sampling_rate).float()
|
|
|
| try:
|
| os.mkdir(f"onnx/{project_name}")
|
| except:
|
| pass
|
|
|
| ssl_content = ssl(ref_audio_16k).float()
|
|
|
| debug = False
|
|
|
| if debug:
|
| a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug)
|
| soundfile.write("out1.wav", a.cpu().detach().numpy(), vits.hps.data.sampling_rate)
|
| soundfile.write("out2.wav", b[0], vits.hps.data.sampling_rate)
|
| return
|
|
|
| a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy()
|
|
|
| soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
|
|
|
| gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
|
|
|
| MoeVSConf = {
|
| "Folder" : f"{project_name}",
|
| "Name" : f"{project_name}",
|
| "Type" : "GPT-SoVits",
|
| "Rate" : vits.hps.data.sampling_rate,
|
| "NumLayers": gpt.t2s_model.num_layers,
|
| "EmbeddingDim": gpt.t2s_model.embedding_dim,
|
| "Dict": "BasicDict",
|
| "BertPath": "chinese-roberta-wwm-ext-large",
|
| "Symbol": symbols,
|
| "AddBlank": False
|
| }
|
|
|
| MoeVSConfJson = json.dumps(MoeVSConf)
|
| with open(f"onnx/{project_name}.json", 'w') as MoeVsConfFile:
|
| json.dump(MoeVSConf, MoeVsConfFile, indent = 4)
|
|
|
|
|
| if __name__ == "__main__":
|
| try:
|
| os.mkdir("onnx")
|
| except:
|
| pass
|
|
|
| gpt_path = "GPT_weights/nahida-e25.ckpt"
|
| vits_path = "SoVITS_weights/nahida_e30_s3930.pth"
|
| exp_path = "nahida"
|
| export(vits_path, gpt_path, exp_path)
|
|
|
| |