{ "cells": [ { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/yash/.local/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import torch\n", "from torch import nn\n", "from tqdm.auto import tqdm\n", "from torchvision import transforms\n", "from torchvision.utils import make_grid\n", "from torch.utils.data import DataLoader\n", "import matplotlib.pyplot as plt\n", "torch.manual_seed(0)\n", "\n", "def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):\n", " '''\n", " Function for visualizing images: Given a tensor of images, number of images, and\n", " size per image, plots and prints the images in an uniform grid.\n", " '''\n", " image_tensor = (image_tensor + 1) / 2\n", " image_shifted = image_tensor\n", " image_unflat = image_shifted.detach().cpu().view(-1, *size)\n", " image_grid = make_grid(image_unflat[:num_images], nrow=5)\n", " plt.imshow(image_grid.permute(1, 2, 0).squeeze())\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "import glob\n", "import random\n", "import os\n", "from torch.utils.data import Dataset\n", "from PIL import Image\n", "\n", "class ImageDataset(Dataset):\n", " def __init__(self, root, transform=None, mode='train'):\n", " self.transform = transform\n", " self.files_A = sorted(glob.glob(os.path.join(root, '%sA' % mode) + '/*.*'))\n", " self.files_B = sorted(glob.glob(os.path.join(root, '%sB' % mode) + '/*.*'))\n", " if len(self.files_A) > len(self.files_B):\n", " self.files_A, self.files_B = self.files_B, self.files_A\n", " self.new_perm()\n", " assert len(self.files_A) > 0, \"Make sure you downloaded the horse2zebra images!\"\n", "\n", " def new_perm(self):\n", " self.randperm = torch.randperm(len(self.files_B))[:len(self.files_A)]\n", "\n", " def __getitem__(self, index):\n", " item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)]))\n", " item_B = self.transform(Image.open(self.files_B[self.randperm[index]]))\n", " if item_A.shape[0] != 3: \n", " item_A = item_A.repeat(3, 1, 1)\n", " if item_B.shape[0] != 3: \n", " item_B = item_B.repeat(3, 1, 1)\n", " if index == len(self) - 1:\n", " self.new_perm()\n", " # Old versions of PyTorch didn't support normalization for different-channeled images\n", " return (item_A - 0.5) * 2, (item_B - 0.5) * 2\n", "\n", " def __len__(self):\n", " return min(len(self.files_A), len(self.files_B))" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "class ResidualBlock(nn.Module):\n", " def __init__(self,input_channels ) -> None:\n", " super(ResidualBlock,self).__init__()\n", " self.conv1 = nn.Conv2d(input_channels,input_channels,3,1,padding=1,padding_mode='reflect')\n", " self.conv2 = nn.Conv2d(input_channels,input_channels,3,1,padding=1,padding_mode='reflect')\n", " self.instanceNorm = nn.InstanceNorm2d(input_channels)\n", " self.activation = nn.ReLU()\n", "\n", " def forward(self,x):\n", " original = x.copy()\n", " x = self.conv1(x)\n", " x = self.instanceNorm(x)\n", " x = self.activation(x)\n", " x = self.conv2(x)\n", " x = self.instanceNorm(x)\n", " return original + x\n", "\n" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "class ContractingBlock(nn.Module):\n", " def __init__(self, input_channels, use_bn=True,kernel_size=3,activation='relu') -> None:\n", " super(ContractingBlock,self).__init__()\n", " self.conv1 = nn.Conv2d(input_channels, input_channels*2, kernel_size,padding=1,stride=2,padding_mode='reflect')\n", " self.activation = nn.ReLU() if activation == 'relu' else nn.LeakyReLU(0.2)\n", " if use_bn:\n", " self.normalization = nn.InstanceNorm2d(input_channels)\n", " self.use_bn = use_bn\n", "\n", " def forward(self,x):\n", " x = self.conv1(x)\n", " if self.use_bn:\n", " self.normalization(x)\n", " x = self.activation(x)\n", " return x\n", " \n", "\n", "class ExpandingBlock(nn.Module):\n", " def __init__(self,input_channels,use_bn=True) -> None:\n", " super(ExpandingBlock, self).__init__()\n", " self.conv1 = nn.ConvTranspose2d(input_channels, input_channels // 2, kernel_size=3,stride=2,padding=1,output_padding=1)\n", " if use_bn:\n", " self.normalization = nn.InstanceNorm2d(input_channels // 2)\n", " self.use_bn = use_bn\n", " self.activation = nn.ReLU()\n", "\n", " def forward(self, x):\n", " x = self.conv1(x)\n", " if self.use_bn:\n", " x = self.normalization(x)\n", " x = self.activation(x)\n", " return x\n", " \n", "\n", "\n", "class FeatureMapBlock(nn.Module):\n", " def __init__(self, input_channels, output_channels) -> None:\n", " super(FeatureMapBlock,self).__init__()\n", " self.conv = nn.Conv2d(input_channels, output_channels,kernel_size=7,padding=3,padding_mode='reflect')\n", "\n", " def forward(self,x):\n", " x = self.conv(x)\n", " return x\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "class Generator(nn.Module):\n", " def __init__(self, input_channels,output_channels, hidden_dim=64) -> None:\n", " super(Generator,self).__init__()\n", " self.upfeature = FeatureMapBlock(input_channels,hidden_dim)\n", " self.contract1 = ContractingBlock(hidden_dim)\n", " self.contract2 = ContractingBlock(hidden_dim * 2)\n", " res_mult = 4\n", " self.res0 = ResidualBlock(hidden_dim * res_mult)\n", " self.res1 = ResidualBlock(hidden_dim * res_mult)\n", " self.res2 = ResidualBlock(hidden_dim * res_mult)\n", " self.res3 = ResidualBlock(hidden_dim * res_mult)\n", " self.res4 = ResidualBlock(hidden_dim * res_mult)\n", " self.res5 = ResidualBlock(hidden_dim * res_mult)\n", " self.res6 = ResidualBlock(hidden_dim * res_mult)\n", " self.res7 = ResidualBlock(hidden_dim * res_mult)\n", " self.res8 = ResidualBlock(hidden_dim * res_mult)\n", " self.expand1 = ExpandingBlock(hidden_dim * res_mult)\n", " self.expand2 = ExpandingBlock(hidden_dim * 2)\n", " self.downfeature = FeatureMapBlock(hidden_dim,output_channels)\n", " self.tanh = nn.Tanh()\n", "\n", " def forward(self, x):\n", " x0 = self.upfeature(x)\n", " x1 = self.contract1(x0)\n", " x2 = self.contract2(x1)\n", " x3 = self.res0(x2)\n", " x4 = self.res1(x3)\n", " x5 = self.res2(x4)\n", " x6 = self.res3(x5)\n", " x7 = self.res4(x6)\n", " x8 = self.res5(x7)\n", " x9 = self.res6(x8)\n", " x10 = self.res7(x9)\n", " x11 = self.res8(x10)\n", " x12 = self.expand1(x11)\n", " x13 = self.expand2(x12)\n", " xn = self.downfeature(x13)\n", " return self.tanh(xn)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "class Discriminator(nn.Module):\n", " def __init__(self, input_channels, hidden_channels=64) -> None:\n", " super(Discriminator,self).__init__()\n", " self.upfeature = FeatureMapBlock(input_channels,hidden_channels)\n", " self.contract1 = ContractingBlock(hidden_channels, False,kernel_size=4,activation='lrelu')\n", " self.contract2 = ContractingBlock(hidden_channels * 2,kernel_size=4,activation='lrelu')\n", " self.contract3 = ContractingBlock(hidden_channels * 4,kernel_size=4,activation='lrelu')\n", " self.conv = nn.Conv2d(hidden_channels*8,1,kernel_size=1)\n", "\n", " def forward(self,x):\n", " x0 = self.upfeature(x)\n", " x1 = self.contract1(x0)\n", " x2 = self.contract2(x1)\n", " x3 = self.contract3(x2)\n", " x4 = self.conv(x3)\n", " return x4\n", "\n", " " ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "import torch.nn.functional as F\n", "\n", "adv_criterion = nn.MSELoss()\n", "recon_criterion = nn.L1Loss()\n", "\n", "n_epochs = 60\n", "dim_A = 3\n", "dim_B = 3\n", "display_step = 200\n", "batch_size = 1\n", "lr = 0.0002\n", "load_shape = 286\n", "target_shape = 256\n", "device='cuda'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "transform = transforms.Compose([\n", " transforms.Resize(load_shape),\n", " transforms.RandomCrop(target_shape),\n", " transforms.RandomHorizontalFlip(),\n", " transforms.ToTensor(),\n", "])\n", "\n", "dataset = ImageDataset(\"horse2zebra\", transform=transform)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "gen_AB = Generator(dim_A,dim_B).to(device)\n", "gen_BA = Generator(dim_B,dim_A).to(device)\n", "gen_opt = torch.optim.Adam(list(gen_AB.parameters()) + list(gen_BA.parameters()),lr = lr,betas=(0.5,0.999))\n", "disc_A = Discriminator(dim_A).to(device)\n", "disc_A_opt = torch.optim.Adam(disc_A.parameters(),lr=lr,betas=(0.5,0.999))\n", "disc_B = Discriminator(dim_B).to(device)\n", "disc_B_opt = torch.optim.Adam(disc_B.parameters(),lr=lr,betas=(0.5,0.999))\n", "\n", "def weights_init(m):\n", " if isinstance(m,nn.Conv2d) or isinstance(m,nn.ConvTranspose2d):\n", " torch.nn.init.normal_(m.weight,1.0,0.2)\n", " if isinstance(m, nn.BatchNorm2d):\n", " torch.nn.init.normal_(m.weight, 0.0, 0.02)\n", " torch.nn.init.constant_(m.bias, 0)\n", "\n", "\n", "gen_AB = gen_AB.apply(weights_init)\n", "gen_BA = gen_BA.apply(weights_init)\n", "disc_A = disc_A.apply(weights_init)\n", "disc_B = disc_B.apply(weights_init)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "def get_disc_loss(real_X, fake_X,disc_X, adv_criterion):\n", " real_pred = disc_X(real_X.detach())\n", " disc_real_loss = adv_criterion(real_pred,torch.ones_like(real_pred))\n", " fake_pred = disc_X(fake_X.deatch())\n", " disc_fake_loss = adv_criterion(fake_pred.detach(),torch.zeros_like(fake_pred))\n", " disc_loss = (disc_real_loss + disc_fake_loss) / 2\n", " return disc_loss" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "def get_gen_adversarial_loss(real_X, disc_Y, gen_XY, adv_criterion):\n", " fake_Y = gen_XY(real_X.detach())\n", " disc_pred = disc_Y(fake_Y) \n", " adverserial_loss = adv_criterion(disc_pred,torch.ones_like(disc_pred))\n", " return adverserial_loss,fake_Y\n", "\n", "def get_identity_loss(real_X, gen_YX,identity_criterion):\n", " identity_X = gen_YX(real_X)\n", " identity_loss = identity_criterion(identity_X,real_X)\n", " return identity_loss,identity_X\n" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "def get_cycle_consistency_loss(real_X, fake_Y, gen_YX, cycle_criterion):\n", " cycle_X = gen_YX(fake_Y)\n", " cycle_loss = cycle_criterion(cycle_X,real_X)\n", " return cycle_loss,cycle_X\n" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "def get_gen_loss(real_A, real_B,gen_AB,gen_BA,disc_B,disc_A,adv_criterion,cycle_criterion,identity_criterion,lambda_identity=0.2,lambda_cycle=10):\n", " adv_loss_BA, fake_A = get_gen_adversarial_loss(real_B, disc_A, gen_BA, adv_criterion)\n", " adv_loss_AB, fake_B = get_gen_adversarial_loss(real_A, disc_B, gen_AB, adv_criterion)\n", " gen_adversarial_loss = adv_loss_BA + adv_loss_AB\n", "\n", " # Identity Loss -- get_identity_loss(real_X, gen_YX, identity_criterion)\n", " identity_loss_A, identity_A = get_identity_loss(real_A, gen_BA, identity_criterion)\n", " identity_loss_B, identity_B = get_identity_loss(real_B, gen_AB, identity_criterion)\n", " gen_identity_loss = identity_loss_A + identity_loss_B\n", "\n", " # Cycle-consistency Loss -- get_cycle_consistency_loss(real_X, fake_Y, gen_YX, cycle_criterion)\n", " cycle_loss_BA, cycle_A = get_cycle_consistency_loss(real_A, fake_B, gen_BA, cycle_criterion)\n", " cycle_loss_AB, cycle_B = get_cycle_consistency_loss(real_B, fake_A, gen_AB, cycle_criterion)\n", " gen_cycle_loss = cycle_loss_BA + cycle_loss_AB\n", "\n", " # Total loss\n", " gen_loss = lambda_identity * gen_identity_loss + lambda_cycle * gen_cycle_loss + gen_adversarial_loss\n", "\n", " return gen_loss , fake_A,fake_B" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def train():\n", " mean_gen_loss = 0\n", " mean_disc_loss = 0\n", " dataloader = DataLoader(dataset,batch_size,shuffle=True)\n", " cur_step = 0\n", "\n", " for epoch in range(n_epochs):\n", " for real_A,real_B in tqdm(dataloader):\n", " real_A = nn.functional.interpolate(real_A,size=target_shape)\n", " real_B = nn.functional.interpolate(real_B,size=target_shape)\n", " cur_batch_size = len(real_A)\n", " real_A = real_A.to(device)\n", " real_B = real_B.to(device)\n", "\n", " disc_A_opt.zero_grad()\n", " with torch.no_grad():\n", " fake_A = gen_BA(real_A)\n", " disc_A_loss = get_disc_loss(real_A,fake_A,disc_A,adv_criterion)\n", " disc_A_loss.backward(retain_graph=True)\n", " disc_A_opt.step()\n", "\n", " disc_B_opt.zero_grad()\n", " with torch.no_grad():\n", " fake_B = gen_AB(real_B)\n", " disc_B_loss = get_disc_loss(real_B,fake_B,disc_B,adv_criterion)\n", " disc_B_loss.backward(retain_graph=True)\n", " disc_B_opt.step()\n", "\n", " gen_opt.zero_grad()\n", " gen_loss ,fake_A,fake_B= get_gen_loss(real_A,real_B,gen_AB,gen_BA,disc_B,disc_A,adv_criterion=,identity_criterion=recon_criterion,cycle_criterion=recon_criterion)\n", " gen_loss.backward()\n", " gen_opt.step()\n", "\n", " mean_gen_loss += gen_loss.item() / display_step\n", "\n", " mean_disc_loss += disc_A_loss.item() / display_step\n", "\n", " if cur_step % display_step == 0 and cur_step > 0:\n", " print(f\"Epoch: {epoch} | Step: {cur_step} | Gen_loss: {mean_gen_loss} | Disc_loss: {mean_disc_loss} |\")\n", " show_tensor_images(torch.cat([real_A,real_B]),size=(dim_A,target_shape,target_shape))\n", " show_tensor_images(torch.cat([fake_A,fake_B]),size=(dim_B,target_shape,target_shape))\n", " mean_gen_loss = 0\n", " mean_disc_loss = 0\n", " torch.save({\n", " 'gen_AB': gen_AB,\n", " 'gen_BA': gen_BA,\n", " 'gen_opt': gen_opt,\n", " 'disc_A': disc_A,\n", " 'disc_A_opt': disc_A_opt,\n", " 'disc_B': disc_B,\n", " 'disc_B_opt': disc_B_opt\n", " }, f\"checkpoints/cycleGAN_{cur_step}.pth\")\n", "\n", " cur_step += 1\n", "\n", " " ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 2 }