File size: 3,323 Bytes
538668e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
import numpy as np

def create_optimizer(model, config):
    train_config = config['training']
    base_lr = train_config['learning_rate']
    weight_decay = train_config['weight_decay']
    
    layer_decay = train_config.get('layer_decay', 0.8) 
    
    # 获取所有的 blocks 数量用于计算深度
    # 假设 model 是 HieraClassifier,其 encoder blocks 在 self.blocks 中
    num_layers = len(model.blocks) + 1 # +1 处理 patch_embed
    
    parameter_groups = []
    
    # 1. 专门处理 Head (分类头通常使用最大的 base_lr)
    head_lr = train_config.get('head_lr', base_lr)
    parameter_groups.append({
        "params": [p for n, p in model.named_parameters() if "head" in n],
        "lr": head_lr,
        "weight_decay": weight_decay
    })

    # 2. 处理 Encoder Blocks (按层衰减)
    for i, block in enumerate(model.blocks):
        # 深度越深(靠近 head),学习率越高
        # 最后一层 i = num_layers-2,缩放接近 1.0
        # 第一层 i = 0,缩放为 layer_decay^(num_layers)
        scale = layer_decay ** (num_layers - i - 1)
        
        parameter_groups.append({
            "params": block.parameters(),
            "lr": base_lr * scale,
            "weight_decay": weight_decay
        })

    # 3. 处理 Patch Embed 和其他初始层 (最低的学习率)
    earliest_params = []
    for n, p in model.named_parameters():
        if "patch_embed" in n or "encoder_norm" in n:
            earliest_params.append(p)
    
    if earliest_params:
        parameter_groups.append({
            "params": earliest_params,
            "lr": base_lr * (layer_decay ** num_layers),
            "weight_decay": weight_decay
        })

    if train_config['optimizer'].lower() == 'adamw':
        optimizer = torch.optim.AdamW(
            parameter_groups,
            betas=tuple(train_config['betas']),
            weight_decay=train_config['weight_decay']
        )
    elif train_config['optimizer'].lower() == 'sgd':
        optimizer = torch.optim.SGD(
            parameter_groups,
            momentum=train_config.get('momentum', 0.9),
            weight_decay=train_config['weight_decay']
        )
    else:
        raise ValueError(f"Unsupported optimizer: {train_config['optimizer']}")

    return optimizer


def create_lr_scheduler(optimizer, config, steps_per_epoch):
    """Create learning rate scheduler"""
    train_config = config['training']
    total_steps = train_config['epochs'] * steps_per_epoch
    warmup_steps = train_config['warmup_epochs'] * steps_per_epoch

    if train_config['lr_scheduler'].lower() == 'cosine':
        def lr_lambda(current_step):
            if current_step < warmup_steps:
                # Linear warmup
                return float(current_step) / float(max(1, warmup_steps))
            else:
                # Cosine annealing
                progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
                return max(train_config['min_lr'] / train_config['learning_rate'],
                          0.5 * (1.0 + np.cos(np.pi * progress)))

        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    else:
        raise ValueError(f"Unsupported scheduler: {train_config['lr_scheduler']}")

    return scheduler