Spaces:
Build error
Build error
| import torch | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pylab as plt | |
| def get_padding(k, d): | |
| return int((k * d - d) / 2) | |
| def plot_spectrogram(spectrogram): | |
| fig, ax = plt.subplots(figsize=(10, 2)) | |
| im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") | |
| plt.colorbar(im, ax=ax) | |
| fig.canvas.draw() | |
| plt.close() | |
| return fig | |
| def save_checkpoint( | |
| checkpoint_dir, | |
| generator, | |
| discriminator, | |
| optimizer_generator, | |
| optimizer_discriminator, | |
| scheduler_generator, | |
| scheduler_discriminator, | |
| step, | |
| loss, | |
| best, | |
| logger, | |
| ): | |
| state = { | |
| "generator": { | |
| "model": generator.state_dict(), | |
| "optimizer": optimizer_generator.state_dict(), | |
| "scheduler": scheduler_generator.state_dict(), | |
| }, | |
| "discriminator": { | |
| "model": discriminator.state_dict(), | |
| "optimizer": optimizer_discriminator.state_dict(), | |
| "scheduler": scheduler_discriminator.state_dict(), | |
| }, | |
| "step": step, | |
| "loss": loss, | |
| } | |
| checkpoint_dir.mkdir(exist_ok=True, parents=True) | |
| checkpoint_path = checkpoint_dir / f"model-{step}.pt" | |
| torch.save(state, checkpoint_path) | |
| if best: | |
| best_path = checkpoint_dir / "model-best.pt" | |
| torch.save(state, best_path) | |
| logger.info(f"Saved checkpoint: {checkpoint_path.stem}") | |
| def load_checkpoint( | |
| load_path, | |
| generator, | |
| discriminator, | |
| optimizer_generator, | |
| optimizer_discriminator, | |
| scheduler_generator, | |
| scheduler_discriminator, | |
| rank, | |
| logger, | |
| finetune=False, | |
| ): | |
| logger.info(f"Loading checkpoint from {load_path}") | |
| checkpoint = torch.load(load_path, map_location={"cuda:0": f"cuda:{rank}"}) | |
| generator.load_state_dict(checkpoint["generator"]["model"]) | |
| discriminator.load_state_dict(checkpoint["discriminator"]["model"]) | |
| if not finetune: | |
| optimizer_generator.load_state_dict(checkpoint["generator"]["optimizer"]) | |
| scheduler_generator.load_state_dict(checkpoint["generator"]["scheduler"]) | |
| optimizer_discriminator.load_state_dict( | |
| checkpoint["discriminator"]["optimizer"] | |
| ) | |
| scheduler_discriminator.load_state_dict( | |
| checkpoint["discriminator"]["scheduler"] | |
| ) | |
| return checkpoint["step"], checkpoint["loss"] | |