Kernels
wyldecat Claude Opus 4.6 commited on
Commit
e74d98f
·
1 Parent(s): cdaaf4f

Add torch.compile, CUDA graph, and compiled momentum [skip-build]

Browse files

- Newton-Schulz: per-shape torch.compile caching + CUDA graph replay
- Batched momentum: separately compiled nesterov/non-nesterov functions
- Batched Newton-Schulz for MoE experts (bmm/baddbmm)
- Triton matmul_transpose cleanup
- Inline uneven shard handling, remove small_param_numel_threshold
- Raise dynamo recompile_limit for test suite

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

test/conftest.py CHANGED
@@ -9,6 +9,11 @@ from transformers import AutoModelForCausalLM
9
  logger = logging.getLogger(__name__)
10
  logging.basicConfig(level=logging.INFO)
11
 
 
 
 
 
 
12
  SEED = 0xdeadbeef
13
 
14
 
 
9
  logger = logging.getLogger(__name__)
10
  logging.basicConfig(level=logging.INFO)
11
 
12
+ # Raise dynamo recompile limit so that compiled momentum (batch_pre_ortho)
13
+ # does not fall back to eager mode when the test suite runs 30+ model
14
+ # configurations with different tensor shapes in a single process.
15
+ torch._dynamo.config.recompile_limit = 64
16
+
17
  SEED = 0xdeadbeef
18
 
19
 
torch-ext/optimizer/core.py CHANGED
@@ -1,9 +1,9 @@
1
  import logging
2
  import math
3
  from dataclasses import dataclass
 
4
 
5
  import torch
6
- import torch.distributed as dist
7
  from torch.distributed import ProcessGroup
8
  from torch.distributed.tensor import DTensor
9
 
@@ -31,26 +31,71 @@ class _muon_state:
31
  qk_clip_state: torch.Tensor | None = None
32
 
33
 
34
- def update_g(optimizer_state, p, g, group, momentum):
35
- """Apply momentum update to gradient.
 
 
 
 
 
 
36
 
37
- Args:
38
- optimizer_state: The optimizer's state dict (self.state in Muon).
39
- p: Parameter tensor.
40
- g: Gradient tensor.
41
- group: Parameter group dict.
42
- momentum: Momentum coefficient.
43
 
44
- Returns:
45
- Momentum-updated gradient tensor.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  """
47
- state = optimizer_state[p]
48
- buf = state.setdefault("momentum_buffer", torch.zeros_like(g))
49
- torch.add(g, buf, alpha=momentum, out=buf)
50
- if group["nesterov"]:
51
- g.add_(buf, alpha=momentum)
52
- return g
53
- return buf
 
 
 
 
 
 
 
 
 
 
 
54
 
55
 
56
  def update_p(p, u, lr, adjusted_lr, weight_decay):
@@ -63,14 +108,13 @@ def update_p(p, u, lr, adjusted_lr, weight_decay):
63
  adjusted_lr: Size-adjusted learning rate.
64
  weight_decay: Weight decay coefficient.
