File size: 26,344 Bytes
78d2329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
from dataclasses import dataclass, field
from typing import List
import tqdm as tqdm
import numpy as np
import torch
from torch import Tensor
import math
from pytorch_optimizer import load_optimizer
from torch.optim.lr_scheduler import LambdaLR
import torch.nn.functional as F
from einops import rearrange
from optgs.evaluation.metrics import compute_rgb_metrics
from optgs.misc.io import FrequencyScheduler
from optgs.scene_trainer.gaussian_module import GaussiansModule, gaussians2module, module2gaussians
from optgs.model.types import Gaussians
from optgs.scene_trainer.optimizer.optimizer import OptimizerOutput
from optgs.scene_trainer.optimizer.optimizer_utils import Number3DGSCfg
from optgs.misc.detaching_cpu_list import DetachingCPUList
from optgs.dataset.camera_datasets.camera import get_scene_scale
from optgs.misc.general_utils import get_expon_lr_func
from fused_ssim import fused_ssim
from optgs.model.decoder.decoder import Decoder, DecoderOutput


@dataclass
class PostProcessADCCfg:
    """ADC (Adaptive Density Control) config for postprocessing.
    Defaults match vanilla 3DGS (config/scene_trainer/scene_optimizer/refiner/default.yaml).
    """
    do_densify: bool = True
    do_prune: bool = True
    do_opacity_reset: bool = True

    # Scheduling
    pause_refine_after_reset: int = 0
    refine_every: int = 100
    reset_every: int = 3000
    refine_start_iter: int = 500
    refine_stop_iter: int = 15000
    refine_scale2d_stop_iter: int = 0

    # Densification thresholds
    grow_grad2d: float = 0.0002
    grow_scale3d: float = 0.01  # aka percent_dense
    grow_scale2d: float = 0.05

    # Pruning thresholds
    prune_scale3d: float = 0.1
    prune_scale2d: float = 0.15
    min_opacity: float = 0.005

    revised_opacity: bool = False


@dataclass
class PostProcessCfg:
    name: str
    steps: int
    compute_metrics_every: int
    lr_data: Number3DGSCfg
    scheduler: str | None
    scheduler_warm_up_ratio: float

    # SGD-specific
    momentum: float = 0.0
    nesterov: bool = False

    # Adam-specific
    betas: List[float] | None = None
    eps: float = 1e-8
    amsgrad: bool = False

    # Shared
    weight_decay: float = 0.0

    # LR scheduling: steps already done by scene trainer (offsets the schedule)
    prior_steps: int = 0

    # Means LR scheduling (defaults match vanilla optimizer behavior)
    means_lr_final_ratio: float = 0.0625  # ratio of final/initial means LR (vanilla: 1e-5 / 1.6e-4)
    means_lr_delay_mult: float = 0.01    # ramp-up delay multiplier (vanilla default: 0.01)
    means_lr_scale_by_scene_extent: bool = True  # scale means LR by scene extent (vanilla default)

    # View chunking for gradient accumulation
    chunk_size: int = -1  # -1 = all views at once

    # ADC (Adaptive Density Control)
    adc: PostProcessADCCfg | None = None

    @property
    def is_active(self) -> bool:
        return self.name != "none" and self.steps > 0

    def get_dir_name(self, with_name=True):
        dir_str = self._get_dir_name()
        return f"{self.name}_{dir_str}" if with_name else dir_str

    def _get_dir_name(self):
        if self.name == "sgd":
            return f"lr{self.lr_data.base}_mom{self.momentum}"
        elif self.name == "adam":
            return f"lr{self.lr_data.base}_betas{'-'.join(map(str, self.betas or []))}_eps{self.eps}"
        return ""


def _module_to_deactivated_gaussians(gm: GaussiansModule) -> Gaussians:
    """Convert GaussiansModule to Gaussians with deactivated (raw) values for ADC."""
    return Gaussians(
        means=gm.means.detach().unsqueeze(0),
        scales=gm.scales_raw.detach().unsqueeze(0),       # log space
        opacities=gm.opacities_raw.detach().unsqueeze(0),  # logit space
        rotations=gm.rotations.detach().unsqueeze(0),
        rotations_unnorm=gm.rotations_unnorm.detach().unsqueeze(0),
        harmonics=gm.harmonics.detach().unsqueeze(0),
        stores_activated=False,
    )


