| | import os |
| | import random |
| | import matplotlib |
| | import matplotlib.pyplot as plt |
| |
|
| | matplotlib.use('Agg') |
| |
|
| | import torch |
| | from torch import nn, autograd |
| | from torch.utils.data import DataLoader |
| | from torch.utils.tensorboard import SummaryWriter |
| | import torch.nn.functional as F |
| |
|
| | from utils import common, train_utils |
| | from criteria import id_loss, moco_loss |
| | from configs import data_configs |
| | from datasets.images_dataset import ImagesDataset |
| | from criteria.lpips.lpips import LPIPS |
| | from models.psp import pSp |
| | from models.latent_codes_pool import LatentCodesPool |
| | from models.discriminator import LatentCodesDiscriminator |
| | from models.encoders.psp_encoders import ProgressiveStage |
| | from training.ranger import Ranger |
| |
|
| | random.seed(0) |
| | torch.manual_seed(0) |
| |
|
| |
|
| | class Coach: |
| | def __init__(self, opts, prev_train_checkpoint=None): |
| | self.opts = opts |
| |
|
| | self.global_step = 0 |
| |
|
| | self.device = 'cuda:0' |
| | self.opts.device = self.device |
| | |
| | self.net = pSp(self.opts).to(self.device) |
| |
|
| | |
| | if self.opts.lpips_lambda > 0: |
| | self.lpips_loss = LPIPS(net_type=self.opts.lpips_type).to(self.device).eval() |
| | if self.opts.id_lambda > 0: |
| | if 'ffhq' in self.opts.dataset_type or 'celeb' in self.opts.dataset_type: |
| | self.id_loss = id_loss.IDLoss().to(self.device).eval() |
| | else: |
| | self.id_loss = moco_loss.MocoLoss(opts).to(self.device).eval() |
| | self.mse_loss = nn.MSELoss().to(self.device).eval() |
| |
|
| | |
| | self.optimizer = self.configure_optimizers() |
| |
|
| | |
| | if self.opts.w_discriminator_lambda > 0: |
| | self.discriminator = LatentCodesDiscriminator(512, 4).to(self.device) |
| | self.discriminator_optimizer = torch.optim.Adam(list(self.discriminator.parameters()), |
| | lr=opts.w_discriminator_lr) |
| | self.real_w_pool = LatentCodesPool(self.opts.w_pool_size) |
| | self.fake_w_pool = LatentCodesPool(self.opts.w_pool_size) |
| |
|
| | |
| | self.train_dataset, self.test_dataset = self.configure_datasets() |
| | self.train_dataloader = DataLoader(self.train_dataset, |
| | batch_size=self.opts.batch_size, |
| | shuffle=True, |
| | num_workers=int(self.opts.workers), |
| | drop_last=True) |
| | self.test_dataloader = DataLoader(self.test_dataset, |
| | batch_size=self.opts.test_batch_size, |
| | shuffle=False, |
| | num_workers=int(self.opts.test_workers), |
| | drop_last=True) |
| |
|
| | |
| | log_dir = os.path.join(opts.exp_dir, 'logs') |
| | os.makedirs(log_dir, exist_ok=True) |
| | self.logger = SummaryWriter(log_dir=log_dir) |
| |
|
| | |
| | self.checkpoint_dir = os.path.join(opts.exp_dir, 'checkpoints') |
| | os.makedirs(self.checkpoint_dir, exist_ok=True) |
| | self.best_val_loss = None |
| | if self.opts.save_interval is None: |
| | self.opts.save_interval = self.opts.max_steps |
| |
|
| | if prev_train_checkpoint is not None: |
| | self.load_from_train_checkpoint(prev_train_checkpoint) |
| | prev_train_checkpoint = None |
| |
|
| | def load_from_train_checkpoint(self, ckpt): |
| | print('Loading previous training data...') |
| | self.global_step = ckpt['global_step'] + 1 |
| | self.best_val_loss = ckpt['best_val_loss'] |
| | self.net.load_state_dict(ckpt['state_dict']) |
| |
|
| | if self.opts.keep_optimizer: |
| | self.optimizer.load_state_dict(ckpt['optimizer']) |
| | if self.opts.w_discriminator_lambda > 0: |
| | self.discriminator.load_state_dict(ckpt['discriminator_state_dict']) |
| | self.discriminator_optimizer.load_state_dict(ckpt['discriminator_optimizer_state_dict']) |
| | if self.opts.progressive_steps: |
| | self.check_for_progressive_training_update(is_resume_from_ckpt=True) |
| | print(f'Resuming training from step {self.global_step}') |
| |
|
| | def train(self): |
| | self.net.train() |
| | if self.opts.progressive_steps: |
| | self.check_for_progressive_training_update() |
| | while self.global_step < self.opts.max_steps: |
| | for batch_idx, batch in enumerate(self.train_dataloader): |
| | loss_dict = {} |
| | if self.is_training_discriminator(): |
| | loss_dict = self.train_discriminator(batch) |
| | x, y, y_hat, latent = self.forward(batch) |
| | loss, encoder_loss_dict, id_logs = self.calc_loss(x, y, y_hat, latent) |
| | loss_dict = {**loss_dict, **encoder_loss_dict} |
| | self.optimizer.zero_grad() |
| | loss.backward() |
| | self.optimizer.step() |
| |
|
| | |
| | if self.global_step % self.opts.image_interval == 0 or ( |
| | self.global_step < 1000 and self.global_step % 25 == 0): |
| | self.parse_and_log_images(id_logs, x, y, y_hat, title='images/train/faces') |
| | if self.global_step % self.opts.board_interval == 0: |
| | self.print_metrics(loss_dict, prefix='train') |
| | self.log_metrics(loss_dict, prefix='train') |
| |
|
| | |
| | val_loss_dict = None |
| | if self.global_step % self.opts.val_interval == 0 or self.global_step == self.opts.max_steps: |
| | val_loss_dict = self.validate() |
| | if val_loss_dict and (self.best_val_loss is None or val_loss_dict['loss'] < self.best_val_loss): |
| | self.best_val_loss = val_loss_dict['loss'] |
| | self.checkpoint_me(val_loss_dict, is_best=True) |
| |
|
| | if self.global_step % self.opts.save_interval == 0 or self.global_step == self.opts.max_steps: |
| | if val_loss_dict is not None: |
| | self.checkpoint_me(val_loss_dict, is_best=False) |
| | else: |
| | self.checkpoint_me(loss_dict, is_best=False) |
| |
|
| | if self.global_step == self.opts.max_steps: |
| | print('OMG, finished training!') |
| | break |
| |
|
| | self.global_step += 1 |
| | if self.opts.progressive_steps: |
| | self.check_for_progressive_training_update() |
| |
|
| | def check_for_progressive_training_update(self, is_resume_from_ckpt=False): |
| | for i in range(len(self.opts.progressive_steps)): |
| | if is_resume_from_ckpt and self.global_step >= self.opts.progressive_steps[i]: |
| | self.net.encoder.set_progressive_stage(ProgressiveStage(i)) |
| | if self.global_step == self.opts.progressive_steps[i]: |
| | self.net.encoder.set_progressive_stage(ProgressiveStage(i)) |
| |
|
| | def validate(self): |
| | self.net.eval() |
| | agg_loss_dict = [] |
| | for batch_idx, batch in enumerate(self.test_dataloader): |
| | cur_loss_dict = {} |
| | if self.is_training_discriminator(): |
| | cur_loss_dict = self.validate_discriminator(batch) |
| | with torch.no_grad(): |
| | x, y, y_hat, latent = self.forward(batch) |
| | loss, cur_encoder_loss_dict, id_logs = self.calc_loss(x, y, y_hat, latent) |
| | cur_loss_dict = {**cur_loss_dict, **cur_encoder_loss_dict} |
| | agg_loss_dict.append(cur_loss_dict) |
| |
|
| | |
| | self.parse_and_log_images(id_logs, x, y, y_hat, |
| | title='images/test/faces', |
| | subscript='{:04d}'.format(batch_idx)) |
| |
|
| | |
| | if self.global_step == 0 and batch_idx >= 4: |
| | self.net.train() |
| | return None |
| |
|
| | loss_dict = train_utils.aggregate_loss_dict(agg_loss_dict) |
| | self.log_metrics(loss_dict, prefix='test') |
| | self.print_metrics(loss_dict, prefix='test') |
| |
|
| | self.net.train() |
| | return loss_dict |
| |
|
| | def checkpoint_me(self, loss_dict, is_best): |
| | save_name = 'best_model.pt' if is_best else 'iteration_{}.pt'.format(self.global_step) |
| | save_dict = self.__get_save_dict() |
| | checkpoint_path = os.path.join(self.checkpoint_dir, save_name) |
| | torch.save(save_dict, checkpoint_path) |
| | with open(os.path.join(self.checkpoint_dir, 'timestamp.txt'), 'a') as f: |
| | if is_best: |
| | f.write( |
| | '**Best**: Step - {}, Loss - {:.3f} \n{}\n'.format(self.global_step, self.best_val_loss, loss_dict)) |
| | else: |
| | f.write('Step - {}, \n{}\n'.format(self.global_step, loss_dict)) |
| |
|
| | def configure_optimizers(self): |
| | params = list(self.net.encoder.parameters()) |
| | if self.opts.train_decoder: |
| | params += list(self.net.decoder.parameters()) |
| | else: |
| | self.requires_grad(self.net.decoder, False) |
| | if self.opts.optim_name == 'adam': |
| | optimizer = torch.optim.Adam(params, lr=self.opts.learning_rate) |
| | else: |
| | optimizer = Ranger(params, lr=self.opts.learning_rate) |
| | return optimizer |
| |
|
| | def configure_datasets(self): |
| | if self.opts.dataset_type not in data_configs.DATASETS.keys(): |
| | Exception('{} is not a valid dataset_type'.format(self.opts.dataset_type)) |
| | print('Loading dataset for {}'.format(self.opts.dataset_type)) |
| | dataset_args = data_configs.DATASETS[self.opts.dataset_type] |
| | transforms_dict = dataset_args['transforms'](self.opts).get_transforms() |
| | train_dataset = ImagesDataset(source_root=dataset_args['train_source_root'], |
| | target_root=dataset_args['train_target_root'], |
| | source_transform=transforms_dict['transform_source'], |
| | target_transform=transforms_dict['transform_gt_train'], |
| | opts=self.opts) |
| | test_dataset = ImagesDataset(source_root=dataset_args['test_source_root'], |
| | target_root=dataset_args['test_target_root'], |
| | source_transform=transforms_dict['transform_source'], |
| | target_transform=transforms_dict['transform_test'], |
| | opts=self.opts) |
| | print("Number of training samples: {}".format(len(train_dataset))) |
| | print("Number of test samples: {}".format(len(test_dataset))) |
| | return train_dataset, test_dataset |
| |
|
| | def calc_loss(self, x, y, y_hat, latent): |
| | loss_dict = {} |
| | loss = 0.0 |
| | id_logs = None |
| | if self.is_training_discriminator(): |
| | loss_disc = 0. |
| | dims_to_discriminate = self.get_dims_to_discriminate() if self.is_progressive_training() else \ |
| | list(range(self.net.decoder.n_latent)) |
| |
|
| | for i in dims_to_discriminate: |
| | w = latent[:, i, :] |
| | fake_pred = self.discriminator(w) |
| | loss_disc += F.softplus(-fake_pred).mean() |
| | loss_disc /= len(dims_to_discriminate) |
| | loss_dict['encoder_discriminator_loss'] = float(loss_disc) |
| | loss += self.opts.w_discriminator_lambda * loss_disc |
| |
|
| | if self.opts.progressive_steps and self.net.encoder.progressive_stage.value != 18: |
| | total_delta_loss = 0 |
| | deltas_latent_dims = self.net.encoder.get_deltas_starting_dimensions() |
| |
|
| | first_w = latent[:, 0, :] |
| | for i in range(1, self.net.encoder.progressive_stage.value + 1): |
| | curr_dim = deltas_latent_dims[i] |
| | delta = latent[:, curr_dim, :] - first_w |
| | delta_loss = torch.norm(delta, self.opts.delta_norm, dim=1).mean() |
| | loss_dict[f"delta{i}_loss"] = float(delta_loss) |
| | total_delta_loss += delta_loss |
| | loss_dict['total_delta_loss'] = float(total_delta_loss) |
| | loss += self.opts.delta_norm_lambda * total_delta_loss |
| |
|
| | if self.opts.id_lambda > 0: |
| | loss_id, sim_improvement, id_logs = self.id_loss(y_hat, y, x) |
| | loss_dict['loss_id'] = float(loss_id) |
| | loss_dict['id_improve'] = float(sim_improvement) |
| | loss += loss_id * self.opts.id_lambda |
| | if self.opts.l2_lambda > 0: |
| | loss_l2 = F.mse_loss(y_hat, y) |
| | loss_dict['loss_l2'] = float(loss_l2) |
| | loss += loss_l2 * self.opts.l2_lambda |
| | if self.opts.lpips_lambda > 0: |
| | loss_lpips = self.lpips_loss(y_hat, y) |
| | loss_dict['loss_lpips'] = float(loss_lpips) |
| | loss += loss_lpips * self.opts.lpips_lambda |
| | loss_dict['loss'] = float(loss) |
| | return loss, loss_dict, id_logs |
| |
|
| | def forward(self, batch): |
| | x, y = batch |
| | x, y = x.to(self.device).float(), y.to(self.device).float() |
| | y_hat, latent = self.net.forward(x, return_latents=True) |
| | if self.opts.dataset_type == "cars_encode": |
| | y_hat = y_hat[:, :, 32:224, :] |
| | return x, y, y_hat, latent |
| |
|
| | def log_metrics(self, metrics_dict, prefix): |
| | for key, value in metrics_dict.items(): |
| | self.logger.add_scalar('{}/{}'.format(prefix, key), value, self.global_step) |
| |
|
| | def print_metrics(self, metrics_dict, prefix): |
| | print('Metrics for {}, step {}'.format(prefix, self.global_step)) |
| | for key, value in metrics_dict.items(): |
| | print('\t{} = '.format(key), value) |
| |
|
| | def parse_and_log_images(self, id_logs, x, y, y_hat, title, subscript=None, display_count=2): |
| | im_data = [] |
| | for i in range(display_count): |
| | cur_im_data = { |
| | 'input_face': common.log_input_image(x[i], self.opts), |
| | 'target_face': common.tensor2im(y[i]), |
| | 'output_face': common.tensor2im(y_hat[i]), |
| | } |
| | if id_logs is not None: |
| | for key in id_logs[i]: |
| | cur_im_data[key] = id_logs[i][key] |
| | im_data.append(cur_im_data) |
| | self.log_images(title, im_data=im_data, subscript=subscript) |
| |
|
| | def log_images(self, name, im_data, subscript=None, log_latest=False): |
| | fig = common.vis_faces(im_data) |
| | step = self.global_step |
| | if log_latest: |
| | step = 0 |
| | if subscript: |
| | path = os.path.join(self.logger.log_dir, name, '{}_{:04d}.jpg'.format(subscript, step)) |
| | else: |
| | path = os.path.join(self.logger.log_dir, name, '{:04d}.jpg'.format(step)) |
| | os.makedirs(os.path.dirname(path), exist_ok=True) |
| | fig.savefig(path) |
| | plt.close(fig) |
| |
|
| | def __get_save_dict(self): |
| | save_dict = { |
| | 'state_dict': self.net.state_dict(), |
| | 'opts': vars(self.opts) |
| | } |
| | |
| | if self.opts.start_from_latent_avg: |
| | save_dict['latent_avg'] = self.net.latent_avg |
| |
|
| | if self.opts.save_training_data: |
| | save_dict['global_step'] = self.global_step |
| | save_dict['optimizer'] = self.optimizer.state_dict() |
| | save_dict['best_val_loss'] = self.best_val_loss |
| | if self.opts.w_discriminator_lambda > 0: |
| | save_dict['discriminator_state_dict'] = self.discriminator.state_dict() |
| | save_dict['discriminator_optimizer_state_dict'] = self.discriminator_optimizer.state_dict() |
| | return save_dict |
| |
|
| | def get_dims_to_discriminate(self): |
| | deltas_starting_dimensions = self.net.encoder.get_deltas_starting_dimensions() |
| | return deltas_starting_dimensions[:self.net.encoder.progressive_stage.value + 1] |
| |
|
| | def is_progressive_training(self): |
| | return self.opts.progressive_steps is not None |
| |
|
| | |
| |
|
| | def is_training_discriminator(self): |
| | return self.opts.w_discriminator_lambda > 0 |
| |
|
| | @staticmethod |
| | def discriminator_loss(real_pred, fake_pred, loss_dict): |
| | real_loss = F.softplus(-real_pred).mean() |
| | fake_loss = F.softplus(fake_pred).mean() |
| |
|
| | loss_dict['d_real_loss'] = float(real_loss) |
| | loss_dict['d_fake_loss'] = float(fake_loss) |
| |
|
| | return real_loss + fake_loss |
| |
|
| | @staticmethod |
| | def discriminator_r1_loss(real_pred, real_w): |
| | grad_real, = autograd.grad( |
| | outputs=real_pred.sum(), inputs=real_w, create_graph=True |
| | ) |
| | grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean() |
| |
|
| | return grad_penalty |
| |
|
| | @staticmethod |
| | def requires_grad(model, flag=True): |
| | for p in model.parameters(): |
| | p.requires_grad = flag |
| |
|
| | def train_discriminator(self, batch): |
| | loss_dict = {} |
| | x, _ = batch |
| | x = x.to(self.device).float() |
| | self.requires_grad(self.discriminator, True) |
| |
|
| | with torch.no_grad(): |
| | real_w, fake_w = self.sample_real_and_fake_latents(x) |
| | real_pred = self.discriminator(real_w) |
| | fake_pred = self.discriminator(fake_w) |
| | loss = self.discriminator_loss(real_pred, fake_pred, loss_dict) |
| | loss_dict['discriminator_loss'] = float(loss) |
| |
|
| | self.discriminator_optimizer.zero_grad() |
| | loss.backward() |
| | self.discriminator_optimizer.step() |
| |
|
| | |
| | d_regularize = self.global_step % self.opts.d_reg_every == 0 |
| | if d_regularize: |
| | real_w = real_w.detach() |
| | real_w.requires_grad = True |
| | real_pred = self.discriminator(real_w) |
| | r1_loss = self.discriminator_r1_loss(real_pred, real_w) |
| |
|
| | self.discriminator.zero_grad() |
| | r1_final_loss = self.opts.r1 / 2 * r1_loss * self.opts.d_reg_every + 0 * real_pred[0] |
| | r1_final_loss.backward() |
| | self.discriminator_optimizer.step() |
| | loss_dict['discriminator_r1_loss'] = float(r1_final_loss) |
| |
|
| | |
| | self.requires_grad(self.discriminator, False) |
| |
|
| | return loss_dict |
| |
|
| | def validate_discriminator(self, test_batch): |
| | with torch.no_grad(): |
| | loss_dict = {} |
| | x, _ = test_batch |
| | x = x.to(self.device).float() |
| | real_w, fake_w = self.sample_real_and_fake_latents(x) |
| | real_pred = self.discriminator(real_w) |
| | fake_pred = self.discriminator(fake_w) |
| | loss = self.discriminator_loss(real_pred, fake_pred, loss_dict) |
| | loss_dict['discriminator_loss'] = float(loss) |
| | return loss_dict |
| |
|
| | def sample_real_and_fake_latents(self, x): |
| | sample_z = torch.randn(self.opts.batch_size, 512, device=self.device) |
| | real_w = self.net.decoder.get_latent(sample_z) |
| | fake_w = self.net.encoder(x) |
| | if self.is_progressive_training(): |
| | dims_to_discriminate = self.get_dims_to_discriminate() |
| | fake_w = fake_w[:, dims_to_discriminate, :] |
| | if self.opts.use_w_pool: |
| | real_w = self.real_w_pool.query(real_w) |
| | fake_w = self.fake_w_pool.query(fake_w) |
| | if fake_w.ndim == 3: |
| | fake_w = fake_w[:, 0, :] |
| | return real_w, fake_w |
| |
|