Commit
·
bdd2678
1
Parent(s):
8535e80
fix(muon): delete intermediate tensors immediately to lower peak mem usage
Browse files- build/torch26-cxx11-cu118-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx11-cu118-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
- build/torch26-cxx11-cu118-x86_64-linux/optimizer/muon.py +10 -19
- build/torch26-cxx11-cu124-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx11-cu124-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
- build/torch26-cxx11-cu124-x86_64-linux/optimizer/muon.py +10 -19
- build/torch26-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx11-cu126-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
- build/torch26-cxx11-cu126-x86_64-linux/optimizer/muon.py +10 -19
- build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
- build/torch26-cxx11-rocm62-x86_64-linux/optimizer/muon.py +10 -19
- build/torch26-cxx98-cu118-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx98-cu118-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
- build/torch26-cxx98-cu118-x86_64-linux/optimizer/muon.py +10 -19
- build/torch26-cxx98-cu124-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx98-cu124-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
- build/torch26-cxx98-cu124-x86_64-linux/optimizer/muon.py +10 -19
- build/torch26-cxx98-cu126-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx98-cu126-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
- build/torch26-cxx98-cu126-x86_64-linux/optimizer/muon.py +10 -19
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py +10 -19
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py +10 -19
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py +10 -19
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-310.pyc +0 -0
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-310.pyc +0 -0
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so +3 -0
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py +10 -19
- torch-ext/optimizer/muon.py +10 -19
build/torch26-cxx11-cu118-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_8535e80_dirty
|
| 3 |
+
ops = torch.ops._optimizer_8535e80_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_8535e80_dirty::{op_name}"
|
build/torch26-cxx11-cu118-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a46d9e65efcfa82522950d9ebf2b2b4594d9ed5abc28704352a1f7de2dae707a
|
| 3 |
+
size 1787272
|
build/torch26-cxx11-cu118-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -48,7 +48,6 @@ class _muon_state:
|
|
| 48 |
worker_rank: int | None = None
|
| 49 |
gathered_grad: torch.Tensor | None = None
|
| 50 |
computed_u: torch.Tensor | None = None
|
| 51 |
-
scattered_u: torch.Tensor | None = None
|
| 52 |
gather_event: torch.cuda.Event | None = None
|
| 53 |
compute_event: torch.cuda.Event | None = None
|
| 54 |
|
|
@@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 93 |
state.computed_u = u
|
| 94 |
state.compute_event = torch.cuda.Event()
|
| 95 |
state.compute_event.record()
|
|
|
|
|
|
|
| 96 |
else:
|
| 97 |
state.computed_u = None
|
| 98 |
state.compute_event = None
|
| 99 |
|
| 100 |
|
| 101 |
-
def _scatter(p, state, rank, comm_stream):
|
| 102 |
u = state.computed_u
|
| 103 |
mesh = p.device_mesh
|
| 104 |
|
|
@@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream):
|
|
| 118 |
src=state.worker_rank,
|
| 119 |
group=mesh.get_group(),
|
| 120 |
)
|
|
|
|
|
|
|
|
|
|
| 121 |
u = DTensor.from_local(
|
| 122 |
u,
|
| 123 |
placements=p.placements,
|
| 124 |
device_mesh=mesh,
|
| 125 |
)
|
| 126 |
-
|
| 127 |
-
|
| 128 |
|
| 129 |
|
| 130 |
class Muon(torch.optim.Optimizer):
|
|
@@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 353 |
def enqueue_scatters(start_idx, chunk_size):
|
| 354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 355 |
state = param_to_state[id(p)]
|
| 356 |
-
|
|
|
|
| 357 |
|
| 358 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 359 |
|
|
@@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer):
|
|
| 368 |
|
| 369 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
| 370 |
|
| 371 |
-
for p in params:
|
| 372 |
-
g = p.grad
|
| 373 |
-
if g is None:
|
| 374 |
-
continue
|
| 375 |
-
|
| 376 |
-
# Update p with sharded u
|
| 377 |
-
state = param_to_state[id(p)]
|
| 378 |
-
self._update_p(
|
| 379 |
-
p,
|
| 380 |
-
state.scattered_u,
|
| 381 |
-
lr=lr,
|
| 382 |
-
wd=wd,
|
| 383 |
-
)
|
| 384 |
-
|
| 385 |
def step(self, closure=None):
|
| 386 |
"""Perform a single optimization step.
|
| 387 |
|
|
|
|
| 48 |
worker_rank: int | None = None
|
| 49 |
gathered_grad: torch.Tensor | None = None
|
| 50 |
computed_u: torch.Tensor | None = None
|
|
|
|
| 51 |
gather_event: torch.cuda.Event | None = None
|
| 52 |
compute_event: torch.cuda.Event | None = None
|
| 53 |
|
|
|
|
| 92 |
state.computed_u = u
|
| 93 |
state.compute_event = torch.cuda.Event()
|
| 94 |
state.compute_event.record()
|
| 95 |
+
state.gathered_grad.record_stream(compute_stream)
|
| 96 |
+
del state.gathered_grad
|
| 97 |
else:
|
| 98 |
state.computed_u = None
|
| 99 |
state.compute_event = None
|
| 100 |
|
| 101 |
|
| 102 |
+
def _scatter(p, state, lr, wd, rank, comm_stream):
|
| 103 |
u = state.computed_u
|
| 104 |
mesh = p.device_mesh
|
| 105 |
|
|
|
|
| 119 |
src=state.worker_rank,
|
| 120 |
group=mesh.get_group(),
|
| 121 |
)
|
| 122 |
+
if rank == state.worker_rank:
|
| 123 |
+
state.computed_u.record_stream(comm_stream)
|
| 124 |
+
del state.computed_u
|
| 125 |
u = DTensor.from_local(
|
| 126 |
u,
|
| 127 |
placements=p.placements,
|
| 128 |
device_mesh=mesh,
|
| 129 |
)
|
| 130 |
+
p.data.mul_(1 - lr * wd)
|
| 131 |
+
p.data.add_(u, alpha=-lr)
|
| 132 |
|
| 133 |
|
| 134 |
class Muon(torch.optim.Optimizer):
|
|
|
|
| 357 |
def enqueue_scatters(start_idx, chunk_size):
|
| 358 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 359 |
state = param_to_state[id(p)]
|
| 360 |
+
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 361 |
+
_scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
|
| 362 |
|
| 363 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 364 |
|
|
|
|
| 373 |
|
| 374 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
| 375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
def step(self, closure=None):
|
| 377 |
"""Perform a single optimization step.
|
| 378 |
|
build/torch26-cxx11-cu124-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_8535e80_dirty
|
| 3 |
+
ops = torch.ops._optimizer_8535e80_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_8535e80_dirty::{op_name}"
|
build/torch26-cxx11-cu124-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d351a600884b7378f546a345afe65c176e1399bb42fb7dfe4333b0e90975803b
|
| 3 |
+
size 1824224
|
build/torch26-cxx11-cu124-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -48,7 +48,6 @@ class _muon_state:
|
|
| 48 |
worker_rank: int | None = None
|
| 49 |
gathered_grad: torch.Tensor | None = None
|
| 50 |
computed_u: torch.Tensor | None = None
|
| 51 |
-
scattered_u: torch.Tensor | None = None
|
| 52 |
gather_event: torch.cuda.Event | None = None
|
| 53 |
compute_event: torch.cuda.Event | None = None
|
| 54 |
|
|
@@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 93 |
state.computed_u = u
|
| 94 |
state.compute_event = torch.cuda.Event()
|
| 95 |
state.compute_event.record()
|
|
|
|
|
|
|
| 96 |
else:
|
| 97 |
state.computed_u = None
|
| 98 |
state.compute_event = None
|
| 99 |
|
| 100 |
|
| 101 |
-
def _scatter(p, state, rank, comm_stream):
|
| 102 |
u = state.computed_u
|
| 103 |
mesh = p.device_mesh
|
| 104 |
|
|
@@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream):
|
|
| 118 |
src=state.worker_rank,
|
| 119 |
group=mesh.get_group(),
|
| 120 |
)
|
|
|
|
|
|
|
|
|
|
| 121 |
u = DTensor.from_local(
|
| 122 |
u,
|
| 123 |
placements=p.placements,
|
| 124 |
device_mesh=mesh,
|
| 125 |
)
|
| 126 |
-
|
| 127 |
-
|
| 128 |
|
| 129 |
|
| 130 |
class Muon(torch.optim.Optimizer):
|
|
@@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 353 |
def enqueue_scatters(start_idx, chunk_size):
|
| 354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 355 |
state = param_to_state[id(p)]
|
| 356 |
-
|
|
|
|
| 357 |
|
| 358 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 359 |
|
|
@@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer):
|
|
| 368 |
|
| 369 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
| 370 |
|
| 371 |
-
for p in params:
|
| 372 |
-
g = p.grad
|
| 373 |
-
if g is None:
|
| 374 |
-
continue
|
| 375 |
-
|
| 376 |
-
# Update p with sharded u
|
| 377 |
-
state = param_to_state[id(p)]
|
| 378 |
-
self._update_p(
|
| 379 |
-
p,
|
| 380 |
-
state.scattered_u,
|
| 381 |
-
lr=lr,
|
| 382 |
-
wd=wd,
|
| 383 |
-
)
|
| 384 |
-
|
| 385 |
def step(self, closure=None):
|
| 386 |
"""Perform a single optimization step.
|
| 387 |
|
|
|
|
| 48 |
worker_rank: int | None = None
|
| 49 |
gathered_grad: torch.Tensor | None = None
|
| 50 |
computed_u: torch.Tensor | None = None
|
|
|
|
| 51 |
gather_event: torch.cuda.Event | None = None
|
| 52 |
compute_event: torch.cuda.Event | None = None
|
| 53 |
|
|
|
|
| 92 |
state.computed_u = u
|
| 93 |
state.compute_event = torch.cuda.Event()
|
| 94 |
state.compute_event.record()
|
| 95 |
+
state.gathered_grad.record_stream(compute_stream)
|
| 96 |
+
del state.gathered_grad
|
| 97 |
else:
|
| 98 |
state.computed_u = None
|
| 99 |
state.compute_event = None
|
| 100 |
|
| 101 |
|
| 102 |
+
def _scatter(p, state, lr, wd, rank, comm_stream):
|
| 103 |
u = state.computed_u
|
| 104 |
mesh = p.device_mesh
|
| 105 |
|
|
|
|
| 119 |
src=state.worker_rank,
|
| 120 |
group=mesh.get_group(),
|
| 121 |
)
|
| 122 |
+
if rank == state.worker_rank:
|
| 123 |
+
state.computed_u.record_stream(comm_stream)
|
| 124 |
+
del state.computed_u
|
| 125 |
u = DTensor.from_local(
|
| 126 |
u,
|
| 127 |
placements=p.placements,
|
| 128 |
device_mesh=mesh,
|
| 129 |
)
|
| 130 |
+
p.data.mul_(1 - lr * wd)
|
| 131 |
+
p.data.add_(u, alpha=-lr)
|
| 132 |
|
| 133 |
|
| 134 |
class Muon(torch.optim.Optimizer):
|
|
|
|
| 357 |
def enqueue_scatters(start_idx, chunk_size):
|
| 358 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 359 |
state = param_to_state[id(p)]
|
| 360 |
+
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 361 |
+
_scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
|
| 362 |
|
| 363 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 364 |
|
|
|
|
| 373 |
|
| 374 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
| 375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
def step(self, closure=None):
|
| 377 |
"""Perform a single optimization step.
|
| 378 |
|
build/torch26-cxx11-cu126-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_8535e80_dirty
|
| 3 |
+
ops = torch.ops._optimizer_8535e80_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_8535e80_dirty::{op_name}"
|
build/torch26-cxx11-cu126-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2c0843f38cee494b7a5939eb62d27039d76dc3f69401d411efbacaa25cb0d67a
|
| 3 |
+
size 1824224
|
build/torch26-cxx11-cu126-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -48,7 +48,6 @@ class _muon_state:
|
|
| 48 |
worker_rank: int | None = None
|
| 49 |
gathered_grad: torch.Tensor | None = None
|
| 50 |
computed_u: torch.Tensor | None = None
|
| 51 |
-
scattered_u: torch.Tensor | None = None
|
| 52 |
gather_event: torch.cuda.Event | None = None
|
| 53 |
compute_event: torch.cuda.Event | None = None
|
| 54 |
|
|
@@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 93 |
state.computed_u = u
|
| 94 |
state.compute_event = torch.cuda.Event()
|
| 95 |
state.compute_event.record()
|
|
|
|
|
|
|
| 96 |
else:
|
| 97 |
state.computed_u = None
|
| 98 |
state.compute_event = None
|
| 99 |
|
| 100 |
|
| 101 |
-
def _scatter(p, state, rank, comm_stream):
|
| 102 |
u = state.computed_u
|
| 103 |
mesh = p.device_mesh
|
| 104 |
|
|
@@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream):
|
|
| 118 |
src=state.worker_rank,
|
| 119 |
group=mesh.get_group(),
|
| 120 |
)
|
|
|
|
|
|
|
|
|
|
| 121 |
u = DTensor.from_local(
|
| 122 |
u,
|
| 123 |
placements=p.placements,
|
| 124 |
device_mesh=mesh,
|
| 125 |
)
|
| 126 |
-
|
| 127 |
-
|
| 128 |
|
| 129 |
|
| 130 |
class Muon(torch.optim.Optimizer):
|
|
@@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 353 |
def enqueue_scatters(start_idx, chunk_size):
|
| 354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 355 |
state = param_to_state[id(p)]
|
| 356 |
-
|
|
|
|
| 357 |
|
| 358 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 359 |
|
|
@@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer):
|
|
| 368 |
|
| 369 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
| 370 |
|
| 371 |
-
for p in params:
|
| 372 |
-
g = p.grad
|
| 373 |
-
if g is None:
|
| 374 |
-
continue
|
| 375 |
-
|
| 376 |
-
# Update p with sharded u
|
| 377 |
-
state = param_to_state[id(p)]
|
| 378 |
-
self._update_p(
|
| 379 |
-
p,
|
| 380 |
-
state.scattered_u,
|
| 381 |
-
lr=lr,
|
| 382 |
-
wd=wd,
|
| 383 |
-
)
|
| 384 |
-
|
| 385 |
def step(self, closure=None):
|
| 386 |
"""Perform a single optimization step.
|
| 387 |
|
|
|
|
| 48 |
worker_rank: int | None = None
|
| 49 |
gathered_grad: torch.Tensor | None = None
|
| 50 |
computed_u: torch.Tensor | None = None
|
|
|
|
| 51 |
gather_event: torch.cuda.Event | None = None
|
| 52 |
compute_event: torch.cuda.Event | None = None
|
| 53 |
|
|
|
|
| 92 |
state.computed_u = u
|
| 93 |
state.compute_event = torch.cuda.Event()
|
| 94 |
state.compute_event.record()
|
| 95 |
+
state.gathered_grad.record_stream(compute_stream)
|
| 96 |
+
del state.gathered_grad
|
| 97 |
else:
|
| 98 |
state.computed_u = None
|
| 99 |
state.compute_event = None
|
| 100 |
|
| 101 |
|
| 102 |
+
def _scatter(p, state, lr, wd, rank, comm_stream):
|
| 103 |
u = state.computed_u
|
| 104 |
mesh = p.device_mesh
|
| 105 |
|
|
|
|
| 119 |
src=state.worker_rank,
|
| 120 |
group=mesh.get_group(),
|
| 121 |
)
|
| 122 |
+
if rank == state.worker_rank:
|
| 123 |
+
state.computed_u.record_stream(comm_stream)
|
| 124 |
+
del state.computed_u
|
| 125 |
u = DTensor.from_local(
|
| 126 |
u,
|
| 127 |
placements=p.placements,
|
| 128 |
device_mesh=mesh,
|
| 129 |
)
|
| 130 |
+
p.data.mul_(1 - lr * wd)
|
| 131 |
+
p.data.add_(u, alpha=-lr)
|
| 132 |
|
| 133 |
|
| 134 |
class Muon(torch.optim.Optimizer):
|
|
|
|
| 357 |
def enqueue_scatters(start_idx, chunk_size):
|
| 358 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 359 |
state = param_to_state[id(p)]
|
| 360 |
+
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 361 |
+
_scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
|
| 362 |
|
| 363 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 364 |
|
|
|
|
| 373 |
|
| 374 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
| 375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
def step(self, closure=None):
|
| 377 |
"""Perform a single optimization step.
|
| 378 |
|
build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_8535e80_dirty
|
| 3 |
+
ops = torch.ops._optimizer_8535e80_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_8535e80_dirty::{op_name}"
|
build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:acdba99ce95532a9ca6a8987a7ab61a257657872f2cc672c91e8e5fe809aa24e
|
| 3 |
+
size 1749744
|
build/torch26-cxx11-rocm62-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -48,7 +48,6 @@ class _muon_state:
|
|
| 48 |
worker_rank: int | None = None
|
| 49 |
gathered_grad: torch.Tensor | None = None
|
| 50 |
computed_u: torch.Tensor | None = None
|
| 51 |
-
scattered_u: torch.Tensor | None = None
|
| 52 |
gather_event: torch.cuda.Event | None = None
|
| 53 |
compute_event: torch.cuda.Event | None = None
|
| 54 |
|
|
@@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 93 |
state.computed_u = u
|
| 94 |
state.compute_event = torch.cuda.Event()
|
| 95 |
state.compute_event.record()
|
|
|
|
|
|
|
| 96 |
else:
|
| 97 |
state.computed_u = None
|
| 98 |
state.compute_event = None
|
| 99 |
|
| 100 |
|
| 101 |
-
def _scatter(p, state, rank, comm_stream):
|
| 102 |
u = state.computed_u
|
| 103 |
mesh = p.device_mesh
|
| 104 |
|
|
@@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream):
|
|
| 118 |
src=state.worker_rank,
|
| 119 |
group=mesh.get_group(),
|
| 120 |
)
|
|
|
|
|
|
|
|
|
|
| 121 |
u = DTensor.from_local(
|
| 122 |
u,
|
| 123 |
placements=p.placements,
|
| 124 |
device_mesh=mesh,
|
| 125 |
)
|
| 126 |
-
|
| 127 |
-
|
| 128 |
|
| 129 |
|
| 130 |
class Muon(torch.optim.Optimizer):
|
|
@@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 353 |
def enqueue_scatters(start_idx, chunk_size):
|
| 354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 355 |
state = param_to_state[id(p)]
|
| 356 |
-
|
|
|
|
| 357 |
|
| 358 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 359 |
|
|
@@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer):
|
|
| 368 |
|
| 369 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
| 370 |
|
| 371 |
-
for p in params:
|
| 372 |
-
g = p.grad
|
| 373 |
-
if g is None:
|
| 374 |
-
continue
|
| 375 |
-
|
| 376 |
-
# Update p with sharded u
|
| 377 |
-
state = param_to_state[id(p)]
|
| 378 |
-
self._update_p(
|
| 379 |
-
p,
|
| 380 |
-
state.scattered_u,
|
| 381 |
-
lr=lr,
|
| 382 |
-
wd=wd,
|
| 383 |
-
)
|
| 384 |
-
|
| 385 |
def step(self, closure=None):
|
| 386 |
"""Perform a single optimization step.
|
| 387 |
|
|
|
|
| 48 |
worker_rank: int | None = None
|
| 49 |
gathered_grad: torch.Tensor | None = None
|
| 50 |
computed_u: torch.Tensor | None = None
|
|
|
|
| 51 |
gather_event: torch.cuda.Event | None = None
|
| 52 |
compute_event: torch.cuda.Event | None = None
|
| 53 |
|
|
|
|
| 92 |
state.computed_u = u
|
| 93 |
state.compute_event = torch.cuda.Event()
|
| 94 |
state.compute_event.record()
|
| 95 |
+
state.gathered_grad.record_stream(compute_stream)
|
| 96 |
+
del state.gathered_grad
|
| 97 |
else:
|
| 98 |
state.computed_u = None
|
| 99 |
state.compute_event = None
|
| 100 |
|
| 101 |
|
| 102 |
+
def _scatter(p, state, lr, wd, rank, comm_stream):
|
| 103 |
u = state.computed_u
|
| 104 |
mesh = p.device_mesh
|
| 105 |
|
|
|
|
| 119 |
src=state.worker_rank,
|
| 120 |
group=mesh.get_group(),
|
| 121 |
)
|
| 122 |
+
if rank == state.worker_rank:
|
| 123 |
+
state.computed_u.record_stream(comm_stream)
|
| 124 |
+
del state.computed_u
|
| 125 |
u = DTensor.from_local(
|
| 126 |
u,
|
| 127 |
placements=p.placements,
|
| 128 |
device_mesh=mesh,
|
| 129 |
)
|
| 130 |
+
p.data.mul_(1 - lr * wd)
|
| 131 |
+
p.data.add_(u, alpha=-lr)
|
| 132 |
|
| 133 |
|
| 134 |
class Muon(torch.optim.Optimizer):
|
|
|
|
| 357 |
def enqueue_scatters(start_idx, chunk_size):
|
| 358 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 359 |
state = param_to_state[id(p)]
|
| 360 |
+
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 361 |
+
_scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
|
| 362 |
|
| 363 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 364 |
|
|
|
|
| 373 |
|
| 374 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
| 375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
def step(self, closure=None):
|
| 377 |
"""Perform a single optimization step.
|
| 378 |
|
build/torch26-cxx98-cu118-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_8535e80_dirty
|
| 3 |
+
ops = torch.ops._optimizer_8535e80_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_8535e80_dirty::{op_name}"
|
build/torch26-cxx98-cu118-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f7d5e76c002507f66f2a227d02c2b11aa3fdc3f07a2a0b82faaa34133adb77ef
|
| 3 |
+
size 1787192
|
build/torch26-cxx98-cu118-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -48,7 +48,6 @@ class _muon_state:
|
|
| 48 |
worker_rank: int | None = None
|
| 49 |
gathered_grad: torch.Tensor | None = None
|
| 50 |
computed_u: torch.Tensor | None = None
|
| 51 |
-
scattered_u: torch.Tensor | None = None
|
| 52 |
gather_event: torch.cuda.Event | None = None
|
| 53 |
compute_event: torch.cuda.Event | None = None
|
| 54 |
|
|
@@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 93 |
state.computed_u = u
|
| 94 |
state.compute_event = torch.cuda.Event()
|
| 95 |
state.compute_event.record()
|
|
|
|
|
|
|
| 96 |
else:
|
| 97 |
state.computed_u = None
|
| 98 |
state.compute_event = None
|
| 99 |
|
| 100 |
|
| 101 |
-
def _scatter(p, state, rank, comm_stream):
|
| 102 |
u = state.computed_u
|
| 103 |
mesh = p.device_mesh
|
| 104 |
|
|
@@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream):
|
|
| 118 |
src=state.worker_rank,
|
| 119 |
group=mesh.get_group(),
|
| 120 |
)
|
|
|
|
|
|
|
|
|
|
| 121 |
u = DTensor.from_local(
|
| 122 |
u,
|
| 123 |
placements=p.placements,
|
| 124 |
device_mesh=mesh,
|
| 125 |
)
|
| 126 |
-
|
| 127 |
-
|
| 128 |
|
| 129 |
|
| 130 |
class Muon(torch.optim.Optimizer):
|
|
@@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 353 |
def enqueue_scatters(start_idx, chunk_size):
|
| 354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 355 |
state = param_to_state[id(p)]
|
| 356 |
-
|
|
|
|
| 357 |
|
| 358 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 359 |
|
|
@@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer):
|
|
| 368 |
|
| 369 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
| 370 |
|
| 371 |
-
for p in params:
|
| 372 |
-
g = p.grad
|
| 373 |
-
if g is None:
|
| 374 |
-
continue
|
| 375 |
-
|
| 376 |
-
# Update p with sharded u
|
| 377 |
-
state = param_to_state[id(p)]
|
| 378 |
-
self._update_p(
|
| 379 |
-
p,
|
| 380 |
-
state.scattered_u,
|
| 381 |
-
lr=lr,
|
| 382 |
-
wd=wd,
|
| 383 |
-
)
|
| 384 |
-
|
| 385 |
def step(self, closure=None):
|
| 386 |
"""Perform a single optimization step.
|
| 387 |
|
|
|
|
| 48 |
worker_rank: int | None = None
|
| 49 |
gathered_grad: torch.Tensor | None = None
|
| 50 |
computed_u: torch.Tensor | None = None
|
|
|
|
| 51 |
gather_event: torch.cuda.Event | None = None
|
| 52 |
compute_event: torch.cuda.Event | None = None
|
| 53 |
|
|
|
|
| 92 |
state.computed_u = u
|
| 93 |
state.compute_event = torch.cuda.Event()
|
| 94 |
state.compute_event.record()
|
| 95 |
+
state.gathered_grad.record_stream(compute_stream)
|
| 96 |
+
del state.gathered_grad
|
| 97 |
else:
|
| 98 |
state.computed_u = None
|
| 99 |
state.compute_event = None
|
| 100 |
|
| 101 |
|
| 102 |
+
def _scatter(p, state, lr, wd, rank, comm_stream):
|
| 103 |
u = state.computed_u
|
| 104 |
mesh = p.device_mesh
|
| 105 |
|
|
|
|
| 119 |
src=state.worker_rank,
|
| 120 |
group=mesh.get_group(),
|
| 121 |
)
|
| 122 |
+
if rank == state.worker_rank:
|
| 123 |
+
state.computed_u.record_stream(comm_stream)
|
| 124 |
+
del state.computed_u
|
| 125 |
u = DTensor.from_local(
|
| 126 |
u,
|
| 127 |
placements=p.placements,
|
| 128 |
device_mesh=mesh,
|
| 129 |
)
|
| 130 |
+
p.data.mul_(1 - lr * wd)
|
| 131 |
+
p.data.add_(u, alpha=-lr)
|
| 132 |
|
| 133 |
|
| 134 |
class Muon(torch.optim.Optimizer):
|
|
|
|
| 357 |
def enqueue_scatters(start_idx, chunk_size):
|
| 358 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 359 |
state = param_to_state[id(p)]
|
| 360 |
+
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 361 |
+
_scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
|
| 362 |
|
| 363 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 364 |
|
|
|
|
| 373 |
|
| 374 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
| 375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
def step(self, closure=None):
|
| 377 |
"""Perform a single optimization step.
|
| 378 |
|
build/torch26-cxx98-cu124-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_8535e80_dirty
|
| 3 |
+
ops = torch.ops._optimizer_8535e80_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_8535e80_dirty::{op_name}"
|
build/torch26-cxx98-cu124-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:becccd250f38a84803350cfb5fac3a6682b1e594968a714642724cbc71246b4a
|
| 3 |
+
size 1824184
|
build/torch26-cxx98-cu124-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -48,7 +48,6 @@ class _muon_state:
|
|
| 48 |
worker_rank: int | None = None
|
| 49 |
gathered_grad: torch.Tensor | None = None
|
| 50 |
computed_u: torch.Tensor | None = None
|
| 51 |
-
scattered_u: torch.Tensor | None = None
|
| 52 |
gather_event: torch.cuda.Event | None = None
|
| 53 |
compute_event: torch.cuda.Event | None = None
|
| 54 |
|
|
@@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 93 |
state.computed_u = u
|
| 94 |
state.compute_event = torch.cuda.Event()
|
| 95 |
state.compute_event.record()
|
|
|
|
|
|
|
| 96 |
else:
|
| 97 |
state.computed_u = None
|
| 98 |
state.compute_event = None
|
| 99 |
|
| 100 |
|
| 101 |
-
def _scatter(p, state, rank, comm_stream):
|
| 102 |
u = state.computed_u
|
| 103 |
mesh = p.device_mesh
|
| 104 |
|
|
@@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream):
|
|
| 118 |
src=state.worker_rank,
|
| 119 |
group=mesh.get_group(),
|
| 120 |
)
|
|
|
|
|
|
|
|
|
|
| 121 |
u = DTensor.from_local(
|
| 122 |
u,
|
| 123 |
placements=p.placements,
|
| 124 |
device_mesh=mesh,
|
| 125 |
)
|
| 126 |
-
|
| 127 |
-
|
| 128 |
|
| 129 |
|
| 130 |
class Muon(torch.optim.Optimizer):
|
|
@@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 353 |
def enqueue_scatters(start_idx, chunk_size):
|
| 354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 355 |
state = param_to_state[id(p)]
|
| 356 |
-
|
|
|
|
| 357 |
|
| 358 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 359 |
|
|
@@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer):
|
|
| 368 |
|
| 369 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
| 370 |
|
| 371 |
-
for p in params:
|
| 372 |
-
g = p.grad
|
| 373 |
-
if g is None:
|
| 374 |
-
continue
|
| 375 |
-
|
| 376 |
-
# Update p with sharded u
|
| 377 |
-
state = param_to_state[id(p)]
|
| 378 |
-
self._update_p(
|
| 379 |
-
p,
|
| 380 |
-
state.scattered_u,
|
| 381 |
-
lr=lr,
|
| 382 |
-
wd=wd,
|
| 383 |
-
)
|
| 384 |
-
|
| 385 |
def step(self, closure=None):
|
| 386 |
"""Perform a single optimization step.
|
| 387 |
|
|
|
|
| 48 |
worker_rank: int | None = None
|
| 49 |
gathered_grad: torch.Tensor | None = None
|
| 50 |
computed_u: torch.Tensor | None = None
|
|
|
|
| 51 |
gather_event: torch.cuda.Event | None = None
|
| 52 |
compute_event: torch.cuda.Event | None = None
|
| 53 |
|
|
|
|
| 92 |
state.computed_u = u
|
| 93 |
state.compute_event = torch.cuda.Event()
|
| 94 |
state.compute_event.record()
|
| 95 |
+
state.gathered_grad.record_stream(compute_stream)
|
| 96 |
+
del state.gathered_grad
|
| 97 |
else:
|
| 98 |
state.computed_u = None
|
| 99 |
state.compute_event = None
|
| 100 |
|
| 101 |
|
| 102 |
+
def _scatter(p, state, lr, wd, rank, comm_stream):
|
| 103 |
u = state.computed_u
|
| 104 |
mesh = p.device_mesh
|
| 105 |
|
|
|
|
| 119 |
src=state.worker_rank,
|
| 120 |
group=mesh.get_group(),
|
| 121 |
)
|
| 122 |
+
if rank == state.worker_rank:
|
| 123 |
+
state.computed_u.record_stream(comm_stream)
|
| 124 |
+
del state.computed_u
|
| 125 |
u = DTensor.from_local(
|
| 126 |
u,
|
| 127 |
placements=p.placements,
|
| 128 |
device_mesh=mesh,
|
| 129 |
)
|
| 130 |
+
p.data.mul_(1 - lr * wd)
|
| 131 |
+
p.data.add_(u, alpha=-lr)
|
| 132 |
|
| 133 |
|
| 134 |
class Muon(torch.optim.Optimizer):
|
|
|
|
| 357 |
def enqueue_scatters(start_idx, chunk_size):
|
| 358 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 359 |
state = param_to_state[id(p)]
|
| 360 |
+
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 361 |
+
_scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
|
| 362 |
|
| 363 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 364 |
|
|
|
|
| 373 |
|
| 374 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
| 375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
def step(self, closure=None):
|
| 377 |
"""Perform a single optimization step.
|
| 378 |
|
build/torch26-cxx98-cu126-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_8535e80_dirty
|
| 3 |
+
ops = torch.ops._optimizer_8535e80_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_8535e80_dirty::{op_name}"
|
build/torch26-cxx98-cu126-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:34215ecc274ef516967962c8457dad214e9bbf618bf5eee8f467371f4f620284
|
| 3 |
+
size 1824184
|
build/torch26-cxx98-cu126-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -48,7 +48,6 @@ class _muon_state:
|
|
| 48 |
worker_rank: int | None = None
|
| 49 |
gathered_grad: torch.Tensor | None = None
|
| 50 |
computed_u: torch.Tensor | None = None
|
| 51 |
-
scattered_u: torch.Tensor | None = None
|
| 52 |
gather_event: torch.cuda.Event | None = None
|
| 53 |
compute_event: torch.cuda.Event | None = None
|
| 54 |
|
|
@@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 93 |
state.computed_u = u
|
| 94 |
state.compute_event = torch.cuda.Event()
|
| 95 |
state.compute_event.record()
|
|
|
|
|
|
|
| 96 |
else:
|
| 97 |
state.computed_u = None
|
| 98 |
state.compute_event = None
|
| 99 |
|
| 100 |
|
| 101 |
-
def _scatter(p, state, rank, comm_stream):
|
| 102 |
u = state.computed_u
|
| 103 |
mesh = p.device_mesh
|
| 104 |
|
|
@@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream):
|
|
| 118 |
src=state.worker_rank,
|
| 119 |
group=mesh.get_group(),
|
| 120 |
)
|
|
|
|
|
|
|
|
|
|
| 121 |
u = DTensor.from_local(
|
| 122 |
u,
|
| 123 |
placements=p.placements,
|
| 124 |
device_mesh=mesh,
|
| 125 |
)
|
| 126 |
-
|
| 127 |
-
|
| 128 |
|
| 129 |
|
| 130 |
class Muon(torch.optim.Optimizer):
|
|
@@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 353 |
def enqueue_scatters(start_idx, chunk_size):
|
| 354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 355 |
state = param_to_state[id(p)]
|
| 356 |
-
|
|
|
|
| 357 |
|
| 358 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 359 |
|
|
@@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer):
|
|
| 368 |
|
| 369 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
| 370 |
|
| 371 |
-
for p in params:
|
| 372 |
-
g = p.grad
|
| 373 |
-
if g is None:
|
| 374 |
-
continue
|
| 375 |
-
|
| 376 |
-
# Update p with sharded u
|
| 377 |
-
state = param_to_state[id(p)]
|
| 378 |
-
self._update_p(
|
| 379 |
-
p,
|
| 380 |
-
state.scattered_u,
|
| 381 |
-
lr=lr,
|
| 382 |
-
wd=wd,
|
| 383 |
-
)
|
| 384 |
-
|
| 385 |
def step(self, closure=None):
|
| 386 |
"""Perform a single optimization step.
|
| 387 |
|
|
|
|
| 48 |
worker_rank: int | None = None
|
| 49 |
gathered_grad: torch.Tensor | None = None
|
| 50 |
computed_u: torch.Tensor | None = None
|
|
|
|
| 51 |
gather_event: torch.cuda.Event | None = None
|
| 52 |
compute_event: torch.cuda.Event | None = None
|
| 53 |
|
|
|
|
| 92 |
state.computed_u = u
|
| 93 |
state.compute_event = torch.cuda.Event()
|
| 94 |
state.compute_event.record()
|
| 95 |
+
state.gathered_grad.record_stream(compute_stream)
|
| 96 |
+
del state.gathered_grad
|
| 97 |
else:
|
| 98 |
state.computed_u = None
|
| 99 |
state.compute_event = None
|
| 100 |
|
| 101 |
|
| 102 |
+
def _scatter(p, state, lr, wd, rank, comm_stream):
|
| 103 |
u = state.computed_u
|
| 104 |
mesh = p.device_mesh
|
| 105 |
|
|
|
|
| 119 |
src=state.worker_rank,
|
| 120 |
group=mesh.get_group(),
|
| 121 |
)
|
| 122 |
+
if rank == state.worker_rank:
|
| 123 |
+
state.computed_u.record_stream(comm_stream)
|
| 124 |
+
del state.computed_u
|
| 125 |
u = DTensor.from_local(
|
| 126 |
u,
|
| 127 |
placements=p.placements,
|
| 128 |
device_mesh=mesh,
|
| 129 |
)
|
| 130 |
+
p.data.mul_(1 - lr * wd)
|
| 131 |
+
p.data.add_(u, alpha=-lr)
|
| 132 |
|
| 133 |
|
| 134 |
class Muon(torch.optim.Optimizer):
|
|
|
|
| 357 |
def enqueue_scatters(start_idx, chunk_size):
|
| 358 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 359 |
state = param_to_state[id(p)]
|
| 360 |
+
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 361 |
+
_scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
|
| 362 |
|
| 363 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 364 |
|
|
|
|
| 373 |
|
| 374 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
| 375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
def step(self, closure=None):
|
| 377 |
"""Perform a single optimization step.
|
| 378 |
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_8535e80_dirty
|
| 3 |
+
ops = torch.ops._optimizer_8535e80_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_8535e80_dirty::{op_name}"
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c23a3adbe4dc1a64b4851a9f8e4aed0e3e1eeeded27322c54f5b942282a2a332
|
| 3 |
+
size 1787368
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -48,7 +48,6 @@ class _muon_state:
|
|
| 48 |
worker_rank: int | None = None
|
| 49 |
gathered_grad: torch.Tensor | None = None
|
| 50 |
computed_u: torch.Tensor | None = None
|
| 51 |
-
scattered_u: torch.Tensor | None = None
|
| 52 |
gather_event: torch.cuda.Event | None = None
|
| 53 |
compute_event: torch.cuda.Event | None = None
|
| 54 |
|
|
@@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 93 |
state.computed_u = u
|
| 94 |
state.compute_event = torch.cuda.Event()
|
| 95 |
state.compute_event.record()
|
|
|
|
|
|
|
| 96 |
else:
|
| 97 |
state.computed_u = None
|
| 98 |
state.compute_event = None
|
| 99 |
|
| 100 |
|
| 101 |
-
def _scatter(p, state, rank, comm_stream):
|
| 102 |
u = state.computed_u
|
| 103 |
mesh = p.device_mesh
|
| 104 |
|
|
@@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream):
|
|
| 118 |
src=state.worker_rank,
|
| 119 |
group=mesh.get_group(),
|
| 120 |
)
|
|
|
|
|
|
|
|
|
|
| 121 |
u = DTensor.from_local(
|
| 122 |
u,
|
| 123 |
placements=p.placements,
|
| 124 |
device_mesh=mesh,
|
| 125 |
)
|
| 126 |
-
|
| 127 |
-
|
| 128 |
|
| 129 |
|
| 130 |
class Muon(torch.optim.Optimizer):
|
|
@@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 353 |
def enqueue_scatters(start_idx, chunk_size):
|
| 354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 355 |
state = param_to_state[id(p)]
|
| 356 |
-
|
|
|
|
| 357 |
|
| 358 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 359 |
|
|
@@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer):
|
|
| 368 |
|
| 369 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
| 370 |
|
| 371 |
-
for p in params:
|
| 372 |
-
g = p.grad
|
| 373 |
-
if g is None:
|
| 374 |
-
continue
|
| 375 |
-
|
| 376 |
-
# Update p with sharded u
|
| 377 |
-
state = param_to_state[id(p)]
|
| 378 |
-
self._update_p(
|
| 379 |
-
p,
|
| 380 |
-
state.scattered_u,
|
| 381 |
-
lr=lr,
|
| 382 |
-
wd=wd,
|
| 383 |
-
)
|
| 384 |
-
|
| 385 |
def step(self, closure=None):
|
| 386 |
"""Perform a single optimization step.
|
| 387 |
|
|
|
|
| 48 |
worker_rank: int | None = None
|
| 49 |
gathered_grad: torch.Tensor | None = None
|
| 50 |
computed_u: torch.Tensor | None = None
|
|
|
|
| 51 |
gather_event: torch.cuda.Event | None = None
|
| 52 |
compute_event: torch.cuda.Event | None = None
|
| 53 |
|
|
|
|
| 92 |
state.computed_u = u
|
| 93 |
state.compute_event = torch.cuda.Event()
|
| 94 |
state.compute_event.record()
|
| 95 |
+
state.gathered_grad.record_stream(compute_stream)
|
| 96 |
+
del state.gathered_grad
|
| 97 |
else:
|
| 98 |
state.computed_u = None
|
| 99 |
state.compute_event = None
|
| 100 |
|
| 101 |
|
| 102 |
+
def _scatter(p, state, lr, wd, rank, comm_stream):
|
| 103 |
u = state.computed_u
|
| 104 |
mesh = p.device_mesh
|
| 105 |
|
|
|
|
| 119 |
src=state.worker_rank,
|
| 120 |
group=mesh.get_group(),
|
| 121 |
)
|
| 122 |
+
if rank == state.worker_rank:
|
| 123 |
+
state.computed_u.record_stream(comm_stream)
|
| 124 |
+
del state.computed_u
|
| 125 |
u = DTensor.from_local(
|
| 126 |
u,
|
| 127 |
placements=p.placements,
|
| 128 |
device_mesh=mesh,
|
| 129 |
)
|
| 130 |
+
p.data.mul_(1 - lr * wd)
|
| 131 |
+
p.data.add_(u, alpha=-lr)
|
| 132 |
|
| 133 |
|
| 134 |
class Muon(torch.optim.Optimizer):
|
|
|
|
| 357 |
def enqueue_scatters(start_idx, chunk_size):
|
| 358 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 359 |
state = param_to_state[id(p)]
|
| 360 |
+
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 361 |
+
_scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
|
| 362 |
|
| 363 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 364 |
|
|
|
|
| 373 |
|
| 374 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
| 375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
def step(self, closure=None):
|
| 377 |
"""Perform a single optimization step.
|
| 378 |
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_8535e80_dirty
|
| 3 |
+
ops = torch.ops._optimizer_8535e80_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_8535e80_dirty::{op_name}"
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d4aa09c22745d5efe1ef0669c4ca05615f67595dc90cabeee6e878301fa9bd22
|
| 3 |
+
size 1824256
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -48,7 +48,6 @@ class _muon_state:
|
|
| 48 |
worker_rank: int | None = None
|
| 49 |
gathered_grad: torch.Tensor | None = None
|
| 50 |
computed_u: torch.Tensor | None = None
|
| 51 |
-
scattered_u: torch.Tensor | None = None
|
| 52 |
gather_event: torch.cuda.Event | None = None
|
| 53 |
compute_event: torch.cuda.Event | None = None
|
| 54 |
|
|
@@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 93 |
state.computed_u = u
|
| 94 |
state.compute_event = torch.cuda.Event()
|
| 95 |
state.compute_event.record()
|
|
|
|
|
|
|
| 96 |
else:
|
| 97 |
state.computed_u = None
|
| 98 |
state.compute_event = None
|
| 99 |
|
| 100 |
|
| 101 |
-
def _scatter(p, state, rank, comm_stream):
|
| 102 |
u = state.computed_u
|
| 103 |
mesh = p.device_mesh
|
| 104 |
|
|
@@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream):
|
|
| 118 |
src=state.worker_rank,
|
| 119 |
group=mesh.get_group(),
|
| 120 |
)
|
|
|
|
|
|
|
|
|
|
| 121 |
u = DTensor.from_local(
|
| 122 |
u,
|
| 123 |
placements=p.placements,
|
| 124 |
device_mesh=mesh,
|
| 125 |
)
|
| 126 |
-
|
| 127 |
-
|
| 128 |
|
| 129 |
|
| 130 |
class Muon(torch.optim.Optimizer):
|
|
@@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 353 |
def enqueue_scatters(start_idx, chunk_size):
|
| 354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 355 |
state = param_to_state[id(p)]
|
| 356 |
-
|
|
|
|
| 357 |
|
| 358 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 359 |
|
|
@@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer):
|
|
| 368 |
|
| 369 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
| 370 |
|
| 371 |
-
for p in params:
|
| 372 |
-
g = p.grad
|
| 373 |
-
if g is None:
|
| 374 |
-
continue
|
| 375 |
-
|
| 376 |
-
# Update p with sharded u
|
| 377 |
-
state = param_to_state[id(p)]
|
| 378 |
-
self._update_p(
|
| 379 |
-
p,
|
| 380 |
-
state.scattered_u,
|
| 381 |
-
lr=lr,
|
| 382 |
-
wd=wd,
|
| 383 |
-
)
|
| 384 |
-
|
| 385 |
def step(self, closure=None):
|
| 386 |
"""Perform a single optimization step.
|
| 387 |
|
|
|
|
| 48 |
worker_rank: int | None = None
|
| 49 |
gathered_grad: torch.Tensor | None = None
|
| 50 |
computed_u: torch.Tensor | None = None
|
|
|
|
| 51 |
gather_event: torch.cuda.Event | None = None
|
| 52 |
compute_event: torch.cuda.Event | None = None
|
| 53 |
|
|
|
|
| 92 |
state.computed_u = u
|
| 93 |
state.compute_event = torch.cuda.Event()
|
| 94 |
state.compute_event.record()
|
| 95 |
+
state.gathered_grad.record_stream(compute_stream)
|
| 96 |
+
del state.gathered_grad
|
| 97 |
else:
|
| 98 |
state.computed_u = None
|
| 99 |
state.compute_event = None
|
| 100 |
|
| 101 |
|
| 102 |
+
def _scatter(p, state, lr, wd, rank, comm_stream):
|
| 103 |
u = state.computed_u
|
| 104 |
mesh = p.device_mesh
|
| 105 |
|
|
|
|
| 119 |
src=state.worker_rank,
|
| 120 |
group=mesh.get_group(),
|
| 121 |
)
|
| 122 |
+
if rank == state.worker_rank:
|
| 123 |
+
state.computed_u.record_stream(comm_stream)
|
| 124 |
+
del state.computed_u
|
| 125 |
u = DTensor.from_local(
|
| 126 |
u,
|
| 127 |
placements=p.placements,
|
| 128 |
device_mesh=mesh,
|
| 129 |
)
|
| 130 |
+
p.data.mul_(1 - lr * wd)
|
| 131 |
+
p.data.add_(u, alpha=-lr)
|
| 132 |
|
| 133 |
|
| 134 |
class Muon(torch.optim.Optimizer):
|
|
|
|
| 357 |
def enqueue_scatters(start_idx, chunk_size):
|
| 358 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 359 |
state = param_to_state[id(p)]
|
| 360 |
+
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 361 |
+
_scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
|
| 362 |
|
| 363 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 364 |
|
|
|
|
| 373 |
|
| 374 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
| 375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
def step(self, closure=None):
|
| 377 |
"""Perform a single optimization step.
|
| 378 |
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_8535e80_dirty
|
| 3 |
+
ops = torch.ops._optimizer_8535e80_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_8535e80_dirty::{op_name}"
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b4baf569b70749c4657062fb0f56943fc486adb0c482e50c7aa8e31ddf5cc870
|
| 3 |
+
size 1883352
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -48,7 +48,6 @@ class _muon_state:
|
|
| 48 |
worker_rank: int | None = None
|
| 49 |
gathered_grad: torch.Tensor | None = None
|
| 50 |
computed_u: torch.Tensor | None = None
|
| 51 |
-
scattered_u: torch.Tensor | None = None
|
| 52 |
gather_event: torch.cuda.Event | None = None
|
| 53 |
compute_event: torch.cuda.Event | None = None
|
| 54 |
|
|
@@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 93 |
state.computed_u = u
|
| 94 |
state.compute_event = torch.cuda.Event()
|
| 95 |
state.compute_event.record()
|
|
|
|
|
|
|
| 96 |
else:
|
| 97 |
state.computed_u = None
|
| 98 |
state.compute_event = None
|
| 99 |
|
| 100 |
|
| 101 |
-
def _scatter(p, state, rank, comm_stream):
|
| 102 |
u = state.computed_u
|
| 103 |
mesh = p.device_mesh
|
| 104 |
|
|
@@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream):
|
|
| 118 |
src=state.worker_rank,
|
| 119 |
group=mesh.get_group(),
|
| 120 |
)
|
|
|
|
|
|
|
|
|
|
| 121 |
u = DTensor.from_local(
|
| 122 |
u,
|
| 123 |
placements=p.placements,
|
| 124 |
device_mesh=mesh,
|
| 125 |
)
|
| 126 |
-
|
| 127 |
-
|
| 128 |
|
| 129 |
|
| 130 |
class Muon(torch.optim.Optimizer):
|
|
@@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 353 |
def enqueue_scatters(start_idx, chunk_size):
|
| 354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 355 |
state = param_to_state[id(p)]
|
| 356 |
-
|
|
|
|
| 357 |
|
| 358 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 359 |
|
|
@@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer):
|
|
| 368 |
|
| 369 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
| 370 |
|
| 371 |
-
for p in params:
|
| 372 |
-
g = p.grad
|
| 373 |
-
if g is None:
|
| 374 |
-
continue
|
| 375 |
-
|
| 376 |
-
# Update p with sharded u
|
| 377 |
-
state = param_to_state[id(p)]
|
| 378 |
-
self._update_p(
|
| 379 |
-
p,
|
| 380 |
-
state.scattered_u,
|
| 381 |
-
lr=lr,
|
| 382 |
-
wd=wd,
|
| 383 |
-
)
|
| 384 |
-
|
| 385 |
def step(self, closure=None):
|
| 386 |
"""Perform a single optimization step.
|
| 387 |
|
|
|
|
| 48 |
worker_rank: int | None = None
|
| 49 |
gathered_grad: torch.Tensor | None = None
|
| 50 |
computed_u: torch.Tensor | None = None
|
|
|
|
| 51 |
gather_event: torch.cuda.Event | None = None
|
| 52 |
compute_event: torch.cuda.Event | None = None
|
| 53 |
|
|
|
|
| 92 |
state.computed_u = u
|
| 93 |
state.compute_event = torch.cuda.Event()
|
| 94 |
state.compute_event.record()
|
| 95 |
+
state.gathered_grad.record_stream(compute_stream)
|
| 96 |
+
del state.gathered_grad
|
| 97 |
else:
|
| 98 |
state.computed_u = None
|
| 99 |
state.compute_event = None
|
| 100 |
|
| 101 |
|
| 102 |
+
def _scatter(p, state, lr, wd, rank, comm_stream):
|
| 103 |
u = state.computed_u
|
| 104 |
mesh = p.device_mesh
|
| 105 |
|
|
|
|
| 119 |
src=state.worker_rank,
|
| 120 |
group=mesh.get_group(),
|
| 121 |
)
|
| 122 |
+
if rank == state.worker_rank:
|
| 123 |
+
state.computed_u.record_stream(comm_stream)
|
| 124 |
+
del state.computed_u
|
| 125 |
u = DTensor.from_local(
|
| 126 |
u,
|
| 127 |
placements=p.placements,
|
| 128 |
device_mesh=mesh,
|
| 129 |
)
|
| 130 |
+
p.data.mul_(1 - lr * wd)
|
| 131 |
+
p.data.add_(u, alpha=-lr)
|
| 132 |
|
| 133 |
|
| 134 |
class Muon(torch.optim.Optimizer):
|
|
|
|
| 357 |
def enqueue_scatters(start_idx, chunk_size):
|
| 358 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 359 |
state = param_to_state[id(p)]
|
| 360 |
+
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 361 |
+
_scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
|
| 362 |
|
| 363 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 364 |
|
|
|
|
| 373 |
|
| 374 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
| 375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
def step(self, closure=None):
|
| 377 |
"""Perform a single optimization step.
|
| 378 |
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-310.pyc
CHANGED
|
Binary files a/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-310.pyc and b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-310.pyc differ
|
|
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-310.pyc
CHANGED
|
Binary files a/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-310.pyc and b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-310.pyc differ
|
|
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_8535e80_dirty
|
| 3 |
+
ops = torch.ops._optimizer_8535e80_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_8535e80_dirty::{op_name}"
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_8535e80_dirty.abi3.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8566c9bc05e13c9394572f9f9c6bac24c31932548be485f49eb49fb249880832
|
| 3 |
+
size 1749648
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -48,7 +48,6 @@ class _muon_state:
|
|
| 48 |
worker_rank: int | None = None
|
| 49 |
gathered_grad: torch.Tensor | None = None
|
| 50 |
computed_u: torch.Tensor | None = None
|
| 51 |
-
scattered_u: torch.Tensor | None = None
|
| 52 |
gather_event: torch.cuda.Event | None = None
|
| 53 |
compute_event: torch.cuda.Event | None = None
|
| 54 |
|
|
@@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 93 |
state.computed_u = u
|
| 94 |
state.compute_event = torch.cuda.Event()
|
| 95 |
state.compute_event.record()
|
|
|
|
|
|
|
| 96 |
else:
|
| 97 |
state.computed_u = None
|
| 98 |
state.compute_event = None
|
| 99 |
|
| 100 |
|
| 101 |
-
def _scatter(p, state, rank, comm_stream):
|
| 102 |
u = state.computed_u
|
| 103 |
mesh = p.device_mesh
|
| 104 |
|
|
@@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream):
|
|
| 118 |
src=state.worker_rank,
|
| 119 |
group=mesh.get_group(),
|
| 120 |
)
|
|
|
|
|
|
|
|
|
|
| 121 |
u = DTensor.from_local(
|
| 122 |
u,
|
| 123 |
placements=p.placements,
|
| 124 |
device_mesh=mesh,
|
| 125 |
)
|
| 126 |
-
|
| 127 |
-
|
| 128 |
|
| 129 |
|
| 130 |
class Muon(torch.optim.Optimizer):
|
|
@@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 353 |
def enqueue_scatters(start_idx, chunk_size):
|
| 354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 355 |
state = param_to_state[id(p)]
|
| 356 |
-
|
|
|
|
| 357 |
|
| 358 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 359 |
|
|
@@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer):
|
|
| 368 |
|
| 369 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
| 370 |
|
| 371 |
-
for p in params:
|
| 372 |
-
g = p.grad
|
| 373 |
-
if g is None:
|
| 374 |
-
continue
|
| 375 |
-
|
| 376 |
-
# Update p with sharded u
|
| 377 |
-
state = param_to_state[id(p)]
|
| 378 |
-
self._update_p(
|
| 379 |
-
p,
|
| 380 |
-
state.scattered_u,
|
| 381 |
-
lr=lr,
|
| 382 |
-
wd=wd,
|
| 383 |
-
)
|
| 384 |
-
|
| 385 |
def step(self, closure=None):
|
| 386 |
"""Perform a single optimization step.
|
| 387 |
|
|
|
|
| 48 |
worker_rank: int | None = None
|
| 49 |
gathered_grad: torch.Tensor | None = None
|
| 50 |
computed_u: torch.Tensor | None = None
|
|
|
|
| 51 |
gather_event: torch.cuda.Event | None = None
|
| 52 |
compute_event: torch.cuda.Event | None = None
|
| 53 |
|
|
|
|
| 92 |
state.computed_u = u
|
| 93 |
state.compute_event = torch.cuda.Event()
|
| 94 |
state.compute_event.record()
|
| 95 |
+
state.gathered_grad.record_stream(compute_stream)
|
| 96 |
+
del state.gathered_grad
|
| 97 |
else:
|
| 98 |
state.computed_u = None
|
| 99 |
state.compute_event = None
|
| 100 |
|
| 101 |
|
| 102 |
+
def _scatter(p, state, lr, wd, rank, comm_stream):
|
| 103 |
u = state.computed_u
|
| 104 |
mesh = p.device_mesh
|
| 105 |
|
|
|
|
| 119 |
src=state.worker_rank,
|
| 120 |
group=mesh.get_group(),
|
| 121 |
)
|
| 122 |
+
if rank == state.worker_rank:
|
| 123 |
+
state.computed_u.record_stream(comm_stream)
|
| 124 |
+
del state.computed_u
|
| 125 |
u = DTensor.from_local(
|
| 126 |
u,
|
| 127 |
placements=p.placements,
|
| 128 |
device_mesh=mesh,
|
| 129 |
)
|
| 130 |
+
p.data.mul_(1 - lr * wd)
|
| 131 |
+
p.data.add_(u, alpha=-lr)
|
| 132 |
|
| 133 |
|
| 134 |
class Muon(torch.optim.Optimizer):
|
|
|
|
| 357 |
def enqueue_scatters(start_idx, chunk_size):
|
| 358 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 359 |
state = param_to_state[id(p)]
|
| 360 |
+
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 361 |
+
_scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
|
| 362 |
|
| 363 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 364 |
|
|
|
|
| 373 |
|
| 374 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
| 375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
def step(self, closure=None):
|
| 377 |
"""Perform a single optimization step.
|
| 378 |
|
torch-ext/optimizer/muon.py
CHANGED
|
@@ -48,7 +48,6 @@ class _muon_state:
|
|
| 48 |
worker_rank: int | None = None
|
| 49 |
gathered_grad: torch.Tensor | None = None
|
| 50 |
computed_u: torch.Tensor | None = None
|
| 51 |
-
scattered_u: torch.Tensor | None = None
|
| 52 |
gather_event: torch.cuda.Event | None = None
|
| 53 |
compute_event: torch.cuda.Event | None = None
|
| 54 |
|
|
@@ -93,12 +92,14 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 93 |
state.computed_u = u
|
| 94 |
state.compute_event = torch.cuda.Event()
|
| 95 |
state.compute_event.record()
|
|
|
|
|
|
|
| 96 |
else:
|
| 97 |
state.computed_u = None
|
| 98 |
state.compute_event = None
|
| 99 |
|
| 100 |
|
| 101 |
-
def _scatter(p, state, rank, comm_stream):
|
| 102 |
u = state.computed_u
|
| 103 |
mesh = p.device_mesh
|
| 104 |
|
|
@@ -118,13 +119,16 @@ def _scatter(p, state, rank, comm_stream):
|
|
| 118 |
src=state.worker_rank,
|
| 119 |
group=mesh.get_group(),
|
| 120 |
)
|
|
|
|
|
|
|
|
|
|
| 121 |
u = DTensor.from_local(
|
| 122 |
u,
|
| 123 |
placements=p.placements,
|
| 124 |
device_mesh=mesh,
|
| 125 |
)
|
| 126 |
-
|
| 127 |
-
|
| 128 |
|
| 129 |
|
| 130 |
class Muon(torch.optim.Optimizer):
|
|
@@ -353,7 +357,8 @@ class Muon(torch.optim.Optimizer):
|
|
| 353 |
def enqueue_scatters(start_idx, chunk_size):
|
| 354 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 355 |
state = param_to_state[id(p)]
|
| 356 |
-
|
|
|
|
| 357 |
|
| 358 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 359 |
|
|
@@ -368,20 +373,6 @@ class Muon(torch.optim.Optimizer):
|
|
| 368 |
|
| 369 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
| 370 |
|
| 371 |
-
for p in params:
|
| 372 |
-
g = p.grad
|
| 373 |
-
if g is None:
|
| 374 |
-
continue
|
| 375 |
-
|
| 376 |
-
# Update p with sharded u
|
| 377 |
-
state = param_to_state[id(p)]
|
| 378 |
-
self._update_p(
|
| 379 |
-
p,
|
| 380 |
-
state.scattered_u,
|
| 381 |
-
lr=lr,
|
| 382 |
-
wd=wd,
|
| 383 |
-
)
|
| 384 |
-
|
| 385 |
def step(self, closure=None):
|
| 386 |
"""Perform a single optimization step.
|
| 387 |
|
|
|
|
| 48 |
worker_rank: int | None = None
|
| 49 |
gathered_grad: torch.Tensor | None = None
|
| 50 |
computed_u: torch.Tensor | None = None
|
|
|
|
| 51 |
gather_event: torch.cuda.Event | None = None
|
| 52 |
compute_event: torch.cuda.Event | None = None
|
| 53 |
|
|
|
|
| 92 |
state.computed_u = u
|
| 93 |
state.compute_event = torch.cuda.Event()
|
| 94 |
state.compute_event.record()
|
| 95 |
+
state.gathered_grad.record_stream(compute_stream)
|
| 96 |
+
del state.gathered_grad
|
| 97 |
else:
|
| 98 |
state.computed_u = None
|
| 99 |
state.compute_event = None
|
| 100 |
|
| 101 |
|
| 102 |
+
def _scatter(p, state, lr, wd, rank, comm_stream):
|
| 103 |
u = state.computed_u
|
| 104 |
mesh = p.device_mesh
|
| 105 |
|
|
|
|
| 119 |
src=state.worker_rank,
|
| 120 |
group=mesh.get_group(),
|
| 121 |
)
|
| 122 |
+
if rank == state.worker_rank:
|
| 123 |
+
state.computed_u.record_stream(comm_stream)
|
| 124 |
+
del state.computed_u
|
| 125 |
u = DTensor.from_local(
|
| 126 |
u,
|
| 127 |
placements=p.placements,
|
| 128 |
device_mesh=mesh,
|
| 129 |
)
|
| 130 |
+
p.data.mul_(1 - lr * wd)
|
| 131 |
+
p.data.add_(u, alpha=-lr)
|
| 132 |
|
| 133 |
|
| 134 |
class Muon(torch.optim.Optimizer):
|
|
|
|
| 357 |
def enqueue_scatters(start_idx, chunk_size):
|
| 358 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 359 |
state = param_to_state[id(p)]
|
| 360 |
+
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 361 |
+
_scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
|
| 362 |
|
| 363 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 364 |
|
|
|
|
| 373 |
|
| 374 |
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
| 375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
def step(self, closure=None):
|
| 377 |
"""Perform a single optimization step.
|
| 378 |
|