| """ |
| model/config.py |
| |
| ModelConfig dataclass + preset configs for SLLM-100M and SLLM-150M. |
| All hyperparameters live here so every other module imports from one place. |
| """ |
|
|
| from dataclasses import dataclass, field |
|
|
|
|
| def _swiglu_d_ff(d_model: int) -> int: |
| """ |
| SwiGLU hidden dimension. |
| LLaMA formula: round_up_256( int(2/3 * 4 * d_model) ) |
| """ |
| raw = int(2 / 3 * 4 * d_model) |
| return ((raw + 255) // 256) * 256 |
|
|
|
|
| @dataclass |
| class ModelConfig: |
| |
| vocab_size: int = 32_000 |
|
|
| |
| context_length: int = 1024 |
|
|
| |
| d_model: int = 768 |
| n_heads: int = 12 |
| n_layers: int = 12 |
|
|
| |
| |
| d_ff: int = 0 |
|
|
| |
| dropout: float = 0.0 |
|
|
| |
| bias: bool = False |
| rope_theta: float = 10_000.0 |
|
|
| def __post_init__(self): |
| |
| if self.d_ff == 0: |
| self.d_ff = _swiglu_d_ff(self.d_model) |
|
|
| |
| assert self.d_model % self.n_heads == 0, ( |
| f"d_model ({self.d_model}) must be divisible by n_heads ({self.n_heads})" |
| ) |
|
|
| @property |
| def head_dim(self) -> int: |
| return self.d_model // self.n_heads |
|
|
| def count_params(self) -> int: |
| """Returns total trainable parameter count (with tied embeddings).""" |
| embed = self.vocab_size * self.d_model |
| attn = 4 * self.d_model * self.d_model |
| mlp = 3 * self.d_model * self.d_ff |
| norms = 2 * self.d_model |
| per_block = attn + mlp + norms |
| final_norm = self.d_model |
| return embed + self.n_layers * per_block + final_norm |
|
|
| def __repr__(self) -> str: |
| n = self.count_params() |
| return ( |
| f"ModelConfig(" |
| f"d={self.d_model}, h={self.n_heads}, l={self.n_layers}, " |
| f"ff={self.d_ff}, ctx={self.context_length}, " |
| f"params={n/1e6:.1f}M)" |
| ) |
|
|
|
|
| |
| |
| |
|
|
| SLLM_100M = ModelConfig( |
| vocab_size = 32_000, |
| context_length = 1024, |
| d_model = 768, |
| n_heads = 12, |
| n_layers = 12, |
| |
| ) |
|
|
| SLLM_150M = ModelConfig( |
| vocab_size = 32_000, |
| context_length = 1024, |
| d_model = 1024, |
| n_heads = 16, |
| n_layers = 9, |
| |
| ) |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| for cfg in [SLLM_100M, SLLM_150M]: |
| print(cfg) |
| print(f" head_dim : {cfg.head_dim}") |
| print(f" d_ff : {cfg.d_ff}") |
| print() |
|
|