""" Custom optimizer constructor for viewpoint-conditioned training. Supports parameter-wise learning rates for different model components: - viewpoint_mlp: Higher LR for new viewpoint token module - viewpoint_head: Higher LR for new viewpoint prediction head - llm: Lower LR for pretrained LLM - mar: Lower LR for pretrained MAR - proj_in/proj_out: Medium LR for projection layers """ import torch.nn as nn from mmengine.optim import DefaultOptimWrapperConstructor, OptimWrapper from mmengine.registry import OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, OPTIMIZERS import inspect @OPTIM_WRAPPER_CONSTRUCTORS.register_module() class ViewpointOptimWrapperConstructor(DefaultOptimWrapperConstructor): """ Custom optimizer wrapper constructor with parameter-wise learning rates. Expects the following parameters in optim_wrapper_cfg: - lr_viewpoint: Learning rate for viewpoint modules (default: 1e-3) - lr_llm: Learning rate for LLM (default: 1e-5) - lr_mar: Learning rate for MAR (default: 1e-5) - lr_proj: Learning rate for projection layers (default: 1e-4) """ def __call__(self, model: nn.Module) -> OptimWrapper: if hasattr(model, 'module'): model = model.module optim_wrapper_cfg = self.optim_wrapper_cfg.copy() optim_wrapper_cfg.setdefault('type', 'OptimWrapper') optimizer_cfg = self.optimizer_cfg.copy() # Get base learning rate and weight decay base_lr = optimizer_cfg.get('lr', 1e-5) weight_decay = optimizer_cfg.pop('weight_decay', 0.02) # Get component-specific learning rates (with fallbacks) lr_viewpoint = optim_wrapper_cfg.pop('lr_viewpoint', 1e-3) lr_llm = optim_wrapper_cfg.pop('lr_llm', 1e-5) lr_mar = optim_wrapper_cfg.pop('lr_mar', 1e-5) lr_proj = optim_wrapper_cfg.pop('lr_proj', 1e-4) # Freeze parameters for components with lr=0 # This saves memory and computation by not computing gradients frozen_components = [] component_lr_map = { 'viewpoint': lr_viewpoint, # Covers both viewpoint_mlp and viewpoint_head 'llm': lr_llm, 'mar': lr_mar, 'proj': lr_proj, } print("\n" + "="*80) print("Viewpoint Optimizer: Checking for frozen components (lr=0)") print("="*80) for component_name, component_lr in component_lr_map.items(): if component_lr == 0: frozen_components.append(component_name) # Freeze parameters for this component num_frozen = 0 for name, param in model.named_parameters(): # Match component name patterns should_freeze = False if component_name == 'viewpoint' and ('viewpoint_mlp' in name or 'viewpoint_head' in name): should_freeze = True elif component_name == 'llm' and 'llm' in name: should_freeze = True elif component_name == 'mar' and 'mar' in name: should_freeze = True elif component_name == 'proj' and ('proj_in' in name or 'proj_out' in name): should_freeze = True if should_freeze: param.requires_grad = False num_frozen += param.numel() print(f" ✓ Frozen {component_name}: {num_frozen:,} parameters (lr=0)") if not frozen_components: print(" No components frozen (all have lr > 0)") print("="*80 + "\n") # Categorize parameters by component viewpoint_mlp_params = [] viewpoint_head_params = [] llm_params = [] mar_params = [] proj_params = [] other_params = [] # Track no-decay parameters viewpoint_mlp_no_decay = [] viewpoint_head_no_decay = [] llm_no_decay = [] mar_no_decay = [] proj_no_decay = [] other_no_decay = [] for name, param in model.named_parameters(): if not param.requires_grad: continue # Determine if parameter should have weight decay # Skip bias, norms, and diffloss apply_decay = True if len(param.shape) == 1 or name.endswith(".bias") or 'diffloss' in name: apply_decay = False # Categorize by component if 'viewpoint_mlp' in name: if apply_decay: viewpoint_mlp_params.append(param) else: viewpoint_mlp_no_decay.append(param) elif 'viewpoint_head' in name: if apply_decay: viewpoint_head_params.append(param) else: viewpoint_head_no_decay.append(param) elif 'llm' in name: if apply_decay: llm_params.append(param) else: llm_no_decay.append(param) elif 'mar' in name: if apply_decay: mar_params.append(param) else: mar_no_decay.append(param) elif 'proj_in' in name or 'proj_out' in name: if apply_decay: proj_params.append(param) else: proj_no_decay.append(param) else: if apply_decay: other_params.append(param) else: other_no_decay.append(param) # Build parameter groups param_groups = [] # Viewpoint MLP (with decay) if viewpoint_mlp_params: param_groups.append({ 'params': viewpoint_mlp_params, 'lr': lr_viewpoint, 'weight_decay': weight_decay, 'name': 'viewpoint_mlp_decay' }) if viewpoint_mlp_no_decay: param_groups.append({ 'params': viewpoint_mlp_no_decay, 'lr': lr_viewpoint, 'weight_decay': 0.0, 'name': 'viewpoint_mlp_no_decay' }) # Viewpoint Head (with decay) if viewpoint_head_params: param_groups.append({ 'params': viewpoint_head_params, 'lr': lr_viewpoint, 'weight_decay': weight_decay, 'name': 'viewpoint_head_decay' }) if viewpoint_head_no_decay: param_groups.append({ 'params': viewpoint_head_no_decay, 'lr': lr_viewpoint, 'weight_decay': 0.0, 'name': 'viewpoint_head_no_decay' }) # LLM if llm_params: param_groups.append({ 'params': llm_params, 'lr': lr_llm, 'weight_decay': weight_decay, 'name': 'llm_decay' }) if llm_no_decay: param_groups.append({ 'params': llm_no_decay, 'lr': lr_llm, 'weight_decay': 0.0, 'name': 'llm_no_decay' }) # MAR if mar_params: param_groups.append({ 'params': mar_params, 'lr': lr_mar, 'weight_decay': weight_decay, 'name': 'mar_decay' }) if mar_no_decay: param_groups.append({ 'params': mar_no_decay, 'lr': lr_mar, 'weight_decay': 0.0, 'name': 'mar_no_decay' }) # Projection layers if proj_params: param_groups.append({ 'params': proj_params, 'lr': lr_proj, 'weight_decay': weight_decay, 'name': 'proj_decay' }) if proj_no_decay: param_groups.append({ 'params': proj_no_decay, 'lr': lr_proj, 'weight_decay': 0.0, 'name': 'proj_no_decay' }) # Other parameters if other_params: param_groups.append({ 'params': other_params, 'lr': base_lr, 'weight_decay': weight_decay, 'name': 'other_decay' }) if other_no_decay: param_groups.append({ 'params': other_no_decay, 'lr': base_lr, 'weight_decay': 0.0, 'name': 'other_no_decay' }) # Print parameter group statistics print("\n" + "="*80) print("Viewpoint Optimizer Parameter Groups:") print("="*80) for group in param_groups: num_params = sum(p.numel() for p in group['params']) print(f" {group['name']:30s} | LR: {group['lr']:.2e} | " f"Weight Decay: {group['weight_decay']:.2e} | " f"Params: {num_params:,}") print("="*80 + "\n") # Build optimizer optimizer_cls = self.optimizer_cfg['type'] if isinstance(optimizer_cls, str): with OPTIMIZERS.switch_scope_and_registry(None) as registry: optimizer_cls = registry.get(self.optimizer_cfg['type']) first_arg_name = next(iter(inspect.signature(optimizer_cls).parameters)) optimizer_cfg[first_arg_name] = param_groups optimizer = OPTIMIZERS.build(optimizer_cfg) # Build optimizer wrapper optim_wrapper = OPTIM_WRAPPERS.build( optim_wrapper_cfg, default_args=dict(optimizer=optimizer)) return optim_wrapper