|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
GSLRM (Gaussian Splatting Large Reconstruction Model) |
|
|
|
|
|
This module implements a transformer-based model for generating 3D Gaussian splats |
|
|
from multi-view images. The model uses a combination of image tokenization, |
|
|
transformer processing, and Gaussian splatting for novel view synthesis. |
|
|
|
|
|
Classes: |
|
|
Renderer: Handles Gaussian splatting rendering operations |
|
|
GaussiansUpsampler: Converts transformer tokens to Gaussian parameters |
|
|
LossComputer: Computes various loss functions for training |
|
|
TransformTarget: Handles target image transformations (cropping, etc.) |
|
|
GSLRM: Main model class that orchestrates the entire pipeline |
|
|
""" |
|
|
|
|
|
import copy |
|
|
from typing import List, Optional, Tuple |
|
|
|
|
|
import lpips |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from easydict import EasyDict as edict |
|
|
from einops import rearrange |
|
|
from einops.layers.torch import Rearrange |
|
|
|
|
|
|
|
|
from .gaussians_renderer import ( |
|
|
GaussianModel, |
|
|
deferred_gaussian_render, |
|
|
render_opencv_cam, |
|
|
) |
|
|
from .transform_data import SplitData, TransformInput, TransformTarget |
|
|
from .utils_transformer import ( |
|
|
TransformerBlock, |
|
|
_init_weights, |
|
|
) |
|
|
|
|
|
class Renderer(nn.Module): |
|
|
""" |
|
|
Handles Gaussian splatting rendering operations. |
|
|
|
|
|
Supports both deferred rendering (for training with gradients) and |
|
|
standard rendering (for inference). |
|
|
""" |
|
|
|
|
|
def __init__(self, config: edict): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.scaling_modifier = config.model.gaussians.get("scaling_modifier", None) |
|
|
self.gaussians_model = GaussianModel( |
|
|
config.model.gaussians.sh_degree, |
|
|
self.scaling_modifier |
|
|
) |
|
|
|
|
|
print(f"Renderer initialized with scaling_modifier: {self.scaling_modifier}") |
|
|
|
|
|
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) |
|
|
def forward( |
|
|
self, |
|
|
xyz: torch.Tensor, |
|
|
features: torch.Tensor, |
|
|
scaling: torch.Tensor, |
|
|
rotation: torch.Tensor, |
|
|
opacity: torch.Tensor, |
|
|
height: int, |
|
|
width: int, |
|
|
C2W: torch.Tensor, |
|
|
fxfycxcy: torch.Tensor, |
|
|
deferred: bool = True, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Render Gaussian splats to images. |
|
|
|
|
|
Args: |
|
|
xyz: Gaussian positions |
|
|
features: Gaussian spherical harmonic features |
|
|
scaling: Gaussian scaling parameters |
|
|
rotation: Gaussian rotation quaternions |
|
|
opacity: Gaussian opacity values |
|
|
height: Output image height |
|
|
width: Output image width |
|
|
C2W: Camera-to-world transformation matrices |
|
|
fxfycxcy: Camera intrinsics (fx, fy, cx, cy) |
|
|
deferred: Whether to use deferred rendering (maintains gradients) |
|
|
|
|
|
Returns: |
|
|
Rendered images |
|
|
""" |
|
|
if deferred: |
|
|
return deferred_gaussian_render( |
|
|
xyz, features, scaling, rotation, opacity, |
|
|
height, width, C2W, fxfycxcy, self.scaling_modifier |
|
|
) |
|
|
else: |
|
|
return self._render_sequential( |
|
|
xyz, features, scaling, rotation, opacity, |
|
|
height, width, C2W, fxfycxcy |
|
|
) |
|
|
|
|
|
def _render_sequential( |
|
|
self, xyz, features, scaling, rotation, opacity, |
|
|
height, width, C2W, fxfycxcy |
|
|
) -> torch.Tensor: |
|
|
"""Sequential rendering without gradient support (used for inference).""" |
|
|
b, v = C2W.size(0), C2W.size(1) |
|
|
renderings = torch.zeros( |
|
|
b, v, 3, height, width, dtype=torch.float32, device=xyz.device |
|
|
) |
|
|
|
|
|
for i in range(b): |
|
|
pc = self.gaussians_model.set_data( |
|
|
xyz[i], features[i], scaling[i], rotation[i], opacity[i] |
|
|
) |
|
|
for j in range(v): |
|
|
renderings[i, j] = render_opencv_cam( |
|
|
pc, height, width, C2W[i, j], fxfycxcy[i, j] |
|
|
)["render"] |
|
|
|
|
|
return renderings |
|
|
|
|
|
|
|
|
class GaussiansUpsampler(nn.Module): |
|
|
""" |
|
|
Converts transformer output tokens to Gaussian splatting parameters. |
|
|
|
|
|
Takes high-dimensional transformer features and projects them to the |
|
|
concatenated Gaussian parameter space (xyz + features + scaling + rotation + opacity). |
|
|
""" |
|
|
|
|
|
def __init__(self, config: edict): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.layernorm = nn.LayerNorm(config.model.transformer.d, bias=False) |
|
|
|
|
|
|
|
|
sh_dim = (config.model.gaussians.sh_degree + 1) ** 2 * 3 |
|
|
gaussian_param_dim = 3 + sh_dim + 3 + 4 + 1 |
|
|
|
|
|
|
|
|
upsample_factor = config.model.gaussians.upsampler.upsample_factor |
|
|
if upsample_factor > 1: |
|
|
raise NotImplementedError("GaussiansUpsampler only supports upsample_factor=1") |
|
|
|
|
|
|
|
|
self.linear = nn.Linear( |
|
|
config.model.transformer.d, |
|
|
gaussian_param_dim, |
|
|
bias=False, |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
gaussians: torch.Tensor, |
|
|
images: torch.Tensor |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Convert transformer tokens to Gaussian parameters. |
|
|
|
|
|
Args: |
|
|
gaussians: Transformer output tokens for Gaussians |
|
|
images: Image tokens (unused but kept for compatibility) |
|
|
|
|
|
Returns: |
|
|
Raw Gaussian parameters (before conversion to final format) |
|
|
""" |
|
|
upsample_factor = self.config.model.gaussians.upsampler.upsample_factor |
|
|
if upsample_factor > 1: |
|
|
raise NotImplementedError("GaussiansUpsampler only supports upsample_factor=1") |
|
|
|
|
|
return self.linear(self.layernorm(gaussians)) |
|
|
|
|
|
def to_gs(self, gaussians: torch.Tensor) -> Tuple[torch.Tensor, ...]: |
|
|
""" |
|
|
Convert raw Gaussian parameters to final format. |
|
|
|
|
|
Args: |
|
|
gaussians: Raw Gaussian parameters [b, n_gaussians, param_dim] |
|
|
|
|
|
Returns: |
|
|
Tuple of (xyz, features, scaling, rotation, opacity) |
|
|
""" |
|
|
sh_dim = (self.config.model.gaussians.sh_degree + 1) ** 2 * 3 |
|
|
|
|
|
|
|
|
xyz, features, scaling, rotation, opacity = gaussians.split( |
|
|
[3, sh_dim, 3, 4, 1], dim=2 |
|
|
) |
|
|
|
|
|
|
|
|
features = features.reshape( |
|
|
features.size(0), |
|
|
features.size(1), |
|
|
(self.config.model.gaussians.sh_degree + 1) ** 2, |
|
|
3, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
scaling = (scaling - 2.3).clamp(max=-1.20) |
|
|
|
|
|
|
|
|
opacity = opacity - 2.0 |
|
|
|
|
|
return xyz, features, scaling, rotation, opacity |
|
|
|
|
|
class GSLRM(nn.Module): |
|
|
""" |
|
|
Gaussian Splatting Large Reconstruction Model. |
|
|
|
|
|
A transformer-based model that generates 3D Gaussian splats from multi-view images. |
|
|
The model processes input images through tokenization, transformer layers, and |
|
|
generates Gaussian parameters for novel view synthesis. |
|
|
|
|
|
Architecture: |
|
|
1. Image tokenization with patch-based encoding |
|
|
2. Transformer processing with Gaussian positional embeddings |
|
|
3. Gaussian parameter generation and upsampling |
|
|
4. Rendering and loss computation |
|
|
""" |
|
|
|
|
|
def __init__(self, config: edict): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self._init_data_processors(config) |
|
|
|
|
|
|
|
|
self._init_tokenizer(config) |
|
|
self._init_positional_embeddings(config) |
|
|
self._init_transformer(config) |
|
|
self._init_gaussian_modules(config) |
|
|
self._init_rendering_modules(config) |
|
|
|
|
|
|
|
|
self._init_training_state(config) |
|
|
|
|
|
def _init_data_processors(self, config: edict) -> None: |
|
|
"""Initialize data splitting and transformation modules.""" |
|
|
self.data_splitter = SplitData(config) |
|
|
self.input_transformer = TransformInput(config) |
|
|
self.target_transformer = TransformTarget(config) |
|
|
|
|
|
def _init_tokenizer(self, config: edict) -> None: |
|
|
"""Initialize image tokenization pipeline.""" |
|
|
patch_size = config.model.image_tokenizer.patch_size |
|
|
input_channels = config.model.image_tokenizer.in_channels |
|
|
hidden_dim = config.model.transformer.d |
|
|
|
|
|
self.patch_embedder = nn.Sequential( |
|
|
Rearrange( |
|
|
"batch views channels (height patch_h) (width patch_w) -> (batch views) (height width) (patch_h patch_w channels)", |
|
|
patch_h=patch_size, |
|
|
patch_w=patch_size, |
|
|
), |
|
|
nn.Linear( |
|
|
input_channels * (patch_size ** 2), |
|
|
hidden_dim, |
|
|
bias=False, |
|
|
), |
|
|
) |
|
|
self.patch_embedder.apply(_init_weights) |
|
|
|
|
|
def _init_positional_embeddings(self, config: edict) -> None: |
|
|
"""Initialize positional embeddings for reference/source markers and Gaussians.""" |
|
|
hidden_dim = config.model.transformer.d |
|
|
|
|
|
|
|
|
self.view_type_embeddings = None |
|
|
if config.model.get("add_refsrc_marker", False): |
|
|
self.view_type_embeddings = nn.Parameter( |
|
|
torch.randn(2, hidden_dim) |
|
|
) |
|
|
nn.init.trunc_normal_(self.view_type_embeddings, std=0.02) |
|
|
|
|
|
|
|
|
num_gaussians = config.model.gaussians.n_gaussians |
|
|
self.gaussian_position_embeddings = nn.Parameter( |
|
|
torch.randn(num_gaussians, hidden_dim) |
|
|
) |
|
|
nn.init.trunc_normal_(self.gaussian_position_embeddings, std=0.02) |
|
|
|
|
|
def _init_transformer(self, config: edict) -> None: |
|
|
"""Initialize transformer architecture.""" |
|
|
hidden_dim = config.model.transformer.d |
|
|
head_dim = config.model.transformer.d_head |
|
|
num_layers = config.model.transformer.n_layer |
|
|
|
|
|
self.input_layer_norm = nn.LayerNorm(hidden_dim, bias=False) |
|
|
self.transformer_layers = nn.ModuleList([ |
|
|
TransformerBlock(hidden_dim, head_dim) |
|
|
for _ in range(num_layers) |
|
|
]) |
|
|
self.transformer_layers.apply(_init_weights) |
|
|
|
|
|
def _init_gaussian_modules(self, config: edict) -> None: |
|
|
"""Initialize Gaussian parameter generation modules.""" |
|
|
hidden_dim = config.model.transformer.d |
|
|
patch_size = config.model.image_tokenizer.patch_size |
|
|
sh_degree = config.model.gaussians.sh_degree |
|
|
|
|
|
|
|
|
|
|
|
gaussian_param_dim = 3 + (sh_degree + 1) ** 2 * 3 + 3 + 4 + 1 |
|
|
|
|
|
|
|
|
self.gaussian_upsampler = GaussiansUpsampler(config) |
|
|
self.gaussian_upsampler.apply(_init_weights) |
|
|
|
|
|
|
|
|
self.pixel_gaussian_decoder = nn.Sequential( |
|
|
nn.LayerNorm(hidden_dim, bias=False), |
|
|
nn.Linear( |
|
|
hidden_dim, |
|
|
(patch_size ** 2) * gaussian_param_dim, |
|
|
bias=False, |
|
|
), |
|
|
) |
|
|
self.pixel_gaussian_decoder.apply(_init_weights) |
|
|
|
|
|
def _init_rendering_modules(self, config: edict) -> None: |
|
|
"""Initialize rendering and loss computation modules.""" |
|
|
self.gaussian_renderer = Renderer(config) |
|
|
|
|
|
def _init_training_state(self, config: edict) -> None: |
|
|
"""Initialize training state management variables.""" |
|
|
self.training_step = None |
|
|
self.training_start_step = None |
|
|
self.training_max_step = None |
|
|
self.original_config = copy.deepcopy(config) |
|
|
|
|
|
|
|
|
def _create_transformer_layer_runner(self, start_layer: int, end_layer: int): |
|
|
""" |
|
|
Create a function to run a subset of transformer layers. |
|
|
|
|
|
Args: |
|
|
start_layer: Starting layer index |
|
|
end_layer: Ending layer index (exclusive) |
|
|
|
|
|
Returns: |
|
|
Function that processes tokens through specified layers |
|
|
""" |
|
|
def run_transformer_layers(token_sequence: torch.Tensor) -> torch.Tensor: |
|
|
for layer_idx in range(start_layer, min(end_layer, len(self.transformer_layers))): |
|
|
token_sequence = self.transformer_layers[layer_idx](token_sequence) |
|
|
return token_sequence |
|
|
return run_transformer_layers |
|
|
|
|
|
def _create_posed_images_with_plucker(self, input_data: edict) -> torch.Tensor: |
|
|
""" |
|
|
Create posed images by concatenating RGB with Plucker coordinates. |
|
|
|
|
|
Args: |
|
|
input_data: Input data containing images and ray information |
|
|
|
|
|
Returns: |
|
|
Posed images with Plucker coordinates [batch, views, channels, height, width] |
|
|
""" |
|
|
|
|
|
normalized_rgb = input_data.image[:, :, :3, :, :] * 2.0 - 1.0 |
|
|
|
|
|
if self.config.model.get("use_custom_plucker", False): |
|
|
|
|
|
ray_origin_dot_direction = torch.sum( |
|
|
-input_data.ray_o * input_data.ray_d, dim=2, keepdim=True |
|
|
) |
|
|
nearest_points = input_data.ray_o + ray_origin_dot_direction * input_data.ray_d |
|
|
|
|
|
return torch.cat([ |
|
|
normalized_rgb, |
|
|
input_data.ray_d, |
|
|
nearest_points, |
|
|
], dim=2) |
|
|
|
|
|
elif self.config.model.get("use_aug_plucker", False): |
|
|
|
|
|
ray_cross_product = torch.cross(input_data.ray_o, input_data.ray_d, dim=2) |
|
|
ray_origin_dot_direction = torch.sum( |
|
|
-input_data.ray_o * input_data.ray_d, dim=2, keepdim=True |
|
|
) |
|
|
nearest_points = input_data.ray_o + ray_origin_dot_direction * input_data.ray_d |
|
|
|
|
|
return torch.cat([ |
|
|
normalized_rgb, |
|
|
ray_cross_product, |
|
|
input_data.ray_d, |
|
|
nearest_points, |
|
|
], dim=2) |
|
|
|
|
|
else: |
|
|
|
|
|
ray_cross_product = torch.cross(input_data.ray_o, input_data.ray_d, dim=2) |
|
|
|
|
|
return torch.cat([ |
|
|
normalized_rgb, |
|
|
ray_cross_product, |
|
|
input_data.ray_d, |
|
|
], dim=2) |
|
|
|
|
|
def _add_view_type_embeddings( |
|
|
self, |
|
|
image_tokens: torch.Tensor, |
|
|
batch_size: int, |
|
|
num_views: int, |
|
|
num_patches: int, |
|
|
hidden_dim: int |
|
|
) -> torch.Tensor: |
|
|
"""Add view type embeddings to distinguish reference vs source views.""" |
|
|
image_tokens = image_tokens.reshape(batch_size, num_views, num_patches, hidden_dim) |
|
|
|
|
|
|
|
|
view_markers = [self.view_type_embeddings[0]] + [ |
|
|
self.view_type_embeddings[1] for _ in range(1, num_views) |
|
|
] |
|
|
view_markers = torch.stack(view_markers, dim=0)[None, :, None, :] |
|
|
|
|
|
|
|
|
image_tokens = image_tokens + view_markers |
|
|
return image_tokens.reshape(batch_size, num_views * num_patches, hidden_dim) |
|
|
|
|
|
def _process_through_transformer( |
|
|
self, |
|
|
gaussian_tokens: torch.Tensor, |
|
|
image_tokens: torch.Tensor |
|
|
) -> torch.Tensor: |
|
|
"""Process combined tokens through transformer with gradient checkpointing.""" |
|
|
|
|
|
combined_tokens = torch.cat((gaussian_tokens, image_tokens), dim=1) |
|
|
combined_tokens = self.input_layer_norm(combined_tokens) |
|
|
|
|
|
|
|
|
checkpoint_interval = self.config.training.runtime.grad_checkpoint_every |
|
|
num_layers = len(self.transformer_layers) |
|
|
|
|
|
for start_idx in range(0, num_layers, checkpoint_interval): |
|
|
end_idx = start_idx + checkpoint_interval |
|
|
layer_runner = self._create_transformer_layer_runner(start_idx, end_idx) |
|
|
|
|
|
combined_tokens = torch.utils.checkpoint.checkpoint( |
|
|
layer_runner, |
|
|
combined_tokens, |
|
|
use_reentrant=False, |
|
|
) |
|
|
|
|
|
return combined_tokens |
|
|
|
|
|
def _apply_hard_pixel_alignment( |
|
|
self, |
|
|
pixel_aligned_xyz: torch.Tensor, |
|
|
input_data: edict |
|
|
) -> torch.Tensor: |
|
|
"""Apply hard pixel alignment to ensure Gaussians align with ray directions.""" |
|
|
depth_bias = self.config.model.get("depth_preact_bias", 0.0) |
|
|
|
|
|
|
|
|
depth_values = torch.sigmoid( |
|
|
pixel_aligned_xyz.mean(dim=2, keepdim=True) + depth_bias |
|
|
) |
|
|
|
|
|
|
|
|
if (self.config.model.get("use_aug_plucker", False) or |
|
|
self.config.model.get("use_custom_plucker", False)): |
|
|
|
|
|
ray_origin_dot_direction = torch.sum( |
|
|
-input_data.ray_o * input_data.ray_d, dim=2, keepdim=True |
|
|
) |
|
|
depth_values = (2.0 * depth_values - 1.0) * 1.8 + ray_origin_dot_direction |
|
|
|
|
|
elif (self.config.model.get("depth_min", -1.0) > 0.0 and |
|
|
self.config.model.get("depth_max", -1.0) > 0.0): |
|
|
|
|
|
depth_min = self.config.model.depth_min |
|
|
depth_max = self.config.model.depth_max |
|
|
depth_values = depth_values * (depth_max - depth_min) + depth_min |
|
|
|
|
|
elif self.config.model.get("depth_reference_origin", False): |
|
|
|
|
|
ray_origin_norm = input_data.ray_o.norm(dim=2, p=2, keepdim=True) |
|
|
depth_values = (2.0 * depth_values - 1.0) * 1.8 + ray_origin_norm |
|
|
|
|
|
else: |
|
|
|
|
|
depth_values = (2.0 * depth_values - 1.0) * 1.5 + 2.7 |
|
|
|
|
|
|
|
|
aligned_positions = input_data.ray_o + depth_values * input_data.ray_d |
|
|
|
|
|
|
|
|
if (self.config.model.get("clip_xyz", False) and |
|
|
not self.config.inference): |
|
|
aligned_positions = aligned_positions.clamp(-1.0, 1.0) |
|
|
|
|
|
return aligned_positions |
|
|
|
|
|
def _create_gaussian_models_and_stats( |
|
|
self, |
|
|
xyz: torch.Tensor, |
|
|
features: torch.Tensor, |
|
|
scaling: torch.Tensor, |
|
|
rotation: torch.Tensor, |
|
|
opacity: torch.Tensor, |
|
|
num_pixel_aligned: int, |
|
|
num_views: int, |
|
|
height: int, |
|
|
width: int, |
|
|
patch_size: int |
|
|
) -> Tuple[List, torch.Tensor, List[float]]: |
|
|
""" |
|
|
Create Gaussian models for each batch item and compute usage statistics. |
|
|
|
|
|
Returns: |
|
|
Tuple of (gaussian_models, pixel_aligned_positions, usage_statistics) |
|
|
""" |
|
|
gaussian_models = [] |
|
|
pixel_aligned_positions_list = [] |
|
|
usage_statistics = [] |
|
|
|
|
|
batch_size = xyz.size(0) |
|
|
opacity_threshold = 0.05 |
|
|
|
|
|
for batch_idx in range(batch_size): |
|
|
|
|
|
self.gaussian_renderer.gaussians_model.empty() |
|
|
gaussian_model = copy.deepcopy(self.gaussian_renderer.gaussians_model) |
|
|
|
|
|
|
|
|
gaussian_model = gaussian_model.set_data( |
|
|
xyz[batch_idx].detach().float(), |
|
|
features[batch_idx].detach().float(), |
|
|
scaling[batch_idx].detach().float(), |
|
|
rotation[batch_idx].detach().float(), |
|
|
opacity[batch_idx].detach().float(), |
|
|
) |
|
|
gaussian_models.append(gaussian_model) |
|
|
|
|
|
|
|
|
opacity_mask = gaussian_model.get_opacity > opacity_threshold |
|
|
usage_ratio = opacity_mask.sum() / opacity_mask.numel() |
|
|
if torch.is_tensor(usage_ratio): |
|
|
usage_ratio = usage_ratio.item() |
|
|
usage_statistics.append(usage_ratio) |
|
|
|
|
|
|
|
|
pixel_xyz = gaussian_model.get_xyz[-num_pixel_aligned:, :] |
|
|
pixel_xyz_reshaped = rearrange( |
|
|
pixel_xyz, |
|
|
"(views height width patch_h patch_w) coords -> views coords (height patch_h) (width patch_w)", |
|
|
views=num_views, |
|
|
height=height // patch_size, |
|
|
width=width // patch_size, |
|
|
patch_h=patch_size, |
|
|
patch_w=patch_size, |
|
|
) |
|
|
pixel_aligned_positions_list.append(pixel_xyz_reshaped) |
|
|
|
|
|
|
|
|
pixel_aligned_positions = torch.stack(pixel_aligned_positions_list, dim=0) |
|
|
|
|
|
return gaussian_models, pixel_aligned_positions, usage_statistics |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
batch_data: edict, |
|
|
create_visual: bool = False, |
|
|
split_data: bool = True |
|
|
) -> edict: |
|
|
""" |
|
|
Forward pass of the GSLRM model. |
|
|
|
|
|
Args: |
|
|
batch_data: Input batch containing: |
|
|
- image: Multi-view images [batch, views, channels, height, width] |
|
|
- fxfycxcy: Camera intrinsics [batch, views, 4] |
|
|
- c2w: Camera-to-world matrices [batch, views, 4, 4] |
|
|
create_visual: Whether to create visualization outputs |
|
|
split_data: Whether to split input/target data |
|
|
|
|
|
Returns: |
|
|
Dictionary containing model outputs including Gaussians, renders, and losses |
|
|
""" |
|
|
with torch.no_grad(): |
|
|
target_data = None |
|
|
if split_data: |
|
|
batch_data, target_data = self.data_splitter( |
|
|
batch_data, self.config.training.dataset.target_has_input |
|
|
) |
|
|
target_data = self.target_transformer(target_data) |
|
|
|
|
|
input_data = self.input_transformer(batch_data) |
|
|
|
|
|
|
|
|
posed_images = self._create_posed_images_with_plucker(input_data) |
|
|
|
|
|
|
|
|
batch_size, num_views, channels, height, width = posed_images.size() |
|
|
|
|
|
|
|
|
image_patch_tokens = self.patch_embedder(posed_images) |
|
|
_, num_patches, hidden_dim = image_patch_tokens.size() |
|
|
image_patch_tokens = image_patch_tokens.reshape( |
|
|
batch_size, num_views * num_patches, hidden_dim |
|
|
) |
|
|
|
|
|
|
|
|
if self.view_type_embeddings is not None: |
|
|
image_patch_tokens = self._add_view_type_embeddings( |
|
|
image_patch_tokens, batch_size, num_views, num_patches, hidden_dim |
|
|
) |
|
|
|
|
|
|
|
|
gaussian_tokens = self.gaussian_position_embeddings.expand(batch_size, -1, -1) |
|
|
|
|
|
|
|
|
combined_tokens = self._process_through_transformer( |
|
|
gaussian_tokens, image_patch_tokens |
|
|
) |
|
|
|
|
|
|
|
|
num_gaussians = self.config.model.gaussians.n_gaussians |
|
|
gaussian_tokens, image_patch_tokens = combined_tokens.split( |
|
|
[num_gaussians, num_views * num_patches], dim=1 |
|
|
) |
|
|
|
|
|
|
|
|
gaussian_params = self.gaussian_upsampler(gaussian_tokens, image_patch_tokens) |
|
|
|
|
|
|
|
|
pixel_aligned_gaussian_params = self.pixel_gaussian_decoder(image_patch_tokens) |
|
|
|
|
|
|
|
|
sh_degree = self.config.model.gaussians.sh_degree |
|
|
gaussian_param_dim = 3 + (sh_degree + 1) ** 2 * 3 + 3 + 4 + 1 |
|
|
|
|
|
pixel_aligned_gaussian_params = pixel_aligned_gaussian_params.reshape( |
|
|
batch_size, -1, gaussian_param_dim |
|
|
) |
|
|
num_pixel_aligned_gaussians = pixel_aligned_gaussian_params.size(1) |
|
|
|
|
|
|
|
|
all_gaussian_params = torch.cat((gaussian_params, pixel_aligned_gaussian_params), dim=1) |
|
|
|
|
|
|
|
|
xyz, features, scaling, rotation, opacity = self.gaussian_upsampler.to_gs(all_gaussian_params) |
|
|
|
|
|
|
|
|
pixel_aligned_xyz = xyz[:, -num_pixel_aligned_gaussians:, :] |
|
|
patch_size = self.config.model.image_tokenizer.patch_size |
|
|
|
|
|
pixel_aligned_xyz = rearrange( |
|
|
pixel_aligned_xyz, |
|
|
"batch (views height width patch_h patch_w) coords -> batch views coords (height patch_h) (width patch_w)", |
|
|
views=num_views, |
|
|
height=height // patch_size, |
|
|
width=width // patch_size, |
|
|
patch_h=patch_size, |
|
|
patch_w=patch_size, |
|
|
) |
|
|
|
|
|
|
|
|
if self.config.model.hard_pixelalign: |
|
|
pixel_aligned_xyz = self._apply_hard_pixel_alignment( |
|
|
pixel_aligned_xyz, input_data |
|
|
) |
|
|
|
|
|
|
|
|
pixel_aligned_xyz_flat = rearrange( |
|
|
pixel_aligned_xyz, |
|
|
"batch views coords (height patch_h) (width patch_w) -> batch (views height width patch_h patch_w) coords", |
|
|
patch_h=patch_size, |
|
|
patch_w=patch_size, |
|
|
) |
|
|
|
|
|
|
|
|
xyz = torch.cat( |
|
|
(xyz[:, :-num_pixel_aligned_gaussians, :], pixel_aligned_xyz_flat), |
|
|
dim=1 |
|
|
) |
|
|
|
|
|
|
|
|
gaussian_splat_result = edict( |
|
|
xyz=xyz, |
|
|
features=features, |
|
|
scaling=scaling, |
|
|
rotation=rotation, |
|
|
opacity=opacity, |
|
|
) |
|
|
|
|
|
|
|
|
rendered_images = None |
|
|
|
|
|
if target_data is not None: |
|
|
target_height, target_width = target_data.image.size(3), target_data.image.size(4) |
|
|
|
|
|
|
|
|
rendered_images = self.gaussian_renderer( |
|
|
xyz, features, scaling, rotation, opacity, |
|
|
target_height, target_width, |
|
|
C2W=target_data.c2w, |
|
|
fxfycxcy=target_data.fxfycxcy, |
|
|
) |
|
|
|
|
|
|
|
|
gaussian_models, pixel_aligned_positions, usage_statistics = self._create_gaussian_models_and_stats( |
|
|
xyz, features, scaling, rotation, opacity, |
|
|
num_pixel_aligned_gaussians, num_views, height, width, patch_size |
|
|
) |
|
|
|
|
|
|
|
|
return edict( |
|
|
input=input_data, |
|
|
target=target_data, |
|
|
gaussians=gaussian_models, |
|
|
pixelalign_xyz=pixel_aligned_positions, |
|
|
img_tokens=image_patch_tokens, |
|
|
loss_metrics=None, |
|
|
render=rendered_images, |
|
|
) |