| import torch.nn as nn |
| from torchvision import transforms |
| from utils import * |
| from models import Generator , Discriminator |
| from tqdm.auto import tqdm |
|
|
| adv_criterion = nn.MSELoss() |
| recon_criterion = nn.L1Loss() |
|
|
| n_epochs = 60 |
| dim_A = 3 |
| dim_B = 3 |
| display_step = 200 |
| batch_size = 1 |
| lr = 0.0002 |
| load_shape = 286 |
| target_shape = 256 |
| device='cuda' |
|
|
|
|
| transform = transforms.Compose([ |
| transforms.Resize(load_shape), |
| transforms.RandomCrop(target_shape), |
| transforms.RandomHorizontalFlip(), |
| transforms.ToTensor(), |
| ]) |
|
|
| dataset = ImageDataset("horse2zebra", transform=transform) |
|
|
| gen_AB = Generator(dim_A,dim_B).to(device) |
| gen_BA = Generator(dim_B,dim_A).to(device) |
| gen_opt = torch.optim.Adam(list(gen_AB.parameters()) + list(gen_BA.parameters()),lr = lr,betas=(0.5,0.999)) |
| disc_A = Discriminator(dim_A).to(device) |
| disc_A_opt = torch.optim.Adam(disc_A.parameters(),lr=lr,betas=(0.5,0.999)) |
| disc_B = Discriminator(dim_B).to(device) |
| disc_B_opt = torch.optim.Adam(disc_B.parameters(),lr=lr,betas=(0.5,0.999)) |
|
|
|
|
| gen_AB = gen_AB.apply(weights_init) |
| gen_BA = gen_BA.apply(weights_init) |
| disc_A = disc_A.apply(weights_init) |
| disc_B = disc_B.apply(weights_init) |
|
|
|
|
|
|
| def train(): |
| mean_gen_loss = 0 |
| mean_disc_loss = 0 |
| dataloader = DataLoader(dataset,batch_size,shuffle=True) |
| cur_step = 0 |
|
|
| for epoch in range(n_epochs): |
| for real_A,real_B in tqdm(dataloader): |
| real_A = nn.functional.interpolate(real_A,size=target_shape) |
| real_B = nn.functional.interpolate(real_B,size=target_shape) |
| cur_batch_size = len(real_A) |
| real_A = real_A.to(device) |
| real_B = real_B.to(device) |
|
|
| disc_A_opt.zero_grad() |
| with torch.no_grad(): |
| fake_A = gen_BA(real_A) |
| disc_A_loss = get_disc_loss(real_A,fake_A,disc_A,adv_criterion) |
| disc_A_loss.backward(retain_graph=True) |
| disc_A_opt.step() |
|
|
| disc_B_opt.zero_grad() |
| with torch.no_grad(): |
| fake_B = gen_AB(real_B) |
| disc_B_loss = get_disc_loss(real_B,fake_B,disc_B,adv_criterion) |
| disc_B_loss.backward(retain_graph=True) |
| disc_B_opt.step() |
|
|
| gen_opt.zero_grad() |
| gen_loss ,fake_A,fake_B= get_gen_loss(real_A,real_B,gen_AB,gen_BA,disc_B,disc_A,adv_criterion=,identity_criterion=recon_criterion,cycle_criterion=recon_criterion) |
| gen_loss.backward() |
| gen_opt.step() |
|
|
| mean_gen_loss += gen_loss.item() / display_step |
|
|
| mean_disc_loss += disc_A_loss.item() / display_step |
|
|
| if cur_step % display_step == 0 and cur_step > 0: |
| print(f"Epoch: {epoch} | Step: {cur_step} | Gen_loss: {mean_gen_loss} | Disc_loss: {mean_disc_loss} |") |
| show_tensor_images(torch.cat([real_A,real_B]),size=(dim_A,target_shape,target_shape)) |
| show_tensor_images(torch.cat([fake_A,fake_B]),size=(dim_B,target_shape,target_shape)) |
| mean_gen_loss = 0 |
| mean_disc_loss = 0 |
| torch.save({ |
| 'gen_AB': gen_AB, |
| 'gen_BA': gen_BA, |
| 'gen_opt': gen_opt, |
| 'disc_A': disc_A, |
| 'disc_A_opt': disc_A_opt, |
| 'disc_B': disc_B, |
| 'disc_B_opt': disc_B_opt |
| }, f"checkpoints/cycleGAN_{cur_step}.pth") |
|
|
| cur_step += 1 |
|
|
| if __name__ == "__main__": |
| train() |