File size: 1,057 Bytes
eca55dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import math


class LinearWarmupCosineDecay:
    def __init__(
        self,
        warmup_steps: int,
        total_steps: int,
        final_lr_ratio: float,
    ):
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.final_lr_ratio = final_lr_ratio

    def __call__(self, current_step: int) -> float:
        if current_step < self.warmup_steps:
            # Linear warmup
            return float(current_step) / float(max(1, self.warmup_steps))

        # Cosine decay
        progress = float(current_step - self.warmup_steps) / float(
            max(1, self.total_steps - self.warmup_steps)
        )
        progress = min(1.0, max(0.0, progress))  # Clip to [0, 1]

        # Cosine decay from 1.0 to final_lr_ratio
        # formula: final + 0.5 * (initial - final) * (1 + cos(pi * progress))
        # scaled relative to initial lr (which is 1.0 in lambda)

        cosine_part = 0.5 * (1.0 + math.cos(math.pi * progress))
        return self.final_lr_ratio + (1.0 - self.final_lr_ratio) * cosine_part