| """Environment-driven training configuration.""" |
|
|
| from __future__ import annotations |
|
|
| import os |
| import math |
| import uuid |
| from dataclasses import dataclass, field |
| from pathlib import Path |
|
|
| from transformers import PretrainedConfig |
|
|
|
|
| DEFAULT_VOCAB_SIZE = 4096 |
| DEFAULT_HIDDEN_SIZE = 192 |
| DEFAULT_NUM_HIDDEN_LAYERS = 5 |
| DEFAULT_NUM_ATTENTION_HEADS = 4 |
| DEFAULT_NUM_KEY_VALUE_HEADS = 2 |
| DEFAULT_HEAD_DIM = DEFAULT_HIDDEN_SIZE // DEFAULT_NUM_ATTENTION_HEADS |
| DEFAULT_INTERMEDIATE_SIZE = DEFAULT_HIDDEN_SIZE * 5 // 2 |
| DEFAULT_BLOCK_SIZE = 512 |
| DEFAULT_ROPE_THETA = 5000.0 |
| DEFAULT_PLACE_VOCAB_SIZE = 66 |
| DEFAULT_ROLE_VOCAB_SIZE = 12 |
| DEFAULT_FEATURE_DIGIT_TOKEN_IDS = tuple(range(20, 30)) |
| DEFAULT_FEATURE_EQUALS_TOKEN_ID = 33 |
| DEFAULT_FEATURE_SPACE_TOKEN_IDS = (225,) |
|
|
|
|
| class GPTConfig(PretrainedConfig): |
| model_type = "gpt" |
|
|
| def __init__( |
| self, |
| vocab_size: int = DEFAULT_VOCAB_SIZE, |
| hidden_size: int = DEFAULT_HIDDEN_SIZE, |
| num_hidden_layers: int = DEFAULT_NUM_HIDDEN_LAYERS, |
| num_attention_heads: int = DEFAULT_NUM_ATTENTION_HEADS, |
| num_key_value_heads: int | None = DEFAULT_NUM_KEY_VALUE_HEADS, |
| intermediate_size: int | None = DEFAULT_INTERMEDIATE_SIZE, |
| head_dim: int | None = None, |
| block_size: int = DEFAULT_BLOCK_SIZE, |
| rope_theta: float = DEFAULT_ROPE_THETA, |
| rms_norm_eps: float = 1e-6, |
| xsa_projection: bool = True, |
| tie_word_embeddings: bool = True, |
| labels_are_shifted: bool = False, |
| use_place_embeddings: bool = True, |
| use_role_embeddings: bool = True, |
| place_vocab_size: int = DEFAULT_PLACE_VOCAB_SIZE, |
| role_vocab_size: int = DEFAULT_ROLE_VOCAB_SIZE, |
| feature_digit_token_ids: list[int] | tuple[int, ...] | None = None, |
| feature_equals_token_id: int | None = DEFAULT_FEATURE_EQUALS_TOKEN_ID, |
| feature_space_token_ids: list[int] | tuple[int, ...] | None = None, |
| **kwargs, |
| ): |
| if num_key_value_heads is None: |
| num_key_value_heads = num_attention_heads |
| if head_dim is None: |
| if hidden_size % num_attention_heads != 0: |
| raise ValueError("hidden_size must be divisible by num_attention_heads") |
| head_dim = hidden_size // num_attention_heads |
| if intermediate_size is None: |
| intermediate_size = hidden_size * 4 |
| if num_attention_heads % num_key_value_heads != 0: |
| raise ValueError("num_attention_heads must be divisible by num_key_value_heads") |
| if head_dim % 2 != 0: |
| raise ValueError("head_dim must be even for RoPE") |
|
|
| super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) |
| self.vocab_size = int(vocab_size) |
| self.hidden_size = int(hidden_size) |
| self.num_hidden_layers = int(num_hidden_layers) |
| self.num_attention_heads = int(num_attention_heads) |
| self.num_key_value_heads = int(num_key_value_heads) |
| self.intermediate_size = int(intermediate_size) |
| self.head_dim = int(head_dim) |
| self.block_size = int(block_size) |
| self.max_position_embeddings = int(block_size) |
| self.rope_theta = float(rope_theta) |
| self.rms_norm_eps = float(rms_norm_eps) |
| self.xsa_projection = bool(xsa_projection) |
| self.labels_are_shifted = bool(labels_are_shifted) |
| self.use_place_embeddings = bool(use_place_embeddings) |
| self.use_role_embeddings = bool(use_role_embeddings) |
| self.place_vocab_size = int(place_vocab_size) |
| self.role_vocab_size = int(role_vocab_size) |
| self.feature_digit_token_ids = [ |
| int(token_id) |
| for token_id in ( |
| DEFAULT_FEATURE_DIGIT_TOKEN_IDS |
| if feature_digit_token_ids is None |
| else feature_digit_token_ids |
| ) |
| ] |
| self.feature_equals_token_id = ( |
| None if feature_equals_token_id is None else int(feature_equals_token_id) |
| ) |
| self.feature_space_token_ids = [ |
| int(token_id) |
| for token_id in ( |
| DEFAULT_FEATURE_SPACE_TOKEN_IDS |
| if feature_space_token_ids is None |
| else feature_space_token_ids |
| ) |
| ] |
|
|
|
|
| def _bool_env(name: str, default: bool) -> bool: |
| raw = os.environ.get(name) |
| if raw is None: |
| return default |
| return raw.strip().lower() in {"1", "true", "yes", "on"} |
|
|
|
|
| def _path_env(name: str, default: str) -> str: |
| return str(Path(os.environ.get(name, default)).expanduser()) |
|
|
|
|
| @dataclass |
| class Hyperparameters: |
| data_dir: str = field(default_factory=lambda: _path_env("DATA_DIR", ".")) |
| tokenized_dir: str = field(default_factory=lambda: _path_env("TOKENIZED_DIR", "tokenized2")) |
| tokenizer_dir: str = field(default_factory=lambda: _path_env("TOKENIZER_DIR", "tokenizer_4k")) |
| tokenizer_path: str = field(default_factory=lambda: os.environ.get("TOKENIZER_PATH", "")) |
| curriculum_path: str = field(default_factory=lambda: os.environ.get("CURRICULUM_PATH", "")) |
| mix_weights_path: str = field(default_factory=lambda: os.environ.get("MIX_WEIGHTS_PATH", "")) |
| run_id: str = field(default_factory=lambda: os.environ.get("RUN_ID", str(uuid.uuid4()))) |
| seed: int = field(default_factory=lambda: int(os.environ.get("SEED", "1337"))) |
| rank: int = field(init=False) |
|
|
| iterations: int = field(default_factory=lambda: int(os.environ.get("ITERATIONS", "10000"))) |
| requested_train_tokens: int = field(init=False) |
| train_tokens: int = field(init=False) |
| decay_start_frac: float = field(default_factory=lambda: float(os.environ.get("DECAY_START_FRAC", "0.9"))) |
| warmup_steps: int = field(default_factory=lambda: int(os.environ.get("WARMUP_STEPS", "0"))) |
| lr_warmup_steps: int = field(default_factory=lambda: int(os.environ.get("LR_WARMUP_STEPS", "500"))) |
| train_batch_tokens: int = field(default_factory=lambda: int(os.environ.get("TRAIN_BATCH_TOKENS", str(DEFAULT_BLOCK_SIZE * 512)))) |
| train_seq_len: int = field(init=False) |
| eval_seq_len: int = field(init=False) |
| grad_accum_steps: int = field(default_factory=lambda: int(os.environ.get("GRAD_ACCUM_STEPS", "2"))) |
| train_log_every: int = field(default_factory=lambda: int(os.environ.get("TRAIN_LOG_EVERY", "100"))) |
| train_log_dense_steps: int = field(default_factory=lambda: int(os.environ.get("TRAIN_LOG_DENSE_STEPS", "100"))) |
| train_log_ramp_steps: int = field( |
| default_factory=lambda: int( |
| os.environ.get( |
| "TRAIN_LOG_RAMP_STEPS", |
| os.environ.get("TRAIN_LOG_FIRST_STEPS", "500"), |
| ) |
| ) |
| ) |
|
|
| val_batch_tokens: int = field(default_factory=lambda: int(os.environ.get("VAL_BATCH_TOKENS", str(DEFAULT_BLOCK_SIZE * 256)))) |
| val_loss_every: int = field(default_factory=lambda: int(os.environ.get("VAL_LOSS_EVERY", "1000"))) |
| val_max_tokens: int = field(default_factory=lambda: int(os.environ.get("VAL_MAX_TOKENS", "10_000_000"))) |
|
|
| vocab_size: int = field(default_factory=lambda: int(os.environ.get("VOCAB_SIZE", str(DEFAULT_VOCAB_SIZE)))) |
| hidden_size: int = field(default_factory=lambda: int(os.environ.get("HIDDEN_SIZE", os.environ.get("MODEL_DIM", str(DEFAULT_HIDDEN_SIZE))))) |
| num_hidden_layers: int = field(default_factory=lambda: int(os.environ.get("NUM_HIDDEN_LAYERS", os.environ.get("NUM_LAYERS", str(DEFAULT_NUM_HIDDEN_LAYERS))))) |
| num_attention_heads: int = field(default_factory=lambda: int(os.environ.get("NUM_ATTENTION_HEADS", os.environ.get("NUM_HEADS", str(DEFAULT_NUM_ATTENTION_HEADS))))) |
| num_key_value_heads: int = field(default_factory=lambda: int(os.environ.get("NUM_KEY_VALUE_HEADS", os.environ.get("NUM_KV_HEADS", str(DEFAULT_NUM_KEY_VALUE_HEADS))))) |
| head_dim: int = field(init=False) |
| intermediate_size: int = field(default_factory=lambda: int(os.environ.get("INTERMEDIATE_SIZE", os.environ.get("INTERMEDIATE", str(DEFAULT_INTERMEDIATE_SIZE))))) |
| block_size: int = field(default_factory=lambda: int(os.environ.get("BLOCK_SIZE", str(DEFAULT_BLOCK_SIZE)))) |
| rope_theta: float = field(default_factory=lambda: float(os.environ.get("ROPE_THETA", os.environ.get("ROPE_BASE", str(DEFAULT_ROPE_THETA))))) |
| rms_norm_eps: float = field(default_factory=lambda: float(os.environ.get("RMS_NORM_EPS", "1e-6"))) |
| xsa_projection: bool = field(default_factory=lambda: _bool_env("XSA_PROJECTION", True)) |
| tie_word_embeddings: bool = field(default_factory=lambda: _bool_env("TIE_WORD_EMBEDDINGS", _bool_env("TIE_EMBEDDINGS", True))) |
| use_place_embeddings: bool = field(default_factory=lambda: _bool_env("USE_PLACE_EMBEDDINGS", True)) |
| use_role_embeddings: bool = field(default_factory=lambda: _bool_env("USE_ROLE_EMBEDDINGS", True)) |
| place_vocab_size: int = field(default_factory=lambda: int(os.environ.get("PLACE_VOCAB_SIZE", str(DEFAULT_PLACE_VOCAB_SIZE)))) |
| role_vocab_size: int = field(default_factory=lambda: int(os.environ.get("ROLE_VOCAB_SIZE", str(DEFAULT_ROLE_VOCAB_SIZE)))) |
|
|
| min_lr: float = field(default_factory=lambda: float(os.environ.get("MIN_LR", "0.01"))) |
| lr: float = field(default_factory=lambda: float(os.environ.get("LR", "0.005"))) |
| beta1: float = field(default_factory=lambda: float(os.environ.get("BETA1", "0.9"))) |
| beta2: float = field(default_factory=lambda: float(os.environ.get("BETA2", "0.98"))) |
| adam_eps: float = field(default_factory=lambda: float(os.environ.get("ADAM_EPS", "1e-8"))) |
| weight_decay: float = field(default_factory=lambda: float(os.environ.get("WEIGHT_DECAY", "0.001"))) |
|
|
| compile_model: bool = field(default_factory=lambda: _bool_env("COMPILE_MODEL", True)) |
| autocast: bool = field(default_factory=lambda: _bool_env("AUTOCAST", True)) |
| bf16: bool = field(default_factory=lambda: _bool_env("BF16", True)) |
| device: str = field(default_factory=lambda: os.environ.get("DEVICE", "auto")) |
|
|
| output_dir: str = field(default_factory=lambda: _path_env("OUTPUT_DIR", "outputs")) |
| checkpoint_name: str = field(default_factory=lambda: os.environ.get("CHECKPOINT_NAME", "final_model")) |
| logfile: str = field(init=False) |
| model_path: str = field(init=False) |
| is_main_process: bool = True |
| train_files: str = field(init=False) |
| val_files: str = field(init=False) |
|
|
| def __post_init__(self) -> None: |
| self.rank = int(os.environ.get("RANK", "0")) |
| if self.rank < 0: |
| raise ValueError("RANK must be non-negative") |
| self.is_main_process = self.rank == 0 |
| self.head_dim = int(os.environ.get("HEAD_DIM", str(self.hidden_size // self.num_attention_heads))) |
| self.train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", str(self.block_size))) |
| self.eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", os.environ.get("TRAIN_SEQ_LEN", str(self.train_seq_len)))) |
| token_alignment = self.train_seq_len * self.grad_accum_steps |
| if self.train_batch_tokens % token_alignment != 0: |
| raise ValueError( |
| "TRAIN_BATCH_TOKENS must be divisible by TRAIN_SEQ_LEN * GRAD_ACCUM_STEPS" |
| ) |
| requested_train_tokens = int(os.environ.get("TRAIN_TOKENS", "0")) |
| self.requested_train_tokens = requested_train_tokens or self.iterations * self.train_batch_tokens |
| if self.requested_train_tokens <= 0: |
| raise ValueError("TRAIN_TOKENS must be positive") |
| self.train_tokens = self.requested_train_tokens - (self.requested_train_tokens % token_alignment) |
| if self.train_tokens <= 0: |
| raise ValueError( |
| "TRAIN_TOKENS must provide at least TRAIN_SEQ_LEN * GRAD_ACCUM_STEPS tokens" |
| ) |
| self.iterations = math.ceil(self.train_tokens / self.train_batch_tokens) |
| tokenized = Path(self.tokenized_dir) |
| self.train_files = os.environ.get("TRAIN_FILES", str(tokenized / "*" / "shard_*.bin")) |
| self.val_files = os.environ.get("VAL_FILES", os.environ.get("TRAIN_FILES", self.train_files)) |
| explicit_legacy_mix = bool(os.environ.get("MIX_WEIGHTS_PATH")) |
| if not self.curriculum_path and not explicit_legacy_mix: |
| tokenized_curriculum = tokenized / "curriculum.json" |
| default_curriculum = Path("pretraining_curriculum.json") |
| if tokenized_curriculum.exists(): |
| self.curriculum_path = str(tokenized_curriculum) |
| elif default_curriculum.exists(): |
| self.curriculum_path = str(default_curriculum) |
| if not self.mix_weights_path and not self.curriculum_path: |
| mix_weights_path = tokenized / "mix_weights.json" |
| self.mix_weights_path = str(mix_weights_path) if mix_weights_path.exists() else "" |
| if not self.tokenizer_path: |
| tok_dir = Path(self.tokenizer_dir) |
| json_path = tok_dir / "tokenizer.json" |
| self.tokenizer_path = str(json_path if json_path.exists() else tok_dir) |
| out = Path(self.output_dir) |
| self.logfile = os.environ.get("LOGFILE", str(out / "logs" / f"{self.run_id}.txt")) |
| self.model_path = os.environ.get("MODEL_PATH", str(out / self.checkpoint_name)) |
|
|
| def to_model_config(self) -> GPTConfig: |
| return GPTConfig( |
| vocab_size=self.vocab_size, |
| hidden_size=self.hidden_size, |
| num_hidden_layers=self.num_hidden_layers, |
| num_attention_heads=self.num_attention_heads, |
| num_key_value_heads=self.num_key_value_heads, |
| head_dim=self.head_dim, |
| intermediate_size=self.intermediate_size, |
| block_size=self.block_size, |
| rope_theta=self.rope_theta, |
| rms_norm_eps=self.rms_norm_eps, |
| xsa_projection=self.xsa_projection, |
| tie_word_embeddings=self.tie_word_embeddings, |
| use_place_embeddings=self.use_place_embeddings, |
| use_role_embeddings=self.use_role_embeddings, |
| place_vocab_size=self.place_vocab_size, |
| role_vocab_size=self.role_vocab_size, |
| labels_are_shifted=True, |
| ) |
|
|