File size: 17,864 Bytes
78d2329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
from abc import ABC
from dataclasses import dataclass, field
from pathlib import Path
from typing import TypeVar, Generic, Optional, TYPE_CHECKING, Any
import torch
from matplotlib import pyplot as plt
from torch import nn
from torch import Tensor
import numpy as np
import os
from optgs.dataset.camera_datasets.camera import get_scene_scale
from optgs.misc.io import FrequencyScheduler
from optgs.dataset.data_types import BatchedViews
from optgs.model.decoder import Decoder
from optgs.model.decoder.decoder import DecoderOutput
from optgs.model.types import Gaussians
from optgs.scene_trainer.adc.base import BaseStrategyCfg
from optgs.scene_trainer.initializer.initializer import InitializerOutput
from optgs.scene_trainer.optimizer.layer import AdamState
from optgs.scene_trainer.initializer import InitializerCfg
from optgs.misc.detaching_cpu_list import DetachingCPUList
from optgs.scene_trainer.optimizer.lr_scheduler import LrSchedulerCfgType, get_scheduler

if TYPE_CHECKING:
    from optgs.scene_trainer.adc.vanilla import VanillaStrategyState
    from optgs.scene_trainer.adc.mcmc import McmcStrategyState


@dataclass
class OptimizerState:
    state: torch.Tensor | None = None
    init_state: torch.Tensor | None = None  # state at the beginning of the optimization
    adam_state: AdamState | None = None
    adc_state: Any = None  # VanillaStrategyState | McmcStrategyState | None


@dataclass
class OptimizerPreviousOutput:
    gaussians: Gaussians
    state: OptimizerState | None = None


@dataclass
class OptimizerInput:
    context: BatchedViews
    renderer: Decoder
    prev_output: InitializerOutput | OptimizerPreviousOutput
    num_refine: int
    iter_batch_size: int | None
    target: BatchedViews | None = None
    context_remain: dict | None = None
    debug_dict: dict | None = None
    additional_info: tuple | None = None

    @property
    def device(self) -> torch.device:
        return self.context["image"].device


@dataclass
class OptimizerOutput:
    # TODO Naama: should we add here iterations?
    gaussian_list: DetachingCPUList[Gaussians]
    t: int | None = None
    T: int | None = None
    last_prev_output: OptimizerPreviousOutput | None = None
    target_render_list: DetachingCPUList[DecoderOutput] | None = None
    context_render_list: DetachingCPUList[DecoderOutput] | None = None
    info: dict | None = None
    context_index_list: list[int] = field(default_factory=list)
    target_index_list: list[int] = field(default_factory=list)

    def get_render_list(self, which: str) -> DetachingCPUList[DecoderOutput] | None:
        if which == "target":
            return self.target_render_list
        elif which == "context":
            return self.context_render_list
        else:
            raise ValueError(f"Unknown which: {which}, should be 'target' or 'context'")

    def get_index_list(self, which: str):
        if which == "target":
            return self.target_index_list
        elif which == "context":
            return self.context_index_list
        else:
            raise ValueError(f"Unknown which: {which}, should be 'target' or 'context'")

    @classmethod
    def empty(cls, t=None) -> "OptimizerOutput":
        new = cls(gaussian_list=DetachingCPUList(), t=t)
        new.target_render_list = DetachingCPUList()
        new.context_render_list = DetachingCPUList()
        # info is a dict of lists of dicts, should all be stored in cpu
        new.info: dict[str, list[dict[str, Tensor]]] = {}
        return new


@dataclass
class OptimizerCfg:
    
    # subset optimization flags
    no_refine_mean: bool
    no_refine_scale: bool
    no_refine_rotation: bool
    no_refine_opacity: bool
    no_refine_sh0: bool
    no_refine_shN: bool

    # lr scheduler
    lr_scheduler: LrSchedulerCfgType
    
    refiner: BaseStrategyCfg

    # gradients
    input_gradients_chunk_size: int | None  # if None, use full image

    # L1 opacity regularization from 3DGS-MCMC (arXiv:2404.09591); 0.0 to disable
    opacity_reg_lambda: float

    def update(self, initializer_cfg: InitializerCfg):
        pass

    @property
    def any_adc(self) -> bool:
        return self.refiner.do_densify or self.refiner.do_prune or self.refiner.do_opacity_reset

    @property
    def need_2d_grads(self) -> bool:
        return self.refiner.do_densify

    @property
    def optimize_all(self):
        # All the no_refine_* are False
        return not any([
            self.no_refine_mean,
            self.no_refine_scale,
            self.no_refine_rotation,
            self.no_refine_opacity,
            self.no_refine_sh0,
            self.no_refine_shN,
        ])


