| | """ |
| | LMConfig: configuration dataclass for the LLM model architecture. |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import math |
| | from dataclasses import dataclass, field |
| | from pathlib import Path |
| | from typing import Optional |
| |
|
| | import json |
| |
|
| | import yaml |
| |
|
| |
|
| | def _round_to_multiple(n: int, multiple: int) -> int: |
| | """Round n up to the nearest multiple of `multiple`.""" |
| | return math.ceil(n / multiple) * multiple |
| |
|
| |
|
| | @dataclass |
| | class LMConfig: |
| | |
| | vocab_size: int = 32000 |
| |
|
| | |
| | d_model: int = 768 |
| | n_layers: int = 12 |
| | n_heads: int = 12 |
| |
|
| | |
| | n_kv_heads: Optional[int] = None |
| |
|
| | |
| | d_ffn: Optional[int] = None |
| |
|
| | |
| | max_seq_len: int = 2048 |
| |
|
| | |
| | rope_theta: float = 10000.0 |
| |
|
| | |
| | dropout: float = 0.0 |
| | bias: bool = False |
| |
|
| | |
| | use_flash_attn: bool = True |
| |
|
| | |
| | use_fp8: bool = False |
| |
|
| | |
| | use_hybrid: bool = False |
| | hybrid_pattern: str = "" |
| | |
| | mamba_d_state: int = 128 |
| | mamba_head_dim: int = 64 |
| | mamba_expand: int = 2 |
| | mamba_conv_kernel: int = 4 |
| | mamba_n_groups: int = 1 |
| | mamba_chunk_size: int = 256 |
| |
|
| | def __post_init__(self) -> None: |
| | |
| | if self.n_kv_heads is None: |
| | self.n_kv_heads = self.n_heads |
| |
|
| | |
| | if self.n_heads % self.n_kv_heads != 0: |
| | raise ValueError( |
| | f"n_heads ({self.n_heads}) must be divisible by " |
| | f"n_kv_heads ({self.n_kv_heads})" |
| | ) |
| |
|
| | |
| | |
| | if self.d_ffn is None: |
| | raw = int(8 / 3 * self.d_model) |
| | self.d_ffn = _round_to_multiple(raw, 256) |
| |
|
| | |
| | if self.use_hybrid and not self.hybrid_pattern.strip(): |
| | raise ValueError( |
| | "use_hybrid=True requires a non-empty hybrid_pattern " |
| | "(space-separated 'M'/'A' per layer)" |
| | ) |
| |
|
| | |
| | if self.use_fp8: |
| | if self.d_model % 16 != 0: |
| | raise ValueError(f"FP8: d_model ({self.d_model}) must be divisible by 16") |
| | if self.d_ffn % 16 != 0: |
| | raise ValueError(f"FP8: d_ffn ({self.d_ffn}) must be divisible by 16") |
| |
|
| | |
| | |
| | |
| |
|
| | @property |
| | def num_params(self) -> int: |
| | """Approximate parameter count using the 12 * L * d^2 rule.""" |
| | return 12 * self.n_layers * self.d_model ** 2 |
| |
|
| | @property |
| | def head_dim(self) -> int: |
| | """Dimensionality of each attention head.""" |
| | return self.d_model // self.n_heads |
| |
|
| | |
| | |
| | |
| |
|
| | def to_dict(self) -> dict: |
| | """Return a plain-Python-dict representation of the config.""" |
| | return { |
| | "vocab_size": self.vocab_size, |
| | "d_model": self.d_model, |
| | "n_layers": self.n_layers, |
| | "n_heads": self.n_heads, |
| | "n_kv_heads": self.n_kv_heads, |
| | "d_ffn": self.d_ffn, |
| | "max_seq_len": self.max_seq_len, |
| | "rope_theta": self.rope_theta, |
| | "dropout": self.dropout, |
| | "bias": self.bias, |
| | "use_flash_attn": self.use_flash_attn, |
| | "use_fp8": self.use_fp8, |
| | "use_hybrid": self.use_hybrid, |
| | "hybrid_pattern": self.hybrid_pattern, |
| | "mamba_d_state": self.mamba_d_state, |
| | "mamba_head_dim": self.mamba_head_dim, |
| | "mamba_expand": self.mamba_expand, |
| | "mamba_conv_kernel": self.mamba_conv_kernel, |
| | "mamba_n_groups": self.mamba_n_groups, |
| | "mamba_chunk_size": self.mamba_chunk_size, |
| | } |
| |
|
| | def to_yaml(self, path: str | Path) -> None: |
| | """Serialise config to a YAML file.""" |
| | path = Path(path) |
| | path.parent.mkdir(parents=True, exist_ok=True) |
| | with open(path, "w", encoding="utf-8") as f: |
| | yaml.safe_dump(self.to_dict(), f, default_flow_style=False, sort_keys=False) |
| |
|
| | @classmethod |
| | def from_dict(cls, d: dict) -> "LMConfig": |
| | """Construct a LMConfig from a plain dict (e.g. loaded from YAML).""" |
| | return cls(**d) |
| |
|
| | @classmethod |
| | def from_yaml(cls, path: str | Path) -> "LMConfig": |
| | """Load config from a YAML file.""" |
| | with open(path, "r", encoding="utf-8") as f: |
| | data = yaml.safe_load(f) |
| | |
| | if "model" in data and isinstance(data["model"], dict): |
| | data = data["model"] |
| | return cls.from_dict(data) |
| |
|
| | @classmethod |
| | def from_hf_config(cls, path: str | Path) -> "LMConfig": |
| | """Load config from a HuggingFace-format config.json (LlamaForCausalLM).""" |
| | path = Path(path) |
| | with open(path, "r", encoding="utf-8") as f: |
| | hf = json.load(f) |
| |
|
| | rope_theta = 10000.0 |
| | if "rope_parameters" in hf and isinstance(hf["rope_parameters"], dict): |
| | rope_theta = float(hf["rope_parameters"].get("rope_theta", rope_theta)) |
| | elif "rope_theta" in hf: |
| | rope_theta = float(hf["rope_theta"]) |
| |
|
| | return cls( |
| | vocab_size=hf["vocab_size"], |
| | d_model=hf["hidden_size"], |
| | n_layers=hf["num_hidden_layers"], |
| | n_heads=hf["num_attention_heads"], |
| | n_kv_heads=hf.get("num_key_value_heads", hf["num_attention_heads"]), |
| | d_ffn=hf["intermediate_size"], |
| | max_seq_len=hf.get("max_position_embeddings", 4096), |
| | rope_theta=rope_theta, |
| | dropout=hf.get("attention_dropout", 0.0), |
| | bias=hf.get("attention_bias", False), |
| | ) |
| |
|