alexandretl commited on
Commit
d79da9a
·
1 Parent(s): 3b164a1

alpha normalize ademamix | mamba norms and gate | VWN | wnorm (nemotron-flash) | MG equivalence | fix IDM config saving | CCAv2 | MoBA | reduce lm head

Browse files
compute_loss.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import pickle
4
+ from dataclasses import dataclass
5
+ from typing import Optional
6
+ import tyro
7
+ from tqdm.auto import tqdm
8
+ import numpy as np
9
+
10
+ import torch
11
+
12
+ from .configuration_dragon import DragonConfig
13
+ from .modeling_dragon import DragonForCausalLM
14
+
15
+ @dataclass
16
+ class Args:
17
+ load_dir: str
18
+ val_bin: str
19
+
20
+ @dataclass
21
+ class NanoArgs:
22
+ resume_from: Optional[str] = None
23
+ run_name : str = ""
24
+
25
+ # arch - general
26
+ d_model : int = 768
27
+ n_heads : int = 6 # head dim 128 suggested by @Grad62304977
28
+ head_dim: Optional[int] = None
29
+ layers_config : str = 4*"lrdlr"
30
+ expand_factor : int = 2 # expand factor for Mamba/Dragon
31
+ rope_type_local: str = "" #p-rope
32
+ rope_type_global: str = "" #p-rope
33
+ rope_theta_local: float = 10000.0
34
+ rope_theta_global: float = 0.0
35
+ eps_rmsnorm: float = 1e-6
36
+ mlp_expand: int = 4 # expand factor for MLP
37
+ fused_loss_computation : bool = True # whether to use fused linear + cross entropy loss
38
+ use_uscaling: bool = False
39
+ uscaling_tau: float = 0.2
40
+ zero_centered_gamma: bool = False
41
+ zero_centered_gate: bool = False
42
+ zero_centered_gate_type: int = 1 # 1, 2, 3, 4
43
+ gate_attn: bool = False
44
+ gate_gdn: bool = True
45
+ gate_type: str = "elementwise" # elementwise (one per dim), headwise (one per head), kimi (lora)
46
+ gate_act: str = "silu" # silu, sigmoid
47
+ scalar_proj_as_hidden_matrix: bool = True
48
+ normalization_type: str = "rmsnorm" # rmsnorm, seednorm
49
+ seednorm_wd: bool = True
50
+ seednorm_type: int = 1
51
+ seednorm_rank: int = 1
52
+ mixer_gn: bool = True
53
+ mlp_linking : bool = False
54
+ final_norm: bool = True
55
+ layer_norm_scaling: bool = False # not read when using muP
56
+ mlp_type: str = "simple" # simple, gated
57
+ tie_lm_head: bool = False
58
+
59
+ # MoE
60
+ moe: bool = False
61
+ moe_num_routed_experts: int = 2
62
+ moe_routed_scaling_factor: float = 2.5
63
+ moe_routed_intermediate_size: int = 768
64
+ moe_shared_intermediate_size: int = 768
65
+
66
+ # attention related
67
+ n_kv_heads : int = 0
68
+ swa_window_size : int = 1024
69
+ slw_warmup_iters: float = 0
70
+ slw_start: int = 8 # window size at the start of training
71
+ slw_increment: int = 64 # window size increment at each step
72
+ softcap_local_attn: float = 0.0 # logit soft-capping for local attn logits, as per Gemma2 (0.0 = no soft-capping)
73
+ softcap_global_attn: float = 0.0
74
+ qk_norm: bool = True
75
+ scalable_softmax: bool = True
76
+ resformer : bool = False # Works only on f layers (DiffAttention)
77
+ token_shift_attn: bool = False
78
+ token_shift_gdn: bool = False
79
+ token_conv1d_attn: bool = False
80
+ token_conv1d_gdn: bool = True
81
+ num_attention_heads_indexer: int = 8
82
+ head_dim_indexer: int = 32
83
+ dsa_q_lora_rank: int = 128
84
+ dsa_topk: int = 512
85
+ cca_seq_kernel_size: int = 4
86
+ nsa_topk: int = 16
87
+ nsa_block_size: int = 64
88
+ nsa_window_size: int = 512
89
+ num_signal_heads_diff: Optional[int] = None
90
+ tpa_rank: int = 2
91
+ shrink_qk_da: int = 2
92
+ mla_kv_rank: int = 128
93
+
94
+ # GDN related
95
+ rope_gdn: Optional[str] = None # None, rope, (srope)
96
+ head_dim_gdn: Optional[int] = None
97
+ n_heads_gdn: int = 0
98
+ n_kv_heads_gdn: int = 0
99
+ shrink_qk_gdn: int = 2
100
+ kda_allow_neg_eigval: bool = False
101
+ kda_num_v_heads: Optional[int] = None
102
+ mamba_mimo_dim: Optional[int] = 2
103
+ mamba_ngroups: Optional[int] = 1
104
+ mamba_d_state: int = 128
105
+ mamba_headdim: int = 64
106
+ mamba3_rope: bool = True
107
+ mamba3_remove_BC_bias: bool = False
108
+ mamba3_is_id_rms: bool = True
109
+ mamba3_remove_conv: bool = True
110
+ mamba3_is_A_dd: bool = True
111
+ mamba3_add_trapezoid: bool = True
112
+
113
+ # optim
114
+ optim: str = "adamw" # adamw, spam, stable-spam, muon, muon_moonlight, splus
115
+ second_order_optim : Optional[str] = None # snoo
116
+ batch_size: int = 8*64 # batch size, in sequences, across all devices
117
+ device_batch_size: int = 64 # batch size, in sequences, per device
118
+ total_iterations: int = 1000 # number of iterations to run
119
+ learning_rate: float = 1e-4
120
+ weight_decay: float = 0.
121
+ adam_beta1: float = 0.9
122
+ adam_beta2: float = 0.95
123
+ adam_eps: float = 1e-8
124
+ warmup_iters: int = 200
125
+ warmdown_iters: int = 3000
126
+ warmdown_type: str = "linear" # linear, cosine
127
+ grad_norm_clip: float = 1.0
128
+ uscaling_mult_embed: float = 0
129
+ uscaling_mult_scalar: float = 0
130
+ uscaling_mult_head: float = 0
131
+ init_std: float = 0.006
132
+ patch_level_training: bool = False
133
+ patch_level_training_size: int = 4
134
+ second_order_lr: float = 0.68
135
+ second_order_momentum: float = 0.37
136
+ second_order_interval: int = 25
137
+
138
+ # data
139
+ vocab_size: int = 50304
140
+ bos_id: int = 50256
141
+ sequence_length: int = 1024
142
+ intra_doc_masking: bool = False
143
+ input_bin: Optional[str] = None
144
+ input_val_bin: Optional[str] = None
145
+
146
+ # evaluation and logging
147
+ val_loss_every: int = 125
148
+ val_iterations: int = 50 # 1 step = global bs * T tokens
149
+ inspect_every: int = 0
150
+ save_every: int = 1000
151
+ log_dir: str = "logs/"
152
+ wandb_project: str = "dragon_v1.5"
153
+ wandb_name: Optional[str] = None
154
+ log_wandb: bool = False
155
+
156
+ load_arg_from_config: bool = True
157
+ load_optim: bool = True
158
+ load_sched: bool = True
159
+ compile: bool = True
160
+ compile_dynamic: bool = False
161
+
162
+ # used during training
163
+ slw_window: int = 0
164
+
165
+ def _peek_data_shard(filename):
166
+ with open(filename, "rb") as f:
167
+ header = np.frombuffer(f.read(256 * 4), dtype=np.int32)
168
+ if header[0] != 20240520:
169
+ print("ERROR: magic number mismatch in the data .bin file!")
170
+ print("---> HINT: Are you passing in a correct file with --input_bin?")
171
+ print("---> HINT: Dataset encoding changed recently, re-run data prepro or refer again to README")
172
+ exit(1)
173
+ assert header[1] == 1, "unsupported version"
174
+ ntok = int(header[2])
175
+ return ntok
176
+
177
+ def _load_data_shard(filename):
178
+ with open(filename, "rb") as f:
179
+ header = np.frombuffer(f.read(256 * 4), dtype=np.int32)
180
+ assert header[0] == 20240520, "magic number mismatch in the data .bin file"
181
+ assert header[1] == 1, "unsupported version"
182
+ ntok = int(header[2])
183
+ # memmap the token payload directly (uint16) after the 256*4B header
184
+ tokens = np.memmap(filename, dtype=np.uint16, mode="r", offset=256 * 4, shape=(ntok,))
185
+ assert tokens.size == ntok, "number of tokens read does not match header?"
186
+ return tokens
187
+
188
+ class DistributedDataLoader:
189
+ def __init__(self, filename_pattern, intra_doc_masking,B, T, process_rank, num_processes, bos_id, stop_on_end=False):
190
+ self.process_rank = process_rank
191
+ self.num_processes = num_processes
192
+ self.intra_doc_masking = intra_doc_masking
193
+ self.bos_id = bos_id
194
+ self.B = B # micro batch size
195
+ self.T = T
196
+ self.stop_on_end = stop_on_end
197
+
198
+ # glob files that match the pattern
199
+ self.files = sorted(glob.glob(filename_pattern))
200
+ assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}"
201
+ if self.stop_on_end:
202
+ assert len(self.files) == 1, "Pass a single .bin path (not a pattern) when stop_on_end=True."
203
+
204
+ # load and validate all data shards, count number of tokens in total
205
+ ntok_total = 0
206
+ self.shard_ntoks = []
207
+ for fname in self.files:
208
+ shard_ntok = _peek_data_shard(fname)
209
+ #print(f"shard {fname} has {shard_ntok} tokens")
210
+ assert shard_ntok >= num_processes * B * T + 1
211
+ self.shard_ntoks.append(shard_ntok)
212
+ ntok_total += int(shard_ntok)
213
+ self.ntok_total = ntok_total
214
+
215
+ # kick things off
216
+ self.reset()
217
+
218
+ def reset(self, shard=0):
219
+ self.current_shard = shard
220
+ self.current_position = self.process_rank * self.B * self.T
221
+ self.tokens = _load_data_shard(self.files[self.current_shard])
222
+
223
+ def advance(self): # advance to next data shard
224
+ self.current_shard = (self.current_shard + 1) % len(self.files)
225
+ self.current_position = self.process_rank * self.B * self.T
226
+ self.tokens = _load_data_shard(self.files[self.current_shard])
227
+
228
+ if self.process_rank == 0:
229
+ shard_tokens = self.shard_ntoks[self.current_shard]
230
+ cum_tokens = sum(self.shard_ntoks[: self.current_shard + 1])
231
+
232
+ def _fmt(n):
233
+ return f"{n/1e9:.2f}B" if n >= 1_000_000_000 else (
234
+ f"{n/1e6:.2f}M" if n >= 1_000_000 else str(n))
235
+
236
+ print(
237
+ f"Advancing to shard {self.current_shard}/{len(self.files)-1} "
238
+ f"(this={_fmt(shard_tokens)} tok, cum={_fmt(cum_tokens)}/{_fmt(self.ntok_total)})"
239
+ )
240
+
241
+ def next_batch(self):
242
+ B = self.B
243
+ T = self.T
244
+ buf = self.tokens[self.current_position : self.current_position+B*T]
245
+ buf = np.asarray(buf, dtype=np.int64)
246
+ x = torch.from_numpy(buf.reshape(B, T)) # inputs
247
+ y = torch.from_numpy(buf.reshape(B, T)) # targets
248
+
249
+ # compute cumulative document positions for intra-document masking
250
+ cu = None
251
+ maxlen = None
252
+ position_ids = None
253
+ if self.intra_doc_masking:
254
+ assert self.B == 1
255
+ starts = (x == self.bos_id).nonzero(as_tuple=True)[1].to(torch.long)
256
+ if starts.numel() == 0 or starts[0] != 0:
257
+ starts = torch.cat([torch.zeros(1, dtype=torch.long), starts])
258
+ ends = torch.cat([starts[1:], torch.tensor([x.numel()])])
259
+ seqlens = (ends - starts).to(torch.int32)
260
+ # cu_seqlens, max_seqlen.
261
+ cu = torch.cat([torch.zeros(1, dtype=torch.int32), seqlens.cumsum(0)]).cuda().to(torch.int32)
262
+ maxlen = int(seqlens.max())
263
+ # position_ids.
264
+ lengths = seqlens.to(torch.long)
265
+ starts_per_token = torch.repeat_interleave(starts.to(torch.long), lengths)
266
+ idx = torch.arange(T, device=x.device, dtype=torch.long)
267
+ position_ids = (idx - starts_per_token).unsqueeze(0)
268
+
269
+ # advance current position and load next shard if necessary
270
+ self.current_position += B * T * self.num_processes
271
+ if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
272
+ if self.stop_on_end:
273
+ raise StopIteration
274
+ else:
275
+ self.advance()
276
+
277
+ return x.cuda(), y.cuda(), cu, maxlen, position_ids
278
+
279
+ run_args = tyro.cli(Args)
280
+
281
+ saved_args_path = os.path.join(os.path.dirname(run_args.load_dir), "args.pkl")
282
+ print(f"Loading args from {saved_args_path}")
283
+ if os.path.exists(saved_args_path):
284
+ with open(saved_args_path, "rb") as f:
285
+ saved_args = pickle.load(f)
286
+ args: NanoArgs = saved_args
287
+
288
+ print(args)
289
+
290
+ B, T = args.device_batch_size, args.sequence_length
291
+ accumulation_steps = args.batch_size // (B * 1)
292
+
293
+ val_loader = DistributedDataLoader(run_args.val_bin, False, B, T, 0, 1, args.bos_id, stop_on_end=True)
294
+ print(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files")
295
+
296
+ # load model.
297
+ config_hf = DragonConfig(
298
+ tie_lm_head=args.tie_lm_head,
299
+ mlp_type=args.mlp_type,
300
+ layer_norm_scaling=args.layer_norm_scaling,
301
+ mamba_d_state=args.mamba_d_state,
302
+ mamba_headdim=args.mamba_headdim,
303
+ mamba3_rope=args.mamba3_rope,
304
+ mamba3_remove_BC_bias=args.mamba3_remove_BC_bias,
305
+ mamba3_is_id_rms=args.mamba3_is_id_rms,
306
+ mamba3_remove_conv=args.mamba3_remove_conv,
307
+ mamba3_is_A_dd=args.mamba3_is_A_dd,
308
+ mamba3_add_trapezoid=args.mamba3_add_trapezoid,
309
+ moe=args.moe,
310
+ moe_num_routed_experts=args.moe_num_routed_experts,
311
+ moe_routed_scaling_factor=args.moe_routed_scaling_factor,
312
+ moe_routed_intermediate_size=args.moe_routed_intermediate_size,
313
+ moe_shared_intermediate_size=args.moe_shared_intermediate_size,
314
+ intra_doc_masking=args.intra_doc_masking,
315
+ seednorm_rank=args.seednorm_rank,
316
+ seednorm_type=args.seednorm_type,
317
+ final_norm=args.final_norm,
318
+ mla_kv_rank=args.mla_kv_rank,
319
+ rope_gdn=args.rope_gdn,
320
+ shrink_qk_da=args.shrink_qk_da,
321
+ shrink_qk_gdn=args.shrink_qk_gdn,
322
+ mixer_gn=args.mixer_gn,
323
+ kda_allow_neg_eigval=args.kda_allow_neg_eigval,
324
+ kda_num_v_heads=args.kda_num_v_heads,
325
+ seednorm_wd=args.seednorm_wd,
326
+ normalization_type=args.normalization_type,
327
+ tpa_rank=args.tpa_rank,
328
+ num_signal_heads_diff=args.num_signal_heads_diff,
329
+ scalar_proj_as_hidden_matrix=args.scalar_proj_as_hidden_matrix,
330
+ token_shift_attn=args.token_shift_attn,
331
+ token_shift_gdn=args.token_shift_gdn,
332
+ token_conv1d_attn=args.token_conv1d_attn,
333
+ token_conv1d_gdn=args.token_conv1d_gdn,
334
+ patch_level_training=args.patch_level_training,
335
+ patch_level_training_size=args.patch_level_training_size,
336
+ nsa_topk=args.nsa_topk,
337
+ nsa_block_size=args.nsa_block_size,
338
+ nsa_window_size=args.nsa_window_size,
339
+ cca_seq_kernel_size=args.cca_seq_kernel_size,
340
+ head_dim=args.head_dim,
341
+ head_dim_gdn=args.head_dim_gdn,
342
+ num_attention_heads_gdn=args.n_heads_gdn,
343
+ num_key_value_heads_gdn=args.n_kv_heads_gdn,
344
+ zero_centered_gate=args.zero_centered_gate,
345
+ zero_centered_gate_type=args.zero_centered_gate_type,
346
+ scalable_softmax=args.scalable_softmax,
347
+ mamba_mimo_dim=args.mamba_mimo_dim,
348
+ mamba_ngroups=args.mamba_ngroups,
349
+ resformer=args.resformer,
350
+ gate_type=args.gate_type,
351
+ gate_act=args.gate_act,
352
+ gate_attn=args.gate_attn,
353
+ gate_gdn=args.gate_gdn,
354
+ fused_loss_computation=args.fused_loss_computation,
355
+ qk_norm=args.qk_norm,
356
+ num_attention_heads_indexer=args.num_attention_heads_indexer,
357
+ head_dim_indexer=args.head_dim_indexer,
358
+ dsa_q_lora_rank=args.dsa_q_lora_rank,
359
+ dsa_topk=args.dsa_topk,
360
+ zero_centered_gamma=args.zero_centered_gamma,
361
+ vocab_size=args.vocab_size,
362
+ max_position_embeddings=args.sequence_length,
363
+ use_uscaling=args.use_uscaling,
364
+ hidden_size=args.d_model,
365
+ intermediate_size=args.d_model * args.mlp_expand,
366
+ expand_factor=args.expand_factor,
367
+ layers_config=args.layers_config,
368
+ num_attention_heads=args.n_heads,
369
+ num_key_value_heads=args.n_kv_heads if args.n_kv_heads > 0 else args.n_heads,
370
+ initializer_range=args.init_std,
371
+ softcap_local_attn=args.softcap_local_attn,
372
+ softcap_global_attn=args.softcap_global_attn,
373
+ norm_epsilon=args.eps_rmsnorm,
374
+ use_cache=False,
375
+ sliding_window_size=args.swa_window_size,
376
+ rope_type_global=args.rope_type_global,
377
+ rope_type_local=args.rope_type_local,
378
+ rope_theta_global=args.rope_theta_global,
379
+ rope_theta_local=args.rope_theta_local,
380
+ uscaling_tau=args.uscaling_tau,
381
+ mlp_linking=args.mlp_linking
382
+ )
383
+
384
+ model = DragonForCausalLM.from_pretrained(run_args.load_dir, config=config_hf, torch_dtype=torch.bfloat16)
385
+ model = model.cuda()
386
+
387
+ model = torch.compile(model, dynamic=args.compile_dynamic) if args.compile else model
388
+ model.eval()
389
+ ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16)
390
+
391
+ val_loader.reset()
392
+ total_steps = (val_loader.shard_ntoks[val_loader.current_shard] - 1) // (B * T * val_loader.num_processes)
393
+ pbar = tqdm(total=total_steps, desc="Validating", unit="step")
394
+ val_loss_sum = torch.zeros((), device="cuda", dtype=torch.float32)
395
+ n_steps = 0
396
+ tok_per_step = B * T
397
+
398
+ with torch.no_grad():
399
+ while True:
400
+ try:
401
+ inputs, targets, cu, maxlen, position_ids = val_loader.next_batch()
402
+ except StopIteration:
403
+ break
404
+ with ctx:
405
+ step_loss = model(
406
+ input_ids=inputs,
407
+ labels=targets,
408
+ just_loss=True,
409
+ cu_seqlens=cu,
410
+ max_seqlen=maxlen,
411
+ position_ids=position_ids,
412
+ ).loss.detach()
413
+ val_loss_sum += step_loss
414
+ n_steps += 1
415
+ avg = (val_loss_sum / n_steps).item()
416
+ pbar.update(1)
417
+ pbar.set_postfix(avg_loss=f"{avg:.4f}", ppl=f"{np.exp(avg):.2f}")
418
+ pbar.close()
419
+
420
+ assert n_steps > 0, "No batches read from the file; check B/T vs file size."
421
+ val_loss = (val_loss_sum / n_steps).item()
422
+ print(f"Validation Loss: {val_loss:.6f}. Perplexity: {np.exp(val_loss):.6f} (steps={n_steps}, tokens={n_steps*tok_per_step})")
configuration_dragon.py CHANGED
@@ -92,6 +92,15 @@ class DragonConfig(PretrainedConfig):
92
 
