Kernels
TaehyunKim commited on
Commit
1a3da4d
·
unverified ·
2 Parent(s): d65066c e93bd1e

Merge pull request #9 from MotifTechnologies/all2all_gather_scatter

Browse files
Files changed (41) hide show
  1. build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py +3 -3
  2. build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_9c21645_dirty.abi3.so → _optimizer_15336dc_dirty.abi3.so} +1 -1
  3. build/torch27-cxx11-cu118-x86_64-linux/optimizer/matmul_transpose_triton.py +128 -0
  4. build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py +319 -104
  5. build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
  6. build/{torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_9c21645_dirty.abi3.so → torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_15336dc_dirty.abi3.so} +1 -1
  7. build/torch27-cxx11-cu126-x86_64-linux/optimizer/matmul_transpose_triton.py +128 -0
  8. build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py +319 -104
  9. build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
  10. build/{torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_9c21645_dirty.abi3.so → torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_15336dc_dirty.abi3.so} +1 -1
  11. build/torch27-cxx11-cu128-x86_64-linux/optimizer/matmul_transpose_triton.py +128 -0
  12. build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py +319 -104
  13. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
  14. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_9c21645_dirty.abi3.so → _optimizer_15336dc_dirty.abi3.so} +1 -1
  15. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/matmul_transpose_triton.py +128 -0
  16. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py +319 -104
  17. build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
  18. build/{torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_9c21645_dirty.abi3.so → torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_15336dc_dirty.abi3.so} +1 -1
  19. build/torch28-cxx11-cu126-x86_64-linux/optimizer/matmul_transpose_triton.py +128 -0
  20. build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py +319 -104
  21. build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
  22. build/{torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_9c21645_dirty.abi3.so → torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_15336dc_dirty.abi3.so} +1 -1
  23. build/torch28-cxx11-cu128-x86_64-linux/optimizer/matmul_transpose_triton.py +128 -0
  24. build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py +319 -104
  25. build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py +3 -3
  26. build/{torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_9c21645_dirty.abi3.so → torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_15336dc_dirty.abi3.so} +1 -1
  27. build/torch28-cxx11-cu129-x86_64-linux/optimizer/matmul_transpose_triton.py +128 -0
  28. build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py +319 -104
  29. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
  30. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_9c21645_dirty.abi3.so → _optimizer_15336dc_dirty.abi3.so} +1 -1
  31. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/matmul_transpose_triton.py +128 -0
  32. build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py +319 -104
  33. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py +3 -3
  34. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/{_optimizer_9c21645_dirty.abi3.so → _optimizer_15336dc_dirty.abi3.so} +1 -1
  35. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/matmul_transpose_triton.py +128 -0
  36. build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py +319 -104
  37. test/test_muon/muon.py +0 -1
  38. test/test_muon/optimizer +1 -0
  39. test/test_muon/test.py +1 -1
  40. torch-ext/optimizer/matmul_transpose_triton.py +128 -0
  41. torch-ext/optimizer/muon.py +319 -104
build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_9c21645_dirty
3
- ops = torch.ops._optimizer_9c21645_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_9c21645_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_15336dc_dirty
3
+ ops = torch.ops._optimizer_15336dc_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_15336dc_dirty::{op_name}"
build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_9c21645_dirty.abi3.so → _optimizer_15336dc_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:bf8b97161714dff91953d26ae0bf59ebc9f3653ce57a3998723cc08aa97b71e6
3
  size 1787368
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94a28c3602d8c7a6b216976b1fb09cdd1e9f61bfc9359a80f41b5b628efdfc28
3
  size 1787368
build/torch27-cxx11-cu118-x86_64-linux/optimizer/matmul_transpose_triton.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright (c) 2025 Tianyang Lin
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ import torch
24
+ import triton
25
+ import triton.language as tl
26
+
27
+
28
+ def get_autotune_config():
29
+ return [
30
+ triton.Config(
31
+ {
32
+ 'BLOCK_SIZE_M': blk_m,
33
+ 'BLOCK_SIZE_K': blk_k,
34
+ 'GROUP_SIZE_M': grp_sz
35
+ },
36
+ num_stages=n_stages,
37
+ num_warps=n_warps) for blk_m in [32, 64, 128]
38
+ for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5]
39
+ for n_warps in [4, 8]
40
+ ]
41
+
42
+
43
+ @triton.autotune(
44
+ configs=get_autotune_config(),
45
+ key=['M', 'K'],
46
+ )
47
+ @triton.jit
48
+ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
49
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
50
+ GROUP_SIZE_M: tl.constexpr):
51
+ """
52
+ Core kernel jit function of matmul_transpose that computes y = x @ x.T
53
+ The code is a simple adaptation from the triton `matmul` tutorial:
54
+ https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
55
+ """
56
+ pid = tl.program_id(axis=0)
57
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
58
+ num_pid_n = tl.cdiv(M, BLOCK_SIZE_M)
59
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
60
+ group_id = pid // num_pid_in_group
61
+ first_pid_m = group_id * GROUP_SIZE_M
62
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
63
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
64
+ pid_n = (pid % num_pid_in_group) // group_size_m
65
+ if pid_m > pid_n:
66
+ return
67
+
68
+ offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
69
+ offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
70
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
71
+ # we use a & b ptrs to denote different rows of x.
72
+ a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
73
+ b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk)
74
+
75
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32)
76
+
77
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
78
+ a = tl.load(a_ptrs,
79
+ mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
80
+ other=0.0)
81
+ b = tl.load(b_ptrs,
82
+ mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
83
+ other=0.0)
84
+ accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator)
85
+ a_ptrs += BLOCK_SIZE_K * stride_xk
86
+ b_ptrs += BLOCK_SIZE_K * stride_xk
87
+ # use dtype.element_ty to accommodate different input datatypes as in cpp templates
88
+ # https://github.com/triton-lang/triton/issues/2252
89
+ c = accumulator.to(x.dtype.element_ty)
90
+
91
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
92
+ offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
93
+ c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :]
94
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
95
+ tl.store(c_ptrs, c, mask=c_mask)
96
+
97
+ # transpose and copy
98
+ if pid_m < pid_n:
99
+ ct_ptrs = y + stride_ym * offs_cn[:,
100
+ None] + stride_yn * offs_cm[None, :]
101
+ ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M)
102
+ tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)
103
+
104
+
105
+ def matmul_transpose_assign(d_in, d_out):
106
+ assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor"
107
+ assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor"
108
+ assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device"
109
+ assert d_in.dtype == d_out.dtype, "Inputs must have the same data type"
110
+ assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor"
111
+ assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor"
112
+ assert d_in.size(0) == d_out.size(0) == d_out.size(0), \
113
+ "First dimension of `d_in` must match first and second dimension of `d_out`"
114
+
115
+ d_in = d_in.contiguous()
116
+ M, K = d_in.shape
117
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
118
+ M, META['BLOCK_SIZE_M']), )
119
+ with torch.cuda.device(d_in.device.index):
120
+ mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
121
+ d_out.stride(0), d_out.stride(1))
122
+
123
+
124
+ def matmul_transpose(d_in):
125
+ M, _ = d_in.shape
126
+ d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype)
127
+ matmul_transpose_assign(d_in, d_out)
128
+ return d_out
build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py CHANGED
@@ -8,14 +8,19 @@ import torch
8
  import torch.distributed as dist
9
  from torch.distributed._tensor import DTensor, Replicate, Shard
10
 
 
 
11
  logger = logging.getLogger(__name__)
12
 
 
 
13
 
14
  # This code snippet is a modified version adapted from the following GitHub repositories:
15
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
16
  # Muon's Newton–Schulz iteration causes high variance in singular values
17
  # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
18
  @torch.no_grad()
 
19
  def _zeropower_via_newtonschulz5(G, steps):
20
  """
21
  Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
@@ -27,13 +32,15 @@ def _zeropower_via_newtonschulz5(G, steps):
27
  performance at all relative to UV^T, where USV^T = G is the SVD.
28
  """
29
  assert len(G.shape) == 2
30
- assert G.dtype == torch.bfloat16
31
  X = G # no manual typecast
32
 
33
  if G.size(0) > G.size(1):
34
  X = X.T
35
  # Ensure spectral norm is at most 1
36
  X = X / (X.norm() + 1e-7)
 
 
37
  # Perform the NS iterations
38
  for a, b, c in [
39
  (4.0848, -6.8946, 2.9270),
@@ -42,13 +49,10 @@ def _zeropower_via_newtonschulz5(G, steps):
42
  (2.8769, -3.1427, 1.2046),
43
  (2.8366, -3.0525, 1.2012),
44
  ]:
45
- A = X @ X.T
46
- # B = (
47
- # b * A + c * A @ A
48
- # )
49
- B = torch.addmm(A, A, A, alpha=c, beta=b)
50
- # X = a * X + B @ X
51
- X = torch.addmm(X, B, X, alpha=1.0, beta=a)
52
 
53
  if G.size(0) > G.size(1):
54
  X = X.T
@@ -69,51 +73,142 @@ class _muon_state:
69
  qk_clip_state = None
70
 
71
 
 
 
 
 
 
 
 
 
72
  @torch.no_grad()
73
- def _gather(p, state, rank, comm_stream, none_grad):
74
  """
75
- Gather the gradients to worker_rank.
76
- If none_grad is True, free p.grad after the gather.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  """
78
  with torch.cuda.stream(comm_stream):
79
- g = p.grad
 
80
 
81
- if rank == state.worker_rank:
82
- num_ranks = dist.get_world_size(group=state.process_group)
83
- gather_list = [
84
- torch.empty_like(g.to_local(), dtype=torch.bfloat16)
85
- for _ in range(num_ranks)
86
- ]
87
- else:
88
- gather_list = None
89
-
90
- g = g.to(torch.bfloat16)
91
- torch.distributed.gather(
92
- g.to_local(),
93
- dst=state.worker_rank,
94
- gather_list=gather_list,
95
- group=state.process_group,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  )
97
- if rank == state.worker_rank:
98
- if state.gathered_grad is not None:
99
- raise RuntimeError(
100
- "Gather event already exists, which should not happen.")
101
- state.gathered_grad = torch.cat(gather_list, dim=0)
102
- state.gather_event = torch.cuda.Event()
103
- state.gather_event.record()
104
- else:
105
- state.gathered_grad = None
106
- state.gather_event = None
107
- gather_list = None
108
- if none_grad:
109
- # We can safely free p.grad without calling record_stream:
110
- # p.grad.to_local().record_stream(comm_stream)
111
- # Explanation:
112
- # 1. p.grad is created on the default stream, but the default stream
113
- # is synchronized with the comm stream later.
114
- # 2. There is no further activity on the default stream before the optimizer finishes.
115
- # Therefore, it is safe to free p.grad directly on the comm stream.
116
- p.grad = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
 
119
  @torch.no_grad()
@@ -127,45 +222,145 @@ def _compute_u(p, state, steps, rank, compute_stream):
127
  raise RuntimeError("Gather event must be set before compute.")
128
  compute_stream.wait_event(state.gather_event)
129
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
 
130
  state.computed_u = u
131
- state.scattered_u = torch.empty_like(p.to_local(),
132
- dtype=torch.bfloat16)
133
- state.compute_event = torch.cuda.Event()
134
- state.compute_event.record()
135
- u = None
136
 
137
 
138
  @torch.no_grad()
139
- def _scatter(p, state, rank, comm_stream):
140
  """
141
- Scatter the computed_u from worker_rank to all ranks.
 
142
  """
 
 
 
 
 
 
 
 
 
 
143
 
 
 
 
 
144
  with torch.cuda.stream(comm_stream):
145
- if state.compute_event is None:
146
- raise RuntimeError("Compute event must be set before scatter.")
147
- comm_stream.wait_event(state.compute_event)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
- if rank == state.worker_rank:
150
- num_ranks = dist.get_world_size(group=state.process_group)
151
- # Clear the gathered gradient to free memory
152
- state.gathered_grad = None
 
 
 
 
153
 
154
- u = state.computed_u
155
- scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
156
- scatter_list = [s.contiguous() for s in scatter_list]
 
 
 
 
 
 
157
  else:
158
- scatter_list = None
 
 
159
 
160
- torch.distributed.scatter(
161
- state.scattered_u,
162
- scatter_list=scatter_list,
163
- src=state.worker_rank,
164
- group=state.process_group,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  )
166
- state.scatter_event = torch.cuda.Event()
167
- state.scatter_event.record()
168
- scatter_list = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
 
171
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
@@ -321,6 +516,11 @@ class Muon(torch.optim.Optimizer):
321
  "head_dim": 128,
322
  "threshold": 100
323
  }
 
 
 
 
 
324
  """
325
 
326
  def __init__(self,
@@ -339,7 +539,8 @@ class Muon(torch.optim.Optimizer):
339
  "k_indices": [],
340
  "head_dim": 128,
341
  "threshold": 100
342
- }):
 
343
  defaults = dict(
344
  lr=lr,
345
  weight_decay=weight_decay,
@@ -363,15 +564,13 @@ class Muon(torch.optim.Optimizer):
363
 
364
  super().__init__(params, defaults)
365
 
366
- if dist.is_initialized():
367
- self.rank = dist.get_rank()
368
- else:
369
- self.rank = None
370
 
371
  self.comm_stream = torch.cuda.Stream()
372
  self.compute_stream = torch.cuda.Stream()
373
  self.debug = debug
374
  self.clip_config = clip_config
 
375
 
376
  def _calc_flops(self, G, steps):
377
  assert len(G.shape) == 2
@@ -444,11 +643,18 @@ class Muon(torch.optim.Optimizer):
444
  if mesh is None:
445
  mesh = p.device_mesh
446
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
 
 
 
 
 
447
  elif mesh != p.device_mesh:
448
  raise ValueError("All parameters must be on the same mesh.")
449
 
 
450
  param_to_state[id(p)] = _muon_state()
451
- param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
 
452
  param_to_state[id(p)].process_group = process_group
453
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
454
  param_to_state[id(p)].qk_clip_state = qk_clip_state
@@ -478,7 +684,7 @@ class Muon(torch.optim.Optimizer):
478
  else:
479
  g = buf
480
 
481
- u = _zeropower_via_newtonschulz5(g.bfloat16(),
482
  steps=group["ns_steps"])
483
 
484
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
@@ -493,15 +699,12 @@ class Muon(torch.optim.Optimizer):
493
  def _update_g(self, p, g, group, momentum):
494
  # calc update
495
  state = self.state[p]
496
- if "momentum_buffer" not in state:
497
- state["momentum_buffer"] = torch.zeros_like(g)
498
- buf = state["momentum_buffer"]
499
- buf.mul_(momentum).add_(g)
500
  if group["nesterov"]:
501
- g = g.add(buf, alpha=momentum)
502
- else:
503
- g = buf
504
- return g
505
 
506
  @staticmethod
507
  def _update_p(p, u, lr, adjusted_lr, weight_decay):
@@ -585,11 +788,17 @@ class Muon(torch.optim.Optimizer):
585
  param_to_state, ordered_params = self.init_state_and_assign_params(
586
  names, params, group, qk_logits)
587
 
588
- def enqueue_gathers(start_idx, chunk_size):
589
- for p in ordered_params[start_idx:start_idx + chunk_size]:
590
- state = param_to_state[id(p)]
591
- _gather(p, state, self.rank, self.comm_stream,
592
- group["none_grad"])
 
 
 
 
 
 
593
 
594
  def enqueue_computes(start_idx, chunk_size):
595
  for p in ordered_params[start_idx:start_idx + chunk_size]:
@@ -597,10 +806,14 @@ class Muon(torch.optim.Optimizer):
597
  _compute_u(p, state, group["ns_steps"], self.rank,
598
  self.compute_stream)
599
 
600
- def enqueue_scatters(start_idx, chunk_size):
601
- for p in ordered_params[start_idx:start_idx + chunk_size]:
602
- state = param_to_state[id(p)]
603
- _scatter(p, state, self.rank, self.comm_stream)
 
 
 
 
604
 
605
  def enqueue_update_param(start_idx, chunk_size):
606
  for p in ordered_params[start_idx:start_idx + chunk_size]:
@@ -615,14 +828,16 @@ class Muon(torch.optim.Optimizer):
615
  # Wait grad update
616
  self.comm_stream.wait_stream(torch.cuda.current_stream())
617
 
618
- enqueue_gathers(0, chunk_size)
 
 
 
 
619
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
620
- enqueue_computes(i, chunk_size)
621
- if i > 0:
622
- enqueue_update_param(i - chunk_size, chunk_size)
623
- enqueue_gathers(i + chunk_size, chunk_size)
624
- enqueue_scatters(i, chunk_size)
625
- enqueue_update_param(i, chunk_size)
626
 
627
  # Wait the last update_param to finish
628
  torch.cuda.current_stream().wait_stream(self.compute_stream)
 
8
  import torch.distributed as dist
9
  from torch.distributed._tensor import DTensor, Replicate, Shard
10
 
11
+ from .matmul_transpose_triton import matmul_transpose_assign
12
+
13
  logger = logging.getLogger(__name__)
14
 
15
+ COMM_DTYPE = torch.bfloat16
16
+
17
 
18
  # This code snippet is a modified version adapted from the following GitHub repositories:
19
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
20
  # Muon's Newton–Schulz iteration causes high variance in singular values
21
  # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
22
  @torch.no_grad()
23
+ # matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
24
  def _zeropower_via_newtonschulz5(G, steps):
25
  """
26
  Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
 
32
  performance at all relative to UV^T, where USV^T = G is the SVD.
33
  """
34
  assert len(G.shape) == 2
35
+ assert G.dtype == COMM_DTYPE
36
  X = G # no manual typecast
37
 
38
  if G.size(0) > G.size(1):
39
  X = X.T
40
  # Ensure spectral norm is at most 1
41
  X = X / (X.norm() + 1e-7)
42
+ buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
43
+ buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
44
  # Perform the NS iterations
45
  for a, b, c in [
46
  (4.0848, -6.8946, 2.9270),
 
49
  (2.8769, -3.1427, 1.2046),
50
  (2.8366, -3.0525, 1.2012),
51
  ]:
52
+ matmul_transpose_assign(X, buf1)
53
+ matmul_transpose_assign(buf1, buf2)
54
+ buf1.mul_(b).add_(buf2, alpha=c)
55
+ X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)
 
 
 
56
 
57
  if G.size(0) > G.size(1):
58
  X = X.T
 
73
  qk_clip_state = None
74
 
75
 
76
+ def split_elems_for_src(param, src_rank, num_ranks) -> int:
77
+ rows = param.shape[0]
78
+ cols = int(param.numel() // rows)
79
+ base, rem = divmod(rows, num_ranks)
80
+ my_rows = base + (1 if src_rank < rem else 0)
81
+ return my_rows * cols
82
+
83
+
84
  @torch.no_grad()
85
+ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
86
  """
87
+ Pre-allocate gathered_grad buffer on compute_stream
88
+ before launching all2all gather
89
+ """
90
+ with torch.cuda.stream(compute_stream):
91
+ for p in params:
92
+ state = param_to_state[id(p)]
93
+ if rank == state.worker_rank:
94
+ num_ranks = dist.get_world_size(group=state.process_group)
95
+ state.gathered_grad = torch.empty(p.grad.numel(),
96
+ dtype=COMM_DTYPE,
97
+ device="cuda")
98
+ else:
99
+ state.gathered_grad = None
100
+
101
+ alloc_event = torch.cuda.Event()
102
+ alloc_event.record(compute_stream)
103
+ return alloc_event
104
+
105
+
106
+ @torch.no_grad()
107
+ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
108
+ alloc_event):
109
+ """
110
+ All2all gathers shards so each owner rank reconstructs its full gradient
111
  """
112
  with torch.cuda.stream(comm_stream):
113
+ process_group = param_to_state[id(params[0])].process_group
114
+ num_ranks = dist.get_world_size(group=process_group)
115
 
116
+ # Construct sending buffers
117
+ per_dst = [[] for _ in range(num_ranks)]
118
+ send_counts = [0] * num_ranks
119
+
120
+ for p in params:
121
+ state = param_to_state[id(p)]
122
+ dst = state.worker_rank
123
+ assert dst < num_ranks
124
+ shard_elems = split_elems_for_src(p, rank, num_ranks)
125
+ g = p.grad
126
+ g = g.to_local().to(COMM_DTYPE).contiguous().view(-1)
127
+ assert g.numel() == shard_elems
128
+ per_dst[dst].append(g)
129
+ send_counts[dst] += shard_elems
130
+
131
+ assert all(
132
+ len(v) > 0
133
+ for v in per_dst), "all params should be sharded to all devices"
134
+
135
+ send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
136
+ owned_params = [
137
+ p for p in params if param_to_state[id(p)].worker_rank == rank
138
+ ]
139
+
140
+ # Compute receive sizes and allocate receiving buffers
141
+ recv_counts = [0] * num_ranks
142
+
143
+ for src in range(num_ranks):
144
+ total = 0
145
+ for p in owned_params:
146
+ state = param_to_state[id(p)]
147
+ assert state.worker_rank == rank
148
+ total += split_elems_for_src(p, src, num_ranks)
149
+ recv_counts[src] = total
150
+
151
+ recv_total = sum(recv_counts)
152
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
153
+
154
+ #All2All
155
+ dist.all_to_all_single(
156
+ recv_buf,
157
+ send_buf,
158
+ output_split_sizes=recv_counts,
159
+ input_split_sizes=send_counts,
160
+ group=process_group,
161
  )
162
+
163
+ # Reconstructs gathered grad from the received buffer
164
+ #
165
+ # recv_buf (num ranks = 3)
166
+ #
167
+ # From rank 0 From rank 1 From rank 2
168
+ # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 |
169
+ #
170
+ # Outer loop:
171
+ # rank 0 -> rank 1 -> rank2
172
+ #
173
+ # Inner loop:
174
+ # p1_n -> p2_n -> p3_n
175
+
176
+ comm_stream.wait_event(alloc_event)
177
+
178
+ off = 0
179
+ write_offsets = {id(p): 0 for p in owned_params}
180
+ for src in range(num_ranks):
181
+ if recv_counts[src] == 0:
182
+ continue
183
+
184
+ block = recv_counts[src]
185
+ inner_off = 0
186
+ for p in owned_params:
187
+ state = param_to_state[id(p)]
188
+ assert state.worker_rank == rank
189
+ n = split_elems_for_src(p, src, num_ranks)
190
+ assert n > 0
191
+
192
+ sg = recv_buf.narrow(0, off + inner_off, n)
193
+ woff = write_offsets[id(p)]
194
+ dst = state.gathered_grad.narrow(0, woff, n)
195
+ dst.copy_(sg)
196
+
197
+ write_offsets[id(p)] += n
198
+ inner_off += n
199
+ off += block
200
+
201
+ for p in params:
202
+ state = param_to_state[id(p)]
203
+ if state.worker_rank == rank:
204
+ state.gathered_grad = state.gathered_grad.view_as(p)
205
+ state.gather_event = torch.cuda.Event()
206
+ state.gather_event.record(comm_stream)
207
+ else:
208
+ state.gathered_grad = None
209
+ state.gather_event = None
210
+ if none_grad:
211
+ p.grad = None
212
 
213
 
214
  @torch.no_grad()
 
222
  raise RuntimeError("Gather event must be set before compute.")
223
  compute_stream.wait_event(state.gather_event)
224
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
225
+ state.gathered_grad = None
226
  state.computed_u = u
227
+ state.compute_event = torch.cuda.Event()
228
+ state.compute_event.record()
229
+ else:
230
+ state.computed_u = None
231
+ state.compute_event = None
232
 
233
 
234
  @torch.no_grad()
235
+ def _alloc_scattered_u(params, param_to_state, rank, compute_stream):
236
  """
237
+ Pre-allocate scattered_u buffer on compute_stream
238
+ before launching all2all gather
239
  """
240
+ with torch.cuda.stream(compute_stream):
241
+ for p in params:
242
+ state = param_to_state[id(p)]
243
+ state.scattered_u = torch.empty_like(p.to_local(),
244
+ dtype=COMM_DTYPE)
245
+
246
+ alloc_event = torch.cuda.Event()
247
+ alloc_event.record(compute_stream)
248
+ return alloc_event
249
+
250
 
251
+ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
252
+ """
253
+ All2all scatters full gradients to all ranks
254
+ """
255
  with torch.cuda.stream(comm_stream):
256
+ process_group = param_to_state[id(params[0])].process_group
257
+ num_ranks = dist.get_world_size(group=process_group)
258
+ owned_params = [
259
+ p for p in params if param_to_state[id(p)].worker_rank == rank
260
+ ]
261
+
262
+ # Construct sending buffer
263
+ per_dst = [[] for _ in range(num_ranks)]
264
+ send_counts = [0] * num_ranks
265
+
266
+ if owned_params:
267
+ for p in owned_params:
268
+ state = param_to_state[id(p)]
269
+ if state.compute_event is None:
270
+ raise RuntimeError(
271
+ "Compute event must be set before scatter.")
272
+ comm_stream.wait_event(state.compute_event)
273
+ state.gathered_grad = None
274
 
275
+ assert state.computed_u is not None
276
+
277
+ u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1)
278
+
279
+ offset = 0
280
+ for dst in range(num_ranks):
281
+ n = split_elems_for_src(p, dst, num_ranks)
282
+ assert n > 0
283
 
284
+ su = u_full.narrow(0, offset, n)
285
+ per_dst[dst].append(su)
286
+ send_counts[dst] += n
287
+ offset += n
288
+
289
+ assert offset == u_full.numel()
290
+
291
+ if any(len(v) > 0 for v in per_dst):
292
+ send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
293
  else:
294
+ # all_to_all requires participation from all ranks
295
+ # Even non-owner ranks must join the collective call
296
+ send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
297
 
298
+ # Compute receive sizes and allocate receiving buffers
299
+ recv_counts = [0] * num_ranks
300
+
301
+ for src in range(num_ranks):
302
+ total = 0
303
+ for p in params:
304
+ state = param_to_state[id(p)]
305
+ if state.worker_rank != src:
306
+ continue
307
+ total += split_elems_for_src(p, rank, num_ranks)
308
+ recv_counts[src] = total
309
+
310
+ recv_total = sum(recv_counts)
311
+ assert recv_total > 0
312
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
313
+
314
+ #All2All
315
+ dist.all_to_all_single(
316
+ recv_buf,
317
+ send_buf,
318
+ output_split_sizes=recv_counts,
319
+ input_split_sizes=send_counts,
320
+ group=process_group,
321
  )
322
+
323
+ # Copy to pre-allocated scattered_u buffer from the received buffer
324
+ #
325
+ # recv_buf (num ranks = 3, local_rank = 0)
326
+ #
327
+ # From rank 0 From rank 1 From rank 2
328
+ # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 |
329
+ #
330
+ # Outer loop:
331
+ # rank 0 -> rank 1 -> rank2
332
+ #
333
+ # Inner loop:
334
+ # src(0) : p1_0 -> p2_0 -> p3_0
335
+ # src(1) : p4_0
336
+ # src(2) : p5_0 -> p6_0
337
+
338
+ comm_stream.wait_event(alloc_event)
339
+
340
+ off = 0
341
+ for src in range(num_ranks):
342
+ block = recv_counts[src]
343
+ if block == 0:
344
+ continue
345
+
346
+ inner_off = 0
347
+ for p in params:
348
+ state = param_to_state[id(p)]
349
+ if state.worker_rank != src:
350
+ continue
351
+ n = split_elems_for_src(p, rank, num_ranks)
352
+ assert n > 0
353
+
354
+ flat_local = recv_buf.narrow(0, off + inner_off,
355
+ n).view_as(p.to_local())
356
+ state.scattered_u.copy_(flat_local)
357
+
358
+ state.scatter_event = torch.cuda.Event()
359
+ state.scatter_event.record(comm_stream)
360
+ inner_off += n
361
+
362
+ assert inner_off == block
363
+ off += block
364
 
365
 
366
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
 
516
  "head_dim": 128,
517
  "threshold": 100
518
  }
519
+ overlap_step : How many all2all gather, compute operations are launched in advance
520
+ before the corresponding all2all scatter steps begin.
521
+ A higher overlap_step increases memory usage but can improve
522
+ performance by overlapping communication.
523
+ Parallel muon only.
524
  """
525
 
526
  def __init__(self,
 
539
  "k_indices": [],
540
  "head_dim": 128,
541
  "threshold": 100
542
+ },
543
+ overlap_step=5):
544
  defaults = dict(
545
  lr=lr,
546
  weight_decay=weight_decay,
 
564
 
565
  super().__init__(params, defaults)
566
 
567
+ self.rank = None
 
 
 
568
 
569
  self.comm_stream = torch.cuda.Stream()
570
  self.compute_stream = torch.cuda.Stream()
571
  self.debug = debug
572
  self.clip_config = clip_config
573
+ self.overlap_step = overlap_step
574
 
575
  def _calc_flops(self, G, steps):
576
  assert len(G.shape) == 2
 
643
  if mesh is None:
644
  mesh = p.device_mesh
645
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
646
+ local_rank = dist.get_rank(group=process_group)
647
+ if self.rank is None:
648
+ self.rank = dist.get_rank(group=process_group)
649
+ else:
650
+ assert self.rank == local_rank
651
  elif mesh != p.device_mesh:
652
  raise ValueError("All parameters must be on the same mesh.")
653
 
654
+ num_ranks = dist.get_world_size(group=process_group)
655
  param_to_state[id(p)] = _muon_state()
656
+ param_to_state[id(
657
+ p)].worker_rank = shard_mesh[round_robin].item() % num_ranks
658
  param_to_state[id(p)].process_group = process_group
659
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
660
  param_to_state[id(p)].qk_clip_state = qk_clip_state
 
684
  else:
685
  g = buf
686
 
687
+ u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
688
  steps=group["ns_steps"])
689
 
690
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
 
699
  def _update_g(self, p, g, group, momentum):
700
  # calc update
701
  state = self.state[p]
702
+ buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
703
+ torch.add(g, buf, alpha=momentum, out=buf)
 
 
704
  if group["nesterov"]:
705
+ g.add_(buf, alpha=momentum)
706
+ return g
707
+ return buf
 
708
 
709
  @staticmethod
710
  def _update_p(p, u, lr, adjusted_lr, weight_decay):
 
788
  param_to_state, ordered_params = self.init_state_and_assign_params(
789
  names, params, group, qk_logits)
790
 
791
+ assert self.rank is not None
792
+
793
+ def enqueue_all2all_gather(start_idx, chunk_size):
794
+ target_params = ordered_params[start_idx:start_idx + chunk_size]
795
+ if target_params:
796
+ alloc_event = _alloc_gathered_grad(target_params,
797
+ param_to_state, self.rank,
798
+ self.compute_stream)
799
+ _all2all_gather(target_params, param_to_state, self.rank,
800
+ self.comm_stream, group["none_grad"],
801
+ alloc_event)
802
 
803
  def enqueue_computes(start_idx, chunk_size):
804
  for p in ordered_params[start_idx:start_idx + chunk_size]:
 
806
  _compute_u(p, state, group["ns_steps"], self.rank,
807
  self.compute_stream)
808
 
809
+ def enqueue_all2all_scatter(start_idx, chunk_size):
810
+ target_params = ordered_params[start_idx:start_idx + chunk_size]
811
+ if target_params:
812
+ alloc_event = _alloc_scattered_u(target_params, param_to_state,
813
+ self.rank,
814
+ self.compute_stream)
815
+ _all2all_scatter(target_params, param_to_state, self.rank,
816
+ self.comm_stream, alloc_event)
817
 
818
  def enqueue_update_param(start_idx, chunk_size):
819
  for p in ordered_params[start_idx:start_idx + chunk_size]:
 
828
  # Wait grad update
829
  self.comm_stream.wait_stream(torch.cuda.current_stream())
830
 
831
+ overlap_step = self.overlap_step
832
+ for i in range(0, overlap_step):
833
+ enqueue_all2all_gather(i * chunk_size, chunk_size)
834
+ enqueue_computes(i * chunk_size, chunk_size)
835
+
836
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
837
+ enqueue_all2all_scatter(i, chunk_size)
838
+ enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size)
839
+ enqueue_update_param(i, chunk_size)
840
+ enqueue_computes(i + overlap_step * chunk_size, chunk_size)
 
 
841
 
842
  # Wait the last update_param to finish
843
  torch.cuda.current_stream().wait_stream(self.compute_stream)
build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_9c21645_dirty
3
- ops = torch.ops._optimizer_9c21645_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_9c21645_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_15336dc_dirty
3
+ ops = torch.ops._optimizer_15336dc_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_15336dc_dirty::{op_name}"
build/{torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_9c21645_dirty.abi3.so → torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_15336dc_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:42ae6ac1cf967d7d23cac7930c8db635105f60631220a60b9cee060d082f40ae
3
  size 1824256
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8ca6ca8225dc9b7888566f5c7fd824234a3b4ac76718a5d18e6c75ca7acd488d
3
  size 1824256
build/torch27-cxx11-cu126-x86_64-linux/optimizer/matmul_transpose_triton.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright (c) 2025 Tianyang Lin
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ import torch
24
+ import triton
25
+ import triton.language as tl
26
+
27
+
28
+ def get_autotune_config():
29
+ return [
30
+ triton.Config(
31
+ {
32
+ 'BLOCK_SIZE_M': blk_m,
33
+ 'BLOCK_SIZE_K': blk_k,
34
+ 'GROUP_SIZE_M': grp_sz
35
+ },
36
+ num_stages=n_stages,
37
+ num_warps=n_warps) for blk_m in [32, 64, 128]
38
+ for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5]
39
+ for n_warps in [4, 8]
40
+ ]
41
+
42
+
43
+ @triton.autotune(
44
+ configs=get_autotune_config(),
45
+ key=['M', 'K'],
46
+ )
47
+ @triton.jit
48
+ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
49
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
50
+ GROUP_SIZE_M: tl.constexpr):
51
+ """
52
+ Core kernel jit function of matmul_transpose that computes y = x @ x.T
53
+ The code is a simple adaptation from the triton `matmul` tutorial:
54
+ https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
55
+ """
56
+ pid = tl.program_id(axis=0)
57
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
58
+ num_pid_n = tl.cdiv(M, BLOCK_SIZE_M)
59
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
60
+ group_id = pid // num_pid_in_group
61
+ first_pid_m = group_id * GROUP_SIZE_M
62
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
63
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
64
+ pid_n = (pid % num_pid_in_group) // group_size_m
65
+ if pid_m > pid_n:
66
+ return
67
+
68
+ offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
69
+ offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
70
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
71
+ # we use a & b ptrs to denote different rows of x.
72
+ a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
73
+ b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk)
74
+
75
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32)
76
+
77
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
78
+ a = tl.load(a_ptrs,
79
+ mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
80
+ other=0.0)
81
+ b = tl.load(b_ptrs,
82
+ mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
83
+ other=0.0)
84
+ accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator)
85
+ a_ptrs += BLOCK_SIZE_K * stride_xk
86
+ b_ptrs += BLOCK_SIZE_K * stride_xk
87
+ # use dtype.element_ty to accommodate different input datatypes as in cpp templates
88
+ # https://github.com/triton-lang/triton/issues/2252
89
+ c = accumulator.to(x.dtype.element_ty)
90
+
91
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
92
+ offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
93
+ c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :]
94
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
95
+ tl.store(c_ptrs, c, mask=c_mask)
96
+
97
+ # transpose and copy
98
+ if pid_m < pid_n:
99
+ ct_ptrs = y + stride_ym * offs_cn[:,
100
+ None] + stride_yn * offs_cm[None, :]
101
+ ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M)
102
+ tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)
103
+
104
+
105
+ def matmul_transpose_assign(d_in, d_out):
106
+ assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor"
107
+ assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor"
108
+ assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device"
109
+ assert d_in.dtype == d_out.dtype, "Inputs must have the same data type"
110
+ assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor"
111
+ assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor"
112
+ assert d_in.size(0) == d_out.size(0) == d_out.size(0), \
113
+ "First dimension of `d_in` must match first and second dimension of `d_out`"
114
+
115
+ d_in = d_in.contiguous()
116
+ M, K = d_in.shape
117
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
118
+ M, META['BLOCK_SIZE_M']), )
119
+ with torch.cuda.device(d_in.device.index):
120
+ mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
121
+ d_out.stride(0), d_out.stride(1))
122
+
123
+
124
+ def matmul_transpose(d_in):
125
+ M, _ = d_in.shape
126
+ d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype)
127
+ matmul_transpose_assign(d_in, d_out)
128
+ return d_out
build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py CHANGED
@@ -8,14 +8,19 @@ import torch
8
  import torch.distributed as dist
9
  from torch.distributed._tensor import DTensor, Replicate, Shard
10
 
 
 
11
  logger = logging.getLogger(__name__)
12
 
 
 
13
 
14
  # This code snippet is a modified version adapted from the following GitHub repositories:
15
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
16
  # Muon's Newton–Schulz iteration causes high variance in singular values
17
  # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
18
  @torch.no_grad()
 
19
  def _zeropower_via_newtonschulz5(G, steps):
20
  """
21
  Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
@@ -27,13 +32,15 @@ def _zeropower_via_newtonschulz5(G, steps):
27
  performance at all relative to UV^T, where USV^T = G is the SVD.
28
  """