T = TypeVar("T")


class Optimizer(nn.Module, ABC, Generic[T]):
    cfg: T

    def __init__(self, cfg: T, save_every: Optional[FrequencyScheduler] = None) -> None:
        super().__init__()
        self.cfg = cfg
        self.save_every = save_every

        # for timing
        self.iter_start = torch.cuda.Event(enable_timing=True)
        self.iter_end = torch.cuda.Event(enable_timing=True)
        # decoder_event_start/end bracket only the rendering-for-gradients call inside
        # apply_one_update_step, letting us split iter_time into decoder vs optimizer.
        self.decoder_event_start = torch.cuda.Event(enable_timing=True)
        self.decoder_event_end = torch.cuda.Event(enable_timing=True)
        # scene_start_event_start/end bracket optimizer.on_scene_start() (KNN, Adam init).
        # Read after the post-loop cuda.synchronize() in scene_trainer.get_optimized_gaussians.
        self.scene_start_event_start = torch.cuda.Event(enable_timing=True)
        self.scene_start_event_end = torch.cuda.Event(enable_timing=True)

        # Init logs for densification/pruning
        self.radii_max_log = []
        self.grads_max_log = []
        self.nr_cloned_log = []
        self.nr_splitted_log = []
        self.nr_pruned_log = []
        self.nr_gaussians_log = []
        self.iter_time_log = []       # total ms per iteration
        self.decoder_time_log = []    # ms spent in rendering-for-gradients per iteration
        self.optimizer_time_log = []  # ms spent in update step (iter_time - decoder_time)
        self.scene_start_ms = 0.0    # ms for on_scene_start (KNN lookup, Adam state init)
        self.nr_nonzero_grad_log = []

        # LR scheduler
        self.scheduler = get_scheduler(self.cfg.lr_scheduler)

    def forward(self, i, optimizer_input: OptimizerInput, optimizer_output: OptimizerOutput, **kwargs) -> OptimizerOutput:
        return self._forward_impl(i, optimizer_input, optimizer_output, **kwargs)

    def _record_iter_timing(self) -> None:
        """Record per-iteration timing into iter/decoder/optimizer_time_log.
        Call right after the timed region; iter_start must already be recorded."""
        self.iter_end.record()
        torch.cuda.synchronize()
        elapsed_time = self.iter_start.elapsed_time(self.iter_end)
        self.iter_time_log.append(elapsed_time)
        decoder_ms = self.decoder_event_start.elapsed_time(self.decoder_event_end)
        self.decoder_time_log.append(decoder_ms)
        self.optimizer_time_log.append(elapsed_time - decoder_ms)

    def on_scene_start(self, optimizer_input: OptimizerInput) -> None:
        self._on_scene_start_impl(optimizer_input)

    def _on_scene_start_impl(self, optimizer_input: OptimizerInput) -> None:
        init_output = optimizer_input.prev_output
        assert isinstance(init_output, InitializerOutput), \
            (f"base Optimizer class on_scene_start just convert the InitializerOutput to OptimizerPreviousOutput, "
             f"without handling the state. "
             f"It also initialize a new state for density control."
             f"Got type {type(init_output)}")

        # Converting the initializer output to optimizer previous output
        optimizer_prev_output = OptimizerPreviousOutput(
            gaussians=init_output.gaussians.clone(),
            state=None,
        )
        optimizer_input.prev_output = optimizer_prev_output

        if self.cfg.any_adc:
            self.reset_logs()
            optimizer_prev_output.state = OptimizerState()  # init to empty state
            self.initialize_adc_state(self.cfg, optimizer_input)

    def on_scene_end(self) -> None:
        pass

    def reset_logs(self):
        self.radii_max_log = []
        self.grads_max_log = []
        self.nr_cloned_log = []
        self.nr_splitted_log = []
        self.nr_pruned_log = []
        self.nr_gaussians_log = []
        self.iter_time_log = []
        self.decoder_time_log = []
        self.optimizer_time_log = []
        self.scene_start_ms = 0.0
        self.nr_nonzero_grad_log = []

    @staticmethod
    def initialize_adc_state(cfg: OptimizerCfg, optimizer_input: OptimizerInput) -> None:
        # Lazy import to avoid circular dependency
        from optgs.scene_trainer.adc import init_strategy_state

        # get number of points
        init_gaussians = optimizer_input.prev_output.gaussians
        nr_points = init_gaussians.means.shape[1]
        # get scene extent
        context = optimizer_input.context
        target = optimizer_input.target
        assert (
                context["extrinsics"].shape[0] == context["intrinsics"].shape[0] == 1
        ), "scene batch size > 1 not supported yet..."

        scene_scale = context["scene_scale"][0].item()
        # Initialize ADC state
        optimizer_input.prev_output.state.adc_state = init_strategy_state(
            cfg=cfg.refiner,
            nr_points=nr_points,
            device=init_gaussians.means.device,
            scene_extent=scene_scale
        )
        print("Initialized ADC state with", nr_points, "points and scene extent", scene_scale)
        
    def _forward_impl(self, i, optimizer_input: OptimizerInput, optimizer_output: OptimizerOutput, **kwargs) -> OptimizerOutput:
        raise NotImplementedError()

    def validate_input(self, optimizer_input: OptimizerInput) -> None:
        pass

    def _save_post_update_renders(
            self,
            i: int,
            optimizer_input: OptimizerInput,
            optimizer_output: OptimizerOutput,
            updated_gaussians: Gaussians,
            full_context: BatchedViews,
            full_target: BatchedViews,
    ) -> None:
        """Render and append post-update context+target views.

        Renders every iteration during training (so per-step renders can feed the meta-loss);
        otherwise renders only when save_every fires for the given tag. The per-iter subset
        (optimizer_input.context/target) is used in training when sampling indices exist,
        otherwise the full views.
        """
        for tag, full, iter_views in (
            ("context", full_context, optimizer_input.context),
            ("target", full_target, optimizer_input.target),
        ):
            if not (self.training or self.save_every(i + 1, tag=tag)):
                continue
            index_list = optimizer_output.get_index_list(tag)
            subset = iter_views if (index_list and self.training) else full
            render_output = optimizer_input.renderer.forward_batch_subset(
                updated_gaussians,
                subset,
                iter_batch_size=optimizer_input.iter_batch_size,
            )
            optimizer_output.get_render_list(tag).append(
                render_output,
                detach_and_cpu=not self.training,
            )

    @torch.no_grad()
    def apply_adc(self, i, v, h, w, adc_state, gaussians, meta, object_dict_to_adjust=None):
        """
        Apply adaptive density control (ADC) based on 2D gradient norms.
        Implements densification and pruning of Gaussians during optimization, as in vanilla 3DGS.

        Args:
            gaussians: Gaussians to be densified/pruned in place.
            h: Height of the rendered images.
            i: Current optimization iteration.
            v: Number of views.
            meta: Metadata dict from the rendering, including visibility masks and radii.
            w: Width of the rendered images.
            object_dict_to_adjust: Dict of object to adjust after pruning and densification, if needed.
        """
        # Lazy import to avoid circular dependency
        from optgs.scene_trainer.adc import post_backward

        visibility_mask = meta["visibility_filter"]  # [B, V, N]
        radii_2d = meta["radii"].float()  # [B, V, N, 2]
        means2d_grads = meta["means_2d_grads"]  # [B, V, N, 2] or None
        
        # means lr for MCMC noise injection
        # check if optimizer has means_lr_scheduler
        if hasattr(self, "means_lr_scheduler"):
            assert self.means_lr_scheduler is not None, "means_lr_scheduler is None."
            lr = self.means_lr_scheduler(i)
        else:
            # Use fallback_means_lr from the refiner config so noise magnitude matches the
            # original paper (means_lr * noise_lr ≈ 1.6e-4 * 5e5 = 80 covariance-units).
            lr = self.cfg.refiner.fallback_means_lr

        # Post-backward (ADC)
        nr_cloned, nr_splitted, nr_pruned, max_radii, max_grad2d = post_backward(
            cfg=self.cfg.refiner,
            step=i,
            gaussians=gaussians,
            adc_state=adc_state,
            smoothers=object_dict_to_adjust,
            radii_2d=radii_2d,  # [V, N]
            means2d_grads=means2d_grads,  # [V, N, 2]
            visibility_mask=visibility_mask,  # [V, N]
            iter_batch_size=v,
            w=w,
            h=h,
            lr=lr
        )
        
        self.nr_cloned_log.append(nr_cloned)
        self.nr_splitted_log.append(nr_splitted)
        self.nr_pruned_log.append(nr_pruned)
        if max_radii is not None:
            self.radii_max_log.append(max_radii)
        else:
            self.radii_max_log.append(0.0)
        if max_grad2d is not None:
            self.grads_max_log.append(max_grad2d)
        else:
            self.grads_max_log.append(0.0)

    def plot_info(self, step, output_path: Path | None = None, scene_name: str | None = None) -> None:

        if output_path is None:
            return

        if scene_name is None:
            return 
        
        save_path = output_path / "plots" / scene_name
        os.makedirs(save_path, exist_ok=True)
        
        # Define datasets and labels in a compact structure
        data = []
        
        if len(self.radii_max_log) == len(self.iter_time_log):
            data.append((range(len(self.iter_time_log)), self.radii_max_log, "Max Radius"))
        if len(self.grads_max_log) == len(self.iter_time_log):
            data.append((range(len(self.iter_time_log)), self.grads_max_log, "Max Grad magnitude"))
        if len(self.nr_cloned_log) == len(self.iter_time_log):
            data.append((range(len(self.iter_time_log)), self.nr_cloned_log, "Cloned"))
        if len(self.nr_splitted_log) == len(self.iter_time_log):
            data.append((range(len(self.iter_time_log)), self.nr_splitted_log, "Splitted"))
        if len(self.nr_pruned_log) == len(self.iter_time_log):
            data.append((range(len(self.iter_time_log)), self.nr_pruned_log, "Pruned"))

        data.append((range(len(self.iter_time_log)), self.nr_gaussians_log, "Total"))
        data.append((range(len(self.iter_time_log)), self.iter_time_log, "Iteration Time (ms)"))

        # Create a larger figure with shared x-axis
        nr_rows = len(data)
        fig, axes = plt.subplots(nr_rows, 1, figsize=(10, 15), sharex=True)

        # Define some styles for visual variety
        styles = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 'tab:brown', 'tab:pink']
        assert nr_rows <= len(styles), "Not enough styles defined for the number of subplots."
        
        # Loop through subplots
        for ax, (x, y, label), color in zip(axes, data, styles):
            ax.plot(x, y, label=label, color=color, linewidth=2)
            ax.set_ylabel("Value", fontsize=11)
            ax.grid(True, linestyle="--", alpha=0.6)
            ax.legend(loc="upper right", fontsize=10)
            ax.set_title(f"{label} Gaussians", fontsize=13, pad=5)
            # show x-axis ticks on all plots
            ax.tick_params(axis='x', which='both', bottom=True, top=False, labelbottom=True)
            # set y-axis vmin to 0
            # ax.set_ylim(bottom=0)
        
        # Shared x-axis label
        axes[-1].set_xlabel("Iteration", fontsize=12)
        # Improve layout
        plt.tight_layout()
        plt.subplots_adjust(hspace=0.3)
        #
        # module_name = self.__class__.__name__.lower()
        
        # Save and close
        save_path = save_path / f"stats_{step}.png"
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
        print("Saved optimizer stats plot to:", save_path)


class LearnedOptimizer(Optimizer[T], ABC):
    @property
    def strategy(self) -> str:
        return "learned"

    @property
    def device(self) -> torch.device:
        return next(self.parameters()).device


class NonlearnedOptimizer(Optimizer[T], ABC):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        # nn.Module.__init__ sets training=True (a plain attribute, not via
        # train()); a non-learned optimizer has no trainable parameters, so pin
        # it to eval at construction.
        self.eval()

    @property
    def strategy(self) -> str:
        return "nonlearned"

    def train(self, mode: bool = True):
        # train mode is meaningless here, and `self.training` gates
        # meta-training-only code paths (e.g. _save_post_update_renders
        # retaining full-scene renders on GPU). Pin to eval, even under a
        # generic `module.train()` recursion.
        return super().train(False)