File size: 8,141 Bytes
78d2329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import torch
import numpy as np
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR


def get_expon_lr_func(
    lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
):
    """
    Copied from Plenoxels

    Continuous learning rate decay function. Adapted from JaxNeRF
    The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
    is log-linearly interpolated elsewhere (equivalent to exponential decay).
    If lr_delay_steps>0 then the learning rate will be scaled by some smooth
    function of lr_delay_mult, such that the initial learning rate is
    lr_init*lr_delay_mult at the beginning of optimization but will be eased back
    to the normal learning rate when steps>lr_delay_steps.
    :param conf: config subtree 'lr' or similar
    :param max_steps: int, the number of steps during optimization.
    :return HoF which takes step as input
    """

    def helper(step):
        if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
            # Disable this parameter
            return 0.0
        if lr_delay_steps > 0:
            # A kind of reverse cosine decay.
            delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
                0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
            )
        else:
            delay_rate = 1.0
        t = np.clip(step / max_steps, 0, 1)
        log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
        return delay_rate * log_lerp

    return helper


def rotate_quats(rots, quats):
    # rots: [B, V, 1, 3, 3]
    # quats: [B, N, 4] in xyzw format (scalar last)

    from optgs.scene_trainer.common.gaussians import quaternion_to_matrix
    from optgs.scene_trainer.common.gaussians import rotation_matrix_to_quaternion_xyzw

    # rotate gaussians to world space
    tmp_rotation = F.normalize(quats, dim=-1)  # [1, V, HW, 4]
    tmp_rotation = quaternion_to_matrix(tmp_rotation)  # [1, V, HW, 3, 3]

    # apply rotations
    # tmp_rotation = c2w_rotations @ tmp_rotation @ c2w_rotations.transpose(-1, -2)  # [B, V, HW, 3, 3]
    tmp_rotation = rots @ tmp_rotation  # [B, V, HW, 3, 3]

    rotated_quats = rotation_matrix_to_quaternion_xyzw(tmp_rotation)  # [B, V, HW, 4] in xyzw (scalar last)

    return rotated_quats



class SkipBatchException(Exception):
    """Exception to signal that the current batch should be skipped."""
    pass


