Spaces:
Sleeping
Sleeping
| from dataclasses import dataclass | |
| from typing import Literal, Optional, List | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.utils.checkpoint | |
| from einops import rearrange | |
| from torch import nn | |
| from optgs.dataset.data_types import BatchedExample, DataShim, BatchedViews | |
| from optgs.dataset.shims.patch_shim import apply_patch_shim | |
| from optgs.geometry.projection import sample_image_grid, get_world_rays | |
| from optgs.misc.general_utils import rotate_quats | |
| from optgs.misc.io import FrequencyScheduler | |
| from optgs.model.encoder.layer import BasicBlock | |
| from optgs.model.encoder.unimatch.dpt_head import DPTHead | |
| from optgs.model.encoder.unimatch.feature_upsampler import ResizeConvFeatureUpsampler | |
| from optgs.model.encoder.unimatch.ldm_unet.unet import UNetModel | |
| from optgs.model.encoder.unimatch.mv_unimatch import MultiViewUniMatch | |
| from optgs.model.encoder.visualization.encoder_visualizer_depthsplat_cfg import EncoderVisualizerDepthSplatCfg | |
| from optgs.model.types import Gaussians | |
| from optgs.scene_trainer.common.gaussian_adapter import GaussianAdapter, GaussianAdapterCfg, build_covariance, RGB2SH | |
| from optgs.scene_trainer.initializer.initializer import InitializerOutput, LearnedInitializer, PerPixelInitializerCfg | |
| try: | |
| from optgs.model.encoder.point_transformer.layer import (PlainPointTransformer, SubsampleBlock, PointLinearWrapper, | |
| MultiScalePointTransformer, | |
| MultViewLowresAttn, MultViewUniMatchAttn, | |
| GaussianErrorCrossAttn) | |
| except: | |
| pass | |
| from optgs.model.encoder.lvsm.transformer import LVSMTransformer | |
| try: | |
| from simple_knn._C import distCUDA2 | |
| except: | |
| pass | |
| class ResplatInitializerCfg(PerPixelInitializerCfg): | |
| name: Literal["resplat_v1", "resplat_v2"] | |
| d_feature: int | |
| num_depth_candidates: int | |
| num_surfaces: int | |
| visualizer: EncoderVisualizerDepthSplatCfg | |
| gaussian_adapter: GaussianAdapterCfg | |
| gaussians_per_pixel: int | |
| unimatch_weights_path: str | None | |
| downscale_factor: int | |
| shim_patch_size: int | List[int] | |
| multiview_trans_attn_split: int | |
| deform_sample_depth: bool # non-pixel aligned Gaussians with learned offsets | |
| deform_sample_depth_debug: bool # check depth sampling | |
| # mv_unimatch | |
| num_scales: int | |
| upsample_factor: int | |
| lowest_feature_resolution: int | |
| depth_unet_channels: int | |
| grid_sample_disable_cudnn: bool | |
| # depthsplat color branch | |
| large_gaussian_head: bool | |
| color_large_unet: bool | |
| init_sh_input_img: bool | |
| feature_upsampler_channels: int | |
| gaussian_regressor_channels: int | |
| # loss config | |
| return_depth: bool | |
| # only depth | |
| train_depth_only: bool | |
| # monodepth config | |
| monodepth_vit_type: str | |
| # multi-view matching | |
| local_mv_match: int | |
| # point transformer | |
| pt_head: bool | |
| init_pt_with_mv_attn: bool | |
| init_pt_with_mv_attn_lowres: bool | |
| pt_head_conv: bool | |
| pt_head_concat_img: bool | |
| pt_head_channels: int | None | |
| multi_scale_pt: bool | |
| attn_proj_channels: int | None | |
| fps_num_samples: int | None | |
| knn_samples: int | |
| post_norm: bool | |
| no_rpe: bool | |
| no_knn_attn: bool | |
| num_blocks: int | |
| pt_downsample: int | |
| fps_agg_func: str | |
| subsample_method: str | |
| add_pt_residual: bool | |
| pt_pred_residual_position: bool # based on the inital point cloud from depth, predict additional residual | |
| latent_dpt_upsampler: bool | |
| latent_dpt_upsampler_no_concat: bool | |
| light_dpt_feature: bool | |
| # freeze depth | |
| freeze_depth: bool | |
| use_gt_depth: bool | |
| # separate depth and color branches | |
| separate_depth_color: bool | |
| separate_depth_type: str | |
| separate_depth_gaussian_scale: bool | |
| # unet gaussian regressor | |
| unet_gaussian_regressor: bool | |
| resnet_gaussian_regressor: bool | |
| # lvsm gaussian regressor | |
| lvsm_gaussian_regressor: bool | |
| lvsm_layers: int | |
| sample_log_depth: bool | |
| bilinear_upsample_depth: bool | |
| no_upsample_depth: bool | |
| return_lowres_depth: bool | |
| # latent gaussian instead of pixel-aligned gaussian | |
| fixed_latent_size: bool # same channels for both downsample 4 and 8 | |
| latent_gs_img_interp: str | |
| dpt_head_depth: bool # downsample the full resolution depth to low resolution | |
| avgpool_depth: bool | |
| nearest_down_depth: bool | |
| # predict scene scale and use point distance to normalize the scene | |
| predict_scale: bool | |
| norm_by_points: bool | |
| no_pred_depth_range: bool | |
| # init gaussian scale with point cloud distance | |
| point_dist_init_gaussian_scale: bool | |
| # feature upsampler | |
| resizeconv_upsampler: bool | |
| # rotate_quat_to_world: bool # rotate the quaternion to world space | |
| latent_new_reshape: bool # debug | |
| # amp | |
| use_amp: bool | |
| pt_head_amp: bool | |
| use_checkpointing: bool | |
| init_use_checkpointing: bool # init model uses checkpointing | |
| no_pixel_offset: bool | |
| pt_heads: int | |
| init_gaussian_multiple: int | |
| depth_pred_half_res: bool | |
| def get_feature_upsampler_channels(self): | |
| # upsample features to the original resolution | |
| model_configs = { | |
| 'vits': {'in_channels': 384, 'features': 64, 'out_channels': [48, 96, 192, 384]}, | |
| 'vitb': {'in_channels': 768, 'features': 96, 'out_channels': [96, 192, 384, 768]}, | |
| 'vitl': {'in_channels': 1024, 'features': 128, 'out_channels': [128, 256, 512, 1024]}, | |
| } | |
| vit_type = self.monodepth_vit_type | |
| in_channels = model_configs[vit_type]['in_channels'] | |
| if self.latent_gs and not self.latent_dpt_upsampler: | |
| if self.latent_downsample == 2: | |
| feature_num = in_channels // 64 * 4 + 128 // 4 + 64 + 96 + 128 // 4 | |
| elif self.latent_downsample == 4: | |
| feature_num = in_channels // 4 + 128 + 64 + 96 + 128 | |
| elif self.latent_downsample == 8: | |
| if self.fixed_latent_size: | |
| feature_num = in_channels // 4 + 128 + 64 + 96 + 128 | |
| else: | |
| feature_num = in_channels + 128 + 64 + 96 + 128 | |
| else: | |
| raise NotImplementedError(f"Unsupported latent_downsample value: {self.cfg.latent_downsample}") | |
| elif self.resizeconv_upsampler: | |
| feature_num = self.feature_upsampler_channels | |
| else: | |
| if self.light_dpt_feature: | |
| for config in model_configs.values(): | |
| config['out_channels'] = [c // 2 for c in config['out_channels']] | |
| features = model_configs[vit_type]["features"] | |
| if self.latent_gs and not self.latent_dpt_upsampler_no_concat: | |
| features *= 4 | |
| feature_num = features | |
| return feature_num, model_configs | |
| def get_pt_in_channels(self): | |
| feature_upsampler_channels, _ = self.get_feature_upsampler_channels() | |
| in_channels = 3 + feature_upsampler_channels + self.gaussian_regressor_channels + 1 | |
| if self.latent_gs: | |
| # image unshuffle | |
| if self.fixed_latent_size: | |
| in_channels = in_channels - 3 + 3 * (4 ** 2) | |
| else: | |
| in_channels = in_channels - 3 + 3 * (self.latent_downsample ** 2) | |
| return in_channels | |
| def get_gaussian_param_num(self): | |
| # predict gaussian parameters: scale, q, sh, offset, opacity | |
| # d_in: (scale, q, sh) | |
| sh_d = self.get_sh_d() | |
| init_gaussian_param_num = 3 + 4 + 3 * sh_d + 2 + 1 | |
| if self.no_pixel_offset: | |
| init_gaussian_param_num -= 2 | |
| if self.pt_downsample > 0: | |
| # no pixel offset | |
| init_gaussian_param_num -= 2 | |
| if self.pt_pred_residual_position: | |
| # based on the inital point cloud from depth, predict additional residual | |
| # without pixel offset on 2d | |
| init_gaussian_param_num = init_gaussian_param_num + 3 - 2 | |
| # multiple gaussians per latent | |
| if self.init_gaussian_multiple > 1: | |
| # we use the point cloud unprojected from higher resolution depth map as center | |
| # assert self.cfg.gaussian_adapter.init_rotation_identity | |
| assert self.latent_gs | |
| init_gaussian_param_num *= self.init_gaussian_multiple | |
| return init_gaussian_param_num | |
| def get_sh_d(self): | |
| sh_d = (self.gaussian_adapter.sh_degree + 1) ** 2 | |
| return sh_d | |
| class ResplatInitializer(LearnedInitializer[ResplatInitializerCfg]): | |
| def __init__(self, cfg: ResplatInitializerCfg) -> None: | |
| super().__init__(cfg) | |
| self.depth_predictor = self._get_depth_predictor(cfg) | |
| if self.cfg.train_depth_only: | |
| return | |
| feature_upsampler_channels, model_configs = self.cfg.get_feature_upsampler_channels() | |
| if self.cfg.latent_gs and not self.cfg.latent_dpt_upsampler: | |
| # No need to create a module — this config only computes channels | |
| pass | |
| elif self.cfg.resizeconv_upsampler: | |
| self.feature_upsampler = ResizeConvFeatureUpsampler( | |
| num_scales=cfg.num_scales, | |
| lowest_feature_resolution=cfg.lowest_feature_resolution, | |
| out_channels=self.cfg.feature_upsampler_channels, | |
| vit_type=self.cfg.monodepth_vit_type, | |
| ) | |
| else: | |
| self.feature_upsampler = DPTHead( | |
| **model_configs[cfg.monodepth_vit_type], | |
| downsample_factor=cfg.upsample_factor, | |
| return_feature=True, | |
| num_scales=cfg.num_scales, | |
| latent_downsample=self.cfg.latent_downsample if self.cfg.latent_gs else None, | |
| latent_feature_no_concat=self.cfg.latent_dpt_upsampler_no_concat, | |
| ) | |
| # gaussians adapter (can be removed) | |
| self.gaussian_adapter = GaussianAdapter(cfg.gaussian_adapter) | |
| # concat(img, depth, match_prob, features) | |
| in_channels = 3 + 1 + 1 + feature_upsampler_channels | |
| channels = self.cfg.gaussian_regressor_channels | |
| if self.cfg.latent_gs: | |
| # image unshuffle | |
| if self.cfg.fixed_latent_size: | |
| # fixed patch size 4 | |
| in_channels = in_channels - 3 + 3 * (4 ** 2) | |
| else: | |
| in_channels = in_channels - 3 + 3 * (self.cfg.latent_downsample ** 2) | |
| # unet gaussian regressor | |
| if self.cfg.unet_gaussian_regressor: | |
| modules = [ | |
| nn.Conv2d(in_channels, channels, 3, 1, 1), | |
| nn.GroupNorm(8, channels), | |
| nn.GELU(), | |
| ] | |
| if self.cfg.color_large_unet: | |
| unet_channel_mult = [1, 2, 4, 4, 4] | |
| else: | |
| unet_channel_mult = [1, 1, 1, 1, 1] | |
| unet_attn_resolutions = [16] | |
| modules.append( | |
| UNetModel( | |
| image_size=None, | |
| in_channels=channels, | |
| model_channels=channels, | |
| out_channels=channels, | |
| num_res_blocks=1, # self.unet_per_scale_blocks, | |
| # attention_resolutions=[8, 4, 2], | |
| attention_resolutions=unet_attn_resolutions, | |
| # channel_mult=[1, 1, 1, 1], | |
| channel_mult=unet_channel_mult, | |
| num_head_channels=32 if self.cfg.gaussian_regressor_channels >= 32 else 16, | |
| dims=2, | |
| postnorm=False, | |
| num_frames=2, | |
| use_cross_view_self_attn=True, | |
| ) | |
| ) | |
| modules.append(nn.Conv2d(channels, channels, 3, 1, 1)) | |
| elif self.cfg.resnet_gaussian_regressor: | |
| modules = [ | |
| nn.Conv2d(in_channels, channels, 3, 1, 1), | |
| nn.GroupNorm(8, channels), | |
| nn.GELU(), | |
| BasicBlock(channels, channels), | |
| BasicBlock(channels, channels), | |
| ] | |
| elif self.cfg.lvsm_gaussian_regressor: | |
| modules = [ | |
| nn.Linear(in_channels, channels), | |
| nn.LayerNorm(channels), | |
| nn.GELU(), | |
| LVSMTransformer(channels, | |
| n_layer=self.cfg.lvsm_layers) | |
| ] | |
| else: | |
| # conv regressor | |
| modules = [ | |
| nn.Conv2d(in_channels, channels, 3, 1, 1), | |
| nn.GELU(), | |
| nn.Conv2d(channels, channels, 3, 1, 1), | |
| ] | |
| self.gaussian_regressor = nn.Sequential(*modules) | |
| init_gaussian_param_num = self.cfg.get_gaussian_param_num() | |
| # gaussian head input channels | |
| # concat(img, features, regressor_out, match_prob) | |
| in_channels = self.cfg.get_pt_in_channels() | |
| if self.cfg.pt_head: | |
| channels = self.cfg.gaussian_regressor_channels | |
| if self.cfg.pt_head_channels is not None: | |
| channels = self.cfg.pt_head_channels | |
| self.proj = nn.Linear(in_channels, channels) | |
| if self.cfg.multi_scale_pt: | |
| self.pt = MultiScalePointTransformer(channels, | |
| self.cfg.knn_samples, | |
| downsample_agg_func=self.cfg.fps_agg_func, | |
| subsample_method=self.cfg.subsample_method, | |
| fps_num_samples=self.cfg.fps_num_samples, | |
| attn_proj_channels=self.cfg.attn_proj_channels, | |
| ) | |
| else: | |
| self.pt = PlainPointTransformer(channels, self.cfg.knn_samples, | |
| post_norm=self.cfg.post_norm, | |
| no_rpe=self.cfg.no_rpe, | |
| no_attn=self.cfg.no_knn_attn, | |
| num_blocks=self.cfg.num_blocks, | |
| num_heads=self.cfg.pt_heads, | |
| attn_proj_channels=self.cfg.attn_proj_channels, | |
| use_checkpointing=self.cfg.use_checkpointing, | |
| init_use_checkpointing=self.cfg.init_use_checkpointing, | |
| with_mv_attn=self.cfg.init_pt_with_mv_attn, | |
| with_mv_attn_lowres=self.cfg.init_pt_with_mv_attn_lowres, | |
| ) | |
| out_channels = channels | |
| # point downsample | |
| if self.cfg.pt_downsample > 0: | |
| num_downsample = int(np.log2(self.cfg.pt_downsample)) | |
| if num_downsample == 0: | |
| stride = 1 | |
| else: | |
| stride = 2 | |
| assert num_downsample == 1, f"unsupported num_downsample: {num_downsample}" | |
| self.pt_down = SubsampleBlock(channels, out_channels=channels * 2, | |
| stride=stride, | |
| knn_samples=self.cfg.knn_samples, | |
| post_norm=self.cfg.post_norm, | |
| agg_func=self.cfg.fps_agg_func, | |
| subsample_method=self.cfg.subsample_method, | |
| ) | |
| out_channels = channels * 2 | |
| # TODO: add more pt blocks after downsampling | |
| if self.cfg.pt_head_concat_img: | |
| # concat to the initial image and features | |
| out_channels = out_channels + 3 | |
| if self.cfg.latent_gs: | |
| # pixel unshuffle the full image to the latent resolution | |
| out_channels = out_channels - 3 + 3 * (self.cfg.latent_downsample ** 2) | |
| self.gaussian_head = nn.Sequential( | |
| nn.Linear(out_channels, init_gaussian_param_num), | |
| nn.GELU(), | |
| nn.Linear(init_gaussian_param_num, init_gaussian_param_num) | |
| ) | |
| # random initialize rotations: first part | |
| # 4 | |
| num_rotation_params = 4 * self.cfg.init_gaussian_multiple | |
| # zero init other remaining params | |
| # scale, opacity, offset, sh | |
| # 4 + 1 + 1 + 3 * 16 = 54 | |
| nn.init.zeros_(self.gaussian_head[-1].weight[num_rotation_params:]) | |
| nn.init.zeros_(self.gaussian_head[-1].bias[num_rotation_params:]) | |
| else: | |
| self.gaussian_head = nn.Sequential( | |
| nn.Conv2d(in_channels, init_gaussian_param_num, | |
| 3, 1, 1, padding_mode='replicate'), | |
| nn.GELU(), | |
| nn.Conv2d(init_gaussian_param_num, | |
| init_gaussian_param_num, 3, 1, 1, padding_mode='replicate') | |
| ) | |
| # random initialize rotations: first part | |
| # 4 | |
| num_rotation_params = 4 * self.cfg.init_gaussian_multiple | |
| # zero init other remaining params | |
| # scale, opacity, offset, sh | |
| # 3 + 1 + 2 + 3 * 16 = 54 | |
| nn.init.zeros_(self.gaussian_head[-1].weight[num_rotation_params:]) | |
| nn.init.zeros_(self.gaussian_head[-1].bias[num_rotation_params:]) | |
| self.test_save_every: FrequencyScheduler | None = None # a class to save intermediate results during testing, will be set by the ModelWrraper | |
| def _get_depth_predictor(self, cfg): | |
| return MultiViewUniMatch( | |
| num_scales=cfg.num_scales, | |
| upsample_factor=cfg.upsample_factor, | |
| lowest_feature_resolution=cfg.lowest_feature_resolution, | |
| num_depth_candidates=cfg.num_depth_candidates, | |
| vit_type=cfg.monodepth_vit_type, | |
| unet_channels=cfg.depth_unet_channels, | |
| grid_sample_disable_cudnn=cfg.grid_sample_disable_cudnn, | |
| sample_log_depth=self.cfg.sample_log_depth, | |
| bilinear_upsample_depth=self.cfg.bilinear_upsample_depth, | |
| no_upsample_depth=self.cfg.no_upsample_depth, | |
| use_amp=self.cfg.use_amp, | |
| return_raw_mono_features=not self.cfg.latent_dpt_upsampler, | |
| use_checkpointing=self.cfg.use_checkpointing, | |
| ) | |
| def forward( | |
| self, | |
| context: BatchedViews, | |
| visualization_dump: Optional[dict] = None, | |
| **kwargs | |
| ) -> InitializerOutput: | |
| device = context["image"].device | |
| b, v, _, h, w = context["image"].shape | |
| if v > 3: | |
| with torch.no_grad(): | |
| xyzs = context["extrinsics"][:, :, :3, -1].detach() | |
| cameras_dist_matrix = torch.cdist(xyzs, xyzs, p=2) | |
| cameras_dist_index = torch.argsort(cameras_dist_matrix) | |
| cameras_dist_index = cameras_dist_index[:, :, :(self.cfg.local_mv_match + 1)] | |
| else: | |
| cameras_dist_index = None | |
| # depth prediction | |
| if self.cfg.depth_pred_half_res: | |
| half_img = rearrange(context["image"], "b v c h w -> (b v) c h w") | |
| half_img = F.interpolate(half_img, scale_factor=0.5, mode='bilinear', align_corners=True) | |
| half_img = rearrange(half_img, "(b v) c h w -> b v c h w", b=b, v=v) | |
| results_dict = self.depth_predictor( | |
| half_img, | |
| attn_splits_list=[2], | |
| min_depth=1. / context["far"], | |
| max_depth=1. / context["near"], | |
| intrinsics=context["intrinsics"], | |
| extrinsics=context["extrinsics"], | |
| nn_matrix=cameras_dist_index, | |
| ) | |
| # upsample depth to the original resolution | |
| for key in results_dict.keys(): | |
| # NOTE: no need to upsample depth since depth later is in the low resolution | |
| if key != 'depth_preds': | |
| for i in range(len(results_dict[key])): | |
| results_dict[key][i] = F.interpolate(results_dict[key][i], scale_factor=2, mode='bilinear', | |
| align_corners=True) | |
| # depthsplat: upsample depth to the original resolution | |
| if not self.cfg.latent_gs: | |
| for i in range(len(results_dict['depth_preds'])): | |
| results_dict['depth_preds'][i] = F.interpolate(results_dict['depth_preds'][i], scale_factor=2, | |
| mode='bilinear', align_corners=True) | |
| else: | |
| results_dict = self.depth_predictor( | |
| context["image"], | |
| attn_splits_list=[2], | |
| min_depth=1. / context["far"], | |
| max_depth=1. / context["near"], | |
| intrinsics=context["intrinsics"], | |
| extrinsics=context["extrinsics"], | |
| nn_matrix=cameras_dist_index, | |
| ) | |
| if self.cfg.use_gt_depth: | |
| # directly use gt depth as gaussian centers instead of learning them | |
| # to understand the bottleneck of the model | |
| assert 'depth' in context | |
| depth_preds = [context['depth']] | |
| else: | |
| # list of [B, V, H, W], with all the intermediate depths | |
| depth_preds = results_dict['depth_preds'] | |
| # [B, V, H, W] | |
| depth = depth_preds[-1] | |
| gaussian_scale_depth = None | |
| # features [BV, C, H, W] | |
| if self.cfg.latent_gs and not self.cfg.latent_dpt_upsampler: | |
| # concat all features | |
| assert self.cfg.num_scales == 1 | |
| # use pixelshuffle and pixelunshuffle to align all feature resolutions | |
| # first resize the mono features to 1/16 | |
| mono_features = [F.interpolate(x, size=(h // 16, w // 16), mode='bilinear', align_corners=True) for x in | |
| results_dict['raw_mono_features']] | |
| if self.cfg.fixed_latent_size: | |
| scale_factor = 4 | |
| mono_features = [F.pixel_shuffle(x, upscale_factor=scale_factor) for x in mono_features] | |
| mono_features = torch.cat(mono_features, dim=1) # channel: 384 / 16 * 4 | |
| if self.cfg.latent_downsample == 8: | |
| mono_features = F.interpolate(mono_features, scale_factor=0.5, mode='bilinear', align_corners=True) | |
| else: | |
| if self.cfg.latent_downsample == 4: | |
| scale_factor = 4 | |
| mono_features = [F.pixel_shuffle(x, upscale_factor=scale_factor) for x in mono_features] | |
| mono_features = torch.cat(mono_features, dim=1) # channel: 384 / 16 * 4 | |
| elif self.cfg.latent_downsample == 2: | |
| scale_factor = 8 | |
| mono_features = [F.pixel_shuffle(x, upscale_factor=scale_factor) for x in mono_features] | |
| mono_features = torch.cat(mono_features, dim=1) # channel: 384 / 64 * 4 | |
| elif self.cfg.latent_downsample == 8: | |
| scale_factor = 2 | |
| mono_features = [F.pixel_shuffle(x, upscale_factor=scale_factor) for x in mono_features] | |
| mono_features = torch.cat(mono_features, dim=1) # channel: 384 / 4 * 4 | |
| else: | |
| raise NotImplementedError | |
| cnn_features = results_dict["features_cnn_all_scales"][::-1] | |
| if self.cfg.latent_downsample == 2: | |
| # use pixel shuffle to save channels | |
| # 1/2, 1/2, 1/4 | |
| cnn_features[2] = F.pixel_shuffle(cnn_features[2], upscale_factor=2) | |
| # 64 + 96 + 128 // 4 | |
| cnn_features = torch.cat(cnn_features, dim=1) | |
| # 128 // 4 | |
| mv_features = results_dict["features_mv"][0] | |
| mv_features = F.pixel_shuffle(mv_features, upscale_factor=2) | |
| else: | |
| # resize all cnn features to the latent resolution | |
| target_h, target_w = h // self.cfg.latent_downsample, w // self.cfg.latent_downsample | |
| for i in range(len(cnn_features)): | |
| cnn_features[i] = F.interpolate(cnn_features[i], size=(target_h, target_w), mode='bilinear', | |
| align_corners=True) | |
| cnn_features = torch.cat(cnn_features, dim=1) | |
| mv_features = results_dict["features_mv"][0] | |
| if mv_features.shape[-2] != target_h or mv_features.shape[-1] != target_w: | |
| mv_features = F.interpolate(mv_features, size=(target_h, target_w), mode='bilinear', | |
| align_corners=True) | |
| features = torch.cat((mono_features, cnn_features, mv_features), dim=1) | |
| elif self.cfg.resizeconv_upsampler: | |
| features = self.feature_upsampler(results_dict["features_cnn"], | |
| results_dict["features_mv"], | |
| results_dict["features_mono"], | |
| ) | |
| else: | |
| with torch.amp.autocast(device_type='cuda', enabled=self.cfg.use_amp, dtype=torch.bfloat16): | |
| features = self.feature_upsampler(results_dict["features_mono_intermediate"], | |
| cnn_features=results_dict["features_cnn_all_scales"][::-1], | |
| mv_features=results_dict["features_mv"][ | |
| 0] if self.cfg.num_scales == 1 else results_dict["features_mv"][ | |
| ::-1] | |
| ) | |
| # match prob from softmax | |
| # [BV, D, H, W] in feature resolution | |
| match_prob = results_dict['match_probs'][-1] | |
| match_prob = torch.max(match_prob, dim=1, keepdim=True)[ | |
| 0] # [BV, 1, H, W] | |
| if not self.cfg.latent_gs: | |
| match_prob = F.interpolate( | |
| match_prob, size=depth.shape[-2:], mode='nearest') | |
| # unet input | |
| if self.cfg.latent_gs: | |
| img_unshuffle = rearrange(context["image"], "b v c h w -> (b v) c h w") | |
| if self.cfg.fixed_latent_size: | |
| if self.cfg.latent_downsample == 8: | |
| img_unshuffle = F.interpolate(img_unshuffle, scale_factor=0.5, mode='area') | |
| img_unshuffle = F.pixel_unshuffle(img_unshuffle, downscale_factor=4) | |
| else: | |
| img_unshuffle = F.pixel_unshuffle(img_unshuffle, downscale_factor=self.cfg.latent_downsample) | |
| # depth is in the full resolution, downsample to latent depth | |
| if self.cfg.depth_pred_half_res: | |
| latent_depth = F.interpolate(depth, scale_factor=1. / (self.cfg.latent_downsample // 2), | |
| mode='bilinear', align_corners=True) | |
| else: | |
| if self.cfg.no_upsample_depth: | |
| assert self.cfg.latent_downsample == 8 or self.cfg.latent_downsample == 4 | |
| if self.cfg.latent_downsample == 8: | |
| latent_depth = depth | |
| else: | |
| # 1/8 depth to 1/4 | |
| latent_depth = F.interpolate(depth, scale_factor=2, mode='bilinear', align_corners=True) | |
| else: | |
| if self.cfg.avgpool_depth: | |
| latent_depth = F.avg_pool2d(depth, kernel_size=self.cfg.latent_downsample, | |
| stride=self.cfg.latent_downsample) | |
| elif self.cfg.nearest_down_depth: | |
| latent_depth = F.interpolate(depth, scale_factor=1. / self.cfg.latent_downsample, | |
| mode='nearest') | |
| else: | |
| latent_depth = F.interpolate(depth, scale_factor=1. / self.cfg.latent_downsample, | |
| mode='bilinear', align_corners=True) | |
| if match_prob.shape[-2:] != latent_depth.shape[-2]: | |
| match_prob = F.interpolate( | |
| match_prob, size=latent_depth.shape[-2:], mode='nearest') | |
| concat = torch.cat(( | |
| img_unshuffle, | |
| rearrange(latent_depth, "b v h w -> (b v) () h w"), | |
| match_prob, | |
| features, | |
| ), dim=1) | |
| else: | |
| concat = torch.cat(( | |
| rearrange(context["image"], "b v c h w -> (b v) c h w"), | |
| rearrange(depth, "b v h w -> (b v) () h w"), | |
| match_prob, | |
| features, | |
| ), dim=1) | |
| if self.cfg.lvsm_gaussian_regressor: | |
| h, w = concat.shape[-2:] | |
| tmp = rearrange(concat, "(b v) c h w -> b (v h w) c", b=b, v=v) | |
| with torch.autocast('cuda', dtype=torch.bfloat16): | |
| out = self.gaussian_regressor(tmp) | |
| out = rearrange(out, "b (v h w) c -> (b v) c h w", b=b, v=v, h=h, w=w) | |
| else: | |
| with torch.amp.autocast(device_type='cuda', enabled=self.cfg.use_amp, dtype=torch.bfloat16): | |
| out = self.gaussian_regressor(concat) | |
| if self.cfg.latent_gs: | |
| concat = [out, img_unshuffle, features, match_prob] | |
| else: | |
| concat = [out, | |
| rearrange(context["image"], | |
| "b v c h w -> (b v) c h w"), | |
| features, | |
| match_prob] | |
| out = torch.cat(concat, dim=1) | |
| # [BV, C, H, W] | |
| condition_features = out | |
| init_scales = None | |
| if self.cfg.pt_head: | |
| if self.cfg.latent_gs: | |
| h, w = latent_depth.shape[-2:] | |
| else: | |
| h, w = depth.shape[-2:] | |
| with torch.amp.autocast(device_type='cuda', enabled=self.cfg.pt_head_amp, dtype=torch.bfloat16): | |
| tmp_feature = self.proj(rearrange(out, "bv c h w -> (bv h w) c")) | |
| # get point cloud | |
| xy_ray, _ = sample_image_grid((h, w), out.device) | |
| xy_ray = rearrange(xy_ray, "h w xy -> (h w) () xy") | |
| # [B, V, H*W, 1, 2] | |
| tmp_coords = xy_ray.unsqueeze(0).unsqueeze(0).repeat(b, v, 1, 1, 1) | |
| # [B, V, H*W, 1, 1] | |
| if self.cfg.latent_gs: | |
| tmp_depth = rearrange(latent_depth, "b v h w -> b v (h w) () ()") | |
| else: | |
| tmp_depth = rearrange(depth, "b v h w -> b v (h w) () ()") | |
| # [B, V, 1, 1, 4, 4] | |
| tmp_extrinsics = context["extrinsics"].unsqueeze(2).unsqueeze(2) | |
| # [B, V, 1, 1, 3, 3] | |
| tmp_intrinsics = context["intrinsics"].unsqueeze(2).unsqueeze(2) | |
| # [B, V, H*W, 1, 3] | |
| origins, directions = get_world_rays(tmp_coords, tmp_extrinsics, tmp_intrinsics) | |
| point_cloud = origins + directions * tmp_depth | |
| # Create offset directly on device to avoid CPU-GPU transfer | |
| offset = torch.arange(1, b + 1, device=depth.device, dtype=torch.long) * (v * h * w) | |
| point_cloud = rearrange(point_cloud, "b v h w c -> (b v h w) c") | |
| with torch.amp.autocast(device_type='cuda', enabled=self.cfg.pt_head_amp, dtype=torch.bfloat16): | |
| if self.cfg.add_pt_residual: | |
| out = tmp_feature + self.pt((point_cloud, tmp_feature, offset), b=b, v=v, h=h, w=w) | |
| else: | |
| out = self.pt((point_cloud, tmp_feature, offset), b=b, v=v, h=h, w=w) | |
| condition_features = rearrange(out, "(bv h w) c -> bv c h w", h=h, w=w) | |
| if self.cfg.pt_downsample > 0: | |
| out, fps_idx = self.pt_down((point_cloud, out, offset)) | |
| # [N, 3] | |
| point_cloud, out, offset = out | |
| with torch.amp.autocast(device_type='cuda', enabled=self.cfg.pt_head_amp, dtype=torch.bfloat16): | |
| if self.cfg.pt_head_concat_img: | |
| if self.cfg.latent_gs: | |
| # pixel unshuffle image | |
| img_unshuffle = rearrange(context["image"], "b v c h w -> (b v) c h w") | |
| img_unshuffle = F.pixel_unshuffle(img_unshuffle, downscale_factor=self.cfg.latent_downsample) | |
| img_unshuffle = rearrange(img_unshuffle, "(b v) c h w -> (b v h w) c", b=b, v=v) | |
| out = torch.cat((out, img_unshuffle), dim=-1) | |
| if self.cfg.pt_head_conv: | |
| out = rearrange(out, "(b v h w) c -> (b v) c h w", b=b, v=v, h=h, w=w) | |
| out = self.gaussian_head(out) | |
| if self.cfg.pt_head_conv: | |
| out = rearrange(out, "(b v) c h w -> (b v h w) c", b=b, v=v) | |
| if self.cfg.pt_downsample > 0: | |
| # [N, C] | |
| gaussians = out | |
| else: | |
| if self.cfg.pt_pred_residual_position: | |
| # TODO: add intermediate supervision to the initial point cloud | |
| # TODO: multiple scale factor to the delta position to make it more stable | |
| # residual position | |
| point_cloud = point_cloud + out[..., -3:] # [BVHW, 3] | |
| # remaining gaussians | |
| out = out[..., :-3] | |
| point_cloud = rearrange(point_cloud, "(b v h w) c -> b v (h w) () () c", b=b, v=v, h=h, w=w) | |
| gaussians = rearrange(out, "(b v h w) c -> (b v) c h w", b=b, h=h, w=w) | |
| else: | |
| with torch.amp.autocast(device_type='cuda', enabled=self.cfg.use_amp, dtype=torch.bfloat16): | |
| gaussians = self.gaussian_head(out) # [BV, C, H, W] | |
| # [BV, C, H, W] | |
| gaussians = gaussians.float() | |
| if self.cfg.latent_gs: | |
| if self.cfg.init_gaussian_multiple > 1: | |
| # hard coded for now | |
| if self.cfg.init_gaussian_multiple == 4: | |
| # TODO: try avgpooling downsampling depth | |
| if self.cfg.latent_downsample == 4: | |
| # resize full resolution depth | |
| depths = F.interpolate(depth, scale_factor=0.5, mode='bilinear', align_corners=True) | |
| elif self.cfg.latent_downsample == 8: | |
| depths = F.interpolate(depth, scale_factor=0.25, mode='bilinear', align_corners=True) | |
| elif self.cfg.latent_downsample == 2: | |
| depths = depth | |
| else: | |
| raise NotImplementedError | |
| elif self.cfg.init_gaussian_multiple == 16: | |
| # TODO: try avgpooling downsampling depth | |
| if self.cfg.latent_downsample == 4: | |
| depths = depth | |
| elif self.cfg.latent_downsample == 8: | |
| depths = F.interpolate(depth, scale_factor=0.5, mode='bilinear', align_corners=True) | |
| else: | |
| raise NotImplementedError | |
| else: | |
| raise NotImplementedError | |
| depths = rearrange(depths, "b v h w -> b v (h w) () ()") | |
| else: | |
| depths = rearrange(latent_depth, "b v h w -> b v (h w) () ()") | |
| else: | |
| depths = rearrange(depth, "b v h w -> b v (h w) () ()") | |
| if self.cfg.pt_downsample > 0: | |
| # split batch | |
| assert offset.shape[0] == b | |
| if self.cfg.latent_gs: | |
| sh_input_images = rearrange(context["image"], "b v c h w -> (b v) c h w") | |
| if self.cfg.latent_gs_img_interp == 'bicubic': | |
| sh_input_images = F.interpolate(sh_input_images, scale_factor=1. / self.cfg.latent_downsample, | |
| mode='bicubic', align_corners=True) | |
| elif self.cfg.latent_gs_img_interp == 'area': | |
| sh_input_images = F.interpolate(sh_input_images, scale_factor=1. / self.cfg.latent_downsample, | |
| mode='area') | |
| elif self.cfg.latent_gs_img_interp == 'softmax': | |
| sh_input_images = self.softmax_downsample(sh_input_images) | |
| else: | |
| raise NotImplementedError | |
| h, w = sh_input_images.shape[-2:] | |
| sh_input_images = rearrange(sh_input_images, "(b v) c h w -> b v c h w", b=b, v=v) | |
| else: | |
| sh_input_images = context["image"] | |
| sh_input_images = rearrange(sh_input_images, "b v c h w -> (b v h w) c") | |
| # subsample with fps index | |
| sh_input_images = sh_input_images[fps_idx.long(), :] # [N, 3] | |
| # extrinsics | |
| extrinsics_all = rearrange(repeat(context["extrinsics"], "b v i j -> b v h w i j", h=h, w=w), | |
| "b v h w i j -> (b v h w) i j" | |
| ) | |
| extrinsics_all = extrinsics_all[fps_idx.long(), :, :] # [N, 4, 4] | |
| point_list = [point_cloud[:offset[0]]] | |
| gaussian_list = [gaussians[:offset[0]]] | |
| sh_img_list = [sh_input_images[:offset[0]]] | |
| extrinsics_list = [extrinsics_all[:offset[0]]] | |
| for i in range(b - 1): | |
| point_list.append(point_cloud[offset[i]:offset[i + 1]]) | |
| gaussian_list.append(gaussians[offset[i]:offset[i + 1]]) | |
| sh_img_list.append(sh_input_images[offset[i]:offset[i + 1]]) | |
| extrinsics_list.append(extrinsics_all[offset[i]:offset[i + 1]]) | |
| point_cloud = torch.stack(point_list, dim=0) # [B, N, 3] | |
| gaussians = torch.stack(gaussian_list, dim=0) # [B, N, C] | |
| sh_imgs = torch.stack(sh_img_list, dim=0) # [B, N, 3] | |
| extrinsics_all = torch.stack(extrinsics_list, dim=0) # [B, N, 4, 4] | |
| # point_cloud = [point_cloud[offset[i]:offset[i+1]] for i in range(b)] | |
| # point_cloud = torch.stack(point_cloud, dim=0) # [B, N, 3] | |
| # gaussians = [gaussians[offset[i]:offset[i+1]] for i in range(b)] | |
| # gaussians = torch.stack(gaussians, dim=0) # [B, N, 3] | |
| opacities = gaussians[..., 0].sigmoid() # [B, N] | |
| gaussians = self.gaussian_adapter.forward( | |
| extrinsics=extrinsics_all, | |
| intrinsics=None, | |
| coordinates=None, | |
| depths=None, | |
| opacities=opacities, | |
| raw_gaussians=gaussians[..., 1:], | |
| image_shape=None, | |
| point_cloud=point_cloud, | |
| input_images=sh_imgs, | |
| ) | |
| gaussians = rearrange(gaussians, "(b v) c h w -> b v c h w", b=b, v=v) | |
| # [B, V, H*W, 84] | |
| raw_gaussians = rearrange( | |
| gaussians, "b v c h w -> b v (h w) c") | |
| assert len(depth_preds) == 1, "num_scales must be 1; multi-scale depth supervision is not supported" | |
| # [B, V, H*W, C] | |
| repeat = self.cfg.init_gaussian_multiple | |
| num_sh = self.gaussian_adapter.d_sh | |
| if self.cfg.no_pixel_offset: | |
| rotations_unnorm, scales, opacities_raw, sh = raw_gaussians.split( | |
| [4 * repeat, 3 * repeat, 1 * repeat, 3 * num_sh * repeat], | |
| dim=-1, | |
| ) | |
| else: | |
| rotations_unnorm, scales, opacities_raw, offset, sh = raw_gaussians.split( | |
| [4 * repeat, 3 * repeat, 1 * repeat, 2 * repeat, 3 * num_sh * repeat], | |
| dim=-1, | |
| ) | |
| latent_h, latent_w = gaussians.shape[-2:] | |
| if repeat > 1: | |
| # reshape all the gaussian parameters | |
| if True or self.cfg.latent_new_reshape: | |
| # this works | |
| r = int(np.sqrt(repeat)) | |
| rotations_unnorm = rearrange(rotations_unnorm, "b v (h w) (c x y) -> b v (h x w y) c", | |
| h=latent_h, w=latent_w, x=r, y=r) | |
| scales = rearrange(scales, "b v (h w) (c x y) -> b v (h x w y) c", h=latent_h, w=latent_w, x=r, | |
| y=r) | |
| opacities_raw = rearrange(opacities_raw, "b v (h w) (c x y) -> b v (h x w y) c", h=latent_h, | |
| w=latent_w, x=r, y=r) | |
| offset = rearrange(offset, "b v (h w) (c x y) -> b v (h x w y) c", h=latent_h, w=latent_w, x=r, | |
| y=r) | |
| sh = rearrange(sh, "b v (h w) (c x y) -> b v (h x w y) c", h=latent_h, w=latent_w, x=r, y=r) | |
| else: | |
| # doesn't work | |
| rotations_unnorm = rearrange(rotations_unnorm, "b v hw (k c) -> b v (hw k) c", k=repeat) | |
| scales = rearrange(scales, "b v hw (k c) -> b v (hw k) c", k=repeat) | |
| opacities_raw = rearrange(opacities_raw, "b v hw (k c) -> b v (hw k) c", k=repeat) | |
| offset = rearrange(offset, "b v hw (k c) -> b v (hw k) c", k=repeat) | |
| sh = rearrange(sh, "b v hw (k c) -> b v (hw k) c", k=repeat) | |
| opacities = opacities_raw.sigmoid() # [B, V, H*W*K, 1] | |
| if self.cfg.latent_downsample == 4 and self.cfg.init_gaussian_multiple == 4: | |
| scale_factor = 2 | |
| elif self.cfg.latent_downsample == 2 and self.cfg.init_gaussian_multiple == 4: | |
| scale_factor = 2 | |
| elif self.cfg.latent_downsample == 4 and self.cfg.init_gaussian_multiple == 16: | |
| scale_factor = 4 | |
| elif self.cfg.latent_downsample == 8 and self.cfg.init_gaussian_multiple == 4: | |
| scale_factor = 2 | |
| elif self.cfg.latent_downsample == 8 and self.cfg.init_gaussian_multiple == 16: | |
| scale_factor = 4 | |
| else: | |
| scale_factor = 1 | |
| h, w = latent_h * scale_factor, latent_w * scale_factor | |
| # unproject depth | |
| xy_ray, _ = sample_image_grid((h, w), device) # [H, W, 2] in [0, 1] | |
| xy_ray = rearrange(xy_ray, "h w xy -> (h w) () xy") # [H*W, 1, 2] | |
| if self.cfg.no_pixel_offset: | |
| offset_xy = torch.ones_like(raw_gaussians[..., :2]).unsqueeze(-2).to( | |
| raw_gaussians.device) * 0.5 # [B, V, H*W, 1, 2] | |
| else: | |
| offset_xy = offset.sigmoid().unsqueeze(-2) # [B, V, H*W, 1, 2] | |
| pixel_size = 1 / \ | |
| torch.tensor((w, h), dtype=torch.float32, device=device) | |
| # [H*W, 1, 2] | |
| if self.cfg.deform_sample_depth and not self.cfg.deform_sample_depth_debug: | |
| # (offset_xy - 0.5) in -0.5 to 0.5, without multiplying by pixel size such that the points can move in the image space | |
| xy_ray = (xy_ray + (offset_xy - 0.5)).clamp(min=0., max=1.) | |
| else: | |
| xy_ray = xy_ray + (offset_xy - 0.5) * pixel_size | |
| if self.cfg.deform_sample_depth: | |
| # use low-res xy_ray to sample full-res depth | |
| sample_grid = rearrange(xy_ray, "b v (h w) c xy -> (b v) h w (c xy)", h=h, w=w) # in [0, 1] | |
| # to [-1, 1] | |
| sample_grid = 2 * (sample_grid - 0.5) # [BV, h, w, 2] | |
| fullres_depth = rearrange(depth, "b v h w -> (b v) () h w") # [BV, 1, H, W] | |
| sampled_depth = F.grid_sample(fullres_depth, sample_grid, mode='bilinear', align_corners=True, | |
| padding_mode="border") # [BV, 1, h, w] | |
| # reshape | |
| depths = rearrange(sampled_depth, "(b v) () h w -> b v (h w) () ()", b=b, v=v, h=h, w=w) | |
| if self.cfg.latent_gs: | |
| sh_input_images = rearrange(context["image"], "b v c h w -> (b v) c h w") | |
| if self.cfg.latent_downsample == 4 and self.cfg.init_gaussian_multiple == 4: | |
| sh_input_images = F.interpolate(sh_input_images, scale_factor=0.5, mode='area') | |
| elif self.cfg.latent_downsample == 4 and self.cfg.init_gaussian_multiple == 16: | |
| pass | |
| elif self.cfg.latent_downsample == 2 and self.cfg.init_gaussian_multiple == 4: | |
| pass | |
| elif self.cfg.latent_downsample == 8 and self.cfg.init_gaussian_multiple == 4: | |
| sh_input_images = F.interpolate(sh_input_images, scale_factor=0.25, mode='area') | |
| elif self.cfg.latent_downsample == 8 and self.cfg.init_gaussian_multiple == 16: | |
| sh_input_images = F.interpolate(sh_input_images, scale_factor=0.5, mode='area') | |
| else: | |
| sh_input_images = F.interpolate(sh_input_images, scale_factor=1. / self.cfg.latent_downsample, | |
| mode='area') | |
| sh_input_images = rearrange(sh_input_images, "(b v) c h w -> b v c h w", b=b, v=v) | |
| else: | |
| sh_input_images = context["image"] | |
| assert len(depth_preds) == 1, "num_scales must be 1; multi-scale depth supervision is not supported" | |
| # build gaussians | |
| # scale | |
| scales = torch.clamp(F.softplus(scales - self.cfg.gaussian_adapter.exp_scale_bias), | |
| min=self.cfg.gaussian_adapter.clamp_min_scale, | |
| max=self.cfg.gaussian_adapter.gaussian_scale_max | |
| ) | |
| # Normalize the quaternion features to yield a valid quaternion. | |
| # rotations = rotations_unnorm / (rotations_unnorm.norm(dim=-1, keepdim=True) + 1e-8) | |
| # Convert rotations to world-space | |
| c2w_rotations = context["extrinsics"][..., :3, :3].unsqueeze(2) # [B, V, 1, 3, 3] | |
| rotations = rotate_quats(c2w_rotations, rotations_unnorm) | |
| rotations_unnorm = rotations.clone() | |
| # Create world-space covariance matrices. | |
| covariances = build_covariance(scale=scales, rotation_xyzw=rotations) # [B, V, H*W, 3, 3] | |
| # means | |
| # [B, V, H*W, 1, 2] | |
| # xy_ray = xy_ray.unsqueeze(0).unsqueeze(0).repeat(b, v, 1, 1, 1) | |
| origins, directions = get_world_rays(xy_ray, | |
| context["extrinsics"].unsqueeze(2).unsqueeze(2), | |
| context["intrinsics"].unsqueeze(2).unsqueeze(2)) | |
| means = origins + directions * depths | |
| # sh: [B, V, HW, 3, SH] | |
| sh = rearrange(sh, "... (xyz d_sh) -> ... xyz d_sh", xyz=3).clone() | |
| # sh = sh.broadcast_to((*opacities.shape, 3, self.gaussian_adapter.d_sh)).clone() | |
| # [B, V, H*W, 3] | |
| sh_input_images = rearrange(sh_input_images, "b v c h w -> b v (h w) c") | |
| # init sh with input images | |
| sh[..., 0] = sh[..., 0] + RGB2SH(sh_input_images) | |
| gaussians = Gaussians( | |
| means=rearrange(means, "b v r spp xyz -> b (v r spp) xyz"), | |
| covariances=rearrange(covariances, "b v r i j -> b (v r) i j"), | |
| harmonics=rearrange(sh, "b v r c d_sh -> b (v r) c d_sh"), | |
| opacities=rearrange(opacities, "b v r spp -> b (v r spp)"), | |
| scales=rearrange(scales, "b v r xyz -> b (v r) xyz"), | |
| rotations=rearrange(rotations, "b v r wxyz -> b (v r) wxyz"), # in wxyz format | |
| rotations_unnorm=rearrange(rotations_unnorm, "b v r wxyz -> b (v r) wxyz") # in wxyz format | |
| ) | |
| else: | |
| gaussians = rearrange(gaussians, "(b v) c h w -> b v c h w", b=b, v=v) | |
| # [B, V, H*W, 84] | |
| raw_gaussians = rearrange( | |
| gaussians, "b v c h w -> b v (h w) c") | |
| assert len(depth_preds) == 1, "num_scales must be 1; multi-scale depth supervision is not supported" | |
| # [B, V, H*W, C] | |
| repeat = self.cfg.init_gaussian_multiple | |
| num_sh = self.gaussian_adapter.d_sh | |
| if self.cfg.no_pixel_offset: | |
| rotations_unnorm, scales, opacities_raw, sh = raw_gaussians.split( | |
| [4 * repeat, 3 * repeat, 1 * repeat, 3 * num_sh * repeat], | |
| dim=-1, | |
| ) | |
| else: | |
| rotations_unnorm, scales, opacities_raw, offset, sh = raw_gaussians.split( | |
| [4 * repeat, 3 * repeat, 1 * repeat, 2 * repeat, 3 * num_sh * repeat], | |
| dim=-1, | |
| ) | |
| latent_h, latent_w = gaussians.shape[-2:] | |
| if repeat > 1: | |
| # reshape all the gaussian parameters | |
| if True or self.cfg.latent_new_reshape: | |
| # this works | |
| r = int(np.sqrt(repeat)) | |
| rotations_unnorm = rearrange(rotations_unnorm, "b v (h w) (c x y) -> b v (h x w y) c", | |
| h=latent_h, w=latent_w, x=r, y=r) | |
| scales = rearrange(scales, "b v (h w) (c x y) -> b v (h x w y) c", h=latent_h, w=latent_w, x=r, | |
| y=r) | |
| opacities_raw = rearrange(opacities_raw, "b v (h w) (c x y) -> b v (h x w y) c", h=latent_h, | |
| w=latent_w, x=r, y=r) | |
| offset = rearrange(offset, "b v (h w) (c x y) -> b v (h x w y) c", h=latent_h, w=latent_w, x=r, | |
| y=r) | |
| sh = rearrange(sh, "b v (h w) (c x y) -> b v (h x w y) c", h=latent_h, w=latent_w, x=r, y=r) | |
| else: | |
| # doesn't work | |
| rotations_unnorm = rearrange(rotations_unnorm, "b v hw (k c) -> b v (hw k) c", k=repeat) | |
| scales = rearrange(scales, "b v hw (k c) -> b v (hw k) c", k=repeat) | |
| opacities_raw = rearrange(opacities_raw, "b v hw (k c) -> b v (hw k) c", k=repeat) | |
| offset = rearrange(offset, "b v hw (k c) -> b v (hw k) c", k=repeat) | |
| sh = rearrange(sh, "b v hw (k c) -> b v (hw k) c", k=repeat) | |
| opacities = opacities_raw.sigmoid() # [B, V, H*W*K, 1] | |
| if self.cfg.latent_downsample == 4 and self.cfg.init_gaussian_multiple == 4: | |
| scale_factor = 2 | |
| elif self.cfg.latent_downsample == 2 and self.cfg.init_gaussian_multiple == 4: | |
| scale_factor = 2 | |
| elif self.cfg.latent_downsample == 4 and self.cfg.init_gaussian_multiple == 16: | |
| scale_factor = 4 | |
| elif self.cfg.latent_downsample == 8 and self.cfg.init_gaussian_multiple == 4: | |
| scale_factor = 2 | |
| elif self.cfg.latent_downsample == 8 and self.cfg.init_gaussian_multiple == 16: | |
| scale_factor = 4 | |
| else: | |
| scale_factor = 1 | |
| h, w = latent_h * scale_factor, latent_w * scale_factor | |
| # unproject depth | |
| xy_ray, _ = sample_image_grid((h, w), device) | |
| xy_ray = rearrange(xy_ray, "h w xy -> (h w) () xy") | |
| if self.cfg.no_pixel_offset: | |
| offset_xy = torch.ones_like(raw_gaussians[..., :2]).unsqueeze(-2).to( | |
| raw_gaussians.device) * 0.5 # [B, V, H*W, 1, 2] | |
| else: | |
| offset_xy = offset.sigmoid().unsqueeze(-2) # [B, V, H*W, 1, 2] | |
| pixel_size = 1 / \ | |
| torch.tensor((w, h), dtype=torch.float32, device=device) | |
| # [H*W, 1, 2] | |
| if self.cfg.deform_sample_depth and not self.cfg.deform_sample_depth_debug: | |
| # (offset_xy - 0.5) in -0.5 to 0.5, without multiplying by pixel size such that the points can move in the image space | |
| xy_ray = (xy_ray + (offset_xy - 0.5)).clamp(min=0., max=1.) | |
| else: | |
| xy_ray = xy_ray + (offset_xy - 0.5) * pixel_size | |
| if self.cfg.deform_sample_depth: | |
| # use low-res xy_ray to sample full-res depth | |
| sample_grid = rearrange(xy_ray, "b v (h w) c xy -> (b v) h w (c xy)", h=h, w=w) # in [0, 1] | |
| # to [-1, 1] | |
| sample_grid = 2 * (sample_grid - 0.5) # [BV, h, w, 2] | |
| fullres_depth = rearrange(depth, "b v h w -> (b v) () h w") # [BV, 1, H, W] | |
| sampled_depth = F.grid_sample(fullres_depth, sample_grid, mode='bilinear', align_corners=True, | |
| padding_mode="border") # [BV, 1, h, w] | |
| # reshape | |
| depths = rearrange(sampled_depth, "(b v) () h w -> b v (h w) () ()", b=b, v=v, h=h, w=w) | |
| if self.cfg.latent_gs: | |
| sh_input_images = rearrange(context["image"], "b v c h w -> (b v) c h w") | |
| if self.cfg.latent_downsample == 4 and self.cfg.init_gaussian_multiple == 4: | |
| sh_input_images = F.interpolate(sh_input_images, scale_factor=0.5, mode='area') | |
| elif self.cfg.latent_downsample == 4 and self.cfg.init_gaussian_multiple == 16: | |
| pass | |
| elif self.cfg.latent_downsample == 2 and self.cfg.init_gaussian_multiple == 4: | |
| pass | |
| elif self.cfg.latent_downsample == 8 and self.cfg.init_gaussian_multiple == 4: | |
| sh_input_images = F.interpolate(sh_input_images, scale_factor=0.25, mode='area') | |
| elif self.cfg.latent_downsample == 8 and self.cfg.init_gaussian_multiple == 16: | |
| sh_input_images = F.interpolate(sh_input_images, scale_factor=0.5, mode='area') | |
| else: | |
| sh_input_images = F.interpolate(sh_input_images, scale_factor=1. / self.cfg.latent_downsample, | |
| mode='area') | |
| sh_input_images = rearrange(sh_input_images, "(b v) c h w -> b v c h w", b=b, v=v) | |
| else: | |
| sh_input_images = context["image"] | |
| assert len(depth_preds) == 1, "num_scales must be 1; multi-scale depth supervision is not supported" | |
| # build gaussians | |
| # scale | |
| scales = torch.clamp(F.softplus(scales - self.cfg.gaussian_adapter.exp_scale_bias), | |
| min=self.cfg.gaussian_adapter.clamp_min_scale, | |
| max=self.cfg.gaussian_adapter.gaussian_scale_max | |
| ) | |
| # Convert rotations to world-space | |
| c2w_rotations = context["extrinsics"][..., :3, :3].unsqueeze(2) # [B, V, 1, 3, 3] | |
| # Here quaternions follow the xyzw format (scalar last) | |
| rotations = rotate_quats(c2w_rotations, rotations_unnorm) | |
| rotations_unnorm = rotations.clone() | |
| # Create world-space covariance matrices. | |
| covariances = build_covariance(scale=scales, rotation_xyzw=rotations) # [B, V, H*W, 3, 3] | |
| # means | |
| # [B, V, H*W, 1, 2] | |
| origins, directions = get_world_rays(xy_ray, | |
| context["extrinsics"].unsqueeze(2).unsqueeze(2), | |
| context["intrinsics"].unsqueeze(2).unsqueeze(2)) | |
| means = origins + directions * depths | |
| # sh: [B, V, HW, 3, SH] | |
| sh = rearrange(sh, "... (xyz d_sh) -> ... xyz d_sh", xyz=3).clone() | |
| # [B, V, H*W, 3] | |
| sh_input_images = rearrange(sh_input_images, "b v c h w -> b v (h w) c") | |
| # init sh with input images | |
| sh[..., 0] = sh[..., 0] + RGB2SH(sh_input_images) | |
| gaussians = Gaussians( | |
| means=rearrange(means, "b v r spp xyz -> b (v r spp) xyz"), | |
| covariances=rearrange(covariances, "b v r i j -> b (v r) i j"), | |
| harmonics=rearrange(sh, "b v r c d_sh -> b (v r) c d_sh"), | |
| opacities=rearrange(opacities, "b v r spp -> b (v r spp)"), | |
| scales=rearrange(scales, "b v r xyz -> b (v r) xyz"), | |
| rotations=rearrange(rotations, "b v r wxyz -> b (v r) wxyz"), | |
| rotations_unnorm=rearrange(rotations_unnorm, "b v r wxyz -> b (v r) wxyz") | |
| ) | |
| # Dump visualizations if needed. | |
| if visualization_dump is not None: | |
| visualization_dump["depth"] = rearrange( | |
| depths, "b v (h w) srf s -> b v h w srf s", h=h, w=w | |
| ) | |
| # if self.cfg.pt_downsample > 0: | |
| # visualization_dump["scales"] = gaussians.scales | |
| # visualization_dump["rotations"] = gaussians.rotations | |
| # else: | |
| # visualization_dump["scales"] = rearrange( | |
| # gaussians.scales, "b v r srf spp xyz -> b (v r srf spp) xyz" | |
| # ) | |
| # visualization_dump["rotations"] = rearrange( | |
| # gaussians.rotations, "b v r srf spp xyzw -> b (v r srf spp) xyzw" | |
| # ) | |
| if self.cfg.return_depth: | |
| # return depth prediction for supervision | |
| depths = depth_preds[-1] | |
| if self.cfg.return_lowres_depth: | |
| assert latent_depth is not None | |
| depths = latent_depth | |
| else: | |
| if depths.shape[-2:] != context["image"].shape[-2:]: | |
| # depths can be at lower resolution since we predict latent | |
| depths = F.interpolate( | |
| depths, size=context["image"].shape[-2:], mode='bilinear', align_corners=True) | |
| return InitializerOutput( | |
| gaussians=gaussians, | |
| depths=depths, | |
| features=condition_features | |
| ) | |
| else: | |
| return InitializerOutput( | |
| gaussians=gaussians, | |
| features=condition_features | |
| ) | |
| def get_data_shim(self) -> DataShim: | |
| def data_shim(batch: BatchedExample) -> BatchedExample: | |
| patch_size = self.cfg.shim_patch_size | |
| if isinstance(self.cfg.shim_patch_size, int): | |
| patch_size = patch_size * self.cfg.downscale_factor | |
| else: | |
| patch_size = [p * self.cfg.downscale_factor for p in patch_size] | |
| batch = apply_patch_shim( | |
| batch, | |
| patch_size=patch_size, | |
| ) | |
| return batch | |
| return data_shim | |
| def update_gt_depth_range(batch): | |
| assert "depth" in batch["context"] | |
| batch["context"]["near"] = batch["context"]["depth"].min(dim=3)[0].min(dim=2)[0].clamp(min=0.01) | |
| batch["context"]["far"] = batch["context"]["depth"].max(dim=3)[0].max(dim=2)[0].clamp(max=1000.) | |
| batch["target"]["near"] = batch["target"]["depth"].min(dim=3)[0].min(dim=2)[0].clamp(min=0.01) | |
| batch["target"]["far"] = batch["target"]["depth"].max(dim=3)[0].max(dim=2)[0].clamp(max=1000.) | |
| def update_depth_range_from_disparity(self, batch): | |
| b, v, _, h, w = batch["context"]["image"].shape | |
| # TODO: support multi-view later | |
| assert v == 2 | |
| assert self.decoder.cfg.scale_invariant is False | |
| w = batch["context"]["image"].shape[-1] | |
| # compute the depth range based on disparity range | |
| dist = (batch["context"]["extrinsics"][:, 0, :3, 3] - batch["context"]["extrinsics"][:, 1, :3, 3]).norm( | |
| dim=1, keepdim=True) | |
| focal = batch["context"]["intrinsics"][:, :, 0, 0] * w | |
| min_depth = dist * focal / self.train_cfg.max_disparity | |
| max_depth = dist * focal / self.train_cfg.min_disparity | |
| batch["context"]["near"] = min_depth | |
| batch["context"]["far"] = max_depth | |
| # TODO: also update target near and far | |
| def predict_scale(self, batch): | |
| context = batch["context"] | |
| # [B, V, H, W] | |
| init_depth = self.encoder.scale_predictor( | |
| context["image"], | |
| attn_splits_list=[2], | |
| min_depth=1. / context["far"], | |
| max_depth=1. / context["near"], | |
| intrinsics=context["intrinsics"], | |
| extrinsics=context["extrinsics"], | |
| )['depth_preds'][-1] | |
| if not self.encoder.cfg.no_pred_depth_range: | |
| new_near = init_depth.min(dim=3)[0].min(dim=2)[0].clamp(min=0.1) # [B, V] | |
| new_far = init_depth.max(dim=3)[0].max(dim=2)[0].clamp(max=100.) | |
| batch["context"]["near"] = new_near | |
| batch["context"]["far"] = new_far | |
| batch["target"]["near"] = new_near.min(dim=1, keepdim=True)[0].repeat(1, | |
| batch["target"]["near"].shape[1]) | |
| batch["target"]["far"] = new_far.max(dim=1, keepdim=True)[0].repeat(1, batch["target"]["near"].shape[1]) | |
| if self.encoder.cfg.norm_by_points: | |
| b, v, h, w = init_depth.shape | |
| # get point cloud | |
| xy_ray, _ = sample_image_grid((h, w), batch["context"]["image"].device) | |
| xy_ray = rearrange(xy_ray, "h w xy -> (h w) () xy") | |
| # [B, V, H*W, 1, 2] | |
| tmp_coords = xy_ray.unsqueeze(0).unsqueeze(0).repeat(b, v, 1, 1, 1) | |
| # [B, V, H*W, 1, 1] | |
| tmp_depth = rearrange(init_depth, "b v h w -> b v (h w) () ()") | |
| # [B, V, 1, 1, 4, 4] | |
| tmp_extrinsics = context["extrinsics"].unsqueeze(2).unsqueeze(2) | |
| # [B, V, 1, 1, 3, 3] | |
| tmp_intrinsics = context["intrinsics"].unsqueeze(2).unsqueeze(2) | |
| # [B, V, H*W, 1, 3] | |
| origins, directions = get_world_rays(tmp_coords, tmp_extrinsics, tmp_intrinsics) | |
| point_cloud = origins + directions * tmp_depth | |
| point_cloud = rearrange(point_cloud, "b v h w c -> b (v h w) c") | |
| point_dist = point_cloud.norm(dim=-1).mean(dim=-1) # [B] | |
| norm_factor = point_dist.clamp(min=1e-6) | |
| # normalize near, far and extrinsics | |
| batch["context"]["near"] = batch["context"]["near"] / norm_factor.view(b, 1) | |
| batch["context"]["far"] = batch["context"]["far"] / norm_factor.view(b, 1) | |
| batch["target"]["near"] = batch["target"]["near"] / norm_factor.view(b, 1) | |
| batch["target"]["far"] = batch["target"]["far"] / norm_factor.view(b, 1) | |
| batch["context"]["extrinsics"][:, :, :3, -1] /= norm_factor.view(b, 1, 1) | |
| batch["target"]["extrinsics"][:, :, :3, -1] /= norm_factor.view(b, 1, 1) | |
| def preprocessing(self, batch, train_cfg): | |
| # use gt depth range instead of a fixed one | |
| if train_cfg.use_gt_depth_range: | |
| self.update_gt_depth_range(batch) | |
| # compute depth range from camera distance and disparity range | |
| if train_cfg.depth_range_from_disparity: | |
| self.update_depth_range_from_disparity(batch) | |
| # use a pretrained depth model to predict scale | |
| if self.cfg.predict_scale: | |
| self.predict_scale(batch) | |