""" model/config.py ModelConfig dataclass + preset configs for SLLM-100M and SLLM-150M. All hyperparameters live here so every other module imports from one place. """ from dataclasses import dataclass, field def _swiglu_d_ff(d_model: int) -> int: """ SwiGLU hidden dimension. LLaMA formula: round_up_256( int(2/3 * 4 * d_model) ) """ raw = int(2 / 3 * 4 * d_model) return ((raw + 255) // 256) * 256 # round up to nearest 256 @dataclass class ModelConfig: # ---- Vocabulary ------------------------------------------------- # vocab_size: int = 32_000 # must match trained tokenizer # ---- Sequence --------------------------------------------------- # context_length: int = 1024 # max tokens per sequence # ---- Transformer dimensions ------------------------------------- # d_model: int = 768 # embedding / hidden dim n_heads: int = 12 # number of attention heads n_layers: int = 12 # number of transformer blocks # ---- FFN -------------------------------------------------------- # # SwiGLU d_ff is auto-computed from d_model if not set explicitly d_ff: int = 0 # 0 = auto # ---- Regularization --------------------------------------------- # dropout: float = 0.0 # 0.0 for pre-training # ---- Misc ------------------------------------------------------- # bias: bool = False # no bias (cleaner, matches LLaMA) rope_theta: float = 10_000.0 # RoPE base frequency def __post_init__(self): # Auto-compute d_ff if not set if self.d_ff == 0: self.d_ff = _swiglu_d_ff(self.d_model) # Sanity checks assert self.d_model % self.n_heads == 0, ( f"d_model ({self.d_model}) must be divisible by n_heads ({self.n_heads})" ) @property def head_dim(self) -> int: return self.d_model // self.n_heads def count_params(self) -> int: """Returns total trainable parameter count (with tied embeddings).""" embed = self.vocab_size * self.d_model attn = 4 * self.d_model * self.d_model # Q, K, V, O mlp = 3 * self.d_model * self.d_ff # gate, up, down norms = 2 * self.d_model # pre-attn + pre-mlp per_block = attn + mlp + norms final_norm = self.d_model return embed + self.n_layers * per_block + final_norm def __repr__(self) -> str: n = self.count_params() return ( f"ModelConfig(" f"d={self.d_model}, h={self.n_heads}, l={self.n_layers}, " f"ff={self.d_ff}, ctx={self.context_length}, " f"params={n/1e6:.1f}M)" ) # ------------------------------------------------------------------ # # PRESET CONFIGS # ------------------------------------------------------------------ # SLLM_100M = ModelConfig( vocab_size = 32_000, context_length = 1024, d_model = 768, n_heads = 12, n_layers = 12, # d_ff auto = 2048 ) SLLM_150M = ModelConfig( vocab_size = 32_000, context_length = 1024, d_model = 1024, n_heads = 16, n_layers = 9, # d_ff auto = 2816 ) # ------------------------------------------------------------------ # # QUICK CHECK # ------------------------------------------------------------------ # if __name__ == "__main__": for cfg in [SLLM_100M, SLLM_150M]: print(cfg) print(f" head_dim : {cfg.head_dim}") print(f" d_ff : {cfg.d_ff}") print()