29
  assert len(G.shape) == 2
30
- assert G.dtype == torch.bfloat16
31
  X = G # no manual typecast
32
 
33
  if G.size(0) > G.size(1):
34
  X = X.T
35
  # Ensure spectral norm is at most 1
36
  X = X / (X.norm() + 1e-7)
 
 
37
  # Perform the NS iterations
38
  for a, b, c in [
39
  (4.0848, -6.8946, 2.9270),
@@ -42,13 +49,10 @@ def _zeropower_via_newtonschulz5(G, steps):
42
  (2.8769, -3.1427, 1.2046),
43
  (2.8366, -3.0525, 1.2012),
44
  ]:
45
- A = X @ X.T
46
- # B = (
47
- # b * A + c * A @ A
48
- # )
49
- B = torch.addmm(A, A, A, alpha=c, beta=b)
50
- # X = a * X + B @ X
51
- X = torch.addmm(X, B, X, alpha=1.0, beta=a)
52
 
53
  if G.size(0) > G.size(1):
54
  X = X.T
@@ -69,51 +73,142 @@ class _muon_state:
69
  qk_clip_state = None
70
 
71
 
 
 
 
 
 
 
 
 
72
  @torch.no_grad()
73
- def _gather(p, state, rank, comm_stream, none_grad):
74
  """
75
- Gather the gradients to worker_rank.
76
- If none_grad is True, free p.grad after the gather.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  """
78
  with torch.cuda.stream(comm_stream):
79
- g = p.grad
 
80
 
81
- if rank == state.worker_rank:
82
- num_ranks = dist.get_world_size(group=state.process_group)
83
- gather_list = [
84
- torch.empty_like(g.to_local(), dtype=torch.bfloat16)
85
- for _ in range(num_ranks)
86
- ]
87
- else:
88
- gather_list = None
89
-
90
- g = g.to(torch.bfloat16)
91
- torch.distributed.gather(
92
- g.to_local(),
93
- dst=state.worker_rank,
94
- gather_list=gather_list,
95
- group=state.process_group,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  )
97
- if rank == state.worker_rank:
98
- if state.gathered_grad is not None:
99
- raise RuntimeError(
100
- "Gather event already exists, which should not happen.")
101
- state.gathered_grad = torch.cat(gather_list, dim=0)
102
- state.gather_event = torch.cuda.Event()
103
- state.gather_event.record()
104
- else:
105
- state.gathered_grad = None
106
- state.gather_event = None
107
- gather_list = None
108
- if none_grad:
109
- # We can safely free p.grad without calling record_stream:
110
- # p.grad.to_local().record_stream(comm_stream)
111
- # Explanation:
112
- # 1. p.grad is created on the default stream, but the default stream
113
- # is synchronized with the comm stream later.
114
- # 2. There is no further activity on the default stream before the optimizer finishes.
115
- # Therefore, it is safe to free p.grad directly on the comm stream.
116
- p.grad = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
 
119
  @torch.no_grad()
@@ -127,45 +222,145 @@ def _compute_u(p, state, steps, rank, compute_stream):
127
  raise RuntimeError("Gather event must be set before compute.")
128
  compute_stream.wait_event(state.gather_event)
129
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
 
130
  state.computed_u = u
131
- state.scattered_u = torch.empty_like(p.to_local(),
132
- dtype=torch.bfloat16)
133
- state.compute_event = torch.cuda.Event()
134
- state.compute_event.record()
135
- u = None
136
 
137
 
138
  @torch.no_grad()
139
- def _scatter(p, state, rank, comm_stream):
140
  """
141
- Scatter the computed_u from worker_rank to all ranks.
 
142
  """
 
 
 
 
 
 
 
 
 
 
143
 
 
 
 
 
144
  with torch.cuda.stream(comm_stream):
145
- if state.compute_event is None:
146
- raise RuntimeError("Compute event must be set before scatter.")
147
- comm_stream.wait_event(state.compute_event)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
- if rank == state.worker_rank:
150
- num_ranks = dist.get_world_size(group=state.process_group)
151
- # Clear the gathered gradient to free memory
152
- state.gathered_grad = None
 
 
 
 
153
 
154
- u = state.computed_u
155
- scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
156
- scatter_list = [s.contiguous() for s in scatter_list]
 
 
 
 
 
 
157
  else:
158
- scatter_list = None
 
 
159
 
160
- torch.distributed.scatter(
161
- state.scattered_u,
162
- scatter_list=scatter_list,
163
- src=state.worker_rank,
164
- group=state.process_group,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  )
166
- state.scatter_event = torch.cuda.Event()
167
- state.scatter_event.record()
168
- scatter_list = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
 
171
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
@@ -321,6 +516,11 @@ class Muon(torch.optim.Optimizer):
321
  "head_dim": 128,
322
  "threshold": 100
323
  }
 
 
 
 
 
324
  """
325
 
326
  def __init__(self,
@@ -339,7 +539,8 @@ class Muon(torch.optim.Optimizer):
339
  "k_indices": [],
340
  "head_dim": 128,
341
  "threshold": 100
342
- }):
 
343
  defaults = dict(
344
  lr=lr,
345
  weight_decay=weight_decay,
@@ -363,15 +564,13 @@ class Muon(torch.optim.Optimizer):
363
 
364
  super().__init__(params, defaults)
365
 
366
- if dist.is_initialized():
367
- self.rank = dist.get_rank()
368
- else:
369
- self.rank = None
370
 
371
  self.comm_stream = torch.cuda.Stream()
372
  self.compute_stream = torch.cuda.Stream()
373
  self.debug = debug
374
  self.clip_config = clip_config
 
375
 
376
  def _calc_flops(self, G, steps):
377
  assert len(G.shape) == 2
@@ -444,11 +643,18 @@ class Muon(torch.optim.Optimizer):
444
  if mesh is None:
445
  mesh = p.device_mesh
446
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
 
 
 
 
 
447
  elif mesh != p.device_mesh:
448
  raise ValueError("All parameters must be on the same mesh.")
449
 
 
450
  param_to_state[id(p)] = _muon_state()
451
- param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
 
452
  param_to_state[id(p)].process_group = process_group
453
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
454
  param_to_state[id(p)].qk_clip_state = qk_clip_state
@@ -478,7 +684,7 @@ class Muon(torch.optim.Optimizer):
478
  else:
479
  g = buf
480
 
481
- u = _zeropower_via_newtonschulz5(g.bfloat16(),
482
  steps=group["ns_steps"])
483
 
484
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
@@ -493,15 +699,12 @@ class Muon(torch.optim.Optimizer):
493
  def _update_g(self, p, g, group, momentum):
494
  # calc update
495
  state = self.state[p]
496
- if "momentum_buffer" not in state:
497
- state["momentum_buffer"] = torch.zeros_like(g)
498
- buf = state["momentum_buffer"]
499
- buf.mul_(momentum).add_(g)
500
  if group["nesterov"]:
501
- g = g.add(buf, alpha=momentum)
502
- else:
503
- g = buf
504
- return g
505
 
506
  @staticmethod
507
  def _update_p(p, u, lr, adjusted_lr, weight_decay):
@@ -585,11 +788,17 @@ class Muon(torch.optim.Optimizer):
585
  param_to_state, ordered_params = self.init_state_and_assign_params(
586
  names, params, group, qk_logits)
587
 
588
- def enqueue_gathers(start_idx, chunk_size):
589
- for p in ordered_params[start_idx:start_idx + chunk_size]:
590
- state = param_to_state[id(p)]
591
- _gather(p, state, self.rank, self.comm_stream,
592
- group["none_grad"])
 
 
 
 
 
 
593
 
594
  def enqueue_computes(start_idx, chunk_size):
595
  for p in ordered_params[start_idx:start_idx + chunk_size]:
@@ -597,10 +806,14 @@ class Muon(torch.optim.Optimizer):
597
  _compute_u(p, state, group["ns_steps"], self.rank,
598
  self.compute_stream)
599
 
600
- def enqueue_scatters(start_idx, chunk_size):
601
- for p in ordered_params[start_idx:start_idx + chunk_size]:
602
- state = param_to_state[id(p)]
603
- _scatter(p, state, self.rank, self.comm_stream)
 
 
 
 
604
 
605
  def enqueue_update_param(start_idx, chunk_size):
606
  for p in ordered_params[start_idx:start_idx + chunk_size]:
@@ -615,14 +828,16 @@ class Muon(torch.optim.Optimizer):
615
  # Wait grad update
616
  self.comm_stream.wait_stream(torch.cuda.current_stream())
617
 
618
- enqueue_gathers(0, chunk_size)
 
 
 
 
619
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
620
- enqueue_computes(i, chunk_size)
621
- if i > 0:
622
- enqueue_update_param(i - chunk_size, chunk_size)
623
- enqueue_gathers(i + chunk_size, chunk_size)
624
- enqueue_scatters(i, chunk_size)
625
- enqueue_update_param(i, chunk_size)
626
 
627
  # Wait the last update_param to finish
628
  torch.cuda.current_stream().wait_stream(self.compute_stream)
 
8
  import torch.distributed as dist
9
  from torch.distributed._tensor import DTensor, Replicate, Shard
10
 
11
+ from .matmul_transpose_triton import matmul_transpose_assign
12
+
13
  logger = logging.getLogger(__name__)
14
 
15
+ COMM_DTYPE = torch.bfloat16
16
+
17
 
18
  # This code snippet is a modified version adapted from the following GitHub repositories:
19
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
20
  # Muon's Newton–Schulz iteration causes high variance in singular values
21
  # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
22
  @torch.no_grad()
23
+ # matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
24
  def _zeropower_via_newtonschulz5(G, steps):
25
  """
26
  Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
 
32
  performance at all relative to UV^T, where USV^T = G is the SVD.
33
  """
34
  assert len(G.shape) == 2
35
+ assert G.dtype == COMM_DTYPE
36
  X = G # no manual typecast
37
 
38
  if G.size(0) > G.size(1):
39
  X = X.T
40
  # Ensure spectral norm is at most 1
41
  X = X / (X.norm() + 1e-7)
42
+ buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
43
+ buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
44
  # Perform the NS iterations
45
  for a, b, c in [
46
  (4.0848, -6.8946, 2.9270),
 
49
  (2.8769, -3.1427, 1.2046),
50
  (2.8366, -3.0525, 1.2012),
51
  ]:
52
+ matmul_transpose_assign(X, buf1)
53
+ matmul_transpose_assign(buf1, buf2)
54
+ buf1.mul_(b).add_(buf2, alpha=c)
55
+ X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)
 
 
 
56
 
57
  if G.size(0) > G.size(1):
58
  X = X.T
 
73
  qk_clip_state = None
74
 
75
 
76
+ def split_elems_for_src(param, src_rank, num_ranks) -> int:
77
+ rows = param.shape[0]
78
+ cols = int(param.numel() // rows)
79
+ base, rem = divmod(rows, num_ranks)
80
+ my_rows = base + (1 if src_rank < rem else 0)
81
+ return my_rows * cols
82
+
83
+
84
  @torch.no_grad()
85
+ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
86
  """
87
+ Pre-allocate gathered_grad buffer on compute_stream
88
+ before launching all2all gather
89
+ """
90
+ with torch.cuda.stream(compute_stream):
91
+ for p in params:
92
+ state = param_to_state[id(p)]
93
+ if rank == state.worker_rank:
94
+ num_ranks = dist.get_world_size(group=state.process_group)
95
+ state.gathered_grad = torch.empty(p.grad.numel(),
96
+ dtype=COMM_DTYPE,
97
+ device="cuda")
98
+ else:
99
+ state.gathered_grad = None
100
+
101
+ alloc_event = torch.cuda.Event()
102
+ alloc_event.record(compute_stream)
103
+ return alloc_event
104
+
105
+
106
+ @torch.no_grad()
107
+ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
108
+ alloc_event):
109
+ """
110
+ All2all gathers shards so each owner rank reconstructs its full gradient
111
  """
112
  with torch.cuda.stream(comm_stream):
113
+ process_group = param_to_state[id(params[0])].process_group
114
+ num_ranks = dist.get_world_size(group=process_group)
115
 
116
+ # Construct sending buffers
117
+ per_dst = [[] for _ in range(num_ranks)]
118
+ send_counts = [0] * num_ranks
119
+
120
+ for p in params:
121
+ state = param_to_state[id(p)]
122
+ dst = state.worker_rank
123
+ assert dst < num_ranks
124
+ shard_elems = split_elems_for_src(p, rank, num_ranks)
125
+ g = p.grad
126
+ g = g.to_local().to(COMM_DTYPE).contiguous().view(-1)
127
+ assert g.numel() == shard_elems
128
+ per_dst[dst].append(g)
129
+ send_counts[dst] += shard_elems
130
+
131
+ assert all(
132
+ len(v) > 0
133
+ for v in per_dst), "all params should be sharded to all devices"
134
+
135
+ send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
136
+ owned_params = [
137
+ p for p in params if param_to_state[id(p)].worker_rank == rank
138
+ ]
139
+
140
+ # Compute receive sizes and allocate receiving buffers
141
+ recv_counts = [0] * num_ranks
142
+
143
+ for src in range(num_ranks):
144
+ total = 0
145
+ for p in owned_params:
146
+ state = param_to_state[id(p)]
147
+ assert state.worker_rank == rank
148
+ total += split_elems_for_src(p, src, num_ranks)
149
+ recv_counts[src] = total
150
+
151
+ recv_total = sum(recv_counts)
152
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
153
+
154
+ #All2All
155
+ dist.all_to_all_single(
156
+ recv_buf,
157
+ send_buf,
158
+ output_split_sizes=recv_counts,
159
+ input_split_sizes=send_counts,
160
+ group=process_group,
161
  )
162
+
163
+ # Reconstructs gathered grad from the received buffer
164
+ #
165
+ # recv_buf (num ranks = 3)
166
+ #
167
+ # From rank 0 From rank 1 From rank 2
168
+ # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 |
169
+ #
170
+ # Outer loop:
171
+ # rank 0 -> rank 1 -> rank2
172
+ #
173
+ # Inner loop:
174
+ # p1_n -> p2_n -> p3_n
175
+
176
+ comm_stream.wait_event(alloc_event)
177
+
178
+ off = 0
179
+ write_offsets = {id(p): 0 for p in owned_params}
180
+ for src in range(num_ranks):
181
+ if recv_counts[src] == 0:
182
+ continue
183
+
184
+ block = recv_counts[src]
185
+ inner_off = 0
186
+ for p in owned_params:
187
+ state = param_to_state[id(p)]
188
+ assert state.worker_rank == rank
189
+ n = split_elems_for_src(p, src, num_ranks)
190
+ assert n > 0
191
+
192
+ sg = recv_buf.narrow(0, off + inner_off, n)
193
+ woff = write_offsets[id(p)]
194
+ dst = state.gathered_grad.narrow(0, woff, n)
195
+ dst.copy_(sg)
196
+
197
+ write_offsets[id(p)] += n
198
+ inner_off += n
199
+ off += block
200
+
201
+ for p in params:
202
+ state = param_to_state[id(p)]
203
+ if state.worker_rank == rank:
204
+ state.gathered_grad = state.gathered_grad.view_as(p)
205
+ state.gather_event = torch.cuda.Event()
206
+ state.gather_event.record(comm_stream)
207
+ else:
208
+ state.gathered_grad = None
209
+ state.gather_event = None
210
+ if none_grad:
211
+ p.grad = None
212
 
213
 
214
  @torch.no_grad()
 
222
  raise RuntimeError("Gather event must be set before compute.")
223
  compute_stream.wait_event(state.gather_event)
224
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
225
+ state.gathered_grad = None
226
  state.computed_u = u
227
+ state.compute_event = torch.cuda.Event()
228
+ state.compute_event.record()
229
+ else:
230
+ state.computed_u = None
231
+ state.compute_event = None
232
 
233
 
234
  @torch.no_grad()
235
+ def _alloc_scattered_u(params, param_to_state, rank, compute_stream):
236
  """
237
+ Pre-allocate scattered_u buffer on compute_stream
238
+ before launching all2all gather
239
  """
240
+ with torch.cuda.stream(compute_stream):
241
+ for p in params:
242
+ state = param_to_state[id(p)]
243
+ state.scattered_u = torch.empty_like(p.to_local(),
244
+ dtype=COMM_DTYPE)
245
+
246
+ alloc_event = torch.cuda.Event()
247
+ alloc_event.record(compute_stream)
248
+ return alloc_event
249
+
250
 
251
+ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
252
+ """
253
+ All2all scatters full gradients to all ranks
254
+ """
255
  with torch.cuda.stream(comm_stream):
256
+ process_group = param_to_state[id(params[0])].process_group
257
+ num_ranks = dist.get_world_size(group=process_group)
258
+ owned_params = [
259
+ p for p in params if param_to_state[id(p)].worker_rank == rank
260
+ ]
261
+
262
+ # Construct sending buffer
263
+ per_dst = [[] for _ in range(num_ranks)]
264
+ send_counts = [0] * num_ranks
265
+
266
+ if owned_params:
267
+ for p in owned_params:
268
+ state = param_to_state[id(p)]
269
+ if state.compute_event is None:
270
+ raise RuntimeError(
271
+ "Compute event must be set before scatter.")
272
+ comm_stream.wait_event(state.compute_event)
273
+ state.gathered_grad = None
274
 
275
+ assert state.computed_u is not None
276
+
277
+ u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1)
278
+
279
+ offset = 0
280
+ for dst in range(num_ranks):
281
+ n = split_elems_for_src(p, dst, num_ranks)
282
+ assert n > 0
283
 
284
+ su = u_full.narrow(0, offset, n)
285
+ per_dst[dst].append(su)
286
+ send_counts[dst] += n
287
+ offset += n
288
+
289
+ assert offset == u_full.numel()
290
+
291
+ if any(len(v) > 0 for v in per_dst):
292
+ send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
293
  else:
294
+ # all_to_all requires participation from all ranks
295
+ # Even non-owner ranks must join the collective call
296
+ send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
297
 
298
+ # Compute receive sizes and allocate receiving buffers
299
+ recv_counts = [0] * num_ranks
300
+
301
+ for src in range(num_ranks):
302
+ total = 0
303
+ for p in params:
304
+ state = param_to_state[id(p)]
305
+ if state.worker_rank != src:
306
+ continue
307
+ total += split_elems_for_src(p, rank, num_ranks)
308
+ recv_counts[src] = total
309
+
310
+ recv_total = sum(recv_counts)
311
+ assert recv_total > 0
312
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
313
+
314
+ #All2All
315
+ dist.all_to_all_single(
316
+ recv_buf,
317
+ send_buf,
318
+ output_split_sizes=recv_counts,
319
+ input_split_sizes=send_counts,
320
+ group=process_group,
321
  )
322
+
323
+ # Copy to pre-allocated scattered_u buffer from the received buffer
324
+ #
325
+ # recv_buf (num ranks = 3, local_rank = 0)
326
+ #
327
+ # From rank 0 From rank 1 From rank 2
328
+ # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 |
329
+ #
330
+ # Outer loop:
331
+ # rank 0 -> rank 1 -> rank2
332
+ #
333
+ # Inner loop:
334
+ # src(0) : p1_0 -> p2_0 -> p3_0
335
+ # src(1) : p4_0
336
+ # src(2) : p5_0 -> p6_0
337
+
338
+ comm_stream.wait_event(alloc_event)
339
+
340
+ off = 0
341
+ for src in range(num_ranks):
342
+ block = recv_counts[src]
343
+ if block == 0:
344
+ continue
345
+
346
+ inner_off = 0
347
+ for p in params:
348
+ state = param_to_state[id(p)]
349
+ if state.worker_rank != src:
350
+ continue
351
+ n = split_elems_for_src(p, rank, num_ranks)
352
+ assert n > 0
353
+
354
+ flat_local = recv_buf.narrow(0, off + inner_off,
355
+ n).view_as(p.to_local())
356
+ state.scattered_u.copy_(flat_local)
357
+
358
+ state.scatter_event = torch.cuda.Event()
359
+ state.scatter_event.record(comm_stream)
360
+ inner_off += n
361
+
362
+ assert inner_off == block
363
+ off += block
364
 
365
 
366
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
 
516
  "head_dim": 128,
517
  "threshold": 100
518
  }
519
+ overlap_step : How many all2all gather, compute operations are launched in advance
520
+ before the corresponding all2all scatter steps begin.
521
+ A higher overlap_step increases memory usage but can improve
522
+ performance by overlapping communication.
523
+ Parallel muon only.
524
  """
525
 
526
  def __init__(self,
 
539
  "k_indices": [],
540
  "head_dim": 128,
541
  "threshold": 100
542
+ },
543
+ overlap_step=5):
544
  defaults = dict(
545
  lr=lr,
546
  weight_decay=weight_decay,
 
564
 
565
  super().__init__(params, defaults)
566
 
567
+ self.rank = None
 
 
 
568
 
569
  self.comm_stream = torch.cuda.Stream()
570
  self.compute_stream = torch.cuda.Stream()
571
  self.debug = debug
572
  self.clip_config = clip_config
573
+ self.overlap_step = overlap_step
574
 
575
  def _calc_flops(self, G, steps):
576
  assert len(G.shape) == 2
 
643
  if mesh is None:
644
  mesh = p.device_mesh
645
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
646
+ local_rank = dist.get_rank(group=process_group)
647
+ if self.rank is None:
648
+ self.rank = dist.get_rank(group=process_group)
649
+ else:
650
+ assert self.rank == local_rank
651
  elif mesh != p.device_mesh:
652
  raise ValueError("All parameters must be on the same mesh.")
653
 
654
+ num_ranks = dist.get_world_size(group=process_group)
655
  param_to_state[id(p)] = _muon_state()
656
+ param_to_state[id(
657
+ p)].worker_rank = shard_mesh[round_robin].item() % num_ranks
658
  param_to_state[id(p)].process_group = process_group
659
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
660
  param_to_state[id(p)].qk_clip_state = qk_clip_state
 
684
  else:
685
  g = buf
686
 
687
+ u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
688
  steps=group["ns_steps"])
689
 
690
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
 
699
  def _update_g(self, p, g, group, momentum):
700
  # calc update
701
  state = self.state[p]
702
+ buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
703
+ torch.add(g, buf, alpha=momentum, out=buf)
 
 
704
  if group["nesterov"]:
705
+ g.add_(buf, alpha=momentum)
706
+ return g
707
+ return buf
 
708
 
709
  @staticmethod
710
  def _update_p(p, u, lr, adjusted_lr, weight_decay):
 
788
  param_to_state, ordered_params = self.init_state_and_assign_params(
789
  names, params, group, qk_logits)
790
 
791
+ assert self.rank is not None
792
+
793
+ def enqueue_all2all_gather(start_idx, chunk_size):
794
+ target_params = ordered_params[start_idx:start_idx + chunk_size]
795
+ if target_params:
796
+ alloc_event = _alloc_gathered_grad(target_params,
797
+ param_to_state, self.rank,
798
+ self.compute_stream)
799
+ _all2all_gather(target_params, param_to_state, self.rank,
800
+ self.comm_stream, group["none_grad"],
801
+ alloc_event)
802
 
803
  def enqueue_computes(start_idx, chunk_size):
804
  for p in ordered_params[start_idx:start_idx + chunk_size]:
 
806
  _compute_u(p, state, group["ns_steps"], self.rank,
807
  self.compute_stream)
808
 
809
+ def enqueue_all2all_scatter(start_idx, chunk_size):
810
+ target_params = ordered_params[start_idx:start_idx + chunk_size]
811
+ if target_params:
812
+ alloc_event = _alloc_scattered_u(target_params, param_to_state,
813
+ self.rank,
814
+ self.compute_stream)
815
+ _all2all_scatter(target_params, param_to_state, self.rank,
816
+ self.comm_stream, alloc_event)
817
 
818
  def enqueue_update_param(start_idx, chunk_size):
819
  for p in ordered_params[start_idx:start_idx + chunk_size]:
 
828
  # Wait grad update
829
  self.comm_stream.wait_stream(torch.cuda.current_stream())
830
 
831
+ overlap_step = self.overlap_step
832
+ for i in range(0, overlap_step):
833
+ enqueue_all2all_gather(i * chunk_size, chunk_size)
834
+ enqueue_computes(i * chunk_size, chunk_size)
835
+
836
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
837
+ enqueue_all2all_scatter(i, chunk_size)
838
+ enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size)
839
+ enqueue_update_param(i, chunk_size)
840
+ enqueue_computes(i + overlap_step * chunk_size, chunk_size)
 
 
841
 
842
  # Wait the last update_param to finish
843
  torch.cuda.current_stream().wait_stream(self.compute_stream)
build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_9c21645_dirty
3
- ops = torch.ops._optimizer_9c21645_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_9c21645_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_15336dc_dirty
3
+ ops = torch.ops._optimizer_15336dc_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_15336dc_dirty::{op_name}"
build/{torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_9c21645_dirty.abi3.so → torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_15336dc_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:dae71b7e998e72130093a86f8c983c3379510e23525e3cdcd4afe5c21bf4d3db
3
  size 1883344
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e06baa32b0950126ee192654bd9f7adc79cc05d8ec39d2078c70d62ee81fdcd5
3
  size 1883344
build/torch27-cxx11-cu128-x86_64-linux/optimizer/matmul_transpose_triton.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright (c) 2025 Tianyang Lin
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ import torch
24
+ import triton
25
+ import triton.language as tl
26
+
27
+
28
+ def get_autotune_config():
29
+ return [
30
+ triton.Config(
31
+ {
32
+ 'BLOCK_SIZE_M': blk_m,
33
+ 'BLOCK_SIZE_K': blk_k,
34
+ 'GROUP_SIZE_M': grp_sz
35
+ },
36
+ num_stages=n_stages,
37
+ num_warps=n_warps) for blk_m in [32, 64, 128]
38
+ for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5]
39
+ for n_warps in [4, 8]
40
+ ]
41
+
42
+
43
+ @triton.autotune(
44
+ configs=get_autotune_config(),
45
+ key=['M', 'K'],
46
+ )
47
+ @triton.jit
48
+ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
49
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
50
+ GROUP_SIZE_M: tl.constexpr):
51
+ """
52
+ Core kernel jit function of matmul_transpose that computes y = x @ x.T
53
+ The code is a simple adaptation from the triton `matmul` tutorial:
54
+ https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
55
+ """
56
+ pid = tl.program_id(axis=0)
57
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
58
+ num_pid_n = tl.cdiv(M, BLOCK_SIZE_M)
59
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
60
+ group_id = pid // num_pid_in_group
61
+ first_pid_m = group_id * GROUP_SIZE_M
62
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
63
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
64
+ pid_n = (pid % num_pid_in_group) // group_size_m
65
+ if pid_m > pid_n:
66
+ return
67
+
68
+ offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
69
+ offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
70
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
71
+ # we use a & b ptrs to denote different rows of x.
72
+ a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
73
+ b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk)
74
+
75
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32)
76
+
77
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
78
+ a = tl.load(a_ptrs,
79
+ mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
80
+ other=0.0)
81
+ b = tl.load(b_ptrs,
82
+ mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
83
+ other=0.0)
84
+ accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator)
85
+ a_ptrs += BLOCK_SIZE_K * stride_xk
86
+ b_ptrs += BLOCK_SIZE_K * stride_xk
87
+ # use dtype.element_ty to accommodate different input datatypes as in cpp templates
88
+ # https://github.com/triton-lang/triton/issues/2252
89
+ c = accumulator.to(x.dtype.element_ty)
90
+
91
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
92
+ offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
93
+ c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :]
94
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
95
+ tl.store(c_ptrs, c, mask=c_mask)
96
+
97
+ # transpose and copy
98
+ if pid_m < pid_n:
99
+ ct_ptrs = y + stride_ym * offs_cn[:,
100
+ None] + stride_yn * offs_cm[None, :]
101
+ ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M)
102
+ tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)
103
+
104
+
105
+ def matmul_transpose_assign(d_in, d_out):
106
+ assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor"
107
+ assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor"
108
+ assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device"
109
+ assert d_in.dtype == d_out.dtype, "Inputs must have the same data type"
110
+ assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor"
111
+ assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor"
112
+ assert d_in.size(0) == d_out.size(0) == d_out.size(0), \
113
+ "First dimension of `d_in` must match first and second dimension of `d_out`"
114
+
115
+ d_in = d_in.contiguous()
116
+ M, K = d_in.shape
117
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
118
+ M, META['BLOCK_SIZE_M']), )
119
+ with torch.cuda.device(d_in.device.index):
120
+ mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
121
+ d_out.stride(0), d_out.stride(1))
122
+
123
+
124
+ def matmul_transpose(d_in):
125
+ M, _ = d_in.shape
126
+ d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype)
127
+ matmul_transpose_assign(d_in, d_out)
128
+ return d_out
build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py CHANGED
@@ -8,14 +8,19 @@ import torch
8
  import torch.distributed as dist
9
  from torch.distributed._tensor import DTensor, Replicate, Shard
10
 
 
 
11
  logger = logging.getLogger(__name__)
12
 
 
 
13
 
14
  # This code snippet is a modified version adapted from the following GitHub repositories:
15
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
16
  # Muon's Newton–Schulz iteration causes high variance in singular values
17
  # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
18
  @torch.no_grad()
 
19
  def _zeropower_via_newtonschulz5(G, steps):
20
  """
21
  Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
@@ -27,13 +32,15 @@ def _zeropower_via_newtonschulz5(G, steps):
27
  performance at all relative to UV^T, where USV^T = G is the SVD.
28
  """
29
  assert len(G.shape) == 2
30
- assert G.dtype == torch.bfloat16
31
  X = G # no manual typecast
32
 
33
  if G.size(0) > G.size(1):
34
  X = X.T
35
  # Ensure spectral norm is at most 1
36
  X = X / (X.norm() + 1e-7)
 
 
37
  # Perform the NS iterations
38
  for a, b, c in [
39
  (4.0848, -6.8946, 2.9270),
@@ -42,13 +49,10 @@ def _zeropower_via_newtonschulz5(G, steps):
42
  (2.8769, -3.1427, 1.2046),
43
  (2.8366, -3.0525, 1.2012),
44
  ]:
45
- A = X @ X.T
46
- # B = (
47
- # b * A + c * A @ A
48
- # )
49
- B = torch.addmm(A, A, A, alpha=c, beta=b)
50
- # X = a * X + B @ X
51
- X = torch.addmm(X, B, X, alpha=1.0, beta=a)
52
 
53
  if G.size(0) > G.size(1):
54
  X = X.T
@@ -69,51 +73,142 @@ class _muon_state:
69
  qk_clip_state = None
70
 
71
 
 
 
 
 
 
 
 
 
72
  @torch.no_grad()
73
- def _gather(p, state, rank, comm_stream, none_grad):
74
  """
75
- Gather the gradients to worker_rank.
76
- If none_grad is True, free p.grad after the gather.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  """
78
  with torch.cuda.stream(comm_stream):
79
- g = p.grad
 
80
 
81
- if rank == state.worker_rank:
82
- num_ranks = dist.get_world_size(group=state.process_group)
83
- gather_list = [
84
- torch.empty_like(g.to_local(), dtype=torch.bfloat16)
85
- for _ in range(num_ranks)
86
- ]
87
- else:
88
- gather_list = None
89
-
90
- g = g.to(torch.bfloat16)
91
- torch.distributed.gather(
92
- g.to_local(),
93
- dst=state.worker_rank,
94
- gather_list=gather_list,
95
- group=state.process_group,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  )
97
- if rank == state.worker_rank:
98
- if state.gathered_grad is not None:
99
- raise RuntimeError(
100
- "Gather event already exists, which should not happen.")
101
- state.gathered_grad = torch.cat(gather_list, dim=0)
102
- state.gather_event = torch.cuda.Event()
103
- state.gather_event.record()
104
- else:
105
- state.gathered_grad = None
106
- state.gather_event = None
107
- gather_list = None
108
- if none_grad:
109
- # We can safely free p.grad without calling record_stream:
110
- # p.grad.to_local().record_stream(comm_stream)
111
- # Explanation:
112
- # 1. p.grad is created on the default stream, but the default stream
113
- # is synchronized with the comm stream later.
114
- # 2. There is no further activity on the default stream before the optimizer finishes.
115
- # Therefore, it is safe to free p.grad directly on the comm stream.
116
- p.grad = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
 
119
  @torch.no_grad()
@@ -127,45 +222,145 @@ def _compute_u(p, state, steps, rank, compute_stream):
127
  raise RuntimeError("Gather event must be set before compute.")
128
  compute_stream.wait_event(state.gather_event)
129
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
 
130
  state.computed_u = u
131
- state.scattered_u = torch.empty_like(p.to_local(),
132
- dtype=torch.bfloat16)
133
- state.compute_event = torch.cuda.Event()
134
- state.compute_event.record()
135
- u = None
136
 
137
 
138
  @torch.no_grad()
139
- def _scatter(p, state, rank, comm_stream):
140
  """
141
- Scatter the computed_u from worker_rank to all ranks.
 
142
  """
 
 
 
 
 
 
 
 
 
 
143
 
 
 
 
 
144
  with torch.cuda.stream(comm_stream):
145
- if state.compute_event is None:
146
- raise RuntimeError("Compute event must be set before scatter.")
147
- comm_stream.wait_event(state.compute_event)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
- if rank == state.worker_rank:
150
- num_ranks = dist.get_world_size(group=state.process_group)
151
- # Clear the gathered gradient to free memory
152
- state.gathered_grad = None
 
 
 
 
153
 
154
- u = state.computed_u
155
- scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
156
- scatter_list = [s.contiguous() for s in scatter_list]
 
 
 
 
 
 
157
  else:
158
- scatter_list = None
 
 
159
 
160
- torch.distributed.scatter(
161
- state.scattered_u,
162
- scatter_list=scatter_list,
163
- src=state.worker_rank,
164
- group=state.process_group,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  )
166
- state.scatter_event = torch.cuda.Event()
167
- state.scatter_event.record()
168
- scatter_list = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
 
171
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
@@ -321,6 +516,11 @@ class Muon(torch.optim.Optimizer):
321
  "head_dim": 128,
322
  "threshold": 100
323
  }
 
 
 
 
 
324
  """
325
 
326
  def __init__(self,
@@ -339,7 +539,8 @@ class Muon(torch.optim.Optimizer):
339
  "k_indices": [],
340
  "head_dim": 128,
341
  "threshold": 100
342
- }):
 
343
  defaults = dict(
344
  lr=lr,
345
  weight_decay=weight_decay,
@@ -363,15 +564,13 @@ class Muon(torch.optim.Optimizer):
363
 
364
  super().__init__(params, defaults)
365
 
366
- if dist.is_initialized():
367
- self.rank = dist.get_rank()
368
- else:
369
- self.rank = None
370
 
371
  self.comm_stream = torch.cuda.Stream()
372
  self.compute_stream = torch.cuda.Stream()
373
  self.debug = debug
374
  self.clip_config = clip_config
 
375
 
376
  def _calc_flops(self, G, steps):
377
  assert len(G.shape) == 2
@@ -444,11 +643,18 @@ class Muon(torch.optim.Optimizer):
444
  if mesh is None:
445
  mesh = p.device_mesh
446
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
 
 
 
 
 
447
  elif mesh != p.device_mesh:
448
  raise ValueError("All parameters must be on the same mesh.")
449
 
 
450
  param_to_state[id(p)] = _muon_state()
451
- param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
 
452
  param_to_state[id(p)].process_group = process_group
453
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
454
  param_to_state[id(p)].qk_clip_state = qk_clip_state
@@ -478,7 +684,7 @@ class Muon(torch.optim.Optimizer):
478
  else:
479
  g = buf
480
 
481
- u = _zeropower_via_newtonschulz5(g.bfloat16(),
482
  steps=group["ns_steps"])
483
 
484
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
@@ -493,15 +699,12 @@ class Muon(torch.optim.Optimizer):
493
  def _update_g(self, p, g, group, momentum):
494
  # calc update
495
  state = self.state[p]
496
- if "momentum_buffer" not in state:
497
- state["momentum_buffer"] = torch.zeros_like(g)
498
- buf = state["momentum_buffer"]
499
- buf.mul_(momentum).add_(g)
500
  if group["nesterov"]:
501
- g = g.add(buf, alpha=momentum)
502
- else:
503
- g = buf
504
- return g
505
 
506
  @staticmethod
507
  def _update_p(p, u, lr, adjusted_lr, weight_decay):
@@ -585,11 +788,17 @@ class Muon(torch.optim.Optimizer):
585
  param_to_state, ordered_params = self.init_state_and_assign_params(
586
  names, params, group, qk_logits)
587
 
588
- def enqueue_gathers(start_idx, chunk_size):
589
- for p in ordered_params[start_idx:start_idx + chunk_size]:
590
- state = param_to_state[id(p)]
591
- _gather(p, state, self.rank, self.comm_stream,
592
- group["none_grad"])
 
 
 
 
 
 
593
 
594
  def enqueue_computes(start_idx, chunk_size):
595
  for p in ordered_params[start_idx:start_idx + chunk_size]:
@@ -597,10 +806,14 @@ class Muon(torch.optim.Optimizer):
597
  _compute_u(p, state, group["ns_steps"], self.rank,
598
  self.compute_stream)
599
 
600
- def enqueue_scatters(start_idx, chunk_size):
601
- for p in ordered_params[start_idx:start_idx + chunk_size]:
602
- state = param_to_state[id(p)]
603
- _scatter(p, state, self.rank, self.comm_stream)
 
 
 
 
604
 
605
  def enqueue_update_param(start_idx, chunk_size):
606
  for p in ordered_params[start_idx:start_idx + chunk_size]:
@@ -615,14 +828,16 @@ class Muon(torch.optim.Optimizer):
615
  # Wait grad update
616
  self.comm_stream.wait_stream(torch.cuda.current_stream())
617
 
618
- enqueue_gathers(0, chunk_size)
 
 
 
 
619
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
620
- enqueue_computes(i, chunk_size)
621
- if i > 0:
622
- enqueue_update_param(i - chunk_size, chunk_size)
623
- enqueue_gathers(i + chunk_size, chunk_size)
624
- enqueue_scatters(i, chunk_size)
625
- enqueue_update_param(i, chunk_size)
626
 
627
  # Wait the last update_param to finish
628
  torch.cuda.current_stream().wait_stream(self.compute_stream)
 
8
  import torch.distributed as dist
9
  from torch.distributed._tensor import DTensor, Replicate, Shard
10
 
11
+ from .matmul_transpose_triton import matmul_transpose_assign
12
+
13
  logger = logging.getLogger(__name__)
14
 
15
+ COMM_DTYPE = torch.bfloat16
16
+
17
 
18
  # This code snippet is a modified version adapted from the following GitHub repositories:
19
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
20
  # Muon's Newton–Schulz iteration causes high variance in singular values
21
  # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
22
  @torch.no_grad()
23
+ # matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
24
  def _zeropower_via_newtonschulz5(G, steps):
25
  """