65
  """
66
- if isinstance(p, torch.nn.Parameter):
67
- # apply weight decay
68
- p.data.mul_(1 - lr * weight_decay)
69
- # apply update
70
- p.data.add_(u, alpha=-adjusted_lr)
71
- else:
72
- p.mul_(1 - lr * weight_decay)
73
- p.add_(u, alpha=-adjusted_lr)
74
 
75
 
76
  def adjust_lr_for_muon(lr, param_shape):
@@ -147,7 +191,7 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
147
  is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys)
148
 
149
  muon_params, muon_names = [], []
150
- non_muon_params = []
151
 
152
  for n, p in model.named_parameters():
153
  if not p.requires_grad:
@@ -157,6 +201,10 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
157
  muon_names.append(n)
158
  else:
159
  non_muon_params.append(p)
 
 
 
 
160
 
161
  return [
162
  {
 
1
  import logging
2
  import math
3
  from dataclasses import dataclass
4
+ from typing import List
5
 
6
  import torch
 
7
  from torch.distributed import ProcessGroup
8
  from torch.distributed.tensor import DTensor
9
 
 
31
  qk_clip_state: torch.Tensor | None = None
32
 
33
 
34
+ def _batch_momentum(
35
+ grads: List[torch.Tensor],
36
+ momentum_bufs: List[torch.Tensor],
37
+ momentum: torch.Tensor,
38
+ ) -> None:
39
+ """Batched momentum update (no nesterov)."""
40
+ torch._foreach_mul_(momentum_bufs, momentum)
41
+ torch._foreach_add_(momentum_bufs, grads)
42
 
 
 
 
 
 
 
43
 
44
+ def _batch_momentum_nesterov(
45
+ grads: List[torch.Tensor],
46
+ momentum_bufs: List[torch.Tensor],
47
+ momentum: torch.Tensor,
48
+ ) -> None:
49
+ """Batched momentum update with nesterov correction."""
50
+ torch._foreach_mul_(momentum_bufs, momentum)
51
+ torch._foreach_add_(momentum_bufs, grads)
52
+ nesterov_terms = torch._foreach_mul(momentum_bufs, momentum)
53
+ torch._foreach_add_(grads, nesterov_terms)
54
+
55
+
56
+ _compiled_momentum: dict[bool, callable] = {}
57
+ _use_momentum_compile = True
58
+
59
+
60
+ def set_momentum_compile(enabled: bool):
61
+ """Toggle torch.compile for batched momentum."""
62
+ global _use_momentum_compile
63
+ _use_momentum_compile = enabled
64
+
65
+
66
+ def batch_pre_ortho(
67
+ grads: List[torch.Tensor],
68
+ momentum_bufs: List[torch.Tensor],
69
+ momentum: torch.Tensor,
70
+ nesterov: bool,
71
+ ) -> None:
72
+ """Batched momentum update on lists of plain tensors.
73
+
74
+ Mirrors dion's ``muon_update_pre_orthogonalize``.
75
+ Inputs must be plain CUDA tensors (not DTensor).
76
+ Modifies ``momentum_bufs`` and (for nesterov) ``grads`` in-place.
77
+
78
+ When compile is enabled, uses separately compiled functions for
79
+ nesterov=True/False to avoid graph breaks from the branch.
80
  """
81
+ fn = _batch_momentum_nesterov if nesterov else _batch_momentum
82
+ if _use_momentum_compile:
83
+ if nesterov not in _compiled_momentum:
84
+ _compiled_momentum[nesterov] = torch.compile(fn)
85
+ fn = _compiled_momentum[nesterov]
86
+ fn(grads, momentum_bufs, momentum)
87
+
88
+
89
+ def _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay):
90
+ """Weight-decay + update on plain tensors.
91
+
92
+ Not compiled: per-param @torch.compile caused ~0.25ms TorchDynamo cache
93
+ lookup per call × 256+ params = massive overhead. The pipeline path uses
94
+ batched _foreach_* ops instead; this function remains for base() and
95
+ distributed_muon().
96
+ """
97
+ p_data.mul_(1 - lr * weight_decay)
98
+ p_data.add_(u_data, alpha=-adjusted_lr)
99
 
100
 
101
  def update_p(p, u, lr, adjusted_lr, weight_decay):
 
108
  adjusted_lr: Size-adjusted learning rate.
109
  weight_decay: Weight decay coefficient.
110
  """
111
+ # Unwrap Parameter -> underlying data tensor.
112
+ p_data = p.data if isinstance(p, torch.nn.Parameter) else p
113
+ # Unwrap DTensor -> local CUDA tensor for compiled kernel.
114
+ if isinstance(p_data, DTensor):
115
+ p_data = p_data._local_tensor
116
+ u_data = u._local_tensor if isinstance(u, DTensor) else u
117
+ _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay)
 
118
 
119
 
120
  def adjust_lr_for_muon(lr, param_shape):
 
191
  is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys)
192
 
193
  muon_params, muon_names = [], []
194
+ non_muon_params, non_muon_names = [], []
195
 
196
  for n, p in model.named_parameters():
197
  if not p.requires_grad:
 
201
  muon_names.append(n)
202
  else:
203
  non_muon_params.append(p)
204
+ non_muon_names.append(n)
205
+
206
+ logger.info("[param_groups] expert_keys=%s, Muon=%d, AdamW=%d",
207
+ expert_keys, len(muon_names), len(non_muon_names))
208
 
