Spaces:
Sleeping
Sleeping
File size: 5,814 Bytes
78d2329 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 | from dataclasses import dataclass, fields
import torch
from jaxtyping import Float, Bool, Int64, BFloat16
from torch import Tensor
@dataclass
class Gaussians:
means: Float[Tensor, "batch gaussian dim"]
harmonics: Float[Tensor, "batch gaussian 3 d_sh"]
opacities: Float[Tensor, "batch gaussian"]
scales: Float[Tensor, "batch gaussian 3"]
rotations_unnorm: Float[Tensor, "batch gaussian 4"]
rotations: Float[Tensor, "batch gaussian 4"] | None = None
covariances: Float[Tensor, "batch gaussian dim dim"] | None = None
probabilities: Float[Tensor, "batch gaussian distr"] | None = None
# mask: Bool[Tensor, "batch gaussian"] | None = None
sel: Int64[Tensor, "valid_gaussian_1"] | None = None
filter_3D: Float[Tensor, "batch gaussian"] | None = None
gradients: Float[Tensor, "batch valid_gaussian_1 total_dim"] | BFloat16[Tensor, "batch valid_gaussian_1 total_dim"] | None = None
norm_gradients: Float[Tensor, "batch valid_gaussian_1 total_dim"] | BFloat16[Tensor, "batch valid_gaussian_1 total_dim"] | None = None
deltas: Float[Tensor, "batch valid_gaussian_2 d_delta"] | BFloat16[Tensor, "batch valid_gaussian_2 d_delta"] | None = None # In case of predicting scale and mag, the raw deltas are 2*total_dim, cannot use SGD loss directly
visibility: Float[Tensor, "batch gaussian"] | None = None # visibility information at the end of the current batch for pruning
visibility_aggregator: Float[Tensor, "batch gaussian"] | None = None # aggregates visibility over epoch
stores_activated: bool = True # whether scales and opacities are stored in activated form
nr_valid: int = -1 # the number of valid gaussians (without padding)
EXCLUDED_FROM_MASKING = {"sel", "stores_activated", "deltas", "gradients", "norm_gradients", "valid_gaussians"} # deltas are predicted to non masked values
def to(self, device=None, dtype=None) -> "Gaussians":
""" Move all tensors to the specified device or dtype. """
def to_with_none(tensor):
if isinstance(tensor, bool):
return tensor
elif isinstance(tensor, int):
return tensor
return tensor.to(device=device, dtype=dtype) if tensor is not None else None
new_tensors = {field.name: to_with_none(getattr(self, field.name)) for field in fields(self)}
return Gaussians(**new_tensors)
def clone(self) -> "Gaussians":
""" Clone all tensors. """
# handle None and bool fields
new_tensors = {}
for field in fields(self):
tensor = getattr(self, field.name)
if isinstance(tensor, bool):
new_tensors[field.name] = tensor
elif isinstance(tensor, int):
new_tensors[field.name] = tensor
elif tensor is not None:
new_tensors[field.name] = tensor.clone()
else:
new_tensors[field.name] = None
return Gaussians(**new_tensors)
# Override __getitem__ to support indexing
def __getitem__(self, idx) -> "Gaussians":
new_tensors = {}
for field in fields(self):
tensor = getattr(self, field.name)
if isinstance(tensor, bool):
new_tensors[field.name] = tensor
elif isinstance(tensor, int):
new_tensors[field.name] = tensor
elif tensor is not None and field.name not in self.EXCLUDED_FROM_MASKING:
new_tensors[field.name] = tensor[idx]
else:
new_tensors[field.name] = None
return Gaussians(**new_tensors)
def sample_subset(self, sampled_indices) -> "Gaussians":
""" Randomly sample a subset of gaussians. """
total_gaussians = self.means.shape[1]
sample_num = len(sampled_indices)
new_tensors = {}
for field in fields(self):
tensor = getattr(self, field.name)
if tensor is not None:
if isinstance(tensor, bool):
new_tensors[field.name] = tensor
elif isinstance(tensor, int):
new_tensors[field.name] = tensor
else:
new_tensors[field.name] = tensor[:, sampled_indices]
else:
new_tensors[field.name] = None
print(f"Sampled {sample_num} / {total_gaussians} gaussians.")
return Gaussians(**new_tensors)
def __len__(self):
return self.means.shape[1]
def update_object_by_curr_mask(self, **new_values) -> "Gaussians":
""" Update certain element using the current mask. """
sel = self.sel
new_tensors = {}
for field in fields(self):
tensor = getattr(self, field.name) # [B, G, ...]
if tensor is not None:
if field.name in new_values:
new_value = new_values[field.name] # [B, G_valid, ...]
if sel is None or new_value is None or field.name in self.EXCLUDED_FROM_MASKING:
tensor = new_value
else:
tensor = tensor.clone()
tensor[:, sel, ...] = new_value
new_tensors[field.name] = tensor
else:
if field.name in new_values:
if field.name in ["deltas", "gradients", "norm_gradients"]:
# Special case: allow updating deltas even if it is None
new_tensors[field.name] = new_values[field.name]
continue
assert new_values[field.name] is None, f"Cannot update a None field! {field.name}, got {new_values[field.name]}"
new_tensors[field.name] = None
return Gaussians(**new_tensors)
|