File size: 6,402 Bytes
c0f89d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
"""
LMConfig: configuration dataclass for the LLM model architecture.
"""

from __future__ import annotations

import math
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional

import json

import yaml


def _round_to_multiple(n: int, multiple: int) -> int:
    """Round n up to the nearest multiple of `multiple`."""
    return math.ceil(n / multiple) * multiple


@dataclass
class LMConfig:
    # Vocabulary
    vocab_size: int = 32000

    # Model dimensions
    d_model: int = 768
    n_layers: int = 12
    n_heads: int = 12

    # Grouped-query attention: None → standard MHA (n_kv_heads == n_heads)
    n_kv_heads: Optional[int] = None

    # Feed-forward hidden dimension: None → auto-computed
    d_ffn: Optional[int] = None

    # Sequence length
    max_seq_len: int = 2048

    # RoPE base frequency
    rope_theta: float = 10000.0

    # Regularisation
    dropout: float = 0.0
    bias: bool = False

    # Attention backend
    use_flash_attn: bool = True

    # FP8 quantization
    use_fp8: bool = False

    # Hybrid Mamba-Transformer settings
    use_hybrid: bool = False
    hybrid_pattern: str = ""  # e.g. "M M A M M M M A M M M M M M M M M M A M" for 40-layer Nemotron-H style
    # Mamba-2 SSM parameters
    mamba_d_state: int = 128
    mamba_head_dim: int = 64
    mamba_expand: int = 2
    mamba_conv_kernel: int = 4
    mamba_n_groups: int = 1
    mamba_chunk_size: int = 256

    def __post_init__(self) -> None:
        # Resolve n_kv_heads: None → full MHA
        if self.n_kv_heads is None:
            self.n_kv_heads = self.n_heads

        # Validate GQA divisibility
        if self.n_heads % self.n_kv_heads != 0:
            raise ValueError(
                f"n_heads ({self.n_heads}) must be divisible by "
                f"n_kv_heads ({self.n_kv_heads})"
            )

        # Compute d_ffn using the LLaMA-style formula: round(8/3 * d_model)
        # rounded up to the nearest multiple of 256.
        if self.d_ffn is None:
            raw = int(8 / 3 * self.d_model)
            self.d_ffn = _round_to_multiple(raw, 256)

        # Hybrid Mamba-Transformer validation
        if self.use_hybrid and not self.hybrid_pattern.strip():
            raise ValueError(
                "use_hybrid=True requires a non-empty hybrid_pattern "
                "(space-separated 'M'/'A' per layer)"
            )

        # FP8 alignment: TE requires dimensions divisible by 16
        if self.use_fp8:
            if self.d_model % 16 != 0:
                raise ValueError(f"FP8: d_model ({self.d_model}) must be divisible by 16")
            if self.d_ffn % 16 != 0:
                raise ValueError(f"FP8: d_ffn ({self.d_ffn}) must be divisible by 16")

    # ------------------------------------------------------------------
    # Properties
    # ------------------------------------------------------------------

    @property
    def num_params(self) -> int:
        """Approximate parameter count using the 12 * L * d^2 rule."""
        return 12 * self.n_layers * self.d_model ** 2

    @property
    def head_dim(self) -> int:
        """Dimensionality of each attention head."""
        return self.d_model // self.n_heads

    # ------------------------------------------------------------------
    # Serialisation helpers
    # ------------------------------------------------------------------

    def to_dict(self) -> dict:
        """Return a plain-Python-dict representation of the config."""
        return {
            "vocab_size": self.vocab_size,
            "d_model": self.d_model,
            "n_layers": self.n_layers,
            "n_heads": self.n_heads,
            "n_kv_heads": self.n_kv_heads,
            "d_ffn": self.d_ffn,
            "max_seq_len": self.max_seq_len,
            "rope_theta": self.rope_theta,
            "dropout": self.dropout,
            "bias": self.bias,
            "use_flash_attn": self.use_flash_attn,
            "use_fp8": self.use_fp8,
            "use_hybrid": self.use_hybrid,
            "hybrid_pattern": self.hybrid_pattern,
            "mamba_d_state": self.mamba_d_state,
            "mamba_head_dim": self.mamba_head_dim,
            "mamba_expand": self.mamba_expand,
            "mamba_conv_kernel": self.mamba_conv_kernel,
            "mamba_n_groups": self.mamba_n_groups,
            "mamba_chunk_size": self.mamba_chunk_size,
        }

    def to_yaml(self, path: str | Path) -> None:
        """Serialise config to a YAML file."""
        path = Path(path)
        path.parent.mkdir(parents=True, exist_ok=True)
        with open(path, "w", encoding="utf-8") as f:
            yaml.safe_dump(self.to_dict(), f, default_flow_style=False, sort_keys=False)

    @classmethod
    def from_dict(cls, d: dict) -> "LMConfig":
        """Construct a LMConfig from a plain dict (e.g. loaded from YAML)."""
        return cls(**d)

    @classmethod
    def from_yaml(cls, path: str | Path) -> "LMConfig":
        """Load config from a YAML file."""
        with open(path, "r", encoding="utf-8") as f:
            data = yaml.safe_load(f)
        # Support nested YAML with 'model' section (e.g., shared multi-section configs)
        if "model" in data and isinstance(data["model"], dict):
            data = data["model"]
        return cls.from_dict(data)

    @classmethod
    def from_hf_config(cls, path: str | Path) -> "LMConfig":
        """Load config from a HuggingFace-format config.json (LlamaForCausalLM)."""
        path = Path(path)
        with open(path, "r", encoding="utf-8") as f:
            hf = json.load(f)

        rope_theta = 10000.0
        if "rope_parameters" in hf and isinstance(hf["rope_parameters"], dict):
            rope_theta = float(hf["rope_parameters"].get("rope_theta", rope_theta))
        elif "rope_theta" in hf:
            rope_theta = float(hf["rope_theta"])

        return cls(
            vocab_size=hf["vocab_size"],
            d_model=hf["hidden_size"],
            n_layers=hf["num_hidden_layers"],
            n_heads=hf["num_attention_heads"],
            n_kv_heads=hf.get("num_key_value_heads", hf["num_attention_heads"]),
            d_ffn=hf["intermediate_size"],
            max_seq_len=hf.get("max_position_embeddings", 4096),
            rope_theta=rope_theta,
            dropout=hf.get("attention_dropout", 0.0),
            bias=hf.get("attention_bias", False),
        )