| import math | |
| class LinearWarmupCosineDecay: | |
| def __init__( | |
| self, | |
| warmup_steps: int, | |
| total_steps: int, | |
| final_lr_ratio: float, | |
| ): | |
| self.warmup_steps = warmup_steps | |
| self.total_steps = total_steps | |
| self.final_lr_ratio = final_lr_ratio | |
| def __call__(self, current_step: int) -> float: | |
| if current_step < self.warmup_steps: | |
| # Linear warmup | |
| return float(current_step) / float(max(1, self.warmup_steps)) | |
| # Cosine decay | |
| progress = float(current_step - self.warmup_steps) / float( | |
| max(1, self.total_steps - self.warmup_steps) | |
| ) | |
| progress = min(1.0, max(0.0, progress)) # Clip to [0, 1] | |
| # Cosine decay from 1.0 to final_lr_ratio | |
| # formula: final + 0.5 * (initial - final) * (1 + cos(pi * progress)) | |
| # scaled relative to initial lr (which is 1.0 in lambda) | |
| cosine_part = 0.5 * (1.0 + math.cos(math.pi * progress)) | |
| return self.final_lr_ratio + (1.0 - self.final_lr_ratio) * cosine_part | |