algorythmtechnologies's picture
Upload folder using huggingface_hub
8174855 verified
import json
from dataclasses import dataclass
from typing import Optional
@dataclass
class ModelConfig:
# Core
vocab_size: int
n_positions: int
d_model: int
n_layers: int
n_heads: int
mlp_ratio: int = 4
dropout: float = 0.1
tie_word_embeddings: bool = True
use_positional_embedding: bool = True
final_layer_norm: bool = True
# Derived convenience
@property
def d_mlp(self) -> int:
return self.d_model * self.mlp_ratio
def to_json(self) -> str:
return json.dumps(self.__dict__, indent=2)
@staticmethod
def from_json_str(s: str) -> "ModelConfig":
data = json.loads(s)
return ModelConfig(**data)
@staticmethod
def from_json_file(path: str) -> "ModelConfig":
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
if "model" in data:
data = data["model"]
return ModelConfig(**data)
def param_count_formula(self, include_lm_head_bias: bool = False) -> int:
# Formula (with learned positional embeddings and tied LM head):
# Total = V*d + P*d + L*(12*d^2 + 13*d) + 2*d + (bias? V : 0)
V = self.vocab_size
P = self.n_positions if self.use_positional_embedding else 0
d = self.d_model
L = self.n_layers
total = V * d + P * d + L * (12 * d * d + 13 * d) + 2 * d
if include_lm_head_bias:
total += V
return total
def assert_exact_params(self, expected: int = 25_000_000) -> None:
total = self.param_count_formula(include_lm_head_bias=False)
assert total == expected, f"Parameter mismatch: got {total}, expected {expected}"