Kernels
wyldecat Claude Opus 4.6 commited on
Commit
1a97671
·
1 Parent(s): 14040eb

Add uneven shard correctness test [skip-build]

Browse files

Test parallel Muon with param dimensions not divisible by shard count
(dim=33,19,11 with 8 ranks). Verifies against sequential baseline.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. test/test_muon.py +90 -0
test/test_muon.py CHANGED
@@ -301,3 +301,93 @@ def test_parallel_muon_empty_shard(init_dist):
301
 
302
  set_ns_compile(True)
303
  logger.info("test_parallel_muon_empty_shard PASSED (rank %d)", rank)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
302
  set_ns_compile(True)
303
  logger.info("test_parallel_muon_empty_shard PASSED (rank %d)", rank)
304
+
305
+
306
+ @pytest.mark.parametrize("uneven_dim", [
307
+ pytest.param(33, id="33"),
308
+ pytest.param(19, id="19"),
309
+ pytest.param(11, id="11"),
310
+ ])
311
+ def test_parallel_muon_uneven_shard(init_dist, uneven_dim):
312
+ """Test that parallel Muon produces correct results when parameter
313
+ dimensions are not evenly divisible by the number of shard ranks.
314
+
315
+ For example, dim=33 with 8 ranks gives 7 ranks with 4 rows and
316
+ 1 rank with 5 rows. This exercises the remainder-handling logic
317
+ in ``get_slices_of_dtensor`` and the all-to-all pipeline.
318
+ """
319
+ rank = dist.get_rank()
320
+ world_size = dist.get_world_size()
321
+ mesh = dist.init_device_mesh("cuda", (world_size, ),
322
+ mesh_dim_names=("dp", ))
323
+
324
+ set_ns_compile(False)
325
+ torch.manual_seed(42)
326
+
327
+ other_dim = 64
328
+ num_params = 3
329
+
330
+ # --- Build sharded params + grads ---
331
+ muon_params = []
332
+ muon_names = []
333
+ full_params_snapshot = []
334
+ full_grads = []
335
+
336
+ for i in range(num_params):
337
+ full = torch.randn(uneven_dim, other_dim, device="cuda")
338
+ full_params_snapshot.append(full.clone())
339
+ dt = distribute_tensor(full, mesh, [Shard(0)])
340
+ p = torch.nn.Parameter(dt)
341
+ grad_full = torch.randn(uneven_dim, other_dim, device="cuda")
342
+ full_grads.append(grad_full.clone())
343
+ p.grad = distribute_tensor(grad_full, mesh, [Shard(0)])
344
+ muon_params.append(p)
345
+ muon_names.append(f"layer.{i}.weight")
346
+
347
+ # --- Parallel path (all2all pipeline) ---
348
+ param_groups_par = [{
349
+ "params": muon_params,
350
+ "names": muon_names,
351
+ "use_muon": True,
352
+ "lr": 0.02,
353
+ "weight_decay": 0.01,
354
+ "momentum": 0.95,
355
+ "nesterov": True,
356
+ "ns_steps": 5,
357
+ "none_grad": False,
358
+ }]
359
+ optim_par = Muon(params=param_groups_par, chunk_size=1, warmup_step=0)
360
+ optim_par.step()
361
+
362
+ # --- Sequential baseline (base path, no sharding) ---
363
+ seq_params = []
364
+ seq_names = []
365
+ for i in range(num_params):
366
+ p = torch.nn.Parameter(full_params_snapshot[i].clone())
367
+ p.grad = full_grads[i].clone()
368
+ seq_params.append(p)
369
+ seq_names.append(f"layer.{i}.weight")
370
+
371
+ param_groups_seq = [{
372
+ "params": seq_params,
373
+ "names": seq_names,
374
+ "use_muon": True,
375
+ "lr": 0.02,
376
+ "weight_decay": 0.01,
377
+ "momentum": 0.95,
378
+ "nesterov": True,
379
+ "ns_steps": 5,
380
+ "none_grad": False,
381
+ }]
382
+ optim_seq = Muon(params=param_groups_seq)
383
+ optim_seq.step()
384
+
385
+ # --- Compare: parallel result (gathered) must match sequential ---
386
+ for i in range(num_params):
387
+ par_full = muon_params[i].data.full_tensor()
388
+ seq_full = seq_params[i].data
389
+ torch.testing.assert_close(par_full, seq_full, atol=0, rtol=0)
390
+
391
+ set_ns_compile(True)
392
+ logger.info("test_parallel_muon_uneven_shard (dim=%d) PASSED (rank %d)",
393
+ uneven_dim, rank)