File size: 1,057 Bytes
eca55dc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | import math
class LinearWarmupCosineDecay:
def __init__(
self,
warmup_steps: int,
total_steps: int,
final_lr_ratio: float,
):
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.final_lr_ratio = final_lr_ratio
def __call__(self, current_step: int) -> float:
if current_step < self.warmup_steps:
# Linear warmup
return float(current_step) / float(max(1, self.warmup_steps))
# Cosine decay
progress = float(current_step - self.warmup_steps) / float(
max(1, self.total_steps - self.warmup_steps)
)
progress = min(1.0, max(0.0, progress)) # Clip to [0, 1]
# Cosine decay from 1.0 to final_lr_ratio
# formula: final + 0.5 * (initial - final) * (1 + cos(pi * progress))
# scaled relative to initial lr (which is 1.0 in lambda)
cosine_part = 0.5 * (1.0 + math.cos(math.pi * progress))
return self.final_lr_ratio + (1.0 - self.final_lr_ratio) * cosine_part
|