File size: 5,814 Bytes
78d2329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from dataclasses import dataclass, fields

import torch
from jaxtyping import Float, Bool, Int64, BFloat16
from torch import Tensor


@dataclass
class Gaussians:
    means: Float[Tensor, "batch gaussian dim"]
    harmonics: Float[Tensor, "batch gaussian 3 d_sh"]
    opacities: Float[Tensor, "batch gaussian"]
    scales: Float[Tensor, "batch gaussian 3"]
    rotations_unnorm: Float[Tensor, "batch gaussian 4"]
    rotations: Float[Tensor, "batch gaussian 4"] | None = None
    covariances: Float[Tensor, "batch gaussian dim dim"] | None = None
    probabilities: Float[Tensor, "batch gaussian distr"] | None = None
    # mask: Bool[Tensor, "batch gaussian"] | None = None
    sel: Int64[Tensor, "valid_gaussian_1"] | None = None
    filter_3D: Float[Tensor, "batch gaussian"] | None = None
    gradients: Float[Tensor, "batch valid_gaussian_1 total_dim"] | BFloat16[Tensor, "batch valid_gaussian_1 total_dim"] | None = None
    norm_gradients: Float[Tensor, "batch valid_gaussian_1 total_dim"] | BFloat16[Tensor, "batch valid_gaussian_1 total_dim"] | None = None
    deltas: Float[Tensor, "batch valid_gaussian_2 d_delta"] | BFloat16[Tensor, "batch valid_gaussian_2 d_delta"] | None = None  # In case of predicting scale and mag, the raw deltas are 2*total_dim, cannot use SGD loss directly
    visibility: Float[Tensor, "batch gaussian"] | None = None  # visibility information at the end of the current batch for pruning
    visibility_aggregator: Float[Tensor, "batch gaussian"] | None = None  # aggregates visibility over epoch
    stores_activated: bool = True  # whether scales and opacities are stored in activated form
    nr_valid: int = -1  # the number of valid gaussians (without padding)

    EXCLUDED_FROM_MASKING = {"sel", "stores_activated", "deltas", "gradients", "norm_gradients", "valid_gaussians"}  # deltas are predicted to non masked values
    
    def to(self, device=None, dtype=None) -> "Gaussians":
        """ Move all tensors to the specified device or dtype. """
        def to_with_none(tensor):
            if isinstance(tensor, bool):
                return tensor
            elif isinstance(tensor, int):
                return tensor
            return tensor.to(device=device, dtype=dtype) if tensor is not None else None

        new_tensors = {field.name: to_with_none(getattr(self, field.name)) for field in fields(self)}

        return Gaussians(**new_tensors)

    def clone(self) -> "Gaussians":
        """ Clone all tensors. """
        # handle None and bool fields
        new_tensors = {}
        for field in fields(self):
            tensor = getattr(self, field.name)
            if isinstance(tensor, bool):
                new_tensors[field.name] = tensor
            elif isinstance(tensor, int):
                new_tensors[field.name] = tensor
            elif tensor is not None:
                new_tensors[field.name] = tensor.clone()
            else:
                new_tensors[field.name] = None

        return Gaussians(**new_tensors)

    # Override __getitem__ to support indexing
    def __getitem__(self, idx) -> "Gaussians":
        new_tensors = {}
        for field in fields(self):
            tensor = getattr(self, field.name)
            if isinstance(tensor, bool):
                new_tensors[field.name] = tensor
            elif isinstance(tensor, int):
                new_tensors[field.name] = tensor
            elif tensor is not None and field.name not in self.EXCLUDED_FROM_MASKING:
                new_tensors[field.name] = tensor[idx]
            else:
                new_tensors[field.name] = None
        return Gaussians(**new_tensors)

    def sample_subset(self, sampled_indices) -> "Gaussians":
        """ Randomly sample a subset of gaussians. """
        total_gaussians = self.means.shape[1]
        sample_num = len(sampled_indices)

        new_tensors = {}
        for field in fields(self):
            tensor = getattr(self, field.name)
            if tensor is not None:
                if isinstance(tensor, bool):
                    new_tensors[field.name] = tensor
                elif isinstance(tensor, int):
                    new_tensors[field.name] = tensor
                else:
                    new_tensors[field.name] = tensor[:, sampled_indices]
            else:
                new_tensors[field.name] = None
        print(f"Sampled {sample_num} / {total_gaussians} gaussians.")
        return Gaussians(**new_tensors)

    def __len__(self):
        return self.means.shape[1]

    def update_object_by_curr_mask(self, **new_values) -> "Gaussians":
        """ Update certain element using the current mask. """
        sel = self.sel
        new_tensors = {}
        for field in fields(self):
            tensor = getattr(self, field.name)  # [B, G, ...]
            if tensor is not None:
                if field.name in new_values:
                    new_value = new_values[field.name]  # [B, G_valid, ...]
                    if sel is None or new_value is None or field.name in self.EXCLUDED_FROM_MASKING:
                        tensor = new_value
                    else:
                        tensor = tensor.clone()
                        tensor[:, sel, ...] = new_value
                new_tensors[field.name] = tensor
            else:
                if field.name in new_values:
                    if field.name in ["deltas", "gradients", "norm_gradients"]:
                        # Special case: allow updating deltas even if it is None
                        new_tensors[field.name] = new_values[field.name]
                        continue
                    assert new_values[field.name] is None, f"Cannot update a None field! {field.name}, got {new_values[field.name]}"
                new_tensors[field.name] = None
        return Gaussians(**new_tensors)