Learn2Splat / optgs /model /types.py
SteEsp's picture
Add Docker-based Learn2Splat demo (viser GUI)
78d2329 verified
from dataclasses import dataclass, fields
import torch
from jaxtyping import Float, Bool, Int64, BFloat16
from torch import Tensor
@dataclass
class Gaussians:
means: Float[Tensor, "batch gaussian dim"]
harmonics: Float[Tensor, "batch gaussian 3 d_sh"]
opacities: Float[Tensor, "batch gaussian"]
scales: Float[Tensor, "batch gaussian 3"]
rotations_unnorm: Float[Tensor, "batch gaussian 4"]
rotations: Float[Tensor, "batch gaussian 4"] | None = None
covariances: Float[Tensor, "batch gaussian dim dim"] | None = None
probabilities: Float[Tensor, "batch gaussian distr"] | None = None
# mask: Bool[Tensor, "batch gaussian"] | None = None
sel: Int64[Tensor, "valid_gaussian_1"] | None = None
filter_3D: Float[Tensor, "batch gaussian"] | None = None
gradients: Float[Tensor, "batch valid_gaussian_1 total_dim"] | BFloat16[Tensor, "batch valid_gaussian_1 total_dim"] | None = None
norm_gradients: Float[Tensor, "batch valid_gaussian_1 total_dim"] | BFloat16[Tensor, "batch valid_gaussian_1 total_dim"] | None = None
deltas: Float[Tensor, "batch valid_gaussian_2 d_delta"] | BFloat16[Tensor, "batch valid_gaussian_2 d_delta"] | None = None # In case of predicting scale and mag, the raw deltas are 2*total_dim, cannot use SGD loss directly
visibility: Float[Tensor, "batch gaussian"] | None = None # visibility information at the end of the current batch for pruning
visibility_aggregator: Float[Tensor, "batch gaussian"] | None = None # aggregates visibility over epoch
stores_activated: bool = True # whether scales and opacities are stored in activated form
nr_valid: int = -1 # the number of valid gaussians (without padding)
EXCLUDED_FROM_MASKING = {"sel", "stores_activated", "deltas", "gradients", "norm_gradients", "valid_gaussians"} # deltas are predicted to non masked values
def to(self, device=None, dtype=None) -> "Gaussians":
""" Move all tensors to the specified device or dtype. """
def to_with_none(tensor):
if isinstance(tensor, bool):
return tensor
elif isinstance(tensor, int):
return tensor
return tensor.to(device=device, dtype=dtype) if tensor is not None else None
new_tensors = {field.name: to_with_none(getattr(self, field.name)) for field in fields(self)}
return Gaussians(**new_tensors)
def clone(self) -> "Gaussians":
""" Clone all tensors. """
# handle None and bool fields
new_tensors = {}
for field in fields(self):
tensor = getattr(self, field.name)
if isinstance(tensor, bool):
new_tensors[field.name] = tensor
elif isinstance(tensor, int):
new_tensors[field.name] = tensor
elif tensor is not None:
new_tensors[field.name] = tensor.clone()
else:
new_tensors[field.name] = None
return Gaussians(**new_tensors)
# Override __getitem__ to support indexing
def __getitem__(self, idx) -> "Gaussians":
new_tensors = {}
for field in fields(self):
tensor = getattr(self, field.name)
if isinstance(tensor, bool):
new_tensors[field.name] = tensor
elif isinstance(tensor, int):
new_tensors[field.name] = tensor
elif tensor is not None and field.name not in self.EXCLUDED_FROM_MASKING:
new_tensors[field.name] = tensor[idx]
else:
new_tensors[field.name] = None
return Gaussians(**new_tensors)
def sample_subset(self, sampled_indices) -> "Gaussians":
""" Randomly sample a subset of gaussians. """
total_gaussians = self.means.shape[1]
sample_num = len(sampled_indices)
new_tensors = {}
for field in fields(self):
tensor = getattr(self, field.name)
if tensor is not None:
if isinstance(tensor, bool):
new_tensors[field.name] = tensor
elif isinstance(tensor, int):
new_tensors[field.name] = tensor
else:
new_tensors[field.name] = tensor[:, sampled_indices]
else:
new_tensors[field.name] = None
print(f"Sampled {sample_num} / {total_gaussians} gaussians.")
return Gaussians(**new_tensors)
def __len__(self):
return self.means.shape[1]
def update_object_by_curr_mask(self, **new_values) -> "Gaussians":
""" Update certain element using the current mask. """
sel = self.sel
new_tensors = {}
for field in fields(self):
tensor = getattr(self, field.name) # [B, G, ...]
if tensor is not None:
if field.name in new_values:
new_value = new_values[field.name] # [B, G_valid, ...]
if sel is None or new_value is None or field.name in self.EXCLUDED_FROM_MASKING:
tensor = new_value
else:
tensor = tensor.clone()
tensor[:, sel, ...] = new_value
new_tensors[field.name] = tensor
else:
if field.name in new_values:
if field.name in ["deltas", "gradients", "norm_gradients"]:
# Special case: allow updating deltas even if it is None
new_tensors[field.name] = new_values[field.name]
continue
assert new_values[field.name] is None, f"Cannot update a None field! {field.name}, got {new_values[field.name]}"
new_tensors[field.name] = None
return Gaussians(**new_tensors)