File size: 16,701 Bytes
78d2329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
from dataclasses import dataclass
from functools import partial
from typing import Literal, List, Optional

import torch
from torch import Tensor

from optgs.dataset.data_types import BatchedViews
from optgs.misc.general_utils import get_expon_lr_func
from optgs.misc.io import FrequencyScheduler
from optgs.model.decoder.decoder import Decoder
from optgs.model.types import Gaussians
from optgs.scene_trainer.initializer import InitializerCfg
from optgs.scene_trainer.optimizer.layer import AdamInputSmoothing
from optgs.scene_trainer.optimizer.optimizer import (
    OptimizerInput,
    OptimizerOutput,
    OptimizerCfg, NonlearnedOptimizer,
)
from optgs.scene_trainer.optimizer.optimizer_utils import (
    calc_input_gradients,
    squeeze_grad_dict,
    smooth_grads,
)


@dataclass
class AdamOptimizerCfg(OptimizerCfg):
    name: Literal["adam"]

    # adam params
    betas: List[float | int]  # Typically a list of two floats, e.g., [0.9, 0.999]
    eps: float
    weight_decay: float

    # learning rates
    base_lr: int | float
    means_lr_init: float
    means_lr_final: float
    means_lr_delay_mult: float
    means_lr_max_steps: int  # should be equal to total optimization steps
    scales_lr: float
    rotations_lr: float
    opacities_lr: float
    sh0s_lr: float
    shNs_lr: float  # 20 times less as sh0s_lr in original paper

    def update(self, initializer_cfg: InitializerCfg):
        pass


