| | """ |
| | @Desc: 2.0版本兼容 对应2.0.1 2.0.2-fix |
| | """ |
| | import torch |
| | import commons |
| | from .text import cleaned_text_to_sequence, get_bert |
| | from .text.cleaner import clean_text |
| |
|
| |
|
| | def get_text(text, language_str, hps, device): |
| | |
| | norm_text, phone, tone, word2ph = clean_text(text, language_str) |
| | phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str) |
| |
|
| | if hps.data.add_blank: |
| | phone = commons.intersperse(phone, 0) |
| | tone = commons.intersperse(tone, 0) |
| | language = commons.intersperse(language, 0) |
| | for i in range(len(word2ph)): |
| | word2ph[i] = word2ph[i] * 2 |
| | word2ph[0] += 1 |
| | bert_ori = get_bert(norm_text, word2ph, language_str, device) |
| | del word2ph |
| | assert bert_ori.shape[-1] == len(phone), phone |
| |
|
| | if language_str == "ZH": |
| | bert = bert_ori |
| | ja_bert = torch.zeros(1024, len(phone)) |
| | en_bert = torch.zeros(1024, len(phone)) |
| | elif language_str == "JP": |
| | bert = torch.zeros(1024, len(phone)) |
| | ja_bert = bert_ori |
| | en_bert = torch.zeros(1024, len(phone)) |
| | elif language_str == "EN": |
| | bert = torch.zeros(1024, len(phone)) |
| | ja_bert = torch.zeros(1024, len(phone)) |
| | en_bert = bert_ori |
| | else: |
| | raise ValueError("language_str should be ZH, JP or EN") |
| |
|
| | assert bert.shape[-1] == len( |
| | phone |
| | ), f"Bert seq len {bert.shape[-1]} != {len(phone)}" |
| |
|
| | phone = torch.LongTensor(phone) |
| | tone = torch.LongTensor(tone) |
| | language = torch.LongTensor(language) |
| | return bert, ja_bert, en_bert, phone, tone, language |
| |
|
| |
|
| | def infer( |
| | text, |
| | sdp_ratio, |
| | noise_scale, |
| | noise_scale_w, |
| | length_scale, |
| | sid, |
| | language, |
| | hps, |
| | net_g, |
| | device, |
| | ): |
| | bert, ja_bert, en_bert, phones, tones, lang_ids = get_text( |
| | text, language, hps, device |
| | ) |
| | with torch.no_grad(): |
| | x_tst = phones.to(device).unsqueeze(0) |
| | tones = tones.to(device).unsqueeze(0) |
| | lang_ids = lang_ids.to(device).unsqueeze(0) |
| | bert = bert.to(device).unsqueeze(0) |
| | ja_bert = ja_bert.to(device).unsqueeze(0) |
| | en_bert = en_bert.to(device).unsqueeze(0) |
| | x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device) |
| | del phones |
| | speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device) |
| | audio = ( |
| | net_g.infer( |
| | x_tst, |
| | x_tst_lengths, |
| | speakers, |
| | tones, |
| | lang_ids, |
| | bert, |
| | ja_bert, |
| | en_bert, |
| | sdp_ratio=sdp_ratio, |
| | noise_scale=noise_scale, |
| | noise_scale_w=noise_scale_w, |
| | length_scale=length_scale, |
| | )[0][0, 0] |
| | .data.cpu() |
| | .float() |
| | .numpy() |
| | ) |
| | del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers, ja_bert, en_bert |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| | return audio |
| |
|