import pandas as pd import numpy as np import torch import torch.nn as nn import torchvision from torchvision import transforms from torch.utils.data import Dataset, DataLoader import os import random from PIL import Image import torchvision.transforms.functional as F import matplotlib.pyplot as plt from tqdm.notebook import tqdm import itertools from torch import autograd import torch.distributed as dist import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP from imagePool import ImagePool from resnetGen import ResnetGenerator from nlayerDis import NLayerDiscriminator # Define Device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') """ Create Dataset and Dataloader """ import sys sys.path.append("/iitjhome/m23csa016") TRAIN_LABELS = "Assignment_4/Train/Train_labels.csv" TEST_LABELS = "Assignment_4/Test/Test_Labels.csv" TRAIN_DATA_DIR = "Assignment_4/Train/Train_data" TEST_DATA_DIR = "Assignment_4/Test/Test" TRAIN_SKETCH_DIR = "Assignment_4/Train/Contours" TEST_SKETCH_DIR = "Assignment_4/Test/Test_contours" # Create Dataset class ISICDataset(Dataset): def __init__(self, datadir, csvpath, sketchdir, transform=None): self.datadir = datadir self.csv = pd.read_csv(csvpath) self.sketchdir = sketchdir self.transform = transform def __len__(self): return len(self.csv[:300]) def __getitem__(self, index): img_path = os.path.join(self.datadir, self.csv.iloc[index, 0] + ".jpg") image = Image.open(img_path) labels = self.csv.iloc[index, 1:].values # label = np.argmax(labels, axis=0) sketch_name = random.choice(os.listdir(self.sketchdir)) sketch_path = os.path.join(self.sketchdir, sketch_name) fs, ext = os.path.splitext(sketch_path) while ext not in ['.jpg', '.jpeg', '.png']: sketch_name = random.choice(os.listdir(self.sketchdir)) sketch_path = os.path.join(self.sketchdir, sketch_name) fs, ext = os.path.splitext(sketch_path) sketch = Image.open(sketch_path) if self.transform: image = self.transform(image) sketch = self.transform(sketch) x, y = int(image.size(1)), int(image.size(1) / 7) labels = np.array(labels, dtype=np.float32) labels = np.tile(labels,(x,y)) label = torch.tensor(labels, dtype=torch.float32) return label, image, sketch transform = transforms.Compose([ transforms.Resize((56, 56)), transforms.ToTensor() ]) # Train Dataset and Dataloader train_dataset = ISICDataset(TRAIN_DATA_DIR, TRAIN_LABELS, TRAIN_SKETCH_DIR, transform=transform) train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=2) # Train Dataset and Dataloader test_dataset = ISICDataset(TEST_DATA_DIR, TEST_LABELS, TEST_SKETCH_DIR, transform=transform) test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=2) """ END """ def show(r_img, c_limage, fake_img): fig, axes = plt.subplots(1, 3, figsize=(5, 6)) r_img = r_img.squeeze(0) c_limage = c_limage.squeeze(0) fake_img = fake_img.squeeze(0) r_img = r_img.detach() r_img = F.to_pil_image(r_img) axes[0].imshow(r_img) axes[0].set_title('Original Image') axes[0].axis('off') # Plot the mask c_limage = c_limage.detach() c_limage = F.to_pil_image(c_limage) axes[1].imshow(c_limage) axes[1].set_title('Image & Label') axes[1].axis('off') # Plot the segmented mask fake_img = fake_img.detach() fake_img = F.to_pil_image(fake_img) axes[2].imshow(fake_img) axes[2].set_title('Generated Image') axes[2].axis('off') plt.tight_layout() plt.show() class CGANTrainer(): def __init__(self, rank): super().__init__() self.optimizers = [] self.lamb = 10.0 self.label_embed = nn.Sequential( nn.Embedding(7, 100), nn.Linear(100, 64*64) ).to(device) self.genA = ResnetGenerator(input_nc=3, output_nc=3).to(rank) self.genA = DDP(self.genA, device_ids=[rank]) self.genB = ResnetGenerator(input_nc=3, output_nc=3).to(rank) self.genB = DDP(self.genB, device_ids=[rank]) self.disA = NLayerDiscriminator(input_nc=3).to(rank) self.disA = DDP(self.disA, device_ids=[rank]) self.disB = NLayerDiscriminator(input_nc=3).to(rank) self.disB = DDP(self.disB, device_ids=[rank]) self.fakeA_pool = ImagePool(pool_size=50) self.fakeB_pool = ImagePool(pool_size=50) self.GANloss = nn.BCEWithLogitsLoss() self.cycleLoss = nn.L1Loss() self.optimizer_G = torch.optim.Adam(itertools.chain(self.genA.parameters(), self.genB.parameters()), lr=0.0002, betas=(0.5, 0.999)) self.optimizer_D = torch.optim.Adam(itertools.chain(self.disA.parameters(), self.disB.parameters()), lr=0.0002, betas=(0.5, 0.999)) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) # Gradient Penalty for WGAN def gradient_penalty(self, dis, real, fake): alpha = torch.rand(real.size(0), real.size(1), 1, 1) alpha = alpha.expand(real.size()) alpha = alpha.float().to(device) xhat = alpha * real + (1-alpha) * fake xhat = xhat.float().to(device) xhat = autograd.Variable(xhat, requires_grad = True) xhat_D = dis(xhat) grad = autograd.grad( outputs=xhat_D, inputs=xhat, grad_outputs=torch.ones(xhat_D.size()).to(device), create_graph=True, retain_graph=True, only_inputs=True )[0] penalty = ((grad.norm(2, dim=1) - 1) ** 2).mean() * 0.5 return penalty def set_requires_grad(self, nets, requires_grad=False): """Set requies_grad=False for all the networks to avoid unnecessary computations Parameters: nets (network list) -- a list of networks requires_grad (bool) -- whether the networks require gradients or not """ if not isinstance(nets, list): nets = [nets] for net in nets: if net is not None: for param in net.parameters(): param.requires_grad = requires_grad def backward_D_basic(self, netD, real, fake): """Calculate GAN loss for the discriminator Parameters: netD (network) -- the discriminator D real (tensor array) -- real images fake (tensor array) -- images generated by a generator Return the discriminator loss. We also call loss_D.backward() to calculate the gradients. """ # Real pred_real = netD(real).mean() # self.real_target = self.real_target.expand_as(pred_real) # loss_D_real = self.GANloss(pred_real, self.real_target) # Fake pred_fake = netD(fake.detach()).mean() # self.fake_target = self.real_target.expand_as(pred_fake) # loss_D_fake = self.GANloss(pred_fake, self.fake_target) # Combined loss and calculate gradients gp = self.gradient_penalty(netD, real, fake.detach()) loss_D = (pred_fake - pred_real) + gp loss_D.backward(retain_graph=True) return loss_D def backward_disA(self): """Calculate GAN loss for discriminator disA""" fake_B = self.fakeB_pool.query(self.fake_sketch) self.loss_disA = self.backward_D_basic(self.disA, self.concat_ls, fake_B) def backward_disB(self): """Calculate GAN loss for discriminator disB""" fake_A = self.fakeA_pool.query(self.fake_image) self.loss_disB = self.backward_D_basic(self.disB, self.concat_li, fake_A) # Generator Backpropagation Function def backward_G(self): """Calculate the loss for generators genA and genB""" # GAN loss disA(genA(image)) fake_prediction_A = self.disA(self.fake_sketch).mean() real_prediction_A = self.disA(self.sketch).mean() gp_A = self.gradient_penalty(self.disA, self.sketch, self.fake_sketch) OTdis_A = (fake_prediction_A - real_prediction_A) + gp_A fake_prediction_B = self.disB(self.fake_image).mean() real_prediction_B = self.disB(self.image).mean() gp_B = self.gradient_penalty(self.disB, self.image, self.fake_image) OTdis_B = (fake_prediction_B - real_prediction_B) + gp_B # Forward cycle loss || genB(genA(image)) - image || self.loss_cycle_A = self.cycleLoss(self.rec_image, self.concat_li) * self.lamb # Backward cycle loss || genA(genB(sketch)) - sketch || self.loss_cycle_B = self.cycleLoss(self.rec_sketch, self.concat_ls) * self.lamb # combined loss and calculate gradients self.loss_G = self.loss_cycle_A + self.loss_cycle_B - (OTdis_A + OTdis_B) self.loss_G.backward(retain_graph=True) def train(self, dataloader, epochs=10): for epoch in range(1, epochs+1): total_dloss = 0.0 total_gloss = 0.0 b_dloss, b_gloss = 0.0, 0.0 for index, input in tqdm(enumerate(dataloader), total=len(dataloader)): self.label, self.image, self.sketch = input self.sketch = torch.repeat_interleave(self.sketch, 3, dim=1) self.label = self.label.to(device) self.image = self.image.to(device) self.sketch = self.sketch.to(device) # label_output = self.label_embed(self.label) # (32*32) self.label = self.label.unsqueeze(1) self.concat_li = self.image + self.label self.concat_ls = self.sketch + self.label self.real_target = torch.ones(self.image.size(0), 1, 1, 1).to(device) self.fake_target = torch.zeros(self.image.size(0), 1, 1, 1).to(device) self.fake_sketch = self.genA(self.concat_li) self.rec_image = self.genB(self.fake_sketch) self.fake_image = self.genB(self.concat_ls) self.rec_sketch = self.genA(self.fake_image) # Freeze Discriminator to avoid unnecessary calculations self.set_requires_grad([self.disA, self.disB], False) # Start training Generator (genA & genB) self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # Start training Discriminator (disA & disB) self.set_requires_grad([self.disA, self.disB], True) self.optimizer_D.zero_grad() self.backward_disA() self.backward_disB() self.optimizer_D.step() total_dloss += (self.loss_disA + self.loss_disB) / 2 total_gloss += self.loss_G b_dloss += (self.loss_disA + self.loss_disB) / 2 b_gloss += self.loss_G # Intermediate logging and visualization if index % 10 == 0: show(self.image[0], self.concat_li[0], self.fake_image[0]) print(f"{index}/{len(train_dataloader)} Batch Dis Loss: {b_dloss}, Batch Gen Loss: {b_gloss}\n") b_dloss, b_gloss = 0.0, 0.0 avg_dloss = total_dloss / len(dataloader) avg_gloss = total_gloss / len(dataloader) print(f"{epoch}/{epochs} Average D Loss: {avg_dloss}, Average G Loss: {avg_gloss}\n") if __name__ == "__main__": cgantrainer = CGANTrainer() cgantrainer.train()