| import argparse |
| import math |
| import random |
| import os |
| from util import * |
| import numpy as np |
| import torch |
| torch.backends.cudnn.benchmark = True |
| from torch import nn, autograd |
| from torch import optim |
| from torch.nn import functional as F |
| from torch.utils import data |
| import torch.distributed as dist |
|
|
| from torchvision import transforms, utils |
| from tqdm import tqdm |
| from torch.optim import lr_scheduler |
| import copy |
| import kornia.augmentation as K |
| import kornia |
| import lpips |
|
|
| from model import * |
| from dataset import ImageFolder |
| from distributed import ( |
| get_rank, |
| synchronize, |
| reduce_loss_dict, |
| reduce_sum, |
| get_world_size, |
| ) |
|
|
| mse_criterion = nn.MSELoss() |
|
|
|
|
| def test(args, genA2B, genB2A, testA_loader, testB_loader, name, step): |
| testA_loader = iter(testA_loader) |
| testB_loader = iter(testB_loader) |
| with torch.no_grad(): |
| test_sample_num = 16 |
|
|
| genA2B.eval(), genB2A.eval() |
| A2B = [] |
| B2A = [] |
| for i in range(test_sample_num): |
| real_A = testA_loader.next() |
| real_B = testB_loader.next() |
|
|
| real_A, real_B = real_A.cuda(), real_B.cuda() |
|
|
| A2B_content, A2B_style = genA2B.encode(real_A) |
| B2A_content, B2A_style = genB2A.encode(real_B) |
|
|
| if i % 2 == 0: |
| A2B_mod1 = torch.randn([1, args.latent_dim]).cuda() |
| B2A_mod1 = torch.randn([1, args.latent_dim]).cuda() |
| A2B_mod2 = torch.randn([1, args.latent_dim]).cuda() |
| B2A_mod2 = torch.randn([1, args.latent_dim]).cuda() |
|
|
| fake_B2B, _, _ = genA2B(real_B) |
| fake_A2A, _, _ = genB2A(real_A) |
|
|
| colsA = [real_A, fake_A2A] |
| colsB = [real_B, fake_B2B] |
| |
| fake_A2B_1 = genA2B.decode(A2B_content, A2B_mod1) |
| fake_B2A_1 = genB2A.decode(B2A_content, B2A_mod1) |
|
|
| fake_A2B_2 = genA2B.decode(A2B_content, A2B_mod2) |
| fake_B2A_2 = genB2A.decode(B2A_content, B2A_mod2) |
|
|
| fake_A2B_3 = genA2B.decode(A2B_content, B2A_style) |
| fake_B2A_3 = genB2A.decode(B2A_content, A2B_style) |
|
|
| colsA += [fake_A2B_3, fake_A2B_1, fake_A2B_2] |
| colsB += [fake_B2A_3, fake_B2A_1, fake_B2A_2] |
|
|
| fake_A2B2A, _, _ = genB2A(fake_A2B_3, A2B_style) |
| fake_B2A2B, _, _ = genA2B(fake_B2A_3, B2A_style) |
| colsA.append(fake_A2B2A) |
| colsB.append(fake_B2A2B) |
|
|
| fake_A2B2A, _, _ = genB2A(fake_A2B_1, A2B_style) |
| fake_B2A2B, _, _ = genA2B(fake_B2A_1, B2A_style) |
| colsA.append(fake_A2B2A) |
| colsB.append(fake_B2A2B) |
|
|
| fake_A2B2A, _, _ = genB2A(fake_A2B_2, A2B_style) |
| fake_B2A2B, _, _ = genA2B(fake_B2A_2, B2A_style) |
| colsA.append(fake_A2B2A) |
| colsB.append(fake_B2A2B) |
|
|
| fake_A2B2A, _, _ = genB2A(fake_A2B_1) |
| fake_B2A2B, _, _ = genA2B(fake_B2A_1) |
| colsA.append(fake_A2B2A) |
| colsB.append(fake_B2A2B) |
|
|
| colsA = torch.cat(colsA, 2).detach().cpu() |
| colsB = torch.cat(colsB, 2).detach().cpu() |
|
|
| A2B.append(colsA) |
| B2A.append(colsB) |
| A2B = torch.cat(A2B, 0) |
| B2A = torch.cat(B2A, 0) |
|
|
| utils.save_image(A2B, f'{im_path}/{name}_A2B_{str(step).zfill(6)}.jpg', normalize=True, range=(-1, 1), nrow=16) |
| utils.save_image(B2A, f'{im_path}/{name}_B2A_{str(step).zfill(6)}.jpg', normalize=True, range=(-1, 1), nrow=16) |
|
|
| genA2B.train(), genB2A.train() |
|
|
|
|
| def train(args, trainA_loader, trainB_loader, testA_loader, testB_loader, G_A2B, G_B2A, D_A, D_B, G_optim, D_optim, device): |
| G_A2B.train(), G_B2A.train(), D_A.train(), D_B.train() |
| trainA_loader = sample_data(trainA_loader) |
| trainB_loader = sample_data(trainB_loader) |
| G_scheduler = lr_scheduler.StepLR(G_optim, step_size=100000, gamma=0.5) |
| D_scheduler = lr_scheduler.StepLR(D_optim, step_size=100000, gamma=0.5) |
|
|
| pbar = range(args.iter) |
|
|
| if get_rank() == 0: |
| pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.1) |
|
|
| loss_dict = {} |
| mean_path_length_A2B = 0 |
| mean_path_length_B2A = 0 |
|
|
| if args.distributed: |
| G_A2B_module = G_A2B.module |
| G_B2A_module = G_B2A.module |
| D_A_module = D_A.module |
| D_B_module = D_B.module |
| D_L_module = D_L.module |
|
|
| else: |
| G_A2B_module = G_A2B |
| G_B2A_module = G_B2A |
| D_A_module = D_A |
| D_B_module = D_B |
| D_L_module = D_L |
|
|
| for idx in pbar: |
| i = idx + args.start_iter |
|
|
| if i > args.iter: |
| print('Done!') |
| break |
|
|
| ori_A = next(trainA_loader) |
| ori_B = next(trainB_loader) |
| if isinstance(ori_A, list): |
| ori_A = ori_A[0] |
| if isinstance(ori_B, list): |
| ori_B = ori_B[0] |
|
|
| ori_A = ori_A.to(device) |
| ori_B = ori_B.to(device) |
| aug_A = aug(ori_A) |
| aug_B = aug(ori_B) |
| A = aug(ori_A[[np.random.randint(args.batch)]].expand_as(ori_A)) |
| B = aug(ori_B[[np.random.randint(args.batch)]].expand_as(ori_B)) |
|
|
| if i % args.d_reg_every == 0: |
| aug_A.requires_grad = True |
| aug_B.requires_grad = True |
| |
| A2B_content, A2B_style = G_A2B.encode(A) |
| B2A_content, B2A_style = G_B2A.encode(B) |
|
|
| |
| aug_A2B_style = G_B2A.style_encode(aug_B) |
| aug_B2A_style = G_A2B.style_encode(aug_A) |
| rand_A2B_style = torch.randn([args.batch, args.latent_dim]).to(device).requires_grad_() |
| rand_B2A_style = torch.randn([args.batch, args.latent_dim]).to(device).requires_grad_() |
|
|
| |
| idx = torch.randperm(2*args.batch) |
| input_A2B_style = torch.cat([rand_A2B_style, aug_A2B_style], 0)[idx][:args.batch] |
|
|
| idx = torch.randperm(2*args.batch) |
| input_B2A_style = torch.cat([rand_B2A_style, aug_B2A_style], 0)[idx][:args.batch] |
|
|
| fake_A2B = G_A2B.decode(A2B_content, input_A2B_style) |
| fake_B2A = G_B2A.decode(B2A_content, input_B2A_style) |
|
|
|
|
| |
| real_A_logit = D_A(aug_A) |
| real_B_logit = D_B(aug_B) |
| real_L_logit1 = D_L(rand_A2B_style) |
| real_L_logit2 = D_L(rand_B2A_style) |
|
|
| fake_B_logit = D_B(fake_A2B.detach()) |
| fake_A_logit = D_A(fake_B2A.detach()) |
| fake_L_logit1 = D_L(aug_A2B_style.detach()) |
| fake_L_logit2 = D_L(aug_B2A_style.detach()) |
|
|
| |
| D_loss = d_logistic_loss(real_A_logit, fake_A_logit) +\ |
| d_logistic_loss(real_B_logit, fake_B_logit) +\ |
| d_logistic_loss(real_L_logit1, fake_L_logit1) +\ |
| d_logistic_loss(real_L_logit2, fake_L_logit2) |
|
|
| loss_dict['D_adv'] = D_loss |
|
|
| if i % args.d_reg_every == 0: |
| r1_A_loss = d_r1_loss(real_A_logit, aug_A) |
| r1_B_loss = d_r1_loss(real_B_logit, aug_B) |
| r1_L_loss = d_r1_loss(real_L_logit1, rand_A2B_style) + d_r1_loss(real_L_logit2, rand_B2A_style) |
| r1_loss = r1_A_loss + r1_B_loss + r1_L_loss |
| D_r1_loss = (args.r1 / 2 * r1_loss * args.d_reg_every) |
| D_loss += D_r1_loss |
|
|
| D_optim.zero_grad() |
| D_loss.backward() |
| D_optim.step() |
|
|
| |
| |
| fake_B_logit = D_B(fake_A2B) |
| fake_A_logit = D_A(fake_B2A) |
| fake_L_logit1 = D_L(aug_A2B_style) |
| fake_L_logit2 = D_L(aug_B2A_style) |
|
|
| lambda_adv = (1, 1, 1) |
| G_adv_loss = 1 * (g_nonsaturating_loss(fake_A_logit, lambda_adv) +\ |
| g_nonsaturating_loss(fake_B_logit, lambda_adv) +\ |
| 2*g_nonsaturating_loss(fake_L_logit1, (1,)) +\ |
| 2*g_nonsaturating_loss(fake_L_logit2, (1,))) |
|
|
| |
| G_con_loss = 50 * (A2B_style.var(0, unbiased=False).sum() + B2A_style.var(0, unbiased=False).sum()) |
| |
| |
| A2B2A_content, A2B2A_style = G_B2A.encode(fake_A2B) |
| B2A2B_content, B2A2B_style = G_A2B.encode(fake_B2A) |
| fake_A2B2A = G_B2A.decode(A2B2A_content, shuffle_batch(A2B_style)) |
| fake_B2A2B = G_A2B.decode(B2A2B_content, shuffle_batch(B2A_style)) |
|
|
| G_cycle_loss = 20 * (F.mse_loss(fake_A2B2A, A) + F.mse_loss(fake_B2A2B, B)) |
| lpips_loss = 10 * (lpips_fn(fake_A2B2A, A).mean() + lpips_fn(fake_B2A2B, B).mean()) |
|
|
| |
| G_style_loss = 5 * (mse_criterion(A2B2A_style, input_A2B_style) +\ |
| mse_criterion(B2A2B_style, input_B2A_style)) |
|
|
|
|
| G_loss = G_adv_loss + G_cycle_loss + G_con_loss + lpips_loss + G_style_loss |
|
|
| loss_dict['G_adv'] = G_adv_loss |
| loss_dict['G_con'] = G_con_loss |
| loss_dict['G_cycle'] = G_cycle_loss |
| loss_dict['lpips'] = lpips_loss |
|
|
| G_optim.zero_grad() |
| G_loss.backward() |
| G_optim.step() |
|
|
| G_scheduler.step() |
| D_scheduler.step() |
|
|
| accumulate(G_A2B_ema, G_A2B_module) |
| accumulate(G_B2A_ema, G_B2A_module) |
|
|
| loss_reduced = reduce_loss_dict(loss_dict) |
| D_adv_loss_val = loss_reduced['D_adv'].mean().item() |
|
|
| G_adv_loss_val = loss_reduced['G_adv'].mean().item() |
| G_cycle_loss_val = loss_reduced['G_cycle'].mean().item() |
| G_con_loss_val = loss_reduced['G_con'].mean().item() |
| lpips_val = loss_reduced['lpips'].mean().item() |
|
|
| if get_rank() == 0: |
| pbar.set_description( |
| ( |
| f'Dadv: {D_adv_loss_val:.2f}; lpips: {lpips_val:.2f} ' |
| f'Gadv: {G_adv_loss_val:.2f}; Gcycle: {G_cycle_loss_val:.2f}; GMS: {G_con_loss_val:.2f} {G_style_loss.item():.2f}' |
| ) |
| ) |
|
|
| if i % 1000 == 0: |
| with torch.no_grad(): |
| test(args, G_A2B, G_B2A, testA_loader, testB_loader, 'normal', i) |
| test(args, G_A2B_ema, G_B2A_ema, testA_loader, testB_loader, 'ema', i) |
|
|
| if (i+1) % 2000 == 0: |
| torch.save( |
| { |
| 'G_A2B': G_A2B_module.state_dict(), |
| 'G_B2A': G_B2A_module.state_dict(), |
| 'G_A2B_ema': G_A2B_ema.state_dict(), |
| 'G_B2A_ema': G_B2A_ema.state_dict(), |
| 'D_A': D_A_module.state_dict(), |
| 'D_B': D_B_module.state_dict(), |
| 'D_L': D_L_module.state_dict(), |
| 'G_optim': G_optim.state_dict(), |
| 'D_optim': D_optim.state_dict(), |
| 'iter': i, |
| }, |
| os.path.join(model_path, 'ck.pt'), |
| ) |
|
|
|
|
| if __name__ == '__main__': |
| device = 'cuda' |
|
|
| parser = argparse.ArgumentParser() |
|
|
| parser.add_argument('--iter', type=int, default=300000) |
| parser.add_argument('--batch', type=int, default=4) |
| parser.add_argument('--n_sample', type=int, default=64) |
| parser.add_argument('--size', type=int, default=256) |
| parser.add_argument('--r1', type=float, default=10) |
| parser.add_argument('--lambda_cycle', type=int, default=1) |
| parser.add_argument('--path_regularize', type=float, default=2) |
| parser.add_argument('--path_batch_shrink', type=int, default=2) |
| parser.add_argument('--d_reg_every', type=int, default=16) |
| parser.add_argument('--g_reg_every', type=int, default=4) |
| parser.add_argument('--mixing', type=float, default=0.9) |
| parser.add_argument('--ckpt', type=str, default=None) |
| parser.add_argument('--lr', type=float, default=2e-3) |
| parser.add_argument('--local_rank', type=int, default=0) |
| parser.add_argument('--num_down', type=int, default=3) |
| parser.add_argument('--name', type=str, required=True) |
| parser.add_argument('--d_path', type=str, required=True) |
| parser.add_argument('--latent_dim', type=int, default=8) |
| parser.add_argument('--lr_mlp', type=float, default=0.01) |
| parser.add_argument('--n_res', type=int, default=1) |
|
|
| args = parser.parse_args() |
|
|
| n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 |
| args.distributed = False |
|
|
| if args.distributed: |
| torch.cuda.set_device(args.local_rank) |
| torch.distributed.init_process_group(backend='nccl', init_method='env://') |
| synchronize() |
|
|
| save_path = f'./{args.name}' |
| im_path = os.path.join(save_path, 'sample') |
| model_path = os.path.join(save_path, 'checkpoint') |
| os.makedirs(im_path, exist_ok=True) |
| os.makedirs(model_path, exist_ok=True) |
|
|
| args.n_mlp = 5 |
|
|
| args.start_iter = 0 |
|
|
| G_A2B = Generator( args.size, args.num_down, args.latent_dim, args.n_mlp, lr_mlp=args.lr_mlp, n_res=args.n_res).to(device) |
| D_A = Discriminator(args.size).to(device) |
| G_B2A = Generator( args.size, args.num_down, args.latent_dim, args.n_mlp, lr_mlp=args.lr_mlp, n_res=args.n_res).to(device) |
| D_B = Discriminator(args.size).to(device) |
| D_L = LatDiscriminator(args.latent_dim).to(device) |
| lpips_fn = lpips.LPIPS(net='vgg').to(device) |
|
|
| G_A2B_ema = copy.deepcopy(G_A2B).to(device).eval() |
| G_B2A_ema = copy.deepcopy(G_B2A).to(device).eval() |
|
|
| g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1) |
| d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1) |
|
|
| G_optim = optim.Adam( list(G_A2B.parameters()) + list(G_B2A.parameters()), lr=args.lr, betas=(0, 0.99)) |
| D_optim = optim.Adam( |
| list(D_L.parameters()) + list(D_A.parameters()) + list(D_B.parameters()), |
| lr=args.lr, betas=(0**d_reg_ratio, 0.99**d_reg_ratio)) |
|
|
| if args.ckpt is not None: |
| ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage) |
|
|
| try: |
| ckpt_name = os.path.basename(args.ckpt) |
| args.start_iter = int(os.path.splitext(ckpt_name)[0]) |
| |
| except ValueError: |
| pass |
| |
| G_A2B.load_state_dict(ckpt['G_A2B']) |
| G_B2A.load_state_dict(ckpt['G_B2A']) |
| G_A2B_ema.load_state_dict(ckpt['G_A2B_ema']) |
| G_B2A_ema.load_state_dict(ckpt['G_B2A_ema']) |
| D_A.load_state_dict(ckpt['D_A']) |
| D_B.load_state_dict(ckpt['D_B']) |
| D_L.load_state_dict(ckpt['D_L']) |
|
|
| G_optim.load_state_dict(ckpt['G_optim']) |
| D_optim.load_state_dict(ckpt['D_optim']) |
| args.start_iter = ckpt['iter'] |
|
|
| if args.distributed: |
| G_A2B = nn.parallel.DistributedDataParallel( |
| G_A2B, |
| device_ids=[args.local_rank], |
| output_device=args.local_rank, |
| broadcast_buffers=False, |
| ) |
|
|
| D_A = nn.parallel.DistributedDataParallel( |
| D_A, |
| device_ids=[args.local_rank], |
| output_device=args.local_rank, |
| broadcast_buffers=False, |
| ) |
|
|
| G_B2A = nn.parallel.DistributedDataParallel( |
| G_B2A, |
| device_ids=[args.local_rank], |
| output_device=args.local_rank, |
| broadcast_buffers=False, |
| ) |
|
|
| D_B = nn.parallel.DistributedDataParallel( |
| D_B, |
| device_ids=[args.local_rank], |
| output_device=args.local_rank, |
| broadcast_buffers=False, |
| ) |
| D_L = nn.parallel.DistributedDataParallel( |
| D_L, |
| device_ids=[args.local_rank], |
| output_device=args.local_rank, |
| broadcast_buffers=False, |
| ) |
| train_transform = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), inplace=True) |
| ]) |
|
|
| test_transform = transforms.Compose([ |
| transforms.Resize((args.size, args.size)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), inplace=True) |
| ]) |
|
|
| aug = nn.Sequential( |
| K.RandomAffine(degrees=(-20,20), scale=(0.8, 1.2), translate=(0.1, 0.1), shear=0.15), |
| kornia.geometry.transform.Resize(256+30), |
| K.RandomCrop((256,256)), |
| K.RandomHorizontalFlip(), |
| ) |
|
|
|
|
| d_path = args.d_path |
| trainA = ImageFolder(os.path.join(d_path, 'trainA'), train_transform) |
| trainB = ImageFolder(os.path.join(d_path, 'trainB'), train_transform) |
| testA = ImageFolder(os.path.join(d_path, 'testA'), test_transform) |
| testB = ImageFolder(os.path.join(d_path, 'testB'), test_transform) |
| |
| trainA_loader = data.DataLoader(trainA, batch_size=args.batch, |
| sampler=data_sampler(trainA, shuffle=True, distributed=args.distributed), drop_last=True, pin_memory=True, num_workers=5) |
| trainB_loader = data.DataLoader(trainB, batch_size=args.batch, |
| sampler=data_sampler(trainB, shuffle=True, distributed=args.distributed), drop_last=True, pin_memory=True, num_workers=5) |
|
|
| testA_loader = data.DataLoader(testA, batch_size=1, shuffle=False) |
| testB_loader = data.DataLoader(testB, batch_size=1, shuffle=False) |
|
|
|
|
| train(args, trainA_loader, trainB_loader, testA_loader, testB_loader, G_A2B, G_B2A, D_A, D_B, G_optim, D_optim, device) |
|
|