AniGen / anigen /utils /optimizer_utils.py
Yihua7's picture
Initial commit: AniGen - Animatable 3D Generation
6b92ff7
"""
Optimizer utilities for PyTorch Lightning systems.
"""
import torch
import torch.nn as nn
from typing import Dict, List, Any
def create_muon_optimizer(models: Dict[str, nn.Module], optimizer_args: Dict[str, Any]):
"""
Create Muon optimizer with different learning rates for different parameters.
Args:
models: Dictionary of models
optimizer_args: Optimizer configuration with muon_lr, muon_weight_decay,
other_lr, other_weight_decay
Returns:
Configured optimizer
"""
muon_lr = optimizer_args.get('muon_lr', 0.02)
muon_weight_decay = optimizer_args.get('muon_weight_decay', 0.01)
other_lr = optimizer_args.get('other_lr', 1e-4)
other_weight_decay = optimizer_args.get('other_weight_decay', 0.01)
# Separate parameters for Muon and other optimizers
muon_params = []
other_params = []
for name, model in models.items():
for param_name, param in model.named_parameters():
if not param.requires_grad:
continue
# Define which parameters should use Muon optimizer
# This is a heuristic - adjust based on your model architecture
if ('weight' in param_name and
param.dim() >= 2 and
param.numel() >= 1024): # Large weight matrices
muon_params.append(param)
else:
other_params.append(param)
# Try to import and use Muon optimizer
try:
from muon import Muon
# Create parameter groups
param_groups = []
if muon_params:
param_groups.append({
'params': muon_params,
'lr': muon_lr,
'weight_decay': muon_weight_decay,
'momentum': 0.95, # Muon-specific parameter
})
if other_params:
param_groups.append({
'params': other_params,
'lr': other_lr,
'weight_decay': other_weight_decay,
})
return Muon(param_groups)
except ImportError:
print("Warning: Muon optimizer not available. Falling back to AdamW.")
# Fallback to AdamW with combined parameters
all_params = muon_params + other_params
return torch.optim.AdamW(all_params, lr=other_lr, weight_decay=other_weight_decay)
def create_optimizer(models: Dict[str, nn.Module], optimizer_config: Dict[str, Any]):
"""
Create optimizer based on configuration.
Args:
models: Dictionary of models
optimizer_config: Optimizer configuration
Returns:
Configured optimizer
"""
optimizer_name = optimizer_config.get('name', 'Adam').lower()
optimizer_args = optimizer_config.get('args', {})
# Get all parameters
params = []
for model in models.values():
params.extend([p for p in model.parameters() if p.requires_grad])
if optimizer_name == 'adam':
return torch.optim.Adam(params, **optimizer_args)
elif optimizer_name == 'adamw':
return torch.optim.AdamW(params, **optimizer_args)
elif optimizer_name == 'sgd':
return torch.optim.SGD(params, **optimizer_args)
elif optimizer_name == 'muon':
return create_muon_optimizer(models, optimizer_args)
else:
raise ValueError(f"Unknown optimizer: {optimizer_name}")
def create_lr_scheduler(optimizer, lr_scheduler_config: Dict[str, Any]):
"""
Create learning rate scheduler based on configuration.
Args:
optimizer: The optimizer to schedule
lr_scheduler_config: Scheduler configuration
Returns:
Configured scheduler
"""
scheduler_name = lr_scheduler_config.get('name', 'CosineAnnealingLR')
scheduler_args = lr_scheduler_config.get('args', {})
if scheduler_name == 'CosineAnnealingLR':
return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, **scheduler_args)
elif scheduler_name == 'LinearLR':
return torch.optim.lr_scheduler.LinearLR(optimizer, **scheduler_args)
elif scheduler_name == 'ExponentialLR':
return torch.optim.lr_scheduler.ExponentialLR(optimizer, **scheduler_args)
elif scheduler_name == 'StepLR':
return torch.optim.lr_scheduler.StepLR(optimizer, **scheduler_args)
elif scheduler_name == 'MultiStepLR':
return torch.optim.lr_scheduler.MultiStepLR(optimizer, **scheduler_args)
elif scheduler_name == 'SequentialLR':
# Handle sequential scheduler
schedulers = []
for sched_config in lr_scheduler_config['schedulers']:
sched = create_single_scheduler(sched_config, optimizer)
schedulers.append(sched)
return torch.optim.lr_scheduler.SequentialLR(
optimizer, schedulers, **scheduler_args
)
else:
raise ValueError(f"Unknown scheduler: {scheduler_name}")
def create_single_scheduler(scheduler_config: Dict[str, Any], optimizer):
"""Create a single scheduler for SequentialLR."""
name = scheduler_config['name']
args = scheduler_config['args']
if name == 'LinearLR':
return torch.optim.lr_scheduler.LinearLR(optimizer, **args)
elif name == 'CosineAnnealingLR':
return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, **args)
elif name == 'ExponentialLR':
return torch.optim.lr_scheduler.ExponentialLR(optimizer, **args)
elif name == 'StepLR':
return torch.optim.lr_scheduler.StepLR(optimizer, **args)
elif name == 'MultiStepLR':
return torch.optim.lr_scheduler.MultiStepLR(optimizer, **args)
else:
raise ValueError(f"Unknown scheduler: {name}")
class AdaptiveGradClipper:
"""
Adaptive gradient clipping based on percentile of gradient norms.
"""
def __init__(self, max_norm: float = 1.0, clip_percentile: float = 95):
self.max_norm = max_norm
self.clip_percentile = clip_percentile
self.grad_norms_history = []
self.history_size = 1000
def __call__(self, parameters):
"""Apply adaptive gradient clipping."""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = [p for p in parameters if p.grad is not None]
if not parameters:
return torch.tensor(0.0)
# Calculate gradient norm
total_norm = torch.norm(
torch.stack([torch.norm(p.grad.detach()) for p in parameters])
)
# Update history
self.grad_norms_history.append(total_norm.item())
if len(self.grad_norms_history) > self.history_size:
self.grad_norms_history.pop(0)
# Calculate adaptive threshold
if len(self.grad_norms_history) >= 10:
threshold = torch.quantile(
torch.tensor(self.grad_norms_history),
self.clip_percentile / 100.0
)
clip_value = min(self.max_norm, threshold.item())
else:
clip_value = self.max_norm
# Apply clipping
if total_norm > clip_value:
clip_coef = clip_value / (total_norm + 1e-6)
for p in parameters:
p.grad.detach().mul_(clip_coef)
return total_norm
def create_grad_clipper(grad_clip_config: Dict[str, Any]):
"""Create gradient clipper based on configuration."""
if grad_clip_config is None:
return None
name = grad_clip_config.get('name', 'norm')
args = grad_clip_config.get('args', {})
if name.lower() == 'norm':
max_norm = args.get('max_norm', 1.0)
return lambda params: torch.nn.utils.clip_grad_norm_(params, max_norm)
elif name.lower() == 'adaptivegradclipper':
return AdaptiveGradClipper(**args)
else:
raise ValueError(f"Unknown gradient clipper: {name}")