File size: 35,693 Bytes
d7b3a74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
diff --git a/megatron/core/dist_checkpointing/strategies/common.py b/megatron/core/dist_checkpointing/strategies/common.py
index 41c21d93d..ef80f72d6 100644
--- a/megatron/core/dist_checkpointing/strategies/common.py
+++ b/megatron/core/dist_checkpointing/strategies/common.py
@@ -86,7 +86,7 @@ class TorchCommonLoadStrategy(LoadCommonStrategy):
                 msc = MultiStorageClientFeature.import_package()
                 return msc.torch.load(load_path, map_location='cpu')
             else:
-                return torch.load(load_path, map_location='cpu')
+                return torch.load(load_path, map_location='cpu', weights_only=False)
         except FileNotFoundError as e:
             err_msg = f'Common file {load_path} does not exist'
             if MultiStorageClientFeature.is_enabled():
diff --git a/megatron/core/dist_checkpointing/strategies/torch.py b/megatron/core/dist_checkpointing/strategies/torch.py
index 5a1ea308d..aa701237f 100644
--- a/megatron/core/dist_checkpointing/strategies/torch.py
+++ b/megatron/core/dist_checkpointing/strategies/torch.py
@@ -597,10 +597,12 @@ class MCoreLoadPlanner(DefaultLoadPlanner):
     def _validate_global_shapes(self, metadata, sharded_tensors):
         for sh_ten in sharded_tensors:
             if sh_ten.key not in metadata.state_dict_metadata:
-                raise KeyError(
-                    f"{sh_ten.key} from model not in state dict:"
-                    f" {sorted(metadata.state_dict_metadata.keys())}"
-                )
+                # raise KeyError(
+                #     f"{sh_ten.key} from model not in state dict:"
+                #     f" {sorted(metadata.state_dict_metadata.keys())}"
+                # )
+                print(f"{sh_ten.key} from model not in state dict, will skip")
+                continue
             loaded_shape = metadata.state_dict_metadata[sh_ten.key].size
             expected_shape = self._expected_shape(sh_ten)
             if loaded_shape != expected_shape:
@@ -630,7 +632,7 @@ class MCoreLoadPlanner(DefaultLoadPlanner):
         tensor_metadata = self.metadata.state_dict_metadata
         metadata_with_sizes = [
             (tensor_metadata[key], tensor_metadata[key].size, sharded_tensor)
-            for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items()
+            for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() if key in tensor_metadata
         ]
         try:
             # Temporarily set sizes to expected shapes
@@ -959,6 +961,7 @@ class TorchDistLoadShardedStrategy(LoadShardedStrategy):
             planner=MCoreLoadPlanner(
                 shapes_validation_sharded_tensors=flexible_shape_sharded_tensors,
                 allow_shape_mismatch_sharded_tensors=allow_shape_mismatch_sharded_tensors,
+                allow_partial_load=True,
             ),
         )
 
diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py
index acb93ef78..d239db4ab 100644
--- a/megatron/core/extensions/transformer_engine.py
+++ b/megatron/core/extensions/transformer_engine.py
@@ -408,6 +408,7 @@ class TELinear(te.pytorch.Linear):
         )
 
         for param in self.parameters():
+            setattr(param, "parallel_mode", parallel_mode)
             if is_expert:
                 # Reduce the gradient on the expert_data_parallel group for expert linear layers
                 setattr(param, "allreduce", not self.expert_parallel)
@@ -1161,6 +1162,61 @@ class TEDotProductAttention(te.pytorch.DotProductAttention):
 
 
 if HAVE_TE and is_te_min_version("1.9.0.dev0"):