26
  Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
 
32
  performance at all relative to UV^T, where USV^T = G is the SVD.
33
  """
34
  assert len(G.shape) == 2
35
+ assert G.dtype == COMM_DTYPE
36
  X = G # no manual typecast
37
 
38
  if G.size(0) > G.size(1):
39
  X = X.T
40
  # Ensure spectral norm is at most 1
41
  X = X / (X.norm() + 1e-7)
42
+ buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
43
+ buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
44
  # Perform the NS iterations
45
  for a, b, c in [
46
  (4.0848, -6.8946, 2.9270),
 
49
  (2.8769, -3.1427, 1.2046),
50
  (2.8366, -3.0525, 1.2012),
51
  ]:
52
+ matmul_transpose_assign(X, buf1)
53
+ matmul_transpose_assign(buf1, buf2)
54
+ buf1.mul_(b).add_(buf2, alpha=c)
55
+ X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)
 
 
 
56
 
57
  if G.size(0) > G.size(1):
58
  X = X.T
 
73
  qk_clip_state = None
74
 
75
 
76
+ def split_elems_for_src(param, src_rank, num_ranks) -> int:
77
+ rows = param.shape[0]
78
+ cols = int(param.numel() // rows)
79
+ base, rem = divmod(rows, num_ranks)
80
+ my_rows = base + (1 if src_rank < rem else 0)
81
+ return my_rows * cols
82
+
83
+
84
  @torch.no_grad()
85
+ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
86
  """
87
+ Pre-allocate gathered_grad buffer on compute_stream
88
+ before launching all2all gather
89
+ """
90
+ with torch.cuda.stream(compute_stream):
91
+ for p in params:
92
+ state = param_to_state[id(p)]
93
+ if rank == state.worker_rank:
94
+ num_ranks = dist.get_world_size(group=state.process_group)
95
+ state.gathered_grad = torch.empty(p.grad.numel(),
96
+ dtype=COMM_DTYPE,
97
+ device="cuda")
98
+ else:
99
+ state.gathered_grad = None
100
+
101
+ alloc_event = torch.cuda.Event()
102
+ alloc_event.record(compute_stream)
103
+ return alloc_event
104
+
105
+
106
+ @torch.no_grad()
107
+ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
108
+ alloc_event):
109
+ """
110
+ All2all gathers shards so each owner rank reconstructs its full gradient
111
  """
112
  with torch.cuda.stream(comm_stream):
113
+ process_group = param_to_state[id(params[0])].process_group
114
+ num_ranks = dist.get_world_size(group=process_group)
115
 
116
+ # Construct sending buffers
117
+ per_dst = [[] for _ in range(num_ranks)]
118
+ send_counts = [0] * num_ranks
119
+
120
+ for p in params:
121
+ state = param_to_state[id(p)]
122
+ dst = state.worker_rank
123
+ assert dst < num_ranks
124
+ shard_elems = split_elems_for_src(p, rank, num_ranks)
125
+ g = p.grad
126
+ g = g.to_local().to(COMM_DTYPE).contiguous().view(-1)
127
+ assert g.numel() == shard_elems
128
+ per_dst[dst].append(g)
129
+ send_counts[dst] += shard_elems
130
+
131
+ assert all(
132
+ len(v) > 0
133
+ for v in per_dst), "all params should be sharded to all devices"
134
+
135
+ send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
136
+ owned_params = [
137
+ p for p in params if param_to_state[id(p)].worker_rank == rank
138
+ ]
139
+
140
+ # Compute receive sizes and allocate receiving buffers
141
+ recv_counts = [0] * num_ranks
142
+
143
+ for src in range(num_ranks):
144
+ total = 0
145
+ for p in owned_params:
146
+ state = param_to_state[id(p)]
147
+ assert state.worker_rank == rank
148
+ total += split_elems_for_src(p, src, num_ranks)
149
+ recv_counts[src] = total
150
+
151
+ recv_total = sum(recv_counts)
152
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
153
+
154
+ #All2All
155
+ dist.all_to_all_single(
156
+ recv_buf,
157
+ send_buf,
158
+ output_split_sizes=recv_counts,
159
+ input_split_sizes=send_counts,
160
+ group=process_group,
161
  )
162
+
163
+ # Reconstructs gathered grad from the received buffer
164
+ #
165
+ # recv_buf (num ranks = 3)
166
+ #
167
+ # From rank 0 From rank 1 From rank 2
168
+ # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 |
169
+ #
170
+ # Outer loop:
171
+ # rank 0 -> rank 1 -> rank2
172
+ #
173
+ # Inner loop:
174
+ # p1_n -> p2_n -> p3_n
175
+
176
+ comm_stream.wait_event(alloc_event)
177
+
178
+ off = 0
179
+ write_offsets = {id(p): 0 for p in owned_params}
180
+ for src in range(num_ranks):
181
+ if recv_counts[src] == 0:
182
+ continue
183
+
184
+ block = recv_counts[src]
185
+ inner_off = 0
186
+ for p in owned_params:
187
+ state = param_to_state[id(p)]
188
+ assert state.worker_rank == rank
189
+ n = split_elems_for_src(p, src, num_ranks)
190
+ assert n > 0
191
+
192
+ sg = recv_buf.narrow(0, off + inner_off, n)
193
+ woff = write_offsets[id(p)]
194
+ dst = state.gathered_grad.narrow(0, woff, n)
195
+ dst.copy_(sg)
196
+
197
+ write_offsets[id(p)] += n
198
+ inner_off += n
199
+ off += block
200
+
201
+ for p in params:
202
+ state = param_to_state[id(p)]
203
+ if state.worker_rank == rank:
204
+ state.gathered_grad = state.gathered_grad.view_as(p)
205
+ state.gather_event = torch.cuda.Event()
206
+ state.gather_event.record(comm_stream)
207
+ else:
208
+ state.gathered_grad = None
209
+ state.gather_event = None
210
+ if none_grad:
211
+ p.grad = None
212
 
213
 
214
  @torch.no_grad()
 
222
  raise RuntimeError("Gather event must be set before compute.")
223
  compute_stream.wait_event(state.gather_event)
224
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
225
+ state.gathered_grad = None
226
  state.computed_u = u
227
+ state.compute_event = torch.cuda.Event()
228
+ state.compute_event.record()
229
+ else:
230
+ state.computed_u = None
231
+ state.compute_event = None
232
 
233
 
234
  @torch.no_grad()
235
+ def _alloc_scattered_u(params, param_to_state, rank, compute_stream):
236
  """
237
+ Pre-allocate scattered_u buffer on compute_stream
238
+ before launching all2all gather
239
  """
240
+ with torch.cuda.stream(compute_stream):
241
+ for p in params:
242
+ state = param_to_state[id(p)]
243
+ state.scattered_u = torch.empty_like(p.to_local(),
244
+ dtype=COMM_DTYPE)
245
+
246
+ alloc_event = torch.cuda.Event()
247
+ alloc_event.record(compute_stream)
248
+ return alloc_event
249
+
250
 
251
+ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
252
+ """
253
+ All2all scatters full gradients to all ranks
254
+ """
255
  with torch.cuda.stream(comm_stream):
256
+ process_group = param_to_state[id(params[0])].process_group
257
+ num_ranks = dist.get_world_size(group=process_group)
258
+ owned_params = [
259
+ p for p in params if param_to_state[id(p)].worker_rank == rank
260
+ ]
261
+
262
+ # Construct sending buffer
263
+ per_dst = [[] for _ in range(num_ranks)]
264
+ send_counts = [0] * num_ranks
265
+
266
+ if owned_params:
267
+ for p in owned_params:
268
+ state = param_to_state[id(p)]
269
+ if state.compute_event is None:
270
+ raise RuntimeError(
271
+ "Compute event must be set before scatter.")
272
+ comm_stream.wait_event(state.compute_event)
273
+ state.gathered_grad = None
274
 
275
+ assert state.computed_u is not None
276
+
277
+ u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1)
278
+
279
+ offset = 0
280
+ for dst in range(num_ranks):
281
+ n = split_elems_for_src(p, dst, num_ranks)
282
+ assert n > 0
283
 
284
+ su = u_full.narrow(0, offset, n)
285
+ per_dst[dst].append(su)
286
+ send_counts[dst] += n
287
+ offset += n
288
+
289
+ assert offset == u_full.numel()
290
+
291
+ if any(len(v) > 0 for v in per_dst):
292
+ send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
293
  else:
294
+ # all_to_all requires participation from all ranks
295
+ # Even non-owner ranks must join the collective call
296
+ send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
297
 
298
+ # Compute receive sizes and allocate receiving buffers
299
+ recv_counts = [0] * num_ranks
300
+
301
+ for src in range(num_ranks):
302
+ total = 0
303
+ for p in params:
304
+ state = param_to_state[id(p)]
305
+ if state.worker_rank != src:
306
+ continue
307
+ total += split_elems_for_src(p, rank, num_ranks)
308
+ recv_counts[src] = total
309
+
310
+ recv_total = sum(recv_counts)
311
+ assert recv_total > 0
312
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
313
+
314
+ #All2All
315
+ dist.all_to_all_single(
316
+ recv_buf,
317
+ send_buf,
318
+ output_split_sizes=recv_counts,
319
+ input_split_sizes=send_counts,
320
+ group=process_group,
321
  )
322
+
323
+ # Copy to pre-allocated scattered_u buffer from the received buffer
324
+ #
325
+ # recv_buf (num ranks = 3, local_rank = 0)
326
+ #
327
+ # From rank 0 From rank 1 From rank 2
328
+ # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 |
329
+ #
330
+ # Outer loop:
331
+ # rank 0 -> rank 1 -> rank2
332
+ #
333
+ # Inner loop:
334
+ # src(0) : p1_0 -> p2_0 -> p3_0
335
+ # src(1) : p4_0
336
+ # src(2) : p5_0 -> p6_0
337
+
338
+ comm_stream.wait_event(alloc_event)
339
+
340
+ off = 0
341
+ for src in range(num_ranks):
342
+ block = recv_counts[src]
343
+ if block == 0:
344
+ continue
345
+
346
+ inner_off = 0
347
+ for p in params:
348
+ state = param_to_state[id(p)]
349
+ if state.worker_rank != src:
350
+ continue
351
+ n = split_elems_for_src(p, rank, num_ranks)
352
+ assert n > 0
353
+
354
+ flat_local = recv_buf.narrow(0, off + inner_off,
355
+ n).view_as(p.to_local())
356
+ state.scattered_u.copy_(flat_local)
357
+
358
+ state.scatter_event = torch.cuda.Event()
359
+ state.scatter_event.record(comm_stream)
360
+ inner_off += n
361
+
362
+ assert inner_off == block
363
+ off += block
364
 
365
 
366
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
 
516
  "head_dim": 128,
517
  "threshold": 100
518
  }
519
+ overlap_step : How many all2all gather, compute operations are launched in advance
520
+ before the corresponding all2all scatter steps begin.
521
+ A higher overlap_step increases memory usage but can improve
522
+ performance by overlapping communication.
523
+ Parallel muon only.
524
  """
525
 
526
  def __init__(self,
 
539
  "k_indices": [],
540
  "head_dim": 128,
541
  "threshold": 100
542
+ },
543
+ overlap_step=5):
544
  defaults = dict(
545
  lr=lr,
546
  weight_decay=weight_decay,
 
564
 
565
  super().__init__(params, defaults)
566
 
567
+ self.rank = None
 
 
 
568
 
569
  self.comm_stream = torch.cuda.Stream()
570
  self.compute_stream = torch.cuda.Stream()
571
  self.debug = debug
572
  self.clip_config = clip_config
573
+ self.overlap_step = overlap_step
574
 
575
  def _calc_flops(self, G, steps):
576
  assert len(G.shape) == 2
 
643
  if mesh is None:
644
  mesh = p.device_mesh
645
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
646
+ local_rank = dist.get_rank(group=process_group)
647
+ if self.rank is None:
648
+ self.rank = dist.get_rank(group=process_group)
649
+ else:
650
+ assert self.rank == local_rank
651
  elif mesh != p.device_mesh:
652
  raise ValueError("All parameters must be on the same mesh.")
653
 
654
+ num_ranks = dist.get_world_size(group=process_group)
655
  param_to_state[id(p)] = _muon_state()
656
+ param_to_state[id(
657
+ p)].worker_rank = shard_mesh[round_robin].item() % num_ranks
658
  param_to_state[id(p)].process_group = process_group
659
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
660
  param_to_state[id(p)].qk_clip_state = qk_clip_state
 
684
  else:
685
  g = buf
686
 
687
+ u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
688
  steps=group["ns_steps"])
689
 
690
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
 
699
  def _update_g(self, p, g, group, momentum):
700
  # calc update
701
  state = self.state[p]
702
+ buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
703
+ torch.add(g, buf, alpha=momentum, out=buf)
 
 
704
  if group["nesterov"]:
705
+ g.add_(buf, alpha=momentum)
706
+ return g
707
+ return buf
 
708
 
709
  @staticmethod
710
  def _update_p(p, u, lr, adjusted_lr, weight_decay):
 
788
  param_to_state, ordered_params = self.init_state_and_assign_params(
789
  names, params, group, qk_logits)
790
 
791
+ assert self.rank is not None
792
+
793
+ def enqueue_all2all_gather(start_idx, chunk_size):
794
+ target_params = ordered_params[start_idx:start_idx + chunk_size]
795
+ if target_params:
796
+ alloc_event = _alloc_gathered_grad(target_params,
797
+ param_to_state, self.rank,
798
+ self.compute_stream)
799
+ _all2all_gather(target_params, param_to_state, self.rank,
800
+ self.comm_stream, group["none_grad"],
801
+ alloc_event)
802
 
803
  def enqueue_computes(start_idx, chunk_size):
804
  for p in ordered_params[start_idx:start_idx + chunk_size]:
 
806
  _compute_u(p, state, group["ns_steps"], self.rank,
807
  self.compute_stream)
808
 
809
+ def enqueue_all2all_scatter(start_idx, chunk_size):
810
+ target_params = ordered_params[start_idx:start_idx + chunk_size]
811
+ if target_params:
812
+ alloc_event = _alloc_scattered_u(target_params, param_to_state,
813
+ self.rank,
814
+ self.compute_stream)
815
+ _all2all_scatter(target_params, param_to_state, self.rank,
816
+ self.comm_stream, alloc_event)
817
 
818
  def enqueue_update_param(start_idx, chunk_size):
819
  for p in ordered_params[start_idx:start_idx + chunk_size]:
 
828
  # Wait grad update
829
  self.comm_stream.wait_stream(torch.cuda.current_stream())
830
 
831
+ overlap_step = self.overlap_step
832
+ for i in range(0, overlap_step):
833
+ enqueue_all2all_gather(i * chunk_size, chunk_size)
834
+ enqueue_computes(i * chunk_size, chunk_size)
835
+
836
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
837
+ enqueue_all2all_scatter(i, chunk_size)
838
+ enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size)
839
+ enqueue_update_param(i, chunk_size)
840
+ enqueue_computes(i + overlap_step * chunk_size, chunk_size)
 
 
841
 
842
  # Wait the last update_param to finish
843
  torch.cuda.current_stream().wait_stream(self.compute_stream)
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_9c21645_dirty
3
- ops = torch.ops._optimizer_9c21645_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_9c21645_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_15336dc_dirty
3
+ ops = torch.ops._optimizer_15336dc_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_15336dc_dirty::{op_name}"
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_9c21645_dirty.abi3.so → _optimizer_15336dc_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:41492cb1479920b654768a5597d88670dd0caeedbdcd73fd63afa31ffc6961d6
3
  size 1749776
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7cf2f7b8519dbc3f20e9d151914b55e56d10c012e2232d550b7c8d262746d71
3
  size 1749776
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/matmul_transpose_triton.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright (c) 2025 Tianyang Lin
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ import torch
24
+ import triton
25
+ import triton.language as tl
26
+
27
+
28
+ def get_autotune_config():
29
+ return [
30
+ triton.Config(
31
+ {
32
+ 'BLOCK_SIZE_M': blk_m,
33
+ 'BLOCK_SIZE_K': blk_k,
34
+ 'GROUP_SIZE_M': grp_sz
35
+ },
36
+ num_stages=n_stages,
37
+ num_warps=n_warps) for blk_m in [32, 64, 128]
38
+ for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5]
39
+ for n_warps in [4, 8]
40
+ ]
41
+
42
+
43
+ @triton.autotune(
44
+ configs=get_autotune_config(),
45
+ key=['M', 'K'],
46
+ )
47
+ @triton.jit
48
+ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
49
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
50
+ GROUP_SIZE_M: tl.constexpr):
51
+ """
52
+ Core kernel jit function of matmul_transpose that computes y = x @ x.T
53
+ The code is a simple adaptation from the triton `matmul` tutorial:
54
+ https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
55
+ """
56
+ pid = tl.program_id(axis=0)
57
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
58
+ num_pid_n = tl.cdiv(M, BLOCK_SIZE_M)
59
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
60
+ group_id = pid // num_pid_in_group
61
+ first_pid_m = group_id * GROUP_SIZE_M
62
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
63
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
64
+ pid_n = (pid % num_pid_in_group) // group_size_m
65
+ if pid_m > pid_n:
66
+ return
67
+
68
+ offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
69
+ offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
70
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
71
+ # we use a & b ptrs to denote different rows of x.
72
+ a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
73
+ b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk)
74
+
75
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32)
76
+
77
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
78
+ a = tl.load(a_ptrs,
79
+ mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
80
+ other=0.0)
81
+ b = tl.load(b_ptrs,
82
+ mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
83
+ other=0.0)
84
+ accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator)
85
+ a_ptrs += BLOCK_SIZE_K * stride_xk
86
+ b_ptrs += BLOCK_SIZE_K * stride_xk
87
+ # use dtype.element_ty to accommodate different input datatypes as in cpp templates
88
+ # https://github.com/triton-lang/triton/issues/2252
89
+ c = accumulator.to(x.dtype.element_ty)
90
+
91
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
92
+ offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
93
+ c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :]
94
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
95
+ tl.store(c_ptrs, c, mask=c_mask)
96
+
97
+ # transpose and copy
98
+ if pid_m < pid_n:
99
+ ct_ptrs = y + stride_ym * offs_cn[:,
100
+ None] + stride_yn * offs_cm[None, :]
101
+ ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M)
102
+ tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)
103
+
104
+
105
+ def matmul_transpose_assign(d_in, d_out):
106
+ assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor"
107
+ assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor"
108
+ assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device"
109
+ assert d_in.dtype == d_out.dtype, "Inputs must have the same data type"
110
+ assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor"
111
+ assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor"
112
+ assert d_in.size(0) == d_out.size(0) == d_out.size(0), \
113
+ "First dimension of `d_in` must match first and second dimension of `d_out`"
114
+
115
+ d_in = d_in.contiguous()
116
+ M, K = d_in.shape
117
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
118
+ M, META['BLOCK_SIZE_M']), )
119
+ with torch.cuda.device(d_in.device.index):
120
+ mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
121
+ d_out.stride(0), d_out.stride(1))
122
+
123
+
124
+ def matmul_transpose(d_in):
125
+ M, _ = d_in.shape
126
+ d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype)
127
+ matmul_transpose_assign(d_in, d_out)
128
+ return d_out
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py CHANGED
@@ -8,14 +8,19 @@ import torch
8
  import torch.distributed as dist
9
  from torch.distributed._tensor import DTensor, Replicate, Shard
10
 
 
 
11
  logger = logging.getLogger(__name__)
12
 
 
 
13
 
14
  # This code snippet is a modified version adapted from the following GitHub repositories:
15
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
16
  # Muon's Newton–Schulz iteration causes high variance in singular values
17
  # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
18
  @torch.no_grad()
 
19
  def _zeropower_via_newtonschulz5(G, steps):
20
  """
21
  Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
@@ -27,13 +32,15 @@ def _zeropower_via_newtonschulz5(G, steps):
27
  performance at all relative to UV^T, where USV^T = G is the SVD.
28
  """
29
  assert len(G.shape) == 2
30
- assert G.dtype == torch.bfloat16
31
  X = G # no manual typecast
32
 
33
  if G.size(0) > G.size(1):
34
  X = X.T
35
  # Ensure spectral norm is at most 1
36
  X = X / (X.norm() + 1e-7)
 
 
37
  # Perform the NS iterations
38
  for a, b, c in [
39
  (4.0848, -6.8946, 2.9270),
@@ -42,13 +49,10 @@ def _zeropower_via_newtonschulz5(G, steps):
42
  (2.8769, -3.1427, 1.2046),
43
  (2.8366, -3.0525, 1.2012),
44
  ]:
45
- A = X @ X.T
46
- # B = (
47
- # b * A + c * A @ A
48
- # )
49
- B = torch.addmm(A, A, A, alpha=c, beta=b)
50
- # X = a * X + B @ X
51
- X = torch.addmm(X, B, X, alpha=1.0, beta=a)
52
 
53
  if G.size(0) > G.size(1):
54
  X = X.T
@@ -69,51 +73,142 @@ class _muon_state:
69
  qk_clip_state = None
70
 
71
 
 
 
 
 
 
 
 
 
72
  @torch.no_grad()
73
- def _gather(p, state, rank, comm_stream, none_grad):
74
  """
75
- Gather the gradients to worker_rank.
76
- If none_grad is True, free p.grad after the gather.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  """
78
  with torch.cuda.stream(comm_stream):
79
- g = p.grad
 
80
 
81
- if rank == state.worker_rank:
82
- num_ranks = dist.get_world_size(group=state.process_group)
83
- gather_list = [
84
- torch.empty_like(g.to_local(), dtype=torch.bfloat16)
85
- for _ in range(num_ranks)
86
- ]
87
- else:
88
- gather_list = None
89
-
90
- g = g.to(torch.bfloat16)
91
- torch.distributed.gather(
92
- g.to_local(),
93
- dst=state.worker_rank,
94
- gather_list=gather_list,
95
- group=state.process_group,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  )
97
- if rank == state.worker_rank:
98
- if state.gathered_grad is not None:
99
- raise RuntimeError(
100
- "Gather event already exists, which should not happen.")
101
- state.gathered_grad = torch.cat(gather_list, dim=0)
102
- state.gather_event = torch.cuda.Event()
103
- state.gather_event.record()
104
- else:
105
- state.gathered_grad = None
106
- state.gather_event = None
107
- gather_list = None
108
- if none_grad:
109
- # We can safely free p.grad without calling record_stream:
110
- # p.grad.to_local().record_stream(comm_stream)
111
- # Explanation:
112
- # 1. p.grad is created on the default stream, but the default stream
113
- # is synchronized with the comm stream later.
114
- # 2. There is no further activity on the default stream before the optimizer finishes.
115
- # Therefore, it is safe to free p.grad directly on the comm stream.
116
- p.grad = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
 
119
  @torch.no_grad()
@@ -127,45 +222,145 @@ def _compute_u(p, state, steps, rank, compute_stream):
127
  raise RuntimeError("Gather event must be set before compute.")
128
  compute_stream.wait_event(state.gather_event)
129
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
 
130
  state.computed_u = u
131
- state.scattered_u = torch.empty_like(p.to_local(),
132
- dtype=torch.bfloat16)
133
- state.compute_event = torch.cuda.Event()
134
- state.compute_event.record()
135
- u = None
136
 
137
 
138
  @torch.no_grad()
139
- def _scatter(p, state, rank, comm_stream):
140
  """
141
- Scatter the computed_u from worker_rank to all ranks.
 
142
  """
 
 
 
 
 
 
 
 
 
 
143
 
 
 
 
 
144
  with torch.cuda.stream(comm_stream):
145
- if state.compute_event is None:
146
- raise RuntimeError("Compute event must be set before scatter.")
147
- comm_stream.wait_event(state.compute_event)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
- if rank == state.worker_rank:
150
- num_ranks = dist.get_world_size(group=state.process_group)
151
- # Clear the gathered gradient to free memory
152
- state.gathered_grad = None
 
 
 
 
153
 
154
- u = state.computed_u
155
- scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
156
- scatter_list = [s.contiguous() for s in scatter_list]
 
 
 
 
 
 
157
  else:
158
- scatter_list = None
 
 
159
 
160
- torch.distributed.scatter(
161
- state.scattered_u,
162
- scatter_list=scatter_list,
163
- src=state.worker_rank,
164
- group=state.process_group,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  )
166
- state.scatter_event = torch.cuda.Event()
167
- state.scatter_event.record()
168
- scatter_list = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
 
171
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
@@ -321,6 +516,11 @@ class Muon(torch.optim.Optimizer):
321
  "head_dim": 128,
322
  "threshold": 100
323
  }
 
 
 
 
 
324
  """
325
 
326
  def __init__(self,
@@ -339,7 +539,8 @@ class Muon(torch.optim.Optimizer):
339
  "k_indices": [],
340
  "head_dim": 128,
341
  "threshold": 100
342
- }):
 
343
  defaults = dict(
344
  lr=lr,
345
  weight_decay=weight_decay,
@@ -363,15 +564,13 @@ class Muon(torch.optim.Optimizer):
363
 
364
  super().__init__(params, defaults)
365
 
366
- if dist.is_initialized():
367
- self.rank = dist.get_rank()
368
- else:
369
- self.rank = None
370
 
371
  self.comm_stream = torch.cuda.Stream()
372
  self.compute_stream = torch.cuda.Stream()
373
  self.debug = debug
374
  self.clip_config = clip_config
 
375
 
376
  def _calc_flops(self, G, steps):
377
  assert len(G.shape) == 2
@@ -444,11 +643,18 @@ class Muon(torch.optim.Optimizer):
444
  if mesh is None:
445
  mesh = p.device_mesh
446
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
 
 
 
 
 
447
  elif mesh != p.device_mesh:
448
  raise ValueError("All parameters must be on the same mesh.")
449
 
 
450
  param_to_state[id(p)] = _muon_state()
451
- param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
 
452
  param_to_state[id(p)].process_group = process_group
453
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
454
  param_to_state[id(p)].qk_clip_state = qk_clip_state
@@ -478,7 +684,7 @@ class Muon(torch.optim.Optimizer):
478
  else:
479
  g = buf
480
 
481
- u = _zeropower_via_newtonschulz5(g.bfloat16(),
482
  steps=group["ns_steps"])
483
 
484
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
@@ -493,15 +699,12 @@ class Muon(torch.optim.Optimizer):
493
  def _update_g(self, p, g, group, momentum):
494
  # calc update
495
  state = self.state[p]
496
- if "momentum_buffer" not in state:
497
- state["momentum_buffer"] = torch.zeros_like(g)
498
- buf = state["momentum_buffer"]
499
- buf.mul_(momentum).add_(g)
500
  if group["nesterov"]:
501
- g = g.add(buf, alpha=momentum)
502
- else:
503
- g = buf
504
- return g
505
 
506
  @staticmethod
507
  def _update_p(p, u, lr, adjusted_lr, weight_decay):
@@ -585,11 +788,17 @@ class Muon(torch.optim.Optimizer):
585
  param_to_state, ordered_params = self.init_state_and_assign_params(
586
  names, params, group, qk_logits)
587
 
588
- def enqueue_gathers(start_idx, chunk_size):
589
- for p in ordered_params[start_idx:start_idx + chunk_size]:
590
- state = param_to_state[id(p)]
591
- _gather(p, state, self.rank, self.comm_stream,
592
- group["none_grad"])
 
 
 
 
 
 
593
 
594
  def enqueue_computes(start_idx, chunk_size):
595
  for p in ordered_params[start_idx:start_idx + chunk_size]:
@@ -597,10 +806,14 @@ class Muon(torch.optim.Optimizer):
597
  _compute_u(p, state, group["ns_steps"], self.rank,
598
  self.compute_stream)
599
 
600
- def enqueue_scatters(start_idx, chunk_size):
601
- for p in ordered_params[start_idx:start_idx + chunk_size]:
602
- state = param_to_state[id(p)]
603
- _scatter(p, state, self.rank, self.comm_stream)
 
 
 
 
604
 
605
  def enqueue_update_param(start_idx, chunk_size):
606
  for p in ordered_params[start_idx:start_idx + chunk_size]:
@@ -615,14 +828,16 @@ class Muon(torch.optim.Optimizer):
615
  # Wait grad update
616
  self.comm_stream.wait_stream(torch.cuda.current_stream())
617
 
618
- enqueue_gathers(0, chunk_size)
 
 
 
 
619
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
620
- enqueue_computes(i, chunk_size)
621
- if i > 0:
622
- enqueue_update_param(i - chunk_size, chunk_size)
623
- enqueue_gathers(i + chunk_size, chunk_size)
624
- enqueue_scatters(i, chunk_size)
625
- enqueue_update_param(i, chunk_size)
626
 
627
  # Wait the last update_param to finish
628
  torch.cuda.current_stream().wait_stream(self.compute_stream)
 
8
  import torch.distributed as dist
9
  from torch.distributed._tensor import DTensor, Replicate, Shard
10
 
11
+ from .matmul_transpose_triton import matmul_transpose_assign
12
+
13
  logger = logging.getLogger(__name__)
14
 
15
+ COMM_DTYPE = torch.bfloat16
16
+
17
 
18
  # This code snippet is a modified version adapted from the following GitHub repositories:
19
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
20
  # Muon's Newton–Schulz iteration causes high variance in singular values
21
  # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
22
  @torch.no_grad()
23
+ # matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
24
  def _zeropower_via_newtonschulz5(G, steps):
25
  """
26
  Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
 
32
  performance at all relative to UV^T, where USV^T = G is the SVD.
33
  """
34
  assert len(G.shape) == 2
35
+ assert G.dtype == COMM_DTYPE
36
  X = G # no manual typecast
37
 
38
  if G.size(0) > G.size(1):
39
  X = X.T
40
  # Ensure spectral norm is at most 1
41
  X = X / (X.norm() + 1e-7)
42
+ buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
43
+ buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
44
  # Perform the NS iterations
45
  for a, b, c in [
46
  (4.0848, -6.8946, 2.9270),
 
49
  (2.8769, -3.1427, 1.2046),
50
  (2.8366, -3.0525, 1.2012),
51
  ]:
52
+ matmul_transpose_assign(X, buf1)
53
+ matmul_transpose_assign(buf1, buf2)
54
+ buf1.mul_(b).add_(buf2, alpha=c)
55
+ X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)
 
 
 
56
 
57
  if G.size(0) > G.size(1):
58
  X = X.T
 
73
  qk_clip_state = None
74
 
75
 
76
+ def split_elems_for_src(param, src_rank, num_ranks) -> int:
77
+ rows = param.shape[0]
78
+ cols = int(param.numel() // rows)
79
+ base, rem = divmod(rows, num_ranks)
80
+ my_rows = base + (1 if src_rank < rem else 0)
81
+ return my_rows * cols
82
+
83
+
84
  @torch.no_grad()
85
+ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
86
  """
87
+ Pre-allocate gathered_grad buffer on compute_stream
88
+ before launching all2all gather
89
+ """
90
+ with torch.cuda.stream(compute_stream):
91
+ for p in params:
92
+ state = param_to_state[id(p)]
93
+ if rank == state.worker_rank:
94
+ num_ranks = dist.get_world_size(group=state.process_group)
95
+ state.gathered_grad = torch.empty(p.grad.numel(),
96
+ dtype=COMM_DTYPE,
97
+ device="cuda")
98
+ else:
99
+ state.gathered_grad = None
100
+
101
+ alloc_event = torch.cuda.Event()
102
+ alloc_event.record(compute_stream)
103
+ return alloc_event
104
+
105
+
106
+ @torch.no_grad()
107
+ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
108
+ alloc_event):
109
+ """
110
+ All2all gathers shards so each owner rank reconstructs its full gradient
111
  """
112
  with torch.cuda.stream(comm_stream):
113
+ process_group = param_to_state[id(params[0])].process_group
114
+ num_ranks = dist.get_world_size(group=process_group)
115
 
116
+ # Construct sending buffers
117
+ per_dst = [[] for _ in range(num_ranks)]
118
+ send_counts = [0] * num_ranks
119
+
120
+ for p in params:
121
+ state = param_to_state[id(p)]
122
+ dst = state.worker_rank
123
+ assert dst < num_ranks
124
+ shard_elems = split_elems_for_src(p, rank, num_ranks)
125
+ g = p.grad
126
+ g = g.to_local().to(COMM_DTYPE).contiguous().view(-1)
127
+ assert g.numel() == shard_elems
128
+ per_dst[dst].append(g)
129
+ send_counts[dst] += shard_elems
130
+
131
+ assert all(
132
+ len(v) > 0
133
+ for v in per_dst), "all params should be sharded to all devices"
134
+
135
+ send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
136
+ owned_params = [
137
+ p for p in params if param_to_state[id(p)].worker_rank == rank
138
+ ]
139
+
140
+ # Compute receive sizes and allocate receiving buffers
141
+ recv_counts = [0] * num_ranks
142
+
143
+ for src in range(num_ranks):
144
+ total = 0
145
+ for p in owned_params:
146
+ state = param_to_state[id(p)]
147
+ assert state.worker_rank == rank
148
+ total += split_elems_for_src(p, src, num_ranks)
149
+ recv_counts[src] = total
150
+
151
+ recv_total = sum(recv_counts)
152
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
153
+
154
+ #All2All
155
+ dist.all_to_all_single(
156
+ recv_buf,
157
+ send_buf,
158
+ output_split_sizes=recv_counts,
159
+ input_split_sizes=send_counts,
160
+ group=process_group,
161
  )
162
+
163
+ # Reconstructs gathered grad from the received buffer
164
+ #
165
+ # recv_buf (num ranks = 3)
166
+ #
167
+ # From rank 0 From rank 1 From rank 2
168
+ # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 |
169
+ #
170
+ # Outer loop:
171
+ # rank 0 -> rank 1 -> rank2
172
+ #
173
+ # Inner loop:
174
+ # p1_n -> p2_n -> p3_n
175
+
176
+ comm_stream.wait_event(alloc_event)
177
+
178
+ off = 0
179
+ write_offsets = {id(p): 0 for p in owned_params}
180
+ for src in range(num_ranks):
181
+ if recv_counts[src] == 0:
182
+ continue
183
+
184
+ block = recv_counts[src]
185
+ inner_off = 0
186
+ for p in owned_params:
187
+ state = param_to_state[id(p)]
188
+ assert state.worker_rank == rank
189
+ n = split_elems_for_src(p, src, num_ranks)
190
+ assert n > 0
191
+
192
+ sg = recv_buf.narrow(0, off + inner_off, n)
193
+ woff = write_offsets[id(p)]
194
+ dst = state.gathered_grad.narrow(0, woff, n)
195
+ dst.copy_(sg)
196
+
197
+ write_offsets[id(p)] += n
198
+ inner_off += n
199
+ off += block
200
+
201
+ for p in params:
202
+ state = param_to_state[id(p)]
203
+ if state.worker_rank == rank:
204
+ state.gathered_grad = state.gathered_grad.view_as(p)
205
+ state.gather_event = torch.cuda.Event()
206
+ state.gather_event.record(comm_stream)
207
+ else:
208
+ state.gathered_grad = None
209
+ state.gather_event = None
210
+ if none_grad:
211
+ p.grad = None
212
 
213
 
214
  @torch.no_grad()
 
222
  raise RuntimeError("Gather event must be set before compute.")
223
  compute_stream.wait_event(state.gather_event)
224
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
225
+ state.gathered_grad = None
226
  state.computed_u = u
227
+ state.compute_event = torch.cuda.Event()
228
+ state.compute_event.record()
229
+ else:
230
+ state.computed_u = None
231
+ state.compute_event = None
232
 
233
 
234
  @torch.no_grad()
235
+ def _alloc_scattered_u(params, param_to_state, rank, compute_stream):
236
  """
