import json from dataclasses import dataclass from typing import Optional @dataclass class ModelConfig: # Core vocab_size: int n_positions: int d_model: int n_layers: int n_heads: int mlp_ratio: int = 4 dropout: float = 0.1 tie_word_embeddings: bool = True use_positional_embedding: bool = True final_layer_norm: bool = True # Derived convenience @property def d_mlp(self) -> int: return self.d_model * self.mlp_ratio def to_json(self) -> str: return json.dumps(self.__dict__, indent=2) @staticmethod def from_json_str(s: str) -> "ModelConfig": data = json.loads(s) return ModelConfig(**data) @staticmethod def from_json_file(path: str) -> "ModelConfig": with open(path, "r", encoding="utf-8") as f: data = json.load(f) if "model" in data: data = data["model"] return ModelConfig(**data) def param_count_formula(self, include_lm_head_bias: bool = False) -> int: # Formula (with learned positional embeddings and tied LM head): # Total = V*d + P*d + L*(12*d^2 + 13*d) + 2*d + (bias? V : 0) V = self.vocab_size P = self.n_positions if self.use_positional_embedding else 0 d = self.d_model L = self.n_layers total = V * d + P * d + L * (12 * d * d + 13 * d) + 2 * d if include_lm_head_bias: total += V return total def assert_exact_params(self, expected: int = 25_000_000) -> None: total = self.param_count_formula(include_lm_head_bias=False) assert total == expected, f"Parameter mismatch: got {total}, expected {expected}"