File size: 2,531 Bytes
1dd2382
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""
Configuration for 1B parameter LLaMA-style Transformer model.
Architecture: Decoder-only Transformer with RoPE, GQA, SwiGLU, RMSNorm.
"""

from dataclasses import dataclass


@dataclass
class ModelConfig:
    vocab_size: int = 32000
    hidden_dim: int = 2048
    intermediate_dim: int = 5504       # ~2.7x hidden for SwiGLU (adjusted for param count)
    num_layers: int = 22
    num_attention_heads: int = 32
    num_kv_heads: int = 8              # GQA: 4 query heads per KV head
    max_seq_len: int = 2048
    rope_theta: float = 10000.0
    rms_norm_eps: float = 1e-5
    dropout: float = 0.0               # No dropout (modern practice for pretraining)
    tie_word_embeddings: bool = False

    @property
    def head_dim(self) -> int:
        return self.hidden_dim // self.num_attention_heads

    @property
    def num_params_approx(self) -> int:
        """Rough parameter count estimate."""
        embed = self.vocab_size * self.hidden_dim
        attn_per_layer = (
            self.hidden_dim * self.head_dim * self.num_attention_heads +  # Q
            self.hidden_dim * self.head_dim * self.num_kv_heads +         # K
            self.hidden_dim * self.head_dim * self.num_kv_heads +         # V
            self.head_dim * self.num_attention_heads * self.hidden_dim    # O
        )
        ffn_per_layer = 3 * self.hidden_dim * self.intermediate_dim      # gate + up + down
        norm_per_layer = 2 * self.hidden_dim
        total = (
            embed +
            self.num_layers * (attn_per_layer + ffn_per_layer + norm_per_layer) +
            self.hidden_dim +  # final norm
            (0 if self.tie_word_embeddings else self.vocab_size * self.hidden_dim)
        )
        return total


@dataclass
class TrainConfig:
    # Paths
    checkpoint_dir: str = "/jfs/deepak-kumar/checkpoints"
    data_cache_dir: str = "/jfs/deepak-kumar/data"
    log_dir: str = "/home/jovyan/training/logs"

    # Training
    total_tokens: int = 20_000_000_000   # 20B tokens
    batch_size_per_gpu: int = 8
    gradient_accumulation_steps: int = 8  # effective batch = 8 * 8 * 8 = 512 seqs
    max_seq_len: int = 2048
    
    # WSD Schedule
    learning_rate: float = 3e-4
    min_lr: float = 3e-5
    warmup_steps: int = 1000
    weight_decay: float = 0.1
    beta1: float = 0.9
    beta2: float = 0.95
    grad_clip: float = 1.0

    # Logging
    log_interval: int = 10
    save_interval: int = 1000
    eval_interval: int = 500

    # System
    num_workers: int = 4
    seed: int = 42
    bf16: bool = True