237
+ Pre-allocate scattered_u buffer on compute_stream
238
+ before launching all2all gather
239
  """
240
+ with torch.cuda.stream(compute_stream):
241
+ for p in params:
242
+ state = param_to_state[id(p)]
243
+ state.scattered_u = torch.empty_like(p.to_local(),
244
+ dtype=COMM_DTYPE)
245
+
246
+ alloc_event = torch.cuda.Event()
247
+ alloc_event.record(compute_stream)
248
+ return alloc_event
249
+
250
 
251
+ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
252
+ """
253
+ All2all scatters full gradients to all ranks
254
+ """
255
  with torch.cuda.stream(comm_stream):
256
+ process_group = param_to_state[id(params[0])].process_group
257
+ num_ranks = dist.get_world_size(group=process_group)
258
+ owned_params = [
259
+ p for p in params if param_to_state[id(p)].worker_rank == rank
260
+ ]
261
+
262
+ # Construct sending buffer
263
+ per_dst = [[] for _ in range(num_ranks)]
264
+ send_counts = [0] * num_ranks
265
+
266
+ if owned_params:
267
+ for p in owned_params:
268
+ state = param_to_state[id(p)]
269
+ if state.compute_event is None:
270
+ raise RuntimeError(
271
+ "Compute event must be set before scatter.")
272
+ comm_stream.wait_event(state.compute_event)
273
+ state.gathered_grad = None
274
 
275
+ assert state.computed_u is not None
276
+
277
+ u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1)
278
+
279
+ offset = 0
280
+ for dst in range(num_ranks):
281
+ n = split_elems_for_src(p, dst, num_ranks)
282
+ assert n > 0
283
 
284
+ su = u_full.narrow(0, offset, n)
285
+ per_dst[dst].append(su)
286
+ send_counts[dst] += n
287
+ offset += n
288
+
289
+ assert offset == u_full.numel()
290
+
291
+ if any(len(v) > 0 for v in per_dst):
292
+ send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
293
  else:
294
+ # all_to_all requires participation from all ranks
295
+ # Even non-owner ranks must join the collective call
296
+ send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
297
 
298
+ # Compute receive sizes and allocate receiving buffers
299
+ recv_counts = [0] * num_ranks
300
+
301
+ for src in range(num_ranks):
302
+ total = 0
303
+ for p in params:
304
+ state = param_to_state[id(p)]
305
+ if state.worker_rank != src:
306
+ continue
307
+ total += split_elems_for_src(p, rank, num_ranks)
308
+ recv_counts[src] = total
309
+
310
+ recv_total = sum(recv_counts)
311
+ assert recv_total > 0
312
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
313
+
314
+ #All2All
315
+ dist.all_to_all_single(
316
+ recv_buf,
317
+ send_buf,
318
+ output_split_sizes=recv_counts,
319
+ input_split_sizes=send_counts,
320
+ group=process_group,
321
  )
322
+
323
+ # Copy to pre-allocated scattered_u buffer from the received buffer
324
+ #
325
+ # recv_buf (num ranks = 3, local_rank = 0)
326
+ #
327
+ # From rank 0 From rank 1 From rank 2
328
+ # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 |
329
+ #
330
+ # Outer loop:
331
+ # rank 0 -> rank 1 -> rank2
332
+ #
333
+ # Inner loop:
334
+ # src(0) : p1_0 -> p2_0 -> p3_0
335
+ # src(1) : p4_0
336
+ # src(2) : p5_0 -> p6_0
337
+
338
+ comm_stream.wait_event(alloc_event)
339
+
340
+ off = 0
341
+ for src in range(num_ranks):
342
+ block = recv_counts[src]
343
+ if block == 0:
344
+ continue
345
+
346
+ inner_off = 0
347
+ for p in params:
348
+ state = param_to_state[id(p)]
349
+ if state.worker_rank != src:
350
+ continue
351
+ n = split_elems_for_src(p, rank, num_ranks)
352
+ assert n > 0
353
+
354
+ flat_local = recv_buf.narrow(0, off + inner_off,
355
+ n).view_as(p.to_local())
356
+ state.scattered_u.copy_(flat_local)
357
+
358
+ state.scatter_event = torch.cuda.Event()
359
+ state.scatter_event.record(comm_stream)
360
+ inner_off += n
361
+
362
+ assert inner_off == block
363
+ off += block
364
 
365
 
366
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
 
516
  "head_dim": 128,
517
  "threshold": 100
518
  }
519
+ overlap_step : How many all2all gather, compute operations are launched in advance
520
+ before the corresponding all2all scatter steps begin.
521
+ A higher overlap_step increases memory usage but can improve
522
+ performance by overlapping communication.
523
+ Parallel muon only.
524
  """
525
 
526
  def __init__(self,
 
539
  "k_indices": [],
540
  "head_dim": 128,
541
  "threshold": 100
542
+ },
543
+ overlap_step=5):
544
  defaults = dict(
545
  lr=lr,
546
  weight_decay=weight_decay,
 
564
 
565
  super().__init__(params, defaults)
566
 
567
+ self.rank = None
 
 
 
568
 
569
  self.comm_stream = torch.cuda.Stream()
570
  self.compute_stream = torch.cuda.Stream()
571
  self.debug = debug
572
  self.clip_config = clip_config
573
+ self.overlap_step = overlap_step
574
 
575
  def _calc_flops(self, G, steps):
576
  assert len(G.shape) == 2
 
643
  if mesh is None:
644
  mesh = p.device_mesh
645
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
646
+ local_rank = dist.get_rank(group=process_group)
647
+ if self.rank is None:
648
+ self.rank = dist.get_rank(group=process_group)
649
+ else:
650
+ assert self.rank == local_rank
651
  elif mesh != p.device_mesh:
652
  raise ValueError("All parameters must be on the same mesh.")
653
 
654
+ num_ranks = dist.get_world_size(group=process_group)
655
  param_to_state[id(p)] = _muon_state()
656
+ param_to_state[id(
657
+ p)].worker_rank = shard_mesh[round_robin].item() % num_ranks
658
  param_to_state[id(p)].process_group = process_group
659
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
660
  param_to_state[id(p)].qk_clip_state = qk_clip_state
 
684
  else:
685
  g = buf
686
 
687
+ u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
688
  steps=group["ns_steps"])
689
 
690
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
 
699
  def _update_g(self, p, g, group, momentum):
700
  # calc update
701
  state = self.state[p]
702
+ buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
703
+ torch.add(g, buf, alpha=momentum, out=buf)
 
 
704
  if group["nesterov"]:
705
+ g.add_(buf, alpha=momentum)
706
+ return g
707
+ return buf
 
708
 
709
  @staticmethod
710
  def _update_p(p, u, lr, adjusted_lr, weight_decay):
 
788
  param_to_state, ordered_params = self.init_state_and_assign_params(
789
  names, params, group, qk_logits)
790
 
791
+ assert self.rank is not None
792
+
793
+ def enqueue_all2all_gather(start_idx, chunk_size):
794
+ target_params = ordered_params[start_idx:start_idx + chunk_size]
795
+ if target_params:
796
+ alloc_event = _alloc_gathered_grad(target_params,
797
+ param_to_state, self.rank,
798
+ self.compute_stream)
799
+ _all2all_gather(target_params, param_to_state, self.rank,
800
+ self.comm_stream, group["none_grad"],
801
+ alloc_event)
802
 
803
  def enqueue_computes(start_idx, chunk_size):
804
  for p in ordered_params[start_idx:start_idx + chunk_size]:
 
806
  _compute_u(p, state, group["ns_steps"], self.rank,
807
  self.compute_stream)
808
 
809
+ def enqueue_all2all_scatter(start_idx, chunk_size):
810
+ target_params = ordered_params[start_idx:start_idx + chunk_size]
811
+ if target_params:
812
+ alloc_event = _alloc_scattered_u(target_params, param_to_state,
813
+ self.rank,
814
+ self.compute_stream)
815
+ _all2all_scatter(target_params, param_to_state, self.rank,
816
+ self.comm_stream, alloc_event)
817
 
818
  def enqueue_update_param(start_idx, chunk_size):
819
  for p in ordered_params[start_idx:start_idx + chunk_size]:
 
828
  # Wait grad update
829
  self.comm_stream.wait_stream(torch.cuda.current_stream())
830
 
831
+ overlap_step = self.overlap_step
832
+ for i in range(0, overlap_step):
833
+ enqueue_all2all_gather(i * chunk_size, chunk_size)
834
+ enqueue_computes(i * chunk_size, chunk_size)
835
+
836
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
837
+ enqueue_all2all_scatter(i, chunk_size)
838
+ enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size)
839
+ enqueue_update_param(i, chunk_size)
840
+ enqueue_computes(i + overlap_step * chunk_size, chunk_size)
 
 
841
 
842
  # Wait the last update_param to finish
843
  torch.cuda.current_stream().wait_stream(self.compute_stream)
build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_9c21645_dirty
3
- ops = torch.ops._optimizer_9c21645_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_9c21645_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_15336dc_dirty
3
+ ops = torch.ops._optimizer_15336dc_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_15336dc_dirty::{op_name}"
build/{torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_9c21645_dirty.abi3.so → torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_15336dc_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:42ae6ac1cf967d7d23cac7930c8db635105f60631220a60b9cee060d082f40ae
3
  size 1824256
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8ca6ca8225dc9b7888566f5c7fd824234a3b4ac76718a5d18e6c75ca7acd488d
3
  size 1824256
build/torch28-cxx11-cu126-x86_64-linux/optimizer/matmul_transpose_triton.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright (c) 2025 Tianyang Lin
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ import torch
24
+ import triton
25
+ import triton.language as tl
26
+
27
+
28
+ def get_autotune_config():
29
+ return [
30
+ triton.Config(
31
+ {
32
+ 'BLOCK_SIZE_M': blk_m,
33
+ 'BLOCK_SIZE_K': blk_k,
34
+ 'GROUP_SIZE_M': grp_sz
35
+ },
36
+ num_stages=n_stages,
37
+ num_warps=n_warps) for blk_m in [32, 64, 128]
38
+ for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5]
39
+ for n_warps in [4, 8]
40
+ ]
41
+
42
+
43
+ @triton.autotune(
44
+ configs=get_autotune_config(),
45
+ key=['M', 'K'],
46
+ )
47
+ @triton.jit
48
+ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
49
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
50
+ GROUP_SIZE_M: tl.constexpr):
51
+ """
52
+ Core kernel jit function of matmul_transpose that computes y = x @ x.T
53
+ The code is a simple adaptation from the triton `matmul` tutorial:
54
+ https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
55
+ """
56
+ pid = tl.program_id(axis=0)
57
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
58
+ num_pid_n = tl.cdiv(M, BLOCK_SIZE_M)
59
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
60
+ group_id = pid // num_pid_in_group
61
+ first_pid_m = group_id * GROUP_SIZE_M
62
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
63
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
64
+ pid_n = (pid % num_pid_in_group) // group_size_m
65
+ if pid_m > pid_n:
66
+ return
67
+
68
+ offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
69
+ offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
70
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
71
+ # we use a & b ptrs to denote different rows of x.
72
+ a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
73
+ b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk)
74
+
75
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32)
76
+
77
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
78
+ a = tl.load(a_ptrs,
79
+ mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
80
+ other=0.0)
81
+ b = tl.load(b_ptrs,
82
+ mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
83
+ other=0.0)
84
+ accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator)
85
+ a_ptrs += BLOCK_SIZE_K * stride_xk
86
+ b_ptrs += BLOCK_SIZE_K * stride_xk
87
+ # use dtype.element_ty to accommodate different input datatypes as in cpp templates
88
+ # https://github.com/triton-lang/triton/issues/2252
89
+ c = accumulator.to(x.dtype.element_ty)
90
+
91
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
92
+ offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
93
+ c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :]
94
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
95
+ tl.store(c_ptrs, c, mask=c_mask)
96
+
97
+ # transpose and copy
98
+ if pid_m < pid_n:
99
+ ct_ptrs = y + stride_ym * offs_cn[:,
100
+ None] + stride_yn * offs_cm[None, :]
101
+ ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M)
102
+ tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)
103
+
104
+
105
+ def matmul_transpose_assign(d_in, d_out):
106
+ assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor"
107
+ assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor"
108
+ assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device"
109
+ assert d_in.dtype == d_out.dtype, "Inputs must have the same data type"
110
+ assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor"
111
+ assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor"
112
+ assert d_in.size(0) == d_out.size(0) == d_out.size(0), \
113
+ "First dimension of `d_in` must match first and second dimension of `d_out`"
114
+
115
+ d_in = d_in.contiguous()
116
+ M, K = d_in.shape
117
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
118
+ M, META['BLOCK_SIZE_M']), )
119
+ with torch.cuda.device(d_in.device.index):
120
+ mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
121
+ d_out.stride(0), d_out.stride(1))
122
+
123
+
124
+ def matmul_transpose(d_in):
125
+ M, _ = d_in.shape
126
+ d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype)
127
+ matmul_transpose_assign(d_in, d_out)
128
+ return d_out
build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py CHANGED
@@ -8,14 +8,19 @@ import torch
8
  import torch.distributed as dist
9
  from torch.distributed._tensor import DTensor, Replicate, Shard
10
 
 
 
11
  logger = logging.getLogger(__name__)
12
 
 
 
13
 
14
  # This code snippet is a modified version adapted from the following GitHub repositories:
15
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
16
  # Muon's Newton–Schulz iteration causes high variance in singular values
17
  # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
18
  @torch.no_grad()
 
19
  def _zeropower_via_newtonschulz5(G, steps):
20
  """
21
  Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
@@ -27,13 +32,15 @@ def _zeropower_via_newtonschulz5(G, steps):
27
  performance at all relative to UV^T, where USV^T = G is the SVD.
28
  """
29
  assert len(G.shape) == 2
30
- assert G.dtype == torch.bfloat16
31
  X = G # no manual typecast
32
 
33
  if G.size(0) > G.size(1):
34
  X = X.T
35
  # Ensure spectral norm is at most 1
36
  X = X / (X.norm() + 1e-7)
 
 
37
  # Perform the NS iterations
38
  for a, b, c in [
39
  (4.0848, -6.8946, 2.9270),
@@ -42,13 +49,10 @@ def _zeropower_via_newtonschulz5(G, steps):
42
  (2.8769, -3.1427, 1.2046),
43
  (2.8366, -3.0525, 1.2012),
44
  ]:
45
- A = X @ X.T
46
- # B = (
47
- # b * A + c * A @ A
48
- # )
49
- B = torch.addmm(A, A, A, alpha=c, beta=b)
50
- # X = a * X + B @ X
51
- X = torch.addmm(X, B, X, alpha=1.0, beta=a)
52
 
53
  if G.size(0) > G.size(1):
54
  X = X.T
@@ -69,51 +73,142 @@ class _muon_state:
69
  qk_clip_state = None
70
 
71
 
 
 
 
 
 
 
 
 
72
  @torch.no_grad()
73
- def _gather(p, state, rank, comm_stream, none_grad):
74
  """
75
- Gather the gradients to worker_rank.
76
- If none_grad is True, free p.grad after the gather.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  """
78
  with torch.cuda.stream(comm_stream):
79
- g = p.grad
 
80
 
81
- if rank == state.worker_rank:
82
- num_ranks = dist.get_world_size(group=state.process_group)
83
- gather_list = [
84
- torch.empty_like(g.to_local(), dtype=torch.bfloat16)
85
- for _ in range(num_ranks)
86
- ]
87
- else:
88
- gather_list = None
89
-
90
- g = g.to(torch.bfloat16)
91
- torch.distributed.gather(
92
- g.to_local(),
93
- dst=state.worker_rank,
94
- gather_list=gather_list,
95
- group=state.process_group,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  )
97
- if rank == state.worker_rank:
98
- if state.gathered_grad is not None:
99
- raise RuntimeError(
100
- "Gather event already exists, which should not happen.")
101
- state.gathered_grad = torch.cat(gather_list, dim=0)
102
- state.gather_event = torch.cuda.Event()
103
- state.gather_event.record()
104
- else:
105
- state.gathered_grad = None
106
- state.gather_event = None
107
- gather_list = None
108
- if none_grad:
109
- # We can safely free p.grad without calling record_stream:
110
- # p.grad.to_local().record_stream(comm_stream)
111
- # Explanation:
112
- # 1. p.grad is created on the default stream, but the default stream
113
- # is synchronized with the comm stream later.
114
- # 2. There is no further activity on the default stream before the optimizer finishes.
115
- # Therefore, it is safe to free p.grad directly on the comm stream.
116
- p.grad = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
 
119
  @torch.no_grad()
@@ -127,45 +222,145 @@ def _compute_u(p, state, steps, rank, compute_stream):
127
  raise RuntimeError("Gather event must be set before compute.")
128
  compute_stream.wait_event(state.gather_event)
129
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
 
130
  state.computed_u = u
131
- state.scattered_u = torch.empty_like(p.to_local(),
132
- dtype=torch.bfloat16)
133
- state.compute_event = torch.cuda.Event()
134
- state.compute_event.record()
135
- u = None
136
 
137
 
138
  @torch.no_grad()
139
- def _scatter(p, state, rank, comm_stream):
140
  """
141
- Scatter the computed_u from worker_rank to all ranks.
 
142
  """
 
 
 
 
 
 
 
 
 
 
143
 
 
 
 
 
144
  with torch.cuda.stream(comm_stream):
145
- if state.compute_event is None:
146
- raise RuntimeError("Compute event must be set before scatter.")
147
- comm_stream.wait_event(state.compute_event)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
- if rank == state.worker_rank:
150
- num_ranks = dist.get_world_size(group=state.process_group)
151
- # Clear the gathered gradient to free memory
152
- state.gathered_grad = None
 
 
 
 
153
 
154
- u = state.computed_u
155
- scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
156
- scatter_list = [s.contiguous() for s in scatter_list]
 
 
 
 
 
 
157
  else:
158
- scatter_list = None
 
 
159
 
160
- torch.distributed.scatter(
161
- state.scattered_u,
162
- scatter_list=scatter_list,
163
- src=state.worker_rank,
164
- group=state.process_group,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  )
166
- state.scatter_event = torch.cuda.Event()
167
- state.scatter_event.record()
168
- scatter_list = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
 
171
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
@@ -321,6 +516,11 @@ class Muon(torch.optim.Optimizer):
321
  "head_dim": 128,
322
  "threshold": 100
323
  }
 
 
 
 
 
324
  """
325
 
326
  def __init__(self,
@@ -339,7 +539,8 @@ class Muon(torch.optim.Optimizer):
339
  "k_indices": [],
340
  "head_dim": 128,
341
  "threshold": 100
342
- }):
 
343
  defaults = dict(
344
  lr=lr,
345
  weight_decay=weight_decay,
@@ -363,15 +564,13 @@ class Muon(torch.optim.Optimizer):
363
 
364
  super().__init__(params, defaults)
365
 
366
- if dist.is_initialized():
367
- self.rank = dist.get_rank()
368
- else:
369
- self.rank = None
370
 
371
  self.comm_stream = torch.cuda.Stream()
372
  self.compute_stream = torch.cuda.Stream()
373
  self.debug = debug
374
  self.clip_config = clip_config
 
375
 
376
  def _calc_flops(self, G, steps):
377
  assert len(G.shape) == 2
@@ -444,11 +643,18 @@ class Muon(torch.optim.Optimizer):
444
  if mesh is None:
445
  mesh = p.device_mesh
446
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
 
 
 
 
 
447
  elif mesh != p.device_mesh:
448
  raise ValueError("All parameters must be on the same mesh.")
449
 
 
450
  param_to_state[id(p)] = _muon_state()
451
- param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
 
452
  param_to_state[id(p)].process_group = process_group
453
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
454
  param_to_state[id(p)].qk_clip_state = qk_clip_state
@@ -478,7 +684,7 @@ class Muon(torch.optim.Optimizer):
478
  else:
479
  g = buf
480
 
481
- u = _zeropower_via_newtonschulz5(g.bfloat16(),
482
  steps=group["ns_steps"])
483
 
484
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
@@ -493,15 +699,12 @@ class Muon(torch.optim.Optimizer):
493
  def _update_g(self, p, g, group, momentum):
494
  # calc update
495
  state = self.state[p]
496
- if "momentum_buffer" not in state:
497
- state["momentum_buffer"] = torch.zeros_like(g)
498
- buf = state["momentum_buffer"]
499
- buf.mul_(momentum).add_(g)
500
  if group["nesterov"]:
501
- g = g.add(buf, alpha=momentum)
502
- else:
503
- g = buf
504
- return g
505
 
506
  @staticmethod
507
  def _update_p(p, u, lr, adjusted_lr, weight_decay):
@@ -585,11 +788,17 @@ class Muon(torch.optim.Optimizer):
585
  param_to_state, ordered_params = self.init_state_and_assign_params(
586
  names, params, group, qk_logits)
587
 
588
- def enqueue_gathers(start_idx, chunk_size):
589
- for p in ordered_params[start_idx:start_idx + chunk_size]:
590
- state = param_to_state[id(p)]
591
- _gather(p, state, self.rank, self.comm_stream,
592
- group["none_grad"])
 
 
 
 
 
 
593
 
594
  def enqueue_computes(start_idx, chunk_size):
595
  for p in ordered_params[start_idx:start_idx + chunk_size]:
@@ -597,10 +806,14 @@ class Muon(torch.optim.Optimizer):
597
  _compute_u(p, state, group["ns_steps"], self.rank,
598
  self.compute_stream)
599
 
600
- def enqueue_scatters(start_idx, chunk_size):
601
- for p in ordered_params[start_idx:start_idx + chunk_size]:
602
- state = param_to_state[id(p)]
603
- _scatter(p, state, self.rank, self.comm_stream)
 
 
 
 
604
 
605
  def enqueue_update_param(start_idx, chunk_size):
606
  for p in ordered_params[start_idx:start_idx + chunk_size]:
@@ -615,14 +828,16 @@ class Muon(torch.optim.Optimizer):
615
  # Wait grad update
616
  self.comm_stream.wait_stream(torch.cuda.current_stream())
617
 
618
- enqueue_gathers(0, chunk_size)
 
 
 
 
619
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
620
- enqueue_computes(i, chunk_size)
621
- if i > 0:
622
- enqueue_update_param(i - chunk_size, chunk_size)
623
- enqueue_gathers(i + chunk_size, chunk_size)
624
- enqueue_scatters(i, chunk_size)
625
- enqueue_update_param(i, chunk_size)
626
 
627
  # Wait the last update_param to finish
628
  torch.cuda.current_stream().wait_stream(self.compute_stream)
 
8
  import torch.distributed as dist
9
  from torch.distributed._tensor import DTensor, Replicate, Shard
10
 
11
+ from .matmul_transpose_triton import matmul_transpose_assign
12
+
13
  logger = logging.getLogger(__name__)
14
 
15
+ COMM_DTYPE = torch.bfloat16
16
+
17
 
18
  # This code snippet is a modified version adapted from the following GitHub repositories:
19
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
20
  # Muon's Newton–Schulz iteration causes high variance in singular values
21
  # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
22
  @torch.no_grad()
23
+ # matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
24
  def _zeropower_via_newtonschulz5(G, steps):
25
  """
26
  Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
 
32
  performance at all relative to UV^T, where USV^T = G is the SVD.
33
  """
34
  assert len(G.shape) == 2
35
+ assert G.dtype == COMM_DTYPE
36
  X = G # no manual typecast
37
 
38
  if G.size(0) > G.size(1):
39
  X = X.T
40
  # Ensure spectral norm is at most 1
41
  X = X / (X.norm() + 1e-7)
42
+ buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
43
+ buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
44
  # Perform the NS iterations
45
  for a, b, c in [
46
  (4.0848, -6.8946, 2.9270),
 
49
  (2.8769, -3.1427, 1.2046),
50
  (2.8366, -3.0525, 1.2012),
51
  ]:
52
+ matmul_transpose_assign(X, buf1)
53
+ matmul_transpose_assign(buf1, buf2)
54
+ buf1.mul_(b).add_(buf2, alpha=c)
55
+ X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)
 
 
 
56
 
57
  if G.size(0) > G.size(1):
58
  X = X.T
 
73
  qk_clip_state = None
74
 
75
 
76
+ def split_elems_for_src(param, src_rank, num_ranks) -> int:
77
+ rows = param.shape[0]
78
+ cols = int(param.numel() // rows)
79
+ base, rem = divmod(rows, num_ranks)
80
+ my_rows = base + (1 if src_rank < rem else 0)
81
+ return my_rows * cols
82
+
83
+
84
  @torch.no_grad()
85
+ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
86
  """
87
+ Pre-allocate gathered_grad buffer on compute_stream
88
+ before launching all2all gather
89
+ """
90
+ with torch.cuda.stream(compute_stream):
91
+ for p in params:
92
+ state = param_to_state[id(p)]
93
+ if rank == state.worker_rank:
94
+ num_ranks = dist.get_world_size(group=state.process_group)
95
+ state.gathered_grad = torch.empty(p.grad.numel(),
96
+ dtype=COMM_DTYPE,
97
+ device="cuda")
98
+ else:
99
+ state.gathered_grad = None
100
+
101
+ alloc_event = torch.cuda.Event()
102
+ alloc_event.record(compute_stream)
103
+ return alloc_event
104
+
105
+
106
+ @torch.no_grad()
107
+ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
108
+ alloc_event):
109
+ """
110
+ All2all gathers shards so each owner rank reconstructs its full gradient
111
  """
112
  with torch.cuda.stream(comm_stream):
113
+ process_group = param_to_state[id(params[0])].process_group
114
+ num_ranks = dist.get_world_size(group=process_group)
115
 
116
+ # Construct sending buffers
117
+ per_dst = [[] for _ in range(num_ranks)]
118
+ send_counts = [0] * num_ranks
119
+
120
+ for p in params:
121
+ state = param_to_state[id(p)]
122
+ dst = state.worker_rank
123
+ assert dst < num_ranks
124
+ shard_elems = split_elems_for_src(p, rank, num_ranks)
125
+ g = p.grad
126
+ g = g.to_local().to(COMM_DTYPE).contiguous().view(-1)
127
+ assert g.numel() == shard_elems
128
+ per_dst[dst].append(g)
129
+ send_counts[dst] += shard_elems
130
+
131
+ assert all(
132
+ len(v) > 0
133
+ for v in per_dst), "all params should be sharded to all devices"
134
+
135
+ send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
136
+ owned_params = [
137
+ p for p in params if param_to_state[id(p)].worker_rank == rank
138
+ ]
139
+
140
+ # Compute receive sizes and allocate receiving buffers
141
+ recv_counts = [0] * num_ranks
142
+
143
+ for src in range(num_ranks):
144
+ total = 0
145
+ for p in owned_params:
146
+ state = param_to_state[id(p)]
147
+ assert state.worker_rank == rank
148
+ total += split_elems_for_src(p, src, num_ranks)
149
+ recv_counts[src] = total
150
+
151
+ recv_total = sum(recv_counts)
152
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
153
+
154
+ #All2All
155
+ dist.all_to_all_single(
156
+ recv_buf,
157
+ send_buf,
158
+ output_split_sizes=recv_counts,
159
+ input_split_sizes=send_counts,
160
+ group=process_group,
161
  )
162
+
163
+ # Reconstructs gathered grad from the received buffer
164
+ #
165
+ # recv_buf (num ranks = 3)
166
+ #
167
+ # From rank 0 From rank 1 From rank 2
168
+ # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 |
169
+ #
170
+ # Outer loop:
171
+ # rank 0 -> rank 1 -> rank2
172
+ #
173
+ # Inner loop:
174
+ # p1_n -> p2_n -> p3_n
175
+
176
+ comm_stream.wait_event(alloc_event)
177
+
178
+ off = 0
179
+ write_offsets = {id(p): 0 for p in owned_params}
180
+ for src in range(num_ranks):
181
+ if recv_counts[src] == 0:
182
+ continue
183
+
184
+ block = recv_counts[src]
185
+ inner_off = 0
186
+ for p in owned_params:
187
+ state = param_to_state[id(p)]
188
+ assert state.worker_rank == rank
189
+ n = split_elems_for_src(p, src, num_ranks)
190
+ assert n > 0
191
+
192
+ sg = recv_buf.narrow(0, off + inner_off, n)
193
+ woff = write_offsets[id(p)]
194
+ dst = state.gathered_grad.narrow(0, woff, n)
195
+ dst.copy_(sg)
196
+
197
+ write_offsets[id(p)] += n
198
+ inner_off += n
199
+ off += block
200
+
201
+ for p in params:
202
+ state = param_to_state[id(p)]
203
+ if state.worker_rank == rank:
204
+ state.gathered_grad = state.gathered_grad.view_as(p)
205
+ state.gather_event = torch.cuda.Event()
206
+ state.gather_event.record(comm_stream)
207
+ else:
208
+ state.gathered_grad = None
209
+ state.gather_event = None
210
+ if none_grad:
211
+ p.grad = None
212
 
213
 
214
  @torch.no_grad()
 
222
  raise RuntimeError("Gather event must be set before compute.")
223
  compute_stream.wait_event(state.gather_event)
224
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
225
+ state.gathered_grad = None
226
  state.computed_u = u
227
+ state.compute_event = torch.cuda.Event()
228
+ state.compute_event.record()
229
+ else:
230
+ state.computed_u = None
231
+ state.compute_event = None
232
 
233
 
234
  @torch.no_grad()
235
+ def _alloc_scattered_u(params, param_to_state, rank, compute_stream):
236
  """
237
+ Pre-allocate scattered_u buffer on compute_stream
238
+ before launching all2all gather
239
  """
240
+ with torch.cuda.stream(compute_stream):
241
+ for p in params:
242
+ state = param_to_state[id(p)]
243
+ state.scattered_u = torch.empty_like(p.to_local(),
244
+ dtype=COMM_DTYPE)
245
+
246
+ alloc_event = torch.cuda.Event()
247
+ alloc_event.record(compute_stream)
248
+ return alloc_event
249
+
250
 
251
+ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
252
+ """
253
+ All2all scatters full gradients to all ranks
254
+ """
255
  with torch.cuda.stream(comm_stream):
256
+ process_group = param_to_state[id(params[0])].process_group
257
+ num_ranks = dist.get_world_size(group=process_group)
258
+ owned_params = [
259
+ p for p in params if param_to_state[id(p)].worker_rank == rank
260
+ ]
261
+
262
+ # Construct sending buffer
263
+ per_dst = [[] for _ in range(num_ranks)]
264
+ send_counts = [0] * num_ranks
265
+
266
+ if owned_params:
267
+ for p in owned_params:
268
+ state = param_to_state[id(p)]
269
+ if state.compute_event is None:
270
+ raise RuntimeError(
271
+ "Compute event must be set before scatter.")
272
+ comm_stream.wait_event(state.compute_event)
273
+ state.gathered_grad = None
274
 
275
+ assert state.computed_u is not None
276
+
277
+ u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1)
278
+
279
+ offset = 0
280
+ for dst in range(num_ranks):
281
+ n = split_elems_for_src(p, dst, num_ranks)
282
+ assert n > 0
283
 
284
+ su = u_full.narrow(0, offset, n)
285
+ per_dst[dst].append(su)
286
+ send_counts[dst] += n
287
+ offset += n
288
+
289
+ assert offset == u_full.numel()
290
+
291
+ if any(len(v) > 0 for v in per_dst):
292
+ send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
293
  else:
294
+ # all_to_all requires participation from all ranks
295
+ # Even non-owner ranks must join the collective call
296
+ send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
297
 
298
+ # Compute receive sizes and allocate receiving buffers
299
+ recv_counts = [0] * num_ranks
300
+
301
+ for src in range(num_ranks):
302
+ total = 0
303
+ for p in params:
304
+ state = param_to_state[id(p)]
305
+ if state.worker_rank != src:
306
+ continue
307
+ total += split_elems_for_src(p, rank, num_ranks)
308
+ recv_counts[src] = total
309
+
310
+ recv_total = sum(recv_counts)
311
+ assert recv_total > 0
312
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
313
+
314
+ #All2All
315
+ dist.all_to_all_single(
316
+ recv_buf,
317
+ send_buf,
318
+ output_split_sizes=recv_counts,
319
+ input_split_sizes=send_counts,
320
+ group=process_group,
321
  )
322
+
323
+ # Copy to pre-allocated scattered_u buffer from the received buffer
324
+ #
325
+ # recv_buf (num ranks = 3, local_rank = 0)
326
+ #
327
+ # From rank 0 From rank 1 From rank 2
328
+ # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 |
329
+ #
330
+ # Outer loop:
331
+ # rank 0 -> rank 1 -> rank2
332
+ #
333
+ # Inner loop:
334
+ # src(0) : p1_0 -> p2_0 -> p3_0
335
+ # src(1) : p4_0
336
+ # src(2) : p5_0 -> p6_0
337
+
338
+ comm_stream.wait_event(alloc_event)
339
+
340
+ off = 0
341
+ for src in range(num_ranks):
342
+ block = recv_counts[src]
343
+ if block == 0:
344
+ continue
345
+
346
+ inner_off = 0
347
+ for p in params:
348
+ state = param_to_state[id(p)]
349
+ if state.worker_rank != src:
350
+ continue
351
+ n = split_elems_for_src(p, rank, num_ranks)
352
+ assert n > 0
353
+
354
+ flat_local = recv_buf.narrow(0, off + inner_off,
355
+ n).view_as(p.to_local())
356
+ state.scattered_u.copy_(flat_local)
357
+
358
+ state.scatter_event = torch.cuda.Event()
359
+ state.scatter_event.record(comm_stream)
360
+ inner_off += n
361
+
362
+ assert inner_off == block
363
+ off += block
364
 
365
 
366
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
 
516
  "head_dim": 128,
517
  "threshold": 100
518
  }
519
+ overlap_step : How many all2all gather, compute operations are launched in advance
520
+ before the corresponding all2all scatter steps begin.
521
+ A higher overlap_step increases memory usage but can improve
522
+ performance by overlapping communication.
523
+ Parallel muon only.
524
  """
525
 
526
  def __init__(self,
 
539
  "k_indices": [],
540
  "head_dim": 128,
541
  "threshold": 100
542
+ },
543
+ overlap_step=5):
544
  defaults = dict(
545
  lr=lr,
546
  weight_decay=weight_decay,
 
564
 
565
  super().__init__(params, defaults)
566
 
567
+ self.rank = None
 
 
 
568
 
569
  self.comm_stream = torch.cuda.Stream()
570
  self.compute_stream = torch.cuda.Stream()
571
  self.debug = debug
572
  self.clip_config = clip_config
573
+ self.overlap_step = overlap_step
574
 
575
  def _calc_flops(self, G, steps):
576
  assert len(G.shape) == 2
 
643
  if mesh is None:
644
  mesh = p.device_mesh
645
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
646
+ local_rank = dist.get_rank(group=process_group)
647
+ if self.rank is None:
648
+ self.rank = dist.get_rank(group=process_group)
649
+ else:
650
+ assert self.rank == local_rank
651
  elif mesh != p.device_mesh:
652
  raise ValueError("All parameters must be on the same mesh.")
653
 
654
+ num_ranks = dist.get_world_size(group=process_group)
655
  param_to_state[id(p)] = _muon_state()
656
+ param_to_state[id(
657
+ p)].worker_rank = shard_mesh[round_robin].item() % num_ranks
658
  param_to_state[id(p)].process_group = process_group
659
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
660
  param_to_state[id(p)].qk_clip_state = qk_clip_state
 
684
  else:
685
  g = buf
686
 
687
+ u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
688
  steps=group["ns_steps"])
689
 
690
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
 
699
  def _update_g(self, p, g, group, momentum):
700
  # calc update
701
  state = self.state[p]
702
+ buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
703
+ torch.add(g, buf, alpha=momentum, out=buf)
 
 
704
  if group["nesterov"]:
705
+ g.add_(buf, alpha=momentum)
706
+ return g
707
+ return buf
 
708
 
709
  @staticmethod
710
  def _update_p(p, u, lr, adjusted_lr, weight_decay):
 
788
  param_to_state, ordered_params = self.init_state_and_assign_params(
789
  names, params, group, qk_logits)
790
 
791
+ assert self.rank is not None
792
+
793
+ def enqueue_all2all_gather(start_idx, chunk_size):
794
+ target_params = ordered_params[start_idx:start_idx + chunk_size]
795
+ if target_params:
796
+ alloc_event = _alloc_gathered_grad(target_params,
797
+ param_to_state, self.rank,
798
+ self.compute_stream)
799
+ _all2all_gather(target_params, param_to_state, self.rank,
800
+ self.comm_stream, group["none_grad"],
801
+ alloc_event)
802
 
803
  def enqueue_computes(start_idx, chunk_size):
804
  for p in ordered_params[start_idx:start_idx + chunk_size]:
 
806
  _compute_u(p, state, group["ns_steps"], self.rank,
807
  self.compute_stream)
808
 
809
+ def enqueue_all2all_scatter(start_idx, chunk_size):
810
+ target_params = ordered_params[start_idx:start_idx + chunk_size]
811
+ if target_params:
812
+ alloc_event = _alloc_scattered_u(target_params, param_to_state,
813
+ self.rank,
814
+ self.compute_stream)
815
+ _all2all_scatter(target_params, param_to_state, self.rank,
816
+ self.comm_stream, alloc_event)
817
 
