| | |
| | |
| | |
| | |
| | |
| |
|
| | import functools |
| | from typing import Any, Generic, Iterator, TypeVar |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch.distributed.checkpoint.state_dict import ( |
| | get_optimizer_state_dict, |
| | set_optimizer_state_dict, |
| | StateDictOptions, |
| | ) |
| | from torch.distributed.checkpoint.stateful import Stateful |
| | from torch.optim import Optimizer |
| |
|
| | from torchtitan.components.ft import FTManager, has_torchft |
| | from torchtitan.config_manager import JobConfig |
| |
|
| | __all__ = [ |
| | "OptimizersContainer", |
| | "build_optimizers", |
| | ] |
| |
|
| |
|
| | if has_torchft: |
| | import torchft as ft |
| |
|
| |
|
| | T = TypeVar("T", bound=Optimizer) |
| |
|
| |
|
| | class OptimizersContainer(Optimizer, Stateful, Generic[T]): |
| | """A container for multiple optimizers. |
| | |
| | This class is used to wrap multiple optimizers into a single object that can be |
| | used to reduce the complexity of the training loop. This mimics the behavior of |
| | ``torch.optim.Optimizer``. This class currently only supports ``Adam`` and ``AdamW``. |
| | |
| | **Note** |
| | Users who want to customize the optimizer behavior can inherit from this class and |
| | extend the functionality as needed. The following methods must follow the same signature |
| | as ``torch.optim.Optimizer`` class: ``step()``, ``zero_grad()``, ``state_dict()``, |
| | ``load_state_dict()``. |
| | |
| | **Limitations** |
| | This class assumes that all the optimizers are the same type and have the same |
| | configurations. With this assumption, TorchTitan can support lr scheduler resharding |
| | (e.g., loading a checkpoint with a different number of GPUs and/or different |
| | parallelization strategy). Note that ``get_optimizer_state_dict`` already enables the |
| | resharding for the optimizer state but not for the lr scheduler state, hence the limitation. |
| | |
| | Args: |
| | model_parts (List[nn.Module]): List of model parts to be optimized. |
| | optimizer_kwargs (Dict[str, Any]): Keyword arguments for the optimizers. |
| | name (str): Name of the optimizers. |
| | """ |
| |
|
| | optimizers: list[T] |
| | model_parts: list[nn.Module] |
| |
|
| | def __init__( |
| | self, |
| | model_parts: list[nn.Module], |
| | optimizer_cls: type[T], |
| | optimizer_kwargs: dict[str, Any], |
| | ) -> None: |
| | all_params = [] |
| | self.optimizers = [] |
| | self.model_parts = model_parts |
| | for model in self.model_parts: |
| | params = [p for p in model.parameters() if p.requires_grad] |
| | self.optimizers.append(optimizer_cls(params, **optimizer_kwargs)) |
| | all_params.extend(params) |
| | self._validate_length(len(self.model_parts)) |
| | self._post_init(all_params, optimizer_kwargs) |
| |
|
| | def __iter__(self) -> Iterator[T]: |
| | return iter(self.optimizers) |
| |
|
| | def __len__(self) -> int: |
| | return len(self.optimizers) |
| |
|
| | def step(self, *args, **kwargs) -> None: |
| | for optimizer in self.optimizers: |
| | optimizer.step(*args, **kwargs) |
| |
|
| | def zero_grad(self, *args, **kwargs) -> None: |
| | for optimizer in self.optimizers: |
| | optimizer.zero_grad(*args, **kwargs) |
| |
|
| | def state_dict(self) -> dict[str, Any]: |
| | func = functools.partial( |
| | get_optimizer_state_dict, |
| | options=StateDictOptions(flatten_optimizer_state_dict=True), |
| | ) |
| | return { |
| | k: v |
| | for sd in map(func, self.model_parts, self.optimizers) |
| | for k, v in sd.items() |
| | } |
| |
|
| | def load_state_dict(self, state_dict: dict[str, Any]) -> None: |
| | func = functools.partial( |
| | set_optimizer_state_dict, |
| | optim_state_dict=state_dict, |
| | options=StateDictOptions(flatten_optimizer_state_dict=True), |
| | ) |
| | list(map(func, self.model_parts, self.optimizers)) |
| |
|
| | def _validate_length(self, expected_length: int) -> None: |
| | assert expected_length == len(self.optimizers), ( |
| | "Must pass one optimizer per model part or per param if " |
| | "using OptimizersInBackwardContainer." |
| | ) |
| |
|
| | def _post_init( |
| | self, all_params: list[nn.Parameter], optimizer_kwargs: dict[str, Any] |
| | ) -> None: |
| | |
| | |
| | Optimizer.__init__(self, all_params, optimizer_kwargs) |
| |
|
| |
|
| | class OptimizersInBackwardContainer(OptimizersContainer): |
| | """OptimizersContainer for executing ``optim.step()`` in backward pass. |
| | |
| | This class extend ``OptimizersContainer`` to support optimizer step in |
| | backward pass. ``step()`` and ``zero_grad()`` are no-op in this class. |
| | Instead, ``register_post_accumulate_grad_hook`` is used to register a hook to |
| | execute these methods when the gradient is accumulated. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | model_parts: list[nn.Module], |
| | optimizer_cls: type[T], |
| | optimizer_kwargs: dict[str, Any], |
| | ) -> None: |
| | all_params = [] |
| | self.model_parts = model_parts |
| |
|
| | optim_dict = {} |
| | for model in self.model_parts: |
| | for p in model.parameters(): |
| | if p.requires_grad: |
| | optim_dict[p] = optimizer_cls([p], **optimizer_kwargs) |
| | all_params.append(p) |
| |
|
| | def optim_hook(param) -> None: |
| | optim_dict[param].step() |
| | optim_dict[param].zero_grad() |
| |
|
| | for model in self.model_parts: |
| | for param in model.parameters(): |
| | if param.requires_grad: |
| | param.register_post_accumulate_grad_hook(optim_hook) |
| |
|
| | self.optimizers = list(optim_dict.values()) |
| |
|
| | self._validate_length( |
| | sum(len(list(model.parameters())) for model in self.model_parts) |
| | ) |
| | self._post_init(all_params, optimizer_kwargs) |
| |
|
| | def step(self) -> None: |
| | pass |
| |
|
| | def zero_grad(self) -> None: |
| | pass |
| |
|
| |
|
| | class FTOptimizersContainer(OptimizersContainer): |
| | def __init__( |
| | self, |
| | model_parts: list[nn.Module], |
| | optimizer_cls: type[T], |
| | optimizer_kwargs: dict[str, Any], |
| | ft_manager: "ft.Manager", |
| | ) -> None: |
| | super().__init__(model_parts, optimizer_cls, optimizer_kwargs) |
| |
|
| | |
| | |
| | _ = { |
| | k: v |
| | for sd in map(get_optimizer_state_dict, model_parts, self.optimizers) |
| | for k, v in sd.items() |
| | } |
| | self.cache_state_dict: dict[str, Any] = {} |
| | self._ft_optimizer = ft.Optimizer(ft_manager, self) |
| | self._call_from_ft: bool = False |
| |
|
| | def init_cache_state_dict(self) -> None: |
| | self.cache_state_dict = super().state_dict() |
| |
|
| | def state_dict(self) -> dict[str, Any]: |
| | return self.cache_state_dict |
| |
|
| | def load_state_dict(self, state_dict: dict[str, Any]) -> None: |
| | |
| | |
| | |
| | self.cache_state_dict = {} |
| | super().load_state_dict(state_dict) |
| | self.init_cache_state_dict() |
| |
|
| | def step(self, *args, **kwargs) -> None: |
| | """Calling the correct step() depending on the caller. |
| | |
| | TorchFT's OptimizerWrapper.step() is designed to be callled only once |
| | per train step per ft.Manager regardless how many optimizers are used. |
| | Hence we will need to appropriately dispatch the call. |
| | """ |
| | if self._call_from_ft: |
| | super().step(*args, **kwargs) |
| | else: |
| | self._call_from_ft = True |
| | self._ft_optimizer.step(*args, **kwargs) |
| | self._call_from_ft = False |
| |
|
| | def zero_grad(self, *args, **kwargs) -> None: |
| | """Calling the correct zero_grad() depending on the caller. |
| | |
| | Check the comment in ``step()``. |
| | """ |
| | if self._call_from_ft: |
| | super().zero_grad(*args, **kwargs) |
| | else: |
| | self._call_from_ft = True |
| | self._ft_optimizer.zero_grad(*args, **kwargs) |
| | self._call_from_ft = False |
| |
|
| |
|
| | def build_optimizers( |
| | model_parts: list[nn.Module], |
| | job_config: JobConfig, |
| | ft_manager: FTManager, |
| | ) -> OptimizersContainer: |
| | """Create a OptimizersContainer for the given model parts and job config. |
| | |
| | This function creates a ``OptimizersContainer`` for the given model parts. |
| | ``job_config`` should define the correct optimizer name and parameters. |
| | This function currently supports creating ``OptimizersContainer`` and |
| | ``OptimizersInBackwardContainer``. |
| | |
| | **Note** |
| | Users who want to customize the optimizer behavior can create their own |
| | ``OptimizersContainer`` subclass and ``build_optimizers``. Passing the |
| | customized ``build_optimizers`` to ``TrainSpec`` will create the customized |
| | ``OptimizersContainer``. |
| | |
| | Args: |
| | model_parts (List[nn.Module]): List of model parts to be optimized. |
| | job_config (JobConfig): Job config containing the optimizer name and parameters. |
| | """ |
| | optim_in_bwd = job_config.optimizer.early_step_in_backward |
| | if optim_in_bwd and job_config.parallelism.pipeline_parallel_degree > 1: |
| | raise NotImplementedError( |
| | "Optimizers in backward is not supported with pipeline parallelism." |
| | ) |
| | name = job_config.optimizer.name |
| | lr = job_config.optimizer.lr |
| | eps = job_config.optimizer.eps |
| |
|
| | optim_implementation = job_config.optimizer.implementation |
| | assert optim_implementation in ["fused", "foreach", "for-loop"] |
| |
|
| | fused = optim_implementation == "fused" |
| | foreach = optim_implementation == "foreach" |
| |
|
| | optimizer_kwargs = { |
| | "lr": lr, |
| | "eps": eps, |
| | "betas": (0.9, 0.95), |
| | "weight_decay": 0.1, |
| | "fused": fused, |
| | "foreach": foreach, |
| | } |
| |
|
| | optimizer_classes = { |
| | "Adam": torch.optim.Adam, |
| | "AdamW": torch.optim.AdamW, |
| | } |
| | if name not in optimizer_classes: |
| | raise NotImplementedError(f"Optimizer {name} not added.") |
| | optimizer_cls = optimizer_classes[name] |
| |
|
| | if optim_in_bwd and ft_manager.enabled: |
| | raise ValueError("TorchFT is not supported with optimizers in backward.") |
| | elif optim_in_bwd: |
| | return OptimizersInBackwardContainer( |
| | model_parts, optimizer_cls, optimizer_kwargs |
| | ) |
| | elif ft_manager.enabled: |
| | return FTOptimizersContainer( |
| | model_parts, optimizer_cls, optimizer_kwargs, ft_manager.manager |
| | ) |
| | else: |
| | return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs) |
| |
|