Learn2Splat / optgs /scene_trainer /initializer /initializer_resplat.py
SteEsp's picture
Add Docker-based Learn2Splat demo (viser GUI)
78d2329 verified
from dataclasses import dataclass
from typing import Literal, Optional, List
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from einops import rearrange
from torch import nn
from optgs.dataset.data_types import BatchedExample, DataShim, BatchedViews
from optgs.dataset.shims.patch_shim import apply_patch_shim
from optgs.geometry.projection import sample_image_grid, get_world_rays
from optgs.misc.general_utils import rotate_quats
from optgs.misc.io import FrequencyScheduler
from optgs.model.encoder.layer import BasicBlock
from optgs.model.encoder.unimatch.dpt_head import DPTHead
from optgs.model.encoder.unimatch.feature_upsampler import ResizeConvFeatureUpsampler
from optgs.model.encoder.unimatch.ldm_unet.unet import UNetModel
from optgs.model.encoder.unimatch.mv_unimatch import MultiViewUniMatch
from optgs.model.encoder.visualization.encoder_visualizer_depthsplat_cfg import EncoderVisualizerDepthSplatCfg
from optgs.model.types import Gaussians
from optgs.scene_trainer.common.gaussian_adapter import GaussianAdapter, GaussianAdapterCfg, build_covariance, RGB2SH
from optgs.scene_trainer.initializer.initializer import InitializerOutput, LearnedInitializer, PerPixelInitializerCfg
try:
from optgs.model.encoder.point_transformer.layer import (PlainPointTransformer, SubsampleBlock, PointLinearWrapper,
MultiScalePointTransformer,
MultViewLowresAttn, MultViewUniMatchAttn,
GaussianErrorCrossAttn)
except:
pass
from optgs.model.encoder.lvsm.transformer import LVSMTransformer
try:
from simple_knn._C import distCUDA2
except:
pass
@dataclass
class ResplatInitializerCfg(PerPixelInitializerCfg):
name: Literal["resplat_v1", "resplat_v2"]
d_feature: int
num_depth_candidates: int
num_surfaces: int
visualizer: EncoderVisualizerDepthSplatCfg
gaussian_adapter: GaussianAdapterCfg
gaussians_per_pixel: int
unimatch_weights_path: str | None
downscale_factor: int
shim_patch_size: int | List[int]
multiview_trans_attn_split: int
deform_sample_depth: bool # non-pixel aligned Gaussians with learned offsets
deform_sample_depth_debug: bool # check depth sampling
# mv_unimatch
num_scales: int
upsample_factor: int
lowest_feature_resolution: int
depth_unet_channels: int
grid_sample_disable_cudnn: bool
# depthsplat color branch
large_gaussian_head: bool
color_large_unet: bool
init_sh_input_img: bool
feature_upsampler_channels: int
gaussian_regressor_channels: int
# loss config
return_depth: bool
# only depth
train_depth_only: bool
# monodepth config
monodepth_vit_type: str
# multi-view matching
local_mv_match: int
# point transformer
pt_head: bool
init_pt_with_mv_attn: bool
init_pt_with_mv_attn_lowres: bool
pt_head_conv: bool
pt_head_concat_img: bool
pt_head_channels: int | None
multi_scale_pt: bool
attn_proj_channels: int | None
fps_num_samples: int | None
knn_samples: int
post_norm: bool
no_rpe: bool
no_knn_attn: bool
num_blocks: int
pt_downsample: int
fps_agg_func: str
subsample_method: str
add_pt_residual: bool
pt_pred_residual_position: bool # based on the inital point cloud from depth, predict additional residual
latent_dpt_upsampler: bool
latent_dpt_upsampler_no_concat: bool
light_dpt_feature: bool
# freeze depth
freeze_depth: bool
use_gt_depth: bool
# separate depth and color branches
separate_depth_color: bool
separate_depth_type: str
separate_depth_gaussian_scale: bool
# unet gaussian regressor
unet_gaussian_regressor: bool
resnet_gaussian_regressor: bool
# lvsm gaussian regressor
lvsm_gaussian_regressor: bool
lvsm_layers: int
sample_log_depth: bool
bilinear_upsample_depth: bool
no_upsample_depth: bool
return_lowres_depth: bool
# latent gaussian instead of pixel-aligned gaussian
fixed_latent_size: bool # same channels for both downsample 4 and 8
latent_gs_img_interp: str
dpt_head_depth: bool # downsample the full resolution depth to low resolution
avgpool_depth: bool
nearest_down_depth: bool
# predict scene scale and use point distance to normalize the scene
predict_scale: bool
norm_by_points: bool
no_pred_depth_range: bool
# init gaussian scale with point cloud distance
point_dist_init_gaussian_scale: bool
# feature upsampler
resizeconv_upsampler: bool
# rotate_quat_to_world: bool # rotate the quaternion to world space
latent_new_reshape: bool # debug
# amp
use_amp: bool
pt_head_amp: bool
use_checkpointing: bool
init_use_checkpointing: bool # init model uses checkpointing
no_pixel_offset: bool
pt_heads: int
init_gaussian_multiple: int
depth_pred_half_res: bool
def get_feature_upsampler_channels(self):
# upsample features to the original resolution
model_configs = {
'vits': {'in_channels': 384, 'features': 64, 'out_channels': [48, 96, 192, 384]},
'vitb': {'in_channels': 768, 'features': 96, 'out_channels': [96, 192, 384, 768]},
'vitl': {'in_channels': 1024, 'features': 128, 'out_channels': [128, 256, 512, 1024]},
}
vit_type = self.monodepth_vit_type
in_channels = model_configs[vit_type]['in_channels']
if self.latent_gs and not self.latent_dpt_upsampler:
if self.latent_downsample == 2:
feature_num = in_channels // 64 * 4 + 128 // 4 + 64 + 96 + 128 // 4
elif self.latent_downsample == 4:
feature_num = in_channels // 4 + 128 + 64 + 96 + 128
elif self.latent_downsample == 8:
if self.fixed_latent_size:
feature_num = in_channels // 4 + 128 + 64 + 96 + 128
else:
feature_num = in_channels + 128 + 64 + 96 + 128
else:
raise NotImplementedError(f"Unsupported latent_downsample value: {self.cfg.latent_downsample}")
elif self.resizeconv_upsampler:
feature_num = self.feature_upsampler_channels
else:
if self.light_dpt_feature:
for config in model_configs.values():
config['out_channels'] = [c // 2 for c in config['out_channels']]
features = model_configs[vit_type]["features"]
if self.latent_gs and not self.latent_dpt_upsampler_no_concat:
features *= 4
feature_num = features
return feature_num, model_configs
def get_pt_in_channels(self):
feature_upsampler_channels, _ = self.get_feature_upsampler_channels()
in_channels = 3 + feature_upsampler_channels + self.gaussian_regressor_channels + 1
if self.latent_gs:
# image unshuffle
if self.fixed_latent_size:
in_channels = in_channels - 3 + 3 * (4 ** 2)
else:
in_channels = in_channels - 3 + 3 * (self.latent_downsample ** 2)
return in_channels
def get_gaussian_param_num(self):
# predict gaussian parameters: scale, q, sh, offset, opacity
# d_in: (scale, q, sh)
sh_d = self.get_sh_d()
init_gaussian_param_num = 3 + 4 + 3 * sh_d + 2 + 1
if self.no_pixel_offset:
init_gaussian_param_num -= 2
if self.pt_downsample > 0:
# no pixel offset
init_gaussian_param_num -= 2
if self.pt_pred_residual_position:
# based on the inital point cloud from depth, predict additional residual
# without pixel offset on 2d
init_gaussian_param_num = init_gaussian_param_num + 3 - 2
# multiple gaussians per latent
if self.init_gaussian_multiple > 1:
# we use the point cloud unprojected from higher resolution depth map as center
# assert self.cfg.gaussian_adapter.init_rotation_identity
assert self.latent_gs
init_gaussian_param_num *= self.init_gaussian_multiple
return init_gaussian_param_num
def get_sh_d(self):
sh_d = (self.gaussian_adapter.sh_degree + 1) ** 2
return sh_d
class ResplatInitializer(LearnedInitializer[ResplatInitializerCfg]):
def __init__(self, cfg: ResplatInitializerCfg) -> None:
super().__init__(cfg)
self.depth_predictor = self._get_depth_predictor(cfg)
if self.cfg.train_depth_only:
return
feature_upsampler_channels, model_configs = self.cfg.get_feature_upsampler_channels()
if self.cfg.latent_gs and not self.cfg.latent_dpt_upsampler:
# No need to create a module — this config only computes channels
pass
elif self.cfg.resizeconv_upsampler:
self.feature_upsampler = ResizeConvFeatureUpsampler(
num_scales=cfg.num_scales,
lowest_feature_resolution=cfg.lowest_feature_resolution,
out_channels=self.cfg.feature_upsampler_channels,
vit_type=self.cfg.monodepth_vit_type,
)
else:
self.feature_upsampler = DPTHead(
**model_configs[cfg.monodepth_vit_type],
downsample_factor=cfg.upsample_factor,
return_feature=True,
num_scales=cfg.num_scales,
latent_downsample=self.cfg.latent_downsample if self.cfg.latent_gs else None,
latent_feature_no_concat=self.cfg.latent_dpt_upsampler_no_concat,
)
# gaussians adapter (can be removed)
self.gaussian_adapter = GaussianAdapter(cfg.gaussian_adapter)
# concat(img, depth, match_prob, features)
in_channels = 3 + 1 + 1 + feature_upsampler_channels
channels = self.cfg.gaussian_regressor_channels
if self.cfg.latent_gs:
# image unshuffle
if self.cfg.fixed_latent_size:
# fixed patch size 4
in_channels = in_channels - 3 + 3 * (4 ** 2)
else:
in_channels = in_channels - 3 + 3 * (self.cfg.latent_downsample ** 2)
# unet gaussian regressor
if self.cfg.unet_gaussian_regressor:
modules = [
nn.Conv2d(in_channels, channels, 3, 1, 1),
nn.GroupNorm(8, channels),
nn.GELU(),
]
if self.cfg.color_large_unet:
unet_channel_mult = [1, 2, 4, 4, 4]
else:
unet_channel_mult = [1, 1, 1, 1, 1]
unet_attn_resolutions = [16]
modules.append(
UNetModel(
image_size=None,
in_channels=channels,
model_channels=channels,
out_channels=channels,
num_res_blocks=1, # self.unet_per_scale_blocks,
# attention_resolutions=[8, 4, 2],
attention_resolutions=unet_attn_resolutions,
# channel_mult=[1, 1, 1, 1],
channel_mult=unet_channel_mult,
num_head_channels=32 if self.cfg.gaussian_regressor_channels >= 32 else 16,
dims=2,
postnorm=False,
num_frames=2,
use_cross_view_self_attn=True,
)
)
modules.append(nn.Conv2d(channels, channels, 3, 1, 1))
elif self.cfg.resnet_gaussian_regressor:
modules = [
nn.Conv2d(in_channels, channels, 3, 1, 1),
nn.GroupNorm(8, channels),
nn.GELU(),
BasicBlock(channels, channels),
BasicBlock(channels, channels),
]
elif self.cfg.lvsm_gaussian_regressor:
modules = [
nn.Linear(in_channels, channels),
nn.LayerNorm(channels),
nn.GELU(),
LVSMTransformer(channels,
n_layer=self.cfg.lvsm_layers)
]
else:
# conv regressor
modules = [
nn.Conv2d(in_channels, channels, 3, 1, 1),
nn.GELU(),
nn.Conv2d(channels, channels, 3, 1, 1),
]
self.gaussian_regressor = nn.Sequential(*modules)
init_gaussian_param_num = self.cfg.get_gaussian_param_num()
# gaussian head input channels
# concat(img, features, regressor_out, match_prob)
in_channels = self.cfg.get_pt_in_channels()
if self.cfg.pt_head:
channels = self.cfg.gaussian_regressor_channels
if self.cfg.pt_head_channels is not None:
channels = self.cfg.pt_head_channels
self.proj = nn.Linear(in_channels, channels)
if self.cfg.multi_scale_pt:
self.pt = MultiScalePointTransformer(channels,
self.cfg.knn_samples,
downsample_agg_func=self.cfg.fps_agg_func,
subsample_method=self.cfg.subsample_method,
fps_num_samples=self.cfg.fps_num_samples,
attn_proj_channels=self.cfg.attn_proj_channels,
)
else:
self.pt = PlainPointTransformer(channels, self.cfg.knn_samples,
post_norm=self.cfg.post_norm,
no_rpe=self.cfg.no_rpe,
no_attn=self.cfg.no_knn_attn,
num_blocks=self.cfg.num_blocks,
num_heads=self.cfg.pt_heads,
attn_proj_channels=self.cfg.attn_proj_channels,
use_checkpointing=self.cfg.use_checkpointing,
init_use_checkpointing=self.cfg.init_use_checkpointing,
with_mv_attn=self.cfg.init_pt_with_mv_attn,
with_mv_attn_lowres=self.cfg.init_pt_with_mv_attn_lowres,
)
out_channels = channels
# point downsample
if self.cfg.pt_downsample > 0:
num_downsample = int(np.log2(self.cfg.pt_downsample))
if num_downsample == 0:
stride = 1
else:
stride = 2
assert num_downsample == 1, f"unsupported num_downsample: {num_downsample}"
self.pt_down = SubsampleBlock(channels, out_channels=channels * 2,
stride=stride,
knn_samples=self.cfg.knn_samples,
post_norm=self.cfg.post_norm,
agg_func=self.cfg.fps_agg_func,
subsample_method=self.cfg.subsample_method,
)
out_channels = channels * 2
# TODO: add more pt blocks after downsampling
if self.cfg.pt_head_concat_img:
# concat to the initial image and features
out_channels = out_channels + 3
if self.cfg.latent_gs:
# pixel unshuffle the full image to the latent resolution
out_channels = out_channels - 3 + 3 * (self.cfg.latent_downsample ** 2)
self.gaussian_head = nn.Sequential(
nn.Linear(out_channels, init_gaussian_param_num),
nn.GELU(),
nn.Linear(init_gaussian_param_num, init_gaussian_param_num)
)
# random initialize rotations: first part
# 4
num_rotation_params = 4 * self.cfg.init_gaussian_multiple
# zero init other remaining params
# scale, opacity, offset, sh
# 4 + 1 + 1 + 3 * 16 = 54
nn.init.zeros_(self.gaussian_head[-1].weight[num_rotation_params:])
nn.init.zeros_(self.gaussian_head[-1].bias[num_rotation_params:])
else:
self.gaussian_head = nn.Sequential(
nn.Conv2d(in_channels, init_gaussian_param_num,
3, 1, 1, padding_mode='replicate'),
nn.GELU(),
nn.Conv2d(init_gaussian_param_num,
init_gaussian_param_num, 3, 1, 1, padding_mode='replicate')
)
# random initialize rotations: first part
# 4
num_rotation_params = 4 * self.cfg.init_gaussian_multiple
# zero init other remaining params
# scale, opacity, offset, sh
# 3 + 1 + 2 + 3 * 16 = 54
nn.init.zeros_(self.gaussian_head[-1].weight[num_rotation_params:])
nn.init.zeros_(self.gaussian_head[-1].bias[num_rotation_params:])
self.test_save_every: FrequencyScheduler | None = None # a class to save intermediate results during testing, will be set by the ModelWrraper
def _get_depth_predictor(self, cfg):
return MultiViewUniMatch(
num_scales=cfg.num_scales,
upsample_factor=cfg.upsample_factor,
lowest_feature_resolution=cfg.lowest_feature_resolution,
num_depth_candidates=cfg.num_depth_candidates,
vit_type=cfg.monodepth_vit_type,
unet_channels=cfg.depth_unet_channels,
grid_sample_disable_cudnn=cfg.grid_sample_disable_cudnn,
sample_log_depth=self.cfg.sample_log_depth,
bilinear_upsample_depth=self.cfg.bilinear_upsample_depth,
no_upsample_depth=self.cfg.no_upsample_depth,
use_amp=self.cfg.use_amp,
return_raw_mono_features=not self.cfg.latent_dpt_upsampler,
use_checkpointing=self.cfg.use_checkpointing,
)
def forward(
self,
context: BatchedViews,
visualization_dump: Optional[dict] = None,
**kwargs
) -> InitializerOutput:
device = context["image"].device
b, v, _, h, w = context["image"].shape
if v > 3:
with torch.no_grad():
xyzs = context["extrinsics"][:, :, :3, -1].detach()
cameras_dist_matrix = torch.cdist(xyzs, xyzs, p=2)
cameras_dist_index = torch.argsort(cameras_dist_matrix)
cameras_dist_index = cameras_dist_index[:, :, :(self.cfg.local_mv_match + 1)]
else:
cameras_dist_index = None
# depth prediction
if self.cfg.depth_pred_half_res:
half_img = rearrange(context["image"], "b v c h w -> (b v) c h w")
half_img = F.interpolate(half_img, scale_factor=0.5, mode='bilinear', align_corners=True)
half_img = rearrange(half_img, "(b v) c h w -> b v c h w", b=b, v=v)
results_dict = self.depth_predictor(
half_img,
attn_splits_list=[2],
min_depth=1. / context["far"],
max_depth=1. / context["near"],
intrinsics=context["intrinsics"],
extrinsics=context["extrinsics"],
nn_matrix=cameras_dist_index,
)
# upsample depth to the original resolution
for key in results_dict.keys():
# NOTE: no need to upsample depth since depth later is in the low resolution
if key != 'depth_preds':
for i in range(len(results_dict[key])):
results_dict[key][i] = F.interpolate(results_dict[key][i], scale_factor=2, mode='bilinear',
align_corners=True)
# depthsplat: upsample depth to the original resolution
if not self.cfg.latent_gs:
for i in range(len(results_dict['depth_preds'])):
results_dict['depth_preds'][i] = F.interpolate(results_dict['depth_preds'][i], scale_factor=2,
mode='bilinear', align_corners=True)
else:
results_dict = self.depth_predictor(
context["image"],
attn_splits_list=[2],
min_depth=1. / context["far"],
max_depth=1. / context["near"],
intrinsics=context["intrinsics"],
extrinsics=context["extrinsics"],
nn_matrix=cameras_dist_index,
)
if self.cfg.use_gt_depth:
# directly use gt depth as gaussian centers instead of learning them
# to understand the bottleneck of the model
assert 'depth' in context
depth_preds = [context['depth']]
else:
# list of [B, V, H, W], with all the intermediate depths
depth_preds = results_dict['depth_preds']
# [B, V, H, W]
depth = depth_preds[-1]
gaussian_scale_depth = None
# features [BV, C, H, W]
if self.cfg.latent_gs and not self.cfg.latent_dpt_upsampler:
# concat all features
assert self.cfg.num_scales == 1
# use pixelshuffle and pixelunshuffle to align all feature resolutions
# first resize the mono features to 1/16
mono_features = [F.interpolate(x, size=(h // 16, w // 16), mode='bilinear', align_corners=True) for x in
results_dict['raw_mono_features']]
if self.cfg.fixed_latent_size:
scale_factor = 4
mono_features = [F.pixel_shuffle(x, upscale_factor=scale_factor) for x in mono_features]
mono_features = torch.cat(mono_features, dim=1) # channel: 384 / 16 * 4
if self.cfg.latent_downsample == 8:
mono_features = F.interpolate(mono_features, scale_factor=0.5, mode='bilinear', align_corners=True)
else:
if self.cfg.latent_downsample == 4:
scale_factor = 4
mono_features = [F.pixel_shuffle(x, upscale_factor=scale_factor) for x in mono_features]
mono_features = torch.cat(mono_features, dim=1) # channel: 384 / 16 * 4
elif self.cfg.latent_downsample == 2:
scale_factor = 8
mono_features = [F.pixel_shuffle(x, upscale_factor=scale_factor) for x in mono_features]
mono_features = torch.cat(mono_features, dim=1) # channel: 384 / 64 * 4
elif self.cfg.latent_downsample == 8:
scale_factor = 2
mono_features = [F.pixel_shuffle(x, upscale_factor=scale_factor) for x in mono_features]
mono_features = torch.cat(mono_features, dim=1) # channel: 384 / 4 * 4
else:
raise NotImplementedError
cnn_features = results_dict["features_cnn_all_scales"][::-1]
if self.cfg.latent_downsample == 2:
# use pixel shuffle to save channels
# 1/2, 1/2, 1/4
cnn_features[2] = F.pixel_shuffle(cnn_features[2], upscale_factor=2)
# 64 + 96 + 128 // 4
cnn_features = torch.cat(cnn_features, dim=1)
# 128 // 4
mv_features = results_dict["features_mv"][0]
mv_features = F.pixel_shuffle(mv_features, upscale_factor=2)
else:
# resize all cnn features to the latent resolution
target_h, target_w = h // self.cfg.latent_downsample, w // self.cfg.latent_downsample
for i in range(len(cnn_features)):
cnn_features[i] = F.interpolate(cnn_features[i], size=(target_h, target_w), mode='bilinear',
align_corners=True)
cnn_features = torch.cat(cnn_features, dim=1)
mv_features = results_dict["features_mv"][0]
if mv_features.shape[-2] != target_h or mv_features.shape[-1] != target_w:
mv_features = F.interpolate(mv_features, size=(target_h, target_w), mode='bilinear',
align_corners=True)
features = torch.cat((mono_features, cnn_features, mv_features), dim=1)
elif self.cfg.resizeconv_upsampler:
features = self.feature_upsampler(results_dict["features_cnn"],
results_dict["features_mv"],
results_dict["features_mono"],
)
else:
with torch.amp.autocast(device_type='cuda', enabled=self.cfg.use_amp, dtype=torch.bfloat16):
features = self.feature_upsampler(results_dict["features_mono_intermediate"],
cnn_features=results_dict["features_cnn_all_scales"][::-1],
mv_features=results_dict["features_mv"][
0] if self.cfg.num_scales == 1 else results_dict["features_mv"][
::-1]
)
# match prob from softmax
# [BV, D, H, W] in feature resolution
match_prob = results_dict['match_probs'][-1]
match_prob = torch.max(match_prob, dim=1, keepdim=True)[
0] # [BV, 1, H, W]
if not self.cfg.latent_gs:
match_prob = F.interpolate(
match_prob, size=depth.shape[-2:], mode='nearest')
# unet input
if self.cfg.latent_gs:
img_unshuffle = rearrange(context["image"], "b v c h w -> (b v) c h w")
if self.cfg.fixed_latent_size:
if self.cfg.latent_downsample == 8:
img_unshuffle = F.interpolate(img_unshuffle, scale_factor=0.5, mode='area')
img_unshuffle = F.pixel_unshuffle(img_unshuffle, downscale_factor=4)
else:
img_unshuffle = F.pixel_unshuffle(img_unshuffle, downscale_factor=self.cfg.latent_downsample)
# depth is in the full resolution, downsample to latent depth
if self.cfg.depth_pred_half_res:
latent_depth = F.interpolate(depth, scale_factor=1. / (self.cfg.latent_downsample // 2),
mode='bilinear', align_corners=True)
else:
if self.cfg.no_upsample_depth:
assert self.cfg.latent_downsample == 8 or self.cfg.latent_downsample == 4
if self.cfg.latent_downsample == 8:
latent_depth = depth
else:
# 1/8 depth to 1/4
latent_depth = F.interpolate(depth, scale_factor=2, mode='bilinear', align_corners=True)
else:
if self.cfg.avgpool_depth:
latent_depth = F.avg_pool2d(depth, kernel_size=self.cfg.latent_downsample,
stride=self.cfg.latent_downsample)
elif self.cfg.nearest_down_depth:
latent_depth = F.interpolate(depth, scale_factor=1. / self.cfg.latent_downsample,
mode='nearest')
else:
latent_depth = F.interpolate(depth, scale_factor=1. / self.cfg.latent_downsample,
mode='bilinear', align_corners=True)
if match_prob.shape[-2:] != latent_depth.shape[-2]:
match_prob = F.interpolate(
match_prob, size=latent_depth.shape[-2:], mode='nearest')
concat = torch.cat((
img_unshuffle,
rearrange(latent_depth, "b v h w -> (b v) () h w"),
match_prob,
features,
), dim=1)
else:
concat = torch.cat((
rearrange(context["image"], "b v c h w -> (b v) c h w"),
rearrange(depth, "b v h w -> (b v) () h w"),
match_prob,
features,
), dim=1)
if self.cfg.lvsm_gaussian_regressor:
h, w = concat.shape[-2:]
tmp = rearrange(concat, "(b v) c h w -> b (v h w) c", b=b, v=v)
with torch.autocast('cuda', dtype=torch.bfloat16):
out = self.gaussian_regressor(tmp)
out = rearrange(out, "b (v h w) c -> (b v) c h w", b=b, v=v, h=h, w=w)
else:
with torch.amp.autocast(device_type='cuda', enabled=self.cfg.use_amp, dtype=torch.bfloat16):
out = self.gaussian_regressor(concat)
if self.cfg.latent_gs:
concat = [out, img_unshuffle, features, match_prob]
else:
concat = [out,
rearrange(context["image"],
"b v c h w -> (b v) c h w"),
features,
match_prob]
out = torch.cat(concat, dim=1)
# [BV, C, H, W]
condition_features = out
init_scales = None
if self.cfg.pt_head:
if self.cfg.latent_gs:
h, w = latent_depth.shape[-2:]
else:
h, w = depth.shape[-2:]
with torch.amp.autocast(device_type='cuda', enabled=self.cfg.pt_head_amp, dtype=torch.bfloat16):
tmp_feature = self.proj(rearrange(out, "bv c h w -> (bv h w) c"))
# get point cloud
xy_ray, _ = sample_image_grid((h, w), out.device)
xy_ray = rearrange(xy_ray, "h w xy -> (h w) () xy")
# [B, V, H*W, 1, 2]
tmp_coords = xy_ray.unsqueeze(0).unsqueeze(0).repeat(b, v, 1, 1, 1)
# [B, V, H*W, 1, 1]
if self.cfg.latent_gs:
tmp_depth = rearrange(latent_depth, "b v h w -> b v (h w) () ()")
else:
tmp_depth = rearrange(depth, "b v h w -> b v (h w) () ()")
# [B, V, 1, 1, 4, 4]
tmp_extrinsics = context["extrinsics"].unsqueeze(2).unsqueeze(2)
# [B, V, 1, 1, 3, 3]
tmp_intrinsics = context["intrinsics"].unsqueeze(2).unsqueeze(2)
# [B, V, H*W, 1, 3]
origins, directions = get_world_rays(tmp_coords, tmp_extrinsics, tmp_intrinsics)
point_cloud = origins + directions * tmp_depth
# Create offset directly on device to avoid CPU-GPU transfer
offset = torch.arange(1, b + 1, device=depth.device, dtype=torch.long) * (v * h * w)
point_cloud = rearrange(point_cloud, "b v h w c -> (b v h w) c")
with torch.amp.autocast(device_type='cuda', enabled=self.cfg.pt_head_amp, dtype=torch.bfloat16):
if self.cfg.add_pt_residual:
out = tmp_feature + self.pt((point_cloud, tmp_feature, offset), b=b, v=v, h=h, w=w)
else:
out = self.pt((point_cloud, tmp_feature, offset), b=b, v=v, h=h, w=w)
condition_features = rearrange(out, "(bv h w) c -> bv c h w", h=h, w=w)
if self.cfg.pt_downsample > 0:
out, fps_idx = self.pt_down((point_cloud, out, offset))
# [N, 3]
point_cloud, out, offset = out
with torch.amp.autocast(device_type='cuda', enabled=self.cfg.pt_head_amp, dtype=torch.bfloat16):
if self.cfg.pt_head_concat_img:
if self.cfg.latent_gs:
# pixel unshuffle image
img_unshuffle = rearrange(context["image"], "b v c h w -> (b v) c h w")
img_unshuffle = F.pixel_unshuffle(img_unshuffle, downscale_factor=self.cfg.latent_downsample)
img_unshuffle = rearrange(img_unshuffle, "(b v) c h w -> (b v h w) c", b=b, v=v)
out = torch.cat((out, img_unshuffle), dim=-1)
if self.cfg.pt_head_conv:
out = rearrange(out, "(b v h w) c -> (b v) c h w", b=b, v=v, h=h, w=w)
out = self.gaussian_head(out)
if self.cfg.pt_head_conv:
out = rearrange(out, "(b v) c h w -> (b v h w) c", b=b, v=v)
if self.cfg.pt_downsample > 0:
# [N, C]
gaussians = out
else:
if self.cfg.pt_pred_residual_position:
# TODO: add intermediate supervision to the initial point cloud
# TODO: multiple scale factor to the delta position to make it more stable
# residual position
point_cloud = point_cloud + out[..., -3:] # [BVHW, 3]
# remaining gaussians
out = out[..., :-3]
point_cloud = rearrange(point_cloud, "(b v h w) c -> b v (h w) () () c", b=b, v=v, h=h, w=w)
gaussians = rearrange(out, "(b v h w) c -> (b v) c h w", b=b, h=h, w=w)
else:
with torch.amp.autocast(device_type='cuda', enabled=self.cfg.use_amp, dtype=torch.bfloat16):
gaussians = self.gaussian_head(out) # [BV, C, H, W]
# [BV, C, H, W]
gaussians = gaussians.float()
if self.cfg.latent_gs:
if self.cfg.init_gaussian_multiple > 1:
# hard coded for now
if self.cfg.init_gaussian_multiple == 4:
# TODO: try avgpooling downsampling depth
if self.cfg.latent_downsample == 4:
# resize full resolution depth
depths = F.interpolate(depth, scale_factor=0.5, mode='bilinear', align_corners=True)
elif self.cfg.latent_downsample == 8:
depths = F.interpolate(depth, scale_factor=0.25, mode='bilinear', align_corners=True)
elif self.cfg.latent_downsample == 2:
depths = depth
else:
raise NotImplementedError
elif self.cfg.init_gaussian_multiple == 16:
# TODO: try avgpooling downsampling depth
if self.cfg.latent_downsample == 4:
depths = depth
elif self.cfg.latent_downsample == 8:
depths = F.interpolate(depth, scale_factor=0.5, mode='bilinear', align_corners=True)
else:
raise NotImplementedError
else:
raise NotImplementedError
depths = rearrange(depths, "b v h w -> b v (h w) () ()")
else:
depths = rearrange(latent_depth, "b v h w -> b v (h w) () ()")
else:
depths = rearrange(depth, "b v h w -> b v (h w) () ()")
if self.cfg.pt_downsample > 0:
# split batch
assert offset.shape[0] == b
if self.cfg.latent_gs:
sh_input_images = rearrange(context["image"], "b v c h w -> (b v) c h w")
if self.cfg.latent_gs_img_interp == 'bicubic':
sh_input_images = F.interpolate(sh_input_images, scale_factor=1. / self.cfg.latent_downsample,
mode='bicubic', align_corners=True)
elif self.cfg.latent_gs_img_interp == 'area':
sh_input_images = F.interpolate(sh_input_images, scale_factor=1. / self.cfg.latent_downsample,
mode='area')
elif self.cfg.latent_gs_img_interp == 'softmax':
sh_input_images = self.softmax_downsample(sh_input_images)
else:
raise NotImplementedError
h, w = sh_input_images.shape[-2:]
sh_input_images = rearrange(sh_input_images, "(b v) c h w -> b v c h w", b=b, v=v)
else:
sh_input_images = context["image"]
sh_input_images = rearrange(sh_input_images, "b v c h w -> (b v h w) c")
# subsample with fps index
sh_input_images = sh_input_images[fps_idx.long(), :] # [N, 3]
# extrinsics
extrinsics_all = rearrange(repeat(context["extrinsics"], "b v i j -> b v h w i j", h=h, w=w),
"b v h w i j -> (b v h w) i j"
)
extrinsics_all = extrinsics_all[fps_idx.long(), :, :] # [N, 4, 4]
point_list = [point_cloud[:offset[0]]]
gaussian_list = [gaussians[:offset[0]]]
sh_img_list = [sh_input_images[:offset[0]]]
extrinsics_list = [extrinsics_all[:offset[0]]]
for i in range(b - 1):
point_list.append(point_cloud[offset[i]:offset[i + 1]])
gaussian_list.append(gaussians[offset[i]:offset[i + 1]])
sh_img_list.append(sh_input_images[offset[i]:offset[i + 1]])
extrinsics_list.append(extrinsics_all[offset[i]:offset[i + 1]])
point_cloud = torch.stack(point_list, dim=0) # [B, N, 3]
gaussians = torch.stack(gaussian_list, dim=0) # [B, N, C]
sh_imgs = torch.stack(sh_img_list, dim=0) # [B, N, 3]
extrinsics_all = torch.stack(extrinsics_list, dim=0) # [B, N, 4, 4]
# point_cloud = [point_cloud[offset[i]:offset[i+1]] for i in range(b)]
# point_cloud = torch.stack(point_cloud, dim=0) # [B, N, 3]
# gaussians = [gaussians[offset[i]:offset[i+1]] for i in range(b)]
# gaussians = torch.stack(gaussians, dim=0) # [B, N, 3]
opacities = gaussians[..., 0].sigmoid() # [B, N]
gaussians = self.gaussian_adapter.forward(
extrinsics=extrinsics_all,
intrinsics=None,
coordinates=None,
depths=None,
opacities=opacities,
raw_gaussians=gaussians[..., 1:],
image_shape=None,
point_cloud=point_cloud,
input_images=sh_imgs,
)
gaussians = rearrange(gaussians, "(b v) c h w -> b v c h w", b=b, v=v)
# [B, V, H*W, 84]
raw_gaussians = rearrange(
gaussians, "b v c h w -> b v (h w) c")
assert len(depth_preds) == 1, "num_scales must be 1; multi-scale depth supervision is not supported"
# [B, V, H*W, C]
repeat = self.cfg.init_gaussian_multiple
num_sh = self.gaussian_adapter.d_sh
if self.cfg.no_pixel_offset:
rotations_unnorm, scales, opacities_raw, sh = raw_gaussians.split(
[4 * repeat, 3 * repeat, 1 * repeat, 3 * num_sh * repeat],
dim=-1,
)
else:
rotations_unnorm, scales, opacities_raw, offset, sh = raw_gaussians.split(
[4 * repeat, 3 * repeat, 1 * repeat, 2 * repeat, 3 * num_sh * repeat],
dim=-1,
)
latent_h, latent_w = gaussians.shape[-2:]
if repeat > 1:
# reshape all the gaussian parameters
if True or self.cfg.latent_new_reshape:
# this works
r = int(np.sqrt(repeat))
rotations_unnorm = rearrange(rotations_unnorm, "b v (h w) (c x y) -> b v (h x w y) c",
h=latent_h, w=latent_w, x=r, y=r)
scales = rearrange(scales, "b v (h w) (c x y) -> b v (h x w y) c", h=latent_h, w=latent_w, x=r,
y=r)
opacities_raw = rearrange(opacities_raw, "b v (h w) (c x y) -> b v (h x w y) c", h=latent_h,
w=latent_w, x=r, y=r)
offset = rearrange(offset, "b v (h w) (c x y) -> b v (h x w y) c", h=latent_h, w=latent_w, x=r,
y=r)
sh = rearrange(sh, "b v (h w) (c x y) -> b v (h x w y) c", h=latent_h, w=latent_w, x=r, y=r)
else:
# doesn't work
rotations_unnorm = rearrange(rotations_unnorm, "b v hw (k c) -> b v (hw k) c", k=repeat)
scales = rearrange(scales, "b v hw (k c) -> b v (hw k) c", k=repeat)
opacities_raw = rearrange(opacities_raw, "b v hw (k c) -> b v (hw k) c", k=repeat)
offset = rearrange(offset, "b v hw (k c) -> b v (hw k) c", k=repeat)
sh = rearrange(sh, "b v hw (k c) -> b v (hw k) c", k=repeat)
opacities = opacities_raw.sigmoid() # [B, V, H*W*K, 1]
if self.cfg.latent_downsample == 4 and self.cfg.init_gaussian_multiple == 4:
scale_factor = 2
elif self.cfg.latent_downsample == 2 and self.cfg.init_gaussian_multiple == 4:
scale_factor = 2
elif self.cfg.latent_downsample == 4 and self.cfg.init_gaussian_multiple == 16:
scale_factor = 4
elif self.cfg.latent_downsample == 8 and self.cfg.init_gaussian_multiple == 4:
scale_factor = 2
elif self.cfg.latent_downsample == 8 and self.cfg.init_gaussian_multiple == 16:
scale_factor = 4
else:
scale_factor = 1
h, w = latent_h * scale_factor, latent_w * scale_factor
# unproject depth
xy_ray, _ = sample_image_grid((h, w), device) # [H, W, 2] in [0, 1]
xy_ray = rearrange(xy_ray, "h w xy -> (h w) () xy") # [H*W, 1, 2]
if self.cfg.no_pixel_offset:
offset_xy = torch.ones_like(raw_gaussians[..., :2]).unsqueeze(-2).to(
raw_gaussians.device) * 0.5 # [B, V, H*W, 1, 2]
else:
offset_xy = offset.sigmoid().unsqueeze(-2) # [B, V, H*W, 1, 2]
pixel_size = 1 / \
torch.tensor((w, h), dtype=torch.float32, device=device)
# [H*W, 1, 2]
if self.cfg.deform_sample_depth and not self.cfg.deform_sample_depth_debug:
# (offset_xy - 0.5) in -0.5 to 0.5, without multiplying by pixel size such that the points can move in the image space
xy_ray = (xy_ray + (offset_xy - 0.5)).clamp(min=0., max=1.)
else:
xy_ray = xy_ray + (offset_xy - 0.5) * pixel_size
if self.cfg.deform_sample_depth:
# use low-res xy_ray to sample full-res depth
sample_grid = rearrange(xy_ray, "b v (h w) c xy -> (b v) h w (c xy)", h=h, w=w) # in [0, 1]
# to [-1, 1]
sample_grid = 2 * (sample_grid - 0.5) # [BV, h, w, 2]
fullres_depth = rearrange(depth, "b v h w -> (b v) () h w") # [BV, 1, H, W]
sampled_depth = F.grid_sample(fullres_depth, sample_grid, mode='bilinear', align_corners=True,
padding_mode="border") # [BV, 1, h, w]
# reshape
depths = rearrange(sampled_depth, "(b v) () h w -> b v (h w) () ()", b=b, v=v, h=h, w=w)
if self.cfg.latent_gs:
sh_input_images = rearrange(context["image"], "b v c h w -> (b v) c h w")
if self.cfg.latent_downsample == 4 and self.cfg.init_gaussian_multiple == 4:
sh_input_images = F.interpolate(sh_input_images, scale_factor=0.5, mode='area')
elif self.cfg.latent_downsample == 4 and self.cfg.init_gaussian_multiple == 16:
pass
elif self.cfg.latent_downsample == 2 and self.cfg.init_gaussian_multiple == 4:
pass
elif self.cfg.latent_downsample == 8 and self.cfg.init_gaussian_multiple == 4:
sh_input_images = F.interpolate(sh_input_images, scale_factor=0.25, mode='area')
elif self.cfg.latent_downsample == 8 and self.cfg.init_gaussian_multiple == 16:
sh_input_images = F.interpolate(sh_input_images, scale_factor=0.5, mode='area')
else:
sh_input_images = F.interpolate(sh_input_images, scale_factor=1. / self.cfg.latent_downsample,
mode='area')
sh_input_images = rearrange(sh_input_images, "(b v) c h w -> b v c h w", b=b, v=v)
else:
sh_input_images = context["image"]
assert len(depth_preds) == 1, "num_scales must be 1; multi-scale depth supervision is not supported"
# build gaussians
# scale
scales = torch.clamp(F.softplus(scales - self.cfg.gaussian_adapter.exp_scale_bias),
min=self.cfg.gaussian_adapter.clamp_min_scale,
max=self.cfg.gaussian_adapter.gaussian_scale_max
)
# Normalize the quaternion features to yield a valid quaternion.
# rotations = rotations_unnorm / (rotations_unnorm.norm(dim=-1, keepdim=True) + 1e-8)
# Convert rotations to world-space
c2w_rotations = context["extrinsics"][..., :3, :3].unsqueeze(2) # [B, V, 1, 3, 3]
rotations = rotate_quats(c2w_rotations, rotations_unnorm)
rotations_unnorm = rotations.clone()
# Create world-space covariance matrices.
covariances = build_covariance(scale=scales, rotation_xyzw=rotations) # [B, V, H*W, 3, 3]
# means
# [B, V, H*W, 1, 2]
# xy_ray = xy_ray.unsqueeze(0).unsqueeze(0).repeat(b, v, 1, 1, 1)
origins, directions = get_world_rays(xy_ray,
context["extrinsics"].unsqueeze(2).unsqueeze(2),
context["intrinsics"].unsqueeze(2).unsqueeze(2))
means = origins + directions * depths
# sh: [B, V, HW, 3, SH]
sh = rearrange(sh, "... (xyz d_sh) -> ... xyz d_sh", xyz=3).clone()
# sh = sh.broadcast_to((*opacities.shape, 3, self.gaussian_adapter.d_sh)).clone()
# [B, V, H*W, 3]
sh_input_images = rearrange(sh_input_images, "b v c h w -> b v (h w) c")
# init sh with input images
sh[..., 0] = sh[..., 0] + RGB2SH(sh_input_images)
gaussians = Gaussians(
means=rearrange(means, "b v r spp xyz -> b (v r spp) xyz"),
covariances=rearrange(covariances, "b v r i j -> b (v r) i j"),
harmonics=rearrange(sh, "b v r c d_sh -> b (v r) c d_sh"),
opacities=rearrange(opacities, "b v r spp -> b (v r spp)"),
scales=rearrange(scales, "b v r xyz -> b (v r) xyz"),
rotations=rearrange(rotations, "b v r wxyz -> b (v r) wxyz"), # in wxyz format
rotations_unnorm=rearrange(rotations_unnorm, "b v r wxyz -> b (v r) wxyz") # in wxyz format
)
else:
gaussians = rearrange(gaussians, "(b v) c h w -> b v c h w", b=b, v=v)
# [B, V, H*W, 84]
raw_gaussians = rearrange(
gaussians, "b v c h w -> b v (h w) c")
assert len(depth_preds) == 1, "num_scales must be 1; multi-scale depth supervision is not supported"
# [B, V, H*W, C]
repeat = self.cfg.init_gaussian_multiple
num_sh = self.gaussian_adapter.d_sh
if self.cfg.no_pixel_offset:
rotations_unnorm, scales, opacities_raw, sh = raw_gaussians.split(
[4 * repeat, 3 * repeat, 1 * repeat, 3 * num_sh * repeat],
dim=-1,
)
else:
rotations_unnorm, scales, opacities_raw, offset, sh = raw_gaussians.split(
[4 * repeat, 3 * repeat, 1 * repeat, 2 * repeat, 3 * num_sh * repeat],
dim=-1,
)
latent_h, latent_w = gaussians.shape[-2:]
if repeat > 1:
# reshape all the gaussian parameters
if True or self.cfg.latent_new_reshape:
# this works
r = int(np.sqrt(repeat))
rotations_unnorm = rearrange(rotations_unnorm, "b v (h w) (c x y) -> b v (h x w y) c",
h=latent_h, w=latent_w, x=r, y=r)
scales = rearrange(scales, "b v (h w) (c x y) -> b v (h x w y) c", h=latent_h, w=latent_w, x=r,
y=r)
opacities_raw = rearrange(opacities_raw, "b v (h w) (c x y) -> b v (h x w y) c", h=latent_h,
w=latent_w, x=r, y=r)
offset = rearrange(offset, "b v (h w) (c x y) -> b v (h x w y) c", h=latent_h, w=latent_w, x=r,
y=r)
sh = rearrange(sh, "b v (h w) (c x y) -> b v (h x w y) c", h=latent_h, w=latent_w, x=r, y=r)
else:
# doesn't work
rotations_unnorm = rearrange(rotations_unnorm, "b v hw (k c) -> b v (hw k) c", k=repeat)
scales = rearrange(scales, "b v hw (k c) -> b v (hw k) c", k=repeat)
opacities_raw = rearrange(opacities_raw, "b v hw (k c) -> b v (hw k) c", k=repeat)
offset = rearrange(offset, "b v hw (k c) -> b v (hw k) c", k=repeat)
sh = rearrange(sh, "b v hw (k c) -> b v (hw k) c", k=repeat)
opacities = opacities_raw.sigmoid() # [B, V, H*W*K, 1]
if self.cfg.latent_downsample == 4 and self.cfg.init_gaussian_multiple == 4:
scale_factor = 2
elif self.cfg.latent_downsample == 2 and self.cfg.init_gaussian_multiple == 4:
scale_factor = 2
elif self.cfg.latent_downsample == 4 and self.cfg.init_gaussian_multiple == 16:
scale_factor = 4
elif self.cfg.latent_downsample == 8 and self.cfg.init_gaussian_multiple == 4:
scale_factor = 2
elif self.cfg.latent_downsample == 8 and self.cfg.init_gaussian_multiple == 16:
scale_factor = 4
else:
scale_factor = 1
h, w = latent_h * scale_factor, latent_w * scale_factor
# unproject depth
xy_ray, _ = sample_image_grid((h, w), device)
xy_ray = rearrange(xy_ray, "h w xy -> (h w) () xy")
if self.cfg.no_pixel_offset:
offset_xy = torch.ones_like(raw_gaussians[..., :2]).unsqueeze(-2).to(
raw_gaussians.device) * 0.5 # [B, V, H*W, 1, 2]
else:
offset_xy = offset.sigmoid().unsqueeze(-2) # [B, V, H*W, 1, 2]
pixel_size = 1 / \
torch.tensor((w, h), dtype=torch.float32, device=device)
# [H*W, 1, 2]
if self.cfg.deform_sample_depth and not self.cfg.deform_sample_depth_debug:
# (offset_xy - 0.5) in -0.5 to 0.5, without multiplying by pixel size such that the points can move in the image space
xy_ray = (xy_ray + (offset_xy - 0.5)).clamp(min=0., max=1.)
else:
xy_ray = xy_ray + (offset_xy - 0.5) * pixel_size
if self.cfg.deform_sample_depth:
# use low-res xy_ray to sample full-res depth
sample_grid = rearrange(xy_ray, "b v (h w) c xy -> (b v) h w (c xy)", h=h, w=w) # in [0, 1]
# to [-1, 1]
sample_grid = 2 * (sample_grid - 0.5) # [BV, h, w, 2]
fullres_depth = rearrange(depth, "b v h w -> (b v) () h w") # [BV, 1, H, W]
sampled_depth = F.grid_sample(fullres_depth, sample_grid, mode='bilinear', align_corners=True,
padding_mode="border") # [BV, 1, h, w]
# reshape
depths = rearrange(sampled_depth, "(b v) () h w -> b v (h w) () ()", b=b, v=v, h=h, w=w)
if self.cfg.latent_gs:
sh_input_images = rearrange(context["image"], "b v c h w -> (b v) c h w")
if self.cfg.latent_downsample == 4 and self.cfg.init_gaussian_multiple == 4:
sh_input_images = F.interpolate(sh_input_images, scale_factor=0.5, mode='area')
elif self.cfg.latent_downsample == 4 and self.cfg.init_gaussian_multiple == 16:
pass
elif self.cfg.latent_downsample == 2 and self.cfg.init_gaussian_multiple == 4:
pass
elif self.cfg.latent_downsample == 8 and self.cfg.init_gaussian_multiple == 4:
sh_input_images = F.interpolate(sh_input_images, scale_factor=0.25, mode='area')
elif self.cfg.latent_downsample == 8 and self.cfg.init_gaussian_multiple == 16:
sh_input_images = F.interpolate(sh_input_images, scale_factor=0.5, mode='area')
else:
sh_input_images = F.interpolate(sh_input_images, scale_factor=1. / self.cfg.latent_downsample,
mode='area')
sh_input_images = rearrange(sh_input_images, "(b v) c h w -> b v c h w", b=b, v=v)
else:
sh_input_images = context["image"]
assert len(depth_preds) == 1, "num_scales must be 1; multi-scale depth supervision is not supported"
# build gaussians
# scale
scales = torch.clamp(F.softplus(scales - self.cfg.gaussian_adapter.exp_scale_bias),
min=self.cfg.gaussian_adapter.clamp_min_scale,
max=self.cfg.gaussian_adapter.gaussian_scale_max
)
# Convert rotations to world-space
c2w_rotations = context["extrinsics"][..., :3, :3].unsqueeze(2) # [B, V, 1, 3, 3]
# Here quaternions follow the xyzw format (scalar last)
rotations = rotate_quats(c2w_rotations, rotations_unnorm)
rotations_unnorm = rotations.clone()
# Create world-space covariance matrices.
covariances = build_covariance(scale=scales, rotation_xyzw=rotations) # [B, V, H*W, 3, 3]
# means
# [B, V, H*W, 1, 2]
origins, directions = get_world_rays(xy_ray,
context["extrinsics"].unsqueeze(2).unsqueeze(2),
context["intrinsics"].unsqueeze(2).unsqueeze(2))
means = origins + directions * depths
# sh: [B, V, HW, 3, SH]
sh = rearrange(sh, "... (xyz d_sh) -> ... xyz d_sh", xyz=3).clone()
# [B, V, H*W, 3]
sh_input_images = rearrange(sh_input_images, "b v c h w -> b v (h w) c")
# init sh with input images
sh[..., 0] = sh[..., 0] + RGB2SH(sh_input_images)
gaussians = Gaussians(
means=rearrange(means, "b v r spp xyz -> b (v r spp) xyz"),
covariances=rearrange(covariances, "b v r i j -> b (v r) i j"),
harmonics=rearrange(sh, "b v r c d_sh -> b (v r) c d_sh"),
opacities=rearrange(opacities, "b v r spp -> b (v r spp)"),
scales=rearrange(scales, "b v r xyz -> b (v r) xyz"),
rotations=rearrange(rotations, "b v r wxyz -> b (v r) wxyz"),
rotations_unnorm=rearrange(rotations_unnorm, "b v r wxyz -> b (v r) wxyz")
)
# Dump visualizations if needed.
if visualization_dump is not None:
visualization_dump["depth"] = rearrange(
depths, "b v (h w) srf s -> b v h w srf s", h=h, w=w
)
# if self.cfg.pt_downsample > 0:
# visualization_dump["scales"] = gaussians.scales
# visualization_dump["rotations"] = gaussians.rotations
# else:
# visualization_dump["scales"] = rearrange(
# gaussians.scales, "b v r srf spp xyz -> b (v r srf spp) xyz"
# )
# visualization_dump["rotations"] = rearrange(
# gaussians.rotations, "b v r srf spp xyzw -> b (v r srf spp) xyzw"
# )
if self.cfg.return_depth:
# return depth prediction for supervision
depths = depth_preds[-1]
if self.cfg.return_lowres_depth:
assert latent_depth is not None
depths = latent_depth
else:
if depths.shape[-2:] != context["image"].shape[-2:]:
# depths can be at lower resolution since we predict latent
depths = F.interpolate(
depths, size=context["image"].shape[-2:], mode='bilinear', align_corners=True)
return InitializerOutput(
gaussians=gaussians,
depths=depths,
features=condition_features
)
else:
return InitializerOutput(
gaussians=gaussians,
features=condition_features
)
def get_data_shim(self) -> DataShim:
def data_shim(batch: BatchedExample) -> BatchedExample:
patch_size = self.cfg.shim_patch_size
if isinstance(self.cfg.shim_patch_size, int):
patch_size = patch_size * self.cfg.downscale_factor
else:
patch_size = [p * self.cfg.downscale_factor for p in patch_size]
batch = apply_patch_shim(
batch,
patch_size=patch_size,
)
return batch
return data_shim
@staticmethod
def update_gt_depth_range(batch):
assert "depth" in batch["context"]
batch["context"]["near"] = batch["context"]["depth"].min(dim=3)[0].min(dim=2)[0].clamp(min=0.01)
batch["context"]["far"] = batch["context"]["depth"].max(dim=3)[0].max(dim=2)[0].clamp(max=1000.)
batch["target"]["near"] = batch["target"]["depth"].min(dim=3)[0].min(dim=2)[0].clamp(min=0.01)
batch["target"]["far"] = batch["target"]["depth"].max(dim=3)[0].max(dim=2)[0].clamp(max=1000.)
def update_depth_range_from_disparity(self, batch):
b, v, _, h, w = batch["context"]["image"].shape
# TODO: support multi-view later
assert v == 2
assert self.decoder.cfg.scale_invariant is False
w = batch["context"]["image"].shape[-1]
# compute the depth range based on disparity range
dist = (batch["context"]["extrinsics"][:, 0, :3, 3] - batch["context"]["extrinsics"][:, 1, :3, 3]).norm(
dim=1, keepdim=True)
focal = batch["context"]["intrinsics"][:, :, 0, 0] * w
min_depth = dist * focal / self.train_cfg.max_disparity
max_depth = dist * focal / self.train_cfg.min_disparity
batch["context"]["near"] = min_depth
batch["context"]["far"] = max_depth
# TODO: also update target near and far
def predict_scale(self, batch):
context = batch["context"]
# [B, V, H, W]
init_depth = self.encoder.scale_predictor(
context["image"],
attn_splits_list=[2],
min_depth=1. / context["far"],
max_depth=1. / context["near"],
intrinsics=context["intrinsics"],
extrinsics=context["extrinsics"],
)['depth_preds'][-1]
if not self.encoder.cfg.no_pred_depth_range:
new_near = init_depth.min(dim=3)[0].min(dim=2)[0].clamp(min=0.1) # [B, V]
new_far = init_depth.max(dim=3)[0].max(dim=2)[0].clamp(max=100.)
batch["context"]["near"] = new_near
batch["context"]["far"] = new_far
batch["target"]["near"] = new_near.min(dim=1, keepdim=True)[0].repeat(1,
batch["target"]["near"].shape[1])
batch["target"]["far"] = new_far.max(dim=1, keepdim=True)[0].repeat(1, batch["target"]["near"].shape[1])
if self.encoder.cfg.norm_by_points:
b, v, h, w = init_depth.shape
# get point cloud
xy_ray, _ = sample_image_grid((h, w), batch["context"]["image"].device)
xy_ray = rearrange(xy_ray, "h w xy -> (h w) () xy")
# [B, V, H*W, 1, 2]
tmp_coords = xy_ray.unsqueeze(0).unsqueeze(0).repeat(b, v, 1, 1, 1)
# [B, V, H*W, 1, 1]
tmp_depth = rearrange(init_depth, "b v h w -> b v (h w) () ()")
# [B, V, 1, 1, 4, 4]
tmp_extrinsics = context["extrinsics"].unsqueeze(2).unsqueeze(2)
# [B, V, 1, 1, 3, 3]
tmp_intrinsics = context["intrinsics"].unsqueeze(2).unsqueeze(2)
# [B, V, H*W, 1, 3]
origins, directions = get_world_rays(tmp_coords, tmp_extrinsics, tmp_intrinsics)
point_cloud = origins + directions * tmp_depth
point_cloud = rearrange(point_cloud, "b v h w c -> b (v h w) c")
point_dist = point_cloud.norm(dim=-1).mean(dim=-1) # [B]
norm_factor = point_dist.clamp(min=1e-6)
# normalize near, far and extrinsics
batch["context"]["near"] = batch["context"]["near"] / norm_factor.view(b, 1)
batch["context"]["far"] = batch["context"]["far"] / norm_factor.view(b, 1)
batch["target"]["near"] = batch["target"]["near"] / norm_factor.view(b, 1)
batch["target"]["far"] = batch["target"]["far"] / norm_factor.view(b, 1)
batch["context"]["extrinsics"][:, :, :3, -1] /= norm_factor.view(b, 1, 1)
batch["target"]["extrinsics"][:, :, :3, -1] /= norm_factor.view(b, 1, 1)
def preprocessing(self, batch, train_cfg):
# use gt depth range instead of a fixed one
if train_cfg.use_gt_depth_range:
self.update_gt_depth_range(batch)
# compute depth range from camera distance and disparity range
if train_cfg.depth_range_from_disparity:
self.update_depth_range_from_disparity(batch)
# use a pretrained depth model to predict scale
if self.cfg.predict_scale:
self.predict_scale(batch)