from dataclasses import dataclass @dataclass class ModelConfig: vocab_size: int = 100277 # cl100k_base (tiktoken); must match tokenizer.VOCAB_SIZE dim: int = 1024 n_layers: int = 24 n_heads: int = 16 n_kv_heads: int = 8 # GQA: fewer KV heads than Q heads ffn_dim_multiplier: float = 2.6667 max_seq_len: int = 1024 rope_theta: float = 10000.0 norm_eps: float = 1e-5 dropout: float = 0.0 @dataclass class TrainConfig: # Data dataset_path: str = "data/fineweb-edu-10BT" # local path after save_to_disk dataset_name_field: str = "text" target_tokens: int = 8_000_000_000 val_fraction: float = 0.001 split_seed: int = 1337 use_packed_data: bool = True prepare_packed_data: bool = False # Batch / accumulation — tuned for RTX 4070 8 GB batch_size: int = 1 grad_accum_steps: int = 320 # effective batch = 1 * 256 * 1024 = 262144 tokens/step max_seq_len: int = 1024 num_workers: int = 2 # Optimiser optimizer: str = "adamw8bit" # "adamw" or "adamw8bit" (bitsandbytes) fused_adamw: bool = True lr: float = 3e-4 lr_min: float = 1e-5 weight_decay: float = 0.1 beta1: float = 0.9 beta2: float = 0.95 grad_clip: float = 1.0 # Schedule warmup_steps: int = 2000 max_steps: int = 0 # computed from target_tokens at runtime decay_fraction: float = 0.1 # final 10% of training for decay (Chinchilla-style) # Evaluation eval_interval: int = 500 # eval every N steps eval_steps: int = 100 # number of batches for eval # Metrics logging metrics_csv_path: str = "train_metrics.csv" eval_csv_path: str = "eval.csv" metrics_interval: int = 500 attention_entropy_probe_len: int = 256 gpu_peak_tflops: float = 0.0 # set >0 for MFU; 0 disables MFU # Precision & performance dtype: str = "bfloat16" compile_model: bool = False # re-enabled: speedup worth the VRAM grad_checkpointing: bool = True # Checkpointing ckpt_dir: str = "checkpoints" ckpt_interval: int = 1000 ckpt_keep_last: int = 3 # keep only last N checkpoints; <=0 keeps all log_interval: int = 10 micro_log_interval: int = 32 # print progress inside grad accumulation save_rng_state: bool = True save_loader_state: bool = True # Dry-run: verify VRAM before committing to full training dry_run: bool = False dry_run_steps: int = 2 # Reproducibility seed: int = 1337 @dataclass class SFTConfig(TrainConfig): dataset_paths: tuple = ("data/smol-smoltalk",) ckpt_dir: str = "sft_checkpoints" lr_start: float = 1.5e-4 lr: float = 3e-4 lr_min: float = 3e-5 warmup_fraction: float = 0.08 decay_fraction: float = 0.16 smoltalk_max_rows: int = 200_000 grad_accum_steps: int = 312 metrics_csv_path: str = "SFT_metrics.csv" metrics_interval: int = 10 eval_interval: int = 10