Kernels
wyldecat Claude Opus 4.6 commited on
Commit
81f49fe
·
1 Parent(s): 0f37d63

Update tests for MoE and parallel optimizations [skip-build]

Browse files

- Add MoE test cases (test_muon_moe.py)
- Update parallel test configurations
- Test utility updates

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

test/test_muon.py CHANGED
@@ -7,7 +7,9 @@ import pytest
7
  import torch
8
  import torch.distributed as dist
9
  from optimizer.muon import Muon, get_default_muon_param_groups
10
- from torch.distributed.tensor import DTensor, Replicate
 
 
11
  from torch.profiler import ProfilerActivity, profile
12
 
13
  from .utils import (ParallelDims, assert_params_equal, parallelize_motif,
@@ -23,7 +25,6 @@ def apply_muon_step(
23
  grads: list[torch.Tensor],
24
  warmup_step: int,
25
  chunk_size: int,
26
- small_param_numel_threshold: int,
27
  qk_logits: dict[int, torch.Tensor] | None = None,
28
  use_distributed_muon: bool = False,
29
  measure_perf: bool = False,
@@ -67,7 +68,6 @@ def apply_muon_step(
67
  none_grad=False,
68
  warmup_step=warmup_step,
69
  chunk_size=chunk_size,
70
- small_param_numel_threshold=small_param_numel_threshold,
71
  use_distributed_muon=use_distributed_muon,
72
  )
73
 
@@ -119,43 +119,45 @@ def apply_muon_step(
119
  def sequential_muon_result(
120
  skip_verify, # from conftest.py
121
  inputs # from conftest.py
122
- ) -> dict[bool, torch.nn.Module]:
123
- """Run Muon optimizer to sequential model for baseline results."""
 
 
 
124
  if skip_verify:
125
  logger.info("Skipping verification tests as per user request")
126
  return None
127
 
128
  model, grads, qk_logits = inputs
 
129
 
130
- result = apply_muon_step(
131
- model=copy.deepcopy(model).cuda(),
132
- parallel_dims=None,
133
- grads=grads,
134
- warmup_step=-1,
135
- chunk_size=-1,
136
- small_param_numel_threshold=-1,
137
- qk_logits=None,
138
- )[0].cpu()
139
-
140
- result_qk_clip = apply_muon_step(
141
- model=copy.deepcopy(model).cuda(),
142
- parallel_dims=None,
143
- grads=grads,
144
- warmup_step=-1,
145
- chunk_size=-1,
146
- small_param_numel_threshold=-1,
147
- qk_logits=qk_logits,
148
- )[0].cpu()
149
 
150
- return {
151
- False: result,
152
- True: result_qk_clip,
153
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
 
156
  OVERLAP_STEPS = [5]
157
  CHUNK_SIZES = [2]
158
- SMALL_PARAM_NUMEL_THRESHOLDS = [65536, 1_000_000_000]
159
 
160
 
161
  @pytest.mark.parametrize("parallel_dims", [
@@ -170,17 +172,16 @@ SMALL_PARAM_NUMEL_THRESHOLDS = [65536, 1_000_000_000]
170
  @pytest.mark.parametrize("use_distributed_muon", [False])
171
  @pytest.mark.parametrize("warmup_step", OVERLAP_STEPS)
172
  @pytest.mark.parametrize("chunk_size", CHUNK_SIZES)
173
- @pytest.mark.parametrize("small_param_numel_threshold",
174
- SMALL_PARAM_NUMEL_THRESHOLDS)
175
  def test_parallel_muon(
176
  request,
177
- sequential_muon_result: dict[bool, torch.nn.Module],
178
  parallel_dims: ParallelDims,
179
  apply_qk_clip: bool,
180
  use_distributed_muon: bool,
181
  warmup_step: int,
182
  chunk_size: int,
183
- small_param_numel_threshold: int,
184
  inputs: tuple[torch.nn.Module, list[torch.Tensor],
185
  dict[int, torch.Tensor]], # from conftest.py
186
  measure_perf, # from conftest.py
@@ -191,6 +192,8 @@ def test_parallel_muon(
191
  if use_distributed_muon and warmup_step != OVERLAP_STEPS[0]:
192
  pytest.skip("Distributed Muon does not effected by warmup step")
193
 
 
 
194
  model, grads, qk_logits = inputs
195
 
196
  if not apply_qk_clip:
@@ -212,7 +215,6 @@ def test_parallel_muon(
212
  grads=grads,
213
  warmup_step=warmup_step,
214
  chunk_size=chunk_size,
215
- small_param_numel_threshold=small_param_numel_threshold,
216
  qk_logits=qk_logits,
217
  use_distributed_muon=use_distributed_muon,
218
  measure_perf=measure_perf,
@@ -236,5 +238,66 @@ def test_parallel_muon(
236
  elif measure_perf:
237
  logger.info("Skipping correctness check as timing is enabled")
238
  else:
 
 
239
  assert_params_equal(parallelized_model,
240
- sequential_muon_result[apply_qk_clip])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import torch
8
  import torch.distributed as dist
9
  from optimizer.muon import Muon, get_default_muon_param_groups
10
+ from optimizer.newton_schulz import set_ns_compile
11
+ from torch.distributed.tensor import (DTensor, Replicate, Shard,
12
+ distribute_tensor)
13
  from torch.profiler import ProfilerActivity, profile
14
 
15
  from .utils import (ParallelDims, assert_params_equal, parallelize_motif,
 
25
  grads: list[torch.Tensor],
26
  warmup_step: int,
27
  chunk_size: int,
 
28
  qk_logits: dict[int, torch.Tensor] | None = None,
29
  use_distributed_muon: bool = False,
30
  measure_perf: bool = False,
 
68
  none_grad=False,
69
  warmup_step=warmup_step,
70
  chunk_size=chunk_size,
 
71
  use_distributed_muon=use_distributed_muon,
72
  )
73
 
 
119
  def sequential_muon_result(
120
  skip_verify, # from conftest.py
121
  inputs # from conftest.py
122
+ ) -> dict[tuple[bool, bool], torch.nn.Module]:
123
+ """Run Muon optimizer to sequential model for baseline results.
124
+
125
+ Returns dict keyed by ``(apply_qk_clip, use_compile)``.
126
+ """
127
  if skip_verify:
128
  logger.info("Skipping verification tests as per user request")
129
  return None
130
 
131
  model, grads, qk_logits = inputs
132
+ results: dict[tuple[bool, bool], torch.nn.Module] = {}
133
 
134
+ for use_compile in [False, True]:
135
+ set_ns_compile(use_compile)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
+ results[(False, use_compile)] = apply_muon_step(
138
+ model=copy.deepcopy(model).cuda(),
139
+ parallel_dims=None,
140
+ grads=grads,
141
+ warmup_step=-1,
142
+ chunk_size=-1,
143
+ qk_logits=None,
144
+ )[0].cpu()
145
+
146
+ results[(True, use_compile)] = apply_muon_step(
147
+ model=copy.deepcopy(model).cuda(),
148
+ parallel_dims=None,
149
+ grads=grads,
150
+ warmup_step=-1,
151
+ chunk_size=-1,
152
+ qk_logits=qk_logits,
153
+ )[0].cpu()
154
+
155
+ set_ns_compile(True) # restore default
156
+ return results
157
 
158
 
159
  OVERLAP_STEPS = [5]
160
  CHUNK_SIZES = [2]
 
161
 
162
 
163
  @pytest.mark.parametrize("parallel_dims", [
 
172
  @pytest.mark.parametrize("use_distributed_muon", [False])
173
  @pytest.mark.parametrize("warmup_step", OVERLAP_STEPS)
174
  @pytest.mark.parametrize("chunk_size", CHUNK_SIZES)
175
+ @pytest.mark.parametrize("use_compile", [False, True])
 
176
  def test_parallel_muon(
177
  request,
178
+ sequential_muon_result: dict[tuple[bool, bool], torch.nn.Module],
179
  parallel_dims: ParallelDims,
180
  apply_qk_clip: bool,
181
  use_distributed_muon: bool,
182
  warmup_step: int,
183
  chunk_size: int,
184
+ use_compile: bool,
185
  inputs: tuple[torch.nn.Module, list[torch.Tensor],
186
  dict[int, torch.Tensor]], # from conftest.py
187
  measure_perf, # from conftest.py
 
192
  if use_distributed_muon and warmup_step != OVERLAP_STEPS[0]:
193
  pytest.skip("Distributed Muon does not effected by warmup step")
194
 
195
+ set_ns_compile(use_compile)
196
+
197
  model, grads, qk_logits = inputs
198
 
199
  if not apply_qk_clip:
 
215
  grads=grads,
216
  warmup_step=warmup_step,
217
  chunk_size=chunk_size,
 
218
  qk_logits=qk_logits,
219
  use_distributed_muon=use_distributed_muon,
220
  measure_perf=measure_perf,
 
238
  elif measure_perf:
239
  logger.info("Skipping correctness check as timing is enabled")
240
  else:
241
+ atol = 1e-5 if use_compile else 0
242
+ rtol = 1e-2 if use_compile else 0
243
  assert_params_equal(parallelized_model,
244
+ sequential_muon_result[(apply_qk_clip,
245
+ use_compile)],
246
+ atol=atol,
247
+ rtol=rtol)
248
+
249
+
250
+ def test_parallel_muon_empty_shard(init_dist):
251
+ """Regression: parallel Muon must handle chunks where some ranks have
252
+ empty local shards (dim-0 < world_size).
253
+
254
+ With 8-way Shard(0) and dim-0 of size 4, ranks 4-7 get 0-element local
255
+ shards. Previously ``_launch_gather`` hit ``assert total_send > 0``.
256
+ """
257
+ rank = dist.get_rank()
258
+ world_size = dist.get_world_size()
259
+ mesh = dist.init_device_mesh("cuda", (world_size, ),
260
+ mesh_dim_names=("dp", ))
261
+
262
+ set_ns_compile(False)
263
+
264
+ # dim-0 = 4 < 8 ranks → ranks 4-7 have empty local shards with Shard(0)
265
+ small_dim = 4
266
+ num_params = 4
267
+ torch.manual_seed(42)
268
+
269
+ muon_params = []
270
+ muon_names = []
271
+ for i in range(num_params):
272
+ full = torch.randn(small_dim, 64, device="cuda")
273
+ dt = distribute_tensor(full, mesh, [Shard(0)])
274
+ p = torch.nn.Parameter(dt)
275
+ grad_full = torch.randn(small_dim, 64, device="cuda")
276
+ p.grad = distribute_tensor(grad_full, mesh, [Shard(0)])
277
+ muon_params.append(p)
278
+ muon_names.append(f"layer.{i}.weight")
279
+
280
+ param_groups = [{
281
+ "params": muon_params,
282
+ "names": muon_names,
283
+ "use_muon": True,
284
+ "lr": 0.02,
285
+ "weight_decay": 0.01,
286
+ "momentum": 0.95,
287
+ "nesterov": True,
288
+ "ns_steps": 5,
289
+ "none_grad": False,
290
+ }]
291
+
292
+ optim = Muon(params=param_groups, chunk_size=1, warmup_step=0)
293
+ # Must not raise AssertionError: total_send > 0
294
+ optim.step()
295
+
296
+ # Run a second step to verify cached path also works
297
+ for p in muon_params:
298
+ grad_full = torch.randn(small_dim, 64, device="cuda")
299
+ p.grad = distribute_tensor(grad_full, mesh, [Shard(0)])
300
+ optim.step()
301
+
302
+ set_ns_compile(True)
303
+ logger.info("test_parallel_muon_empty_shard PASSED (rank %d)", rank)
test/test_muon_moe.py CHANGED
@@ -45,7 +45,6 @@ def apply_muon_step_moe(
45
  grads: list[torch.Tensor],
46
  warmup_step: int,
47
  chunk_size: int,
48
- small_param_numel_threshold: int,
49
  use_distributed_muon: bool = False,
50
  measure_perf: bool = False,
51
  do_profile: bool = False,
@@ -63,7 +62,6 @@ def apply_muon_step_moe(
63
  none_grad=False,
64
  warmup_step=warmup_step,
65
  chunk_size=chunk_size,
66
- small_param_numel_threshold=small_param_numel_threshold,
67
  use_distributed_muon=use_distributed_muon,
68
  expert_keys=["experts"],
69
  )
@@ -73,6 +71,10 @@ def apply_muon_step_moe(
73
 
74
  optim.step()
75
 
 
 
 
 
76
  timing_result: tuple[float, float] | None = None
77
 
78
  if measure_perf:
@@ -133,7 +135,6 @@ def sequential_moe_result(
133
  grads=grads,
134
  warmup_step=-1,
135
  chunk_size=-1,
136
- small_param_numel_threshold=-1,
137
  )
138
  result = result.cpu()
139
 
@@ -142,25 +143,26 @@ def sequential_moe_result(
142
 
143
  OVERLAP_STEPS = [5]
144
  CHUNK_SIZES = [2]
145
- SMALL_PARAM_NUMEL_THRESHOLDS = [65536, 1_000_000_000]
146
 
147
 
148
- @pytest.mark.parametrize("parallel_dims", [
149
- pytest.param(ParallelDims(8, 1, 1), id="base"),
150
- pytest.param(ParallelDims(1, 8, 1), id="fsdp"),
151
- pytest.param(ParallelDims(2, 4, 1), id="hsdp"),
152
- pytest.param(ParallelDims(2, 2, 2), id="hsdp+tp"),
153
- pytest.param(ParallelDims(1, 2, 4), id="fsdp+tp"),
154
- pytest.param(ParallelDims(1, 1, 1, ep_degree=8), id="ep"),
155
- pytest.param(ParallelDims(1, 4, 1, ep_degree=2), id="ep+fsdp"),
156
- pytest.param(ParallelDims(1, 2, 1, ep_degree=4), id="ep4+fsdp"),
157
- pytest.param(ParallelDims(2, 2, 1, ep_degree=2), id="ep+hsdp"),
158
- ])
 
 
 
 
159
  @pytest.mark.parametrize("use_distributed_muon", [False])
160
  @pytest.mark.parametrize("warmup_step", OVERLAP_STEPS)
161
  @pytest.mark.parametrize("chunk_size", CHUNK_SIZES)
162
- @pytest.mark.parametrize("small_param_numel_threshold",
163
- SMALL_PARAM_NUMEL_THRESHOLDS)
164
  def test_parallel_muon_moe(
165
  request,
166
  sequential_moe_result: torch.nn.Module | None,
@@ -168,7 +170,6 @@ def test_parallel_muon_moe(
168
  use_distributed_muon: bool,
169
  warmup_step: int,
170
  chunk_size: int,
171
- small_param_numel_threshold: int,
172
  moe_inputs: tuple[torch.nn.Module, list[torch.Tensor]],
173
  measure_perf,
174
  do_profile,
@@ -186,7 +187,6 @@ def test_parallel_muon_moe(
186
  grads=grads,
187
  warmup_step=warmup_step,
188
  chunk_size=chunk_size,
189
- small_param_numel_threshold=small_param_numel_threshold,
190
  use_distributed_muon=use_distributed_muon,
191
  measure_perf=measure_perf,
192
  do_profile=do_profile,
@@ -231,7 +231,6 @@ def sequential_moe_result_few_experts(
231
  grads=grads,
232
  warmup_step=-1,
233
  chunk_size=-1,
234
- small_param_numel_threshold=-1,
235
  )
236
  result = result.cpu()
237
 
@@ -239,14 +238,12 @@ def sequential_moe_result_few_experts(
239
 
240
 
241
  @pytest.mark.parametrize("parallel_dims", [
242
- pytest.param(ParallelDims(1, 4, 1, ep_degree=2), id="ep+fsdp"),
243
- pytest.param(ParallelDims(2, 2, 1, ep_degree=2), id="ep+hsdp"),
244
  ])
245
  @pytest.mark.parametrize("use_distributed_muon", [False])
246
  @pytest.mark.parametrize("warmup_step", OVERLAP_STEPS)
247
  @pytest.mark.parametrize("chunk_size", CHUNK_SIZES)
248
- @pytest.mark.parametrize("small_param_numel_threshold",
249
- SMALL_PARAM_NUMEL_THRESHOLDS)
250
  def test_parallel_muon_moe_few_experts(
251
  request,
252
  sequential_moe_result_few_experts: torch.nn.Module | None,
@@ -254,7 +251,6 @@ def test_parallel_muon_moe_few_experts(
254
  use_distributed_muon: bool,
255
  warmup_step: int,
256
  chunk_size: int,
257
- small_param_numel_threshold: int,
258
  moe_inputs_few_experts: tuple[torch.nn.Module, list[torch.Tensor]],
259
  measure_perf,
260
  do_profile,
@@ -271,7 +267,6 @@ def test_parallel_muon_moe_few_experts(
271
  grads=grads,
272
  warmup_step=warmup_step,
273
  chunk_size=chunk_size,
274
- small_param_numel_threshold=small_param_numel_threshold,
275
  use_distributed_muon=use_distributed_muon,
276
  measure_perf=measure_perf,
277
  do_profile=do_profile,
 
45
  grads: list[torch.Tensor],
46
  warmup_step: int,
47
  chunk_size: int,
 
48
  use_distributed_muon: bool = False,
49
  measure_perf: bool = False,
50
  do_profile: bool = False,
 
62
  none_grad=False,
63
  warmup_step=warmup_step,
64
  chunk_size=chunk_size,
 
65
  use_distributed_muon=use_distributed_muon,
66
  expert_keys=["experts"],
67
  )
 
71
 
72
  optim.step()
73
 
74
+ # Second step to exercise expert expand cache hot path.
75
+ _restore_grads(model, saved_grads)
76
+ optim.step()
77
+
78
  timing_result: tuple[float, float] | None = None
79
 
80
  if measure_perf:
 
135
  grads=grads,
136
  warmup_step=-1,
137
  chunk_size=-1,
 
138
  )
139
  result = result.cpu()
140
 
 
143
 
144
  OVERLAP_STEPS = [5]
145
  CHUNK_SIZES = [2]
 
146
 
147
 
148
+ @pytest.mark.parametrize(
149
+ "parallel_dims",
150
+ [
151
+ # --- No EP (non-expert only) ---
152
+ pytest.param(ParallelDims(8, 1, 1), id="dp8"),
153
+ pytest.param(ParallelDims(1, 8, 1), id="fsdp8"),
154
+ pytest.param(ParallelDims(2, 4, 1), id="hsdp2x4"),
155
+ # --- EP configs ---
156
+ # naming: fsdp{dp_shard}_ep{ep} where dp_shard = dp_shard_mod_ep * ep
157
+ # dp_shard_mod_ep (= expert FSDP) = dp_shard_degree in our ParallelDims
158
+ pytest.param(ParallelDims(1, 1, 1, ep_degree=8), id="fsdp8_ep8"),
159
+ pytest.param(ParallelDims(1, 4, 1, ep_degree=2), id="fsdp8_ep2"),
160
+ pytest.param(ParallelDims(1, 2, 1, ep_degree=4), id="fsdp8_ep4"),
161
+ pytest.param(ParallelDims(2, 2, 1, ep_degree=2), id="hsdp_ep2"),
162
+ ])
163
  @pytest.mark.parametrize("use_distributed_muon", [False])
164
  @pytest.mark.parametrize("warmup_step", OVERLAP_STEPS)
165
  @pytest.mark.parametrize("chunk_size", CHUNK_SIZES)
 
 
166
  def test_parallel_muon_moe(
167
  request,
168
  sequential_moe_result: torch.nn.Module | None,
 
170
  use_distributed_muon: bool,
171
  warmup_step: int,
172
  chunk_size: int,
 
173
  moe_inputs: tuple[torch.nn.Module, list[torch.Tensor]],
174
  measure_perf,
175
  do_profile,
 
187
  grads=grads,
188
  warmup_step=warmup_step,
189
  chunk_size=chunk_size,
 
190
  use_distributed_muon=use_distributed_muon,
191
  measure_perf=measure_perf,
192
  do_profile=do_profile,
 
231
  grads=grads,
232
  warmup_step=-1,
233
  chunk_size=-1,
 
234
  )
235
  result = result.cpu()
236
 
 
238
 
239
 
240
  @pytest.mark.parametrize("parallel_dims", [
241
+ pytest.param(ParallelDims(1, 4, 1, ep_degree=2), id="fsdp8_ep2"),
242
+ pytest.param(ParallelDims(2, 2, 1, ep_degree=2), id="hsdp_ep2"),
243
  ])
244
  @pytest.mark.parametrize("use_distributed_muon", [False])
245
  @pytest.mark.parametrize("warmup_step", OVERLAP_STEPS)
246
  @pytest.mark.parametrize("chunk_size", CHUNK_SIZES)
 
 
247
  def test_parallel_muon_moe_few_experts(
248
  request,
249
  sequential_moe_result_few_experts: torch.nn.Module | None,
 
251
  use_distributed_muon: bool,
252
  warmup_step: int,
253
  chunk_size: int,
 
254
  moe_inputs_few_experts: tuple[torch.nn.Module, list[torch.Tensor]],
255
  measure_perf,
256
  do_profile,
 
267
  grads=grads,
268
  warmup_step=warmup_step,
269
  chunk_size=chunk_size,
 
270
  use_distributed_muon=use_distributed_muon,
271
  measure_perf=measure_perf,
272
  do_profile=do_profile,
test/test_normalize_fqn.py CHANGED
@@ -1,6 +1,5 @@
1
  """Unit tests for FQN normalization (no GPU / distributed required)."""
2
 
3
-
4
  from optimizer.core import default_is_muon, is_expert_param, normalize_fqn
5
  from optimizer.qk_clip import parse_qk_layer
6
 
 
1
  """Unit tests for FQN normalization (no GPU / distributed required)."""
2
 
 
3
  from optimizer.core import default_is_muon, is_expert_param, normalize_fqn
4
  from optimizer.qk_clip import parse_qk_layer
5
 
test/utils.py CHANGED
@@ -259,12 +259,16 @@ def parallelize_qk_logits(
259
 
260
 
261
  def assert_params_equal(actual: torch.nn.Module,
262
- expected: torch.nn.Module) -> None:
 
 
263
  """Asserts that the parameters of two models are equal.
264
 
265
  Args:
266
  actual (torch.nn.Module): The actual model.
267
  expected (torch.nn.Module): The expected model.
 
 
268
  Returns:
269
  None
270
  """
@@ -279,4 +283,4 @@ def assert_params_equal(actual: torch.nn.Module,
279
  p = get_full_param(p.cuda())
280
  s = get_full_param(s.cuda())
281
 
282
- torch.testing.assert_close(p, s, atol=0, rtol=0)
 
259
 
260
 
261
  def assert_params_equal(actual: torch.nn.Module,
262
+ expected: torch.nn.Module,
263
+ atol: float = 0,
264
+ rtol: float = 0) -> None:
265
  """Asserts that the parameters of two models are equal.
266
 
267
  Args:
268
  actual (torch.nn.Module): The actual model.
269
  expected (torch.nn.Module): The expected model.
270
+ atol: Absolute tolerance.
271
+ rtol: Relative tolerance.
272
  Returns:
273
  None
274
  """
 
283
  p = get_full_param(p.cuda())
284
  s = get_full_param(s.cuda())
285
 
286
+ torch.testing.assert_close(p, s, atol=atol, rtol=rtol)