TaehyunKim
commited on
Update torch-ext/optimizer/muon.py
Browse files
torch-ext/optimizer/muon.py
CHANGED
|
@@ -610,7 +610,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 610 |
elif p.placements == (Replicate(), Shard(dim=0)):
|
| 611 |
# Case for HSDP
|
| 612 |
process_group = p.device_mesh.get_group(mesh_dim=1)
|
| 613 |
-
if self.rank is None:
|
| 614 |
self.rank = dist.get_rank(group=process_group)
|
| 615 |
else:
|
| 616 |
assert self.rank == dist.get_rank(group=process_group)
|
|
|
|
| 610 |
elif p.placements == (Replicate(), Shard(dim=0)):
|
| 611 |
# Case for HSDP
|
| 612 |
process_group = p.device_mesh.get_group(mesh_dim=1)
|
| 613 |
+
if self.rank is None:
|
| 614 |
self.rank = dist.get_rank(group=process_group)
|
| 615 |
else:
|
| 616 |
assert self.rank == dist.get_rank(group=process_group)
|