| """
|
| LLaDA configuration
|
| """
|
| from transformers import AutoConfig, PretrainedConfig
|
|
|
| from enum import Enum
|
| from os import PathLike
|
| from typing import Union
|
| from dataclasses import asdict, dataclass, field
|
| from glob import glob
|
| from pathlib import Path
|
| from typing import (
|
| Any,
|
| Dict,
|
| Iterable,
|
| List,
|
| Optional,
|
| Tuple,
|
| Type,
|
| TypeVar,
|
| Union,
|
| cast,
|
| )
|
|
|
|
|
| __all__ = [
|
| "ActivationType",
|
| "ActivationCheckpointingStrategy",
|
| "BlockType",
|
| "LayerNormType",
|
| "InitFnType",
|
| "ModelConfig",
|
| ]
|
|
|
| PathOrStr = Union[str, PathLike]
|
|
|
|
|
| class StrEnum(str, Enum):
|
| """
|
| This is equivalent to Python's :class:`enum.StrEnum` since version 3.11.
|
| We include this here for compatibility with older version of Python.
|
| """
|
|
|
| def __str__(self) -> str:
|
| return self.value
|
|
|
| def __repr__(self) -> str:
|
| return f"'{str(self)}'"
|
|
|
|
|
| class LayerNormType(StrEnum):
|
| default = "default"
|
| """
|
| The default LayerNorm implementation, equivalent to PyTorch's built-in version.
|
| """
|
|
|
| low_precision = "low_precision"
|
| """
|
| A low-precision version of the default LayerNorm.
|
| """
|
|
|
| rms = "rms"
|
| """
|
| An RMSNorm implementation. When using ``torch.compile`` this is
|
| probably the fastest implementation.
|
| """
|
|
|
| gemma_rms = "gemma_rms"
|
| """
|
| An RMSNorm implementation by gemmma. When using ``torch.compile`` this is
|
| probably the fastest implementation.
|
| """
|
|
|
| amd_compatible = "amd_compatible"
|
| """
|
| LayerNorm implemented manually to work around an issue with ROCm.
|
| """
|
|
|
|
|
| class ActivationType(StrEnum):
|
| gelu = "gelu"
|
| relu = "relu"
|
| silu = "silu"
|
| swiglu = "swiglu"
|
|
|
|
|
| class BlockType(StrEnum):
|
| sequential = "sequential"
|
| parallel = "parallel"
|
|
|
| llama = "llama"
|
| """
|
| A block similar to the sequential block with slightly different
|
| implementations of operations like attention to imitate the behavior of Llama.
|
| """
|
|
|
|
|
| class InitFnType(StrEnum):
|
| mitchell = "mitchell"
|
| """
|
| The strategy suggested to us by Mitchell Wortsman from UW.
|
| This uses a truncated normal distribution with an adaptive standard deviation that depends
|
| on the size of the weights as well as the depth of the layer.
|
| """
|
|
|
| normal = "normal"
|
| """
|
| All weights are initialized from the same normal distribution.
|
| """
|
|
|
| kaiming_normal = "kaiming_normal"
|
| """
|
| All weights are initialized with the Kaiming method from a normal distribution.
|
| Note this currently won't work with FSDP.
|
| """
|
|
|
| fan_in = "fan_in"
|
| """
|
| "Fan-in variance scaling", i.e. normal with a standard deviation of ``1/sqrt(d_in)`` where ``d_in``
|
| is the input dimensionality of the kernel.
|
| """
|
|
|
| full_megatron = "full_megatron"
|
| """
|
| This is what metaseq calls "full megatron init". It is the init used for Llama 2.
|
| """
|
|
|
|
|
| @dataclass
|
| class EngramConfig:
|
| tokenizer_name_or_path: str = "deepseek-ai/DeepSeek-V3"
|
| engram_vocab_size: List[int] = field(default_factory=lambda: [129280*5, 129280*5])
|
| max_ngram_size: int = 3
|
| n_embed_per_ngram: int = 512
|
| n_head_per_ngram: int = 8
|
| layer_ids: List[int] = field(default_factory=lambda: [1, 15])
|
| pad_id: int = 2
|
| seed: int = 0
|
| kernel_size: int = 7
|
|
|
|
|
| @dataclass
|
| class ModelConfig():
|
| """
|
| LLaDA (model) configuration.
|
| """
|
|
|
|
|
|
|
| d_model: int = 768
|
| """
|
| The hidden size of the model.
|
| """
|
|
|
| n_heads: int = 12
|
| """
|
| The number of self-attention heads.
|
| """
|
|
|
| n_kv_heads: Optional[int] = None
|
| """
|
| The number of heads to use for keys and values. Defaults to `n_heads`.
|
| Set this to ``None`` or ``n_heads`` for normal multi-head attention.
|
| Set this to 1 for multi-query attention.
|
| Set it to some in-between value for Llama2-style grouped query attention.
|
| """
|
|
|
| n_layers: int = 12
|
| """
|
| The number of layers/blocks.
|
| """
|
|
|
| mlp_ratio: int = 4
|
| """
|
| The ratio of the inner MLP dimensionality to ``d_model``.
|
| This is only used when ``mlp_hidden_size`` is not set.
|
| """
|
|
|
| mlp_hidden_size: Optional[int] = None
|
| """
|
| Set the exact hidden size for the MLP. Otherwise the inner MLP hidden size will be set to `mlp_ratio * d_model`.
|
| """
|
|
|
| activation_type: ActivationType = ActivationType.swiglu
|
| """
|
| The activation function to use within the MLP layers.
|
| """
|
|
|
| block_type: BlockType = BlockType.sequential
|
| """
|
| The transformer block implementation.
|
| """
|
|
|
| block_group_size: int = 1
|
| """
|
| The number of blocks to group together into a single parent block.
|
| This has no affect on the number of parameters in the model and is only used to wrap groups
|
| of blocks together with a single FSDP wrapper during training.
|
| """
|
|
|
| alibi: bool = False
|
| """
|
| If ``True``, use ALiBi embeddings. Mutually exclusive with ``rope``.
|
| """
|
|
|
| alibi_bias_max: float = 8.0
|
| """
|
| Maximum absolute value of ALiBi bias.
|
| """
|
|
|
| rope: bool = False
|
| """
|
| Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``.
|
| """
|
|
|
| rope_full_precision: bool = True
|
| """
|
| If ``True``, apply RoPE embeddings at full precision regardless of the input type. Otherwise,
|
| apply RoPE at the precision of the input.
|
| """
|
|
|
| flash_attention: bool = False
|
| """
|
| If ``True``, use ``FlashAttention``.
|
| """
|
|
|
| attention_dropout: float = 0.1
|
| """
|
| The dropout probability within the attention modules.
|
| """
|
|
|
| multi_query_attention: Optional[bool] = None
|
| """
|
| Use the Multi-Query formulation of attention used in PaLM. This reduces the number of parameters
|
| and is more efficient during inference.
|
| """
|
|
|
| attention_layer_norm: bool = False
|
| """
|
| Apply layer norm to the keys and queries within the attention mechanism.
|
| This can help stabilize training.
|
| """
|
|
|
| residual_dropout: float = 0.1
|
| """
|
| The dropout probability for the MLP and attention output within each block.
|
| """
|
|
|
| embedding_dropout: float = 0.1
|
| """
|
| The dropout probability for embeddings.
|
| """
|
|
|
| input_emb_norm: bool = False
|
| """
|
| An input hidden_states norm implementation by gemmma.
|
| """
|
|
|
| layer_norm_type: LayerNormType = LayerNormType.default
|
| """
|
| The layernorm implementation to use.
|
| """
|
|
|
| layer_norm_with_affine: bool = True
|
| """
|
| Whether to include bias and weight parameters for the layer norms.
|
| This only affects layer norms that are immediately followed by a linear layer in the forward pass,
|
| so everything except QK-norms. To turn off affines for QK norms as well, set :attr:`attention_layer_norm_with_affine`
|
| to ``False``.
|
| """
|
|
|
| rms_norm_eps: float = 1e-05
|
| """
|
| The rms layernorm eps param.
|
| """
|
|
|
| attention_layer_norm_with_affine: bool = True
|
| """
|
| Toggle affine transform for the QK norms.
|
| """
|
|
|
| max_sequence_length: int = 1024
|
| """
|
| The maximum input sequence length supported by the model.
|
| """
|
|
|
| rope_theta: float = 10000.0
|
| """
|
| The rope base param.
|
| """
|
|
|
| include_qkv_bias: Optional[bool] = False
|
| """
|
| Whether or not to include bias parameters in qkv linear layers.
|
| """
|
|
|
| include_bias: bool = False
|
| """
|
| Whether or not to include bias parameters in linear layers.
|
| In PaLM, they got rid of all bias terms because they found that large
|
| models tend to have near 0 bias terms anyway.
|
| """
|
|
|
| bias_for_layer_norm: Optional[bool] = None
|
| """
|
| Whether or not to include bias parameters in layer norm.
|
| This is separate from the include_bias parameter, because of a ROCm crash when biases are disabled in
|
| layer norm.
|
| When this is None (the default), it inherits the setting from include_bias.
|
| """
|
|
|
| scale_logits: bool = False
|
| """
|
| If ``True``, scale the output logits by ``1 / sqrt(d_model)``.
|
| """
|
|
|
| vocab_size: int = 50257
|
| """
|
| Vocabulary size of the model.
|
| """
|
|
|
| embedding_size: Optional[int] = 50304
|
| """
|
| The number of embeddings, i.e. the number of tokens. If set to ``None`` it will default
|
| to ``vocab_size``. If ``vocab_size`` is not a multiple of 128, setting this to the
|
| next multiple of 128 that's greater than ``vocab_size`` can improve throughput
|
| substantially.
|
| """
|
|
|
| weight_tying: bool = True
|
| """
|
| Whether to tie output linear weights to the input embedding.
|
| """
|
|
|
| eos_token_id: int = 50256
|
| """
|
| The ID of the end-of-sentence special token.
|
| """
|
|
|
| pad_token_id: int = 50256
|
| """
|
| The ID of the token to use for padding. Defaults to the ID of the EOS token.
|
| """
|
|
|
| mask_token_id: Optional[int] = 50256
|
| """
|
| The ID of the token to use for mask token. Defaults to the ID of the EOS token.
|
| """
|
|
|
| init_device: Optional[str] = None
|
| """
|
| The torch device to use when initializing the model parameters, e.g. "cpu", "cuda:0", "meta".
|
| """
|
|
|
| init_fn: InitFnType = InitFnType.normal
|
| """
|
| The weight initialization strategy.
|
| """
|
|
|
| init_std: float = 0.02
|
| """
|
| The standard deviation to use when initializing weights with a "fixed distribution" ``init_fn``, such
|
| as "normal".
|
| """
|
|
|
| init_cutoff_factor: Optional[float] = None
|
| """
|
| A positive factor used to scale the cutoff values when initializing weights with a "fixed distribution" ``init_fn``, such
|
| as "normal". Setting this to None means values are not cutoff.
|
| """
|
|
|
| precision: Optional[str] = None
|
| """
|
| Precision used to train/evaluate with. You shouldn't set this directly.
|
| See :data:`TrainConfig.precision` instead.
|
| """
|
|
|
| engram_config: Optional[EngramConfig] = None
|
|
|
| @property
|
| def effective_n_kv_heads(self) -> int:
|
| if self.n_kv_heads is None:
|
| if self.multi_query_attention is True:
|
| return 1
|
| else:
|
| return self.n_heads
|
| else:
|
| if self.multi_query_attention is None:
|
| return self.n_kv_heads
|
| if self.multi_query_attention:
|
| n_kv_heads_should_be = 1
|
| else:
|
| n_kv_heads_should_be = self.n_heads
|
| if self.n_kv_heads == n_kv_heads_should_be:
|
| return n_kv_heads_should_be
|
| else:
|
| raise Exception(
|
| "You can't set `multi_query_attention` and `n_kv_heads` at the same time."
|
| )
|
|
|
| class ActivationCheckpointingStrategy(StrEnum):
|
| whole_layer = "whole_layer"
|
| """
|
| Checkpoint every transformer layer.
|
| """
|
|
|
| one_in_two = "one_in_two"
|
| """
|
| Checkpoint one in two transformer layers.
|
| """
|
|
|
| one_in_three = "one_in_three"
|
| """
|
| Checkpoint one in three transformer layers.
|
| """
|
|
|
| one_in_four = "one_in_four"
|
| """
|
| Checkpoint one in four transformer layers.
|
| """
|
|
|
| two_in_three = "two_in_three"
|
| """
|
| Checkpoint two out of every three transformer layers.
|
| """
|
|
|
| three_in_four = "three_in_four"
|
| """
|
| Checkpoint three out of four of every transformer layers.
|
| """
|
|
|
| four_in_five = "four_in_five"
|
| """
|
| Checkpoint four out of five of every transformer layers.
|
| """
|
|
|
| nine_in_ten = "nine_in_ten"
|
| """
|
| Checkpoint nine out of ten of every transformer layers.
|
| """
|
|
|
| fine_grained = "fine_grained"
|
| """
|
| Focus checkpointing on where it is cheap to recompute and saves most memory.
|
| """
|
|
|
| class LLaDAConfig(PretrainedConfig):
|
| model_type = "llada"
|
| keys_to_ignore_at_inference = ["past_key_values"]
|
|
|
| def __init__(self, use_cache: bool = False, **kwargs):
|
| model_config = ModelConfig()
|
| all_kwargs = model_config.__dict__
|
| all_kwargs.update(kwargs)
|
| all_kwargs.update({"use_cache": use_cache})
|
| all_kwargs.update(
|
| {
|
| "architectures": all_kwargs.get("architectures", ["LLaDAModelLM"])
|
| }
|
| )
|
| super().__init__(**all_kwargs)
|
|
|
| @property
|
| def num_attention_heads(self):
|
| return self.n_heads
|
|
|
| @property
|
| def num_hidden_layers(self):
|
| return self.n_layers
|
|
|
| @property
|
| def hidden_size(self):
|
| return self.d_model
|
|
|
|
|
|
|
| AutoConfig.register("llada", LLaDAConfig)
|
|
|