import glob import logging import os import shutil import socket import sys import ffmpeg import matplotlib import matplotlib.pylab as plt import numpy as np import torch from scipy.io.wavfile import read from torch.nn import functional as F from modules.shared import ROOT_DIR from .config import TrainConfig matplotlib.use("Agg") logging.getLogger("matplotlib").setLevel(logging.WARNING) logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) logger = logging def load_audio(file: str, sr): try: # https://github.com/openai/whisper/blob/main/whisper/audio.py#L26 # This launches a subprocess to decode audio while down-mixing and resampling as necessary. # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed. file = ( file.strip(" ").strip('"').strip("\n").strip('"').strip(" ") ) # Prevent small white copy path head and tail with spaces and " and return out, _ = ( ffmpeg.input(file, threads=0) .output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr) .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True) ) except Exception as e: raise RuntimeError(f"Failed to load audio: {e}") return np.frombuffer(out, np.float32).flatten() def find_empty_port(): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.bind(("", 0)) s.listen(1) port = s.getsockname()[1] s.close() return port def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1): assert os.path.isfile(checkpoint_path) checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") saved_state_dict = checkpoint_dict["model"] if hasattr(model, "module"): state_dict = model.module.state_dict() else: state_dict = model.state_dict() new_state_dict = {} for k, v in state_dict.items(): # 模型需要的shape try: new_state_dict[k] = saved_state_dict[k] if saved_state_dict[k].shape != state_dict[k].shape: print( f"shape-{k}-mismatch|need-{state_dict[k].shape}|get-{saved_state_dict[k].shape}" ) if saved_state_dict[k].dim() == 2: # NOTE: check is this ok? # for embedded input 256 <==> 768 # this achieves we can continue training from original's pretrained checkpoints when using embedder that 768-th dim output etc. if saved_state_dict[k].dtype == torch.half: new_state_dict[k] = ( F.interpolate( saved_state_dict[k].float().unsqueeze(0).unsqueeze(0), size=state_dict[k].shape, mode="bilinear", ) .half() .squeeze(0) .squeeze(0) ) else: new_state_dict[k] = ( F.interpolate( saved_state_dict[k].unsqueeze(0).unsqueeze(0), size=state_dict[k].shape, mode="bilinear", ) .squeeze(0) .squeeze(0) ) print( "interpolated new_state_dict", k, "from", saved_state_dict[k].shape, "to", new_state_dict[k].shape, ) else: raise KeyError except Exception as e: # print(traceback.format_exc()) print(f"{k} is not in the checkpoint") print("error: %s" % e) new_state_dict[k] = v # 模型自带的随机值 if hasattr(model, "module"): model.module.load_state_dict(new_state_dict, strict=False) else: model.load_state_dict(new_state_dict, strict=False) print("Loaded model weights") epoch = checkpoint_dict["epoch"] learning_rate = checkpoint_dict["learning_rate"] if optimizer is not None and load_opt == 1: optimizer.load_state_dict(checkpoint_dict["optimizer"]) print("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, epoch)) return model, optimizer, learning_rate, epoch def save_state(model, optimizer, learning_rate, epoch, checkpoint_path): print( "Saving model and optimizer state at epoch {} to {}".format( epoch, checkpoint_path ) ) if hasattr(model, "module"): state_dict = model.module.state_dict() else: state_dict = model.state_dict() torch.save( { "model": state_dict, "epoch": epoch, "optimizer": optimizer.state_dict(), "learning_rate": learning_rate, }, checkpoint_path, ) def summarize( writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050, ): for k, v in scalars.items(): writer.add_scalar(k, v, global_step) for k, v in histograms.items(): writer.add_histogram(k, v, global_step) for k, v in images.items(): writer.add_image(k, v, global_step, dataformats="HWC") for k, v in audios.items(): writer.add_audio(k, v, global_step, audio_sampling_rate) def latest_checkpoint_path(dir_path, regex="G_*.pth"): filelist = glob.glob(os.path.join(dir_path, regex)) if len(filelist) == 0: return None filelist.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) filepath = filelist[-1] return filepath def plot_spectrogram_to_numpy(spectrogram): fig, ax = plt.subplots(figsize=(10, 2)) im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") plt.colorbar(im, ax=ax) plt.xlabel("Frames") plt.ylabel("Channels") plt.tight_layout() fig.canvas.draw() data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) plt.close() return data def plot_alignment_to_numpy(alignment, info=None): fig, ax = plt.subplots(figsize=(6, 4)) im = ax.imshow( alignment.transpose(), aspect="auto", origin="lower", interpolation="none" ) fig.colorbar(im, ax=ax) xlabel = "Decoder timestep" if info is not None: xlabel += "\n\n" + info plt.xlabel(xlabel) plt.ylabel("Encoder timestep") plt.tight_layout() fig.canvas.draw() data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) plt.close() return data def load_wav_to_torch(full_path): sampling_rate, data = read(full_path) return torch.FloatTensor(data.astype(np.float32)), sampling_rate def load_config(training_dir: str, sample_rate: int, emb_channels: int): if emb_channels == 256: config_path = os.path.join(ROOT_DIR, "configs", f"{sample_rate}.json") else: config_path = os.path.join( ROOT_DIR, "configs", f"{sample_rate}-{emb_channels}.json" ) config_save_path = os.path.join(training_dir, "config.json") shutil.copyfile(config_path, config_save_path) return TrainConfig.parse_file(config_save_path)