Kernels
TaehyunKim commited on
Commit
c16b438
·
unverified ·
1 Parent(s): 4f71bc9

Update muon.py

Browse files
Files changed (1) hide show
  1. torch-ext/optimizer/muon.py +8 -9
torch-ext/optimizer/muon.py CHANGED
@@ -597,7 +597,7 @@ class Muon(torch.optim.Optimizer):
597
  adjusted_lr = lr * adjusted_ratio
598
  return adjusted_lr
599
 
600
- def get_shard_mesh(self, p, rank):
601
  """
602
  Get the shard mesh for a parameter p on the given rank.
603
  """
@@ -609,8 +609,13 @@ class Muon(torch.optim.Optimizer):
609
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
610
  elif p.placements == (Replicate(), Shard(dim=0)):
611
  # Case for HSDP
 
 
 
 
 
612
  for i, shard_mesh in enumerate(p.device_mesh.mesh):
613
- if rank in shard_mesh:
614
  return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
615
  else:
616
  raise ValueError(f"Unsupported placements ({p.placements}).")
@@ -651,15 +656,9 @@ class Muon(torch.optim.Optimizer):
651
  for n, p in zip(ordered_names, ordered_params):
652
  if mesh is None:
653
  mesh = p.device_mesh
654
- shard_mesh, process_group = self.get_shard_mesh(p, self.rank)
655
- local_rank = dist.get_rank(group=process_group)
656
- if self.rank is None:
657
- self.rank = dist.get_rank(group=process_group)
658
- else:
659
- assert self.rank == local_rank
660
  elif mesh != p.device_mesh:
661
  raise ValueError("All parameters must be on the same mesh.")
662
-
663
  num_ranks = dist.get_world_size(group=process_group)
664
  param_to_state[id(p)] = _muon_state()
665
  param_to_state[id(
 
597
  adjusted_lr = lr * adjusted_ratio
598
  return adjusted_lr
599
 
600
+ def get_shard_mesh(self, p):
601
  """
602
  Get the shard mesh for a parameter p on the given rank.
603
  """
 
609
  return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
610
  elif p.placements == (Replicate(), Shard(dim=0)):
611
  # Case for HSDP
612
+ process_group = p.device_mesh.get_group(mesh_dim=1)
613
+ if self.rank is None:
614
+ self.rank = dist.get_rank(group=process_group)
615
+ else:
616
+ assert self.rank == dist.get_rank(group=process_group)
617
  for i, shard_mesh in enumerate(p.device_mesh.mesh):
618
+ if self.rank in shard_mesh:
619
  return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
620
  else:
621
  raise ValueError(f"Unsupported placements ({p.placements}).")
 
656
  for n, p in zip(ordered_names, ordered_params):
657
  if mesh is None:
658
  mesh = p.device_mesh
659
+ shard_mesh, process_group = self.get_shard_mesh(p)
 
 
 
 
 
660
  elif mesh != p.device_mesh:
661
  raise ValueError("All parameters must be on the same mesh.")
 
662
  num_ranks = dist.get_world_size(group=process_group)
663
  param_to_state[id(p)] = _muon_state()
664
  param_to_state[id(