Kernels
TaehyunKim commited on
Commit
ff2fcfb
·
unverified ·
1 Parent(s): c16b438

Update torch-ext/optimizer/muon.py

Browse files
Files changed (1) hide show
  1. torch-ext/optimizer/muon.py +1 -1
torch-ext/optimizer/muon.py CHANGED
@@ -610,7 +610,7 @@ class Muon(torch.optim.Optimizer):
610
  elif p.placements == (Replicate(), Shard(dim=0)):
611
  # Case for HSDP
612
  process_group = p.device_mesh.get_group(mesh_dim=1)
613
- if self.rank is None:
614
  self.rank = dist.get_rank(group=process_group)
615
  else:
616
  assert self.rank == dist.get_rank(group=process_group)
 
610
  elif p.placements == (Replicate(), Shard(dim=0)):
611
  # Case for HSDP
612
  process_group = p.device_mesh.get_group(mesh_dim=1)
613
+ if self.rank is None:
614
  self.rank = dist.get_rank(group=process_group)
615
  else:
616
  assert self.rank == dist.get_rank(group=process_group)