| import os |
| import matplotlib.pyplot as plt |
| from torchvision.utils import make_grid |
| import torch |
|
|
| def show_tensor_images(image_tensor, epoch,step,otype,num_images=25, size=(1, 28, 28)): |
| image_shifted = image_tensor |
| image_unflat = image_shifted.detach().cpu().view(-1, *size) |
| image_grid = make_grid(image_unflat[:num_images], nrow=5) |
| if not os.path.exists(f"/outputs/Epoch{epoch}"): |
| os.makedirs(f"/outputs/Epoch{epoch}") |
| plt.savefig(os.path.join(f"/outputs/Epoch{epoch}_step_{step}_{otype}")) |
| plt.close() |
|
|
|
|
| def crop(image, new_shape): |
| middle_height = image.shape[2] // 2 |
| middle_width = image.shape[3] // 2 |
| starting_height = middle_height - new_shape[2] // 2 |
| final_height = starting_height + new_shape[2] |
| starting_width = middle_width - new_shape[3] // 2 |
| final_width = starting_width + new_shape[3] |
| cropped_image = image[:, :, starting_height:final_height, starting_width:final_width] |
| return cropped_image |
|
|
|
|
| def get_gen_loss(gen, disc, real, condition, dev_criterion, recon_criterion, lambda_recon): |
| gen_pred = gen(condition) |
| disc_pred = disc(gen_pred, real) |
| gen_adv_loss = dev_criterion(disc_pred, torch.ones_like(disc_pred)) |
| gen_recon_loss = recon_criterion(real,gen_pred) |
| gen_loss = gen_adv_loss + lambda_recon * gen_recon_loss |
|
|
| return gen_loss |