|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
from functools import partial |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from einops import rearrange, repeat |
|
|
|
|
|
from src.models.utils.attention import PatchEmbed3D |
|
|
from src.rendering.gs import GaussianRenderer |
|
|
from src.rendering.gs_deferred import GaussianRendererDeferred |
|
|
|
|
|
from src.models.utils.cosmos_1_tokenizer import load_cosmos_1_decoder |
|
|
from src.models.utils.render import subsample_pixels_spatio_temporal, query_z_with_indices, subsample_x_and_rays |
|
|
from src.models.utils.model import get_model_blocks, ConvTranspose3dFactorized, MultiStageConvTranspose3d, ConvTranspose3dReduced, forward_checkpointing, PositionalEmbedding |
|
|
|
|
|
class LatentRecon(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
opt, |
|
|
): |
|
|
super().__init__() |
|
|
self.opt = opt |
|
|
|
|
|
|
|
|
norm_layer = partial(nn.LayerNorm, eps=1e-6) |
|
|
self.enc_blocks = get_model_blocks( |
|
|
self.opt.enc_embed_dim, |
|
|
self.opt.enc_depth, |
|
|
self.opt.enc_num_heads, |
|
|
self.opt.mlp_ratio, |
|
|
self.opt.use_mamba, |
|
|
self.opt.llrm_7m1t, |
|
|
norm_layer, |
|
|
self.opt.use_qk_norm, |
|
|
self.opt.llrm_7m1t_index, |
|
|
) |
|
|
self.enc_norm = norm_layer(self.opt.enc_embed_dim) |
|
|
|
|
|
|
|
|
if self.opt.num_block_channels_reduce is not None: |
|
|
self.blocks_out_channels = self.opt.num_block_channels_reduce |
|
|
self.block_out = nn.Linear(self.opt.enc_embed_dim, self.blocks_out_channels) |
|
|
else: |
|
|
self.blocks_out_channels = self.opt.enc_embed_dim |
|
|
self.block_out = None |
|
|
|
|
|
|
|
|
patch_size_video = [ |
|
|
self.opt.patch_size_temporal, |
|
|
self.opt.patch_size, |
|
|
self.opt.patch_size, |
|
|
] |
|
|
patch_size_plucker = [ |
|
|
self.opt.latent_time_compression, |
|
|
self.opt.latent_spat_compression * self.opt.patch_size, |
|
|
self.opt.latent_spat_compression * self.opt.patch_size, |
|
|
] |
|
|
|
|
|
|
|
|
if self.opt.time_embedding: |
|
|
if self.opt.time_embedding_vae: |
|
|
patch_size_time = patch_size_video |
|
|
time_embedding_dim = self.opt.num_latent_c |
|
|
else: |
|
|
patch_size_time = [ |
|
|
self.opt.latent_time_compression, |
|
|
self.opt.patch_size, |
|
|
self.opt.patch_size, |
|
|
] |
|
|
time_embedding_dim = self.opt.time_embedding_dim |
|
|
self.stride_size_out = [ |
|
|
self.opt.latent_time_compression // self.opt.patch_size_out_factor[0], |
|
|
self.opt.latent_spat_compression * self.opt.patch_size // self.opt.patch_size_out_factor[1], |
|
|
self.opt.latent_spat_compression * self.opt.patch_size // self.opt.patch_size_out_factor[2], |
|
|
] |
|
|
|
|
|
|
|
|
if self.opt.use_rgb_decoder: |
|
|
self.padding_time = self.padding_plucker = (0, 0, 0) |
|
|
self.patch_size_extra_t = 0 |
|
|
else: |
|
|
|
|
|
if self.opt.time_embedding_vae: |
|
|
self.padding_time = (0, 0, 0) |
|
|
else: |
|
|
self.padding_time = (2, 0, 0) |
|
|
self.padding_plucker = (self.opt.latent_time_compression//2, 0, 0) |
|
|
self.patch_size_extra_t = 1 |
|
|
self.patch_size_out = [self.stride_size_out[0] + self.patch_size_extra_t, self.stride_size_out[1], self.stride_size_out[2]] |
|
|
|
|
|
|
|
|
if self.opt.use_patch_embeddings_encoder: |
|
|
self.patch_embed = PatchEmbed3D(patch_size_video, self.opt.num_latent_c, self.opt.enc_embed_dim) |
|
|
|
|
|
|
|
|
if self.opt.plucker_embedding_vae: |
|
|
patch_size_plucker = patch_size_video |
|
|
self.padding_plucker = 0 |
|
|
|
|
|
if self.opt.plucker_embedding_vae_fuse_type == 'concat': |
|
|
num_plucker_in_channels = 2 * self.opt.num_latent_c |
|
|
else: |
|
|
num_plucker_in_channels = 6 |
|
|
self.patch_plucker_embed = PatchEmbed3D(patch_size_plucker, num_plucker_in_channels, self.opt.enc_embed_dim, zero_init=True, padding=self.padding_plucker) |
|
|
|
|
|
|
|
|
if self.opt.time_embedding: |
|
|
self.patch_time_embed = PatchEmbed3D(patch_size_time, time_embedding_dim, self.opt.enc_embed_dim, zero_init=True, padding=self.padding_time) |
|
|
self.patch_time_embed_tgt = PatchEmbed3D(patch_size_time, time_embedding_dim, self.opt.enc_embed_dim, zero_init=True, padding=self.padding_time) |
|
|
|
|
|
|
|
|
if self.opt.use_pos_embedding: |
|
|
self.pos_embedding = PositionalEmbedding(**self.opt.pos_embedding_kwargs) |
|
|
|
|
|
|
|
|
self.output_dims = self.opt.output_dims |
|
|
|
|
|
|
|
|
if self.opt.sub_sample_gaussians_type == 'learned': |
|
|
self.output_dims += 1 |
|
|
|
|
|
|
|
|
if self.opt.gaussians_predict_offset: |
|
|
self.output_dims += 3 |
|
|
|
|
|
|
|
|
transposed_conv_kwargs = {} |
|
|
if self.opt.transposed_conv_type == 'factorized': |
|
|
transposed_conv_module = ConvTranspose3dFactorized |
|
|
elif self.opt.transposed_conv_type == 'reduce_transposed': |
|
|
transposed_conv_module = ConvTranspose3dReduced |
|
|
if self.opt.transposed_conv_hidden_channels is not None: |
|
|
transposed_conv_kwargs['hidden_channels'] = self.opt.transposed_conv_hidden_channels |
|
|
elif self.opt.transposed_conv_type == 'multi_stage_transpose': |
|
|
transposed_conv_module = MultiStageConvTranspose3d |
|
|
else: |
|
|
transposed_conv_module = nn.ConvTranspose3d |
|
|
|
|
|
|
|
|
if self.opt.use_cosmos_decoder: |
|
|
self.opt.decoder_cosmos_kwargs['out_channels'] = self.output_dims |
|
|
self.decoder_cosmos, tokenizer_config = load_cosmos_1_decoder(self.opt.vae_path, self.opt.decoder_cosmos_kwargs) |
|
|
deconv_out_channels = tokenizer_config['channels'] |
|
|
|
|
|
temp_factor = int(self.opt.latent_time_compression/self.opt.decoder_cosmos_kwargs['temporal_compression']) |
|
|
spat_factor = int(self.opt.latent_spat_compression/self.opt.decoder_cosmos_kwargs['spatial_compression']) |
|
|
self.patch_size_out = [ |
|
|
int(1 / self.opt.patch_size_out_factor[0] * temp_factor), |
|
|
int(self.opt.patch_size / self.opt.patch_size_out_factor[1] * spat_factor), |
|
|
int(self.opt.patch_size / self.opt.patch_size_out_factor[2] * spat_factor), |
|
|
] |
|
|
self.stride_size_out = self.patch_size_out |
|
|
patch_size_out_deconv = self.patch_size_out |
|
|
stride_size_out_deconv = self.stride_size_out |
|
|
self.deconv = transposed_conv_module(self.blocks_out_channels, deconv_out_channels, patch_size_out_deconv, stride=stride_size_out_deconv, padding=0) |
|
|
else: |
|
|
if self.opt.use_patch_embeddings_encoder: |
|
|
self.padding_deconv = (self.opt.latent_time_compression//2, 0, 0) |
|
|
self.deconv = transposed_conv_module(self.blocks_out_channels, self.output_dims, self.patch_size_out, stride=self.stride_size_out, padding=self.padding_deconv, **transposed_conv_kwargs) |
|
|
|
|
|
|
|
|
for module_name, module in self.named_children(): |
|
|
module.apply(self._init_weights) |
|
|
|
|
|
|
|
|
if self.opt.deferred_bp: |
|
|
self.gs = GaussianRendererDeferred(opt) |
|
|
else: |
|
|
self.gs = GaussianRenderer(opt) |
|
|
|
|
|
|
|
|
scale_cap = opt.gaussian_scale_cap |
|
|
scale_shift = 1 - math.log(scale_cap) |
|
|
self.scale_act = lambda x: torch.minimum(torch.exp(x-scale_shift),torch.tensor([scale_cap],device=x.device,dtype=x.dtype)) |
|
|
self.opacity_act = lambda x: torch.sigmoid(x-2.0) |
|
|
self.rot_act = lambda x: F.normalize(x, dim=-1) |
|
|
self.rgb_act = lambda x: 0.5 * torch.tanh(x) + 0.5 |
|
|
self.dnear = self.opt.dnear |
|
|
self.dfar = self.opt.dfar |
|
|
|
|
|
def forward_gaussians(self, images_input, plucker_embedding, rays_os, rays_ds, time_embeddings, num_input_multi_views=None): |
|
|
|
|
|
images_input = self.reshape_mv_temp_to_batch(images_input, num_input_multi_views=num_input_multi_views) |
|
|
plucker_embedding = self.reshape_mv_temp_to_batch(plucker_embedding, num_input_multi_views=num_input_multi_views) |
|
|
rays_os = self.reshape_mv_temp_to_batch(rays_os, num_input_multi_views=num_input_multi_views) |
|
|
rays_ds = self.reshape_mv_temp_to_batch(rays_ds, num_input_multi_views=num_input_multi_views) |
|
|
|
|
|
B, V, C, H, W = images_input.shape |
|
|
h = int(H//self.opt.patch_size) |
|
|
w = int(W//self.opt.patch_size) |
|
|
|
|
|
|
|
|
if self.opt.use_patch_embeddings_encoder: |
|
|
x = forward_checkpointing(self.patch_embed, images_input, gradient_checkpoint=self.opt.gradient_checkpoint_transformer) |
|
|
else: |
|
|
x = rearrange(images_input, 'b t c h w -> b (t h w) c') |
|
|
|
|
|
|
|
|
if self.opt.time_embedding and self.opt.get('use_time_embedding', True): |
|
|
x_time_emb, x_time_emb_tgt = self.get_time_embedding(time_embeddings, V, num_input_multi_views=num_input_multi_views) |
|
|
x = x + x_time_emb + x_time_emb_tgt |
|
|
|
|
|
|
|
|
if self.opt.use_plucker: |
|
|
x = x + forward_checkpointing(self.patch_plucker_embed, plucker_embedding, gradient_checkpoint=self.opt.gradient_checkpoint_transformer) |
|
|
|
|
|
|
|
|
if self.opt.process_multi_views: |
|
|
x = self.reshape_mv_batch_to_temp(x, num_input_multi_views=num_input_multi_views) |
|
|
|
|
|
|
|
|
if self.opt.use_pos_embedding: |
|
|
x = self.pos_embedding(x) |
|
|
|
|
|
|
|
|
for blk_idx, blk in enumerate(self.enc_blocks): |
|
|
x = forward_checkpointing(blk, x , gradient_checkpoint=self.opt.gradient_checkpoint_transformer) |
|
|
x = forward_checkpointing(self.enc_norm, x, gradient_checkpoint=self.opt.gradient_checkpoint_transformer) |
|
|
|
|
|
|
|
|
if self.block_out is not None: |
|
|
x = forward_checkpointing(self.block_out, x, gradient_checkpoint=self.opt.gradient_checkpoint_transformer) |
|
|
|
|
|
|
|
|
if self.opt.process_multi_views: |
|
|
x = self.reshape_mv_temp_to_batch(x, num_input_multi_views=num_input_multi_views) |
|
|
|
|
|
|
|
|
x = rearrange(x, 'b (t h w) c -> b c t h w', h=h, w=w) |
|
|
if self.opt.use_patch_embeddings_encoder: |
|
|
x = forward_checkpointing(self.deconv, x, gradient_checkpoint=self.opt.gradient_checkpoint_conv) |
|
|
elif self.opt.use_cosmos_decoder: |
|
|
x = self.decoder_cosmos(x, gradient_checkpoint=self.opt.gradient_checkpoint_conv) |
|
|
|
|
|
|
|
|
if self.opt.sub_sample_gaussians_factor is not None: |
|
|
if self.opt.sub_sample_gaussians_type == 'learned': |
|
|
x, x_mask = x[:, :-1], x[:, [-1]] |
|
|
else: |
|
|
x_mask = None |
|
|
|
|
|
|
|
|
if self.opt.sub_sample_gaussians and self.opt.sub_sample_gaussians_factor is not None: |
|
|
x = forward_checkpointing(self.subsample_x_and_rays_wrapper, x, rays_os, rays_ds, x_mask, gradient_checkpoint=self.opt.gradient_checkpoint_conv) |
|
|
else: |
|
|
x = rearrange(x, 'b c t h w -> b (t h w) c') |
|
|
rays_os = rearrange(rays_os, 'b t c h w -> b (t h w) c') |
|
|
rays_ds = rearrange(rays_ds, 'b t c h w -> b (t h w) c') |
|
|
x_mask = None |
|
|
|
|
|
|
|
|
x = forward_checkpointing(self.gaussian_processing, x, rays_os, rays_ds, gradient_checkpoint=self.opt.gradient_checkpoint_conv) |
|
|
|
|
|
|
|
|
if self.opt.fuse_multi_views or not self.training: |
|
|
x = self.reshape_mv_batch_to_temp(x, num_input_multi_views) |
|
|
|
|
|
|
|
|
x = self.gaussian_pruning(x) |
|
|
return x, x_mask |
|
|
|
|
|
def forward(self, data, skip_loss=False): |
|
|
|
|
|
results = {} |
|
|
loss = 0 |
|
|
|
|
|
rays_os = data['rays_os'] |
|
|
rays_ds = data['rays_ds'] |
|
|
plucker_embedding = data['plucker_embedding'] |
|
|
images = data['images_input_embed'] |
|
|
time_embeddings = data['time_embeddings'] |
|
|
cam_view = data['cam_view'] |
|
|
if len(cam_view.shape) == 5: |
|
|
|
|
|
cam_view = cam_view.squeeze(1) |
|
|
intrinsics = data['intrinsics'] |
|
|
num_input_multi_views = data['num_input_multi_views'] |
|
|
B = images.shape[0] |
|
|
|
|
|
|
|
|
gaussians, gaussians_mask = self.forward_gaussians(images, plucker_embedding, rays_os, rays_ds, time_embeddings=time_embeddings, num_input_multi_views=num_input_multi_views) |
|
|
|
|
|
|
|
|
add_gs_render_kwargs = {} |
|
|
if self.opt.deferred_bp: |
|
|
bg_color = [1, 1, 1] |
|
|
add_gs_render_kwargs = {'patch_size': self.opt.gs_render_patch_size} |
|
|
else: |
|
|
bg_color = torch.ones(3, dtype=gaussians.dtype, device=gaussians.device) |
|
|
|
|
|
|
|
|
results = self.gs.render(gaussians, cam_view, bg_color=bg_color, intrinsics=intrinsics, **add_gs_render_kwargs) |
|
|
|
|
|
|
|
|
if self.training: |
|
|
out = {} |
|
|
out['images_pred'] = results['images_pred'] |
|
|
out['depths_pred'] = results['depths_pred'] |
|
|
|
|
|
|
|
|
if self.opt.lambda_opacity > 0: |
|
|
opacity = gaussians[..., 3] |
|
|
else: |
|
|
opacity = None |
|
|
out["opacity_pred"] = opacity |
|
|
else: |
|
|
out = results |
|
|
out['gaussians_mask_pred'] = gaussians_mask |
|
|
out['gaussians'] = gaussians |
|
|
return out |
|
|
|
|
|
def gaussian_processing(self, x: torch.Tensor, rays_os: torch.Tensor, rays_ds: torch.Tensor): |
|
|
|
|
|
if self.opt.gaussians_predict_offset: |
|
|
pos_offset = x[..., -3:] |
|
|
x = x[..., :-3] |
|
|
distance, rgb, scaling, rotation, opacity = x.split([1, 3, 3, 4, 1], dim=-1) |
|
|
w = torch.sigmoid(distance + self.opt.pre_sigmoid_distance_shift) |
|
|
depths = self.dnear * (1 - w) + self.dfar * w |
|
|
pos = rays_os + rays_ds * depths |
|
|
|
|
|
|
|
|
if self.opt.gaussians_predict_offset and self.opt.use_gaussians_predict_offset: |
|
|
if self.opt.gaussians_predict_offset_act == 'clamp': |
|
|
pos_offset = pos_offset.clamp(self.opt.gaussians_predict_offset_range[0], self.opt.gaussians_predict_offset_range[1]) |
|
|
elif self.opt.gaussians_predict_offset_act == 'tanh': |
|
|
pos_offset = self.opt.gaussians_predict_offset_range[1] * torch.tanh(pos_offset) |
|
|
pos = pos + pos_offset |
|
|
|
|
|
|
|
|
opacity = self.opacity_act(opacity) |
|
|
scale = self.scale_act(scaling) |
|
|
rotation = self.rot_act(rotation) |
|
|
rgbs = self.rgb_act(rgb) |
|
|
gaussians = torch.cat([pos, opacity, scale, rotation, rgbs], dim=-1) |
|
|
return gaussians |
|
|
|
|
|
def get_time_embedding(self, time_embeddings: torch.Tensor, V: int, num_input_multi_views: int = None): |
|
|
|
|
|
if self.opt.time_embedding_vae: |
|
|
num_in_times = V |
|
|
else: |
|
|
time_embeddings = repeat(time_embeddings, 'b t c -> b t c h w', h=H, w=W) |
|
|
num_in_times = self.opt.num_input_views |
|
|
time_embeddings_input = time_embeddings[:, :num_in_times] |
|
|
time_embeddings_target = time_embeddings[:, num_in_times: num_in_times + 1] |
|
|
|
|
|
time_embeddings_target = repeat(time_embeddings_target, 'b 1 c h w -> b t c h w', t=num_in_times) |
|
|
|
|
|
x_time_emb = self.patch_time_embed(time_embeddings_input) |
|
|
x_time_emb_tgt = self.patch_time_embed_tgt(time_embeddings_target) |
|
|
|
|
|
x_time_emb = self.repeat_to_mv(x_time_emb, num_input_multi_views=num_input_multi_views) |
|
|
x_time_emb_tgt = self.repeat_to_mv(x_time_emb_tgt, num_input_multi_views=num_input_multi_views) |
|
|
return x_time_emb, x_time_emb_tgt |
|
|
|
|
|
def gaussian_pruning(self, gaussians: torch.Tensor): |
|
|
|
|
|
prune_ratio = self.opt.gaussians_prune_ratio |
|
|
if prune_ratio > 0: |
|
|
opacity = gaussians[:, :, [3]] |
|
|
num_gaussians = gaussians.shape[1] |
|
|
keep_ratio = 1 - prune_ratio |
|
|
random_ratio = self.opt.gaussians_random_ratio |
|
|
random_ratio = keep_ratio * random_ratio |
|
|
keep_ratio = keep_ratio - random_ratio |
|
|
num_keep = int(num_gaussians * keep_ratio) |
|
|
num_keep_random = int(num_gaussians * random_ratio) |
|
|
|
|
|
idx_sort = opacity.argsort(dim=1, descending=True) |
|
|
keep_idx = idx_sort[:, :num_keep] |
|
|
if num_keep_random > 0: |
|
|
rest_idx = idx_sort[:, num_keep:] |
|
|
random_idx = rest_idx[:, torch.randperm(rest_idx.shape[1])[:num_keep_random]] |
|
|
keep_idx = torch.cat([keep_idx, random_idx], dim=1) |
|
|
gaussians = gaussians.gather(1, keep_idx.expand(-1, -1, gaussians.shape[-1])) |
|
|
return gaussians |
|
|
|
|
|
def subsample_x_and_rays_wrapper(self, x, rays_os, rays_ds, x_mask): |
|
|
return subsample_x_and_rays( |
|
|
x, rays_os, rays_ds, x_mask, |
|
|
self.opt.sub_sample_gaussians_factor, |
|
|
self.opt.sub_sample_gaussians_type, |
|
|
self.opt.sub_sample_gaussians_type_tokens, |
|
|
self.opt.sub_sample_gaussians_temperature, |
|
|
self.training, |
|
|
) |
|
|
|
|
|
def _init_weights(self, m): |
|
|
from timm.models.layers import trunc_normal_ |
|
|
if isinstance(m, nn.Linear): |
|
|
trunc_normal_(m.weight, std=.02) |
|
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
|
nn.init.constant_(m.bias, 0) |
|
|
elif isinstance(m, nn.LayerNorm): |
|
|
nn.init.constant_(m.bias, 0) |
|
|
nn.init.constant_(m.weight, 1.0) |
|
|
|
|
|
def subsample_views(self, x: torch.Tensor, num_input_multi_views: int): |
|
|
x = rearrange(x, '(b v) ... -> b v ...', v=num_input_multi_views) |
|
|
x = x[:, :self.opt.num_target_multi_views] |
|
|
x = rearrange(x, 'b v ... -> (b v) ...') |
|
|
return x |
|
|
|
|
|
def reshape_mv_temp_to_batch(self, x, num_input_multi_views=None): |
|
|
if num_input_multi_views is None: |
|
|
num_input_multi_views = self.num_input_multi_views |
|
|
if num_input_multi_views != 1: |
|
|
if len(x.shape) == 5: |
|
|
x = rearrange(x, 'b (v t) c h w -> (b v) t c h w', v=num_input_multi_views) |
|
|
elif len(x.shape) == 3: |
|
|
x = rearrange(x, 'b (v d) c -> (b v) d c', v=num_input_multi_views) |
|
|
return x |
|
|
|
|
|
def reshape_mv_batch_to_temp(self, x, num_input_multi_views=None): |
|
|
if num_input_multi_views is None: |
|
|
num_input_multi_views = self.num_input_multi_views |
|
|
if num_input_multi_views != 1: |
|
|
if len(x.shape) == 5: |
|
|
x = rearrange(x, '(b v) t c h w -> b (v t) c h w', v=num_input_multi_views) |
|
|
elif len(x.shape) == 3: |
|
|
x = rearrange(x, '(b v) d c -> b (v d) c', v=num_input_multi_views) |
|
|
return x |
|
|
|
|
|
def reshape_mv_batch_to_mv(self, x, num_input_multi_views=None): |
|
|
if num_input_multi_views is None: |
|
|
num_input_multi_views = self.num_input_multi_views |
|
|
if num_input_multi_views != 1: |
|
|
if len(x.shape) == 5: |
|
|
x = rearrange(x, '(b v) t c h w -> b v t c h w', v=num_input_multi_views) |
|
|
elif len(x.shape) == 3: |
|
|
x = rearrange(x, '(b v) d c -> b v d c', v=num_input_multi_views) |
|
|
return x |
|
|
|
|
|
def reshape_mv_batch_to_view(self, x, num_input_multi_views=None): |
|
|
if num_input_multi_views is None: |
|
|
num_input_multi_views = self.num_input_multi_views |
|
|
if num_input_multi_views != 1: |
|
|
if len(x.shape) == 3: |
|
|
x = rearrange(x, '(b v) d c -> (b d) v c', v=num_input_multi_views) |
|
|
return x |
|
|
|
|
|
def repeat_to_mv(self, x, num_input_multi_views=None): |
|
|
if num_input_multi_views is None: |
|
|
num_input_multi_views = self.num_input_multi_views |
|
|
if num_input_multi_views != 1: |
|
|
if len(x.shape) == 5: |
|
|
x = repeat(x, 'b t c h w -> (b v) t c h w', v=num_input_multi_views) |
|
|
elif len(x.shape) == 3: |
|
|
x = repeat(x, 'b d c -> (b v) d c', v=num_input_multi_views) |
|
|
return x |