|
|
import torch |
|
|
import numpy as np |
|
|
|
|
|
def create_optimizer(model, config): |
|
|
train_config = config['training'] |
|
|
base_lr = train_config['learning_rate'] |
|
|
weight_decay = train_config['weight_decay'] |
|
|
|
|
|
layer_decay = train_config.get('layer_decay', 0.8) |
|
|
|
|
|
|
|
|
|
|
|
num_layers = len(model.blocks) + 1 |
|
|
|
|
|
parameter_groups = [] |
|
|
|
|
|
|
|
|
head_lr = train_config.get('head_lr', base_lr) |
|
|
parameter_groups.append({ |
|
|
"params": [p for n, p in model.named_parameters() if "head" in n], |
|
|
"lr": head_lr, |
|
|
"weight_decay": weight_decay |
|
|
}) |
|
|
|
|
|
|
|
|
for i, block in enumerate(model.blocks): |
|
|
|
|
|
|
|
|
|
|
|
scale = layer_decay ** (num_layers - i - 1) |
|
|
|
|
|
parameter_groups.append({ |
|
|
"params": block.parameters(), |
|
|
"lr": base_lr * scale, |
|
|
"weight_decay": weight_decay |
|
|
}) |
|
|
|
|
|
|
|
|
earliest_params = [] |
|
|
for n, p in model.named_parameters(): |
|
|
if "patch_embed" in n or "encoder_norm" in n: |
|
|
earliest_params.append(p) |
|
|
|
|
|
if earliest_params: |
|
|
parameter_groups.append({ |
|
|
"params": earliest_params, |
|
|
"lr": base_lr * (layer_decay ** num_layers), |
|
|
"weight_decay": weight_decay |
|
|
}) |
|
|
|
|
|
if train_config['optimizer'].lower() == 'adamw': |
|
|
optimizer = torch.optim.AdamW( |
|
|
parameter_groups, |
|
|
betas=tuple(train_config['betas']), |
|
|
weight_decay=train_config['weight_decay'] |
|
|
) |
|
|
elif train_config['optimizer'].lower() == 'sgd': |
|
|
optimizer = torch.optim.SGD( |
|
|
parameter_groups, |
|
|
momentum=train_config.get('momentum', 0.9), |
|
|
weight_decay=train_config['weight_decay'] |
|
|
) |
|
|
else: |
|
|
raise ValueError(f"Unsupported optimizer: {train_config['optimizer']}") |
|
|
|
|
|
return optimizer |
|
|
|
|
|
|
|
|
def create_lr_scheduler(optimizer, config, steps_per_epoch): |
|
|
"""Create learning rate scheduler""" |
|
|
train_config = config['training'] |
|
|
total_steps = train_config['epochs'] * steps_per_epoch |
|
|
warmup_steps = train_config['warmup_epochs'] * steps_per_epoch |
|
|
|
|
|
if train_config['lr_scheduler'].lower() == 'cosine': |
|
|
def lr_lambda(current_step): |
|
|
if current_step < warmup_steps: |
|
|
|
|
|
return float(current_step) / float(max(1, warmup_steps)) |
|
|
else: |
|
|
|
|
|
progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps)) |
|
|
return max(train_config['min_lr'] / train_config['learning_rate'], |
|
|
0.5 * (1.0 + np.cos(np.pi * progress))) |
|
|
|
|
|
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
|
|
else: |
|
|
raise ValueError(f"Unsupported scheduler: {train_config['lr_scheduler']}") |
|
|
|
|
|
return scheduler |