Commit
Β·
02ac540
1
Parent(s):
64757cb
refactor(muon): change argument adam_wd to weight_decay and handle params' type
Browse files- build/torch26-cxx11-cu118-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so} +1 -1
- build/torch26-cxx11-cu118-x86_64-linux/optimizer/muon.py +52 -21
- build/torch26-cxx11-cu124-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx11-cu124-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so} +1 -1
- build/torch26-cxx11-cu124-x86_64-linux/optimizer/muon.py +52 -21
- build/torch26-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so} +1 -1
- build/torch26-cxx11-cu126-x86_64-linux/optimizer/muon.py +52 -21
- build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx11-rocm62-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so} +1 -1
- build/torch26-cxx11-rocm62-x86_64-linux/optimizer/muon.py +52 -21
- build/torch26-cxx98-cu118-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx98-cu118-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so} +1 -1
- build/torch26-cxx98-cu118-x86_64-linux/optimizer/muon.py +52 -21
- build/torch26-cxx98-cu124-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx98-cu124-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so} +1 -1
- build/torch26-cxx98-cu124-x86_64-linux/optimizer/muon.py +52 -21
- build/torch26-cxx98-cu126-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx98-cu126-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so} +1 -1
- build/torch26-cxx98-cu126-x86_64-linux/optimizer/muon.py +52 -21
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so} +1 -1
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py +52 -21
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so} +1 -1
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py +52 -21
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so} +1 -1
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py +52 -21
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-312.pyc +0 -0
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-312.pyc +0 -0
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so} +1 -1
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py +52 -21
- torch-ext/optimizer/muon.py +52 -21
build/torch26-cxx11-cu118-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_64757cb_dirty
|
| 3 |
+
ops = torch.ops._optimizer_64757cb_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_64757cb_dirty::{op_name}"
|
build/torch26-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1787272
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f1f5df341112d93e43c0801e285abd66e79bfbe399d228f8be09ff26ece7421b
|
| 3 |
size 1787272
|
build/torch26-cxx11-cu118-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
-
from torch.distributed._tensor import DTensor
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
@@ -103,7 +103,7 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 103 |
|
| 104 |
|
| 105 |
@torch.no_grad()
|
| 106 |
-
def _scatter(p, state, lr,
|
| 107 |
u = state.computed_u
|
| 108 |
mesh = p.device_mesh
|
| 109 |
|
|
@@ -131,10 +131,14 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
| 131 |
placements=p.placements,
|
| 132 |
device_mesh=mesh,
|
| 133 |
)
|
| 134 |
-
p.data.mul_(1 - lr *
|
| 135 |
p.data.add_(u, alpha=-lr)
|
| 136 |
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
class Muon(torch.optim.Optimizer):
|
| 139 |
"""
|
| 140 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
@@ -159,18 +163,18 @@ class Muon(torch.optim.Optimizer):
|
|
| 159 |
adamw_lr: The learning rate for the internal AdamW.
|
| 160 |
adamw_betas: The betas for the internal AdamW.
|
| 161 |
adamw_eps: The epsilon for the internal AdamW.
|
| 162 |
-
|
| 163 |
"""
|
| 164 |
|
| 165 |
def __init__(
|
| 166 |
self,
|
| 167 |
model,
|
| 168 |
-
is_muon_func,
|
| 169 |
lr=1e-3,
|
| 170 |
momentum=0.95,
|
| 171 |
nesterov=True,
|
| 172 |
ns_steps=5,
|
| 173 |
-
|
| 174 |
adamw_betas=(0.9, 0.95),
|
| 175 |
adamw_eps=1e-8,
|
| 176 |
none_grad=True,
|
|
@@ -178,7 +182,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 178 |
):
|
| 179 |
defaults = dict(
|
| 180 |
lr=lr,
|
| 181 |
-
|
| 182 |
momentum=momentum,
|
| 183 |
nesterov=nesterov,
|
| 184 |
ns_steps=ns_steps,
|
|
@@ -272,7 +276,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 272 |
|
| 273 |
return param_to_state, ordered_params
|
| 274 |
|
| 275 |
-
def base(self, params, group, lr,
|
| 276 |
# generate weight updates in distributed fashion
|
| 277 |
for p in params:
|
| 278 |
g = p.grad
|
|
@@ -299,7 +303,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 299 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 300 |
|
| 301 |
# apply weight decay
|
| 302 |
-
p.data.mul_(1 - lr *
|
| 303 |
|
| 304 |
# apply update
|
| 305 |
p.data.add_(u, alpha=-adjusted_lr)
|
|
@@ -317,15 +321,15 @@ class Muon(torch.optim.Optimizer):
|
|
| 317 |
g = buf
|
| 318 |
return g
|
| 319 |
|
| 320 |
-
def _update_p(self, p, u, lr,
|
| 321 |
# scale update
|
| 322 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 323 |
# apply weight decay
|
| 324 |
-
p.data.mul_(1 - lr *
|
| 325 |
# apply update
|
| 326 |
p.data.add_(u, alpha=-adjusted_lr)
|
| 327 |
|
| 328 |
-
def parallel(self, params, group, lr,
|
| 329 |
"""
|
| 330 |
Perform a parallel optimization step using Muon.
|
| 331 |
"""
|
|
@@ -364,7 +368,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 364 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 365 |
state = param_to_state[id(p)]
|
| 366 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 367 |
-
_scatter(
|
|
|
|
|
|
|
| 368 |
|
| 369 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 370 |
|
|
@@ -398,23 +404,48 @@ class Muon(torch.optim.Optimizer):
|
|
| 398 |
|
| 399 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 400 |
lr = group["lr"]
|
| 401 |
-
|
| 402 |
momentum = group["momentum"]
|
| 403 |
|
| 404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
self.parallel(
|
| 406 |
-
|
| 407 |
group,
|
| 408 |
lr=lr,
|
| 409 |
-
|
| 410 |
momentum=momentum,
|
| 411 |
)
|
| 412 |
-
|
|
|
|
| 413 |
self.base(
|
| 414 |
-
|
| 415 |
group,
|
| 416 |
lr=lr,
|
| 417 |
-
|
| 418 |
momentum=momentum,
|
| 419 |
)
|
| 420 |
|
|
@@ -426,7 +457,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 426 |
lr = group["lr"]
|
| 427 |
beta1, beta2 = group["adamw_betas"]
|
| 428 |
eps = group["adamw_eps"]
|
| 429 |
-
weight_decay = group["
|
| 430 |
|
| 431 |
for p in params:
|
| 432 |
g = p.grad
|
|
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
+
from torch.distributed._tensor import DTensor, Replicate
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
|
| 103 |
|
| 104 |
|
| 105 |
@torch.no_grad()
|
| 106 |
+
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
| 107 |
u = state.computed_u
|
| 108 |
mesh = p.device_mesh
|
| 109 |
|
|
|
|
| 131 |
placements=p.placements,
|
| 132 |
device_mesh=mesh,
|
| 133 |
)
|
| 134 |
+
p.data.mul_(1 - lr * weight_decay)
|
| 135 |
p.data.add_(u, alpha=-lr)
|
| 136 |
|
| 137 |
|
| 138 |
+
def default_is_muon(x, name):
|
| 139 |
+
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
| 140 |
+
|
| 141 |
+
|
| 142 |
class Muon(torch.optim.Optimizer):
|
| 143 |
"""
|
| 144 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
|
|
| 163 |
adamw_lr: The learning rate for the internal AdamW.
|
| 164 |
adamw_betas: The betas for the internal AdamW.
|
| 165 |
adamw_eps: The epsilon for the internal AdamW.
|
| 166 |
+
adamw_weight_decay: The weight decay for the internal AdamW.
|
| 167 |
"""
|
| 168 |
|
| 169 |
def __init__(
|
| 170 |
self,
|
| 171 |
model,
|
| 172 |
+
is_muon_func=default_is_muon,
|
| 173 |
lr=1e-3,
|
| 174 |
momentum=0.95,
|
| 175 |
nesterov=True,
|
| 176 |
ns_steps=5,
|
| 177 |
+
weight_decay=0.1,
|
| 178 |
adamw_betas=(0.9, 0.95),
|
| 179 |
adamw_eps=1e-8,
|
| 180 |
none_grad=True,
|
|
|
|
| 182 |
):
|
| 183 |
defaults = dict(
|
| 184 |
lr=lr,
|
| 185 |
+
weight_decay=weight_decay,
|
| 186 |
momentum=momentum,
|
| 187 |
nesterov=nesterov,
|
| 188 |
ns_steps=ns_steps,
|
|
|
|
| 276 |
|
| 277 |
return param_to_state, ordered_params
|
| 278 |
|
| 279 |
+
def base(self, params, group, lr, weight_decay, momentum):
|
| 280 |
# generate weight updates in distributed fashion
|
| 281 |
for p in params:
|
| 282 |
g = p.grad
|
|
|
|
| 303 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 304 |
|
| 305 |
# apply weight decay
|
| 306 |
+
p.data.mul_(1 - lr * weight_decay)
|
| 307 |
|
| 308 |
# apply update
|
| 309 |
p.data.add_(u, alpha=-adjusted_lr)
|
|
|
|
| 321 |
g = buf
|
| 322 |
return g
|
| 323 |
|
| 324 |
+
def _update_p(self, p, u, lr, weight_decay):
|
| 325 |
# scale update
|
| 326 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 327 |
# apply weight decay
|
| 328 |
+
p.data.mul_(1 - lr * weight_decay)
|
| 329 |
# apply update
|
| 330 |
p.data.add_(u, alpha=-adjusted_lr)
|
| 331 |
|
| 332 |
+
def parallel(self, params, group, lr, weight_decay, momentum):
|
| 333 |
"""
|
| 334 |
Perform a parallel optimization step using Muon.
|
| 335 |
"""
|
|
|
|
| 368 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 369 |
state = param_to_state[id(p)]
|
| 370 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 371 |
+
_scatter(
|
| 372 |
+
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
| 373 |
+
)
|
| 374 |
|
| 375 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 376 |
|
|
|
|
| 404 |
|
| 405 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 406 |
lr = group["lr"]
|
| 407 |
+
weight_decay = group["weight_decay"]
|
| 408 |
momentum = group["momentum"]
|
| 409 |
|
| 410 |
+
param_dtensors = []
|
| 411 |
+
param_tensors = []
|
| 412 |
+
|
| 413 |
+
for p in params:
|
| 414 |
+
if p is None or p.grad is None:
|
| 415 |
+
continue
|
| 416 |
+
if isinstance(p.data, DTensor):
|
| 417 |
+
if all(
|
| 418 |
+
isinstance(placement, Replicate) for placement in p.placements
|
| 419 |
+
):
|
| 420 |
+
param_tensors.append(p)
|
| 421 |
+
else:
|
| 422 |
+
param_dtensors.append(p)
|
| 423 |
+
elif isinstance(p.data, torch.Tensor):
|
| 424 |
+
param_tensors.append(p)
|
| 425 |
+
else:
|
| 426 |
+
raise TypeError(f"Unsupported parameter type: {type(p.data)}")
|
| 427 |
+
|
| 428 |
+
if self.debug:
|
| 429 |
+
print(
|
| 430 |
+
f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
|
| 431 |
+
flush=True,
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
if len(param_dtensors) > 0:
|
| 435 |
self.parallel(
|
| 436 |
+
param_dtensors,
|
| 437 |
group,
|
| 438 |
lr=lr,
|
| 439 |
+
weight_decay=weight_decay,
|
| 440 |
momentum=momentum,
|
| 441 |
)
|
| 442 |
+
|
| 443 |
+
if len(param_tensors) > 0:
|
| 444 |
self.base(
|
| 445 |
+
param_tensors,
|
| 446 |
group,
|
| 447 |
lr=lr,
|
| 448 |
+
weight_decay=weight_decay,
|
| 449 |
momentum=momentum,
|
| 450 |
)
|
| 451 |
|
|
|
|
| 457 |
lr = group["lr"]
|
| 458 |
beta1, beta2 = group["adamw_betas"]
|
| 459 |
eps = group["adamw_eps"]
|
| 460 |
+
weight_decay = group["weight_decay"]
|
| 461 |
|
| 462 |
for p in params:
|
| 463 |
g = p.grad
|
build/torch26-cxx11-cu124-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_64757cb_dirty
|
| 3 |
+
ops = torch.ops._optimizer_64757cb_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_64757cb_dirty::{op_name}"
|
build/torch26-cxx11-cu124-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1824224
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2921aa2aa2587e261dc9ca4e5f60303b0d1c9a305d1584918a8c56b6dc79ebfb
|
| 3 |
size 1824224
|
build/torch26-cxx11-cu124-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
-
from torch.distributed._tensor import DTensor
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
@@ -103,7 +103,7 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 103 |
|
| 104 |
|
| 105 |
@torch.no_grad()
|
| 106 |
-
def _scatter(p, state, lr,
|
| 107 |
u = state.computed_u
|
| 108 |
mesh = p.device_mesh
|
| 109 |
|
|
@@ -131,10 +131,14 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
| 131 |
placements=p.placements,
|
| 132 |
device_mesh=mesh,
|
| 133 |
)
|
| 134 |
-
p.data.mul_(1 - lr *
|
| 135 |
p.data.add_(u, alpha=-lr)
|
| 136 |
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
class Muon(torch.optim.Optimizer):
|
| 139 |
"""
|
| 140 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
@@ -159,18 +163,18 @@ class Muon(torch.optim.Optimizer):
|
|
| 159 |
adamw_lr: The learning rate for the internal AdamW.
|
| 160 |
adamw_betas: The betas for the internal AdamW.
|
| 161 |
adamw_eps: The epsilon for the internal AdamW.
|
| 162 |
-
|
| 163 |
"""
|
| 164 |
|
| 165 |
def __init__(
|
| 166 |
self,
|
| 167 |
model,
|
| 168 |
-
is_muon_func,
|
| 169 |
lr=1e-3,
|
| 170 |
momentum=0.95,
|
| 171 |
nesterov=True,
|
| 172 |
ns_steps=5,
|
| 173 |
-
|
| 174 |
adamw_betas=(0.9, 0.95),
|
| 175 |
adamw_eps=1e-8,
|
| 176 |
none_grad=True,
|
|
@@ -178,7 +182,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 178 |
):
|
| 179 |
defaults = dict(
|
| 180 |
lr=lr,
|
| 181 |
-
|
| 182 |
momentum=momentum,
|
| 183 |
nesterov=nesterov,
|
| 184 |
ns_steps=ns_steps,
|
|
@@ -272,7 +276,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 272 |
|
| 273 |
return param_to_state, ordered_params
|
| 274 |
|
| 275 |
-
def base(self, params, group, lr,
|
| 276 |
# generate weight updates in distributed fashion
|
| 277 |
for p in params:
|
| 278 |
g = p.grad
|
|
@@ -299,7 +303,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 299 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 300 |
|
| 301 |
# apply weight decay
|
| 302 |
-
p.data.mul_(1 - lr *
|
| 303 |
|
| 304 |
# apply update
|
| 305 |
p.data.add_(u, alpha=-adjusted_lr)
|
|
@@ -317,15 +321,15 @@ class Muon(torch.optim.Optimizer):
|
|
| 317 |
g = buf
|
| 318 |
return g
|
| 319 |
|
| 320 |
-
def _update_p(self, p, u, lr,
|
| 321 |
# scale update
|
| 322 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 323 |
# apply weight decay
|
| 324 |
-
p.data.mul_(1 - lr *
|
| 325 |
# apply update
|
| 326 |
p.data.add_(u, alpha=-adjusted_lr)
|
| 327 |
|
| 328 |
-
def parallel(self, params, group, lr,
|
| 329 |
"""
|
| 330 |
Perform a parallel optimization step using Muon.
|
| 331 |
"""
|
|
@@ -364,7 +368,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 364 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 365 |
state = param_to_state[id(p)]
|
| 366 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 367 |
-
_scatter(
|
|
|
|
|
|
|
| 368 |
|
| 369 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 370 |
|
|
@@ -398,23 +404,48 @@ class Muon(torch.optim.Optimizer):
|
|
| 398 |
|
| 399 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 400 |
lr = group["lr"]
|
| 401 |
-
|
| 402 |
momentum = group["momentum"]
|
| 403 |
|
| 404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
self.parallel(
|
| 406 |
-
|
| 407 |
group,
|
| 408 |
lr=lr,
|
| 409 |
-
|
| 410 |
momentum=momentum,
|
| 411 |
)
|
| 412 |
-
|
|
|
|
| 413 |
self.base(
|
| 414 |
-
|
| 415 |
group,
|
| 416 |
lr=lr,
|
| 417 |
-
|
| 418 |
momentum=momentum,
|
| 419 |
)
|
| 420 |
|
|
@@ -426,7 +457,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 426 |
lr = group["lr"]
|
| 427 |
beta1, beta2 = group["adamw_betas"]
|
| 428 |
eps = group["adamw_eps"]
|
| 429 |
-
weight_decay = group["
|
| 430 |
|
| 431 |
for p in params:
|
| 432 |
g = p.grad
|
|
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
+
from torch.distributed._tensor import DTensor, Replicate
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
|
| 103 |
|
| 104 |
|
| 105 |
@torch.no_grad()
|
| 106 |
+
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
| 107 |
u = state.computed_u
|
| 108 |
mesh = p.device_mesh
|
| 109 |
|
|
|
|
| 131 |
placements=p.placements,
|
| 132 |
device_mesh=mesh,
|
| 133 |
)
|
| 134 |
+
p.data.mul_(1 - lr * weight_decay)
|
| 135 |
p.data.add_(u, alpha=-lr)
|
| 136 |
|
| 137 |
|
| 138 |
+
def default_is_muon(x, name):
|
| 139 |
+
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
| 140 |
+
|
| 141 |
+
|
| 142 |
class Muon(torch.optim.Optimizer):
|
| 143 |
"""
|
| 144 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
|
|
| 163 |
adamw_lr: The learning rate for the internal AdamW.
|
| 164 |
adamw_betas: The betas for the internal AdamW.
|
| 165 |
adamw_eps: The epsilon for the internal AdamW.
|
| 166 |
+
adamw_weight_decay: The weight decay for the internal AdamW.
|
| 167 |
"""
|
| 168 |
|
| 169 |
def __init__(
|
| 170 |
self,
|
| 171 |
model,
|
| 172 |
+
is_muon_func=default_is_muon,
|
| 173 |
lr=1e-3,
|
| 174 |
momentum=0.95,
|
| 175 |
nesterov=True,
|
| 176 |
ns_steps=5,
|
| 177 |
+
weight_decay=0.1,
|
| 178 |
adamw_betas=(0.9, 0.95),
|
| 179 |
adamw_eps=1e-8,
|
| 180 |
none_grad=True,
|
|
|
|
| 182 |
):
|
| 183 |
defaults = dict(
|
| 184 |
lr=lr,
|
| 185 |
+
weight_decay=weight_decay,
|
| 186 |
momentum=momentum,
|
| 187 |
nesterov=nesterov,
|
| 188 |
ns_steps=ns_steps,
|
|
|
|
| 276 |
|
| 277 |
return param_to_state, ordered_params
|
| 278 |
|
| 279 |
+
def base(self, params, group, lr, weight_decay, momentum):
|
| 280 |
# generate weight updates in distributed fashion
|
| 281 |
for p in params:
|
| 282 |
g = p.grad
|
|
|
|
| 303 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 304 |
|
| 305 |
# apply weight decay
|
| 306 |
+
p.data.mul_(1 - lr * weight_decay)
|
| 307 |
|
| 308 |
# apply update
|
| 309 |
p.data.add_(u, alpha=-adjusted_lr)
|
|
|
|
| 321 |
g = buf
|
| 322 |
return g
|
| 323 |
|
| 324 |
+
def _update_p(self, p, u, lr, weight_decay):
|
| 325 |
# scale update
|
| 326 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 327 |
# apply weight decay
|
| 328 |
+
p.data.mul_(1 - lr * weight_decay)
|
| 329 |
# apply update
|
| 330 |
p.data.add_(u, alpha=-adjusted_lr)
|
| 331 |
|
| 332 |
+
def parallel(self, params, group, lr, weight_decay, momentum):
|
| 333 |
"""
|
| 334 |
Perform a parallel optimization step using Muon.
|
| 335 |
"""
|
|
|
|
| 368 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 369 |
state = param_to_state[id(p)]
|
| 370 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 371 |
+
_scatter(
|
| 372 |
+
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
| 373 |
+
)
|
| 374 |
|
| 375 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 376 |
|
|
|
|
| 404 |
|
| 405 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 406 |
lr = group["lr"]
|
| 407 |
+
weight_decay = group["weight_decay"]
|
| 408 |
momentum = group["momentum"]
|
| 409 |
|
| 410 |
+
param_dtensors = []
|
| 411 |
+
param_tensors = []
|
| 412 |
+
|
| 413 |
+
for p in params:
|
| 414 |
+
if p is None or p.grad is None:
|
| 415 |
+
continue
|
| 416 |
+
if isinstance(p.data, DTensor):
|
| 417 |
+
if all(
|
| 418 |
+
isinstance(placement, Replicate) for placement in p.placements
|
| 419 |
+
):
|
| 420 |
+
param_tensors.append(p)
|
| 421 |
+
else:
|
| 422 |
+
param_dtensors.append(p)
|
| 423 |
+
elif isinstance(p.data, torch.Tensor):
|
| 424 |
+
param_tensors.append(p)
|
| 425 |
+
else:
|
| 426 |
+
raise TypeError(f"Unsupported parameter type: {type(p.data)}")
|
| 427 |
+
|
| 428 |
+
if self.debug:
|
| 429 |
+
print(
|
| 430 |
+
f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
|
| 431 |
+
flush=True,
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
if len(param_dtensors) > 0:
|
| 435 |
self.parallel(
|
| 436 |
+
param_dtensors,
|
| 437 |
group,
|
| 438 |
lr=lr,
|
| 439 |
+
weight_decay=weight_decay,
|
| 440 |
momentum=momentum,
|
| 441 |
)
|
| 442 |
+
|
| 443 |
+
if len(param_tensors) > 0:
|
| 444 |
self.base(
|
| 445 |
+
param_tensors,
|
| 446 |
group,
|
| 447 |
lr=lr,
|
| 448 |
+
weight_decay=weight_decay,
|
| 449 |
momentum=momentum,
|
| 450 |
)
|
| 451 |
|
|
|
|
| 457 |
lr = group["lr"]
|
| 458 |
beta1, beta2 = group["adamw_betas"]
|
| 459 |
eps = group["adamw_eps"]
|
| 460 |
+
weight_decay = group["weight_decay"]
|
| 461 |
|
| 462 |
for p in params:
|
| 463 |
g = p.grad
|
build/torch26-cxx11-cu126-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_64757cb_dirty
|
| 3 |
+
ops = torch.ops._optimizer_64757cb_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_64757cb_dirty::{op_name}"
|
build/torch26-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1824224
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a93530e6981fdac23236dd7e3657c5b47513cda4accec78293234ce5f233400b
|
| 3 |
size 1824224
|
build/torch26-cxx11-cu126-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
-
from torch.distributed._tensor import DTensor
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
@@ -103,7 +103,7 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 103 |
|
| 104 |
|
| 105 |
@torch.no_grad()
|
| 106 |
-
def _scatter(p, state, lr,
|
| 107 |
u = state.computed_u
|
| 108 |
mesh = p.device_mesh
|
| 109 |
|
|
@@ -131,10 +131,14 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
| 131 |
placements=p.placements,
|
| 132 |
device_mesh=mesh,
|
| 133 |
)
|
| 134 |
-
p.data.mul_(1 - lr *
|
| 135 |
p.data.add_(u, alpha=-lr)
|
| 136 |
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
class Muon(torch.optim.Optimizer):
|
| 139 |
"""
|
| 140 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
@@ -159,18 +163,18 @@ class Muon(torch.optim.Optimizer):
|
|
| 159 |
adamw_lr: The learning rate for the internal AdamW.
|
| 160 |
adamw_betas: The betas for the internal AdamW.
|
| 161 |
adamw_eps: The epsilon for the internal AdamW.
|
| 162 |
-
|
| 163 |
"""
|
| 164 |
|
| 165 |
def __init__(
|
| 166 |
self,
|
| 167 |
model,
|
| 168 |
-
is_muon_func,
|
| 169 |
lr=1e-3,
|
| 170 |
momentum=0.95,
|
| 171 |
nesterov=True,
|
| 172 |
ns_steps=5,
|
| 173 |
-
|
| 174 |
adamw_betas=(0.9, 0.95),
|
| 175 |
adamw_eps=1e-8,
|
| 176 |
none_grad=True,
|
|
@@ -178,7 +182,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 178 |
):
|
| 179 |
defaults = dict(
|
| 180 |
lr=lr,
|
| 181 |
-
|
| 182 |
momentum=momentum,
|
| 183 |
nesterov=nesterov,
|
| 184 |
ns_steps=ns_steps,
|
|
@@ -272,7 +276,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 272 |
|
| 273 |
return param_to_state, ordered_params
|
| 274 |
|
| 275 |
-
def base(self, params, group, lr,
|
| 276 |
# generate weight updates in distributed fashion
|
| 277 |
for p in params:
|
| 278 |
g = p.grad
|
|
@@ -299,7 +303,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 299 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 300 |
|
| 301 |
# apply weight decay
|
| 302 |
-
p.data.mul_(1 - lr *
|
| 303 |
|
| 304 |
# apply update
|
| 305 |
p.data.add_(u, alpha=-adjusted_lr)
|
|
@@ -317,15 +321,15 @@ class Muon(torch.optim.Optimizer):
|
|
| 317 |
g = buf
|
| 318 |
return g
|
| 319 |
|
| 320 |
-
def _update_p(self, p, u, lr,
|
| 321 |
# scale update
|
| 322 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 323 |
# apply weight decay
|
| 324 |
-
p.data.mul_(1 - lr *
|
| 325 |
# apply update
|
| 326 |
p.data.add_(u, alpha=-adjusted_lr)
|
| 327 |
|
| 328 |
-
def parallel(self, params, group, lr,
|
| 329 |
"""
|
| 330 |
Perform a parallel optimization step using Muon.
|
| 331 |
"""
|
|
@@ -364,7 +368,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 364 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 365 |
state = param_to_state[id(p)]
|
| 366 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 367 |
-
_scatter(
|
|
|
|
|
|
|
| 368 |
|
| 369 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 370 |
|
|
@@ -398,23 +404,48 @@ class Muon(torch.optim.Optimizer):
|
|
| 398 |
|
| 399 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 400 |
lr = group["lr"]
|
| 401 |
-
|
| 402 |
momentum = group["momentum"]
|
| 403 |
|
| 404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
self.parallel(
|
| 406 |
-
|
| 407 |
group,
|
| 408 |
lr=lr,
|
| 409 |
-
|
| 410 |
momentum=momentum,
|
| 411 |
)
|
| 412 |
-
|
|
|
|
| 413 |
self.base(
|
| 414 |
-
|
| 415 |
group,
|
| 416 |
lr=lr,
|
| 417 |
-
|
| 418 |
momentum=momentum,
|
| 419 |
)
|
| 420 |
|
|
@@ -426,7 +457,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 426 |
lr = group["lr"]
|
| 427 |
beta1, beta2 = group["adamw_betas"]
|
| 428 |
eps = group["adamw_eps"]
|
| 429 |
-
weight_decay = group["
|
| 430 |
|
| 431 |
for p in params:
|
| 432 |
g = p.grad
|
|
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
+
from torch.distributed._tensor import DTensor, Replicate
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
|
| 103 |
|
| 104 |
|
| 105 |
@torch.no_grad()
|
| 106 |
+
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
| 107 |
u = state.computed_u
|
| 108 |
mesh = p.device_mesh
|
| 109 |
|
|
|
|
| 131 |
placements=p.placements,
|
| 132 |
device_mesh=mesh,
|
| 133 |
)
|
| 134 |
+
p.data.mul_(1 - lr * weight_decay)
|
| 135 |
p.data.add_(u, alpha=-lr)
|
| 136 |
|
| 137 |
|
| 138 |
+
def default_is_muon(x, name):
|
| 139 |
+
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
| 140 |
+
|
| 141 |
+
|
| 142 |
class Muon(torch.optim.Optimizer):
|
| 143 |
"""
|
| 144 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
|
|
| 163 |
adamw_lr: The learning rate for the internal AdamW.
|
| 164 |
adamw_betas: The betas for the internal AdamW.
|
| 165 |
adamw_eps: The epsilon for the internal AdamW.
|
| 166 |
+
adamw_weight_decay: The weight decay for the internal AdamW.
|
| 167 |
"""
|
| 168 |
|
| 169 |
def __init__(
|
| 170 |
self,
|
| 171 |
model,
|
| 172 |
+
is_muon_func=default_is_muon,
|
| 173 |
lr=1e-3,
|
| 174 |
momentum=0.95,
|
| 175 |
nesterov=True,
|
| 176 |
ns_steps=5,
|
| 177 |
+
weight_decay=0.1,
|
| 178 |
adamw_betas=(0.9, 0.95),
|
| 179 |
adamw_eps=1e-8,
|
| 180 |
none_grad=True,
|
|
|
|
| 182 |
):
|
| 183 |
defaults = dict(
|
| 184 |
lr=lr,
|
| 185 |
+
weight_decay=weight_decay,
|
| 186 |
momentum=momentum,
|
| 187 |
nesterov=nesterov,
|
| 188 |
ns_steps=ns_steps,
|
|
|
|
| 276 |
|
| 277 |
return param_to_state, ordered_params
|
| 278 |
|
| 279 |
+
def base(self, params, group, lr, weight_decay, momentum):
|
| 280 |
# generate weight updates in distributed fashion
|
| 281 |
for p in params:
|
| 282 |
g = p.grad
|
|
|
|
| 303 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 304 |
|
| 305 |
# apply weight decay
|
| 306 |
+
p.data.mul_(1 - lr * weight_decay)
|
| 307 |
|
| 308 |
# apply update
|
| 309 |
p.data.add_(u, alpha=-adjusted_lr)
|
|
|
|
| 321 |
g = buf
|
| 322 |
return g
|
| 323 |
|
| 324 |
+
def _update_p(self, p, u, lr, weight_decay):
|
| 325 |
# scale update
|
| 326 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 327 |
# apply weight decay
|
| 328 |
+
p.data.mul_(1 - lr * weight_decay)
|
| 329 |
# apply update
|
| 330 |
p.data.add_(u, alpha=-adjusted_lr)
|
| 331 |
|
| 332 |
+
def parallel(self, params, group, lr, weight_decay, momentum):
|
| 333 |
"""
|
| 334 |
Perform a parallel optimization step using Muon.
|
| 335 |
"""
|
|
|
|
| 368 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 369 |
state = param_to_state[id(p)]
|
| 370 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 371 |
+
_scatter(
|
| 372 |
+
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
| 373 |
+
)
|
| 374 |
|
| 375 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 376 |
|
|
|
|
| 404 |
|
| 405 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 406 |
lr = group["lr"]
|
| 407 |
+
weight_decay = group["weight_decay"]
|
| 408 |
momentum = group["momentum"]
|
| 409 |
|
| 410 |
+
param_dtensors = []
|
| 411 |
+
param_tensors = []
|
| 412 |
+
|
| 413 |
+
for p in params:
|
| 414 |
+
if p is None or p.grad is None:
|
| 415 |
+
continue
|
| 416 |
+
if isinstance(p.data, DTensor):
|
| 417 |
+
if all(
|
| 418 |
+
isinstance(placement, Replicate) for placement in p.placements
|
| 419 |
+
):
|
| 420 |
+
param_tensors.append(p)
|
| 421 |
+
else:
|
| 422 |
+
param_dtensors.append(p)
|
| 423 |
+
elif isinstance(p.data, torch.Tensor):
|
| 424 |
+
param_tensors.append(p)
|
| 425 |
+
else:
|
| 426 |
+
raise TypeError(f"Unsupported parameter type: {type(p.data)}")
|
| 427 |
+
|
| 428 |
+
if self.debug:
|
| 429 |
+
print(
|
| 430 |
+
f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
|
| 431 |
+
flush=True,
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
if len(param_dtensors) > 0:
|
| 435 |
self.parallel(
|
| 436 |
+
param_dtensors,
|
| 437 |
group,
|
| 438 |
lr=lr,
|
| 439 |
+
weight_decay=weight_decay,
|
| 440 |
momentum=momentum,
|
| 441 |
)
|
| 442 |
+
|
| 443 |
+
if len(param_tensors) > 0:
|
| 444 |
self.base(
|
| 445 |
+
param_tensors,
|
| 446 |
group,
|
| 447 |
lr=lr,
|
| 448 |
+
weight_decay=weight_decay,
|
| 449 |
momentum=momentum,
|
| 450 |
)
|
| 451 |
|
|
|
|
| 457 |
lr = group["lr"]
|
| 458 |
beta1, beta2 = group["adamw_betas"]
|
| 459 |
eps = group["adamw_eps"]
|
| 460 |
+
weight_decay = group["weight_decay"]
|
| 461 |
|
| 462 |
for p in params:
|
| 463 |
g = p.grad
|
build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_64757cb_dirty
|
| 3 |
+
ops = torch.ops._optimizer_64757cb_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_64757cb_dirty::{op_name}"
|
build/torch26-cxx11-rocm62-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1749744
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:caa40905ac8f209fecccae42c6892c3766ad5c7069382e60d2339e73da6ee7d6
|
| 3 |
size 1749744
|
build/torch26-cxx11-rocm62-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
-
from torch.distributed._tensor import DTensor
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
@@ -103,7 +103,7 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 103 |
|
| 104 |
|
| 105 |
@torch.no_grad()
|
| 106 |
-
def _scatter(p, state, lr,
|
| 107 |
u = state.computed_u
|
| 108 |
mesh = p.device_mesh
|
| 109 |
|
|
@@ -131,10 +131,14 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
| 131 |
placements=p.placements,
|
| 132 |
device_mesh=mesh,
|
| 133 |
)
|
| 134 |
-
p.data.mul_(1 - lr *
|
| 135 |
p.data.add_(u, alpha=-lr)
|
| 136 |
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
class Muon(torch.optim.Optimizer):
|
| 139 |
"""
|
| 140 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
@@ -159,18 +163,18 @@ class Muon(torch.optim.Optimizer):
|
|
| 159 |
adamw_lr: The learning rate for the internal AdamW.
|
| 160 |
adamw_betas: The betas for the internal AdamW.
|
| 161 |
adamw_eps: The epsilon for the internal AdamW.
|
| 162 |
-
|
| 163 |
"""
|
| 164 |
|
| 165 |
def __init__(
|
| 166 |
self,
|
| 167 |
model,
|
| 168 |
-
is_muon_func,
|
| 169 |
lr=1e-3,
|
| 170 |
momentum=0.95,
|
| 171 |
nesterov=True,
|
| 172 |
ns_steps=5,
|
| 173 |
-
|
| 174 |
adamw_betas=(0.9, 0.95),
|
| 175 |
adamw_eps=1e-8,
|
| 176 |
none_grad=True,
|
|
@@ -178,7 +182,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 178 |
):
|
| 179 |
defaults = dict(
|
| 180 |
lr=lr,
|
| 181 |
-
|
| 182 |
momentum=momentum,
|
| 183 |
nesterov=nesterov,
|
| 184 |
ns_steps=ns_steps,
|
|
@@ -272,7 +276,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 272 |
|
| 273 |
return param_to_state, ordered_params
|
| 274 |
|
| 275 |
-
def base(self, params, group, lr,
|
| 276 |
# generate weight updates in distributed fashion
|
| 277 |
for p in params:
|
| 278 |
g = p.grad
|
|
@@ -299,7 +303,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 299 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 300 |
|
| 301 |
# apply weight decay
|
| 302 |
-
p.data.mul_(1 - lr *
|
| 303 |
|
| 304 |
# apply update
|
| 305 |
p.data.add_(u, alpha=-adjusted_lr)
|
|
@@ -317,15 +321,15 @@ class Muon(torch.optim.Optimizer):
|
|
| 317 |
g = buf
|
| 318 |
return g
|
| 319 |
|
| 320 |
-
def _update_p(self, p, u, lr,
|
| 321 |
# scale update
|
| 322 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 323 |
# apply weight decay
|
| 324 |
-
p.data.mul_(1 - lr *
|
| 325 |
# apply update
|
| 326 |
p.data.add_(u, alpha=-adjusted_lr)
|
| 327 |
|
| 328 |
-
def parallel(self, params, group, lr,
|
| 329 |
"""
|
| 330 |
Perform a parallel optimization step using Muon.
|
| 331 |
"""
|
|
@@ -364,7 +368,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 364 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 365 |
state = param_to_state[id(p)]
|
| 366 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 367 |
-
_scatter(
|
|
|
|
|
|
|
| 368 |
|
| 369 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 370 |
|
|
@@ -398,23 +404,48 @@ class Muon(torch.optim.Optimizer):
|
|
| 398 |
|
| 399 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 400 |
lr = group["lr"]
|
| 401 |
-
|
| 402 |
momentum = group["momentum"]
|
| 403 |
|
| 404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
self.parallel(
|
| 406 |
-
|
| 407 |
group,
|
| 408 |
lr=lr,
|
| 409 |
-
|
| 410 |
momentum=momentum,
|
| 411 |
)
|
| 412 |
-
|
|
|
|
| 413 |
self.base(
|
| 414 |
-
|
| 415 |
group,
|
| 416 |
lr=lr,
|
| 417 |
-
|
| 418 |
momentum=momentum,
|
| 419 |
)
|
| 420 |
|
|
@@ -426,7 +457,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 426 |
lr = group["lr"]
|
| 427 |
beta1, beta2 = group["adamw_betas"]
|
| 428 |
eps = group["adamw_eps"]
|
| 429 |
-
weight_decay = group["
|
| 430 |
|
| 431 |
for p in params:
|
| 432 |
g = p.grad
|
|
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
+
from torch.distributed._tensor import DTensor, Replicate
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
|
| 103 |
|
| 104 |
|
| 105 |
@torch.no_grad()
|
| 106 |
+
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
| 107 |
u = state.computed_u
|
| 108 |
mesh = p.device_mesh
|
| 109 |
|
|
|
|
| 131 |
placements=p.placements,
|
| 132 |
device_mesh=mesh,
|
| 133 |
)
|
| 134 |
+
p.data.mul_(1 - lr * weight_decay)
|
| 135 |
p.data.add_(u, alpha=-lr)
|
| 136 |
|
| 137 |
|
| 138 |
+
def default_is_muon(x, name):
|
| 139 |
+
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
| 140 |
+
|
| 141 |
+
|
| 142 |
class Muon(torch.optim.Optimizer):
|
| 143 |
"""
|
| 144 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
|
|
| 163 |
adamw_lr: The learning rate for the internal AdamW.
|
| 164 |
adamw_betas: The betas for the internal AdamW.
|
| 165 |
adamw_eps: The epsilon for the internal AdamW.
|
| 166 |
+
adamw_weight_decay: The weight decay for the internal AdamW.
|
| 167 |
"""
|
| 168 |
|
| 169 |
def __init__(
|
| 170 |
self,
|
| 171 |
model,
|
| 172 |
+
is_muon_func=default_is_muon,
|
| 173 |
lr=1e-3,
|
| 174 |
momentum=0.95,
|
| 175 |
nesterov=True,
|
| 176 |
ns_steps=5,
|
| 177 |
+
weight_decay=0.1,
|
| 178 |
adamw_betas=(0.9, 0.95),
|
| 179 |
adamw_eps=1e-8,
|
| 180 |
none_grad=True,
|
|
|
|
| 182 |
):
|
| 183 |
defaults = dict(
|
| 184 |
lr=lr,
|
| 185 |
+
weight_decay=weight_decay,
|
| 186 |
momentum=momentum,
|
| 187 |
nesterov=nesterov,
|
| 188 |
ns_steps=ns_steps,
|
|
|
|
| 276 |
|
| 277 |
return param_to_state, ordered_params
|
| 278 |
|
| 279 |
+
def base(self, params, group, lr, weight_decay, momentum):
|
| 280 |
# generate weight updates in distributed fashion
|
| 281 |
for p in params:
|
| 282 |
g = p.grad
|
|
|
|
| 303 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 304 |
|
| 305 |
# apply weight decay
|
| 306 |
+
p.data.mul_(1 - lr * weight_decay)
|
| 307 |
|
| 308 |
# apply update
|
| 309 |
p.data.add_(u, alpha=-adjusted_lr)
|
|
|
|
| 321 |
g = buf
|
| 322 |
return g
|
| 323 |
|
| 324 |
+
def _update_p(self, p, u, lr, weight_decay):
|
| 325 |
# scale update
|
| 326 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 327 |
# apply weight decay
|
| 328 |
+
p.data.mul_(1 - lr * weight_decay)
|
| 329 |
# apply update
|
| 330 |
p.data.add_(u, alpha=-adjusted_lr)
|
| 331 |
|
| 332 |
+
def parallel(self, params, group, lr, weight_decay, momentum):
|
| 333 |
"""
|
| 334 |
Perform a parallel optimization step using Muon.
|
| 335 |
"""
|
|
|
|
| 368 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 369 |
state = param_to_state[id(p)]
|
| 370 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 371 |
+
_scatter(
|
| 372 |
+
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
| 373 |
+
)
|
| 374 |
|
| 375 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 376 |
|
|
|
|
| 404 |
|
| 405 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 406 |
lr = group["lr"]
|
| 407 |
+
weight_decay = group["weight_decay"]
|
| 408 |
momentum = group["momentum"]
|
| 409 |
|
| 410 |
+
param_dtensors = []
|
| 411 |
+
param_tensors = []
|
| 412 |
+
|
| 413 |
+
for p in params:
|
| 414 |
+
if p is None or p.grad is None:
|
| 415 |
+
continue
|
| 416 |
+
if isinstance(p.data, DTensor):
|
| 417 |
+
if all(
|
| 418 |
+
isinstance(placement, Replicate) for placement in p.placements
|
| 419 |
+
):
|
| 420 |
+
param_tensors.append(p)
|
| 421 |
+
else:
|
| 422 |
+
param_dtensors.append(p)
|
| 423 |
+
elif isinstance(p.data, torch.Tensor):
|
| 424 |
+
param_tensors.append(p)
|
| 425 |
+
else:
|
| 426 |
+
raise TypeError(f"Unsupported parameter type: {type(p.data)}")
|
| 427 |
+
|
| 428 |
+
if self.debug:
|
| 429 |
+
print(
|
| 430 |
+
f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
|
| 431 |
+
flush=True,
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
if len(param_dtensors) > 0:
|
| 435 |
self.parallel(
|
| 436 |
+
param_dtensors,
|
| 437 |
group,
|
| 438 |
lr=lr,
|
| 439 |
+
weight_decay=weight_decay,
|
| 440 |
momentum=momentum,
|
| 441 |
)
|
| 442 |
+
|
| 443 |
+
if len(param_tensors) > 0:
|
| 444 |
self.base(
|
| 445 |
+
param_tensors,
|
| 446 |
group,
|
| 447 |
lr=lr,
|
| 448 |
+
weight_decay=weight_decay,
|
| 449 |
momentum=momentum,
|
| 450 |
)
|
| 451 |
|
|
|
|
| 457 |
lr = group["lr"]
|
| 458 |
beta1, beta2 = group["adamw_betas"]
|
| 459 |
eps = group["adamw_eps"]
|
| 460 |
+
weight_decay = group["weight_decay"]
|
| 461 |
|
| 462 |
for p in params:
|
| 463 |
g = p.grad
|
build/torch26-cxx98-cu118-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_64757cb_dirty
|
| 3 |
+
ops = torch.ops._optimizer_64757cb_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_64757cb_dirty::{op_name}"
|
build/torch26-cxx98-cu118-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1787192
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6919551ed599e7e0dc1a750d1972bdb31605f57583b3617054cb70dd40d54d26
|
| 3 |
size 1787192
|
build/torch26-cxx98-cu118-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
-
from torch.distributed._tensor import DTensor
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
@@ -103,7 +103,7 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 103 |
|
| 104 |
|
| 105 |
@torch.no_grad()
|
| 106 |
-
def _scatter(p, state, lr,
|
| 107 |
u = state.computed_u
|
| 108 |
mesh = p.device_mesh
|
| 109 |
|
|
@@ -131,10 +131,14 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
| 131 |
placements=p.placements,
|
| 132 |
device_mesh=mesh,
|
| 133 |
)
|
| 134 |
-
p.data.mul_(1 - lr *
|
| 135 |
p.data.add_(u, alpha=-lr)
|
| 136 |
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
class Muon(torch.optim.Optimizer):
|
| 139 |
"""
|
| 140 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
@@ -159,18 +163,18 @@ class Muon(torch.optim.Optimizer):
|
|
| 159 |
adamw_lr: The learning rate for the internal AdamW.
|
| 160 |
adamw_betas: The betas for the internal AdamW.
|
| 161 |
adamw_eps: The epsilon for the internal AdamW.
|
| 162 |
-
|
| 163 |
"""
|
| 164 |
|
| 165 |
def __init__(
|
| 166 |
self,
|
| 167 |
model,
|
| 168 |
-
is_muon_func,
|
| 169 |
lr=1e-3,
|
| 170 |
momentum=0.95,
|
| 171 |
nesterov=True,
|
| 172 |
ns_steps=5,
|
| 173 |
-
|
| 174 |
adamw_betas=(0.9, 0.95),
|
| 175 |
adamw_eps=1e-8,
|
| 176 |
none_grad=True,
|
|
@@ -178,7 +182,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 178 |
):
|
| 179 |
defaults = dict(
|
| 180 |
lr=lr,
|
| 181 |
-
|
| 182 |
momentum=momentum,
|
| 183 |
nesterov=nesterov,
|
| 184 |
ns_steps=ns_steps,
|
|
@@ -272,7 +276,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 272 |
|
| 273 |
return param_to_state, ordered_params
|
| 274 |
|
| 275 |
-
def base(self, params, group, lr,
|
| 276 |
# generate weight updates in distributed fashion
|
| 277 |
for p in params:
|
| 278 |
g = p.grad
|
|
@@ -299,7 +303,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 299 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 300 |
|
| 301 |
# apply weight decay
|
| 302 |
-
p.data.mul_(1 - lr *
|
| 303 |
|
| 304 |
# apply update
|
| 305 |
p.data.add_(u, alpha=-adjusted_lr)
|
|
@@ -317,15 +321,15 @@ class Muon(torch.optim.Optimizer):
|
|
| 317 |
g = buf
|
| 318 |
return g
|
| 319 |
|
| 320 |
-
def _update_p(self, p, u, lr,
|
| 321 |
# scale update
|
| 322 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 323 |
# apply weight decay
|
| 324 |
-
p.data.mul_(1 - lr *
|
| 325 |
# apply update
|
| 326 |
p.data.add_(u, alpha=-adjusted_lr)
|
| 327 |
|
| 328 |
-
def parallel(self, params, group, lr,
|
| 329 |
"""
|
| 330 |
Perform a parallel optimization step using Muon.
|
| 331 |
"""
|
|
@@ -364,7 +368,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 364 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 365 |
state = param_to_state[id(p)]
|
| 366 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 367 |
-
_scatter(
|
|
|
|
|
|
|
| 368 |
|
| 369 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 370 |
|
|
@@ -398,23 +404,48 @@ class Muon(torch.optim.Optimizer):
|
|
| 398 |
|
| 399 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 400 |
lr = group["lr"]
|
| 401 |
-
|
| 402 |
momentum = group["momentum"]
|
| 403 |
|
| 404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
self.parallel(
|
| 406 |
-
|
| 407 |
group,
|
| 408 |
lr=lr,
|
| 409 |
-
|
| 410 |
momentum=momentum,
|
| 411 |
)
|
| 412 |
-
|
|
|
|
| 413 |
self.base(
|
| 414 |
-
|
| 415 |
group,
|
| 416 |
lr=lr,
|
| 417 |
-
|
| 418 |
momentum=momentum,
|
| 419 |
)
|
| 420 |
|
|
@@ -426,7 +457,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 426 |
lr = group["lr"]
|
| 427 |
beta1, beta2 = group["adamw_betas"]
|
| 428 |
eps = group["adamw_eps"]
|
| 429 |
-
weight_decay = group["
|
| 430 |
|
| 431 |
for p in params:
|
| 432 |
g = p.grad
|
|
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
+
from torch.distributed._tensor import DTensor, Replicate
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
|
| 103 |
|
| 104 |
|
| 105 |
@torch.no_grad()
|
| 106 |
+
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
| 107 |
u = state.computed_u
|
| 108 |
mesh = p.device_mesh
|
| 109 |
|
|
|
|
| 131 |
placements=p.placements,
|
| 132 |
device_mesh=mesh,
|
| 133 |
)
|
| 134 |
+
p.data.mul_(1 - lr * weight_decay)
|
| 135 |
p.data.add_(u, alpha=-lr)
|
| 136 |
|
| 137 |
|
| 138 |
+
def default_is_muon(x, name):
|
| 139 |
+
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
| 140 |
+
|
| 141 |
+
|
| 142 |
class Muon(torch.optim.Optimizer):
|
| 143 |
"""
|
| 144 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
|
|
| 163 |
adamw_lr: The learning rate for the internal AdamW.
|
| 164 |
adamw_betas: The betas for the internal AdamW.
|
| 165 |
adamw_eps: The epsilon for the internal AdamW.
|
| 166 |
+
adamw_weight_decay: The weight decay for the internal AdamW.
|
| 167 |
"""
|
| 168 |
|
| 169 |
def __init__(
|
| 170 |
self,
|
| 171 |
model,
|
| 172 |
+
is_muon_func=default_is_muon,
|
| 173 |
lr=1e-3,
|
| 174 |
momentum=0.95,
|
| 175 |
nesterov=True,
|
| 176 |
ns_steps=5,
|
| 177 |
+
weight_decay=0.1,
|
| 178 |
adamw_betas=(0.9, 0.95),
|
| 179 |
adamw_eps=1e-8,
|
| 180 |
none_grad=True,
|
|
|
|
| 182 |
):
|
| 183 |
defaults = dict(
|
| 184 |
lr=lr,
|
| 185 |
+
weight_decay=weight_decay,
|
| 186 |
momentum=momentum,
|
| 187 |
nesterov=nesterov,
|
| 188 |
ns_steps=ns_steps,
|
|
|
|
| 276 |
|
| 277 |
return param_to_state, ordered_params
|
| 278 |
|
| 279 |
+
def base(self, params, group, lr, weight_decay, momentum):
|
| 280 |
# generate weight updates in distributed fashion
|
| 281 |
for p in params:
|
| 282 |
g = p.grad
|
|
|
|
| 303 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 304 |
|
| 305 |
# apply weight decay
|
| 306 |
+
p.data.mul_(1 - lr * weight_decay)
|
| 307 |
|
| 308 |
# apply update
|
| 309 |
p.data.add_(u, alpha=-adjusted_lr)
|
|
|
|
| 321 |
g = buf
|
| 322 |
return g
|
| 323 |
|
| 324 |
+
def _update_p(self, p, u, lr, weight_decay):
|
| 325 |
# scale update
|
| 326 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 327 |
# apply weight decay
|
| 328 |
+
p.data.mul_(1 - lr * weight_decay)
|
| 329 |
# apply update
|
| 330 |
p.data.add_(u, alpha=-adjusted_lr)
|
| 331 |
|
| 332 |
+
def parallel(self, params, group, lr, weight_decay, momentum):
|
| 333 |
"""
|
| 334 |
Perform a parallel optimization step using Muon.
|
| 335 |
"""
|
|
|
|
| 368 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 369 |
state = param_to_state[id(p)]
|
| 370 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 371 |
+
_scatter(
|
| 372 |
+
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
| 373 |
+
)
|
| 374 |
|
| 375 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 376 |
|
|
|
|
| 404 |
|
| 405 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 406 |
lr = group["lr"]
|
| 407 |
+
weight_decay = group["weight_decay"]
|
| 408 |
momentum = group["momentum"]
|
| 409 |
|
| 410 |
+
param_dtensors = []
|
| 411 |
+
param_tensors = []
|
| 412 |
+
|
| 413 |
+
for p in params:
|
| 414 |
+
if p is None or p.grad is None:
|
| 415 |
+
continue
|
| 416 |
+
if isinstance(p.data, DTensor):
|
| 417 |
+
if all(
|
| 418 |
+
isinstance(placement, Replicate) for placement in p.placements
|
| 419 |
+
):
|
| 420 |
+
param_tensors.append(p)
|
| 421 |
+
else:
|
| 422 |
+
param_dtensors.append(p)
|
| 423 |
+
elif isinstance(p.data, torch.Tensor):
|
| 424 |
+
param_tensors.append(p)
|
| 425 |
+
else:
|
| 426 |
+
raise TypeError(f"Unsupported parameter type: {type(p.data)}")
|
| 427 |
+
|
| 428 |
+
if self.debug:
|
| 429 |
+
print(
|
| 430 |
+
f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
|
| 431 |
+
flush=True,
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
if len(param_dtensors) > 0:
|
| 435 |
self.parallel(
|
| 436 |
+
param_dtensors,
|
| 437 |
group,
|
| 438 |
lr=lr,
|
| 439 |
+
weight_decay=weight_decay,
|
| 440 |
momentum=momentum,
|
| 441 |
)
|
| 442 |
+
|
| 443 |
+
if len(param_tensors) > 0:
|
| 444 |
self.base(
|
| 445 |
+
param_tensors,
|
| 446 |
group,
|
| 447 |
lr=lr,
|
| 448 |
+
weight_decay=weight_decay,
|
| 449 |
momentum=momentum,
|
| 450 |
)
|
| 451 |
|
|
|
|
| 457 |
lr = group["lr"]
|
| 458 |
beta1, beta2 = group["adamw_betas"]
|
| 459 |
eps = group["adamw_eps"]
|
| 460 |
+
weight_decay = group["weight_decay"]
|
| 461 |
|
| 462 |
for p in params:
|
| 463 |
g = p.grad
|
build/torch26-cxx98-cu124-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_64757cb_dirty
|
| 3 |
+
ops = torch.ops._optimizer_64757cb_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_64757cb_dirty::{op_name}"
|
build/torch26-cxx98-cu124-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1824184
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f07cc2637669130fc9e209cb2c4358caba1c4c2d5837a108043b073d7897c3a7
|
| 3 |
size 1824184
|
build/torch26-cxx98-cu124-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
-
from torch.distributed._tensor import DTensor
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
@@ -103,7 +103,7 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 103 |
|
| 104 |
|
| 105 |
@torch.no_grad()
|
| 106 |
-
def _scatter(p, state, lr,
|
| 107 |
u = state.computed_u
|
| 108 |
mesh = p.device_mesh
|
| 109 |
|
|
@@ -131,10 +131,14 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
| 131 |
placements=p.placements,
|
| 132 |
device_mesh=mesh,
|
| 133 |
)
|
| 134 |
-
p.data.mul_(1 - lr *
|
| 135 |
p.data.add_(u, alpha=-lr)
|
| 136 |
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
class Muon(torch.optim.Optimizer):
|
| 139 |
"""
|
| 140 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
@@ -159,18 +163,18 @@ class Muon(torch.optim.Optimizer):
|
|
| 159 |
adamw_lr: The learning rate for the internal AdamW.
|
| 160 |
adamw_betas: The betas for the internal AdamW.
|
| 161 |
adamw_eps: The epsilon for the internal AdamW.
|
| 162 |
-
|
| 163 |
"""
|
| 164 |
|
| 165 |
def __init__(
|
| 166 |
self,
|
| 167 |
model,
|
| 168 |
-
is_muon_func,
|
| 169 |
lr=1e-3,
|
| 170 |
momentum=0.95,
|
| 171 |
nesterov=True,
|
| 172 |
ns_steps=5,
|
| 173 |
-
|
| 174 |
adamw_betas=(0.9, 0.95),
|
| 175 |
adamw_eps=1e-8,
|
| 176 |
none_grad=True,
|
|
@@ -178,7 +182,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 178 |
):
|
| 179 |
defaults = dict(
|
| 180 |
lr=lr,
|
| 181 |
-
|
| 182 |
momentum=momentum,
|
| 183 |
nesterov=nesterov,
|
| 184 |
ns_steps=ns_steps,
|
|
@@ -272,7 +276,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 272 |
|
| 273 |
return param_to_state, ordered_params
|
| 274 |
|
| 275 |
-
def base(self, params, group, lr,
|
| 276 |
# generate weight updates in distributed fashion
|
| 277 |
for p in params:
|
| 278 |
g = p.grad
|
|
@@ -299,7 +303,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 299 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 300 |
|
| 301 |
# apply weight decay
|
| 302 |
-
p.data.mul_(1 - lr *
|
| 303 |
|
| 304 |
# apply update
|
| 305 |
p.data.add_(u, alpha=-adjusted_lr)
|
|
@@ -317,15 +321,15 @@ class Muon(torch.optim.Optimizer):
|
|
| 317 |
g = buf
|
| 318 |
return g
|
| 319 |
|
| 320 |
-
def _update_p(self, p, u, lr,
|
| 321 |
# scale update
|
| 322 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 323 |
# apply weight decay
|
| 324 |
-
p.data.mul_(1 - lr *
|
| 325 |
# apply update
|
| 326 |
p.data.add_(u, alpha=-adjusted_lr)
|
| 327 |
|
| 328 |
-
def parallel(self, params, group, lr,
|
| 329 |
"""
|
| 330 |
Perform a parallel optimization step using Muon.
|
| 331 |
"""
|
|
@@ -364,7 +368,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 364 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 365 |
state = param_to_state[id(p)]
|
| 366 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 367 |
-
_scatter(
|
|
|
|
|
|
|
| 368 |
|
| 369 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 370 |
|
|
@@ -398,23 +404,48 @@ class Muon(torch.optim.Optimizer):
|
|
| 398 |
|
| 399 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 400 |
lr = group["lr"]
|
| 401 |
-
|
| 402 |
momentum = group["momentum"]
|
| 403 |
|
| 404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
self.parallel(
|
| 406 |
-
|
| 407 |
group,
|
| 408 |
lr=lr,
|
| 409 |
-
|
| 410 |
momentum=momentum,
|
| 411 |
)
|
| 412 |
-
|
|
|
|
| 413 |
self.base(
|
| 414 |
-
|
| 415 |
group,
|
| 416 |
lr=lr,
|
| 417 |
-
|
| 418 |
momentum=momentum,
|
| 419 |
)
|
| 420 |
|
|
@@ -426,7 +457,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 426 |
lr = group["lr"]
|
| 427 |
beta1, beta2 = group["adamw_betas"]
|
| 428 |
eps = group["adamw_eps"]
|
| 429 |
-
weight_decay = group["
|
| 430 |
|
| 431 |
for p in params:
|
| 432 |
g = p.grad
|
|
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
+
from torch.distributed._tensor import DTensor, Replicate
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
|
| 103 |
|
| 104 |
|
| 105 |
@torch.no_grad()
|
| 106 |
+
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
| 107 |
u = state.computed_u
|
| 108 |
mesh = p.device_mesh
|
| 109 |
|
|
|
|
| 131 |
placements=p.placements,
|
| 132 |
device_mesh=mesh,
|
| 133 |
)
|
| 134 |
+
p.data.mul_(1 - lr * weight_decay)
|
| 135 |
p.data.add_(u, alpha=-lr)
|
| 136 |
|
| 137 |
|
| 138 |
+
def default_is_muon(x, name):
|
| 139 |
+
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
| 140 |
+
|
| 141 |
+
|
| 142 |
class Muon(torch.optim.Optimizer):
|
| 143 |
"""
|
| 144 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
|
|
| 163 |
adamw_lr: The learning rate for the internal AdamW.
|
| 164 |
adamw_betas: The betas for the internal AdamW.
|
| 165 |
adamw_eps: The epsilon for the internal AdamW.
|
| 166 |
+
adamw_weight_decay: The weight decay for the internal AdamW.
|
| 167 |
"""
|
| 168 |
|
| 169 |
def __init__(
|
| 170 |
self,
|
| 171 |
model,
|
| 172 |
+
is_muon_func=default_is_muon,
|
| 173 |
lr=1e-3,
|
| 174 |
momentum=0.95,
|
| 175 |
nesterov=True,
|
| 176 |
ns_steps=5,
|
| 177 |
+
weight_decay=0.1,
|
| 178 |
adamw_betas=(0.9, 0.95),
|
| 179 |
adamw_eps=1e-8,
|
| 180 |
none_grad=True,
|
|
|
|
| 182 |
):
|
| 183 |
defaults = dict(
|
| 184 |
lr=lr,
|
| 185 |
+
weight_decay=weight_decay,
|
| 186 |
momentum=momentum,
|
| 187 |
nesterov=nesterov,
|
| 188 |
ns_steps=ns_steps,
|
|
|
|
| 276 |
|
| 277 |
return param_to_state, ordered_params
|
| 278 |
|
| 279 |
+
def base(self, params, group, lr, weight_decay, momentum):
|
| 280 |
# generate weight updates in distributed fashion
|
| 281 |
for p in params:
|
| 282 |
g = p.grad
|
|
|
|
| 303 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 304 |
|
| 305 |
# apply weight decay
|
| 306 |
+
p.data.mul_(1 - lr * weight_decay)
|
| 307 |
|
| 308 |
# apply update
|
| 309 |
p.data.add_(u, alpha=-adjusted_lr)
|
|
|
|
| 321 |
g = buf
|
| 322 |
return g
|
| 323 |
|
| 324 |
+
def _update_p(self, p, u, lr, weight_decay):
|
| 325 |
# scale update
|
| 326 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 327 |
# apply weight decay
|
| 328 |
+
p.data.mul_(1 - lr * weight_decay)
|
| 329 |
# apply update
|
| 330 |
p.data.add_(u, alpha=-adjusted_lr)
|
| 331 |
|
| 332 |
+
def parallel(self, params, group, lr, weight_decay, momentum):
|
| 333 |
"""
|
| 334 |
Perform a parallel optimization step using Muon.
|
| 335 |
"""
|
|
|
|
| 368 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 369 |
state = param_to_state[id(p)]
|
| 370 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 371 |
+
_scatter(
|
| 372 |
+
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
| 373 |
+
)
|
| 374 |
|
| 375 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 376 |
|
|
|
|
| 404 |
|
| 405 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 406 |
lr = group["lr"]
|
| 407 |
+
weight_decay = group["weight_decay"]
|
| 408 |
momentum = group["momentum"]
|
| 409 |
|
| 410 |
+
param_dtensors = []
|
| 411 |
+
param_tensors = []
|
| 412 |
+
|
| 413 |
+
for p in params:
|
| 414 |
+
if p is None or p.grad is None:
|
| 415 |
+
continue
|
| 416 |
+
if isinstance(p.data, DTensor):
|
| 417 |
+
if all(
|
| 418 |
+
isinstance(placement, Replicate) for placement in p.placements
|
| 419 |
+
):
|
| 420 |
+
param_tensors.append(p)
|
| 421 |
+
else:
|
| 422 |
+
param_dtensors.append(p)
|
| 423 |
+
elif isinstance(p.data, torch.Tensor):
|
| 424 |
+
param_tensors.append(p)
|
| 425 |
+
else:
|
| 426 |
+
raise TypeError(f"Unsupported parameter type: {type(p.data)}")
|
| 427 |
+
|
| 428 |
+
if self.debug:
|
| 429 |
+
print(
|
| 430 |
+
f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
|
| 431 |
+
flush=True,
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
if len(param_dtensors) > 0:
|
| 435 |
self.parallel(
|
| 436 |
+
param_dtensors,
|
| 437 |
group,
|
| 438 |
lr=lr,
|
| 439 |
+
weight_decay=weight_decay,
|
| 440 |
momentum=momentum,
|
| 441 |
)
|
| 442 |
+
|
| 443 |
+
if len(param_tensors) > 0:
|
| 444 |
self.base(
|
| 445 |
+
param_tensors,
|
| 446 |
group,
|
| 447 |
lr=lr,
|
| 448 |
+
weight_decay=weight_decay,
|
| 449 |
momentum=momentum,
|
| 450 |
)
|
| 451 |
|
|
|
|
| 457 |
lr = group["lr"]
|
| 458 |
beta1, beta2 = group["adamw_betas"]
|
| 459 |
eps = group["adamw_eps"]
|
| 460 |
+
weight_decay = group["weight_decay"]
|
| 461 |
|
| 462 |
for p in params:
|
| 463 |
g = p.grad
|
build/torch26-cxx98-cu126-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_64757cb_dirty
|
| 3 |
+
ops = torch.ops._optimizer_64757cb_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_64757cb_dirty::{op_name}"
|
build/torch26-cxx98-cu126-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1824184
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8b9ef8fa2dd4d80cb3c1c3c2a72b99e0d76b3e676acd551f3a9ff4cdd21773eb
|
| 3 |
size 1824184
|
build/torch26-cxx98-cu126-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
-
from torch.distributed._tensor import DTensor
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
@@ -103,7 +103,7 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 103 |
|
| 104 |
|
| 105 |
@torch.no_grad()
|
| 106 |
-
def _scatter(p, state, lr,
|
| 107 |
u = state.computed_u
|
| 108 |
mesh = p.device_mesh
|
| 109 |
|
|
@@ -131,10 +131,14 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
| 131 |
placements=p.placements,
|
| 132 |
device_mesh=mesh,
|
| 133 |
)
|
| 134 |
-
p.data.mul_(1 - lr *
|
| 135 |
p.data.add_(u, alpha=-lr)
|
| 136 |
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
class Muon(torch.optim.Optimizer):
|
| 139 |
"""
|
| 140 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
@@ -159,18 +163,18 @@ class Muon(torch.optim.Optimizer):
|
|
| 159 |
adamw_lr: The learning rate for the internal AdamW.
|
| 160 |
adamw_betas: The betas for the internal AdamW.
|
| 161 |
adamw_eps: The epsilon for the internal AdamW.
|
| 162 |
-
|
| 163 |
"""
|
| 164 |
|
| 165 |
def __init__(
|
| 166 |
self,
|
| 167 |
model,
|
| 168 |
-
is_muon_func,
|
| 169 |
lr=1e-3,
|
| 170 |
momentum=0.95,
|
| 171 |
nesterov=True,
|
| 172 |
ns_steps=5,
|
| 173 |
-
|
| 174 |
adamw_betas=(0.9, 0.95),
|
| 175 |
adamw_eps=1e-8,
|
| 176 |
none_grad=True,
|
|
@@ -178,7 +182,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 178 |
):
|
| 179 |
defaults = dict(
|
| 180 |
lr=lr,
|
| 181 |
-
|
| 182 |
momentum=momentum,
|
| 183 |
nesterov=nesterov,
|
| 184 |
ns_steps=ns_steps,
|
|
@@ -272,7 +276,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 272 |
|
| 273 |
return param_to_state, ordered_params
|
| 274 |
|
| 275 |
-
def base(self, params, group, lr,
|
| 276 |
# generate weight updates in distributed fashion
|
| 277 |
for p in params:
|
| 278 |
g = p.grad
|
|
@@ -299,7 +303,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 299 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 300 |
|
| 301 |
# apply weight decay
|
| 302 |
-
p.data.mul_(1 - lr *
|
| 303 |
|
| 304 |
# apply update
|
| 305 |
p.data.add_(u, alpha=-adjusted_lr)
|
|
@@ -317,15 +321,15 @@ class Muon(torch.optim.Optimizer):
|
|
| 317 |
g = buf
|
| 318 |
return g
|
| 319 |
|
| 320 |
-
def _update_p(self, p, u, lr,
|
| 321 |
# scale update
|
| 322 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 323 |
# apply weight decay
|
| 324 |
-
p.data.mul_(1 - lr *
|
| 325 |
# apply update
|
| 326 |
p.data.add_(u, alpha=-adjusted_lr)
|
| 327 |
|
| 328 |
-
def parallel(self, params, group, lr,
|
| 329 |
"""
|
| 330 |
Perform a parallel optimization step using Muon.
|
| 331 |
"""
|
|
@@ -364,7 +368,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 364 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 365 |
state = param_to_state[id(p)]
|
| 366 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 367 |
-
_scatter(
|
|
|
|
|
|
|
| 368 |
|
| 369 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 370 |
|
|
@@ -398,23 +404,48 @@ class Muon(torch.optim.Optimizer):
|
|
| 398 |
|
| 399 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 400 |
lr = group["lr"]
|
| 401 |
-
|
| 402 |
momentum = group["momentum"]
|
| 403 |
|
| 404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
self.parallel(
|
| 406 |
-
|
| 407 |
group,
|
| 408 |
lr=lr,
|
| 409 |
-
|
| 410 |
momentum=momentum,
|
| 411 |
)
|
| 412 |
-
|
|
|
|
| 413 |
self.base(
|
| 414 |
-
|
| 415 |
group,
|
| 416 |
lr=lr,
|
| 417 |
-
|
| 418 |
momentum=momentum,
|
| 419 |
)
|
| 420 |
|
|
@@ -426,7 +457,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 426 |
lr = group["lr"]
|
| 427 |
beta1, beta2 = group["adamw_betas"]
|
| 428 |
eps = group["adamw_eps"]
|
| 429 |
-
weight_decay = group["
|
| 430 |
|
| 431 |
for p in params:
|
| 432 |
g = p.grad
|
|
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
+
from torch.distributed._tensor import DTensor, Replicate
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
|
| 103 |
|
| 104 |
|
| 105 |
@torch.no_grad()
|
| 106 |
+
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
| 107 |
u = state.computed_u
|
| 108 |
mesh = p.device_mesh
|
| 109 |
|
|
|
|
| 131 |
placements=p.placements,
|
| 132 |
device_mesh=mesh,
|
| 133 |
)
|
| 134 |
+
p.data.mul_(1 - lr * weight_decay)
|
| 135 |
p.data.add_(u, alpha=-lr)
|
| 136 |
|
| 137 |
|
| 138 |
+
def default_is_muon(x, name):
|
| 139 |
+
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
| 140 |
+
|
| 141 |
+
|
| 142 |
class Muon(torch.optim.Optimizer):
|
| 143 |
"""
|
| 144 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
|
|
| 163 |
adamw_lr: The learning rate for the internal AdamW.
|
| 164 |
adamw_betas: The betas for the internal AdamW.
|
| 165 |
adamw_eps: The epsilon for the internal AdamW.
|
| 166 |
+
adamw_weight_decay: The weight decay for the internal AdamW.
|
| 167 |
"""
|
| 168 |
|
| 169 |
def __init__(
|
| 170 |
self,
|
| 171 |
model,
|
| 172 |
+
is_muon_func=default_is_muon,
|
| 173 |
lr=1e-3,
|
| 174 |
momentum=0.95,
|
| 175 |
nesterov=True,
|
| 176 |
ns_steps=5,
|
| 177 |
+
weight_decay=0.1,
|
| 178 |
adamw_betas=(0.9, 0.95),
|
| 179 |
adamw_eps=1e-8,
|
| 180 |
none_grad=True,
|
|
|
|
| 182 |
):
|
| 183 |
defaults = dict(
|
| 184 |
lr=lr,
|
| 185 |
+
weight_decay=weight_decay,
|
| 186 |
momentum=momentum,
|
| 187 |
nesterov=nesterov,
|
| 188 |
ns_steps=ns_steps,
|
|
|
|
| 276 |
|
| 277 |
return param_to_state, ordered_params
|
| 278 |
|
| 279 |
+
def base(self, params, group, lr, weight_decay, momentum):
|
| 280 |
# generate weight updates in distributed fashion
|
| 281 |
for p in params:
|
| 282 |
g = p.grad
|
|
|
|
| 303 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 304 |
|
| 305 |
# apply weight decay
|
| 306 |
+
p.data.mul_(1 - lr * weight_decay)
|
| 307 |
|
| 308 |
# apply update
|
| 309 |
p.data.add_(u, alpha=-adjusted_lr)
|
|
|
|
| 321 |
g = buf
|
| 322 |
return g
|
| 323 |
|
| 324 |
+
def _update_p(self, p, u, lr, weight_decay):
|
| 325 |
# scale update
|
| 326 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 327 |
# apply weight decay
|
| 328 |
+
p.data.mul_(1 - lr * weight_decay)
|
| 329 |
# apply update
|
| 330 |
p.data.add_(u, alpha=-adjusted_lr)
|
| 331 |
|
| 332 |
+
def parallel(self, params, group, lr, weight_decay, momentum):
|
| 333 |
"""
|
| 334 |
Perform a parallel optimization step using Muon.
|
| 335 |
"""
|
|
|
|
| 368 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 369 |
state = param_to_state[id(p)]
|
| 370 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 371 |
+
_scatter(
|
| 372 |
+
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
| 373 |
+
)
|
| 374 |
|
| 375 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 376 |
|
|
|
|
| 404 |
|
| 405 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 406 |
lr = group["lr"]
|
| 407 |
+
weight_decay = group["weight_decay"]
|
| 408 |
momentum = group["momentum"]
|
| 409 |
|
| 410 |
+
param_dtensors = []
|
| 411 |
+
param_tensors = []
|
| 412 |
+
|
| 413 |
+
for p in params:
|
| 414 |
+
if p is None or p.grad is None:
|
| 415 |
+
continue
|
| 416 |
+
if isinstance(p.data, DTensor):
|
| 417 |
+
if all(
|
| 418 |
+
isinstance(placement, Replicate) for placement in p.placements
|
| 419 |
+
):
|
| 420 |
+
param_tensors.append(p)
|
| 421 |
+
else:
|
| 422 |
+
param_dtensors.append(p)
|
| 423 |
+
elif isinstance(p.data, torch.Tensor):
|
| 424 |
+
param_tensors.append(p)
|
| 425 |
+
else:
|
| 426 |
+
raise TypeError(f"Unsupported parameter type: {type(p.data)}")
|
| 427 |
+
|
| 428 |
+
if self.debug:
|
| 429 |
+
print(
|
| 430 |
+
f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
|
| 431 |
+
flush=True,
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
if len(param_dtensors) > 0:
|
| 435 |
self.parallel(
|
| 436 |
+
param_dtensors,
|
| 437 |
group,
|
| 438 |
lr=lr,
|
| 439 |
+
weight_decay=weight_decay,
|
| 440 |
momentum=momentum,
|
| 441 |
)
|
| 442 |
+
|
| 443 |
+
if len(param_tensors) > 0:
|
| 444 |
self.base(
|
| 445 |
+
param_tensors,
|
| 446 |
group,
|
| 447 |
lr=lr,
|
| 448 |
+
weight_decay=weight_decay,
|
| 449 |
momentum=momentum,
|
| 450 |
)
|
| 451 |
|
|
|
|
| 457 |
lr = group["lr"]
|
| 458 |
beta1, beta2 = group["adamw_betas"]
|
| 459 |
eps = group["adamw_eps"]
|
| 460 |
+
weight_decay = group["weight_decay"]
|
| 461 |
|
| 462 |
for p in params:
|
| 463 |
g = p.grad
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_64757cb_dirty
|
| 3 |
+
ops = torch.ops._optimizer_64757cb_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_64757cb_dirty::{op_name}"
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1787368
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8413f32011996384f13a985a99b4e2f863f8e4717acdb8439b63a10f77db6f15
|
| 3 |
size 1787368
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
-
from torch.distributed._tensor import DTensor
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
@@ -103,7 +103,7 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 103 |
|
| 104 |
|
| 105 |
@torch.no_grad()
|
| 106 |
-
def _scatter(p, state, lr,
|
| 107 |
u = state.computed_u
|
| 108 |
mesh = p.device_mesh
|
| 109 |
|
|
@@ -131,10 +131,14 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
| 131 |
placements=p.placements,
|
| 132 |
device_mesh=mesh,
|
| 133 |
)
|
| 134 |
-
p.data.mul_(1 - lr *
|
| 135 |
p.data.add_(u, alpha=-lr)
|
| 136 |
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
class Muon(torch.optim.Optimizer):
|
| 139 |
"""
|
| 140 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
@@ -159,18 +163,18 @@ class Muon(torch.optim.Optimizer):
|
|
| 159 |
adamw_lr: The learning rate for the internal AdamW.
|
| 160 |
adamw_betas: The betas for the internal AdamW.
|
| 161 |
adamw_eps: The epsilon for the internal AdamW.
|
| 162 |
-
|
| 163 |
"""
|
| 164 |
|
| 165 |
def __init__(
|
| 166 |
self,
|
| 167 |
model,
|
| 168 |
-
is_muon_func,
|
| 169 |
lr=1e-3,
|
| 170 |
momentum=0.95,
|
| 171 |
nesterov=True,
|
| 172 |
ns_steps=5,
|
| 173 |
-
|
| 174 |
adamw_betas=(0.9, 0.95),
|
| 175 |
adamw_eps=1e-8,
|
| 176 |
none_grad=True,
|
|
@@ -178,7 +182,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 178 |
):
|
| 179 |
defaults = dict(
|
| 180 |
lr=lr,
|
| 181 |
-
|
| 182 |
momentum=momentum,
|
| 183 |
nesterov=nesterov,
|
| 184 |
ns_steps=ns_steps,
|
|
@@ -272,7 +276,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 272 |
|
| 273 |
return param_to_state, ordered_params
|
| 274 |
|
| 275 |
-
def base(self, params, group, lr,
|
| 276 |
# generate weight updates in distributed fashion
|
| 277 |
for p in params:
|
| 278 |
g = p.grad
|
|
@@ -299,7 +303,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 299 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 300 |
|
| 301 |
# apply weight decay
|
| 302 |
-
p.data.mul_(1 - lr *
|
| 303 |
|
| 304 |
# apply update
|
| 305 |
p.data.add_(u, alpha=-adjusted_lr)
|
|
@@ -317,15 +321,15 @@ class Muon(torch.optim.Optimizer):
|
|
| 317 |
g = buf
|
| 318 |
return g
|
| 319 |
|
| 320 |
-
def _update_p(self, p, u, lr,
|
| 321 |
# scale update
|
| 322 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 323 |
# apply weight decay
|
| 324 |
-
p.data.mul_(1 - lr *
|
| 325 |
# apply update
|
| 326 |
p.data.add_(u, alpha=-adjusted_lr)
|
| 327 |
|
| 328 |
-
def parallel(self, params, group, lr,
|
| 329 |
"""
|
| 330 |
Perform a parallel optimization step using Muon.
|
| 331 |
"""
|
|
@@ -364,7 +368,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 364 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 365 |
state = param_to_state[id(p)]
|
| 366 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 367 |
-
_scatter(
|
|
|
|
|
|
|
| 368 |
|
| 369 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 370 |
|
|
@@ -398,23 +404,48 @@ class Muon(torch.optim.Optimizer):
|
|
| 398 |
|
| 399 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 400 |
lr = group["lr"]
|
| 401 |
-
|
| 402 |
momentum = group["momentum"]
|
| 403 |
|
| 404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
self.parallel(
|
| 406 |
-
|
| 407 |
group,
|
| 408 |
lr=lr,
|
| 409 |
-
|
| 410 |
momentum=momentum,
|
| 411 |
)
|
| 412 |
-
|
|
|
|
| 413 |
self.base(
|
| 414 |
-
|
| 415 |
group,
|
| 416 |
lr=lr,
|
| 417 |
-
|
| 418 |
momentum=momentum,
|
| 419 |
)
|
| 420 |
|
|
@@ -426,7 +457,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 426 |
lr = group["lr"]
|
| 427 |
beta1, beta2 = group["adamw_betas"]
|
| 428 |
eps = group["adamw_eps"]
|
| 429 |
-
weight_decay = group["
|
| 430 |
|
| 431 |
for p in params:
|
| 432 |
g = p.grad
|
|
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
+
from torch.distributed._tensor import DTensor, Replicate
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
|
| 103 |
|
| 104 |
|
| 105 |
@torch.no_grad()
|
| 106 |
+
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
| 107 |
u = state.computed_u
|
| 108 |
mesh = p.device_mesh
|
| 109 |
|
|
|
|
| 131 |
placements=p.placements,
|
| 132 |
device_mesh=mesh,
|
| 133 |
)
|
| 134 |
+
p.data.mul_(1 - lr * weight_decay)
|
| 135 |
p.data.add_(u, alpha=-lr)
|
| 136 |
|
| 137 |
|
| 138 |
+
def default_is_muon(x, name):
|
| 139 |
+
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
| 140 |
+
|
| 141 |
+
|
| 142 |
class Muon(torch.optim.Optimizer):
|
| 143 |
"""
|
| 144 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
|
|
| 163 |
adamw_lr: The learning rate for the internal AdamW.
|
| 164 |
adamw_betas: The betas for the internal AdamW.
|
| 165 |
adamw_eps: The epsilon for the internal AdamW.
|
| 166 |
+
adamw_weight_decay: The weight decay for the internal AdamW.
|
| 167 |
"""
|
| 168 |
|
| 169 |
def __init__(
|
| 170 |
self,
|
| 171 |
model,
|
| 172 |
+
is_muon_func=default_is_muon,
|
| 173 |
lr=1e-3,
|
| 174 |
momentum=0.95,
|
| 175 |
nesterov=True,
|
| 176 |
ns_steps=5,
|
| 177 |
+
weight_decay=0.1,
|
| 178 |
adamw_betas=(0.9, 0.95),
|
| 179 |
adamw_eps=1e-8,
|
| 180 |
none_grad=True,
|
|
|
|
| 182 |
):
|
| 183 |
defaults = dict(
|
| 184 |
lr=lr,
|
| 185 |
+
weight_decay=weight_decay,
|
| 186 |
momentum=momentum,
|
| 187 |
nesterov=nesterov,
|
| 188 |
ns_steps=ns_steps,
|
|
|
|
| 276 |
|
| 277 |
return param_to_state, ordered_params
|
| 278 |
|
| 279 |
+
def base(self, params, group, lr, weight_decay, momentum):
|
| 280 |
# generate weight updates in distributed fashion
|
| 281 |
for p in params:
|
| 282 |
g = p.grad
|
|
|
|
| 303 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 304 |
|
| 305 |
# apply weight decay
|
| 306 |
+
p.data.mul_(1 - lr * weight_decay)
|
| 307 |
|
| 308 |
# apply update
|
| 309 |
p.data.add_(u, alpha=-adjusted_lr)
|
|
|
|
| 321 |
g = buf
|
| 322 |
return g
|
| 323 |
|
| 324 |
+
def _update_p(self, p, u, lr, weight_decay):
|
| 325 |
# scale update
|
| 326 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 327 |
# apply weight decay
|
| 328 |
+
p.data.mul_(1 - lr * weight_decay)
|
| 329 |
# apply update
|
| 330 |
p.data.add_(u, alpha=-adjusted_lr)
|
| 331 |
|
| 332 |
+
def parallel(self, params, group, lr, weight_decay, momentum):
|
| 333 |
"""
|
| 334 |
Perform a parallel optimization step using Muon.
|
| 335 |
"""
|
|
|
|
| 368 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 369 |
state = param_to_state[id(p)]
|
| 370 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 371 |
+
_scatter(
|
| 372 |
+
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
| 373 |
+
)
|
| 374 |
|
| 375 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 376 |
|
|
|
|
| 404 |
|
| 405 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 406 |
lr = group["lr"]
|
| 407 |
+
weight_decay = group["weight_decay"]
|
| 408 |
momentum = group["momentum"]
|
| 409 |
|
| 410 |
+
param_dtensors = []
|
| 411 |
+
param_tensors = []
|
| 412 |
+
|
| 413 |
+
for p in params:
|
| 414 |
+
if p is None or p.grad is None:
|
| 415 |
+
continue
|
| 416 |
+
if isinstance(p.data, DTensor):
|
| 417 |
+
if all(
|
| 418 |
+
isinstance(placement, Replicate) for placement in p.placements
|
| 419 |
+
):
|
| 420 |
+
param_tensors.append(p)
|
| 421 |
+
else:
|
| 422 |
+
param_dtensors.append(p)
|
| 423 |
+
elif isinstance(p.data, torch.Tensor):
|
| 424 |
+
param_tensors.append(p)
|
| 425 |
+
else:
|
| 426 |
+
raise TypeError(f"Unsupported parameter type: {type(p.data)}")
|
| 427 |
+
|
| 428 |
+
if self.debug:
|
| 429 |
+
print(
|
| 430 |
+
f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
|
| 431 |
+
flush=True,
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
if len(param_dtensors) > 0:
|
| 435 |
self.parallel(
|
| 436 |
+
param_dtensors,
|
| 437 |
group,
|
| 438 |
lr=lr,
|
| 439 |
+
weight_decay=weight_decay,
|
| 440 |
momentum=momentum,
|
| 441 |
)
|
| 442 |
+
|
| 443 |
+
if len(param_tensors) > 0:
|
| 444 |
self.base(
|
| 445 |
+
param_tensors,
|
| 446 |
group,
|
| 447 |
lr=lr,
|
| 448 |
+
weight_decay=weight_decay,
|
| 449 |
momentum=momentum,
|
| 450 |
)
|
| 451 |
|
|
|
|
| 457 |
lr = group["lr"]
|
| 458 |
beta1, beta2 = group["adamw_betas"]
|
| 459 |
eps = group["adamw_eps"]
|
| 460 |
+
weight_decay = group["weight_decay"]
|
| 461 |
|
| 462 |
for p in params:
|
| 463 |
g = p.grad
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_64757cb_dirty
|
| 3 |
+
ops = torch.ops._optimizer_64757cb_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_64757cb_dirty::{op_name}"
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1824256
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c9d303b11a0a82e9c51c7b32c7555bd351ec375b1879bf46eb64ea4aff32100f
|
| 3 |
size 1824256
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
-
from torch.distributed._tensor import DTensor
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
@@ -103,7 +103,7 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 103 |
|
| 104 |
|
| 105 |
@torch.no_grad()
|
| 106 |
-
def _scatter(p, state, lr,
|
| 107 |
u = state.computed_u
|
| 108 |
mesh = p.device_mesh
|
| 109 |
|
|
@@ -131,10 +131,14 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
| 131 |
placements=p.placements,
|
| 132 |
device_mesh=mesh,
|
| 133 |
)
|
| 134 |
-
p.data.mul_(1 - lr *
|
| 135 |
p.data.add_(u, alpha=-lr)
|
| 136 |
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
class Muon(torch.optim.Optimizer):
|
| 139 |
"""
|
| 140 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
@@ -159,18 +163,18 @@ class Muon(torch.optim.Optimizer):
|
|
| 159 |
adamw_lr: The learning rate for the internal AdamW.
|
| 160 |
adamw_betas: The betas for the internal AdamW.
|
| 161 |
adamw_eps: The epsilon for the internal AdamW.
|
| 162 |
-
|
| 163 |
"""
|
| 164 |
|
| 165 |
def __init__(
|
| 166 |
self,
|
| 167 |
model,
|
| 168 |
-
is_muon_func,
|
| 169 |
lr=1e-3,
|
| 170 |
momentum=0.95,
|
| 171 |
nesterov=True,
|
| 172 |
ns_steps=5,
|
| 173 |
-
|
| 174 |
adamw_betas=(0.9, 0.95),
|
| 175 |
adamw_eps=1e-8,
|
| 176 |
none_grad=True,
|
|
@@ -178,7 +182,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 178 |
):
|
| 179 |
defaults = dict(
|
| 180 |
lr=lr,
|
| 181 |
-
|
| 182 |
momentum=momentum,
|
| 183 |
nesterov=nesterov,
|
| 184 |
ns_steps=ns_steps,
|
|
@@ -272,7 +276,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 272 |
|
| 273 |
return param_to_state, ordered_params
|
| 274 |
|
| 275 |
-
def base(self, params, group, lr,
|
| 276 |
# generate weight updates in distributed fashion
|
| 277 |
for p in params:
|
| 278 |
g = p.grad
|
|
@@ -299,7 +303,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 299 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 300 |
|
| 301 |
# apply weight decay
|
| 302 |
-
p.data.mul_(1 - lr *
|
| 303 |
|
| 304 |
# apply update
|
| 305 |
p.data.add_(u, alpha=-adjusted_lr)
|
|
@@ -317,15 +321,15 @@ class Muon(torch.optim.Optimizer):
|
|
| 317 |
g = buf
|
| 318 |
return g
|
| 319 |
|
| 320 |
-
def _update_p(self, p, u, lr,
|
| 321 |
# scale update
|
| 322 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 323 |
# apply weight decay
|
| 324 |
-
p.data.mul_(1 - lr *
|
| 325 |
# apply update
|
| 326 |
p.data.add_(u, alpha=-adjusted_lr)
|
| 327 |
|
| 328 |
-
def parallel(self, params, group, lr,
|
| 329 |
"""
|
| 330 |
Perform a parallel optimization step using Muon.
|
| 331 |
"""
|
|
@@ -364,7 +368,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 364 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 365 |
state = param_to_state[id(p)]
|
| 366 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 367 |
-
_scatter(
|
|
|
|
|
|
|
| 368 |
|
| 369 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 370 |
|
|
@@ -398,23 +404,48 @@ class Muon(torch.optim.Optimizer):
|
|
| 398 |
|
| 399 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 400 |
lr = group["lr"]
|
| 401 |
-
|
| 402 |
momentum = group["momentum"]
|
| 403 |
|
| 404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
self.parallel(
|
| 406 |
-
|
| 407 |
group,
|
| 408 |
lr=lr,
|
| 409 |
-
|
| 410 |
momentum=momentum,
|
| 411 |
)
|
| 412 |
-
|
|
|
|
| 413 |
self.base(
|
| 414 |
-
|
| 415 |
group,
|
| 416 |
lr=lr,
|
| 417 |
-
|
| 418 |
momentum=momentum,
|
| 419 |
)
|
| 420 |
|
|
@@ -426,7 +457,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 426 |
lr = group["lr"]
|
| 427 |
beta1, beta2 = group["adamw_betas"]
|
| 428 |
eps = group["adamw_eps"]
|
| 429 |
-
weight_decay = group["
|
| 430 |
|
| 431 |
for p in params:
|
| 432 |
g = p.grad
|
|
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
+
from torch.distributed._tensor import DTensor, Replicate
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
|
| 103 |
|
| 104 |
|
| 105 |
@torch.no_grad()
|
| 106 |
+
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
| 107 |
u = state.computed_u
|
| 108 |
mesh = p.device_mesh
|
| 109 |
|
|
|
|
| 131 |
placements=p.placements,
|
| 132 |
device_mesh=mesh,
|
| 133 |
)
|
| 134 |
+
p.data.mul_(1 - lr * weight_decay)
|
| 135 |
p.data.add_(u, alpha=-lr)
|
| 136 |
|
| 137 |
|
| 138 |
+
def default_is_muon(x, name):
|
| 139 |
+
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
| 140 |
+
|
| 141 |
+
|
| 142 |
class Muon(torch.optim.Optimizer):
|
| 143 |
"""
|
| 144 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
|
|
| 163 |
adamw_lr: The learning rate for the internal AdamW.
|
| 164 |
adamw_betas: The betas for the internal AdamW.
|
| 165 |
adamw_eps: The epsilon for the internal AdamW.
|
| 166 |
+
adamw_weight_decay: The weight decay for the internal AdamW.
|
| 167 |
"""
|
| 168 |
|
| 169 |
def __init__(
|
| 170 |
self,
|
| 171 |
model,
|
| 172 |
+
is_muon_func=default_is_muon,
|
| 173 |
lr=1e-3,
|
| 174 |
momentum=0.95,
|
| 175 |
nesterov=True,
|
| 176 |
ns_steps=5,
|
| 177 |
+
weight_decay=0.1,
|
| 178 |
adamw_betas=(0.9, 0.95),
|
| 179 |
adamw_eps=1e-8,
|
| 180 |
none_grad=True,
|
|
|
|
| 182 |
):
|
| 183 |
defaults = dict(
|
| 184 |
lr=lr,
|
| 185 |
+
weight_decay=weight_decay,
|
| 186 |
momentum=momentum,
|
| 187 |
nesterov=nesterov,
|
| 188 |
ns_steps=ns_steps,
|
|
|
|
| 276 |
|
| 277 |
return param_to_state, ordered_params
|
| 278 |
|
| 279 |
+
def base(self, params, group, lr, weight_decay, momentum):
|
| 280 |
# generate weight updates in distributed fashion
|
| 281 |
for p in params:
|
| 282 |
g = p.grad
|
|
|
|
| 303 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 304 |
|
| 305 |
# apply weight decay
|
| 306 |
+
p.data.mul_(1 - lr * weight_decay)
|
| 307 |
|
| 308 |
# apply update
|
| 309 |
p.data.add_(u, alpha=-adjusted_lr)
|
|
|
|
| 321 |
g = buf
|
| 322 |
return g
|
| 323 |
|
| 324 |
+
def _update_p(self, p, u, lr, weight_decay):
|
| 325 |
# scale update
|
| 326 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 327 |
# apply weight decay
|
| 328 |
+
p.data.mul_(1 - lr * weight_decay)
|
| 329 |
# apply update
|
| 330 |
p.data.add_(u, alpha=-adjusted_lr)
|
| 331 |
|
| 332 |
+
def parallel(self, params, group, lr, weight_decay, momentum):
|
| 333 |
"""
|
| 334 |
Perform a parallel optimization step using Muon.
|
| 335 |
"""
|
|
|
|
| 368 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 369 |
state = param_to_state[id(p)]
|
| 370 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 371 |
+
_scatter(
|
| 372 |
+
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
| 373 |
+
)
|
| 374 |
|
| 375 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 376 |
|
|
|
|
| 404 |
|
| 405 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 406 |
lr = group["lr"]
|
| 407 |
+
weight_decay = group["weight_decay"]
|
| 408 |
momentum = group["momentum"]
|
| 409 |
|
| 410 |
+
param_dtensors = []
|
| 411 |
+
param_tensors = []
|
| 412 |
+
|
| 413 |
+
for p in params:
|
| 414 |
+
if p is None or p.grad is None:
|
| 415 |
+
continue
|
| 416 |
+
if isinstance(p.data, DTensor):
|
| 417 |
+
if all(
|
| 418 |
+
isinstance(placement, Replicate) for placement in p.placements
|
| 419 |
+
):
|
| 420 |
+
param_tensors.append(p)
|
| 421 |
+
else:
|
| 422 |
+
param_dtensors.append(p)
|
| 423 |
+
elif isinstance(p.data, torch.Tensor):
|
| 424 |
+
param_tensors.append(p)
|
| 425 |
+
else:
|
| 426 |
+
raise TypeError(f"Unsupported parameter type: {type(p.data)}")
|
| 427 |
+
|
| 428 |
+
if self.debug:
|
| 429 |
+
print(
|
| 430 |
+
f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
|
| 431 |
+
flush=True,
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
if len(param_dtensors) > 0:
|
| 435 |
self.parallel(
|
| 436 |
+
param_dtensors,
|
| 437 |
group,
|
| 438 |
lr=lr,
|
| 439 |
+
weight_decay=weight_decay,
|
| 440 |
momentum=momentum,
|
| 441 |
)
|
| 442 |
+
|
| 443 |
+
if len(param_tensors) > 0:
|
| 444 |
self.base(
|
| 445 |
+
param_tensors,
|
| 446 |
group,
|
| 447 |
lr=lr,
|
| 448 |
+
weight_decay=weight_decay,
|
| 449 |
momentum=momentum,
|
| 450 |
)
|
| 451 |
|
|
|
|
| 457 |
lr = group["lr"]
|
| 458 |
beta1, beta2 = group["adamw_betas"]
|
| 459 |
eps = group["adamw_eps"]
|
| 460 |
+
weight_decay = group["weight_decay"]
|
| 461 |
|
| 462 |
for p in params:
|
| 463 |
g = p.grad
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_64757cb_dirty
|
| 3 |
+
ops = torch.ops._optimizer_64757cb_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_64757cb_dirty::{op_name}"
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1883352
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b4334fe8d7157c2a9c85217cb981692daf9eb4c6d3f205d0fd41d4b717daefa1
|
| 3 |
size 1883352
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
-
from torch.distributed._tensor import DTensor
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
@@ -103,7 +103,7 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 103 |
|
| 104 |
|
| 105 |
@torch.no_grad()
|
| 106 |
-
def _scatter(p, state, lr,
|
| 107 |
u = state.computed_u
|
| 108 |
mesh = p.device_mesh
|
| 109 |
|
|
@@ -131,10 +131,14 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
| 131 |
placements=p.placements,
|
| 132 |
device_mesh=mesh,
|
| 133 |
)
|
| 134 |
-
p.data.mul_(1 - lr *
|
| 135 |
p.data.add_(u, alpha=-lr)
|
| 136 |
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
class Muon(torch.optim.Optimizer):
|
| 139 |
"""
|
| 140 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
@@ -159,18 +163,18 @@ class Muon(torch.optim.Optimizer):
|
|
| 159 |
adamw_lr: The learning rate for the internal AdamW.
|
| 160 |
adamw_betas: The betas for the internal AdamW.
|
| 161 |
adamw_eps: The epsilon for the internal AdamW.
|
| 162 |
-
|
| 163 |
"""
|
| 164 |
|
| 165 |
def __init__(
|
| 166 |
self,
|
| 167 |
model,
|
| 168 |
-
is_muon_func,
|
| 169 |
lr=1e-3,
|
| 170 |
momentum=0.95,
|
| 171 |
nesterov=True,
|
| 172 |
ns_steps=5,
|
| 173 |
-
|
| 174 |
adamw_betas=(0.9, 0.95),
|
| 175 |
adamw_eps=1e-8,
|
| 176 |
none_grad=True,
|
|
@@ -178,7 +182,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 178 |
):
|
| 179 |
defaults = dict(
|
| 180 |
lr=lr,
|
| 181 |
-
|
| 182 |
momentum=momentum,
|
| 183 |
nesterov=nesterov,
|
| 184 |
ns_steps=ns_steps,
|
|
@@ -272,7 +276,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 272 |
|
| 273 |
return param_to_state, ordered_params
|
| 274 |
|
| 275 |
-
def base(self, params, group, lr,
|
| 276 |
# generate weight updates in distributed fashion
|
| 277 |
for p in params:
|
| 278 |
g = p.grad
|
|
@@ -299,7 +303,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 299 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 300 |
|
| 301 |
# apply weight decay
|
| 302 |
-
p.data.mul_(1 - lr *
|
| 303 |
|
| 304 |
# apply update
|
| 305 |
p.data.add_(u, alpha=-adjusted_lr)
|
|
@@ -317,15 +321,15 @@ class Muon(torch.optim.Optimizer):
|
|
| 317 |
g = buf
|
| 318 |
return g
|
| 319 |
|
| 320 |
-
def _update_p(self, p, u, lr,
|
| 321 |
# scale update
|
| 322 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 323 |
# apply weight decay
|
| 324 |
-
p.data.mul_(1 - lr *
|
| 325 |
# apply update
|
| 326 |
p.data.add_(u, alpha=-adjusted_lr)
|
| 327 |
|
| 328 |
-
def parallel(self, params, group, lr,
|
| 329 |
"""
|
| 330 |
Perform a parallel optimization step using Muon.
|
| 331 |
"""
|
|
@@ -364,7 +368,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 364 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 365 |
state = param_to_state[id(p)]
|
| 366 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 367 |
-
_scatter(
|
|
|
|
|
|
|
| 368 |
|
| 369 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 370 |
|
|
@@ -398,23 +404,48 @@ class Muon(torch.optim.Optimizer):
|
|
| 398 |
|
| 399 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 400 |
lr = group["lr"]
|
| 401 |
-
|
| 402 |
momentum = group["momentum"]
|
| 403 |
|
| 404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
self.parallel(
|
| 406 |
-
|
| 407 |
group,
|
| 408 |
lr=lr,
|
| 409 |
-
|
| 410 |
momentum=momentum,
|
| 411 |
)
|
| 412 |
-
|
|
|
|
| 413 |
self.base(
|
| 414 |
-
|
| 415 |
group,
|
| 416 |
lr=lr,
|
| 417 |
-
|
| 418 |
momentum=momentum,
|
| 419 |
)
|
| 420 |
|
|
@@ -426,7 +457,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 426 |
lr = group["lr"]
|
| 427 |
beta1, beta2 = group["adamw_betas"]
|
| 428 |
eps = group["adamw_eps"]
|
| 429 |
-
weight_decay = group["
|
| 430 |
|
| 431 |
for p in params:
|
| 432 |
g = p.grad
|
|
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
+
from torch.distributed._tensor import DTensor, Replicate
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
|
| 103 |
|
| 104 |
|
| 105 |
@torch.no_grad()
|
| 106 |
+
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
| 107 |
u = state.computed_u
|
| 108 |
mesh = p.device_mesh
|
| 109 |
|
|
|
|
| 131 |
placements=p.placements,
|
| 132 |
device_mesh=mesh,
|
| 133 |
)
|
| 134 |
+
p.data.mul_(1 - lr * weight_decay)
|
| 135 |
p.data.add_(u, alpha=-lr)
|
| 136 |
|
| 137 |
|
| 138 |
+
def default_is_muon(x, name):
|
| 139 |
+
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
| 140 |
+
|
| 141 |
+
|
| 142 |
class Muon(torch.optim.Optimizer):
|
| 143 |
"""
|
| 144 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
|
|
| 163 |
adamw_lr: The learning rate for the internal AdamW.
|
| 164 |
adamw_betas: The betas for the internal AdamW.
|
| 165 |
adamw_eps: The epsilon for the internal AdamW.
|
| 166 |
+
adamw_weight_decay: The weight decay for the internal AdamW.
|
| 167 |
"""
|
| 168 |
|
| 169 |
def __init__(
|
| 170 |
self,
|
| 171 |
model,
|
| 172 |
+
is_muon_func=default_is_muon,
|
| 173 |
lr=1e-3,
|
| 174 |
momentum=0.95,
|
| 175 |
nesterov=True,
|
| 176 |
ns_steps=5,
|
| 177 |
+
weight_decay=0.1,
|
| 178 |
adamw_betas=(0.9, 0.95),
|
| 179 |
adamw_eps=1e-8,
|
| 180 |
none_grad=True,
|
|
|
|
| 182 |
):
|
| 183 |
defaults = dict(
|
| 184 |
lr=lr,
|
| 185 |
+
weight_decay=weight_decay,
|
| 186 |
momentum=momentum,
|
| 187 |
nesterov=nesterov,
|
| 188 |
ns_steps=ns_steps,
|
|
|
|
| 276 |
|
| 277 |
return param_to_state, ordered_params
|
| 278 |
|
| 279 |
+
def base(self, params, group, lr, weight_decay, momentum):
|
| 280 |
# generate weight updates in distributed fashion
|
| 281 |
for p in params:
|
| 282 |
g = p.grad
|
|
|
|
| 303 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 304 |
|
| 305 |
# apply weight decay
|
| 306 |
+
p.data.mul_(1 - lr * weight_decay)
|
| 307 |
|
| 308 |
# apply update
|
| 309 |
p.data.add_(u, alpha=-adjusted_lr)
|
|
|
|
| 321 |
g = buf
|
| 322 |
return g
|
| 323 |
|
| 324 |
+
def _update_p(self, p, u, lr, weight_decay):
|
| 325 |
# scale update
|
| 326 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 327 |
# apply weight decay
|
| 328 |
+
p.data.mul_(1 - lr * weight_decay)
|
| 329 |
# apply update
|
| 330 |
p.data.add_(u, alpha=-adjusted_lr)
|
| 331 |
|
| 332 |
+
def parallel(self, params, group, lr, weight_decay, momentum):
|
| 333 |
"""
|
| 334 |
Perform a parallel optimization step using Muon.
|
| 335 |
"""
|
|
|
|
| 368 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 369 |
state = param_to_state[id(p)]
|
| 370 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 371 |
+
_scatter(
|
| 372 |
+
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
| 373 |
+
)
|
| 374 |
|
| 375 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 376 |
|
|
|
|
| 404 |
|
| 405 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 406 |
lr = group["lr"]
|
| 407 |
+
weight_decay = group["weight_decay"]
|
| 408 |
momentum = group["momentum"]
|
| 409 |
|
| 410 |
+
param_dtensors = []
|
| 411 |
+
param_tensors = []
|
| 412 |
+
|
| 413 |
+
for p in params:
|
| 414 |
+
if p is None or p.grad is None:
|
| 415 |
+
continue
|
| 416 |
+
if isinstance(p.data, DTensor):
|
| 417 |
+
if all(
|
| 418 |
+
isinstance(placement, Replicate) for placement in p.placements
|
| 419 |
+
):
|
| 420 |
+
param_tensors.append(p)
|
| 421 |
+
else:
|
| 422 |
+
param_dtensors.append(p)
|
| 423 |
+
elif isinstance(p.data, torch.Tensor):
|
| 424 |
+
param_tensors.append(p)
|
| 425 |
+
else:
|
| 426 |
+
raise TypeError(f"Unsupported parameter type: {type(p.data)}")
|
| 427 |
+
|
| 428 |
+
if self.debug:
|
| 429 |
+
print(
|
| 430 |
+
f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
|
| 431 |
+
flush=True,
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
if len(param_dtensors) > 0:
|
| 435 |
self.parallel(
|
| 436 |
+
param_dtensors,
|
| 437 |
group,
|
| 438 |
lr=lr,
|
| 439 |
+
weight_decay=weight_decay,
|
| 440 |
momentum=momentum,
|
| 441 |
)
|
| 442 |
+
|
| 443 |
+
if len(param_tensors) > 0:
|
| 444 |
self.base(
|
| 445 |
+
param_tensors,
|
| 446 |
group,
|
| 447 |
lr=lr,
|
| 448 |
+
weight_decay=weight_decay,
|
| 449 |
momentum=momentum,
|
| 450 |
)
|
| 451 |
|
|
|
|
| 457 |
lr = group["lr"]
|
| 458 |
beta1, beta2 = group["adamw_betas"]
|
| 459 |
eps = group["adamw_eps"]
|
| 460 |
+
weight_decay = group["weight_decay"]
|
| 461 |
|
| 462 |
for p in params:
|
| 463 |
g = p.grad
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (252 Bytes). View file
|
|
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-312.pyc
ADDED
|
Binary file (22 kB). View file
|
|
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _optimizer_64757cb_dirty
|
| 3 |
+
ops = torch.ops._optimizer_64757cb_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_optimizer_64757cb_dirty::{op_name}"
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1749648
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:272fcc69e3774fa43e222efefceeaca97a8c84ee3f1fe528a7478a8e80a70976
|
| 3 |
size 1749648
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py
CHANGED
|
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
-
from torch.distributed._tensor import DTensor
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
@@ -103,7 +103,7 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 103 |
|
| 104 |
|
| 105 |
@torch.no_grad()
|
| 106 |
-
def _scatter(p, state, lr,
|
| 107 |
u = state.computed_u
|
| 108 |
mesh = p.device_mesh
|
| 109 |
|
|
@@ -131,10 +131,14 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
| 131 |
placements=p.placements,
|
| 132 |
device_mesh=mesh,
|
| 133 |
)
|
| 134 |
-
p.data.mul_(1 - lr *
|
| 135 |
p.data.add_(u, alpha=-lr)
|
| 136 |
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
class Muon(torch.optim.Optimizer):
|
| 139 |
"""
|
| 140 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
@@ -159,18 +163,18 @@ class Muon(torch.optim.Optimizer):
|
|
| 159 |
adamw_lr: The learning rate for the internal AdamW.
|
| 160 |
adamw_betas: The betas for the internal AdamW.
|
| 161 |
adamw_eps: The epsilon for the internal AdamW.
|
| 162 |
-
|
| 163 |
"""
|
| 164 |
|
| 165 |
def __init__(
|
| 166 |
self,
|
| 167 |
model,
|
| 168 |
-
is_muon_func,
|
| 169 |
lr=1e-3,
|
| 170 |
momentum=0.95,
|
| 171 |
nesterov=True,
|
| 172 |
ns_steps=5,
|
| 173 |
-
|
| 174 |
adamw_betas=(0.9, 0.95),
|
| 175 |
adamw_eps=1e-8,
|
| 176 |
none_grad=True,
|
|
@@ -178,7 +182,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 178 |
):
|
| 179 |
defaults = dict(
|
| 180 |
lr=lr,
|
| 181 |
-
|
| 182 |
momentum=momentum,
|
| 183 |
nesterov=nesterov,
|
| 184 |
ns_steps=ns_steps,
|
|
@@ -272,7 +276,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 272 |
|
| 273 |
return param_to_state, ordered_params
|
| 274 |
|
| 275 |
-
def base(self, params, group, lr,
|
| 276 |
# generate weight updates in distributed fashion
|
| 277 |
for p in params:
|
| 278 |
g = p.grad
|
|
@@ -299,7 +303,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 299 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 300 |
|
| 301 |
# apply weight decay
|
| 302 |
-
p.data.mul_(1 - lr *
|
| 303 |
|
| 304 |
# apply update
|
| 305 |
p.data.add_(u, alpha=-adjusted_lr)
|
|
@@ -317,15 +321,15 @@ class Muon(torch.optim.Optimizer):
|
|
| 317 |
g = buf
|
| 318 |
return g
|
| 319 |
|
| 320 |
-
def _update_p(self, p, u, lr,
|
| 321 |
# scale update
|
| 322 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 323 |
# apply weight decay
|
| 324 |
-
p.data.mul_(1 - lr *
|
| 325 |
# apply update
|
| 326 |
p.data.add_(u, alpha=-adjusted_lr)
|
| 327 |
|
| 328 |
-
def parallel(self, params, group, lr,
|
| 329 |
"""
|
| 330 |
Perform a parallel optimization step using Muon.
|
| 331 |
"""
|
|
@@ -364,7 +368,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 364 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 365 |
state = param_to_state[id(p)]
|
| 366 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 367 |
-
_scatter(
|
|
|
|
|
|
|
| 368 |
|
| 369 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 370 |
|
|
@@ -398,23 +404,48 @@ class Muon(torch.optim.Optimizer):
|
|
| 398 |
|
| 399 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 400 |
lr = group["lr"]
|
| 401 |
-
|
| 402 |
momentum = group["momentum"]
|
| 403 |
|
| 404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
self.parallel(
|
| 406 |
-
|
| 407 |
group,
|
| 408 |
lr=lr,
|
| 409 |
-
|
| 410 |
momentum=momentum,
|
| 411 |
)
|
| 412 |
-
|
|
|
|
| 413 |
self.base(
|
| 414 |
-
|
| 415 |
group,
|
| 416 |
lr=lr,
|
| 417 |
-
|
| 418 |
momentum=momentum,
|
| 419 |
)
|
| 420 |
|
|
@@ -426,7 +457,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 426 |
lr = group["lr"]
|
| 427 |
beta1, beta2 = group["adamw_betas"]
|
| 428 |
eps = group["adamw_eps"]
|
| 429 |
-
weight_decay = group["
|
| 430 |
|
| 431 |
for p in params:
|
| 432 |
g = p.grad
|
|
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
+
from torch.distributed._tensor import DTensor, Replicate
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
|
| 103 |
|
| 104 |
|
| 105 |
@torch.no_grad()
|
| 106 |
+
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
| 107 |
u = state.computed_u
|
| 108 |
mesh = p.device_mesh
|
| 109 |
|
|
|
|
| 131 |
placements=p.placements,
|
| 132 |
device_mesh=mesh,
|
| 133 |
)
|
| 134 |
+
p.data.mul_(1 - lr * weight_decay)
|
| 135 |
p.data.add_(u, alpha=-lr)
|
| 136 |
|
| 137 |
|
| 138 |
+
def default_is_muon(x, name):
|
| 139 |
+
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
| 140 |
+
|
| 141 |
+
|
| 142 |
class Muon(torch.optim.Optimizer):
|
| 143 |
"""
|
| 144 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
|
|
| 163 |
adamw_lr: The learning rate for the internal AdamW.
|
| 164 |
adamw_betas: The betas for the internal AdamW.
|
| 165 |
adamw_eps: The epsilon for the internal AdamW.
|
| 166 |
+
adamw_weight_decay: The weight decay for the internal AdamW.
|
| 167 |
"""
|
| 168 |
|
| 169 |
def __init__(
|
| 170 |
self,
|
| 171 |
model,
|
| 172 |
+
is_muon_func=default_is_muon,
|
| 173 |
lr=1e-3,
|
| 174 |
momentum=0.95,
|
| 175 |
nesterov=True,
|
| 176 |
ns_steps=5,
|
| 177 |
+
weight_decay=0.1,
|
| 178 |
adamw_betas=(0.9, 0.95),
|
| 179 |
adamw_eps=1e-8,
|
| 180 |
none_grad=True,
|
|
|
|
| 182 |
):
|
| 183 |
defaults = dict(
|
| 184 |
lr=lr,
|
| 185 |
+
weight_decay=weight_decay,
|
| 186 |
momentum=momentum,
|
| 187 |
nesterov=nesterov,
|
| 188 |
ns_steps=ns_steps,
|
|
|
|
| 276 |
|
| 277 |
return param_to_state, ordered_params
|
| 278 |
|
| 279 |
+
def base(self, params, group, lr, weight_decay, momentum):
|
| 280 |
# generate weight updates in distributed fashion
|
| 281 |
for p in params:
|
| 282 |
g = p.grad
|
|
|
|
| 303 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 304 |
|
| 305 |
# apply weight decay
|
| 306 |
+
p.data.mul_(1 - lr * weight_decay)
|
| 307 |
|
| 308 |
# apply update
|
| 309 |
p.data.add_(u, alpha=-adjusted_lr)
|
|
|
|
| 321 |
g = buf
|
| 322 |
return g
|
| 323 |
|
| 324 |
+
def _update_p(self, p, u, lr, weight_decay):
|
| 325 |
# scale update
|
| 326 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 327 |
# apply weight decay
|
| 328 |
+
p.data.mul_(1 - lr * weight_decay)
|
| 329 |
# apply update
|
| 330 |
p.data.add_(u, alpha=-adjusted_lr)
|
| 331 |
|
| 332 |
+
def parallel(self, params, group, lr, weight_decay, momentum):
|
| 333 |
"""
|
| 334 |
Perform a parallel optimization step using Muon.
|
| 335 |
"""
|
|
|
|
| 368 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 369 |
state = param_to_state[id(p)]
|
| 370 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 371 |
+
_scatter(
|
| 372 |
+
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
| 373 |
+
)
|
| 374 |
|
| 375 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 376 |
|
|
|
|
| 404 |
|
| 405 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 406 |
lr = group["lr"]
|
| 407 |
+
weight_decay = group["weight_decay"]
|
| 408 |
momentum = group["momentum"]
|
| 409 |
|
| 410 |
+
param_dtensors = []
|
| 411 |
+
param_tensors = []
|
| 412 |
+
|
| 413 |
+
for p in params:
|
| 414 |
+
if p is None or p.grad is None:
|
| 415 |
+
continue
|
| 416 |
+
if isinstance(p.data, DTensor):
|
| 417 |
+
if all(
|
| 418 |
+
isinstance(placement, Replicate) for placement in p.placements
|
| 419 |
+
):
|
| 420 |
+
param_tensors.append(p)
|
| 421 |
+
else:
|
| 422 |
+
param_dtensors.append(p)
|
| 423 |
+
elif isinstance(p.data, torch.Tensor):
|
| 424 |
+
param_tensors.append(p)
|
| 425 |
+
else:
|
| 426 |
+
raise TypeError(f"Unsupported parameter type: {type(p.data)}")
|
| 427 |
+
|
| 428 |
+
if self.debug:
|
| 429 |
+
print(
|
| 430 |
+
f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
|
| 431 |
+
flush=True,
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
if len(param_dtensors) > 0:
|
| 435 |
self.parallel(
|
| 436 |
+
param_dtensors,
|
| 437 |
group,
|
| 438 |
lr=lr,
|
| 439 |
+
weight_decay=weight_decay,
|
| 440 |
momentum=momentum,
|
| 441 |
)
|
| 442 |
+
|
| 443 |
+
if len(param_tensors) > 0:
|
| 444 |
self.base(
|
| 445 |
+
param_tensors,
|
| 446 |
group,
|
| 447 |
lr=lr,
|
| 448 |
+
weight_decay=weight_decay,
|
| 449 |
momentum=momentum,
|
| 450 |
)
|
| 451 |
|
|
|
|
| 457 |
lr = group["lr"]
|
| 458 |
beta1, beta2 = group["adamw_betas"]
|
| 459 |
eps = group["adamw_eps"]
|
| 460 |
+
weight_decay = group["weight_decay"]
|
| 461 |
|
| 462 |
for p in params:
|
| 463 |
g = p.grad
|
torch-ext/optimizer/muon.py
CHANGED
|
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
-
from torch.distributed._tensor import DTensor
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
@@ -103,7 +103,7 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
| 103 |
|
| 104 |
|
| 105 |
@torch.no_grad()
|
| 106 |
-
def _scatter(p, state, lr,
|
| 107 |
u = state.computed_u
|
| 108 |
mesh = p.device_mesh
|
| 109 |
|
|
@@ -131,10 +131,14 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
| 131 |
placements=p.placements,
|
| 132 |
device_mesh=mesh,
|
| 133 |
)
|
| 134 |
-
p.data.mul_(1 - lr *
|
| 135 |
p.data.add_(u, alpha=-lr)
|
| 136 |
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
class Muon(torch.optim.Optimizer):
|
| 139 |
"""
|
| 140 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
@@ -159,18 +163,18 @@ class Muon(torch.optim.Optimizer):
|
|
| 159 |
adamw_lr: The learning rate for the internal AdamW.
|
| 160 |
adamw_betas: The betas for the internal AdamW.
|
| 161 |
adamw_eps: The epsilon for the internal AdamW.
|
| 162 |
-
|
| 163 |
"""
|
| 164 |
|
| 165 |
def __init__(
|
| 166 |
self,
|
| 167 |
model,
|
| 168 |
-
is_muon_func,
|
| 169 |
lr=1e-3,
|
| 170 |
momentum=0.95,
|
| 171 |
nesterov=True,
|
| 172 |
ns_steps=5,
|
| 173 |
-
|
| 174 |
adamw_betas=(0.9, 0.95),
|
| 175 |
adamw_eps=1e-8,
|
| 176 |
none_grad=True,
|
|
@@ -178,7 +182,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 178 |
):
|
| 179 |
defaults = dict(
|
| 180 |
lr=lr,
|
| 181 |
-
|
| 182 |
momentum=momentum,
|
| 183 |
nesterov=nesterov,
|
| 184 |
ns_steps=ns_steps,
|
|
@@ -272,7 +276,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 272 |
|
| 273 |
return param_to_state, ordered_params
|
| 274 |
|
| 275 |
-
def base(self, params, group, lr,
|
| 276 |
# generate weight updates in distributed fashion
|
| 277 |
for p in params:
|
| 278 |
g = p.grad
|
|
@@ -299,7 +303,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 299 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 300 |
|
| 301 |
# apply weight decay
|
| 302 |
-
p.data.mul_(1 - lr *
|
| 303 |
|
| 304 |
# apply update
|
| 305 |
p.data.add_(u, alpha=-adjusted_lr)
|
|
@@ -317,15 +321,15 @@ class Muon(torch.optim.Optimizer):
|
|
| 317 |
g = buf
|
| 318 |
return g
|
| 319 |
|
| 320 |
-
def _update_p(self, p, u, lr,
|
| 321 |
# scale update
|
| 322 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 323 |
# apply weight decay
|
| 324 |
-
p.data.mul_(1 - lr *
|
| 325 |
# apply update
|
| 326 |
p.data.add_(u, alpha=-adjusted_lr)
|
| 327 |
|
| 328 |
-
def parallel(self, params, group, lr,
|
| 329 |
"""
|
| 330 |
Perform a parallel optimization step using Muon.
|
| 331 |
"""
|
|
@@ -364,7 +368,9 @@ class Muon(torch.optim.Optimizer):
|
|
| 364 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 365 |
state = param_to_state[id(p)]
|
| 366 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 367 |
-
_scatter(
|
|
|
|
|
|
|
| 368 |
|
| 369 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 370 |
|
|
@@ -398,23 +404,48 @@ class Muon(torch.optim.Optimizer):
|
|
| 398 |
|
| 399 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 400 |
lr = group["lr"]
|
| 401 |
-
|
| 402 |
momentum = group["momentum"]
|
| 403 |
|
| 404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
self.parallel(
|
| 406 |
-
|
| 407 |
group,
|
| 408 |
lr=lr,
|
| 409 |
-
|
| 410 |
momentum=momentum,
|
| 411 |
)
|
| 412 |
-
|
|
|
|
| 413 |
self.base(
|
| 414 |
-
|
| 415 |
group,
|
| 416 |
lr=lr,
|
| 417 |
-
|
| 418 |
momentum=momentum,
|
| 419 |
)
|
| 420 |
|
|
@@ -426,7 +457,7 @@ class Muon(torch.optim.Optimizer):
|
|
| 426 |
lr = group["lr"]
|
| 427 |
beta1, beta2 = group["adamw_betas"]
|
| 428 |
eps = group["adamw_eps"]
|
| 429 |
-
weight_decay = group["
|
| 430 |
|
| 431 |
for p in params:
|
| 432 |
g = p.grad
|
|
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.distributed as dist
|
| 6 |
+
from torch.distributed._tensor import DTensor, Replicate
|
| 7 |
|
| 8 |
|
| 9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
|
| 103 |
|
| 104 |
|
| 105 |
@torch.no_grad()
|
| 106 |
+
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
| 107 |
u = state.computed_u
|
| 108 |
mesh = p.device_mesh
|
| 109 |
|
|
|
|
| 131 |
placements=p.placements,
|
| 132 |
device_mesh=mesh,
|
| 133 |
)
|
| 134 |
+
p.data.mul_(1 - lr * weight_decay)
|
| 135 |
p.data.add_(u, alpha=-lr)
|
| 136 |
|
| 137 |
|
| 138 |
+
def default_is_muon(x, name):
|
| 139 |
+
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
| 140 |
+
|
| 141 |
+
|
| 142 |
class Muon(torch.optim.Optimizer):
|
| 143 |
"""
|
| 144 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
|
|
| 163 |
adamw_lr: The learning rate for the internal AdamW.
|
| 164 |
adamw_betas: The betas for the internal AdamW.
|
| 165 |
adamw_eps: The epsilon for the internal AdamW.
|
| 166 |
+
adamw_weight_decay: The weight decay for the internal AdamW.
|
| 167 |
"""
|
| 168 |
|
| 169 |
def __init__(
|
| 170 |
self,
|
| 171 |
model,
|
| 172 |
+
is_muon_func=default_is_muon,
|
| 173 |
lr=1e-3,
|
| 174 |
momentum=0.95,
|
| 175 |
nesterov=True,
|
| 176 |
ns_steps=5,
|
| 177 |
+
weight_decay=0.1,
|
| 178 |
adamw_betas=(0.9, 0.95),
|
| 179 |
adamw_eps=1e-8,
|
| 180 |
none_grad=True,
|
|
|
|
| 182 |
):
|
| 183 |
defaults = dict(
|
| 184 |
lr=lr,
|
| 185 |
+
weight_decay=weight_decay,
|
| 186 |
momentum=momentum,
|
| 187 |
nesterov=nesterov,
|
| 188 |
ns_steps=ns_steps,
|
|
|
|
| 276 |
|
| 277 |
return param_to_state, ordered_params
|
| 278 |
|
| 279 |
+
def base(self, params, group, lr, weight_decay, momentum):
|
| 280 |
# generate weight updates in distributed fashion
|
| 281 |
for p in params:
|
| 282 |
g = p.grad
|
|
|
|
| 303 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 304 |
|
| 305 |
# apply weight decay
|
| 306 |
+
p.data.mul_(1 - lr * weight_decay)
|
| 307 |
|
| 308 |
# apply update
|
| 309 |
p.data.add_(u, alpha=-adjusted_lr)
|
|
|
|
| 321 |
g = buf
|
| 322 |
return g
|
| 323 |
|
| 324 |
+
def _update_p(self, p, u, lr, weight_decay):
|
| 325 |
# scale update
|
| 326 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 327 |
# apply weight decay
|
| 328 |
+
p.data.mul_(1 - lr * weight_decay)
|
| 329 |
# apply update
|
| 330 |
p.data.add_(u, alpha=-adjusted_lr)
|
| 331 |
|
| 332 |
+
def parallel(self, params, group, lr, weight_decay, momentum):
|
| 333 |
"""
|
| 334 |
Perform a parallel optimization step using Muon.
|
| 335 |
"""
|
|
|
|
| 368 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
| 369 |
state = param_to_state[id(p)]
|
| 370 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
| 371 |
+
_scatter(
|
| 372 |
+
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
| 373 |
+
)
|
| 374 |
|
| 375 |
chunk_size = params[0].device_mesh.mesh.numel()
|
| 376 |
|
|
|
|
| 404 |
|
| 405 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
| 406 |
lr = group["lr"]
|
| 407 |
+
weight_decay = group["weight_decay"]
|
| 408 |
momentum = group["momentum"]
|
| 409 |
|
| 410 |
+
param_dtensors = []
|
| 411 |
+
param_tensors = []
|
| 412 |
+
|
| 413 |
+
for p in params:
|
| 414 |
+
if p is None or p.grad is None:
|
| 415 |
+
continue
|
| 416 |
+
if isinstance(p.data, DTensor):
|
| 417 |
+
if all(
|
| 418 |
+
isinstance(placement, Replicate) for placement in p.placements
|
| 419 |
+
):
|
| 420 |
+
param_tensors.append(p)
|
| 421 |
+
else:
|
| 422 |
+
param_dtensors.append(p)
|
| 423 |
+
elif isinstance(p.data, torch.Tensor):
|
| 424 |
+
param_tensors.append(p)
|
| 425 |
+
else:
|
| 426 |
+
raise TypeError(f"Unsupported parameter type: {type(p.data)}")
|
| 427 |
+
|
| 428 |
+
if self.debug:
|
| 429 |
+
print(
|
| 430 |
+
f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
|
| 431 |
+
flush=True,
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
if len(param_dtensors) > 0:
|
| 435 |
self.parallel(
|
| 436 |
+
param_dtensors,
|
| 437 |
group,
|
| 438 |
lr=lr,
|
| 439 |
+
weight_decay=weight_decay,
|
| 440 |
momentum=momentum,
|
| 441 |
)
|
| 442 |
+
|
| 443 |
+
if len(param_tensors) > 0:
|
| 444 |
self.base(
|
| 445 |
+
param_tensors,
|
| 446 |
group,
|
| 447 |
lr=lr,
|
| 448 |
+
weight_decay=weight_decay,
|
| 449 |
momentum=momentum,
|
| 450 |
)
|
| 451 |
|
|
|
|
| 457 |
lr = group["lr"]
|
| 458 |
beta1, beta2 = group["adamw_betas"]
|
| 459 |
eps = group["adamw_eps"]
|
| 460 |
+
weight_decay = group["weight_decay"]
|
| 461 |
|
| 462 |
for p in params:
|
| 463 |
g = p.grad
|