File size: 3,768 Bytes
7f974df | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 | """
model/config.py
ModelConfig dataclass + preset configs for SLLM-100M and SLLM-150M.
All hyperparameters live here so every other module imports from one place.
"""
from dataclasses import dataclass, field
def _swiglu_d_ff(d_model: int) -> int:
"""
SwiGLU hidden dimension.
LLaMA formula: round_up_256( int(2/3 * 4 * d_model) )
"""
raw = int(2 / 3 * 4 * d_model)
return ((raw + 255) // 256) * 256 # round up to nearest 256
@dataclass
class ModelConfig:
# ---- Vocabulary ------------------------------------------------- #
vocab_size: int = 32_000 # must match trained tokenizer
# ---- Sequence --------------------------------------------------- #
context_length: int = 1024 # max tokens per sequence
# ---- Transformer dimensions ------------------------------------- #
d_model: int = 768 # embedding / hidden dim
n_heads: int = 12 # number of attention heads
n_layers: int = 12 # number of transformer blocks
# ---- FFN -------------------------------------------------------- #
# SwiGLU d_ff is auto-computed from d_model if not set explicitly
d_ff: int = 0 # 0 = auto
# ---- Regularization --------------------------------------------- #
dropout: float = 0.0 # 0.0 for pre-training
# ---- Misc ------------------------------------------------------- #
bias: bool = False # no bias (cleaner, matches LLaMA)
rope_theta: float = 10_000.0 # RoPE base frequency
def __post_init__(self):
# Auto-compute d_ff if not set
if self.d_ff == 0:
self.d_ff = _swiglu_d_ff(self.d_model)
# Sanity checks
assert self.d_model % self.n_heads == 0, (
f"d_model ({self.d_model}) must be divisible by n_heads ({self.n_heads})"
)
@property
def head_dim(self) -> int:
return self.d_model // self.n_heads
def count_params(self) -> int:
"""Returns total trainable parameter count (with tied embeddings)."""
embed = self.vocab_size * self.d_model
attn = 4 * self.d_model * self.d_model # Q, K, V, O
mlp = 3 * self.d_model * self.d_ff # gate, up, down
norms = 2 * self.d_model # pre-attn + pre-mlp
per_block = attn + mlp + norms
final_norm = self.d_model
return embed + self.n_layers * per_block + final_norm
def __repr__(self) -> str:
n = self.count_params()
return (
f"ModelConfig("
f"d={self.d_model}, h={self.n_heads}, l={self.n_layers}, "
f"ff={self.d_ff}, ctx={self.context_length}, "
f"params={n/1e6:.1f}M)"
)
# ------------------------------------------------------------------ #
# PRESET CONFIGS
# ------------------------------------------------------------------ #
SLLM_100M = ModelConfig(
vocab_size = 32_000,
context_length = 1024,
d_model = 768,
n_heads = 12,
n_layers = 12,
# d_ff auto = 2048
)
SLLM_150M = ModelConfig(
vocab_size = 32_000,
context_length = 1024,
d_model = 1024,
n_heads = 16,
n_layers = 9,
# d_ff auto = 2816
)
# ------------------------------------------------------------------ #
# QUICK CHECK
# ------------------------------------------------------------------ #
if __name__ == "__main__":
for cfg in [SLLM_100M, SLLM_150M]:
print(cfg)
print(f" head_dim : {cfg.head_dim}")
print(f" d_ff : {cfg.d_ff}")
print()
|