93
  def __init__(
94
  self,
 
 
 
 
 
 
 
 
 
95
  tie_lm_head: bool = False,
96
  mlp_type: str = "simple",
97
  layer_norm_scaling: bool = False,
@@ -103,6 +112,7 @@ class DragonConfig(PretrainedConfig):
103
  mamba3_remove_conv: bool = True,
104
  mamba3_is_A_dd: bool = True,
105
  mamba3_add_trapezoid: bool = True,
 
106
  moe: bool = False,
107
  moe_num_routed_experts: int = 2,
108
  moe_routed_scaling_factor: float = 2.5,
@@ -116,6 +126,7 @@ class DragonConfig(PretrainedConfig):
116
  shrink_qk_da: int = 2,
117
  shrink_qk_gdn: int = 2,
118
  mixer_gn: bool = True,
 
119
  kda_allow_neg_eigval: bool = False,
120
  kda_num_v_heads: Optional[int] = None,
121
  seednorm_wd: bool = True,
@@ -197,6 +208,15 @@ class DragonConfig(PretrainedConfig):
197
  mlp_linking=False,
198
  **kwargs,
199
  ):
 
 
 
 
 
 
 
 
 
200
  self.tie_lm_head = tie_lm_head
201
  self.mlp_type = mlp_type
202
  self.layer_norm_scaling = layer_norm_scaling
@@ -208,6 +228,7 @@ class DragonConfig(PretrainedConfig):
208
  self.mamba3_remove_conv = mamba3_remove_conv
209
  self.mamba3_is_A_dd = mamba3_is_A_dd
210
  self.mamba3_add_trapezoid = mamba3_add_trapezoid
 
211
  self.moe = moe
212
  self.moe_num_routed_experts = moe_num_routed_experts
213
  self.moe_routed_scaling_factor = moe_routed_scaling_factor
@@ -221,6 +242,7 @@ class DragonConfig(PretrainedConfig):
221
  self.shrink_qk_da = shrink_qk_da
222
  self.shrink_qk_gdn = shrink_qk_gdn
223
  self.mixer_gn = mixer_gn
 
224
  self.kda_allow_neg_eigval = kda_allow_neg_eigval
225
  self.kda_num_v_heads = kda_num_v_heads
226
  self.seednorm_wd = seednorm_wd
 
92
 
93
  def __init__(
94
  self,
95
+ reduce_lm_head: int = 0,
96
+ dataset_type: str = "hf",
97
+ vwn: bool = False,
98
+ vwn_m: int = 2,
99
+ vwn_n: int = 3,
100
+ vwn_wd_alpha_beta: bool = False,
101
+ vwn_dynamic: bool = True,
102
+ legacy_gate: bool = False,
103
+ init_gpt2: bool = False,
104
  tie_lm_head: bool = False,
105
  mlp_type: str = "simple",
106
  layer_norm_scaling: bool = False,
 
112
  mamba3_remove_conv: bool = True,
113
  mamba3_is_A_dd: bool = True,
114
  mamba3_add_trapezoid: bool = True,
115
+ mamba3_postgate_norm: bool = False,
116
  moe: bool = False,
117
  moe_num_routed_experts: int = 2,
118
  moe_routed_scaling_factor: float = 2.5,
 
126
  shrink_qk_da: int = 2,
127
  shrink_qk_gdn: int = 2,
128
  mixer_gn: bool = True,
129
+ gate_before_norm: bool = True,
130
  kda_allow_neg_eigval: bool = False,
131
  kda_num_v_heads: Optional[int] = None,
132
  seednorm_wd: bool = True,
 
208
  mlp_linking=False,
209
  **kwargs,
210
  ):
211
+ self.reduce_lm_head = reduce_lm_head
212
+ self.dataset_type = dataset_type
213
+ self.vwn = vwn
214
+ self.vwn_m = vwn_m
215
+ self.vwn_n = vwn_n
216
+ self.vwn_wd_alpha_beta = vwn_wd_alpha_beta
217
+ self.vwn_dynamic = vwn_dynamic
218
+ self.legacy_gate = legacy_gate
219
+ self.init_gpt2 = init_gpt2
220
  self.tie_lm_head = tie_lm_head
221
  self.mlp_type = mlp_type
222
  self.layer_norm_scaling = layer_norm_scaling
 
228
  self.mamba3_remove_conv = mamba3_remove_conv
229
  self.mamba3_is_A_dd = mamba3_is_A_dd
230
  self.mamba3_add_trapezoid = mamba3_add_trapezoid
231
+ self.mamba3_postgate_norm = mamba3_postgate_norm
232
  self.moe = moe
233
  self.moe_num_routed_experts = moe_num_routed_experts
234
  self.moe_routed_scaling_factor = moe_routed_scaling_factor
 
242
  self.shrink_qk_da = shrink_qk_da
243
  self.shrink_qk_gdn = shrink_qk_gdn
244
  self.mixer_gn = mixer_gn
245
+ self.gate_before_norm = gate_before_norm
246
  self.kda_allow_neg_eigval = kda_allow_neg_eigval
247
  self.kda_num_v_heads = kda_num_v_heads
248
  self.seednorm_wd = seednorm_wd
modeling_dragon.py CHANGED
@@ -21,6 +21,11 @@ from fla.ops.nsa.parallel import parallel_nsa
21
 
22
  from flash_attn.modules.mlp import GatedMlp
23
 
 
 
 
 
 
24
  try:
25
  from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
26
  except ImportError:
@@ -54,6 +59,8 @@ try:
54
  except ImportError:
55
  chunk_kda, fused_recurrent_kda, fused_kda_gate, prepare_sequence_ids = None, None, None, None
56
 
 
 
57
  from torch.compiler import disable
58
 
59
  logger = logging.get_logger(__name__)
@@ -268,6 +275,13 @@ class DragonLinear(nn.Linear):
268
  out = super().forward(x)
269
  return ScaledGrad.apply(out, self.alpha_fwd, self.alpha_bwd)
270
 
 
 
 
 
 
 
 
271
  class HybridDragonDynamicCache(DynamicCache):
272
  """
273
  A dynamic cache that handle both the attention cache (which has a seq_len dimension) and the GDN cache
@@ -299,6 +313,10 @@ class HybridDragonDynamicCache(DynamicCache):
299
  self.q_conv_caches = []
300
  self.k_conv_caches = []
301
  self.v_conv_caches = []
 
 
 
 
302
 
303
  for idx, layer_type in enumerate(config.layers_config):
304
  if not layer_type == "r":
@@ -313,6 +331,8 @@ class HybridDragonDynamicCache(DynamicCache):
313
  self.q_conv_caches.append(None)
314
  self.k_conv_caches.append(None)
315
  self.v_conv_caches.append(None)
 
 
316
 
317
  self.window_size = config.sliding_window_size
318
  self.layers_config = config.layers_config
@@ -359,6 +379,15 @@ class HybridDragonDynamicCache(DynamicCache):
359
 
360
  def set_prev_hidden(self, layer_idx, h):
361
  self.cca_prev_hidden[layer_idx] = h
 
 
 
 
 
 
 
 
 
362
 
363
  # kv shift
364
  def get_last_kv(self, layer_idx):
@@ -568,6 +597,7 @@ class DragonAttention(nn.Module):
568
 
569
  projection_dim = self.head_dim * (self.num_attention_heads + 2 * (0 if reuse_kv else self.num_key_value_heads))
570
  self.linear_qkv = DragonLinear(config, config.hidden_size, projection_dim, bias=False)
 
571
 
572
  if self.config.token_shift_attn:
573
  if self.config.scalar_proj_as_hidden_matrix:
@@ -755,6 +785,187 @@ class DragonAttention(nn.Module):
755
 
756
  return attn_output, last_key_states, last_value_states
757
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
758
  class DragonTensorProductAttention(nn.Module):
759
  """
760
  Multi-headed attention from 'Attention Is All You Need' paper.
@@ -785,6 +996,8 @@ class DragonTensorProductAttention(nn.Module):
785
  self.W_A_v = DragonLinear(config, self.hidden_size, self.num_attention_heads * self.rank, bias=False)
786
  self.W_B_k = DragonLinear(config, self.hidden_size, self.rank * self.head_dim, bias=False)
787
  self.W_B_v = DragonLinear(config, self.hidden_size, self.rank * self.head_dim, bias=False)
 
 
788
 
789
  if self.config.token_shift_attn:
790
  if self.config.scalar_proj_as_hidden_matrix:
@@ -1156,6 +1369,246 @@ class DragonCompressedConvolutionalAttention(nn.Module):
1156
 
1157
  return attn_output, None, None
1158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1159
  class DragonNativeSparseAttention(nn.Module):
1160
  """
1161
  Multi-headed attention from 'Attention Is All You Need' paper.
@@ -1696,6 +2149,7 @@ class DragonDifferentialAttention(nn.Module):
1696
 
1697
  projection_dim = self.head_qk_dim * self.num_attention_heads + self.head_qk_dim * self.num_key_value_heads + (self.head_v_dim * self.num_noise_heads//2)
1698
  self.linear_qkv = DragonLinear(config, config.hidden_size, projection_dim, bias=False)
 
1699
 
1700
  if self.config.token_shift_attn:
1701
  if self.config.scalar_proj_as_hidden_matrix:
@@ -2373,6 +2827,8 @@ class DragonDifferentialTensorProductAttention(nn.Module):
2373
  self.W_A_v = DragonLinear(config, self.hidden_size, self.num_noise_heads * self.rank, bias=False)
2374
  self.W_B_k = DragonLinear(config, self.hidden_size, self.rank * self.head_qk_dim, bias=False)
2375
  self.W_B_v = DragonLinear(config, self.hidden_size, self.rank * self.head_v_dim, bias=False)
 
 
2376
 
2377
  if self.config.token_shift_attn:
2378
  if self.config.scalar_proj_as_hidden_matrix:
@@ -3161,12 +3617,29 @@ class DragonGatedDeltaNet(nn.Module):
3161
  self.num_attention_heads*self.dk + self.n_kv_heads*self.dk + self.n_kv_heads*self.dv,
3162
  bias=False
3163
  )
 
3164
  self.linear_ba = DragonLinear(
3165
  config, config.hidden_size,
3166
  self.num_attention_heads + self.num_attention_heads, #+ self.num_attention_heads*self.dv, # b(H), a(H), g(H*dv)
3167
  bias=False
3168
  )
3169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3170
  dt_min = config.time_step_min
3171
  dt_max = config.time_step_max
3172
  dt_init_floor = config.time_step_floor
@@ -3181,11 +3654,13 @@ class DragonGatedDeltaNet(nn.Module):
3181
  inv_dt = dt + torch.log(-torch.expm1(-dt))
3182
  with torch.no_grad():
3183
  self.dt_bias = nn.Parameter(inv_dt)
 
3184
 
3185
  assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
3186
  A = torch.empty(self.n_heads_local, dtype=torch.float32).uniform_(*A_init_range)
3187
  A_log = torch.log(A) # Keep A_log in fp32
3188
  self.A_log = nn.Parameter(A_log)
 
3189
 
3190
  if self.config.rope_gdn == "rope":
3191
  self.rope_proj = DragonLinear(config, config.hidden_size, self.dk//4, bias=False)
@@ -3348,6 +3823,11 @@ class DragonGatedDeltaNet(nn.Module):
3348
  use_qk_l2norm_in_kernel=True
3349
  ) # (B L H dv)
3350
 
 
 
 
 
 
3351
  # update GDN cache
3352
  if cache_params is not None:
3353
  cache_params.ssm_caches[self.layer_idx] = ssm_cache
@@ -3381,6 +3861,9 @@ class DragonKimiDeltaAttention(nn.Module):
3381
  self.q_proj = DragonLinear(config, config.hidden_size, self.key_dim, bias=False)
3382
  self.k_proj = DragonLinear(config, config.hidden_size, self.key_dim, bias=False)
3383
  self.v_proj = DragonLinear(config, config.hidden_size, self.value_dim, bias=False)
 
 
 
3384
 
3385
  self.q_conv1d = ShortConvolution(
3386
  hidden_size=self.key_dim,
@@ -3413,10 +3896,21 @@ class DragonKimiDeltaAttention(nn.Module):
3413
  self.A_log = nn.Parameter(torch.log(torch.empty(self.num_q_heads, dtype=torch.float32).uniform_(1, 16)))
3414
  self.dt_bias = nn.Parameter(torch.zeros(self.key_dim, dtype=torch.float32))
3415
 
3416
- """self.g_proj = nn.Sequential(
3417
- DragonLinear(config, config.hidden_size, self.head_v_dim, bias=False),
3418
- DragonLinear(config, self.head_v_dim, self.value_dim, bias=True),
3419
- )"""
 
 
 
 
 
 
 
 
 
 
 
3420
 
3421
  @disable
3422
  def _kda_gate_call(self, g, A_log, head_k_dim, g_bias):
@@ -3427,6 +3921,7 @@ class DragonKimiDeltaAttention(nn.Module):
3427
  hidden_states: torch.Tensor,
3428
  position_embeddings: tuple[torch.Tensor, torch.Tensor],
3429
  cache_params: Optional[HybridDragonDynamicCache] = None,
 
3430
  **kwargs,
3431
  ):
3432
  _, q_len, _ = hidden_states.shape
@@ -3443,20 +3938,26 @@ class DragonKimiDeltaAttention(nn.Module):
3443
  conv_state_k = cache_params.k_conv_caches[self.layer_idx]
3444
  conv_state_v = cache_params.v_conv_caches[self.layer_idx]
3445
 
 
 
 
3446
  q, conv_state_q = self.q_conv1d(
3447
  x=self.q_proj(hidden_states),
3448
  cache=conv_state_q,
3449
  output_final_state=cache_params is not None,
 
3450
  )
3451
  k, conv_state_k = self.k_conv1d(
3452
  x=self.k_proj(hidden_states),
3453
  cache=conv_state_k,
3454
  output_final_state=cache_params is not None,
 
3455
  )
3456
  v, conv_state_v = self.v_conv1d(
3457
  x=self.v_proj(hidden_states),
3458
  cache=conv_state_v,
3459
  output_final_state=cache_params is not None,
 
3460
  )
3461
 
3462
  g = self.f_proj(hidden_states)
@@ -3482,6 +3983,7 @@ class DragonKimiDeltaAttention(nn.Module):
3482
  initial_state=None,
3483
  output_final_state=cache_params is not None,
3484
  use_qk_l2norm_in_kernel=True,
 
3485
  )
3486
  elif mode == 'fused_recurrent':
3487
  o, ssm_cache = fused_recurrent_kda(
@@ -3500,6 +4002,11 @@ class DragonKimiDeltaAttention(nn.Module):
3500
  #o = o * F.silu(rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim))
3501
  # TODO: other types of gates? as well as ZCG?
3502
 
 
 
 
 
 
3503
  if cache_params is not None:
3504
  cache_params.ssm_caches[self.layer_idx] = ssm_cache
3505
  cache_params.q_conv_caches[self.layer_idx] = conv_state_q
@@ -3549,8 +4056,8 @@ class DragonMamba3(nn.Module):
3549
  if config.mamba3_rope:
3550
  self.rope_proj = DragonLinear(config, self.d_model, self.num_rope_angles, bias=False)
3551
 
3552
- # Order: [z, x, B, C, dt]
3553
- d_in_proj = 2 * self.d_inner + 2 * self.d_state * self.ngroups + self.nheads
3554
 
3555
  if self.config.mamba3_is_A_dd:
3556
  self.A_proj = DragonLinear(config, self.d_model, self.nheads, bias=False, dtype=torch.float32)
@@ -3575,6 +4082,7 @@ class DragonMamba3(nn.Module):
3575
  self.dt_bias._no_weight_decay = True
3576
 
3577
  self.in_proj = DragonLinear(config, self.d_model, d_in_proj, bias=self.bias)
 
3578
 
3579
  self.B_bias, self.C_bias = None, None
3580
  if not config.mamba3_remove_BC_bias:
@@ -3604,18 +4112,36 @@ class DragonMamba3(nn.Module):
3604
  self.D = nn.Parameter(torch.ones(self.nheads))
3605
  self.D._no_weight_decay = True
3606
 
3607
- def forward(
 
 
 
 
 
 
 
 
 
 
3608
  self,
3609
  hidden_states: torch.Tensor,
3610
  cache_params: Optional[HybridDragonDynamicCache] = None,
 
3611
  **kwargs
3612
  ):
 
 
 
 
 
 
 
 
3613
  # Apply in_proj
3614
- zxBCdt = self.in_proj(hidden_states)
3615
- z, xBC, dd_dt = torch.split(
3616
- zxBCdt,
3617
  [
3618
- self.d_inner,
3619
  self.d_inner + 2 * self.d_state * self.ngroups,
3620
  self.nheads,
3621
  ],
@@ -3628,12 +4154,17 @@ class DragonMamba3(nn.Module):
3628
  _A = -torch.exp(self.A_log).unsqueeze(0).unsqueeze(0)
3629
  dt = F.softplus(dd_dt + self.dt_bias) # (B, L, N)
3630
 
 
 
 
 
3631
  if not self.config.mamba3_remove_conv:
3632
  xBC = causal_conv1d_fn(
3633
  x=xBC.transpose(1, 2),
3634
  weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
3635
  bias=self.conv1d.bias,
3636
  activation=self.activation,
 
3637
  ).transpose(1, 2) # (B, L, self.d_inner + 2 * ngroups * d_state)
3638
 
3639
  x, B, C = torch.split(
@@ -3699,10 +4230,6 @@ class DragonMamba3(nn.Module):
3699
 
3700
  x_scalar = (gamma_arr*_alpha_arr).to(torch.bfloat16)
3701
 
3702
- ssm_cache = None
3703
- if cache_params is not None:
3704
- ssm_cache = cache_params.ssm_caches[self.layer_idx]
3705
-
3706
  out = mamba_chunk_scan_discretized_combined(
3707
  x=x.bfloat16(),
3708
  A=A,
@@ -3714,19 +4241,26 @@ class DragonMamba3(nn.Module):
3714
  CB_sum=CB_sum,
3715
  D=self.D,
3716
  z=None,
3717
- initial_states=ssm_cache,
3718
- return_final_states=cache_params is not None,
 
3719
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
3720
 
3721
- if cache_params is not None:
3722
- y, ssm_cache = out
3723
- cache_params.ssm_caches[self.layer_idx] = ssm_cache
3724
- else:
3725
- y = out
3726
-
3727
- y = rearrange(y, "b l h p -> b l (h p)")
3728
- y = y*self.act(z)
3729
- y = rearrange(y, "b l (h p) -> b l h p", h=self.nheads).to(x.dtype)
3730
 
3731
  return y, None, None
3732
 
@@ -3747,6 +4281,7 @@ class DragonMamba2(nn.Module):
3747
  # Order: [x, B, C, dt]
3748
  d_in_proj = self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
3749
  self.in_proj = DragonLinear(config, self.d_model, d_in_proj, bias=False)
 
3750
 
3751
  if not self.config.mamba3_remove_conv:
3752
  conv_dim = self.d_inner + 2 * self.ngroups * self.d_state
@@ -3784,6 +4319,15 @@ class DragonMamba2(nn.Module):
3784
  self.D = nn.Parameter(torch.ones(self.nheads))
3785
  self.D._no_weight_decay = True
3786
 
 
 
 
 
 
 
 
 
 
3787
  def forward(self, hidden_states, **kwargs):
3788
  """
3789
  u: (B, L, D)
@@ -3830,6 +4374,12 @@ class DragonMamba2(nn.Module):
3830
  initial_states=None,
3831
  )
3832
 
 
 
 
 
 
 
3833
  return y, None, None
3834
 
3835
  class DragonMamba3Mimo(nn.Module):
@@ -3844,11 +4394,13 @@ class DragonMamba3Mimo(nn.Module):
3844
  "when creating this class."
3845
  )
3846
 
 
 
3847
  self.d_model = config.hidden_size
3848
- self.d_state = 64
3849
  self.conv_init = None
3850
  self.expand = 2
3851
- self.headdim = 128
3852
  self.ngroups = config.mamba_ngroups
3853
  self.activation = "swish"
3854
  self.bias = False
@@ -3863,14 +4415,12 @@ class DragonMamba3Mimo(nn.Module):
3863
  self.dt_init_floor = 1e-4
3864
  self.mimo_dim = config.mamba_mimo_dim
3865
  self.mimo_proj_block_order = 1
3866
-
3867
 
