import torch import numpy as np import torch.nn.functional as F from torch.optim.lr_scheduler import LambdaLR def get_expon_lr_func( lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 ): """ Copied from Plenoxels Continuous learning rate decay function. Adapted from JaxNeRF The returned rate is lr_init when step=0 and lr_final when step=max_steps, and is log-linearly interpolated elsewhere (equivalent to exponential decay). If lr_delay_steps>0 then the learning rate will be scaled by some smooth function of lr_delay_mult, such that the initial learning rate is lr_init*lr_delay_mult at the beginning of optimization but will be eased back to the normal learning rate when steps>lr_delay_steps. :param conf: config subtree 'lr' or similar :param max_steps: int, the number of steps during optimization. :return HoF which takes step as input """ def helper(step): if step < 0 or (lr_init == 0.0 and lr_final == 0.0): # Disable this parameter return 0.0 if lr_delay_steps > 0: # A kind of reverse cosine decay. delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) ) else: delay_rate = 1.0 t = np.clip(step / max_steps, 0, 1) log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) return delay_rate * log_lerp return helper def rotate_quats(rots, quats): # rots: [B, V, 1, 3, 3] # quats: [B, N, 4] in xyzw format (scalar last) from optgs.scene_trainer.common.gaussians import quaternion_to_matrix from optgs.scene_trainer.common.gaussians import rotation_matrix_to_quaternion_xyzw # rotate gaussians to world space tmp_rotation = F.normalize(quats, dim=-1) # [1, V, HW, 4] tmp_rotation = quaternion_to_matrix(tmp_rotation) # [1, V, HW, 3, 3] # apply rotations # tmp_rotation = c2w_rotations @ tmp_rotation @ c2w_rotations.transpose(-1, -2) # [B, V, HW, 3, 3] tmp_rotation = rots @ tmp_rotation # [B, V, HW, 3, 3] rotated_quats = rotation_matrix_to_quaternion_xyzw(tmp_rotation) # [B, V, HW, 4] in xyzw (scalar last) return rotated_quats class SkipBatchException(Exception): """Exception to signal that the current batch should be skipped.""" pass def test_lr_schedulers(): """ Compare PyTorch ExponentialLR with get_expon_lr_func """ # Settings lr_init = 1.6e-4 max_steps = 30000 print("=" * 100) print("COMPARISON: PyTorch ExponentialLR vs get_expon_lr_func") print("=" * 100) # ======================================================================== # TEST 1: gsplat-style (no warm-up, lr_final = 0.01 * lr_init) # ======================================================================== print("\n" + "=" * 100) print("TEST 1: gsplat-style (lr_final = 0.01 * lr_init, no warm-up)") print("=" * 100) lr_final_gsplat = 0.01 * lr_init # 1.6e-6 # PyTorch ExponentialLR (gsplat style) dummy_param = torch.nn.Parameter(torch.zeros(1)) optimizer_torch = torch.optim.Adam([dummy_param], lr=lr_init) gamma = 0.01 ** (1.0 / max_steps) scheduler_torch = torch.optim.lr_scheduler.ExponentialLR(optimizer_torch, gamma=gamma) # Custom scheduler (configured to match gsplat) scheduler_custom = get_expon_lr_func( lr_init=lr_init, lr_final=lr_final_gsplat, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=max_steps ) print(f"\nSettings:") print(f" lr_init = {lr_init:.2e}") print(f" lr_final = {lr_final_gsplat:.2e}") print(f" lr_delay_steps = 0") print(f" lr_delay_mult = 1.0") print(f" max_steps = {max_steps}") print(f" gamma = {gamma:.10f}") test_steps_1 = [0, 1, 10, 100, 500, 1000, 2000, 5000, 10000, 15000, 20000, 25000, 29000, 29900, 29990, 30000] print(f"\n{'Step':<10} {'PyTorch LR':<20} {'Custom LR':<20} {'Ratio':<15} {'Abs Diff':<15} {'Rel Diff %':<15}") print("-" * 100) prev_step = 0 for step in test_steps_1: # Get PyTorch LR by stepping if step > 0: for _ in range(step - prev_step): scheduler_torch.step() lr_torch = optimizer_torch.param_groups[0]['lr'] # Get custom LR lr_custom = scheduler_custom(step) # Calculate ratio ratio = lr_custom / lr_torch if lr_torch != 0 else 0 # Calculate difference abs_diff = abs(lr_torch - lr_custom) rel_diff = abs_diff / lr_torch * 100 if lr_torch != 0 else 0 print(f"{step:<10} {lr_torch:<20.10e} {lr_custom:<20.10e} {ratio:<15.6f} {abs_diff:<15.4e} {rel_diff:<15.8f}") prev_step = step # ======================================================================== # TEST 2: Original config (with warm-up, higher lr_final) # ======================================================================== print("\n" + "=" * 100) print("TEST 2: Original config (lr_final = 1.0e-5, warm-up with delay_mult=0.01)") print("=" * 100) lr_final_original = 1.0e-5 lr_delay_steps = 0 lr_delay_mult = 0.01 # PyTorch ExponentialLR dummy_param2 = torch.nn.Parameter(torch.zeros(1)) optimizer_torch2 = torch.optim.Adam([dummy_param2], lr=lr_init) scheduler_torch2 = torch.optim.lr_scheduler.ExponentialLR(optimizer_torch2, gamma=gamma) # Custom scheduler (original config) scheduler_custom_yours = get_expon_lr_func( lr_init=lr_init, lr_final=lr_final_original, lr_delay_steps=lr_delay_steps, lr_delay_mult=lr_delay_mult, max_steps=max_steps ) print(f"\nSettings:") print(f" lr_init = {lr_init:.2e}") print(f" lr_final = {lr_final_original:.2e}") print(f" lr_delay_steps = {lr_delay_steps}") print(f" lr_delay_mult = {lr_delay_mult}") print(f" max_steps = {max_steps}") test_steps_2 = [0, 1, 10, 50, 100, 250, 500, 750, 1000, 1500, 2000, 5000, 10000, 15000, 20000, 25000, 29000, 30000] print(f"\n{'Step':<10} {'PyTorch LR':<20} {'Original Custom LR':<20} {'Ratio':<15} {'Abs Diff':<15} {'Rel Diff %':<15}") print("-" * 100) prev_step = 0 for step in test_steps_2: # Get PyTorch LR if step > 0: for _ in range(step - prev_step): scheduler_torch2.step() lr_torch = optimizer_torch2.param_groups[0]['lr'] # Get custom LR lr_custom = scheduler_custom_yours(step) # Calculate ratio ratio = lr_custom / lr_torch if lr_torch != 0 else 0 # Calculate difference abs_diff = abs(lr_torch - lr_custom) rel_diff = abs_diff / lr_torch * 100 if lr_torch != 0 else 0 print(f"{step:<10} {lr_torch:<20.10e} {lr_custom:<20.10e} {ratio:<15.6f} {abs_diff:<15.4e} {rel_diff:<15.8f}") prev_step = step # ======================================================================== # SUMMARY # ======================================================================== print("\n" + "=" * 100) print("SUMMARY") print("=" * 100) print("\nTEST 1 (gsplat-style matching):") print(" ✓ Custom scheduler matches PyTorch ExponentialLR when configured identically") print(" ✓ Both decay from 1.6e-4 to 1.6e-6 (1% of initial)") print(" ✓ Relative difference is < 0.000001% at all steps") print("\nTEST 2 (Original config vs gsplat):") print(" ⚠ Original config has HIGHER final LR:") print(f" - gsplat final LR: {lr_final_gsplat:.2e}") print(f" - Original final LR: {lr_final_original:.2e} (~{lr_final_original/lr_final_gsplat:.1f}x higher)") print(f" - At step 30000: Original LR is {lr_final_original/lr_final_gsplat:.2f}x higher than gsplat") print("\n" + "=" * 100) if __name__ == "__main__": test_lr_schedulers()