| |
| |
| |
| |
| |
| |
| |
|
|
| from abc import abstractmethod |
| from collections.abc import Callable, Mapping |
| from dataclasses import dataclass |
| from typing import Protocol, TypeAlias |
|
|
| import torch |
| import torch.nn as nn |
| from torch.distributed.pipelining.schedules import _PipelineSchedule |
|
|
| from torchtitan.components.dataloader import BaseDataLoader |
| from torchtitan.components.ft import FTManager |
| from torchtitan.components.loss import LossFunction |
| from torchtitan.components.lr_scheduler import LRSchedulersContainer |
| from torchtitan.components.metrics import MetricsProcessor |
| from torchtitan.components.optimizer import OptimizersContainer |
| from torchtitan.components.tokenizer import Tokenizer |
| from torchtitan.config_manager import JobConfig |
|
|
| DeviceType = int | str | torch.device |
|
|
|
|
| @dataclass |
| class BaseModelArgs: |
| """All ModelArgs should inherit from this class. |
| |
| The only usage of this class is type checking but allows us to extend common |
| arguments to all models in the future. |
| """ |
|
|
| _enforced: str = "This field is used to enforce all fields have defaults." |
|
|
| @abstractmethod |
| def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None: |
| pass |
|
|
| @abstractmethod |
| def get_nparams_and_flops( |
| self, model: nn.Module, seq_len: int |
| ) -> tuple[int, float]: |
| pass |
|
|
|
|
| class ModelProtocol(Protocol): |
| """Defines the interface for a model class. |
| |
| This is used to enforce that all model classes have some methods that are |
| required by the TorchTitan trainer. |
| """ |
|
|
| @classmethod |
| def from_model_args(cls, args: BaseModelArgs) -> nn.Module: |
| ... |
|
|
|
|
| ParallelizeFunction: TypeAlias = Callable[..., nn.Module] |
| PipeliningFunction: TypeAlias = Callable[ |
| ..., tuple[_PipelineSchedule, list[nn.Module], bool, bool] |
| ] |
| DataLoaderBuilder: TypeAlias = Callable[..., BaseDataLoader] |
| TokenizerBuilder: TypeAlias = Callable[..., Tokenizer] |
| MetricsProcessorBuilder: TypeAlias = Callable[..., MetricsProcessor] |
| OptimizersBuilder: TypeAlias = Callable[ |
| [list[nn.Module], JobConfig, FTManager], OptimizersContainer |
| ] |
| LRSchedulersBuilder: TypeAlias = Callable[ |
| [OptimizersContainer, JobConfig], LRSchedulersContainer |
| ] |
| LossFunctionBuilder: TypeAlias = Callable[..., LossFunction] |
|
|
|
|
| @dataclass |
| class TrainSpec: |
| name: str |
| cls: type[nn.Module] |
| config: Mapping[str, BaseModelArgs] |
| parallelize_fn: ParallelizeFunction |
| pipelining_fn: PipeliningFunction | None |
| build_optimizers_fn: OptimizersBuilder |
| build_lr_schedulers_fn: LRSchedulersBuilder |
| build_dataloader_fn: DataLoaderBuilder |
| build_tokenizer_fn: TokenizerBuilder | None |
| build_loss_fn: LossFunctionBuilder |
| build_metrics_processor_fn: MetricsProcessorBuilder | None = None |
|
|
|
|
| _train_specs = {} |
|
|
|
|
| def register_train_spec(train_spec: TrainSpec) -> None: |
| global _train_specs |
| if train_spec.name in _train_specs: |
| raise ValueError(f"Model {train_spec.name} is already registered.") |
|
|
| _train_specs[train_spec.name] = train_spec |
|
|
|
|
| def get_train_spec(name: str) -> TrainSpec: |
| global _train_specs |
| if name not in _train_specs: |
| raise ValueError(f"Model {name} is not registered.") |
| return _train_specs[name] |
|
|
|
|
| def apply_to_train_specs(func: Callable[[TrainSpec], TrainSpec]) -> None: |
| global _train_specs |
| for name, train_spec in _train_specs.items(): |
| _train_specs[name] = func(train_spec) |
|
|