3868
  self.d_inner = int(self.expand * self.d_model)
3869
  assert self.d_inner % self.headdim == 0
3870
  self.nheads = self.d_inner // self.headdim
3871
  self.dr_out_dim = self.d_inner // self.mimo_proj_block_order
3872
 
3873
-
3874
  self.split_tensor_size = int(self.d_state * self.rope_fraction)
3875
  if self.split_tensor_size % 2 != 0:
3876
  self.split_tensor_size -= 1
@@ -3896,6 +4446,7 @@ class DragonMamba3Mimo(nn.Module):
3896
  self.dt_bias._no_weight_decay = True
3897
 
3898
  self.in_proj = DragonLinear(config, self.d_model, d_in_proj, bias=self.bias)
 
3899
 
3900
  self.B_bias = nn.Parameter(torch.ones((self.mimo_dim, self.nheads, self.d_state)), requires_grad=True)
3901
  self.C_bias = nn.Parameter(torch.ones((self.mimo_dim, self.nheads, self.d_state)), requires_grad=True)
@@ -3927,11 +4478,14 @@ class DragonMamba3Mimo(nn.Module):
3927
  self.in_proj_mimo_z = nn.Parameter(in_proj_mimo_z_init_weights, requires_grad=True)
3928
  self.out_proj_mimo = nn.Parameter(out_proj_mimo_init_weights, requires_grad=True)
3929
 
3930
-
3931
  # D "skip" parameter
3932
  self.D = nn.Parameter(torch.ones(self.nheads))
3933
  self.D._no_weight_decay = True
3934
 
 
 
 
 
3935
  def forward(self, hidden_states, **kwargs):
3936
  # Apply in_proj
3937
  zxBCdt = self.in_proj(hidden_states)
@@ -4024,7 +4578,7 @@ class DragonMamba3Mimo(nn.Module):
4024
  _beta_arr = torch.roll(beta_arr, shifts=-1, dims=1)
4025
 
4026
  x_scalar = (gamma_arr*_alpha_arr + _beta_arr).to(torch.bfloat16)
4027
-
4028
  z = rearrange(z, "b l r (h p) -> b l r h p", p=self.headdim)
4029
 
4030
  y = mamba_mimo_chunk_scan_discretized_fused_combined(
@@ -4037,10 +4591,15 @@ class DragonMamba3Mimo(nn.Module):
4037
  gamma=gamma_arr,
4038
  CB_sum=CB_sum,
4039
  D=self.D,
4040
- z=z,
4041
  )
4042
 
4043
  y = rearrange(y, "b l r h p -> b l r (h p)")
 
 
 
 
 
4044
  #if seqlen_og is not None:
4045
  # y = rearrange(y, "b l r d -> (b l) r d")
4046
 
@@ -4067,7 +4626,9 @@ class DragonMLP(nn.Module):
4067
  self.lambda1 = nn.Parameter(torch.zeros(self.link_size)) # sigmoid->0.5
4068
  else :
4069
  self.fc_1 = DragonLinear(config, config.hidden_size, intermediate_size, bias=False)
 
4070
  self.fc_2 = DragonLinear(config, intermediate_size, config.hidden_size, bias=False)
 
4071
  self.register_buffer("_2_sqrt_5", torch.tensor(2/math.sqrt(5)) if config.use_uscaling else torch.tensor(1.), persistent=False)
4072
 
4073
  def forward(self, hidden_states):
@@ -4096,7 +4657,9 @@ class DragonGatedMLP(nn.Module):
4096
  self.intermediate_size = intermediate_size
4097
 
4098
  self.fc_1 = DragonLinear(config, config.hidden_size, num_active_experts*self.intermediate_size, bias=False)
 
4099
  self.fc_2 = DragonLinear(config, num_active_experts*self.intermediate_size, config.hidden_size, bias=False)
 
4100
  self.register_buffer("_2_sqrt_5", torch.tensor(2/math.sqrt(5)) if config.use_uscaling else torch.tensor(1.), persistent=False)
4101
 
4102
  def forward(self, hidden_states, gates):
@@ -4174,6 +4737,11 @@ class DragonMonoBlock(GradientCheckpointingLayer):
4174
  head_dim = self.mixer.head_dim
4175
  num_attention_heads = self.mixer.num_q_heads
4176
  use_gate = config.gate_attn
 
 
 
 
 
4177
  elif layer_type == 'n':
4178
  self.mixer = DragonNativeSparseAttention(config, reuse_kv=False, layer_idx=layer_idx)
4179
  head_dim = self.mixer.head_dim
@@ -4203,7 +4771,7 @@ class DragonMonoBlock(GradientCheckpointingLayer):
4203
  self.mixer = DragonMamba3(config, layer_idx=layer_idx)
4204
  head_dim = self.mixer.headdim
4205
  num_attention_heads = self.mixer.nheads
4206
- use_gate = False
4207
  elif layer_type == '2':
4208
  self.mixer = DragonMamba2(config, layer_idx=layer_idx)
4209
  head_dim = self.mixer.headdim
@@ -4214,6 +4782,11 @@ class DragonMonoBlock(GradientCheckpointingLayer):
4214
  head_dim = self.mixer.headdim
4215
  num_attention_heads = self.mixer.nheads
4216
  use_gate = False # inside Mamba3Mimo
 
 
 
 
 
4217
  else:
4218
  raise ValueError(f"Unknown layer type: {layer_type}")
4219
 
@@ -4233,6 +4806,7 @@ class DragonMonoBlock(GradientCheckpointingLayer):
4233
  self.gate_proj.is_scalar_weight = True
4234
  else:
4235
  raise ValueError(f"Unknown gate_type: {self.config.gate_type}")
 
4236
  if self.config.zero_centered_gate:
4237
  val = 1.
4238
  if self.config.zero_centered_gate_type==3:
@@ -4253,6 +4827,7 @@ class DragonMonoBlock(GradientCheckpointingLayer):
4253
  self.use_gate = use_gate
4254
 
4255
  self.mixer_proj = DragonLinear(config, head_dim*num_attention_heads, config.hidden_size, bias=False)
 
4256
  if config.mixer_gn:
4257
  self.mixer_group_norm = DragonHeadWiseRMSNorm(n_heads=num_attention_heads, d_head=head_dim, eps=config.norm_epsilon, zero_centered_gamma=config.zero_centered_gamma)
4258
 
@@ -4299,6 +4874,8 @@ class DragonMonoBlock(GradientCheckpointingLayer):
4299
  cu_seqlens=cu_seqlens,
4300
  max_seqlen=max_seqlen,
4301
  ) # (B, L, E*D)
 
 
4302
  if self.use_gate:
4303
  if self.config.gate_type == "elementwise" or self.config.gate_type == "kimi":
4304
  g_proj = self.gate_proj(hidden_states).view(hidden_states.size(0), hidden_states.size(1), self.num_attention_heads, self.head_dim).to(y_mixer.dtype)
@@ -4313,7 +4890,7 @@ class DragonMonoBlock(GradientCheckpointingLayer):
4313
  y_mixer = y_mixer * (self.gate_act(g_proj) + self.gate_bias)
4314
  elif self.config.zero_centered_gate_type == 3 or self.config.zero_centered_gate_type == 4:
4315
  y_mixer = y_mixer * self.gate_act(g_proj + self.gate_bias)
4316
- if self.config.mixer_gn:
4317
  y_mixer = self.mixer_group_norm(y_mixer)
4318
  y_mixer = y_mixer.view(y_mixer.size(0), y_mixer.size(1), -1)
4319
  y_mixer = self.mixer_proj(y_mixer)
@@ -4327,6 +4904,282 @@ class DragonMonoBlock(GradientCheckpointingLayer):
4327
 
4328
  return hidden_states, last_key_states, last_value_states
4329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4330
  class DragonBlock(GradientCheckpointingLayer):
4331
  def __init__(self, config: DragonConfig, layer_idx: int, layer_type: str):
4332
  super().__init__()
@@ -4412,13 +5265,13 @@ class DragonPreTrainedModel(PreTrainedModel):
4412
  "attentions": DragonBlock,
4413
  }
4414
 
4415
- def _init_weights(self, module): # TODO: ??
4416
  if isinstance(module, (DragonLinear, nn.Conv1d)):
4417
  if module.bias is not None:
4418
  nn.init.zeros_(module.bias)
4419
- nn.init.normal_(module.weight, mean=0., std=1. if self.config.use_uscaling else 0.006)
4420
  elif isinstance(module, nn.Embedding):
4421
- nn.init.normal_(module.weight, mean=0., std=1. if self.config.use_uscaling else 0.006)
4422
 
4423
  @dataclass
4424
  class DragonOutput(ModelOutput):
@@ -4473,19 +5326,31 @@ class DragonModel(DragonPreTrainedModel):
4473
  self.vocab_size = config.vocab_size
4474
 
4475
  self.embedding = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
4476
- self.layers = nn.ModuleList([DragonBlock(config, layer_idx=i, layer_type=layer) if layer in ['l', 'r', 'd'] else DragonMonoBlock(config, layer_idx=i, layer_type=layer) for i, layer in enumerate(config.layers_config)])
 
 
 
 
 
 
 
4477
 
4478
  if self.config.rope_type_global != '' or self.config.rope_type_local != '':
4479
  self.rotary_emb = DragonRotaryEmbedding(config, head_dim=config.head_dim if config.head_dim else (config.expand_factor*config.hidden_size)//config.num_attention_heads, theta=config.rope_theta_local) # only for SWA
4480
  else:
4481
  self.rotary_emb = None
4482
 
 
 
 
 
 
4483
  if self.config.final_norm:
4484
  self.final_norm = DragonNorm(config, config.hidden_size)
4485
 
4486
  self.gradient_checkpointing = False
4487
  self.post_init()
4488
-
4489
  def get_input_embeddings(self):
4490
  return self.embedding
4491
 
@@ -4514,6 +5379,8 @@ class DragonModel(DragonPreTrainedModel):
4514
 
4515
  if inputs_embeds is None:
4516
  inputs_embeds = self.embedding(input_ids)
 
 
4517
 
4518
  if self.config.patch_level_training:
4519
  # (B, KL, D) => (B, L, D) OR (B, L, D) ==> (B, L//K, D)
@@ -4570,12 +5437,21 @@ class DragonModel(DragonPreTrainedModel):
4570
  )
4571
  shared_kv = (last_k, last_v)
4572
 
 
 
 
 
 
 
4573
  if self.config.final_norm:
4574
  hidden_states = self.final_norm(hidden_states)
4575
 
4576
  if output_hidden_states:
4577
  all_hidden_states = all_hidden_states + (hidden_states,)
4578
 
 
 
 
4579
  return DragonOutput(
4580
  last_hidden_state=hidden_states,
4581
  past_key_values=past_key_values if use_cache else None,
@@ -4589,11 +5465,23 @@ class DragonForCausalLM(DragonPreTrainedModel, GenerationMixin):
4589
  self.config = config
4590
  self.model = DragonModel(config)
4591
  self.vocab_size = config.vocab_size
4592
- self.lm_head = DragonLinear(config, config.hidden_size, config.vocab_size, bias=False, alpha_fwd=1/config.hidden_size, alpha_bwd=1/math.sqrt(config.hidden_size))
 
 
 
 
 
 
 
4593
  self.post_init()
4594
  if config.tie_lm_head:
4595
  self.lm_head.weight = self.model.embedding.weight
4596
 
 
 
 
 
 
4597
  def forward(
4598
  self,
4599
  input_ids: Optional[torch.LongTensor] = None,
@@ -4639,7 +5527,10 @@ class DragonForCausalLM(DragonPreTrainedModel, GenerationMixin):
4639
  labels = labels.to(hidden_states.device)
4640
 
4641
  if linear_cross_entropy is None or not self.config.fused_loss_computation:
4642
- logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)[:, slice_indices, :]).float()
 
 
 
4643
  if not self.config.patch_level_training:
4644
  shift_logits = logits[..., :-1, :].contiguous()
4645
  shift_labels = labels[..., 1:].contiguous()
@@ -4653,6 +5544,7 @@ class DragonForCausalLM(DragonPreTrainedModel, GenerationMixin):
4653
  loss = loss + F.nll_loss(log_probs, shift_labels[:, i])
4654
  loss = loss / self.config.patch_level_training_size
4655
  else:
 
4656
  assert not self.config.patch_level_training, "Fused loss computation is not supported with patch-level training."
