Atom3.4m / config.py
Maksymilian
Upload folder using huggingface_hub
bdb11fe verified
Raw
History Blame Contribute Delete
11.5 kB
"""Environment-driven training configuration."""
from __future__ import annotations
import os
import math
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from transformers import PretrainedConfig
DEFAULT_VOCAB_SIZE = 4096
DEFAULT_HIDDEN_SIZE = 192
DEFAULT_NUM_HIDDEN_LAYERS = 7
DEFAULT_NUM_ATTENTION_HEADS = 3
DEFAULT_NUM_KEY_VALUE_HEADS = 1
DEFAULT_HEAD_DIM = DEFAULT_HIDDEN_SIZE // DEFAULT_NUM_ATTENTION_HEADS
DEFAULT_INTERMEDIATE_SIZE = DEFAULT_HIDDEN_SIZE * 5 // 2
DEFAULT_BLOCK_SIZE = 512
DEFAULT_ROPE_THETA = 5000.0
class GPTConfig(PretrainedConfig):
model_type = "gpt"
def __init__(
self,
vocab_size: int = DEFAULT_VOCAB_SIZE,
hidden_size: int = DEFAULT_HIDDEN_SIZE,
num_hidden_layers: int = DEFAULT_NUM_HIDDEN_LAYERS,
num_attention_heads: int = DEFAULT_NUM_ATTENTION_HEADS,
num_key_value_heads: int | None = DEFAULT_NUM_KEY_VALUE_HEADS,
intermediate_size: int | None = DEFAULT_INTERMEDIATE_SIZE,
head_dim: int | None = None,
block_size: int = DEFAULT_BLOCK_SIZE,
rope_theta: float = DEFAULT_ROPE_THETA,
rms_norm_eps: float = 1e-6,
xsa_projection: bool = True,
tie_word_embeddings: bool = True,
labels_are_shifted: bool = False,
**kwargs,
):
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
if head_dim is None:
if hidden_size % num_attention_heads != 0:
raise ValueError("hidden_size must be divisible by num_attention_heads")
head_dim = hidden_size // num_attention_heads
if intermediate_size is None:
intermediate_size = hidden_size * 4
if num_attention_heads % num_key_value_heads != 0:
raise ValueError("num_attention_heads must be divisible by num_key_value_heads")
if head_dim % 2 != 0:
raise ValueError("head_dim must be even for RoPE")
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
self.vocab_size = int(vocab_size)
self.hidden_size = int(hidden_size)
self.num_hidden_layers = int(num_hidden_layers)
self.num_attention_heads = int(num_attention_heads)
self.num_key_value_heads = int(num_key_value_heads)
self.intermediate_size = int(intermediate_size)
self.head_dim = int(head_dim)
self.block_size = int(block_size)
self.max_position_embeddings = int(block_size)
self.rope_theta = float(rope_theta)
self.rms_norm_eps = float(rms_norm_eps)
self.xsa_projection = bool(xsa_projection)
self.labels_are_shifted = bool(labels_are_shifted)
def _bool_env(name: str, default: bool) -> bool:
raw = os.environ.get(name)
if raw is None:
return default
return raw.strip().lower() in {"1", "true", "yes", "on"}
def _path_env(name: str, default: str) -> str:
return str(Path(os.environ.get(name, default)).expanduser())
@dataclass
class Hyperparameters:
data_dir: str = field(default_factory=lambda: _path_env("DATA_DIR", "."))
tokenized_dir: str = field(default_factory=lambda: _path_env("TOKENIZED_DIR", "tokenized"))
tokenizer_dir: str = field(default_factory=lambda: _path_env("TOKENIZER_DIR", "tokenizer_4k"))
tokenizer_path: str = field(default_factory=lambda: os.environ.get("TOKENIZER_PATH", ""))
curriculum_path: str = field(default_factory=lambda: os.environ.get("CURRICULUM_PATH", ""))
mix_weights_path: str = field(default_factory=lambda: os.environ.get("MIX_WEIGHTS_PATH", ""))
run_id: str = field(default_factory=lambda: os.environ.get("RUN_ID", str(uuid.uuid4())))
seed: int = field(default_factory=lambda: int(os.environ.get("SEED", "1337")))
rank: int = field(init=False)
iterations: int = field(default_factory=lambda: int(os.environ.get("ITERATIONS", "10000")))
requested_train_tokens: int = field(init=False)
train_tokens: int = field(init=False)
decay_start_frac: float = field(default_factory=lambda: float(os.environ.get("DECAY_START_FRAC", "0.7")))
warmup_steps: int = field(default_factory=lambda: int(os.environ.get("WARMUP_STEPS", "0")))
lr_warmup_steps: int = field(default_factory=lambda: int(os.environ.get("LR_WARMUP_STEPS", "500")))
train_batch_tokens: int = field(default_factory=lambda: int(os.environ.get("TRAIN_BATCH_TOKENS", str(DEFAULT_BLOCK_SIZE * 512))))
train_seq_len: int = field(init=False)
eval_seq_len: int = field(init=False)
grad_accum_steps: int = field(default_factory=lambda: int(os.environ.get("GRAD_ACCUM_STEPS", "2")))
train_log_every: int = field(default_factory=lambda: int(os.environ.get("TRAIN_LOG_EVERY", "100")))
train_log_first_steps: int = field(default_factory=lambda: int(os.environ.get("TRAIN_LOG_FIRST_STEPS", "500")))
val_batch_tokens: int = field(default_factory=lambda: int(os.environ.get("VAL_BATCH_TOKENS", str(DEFAULT_BLOCK_SIZE * 256))))
val_loss_every: int = field(default_factory=lambda: int(os.environ.get("VAL_LOSS_EVERY", "1000")))
val_max_tokens: int = field(default_factory=lambda: int(os.environ.get("VAL_MAX_TOKENS", "10_000_000")))
vocab_size: int = field(default_factory=lambda: int(os.environ.get("VOCAB_SIZE", str(DEFAULT_VOCAB_SIZE))))
hidden_size: int = field(default_factory=lambda: int(os.environ.get("HIDDEN_SIZE", os.environ.get("MODEL_DIM", str(DEFAULT_HIDDEN_SIZE)))))
num_hidden_layers: int = field(default_factory=lambda: int(os.environ.get("NUM_HIDDEN_LAYERS", os.environ.get("NUM_LAYERS", str(DEFAULT_NUM_HIDDEN_LAYERS)))))
num_attention_heads: int = field(default_factory=lambda: int(os.environ.get("NUM_ATTENTION_HEADS", os.environ.get("NUM_HEADS", str(DEFAULT_NUM_ATTENTION_HEADS)))))
num_key_value_heads: int = field(default_factory=lambda: int(os.environ.get("NUM_KEY_VALUE_HEADS", os.environ.get("NUM_KV_HEADS", str(DEFAULT_NUM_KEY_VALUE_HEADS)))))
head_dim: int = field(init=False)
intermediate_size: int = field(default_factory=lambda: int(os.environ.get("INTERMEDIATE_SIZE", os.environ.get("INTERMEDIATE", str(DEFAULT_INTERMEDIATE_SIZE)))))
block_size: int = field(default_factory=lambda: int(os.environ.get("BLOCK_SIZE", str(DEFAULT_BLOCK_SIZE))))
rope_theta: float = field(default_factory=lambda: float(os.environ.get("ROPE_THETA", os.environ.get("ROPE_BASE", str(DEFAULT_ROPE_THETA)))))
rms_norm_eps: float = field(default_factory=lambda: float(os.environ.get("RMS_NORM_EPS", "1e-6")))
xsa_projection: bool = field(default_factory=lambda: _bool_env("XSA_PROJECTION", True))
tie_word_embeddings: bool = field(default_factory=lambda: _bool_env("TIE_WORD_EMBEDDINGS", _bool_env("TIE_EMBEDDINGS", True)))
min_lr: float = field(default_factory=lambda: float(os.environ.get("MIN_LR", "0.0")))
lr: float = field(default_factory=lambda: float(os.environ.get("LR", "0.004")))
beta1: float = field(default_factory=lambda: float(os.environ.get("BETA1", "0.9")))
beta2: float = field(default_factory=lambda: float(os.environ.get("BETA2", "0.95")))
adam_eps: float = field(default_factory=lambda: float(os.environ.get("ADAM_EPS", "1e-8")))
weight_decay: float = field(default_factory=lambda: float(os.environ.get("WEIGHT_DECAY", "0.005")))
compile_model: bool = field(default_factory=lambda: _bool_env("COMPILE_MODEL", True))
autocast: bool = field(default_factory=lambda: _bool_env("AUTOCAST", True))
bf16: bool = field(default_factory=lambda: _bool_env("BF16", True))
device: str = field(default_factory=lambda: os.environ.get("DEVICE", "auto"))
output_dir: str = field(default_factory=lambda: _path_env("OUTPUT_DIR", "outputs"))
checkpoint_name: str = field(default_factory=lambda: os.environ.get("CHECKPOINT_NAME", "final_model"))
logfile: str = field(init=False)
model_path: str = field(init=False)
is_main_process: bool = True
train_files: str = field(init=False)
val_files: str = field(init=False)
def __post_init__(self) -> None:
self.rank = int(os.environ.get("RANK", "0"))
if self.rank < 0:
raise ValueError("RANK must be non-negative")
self.is_main_process = self.rank == 0
self.head_dim = int(os.environ.get("HEAD_DIM", str(self.hidden_size // self.num_attention_heads)))
self.train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", str(self.block_size)))
self.eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", os.environ.get("TRAIN_SEQ_LEN", str(self.train_seq_len))))
token_alignment = self.train_seq_len * self.grad_accum_steps
if self.train_batch_tokens % token_alignment != 0:
raise ValueError(
"TRAIN_BATCH_TOKENS must be divisible by TRAIN_SEQ_LEN * GRAD_ACCUM_STEPS"
)
requested_train_tokens = int(os.environ.get("TRAIN_TOKENS", "0"))
self.requested_train_tokens = requested_train_tokens or self.iterations * self.train_batch_tokens
if self.requested_train_tokens <= 0:
raise ValueError("TRAIN_TOKENS must be positive")
self.train_tokens = self.requested_train_tokens - (self.requested_train_tokens % token_alignment)
if self.train_tokens <= 0:
raise ValueError(
"TRAIN_TOKENS must provide at least TRAIN_SEQ_LEN * GRAD_ACCUM_STEPS tokens"
)
self.iterations = math.ceil(self.train_tokens / self.train_batch_tokens)
tokenized = Path(self.tokenized_dir)
self.train_files = os.environ.get("TRAIN_FILES", str(tokenized / "*" / "shard_*.bin"))
self.val_files = os.environ.get("VAL_FILES", os.environ.get("TRAIN_FILES", self.train_files))
explicit_legacy_mix = bool(os.environ.get("MIX_WEIGHTS_PATH"))
if not self.curriculum_path and not explicit_legacy_mix:
tokenized_curriculum = tokenized / "curriculum.json"
default_curriculum = Path("pretraining_curriculum.json")
if tokenized_curriculum.exists():
self.curriculum_path = str(tokenized_curriculum)
elif default_curriculum.exists():
self.curriculum_path = str(default_curriculum)
if not self.mix_weights_path and not self.curriculum_path:
mix_weights_path = tokenized / "mix_weights.json"
self.mix_weights_path = str(mix_weights_path) if mix_weights_path.exists() else ""
if not self.tokenizer_path:
tok_dir = Path(self.tokenizer_dir)
json_path = tok_dir / "tokenizer.json"
self.tokenizer_path = str(json_path if json_path.exists() else tok_dir)
out = Path(self.output_dir)
self.logfile = os.environ.get("LOGFILE", str(out / "logs" / f"{self.run_id}.txt"))
self.model_path = os.environ.get("MODEL_PATH", str(out / self.checkpoint_name))
def to_model_config(self) -> GPTConfig:
return GPTConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
num_key_value_heads=self.num_key_value_heads,
head_dim=self.head_dim,
intermediate_size=self.intermediate_size,
block_size=self.block_size,
rope_theta=self.rope_theta,
rms_norm_eps=self.rms_norm_eps,
xsa_projection=self.xsa_projection,
tie_word_embeddings=self.tie_word_embeddings,
labels_are_shifted=True,
)