| import os
|
|
|
| import torch
|
| import numpy as np
|
| from modules.hifigan.hifigan import HifiGanGenerator
|
| from vocoders.hifigan import HifiGAN
|
| from inference.svs.opencpop.map import cpop_pinyin2ph_func
|
|
|
| from utils import load_ckpt
|
| from utils.hparams import set_hparams, hparams
|
| from utils.text_encoder import TokenTextEncoder
|
| from pypinyin import pinyin, lazy_pinyin, Style
|
| import librosa
|
| import glob
|
| import re
|
|
|
|
|
| class BaseSVSInfer:
|
| def __init__(self, hparams, device=None):
|
| if device is None:
|
| device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| self.hparams = hparams
|
| self.device = device
|
|
|
| phone_list = ["AP", "SP", "a", "ai", "an", "ang", "ao", "b", "c", "ch", "d", "e", "ei", "en", "eng", "er", "f", "g",
|
| "h", "i", "ia", "ian", "iang", "iao", "ie", "in", "ing", "iong", "iu", "j", "k", "l", "m", "n", "o",
|
| "ong", "ou", "p", "q", "r", "s", "sh", "t", "u", "ua", "uai", "uan", "uang", "ui", "un", "uo", "v",
|
| "van", "ve", "vn", "w", "x", "y", "z", "zh"]
|
| self.ph_encoder = TokenTextEncoder(None, vocab_list=phone_list, replace_oov=',')
|
| self.pinyin2phs = cpop_pinyin2ph_func()
|
| self.spk_map = {'opencpop': 0}
|
|
|
| self.model = self.build_model()
|
| self.model.eval()
|
| self.model.to(self.device)
|
| self.vocoder = self.build_vocoder()
|
| self.vocoder.eval()
|
| self.vocoder.to(self.device)
|
|
|
| def build_model(self):
|
| raise NotImplementedError
|
|
|
| def forward_model(self, inp):
|
| raise NotImplementedError
|
|
|
| def build_vocoder(self):
|
| base_dir = hparams['vocoder_ckpt']
|
| config_path = f'{base_dir}/config.yaml'
|
| ckpt = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key=
|
| lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x)[0]))[-1]
|
| print('| load HifiGAN: ', ckpt)
|
| ckpt_dict = torch.load(ckpt, map_location="cpu")
|
| config = set_hparams(config_path, global_hparams=False)
|
| state = ckpt_dict["state_dict"]["model_gen"]
|
| vocoder = HifiGanGenerator(config)
|
| vocoder.load_state_dict(state, strict=True)
|
| vocoder.remove_weight_norm()
|
| vocoder = vocoder.eval().to(self.device)
|
| return vocoder
|
|
|
| def run_vocoder(self, c, **kwargs):
|
| c = c.transpose(2, 1)
|
| f0 = kwargs.get('f0')
|
| if f0 is not None and hparams.get('use_nsf'):
|
|
|
| y = self.vocoder(c, f0).view(-1)
|
| else:
|
| y = self.vocoder(c).view(-1)
|
|
|
| return y[None]
|
|
|
| def preprocess_word_level_input(self, inp):
|
|
|
| text_raw = inp['text'].replace('最长', '最常').replace('长睫毛', '常睫毛') \
|
| .replace('那么长', '那么常').replace('多长', '多常') \
|
| .replace('很长', '很常')
|
|
|
|
|
| pinyins = lazy_pinyin(text_raw, strict=False)
|
| ph_per_word_lst = [self.pinyin2phs[pinyin.strip()] for pinyin in pinyins if pinyin.strip() in self.pinyin2phs]
|
|
|
|
|
| note_per_word_lst = [x.strip() for x in inp['notes'].split('|') if x.strip() != '']
|
| mididur_per_word_lst = [x.strip() for x in inp['notes_duration'].split('|') if x.strip() != '']
|
|
|
| if len(note_per_word_lst) == len(ph_per_word_lst) == len(mididur_per_word_lst):
|
| print('Pass word-notes check.')
|
| else:
|
| print('The number of words does\'t match the number of notes\' windows. ',
|
| 'You should split the note(s) for each word by | mark.')
|
| print(ph_per_word_lst, note_per_word_lst, mididur_per_word_lst)
|
| print(len(ph_per_word_lst), len(note_per_word_lst), len(mididur_per_word_lst))
|
| return None
|
|
|
| note_lst = []
|
| ph_lst = []
|
| midi_dur_lst = []
|
| is_slur = []
|
| for idx, ph_per_word in enumerate(ph_per_word_lst):
|
|
|
|
|
| ph_in_this_word = ph_per_word.split()
|
|
|
|
|
|
|
| note_in_this_word = note_per_word_lst[idx].split()
|
| midi_dur_in_this_word = mididur_per_word_lst[idx].split()
|
|
|
|
|
|
|
|
|
|
|
|
|
| for ph in ph_in_this_word:
|
| ph_lst.append(ph)
|
| note_lst.append(note_in_this_word[0])
|
| midi_dur_lst.append(midi_dur_in_this_word[0])
|
| is_slur.append(0)
|
|
|
|
|
|
|
|
|
|
|
| if len(note_in_this_word) > 1:
|
| for idx in range(1, len(note_in_this_word)):
|
| ph_lst.append(ph_in_this_word[-1])
|
| note_lst.append(note_in_this_word[idx])
|
| midi_dur_lst.append(midi_dur_in_this_word[idx])
|
| is_slur.append(1)
|
| ph_seq = ' '.join(ph_lst)
|
|
|
| if len(ph_lst) == len(note_lst) == len(midi_dur_lst):
|
| print(len(ph_lst), len(note_lst), len(midi_dur_lst))
|
| print('Pass word-notes check.')
|
| else:
|
| print('The number of words does\'t match the number of notes\' windows. ',
|
| 'You should split the note(s) for each word by | mark.')
|
| return None
|
| return ph_seq, note_lst, midi_dur_lst, is_slur
|
|
|
| def preprocess_phoneme_level_input(self, inp):
|
| ph_seq = inp['ph_seq']
|
| note_lst = inp['note_seq'].split()
|
| midi_dur_lst = inp['note_dur_seq'].split()
|
| is_slur = [float(x) for x in inp['is_slur_seq'].split()]
|
| print(len(note_lst), len(ph_seq.split()), len(midi_dur_lst))
|
| if len(note_lst) == len(ph_seq.split()) == len(midi_dur_lst):
|
| print('Pass word-notes check.')
|
| else:
|
| print('The number of words does\'t match the number of notes\' windows. ',
|
| 'You should split the note(s) for each word by | mark.')
|
| return None
|
| return ph_seq, note_lst, midi_dur_lst, is_slur
|
|
|
| def preprocess_input(self, inp, input_type='word'):
|
| """
|
|
|
| :param inp: {'text': str, 'item_name': (str, optional), 'spk_name': (str, optional)}
|
| :return:
|
| """
|
|
|
| item_name = inp.get('item_name', '<ITEM_NAME>')
|
| spk_name = inp.get('spk_name', 'opencpop')
|
|
|
|
|
| spk_id = self.spk_map[spk_name]
|
|
|
|
|
| if input_type == 'word':
|
| ret = self.preprocess_word_level_input(inp)
|
| elif input_type == 'phoneme':
|
| ret = self.preprocess_phoneme_level_input(inp)
|
| else:
|
| print('Invalid input type.')
|
| return None
|
|
|
| if ret:
|
| ph_seq, note_lst, midi_dur_lst, is_slur = ret
|
| else:
|
| print('==========> Preprocess_word_level or phone_level input wrong.')
|
| return None
|
|
|
|
|
| try:
|
| midis = [librosa.note_to_midi(x.split("/")[0]) if x != 'rest' else 0
|
| for x in note_lst]
|
| midi_dur_lst = [float(x) for x in midi_dur_lst]
|
| except Exception as e:
|
| print(e)
|
| print('Invalid Input Type.')
|
| return None
|
|
|
| ph_token = self.ph_encoder.encode(ph_seq)
|
| item = {'item_name': item_name, 'text': inp['text'], 'ph': ph_seq, 'spk_id': spk_id,
|
| 'ph_token': ph_token, 'pitch_midi': np.asarray(midis), 'midi_dur': np.asarray(midi_dur_lst),
|
| 'is_slur': np.asarray(is_slur), }
|
| item['ph_len'] = len(item['ph_token'])
|
| return item
|
|
|
| def input_to_batch(self, item):
|
| item_names = [item['item_name']]
|
| text = [item['text']]
|
| ph = [item['ph']]
|
| txt_tokens = torch.LongTensor(item['ph_token'])[None, :].to(self.device)
|
| txt_lengths = torch.LongTensor([txt_tokens.shape[1]]).to(self.device)
|
| spk_ids = torch.LongTensor(item['spk_id'])[None, :].to(self.device)
|
|
|
| pitch_midi = torch.LongTensor(item['pitch_midi'])[None, :hparams['max_frames']].to(self.device)
|
| midi_dur = torch.FloatTensor(item['midi_dur'])[None, :hparams['max_frames']].to(self.device)
|
| is_slur = torch.LongTensor(item['is_slur'])[None, :hparams['max_frames']].to(self.device)
|
|
|
| batch = {
|
| 'item_name': item_names,
|
| 'text': text,
|
| 'ph': ph,
|
| 'txt_tokens': txt_tokens,
|
| 'txt_lengths': txt_lengths,
|
| 'spk_ids': spk_ids,
|
| 'pitch_midi': pitch_midi,
|
| 'midi_dur': midi_dur,
|
| 'is_slur': is_slur
|
| }
|
| return batch
|
|
|
| def postprocess_output(self, output):
|
| return output
|
|
|
| def infer_once(self, inp):
|
| inp = self.preprocess_input(inp, input_type=inp['input_type'] if inp.get('input_type') else 'word')
|
| output = self.forward_model(inp)
|
| output = self.postprocess_output(output)
|
| return output
|
|
|
| @classmethod
|
| def example_run(cls, inp):
|
| from utils.audio import save_wav
|
| set_hparams(print_hparams=False)
|
| infer_ins = cls(hparams)
|
| out = infer_ins.infer_once(inp)
|
| os.makedirs('infer_out', exist_ok=True)
|
| save_wav(out, f'infer_out/example_out.wav', hparams['audio_sample_rate'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| |