from dataclasses import dataclass @dataclass class ModelConfig: vocab_size: int = 50304 # rounded-up GPT-2 vocab for better matmul shapes d_model: int = 768 n_layers: int = 10 n_heads: int = 12 # head_dim = 64 d_ff: int = 2048 # canonical SwiGLU 8/3 * d_model seq_len: int = 2048 dropout: float = 0.0 rope_theta: float = 10000.0 tie_embeddings: bool = True norm_eps: float = 1e-5 @dataclass class TrainConfig: # Paths data_dir: str = "data" out_dir: str = "checkpoints" # Model (mirrors ModelConfig so a single dataclass configures runs) vocab_size: int = 50304 d_model: int = 768 n_layers: int = 10 n_heads: int = 12 d_ff: int = 2048 seq_len: int = 2048 dropout: float = 0.0 # Training budget target_tokens: int = 1_000_000_000 # Memory at seq=2048 for ~100M params: keep microbatches small and use # grad accumulation to keep effective batch = 32 × 2048 = 65_536 tok/step. batch_size: int = 4 grad_accum_steps: int = 8 # Optimizer / schedule learning_rate: float = 6e-4 min_lr_ratio: float = 0.1 warmup_tokens: int = 3_000_000 weight_decay: float = 0.1 beta1: float = 0.9 beta2: float = 0.95 grad_clip: float = 1.0 # Checkpoint / eval cadence (in tokens) ckpt_every_tokens: int = 100_000_000 eval_every_tokens: int = 6_000_000 eval_batches: int = 50 # Logging log_every_steps: int = 10 # System device: str = "mps" dtype: str = "bfloat16" seed: int = 1337