CycleGAN / utils.py
Yash Nagraj
Change the display function to save instead of show
dc00587
from torchvision import transforms
import torch
import torch.nn as nn
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import glob
import os
from torch.utils.data import Dataset
from PIL import Image
def show_tensor_images(image_tensor, epoch,step,num_images=25, size=(1, 28, 28)):
image_tensor = (image_tensor + 1) / 2
image_shifted = image_tensor
image_unflat = image_shifted.detach().cpu().view(-1, *size)
image_grid = make_grid(image_unflat[:num_images], nrow=5)
if not os.path.exists(f"/outputs/Epoch{epoch}"):
os.makedirs(f"/outputs/Epoch{epoch}")
plt.imshow(image_grid.permute(1, 2, 0).squeeze())
plt.savefig(os.path.join(f"outputs/Epoch{epoch}_step_{step}"))
plt.close()
class ImageDataset(Dataset):
def __init__(self, root, transform=None, mode='train'):
self.transform = transform
self.files_A = sorted(glob.glob(os.path.join(root, '%sA' % mode) + '/*.*'))
self.files_B = sorted(glob.glob(os.path.join(root, '%sB' % mode) + '/*.*'))
if len(self.files_A) > len(self.files_B):
self.files_A, self.files_B = self.files_B, self.files_A
self.new_perm()
assert len(self.files_A) > 0, "Make sure you downloaded the horse2zebra images!"
def new_perm(self):
self.randperm = torch.randperm(len(self.files_B))[:len(self.files_A)]
def __getitem__(self, index):
item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)]))
item_B = self.transform(Image.open(self.files_B[self.randperm[index]]))
if item_A.shape[0] != 3:
item_A = item_A.repeat(3, 1, 1)
if item_B.shape[0] != 3:
item_B = item_B.repeat(3, 1, 1)
if index == len(self) - 1:
self.new_perm()
# Old versions of PyTorch didn't support normalization for different-channeled images
return (item_A - 0.5) * 2, (item_B - 0.5) * 2
def __len__(self):
return min(len(self.files_A), len(self.files_B))
def weights_init(m):
if isinstance(m,nn.Conv2d) or isinstance(m,nn.ConvTranspose2d):
torch.nn.init.normal_(m.weight,1.0,0.2)
if isinstance(m, nn.BatchNorm2d):
torch.nn.init.normal_(m.weight, 0.0, 0.02)
torch.nn.init.constant_(m.bias, 0)
def get_disc_loss(real_X, fake_X,disc_X, adv_criterion):
real_pred = disc_X(real_X.detach())
disc_real_loss = adv_criterion(real_pred,torch.ones_like(real_pred))
fake_pred = disc_X(fake_X.deatch())
disc_fake_loss = adv_criterion(fake_pred.detach(),torch.zeros_like(fake_pred))
disc_loss = (disc_real_loss + disc_fake_loss) / 2
return disc_loss
def get_gen_adversarial_loss(real_X, disc_Y, gen_XY, adv_criterion):
fake_Y = gen_XY(real_X.detach())
disc_pred = disc_Y(fake_Y)
adverserial_loss = adv_criterion(disc_pred,torch.ones_like(disc_pred))
return adverserial_loss,fake_Y
def get_identity_loss(real_X, gen_YX,identity_criterion):
identity_X = gen_YX(real_X)
identity_loss = identity_criterion(identity_X,real_X)
return identity_loss,identity_X
def get_cycle_consistency_loss(real_X, fake_Y, gen_YX, cycle_criterion):
cycle_X = gen_YX(fake_Y)
cycle_loss = cycle_criterion(cycle_X,real_X)
return cycle_loss,cycle_X
def get_gen_loss(real_A, real_B,gen_AB,gen_BA,disc_B,disc_A,adv_criterion,cycle_criterion,identity_criterion,lambda_identity=0.2,lambda_cycle=10):
adv_loss_BA, fake_A = get_gen_adversarial_loss(real_B, disc_A, gen_BA, adv_criterion)
adv_loss_AB, fake_B = get_gen_adversarial_loss(real_A, disc_B, gen_AB, adv_criterion)
gen_adversarial_loss = adv_loss_BA + adv_loss_AB
# Identity Loss -- get_identity_loss(real_X, gen_YX, identity_criterion)
identity_loss_A, identity_A = get_identity_loss(real_A, gen_BA, identity_criterion)
identity_loss_B, identity_B = get_identity_loss(real_B, gen_AB, identity_criterion)
gen_identity_loss = identity_loss_A + identity_loss_B
# Cycle-consistency Loss -- get_cycle_consistency_loss(real_X, fake_Y, gen_YX, cycle_criterion)
cycle_loss_BA, cycle_A = get_cycle_consistency_loss(real_A, fake_B, gen_BA, cycle_criterion)
cycle_loss_AB, cycle_B = get_cycle_consistency_loss(real_B, fake_A, gen_AB, cycle_criterion)
gen_cycle_loss = cycle_loss_BA + cycle_loss_AB
# Total loss
gen_loss = lambda_identity * gen_identity_loss + lambda_cycle * gen_cycle_loss + gen_adversarial_loss
return gen_loss , fake_A,fake_B