from dataclasses import dataclass import yaml from pathlib import Path @dataclass class TinyConfig: vocab_size: int = 32000 hidden_size: int = 576 num_hidden_layers: int = 16 num_attention_heads: int = 9 num_key_value_heads: int = 3 intermediate_size: int = 1536 max_position_embeddings: int = 1024 rope_theta: float = 10000.0 rms_norm_eps: float = 1e-5 tie_word_embeddings: bool = True attention_bias: bool = False mlp_bias: bool = False dropout: float = 0.0 bos_token_id: int = 1 eos_token_id: int = 2 @classmethod def from_yaml(cls, path): data = yaml.safe_load(Path(path).read_text()) if 'model' in data: data = data['model'] return cls(**data) def load_yaml(path): return yaml.safe_load(Path(path).read_text())