CycleGAN / train.py
Yash Nagraj
Add the training scripts for cloud training
275907d
import torch.nn as nn
from torchvision import transforms
from utils import *
from models import Generator , Discriminator
from tqdm.auto import tqdm
adv_criterion = nn.MSELoss()
recon_criterion = nn.L1Loss()
n_epochs = 60
dim_A = 3
dim_B = 3
display_step = 200
batch_size = 1
lr = 0.0002
load_shape = 286
target_shape = 256
device='cuda'
transform = transforms.Compose([
transforms.Resize(load_shape),
transforms.RandomCrop(target_shape),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
dataset = ImageDataset("horse2zebra", transform=transform)
gen_AB = Generator(dim_A,dim_B).to(device)
gen_BA = Generator(dim_B,dim_A).to(device)
gen_opt = torch.optim.Adam(list(gen_AB.parameters()) + list(gen_BA.parameters()),lr = lr,betas=(0.5,0.999))
disc_A = Discriminator(dim_A).to(device)
disc_A_opt = torch.optim.Adam(disc_A.parameters(),lr=lr,betas=(0.5,0.999))
disc_B = Discriminator(dim_B).to(device)
disc_B_opt = torch.optim.Adam(disc_B.parameters(),lr=lr,betas=(0.5,0.999))
gen_AB = gen_AB.apply(weights_init)
gen_BA = gen_BA.apply(weights_init)
disc_A = disc_A.apply(weights_init)
disc_B = disc_B.apply(weights_init)
def train():
mean_gen_loss = 0
mean_disc_loss = 0
dataloader = DataLoader(dataset,batch_size,shuffle=True)
cur_step = 0
for epoch in range(n_epochs):
for real_A,real_B in tqdm(dataloader):
real_A = nn.functional.interpolate(real_A,size=target_shape)
real_B = nn.functional.interpolate(real_B,size=target_shape)
cur_batch_size = len(real_A)
real_A = real_A.to(device)
real_B = real_B.to(device)
disc_A_opt.zero_grad()
with torch.no_grad():
fake_A = gen_BA(real_A)
disc_A_loss = get_disc_loss(real_A,fake_A,disc_A,adv_criterion)
disc_A_loss.backward(retain_graph=True)
disc_A_opt.step()
disc_B_opt.zero_grad()
with torch.no_grad():
fake_B = gen_AB(real_B)
disc_B_loss = get_disc_loss(real_B,fake_B,disc_B,adv_criterion)
disc_B_loss.backward(retain_graph=True)
disc_B_opt.step()
gen_opt.zero_grad()
gen_loss ,fake_A,fake_B= get_gen_loss(real_A,real_B,gen_AB,gen_BA,disc_B,disc_A,adv_criterion=,identity_criterion=recon_criterion,cycle_criterion=recon_criterion)
gen_loss.backward()
gen_opt.step()
mean_gen_loss += gen_loss.item() / display_step
mean_disc_loss += disc_A_loss.item() / display_step
if cur_step % display_step == 0 and cur_step > 0:
print(f"Epoch: {epoch} | Step: {cur_step} | Gen_loss: {mean_gen_loss} | Disc_loss: {mean_disc_loss} |")
show_tensor_images(torch.cat([real_A,real_B]),size=(dim_A,target_shape,target_shape))
show_tensor_images(torch.cat([fake_A,fake_B]),size=(dim_B,target_shape,target_shape))
mean_gen_loss = 0
mean_disc_loss = 0
torch.save({
'gen_AB': gen_AB,
'gen_BA': gen_BA,
'gen_opt': gen_opt,
'disc_A': disc_A,
'disc_A_opt': disc_A_opt,
'disc_B': disc_B,
'disc_B_opt': disc_B_opt
}, f"checkpoints/cycleGAN_{cur_step}.pth")
cur_step += 1
if __name__ == "__main__":
train()