| """ |
| Optimizer utilities for PyTorch Lightning systems. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from typing import Dict, List, Any |
|
|
|
|
| def create_muon_optimizer(models: Dict[str, nn.Module], optimizer_args: Dict[str, Any]): |
| """ |
| Create Muon optimizer with different learning rates for different parameters. |
| |
| Args: |
| models: Dictionary of models |
| optimizer_args: Optimizer configuration with muon_lr, muon_weight_decay, |
| other_lr, other_weight_decay |
| |
| Returns: |
| Configured optimizer |
| """ |
| muon_lr = optimizer_args.get('muon_lr', 0.02) |
| muon_weight_decay = optimizer_args.get('muon_weight_decay', 0.01) |
| other_lr = optimizer_args.get('other_lr', 1e-4) |
| other_weight_decay = optimizer_args.get('other_weight_decay', 0.01) |
| |
| |
| muon_params = [] |
| other_params = [] |
| |
| for name, model in models.items(): |
| for param_name, param in model.named_parameters(): |
| if not param.requires_grad: |
| continue |
| |
| |
| |
| if ('weight' in param_name and |
| param.dim() >= 2 and |
| param.numel() >= 1024): |
| muon_params.append(param) |
| else: |
| other_params.append(param) |
| |
| |
| try: |
| from muon import Muon |
| |
| |
| param_groups = [] |
| |
| if muon_params: |
| param_groups.append({ |
| 'params': muon_params, |
| 'lr': muon_lr, |
| 'weight_decay': muon_weight_decay, |
| 'momentum': 0.95, |
| }) |
| |
| if other_params: |
| param_groups.append({ |
| 'params': other_params, |
| 'lr': other_lr, |
| 'weight_decay': other_weight_decay, |
| }) |
| |
| return Muon(param_groups) |
| |
| except ImportError: |
| print("Warning: Muon optimizer not available. Falling back to AdamW.") |
| |
| all_params = muon_params + other_params |
| return torch.optim.AdamW(all_params, lr=other_lr, weight_decay=other_weight_decay) |
|
|
|
|
| def create_optimizer(models: Dict[str, nn.Module], optimizer_config: Dict[str, Any]): |
| """ |
| Create optimizer based on configuration. |
| |
| Args: |
| models: Dictionary of models |
| optimizer_config: Optimizer configuration |
| |
| Returns: |
| Configured optimizer |
| """ |
| optimizer_name = optimizer_config.get('name', 'Adam').lower() |
| optimizer_args = optimizer_config.get('args', {}) |
| |
| |
| params = [] |
| for model in models.values(): |
| params.extend([p for p in model.parameters() if p.requires_grad]) |
| |
| if optimizer_name == 'adam': |
| return torch.optim.Adam(params, **optimizer_args) |
| elif optimizer_name == 'adamw': |
| return torch.optim.AdamW(params, **optimizer_args) |
| elif optimizer_name == 'sgd': |
| return torch.optim.SGD(params, **optimizer_args) |
| elif optimizer_name == 'muon': |
| return create_muon_optimizer(models, optimizer_args) |
| else: |
| raise ValueError(f"Unknown optimizer: {optimizer_name}") |
|
|
|
|
| def create_lr_scheduler(optimizer, lr_scheduler_config: Dict[str, Any]): |
| """ |
| Create learning rate scheduler based on configuration. |
| |
| Args: |
| optimizer: The optimizer to schedule |
| lr_scheduler_config: Scheduler configuration |
| |
| Returns: |
| Configured scheduler |
| """ |
| scheduler_name = lr_scheduler_config.get('name', 'CosineAnnealingLR') |
| scheduler_args = lr_scheduler_config.get('args', {}) |
| |
| if scheduler_name == 'CosineAnnealingLR': |
| return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, **scheduler_args) |
| elif scheduler_name == 'LinearLR': |
| return torch.optim.lr_scheduler.LinearLR(optimizer, **scheduler_args) |
| elif scheduler_name == 'ExponentialLR': |
| return torch.optim.lr_scheduler.ExponentialLR(optimizer, **scheduler_args) |
| elif scheduler_name == 'StepLR': |
| return torch.optim.lr_scheduler.StepLR(optimizer, **scheduler_args) |
| elif scheduler_name == 'MultiStepLR': |
| return torch.optim.lr_scheduler.MultiStepLR(optimizer, **scheduler_args) |
| elif scheduler_name == 'SequentialLR': |
| |
| schedulers = [] |
| for sched_config in lr_scheduler_config['schedulers']: |
| sched = create_single_scheduler(sched_config, optimizer) |
| schedulers.append(sched) |
| return torch.optim.lr_scheduler.SequentialLR( |
| optimizer, schedulers, **scheduler_args |
| ) |
| else: |
| raise ValueError(f"Unknown scheduler: {scheduler_name}") |
|
|
|
|
| def create_single_scheduler(scheduler_config: Dict[str, Any], optimizer): |
| """Create a single scheduler for SequentialLR.""" |
| name = scheduler_config['name'] |
| args = scheduler_config['args'] |
| |
| if name == 'LinearLR': |
| return torch.optim.lr_scheduler.LinearLR(optimizer, **args) |
| elif name == 'CosineAnnealingLR': |
| return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, **args) |
| elif name == 'ExponentialLR': |
| return torch.optim.lr_scheduler.ExponentialLR(optimizer, **args) |
| elif name == 'StepLR': |
| return torch.optim.lr_scheduler.StepLR(optimizer, **args) |
| elif name == 'MultiStepLR': |
| return torch.optim.lr_scheduler.MultiStepLR(optimizer, **args) |
| else: |
| raise ValueError(f"Unknown scheduler: {name}") |
|
|
|
|
| class AdaptiveGradClipper: |
| """ |
| Adaptive gradient clipping based on percentile of gradient norms. |
| """ |
| |
| def __init__(self, max_norm: float = 1.0, clip_percentile: float = 95): |
| self.max_norm = max_norm |
| self.clip_percentile = clip_percentile |
| self.grad_norms_history = [] |
| self.history_size = 1000 |
| |
| def __call__(self, parameters): |
| """Apply adaptive gradient clipping.""" |
| if isinstance(parameters, torch.Tensor): |
| parameters = [parameters] |
| parameters = [p for p in parameters if p.grad is not None] |
| |
| if not parameters: |
| return torch.tensor(0.0) |
| |
| |
| total_norm = torch.norm( |
| torch.stack([torch.norm(p.grad.detach()) for p in parameters]) |
| ) |
| |
| |
| self.grad_norms_history.append(total_norm.item()) |
| if len(self.grad_norms_history) > self.history_size: |
| self.grad_norms_history.pop(0) |
| |
| |
| if len(self.grad_norms_history) >= 10: |
| threshold = torch.quantile( |
| torch.tensor(self.grad_norms_history), |
| self.clip_percentile / 100.0 |
| ) |
| clip_value = min(self.max_norm, threshold.item()) |
| else: |
| clip_value = self.max_norm |
| |
| |
| if total_norm > clip_value: |
| clip_coef = clip_value / (total_norm + 1e-6) |
| for p in parameters: |
| p.grad.detach().mul_(clip_coef) |
| |
| return total_norm |
|
|
|
|
| def create_grad_clipper(grad_clip_config: Dict[str, Any]): |
| """Create gradient clipper based on configuration.""" |
| if grad_clip_config is None: |
| return None |
| |
| name = grad_clip_config.get('name', 'norm') |
| args = grad_clip_config.get('args', {}) |
| |
| if name.lower() == 'norm': |
| max_norm = args.get('max_norm', 1.0) |
| return lambda params: torch.nn.utils.clip_grad_norm_(params, max_norm) |
| elif name.lower() == 'adaptivegradclipper': |
| return AdaptiveGradClipper(**args) |
| else: |
| raise ValueError(f"Unknown gradient clipper: {name}") |
|
|