Spaces:
Sleeping
Sleeping
| from dataclasses import dataclass, fields | |
| import torch | |
| from jaxtyping import Float, Bool, Int64, BFloat16 | |
| from torch import Tensor | |
| class Gaussians: | |
| means: Float[Tensor, "batch gaussian dim"] | |
| harmonics: Float[Tensor, "batch gaussian 3 d_sh"] | |
| opacities: Float[Tensor, "batch gaussian"] | |
| scales: Float[Tensor, "batch gaussian 3"] | |
| rotations_unnorm: Float[Tensor, "batch gaussian 4"] | |
| rotations: Float[Tensor, "batch gaussian 4"] | None = None | |
| covariances: Float[Tensor, "batch gaussian dim dim"] | None = None | |
| probabilities: Float[Tensor, "batch gaussian distr"] | None = None | |
| # mask: Bool[Tensor, "batch gaussian"] | None = None | |
| sel: Int64[Tensor, "valid_gaussian_1"] | None = None | |
| filter_3D: Float[Tensor, "batch gaussian"] | None = None | |
| gradients: Float[Tensor, "batch valid_gaussian_1 total_dim"] | BFloat16[Tensor, "batch valid_gaussian_1 total_dim"] | None = None | |
| norm_gradients: Float[Tensor, "batch valid_gaussian_1 total_dim"] | BFloat16[Tensor, "batch valid_gaussian_1 total_dim"] | None = None | |
| deltas: Float[Tensor, "batch valid_gaussian_2 d_delta"] | BFloat16[Tensor, "batch valid_gaussian_2 d_delta"] | None = None # In case of predicting scale and mag, the raw deltas are 2*total_dim, cannot use SGD loss directly | |
| visibility: Float[Tensor, "batch gaussian"] | None = None # visibility information at the end of the current batch for pruning | |
| visibility_aggregator: Float[Tensor, "batch gaussian"] | None = None # aggregates visibility over epoch | |
| stores_activated: bool = True # whether scales and opacities are stored in activated form | |
| nr_valid: int = -1 # the number of valid gaussians (without padding) | |
| EXCLUDED_FROM_MASKING = {"sel", "stores_activated", "deltas", "gradients", "norm_gradients", "valid_gaussians"} # deltas are predicted to non masked values | |
| def to(self, device=None, dtype=None) -> "Gaussians": | |
| """ Move all tensors to the specified device or dtype. """ | |
| def to_with_none(tensor): | |
| if isinstance(tensor, bool): | |
| return tensor | |
| elif isinstance(tensor, int): | |
| return tensor | |
| return tensor.to(device=device, dtype=dtype) if tensor is not None else None | |
| new_tensors = {field.name: to_with_none(getattr(self, field.name)) for field in fields(self)} | |
| return Gaussians(**new_tensors) | |
| def clone(self) -> "Gaussians": | |
| """ Clone all tensors. """ | |
| # handle None and bool fields | |
| new_tensors = {} | |
| for field in fields(self): | |
| tensor = getattr(self, field.name) | |
| if isinstance(tensor, bool): | |
| new_tensors[field.name] = tensor | |
| elif isinstance(tensor, int): | |
| new_tensors[field.name] = tensor | |
| elif tensor is not None: | |
| new_tensors[field.name] = tensor.clone() | |
| else: | |
| new_tensors[field.name] = None | |
| return Gaussians(**new_tensors) | |
| # Override __getitem__ to support indexing | |
| def __getitem__(self, idx) -> "Gaussians": | |
| new_tensors = {} | |
| for field in fields(self): | |
| tensor = getattr(self, field.name) | |
| if isinstance(tensor, bool): | |
| new_tensors[field.name] = tensor | |
| elif isinstance(tensor, int): | |
| new_tensors[field.name] = tensor | |
| elif tensor is not None and field.name not in self.EXCLUDED_FROM_MASKING: | |
| new_tensors[field.name] = tensor[idx] | |
| else: | |
| new_tensors[field.name] = None | |
| return Gaussians(**new_tensors) | |
| def sample_subset(self, sampled_indices) -> "Gaussians": | |
| """ Randomly sample a subset of gaussians. """ | |
| total_gaussians = self.means.shape[1] | |
| sample_num = len(sampled_indices) | |
| new_tensors = {} | |
| for field in fields(self): | |
| tensor = getattr(self, field.name) | |
| if tensor is not None: | |
| if isinstance(tensor, bool): | |
| new_tensors[field.name] = tensor | |
| elif isinstance(tensor, int): | |
| new_tensors[field.name] = tensor | |
| else: | |
| new_tensors[field.name] = tensor[:, sampled_indices] | |
| else: | |
| new_tensors[field.name] = None | |
| print(f"Sampled {sample_num} / {total_gaussians} gaussians.") | |
| return Gaussians(**new_tensors) | |
| def __len__(self): | |
| return self.means.shape[1] | |
| def update_object_by_curr_mask(self, **new_values) -> "Gaussians": | |
| """ Update certain element using the current mask. """ | |
| sel = self.sel | |
| new_tensors = {} | |
| for field in fields(self): | |
| tensor = getattr(self, field.name) # [B, G, ...] | |
| if tensor is not None: | |
| if field.name in new_values: | |
| new_value = new_values[field.name] # [B, G_valid, ...] | |
| if sel is None or new_value is None or field.name in self.EXCLUDED_FROM_MASKING: | |
| tensor = new_value | |
| else: | |
| tensor = tensor.clone() | |
| tensor[:, sel, ...] = new_value | |
| new_tensors[field.name] = tensor | |
| else: | |
| if field.name in new_values: | |
| if field.name in ["deltas", "gradients", "norm_gradients"]: | |
| # Special case: allow updating deltas even if it is None | |
| new_tensors[field.name] = new_values[field.name] | |
| continue | |
| assert new_values[field.name] is None, f"Cannot update a None field! {field.name}, got {new_values[field.name]}" | |
| new_tensors[field.name] = None | |
| return Gaussians(**new_tensors) | |