File size: 1,335 Bytes
52c73e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import os
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import torch

def show_tensor_images(image_tensor, epoch,step,otype,num_images=25, size=(1, 28, 28)):
    image_shifted = image_tensor
    image_unflat = image_shifted.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    if not os.path.exists(f"/outputs/Epoch{epoch}"):
        os.makedirs(f"/outputs/Epoch{epoch}")
    plt.savefig(os.path.join(f"/outputs/Epoch{epoch}_step_{step}_{otype}"))
    plt.close()


def crop(image, new_shape):
    middle_height = image.shape[2] // 2
    middle_width = image.shape[3] // 2
    starting_height = middle_height - new_shape[2] // 2
    final_height = starting_height + new_shape[2]
    starting_width = middle_width - new_shape[3] // 2
    final_width = starting_width + new_shape[3]
    cropped_image = image[:, :, starting_height:final_height, starting_width:final_width]    
    return cropped_image


def get_gen_loss(gen, disc, real, condition, dev_criterion, recon_criterion, lambda_recon):
    gen_pred = gen(condition)
    disc_pred = disc(gen_pred, real)
    gen_adv_loss = dev_criterion(disc_pred, torch.ones_like(disc_pred))
    gen_recon_loss = recon_criterion(real,gen_pred)
    gen_loss = gen_adv_loss + lambda_recon * gen_recon_loss

    return gen_loss