Spaces:
Build error
Build error
File size: 1,437 Bytes
2d55a06 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 | from torchvision.utils import save_image
import os
import torch
import config
def denormalize(imgs, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
for i in range(3):
imgs[:, i, :, :] = imgs[:, i, :, :] * std[i] + mean[i]
return imgs
def save_val_predictions(gen, val_loader, epoch, folder_path):
x, y = next(iter(val_loader))
x, y_real = x.to(config.device), y.to(config.device)
gen.eval()
with torch.no_grad():
y_fake = gen(x)
y_real = denormalize(y_real)
y_fake = denormalize(y_fake)
concat_imgs = torch.cat([y_real, y_fake], dim=2)
for i in range(len(concat_imgs)):
concat_img = concat_imgs[i]
save_image(concat_img, os.path.join(folder_path, f"image_{i}_epoch_{epoch}.png"))
gen.train()
def save_loss(loss, epoch, file_path):
with open(file_path, "a") as f:
f.write(f'{epoch} {loss}\n')
def save_checkpoint(net, optimizer, path):
ckpt = {
'net': net.state_dict(),
'optimizer': optimizer.state_dict()
}
torch.save(ckpt, path)
print(f'Save model into {path} successfully')
def load_checkpoint(net, path, optimizer=None):
ckpt = torch.load(path, map_location=config.device)
net.load_state_dict(ckpt['net'])
if optimizer is not None:
optimizer.load_state_dict(ckpt['optimizer'])
print(f'Load model from {path} successfully') |