818
  def enqueue_update_param(start_idx, chunk_size):
819
  for p in ordered_params[start_idx:start_idx + chunk_size]:
 
828
  # Wait grad update
829
  self.comm_stream.wait_stream(torch.cuda.current_stream())
830
 
831
+ overlap_step = self.overlap_step
832
+ for i in range(0, overlap_step):
833
+ enqueue_all2all_gather(i * chunk_size, chunk_size)
834
+ enqueue_computes(i * chunk_size, chunk_size)
835
+
836
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
837
+ enqueue_all2all_scatter(i, chunk_size)
838
+ enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size)
839
+ enqueue_update_param(i, chunk_size)
840
+ enqueue_computes(i + overlap_step * chunk_size, chunk_size)
 
 
841
 
842
  # Wait the last update_param to finish
843
  torch.cuda.current_stream().wait_stream(self.compute_stream)
build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_9c21645_dirty
3
- ops = torch.ops._optimizer_9c21645_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_9c21645_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_15336dc_dirty
3
+ ops = torch.ops._optimizer_15336dc_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_15336dc_dirty::{op_name}"
build/{torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_9c21645_dirty.abi3.so → torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_15336dc_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:eb40a06623bb3668b82ff248b5a3c1bcf41e7f3f860888b261505b3a71257bc7
3
  size 1883344
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e06baa32b0950126ee192654bd9f7adc79cc05d8ec39d2078c70d62ee81fdcd5
3
  size 1883344
build/torch28-cxx11-cu128-x86_64-linux/optimizer/matmul_transpose_triton.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright (c) 2025 Tianyang Lin
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ import torch
24
+ import triton
25
+ import triton.language as tl
26
+
27
+
28
+ def get_autotune_config():
29
+ return [
30
+ triton.Config(
31
+ {
32
+ 'BLOCK_SIZE_M': blk_m,
33
+ 'BLOCK_SIZE_K': blk_k,
34
+ 'GROUP_SIZE_M': grp_sz
35
+ },
36
+ num_stages=n_stages,
37
+ num_warps=n_warps) for blk_m in [32, 64, 128]
38
+ for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5]
39
+ for n_warps in [4, 8]
40
+ ]
41
+
42
+
43
+ @triton.autotune(
44
+ configs=get_autotune_config(),
45
+ key=['M', 'K'],
46
+ )
47
+ @triton.jit
48
+ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
49
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
50
+ GROUP_SIZE_M: tl.constexpr):
51
+ """
52
+ Core kernel jit function of matmul_transpose that computes y = x @ x.T
53
+ The code is a simple adaptation from the triton `matmul` tutorial:
54
+ https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
55
+ """
56
+ pid = tl.program_id(axis=0)
57
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
58
+ num_pid_n = tl.cdiv(M, BLOCK_SIZE_M)
59
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
60
+ group_id = pid // num_pid_in_group
61
+ first_pid_m = group_id * GROUP_SIZE_M
62
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
63
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
64
+ pid_n = (pid % num_pid_in_group) // group_size_m
65
+ if pid_m > pid_n:
66
+ return
67
+
68
+ offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
69
+ offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
70
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
71
+ # we use a & b ptrs to denote different rows of x.
72
+ a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
73
+ b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk)
74
+
75
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32)
76
+
77
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
78
+ a = tl.load(a_ptrs,
79
+ mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
80
+ other=0.0)
81
+ b = tl.load(b_ptrs,
82
+ mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
83
+ other=0.0)
84
+ accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator)
85
+ a_ptrs += BLOCK_SIZE_K * stride_xk
86
+ b_ptrs += BLOCK_SIZE_K * stride_xk
87
+ # use dtype.element_ty to accommodate different input datatypes as in cpp templates
88
+ # https://github.com/triton-lang/triton/issues/2252
89
+ c = accumulator.to(x.dtype.element_ty)
90
+
91
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
92
+ offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
93
+ c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :]
94
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
95
+ tl.store(c_ptrs, c, mask=c_mask)
96
+
97
+ # transpose and copy
98
+ if pid_m < pid_n:
99
+ ct_ptrs = y + stride_ym * offs_cn[:,
100
+ None] + stride_yn * offs_cm[None, :]
101
+ ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M)
102
+ tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)
103
+
104
+
105
+ def matmul_transpose_assign(d_in, d_out):
106
+ assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor"
107
+ assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor"
108
+ assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device"
109
+ assert d_in.dtype == d_out.dtype, "Inputs must have the same data type"
110
+ assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor"
111
+ assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor"
112
+ assert d_in.size(0) == d_out.size(0) == d_out.size(0), \
113
+ "First dimension of `d_in` must match first and second dimension of `d_out`"
114
+
115
+ d_in = d_in.contiguous()
116
+ M, K = d_in.shape
117
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
118
+ M, META['BLOCK_SIZE_M']), )
119
+ with torch.cuda.device(d_in.device.index):
120
+ mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
121
+ d_out.stride(0), d_out.stride(1))
122
+
123
+
124
+ def matmul_transpose(d_in):
125
+ M, _ = d_in.shape
126
+ d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype)
127
+ matmul_transpose_assign(d_in, d_out)
128
+ return d_out
build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py CHANGED
@@ -8,14 +8,19 @@ import torch
8
  import torch.distributed as dist
9
  from torch.distributed._tensor import DTensor, Replicate, Shard
10
 
 
 
11
  logger = logging.getLogger(__name__)
12
 
 
 
13
 
14
  # This code snippet is a modified version adapted from the following GitHub repositories:
15
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
16
  # Muon's Newton–Schulz iteration causes high variance in singular values
17
  # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
18
  @torch.no_grad()
 
19
  def _zeropower_via_newtonschulz5(G, steps):
20
  """
21
  Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
@@ -27,13 +32,15 @@ def _zeropower_via_newtonschulz5(G, steps):
27
  performance at all relative to UV^T, where USV^T = G is the SVD.
28
  """
29
  assert len(G.shape) == 2
30
- assert G.dtype == torch.bfloat16
31
  X = G # no manual typecast
32
 
33
  if G.size(0) > G.size(1):
34
  X = X.T
35
  # Ensure spectral norm is at most 1
36
  X = X / (X.norm() + 1e-7)
 
 
37
  # Perform the NS iterations
38
  for a, b, c in [
39
  (4.0848, -6.8946, 2.9270),
@@ -42,13 +49,10 @@ def _zeropower_via_newtonschulz5(G, steps):
42
  (2.8769, -3.1427, 1.2046),
43
  (2.8366, -3.0525, 1.2012),
44
  ]:
45
- A = X @ X.T
46
- # B = (
47
- # b * A + c * A @ A
48
- # )
49
- B = torch.addmm(A, A, A, alpha=c, beta=b)
50
- # X = a * X + B @ X
51
- X = torch.addmm(X, B, X, alpha=1.0, beta=a)
52
 
53
  if G.size(0) > G.size(1):
54
  X = X.T
@@ -69,51 +73,142 @@ class _muon_state:
69
  qk_clip_state = None
70
 
71
 
 
 
 
 
 
 
 
 
72
  @torch.no_grad()
73
- def _gather(p, state, rank, comm_stream, none_grad):
74
  """
75
- Gather the gradients to worker_rank.
76
- If none_grad is True, free p.grad after the gather.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  """
78
  with torch.cuda.stream(comm_stream):
79
- g = p.grad
 
80
 
81
- if rank == state.worker_rank:
82
- num_ranks = dist.get_world_size(group=state.process_group)
83
- gather_list = [
84
- torch.empty_like(g.to_local(), dtype=torch.bfloat16)
85
- for _ in range(num_ranks)
86
- ]
87
- else:
88
- gather_list = None
89
-
90
- g = g.to(torch.bfloat16)
91
- torch.distributed.gather(
92
- g.to_local(),
93
- dst=state.worker_rank,
94
- gather_list=gather_list,
95
- group=state.process_group,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  )
97
- if rank == state.worker_rank:
98
- if state.gathered_grad is not None:
99
- raise RuntimeError(
100
- "Gather event already exists, which should not happen.")
101
- state.gathered_grad = torch.cat(gather_list, dim=0)
102
- state.gather_event = torch.cuda.Event()
103
- state.gather_event.record()
104
- else:
105
- state.gathered_grad = None
106
- state.gather_event = None
107
- gather_list = None
108
- if none_grad:
109
- # We can safely free p.grad without calling record_stream:
110
- # p.grad.to_local().record_stream(comm_stream)
111
- # Explanation:
112
- # 1. p.grad is created on the default stream, but the default stream
113
- # is synchronized with the comm stream later.
114
- # 2. There is no further activity on the default stream before the optimizer finishes.
115
- # Therefore, it is safe to free p.grad directly on the comm stream.
116
- p.grad = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
 
119
  @torch.no_grad()
@@ -127,45 +222,145 @@ def _compute_u(p, state, steps, rank, compute_stream):
127
  raise RuntimeError("Gather event must be set before compute.")
128
  compute_stream.wait_event(state.gather_event)
129
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
 
130
  state.computed_u = u
131
- state.scattered_u = torch.empty_like(p.to_local(),
132
- dtype=torch.bfloat16)
133
- state.compute_event = torch.cuda.Event()
134
- state.compute_event.record()
135
- u = None
136
 
137
 
138
  @torch.no_grad()
139
- def _scatter(p, state, rank, comm_stream):
140
  """
141
- Scatter the computed_u from worker_rank to all ranks.
 
142
  """
 
 
 
 
 
 
 
 
 
 
143
 
 
 
 
 
144
  with torch.cuda.stream(comm_stream):
145
- if state.compute_event is None:
146
- raise RuntimeError("Compute event must be set before scatter.")
147
- comm_stream.wait_event(state.compute_event)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
- if rank == state.worker_rank:
150
- num_ranks = dist.get_world_size(group=state.process_group)
151
- # Clear the gathered gradient to free memory
152
- state.gathered_grad = None
 
 
 
 
153
 
154
- u = state.computed_u
155
- scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
156
- scatter_list = [s.contiguous() for s in scatter_list]
 
 
 
 
 
 
157
  else:
158
- scatter_list = None
 
 
159
 
160
- torch.distributed.scatter(
161
- state.scattered_u,
162
- scatter_list=scatter_list,
163
- src=state.worker_rank,
164
- group=state.process_group,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  )
166
- state.scatter_event = torch.cuda.Event()
167
- state.scatter_event.record()
168
- scatter_list = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
 
171
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
@@ -321,6 +516,11 @@ class Muon(torch.optim.Optimizer):
321
  "head_dim": 128,
322
  "threshold": 100
323
  }
 
 
 
 
 
324
  """
325
 
326
  def __init__(self,
@@ -339,7 +539,8 @@ class Muon(torch.optim.Optimizer):
339
  "k_indices": [],
340
  "head_dim": 128,
341
  "threshold": 100
342
- }):
 
343
  defaults = dict(
344
  lr=lr,
345
  weight_decay=weight_decay,
@@ -363,15 +564,13 @@ class Muon(torch.optim.Optimizer):
363
 
364
  super().__init__(params, defaults)
365
 
366
- if dist.is_initialized():
367
- self.rank = dist.get_rank()
368
- else:
369
- self.rank = None
370
 
371
  self.comm_stream = torch.cuda.Stream()
372
  self.compute_stream = torch.cuda.Stream()
373
  self.debug = debug
374
  self.clip_config = clip_config
 
375
 
376
  def _calc_flops(self, G, steps):
377
  assert len(G.shape) == 2
@@ -444,11 +643,18 @@ class Muon(torch.optim.Optimizer):
444
  if mesh is None:
445
  mesh = p.device_mesh
446
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
 
 
 
 
 
447
  elif mesh != p.device_mesh:
448
  raise ValueError("All parameters must be on the same mesh.")
449
 
 
450
  param_to_state[id(p)] = _muon_state()
451
- param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
 
452
  param_to_state[id(p)].process_group = process_group
453
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
454
  param_to_state[id(p)].qk_clip_state = qk_clip_state
@@ -478,7 +684,7 @@ class Muon(torch.optim.Optimizer):
478
  else:
479
  g = buf
480
 
481
- u = _zeropower_via_newtonschulz5(g.bfloat16(),
482
  steps=group["ns_steps"])
483
 
484
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
@@ -493,15 +699,12 @@ class Muon(torch.optim.Optimizer):
493
  def _update_g(self, p, g, group, momentum):
494
  # calc update
495
  state = self.state[p]
496
- if "momentum_buffer" not in state:
497
- state["momentum_buffer"] = torch.zeros_like(g)
498
- buf = state["momentum_buffer"]
499
- buf.mul_(momentum).add_(g)
500
  if group["nesterov"]:
501
- g = g.add(buf, alpha=momentum)
502
- else:
503
- g = buf
504
- return g
505
 
506
  @staticmethod
507
  def _update_p(p, u, lr, adjusted_lr, weight_decay):
@@ -585,11 +788,17 @@ class Muon(torch.optim.Optimizer):
585
  param_to_state, ordered_params = self.init_state_and_assign_params(
586
  names, params, group, qk_logits)
587
 
588
- def enqueue_gathers(start_idx, chunk_size):
589
- for p in ordered_params[start_idx:start_idx + chunk_size]:
590
- state = param_to_state[id(p)]
591
- _gather(p, state, self.rank, self.comm_stream,
592
- group["none_grad"])
 
 
 
 
 
 
593
 
594
  def enqueue_computes(start_idx, chunk_size):
595
  for p in ordered_params[start_idx:start_idx + chunk_size]:
@@ -597,10 +806,14 @@ class Muon(torch.optim.Optimizer):
597
  _compute_u(p, state, group["ns_steps"], self.rank,
598
  self.compute_stream)
599
 
600
- def enqueue_scatters(start_idx, chunk_size):
601
- for p in ordered_params[start_idx:start_idx + chunk_size]:
602
- state = param_to_state[id(p)]
603
- _scatter(p, state, self.rank, self.comm_stream)
 
 
 
 
604
 
605
  def enqueue_update_param(start_idx, chunk_size):
606
  for p in ordered_params[start_idx:start_idx + chunk_size]:
@@ -615,14 +828,16 @@ class Muon(torch.optim.Optimizer):
615
  # Wait grad update
616
  self.comm_stream.wait_stream(torch.cuda.current_stream())
617
 
618
- enqueue_gathers(0, chunk_size)
 
 
 
 
619
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
620
- enqueue_computes(i, chunk_size)
621
- if i > 0:
622
- enqueue_update_param(i - chunk_size, chunk_size)
623
- enqueue_gathers(i + chunk_size, chunk_size)
624
- enqueue_scatters(i, chunk_size)
625
- enqueue_update_param(i, chunk_size)
626
 
627
  # Wait the last update_param to finish
628
  torch.cuda.current_stream().wait_stream(self.compute_stream)
 
8
  import torch.distributed as dist
9
  from torch.distributed._tensor import DTensor, Replicate, Shard
10
 
11
+ from .matmul_transpose_triton import matmul_transpose_assign
12
+
13
  logger = logging.getLogger(__name__)
14
 
15
+ COMM_DTYPE = torch.bfloat16
16
+
17
 
18
  # This code snippet is a modified version adapted from the following GitHub repositories:
19
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
20
  # Muon's Newton–Schulz iteration causes high variance in singular values
21
  # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
22
  @torch.no_grad()
23
+ # matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
24
  def _zeropower_via_newtonschulz5(G, steps):
25
  """
26
  Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
 
32
  performance at all relative to UV^T, where USV^T = G is the SVD.
33
  """
34
  assert len(G.shape) == 2
35
+ assert G.dtype == COMM_DTYPE
36
  X = G # no manual typecast
37
 
38
  if G.size(0) > G.size(1):
39
  X = X.T
40
  # Ensure spectral norm is at most 1
41
  X = X / (X.norm() + 1e-7)
42
+ buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
43
+ buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
44
  # Perform the NS iterations
45
  for a, b, c in [
46
  (4.0848, -6.8946, 2.9270),
 
49
  (2.8769, -3.1427, 1.2046),
50
  (2.8366, -3.0525, 1.2012),
51
  ]:
52
+ matmul_transpose_assign(X, buf1)
53
+ matmul_transpose_assign(buf1, buf2)
54
+ buf1.mul_(b).add_(buf2, alpha=c)
55
+ X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)
 
 
 
56
 
57
  if G.size(0) > G.size(1):
58
  X = X.T
 
73
  qk_clip_state = None
74
 
75
 
76
+ def split_elems_for_src(param, src_rank, num_ranks) -> int:
77
+ rows = param.shape[0]
78
+ cols = int(param.numel() // rows)
79
+ base, rem = divmod(rows, num_ranks)
80
+ my_rows = base + (1 if src_rank < rem else 0)
81
+ return my_rows * cols
82
+
83
+
84
  @torch.no_grad()
85
+ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
86
  """
87
+ Pre-allocate gathered_grad buffer on compute_stream
88
+ before launching all2all gather
89
+ """
90
+ with torch.cuda.stream(compute_stream):
91
+ for p in params:
92
+ state = param_to_state[id(p)]
93
+ if rank == state.worker_rank:
94
+ num_ranks = dist.get_world_size(group=state.process_group)
95
+ state.gathered_grad = torch.empty(p.grad.numel(),
96
+ dtype=COMM_DTYPE,
97
+ device="cuda")
98
+ else:
99
+ state.gathered_grad = None
100
+
101
+ alloc_event = torch.cuda.Event()
102
+ alloc_event.record(compute_stream)
103
+ return alloc_event
104
+
105
+
106
+ @torch.no_grad()
107
+ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
108
+ alloc_event):
109
+ """
110
+ All2all gathers shards so each owner rank reconstructs its full gradient
111
  """
112
  with torch.cuda.stream(comm_stream):
113
+ process_group = param_to_state[id(params[0])].process_group
114
+ num_ranks = dist.get_world_size(group=process_group)
115
 
116
+ # Construct sending buffers
117
+ per_dst = [[] for _ in range(num_ranks)]
118
+ send_counts = [0] * num_ranks
119
+
120
+ for p in params:
121
+ state = param_to_state[id(p)]
122
+ dst = state.worker_rank
123
+ assert dst < num_ranks
124
+ shard_elems = split_elems_for_src(p, rank, num_ranks)
125
+ g = p.grad
126
+ g = g.to_local().to(COMM_DTYPE).contiguous().view(-1)
127
+ assert g.numel() == shard_elems
128
+ per_dst[dst].append(g)
129
+ send_counts[dst] += shard_elems
130
+
131
+ assert all(
132
+ len(v) > 0
133
+ for v in per_dst), "all params should be sharded to all devices"
134
+
135
+ send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
136
+ owned_params = [
137
+ p for p in params if param_to_state[id(p)].worker_rank == rank
138
+ ]
139
+
140
+ # Compute receive sizes and allocate receiving buffers
141
+ recv_counts = [0] * num_ranks
142
+
143
+ for src in range(num_ranks):
144
+ total = 0
145
+ for p in owned_params:
146
+ state = param_to_state[id(p)]
147
+ assert state.worker_rank == rank
148
+ total += split_elems_for_src(p, src, num_ranks)
149
+ recv_counts[src] = total
150
+
151
+ recv_total = sum(recv_counts)
152
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
153
+
154
+ #All2All
155
+ dist.all_to_all_single(
156
+ recv_buf,
157
+ send_buf,
158
+ output_split_sizes=recv_counts,
159
+ input_split_sizes=send_counts,
160
+ group=process_group,
161
  )
162
+
163
+ # Reconstructs gathered grad from the received buffer
164
+ #
165
+ # recv_buf (num ranks = 3)
166
+ #
167
+ # From rank 0 From rank 1 From rank 2
168
+ # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 |
169
+ #
170
+ # Outer loop:
171
+ # rank 0 -> rank 1 -> rank2
172
+ #
173
+ # Inner loop:
174
+ # p1_n -> p2_n -> p3_n
175
+
176
+ comm_stream.wait_event(alloc_event)
177
+
178
+ off = 0
179
+ write_offsets = {id(p): 0 for p in owned_params}
180
+ for src in range(num_ranks):
181
+ if recv_counts[src] == 0:
182
+ continue
183
+
184
+ block = recv_counts[src]
185
+ inner_off = 0
186
+ for p in owned_params:
187
+ state = param_to_state[id(p)]
188
+ assert state.worker_rank == rank
189
+ n = split_elems_for_src(p, src, num_ranks)
190
+ assert n > 0
191
+
192
+ sg = recv_buf.narrow(0, off + inner_off, n)
193
+ woff = write_offsets[id(p)]
194
+ dst = state.gathered_grad.narrow(0, woff, n)
195
+ dst.copy_(sg)
196
+
197
+ write_offsets[id(p)] += n
198
+ inner_off += n
199
+ off += block
200
+
201
+ for p in params:
202
+ state = param_to_state[id(p)]
203
+ if state.worker_rank == rank:
204
+ state.gathered_grad = state.gathered_grad.view_as(p)
205
+ state.gather_event = torch.cuda.Event()
206
+ state.gather_event.record(comm_stream)
207
+ else:
208
+ state.gathered_grad = None
209
+ state.gather_event = None
210
+ if none_grad:
211
+ p.grad = None
212
 
213
 
214
  @torch.no_grad()
 
222
  raise RuntimeError("Gather event must be set before compute.")
223
  compute_stream.wait_event(state.gather_event)
224
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
225
+ state.gathered_grad = None
226
  state.computed_u = u
227
+ state.compute_event = torch.cuda.Event()
228
+ state.compute_event.record()
229
+ else:
230
+ state.computed_u = None
231
+ state.compute_event = None
232
 
233
 
234
  @torch.no_grad()
235
+ def _alloc_scattered_u(params, param_to_state, rank, compute_stream):
236
  """
237
+ Pre-allocate scattered_u buffer on compute_stream
238
+ before launching all2all gather
239
  """
240
+ with torch.cuda.stream(compute_stream):
241
+ for p in params:
242
+ state = param_to_state[id(p)]
243
+ state.scattered_u = torch.empty_like(p.to_local(),
244
+ dtype=COMM_DTYPE)
245
+
246
+ alloc_event = torch.cuda.Event()
247
+ alloc_event.record(compute_stream)
248
+ return alloc_event
249
+
250
 
251
+ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
252
+ """
253
+ All2all scatters full gradients to all ranks
254
+ """
255
  with torch.cuda.stream(comm_stream):
256
+ process_group = param_to_state[id(params[0])].process_group
257
+ num_ranks = dist.get_world_size(group=process_group)
258
+ owned_params = [
259
+ p for p in params if param_to_state[id(p)].worker_rank == rank
260
+ ]
261
+
262
+ # Construct sending buffer
263
+ per_dst = [[] for _ in range(num_ranks)]
264
+ send_counts = [0] * num_ranks
265
+
266
+ if owned_params:
267
+ for p in owned_params:
268
+ state = param_to_state[id(p)]
269
+ if state.compute_event is None:
270
+ raise RuntimeError(
271
+ "Compute event must be set before scatter.")
272
+ comm_stream.wait_event(state.compute_event)
273
+ state.gathered_grad = None
274
 
275
+ assert state.computed_u is not None
276
+
277
+ u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1)
278
+
279
+ offset = 0
280
+ for dst in range(num_ranks):
281
+ n = split_elems_for_src(p, dst, num_ranks)
282
+ assert n > 0
283
 
284
+ su = u_full.narrow(0, offset, n)
285
+ per_dst[dst].append(su)
286
+ send_counts[dst] += n
287
+ offset += n
288
+
289
+ assert offset == u_full.numel()
290
+
291
+ if any(len(v) > 0 for v in per_dst):
292
+ send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
293
  else:
294
+ # all_to_all requires participation from all ranks
295
+ # Even non-owner ranks must join the collective call
296
+ send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
297
 
298
+ # Compute receive sizes and allocate receiving buffers
299
+ recv_counts = [0] * num_ranks
300
+
301
+ for src in range(num_ranks):
302
+ total = 0
303
+ for p in params:
304
+ state = param_to_state[id(p)]
305
+ if state.worker_rank != src:
306
+ continue
307
+ total += split_elems_for_src(p, rank, num_ranks)
308
+ recv_counts[src] = total
309
+
310
+ recv_total = sum(recv_counts)
311
+ assert recv_total > 0
312
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
313
+
314
+ #All2All
315
+ dist.all_to_all_single(
316
+ recv_buf,
317
+ send_buf,
318
+ output_split_sizes=recv_counts,
319
+ input_split_sizes=send_counts,
320
+ group=process_group,
321
  )
322
+
323
+ # Copy to pre-allocated scattered_u buffer from the received buffer
324
+ #
325
+ # recv_buf (num ranks = 3, local_rank = 0)
326
+ #
327
+ # From rank 0 From rank 1 From rank 2
328
+ # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 |
329
+ #
330
+ # Outer loop:
331
+ # rank 0 -> rank 1 -> rank2
332
+ #
333
+ # Inner loop:
334
+ # src(0) : p1_0 -> p2_0 -> p3_0
335
+ # src(1) : p4_0
336
+ # src(2) : p5_0 -> p6_0
337
+
338
+ comm_stream.wait_event(alloc_event)
339
+
340
+ off = 0
341
+ for src in range(num_ranks):
342
+ block = recv_counts[src]
343
+ if block == 0:
344
+ continue
345
+
346
+ inner_off = 0
347
+ for p in params:
348
+ state = param_to_state[id(p)]
349
+ if state.worker_rank != src:
350
+ continue
351
+ n = split_elems_for_src(p, rank, num_ranks)
352
+ assert n > 0
353
+
354
+ flat_local = recv_buf.narrow(0, off + inner_off,
355
+ n).view_as(p.to_local())
356
+ state.scattered_u.copy_(flat_local)
357
+
358
+ state.scatter_event = torch.cuda.Event()
359
+ state.scatter_event.record(comm_stream)
360
+ inner_off += n
361
+
362
+ assert inner_off == block
363
+ off += block
364
 
365
 
366
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
 
516
  "head_dim": 128,
517
  "threshold": 100
518
  }
519
+ overlap_step : How many all2all gather, compute operations are launched in advance
520
+ before the corresponding all2all scatter steps begin.
521
+ A higher overlap_step increases memory usage but can improve
522
+ performance by overlapping communication.
523
+ Parallel muon only.
524
  """
525
 
526
  def __init__(self,
 
539
  "k_indices": [],
540
  "head_dim": 128,
541
  "threshold": 100
542
+ },
543
+ overlap_step=5):
544
  defaults = dict(
545
  lr=lr,
546
  weight_decay=weight_decay,
 
564
 
565
  super().__init__(params, defaults)
566
 
567
+ self.rank = None
 
 
 
568
 
569
  self.comm_stream = torch.cuda.Stream()
570
  self.compute_stream = torch.cuda.Stream()
571
  self.debug = debug
572
  self.clip_config = clip_config
573
+ self.overlap_step = overlap_step
574
 
575
  def _calc_flops(self, G, steps):
576
  assert len(G.shape) == 2
 
643
  if mesh is None:
644
  mesh = p.device_mesh
645
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
646
+ local_rank = dist.get_rank(group=process_group)
647
+ if self.rank is None:
648
+ self.rank = dist.get_rank(group=process_group)
649
+ else:
650
+ assert self.rank == local_rank
651
  elif mesh != p.device_mesh:
652
  raise ValueError("All parameters must be on the same mesh.")
653
 
654
+ num_ranks = dist.get_world_size(group=process_group)
655
  param_to_state[id(p)] = _muon_state()
656
+ param_to_state[id(
657
+ p)].worker_rank = shard_mesh[round_robin].item() % num_ranks
658
  param_to_state[id(p)].process_group = process_group
659
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
660
  param_to_state[id(p)].qk_clip_state = qk_clip_state
 
684
  else:
685
  g = buf
686
 
687
+ u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
688
  steps=group["ns_steps"])
689
 
690
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
 
699
  def _update_g(self, p, g, group, momentum):
700
  # calc update
701
  state = self.state[p]
702
+ buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
703
+ torch.add(g, buf, alpha=momentum, out=buf)
 
 
704
  if group["nesterov"]:
705
+ g.add_(buf, alpha=momentum)
706
+ return g
707
+ return buf
 
708
 
709
  @staticmethod
710
  def _update_p(p, u, lr, adjusted_lr, weight_decay):
 
788
  param_to_state, ordered_params = self.init_state_and_assign_params(
789
  names, params, group, qk_logits)
790
 
791
+ assert self.rank is not None
792
+
793
+ def enqueue_all2all_gather(start_idx, chunk_size):
794
+ target_params = ordered_params[start_idx:start_idx + chunk_size]
795
+ if target_params:
796
+ alloc_event = _alloc_gathered_grad(target_params,
797
+ param_to_state, self.rank,
798
+ self.compute_stream)
799
+ _all2all_gather(target_params, param_to_state, self.rank,
800
+ self.comm_stream, group["none_grad"],
801
+ alloc_event)
802
 
803
  def enqueue_computes(start_idx, chunk_size):
804
  for p in ordered_params[start_idx:start_idx + chunk_size]:
 
806
  _compute_u(p, state, group["ns_steps"], self.rank,
807
  self.compute_stream)
808
 
809
+ def enqueue_all2all_scatter(start_idx, chunk_size):
810
+ target_params = ordered_params[start_idx:start_idx + chunk_size]
811
+ if target_params:
812
+ alloc_event = _alloc_scattered_u(target_params, param_to_state,
813
+ self.rank,
814
+ self.compute_stream)
815
+ _all2all_scatter(target_params, param_to_state, self.rank,
816
+ self.comm_stream, alloc_event)
817
 
818
  def enqueue_update_param(start_idx, chunk_size):
819
  for p in ordered_params[start_idx:start_idx + chunk_size]:
 
828
  # Wait grad update
829
  self.comm_stream.wait_stream(torch.cuda.current_stream())
830
 
831
+ overlap_step = self.overlap_step
832
+ for i in range(0, overlap_step):
833
+ enqueue_all2all_gather(i * chunk_size, chunk_size)
834
+ enqueue_computes(i * chunk_size, chunk_size)
835
+
836
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
837
+ enqueue_all2all_scatter(i, chunk_size)
838
+ enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size)
839
+ enqueue_update_param(i, chunk_size)
840
+ enqueue_computes(i + overlap_step * chunk_size, chunk_size)
 
 
841
 
842
  # Wait the last update_param to finish
843
  torch.cuda.current_stream().wait_stream(self.compute_stream)
build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_9c21645_dirty
3
- ops = torch.ops._optimizer_9c21645_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_9c21645_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_15336dc_dirty
3
+ ops = torch.ops._optimizer_15336dc_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_15336dc_dirty::{op_name}"
build/{torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_9c21645_dirty.abi3.so → torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_15336dc_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:dae71b7e998e72130093a86f8c983c3379510e23525e3cdcd4afe5c21bf4d3db
3
  size 1883344
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6880c22f63ccd66e8ac62792a564d1ade58325b47369a1773c7753d4243893b9
3
  size 1883344
build/torch28-cxx11-cu129-x86_64-linux/optimizer/matmul_transpose_triton.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright (c) 2025 Tianyang Lin
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ import torch
24
+ import triton
25
+ import triton.language as tl
26
+
27
+
28
+ def get_autotune_config():
29
+ return [
30
+ triton.Config(
31
+ {
32
+ 'BLOCK_SIZE_M': blk_m,
33
+ 'BLOCK_SIZE_K': blk_k,
34
+ 'GROUP_SIZE_M': grp_sz
35
+ },
36
+ num_stages=n_stages,
37
+ num_warps=n_warps) for blk_m in [32, 64, 128]
38
+ for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5]
39
+ for n_warps in [4, 8]
40
+ ]
41
+
42
+
43
+ @triton.autotune(
44
+ configs=get_autotune_config(),
45
+ key=['M', 'K'],
46
+ )
47
+ @triton.jit
48
+ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
49
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
50
+ GROUP_SIZE_M: tl.constexpr):
51
+ """
52
+ Core kernel jit function of matmul_transpose that computes y = x @ x.T
53
+ The code is a simple adaptation from the triton `matmul` tutorial:
54
+ https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
55
+ """
56
+ pid = tl.program_id(axis=0)
57
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
58
+ num_pid_n = tl.cdiv(M, BLOCK_SIZE_M)
59
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
60
+ group_id = pid // num_pid_in_group
61
+ first_pid_m = group_id * GROUP_SIZE_M
62
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
63
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
64
+ pid_n = (pid % num_pid_in_group) // group_size_m
65
+ if pid_m > pid_n:
66
+ return
67
+
68
+ offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
69
+ offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
70
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
71
+ # we use a & b ptrs to denote different rows of x.
72
+ a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
73
+ b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk)
74
+
75
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32)
76
+
77
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
78
+ a = tl.load(a_ptrs,
79
+ mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
80
+ other=0.0)
81
+ b = tl.load(b_ptrs,
82
+ mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
83
+ other=0.0)
84
+ accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator)
85
+ a_ptrs += BLOCK_SIZE_K * stride_xk
86
+ b_ptrs += BLOCK_SIZE_K * stride_xk
87
+ # use dtype.element_ty to accommodate different input datatypes as in cpp templates
88
+ # https://github.com/triton-lang/triton/issues/2252
89
+ c = accumulator.to(x.dtype.element_ty)
90
+
91
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
92
+ offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
93
+ c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :]
94
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
95
+ tl.store(c_ptrs, c, mask=c_mask)
96
+
97
+ # transpose and copy
98
+ if pid_m < pid_n:
99
+ ct_ptrs = y + stride_ym * offs_cn[:,
100
+ None] + stride_yn * offs_cm[None, :]
101
+ ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M)
102
+ tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)
103
+
104
+
105
+ def matmul_transpose_assign(d_in, d_out):
106
+ assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor"
107
+ assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor"
108
+ assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device"
109
+ assert d_in.dtype == d_out.dtype, "Inputs must have the same data type"
110
+ assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor"
111
+ assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor"
112
+ assert d_in.size(0) == d_out.size(0) == d_out.size(0), \
113
+ "First dimension of `d_in` must match first and second dimension of `d_out`"
114
+
115
+ d_in = d_in.contiguous()
116
+ M, K = d_in.shape
117
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
118
+ M, META['BLOCK_SIZE_M']), )
119
+ with torch.cuda.device(d_in.device.index):
120
+ mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
121
+ d_out.stride(0), d_out.stride(1))
122
+
123
+
124
+ def matmul_transpose(d_in):
125
+ M, _ = d_in.shape
126
+ d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype)
127
+ matmul_transpose_assign(d_in, d_out)
128
+ return d_out
build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py CHANGED
@@ -8,14 +8,19 @@ import torch
8
  import torch.distributed as dist
9
  from torch.distributed._tensor import DTensor, Replicate, Shard
10
 
 
 
11
  logger = logging.getLogger(__name__)
12
 
 
 
13
 
14
  # This code snippet is a modified version adapted from the following GitHub repositories:
15
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
16
  # Muon's Newton–Schulz iteration causes high variance in singular values
17
  # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
18
  @torch.no_grad()
 
19
  def _zeropower_via_newtonschulz5(G, steps):
20
  """
21
  Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
@@ -27,13 +32,15 @@ def _zeropower_via_newtonschulz5(G, steps):
27
  performance at all relative to UV^T, where USV^T = G is the SVD.
28
  """
29
  assert len(G.shape) == 2
30
- assert G.dtype == torch.bfloat16
31
  X = G # no manual typecast
32
 
33
  if G.size(0) > G.size(1):
34
  X = X.T
35
  # Ensure spectral norm is at most 1
36
  X = X / (X.norm() + 1e-7)
 
 
37
  # Perform the NS iterations
38
  for a, b, c in [
39
  (4.0848, -6.8946, 2.9270),
@@ -42,13 +49,10 @@ def _zeropower_via_newtonschulz5(G, steps):
42
  (2.8769, -3.1427, 1.2046),
43
  (2.8366, -3.0525, 1.2012),
44
  ]:
45
- A = X @ X.T
46
- # B = (
47
- # b * A + c * A @ A
48
- # )
49
- B = torch.addmm(A, A, A, alpha=c, beta=b)
50
- # X = a * X + B @ X
51
- X = torch.addmm(X, B, X, alpha=1.0, beta=a)
52
 
53
  if G.size(0) > G.size(1):
54
  X = X.T
@@ -69,51 +73,142 @@ class _muon_state:
69
  qk_clip_state = None
70
 
71
 
 
 
 
 
 
 
 
 
72
  @torch.no_grad()
73
- def _gather(p, state, rank, comm_stream, none_grad):
74
  """
75
- Gather the gradients to worker_rank.
76
- If none_grad is True, free p.grad after the gather.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  """
78
  with torch.cuda.stream(comm_stream):
79
- g = p.grad
 
80
 
