File size: 1,565 Bytes
a2ce935
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from dataclasses import dataclass


@dataclass
class ModelConfig:
    vocab_size: int = 50304  # rounded-up GPT-2 vocab for better matmul shapes
    d_model: int = 768
    n_layers: int = 10
    n_heads: int = 12        # head_dim = 64
    d_ff: int = 2048         # canonical SwiGLU 8/3 * d_model
    seq_len: int = 2048
    dropout: float = 0.0
    rope_theta: float = 10000.0
    tie_embeddings: bool = True
    norm_eps: float = 1e-5


@dataclass
class TrainConfig:
    # Paths
    data_dir: str = "data"
    out_dir: str = "checkpoints"

    # Model (mirrors ModelConfig so a single dataclass configures runs)
    vocab_size: int = 50304
    d_model: int = 768
    n_layers: int = 10
    n_heads: int = 12
    d_ff: int = 2048
    seq_len: int = 2048
    dropout: float = 0.0

    # Training budget
    target_tokens: int = 1_000_000_000
    # Memory at seq=2048 for ~100M params: keep microbatches small and use
    # grad accumulation to keep effective batch = 32 × 2048 = 65_536 tok/step.
    batch_size: int = 4
    grad_accum_steps: int = 8

    # Optimizer / schedule
    learning_rate: float = 6e-4
    min_lr_ratio: float = 0.1
    warmup_tokens: int = 3_000_000
    weight_decay: float = 0.1
    beta1: float = 0.9
    beta2: float = 0.95
    grad_clip: float = 1.0

    # Checkpoint / eval cadence (in tokens)
    ckpt_every_tokens: int = 100_000_000
    eval_every_tokens: int = 6_000_000
    eval_batches: int = 50

    # Logging
    log_every_steps: int = 10

    # System
    device: str = "mps"
    dtype: str = "bfloat16"
    seed: int = 1337