Spaces:
Running on Zero
Running on Zero
| """ | |
| Custom optimizer constructor for viewpoint-conditioned training. | |
| Supports parameter-wise learning rates for different model components: | |
| - viewpoint_mlp: Higher LR for new viewpoint token module | |
| - viewpoint_head: Higher LR for new viewpoint prediction head | |
| - llm: Lower LR for pretrained LLM | |
| - mar: Lower LR for pretrained MAR | |
| - proj_in/proj_out: Medium LR for projection layers | |
| """ | |
| import torch.nn as nn | |
| from mmengine.optim import DefaultOptimWrapperConstructor, OptimWrapper | |
| from mmengine.registry import OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, OPTIMIZERS | |
| import inspect | |
| class ViewpointOptimWrapperConstructor(DefaultOptimWrapperConstructor): | |
| """ | |
| Custom optimizer wrapper constructor with parameter-wise learning rates. | |
| Expects the following parameters in optim_wrapper_cfg: | |
| - lr_viewpoint: Learning rate for viewpoint modules (default: 1e-3) | |
| - lr_llm: Learning rate for LLM (default: 1e-5) | |
| - lr_mar: Learning rate for MAR (default: 1e-5) | |
| - lr_proj: Learning rate for projection layers (default: 1e-4) | |
| """ | |
| def __call__(self, model: nn.Module) -> OptimWrapper: | |
| if hasattr(model, 'module'): | |
| model = model.module | |
| optim_wrapper_cfg = self.optim_wrapper_cfg.copy() | |
| optim_wrapper_cfg.setdefault('type', 'OptimWrapper') | |
| optimizer_cfg = self.optimizer_cfg.copy() | |
| # Get base learning rate and weight decay | |
| base_lr = optimizer_cfg.get('lr', 1e-5) | |
| weight_decay = optimizer_cfg.pop('weight_decay', 0.02) | |
| # Get component-specific learning rates (with fallbacks) | |
| lr_viewpoint = optim_wrapper_cfg.pop('lr_viewpoint', 1e-3) | |
| lr_llm = optim_wrapper_cfg.pop('lr_llm', 1e-5) | |
| lr_mar = optim_wrapper_cfg.pop('lr_mar', 1e-5) | |
| lr_proj = optim_wrapper_cfg.pop('lr_proj', 1e-4) | |
| # Freeze parameters for components with lr=0 | |
| # This saves memory and computation by not computing gradients | |
| frozen_components = [] | |
| component_lr_map = { | |
| 'viewpoint': lr_viewpoint, # Covers both viewpoint_mlp and viewpoint_head | |
| 'llm': lr_llm, | |
| 'mar': lr_mar, | |
| 'proj': lr_proj, | |
| } | |
| print("\n" + "="*80) | |
| print("Viewpoint Optimizer: Checking for frozen components (lr=0)") | |
| print("="*80) | |
| for component_name, component_lr in component_lr_map.items(): | |
| if component_lr == 0: | |
| frozen_components.append(component_name) | |
| # Freeze parameters for this component | |
| num_frozen = 0 | |
| for name, param in model.named_parameters(): | |
| # Match component name patterns | |
| should_freeze = False | |
| if component_name == 'viewpoint' and ('viewpoint_mlp' in name or 'viewpoint_head' in name): | |
| should_freeze = True | |
| elif component_name == 'llm' and 'llm' in name: | |
| should_freeze = True | |
| elif component_name == 'mar' and 'mar' in name: | |
| should_freeze = True | |
| elif component_name == 'proj' and ('proj_in' in name or 'proj_out' in name): | |
| should_freeze = True | |
| if should_freeze: | |
| param.requires_grad = False | |
| num_frozen += param.numel() | |
| print(f" ✓ Frozen {component_name}: {num_frozen:,} parameters (lr=0)") | |
| if not frozen_components: | |
| print(" No components frozen (all have lr > 0)") | |
| print("="*80 + "\n") | |
| # Categorize parameters by component | |
| viewpoint_mlp_params = [] | |
| viewpoint_head_params = [] | |
| llm_params = [] | |
| mar_params = [] | |
| proj_params = [] | |
| other_params = [] | |
| # Track no-decay parameters | |
| viewpoint_mlp_no_decay = [] | |
| viewpoint_head_no_decay = [] | |
| llm_no_decay = [] | |
| mar_no_decay = [] | |
| proj_no_decay = [] | |
| other_no_decay = [] | |
| for name, param in model.named_parameters(): | |
| if not param.requires_grad: | |
| continue | |
| # Determine if parameter should have weight decay | |
| # Skip bias, norms, and diffloss | |
| apply_decay = True | |
| if len(param.shape) == 1 or name.endswith(".bias") or 'diffloss' in name: | |
| apply_decay = False | |
| # Categorize by component | |
| if 'viewpoint_mlp' in name: | |
| if apply_decay: | |
| viewpoint_mlp_params.append(param) | |
| else: | |
| viewpoint_mlp_no_decay.append(param) | |
| elif 'viewpoint_head' in name: | |
| if apply_decay: | |
| viewpoint_head_params.append(param) | |
| else: | |
| viewpoint_head_no_decay.append(param) | |
| elif 'llm' in name: | |
| if apply_decay: | |
| llm_params.append(param) | |
| else: | |
| llm_no_decay.append(param) | |
| elif 'mar' in name: | |
| if apply_decay: | |
| mar_params.append(param) | |
| else: | |
| mar_no_decay.append(param) | |
| elif 'proj_in' in name or 'proj_out' in name: | |
| if apply_decay: | |
| proj_params.append(param) | |
| else: | |
| proj_no_decay.append(param) | |
| else: | |
| if apply_decay: | |
| other_params.append(param) | |
| else: | |
| other_no_decay.append(param) | |
| # Build parameter groups | |
| param_groups = [] | |
| # Viewpoint MLP (with decay) | |
| if viewpoint_mlp_params: | |
| param_groups.append({ | |
| 'params': viewpoint_mlp_params, | |
| 'lr': lr_viewpoint, | |
| 'weight_decay': weight_decay, | |
| 'name': 'viewpoint_mlp_decay' | |
| }) | |
| if viewpoint_mlp_no_decay: | |
| param_groups.append({ | |
| 'params': viewpoint_mlp_no_decay, | |
| 'lr': lr_viewpoint, | |
| 'weight_decay': 0.0, | |
| 'name': 'viewpoint_mlp_no_decay' | |
| }) | |
| # Viewpoint Head (with decay) | |
| if viewpoint_head_params: | |
| param_groups.append({ | |
| 'params': viewpoint_head_params, | |
| 'lr': lr_viewpoint, | |
| 'weight_decay': weight_decay, | |
| 'name': 'viewpoint_head_decay' | |
| }) | |
| if viewpoint_head_no_decay: | |
| param_groups.append({ | |
| 'params': viewpoint_head_no_decay, | |
| 'lr': lr_viewpoint, | |
| 'weight_decay': 0.0, | |
| 'name': 'viewpoint_head_no_decay' | |
| }) | |
| # LLM | |
| if llm_params: | |
| param_groups.append({ | |
| 'params': llm_params, | |
| 'lr': lr_llm, | |
| 'weight_decay': weight_decay, | |
| 'name': 'llm_decay' | |
| }) | |
| if llm_no_decay: | |
| param_groups.append({ | |
| 'params': llm_no_decay, | |
| 'lr': lr_llm, | |
| 'weight_decay': 0.0, | |
| 'name': 'llm_no_decay' | |
| }) | |
| # MAR | |
| if mar_params: | |
| param_groups.append({ | |
| 'params': mar_params, | |
| 'lr': lr_mar, | |
| 'weight_decay': weight_decay, | |
| 'name': 'mar_decay' | |
| }) | |
| if mar_no_decay: | |
| param_groups.append({ | |
| 'params': mar_no_decay, | |
| 'lr': lr_mar, | |
| 'weight_decay': 0.0, | |
| 'name': 'mar_no_decay' | |
| }) | |
| # Projection layers | |
| if proj_params: | |
| param_groups.append({ | |
| 'params': proj_params, | |
| 'lr': lr_proj, | |
| 'weight_decay': weight_decay, | |
| 'name': 'proj_decay' | |
| }) | |
| if proj_no_decay: | |
| param_groups.append({ | |
| 'params': proj_no_decay, | |
| 'lr': lr_proj, | |
| 'weight_decay': 0.0, | |
| 'name': 'proj_no_decay' | |
| }) | |
| # Other parameters | |
| if other_params: | |
| param_groups.append({ | |
| 'params': other_params, | |
| 'lr': base_lr, | |
| 'weight_decay': weight_decay, | |
| 'name': 'other_decay' | |
| }) | |
| if other_no_decay: | |
| param_groups.append({ | |
| 'params': other_no_decay, | |
| 'lr': base_lr, | |
| 'weight_decay': 0.0, | |
| 'name': 'other_no_decay' | |
| }) | |
| # Print parameter group statistics | |
| print("\n" + "="*80) | |
| print("Viewpoint Optimizer Parameter Groups:") | |
| print("="*80) | |
| for group in param_groups: | |
| num_params = sum(p.numel() for p in group['params']) | |
| print(f" {group['name']:30s} | LR: {group['lr']:.2e} | " | |
| f"Weight Decay: {group['weight_decay']:.2e} | " | |
| f"Params: {num_params:,}") | |
| print("="*80 + "\n") | |
| # Build optimizer | |
| optimizer_cls = self.optimizer_cfg['type'] | |
| if isinstance(optimizer_cls, str): | |
| with OPTIMIZERS.switch_scope_and_registry(None) as registry: | |
| optimizer_cls = registry.get(self.optimizer_cfg['type']) | |
| first_arg_name = next(iter(inspect.signature(optimizer_cls).parameters)) | |
| optimizer_cfg[first_arg_name] = param_groups | |
| optimizer = OPTIMIZERS.build(optimizer_cfg) | |
| # Build optimizer wrapper | |
| optim_wrapper = OPTIM_WRAPPERS.build( | |
| optim_wrapper_cfg, default_args=dict(optimizer=optimizer)) | |
| return optim_wrapper | |