| from lib.net import NormalNet |
| from lib.common.train_util import * |
| import logging |
| import torch |
| import numpy as np |
| from torch import nn |
| from skimage.transform import resize |
| import pytorch_lightning as pl |
|
|
| torch.backends.cudnn.benchmark = True |
|
|
| logging.getLogger("lightning").setLevel(logging.ERROR) |
|
|
|
|
| class Normal(pl.LightningModule): |
| def __init__(self, cfg): |
| super(Normal, self).__init__() |
| self.cfg = cfg |
| self.batch_size = self.cfg.batch_size |
| self.lr_N = self.cfg.lr_N |
|
|
| self.schedulers = [] |
|
|
| self.netG = NormalNet(self.cfg, error_term=nn.SmoothL1Loss()) |
|
|
| self.in_nml = [item[0] for item in cfg.net.in_nml] |
|
|
| def get_progress_bar_dict(self): |
| tqdm_dict = super().get_progress_bar_dict() |
| if "v_num" in tqdm_dict: |
| del tqdm_dict["v_num"] |
| return tqdm_dict |
|
|
| |
| def configure_optimizers(self): |
|
|
| |
| weight_decay = self.cfg.weight_decay |
| momentum = self.cfg.momentum |
|
|
| optim_params_N_F = [ |
| {"params": self.netG.netF.parameters(), "lr": self.lr_N}] |
| optim_params_N_B = [ |
| {"params": self.netG.netB.parameters(), "lr": self.lr_N}] |
|
|
| optimizer_N_F = torch.optim.Adam( |
| optim_params_N_F, lr=self.lr_N, weight_decay=weight_decay |
| ) |
|
|
| optimizer_N_B = torch.optim.Adam( |
| optim_params_N_B, lr=self.lr_N, weight_decay=weight_decay |
| ) |
|
|
| scheduler_N_F = torch.optim.lr_scheduler.MultiStepLR( |
| optimizer_N_F, milestones=self.cfg.schedule, gamma=self.cfg.gamma |
| ) |
|
|
| scheduler_N_B = torch.optim.lr_scheduler.MultiStepLR( |
| optimizer_N_B, milestones=self.cfg.schedule, gamma=self.cfg.gamma |
| ) |
|
|
| self.schedulers = [scheduler_N_F, scheduler_N_B] |
| optims = [optimizer_N_F, optimizer_N_B] |
|
|
| return optims, self.schedulers |
|
|
| def render_func(self, render_tensor): |
|
|
| height = render_tensor["image"].shape[2] |
| result_list = [] |
|
|
| for name in render_tensor.keys(): |
| result_list.append( |
| resize( |
| ((render_tensor[name].cpu().numpy()[0] + 1.0) / 2.0).transpose( |
| 1, 2, 0 |
| ), |
| (height, height), |
| anti_aliasing=True, |
| ) |
| ) |
| result_array = np.concatenate(result_list, axis=1) |
|
|
| return result_array |
|
|
| def training_step(self, batch, batch_idx, optimizer_idx): |
|
|
| export_cfg(self.logger, self.cfg) |
|
|
| |
| in_tensor = {} |
| for name in self.in_nml: |
| in_tensor[name] = batch[name] |
|
|
| FB_tensor = {"normal_F": batch["normal_F"], |
| "normal_B": batch["normal_B"]} |
|
|
| self.netG.train() |
|
|
| preds_F, preds_B = self.netG(in_tensor) |
| error_NF, error_NB = self.netG.get_norm_error( |
| preds_F, preds_B, FB_tensor) |
|
|
| (opt_nf, opt_nb) = self.optimizers() |
|
|
| opt_nf.zero_grad() |
| opt_nb.zero_grad() |
|
|
| self.manual_backward(error_NF, opt_nf) |
| self.manual_backward(error_NB, opt_nb) |
|
|
| opt_nf.step() |
| opt_nb.step() |
|
|
| if batch_idx > 0 and batch_idx % int(self.cfg.freq_show_train) == 0: |
|
|
| self.netG.eval() |
| with torch.no_grad(): |
| nmlF, nmlB = self.netG(in_tensor) |
| in_tensor.update({"nmlF": nmlF, "nmlB": nmlB}) |
| result_array = self.render_func(in_tensor) |
|
|
| self.logger.experiment.add_image( |
| tag=f"Normal-train/{self.global_step}", |
| img_tensor=result_array.transpose(2, 0, 1), |
| global_step=self.global_step, |
| ) |
|
|
| |
| metrics_log = { |
| "train_loss-NF": error_NF.item(), |
| "train_loss-NB": error_NB.item(), |
| } |
|
|
| tf_log = tf_log_convert(metrics_log) |
| bar_log = bar_log_convert(metrics_log) |
|
|
| return { |
| "loss": error_NF + error_NB, |
| "loss-NF": error_NF, |
| "loss-NB": error_NB, |
| "log": tf_log, |
| "progress_bar": bar_log, |
| } |
|
|
| def training_epoch_end(self, outputs): |
|
|
| if [] in outputs: |
| outputs = outputs[0] |
|
|
| |
| metrics_log = { |
| "train_avgloss": batch_mean(outputs, "loss"), |
| "train_avgloss-NF": batch_mean(outputs, "loss-NF"), |
| "train_avgloss-NB": batch_mean(outputs, "loss-NB"), |
| } |
|
|
| tf_log = tf_log_convert(metrics_log) |
|
|
| tf_log["lr-NF"] = self.schedulers[0].get_last_lr()[0] |
| tf_log["lr-NB"] = self.schedulers[1].get_last_lr()[0] |
|
|
| return {"log": tf_log} |
|
|
| def validation_step(self, batch, batch_idx): |
|
|
| |
| in_tensor = {} |
| for name in self.in_nml: |
| in_tensor[name] = batch[name] |
|
|
| FB_tensor = {"normal_F": batch["normal_F"], |
| "normal_B": batch["normal_B"]} |
|
|
| self.netG.train() |
|
|
| preds_F, preds_B = self.netG(in_tensor) |
| error_NF, error_NB = self.netG.get_norm_error( |
| preds_F, preds_B, FB_tensor) |
|
|
| if (batch_idx > 0 and batch_idx % int(self.cfg.freq_show_train) == 0) or ( |
| batch_idx == 0 |
| ): |
|
|
| with torch.no_grad(): |
| nmlF, nmlB = self.netG(in_tensor) |
| in_tensor.update({"nmlF": nmlF, "nmlB": nmlB}) |
| result_array = self.render_func(in_tensor) |
|
|
| self.logger.experiment.add_image( |
| tag=f"Normal-val/{self.global_step}", |
| img_tensor=result_array.transpose(2, 0, 1), |
| global_step=self.global_step, |
| ) |
|
|
| return { |
| "val_loss": error_NF + error_NB, |
| "val_loss-NF": error_NF, |
| "val_loss-NB": error_NB, |
| } |
|
|
| def validation_epoch_end(self, outputs): |
|
|
| |
| metrics_log = { |
| "val_avgloss": batch_mean(outputs, "val_loss"), |
| "val_avgloss-NF": batch_mean(outputs, "val_loss-NF"), |
| "val_avgloss-NB": batch_mean(outputs, "val_loss-NB"), |
| } |
|
|
| tf_log = tf_log_convert(metrics_log) |
|
|
| return {"log": tf_log} |
|
|