Pix2PIx / train.py
Yash Nagraj
Add train code with dataset.sh
52c73e3
import torch.nn.functional as F
import torch.nn as nn
from torchvision import transforms
import torchvision
from torch.utils.data import DataLoader
from UNet import *
from tqdm.auto import tqdm
adv_criterion = nn.BCEWithLogitsLoss()
recon_criterion = nn.L1Loss()
lambda_recon = 200
n_epochs = 50
input_dim = 3
real_dim = 3
display_step = 200
batch_size = 4
lr = 0.0002
target_shape = 256
device = 'cuda'
transform = transforms.Compose([
transforms.ToTensor(),
])
dataset = torchvision.datasets.ImageFolder("/datasets/maps", transform=transform)
gen = UNet(input_dim, real_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(),lr=lr)
disc = Discriminator(input_dim+real_dim).to(device)
disc_opt = torch.optim.Adam(disc.parameters(),lr=lr)
def weights_init(m):
if isinstance(m,nn.Conv2d) or isinstance(m,nn.ConvTranspose2d):
torch.nn.init.normal_(m.weight,0.0,0.02)
if isinstance(m,nn.BatchNorm2d):
torch.nn.init.normal_(m.weight,0.0,0.02)
torch.nn.init.constant_(m.bias,0.0)
gen =gen.apply(weights_init)
disc = disc.apply(weights_init)
mean_gen_loss = 0
mean_disc_loss = 0
dataloader = DataLoader(dataset,batch_size=batch_size,shuffle=True)
cur_step = 0
for epoch in range(n_epochs):
for image,_ in tqdm(dataloader):
image_width = image.shape[3]
condition = image[:, :, :, :image_width // 2]
condition = nn.functional.interpolate(condition, size=target_shape)
real = image[:, :, :, image_width // 2:]
real = nn.functional.interpolate(real, size=target_shape)
cur_batch_size = len(real)
condition = condition.to(device)
real = real.to(device)
disc_opt.zero_grad()
with torch.no_grad():
fake = gen(condition)
disc_fake_pred = disc(fake.detach(),condition)
disc_fake_loss = adv_criterion(disc_fake_pred,torch.zeros_like(disc_fake_pred))
disc_real_pred = disc(real,condition)
disc_real_loss = adv_criterion(disc_real_pred,torch.ones_like(disc_real_pred))
disc_loss = (disc_fake_loss + disc_real_loss) / 2
disc_loss.backward(retain_graph= True)
disc_opt.step()
gen.zero_grad()
gen_loss = get_gen_loss(gen,disc,real,condition,adv_criterion,recon_criterion,lambda_recon)
gen_loss.backward()
gen_opt.step()
mean_gen_loss += gen_loss.item() / display_step
mean_disc_loss += disc_loss.item() / display_step
if cur_step % display_step == 0:
print(f"Epoch: {epoch} | Step: {cur_step} | Gen-Loss: {mean_gen_loss} | Disc-loss: {mean_disc_loss}")
show_tensor_images(condition,epoch,cur_step,"condition", size=(input_dim, target_shape, target_shape))
show_tensor_images(real, epoch,cur_step,"real",size=(real_dim, target_shape, target_shape))
show_tensor_images(fake, epoch,cur_step,"generated",size=(real_dim, target_shape, target_shape))
mean_gen_loss = 0
mean_disc_loss = 0
torch.save({
"gen":gen,
"disc":disc,
"gen_opt": gen_opt,
"disc_opt": disc_opt
},f"checkpoints/Pix2Pix_Epoch{epoch}.pth")
cur_step += 1