alexandretl commited on
Commit
bc8288b
·
1 Parent(s): 959cbe5

MLA | KDA | TPA | GDA | ResFormer | Mamba3 | DragonMimo (WIP) | tokenshift | SeeDNorm | shrink DA/GDN | gate shared across all block types |

Browse files
configuration_dragon.py CHANGED
@@ -3,6 +3,7 @@
3
  # TODO : TP (cf qwen)
4
  # TODO : init
5
 
 
6
  import re
7
 
8
  from transformers.configuration_utils import PretrainedConfig
@@ -89,29 +90,40 @@ class DragonConfig(PretrainedConfig):
89
  model_type = "dragon"
90
  keys_to_ignore_at_inference = ["past_key_values"]
91
 
92
- """
93
- config.num_attention_heads_indexer
94
- self.indexer_head_dim = config.head_dim_indexer
95
- self.q_lora_rank = config.dsa_q_lora_rank
96
- self.topk = config.dsa_topk
97
- """
98
-
99
  def __init__(
100
  self,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  patch_level_training: bool = False,
102
  patch_level_training_size: int = 4,
103
- nsa_head_dim: int = 128,
104
  nsa_topk: int = 16,
105
  nsa_block_size: int = 64,
106
  nsa_window_size: int = 512,
107
- cca_head_dim: int = 128,
108
  cca_seq_kernel_size: int = 4,
109
  rope_gdn: str = None,
110
  zero_centered_gate: bool = False,
111
  zero_centered_gate_type: int = 1,
112
  scalable_softmax: bool = True,
 
 
 
 
113
  gate_attn: bool = False,
114
  gate_gdn: bool = True,
 
115
  num_attention_heads_gdn: int = 32,
116
  num_key_value_heads_gdn: int = None,
117
  fused_loss_computation=False,
@@ -129,6 +141,7 @@ class DragonConfig(PretrainedConfig):
129
  intermediate_size=8192,
130
  expand_factor=2,
131
  layers_config=4*"lrdlr",
 
132
  num_attention_heads=32,
133
  num_key_value_heads=8,
134
  mlp_hidden_act="relu2",
@@ -147,7 +160,10 @@ class DragonConfig(PretrainedConfig):
147
  eos_token_id=2,
148
  sliding_window_size=1024,
149
  slw_wsize=-1,
 
 
150
  rope_theta_local=163.,
 
151
  uscaling_tau=0.2,
152
  attention_dropout=0.,
153
  hidden_dropout=0.,
@@ -157,21 +173,39 @@ class DragonConfig(PretrainedConfig):
157
  gdn_dt_init_floor=1e-4,
158
  gdn_A_init_range=(1, 16),
159
  old_lns=False,
 
160
  **kwargs,
161
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  self.patch_level_training = patch_level_training
163
  self.patch_level_training_size = patch_level_training_size
164
- self.nsa_head_dim = nsa_head_dim
165
  self.nsa_topk = nsa_topk
166
  self.nsa_block_size = nsa_block_size
167
  self.nsa_window_size = nsa_window_size
168
- self.cca_head_dim = cca_head_dim
169
  self.cca_seq_kernel_size = cca_seq_kernel_size
170
  self.rope_gdn = rope_gdn
171
  self.zero_centered_gate = zero_centered_gate
172
  self.zero_centered_gate_type = zero_centered_gate_type
 
 
173
  self.gate_attn = gate_attn
174
  self.gate_gdn = gate_gdn
 
 
175
  self.num_attention_heads_gdn = num_attention_heads_gdn
176
  if num_key_value_heads_gdn is None:
177
  num_key_value_heads_gdn = num_attention_heads_gdn
@@ -182,13 +216,18 @@ class DragonConfig(PretrainedConfig):
182
  self.dsa_q_lora_rank = dsa_q_lora_rank
183
  self.dsa_topk = dsa_topk
184
  self.zero_centered_gamma = zero_centered_gamma
185
- self.rope_theta = rope_theta_local
 
 
 
186
  self.qk_norm = qk_norm
187
  self.softcap_local_attn=softcap_local_attn
188
  self.softcap_global_attn=softcap_global_attn
189
  self.use_uscaling = use_uscaling
190
  self.uscaling_tau = uscaling_tau
191
  self.scalable_softmax = scalable_softmax
 
 
192
 
193
  self.vocab_size = vocab_size
194
  self.tie_word_embeddings = tie_word_embeddings
@@ -226,9 +265,11 @@ class DragonConfig(PretrainedConfig):
226
  self.A_init_range = gdn_A_init_range
227
 
228
  self.old_lns = old_lns
 
 
229
 
230
- assert self.hidden_size % self.num_attention_heads == 0
231
- assert self.num_attention_heads % self.num_key_value_heads == 0
232
  #assert self.num_attention_heads % 2 == 0, "Number of attention heads must be even for differential attention."
233
  #assert self.num_key_value_heads % 2 == 0, "Number of kv heads must be even for differential attention."
234
 
 
3
  # TODO : TP (cf qwen)
4
  # TODO : init
5
 
6
+ from typing import Optional
7
  import re
8
 
9
  from transformers.configuration_utils import PretrainedConfig
 
90
  model_type = "dragon"
91
  keys_to_ignore_at_inference = ["past_key_values"]
92
 
 
 
 
 
 
 
 
93
  def __init__(
94
  self,
95
+ mla_kv_rank: int = 128,
96
+ shrink_qk_da: int = 2,
97
+ shrink_qk_gdn: int = 2,
98
+ mixer_gn: bool = True,
99
+ kda_allow_neg_eigval: bool = False,
100
+ kda_num_v_heads: Optional[int] = None,
101
+ seednorm_wd: bool = True,
102
+ normalization_type: str = "rmsnorm",
103
+ tpa_rank: int = 2,
104
+ num_signal_heads_diff: Optional[int] = None,
105
+ scalar_proj_as_hidden_matrix: bool = True,
106
+ token_shift_attn: bool = False,
107
+ token_shift_gdn: bool = False,
108
+ token_conv1d_attn: bool = False,
109
+ token_conv1d_gdn: bool = True,
110
  patch_level_training: bool = False,
111
  patch_level_training_size: int = 4,
 
112
  nsa_topk: int = 16,
113
  nsa_block_size: int = 64,
114
  nsa_window_size: int = 512,
 
115
  cca_seq_kernel_size: int = 4,
116
  rope_gdn: str = None,
117
  zero_centered_gate: bool = False,
118
  zero_centered_gate_type: int = 1,
119
  scalable_softmax: bool = True,
120
+ resformer: bool = False,
121
+ mamba_mimo_dim : int = 4,
122
+ gate_type: str = "elementwise",
123
+ gate_act: str = "silu",
124
  gate_attn: bool = False,
125
  gate_gdn: bool = True,
126
+ head_dim_gdn: Optional[int] = None,
127
  num_attention_heads_gdn: int = 32,
128
  num_key_value_heads_gdn: int = None,
129
  fused_loss_computation=False,
 
141
  intermediate_size=8192,
142
  expand_factor=2,
143
  layers_config=4*"lrdlr",
144
+ head_dim=128,
145
  num_attention_heads=32,
146
  num_key_value_heads=8,
147
  mlp_hidden_act="relu2",
 
160
  eos_token_id=2,
161
  sliding_window_size=1024,
162
  slw_wsize=-1,
163
+ rope_type_local="rope",
164
+ rope_type_global="",
165
  rope_theta_local=163.,
166
+ rope_theta_global=10000.,
167
  uscaling_tau=0.2,
168
  attention_dropout=0.,
169
  hidden_dropout=0.,
 
173
  gdn_dt_init_floor=1e-4,
174
  gdn_A_init_range=(1, 16),
175
  old_lns=False,
176
+ mlp_linking=False,
177
  **kwargs,
178
  ):
179
+ self.mla_kv_rank = mla_kv_rank
180
+ self.shrink_qk_da = shrink_qk_da
181
+ self.shrink_qk_gdn = shrink_qk_gdn
182
+ self.mixer_gn = mixer_gn
183
+ self.kda_allow_neg_eigval = kda_allow_neg_eigval
184
+ self.kda_num_v_heads = kda_num_v_heads
185
+ self.seednorm_wd = seednorm_wd
186
+ self.normalization_type = normalization_type
187
+ self.tpa_rank = tpa_rank
188
+ self.num_signal_heads_diff = num_signal_heads_diff
189
+ self.scalar_proj_as_hidden_matrix = scalar_proj_as_hidden_matrix
190
+ self.token_shift_attn = token_shift_attn
191
+ self.token_shift_gdn = token_shift_gdn
192
+ self.token_conv1d_attn = token_conv1d_attn
193
+ self.token_conv1d_gdn = token_conv1d_gdn
194
  self.patch_level_training = patch_level_training
195
  self.patch_level_training_size = patch_level_training_size
 
196
  self.nsa_topk = nsa_topk
197
  self.nsa_block_size = nsa_block_size
198
  self.nsa_window_size = nsa_window_size
 
199
  self.cca_seq_kernel_size = cca_seq_kernel_size
200
  self.rope_gdn = rope_gdn
201
  self.zero_centered_gate = zero_centered_gate
202
  self.zero_centered_gate_type = zero_centered_gate_type
203
+ self.gate_type = gate_type
204
+ self.gate_act = gate_act
205
  self.gate_attn = gate_attn
206
  self.gate_gdn = gate_gdn
207
+ self.head_dim = head_dim
208
+ self.head_dim_gdn = head_dim_gdn
209
  self.num_attention_heads_gdn = num_attention_heads_gdn
210
  if num_key_value_heads_gdn is None:
211
  num_key_value_heads_gdn = num_attention_heads_gdn
 
216
  self.dsa_q_lora_rank = dsa_q_lora_rank
217
  self.dsa_topk = dsa_topk
218
  self.zero_centered_gamma = zero_centered_gamma
219
+ self.rope_type_local = rope_type_local
220
+ self.rope_type_global = rope_type_global
221
+ self.rope_theta_local = rope_theta_local
222
+ self.rope_theta_global = rope_theta_global
223
  self.qk_norm = qk_norm
224
  self.softcap_local_attn=softcap_local_attn
225
  self.softcap_global_attn=softcap_global_attn
226
  self.use_uscaling = use_uscaling
227
  self.uscaling_tau = uscaling_tau
228
  self.scalable_softmax = scalable_softmax
229
+ self.resformer = resformer
230
+ self.mamba_mimo_dim = mamba_mimo_dim
231
 
232
  self.vocab_size = vocab_size
233
  self.tie_word_embeddings = tie_word_embeddings
 
265
  self.A_init_range = gdn_A_init_range
266
 
267
  self.old_lns = old_lns
268
+
269
+ self.mlp_linking = mlp_linking
270
 
271
+ #assert self.hidden_size % self.num_attention_heads == 0
272
+ #assert self.num_attention_heads % self.num_key_value_heads == 0
273
  #assert self.num_attention_heads % 2 == 0, "Number of attention heads must be even for differential attention."
274
  #assert self.num_key_value_heads % 2 == 0, "Number of kv heads must be even for differential attention."
275
 
inspecting_dragon.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Optional
2
+ from dataclasses import dataclass
3
+ import json
4
+ import re
5
+ import torch
6
+ import torch.nn as nn
7
+ from functools import partial
8
+ from collections import defaultdict
9
+ import tyro
10
+
11
+ from .configuration_dragon import DragonConfig
12
+ from .modeling_dragon import DragonForCausalLM
13
+
14
+ @dataclass
15
+ class NanoArgs:
16
+ resume_from: Optional[str] = None
17
+ run_name : str = ""
18
+
19
+ # arch - general
20
+ d_model : int = 768
21
+ n_heads : int = 6 # head dim 128 suggested by @Grad62304977
22
+ layers_config : str = 4*"lrdlr"
23
+ expand_factor : int = 1 # expand factor for Mamba/Dragon
24
+ rope_theta_local: float = 10000.0
25
+ eps_rmsnorm: float = 1e-6
26
+ mlp_expand: int = 4 # expand factor for MLP
27
+ fused_loss_computation : bool = True # whether to use fused linear + cross entropy loss
28
+ use_uscaling: bool = False
29
+ uscaling_tau: float = 0.2
30
+ zero_centered_gamma: bool = False
31
+ zero_centered_gate: bool = False
32
+ zero_centered_gate_type: int = 1 # 1, 2, 3, 4
33
+ gate_attn: bool = False
34
+ gate_gdn: bool = True
35
+ gate_type: str = "elementwise" # elementwise (one per dim), headwise (one per head)
36
+ gate_act: str = "silu" # silu, sigmoid
37
+ scalar_proj_as_hidden_matrix: bool = True
38
+
39
+ # attention related
40
+ n_kv_heads : int = 0
41
+ swa_window_size : int = 1024
42
+ slw_warmup_iters: float = 0
43
+ slw_start: int = 8 # window size at the start of training
44
+ slw_increment: int = 64 # window size increment at each step
45
+ softcap_local_attn: float = 0.0 # logit soft-capping for local attn logits, as per Gemma2 (0.0 = no soft-capping)
46
+ softcap_global_attn: float = 0.0
47
+ qk_norm: bool = True
48
+ scalable_softmax: bool = True
49
+ token_shift: bool = False
50
+ num_attention_heads_indexer: int = 8
51
+ head_dim_indexer: int = 32
52
+ dsa_q_lora_rank: int = 128
53
+ dsa_topk: int = 512
54
+ cca_head_dim: int = 128
55
+ cca_seq_kernel_size: int = 4
56
+ nsa_head_dim: int = 128
57
+ nsa_topk: int = 16
58
+ nsa_block_size: int = 64
59
+ nsa_window_size: int = 512
60
+
61
+ # GDN related
62
+ rope_gdn: Optional[str] = None # None, rope, (srope)
63
+ n_heads_gdn: int = 0
64
+ n_kv_heads_gdn: int = 0
65
+
66
+ # optim
67
+ optim: str = "adamw" # adamw, spam, stable-spam, muon, muon_moonlight, splus
68
+ second_order_optim : Optional[str] = None #Snoo
69
+ batch_size: int = 8*64 # batch size, in sequences, across all devices
70
+ device_batch_size: int = 64 # batch size, in sequences, per device
71
+ total_iterations: int = 1000 # number of iterations to run
72
+ learning_rate: float = 1e-4
73
+ weight_decay: float = 0.
74
+ adam_beta1: float = 0.9
75
+ adam_beta2: float = 0.95
76
+ adam_eps: float = 1e-8
77
+ warmup_iters: int = 200
78
+ warmdown_iters: int = 3000
79
+ grad_norm_clip: float = 1.0
80
+ uscaling_mult_embed: float = 0
81
+ uscaling_mult_scalar: float = 0
82
+ uscaling_mult_head: float = 0
83
+ init_std: float = 0.006
84
+ patch_level_training: bool = False
85
+ patch_level_training_size: int = 4
86
+ patch_level_training_mode: str = "reduced" # reduced = ask L tokens, treat L//K. full = ask K*L tokens, treat L.
87
+
88
+ # data
89
+ vocab_size: int = 50304
90
+ sequence_length: int = 1024
91
+ use_patch_level_training: bool = False
92
+ patch_size: int = 4
93
+ patch_training_fraction: float = 0.67
94
+ input_bin: Optional[str] = None
95
+ input_val_bin: Optional[str] = None
96
+
97
+ # evaluation and logging
98
+ val_loss_every: int = 125
99
+ val_iterations: int = 50 # 1 step = global bs * T tokens
100
+ inspect_every: int = 0
101
+ save_every: int = 1000
102
+ log_dir: str = "logs/"
103
+ wandb_project: str = "dragon_v1.5"
104
+ wandb_name: Optional[str] = None
105
+ log_wandb: bool = False
106
+
107
+ load_arg_from_config: bool = True
108
+ load_optim: bool = True
109
+ load_sched: bool = True
110
+ compile: bool = True
111
+
112
+ # used during training
113
+ slw_window: int = 0
114
+
115
+ args = tyro.cli(NanoArgs)
116
+
117
+ # load model.
118
+ config_hf = DragonConfig(
119
+ scalar_proj_as_hidden_matrix=args.scalar_proj_as_hidden_matrix,
120
+ token_shift=args.token_shift,
121
+ patch_level_training=args.patch_level_training,
122
+ patch_level_training_size=args.patch_level_training_size,
123
+ nsa_head_dim=args.nsa_head_dim,
124
+ nsa_topk=args.nsa_topk,
125
+ nsa_block_size=args.nsa_block_size,
126
+ nsa_window_size=args.nsa_window_size,
127
+ cca_head_dim=args.cca_head_dim,
128
+ cca_seq_kernel_size=args.cca_seq_kernel_size,
129
+ num_attention_heads_gdn=args.n_heads_gdn,
130
+ num_key_value_heads_gdn=args.n_kv_heads_gdn,
131
+ zero_centered_gate=args.zero_centered_gate,
132
+ zero_centered_gate_type=args.zero_centered_gate_type,
133
+ scalable_softmax=args.scalable_softmax,
134
+ gate_type=args.gate_type,
135
+ gate_act=args.gate_act,
136
+ gate_attn=args.gate_attn,
137
+ gate_gdn=args.gate_gdn,
138
+ fused_loss_computation=args.fused_loss_computation,
139
+ qk_norm=args.qk_norm,
140
+ num_attention_heads_indexer=args.num_attention_heads_indexer,
141
+ head_dim_indexer=args.head_dim_indexer,
142
+ dsa_q_lora_rank=args.dsa_q_lora_rank,
143
+ dsa_topk=args.dsa_topk,
144
+ zero_centered_gamma=args.zero_centered_gamma,
145
+ vocab_size=args.vocab_size,
146
+ max_position_embeddings=args.sequence_length,
147
+ use_uscaling=args.use_uscaling,
148
+ hidden_size=args.d_model,
149
+ intermediate_size=args.d_model * args.mlp_expand,
150
+ expand_factor=args.expand_factor,
151
+ layers_config=args.layers_config,
152
+ num_attention_heads=args.n_heads,
153
+ num_key_value_heads=args.n_kv_heads if args.n_kv_heads > 0 else args.n_heads,
154
+ initializer_range=args.init_std,
155
+ softcap_local_attn=args.softcap_local_attn,
156
+ softcap_global_attn=args.softcap_global_attn,
157
+ norm_epsilon=args.eps_rmsnorm,
158
+ use_cache=False,
159
+ sliding_window_size=args.swa_window_size,
160
+ rope_theta_local=args.rope_theta_local,
161
+ uscaling_tau=args.uscaling_tau,
162
+ )
163
+
164
+ model = DragonForCausalLM(config_hf)
165
+ model = model.cuda()
166
+
167
+ B, L = 2, 2048
168
+
169
+ # ---------- helpers ---------- #
170
+ def l1(x: torch.Tensor) -> float:
171
+ return x.abs().mean().item()
172
+
173
+ def _capture(name: str, store: Dict[str, torch.Tensor], _m, _inp, out):
174
+ """Save every tensor produced by a module so that we can measure activations."""
175
+ def walk(x, suf=""):
176
+ if torch.is_tensor(x):
177
+ store[f"{name}{suf}"] = x.detach()
178
+ elif isinstance(x, (list, tuple)):
179
+ for i, xi in enumerate(x):
180
+ walk(xi, suf + f"[{i}]")
181
+ walk(out)
182
+
183
+ _stat_pat = re.compile(r"(\.grad\.(?:std|mean|l1)|\.act\.(?:std|mean|l1)|\.(?:std|mean|l1))$")
184
+
185
+ # Support multiple model naming schemes
186
+ _LAYER_PATTERNS = [
187
+ re.compile(r"\.h\.(\d+)\."), # transformer.h.<i>.
188
+ re.compile(r"\.layers\.(\d+)\."), # model.layers.<i>.
189
+ re.compile(r"\.decoder\.layers\.(\d+)\."), # decoder.layers.<i>.
190
+ re.compile(r"\.block\.(\d+)\."), # ...block.<i>.
191
+ ]
192
+
193
+ def _find_layer_span_and_idx(key: str):
194
+ for pat in _LAYer_PATTERNS if False else _LAYER_PATTERNS: # keep exact name
195
+ m = pat.search(key)
196
+ if m:
197
+ return m.span(0), int(m.group(1)) # span of ".layers.<i>." and the idx
198
+ return None, -1
199
+
200
+ def _layer_idx(key: str) -> int:
201
+ _, idx = _find_layer_span_and_idx(key)
202
+ return idx
203
+
204
+ def _base_key(key: str) -> str:
205
+ """Return <parameter-suffix>.<stat> without the layer index, e.g. mixer.linear_qkv.weight.std"""
206
+ span, _ = _find_layer_span_and_idx(key)
207
+ pre_cut = key
208
+ if span:
209
+ s, e = span
210
+ pre_cut = pre_cut[:s] + "." + pre_cut[e:] # collapse the layer segment to a single dot
211
+ # Drop common top-level prefixes
212
+ for prefix in ("transformer.", "model.", "module."):
213
+ if pre_cut.startswith(prefix):
214
+ pre_cut = pre_cut[len(prefix):]
215
+ stat_match = _stat_pat.search(pre_cut)
216
+ assert stat_match, f"No stat suffix in key {key}"
217
+ stat_suffix = stat_match.group(1)
218
+ base_no_stat = pre_cut[: -len(stat_suffix)]
219
+ return f"{base_no_stat}{stat_suffix}"
220
+
221
+ # ---------- main routine ---------- #
222
+
223
+ def show_layer_stats(model: nn.Module) -> str:
224
+ """Run a forward/backward pass and return aggregated stats in JSON.
225
+
226
+ The JSON schema is:
227
+ {
228
+ "attn.linear_qkv.weight.std": [layer0, layer1, ..., layerN],
229
+ "attn.linear_qkv.grad.std" : [...],
230
+ "attn.linear_qkv.act.std" : [...],
231
+ ...
232
+ }
233
+ Layers that do not have a value for a given statistic are represented with null.
234
+ Non‑layer parameters (e.g., embeddings) are kept flat as a single key‑value pair.
235
+ """
236
+
237
+ PAD = len(str(len(config_hf.layers_config) - 1))
238
+
239
+ # ----- collect activations ----- #
240
+ acts, hooks = {}, []
241
+ for n, m in model.named_modules():
242
+ if m is model:
243
+ continue # skip root
244
+ hooks.append(m.register_forward_hook(partial(_capture, n, acts)))
245
+
246
+ x = torch.randint(0, config_hf.vocab_size, (B, L), device="cuda")
247
+ y = torch.randint(0, config_hf.vocab_size, (B, L), device="cuda")
248
+ loss = model(input_ids=x, labels=y).loss
249
+ loss.backward()
250
+
251
+ # ----- collect stats (weight / grad / act) ----- #
252
+ raw_stats = {}
253
+ for n, p in model.named_parameters():
254
+ raw_stats[f"{n}.std"] = p.std().item()
255
+ #raw_stats[f"{n}.mean"] = p.mean().item()
256
+ raw_stats[f"{n}.l1"] = l1(p)
257
+ if p.grad is not None:
258
+ raw_stats[f"{n}.grad.std"] = p.grad.std().item()
259
+ #raw_stats[f"{n}.grad.mean"] = p.grad.mean().item()
260
+ raw_stats[f"{n}.grad.l1"] = l1(p.grad)
261
+ for n, a in acts.items():
262
+ raw_stats[f"{n}.act.std"] = a.std().item()
263
+ #raw_stats[f"{n}.act.mean"] = a.mean().item()
264
+ raw_stats[f"{n}.act.l1"] = l1(a)
265
+
266
+ # ----- aggregate across layers ----- #
267
+ agg: Dict[str, List] = defaultdict(lambda: [None] * len(config_hf.layers_config))
268
+ flat: Dict[str, float] = {}
269
+
270
+ for key, val in raw_stats.items():
271
+ layer = _layer_idx(key)
272
+ if layer == -1:
273
+ # params without layer index stay flat
274
+ flat[key] = val
275
+ continue
276
+ base = _base_key(key)
277
+ if layer < len(config_hf.layers_config):
278
+ agg[base][layer] = val
279
+ else:
280
+ # unexpected layer index; fall back to flat
281
+ flat[key] = val
282
+
283
+ # ----- merge flat & aggregated with custom sorting ----- #
284
+ stats = {}
285
+
286
+ # First: per-quantity arrays over layers
287
+ for base_key in sorted(agg.keys()):
288
+ stats[f"inspect/{base_key}"] = agg[base_key] # list of length = #layers (None where absent)
289
+
290
+ # Then: non-layer (“flat”) stats
291
+ for k, v in sorted(flat.items()):
292
+ stats[f"inspect/{k}"] = v
293
+
294
+ return stats
295
+
296
+ filename = "layer_stats.json"
297
+
298
+ json_blob = show_layer_stats(model)
299
+ with open(args.log_dir + filename, "w") as f:
300
+ if json_blob:
301
+ json.dump(json_blob, f, indent=2) # Use json.dump() instead of f.write()
302
+ print(f"✅ Saved layer stats to {args.log_dir + filename} ✅")
modeling_dragon.py CHANGED
The diff for this file is too large to render. See raw diff
 
optimizers/Ademamix.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from: https://pytorch.org/docs/1.6.0/_modules/torch/optim/adam.html
3
+ """
4
+ import math
5
+ import torch
6
+ from torch.optim import Optimizer
7
+
8
+
9
+ def linear_warmup_scheduler(step, alpha_end, alpha_start=0, warmup=1):
10
+ if step < warmup:
11
+ a = step / float(warmup)
12
+ return (1.0-a) * alpha_start + a * alpha_end
13
+ return alpha_end
14
+
15
+
16
+ def linear_hl_warmup_scheduler(step, beta_end, beta_start=0, warmup=1):
17
+
18
+ def f(beta, eps=1e-8):
19
+ return math.log(0.5)/math.log(beta+eps)-1
20
+
21
+ def f_inv(t):
22
+ return math.pow(0.5, 1/(t+1))
23
+
24
+ if step < warmup:
25
+ a = step / float(warmup)
26
+ return f_inv((1.0-a) * f(beta_start) + a * f(beta_end))
27
+ return beta_end
28
+
29
+
30
+ class AdEMAMix(Optimizer):
31
+ r"""Implements the AdEMAMix algorithm.
32
+
33
+ Arguments:
34
+ params (iterable): iterable of parameters to optimize or dicts defining
35
+ parameter groups
36
+ lr (float, optional): learning rate (default: 1e-3)
37
+ betas (Tuple[float, float, float], optional): coefficients used for computing
38
+ running averages of gradient and its square (default: (0.9, 0.999, 0.9999))
39
+ corresponding to beta_1, beta_2, beta_3 in AdEMAMix
40
+ alpha (float): AdEMAMix alpha coeficient mixing the slow and fast EMAs (default: 2)
41
+ beta3_warmup (int, optional): number of warmup steps used to increase beta3 (default: None)
42
+ alpha_warmup: (int, optional): number of warmup steps used to increase alpha (default: None)
43
+ eps (float, optional): term added to the denominator to improve
44
+ numerical stability (default: 1e-8)
45
+ weight_decay (float, optional): weight decay as in AdamW (default: 0)
46
+ """
47
+
48
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999, 0.999), alpha=8.0,
49
+ beta3_warmup=None, alpha_warmup=None, eps=1e-8,
50
+ weight_decay=0):
51
+ if not 0.0 <= lr:
52
+ raise ValueError("Invalid learning rate: {}".format(lr))
53
+ if not 0.0 <= eps:
54
+ raise ValueError("Invalid epsilon value: {}".format(eps))
55
+ if not 0.0 <= betas[0] < 1.0:
56
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
57
+ if not 0.0 <= betas[1] < 1.0:
58
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
59
+ if not 0.0 <= betas[2] < 1.0:
60
+ raise ValueError("Invalid beta parameter at index 2: {}".format(betas[2]))
61
+ if not 0.0 <= weight_decay:
62
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
63
+ if not 0.0 <= alpha:
64
+ raise ValueError("Invalid alpha value: {}".format(alpha))
65
+ defaults = dict(lr=lr, betas=betas, eps=eps, alpha=alpha, beta3_warmup=beta3_warmup,
66
+ alpha_warmup=alpha_warmup, weight_decay=weight_decay)
67
+ super(AdEMAMix, self).__init__(params, defaults)
68
+
69
+ def __setstate__(self, state):
70
+ super(AdEMAMix, self).__setstate__(state)
71
+
72
+ @torch.no_grad()
73
+ def step(self, closure=None):
74
+ """Performs a single optimization step.
75
+
76
+ Arguments:
77
+ closure (callable, optional): A closure that reevaluates the model
78
+ and returns the loss.
79
+ """
80
+ loss = None
81
+ if closure is not None:
82
+ with torch.enable_grad():
83
+ loss = closure()
84
+
85
+ for group in self.param_groups:
86
+
87
+ lr = group["lr"]
88
+ lmbda = group["weight_decay"]
89
+ eps = group["eps"]
90
+ beta1, beta2, beta3_final = group["betas"]
91
+ beta3_warmup = group["beta3_warmup"]
92
+ alpha_final = group["alpha"]
93
+ alpha_warmup = group["alpha_warmup"]
94
+
95
+ for p in group['params']:
96
+ if p.grad is None:
97
+ continue
98
+ grad = p.grad
99
+ if grad.is_sparse:
100
+ raise RuntimeError('AdEMAMix does not support sparse gradients.')
101
+
102
+ state = self.state[p]
103
+
104
+ # State initialization
105
+ if len(state) == 0:
106
+ state['step'] = 0
107
+ # Exponential moving average of gradient values
108
+ if beta1 != 0.0: # save memory in case beta1 is 0.0
109
+ state['exp_avg_fast'] = torch.zeros_like(p, memory_format=torch.preserve_format)
110
+ else:
111
+ state['exp_avg_fast'] = None
112
+ state['exp_avg_slow'] = torch.zeros_like(p, memory_format=torch.preserve_format)
113
+ # Exponential moving average of squared gradient values
114
+ state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
115
+
116
+ exp_avg_fast, exp_avg_slow, exp_avg_sq = state['exp_avg_fast'], state['exp_avg_slow'], state['exp_avg_sq']
117
+
118
+ state['step'] += 1
119
+ bias_correction1 = 1 - beta1 ** state['step']
120
+ bias_correction2 = 1 - beta2 ** state['step']
121
+
122
+ # Compute the effective alpha and beta3 in case warmup is used
123
+ if alpha_warmup is not None:
124
+ alpha = linear_warmup_scheduler(state["step"], alpha_end=alpha_final, alpha_start=0, warmup=alpha_warmup)
125
+ else:
126
+ alpha = alpha_final
127
+
128
+ if beta3_warmup is not None:
129
+ beta3 = linear_hl_warmup_scheduler(state["step"], beta_end=beta3_final, beta_start=beta1, warmup=beta3_warmup)
130
+ else:
131
+ beta3 = beta3_final
132
+
133
+ # Decay the first and second moment running average coefficient
134
+ if beta1 != 0.0:
135
+ exp_avg_fast.mul_(beta1).add_(grad, alpha=1 - beta1)
136
+ else:
137
+ exp_avg_fast = grad
138
+ exp_avg_slow.mul_(beta3).add_(grad, alpha=1 - beta3)
139
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
140
+
141
+ denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
142
+
143
+ update = (exp_avg_fast.div(bias_correction1) + alpha * exp_avg_slow) / denom
144
+
145
+ # decay
146
+ update.add_(p, alpha=lmbda)
147
+
148
+ p.add_(-lr * update)
149
+
150
+ return loss
151
+
152
+
153
+ if __name__ == "__main__": # small dummy test
154
+
155
+ x = torch.randn((10,7))
156
+ model = torch.nn.Linear(7, 1, bias=False)
157
+ opt = AdEMAMix(params=model.parameters(), lr=1e-2, betas=(0.9, 0.999, 0.9999), alpha=2.0, beta3_warmup=45, alpha_warmup=45, weight_decay=0.1)
158
+ print(model.weight)
159
+ for itr in range(50):
160
+ y = model(x).mean()
161
+ opt.zero_grad()
162
+ y.backward()
163
+ opt.step()
164
+
165
+ print(model.weight)
optimizers/Snoo.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+
6
+ class Snoo:
7
+ """
8
+ @DominikKallusky, @vishal9-team, @vinaysrao
9
+ Sparse Nesterov Outer Optimizer (Snoo) is a momentum-based wrapper to any optimizer that can
10
+ improve the stability and smoothness of the optimization process and thus the quality
11
+ of large language models (LLM) and other models. Snoo implicitly adds temporal regularization
12
+ to the parameters, thus smoothing the training trajectory and instilling a bias towards flatter
13
+ minima and lower parameter norms. Snoo is computationally efficient, incurring minimal overhead
14
+ in compute and moderate memory usage.
15
+ """
16
+
17
+ @torch.no_grad()
18
+ def __init__(self, model: nn.Module, lr: float, momentum: float, k: int) -> None:
19
+ self.model = model
20
+ self.lr = lr
21
+ self.momentum = momentum
22
+ self.k = k
23
+ self.current_step = 0
24
+ self.outer_buf = [p.clone() for p in model.parameters()]
25
+ self.model_params = list(self.model.parameters())
26
+ self.optimizer = torch.optim.SGD(
27
+ self.model.parameters(),
28
+ lr=lr,
29
+ momentum=momentum,
30
+ nesterov=True,
31
+ fused=True,
32
+ )
33
+
34
+ @torch.no_grad()
35
+ def step(
36
+ self,
37
+ ) -> None:
38
+ if self.current_step % self.k == 0:
39
+ for p_new, p_old in zip(self.model_params, self.outer_buf):
40
+ p_new.grad = p_old.data - p_new.data
41
+ p_new.copy_(p_old, non_blocking=True)
42
+
43
+ self.optimizer.step()
44
+
45
+ for p_new, p_old in zip(self.model_params, self.outer_buf):
46
+ p_old.copy_(p_new, non_blocking=True)
47
+ self.current_step += 1
48
+
49
+ def state_dict(self):
50
+ state_dict = {
51
+ "current_step": self.current_step,
52
+ "lr": self.lr,
53
+ "momentum": self.momentum,
54
+ "k": self.k,
55
+ "outer_buf": [p.clone() for p in self.outer_buf],
56
+ "optimizer_state_dict": self.optimizer.state_dict(),
57
+ }
58
+ return state_dict
59
+
60
+ def load_state_dict(self, state_dict):
61
+ self.current_step = state_dict["current_step"]
62
+ self.lr = state_dict["lr"]
63
+ self.momentum = state_dict["momentum"]
64
+ self.k = state_dict["k"]
65
+ for p_src, p_dst in zip(state_dict["outer_buf"], self.outer_buf):
66
+ p_dst.copy_(p_src)
67
+ self.optimizer.load_state_dict(state_dict["optimizer_state_dict"])
optimizers/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .Ademamix import AdEMAMix
2
+ from .Snoo import Snoo
training_dragon.py CHANGED
@@ -32,9 +32,13 @@ class NanoArgs:
32
  # arch - general
33
  d_model : int = 768
34
  n_heads : int = 6 # head dim 128 suggested by @Grad62304977
 
35
  layers_config : str = 4*"lrdlr"
36
- expand_factor : int = 1 # expand factor for Mamba/Dragon
 
 
37
  rope_theta_local: float = 10000.0
 
38
  eps_rmsnorm: float = 1e-6
39
  mlp_expand: int = 4 # expand factor for MLP
40
  fused_loss_computation : bool = True # whether to use fused linear + cross entropy loss
@@ -42,9 +46,16 @@ class NanoArgs:
42
  uscaling_tau: float = 0.2
43
  zero_centered_gamma: bool = False
44
  zero_centered_gate: bool = False
45
- zero_centered_gate_type: int = 1 # 1, 2, 3
46
  gate_attn: bool = False
47
  gate_gdn: bool = True
 
 
 
 
 
 
 
48
 
49
  # attention related
50
  n_kv_heads : int = 0
@@ -56,24 +67,36 @@ class NanoArgs:
56
  softcap_global_attn: float = 0.0
57
  qk_norm: bool = True
58
  scalable_softmax: bool = True
 
 
 
 
 
59
  num_attention_heads_indexer: int = 8
60
  head_dim_indexer: int = 32
61
  dsa_q_lora_rank: int = 128
62
  dsa_topk: int = 512
63
- cca_head_dim: int = 128
64
  cca_seq_kernel_size: int = 4
65
- nsa_head_dim: int = 128
66
  nsa_topk: int = 16
67
  nsa_block_size: int = 64
68
  nsa_window_size: int = 512
 
 
 
 
69
 
70
  # GDN related
71
  rope_gdn: Optional[str] = None # None, rope, (srope)
 
72
  n_heads_gdn: int = 0
73
  n_kv_heads_gdn: int = 0
 
 
 
74
 
75
  # optim
76
  optim: str = "adamw" # adamw, spam, stable-spam, muon, muon_moonlight, splus
 
77
  batch_size: int = 8*64 # batch size, in sequences, across all devices
78
  device_batch_size: int = 64 # batch size, in sequences, per device
79
  total_iterations: int = 1000 # number of iterations to run
@@ -91,14 +114,13 @@ class NanoArgs:
91
  init_std: float = 0.006
92
  patch_level_training: bool = False
93
  patch_level_training_size: int = 4
94
- patch_level_training_mode: str = "reduced" # reduced = ask L tokens, treat L//K. full = ask K*L tokens, treat L.
 
 
95
 
96
  # data
97
  vocab_size: int = 50304
98
  sequence_length: int = 1024
99
- use_patch_level_training: bool = False
100
- patch_size: int = 4
101
- patch_training_fraction: float = 0.67
102
  input_bin: Optional[str] = None
103
  input_val_bin: Optional[str] = None
104
 
@@ -213,11 +235,15 @@ def param_groups_mup(model, base_lr_hidden, base_lr_scalar, base_lr_embed, base_
213
  for mod in model.modules():
214
  if isinstance(mod, nn.Linear):
215
  pname = id2name.get(id(mod.weight), "")
 
216
  fan_in = mod.weight.shape[1]
217
  scale = 1 / math.sqrt(fan_in)
218
  if "lm_head" in pname:
219
  lr_scaled = base_lr_head
220
  wd_scaled = 0.0
 
 
 
221
  else:
222
  lr_scaled = base_lr_hidden * scale
223
  wd_scaled = wd / lr_scaled
@@ -226,7 +252,7 @@ def param_groups_mup(model, base_lr_hidden, base_lr_scalar, base_lr_embed, base_
226
  seen.add(mod.weight)
227
 
228
  if mod.bias is not None:
229
- groups.append({"params": [mod.bias], "lr": lr_scaled, "weight_decay": 0.0})
230
  seen.add(mod.bias)
231
 
232
  for p in model.parameters():
@@ -235,13 +261,17 @@ def param_groups_mup(model, base_lr_hidden, base_lr_scalar, base_lr_embed, base_
235
  pname = id2name.get(id(p), "<unnamed>")
236
 
237
  if "embedding" in pname:
238
- fan_out = p.shape[1] # nn.Embedding is transposed
239
  #lr_scaled = base_lr / math.sqrt(fan_out) # u-muP
240
  lr_scaled = base_lr_embed
241
  else:
242
  lr_scaled = base_lr_scalar
243
 
244
- groups.append({"params": [p], "lr": lr_scaled, "weight_decay": 0.})
 
 
 
 
245
 
246
  return groups
247
 
@@ -299,11 +329,13 @@ if master_process:
299
  with open(f'{logdir}/args.json', 'w') as f: json.dump(vars(args), f)
300
  with open(f'{logdir}/args.pkl', 'wb') as f: pickle.dump(args, f)
301
  def print0(s, console=True):
302
- if master_process:
303
- with open(logfile, "a") as f:
304
- if console:
305
- print(s)
306
- print(s, file=f)
 
 
307
  if resume_dir is not None and args.load_arg_from_config:
308
  saved_args_path = os.path.join(os.path.dirname(resume_dir), "args.pkl")
309
  print0(f"Loading args from {saved_args_path}")
@@ -326,16 +358,14 @@ np.random.seed(seed)
326
 
327
  # define convenience variables.
328
  B, T = args.device_batch_size, args.sequence_length
 
 
329
  assert args.batch_size % (B * ddp_world_size) == 0
330
  accumulation_steps = args.batch_size // (B * ddp_world_size)
331
 
332
  # load dataloaders.
333
- if args.patch_level_training:
334
- if args.patch_level_training_mode == "reduced":
335
- assert T % args.patch_level_training_size == 0, "sequence length must be divisible by patch level training size in reduced mode"
336
- T = T
337
- elif args.patch_level_training_mode == "full":
338
- T = T * args.patch_level_training_size
339
  train_loader = DistributedDataLoader(args.input_bin, B, T, ddp_rank, ddp_world_size)
340
  val_loader = DistributedDataLoader(args.input_val_bin, B, T, ddp_rank, ddp_world_size)
341
  print0(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files")
@@ -343,19 +373,38 @@ print0(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total}
343
 
344
  # load model.
345
  config_hf = DragonConfig(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  patch_level_training=args.patch_level_training,
347
  patch_level_training_size=args.patch_level_training_size,
348
- nsa_head_dim=args.nsa_head_dim,
349
  nsa_topk=args.nsa_topk,
350
  nsa_block_size=args.nsa_block_size,
351
  nsa_window_size=args.nsa_window_size,
352
- cca_head_dim=args.cca_head_dim,
353
  cca_seq_kernel_size=args.cca_seq_kernel_size,
 
 
354
  num_attention_heads_gdn=args.n_heads_gdn,
355
  num_key_value_heads_gdn=args.n_kv_heads_gdn,
356
  zero_centered_gate=args.zero_centered_gate,
357
  zero_centered_gate_type=args.zero_centered_gate_type,
358
  scalable_softmax=args.scalable_softmax,
 
 
 
359
  gate_attn=args.gate_attn,
360
  gate_gdn=args.gate_gdn,
361
  fused_loss_computation=args.fused_loss_computation,
@@ -380,15 +429,19 @@ config_hf = DragonConfig(
380
  norm_epsilon=args.eps_rmsnorm,
381
  use_cache=False,
382
  sliding_window_size=args.swa_window_size,
 
 
 
383
  rope_theta_local=args.rope_theta_local,
384
  uscaling_tau=args.uscaling_tau,
 
385
  )
386
 
387
  if resume_dir is None:
388
  model = DragonForCausalLM(config_hf)
389
  model = model.cuda()
390
  else:
391
- model = DragonForCausalLM.from_pretrained(resume_dir, torch_dtype=torch.bfloat16)
392
  model = model.cuda()
393
  print0(model)
394
 
@@ -421,12 +474,13 @@ print0(f"number of total parameters: {num_params}")
421
  uncompiled_model = model
422
  model = torch.compile(model, dynamic=True) if args.compile else model
423
  model.train()
424
- model = DDP(model, device_ids=[ddp_local_rank])
425
  raw_model = model.module
426
  ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16)
427
 
428
  # load optimizers & schedulers.
429
  if args.use_uscaling:
 
430
  param_list = param_groups_mup(
431
  raw_model,
432
  base_lr_hidden=args.learning_rate,
@@ -435,9 +489,30 @@ if args.use_uscaling:
435
  base_lr_head=args.uscaling_mult_head*args.learning_rate if args.uscaling_mult_head > 0 else args.learning_rate,
436
  wd=args.weight_decay,
437
  )
438
- optimizer = torch.optim.AdamW(param_list, betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps)
 
 
 
 
 
 
 
439
  else:
440
- optimizer = torch.optim.AdamW(raw_model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay, betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
  optimizers = [optimizer]
442
 
443
  def get_lr_wsd(num_iterations, warmup_iters, warmdown_iters, it):
@@ -478,12 +553,13 @@ WARMUP_SKIP = 10
478
 
479
  # begin training.
480
  train_loader.reset()
481
- tokenizer = transformers.AutoTokenizer.from_pretrained("openai-community/gpt2", use_fast=True) # for saving
 
482
  x, y = train_loader.next_batch()
483
 
484
- for iter_ in range(start_iter, args.total_iterations+1):
485
- last_iter = (iter_ == args.total_iterations)
486
- if iter_ == WARMUP_SKIP:
487
  training_time_ms = 0
488
  t0 = time.perf_counter()
489
  to_log = {}
@@ -521,7 +597,7 @@ for iter_ in range(start_iter, args.total_iterations+1):
521
  model.train()
522
 
523
  # log.
524
- print0(f'iteration:{iter_:0{len(str(args.total_iterations))}d}/{args.total_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms')
525
  if master_process:
526
  wandb.log({"val_loss": val_loss}, step=iter_)
527
 
@@ -530,7 +606,7 @@ for iter_ in range(start_iter, args.total_iterations+1):
530
  t0 = time.perf_counter()
531
 
532
  # ----------- SAVING SECTION -----------
533
- if master_process and iter_ > start_iter and (last_iter or (args.save_every > 0 and iter_ % args.save_every == 0)):
534
  # stop the clock.
535
  torch.cuda.synchronize()
536
  training_time_ms += 1000 * (time.perf_counter() - t0)
@@ -584,14 +660,16 @@ for iter_ in range(start_iter, args.total_iterations+1):
584
  for opt, sched in zip(optimizers, schedulers):
585
  opt.step()
586
  sched.step()
 
 
587
  # null those gradients.
588
  model.zero_grad(set_to_none=True)
589
 
590
  # ----------- LOGGING SECTION -----------
591
  approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0)
592
- avg_step_time = approx_training_time_ms / (iter_ + 1 - WARMUP_SKIP) if iter_ >= WARMUP_SKIP else 0
593
  extra = " ".join(f"{k}:{v}" for k, v in (to_log or {}).items())
594
- print0(f"iteration:{iter_+1:0{len(str(args.total_iterations))}d}/{args.total_iterations} train_loss:{train_loss.item():.4f} lr: {schedulers[0].get_last_lr()[0]:.4f} train_time:{approx_training_time_ms:.0f}ms step_avg:{avg_step_time:.2f}ms {extra}")
595
  if master_process:
596
  wandb.log({'train_loss': train_loss.item(), 'step_avg_time': avg_step_time, **{f'lr_{i}': sched.get_last_lr()[0] for i, sched in enumerate(schedulers)}, 'grad_norm': grad_norm.item(), **to_log}, step=iter_)
597
 
 
32
  # arch - general
33
  d_model : int = 768
34
  n_heads : int = 6 # head dim 128 suggested by @Grad62304977
35
+ head_dim: Optional[int] = None
36
  layers_config : str = 4*"lrdlr"
37
+ expand_factor : int = 2 # expand factor for Mamba/Dragon
38
+ rope_type_local: str = "rope" #p-rope
39
+ rope_type_global: str = "rope" #p-rope
40
  rope_theta_local: float = 10000.0
41
+ rope_theta_global: float = 0.0
42
  eps_rmsnorm: float = 1e-6
43
  mlp_expand: int = 4 # expand factor for MLP
44
  fused_loss_computation : bool = True # whether to use fused linear + cross entropy loss
 
46
  uscaling_tau: float = 0.2
47
  zero_centered_gamma: bool = False
48
  zero_centered_gate: bool = False
49
+ zero_centered_gate_type: int = 1 # 1, 2, 3, 4
50
  gate_attn: bool = False
51
  gate_gdn: bool = True
52
+ gate_type: str = "elementwise" # elementwise (one per dim), headwise (one per head), kimi (lora)
53
+ gate_act: str = "silu" # silu, sigmoid
54
+ scalar_proj_as_hidden_matrix: bool = True
55
+ normalization_type: str = "rmsnorm" # rmsnorm, seednorm
56
+ seednorm_wd: bool = True
57
+ mixer_gn: bool = True
58
+ mlp_linking : bool = False
59
 
60
  # attention related
61
  n_kv_heads : int = 0
 
67
  softcap_global_attn: float = 0.0
68
  qk_norm: bool = True
69
  scalable_softmax: bool = True
70
+ resformer : bool = False # Works only on f layers (DiffAttention)
71
+ token_shift_attn: bool = False
72
+ token_shift_gdn: bool = False
73
+ token_conv1d_attn: bool = False
74
+ token_conv1d_gdn: bool = True
75
  num_attention_heads_indexer: int = 8
76
  head_dim_indexer: int = 32
77
  dsa_q_lora_rank: int = 128
78
  dsa_topk: int = 512
 
79
  cca_seq_kernel_size: int = 4
 
80
  nsa_topk: int = 16
81
  nsa_block_size: int = 64
82
  nsa_window_size: int = 512
83
+ num_signal_heads_diff: Optional[int] = None
84
+ tpa_rank: int = 2
85
+ shrink_qk_da: int = 2
86
+ mla_kv_rank: int = 128
87
 
88
  # GDN related
89
  rope_gdn: Optional[str] = None # None, rope, (srope)
90
+ head_dim_gdn: Optional[int] = None
91
  n_heads_gdn: int = 0
92
  n_kv_heads_gdn: int = 0
93
+ shrink_qk_gdn: int = 2
94
+ kda_allow_neg_eigval: bool = False
95
+ kda_num_v_heads: Optional[int] = None
96
 
97
  # optim
98
  optim: str = "adamw" # adamw, spam, stable-spam, muon, muon_moonlight, splus
99
+ second_order_optim : Optional[str] = None # snoo
100
  batch_size: int = 8*64 # batch size, in sequences, across all devices
101
  device_batch_size: int = 64 # batch size, in sequences, per device
102
  total_iterations: int = 1000 # number of iterations to run
 
114
  init_std: float = 0.006
115
  patch_level_training: bool = False
116
  patch_level_training_size: int = 4
117
+ second_order_lr: float = 0.68
118
+ second_order_momentum: float = 0.37
119
+ second_order_interval: int = 25
120
 
121
  # data
122
  vocab_size: int = 50304
123
  sequence_length: int = 1024
 
 
 
124
  input_bin: Optional[str] = None
125
  input_val_bin: Optional[str] = None
126
 
 
235
  for mod in model.modules():
236
  if isinstance(mod, nn.Linear):
237
  pname = id2name.get(id(mod.weight), "")
238
+ is_scalar = getattr(mod, "is_scalar_weight", False)
239
  fan_in = mod.weight.shape[1]
240
  scale = 1 / math.sqrt(fan_in)
241
  if "lm_head" in pname:
242
  lr_scaled = base_lr_head
243
  wd_scaled = 0.0
244
+ elif is_scalar:
245
+ lr_scaled = base_lr_scalar
246
+ wd_scaled = 0.0
247
  else:
248
  lr_scaled = base_lr_hidden * scale
249
  wd_scaled = wd / lr_scaled
 
252
  seen.add(mod.weight)
253
 
254
  if mod.bias is not None:
255
+ groups.append({"params": [mod.bias], "lr": base_lr_scalar, "weight_decay": 0.0})
256
  seen.add(mod.bias)
257
 
258
  for p in model.parameters():
 
261
  pname = id2name.get(id(p), "<unnamed>")
262
 
263
  if "embedding" in pname:
264
+ #fan_out = p.shape[1] # nn.Embedding is transposed
265
  #lr_scaled = base_lr / math.sqrt(fan_out) # u-muP
266
  lr_scaled = base_lr_embed
267
  else:
268
  lr_scaled = base_lr_scalar
269
 
270
+ wd_scaled = 0.
271
+ if getattr(p, "requires_weight_decay", False):
272
+ wd_scaled = wd / lr_scaled
273
+
274
+ groups.append({"params": [p], "lr": lr_scaled, "weight_decay": wd_scaled})
275
 
276
  return groups
277
 
 
329
  with open(f'{logdir}/args.json', 'w') as f: json.dump(vars(args), f)
330
  with open(f'{logdir}/args.pkl', 'wb') as f: pickle.dump(args, f)
331
  def print0(s, console=True):
332
+ if not master_process: return
333
+ if console:
334
+ print(s)
335
+ try:
336
+ d=os.path.dirname(logfile); d and os.makedirs(d, exist_ok=True)
337
+ with open(logfile, "a", encoding="utf-8") as f: print(s, file=f)
338
+ except: pass
339
  if resume_dir is not None and args.load_arg_from_config:
340
  saved_args_path = os.path.join(os.path.dirname(resume_dir), "args.pkl")
341
  print0(f"Loading args from {saved_args_path}")
 
358
 
359
  # define convenience variables.
360
  B, T = args.device_batch_size, args.sequence_length
361
+ if args.patch_level_training:
362
+ T = args.patch_level_training_size * T
363
  assert args.batch_size % (B * ddp_world_size) == 0
364
  accumulation_steps = args.batch_size // (B * ddp_world_size)
365
 
366
  # load dataloaders.
367
+ #if args.patch_level_training:
368
+ # assert T % args.patch_level_training_size == 0, "sequence length must be divisible by patch level training size in reduced mode"
 
 
 
 
369
  train_loader = DistributedDataLoader(args.input_bin, B, T, ddp_rank, ddp_world_size)
370
  val_loader = DistributedDataLoader(args.input_val_bin, B, T, ddp_rank, ddp_world_size)
371
  print0(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files")
 
373
 
374
  # load model.
375
  config_hf = DragonConfig(
376
+ mla_kv_rank=args.mla_kv_rank,
377
+ rope_gdn=args.rope_gdn,
378
+ shrink_qk_da=args.shrink_qk_da,
379
+ shrink_qk_gdn=args.shrink_qk_gdn,
380
+ mixer_gn=args.mixer_gn,
381
+ kda_allow_neg_eigval=args.kda_allow_neg_eigval,
382
+ kda_num_v_heads=args.kda_num_v_heads,
383
+ seednorm_wd=args.seednorm_wd,
384
+ normalization_type=args.normalization_type,
385
+ tpa_rank=args.tpa_rank,
386
+ num_signal_heads_diff=args.num_signal_heads_diff,
387
+ scalar_proj_as_hidden_matrix=args.scalar_proj_as_hidden_matrix,
388
+ token_shift_attn=args.token_shift_attn,
389
+ token_shift_gdn=args.token_shift_gdn,
390
+ token_conv1d_attn=args.token_conv1d_attn,
391
+ token_conv1d_gdn=args.token_conv1d_gdn,
392
  patch_level_training=args.patch_level_training,
393
  patch_level_training_size=args.patch_level_training_size,
 
394
  nsa_topk=args.nsa_topk,
395
  nsa_block_size=args.nsa_block_size,
396
  nsa_window_size=args.nsa_window_size,
 
397
  cca_seq_kernel_size=args.cca_seq_kernel_size,
398
+ head_dim=args.head_dim,
399
+ head_dim_gdn=args.head_dim_gdn,
400
  num_attention_heads_gdn=args.n_heads_gdn,
401
  num_key_value_heads_gdn=args.n_kv_heads_gdn,
402
  zero_centered_gate=args.zero_centered_gate,
403
  zero_centered_gate_type=args.zero_centered_gate_type,
404
  scalable_softmax=args.scalable_softmax,
405
+ resformer=args.resformer,
406
+ gate_type=args.gate_type,
407
+ gate_act=args.gate_act,
408
  gate_attn=args.gate_attn,
409
  gate_gdn=args.gate_gdn,
410
  fused_loss_computation=args.fused_loss_computation,
 
429
  norm_epsilon=args.eps_rmsnorm,
430
  use_cache=False,
431
  sliding_window_size=args.swa_window_size,
432
+ rope_type_global=args.rope_type_global,
433
+ rope_type_local=args.rope_type_local,
434
+ rope_theta_global=args.rope_theta_global,
435
  rope_theta_local=args.rope_theta_local,
436
  uscaling_tau=args.uscaling_tau,
437
+ mlp_linking=args.mlp_linking
438
  )
439
 
440
  if resume_dir is None:
441
  model = DragonForCausalLM(config_hf)
442
  model = model.cuda()
443
  else:
444
+ model = DragonForCausalLM.from_pretrained(resume_dir, config=config_hf, torch_dtype=torch.bfloat16)
445
  model = model.cuda()
446
  print0(model)
447
 
 
474
  uncompiled_model = model
475
  model = torch.compile(model, dynamic=True) if args.compile else model
476
  model.train()
477
+ model = DDP(model, device_ids=[ddp_local_rank], find_unused_parameters=args.resformer)
478
  raw_model = model.module
479
  ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16)
480
 
481
  # load optimizers & schedulers.
482
  if args.use_uscaling:
483
+ #assert args.optim == "adamw", "uscaling is only supported with AdamW optimizer currently"
484
  param_list = param_groups_mup(
485
  raw_model,
486
  base_lr_hidden=args.learning_rate,
 
489
  base_lr_head=args.uscaling_mult_head*args.learning_rate if args.uscaling_mult_head > 0 else args.learning_rate,
490
  wd=args.weight_decay,
491
  )
492
+ if args.optim == "adamw":
493
+ optimizer = torch.optim.AdamW(param_list, betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps)
494
+ elif args.optim == "ademamix":
495
+ from .optimizers.Ademamix import AdEMAMix
496
+ beta3_warmup = alpha_warmup = args.total_iterations
497
+ optimizer = AdEMAMix(param_list, beta3_warmup=beta3_warmup, alpha_warmup=alpha_warmup, weight_decay=args.weight_decay)
498
+ else:
499
+ raise ValueError(f"Unknown optimizer for unit scaling: {args.optim}")
500
  else:
501
+ if args.optim == "adamw":
502
+ optimizer = torch.optim.AdamW(raw_model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay, betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps)
503
+ elif args.optim == "ademamix":
504
+ from .optimizers.Ademamix import AdEMAMix
505
+
506
+ beta3_warmup = alpha_warmup = args.total_iterations
507
+ optimizer = AdEMAMix(raw_model.parameters(), lr=args.learning_rate, beta3_warmup=beta3_warmup, alpha_warmup=alpha_warmup, weight_decay=args.weight_decay)
508
+ else:
509
+ raise ValueError(f"Unknown Optimizer: {args.optim}")
510
+ if args.second_order_optim == "snoo":
511
+ from .optimizers.Snoo import Snoo
512
+ second_order_optim = Snoo(raw_model, lr=args.second_order_lr, momentum=args.second_order_momentum, k=args.second_order_interval)
513
+ else:
514
+ second_order_optim = None
515
+
516
  optimizers = [optimizer]
517
 
518
  def get_lr_wsd(num_iterations, warmup_iters, warmdown_iters, it):
 
553
 
554
  # begin training.
555
  train_loader.reset()
556
+ #tokenizer = transformers.AutoTokenizer.from_pretrained("openai-community/gpt2", use_fast=True) # for saving
557
+ tokenizer = transformers.AutoTokenizer.from_pretrained("/leonardo_work/BOOST_LCustodi/script/training/temp/hf_models/gpt2", use_fast=True)
558
  x, y = train_loader.next_batch()
559
 
560
+ for iter_ in range(start_iter, start_iter+args.total_iterations+1):
561
+ last_iter = (iter_ == start_iter+args.total_iterations)
562
+ if iter_ == start_iter+WARMUP_SKIP:
563
  training_time_ms = 0
564
  t0 = time.perf_counter()
565
  to_log = {}
 
597
  model.train()
598
 
599
  # log.
600
+ print0(f'iteration:{iter_:0{len(str(start_iter+args.total_iterations))}d}/{args.total_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms')
601
  if master_process:
602
  wandb.log({"val_loss": val_loss}, step=iter_)
603
 
 
606
  t0 = time.perf_counter()
607
 
608
  # ----------- SAVING SECTION -----------
609
+ if master_process and (last_iter or (args.save_every > 0 and iter_ % args.save_every == 0)):
610
  # stop the clock.
611
  torch.cuda.synchronize()
612
  training_time_ms += 1000 * (time.perf_counter() - t0)
 
660
  for opt, sched in zip(optimizers, schedulers):
661
  opt.step()
662
  sched.step()
663
+ if second_order_optim:
664
+ second_order_optim.step()
665
  # null those gradients.
666
  model.zero_grad(set_to_none=True)
667
 
668
  # ----------- LOGGING SECTION -----------
669
  approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0)
670
+ avg_step_time = approx_training_time_ms / (iter_ + 1 - WARMUP_SKIP) if iter_ >= start_iter+WARMUP_SKIP else 0
671
  extra = " ".join(f"{k}:{v}" for k, v in (to_log or {}).items())
672
+ print0(f"iteration:{iter_+1:0{len(str(start_iter+args.total_iterations))}d}/{args.total_iterations} train_loss:{train_loss.item():.4f} lr: {schedulers[0].get_last_lr()[0]:.4f} train_time:{approx_training_time_ms:.0f}ms step_avg:{avg_step_time:.2f}ms {extra}")
673
  if master_process:
674
  wandb.log({'train_loss': train_loss.item(), 'step_avg_time': avg_step_time, **{f'lr_{i}': sched.get_last_lr()[0] for i, sched in enumerate(schedulers)}, 'grad_norm': grad_norm.item(), **to_log}, step=iter_)
675