Spaces:
Running
on
Zero
Running
on
Zero
| from typing import * | |
| import os | |
| import copy | |
| import functools | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader | |
| import utils3d | |
| from easydict import EasyDict as edict | |
| from ..basic import BasicTrainer | |
| from ...modules import sparse as sp | |
| from ...renderers import MeshRenderer | |
| from ...representations import Mesh | |
| from ...utils.data_utils import recursive_to_device, cycle, BalancedResumableSampler | |
| from ...utils.loss_utils import l1_loss, ssim, lpips | |
| class ShapeVaeTrainer(BasicTrainer): | |
| """ | |
| Trainer for Shape VAE | |
| Args: | |
| models (dict[str, nn.Module]): Models to train. | |
| dataset (torch.utils.data.Dataset): Dataset. | |
| output_dir (str): Output directory. | |
| load_dir (str): Load directory. | |
| step (int): Step to load. | |
| batch_size (int): Batch size. | |
| batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. | |
| batch_split (int): Split batch with gradient accumulation. | |
| max_steps (int): Max steps. | |
| optimizer (dict): Optimizer config. | |
| lr_scheduler (dict): Learning rate scheduler config. | |
| elastic (dict): Elastic memory management config. | |
| grad_clip (float or dict): Gradient clip config. | |
| ema_rate (float or list): Exponential moving average rates. | |
| fp16_mode (str): FP16 mode. | |
| - None: No FP16. | |
| - 'inflat_all': Hold a inflated fp32 master param for all params. | |
| - 'amp': Automatic mixed precision. | |
| fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. | |
| finetune_ckpt (dict): Finetune checkpoint. | |
| log_param_stats (bool): Log parameter stats. | |
| i_print (int): Print interval. | |
| i_log (int): Log interval. | |
| i_sample (int): Sample interval. | |
| i_save (int): Save interval. | |
| i_ddpcheck (int): DDP check interval. | |
| lambda_subdiv (float): Subdivision loss weight. | |
| lambda_intersected (float): Intersected loss weight. | |
| lambda_vertice (float): Vertice loss weight. | |
| lambda_kl (float): KL loss weight. | |
| lambda_ssim (float): SSIM loss weight. | |
| lambda_lpips (float): LPIPS loss weight. | |
| """ | |
| def __init__( | |
| self, | |
| *args, | |
| lambda_subdiv: float = 0.1, | |
| lambda_intersected: float = 0.1, | |
| lambda_vertice: float = 1e-2, | |
| lambda_mask: float = 1, | |
| lambda_depth: float = 10, | |
| lambda_normal: float = 1, | |
| lambda_kl: float = 1e-6, | |
| lambda_ssim: float = 0.2, | |
| lambda_lpips: float = 0.2, | |
| render_resolution: float = 1024, | |
| camera_randomization_config: dict = { | |
| 'radius_range': [2, 100], | |
| }, | |
| **kwargs | |
| ): | |
| super().__init__(*args, **kwargs) | |
| self.lambda_subdiv = lambda_subdiv | |
| self.lambda_intersected = lambda_intersected | |
| self.lambda_mask = lambda_mask | |
| self.lambda_vertice = lambda_vertice | |
| self.lambda_depth = lambda_depth | |
| self.lambda_normal = lambda_normal | |
| self.lambda_kl = lambda_kl | |
| self.lambda_ssim = lambda_ssim | |
| self.lambda_lpips = lambda_lpips | |
| self.camera_randomization_config = camera_randomization_config | |
| self.renderer = MeshRenderer({'near': 1, 'far': 3, 'resolution': render_resolution}, device=self.device) | |
| def prepare_dataloader(self, **kwargs): | |
| """ | |
| Prepare dataloader. | |
| """ | |
| self.data_sampler = BalancedResumableSampler( | |
| self.dataset, | |
| shuffle=True, | |
| batch_size=self.batch_size_per_gpu, | |
| ) | |
| self.dataloader = DataLoader( | |
| self.dataset, | |
| batch_size=self.batch_size_per_gpu, | |
| num_workers=int(np.ceil(os.cpu_count() / torch.cuda.device_count())), | |
| pin_memory=True, | |
| drop_last=True, | |
| persistent_workers=True, | |
| collate_fn=functools.partial(self.dataset.collate_fn, split_size=self.batch_split), | |
| sampler=self.data_sampler, | |
| ) | |
| self.data_iterator = cycle(self.dataloader) | |
| def _randomize_camera(self, num_samples: int): | |
| # sample radius and fov | |
| r_min, r_max = self.camera_randomization_config['radius_range'] | |
| k_min = 1 / r_max**2 | |
| k_max = 1 / r_min**2 | |
| ks = torch.rand(num_samples, device=self.device) * (k_max - k_min) + k_min | |
| radius = 1 / torch.sqrt(ks) | |
| fov = 2 * torch.arcsin(0.5 / radius) | |
| origin = radius.unsqueeze(-1) * F.normalize(torch.randn(num_samples, 3, device=self.device), dim=-1) | |
| # build camera | |
| extrinsics = utils3d.torch.extrinsics_look_at(origin, torch.zeros_like(origin), torch.tensor([0, 0, 1], dtype=torch.float32, device=self.device)) | |
| intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov) | |
| near = [np.random.uniform(r - 1, r) for r in radius.tolist()] | |
| return { | |
| 'extrinsics': extrinsics, | |
| 'intrinsics': intrinsics, | |
| 'near': near, | |
| } | |
| def _render_batch(self, reps: List[Mesh], extrinsics: torch.Tensor, intrinsics: torch.Tensor, near: List, | |
| return_types=['mask', 'normal', 'depth']) -> Dict[str, torch.Tensor]: | |
| """ | |
| Render a batch of representations. | |
| Args: | |
| reps: The dictionary of lists of representations. | |
| extrinsics: The [N x 4 x 4] tensor of extrinsics. | |
| intrinsics: The [N x 3 x 3] tensor of intrinsics. | |
| return_types: vary in ['mask', 'normal', 'depth', 'normal_map', 'color'] | |
| Returns: | |
| a dict with | |
| mask : [N x 1 x H x W] tensor of rendered masks | |
| normal : [N x 3 x H x W] tensor of rendered normals | |
| depth : [N x 1 x H x W] tensor of rendered depths | |
| """ | |
| ret = {k : [] for k in return_types} | |
| for i, rep in enumerate(reps): | |
| self.renderer.rendering_options['near'] = near[i] | |
| self.renderer.rendering_options['far'] = near[i] + 2 | |
| out_dict = self.renderer.render(rep, extrinsics[i], intrinsics[i], return_types=return_types) | |
| for k in out_dict: | |
| ret[k].append(out_dict[k][None] if k in ['mask', 'depth'] else out_dict[k]) | |
| for k in ret: | |
| ret[k] = torch.stack(ret[k]) | |
| return ret | |
| def training_losses( | |
| self, | |
| vertices: sp.SparseTensor, | |
| intersected: sp.SparseTensor, | |
| mesh: List[Mesh], | |
| ) -> Tuple[Dict, Dict]: | |
| """ | |
| Compute training losses. | |
| Args: | |
| vertices (SparseTensor): vertices of each active voxel | |
| intersected (SparseTensor): intersected flag of each active voxel | |
| mesh (List[Mesh]): the list of meshes to render | |
| Returns: | |
| a dict with the key "loss" containing a scalar tensor. | |
| may also contain other keys for different terms. | |
| """ | |
| z, mean, logvar = self.training_models['encoder'](vertices, intersected, sample_posterior=True, return_raw=True) | |
| recon, pred_vertice, pred_intersected, subs_gt, subs = self.training_models['decoder'](z, intersected) | |
| terms = edict(loss = 0.0) | |
| # direct regression | |
| if self.lambda_intersected > 0: | |
| terms["direct/intersected"] = F.binary_cross_entropy_with_logits(pred_intersected.feats.flatten(), intersected.feats.flatten().float()) | |
| terms["loss"] = terms["loss"] + self.lambda_intersected * terms["direct/intersected"] | |
| if self.lambda_vertice > 0: | |
| terms["direct/vertice"] = F.mse_loss(pred_vertice.feats, vertices.feats) | |
| terms["loss"] = terms["loss"] + self.lambda_vertice * terms["direct/vertice"] | |
| # subdivision prediction loss | |
| for i, (sub_gt, sub) in enumerate(zip(subs_gt, subs)): | |
| terms[f"bce_sub{i}"] = F.binary_cross_entropy_with_logits(sub.feats, sub_gt.float()) | |
| terms["loss"] = terms["loss"] + self.lambda_subdiv * terms[f"bce_sub{i}"] | |
| # rendering loss | |
| cameras = self._randomize_camera(len(mesh)) | |
| gt_renders = self._render_batch(mesh, **cameras, return_types=['mask', 'normal', 'depth']) | |
| pred_renders = self._render_batch(recon, **cameras, return_types=['mask', 'normal', 'depth']) | |
| terms['render/mask'] = l1_loss(pred_renders['mask'], gt_renders['mask']) | |
| terms['render/depth'] = l1_loss(pred_renders['depth'], gt_renders['depth']) | |
| terms['render/normal/l1'] = l1_loss(pred_renders['normal'], gt_renders['normal']) | |
| terms['render/normal/ssim'] = 1 - ssim(pred_renders['normal'], gt_renders['normal']) | |
| terms['render/normal/lpips'] = lpips(pred_renders['normal'], gt_renders['normal']) | |
| terms['loss'] = terms['loss'] + \ | |
| self.lambda_mask * terms['render/mask'] + \ | |
| self.lambda_depth * terms['render/depth'] + \ | |
| self.lambda_normal * (terms['render/normal/l1'] + self.lambda_ssim * terms['render/normal/ssim'] + self.lambda_lpips * terms['render/normal/lpips']) | |
| # KL regularization | |
| terms["kl"] = 0.5 * torch.mean(mean.pow(2) + logvar.exp() - logvar - 1) | |
| terms["loss"] = terms["loss"] + self.lambda_kl * terms["kl"] | |
| return terms, {} | |
| def run_snapshot( | |
| self, | |
| num_samples: int, | |
| batch_size: int, | |
| verbose: bool = False, | |
| ) -> Dict: | |
| dataloader = DataLoader( | |
| copy.deepcopy(self.dataset), | |
| batch_size=batch_size, | |
| shuffle=True, | |
| num_workers=1, | |
| collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, | |
| ) | |
| # inference | |
| gts = [] | |
| recons = [] | |
| recons2 = [] | |
| self.models['encoder'].eval() | |
| for i in range(0, num_samples, batch_size): | |
| batch = min(batch_size, num_samples - i) | |
| data = next(iter(dataloader)) | |
| args = {k: v[:batch] for k, v in data.items()} | |
| args = recursive_to_device(args, self.device) | |
| z = self.models['encoder'](args['vertices'], args['intersected']) | |
| self.models['decoder'].train() | |
| y = self.models['decoder'](z, args['intersected'])[0] | |
| z.clear_spatial_cache() | |
| self.models['decoder'].eval() | |
| y2 = self.models['decoder'](z) | |
| gts.extend(args['mesh']) | |
| recons.extend(y) | |
| recons2.extend(y2) | |
| self.models['encoder'].train() | |
| self.models['decoder'].train() | |
| cameras = self._randomize_camera(num_samples) | |
| gt_renders = self._render_batch(gts, **cameras, return_types=['normal']) | |
| recons_renders = self._render_batch(recons, **cameras, return_types=['normal']) | |
| recons2_renders = self._render_batch(recons2, **cameras, return_types=['normal']) | |
| sample_dict = { | |
| 'gt': {'value': gt_renders['normal'], 'type': 'image'}, | |
| 'rec': {'value': recons_renders['normal'], 'type': 'image'}, | |
| 'rec2': {'value': recons2_renders['normal'], 'type': 'image'}, | |
| } | |
| return sample_dict | |