File size: 1,437 Bytes
2d55a06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from torchvision.utils import save_image
import os
import torch
import config

def denormalize(imgs, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
    for i in range(3):
        imgs[:, i, :, :] = imgs[:, i, :, :] * std[i] + mean[i]
    return imgs

def save_val_predictions(gen, val_loader, epoch, folder_path):
    x, y = next(iter(val_loader))
    x, y_real = x.to(config.device), y.to(config.device)

    gen.eval()
    with torch.no_grad():
        y_fake = gen(x)
        y_real = denormalize(y_real)
        y_fake = denormalize(y_fake)
        concat_imgs = torch.cat([y_real, y_fake], dim=2)
        for i in range(len(concat_imgs)):
            concat_img = concat_imgs[i]
            save_image(concat_img, os.path.join(folder_path, f"image_{i}_epoch_{epoch}.png"))
    gen.train()

def save_loss(loss, epoch, file_path):
    with open(file_path, "a") as f:
        f.write(f'{epoch} {loss}\n')

def save_checkpoint(net, optimizer, path):
    ckpt = {
        'net': net.state_dict(),
        'optimizer': optimizer.state_dict()
    }
    torch.save(ckpt, path)
    print(f'Save model into {path} successfully')

def load_checkpoint(net, path, optimizer=None):
    ckpt = torch.load(path, map_location=config.device)
    net.load_state_dict(ckpt['net'])
    if optimizer is not None:
        optimizer.load_state_dict(ckpt['optimizer'])
    print(f'Load model from {path} successfully')