TaehyunKim
commited on
Update muon.py
Browse files
torch-ext/optimizer/muon.py
CHANGED
|
@@ -597,7 +597,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 597 |
adjusted_lr = lr * adjusted_ratio
|
| 598 |
return adjusted_lr
|
| 599 |
|
| 600 |
-
def get_shard_mesh(self, p
|
| 601 |
"""
|
| 602 |
Get the shard mesh for a parameter p on the given rank.
|
| 603 |
"""
|
|
@@ -609,8 +609,13 @@ class Muon(torch.optim.Optimizer):
|
|
| 609 |
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 610 |
elif p.placements == (Replicate(), Shard(dim=0)):
|
| 611 |
# Case for HSDP
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 612 |
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
| 613 |
-
if rank in shard_mesh:
|
| 614 |
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
| 615 |
else:
|
| 616 |
raise ValueError(f"Unsupported placements ({p.placements}).")
|
|
@@ -651,15 +656,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 651 |
for n, p in zip(ordered_names, ordered_params):
|
| 652 |
if mesh is None:
|
| 653 |
mesh = p.device_mesh
|
| 654 |
-
shard_mesh, process_group = self.get_shard_mesh(p
|
| 655 |
-
local_rank = dist.get_rank(group=process_group)
|
| 656 |
-
if self.rank is None:
|
| 657 |
-
self.rank = dist.get_rank(group=process_group)
|
| 658 |
-
else:
|
| 659 |
-
assert self.rank == local_rank
|
| 660 |
elif mesh != p.device_mesh:
|
| 661 |
raise ValueError("All parameters must be on the same mesh.")
|
| 662 |
-
|
| 663 |
num_ranks = dist.get_world_size(group=process_group)
|
| 664 |
param_to_state[id(p)] = _muon_state()
|
| 665 |
param_to_state[id(
|
|
|
|
| 597 |
adjusted_lr = lr * adjusted_ratio
|
| 598 |
return adjusted_lr
|
| 599 |
|
| 600 |
+
def get_shard_mesh(self, p):
|
| 601 |
"""
|
| 602 |
Get the shard mesh for a parameter p on the given rank.
|
| 603 |
"""
|
|
|
|
| 609 |
return p.device_mesh.mesh, p.device_mesh.get_group(mesh_dim=0)
|
| 610 |
elif p.placements == (Replicate(), Shard(dim=0)):
|
| 611 |
# Case for HSDP
|
| 612 |
+
process_group = p.device_mesh.get_group(mesh_dim=1)
|
| 613 |
+
if self.rank is None:
|
| 614 |
+
self.rank = dist.get_rank(group=process_group)
|
| 615 |
+
else:
|
| 616 |
+
assert self.rank == dist.get_rank(group=process_group)
|
| 617 |
for i, shard_mesh in enumerate(p.device_mesh.mesh):
|
| 618 |
+
if self.rank in shard_mesh:
|
| 619 |
return shard_mesh, p.device_mesh.get_group(mesh_dim=1)
|
| 620 |
else:
|
| 621 |
raise ValueError(f"Unsupported placements ({p.placements}).")
|
|
|
|
| 656 |
for n, p in zip(ordered_names, ordered_params):
|
| 657 |
if mesh is None:
|
| 658 |
mesh = p.device_mesh
|
| 659 |
+
shard_mesh, process_group = self.get_shard_mesh(p)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 660 |
elif mesh != p.device_mesh:
|
| 661 |
raise ValueError("All parameters must be on the same mesh.")
|
|
|
|
| 662 |
num_ranks = dist.get_world_size(group=process_group)
|
| 663 |
param_to_state[id(p)] = _muon_state()
|
| 664 |
param_to_state[id(
|