4657
  loss = linear_cross_entropy(
4658
  hidden_states[:, slice_indices, :].view(-1, hidden_states.size(-1)),
 
21
 
22
  from flash_attn.modules.mlp import GatedMlp
23
 
24
+ try:
25
+ import flash_moba
26
+ except ImportError:
27
+ flash_moba = None
28
+
29
  try:
30
  from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
31
  except ImportError:
 
59
  except ImportError:
60
  chunk_kda, fused_recurrent_kda, fused_kda_gate, prepare_sequence_ids = None, None, None, None
61
 
62
+ from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated
63
+
64
  from torch.compiler import disable
65
 
66
  logger = logging.get_logger(__name__)
 
275
  out = super().forward(x)
276
  return ScaledGrad.apply(out, self.alpha_fwd, self.alpha_bwd)
277
 
278
+ class DragonScale(nn.Module):
279
+ def __init__(self, s: float):
280
+ super().__init__()
281
+ self.s = s
282
+ def forward(self, x):
283
+ return x * self.s
284
+
285
  class HybridDragonDynamicCache(DynamicCache):
286
  """
287
  A dynamic cache that handle both the attention cache (which has a seq_len dimension) and the GDN cache
 
313
  self.q_conv_caches = []
314
  self.k_conv_caches = []
315
  self.v_conv_caches = []
316
+ # cca v2
317
+ self.conv_states = []
318
+ self.prev_hs = []
319
+ self.has_previous_state = False
320
 
321
  for idx, layer_type in enumerate(config.layers_config):
322
  if not layer_type == "r":
 
331
  self.q_conv_caches.append(None)
332
  self.k_conv_caches.append(None)
333
  self.v_conv_caches.append(None)
334
+ self.conv_states.append(None)
335
+ self.prev_hs.append(None)
336
 
337
  self.window_size = config.sliding_window_size
338
  self.layers_config = config.layers_config
 
379
 
380
  def set_prev_hidden(self, layer_idx, h):
381
  self.cca_prev_hidden[layer_idx] = h
382
+
383
+ # cca v2
384
+ def update_conv_state(self, layer_idx: int, new_conv_state: torch.Tensor) -> torch.Tensor:
385
+ if not self.has_previous_state:
386
+ self.conv_states[layer_idx] = new_conv_state#.to(self.conv_states.device)
387
+ else:
388
+ self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1)
389
+ self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :]#.to(self.conv_states.device)
390
+ return self.conv_states[layer_idx]
391
 
392
  # kv shift
393
  def get_last_kv(self, layer_idx):
 
597
 
598
  projection_dim = self.head_dim * (self.num_attention_heads + 2 * (0 if reuse_kv else self.num_key_value_heads))
599
  self.linear_qkv = DragonLinear(config, config.hidden_size, projection_dim, bias=False)
600
+ self.linear_qkv.norm_case_1 = True
601
 
602
  if self.config.token_shift_attn:
603
  if self.config.scalar_proj_as_hidden_matrix:
 
785
 
786
  return attn_output, last_key_states, last_value_states
787
 
788
+ class DragonMoBAttention(nn.Module):
789
+ def __init__(self, config: DragonConfig, reuse_kv: bool, layer_idx: Optional[int], **kwargs):
790
+ super().__init__()
791
+ self.config = config
792
+ self.layer_idx = layer_idx
793
+ if layer_idx is None:
794
+ logger.warning_once(
795
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
796
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
797
+ "when creating this class."
798
+ )
799
+ self.num_attention_heads = config.num_attention_heads
800
+ self.num_key_value_heads = config.num_key_value_heads
801
+ self.hidden_size = config.hidden_size
802
+ self.head_dim = config.head_dim # if config.head_dim else config.hidden_size * config.expand_factor // self.num_attention_heads
803
+ self.qk_norm = config.qk_norm
804
+ self.window_size = config.sliding_window_size
805
+ self.block_size = config.nsa_block_size
806
+ self.topk = config.nsa_topk
807
+ self.reuse_kv = reuse_kv
808
+
809
+ projection_dim = self.head_dim * (self.num_attention_heads + 2 * (0 if reuse_kv else self.num_key_value_heads))
810
+ self.linear_qkv = DragonLinear(config, config.hidden_size, projection_dim, bias=False)
811
+ self.linear_qkv.norm_case_1 = True
812
+
813
+ if self.config.token_shift_attn:
814
+ if self.config.scalar_proj_as_hidden_matrix:
815
+ self.shift_proj_k = DragonLinear(config, self.hidden_size, self.num_key_value_heads, bias=False)
816
+ self.shift_proj_v = DragonLinear(config, self.hidden_size, self.num_key_value_heads, bias=False)
817
+ else:
818
+ self.shift_proj_k = DragonLinear(config, self.hidden_size, self.num_key_value_heads, bias=False, alpha_bwd=1., alpha_fwd=1.)
819
+ self.shift_proj_v = DragonLinear(config, self.hidden_size, self.num_key_value_heads, bias=False, alpha_bwd=1., alpha_fwd=1.)
820
+ self.shift_proj_k.is_scalar_weight = True
821
+ self.shift_proj_v.is_scalar_weight = True
822
+
823
+ if self.config.token_conv1d_attn:
824
+ self.conv_size = config.conv_kernel
825
+ self.conv_dim = self.num_attention_heads * self.head_dim + self.num_key_value_heads * self.head_dim + self.num_key_value_heads * self.head_dim
826
+ self.qkv_conv1d = nn.Conv1d(in_channels=self.conv_dim, out_channels=self.conv_dim, bias=False, kernel_size=self.conv_size, groups=self.conv_dim, padding=self.conv_size-1)
827
+ self.causal_conv1d_fn = causal_conv1d_fn
828
+ self.causal_conv1d_update = causal_conv1d_update or torch_causal_conv1d_update
829
+
830
+ if self.qk_norm:
831
+ self.q_norm = DragonNorm(config, self.head_dim)
832
+ if not reuse_kv:
833
+ self.k_norm = DragonNorm(config, self.head_dim)
834
+
835
+ def forward(
836
+ self,
837
+ hidden_states: torch.Tensor,
838
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
839
+ position_ids: Optional[torch.LongTensor] = None,
840
+ cache_params: Optional[HybridDragonDynamicCache] = None,
841
+ key_value_last_layer: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
842
+ **kwargs,
843
+ ):
844
+ _, q_len, _ = hidden_states.shape
845
+ use_precomputed_states = (cache_params is not None and q_len == 1)
846
+
847
+ # Q, K, V projections.
848
+ if not self.reuse_kv:
849
+ query_states, key_states, value_states = get_query_key_value_tensors(self, hidden_states)
850
+ else:
851
+ query_states = get_query_key_value_tensors(self, hidden_states)
852
+ key_states, value_states = key_value_last_layer
853
+ last_key_states, last_value_states = None, None
854
+
855
+ # token-shift.
856
+ if self.config.token_shift_attn and not self.reuse_kv:
857
+ alpha_k = torch.sigmoid(self.shift_proj_k(hidden_states).float()).float().to(key_states.dtype).unsqueeze(-1) # (B, L, Hkv, 1)
858
+ alpha_v = torch.sigmoid(self.shift_proj_v(hidden_states).float()).float().to(value_states.dtype).unsqueeze(-1) # (B, L, Hkv, 1)
859
+
860
+ if cache_params is not None:
861
+ k_prev, v_prev = cache_params.get_last_kv(self.layer_idx)
862
+ if k_prev is None:
863
+ k_prev, v_prev = torch.zeros_like(key_states[:, :1]), torch.zeros_like(value_states[:, :1])
864
+ cache_params.set_last_kv(self.layer_idx, key_states[:, -1:], value_states[:, -1:])
865
+ else:
866
+ k_prev = F.pad(key_states, (0, 0, 0, 0, 1, 0))[:, :-1] # (B, L, H, D)
867
+ v_prev = F.pad(value_states, (0, 0, 0, 0, 1, 0))[:, :-1] # (B, L, H, D)
868
+
869
+ key_states = alpha_k * k_prev + (1 - alpha_k) * key_states
870
+ value_states = alpha_v * v_prev + (1 - alpha_v) * value_states
871
+
872
+ # conv.
873
+ if self.config.token_conv1d_attn:
874
+ assert not self.reuse_kv, "not supported"
875
+ # --- pack for conv ---
876
+ q_proj = rearrange(query_states, "b l h d -> b l (h d)")
877
+ k_proj = rearrange(key_states, "b l g d -> b l (g d)")
878
+ v_proj = rearrange(value_states, "b l g d -> b l (g d)")
879
+ mixed_qkv = torch.cat([q_proj, k_proj, v_proj], dim=-1).transpose(1, 2) # (B,C,L)
880
+
881
+ if cache_params is not None:
882
+ conv_cache = cache_params.conv_caches[self.layer_idx]
883
+
884
+ if use_precomputed_states:
885
+ mixed_qkv = self.causal_conv1d_update(
886
+ mixed_qkv,
887
+ conv_cache,
888
+ self.qkv_conv1d.weight.squeeze(1),
889
+ self.qkv_conv1d.bias,
890
+ 'silu',
891
+ ) # conv_cache is updated in-place here
892
+ else:
893
+ if cache_params is not None:
894
+ conv_cache = F.pad(mixed_qkv, (self.conv_size - mixed_qkv.shape[-1], 0))
895
+ cache_params.conv_caches[self.layer_idx] = conv_cache
896
+ if self.causal_conv1d_fn is not None:
897
+ mixed_qkv = self.causal_conv1d_fn(
898
+ x=mixed_qkv,
899
+ weight=self.qkv_conv1d.weight.squeeze(1),
900
+ bias=self.qkv_conv1d.bias,
901
+ activation='silu',
902
+ seq_idx=None,
903
+ )
904
+ else:
905
+ mixed_qkv = F.silu(self.qkv_conv1d(mixed_qkv)[:, :, :q_len])
906
+
907
+ # split back
908
+ mixed_qkv = mixed_qkv.transpose(1, 2)
909
+ q_proj, k_proj, v_proj = torch.split(
910
+ mixed_qkv,
911
+ [self.num_attention_heads*self.head_dim, self.num_key_value_heads*self.head_dim, self.num_key_value_heads*self.head_dim],
912
+ dim=-1,
913
+ )
914
+ query_states = rearrange(q_proj, "b l (h d) -> b l h d", h=self.num_attention_heads)
915
+ key_states = rearrange(k_proj, "b l (g d) -> b l g d", g=self.num_key_value_heads)
916
+ value_states = rearrange(v_proj, "b l (g d) -> b l g d", g=self.num_key_value_heads)
917
+
918
+ # QK-norm.
919
+ if self.qk_norm:
920
+ query_states = self.q_norm(query_states)
921
+ if not self.reuse_kv:
922
+ key_states = self.k_norm(key_states)
923
+
924
+ # RoPE.
925
+ if self.config.rope_theta_local > 0.0:
926
+ cos, sin = position_embeddings
927
+ if self.config.rope_type_local == "rope":
928
+ query_states = apply_rotary_emb(query_states, cos, sin)
929
+ if not self.reuse_kv:
930
+ key_states = apply_rotary_emb(key_states, cos, sin)
931
+ elif self.config.rope_type_local == "p-rope":
932
+ query_states = apply_p_rotary_emb(query_states, cos, sin, p=0.5)
933
+ if not self.reuse_kv:
934
+ key_states = apply_p_rotary_emb(key_states, cos, sin)
935
+ else:
936
+ raise ValueError(f"Unknow rope type : {self.config.rope_type_local}")
937
+
938
+ # KV-cache.
939
+ if not self.reuse_kv and cache_params is not None:
940
+ key_states, value_states = cache_params.update(key_states, value_states, self.layer_idx)
941
+
942
+ # save k,v for next layer (*after* norm and RoPE and kv-cache update)
943
+ if not self.reuse_kv:
944
+ last_key_states, last_value_states = key_states, value_states
945
+
946
+ # attention computation.
947
+ B, L, _, _ = query_states.shape
948
+ cu_seqlens = torch.arange(0, (B + 1) * L, step=L, dtype=torch.int32, device=query_states.device)
949
+ attn_output = flash_moba.flash_moba_varlen_func(
950
+ q=query_states.bfloat16().view(B*L, self.num_attention_heads, self.head_dim),
951
+ k=key_states.bfloat16().view(B*L, self.num_key_value_heads, self.head_dim),
952
+ v=value_states.bfloat16().view(B*L, self.num_key_value_heads, self.head_dim),
953
+ cu_seqlens_q=cu_seqlens,
954
+ cu_seqlens_k=cu_seqlens,
955
+ max_seqlen_q=L,
956
+ max_seqlen_k=L,
957
+ moba_chunk_size=self.block_size,
958
+ moba_topk=self.topk,
959
+ causal=True,
960
+ ).view(B, L, self.num_attention_heads, self.head_dim)
961
+ # softmax scale...
962
+ # softcap...
963
+
964
+ #if cache_params is not None and not self.reuse_kv:
965
+ # cache_params.trim(self.layer_idx)
966
+
967
+ return attn_output, last_key_states, last_value_states
968
+
969
  class DragonTensorProductAttention(nn.Module):
970
  """
971
  Multi-headed attention from 'Attention Is All You Need' paper.
 
996
  self.W_A_v = DragonLinear(config, self.hidden_size, self.num_attention_heads * self.rank, bias=False)
997
  self.W_B_k = DragonLinear(config, self.hidden_size, self.rank * self.head_dim, bias=False)
998
  self.W_B_v = DragonLinear(config, self.hidden_size, self.rank * self.head_dim, bias=False)
999
+ self.c_q.norm_case_1 = True
1000
+ # todo : norm others?
1001
 
1002
  if self.config.token_shift_attn:
1003
  if self.config.scalar_proj_as_hidden_matrix:
 
1369
 
1370
  return attn_output, None, None
1371
 
1372
+ class DragonCompressedConvolutionalAttention2(nn.Module):
1373
+ def __init__(self, config: DragonConfig, layer_idx: Optional[int], **kwargs):
1374
+ super().__init__()
1375
+ self.config = config
1376
+ assert layer_idx is not None
1377
+ self.layer_idx = layer_idx
1378
+ if layer_idx is None:
1379
+ logger.warning_once(
1380
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
1381
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
1382
+ "when creating this class."
1383
+ )
1384
+
1385
+ self.hidden_size = config.hidden_size
1386
+ self.window_size = config.sliding_window_size
1387
+
1388
+ self.cca_time0 = 2
1389
+ self.cca_time1 = 2
1390
+ self.padding0 = self.cca_time0 - 1
1391
+ self.padding1 = self.cca_time1 - 1
1392
+ self.total_padding = self.padding0 + self.padding1
1393
+
1394
+ self.num_kv_heads = 5 # config.num_key_value_heads
1395
+ self.num_q_heads = 10 # config.num_attention_heads
1396
+ self.num_heads = config.num_attention_heads
1397
+
1398
+ # Geometry
1399
+ self.head_dim = config.head_dim
1400
+ self.latent_k_dim = self.num_kv_heads * self.head_dim
1401
+ self.latent_q_dim = self.num_q_heads * self.head_dim
1402
+ self.sqrt_head_dim = float(math.sqrt(self.head_dim))
1403
+ self.gqa_groups = self.num_q_heads // self.num_kv_heads
1404
+ assert self.num_q_heads % self.num_kv_heads == 0, "q_heads must be a multiple of k_heads"
1405
+ assert (self.latent_k_dim + self.latent_q_dim) == (self.num_kv_heads + self.num_q_heads) * self.head_dim
1406
+
1407
+ # Projections
1408
+ self.linear_q = nn.Linear(self.hidden_size, self.latent_q_dim, bias=self.config.attention_bias)
1409
+ self.linear_k = nn.Linear(self.hidden_size, self.latent_k_dim, bias=self.config.attention_bias)
1410
+ self.val_proj1 = nn.Linear(self.hidden_size, self.latent_k_dim // 2, bias=self.config.attention_bias)
1411
+ self.val_proj2 = nn.Linear(self.hidden_size, self.latent_k_dim // 2, bias=self.config.attention_bias)
1412
+
1413
+ # Depthwise + grouped conv along sequence
1414
+ in_out_ch = self.latent_k_dim + self.latent_q_dim
1415
+ self.conv_qk = nn.Sequential(
1416
+ nn.Conv1d(
1417
+ in_channels=in_out_ch,
1418
+ out_channels=in_out_ch,
1419
+ kernel_size=self.cca_time0,
1420
+ groups=in_out_ch,
1421
+ padding=0,
1422
+ stride=1,
1423
+ ),
1424
+ nn.Conv1d(
1425
+ in_channels=in_out_ch,
1426
+ out_channels=in_out_ch,
1427
+ kernel_size=self.cca_time1,
1428
+ groups=(self.num_kv_heads + self.num_q_heads),
1429
+ padding=0,
1430
+ stride=1,
1431
+ ),
1432
+ )
1433
+
1434
+ # Per-k head temperature
1435
+ self.temp = nn.Parameter(torch.zeros(self.num_kv_heads))
1436
+
1437
+ def forward(
1438
+ self,
1439
+ hidden_states: torch.Tensor,
1440
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
1441
+ cache_params: Optional[HybridDragonDynamicCache],
1442
+ **kwargs,
1443
+ ):
1444
+ """
1445
+ hidden_states: [B, S, E] (HF layout)
1446
+ returns:
1447
+ query: [B, S, num_q_heads*head_dim]
1448
+ key : [B, S, num_k_heads*head_dim]
1449
+ value: [B, S, num_k_heads*head_dim]
1450
+ """
1451
+
1452
+ past_key_values = cache_params
1453
+ batch_size, seq_length, _ = hidden_states.shape
1454
+
1455
+ # ---- Switch to [S, B, H] ----
1456
+ hs = hidden_states.transpose(0, 1).contiguous() # [S, B, H]
1457
+ # Time-shifted stream for v2 (pad one at the front along sequence)
1458
+ hs_d = F.pad(hs[:-1], pad=(0, 0, 0, 0, 1, 0)) # [S, B, H]
1459
+
1460
+ # Q/K in the full space
1461
+ q = self.linear_q(hs) # [S, B, latent_q_dim]
1462
+ k = self.linear_k(hs) # [S, B, latent_k_dim]
1463
+ qk_packed0 = torch.cat([q, k], dim=-1) # [S, B, latent_q + latent_k]
1464
+
1465
+ # Pre-mean tensors in head form (for "qk_mean_{q,k}" calc)
1466
+ query_pre = qk_packed0[..., : self.latent_q_dim].view(
1467
+ *qk_packed0.shape[:2], self.num_q_heads, self.head_dim
1468
+ ) # [S, B, qh, dh]
1469
+
1470
+ key_pre = qk_packed0[..., self.latent_q_dim :].view(
1471
+ *qk_packed0.shape[:2], self.num_kv_heads, self.head_dim
1472
+ ) # [S, B, kh, dh]
1473
+ key_pre = (
1474
+ key_pre.unsqueeze(-2)
1475
+ .repeat(1, 1, 1, self.gqa_groups, 1)
1476
+ .view(*qk_packed0.shape[:2], self.num_q_heads, self.head_dim)
1477
+ ) # [S, B, qh, dh]
1478
+
1479
+ # Means for residual mixing
1480
+ qk_mean_q = (query_pre + key_pre) / 2
1481
+ qk_mean_k = qk_mean_q.view(*qk_mean_q.shape[:2], self.num_kv_heads, self.gqa_groups, -1).mean(dim=-2)
1482
+
1483
+ if past_key_values is not None:
1484
+ if past_key_values.has_previous_state:
1485
+ # Generation
1486
+ qk_packed0 = qk_packed0.transpose(0, 1) # [B, 1, H]
1487
+ qk_packed0_cached = past_key_values.conv_states[self.layer_idx] # [B, H, 2]
1488
+ qk_packed0_cat = torch.cat([qk_packed0_cached, qk_packed0.transpose(1, 2)], dim=-1) # [B, H, 3]
1489
+ qk_packed3 = self.conv_qk(qk_packed0_cat).permute(2, 0, 1) # [S, B, E]
1490
+ qk_packed0_cache = past_key_values.update_conv_state(
1491
+ layer_idx=self.layer_idx, new_conv_state=qk_packed0
1492
+ ) # [B, H, 2]
1493
+
1494
+ else:
1495
+ # Prefill
1496
+ qk_packed0_transposed = qk_packed0.permute(1, 2, 0) # [S, B, H] -> [B, H, S]
1497
+ conv_states = nn.functional.pad(
1498
+ qk_packed0_transposed,
1499
+ (
1500
+ self.cca_time0 - qk_packed0_transposed.shape[-1],
1501
+ 0,
1502
+ ),
1503
+ )
1504
+ qk_packed0_cache = past_key_values.update_conv_state(
1505
+ layer_idx=self.layer_idx, new_conv_state=conv_states
1506
+ )
1507
+ # Convs over sequence: [S, B, E] -> [B, E, S] -> pad -> conv ->
1508
+ # [S, B, E]
1509
+ qk_packed1 = qk_packed0.permute(1, 2, 0) # [B, E, S]
1510
+ qk_packed2 = F.pad(qk_packed1, (self.total_padding, 0))
1511
+ qk_packed3 = self.conv_qk(qk_packed2).permute(2, 0, 1) # [S, B, E]
1512
+
1513
+ else:
1514
+ # Convs over sequence: [S, B, E] -> [B, E, S] -> pad -> conv -> [S,
1515
+ # B, E]
1516
+ qk_packed1 = qk_packed0.permute(1, 2, 0) # [B, E, S]
1517
+ qk_packed2 = F.pad(qk_packed1, (self.total_padding, 0))
1518
+ qk_packed3 = self.conv_qk(qk_packed2).permute(2, 0, 1) # [S, B, E]
1519
+
1520
+ # Build queries/keys from conv output + means
1521
+ query = (
1522
+ qk_packed3[..., : self.latent_q_dim].view(*qk_packed3.shape[:2], self.num_q_heads, self.head_dim)
1523
+ + qk_mean_q
1524
+ ) # [S, B, qh, dh]
1525
+
1526
+ key = (
1527
+ qk_packed3[..., self.latent_q_dim :].view(*qk_packed3.shape[:2], self.num_kv_heads, self.head_dim)
1528
+ + qk_mean_k
1529
+ ) # [S, B, kh, dh]
1530
+
1531
+ # Values from the two time streams
1532
+ v1 = self.val_proj1(hs) # [S, B, latent_k_dim/2]
1533
+ if past_key_values is not None:
1534
+ if past_key_values.has_previous_state:
1535
+ # Generation
1536
+ # [B, H]
1537
+ hs_d = past_key_values.prev_hs[self.layer_idx].clone()
1538
+ hs_d = hs_d.unsqueeze(0) # [1, B, H]
1539
+ else:
1540
+ past_key_values.prev_hs[self.layer_idx] = torch.zeros(batch_size, self.hidden_size, device=hs.device, dtype=hs.dtype)
1541
+ past_key_values.prev_hs[self.layer_idx].copy_(hs[-1, :, :])
1542
+
1543
+ v2 = self.val_proj2(hs_d) # [S, B, latent_k_dim/2]
1544
+ value = (
1545
+ torch.cat([v1, v2], dim=-1).contiguous().view(*hs.shape[:2], self.num_kv_heads, self.head_dim)
1546
+ ) # [S, B, kh, dh]
1547
+
1548
+ # L2-normalize per head, then scale
1549
+ query_norm = query.norm(p=2, dim=-1, keepdim=True)
1550
+ key_norm = key.norm(p=2, dim=-1, keepdim=True)
1551
+
1552
+ key = (key * (self.sqrt_head_dim / key_norm)) * self.temp[None, None].unsqueeze(-1)
1553
+ query = query * (self.sqrt_head_dim / query_norm)
1554
+
1555
+ # Flatten head axis, then return to HF layout [B, S, ...]
1556
+ query = query.view(*query.shape[:2], self.num_q_heads * self.head_dim).transpose(0, 1).contiguous()
1557
+ key = key.view(*key.shape[:2], self.num_kv_heads * self.head_dim).transpose(0, 1).contiguous()
1558
+ value = value.view(*value.shape[:2], self.num_kv_heads * self.head_dim).transpose(0, 1).contiguous()
1559
+
1560
+ query_states = query
1561
+ key_states = key
1562
+ value_states = value
1563
+
1564
+ query_states = query_states.view(batch_size, seq_length, self.num_q_heads, self.head_dim)
1565
+ key_states = key_states.view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
1566
+ value_states = value_states.view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
1567
+
1568
+ # RoPE.
1569
+ if self.config.rope_theta_local > 0.0:
1570
+ cos, sin = position_embeddings
1571
+ if self.config.rope_type_local == "rope":
1572
+ query_states = apply_rotary_emb(query_states, cos, sin)
1573
+ key_states = apply_rotary_emb(key_states, cos, sin)
1574
+ elif self.config.rope_type_local == "p-rope":
1575
+ query_states = apply_p_rotary_emb(query_states, cos, sin, p=0.5)
1576
+ key_states = apply_p_rotary_emb(key_states, cos, sin)
1577
+ else:
1578
+ raise ValueError(f"Unknow rope type : {self.config.rope_type_local}")
1579
+
1580
+ # KV-cache.
1581
+ if past_key_values is not None:
1582
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
1583
+
1584
+ # attention computation.
1585
+ wsize = min(self.window_size, self.config.slw_wsize) if self.config.slw_wsize > 0 else self.window_size
1586
+
1587
+ if ATTN_IMPL == "eager":
1588
+ attention_interface = lambda q, k, v, wsize, **kw: eager_attention_forward(q, k, v, window_size=(wsize, 0), **kw)
1589
+ elif ATTN_IMPL == "flex":
1590
+ if wsize != self.last_wsize:
1591
+ self.last_wsize = self.build_mask(wsize)
1592
+ attention_interface = lambda q, k, v, softmax_scale, **kw: flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=create_block_mask(self.attn_mask, B=None, H=None, Q_LEN=q.size(1), KV_LEN=k.size(1)), score_mod=self.score_mod, scale=softmax_scale, enable_gqa=self.num_attention_heads > self.num_key_value_heads).transpose(1, 2)
1593
+ elif ATTN_IMPL == "fa2":
1594
+ attention_interface = lambda q, k, v, wsize, **kw: flash_attn_func(q, k, v, window_size=(wsize, 0), **kw)
1595
+ elif ATTN_IMPL == "fa3":
1596
+ attention_interface = lambda q, k, v, wsize, **kw: flash_attn_func(q, k, v, window_size=(wsize, 0), **kw)[0]
1597
+ else:
1598
+ raise ValueError(f"Unknown ATTN_IMPL: {ATTN_IMPL}")
1599
+
1600
+ attn_output = attention_interface(
1601
+ query_states.bfloat16(),
1602
+ key_states.bfloat16(),
1603
+ value_states.bfloat16(),
1604
+ causal=True,
1605
+ wsize=wsize,
1606
+ softcap=self.config.softcap_local_attn,
1607
+ softmax_scale=None if not self.config.use_uscaling else 1/self.head_dim,
1608
+ )
1609
+
1610
+ return attn_output, None, None
1611
+
1612
  class DragonNativeSparseAttention(nn.Module):
1613
  """
1614
  Multi-headed attention from 'Attention Is All You Need' paper.
 
2149
 
2150
  projection_dim = self.head_qk_dim * self.num_attention_heads + self.head_qk_dim * self.num_key_value_heads + (self.head_v_dim * self.num_noise_heads//2)
2151
  self.linear_qkv = DragonLinear(config, config.hidden_size, projection_dim, bias=False)
2152
+ self.linear_qkv.norm_case_1 = True
2153
 
2154
  if self.config.token_shift_attn:
2155
  if self.config.scalar_proj_as_hidden_matrix:
 
2827
  self.W_A_v = DragonLinear(config, self.hidden_size, self.num_noise_heads * self.rank, bias=False)
2828
  self.W_B_k = DragonLinear(config, self.hidden_size, self.rank * self.head_qk_dim, bias=False)
2829
  self.W_B_v = DragonLinear(config, self.hidden_size, self.rank * self.head_v_dim, bias=False)
2830
+ self.c_q.norm_case_1 = True
2831
+ # todo: norm others?
2832
 
2833
  if self.config.token_shift_attn:
2834
  if self.config.scalar_proj_as_hidden_matrix:
 
3617
  self.num_attention_heads*self.dk + self.n_kv_heads*self.dk + self.n_kv_heads*self.dv,
3618
  bias=False
3619
  )
3620
+ self.linear_qkv.norm_case_1 = True
3621
  self.linear_ba = DragonLinear(
3622
  config, config.hidden_size,
3623
  self.num_attention_heads + self.num_attention_heads, #+ self.num_attention_heads*self.dv, # b(H), a(H), g(H*dv)
3624
  bias=False
3625
  )
3626
 
3627
+ if config.legacy_gate:
3628
+ if config.gate_type == 'kimi':
3629
+ self.linear_g = nn.Sequential(
3630
+ DragonLinear(config, config.hidden_size, self.dv, bias=False),
3631
+ DragonLinear(config, self.dv, self.n_kv_heads*self.dv, bias=True),
3632
+ )
3633
+ self.output_norm = FusedRMSNormGated(hidden_size=self.dv, eps=config.norm_epsilon, activation='sigmoid')
3634
+ else:
3635
+ self.linear_g = DragonLinear(
3636
+ config, config.hidden_size,
3637
+ self.n_kv_heads * self.dv,
3638
+ bias=False
3639
+ )
3640
+ self.output_norm = FusedRMSNormGated(hidden_size=self.dv, eps=config.norm_epsilon)
3641
+ self.linear_g.norm_case_1 = True
3642
+
3643
  dt_min = config.time_step_min
3644
  dt_max = config.time_step_max
3645
  dt_init_floor = config.time_step_floor
 
3654
  inv_dt = dt + torch.log(-torch.expm1(-dt))
3655
  with torch.no_grad():
3656
  self.dt_bias = nn.Parameter(inv_dt)
3657
+ self.dt_bias._no_weight_decay = True
3658
 
3659
  assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
3660
  A = torch.empty(self.n_heads_local, dtype=torch.float32).uniform_(*A_init_range)
3661
  A_log = torch.log(A) # Keep A_log in fp32
3662
  self.A_log = nn.Parameter(A_log)
3663
+ self.A_log._no_weight_decay = True
3664
 
3665
  if self.config.rope_gdn == "rope":
3666
  self.rope_proj = DragonLinear(config, config.hidden_size, self.dk//4, bias=False)
 
3823
  use_qk_l2norm_in_kernel=True
3824
  ) # (B L H dv)
3825
 
3826
+ if self.config.legacy_gate:
3827
+ g = self.linear_g(hidden_states) # (B, L, H*dv)
3828
+ g = rearrange(g, "b l (h d) -> b l h d", h=self.n_kv_heads)
3829
+ o = self.output_norm(o, g)
3830
+
3831
  # update GDN cache
3832
  if cache_params is not None:
3833
  cache_params.ssm_caches[self.layer_idx] = ssm_cache
 
3861
  self.q_proj = DragonLinear(config, config.hidden_size, self.key_dim, bias=False)
3862
  self.k_proj = DragonLinear(config, config.hidden_size, self.key_dim, bias=False)
3863
  self.v_proj = DragonLinear(config, config.hidden_size, self.value_dim, bias=False)
3864
+ self.q_proj.norm_case_1 = True
3865
+ self.k_proj.norm_case_1 = True
3866
+ self.v_proj.norm_case_1 = True
3867
 
3868
  self.q_conv1d = ShortConvolution(
3869
  hidden_size=self.key_dim,
 
3896
  self.A_log = nn.Parameter(torch.log(torch.empty(self.num_q_heads, dtype=torch.float32).uniform_(1, 16)))
3897
  self.dt_bias = nn.Parameter(torch.zeros(self.key_dim, dtype=torch.float32))
3898
 
3899
+ if config.legacy_gate:
3900
+ if config.gate_type == 'kimi':
3901
+ self.linear_g = nn.Sequential(
3902
+ DragonLinear(config, config.hidden_size, self.head_v_dim, bias=False),
3903
+ DragonLinear(config, self.head_v_dim, self.num_attention_heads*self.head_v_dim, bias=True),
3904
+ )
3905
+ self.output_norm = FusedRMSNormGated(hidden_size=self.head_v_dim, eps=config.norm_epsilon, activation='sigmoid')
3906
+ else:
3907
+ self.linear_g = DragonLinear(
3908
+ config, config.hidden_size,
3909
+ self.num_attention_heads * self.head_v_dim,
3910
+ bias=False
3911
+ )
3912
+ self.output_norm = FusedRMSNormGated(hidden_size=self.head_v_dim, eps=config.norm_epsilon)
3913
+ self.linear_g.norm_case_1 = True
3914
 
3915
  @disable
3916
  def _kda_gate_call(self, g, A_log, head_k_dim, g_bias):
 
3921
  hidden_states: torch.Tensor,
3922
  position_embeddings: tuple[torch.Tensor, torch.Tensor],
3923
  cache_params: Optional[HybridDragonDynamicCache] = None,
3924
+ cu_seqlens: Optional[torch.Tensor] = None,
3925
  **kwargs,
3926
  ):
3927
  _, q_len, _ = hidden_states.shape
 
3938
  conv_state_k = cache_params.k_conv_caches[self.layer_idx]
3939
  conv_state_v = cache_params.v_conv_caches[self.layer_idx]
3940
 
3941
+ seq_idx = None
3942
+ if cu_seqlens is not None:
3943
+ seq_idx = prepare_sequence_ids(cu_seqlens).to(torch.int32).unsqueeze(0)
3944
  q, conv_state_q = self.q_conv1d(
3945
  x=self.q_proj(hidden_states),
3946
  cache=conv_state_q,
3947
  output_final_state=cache_params is not None,
3948
+ seq_idx=seq_idx,
3949
  )
3950
  k, conv_state_k = self.k_conv1d(
3951
  x=self.k_proj(hidden_states),
3952
  cache=conv_state_k,
3953
  output_final_state=cache_params is not None,
3954
+ seq_idx=seq_idx,
3955
  )
3956
  v, conv_state_v = self.v_conv1d(
3957
  x=self.v_proj(hidden_states),
3958
  cache=conv_state_v,
3959
  output_final_state=cache_params is not None,
3960
+ seq_idx=seq_idx,
3961
  )
3962
 
3963
  g = self.f_proj(hidden_states)
 
3983
  initial_state=None,
3984
  output_final_state=cache_params is not None,
3985
  use_qk_l2norm_in_kernel=True,
3986
+ cu_seqlens=cu_seqlens,
3987
  )
3988
  elif mode == 'fused_recurrent':
3989
  o, ssm_cache = fused_recurrent_kda(
 
4002
  #o = o * F.silu(rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim))
4003
  # TODO: other types of gates? as well as ZCG?
4004
 
4005
+ if self.config.legacy_gate:
4006
+ g = self.linear_g(hidden_states) # (B, L, H*dv)
4007
+ g = rearrange(g, "b l (h d) -> b l h d", h=self.num_attention_heads)
4008
+ o = self.output_norm(o, g)
4009
+
4010
  if cache_params is not None:
4011
  cache_params.ssm_caches[self.layer_idx] = ssm_cache
4012
  cache_params.q_conv_caches[self.layer_idx] = conv_state_q
 
4056
  if config.mamba3_rope:
4057
  self.rope_proj = DragonLinear(config, self.d_model, self.num_rope_angles, bias=False)
4058
 
4059
+ # Order: [x, B, C, dt]
4060
+ d_in_proj = self.d_inner + 2 * self.d_state * self.ngroups + self.nheads
4061
 
4062
  if self.config.mamba3_is_A_dd:
4063
  self.A_proj = DragonLinear(config, self.d_model, self.nheads, bias=False, dtype=torch.float32)
 
4082
  self.dt_bias._no_weight_decay = True
4083
 
4084
  self.in_proj = DragonLinear(config, self.d_model, d_in_proj, bias=self.bias)
4085
+ self.in_proj.norm_case_1 = True
4086
 
4087
  self.B_bias, self.C_bias = None, None
4088
  if not config.mamba3_remove_BC_bias:
 
4112
  self.D = nn.Parameter(torch.ones(self.nheads))
4113
  self.D._no_weight_decay = True
4114
 
4115
+ if config.legacy_gate:
4116
+ self.linear_g = DragonLinear(
4117
+ config, config.hidden_size,
4118
+ self.d_inner,
4119
+ bias=False,
4120
+ )
4121
+ self.linear_g.norm_case_1 = True
4122
+ if config.mamba3_postgate_norm:
4123
+ self.output_norm = RMSNormGated(self.d_inner, eps=config.norm_epsilon, norm_before_gate=False)
4124
+
4125
+ def forward(
4126
  self,
4127
  hidden_states: torch.Tensor,
4128
  cache_params: Optional[HybridDragonDynamicCache] = None,
4129
+ cu_seqlens: Optional[torch.Tensor] = None,
4130
  **kwargs
4131
  ):
4132
+ cached_len = None
4133
+ if cache_params is not None:
4134
+ hidden_states_cached = cache_params.ssm_caches[self.layer_idx] # (B, L, D)
4135
+ if hidden_states_cached is not None:
4136
+ cached_len = hidden_states_cached.shape[1]
4137
+ hidden_states = torch.cat([hidden_states_cached, hidden_states], dim=1) # (B, L+1, D)
4138
+ cache_params.ssm_caches[self.layer_idx] = hidden_states
4139
+
4140
  # Apply in_proj
4141
+ xBCdt = self.in_proj(hidden_states) # (B, l, D), l=1 when decoding
4142
+ xBC, dd_dt = torch.split(
4143
+ xBCdt,
4144
  [
 
4145
  self.d_inner + 2 * self.d_state * self.ngroups,
4146
  self.nheads,
4147
  ],
 
4154
  _A = -torch.exp(self.A_log).unsqueeze(0).unsqueeze(0)
4155
  dt = F.softplus(dd_dt + self.dt_bias) # (B, L, N)
4156
 
4157
+ seq_idx = None
4158
+ if cu_seqlens is not None:
4159
+ seq_idx = prepare_sequence_ids(cu_seqlens).to(torch.int32).unsqueeze(0)
4160
+
4161
  if not self.config.mamba3_remove_conv:
4162
  xBC = causal_conv1d_fn(
4163
  x=xBC.transpose(1, 2),
4164
  weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
4165
  bias=self.conv1d.bias,
4166
  activation=self.activation,
4167
+ seq_idx=seq_idx,
4168
  ).transpose(1, 2) # (B, L, self.d_inner + 2 * ngroups * d_state)
4169
 
4170
  x, B, C = torch.split(
 
4230
 
4231
  x_scalar = (gamma_arr*_alpha_arr).to(torch.bfloat16)
4232
 
 
 
 
 
4233
  out = mamba_chunk_scan_discretized_combined(
4234
  x=x.bfloat16(),
4235
  A=A,
 
4241
  CB_sum=CB_sum,
4242
  D=self.D,
4243
  z=None,
4244
+ initial_states=None, # ssm_cache,
4245
+ return_final_states=False, # cache_params is not None,
4246
+ seq_idx=seq_idx,
4247
  )
4248
+ y = out
4249
+
4250
+ if self.config.legacy_gate:
4251
+ if not self.config.mamba3_postgate_norm:
4252
+ g = self.linear_g(hidden_states) # (B, L, d_inner)
4253
+ y = rearrange(y, "b l h p -> b l (h p)")
4254
+ y = y * F.silu(g)
4255
+ y = rearrange(y, "b l (h p) -> b l h p", h=self.nheads)
4256
+ else:
4257
+ g = self.linear_g(hidden_states) # (B, L, d_inner)
4258
+ y = rearrange(y, "b l h p -> b l (h p)")
4259
+ y = self.output_norm(y, g)
4260
+ y = rearrange(y, "b l (h p) -> b l h p", h=self.nheads)
4261
 
4262
+ if cached_len and cached_len > 0:
4263
+ y = y[:, cached_len:, :] # keep only the new Ln steps
 
 
 
 
 
 
 
4264
 
4265
  return y, None, None
4266
 
 
4281
  # Order: [x, B, C, dt]
4282
  d_in_proj = self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
4283
  self.in_proj = DragonLinear(config, self.d_model, d_in_proj, bias=False)
4284
+ self.in_proj.norm_case_1 = True
4285
 
4286
  if not self.config.mamba3_remove_conv:
4287
  conv_dim = self.d_inner + 2 * self.ngroups * self.d_state
 
4319
  self.D = nn.Parameter(torch.ones(self.nheads))
4320
  self.D._no_weight_decay = True
4321
 
4322
+ if config.legacy_gate:
4323
+ self.linear_g = DragonLinear(
4324
+ config, config.hidden_size,
4325
+ self.d_inner,
4326
+ bias=False,
4327
+ )
4328
+ self.linear_g.norm_case_1 = True
4329
+ self.output_norm = RMSNormGated(self.d_inner, eps=config.norm_epsilon, norm_before_gate=False)
4330
+
4331
  def forward(self, hidden_states, **kwargs):
4332
  """
4333
  u: (B, L, D)
 
4374
  initial_states=None,
4375
  )
4376
 
4377
+ if self.config.legacy_gate:
4378
+ g = self.linear_g(hidden_states) # (B, L, d_inner)
4379
+ y = rearrange(y, "b l h p -> b l (h p)")
4380
+ y = self.output_norm(y, g)
4381
+ y = rearrange(y, "b l (h p) -> b l h p", h=self.nheads)
4382
+
4383
  return y, None, None
4384
 
4385
  class DragonMamba3Mimo(nn.Module):
 
4394
  "when creating this class."
4395
  )
