| import torch.nn.functional as F |
| import torch.nn as nn |
| from torchvision import transforms |
| import torchvision |
| from torch.utils.data import DataLoader |
| from UNet import * |
| from tqdm.auto import tqdm |
|
|
| adv_criterion = nn.BCEWithLogitsLoss() |
| recon_criterion = nn.L1Loss() |
| lambda_recon = 200 |
|
|
| n_epochs = 50 |
| input_dim = 3 |
| real_dim = 3 |
| display_step = 200 |
| batch_size = 4 |
| lr = 0.0002 |
| target_shape = 256 |
| device = 'cuda' |
|
|
|
|
| transform = transforms.Compose([ |
| transforms.ToTensor(), |
| ]) |
|
|
| dataset = torchvision.datasets.ImageFolder("/datasets/maps", transform=transform) |
|
|
|
|
| gen = UNet(input_dim, real_dim).to(device) |
| gen_opt = torch.optim.Adam(gen.parameters(),lr=lr) |
| disc = Discriminator(input_dim+real_dim).to(device) |
| disc_opt = torch.optim.Adam(disc.parameters(),lr=lr) |
|
|
| def weights_init(m): |
| if isinstance(m,nn.Conv2d) or isinstance(m,nn.ConvTranspose2d): |
| torch.nn.init.normal_(m.weight,0.0,0.02) |
| if isinstance(m,nn.BatchNorm2d): |
| torch.nn.init.normal_(m.weight,0.0,0.02) |
| torch.nn.init.constant_(m.bias,0.0) |
|
|
| gen =gen.apply(weights_init) |
| disc = disc.apply(weights_init) |
|
|
|
|
| mean_gen_loss = 0 |
| mean_disc_loss = 0 |
| dataloader = DataLoader(dataset,batch_size=batch_size,shuffle=True) |
|
|
| cur_step = 0 |
| for epoch in range(n_epochs): |
| for image,_ in tqdm(dataloader): |
| image_width = image.shape[3] |
| condition = image[:, :, :, :image_width // 2] |
| condition = nn.functional.interpolate(condition, size=target_shape) |
| real = image[:, :, :, image_width // 2:] |
| real = nn.functional.interpolate(real, size=target_shape) |
| cur_batch_size = len(real) |
| condition = condition.to(device) |
| real = real.to(device) |
|
|
| disc_opt.zero_grad() |
| with torch.no_grad(): |
| fake = gen(condition) |
| disc_fake_pred = disc(fake.detach(),condition) |
| disc_fake_loss = adv_criterion(disc_fake_pred,torch.zeros_like(disc_fake_pred)) |
| disc_real_pred = disc(real,condition) |
| disc_real_loss = adv_criterion(disc_real_pred,torch.ones_like(disc_real_pred)) |
| disc_loss = (disc_fake_loss + disc_real_loss) / 2 |
| disc_loss.backward(retain_graph= True) |
| disc_opt.step() |
|
|
| gen.zero_grad() |
| gen_loss = get_gen_loss(gen,disc,real,condition,adv_criterion,recon_criterion,lambda_recon) |
| gen_loss.backward() |
| gen_opt.step() |
|
|
| mean_gen_loss += gen_loss.item() / display_step |
| mean_disc_loss += disc_loss.item() / display_step |
|
|
| if cur_step % display_step == 0: |
| print(f"Epoch: {epoch} | Step: {cur_step} | Gen-Loss: {mean_gen_loss} | Disc-loss: {mean_disc_loss}") |
| show_tensor_images(condition,epoch,cur_step,"condition", size=(input_dim, target_shape, target_shape)) |
| show_tensor_images(real, epoch,cur_step,"real",size=(real_dim, target_shape, target_shape)) |
| show_tensor_images(fake, epoch,cur_step,"generated",size=(real_dim, target_shape, target_shape)) |
| mean_gen_loss = 0 |
| mean_disc_loss = 0 |
|
|
| torch.save({ |
| "gen":gen, |
| "disc":disc, |
| "gen_opt": gen_opt, |
| "disc_opt": disc_opt |
| },f"checkpoints/Pix2Pix_Epoch{epoch}.pth") |
| cur_step += 1 |
|
|