"""Environment-driven training configuration.""" from __future__ import annotations import os import math import uuid from dataclasses import dataclass, field from pathlib import Path from transformers import PretrainedConfig DEFAULT_VOCAB_SIZE = 4096 DEFAULT_HIDDEN_SIZE = 192 DEFAULT_NUM_HIDDEN_LAYERS = 7 DEFAULT_NUM_ATTENTION_HEADS = 3 DEFAULT_NUM_KEY_VALUE_HEADS = 1 DEFAULT_HEAD_DIM = DEFAULT_HIDDEN_SIZE // DEFAULT_NUM_ATTENTION_HEADS DEFAULT_INTERMEDIATE_SIZE = DEFAULT_HIDDEN_SIZE * 5 // 2 DEFAULT_BLOCK_SIZE = 512 DEFAULT_ROPE_THETA = 5000.0 class GPTConfig(PretrainedConfig): model_type = "gpt" def __init__( self, vocab_size: int = DEFAULT_VOCAB_SIZE, hidden_size: int = DEFAULT_HIDDEN_SIZE, num_hidden_layers: int = DEFAULT_NUM_HIDDEN_LAYERS, num_attention_heads: int = DEFAULT_NUM_ATTENTION_HEADS, num_key_value_heads: int | None = DEFAULT_NUM_KEY_VALUE_HEADS, intermediate_size: int | None = DEFAULT_INTERMEDIATE_SIZE, head_dim: int | None = None, block_size: int = DEFAULT_BLOCK_SIZE, rope_theta: float = DEFAULT_ROPE_THETA, rms_norm_eps: float = 1e-6, xsa_projection: bool = True, tie_word_embeddings: bool = True, labels_are_shifted: bool = False, **kwargs, ): if num_key_value_heads is None: num_key_value_heads = num_attention_heads if head_dim is None: if hidden_size % num_attention_heads != 0: raise ValueError("hidden_size must be divisible by num_attention_heads") head_dim = hidden_size // num_attention_heads if intermediate_size is None: intermediate_size = hidden_size * 4 if num_attention_heads % num_key_value_heads != 0: raise ValueError("num_attention_heads must be divisible by num_key_value_heads") if head_dim % 2 != 0: raise ValueError("head_dim must be even for RoPE") super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) self.vocab_size = int(vocab_size) self.hidden_size = int(hidden_size) self.num_hidden_layers = int(num_hidden_layers) self.num_attention_heads = int(num_attention_heads) self.num_key_value_heads = int(num_key_value_heads) self.intermediate_size = int(intermediate_size) self.head_dim = int(head_dim) self.block_size = int(block_size) self.max_position_embeddings = int(block_size) self.rope_theta = float(rope_theta) self.rms_norm_eps = float(rms_norm_eps) self.xsa_projection = bool(xsa_projection) self.labels_are_shifted = bool(labels_are_shifted) def _bool_env(name: str, default: bool) -> bool: raw = os.environ.get(name) if raw is None: return default return raw.strip().lower() in {"1", "true", "yes", "on"} def _path_env(name: str, default: str) -> str: return str(Path(os.environ.get(name, default)).expanduser()) @dataclass class Hyperparameters: data_dir: str = field(default_factory=lambda: _path_env("DATA_DIR", ".")) tokenized_dir: str = field(default_factory=lambda: _path_env("TOKENIZED_DIR", "tokenized")) tokenizer_dir: str = field(default_factory=lambda: _path_env("TOKENIZER_DIR", "tokenizer_4k")) tokenizer_path: str = field(default_factory=lambda: os.environ.get("TOKENIZER_PATH", "")) curriculum_path: str = field(default_factory=lambda: os.environ.get("CURRICULUM_PATH", "")) mix_weights_path: str = field(default_factory=lambda: os.environ.get("MIX_WEIGHTS_PATH", "")) run_id: str = field(default_factory=lambda: os.environ.get("RUN_ID", str(uuid.uuid4()))) seed: int = field(default_factory=lambda: int(os.environ.get("SEED", "1337"))) rank: int = field(init=False) iterations: int = field(default_factory=lambda: int(os.environ.get("ITERATIONS", "10000"))) requested_train_tokens: int = field(init=False) train_tokens: int = field(init=False) decay_start_frac: float = field(default_factory=lambda: float(os.environ.get("DECAY_START_FRAC", "0.7"))) warmup_steps: int = field(default_factory=lambda: int(os.environ.get("WARMUP_STEPS", "0"))) lr_warmup_steps: int = field(default_factory=lambda: int(os.environ.get("LR_WARMUP_STEPS", "500"))) train_batch_tokens: int = field(default_factory=lambda: int(os.environ.get("TRAIN_BATCH_TOKENS", str(DEFAULT_BLOCK_SIZE * 512)))) train_seq_len: int = field(init=False) eval_seq_len: int = field(init=False) grad_accum_steps: int = field(default_factory=lambda: int(os.environ.get("GRAD_ACCUM_STEPS", "2"))) train_log_every: int = field(default_factory=lambda: int(os.environ.get("TRAIN_LOG_EVERY", "100"))) train_log_first_steps: int = field(default_factory=lambda: int(os.environ.get("TRAIN_LOG_FIRST_STEPS", "500"))) val_batch_tokens: int = field(default_factory=lambda: int(os.environ.get("VAL_BATCH_TOKENS", str(DEFAULT_BLOCK_SIZE * 256)))) val_loss_every: int = field(default_factory=lambda: int(os.environ.get("VAL_LOSS_EVERY", "1000"))) val_max_tokens: int = field(default_factory=lambda: int(os.environ.get("VAL_MAX_TOKENS", "10_000_000"))) vocab_size: int = field(default_factory=lambda: int(os.environ.get("VOCAB_SIZE", str(DEFAULT_VOCAB_SIZE)))) hidden_size: int = field(default_factory=lambda: int(os.environ.get("HIDDEN_SIZE", os.environ.get("MODEL_DIM", str(DEFAULT_HIDDEN_SIZE))))) num_hidden_layers: int = field(default_factory=lambda: int(os.environ.get("NUM_HIDDEN_LAYERS", os.environ.get("NUM_LAYERS", str(DEFAULT_NUM_HIDDEN_LAYERS))))) num_attention_heads: int = field(default_factory=lambda: int(os.environ.get("NUM_ATTENTION_HEADS", os.environ.get("NUM_HEADS", str(DEFAULT_NUM_ATTENTION_HEADS))))) num_key_value_heads: int = field(default_factory=lambda: int(os.environ.get("NUM_KEY_VALUE_HEADS", os.environ.get("NUM_KV_HEADS", str(DEFAULT_NUM_KEY_VALUE_HEADS))))) head_dim: int = field(init=False) intermediate_size: int = field(default_factory=lambda: int(os.environ.get("INTERMEDIATE_SIZE", os.environ.get("INTERMEDIATE", str(DEFAULT_INTERMEDIATE_SIZE))))) block_size: int = field(default_factory=lambda: int(os.environ.get("BLOCK_SIZE", str(DEFAULT_BLOCK_SIZE)))) rope_theta: float = field(default_factory=lambda: float(os.environ.get("ROPE_THETA", os.environ.get("ROPE_BASE", str(DEFAULT_ROPE_THETA))))) rms_norm_eps: float = field(default_factory=lambda: float(os.environ.get("RMS_NORM_EPS", "1e-6"))) xsa_projection: bool = field(default_factory=lambda: _bool_env("XSA_PROJECTION", True)) tie_word_embeddings: bool = field(default_factory=lambda: _bool_env("TIE_WORD_EMBEDDINGS", _bool_env("TIE_EMBEDDINGS", True))) min_lr: float = field(default_factory=lambda: float(os.environ.get("MIN_LR", "0.0"))) lr: float = field(default_factory=lambda: float(os.environ.get("LR", "0.004"))) beta1: float = field(default_factory=lambda: float(os.environ.get("BETA1", "0.9"))) beta2: float = field(default_factory=lambda: float(os.environ.get("BETA2", "0.95"))) adam_eps: float = field(default_factory=lambda: float(os.environ.get("ADAM_EPS", "1e-8"))) weight_decay: float = field(default_factory=lambda: float(os.environ.get("WEIGHT_DECAY", "0.005"))) compile_model: bool = field(default_factory=lambda: _bool_env("COMPILE_MODEL", True)) autocast: bool = field(default_factory=lambda: _bool_env("AUTOCAST", True)) bf16: bool = field(default_factory=lambda: _bool_env("BF16", True)) device: str = field(default_factory=lambda: os.environ.get("DEVICE", "auto")) output_dir: str = field(default_factory=lambda: _path_env("OUTPUT_DIR", "outputs")) checkpoint_name: str = field(default_factory=lambda: os.environ.get("CHECKPOINT_NAME", "final_model")) logfile: str = field(init=False) model_path: str = field(init=False) is_main_process: bool = True train_files: str = field(init=False) val_files: str = field(init=False) def __post_init__(self) -> None: self.rank = int(os.environ.get("RANK", "0")) if self.rank < 0: raise ValueError("RANK must be non-negative") self.is_main_process = self.rank == 0 self.head_dim = int(os.environ.get("HEAD_DIM", str(self.hidden_size // self.num_attention_heads))) self.train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", str(self.block_size))) self.eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", os.environ.get("TRAIN_SEQ_LEN", str(self.train_seq_len)))) token_alignment = self.train_seq_len * self.grad_accum_steps if self.train_batch_tokens % token_alignment != 0: raise ValueError( "TRAIN_BATCH_TOKENS must be divisible by TRAIN_SEQ_LEN * GRAD_ACCUM_STEPS" ) requested_train_tokens = int(os.environ.get("TRAIN_TOKENS", "0")) self.requested_train_tokens = requested_train_tokens or self.iterations * self.train_batch_tokens if self.requested_train_tokens <= 0: raise ValueError("TRAIN_TOKENS must be positive") self.train_tokens = self.requested_train_tokens - (self.requested_train_tokens % token_alignment) if self.train_tokens <= 0: raise ValueError( "TRAIN_TOKENS must provide at least TRAIN_SEQ_LEN * GRAD_ACCUM_STEPS tokens" ) self.iterations = math.ceil(self.train_tokens / self.train_batch_tokens) tokenized = Path(self.tokenized_dir) self.train_files = os.environ.get("TRAIN_FILES", str(tokenized / "*" / "shard_*.bin")) self.val_files = os.environ.get("VAL_FILES", os.environ.get("TRAIN_FILES", self.train_files)) explicit_legacy_mix = bool(os.environ.get("MIX_WEIGHTS_PATH")) if not self.curriculum_path and not explicit_legacy_mix: tokenized_curriculum = tokenized / "curriculum.json" default_curriculum = Path("pretraining_curriculum.json") if tokenized_curriculum.exists(): self.curriculum_path = str(tokenized_curriculum) elif default_curriculum.exists(): self.curriculum_path = str(default_curriculum) if not self.mix_weights_path and not self.curriculum_path: mix_weights_path = tokenized / "mix_weights.json" self.mix_weights_path = str(mix_weights_path) if mix_weights_path.exists() else "" if not self.tokenizer_path: tok_dir = Path(self.tokenizer_dir) json_path = tok_dir / "tokenizer.json" self.tokenizer_path = str(json_path if json_path.exists() else tok_dir) out = Path(self.output_dir) self.logfile = os.environ.get("LOGFILE", str(out / "logs" / f"{self.run_id}.txt")) self.model_path = os.environ.get("MODEL_PATH", str(out / self.checkpoint_name)) def to_model_config(self) -> GPTConfig: return GPTConfig( vocab_size=self.vocab_size, hidden_size=self.hidden_size, num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads, num_key_value_heads=self.num_key_value_heads, head_dim=self.head_dim, intermediate_size=self.intermediate_size, block_size=self.block_size, rope_theta=self.rope_theta, rms_norm_eps=self.rms_norm_eps, xsa_projection=self.xsa_projection, tie_word_embeddings=self.tie_word_embeddings, labels_are_shifted=True, )