Spaces:
Sleeping
Sleeping
File size: 7,598 Bytes
b5a064f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
import glob
import logging
import os
import shutil
import socket
import sys
import ffmpeg
import matplotlib
import matplotlib.pylab as plt
import numpy as np
import torch
from scipy.io.wavfile import read
from torch.nn import functional as F
from modules.shared import ROOT_DIR
from .config import TrainConfig
matplotlib.use("Agg")
logging.getLogger("matplotlib").setLevel(logging.WARNING)
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logger = logging
def load_audio(file: str, sr):
try:
# https://github.com/openai/whisper/blob/main/whisper/audio.py#L26
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
file = (
file.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
) # Prevent small white copy path head and tail with spaces and " and return
out, _ = (
ffmpeg.input(file, threads=0)
.output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr)
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
)
except Exception as e:
raise RuntimeError(f"Failed to load audio: {e}")
return np.frombuffer(out, np.float32).flatten()
def find_empty_port():
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
s.listen(1)
port = s.getsockname()[1]
s.close()
return port
def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
saved_state_dict = checkpoint_dict["model"]
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
new_state_dict = {}
for k, v in state_dict.items(): # 模型需要的shape
try:
new_state_dict[k] = saved_state_dict[k]
if saved_state_dict[k].shape != state_dict[k].shape:
print(
f"shape-{k}-mismatch|need-{state_dict[k].shape}|get-{saved_state_dict[k].shape}"
)
if saved_state_dict[k].dim() == 2: # NOTE: check is this ok?
# for embedded input 256 <==> 768
# this achieves we can continue training from original's pretrained checkpoints when using embedder that 768-th dim output etc.
if saved_state_dict[k].dtype == torch.half:
new_state_dict[k] = (
F.interpolate(
saved_state_dict[k].float().unsqueeze(0).unsqueeze(0),
size=state_dict[k].shape,
mode="bilinear",
)
.half()
.squeeze(0)
.squeeze(0)
)
else:
new_state_dict[k] = (
F.interpolate(
saved_state_dict[k].unsqueeze(0).unsqueeze(0),
size=state_dict[k].shape,
mode="bilinear",
)
.squeeze(0)
.squeeze(0)
)
print(
"interpolated new_state_dict",
k,
"from",
saved_state_dict[k].shape,
"to",
new_state_dict[k].shape,
)
else:
raise KeyError
except Exception as e:
# print(traceback.format_exc())
print(f"{k} is not in the checkpoint")
print("error: %s" % e)
new_state_dict[k] = v # 模型自带的随机值
if hasattr(model, "module"):
model.module.load_state_dict(new_state_dict, strict=False)
else:
model.load_state_dict(new_state_dict, strict=False)
print("Loaded model weights")
epoch = checkpoint_dict["epoch"]
learning_rate = checkpoint_dict["learning_rate"]
if optimizer is not None and load_opt == 1:
optimizer.load_state_dict(checkpoint_dict["optimizer"])
print("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, epoch))
return model, optimizer, learning_rate, epoch
def save_state(model, optimizer, learning_rate, epoch, checkpoint_path):
print(
"Saving model and optimizer state at epoch {} to {}".format(
epoch, checkpoint_path
)
)
if hasattr(model, "module"):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
torch.save(
{
"model": state_dict,
"epoch": epoch,
"optimizer": optimizer.state_dict(),
"learning_rate": learning_rate,
},
checkpoint_path,
)
def summarize(
writer,
global_step,
scalars={},
histograms={},
images={},
audios={},
audio_sampling_rate=22050,
):
for k, v in scalars.items():
writer.add_scalar(k, v, global_step)
for k, v in histograms.items():
writer.add_histogram(k, v, global_step)
for k, v in images.items():
writer.add_image(k, v, global_step, dataformats="HWC")
for k, v in audios.items():
writer.add_audio(k, v, global_step, audio_sampling_rate)
def latest_checkpoint_path(dir_path, regex="G_*.pth"):
filelist = glob.glob(os.path.join(dir_path, regex))
if len(filelist) == 0:
return None
filelist.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
filepath = filelist[-1]
return filepath
def plot_spectrogram_to_numpy(spectrogram):
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
def plot_alignment_to_numpy(alignment, info=None):
fig, ax = plt.subplots(figsize=(6, 4))
im = ax.imshow(
alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
)
fig.colorbar(im, ax=ax)
xlabel = "Decoder timestep"
if info is not None:
xlabel += "\n\n" + info
plt.xlabel(xlabel)
plt.ylabel("Encoder timestep")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
def load_wav_to_torch(full_path):
sampling_rate, data = read(full_path)
return torch.FloatTensor(data.astype(np.float32)), sampling_rate
def load_config(training_dir: str, sample_rate: int, emb_channels: int):
if emb_channels == 256:
config_path = os.path.join(ROOT_DIR, "configs", f"{sample_rate}.json")
else:
config_path = os.path.join(
ROOT_DIR, "configs", f"{sample_rate}-{emb_channels}.json"
)
config_save_path = os.path.join(training_dir, "config.json")
shutil.copyfile(config_path, config_save_path)
return TrainConfig.parse_file(config_save_path)
|