from dataclasses import dataclass, field, asdict @dataclass class ModelConfig: vocab_size: int = 8192 n_layer: int = 8 n_head: int = 8 n_embd: int = 512 block_size: int = 1024 rope_base: float = 10000.0 mlp_mult: int = 4 dropout: float = 0.0 tie_embeddings: bool = True @property def head_dim(self) -> int: assert self.n_embd % self.n_head == 0 return self.n_embd // self.n_head @dataclass class TrainConfig: out_dir: str = "checkpoints" data_dir: str = "data" tokenizer_path: str = "data/tokenizer.json" batch_size: int = 32 grad_accum: int = 4 max_steps: int = 20000 eval_interval: int = 500 eval_iters: int = 100 log_interval: int = 20 save_interval: int = 2000 lr: float = 6e-4 min_lr: float = 6e-5 warmup_steps: int = 200 weight_decay: float = 0.1 beta1: float = 0.9 beta2: float = 0.95 grad_clip: float = 1.0 dtype: str = "bfloat16" compile: bool = True seed: int = 1337 device: str = "cuda" def to_dict(self): return asdict(self)