File size: 3,768 Bytes
7f974df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""
model/config.py

ModelConfig dataclass + preset configs for SLLM-100M and SLLM-150M.
All hyperparameters live here so every other module imports from one place.
"""

from dataclasses import dataclass, field


def _swiglu_d_ff(d_model: int) -> int:
    """
    SwiGLU hidden dimension.
    LLaMA formula: round_up_256( int(2/3 * 4 * d_model) )
    """
    raw = int(2 / 3 * 4 * d_model)
    return ((raw + 255) // 256) * 256          # round up to nearest 256


@dataclass
class ModelConfig:
    # ---- Vocabulary ------------------------------------------------- #
    vocab_size: int     = 32_000               # must match trained tokenizer

    # ---- Sequence --------------------------------------------------- #
    context_length: int = 1024                 # max tokens per sequence

    # ---- Transformer dimensions ------------------------------------- #
    d_model: int        = 768                  # embedding / hidden dim
    n_heads: int        = 12                   # number of attention heads
    n_layers: int       = 12                   # number of transformer blocks

    # ---- FFN -------------------------------------------------------- #
    # SwiGLU d_ff is auto-computed from d_model if not set explicitly
    d_ff: int           = 0                    # 0 = auto

    # ---- Regularization --------------------------------------------- #
    dropout: float      = 0.0                  # 0.0 for pre-training

    # ---- Misc ------------------------------------------------------- #
    bias: bool          = False                # no bias (cleaner, matches LLaMA)
    rope_theta: float   = 10_000.0             # RoPE base frequency

    def __post_init__(self):
        # Auto-compute d_ff if not set
        if self.d_ff == 0:
            self.d_ff = _swiglu_d_ff(self.d_model)

        # Sanity checks
        assert self.d_model % self.n_heads == 0, (
            f"d_model ({self.d_model}) must be divisible by n_heads ({self.n_heads})"
        )

    @property
    def head_dim(self) -> int:
        return self.d_model // self.n_heads

    def count_params(self) -> int:
        """Returns total trainable parameter count (with tied embeddings)."""
        embed      = self.vocab_size * self.d_model
        attn       = 4 * self.d_model * self.d_model   # Q, K, V, O
        mlp        = 3 * self.d_model * self.d_ff      # gate, up, down
        norms      = 2 * self.d_model                  # pre-attn + pre-mlp
        per_block  = attn + mlp + norms
        final_norm = self.d_model
        return embed + self.n_layers * per_block + final_norm

    def __repr__(self) -> str:
        n = self.count_params()
        return (
            f"ModelConfig("
            f"d={self.d_model}, h={self.n_heads}, l={self.n_layers}, "
            f"ff={self.d_ff}, ctx={self.context_length}, "
            f"params={n/1e6:.1f}M)"
        )


# ------------------------------------------------------------------ #
#  PRESET CONFIGS
# ------------------------------------------------------------------ #

SLLM_100M = ModelConfig(
    vocab_size      = 32_000,
    context_length  = 1024,
    d_model         = 768,
    n_heads         = 12,
    n_layers        = 12,
    # d_ff auto = 2048
)

SLLM_150M = ModelConfig(
    vocab_size      = 32_000,
    context_length  = 1024,
    d_model         = 1024,
    n_heads         = 16,
    n_layers        = 9,
    # d_ff auto = 2816
)


# ------------------------------------------------------------------ #
#  QUICK CHECK
# ------------------------------------------------------------------ #

if __name__ == "__main__":
    for cfg in [SLLM_100M, SLLM_150M]:
        print(cfg)
        print(f"  head_dim : {cfg.head_dim}")
        print(f"  d_ff     : {cfg.d_ff}")
        print()