+    def ceil_div(x: int, y: int) -> int:
+        return (x + y - 1) // y
+
+    class _FakeInt4QuantizationSTE(torch.autograd.Function):
+        @staticmethod
+        def forward(ctx, x, group_size):
+            m, n = x.shape
+            block_size_m, block_size_n = 1, group_size
+
+
+            m_padded = ceil_div(m, block_size_m) * block_size_m
+            n_padded = ceil_div(n, block_size_n) * block_size_n
+
+            x_padded = torch.zeros(
+                (m_padded, n_padded),
+                dtype=x.dtype, device=x.device
+            )
+            x_padded[:m, :n] = x
+
+            x_view = x_padded.view(
+                m_padded // block_size_m,
+                block_size_m,
+                n_padded // block_size_n,
+                block_size_n
+            )
+
+            x_max = x_view.abs().float().amax(dim=(1, 3), keepdim=True)
+            q_max = 7
+            x_scale = x_max / q_max
+
+            x_scale = x_scale.clamp(min=1e-5)
+
+            x_div = x_view / x_scale
+            x_round = torch.round(x_div)
+
+            x_q_clamped = x_round.clamp(-q_max, q_max)
+
+            x_dequant_view = x_q_clamped * x_scale
+
+            x_dequant_full = x_dequant_view.view_as(x_padded)
+            x_out = x_dequant_full[:m, :n].contiguous().to(x.dtype)
+
+            return x_out
+
+        @staticmethod
+        def backward(ctx, grad_output):
+            return grad_output, None
+
+    def fake_int4_quantization_ste(x, group_size):
+        x_out = _FakeInt4QuantizationSTE.apply(x, group_size)
+        
+        if hasattr(x, 'main_grad'):
+            x_out.main_grad = x.main_grad
+            
+        return x_out
 
     class TEGroupedLinear(te.pytorch.GroupedLinear):
         """
@@ -1351,6 +1407,7 @@ if HAVE_TE and is_te_min_version("1.9.0.dev0"):
             _is_first_microbatch = (
                 None if self.disable_parameter_transpose_cache else self.is_first_microbatch
             )
+
             out = super().forward(x, m_splits, is_first_microbatch=_is_first_microbatch)
             self.is_first_microbatch = False
 
@@ -1361,6 +1418,20 @@ if HAVE_TE and is_te_min_version("1.9.0.dev0"):
                 return out
             return out, None
 
+        def _get_weight_tensors(self):
+            """Get the weight tensors of the module."""
+            weight_tensors = super()._get_weight_tensors()
+
+            if os.getenv("OPEN_TRAINING_INT4_FAKE_QAT_FLAG", "0") == "1":
+                group_size = int(os.getenv("OPEN_TRAINING_INT4_GROUP_SIZE", "128"))
+
+                weight_tensors = [
+                    fake_int4_quantization_ste(w, group_size) 
+                    for w in weight_tensors
+                ]
+                
+            return weight_tensors
+
         def _encode_extra_state(self, state):
             # TE 2.0 changed the format of extra_state to be a byte tensor
             if is_te_min_version("2.0.0"):
diff --git a/megatron/core/fusions/fused_mla_yarn_rope_apply.py b/megatron/core/fusions/fused_mla_yarn_rope_apply.py
index 1fd5dcfae..c9aeef1f0 100644
--- a/megatron/core/fusions/fused_mla_yarn_rope_apply.py
+++ b/megatron/core/fusions/fused_mla_yarn_rope_apply.py
@@ -385,6 +385,7 @@ def rotary_fwd_kv_kernel(
     SIN,
     emb_dim: tl.constexpr,
     k_dim: tl.constexpr,
+    k_dim_ceil: tl.constexpr,
     v_dim: tl.constexpr,
     head_num: tl.constexpr,
     batch_size,
@@ -434,21 +435,27 @@ def rotary_fwd_kv_kernel(
     cos_right = tl.load(COS + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2))
     sin_right = tl.load(SIN + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2))
 
-    KV_ptr = KV + pid_m * stride_kv_seq + pid_head * BLOCK_H * stride_kv_nheads
-    kv_off = tl.arange(0, BLOCK_H)[:, None] * stride_kv_nheads
-    mask = kv_off < head_num * stride_kv_nheads
-    k_in_off = kv_off + tl.arange(0, k_dim)[None, :]
-    v_in_off = kv_off + k_dim + tl.arange(0, v_dim)[None, :]
-    k = tl.load(KV_ptr + k_in_off, mask=mask)
-    v = tl.load(KV_ptr + v_in_off, mask=mask)
+    KV_ptr = KV + pid_m * stride_kv_seq # + pid_head * BLOCK_H * stride_kv_nheads
+    ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H
+    kj_range = tl.arange(0, k_dim_ceil)[None, :]
+    mask_k = (ki_range < head_num) & (kj_range < k_dim)
+    mask_v = ki_range < head_num
+    k_off = ki_range * stride_kv_nheads + kj_range
+    if v_dim > 0:
+        v_off = ki_range * stride_kv_nheads + k_dim + tl.arange(0, v_dim)[None, :]
+        v = tl.load(KV_ptr + v_off, mask=mask_v)
+    else:
+        v = tl.zeros((BLOCK_H, 1), dtype=KV.dtype.element_ty)
+    k = tl.load(KV_ptr + k_off, mask=mask_k)
 
-    K_ptr = O_KEY + pid_m * stride_k_seq + pid_head * BLOCK_H * stride_k_nheads
-    V_ptr = O_VALUE + pid_m * stride_v_seq + pid_head * BLOCK_H * stride_v_nheads
+    K_ptr = O_KEY + pid_m * stride_k_seq # + pid_head * BLOCK_H * stride_k_nheads
+    V_ptr = O_VALUE + pid_m * stride_v_seq # + pid_head * BLOCK_H * stride_v_nheads
 
-    k_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads + tl.arange(0, k_dim)[None, :]
-    v_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_v_nheads + tl.arange(0, v_dim)[None, :]
-    tl.store(K_ptr + k_out_off, k, mask=mask)
-    tl.store(V_ptr + v_out_off, v, mask=mask)
+    k_out_off = ki_range * stride_k_nheads + kj_range
+    tl.store(K_ptr + k_out_off, k, mask=mask_k)
+    if v_dim > 0:
+        v_out_off = ki_range * stride_v_nheads + tl.arange(0, v_dim)[None, :]
+        tl.store(V_ptr + v_out_off, v, mask=mask_v)
 
     EMB = K_POS_EMB + pid_m * stride_emb_seq
     # x1 = t[..., 0::2], x2 = t[..., 1::2]
@@ -460,14 +467,16 @@ def rotary_fwd_kv_kernel(
     x_left = x_left.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2)
     x_right = x_right.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2)
 
+    x_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H
+    mask_x = x_range < head_num
     x_left_off = (
-        tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads
+        x_range * stride_k_nheads
         + k_dim
         + tl.arange(0, emb_dim // 2)[None, :]
     )
     x_right_off = x_left_off + emb_dim // 2
-    tl.store(K_ptr + x_left_off, x_left, mask=mask)
-    tl.store(K_ptr + x_right_off, x_right, mask=mask)
+    tl.store(K_ptr + x_left_off, x_left, mask=mask_x)
+    tl.store(K_ptr + x_right_off, x_right, mask=mask_x)
 
 
 @triton.autotune(
@@ -493,6 +502,7 @@ def rotary_bwd_kv_kernel(
     SIN,
     emb_dim: tl.constexpr,
     k_dim: tl.constexpr,
+    k_dim_ceil: tl.constexpr,
     v_dim: tl.constexpr,
     head_num: tl.constexpr,
     batch_size,
@@ -533,27 +543,32 @@ def rotary_bwd_kv_kernel(
     else:
         token_idx = _get_thd_token_idx(cu_seqlens_kv, pid_m, seq_num, cp_rank, cp_size)
 
-    dKV_ptr = dKV + pid_m * stride_dkv_seq + pid_head * BLOCK_H * stride_dkv_nheads
-    dkv_off = tl.arange(0, BLOCK_H)[:, None] * stride_dkv_nheads
-    mask = dkv_off < head_num * stride_dkv_nheads
-    dk_out_off = dkv_off + tl.arange(0, k_dim)[None, :]
-    dv_out_off = dkv_off + k_dim + tl.arange(0, v_dim)[None, :]
-
-    dK_ptr = dK + pid_m * stride_dk_seq + pid_head * BLOCK_H * stride_dk_nheads
-    dV_ptr = dV + pid_m * stride_dv_seq + pid_head * BLOCK_H * stride_dv_nheads
-    dk_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + tl.arange(0, k_dim)[None, :]
-    dv_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dv_nheads + tl.arange(0, v_dim)[None, :]
-    dk = tl.load(dK_ptr + dk_in_off, mask=mask)
-    dv = tl.load(dV_ptr + dv_in_off, mask=mask)
-    tl.store(dKV_ptr + dk_out_off, dk, mask=mask)
-    tl.store(dKV_ptr + dv_out_off, dv, mask=mask)
+    dKV_ptr = dKV + pid_m * stride_dkv_seq # + pid_head * BLOCK_H * stride_dkv_nheads
+    ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H 
+    kj_range = tl.arange(0, k_dim_ceil)[None, :]
+    mask_k = (ki_range < head_num) & (kj_range < k_dim)
+    mask_v = ki_range < head_num
+    dk_out_off = ki_range * stride_dkv_nheads + kj_range
+
+    dK_ptr = dK + pid_m * stride_dk_seq # + pid_head * BLOCK_H * stride_dk_nheads
+    dV_ptr = dV + pid_m * stride_dv_seq # + pid_head * BLOCK_H * stride_dv_nheads
+    dk_in_off = ki_range * stride_dk_nheads + kj_range
+
+    dk = tl.load(dK_ptr + dk_in_off, mask=mask_k)
+    tl.store(dKV_ptr + dk_out_off, dk, mask=mask_k)
+    
+    if v_dim > 0:
+        dv_out_off = ki_range * stride_dkv_nheads + k_dim + tl.arange(0, v_dim)[None, :]
+        dv_in_off = ki_range * stride_dv_nheads + tl.arange(0, v_dim)[None, :]
+        dv = tl.load(dV_ptr + dv_in_off, mask=mask_v)
+        tl.store(dKV_ptr + dv_out_off, dv, mask=mask_v)
 
     if pid_head == 0:
         x_left_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32)
         x_right_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32)
         for i in tl.static_range(triton.cdiv(head_num, BLOCK_H)):
-            dK_ptr = dK + pid_m * stride_dk_seq + i * BLOCK_H * stride_dk_nheads
-            x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim
+            dK_ptr = dK + pid_m * stride_dk_seq # + i * BLOCK_H * stride_dk_nheads
+            x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim + i * BLOCK_H * stride_dk_nheads
             mask = x_off < head_num * stride_dk_nheads
             x_left_off = x_off + tl.arange(0, emb_dim // 2)[None, :]
             x_right_off = x_left_off + emb_dim // 2
@@ -632,6 +647,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function):
 
         o_key = kv.new_empty(total_seqlen, nheads, emb_dim + k_dim)
         o_value = kv.new_empty(total_seqlen, nheads, v_dim)
+        k_dim_ceil = triton.next_power_of_2(k_dim)
 
         grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"]))
         rotary_fwd_kv_kernel[grid](
@@ -643,6 +659,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function):
             sin,
             emb_dim,
             k_dim,
+            k_dim_ceil,
             v_dim,
             nheads,
             batch_size,
@@ -700,6 +717,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function):
 
         d_kv = dk.new_empty(total_seqlen, nheads, ctx.k_dim + ctx.v_dim)
         d_emb = dk.new_empty(total_seqlen, 1, ctx.emb_dim)
+        k_dim_ceil = triton.next_power_of_2(ctx.k_dim)
 
         grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"]))
         rotary_bwd_kv_kernel[grid](
@@ -711,6 +729,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function):
             sin,
             ctx.emb_dim,
             ctx.k_dim,
+            k_dim_ceil,
             ctx.v_dim,
             nheads,
             batch_size,
diff --git a/megatron/core/models/common/language_module/language_module.py b/megatron/core/models/common/language_module/language_module.py
index 13d74aa52..060898a7a 100644
--- a/megatron/core/models/common/language_module/language_module.py
+++ b/megatron/core/models/common/language_module/language_module.py
@@ -184,7 +184,15 @@ class LanguageModule(MegatronModule):
             assert (
                 column_parallel_linear is not None
             ), "column_parallel_linear cannot be None when not using fused linear cross entropy."
-            logits, _ = column_parallel_linear(hidden, **col_linear_kwargs)
+            # output
+            output_layer_params = {k: v.detach() for k, v in column_parallel_linear.named_parameters()}
+            output_layer_buffers = dict(column_parallel_linear.named_buffers())
+            logits, _ = torch.func.functional_call(
+                column_parallel_linear,
+                {**output_layer_params, **output_layer_buffers},
+                (hidden,),
+                col_linear_kwargs,
+            )
 
             return self.compute_language_model_loss(labels, logits)
 
diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py
index e21127b87..712793853 100755
--- a/megatron/core/models/gpt/gpt_layer_specs.py
+++ b/megatron/core/models/gpt/gpt_layer_specs.py
@@ -188,6 +188,8 @@ def get_gpt_layer_with_transformer_engine_spec(
     use_kitchen: bool = False,
     use_te_activation_func: bool = False,
     fallback_to_eager_attn: bool = False,
+    post_self_attn_layernorm: bool = False,
+    post_mlp_layernorm: bool = False,
 ) -> ModuleSpec:
     """Use this spec to use lower-level Transformer Engine modules (required for fp8 training).
 
