Spaces:
Runtime error
Runtime error
| # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py | |
| # reference: https://github.com/lifeiteng/vall-e | |
| import argparse | |
| from io import BytesIO | |
| from typing import Optional | |
| from my_utils import load_audio | |
| import torch | |
| import torchaudio | |
| from torch import IntTensor, LongTensor, Tensor, nn | |
| from torch.nn import functional as F | |
| from transformers import AutoModelForMaskedLM, AutoTokenizer | |
| from feature_extractor import cnhubert | |
| from AR.models.t2s_lightning_module import Text2SemanticLightningModule | |
| from module.models_onnx import SynthesizerTrn | |
| from inference_webui import get_phones_and_bert | |
| from sv import SV | |
| import kaldi as Kaldi | |
| import os | |
| import soundfile | |
| default_config = { | |
| "embedding_dim": 512, | |
| "hidden_dim": 512, | |
| "num_head": 8, | |
| "num_layers": 12, | |
| "num_codebook": 8, | |
| "p_dropout": 0.0, | |
| "vocab_size": 1024 + 1, | |
| "phoneme_vocab_size": 512, | |
| "EOS": 1024, | |
| } | |
| sv_cn_model = None | |
| def init_sv_cn(device, is_half): | |
| global sv_cn_model | |
| sv_cn_model = SV(device, is_half) | |
| def load_sovits_new(sovits_path): | |
| f = open(sovits_path, "rb") | |
| meta = f.read(2) | |
| if meta != b"PK": | |
| data = b"PK" + f.read() | |
| bio = BytesIO() | |
| bio.write(data) | |
| bio.seek(0) | |
| return torch.load(bio, map_location="cpu", weights_only=False) | |
| return torch.load(sovits_path, map_location="cpu", weights_only=False) | |
| def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule: | |
| config = dict_s1["config"] | |
| config["model"]["dropout"] = float(config["model"]["dropout"]) | |
| t2s_model = Text2SemanticLightningModule(config, "****", is_train=False) | |
| t2s_model.load_state_dict(dict_s1["weight"]) | |
| t2s_model = t2s_model.eval() | |
| return t2s_model | |
| def logits_to_probs( | |
| logits, | |
| previous_tokens: Optional[torch.Tensor] = None, | |
| temperature: float = 1.0, | |
| top_k: Optional[int] = None, | |
| top_p: Optional[int] = None, | |
| repetition_penalty: float = 1.0, | |
| ): | |
| # if previous_tokens is not None: | |
| # previous_tokens = previous_tokens.squeeze() | |
| # print(logits.shape,previous_tokens.shape) | |
| # pdb.set_trace() | |
| if previous_tokens is not None and repetition_penalty != 1.0: | |
| previous_tokens = previous_tokens.long() | |
| score = torch.gather(logits, dim=1, index=previous_tokens) | |
| score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty) | |
| logits.scatter_(dim=1, index=previous_tokens, src=score) | |
| if top_p is not None and top_p < 1.0: | |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
| cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1) | |
| sorted_indices_to_remove = cum_probs > top_p | |
| sorted_indices_to_remove[:, 0] = False # keep at least one option | |
| indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove) | |
| logits = logits.masked_fill(indices_to_remove, -float("Inf")) | |
| logits = logits / max(temperature, 1e-5) | |
| if top_k is not None: | |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
| pivot = v[:, -1].unsqueeze(-1) | |
| logits = torch.where(logits < pivot, -float("Inf"), logits) | |
| probs = torch.nn.functional.softmax(logits, dim=-1) | |
| return probs | |
| def multinomial_sample_one_no_sync(probs_sort): | |
| # Does multinomial sampling without a cuda synchronization | |
| q = torch.empty_like(probs_sort).exponential_(1.0) | |
| return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) | |
| def sample( | |
| logits, | |
| previous_tokens, | |
| temperature: float = 1.0, | |
| top_k: Optional[int] = None, | |
| top_p: Optional[int] = None, | |
| repetition_penalty: float = 1.35, | |
| ): | |
| probs = logits_to_probs( | |
| logits=logits, | |
| previous_tokens=previous_tokens, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| ) | |
| idx_next = multinomial_sample_one_no_sync(probs) | |
| return idx_next, probs | |
| def spectrogram_torch( | |
| hann_window: Tensor, y: Tensor, n_fft: int, sampling_rate: int, hop_size: int, win_size: int, center: bool = False | |
| ): | |
| # hann_window = torch.hann_window(win_size, device=y.device, dtype=y.dtype) | |
| y = torch.nn.functional.pad( | |
| y.unsqueeze(1), | |
| (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), | |
| mode="reflect", | |
| ) | |
| y = y.squeeze(1) | |
| spec = torch.stft( | |
| y, | |
| n_fft, | |
| hop_length=hop_size, | |
| win_length=win_size, | |
| window=hann_window, | |
| center=center, | |
| pad_mode="reflect", | |
| normalized=False, | |
| onesided=True, | |
| return_complex=False, | |
| ) | |
| spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) | |
| return spec | |
| class DictToAttrRecursive(dict): | |
| def __init__(self, input_dict): | |
| super().__init__(input_dict) | |
| for key, value in input_dict.items(): | |
| if isinstance(value, dict): | |
| value = DictToAttrRecursive(value) | |
| self[key] = value | |
| setattr(self, key, value) | |
| def __getattr__(self, item): | |
| try: | |
| return self[item] | |
| except KeyError: | |
| raise AttributeError(f"Attribute {item} not found") | |
| def __setattr__(self, key, value): | |
| if isinstance(value, dict): | |
| value = DictToAttrRecursive(value) | |
| super(DictToAttrRecursive, self).__setitem__(key, value) | |
| super().__setattr__(key, value) | |
| def __delattr__(self, item): | |
| try: | |
| del self[item] | |
| except KeyError: | |
| raise AttributeError(f"Attribute {item} not found") | |
| class T2SMLP: | |
| def __init__(self, w1, b1, w2, b2): | |
| self.w1 = w1 | |
| self.b1 = b1 | |
| self.w2 = w2 | |
| self.b2 = b2 | |
| def forward(self, x): | |
| x = F.relu(F.linear(x, self.w1, self.b1)) | |
| x = F.linear(x, self.w2, self.b2) | |
| return x | |
| class T2SBlock: | |
| def __init__( | |
| self, | |
| num_heads: int, | |
| hidden_dim: int, | |
| mlp: T2SMLP, | |
| qkv_w, | |
| qkv_b, | |
| out_w, | |
| out_b, | |
| norm_w1, | |
| norm_b1, | |
| norm_eps1: float, | |
| norm_w2, | |
| norm_b2, | |
| norm_eps2: float, | |
| ): | |
| self.num_heads = num_heads | |
| self.mlp = mlp | |
| self.hidden_dim: int = hidden_dim | |
| self.qkv_w = qkv_w | |
| self.qkv_b = qkv_b | |
| self.out_w = out_w | |
| self.out_b = out_b | |
| self.norm_w1 = norm_w1 | |
| self.norm_b1 = norm_b1 | |
| self.norm_eps1 = norm_eps1 | |
| self.norm_w2 = norm_w2 | |
| self.norm_b2 = norm_b2 | |
| self.norm_eps2 = norm_eps2 | |
| self.false = torch.tensor(False, dtype=torch.bool) | |
| def to_mask(self, x: torch.Tensor, padding_mask: Optional[torch.Tensor]): | |
| if padding_mask is None: | |
| return x | |
| if padding_mask.dtype == torch.bool: | |
| return x.masked_fill(padding_mask, 0) | |
| else: | |
| return x * padding_mask | |
| def process_prompt(self, x: torch.Tensor, attn_mask: torch.Tensor, padding_mask: Optional[torch.Tensor] = None): | |
| q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1) | |
| batch_size = q.shape[0] | |
| q_len = q.shape[1] | |
| kv_len = k.shape[1] | |
| q = self.to_mask(q, padding_mask) | |
| k_cache = self.to_mask(k, padding_mask) | |
| v_cache = self.to_mask(v, padding_mask) | |
| q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2) | |
| k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) | |
| v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) | |
| attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask) | |
| attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim) | |
| attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0) | |
| attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b) | |
| if padding_mask is not None: | |
| for i in range(batch_size): | |
| # mask = padding_mask[i,:,0] | |
| if self.false.device != padding_mask.device: | |
| self.false = self.false.to(padding_mask.device) | |
| idx = torch.where(padding_mask[i, :, 0] == self.false)[0] | |
| x_item = x[i, idx, :].unsqueeze(0) | |
| attn_item = attn[i, idx, :].unsqueeze(0) | |
| x_item = x_item + attn_item | |
| x_item = F.layer_norm(x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1) | |
| x_item = x_item + self.mlp.forward(x_item) | |
| x_item = F.layer_norm( | |
| x_item, | |
| [self.hidden_dim], | |
| self.norm_w2, | |
| self.norm_b2, | |
| self.norm_eps2, | |
| ) | |
| x[i, idx, :] = x_item.squeeze(0) | |
| x = self.to_mask(x, padding_mask) | |
| else: | |
| x = x + attn | |
| x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1) | |
| x = x + self.mlp.forward(x) | |
| x = F.layer_norm( | |
| x, | |
| [self.hidden_dim], | |
| self.norm_w2, | |
| self.norm_b2, | |
| self.norm_eps2, | |
| ) | |
| return x, k_cache, v_cache | |
| def decode_next_token(self, x: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor): | |
| q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1) | |
| k_cache = torch.cat([k_cache, k], dim=1) | |
| v_cache = torch.cat([v_cache, v], dim=1) | |
| batch_size = q.shape[0] | |
| q_len = q.shape[1] | |
| kv_len = k_cache.shape[1] | |
| q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2) | |
| k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) | |
| v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) | |
| attn = F.scaled_dot_product_attention(q, k, v) | |
| # attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim) | |
| # attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0) | |
| attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1) | |
| attn = F.linear(attn, self.out_w, self.out_b) | |
| x = x + attn | |
| x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1) | |
| x = x + self.mlp.forward(x) | |
| x = F.layer_norm( | |
| x, | |
| [self.hidden_dim], | |
| self.norm_w2, | |
| self.norm_b2, | |
| self.norm_eps2, | |
| ) | |
| return x, k_cache, v_cache | |
| class T2STransformer: | |
| def __init__(self, num_blocks: int, blocks: list[T2SBlock]): | |
| self.num_blocks: int = num_blocks | |
| self.blocks = blocks | |
| def process_prompt(self, x: torch.Tensor, attn_mask: torch.Tensor, padding_mask: Optional[torch.Tensor] = None): | |
| k_cache: list[torch.Tensor] = [] | |
| v_cache: list[torch.Tensor] = [] | |
| for i in range(self.num_blocks): | |
| x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask) | |
| k_cache.append(k_cache_) | |
| v_cache.append(v_cache_) | |
| return x, k_cache, v_cache | |
| def decode_next_token(self, x: torch.Tensor, k_cache: list[torch.Tensor], v_cache: list[torch.Tensor]): | |
| for i in range(self.num_blocks): | |
| x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i]) | |
| return x, k_cache, v_cache | |
| class VitsModel(nn.Module): | |
| def __init__(self, vits_path, version=None, is_half=True, device="cpu"): | |
| super().__init__() | |
| # dict_s2 = torch.load(vits_path,map_location="cpu") | |
| dict_s2 = load_sovits_new(vits_path) | |
| self.hps = dict_s2["config"] | |
| if version is None: | |
| if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322: | |
| self.hps["model"]["version"] = "v1" | |
| else: | |
| self.hps["model"]["version"] = "v2" | |
| else: | |
| if version in ["v1", "v2", "v3", "v4", "v2Pro", "v2ProPlus"]: | |
| self.hps["model"]["version"] = version | |
| else: | |
| raise ValueError(f"Unsupported version: {version}") | |
| self.hps = DictToAttrRecursive(self.hps) | |
| self.hps.model.semantic_frame_rate = "25hz" | |
| self.vq_model = SynthesizerTrn( | |
| self.hps.data.filter_length // 2 + 1, | |
| self.hps.train.segment_size // self.hps.data.hop_length, | |
| n_speakers=self.hps.data.n_speakers, | |
| **self.hps.model, | |
| ) | |
| self.vq_model.load_state_dict(dict_s2["weight"], strict=False) | |
| self.vq_model.dec.remove_weight_norm() | |
| if is_half: | |
| self.vq_model = self.vq_model.half() | |
| self.vq_model = self.vq_model.to(device) | |
| self.vq_model.eval() | |
| self.hann_window = torch.hann_window( | |
| self.hps.data.win_length, device=device, dtype=torch.float16 if is_half else torch.float32 | |
| ) | |
| def forward(self, text_seq, pred_semantic, ref_audio, speed=1.0, sv_emb=None): | |
| refer = spectrogram_torch( | |
| self.hann_window, | |
| ref_audio, | |
| self.hps.data.filter_length, | |
| self.hps.data.sampling_rate, | |
| self.hps.data.hop_length, | |
| self.hps.data.win_length, | |
| center=False, | |
| ) | |
| return self.vq_model(pred_semantic, text_seq, refer, speed=speed, sv_emb=sv_emb)[0, 0] | |
| class T2SModel(nn.Module): | |
| def __init__(self, raw_t2s: Text2SemanticLightningModule): | |
| super(T2SModel, self).__init__() | |
| self.model_dim = raw_t2s.model.model_dim | |
| self.embedding_dim = raw_t2s.model.embedding_dim | |
| self.num_head = raw_t2s.model.num_head | |
| self.num_layers = raw_t2s.model.num_layers | |
| self.vocab_size = raw_t2s.model.vocab_size | |
| self.phoneme_vocab_size = raw_t2s.model.phoneme_vocab_size | |
| # self.p_dropout = float(raw_t2s.model.p_dropout) | |
| self.EOS: int = int(raw_t2s.model.EOS) | |
| self.norm_first = raw_t2s.model.norm_first | |
| assert self.EOS == self.vocab_size - 1 | |
| self.hz = 50 | |
| self.bert_proj = raw_t2s.model.bert_proj | |
| self.ar_text_embedding = raw_t2s.model.ar_text_embedding | |
| self.ar_text_position = raw_t2s.model.ar_text_position | |
| self.ar_audio_embedding = raw_t2s.model.ar_audio_embedding | |
| self.ar_audio_position = raw_t2s.model.ar_audio_position | |
| # self.t2s_transformer = T2STransformer(self.num_layers, blocks) | |
| # self.t2s_transformer = raw_t2s.model.t2s_transformer | |
| blocks = [] | |
| h = raw_t2s.model.h | |
| for i in range(self.num_layers): | |
| layer = h.layers[i] | |
| t2smlp = T2SMLP(layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias) | |
| block = T2SBlock( | |
| self.num_head, | |
| self.model_dim, | |
| t2smlp, | |
| layer.self_attn.in_proj_weight, | |
| layer.self_attn.in_proj_bias, | |
| layer.self_attn.out_proj.weight, | |
| layer.self_attn.out_proj.bias, | |
| layer.norm1.weight, | |
| layer.norm1.bias, | |
| layer.norm1.eps, | |
| layer.norm2.weight, | |
| layer.norm2.bias, | |
| layer.norm2.eps, | |
| ) | |
| blocks.append(block) | |
| self.t2s_transformer = T2STransformer(self.num_layers, blocks) | |
| # self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False) | |
| self.ar_predict_layer = raw_t2s.model.ar_predict_layer | |
| # self.loss_fct = nn.CrossEntropyLoss(reduction="sum") | |
| self.max_sec = raw_t2s.config["data"]["max_sec"] | |
| self.top_k = int(raw_t2s.config["inference"]["top_k"]) | |
| self.early_stop_num = torch.LongTensor([self.hz * self.max_sec]) | |
| def forward( | |
| self, | |
| prompts: LongTensor, | |
| ref_seq: LongTensor, | |
| text_seq: LongTensor, | |
| ref_bert: torch.Tensor, | |
| text_bert: torch.Tensor, | |
| top_k: LongTensor, | |
| ): | |
| bert = torch.cat([ref_bert.T, text_bert.T], 1) | |
| all_phoneme_ids = torch.cat([ref_seq, text_seq], 1) | |
| bert = bert.unsqueeze(0) | |
| x = self.ar_text_embedding(all_phoneme_ids) | |
| x = x + self.bert_proj(bert.transpose(1, 2)) | |
| x: torch.Tensor = self.ar_text_position(x) | |
| early_stop_num = self.early_stop_num | |
| # [1,N,512] [1,N] | |
| # y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts) | |
| y = prompts | |
| # x_example = x[:,:,0] * 0.0 | |
| x_len = x.shape[1] | |
| x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) | |
| y_emb = self.ar_audio_embedding(y) | |
| y_len = y_emb.shape[1] | |
| prefix_len = y.shape[1] | |
| y_pos = self.ar_audio_position(y_emb) | |
| xy_pos = torch.concat([x, y_pos], dim=1) | |
| bsz = x.shape[0] | |
| src_len = x_len + y_len | |
| x_attn_mask_pad = F.pad( | |
| x_attn_mask, | |
| (0, y_len), ###xx็็บฏ0ๆฉๅฑๅฐxx็บฏ0+xy็บฏ1๏ผ(x,x+y) | |
| value=True, | |
| ) | |
| y_attn_mask = F.pad( ###yy็ๅณไธ1ๆฉๅฑๅฐๅทฆ่พนxy็0,(y,x+y) | |
| torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), | |
| (x_len, 0), | |
| value=False, | |
| ) | |
| xy_attn_mask = ( | |
| torch.concat([x_attn_mask_pad, y_attn_mask], dim=0) | |
| .unsqueeze(0) | |
| .expand(bsz * self.num_head, -1, -1) | |
| .view(bsz, self.num_head, src_len, src_len) | |
| .to(device=x.device, dtype=torch.bool) | |
| ) | |
| idx = 0 | |
| top_k = int(top_k) | |
| xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None) | |
| logits = self.ar_predict_layer(xy_dec[:, -1]) | |
| logits = logits[:, :-1] | |
| samples = sample(logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0] | |
| y = torch.concat([y, samples], dim=1) | |
| y_emb = self.ar_audio_embedding(y[:, -1:]) | |
| xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[ | |
| :, y_len + idx | |
| ].to(dtype=y_emb.dtype, device=y_emb.device) | |
| stop = False | |
| # for idx in range(1, 50): | |
| for idx in range(1, 1500): | |
| # [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N] | |
| # y, k, v, y_emb, logits, samples = self.stage_decoder(y, k, v, y_emb, x_example) | |
| xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache) | |
| logits = self.ar_predict_layer(xy_dec[:, -1]) | |
| if idx < 11: ###่ณๅฐ้ขๆตๅบ10ไธชtokenไธ็ถไธ็ปๅๆญข๏ผ0.4s๏ผ | |
| logits = logits[:, :-1] | |
| samples = sample(logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0] | |
| y = torch.concat([y, samples], dim=1) | |
| if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: | |
| stop = True | |
| if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: | |
| stop = True | |
| if stop: | |
| if y.shape[1] == 0: | |
| y = torch.concat([y, torch.zeros_like(samples)], dim=1) | |
| break | |
| y_emb = self.ar_audio_embedding(y[:, -1:]) | |
| xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[ | |
| :, y_len + idx | |
| ].to(dtype=y_emb.dtype, device=y_emb.device) | |
| y[0, -1] = 0 | |
| return y[:, -idx:].unsqueeze(0) | |
| bert_path = os.environ.get("bert_path", "pretrained_models/chinese-roberta-wwm-ext-large") | |
| cnhubert_base_path = "pretrained_models/chinese-hubert-base" | |
| cnhubert.cnhubert_base_path = cnhubert_base_path | |
| def build_phone_level_feature(res: Tensor, word2ph: IntTensor): | |
| phone_level_feature = [] | |
| for i in range(word2ph.shape[0]): | |
| repeat_feature = res[i].repeat(word2ph[i].item(), 1) | |
| phone_level_feature.append(repeat_feature) | |
| phone_level_feature = torch.cat(phone_level_feature, dim=0) | |
| # [sum(word2ph), 1024] | |
| return phone_level_feature | |
| class MyBertModel(torch.nn.Module): | |
| def __init__(self, bert_model): | |
| super(MyBertModel, self).__init__() | |
| self.bert = bert_model | |
| def forward( | |
| self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor, word2ph: IntTensor | |
| ): | |
| outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) | |
| # res = torch.cat(outputs["hidden_states"][-3:-2], -1)[0][1:-1] | |
| res = torch.cat(outputs[1][-3:-2], -1)[0][1:-1] | |
| return build_phone_level_feature(res, word2ph) | |
| class SSLModel(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.ssl = cnhubert.get_model().model | |
| def forward(self, ref_audio_16k) -> torch.Tensor: | |
| ssl_content = self.ssl(ref_audio_16k)["last_hidden_state"].transpose(1, 2) | |
| return ssl_content | |
| class ExportSSLModel(torch.nn.Module): | |
| def __init__(self, ssl: SSLModel): | |
| super().__init__() | |
| self.ssl = ssl | |
| def forward(self, ref_audio: torch.Tensor): | |
| return self.ssl(ref_audio) | |
| def resample(self, ref_audio: torch.Tensor, src_sr: int, dst_sr: int) -> torch.Tensor: | |
| audio = resamplex(ref_audio, src_sr, dst_sr).float() | |
| return audio | |
| def export_bert(output_path): | |
| tokenizer = AutoTokenizer.from_pretrained(bert_path) | |
| text = "ๅนๆฏๅฃฐไธๅฃฐๆฅ็ไธๅฃฐไผ ๅบ,ๆจๅ ฐๅฏน็ๆฟ้จ็ปๅธ.ๅฌไธ่ง็ปๅธๆบ็ปๅธ็ๅฃฐ้ณ,ๅชๅฌ่งๆจๅ ฐๅจๅนๆฏ.้ฎๆจๅ ฐๅจๆณไปไน?้ฎๆจๅ ฐๅจๆฆ่ฎฐไปไน?ๆจๅ ฐ็ญ้,ๆไนๆฒกๆๅจๆณไปไน,ไนๆฒกๆๅจๆฆ่ฎฐไปไน." | |
| ref_bert_inputs = tokenizer(text, return_tensors="pt") | |
| word2ph = [] | |
| for c in text: | |
| if c in ["๏ผ", "ใ", "๏ผ", "๏ผ", ",", ".", "?"]: | |
| word2ph.append(1) | |
| else: | |
| word2ph.append(2) | |
| ref_bert_inputs["word2ph"] = torch.Tensor(word2ph).int() | |
| bert_model = AutoModelForMaskedLM.from_pretrained(bert_path, output_hidden_states=True, torchscript=True) | |
| my_bert_model = MyBertModel(bert_model) | |
| ref_bert_inputs = { | |
| "input_ids": ref_bert_inputs["input_ids"], | |
| "attention_mask": ref_bert_inputs["attention_mask"], | |
| "token_type_ids": ref_bert_inputs["token_type_ids"], | |
| "word2ph": ref_bert_inputs["word2ph"], | |
| } | |
| torch._dynamo.mark_dynamic(ref_bert_inputs["input_ids"], 1) | |
| torch._dynamo.mark_dynamic(ref_bert_inputs["attention_mask"], 1) | |
| torch._dynamo.mark_dynamic(ref_bert_inputs["token_type_ids"], 1) | |
| torch._dynamo.mark_dynamic(ref_bert_inputs["word2ph"], 0) | |
| my_bert_model = torch.jit.trace(my_bert_model, example_kwarg_inputs=ref_bert_inputs) | |
| output_path = os.path.join(output_path, "bert_model.pt") | |
| my_bert_model.save(output_path) | |
| print("#### exported bert ####") | |
| def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_bert_and_ssl=False, device="cpu"): | |
| if not os.path.exists(output_path): | |
| os.makedirs(output_path) | |
| print(f"็ฎๅฝๅทฒๅๅปบ: {output_path}") | |
| else: | |
| print(f"็ฎๅฝๅทฒๅญๅจ: {output_path}") | |
| ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float() | |
| ssl = SSLModel() | |
| if export_bert_and_ssl: | |
| s = ExportSSLModel(torch.jit.trace(ssl, example_inputs=(ref_audio))) | |
| ssl_path = os.path.join(output_path, "ssl_model.pt") | |
| torch.jit.script(s).save(ssl_path) | |
| print("#### exported ssl ####") | |
| export_bert(output_path) | |
| else: | |
| s = ExportSSLModel(ssl) | |
| print(f"device: {device}") | |
| ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(ref_text, "all_zh", "v2") | |
| ref_seq = torch.LongTensor([ref_seq_id]).to(device) | |
| ref_bert = ref_bert_T.T.to(ref_seq.device) | |
| text_seq_id, text_bert_T, norm_text = get_phones_and_bert( | |
| "่ฟๆฏไธไธช็ฎๅ็็คบไพ๏ผ็ๆฒกๆณๅฐ่ฟไน็ฎๅๅฐฑๅฎๆไบใThe King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.", | |
| "auto", | |
| "v2", | |
| ) | |
| text_seq = torch.LongTensor([text_seq_id]).to(device) | |
| text_bert = text_bert_T.T.to(text_seq.device) | |
| ssl_content = ssl(ref_audio).to(device) | |
| # vits_path = "SoVITS_weights_v2/xw_e8_s216.pth" | |
| vits = VitsModel(vits_path, device=device, is_half=False) | |
| vits.eval() | |
| # gpt_path = "GPT_weights_v2/xw-e15.ckpt" | |
| # dict_s1 = torch.load(gpt_path, map_location=device) | |
| dict_s1 = torch.load(gpt_path, weights_only=False) | |
| raw_t2s = get_raw_t2s_model(dict_s1).to(device) | |
| print("#### get_raw_t2s_model ####") | |
| print(raw_t2s.config) | |
| t2s_m = T2SModel(raw_t2s) | |
| t2s_m.eval() | |
| t2s = torch.jit.script(t2s_m).to(device) | |
| print("#### script t2s_m ####") | |
| print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate) | |
| gpt_sovits = GPT_SoVITS(t2s, vits).to(device) | |
| gpt_sovits.eval() | |
| ref_audio_sr = s.resample(ref_audio, 16000, 32000).to(device) | |
| torch._dynamo.mark_dynamic(ssl_content, 2) | |
| torch._dynamo.mark_dynamic(ref_audio_sr, 1) | |
| torch._dynamo.mark_dynamic(ref_seq, 1) | |
| torch._dynamo.mark_dynamic(text_seq, 1) | |
| torch._dynamo.mark_dynamic(ref_bert, 0) | |
| torch._dynamo.mark_dynamic(text_bert, 0) | |
| top_k = torch.LongTensor([5]).to(device) | |
| with torch.no_grad(): | |
| gpt_sovits_export = torch.jit.trace( | |
| gpt_sovits, example_inputs=(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, text_bert, top_k) | |
| ) | |
| gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt") | |
| gpt_sovits_export.save(gpt_sovits_path) | |
| print("#### exported gpt_sovits ####") | |
| def export_prov2( | |
| gpt_path, | |
| vits_path, | |
| version, | |
| ref_audio_path, | |
| ref_text, | |
| output_path, | |
| export_bert_and_ssl=False, | |
| device="cpu", | |
| is_half=True, | |
| ): | |
| if sv_cn_model == None: | |
| init_sv_cn(device, is_half) | |
| if not os.path.exists(output_path): | |
| os.makedirs(output_path) | |
| print(f"็ฎๅฝๅทฒๅๅปบ: {output_path}") | |
| else: | |
| print(f"็ฎๅฝๅทฒๅญๅจ: {output_path}") | |
| ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float() | |
| ssl = SSLModel() | |
| if export_bert_and_ssl: | |
| s = ExportSSLModel(torch.jit.trace(ssl, example_inputs=(ref_audio))) | |
| ssl_path = os.path.join(output_path, "ssl_model.pt") | |
| torch.jit.script(s).save(ssl_path) | |
| print("#### exported ssl ####") | |
| export_bert(output_path) | |
| else: | |
| s = ExportSSLModel(ssl) | |
| print(f"device: {device}") | |
| ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(ref_text, "all_zh", "v2") | |
| ref_seq = torch.LongTensor([ref_seq_id]).to(device) | |
| ref_bert = ref_bert_T.T | |
| if is_half: | |
| ref_bert = ref_bert.half() | |
| ref_bert = ref_bert.to(ref_seq.device) | |
| text_seq_id, text_bert_T, norm_text = get_phones_and_bert( | |
| "่ฟๆฏไธไธช็ฎๅ็็คบไพ๏ผ็ๆฒกๆณๅฐ่ฟไน็ฎๅๅฐฑๅฎๆไบใThe King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.", | |
| "auto", | |
| "v2", | |
| ) | |
| text_seq = torch.LongTensor([text_seq_id]).to(device) | |
| text_bert = text_bert_T.T | |
| if is_half: | |
| text_bert = text_bert.half() | |
| text_bert = text_bert.to(text_seq.device) | |
| ssl_content = ssl(ref_audio) | |
| if is_half: | |
| ssl_content = ssl_content.half() | |
| ssl_content = ssl_content.to(device) | |
| sv_model = ExportERes2NetV2(sv_cn_model) | |
| # vits_path = "SoVITS_weights_v2/xw_e8_s216.pth" | |
| vits = VitsModel(vits_path, version, is_half=is_half, device=device) | |
| vits.eval() | |
| # gpt_path = "GPT_weights_v2/xw-e15.ckpt" | |
| # dict_s1 = torch.load(gpt_path, map_location=device) | |
| dict_s1 = torch.load(gpt_path, weights_only=False) | |
| raw_t2s = get_raw_t2s_model(dict_s1).to(device) | |
| print("#### get_raw_t2s_model ####") | |
| print(raw_t2s.config) | |
| if is_half: | |
| raw_t2s = raw_t2s.half() | |
| t2s_m = T2SModel(raw_t2s) | |
| t2s_m.eval() | |
| t2s = torch.jit.script(t2s_m).to(device) | |
| print("#### script t2s_m ####") | |
| print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate) | |
| gpt_sovits = GPT_SoVITS_V2Pro(t2s, vits, sv_model).to(device) | |
| gpt_sovits.eval() | |
| ref_audio_sr = s.resample(ref_audio, 16000, 32000) | |
| if is_half: | |
| ref_audio_sr = ref_audio_sr.half() | |
| ref_audio_sr = ref_audio_sr.to(device) | |
| torch._dynamo.mark_dynamic(ssl_content, 2) | |
| torch._dynamo.mark_dynamic(ref_audio_sr, 1) | |
| torch._dynamo.mark_dynamic(ref_seq, 1) | |
| torch._dynamo.mark_dynamic(text_seq, 1) | |
| torch._dynamo.mark_dynamic(ref_bert, 0) | |
| torch._dynamo.mark_dynamic(text_bert, 0) | |
| # torch._dynamo.mark_dynamic(sv_emb, 0) | |
| top_k = torch.LongTensor([5]).to(device) | |
| # ๅ ่ทไธ้ sv_model ่ฎฉๅฎๅ ่ฝฝ cache๏ผ่ฏฆๆ ่ง L880 | |
| gpt_sovits.sv_model(ref_audio_sr) | |
| with torch.no_grad(): | |
| gpt_sovits_export = torch.jit.trace( | |
| gpt_sovits, | |
| example_inputs=( | |
| ssl_content, | |
| ref_audio_sr, | |
| ref_seq, | |
| text_seq, | |
| ref_bert, | |
| text_bert, | |
| top_k, | |
| ), | |
| ) | |
| gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt") | |
| gpt_sovits_export.save(gpt_sovits_path) | |
| print("#### exported gpt_sovits ####") | |
| audio = gpt_sovits_export(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, text_bert, top_k) | |
| print("start write wav") | |
| soundfile.write("out.wav", audio.float().detach().cpu().numpy(), 32000) | |
| def parse_audio(ref_audio): | |
| ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float() # .to(ref_audio.device) | |
| ref_audio_sr = torchaudio.functional.resample(ref_audio, 48000, 32000).float() # .to(ref_audio.device) | |
| return ref_audio_16k, ref_audio_sr | |
| def resamplex(ref_audio: torch.Tensor, src_sr: int, dst_sr: int) -> torch.Tensor: | |
| return torchaudio.functional.resample(ref_audio, src_sr, dst_sr).float() | |
| class GPT_SoVITS(nn.Module): | |
| def __init__(self, t2s: T2SModel, vits: VitsModel): | |
| super().__init__() | |
| self.t2s = t2s | |
| self.vits = vits | |
| def forward( | |
| self, | |
| ssl_content: torch.Tensor, | |
| ref_audio_sr: torch.Tensor, | |
| ref_seq: Tensor, | |
| text_seq: Tensor, | |
| ref_bert: Tensor, | |
| text_bert: Tensor, | |
| top_k: LongTensor, | |
| speed=1.0, | |
| ): | |
| codes = self.vits.vq_model.extract_latent(ssl_content) | |
| prompt_semantic = codes[0, 0] | |
| prompts = prompt_semantic.unsqueeze(0) | |
| pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k) | |
| audio = self.vits(text_seq, pred_semantic, ref_audio_sr, speed) | |
| return audio | |
| class ExportERes2NetV2(nn.Module): | |
| def __init__(self, sv_cn_model: SV): | |
| super(ExportERes2NetV2, self).__init__() | |
| self.bn1 = sv_cn_model.embedding_model.bn1 | |
| self.conv1 = sv_cn_model.embedding_model.conv1 | |
| self.layer1 = sv_cn_model.embedding_model.layer1 | |
| self.layer2 = sv_cn_model.embedding_model.layer2 | |
| self.layer3 = sv_cn_model.embedding_model.layer3 | |
| self.layer4 = sv_cn_model.embedding_model.layer4 | |
| self.layer3_ds = sv_cn_model.embedding_model.layer3_ds | |
| self.fuse34 = sv_cn_model.embedding_model.fuse34 | |
| # audio_16k.shape: [1,N] | |
| def forward(self, audio_16k): | |
| # ่ฟไธช fbank ๅฝๆฐๆไธไธช cache, ไธ่ฟไธ่ฆ็ดง๏ผๅฎ่ท audio_16k ็้ฟๅบฆๆ ๅ ณ | |
| # ๅช่ท device ๅ dtype ๆๅ ณ | |
| x = Kaldi.fbank(audio_16k, num_mel_bins=80, sample_frequency=16000, dither=0) | |
| x = torch.stack([x]) | |
| x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) | |
| x = x.unsqueeze_(1) | |
| out = F.relu(self.bn1(self.conv1(x))) | |
| out1 = self.layer1(out) | |
| out2 = self.layer2(out1) | |
| out3 = self.layer3(out2) | |
| out4 = self.layer4(out3) | |
| out3_ds = self.layer3_ds(out3) | |
| fuse_out34 = self.fuse34(out4, out3_ds) | |
| return fuse_out34.flatten(start_dim=1, end_dim=2).mean(-1) | |
| class GPT_SoVITS_V2Pro(nn.Module): | |
| def __init__(self, t2s: T2SModel, vits: VitsModel, sv_model: ExportERes2NetV2): | |
| super().__init__() | |
| self.t2s = t2s | |
| self.vits = vits | |
| self.sv_model = sv_model | |
| def forward( | |
| self, | |
| ssl_content: torch.Tensor, | |
| ref_audio_sr: torch.Tensor, | |
| ref_seq: Tensor, | |
| text_seq: Tensor, | |
| ref_bert: Tensor, | |
| text_bert: Tensor, | |
| top_k: LongTensor, | |
| speed=1.0, | |
| ): | |
| codes = self.vits.vq_model.extract_latent(ssl_content) | |
| prompt_semantic = codes[0, 0] | |
| prompts = prompt_semantic.unsqueeze(0) | |
| audio_16k = resamplex(ref_audio_sr, 32000, 16000).to(ref_audio_sr.dtype) | |
| sv_emb = self.sv_model(audio_16k) | |
| pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k) | |
| audio = self.vits(text_seq, pred_semantic, ref_audio_sr, speed, sv_emb) | |
| return audio | |
| def test(): | |
| parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool") | |
| parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file") | |
| parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file") | |
| parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file") | |
| parser.add_argument("--ref_text", required=True, help="Path to the reference text file") | |
| parser.add_argument("--output_path", required=True, help="Path to the output directory") | |
| args = parser.parse_args() | |
| gpt_path = args.gpt_model | |
| vits_path = args.sovits_model | |
| ref_audio_path = args.ref_audio | |
| ref_text = args.ref_text | |
| tokenizer = AutoTokenizer.from_pretrained(bert_path) | |
| # bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True,torchscript=True) | |
| # bert = MyBertModel(bert_model) | |
| my_bert = torch.jit.load("onnx/bert_model.pt", map_location="cuda") | |
| # dict_s1 = torch.load(gpt_path, map_location="cuda") | |
| # raw_t2s = get_raw_t2s_model(dict_s1) | |
| # t2s = T2SModel(raw_t2s) | |
| # t2s.eval() | |
| # t2s = torch.jit.load("onnx/xw/t2s_model.pt",map_location='cuda') | |
| # vits_path = "SoVITS_weights_v2/xw_e8_s216.pth" | |
| # vits = VitsModel(vits_path) | |
| # vits.eval() | |
| # ssl = ExportSSLModel(SSLModel()).to('cuda') | |
| # ssl.eval() | |
| ssl = torch.jit.load("onnx/by/ssl_model.pt", map_location="cuda") | |
| # gpt_sovits = GPT_SoVITS(t2s,vits) | |
| gpt_sovits = torch.jit.load("onnx/by/gpt_sovits_model.pt", map_location="cuda") | |
| ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(ref_text, "all_zh", "v2") | |
| ref_seq = torch.LongTensor([ref_seq_id]) | |
| ref_bert = ref_bert_T.T.to(ref_seq.device) | |
| # text_seq_id,text_bert_T,norm_text = get_phones_and_bert("ๆจๅคฉๆไธ็่งๅพๅ ตๆไนฆ,็ฅ้ๅไธปๅจๅคง่งๆจกๅพๅ ต,้ฃไนๅคๅทๅพๅ ตๆๅ,ๆฏไธๅทไธ้ฝๆ็ถไบฒ็ๅๅญ.","all_zh",'v2') | |
| text = "ๆจๅคฉๆไธ็่งๅพๅ ตๆไนฆ,็ฅ้ๅไธปๅจๅคง่งๆจกๅพๅ ต,้ฃไนๅคๅทๅพๅ ตๆๅ,ๆฏไธๅทไธ้ฝๆ็ถไบฒ็ๅๅญ." | |
| text_seq_id, text_bert_T, norm_text = get_phones_and_bert(text, "all_zh", "v2") | |
| test_bert = tokenizer(text, return_tensors="pt") | |
| word2ph = [] | |
| for c in text: | |
| if c in ["๏ผ", "ใ", "๏ผ", "๏ผ", "?", ",", "."]: | |
| word2ph.append(1) | |
| else: | |
| word2ph.append(2) | |
| test_bert["word2ph"] = torch.Tensor(word2ph).int() | |
| test_bert = my_bert( | |
| test_bert["input_ids"].to("cuda"), | |
| test_bert["attention_mask"].to("cuda"), | |
| test_bert["token_type_ids"].to("cuda"), | |
| test_bert["word2ph"].to("cuda"), | |
| ) | |
| text_seq = torch.LongTensor([text_seq_id]) | |
| text_bert = text_bert_T.T.to(text_seq.device) | |
| print("text_bert:", text_bert.shape, text_bert) | |
| print("test_bert:", test_bert.shape, test_bert) | |
| print(torch.allclose(text_bert.to("cuda"), test_bert)) | |
| print("text_seq:", text_seq.shape) | |
| print("text_bert:", text_bert.shape, text_bert.type()) | |
| # [1,N] | |
| ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float().to("cuda") | |
| print("ref_audio:", ref_audio.shape) | |
| ref_audio_sr = ssl.resample(ref_audio, 16000, 32000) | |
| print("start ssl") | |
| ssl_content = ssl(ref_audio) | |
| print("start gpt_sovits:") | |
| print("ssl_content:", ssl_content.shape) | |
| print("ref_audio_sr:", ref_audio_sr.shape) | |
| print("ref_seq:", ref_seq.shape) | |
| ref_seq = ref_seq.to("cuda") | |
| print("text_seq:", text_seq.shape) | |
| text_seq = text_seq.to("cuda") | |
| print("ref_bert:", ref_bert.shape) | |
| ref_bert = ref_bert.to("cuda") | |
| print("text_bert:", text_bert.shape) | |
| text_bert = text_bert.to("cuda") | |
| top_k = torch.LongTensor([5]).to("cuda") | |
| with torch.no_grad(): | |
| audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, test_bert, top_k) | |
| print("start write wav") | |
| soundfile.write("out.wav", audio.detach().cpu().numpy(), 32000) | |
| import text | |
| import json | |
| def export_symbel(version="v2"): | |
| if version == "v1": | |
| symbols = text._symbol_to_id_v1 | |
| with open("onnx/symbols_v1.json", "w") as file: | |
| json.dump(symbols, file, indent=4) | |
| else: | |
| symbols = text._symbol_to_id_v2 | |
| with open("onnx/symbols_v2.json", "w") as file: | |
| json.dump(symbols, file, indent=4) | |
| def main(): | |
| parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool") | |
| parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file") | |
| parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file") | |
| parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file") | |
| parser.add_argument("--ref_text", required=True, help="Path to the reference text file") | |
| parser.add_argument("--output_path", required=True, help="Path to the output directory") | |
| parser.add_argument("--export_common_model", action="store_true", help="Export Bert and SSL model") | |
| parser.add_argument("--device", help="Device to use") | |
| parser.add_argument("--version", help="version of the model", default="v2") | |
| parser.add_argument("--no-half", action="store_true", help="Do not use half precision for model weights") | |
| args = parser.parse_args() | |
| if args.version in ["v2Pro", "v2ProPlus"]: | |
| is_half = not args.no_half | |
| print(f"Using half precision: {is_half}") | |
| export_prov2( | |
| gpt_path=args.gpt_model, | |
| vits_path=args.sovits_model, | |
| version=args.version, | |
| ref_audio_path=args.ref_audio, | |
| ref_text=args.ref_text, | |
| output_path=args.output_path, | |
| export_bert_and_ssl=args.export_common_model, | |
| device=args.device, | |
| is_half=is_half, | |
| ) | |
| else: | |
| export( | |
| gpt_path=args.gpt_model, | |
| vits_path=args.sovits_model, | |
| ref_audio_path=args.ref_audio, | |
| ref_text=args.ref_text, | |
| output_path=args.output_path, | |
| device=args.device, | |
| export_bert_and_ssl=args.export_common_model, | |
| ) | |
| if __name__ == "__main__": | |
| with torch.no_grad(): | |
| main() | |
| # test() | |