File size: 3,981 Bytes
27871e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""
Model configuration for SLM v1.
Defines all hyperparameters based on architecture specification.
"""

from dataclasses import dataclass
from typing import Optional
import yaml


@dataclass
class SLMConfig:
    """Configuration class for the SLM model.

    Architecture: 120M parameter decoder-only transformer
    - 8 layers, 1024 hidden size, 16 attention heads
    - RMSNorm (pre-norm), GELU FFN, RoPE positions
    - Explicit KV cache for efficient inference
    """

    # Model architecture
    vocab_size: int = 16384
    hidden_size: int = 1024
    num_layers: int = 8
    num_heads: int = 16
    head_dim: int = 64
    intermediate_size: int = 4096  # 4 * hidden_size

    # Position encoding
    max_position_embeddings: int = 1024
    rope_theta: float = 10000.0

    # Normalization
    rms_norm_eps: float = 1e-6

    # Embeddings
    tie_word_embeddings: bool = True

    # Dropout (disabled for inference, optional for training)
    dropout: float = 0.0
    attention_dropout: float = 0.0

    # Precision
    torch_dtype: str = "float16"

    def __post_init__(self):
        """Validate configuration after initialization."""
        assert self.hidden_size % self.num_heads == 0, \
            f"hidden_size ({self.hidden_size}) must be divisible by num_heads ({self.num_heads})"
        assert self.head_dim == self.hidden_size // self.num_heads, \
            f"head_dim ({self.head_dim}) must equal hidden_size // num_heads ({self.hidden_size // self.num_heads})"

    @classmethod
    def from_yaml(cls, path: str) -> "SLMConfig":
        """Load configuration from YAML file."""
        with open(path, "r") as f:
            config_dict = yaml.safe_load(f)

        model_config = config_dict.get("model", {})
        return cls(**model_config)

    def to_dict(self) -> dict:
        """Convert configuration to dictionary."""
        return {
            "vocab_size": self.vocab_size,
            "hidden_size": self.hidden_size,
            "num_layers": self.num_layers,
            "num_heads": self.num_heads,
            "head_dim": self.head_dim,
            "intermediate_size": self.intermediate_size,
            "max_position_embeddings": self.max_position_embeddings,
            "rope_theta": self.rope_theta,
            "rms_norm_eps": self.rms_norm_eps,
            "tie_word_embeddings": self.tie_word_embeddings,
            "dropout": self.dropout,
            "attention_dropout": self.attention_dropout,
            "torch_dtype": self.torch_dtype,
        }

    @property
    def num_parameters(self) -> int:
        """Estimate total number of parameters."""
        # Embedding: vocab_size * hidden_size
        embedding_params = self.vocab_size * self.hidden_size

        # Per layer:
        # - Attention: 4 * hidden_size^2 (Q, K, V, O projections)
        # - FFN: 2 * hidden_size * intermediate_size
        # - Norms: 2 * hidden_size
        attention_params = 4 * self.hidden_size * self.hidden_size
        ffn_params = 2 * self.hidden_size * self.intermediate_size
        norm_params = 2 * self.hidden_size

        layer_params = attention_params + ffn_params + norm_params
        total_layer_params = self.num_layers * layer_params

        # Output head (tied with embedding if enabled)
        output_params = 0 if self.tie_word_embeddings else self.vocab_size * self.hidden_size

        # Final norm
        final_norm_params = self.hidden_size

        return embedding_params + total_layer_params + output_params + final_norm_params

    def __repr__(self) -> str:
        params_m = self.num_parameters / 1e6
        return (
            f"SLMConfig(\n"
            f"  vocab_size={self.vocab_size},\n"
            f"  hidden_size={self.hidden_size},\n"
            f"  num_layers={self.num_layers},\n"
            f"  num_heads={self.num_heads},\n"
            f"  max_position_embeddings={self.max_position_embeddings},\n"
            f"  estimated_params={params_m:.1f}M\n"
            f")"
        )