|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
from cosmos_predict1.utils.lazy_config import LazyCall as L |
|
|
|
|
|
|
|
|
class LambdaLinearWarmupScheduler: |
|
|
""" |
|
|
A learning rate scheduler that implements linear warm-up and cool-down. |
|
|
|
|
|
This scheduler provides three phases: |
|
|
1. Warm-up: Learning rate linearly increases from 0 to 1. |
|
|
2. Constant: Learning rate remains at 1. |
|
|
3. Cool-down: Learning rate linearly decreases from 1 to 0. |
|
|
|
|
|
Args: |
|
|
warmup_steps (int): Number of steps for the warm-up phase. |
|
|
warmup_offset (int): Starts warmup from this offset. |
|
|
max_iter (int, optional): Total number of iterations. Required if cooldown_steps is provided. |
|
|
cooldown_steps (int, optional): Number of steps for the cool-down phase. |
|
|
|
|
|
Raises: |
|
|
ValueError: If cooldown_steps is provided without max_iter, or if an invalid step is given. |
|
|
""" |
|
|
|
|
|
def __init__(self, warmup_steps: int, warmup_offset: int = 0, max_iter: int = None, cooldown_steps: int = None): |
|
|
self.warmup_steps = warmup_steps |
|
|
self.warmup_offset = warmup_offset |
|
|
self.max_iter = max_iter |
|
|
self.cooldown_steps = cooldown_steps |
|
|
|
|
|
if cooldown_steps is not None: |
|
|
if max_iter is None: |
|
|
raise ValueError("max_iter must be specified when cooldown_steps is provided") |
|
|
self.cooldown_start = max_iter - cooldown_steps |
|
|
else: |
|
|
self.cooldown_start = None |
|
|
|
|
|
def __call__(self, step): |
|
|
|
|
|
if step < self.warmup_offset: |
|
|
return 0 |
|
|
|
|
|
if step < self.warmup_steps + self.warmup_offset: |
|
|
return float(step - self.warmup_offset) / float(max(1, self.warmup_steps)) |
|
|
|
|
|
|
|
|
elif self.cooldown_steps is None: |
|
|
return 1.0 |
|
|
|
|
|
|
|
|
elif step < self.cooldown_start: |
|
|
return 1.0 |
|
|
|
|
|
|
|
|
elif self.cooldown_start <= step < self.max_iter: |
|
|
cooldown_progress = (step - self.cooldown_start) / self.cooldown_steps |
|
|
return 1.0 - cooldown_progress |
|
|
|
|
|
|
|
|
elif step >= self.max_iter: |
|
|
return 0.0 |
|
|
|
|
|
|
|
|
else: |
|
|
raise ValueError(f"Invalid step {step}") |
|
|
|
|
|
|
|
|
LambdaLinearLR = L(torch.optim.lr_scheduler.LambdaLR)( |
|
|
optimizer=None, |
|
|
lr_lambda=L(LambdaLinearWarmupScheduler)(warmup_steps=5000), |
|
|
) |
|
|
|