class AdamOptimizer(NonlearnedOptimizer[AdamOptimizerCfg]):
    def __init__(
            self, cfg: AdamOptimizerCfg, save_every: Optional[FrequencyScheduler] = None
    ) -> None:
        super().__init__(cfg, save_every)

        self.smoothers = None
        self.means_lr_scheduler = None
        self._meta_bufs: dict = {}  # reused across steps: radii, visibility buffers

        # NOTE: AdamOptimizer is evaluation-only (3DGS baseline); not used during meta-training.

    def _on_scene_start_impl(self, optimizer_input: OptimizerInput) -> None:
        super()._on_scene_start_impl(optimizer_input)

        # assert scene batch size 1
        context = optimizer_input.context
        assert (
                context["extrinsics"].shape[0] == context["intrinsics"].shape[0] == 1
        ), "scene batch size > 1 not supported yet..."

        # instantiate Adam optimizers for each parameter type
        nr_gaussians = optimizer_input.prev_output.gaussians.means.shape[1]
        device = optimizer_input.prev_output.gaussians.means.device
        smoother_cls = partial(AdamInputSmoothing, beta1=self.cfg.betas[0], beta2=self.cfg.betas[1], eps=self.cfg.eps,
                               device=device)
        means_smoother = smoother_cls(shape=optimizer_input.prev_output.gaussians.means.shape[1:])
        scales_smoother = smoother_cls(shape=optimizer_input.prev_output.gaussians.scales.shape[1:])
        rotations_smoother = smoother_cls(shape=optimizer_input.prev_output.gaussians.rotations.shape[1:])
        opacities_smoother = smoother_cls(shape=optimizer_input.prev_output.gaussians.opacities.shape[1:])
        sh0s_smoother = smoother_cls(shape=optimizer_input.prev_output.gaussians.harmonics[..., :, :1].shape[1:])

        init_gaussians = optimizer_input.prev_output.gaussians
        if init_gaussians.harmonics.shape[-1] > 1:
            shNs_smoother = smoother_cls(shape=(init_gaussians.harmonics[..., :, 1:]).shape[1:])
        else:
            shNs_smoother = None

        self.smoothers = {
            "means": means_smoother,
            "scales": scales_smoother,
            "rotations": rotations_smoother,
            "opacities": opacities_smoother,
            "sh0s": sh0s_smoother,
            "shNs": shNs_smoother,
        }

        # get scene extent
        scene_scale = optimizer_input.context["scene_scale"]
        if scene_scale is None:
            scene_scale = torch.ones(1, 1, device=device)
        scene_scale = scene_scale[0].item()

        # initialize learning rate scheduler for means
        self.means_lr_scheduler = get_expon_lr_func(
            lr_init=self.cfg.means_lr_init * scene_scale,
            lr_final=self.cfg.means_lr_final * scene_scale,
            lr_delay_mult=self.cfg.means_lr_delay_mult,
            max_steps=self.cfg.means_lr_max_steps
        )

    def on_scene_end(self) -> None:
        super().on_scene_end()
        self.smoothers = None
        self.means_lr_scheduler = None
        self._meta_bufs.clear()

    def _forward_impl(
            self,
            i: int,
            optimizer_input: OptimizerInput,
            optimizer_output: OptimizerOutput,
            full_context: BatchedViews,
            full_target: BatchedViews,
            **kwargs
    ) -> OptimizerOutput:

        # Timing
        self.iter_start.record()

        # Unpack
        iter_context: BatchedViews = optimizer_input.context
        target: BatchedViews = optimizer_input.target
        renderer: Decoder = optimizer_input.renderer
        b, v, _, h, w = iter_context["image"].shape
        assert b == 1, "Batch size > 1 not supported for post-processing"

        # Log number of gaussians
        self.nr_gaussians_log.append(
            optimizer_input.prev_output.gaussians.means.shape[1]
        )

        # One optimization step
        res = self.apply_one_update_step(i, optimizer_input, optimizer_output, sh_degree=kwargs.get("sh_degree", None))
        gaussians: Gaussians = res[0]
        meta_for_adc: dict = res[1]
        updates: dict[str, Tensor] = res[2]
        grads_raw: dict[str, Tensor] = res[3]
        normalized_grads: dict[str, Tensor] = res[4]
        learning_rates: dict[str, float] = res[5]

        # Densification and Pruning
        if self.cfg.any_adc:
            # Apply ADC
            self.apply_adc(
                i=i, v=v, h=h, w=w,
                adc_state=optimizer_input.prev_output.state.adc_state,
                gaussians=gaussians,
                meta=meta_for_adc,
                object_dict_to_adjust=self.smoothers
            )
            # ADC changes N → cached buffers are invalid; re-make tensors as fresh leaves.
            # torch.cat (used by add_new/relocate) produces a non-leaf even with requires_grad=True,
            # so .grad is never populated by backward(). detach() cuts the grad_fn first.
            buf_nr_gaussians = self._meta_bufs['N']
            actual_nr_gaussians = gaussians.means.shape[1]
            if buf_nr_gaussians != actual_nr_gaussians:
                self._meta_bufs.clear()
                # TODO Naama: need to think if the detach is necessary (was added during mcmc implementation)
                gaussians.means = gaussians.means.detach().requires_grad_(True)
                gaussians.scales = gaussians.scales.detach().requires_grad_(True)
                gaussians.rotations_unnorm = gaussians.rotations_unnorm.detach().requires_grad_(True)
                gaussians.opacities = gaussians.opacities.detach().requires_grad_(True)
                gaussians.harmonics = gaussians.harmonics.detach().requires_grad_(True)

        # Timing
        self._record_iter_timing()

        # TODO Naama: we can log stats with save_every, but need to change stuff later.
        # Log stats — guard with save_every
        if grads_raw is not None:  # and self.save_every(i + 1, tag="info"):
            G = grads_raw["means"].shape[0]
            nonzero_grads = [(g.reshape(G, -1) != 0).any(dim=-1) for g in grads_raw.values() if g is not None]
            nonzero_grads = torch.stack(nonzero_grads)  # [num_params, G]
            nonzero_grads = nonzero_grads.any(dim=0)  # [G]
            self.nr_nonzero_grad_log.append(nonzero_grads.sum().item())

        # Save updated gaussians (for next iteration)
        optimizer_input.prev_output.gaussians = gaussians

        # Info
        if self.save_every(i + 1, tag="info"):

            # save gaussians
            optimizer_output.gaussian_list.append(gaussians, detach_and_cpu=True, save_to_disk=False, no_cache=False)

            # Save delta stats
            assert optimizer_output.info is not None

            # log deltas
            if "deltas" not in optimizer_output.info:
                optimizer_output.info["deltas"] = []
            optimizer_output.info["deltas"].append({k: v.cpu() for k, v in updates.items() if v is not None})

            # log gradients
            if "grads" not in optimizer_output.info:
                optimizer_output.info["grads"] = []
            optimizer_output.info["grads"].append({k: v.cpu() for k, v in grads_raw.items() if v is not None})

            # log normalized gradients
            if "normalized_grads" not in optimizer_output.info:
                optimizer_output.info["normalized_grads"] = []
            optimizer_output.info["normalized_grads"].append(
                {k: v.cpu() for k, v in normalized_grads.items() if v is not None})

            # log learning rates
            if "learning_rates" not in optimizer_output.info:
                optimizer_output.info["learning_rates"] = []
            optimizer_output.info["learning_rates"].append(learning_rates)

            # Check if output_path in kwargs
            output_path = kwargs.get("output_path", None)
            scene_name = kwargs.get("scene_name", None)

            # Plot stats
            # if self.cfg.any_adc:
            #     self.plot_info(i, output_path=output_path, scene_name=scene_name)

        # Post-update context + target renders
        self._save_post_update_renders(
            i, optimizer_input, optimizer_output, gaussians,
            full_context, full_target,
        )

        # Optimizer output is being changed in place, but for clarity we return it
        return optimizer_output

    def apply_one_update_step(
            self, i, optimizer_input: OptimizerInput, optimizer_output: OptimizerOutput, sh_degree: int | None = None
    ) -> tuple[Gaussians, dict | None, dict, dict[str, Tensor], dict[str, Tensor], dict[str, float]]:

        iter_context = optimizer_input.context
        b, v, _, h, w = iter_context["image"].shape
        renderer = optimizer_input.renderer
        gaussians = optimizer_input.prev_output.gaussians

        # if first iteration
        if i == 0:
            # assert gaussians stores activated values
            assert gaussians.stores_activated, "Gaussians must store activated values."
            # deactivate values in-place (avoids allocating new tensors)
            gaussians.scales.log_()  # [B, N, 3]
            gaussians.opacities.logit_()
            gaussians.stores_activated = False
            # enable requires_grad once — .grad buffers persist across steps,
            # so backward() reuses them instead of allocating new tensors each call
            gaussians.means.requires_grad_(True)
            gaussians.scales.requires_grad_(True)
            gaussians.rotations_unnorm.requires_grad_(True)
            gaussians.opacities.requires_grad_(True)
            gaussians.harmonics.requires_grad_(True)
        else:
            # assert gaussians does not store activated values
            assert not gaussians.stores_activated, "Gaussians must not store activated values."

        # learning rates
        # TODO Naama: use current cfg field lr_scheduler, which also defines the lr per param
        assert self.means_lr_scheduler is not None, "means_lr_scheduler is not initialized"
        means_lr = self.means_lr_scheduler(i) * self.cfg.base_lr
        scales_lr = self.cfg.scales_lr * self.cfg.base_lr
        rotations_lr = self.cfg.rotations_lr * self.cfg.base_lr
        opacities_lr = self.cfg.opacities_lr * self.cfg.base_lr
        sh0s_lr = self.cfg.sh0s_lr * self.cfg.base_lr
        shNs_lr = self.cfg.shNs_lr * self.cfg.base_lr

        # scale learning rates by number of views in the batch
        # means_lr *= v
        # scales_lr *= v
        # rotations_lr *= v
        # opacities_lr *= v
        # sh0s_lr *= v
        # shNs_lr *= v

        assert (
                iter_context["extrinsics"].shape[0] == iter_context["extrinsics"].shape[0] == 1
        ), "scene batch size > 1 not supported for yet..."

        # unpack gaussians
        means = gaussians.means  # [B, N, 3]
        rotations_unnorm = gaussians.rotations_unnorm  # [B, N, 4]
        scales_raw = gaussians.scales  # [B, N, 3]
        opacities_raw = gaussians.opacities  # [B, N]
        shs = gaussians.harmonics  # [B, N, 3, sh_d]

        self.decoder_event_start.record()
        loss, grads_raw, meta_for_adc = calc_input_gradients(
            iter_context,
            means,
            scales_raw,
            rotations_unnorm,
            opacities_raw,
            shs,
            renderer,
            need_2d_grads=self.cfg.need_2d_grads,
            chunk_size=self.cfg.input_gradients_chunk_size,
            any_adc=self.cfg.any_adc,
            sh_degree=sh_degree,
            meta_bufs=self._meta_bufs,
            opacity_reg_lambda=self.cfg.opacity_reg_lambda,
        )
        self.decoder_event_end.record()

        # get updates from adam optimizer
        grads_raw = squeeze_grad_dict(grads_raw)
        assert self.smoothers is not None, "Smoothers not initialized"
        grads_adam = smooth_grads(grads_raw, self.smoothers)

        # update the gaussians parameters
        # Batch delta computation for contiguous params with _foreach_mul to reduce kernel launches.
        # no_refine flags are handled by excluding the param from the batch (delta stays None).
        _grad_lr_pairs = [
            (grads_adam["means"], -means_lr, self.cfg.no_refine_mean),
            (grads_adam["scales"], -scales_lr, self.cfg.no_refine_scale),
            (grads_adam["rotations"], -rotations_lr, self.cfg.no_refine_rotation),
            (grads_adam["opacities"], -opacities_lr, self.cfg.no_refine_opacity),
        ]
        _active_grads = [g for g, lr, skip in _grad_lr_pairs if not skip]
        _active_lrs = [lr for g, lr, skip in _grad_lr_pairs if not skip]
        _active_deltas = torch._foreach_mul(_active_grads, _active_lrs) if _active_grads else []

        _delta_iter = iter(_active_deltas)
        delta_means = next(_delta_iter) if not self.cfg.no_refine_mean else None
        delta_scales_raw = next(_delta_iter) if not self.cfg.no_refine_scale else None
        delta_rotations_unnorm = next(_delta_iter) if not self.cfg.no_refine_rotation else None
        delta_opacities_raw = next(_delta_iter) if not self.cfg.no_refine_opacity else None

        # SH deltas stay separate (non-contiguous slice views)
        delta_sh0s = None if self.cfg.no_refine_sh0 else -sh0s_lr * grads_adam["sh0s"]
        delta_shNs = None
        if grads_adam["shNs"] is not None and not self.cfg.no_refine_shN:
            delta_shNs = -shNs_lr * grads_adam["shNs"]

        # step — batch contiguous params with _foreach_add_ to reduce kernel launches;
        # SH slice views are non-contiguous so they stay separate
        _params = [means, scales_raw, rotations_unnorm, opacities_raw]
        _deltas = [delta_means, delta_scales_raw, delta_rotations_unnorm, delta_opacities_raw]
        _active = [(p, d) for p, d in zip(_params, _deltas) if d is not None]
        if _active:
            torch._foreach_add_([p for p, d in _active], [d for p, d in _active])
        self.safe_inplace_update(delta_sh0s, shs[..., 0:1])
        self.safe_inplace_update(delta_shNs, shs[..., 1:])

        # assign (means/scales/rotations/harmonics are the same objects; in-place ops above
        # already updated their storage. opacities_raw is a view — do NOT reassign
        # gaussians.opacities here, as that would replace the persistent leaf with a non-leaf
        # view and break retain_grad() on subsequent steps.)
        gaussians.means = means
        gaussians.scales = scales_raw
        gaussians.rotations_unnorm = rotations_unnorm
        gaussians.harmonics = shs

        # group updates
        updates = {
            "means": delta_means,
            "scales": delta_scales_raw,
            "rotations": delta_rotations_unnorm,
            "opacities": delta_opacities_raw,
            "sh0s": delta_sh0s,
            "shNs": delta_shNs,
        }

        learning_rates = {
            "means": means_lr,
            "scales": scales_lr,
            "rotations": rotations_lr,
            "opacities": opacities_lr,
            "sh0s": sh0s_lr,
            "shNs": shNs_lr,
        }

        return gaussians, meta_for_adc, updates, grads_raw, grads_adam, learning_rates

    @staticmethod
    def safe_inplace_update(delta_means: Tensor | None, means: Tensor):
        if delta_means is not None:
            means += delta_means