import os import matplotlib.pyplot as plt from torchvision.utils import make_grid import torch def show_tensor_images(image_tensor, epoch,step,otype,num_images=25, size=(1, 28, 28)): image_shifted = image_tensor image_unflat = image_shifted.detach().cpu().view(-1, *size) image_grid = make_grid(image_unflat[:num_images], nrow=5) if not os.path.exists(f"/outputs/Epoch{epoch}"): os.makedirs(f"/outputs/Epoch{epoch}") plt.savefig(os.path.join(f"/outputs/Epoch{epoch}_step_{step}_{otype}")) plt.close() def crop(image, new_shape): middle_height = image.shape[2] // 2 middle_width = image.shape[3] // 2 starting_height = middle_height - new_shape[2] // 2 final_height = starting_height + new_shape[2] starting_width = middle_width - new_shape[3] // 2 final_width = starting_width + new_shape[3] cropped_image = image[:, :, starting_height:final_height, starting_width:final_width] return cropped_image def get_gen_loss(gen, disc, real, condition, dev_criterion, recon_criterion, lambda_recon): gen_pred = gen(condition) disc_pred = disc(gen_pred, real) gen_adv_loss = dev_criterion(disc_pred, torch.ones_like(disc_pred)) gen_recon_loss = recon_criterion(real,gen_pred) gen_loss = gen_adv_loss + lambda_recon * gen_recon_loss return gen_loss