from torchvision import transforms import torch import torch.nn as nn from torchvision.utils import make_grid from torch.utils.data import DataLoader import matplotlib.pyplot as plt import glob import os from torch.utils.data import Dataset from PIL import Image def show_tensor_images(image_tensor, epoch,step,num_images=25, size=(1, 28, 28)): image_tensor = (image_tensor + 1) / 2 image_shifted = image_tensor image_unflat = image_shifted.detach().cpu().view(-1, *size) image_grid = make_grid(image_unflat[:num_images], nrow=5) if not os.path.exists(f"/outputs/Epoch{epoch}"): os.makedirs(f"/outputs/Epoch{epoch}") plt.imshow(image_grid.permute(1, 2, 0).squeeze()) plt.savefig(os.path.join(f"outputs/Epoch{epoch}_step_{step}")) plt.close() class ImageDataset(Dataset): def __init__(self, root, transform=None, mode='train'): self.transform = transform self.files_A = sorted(glob.glob(os.path.join(root, '%sA' % mode) + '/*.*')) self.files_B = sorted(glob.glob(os.path.join(root, '%sB' % mode) + '/*.*')) if len(self.files_A) > len(self.files_B): self.files_A, self.files_B = self.files_B, self.files_A self.new_perm() assert len(self.files_A) > 0, "Make sure you downloaded the horse2zebra images!" def new_perm(self): self.randperm = torch.randperm(len(self.files_B))[:len(self.files_A)] def __getitem__(self, index): item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)])) item_B = self.transform(Image.open(self.files_B[self.randperm[index]])) if item_A.shape[0] != 3: item_A = item_A.repeat(3, 1, 1) if item_B.shape[0] != 3: item_B = item_B.repeat(3, 1, 1) if index == len(self) - 1: self.new_perm() # Old versions of PyTorch didn't support normalization for different-channeled images return (item_A - 0.5) * 2, (item_B - 0.5) * 2 def __len__(self): return min(len(self.files_A), len(self.files_B)) def weights_init(m): if isinstance(m,nn.Conv2d) or isinstance(m,nn.ConvTranspose2d): torch.nn.init.normal_(m.weight,1.0,0.2) if isinstance(m, nn.BatchNorm2d): torch.nn.init.normal_(m.weight, 0.0, 0.02) torch.nn.init.constant_(m.bias, 0) def get_disc_loss(real_X, fake_X,disc_X, adv_criterion): real_pred = disc_X(real_X.detach()) disc_real_loss = adv_criterion(real_pred,torch.ones_like(real_pred)) fake_pred = disc_X(fake_X.deatch()) disc_fake_loss = adv_criterion(fake_pred.detach(),torch.zeros_like(fake_pred)) disc_loss = (disc_real_loss + disc_fake_loss) / 2 return disc_loss def get_gen_adversarial_loss(real_X, disc_Y, gen_XY, adv_criterion): fake_Y = gen_XY(real_X.detach()) disc_pred = disc_Y(fake_Y) adverserial_loss = adv_criterion(disc_pred,torch.ones_like(disc_pred)) return adverserial_loss,fake_Y def get_identity_loss(real_X, gen_YX,identity_criterion): identity_X = gen_YX(real_X) identity_loss = identity_criterion(identity_X,real_X) return identity_loss,identity_X def get_cycle_consistency_loss(real_X, fake_Y, gen_YX, cycle_criterion): cycle_X = gen_YX(fake_Y) cycle_loss = cycle_criterion(cycle_X,real_X) return cycle_loss,cycle_X def get_gen_loss(real_A, real_B,gen_AB,gen_BA,disc_B,disc_A,adv_criterion,cycle_criterion,identity_criterion,lambda_identity=0.2,lambda_cycle=10): adv_loss_BA, fake_A = get_gen_adversarial_loss(real_B, disc_A, gen_BA, adv_criterion) adv_loss_AB, fake_B = get_gen_adversarial_loss(real_A, disc_B, gen_AB, adv_criterion) gen_adversarial_loss = adv_loss_BA + adv_loss_AB # Identity Loss -- get_identity_loss(real_X, gen_YX, identity_criterion) identity_loss_A, identity_A = get_identity_loss(real_A, gen_BA, identity_criterion) identity_loss_B, identity_B = get_identity_loss(real_B, gen_AB, identity_criterion) gen_identity_loss = identity_loss_A + identity_loss_B # Cycle-consistency Loss -- get_cycle_consistency_loss(real_X, fake_Y, gen_YX, cycle_criterion) cycle_loss_BA, cycle_A = get_cycle_consistency_loss(real_A, fake_B, gen_BA, cycle_criterion) cycle_loss_AB, cycle_B = get_cycle_consistency_loss(real_B, fake_A, gen_AB, cycle_criterion) gen_cycle_loss = cycle_loss_BA + cycle_loss_AB # Total loss gen_loss = lambda_identity * gen_identity_loss + lambda_cycle * gen_cycle_loss + gen_adversarial_loss return gen_loss , fake_A,fake_B