| from __future__ import annotations |
|
|
| from dataclasses import asdict, dataclass, field |
| from glob import glob |
| from pathlib import Path |
| from typing import ( |
| Any, |
| Dict, |
| Iterable, |
| List, |
| Optional, |
| Tuple, |
| Type, |
| TypeVar, |
| Union, |
| cast, |
| ) |
|
|
| import torch |
| from omegaconf import DictConfig, ListConfig |
| from omegaconf import OmegaConf as om |
| from omegaconf.errors import OmegaConfBaseException |
| from torch.distributed.fsdp import MixedPrecision, ShardingStrategy |
|
|
| from .aliases import PathOrStr |
| from .beam_search import Sampler |
| from .exceptions import OLMoConfigurationError |
| from .util import StrEnum |
|
|
| __all__ = [ |
| "ActivationType", |
| "ActivationCheckpointingStrategy", |
| "BlockType", |
| "LayerNormType", |
| "InitFnType", |
| "ModelConfig", |
| "OptimizerType", |
| "OptimizerConfig", |
| "SchedulerType", |
| "SchedulerConfig", |
| "DataConfig", |
| "EvaluatorConfig", |
| "TokenizerConfig", |
| "TrainConfig", |
| "PaddingDirection", |
| "TruncationDirection", |
| "SpeedMonitorConfig", |
| "WandbConfig", |
| "CompilerConfig", |
| "WandbConfig", |
| "FSDPPrecision", |
| "FSDPWrapStrategy", |
| "FSDPConfig", |
| "CheckpointType", |
| ] |
|
|
| C = TypeVar("C", bound="BaseConfig") |
| D = TypeVar("D", bound="DictConfig|ListConfig") |
|
|
|
|
| class BaseConfig: |
| @classmethod |
| def _register_resolvers(cls, validate_paths: bool = True): |
| |
| def path_glob(*paths) -> List[str]: |
| out = [] |
| for path in paths: |
| matches = sorted(glob(path)) |
| if not matches and validate_paths: |
| raise FileNotFoundError(f"{path} does not match any files or dirs") |
| out.extend(matches) |
| return out |
|
|
| |
| def path_choose(*paths) -> str: |
| from .util import is_url |
|
|
| for path in paths: |
| if is_url(path) or Path(path).exists(): |
| return path |
| if validate_paths: |
| raise FileNotFoundError(", ".join(paths)) |
| else: |
| return "" |
|
|
| |
| def path_last_checkpoint(path) -> str: |
| from .util import find_latest_checkpoint |
|
|
| latest_checkpoint = find_latest_checkpoint(path) |
| if latest_checkpoint is None: |
| if validate_paths: |
| raise FileNotFoundError(f"Could not find a latest checkpoint at {path}") |
| else: |
| return "" |
| else: |
| return str(latest_checkpoint) |
|
|
| om.register_new_resolver("path.glob", path_glob, replace=True) |
| om.register_new_resolver("path.choose", path_choose, replace=True) |
| om.register_new_resolver("path.last_checkpoint", path_last_checkpoint, replace=True) |
|
|
| @classmethod |
| def update_legacy_settings(cls, config: D) -> D: |
| """ |
| Update the legacy config settings whose schemas have undergone backwards-incompatible changes. |
| """ |
| return config |
|
|
| @classmethod |
| def new(cls: Type[C], **kwargs) -> C: |
| cls._register_resolvers() |
| conf = om.structured(cls) |
| try: |
| if kwargs: |
| conf = om.merge(conf, kwargs) |
| return cast(C, om.to_object(conf)) |
| except OmegaConfBaseException as e: |
| raise OLMoConfigurationError(str(e)) |
|
|
| @classmethod |
| def load( |
| cls: Type[C], |
| path: PathOrStr, |
| overrides: Optional[List[str]] = None, |
| key: Optional[str] = None, |
| validate_paths: bool = True, |
| ) -> C: |
| """Load from a YAML file.""" |
| cls._register_resolvers(validate_paths=validate_paths) |
| schema = om.structured(cls) |
| try: |
| raw = om.load(str(path)) |
| if key is not None: |
| raw = raw[key] |
| raw = cls.update_legacy_settings(raw) |
| conf = om.merge(schema, raw) |
| if overrides: |
| conf = om.merge(conf, om.from_dotlist(overrides)) |
| return cast(C, om.to_object(conf)) |
| except OmegaConfBaseException as e: |
| raise OLMoConfigurationError(str(e)) |
|
|
| def save(self, path: PathOrStr) -> None: |
| """Save to a YAML file.""" |
| om.save(config=self, f=str(path)) |
|
|
| def asdict(self, exclude: Optional[Iterable[str]] = None) -> Dict[str, Any]: |
| out = asdict(self) |
| if exclude is not None: |
| for name in exclude: |
| if name in out: |
| del out[name] |
| return out |
|
|
|
|
| class LayerNormType(StrEnum): |
| default = "default" |
| """ |
| The default LayerNorm implementation, equivalent to PyTorch's built-in version. |
| """ |
|
|
| low_precision = "low_precision" |
| """ |
| A low-precision version of the default LayerNorm. |
| """ |
|
|
| rms = "rms" |
| """ |
| An RMSNorm implementation. When using ``torch.compile`` this is |
| probably the fastest implementation. |
| """ |
|
|
|
|
| class ActivationType(StrEnum): |
| gelu = "gelu" |
| relu = "relu" |
| swiglu = "swiglu" |
|
|
|
|
| class BlockType(StrEnum): |
| sequential = "sequential" |
|
|
| llama = "llama" |
| """ |
| A block similar to the sequential block with slightly different |
| implementations of operations like attention to imitate the behavior of Llama. |
| """ |
|
|
|
|
| class InitFnType(StrEnum): |
| mitchell = "mitchell" |
| """ |
| The strategy suggested to us by Mitchell Wortsman from UW. |
| This uses a truncated normal distribution with an adaptive standard deviation that depends |
| on the size of the weights as well as the depth of the layer. |
| """ |
|
|
| normal = "normal" |
| """ |
| All weights are initialized from the same normal distribution. |
| """ |
|
|
| kaiming_normal = "kaiming_normal" |
| """ |
| All weights are initialized with the Kaiming method from a normal distribution. |
| Note this currently won't work with FSDP. |
| """ |
|
|
| fan_in = "fan_in" |
| """ |
| "Fan-in variance scaling", i.e. normal with a standard deviation of ``1/sqrt(d_in)`` where ``d_in`` |
| is the input dimensionality of the kernel. |
| """ |
|
|
| full_megatron = "full_megatron" |
| """ |
| This is what metaseq calls "full megatron init". It is the init used for Llama 2. |
| """ |
|
|
|
|
| @dataclass |
| class ModelConfig(BaseConfig): |
| """ |
| OLMo (model) configuration. |
| """ |
|
|
| |
|
|
| d_model: int = 768 |
| """ |
| The hidden size of the model. |
| """ |
|
|
| n_heads: int = 12 |
| """ |
| The number of self-attention heads. |
| """ |
|
|
| n_kv_heads: Optional[int] = None |
| """ |
| The number of heads to use for keys and values. Defaults to `n_heads`. |
| Set this to ``None`` or ``n_heads`` for normal multi-head attention. |
| Set this to 1 for multi-query attention. |
| Set it to some in-between value for Llama2-style grouped query attention. |
| """ |
|
|
| clip_qkv: Optional[float] = None |
| """ |
| Clip QKV to this value when set. |
| """ |
|
|
| n_layers: int = 12 |
| """ |
| The number of layers/blocks. |
| """ |
|
|
| mlp_ratio: int = 4 |
| """ |
| The ratio of the inner MLP dimensionality to ``d_model``. |
| This is only used when ``mlp_hidden_size`` is not set. |
| """ |
|
|
| mlp_hidden_size: Optional[int] = None |
| """ |
| Set the exact hidden size for the MLP. Otherwise the inner MLP hidden size will be set to `mlp_ratio * d_model`. |
| """ |
|
|
| activation_type: ActivationType = ActivationType.swiglu |
| """ |
| The activation function to use within the MLP layers. |
| """ |
|
|
| block_type: BlockType = BlockType.sequential |
| """ |
| The transformer block implementation. |
| """ |
|
|
| block_group_size: int = 1 |
| """ |
| The number of blocks to group together into a single parent block. |
| This has no affect on the number of parameters in the model and is only used to wrap groups |
| of blocks together with a single FSDP wrapper during training. |
| """ |
|
|
| alibi: bool = False |
| """ |
| If ``True``, use ALiBi embeddings. Mutually exclusive with ``rope``. |
| """ |
|
|
| alibi_bias_max: float = 8.0 |
| """ |
| Maximum absolute value of ALiBi bias. |
| """ |
|
|
| rope: bool = False |
| """ |
| Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``. |
| """ |
|
|
| rope_full_precision: bool = True |
| """ |
| If ``True``, apply RoPE embeddings at full precision regardless of the input type. Otherwise, |
| apply RoPE at the precision of the input. |
| """ |
|
|
| flash_attention: bool = False |
| """ |
| If ``True``, use ``FlashAttention``. |
| """ |
|
|
| attention_dropout: float = 0.1 |
| """ |
| The dropout probability within the attention modules. |
| """ |
|
|
| multi_query_attention: Optional[bool] = None |
| """ |
| Deprecated. Use n_kv_heads instead. |
| """ |
|
|
| attention_layer_norm: bool = False |
| """ |
| Apply layer norm to the keys and queries within the attention mechanism. |
| This can help stabilize training. |
| """ |
|
|
| residual_dropout: float = 0.1 |
| """ |
| The dropout probability for the MLP and attention output within each block. |
| """ |
|
|
| embedding_dropout: float = 0.1 |
| """ |
| The dropout probability for embeddings. |
| """ |
|
|
| layer_norm_type: LayerNormType = LayerNormType.default |
| """ |
| The layernorm implementation to use. |
| """ |
|
|
| layer_norm_with_affine: bool = True |
| """ |
| Whether to include bias and weight parameters for the layer norms. |
| This only affects layer norms that are immediately followed by a linear layer in the forward pass, |
| so everything except QK-norms. To turn off affines for QK norms as well, set :attr:`attention_layer_norm_with_affine` |
| to ``False``. |
| """ |
|
|
| attention_layer_norm_with_affine: bool = True |
| """ |
| Toggle affine transform for the QK norms. |
| """ |
|
|
| max_sequence_length: int = 1024 |
| """ |
| The maximum input sequence length supported by the model. |
| """ |
|
|
| include_bias: bool = True |
| """ |
| Whether or not to include bias parameters in linear layers. |
| In PaLM, they got rid of all bias terms because they found that large |
| models tend to have near 0 bias terms anyway. |
| """ |
|
|
| bias_for_layer_norm: Optional[bool] = None |
| """ |
| Whether or not to include bias parameters in layer norm. |
| This is separate from the include_bias parameter, because of a ROCm crash when biases are disabled in |
| layer norm. |
| When this is None (the default), it inherits the setting from include_bias. |
| """ |
|
|
| scale_logits: bool = False |
| """ |
| If ``True``, scale the output logits by ``1 / sqrt(d_model)``. |
| """ |
|
|
| vocab_size: int = 50257 |
| """ |
| Vocabulary size of the model. |
| """ |
|
|
| embedding_size: Optional[int] = 50304 |
| """ |
| The number of embeddings, i.e. the number of tokens. If set to ``None`` it will default |
| to ``vocab_size``. If ``vocab_size`` is not a multiple of 128, setting this to the |
| next multiple of 128 that's greater than ``vocab_size`` can improve throughput |
| substantially. |
| """ |
|
|
| weight_tying: bool = True |
| """ |
| Whether to tie output linear weights to the input embedding. |
| """ |
|
|
| eos_token_id: int = 50256 |
| """ |
| The ID of the end-of-sentence special token. |
| """ |
|
|
| pad_token_id: int = 50256 |
| """ |
| The ID of the token to use for padding. Defaults to the ID of the EOS token. |
| """ |
|
|
| init_device: Optional[str] = None |
| """ |
| The torch device to use when initializing the model parameters, e.g. "cpu", "cuda:0", "meta". |
| """ |
|
|
| init_fn: InitFnType = InitFnType.normal |
| """ |
| The weight initialization strategy. |
| """ |
|
|
| init_std: float = 0.02 |
| """ |
| The standard deviation to use when initializing weights with a "fixed distribution" ``init_fn``, such |
| as "normal". |
| """ |
|
|
| init_cutoff_factor: Optional[float] = None |
| """ |
| A positive factor used to scale the cutoff values when initializing weights with a "fixed distribution" ``init_fn``, such |
| as "normal". Setting this to None means values are not cutoff. |
| """ |
|
|
| precision: Optional[str] = None |
| """ |
| Precision used to train/evaluate with. You shouldn't set this directly. |
| See :data:`TrainConfig.precision` instead. |
| """ |
|
|
| ternary: bool = False |
| """ |
| Use ternary BitLinear layer from "The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits" (https://arxiv.org/pdf/2402.17764.pdf) |
| """ |
|
|
| @property |
| def effective_n_kv_heads(self) -> int: |
| if self.n_kv_heads is None: |
| if self.multi_query_attention is True: |
| return 1 |
| else: |
| return self.n_heads |
| else: |
| if self.multi_query_attention is None: |
| return self.n_kv_heads |
| if self.multi_query_attention: |
| n_kv_heads_should_be = 1 |
| else: |
| n_kv_heads_should_be = self.n_heads |
| if self.n_kv_heads == n_kv_heads_should_be: |
| return n_kv_heads_should_be |
| else: |
| raise OLMoConfigurationError( |
| "You can't set `multi_query_attention` and `n_kv_heads` at the same time." |
| ) |
|
|
|
|
| class OptimizerType(StrEnum): |
| lionw = "lionw" |
| adamw = "adamw" |
|
|
|
|
| @dataclass |
| class OptimizerConfig(BaseConfig): |
| name: OptimizerType = OptimizerType.lionw |
| learning_rate: float = 1.0e-4 |
| weight_decay: float = 0.01 |
| betas: Tuple[float, float] = (0.9, 0.95) |
|
|
| no_decay_norm_and_bias: Optional[bool] = None |
| """ |
| Deprecated. Use ``decay_norm_and_bias`` and ``decay_embeddings`` instead. |
| """ |
|
|
| decay_norm_and_bias: bool = False |
| decay_embeddings: bool = False |
| metrics_log_interval: Optional[int] = None |
| """ |
| The interval with which to collect and log detailed parameter-specific metrics. |
| This only applies when logging to W&B, since these metrics won't be logged to the console. |
| If not set, defaults to the wandb `log_interval`. |
| """ |
|
|
| def __post_init__(self): |
| self.betas = tuple(self.betas) |
|
|
| @classmethod |
| def update_legacy_settings(cls, config: D) -> D: |
| new_config = config.copy() |
| if om.is_dict(new_config): |
| assert isinstance(new_config, DictConfig) |
|
|
| if hasattr(new_config, "name") and new_config.name == "decoupled_lionw": |
| new_config.name = "lionw" |
| if hasattr(new_config, "eps"): |
| del new_config.eps |
|
|
| return new_config |
|
|
|
|
| class SchedulerType(StrEnum): |
| cosine_with_warmup = "cosine_with_warmup" |
| linear_with_warmup = "linear_with_warmup" |
| inverse_sqrt_with_warmup = "inverse_sqrt_with_warmup" |
| max_scheduler = "max_scheduler" |
| constant = "constant" |
|
|
|
|
| class SchedulerUnits(StrEnum): |
| steps = "steps" |
| tokens = "tokens" |
|
|
|
|
| @dataclass |
| class SchedulerConfig(BaseConfig): |
| name: SchedulerType = SchedulerType.cosine_with_warmup |
| units: SchedulerUnits = SchedulerUnits.steps |
| t_warmup: Union[int, float] = 100 |
| t_max: Optional[Union[int, float]] = None |
| alpha_f: float = 0.1 |
|
|
| grad_clip_warmup_steps: Optional[Union[int, float]] = None |
| """ |
| The warmup period for which the max grad norm (or norm ratio) will be set to its |
| warmup value of `max_grad_norm * grad_clip_warmup_factor`. |
| """ |
|
|
| grad_clip_warmup_factor: Optional[float] = None |
| """ |
| The ratio of the max allowed gradient norm (or norm ratio) for clipping during the warmup period |
| vs after the warmup period. |
| """ |
|
|
|
|
| class PaddingDirection(StrEnum): |
| right = "right" |
| left = "left" |
|
|
|
|
| @dataclass |
| class DataConfig(BaseConfig): |
| paths: Optional[List[str]] = None |
| datasets: Optional[Dict[str, List[str]]] = None |
| label_mask_paths: Optional[List[str]] = None |
| pad_direction: PaddingDirection = PaddingDirection.right |
| generate_attention_mask: bool = False |
| num_workers: int = 0 |
| drop_last: bool = False |
| pin_memory: bool = False |
| prefetch_factor: Optional[int] = None |
| persistent_workers: bool = False |
| timeout: int = 0 |
| seed: Optional[int] = None |
|
|
|
|
| class EvaluatorType(StrEnum): |
| downstream = "downstream" |
| lm = "lm" |
|
|
|
|
| @dataclass |
| class EvaluatorConfig(BaseConfig): |
| label: str |
| type: EvaluatorType = EvaluatorType.lm |
| data: DataConfig = field(default_factory=DataConfig) |
| device_eval_batch_size: Optional[int] = None |
| subset_num_batches: Optional[int] = None |
|
|
|
|
| class TruncationDirection(StrEnum): |
| right = "right" |
| left = "left" |
|
|
|
|
| @dataclass |
| class TokenizerConfig(BaseConfig): |
| identifier: str = "gpt2" |
| truncate_direction: TruncationDirection = TruncationDirection.right |
|
|
|
|
| @dataclass |
| class WandbConfig(BaseConfig): |
| project: Optional[str] = None |
| entity: Optional[str] = "ai2-llm" |
| group: Optional[str] = None |
| name: Optional[str] = None |
| tags: Optional[List[str]] = field(default_factory=lambda: ["watching"]) |
| log_artifacts: bool = False |
| rank_zero_only: bool = True |
| log_interval: int = 1 |
|
|
|
|
| @dataclass |
| class SpeedMonitorConfig(BaseConfig): |
| window_size: int = 100 |
| gpu_flops_available: Optional[Union[float, int]] = None |
|
|
|
|
| @dataclass |
| class CompilerConfig(BaseConfig): |
| mode: Optional[str] = None |
| """ |
| The mode to compile the model in. At the moment this can be "default", |
| "reduce-overhead" (useful for smaller models/batches), or "max-autotune" |
| (the fastest for larger models, but takes a long time to compile). |
| """ |
|
|
| fullgraph: bool = False |
| """ |
| Whether it is OK to break model into several subgraphs when compiling. |
| Note that this is not compatible with FSDP. |
| """ |
|
|
| backend: str = "inductor" |
| """ |
| The backend to use. |
| """ |
|
|
|
|
| class FSDPWrapStrategy(StrEnum): |
| by_block = "by_block" |
| """ |
| Wrap each OLMo block with its own FSDP instance. |
| """ |
|
|
| by_block_and_size = "by_block_and_size" |
| """ |
| Like 'by_block' but `wte` and `ff_out` will be wrapped separately as well. |
| """ |
|
|
| by_block_group = "by_block_group" |
| """ |
| Wrap each block group together into its own FSDP instance. |
| This requires :attr:`~ModelConfig.block_group_size` to be bigger than 1. |
| """ |
|
|
| by_block_group_and_size = "by_block_group_and_size" |
| """ |
| Like 'by_block_group' but `wte` and `ff_out` will be wrapped separately as well. |
| """ |
|
|
| size_based = "size_based" |
| """ |
| Used PyTorch's default size-based auto wrap policy. |
| """ |
|
|
| one_in_two = "one_in_two" |
| one_in_three = "one_in_three" |
| one_in_four = "one_in_four" |
| one_in_five = "one_in_five" |
|
|
|
|
| class FSDPPrecision(StrEnum): |
| pure = "pure" |
| """ |
| Equivalent to :class:`torch.distributed.fsdp.MixedPrecision` with ``param_dtype``, ``reduce_dtype``, |
| and ``buffer_dtype`` all set to the autocast precision data type. |
| """ |
|
|
| mixed = "mixed" |
| """ |
| Equivalent to :class:`torch.distributed.fsdp.MixedPrecision` with ``param_dtype``, and ``buffer_dtype`` |
| set to the autocast precision data type, while ``reduce_dtype`` is set to fp32. |
| """ |
|
|
|
|
| @dataclass |
| class FSDPConfig(BaseConfig): |
| use_orig_params: bool = True |
| """ |
| This must be ``True`` if using ``compile`` or you want to track the parameter norm during training. |
| """ |
|
|
| sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD |
|
|
| wrapping_strategy: Optional[FSDPWrapStrategy] = None |
| """ |
| The wrapping strategy to use. If ``None``, the default, the model is wrapped with a single top-level |
| FSDP instance. |
| """ |
|
|
| precision: FSDPPrecision = FSDPPrecision.pure |
|
|
|
|
| class CheckpointType(StrEnum): |
| sharded = "sharded" |
| unsharded = "unsharded" |
| sharded_ephemeral = "sharded_ephemeral" |
|
|
|
|
| class ShardedCheckpointerType(StrEnum): |
| torch_new = "torch_new" |
| torch_legacy = "torch_legacy" |
| local = "local" |
|
|
|
|
| class ActivationCheckpointingStrategy(StrEnum): |
| whole_layer = "whole_layer" |
| """ |
| Checkpoint every transformer layer. |
| """ |
|
|
| one_in_two = "one_in_two" |
| """ |
| Checkpoint one in two transformer layers. |
| """ |
|
|
| one_in_three = "one_in_three" |
| """ |
| Checkpoint one in three transformer layers. |
| """ |
|
|
| one_in_four = "one_in_four" |
| """ |
| Checkpoint one in four transformer layers. |
| """ |
|
|
| two_in_three = "two_in_three" |
| """ |
| Checkpoint two out of every three transformer layers. |
| """ |
|
|
| three_in_four = "three_in_four" |
| """ |
| Checkpoint three out of four of every transformer layers. |
| """ |
|
|
| fine_grained = "fine_grained" |
| """ |
| Focus checkpointing on where it is cheap to recompute and saves most memory. |
| """ |
|
|
|
|
| @dataclass |
| class TrainConfig(BaseConfig): |
| """ |
| OLMo training configuration. |
| """ |
|
|
| run_name: Optional[str] = None |
| """ |
| The name of the run. |
| """ |
|
|
| seed: int = 6198 |
| """ |
| Used to seed all initial RNG states. |
| """ |
|
|
| epoch: Optional[int] = None |
| """ |
| Increment this when starting a new epoch. |
| """ |
|
|
| dry_run: bool = False |
| """ |
| If ``True``, don't actually train. |
| """ |
|
|
| model: ModelConfig = field(default_factory=ModelConfig) |
| """ |
| OLMo Model configuration. |
| """ |
|
|
| optimizer: OptimizerConfig = field(default_factory=OptimizerConfig) |
| """ |
| Optimizer configuration. |
| """ |
|
|
| scheduler: SchedulerConfig = field(default_factory=SchedulerConfig) |
| """ |
| Learning rate scheduler configuration. |
| """ |
|
|
| data: DataConfig = field(default_factory=DataConfig) |
| """ |
| Training data configuration. |
| """ |
|
|
| restore_dataloader: bool = True |
| """ |
| When restarting, restore the data loader to where it left off. |
| If you restarting in order to train on a different dataset, set this to ``False``. |
| """ |
|
|
| fast_forward_batches: Optional[int] = None |
| """ |
| When restarting, use this to fast-forward the dataloader beyond the last checkpoint. |
| This can be useful when restarting due to a loss spike in order to skip the data that |
| corresponded to the spike. |
| """ |
|
|
| evaluators: List[EvaluatorConfig] = field(default_factory=list) |
| """ |
| Evaluation configurations. |
| """ |
|
|
| eval_interval: int = 1000 |
| """ |
| How often (in terms of batches) to run evaluations. |
| """ |
|
|
| tokenizer: TokenizerConfig = field(default_factory=TokenizerConfig) |
| """ |
| Tokenizer configuration. |
| """ |
|
|
| save_folder: str = "./" |
| """ |
| The directory to save checkpoints to. |
| """ |
|
|
| remote_save_folder: Optional[str] = None |
| """ |
| A folder in a cloud bucket to upload saved checkpoints to. |
| """ |
|
|
| canceled_check_interval: int = 50 |
| """ |
| How often (in batches) to check if the run has been canceled or reached its time limit. |
| """ |
|
|
| save_interval: int = 1000 |
| """ |
| How often (in terms of steps) to save sharded training state checkpoints. |
| """ |
|
|
| save_interval_unsharded: Optional[int] = None |
| """ |
| How often (if at all) to save unsharded training state checkpoint. |
| For large models it can be costly to save these, so it usually makes sense to save |
| these less often than regular (sharded) training checkpoints. |
| """ |
|
|
| save_interval_ephemeral: Optional[int] = None |
| """ |
| How often (if at all) to save ephemeral sharded checkpoints. These checkpoints are the same |
| as those saved every `save_interval` except that at most only the most recent one of these is kept. |
| This is useful when you want to checkpoint often for restarts in case of failures, but don't |
| want to keep the majority of these checkpoints. |
| |
| For example, suppose you want to keep your checkpoints at every 1000 steps, but you also want to save |
| a temporary checkpoint every 100 steps in case your job fails. In that case you would |
| set `save_interval=1000` and `save_interval_ephemeral=100`. |
| """ |
|
|
| save_num_checkpoints_to_keep: int = -1 |
| """ |
| How many sharded checkpoints to keep. |
| """ |
|
|
| save_num_unsharded_checkpoints_to_keep: int = -1 |
| """ |
| How many unsharded checkpoints to keep. |
| """ |
|
|
| save_overwrite: bool = False |
| """ |
| If ``True``, overwrite any conflicting checkpoint files. |
| """ |
|
|
| force_save_unsharded: bool = False |
| """ |
| Save an unsharded checkpoint before training (even during a dry run). |
| Use this option with `--load-path={PATH}` and `--dry_run` to convert a sharded |
| checkpoint into an unsharded checkpoint. |
| """ |
|
|
| no_pre_train_checkpoint: bool = False |
| """ |
| Skip saving pre-train checkpoint. |
| """ |
|
|
| load_path: Optional[str] = None |
| """ |
| The path to a training checkpoint to restore/resume from. |
| |
| Note that you can make use of the "path.last_checkpoint" Omegaconfig YAML resolver here, which takes |
| a local or remote directory and resolves to the latest checkpoint (sharded or unsharded) in that directory. |
| For example, |
| |
| ```bash |
| --load_path='${path.last_checkpoint:s3://ai2-llm/checkpoints/7b/v1_5-mix-run-001}' |
| ``` |
| """ |
|
|
| load_path_sharded_checkpointer: Optional[ShardedCheckpointerType] = None |
| """ |
| The sharded checkpointer type to use to load the initial checkpoint from ``load_path``. |
| """ |
|
|
| reset_optimizer_state: bool = False |
| """ |
| When this is set, we restore the model from a checkpoint (if given), but we leave the optimizer uninitialized. |
| We also set a new learning rate schedule that does a new warmup, such that it intercepts the original learning |
| curve (according to the current learning rate schedule settings), and continues from there. |
| """ |
|
|
| reset_trainer_state: bool = False |
| """ |
| When this is set we don't restore the trainer state from a checkpoint. |
| """ |
|
|
| sharded_checkpointer: ShardedCheckpointerType = ShardedCheckpointerType.torch_legacy |
| """ |
| The name of the sharded checkpointer to use to save (sharded) checkpoints throughout training. |
| """ |
|
|
| new_style_checkpoints: Optional[bool] = None |
| """ |
| Deprecated. Use ``sharded_checkpointer`` instead. |
| """ |
|
|
| max_duration: Union[int, str] = 10000 |
| """ |
| How long to train for. |
| |
| If specified without a unit (the default), the units are assumed to be steps. |
| You can also specify this in terms of tokens, for example: `max_duration="2e12T"` means train until |
| 2 trillion tokens. |
| """ |
|
|
| global_train_batch_size: int = 512 |
| """ |
| The effective global batch size. |
| """ |
|
|
| device_train_batch_size: Optional[int] = None |
| """ |
| Don't set this manually. This will be set to ``global_train_batch_size // world_size``. |
| """ |
|
|
| device_train_microbatch_size: int = 16 |
| """ |
| The number of instances passed to the model in a single forward-backward pass. You should set |
| this as large as you can based on available GPU memory. |
| """ |
|
|
| device_eval_batch_size: int = 16 |
| """ |
| The number of evaluation instances passed to the model in a single forward pass on each device. |
| """ |
|
|
| eval_subset_num_batches: int = -1 |
| """ |
| The number of batches to use for downstream evaluation from each dataset. |
| """ |
|
|
| eval_on_load: bool = False |
| """ |
| When resuming from a checkpoint, run the evaluation loop right away. |
| """ |
|
|
| device_train_grad_accum: Optional[int] = None |
| """ |
| Don't set this manually. This will be set to ``device_train_batch_size // device_train_microbatch_size``. |
| """ |
|
|
| max_grad_norm: Optional[float] = None |
| """ |
| Clip gradient norms to this value if set. |
| """ |
|
|
| max_grad_norm_ratio: Optional[float] = None |
| """ |
| If set, gradient norms will be clipped to `max_grad_norm_ratio * exp_avg(norm(grad))`. |
| This takes priority over `max_grad_norm` when set. |
| """ |
|
|
| precision: Optional[str] = None |
| """ |
| Precision to train with (e.g. "amp_bf16", "amp_fp16", or "fp32"). |
| """ |
|
|
| wandb: Optional[WandbConfig] = None |
| """ |
| Weights & Biases configuration. |
| """ |
|
|
| speed_monitor: SpeedMonitorConfig = field(default_factory=SpeedMonitorConfig) |
| """ |
| Speed monitor configuration. |
| """ |
|
|
| console_log_interval: int = 1 |
| """ |
| How often to log to the console. |
| """ |
|
|
| compile: Optional[CompilerConfig] = None |
| """ |
| Settings for compiling the model with ``torch.compile()``. |
| """ |
|
|
| fsdp: FSDPConfig = field(default_factory=FSDPConfig) |
| """ |
| Fully sharded data parallel settings. |
| """ |
|
|
| softmax_auxiliary_loss: bool = False |
| """ |
| If ``True``, we add the auxiliary loss function from PaLM that encourages the softmax |
| normalizing term to be close to 0. |
| """ |
|
|
| time_limit: Optional[float] = 60 * 60 * 47.5 |
| """ |
| The maximum amount of time to train for before saving a checkpoint and ending early. |
| On LUMI we have 48 hours max per job, so we default to just under 48 hours to give us time |
| to write out a final checkpoint. |
| """ |
|
|
| extra_steps_after_cancel: int = 10 |
| """ |
| Under certain conditions when a run is canceled we train for a few extra steps after saving |
| the final checkpoint so that when the run is restarted from the latest checkpoint we have some |
| overlap in metrics. |
| """ |
|
|
| early_stopping_factor: Optional[float] = None |
|
|
| save_data_indices: bool = True |
| """ |
| Save training data indices from each batch for each worker. |
| """ |
|
|
| python_profiling: bool = False |
| """ |
| Whether to run the Python profiler on batches 6, 7, and 8. |
| """ |
|
|
| torch_profiling: bool = False |
| """ |
| Whether to run the PyTorch profiler on batches 6, 7, and 8. |
| """ |
|
|
| stop_at: Optional[int] = None |
| """ |
| Stop at a specific step. |
| """ |
|
|
| stop_after: Optional[int] = None |
| """ |
| Stop after a specific number of steps. |
| """ |
|
|
| activation_checkpointing: Optional[ActivationCheckpointingStrategy] = None |
| """ |
| The activation checkpointing strategy to use. |
| """ |
|
|
| fused_loss: Optional[bool] = None |
| """ |
| Whether to use the fused CE loss function from `flash-attn`. |
| """ |
|
|
| @property |
| def autocast_precision(self) -> torch.dtype: |
| if self.precision == "amp_bf16": |
| return torch.bfloat16 |
| elif self.precision == "amp_fp16": |
| return torch.float16 |
| elif self.precision == "fp32": |
| return torch.float32 |
| else: |
| raise ValueError(f"Unexpected precision type '{self.precision}'") |
|
|
| @property |
| def fsdp_precision(self) -> MixedPrecision: |
| if self.fsdp.precision == FSDPPrecision.pure: |
| return MixedPrecision( |
| param_dtype=self.autocast_precision, |
| reduce_dtype=self.autocast_precision, |
| buffer_dtype=self.autocast_precision, |
| ) |
| elif self.fsdp.precision == FSDPPrecision.mixed: |
| return MixedPrecision( |
| param_dtype=self.autocast_precision, |
| reduce_dtype=torch.float32, |
| buffer_dtype=self.autocast_precision, |
| ) |
| else: |
| raise NotImplementedError(f"{self.fsdp.precision}") |
|
|
| @classmethod |
| def update_legacy_settings(cls, config: D) -> D: |
| new_config = config.copy() |
| if om.is_dict(new_config): |
| assert isinstance(new_config, DictConfig) |
|
|
| if hasattr(new_config, "activation_checkpointing"): |
| if new_config.activation_checkpointing is False: |
| new_config.activation_checkpointing = None |
| if new_config.activation_checkpointing is True: |
| new_config.activation_checkpointing = ActivationCheckpointingStrategy.whole_layer |
|
|
| if hasattr(new_config, "optimizer"): |
| new_config.optimizer = OptimizerConfig.update_legacy_settings(new_config.optimizer) |
|
|
| return new_config |
|
|