Lyra / src /models /recon /model_latent_recon.py
Muhammad Taqi Raza
adding lyra files
af758d1
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from src.models.utils.attention import PatchEmbed3D
from src.rendering.gs import GaussianRenderer
from src.rendering.gs_deferred import GaussianRendererDeferred
from src.models.utils.cosmos_1_tokenizer import load_cosmos_1_decoder
from src.models.utils.render import subsample_pixels_spatio_temporal, query_z_with_indices, subsample_x_and_rays
from src.models.utils.model import get_model_blocks, ConvTranspose3dFactorized, MultiStageConvTranspose3d, ConvTranspose3dReduced, forward_checkpointing, PositionalEmbedding
class LatentRecon(nn.Module):
def __init__(
self,
opt,
):
super().__init__()
self.opt = opt
# Main blocks
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.enc_blocks = get_model_blocks(
self.opt.enc_embed_dim,
self.opt.enc_depth,
self.opt.enc_num_heads,
self.opt.mlp_ratio,
self.opt.use_mamba,
self.opt.llrm_7m1t,
norm_layer,
self.opt.use_qk_norm,
self.opt.llrm_7m1t_index,
)
self.enc_norm = norm_layer(self.opt.enc_embed_dim)
# Reduce number of channels after main blocks
if self.opt.num_block_channels_reduce is not None:
self.blocks_out_channels = self.opt.num_block_channels_reduce
self.block_out = nn.Linear(self.opt.enc_embed_dim, self.blocks_out_channels)
else:
self.blocks_out_channels = self.opt.enc_embed_dim
self.block_out = None
# Patch dimensions
patch_size_video = [
self.opt.patch_size_temporal,
self.opt.patch_size,
self.opt.patch_size,
]
patch_size_plucker = [
self.opt.latent_time_compression,
self.opt.latent_spat_compression * self.opt.patch_size,
self.opt.latent_spat_compression * self.opt.patch_size,
]
# If time embedded already with VAE, use same patchification as image latents
if self.opt.time_embedding:
if self.opt.time_embedding_vae:
patch_size_time = patch_size_video
time_embedding_dim = self.opt.num_latent_c
else:
patch_size_time = [
self.opt.latent_time_compression,
self.opt.patch_size,
self.opt.patch_size,
]
time_embedding_dim = self.opt.time_embedding_dim
self.stride_size_out = [
self.opt.latent_time_compression // self.opt.patch_size_out_factor[0],
self.opt.latent_spat_compression * self.opt.patch_size // self.opt.patch_size_out_factor[1],
self.opt.latent_spat_compression * self.opt.patch_size // self.opt.patch_size_out_factor[2],
]
# Patch embeddings for non-latent input (t h w -> t_latent h_latent/2 w_latent/2)
if self.opt.use_rgb_decoder:
self.padding_time = self.padding_plucker = (0, 0, 0)
self.patch_size_extra_t = 0
else:
# If time embedded already with VAE, use same patchification as image latents
if self.opt.time_embedding_vae:
self.padding_time = (0, 0, 0)
else:
self.padding_time = (2, 0, 0)
self.padding_plucker = (self.opt.latent_time_compression//2, 0, 0)
self.patch_size_extra_t = 1
self.patch_size_out = [self.stride_size_out[0] + self.patch_size_extra_t, self.stride_size_out[1], self.stride_size_out[2]]
# Patch embeddings for video latents (t_latent h_latent w_latent -> t_latent h_latent/2 w_latent/2)
if self.opt.use_patch_embeddings_encoder:
self.patch_embed = PatchEmbed3D(patch_size_video, self.opt.num_latent_c, self.opt.enc_embed_dim)
# Plucker
if self.opt.plucker_embedding_vae:
patch_size_plucker = patch_size_video
self.padding_plucker = 0
# Encoded rays_dxo and rays_d with vae separately
if self.opt.plucker_embedding_vae_fuse_type == 'concat':
num_plucker_in_channels = 2 * self.opt.num_latent_c
else:
num_plucker_in_channels = 6
self.patch_plucker_embed = PatchEmbed3D(patch_size_plucker, num_plucker_in_channels, self.opt.enc_embed_dim, zero_init=True, padding=self.padding_plucker)
# Encode time conditionings
if self.opt.time_embedding:
self.patch_time_embed = PatchEmbed3D(patch_size_time, time_embedding_dim, self.opt.enc_embed_dim, zero_init=True, padding=self.padding_time)
self.patch_time_embed_tgt = PatchEmbed3D(patch_size_time, time_embedding_dim, self.opt.enc_embed_dim, zero_init=True, padding=self.padding_time)
# Positional embedding (if plucker is not used)
if self.opt.use_pos_embedding:
self.pos_embedding = PositionalEmbedding(**self.opt.pos_embedding_kwargs)
# Output layers
self.output_dims = self.opt.output_dims
# Learn mask that prunes gaussians
if self.opt.sub_sample_gaussians_type == 'learned':
self.output_dims += 1
# Learn position offsets
if self.opt.gaussians_predict_offset:
self.output_dims += 3
# Set transposed conv decoding module
transposed_conv_kwargs = {}
if self.opt.transposed_conv_type == 'factorized':
transposed_conv_module = ConvTranspose3dFactorized
elif self.opt.transposed_conv_type == 'reduce_transposed':
transposed_conv_module = ConvTranspose3dReduced
if self.opt.transposed_conv_hidden_channels is not None:
transposed_conv_kwargs['hidden_channels'] = self.opt.transposed_conv_hidden_channels
elif self.opt.transposed_conv_type == 'multi_stage_transpose':
transposed_conv_module = MultiStageConvTranspose3d
else:
transposed_conv_module = nn.ConvTranspose3d
# Decode into RGB with cosmos decoder or just one deconv
if self.opt.use_cosmos_decoder:
self.opt.decoder_cosmos_kwargs['out_channels'] = self.output_dims
self.decoder_cosmos, tokenizer_config = load_cosmos_1_decoder(self.opt.vae_path, self.opt.decoder_cosmos_kwargs)
deconv_out_channels = tokenizer_config['channels']
# Set up new patchification based on cosmos upsampling
temp_factor = int(self.opt.latent_time_compression/self.opt.decoder_cosmos_kwargs['temporal_compression'])
spat_factor = int(self.opt.latent_spat_compression/self.opt.decoder_cosmos_kwargs['spatial_compression'])
self.patch_size_out = [
int(1 / self.opt.patch_size_out_factor[0] * temp_factor),
int(self.opt.patch_size / self.opt.patch_size_out_factor[1] * spat_factor),
int(self.opt.patch_size / self.opt.patch_size_out_factor[2] * spat_factor),
]
self.stride_size_out = self.patch_size_out
patch_size_out_deconv = self.patch_size_out
stride_size_out_deconv = self.stride_size_out
self.deconv = transposed_conv_module(self.blocks_out_channels, deconv_out_channels, patch_size_out_deconv, stride=stride_size_out_deconv, padding=0)
else:
if self.opt.use_patch_embeddings_encoder:
self.padding_deconv = (self.opt.latent_time_compression//2, 0, 0)
self.deconv = transposed_conv_module(self.blocks_out_channels, self.output_dims, self.patch_size_out, stride=self.stride_size_out, padding=self.padding_deconv, **transposed_conv_kwargs)
# Initialize weights
for module_name, module in self.named_children():
module.apply(self._init_weights)
# Gaussian renderer
if self.opt.deferred_bp:
self.gs = GaussianRendererDeferred(opt)
else:
self.gs = GaussianRenderer(opt)
# Gaussian misc
scale_cap = opt.gaussian_scale_cap
scale_shift = 1 - math.log(scale_cap)
self.scale_act = lambda x: torch.minimum(torch.exp(x-scale_shift),torch.tensor([scale_cap],device=x.device,dtype=x.dtype))
self.opacity_act = lambda x: torch.sigmoid(x-2.0)
self.rot_act = lambda x: F.normalize(x, dim=-1)
self.rgb_act = lambda x: 0.5 * torch.tanh(x) + 0.5
self.dnear = self.opt.dnear
self.dfar = self.opt.dfar
def forward_gaussians(self, images_input, plucker_embedding, rays_os, rays_ds, time_embeddings, num_input_multi_views=None):
# Compute embeddings per view independently, reshape multi-view from temporal into batch
images_input = self.reshape_mv_temp_to_batch(images_input, num_input_multi_views=num_input_multi_views)
plucker_embedding = self.reshape_mv_temp_to_batch(plucker_embedding, num_input_multi_views=num_input_multi_views)
rays_os = self.reshape_mv_temp_to_batch(rays_os, num_input_multi_views=num_input_multi_views)
rays_ds = self.reshape_mv_temp_to_batch(rays_ds, num_input_multi_views=num_input_multi_views)
B, V, C, H, W = images_input.shape
h = int(H//self.opt.patch_size)
w = int(W//self.opt.patch_size)
# Patchify input images
if self.opt.use_patch_embeddings_encoder:
x = forward_checkpointing(self.patch_embed, images_input, gradient_checkpoint=self.opt.gradient_checkpoint_transformer)
else:
x = rearrange(images_input, 'b t c h w -> b (t h w) c')
# Time embedding
if self.opt.time_embedding and self.opt.get('use_time_embedding', True):
x_time_emb, x_time_emb_tgt = self.get_time_embedding(time_embeddings, V, num_input_multi_views=num_input_multi_views)
x = x + x_time_emb + x_time_emb_tgt
# Add Plucker embeddings
if self.opt.use_plucker:
x = x + forward_checkpointing(self.patch_plucker_embed, plucker_embedding, gradient_checkpoint=self.opt.gradient_checkpoint_transformer)
# Reshape views into THW dimension for joint attention across input views
if self.opt.process_multi_views:
x = self.reshape_mv_batch_to_temp(x, num_input_multi_views=num_input_multi_views)
# Add positional embedding
if self.opt.use_pos_embedding:
x = self.pos_embedding(x)
# Main blocks
for blk_idx, blk in enumerate(self.enc_blocks):
x = forward_checkpointing(blk, x , gradient_checkpoint=self.opt.gradient_checkpoint_transformer)
x = forward_checkpointing(self.enc_norm, x, gradient_checkpoint=self.opt.gradient_checkpoint_transformer)
# Additional output block
if self.block_out is not None:
x = forward_checkpointing(self.block_out, x, gradient_checkpoint=self.opt.gradient_checkpoint_transformer)
# Reshape multi-view dimension back to batch dimension to decode independently
if self.opt.process_multi_views:
x = self.reshape_mv_temp_to_batch(x, num_input_multi_views=num_input_multi_views)
# Decode temporally and spatially into higher dimensions with convolutions
x = rearrange(x, 'b (t h w) c -> b c t h w', h=h, w=w)
if self.opt.use_patch_embeddings_encoder:
x = forward_checkpointing(self.deconv, x, gradient_checkpoint=self.opt.gradient_checkpoint_conv)
elif self.opt.use_cosmos_decoder:
x = self.decoder_cosmos(x, gradient_checkpoint=self.opt.gradient_checkpoint_conv)
# Get predicted subsampling mask
if self.opt.sub_sample_gaussians_factor is not None:
if self.opt.sub_sample_gaussians_type == 'learned':
x, x_mask = x[:, :-1], x[:, [-1]]
else:
x_mask = None
# Subsample number of gaussians
if self.opt.sub_sample_gaussians and self.opt.sub_sample_gaussians_factor is not None:
x = forward_checkpointing(self.subsample_x_and_rays_wrapper, x, rays_os, rays_ds, x_mask, gradient_checkpoint=self.opt.gradient_checkpoint_conv)
else:
x = rearrange(x, 'b c t h w -> b (t h w) c')
rays_os = rearrange(rays_os, 'b t c h w -> b (t h w) c')
rays_ds = rearrange(rays_ds, 'b t c h w -> b (t h w) c')
x_mask = None
# Set up gaussian attributes
x = forward_checkpointing(self.gaussian_processing, x, rays_os, rays_ds, gradient_checkpoint=self.opt.gradient_checkpoint_conv)
# Merge viewpoints into one gaussian vector
if self.opt.fuse_multi_views or not self.training:
x = self.reshape_mv_batch_to_temp(x, num_input_multi_views)
# Optionally prune gaussians as in Long LRM
x = self.gaussian_pruning(x)
return x, x_mask
def forward(self, data, skip_loss=False):
results = {}
loss = 0
rays_os = data['rays_os']
rays_ds = data['rays_ds']
plucker_embedding = data['plucker_embedding']
images = data['images_input_embed']
time_embeddings = data['time_embeddings']
cam_view = data['cam_view']
if len(cam_view.shape) == 5:
# B 1 T C D -> B T C D
cam_view = cam_view.squeeze(1)
intrinsics = data['intrinsics']
num_input_multi_views = data['num_input_multi_views']
B = images.shape[0]
# use the first view to predict gaussians
gaussians, gaussians_mask = self.forward_gaussians(images, plucker_embedding, rays_os, rays_ds, time_embeddings=time_embeddings, num_input_multi_views=num_input_multi_views) # [B, N, 14]
# always use white background
add_gs_render_kwargs = {}
if self.opt.deferred_bp:
bg_color = [1, 1, 1]
add_gs_render_kwargs = {'patch_size': self.opt.gs_render_patch_size}
else:
bg_color = torch.ones(3, dtype=gaussians.dtype, device=gaussians.device)
# render predictions
results = self.gs.render(gaussians, cam_view, bg_color=bg_color, intrinsics=intrinsics, **add_gs_render_kwargs)
# output
if self.training:
out = {}
out['images_pred'] = results['images_pred']
out['depths_pred'] = results['depths_pred']
# opacity loss
if self.opt.lambda_opacity > 0:
opacity = gaussians[..., 3]
else:
opacity = None
out["opacity_pred"] = opacity
else:
out = results
out['gaussians_mask_pred'] = gaussians_mask
out['gaussians'] = gaussians
return out
def gaussian_processing(self, x: torch.Tensor, rays_os: torch.Tensor, rays_ds: torch.Tensor):
# Pixel-aligned gaussians
if self.opt.gaussians_predict_offset:
pos_offset = x[..., -3:]
x = x[..., :-3]
distance, rgb, scaling, rotation, opacity = x.split([1, 3, 3, 4, 1], dim=-1)
w = torch.sigmoid(distance + self.opt.pre_sigmoid_distance_shift)
depths = self.dnear * (1 - w) + self.dfar * w
pos = rays_os + rays_ds * depths
# Add offsets to gaussian positions
if self.opt.gaussians_predict_offset and self.opt.use_gaussians_predict_offset:
if self.opt.gaussians_predict_offset_act == 'clamp':
pos_offset = pos_offset.clamp(self.opt.gaussians_predict_offset_range[0], self.opt.gaussians_predict_offset_range[1])
elif self.opt.gaussians_predict_offset_act == 'tanh':
pos_offset = self.opt.gaussians_predict_offset_range[1] * torch.tanh(pos_offset)
pos = pos + pos_offset
# Activations
opacity = self.opacity_act(opacity)
scale = self.scale_act(scaling)
rotation = self.rot_act(rotation)
rgbs = self.rgb_act(rgb)
gaussians = torch.cat([pos, opacity, scale, rotation, rgbs], dim=-1) # [B, N, 14]
return gaussians
def get_time_embedding(self, time_embeddings: torch.Tensor, V: int, num_input_multi_views: int = None):
# Split into input and target time embeddings
if self.opt.time_embedding_vae:
num_in_times = V
else:
time_embeddings = repeat(time_embeddings, 'b t c -> b t c h w', h=H, w=W)
num_in_times = self.opt.num_input_views
time_embeddings_input = time_embeddings[:, :num_in_times]
time_embeddings_target = time_embeddings[:, num_in_times: num_in_times + 1]
# Use the same target timestep for all input images
time_embeddings_target = repeat(time_embeddings_target, 'b 1 c h w -> b t c h w', t=num_in_times)
# Encode times
x_time_emb = self.patch_time_embed(time_embeddings_input)
x_time_emb_tgt = self.patch_time_embed_tgt(time_embeddings_target)
# Repeat time embeddings for each view
x_time_emb = self.repeat_to_mv(x_time_emb, num_input_multi_views=num_input_multi_views)
x_time_emb_tgt = self.repeat_to_mv(x_time_emb_tgt, num_input_multi_views=num_input_multi_views)
return x_time_emb, x_time_emb_tgt
def gaussian_pruning(self, gaussians: torch.Tensor):
# Gaussian pruning following Long LRM
prune_ratio = self.opt.gaussians_prune_ratio
if prune_ratio > 0:
opacity = gaussians[:, :, [3]]
num_gaussians = gaussians.shape[1]
keep_ratio = 1 - prune_ratio
random_ratio = self.opt.gaussians_random_ratio
random_ratio = keep_ratio * random_ratio
keep_ratio = keep_ratio - random_ratio
num_keep = int(num_gaussians * keep_ratio)
num_keep_random = int(num_gaussians * random_ratio)
# rank by opacity
idx_sort = opacity.argsort(dim=1, descending=True)
keep_idx = idx_sort[:, :num_keep]
if num_keep_random > 0:
rest_idx = idx_sort[:, num_keep:]
random_idx = rest_idx[:, torch.randperm(rest_idx.shape[1])[:num_keep_random]]
keep_idx = torch.cat([keep_idx, random_idx], dim=1)
gaussians = gaussians.gather(1, keep_idx.expand(-1, -1, gaussians.shape[-1]))
return gaussians
def subsample_x_and_rays_wrapper(self, x, rays_os, rays_ds, x_mask):
return subsample_x_and_rays(
x, rays_os, rays_ds, x_mask,
self.opt.sub_sample_gaussians_factor,
self.opt.sub_sample_gaussians_type,
self.opt.sub_sample_gaussians_type_tokens,
self.opt.sub_sample_gaussians_temperature,
self.training,
)
def _init_weights(self, m):
from timm.models.layers import trunc_normal_
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def subsample_views(self, x: torch.Tensor, num_input_multi_views: int):
x = rearrange(x, '(b v) ... -> b v ...', v=num_input_multi_views)
x = x[:, :self.opt.num_target_multi_views]
x = rearrange(x, 'b v ... -> (b v) ...')
return x
def reshape_mv_temp_to_batch(self, x, num_input_multi_views=None):
if num_input_multi_views is None:
num_input_multi_views = self.num_input_multi_views
if num_input_multi_views != 1:
if len(x.shape) == 5:
x = rearrange(x, 'b (v t) c h w -> (b v) t c h w', v=num_input_multi_views)
elif len(x.shape) == 3:
x = rearrange(x, 'b (v d) c -> (b v) d c', v=num_input_multi_views)
return x
def reshape_mv_batch_to_temp(self, x, num_input_multi_views=None):
if num_input_multi_views is None:
num_input_multi_views = self.num_input_multi_views
if num_input_multi_views != 1:
if len(x.shape) == 5:
x = rearrange(x, '(b v) t c h w -> b (v t) c h w', v=num_input_multi_views)
elif len(x.shape) == 3:
x = rearrange(x, '(b v) d c -> b (v d) c', v=num_input_multi_views)
return x
def reshape_mv_batch_to_mv(self, x, num_input_multi_views=None):
if num_input_multi_views is None:
num_input_multi_views = self.num_input_multi_views
if num_input_multi_views != 1:
if len(x.shape) == 5:
x = rearrange(x, '(b v) t c h w -> b v t c h w', v=num_input_multi_views)
elif len(x.shape) == 3:
x = rearrange(x, '(b v) d c -> b v d c', v=num_input_multi_views)
return x
def reshape_mv_batch_to_view(self, x, num_input_multi_views=None):
if num_input_multi_views is None:
num_input_multi_views = self.num_input_multi_views
if num_input_multi_views != 1:
if len(x.shape) == 3:
x = rearrange(x, '(b v) d c -> (b d) v c', v=num_input_multi_views)
return x
def repeat_to_mv(self, x, num_input_multi_views=None):
if num_input_multi_views is None:
num_input_multi_views = self.num_input_multi_views
if num_input_multi_views != 1:
if len(x.shape) == 5:
x = repeat(x, 'b t c h w -> (b v) t c h w', v=num_input_multi_views)
elif len(x.shape) == 3:
x = repeat(x, 'b d c -> (b v) d c', v=num_input_multi_views)
return x