dkumar15 commited on
Commit
1dd2382
·
verified ·
1 Parent(s): f6f7e6c

Upload training_code/model/config.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. training_code/model/config.py +78 -0
training_code/model/config.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration for 1B parameter LLaMA-style Transformer model.
3
+ Architecture: Decoder-only Transformer with RoPE, GQA, SwiGLU, RMSNorm.
4
+ """
5
+
6
+ from dataclasses import dataclass
7
+
8
+
9
+ @dataclass
10
+ class ModelConfig:
11
+ vocab_size: int = 32000
12
+ hidden_dim: int = 2048
13
+ intermediate_dim: int = 5504 # ~2.7x hidden for SwiGLU (adjusted for param count)
14
+ num_layers: int = 22
15
+ num_attention_heads: int = 32
16
+ num_kv_heads: int = 8 # GQA: 4 query heads per KV head
17
+ max_seq_len: int = 2048
18
+ rope_theta: float = 10000.0
19
+ rms_norm_eps: float = 1e-5
20
+ dropout: float = 0.0 # No dropout (modern practice for pretraining)
21
+ tie_word_embeddings: bool = False
22
+
23
+ @property
24
+ def head_dim(self) -> int:
25
+ return self.hidden_dim // self.num_attention_heads
26
+
27
+ @property
28
+ def num_params_approx(self) -> int:
29
+ """Rough parameter count estimate."""
30
+ embed = self.vocab_size * self.hidden_dim
31
+ attn_per_layer = (
32
+ self.hidden_dim * self.head_dim * self.num_attention_heads + # Q
33
+ self.hidden_dim * self.head_dim * self.num_kv_heads + # K
34
+ self.hidden_dim * self.head_dim * self.num_kv_heads + # V
35
+ self.head_dim * self.num_attention_heads * self.hidden_dim # O
36
+ )
37
+ ffn_per_layer = 3 * self.hidden_dim * self.intermediate_dim # gate + up + down
38
+ norm_per_layer = 2 * self.hidden_dim
39
+ total = (
40
+ embed +
41
+ self.num_layers * (attn_per_layer + ffn_per_layer + norm_per_layer) +
42
+ self.hidden_dim + # final norm
43
+ (0 if self.tie_word_embeddings else self.vocab_size * self.hidden_dim)
44
+ )
45
+ return total
46
+
47
+
48
+ @dataclass
49
+ class TrainConfig:
50
+ # Paths
51
+ checkpoint_dir: str = "/jfs/deepak-kumar/checkpoints"
52
+ data_cache_dir: str = "/jfs/deepak-kumar/data"
53
+ log_dir: str = "/home/jovyan/training/logs"
54
+
55
+ # Training
56
+ total_tokens: int = 20_000_000_000 # 20B tokens
57
+ batch_size_per_gpu: int = 8
58
+ gradient_accumulation_steps: int = 8 # effective batch = 8 * 8 * 8 = 512 seqs
59
+ max_seq_len: int = 2048
60
+
61
+ # WSD Schedule
62
+ learning_rate: float = 3e-4
63
+ min_lr: float = 3e-5
64
+ warmup_steps: int = 1000
65
+ weight_decay: float = 0.1
66
+ beta1: float = 0.9
67
+ beta2: float = 0.95
68
+ grad_clip: float = 1.0
69
+
70
+ # Logging
71
+ log_interval: int = 10
72
+ save_interval: int = 1000
73
+ eval_interval: int = 500
74
+
75
+ # System
76
+ num_workers: int = 4
77
+ seed: int = 42
78
+ bf16: bool = True