@@ -260,6 +262,8 @@ def get_gpt_layer_with_transformer_engine_spec(
         mlp=mlp,
         sharded_state_dict_keys_map=sharded_state_dict_keys_map,
         normalization=normalization,
+        post_self_attn_layernorm=post_self_attn_layernorm,
+        post_mlp_layernorm=post_mlp_layernorm,
     )
 
 
@@ -349,6 +353,8 @@ def get_transformer_layer_spec_for_backend(
     mlp: ModuleSpec,
     sharded_state_dict_keys_map: Optional[dict] = None,
     normalization: Optional[str] = None,
+    post_self_attn_layernorm: bool = False,
+    post_mlp_layernorm: bool = False,
 ) -> ModuleSpec:
     """Helper function to get module spec for TransformerLayer"""
 
@@ -371,9 +377,11 @@ def get_transformer_layer_spec_for_backend(
             input_layernorm=input_layernorm,
             self_attention=attention,
             self_attn_bda=get_bias_dropout_add,
+            post_self_attn_layernorm=TENorm if post_self_attn_layernorm else IdentityOp,
             pre_mlp_layernorm=pre_mlp_layernorm,
             mlp=mlp,
             mlp_bda=get_bias_dropout_add,
+            post_mlp_layernorm=TENorm if post_mlp_layernorm else IdentityOp,
             sharded_state_dict_keys_map=sharded_state_dict_keys_map,
         ),
     )
diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py
index a1230568c..1fd52f65a 100644
--- a/megatron/core/models/gpt/gpt_model.py
+++ b/megatron/core/models/gpt/gpt_model.py
@@ -446,6 +446,7 @@ class GPTModel(LanguageModule):
         *,
         inference_params: Optional[BaseInferenceContext] = None,
         loss_mask: Optional[Tensor] = None,
+        mtp_kwargs: Optional[dict] = {},
     ) -> Tensor:
         """Forward function of the GPT Model This function passes the input tensors
         through the embedding layer, and then the decoder and finally into the post
@@ -508,6 +509,7 @@ class GPTModel(LanguageModule):
             runtime_gather_output=runtime_gather_output,
             extra_block_kwargs=extra_block_kwargs,
             inference_context=inference_context,
+            mtp_kwargs=mtp_kwargs,
         )
 
     def _postprocess(
@@ -529,6 +531,7 @@ class GPTModel(LanguageModule):
         runtime_gather_output=None,
         extra_block_kwargs=None,
         inference_context=None,
+        mtp_kwargs={},
     ):
         """Postprocesses decoder hidden states to generate logits or compute loss.
 
@@ -543,7 +546,8 @@ class GPTModel(LanguageModule):
         output_weight = None
         if self.share_embeddings_and_output_weights:
             output_weight = self.shared_embedding_or_output_weight()