81
- if rank == state.worker_rank:
82
- num_ranks = dist.get_world_size(group=state.process_group)
83
- gather_list = [
84
- torch.empty_like(g.to_local(), dtype=torch.bfloat16)
85
- for _ in range(num_ranks)
86
- ]
87
- else:
88
- gather_list = None
89
-
90
- g = g.to(torch.bfloat16)
91
- torch.distributed.gather(
92
- g.to_local(),
93
- dst=state.worker_rank,
94
- gather_list=gather_list,
95
- group=state.process_group,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  )
97
- if rank == state.worker_rank:
98
- if state.gathered_grad is not None:
99
- raise RuntimeError(
100
- "Gather event already exists, which should not happen.")
101
- state.gathered_grad = torch.cat(gather_list, dim=0)
102
- state.gather_event = torch.cuda.Event()
103
- state.gather_event.record()
104
- else:
105
- state.gathered_grad = None
106
- state.gather_event = None
107
- gather_list = None
108
- if none_grad:
109
- # We can safely free p.grad without calling record_stream:
110
- # p.grad.to_local().record_stream(comm_stream)
111
- # Explanation:
112
- # 1. p.grad is created on the default stream, but the default stream
113
- # is synchronized with the comm stream later.
114
- # 2. There is no further activity on the default stream before the optimizer finishes.
115
- # Therefore, it is safe to free p.grad directly on the comm stream.
116
- p.grad = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
 
119
  @torch.no_grad()
@@ -127,45 +222,145 @@ def _compute_u(p, state, steps, rank, compute_stream):
127
  raise RuntimeError("Gather event must be set before compute.")
128
  compute_stream.wait_event(state.gather_event)
129
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
 
130
  state.computed_u = u
131
- state.scattered_u = torch.empty_like(p.to_local(),
132
- dtype=torch.bfloat16)
133
- state.compute_event = torch.cuda.Event()
134
- state.compute_event.record()
135
- u = None
136
 
137
 
138
  @torch.no_grad()
139
- def _scatter(p, state, rank, comm_stream):
140
  """
141
- Scatter the computed_u from worker_rank to all ranks.
 
142
  """
 
 
 
 
 
 
 
 
 
 
143
 
 
 
 
 
144
  with torch.cuda.stream(comm_stream):
145
- if state.compute_event is None:
146
- raise RuntimeError("Compute event must be set before scatter.")
147
- comm_stream.wait_event(state.compute_event)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
- if rank == state.worker_rank:
150
- num_ranks = dist.get_world_size(group=state.process_group)
151
- # Clear the gathered gradient to free memory
152
- state.gathered_grad = None
 
 
 
 
153
 
154
- u = state.computed_u
155
- scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
156
- scatter_list = [s.contiguous() for s in scatter_list]
 
 
 
 
 
 
157
  else:
158
- scatter_list = None
 
 
159
 
160
- torch.distributed.scatter(
161
- state.scattered_u,
162
- scatter_list=scatter_list,
163
- src=state.worker_rank,
164
- group=state.process_group,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  )
166
- state.scatter_event = torch.cuda.Event()
167
- state.scatter_event.record()
168
- scatter_list = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
 
171
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
@@ -321,6 +516,11 @@ class Muon(torch.optim.Optimizer):
321
  "head_dim": 128,
322
  "threshold": 100
323
  }
 
 
 
 
 
324
  """
325
 
326
  def __init__(self,
@@ -339,7 +539,8 @@ class Muon(torch.optim.Optimizer):
339
  "k_indices": [],
340
  "head_dim": 128,
341
  "threshold": 100
342
- }):
 
343
  defaults = dict(
344
  lr=lr,
345
  weight_decay=weight_decay,
@@ -363,15 +564,13 @@ class Muon(torch.optim.Optimizer):
363
 
364
  super().__init__(params, defaults)
365
 
366
- if dist.is_initialized():
367
- self.rank = dist.get_rank()
368
- else:
369
- self.rank = None
370
 
371
  self.comm_stream = torch.cuda.Stream()
372
  self.compute_stream = torch.cuda.Stream()
373
  self.debug = debug
374
  self.clip_config = clip_config
 
375
 
376
  def _calc_flops(self, G, steps):
377
  assert len(G.shape) == 2
@@ -444,11 +643,18 @@ class Muon(torch.optim.Optimizer):
444
  if mesh is None:
445
  mesh = p.device_mesh
446
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
 
 
 
 
 
447
  elif mesh != p.device_mesh:
448
  raise ValueError("All parameters must be on the same mesh.")
449
 
 
450
  param_to_state[id(p)] = _muon_state()
451
- param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
 
452
  param_to_state[id(p)].process_group = process_group
453
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
454
  param_to_state[id(p)].qk_clip_state = qk_clip_state
@@ -478,7 +684,7 @@ class Muon(torch.optim.Optimizer):
478
  else:
479
  g = buf
480
 
481
- u = _zeropower_via_newtonschulz5(g.bfloat16(),
482
  steps=group["ns_steps"])
483
 
484
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
@@ -493,15 +699,12 @@ class Muon(torch.optim.Optimizer):
493
  def _update_g(self, p, g, group, momentum):
494
  # calc update
495
  state = self.state[p]
496
- if "momentum_buffer" not in state:
497
- state["momentum_buffer"] = torch.zeros_like(g)
498
- buf = state["momentum_buffer"]
499
- buf.mul_(momentum).add_(g)
500
  if group["nesterov"]:
501
- g = g.add(buf, alpha=momentum)
502
- else:
503
- g = buf
504
- return g
505
 
506
  @staticmethod
507
  def _update_p(p, u, lr, adjusted_lr, weight_decay):
@@ -585,11 +788,17 @@ class Muon(torch.optim.Optimizer):
585
  param_to_state, ordered_params = self.init_state_and_assign_params(
586
  names, params, group, qk_logits)
587
 
588
- def enqueue_gathers(start_idx, chunk_size):
589
- for p in ordered_params[start_idx:start_idx + chunk_size]:
590
- state = param_to_state[id(p)]
591
- _gather(p, state, self.rank, self.comm_stream,
592
- group["none_grad"])
 
 
 
 
 
 
593
 
594
  def enqueue_computes(start_idx, chunk_size):
595
  for p in ordered_params[start_idx:start_idx + chunk_size]:
@@ -597,10 +806,14 @@ class Muon(torch.optim.Optimizer):
597
  _compute_u(p, state, group["ns_steps"], self.rank,
598
  self.compute_stream)
599
 
600
- def enqueue_scatters(start_idx, chunk_size):
601
- for p in ordered_params[start_idx:start_idx + chunk_size]:
602
- state = param_to_state[id(p)]
603
- _scatter(p, state, self.rank, self.comm_stream)
 
 
 
 
604
 
605
  def enqueue_update_param(start_idx, chunk_size):
606
  for p in ordered_params[start_idx:start_idx + chunk_size]:
@@ -615,14 +828,16 @@ class Muon(torch.optim.Optimizer):
615
  # Wait grad update
616
  self.comm_stream.wait_stream(torch.cuda.current_stream())
617
 
618
- enqueue_gathers(0, chunk_size)
 
 
 
 
619
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
620
- enqueue_computes(i, chunk_size)
621
- if i > 0:
622
- enqueue_update_param(i - chunk_size, chunk_size)
623
- enqueue_gathers(i + chunk_size, chunk_size)
624
- enqueue_scatters(i, chunk_size)
625
- enqueue_update_param(i, chunk_size)
626
 
627
  # Wait the last update_param to finish
628
  torch.cuda.current_stream().wait_stream(self.compute_stream)
 
8
  import torch.distributed as dist
9
  from torch.distributed._tensor import DTensor, Replicate, Shard
10
 
11
+ from .matmul_transpose_triton import matmul_transpose_assign
12
+
13
  logger = logging.getLogger(__name__)
14
 
15
+ COMM_DTYPE = torch.bfloat16
16
+
17
 
18
  # This code snippet is a modified version adapted from the following GitHub repositories:
19
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
20
  # Muon's Newton–Schulz iteration causes high variance in singular values
21
  # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
22
  @torch.no_grad()
23
+ # matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
24
  def _zeropower_via_newtonschulz5(G, steps):
25
  """
26
  Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
 
32
  performance at all relative to UV^T, where USV^T = G is the SVD.
33
  """
34
  assert len(G.shape) == 2
35
+ assert G.dtype == COMM_DTYPE
36
  X = G # no manual typecast
37
 
38
  if G.size(0) > G.size(1):
39
  X = X.T
40
  # Ensure spectral norm is at most 1
41
  X = X / (X.norm() + 1e-7)
42
+ buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
43
+ buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
44
  # Perform the NS iterations
45
  for a, b, c in [
46
  (4.0848, -6.8946, 2.9270),
 
49
  (2.8769, -3.1427, 1.2046),
50
  (2.8366, -3.0525, 1.2012),
51
  ]:
52
+ matmul_transpose_assign(X, buf1)
53
+ matmul_transpose_assign(buf1, buf2)
54
+ buf1.mul_(b).add_(buf2, alpha=c)
55
+ X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)
 
 
 
56
 
57
  if G.size(0) > G.size(1):
58
  X = X.T
 
73
  qk_clip_state = None
74
 
75
 
76
+ def split_elems_for_src(param, src_rank, num_ranks) -> int:
77
+ rows = param.shape[0]
78
+ cols = int(param.numel() // rows)
79
+ base, rem = divmod(rows, num_ranks)
80
+ my_rows = base + (1 if src_rank < rem else 0)
81
+ return my_rows * cols
82
+
83
+
84
  @torch.no_grad()
85
+ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
86
  """
87
+ Pre-allocate gathered_grad buffer on compute_stream
88
+ before launching all2all gather
89
+ """
90
+ with torch.cuda.stream(compute_stream):
91
+ for p in params:
92
+ state = param_to_state[id(p)]
93
+ if rank == state.worker_rank:
94
+ num_ranks = dist.get_world_size(group=state.process_group)
95
+ state.gathered_grad = torch.empty(p.grad.numel(),
96
+ dtype=COMM_DTYPE,
97
+ device="cuda")
98
+ else:
99
+ state.gathered_grad = None
100
+
101
+ alloc_event = torch.cuda.Event()
102
+ alloc_event.record(compute_stream)
103
+ return alloc_event
104
+
105
+
106
+ @torch.no_grad()
107
+ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
108
+ alloc_event):
109
+ """
110
+ All2all gathers shards so each owner rank reconstructs its full gradient
111
  """
112
  with torch.cuda.stream(comm_stream):
113
+ process_group = param_to_state[id(params[0])].process_group
114
+ num_ranks = dist.get_world_size(group=process_group)
115
 
116
+ # Construct sending buffers
117
+ per_dst = [[] for _ in range(num_ranks)]
118
+ send_counts = [0] * num_ranks
119
+
120
+ for p in params:
121
+ state = param_to_state[id(p)]
122
+ dst = state.worker_rank
123
+ assert dst < num_ranks
124
+ shard_elems = split_elems_for_src(p, rank, num_ranks)
125
+ g = p.grad
126
+ g = g.to_local().to(COMM_DTYPE).contiguous().view(-1)
127
+ assert g.numel() == shard_elems
128
+ per_dst[dst].append(g)
129
+ send_counts[dst] += shard_elems
130
+
131
+ assert all(
132
+ len(v) > 0
133
+ for v in per_dst), "all params should be sharded to all devices"
134
+
135
+ send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
136
+ owned_params = [
137
+ p for p in params if param_to_state[id(p)].worker_rank == rank
138
+ ]
139
+
140
+ # Compute receive sizes and allocate receiving buffers
141
+ recv_counts = [0] * num_ranks
142
+
143
+ for src in range(num_ranks):
144
+ total = 0
145
+ for p in owned_params:
146
+ state = param_to_state[id(p)]
147
+ assert state.worker_rank == rank
148
+ total += split_elems_for_src(p, src, num_ranks)
149
+ recv_counts[src] = total
150
+
151
+ recv_total = sum(recv_counts)
152
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
153
+
154
+ #All2All
155
+ dist.all_to_all_single(
156
+ recv_buf,
157
+ send_buf,
158
+ output_split_sizes=recv_counts,
159
+ input_split_sizes=send_counts,
160
+ group=process_group,
161
  )
162
+
163
+ # Reconstructs gathered grad from the received buffer
164
+ #
165
+ # recv_buf (num ranks = 3)
166
+ #
167
+ # From rank 0 From rank 1 From rank 2
168
+ # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 |
169
+ #
170
+ # Outer loop:
171
+ # rank 0 -> rank 1 -> rank2
172
+ #
173
+ # Inner loop:
174
+ # p1_n -> p2_n -> p3_n
175
+
176
+ comm_stream.wait_event(alloc_event)
177
+
178
+ off = 0
179
+ write_offsets = {id(p): 0 for p in owned_params}
180
+ for src in range(num_ranks):
181
+ if recv_counts[src] == 0:
182
+ continue
183
+
184
+ block = recv_counts[src]
185
+ inner_off = 0
186
+ for p in owned_params:
187
+ state = param_to_state[id(p)]
188
+ assert state.worker_rank == rank
189
+ n = split_elems_for_src(p, src, num_ranks)
190
+ assert n > 0
191
+
192
+ sg = recv_buf.narrow(0, off + inner_off, n)
193
+ woff = write_offsets[id(p)]
194
+ dst = state.gathered_grad.narrow(0, woff, n)
195
+ dst.copy_(sg)
196
+
197
+ write_offsets[id(p)] += n
198
+ inner_off += n
199
+ off += block
200
+
201
+ for p in params:
202
+ state = param_to_state[id(p)]
203
+ if state.worker_rank == rank:
204
+ state.gathered_grad = state.gathered_grad.view_as(p)
205
+ state.gather_event = torch.cuda.Event()
206
+ state.gather_event.record(comm_stream)
207
+ else:
208
+ state.gathered_grad = None
209
+ state.gather_event = None
210
+ if none_grad:
211
+ p.grad = None
212
 
213
 
214
  @torch.no_grad()
 
222
  raise RuntimeError("Gather event must be set before compute.")
223
  compute_stream.wait_event(state.gather_event)
224
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
225
+ state.gathered_grad = None
226
  state.computed_u = u
227
+ state.compute_event = torch.cuda.Event()
228
+ state.compute_event.record()
229
+ else:
230
+ state.computed_u = None
231
+ state.compute_event = None
232
 
233
 
234
  @torch.no_grad()
235
+ def _alloc_scattered_u(params, param_to_state, rank, compute_stream):
236
  """
237
+ Pre-allocate scattered_u buffer on compute_stream
238
+ before launching all2all gather
239
  """
240
+ with torch.cuda.stream(compute_stream):
241
+ for p in params:
242
+ state = param_to_state[id(p)]
243
+ state.scattered_u = torch.empty_like(p.to_local(),
244
+ dtype=COMM_DTYPE)
245
+
246
+ alloc_event = torch.cuda.Event()
247
+ alloc_event.record(compute_stream)
248
+ return alloc_event
249
+
250
 
251
+ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
252
+ """
253
+ All2all scatters full gradients to all ranks
254
+ """
255
  with torch.cuda.stream(comm_stream):
256
+ process_group = param_to_state[id(params[0])].process_group
257
+ num_ranks = dist.get_world_size(group=process_group)
258
+ owned_params = [
259
+ p for p in params if param_to_state[id(p)].worker_rank == rank
260
+ ]
261
+
262
+ # Construct sending buffer
263
+ per_dst = [[] for _ in range(num_ranks)]
264
+ send_counts = [0] * num_ranks
265
+
266
+ if owned_params:
267
+ for p in owned_params:
268
+ state = param_to_state[id(p)]
269
+ if state.compute_event is None:
270
+ raise RuntimeError(
271
+ "Compute event must be set before scatter.")
272
+ comm_stream.wait_event(state.compute_event)
273
+ state.gathered_grad = None
274
 
275
+ assert state.computed_u is not None
276
+
277
+ u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1)
278
+
279
+ offset = 0
280
+ for dst in range(num_ranks):
281
+ n = split_elems_for_src(p, dst, num_ranks)
282
+ assert n > 0
283
 
284
+ su = u_full.narrow(0, offset, n)
285
+ per_dst[dst].append(su)
286
+ send_counts[dst] += n
287
+ offset += n
288
+
289
+ assert offset == u_full.numel()
290
+
291
+ if any(len(v) > 0 for v in per_dst):
292
+ send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
293
  else:
294
+ # all_to_all requires participation from all ranks
295
+ # Even non-owner ranks must join the collective call
296
+ send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
297
 
298
+ # Compute receive sizes and allocate receiving buffers
299
+ recv_counts = [0] * num_ranks
300
+
301
+ for src in range(num_ranks):
302
+ total = 0
303
+ for p in params:
304
+ state = param_to_state[id(p)]
305
+ if state.worker_rank != src:
306
+ continue
307
+ total += split_elems_for_src(p, rank, num_ranks)
308
+ recv_counts[src] = total
309
+
310
+ recv_total = sum(recv_counts)
311
+ assert recv_total > 0
312
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
313
+
314
+ #All2All
315
+ dist.all_to_all_single(
316
+ recv_buf,
317
+ send_buf,
318
+ output_split_sizes=recv_counts,
319
+ input_split_sizes=send_counts,
320
+ group=process_group,
321
  )
322
+
323
+ # Copy to pre-allocated scattered_u buffer from the received buffer
324
+ #
325
+ # recv_buf (num ranks = 3, local_rank = 0)
326
+ #
327
+ # From rank 0 From rank 1 From rank 2
328
+ # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 |
329
+ #
330
+ # Outer loop:
331
+ # rank 0 -> rank 1 -> rank2
332
+ #
333
+ # Inner loop:
334
+ # src(0) : p1_0 -> p2_0 -> p3_0
335
+ # src(1) : p4_0
336
+ # src(2) : p5_0 -> p6_0
337
+
338
+ comm_stream.wait_event(alloc_event)
339
+
340
+ off = 0
341
+ for src in range(num_ranks):
342
+ block = recv_counts[src]
343
+ if block == 0:
344
+ continue
345
+
346
+ inner_off = 0
347
+ for p in params:
348
+ state = param_to_state[id(p)]
349
+ if state.worker_rank != src:
350
+ continue
351
+ n = split_elems_for_src(p, rank, num_ranks)
352
+ assert n > 0
353
+
354
+ flat_local = recv_buf.narrow(0, off + inner_off,
355
+ n).view_as(p.to_local())
356
+ state.scattered_u.copy_(flat_local)
357
+
358
+ state.scatter_event = torch.cuda.Event()
359
+ state.scatter_event.record(comm_stream)
360
+ inner_off += n
361
+
362
+ assert inner_off == block
363
+ off += block
364
 
365
 
366
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
 
516
  "head_dim": 128,
517
  "threshold": 100
518
  }
519
+ overlap_step : How many all2all gather, compute operations are launched in advance
520
+ before the corresponding all2all scatter steps begin.
521
+ A higher overlap_step increases memory usage but can improve
522
+ performance by overlapping communication.
523
+ Parallel muon only.
524
  """
525
 
526
  def __init__(self,
 
539
  "k_indices": [],
540
  "head_dim": 128,
541
  "threshold": 100
542
+ },
543
+ overlap_step=5):
544
  defaults = dict(
545
  lr=lr,
546
  weight_decay=weight_decay,
 
564
 
565
  super().__init__(params, defaults)
566
 
567
+ self.rank = None
 
 
 
568
 
569
  self.comm_stream = torch.cuda.Stream()
570
  self.compute_stream = torch.cuda.Stream()
571
  self.debug = debug
572
  self.clip_config = clip_config
573
+ self.overlap_step = overlap_step
574
 
575
  def _calc_flops(self, G, steps):
576
  assert len(G.shape) == 2
 
643
  if mesh is None:
644
  mesh = p.device_mesh
645
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
646
+ local_rank = dist.get_rank(group=process_group)
647
+ if self.rank is None:
648
+ self.rank = dist.get_rank(group=process_group)
649
+ else:
650
+ assert self.rank == local_rank
651
  elif mesh != p.device_mesh:
652
  raise ValueError("All parameters must be on the same mesh.")
653
 
654
+ num_ranks = dist.get_world_size(group=process_group)
655
  param_to_state[id(p)] = _muon_state()
656
+ param_to_state[id(
657
+ p)].worker_rank = shard_mesh[round_robin].item() % num_ranks
658
  param_to_state[id(p)].process_group = process_group
659
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
660
  param_to_state[id(p)].qk_clip_state = qk_clip_state
 
684
  else:
685
  g = buf
686
 
687
+ u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
688
  steps=group["ns_steps"])
689
 
690
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
 
699
  def _update_g(self, p, g, group, momentum):
700
  # calc update
701
  state = self.state[p]
702
+ buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
703
+ torch.add(g, buf, alpha=momentum, out=buf)
 
 
704
  if group["nesterov"]:
705
+ g.add_(buf, alpha=momentum)
706
+ return g
707
+ return buf
 
708
 
709
  @staticmethod
710
  def _update_p(p, u, lr, adjusted_lr, weight_decay):
 
788
  param_to_state, ordered_params = self.init_state_and_assign_params(
789
  names, params, group, qk_logits)
790
 
791
+ assert self.rank is not None
792
+
793
+ def enqueue_all2all_gather(start_idx, chunk_size):
794
+ target_params = ordered_params[start_idx:start_idx + chunk_size]
795
+ if target_params:
796
+ alloc_event = _alloc_gathered_grad(target_params,
797
+ param_to_state, self.rank,
798
+ self.compute_stream)
799
+ _all2all_gather(target_params, param_to_state, self.rank,
800
+ self.comm_stream, group["none_grad"],
801
+ alloc_event)
802
 
803
  def enqueue_computes(start_idx, chunk_size):
804
  for p in ordered_params[start_idx:start_idx + chunk_size]:
 
806
  _compute_u(p, state, group["ns_steps"], self.rank,
807
  self.compute_stream)
808
 
809
+ def enqueue_all2all_scatter(start_idx, chunk_size):
810
+ target_params = ordered_params[start_idx:start_idx + chunk_size]
811
+ if target_params:
812
+ alloc_event = _alloc_scattered_u(target_params, param_to_state,
813
+ self.rank,
814
+ self.compute_stream)
815
+ _all2all_scatter(target_params, param_to_state, self.rank,
816
+ self.comm_stream, alloc_event)
817
 
818
  def enqueue_update_param(start_idx, chunk_size):
819
  for p in ordered_params[start_idx:start_idx + chunk_size]:
 
828
  # Wait grad update
829
  self.comm_stream.wait_stream(torch.cuda.current_stream())
830
 
831
+ overlap_step = self.overlap_step
832
+ for i in range(0, overlap_step):
833
+ enqueue_all2all_gather(i * chunk_size, chunk_size)
834
+ enqueue_computes(i * chunk_size, chunk_size)
835
+
836
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
837
+ enqueue_all2all_scatter(i, chunk_size)
838
+ enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size)
839
+ enqueue_update_param(i, chunk_size)
840
+ enqueue_computes(i + overlap_step * chunk_size, chunk_size)
 
 
841
 
842
  # Wait the last update_param to finish
843
  torch.cuda.current_stream().wait_stream(self.compute_stream)
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_9c21645_dirty
3
- ops = torch.ops._optimizer_9c21645_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_9c21645_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_15336dc_dirty
3
+ ops = torch.ops._optimizer_15336dc_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_15336dc_dirty::{op_name}"
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_9c21645_dirty.abi3.so → _optimizer_15336dc_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d8f845b8df6426eb5db57e4525b8dd3c80004c44759b01a3e39cc37a817813b5
3
  size 1749936
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae22a3afdffd54435c6e5b145fc0b7772d03eb8c8bad0d388d9b2d1c8d2f60d5
3
  size 1749936
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/matmul_transpose_triton.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright (c) 2025 Tianyang Lin
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ import torch
24
+ import triton
25
+ import triton.language as tl
26
+
27
+
28
+ def get_autotune_config():
29
+ return [
30
+ triton.Config(
31
+ {
32
+ 'BLOCK_SIZE_M': blk_m,
33
+ 'BLOCK_SIZE_K': blk_k,
34
+ 'GROUP_SIZE_M': grp_sz
35
+ },
36
+ num_stages=n_stages,
37
+ num_warps=n_warps) for blk_m in [32, 64, 128]
38
+ for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5]
39
+ for n_warps in [4, 8]
40
+ ]
41
+
42
+
43
+ @triton.autotune(
44
+ configs=get_autotune_config(),
45
+ key=['M', 'K'],
46
+ )
47
+ @triton.jit
48
+ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
49
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
50
+ GROUP_SIZE_M: tl.constexpr):
51
+ """
52
+ Core kernel jit function of matmul_transpose that computes y = x @ x.T
53
+ The code is a simple adaptation from the triton `matmul` tutorial:
54
+ https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
55
+ """
56
+ pid = tl.program_id(axis=0)
57
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
58
+ num_pid_n = tl.cdiv(M, BLOCK_SIZE_M)
59
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
60
+ group_id = pid // num_pid_in_group
61
+ first_pid_m = group_id * GROUP_SIZE_M
62
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
63
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
64
+ pid_n = (pid % num_pid_in_group) // group_size_m
65
+ if pid_m > pid_n:
66
+ return
67
+
68
+ offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
69
+ offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
70
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
71
+ # we use a & b ptrs to denote different rows of x.
72
+ a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
73
+ b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk)
74
+
75
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32)
76
+
77
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
78
+ a = tl.load(a_ptrs,
79
+ mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
80
+ other=0.0)
81
+ b = tl.load(b_ptrs,
82
+ mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
83
+ other=0.0)
84
+ accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator)
85
+ a_ptrs += BLOCK_SIZE_K * stride_xk
86
+ b_ptrs += BLOCK_SIZE_K * stride_xk
87
+ # use dtype.element_ty to accommodate different input datatypes as in cpp templates
88
+ # https://github.com/triton-lang/triton/issues/2252
89
+ c = accumulator.to(x.dtype.element_ty)
90
+
91
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
92
+ offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
93
+ c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :]
94
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
95
+ tl.store(c_ptrs, c, mask=c_mask)
96
+
97
+ # transpose and copy
98
+ if pid_m < pid_n:
99
+ ct_ptrs = y + stride_ym * offs_cn[:,
100
+ None] + stride_yn * offs_cm[None, :]
101
+ ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M)
102
+ tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)
103
+
104
+
105
+ def matmul_transpose_assign(d_in, d_out):
106
+ assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor"
107
+ assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor"
108
+ assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device"
109
+ assert d_in.dtype == d_out.dtype, "Inputs must have the same data type"
110
+ assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor"
111
+ assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor"
112
+ assert d_in.size(0) == d_out.size(0) == d_out.size(0), \
113
+ "First dimension of `d_in` must match first and second dimension of `d_out`"
114
+
115
+ d_in = d_in.contiguous()
116
+ M, K = d_in.shape
117
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
118
+ M, META['BLOCK_SIZE_M']), )
119
+ with torch.cuda.device(d_in.device.index):
120
+ mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
121
+ d_out.stride(0), d_out.stride(1))
122
+
123
+
124
+ def matmul_transpose(d_in):
125
+ M, _ = d_in.shape
126
+ d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype)
127
+ matmul_transpose_assign(d_in, d_out)
128
+ return d_out
build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py CHANGED
@@ -8,14 +8,19 @@ import torch
8
  import torch.distributed as dist
9
  from torch.distributed._tensor import DTensor, Replicate, Shard
10
 
 
 
11
  logger = logging.getLogger(__name__)
12
 
 
 
13
 
14
  # This code snippet is a modified version adapted from the following GitHub repositories:
15
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
16
  # Muon's Newton–Schulz iteration causes high variance in singular values
17
  # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
18
  @torch.no_grad()
 
19
  def _zeropower_via_newtonschulz5(G, steps):
20
  """
21
  Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
@@ -27,13 +32,15 @@ def _zeropower_via_newtonschulz5(G, steps):
27
  performance at all relative to UV^T, where USV^T = G is the SVD.
28
  """
29
  assert len(G.shape) == 2
30
- assert G.dtype == torch.bfloat16
31
  X = G # no manual typecast
32
 
33
  if G.size(0) > G.size(1):
34
  X = X.T
35
  # Ensure spectral norm is at most 1
36
  X = X / (X.norm() + 1e-7)
 
 
37
  # Perform the NS iterations
38
  for a, b, c in [
39
  (4.0848, -6.8946, 2.9270),
@@ -42,13 +49,10 @@ def _zeropower_via_newtonschulz5(G, steps):
42
  (2.8769, -3.1427, 1.2046),
43
  (2.8366, -3.0525, 1.2012),
44
  ]:
45
- A = X @ X.T
46
- # B = (
47
- # b * A + c * A @ A
48
- # )
49
- B = torch.addmm(A, A, A, alpha=c, beta=b)
50
- # X = a * X + B @ X
51
- X = torch.addmm(X, B, X, alpha=1.0, beta=a)
52
 
53
  if G.size(0) > G.size(1):
54
  X = X.T
@@ -69,51 +73,142 @@ class _muon_state:
69
  qk_clip_state = None
70
 
71
 
 
 
 
 
 
 
 
 
72
  @torch.no_grad()
73
- def _gather(p, state, rank, comm_stream, none_grad):
74
  """
75
- Gather the gradients to worker_rank.
76
- If none_grad is True, free p.grad after the gather.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  """
78
  with torch.cuda.stream(comm_stream):
79
- g = p.grad
 
80
 
81
- if rank == state.worker_rank:
82
- num_ranks = dist.get_world_size(group=state.process_group)
83
- gather_list = [
84
- torch.empty_like(g.to_local(), dtype=torch.bfloat16)
85
- for _ in range(num_ranks)
86
- ]
87
- else:
88
- gather_list = None
89
-
90
- g = g.to(torch.bfloat16)
91
- torch.distributed.gather(
92
- g.to_local(),
93
- dst=state.worker_rank,
94
- gather_list=gather_list,
95
- group=state.process_group,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  )
97
- if rank == state.worker_rank:
98
- if state.gathered_grad is not None:
99
- raise RuntimeError(
100
- "Gather event already exists, which should not happen.")
101
- state.gathered_grad = torch.cat(gather_list, dim=0)
102
- state.gather_event = torch.cuda.Event()
103
- state.gather_event.record()
104
- else:
105
- state.gathered_grad = None
106
- state.gather_event = None
107
- gather_list = None
108
- if none_grad:
109
- # We can safely free p.grad without calling record_stream:
110
- # p.grad.to_local().record_stream(comm_stream)
111
- # Explanation:
112
- # 1. p.grad is created on the default stream, but the default stream
113
- # is synchronized with the comm stream later.
114
- # 2. There is no further activity on the default stream before the optimizer finishes.
115
- # Therefore, it is safe to free p.grad directly on the comm stream.
116
- p.grad = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
 
119
  @torch.no_grad()
@@ -127,45 +222,145 @@ def _compute_u(p, state, steps, rank, compute_stream):
127
  raise RuntimeError("Gather event must be set before compute.")
128
  compute_stream.wait_event(state.gather_event)
129
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
 
130
  state.computed_u = u
131
- state.scattered_u = torch.empty_like(p.to_local(),
132
- dtype=torch.bfloat16)
133
- state.compute_event = torch.cuda.Event()
134
- state.compute_event.record()
135
- u = None
136
 
137
 
138
  @torch.no_grad()
139
- def _scatter(p, state, rank, comm_stream):
140
  """
141
- Scatter the computed_u from worker_rank to all ranks.
 
142
  """
 
 
 
 
 
 
 
 
 
 
143
 
 
 
 
 
144
  with torch.cuda.stream(comm_stream):
145
- if state.compute_event is None:
146
- raise RuntimeError("Compute event must be set before scatter.")
147
- comm_stream.wait_event(state.compute_event)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
- if rank == state.worker_rank:
150
- num_ranks = dist.get_world_size(group=state.process_group)
151
- # Clear the gathered gradient to free memory
152
- state.gathered_grad = None
 
 
 
 
153
 
154
- u = state.computed_u
155
- scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
156
- scatter_list = [s.contiguous() for s in scatter_list]
 
 
 
 
 
 
157
  else:
158
- scatter_list = None
 
 
159
 
160
- torch.distributed.scatter(
161
- state.scattered_u,
162
- scatter_list=scatter_list,
163
- src=state.worker_rank,
164
- group=state.process_group,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  )
166
- state.scatter_event = torch.cuda.Event()
167
- state.scatter_event.record()
168
- scatter_list = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
 
171
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
@@ -321,6 +516,11 @@ class Muon(torch.optim.Optimizer):
321
  "head_dim": 128,
322
  "threshold": 100
323
  }
 
 
 
 
 
324
  """
325
 
326
  def __init__(self,
@@ -339,7 +539,8 @@ class Muon(torch.optim.Optimizer):
339
  "k_indices": [],
340
  "head_dim": 128,
341
  "threshold": 100
342
- }):
 
343
  defaults = dict(
344
  lr=lr,
345
  weight_decay=weight_decay,
@@ -363,15 +564,13 @@ class Muon(torch.optim.Optimizer):
363
 
364
  super().__init__(params, defaults)
365
 
366
- if dist.is_initialized():
367
- self.rank = dist.get_rank()
368
- else:
369
- self.rank = None
370
 
371
  self.comm_stream = torch.cuda.Stream()
372
  self.compute_stream = torch.cuda.Stream()
373
  self.debug = debug
374
  self.clip_config = clip_config
 
375
 
376
  def _calc_flops(self, G, steps):
377
  assert len(G.shape) == 2
@@ -444,11 +643,18 @@ class Muon(torch.optim.Optimizer):
444
  if mesh is None:
445
  mesh = p.device_mesh
446
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
 
 
 
 
 
447
  elif mesh != p.device_mesh:
448
  raise ValueError("All parameters must be on the same mesh.")
449
 
 
450
  param_to_state[id(p)] = _muon_state()
451
- param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
 
452
  param_to_state[id(p)].process_group = process_group
453
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
454
  param_to_state[id(p)].qk_clip_state = qk_clip_state
@@ -478,7 +684,7 @@ class Muon(torch.optim.Optimizer):
478
  else:
479
  g = buf
480
 
481
- u = _zeropower_via_newtonschulz5(g.bfloat16(),
482
  steps=group["ns_steps"])
483
 
484
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
@@ -493,15 +699,12 @@ class Muon(torch.optim.Optimizer):
493
  def _update_g(self, p, g, group, momentum):
494
  # calc update
495
  state = self.state[p]
496
- if "momentum_buffer" not in state:
497
- state["momentum_buffer"] = torch.zeros_like(g)
498
- buf = state["momentum_buffer"]
499
- buf.mul_(momentum).add_(g)
500
  if group["nesterov"]:
501
- g = g.add(buf, alpha=momentum)
502
- else:
503
- g = buf
504
- return g
505
 
506
  @staticmethod
507
  def _update_p(p, u, lr, adjusted_lr, weight_decay):
@@ -585,11 +788,17 @@ class Muon(torch.optim.Optimizer):
585
  param_to_state, ordered_params = self.init_state_and_assign_params(
586
  names, params, group, qk_logits)
587
 
588
- def enqueue_gathers(start_idx, chunk_size):
589
- for p in ordered_params[start_idx:start_idx + chunk_size]:
590
- state = param_to_state[id(p)]
591
- _gather(p, state, self.rank, self.comm_stream,
592
- group["none_grad"])
 
 
 
 
 
 
593
 
594
  def enqueue_computes(start_idx, chunk_size):
595
  for p in ordered_params[start_idx:start_idx + chunk_size]:
@@ -597,10 +806,14 @@ class Muon(torch.optim.Optimizer):
597
  _compute_u(p, state, group["ns_steps"], self.rank,
598
  self.compute_stream)
599
 
600
- def enqueue_scatters(start_idx, chunk_size):
601
- for p in ordered_params[start_idx:start_idx + chunk_size]:
602
- state = param_to_state[id(p)]
603
- _scatter(p, state, self.rank, self.comm_stream)
 
 
 
 
604
 
605
  def enqueue_update_param(start_idx, chunk_size):
606
  for p in ordered_params[start_idx:start_idx + chunk_size]:
@@ -615,14 +828,16 @@ class Muon(torch.optim.Optimizer):
615
  # Wait grad update
616
  self.comm_stream.wait_stream(torch.cuda.current_stream())
617
 
618
- enqueue_gathers(0, chunk_size)
 
 
 
 
619
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
620
- enqueue_computes(i, chunk_size)
621
- if i > 0:
622
- enqueue_update_param(i - chunk_size, chunk_size)
623
- enqueue_gathers(i + chunk_size, chunk_size)
624
- enqueue_scatters(i, chunk_size)
625
- enqueue_update_param(i, chunk_size)
626
 
627
  # Wait the last update_param to finish
628
  torch.cuda.current_stream().wait_stream(self.compute_stream)
 
8
  import torch.distributed as dist
9
  from torch.distributed._tensor import DTensor, Replicate, Shard
10
 
11
+ from .matmul_transpose_triton import matmul_transpose_assign
12
+
13
  logger = logging.getLogger(__name__)
14
 
15
+ COMM_DTYPE = torch.bfloat16
16
+
17
 
18
  # This code snippet is a modified version adapted from the following GitHub repositories:
19
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
20
  # Muon's Newton–Schulz iteration causes high variance in singular values
21
  # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
22
  @torch.no_grad()
23
+ # matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
24
  def _zeropower_via_newtonschulz5(G, steps):
25
  """
26
  Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
 
32
  performance at all relative to UV^T, where USV^T = G is the SVD.
