cjayic's picture
init
f4b9544
import torch
import matplotlib
matplotlib.use("Agg")
import matplotlib.pylab as plt
def get_padding(k, d):
return int((k * d - d) / 2)
def plot_spectrogram(spectrogram):
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
fig.canvas.draw()
plt.close()
return fig
def save_checkpoint(
checkpoint_dir,
generator,
discriminator,
optimizer_generator,
optimizer_discriminator,
scheduler_generator,
scheduler_discriminator,
step,
loss,
best,
logger,
):
state = {
"generator": {
"model": generator.state_dict(),
"optimizer": optimizer_generator.state_dict(),
"scheduler": scheduler_generator.state_dict(),
},
"discriminator": {
"model": discriminator.state_dict(),
"optimizer": optimizer_discriminator.state_dict(),
"scheduler": scheduler_discriminator.state_dict(),
},
"step": step,
"loss": loss,
}
checkpoint_dir.mkdir(exist_ok=True, parents=True)
checkpoint_path = checkpoint_dir / f"model-{step}.pt"
torch.save(state, checkpoint_path)
if best:
best_path = checkpoint_dir / "model-best.pt"
torch.save(state, best_path)
logger.info(f"Saved checkpoint: {checkpoint_path.stem}")
def load_checkpoint(
load_path,
generator,
discriminator,
optimizer_generator,
optimizer_discriminator,
scheduler_generator,
scheduler_discriminator,
rank,
logger,
finetune=False,
):
logger.info(f"Loading checkpoint from {load_path}")
checkpoint = torch.load(load_path, map_location={"cuda:0": f"cuda:{rank}"})
generator.load_state_dict(checkpoint["generator"]["model"])
discriminator.load_state_dict(checkpoint["discriminator"]["model"])
if not finetune:
optimizer_generator.load_state_dict(checkpoint["generator"]["optimizer"])
scheduler_generator.load_state_dict(checkpoint["generator"]["scheduler"])
optimizer_discriminator.load_state_dict(
checkpoint["discriminator"]["optimizer"]
)
scheduler_discriminator.load_state_dict(
checkpoint["discriminator"]["scheduler"]
)
return checkpoint["step"], checkpoint["loss"]