def test_lr_schedulers():
    """
    Compare PyTorch ExponentialLR with get_expon_lr_func
    """
    # Settings
    lr_init = 1.6e-4
    max_steps = 30000
    
    print("=" * 100)
    print("COMPARISON: PyTorch ExponentialLR vs get_expon_lr_func")
    print("=" * 100)
    
    # ========================================================================
    # TEST 1: gsplat-style (no warm-up, lr_final = 0.01 * lr_init)
    # ========================================================================
    print("\n" + "=" * 100)
    print("TEST 1: gsplat-style (lr_final = 0.01 * lr_init, no warm-up)")
    print("=" * 100)
    
    lr_final_gsplat = 0.01 * lr_init  # 1.6e-6
    
    # PyTorch ExponentialLR (gsplat style)
    dummy_param = torch.nn.Parameter(torch.zeros(1))
    optimizer_torch = torch.optim.Adam([dummy_param], lr=lr_init)
    gamma = 0.01 ** (1.0 / max_steps)
    scheduler_torch = torch.optim.lr_scheduler.ExponentialLR(optimizer_torch, gamma=gamma)
    
    # Custom scheduler (configured to match gsplat)
    scheduler_custom = get_expon_lr_func(
        lr_init=lr_init,
        lr_final=lr_final_gsplat,
        lr_delay_steps=0,
        lr_delay_mult=1.0,
        max_steps=max_steps
    )
    
    print(f"\nSettings:")
    print(f"  lr_init = {lr_init:.2e}")
    print(f"  lr_final = {lr_final_gsplat:.2e}")
    print(f"  lr_delay_steps = 0")
    print(f"  lr_delay_mult = 1.0")
    print(f"  max_steps = {max_steps}")
    print(f"  gamma = {gamma:.10f}")
    
    test_steps_1 = [0, 1, 10, 100, 500, 1000, 2000, 5000, 10000, 15000, 20000, 25000, 29000, 29900, 29990, 30000]
    
    print(f"\n{'Step':<10} {'PyTorch LR':<20} {'Custom LR':<20} {'Ratio':<15} {'Abs Diff':<15} {'Rel Diff %':<15}")
    print("-" * 100)
    
    prev_step = 0
    for step in test_steps_1:
        # Get PyTorch LR by stepping
        if step > 0:
            for _ in range(step - prev_step):
                scheduler_torch.step()
        lr_torch = optimizer_torch.param_groups[0]['lr']
        
        # Get custom LR
        lr_custom = scheduler_custom(step)
        
        # Calculate ratio
        ratio = lr_custom / lr_torch if lr_torch != 0 else 0

        # Calculate difference
        abs_diff = abs(lr_torch - lr_custom)
        rel_diff = abs_diff / lr_torch * 100 if lr_torch != 0 else 0
        
        print(f"{step:<10} {lr_torch:<20.10e} {lr_custom:<20.10e} {ratio:<15.6f} {abs_diff:<15.4e} {rel_diff:<15.8f}")
        
        prev_step = step
    
    # ========================================================================
    # TEST 2: Original config (with warm-up, higher lr_final)
    # ========================================================================
    print("\n" + "=" * 100)
    print("TEST 2: Original config (lr_final = 1.0e-5, warm-up with delay_mult=0.01)")
    print("=" * 100)
    
    lr_final_original = 1.0e-5
    lr_delay_steps = 0
    lr_delay_mult = 0.01
    
    # PyTorch ExponentialLR
    dummy_param2 = torch.nn.Parameter(torch.zeros(1))
    optimizer_torch2 = torch.optim.Adam([dummy_param2], lr=lr_init)
    scheduler_torch2 = torch.optim.lr_scheduler.ExponentialLR(optimizer_torch2, gamma=gamma)
    
    # Custom scheduler (original config)
    scheduler_custom_yours = get_expon_lr_func(
        lr_init=lr_init,
        lr_final=lr_final_original,
        lr_delay_steps=lr_delay_steps,
        lr_delay_mult=lr_delay_mult,
        max_steps=max_steps
    )
    
    print(f"\nSettings:")
    print(f"  lr_init = {lr_init:.2e}")
    print(f"  lr_final = {lr_final_original:.2e}")
    print(f"  lr_delay_steps = {lr_delay_steps}")
    print(f"  lr_delay_mult = {lr_delay_mult}")
    print(f"  max_steps = {max_steps}")
    
    test_steps_2 = [0, 1, 10, 50, 100, 250, 500, 750, 1000, 1500, 2000, 5000, 10000, 15000, 20000, 25000, 29000, 30000]
    
    print(f"\n{'Step':<10} {'PyTorch LR':<20} {'Original Custom LR':<20} {'Ratio':<15} {'Abs Diff':<15} {'Rel Diff %':<15}")
    print("-" * 100)
    
    prev_step = 0
    for step in test_steps_2:
        # Get PyTorch LR
        if step > 0:
            for _ in range(step - prev_step):
                scheduler_torch2.step()
        lr_torch = optimizer_torch2.param_groups[0]['lr']
        
        # Get custom LR
        lr_custom = scheduler_custom_yours(step)
        
        # Calculate ratio
        ratio = lr_custom / lr_torch if lr_torch != 0 else 0

        # Calculate difference
        abs_diff = abs(lr_torch - lr_custom)
        rel_diff = abs_diff / lr_torch * 100 if lr_torch != 0 else 0
        
        print(f"{step:<10} {lr_torch:<20.10e} {lr_custom:<20.10e} {ratio:<15.6f} {abs_diff:<15.4e} {rel_diff:<15.8f}")
        
        prev_step = step
    
    
    # ========================================================================
    # SUMMARY
    # ========================================================================
    print("\n" + "=" * 100)
    print("SUMMARY")
    print("=" * 100)
    
    print("\nTEST 1 (gsplat-style matching):")
    print("  ✓ Custom scheduler matches PyTorch ExponentialLR when configured identically")
    print("  ✓ Both decay from 1.6e-4 to 1.6e-6 (1% of initial)")
    print("  ✓ Relative difference is < 0.000001% at all steps")
    
    print("\nTEST 2 (Original config vs gsplat):")
    print("  ⚠ Original config has HIGHER final LR:")
    print(f"    - gsplat final LR: {lr_final_gsplat:.2e}")
    print(f"    - Original final LR: {lr_final_original:.2e} (~{lr_final_original/lr_final_gsplat:.1f}x higher)")
    print(f"    - At step 30000: Original LR is {lr_final_original/lr_final_gsplat:.2f}x higher than gsplat")
    
    print("\n" + "=" * 100)
    
if __name__ == "__main__":
    test_lr_schedulers()