Pix2PIx / utils.py
Yash Nagraj
Add train code with dataset.sh
52c73e3
import os
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import torch
def show_tensor_images(image_tensor, epoch,step,otype,num_images=25, size=(1, 28, 28)):
image_shifted = image_tensor
image_unflat = image_shifted.detach().cpu().view(-1, *size)
image_grid = make_grid(image_unflat[:num_images], nrow=5)
if not os.path.exists(f"/outputs/Epoch{epoch}"):
os.makedirs(f"/outputs/Epoch{epoch}")
plt.savefig(os.path.join(f"/outputs/Epoch{epoch}_step_{step}_{otype}"))
plt.close()
def crop(image, new_shape):
middle_height = image.shape[2] // 2
middle_width = image.shape[3] // 2
starting_height = middle_height - new_shape[2] // 2
final_height = starting_height + new_shape[2]
starting_width = middle_width - new_shape[3] // 2
final_width = starting_width + new_shape[3]
cropped_image = image[:, :, starting_height:final_height, starting_width:final_width]
return cropped_image
def get_gen_loss(gen, disc, real, condition, dev_criterion, recon_criterion, lambda_recon):
gen_pred = gen(condition)
disc_pred = disc(gen_pred, real)
gen_adv_loss = dev_criterion(disc_pred, torch.ones_like(disc_pred))
gen_recon_loss = recon_criterion(real,gen_pred)
gen_loss = gen_adv_loss + lambda_recon * gen_recon_loss
return gen_loss