33
  """
34
  assert len(G.shape) == 2
35
+ assert G.dtype == COMM_DTYPE
36
  X = G # no manual typecast
37
 
38
  if G.size(0) > G.size(1):
39
  X = X.T
40
  # Ensure spectral norm is at most 1
41
  X = X / (X.norm() + 1e-7)
42
+ buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
43
+ buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
44
  # Perform the NS iterations
45
  for a, b, c in [
46
  (4.0848, -6.8946, 2.9270),
 
49
  (2.8769, -3.1427, 1.2046),
50
  (2.8366, -3.0525, 1.2012),
51
  ]:
52
+ matmul_transpose_assign(X, buf1)
53
+ matmul_transpose_assign(buf1, buf2)
54
+ buf1.mul_(b).add_(buf2, alpha=c)
55
+ X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)
 
 
 
56
 
57
  if G.size(0) > G.size(1):
58
  X = X.T
 
73
  qk_clip_state = None
74
 
75
 
76
+ def split_elems_for_src(param, src_rank, num_ranks) -> int:
77
+ rows = param.shape[0]
78
+ cols = int(param.numel() // rows)
79
+ base, rem = divmod(rows, num_ranks)
80
+ my_rows = base + (1 if src_rank < rem else 0)
81
+ return my_rows * cols
82
+
83
+
84
  @torch.no_grad()
85
+ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
86
  """
87
+ Pre-allocate gathered_grad buffer on compute_stream
88
+ before launching all2all gather
89
+ """
90
+ with torch.cuda.stream(compute_stream):
91
+ for p in params:
92
+ state = param_to_state[id(p)]
93
+ if rank == state.worker_rank:
94
+ num_ranks = dist.get_world_size(group=state.process_group)
95
+ state.gathered_grad = torch.empty(p.grad.numel(),
96
+ dtype=COMM_DTYPE,
97
+ device="cuda")
98
+ else:
99
+ state.gathered_grad = None
100
+
101
+ alloc_event = torch.cuda.Event()
102
+ alloc_event.record(compute_stream)
103
+ return alloc_event
104
+
105
+
106
+ @torch.no_grad()
107
+ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
108
+ alloc_event):
109
+ """
110
+ All2all gathers shards so each owner rank reconstructs its full gradient
111
  """
112
  with torch.cuda.stream(comm_stream):
113
+ process_group = param_to_state[id(params[0])].process_group
114
+ num_ranks = dist.get_world_size(group=process_group)
115
 
116
+ # Construct sending buffers
117
+ per_dst = [[] for _ in range(num_ranks)]
118
+ send_counts = [0] * num_ranks
119
+
120
+ for p in params:
121
+ state = param_to_state[id(p)]
122
+ dst = state.worker_rank
123
+ assert dst < num_ranks
124
+ shard_elems = split_elems_for_src(p, rank, num_ranks)
125
+ g = p.grad
126
+ g = g.to_local().to(COMM_DTYPE).contiguous().view(-1)
127
+ assert g.numel() == shard_elems
128
+ per_dst[dst].append(g)
129
+ send_counts[dst] += shard_elems
130
+
131
+ assert all(
132
+ len(v) > 0
133
+ for v in per_dst), "all params should be sharded to all devices"
134
+
135
+ send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
136
+ owned_params = [
137
+ p for p in params if param_to_state[id(p)].worker_rank == rank
138
+ ]
139
+
140
+ # Compute receive sizes and allocate receiving buffers
141
+ recv_counts = [0] * num_ranks
142
+
143
+ for src in range(num_ranks):
144
+ total = 0
145
+ for p in owned_params:
146
+ state = param_to_state[id(p)]
147
+ assert state.worker_rank == rank
148
+ total += split_elems_for_src(p, src, num_ranks)
149
+ recv_counts[src] = total
150
+
151
+ recv_total = sum(recv_counts)
152
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
153
+
154
+ #All2All
155
+ dist.all_to_all_single(
156
+ recv_buf,
157
+ send_buf,
158
+ output_split_sizes=recv_counts,
159
+ input_split_sizes=send_counts,
160
+ group=process_group,
161
  )
162
+
163
+ # Reconstructs gathered grad from the received buffer
164
+ #
165
+ # recv_buf (num ranks = 3)
166
+ #
167
+ # From rank 0 From rank 1 From rank 2
168
+ # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 |
169
+ #
170
+ # Outer loop:
171
+ # rank 0 -> rank 1 -> rank2
172
+ #
173
+ # Inner loop:
174
+ # p1_n -> p2_n -> p3_n
175
+
176
+ comm_stream.wait_event(alloc_event)
177
+
178
+ off = 0
179
+ write_offsets = {id(p): 0 for p in owned_params}
180
+ for src in range(num_ranks):
181
+ if recv_counts[src] == 0:
182
+ continue
183
+
184
+ block = recv_counts[src]
185
+ inner_off = 0
186
+ for p in owned_params:
187
+ state = param_to_state[id(p)]
188
+ assert state.worker_rank == rank
189
+ n = split_elems_for_src(p, src, num_ranks)
190
+ assert n > 0
191
+
192
+ sg = recv_buf.narrow(0, off + inner_off, n)
193
+ woff = write_offsets[id(p)]
194
+ dst = state.gathered_grad.narrow(0, woff, n)
195
+ dst.copy_(sg)
196
+
197
+ write_offsets[id(p)] += n
198
+ inner_off += n
199
+ off += block
200
+
201
+ for p in params:
202
+ state = param_to_state[id(p)]
203
+ if state.worker_rank == rank:
204
+ state.gathered_grad = state.gathered_grad.view_as(p)
205
+ state.gather_event = torch.cuda.Event()
206
+ state.gather_event.record(comm_stream)
207
+ else:
208
+ state.gathered_grad = None
209
+ state.gather_event = None
210
+ if none_grad:
211
+ p.grad = None
212
 
213
 
214
  @torch.no_grad()
 
222
  raise RuntimeError("Gather event must be set before compute.")
223
  compute_stream.wait_event(state.gather_event)
224
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
225
+ state.gathered_grad = None
226
  state.computed_u = u
227
+ state.compute_event = torch.cuda.Event()
228
+ state.compute_event.record()
229
+ else:
230
+ state.computed_u = None
231
+ state.compute_event = None
232
 
233
 
234
  @torch.no_grad()
235
+ def _alloc_scattered_u(params, param_to_state, rank, compute_stream):
236
  """
237
+ Pre-allocate scattered_u buffer on compute_stream
238
+ before launching all2all gather
239
  """
240
+ with torch.cuda.stream(compute_stream):
241
+ for p in params:
242
+ state = param_to_state[id(p)]
243
+ state.scattered_u = torch.empty_like(p.to_local(),
244
+ dtype=COMM_DTYPE)
245
+
246
+ alloc_event = torch.cuda.Event()
247
+ alloc_event.record(compute_stream)
248
+ return alloc_event
249
+
250
 
251
+ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
252
+ """
253
+ All2all scatters full gradients to all ranks
254
+ """
255
  with torch.cuda.stream(comm_stream):
256
+ process_group = param_to_state[id(params[0])].process_group
257
+ num_ranks = dist.get_world_size(group=process_group)
258
+ owned_params = [
259
+ p for p in params if param_to_state[id(p)].worker_rank == rank
260
+ ]
261
+
262
+ # Construct sending buffer
263
+ per_dst = [[] for _ in range(num_ranks)]
264
+ send_counts = [0] * num_ranks
265
+
266
+ if owned_params:
267
+ for p in owned_params:
268
+ state = param_to_state[id(p)]
269
+ if state.compute_event is None:
270
+ raise RuntimeError(
271
+ "Compute event must be set before scatter.")
272
+ comm_stream.wait_event(state.compute_event)
273
+ state.gathered_grad = None
274
 
275
+ assert state.computed_u is not None
276
+
277
+ u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1)
278
+
279
+ offset = 0
280
+ for dst in range(num_ranks):
281
+ n = split_elems_for_src(p, dst, num_ranks)
282
+ assert n > 0
283
 
284
+ su = u_full.narrow(0, offset, n)
285
+ per_dst[dst].append(su)
286
+ send_counts[dst] += n
287
+ offset += n
288
+
289
+ assert offset == u_full.numel()
290
+
291
+ if any(len(v) > 0 for v in per_dst):
292
+ send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
293
  else:
294
+ # all_to_all requires participation from all ranks
295
+ # Even non-owner ranks must join the collective call
296
+ send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
297
 
298
+ # Compute receive sizes and allocate receiving buffers
299
+ recv_counts = [0] * num_ranks
300
+
301
+ for src in range(num_ranks):
302
+ total = 0
303
+ for p in params:
304
+ state = param_to_state[id(p)]
305
+ if state.worker_rank != src:
306
+ continue
307
+ total += split_elems_for_src(p, rank, num_ranks)
308
+ recv_counts[src] = total
309
+
310
+ recv_total = sum(recv_counts)
311
+ assert recv_total > 0
312
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
313
+
314
+ #All2All
315
+ dist.all_to_all_single(
316
+ recv_buf,
317
+ send_buf,
318
+ output_split_sizes=recv_counts,
319
+ input_split_sizes=send_counts,
320
+ group=process_group,
321
  )
322
+
323
+ # Copy to pre-allocated scattered_u buffer from the received buffer
324
+ #
325
+ # recv_buf (num ranks = 3, local_rank = 0)
326
+ #
327
+ # From rank 0 From rank 1 From rank 2
328
+ # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 |
329
+ #
330
+ # Outer loop:
331
+ # rank 0 -> rank 1 -> rank2
332
+ #
333
+ # Inner loop:
334
+ # src(0) : p1_0 -> p2_0 -> p3_0
335
+ # src(1) : p4_0
336
+ # src(2) : p5_0 -> p6_0
337
+
338
+ comm_stream.wait_event(alloc_event)
339
+
340
+ off = 0
341
+ for src in range(num_ranks):
342
+ block = recv_counts[src]
343
+ if block == 0:
344
+ continue
345
+
346
+ inner_off = 0
347
+ for p in params:
348
+ state = param_to_state[id(p)]
349
+ if state.worker_rank != src:
350
+ continue
351
+ n = split_elems_for_src(p, rank, num_ranks)
352
+ assert n > 0
353
+
354
+ flat_local = recv_buf.narrow(0, off + inner_off,
355
+ n).view_as(p.to_local())
356
+ state.scattered_u.copy_(flat_local)
357
+
358
+ state.scatter_event = torch.cuda.Event()
359
+ state.scatter_event.record(comm_stream)
360
+ inner_off += n
361
+
362
+ assert inner_off == block
363
+ off += block
364
 
365
 
366
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
 
516
  "head_dim": 128,
517
  "threshold": 100
518
  }
519
+ overlap_step : How many all2all gather, compute operations are launched in advance
520
+ before the corresponding all2all scatter steps begin.
521
+ A higher overlap_step increases memory usage but can improve
522
+ performance by overlapping communication.
523
+ Parallel muon only.
524
  """
525
 
526
  def __init__(self,
 
539
  "k_indices": [],
540
  "head_dim": 128,
541
  "threshold": 100
542
+ },
543
+ overlap_step=5):
544
  defaults = dict(
545
  lr=lr,
546
  weight_decay=weight_decay,
 
564
 
565
  super().__init__(params, defaults)
566
 
567
+ self.rank = None
 
 
 
568
 
569
  self.comm_stream = torch.cuda.Stream()
570
  self.compute_stream = torch.cuda.Stream()
571
  self.debug = debug
572
  self.clip_config = clip_config
573
+ self.overlap_step = overlap_step
574
 
575
  def _calc_flops(self, G, steps):
576
  assert len(G.shape) == 2
 
643
  if mesh is None:
644
  mesh = p.device_mesh
645
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
646
+ local_rank = dist.get_rank(group=process_group)
647
+ if self.rank is None:
648
+ self.rank = dist.get_rank(group=process_group)
649
+ else:
650
+ assert self.rank == local_rank
651
  elif mesh != p.device_mesh:
652
  raise ValueError("All parameters must be on the same mesh.")
653
 
654
+ num_ranks = dist.get_world_size(group=process_group)
655
  param_to_state[id(p)] = _muon_state()
656
+ param_to_state[id(
657
+ p)].worker_rank = shard_mesh[round_robin].item() % num_ranks
658
  param_to_state[id(p)].process_group = process_group
659
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
660
  param_to_state[id(p)].qk_clip_state = qk_clip_state
 
684
  else:
685
  g = buf
686
 
687
+ u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
688
  steps=group["ns_steps"])
689
 
690
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
 
699
  def _update_g(self, p, g, group, momentum):
700
  # calc update
701
  state = self.state[p]
702
+ buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
703
+ torch.add(g, buf, alpha=momentum, out=buf)
 
 
704
  if group["nesterov"]:
705
+ g.add_(buf, alpha=momentum)
706
+ return g
707
+ return buf
 
708
 
709
  @staticmethod
710
  def _update_p(p, u, lr, adjusted_lr, weight_decay):
 
788
  param_to_state, ordered_params = self.init_state_and_assign_params(
789
  names, params, group, qk_logits)
790
 
791
+ assert self.rank is not None
792
+
793
+ def enqueue_all2all_gather(start_idx, chunk_size):
794
+ target_params = ordered_params[start_idx:start_idx + chunk_size]
795
+ if target_params:
796
+ alloc_event = _alloc_gathered_grad(target_params,
797
+ param_to_state, self.rank,
798
+ self.compute_stream)
799
+ _all2all_gather(target_params, param_to_state, self.rank,
800
+ self.comm_stream, group["none_grad"],
801
+ alloc_event)
802
 
803
  def enqueue_computes(start_idx, chunk_size):
804
  for p in ordered_params[start_idx:start_idx + chunk_size]:
 
806
  _compute_u(p, state, group["ns_steps"], self.rank,
807
  self.compute_stream)
808
 
809
+ def enqueue_all2all_scatter(start_idx, chunk_size):
810
+ target_params = ordered_params[start_idx:start_idx + chunk_size]
811
+ if target_params:
812
+ alloc_event = _alloc_scattered_u(target_params, param_to_state,
813
+ self.rank,
814
+ self.compute_stream)
815
+ _all2all_scatter(target_params, param_to_state, self.rank,
816
+ self.comm_stream, alloc_event)
817
 
818
  def enqueue_update_param(start_idx, chunk_size):
819
  for p in ordered_params[start_idx:start_idx + chunk_size]:
 
828
  # Wait grad update
829
  self.comm_stream.wait_stream(torch.cuda.current_stream())
830
 
831
+ overlap_step = self.overlap_step
832
+ for i in range(0, overlap_step):
833
+ enqueue_all2all_gather(i * chunk_size, chunk_size)
834
+ enqueue_computes(i * chunk_size, chunk_size)
835
+
836
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
837
+ enqueue_all2all_scatter(i, chunk_size)
838
+ enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size)
839
+ enqueue_update_param(i, chunk_size)
840
+ enqueue_computes(i + overlap_step * chunk_size, chunk_size)
 
 
841
 
842
  # Wait the last update_param to finish
843
  torch.cuda.current_stream().wait_stream(self.compute_stream)
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_9c21645_dirty
3
- ops = torch.ops._optimizer_9c21645_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_9c21645_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_15336dc_dirty
3
+ ops = torch.ops._optimizer_15336dc_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_15336dc_dirty::{op_name}"
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/{_optimizer_9c21645_dirty.abi3.so → _optimizer_15336dc_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9a477575e3cc30e54d355b3e778240dc25fb0dab30362f3540dc5f925ac03ba1
3
  size 1750024
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8092bc6ee3e353b2188f0874bc7f145e4eafd0366a40da9750c225732961f7c7
3
  size 1750024
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/matmul_transpose_triton.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright (c) 2025 Tianyang Lin
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ import torch
24
+ import triton
25
+ import triton.language as tl
26
+
27
+
28
+ def get_autotune_config():
29
+ return [
30
+ triton.Config(
31
+ {
32
+ 'BLOCK_SIZE_M': blk_m,
33
+ 'BLOCK_SIZE_K': blk_k,
34
+ 'GROUP_SIZE_M': grp_sz
35
+ },
36
+ num_stages=n_stages,
37
+ num_warps=n_warps) for blk_m in [32, 64, 128]
38
+ for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5]
39
+ for n_warps in [4, 8]
40
+ ]
41
+
42
+
43
+ @triton.autotune(
44
+ configs=get_autotune_config(),
45
+ key=['M', 'K'],
46
+ )
47
+ @triton.jit
48
+ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
49
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
50
+ GROUP_SIZE_M: tl.constexpr):
51
+ """
52
+ Core kernel jit function of matmul_transpose that computes y = x @ x.T
53
+ The code is a simple adaptation from the triton `matmul` tutorial:
54
+ https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
55
+ """
56
+ pid = tl.program_id(axis=0)
57
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
58
+ num_pid_n = tl.cdiv(M, BLOCK_SIZE_M)
59
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
60
+ group_id = pid // num_pid_in_group
61
+ first_pid_m = group_id * GROUP_SIZE_M
62
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
63
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
64
+ pid_n = (pid % num_pid_in_group) // group_size_m
65
+ if pid_m > pid_n:
66
+ return
67
+
68
+ offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
69
+ offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
70
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
71
+ # we use a & b ptrs to denote different rows of x.
72
+ a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
73
+ b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk)
74
+
75
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32)
76
+
77
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
78
+ a = tl.load(a_ptrs,
79
+ mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
80
+ other=0.0)
81
+ b = tl.load(b_ptrs,
82
+ mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
83
+ other=0.0)
84
+ accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator)
85
+ a_ptrs += BLOCK_SIZE_K * stride_xk
86
+ b_ptrs += BLOCK_SIZE_K * stride_xk
87
+ # use dtype.element_ty to accommodate different input datatypes as in cpp templates
88
+ # https://github.com/triton-lang/triton/issues/2252
89
+ c = accumulator.to(x.dtype.element_ty)
90
+
91
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
92
+ offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
93
+ c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :]
94
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
95
+ tl.store(c_ptrs, c, mask=c_mask)
96
+
97
+ # transpose and copy
98
+ if pid_m < pid_n:
99
+ ct_ptrs = y + stride_ym * offs_cn[:,
100
+ None] + stride_yn * offs_cm[None, :]
101
+ ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M)
102
+ tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)
103
+
104
+
105
+ def matmul_transpose_assign(d_in, d_out):
106
+ assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor"
107
+ assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor"
108
+ assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device"
109
+ assert d_in.dtype == d_out.dtype, "Inputs must have the same data type"
110
+ assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor"
111
+ assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor"
112
+ assert d_in.size(0) == d_out.size(0) == d_out.size(0), \
113
+ "First dimension of `d_in` must match first and second dimension of `d_out`"
114
+
115
+ d_in = d_in.contiguous()
116
+ M, K = d_in.shape
117
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
118
+ M, META['BLOCK_SIZE_M']), )
119
+ with torch.cuda.device(d_in.device.index):
120
+ mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
121
+ d_out.stride(0), d_out.stride(1))
122
+
123
+
124
+ def matmul_transpose(d_in):
125
+ M, _ = d_in.shape
126
+ d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype)
127
+ matmul_transpose_assign(d_in, d_out)
128
+ return d_out
build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py CHANGED
@@ -8,14 +8,19 @@ import torch
8
  import torch.distributed as dist
9
  from torch.distributed._tensor import DTensor, Replicate, Shard
10
 
 
 
11
  logger = logging.getLogger(__name__)
12
 
 
 
13
 
14
  # This code snippet is a modified version adapted from the following GitHub repositories:
15
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
16
  # Muon's Newton–Schulz iteration causes high variance in singular values
17
  # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
18
  @torch.no_grad()
 
19
  def _zeropower_via_newtonschulz5(G, steps):
20
  """
21
  Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
@@ -27,13 +32,15 @@ def _zeropower_via_newtonschulz5(G, steps):
27
  performance at all relative to UV^T, where USV^T = G is the SVD.
28
  """
29
  assert len(G.shape) == 2
30
- assert G.dtype == torch.bfloat16
31
  X = G # no manual typecast
32
 
33
  if G.size(0) > G.size(1):
34
  X = X.T
35
  # Ensure spectral norm is at most 1
36
  X = X / (X.norm() + 1e-7)
 
 
37
  # Perform the NS iterations
38
  for a, b, c in [
39
  (4.0848, -6.8946, 2.9270),
@@ -42,13 +49,10 @@ def _zeropower_via_newtonschulz5(G, steps):
42
  (2.8769, -3.1427, 1.2046),
43
  (2.8366, -3.0525, 1.2012),
44
  ]:
45
- A = X @ X.T
46
- # B = (
47
- # b * A + c * A @ A
48
- # )
49
- B = torch.addmm(A, A, A, alpha=c, beta=b)
50
- # X = a * X + B @ X
51
- X = torch.addmm(X, B, X, alpha=1.0, beta=a)
52
 
53
  if G.size(0) > G.size(1):
54
  X = X.T
@@ -69,51 +73,142 @@ class _muon_state:
69
  qk_clip_state = None
70
 
71
 
 
 
 
 
 
 
 
 
72
  @torch.no_grad()
73
- def _gather(p, state, rank, comm_stream, none_grad):
74
  """
75
- Gather the gradients to worker_rank.
76
- If none_grad is True, free p.grad after the gather.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  """
78
  with torch.cuda.stream(comm_stream):
79
- g = p.grad
 
80
 
81
- if rank == state.worker_rank:
82
- num_ranks = dist.get_world_size(group=state.process_group)
83
- gather_list = [
84
- torch.empty_like(g.to_local(), dtype=torch.bfloat16)
85
- for _ in range(num_ranks)
86
- ]
87
- else:
88
- gather_list = None
89
-
90
- g = g.to(torch.bfloat16)
91
- torch.distributed.gather(
92
- g.to_local(),
93
- dst=state.worker_rank,
94
- gather_list=gather_list,
95
- group=state.process_group,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  )
97
- if rank == state.worker_rank:
98
- if state.gathered_grad is not None:
99
- raise RuntimeError(
100
- "Gather event already exists, which should not happen.")
101
- state.gathered_grad = torch.cat(gather_list, dim=0)
102
- state.gather_event = torch.cuda.Event()
103
- state.gather_event.record()
104
- else:
105
- state.gathered_grad = None
106
- state.gather_event = None
107
- gather_list = None
108
- if none_grad:
109
- # We can safely free p.grad without calling record_stream:
110
- # p.grad.to_local().record_stream(comm_stream)
111
- # Explanation:
112
- # 1. p.grad is created on the default stream, but the default stream
113
- # is synchronized with the comm stream later.
114
- # 2. There is no further activity on the default stream before the optimizer finishes.
115
- # Therefore, it is safe to free p.grad directly on the comm stream.
116
- p.grad = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
 
119
  @torch.no_grad()
@@ -127,45 +222,145 @@ def _compute_u(p, state, steps, rank, compute_stream):
127
  raise RuntimeError("Gather event must be set before compute.")
128
  compute_stream.wait_event(state.gather_event)
129
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
 
130
  state.computed_u = u
131
- state.scattered_u = torch.empty_like(p.to_local(),
132
- dtype=torch.bfloat16)
133
- state.compute_event = torch.cuda.Event()
134
- state.compute_event.record()
135
- u = None
136
 
137
 
138
  @torch.no_grad()
139
- def _scatter(p, state, rank, comm_stream):
140
  """
141
- Scatter the computed_u from worker_rank to all ranks.
 
142
  """
 
 
 
 
 
 
 
 
 
 
143
 
 
 
 
 
144
  with torch.cuda.stream(comm_stream):
145
- if state.compute_event is None:
146
- raise RuntimeError("Compute event must be set before scatter.")
147
- comm_stream.wait_event(state.compute_event)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
- if rank == state.worker_rank:
150
- num_ranks = dist.get_world_size(group=state.process_group)
151
- # Clear the gathered gradient to free memory
152
- state.gathered_grad = None
 
 
 
 
153
 
154
- u = state.computed_u
155
- scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
156
- scatter_list = [s.contiguous() for s in scatter_list]
 
 
 
 
 
 
157
  else:
158
- scatter_list = None
 
 
159
 
160
- torch.distributed.scatter(
161
- state.scattered_u,
162
- scatter_list=scatter_list,
163
- src=state.worker_rank,
164
- group=state.process_group,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  )
166
- state.scatter_event = torch.cuda.Event()
167
- state.scatter_event.record()
168
- scatter_list = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
 
171
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
@@ -321,6 +516,11 @@ class Muon(torch.optim.Optimizer):
321
  "head_dim": 128,
322
  "threshold": 100
323
  }
 
 
 
 
 
324
  """
325
 
326
  def __init__(self,
@@ -339,7 +539,8 @@ class Muon(torch.optim.Optimizer):
339
  "k_indices": [],
340
  "head_dim": 128,
341
  "threshold": 100
342
- }):
 
343
  defaults = dict(
344
  lr=lr,
345
  weight_decay=weight_decay,
@@ -363,15 +564,13 @@ class Muon(torch.optim.Optimizer):
363
 
364
  super().__init__(params, defaults)
365
 
366
- if dist.is_initialized():
367
- self.rank = dist.get_rank()
368
- else:
369
- self.rank = None
370
 
371
  self.comm_stream = torch.cuda.Stream()
372
  self.compute_stream = torch.cuda.Stream()
373
  self.debug = debug
374
  self.clip_config = clip_config
 
375
 
376
  def _calc_flops(self, G, steps):
377
  assert len(G.shape) == 2
@@ -444,11 +643,18 @@ class Muon(torch.optim.Optimizer):
444
  if mesh is None:
445
  mesh = p.device_mesh
446
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
 
 
 
 
 
447
  elif mesh != p.device_mesh:
448
  raise ValueError("All parameters must be on the same mesh.")
449
 
 
450
  param_to_state[id(p)] = _muon_state()
451
- param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
 
452
  param_to_state[id(p)].process_group = process_group
453
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
454
  param_to_state[id(p)].qk_clip_state = qk_clip_state
@@ -478,7 +684,7 @@ class Muon(torch.optim.Optimizer):
478
  else:
479
  g = buf
480
 
481
- u = _zeropower_via_newtonschulz5(g.bfloat16(),
482
  steps=group["ns_steps"])
483
 
484
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
@@ -493,15 +699,12 @@ class Muon(torch.optim.Optimizer):
493
  def _update_g(self, p, g, group, momentum):
494
  # calc update
495
  state = self.state[p]
496
- if "momentum_buffer" not in state:
497
- state["momentum_buffer"] = torch.zeros_like(g)
498
- buf = state["momentum_buffer"]
499
- buf.mul_(momentum).add_(g)
500
  if group["nesterov"]:
501
- g = g.add(buf, alpha=momentum)
502
- else:
503
- g = buf
504
- return g
505
 
506
  @staticmethod
507
  def _update_p(p, u, lr, adjusted_lr, weight_decay):
@@ -585,11 +788,17 @@ class Muon(torch.optim.Optimizer):
585
  param_to_state, ordered_params = self.init_state_and_assign_params(
586
  names, params, group, qk_logits)
587
 
588
- def enqueue_gathers(start_idx, chunk_size):
589
- for p in ordered_params[start_idx:start_idx + chunk_size]:
590
- state = param_to_state[id(p)]
591
- _gather(p, state, self.rank, self.comm_stream,
592
- group["none_grad"])
 
 
 
 
 
 
593
 
594
  def enqueue_computes(start_idx, chunk_size):
595
  for p in ordered_params[start_idx:start_idx + chunk_size]:
@@ -597,10 +806,14 @@ class Muon(torch.optim.Optimizer):
597
  _compute_u(p, state, group["ns_steps"], self.rank,
598
  self.compute_stream)
599
 
600
- def enqueue_scatters(start_idx, chunk_size):
601
- for p in ordered_params[start_idx:start_idx + chunk_size]:
602
- state = param_to_state[id(p)]
603
- _scatter(p, state, self.rank, self.comm_stream)
 
 
 
 
604
 
605
  def enqueue_update_param(start_idx, chunk_size):
606
  for p in ordered_params[start_idx:start_idx + chunk_size]:
@@ -615,14 +828,16 @@ class Muon(torch.optim.Optimizer):
615
  # Wait grad update
616
  self.comm_stream.wait_stream(torch.cuda.current_stream())
617
 
618
- enqueue_gathers(0, chunk_size)
 
 
 
 
619
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
620
- enqueue_computes(i, chunk_size)
621
- if i > 0:
622
- enqueue_update_param(i - chunk_size, chunk_size)
623
- enqueue_gathers(i + chunk_size, chunk_size)
624
- enqueue_scatters(i, chunk_size)
625
- enqueue_update_param(i, chunk_size)
626
 
627
  # Wait the last update_param to finish
628
  torch.cuda.current_stream().wait_stream(self.compute_stream)
 
8
  import torch.distributed as dist
9
  from torch.distributed._tensor import DTensor, Replicate, Shard
10
 
11
+ from .matmul_transpose_triton import matmul_transpose_assign
12
+
13
  logger = logging.getLogger(__name__)
14
 
15
+ COMM_DTYPE = torch.bfloat16
16
+
17
 
18
  # This code snippet is a modified version adapted from the following GitHub repositories:
19
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
20
  # Muon's Newton–Schulz iteration causes high variance in singular values
21
  # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
22
  @torch.no_grad()
23
+ # matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
24
  def _zeropower_via_newtonschulz5(G, steps):
25
  """
26
  Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
 
32
  performance at all relative to UV^T, where USV^T = G is the SVD.
33
  """
34
  assert len(G.shape) == 2
35
+ assert G.dtype == COMM_DTYPE
36
  X = G # no manual typecast
37
 
38
  if G.size(0) > G.size(1):
39
  X = X.T
40
  # Ensure spectral norm is at most 1
41
  X = X / (X.norm() + 1e-7)
42
+ buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
43
+ buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
44
  # Perform the NS iterations
45
  for a, b, c in [
46
  (4.0848, -6.8946, 2.9270),
 
49
  (2.8769, -3.1427, 1.2046),
50
  (2.8366, -3.0525, 1.2012),
51
  ]:
52
+ matmul_transpose_assign(X, buf1)
53
+ matmul_transpose_assign(buf1, buf2)
54
+ buf1.mul_(b).add_(buf2, alpha=c)
55
+ X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)
 
 
 
56
 
57
  if G.size(0) > G.size(1):
58
  X = X.T
 
73
  qk_clip_state = None
74
 
75
 
76
+ def split_elems_for_src(param, src_rank, num_ranks) -> int:
77
+ rows = param.shape[0]
78
+ cols = int(param.numel() // rows)
79
+ base, rem = divmod(rows, num_ranks)
80
+ my_rows = base + (1 if src_rank < rem else 0)
81
+ return my_rows * cols
82
+
83
+
84
  @torch.no_grad()
85
+ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
86
  """
87
+ Pre-allocate gathered_grad buffer on compute_stream
88
+ before launching all2all gather
89
+ """
90
+ with torch.cuda.stream(compute_stream):
91
+ for p in params:
92
+ state = param_to_state[id(p)]
93
+ if rank == state.worker_rank:
94
+ num_ranks = dist.get_world_size(group=state.process_group)
95
+ state.gathered_grad = torch.empty(p.grad.numel(),
96
+ dtype=COMM_DTYPE,
97
+ device="cuda")
98
+ else:
99
+ state.gathered_grad = None
100
+
101
+ alloc_event = torch.cuda.Event()
102
+ alloc_event.record(compute_stream)
103
+ return alloc_event
104
+
105
+
106
+ @torch.no_grad()
107
+ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
108
+ alloc_event):
109
+ """
110
+ All2all gathers shards so each owner rank reconstructs its full gradient
111
  """
112
  with torch.cuda.stream(comm_stream):
113
+ process_group = param_to_state[id(params[0])].process_group
114
+ num_ranks = dist.get_world_size(group=process_group)
115
 
116
+ # Construct sending buffers
117
+ per_dst = [[] for _ in range(num_ranks)]
118
+ send_counts = [0] * num_ranks
119
+
120
+ for p in params:
121
+ state = param_to_state[id(p)]
122
+ dst = state.worker_rank
123
+ assert dst < num_ranks
124
+ shard_elems = split_elems_for_src(p, rank, num_ranks)
125
+ g = p.grad
126
+ g = g.to_local().to(COMM_DTYPE).contiguous().view(-1)
127
+ assert g.numel() == shard_elems
128
+ per_dst[dst].append(g)
129
+ send_counts[dst] += shard_elems
130
+
131
+ assert all(
132
+ len(v) > 0
133
+ for v in per_dst), "all params should be sharded to all devices"
134
+
135
+ send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
136
+ owned_params = [
137
+ p for p in params if param_to_state[id(p)].worker_rank == rank
138
+ ]
139
+
140
+ # Compute receive sizes and allocate receiving buffers
141
+ recv_counts = [0] * num_ranks
142
+
143
+ for src in range(num_ranks):
144
+ total = 0
145
+ for p in owned_params:
146
+ state = param_to_state[id(p)]
147
+ assert state.worker_rank == rank
148
+ total += split_elems_for_src(p, src, num_ranks)
149
+ recv_counts[src] = total
150
+
151
+ recv_total = sum(recv_counts)
152
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
153
+
154
+ #All2All
155
+ dist.all_to_all_single(
156
+ recv_buf,
157
+ send_buf,
158
+ output_split_sizes=recv_counts,
159
+ input_split_sizes=send_counts,
160
+ group=process_group,
161
  )
162
+
163
+ # Reconstructs gathered grad from the received buffer
164
+ #
165
+ # recv_buf (num ranks = 3)
166
+ #
167
+ # From rank 0 From rank 1 From rank 2
168
+ # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 |
169
+ #
170
+ # Outer loop:
171
+ # rank 0 -> rank 1 -> rank2
172
+ #
173
+ # Inner loop:
174
+ # p1_n -> p2_n -> p3_n
175
+
176
+ comm_stream.wait_event(alloc_event)
177
+
178
+ off = 0
179
+ write_offsets = {id(p): 0 for p in owned_params}
180
+ for src in range(num_ranks):
181
+ if recv_counts[src] == 0:
182
+ continue
183
+
184
+ block = recv_counts[src]
185
+ inner_off = 0
186
+ for p in owned_params:
187
+ state = param_to_state[id(p)]
188
+ assert state.worker_rank == rank
189
+ n = split_elems_for_src(p, src, num_ranks)
190
+ assert n > 0
191
+
192
+ sg = recv_buf.narrow(0, off + inner_off, n)
193
+ woff = write_offsets[id(p)]
194
+ dst = state.gathered_grad.narrow(0, woff, n)
195
+ dst.copy_(sg)
196
+
197
+ write_offsets[id(p)] += n
198
+ inner_off += n
199
+ off += block
200
+
201
+ for p in params:
202
+ state = param_to_state[id(p)]
203
+ if state.worker_rank == rank:
204
+ state.gathered_grad = state.gathered_grad.view_as(p)
205
+ state.gather_event = torch.cuda.Event()
206
+ state.gather_event.record(comm_stream)
207
+ else:
208
+ state.gathered_grad = None
209
+ state.gather_event = None
210
+ if none_grad:
211
+ p.grad = None
212
 
213
 
214
  @torch.no_grad()
 
222
  raise RuntimeError("Gather event must be set before compute.")
223
  compute_stream.wait_event(state.gather_event)
224
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
225
+ state.gathered_grad = None
226
  state.computed_u = u
227
+ state.compute_event = torch.cuda.Event()
228
+ state.compute_event.record()
229
+ else:
230
+ state.computed_u = None
231
+ state.compute_event = None
232
 
233
 
234
  @torch.no_grad()
235
+ def _alloc_scattered_u(params, param_to_state, rank, compute_stream):
236
  """
