File size: 1,749 Bytes
8174855
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import json
from dataclasses import dataclass
from typing import Optional


@dataclass
class ModelConfig:
    # Core
    vocab_size: int
    n_positions: int
    d_model: int
    n_layers: int
    n_heads: int
    mlp_ratio: int = 4
    dropout: float = 0.1
    tie_word_embeddings: bool = True
    use_positional_embedding: bool = True
    final_layer_norm: bool = True

    # Derived convenience
    @property
    def d_mlp(self) -> int:
        return self.d_model * self.mlp_ratio

    def to_json(self) -> str:
        return json.dumps(self.__dict__, indent=2)

    @staticmethod
    def from_json_str(s: str) -> "ModelConfig":
        data = json.loads(s)
        return ModelConfig(**data)

    @staticmethod
    def from_json_file(path: str) -> "ModelConfig":
        with open(path, "r", encoding="utf-8") as f:
            data = json.load(f)
        if "model" in data:
            data = data["model"]
        return ModelConfig(**data)

    def param_count_formula(self, include_lm_head_bias: bool = False) -> int:
        # Formula (with learned positional embeddings and tied LM head):
        # Total = V*d + P*d + L*(12*d^2 + 13*d) + 2*d + (bias? V : 0)
        V = self.vocab_size
        P = self.n_positions if self.use_positional_embedding else 0
        d = self.d_model
        L = self.n_layers
        total = V * d + P * d + L * (12 * d * d + 13 * d) + 2 * d
        if include_lm_head_bias:
            total += V
        return total

    def assert_exact_params(self, expected: int = 25_000_000) -> None:
        total = self.param_count_formula(include_lm_head_bias=False)
        assert total == expected, f"Parameter mismatch: got {total}, expected {expected}"