"""Model configuration for SAGE.""" from __future__ import annotations from dataclasses import asdict, dataclass from pathlib import Path from typing import Any import yaml @dataclass class ModelConfig: """Configuration for the dense SAGE decoder-only transformer.""" name: str = "sage-1b" num_layers: int = 24 d_model: int = 2048 num_attn_heads: int = 16 num_kv_heads: int = 8 head_dim: int = 128 ffn_hidden_dim: int = 5632 vocab_size: int = 50_000 context_length: int = 4096 rope_base_frequency: int = 500_000 rope_scaling_factor: float = 1.0 dropout: float = 0.0 tie_word_embeddings: bool = True rms_norm_eps: float = 1.0e-5 initializer_range: float = 0.02 def __post_init__(self) -> None: if self.num_attn_heads * self.head_dim != self.d_model: raise ValueError("num_attn_heads * head_dim must equal d_model.") if self.num_attn_heads % self.num_kv_heads != 0: raise ValueError("num_attn_heads must be divisible by num_kv_heads.") if self.ffn_hidden_dim % 256 != 0: raise ValueError("ffn_hidden_dim must be a multiple of 256.") @classmethod def from_yaml(cls, path: str | Path) -> "ModelConfig": """Load a config from YAML.""" payload = yaml.safe_load(Path(path).read_text(encoding="utf-8")) return cls(**payload) def to_dict(self) -> dict[str, Any]: """Serialize the config to a dict.""" return asdict(self)