| | """ |
| | Configuration for 1B parameter LLaMA-style Transformer model. |
| | Architecture: Decoder-only Transformer with RoPE, GQA, SwiGLU, RMSNorm. |
| | """ |
| |
|
| | from dataclasses import dataclass |
| |
|
| |
|
| | @dataclass |
| | class ModelConfig: |
| | vocab_size: int = 32000 |
| | hidden_dim: int = 2048 |
| | intermediate_dim: int = 5504 |
| | num_layers: int = 22 |
| | num_attention_heads: int = 32 |
| | num_kv_heads: int = 8 |
| | max_seq_len: int = 2048 |
| | rope_theta: float = 10000.0 |
| | rms_norm_eps: float = 1e-5 |
| | dropout: float = 0.0 |
| | tie_word_embeddings: bool = False |
| |
|
| | @property |
| | def head_dim(self) -> int: |
| | return self.hidden_dim // self.num_attention_heads |
| |
|
| | @property |
| | def num_params_approx(self) -> int: |
| | """Rough parameter count estimate.""" |
| | embed = self.vocab_size * self.hidden_dim |
| | attn_per_layer = ( |
| | self.hidden_dim * self.head_dim * self.num_attention_heads + |
| | self.hidden_dim * self.head_dim * self.num_kv_heads + |
| | self.hidden_dim * self.head_dim * self.num_kv_heads + |
| | self.head_dim * self.num_attention_heads * self.hidden_dim |
| | ) |
| | ffn_per_layer = 3 * self.hidden_dim * self.intermediate_dim |
| | norm_per_layer = 2 * self.hidden_dim |
| | total = ( |
| | embed + |
| | self.num_layers * (attn_per_layer + ffn_per_layer + norm_per_layer) + |
| | self.hidden_dim + |
| | (0 if self.tie_word_embeddings else self.vocab_size * self.hidden_dim) |
| | ) |
| | return total |
| |
|
| |
|
| | @dataclass |
| | class TrainConfig: |
| | |
| | checkpoint_dir: str = "/jfs/deepak-kumar/checkpoints" |
| | data_cache_dir: str = "/jfs/deepak-kumar/data" |
| | log_dir: str = "/home/jovyan/training/logs" |
| |
|
| | |
| | total_tokens: int = 20_000_000_000 |
| | batch_size_per_gpu: int = 8 |
| | gradient_accumulation_steps: int = 8 |
| | max_seq_len: int = 2048 |
| | |
| | |
| | learning_rate: float = 3e-4 |
| | min_lr: float = 3e-5 |
| | warmup_steps: int = 1000 |
| | weight_decay: float = 0.1 |
| | beta1: float = 0.9 |
| | beta2: float = 0.95 |
| | grad_clip: float = 1.0 |
| |
|
| | |
| | log_interval: int = 10 |
| | save_interval: int = 1000 |
| | eval_interval: int = 500 |
| |
|
| | |
| | num_workers: int = 4 |
| | seed: int = 42 |
| | bf16: bool = True |
| |
|