viewtoken-harmon-demo / src /optimisers /viewpoint_constructor.py
XinxuanLu's picture
Initial demo
becf13a verified
"""
Custom optimizer constructor for viewpoint-conditioned training.
Supports parameter-wise learning rates for different model components:
- viewpoint_mlp: Higher LR for new viewpoint token module
- viewpoint_head: Higher LR for new viewpoint prediction head
- llm: Lower LR for pretrained LLM
- mar: Lower LR for pretrained MAR
- proj_in/proj_out: Medium LR for projection layers
"""
import torch.nn as nn
from mmengine.optim import DefaultOptimWrapperConstructor, OptimWrapper
from mmengine.registry import OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, OPTIMIZERS
import inspect
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
class ViewpointOptimWrapperConstructor(DefaultOptimWrapperConstructor):
"""
Custom optimizer wrapper constructor with parameter-wise learning rates.
Expects the following parameters in optim_wrapper_cfg:
- lr_viewpoint: Learning rate for viewpoint modules (default: 1e-3)
- lr_llm: Learning rate for LLM (default: 1e-5)
- lr_mar: Learning rate for MAR (default: 1e-5)
- lr_proj: Learning rate for projection layers (default: 1e-4)
"""
def __call__(self, model: nn.Module) -> OptimWrapper:
if hasattr(model, 'module'):
model = model.module
optim_wrapper_cfg = self.optim_wrapper_cfg.copy()
optim_wrapper_cfg.setdefault('type', 'OptimWrapper')
optimizer_cfg = self.optimizer_cfg.copy()
# Get base learning rate and weight decay
base_lr = optimizer_cfg.get('lr', 1e-5)
weight_decay = optimizer_cfg.pop('weight_decay', 0.02)
# Get component-specific learning rates (with fallbacks)
lr_viewpoint = optim_wrapper_cfg.pop('lr_viewpoint', 1e-3)
lr_llm = optim_wrapper_cfg.pop('lr_llm', 1e-5)
lr_mar = optim_wrapper_cfg.pop('lr_mar', 1e-5)
lr_proj = optim_wrapper_cfg.pop('lr_proj', 1e-4)
# Freeze parameters for components with lr=0
# This saves memory and computation by not computing gradients
frozen_components = []
component_lr_map = {
'viewpoint': lr_viewpoint, # Covers both viewpoint_mlp and viewpoint_head
'llm': lr_llm,
'mar': lr_mar,
'proj': lr_proj,
}
print("\n" + "="*80)
print("Viewpoint Optimizer: Checking for frozen components (lr=0)")
print("="*80)
for component_name, component_lr in component_lr_map.items():
if component_lr == 0:
frozen_components.append(component_name)
# Freeze parameters for this component
num_frozen = 0
for name, param in model.named_parameters():
# Match component name patterns
should_freeze = False
if component_name == 'viewpoint' and ('viewpoint_mlp' in name or 'viewpoint_head' in name):
should_freeze = True
elif component_name == 'llm' and 'llm' in name:
should_freeze = True
elif component_name == 'mar' and 'mar' in name:
should_freeze = True
elif component_name == 'proj' and ('proj_in' in name or 'proj_out' in name):
should_freeze = True
if should_freeze:
param.requires_grad = False
num_frozen += param.numel()
print(f" ✓ Frozen {component_name}: {num_frozen:,} parameters (lr=0)")
if not frozen_components:
print(" No components frozen (all have lr > 0)")
print("="*80 + "\n")
# Categorize parameters by component
viewpoint_mlp_params = []
viewpoint_head_params = []
llm_params = []
mar_params = []
proj_params = []
other_params = []
# Track no-decay parameters
viewpoint_mlp_no_decay = []
viewpoint_head_no_decay = []
llm_no_decay = []
mar_no_decay = []
proj_no_decay = []
other_no_decay = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
# Determine if parameter should have weight decay
# Skip bias, norms, and diffloss
apply_decay = True
if len(param.shape) == 1 or name.endswith(".bias") or 'diffloss' in name:
apply_decay = False
# Categorize by component
if 'viewpoint_mlp' in name:
if apply_decay:
viewpoint_mlp_params.append(param)
else:
viewpoint_mlp_no_decay.append(param)
elif 'viewpoint_head' in name:
if apply_decay:
viewpoint_head_params.append(param)
else:
viewpoint_head_no_decay.append(param)
elif 'llm' in name:
if apply_decay:
llm_params.append(param)
else:
llm_no_decay.append(param)
elif 'mar' in name:
if apply_decay:
mar_params.append(param)
else:
mar_no_decay.append(param)
elif 'proj_in' in name or 'proj_out' in name:
if apply_decay:
proj_params.append(param)
else:
proj_no_decay.append(param)
else:
if apply_decay:
other_params.append(param)
else:
other_no_decay.append(param)
# Build parameter groups
param_groups = []
# Viewpoint MLP (with decay)
if viewpoint_mlp_params:
param_groups.append({
'params': viewpoint_mlp_params,
'lr': lr_viewpoint,
'weight_decay': weight_decay,
'name': 'viewpoint_mlp_decay'
})
if viewpoint_mlp_no_decay:
param_groups.append({
'params': viewpoint_mlp_no_decay,
'lr': lr_viewpoint,
'weight_decay': 0.0,
'name': 'viewpoint_mlp_no_decay'
})
# Viewpoint Head (with decay)
if viewpoint_head_params:
param_groups.append({
'params': viewpoint_head_params,
'lr': lr_viewpoint,
'weight_decay': weight_decay,
'name': 'viewpoint_head_decay'
})
if viewpoint_head_no_decay:
param_groups.append({
'params': viewpoint_head_no_decay,
'lr': lr_viewpoint,
'weight_decay': 0.0,
'name': 'viewpoint_head_no_decay'
})
# LLM
if llm_params:
param_groups.append({
'params': llm_params,
'lr': lr_llm,
'weight_decay': weight_decay,
'name': 'llm_decay'
})
if llm_no_decay:
param_groups.append({
'params': llm_no_decay,
'lr': lr_llm,
'weight_decay': 0.0,
'name': 'llm_no_decay'
})
# MAR
if mar_params:
param_groups.append({
'params': mar_params,
'lr': lr_mar,
'weight_decay': weight_decay,
'name': 'mar_decay'
})
if mar_no_decay:
param_groups.append({
'params': mar_no_decay,
'lr': lr_mar,
'weight_decay': 0.0,
'name': 'mar_no_decay'
})
# Projection layers
if proj_params:
param_groups.append({
'params': proj_params,
'lr': lr_proj,
'weight_decay': weight_decay,
'name': 'proj_decay'
})
if proj_no_decay:
param_groups.append({
'params': proj_no_decay,
'lr': lr_proj,
'weight_decay': 0.0,
'name': 'proj_no_decay'
})
# Other parameters
if other_params:
param_groups.append({
'params': other_params,
'lr': base_lr,
'weight_decay': weight_decay,
'name': 'other_decay'
})
if other_no_decay:
param_groups.append({
'params': other_no_decay,
'lr': base_lr,
'weight_decay': 0.0,
'name': 'other_no_decay'
})
# Print parameter group statistics
print("\n" + "="*80)
print("Viewpoint Optimizer Parameter Groups:")
print("="*80)
for group in param_groups:
num_params = sum(p.numel() for p in group['params'])
print(f" {group['name']:30s} | LR: {group['lr']:.2e} | "
f"Weight Decay: {group['weight_decay']:.2e} | "
f"Params: {num_params:,}")
print("="*80 + "\n")
# Build optimizer
optimizer_cls = self.optimizer_cfg['type']
if isinstance(optimizer_cls, str):
with OPTIMIZERS.switch_scope_and_registry(None) as registry:
optimizer_cls = registry.get(self.optimizer_cfg['type'])
first_arg_name = next(iter(inspect.signature(optimizer_cls).parameters))
optimizer_cfg[first_arg_name] = param_groups
optimizer = OPTIMIZERS.build(optimizer_cfg)
# Build optimizer wrapper
optim_wrapper = OPTIM_WRAPPERS.build(
optim_wrapper_cfg, default_args=dict(optimizer=optimizer))
return optim_wrapper