| |
| |
| |
| |
| |
|
|
| import copy |
| import functools |
| import math |
| from typing import Any, Callable, Iterator |
|
|
| from torch.distributed.checkpoint.stateful import Stateful |
| from torch.optim.lr_scheduler import LambdaLR, LRScheduler |
|
|
| from torchtitan.components.optimizer import OptimizersContainer |
| from torchtitan.config_manager import JobConfig |
| from torchtitan.tools.logging import logger |
|
|
| __all__ = [ |
| "LRSchedulersContainer", |
| "build_lr_schedulers", |
| ] |
|
|
|
|
| class LRSchedulersContainer(Stateful): |
| """Container for multiple learning rate schedulers. |
| |
| This class is used to wrap multiple LRSchedulers into a single object that can be |
| used to reduce the complexity of the training loop. This mimics the behavior of |
| ``torch.optim.lr_scheduler.LRScheduler``. The design concept is the same as |
| ``OptimizersContainer``. This class currently only supports ``LambdaLR``. |
| |
| **Note** |
| Users who want to customize the lr_scheduler behavior can inherit from this class and |
| extend the functionality as needed. The following methods must follow the same |
| signature as ``torch.optim.lr_scheduler.LRScheduler`` class: ``step()``, ``state_dict()``, |
| ``load_state_dict()``. |
| |
| **Limitations** |
| This class assumes all the lr schedulers are the same. There is no easy way to support |
| resharding for multiple different LRSchedulers because LRScheduler.state_dict() is not |
| resharding friendly. Therefore, the limitation is used to allow TorchTitan to support |
| lr scheduler resharding. |
| |
| Args: |
| optimizers (OptimizersContainer): The corresponding optimizers for the lr_schedulers. |
| """ |
|
|
| schedulers: list[LRScheduler] |
|
|
| def __init__(self, optimizers: OptimizersContainer, lr_lambda: Callable) -> None: |
| assert ( |
| len(optimizers) > 0 |
| ), "Must have at least one optimizer to create LRScheduler" |
|
|
| self.schedulers = [LambdaLR(optimizer, lr_lambda) for optimizer in optimizers] |
|
|
| def __iter__(self) -> Iterator[LRScheduler]: |
| return iter(self.schedulers) |
|
|
| def __len__(self) -> int: |
| return len(self.schedulers) |
|
|
| def step(self) -> None: |
| for scheduler in self.schedulers: |
| scheduler.step() |
|
|
| def state_dict(self) -> dict[str, Any]: |
| |
| |
| |
| return self.schedulers[0].state_dict() |
|
|
| def load_state_dict(self, state_dict: dict[str, Any]) -> None: |
| |
| |
| |
| |
| |
| for scheduler in self.schedulers: |
| scheduler.load_state_dict(copy.deepcopy(state_dict)) |
|
|
|
|
| def build_lr_schedulers( |
| optimizers: OptimizersContainer, job_config: JobConfig |
| ) -> LRSchedulersContainer: |
| """Create a LRSchedulerContainer for the given optimizers and job config. |
| |
| This function creates a ``LRSchedulersContainer`` for the given optimizers. |
| ``job_config`` should define the correct lr scheduler parameters. |
| |
| **Note** |
| Users who want to customize the lr scheduler behavior can create their own |
| ``LRSchedulersContainer`` subclass and ``build_lr_scheduler``. Passing the |
| customized ``build_lr_schedulers`` to ``TrainSpec`` will create the customized |
| ``LRSchedulersContainer``. |
| |
| |
| Args: |
| optimizers (OptimizersContainer): The corresponding optimizers for the |
| lr_schedulers. |
| """ |
| training_steps = job_config.training.steps |
| warmup_steps = int(job_config.lr_scheduler.warmup_steps) |
| if job_config.lr_scheduler.decay_ratio is not None: |
| decay_steps = round(training_steps * job_config.lr_scheduler.decay_ratio) |
| if warmup_steps + decay_steps > training_steps: |
| logger.warning( |
| f"Warmup ({warmup_steps}) + decay ({decay_steps}) steps exceed " |
| f"total training steps ({training_steps}). " |
| f"Adjusting decay steps to {training_steps - warmup_steps}." |
| ) |
| decay_steps = training_steps - warmup_steps |
| else: |
| decay_steps = training_steps - warmup_steps |
| stable_steps = training_steps - warmup_steps - decay_steps |
| lr_decay_type = job_config.lr_scheduler.decay_type |
| lr_min = job_config.lr_scheduler.lr_min |
|
|
| def linear_warmup_stable_decay( |
| current_step: int, |
| warmup_steps: int, |
| stable_steps: int, |
| decay_steps: int, |
| lr_decay_type: str, |
| lr_min: float, |
| ): |
| """ |
| Computes linear warmup followed by stable learning rate for a while, |
| then some type of decay. |
| |
| Per LambdaLR requirement, this is accomplished by returning |
| a multiplicative factor `curr_adjustment` ranging from 1 to 0 |
| to adjust the learning rate to create the desired schedule. |
| |
| We offer three types of learning rate decay schedules: |
| 1. `linear`: decays linearly from 1 to 0 over the decay period. |
| 2. `sqrt`: decays as 1 minus the square root of the decay progress. |
| 3. `cosine`: follows a cosine curve, decaying according to the values of the half-period of the cosine function. |
| |
| If `lr_min` is specified, the decay range is scaled from 1 to `lr_min` |
| to ensure the learning rate does not drop below this minimum value. |
| """ |
| warmup_stable_steps = warmup_steps + stable_steps |
| if current_step < warmup_steps: |
| |
| |
| current_step += 1 |
| curr_adjustment = float(current_step / (warmup_steps + 1)) |
| elif current_step < warmup_stable_steps: |
| curr_adjustment = 1.0 |
| else: |
| |
| current_step += 1 |
| progress = float(current_step - warmup_stable_steps) / (decay_steps + 1) |
|
|
| if lr_decay_type == "linear": |
| curr_adjustment = 1 - progress |
| elif lr_decay_type == "sqrt": |
| curr_adjustment = 1 - math.sqrt(progress) |
| elif lr_decay_type == "cosine": |
| curr_adjustment = 0.5 * (1.0 + math.cos(math.pi * progress)) |
| curr_adjustment = lr_min + (1 - lr_min) * curr_adjustment |
| return curr_adjustment |
|
|
| lr_lambda = functools.partial( |
| linear_warmup_stable_decay, |
| warmup_steps=warmup_steps, |
| stable_steps=stable_steps, |
| decay_steps=decay_steps, |
| lr_decay_type=lr_decay_type, |
| lr_min=lr_min, |
| ) |
| return LRSchedulersContainer(optimizers, lr_lambda) |
|
|