Spaces:
Build error
Build error
| import hashlib | |
| import json | |
| import os | |
| import time | |
| import traceback | |
| import warnings | |
| from pathlib import Path | |
| import numpy as np | |
| import parselmouth | |
| import resampy | |
| import torch | |
| import torchcrepe | |
| import utils | |
| from modules.vocoders.nsf_hifigan import nsf_hifigan | |
| from utils.hparams import hparams | |
| from utils.pitch_utils import f0_to_coarse | |
| warnings.filterwarnings("ignore") | |
| class BinarizationError(Exception): | |
| pass | |
| def get_md5(content): | |
| return hashlib.new("md5", content).hexdigest() | |
| def read_temp(file_name): | |
| if not os.path.exists(file_name): | |
| with open(file_name, "w") as f: | |
| f.write(json.dumps({"info": "temp_dict"})) | |
| return {} | |
| else: | |
| try: | |
| with open(file_name, "r") as f: | |
| data = f.read() | |
| data_dict = json.loads(data) | |
| if os.path.getsize(file_name) > 50 * 1024 * 1024: | |
| f_name = file_name.split("/")[-1] | |
| print(f"clean {f_name}") | |
| for wav_hash in list(data_dict.keys()): | |
| if int(time.time()) - int(data_dict[wav_hash]["time"]) > 14 * 24 * 3600: | |
| del data_dict[wav_hash] | |
| except Exception as e: | |
| print(e) | |
| print(f"{file_name} error,auto rebuild file") | |
| data_dict = {"info": "temp_dict"} | |
| return data_dict | |
| def write_temp(file_name, data): | |
| with open(file_name, "w") as f: | |
| f.write(json.dumps(data)) | |
| f0_dict = read_temp("./infer_tools/f0_temp.json") | |
| def get_pitch_parselmouth(wav_data, mel, hparams): | |
| """ | |
| :param wav_data: [T] | |
| :param mel: [T, 80] | |
| :param hparams: | |
| :return: | |
| """ | |
| time_step = hparams['hop_size'] / hparams['audio_sample_rate'] | |
| f0_min = hparams['f0_min'] | |
| f0_max = hparams['f0_max'] | |
| f0 = parselmouth.Sound(wav_data, hparams['audio_sample_rate']).to_pitch_ac( | |
| time_step=time_step, voicing_threshold=0.6, | |
| pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency'] | |
| pad_size = (int(len(wav_data) // hparams['hop_size']) - len(f0) + 1) // 2 | |
| f0 = np.pad(f0, [[pad_size, len(mel) - len(f0) - pad_size]], mode='constant') | |
| pitch_coarse = f0_to_coarse(f0, hparams) | |
| return f0, pitch_coarse | |
| def get_pitch_crepe(wav_data, mel, hparams, threshold=0.05): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # device = torch.device("cuda") | |
| # crepe只支持16khz采样率,需要重采样 | |
| wav16k = resampy.resample(wav_data, hparams['audio_sample_rate'], 16000) | |
| wav16k_torch = torch.FloatTensor(wav16k).unsqueeze(0).to(device) | |
| # 频率范围 | |
| f0_min = hparams['f0_min'] | |
| f0_max = hparams['f0_max'] | |
| # 重采样后按照hopsize=80,也就是5ms一帧分析f0 | |
| f0, pd = torchcrepe.predict(wav16k_torch, 16000, 80, f0_min, f0_max, pad=True, model='full', batch_size=1024, | |
| device=device, return_periodicity=True) | |
| # 滤波,去掉静音,设置uv阈值,参考原仓库readme | |
| pd = torchcrepe.filter.median(pd, 3) | |
| pd = torchcrepe.threshold.Silence(-60.)(pd, wav16k_torch, 16000, 80) | |
| f0 = torchcrepe.threshold.At(threshold)(f0, pd) | |
| f0 = torchcrepe.filter.mean(f0, 3) | |
| # 将nan频率(uv部分)转换为0频率 | |
| f0 = torch.where(torch.isnan(f0), torch.full_like(f0, 0), f0) | |
| # 去掉0频率,并线性插值 | |
| nzindex = torch.nonzero(f0[0]).squeeze() | |
| f0 = torch.index_select(f0[0], dim=0, index=nzindex).cpu().numpy() | |
| time_org = 0.005 * nzindex.cpu().numpy() | |
| time_frame = np.arange(len(mel)) * hparams['hop_size'] / hparams['audio_sample_rate'] | |
| if f0.shape[0] == 0: | |
| f0 = torch.FloatTensor(time_frame.shape[0]).fill_(0) | |
| print('f0 all zero!') | |
| else: | |
| f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1]) | |
| pitch_coarse = f0_to_coarse(f0, hparams) | |
| return f0, pitch_coarse | |
| class File2Batch: | |
| ''' | |
| pipeline: file -> temporary_dict -> processed_input -> batch | |
| ''' | |
| def file2temporary_dict(raw_data_dir, ds_id): | |
| ''' | |
| read from file, store data in temporary dicts | |
| ''' | |
| raw_data_dir = Path(raw_data_dir) | |
| utterance_labels = [] | |
| utterance_labels.extend(list(raw_data_dir.rglob(f"*.wav"))) | |
| utterance_labels.extend(list(raw_data_dir.rglob(f"*.ogg"))) | |
| all_temp_dict = {} | |
| for utterance_label in utterance_labels: | |
| item_name = str(utterance_label) | |
| temp_dict = {'wav_fn': str(utterance_label), 'spk_id': ds_id} | |
| all_temp_dict[item_name] = temp_dict | |
| return all_temp_dict | |
| def temporary_dict2processed_input(item_name, temp_dict, encoder, infer=False, **kwargs): | |
| ''' | |
| process data in temporary_dicts | |
| ''' | |
| def get_pitch(wav, mel): | |
| # get ground truth f0 by self.get_pitch_algorithm | |
| global f0_dict | |
| use_crepe = hparams['use_crepe'] if not infer else kwargs['use_crepe'] | |
| if use_crepe: | |
| md5 = get_md5(wav) | |
| if infer and md5 in f0_dict.keys(): | |
| print("load temp crepe f0") | |
| gt_f0 = np.array(f0_dict[md5]["f0"]) | |
| coarse_f0 = np.array(f0_dict[md5]["coarse"]) | |
| else: | |
| torch.cuda.is_available() and torch.cuda.empty_cache() | |
| gt_f0, coarse_f0 = get_pitch_crepe(wav, mel, hparams, threshold=0.05) | |
| if infer: | |
| f0_dict[md5] = {"f0": gt_f0.tolist(), "coarse": coarse_f0.tolist(), "time": int(time.time())} | |
| write_temp("./infer_tools/f0_temp.json", f0_dict) | |
| else: | |
| gt_f0, coarse_f0 = get_pitch_parselmouth(wav, mel, hparams) | |
| if sum(gt_f0) == 0: | |
| raise BinarizationError("Empty **gt** f0") | |
| processed_input['f0'] = gt_f0 | |
| processed_input['pitch'] = coarse_f0 | |
| def get_align(mel, phone_encoded): | |
| mel2ph = np.zeros([mel.shape[0]], int) | |
| start_frame = 0 | |
| ph_durs = mel.shape[0] / phone_encoded.shape[0] | |
| for i_ph in range(phone_encoded.shape[0]): | |
| end_frame = int(i_ph * ph_durs + ph_durs + 0.5) | |
| mel2ph[start_frame:end_frame + 1] = i_ph + 1 | |
| start_frame = end_frame + 1 | |
| processed_input['mel2ph'] = mel2ph | |
| wav, mel = nsf_hifigan.wav2spec(temp_dict['wav_fn']) | |
| processed_input = { | |
| 'item_name': item_name, 'mel': mel, | |
| 'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0] | |
| } | |
| processed_input = {**temp_dict, **processed_input, | |
| 'spec_min': np.min(mel, axis=0), | |
| 'spec_max': np.max(mel, axis=0)} # merge two dicts | |
| try: | |
| get_pitch(wav, mel) | |
| try: | |
| hubert_encoded = processed_input['hubert'] = encoder.encode(temp_dict['wav_fn']) | |
| except: | |
| traceback.print_exc() | |
| raise Exception(f"hubert encode error") | |
| get_align(mel, hubert_encoded) | |
| except Exception as e: | |
| print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {temp_dict['wav_fn']}") | |
| return None | |
| if hparams['use_energy_embed']: | |
| max_frames = hparams['max_frames'] | |
| spec = torch.Tensor(processed_input['mel'])[:max_frames] | |
| processed_input['energy'] = (spec.exp() ** 2).sum(-1).sqrt() | |
| return processed_input | |
| def processed_input2batch(samples): | |
| ''' | |
| Args: | |
| samples: one batch of processed_input | |
| NOTE: | |
| the batch size is controlled by hparams['max_sentences'] | |
| ''' | |
| if len(samples) == 0: | |
| return {} | |
| id = torch.LongTensor([s['id'] for s in samples]) | |
| item_names = [s['item_name'] for s in samples] | |
| hubert = utils.collate_2d([s['hubert'] for s in samples], 0.0) | |
| f0 = utils.collate_1d([s['f0'] for s in samples], 0.0) | |
| pitch = utils.collate_1d([s['pitch'] for s in samples]) | |
| uv = utils.collate_1d([s['uv'] for s in samples]) | |
| mel2ph = utils.collate_1d([s['mel2ph'] for s in samples], 0.0) \ | |
| if samples[0]['mel2ph'] is not None else None | |
| mels = utils.collate_2d([s['mel'] for s in samples], 0.0) | |
| mel_lengths = torch.LongTensor([s['mel'].shape[0] for s in samples]) | |
| batch = { | |
| 'id': id, | |
| 'item_name': item_names, | |
| 'nsamples': len(samples), | |
| 'hubert': hubert, | |
| 'mels': mels, | |
| 'mel_lengths': mel_lengths, | |
| 'mel2ph': mel2ph, | |
| 'pitch': pitch, | |
| 'f0': f0, | |
| 'uv': uv, | |
| } | |
| if hparams['use_energy_embed']: | |
| batch['energy'] = utils.collate_1d([s['energy'] for s in samples], 0.0) | |
| if hparams['use_spk_id']: | |
| spk_ids = torch.LongTensor([s['spk_id'] for s in samples]) | |
| batch['spk_ids'] = spk_ids | |
| return batch | |