| | import argparse |
| |
|
| | import os |
| |
|
| | import numpy as np |
| | from PIL import Image |
| | from skimage import color, io |
| | import torch |
| | from torch import nn, optim |
| | from torch.nn import functional as F |
| | from torch.utils import data |
| | from torchvision import transforms |
| | from tqdm import tqdm |
| |
|
| | |
| | from models import ColorEncoder, ColorUNet |
| | from vgg_model import vgg19 |
| | from data.data_loader import MultiResolutionDataset |
| |
|
| | from utils import tensor_lab2rgb |
| |
|
| | from distributed import ( |
| | get_rank, |
| | synchronize, |
| | reduce_loss_dict, |
| | ) |
| |
|
| | def mkdirss(dirpath): |
| | if not os.path.exists(dirpath): |
| | os.makedirs(dirpath) |
| |
|
| | def data_sampler(dataset, shuffle, distributed): |
| | if distributed: |
| | return data.distributed.DistributedSampler(dataset, shuffle=shuffle) |
| |
|
| | if shuffle: |
| | return data.RandomSampler(dataset) |
| |
|
| | else: |
| | return data.SequentialSampler(dataset) |
| |
|
| |
|
| | def requires_grad(model, flag=True): |
| | for p in model.parameters(): |
| | p.requires_grad = flag |
| |
|
| |
|
| | def sample_data(loader): |
| | while True: |
| | for batch in loader: |
| | yield batch |
| |
|
| | def Lab2RGB_out(img_lab): |
| | img_lab = img_lab.detach().cpu() |
| | img_l = img_lab[:,:1,:,:] |
| | img_ab = img_lab[:,1:,:,:] |
| | |
| | |
| | img_l = img_l + 50 |
| | pred_lab = torch.cat((img_l, img_ab), 1)[0,...].numpy() |
| | |
| | |
| | out = (np.clip(color.lab2rgb(pred_lab.transpose(1, 2, 0)), 0, 1)* 255).astype("uint8") |
| | return out |
| |
|
| | def RGB2Lab(inputs): |
| | |
| | |
| | return color.rgb2lab(inputs) |
| |
|
| | def Normalize(inputs): |
| | l = inputs[:, :, 0:1] |
| | ab = inputs[:, :, 1:3] |
| | l = l - 50 |
| | lab = np.concatenate((l, ab), 2) |
| |
|
| | return lab.astype('float32') |
| |
|
| | def numpy2tensor(inputs): |
| | out = torch.from_numpy(inputs.transpose(2,0,1)) |
| | return out |
| |
|
| | def tensor2numpy(inputs): |
| | out = inputs[0,...].detach().cpu().numpy().transpose(1,2,0) |
| | return out |
| |
|
| | def preprocessing(inputs): |
| | |
| | img_lab = Normalize(RGB2Lab(inputs)) |
| | img = np.array(inputs, 'float32') |
| | img = numpy2tensor(img) |
| | img_lab = numpy2tensor(img_lab) |
| | return img.unsqueeze(0), img_lab.unsqueeze(0) |
| |
|
| | def uncenter_l(inputs): |
| | l = inputs[:,:1,:,:] + 50 |
| | ab = inputs[:,1:,:,:] |
| | return torch.cat((l, ab), 1) |
| |
|
| | def train( |
| | args, |
| | loader, |
| | colorEncoder, |
| | colorUNet, |
| | vggnet, |
| | g_optim, |
| | device, |
| | ): |
| | loader = sample_data(loader) |
| | |
| | pbar = range(args.iter) |
| |
|
| | if get_rank() == 0: |
| | pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) |
| |
|
| | g_loss_val = 0 |
| | loss_dict = {} |
| | recon_val_all = 0 |
| | fea_val_all = 0 |
| |
|
| | if args.distributed: |
| | colorEncoder_module = colorEncoder.module |
| | colorUNet_module = colorUNet.module |
| |
|
| | else: |
| | colorEncoder_module = colorEncoder |
| | colorUNet_module = colorUNet |
| |
|
| | for idx in pbar: |
| | i = idx + args.start_iter+1 |
| |
|
| | if i > args.iter: |
| | print("Done!") |
| |
|
| | break |
| |
|
| | img, img_ref, img_lab = next(loader) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | img = img.to(device) |
| | img_lab = img_lab.to(device) |
| |
|
| | img_ref = img_ref.to(device) |
| |
|
| | img_l = img_lab[:,:1,:,:] / 50 |
| | img_ab = img_lab[:,1:,:,:] / 110 |
| | |
| |
|
| | colorEncoder.train() |
| | colorUNet.train() |
| |
|
| | requires_grad(colorEncoder, True) |
| | requires_grad(colorUNet, True) |
| |
|
| | ref_color_vector = colorEncoder(img_ref / 255.) |
| |
|
| | fake_swap_ab = colorUNet((img_l, ref_color_vector)) |
| |
|
| | |
| | recon_loss = (F.smooth_l1_loss(fake_swap_ab, img_ab)) * 1 |
| |
|
| | |
| | real_img_rgb = img / 255. |
| | features_A = vggnet(real_img_rgb, layer_name='all') |
| |
|
| | fake_swap_rgb = tensor_lab2rgb(torch.cat((img_l*50+50, fake_swap_ab*110), 1)) |
| | features_B = vggnet(fake_swap_rgb, layer_name='all') |
| | |
| | |
| |
|
| | fea_loss1 = F.l1_loss(features_A[0], features_B[0]) / 32 * 0.1 |
| | fea_loss2 = F.l1_loss(features_A[1], features_B[1]) / 16 * 0.1 |
| | fea_loss3 = F.l1_loss(features_A[2], features_B[2]) / 8 * 0.1 |
| | fea_loss4 = F.l1_loss(features_A[3], features_B[3]) / 4 * 0.1 |
| | fea_loss5 = F.l1_loss(features_A[4], features_B[4]) * 0.1 |
| |
|
| | fea_loss = fea_loss1 + fea_loss2 + fea_loss3 + fea_loss4 + fea_loss5 |
| |
|
| | loss_dict["recon"] = recon_loss |
| |
|
| | loss_dict["fea"] = fea_loss |
| |
|
| | g_optim.zero_grad() |
| | (recon_loss+fea_loss).backward() |
| | g_optim.step() |
| |
|
| | loss_reduced = reduce_loss_dict(loss_dict) |
| |
|
| |
|
| | recon_val = loss_reduced["recon"].mean().item() |
| | recon_val_all += recon_val |
| | |
| | fea_val = loss_reduced["fea"].mean().item() |
| | fea_val_all += fea_val |
| | |
| |
|
| | if get_rank() == 0: |
| | pbar.set_description( |
| | ( |
| | f"recon:{recon_val:.4f}; fea:{fea_val:.4f};" |
| | ) |
| | ) |
| |
|
| |
|
| | if i % 50 == 0: |
| | print(f"recon_all:{recon_val_all/50:.4f}; fea_all:{fea_val_all/50:.4f};") |
| | recon_val_all = 0 |
| | fea_val_all = 0 |
| |
|
| | if i % 500 == 0: |
| | with torch.no_grad(): |
| | colorEncoder.eval() |
| | colorUNet.eval() |
| |
|
| | imgsize = 256 |
| | for inum in range(15): |
| | val_img_path = 'test_datasets/val_Manga/in%d.jpg' % (inum + 1) |
| | val_ref_path = 'test_datasets/val_Manga/ref%d.jpg' % (inum + 1) |
| | |
| | |
| | out_name = 'in%d_ref%d.png'%(inum+1, inum+1) |
| | val_img = Image.open(val_img_path).convert("RGB").resize((imgsize, imgsize)) |
| | val_img_ref = Image.open(val_ref_path).convert("RGB").resize((imgsize, imgsize)) |
| | val_img, val_img_lab = preprocessing(val_img) |
| | val_img_ref, val_img_ref_lab = preprocessing(val_img_ref) |
| |
|
| | |
| | val_img_lab = val_img_lab.to(device) |
| | val_img_ref = val_img_ref.to(device) |
| | |
| |
|
| | val_img_l = val_img_lab[:,:1,:,:] / 50. |
| | |
| |
|
| | ref_color_vector = colorEncoder(val_img_ref / 255.) |
| | fake_swap_ab = colorUNet((val_img_l, ref_color_vector)) |
| |
|
| | fake_img = torch.cat((val_img_l*50, fake_swap_ab*110), 1) |
| |
|
| | sample = np.concatenate((tensor2numpy(val_img), tensor2numpy(val_img_ref), Lab2RGB_out(fake_img)), 1) |
| |
|
| | out_dir = 'training_logs/%s/%06d'%(args.experiment_name, i) |
| | mkdirss(out_dir) |
| | io.imsave('%s/%s'%(out_dir, out_name), sample.astype('uint8')) |
| | torch.cuda.empty_cache() |
| | if i % 2500 == 0: |
| | out_dir = "experiments/%s"%(args.experiment_name) |
| | mkdirss(out_dir) |
| | torch.save( |
| | { |
| | "colorEncoder": colorEncoder_module.state_dict(), |
| | "colorUNet": colorUNet_module.state_dict(), |
| | "g_optim": g_optim.state_dict(), |
| | "args": args, |
| | }, |
| | f"%s/{str(i).zfill(6)}.pt"%(out_dir), |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | device = "cuda" |
| |
|
| | torch.backends.cudnn.benchmark = True |
| |
|
| | parser = argparse.ArgumentParser() |
| |
|
| | parser.add_argument("--datasets", type=str) |
| | parser.add_argument("--iter", type=int, default=100000) |
| | parser.add_argument("--batch", type=int, default=16) |
| | parser.add_argument("--size", type=int, default=256) |
| | parser.add_argument("--ckpt", type=str, default=None) |
| | parser.add_argument("--lr", type=float, default=0.0001) |
| | parser.add_argument("--experiment_name", type=str, default="default") |
| | parser.add_argument("--wandb", action="store_true") |
| | parser.add_argument("--local_rank", type=int, default=0) |
| |
|
| | args = parser.parse_args() |
| |
|
| | n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 |
| | args.distributed = n_gpu > 1 |
| |
|
| | if args.distributed: |
| | torch.cuda.set_device(args.local_rank) |
| | torch.distributed.init_process_group(backend="nccl", init_method="env://") |
| | synchronize() |
| |
|
| | args.start_iter = 0 |
| |
|
| | vggnet = vgg19(pretrained_path = './experiments/VGG19/vgg19-dcbb9e9d.pth', require_grad = False) |
| | vggnet = vggnet.to(device) |
| | vggnet.eval() |
| |
|
| | colorEncoder = ColorEncoder(color_dim=512).to(device) |
| | colorUNet = ColorUNet(bilinear=True).to(device) |
| |
|
| | |
| | g_optim = optim.Adam( |
| | list(colorEncoder.parameters()) + list(colorUNet.parameters()), |
| | lr=args.lr, |
| | betas=(0.9, 0.99), |
| | ) |
| |
|
| | if args.ckpt is not None: |
| | print("load model:", args.ckpt) |
| |
|
| | ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage) |
| |
|
| | try: |
| | ckpt_name = os.path.basename(args.ckpt) |
| | args.start_iter = int(os.path.splitext(ckpt_name)[0]) |
| |
|
| | except ValueError: |
| | pass |
| | |
| | colorEncoder.load_state_dict(ckpt["colorEncoder"]) |
| | colorUNet.load_state_dict(ckpt["colorUNet"]) |
| | g_optim.load_state_dict(ckpt["g_optim"]) |
| |
|
| | |
| |
|
| | if args.distributed: |
| | colorEncoder = nn.parallel.DistributedDataParallel( |
| | colorEncoder, |
| | device_ids=[args.local_rank], |
| | output_device=args.local_rank, |
| | broadcast_buffers=False, |
| | ) |
| |
|
| | colorUNet = nn.parallel.DistributedDataParallel( |
| | colorUNet, |
| | device_ids=[args.local_rank], |
| | output_device=args.local_rank, |
| | broadcast_buffers=False, |
| | ) |
| |
|
| | |
| | transform = transforms.Compose( |
| | [ |
| | transforms.RandomHorizontalFlip(), |
| | transforms.RandomVerticalFlip(), |
| | transforms.RandomRotation(degrees=(0, 360)) |
| | ] |
| | ) |
| |
|
| | datasets = [] |
| | dataset = MultiResolutionDataset(args.datasets, transform, args.size) |
| | datasets.append(dataset) |
| |
|
| | loader = data.DataLoader( |
| | data.ConcatDataset(datasets), |
| | batch_size=args.batch, |
| | sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed), |
| | drop_last=True, |
| | ) |
| |
|
| | train( |
| | args, |
| | loader, |
| | colorEncoder, |
| | colorUNet, |
| | vggnet, |
| | g_optim, |
| | device, |
| | ) |
| |
|
| |
|