Kernels
wyldecat Claude Opus 4.6 commited on
Commit
bdada12
·
1 Parent(s): 1a97671

Add MoE uneven shard test with mixed expert and non-expert params [skip-build]

Browse files

Test parallel Muon with uneven dims (33, 19) mixing 2D DTensor params
(parallel pipeline) and 3D expert plain tensors (batched NS path).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. test/test_muon_moe.py +115 -1
test/test_muon_moe.py CHANGED
@@ -7,7 +7,7 @@ import pytest
7
  import torch
8
  import torch.distributed as dist
9
  from optimizer.muon import Muon, get_default_muon_param_groups
10
- from torch.distributed.tensor import DTensor, Replicate
11
  from torch.profiler import ProfilerActivity, profile
12
 
13
  from .utils import ParallelDims, assert_params_equal, parallelize_llama4
@@ -287,3 +287,117 @@ def test_parallel_muon_moe_few_experts(
287
  else:
288
  assert_params_equal(parallelized_model,
289
  sequential_moe_result_few_experts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import torch
8
  import torch.distributed as dist
9
  from optimizer.muon import Muon, get_default_muon_param_groups
10
+ from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor
11
  from torch.profiler import ProfilerActivity, profile
12
 
13
  from .utils import ParallelDims, assert_params_equal, parallelize_llama4
 
287
  else:
288
  assert_params_equal(parallelized_model,
289
  sequential_moe_result_few_experts)
290
+
291
+
292
+ # ---------------------------------------------------------------------------
293
+ # Uneven shard test: mixed expert (3D plain) + non-expert (2D DTensor)
294
+ # with dimensions not evenly divisible by shard count.
295
+ # ---------------------------------------------------------------------------
296
+
297
+
298
+ @pytest.mark.parametrize("uneven_dim", [
299
+ pytest.param(33, id="33"),
300
+ pytest.param(19, id="19"),
301
+ ])
302
+ def test_parallel_muon_moe_uneven_shard(init_dist, uneven_dim):
303
+ """Test MoE parallel Muon with uneven shard dimensions.
304
+
305
+ Mixes non-expert 2D DTensor params (uneven FSDP sharding, parallel
306
+ pipeline path) with expert 3D plain-tensor params (batched NS path).
307
+ Verifies the combination produces correct results vs sequential baseline.
308
+ """
309
+ from optimizer.newton_schulz import set_ns_compile
310
+
311
+ rank = dist.get_rank()
312
+ world_size = dist.get_world_size()
313
+ mesh = dist.init_device_mesh("cuda", (world_size, ),
314
+ mesh_dim_names=("dp", ))
315
+
316
+ set_ns_compile(False)
317
+ torch.manual_seed(42)
318
+
319
+ other_dim = 64
320
+ num_experts = 4
321
+
322
+ muon_params = []
323
+ muon_names = []
324
+ full_params = []
325
+ full_grads = []
326
+
327
+ # 2D non-expert params with uneven dims → parallel pipeline
328
+ for i in range(2):
329
+ full = torch.randn(uneven_dim, other_dim, device="cuda")
330
+ full_params.append(full.clone())
331
+ dt = distribute_tensor(full, mesh, [Shard(0)])
332
+ p = torch.nn.Parameter(dt)
333
+ g = torch.randn(uneven_dim, other_dim, device="cuda")
334
+ full_grads.append(g.clone())
335
+ p.grad = distribute_tensor(g, mesh, [Shard(0)])
336
+ muon_params.append(p)
337
+ muon_names.append(f"layers.{i}.weight")
338
+
339
+ # 3D expert params (plain tensors) → batched NS path
340
+ full = torch.randn(num_experts, uneven_dim, other_dim, device="cuda")
341
+ full_params.append(full.clone())
342
+ p = torch.nn.Parameter(full)
343
+ g = torch.randn(num_experts, uneven_dim, other_dim, device="cuda")
344
+ full_grads.append(g.clone())
345
+ p.grad = g
346
+ muon_params.append(p)
347
+ muon_names.append("layers.2.experts.w1.weight")
348
+
349
+ # --- Parallel path ---
350
+ param_groups_par = [{
351
+ "params": muon_params,
352
+ "names": muon_names,
353
+ "use_muon": True,
354
+ "lr": 0.02,
355
+ "weight_decay": 0.01,
356
+ "momentum": 0.95,
357
+ "nesterov": True,
358
+ "ns_steps": 5,
359
+ "none_grad": False,
360
+ }]
361
+ optim_par = Muon(params=param_groups_par,
362
+ chunk_size=1,
363
+ warmup_step=0,
364
+ expert_keys=["experts"])
365
+ optim_par.step()
366
+
367
+ # --- Sequential baseline ---
368
+ seq_params = []
369
+ for fp in full_params:
370
+ p = torch.nn.Parameter(fp.clone())
371
+ seq_params.append(p)
372
+
373
+ for p, g in zip(seq_params, full_grads):
374
+ p.grad = g.clone()
375
+
376
+ param_groups_seq = [{
377
+ "params": seq_params,
378
+ "names": list(muon_names),
379
+ "use_muon": True,
380
+ "lr": 0.02,
381
+ "weight_decay": 0.01,
382
+ "momentum": 0.95,
383
+ "nesterov": True,
384
+ "ns_steps": 5,
385
+ "none_grad": False,
386
+ }]
387
+ optim_seq = Muon(params=param_groups_seq, expert_keys=["experts"])
388
+ optim_seq.step()
389
+
390
+ # --- Compare ---
391
+ for i in range(len(muon_params)):
392
+ par_data = muon_params[i].data
393
+ if isinstance(par_data, DTensor):
394
+ par_data = par_data.full_tensor()
395
+ torch.testing.assert_close(par_data,
396
+ seq_params[i].data,
397
+ atol=0,
398
+ rtol=0)
399
+
400
+ set_ns_compile(True)
401
+ logger.info(
402
+ "test_parallel_muon_moe_uneven_shard (dim=%d) PASSED (rank %d)",
403
+ uneven_dim, rank)