File size: 8,002 Bytes
6b92ff7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
"""
Optimizer utilities for PyTorch Lightning systems.
"""

import torch
import torch.nn as nn
from typing import Dict, List, Any


def create_muon_optimizer(models: Dict[str, nn.Module], optimizer_args: Dict[str, Any]):
    """
    Create Muon optimizer with different learning rates for different parameters.
    
    Args:
        models: Dictionary of models
        optimizer_args: Optimizer configuration with muon_lr, muon_weight_decay, 
                       other_lr, other_weight_decay
    
    Returns:
        Configured optimizer
    """
    muon_lr = optimizer_args.get('muon_lr', 0.02)
    muon_weight_decay = optimizer_args.get('muon_weight_decay', 0.01)
    other_lr = optimizer_args.get('other_lr', 1e-4)
    other_weight_decay = optimizer_args.get('other_weight_decay', 0.01)
    
    # Separate parameters for Muon and other optimizers
    muon_params = []
    other_params = []
    
    for name, model in models.items():
        for param_name, param in model.named_parameters():
            if not param.requires_grad:
                continue
                
            # Define which parameters should use Muon optimizer
            # This is a heuristic - adjust based on your model architecture
            if ('weight' in param_name and 
                param.dim() >= 2 and 
                param.numel() >= 1024):  # Large weight matrices
                muon_params.append(param)
            else:
                other_params.append(param)
    
    # Try to import and use Muon optimizer
    try:
        from muon import Muon
        
        # Create parameter groups
        param_groups = []
        
        if muon_params:
            param_groups.append({
                'params': muon_params,
                'lr': muon_lr,
                'weight_decay': muon_weight_decay,
                'momentum': 0.95,  # Muon-specific parameter
            })
        
        if other_params:
            param_groups.append({
                'params': other_params,
                'lr': other_lr,
                'weight_decay': other_weight_decay,
            })
        
        return Muon(param_groups)
        
    except ImportError:
        print("Warning: Muon optimizer not available. Falling back to AdamW.")
        # Fallback to AdamW with combined parameters
        all_params = muon_params + other_params
        return torch.optim.AdamW(all_params, lr=other_lr, weight_decay=other_weight_decay)


def create_optimizer(models: Dict[str, nn.Module], optimizer_config: Dict[str, Any]):
    """
    Create optimizer based on configuration.
    
    Args:
        models: Dictionary of models
        optimizer_config: Optimizer configuration
    
    Returns:
        Configured optimizer
    """
    optimizer_name = optimizer_config.get('name', 'Adam').lower()
    optimizer_args = optimizer_config.get('args', {})
    
    # Get all parameters
    params = []
    for model in models.values():
        params.extend([p for p in model.parameters() if p.requires_grad])
    
    if optimizer_name == 'adam':
        return torch.optim.Adam(params, **optimizer_args)
    elif optimizer_name == 'adamw':
        return torch.optim.AdamW(params, **optimizer_args)
    elif optimizer_name == 'sgd':
        return torch.optim.SGD(params, **optimizer_args)
    elif optimizer_name == 'muon':
        return create_muon_optimizer(models, optimizer_args)
    else:
        raise ValueError(f"Unknown optimizer: {optimizer_name}")


def create_lr_scheduler(optimizer, lr_scheduler_config: Dict[str, Any]):
    """
    Create learning rate scheduler based on configuration.
    
    Args:
        optimizer: The optimizer to schedule
        lr_scheduler_config: Scheduler configuration
    
    Returns:
        Configured scheduler
    """
    scheduler_name = lr_scheduler_config.get('name', 'CosineAnnealingLR')
    scheduler_args = lr_scheduler_config.get('args', {})
    
    if scheduler_name == 'CosineAnnealingLR':
        return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, **scheduler_args)
    elif scheduler_name == 'LinearLR':
        return torch.optim.lr_scheduler.LinearLR(optimizer, **scheduler_args)
    elif scheduler_name == 'ExponentialLR':
        return torch.optim.lr_scheduler.ExponentialLR(optimizer, **scheduler_args)
    elif scheduler_name == 'StepLR':
        return torch.optim.lr_scheduler.StepLR(optimizer, **scheduler_args)
    elif scheduler_name == 'MultiStepLR':
        return torch.optim.lr_scheduler.MultiStepLR(optimizer, **scheduler_args)
    elif scheduler_name == 'SequentialLR':
        # Handle sequential scheduler
        schedulers = []
        for sched_config in lr_scheduler_config['schedulers']:
            sched = create_single_scheduler(sched_config, optimizer)
            schedulers.append(sched)
        return torch.optim.lr_scheduler.SequentialLR(
            optimizer, schedulers, **scheduler_args
        )
    else:
        raise ValueError(f"Unknown scheduler: {scheduler_name}")


def create_single_scheduler(scheduler_config: Dict[str, Any], optimizer):
    """Create a single scheduler for SequentialLR."""
    name = scheduler_config['name']
    args = scheduler_config['args']
    
    if name == 'LinearLR':
        return torch.optim.lr_scheduler.LinearLR(optimizer, **args)
    elif name == 'CosineAnnealingLR':
        return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, **args)
    elif name == 'ExponentialLR':
        return torch.optim.lr_scheduler.ExponentialLR(optimizer, **args)
    elif name == 'StepLR':
        return torch.optim.lr_scheduler.StepLR(optimizer, **args)
    elif name == 'MultiStepLR':
        return torch.optim.lr_scheduler.MultiStepLR(optimizer, **args)
    else:
        raise ValueError(f"Unknown scheduler: {name}")


class AdaptiveGradClipper:
    """
    Adaptive gradient clipping based on percentile of gradient norms.
    """
    
    def __init__(self, max_norm: float = 1.0, clip_percentile: float = 95):
        self.max_norm = max_norm
        self.clip_percentile = clip_percentile
        self.grad_norms_history = []
        self.history_size = 1000
    
    def __call__(self, parameters):
        """Apply adaptive gradient clipping."""
        if isinstance(parameters, torch.Tensor):
            parameters = [parameters]
        parameters = [p for p in parameters if p.grad is not None]
        
        if not parameters:
            return torch.tensor(0.0)
        
        # Calculate gradient norm
        total_norm = torch.norm(
            torch.stack([torch.norm(p.grad.detach()) for p in parameters])
        )
        
        # Update history
        self.grad_norms_history.append(total_norm.item())
        if len(self.grad_norms_history) > self.history_size:
            self.grad_norms_history.pop(0)
        
        # Calculate adaptive threshold
        if len(self.grad_norms_history) >= 10:
            threshold = torch.quantile(
                torch.tensor(self.grad_norms_history), 
                self.clip_percentile / 100.0
            )
            clip_value = min(self.max_norm, threshold.item())
        else:
            clip_value = self.max_norm
        
        # Apply clipping
        if total_norm > clip_value:
            clip_coef = clip_value / (total_norm + 1e-6)
            for p in parameters:
                p.grad.detach().mul_(clip_coef)
        
        return total_norm


def create_grad_clipper(grad_clip_config: Dict[str, Any]):
    """Create gradient clipper based on configuration."""
    if grad_clip_config is None:
        return None
    
    name = grad_clip_config.get('name', 'norm')
    args = grad_clip_config.get('args', {})
    
    if name.lower() == 'norm':
        max_norm = args.get('max_norm', 1.0)
        return lambda params: torch.nn.utils.clip_grad_norm_(params, max_norm)
    elif name.lower() == 'adaptivegradclipper':
        return AdaptiveGradClipper(**args)
    else:
        raise ValueError(f"Unknown gradient clipper: {name}")