def _deactivated_gaussians_to_module(gaussians: Gaussians, device: torch.device) -> GaussiansModule:
    """Convert deactivated Gaussians back to GaussiansModule."""
    assert not gaussians.stores_activated
    return GaussiansModule(
        means=gaussians.means[0].to(device),
        harmonics=gaussians.harmonics[0].to(device),
        opacities=torch.sigmoid(gaussians.opacities[0]).to(device),
        scales=torch.exp(gaussians.scales[0]).to(device),
        rotations_unnorm=gaussians.rotations_unnorm[0].to(device),
    )


class PostProcessing3DGS:

    def __init__(self, cfg: PostProcessCfg, save_every: FrequencyScheduler):
        self.cfg = cfg
        self.save_every = save_every

        # Timing
        self.iter_start = torch.cuda.Event(enable_timing=True)
        self.iter_end = torch.cuda.Event(enable_timing=True)

        self.reset_logs()

    def reset_logs(self):
        self.radii_max_log = []
        self.grads_max_log = []
        self.nr_cloned_log = []
        self.nr_splitted_log = []
        self.nr_pruned_log = []
        self.nr_gaussians_log = []
        self.nr_nonzero_grad_log = []
        self.iter_time_log = []

    def _calc_loss(
        self, context, output_renderer: DecoderOutput
    ) -> Tensor:
        # compute scalar loss
        # assume batch size 1
        assert context["image"].shape[0] == 1
        assert context["image"].shape == output_renderer.color.shape
        l1_render_error = (output_renderer.color - context["image"]).abs().mean()

        ssim_score = fused_ssim(
            rearrange(output_renderer.color, "b v c h w -> (b v) c h w"),
            rearrange(context["image"], "b v c h w -> (b v) c h w"),
            padding="valid"
        )
        loss = 0.8 * l1_render_error + 0.2 * (1 - ssim_score)

        return loss

    def _chunked_forward_backward(self, gaussian_module, iter_context, decoder, render_res, adc_state):
        """Render views in chunks, accumulate gradients, and collect ADC metadata.

        Matches the gradient accumulation approach of calc_input_gradients in the vanilla optimizer:
        each chunk computes a mean loss, gradients accumulate, then are averaged by nr_chunks.
        """
        v = iter_context["image"].shape[1]
        chunk_size = self.cfg.chunk_size if self.cfg.chunk_size > 0 else v
        nr_chunks = math.ceil(v / chunk_size)

        # Accumulate means2d grads and radii for ADC across chunks
        need_adc = adc_state is not None
        h, w = render_res
        if need_adc:
            N = gaussian_module.means.shape[0]
            means2d_grads_all = torch.zeros((1, v, N, 2), device=gaussian_module.means.device)
            radii_all = torch.zeros((1, v, N, 2), device=gaussian_module.means.device)
            visibility_all = torch.zeros((1, v, N), dtype=torch.bool, device=gaussian_module.means.device)

        for chunk_start in range(0, v, chunk_size):
            chunk_end = min(chunk_start + chunk_size, v)

            # Slice views for this chunk
            chunk_context = {
                "image": iter_context["image"][:, chunk_start:chunk_end],
                "extrinsics": iter_context["extrinsics"][:, chunk_start:chunk_end],
                "intrinsics": iter_context["intrinsics"][:, chunk_start:chunk_end],
                "near": iter_context["near"][:, chunk_start:chunk_end],
                "far": iter_context["far"][:, chunk_start:chunk_end],
            }

            # Render
            chunk_output = decoder.forward_batch_subset(gaussian_module, chunk_context, render_res)

            # Retain means2d grad for ADC
            if need_adc and chunk_output.means2d is not None:
                chunk_output.means2d.retain_grad()

            # Loss and backward (gradients accumulate across chunks)
            chunk_loss = self._calc_loss(chunk_context, chunk_output)
            chunk_loss.backward()

            # Collect ADC metadata from this chunk
            if need_adc:
                if chunk_output.radii is not None:
                    radii_all[:, chunk_start:chunk_end] = chunk_output.radii.detach()
                if chunk_output.visibility_filter is not None:
                    visibility_all[:, chunk_start:chunk_end] = chunk_output.visibility_filter.detach()
                if chunk_output.means2d is not None and chunk_output.means2d.grad is not None:
                    means2d_grads_all[:, chunk_start:chunk_end] = chunk_output.means2d.grad.detach()

        # Average gradients across chunks (matches vanilla behavior)
        if nr_chunks > 1:
            for param in gaussian_module.parameters():
                if param.grad is not None:
                    param.grad /= nr_chunks

        # Return ADC metadata
        if need_adc:
            return {
                "radii": radii_all,
                "visibility_filter": visibility_all,
                "means_2d_grads": means2d_grads_all,
            }
        return None

    def _apply_adc(self, step, gaussian_module, adc_state, device):
        """Apply ADC (clone/split/prune/opacity reset) using the same logic as vanilla 3DGS.

        Returns (gaussian_module, optimizer_needs_rebuild).
        """
        from optgs.scene_trainer.adc.vanilla import cloning, splitting, prune, reset_adc_state

        adc_cfg = self.cfg.adc
        changed = False
        nr_cloned, nr_splitted, nr_pruned = 0, 0, 0

        # Convert to deactivated Gaussians for ADC (ADC functions expect Gaussians, not GaussiansModule)
        gaussians = _module_to_deactivated_gaussians(gaussian_module)

        if step < adc_cfg.refine_stop_iter:
            grads = adc_state.grad2d_norm_accum / adc_state.denom.clamp_min(1.0)
            scene_extent = adc_state.scene_extent

            if (
                step >= adc_cfg.refine_start_iter
                and step % adc_cfg.refine_every == 0
                and step % adc_cfg.reset_every >= adc_cfg.pause_refine_after_reset
            ):
                if adc_cfg.do_densify:
                    scales = torch.exp(gaussians.scales.squeeze(0))  # activate
                    is_grad_high = grads > adc_cfg.grow_grad2d
                    is_small = scales.max(dim=-1).values <= adc_cfg.grow_scale3d * scene_extent

                    clone_mask = is_grad_high & is_small
                    split_mask = is_grad_high & ~is_small

                    if step < adc_cfg.refine_scale2d_stop_iter:
                        split_mask |= adc_state.radii2d > adc_cfg.grow_scale2d

                    # Clone
                    cloning(gaussians, adc_state, clone_mask)
                    nr_cloned = int(clone_mask.sum().item())

                    # Extend split_mask for newly cloned points (they should not be split)
                    split_mask = torch.cat([
                        split_mask,
                        torch.zeros(nr_cloned, dtype=torch.bool, device=split_mask.device),
                    ])

                    # Split
                    splitting(gaussians, adc_state, split_mask, N=2,
                              revised_opacity=adc_cfg.revised_opacity)
                    nr_splitted = int(split_mask.sum().item())

                    changed = True

                if adc_cfg.do_prune:
                    opacities = torch.sigmoid(gaussians.opacities.squeeze(0))  # activate
                    scales = torch.exp(gaussians.scales.squeeze(0))            # activate

                    prune_mask = opacities < adc_cfg.min_opacity
                    if step > adc_cfg.reset_every:
                        is_too_big = scales.max(dim=-1).values > adc_cfg.prune_scale3d * scene_extent
                        if step < adc_cfg.refine_scale2d_stop_iter:
                            is_too_big |= adc_state.radii2d > adc_cfg.prune_scale2d
                        prune_mask = prune_mask | is_too_big

                    prune(gaussians, adc_state, prune_mask)
                    nr_pruned = int(prune_mask.sum().item())
                    changed = True

                reset_adc_state(adc_state)
                print(
                    f"ADC @ iter {step}: cloned {nr_cloned}, split {nr_splitted}, "
                    f"pruned {nr_pruned}, total {gaussians.means.shape[1]}"
                )

        # Opacity reset
        if adc_cfg.do_opacity_reset:
            if step % adc_cfg.reset_every == 0 and step > 0:
                opacities = torch.sigmoid(gaussians.opacities)  # activate
                value = adc_cfg.min_opacity * 2.0
                new_opacities = torch.min(opacities, torch.ones_like(opacities) * value)
                gaussians.opacities = torch.logit(new_opacities)  # deactivate back
                changed = True
                print(f"Opacity reset @ iter {step}")

        self.nr_cloned_log.append(nr_cloned)
        self.nr_splitted_log.append(nr_splitted)
        self.nr_pruned_log.append(nr_pruned)

        if changed:
            # Rebuild GaussiansModule from modified Gaussians
            gaussian_module = _deactivated_gaussians_to_module(gaussians, device)

        return gaussian_module, changed

    @torch.no_grad()
    def apply(
        self,
        batch,
        gaussians: Gaussians,
        decoder,
        metrics=["psnr", "ssim"],
        iter_batch_size: int = -1,
        batchify_fn=None,
        visualization_dump=None
    ) -> OptimizerOutput | None:

        target_render_list = DetachingCPUList()
        context_render_list = DetachingCPUList()

        if self.cfg.steps == 0:
            return None

        # [Improvement 1] Calculate scene_scale from both context + target (matches vanilla optimizer)
        camtoworlds_context = batch['context']['extrinsics'][0].cpu().numpy()  # [Vc, 4, 4]
        camtoworlds_target = batch['target']['extrinsics'][0].cpu().numpy()    # [Vt, 4, 4]
        camtoworlds = np.concatenate([camtoworlds_context, camtoworlds_target], axis=0)
        scene_scale = get_scene_scale(camtoworlds)
        print("scene_scale:", scene_scale)

        device = batch['context']['image'].device

        # convert Gaussians to GaussiansModule
        gaussian_module = gaussians2module(gaussians, device=device)

        optimizer = self.get_optimizer(gaussian_module, scene_scale)
        scheduler = self.get_scheduler(optimizer, scene_scale=scene_scale, prior_steps=self.cfg.prior_steps)

        # print all optimizer param groups
        for i, param_group in enumerate(optimizer.param_groups):
            print(f"Param group {i}: lr={param_group['lr']}, weight_decay={param_group.get('weight_decay', 0.0)}, requires_grad={param_group['params'][0].requires_grad}")

        assert batch["context"]["extrinsics"].shape[0] == batch["context"]["extrinsics"].shape[0] == 1, \
            "Batch size > 1 not supported for post-processing"

        nr_context_views, _, h, w = batch["context"]["image"][0].shape

        # controlling number of context views seen at each iteration (for rendering chunk size)
        _iter_batch_size = iter_batch_size if iter_batch_size > 0 else nr_context_views
        print("using iter_batch_size =", _iter_batch_size)

        render_res = (h, w)

        # [Improvement 3] Initialize ADC state if configured
        adc_state = None
        if self.cfg.adc is not None:
            from optgs.scene_trainer.adc.vanilla import VanillaStrategyState
            nr_points = gaussian_module.means.shape[0]
            adc_state = VanillaStrategyState.initialize(
                nr_points=nr_points,
                device=device,
                scene_extent=scene_scale,
            )
            print(f"Initialized ADC state with {nr_points} points")

        # render before first step
        context_render_output = decoder.forward_batch_subset(gaussian_module, batch["context"], render_res, iter_batch_size=_iter_batch_size)
        context_render_list.append(context_render_output, detach_and_cpu=True)  # initial rendering

        target_render_output = decoder.forward_batch_subset(gaussian_module, batch["target"], render_res, iter_batch_size=_iter_batch_size)
        target_render_list.append(target_render_output, detach_and_cpu=True)  # initial rendering

        # Reset viewpoint stack for fresh sampling in postprocessing
        batch["context"].viewpoint_stack = None

        pbar = tqdm.tqdm(range(self.cfg.steps), desc=f"PP {self.cfg.name}", ncols=120)
        pbar_postfix = {}
        for i in pbar:

            self.iter_start.record()

            with torch.enable_grad():

                # Log number of gaussians
                self.nr_gaussians_log.append(gaussian_module.means.shape[0])

                # reset gradients
                optimizer.zero_grad()

                # Sample context views using the same strategy as the optimizer
                iter_context, _ = batchify_fn(batch, "context")

                # [Improvement 4] Render in chunks, accumulate gradients, collect ADC metadata
                meta_for_adc = self._chunked_forward_backward(
                    gaussian_module, iter_context, decoder, render_res, adc_state
                )

                # step
                optimizer.step()

                # update scheduler
                if scheduler is not None:
                    scheduler.step()

            # [Improvement 3] ADC: update state and apply densification/pruning
            if adc_state is not None and meta_for_adc is not None:
                from optgs.scene_trainer.adc.vanilla import update_vanilla_strategy_state

                v_chunk = iter_context["image"].shape[1]
                update_vanilla_strategy_state(
                    adc_state,
                    radii_2d=meta_for_adc["radii"],
                    means2d_grads=meta_for_adc["means_2d_grads"],
                    visibility_mask=meta_for_adc["visibility_filter"],
                    v=v_chunk,
                    w=w,
                    h=h,
                )

                gaussian_module, adc_changed = self._apply_adc(i, gaussian_module, adc_state, device)
                if adc_changed:
                    # Rebuild optimizer and scheduler after ADC changed Gaussian count
                    optimizer = self.get_optimizer(gaussian_module, scene_scale)
                    scheduler = self.get_scheduler(
                        optimizer, scene_scale=scene_scale, prior_steps=self.cfg.prior_steps
                    )
                    # Fast-forward scheduler to current step
                    for _ in range(i + 1):
                        scheduler.step() if scheduler is not None else None

            # Timing
            self.iter_end.record()
            torch.cuda.synchronize()

            elapsed_time = self.iter_start.elapsed_time(self.iter_end)
            self.iter_time_log.append(elapsed_time)

            if self.save_every(i + 1, tag="context"):
                with torch.no_grad():
                    context_render_output = decoder.forward_context(gaussian_module, batch, (h, w))
                    context_render_list.append(context_render_output, detach_and_cpu=True)
                    context_rgb = context_render_output.color[0]  # [Vc, 3, Hc, Wc]
                    ctx_scores: dict = compute_rgb_metrics(
                        rgb=context_rgb,
                        rgb_gt=batch["context"]["image"][0],
                        metrics=metrics,
                        iter_batch_size=iter_batch_size if "lpips" in metrics else -1
                    )
                    for k, v in ctx_scores.items():
                        pbar_postfix[f"ctx_{k}"] = f"{v.item():.2f}"

            if self.save_every(i + 1, tag="target"):
                with torch.no_grad():
                    target_render_output = decoder.forward_target(gaussian_module, batch, (h, w))
                    target_render_list.append(target_render_output, detach_and_cpu=True)
                    target_rgb = target_render_output.color[0]  # [Vt, 3, Ht, Wt]
                    tgt_scores: dict = compute_rgb_metrics(
                        rgb=target_rgb,
                        rgb_gt=batch["target"]["image"][0],
                        metrics=metrics,
                        iter_batch_size=iter_batch_size if "lpips" in metrics else -1
                    )
                    for k, v in tgt_scores.items():
                        pbar_postfix[f"tgt_{k}"] = f"{v.item():.2f}"

            pbar_postfix["gs"] = gaussian_module.means.shape[0]
            pbar.set_postfix(pbar_postfix)

            if visualization_dump is not None and "grads" in visualization_dump:
                self.debug_grads(gaussian_module, visualization_dump, i)

        # convert back to Gaussians

        postprocessed_gaussians = module2gaussians(gaussian_module)
        postprocessed_gaussians_list = DetachingCPUList()
        postprocessed_gaussians_list.append(postprocessed_gaussians, detach_and_cpu=True)
        output = OptimizerOutput(
            target_render_list=target_render_list,
            context_render_list=context_render_list,
            gaussian_list=postprocessed_gaussians_list,
            info = {}
        )

        return output

    def debug_grads(self, gaussians: GaussiansModule, debug_dict, step):
        if debug_dict["grads"] is None:
            # First iteration, first scene
            debug_dict["grads"] = [[]]
        elif step == 0:
            # New iteration, new scene
            debug_dict["grads"].append([])

        grads = [param.grad for name, param in gaussians.named_parameters() if param.grad is not None]
        gaussian_num = gaussians.means.shape[0]
        grads = [g.view(gaussian_num, -1) for g in grads]
        grads = [g.detach().cpu() for g in grads]
        grads = torch.cat(grads, dim=-1)  # [num_gaussians, total_param_dim]

        debug_dict["grads"][-1].append(grads)

    def get_optimizer(self, gaussians: GaussiansModule, scene_scale: float):

        # TODO Naama: support different batch sizes
        batch_size: int = 1

        # Build params list (name, parameter, lr)
        named_parameters = dict(gaussians.named_parameters())
        params = []
        for key in named_parameters.keys():
            lr_data_attr = key
            lr_data_attr = lr_data_attr.replace("_raw", "")
            lr_data_attr = lr_data_attr.replace("_unnorm", "")
            params.append((key, named_parameters[key], getattr(self.cfg.lr_data, lr_data_attr)))

        world_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
        print(f"World size: {world_size}")

        BS = batch_size * world_size
        # Build parameter groups for a single optimizer
        param_groups = [
            {
                "name": name,
                "params": param,
                "lr": lr * math.sqrt(BS),  # individual learning rate
            }
            for name, param, lr in params
        ]

        # Get other optimizer parameters
        opt_params = self.extract_opt_params()

        # Manipulate opt_params with BS if needed
        if "weight_decay" in opt_params:
            opt_params["weight_decay"] *= BS
        if "eps" in opt_params:
            opt_params["eps"] /= math.sqrt(BS)
        if "betas" in opt_params:
            beta1, beta2 = opt_params["betas"]
            opt_params["betas"] = (1 - BS * (1 - beta1), 1 - BS * (1 - beta2))

        # Instantiate a single optimizer with all parameter groups
        optimizer_class = load_optimizer(self.cfg.name)
        optimizer = optimizer_class(
            param_groups,
            **opt_params
        )

        # Print out info for debugging
        print("Optimizer with parameter groups:")
        for i, group in enumerate(optimizer.param_groups):
            print(
                f"Group {i} ({group.get('name', 'unnamed')}): "
                f"lr={group['lr']} params={len(group['params'])}"
            )

        return optimizer


    _OPT_PARAMS = {
        "sgd":  ("momentum", "weight_decay", "nesterov"),
        "adam": ("betas", "eps", "weight_decay", "amsgrad"),
    }

    def extract_opt_params(self):
        allowed = self._OPT_PARAMS.get(self.cfg.name, ())
        return {k: getattr(self.cfg, k) for k in allowed if getattr(self.cfg, k, None) is not None}

    def get_scheduler(self, optimizer, scene_scale: float = 1.0, prior_steps: int = 0):
        if self.cfg.scheduler is None:
            return None

        total_steps = prior_steps + self.cfg.steps

        if self.cfg.scheduler == "exponential":
            print(f"Using exponential LR scheduler (total_steps={total_steps}, prior_steps={prior_steps})")

            # [Improvement 2] Per-param-group scheduling:
            # - Means: exponential decay optionally scaled by scene_extent (matching vanilla optimizer)
            # - Other params: constant LR
            lambdas = []
            for group in optimizer.param_groups:
                if group["name"] == "means" and self.cfg.means_lr_scale_by_scene_extent:
                    # Vanilla-style means LR: exponential decay with scene_extent scaling
                    base_lr = group["lr"]  # initial means LR from param group
                    means_lr_func = get_expon_lr_func(
                        lr_init=base_lr * scene_scale,
                        lr_final=base_lr * scene_scale * self.cfg.means_lr_final_ratio,
                        lr_delay_mult=self.cfg.means_lr_delay_mult,
                        max_steps=total_steps,
                    )
                    # LambdaLR computes: effective_lr = base_lr * lambda(step)
                    # We want: effective_lr = means_lr_func(step)
                    # So: lambda(step) = means_lr_func(step) / base_lr
                    _base_lr = base_lr  # capture for closure
                    _func = means_lr_func
                    lambdas.append(lambda step, f=_func, b=_base_lr: f(step) / b)
                else:
                    # Constant LR for all other param groups
                    lambdas.append(lambda step: 1.0)

            scheduler = LambdaLR(optimizer, lr_lambda=lambdas)
            # Fast-forward to prior_steps so LR continues from where scene trainer left off
            for _ in range(prior_steps):
                scheduler.step()
            return scheduler

        else:
            raise ValueError(f"Unknown scheduler: {self.cfg.scheduler}")