-        if mtp_in_postprocess:
+
+        if mtp_in_postprocess and mtp_kwargs.get('mtp_labels', None) is not None:
             hidden_states = self.mtp(
                 input_ids=input_ids,
                 position_ids=position_ids,
@@ -563,13 +567,18 @@ class GPTModel(LanguageModule):
             return hidden_states
 
         # Skip when mtp_num_layers is None or 0
-        if self.config.mtp_num_layers:
-            mtp_labels = labels.clone()
+        if self.config.mtp_num_layers and mtp_kwargs.get('mtp_labels', None) is not None:
+            mtp_labels = mtp_kwargs['mtp_labels'].clone()
+            mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params)
+
             hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0)
             hidden_states = hidden_states_list[0]
             if loss_mask is None:
                 # if loss_mask is not provided, use all ones as loss_mask
                 loss_mask = torch.ones_like(mtp_labels)
+            else:
+                # Otherwise, roll the loss_mask to keep up with the mtp_labels
+                loss_mask, _ = roll_tensor(loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params)
             for mtp_layer_number in range(self.config.mtp_num_layers):
                 # Calc loss for the current Multi-Token Prediction (MTP) layers.
                 mtp_labels, _ = roll_tensor(
@@ -595,7 +604,7 @@ class GPTModel(LanguageModule):
                     sequence_parallel_enabled=self.output_layer.sequence_parallel,
                     column_parallel_linear=self.output_layer,
                     col_linear_kwargs={
-                        'weight': output_weight,
+                        'weight': output_weight.detach() if output_weight else None,
                         'runtime_gather_output': runtime_gather_output,
                     },
                 )
diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py
index 6e093f96f..eac21a3ea 100644
--- a/megatron/core/optimizer/distrib_optimizer.py
+++ b/megatron/core/optimizer/distrib_optimizer.py
@@ -677,6 +677,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
                 # TE FusedAdam will not accumulate step for empty param groups, so we need to
                 # align the step across param groups.
                 param_group["step"] = int(step)
+            if "step" in param_group and param_group["step"] is None:
+                del param_group["step"]
 
         # Grad scaler state.
         if self.grad_scaler:
@@ -1646,6 +1648,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
                             if key == 'padding':
                                 tensors[key] = LocalNonpersistentObject(tensors[key])
                                 continue
+                            if key == 'step':
+                                continue
                             assert tensors[key].shape == (gbuf_local_end - gbuf_local_start,), (
                                 tensors[key].shape,
                                 gbuf_local_start,
diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py
index a273002b9..4f821cfd5 100644
--- a/megatron/core/parallel_state.py
+++ b/megatron/core/parallel_state.py
@@ -11,6 +11,7 @@ from typing import Callable, List, Optional
 
 import numpy as np
 import torch
+import torch.distributed as dist
 
 from .utils import GlobalMemoryBuffer, is_torch_min_version
 
diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py
index ac839c21f..f18309217 100644
--- a/megatron/core/pipeline_parallel/p2p_communication.py
+++ b/megatron/core/pipeline_parallel/p2p_communication.py
@@ -26,22 +26,22 @@ def _batched_p2p_ops(
     ops = []
     if tensor_send_prev is not None:
         send_prev_op = torch.distributed.P2POp(
-            torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, group
+            torch.distributed.isend, tensor_send_prev, prev_pipeline_rank,
         )
         ops.append(send_prev_op)
     if tensor_recv_prev is not None:
         recv_prev_op = torch.distributed.P2POp(
-            torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, group
+            torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank,
         )
         ops.append(recv_prev_op)
     if tensor_send_next is not None:
         send_next_op = torch.distributed.P2POp(
-            torch.distributed.isend, tensor_send_next, next_pipeline_rank, group
+            torch.distributed.isend, tensor_send_next, next_pipeline_rank,
         )
         ops.append(send_next_op)
     if tensor_recv_next is not None:
         recv_next_op = torch.distributed.P2POp(
-            torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, group
+            torch.distributed.irecv, tensor_recv_next, next_pipeline_rank,
         )
         ops.append(recv_next_op)
     if len(ops) > 0:
diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py
index 28cff06f5..58dc4bb70 100644
--- a/megatron/core/transformer/moe/moe_utils.py
+++ b/megatron/core/transformer/moe/moe_utils.py
@@ -587,6 +587,9 @@ def topk_routing_with_score_function(
         else:
             return torch.topk(scores, k=topk, dim=1)
 
+    from slime.utils.routing_replay import get_routing_replay_compute_topk
+    compute_topk = get_routing_replay_compute_topk(compute_topk)
+
     if score_function == "softmax":
         if use_pre_softmax:
             scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py
index 16fc9d9af..517944f25 100644
--- a/megatron/core/transformer/moe/router.py
+++ b/megatron/core/transformer/moe/router.py
@@ -201,6 +201,9 @@ class TopKRouter(Router):
             self.global_tokens_per_expert = None
             self.ga_steps = None
 
+        from slime.utils.routing_replay import register_routing_replay
+        register_routing_replay(self)
+
     def _maintain_float32_expert_bias(self):
         """
         Maintain the expert bias in float32.
diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py
index a8f4abfcd..f33f6f05e 100755
--- a/megatron/core/transformer/multi_token_prediction.py
+++ b/megatron/core/transformer/multi_token_prediction.py
@@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Union
 
 import torch
 from torch import Tensor
+import warnings
 
 from megatron.core import InferenceParams, parallel_state, tensor_parallel
 from megatron.core.dist_checkpointing.mapping import ShardedStateDict
@@ -714,17 +715,19 @@ class MultiTokenPredictionLayer(MegatronModule):
             cp_group=self.cp_group,
             packed_seq_params=packed_seq_params,
         )
-        position_ids, _ = roll_tensor(
-            position_ids,
-            shifts=-1,
-            dims=-1,
-            cp_group=self.cp_group,
-            packed_seq_params=packed_seq_params,
-        )
+        if position_ids is not None:
+            position_ids, _ = roll_tensor(
+                position_ids,
+                shifts=-1,
+                dims=-1,
+                cp_group=self.cp_group,
+                packed_seq_params=packed_seq_params,
+            )
         # embedding
         decoder_input = embedding(input_ids=input_ids, position_ids=position_ids)
+        decoder_input = decoder_input.detach()
 
-        hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True)
+        hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=False)
 
         return input_ids, position_ids, decoder_input, hidden_states
 
@@ -826,6 +829,51 @@ class MultiTokenPredictionLayer(MegatronModule):
         return hidden_states
 
     def _checkpointed_forward(self, forward_func, *args, **kwargs):
+        """Wrap `forward_func` with activation checkpointing while only passing tensors.
+
+        Non-tensor arguments (e.g., configuration objects, None) are captured via closure so
+        that checkpoint implementations never receive them directly, avoiding save_for_backward
+        issues with non-tensor inputs.
+        """
+
+        # TODO(jiajun): Is there any better implementation here?
+        positional_specs = []
+        kw_specs = []
+        tensor_args: List[torch.Tensor] = []
+
+        for arg in args:
+            if torch.is_tensor(arg):
+                positional_specs.append(('tensor', len(tensor_args)))
+                tensor_args.append(arg)
+            else:
+                positional_specs.append(('const', arg))
+
+        for key, value in kwargs.items():
+            if torch.is_tensor(value):
+                kw_specs.append((key, ('tensor', len(tensor_args))))
+                tensor_args.append(value)
+            else:
+                kw_specs.append((key, ('const', value)))
+
+        def run(*flat_tensor_args):
+            rebuilt_args = []
+            for spec_type, payload in positional_specs:
+                if spec_type == 'tensor':
+                    rebuilt_args.append(flat_tensor_args[payload])
+                else:
+                    rebuilt_args.append(payload)
+
+            rebuilt_kwargs = {}
+            for key, (spec_type, payload) in kw_specs:
+                if spec_type == 'tensor':
+                    rebuilt_kwargs[key] = flat_tensor_args[payload]
+                else:
+                    rebuilt_kwargs[key] = payload
+
+            return forward_func(*rebuilt_args, **rebuilt_kwargs)
+
+        tensor_args_tuple = tuple(tensor_args)
+
         def checkpoint_handler():
             """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`"""
             if self.config.fp8:
@@ -836,12 +884,11 @@ class MultiTokenPredictionLayer(MegatronModule):
                     self.config.distribute_saved_activations,
                     tensor_parallel.random.get_cuda_rng_tracker,
                     parallel_state.get_tensor_model_parallel_group(),
-                    *args,
-                    **kwargs,
+                    *tensor_args_tuple,
                 )
             else:
                 return tensor_parallel.checkpoint(
-                    forward_func, self.config.distribute_saved_activations, *args, *kwargs.values()
+                    run, self.config.distribute_saved_activations, *tensor_args_tuple
                 )
 
         if self.config.recompute_method == 'uniform':
diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py
index e2705bd9f..a0aa109b5 100644
--- a/megatron/core/transformer/transformer_config.py
+++ b/megatron/core/transformer/transformer_config.py
@@ -210,6 +210,9 @@ class TransformerConfig(ModelParallelConfig):
     attention_output_gate: bool = False
     """Whether to apply output gate to the attention layers."""
 
+    post_self_attn_layernorm: bool = False
+    post_mlp_layernorm: bool = False
+
     test_mode: bool = False
     """Whether to run real-time tests."""
 
diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py
index 3ea405770..5a42001b9 100644
--- a/megatron/core/transformer/transformer_layer.py
+++ b/megatron/core/transformer/transformer_layer.py
@@ -223,6 +223,7 @@ class TransformerLayerSubmodules:
     input_layernorm: Union[ModuleSpec, type] = IdentityOp
     self_attention: Union[ModuleSpec, type] = IdentityOp
     self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp
+    post_self_attn_layernorm: Union[ModuleSpec, type] = IdentityOp
 
     pre_cross_attn_layernorm: Union[ModuleSpec, type] = IdentityOp
     cross_attention: Union[ModuleSpec, type] = IdentityOp
@@ -231,6 +232,7 @@ class TransformerLayerSubmodules:
     pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp
     mlp: Union[ModuleSpec, type] = IdentityOp
     mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp
+    post_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp
 
     # Mapping for sharded tensor keys to be applied in `sharded_state_dict` method
     sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict)
@@ -310,6 +312,13 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer):
         # [Module 3: BiasDropoutFusion]
         self.self_attn_bda = build_module(submodules.self_attn_bda)
 
+        self.post_self_attn_layernorm = build_module(
+            submodules.post_self_attn_layernorm,
+            config=self.config,
+            hidden_size=self.config.hidden_size,
+            eps=self.config.layernorm_epsilon,
+        )
+
         # [Module 4: Post SelfAttention] Optional Layernorm after self-attn
         self.pre_cross_attn_layernorm = build_module(
             submodules.pre_cross_attn_layernorm,
@@ -375,6 +384,13 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer):
 
         self.is_moe_layer = isinstance(self.mlp, MoELayer)
 
+        self.post_mlp_layernorm = build_module(
+            submodules.post_mlp_layernorm,
+            config=self.config,
+            hidden_size=self.config.hidden_size,
+            eps=self.config.layernorm_epsilon
+        )
+
         self.recompute_input_layernorm = False
         self.recompute_pre_mlp_layernorm = False
         self.recompute_mlp = False
@@ -551,6 +567,10 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer):
                 attention_output_with_bias[0]
             )
 
+        attention_output, attention_output_bias = attention_output_with_bias
+        attention_output = self.post_self_attn_layernorm(attention_output)
+        attention_output_with_bias = (attention_output, attention_output_bias)
+
         # TODO: could we move `bias_dropout_add_exec_handler` itself
         # inside the module provided in the `bias_dropout_add_spec` module?
         nvtx_range_push(suffix="self_attn_bda")
@@ -677,6 +697,10 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer):
         else:
             mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output)
 
+        mlp_output, mlp_output_bias = mlp_output_with_bias
+        mlp_output = self.post_mlp_layernorm(mlp_output)
+        mlp_output_with_bias = (mlp_output, mlp_output_bias)
+
         if self.recompute_pre_mlp_layernorm:
             # discard the output of the pre-mlp layernorm and register the recompute
             # as a gradient hook of mlp_output_with_bias[0]
diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py
index b267c8a81..83736acdc 100644
--- a/megatron/training/arguments.py
+++ b/megatron/training/arguments.py
@@ -1398,6 +1398,9 @@ def core_transformer_config_from_args(args, config_class=None):
 
     kw_args['inference_sampling_seed'] = args.seed
 
+    kw_args['post_self_attn_layernorm'] = args.post_self_attn_layernorm
+    kw_args['post_mlp_layernorm'] = args.post_mlp_layernorm
+
     # handle quantization config
     # NOTE: Kitchen arguments are only added to the namespace when
     # Kitchen library is available.
@@ -1764,6 +1767,12 @@ def _add_network_size_args(parser):
                        action='store_true',
                        help='If set, use original BERT residula connection '
                        'ordering.')
+    group.add_argument('--post-self-attn-layernorm', action='store_true',
+                       help='If set, use post self attention layernorm.')
+    group.add_argument('--post-mlp-layernorm', action='store_true',
+                       help='If set, use post MLP layernorm.')
+    group.add_argument('--use-gated-attention', action='store_true',
+                       help='If set, use gated attention as in Qwen3Next')
     group.add_argument('--openai-gelu', action='store_true',
                        help='Use OpenAIs GeLU implementation. This option'
                        'should not be used unless for backward compatibility'
diff --git a/megatron/training/tokenizer/tokenizer.py b/megatron/training/tokenizer/tokenizer.py
index 13b7526ca..6c590f653 100644
--- a/megatron/training/tokenizer/tokenizer.py
+++ b/megatron/training/tokenizer/tokenizer.py
@@ -136,7 +136,7 @@ class _HuggingFaceTokenizer(MegatronLegacyTokenizer):
         # TODO(bnorick): download tokenizer once to lustre and use force offline to make sure all tasks read it from there
         self._tokenizer = transformers.AutoTokenizer.from_pretrained(
             pretrained_model_name_or_path=pretrained_model_name_or_path,
-            trust_remote_code=trust_remote_code,
+            trust_remote_code=True,
             **kwargs,
         )
         self._vocab = self._tokenizer.get_vocab()