237
+ Pre-allocate scattered_u buffer on compute_stream
238
+ before launching all2all gather
239
  """
240
+ with torch.cuda.stream(compute_stream):
241
+ for p in params:
242
+ state = param_to_state[id(p)]
243
+ state.scattered_u = torch.empty_like(p.to_local(),
244
+ dtype=COMM_DTYPE)
245
+
246
+ alloc_event = torch.cuda.Event()
247
+ alloc_event.record(compute_stream)
248
+ return alloc_event
249
+
250
 
251
+ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
252
+ """
253
+ All2all scatters full gradients to all ranks
254
+ """
255
  with torch.cuda.stream(comm_stream):
256
+ process_group = param_to_state[id(params[0])].process_group
257
+ num_ranks = dist.get_world_size(group=process_group)
258
+ owned_params = [
259
+ p for p in params if param_to_state[id(p)].worker_rank == rank
260
+ ]
261
+
262
+ # Construct sending buffer
263
+ per_dst = [[] for _ in range(num_ranks)]
264
+ send_counts = [0] * num_ranks
265
+
266
+ if owned_params:
267
+ for p in owned_params:
268
+ state = param_to_state[id(p)]
269
+ if state.compute_event is None:
270
+ raise RuntimeError(
271
+ "Compute event must be set before scatter.")
272
+ comm_stream.wait_event(state.compute_event)
273
+ state.gathered_grad = None
274
 
275
+ assert state.computed_u is not None
276
+
277
+ u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1)
278
+
279
+ offset = 0
280
+ for dst in range(num_ranks):
281
+ n = split_elems_for_src(p, dst, num_ranks)
282
+ assert n > 0
283
 
284
+ su = u_full.narrow(0, offset, n)
285
+ per_dst[dst].append(su)
286
+ send_counts[dst] += n
287
+ offset += n
288
+
289
+ assert offset == u_full.numel()
290
+
291
+ if any(len(v) > 0 for v in per_dst):
292
+ send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
293
  else:
294
+ # all_to_all requires participation from all ranks
295
+ # Even non-owner ranks must join the collective call
296
+ send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
297
 
298
+ # Compute receive sizes and allocate receiving buffers
299
+ recv_counts = [0] * num_ranks
300
+
301
+ for src in range(num_ranks):
302
+ total = 0
303
+ for p in params:
304
+ state = param_to_state[id(p)]
305
+ if state.worker_rank != src:
306
+ continue
307
+ total += split_elems_for_src(p, rank, num_ranks)
308
+ recv_counts[src] = total
309
+
310
+ recv_total = sum(recv_counts)
311
+ assert recv_total > 0
312
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
313
+
314
+ #All2All
315
+ dist.all_to_all_single(
316
+ recv_buf,
317
+ send_buf,
318
+ output_split_sizes=recv_counts,
319
+ input_split_sizes=send_counts,
320
+ group=process_group,
321
  )
322
+
323
+ # Copy to pre-allocated scattered_u buffer from the received buffer
324
+ #
325
+ # recv_buf (num ranks = 3, local_rank = 0)
326
+ #
327
+ # From rank 0 From rank 1 From rank 2
328
+ # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 |
329
+ #
330
+ # Outer loop:
331
+ # rank 0 -> rank 1 -> rank2
332
+ #
333
+ # Inner loop:
334
+ # src(0) : p1_0 -> p2_0 -> p3_0
335
+ # src(1) : p4_0
336
+ # src(2) : p5_0 -> p6_0
337
+
338
+ comm_stream.wait_event(alloc_event)
339
+
340
+ off = 0
341
+ for src in range(num_ranks):
342
+ block = recv_counts[src]
343
+ if block == 0:
344
+ continue
345
+
346
+ inner_off = 0
347
+ for p in params:
348
+ state = param_to_state[id(p)]
349
+ if state.worker_rank != src:
350
+ continue
351
+ n = split_elems_for_src(p, rank, num_ranks)
352
+ assert n > 0
353
+
354
+ flat_local = recv_buf.narrow(0, off + inner_off,
355
+ n).view_as(p.to_local())
356
+ state.scattered_u.copy_(flat_local)
357
+
358
+ state.scatter_event = torch.cuda.Event()
359
+ state.scatter_event.record(comm_stream)
360
+ inner_off += n
361
+
362
+ assert inner_off == block
363
+ off += block
364
 
365
 
366
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
 
516
  "head_dim": 128,
517
  "threshold": 100
518
  }
519
+ overlap_step : How many all2all gather, compute operations are launched in advance
520
+ before the corresponding all2all scatter steps begin.
521
+ A higher overlap_step increases memory usage but can improve
522
+ performance by overlapping communication.
523
+ Parallel muon only.
524
  """
525
 
526
  def __init__(self,
 
539
  "k_indices": [],
540
  "head_dim": 128,
541
  "threshold": 100
542
+ },
543
+ overlap_step=5):
544
  defaults = dict(
545
  lr=lr,
546
  weight_decay=weight_decay,
 
564
 
565
  super().__init__(params, defaults)
566
 
567
+ self.rank = None
 
 
 
568
 
569
  self.comm_stream = torch.cuda.Stream()
570
  self.compute_stream = torch.cuda.Stream()
571
  self.debug = debug
572
  self.clip_config = clip_config
573
+ self.overlap_step = overlap_step
574
 
575
  def _calc_flops(self, G, steps):
576
  assert len(G.shape) == 2
 
643
  if mesh is None:
644
  mesh = p.device_mesh
645
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
646
+ local_rank = dist.get_rank(group=process_group)
647
+ if self.rank is None:
648
+ self.rank = dist.get_rank(group=process_group)
649
+ else:
650
+ assert self.rank == local_rank
651
  elif mesh != p.device_mesh:
652
  raise ValueError("All parameters must be on the same mesh.")
653
 
654
+ num_ranks = dist.get_world_size(group=process_group)
655
  param_to_state[id(p)] = _muon_state()
656
+ param_to_state[id(
657
+ p)].worker_rank = shard_mesh[round_robin].item() % num_ranks
658
  param_to_state[id(p)].process_group = process_group
659
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
660
  param_to_state[id(p)].qk_clip_state = qk_clip_state
 
684
  else:
685
  g = buf
686
 
687
+ u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
688
  steps=group["ns_steps"])
689
 
690
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
 
699
  def _update_g(self, p, g, group, momentum):
700
  # calc update
701
  state = self.state[p]
702
+ buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
703
+ torch.add(g, buf, alpha=momentum, out=buf)
 
 
704
  if group["nesterov"]:
705
+ g.add_(buf, alpha=momentum)
706
+ return g
707
+ return buf
 
708
 
709
  @staticmethod
710
  def _update_p(p, u, lr, adjusted_lr, weight_decay):
 
788
  param_to_state, ordered_params = self.init_state_and_assign_params(
789
  names, params, group, qk_logits)
790
 
791
+ assert self.rank is not None
792
+
793
+ def enqueue_all2all_gather(start_idx, chunk_size):
794
+ target_params = ordered_params[start_idx:start_idx + chunk_size]
795
+ if target_params:
796
+ alloc_event = _alloc_gathered_grad(target_params,
797
+ param_to_state, self.rank,
798
+ self.compute_stream)
799
+ _all2all_gather(target_params, param_to_state, self.rank,
800
+ self.comm_stream, group["none_grad"],
801
+ alloc_event)
802
 
803
  def enqueue_computes(start_idx, chunk_size):
804
  for p in ordered_params[start_idx:start_idx + chunk_size]:
 
806
  _compute_u(p, state, group["ns_steps"], self.rank,
807
  self.compute_stream)
808
 
809
+ def enqueue_all2all_scatter(start_idx, chunk_size):
810
+ target_params = ordered_params[start_idx:start_idx + chunk_size]
811
+ if target_params:
812
+ alloc_event = _alloc_scattered_u(target_params, param_to_state,
813
+ self.rank,
814
+ self.compute_stream)
815
+ _all2all_scatter(target_params, param_to_state, self.rank,
816
+ self.comm_stream, alloc_event)
817
 
818
  def enqueue_update_param(start_idx, chunk_size):
819
  for p in ordered_params[start_idx:start_idx + chunk_size]:
 
828
  # Wait grad update
829
  self.comm_stream.wait_stream(torch.cuda.current_stream())
830
 
831
+ overlap_step = self.overlap_step
832
+ for i in range(0, overlap_step):
833
+ enqueue_all2all_gather(i * chunk_size, chunk_size)
834
+ enqueue_computes(i * chunk_size, chunk_size)
835
+
836
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
837
+ enqueue_all2all_scatter(i, chunk_size)
838
+ enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size)
839
+ enqueue_update_param(i, chunk_size)
840
+ enqueue_computes(i + overlap_step * chunk_size, chunk_size)
 
 
841
 
842
  # Wait the last update_param to finish
843
  torch.cuda.current_stream().wait_stream(self.compute_stream)
test/test_muon/muon.py DELETED
@@ -1 +0,0 @@
1
- ../../torch-ext/optimizer/muon.py
 
 
test/test_muon/optimizer ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../torch-ext/optimizer/
test/test_muon/test.py CHANGED
@@ -2,7 +2,7 @@ import logging
2
 
3
  import torch
4
  import torch.distributed as dist
5
- from muon import Muon, get_default_muon_param_groups
6
  from torch.distributed.fsdp import FSDPModule, fully_shard
7
  from torch.distributed.tensor import DTensor
8
  from torch.distributed.tensor.placement_types import Replicate
 
2
 
3
  import torch
4
  import torch.distributed as dist
5
+ from optimizer.muon import Muon, get_default_muon_param_groups
6
  from torch.distributed.fsdp import FSDPModule, fully_shard
7
  from torch.distributed.tensor import DTensor
8
  from torch.distributed.tensor.placement_types import Replicate
torch-ext/optimizer/matmul_transpose_triton.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright (c) 2025 Tianyang Lin
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ import torch
24
+ import triton
25
+ import triton.language as tl
26
+
27
+
28
+ def get_autotune_config():
29
+ return [
30
+ triton.Config(
31
+ {
32
+ 'BLOCK_SIZE_M': blk_m,
33
+ 'BLOCK_SIZE_K': blk_k,
34
+ 'GROUP_SIZE_M': grp_sz
35
+ },
36
+ num_stages=n_stages,
37
+ num_warps=n_warps) for blk_m in [32, 64, 128]
38
+ for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5]
39
+ for n_warps in [4, 8]
40
+ ]
41
+
42
+
43
+ @triton.autotune(
44
+ configs=get_autotune_config(),
45
+ key=['M', 'K'],
46
+ )
47
+ @triton.jit
48
+ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
49
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
50
+ GROUP_SIZE_M: tl.constexpr):
51
+ """
52
+ Core kernel jit function of matmul_transpose that computes y = x @ x.T
53
+ The code is a simple adaptation from the triton `matmul` tutorial:
54
+ https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
55
+ """
56
+ pid = tl.program_id(axis=0)
57
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
58
+ num_pid_n = tl.cdiv(M, BLOCK_SIZE_M)
59
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
60
+ group_id = pid // num_pid_in_group
61
+ first_pid_m = group_id * GROUP_SIZE_M
62
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
63
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
64
+ pid_n = (pid % num_pid_in_group) // group_size_m
65
+ if pid_m > pid_n:
66
+ return
67
+
68
+ offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
69
+ offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
70
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
71
+ # we use a & b ptrs to denote different rows of x.
72
+ a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
73
+ b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk)
74
+
75
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32)
76
+
77
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
78
+ a = tl.load(a_ptrs,
79
+ mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
80
+ other=0.0)
81
+ b = tl.load(b_ptrs,
82
+ mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
83
+ other=0.0)
84
+ accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator)
85
+ a_ptrs += BLOCK_SIZE_K * stride_xk
86
+ b_ptrs += BLOCK_SIZE_K * stride_xk
87
+ # use dtype.element_ty to accommodate different input datatypes as in cpp templates
88
+ # https://github.com/triton-lang/triton/issues/2252
89
+ c = accumulator.to(x.dtype.element_ty)
90
+
91
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
92
+ offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
93
+ c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :]
94
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
95
+ tl.store(c_ptrs, c, mask=c_mask)
96
+
97
+ # transpose and copy
98
+ if pid_m < pid_n:
99
+ ct_ptrs = y + stride_ym * offs_cn[:,
100
+ None] + stride_yn * offs_cm[None, :]
101
+ ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M)
102
+ tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)
103
+
104
+
105
+ def matmul_transpose_assign(d_in, d_out):
106
+ assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor"
107
+ assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor"
108
+ assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device"
109
+ assert d_in.dtype == d_out.dtype, "Inputs must have the same data type"
110
+ assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor"
111
+ assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor"
112
+ assert d_in.size(0) == d_out.size(0) == d_out.size(0), \
113
+ "First dimension of `d_in` must match first and second dimension of `d_out`"
114
+
115
+ d_in = d_in.contiguous()
116
+ M, K = d_in.shape
117
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
118
+ M, META['BLOCK_SIZE_M']), )
119
+ with torch.cuda.device(d_in.device.index):
120
+ mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
121
+ d_out.stride(0), d_out.stride(1))
122
+
123
+
124
+ def matmul_transpose(d_in):
125
+ M, _ = d_in.shape
126
+ d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype)
127
+ matmul_transpose_assign(d_in, d_out)
128
+ return d_out
torch-ext/optimizer/muon.py CHANGED
@@ -8,14 +8,19 @@ import torch
8
  import torch.distributed as dist
9
  from torch.distributed._tensor import DTensor, Replicate, Shard
10
 
 
 
11
  logger = logging.getLogger(__name__)
12
 
 
 
13
 
14
  # This code snippet is a modified version adapted from the following GitHub repositories:
15
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
16
  # Muon's Newton–Schulz iteration causes high variance in singular values
17
  # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
18
  @torch.no_grad()
 
19
  def _zeropower_via_newtonschulz5(G, steps):
20
  """
21
  Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
@@ -27,13 +32,15 @@ def _zeropower_via_newtonschulz5(G, steps):
27
  performance at all relative to UV^T, where USV^T = G is the SVD.
28
  """
29
  assert len(G.shape) == 2
30
- assert G.dtype == torch.bfloat16
31
  X = G # no manual typecast
32
 
33
  if G.size(0) > G.size(1):
34
  X = X.T
35
  # Ensure spectral norm is at most 1
36
  X = X / (X.norm() + 1e-7)
 
 
37
  # Perform the NS iterations
38
  for a, b, c in [
39
  (4.0848, -6.8946, 2.9270),
@@ -42,13 +49,10 @@ def _zeropower_via_newtonschulz5(G, steps):
42
  (2.8769, -3.1427, 1.2046),
43
  (2.8366, -3.0525, 1.2012),
44
  ]:
45
- A = X @ X.T
46
- # B = (
47
- # b * A + c * A @ A
48
- # )
49
- B = torch.addmm(A, A, A, alpha=c, beta=b)
50
- # X = a * X + B @ X
51
- X = torch.addmm(X, B, X, alpha=1.0, beta=a)
52
 
53
  if G.size(0) > G.size(1):
54
  X = X.T
@@ -69,51 +73,142 @@ class _muon_state:
69
  qk_clip_state = None
70
 
71
 
 
 
 
 
 
 
 
 
72
  @torch.no_grad()
73
- def _gather(p, state, rank, comm_stream, none_grad):
74
  """
75
- Gather the gradients to worker_rank.
76
- If none_grad is True, free p.grad after the gather.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  """
78
  with torch.cuda.stream(comm_stream):
79
- g = p.grad
 
80
 
81
- if rank == state.worker_rank:
82
- num_ranks = dist.get_world_size(group=state.process_group)
83
- gather_list = [
84
- torch.empty_like(g.to_local(), dtype=torch.bfloat16)
85
- for _ in range(num_ranks)
86
- ]
87
- else:
88
- gather_list = None
89
-
90
- g = g.to(torch.bfloat16)
91
- torch.distributed.gather(
92
- g.to_local(),
93
- dst=state.worker_rank,
94
- gather_list=gather_list,
95
- group=state.process_group,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  )
97
- if rank == state.worker_rank:
98
- if state.gathered_grad is not None:
99
- raise RuntimeError(
100
- "Gather event already exists, which should not happen.")
101
- state.gathered_grad = torch.cat(gather_list, dim=0)
102
- state.gather_event = torch.cuda.Event()
103
- state.gather_event.record()
104
- else:
105
- state.gathered_grad = None
106
- state.gather_event = None
107
- gather_list = None
108
- if none_grad:
109
- # We can safely free p.grad without calling record_stream:
110
- # p.grad.to_local().record_stream(comm_stream)
111
- # Explanation:
112
- # 1. p.grad is created on the default stream, but the default stream
113
- # is synchronized with the comm stream later.
114
- # 2. There is no further activity on the default stream before the optimizer finishes.
115
- # Therefore, it is safe to free p.grad directly on the comm stream.
116
- p.grad = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
 
119
  @torch.no_grad()
@@ -127,45 +222,145 @@ def _compute_u(p, state, steps, rank, compute_stream):
127
  raise RuntimeError("Gather event must be set before compute.")
128
  compute_stream.wait_event(state.gather_event)
129
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
 
130
  state.computed_u = u
131
- state.scattered_u = torch.empty_like(p.to_local(),
132
- dtype=torch.bfloat16)
133
- state.compute_event = torch.cuda.Event()
134
- state.compute_event.record()
135
- u = None
136
 
137
 
138
  @torch.no_grad()
139
- def _scatter(p, state, rank, comm_stream):
140
  """
141
- Scatter the computed_u from worker_rank to all ranks.
 
142
  """
 
 
 
 
 
 
 
 
 
 
143
 
 
 
 
 
144
  with torch.cuda.stream(comm_stream):
145
- if state.compute_event is None:
146
- raise RuntimeError("Compute event must be set before scatter.")
147
- comm_stream.wait_event(state.compute_event)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
- if rank == state.worker_rank:
150
- num_ranks = dist.get_world_size(group=state.process_group)
151
- # Clear the gathered gradient to free memory
152
- state.gathered_grad = None
 
 
 
 
153
 
154
- u = state.computed_u
155
- scatter_list = list(torch.split(u, p.size(0) // num_ranks, dim=0))
156
- scatter_list = [s.contiguous() for s in scatter_list]
 
 
 
 
 
 
157
  else:
158
- scatter_list = None
 
 
159
 
160
- torch.distributed.scatter(
161
- state.scattered_u,
162
- scatter_list=scatter_list,
163
- src=state.worker_rank,
164
- group=state.process_group,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  )
166
- state.scatter_event = torch.cuda.Event()
167
- state.scatter_event.record()
168
- scatter_list = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
 
171
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
@@ -321,6 +516,11 @@ class Muon(torch.optim.Optimizer):
321
  "head_dim": 128,
322
  "threshold": 100
323
  }
 
 
 
 
 
324
  """
325
 
326
  def __init__(self,
@@ -339,7 +539,8 @@ class Muon(torch.optim.Optimizer):
339
  "k_indices": [],
340
  "head_dim": 128,
341
  "threshold": 100
342
- }):
 
343
  defaults = dict(
344
  lr=lr,
345
  weight_decay=weight_decay,
@@ -363,15 +564,13 @@ class Muon(torch.optim.Optimizer):
363
 
364
  super().__init__(params, defaults)
365
 
366
- if dist.is_initialized():
367
- self.rank = dist.get_rank()
368
- else:
369
- self.rank = None
370
 
371
  self.comm_stream = torch.cuda.Stream()
372
  self.compute_stream = torch.cuda.Stream()
373
  self.debug = debug
374
  self.clip_config = clip_config
 
375
 
376
  def _calc_flops(self, G, steps):
377
  assert len(G.shape) == 2
@@ -444,11 +643,18 @@ class Muon(torch.optim.Optimizer):
444
  if mesh is None:
445
  mesh = p.device_mesh
446
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
 
 
 
 
 
447
  elif mesh != p.device_mesh:
448
  raise ValueError("All parameters must be on the same mesh.")
449
 
 
450
  param_to_state[id(p)] = _muon_state()
451
- param_to_state[id(p)].worker_rank = shard_mesh[round_robin].item()
 
452
  param_to_state[id(p)].process_group = process_group
453
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
454
  param_to_state[id(p)].qk_clip_state = qk_clip_state
@@ -478,7 +684,7 @@ class Muon(torch.optim.Optimizer):
478
  else:
479
  g = buf
480
 
481
- u = _zeropower_via_newtonschulz5(g.bfloat16(),
482
  steps=group["ns_steps"])
483
 
484
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
@@ -493,15 +699,12 @@ class Muon(torch.optim.Optimizer):
493
  def _update_g(self, p, g, group, momentum):
494
  # calc update
495
  state = self.state[p]
496
- if "momentum_buffer" not in state:
497
- state["momentum_buffer"] = torch.zeros_like(g)
498
- buf = state["momentum_buffer"]
499
- buf.mul_(momentum).add_(g)
500
  if group["nesterov"]:
501
- g = g.add(buf, alpha=momentum)
502
- else:
503
- g = buf
504
- return g
505
 
506
  @staticmethod
507
  def _update_p(p, u, lr, adjusted_lr, weight_decay):
@@ -585,11 +788,17 @@ class Muon(torch.optim.Optimizer):
585
  param_to_state, ordered_params = self.init_state_and_assign_params(
586
  names, params, group, qk_logits)
587
 
588
- def enqueue_gathers(start_idx, chunk_size):
589
- for p in ordered_params[start_idx:start_idx + chunk_size]:
590
- state = param_to_state[id(p)]
591
- _gather(p, state, self.rank, self.comm_stream,
592
- group["none_grad"])
 
 
 
 
 
 
593
 
594
  def enqueue_computes(start_idx, chunk_size):
595
  for p in ordered_params[start_idx:start_idx + chunk_size]:
@@ -597,10 +806,14 @@ class Muon(torch.optim.Optimizer):
597
  _compute_u(p, state, group["ns_steps"], self.rank,
598
  self.compute_stream)
599
 
600
- def enqueue_scatters(start_idx, chunk_size):
601
- for p in ordered_params[start_idx:start_idx + chunk_size]:
602
- state = param_to_state[id(p)]
603
- _scatter(p, state, self.rank, self.comm_stream)
 
 
 
 
604
 
605
  def enqueue_update_param(start_idx, chunk_size):
606
  for p in ordered_params[start_idx:start_idx + chunk_size]:
@@ -615,14 +828,16 @@ class Muon(torch.optim.Optimizer):
615
  # Wait grad update
616
  self.comm_stream.wait_stream(torch.cuda.current_stream())
617
 
618
- enqueue_gathers(0, chunk_size)
 
 
 
 
619
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
620
- enqueue_computes(i, chunk_size)
621
- if i > 0:
622
- enqueue_update_param(i - chunk_size, chunk_size)
623
- enqueue_gathers(i + chunk_size, chunk_size)
624
- enqueue_scatters(i, chunk_size)
625
- enqueue_update_param(i, chunk_size)
626
 
627
  # Wait the last update_param to finish
628
  torch.cuda.current_stream().wait_stream(self.compute_stream)
 
8
  import torch.distributed as dist
9
  from torch.distributed._tensor import DTensor, Replicate, Shard
10
 
11
+ from .matmul_transpose_triton import matmul_transpose_assign
12
+
13
  logger = logging.getLogger(__name__)
14
 
15
+ COMM_DTYPE = torch.bfloat16
16
+
17
 
18
  # This code snippet is a modified version adapted from the following GitHub repositories:
19
  # https://github.com/KellerJordan/Muon/blob/master/muon.py
20
  # Muon's Newton–Schulz iteration causes high variance in singular values
21
  # Idea: give each iteration its own 3 coefficients and optimize them via gradient descent.
22
  @torch.no_grad()
23
+ # matmul_transpose_assign from : https://github.com/nil0x9/flash-muon
24
  def _zeropower_via_newtonschulz5(G, steps):
25
  """
26
  Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
 
32
  performance at all relative to UV^T, where USV^T = G is the SVD.
33
  """
34
  assert len(G.shape) == 2
35
+ assert G.dtype == COMM_DTYPE
36
  X = G # no manual typecast
37
 
38
  if G.size(0) > G.size(1):
39
  X = X.T
40
  # Ensure spectral norm is at most 1
41
  X = X / (X.norm() + 1e-7)
42
+ buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
43
+ buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
44
  # Perform the NS iterations
45
  for a, b, c in [
46
  (4.0848, -6.8946, 2.9270),
 
49
  (2.8769, -3.1427, 1.2046),
50
  (2.8366, -3.0525, 1.2012),
51
  ]:
52
+ matmul_transpose_assign(X, buf1)
53
+ matmul_transpose_assign(buf1, buf2)
54
+ buf1.mul_(b).add_(buf2, alpha=c)
55
+ X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)
 
 
 
56
 
57
  if G.size(0) > G.size(1):
58
  X = X.T
 
73
  qk_clip_state = None
74
 
75
 
76
+ def split_elems_for_src(param, src_rank, num_ranks) -> int:
77
+ rows = param.shape[0]
78
+ cols = int(param.numel() // rows)
79
+ base, rem = divmod(rows, num_ranks)
80
+ my_rows = base + (1 if src_rank < rem else 0)
81
+ return my_rows * cols
82
+
83
+
84
  @torch.no_grad()
85
+ def _alloc_gathered_grad(params, param_to_state, rank, compute_stream):
86
  """
87
+ Pre-allocate gathered_grad buffer on compute_stream
88
+ before launching all2all gather
89
+ """
90
+ with torch.cuda.stream(compute_stream):
91
+ for p in params:
92
+ state = param_to_state[id(p)]
93
+ if rank == state.worker_rank:
94
+ num_ranks = dist.get_world_size(group=state.process_group)
95
+ state.gathered_grad = torch.empty(p.grad.numel(),
96
+ dtype=COMM_DTYPE,
97
+ device="cuda")
98
+ else:
99
+ state.gathered_grad = None
100
+
101
+ alloc_event = torch.cuda.Event()
102
+ alloc_event.record(compute_stream)
103
+ return alloc_event
104
+
105
+
106
+ @torch.no_grad()
107
+ def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad,
108
+ alloc_event):
109
+ """
110
+ All2all gathers shards so each owner rank reconstructs its full gradient
111
  """
112
  with torch.cuda.stream(comm_stream):
113
+ process_group = param_to_state[id(params[0])].process_group
114
+ num_ranks = dist.get_world_size(group=process_group)
115
 
116
+ # Construct sending buffers
117
+ per_dst = [[] for _ in range(num_ranks)]
118
+ send_counts = [0] * num_ranks
119
+
120
+ for p in params:
121
+ state = param_to_state[id(p)]
122
+ dst = state.worker_rank
123
+ assert dst < num_ranks
124
+ shard_elems = split_elems_for_src(p, rank, num_ranks)
125
+ g = p.grad
126
+ g = g.to_local().to(COMM_DTYPE).contiguous().view(-1)
127
+ assert g.numel() == shard_elems
128
+ per_dst[dst].append(g)
129
+ send_counts[dst] += shard_elems
130
+
131
+ assert all(
132
+ len(v) > 0
133
+ for v in per_dst), "all params should be sharded to all devices"
134
+
135
+ send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
136
+ owned_params = [
137
+ p for p in params if param_to_state[id(p)].worker_rank == rank
138
+ ]
139
+
140
+ # Compute receive sizes and allocate receiving buffers
141
+ recv_counts = [0] * num_ranks
142
+
143
+ for src in range(num_ranks):
144
+ total = 0
145
+ for p in owned_params:
146
+ state = param_to_state[id(p)]
147
+ assert state.worker_rank == rank
148
+ total += split_elems_for_src(p, src, num_ranks)
149
+ recv_counts[src] = total
150
+
151
+ recv_total = sum(recv_counts)
152
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
153
+
154
+ #All2All
155
+ dist.all_to_all_single(
156
+ recv_buf,
157
+ send_buf,
158
+ output_split_sizes=recv_counts,
159
+ input_split_sizes=send_counts,
160
+ group=process_group,
161
  )
162
+
163
+ # Reconstructs gathered grad from the received buffer
164
+ #
165
+ # recv_buf (num ranks = 3)
166
+ #
167
+ # From rank 0 From rank 1 From rank 2
168
+ # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 |
169
+ #
170
+ # Outer loop:
171
+ # rank 0 -> rank 1 -> rank2
172
+ #
173
+ # Inner loop:
174
+ # p1_n -> p2_n -> p3_n
175
+
176
+ comm_stream.wait_event(alloc_event)
177
+
178
+ off = 0
179
+ write_offsets = {id(p): 0 for p in owned_params}
180
+ for src in range(num_ranks):
181
+ if recv_counts[src] == 0:
182
+ continue
183
+
184
+ block = recv_counts[src]
185
+ inner_off = 0
186
+ for p in owned_params:
187
+ state = param_to_state[id(p)]
188
+ assert state.worker_rank == rank
189
+ n = split_elems_for_src(p, src, num_ranks)
190
+ assert n > 0
191
+
192
+ sg = recv_buf.narrow(0, off + inner_off, n)
193
+ woff = write_offsets[id(p)]
194
+ dst = state.gathered_grad.narrow(0, woff, n)
195
+ dst.copy_(sg)
196
+
197
+ write_offsets[id(p)] += n
198
+ inner_off += n
199
+ off += block
200
+
201
+ for p in params:
202
+ state = param_to_state[id(p)]
203
+ if state.worker_rank == rank:
204
+ state.gathered_grad = state.gathered_grad.view_as(p)
205
+ state.gather_event = torch.cuda.Event()
206
+ state.gather_event.record(comm_stream)
207
+ else:
208
+ state.gathered_grad = None
209
+ state.gather_event = None
210
+ if none_grad:
211
+ p.grad = None
212
 
213
 
214
  @torch.no_grad()
 
222
  raise RuntimeError("Gather event must be set before compute.")
223
  compute_stream.wait_event(state.gather_event)
224
  u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
225
+ state.gathered_grad = None
226
  state.computed_u = u
227
+ state.compute_event = torch.cuda.Event()
228
+ state.compute_event.record()
229
+ else:
230
+ state.computed_u = None
231
+ state.compute_event = None
232
 
233
 
234
  @torch.no_grad()
235
+ def _alloc_scattered_u(params, param_to_state, rank, compute_stream):
236
  """
237
+ Pre-allocate scattered_u buffer on compute_stream
238
+ before launching all2all gather
239
  """
240
+ with torch.cuda.stream(compute_stream):
241
+ for p in params:
242
+ state = param_to_state[id(p)]
243
+ state.scattered_u = torch.empty_like(p.to_local(),
244
+ dtype=COMM_DTYPE)
245
+
246
+ alloc_event = torch.cuda.Event()
247
+ alloc_event.record(compute_stream)
248
+ return alloc_event
249
+
250
 
251
+ def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event):
252
+ """
253
+ All2all scatters full gradients to all ranks
254
+ """
255
  with torch.cuda.stream(comm_stream):
256
+ process_group = param_to_state[id(params[0])].process_group
257
+ num_ranks = dist.get_world_size(group=process_group)
258
+ owned_params = [
259
+ p for p in params if param_to_state[id(p)].worker_rank == rank
260
+ ]
261
+
262
+ # Construct sending buffer
263
+ per_dst = [[] for _ in range(num_ranks)]
264
+ send_counts = [0] * num_ranks
265
+
266
+ if owned_params:
267
+ for p in owned_params:
268
+ state = param_to_state[id(p)]
269
+ if state.compute_event is None:
270
+ raise RuntimeError(
271
+ "Compute event must be set before scatter.")
272
+ comm_stream.wait_event(state.compute_event)
273
+ state.gathered_grad = None
274
 
275
+ assert state.computed_u is not None
276
+
277
+ u_full = state.computed_u.to(COMM_DTYPE).contiguous().view(-1)
278
+
279
+ offset = 0
280
+ for dst in range(num_ranks):
281
+ n = split_elems_for_src(p, dst, num_ranks)
282
+ assert n > 0
283
 
284
+ su = u_full.narrow(0, offset, n)
285
+ per_dst[dst].append(su)
286
+ send_counts[dst] += n
287
+ offset += n
288
+
289
+ assert offset == u_full.numel()
290
+
291
+ if any(len(v) > 0 for v in per_dst):
292
+ send_buf = torch.cat([torch.cat(v, dim=0) for v in per_dst], dim=0)
293
  else:
294
+ # all_to_all requires participation from all ranks
295
+ # Even non-owner ranks must join the collective call
296
+ send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
297
 
298
+ # Compute receive sizes and allocate receiving buffers
299
+ recv_counts = [0] * num_ranks
300
+
301
+ for src in range(num_ranks):
302
+ total = 0
303
+ for p in params:
304
+ state = param_to_state[id(p)]
305
+ if state.worker_rank != src:
306
+ continue
307
+ total += split_elems_for_src(p, rank, num_ranks)
308
+ recv_counts[src] = total
309
+
310
+ recv_total = sum(recv_counts)
311
+ assert recv_total > 0
312
+ recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
313
+
314
+ #All2All
315
+ dist.all_to_all_single(
316
+ recv_buf,
317
+ send_buf,
318
+ output_split_sizes=recv_counts,
319
+ input_split_sizes=send_counts,
320
+ group=process_group,
321
  )
322
+
323
+ # Copy to pre-allocated scattered_u buffer from the received buffer
324
+ #
325
+ # recv_buf (num ranks = 3, local_rank = 0)
326
+ #
327
+ # From rank 0 From rank 1 From rank 2
328
+ # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 |
329
+ #
330
+ # Outer loop:
331
+ # rank 0 -> rank 1 -> rank2
332
+ #
333
+ # Inner loop:
334
+ # src(0) : p1_0 -> p2_0 -> p3_0
335
+ # src(1) : p4_0
336
+ # src(2) : p5_0 -> p6_0
337
+
338
+ comm_stream.wait_event(alloc_event)
339
+
340
+ off = 0
341
+ for src in range(num_ranks):
342
+ block = recv_counts[src]
343
+ if block == 0:
344
+ continue
345
+
346
+ inner_off = 0
347
+ for p in params:
348
+ state = param_to_state[id(p)]
349
+ if state.worker_rank != src:
350
+ continue
351
+ n = split_elems_for_src(p, rank, num_ranks)
352
+ assert n > 0
353
+
354
+ flat_local = recv_buf.narrow(0, off + inner_off,
355
+ n).view_as(p.to_local())
356
+ state.scattered_u.copy_(flat_local)
357
+
358
+ state.scatter_event = torch.cuda.Event()
359
+ state.scatter_event.record(comm_stream)
360
+ inner_off += n
361
+
362
+ assert inner_off == block
363
+ off += block
364
 
365
 
366
  def _update_param(p, state, lr, adjusted_lr, weight_decay, rank,
 
516
  "head_dim": 128,
517
  "threshold": 100
518
  }
519
+ overlap_step : How many all2all gather, compute operations are launched in advance
520
+ before the corresponding all2all scatter steps begin.
521
+ A higher overlap_step increases memory usage but can improve
522
+ performance by overlapping communication.
523
+ Parallel muon only.
524
  """
525
 
526
  def __init__(self,
 
539
  "k_indices": [],
540
  "head_dim": 128,
541
  "threshold": 100
542
+ },
543
+ overlap_step=5):
544
  defaults = dict(
545
  lr=lr,
546
  weight_decay=weight_decay,
 
564
 
565
  super().__init__(params, defaults)
566
 
567
+ self.rank = None
 
 
 
568
 
569
  self.comm_stream = torch.cuda.Stream()
570
  self.compute_stream = torch.cuda.Stream()
571
  self.debug = debug
572
  self.clip_config = clip_config
573
+ self.overlap_step = overlap_step
574
 
575
  def _calc_flops(self, G, steps):
576
  assert len(G.shape) == 2
 
643
  if mesh is None:
644
  mesh = p.device_mesh
645
  shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
646
+ local_rank = dist.get_rank(group=process_group)
647
+ if self.rank is None:
648
+ self.rank = dist.get_rank(group=process_group)
649
+ else:
650
+ assert self.rank == local_rank
651
  elif mesh != p.device_mesh:
652
  raise ValueError("All parameters must be on the same mesh.")
653
 
654
+ num_ranks = dist.get_world_size(group=process_group)
655
  param_to_state[id(p)] = _muon_state()
656
+ param_to_state[id(
657
+ p)].worker_rank = shard_mesh[round_robin].item() % num_ranks
658
  param_to_state[id(p)].process_group = process_group
659
  qk_clip_state = self.get_qk_clip_info(n, qk_logits)
660
  param_to_state[id(p)].qk_clip_state = qk_clip_state
 
684
  else:
685
  g = buf
686
 
687
+ u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE),
688
  steps=group["ns_steps"])
689
 
690
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
 
699
  def _update_g(self, p, g, group, momentum):
700
  # calc update
701
  state = self.state[p]
702
+ buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
703
+ torch.add(g, buf, alpha=momentum, out=buf)
 
 
704
  if group["nesterov"]:
705
+ g.add_(buf, alpha=momentum)
706
+ return g
707
+ return buf
 
708
 
709
  @staticmethod
710
  def _update_p(p, u, lr, adjusted_lr, weight_decay):
 
788
  param_to_state, ordered_params = self.init_state_and_assign_params(
789
  names, params, group, qk_logits)
790
 
791
+ assert self.rank is not None
792
+
793
+ def enqueue_all2all_gather(start_idx, chunk_size):
794
+ target_params = ordered_params[start_idx:start_idx + chunk_size]
795
+ if target_params:
796
+ alloc_event = _alloc_gathered_grad(target_params,
797
+ param_to_state, self.rank,
798
+ self.compute_stream)
799
+ _all2all_gather(target_params, param_to_state, self.rank,
800
+ self.comm_stream, group["none_grad"],
801
+ alloc_event)
802
 
803
  def enqueue_computes(start_idx, chunk_size):
804
  for p in ordered_params[start_idx:start_idx + chunk_size]:
 
806
  _compute_u(p, state, group["ns_steps"], self.rank,
807
  self.compute_stream)
808
 
809
+ def enqueue_all2all_scatter(start_idx, chunk_size):
810
+ target_params = ordered_params[start_idx:start_idx + chunk_size]
811
+ if target_params:
812
+ alloc_event = _alloc_scattered_u(target_params, param_to_state,
813
+ self.rank,
814
+ self.compute_stream)
815
+ _all2all_scatter(target_params, param_to_state, self.rank,
816
+ self.comm_stream, alloc_event)
817
 
818
  def enqueue_update_param(start_idx, chunk_size):
819
  for p in ordered_params[start_idx:start_idx + chunk_size]:
 
828
  # Wait grad update
829
  self.comm_stream.wait_stream(torch.cuda.current_stream())
830
 
831
+ overlap_step = self.overlap_step
832
+ for i in range(0, overlap_step):
833
+ enqueue_all2all_gather(i * chunk_size, chunk_size)
834
+ enqueue_computes(i * chunk_size, chunk_size)
835
+
836
  for i in range(0, len(params) + chunk_size - 1, chunk_size):
837
+ enqueue_all2all_scatter(i, chunk_size)
838
+ enqueue_all2all_gather(i + overlap_step * chunk_size, chunk_size)
839
+ enqueue_update_param(i, chunk_size)
840
+ enqueue_computes(i + overlap_step * chunk_size, chunk_size)
 
 
841
 
842
  # Wait the last update_param to finish
843
  torch.cuda.current_stream().wait_stream(self.compute_stream)