import torch.nn as nn from torchvision import transforms from utils import * from models import Generator , Discriminator from tqdm.auto import tqdm adv_criterion = nn.MSELoss() recon_criterion = nn.L1Loss() n_epochs = 60 dim_A = 3 dim_B = 3 display_step = 200 batch_size = 1 lr = 0.0002 load_shape = 286 target_shape = 256 device='cuda' transform = transforms.Compose([ transforms.Resize(load_shape), transforms.RandomCrop(target_shape), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]) dataset = ImageDataset("horse2zebra", transform=transform) gen_AB = Generator(dim_A,dim_B).to(device) gen_BA = Generator(dim_B,dim_A).to(device) gen_opt = torch.optim.Adam(list(gen_AB.parameters()) + list(gen_BA.parameters()),lr = lr,betas=(0.5,0.999)) disc_A = Discriminator(dim_A).to(device) disc_A_opt = torch.optim.Adam(disc_A.parameters(),lr=lr,betas=(0.5,0.999)) disc_B = Discriminator(dim_B).to(device) disc_B_opt = torch.optim.Adam(disc_B.parameters(),lr=lr,betas=(0.5,0.999)) gen_AB = gen_AB.apply(weights_init) gen_BA = gen_BA.apply(weights_init) disc_A = disc_A.apply(weights_init) disc_B = disc_B.apply(weights_init) def train(): mean_gen_loss = 0 mean_disc_loss = 0 dataloader = DataLoader(dataset,batch_size,shuffle=True) cur_step = 0 for epoch in range(n_epochs): for real_A,real_B in tqdm(dataloader): real_A = nn.functional.interpolate(real_A,size=target_shape) real_B = nn.functional.interpolate(real_B,size=target_shape) cur_batch_size = len(real_A) real_A = real_A.to(device) real_B = real_B.to(device) disc_A_opt.zero_grad() with torch.no_grad(): fake_A = gen_BA(real_A) disc_A_loss = get_disc_loss(real_A,fake_A,disc_A,adv_criterion) disc_A_loss.backward(retain_graph=True) disc_A_opt.step() disc_B_opt.zero_grad() with torch.no_grad(): fake_B = gen_AB(real_B) disc_B_loss = get_disc_loss(real_B,fake_B,disc_B,adv_criterion) disc_B_loss.backward(retain_graph=True) disc_B_opt.step() gen_opt.zero_grad() gen_loss ,fake_A,fake_B= get_gen_loss(real_A,real_B,gen_AB,gen_BA,disc_B,disc_A,adv_criterion=,identity_criterion=recon_criterion,cycle_criterion=recon_criterion) gen_loss.backward() gen_opt.step() mean_gen_loss += gen_loss.item() / display_step mean_disc_loss += disc_A_loss.item() / display_step if cur_step % display_step == 0 and cur_step > 0: print(f"Epoch: {epoch} | Step: {cur_step} | Gen_loss: {mean_gen_loss} | Disc_loss: {mean_disc_loss} |") show_tensor_images(torch.cat([real_A,real_B]),size=(dim_A,target_shape,target_shape)) show_tensor_images(torch.cat([fake_A,fake_B]),size=(dim_B,target_shape,target_shape)) mean_gen_loss = 0 mean_disc_loss = 0 torch.save({ 'gen_AB': gen_AB, 'gen_BA': gen_BA, 'gen_opt': gen_opt, 'disc_A': disc_A, 'disc_A_opt': disc_A_opt, 'disc_B': disc_B, 'disc_B_opt': disc_B_opt }, f"checkpoints/cycleGAN_{cur_step}.pth") cur_step += 1 if __name__ == "__main__": train()