rvc_api / lib /rvc /utils.py
aryo100's picture
first commit
b5a064f
import glob
import logging
import os
import shutil
import socket
import sys
import ffmpeg
import matplotlib
import matplotlib.pylab as plt
import numpy as np
import torch
from scipy.io.wavfile import read
from torch.nn import functional as F
from modules.shared import ROOT_DIR
from .config import TrainConfig
matplotlib.use("Agg")
logging.getLogger("matplotlib").setLevel(logging.WARNING)
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logger = logging
def load_audio(file: str, sr):
try:
# https://github.com/openai/whisper/blob/main/whisper/audio.py#L26
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
file = (
file.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
) # Prevent small white copy path head and tail with spaces and " and return
out, _ = (
ffmpeg.input(file, threads=0)
.output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr)
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
)
except Exception as e:
raise RuntimeError(f"Failed to load audio: {e}")
return np.frombuffer(out, np.float32).flatten()
def find_empty_port():
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
s.listen(1)
port = s.getsockname()[1]
s.close()
return port
def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
saved_state_dict = checkpoint_dict["model"]
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
new_state_dict = {}
for k, v in state_dict.items(): # 模型需要的shape
try:
new_state_dict[k] = saved_state_dict[k]
if saved_state_dict[k].shape != state_dict[k].shape:
print(
f"shape-{k}-mismatch|need-{state_dict[k].shape}|get-{saved_state_dict[k].shape}"
)
if saved_state_dict[k].dim() == 2: # NOTE: check is this ok?
# for embedded input 256 <==> 768
# this achieves we can continue training from original's pretrained checkpoints when using embedder that 768-th dim output etc.
if saved_state_dict[k].dtype == torch.half:
new_state_dict[k] = (
F.interpolate(
saved_state_dict[k].float().unsqueeze(0).unsqueeze(0),
size=state_dict[k].shape,
mode="bilinear",
)
.half()
.squeeze(0)
.squeeze(0)
)
else:
new_state_dict[k] = (
F.interpolate(
saved_state_dict[k].unsqueeze(0).unsqueeze(0),
size=state_dict[k].shape,
mode="bilinear",
)
.squeeze(0)
.squeeze(0)
)
print(
"interpolated new_state_dict",
k,
"from",
saved_state_dict[k].shape,
"to",
new_state_dict[k].shape,
)
else:
raise KeyError
except Exception as e:
# print(traceback.format_exc())
print(f"{k} is not in the checkpoint")
print("error: %s" % e)
new_state_dict[k] = v # 模型自带的随机值
if hasattr(model, "module"):
model.module.load_state_dict(new_state_dict, strict=False)
else:
model.load_state_dict(new_state_dict, strict=False)
print("Loaded model weights")
epoch = checkpoint_dict["epoch"]
learning_rate = checkpoint_dict["learning_rate"]
if optimizer is not None and load_opt == 1:
optimizer.load_state_dict(checkpoint_dict["optimizer"])
print("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, epoch))
return model, optimizer, learning_rate, epoch
def save_state(model, optimizer, learning_rate, epoch, checkpoint_path):
print(
"Saving model and optimizer state at epoch {} to {}".format(
epoch, checkpoint_path
)
)
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
torch.save(
{
"model": state_dict,
"epoch": epoch,
"optimizer": optimizer.state_dict(),
"learning_rate": learning_rate,
},
checkpoint_path,
)
def summarize(
writer,
global_step,
scalars={},
histograms={},
images={},
audios={},
audio_sampling_rate=22050,
):
for k, v in scalars.items():
writer.add_scalar(k, v, global_step)
for k, v in histograms.items():
writer.add_histogram(k, v, global_step)
for k, v in images.items():
writer.add_image(k, v, global_step, dataformats="HWC")
for k, v in audios.items():
writer.add_audio(k, v, global_step, audio_sampling_rate)
def latest_checkpoint_path(dir_path, regex="G_*.pth"):
filelist = glob.glob(os.path.join(dir_path, regex))
if len(filelist) == 0:
return None
filelist.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
filepath = filelist[-1]
return filepath
def plot_spectrogram_to_numpy(spectrogram):
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
def plot_alignment_to_numpy(alignment, info=None):
fig, ax = plt.subplots(figsize=(6, 4))
im = ax.imshow(
alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
)
fig.colorbar(im, ax=ax)
xlabel = "Decoder timestep"
if info is not None:
xlabel += "\n\n" + info
plt.xlabel(xlabel)
plt.ylabel("Encoder timestep")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
def load_wav_to_torch(full_path):
sampling_rate, data = read(full_path)
return torch.FloatTensor(data.astype(np.float32)), sampling_rate
def load_config(training_dir: str, sample_rate: int, emb_channels: int):
if emb_channels == 256:
config_path = os.path.join(ROOT_DIR, "configs", f"{sample_rate}.json")
else:
config_path = os.path.join(
ROOT_DIR, "configs", f"{sample_rate}-{emb_channels}.json"
)
config_save_path = os.path.join(training_dir, "config.json")
shutil.copyfile(config_path, config_save_path)
return TrainConfig.parse_file(config_save_path)