|
|
import json
|
|
|
from dataclasses import dataclass
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class ModelConfig:
|
|
|
|
|
|
vocab_size: int
|
|
|
n_positions: int
|
|
|
d_model: int
|
|
|
n_layers: int
|
|
|
n_heads: int
|
|
|
mlp_ratio: int = 4
|
|
|
dropout: float = 0.1
|
|
|
tie_word_embeddings: bool = True
|
|
|
use_positional_embedding: bool = True
|
|
|
final_layer_norm: bool = True
|
|
|
|
|
|
|
|
|
@property
|
|
|
def d_mlp(self) -> int:
|
|
|
return self.d_model * self.mlp_ratio
|
|
|
|
|
|
def to_json(self) -> str:
|
|
|
return json.dumps(self.__dict__, indent=2)
|
|
|
|
|
|
@staticmethod
|
|
|
def from_json_str(s: str) -> "ModelConfig":
|
|
|
data = json.loads(s)
|
|
|
return ModelConfig(**data)
|
|
|
|
|
|
@staticmethod
|
|
|
def from_json_file(path: str) -> "ModelConfig":
|
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
|
data = json.load(f)
|
|
|
if "model" in data:
|
|
|
data = data["model"]
|
|
|
return ModelConfig(**data)
|
|
|
|
|
|
def param_count_formula(self, include_lm_head_bias: bool = False) -> int:
|
|
|
|
|
|
|
|
|
V = self.vocab_size
|
|
|
P = self.n_positions if self.use_positional_embedding else 0
|
|
|
d = self.d_model
|
|
|
L = self.n_layers
|
|
|
total = V * d + P * d + L * (12 * d * d + 13 * d) + 2 * d
|
|
|
if include_lm_head_bias:
|
|
|
total += V
|
|
|
return total
|
|
|
|
|
|
def assert_exact_params(self, expected: int = 25_000_000) -> None:
|
|
|
total = self.param_count_formula(include_lm_head_bias=False)
|
|
|
assert total == expected, f"Parameter mismatch: got {total}, expected {expected}"
|
|
|
|