| import os |
| import time |
| import shutil |
|
|
| import torch |
| import cv2 |
| import torch.optim as optim |
| import numpy as np |
| from glob import glob |
| from torch.cuda.amp import GradScaler, autocast |
| from torch.nn.parallel.distributed import DistributedDataParallel |
| from torch.utils.data import Dataset, DataLoader |
| from tqdm import tqdm |
| from utils.image_processing import denormalize_input, preprocess_images, resize_image |
| from losses import LossSummary, AnimeGanLoss, to_gray_scale |
| from utils import load_checkpoint, save_checkpoint, read_image |
| from utils.common import set_lr |
| from color_transfer import color_transfer_pytorch |
|
|
|
|
| def transfer_color_and_rescale(src, target): |
| """Transfer color from src image to target then rescale to [-1, 1]""" |
| out = color_transfer_pytorch(src, target) |
| out = (out / 0.5) - 1 |
| return out |
|
|
| def gaussian_noise(): |
| gaussian_mean = torch.tensor(0.0) |
| gaussian_std = torch.tensor(0.1) |
| return torch.normal(gaussian_mean, gaussian_std) |
|
|
| def convert_to_readable(seconds): |
| return time.strftime('%H:%M:%S', time.gmtime(seconds)) |
|
|
|
|
| def revert_to_np_image(image_tensor): |
| image = image_tensor.cpu().numpy() |
| |
| image = image.transpose(1, 2, 0) |
| image = denormalize_input(image, dtype=np.int16) |
| return image[..., ::-1] |
|
|
|
|
| def save_generated_images(images: torch.Tensor, save_dir: str): |
| """Save generated images `(*, 3, H, W)` range [-1, 1] into disk""" |
| os.makedirs(save_dir, exist_ok=True) |
| images = images.clone().detach().cpu().numpy() |
| images = images.transpose(0, 2, 3, 1) |
| n_images = len(images) |
|
|
| for i in range(n_images): |
| img = images[i] |
| img = denormalize_input(img, dtype=np.int16) |
| img = img[..., ::-1] |
| cv2.imwrite(os.path.join(save_dir, f"G{i}.jpg"), img) |
|
|
|
|
| class DDPTrainer: |
| def _init_distributed(self): |
| if self.cfg.ddp: |
| self.logger.info("Setting up DDP") |
| self.pg = torch.distributed.init_process_group( |
| backend="nccl", |
| rank=self.cfg.local_rank, |
| world_size=self.cfg.world_size |
| ) |
| self.G = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.G, self.pg) |
| self.D = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.D, self.pg) |
| torch.cuda.set_device(self.cfg.local_rank) |
| self.G.cuda(self.cfg.local_rank) |
| self.D.cuda(self.cfg.local_rank) |
| self.logger.info("Setting up DDP Done") |
|
|
| def _init_amp(self, enabled=False): |
| |
| self.scaler_g = GradScaler(enabled=enabled) |
| self.scaler_d = GradScaler(enabled=enabled) |
| if self.cfg.ddp: |
| self.G = DistributedDataParallel( |
| self.G, device_ids=[self.cfg.local_rank], |
| output_device=self.cfg.local_rank, |
| find_unused_parameters=False) |
| |
| self.D = DistributedDataParallel( |
| self.D, device_ids=[self.cfg.local_rank], |
| output_device=self.cfg.local_rank, |
| find_unused_parameters=False) |
| self.logger.info("Set DistributedDataParallel") |
|
|
|
|
| class Trainer(DDPTrainer): |
| """ |
| Base Trainer class |
| """ |
|
|
| def __init__( |
| self, |
| generator, |
| discriminator, |
| config, |
| logger, |
| ) -> None: |
| self.G = generator |
| self.D = discriminator |
| self.cfg = config |
| self.max_norm = 10 |
| self.device_type = 'cuda' if self.cfg.device.startswith('cuda') else 'cpu' |
| self.optimizer_g = optim.Adam(self.G.parameters(), lr=self.cfg.lr_g, betas=(0.5, 0.999)) |
| self.optimizer_d = optim.Adam(self.D.parameters(), lr=self.cfg.lr_d, betas=(0.5, 0.999)) |
| self.loss_tracker = LossSummary() |
| if self.cfg.ddp: |
| self.device = torch.device(f"cuda:{self.cfg.local_rank}") |
| logger.info(f"---------{self.cfg.local_rank} {self.device}") |
| else: |
| self.device = torch.device(self.cfg.device) |
| self.loss_fn = AnimeGanLoss(self.cfg, self.device, self.cfg.gray_adv) |
| self.logger = logger |
| self._init_working_dir() |
| self._init_distributed() |
| self._init_amp(enabled=self.cfg.amp) |
|
|
| def _init_working_dir(self): |
| """Init working directory for saving checkpoint, ...""" |
| os.makedirs(self.cfg.exp_dir, exist_ok=True) |
| Gname = self.G.name |
| Dname = self.D.name |
| self.checkpoint_path_G_init = os.path.join(self.cfg.exp_dir, f"{Gname}_init.pt") |
| self.checkpoint_path_G = os.path.join(self.cfg.exp_dir, f"{Gname}.pt") |
| self.checkpoint_path_D = os.path.join(self.cfg.exp_dir, f"{Dname}.pt") |
| self.save_image_dir = os.path.join(self.cfg.exp_dir, "generated_images") |
| self.example_image_dir = os.path.join(self.cfg.exp_dir, "train_images") |
| os.makedirs(self.save_image_dir, exist_ok=True) |
| os.makedirs(self.example_image_dir, exist_ok=True) |
|
|
| def init_weight_G(self, weight: str): |
| """Init Generator weight""" |
| return load_checkpoint(self.G, weight) |
|
|
| def init_weight_D(self, weight: str): |
| """Init Discriminator weight""" |
| return load_checkpoint(self.D, weight) |
|
|
| def pretrain_generator(self, train_loader, start_epoch): |
| """ |
| Pretrain Generator to recontruct input image. |
| """ |
| init_losses = [] |
| set_lr(self.optimizer_g, self.cfg.init_lr) |
| for epoch in range(start_epoch, self.cfg.init_epochs): |
| |
| |
| pbar = tqdm(train_loader) |
| for data in pbar: |
| img = data["image"].to(self.device) |
|
|
| self.optimizer_g.zero_grad() |
|
|
| with autocast(enabled=self.cfg.amp): |
| fake_img = self.G(img) |
| loss = self.loss_fn.content_loss_vgg(img, fake_img) |
|
|
| self.scaler_g.scale(loss).backward() |
| self.scaler_g.step(self.optimizer_g) |
| self.scaler_g.update() |
|
|
| if self.cfg.ddp: |
| torch.distributed.barrier() |
|
|
| init_losses.append(loss.cpu().detach().numpy()) |
| avg_content_loss = sum(init_losses) / len(init_losses) |
| pbar.set_description(f'[Init Training G] content loss: {avg_content_loss:2f}') |
|
|
| save_checkpoint(self.G, self.checkpoint_path_G_init, self.optimizer_g, epoch) |
| if self.cfg.local_rank == 0: |
| self.generate_and_save(self.cfg.test_image_dir, subname='initg') |
| self.logger.info(f"Epoch {epoch}/{self.cfg.init_epochs}") |
|
|
| set_lr(self.optimizer_g, self.cfg.lr_g) |
|
|
| def train_epoch(self, epoch, train_loader): |
| pbar = tqdm(train_loader, total=len(train_loader)) |
| for data in pbar: |
| img = data["image"].to(self.device) |
| anime = data["anime"].to(self.device) |
| anime_gray = data["anime_gray"].to(self.device) |
| anime_smt_gray = data["smooth_gray"].to(self.device) |
|
|
| |
| self.optimizer_d.zero_grad() |
|
|
| with autocast(enabled=self.cfg.amp): |
| fake_img = self.G(img) |
| |
| if self.cfg.d_noise: |
| fake_img += gaussian_noise() |
| anime += gaussian_noise() |
| anime_gray += gaussian_noise() |
| anime_smt_gray += gaussian_noise() |
|
|
| if self.cfg.gray_adv: |
| fake_img = to_gray_scale(fake_img) |
|
|
| fake_d = self.D(fake_img) |
| real_anime_d = self.D(anime) |
| real_anime_gray_d = self.D(anime_gray) |
| real_anime_smt_gray_d = self.D(anime_smt_gray) |
|
|
| loss_d = self.loss_fn.compute_loss_D( |
| fake_d, |
| real_anime_d, |
| real_anime_gray_d, |
| real_anime_smt_gray_d |
| ) |
|
|
| self.scaler_d.scale(loss_d).backward() |
| self.scaler_d.unscale_(self.optimizer_d) |
| torch.nn.utils.clip_grad_norm_(self.D.parameters(), max_norm=self.max_norm) |
| self.scaler_d.step(self.optimizer_d) |
| self.scaler_d.update() |
| if self.cfg.ddp: |
| torch.distributed.barrier() |
| self.loss_tracker.update_loss_D(loss_d) |
|
|
| |
| self.optimizer_g.zero_grad() |
|
|
| with autocast(enabled=self.cfg.amp): |
| fake_img = self.G(img) |
| |
| if self.cfg.gray_adv: |
| fake_d = self.D(to_gray_scale(fake_img)) |
| else: |
| fake_d = self.D(fake_img) |
|
|
| ( |
| adv_loss, con_loss, |
| gra_loss, col_loss, |
| tv_loss |
| ) = self.loss_fn.compute_loss_G( |
| fake_img, |
| img, |
| fake_d, |
| anime_gray, |
| ) |
| loss_g = adv_loss + con_loss + gra_loss + col_loss + tv_loss |
| if torch.isnan(adv_loss).any(): |
| self.logger.info("----------------------------------------------") |
| self.logger.info(fake_d) |
| self.logger.info(adv_loss) |
| self.logger.info("----------------------------------------------") |
| raise ValueError("NAN loss!!") |
|
|
| self.scaler_g.scale(loss_g).backward() |
| self.scaler_d.unscale_(self.optimizer_g) |
| grad = torch.nn.utils.clip_grad_norm_(self.G.parameters(), max_norm=self.max_norm) |
| self.scaler_g.step(self.optimizer_g) |
| self.scaler_g.update() |
| if self.cfg.ddp: |
| torch.distributed.barrier() |
|
|
| self.loss_tracker.update_loss_G(adv_loss, gra_loss, col_loss, con_loss) |
| pbar.set_description(f"{self.loss_tracker.get_loss_description()} - {grad:.3f}") |
|
|
| def get_train_loader(self, dataset): |
| if self.cfg.ddp: |
| train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) |
| else: |
| train_sampler = None |
| return DataLoader( |
| dataset, |
| batch_size=self.cfg.batch_size, |
| num_workers=self.cfg.num_workers, |
| pin_memory=True, |
| shuffle=train_sampler is None, |
| sampler=train_sampler, |
| drop_last=True, |
| |
| ) |
|
|
| def maybe_increase_imgsz(self, epoch, train_dataset): |
| """ |
| Increase image size at specific epoch |
| + 50% epochs train at imgsz[0] |
| + the rest 50% will increase every `len(epochs) / 2 / (len(imgsz) - 1)` |
| |
| Args: |
| epoch: Current epoch |
| train_dataset: Dataset |
| |
| Examples: |
| ``` |
| epochs = 100 |
| imgsz = [256, 352, 416, 512] |
| => [(0, 256), (50, 352), (66, 416), (82, 512)] |
| ``` |
| """ |
| epochs = self.cfg.epochs |
| imgsz = self.cfg.imgsz |
| num_size_remains = len(imgsz) - 1 |
| half_epochs = epochs // 2 |
|
|
| if len(imgsz) == 1: |
| new_size = imgsz[0] |
| elif epoch < half_epochs: |
| new_size = imgsz[0] |
| else: |
| per_epoch_increment = int(half_epochs / num_size_remains) |
| found = None |
| for i, size in enumerate(imgsz[:]): |
| if epoch < half_epochs + per_epoch_increment * i: |
| found = size |
| break |
| if not found: |
| found = imgsz[-1] |
| new_size = found |
|
|
| self.logger.info(f"Check {imgsz}, {new_size}, {train_dataset.imgsz}") |
| if new_size != train_dataset.imgsz: |
| train_dataset.set_imgsz(new_size) |
| self.logger.info(f"Increase image size to {new_size} at epoch {epoch}") |
|
|
| def train(self, train_dataset: Dataset, start_epoch=0, start_epoch_g=0): |
| """ |
| Train Generator and Discriminator. |
| """ |
| self.logger.info(self.device) |
| self.G.to(self.device) |
| self.D.to(self.device) |
|
|
| self.pretrain_generator(self.get_train_loader(train_dataset), start_epoch_g) |
|
|
| if self.cfg.local_rank == 0: |
| self.logger.info(f"Start training for {self.cfg.epochs} epochs") |
|
|
| for i, data in enumerate(train_dataset): |
| for k in data.keys(): |
| image = data[k] |
| cv2.imwrite( |
| os.path.join(self.example_image_dir, f"data_{k}_{i}.jpg"), |
| revert_to_np_image(image) |
| ) |
| if i == 2: |
| break |
|
|
| end = None |
| num_iter = 0 |
| per_epoch_times = [] |
| for epoch in range(start_epoch, self.cfg.epochs): |
| self.maybe_increase_imgsz(epoch, train_dataset) |
|
|
| start = time.time() |
| self.train_epoch(epoch, self.get_train_loader(train_dataset)) |
|
|
| if epoch % self.cfg.save_interval == 0 and self.cfg.local_rank == 0: |
| save_checkpoint(self.G, self.checkpoint_path_G,self.optimizer_g, epoch) |
| save_checkpoint(self.D, self.checkpoint_path_D, self.optimizer_d, epoch) |
| self.generate_and_save(self.cfg.test_image_dir) |
|
|
| if epoch % 10 == 0: |
| self.copy_results(epoch) |
|
|
| num_iter += 1 |
|
|
| if self.cfg.local_rank == 0: |
| end = time.time() |
| if end is None: |
| eta = 9999 |
| else: |
| per_epoch_time = (end - start) |
| per_epoch_times.append(per_epoch_time) |
| eta = np.mean(per_epoch_times) * (self.cfg.epochs - epoch) |
| eta = convert_to_readable(eta) |
| self.logger.info(f"epoch {epoch}/{self.cfg.epochs}, ETA: {eta}") |
|
|
| def generate_and_save( |
| self, |
| image_dir, |
| max_imgs=15, |
| subname='gen' |
| ): |
| ''' |
| Generate and save images |
| ''' |
| start = time.time() |
| self.G.eval() |
|
|
| max_iter = max_imgs |
| fake_imgs = [] |
| real_imgs = [] |
| image_files = glob(os.path.join(image_dir, "*")) |
|
|
| for i, image_file in enumerate(image_files): |
| image = read_image(image_file) |
| image = resize_image(image) |
| real_imgs.append(image.copy()) |
| image = preprocess_images(image) |
| image = image.to(self.device) |
| with torch.no_grad(): |
| with autocast(enabled=self.cfg.amp): |
| fake_img = self.G(image) |
| |
| fake_img = fake_img.detach().cpu().numpy() |
| |
| fake_img = fake_img.transpose(0, 2, 3, 1) |
| fake_imgs.append(denormalize_input(fake_img, dtype=np.int16)[0]) |
|
|
| if i + 1 == max_iter: |
| break |
|
|
| |
|
|
| for i, (real_img, fake_img) in enumerate(zip(real_imgs, fake_imgs)): |
| img = np.concatenate((real_img, fake_img), axis=1) |
| save_path = os.path.join(self.save_image_dir, f'{subname}_{i}.jpg') |
| if not cv2.imwrite(save_path, img[..., ::-1]): |
| self.logger.info(f"Save generated image failed, {save_path}, {img.shape}") |
| elapsed = time.time() - start |
| self.logger.info(f"Generated {len(fake_imgs)} images in {elapsed:.3f}s.") |
|
|
| def copy_results(self, epoch): |
| """Copy result (Weight + Generated images) to each epoch folder |
| Every N epoch |
| """ |
| copy_dir = os.path.join(self.cfg.exp_dir, f"epoch_{epoch}") |
| os.makedirs(copy_dir, exist_ok=True) |
|
|
| shutil.copy2( |
| self.checkpoint_path_G, |
| copy_dir |
| ) |
|
|
| dest = os.path.join(copy_dir, os.path.basename(self.save_image_dir)) |
| shutil.copytree( |
| self.save_image_dir, |
| dest, |
| dirs_exist_ok=True |
| ) |
|
|