import torch.nn.functional as F import torch.nn as nn from torchvision import transforms import torchvision from torch.utils.data import DataLoader from UNet import * from tqdm.auto import tqdm adv_criterion = nn.BCEWithLogitsLoss() recon_criterion = nn.L1Loss() lambda_recon = 200 n_epochs = 50 input_dim = 3 real_dim = 3 display_step = 200 batch_size = 4 lr = 0.0002 target_shape = 256 device = 'cuda' transform = transforms.Compose([ transforms.ToTensor(), ]) dataset = torchvision.datasets.ImageFolder("/datasets/maps", transform=transform) gen = UNet(input_dim, real_dim).to(device) gen_opt = torch.optim.Adam(gen.parameters(),lr=lr) disc = Discriminator(input_dim+real_dim).to(device) disc_opt = torch.optim.Adam(disc.parameters(),lr=lr) def weights_init(m): if isinstance(m,nn.Conv2d) or isinstance(m,nn.ConvTranspose2d): torch.nn.init.normal_(m.weight,0.0,0.02) if isinstance(m,nn.BatchNorm2d): torch.nn.init.normal_(m.weight,0.0,0.02) torch.nn.init.constant_(m.bias,0.0) gen =gen.apply(weights_init) disc = disc.apply(weights_init) mean_gen_loss = 0 mean_disc_loss = 0 dataloader = DataLoader(dataset,batch_size=batch_size,shuffle=True) cur_step = 0 for epoch in range(n_epochs): for image,_ in tqdm(dataloader): image_width = image.shape[3] condition = image[:, :, :, :image_width // 2] condition = nn.functional.interpolate(condition, size=target_shape) real = image[:, :, :, image_width // 2:] real = nn.functional.interpolate(real, size=target_shape) cur_batch_size = len(real) condition = condition.to(device) real = real.to(device) disc_opt.zero_grad() with torch.no_grad(): fake = gen(condition) disc_fake_pred = disc(fake.detach(),condition) disc_fake_loss = adv_criterion(disc_fake_pred,torch.zeros_like(disc_fake_pred)) disc_real_pred = disc(real,condition) disc_real_loss = adv_criterion(disc_real_pred,torch.ones_like(disc_real_pred)) disc_loss = (disc_fake_loss + disc_real_loss) / 2 disc_loss.backward(retain_graph= True) disc_opt.step() gen.zero_grad() gen_loss = get_gen_loss(gen,disc,real,condition,adv_criterion,recon_criterion,lambda_recon) gen_loss.backward() gen_opt.step() mean_gen_loss += gen_loss.item() / display_step mean_disc_loss += disc_loss.item() / display_step if cur_step % display_step == 0: print(f"Epoch: {epoch} | Step: {cur_step} | Gen-Loss: {mean_gen_loss} | Disc-loss: {mean_disc_loss}") show_tensor_images(condition,epoch,cur_step,"condition", size=(input_dim, target_shape, target_shape)) show_tensor_images(real, epoch,cur_step,"real",size=(real_dim, target_shape, target_shape)) show_tensor_images(fake, epoch,cur_step,"generated",size=(real_dim, target_shape, target_shape)) mean_gen_loss = 0 mean_disc_loss = 0 torch.save({ "gen":gen, "disc":disc, "gen_opt": gen_opt, "disc_opt": disc_opt },f"checkpoints/Pix2Pix_Epoch{epoch}.pth") cur_step += 1