Spaces:
Runtime error
Runtime error
| import torch | |
| import config | |
| from torchvision.utils import save_image | |
| def save_some_examples(gen, val_loader, epoch, folder): | |
| x, y = next(iter(val_loader)) | |
| x, y = x.to(config.DEVICE), y.to(config.DEVICE) | |
| gen.eval() | |
| with torch.no_grad(): | |
| y_fake = gen(x) | |
| y_fake = y_fake * 0.5 + 0.5 # remove normalization# | |
| save_image(y_fake, folder + f"/y_gen_{epoch}.png") | |
| save_image(x * 0.5 + 0.5, folder + f"/input_{epoch}.png") | |
| if epoch == 1: | |
| save_image(y * 0.5 + 0.5, folder + f"/label_{epoch}.png") | |
| gen.train() | |
| def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"): | |
| print("=> Saving checkpoint") | |
| checkpoint = { | |
| "state_dict": model.state_dict(), | |
| "optimizer": optimizer.state_dict(), | |
| } | |
| torch.save(checkpoint, filename) | |
| def load_checkpoint(checkpoint_file, model, optimizer, lr): | |
| print("=> Loading checkpoint") | |
| checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE) | |
| model.load_state_dict(checkpoint["state_dict"]) | |
| optimizer.load_state_dict(checkpoint["optimizer"]) | |
| # If we don't do this then it will just have learning rate of old checkpoint | |
| # and it will lead to many hours of debugging \: | |
| for param_group in optimizer.param_groups: | |
| param_group["lr"] = lr | |