File size: 9,781 Bytes
becf13a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
"""
Custom optimizer constructor for viewpoint-conditioned training.

Supports parameter-wise learning rates for different model components:
- viewpoint_mlp: Higher LR for new viewpoint token module
- viewpoint_head: Higher LR for new viewpoint prediction head
- llm: Lower LR for pretrained LLM
- mar: Lower LR for pretrained MAR
- proj_in/proj_out: Medium LR for projection layers
"""

import torch.nn as nn
from mmengine.optim import DefaultOptimWrapperConstructor, OptimWrapper
from mmengine.registry import OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, OPTIMIZERS
import inspect


@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
class ViewpointOptimWrapperConstructor(DefaultOptimWrapperConstructor):
    """
    Custom optimizer wrapper constructor with parameter-wise learning rates.

    Expects the following parameters in optim_wrapper_cfg:
    - lr_viewpoint: Learning rate for viewpoint modules (default: 1e-3)
    - lr_llm: Learning rate for LLM (default: 1e-5)
    - lr_mar: Learning rate for MAR (default: 1e-5)
    - lr_proj: Learning rate for projection layers (default: 1e-4)
    """

    def __call__(self, model: nn.Module) -> OptimWrapper:
        if hasattr(model, 'module'):
            model = model.module

        optim_wrapper_cfg = self.optim_wrapper_cfg.copy()
        optim_wrapper_cfg.setdefault('type', 'OptimWrapper')
        optimizer_cfg = self.optimizer_cfg.copy()

        # Get base learning rate and weight decay
        base_lr = optimizer_cfg.get('lr', 1e-5)
        weight_decay = optimizer_cfg.pop('weight_decay', 0.02)

        # Get component-specific learning rates (with fallbacks)
        lr_viewpoint = optim_wrapper_cfg.pop('lr_viewpoint', 1e-3)
        lr_llm = optim_wrapper_cfg.pop('lr_llm', 1e-5)
        lr_mar = optim_wrapper_cfg.pop('lr_mar', 1e-5)
        lr_proj = optim_wrapper_cfg.pop('lr_proj', 1e-4)

        # Freeze parameters for components with lr=0
        # This saves memory and computation by not computing gradients
        frozen_components = []
        component_lr_map = {
            'viewpoint': lr_viewpoint,  # Covers both viewpoint_mlp and viewpoint_head
            'llm': lr_llm,
            'mar': lr_mar,
            'proj': lr_proj,
        }

        print("\n" + "="*80)
        print("Viewpoint Optimizer: Checking for frozen components (lr=0)")
        print("="*80)

        for component_name, component_lr in component_lr_map.items():
            if component_lr == 0:
                frozen_components.append(component_name)
                # Freeze parameters for this component
                num_frozen = 0
                for name, param in model.named_parameters():
                    # Match component name patterns
                    should_freeze = False
                    if component_name == 'viewpoint' and ('viewpoint_mlp' in name or 'viewpoint_head' in name):
                        should_freeze = True
                    elif component_name == 'llm' and 'llm' in name:
                        should_freeze = True
                    elif component_name == 'mar' and 'mar' in name:
                        should_freeze = True
                    elif component_name == 'proj' and ('proj_in' in name or 'proj_out' in name):
                        should_freeze = True

                    if should_freeze:
                        param.requires_grad = False
                        num_frozen += param.numel()

                print(f"  ✓ Frozen {component_name}: {num_frozen:,} parameters (lr=0)")

        if not frozen_components:
            print("  No components frozen (all have lr > 0)")
        print("="*80 + "\n")

        # Categorize parameters by component
        viewpoint_mlp_params = []
        viewpoint_head_params = []
        llm_params = []
        mar_params = []
        proj_params = []
        other_params = []

        # Track no-decay parameters
        viewpoint_mlp_no_decay = []
        viewpoint_head_no_decay = []
        llm_no_decay = []
        mar_no_decay = []
        proj_no_decay = []
        other_no_decay = []

        for name, param in model.named_parameters():
            if not param.requires_grad:
                continue

            # Determine if parameter should have weight decay
            # Skip bias, norms, and diffloss
            apply_decay = True
            if len(param.shape) == 1 or name.endswith(".bias") or 'diffloss' in name:
                apply_decay = False

            # Categorize by component
            if 'viewpoint_mlp' in name:
                if apply_decay:
                    viewpoint_mlp_params.append(param)
                else:
                    viewpoint_mlp_no_decay.append(param)
            elif 'viewpoint_head' in name:
                if apply_decay:
                    viewpoint_head_params.append(param)
                else:
                    viewpoint_head_no_decay.append(param)
            elif 'llm' in name:
                if apply_decay:
                    llm_params.append(param)
                else:
                    llm_no_decay.append(param)
            elif 'mar' in name:
                if apply_decay:
                    mar_params.append(param)
                else:
                    mar_no_decay.append(param)
            elif 'proj_in' in name or 'proj_out' in name:
                if apply_decay:
                    proj_params.append(param)
                else:
                    proj_no_decay.append(param)
            else:
                if apply_decay:
                    other_params.append(param)
                else:
                    other_no_decay.append(param)

        # Build parameter groups
        param_groups = []

        # Viewpoint MLP (with decay)
        if viewpoint_mlp_params:
            param_groups.append({
                'params': viewpoint_mlp_params,
                'lr': lr_viewpoint,
                'weight_decay': weight_decay,
                'name': 'viewpoint_mlp_decay'
            })
        if viewpoint_mlp_no_decay:
            param_groups.append({
                'params': viewpoint_mlp_no_decay,
                'lr': lr_viewpoint,
                'weight_decay': 0.0,
                'name': 'viewpoint_mlp_no_decay'
            })

        # Viewpoint Head (with decay)
        if viewpoint_head_params:
            param_groups.append({
                'params': viewpoint_head_params,
                'lr': lr_viewpoint,
                'weight_decay': weight_decay,
                'name': 'viewpoint_head_decay'
            })
        if viewpoint_head_no_decay:
            param_groups.append({
                'params': viewpoint_head_no_decay,
                'lr': lr_viewpoint,
                'weight_decay': 0.0,
                'name': 'viewpoint_head_no_decay'
            })

        # LLM
        if llm_params:
            param_groups.append({
                'params': llm_params,
                'lr': lr_llm,
                'weight_decay': weight_decay,
                'name': 'llm_decay'
            })
        if llm_no_decay:
            param_groups.append({
                'params': llm_no_decay,
                'lr': lr_llm,
                'weight_decay': 0.0,
                'name': 'llm_no_decay'
            })

        # MAR
        if mar_params:
            param_groups.append({
                'params': mar_params,
                'lr': lr_mar,
                'weight_decay': weight_decay,
                'name': 'mar_decay'
            })
        if mar_no_decay:
            param_groups.append({
                'params': mar_no_decay,
                'lr': lr_mar,
                'weight_decay': 0.0,
                'name': 'mar_no_decay'
            })

        # Projection layers
        if proj_params:
            param_groups.append({
                'params': proj_params,
                'lr': lr_proj,
                'weight_decay': weight_decay,
                'name': 'proj_decay'
            })
        if proj_no_decay:
            param_groups.append({
                'params': proj_no_decay,
                'lr': lr_proj,
                'weight_decay': 0.0,
                'name': 'proj_no_decay'
            })

        # Other parameters
        if other_params:
            param_groups.append({
                'params': other_params,
                'lr': base_lr,
                'weight_decay': weight_decay,
                'name': 'other_decay'
            })
        if other_no_decay:
            param_groups.append({
                'params': other_no_decay,
                'lr': base_lr,
                'weight_decay': 0.0,
                'name': 'other_no_decay'
            })

        # Print parameter group statistics
        print("\n" + "="*80)
        print("Viewpoint Optimizer Parameter Groups:")
        print("="*80)
        for group in param_groups:
            num_params = sum(p.numel() for p in group['params'])
            print(f"  {group['name']:30s} | LR: {group['lr']:.2e} | "
                  f"Weight Decay: {group['weight_decay']:.2e} | "
                  f"Params: {num_params:,}")
        print("="*80 + "\n")

        # Build optimizer
        optimizer_cls = self.optimizer_cfg['type']
        if isinstance(optimizer_cls, str):
            with OPTIMIZERS.switch_scope_and_registry(None) as registry:
                optimizer_cls = registry.get(self.optimizer_cfg['type'])

        first_arg_name = next(iter(inspect.signature(optimizer_cls).parameters))
        optimizer_cfg[first_arg_name] = param_groups
        optimizer = OPTIMIZERS.build(optimizer_cfg)

        # Build optimizer wrapper
        optim_wrapper = OPTIM_WRAPPERS.build(
            optim_wrapper_cfg, default_args=dict(optimizer=optimizer))

        return optim_wrapper