4396
 
4397
+ assert not self.config.gate_gdn, "gate must done inside the mimo mamba3 block."
4398
+
4399
  self.d_model = config.hidden_size
4400
+ self.d_state = config.mamba_d_state
4401
  self.conv_init = None
4402
  self.expand = 2
4403
+ self.headdim = config.mamba_headdim
4404
  self.ngroups = config.mamba_ngroups
4405
  self.activation = "swish"
4406
  self.bias = False
 
4415
  self.dt_init_floor = 1e-4
4416
  self.mimo_dim = config.mamba_mimo_dim
4417
  self.mimo_proj_block_order = 1
 
4418
 
4419
  self.d_inner = int(self.expand * self.d_model)
4420
  assert self.d_inner % self.headdim == 0
4421
  self.nheads = self.d_inner // self.headdim
4422
  self.dr_out_dim = self.d_inner // self.mimo_proj_block_order
4423
 
 
4424
  self.split_tensor_size = int(self.d_state * self.rope_fraction)
4425
  if self.split_tensor_size % 2 != 0:
4426
  self.split_tensor_size -= 1
 
4446
  self.dt_bias._no_weight_decay = True
4447
 
4448
  self.in_proj = DragonLinear(config, self.d_model, d_in_proj, bias=self.bias)
4449
+ self.in_proj.norm_case_1 = True
4450
 
4451
  self.B_bias = nn.Parameter(torch.ones((self.mimo_dim, self.nheads, self.d_state)), requires_grad=True)
4452
  self.C_bias = nn.Parameter(torch.ones((self.mimo_dim, self.nheads, self.d_state)), requires_grad=True)
 
4478
  self.in_proj_mimo_z = nn.Parameter(in_proj_mimo_z_init_weights, requires_grad=True)
4479
  self.out_proj_mimo = nn.Parameter(out_proj_mimo_init_weights, requires_grad=True)
4480
 
 
4481
  # D "skip" parameter
4482
  self.D = nn.Parameter(torch.ones(self.nheads))
4483
  self.D._no_weight_decay = True
4484
 
4485
+ if config.legacy_gate:
4486
+ if config.mamba3_postgate_norm:
4487
+ self.output_norm = RMSNormGated(self.d_inner, eps=config.norm_epsilon, norm_before_gate=False)
4488
+
4489
  def forward(self, hidden_states, **kwargs):
4490
  # Apply in_proj
4491
  zxBCdt = self.in_proj(hidden_states)
 
4578
  _beta_arr = torch.roll(beta_arr, shifts=-1, dims=1)
4579
 
4580
  x_scalar = (gamma_arr*_alpha_arr + _beta_arr).to(torch.bfloat16)
4581
+
4582
  z = rearrange(z, "b l r (h p) -> b l r h p", p=self.headdim)
4583
 
4584
  y = mamba_mimo_chunk_scan_discretized_fused_combined(
 
4591
  gamma=gamma_arr,
4592
  CB_sum=CB_sum,
4593
  D=self.D,
4594
+ z=z if not (self.config.legacy_gate and self.config.mamba3_postgate_norm) else None,
4595
  )
4596
 
4597
  y = rearrange(y, "b l r h p -> b l r (h p)")
4598
+
4599
+ if self.config.legacy_gate and self.config.mamba3_postgate_norm:
4600
+ z = rearrange(z, "b l r h p -> b l r (h p)")
4601
+ y = self.output_norm(y, z)
4602
+
4603
  #if seqlen_og is not None:
4604
  # y = rearrange(y, "b l r d -> (b l) r d")
4605
 
 
4626
  self.lambda1 = nn.Parameter(torch.zeros(self.link_size)) # sigmoid->0.5
4627
  else :
4628
  self.fc_1 = DragonLinear(config, config.hidden_size, intermediate_size, bias=False)
4629
+ self.fc_1.norm_case_1 = True
4630
  self.fc_2 = DragonLinear(config, intermediate_size, config.hidden_size, bias=False)
4631
+ self.fc_2.norm_case_2 = True
4632
  self.register_buffer("_2_sqrt_5", torch.tensor(2/math.sqrt(5)) if config.use_uscaling else torch.tensor(1.), persistent=False)
4633
 
4634
  def forward(self, hidden_states):
 
4657
  self.intermediate_size = intermediate_size
4658
 
4659
  self.fc_1 = DragonLinear(config, config.hidden_size, num_active_experts*self.intermediate_size, bias=False)
4660
+ self.fc_1.norm_case_1 = True
4661
  self.fc_2 = DragonLinear(config, num_active_experts*self.intermediate_size, config.hidden_size, bias=False)
4662
+ self.fc_2.norm_case_2 = True
4663
  self.register_buffer("_2_sqrt_5", torch.tensor(2/math.sqrt(5)) if config.use_uscaling else torch.tensor(1.), persistent=False)
4664
 
4665
  def forward(self, hidden_states, gates):
 
4737
  head_dim = self.mixer.head_dim
4738
  num_attention_heads = self.mixer.num_q_heads
4739
  use_gate = config.gate_attn
4740
+ elif layer_type == 'C':
4741
+ self.mixer = DragonCompressedConvolutionalAttention2(config, layer_idx=layer_idx)
4742
+ head_dim = self.mixer.head_dim
4743
+ num_attention_heads = self.mixer.num_q_heads
4744
+ use_gate = config.gate_attn
4745
  elif layer_type == 'n':
4746
  self.mixer = DragonNativeSparseAttention(config, reuse_kv=False, layer_idx=layer_idx)
4747
  head_dim = self.mixer.head_dim
 
4771
  self.mixer = DragonMamba3(config, layer_idx=layer_idx)
4772
  head_dim = self.mixer.headdim
4773
  num_attention_heads = self.mixer.nheads
4774
+ use_gate = config.gate_gdn
4775
  elif layer_type == '2':
4776
  self.mixer = DragonMamba2(config, layer_idx=layer_idx)
4777
  head_dim = self.mixer.headdim
 
4782
  head_dim = self.mixer.headdim
4783
  num_attention_heads = self.mixer.nheads
4784
  use_gate = False # inside Mamba3Mimo
4785
+ elif layer_type == 'b':
4786
+ self.mixer = DragonMoBAttention(config, reuse_kv=False, layer_idx=layer_idx)
4787
+ head_dim = self.mixer.head_dim
4788
+ num_attention_heads = self.mixer.num_attention_heads
4789
+ use_gate = config.gate_attn
4790
  else:
4791
  raise ValueError(f"Unknown layer type: {layer_type}")
4792
 
 
4806
  self.gate_proj.is_scalar_weight = True
4807
  else:
4808
  raise ValueError(f"Unknown gate_type: {self.config.gate_type}")
4809
+ self.gate_proj.norm_case_1 = True
4810
  if self.config.zero_centered_gate:
4811
  val = 1.
4812
  if self.config.zero_centered_gate_type==3:
 
4827
  self.use_gate = use_gate
4828
 
4829
  self.mixer_proj = DragonLinear(config, head_dim*num_attention_heads, config.hidden_size, bias=False)
4830
+ self.mixer_proj.norm_case_2 = True
4831
  if config.mixer_gn:
4832
  self.mixer_group_norm = DragonHeadWiseRMSNorm(n_heads=num_attention_heads, d_head=head_dim, eps=config.norm_epsilon, zero_centered_gamma=config.zero_centered_gamma)
4833
 
 
4874
  cu_seqlens=cu_seqlens,
4875
  max_seqlen=max_seqlen,
4876
  ) # (B, L, E*D)
4877
+ if self.config.mixer_gn and not self.config.gate_before_norm:
4878
+ y_mixer = self.mixer_group_norm(y_mixer)
4879
  if self.use_gate:
4880
  if self.config.gate_type == "elementwise" or self.config.gate_type == "kimi":
