nameissakthi's picture
Add model architecture code
27871e7
"""
Model configuration for SLM v1.
Defines all hyperparameters based on architecture specification.
"""
from dataclasses import dataclass
from typing import Optional
import yaml
@dataclass
class SLMConfig:
"""Configuration class for the SLM model.
Architecture: 120M parameter decoder-only transformer
- 8 layers, 1024 hidden size, 16 attention heads
- RMSNorm (pre-norm), GELU FFN, RoPE positions
- Explicit KV cache for efficient inference
"""
# Model architecture
vocab_size: int = 16384
hidden_size: int = 1024
num_layers: int = 8
num_heads: int = 16
head_dim: int = 64
intermediate_size: int = 4096 # 4 * hidden_size
# Position encoding
max_position_embeddings: int = 1024
rope_theta: float = 10000.0
# Normalization
rms_norm_eps: float = 1e-6
# Embeddings
tie_word_embeddings: bool = True
# Dropout (disabled for inference, optional for training)
dropout: float = 0.0
attention_dropout: float = 0.0
# Precision
torch_dtype: str = "float16"
def __post_init__(self):
"""Validate configuration after initialization."""
assert self.hidden_size % self.num_heads == 0, \
f"hidden_size ({self.hidden_size}) must be divisible by num_heads ({self.num_heads})"
assert self.head_dim == self.hidden_size // self.num_heads, \
f"head_dim ({self.head_dim}) must equal hidden_size // num_heads ({self.hidden_size // self.num_heads})"
@classmethod
def from_yaml(cls, path: str) -> "SLMConfig":
"""Load configuration from YAML file."""
with open(path, "r") as f:
config_dict = yaml.safe_load(f)
model_config = config_dict.get("model", {})
return cls(**model_config)
def to_dict(self) -> dict:
"""Convert configuration to dictionary."""
return {
"vocab_size": self.vocab_size,
"hidden_size": self.hidden_size,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"head_dim": self.head_dim,
"intermediate_size": self.intermediate_size,
"max_position_embeddings": self.max_position_embeddings,
"rope_theta": self.rope_theta,
"rms_norm_eps": self.rms_norm_eps,
"tie_word_embeddings": self.tie_word_embeddings,
"dropout": self.dropout,
"attention_dropout": self.attention_dropout,
"torch_dtype": self.torch_dtype,
}
@property
def num_parameters(self) -> int:
"""Estimate total number of parameters."""
# Embedding: vocab_size * hidden_size
embedding_params = self.vocab_size * self.hidden_size
# Per layer:
# - Attention: 4 * hidden_size^2 (Q, K, V, O projections)
# - FFN: 2 * hidden_size * intermediate_size
# - Norms: 2 * hidden_size
attention_params = 4 * self.hidden_size * self.hidden_size
ffn_params = 2 * self.hidden_size * self.intermediate_size
norm_params = 2 * self.hidden_size
layer_params = attention_params + ffn_params + norm_params
total_layer_params = self.num_layers * layer_params
# Output head (tied with embedding if enabled)
output_params = 0 if self.tie_word_embeddings else self.vocab_size * self.hidden_size
# Final norm
final_norm_params = self.hidden_size
return embedding_params + total_layer_params + output_params + final_norm_params
def __repr__(self) -> str:
params_m = self.num_parameters / 1e6
return (
f"SLMConfig(\n"
f" vocab_size={self.vocab_size},\n"
f" hidden_size={self.hidden_size},\n"
f" num_layers={self.num_layers},\n"
f" num_heads={self.num_heads},\n"
f" max_position_embeddings={self.max_position_embeddings},\n"
f" estimated_params={params_m:.1f}M\n"
f")"
)