| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import os |
| import math |
| from tqdm.auto import tqdm |
| import torch |
| import torch.nn as nn |
| import torchvision |
| import numpy as np |
| from torchvision.utils import make_grid |
| from einops import rearrange, repeat |
| from accelerate.logging import get_logger |
| from taming.modules.losses.vqperceptual import hinge_d_loss |
|
|
| from .base_trainer import Trainer |
| from lam.utils.profiler import DummyProfiler |
| from lam.runners import REGISTRY_RUNNERS |
| from lam.utils.hf_hub import wrap_model_hub |
| from safetensors.torch import load_file |
| from pytorch3d.ops.knn import knn_points |
| import torch.nn.functional as F |
|
|
| logger = get_logger(__name__) |
|
|
| |
|
|
|
|
| from omegaconf import OmegaConf |
| @REGISTRY_RUNNERS.register('train.lam') |
| class LAMTrainer(Trainer): |
|
|
| EXP_TYPE: str = 'lam' |
|
|
| def __init__(self): |
| super().__init__() |
| |
| self.model = self._build_model(self.cfg) |
| if self.has_disc: |
| self.model_disc = self._build_model_disc(self.cfg) |
| self.optimizer = self._build_optimizer(self.model, self.cfg) |
| if self.has_disc: |
| self.optimizer_disc = self._build_optimizer(self.model_disc, self.cfg) |
| |
| self.train_loader, self.val_loader = self._build_dataloader(self.cfg) |
| self.scheduler = self._build_scheduler(self.optimizer, self.cfg) |
| if self.has_disc: |
| self.scheduler_disc = self._build_scheduler(self.optimizer_disc, self.cfg) |
| self.pixel_loss_fn, self.perceptual_loss_fn, self.tv_loss_fn = self._build_loss_fn(self.cfg) |
| self.only_sym_conf = 2 |
| print("==="*16*3, "\n"+"only_sym_conf:", self.only_sym_conf, "\n"+"==="*16*3) |
| |
| |
| def _build_model(self, cfg): |
| assert cfg.experiment.type == 'lrm', \ |
| f"Config type {cfg.experiment.type} does not match with runner {self.__class__.__name__}" |
| from lam.models import ModelLAM |
| model = ModelLAM(**cfg.model) |
|
|
| |
| if len(self.cfg.train.resume) > 0: |
| resume = self.cfg.train.resume |
| print("==="*16*3) |
| self.accelerator.print("loading pretrained weight from:", resume) |
| if resume.endswith('safetensors'): |
| ckpt = load_file(resume, device='cpu') |
| else: |
| ckpt = torch.load(resume, map_location='cpu') |
| state_dict = model.state_dict() |
| for k, v in ckpt.items(): |
| if k in state_dict: |
| if state_dict[k].shape == v.shape: |
| state_dict[k].copy_(v) |
| else: |
| self.accelerator.print(f"WARN] mismatching shape for param {k}: ckpt {v.shape} != model {state_dict[k].shape}, ignored.") |
| else: |
| self.accelerator.print(f"WARN] unexpected param {k}: {v.shape}") |
| self.accelerator.print("Finish loading ckpt:", resume, "\n"+"==="*16*3) |
| return model |
|
|
| def _build_model_disc(self, cfg): |
| if cfg.model.disc.type == "pix2pix": |
| from lam.models.discriminator import NLayerDiscriminator, weights_init |
| model = NLayerDiscriminator(input_nc=cfg.model.disc.in_channels, |
| n_layers=cfg.model.disc.num_layers, |
| use_actnorm=cfg.model.disc.use_actnorm |
| ).apply(weights_init) |
|
|
| elif cfg.model.disc.type == "vqgan": |
| from lam.models.discriminator import Discriminator |
| model = Discriminator(in_channels=cfg.model.disc.in_channels, |
| cond_channels=0, hidden_channels=512, |
| depth=cfg.model.disc.depth) |
| elif cfg.model.disc.type == "stylegan": |
| from lam.models.gan.stylegan_discriminator import SingleDiscriminatorV2, SingleDiscriminator |
| from lam.models.gan.stylegan_discriminator_torch import Discriminator |
| |
| model = Discriminator(512, channel_multiplier=2) |
| |
| model.input_size = cfg.model.disc.img_res |
| else: |
| raise NotImplementedError |
| return model |
|
|
| def _build_optimizer(self, model: nn.Module, cfg): |
| decay_params, no_decay_params = [], [] |
| |
| |
| for name, module in model.named_modules(): |
| if isinstance(module, nn.LayerNorm): |
| no_decay_params.extend([p for p in module.parameters()]) |
| elif hasattr(module, 'bias') and module.bias is not None: |
| no_decay_params.append(module.bias) |
|
|
| |
| _no_decay_ids = set(map(id, no_decay_params)) |
| decay_params = [p for p in model.parameters() if id(p) not in _no_decay_ids] |
|
|
| |
| decay_params = list(filter(lambda p: p.requires_grad, decay_params)) |
| no_decay_params = list(filter(lambda p: p.requires_grad, no_decay_params)) |
|
|
| |
| logger.info("======== Weight Decay Parameters ========") |
| logger.info(f"Total: {len(decay_params)}") |
| logger.info("======== No Weight Decay Parameters ========") |
| logger.info(f"Total: {len(no_decay_params)}") |
|
|
| |
| opt_groups = [ |
| {'params': decay_params, 'weight_decay': cfg.train.optim.weight_decay}, |
| {'params': no_decay_params, 'weight_decay': 0.0}, |
| ] |
| optimizer = torch.optim.AdamW( |
| opt_groups, |
| lr=cfg.train.optim.lr, |
| betas=(cfg.train.optim.beta1, cfg.train.optim.beta2), |
| ) |
| |
| return optimizer |
|
|
| def _build_scheduler(self, optimizer, cfg): |
| local_batches_per_epoch = math.floor(len(self.train_loader) / self.accelerator.num_processes) |
| total_global_batches = cfg.train.epochs * math.ceil(local_batches_per_epoch / self.cfg.train.accum_steps) |
| effective_warmup_iters = cfg.train.scheduler.warmup_real_iters |
| logger.debug(f"======== Scheduler effective max iters: {total_global_batches} ========") |
| logger.debug(f"======== Scheduler effective warmup iters: {effective_warmup_iters} ========") |
| if cfg.train.scheduler.type == 'cosine': |
| from lam.utils.scheduler import CosineWarmupScheduler |
| scheduler = CosineWarmupScheduler( |
| optimizer=optimizer, |
| warmup_iters=effective_warmup_iters, |
| max_iters=total_global_batches, |
| ) |
| else: |
| raise NotImplementedError(f"Scheduler type {cfg.train.scheduler.type} not implemented") |
| return scheduler |
|
|
| def _build_dataloader(self, cfg): |
| |
| from lam.datasets import MixerDataset |
| gaga_track_type = cfg.dataset.get("gaga_track_type", "vfhq_gagtrack") |
| sample_aug_views = cfg.dataset.get("sample_aug_views", 0) |
|
|
| |
| load_normal = cfg.train.loss.get("normal_weight", False) > 0. if hasattr(cfg.train.loss, "normal_weight") else False |
| load_normal = load_normal or (cfg.train.loss.get("surfel_normal_weight", False) > 0. if hasattr(cfg.train.loss, "surfel_normal_weight") else False) |
| print("==="*16*3, "\nload_normal:", load_normal) |
| train_dataset = MixerDataset( |
| split="train", |
| subsets=cfg.dataset.subsets, |
| sample_side_views=cfg.dataset.sample_side_views, |
| render_image_res_low=cfg.dataset.render_image.low, |
| render_image_res_high=cfg.dataset.render_image.high, |
| render_region_size=cfg.dataset.render_image.region, |
| source_image_res=cfg.dataset.source_image_res, |
| repeat_num=cfg.dataset.repeat_num if hasattr(cfg.dataset, "repeat_num") else 1, |
| multiply=cfg.dataset.multiply if hasattr(cfg.dataset, "multiply") else 14, |
| debug=cfg.dataset.debug if hasattr(cfg.dataset, "debug") else False, |
| is_val=False, |
| gaga_track_type=gaga_track_type, |
| sample_aug_views=sample_aug_views, |
| load_albedo=cfg.model.get("render_albedo", False) if hasattr(cfg.model, "render_albedo") else False, |
| load_normal=load_normal, |
| ) |
| val_dataset = MixerDataset( |
| split="val", |
| subsets=cfg.dataset.subsets, |
| sample_side_views=cfg.dataset.sample_side_views, |
| render_image_res_low=cfg.dataset.render_image.low, |
| render_image_res_high=cfg.dataset.render_image.high, |
| render_region_size=cfg.dataset.render_image.region, |
| source_image_res=cfg.dataset.source_image_res, |
| repeat_num=cfg.dataset.repeat_num if hasattr(cfg.dataset, "repeat_num") else 1, |
| multiply=cfg.dataset.multiply if hasattr(cfg.dataset, "multiply") else 14, |
| debug=cfg.dataset.debug if hasattr(cfg.dataset, "debug") else False, |
| is_val=True, |
| gaga_track_type=gaga_track_type, |
| sample_aug_views=sample_aug_views, |
| load_albedo=cfg.model.get("render_albedo", False) if hasattr(cfg.model, "render_albedo") else False, |
| load_normal=load_normal, |
| ) |
|
|
| |
| train_loader = torch.utils.data.DataLoader( |
| train_dataset, |
| batch_size=cfg.train.batch_size, |
| shuffle=True, |
| drop_last=True, |
| num_workers=cfg.dataset.num_train_workers, |
| pin_memory=cfg.dataset.pin_mem, |
| persistent_workers=True, |
| ) |
| val_loader = torch.utils.data.DataLoader( |
| val_dataset, |
| batch_size=cfg.val.batch_size, |
| shuffle=False, |
| drop_last=False, |
| num_workers=cfg.dataset.num_val_workers, |
| pin_memory=cfg.dataset.pin_mem, |
| persistent_workers=False, |
| ) |
|
|
| return train_loader, val_loader |
|
|
| def _build_loss_fn(self, cfg): |
| from lam.losses import PixelLoss, LPIPSLoss, TVLoss |
| pixel_loss_fn = PixelLoss(option=cfg.train.loss.get("pixel_loss_fn", "mse")) |
| with self.accelerator.main_process_first(): |
| perceptual_loss_fn = LPIPSLoss(device=self.device, prefech=True) |
| |
| if cfg.model.get("use_conf_map", False): |
| assert cfg.train.loss.get("head_pl", False), "Set head_pl in train.loss to true to use faceperceptualloss when using conf_map." |
| tv_loss_fn = TVLoss() |
| return pixel_loss_fn, perceptual_loss_fn, tv_loss_fn |
|
|
| def register_hooks(self): |
| pass |
|
|
| def get_flame_params(self, data, is_source=False): |
| flame_params = {} |
| flame_keys = ['root_pose', 'body_pose', 'jaw_pose', 'leye_pose', 'reye_pose', 'lhand_pose', 'rhand_pose', 'expr', 'trans', 'betas',\ |
| 'rotation', 'neck_pose', 'eyes_pose', 'translation', "teeth_bs"] |
| if is_source: |
| flame_keys = ['source_'+item for item in flame_keys] |
| for k, v in data.items(): |
| if k in flame_keys: |
| |
| flame_params[k] = data[k] |
| return flame_params |
| |
| def cross_copy(self, data): |
| B = data.shape[0] |
| assert data.shape[1] == 1 |
| new_data = [] |
| for i in range(B): |
| B_i = [data[i]] |
| for j in range(B): |
| if j != i: |
| B_i.append(data[j]) |
| new_data.append(torch.concat(B_i, dim=0)) |
| new_data = torch.stack(new_data, dim=0) |
| |
| return new_data |
| |
| def prepare_cross_render_data(self, data): |
| B, N_v, C, H, W = data['render_image'].shape |
| assert N_v == 1 |
| |
| |
| data["c2ws"] = self.cross_copy(data["c2ws"]) |
| data["intrs"] = self.cross_copy(data["intrs"]) |
| data["render_full_resolutions"] = self.cross_copy(data["render_full_resolutions"]) |
| data["render_image"] = self.cross_copy(data["render_image"]) |
| data["render_mask"] = self.cross_copy(data["render_mask"]) |
| data["render_bg_colors"] = self.cross_copy(data["render_bg_colors"]) |
| flame_params = self.get_flame_params(data) |
| for key in flame_params.keys(): |
| if "betas" not in key: |
| data[key] = self.cross_copy(data[key]) |
| source_flame_params = self.get_flame_params(data, is_source=True) |
| for key in source_flame_params.keys(): |
| if "betas" not in key: |
| data[key] = self.cross_copy(data[key]) |
| |
| return data |
| |
| def get_loss_weight(self, loss_weight): |
| if isinstance(loss_weight, str) and ":" in loss_weight: |
| start_step, start_value, end_value, end_step = map(float, loss_weight.split(":")) |
| current_step = self.global_step |
| value = start_value + (end_value - start_value) * max( |
| min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0 |
| ) |
| return value |
| elif isinstance(loss_weight, (float, int)): |
| return loss_weight |
| else: |
| raise NotImplementedError |
| |
| def forward_loss_local_step(self, data): |
| render_image = data['render_image'] |
| render_albedo = data.get('render_albedo', None) |
| render_mask = data['render_mask'] |
| render_normal = data.get('render_normal', None) |
| B, N_v, C, H, W = render_image.shape |
| flame_params = self.get_flame_params(data) |
| source_flame_params = self.get_flame_params(data, is_source=True) |
| |
| |
| outputs = self.model( |
| image=data['source_rgbs'], |
| source_c2ws=data['source_c2ws'], |
| source_intrs=data['source_intrs'], |
| render_c2ws=data['c2ws'], |
| render_intrs=data['intrs'], |
| render_bg_colors=data['render_bg_colors'], |
| flame_params=flame_params, |
| source_flame_params=source_flame_params, |
| render_images=render_image, |
| data = data |
| ) |
| |
| |
| loss = 0. |
| loss_pixel = None |
| loss_perceptual = None |
| loss_mask = None |
| extra_loss_dict = {} |
|
|
| num_aug_view = self.cfg.dataset.get("sample_aug_views", 0) |
| real_num_view = data["real_num_view"] - num_aug_view |
|
|
| conf_sigma_l1 = outputs.get("conf_sigma_l1", None) |
| conf_sigma_percl = outputs.get("conf_sigma_percl", None) |
| if self.cfg.model.use_sym_proj: |
| real_num_view *= 2 |
| if self.cfg.model.use_conf_map: |
| conf_sigma_l1 = rearrange(conf_sigma_l1, "b v (c r) h w -> b (v r) c h w", r=2)[:, :real_num_view] |
| conf_sigma_percl = rearrange(conf_sigma_percl, "b v (c r) h w -> b (v r) c h w", r=2)[:, :real_num_view] |
| render_image = repeat(data['render_image'], "b v c h w -> b (v r) c h w", r=2) |
| render_albedo = repeat(render_albedo, "b v c h w -> b (v r) c h w", r=2) if render_albedo is not None else None |
| render_mask = repeat(data['render_mask'], "b v c h w -> b (v r) c h w", r=2) |
| if "render_normal" in data.keys(): |
| render_normal = repeat(data['render_normal'], "b v c h w -> b (v r) c h w", r=2) |
| for k, v in data.items(): |
| if "bbox" in k: |
| data[k] = repeat(v, "b v c -> b (v r) c", r=2) |
|
|
| only_sym_conf = self.only_sym_conf |
|
|
| if self.get_loss_weight(self.cfg.train.loss.get("masked_pixel_weight", 0)) > 0.: |
| gt_rgb = render_image[:, :real_num_view] * render_mask[:, :real_num_view] + 1.0 * (1 - render_mask[:, :real_num_view]) |
| pred_rgb = outputs['comp_rgb'][:, :real_num_view] * render_mask[:, :real_num_view] + 1.0 * (1 - render_mask[:, :real_num_view]) |
| |
| loss_pixel = self.pixel_loss_fn(pred_rgb, gt_rgb, conf_sigma_l1, only_sym_conf=only_sym_conf) * self.get_loss_weight(self.cfg.train.loss.masked_pixel_weight) |
| loss += loss_pixel |
|
|
| |
| loss_perceptual = self.perceptual_loss_fn(pred_rgb, gt_rgb, conf_sigma=conf_sigma_percl, only_sym_conf=only_sym_conf) * self.get_loss_weight(self.cfg.train.loss.masked_pixel_weight) |
| loss += loss_perceptual |
|
|
| if self.get_loss_weight(self.cfg.train.loss.pixel_weight) > 0.: |
| total_loss_pixel = loss_pixel |
| if (hasattr(self.cfg.train.loss, 'rgb_weight') and self.get_loss_weight(self.cfg.train.loss.rgb_weight) > 0.) or not hasattr(self.cfg.train.loss, "rgb_weight"): |
| loss_pixel = self.pixel_loss_fn( |
| outputs['comp_rgb'][:, :real_num_view], render_image[:, :real_num_view], conf_sigma=conf_sigma_l1, only_sym_conf=only_sym_conf |
| ) * self.get_loss_weight(self.cfg.train.loss.pixel_weight) |
| loss += loss_pixel |
| if total_loss_pixel is not None: |
| loss_pixel += total_loss_pixel |
|
|
| if self.get_loss_weight(self.cfg.train.loss.perceptual_weight) > 0.: |
| total_loss_perceptual = loss_perceptual |
| if (hasattr(self.cfg.train.loss, 'rgb_weight') and self.get_loss_weight(self.cfg.train.loss.rgb_weight) > 0.) or not hasattr(self.cfg.train.loss, "rgb_weight"): |
| loss_perceptual = self.perceptual_loss_fn( |
| outputs['comp_rgb'][:, :real_num_view], render_image[:, :real_num_view], conf_sigma=conf_sigma_percl, only_sym_conf=only_sym_conf |
| ) * self.get_loss_weight(self.cfg.train.loss.perceptual_weight) |
| loss += loss_perceptual |
| if total_loss_perceptual is not None: |
| loss_perceptual += total_loss_perceptual |
|
|
| if self.get_loss_weight(self.cfg.train.loss.mask_weight) > 0. and 'comp_mask' in outputs.keys(): |
| loss_mask = self.pixel_loss_fn(outputs['comp_mask'][:, :real_num_view], render_mask[:, :real_num_view], conf_sigma=conf_sigma_l1, only_sym_conf=only_sym_conf |
| ) * self.get_loss_weight(self.cfg.train.loss.mask_weight) |
| loss += loss_mask |
| |
| if hasattr(self.cfg.train.loss, 'offset_reg_weight') and self.get_loss_weight(self.cfg.train.loss.offset_reg_weight) > 0.: |
| loss_offset_reg = 0 |
| for b_idx in range(len(outputs['3dgs'])): |
| loss_offset_reg += torch.nn.functional.mse_loss(outputs['3dgs'][b_idx][0].offset.float(), torch.zeros_like(outputs['3dgs'][b_idx][0].offset.float())) |
| loss_offset_reg = loss_offset_reg / len(outputs['3dgs']) |
| loss += loss_offset_reg * self.get_loss_weight(self.cfg.train.loss.offset_reg_weight) |
| else: |
| loss_offset_reg = None |
|
|
| return outputs, loss, loss_pixel, loss_perceptual, loss_offset_reg, loss_mask, extra_loss_dict |
|
|
| def adopt_weight(self, weight, global_step, threshold=0, value=0.): |
| if global_step < threshold: |
| weight = value |
| return weight |
|
|
| def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer, discriminator_weight=1): |
| nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] |
| g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] |
|
|
| d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) |
| d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() |
| d_weight = d_weight * discriminator_weight |
| return d_weight |
| |
| def disc_preprocess(self, img): |
| |
| img = torch.flatten(img, 0, 1) |
| |
| |
| img = 2 * img - 1 |
| |
| if hasattr(self.accelerator.unwrap_model(self.model_disc), "input_size"): |
| tgt_size = self.accelerator.unwrap_model(self.model_disc).input_size |
| img = nn.functional.interpolate(img, (tgt_size, tgt_size)) |
| img = img.float() |
| |
| return img |
| |
| def forward_to_get_loss_with_gen_loss(self, data): |
| |
| outs, loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, extra_loss_dict = self.forward_loss_local_step(data) |
|
|
| with torch.autocast(device_type=outs["comp_rgb"].device.type, dtype=torch.float32): |
| logits_fake = self.model_disc(self.disc_preprocess(outs["comp_rgb"])) |
| |
| loss_gen = -torch.mean(logits_fake) |
| |
| try: |
| if loss < 1e-5: |
| d_weight = self.cfg.model.disc.disc_weight |
| else: |
| nll_loss = loss_pixel |
| if nll_loss is None: |
| nll_loss = loss |
| d_weight = self.calculate_adaptive_weight(nll_loss, loss_gen, |
| last_layer=self.accelerator.unwrap_model(self.model).get_last_layer(), |
| discriminator_weight=self.cfg.model.disc.disc_weight) |
| except RuntimeError: |
| print("*************Error when calculate_adaptive_weight************") |
| d_weight = torch.tensor(0.0) |
| |
| disc_factor = self.adopt_weight(1.0, self.global_step, threshold=self.cfg.model.disc.disc_iter_start) |
| |
| |
| loss += disc_factor * d_weight * loss_gen |
| |
| |
| self.accelerator.backward(loss) |
| if self.accelerator.sync_gradients and self.cfg.train.optim.clip_grad_norm > 0.: |
| self.accelerator.clip_grad_norm_(self.model.parameters(), self.cfg.train.optim.clip_grad_norm) |
| |
| self.optimizer.step() |
| self.optimizer.zero_grad() |
| |
| return outs, loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, loss_gen, extra_loss_dict |
| |
|
|
| def forward_to_get_loss(self, data): |
| |
| outs, loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, extra_loss_dict = self.forward_loss_local_step(data) |
| |
| |
| self.accelerator.backward(loss) |
| if self.accelerator.sync_gradients and self.cfg.train.optim.clip_grad_norm > 0.: |
| self.accelerator.clip_grad_norm_(self.model.parameters(), self.cfg.train.optim.clip_grad_norm) |
| |
| self.optimizer.step() |
| self.optimizer.zero_grad() |
| |
| return outs, loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, extra_loss_dict |
|
|
|
|
| def forward_disc_loss_local_step(self, pred_img, gt_img): |
| |
| with torch.autocast(device_type=pred_img.device.type, dtype=torch.float32): |
| logits_real = self.model_disc(self.disc_preprocess(gt_img).detach()) |
| logits_fake = self.model_disc(self.disc_preprocess(pred_img).detach()) |
| |
| loss_disc = hinge_d_loss(logits_real, logits_fake) |
| return loss_disc |
|
|
|
|
| def forward_to_get_disc_loss(self, pred_img, gt_img): |
| |
| loss_disc = self.forward_disc_loss_local_step(pred_img, gt_img) |
|
|
| disc_factor = self.adopt_weight(1.0, self.global_step, threshold=self.cfg.model.disc.disc_iter_start) |
| loss = disc_factor * loss_disc |
| |
| |
| self.accelerator.backward(loss) |
| |
| if self.accelerator.sync_gradients and self.cfg.train.optim.clip_grad_norm > 0.: |
| self.accelerator.clip_grad_norm_(self.model_disc.parameters(), self.cfg.train.optim.clip_grad_norm) |
| |
| self.optimizer_disc.step() |
| self.optimizer_disc.zero_grad() |
|
|
| return loss_disc |
|
|
| def train_epoch(self, pbar: tqdm, loader: torch.utils.data.DataLoader, profiler: torch.profiler.profile, iepoch: int): |
|
|
| self.model.train() |
| if self.has_disc: |
| self.model_disc.train() |
|
|
| local_step_losses = [] |
| global_step_losses = [] |
| local_step_extra_losses = [] |
| global_step_extra_losses = [] |
| extra_loss_keys = [] |
|
|
| logger.debug(f"======== Starting epoch {self.current_epoch} ========") |
| loss_disc = None |
| for idx, data in enumerate(loader): |
| data["source_rgbs"] = data["source_rgbs"].to(self.weight_dtype) |
| if self.has_disc and hasattr(self.cfg.model.disc, "cross_render") and self.cfg.model.disc.cross_render: |
| data = self.prepare_cross_render_data(data) |
| data["real_num_view"] = 1 |
| else: |
| data["real_num_view"] = data["render_image"].shape[1] |
| |
| logger.debug(f"======== Starting global step {self.global_step} ========") |
| |
| if not self.has_disc: |
| disc_step = False |
| with self.accelerator.accumulate(self.model): |
| outs, loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, extra_loss_dict = self.forward_to_get_loss(data) |
| |
| |
| loss_disc, loss_gen = None, None |
| local_step_losses.append(torch.stack([ |
| _loss.detach() if _loss is not None else torch.tensor(float('nan'), device=self.device) |
| for _loss in [loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, loss_disc, loss_gen] |
| ])) |
| extra_loss_keys = sorted(list(extra_loss_dict.keys())) |
| if len(extra_loss_keys) > 0: |
| local_step_extra_losses.append(torch.stack([ |
| extra_loss_dict[k].detach() if extra_loss_dict[k] is not None else torch.tensor(float('nan'), device=self.device) |
| for k in extra_loss_keys |
| ])) |
| else: |
| disc_step = (idx % 5) == 0 or (iepoch * len(loader) + idx < 100 and idx % 2 == 0) |
| local_step_losses_bak = torch.zeros(6, device=data["source_rgbs"].device) |
| if not disc_step: |
| with self.accelerator.accumulate(self.model): |
| |
| outs, loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, loss_gen, extra_loss_dict = self.forward_to_get_loss_with_gen_loss(data) |
| |
| local_step_losses.append(torch.stack([ |
| _loss.detach() if _loss is not None else torch.tensor(float('nan'), device=self.device) |
| for _loss in [loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, loss_gen, loss_disc] |
| ])) |
| local_step_losses_bak = local_step_losses[-1].detach() |
| torch.cuda.empty_cache() |
| extra_loss_keys = sorted(list(extra_loss_dict.keys())) |
| if len(extra_loss_keys) > 0: |
| local_step_extra_losses.append(torch.stack([ |
| extra_loss_dict[k].detach() if extra_loss_dict[k] is not None else torch.tensor(float('nan'), device=self.device) |
| for k in extra_loss_keys |
| ])) |
| else: |
| with self.accelerator.accumulate(self.model_disc): |
| |
| outs, _, _, _, _, _, _ = self.forward_loss_local_step(data) |
| loss_disc = self.forward_to_get_disc_loss(pred_img=outs["comp_rgb"], |
| gt_img=data["render_image"]) |
| local_step_losses.append(torch.concat([local_step_losses_bak[:6], loss_disc.unsqueeze(0)], dim=0)) |
| torch.cuda.empty_cache() |
|
|
| |
| if self.accelerator.sync_gradients: |
| profiler.step() |
| if not disc_step: |
| self.scheduler.step() |
| if self.has_disc and disc_step: |
| self.scheduler_disc.step() |
| logger.debug(f"======== Scheduler step ========") |
| self.global_step += 1 |
| global_step_loss = self.accelerator.gather(torch.stack(local_step_losses)).mean(dim=0).cpu() |
| if len(extra_loss_keys) > 0: |
| global_step_extra_loss = self.accelerator.gather(torch.stack(local_step_extra_losses)).mean(dim=0).cpu() |
| global_step_extra_loss_items = global_step_extra_loss.unbind() |
| else: |
| global_step_extra_loss = None |
| global_step_extra_loss_items = [] |
| loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, loss_gen, loss_disc_ = global_step_loss.unbind() |
| loss_kwargs = { |
| 'loss': loss.item(), |
| 'loss_pixel': loss_pixel.item(), |
| 'loss_perceptual': loss_perceptual.item(), |
| 'loss_tv': loss_tv.item(), |
| 'loss_mask': loss_mask.item(), |
| 'loss_disc': loss_disc_.item(), |
| 'loss_gen': loss_gen.item(), |
| } |
| for k, loss in zip(extra_loss_keys, global_step_extra_loss_items): |
| loss_kwargs[k] = loss.item() |
| self.log_scalar_kwargs( |
| step=self.global_step, split='train', |
| **loss_kwargs |
| ) |
| self.log_optimizer(step=self.global_step, attrs=['lr'], group_ids=[0, 1]) |
| local_step_losses = [] |
| global_step_losses.append(global_step_loss) |
| local_step_extra_losses = [] |
| global_step_extra_losses.append(global_step_extra_loss) |
|
|
| |
| pbar.update(1) |
| description = { |
| **loss_kwargs, |
| 'lr': self.optimizer.param_groups[0]['lr'], |
| } |
| description = '[TRAIN STEP]' + \ |
| ', '.join(f'{k}={tqdm.format_num(v)}' for k, v in description.items() if not math.isnan(v)) |
| pbar.set_description(description) |
|
|
| |
| if self.global_step % self.cfg.saver.checkpoint_global_steps == 0: |
| self.save_checkpoint() |
| if self.global_step % self.cfg.val.global_step_period == 0: |
| self.evaluate() |
| self.model.train() |
| if self.has_disc: |
| self.model_disc.train() |
| if (self.global_step % self.cfg.logger.image_monitor.train_global_steps == 0) or (self.global_step < 1000 and self.global_step % 20 == 0): |
| conf_sigma_l1 = outs.get('conf_sigma_l1', None) |
| conf_sigma_l1 = conf_sigma_l1.cpu() if conf_sigma_l1 is not None else None |
| conf_sigma_percl = outs.get('conf_sigma_percl', None) |
| conf_sigma_percl = conf_sigma_percl.cpu() if conf_sigma_percl is not None else None |
| self.log_image_monitor( |
| step=self.global_step, split='train', |
| renders=outs['comp_rgb'].detach()[:self.cfg.logger.image_monitor.samples_per_log].cpu(), |
| conf_sigma_l1=conf_sigma_l1, conf_sigma_percl=conf_sigma_percl, |
| gts=data['render_image'][:self.cfg.logger.image_monitor.samples_per_log].cpu(), |
| ) |
| if 'comp_mask' in outs.keys(): |
| self.log_image_monitor( |
| step=self.global_step, split='train', |
| renders=outs['comp_mask'].detach()[:self.cfg.logger.image_monitor.samples_per_log].cpu(), |
| gts=data['render_mask'][:self.cfg.logger.image_monitor.samples_per_log].cpu(), |
| prefix="_mask", |
| ) |
|
|
| |
| if self.global_step >= self.N_max_global_steps: |
| self.accelerator.set_trigger() |
| break |
|
|
| |
| self.current_epoch += 1 |
| epoch_losses = torch.stack(global_step_losses).mean(dim=0) |
| epoch_loss, epoch_loss_pixel, epoch_loss_perceptual, epoch_loss_tv, epoch_loss_mask, epoch_loss_disc, epoch_loss_gen = epoch_losses.unbind() |
| epoch_loss_dict = { |
| 'loss': epoch_loss.item(), |
| 'loss_pixel': epoch_loss_pixel.item(), |
| 'loss_perceptual': epoch_loss_perceptual.item(), |
| 'loss_tv': epoch_loss_tv.item(), |
| 'loss_mask': epoch_loss_mask.item(), |
| 'loss_disc': epoch_loss_disc.item(), |
| 'loss_gen': epoch_loss_gen.item(), |
| } |
| if len(extra_loss_keys) > 0: |
| epoch_extra_losses = torch.stack(global_step_extra_losses).mean(dim=0) |
| for k, v in zip(extra_loss_keys, epoch_extra_losses.unbind()): |
| epoch_loss_dict[k] = v.item() |
| self.log_scalar_kwargs( |
| epoch=self.current_epoch, split='train', |
| **epoch_loss_dict, |
| ) |
| logger.info( |
| f'[TRAIN EPOCH] {self.current_epoch}/{self.cfg.train.epochs}: ' + \ |
| ', '.join(f'{k}={tqdm.format_num(v)}' for k, v in epoch_loss_dict.items() if not math.isnan(v)) |
| ) |
|
|
| def train(self): |
| |
| starting_local_step_in_epoch = self.global_step_in_epoch * self.cfg.train.accum_steps |
| skipped_loader = self.accelerator.skip_first_batches(self.train_loader, starting_local_step_in_epoch) |
| logger.info(f"======== Skipped {starting_local_step_in_epoch} local batches ========") |
|
|
| with tqdm( |
| range(0, self.N_max_global_steps), |
| initial=self.global_step, |
| disable=(not self.accelerator.is_main_process), |
| ) as pbar: |
|
|
| profiler = torch.profiler.profile( |
| activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], |
| schedule=torch.profiler.schedule( |
| wait=10, warmup=10, active=100, |
| ), |
| on_trace_ready=torch.profiler.tensorboard_trace_handler(os.path.join( |
| self.cfg.logger.tracker_root, |
| self.cfg.experiment.parent, self.cfg.experiment.child, |
| )), |
| record_shapes=True, |
| profile_memory=True, |
| with_stack=True, |
| ) if self.cfg.logger.enable_profiler else DummyProfiler() |
| |
| with profiler: |
| self.optimizer.zero_grad() |
| if self.has_disc: |
| self.optimizer_disc.zero_grad() |
| for iepoch in range(self.current_epoch, self.cfg.train.epochs): |
|
|
| loader = skipped_loader or self.train_loader |
| skipped_loader = None |
| self.train_epoch(pbar=pbar, loader=loader, profiler=profiler, iepoch=iepoch) |
| if self.accelerator.check_trigger(): |
| break |
|
|
| logger.info(f"======== Training finished at global step {self.global_step} ========") |
|
|
| |
| self.save_checkpoint() |
| self.evaluate() |
|
|
| @torch.no_grad() |
| @torch.compiler.disable |
| def evaluate(self, epoch: int = None): |
| self.model.eval() |
|
|
| max_val_batches = self.cfg.val.debug_batches or len(self.val_loader) |
| running_losses = [] |
| running_extra_losses = [] |
| extra_loss_keys = [] |
| sample_data, sample_outs = None, None |
|
|
| for data in tqdm(self.val_loader, disable=(not self.accelerator.is_main_process), total=max_val_batches): |
| data["source_rgbs"] = data["source_rgbs"].to(self.weight_dtype) |
| if self.has_disc and hasattr(self.cfg.model.disc, "cross_render") and self.cfg.model.disc.cross_render: |
| data = self.prepare_cross_render_data(data) |
| data["real_num_view"] = 1 |
| else: |
| data["real_num_view"] = data["render_image"].shape[1] |
| |
| if len(running_losses) >= max_val_batches: |
| logger.info(f"======== Early stop validation at {len(running_losses)} batches ========") |
| break |
|
|
| outs, loss, loss_pixel, loss_perceptual, loss_tv, loss_mask, extra_loss_dict = self.forward_loss_local_step(data) |
| extra_loss_dict = sorted(list(extra_loss_dict.keys())) |
| sample_data, sample_outs = data, outs |
|
|
| running_losses.append(torch.stack([ |
| _loss if _loss is not None else torch.tensor(float('nan'), device=self.device) |
| for _loss in [loss, loss_pixel, loss_perceptual, loss_tv, loss_mask] |
| ])) |
| if len(extra_loss_keys) > 0: |
| running_extra_losses.append(torch.stack([ |
| extra_loss_dict[k] if extra_loss_dict[k] is not None else torch.tensor(float('nan'), device=self.device) |
| for k in extra_loss_keys |
| ])) |
|
|
| |
| conf_sigma_l1 = sample_outs.get('conf_sigma_l1', None) |
| conf_sigma_l1 = conf_sigma_l1.cpu() if conf_sigma_l1 is not None else None |
| conf_sigma_percl = sample_outs.get('conf_sigma_percl', None) |
| conf_sigma_percl = conf_sigma_percl.cpu() if conf_sigma_percl is not None else None |
| self.log_image_monitor_each_process( |
| step=self.global_step, split='val', |
| renders=sample_outs['comp_rgb'][:self.cfg.logger.image_monitor.samples_per_log].cpu(), |
| gts=sample_data['render_image'][:self.cfg.logger.image_monitor.samples_per_log].cpu(), |
| conf_sigma_l1=conf_sigma_l1, conf_sigma_percl=conf_sigma_percl, |
| prefix=f"_{len(running_losses)}_rank{self.accelerator.process_index}" |
| ) |
| if "comp_mask" in sample_outs.keys(): |
| self.log_image_monitor_each_process( |
| step=self.global_step, split='val', |
| renders=sample_outs['comp_mask'][:self.cfg.logger.image_monitor.samples_per_log].cpu(), |
| gts=sample_data['render_mask'][:self.cfg.logger.image_monitor.samples_per_log].cpu(), |
| prefix=f"_mask_{len(running_losses)}_rank{self.accelerator.process_index}" |
| ) |
| |
| total_losses = self.accelerator.gather(torch.stack(running_losses)).mean(dim=0).cpu() |
| total_loss, total_loss_pixel, total_loss_perceptual, total_loss_offset, total_loss_mask = total_losses.unbind() |
| total_loss_dict = { |
| 'loss': total_loss.item(), |
| 'loss_pixel': total_loss_pixel.item(), |
| 'loss_perceptual': total_loss_perceptual.item(), |
| 'loss_offset': total_loss_offset.item(), |
| 'loss_mask': total_loss_mask.item(), |
| } |
| if len(extra_loss_keys) > 0: |
| total_extra_losses = self.accelerator.gather(torch.stack(running_extra_losses)).mean(dim=0).cpu() |
| for k, v in zip(extra_loss_keys, total_extra_losses.unbind()): |
| total_loss_dict[k] = v.item() |
|
|
| if epoch is not None: |
| self.log_scalar_kwargs( |
| epoch=epoch, split='val', |
| **total_loss_dict, |
| ) |
| logger.info( |
| f'[VAL EPOCH] {epoch}/{self.cfg.train.epochs}: ' + \ |
| ', '.join(f'{k}={tqdm.format_num(v)}' for k, v in total_loss_dict.items() if not math.isnan(v)) |
| ) |
| else: |
| self.log_scalar_kwargs( |
| step=self.global_step, split='val', |
| **total_loss_dict, |
| ) |
| logger.info( |
| f'[VAL STEP] {self.global_step}/{self.N_max_global_steps}: ' + \ |
| ', '.join(f'{k}={tqdm.format_num(v)}' for k, v in total_loss_dict.items() if not math.isnan(v)) |
| ) |
|
|
| def log_image_monitor_each_process( |
| self, epoch: int = None, step: int = None, split: str = None, |
| renders: torch.Tensor = None, gts: torch.Tensor = None, prefix=None, |
| conf_sigma_l1: torch.Tensor = None, conf_sigma_percl: torch.Tensor = None |
| ): |
| M = renders.shape[1] |
| if gts.shape[1] != M: |
| gts = repeat(gts, "b v c h w -> b (v r) c h w", r=2) |
| merged = torch.stack([renders, gts], dim=1)[0].view(-1, *renders.shape[2:]) |
| renders, gts = renders.view(-1, *renders.shape[2:]), gts.view(-1, *gts.shape[2:]) |
| renders, gts, merged = make_grid(renders, nrow=M), make_grid(gts, nrow=M), make_grid(merged, nrow=M) |
| log_type, log_progress = self._get_str_progress(epoch, step) |
| split = f'/{split}' if split else '' |
| split = split + prefix if prefix is not None else split |
| log_img_dict = { |
| f'Images_split{split}/rendered': renders.unsqueeze(0), |
| f'Images_split{split}/gt': gts.unsqueeze(0), |
| f'Images_split{split}/merged': merged.unsqueeze(0), |
| } |
| if conf_sigma_l1 is not None: |
| EPS = 1e-7 |
| vis_conf_l1 = 1/(1+conf_sigma_l1.detach()+EPS).cpu() |
| vis_conf_percl = 1/(1+conf_sigma_percl.detach()+EPS).cpu() |
| vis_conf_l1, vis_conf_percl = rearrange(vis_conf_l1, "b v (r c) h w -> (b v r) c h w", r=2), rearrange(vis_conf_percl, "b v (r c) h w -> (b v r) c h w", r=2) |
| vis_conf_l1, vis_conf_percl = repeat(vis_conf_l1, "b c1 h w-> b (c1 c2) h w", c2=3), repeat(vis_conf_percl, "b c1 h w -> b (c1 c2) h w", c2=3) |
| vis_conf_l1, vis_conf_percl = make_grid(vis_conf_l1, nrow=M), make_grid(vis_conf_percl, nrow=M) |
| log_img_dict[f'Images_split{split}/conf_l1'] = vis_conf_l1.unsqueeze(0) |
| log_img_dict[f'Images_split{split}/conf_percl'] = vis_conf_percl.unsqueeze(0) |
|
|
| self.log_images_each_process(log_img_dict, log_progress, {"imwrite_image": False}) |
|
|
| |
| @Trainer.control('on_main_process') |
| def log_image_monitor( |
| self, epoch: int = None, step: int = None, split: str = None, |
| renders: torch.Tensor = None, gts: torch.Tensor = None, prefix=None, |
| conf_sigma_l1: torch.Tensor = None, conf_sigma_percl: torch.Tensor = None |
| ): |
| self.log_image_monitor_each_process(epoch, step, split, renders, gts, prefix, conf_sigma_l1, conf_sigma_percl) |
|
|