import os import glob import sys import argparse import logging import json import subprocess import numpy as np from scipy.io.wavfile import read import torch import torchvision from torch.nn import functional as F from commons import sequence_mask import hifigan # 导入 pyworld 用于 F0 计算 import numpy as np from transformers import ( Wav2Vec2FeatureExtractor, HubertModel, ) MATPLOTLIB_FLAG = False logging.basicConfig(stream=sys.stdout, level=logging.WARNING) logger = logging model_dir = "HuBERT/chinese-hubert-large-fariseq-ckpt" def get_cmodel(rank): # 加载 HuBERT 模型和特征提取器 feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_dir) cmodel = HubertModel.from_pretrained(model_dir) cmodel = cmodel.to(rank) cmodel = cmodel.float() cmodel.eval() return cmodel, feature_extractor def get_content(cmodel, feature_extractor, y, device): # 预处理音频 input_values = feature_extractor(y, return_tensors="pt", sampling_rate=16000).input_values # [1, T] input_values = input_values.to(device).float() # 提取特征 with torch.no_grad(): outputs = cmodel(input_values) c = outputs.last_hidden_state # [batch, seq_len, hidden_size] c = c.transpose(1, 2) return c def get_vocoder(rank): with open("hifigan/config.json", "r") as f: config = json.load(f) config = hifigan.AttrDict(config) vocoder = hifigan.Generator(config) ckpt = torch.load("hifigan/generator_v1.pth") vocoder.load_state_dict(ckpt["generator"]) vocoder.eval() vocoder.remove_weight_norm() vocoder.cuda(rank) return vocoder def transform(mel, height): # 68-92 # r = np.random.random() # rate = r * 0.3 + 0.85 # 0.85-1.15 # height = int(mel.size(-2) * rate) tgt = torchvision.transforms.functional.resize(mel, (height, mel.size(-1))) if height >= mel.size(-2): return tgt[:, :mel.size(-2), :] else: silence = tgt[:, -1:, :].repeat(1, mel.size(-2) - height, 1) silence += torch.randn_like(silence) / 10 return torch.cat((tgt, silence), 1) def stretch(mel, width): # 0.5-2 return torchvision.transforms.functional.resize(mel, (mel.size(-2), width)) def load_checkpoint(checkpoint_path, model, optimizer=None, strict=False): assert os.path.isfile(checkpoint_path) checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') iteration = checkpoint_dict['iteration'] learning_rate = checkpoint_dict['learning_rate'] if optimizer is not None: optimizer.load_state_dict(checkpoint_dict['optimizer']) saved_state_dict = checkpoint_dict['model'] if hasattr(model, 'module'): state_dict = model.module.state_dict() else: state_dict = model.state_dict() if strict: assert state_dict.keys() == saved_state_dict.keys(), "Mismatched model config and checkpoint." new_state_dict = {} for k, v in state_dict.items(): try: new_state_dict[k] = saved_state_dict[k] except: logger.info("%s is not in the checkpoint" % k) new_state_dict[k] = v if hasattr(model, 'module'): model.module.load_state_dict(new_state_dict) else: model.load_state_dict(new_state_dict) logger.info("Loaded checkpoint '{}' (iteration {})".format( checkpoint_path, iteration)) return model, optimizer, learning_rate, iteration def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): logger.info("Saving model and optimizer state at iteration {} to {}".format( iteration, checkpoint_path)) if hasattr(model, 'module'): state_dict = model.module.state_dict() else: state_dict = model.state_dict() torch.save({'model': state_dict, 'iteration': iteration, 'optimizer': optimizer.state_dict(), 'learning_rate': learning_rate}, checkpoint_path) def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050): for k, v in scalars.items(): writer.add_scalar(k, v, global_step) for k, v in histograms.items(): writer.add_histogram(k, v, global_step) for k, v in images.items(): writer.add_image(k, v, global_step, dataformats='HWC') for k, v in audios.items(): writer.add_audio(k, v, global_step, audio_sampling_rate) def latest_checkpoint_path(dir_path, regex="G_*.pth"): f_list = glob.glob(os.path.join(dir_path, regex)) f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) x = f_list[-1] print(x) return x def plot_spectrogram_to_numpy(spectrogram): global MATPLOTLIB_FLAG if not MATPLOTLIB_FLAG: import matplotlib matplotlib.use("Agg") MATPLOTLIB_FLAG = True mpl_logger = logging.getLogger('matplotlib') mpl_logger.setLevel(logging.WARNING) import matplotlib.pylab as plt import numpy as np fig, ax = plt.subplots(figsize=(10, 2)) im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation='none') plt.colorbar(im, ax=ax) plt.xlabel("Frames") plt.ylabel("Channels") plt.tight_layout() fig.canvas.draw() data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) plt.close() return data def plot_alignment_to_numpy(alignment, info=None): global MATPLOTLIB_FLAG if not MATPLOTLIB_FLAG: import matplotlib matplotlib.use("Agg") MATPLOTLIB_FLAG = True mpl_logger = logging.getLogger('matplotlib') mpl_logger.setLevel(logging.WARNING) import matplotlib.pylab as plt import numpy as np fig, ax = plt.subplots(figsize=(6, 4)) im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower', interpolation='none') fig.colorbar(im, ax=ax) xlabel = 'Decoder timestep' if info is not None: xlabel += '\n\n' + info plt.xlabel(xlabel) plt.ylabel('Encoder timestep') plt.tight_layout() fig.canvas.draw() data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) plt.close() return data def load_wav_to_torch(full_path): sampling_rate, data = read(full_path) return torch.FloatTensor(data.astype(np.float32)), sampling_rate def load_filepaths_and_text(filename, split="|"): with open(filename, encoding='utf-8') as f: filepaths_and_text = [line.strip().split(split) for line in f] return filepaths_and_text def get_hparams(init=True): parser = argparse.ArgumentParser() parser.add_argument('-c', '--config', type=str, default="./configs/base.json", help='JSON file for configuration') parser.add_argument('-m', '--model', type=str, required=True, help='Model name') args = parser.parse_args() model_dir = os.path.join("/data/logs/", args.model) if not os.path.exists(model_dir): os.makedirs(model_dir) config_path = args.config config_save_path = os.path.join(model_dir, "config.json") if init: with open(config_path, "r") as f: data = f.read() with open(config_save_path, "w") as f: f.write(data) else: with open(config_save_path, "r") as f: data = f.read() config = json.loads(data) hparams = HParams(**config) hparams.model_dir = model_dir return hparams def get_hparams_from_dir(model_dir): config_save_path = os.path.join(model_dir, "config.json") with open(config_save_path, "r") as f: data = f.read() config = json.loads(data) hparams = HParams(**config) hparams.model_dir = model_dir return hparams def get_hparams_from_file(config_path): with open(config_path, "r") as f: data = f.read() config = json.loads(data) hparams = HParams(**config) return hparams def check_git_hash(model_dir): source_dir = os.path.dirname(os.path.realpath(__file__)) if not os.path.exists(os.path.join(source_dir, ".git")): logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format( source_dir )) return cur_hash = subprocess.getoutput("git rev-parse HEAD") path = os.path.join(model_dir, "githash") if os.path.exists(path): saved_hash = open(path).read() if saved_hash != cur_hash: logger.warn("git hash values are different. {}(saved) != {}(current)".format( saved_hash[:8], cur_hash[:8])) else: open(path, "w").write(cur_hash) def get_logger(model_dir, filename="train.log"): global logger logger = logging.getLogger(os.path.basename(model_dir)) logger.setLevel(logging.DEBUG) formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") if not os.path.exists(model_dir): os.makedirs(model_dir) h = logging.FileHandler(os.path.join(model_dir, filename)) h.setLevel(logging.DEBUG) h.setFormatter(formatter) logger.addHandler(h) return logger class HParams(): def __init__(self, **kwargs): for k, v in kwargs.items(): if type(v) == dict: v = HParams(**v) self[k] = v def keys(self): return self.__dict__.keys() def items(self): return self.__dict__.items() def values(self): return self.__dict__.values() def __len__(self): return len(self.__dict__) def __getitem__(self, key): return getattr(self, key) def __setitem__(self, key, value): return setattr(self, key, value) def __contains__(self, key): return key in self.__dict__ def __repr__(self): return self.__dict__.__repr__() def get_f0_torch(wav, sr=16000, hop_length=160): """ 使用 torchcrepe 在 GPU 上提取 F0 输入 wav: (batch_size, time), 在 GPU 上 返回 f0: (batch_size, f0_length), 在 GPU 上 """ f0 = torchcrepe.predict(wav, sample_rate=sr, hop_length=hop_length, batch_size=64, device=wav.device, model='full') return f0 def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_time=True): """Freeing up space by deleting saved ckpts Arguments: path_to_models -- Path to the model directory n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth sort_by_time -- True -> chronologically delete ckpts False -> lexicographically delete ckpts """ import re ckpts_files = [ f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f)) ] def name_key(_f): return int(re.compile("._(\\d+)\\.pth").match(_f).group(1)) def time_key(_f): return os.path.getmtime(os.path.join(path_to_models, _f)) sort_key = time_key if sort_by_time else name_key def x_sorted(_x): return sorted( [f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")], key=sort_key, ) to_del = [ os.path.join(path_to_models, fn) for fn in (x_sorted("G")[:-n_ckpts_to_keep] + x_sorted("D")[:-n_ckpts_to_keep]) ] def del_info(fn): return logger.info(f".. Free up space by deleting ckpt {fn}") def del_routine(x): return [os.remove(x), del_info(x)] [del_routine(fn) for fn in to_del]