4881
  g_proj = self.gate_proj(hidden_states).view(hidden_states.size(0), hidden_states.size(1), self.num_attention_heads, self.head_dim).to(y_mixer.dtype)
 
4890
  y_mixer = y_mixer * (self.gate_act(g_proj) + self.gate_bias)
4891
  elif self.config.zero_centered_gate_type == 3 or self.config.zero_centered_gate_type == 4:
4892
  y_mixer = y_mixer * self.gate_act(g_proj + self.gate_bias)
4893
+ if self.config.mixer_gn and self.config.gate_before_norm:
4894
  y_mixer = self.mixer_group_norm(y_mixer)
4895
  y_mixer = y_mixer.view(y_mixer.size(0), y_mixer.size(1), -1)
4896
  y_mixer = self.mixer_proj(y_mixer)
 
4904
 
4905
  return hidden_states, last_key_states, last_value_states
4906
 
4907
+ class DragonGHyperConnection(nn.Module):
4908
+ def __init__(self, config: DragonConfig, m, n_in=3):
4909
+ super().__init__()
4910
+ self.config = config
4911
+ self.m, self.n_in = m, n_in
4912
+ dim = self.config.hidden_size
4913
+ self.factor = 1.0 / math.sqrt(dim // self.m)
4914
+
4915
+ # Initialize static beta: cyclic pattern
4916
+ static_beta_tensor = torch.zeros(self.m, n_in)
4917
+ for j in range(n_in):
4918
+ static_beta_tensor[j % self.m, j] = 1.0
4919
+ self.static_beta = nn.Parameter(static_beta_tensor.T.contiguous())
4920
+
4921
+ # Initialize static alpha: block matrix
4922
+ init_alpha = torch.cat([torch.eye(self.m), torch.eye(self.m), torch.zeros((self.m, self.n_in - self.m))], dim=1)
4923
+ if self.n_in > self.m:
4924
+ part2 = torch.cat([torch.zeros((self.n_in - self.m, self.m * 2)), torch.eye(self.n_in - self.m)], dim=1)
4925
+ init_alpha = torch.cat([init_alpha, part2], dim=0)
4926
+ self.static_alpha = nn.Parameter(init_alpha.contiguous())
4927
+
4928
+ # Dynamic parameters
4929
+ self.dynamic_alpha_fn = nn.Parameter(torch.zeros((dim // self.m, self.m + self.n_in)))
4930
+ self.dynamic_beta_fn = nn.Parameter(torch.zeros((dim // self.m, self.m)))
4931
+ self.dynamic_alpha_fn.requires_weight_decay = True
4932
+ self.dynamic_beta_fn.requires_weight_decay = True
4933
+ if self.config.vwn_dynamic:
4934
+ self.dynamic_alpha_scale = nn.Parameter(torch.ones_like(self.static_alpha))
4935
+ self.dynamic_beta_scale = nn.Parameter(torch.ones_like(self.static_beta))
4936
+ if config.vwn_wd_alpha_beta:
4937
+ self.dynamic_alpha_scale.requires_weight_decay = True
4938
+ self.dynamic_beta_scale.requires_weight_decay = True
4939
+ else:
4940
+ self.register_buffer("dynamic_alpha_scale", torch.zeros_like(self.static_alpha), persistent=False)
4941
+ self.register_buffer("dynamic_beta_scale", torch.zeros_like(self.static_beta), persistent=False)
4942
+
4943
+ self.layer_norm = DragonNorm(config, dim//self.m)
4944
+
4945
+ def _base_width_connection(self, h, dynamic_fn, dynamic_scale, static_scale):
4946
+ h_shape = h.shape
4947
+ N, NMM = static_scale.shape
4948
+ M = (NMM - N) // 2
4949
+ h_reshape = h.reshape((h_shape[:-1].numel(),) + (N, h_shape[-1] // N))
4950
+ norm_h = self.layer_norm(h_reshape)
4951
+ alpha_beta = (F.tanh(norm_h @ dynamic_fn.T.to(dtype=norm_h.dtype) * self.factor) * dynamic_scale[None, ...] + static_scale[None, ...])
4952
+ alpha, beta = torch.split(alpha_beta, (M + N, M), dim=-1)
4953
+ mix_h = (h_reshape.transpose(1, 2) @ alpha.to(dtype=h_reshape.dtype)).transpose(1, 2)
4954
+ return mix_h.reshape(h_shape[:-1] + mix_h.shape[1:]), beta
4955
+
4956
+ def width_connection(self, h):
4957
+ dynamic_fn = torch.concat([self.dynamic_alpha_fn.T, self.dynamic_beta_fn.T], dim=0)
4958
+ dynamic_scale = torch.concat([self.dynamic_alpha_scale, self.dynamic_beta_scale], dim=-1).contiguous()
4959
+ static_scale = torch.concat([self.static_alpha, self.static_beta], dim=-1)
4960
+ return self._base_width_connection(h, dynamic_fn.to(dtype=h.dtype), dynamic_scale.to(dtype=h.dtype), static_scale.to(dtype=h.dtype))
4961
+
4962
+ def depth_connection(self, mix_h, h_o, beta, sqrt_one_minus_tau, sqrt_tau):
4963
+ h_o_shape = h_o.shape
4964
+ h_o = h_o.reshape(h_o_shape[:-1] + (self.m, h_o_shape[-1] // self.m))
4965
+ h_i = beta.view(h_o.shape[:2] + beta.shape[1:]).to(dtype=h_o.dtype) @ h_o
4966
+ h = sqrt_tau * h_i + sqrt_one_minus_tau * mix_h[..., self.m:, :]
4967
+ h_shape = h.shape
4968
+ return h.reshape(h_shape[:-2] + (h_shape[-2] * h_shape[-1],)).contiguous()
4969
+
4970
+ class DragonMonoVirtualBlock(GradientCheckpointingLayer):
4971
+ def __init__(self, config: DragonConfig, layer_idx: int, layer_type: str):
4972
+ super().__init__()
4973
+ self.config = config
4974
+ self.layer_idx = layer_idx
4975
+
4976
+ assert self.config.vwn
4977
+
4978
+ if layer_type == 'g':
4979
+ self.mixer = DragonGatedDeltaNet(config, layer_idx=layer_idx)
4980
+ head_dim = self.mixer.head_dim
4981
+ num_attention_heads = self.mixer.num_attention_heads
4982
+ use_gate = config.gate_gdn
4983
+ elif layer_type == 'f':
4984
+ self.mixer = DragonDifferentialAttention(config, layer_idx=layer_idx)
4985
+ head_dim = self.mixer.head_dim
4986
+ num_attention_heads = self.mixer.num_signal_heads
4987
+ use_gate = config.gate_attn
4988
+ elif layer_type == 's':
4989
+ self.mixer = DragonDeepSeekSparseAttention(config, reuse_kv=False, layer_idx=layer_idx)
4990
+ head_dim = self.mixer.head_dim
4991
+ num_attention_heads = self.mixer.num_attention_heads
4992
+ use_gate = config.gate_attn
4993
+ elif layer_type == 'm':
4994
+ self.mixer = DragonDynamicMaskAttention(config, reuse_kv=False, layer_idx=layer_idx)
4995
+ head_dim = self.mixer.head_dim
4996
+ num_attention_heads = self.mixer.num_attention_heads
4997
+ use_gate = config.gate_attn
4998
+ elif layer_type == 'w':
4999
+ self.mixer = DragonAttention(config, reuse_kv=False, layer_idx=layer_idx)
5000
+ head_dim = self.mixer.head_dim
5001
+ num_attention_heads = self.mixer.num_attention_heads
5002
+ use_gate = config.gate_attn
5003
+ elif layer_type == 'p':
5004
+ self.mixer = DragonSlidingWindowRecurrenceAttention(config)
5005
+ head_dim = self.mixer.head_dim
5006
+ num_attention_heads = self.mixer.num_attention_heads
5007
+ use_gate = config.gate_attn
5008
+ elif layer_type == 'c':
5009
+ self.mixer = DragonCompressedConvolutionalAttention(config, layer_idx=layer_idx)
5010
+ head_dim = self.mixer.head_dim
5011
+ num_attention_heads = self.mixer.num_q_heads
5012
+ use_gate = config.gate_attn
5013
+ elif layer_type == 'C':
5014
+ self.mixer = DragonCompressedConvolutionalAttention2(config, layer_idx=layer_idx)
5015
+ head_dim = self.mixer.head_dim
5016
+ num_attention_heads = self.mixer.num_q_heads
5017
+ use_gate = config.gate_attn
5018
+ elif layer_type == 'n':
5019
+ self.mixer = DragonNativeSparseAttention(config, reuse_kv=False, layer_idx=layer_idx)
5020
+ head_dim = self.mixer.head_dim
5021
+ num_attention_heads = self.mixer.num_attention_heads
5022
+ use_gate = config.gate_attn
5023
+ elif layer_type == 't':
5024
+ self.mixer = DragonTensorProductAttention(config, reuse_kv=False, layer_idx=layer_idx)
5025
+ head_dim = self.mixer.head_dim
5026
+ num_attention_heads = self.mixer.num_attention_heads
5027
+ use_gate = config.gate_attn
5028
+ elif layer_type == 'T':
5029
+ self.mixer = DragonDifferentialTensorProductAttention(config, layer_idx=layer_idx)
5030
+ head_dim = self.mixer.head_dim
5031
+ num_attention_heads = self.mixer.num_signal_heads
5032
+ use_gate = config.gate_attn
5033
+ elif layer_type == 'A':
5034
+ self.mixer = DragonDifferentialMultiLatentAttention(config, layer_idx=layer_idx)
5035
+ head_dim = self.mixer.head_dim
5036
+ num_attention_heads = self.mixer.num_signal_heads
5037
+ use_gate = config.gate_attn
5038
+ elif layer_type == 'k':
5039
+ self.mixer = DragonKimiDeltaAttention(config, layer_idx=layer_idx)
5040
+ head_dim = self.mixer.head_dim
5041
+ num_attention_heads = self.mixer.num_attention_heads
5042
+ use_gate = config.gate_gdn
5043
+ elif layer_type == '3':
5044
+ self.mixer = DragonMamba3(config, layer_idx=layer_idx)
5045
+ head_dim = self.mixer.headdim
5046
+ num_attention_heads = self.mixer.nheads
5047
+ use_gate = config.gate_gdn
5048
+ elif layer_type == '2':
5049
+ self.mixer = DragonMamba2(config, layer_idx=layer_idx)
5050
+ head_dim = self.mixer.headdim
5051
+ num_attention_heads = self.mixer.nheads
5052
+ use_gate = config.gate_gdn
5053
+ elif layer_type == 'M':
5054
+ self.mixer = DragonMamba3Mimo(config, layer_idx=layer_idx)
5055
+ head_dim = self.mixer.headdim
5056
+ num_attention_heads = self.mixer.nheads
5057
+ use_gate = False # inside Mamba3Mimo
5058
+ else:
5059
+ raise ValueError(f"Unknown layer type: {layer_type}")
5060
+
5061
+ if use_gate:
5062
+ if self.config.gate_type == "elementwise":
5063
+ self.gate_proj = DragonLinear(self.config, config.hidden_size, num_attention_heads*head_dim, bias=False)
5064
+ elif self.config.gate_type == "kimi":
5065
+ self.gate_proj = nn.Sequential(
5066
+ DragonLinear(config, config.hidden_size, head_dim, bias=False),
5067
+ DragonLinear(config, head_dim, num_attention_heads*head_dim, bias=True),
5068
+ )
5069
+ elif self.config.gate_type == "headwise":
5070
+ if self.config.scalar_proj_as_hidden_matrix:
5071
+ self.gate_proj = DragonLinear(self.config, config.hidden_size, num_attention_heads, bias=False)
5072
+ else:
5073
+ self.gate_proj = DragonLinear(self.config, config.hidden_size, num_attention_heads, bias=False, alpha_fwd=1., alpha_bwd=1.)
5074
+ self.gate_proj.is_scalar_weight = True
5075
+ else:
5076
+ raise ValueError(f"Unknown gate_type: {self.config.gate_type}")
5077
+ self.gate_proj.norm_case_1 = True
5078
+ if self.config.zero_centered_gate:
5079
+ val = 1.
5080
+ if self.config.zero_centered_gate_type==3:
5081
+ val = 1.28 # F.silu(E(g) + 1.28) = 1
5082
+ elif self.config.zero_centered_gate_type==4:
5083
+ val = 1.15 # E(silu(g + 1.15)) = 1
5084
+ self.register_buffer("gate_bias", torch.tensor(val), persistent=False)
5085
+ else:
5086
+ self.register_buffer("gate_bias", torch.tensor(0.), persistent=False)
5087
+ if self.config.gate_act == "silu":
5088
+ self.gate_act = F.silu
5089
+ elif self.config.gate_act == "sigmoid":
5090
+ self.gate_act = F.sigmoid
5091
+ else:
5092
+ raise ValueError(f"Unknown gate_act: {self.config.gate_act}")
5093
+ self.num_attention_heads = num_attention_heads
5094
+ self.head_dim = head_dim
5095
+ self.use_gate = use_gate
5096
+
5097
+ self.mixer_proj = DragonLinear(config, head_dim*num_attention_heads, config.hidden_size, bias=False)
5098
+ self.mixer_proj.norm_case_2 = True
5099
+ if config.mixer_gn:
5100
+ self.mixer_group_norm = DragonHeadWiseRMSNorm(n_heads=num_attention_heads, d_head=head_dim, eps=config.norm_epsilon, zero_centered_gamma=config.zero_centered_gamma)
5101
+
5102
+ self.input_norm = DragonNorm(config, config.hidden_size)
5103
+ self.postmixer_norm = DragonNorm(config, config.hidden_size)
5104
+ self.mixer_ghyper_connection = DragonGHyperConnection(config, m=config.vwn_m, n_in=config.vwn_n)
5105
+ self.mlp_ghyper_connection = DragonGHyperConnection(config, m=config.vwn_m, n_in=config.vwn_n)
5106
+ if not config.moe:
5107
+ if config.mlp_type == "simple":
5108
+ self.mlp = DragonMLP(config)
5109
+ elif config.mlp_type == "gated":
5110
+ self.mlp = GatedMlp(in_features=config.hidden_size, hidden_features=config.intermediate_size, out_features=config.hidden_size, activation=F.silu, bias1=False, bias2=False)
5111
+ else:
5112
+ self.mlp = DragonMoE(config)
5113
+
5114
+ if config.use_uscaling or not config.layer_norm_scaling:
5115
+ self.register_buffer("lns", torch.tensor(1.0), persistent=False)
5116
+ else:
5117
+ self.register_buffer("lns", torch.tensor(1. / math.sqrt(layer_idx + (2 if config.old_lns else 1))), persistent=False)
5118
+ self.register_buffer("sqrt_tau", torch.sqrt(torch.tensor(self.config.uscaling_tau)) if config.use_uscaling else torch.tensor(1.0), persistent=False)
5119
+ self.register_buffer("sqrt_one_minus_tau", torch.sqrt(torch.tensor(1.0 - self.config.uscaling_tau)) if config.use_uscaling else torch.tensor(1.0), persistent=False)
5120
+
5121
+ def forward(
5122
+ self,
5123
+ hidden_states: torch.Tensor,
5124
+ position_ids: Optional[torch.LongTensor] = None,
5125
+ cache_params: Optional[HybridDragonDynamicCache] = None,
5126
+ cache_position: Optional[torch.LongTensor] = None,
5127
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
5128
+ key_value_last_layer: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
5129
+ cu_seqlens: Optional[torch.Tensor] = None,
5130
+ max_seqlen: Optional[int] = None,
5131
+ **kwargs,
5132
+ ):
5133
+ # hidden_states : (B, L, D'). D' = n/m D (expanded width)
5134
+
5135
+ # MIXER.
5136
+ mix_h, beta = self.mixer_ghyper_connection.width_connection(hidden_states)
5137
+ mix_h_shape = mix_h.shape
5138
+ h = mix_h[..., :self.config.vwn_m, :].reshape(mix_h_shape[:-2] + (self.config.vwn_m * mix_h_shape[-1],))
5139
+ # h is (B, L, D)
5140
+ h = self.lns * self.input_norm(h)
5141
+ y_mixer, last_key_states, last_value_states = self.mixer(
5142
+ hidden_states=h,
5143
+ position_embeddings=position_embeddings,
5144
+ position_ids=position_ids,
5145
+ cache_params=cache_params,
5146
+ key_value_last_layer=key_value_last_layer,
5147
+ cu_seqlens=cu_seqlens,
5148
+ max_seqlen=max_seqlen,
5149
+ ) # (B, L, E*D)
5150
+ if self.config.mixer_gn and not self.config.gate_before_norm:
5151
+ y_mixer = self.mixer_group_norm(y_mixer)
5152
+ if self.use_gate:
5153
+ if self.config.gate_type == "elementwise" or self.config.gate_type == "kimi":
5154
+ g_proj = self.gate_proj(h).view(h.size(0), h.size(1), self.num_attention_heads, self.head_dim).to(y_mixer.dtype)
5155
+ elif self.config.gate_type == "headwise":
5156
+ g_proj = self.gate_proj(h).unsqueeze(-1).to(y_mixer.dtype)
5157
+ else:
5158
+ raise ValueError(f"Unknown gate_type: {self.config.gate_type}")
5159
+ if self.config.zero_centered_gate_type == 1:
5160
+ y_mixer = y_mixer * self.gate_act(g_proj)
5161
+ y_mixer = y_mixer + self.gate_bias
5162
+ elif self.config.zero_centered_gate_type == 2:
5163
+ y_mixer = y_mixer * (self.gate_act(g_proj) + self.gate_bias)
5164
+ elif self.config.zero_centered_gate_type == 3 or self.config.zero_centered_gate_type == 4:
5165
+ y_mixer = y_mixer * self.gate_act(g_proj + self.gate_bias)
5166
+ if self.config.mixer_gn and self.config.gate_before_norm:
5167
+ y_mixer = self.mixer_group_norm(y_mixer)
5168
+ y_mixer = y_mixer.view(y_mixer.size(0), y_mixer.size(1), -1)
5169
+ y_mixer = self.mixer_proj(y_mixer) # (B, L, D)
5170
+ h = self.mixer_ghyper_connection.depth_connection(mix_h, y_mixer, beta, self.sqrt_one_minus_tau, self.sqrt_tau) # (B, L, D')
5171
+
5172
+ # MLP.
5173
+ mix_h, beta = self.mlp_ghyper_connection.width_connection(h)
5174
+ mix_h_shape = mix_h.shape
5175
+ h = mix_h[..., :self.config.vwn_m, :].reshape(mix_h_shape[:-2] + (self.config.vwn_m * mix_h_shape[-1],))
5176
+ # h is (B, L, D)
5177
+ h = self.lns * self.postmixer_norm(h)
5178
+ y_mlp = self.mlp(h) # (B, L, D)
5179
+ h = self.mlp_ghyper_connection.depth_connection(mix_h, y_mlp, beta, self.sqrt_one_minus_tau, self.sqrt_tau) # (B, L, D')
5180
+
5181
+ return h, 0, 0
5182
+
5183
  class DragonBlock(GradientCheckpointingLayer):
5184
  def __init__(self, config: DragonConfig, layer_idx: int, layer_type: str):
5185
  super().__init__()
 
5265
  "attentions": DragonBlock,
5266
  }
5267
 
5268
+ def _init_weights(self, module):
5269
  if isinstance(module, (DragonLinear, nn.Conv1d)):
5270
  if module.bias is not None:
5271
  nn.init.zeros_(module.bias)
5272
+ nn.init.normal_(module.weight, mean=0., std=self.config.initializer_range)
5273
  elif isinstance(module, nn.Embedding):
5274
+ nn.init.normal_(module.weight, mean=0., std=self.config.initializer_range)
5275
 
5276
  @dataclass
5277
  class DragonOutput(ModelOutput):
 
5326
  self.vocab_size = config.vocab_size
5327
 
5328
  self.embedding = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
5329
+ if self.config.vwn:
5330
+ self.hidden_size_expanded = int(config.vwn_n/config.vwn_m * config.hidden_size)
5331
+ self.expand_embedding = DragonLinear(config, config.hidden_size, self.hidden_size_expanded, bias=False)
5332
+
5333
+ if not self.config.vwn:
5334
+ self.layers = nn.ModuleList([DragonBlock(config, layer_idx=i, layer_type=layer) if layer in ['l', 'r', 'd'] else DragonMonoBlock(config, layer_idx=i, layer_type=layer) for i, layer in enumerate(config.layers_config)])
5335
+ else:
5336
+ self.layers = nn.ModuleList([DragonBlock(config, layer_idx=i, layer_type=layer) if layer in ['l', 'r', 'd'] else DragonMonoVirtualBlock(config, layer_idx=i, layer_type=layer) for i, layer in enumerate(config.layers_config)])
5337
 
5338
  if self.config.rope_type_global != '' or self.config.rope_type_local != '':
5339
  self.rotary_emb = DragonRotaryEmbedding(config, head_dim=config.head_dim if config.head_dim else (config.expand_factor*config.hidden_size)//config.num_attention_heads, theta=config.rope_theta_local) # only for SWA
5340
  else:
5341
  self.rotary_emb = None
5342
 
5343
+ if self.config.vwn:
5344
+ if int(self.config.vwn_n/self.config.vwn_m) == 8:
5345
+ self.gn = torch.nn.GroupNorm(num_groups=self.hidden_size_expanded//config.hidden_size, num_channels=self.hidden_size_expanded, eps=config.norm_epsilon, affine=False) # todo : zcg ?
5346
+ self.reduce_h = DragonLinear(config, self.hidden_size_expanded, config.hidden_size, bias=False)
5347
+
5348
  if self.config.final_norm:
5349
  self.final_norm = DragonNorm(config, config.hidden_size)
5350
 
5351
  self.gradient_checkpointing = False
5352
  self.post_init()
5353
+
5354
  def get_input_embeddings(self):
5355
  return self.embedding
5356
 
 
5379
 
5380
  if inputs_embeds is None:
5381
  inputs_embeds = self.embedding(input_ids)
5382
+ if self.config.vwn:
5383
+ inputs_embeds = self.expand_embedding(inputs_embeds) # (B, L, D')
5384
 
5385
  if self.config.patch_level_training:
5386
  # (B, KL, D) => (B, L, D) OR (B, L, D) ==> (B, L//K, D)
 
5437
  )
5438
  shared_kv = (last_k, last_v)
5439
 
5440
+ if self.config.vwn:
5441
+ if int(self.config.vwn_n/self.config.vwn_m) == 8:
5442
+ B, L, D = hidden_states.shape
5443
+ hidden_states = self.gn(hidden_states.reshape(-1, D)).view(B, L, D)
5444
+ hidden_states = self.reduce_h(hidden_states) # back to (B, L, D)
5445
+
5446
  if self.config.final_norm:
5447
  hidden_states = self.final_norm(hidden_states)
5448
 
5449
  if output_hidden_states:
5450
  all_hidden_states = all_hidden_states + (hidden_states,)
5451
 
5452
+ if past_key_values and not past_key_values.has_previous_state:
5453
+ past_key_values.has_previous_state = True
5454
+
5455
  return DragonOutput(
5456
  last_hidden_state=hidden_states,
5457
  past_key_values=past_key_values if use_cache else None,
 
5465
  self.config = config
5466
  self.model = DragonModel(config)
5467
  self.vocab_size = config.vocab_size
5468
+ bwd = 1/math.sqrt(config.hidden_size) if config.dataset_type == "hf" else 1/config.hidden_size
5469
+ if config.reduce_lm_head == 0:
5470
+ self.lm_head = DragonLinear(config, config.hidden_size, config.vocab_size, bias=False, alpha_fwd=1/config.hidden_size, alpha_bwd=bwd)
5471
+ else:
5472
+ self.lm_head = nn.Sequential(
5473
+ DragonLinear(config, config.hidden_size, config.reduce_lm_head, bias=False, alpha_fwd=1./math.sqrt(config.reduce_lm_head)),
5474
+ DragonLinear(config, config.reduce_lm_head, config.vocab_size, bias=False, alpha_fwd=1/config.hidden_size, alpha_bwd=bwd),
5475
+ )
5476
  self.post_init()
5477
  if config.tie_lm_head:
5478
  self.lm_head.weight = self.model.embedding.weight
5479
 
5480
+ if config.init_gpt2:
5481
+ for pn, p in self.named_parameters():
5482
+ if pn.endswith('fc2.weight') or pn.endswith('mixer_proj.weight'):
5483
+ torch.nn.init.normal_(p, mean=0.0, std=config.initializer_range/math.sqrt(2 * len(config.layers_config)))
5484
+
5485
  def forward(
5486
  self,
5487
  input_ids: Optional[torch.LongTensor] = None,
 
5527
  labels = labels.to(hidden_states.device)
5528
 
5529
  if linear_cross_entropy is None or not self.config.fused_loss_computation:
5530
+ if not self.config.reduce_lm_head:
5531
+ logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)[:, slice_indices, :]).float()
5532
+ else:
5533
+ logits = self.lm_head(hidden_states.to(self.lm_head[0].weight.dtype)[:, slice_indices, :]).float()
5534
  if not self.config.patch_level_training:
5535
  shift_logits = logits[..., :-1, :].contiguous()
5536
  shift_labels = labels[..., 1:].contiguous()
 
5544
  loss = loss + F.nll_loss(log_probs, shift_labels[:, i])
5545
  loss = loss / self.config.patch_level_training_size
5546
  else:
5547
+ assert not self.config.reduce_lm_head
5548
  assert not self.config.patch_level_training, "Fused loss computation is not supported with patch-level training."
5549
  loss = linear_cross_entropy(
5550
  hidden_states[:, slice_indices, :].view(-1, hidden_states.size(-1)),
optimizers/Ademamix.py CHANGED
@@ -46,7 +46,7 @@ class AdEMAMix(Optimizer):
46
  """
47
 
48
  def __init__(self, params, lr=1e-3, betas=(0.9, 0.999, 0.999), alpha=8.0,
49
- beta3_warmup=None, alpha_warmup=None, eps=1e-8,
50
  weight_decay=0):
51
  if not 0.0 <= lr:
52
  raise ValueError("Invalid learning rate: {}".format(lr))
@@ -62,6 +62,7 @@ class AdEMAMix(Optimizer):
62
  raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
63
  if not 0.0 <= alpha:
64
  raise ValueError("Invalid alpha value: {}".format(alpha))
 
65
  defaults = dict(lr=lr, betas=betas, eps=eps, alpha=alpha, beta3_warmup=beta3_warmup,
66
  alpha_warmup=alpha_warmup, weight_decay=weight_decay)
67
  super(AdEMAMix, self).__init__(params, defaults)
@@ -139,6 +140,8 @@ class AdEMAMix(Optimizer):
139
  exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
140
 
141
  denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
 
 
142
 
143
  update = (exp_avg_fast.div(bias_correction1) + alpha * exp_avg_slow) / denom
144
 
 
46
  """
47
 
48
  def __init__(self, params, lr=1e-3, betas=(0.9, 0.999, 0.999), alpha=8.0,
49
+ beta3_warmup=None, alpha_warmup=None, eps=1e-8, normalize_alpha=False,
50
  weight_decay=0):
51
  if not 0.0 <= lr:
52
  raise ValueError("Invalid learning rate: {}".format(lr))
 
62
  raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
63
  if not 0.0 <= alpha:
64
  raise ValueError("Invalid alpha value: {}".format(alpha))
65
+ self.normalize_alpha = normalize_alpha
66
  defaults = dict(lr=lr, betas=betas, eps=eps, alpha=alpha, beta3_warmup=beta3_warmup,
67
  alpha_warmup=alpha_warmup, weight_decay=weight_decay)
68
  super(AdEMAMix, self).__init__(params, defaults)
 
140
  exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
141
 
142
  denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
143
+ if self.normalize_alpha:
144
+ denom = denom * (1.0 + alpha)
145
 
146
  update = (exp_avg_fast.div(bias_correction1) + alpha * exp_avg_slow) / denom
147
 
training_dragon.py CHANGED
@@ -41,7 +41,8 @@ class NanoArgs:
41
  rope_theta_local: float = 10000.0
42
  rope_theta_global: float = 0.0
43
  eps_rmsnorm: float = 1e-6
44
- mlp_expand: int = 4 # expand factor for MLP
 
45
  fused_loss_computation : bool = True # whether to use fused linear + cross entropy loss
46
  use_uscaling: bool = False
47
  uscaling_tau: float = 0.2
@@ -58,11 +59,19 @@ class NanoArgs:
58
  seednorm_type: int = 1
59
  seednorm_rank: int = 1
60
  mixer_gn: bool = True
 
61
  mlp_linking : bool = False
62
  final_norm: bool = True
63
  layer_norm_scaling: bool = False # not read when using muP
64
  mlp_type: str = "simple" # simple, gated
65
  tie_lm_head: bool = False
 
 
 
 
 
 
 
66
 
67
  # MoE
68
  moe: bool = False
@@ -117,6 +126,7 @@ class NanoArgs:
117
  mamba3_remove_conv: bool = True
118
  mamba3_is_A_dd: bool = True
119
  mamba3_add_trapezoid: bool = True
 
120
 
121
  # optim
122
  optim: str = "adamw" # adamw, spam, stable-spam, muon, muon_moonlight, splus
@@ -129,6 +139,8 @@ class NanoArgs:
129
  adam_beta1: float = 0.9
130
  adam_beta2: float = 0.95
131
  adam_eps: float = 1e-8
 
 
132
  warmup_iters: int = 200
133
  warmdown_iters: int = 3000
134
  warmdown_type: str = "linear" # linear, cosine
@@ -142,6 +154,8 @@ class NanoArgs:
142
  second_order_lr: float = 0.68
143
  second_order_momentum: float = 0.37
144
  second_order_interval: int = 25
 
 
145
 
146
  # data
147
  vocab_size: int = 50304
@@ -150,6 +164,7 @@ class NanoArgs:
150
  intra_doc_masking: bool = False
151
  input_bin: Optional[str] = None
152
  input_val_bin: Optional[str] = None
 
153
 
154
  # evaluation and logging
155
  val_loss_every: int = 125
@@ -170,7 +185,34 @@ class NanoArgs:
170
  # used during training
171
  slw_window: int = 0
172
 
173
- def _peek_data_shard(filename):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  with open(filename, "rb") as f:
175
  header = np.frombuffer(f.read(256 * 4), dtype=np.int32)
176
  if header[0] != 20240520:
@@ -182,25 +224,22 @@ def _peek_data_shard(filename):
182
  ntok = int(header[2])
183
  return ntok
184
 
185
- def _load_data_shard(filename):
186
- with open(filename, "rb") as f:
187
- header = np.frombuffer(f.read(256 * 4), dtype=np.int32)
188
- assert header[0] == 20240520, "magic number mismatch in the data .bin file"
189
- assert header[1] == 1, "unsupported version"
190
- ntok = int(header[2])
191
- # memmap the token payload directly (uint16) after the 256*4B header
192
- tokens = np.memmap(filename, dtype=np.uint16, mode="r", offset=256 * 4, shape=(ntok,))
193
- assert tokens.size == ntok, "number of tokens read does not match header?"
194
- return tokens
195
 
196
  class DistributedDataLoader:
197
- def __init__(self, filename_pattern, intra_doc_masking,B, T, process_rank, num_processes, bos_id):
198
  self.process_rank = process_rank
199
  self.num_processes = num_processes
200
  self.intra_doc_masking = intra_doc_masking
201
  self.bos_id = bos_id
202
  self.B = B # micro batch size
203
  self.T = T
 
204
 
205
  # glob files that match the pattern
206
  self.files = sorted(glob.glob(filename_pattern))
@@ -210,7 +249,7 @@ class DistributedDataLoader:
210
  ntok_total = 0
211
  self.shard_ntoks = []
212
  for fname in self.files:
213
- shard_ntok = _peek_data_shard(fname)
214
  #print(f"shard {fname} has {shard_ntok} tokens")
215
  assert shard_ntok >= num_processes * B * T + 1
216
  self.shard_ntoks.append(shard_ntok)
@@ -223,12 +262,12 @@ class DistributedDataLoader:
223
  def reset(self, shard=0):
224
  self.current_shard = shard
225
  self.current_position = self.process_rank * self.B * self.T
226
- self.tokens = _load_data_shard(self.files[self.current_shard])
227
 
228
  def advance(self): # advance to next data shard
229
  self.current_shard = (self.current_shard + 1) % len(self.files)
230
  self.current_position = self.process_rank * self.B * self.T
231
- self.tokens = _load_data_shard(self.files[self.current_shard])
232
 
233
  if self.process_rank == 0:
234
  shard_tokens = self.shard_ntoks[self.current_shard]
@@ -282,30 +321,38 @@ def param_groups_mup(model, base_lr_hidden, base_lr_scalar, base_lr_embed, base_
282
  groups, seen = [], set()
283
  id2name = {id(p): n for n, p in model.named_parameters()}
284
 
285
- for mod in model.modules():
286
  if isinstance(mod, nn.Linear):
287
  pname = id2name.get(id(mod.weight), "")
288
  is_scalar = getattr(mod, "is_scalar_weight", False)
289
  fan_in = mod.weight.shape[1]
290
- scale = 1 / math.sqrt(fan_in)
291
  if "lm_head" in pname:
 
292
  lr_scaled = base_lr_head
293
  wd_scaled = 0.0
 
294
  elif is_scalar:
 
295
  lr_scaled = base_lr_scalar
296
  wd_scaled = 0.0
 
297
  else:
 
298
  lr_scaled = base_lr_hidden * scale
299
  wd_scaled = wd / lr_scaled
 
300
 
301
  groups.append({"params": [mod.weight], "lr": lr_scaled, "weight_decay": wd_scaled})
302
  seen.add(mod.weight)
303
 
 
 
304
  if mod.bias is not None:
 
305
  groups.append({"params": [mod.bias], "lr": base_lr_scalar, "weight_decay": 0.0})
306
  seen.add(mod.bias)
307
 
308
- for p in model.parameters():
309
  if p in seen:
310
  continue
311
  pname = id2name.get(id(p), "<unnamed>")
@@ -318,11 +365,15 @@ def param_groups_mup(model, base_lr_hidden, base_lr_scalar, base_lr_embed, base_
318
  lr_scaled = base_lr_scalar
319
 
320
  wd_scaled = 0.
 
321
  if getattr(p, "requires_weight_decay", False):
322
  wd_scaled = wd / lr_scaled
 
323
 
324
  groups.append({"params": [p], "lr": lr_scaled, "weight_decay": wd_scaled})
325
 
 
 
326
  return groups
327
 
328
  args = tyro.cli(NanoArgs)
@@ -341,6 +392,9 @@ if args.mlp_type == "gated":
341
  print("problem: gated MLP with MoE is not supported, because we use FA backend")
342
  exit(0)
343
 
 
 
 
344
  # set up DDP (distributed data parallel).
345
  assert torch.cuda.is_available()
346
  dist.init_process_group(
@@ -434,13 +488,22 @@ tokenizer = transformers.AutoTokenizer.from_pretrained("/leonardo_work/BOOST_LCu
434
  # load dataloaders.
435
  #if args.patch_level_training:
436
  # assert T % args.patch_level_training_size == 0, "sequence length must be divisible by patch level training size in reduced mode"
437
- train_loader = DistributedDataLoader(args.input_bin, args.intra_doc_masking, B, T, ddp_rank, ddp_world_size, args.bos_id)
438
- val_loader = DistributedDataLoader(args.input_val_bin, args.intra_doc_masking, B, T, ddp_rank, ddp_world_size, args.bos_id)
439
  print0(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files")
440
  print0(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files")
441
 
442
  # load model.
443
  config_hf = DragonConfig(
 
 
 
 
 
 
 
 
 
444
  tie_lm_head=args.tie_lm_head,
445
  mlp_type=args.mlp_type,
446
  layer_norm_scaling=args.layer_norm_scaling,
@@ -452,6 +515,7 @@ config_hf = DragonConfig(
452
  mamba3_remove_conv=args.mamba3_remove_conv,
453
  mamba3_is_A_dd=args.mamba3_is_A_dd,
454
  mamba3_add_trapezoid=args.mamba3_add_trapezoid,
 
455
  moe=args.moe,
456
  moe_num_routed_experts=args.moe_num_routed_experts,
457
  moe_routed_scaling_factor=args.moe_routed_scaling_factor,
@@ -466,6 +530,7 @@ config_hf = DragonConfig(
466
  shrink_qk_da=args.shrink_qk_da,
467
  shrink_qk_gdn=args.shrink_qk_gdn,
468
  mixer_gn=args.mixer_gn,
 
469
  kda_allow_neg_eigval=args.kda_allow_neg_eigval,
470
  kda_num_v_heads=args.kda_num_v_heads,
471
  seednorm_wd=args.seednorm_wd,
@@ -508,7 +573,7 @@ config_hf = DragonConfig(
508
  max_position_embeddings=args.sequence_length,
509
  use_uscaling=args.use_uscaling,
510
  hidden_size=args.d_model,
511
- intermediate_size=args.d_model * args.mlp_expand,
512
  expand_factor=args.expand_factor,
513
  layers_config=args.layers_config,
514
  num_attention_heads=args.n_heads,
@@ -535,18 +600,14 @@ else:
535
  model = model.cuda()
536
  print0(model)
537
 
538
- """# check here that the init std is as expected: # TODO TEMPORARY
539
  with torch.no_grad():
540
- wstd = model.model.embedding.weight.std().item()
541
- print0(f"Model weight init std: {wstd:.6f} (expected {args.init_std})")
542
- assert abs(wstd - args.init_std) / args.init_std < 0.1, f"weight init std {wstd} deviates from expected {args.init_std} by more than 10%"
543
-
544
- # check on another we
545
- lstd = model.model.layers[0].attn.linear_qkv.weight.std().item()
546
- print0(f"Model first layer attention QKV weight init std: {lstd:.6f} (expected {args.init_std})")
547
-
548
- lstd = model.model.layers[0].lin_attn.qkv_conv1d.weight.std().item()
549
- print0(f"Model first layer conv QKV weight init std: {lstd:.6f} (expected {args.init_std})")"""
550
 
551
  # count params. (total & active)
552
  num_params = sum(p.numel() for p in model.parameters())
@@ -570,7 +631,7 @@ ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16)
570
 
571
  if args.intra_doc_masking:
572
  print0("!!! Using intra-document masking !!!")
573
- print0("It is only compatible with GDN (conv+chunk), DA and GDTPA layers. For DA/GDTPA, kv shift is also compatible. All other config will not have intra-doc masking support!!")
574
 
575
  # load optimizers & schedulers.
576
  if args.use_uscaling:
@@ -587,18 +648,38 @@ if args.use_uscaling:
587
  optimizer = torch.optim.AdamW(param_list, betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps)
588
  elif args.optim == "ademamix":
589
  from .optimizers.Ademamix import AdEMAMix
590
- beta3_warmup = alpha_warmup = args.total_iterations
591
- optimizer = AdEMAMix(param_list, beta3_warmup=beta3_warmup, alpha_warmup=alpha_warmup, weight_decay=args.weight_decay)
 
592
  else:
593
  raise ValueError(f"Unknown optimizer for unit scaling: {args.optim}")
594
  else:
595
  if args.optim == "adamw":
596
- optimizer = torch.optim.AdamW(raw_model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay, betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
597
  elif args.optim == "ademamix":
598
  from .optimizers.Ademamix import AdEMAMix
599
 
600
- beta3_warmup = alpha_warmup = args.total_iterations
601
- optimizer = AdEMAMix(raw_model.parameters(), lr=args.learning_rate, beta3_warmup=beta3_warmup, alpha_warmup=alpha_warmup, weight_decay=args.weight_decay)
 
602
  else:
603
  raise ValueError(f"Unknown Optimizer: {args.optim}")
604
  if args.second_order_optim == "snoo":
@@ -624,7 +705,7 @@ def get_lr_wsd(num_iterations, warmup_iters, warmdown_iters, it):
624
  if args.warmdown_type == "linear":
625
  sched_func = partial(get_lr_wsd, args.total_iterations, args.warmup_iters, args.warmdown_iters)
626
  schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, sched_func) for opt in optimizers]
627
- elif args.warmdown_type == "cosine":
628
  sched = get_wsd_schedule(
629
  optimizers[0],
630
  num_warmup_steps=args.warmup_iters,
@@ -632,7 +713,7 @@ elif args.warmdown_type == "cosine":
632
  num_training_steps=args.total_iterations,
633
  min_lr_ratio=0.,
634
  warmup_type='linear',
635
- decay_type='cosine',
636
  )
637
  schedulers = [sched]
638
  else:
@@ -721,8 +802,11 @@ for iter_ in range(start_iter, start_iter+args.total_iterations+1):
721
  # save model & tokenizer to make evaluation easier.
722
  tokenizer.save_pretrained(save_dir)
723
  state_dict_bf16 = {k: v.detach().to(torch.bfloat16).cpu() for k, v in uncompiled_model.state_dict().items()}
 
 
724
  uncompiled_model.config.torch_dtype = torch.bfloat16
725
  uncompiled_model.save_pretrained(save_dir, safe_serialization=True, state_dict=state_dict_bf16)
 
726
  # save training state.
727
  train_state = dict(
728
  iteration=iter_,
@@ -757,6 +841,18 @@ for iter_ in range(start_iter, start_iter+args.total_iterations+1):
757
  (loss / accumulation_steps).backward()
758
  else:
759
  (loss / accumulation_steps).backward() # just sync on the last step
 
 
 
 
 
 
 
 
 
 
 
 
760
  # clip those gradients.
761
  if args.grad_norm_clip is not None:
762
  grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.grad_norm_clip, foreach=True)
@@ -771,13 +867,26 @@ for iter_ in range(start_iter, start_iter+args.total_iterations+1):
771
  # null those gradients.
772
  model.zero_grad(set_to_none=True)
773
 
 
 
 
 
 
 
 
 
 
 
 
 
 
774
  # ----------- LOGGING SECTION -----------
775
  approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0)
776
  avg_step_time = approx_training_time_ms / (iter_ + 1 - WARMUP_SKIP) if iter_ >= start_iter+WARMUP_SKIP else 0
777
  extra = " ".join(f"{k}:{v}" for k, v in (to_log or {}).items())
778
  print0(f"iteration:{iter_+1:0{len(str(start_iter+args.total_iterations))}d}/{args.total_iterations} train_loss:{train_loss.item():.4f} lr: {schedulers[0].get_last_lr()[0]:.4f} train_time:{approx_training_time_ms:.0f}ms step_avg:{avg_step_time:.2f}ms {extra}")
779
  if master_process:
780
- wandb.log({'train_loss': train_loss.item(), 'step_avg_time': avg_step_time, **{f'lr_{i}': sched.get_last_lr()[0] for i, sched in enumerate(schedulers)}, 'grad_norm': grad_norm.item(), **to_log}, step=iter_)
781
 
782
  print0(f"peak memory consumption during training: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB")
783
  print0("Training complete.")
 
41
  rope_theta_local: float = 10000.0
42
  rope_theta_global: float = 0.0
43
  eps_rmsnorm: float = 1e-6
44
+ mlp_expand: float = 4. # expand factor for MLP
45
+ intermediate_size: Optional[int] = None
46
  fused_loss_computation : bool = True # whether to use fused linear + cross entropy loss
47
  use_uscaling: bool = False
48
  uscaling_tau: float = 0.2
 
59
  seednorm_type: int = 1
60
  seednorm_rank: int = 1
61
  mixer_gn: bool = True
62
+ gate_before_norm: bool = True
63
  mlp_linking : bool = False
64
  final_norm: bool = True
65
  layer_norm_scaling: bool = False # not read when using muP
66
  mlp_type: str = "simple" # simple, gated
67
  tie_lm_head: bool = False
68
+ legacy_gate: bool = False
69
+ vwn: bool = False
70
+ vwn_m: int = 2
71
+ vwn_n: int = 3
72
+ vwn_wd_alpha_beta: bool = False
73
+ vwn_dynamic: bool = True
74
+ reduce_lm_head: int = 0
75
 
76
  # MoE
77
  moe: bool = False
 
126
  mamba3_remove_conv: bool = True
127
  mamba3_is_A_dd: bool = True
128
  mamba3_add_trapezoid: bool = True
129
+ mamba3_postgate_norm: bool = False # only works if legacy_gate is True!!
130
 
131
  # optim
132
  optim: str = "adamw" # adamw, spam, stable-spam, muon, muon_moonlight, splus
 
139
  adam_beta1: float = 0.9
140
  adam_beta2: float = 0.95
141
  adam_eps: float = 1e-8
142
+ alpha_normalize: bool = False # whether to normalize update by (1+alpha) in AdEMAMix
143
+ alpha_ademamix: float = 8.0
144
  warmup_iters: int = 200
145
  warmdown_iters: int = 3000
146
  warmdown_type: str = "linear" # linear, cosine
 
154
  second_order_lr: float = 0.68
155
  second_order_momentum: float = 0.37
156
  second_order_interval: int = 25
157
+ init_gpt2: bool = False
158
+ wnorm: bool = False # as in nemotron-flash (2511.18890)
159
 
160
  # data
161
  vocab_size: int = 50304
 
164
  intra_doc_masking: bool = False
165
  input_bin: Optional[str] = None
166
  input_val_bin: Optional[str] = None
167
+ dataset_type: str = "hf" # hf, mg
168
 
169
  # evaluation and logging
170
  val_loss_every: int = 125
 
185
  # used during training
186
  slw_window: int = 0
187
 
188
+ def _peek_data_shard(filename, dataset_type='hf'):
189
+ if dataset_type == 'hf':
190
+ return _peek_hf_shard(filename)
191
+ elif dataset_type == 'mg':
192
+ return _peek_mg_shard(filename)
193
+ else:
194
+ raise ValueError(f"unknown dataset type: {dataset_type}")
195
+
196
+ def _load_data_shard(filename, dataset_type='hf'):
197
+ if dataset_type == 'hf':
198
+ return _load_hf_shard(filename)
199
+ elif dataset_type == 'mg':
200
+ return _load_mg_shard(filename)
201
+ else:
202
+ raise ValueError(f"unknown dataset type: {dataset_type}")
203
+
204
+ def _load_hf_shard(filename):
205
+ with open(filename, "rb") as f:
206
+ header = np.frombuffer(f.read(256 * 4), dtype=np.int32)
207
+ assert header[0] == 20240520, "magic number mismatch in the data .bin file"
208
+ assert header[1] == 1, "unsupported version"
209
+ ntok = int(header[2])
210
+ # memmap the token payload directly (uint16) after the 256*4B header
211
+ tokens = np.memmap(filename, dtype=np.uint16, mode="r", offset=256 * 4, shape=(ntok,))
212
+ assert tokens.size == ntok, "number of tokens read does not match header?"
213
+ return tokens
214
+
215
+ def _peek_hf_shard(filename):
216
  with open(filename, "rb") as f:
217
  header = np.frombuffer(f.read(256 * 4), dtype=np.int32)
218
  if header[0] != 20240520:
 
224
  ntok = int(header[2])
225
  return ntok
226
 
227
+ def _peek_mg_shard(filename):
228
+ tokens = np.memmap(filename, dtype=np.uint16, mode="r")
229
+ return int(tokens.size)
230
+
231
+ def _load_mg_shard(filename):
232
+ return np.memmap(filename, dtype=np.uint16, mode="r")
 
 
 
 
233
 
234
  class DistributedDataLoader:
235
+ def __init__(self, filename_pattern, intra_doc_masking,B, T, process_rank, num_processes, bos_id, dataset_type='hf'):
236
  self.process_rank = process_rank
237
  self.num_processes = num_processes
238
  self.intra_doc_masking = intra_doc_masking
239
  self.bos_id = bos_id
240
  self.B = B # micro batch size
241
  self.T = T
242
+ self.dataset_type = dataset_type
243
 
244
  # glob files that match the pattern
245
  self.files = sorted(glob.glob(filename_pattern))
 
249
  ntok_total = 0
250
  self.shard_ntoks = []
251
  for fname in self.files:
252
+ shard_ntok = _peek_data_shard(fname, dataset_type=self.dataset_type)
253
  #print(f"shard {fname} has {shard_ntok} tokens")
254
  assert shard_ntok >= num_processes * B * T + 1
255
  self.shard_ntoks.append(shard_ntok)
 
262
  def reset(self, shard=0):
263
  self.current_shard = shard
264
  self.current_position = self.process_rank * self.B * self.T
265
+ self.tokens = _load_data_shard(self.files[self.current_shard], dataset_type=self.dataset_type)
266
 
267
  def advance(self): # advance to next data shard
268
  self.current_shard = (self.current_shard + 1) % len(self.files)
269
  self.current_position = self.process_rank * self.B * self.T
270
+ self.tokens = _load_data_shard(self.files[self.current_shard], dataset_type=self.dataset_type)
271
 
272
  if self.process_rank == 0:
273
  shard_tokens = self.shard_ntoks[self.current_shard]
 
321
  groups, seen = [], set()
322
  id2name = {id(p): n for n, p in model.named_parameters()}
323
 
324
+ for name, mod in model.named_modules():
325
  if isinstance(mod, nn.Linear):
326
  pname = id2name.get(id(mod.weight), "")
327
  is_scalar = getattr(mod, "is_scalar_weight", False)
328
  fan_in = mod.weight.shape[1]
 
329
  if "lm_head" in pname:
330
+ scale = 1
331
  lr_scaled = base_lr_head
332
  wd_scaled = 0.0
333
+ wd_mult = 0.0
334
  elif is_scalar:
335
+ scale = 1
336
  lr_scaled = base_lr_scalar
337
  wd_scaled = 0.0
338
+ wd_mult = 0.0
339
  else:
340
+ scale = 1 / math.sqrt(fan_in)
341
  lr_scaled = base_lr_hidden * scale
342
  wd_scaled = wd / lr_scaled
343
+ wd_mult = 1/lr_scaled
344
 
345
  groups.append({"params": [mod.weight], "lr": lr_scaled, "weight_decay": wd_scaled})
346
  seen.add(mod.weight)
347
 
348
+ print(f"param {name}.weight | shape {mod.weight.shape} | scale {scale} | wd_mult={wd_mult:.3e}")
349
+
350
  if mod.bias is not None:
351
+ assert False
352
  groups.append({"params": [mod.bias], "lr": base_lr_scalar, "weight_decay": 0.0})
353
  seen.add(mod.bias)
354
 
355
+ for name, p in model.named_parameters():
356
  if p in seen:
357
  continue
358
  pname = id2name.get(id(p), "<unnamed>")
 
365
  lr_scaled = base_lr_scalar
366
 
367
  wd_scaled = 0.
368
+ wd_mult = 0.
369
  if getattr(p, "requires_weight_decay", False):
370
  wd_scaled = wd / lr_scaled
371
+ wd_mult = 1/lr_scaled
372
 
373
  groups.append({"params": [p], "lr": lr_scaled, "weight_decay": wd_scaled})
374
 
375
+ print(f"param {name} | shape {p.shape} | scale {1.} | wd_mult={wd_mult:.3e}")
376
+
377
  return groups
378
 
379
  args = tyro.cli(NanoArgs)
 
392
  print("problem: gated MLP with MoE is not supported, because we use FA backend")
393
  exit(0)
394
 
395
+ if args.legacy_gate:
396
+ assert not args.gate_gdn, "legacy_gate is not compatible with gate_gdn."
397
+
398
  # set up DDP (distributed data parallel).
399
  assert torch.cuda.is_available()
400
  dist.init_process_group(
 
488
  # load dataloaders.
489
  #if args.patch_level_training:
490
  # assert T % args.patch_level_training_size == 0, "sequence length must be divisible by patch level training size in reduced mode"
491
+ train_loader = DistributedDataLoader(args.input_bin, args.intra_doc_masking, B, T, ddp_rank, ddp_world_size, args.bos_id, args.dataset_type)
492
+ val_loader = DistributedDataLoader(args.input_val_bin, args.intra_doc_masking, B, T, ddp_rank, ddp_world_size, args.bos_id, args.dataset_type)
493
  print0(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files")
494
  print0(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files")
495
 
496
  # load model.
497
  config_hf = DragonConfig(
498
+ reduce_lm_head=args.reduce_lm_head,
499
+ dataset_type=args.dataset_type,
500
+ vwn=args.vwn,
501
+ vwn_m=args.vwn_m,
502
+ vwn_n=args.vwn_n,
503
+ vwn_wd_alpha_beta=args.vwn_wd_alpha_beta,
504
+ vwn_dynamic=args.vwn_dynamic,
505
+ legacy_gate=args.legacy_gate,
506
+ init_gpt2=args.init_gpt2,
507
  tie_lm_head=args.tie_lm_head,
508
  mlp_type=args.mlp_type,
509
  layer_norm_scaling=args.layer_norm_scaling,
 
515
  mamba3_remove_conv=args.mamba3_remove_conv,
516
  mamba3_is_A_dd=args.mamba3_is_A_dd,
517
  mamba3_add_trapezoid=args.mamba3_add_trapezoid,
518
+ mamba3_postgate_norm=args.mamba3_postgate_norm,
519
  moe=args.moe,
520
  moe_num_routed_experts=args.moe_num_routed_experts,
521
  moe_routed_scaling_factor=args.moe_routed_scaling_factor,
 
530
  shrink_qk_da=args.shrink_qk_da,
531
  shrink_qk_gdn=args.shrink_qk_gdn,
532
  mixer_gn=args.mixer_gn,
533
+ gate_before_norm=args.gate_before_norm,
534
  kda_allow_neg_eigval=args.kda_allow_neg_eigval,
535
  kda_num_v_heads=args.kda_num_v_heads,
536
  seednorm_wd=args.seednorm_wd,
 
573
  max_position_embeddings=args.sequence_length,
574
  use_uscaling=args.use_uscaling,
575
  hidden_size=args.d_model,
576
+ intermediate_size=int(args.d_model * args.mlp_expand) if args.intermediate_size is None else args.intermediate_size,
577
  expand_factor=args.expand_factor,
578
  layers_config=args.layers_config,
579
  num_attention_heads=args.n_heads,
 
600
  model = model.cuda()
601
  print0(model)
602
 
 
603
  with torch.no_grad():
604
+ for name, p in model.named_parameters():
605
+ if p is None or p.numel() == 0:
606
+ continue
607
+ t = p.detach().float()
608
+ mean = t.mean().item()
609
+ std = t.std(unbiased=False).item()
610
+ print0(f"{name:60s} shape={tuple(p.shape)} mean={mean:+.4e} std={std:.4e}")
 
 
 
611
 
612
  # count params. (total & active)
613
  num_params = sum(p.numel() for p in model.parameters())
 
631
 
632
  if args.intra_doc_masking:
633
  print0("!!! Using intra-document masking !!!")
634
+ print0("It is only compatible with GDN (conv+chunk), KDA (conv+chunk), DA and GDTPA layers. For DA/GDTPA, kv shift is also compatible. All other config will not have intra-doc masking support!!")
635
 
636
  # load optimizers & schedulers.
637
  if args.use_uscaling:
 
648
  optimizer = torch.optim.AdamW(param_list, betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps)
649
  elif args.optim == "ademamix":
650
  from .optimizers.Ademamix import AdEMAMix
651
+ beta3_warmup = args.total_iterations
652
+ alpha_warmup = args.total_iterations
653
+ optimizer = AdEMAMix(param_list, beta3_warmup=beta3_warmup, alpha_warmup=alpha_warmup, normalize_alpha=args.alpha_normalize, alpha=args.alpha_ademamix, weight_decay=args.weight_decay)
654
  else:
655
  raise ValueError(f"Unknown optimizer for unit scaling: {args.optim}")
656
  else:
657
  if args.optim == "adamw":
658
+ #optimizer = torch.optim.AdamW(raw_model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay, betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps)
659
+ decay_params = []
660
+ no_decay_params = []
661
+ for name, p in raw_model.named_parameters():
662
+ if not p.requires_grad:
663
+ continue
664
+ if getattr(p, "_no_weight_decay", False):
665
+ no_decay_params.append(p)
666
+ else:
667
+ decay_params.append(p)
668
+ optimizer = torch.optim.AdamW(
669
+ [
670
+ {"params": decay_params, "weight_decay": args.weight_decay},
671
+ {"params": no_decay_params, "weight_decay": 0.0},
672
+ ],
673
+ lr=args.learning_rate,
674
+ betas=(args.adam_beta1, args.adam_beta2),
675
+ eps=args.adam_eps,
676
+ )
677
  elif args.optim == "ademamix":
678
  from .optimizers.Ademamix import AdEMAMix
679
 
680
+ beta3_warmup = args.total_iterations
681
+ alpha_warmup = args.total_iterations
682
+ optimizer = AdEMAMix(raw_model.parameters(), lr=args.learning_rate, beta3_warmup=beta3_warmup, alpha_warmup=alpha_warmup, normalize_alpha=args.alpha_normalize, alpha=args.alpha_ademamix, weight_decay=args.weight_decay)
683
  else:
684
  raise ValueError(f"Unknown Optimizer: {args.optim}")
685
  if args.second_order_optim == "snoo":
 
705
  if args.warmdown_type == "linear":
706
  sched_func = partial(get_lr_wsd, args.total_iterations, args.warmup_iters, args.warmdown_iters)
707
  schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, sched_func) for opt in optimizers]
708
+ elif args.warmdown_type == "cosine" or args.warmdown_type == "1-sqrt":
709
  sched = get_wsd_schedule(
710
  optimizers[0],
711
  num_warmup_steps=args.warmup_iters,
 
713
  num_training_steps=args.total_iterations,
714
  min_lr_ratio=0.,
715
  warmup_type='linear',
716
+ decay_type=args.warmdown_type,
717
  )
718
  schedulers = [sched]
719
  else:
 
802
  # save model & tokenizer to make evaluation easier.
803
  tokenizer.save_pretrained(save_dir)
804
  state_dict_bf16 = {k: v.detach().to(torch.bfloat16).cpu() for k, v in uncompiled_model.state_dict().items()}
805
+ idm_og = uncompiled_model.config.intra_doc_masking
806
+ uncompiled_model.config.intra_doc_masking = False
807
  uncompiled_model.config.torch_dtype = torch.bfloat16
808
  uncompiled_model.save_pretrained(save_dir, safe_serialization=True, state_dict=state_dict_bf16)
809
+ uncompiled_model.config.intra_doc_masking = idm_og
810
  # save training state.
811
  train_state = dict(
812
  iteration=iter_,
 
841
  (loss / accumulation_steps).backward()
842
  else:
843
  (loss / accumulation_steps).backward() # just sync on the last step
844
+ individual_grad_norms = {}
845
+ """# Calculate individual param norms
846
+ # We use 'raw_model' to avoid 'module.' or '_orig_mod.' prefixes in wandb
847
+ individual_grad_norms = {}
848
+ # Only calculate on master process to save time, and maybe throttle frequency (e.g., every 10 steps)
849
+ # If you want it every step, remove the (iter_ % 10 == 0) check.
850
+ if master_process and (iter_ % 50 == 0):
851
+ for name, p in raw_model.named_parameters():
852
+ if p.grad is not None:
853
+ # Calculate L2 norm of the gradient
854
+ param_norm = p.grad.detach().data.norm(2).item()
855
+ individual_grad_norms[f"grad_norm/{name}"] = param_norm"""
856
  # clip those gradients.
857
  if args.grad_norm_clip is not None:
858
  grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.grad_norm_clip, foreach=True)
 
867
  # null those gradients.
868
  model.zero_grad(set_to_none=True)
869
 
870
+ # Wnorm
871
+ if args.wnorm:
872
+ with torch.no_grad():
873
+ for m in model.modules():
874
+ if getattr(m, "norm_case_1", False):
875
+ W = getattr(m, "weight", None)
876
+ denom = W.float().norm(p=2, dim=1, keepdim=True).clamp_min(1e-8).to(W.dtype)
877
+ W.div_(denom)
878
+ elif getattr(m, "norm_case_2", False):
879
+ W = getattr(m, "weight", None)
880
+ denom = W.float().norm(p=2, dim=0, keepdim=True).clamp_min(1e-8).to(W.dtype)
881
+ W.div_(denom)
882
+
883
  # ----------- LOGGING SECTION -----------
884
  approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0)
885
  avg_step_time = approx_training_time_ms / (iter_ + 1 - WARMUP_SKIP) if iter_ >= start_iter+WARMUP_SKIP else 0
886
  extra = " ".join(f"{k}:{v}" for k, v in (to_log or {}).items())
887
  print0(f"iteration:{iter_+1:0{len(str(start_iter+args.total_iterations))}d}/{args.total_iterations} train_loss:{train_loss.item():.4f} lr: {schedulers[0].get_last_lr()[0]:.4f} train_time:{approx_training_time_ms:.0f}ms step_avg:{avg_step_time:.2f}ms {extra}")
888
  if master_process:
889
+ wandb.log({'train_loss': train_loss.item(), 'step_avg_time': avg_step_time, **{f'lr_{i}': sched.get_last_lr()[0] for i, sched in enumerate(schedulers)}, 'grad_norm': grad_norm.item(), **to_log, **individual_grad_norms}, step=iter_)
890
 
891
  print0(f"peak memory consumption during training: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB")
892
  print0("Training complete.")