209
  return [
210
  {
torch-ext/optimizer/distributed/utils.py CHANGED
@@ -72,12 +72,6 @@ def get_slices_of_dtensor(
72
  else:
73
  curr_size = target.size()[shard_dim]
74
 
75
- if curr_size % num_chunks != 0:
76
- raise NotImplementedError(
77
- f"Dimension size {curr_size} is not divisible "
78
- f"by number of ranks {num_chunks} for shard "
79
- f"placement on dim {shard_dim}. (shape: {target.shape})")
80
-
81
  # Compute indices for this level of sharding
82
  if isinstance(placement, _StridedShard):
83
  _shard_size, offsets = _StridedShard.local_shard_size_and_offset(
 
72
  else:
73
  curr_size = target.size()[shard_dim]
74
 
 
 
 
 
 
 
75
  # Compute indices for this level of sharding
76
  if isinstance(placement, _StridedShard):
77
  _shard_size, offsets = _StridedShard.local_shard_size_and_offset(
torch-ext/optimizer/matmul_transpose_triton.py CHANGED
@@ -43,6 +43,7 @@ def get_autotune_config():
43
  @triton.autotune(
44
  configs=get_autotune_config(),
45
  key=['M', 'K'],
 
46
  )
47
  @triton.jit
48
  def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
@@ -102,16 +103,10 @@ def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
102
  tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)
103
 
104
 
105
- def matmul_transpose_assign(d_in, d_out):
106
- assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor"
107
- assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor"
108
- assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device"
109
- assert d_in.dtype == d_out.dtype, "Inputs must have the same data type"
110
- assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor"
111
- assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor"
112
- assert d_in.size(0) == d_out.size(0) == d_out.size(0), \
113
- "First dimension of `d_in` must match first and second dimension of `d_out`"
114
-
115
  d_in = d_in.contiguous()
116
  M, K = d_in.shape
117
  grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
@@ -119,3 +114,9 @@ def matmul_transpose_assign(d_in, d_out):
119
  with torch.cuda.device(d_in.device.index):
120
  mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
121
  d_out.stride(0), d_out.stride(1))
 
 
 
 
 
 
 
43
  @triton.autotune(
44
  configs=get_autotune_config(),
45
  key=['M', 'K'],
46
+ restore_value=['y'],
47
  )
48
  @triton.jit
49
  def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn,
 
103
  tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask)
104
 
105
 
106
+ @torch.library.custom_op("muon::matmul_transpose_assign",
107
+ mutates_args=("d_out", ))
108
+ def matmul_transpose_assign(d_in: torch.Tensor, d_out: torch.Tensor) -> None:
109
+ """Compute d_out = d_in @ d_in.T using an optimized Triton kernel."""
 
 
 
 
 
 
110
  d_in = d_in.contiguous()
111
  M, K = d_in.shape
112
  grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
 
114
  with torch.cuda.device(d_in.device.index):
115
  mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1),
116
  d_out.stride(0), d_out.stride(1))
117
+
118
+
119
+ @matmul_transpose_assign.register_fake
120
+ def _(d_in: torch.Tensor, d_out: torch.Tensor) -> None:
121
+ """FakeTensor impl: d_out is already allocated, mutation is declared."""
122
+ pass
torch-ext/optimizer/newton_schulz.py CHANGED
@@ -162,3 +162,75 @@ def _zeropower_via_newtonschulz5(G, steps):
162
  X = X.T
163
 
164
  return X
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  X = X.T
163
 
164
  return X
165
+
166
+
167
+ @torch.no_grad()
168
+ def _zeropower_via_newtonschulz5_batched(G, steps):
169
+ """Batched polar factor computation for 3D (E, out, in) tensors.
170
+
171
+ Same algorithm as ``_zeropower_via_newtonschulz5`` but uses
172
+ ``torch.bmm`` / ``torch.baddbmm`` instead of the 2D Triton kernel,
173
+ processing all E expert matrices in a single batched call.
174
+ """
175
+ assert len(G.shape) == 3
176
+ assert G.dtype == COMM_DTYPE
177
+ X = G
178
+
179
+ if G.size(1) > G.size(2):
180
+ X = X.transpose(-2, -1)
181
+
182
+ # Per-expert Frobenius norm.
183
+ X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
184
+
185
+ hs = _coeffs_list[:steps] + list(
186
+ repeat(_coeffs_list[-1], steps - len(_coeffs_list)))
187
+ for a, b, c in hs:
188
+ buf1 = torch.bmm(X, X.transpose(-2, -1))
189
+ buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
190
+ buf1.mul_(b).add_(buf2, alpha=c)
191
+ X = torch.baddbmm(X, buf1, X, alpha=1.0, beta=a)
192
+
193
+ if G.size(1) > G.size(2):
194
+ X = X.transpose(-2, -1)
195
+
196
+ return X
197
+
198
+
199
+ _ns_per_shape: dict[tuple[int, ...], callable] = {}
200
+ _use_compile = True
201
+
202
+
203
+ def set_ns_compile(enabled: bool):
204
+ """Toggle torch.compile for Newton-Schulz iteration."""
205
+ global _use_compile
206
+ _use_compile = enabled
207
+
208
+
209
+ def zeropower_via_newtonschulz5(G, steps=5):
210
+ if not _use_compile:
211
+ return _zeropower_via_newtonschulz5(G, steps)
212
+ key = G.shape
213
+ if key not in _ns_per_shape:
214
+ _ns_per_shape[key] = torch.compile(_zeropower_via_newtonschulz5,
215
+ options={
216
+ "triton.cudagraphs": True,
217
+ "shape_padding": False
218
+ })
219
+ torch.compiler.cudagraph_mark_step_begin()
220
+ return _ns_per_shape[key](G, steps).clone()
221
+
222
+
223
+ def zeropower_via_newtonschulz5_batched(G, steps=5):
224
+ """Compile-cached batched Newton-Schulz for 3D expert tensors."""
225
+ if not _use_compile:
226
+ return _zeropower_via_newtonschulz5_batched(G, steps)
227
+ key = G.shape
228
+ if key not in _ns_per_shape:
229
+ _ns_per_shape[key] = torch.compile(
230
+ _zeropower_via_newtonschulz5_batched,
231
+ options={
232
+ "triton.cudagraphs": True,
233
+ "shape_padding": False
234
+ })
235
+ torch.compiler.cudagraph_mark_step_begin()
236
+ return _ns_per_shape[key](G, steps).clone()
torch-ext/optimizer/qk_clip.py CHANGED
@@ -102,23 +102,27 @@ def compute_scales(p, qk_clip_state):
102
  threshold = qk_clip_state.threshold
103
  logit = qk_clip_state.logit
104
 
105
- H_global = p.shape[0] // head_dim
106
- scales_full = torch.ones(H_global, device=p.data.device)
107
- scaling = 0
108
-
109
  for logit_idx, head_idx in enumerate(indices):
110
  v_ele = float(logit[logit_idx])
111
  if v_ele > threshold:
112
  new_scale = math.sqrt(threshold / v_ele)
113
- if new_scale < scales_full[head_idx]:
114
- scales_full[head_idx] = new_scale
115
  logger.info(
116
  f"[{kind}] Head {head_idx} exceeded threshold "
117
  f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
118
  )
119
- scaling += 1
120
 
121
- return scales_full if scaling > 0 else None
 
 
 
 
 
 
 
122
 
123
 
124
  def qk_clip(p, scales, head_dim):
 
102
  threshold = qk_clip_state.threshold
103
  logit = qk_clip_state.logit
104
 
105
+ # Check if any head exceeds threshold before allocating.
106
+ head_scales = {}
 
 
107
  for logit_idx, head_idx in enumerate(indices):
108
  v_ele = float(logit[logit_idx])
109
  if v_ele > threshold:
110
  new_scale = math.sqrt(threshold / v_ele)
111
+ if head_idx not in head_scales or new_scale < head_scales[head_idx]:
112
+ head_scales[head_idx] = new_scale
113
  logger.info(
114
  f"[{kind}] Head {head_idx} exceeded threshold "
115
  f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}"
116
  )
 
117
 
118
+ if not head_scales:
119
+ return None
120
+
121
+ H_global = p.shape[0] // head_dim
122
+ scales_full = torch.ones(H_global, device=p.data.device)
123
+ for head_idx, scale in head_scales.items():
124
+ scales_full[head_idx] = scale
125
+ return scales_full
126
 
127
 
128
  def qk_clip(p, scales, head_dim):