diff --git a/build.toml b/build.toml index b80854db0a67cdde4e5c3dcb8d95f18704812383..ebabc676bfe40eb07e2bb447ff0c17605ac42844 100644 --- a/build.toml +++ b/build.toml @@ -1,23 +1,33 @@ [general] name = "optimizer" -universal = false - -[torch] -src = [ - "torch-ext/torch_binding.cpp", - "torch-ext/torch_binding.h", +backends = [ + "cuda", + "rocm", ] -[kernel.activation] -backend = "rocm" +[torch] src = [ - "optimizer/dummy.cu", + "torch-ext/torch_binding.cpp", + "torch-ext/torch_binding.h", ] -depends = [ "torch" ] -[kernel.activation_cuda] +[kernel.optimizer] backend = "cuda" -src = [ - "optimizer/dummy.cu", +depends = ["torch"] +src = ["optimizer/dummy.cu"] + +[kernel.optimizer_rocm] +backend = "rocm" +rocm-archs = [ + "gfx906", + "gfx908", + "gfx90a", + "gfx940", + "gfx941", + "gfx942", + "gfx1030", + "gfx1100", + "gfx1101", ] -depends = [ "torch" ] +depends = ["torch"] +src = ["optimizer/dummy.cu"] diff --git a/build/torch210-cxx11-cu126-x86_64-linux/__init__.py b/build/torch210-cxx11-cu126-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..239c7a65f8293e7d0df28f05fce645af56d628c0 --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/__init__.py @@ -0,0 +1,5 @@ +from .muon import Muon + +__all__ = [ + "Muon", +] diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_ops.py b/build/torch210-cxx11-cu126-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..e6f6fcf6280e969b1761926112147d3146e27b59 --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _optimizer_06a260a_dirty +ops = torch.ops._optimizer_06a260a_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch210-cxx11-cu126-x86_64-linux/_optimizer_06a260a_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..6015e5b4ea5da27e0002b298d9a1ab55142f88ab --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/_optimizer_06a260a_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5384da54f22f488e0646e09915b821b3235cb404b163a570aa377967f853e3cf +size 1940944 diff --git a/build/torch28-cxx11-cu128-x86_64-linux/optimizer/distributed/utils.py b/build/torch210-cxx11-cu126-x86_64-linux/distributed/utils.py similarity index 96% rename from build/torch28-cxx11-cu128-x86_64-linux/optimizer/distributed/utils.py rename to build/torch210-cxx11-cu126-x86_64-linux/distributed/utils.py index 0b4b58bfb329b1c015129e4c4fc99f7bfa2ab30a..6d5843506c13d9d31603b2b4e30c1c91d0baab28 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/optimizer/distributed/utils.py +++ b/build/torch210-cxx11-cu126-x86_64-linux/distributed/utils.py @@ -50,7 +50,7 @@ def get_slices_of_dtensor( raise NotImplementedError( f"Dimension size {dim_size} is not divisible " f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}.") + f"placement on dim {dim}. (shape: {target.shape})") shard_size = dim_size // num_ranks @@ -64,7 +64,8 @@ def get_slices_of_dtensor( return tuple(slices) -_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, ProcessGroup]] = dict() +_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, + ProcessGroup]] = dict() def construct_shard_mesh( diff --git a/build/torch28-cxx11-cu126-x86_64-linux/optimizer/matmul_transpose_triton.py b/build/torch210-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py similarity index 100% rename from build/torch28-cxx11-cu126-x86_64-linux/optimizer/matmul_transpose_triton.py rename to build/torch210-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py diff --git a/build/torch210-cxx11-cu126-x86_64-linux/metadata.json b/build/torch210-cxx11-cu126-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..76bafa5f33b6818aa6bb4cab04be811b87519b44 --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/metadata.json @@ -0,0 +1 @@ +{"python-depends":[]} \ No newline at end of file diff --git a/build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py b/build/torch210-cxx11-cu126-x86_64-linux/muon.py similarity index 92% rename from build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py rename to build/torch210-cxx11-cu126-x86_64-linux/muon.py index cfbcca71741be70048bfd290c62148b2aceda631..dbf25575f185ff379789482068e4ecf55b9455a9 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/optimizer/muon.py +++ b/build/torch210-cxx11-cu126-x86_64-linux/muon.py @@ -583,6 +583,7 @@ class Muon(torch.optim.Optimizer): Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. use_distributed_muon: Use distributed muon by Liu et al. (2024). For testing purpose only. + small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon """ def __init__(self, @@ -604,7 +605,8 @@ class Muon(torch.optim.Optimizer): }, warmup_step=5, chunk_size=-1, - use_distributed_muon=False): + use_distributed_muon=False, + small_param_numel_threshold=65536): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -637,6 +639,7 @@ class Muon(torch.optim.Optimizer): self.warmup_step = warmup_step self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon + self.small_param_numel_threshold = small_param_numel_threshold def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -745,16 +748,7 @@ class Muon(torch.optim.Optimizer): g = g.view(g.size(0), -1) assert g is not None - # calc update - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if group["nesterov"]: - g = g.add(buf, alpha=momentum) - else: - g = buf + g = self._update_g(p, g, group, momentum) u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), steps=group["ns_steps"]) @@ -780,14 +774,6 @@ class Muon(torch.optim.Optimizer): qk_logits: list[torch.Tensor | DTensor] | None, ): """ Implementation of Distributed Muon by Liu et al. """ - if qk_logits is not None: - raise NotImplementedError("QK clipping is not supported yet") - - if isinstance(params[0], DTensor): - shard_mesh, _, shard_placements = construct_shard_mesh( - placements=params[0].placements, - mesh=params[0].device_mesh, - ) for n, p in zip(names, params): g = p.grad @@ -797,39 +783,44 @@ class Muon(torch.optim.Optimizer): g = g.view(g.size(0), -1) assert g is not None - # calc update - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if group["nesterov"]: - g = g.add(buf, alpha=momentum) - else: - g = buf + g = self._update_g(p, g, group, momentum) # Gather G if isinstance(p.data, DTensor): - g = g.full_tensor() - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) + g_full = g.full_tensor() + p_full = p.data.full_tensor() + else: + g_full = g + p_full = p + + u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), + steps=group["ns_steps"]) + + adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) + Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + scales_full = self._compute_scales( + p_full, qk_clip_state) if qk_clip_state is not None else None + + if scales_full is not None: + Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) if isinstance(p.data, DTensor): - slices = get_slices_of_dtensor( - target=p, - local_rank=dist.get_rank(), - shard_mesh=shard_mesh, - shard_placements=shard_placements, + ndims = len(p.device_mesh.mesh.shape) + p_replicate = DTensor.from_local( + p_full, + device_mesh=p.device_mesh, + placements=[Replicate() for _ in range(ndims)], ) - u_shard = u[slices] - u = DTensor.from_local( - u_shard, + + p_sharded = p_replicate.redistribute( device_mesh=p.device_mesh, placements=p.placements, ) - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + p.copy_(p_sharded) def _update_g(self, p, g, group, momentum): # calc update @@ -843,10 +834,14 @@ class Muon(torch.optim.Optimizer): @staticmethod def _update_p(p, u, lr, adjusted_lr, weight_decay): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) + if isinstance(p, torch.nn.Parameter): + # apply weight decay + p.data.mul_(1 - lr * weight_decay) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + else: + p.mul_(1 - lr * weight_decay) + p.add_(u, alpha=-adjusted_lr) def get_qk_clip_info(self, n, qk_logits): if self.clip_config is None: @@ -903,8 +898,12 @@ class Muon(torch.optim.Optimizer): @staticmethod def _qk_clip(p, scales, head_dim): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) + if isinstance(p, torch.nn.Parameter): + W = p.data.view(-1, head_dim, p.data.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + else: + W = p.view(-1, head_dim, p.shape[1]) + W.mul_(scales.view(-1, 1, 1)) def parallel(self, names, params, group, lr, weight_decay, momentum, qk_logits): @@ -1070,10 +1069,14 @@ class Muon(torch.optim.Optimizer): names = group["names"] param_dtensors = [] - param_tensors = [] name_dtensors = [] + + param_tensors = [] name_tensors = [] + param_dtensors_small = [] + name_dtensors_small = [] + if self.use_distributed_muon: self.distributed_muon(names=names, params=params, @@ -1084,6 +1087,8 @@ class Muon(torch.optim.Optimizer): qk_logits=qk_logits) return + # For simplicity, we use distributed Muon for small parameters + # whose number of elements is below a threshold. for n, p in zip(names, params): if p is None or p.grad is None: continue @@ -1093,6 +1098,9 @@ class Muon(torch.optim.Optimizer): for placement in p.placements): param_tensors.append(p) name_tensors.append(n) + elif p.data.numel() <= self.small_param_numel_threshold: + param_dtensors_small.append(p) + name_dtensors_small.append(n) else: param_dtensors.append(p) name_dtensors.append(n) @@ -1103,29 +1111,48 @@ class Muon(torch.optim.Optimizer): raise TypeError(f"Unsupported parameter type: {type(p.data)}") logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors" - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) + f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " + f"{len(param_dtensors_small)} Small DTensors") + def group_dtensors(dtensors, names): # To support different placements, we group parameters by placements # and run parallel Muon on each group. placement_to_params = defaultdict(lambda: ([], [])) # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] - assert len(name_dtensors) == len(param_dtensors) - for n, p in zip(name_dtensors, param_dtensors): + assert len(dtensors) == len(names) + for p, n in zip(dtensors, names): placement_to_params[tuple([p.placements, p.device_mesh])][0].append(n) placement_to_params[tuple([p.placements, p.device_mesh])][1].append(p) + return placement_to_params + + if len(param_dtensors_small) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + self.distributed_muon( + params=param_dtensors_small, + names=name_dtensors_small, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_dtensors) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) - for _, (names, params) in placement_to_params.items(): + dtensor_group = group_dtensors(param_dtensors, name_dtensors) + for _, (names, params) in dtensor_group.items(): self.parallel( names, params, @@ -1215,6 +1242,7 @@ class Muon(torch.optim.Optimizer): for params in placement_to_params.values(): self._step_adamw_params(params, group) + @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step. diff --git a/build/torch210-cxx11-cu126-x86_64-linux/optimizer/__init__.py b/build/torch210-cxx11-cu126-x86_64-linux/optimizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/optimizer/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-cu128-x86_64-linux/__init__.py b/build/torch210-cxx11-cu128-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..239c7a65f8293e7d0df28f05fce645af56d628c0 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/__init__.py @@ -0,0 +1,5 @@ +from .muon import Muon + +__all__ = [ + "Muon", +] diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_ops.py b/build/torch210-cxx11-cu128-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..e6f6fcf6280e969b1761926112147d3146e27b59 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _optimizer_06a260a_dirty +ops = torch.ops._optimizer_06a260a_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch210-cxx11-cu128-x86_64-linux/_optimizer_06a260a_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..a2b4992c68bd2d564fa8ac804bce7a9f9d0a09d9 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/_optimizer_06a260a_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:976df6a1ec3ec4c462dea18477b56dfb75bcff76f504d55b592ce417931597c0 +size 2004144 diff --git a/build/torch28-cxx11-cu129-x86_64-linux/optimizer/distributed/utils.py b/build/torch210-cxx11-cu128-x86_64-linux/distributed/utils.py similarity index 96% rename from build/torch28-cxx11-cu129-x86_64-linux/optimizer/distributed/utils.py rename to build/torch210-cxx11-cu128-x86_64-linux/distributed/utils.py index 0b4b58bfb329b1c015129e4c4fc99f7bfa2ab30a..6d5843506c13d9d31603b2b4e30c1c91d0baab28 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/optimizer/distributed/utils.py +++ b/build/torch210-cxx11-cu128-x86_64-linux/distributed/utils.py @@ -50,7 +50,7 @@ def get_slices_of_dtensor( raise NotImplementedError( f"Dimension size {dim_size} is not divisible " f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}.") + f"placement on dim {dim}. (shape: {target.shape})") shard_size = dim_size // num_ranks @@ -64,7 +64,8 @@ def get_slices_of_dtensor( return tuple(slices) -_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, ProcessGroup]] = dict() +_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, + ProcessGroup]] = dict() def construct_shard_mesh( diff --git a/build/torch28-cxx11-cu128-x86_64-linux/optimizer/matmul_transpose_triton.py b/build/torch210-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py similarity index 100% rename from build/torch28-cxx11-cu128-x86_64-linux/optimizer/matmul_transpose_triton.py rename to build/torch210-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py diff --git a/build/torch210-cxx11-cu128-x86_64-linux/metadata.json b/build/torch210-cxx11-cu128-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..76bafa5f33b6818aa6bb4cab04be811b87519b44 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/metadata.json @@ -0,0 +1 @@ +{"python-depends":[]} \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py b/build/torch210-cxx11-cu128-x86_64-linux/muon.py similarity index 92% rename from build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py rename to build/torch210-cxx11-cu128-x86_64-linux/muon.py index cfbcca71741be70048bfd290c62148b2aceda631..dbf25575f185ff379789482068e4ecf55b9455a9 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/optimizer/muon.py +++ b/build/torch210-cxx11-cu128-x86_64-linux/muon.py @@ -583,6 +583,7 @@ class Muon(torch.optim.Optimizer): Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. use_distributed_muon: Use distributed muon by Liu et al. (2024). For testing purpose only. + small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon """ def __init__(self, @@ -604,7 +605,8 @@ class Muon(torch.optim.Optimizer): }, warmup_step=5, chunk_size=-1, - use_distributed_muon=False): + use_distributed_muon=False, + small_param_numel_threshold=65536): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -637,6 +639,7 @@ class Muon(torch.optim.Optimizer): self.warmup_step = warmup_step self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon + self.small_param_numel_threshold = small_param_numel_threshold def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -745,16 +748,7 @@ class Muon(torch.optim.Optimizer): g = g.view(g.size(0), -1) assert g is not None - # calc update - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if group["nesterov"]: - g = g.add(buf, alpha=momentum) - else: - g = buf + g = self._update_g(p, g, group, momentum) u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), steps=group["ns_steps"]) @@ -780,14 +774,6 @@ class Muon(torch.optim.Optimizer): qk_logits: list[torch.Tensor | DTensor] | None, ): """ Implementation of Distributed Muon by Liu et al. """ - if qk_logits is not None: - raise NotImplementedError("QK clipping is not supported yet") - - if isinstance(params[0], DTensor): - shard_mesh, _, shard_placements = construct_shard_mesh( - placements=params[0].placements, - mesh=params[0].device_mesh, - ) for n, p in zip(names, params): g = p.grad @@ -797,39 +783,44 @@ class Muon(torch.optim.Optimizer): g = g.view(g.size(0), -1) assert g is not None - # calc update - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if group["nesterov"]: - g = g.add(buf, alpha=momentum) - else: - g = buf + g = self._update_g(p, g, group, momentum) # Gather G if isinstance(p.data, DTensor): - g = g.full_tensor() - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) + g_full = g.full_tensor() + p_full = p.data.full_tensor() + else: + g_full = g + p_full = p + + u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), + steps=group["ns_steps"]) + + adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) + Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + scales_full = self._compute_scales( + p_full, qk_clip_state) if qk_clip_state is not None else None + + if scales_full is not None: + Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) if isinstance(p.data, DTensor): - slices = get_slices_of_dtensor( - target=p, - local_rank=dist.get_rank(), - shard_mesh=shard_mesh, - shard_placements=shard_placements, + ndims = len(p.device_mesh.mesh.shape) + p_replicate = DTensor.from_local( + p_full, + device_mesh=p.device_mesh, + placements=[Replicate() for _ in range(ndims)], ) - u_shard = u[slices] - u = DTensor.from_local( - u_shard, + + p_sharded = p_replicate.redistribute( device_mesh=p.device_mesh, placements=p.placements, ) - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + p.copy_(p_sharded) def _update_g(self, p, g, group, momentum): # calc update @@ -843,10 +834,14 @@ class Muon(torch.optim.Optimizer): @staticmethod def _update_p(p, u, lr, adjusted_lr, weight_decay): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) + if isinstance(p, torch.nn.Parameter): + # apply weight decay + p.data.mul_(1 - lr * weight_decay) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + else: + p.mul_(1 - lr * weight_decay) + p.add_(u, alpha=-adjusted_lr) def get_qk_clip_info(self, n, qk_logits): if self.clip_config is None: @@ -903,8 +898,12 @@ class Muon(torch.optim.Optimizer): @staticmethod def _qk_clip(p, scales, head_dim): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) + if isinstance(p, torch.nn.Parameter): + W = p.data.view(-1, head_dim, p.data.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + else: + W = p.view(-1, head_dim, p.shape[1]) + W.mul_(scales.view(-1, 1, 1)) def parallel(self, names, params, group, lr, weight_decay, momentum, qk_logits): @@ -1070,10 +1069,14 @@ class Muon(torch.optim.Optimizer): names = group["names"] param_dtensors = [] - param_tensors = [] name_dtensors = [] + + param_tensors = [] name_tensors = [] + param_dtensors_small = [] + name_dtensors_small = [] + if self.use_distributed_muon: self.distributed_muon(names=names, params=params, @@ -1084,6 +1087,8 @@ class Muon(torch.optim.Optimizer): qk_logits=qk_logits) return + # For simplicity, we use distributed Muon for small parameters + # whose number of elements is below a threshold. for n, p in zip(names, params): if p is None or p.grad is None: continue @@ -1093,6 +1098,9 @@ class Muon(torch.optim.Optimizer): for placement in p.placements): param_tensors.append(p) name_tensors.append(n) + elif p.data.numel() <= self.small_param_numel_threshold: + param_dtensors_small.append(p) + name_dtensors_small.append(n) else: param_dtensors.append(p) name_dtensors.append(n) @@ -1103,29 +1111,48 @@ class Muon(torch.optim.Optimizer): raise TypeError(f"Unsupported parameter type: {type(p.data)}") logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors" - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) + f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " + f"{len(param_dtensors_small)} Small DTensors") + def group_dtensors(dtensors, names): # To support different placements, we group parameters by placements # and run parallel Muon on each group. placement_to_params = defaultdict(lambda: ([], [])) # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] - assert len(name_dtensors) == len(param_dtensors) - for n, p in zip(name_dtensors, param_dtensors): + assert len(dtensors) == len(names) + for p, n in zip(dtensors, names): placement_to_params[tuple([p.placements, p.device_mesh])][0].append(n) placement_to_params[tuple([p.placements, p.device_mesh])][1].append(p) + return placement_to_params + + if len(param_dtensors_small) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + self.distributed_muon( + params=param_dtensors_small, + names=name_dtensors_small, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_dtensors) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) - for _, (names, params) in placement_to_params.items(): + dtensor_group = group_dtensors(param_dtensors, name_dtensors) + for _, (names, params) in dtensor_group.items(): self.parallel( names, params, @@ -1215,6 +1242,7 @@ class Muon(torch.optim.Optimizer): for params in placement_to_params.values(): self._step_adamw_params(params, group) + @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step. diff --git a/build/torch210-cxx11-cu128-x86_64-linux/optimizer/__init__.py b/build/torch210-cxx11-cu128-x86_64-linux/optimizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/optimizer/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-cu130-x86_64-linux/__init__.py b/build/torch210-cxx11-cu130-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..239c7a65f8293e7d0df28f05fce645af56d628c0 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/__init__.py @@ -0,0 +1,5 @@ +from .muon import Muon + +__all__ = [ + "Muon", +] diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_ops.py b/build/torch210-cxx11-cu130-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..e6f6fcf6280e969b1761926112147d3146e27b59 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _optimizer_06a260a_dirty +ops = torch.ops._optimizer_06a260a_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch210-cxx11-cu130-x86_64-linux/_optimizer_06a260a_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..62bbc727da9606819a23c43dda20add2be7c1fe3 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/_optimizer_06a260a_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:330aaa6cb247ba3b5df7a13ced6ef7eff3e5d7a72a0b88f674f948aeaed66ee2 +size 2004728 diff --git a/build/torch28-cxx11-cu126-x86_64-linux/optimizer/distributed/utils.py b/build/torch210-cxx11-cu130-x86_64-linux/distributed/utils.py similarity index 96% rename from build/torch28-cxx11-cu126-x86_64-linux/optimizer/distributed/utils.py rename to build/torch210-cxx11-cu130-x86_64-linux/distributed/utils.py index 0b4b58bfb329b1c015129e4c4fc99f7bfa2ab30a..6d5843506c13d9d31603b2b4e30c1c91d0baab28 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/optimizer/distributed/utils.py +++ b/build/torch210-cxx11-cu130-x86_64-linux/distributed/utils.py @@ -50,7 +50,7 @@ def get_slices_of_dtensor( raise NotImplementedError( f"Dimension size {dim_size} is not divisible " f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}.") + f"placement on dim {dim}. (shape: {target.shape})") shard_size = dim_size // num_ranks @@ -64,7 +64,8 @@ def get_slices_of_dtensor( return tuple(slices) -_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, ProcessGroup]] = dict() +_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, + ProcessGroup]] = dict() def construct_shard_mesh( diff --git a/build/torch28-cxx11-cu129-x86_64-linux/optimizer/matmul_transpose_triton.py b/build/torch210-cxx11-cu130-x86_64-linux/matmul_transpose_triton.py similarity index 100% rename from build/torch28-cxx11-cu129-x86_64-linux/optimizer/matmul_transpose_triton.py rename to build/torch210-cxx11-cu130-x86_64-linux/matmul_transpose_triton.py diff --git a/build/torch210-cxx11-cu130-x86_64-linux/metadata.json b/build/torch210-cxx11-cu130-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..76bafa5f33b6818aa6bb4cab04be811b87519b44 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/metadata.json @@ -0,0 +1 @@ +{"python-depends":[]} \ No newline at end of file diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py b/build/torch210-cxx11-cu130-x86_64-linux/muon.py similarity index 92% rename from build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py rename to build/torch210-cxx11-cu130-x86_64-linux/muon.py index cfbcca71741be70048bfd290c62148b2aceda631..dbf25575f185ff379789482068e4ecf55b9455a9 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/muon.py +++ b/build/torch210-cxx11-cu130-x86_64-linux/muon.py @@ -583,6 +583,7 @@ class Muon(torch.optim.Optimizer): Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. use_distributed_muon: Use distributed muon by Liu et al. (2024). For testing purpose only. + small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon """ def __init__(self, @@ -604,7 +605,8 @@ class Muon(torch.optim.Optimizer): }, warmup_step=5, chunk_size=-1, - use_distributed_muon=False): + use_distributed_muon=False, + small_param_numel_threshold=65536): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -637,6 +639,7 @@ class Muon(torch.optim.Optimizer): self.warmup_step = warmup_step self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon + self.small_param_numel_threshold = small_param_numel_threshold def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -745,16 +748,7 @@ class Muon(torch.optim.Optimizer): g = g.view(g.size(0), -1) assert g is not None - # calc update - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if group["nesterov"]: - g = g.add(buf, alpha=momentum) - else: - g = buf + g = self._update_g(p, g, group, momentum) u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), steps=group["ns_steps"]) @@ -780,14 +774,6 @@ class Muon(torch.optim.Optimizer): qk_logits: list[torch.Tensor | DTensor] | None, ): """ Implementation of Distributed Muon by Liu et al. """ - if qk_logits is not None: - raise NotImplementedError("QK clipping is not supported yet") - - if isinstance(params[0], DTensor): - shard_mesh, _, shard_placements = construct_shard_mesh( - placements=params[0].placements, - mesh=params[0].device_mesh, - ) for n, p in zip(names, params): g = p.grad @@ -797,39 +783,44 @@ class Muon(torch.optim.Optimizer): g = g.view(g.size(0), -1) assert g is not None - # calc update - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if group["nesterov"]: - g = g.add(buf, alpha=momentum) - else: - g = buf + g = self._update_g(p, g, group, momentum) # Gather G if isinstance(p.data, DTensor): - g = g.full_tensor() - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) + g_full = g.full_tensor() + p_full = p.data.full_tensor() + else: + g_full = g + p_full = p + + u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), + steps=group["ns_steps"]) + + adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) + Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + scales_full = self._compute_scales( + p_full, qk_clip_state) if qk_clip_state is not None else None + + if scales_full is not None: + Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) if isinstance(p.data, DTensor): - slices = get_slices_of_dtensor( - target=p, - local_rank=dist.get_rank(), - shard_mesh=shard_mesh, - shard_placements=shard_placements, + ndims = len(p.device_mesh.mesh.shape) + p_replicate = DTensor.from_local( + p_full, + device_mesh=p.device_mesh, + placements=[Replicate() for _ in range(ndims)], ) - u_shard = u[slices] - u = DTensor.from_local( - u_shard, + + p_sharded = p_replicate.redistribute( device_mesh=p.device_mesh, placements=p.placements, ) - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + p.copy_(p_sharded) def _update_g(self, p, g, group, momentum): # calc update @@ -843,10 +834,14 @@ class Muon(torch.optim.Optimizer): @staticmethod def _update_p(p, u, lr, adjusted_lr, weight_decay): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) + if isinstance(p, torch.nn.Parameter): + # apply weight decay + p.data.mul_(1 - lr * weight_decay) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + else: + p.mul_(1 - lr * weight_decay) + p.add_(u, alpha=-adjusted_lr) def get_qk_clip_info(self, n, qk_logits): if self.clip_config is None: @@ -903,8 +898,12 @@ class Muon(torch.optim.Optimizer): @staticmethod def _qk_clip(p, scales, head_dim): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) + if isinstance(p, torch.nn.Parameter): + W = p.data.view(-1, head_dim, p.data.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + else: + W = p.view(-1, head_dim, p.shape[1]) + W.mul_(scales.view(-1, 1, 1)) def parallel(self, names, params, group, lr, weight_decay, momentum, qk_logits): @@ -1070,10 +1069,14 @@ class Muon(torch.optim.Optimizer): names = group["names"] param_dtensors = [] - param_tensors = [] name_dtensors = [] + + param_tensors = [] name_tensors = [] + param_dtensors_small = [] + name_dtensors_small = [] + if self.use_distributed_muon: self.distributed_muon(names=names, params=params, @@ -1084,6 +1087,8 @@ class Muon(torch.optim.Optimizer): qk_logits=qk_logits) return + # For simplicity, we use distributed Muon for small parameters + # whose number of elements is below a threshold. for n, p in zip(names, params): if p is None or p.grad is None: continue @@ -1093,6 +1098,9 @@ class Muon(torch.optim.Optimizer): for placement in p.placements): param_tensors.append(p) name_tensors.append(n) + elif p.data.numel() <= self.small_param_numel_threshold: + param_dtensors_small.append(p) + name_dtensors_small.append(n) else: param_dtensors.append(p) name_dtensors.append(n) @@ -1103,29 +1111,48 @@ class Muon(torch.optim.Optimizer): raise TypeError(f"Unsupported parameter type: {type(p.data)}") logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors" - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) + f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " + f"{len(param_dtensors_small)} Small DTensors") + def group_dtensors(dtensors, names): # To support different placements, we group parameters by placements # and run parallel Muon on each group. placement_to_params = defaultdict(lambda: ([], [])) # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] - assert len(name_dtensors) == len(param_dtensors) - for n, p in zip(name_dtensors, param_dtensors): + assert len(dtensors) == len(names) + for p, n in zip(dtensors, names): placement_to_params[tuple([p.placements, p.device_mesh])][0].append(n) placement_to_params[tuple([p.placements, p.device_mesh])][1].append(p) + return placement_to_params + + if len(param_dtensors_small) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + self.distributed_muon( + params=param_dtensors_small, + names=name_dtensors_small, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_dtensors) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) - for _, (names, params) in placement_to_params.items(): + dtensor_group = group_dtensors(param_dtensors, name_dtensors) + for _, (names, params) in dtensor_group.items(): self.parallel( names, params, @@ -1215,6 +1242,7 @@ class Muon(torch.optim.Optimizer): for params in placement_to_params.values(): self._step_adamw_params(params, group) + @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step. diff --git a/build/torch210-cxx11-cu130-x86_64-linux/optimizer/__init__.py b/build/torch210-cxx11-cu130-x86_64-linux/optimizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/optimizer/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/__init__.py b/build/torch210-cxx11-rocm70-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..239c7a65f8293e7d0df28f05fce645af56d628c0 --- /dev/null +++ b/build/torch210-cxx11-rocm70-x86_64-linux/__init__.py @@ -0,0 +1,5 @@ +from .muon import Muon + +__all__ = [ + "Muon", +] diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/_ops.py b/build/torch210-cxx11-rocm70-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..e6f6fcf6280e969b1761926112147d3146e27b59 --- /dev/null +++ b/build/torch210-cxx11-rocm70-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _optimizer_06a260a_dirty +ops = torch.ops._optimizer_06a260a_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch210-cxx11-rocm70-x86_64-linux/_optimizer_06a260a_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..a2bbc913106abe6d784d7634ad119d969ff23a3c --- /dev/null +++ b/build/torch210-cxx11-rocm70-x86_64-linux/_optimizer_06a260a_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3562c68e8ee85fc5b268e079150ffff69d52860092d59e44fb9b3c4526c5d497 +size 1866400 diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/distributed/utils.py b/build/torch210-cxx11-rocm70-x86_64-linux/distributed/utils.py similarity index 96% rename from build/torch28-cxx11-rocm63-x86_64-linux/optimizer/distributed/utils.py rename to build/torch210-cxx11-rocm70-x86_64-linux/distributed/utils.py index 0b4b58bfb329b1c015129e4c4fc99f7bfa2ab30a..6d5843506c13d9d31603b2b4e30c1c91d0baab28 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/distributed/utils.py +++ b/build/torch210-cxx11-rocm70-x86_64-linux/distributed/utils.py @@ -50,7 +50,7 @@ def get_slices_of_dtensor( raise NotImplementedError( f"Dimension size {dim_size} is not divisible " f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}.") + f"placement on dim {dim}. (shape: {target.shape})") shard_size = dim_size // num_ranks @@ -64,7 +64,8 @@ def get_slices_of_dtensor( return tuple(slices) -_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, ProcessGroup]] = dict() +_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, + ProcessGroup]] = dict() def construct_shard_mesh( diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/matmul_transpose_triton.py b/build/torch210-cxx11-rocm70-x86_64-linux/matmul_transpose_triton.py similarity index 100% rename from build/torch28-cxx11-rocm63-x86_64-linux/optimizer/matmul_transpose_triton.py rename to build/torch210-cxx11-rocm70-x86_64-linux/matmul_transpose_triton.py diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/metadata.json b/build/torch210-cxx11-rocm70-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..76bafa5f33b6818aa6bb4cab04be811b87519b44 --- /dev/null +++ b/build/torch210-cxx11-rocm70-x86_64-linux/metadata.json @@ -0,0 +1 @@ +{"python-depends":[]} \ No newline at end of file diff --git a/build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py b/build/torch210-cxx11-rocm70-x86_64-linux/muon.py similarity index 92% rename from build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py rename to build/torch210-cxx11-rocm70-x86_64-linux/muon.py index cfbcca71741be70048bfd290c62148b2aceda631..dbf25575f185ff379789482068e4ecf55b9455a9 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/optimizer/muon.py +++ b/build/torch210-cxx11-rocm70-x86_64-linux/muon.py @@ -583,6 +583,7 @@ class Muon(torch.optim.Optimizer): Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. use_distributed_muon: Use distributed muon by Liu et al. (2024). For testing purpose only. + small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon """ def __init__(self, @@ -604,7 +605,8 @@ class Muon(torch.optim.Optimizer): }, warmup_step=5, chunk_size=-1, - use_distributed_muon=False): + use_distributed_muon=False, + small_param_numel_threshold=65536): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -637,6 +639,7 @@ class Muon(torch.optim.Optimizer): self.warmup_step = warmup_step self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon + self.small_param_numel_threshold = small_param_numel_threshold def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -745,16 +748,7 @@ class Muon(torch.optim.Optimizer): g = g.view(g.size(0), -1) assert g is not None - # calc update - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if group["nesterov"]: - g = g.add(buf, alpha=momentum) - else: - g = buf + g = self._update_g(p, g, group, momentum) u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), steps=group["ns_steps"]) @@ -780,14 +774,6 @@ class Muon(torch.optim.Optimizer): qk_logits: list[torch.Tensor | DTensor] | None, ): """ Implementation of Distributed Muon by Liu et al. """ - if qk_logits is not None: - raise NotImplementedError("QK clipping is not supported yet") - - if isinstance(params[0], DTensor): - shard_mesh, _, shard_placements = construct_shard_mesh( - placements=params[0].placements, - mesh=params[0].device_mesh, - ) for n, p in zip(names, params): g = p.grad @@ -797,39 +783,44 @@ class Muon(torch.optim.Optimizer): g = g.view(g.size(0), -1) assert g is not None - # calc update - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if group["nesterov"]: - g = g.add(buf, alpha=momentum) - else: - g = buf + g = self._update_g(p, g, group, momentum) # Gather G if isinstance(p.data, DTensor): - g = g.full_tensor() - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) + g_full = g.full_tensor() + p_full = p.data.full_tensor() + else: + g_full = g + p_full = p + + u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), + steps=group["ns_steps"]) + + adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) + Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + scales_full = self._compute_scales( + p_full, qk_clip_state) if qk_clip_state is not None else None + + if scales_full is not None: + Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) if isinstance(p.data, DTensor): - slices = get_slices_of_dtensor( - target=p, - local_rank=dist.get_rank(), - shard_mesh=shard_mesh, - shard_placements=shard_placements, + ndims = len(p.device_mesh.mesh.shape) + p_replicate = DTensor.from_local( + p_full, + device_mesh=p.device_mesh, + placements=[Replicate() for _ in range(ndims)], ) - u_shard = u[slices] - u = DTensor.from_local( - u_shard, + + p_sharded = p_replicate.redistribute( device_mesh=p.device_mesh, placements=p.placements, ) - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + p.copy_(p_sharded) def _update_g(self, p, g, group, momentum): # calc update @@ -843,10 +834,14 @@ class Muon(torch.optim.Optimizer): @staticmethod def _update_p(p, u, lr, adjusted_lr, weight_decay): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) + if isinstance(p, torch.nn.Parameter): + # apply weight decay + p.data.mul_(1 - lr * weight_decay) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + else: + p.mul_(1 - lr * weight_decay) + p.add_(u, alpha=-adjusted_lr) def get_qk_clip_info(self, n, qk_logits): if self.clip_config is None: @@ -903,8 +898,12 @@ class Muon(torch.optim.Optimizer): @staticmethod def _qk_clip(p, scales, head_dim): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) + if isinstance(p, torch.nn.Parameter): + W = p.data.view(-1, head_dim, p.data.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + else: + W = p.view(-1, head_dim, p.shape[1]) + W.mul_(scales.view(-1, 1, 1)) def parallel(self, names, params, group, lr, weight_decay, momentum, qk_logits): @@ -1070,10 +1069,14 @@ class Muon(torch.optim.Optimizer): names = group["names"] param_dtensors = [] - param_tensors = [] name_dtensors = [] + + param_tensors = [] name_tensors = [] + param_dtensors_small = [] + name_dtensors_small = [] + if self.use_distributed_muon: self.distributed_muon(names=names, params=params, @@ -1084,6 +1087,8 @@ class Muon(torch.optim.Optimizer): qk_logits=qk_logits) return + # For simplicity, we use distributed Muon for small parameters + # whose number of elements is below a threshold. for n, p in zip(names, params): if p is None or p.grad is None: continue @@ -1093,6 +1098,9 @@ class Muon(torch.optim.Optimizer): for placement in p.placements): param_tensors.append(p) name_tensors.append(n) + elif p.data.numel() <= self.small_param_numel_threshold: + param_dtensors_small.append(p) + name_dtensors_small.append(n) else: param_dtensors.append(p) name_dtensors.append(n) @@ -1103,29 +1111,48 @@ class Muon(torch.optim.Optimizer): raise TypeError(f"Unsupported parameter type: {type(p.data)}") logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors" - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) + f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " + f"{len(param_dtensors_small)} Small DTensors") + def group_dtensors(dtensors, names): # To support different placements, we group parameters by placements # and run parallel Muon on each group. placement_to_params = defaultdict(lambda: ([], [])) # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] - assert len(name_dtensors) == len(param_dtensors) - for n, p in zip(name_dtensors, param_dtensors): + assert len(dtensors) == len(names) + for p, n in zip(dtensors, names): placement_to_params[tuple([p.placements, p.device_mesh])][0].append(n) placement_to_params[tuple([p.placements, p.device_mesh])][1].append(p) + return placement_to_params + + if len(param_dtensors_small) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + self.distributed_muon( + params=param_dtensors_small, + names=name_dtensors_small, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_dtensors) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) - for _, (names, params) in placement_to_params.items(): + dtensor_group = group_dtensors(param_dtensors, name_dtensors) + for _, (names, params) in dtensor_group.items(): self.parallel( names, params, @@ -1215,6 +1242,7 @@ class Muon(torch.optim.Optimizer): for params in placement_to_params.values(): self._step_adamw_params(params, group) + @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step. diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/optimizer/__init__.py b/build/torch210-cxx11-rocm70-x86_64-linux/optimizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch210-cxx11-rocm70-x86_64-linux/optimizer/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/__init__.py b/build/torch210-cxx11-rocm71-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..239c7a65f8293e7d0df28f05fce645af56d628c0 --- /dev/null +++ b/build/torch210-cxx11-rocm71-x86_64-linux/__init__.py @@ -0,0 +1,5 @@ +from .muon import Muon + +__all__ = [ + "Muon", +] diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/_ops.py b/build/torch210-cxx11-rocm71-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..e6f6fcf6280e969b1761926112147d3146e27b59 --- /dev/null +++ b/build/torch210-cxx11-rocm71-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _optimizer_06a260a_dirty +ops = torch.ops._optimizer_06a260a_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch210-cxx11-rocm71-x86_64-linux/_optimizer_06a260a_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..ed70a8ee48aca9da47db195b5e73c86aca32b153 --- /dev/null +++ b/build/torch210-cxx11-rocm71-x86_64-linux/_optimizer_06a260a_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d804ba4d3ed9716c80e9819ba16a2bef300fb23fa4c456c550f4a96167a2eb00 +size 1866112 diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/distributed/utils.py b/build/torch210-cxx11-rocm71-x86_64-linux/distributed/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6d5843506c13d9d31603b2b4e30c1c91d0baab28 --- /dev/null +++ b/build/torch210-cxx11-rocm71-x86_64-linux/distributed/utils.py @@ -0,0 +1,175 @@ +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor +from torch.distributed.tensor.placement_types import (Placement, Shard, + _StridedShard) + + +def get_slices_of_dtensor( + target: DTensor | torch.Tensor, + local_rank: int, + shard_mesh: DeviceMesh, + shard_placements: tuple[Placement], +) -> tuple[slice]: + """ + Get the slice of local tensor for a given rank from a tensor. + Args: + target (DTensor | torch.Tensor): The target tensor. + rank (int): The local rank of the shard group. + shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + shard_placements (tuple[Placement]): The shard placements. + """ + + slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + + # find the global rank of the local rank in the shard mesh + rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] + + rank_coords = (shard_mesh.mesh == rank).nonzero() + + assert len(rank_coords) == 1 + rank_coords = tuple(rank_coords[0].tolist()) + + assert len(rank_coords) == len(shard_placements) + + # Caution: Assuming replicate-to-shard of the shard mesh goes with + # left-to-right sharding. This is ensured by the sorting logic of + # construct_shard_mesh function. + for i, (rank_coord, + placement) in enumerate(zip(rank_coords, shard_placements)): + assert isinstance(placement, Shard) + + num_ranks = shard_mesh.mesh.shape[i] + + dim = placement.dim + dim_size = (slices[dim].stop - slices[dim].start) + + if dim_size % num_ranks != 0: + raise NotImplementedError( + f"Dimension size {dim_size} is not divisible " + f"by number of ranks {num_ranks} for shard " + f"placement on dim {dim}. (shape: {target.shape})") + + shard_size = dim_size // num_ranks + + start = slices[dim].start + rank_coord * shard_size + end = start + shard_size + + assert start < end <= slices[dim].stop + + slices[dim] = slice(start, end) + + return tuple(slices) + + +_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, + ProcessGroup]] = dict() + + +def construct_shard_mesh( + placements: tuple[Placement], + mesh: DeviceMesh, +) -> (DeviceMesh, ProcessGroup, tuple[Placement]): + """ + Construct Shard Mesh and Placements for unsharding. + It removes Replicate placements and constructs a new Mesh and ProcessGroup. + """ + my_rank = dist.get_rank() + + assert mesh.mesh.device.type == 'cpu' + + # Copy mesh to avoid modifying the original mesh + mesh = mesh.mesh.clone() + + # 1. Sort placements. Replicate first, then Shard by dim ascending. + + # For Shard, strided shard comes after regular shard on the same dim + # to preserve left-to-right order of replicate-to-shard. + # This is because that strided shard is using stride to represent + # more fine-grained sharding on the same dim. + # Please check the URL below for _StridedShard. + # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 + + def placement_sort_key( + placement_with_index: tuple[float, Placement] + ) -> tuple[int, float, int]: # (dim, split factor, original index) + index, placement = placement_with_index + is_replicate = placement.is_replicate() + is_shard = placement.is_shard() + is_partial = placement.is_partial() + + assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" + assert not is_partial, "Partial placement is not supported." + + if is_replicate: + return (-1.0, 0, index) + elif is_shard: + if isinstance(placement, _StridedShard): + return (placement.dim, 1 / placement.split_factor, index) + return (placement.dim, 0, index) + else: + raise TypeError(f"Unknown placement type: {type(placement)}") + + placements_with_index: list[tuple[int, + Placement]] = list(enumerate(placements)) + placements_with_index = sorted(placements_with_index, + key=placement_sort_key) + + sorted_indices, sorted_placements = zip(*placements_with_index) + + # 2. Permute mesh according to sorted placements. + sorted_mesh = mesh.permute(sorted_indices) + + # 3. Collect list of shard meshes by removing replicate dims + # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] + # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) + num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + + # merge replicate dims + # shard_meshes became a list of shard meshes with a length of replicate degree + if num_replicates > 0: + sorted_mesh = sorted_mesh.flatten( + 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) + else: + shard_meshes = [sorted_mesh] + shard_placements = sorted_placements[num_replicates:] + + # assume all shard placements are different + assert len(shard_placements) == len(set(shard_placements)) + + # 4. Construct ProcessGroups + # Caution: all groups should be created in the same order in all processes, + # even though each process only needs its own group. + + # To use tensor as dict key, convert it to tuple + def tensor_to_tuple(t): + if isinstance(t, torch.Tensor): + t = t.tolist() + if isinstance(t, list): + return tuple(tensor_to_tuple(x) for x in t) + return t + + my_shard_mesh_as_tuple = None + for shard_mesh in shard_meshes: + assert isinstance(shard_mesh, torch.Tensor) + shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) + + if (my_rank == shard_mesh).any().item(): + assert my_shard_mesh_as_tuple is None + my_shard_mesh_as_tuple = shard_mesh_as_tuple + + # update global cache + if shard_mesh_as_tuple not in _ranks_to_dist_cache: + shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) + _ranks_to_dist_cache[shard_mesh_as_tuple] = ( + DeviceMesh(device_type="cuda", mesh=shard_mesh), + shard_process_group, + ) + + my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ + my_shard_mesh_as_tuple] + + return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/matmul_transpose_triton.py b/build/torch210-cxx11-rocm71-x86_64-linux/matmul_transpose_triton.py similarity index 100% rename from build/torch28-cxx11-rocm64-x86_64-linux/optimizer/matmul_transpose_triton.py rename to build/torch210-cxx11-rocm71-x86_64-linux/matmul_transpose_triton.py diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/metadata.json b/build/torch210-cxx11-rocm71-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..76bafa5f33b6818aa6bb4cab04be811b87519b44 --- /dev/null +++ b/build/torch210-cxx11-rocm71-x86_64-linux/metadata.json @@ -0,0 +1 @@ +{"python-depends":[]} \ No newline at end of file diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/muon.py b/build/torch210-cxx11-rocm71-x86_64-linux/muon.py new file mode 100644 index 0000000000000000000000000000000000000000..dbf25575f185ff379789482068e4ecf55b9455a9 --- /dev/null +++ b/build/torch210-cxx11-rocm71-x86_64-linux/muon.py @@ -0,0 +1,1268 @@ +import logging +import math +import types +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, cast + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Replicate +from torch.distributed.tensor.placement_types import Placement + +from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor +from .matmul_transpose_triton import matmul_transpose_assign + +logger = logging.getLogger(__name__) + +COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 + + +# This code snippet is a modified version adapted from the following GitHub repositories: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +# Muon's Newton–Schulz iteration causes high variance in singular values +# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +@torch.no_grad() +# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + assert G.dtype == COMM_DTYPE + X = G # no manual typecast + + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + # Perform the NS iterations + for a, b, c in [ + (4.0848, -6.8946, 2.9270), + (3.9505, -6.3029, 2.6377), + (3.7418, -5.5913, 2.3037), + (2.8769, -3.1427, 1.2046), + (2.8366, -3.0525, 1.2012), + ]: + matmul_transpose_assign(X, buf1) + matmul_transpose_assign(buf1, buf2) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X + + +@dataclass +class _muon_state: + # TODO: use Optional + worker_rank: int + process_group: ProcessGroup + shard_mesh: DeviceMesh + shard_placements: tuple[Placement, ...] + name: str + qk_clip_state: torch.Tensor | None = None + gathered_grad: torch.Tensor | None = None + scattered_u: DTensor | None = None + computed_u: torch.Tensor | None = None + gather_event: torch.cuda.Event | None = None + compute_event: torch.cuda.Event | None = None + scatter_event: torch.cuda.Event | None = None + + +def numel_for_rank( + param: DTensor, + local_rank: int, + state: _muon_state, +) -> int: + slices = get_slices_of_dtensor( + param, + local_rank, + state.shard_mesh, + state.shard_placements, + ) + + numel = 1 + for s, dim in zip(slices, param.shape): + start, stop, step = s.indices(dim) + length = max(0, (stop - start + (step - 1)) // step) + numel *= length + + return numel + + +@torch.no_grad() +def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): + """ + Pre-allocate gathered_grad buffer on compute_stream + before launching all2all gather + """ + with torch.cuda.stream(compute_stream): + for p in params: + state = param_to_state[id(p)] + if rank == state.worker_rank: + state.gathered_grad = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + else: + state.gathered_grad = None + + alloc_event = torch.cuda.Event() + alloc_event.record(compute_stream) + return alloc_event + + +@torch.no_grad() +def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, + alloc_event): + """ + All2all gathers shards so each owner rank reconstructs its full gradient + """ + with torch.cuda.stream(comm_stream): + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + + # Construct sending buffers + per_dst = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + for p in params: + state = param_to_state[id(p)] + dst = state.worker_rank + assert dst < num_ranks + shard_elems = numel_for_rank(p, rank, state) + g = p.grad + g = g.to_local().to(COMM_DTYPE).contiguous() + assert g.numel() == shard_elems + per_dst[dst].append(g.view(-1)) + send_counts[dst] += shard_elems + + assert any( + len(v) > 0 for v in per_dst + ), "At least one destination rank must receive a sharded tensor" + # list[list[Tensor]] -> list[Tensor] + per_dst = [t for dst in per_dst for t in dst] + + send_buf = torch.cat(per_dst, dim=0) + + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Compute receive sizes and allocate receiving buffers + recv_counts = [0] * num_ranks + + for src in range(num_ranks): + total = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + total += numel_for_rank(p, src, state) + recv_counts[src] = total + + recv_total = sum(recv_counts) + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + #All2All + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") + dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + ) + + # Reconstructs gathered grad from the received buffer + # + # recv_buf (num ranks = 3) + # + # From rank 0 From rank 1 From rank 2 + # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | + # + # Outer loop: + # rank 0 -> rank 1 -> rank2 + # + # Inner loop: + # p1_n -> p2_n -> p3_n + + comm_stream.wait_event(alloc_event) + + off = 0 + for src in range(num_ranks): + if recv_counts[src] == 0: + continue + + block = recv_counts[src] + inner_off = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + + # get the slice of the full dtensor corresponding to rank src. + slices = get_slices_of_dtensor(state.gathered_grad, src, + state.shard_mesh, + state.shard_placements) + + dst = state.gathered_grad[slices] + assert dst._base is state.gathered_grad + + n = dst.numel() + assert n > 0 + + sg = recv_buf.narrow(0, off + inner_off, n) + sg = sg.reshape_as(dst) + dst.copy_(sg) + + inner_off += n + off += block + + for p in params: + state = param_to_state[id(p)] + if state.worker_rank == rank: + state.gather_event = torch.cuda.Event() + state.gather_event.record(comm_stream) + else: + state.gathered_grad = None + state.gather_event = None + if none_grad: + p.grad = None + + +@torch.no_grad() +def _compute_u(p, state, steps, rank, compute_stream): + """ + On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. + """ + with torch.cuda.stream(compute_stream): + if rank == state.worker_rank: + if state.gather_event is None: + raise RuntimeError("Gather event must be set before compute.") + compute_stream.wait_event(state.gather_event) + u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) + state.gathered_grad = None + state.computed_u = u + state.compute_event = torch.cuda.Event() + state.compute_event.record() + else: + state.computed_u = None + state.compute_event = None + + +@torch.no_grad() +def _alloc_scattered_u(params, param_to_state, rank, compute_stream): + """ + Pre-allocate scattered_u buffer on compute_stream + before launching all2all gather + """ + with torch.cuda.stream(compute_stream): + for p in params: + state = param_to_state[id(p)] + state.scattered_u = torch.empty_like(p.to_local(), + dtype=COMM_DTYPE) + + alloc_event = torch.cuda.Event() + alloc_event.record(compute_stream) + return alloc_event + + +def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): + """ + All2all scatters full gradients to all ranks + """ + with torch.cuda.stream(comm_stream): + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Construct sending buffer + per_dst = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + if owned_params: + for p in owned_params: + state = param_to_state[id(p)] + if state.compute_event is None: + raise RuntimeError( + "Compute event must be set before scatter.") + comm_stream.wait_event(state.compute_event) + state.gathered_grad = None + + assert state.computed_u is not None + + u_full = state.computed_u.to(COMM_DTYPE).contiguous() + + offset = 0 + for dst in range(num_ranks): + # get the slice of the full tensor corresponding to rank dst. + slices = get_slices_of_dtensor(u_full, dst, + state.shard_mesh, + state.shard_placements) + su = u_full[slices].flatten() + + n = su.numel() + assert n > 0 + + per_dst[dst].append(su) + send_counts[dst] += n + offset += n + + assert offset == u_full.numel() + + lengths = [len(v) for v in per_dst] + if all(l > 0 for l in lengths): + assert all( + l == lengths[0] for l in lengths + ), "All destination ranks must have the same number of sharded tensor" + # list[list[Tensor]] -> list[Tensor] + per_dst = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst, dim=0) + else: + # all_to_all requires participation from all ranks + # Even non-owner ranks must join the collective call + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") + + # Compute receive sizes and allocate receiving buffers + recv_counts = [0] * num_ranks + + for src in range(num_ranks): + total = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + total += numel_for_rank(p, rank, state) + recv_counts[src] = total + + recv_total = sum(recv_counts) + assert recv_total > 0 + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + #All2All + dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + ) + + # Copy to pre-allocated scattered_u buffer from the received buffer + # + # recv_buf (num ranks = 3, local_rank = 0) + # + # From rank 0 From rank 1 From rank 2 + # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | + # + # Outer loop: + # rank 0 -> rank 1 -> rank2 + # + # Inner loop: + # src(0) : p1_0 -> p2_0 -> p3_0 + # src(1) : p4_0 + # src(2) : p5_0 -> p6_0 + + comm_stream.wait_event(alloc_event) + + off = 0 + for src in range(num_ranks): + block = recv_counts[src] + if block == 0: + continue + + inner_off = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + n = numel_for_rank(p, rank, state) + assert n > 0 + + flat_local = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) + state.scattered_u.copy_(flat_local) + + state.scatter_event = torch.cuda.Event() + state.scatter_event.record(comm_stream) + inner_off += n + + assert inner_off == block + off += block + + +def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, + compute_stream): + """ + Update sharded parameter p with the scattered_u. + Only worker_rank frees computed_u. + """ + with torch.cuda.stream(compute_stream): + if state.scatter_event is None: + raise RuntimeError("Scatter event must be set before update") + compute_stream.wait_event(state.scatter_event) + u_dtensor = DTensor.from_local( + state.scattered_u, + placements=p.placements, + device_mesh=p.device_mesh, + ) + + state.scattered_u = u_dtensor + + if rank == state.worker_rank: + # Free computed_u + state.computed_u = None + + Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) + state.scattered_u = None + u_dtensor = None + + scales_full = Muon._compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None + if scales_full is not None: + # Have to slice scales_full among dim 0 + weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, + state.shard_placements) + ratio = p.shape[0] // scales_full.shape[0] + scales_slice = slice( + None if weight_slices[0].start is None else + weight_slices[0].start // ratio, + None if weight_slices[0].stop is None else + weight_slices[0].stop // ratio, + None, + ) + + scales_local = scales_full[scales_slice] + scales_local = DTensor.from_local( + scales_local, + placements=p.placements, + device_mesh=p.device_mesh, + ) + Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) + + +def default_is_muon(name, x): + skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] + return x.ndim >= 2 and not any(key in name for key in skip_keys) + + +def get_default_muon_param_groups(model, is_muon_func=default_is_muon): + muon_params, muon_names = [], [] + non_muon_params = [] + + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + if is_muon_func(n, p): + muon_params.append(p) + muon_names.append(n) + else: + non_muon_params.append(p) + + return [ + { + "params": muon_params, + "names": muon_names, + "use_muon": True, + }, + { + "params": non_muon_params, + "use_muon": False, + }, + ] + + +def parse_qk_layer(name: str) -> tuple[str | None, int]: + """ + Parse a parameter name to check if it is a query/key projection layer + ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + + Returns: + (kind, layer_idx) or (None, -1) if not matched. + + Example: + 'model.3.attn.wq.weight' -> ('wq', 3) + 'model.5.attn.wk.weight' -> ('wk', 5) + 'model.2.attn.q_proj.weight' -> ('q_proj', 2) + 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.4.attn.v_proj.weight' -> (None, -1) + """ + parts = name.split('.') + if len(parts) < 3: + return None, -1 + + kind = parts[-2] + + layer_idx = -1 + for part in reversed(parts): + if part.isdigit(): + layer_idx = int(part) + break + + if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + return kind, layer_idx + + return None, -1 + + +@dataclass +class QKClipInfo: + """Per-parameter dynamic info computed from config + runtime logits.""" + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping + head_dim: int # from config + threshold: float # from config + logit: torch.Tensor | None + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + model: The model to be optimized by Muon. + is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + weight_decay: The weight decay for Muon and AdamW. + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. + debug: Whether to print debug information. + clip_info : Configuration for QK clipping. Expected keys: + - "q_indices" (list[int]): Indices of query heads to consider. + - "k_indices" (list[int]): Indices of key heads to consider. + - "head_dim" (int): Dimensionality of each attention head. + - "threshold" (float): Threshold value; heads whose QK logits exceed + this value will be scaled down. + Default is: + { + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100 + } + warmup_step : How many all2all gather, compute operations are launched in advance + before the corresponding all2all scatter steps begin. + A higher warmup_step increases memory usage but can improve + performance by overlapping communication. + Parallel muon only. + chunk_size : Batch size of parameters to process in each + all2all gather/compute/scatter step. + Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. + use_distributed_muon: Use distributed muon by Liu et al. (2024). + For testing purpose only. + small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon + """ + + def __init__(self, + params, + lr=1e-3, + momentum=0.95, + nesterov=True, + ns_steps=5, + weight_decay=0.1, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + none_grad=True, + debug=False, + clip_config={ + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100 + }, + warmup_step=5, + chunk_size=-1, + use_distributed_muon=False, + small_param_numel_threshold=65536): + defaults = dict( + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + none_grad=none_grad, + use_muon=True, + ) + error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." + instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" + + if isinstance(params, types.GeneratorType): + raise ValueError(error_message.format(idx=0) + instruction_code) + for _idx, param_group in enumerate(params): + if param_group.get("use_muon", None) is None: + raise ValueError( + error_message.format(idx=_idx) + instruction_code) + + super().__init__(params, defaults) + + self.rank = None + + self.comm_stream = torch.cuda.Stream() + self.compute_stream = torch.cuda.Stream() + self.debug = debug + self.clip_config = clip_config + self.warmup_step = warmup_step + self.chunk_size = chunk_size + self.use_distributed_muon = use_distributed_muon + self.small_param_numel_threshold = small_param_numel_threshold + + def _calc_flops(self, G, steps): + assert len(G.shape) == 2 + M, N = G.shape + if M > N: + M, N = N, M + + return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def set_rank_once(self, rank): + if self.rank is None: + self.rank = rank + else: + assert self.rank == rank + + def get_shard_mesh(self, p): + """ + Get the shard mesh for a parameter p on the given rank. + """ + assert isinstance( + p, DTensor), "Parallel Muon only supports DTensor parameters." + + shard_mesh, shard_pg, shard_placements = construct_shard_mesh( + p.placements, p.device_mesh) + + # set rank with the local rank in the shard process group + self.set_rank_once(dist.get_rank(group=shard_pg)) + + return shard_mesh, shard_pg, shard_placements + + def init_state_and_assign_params(self, names, params, group, qk_logits): + param_to_state = {} + param_to_flops = {} + + total_flops = 0 + for p in params: + g = p.grad + if g is None: + continue + assert g.ndim == 2, "Muon only supports 2D parameters." + + flops = self._calc_flops(g, group["ns_steps"]) + param_to_flops[id(p)] = flops + total_flops += flops + + if self.debug: + print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", + flush=True) + + paired = list(zip(names, params)) + + paired_sorted = sorted(paired, + key=lambda x: param_to_flops[id(x[1])], + reverse=True) + + names_sorted, params_sorted = zip(*paired_sorted) + ordered_names = list(names_sorted) + ordered_params = list(params_sorted) + + round_robin = 0 + mesh = ordered_params[0].device_mesh + placements = ordered_params[0].placements + + shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( + ordered_params[0]) + shard_mesh_flattened = shard_mesh.mesh.flatten() + num_ranks = dist.get_world_size(group=shard_pg) + + for n, p in zip(ordered_names, ordered_params): + if mesh != p.device_mesh: + raise ValueError("All parameters must be on the same mesh.") + if placements != p.placements: + raise ValueError("All parameters must have same placements.") + + worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks + round_robin = (round_robin + 1) % len(shard_mesh_flattened) + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + param_to_state[id(p)] = _muon_state( + worker_rank=worker_rank, + process_group=shard_pg, + shard_mesh=shard_mesh, + shard_placements=shard_placements, + name=n, + qk_clip_state=qk_clip_state, + ) + + return param_to_state, ordered_params + + def base(self, names, params, group, lr, weight_decay, momentum, + qk_logits): + # generate weight updates in distributed fashion + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + g = self._update_g(p, g, group, momentum) + + u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), + steps=group["ns_steps"]) + + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + scales_full = self._compute_scales( + p, qk_clip_state) if qk_clip_state is not None else None + if scales_full is not None: + Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + + def distributed_muon( + self, + names: list[str], + params: list[torch.nn.Parameter], + group: dict[str, Any], + lr: float, + weight_decay: float, + momentum: float, + qk_logits: list[torch.Tensor | DTensor] | None, + ): + """ Implementation of Distributed Muon by Liu et al. """ + + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + g = self._update_g(p, g, group, momentum) + + # Gather G + if isinstance(p.data, DTensor): + g_full = g.full_tensor() + p_full = p.data.full_tensor() + else: + g_full = g + p_full = p + + u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), + steps=group["ns_steps"]) + + adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) + Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + scales_full = self._compute_scales( + p_full, qk_clip_state) if qk_clip_state is not None else None + + if scales_full is not None: + Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) + + if isinstance(p.data, DTensor): + ndims = len(p.device_mesh.mesh.shape) + p_replicate = DTensor.from_local( + p_full, + device_mesh=p.device_mesh, + placements=[Replicate() for _ in range(ndims)], + ) + + p_sharded = p_replicate.redistribute( + device_mesh=p.device_mesh, + placements=p.placements, + ) + + p.copy_(p_sharded) + + def _update_g(self, p, g, group, momentum): + # calc update + state = self.state[p] + buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) + torch.add(g, buf, alpha=momentum, out=buf) + if group["nesterov"]: + g.add_(buf, alpha=momentum) + return g + return buf + + @staticmethod + def _update_p(p, u, lr, adjusted_lr, weight_decay): + if isinstance(p, torch.nn.Parameter): + # apply weight decay + p.data.mul_(1 - lr * weight_decay) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + else: + p.mul_(1 - lr * weight_decay) + p.add_(u, alpha=-adjusted_lr) + + def get_qk_clip_info(self, n, qk_logits): + if self.clip_config is None: + return None + + head_dim = self.clip_config.get('head_dim') + threshold = self.clip_config.get('threshold') + kind, layer_idx = parse_qk_layer(n) + + logit, indices = None, [] + if qk_logits is not None and kind is not None: + logit = qk_logits[layer_idx] + indices_key = 'q_indices' if 'q' in kind else 'k_indices' + indices = self.clip_config.get(indices_key, []) or [] + + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) + + @staticmethod + def _compute_scales(p, qk_clip_state): + kind = qk_clip_state.kind + indices = qk_clip_state.indices + head_dim = qk_clip_state.head_dim + threshold = qk_clip_state.threshold + logit = qk_clip_state.logit + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + scaling = 0 + + for logit_idx, head_idx in enumerate(indices): + v_ele = float(logit[logit_idx]) + if v_ele > threshold: + new_scale = math.sqrt(threshold / v_ele) + if new_scale < scales_full[head_idx]: + scales_full[head_idx] = new_scale + logger.info( + f"[{kind}] Head {head_idx} exceeded threshold " + f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" + ) + scaling += 1 + + return scales_full if scaling > 0 else None + + @staticmethod + def _qk_clip(p, scales, head_dim): + if isinstance(p, torch.nn.Parameter): + W = p.data.view(-1, head_dim, p.data.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + else: + W = p.view(-1, head_dim, p.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + + def parallel(self, names, params, group, lr, weight_decay, momentum, + qk_logits): + """ + Perform a parallel optimization step using Muon. + """ + + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + + # Update g in the local rank + g = self._update_g( + p, + g, + group, + momentum=momentum, + ) + p.grad = g + + param_to_state, ordered_params = self.init_state_and_assign_params( + names, params, group, qk_logits) + + assert self.rank is not None + + def enqueue_all2all_gather(start_idx, chunk_size): + target_params = ordered_params[start_idx:start_idx + chunk_size] + if target_params: + alloc_event = _alloc_gathered_grad(target_params, + param_to_state, self.rank, + self.compute_stream) + _all2all_gather(target_params, param_to_state, self.rank, + self.comm_stream, group["none_grad"], + alloc_event) + + def enqueue_computes(start_idx, chunk_size): + for p in ordered_params[start_idx:start_idx + chunk_size]: + state = param_to_state[id(p)] + _compute_u(p, state, group["ns_steps"], self.rank, + self.compute_stream) + + def enqueue_all2all_scatter(start_idx, chunk_size): + target_params = ordered_params[start_idx:start_idx + chunk_size] + if target_params: + alloc_event = _alloc_scattered_u(target_params, param_to_state, + self.rank, + self.compute_stream) + _all2all_scatter(target_params, param_to_state, self.rank, + self.comm_stream, alloc_event) + + def enqueue_update_param(start_idx, chunk_size): + for p in ordered_params[start_idx:start_idx + chunk_size]: + state = param_to_state[id(p)] + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + _update_param(p, state, lr, adjusted_lr, weight_decay, + self.rank, self.compute_stream) + + if self.chunk_size == -1: + shard_ranks = dist.get_world_size(param_to_state[id( + params[0])].process_group) + chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO + elif self.chunk_size > 0: + chunk_size = self.chunk_size + else: + raise ValueError("chunk_size must be -1 or a positive integer.") + + # Wait grad update + self.comm_stream.wait_stream(torch.cuda.current_stream()) + + warmup_step = self.warmup_step + for i in range(0, warmup_step): + enqueue_all2all_gather(i * chunk_size, chunk_size) + enqueue_computes(i * chunk_size, chunk_size) + + for i in range(0, len(params) + chunk_size - 1, chunk_size): + enqueue_all2all_scatter(i, chunk_size) + enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) + enqueue_update_param(i, chunk_size) + enqueue_computes(i + warmup_step * chunk_size, chunk_size) + + # Wait the last update_param to finish + torch.cuda.current_stream().wait_stream(self.compute_stream) + + @staticmethod + def _fused_adamw( + params: list[torch.Tensor], + grads: list[torch.Tensor], + exp_avgs: list[torch.Tensor], + exp_avg_sqs: list[torch.Tensor], + max_exp_avg_sqs: list[torch.Tensor], + state_steps: list[torch.Tensor], + amsgrad: bool, + beta1: float, + beta2: float, + lr: float | torch.Tensor, + weight_decay: float, + eps: float, + maximize: bool, + ) -> None: + if not params: + return + + # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer + # treating it as a scalar. + lr_dict: DeviceDict | None = ({ + lr.device: lr + } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else + None) + grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( + [ + params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, + state_steps + ] # type: ignore[list-item] + ) + for (device, _), ( + ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_max_exp_avg_sqs, + device_state_steps_, + ), + _, + ) in grouped_tensors.items(): + device_params = cast(list[torch.Tensor], device_params_) + device_grads = cast(list[torch.Tensor], device_grads_) + device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) + device_state_steps = cast(list[torch.Tensor], device_state_steps_) + + if lr_dict is not None and device not in lr_dict: + lr_dict[device] = lr.to( + device=device, + non_blocking=True) # type: ignore[union-attr] + lr = lr_dict[device] + torch._foreach_add_(device_state_steps, 1) + func = torch._fused_adamw_ + func( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, # type: ignore[arg-type] + device_state_steps, + amsgrad=amsgrad, + lr=lr, # type: ignore[arg-type] + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + ) + + def _step_muon(self, group, qk_logits=None): + params = group["params"] + lr = group["lr"] + weight_decay = group["weight_decay"] + momentum = group["momentum"] + names = group["names"] + + param_dtensors = [] + name_dtensors = [] + + param_tensors = [] + name_tensors = [] + + param_dtensors_small = [] + name_dtensors_small = [] + + if self.use_distributed_muon: + self.distributed_muon(names=names, + params=params, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits) + return + + # For simplicity, we use distributed Muon for small parameters + # whose number of elements is below a threshold. + for n, p in zip(names, params): + if p is None or p.grad is None: + continue + if isinstance(p.data, DTensor): + if all( + isinstance(placement, Replicate) + for placement in p.placements): + param_tensors.append(p) + name_tensors.append(n) + elif p.data.numel() <= self.small_param_numel_threshold: + param_dtensors_small.append(p) + name_dtensors_small.append(n) + else: + param_dtensors.append(p) + name_dtensors.append(n) + elif isinstance(p.data, torch.Tensor): + param_tensors.append(p) + name_tensors.append(n) + else: + raise TypeError(f"Unsupported parameter type: {type(p.data)}") + + logger.debug( + f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " + f"{len(param_dtensors_small)} Small DTensors") + + def group_dtensors(dtensors, names): + # To support different placements, we group parameters by placements + # and run parallel Muon on each group. + + placement_to_params = defaultdict(lambda: ([], [])) + # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] + + assert len(dtensors) == len(names) + for p, n in zip(dtensors, names): + placement_to_params[tuple([p.placements, + p.device_mesh])][0].append(n) + placement_to_params[tuple([p.placements, + p.device_mesh])][1].append(p) + return placement_to_params + + if len(param_dtensors_small) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + self.distributed_muon( + params=param_dtensors_small, + names=name_dtensors_small, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_dtensors) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + dtensor_group = group_dtensors(param_dtensors, name_dtensors) + for _, (names, params) in dtensor_group.items(): + self.parallel( + names, + params, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_tensors) > 0: + self.base( + name_tensors, + param_tensors, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + def _step_adamw_params(self, params, group): + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + self._fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + def _step_adamw(self, group): + params = group["params"] + + # group params with it's type and placement + placement_to_params: dict[tuple[Placement | type, + DeviceMesh | None]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for params in placement_to_params.values(): + self._step_adamw_params(params, group) + + @torch.no_grad + def step(self, closure=None, qk_logits=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices + to 1D tensors of shape (num_heads,), representing the maximum + QK logits across all tokens, computed as + (1 / sqrt(head_dim)) * (Q @ K^T). + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + if group["use_muon"]: + self._step_muon(group, qk_logits=qk_logits) + else: + self._step_adamw(group) + + return loss diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/optimizer/__init__.py b/build/torch210-cxx11-rocm71-x86_64-linux/optimizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch210-cxx11-rocm71-x86_64-linux/optimizer/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch28-cxx11-cu126-x86_64-linux/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..239c7a65f8293e7d0df28f05fce645af56d628c0 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/__init__.py @@ -0,0 +1,5 @@ +from .muon import Muon + +__all__ = [ + "Muon", +] diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_ops.py b/build/torch28-cxx11-cu126-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..e6f6fcf6280e969b1761926112147d3146e27b59 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _optimizer_06a260a_dirty +ops = torch.ops._optimizer_06a260a_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/_optimizer_06a260a_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..a218cd77694938fb0914270a5c6416a684d50cb3 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/_optimizer_06a260a_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:222315672693e6d4544b1eee4772dc7be744b3794cfd6ff370a6f46d782386a1 +size 1936664 diff --git a/build/torch28-cxx11-cu126-x86_64-linux/distributed/utils.py b/build/torch28-cxx11-cu126-x86_64-linux/distributed/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6d5843506c13d9d31603b2b4e30c1c91d0baab28 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/distributed/utils.py @@ -0,0 +1,175 @@ +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor +from torch.distributed.tensor.placement_types import (Placement, Shard, + _StridedShard) + + +def get_slices_of_dtensor( + target: DTensor | torch.Tensor, + local_rank: int, + shard_mesh: DeviceMesh, + shard_placements: tuple[Placement], +) -> tuple[slice]: + """ + Get the slice of local tensor for a given rank from a tensor. + Args: + target (DTensor | torch.Tensor): The target tensor. + rank (int): The local rank of the shard group. + shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + shard_placements (tuple[Placement]): The shard placements. + """ + + slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + + # find the global rank of the local rank in the shard mesh + rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] + + rank_coords = (shard_mesh.mesh == rank).nonzero() + + assert len(rank_coords) == 1 + rank_coords = tuple(rank_coords[0].tolist()) + + assert len(rank_coords) == len(shard_placements) + + # Caution: Assuming replicate-to-shard of the shard mesh goes with + # left-to-right sharding. This is ensured by the sorting logic of + # construct_shard_mesh function. + for i, (rank_coord, + placement) in enumerate(zip(rank_coords, shard_placements)): + assert isinstance(placement, Shard) + + num_ranks = shard_mesh.mesh.shape[i] + + dim = placement.dim + dim_size = (slices[dim].stop - slices[dim].start) + + if dim_size % num_ranks != 0: + raise NotImplementedError( + f"Dimension size {dim_size} is not divisible " + f"by number of ranks {num_ranks} for shard " + f"placement on dim {dim}. (shape: {target.shape})") + + shard_size = dim_size // num_ranks + + start = slices[dim].start + rank_coord * shard_size + end = start + shard_size + + assert start < end <= slices[dim].stop + + slices[dim] = slice(start, end) + + return tuple(slices) + + +_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, + ProcessGroup]] = dict() + + +def construct_shard_mesh( + placements: tuple[Placement], + mesh: DeviceMesh, +) -> (DeviceMesh, ProcessGroup, tuple[Placement]): + """ + Construct Shard Mesh and Placements for unsharding. + It removes Replicate placements and constructs a new Mesh and ProcessGroup. + """ + my_rank = dist.get_rank() + + assert mesh.mesh.device.type == 'cpu' + + # Copy mesh to avoid modifying the original mesh + mesh = mesh.mesh.clone() + + # 1. Sort placements. Replicate first, then Shard by dim ascending. + + # For Shard, strided shard comes after regular shard on the same dim + # to preserve left-to-right order of replicate-to-shard. + # This is because that strided shard is using stride to represent + # more fine-grained sharding on the same dim. + # Please check the URL below for _StridedShard. + # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 + + def placement_sort_key( + placement_with_index: tuple[float, Placement] + ) -> tuple[int, float, int]: # (dim, split factor, original index) + index, placement = placement_with_index + is_replicate = placement.is_replicate() + is_shard = placement.is_shard() + is_partial = placement.is_partial() + + assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" + assert not is_partial, "Partial placement is not supported." + + if is_replicate: + return (-1.0, 0, index) + elif is_shard: + if isinstance(placement, _StridedShard): + return (placement.dim, 1 / placement.split_factor, index) + return (placement.dim, 0, index) + else: + raise TypeError(f"Unknown placement type: {type(placement)}") + + placements_with_index: list[tuple[int, + Placement]] = list(enumerate(placements)) + placements_with_index = sorted(placements_with_index, + key=placement_sort_key) + + sorted_indices, sorted_placements = zip(*placements_with_index) + + # 2. Permute mesh according to sorted placements. + sorted_mesh = mesh.permute(sorted_indices) + + # 3. Collect list of shard meshes by removing replicate dims + # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] + # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) + num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + + # merge replicate dims + # shard_meshes became a list of shard meshes with a length of replicate degree + if num_replicates > 0: + sorted_mesh = sorted_mesh.flatten( + 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) + else: + shard_meshes = [sorted_mesh] + shard_placements = sorted_placements[num_replicates:] + + # assume all shard placements are different + assert len(shard_placements) == len(set(shard_placements)) + + # 4. Construct ProcessGroups + # Caution: all groups should be created in the same order in all processes, + # even though each process only needs its own group. + + # To use tensor as dict key, convert it to tuple + def tensor_to_tuple(t): + if isinstance(t, torch.Tensor): + t = t.tolist() + if isinstance(t, list): + return tuple(tensor_to_tuple(x) for x in t) + return t + + my_shard_mesh_as_tuple = None + for shard_mesh in shard_meshes: + assert isinstance(shard_mesh, torch.Tensor) + shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) + + if (my_rank == shard_mesh).any().item(): + assert my_shard_mesh_as_tuple is None + my_shard_mesh_as_tuple = shard_mesh_as_tuple + + # update global cache + if shard_mesh_as_tuple not in _ranks_to_dist_cache: + shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) + _ranks_to_dist_cache[shard_mesh_as_tuple] = ( + DeviceMesh(device_type="cuda", mesh=shard_mesh), + shard_process_group, + ) + + my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ + my_shard_mesh_as_tuple] + + return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch29-cxx11-cu126-x86_64-linux/optimizer/matmul_transpose_triton.py b/build/torch28-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py similarity index 100% rename from build/torch29-cxx11-cu126-x86_64-linux/optimizer/matmul_transpose_triton.py rename to build/torch28-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py diff --git a/build/torch28-cxx11-cu126-x86_64-linux/metadata.json b/build/torch28-cxx11-cu126-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..76bafa5f33b6818aa6bb4cab04be811b87519b44 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/metadata.json @@ -0,0 +1 @@ +{"python-depends":[]} \ No newline at end of file diff --git a/build/torch28-cxx11-cu126-x86_64-linux/muon.py b/build/torch28-cxx11-cu126-x86_64-linux/muon.py new file mode 100644 index 0000000000000000000000000000000000000000..dbf25575f185ff379789482068e4ecf55b9455a9 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/muon.py @@ -0,0 +1,1268 @@ +import logging +import math +import types +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, cast + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Replicate +from torch.distributed.tensor.placement_types import Placement + +from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor +from .matmul_transpose_triton import matmul_transpose_assign + +logger = logging.getLogger(__name__) + +COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 + + +# This code snippet is a modified version adapted from the following GitHub repositories: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +# Muon's Newton–Schulz iteration causes high variance in singular values +# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +@torch.no_grad() +# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + assert G.dtype == COMM_DTYPE + X = G # no manual typecast + + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + # Perform the NS iterations + for a, b, c in [ + (4.0848, -6.8946, 2.9270), + (3.9505, -6.3029, 2.6377), + (3.7418, -5.5913, 2.3037), + (2.8769, -3.1427, 1.2046), + (2.8366, -3.0525, 1.2012), + ]: + matmul_transpose_assign(X, buf1) + matmul_transpose_assign(buf1, buf2) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X + + +@dataclass +class _muon_state: + # TODO: use Optional + worker_rank: int + process_group: ProcessGroup + shard_mesh: DeviceMesh + shard_placements: tuple[Placement, ...] + name: str + qk_clip_state: torch.Tensor | None = None + gathered_grad: torch.Tensor | None = None + scattered_u: DTensor | None = None + computed_u: torch.Tensor | None = None + gather_event: torch.cuda.Event | None = None + compute_event: torch.cuda.Event | None = None + scatter_event: torch.cuda.Event | None = None + + +def numel_for_rank( + param: DTensor, + local_rank: int, + state: _muon_state, +) -> int: + slices = get_slices_of_dtensor( + param, + local_rank, + state.shard_mesh, + state.shard_placements, + ) + + numel = 1 + for s, dim in zip(slices, param.shape): + start, stop, step = s.indices(dim) + length = max(0, (stop - start + (step - 1)) // step) + numel *= length + + return numel + + +@torch.no_grad() +def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): + """ + Pre-allocate gathered_grad buffer on compute_stream + before launching all2all gather + """ + with torch.cuda.stream(compute_stream): + for p in params: + state = param_to_state[id(p)] + if rank == state.worker_rank: + state.gathered_grad = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + else: + state.gathered_grad = None + + alloc_event = torch.cuda.Event() + alloc_event.record(compute_stream) + return alloc_event + + +@torch.no_grad() +def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, + alloc_event): + """ + All2all gathers shards so each owner rank reconstructs its full gradient + """ + with torch.cuda.stream(comm_stream): + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + + # Construct sending buffers + per_dst = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + for p in params: + state = param_to_state[id(p)] + dst = state.worker_rank + assert dst < num_ranks + shard_elems = numel_for_rank(p, rank, state) + g = p.grad + g = g.to_local().to(COMM_DTYPE).contiguous() + assert g.numel() == shard_elems + per_dst[dst].append(g.view(-1)) + send_counts[dst] += shard_elems + + assert any( + len(v) > 0 for v in per_dst + ), "At least one destination rank must receive a sharded tensor" + # list[list[Tensor]] -> list[Tensor] + per_dst = [t for dst in per_dst for t in dst] + + send_buf = torch.cat(per_dst, dim=0) + + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Compute receive sizes and allocate receiving buffers + recv_counts = [0] * num_ranks + + for src in range(num_ranks): + total = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + total += numel_for_rank(p, src, state) + recv_counts[src] = total + + recv_total = sum(recv_counts) + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + #All2All + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") + dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + ) + + # Reconstructs gathered grad from the received buffer + # + # recv_buf (num ranks = 3) + # + # From rank 0 From rank 1 From rank 2 + # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | + # + # Outer loop: + # rank 0 -> rank 1 -> rank2 + # + # Inner loop: + # p1_n -> p2_n -> p3_n + + comm_stream.wait_event(alloc_event) + + off = 0 + for src in range(num_ranks): + if recv_counts[src] == 0: + continue + + block = recv_counts[src] + inner_off = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + + # get the slice of the full dtensor corresponding to rank src. + slices = get_slices_of_dtensor(state.gathered_grad, src, + state.shard_mesh, + state.shard_placements) + + dst = state.gathered_grad[slices] + assert dst._base is state.gathered_grad + + n = dst.numel() + assert n > 0 + + sg = recv_buf.narrow(0, off + inner_off, n) + sg = sg.reshape_as(dst) + dst.copy_(sg) + + inner_off += n + off += block + + for p in params: + state = param_to_state[id(p)] + if state.worker_rank == rank: + state.gather_event = torch.cuda.Event() + state.gather_event.record(comm_stream) + else: + state.gathered_grad = None + state.gather_event = None + if none_grad: + p.grad = None + + +@torch.no_grad() +def _compute_u(p, state, steps, rank, compute_stream): + """ + On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. + """ + with torch.cuda.stream(compute_stream): + if rank == state.worker_rank: + if state.gather_event is None: + raise RuntimeError("Gather event must be set before compute.") + compute_stream.wait_event(state.gather_event) + u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) + state.gathered_grad = None + state.computed_u = u + state.compute_event = torch.cuda.Event() + state.compute_event.record() + else: + state.computed_u = None + state.compute_event = None + + +@torch.no_grad() +def _alloc_scattered_u(params, param_to_state, rank, compute_stream): + """ + Pre-allocate scattered_u buffer on compute_stream + before launching all2all gather + """ + with torch.cuda.stream(compute_stream): + for p in params: + state = param_to_state[id(p)] + state.scattered_u = torch.empty_like(p.to_local(), + dtype=COMM_DTYPE) + + alloc_event = torch.cuda.Event() + alloc_event.record(compute_stream) + return alloc_event + + +def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): + """ + All2all scatters full gradients to all ranks + """ + with torch.cuda.stream(comm_stream): + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Construct sending buffer + per_dst = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + if owned_params: + for p in owned_params: + state = param_to_state[id(p)] + if state.compute_event is None: + raise RuntimeError( + "Compute event must be set before scatter.") + comm_stream.wait_event(state.compute_event) + state.gathered_grad = None + + assert state.computed_u is not None + + u_full = state.computed_u.to(COMM_DTYPE).contiguous() + + offset = 0 + for dst in range(num_ranks): + # get the slice of the full tensor corresponding to rank dst. + slices = get_slices_of_dtensor(u_full, dst, + state.shard_mesh, + state.shard_placements) + su = u_full[slices].flatten() + + n = su.numel() + assert n > 0 + + per_dst[dst].append(su) + send_counts[dst] += n + offset += n + + assert offset == u_full.numel() + + lengths = [len(v) for v in per_dst] + if all(l > 0 for l in lengths): + assert all( + l == lengths[0] for l in lengths + ), "All destination ranks must have the same number of sharded tensor" + # list[list[Tensor]] -> list[Tensor] + per_dst = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst, dim=0) + else: + # all_to_all requires participation from all ranks + # Even non-owner ranks must join the collective call + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") + + # Compute receive sizes and allocate receiving buffers + recv_counts = [0] * num_ranks + + for src in range(num_ranks): + total = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + total += numel_for_rank(p, rank, state) + recv_counts[src] = total + + recv_total = sum(recv_counts) + assert recv_total > 0 + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + #All2All + dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + ) + + # Copy to pre-allocated scattered_u buffer from the received buffer + # + # recv_buf (num ranks = 3, local_rank = 0) + # + # From rank 0 From rank 1 From rank 2 + # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | + # + # Outer loop: + # rank 0 -> rank 1 -> rank2 + # + # Inner loop: + # src(0) : p1_0 -> p2_0 -> p3_0 + # src(1) : p4_0 + # src(2) : p5_0 -> p6_0 + + comm_stream.wait_event(alloc_event) + + off = 0 + for src in range(num_ranks): + block = recv_counts[src] + if block == 0: + continue + + inner_off = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + n = numel_for_rank(p, rank, state) + assert n > 0 + + flat_local = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) + state.scattered_u.copy_(flat_local) + + state.scatter_event = torch.cuda.Event() + state.scatter_event.record(comm_stream) + inner_off += n + + assert inner_off == block + off += block + + +def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, + compute_stream): + """ + Update sharded parameter p with the scattered_u. + Only worker_rank frees computed_u. + """ + with torch.cuda.stream(compute_stream): + if state.scatter_event is None: + raise RuntimeError("Scatter event must be set before update") + compute_stream.wait_event(state.scatter_event) + u_dtensor = DTensor.from_local( + state.scattered_u, + placements=p.placements, + device_mesh=p.device_mesh, + ) + + state.scattered_u = u_dtensor + + if rank == state.worker_rank: + # Free computed_u + state.computed_u = None + + Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) + state.scattered_u = None + u_dtensor = None + + scales_full = Muon._compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None + if scales_full is not None: + # Have to slice scales_full among dim 0 + weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, + state.shard_placements) + ratio = p.shape[0] // scales_full.shape[0] + scales_slice = slice( + None if weight_slices[0].start is None else + weight_slices[0].start // ratio, + None if weight_slices[0].stop is None else + weight_slices[0].stop // ratio, + None, + ) + + scales_local = scales_full[scales_slice] + scales_local = DTensor.from_local( + scales_local, + placements=p.placements, + device_mesh=p.device_mesh, + ) + Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) + + +def default_is_muon(name, x): + skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] + return x.ndim >= 2 and not any(key in name for key in skip_keys) + + +def get_default_muon_param_groups(model, is_muon_func=default_is_muon): + muon_params, muon_names = [], [] + non_muon_params = [] + + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + if is_muon_func(n, p): + muon_params.append(p) + muon_names.append(n) + else: + non_muon_params.append(p) + + return [ + { + "params": muon_params, + "names": muon_names, + "use_muon": True, + }, + { + "params": non_muon_params, + "use_muon": False, + }, + ] + + +def parse_qk_layer(name: str) -> tuple[str | None, int]: + """ + Parse a parameter name to check if it is a query/key projection layer + ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + + Returns: + (kind, layer_idx) or (None, -1) if not matched. + + Example: + 'model.3.attn.wq.weight' -> ('wq', 3) + 'model.5.attn.wk.weight' -> ('wk', 5) + 'model.2.attn.q_proj.weight' -> ('q_proj', 2) + 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.4.attn.v_proj.weight' -> (None, -1) + """ + parts = name.split('.') + if len(parts) < 3: + return None, -1 + + kind = parts[-2] + + layer_idx = -1 + for part in reversed(parts): + if part.isdigit(): + layer_idx = int(part) + break + + if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + return kind, layer_idx + + return None, -1 + + +@dataclass +class QKClipInfo: + """Per-parameter dynamic info computed from config + runtime logits.""" + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping + head_dim: int # from config + threshold: float # from config + logit: torch.Tensor | None + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + model: The model to be optimized by Muon. + is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + weight_decay: The weight decay for Muon and AdamW. + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. + debug: Whether to print debug information. + clip_info : Configuration for QK clipping. Expected keys: + - "q_indices" (list[int]): Indices of query heads to consider. + - "k_indices" (list[int]): Indices of key heads to consider. + - "head_dim" (int): Dimensionality of each attention head. + - "threshold" (float): Threshold value; heads whose QK logits exceed + this value will be scaled down. + Default is: + { + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100 + } + warmup_step : How many all2all gather, compute operations are launched in advance + before the corresponding all2all scatter steps begin. + A higher warmup_step increases memory usage but can improve + performance by overlapping communication. + Parallel muon only. + chunk_size : Batch size of parameters to process in each + all2all gather/compute/scatter step. + Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. + use_distributed_muon: Use distributed muon by Liu et al. (2024). + For testing purpose only. + small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon + """ + + def __init__(self, + params, + lr=1e-3, + momentum=0.95, + nesterov=True, + ns_steps=5, + weight_decay=0.1, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + none_grad=True, + debug=False, + clip_config={ + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100 + }, + warmup_step=5, + chunk_size=-1, + use_distributed_muon=False, + small_param_numel_threshold=65536): + defaults = dict( + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + none_grad=none_grad, + use_muon=True, + ) + error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." + instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" + + if isinstance(params, types.GeneratorType): + raise ValueError(error_message.format(idx=0) + instruction_code) + for _idx, param_group in enumerate(params): + if param_group.get("use_muon", None) is None: + raise ValueError( + error_message.format(idx=_idx) + instruction_code) + + super().__init__(params, defaults) + + self.rank = None + + self.comm_stream = torch.cuda.Stream() + self.compute_stream = torch.cuda.Stream() + self.debug = debug + self.clip_config = clip_config + self.warmup_step = warmup_step + self.chunk_size = chunk_size + self.use_distributed_muon = use_distributed_muon + self.small_param_numel_threshold = small_param_numel_threshold + + def _calc_flops(self, G, steps): + assert len(G.shape) == 2 + M, N = G.shape + if M > N: + M, N = N, M + + return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def set_rank_once(self, rank): + if self.rank is None: + self.rank = rank + else: + assert self.rank == rank + + def get_shard_mesh(self, p): + """ + Get the shard mesh for a parameter p on the given rank. + """ + assert isinstance( + p, DTensor), "Parallel Muon only supports DTensor parameters." + + shard_mesh, shard_pg, shard_placements = construct_shard_mesh( + p.placements, p.device_mesh) + + # set rank with the local rank in the shard process group + self.set_rank_once(dist.get_rank(group=shard_pg)) + + return shard_mesh, shard_pg, shard_placements + + def init_state_and_assign_params(self, names, params, group, qk_logits): + param_to_state = {} + param_to_flops = {} + + total_flops = 0 + for p in params: + g = p.grad + if g is None: + continue + assert g.ndim == 2, "Muon only supports 2D parameters." + + flops = self._calc_flops(g, group["ns_steps"]) + param_to_flops[id(p)] = flops + total_flops += flops + + if self.debug: + print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", + flush=True) + + paired = list(zip(names, params)) + + paired_sorted = sorted(paired, + key=lambda x: param_to_flops[id(x[1])], + reverse=True) + + names_sorted, params_sorted = zip(*paired_sorted) + ordered_names = list(names_sorted) + ordered_params = list(params_sorted) + + round_robin = 0 + mesh = ordered_params[0].device_mesh + placements = ordered_params[0].placements + + shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( + ordered_params[0]) + shard_mesh_flattened = shard_mesh.mesh.flatten() + num_ranks = dist.get_world_size(group=shard_pg) + + for n, p in zip(ordered_names, ordered_params): + if mesh != p.device_mesh: + raise ValueError("All parameters must be on the same mesh.") + if placements != p.placements: + raise ValueError("All parameters must have same placements.") + + worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks + round_robin = (round_robin + 1) % len(shard_mesh_flattened) + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + param_to_state[id(p)] = _muon_state( + worker_rank=worker_rank, + process_group=shard_pg, + shard_mesh=shard_mesh, + shard_placements=shard_placements, + name=n, + qk_clip_state=qk_clip_state, + ) + + return param_to_state, ordered_params + + def base(self, names, params, group, lr, weight_decay, momentum, + qk_logits): + # generate weight updates in distributed fashion + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + g = self._update_g(p, g, group, momentum) + + u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), + steps=group["ns_steps"]) + + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + scales_full = self._compute_scales( + p, qk_clip_state) if qk_clip_state is not None else None + if scales_full is not None: + Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + + def distributed_muon( + self, + names: list[str], + params: list[torch.nn.Parameter], + group: dict[str, Any], + lr: float, + weight_decay: float, + momentum: float, + qk_logits: list[torch.Tensor | DTensor] | None, + ): + """ Implementation of Distributed Muon by Liu et al. """ + + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + g = self._update_g(p, g, group, momentum) + + # Gather G + if isinstance(p.data, DTensor): + g_full = g.full_tensor() + p_full = p.data.full_tensor() + else: + g_full = g + p_full = p + + u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), + steps=group["ns_steps"]) + + adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) + Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + scales_full = self._compute_scales( + p_full, qk_clip_state) if qk_clip_state is not None else None + + if scales_full is not None: + Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) + + if isinstance(p.data, DTensor): + ndims = len(p.device_mesh.mesh.shape) + p_replicate = DTensor.from_local( + p_full, + device_mesh=p.device_mesh, + placements=[Replicate() for _ in range(ndims)], + ) + + p_sharded = p_replicate.redistribute( + device_mesh=p.device_mesh, + placements=p.placements, + ) + + p.copy_(p_sharded) + + def _update_g(self, p, g, group, momentum): + # calc update + state = self.state[p] + buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) + torch.add(g, buf, alpha=momentum, out=buf) + if group["nesterov"]: + g.add_(buf, alpha=momentum) + return g + return buf + + @staticmethod + def _update_p(p, u, lr, adjusted_lr, weight_decay): + if isinstance(p, torch.nn.Parameter): + # apply weight decay + p.data.mul_(1 - lr * weight_decay) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + else: + p.mul_(1 - lr * weight_decay) + p.add_(u, alpha=-adjusted_lr) + + def get_qk_clip_info(self, n, qk_logits): + if self.clip_config is None: + return None + + head_dim = self.clip_config.get('head_dim') + threshold = self.clip_config.get('threshold') + kind, layer_idx = parse_qk_layer(n) + + logit, indices = None, [] + if qk_logits is not None and kind is not None: + logit = qk_logits[layer_idx] + indices_key = 'q_indices' if 'q' in kind else 'k_indices' + indices = self.clip_config.get(indices_key, []) or [] + + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) + + @staticmethod + def _compute_scales(p, qk_clip_state): + kind = qk_clip_state.kind + indices = qk_clip_state.indices + head_dim = qk_clip_state.head_dim + threshold = qk_clip_state.threshold + logit = qk_clip_state.logit + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + scaling = 0 + + for logit_idx, head_idx in enumerate(indices): + v_ele = float(logit[logit_idx]) + if v_ele > threshold: + new_scale = math.sqrt(threshold / v_ele) + if new_scale < scales_full[head_idx]: + scales_full[head_idx] = new_scale + logger.info( + f"[{kind}] Head {head_idx} exceeded threshold " + f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" + ) + scaling += 1 + + return scales_full if scaling > 0 else None + + @staticmethod + def _qk_clip(p, scales, head_dim): + if isinstance(p, torch.nn.Parameter): + W = p.data.view(-1, head_dim, p.data.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + else: + W = p.view(-1, head_dim, p.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + + def parallel(self, names, params, group, lr, weight_decay, momentum, + qk_logits): + """ + Perform a parallel optimization step using Muon. + """ + + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + + # Update g in the local rank + g = self._update_g( + p, + g, + group, + momentum=momentum, + ) + p.grad = g + + param_to_state, ordered_params = self.init_state_and_assign_params( + names, params, group, qk_logits) + + assert self.rank is not None + + def enqueue_all2all_gather(start_idx, chunk_size): + target_params = ordered_params[start_idx:start_idx + chunk_size] + if target_params: + alloc_event = _alloc_gathered_grad(target_params, + param_to_state, self.rank, + self.compute_stream) + _all2all_gather(target_params, param_to_state, self.rank, + self.comm_stream, group["none_grad"], + alloc_event) + + def enqueue_computes(start_idx, chunk_size): + for p in ordered_params[start_idx:start_idx + chunk_size]: + state = param_to_state[id(p)] + _compute_u(p, state, group["ns_steps"], self.rank, + self.compute_stream) + + def enqueue_all2all_scatter(start_idx, chunk_size): + target_params = ordered_params[start_idx:start_idx + chunk_size] + if target_params: + alloc_event = _alloc_scattered_u(target_params, param_to_state, + self.rank, + self.compute_stream) + _all2all_scatter(target_params, param_to_state, self.rank, + self.comm_stream, alloc_event) + + def enqueue_update_param(start_idx, chunk_size): + for p in ordered_params[start_idx:start_idx + chunk_size]: + state = param_to_state[id(p)] + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + _update_param(p, state, lr, adjusted_lr, weight_decay, + self.rank, self.compute_stream) + + if self.chunk_size == -1: + shard_ranks = dist.get_world_size(param_to_state[id( + params[0])].process_group) + chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO + elif self.chunk_size > 0: + chunk_size = self.chunk_size + else: + raise ValueError("chunk_size must be -1 or a positive integer.") + + # Wait grad update + self.comm_stream.wait_stream(torch.cuda.current_stream()) + + warmup_step = self.warmup_step + for i in range(0, warmup_step): + enqueue_all2all_gather(i * chunk_size, chunk_size) + enqueue_computes(i * chunk_size, chunk_size) + + for i in range(0, len(params) + chunk_size - 1, chunk_size): + enqueue_all2all_scatter(i, chunk_size) + enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) + enqueue_update_param(i, chunk_size) + enqueue_computes(i + warmup_step * chunk_size, chunk_size) + + # Wait the last update_param to finish + torch.cuda.current_stream().wait_stream(self.compute_stream) + + @staticmethod + def _fused_adamw( + params: list[torch.Tensor], + grads: list[torch.Tensor], + exp_avgs: list[torch.Tensor], + exp_avg_sqs: list[torch.Tensor], + max_exp_avg_sqs: list[torch.Tensor], + state_steps: list[torch.Tensor], + amsgrad: bool, + beta1: float, + beta2: float, + lr: float | torch.Tensor, + weight_decay: float, + eps: float, + maximize: bool, + ) -> None: + if not params: + return + + # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer + # treating it as a scalar. + lr_dict: DeviceDict | None = ({ + lr.device: lr + } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else + None) + grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( + [ + params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, + state_steps + ] # type: ignore[list-item] + ) + for (device, _), ( + ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_max_exp_avg_sqs, + device_state_steps_, + ), + _, + ) in grouped_tensors.items(): + device_params = cast(list[torch.Tensor], device_params_) + device_grads = cast(list[torch.Tensor], device_grads_) + device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) + device_state_steps = cast(list[torch.Tensor], device_state_steps_) + + if lr_dict is not None and device not in lr_dict: + lr_dict[device] = lr.to( + device=device, + non_blocking=True) # type: ignore[union-attr] + lr = lr_dict[device] + torch._foreach_add_(device_state_steps, 1) + func = torch._fused_adamw_ + func( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, # type: ignore[arg-type] + device_state_steps, + amsgrad=amsgrad, + lr=lr, # type: ignore[arg-type] + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + ) + + def _step_muon(self, group, qk_logits=None): + params = group["params"] + lr = group["lr"] + weight_decay = group["weight_decay"] + momentum = group["momentum"] + names = group["names"] + + param_dtensors = [] + name_dtensors = [] + + param_tensors = [] + name_tensors = [] + + param_dtensors_small = [] + name_dtensors_small = [] + + if self.use_distributed_muon: + self.distributed_muon(names=names, + params=params, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits) + return + + # For simplicity, we use distributed Muon for small parameters + # whose number of elements is below a threshold. + for n, p in zip(names, params): + if p is None or p.grad is None: + continue + if isinstance(p.data, DTensor): + if all( + isinstance(placement, Replicate) + for placement in p.placements): + param_tensors.append(p) + name_tensors.append(n) + elif p.data.numel() <= self.small_param_numel_threshold: + param_dtensors_small.append(p) + name_dtensors_small.append(n) + else: + param_dtensors.append(p) + name_dtensors.append(n) + elif isinstance(p.data, torch.Tensor): + param_tensors.append(p) + name_tensors.append(n) + else: + raise TypeError(f"Unsupported parameter type: {type(p.data)}") + + logger.debug( + f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " + f"{len(param_dtensors_small)} Small DTensors") + + def group_dtensors(dtensors, names): + # To support different placements, we group parameters by placements + # and run parallel Muon on each group. + + placement_to_params = defaultdict(lambda: ([], [])) + # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] + + assert len(dtensors) == len(names) + for p, n in zip(dtensors, names): + placement_to_params[tuple([p.placements, + p.device_mesh])][0].append(n) + placement_to_params[tuple([p.placements, + p.device_mesh])][1].append(p) + return placement_to_params + + if len(param_dtensors_small) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + self.distributed_muon( + params=param_dtensors_small, + names=name_dtensors_small, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_dtensors) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + dtensor_group = group_dtensors(param_dtensors, name_dtensors) + for _, (names, params) in dtensor_group.items(): + self.parallel( + names, + params, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_tensors) > 0: + self.base( + name_tensors, + param_tensors, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + def _step_adamw_params(self, params, group): + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + self._fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + def _step_adamw(self, group): + params = group["params"] + + # group params with it's type and placement + placement_to_params: dict[tuple[Placement | type, + DeviceMesh | None]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for params in placement_to_params.values(): + self._step_adamw_params(params, group) + + @torch.no_grad + def step(self, closure=None, qk_logits=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices + to 1D tensors of shape (num_heads,), representing the maximum + QK logits across all tokens, computed as + (1 / sqrt(head_dim)) * (Q @ K^T). + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + if group["use_muon"]: + self._step_muon(group, qk_logits=qk_logits) + else: + self._step_adamw(group) + + return loss diff --git a/build/torch28-cxx11-cu126-x86_64-linux/optimizer/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/optimizer/__init__.py index 239c7a65f8293e7d0df28f05fce645af56d628c0..03dbc1afe1cf156661a2b1b22003cd5f599a0309 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/optimizer/__init__.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/optimizer/__init__.py @@ -1,5 +1,26 @@ -from .muon import Muon +import ctypes +import sys -__all__ = [ - "Muon", -] +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py b/build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py deleted file mode 100644 index 7d598206add1bca142661a3df6c510e3d9575d54..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu126-x86_64-linux/optimizer/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _optimizer_23d68bb_dirty -ops = torch.ops._optimizer_23d68bb_dirty - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_optimizer_23d68bb_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so deleted file mode 100755 index 62346e32d4dc69c4cefb083f0c788f6564fb142c..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu126-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:35708a107d9ac807fa3e63bbacfc6234fd7622a689a79eae3e43fce11f85d3da -size 1924376 diff --git a/build/torch28-cxx11-cu128-x86_64-linux/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..239c7a65f8293e7d0df28f05fce645af56d628c0 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/__init__.py @@ -0,0 +1,5 @@ +from .muon import Muon + +__all__ = [ + "Muon", +] diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_ops.py b/build/torch28-cxx11-cu128-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..e6f6fcf6280e969b1761926112147d3146e27b59 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _optimizer_06a260a_dirty +ops = torch.ops._optimizer_06a260a_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/_optimizer_06a260a_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..1cf60567b59ce1b343c5a44301e443953b674f78 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/_optimizer_06a260a_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:119adc22cd57de6d6d78c1f5094310b57083050f40836a5455bdb6c35bed104b +size 1999872 diff --git a/build/torch28-cxx11-cu128-x86_64-linux/distributed/utils.py b/build/torch28-cxx11-cu128-x86_64-linux/distributed/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6d5843506c13d9d31603b2b4e30c1c91d0baab28 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/distributed/utils.py @@ -0,0 +1,175 @@ +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor +from torch.distributed.tensor.placement_types import (Placement, Shard, + _StridedShard) + + +def get_slices_of_dtensor( + target: DTensor | torch.Tensor, + local_rank: int, + shard_mesh: DeviceMesh, + shard_placements: tuple[Placement], +) -> tuple[slice]: + """ + Get the slice of local tensor for a given rank from a tensor. + Args: + target (DTensor | torch.Tensor): The target tensor. + rank (int): The local rank of the shard group. + shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + shard_placements (tuple[Placement]): The shard placements. + """ + + slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + + # find the global rank of the local rank in the shard mesh + rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] + + rank_coords = (shard_mesh.mesh == rank).nonzero() + + assert len(rank_coords) == 1 + rank_coords = tuple(rank_coords[0].tolist()) + + assert len(rank_coords) == len(shard_placements) + + # Caution: Assuming replicate-to-shard of the shard mesh goes with + # left-to-right sharding. This is ensured by the sorting logic of + # construct_shard_mesh function. + for i, (rank_coord, + placement) in enumerate(zip(rank_coords, shard_placements)): + assert isinstance(placement, Shard) + + num_ranks = shard_mesh.mesh.shape[i] + + dim = placement.dim + dim_size = (slices[dim].stop - slices[dim].start) + + if dim_size % num_ranks != 0: + raise NotImplementedError( + f"Dimension size {dim_size} is not divisible " + f"by number of ranks {num_ranks} for shard " + f"placement on dim {dim}. (shape: {target.shape})") + + shard_size = dim_size // num_ranks + + start = slices[dim].start + rank_coord * shard_size + end = start + shard_size + + assert start < end <= slices[dim].stop + + slices[dim] = slice(start, end) + + return tuple(slices) + + +_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, + ProcessGroup]] = dict() + + +def construct_shard_mesh( + placements: tuple[Placement], + mesh: DeviceMesh, +) -> (DeviceMesh, ProcessGroup, tuple[Placement]): + """ + Construct Shard Mesh and Placements for unsharding. + It removes Replicate placements and constructs a new Mesh and ProcessGroup. + """ + my_rank = dist.get_rank() + + assert mesh.mesh.device.type == 'cpu' + + # Copy mesh to avoid modifying the original mesh + mesh = mesh.mesh.clone() + + # 1. Sort placements. Replicate first, then Shard by dim ascending. + + # For Shard, strided shard comes after regular shard on the same dim + # to preserve left-to-right order of replicate-to-shard. + # This is because that strided shard is using stride to represent + # more fine-grained sharding on the same dim. + # Please check the URL below for _StridedShard. + # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 + + def placement_sort_key( + placement_with_index: tuple[float, Placement] + ) -> tuple[int, float, int]: # (dim, split factor, original index) + index, placement = placement_with_index + is_replicate = placement.is_replicate() + is_shard = placement.is_shard() + is_partial = placement.is_partial() + + assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" + assert not is_partial, "Partial placement is not supported." + + if is_replicate: + return (-1.0, 0, index) + elif is_shard: + if isinstance(placement, _StridedShard): + return (placement.dim, 1 / placement.split_factor, index) + return (placement.dim, 0, index) + else: + raise TypeError(f"Unknown placement type: {type(placement)}") + + placements_with_index: list[tuple[int, + Placement]] = list(enumerate(placements)) + placements_with_index = sorted(placements_with_index, + key=placement_sort_key) + + sorted_indices, sorted_placements = zip(*placements_with_index) + + # 2. Permute mesh according to sorted placements. + sorted_mesh = mesh.permute(sorted_indices) + + # 3. Collect list of shard meshes by removing replicate dims + # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] + # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) + num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + + # merge replicate dims + # shard_meshes became a list of shard meshes with a length of replicate degree + if num_replicates > 0: + sorted_mesh = sorted_mesh.flatten( + 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) + else: + shard_meshes = [sorted_mesh] + shard_placements = sorted_placements[num_replicates:] + + # assume all shard placements are different + assert len(shard_placements) == len(set(shard_placements)) + + # 4. Construct ProcessGroups + # Caution: all groups should be created in the same order in all processes, + # even though each process only needs its own group. + + # To use tensor as dict key, convert it to tuple + def tensor_to_tuple(t): + if isinstance(t, torch.Tensor): + t = t.tolist() + if isinstance(t, list): + return tuple(tensor_to_tuple(x) for x in t) + return t + + my_shard_mesh_as_tuple = None + for shard_mesh in shard_meshes: + assert isinstance(shard_mesh, torch.Tensor) + shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) + + if (my_rank == shard_mesh).any().item(): + assert my_shard_mesh_as_tuple is None + my_shard_mesh_as_tuple = shard_mesh_as_tuple + + # update global cache + if shard_mesh_as_tuple not in _ranks_to_dist_cache: + shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) + _ranks_to_dist_cache[shard_mesh_as_tuple] = ( + DeviceMesh(device_type="cuda", mesh=shard_mesh), + shard_process_group, + ) + + my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ + my_shard_mesh_as_tuple] + + return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch29-cxx11-cu128-x86_64-linux/optimizer/matmul_transpose_triton.py b/build/torch28-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py similarity index 100% rename from build/torch29-cxx11-cu128-x86_64-linux/optimizer/matmul_transpose_triton.py rename to build/torch28-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py diff --git a/build/torch28-cxx11-cu128-x86_64-linux/metadata.json b/build/torch28-cxx11-cu128-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..76bafa5f33b6818aa6bb4cab04be811b87519b44 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/metadata.json @@ -0,0 +1 @@ +{"python-depends":[]} \ No newline at end of file diff --git a/build/torch28-cxx11-cu128-x86_64-linux/muon.py b/build/torch28-cxx11-cu128-x86_64-linux/muon.py new file mode 100644 index 0000000000000000000000000000000000000000..dbf25575f185ff379789482068e4ecf55b9455a9 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/muon.py @@ -0,0 +1,1268 @@ +import logging +import math +import types +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, cast + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Replicate +from torch.distributed.tensor.placement_types import Placement + +from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor +from .matmul_transpose_triton import matmul_transpose_assign + +logger = logging.getLogger(__name__) + +COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 + + +# This code snippet is a modified version adapted from the following GitHub repositories: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +# Muon's Newton–Schulz iteration causes high variance in singular values +# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +@torch.no_grad() +# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + assert G.dtype == COMM_DTYPE + X = G # no manual typecast + + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + # Perform the NS iterations + for a, b, c in [ + (4.0848, -6.8946, 2.9270), + (3.9505, -6.3029, 2.6377), + (3.7418, -5.5913, 2.3037), + (2.8769, -3.1427, 1.2046), + (2.8366, -3.0525, 1.2012), + ]: + matmul_transpose_assign(X, buf1) + matmul_transpose_assign(buf1, buf2) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X + + +@dataclass +class _muon_state: + # TODO: use Optional + worker_rank: int + process_group: ProcessGroup + shard_mesh: DeviceMesh + shard_placements: tuple[Placement, ...] + name: str + qk_clip_state: torch.Tensor | None = None + gathered_grad: torch.Tensor | None = None + scattered_u: DTensor | None = None + computed_u: torch.Tensor | None = None + gather_event: torch.cuda.Event | None = None + compute_event: torch.cuda.Event | None = None + scatter_event: torch.cuda.Event | None = None + + +def numel_for_rank( + param: DTensor, + local_rank: int, + state: _muon_state, +) -> int: + slices = get_slices_of_dtensor( + param, + local_rank, + state.shard_mesh, + state.shard_placements, + ) + + numel = 1 + for s, dim in zip(slices, param.shape): + start, stop, step = s.indices(dim) + length = max(0, (stop - start + (step - 1)) // step) + numel *= length + + return numel + + +@torch.no_grad() +def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): + """ + Pre-allocate gathered_grad buffer on compute_stream + before launching all2all gather + """ + with torch.cuda.stream(compute_stream): + for p in params: + state = param_to_state[id(p)] + if rank == state.worker_rank: + state.gathered_grad = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + else: + state.gathered_grad = None + + alloc_event = torch.cuda.Event() + alloc_event.record(compute_stream) + return alloc_event + + +@torch.no_grad() +def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, + alloc_event): + """ + All2all gathers shards so each owner rank reconstructs its full gradient + """ + with torch.cuda.stream(comm_stream): + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + + # Construct sending buffers + per_dst = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + for p in params: + state = param_to_state[id(p)] + dst = state.worker_rank + assert dst < num_ranks + shard_elems = numel_for_rank(p, rank, state) + g = p.grad + g = g.to_local().to(COMM_DTYPE).contiguous() + assert g.numel() == shard_elems + per_dst[dst].append(g.view(-1)) + send_counts[dst] += shard_elems + + assert any( + len(v) > 0 for v in per_dst + ), "At least one destination rank must receive a sharded tensor" + # list[list[Tensor]] -> list[Tensor] + per_dst = [t for dst in per_dst for t in dst] + + send_buf = torch.cat(per_dst, dim=0) + + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Compute receive sizes and allocate receiving buffers + recv_counts = [0] * num_ranks + + for src in range(num_ranks): + total = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + total += numel_for_rank(p, src, state) + recv_counts[src] = total + + recv_total = sum(recv_counts) + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + #All2All + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") + dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + ) + + # Reconstructs gathered grad from the received buffer + # + # recv_buf (num ranks = 3) + # + # From rank 0 From rank 1 From rank 2 + # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | + # + # Outer loop: + # rank 0 -> rank 1 -> rank2 + # + # Inner loop: + # p1_n -> p2_n -> p3_n + + comm_stream.wait_event(alloc_event) + + off = 0 + for src in range(num_ranks): + if recv_counts[src] == 0: + continue + + block = recv_counts[src] + inner_off = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + + # get the slice of the full dtensor corresponding to rank src. + slices = get_slices_of_dtensor(state.gathered_grad, src, + state.shard_mesh, + state.shard_placements) + + dst = state.gathered_grad[slices] + assert dst._base is state.gathered_grad + + n = dst.numel() + assert n > 0 + + sg = recv_buf.narrow(0, off + inner_off, n) + sg = sg.reshape_as(dst) + dst.copy_(sg) + + inner_off += n + off += block + + for p in params: + state = param_to_state[id(p)] + if state.worker_rank == rank: + state.gather_event = torch.cuda.Event() + state.gather_event.record(comm_stream) + else: + state.gathered_grad = None + state.gather_event = None + if none_grad: + p.grad = None + + +@torch.no_grad() +def _compute_u(p, state, steps, rank, compute_stream): + """ + On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. + """ + with torch.cuda.stream(compute_stream): + if rank == state.worker_rank: + if state.gather_event is None: + raise RuntimeError("Gather event must be set before compute.") + compute_stream.wait_event(state.gather_event) + u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) + state.gathered_grad = None + state.computed_u = u + state.compute_event = torch.cuda.Event() + state.compute_event.record() + else: + state.computed_u = None + state.compute_event = None + + +@torch.no_grad() +def _alloc_scattered_u(params, param_to_state, rank, compute_stream): + """ + Pre-allocate scattered_u buffer on compute_stream + before launching all2all gather + """ + with torch.cuda.stream(compute_stream): + for p in params: + state = param_to_state[id(p)] + state.scattered_u = torch.empty_like(p.to_local(), + dtype=COMM_DTYPE) + + alloc_event = torch.cuda.Event() + alloc_event.record(compute_stream) + return alloc_event + + +def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): + """ + All2all scatters full gradients to all ranks + """ + with torch.cuda.stream(comm_stream): + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Construct sending buffer + per_dst = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + if owned_params: + for p in owned_params: + state = param_to_state[id(p)] + if state.compute_event is None: + raise RuntimeError( + "Compute event must be set before scatter.") + comm_stream.wait_event(state.compute_event) + state.gathered_grad = None + + assert state.computed_u is not None + + u_full = state.computed_u.to(COMM_DTYPE).contiguous() + + offset = 0 + for dst in range(num_ranks): + # get the slice of the full tensor corresponding to rank dst. + slices = get_slices_of_dtensor(u_full, dst, + state.shard_mesh, + state.shard_placements) + su = u_full[slices].flatten() + + n = su.numel() + assert n > 0 + + per_dst[dst].append(su) + send_counts[dst] += n + offset += n + + assert offset == u_full.numel() + + lengths = [len(v) for v in per_dst] + if all(l > 0 for l in lengths): + assert all( + l == lengths[0] for l in lengths + ), "All destination ranks must have the same number of sharded tensor" + # list[list[Tensor]] -> list[Tensor] + per_dst = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst, dim=0) + else: + # all_to_all requires participation from all ranks + # Even non-owner ranks must join the collective call + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") + + # Compute receive sizes and allocate receiving buffers + recv_counts = [0] * num_ranks + + for src in range(num_ranks): + total = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + total += numel_for_rank(p, rank, state) + recv_counts[src] = total + + recv_total = sum(recv_counts) + assert recv_total > 0 + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + #All2All + dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + ) + + # Copy to pre-allocated scattered_u buffer from the received buffer + # + # recv_buf (num ranks = 3, local_rank = 0) + # + # From rank 0 From rank 1 From rank 2 + # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | + # + # Outer loop: + # rank 0 -> rank 1 -> rank2 + # + # Inner loop: + # src(0) : p1_0 -> p2_0 -> p3_0 + # src(1) : p4_0 + # src(2) : p5_0 -> p6_0 + + comm_stream.wait_event(alloc_event) + + off = 0 + for src in range(num_ranks): + block = recv_counts[src] + if block == 0: + continue + + inner_off = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + n = numel_for_rank(p, rank, state) + assert n > 0 + + flat_local = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) + state.scattered_u.copy_(flat_local) + + state.scatter_event = torch.cuda.Event() + state.scatter_event.record(comm_stream) + inner_off += n + + assert inner_off == block + off += block + + +def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, + compute_stream): + """ + Update sharded parameter p with the scattered_u. + Only worker_rank frees computed_u. + """ + with torch.cuda.stream(compute_stream): + if state.scatter_event is None: + raise RuntimeError("Scatter event must be set before update") + compute_stream.wait_event(state.scatter_event) + u_dtensor = DTensor.from_local( + state.scattered_u, + placements=p.placements, + device_mesh=p.device_mesh, + ) + + state.scattered_u = u_dtensor + + if rank == state.worker_rank: + # Free computed_u + state.computed_u = None + + Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) + state.scattered_u = None + u_dtensor = None + + scales_full = Muon._compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None + if scales_full is not None: + # Have to slice scales_full among dim 0 + weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, + state.shard_placements) + ratio = p.shape[0] // scales_full.shape[0] + scales_slice = slice( + None if weight_slices[0].start is None else + weight_slices[0].start // ratio, + None if weight_slices[0].stop is None else + weight_slices[0].stop // ratio, + None, + ) + + scales_local = scales_full[scales_slice] + scales_local = DTensor.from_local( + scales_local, + placements=p.placements, + device_mesh=p.device_mesh, + ) + Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) + + +def default_is_muon(name, x): + skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] + return x.ndim >= 2 and not any(key in name for key in skip_keys) + + +def get_default_muon_param_groups(model, is_muon_func=default_is_muon): + muon_params, muon_names = [], [] + non_muon_params = [] + + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + if is_muon_func(n, p): + muon_params.append(p) + muon_names.append(n) + else: + non_muon_params.append(p) + + return [ + { + "params": muon_params, + "names": muon_names, + "use_muon": True, + }, + { + "params": non_muon_params, + "use_muon": False, + }, + ] + + +def parse_qk_layer(name: str) -> tuple[str | None, int]: + """ + Parse a parameter name to check if it is a query/key projection layer + ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + + Returns: + (kind, layer_idx) or (None, -1) if not matched. + + Example: + 'model.3.attn.wq.weight' -> ('wq', 3) + 'model.5.attn.wk.weight' -> ('wk', 5) + 'model.2.attn.q_proj.weight' -> ('q_proj', 2) + 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.4.attn.v_proj.weight' -> (None, -1) + """ + parts = name.split('.') + if len(parts) < 3: + return None, -1 + + kind = parts[-2] + + layer_idx = -1 + for part in reversed(parts): + if part.isdigit(): + layer_idx = int(part) + break + + if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + return kind, layer_idx + + return None, -1 + + +@dataclass +class QKClipInfo: + """Per-parameter dynamic info computed from config + runtime logits.""" + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping + head_dim: int # from config + threshold: float # from config + logit: torch.Tensor | None + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + model: The model to be optimized by Muon. + is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + weight_decay: The weight decay for Muon and AdamW. + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. + debug: Whether to print debug information. + clip_info : Configuration for QK clipping. Expected keys: + - "q_indices" (list[int]): Indices of query heads to consider. + - "k_indices" (list[int]): Indices of key heads to consider. + - "head_dim" (int): Dimensionality of each attention head. + - "threshold" (float): Threshold value; heads whose QK logits exceed + this value will be scaled down. + Default is: + { + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100 + } + warmup_step : How many all2all gather, compute operations are launched in advance + before the corresponding all2all scatter steps begin. + A higher warmup_step increases memory usage but can improve + performance by overlapping communication. + Parallel muon only. + chunk_size : Batch size of parameters to process in each + all2all gather/compute/scatter step. + Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. + use_distributed_muon: Use distributed muon by Liu et al. (2024). + For testing purpose only. + small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon + """ + + def __init__(self, + params, + lr=1e-3, + momentum=0.95, + nesterov=True, + ns_steps=5, + weight_decay=0.1, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + none_grad=True, + debug=False, + clip_config={ + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100 + }, + warmup_step=5, + chunk_size=-1, + use_distributed_muon=False, + small_param_numel_threshold=65536): + defaults = dict( + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + none_grad=none_grad, + use_muon=True, + ) + error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." + instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" + + if isinstance(params, types.GeneratorType): + raise ValueError(error_message.format(idx=0) + instruction_code) + for _idx, param_group in enumerate(params): + if param_group.get("use_muon", None) is None: + raise ValueError( + error_message.format(idx=_idx) + instruction_code) + + super().__init__(params, defaults) + + self.rank = None + + self.comm_stream = torch.cuda.Stream() + self.compute_stream = torch.cuda.Stream() + self.debug = debug + self.clip_config = clip_config + self.warmup_step = warmup_step + self.chunk_size = chunk_size + self.use_distributed_muon = use_distributed_muon + self.small_param_numel_threshold = small_param_numel_threshold + + def _calc_flops(self, G, steps): + assert len(G.shape) == 2 + M, N = G.shape + if M > N: + M, N = N, M + + return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def set_rank_once(self, rank): + if self.rank is None: + self.rank = rank + else: + assert self.rank == rank + + def get_shard_mesh(self, p): + """ + Get the shard mesh for a parameter p on the given rank. + """ + assert isinstance( + p, DTensor), "Parallel Muon only supports DTensor parameters." + + shard_mesh, shard_pg, shard_placements = construct_shard_mesh( + p.placements, p.device_mesh) + + # set rank with the local rank in the shard process group + self.set_rank_once(dist.get_rank(group=shard_pg)) + + return shard_mesh, shard_pg, shard_placements + + def init_state_and_assign_params(self, names, params, group, qk_logits): + param_to_state = {} + param_to_flops = {} + + total_flops = 0 + for p in params: + g = p.grad + if g is None: + continue + assert g.ndim == 2, "Muon only supports 2D parameters." + + flops = self._calc_flops(g, group["ns_steps"]) + param_to_flops[id(p)] = flops + total_flops += flops + + if self.debug: + print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", + flush=True) + + paired = list(zip(names, params)) + + paired_sorted = sorted(paired, + key=lambda x: param_to_flops[id(x[1])], + reverse=True) + + names_sorted, params_sorted = zip(*paired_sorted) + ordered_names = list(names_sorted) + ordered_params = list(params_sorted) + + round_robin = 0 + mesh = ordered_params[0].device_mesh + placements = ordered_params[0].placements + + shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( + ordered_params[0]) + shard_mesh_flattened = shard_mesh.mesh.flatten() + num_ranks = dist.get_world_size(group=shard_pg) + + for n, p in zip(ordered_names, ordered_params): + if mesh != p.device_mesh: + raise ValueError("All parameters must be on the same mesh.") + if placements != p.placements: + raise ValueError("All parameters must have same placements.") + + worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks + round_robin = (round_robin + 1) % len(shard_mesh_flattened) + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + param_to_state[id(p)] = _muon_state( + worker_rank=worker_rank, + process_group=shard_pg, + shard_mesh=shard_mesh, + shard_placements=shard_placements, + name=n, + qk_clip_state=qk_clip_state, + ) + + return param_to_state, ordered_params + + def base(self, names, params, group, lr, weight_decay, momentum, + qk_logits): + # generate weight updates in distributed fashion + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + g = self._update_g(p, g, group, momentum) + + u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), + steps=group["ns_steps"]) + + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + scales_full = self._compute_scales( + p, qk_clip_state) if qk_clip_state is not None else None + if scales_full is not None: + Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + + def distributed_muon( + self, + names: list[str], + params: list[torch.nn.Parameter], + group: dict[str, Any], + lr: float, + weight_decay: float, + momentum: float, + qk_logits: list[torch.Tensor | DTensor] | None, + ): + """ Implementation of Distributed Muon by Liu et al. """ + + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + g = self._update_g(p, g, group, momentum) + + # Gather G + if isinstance(p.data, DTensor): + g_full = g.full_tensor() + p_full = p.data.full_tensor() + else: + g_full = g + p_full = p + + u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), + steps=group["ns_steps"]) + + adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) + Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + scales_full = self._compute_scales( + p_full, qk_clip_state) if qk_clip_state is not None else None + + if scales_full is not None: + Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) + + if isinstance(p.data, DTensor): + ndims = len(p.device_mesh.mesh.shape) + p_replicate = DTensor.from_local( + p_full, + device_mesh=p.device_mesh, + placements=[Replicate() for _ in range(ndims)], + ) + + p_sharded = p_replicate.redistribute( + device_mesh=p.device_mesh, + placements=p.placements, + ) + + p.copy_(p_sharded) + + def _update_g(self, p, g, group, momentum): + # calc update + state = self.state[p] + buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) + torch.add(g, buf, alpha=momentum, out=buf) + if group["nesterov"]: + g.add_(buf, alpha=momentum) + return g + return buf + + @staticmethod + def _update_p(p, u, lr, adjusted_lr, weight_decay): + if isinstance(p, torch.nn.Parameter): + # apply weight decay + p.data.mul_(1 - lr * weight_decay) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + else: + p.mul_(1 - lr * weight_decay) + p.add_(u, alpha=-adjusted_lr) + + def get_qk_clip_info(self, n, qk_logits): + if self.clip_config is None: + return None + + head_dim = self.clip_config.get('head_dim') + threshold = self.clip_config.get('threshold') + kind, layer_idx = parse_qk_layer(n) + + logit, indices = None, [] + if qk_logits is not None and kind is not None: + logit = qk_logits[layer_idx] + indices_key = 'q_indices' if 'q' in kind else 'k_indices' + indices = self.clip_config.get(indices_key, []) or [] + + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) + + @staticmethod + def _compute_scales(p, qk_clip_state): + kind = qk_clip_state.kind + indices = qk_clip_state.indices + head_dim = qk_clip_state.head_dim + threshold = qk_clip_state.threshold + logit = qk_clip_state.logit + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + scaling = 0 + + for logit_idx, head_idx in enumerate(indices): + v_ele = float(logit[logit_idx]) + if v_ele > threshold: + new_scale = math.sqrt(threshold / v_ele) + if new_scale < scales_full[head_idx]: + scales_full[head_idx] = new_scale + logger.info( + f"[{kind}] Head {head_idx} exceeded threshold " + f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" + ) + scaling += 1 + + return scales_full if scaling > 0 else None + + @staticmethod + def _qk_clip(p, scales, head_dim): + if isinstance(p, torch.nn.Parameter): + W = p.data.view(-1, head_dim, p.data.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + else: + W = p.view(-1, head_dim, p.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + + def parallel(self, names, params, group, lr, weight_decay, momentum, + qk_logits): + """ + Perform a parallel optimization step using Muon. + """ + + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + + # Update g in the local rank + g = self._update_g( + p, + g, + group, + momentum=momentum, + ) + p.grad = g + + param_to_state, ordered_params = self.init_state_and_assign_params( + names, params, group, qk_logits) + + assert self.rank is not None + + def enqueue_all2all_gather(start_idx, chunk_size): + target_params = ordered_params[start_idx:start_idx + chunk_size] + if target_params: + alloc_event = _alloc_gathered_grad(target_params, + param_to_state, self.rank, + self.compute_stream) + _all2all_gather(target_params, param_to_state, self.rank, + self.comm_stream, group["none_grad"], + alloc_event) + + def enqueue_computes(start_idx, chunk_size): + for p in ordered_params[start_idx:start_idx + chunk_size]: + state = param_to_state[id(p)] + _compute_u(p, state, group["ns_steps"], self.rank, + self.compute_stream) + + def enqueue_all2all_scatter(start_idx, chunk_size): + target_params = ordered_params[start_idx:start_idx + chunk_size] + if target_params: + alloc_event = _alloc_scattered_u(target_params, param_to_state, + self.rank, + self.compute_stream) + _all2all_scatter(target_params, param_to_state, self.rank, + self.comm_stream, alloc_event) + + def enqueue_update_param(start_idx, chunk_size): + for p in ordered_params[start_idx:start_idx + chunk_size]: + state = param_to_state[id(p)] + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + _update_param(p, state, lr, adjusted_lr, weight_decay, + self.rank, self.compute_stream) + + if self.chunk_size == -1: + shard_ranks = dist.get_world_size(param_to_state[id( + params[0])].process_group) + chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO + elif self.chunk_size > 0: + chunk_size = self.chunk_size + else: + raise ValueError("chunk_size must be -1 or a positive integer.") + + # Wait grad update + self.comm_stream.wait_stream(torch.cuda.current_stream()) + + warmup_step = self.warmup_step + for i in range(0, warmup_step): + enqueue_all2all_gather(i * chunk_size, chunk_size) + enqueue_computes(i * chunk_size, chunk_size) + + for i in range(0, len(params) + chunk_size - 1, chunk_size): + enqueue_all2all_scatter(i, chunk_size) + enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) + enqueue_update_param(i, chunk_size) + enqueue_computes(i + warmup_step * chunk_size, chunk_size) + + # Wait the last update_param to finish + torch.cuda.current_stream().wait_stream(self.compute_stream) + + @staticmethod + def _fused_adamw( + params: list[torch.Tensor], + grads: list[torch.Tensor], + exp_avgs: list[torch.Tensor], + exp_avg_sqs: list[torch.Tensor], + max_exp_avg_sqs: list[torch.Tensor], + state_steps: list[torch.Tensor], + amsgrad: bool, + beta1: float, + beta2: float, + lr: float | torch.Tensor, + weight_decay: float, + eps: float, + maximize: bool, + ) -> None: + if not params: + return + + # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer + # treating it as a scalar. + lr_dict: DeviceDict | None = ({ + lr.device: lr + } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else + None) + grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( + [ + params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, + state_steps + ] # type: ignore[list-item] + ) + for (device, _), ( + ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_max_exp_avg_sqs, + device_state_steps_, + ), + _, + ) in grouped_tensors.items(): + device_params = cast(list[torch.Tensor], device_params_) + device_grads = cast(list[torch.Tensor], device_grads_) + device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) + device_state_steps = cast(list[torch.Tensor], device_state_steps_) + + if lr_dict is not None and device not in lr_dict: + lr_dict[device] = lr.to( + device=device, + non_blocking=True) # type: ignore[union-attr] + lr = lr_dict[device] + torch._foreach_add_(device_state_steps, 1) + func = torch._fused_adamw_ + func( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, # type: ignore[arg-type] + device_state_steps, + amsgrad=amsgrad, + lr=lr, # type: ignore[arg-type] + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + ) + + def _step_muon(self, group, qk_logits=None): + params = group["params"] + lr = group["lr"] + weight_decay = group["weight_decay"] + momentum = group["momentum"] + names = group["names"] + + param_dtensors = [] + name_dtensors = [] + + param_tensors = [] + name_tensors = [] + + param_dtensors_small = [] + name_dtensors_small = [] + + if self.use_distributed_muon: + self.distributed_muon(names=names, + params=params, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits) + return + + # For simplicity, we use distributed Muon for small parameters + # whose number of elements is below a threshold. + for n, p in zip(names, params): + if p is None or p.grad is None: + continue + if isinstance(p.data, DTensor): + if all( + isinstance(placement, Replicate) + for placement in p.placements): + param_tensors.append(p) + name_tensors.append(n) + elif p.data.numel() <= self.small_param_numel_threshold: + param_dtensors_small.append(p) + name_dtensors_small.append(n) + else: + param_dtensors.append(p) + name_dtensors.append(n) + elif isinstance(p.data, torch.Tensor): + param_tensors.append(p) + name_tensors.append(n) + else: + raise TypeError(f"Unsupported parameter type: {type(p.data)}") + + logger.debug( + f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " + f"{len(param_dtensors_small)} Small DTensors") + + def group_dtensors(dtensors, names): + # To support different placements, we group parameters by placements + # and run parallel Muon on each group. + + placement_to_params = defaultdict(lambda: ([], [])) + # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] + + assert len(dtensors) == len(names) + for p, n in zip(dtensors, names): + placement_to_params[tuple([p.placements, + p.device_mesh])][0].append(n) + placement_to_params[tuple([p.placements, + p.device_mesh])][1].append(p) + return placement_to_params + + if len(param_dtensors_small) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + self.distributed_muon( + params=param_dtensors_small, + names=name_dtensors_small, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_dtensors) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + dtensor_group = group_dtensors(param_dtensors, name_dtensors) + for _, (names, params) in dtensor_group.items(): + self.parallel( + names, + params, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_tensors) > 0: + self.base( + name_tensors, + param_tensors, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + def _step_adamw_params(self, params, group): + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + self._fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + def _step_adamw(self, group): + params = group["params"] + + # group params with it's type and placement + placement_to_params: dict[tuple[Placement | type, + DeviceMesh | None]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for params in placement_to_params.values(): + self._step_adamw_params(params, group) + + @torch.no_grad + def step(self, closure=None, qk_logits=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices + to 1D tensors of shape (num_heads,), representing the maximum + QK logits across all tokens, computed as + (1 / sqrt(head_dim)) * (Q @ K^T). + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + if group["use_muon"]: + self._step_muon(group, qk_logits=qk_logits) + else: + self._step_adamw(group) + + return loss diff --git a/build/torch28-cxx11-cu128-x86_64-linux/optimizer/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/optimizer/__init__.py index 239c7a65f8293e7d0df28f05fce645af56d628c0..03dbc1afe1cf156661a2b1b22003cd5f599a0309 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/optimizer/__init__.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/optimizer/__init__.py @@ -1,5 +1,26 @@ -from .muon import Muon +import ctypes +import sys -__all__ = [ - "Muon", -] +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py b/build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py deleted file mode 100644 index 7d598206add1bca142661a3df6c510e3d9575d54..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu128-x86_64-linux/optimizer/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _optimizer_23d68bb_dirty -ops = torch.ops._optimizer_23d68bb_dirty - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_optimizer_23d68bb_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so deleted file mode 100755 index d31f69b06fba65c78b497ee3f83cdb2b894170b2..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu128-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:03c3bbbbc5c4ceb5cebfe3a2e411f155bebb390f1921c14d59fcf791dd556da1 -size 1983488 diff --git a/build/torch28-cxx11-cu129-x86_64-linux/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..239c7a65f8293e7d0df28f05fce645af56d628c0 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/__init__.py @@ -0,0 +1,5 @@ +from .muon import Muon + +__all__ = [ + "Muon", +] diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_ops.py b/build/torch28-cxx11-cu129-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..e6f6fcf6280e969b1761926112147d3146e27b59 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _optimizer_06a260a_dirty +ops = torch.ops._optimizer_06a260a_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/_optimizer_06a260a_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..e996c45edb033c93ec8a41716764cdcbbd04593d --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/_optimizer_06a260a_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7e8463be5f48aba32d645183945d258cdb532b238ef40665db396b459367cad1 +size 1999872 diff --git a/build/torch28-cxx11-cu129-x86_64-linux/distributed/utils.py b/build/torch28-cxx11-cu129-x86_64-linux/distributed/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6d5843506c13d9d31603b2b4e30c1c91d0baab28 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/distributed/utils.py @@ -0,0 +1,175 @@ +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor +from torch.distributed.tensor.placement_types import (Placement, Shard, + _StridedShard) + + +def get_slices_of_dtensor( + target: DTensor | torch.Tensor, + local_rank: int, + shard_mesh: DeviceMesh, + shard_placements: tuple[Placement], +) -> tuple[slice]: + """ + Get the slice of local tensor for a given rank from a tensor. + Args: + target (DTensor | torch.Tensor): The target tensor. + rank (int): The local rank of the shard group. + shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + shard_placements (tuple[Placement]): The shard placements. + """ + + slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + + # find the global rank of the local rank in the shard mesh + rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] + + rank_coords = (shard_mesh.mesh == rank).nonzero() + + assert len(rank_coords) == 1 + rank_coords = tuple(rank_coords[0].tolist()) + + assert len(rank_coords) == len(shard_placements) + + # Caution: Assuming replicate-to-shard of the shard mesh goes with + # left-to-right sharding. This is ensured by the sorting logic of + # construct_shard_mesh function. + for i, (rank_coord, + placement) in enumerate(zip(rank_coords, shard_placements)): + assert isinstance(placement, Shard) + + num_ranks = shard_mesh.mesh.shape[i] + + dim = placement.dim + dim_size = (slices[dim].stop - slices[dim].start) + + if dim_size % num_ranks != 0: + raise NotImplementedError( + f"Dimension size {dim_size} is not divisible " + f"by number of ranks {num_ranks} for shard " + f"placement on dim {dim}. (shape: {target.shape})") + + shard_size = dim_size // num_ranks + + start = slices[dim].start + rank_coord * shard_size + end = start + shard_size + + assert start < end <= slices[dim].stop + + slices[dim] = slice(start, end) + + return tuple(slices) + + +_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, + ProcessGroup]] = dict() + + +def construct_shard_mesh( + placements: tuple[Placement], + mesh: DeviceMesh, +) -> (DeviceMesh, ProcessGroup, tuple[Placement]): + """ + Construct Shard Mesh and Placements for unsharding. + It removes Replicate placements and constructs a new Mesh and ProcessGroup. + """ + my_rank = dist.get_rank() + + assert mesh.mesh.device.type == 'cpu' + + # Copy mesh to avoid modifying the original mesh + mesh = mesh.mesh.clone() + + # 1. Sort placements. Replicate first, then Shard by dim ascending. + + # For Shard, strided shard comes after regular shard on the same dim + # to preserve left-to-right order of replicate-to-shard. + # This is because that strided shard is using stride to represent + # more fine-grained sharding on the same dim. + # Please check the URL below for _StridedShard. + # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 + + def placement_sort_key( + placement_with_index: tuple[float, Placement] + ) -> tuple[int, float, int]: # (dim, split factor, original index) + index, placement = placement_with_index + is_replicate = placement.is_replicate() + is_shard = placement.is_shard() + is_partial = placement.is_partial() + + assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" + assert not is_partial, "Partial placement is not supported." + + if is_replicate: + return (-1.0, 0, index) + elif is_shard: + if isinstance(placement, _StridedShard): + return (placement.dim, 1 / placement.split_factor, index) + return (placement.dim, 0, index) + else: + raise TypeError(f"Unknown placement type: {type(placement)}") + + placements_with_index: list[tuple[int, + Placement]] = list(enumerate(placements)) + placements_with_index = sorted(placements_with_index, + key=placement_sort_key) + + sorted_indices, sorted_placements = zip(*placements_with_index) + + # 2. Permute mesh according to sorted placements. + sorted_mesh = mesh.permute(sorted_indices) + + # 3. Collect list of shard meshes by removing replicate dims + # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] + # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) + num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + + # merge replicate dims + # shard_meshes became a list of shard meshes with a length of replicate degree + if num_replicates > 0: + sorted_mesh = sorted_mesh.flatten( + 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) + else: + shard_meshes = [sorted_mesh] + shard_placements = sorted_placements[num_replicates:] + + # assume all shard placements are different + assert len(shard_placements) == len(set(shard_placements)) + + # 4. Construct ProcessGroups + # Caution: all groups should be created in the same order in all processes, + # even though each process only needs its own group. + + # To use tensor as dict key, convert it to tuple + def tensor_to_tuple(t): + if isinstance(t, torch.Tensor): + t = t.tolist() + if isinstance(t, list): + return tuple(tensor_to_tuple(x) for x in t) + return t + + my_shard_mesh_as_tuple = None + for shard_mesh in shard_meshes: + assert isinstance(shard_mesh, torch.Tensor) + shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) + + if (my_rank == shard_mesh).any().item(): + assert my_shard_mesh_as_tuple is None + my_shard_mesh_as_tuple = shard_mesh_as_tuple + + # update global cache + if shard_mesh_as_tuple not in _ranks_to_dist_cache: + shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) + _ranks_to_dist_cache[shard_mesh_as_tuple] = ( + DeviceMesh(device_type="cuda", mesh=shard_mesh), + shard_process_group, + ) + + my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ + my_shard_mesh_as_tuple] + + return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch29-cxx11-cu130-x86_64-linux/optimizer/matmul_transpose_triton.py b/build/torch28-cxx11-cu129-x86_64-linux/matmul_transpose_triton.py similarity index 100% rename from build/torch29-cxx11-cu130-x86_64-linux/optimizer/matmul_transpose_triton.py rename to build/torch28-cxx11-cu129-x86_64-linux/matmul_transpose_triton.py diff --git a/build/torch28-cxx11-cu129-x86_64-linux/metadata.json b/build/torch28-cxx11-cu129-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..76bafa5f33b6818aa6bb4cab04be811b87519b44 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/metadata.json @@ -0,0 +1 @@ +{"python-depends":[]} \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-x86_64-linux/muon.py b/build/torch28-cxx11-cu129-x86_64-linux/muon.py new file mode 100644 index 0000000000000000000000000000000000000000..dbf25575f185ff379789482068e4ecf55b9455a9 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/muon.py @@ -0,0 +1,1268 @@ +import logging +import math +import types +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, cast + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Replicate +from torch.distributed.tensor.placement_types import Placement + +from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor +from .matmul_transpose_triton import matmul_transpose_assign + +logger = logging.getLogger(__name__) + +COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 + + +# This code snippet is a modified version adapted from the following GitHub repositories: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +# Muon's Newton–Schulz iteration causes high variance in singular values +# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +@torch.no_grad() +# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + assert G.dtype == COMM_DTYPE + X = G # no manual typecast + + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + # Perform the NS iterations + for a, b, c in [ + (4.0848, -6.8946, 2.9270), + (3.9505, -6.3029, 2.6377), + (3.7418, -5.5913, 2.3037), + (2.8769, -3.1427, 1.2046), + (2.8366, -3.0525, 1.2012), + ]: + matmul_transpose_assign(X, buf1) + matmul_transpose_assign(buf1, buf2) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X + + +@dataclass +class _muon_state: + # TODO: use Optional + worker_rank: int + process_group: ProcessGroup + shard_mesh: DeviceMesh + shard_placements: tuple[Placement, ...] + name: str + qk_clip_state: torch.Tensor | None = None + gathered_grad: torch.Tensor | None = None + scattered_u: DTensor | None = None + computed_u: torch.Tensor | None = None + gather_event: torch.cuda.Event | None = None + compute_event: torch.cuda.Event | None = None + scatter_event: torch.cuda.Event | None = None + + +def numel_for_rank( + param: DTensor, + local_rank: int, + state: _muon_state, +) -> int: + slices = get_slices_of_dtensor( + param, + local_rank, + state.shard_mesh, + state.shard_placements, + ) + + numel = 1 + for s, dim in zip(slices, param.shape): + start, stop, step = s.indices(dim) + length = max(0, (stop - start + (step - 1)) // step) + numel *= length + + return numel + + +@torch.no_grad() +def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): + """ + Pre-allocate gathered_grad buffer on compute_stream + before launching all2all gather + """ + with torch.cuda.stream(compute_stream): + for p in params: + state = param_to_state[id(p)] + if rank == state.worker_rank: + state.gathered_grad = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + else: + state.gathered_grad = None + + alloc_event = torch.cuda.Event() + alloc_event.record(compute_stream) + return alloc_event + + +@torch.no_grad() +def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, + alloc_event): + """ + All2all gathers shards so each owner rank reconstructs its full gradient + """ + with torch.cuda.stream(comm_stream): + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + + # Construct sending buffers + per_dst = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + for p in params: + state = param_to_state[id(p)] + dst = state.worker_rank + assert dst < num_ranks + shard_elems = numel_for_rank(p, rank, state) + g = p.grad + g = g.to_local().to(COMM_DTYPE).contiguous() + assert g.numel() == shard_elems + per_dst[dst].append(g.view(-1)) + send_counts[dst] += shard_elems + + assert any( + len(v) > 0 for v in per_dst + ), "At least one destination rank must receive a sharded tensor" + # list[list[Tensor]] -> list[Tensor] + per_dst = [t for dst in per_dst for t in dst] + + send_buf = torch.cat(per_dst, dim=0) + + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Compute receive sizes and allocate receiving buffers + recv_counts = [0] * num_ranks + + for src in range(num_ranks): + total = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + total += numel_for_rank(p, src, state) + recv_counts[src] = total + + recv_total = sum(recv_counts) + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + #All2All + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") + dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + ) + + # Reconstructs gathered grad from the received buffer + # + # recv_buf (num ranks = 3) + # + # From rank 0 From rank 1 From rank 2 + # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | + # + # Outer loop: + # rank 0 -> rank 1 -> rank2 + # + # Inner loop: + # p1_n -> p2_n -> p3_n + + comm_stream.wait_event(alloc_event) + + off = 0 + for src in range(num_ranks): + if recv_counts[src] == 0: + continue + + block = recv_counts[src] + inner_off = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + + # get the slice of the full dtensor corresponding to rank src. + slices = get_slices_of_dtensor(state.gathered_grad, src, + state.shard_mesh, + state.shard_placements) + + dst = state.gathered_grad[slices] + assert dst._base is state.gathered_grad + + n = dst.numel() + assert n > 0 + + sg = recv_buf.narrow(0, off + inner_off, n) + sg = sg.reshape_as(dst) + dst.copy_(sg) + + inner_off += n + off += block + + for p in params: + state = param_to_state[id(p)] + if state.worker_rank == rank: + state.gather_event = torch.cuda.Event() + state.gather_event.record(comm_stream) + else: + state.gathered_grad = None + state.gather_event = None + if none_grad: + p.grad = None + + +@torch.no_grad() +def _compute_u(p, state, steps, rank, compute_stream): + """ + On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. + """ + with torch.cuda.stream(compute_stream): + if rank == state.worker_rank: + if state.gather_event is None: + raise RuntimeError("Gather event must be set before compute.") + compute_stream.wait_event(state.gather_event) + u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) + state.gathered_grad = None + state.computed_u = u + state.compute_event = torch.cuda.Event() + state.compute_event.record() + else: + state.computed_u = None + state.compute_event = None + + +@torch.no_grad() +def _alloc_scattered_u(params, param_to_state, rank, compute_stream): + """ + Pre-allocate scattered_u buffer on compute_stream + before launching all2all gather + """ + with torch.cuda.stream(compute_stream): + for p in params: + state = param_to_state[id(p)] + state.scattered_u = torch.empty_like(p.to_local(), + dtype=COMM_DTYPE) + + alloc_event = torch.cuda.Event() + alloc_event.record(compute_stream) + return alloc_event + + +def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): + """ + All2all scatters full gradients to all ranks + """ + with torch.cuda.stream(comm_stream): + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Construct sending buffer + per_dst = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + if owned_params: + for p in owned_params: + state = param_to_state[id(p)] + if state.compute_event is None: + raise RuntimeError( + "Compute event must be set before scatter.") + comm_stream.wait_event(state.compute_event) + state.gathered_grad = None + + assert state.computed_u is not None + + u_full = state.computed_u.to(COMM_DTYPE).contiguous() + + offset = 0 + for dst in range(num_ranks): + # get the slice of the full tensor corresponding to rank dst. + slices = get_slices_of_dtensor(u_full, dst, + state.shard_mesh, + state.shard_placements) + su = u_full[slices].flatten() + + n = su.numel() + assert n > 0 + + per_dst[dst].append(su) + send_counts[dst] += n + offset += n + + assert offset == u_full.numel() + + lengths = [len(v) for v in per_dst] + if all(l > 0 for l in lengths): + assert all( + l == lengths[0] for l in lengths + ), "All destination ranks must have the same number of sharded tensor" + # list[list[Tensor]] -> list[Tensor] + per_dst = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst, dim=0) + else: + # all_to_all requires participation from all ranks + # Even non-owner ranks must join the collective call + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") + + # Compute receive sizes and allocate receiving buffers + recv_counts = [0] * num_ranks + + for src in range(num_ranks): + total = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + total += numel_for_rank(p, rank, state) + recv_counts[src] = total + + recv_total = sum(recv_counts) + assert recv_total > 0 + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + #All2All + dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + ) + + # Copy to pre-allocated scattered_u buffer from the received buffer + # + # recv_buf (num ranks = 3, local_rank = 0) + # + # From rank 0 From rank 1 From rank 2 + # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | + # + # Outer loop: + # rank 0 -> rank 1 -> rank2 + # + # Inner loop: + # src(0) : p1_0 -> p2_0 -> p3_0 + # src(1) : p4_0 + # src(2) : p5_0 -> p6_0 + + comm_stream.wait_event(alloc_event) + + off = 0 + for src in range(num_ranks): + block = recv_counts[src] + if block == 0: + continue + + inner_off = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + n = numel_for_rank(p, rank, state) + assert n > 0 + + flat_local = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) + state.scattered_u.copy_(flat_local) + + state.scatter_event = torch.cuda.Event() + state.scatter_event.record(comm_stream) + inner_off += n + + assert inner_off == block + off += block + + +def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, + compute_stream): + """ + Update sharded parameter p with the scattered_u. + Only worker_rank frees computed_u. + """ + with torch.cuda.stream(compute_stream): + if state.scatter_event is None: + raise RuntimeError("Scatter event must be set before update") + compute_stream.wait_event(state.scatter_event) + u_dtensor = DTensor.from_local( + state.scattered_u, + placements=p.placements, + device_mesh=p.device_mesh, + ) + + state.scattered_u = u_dtensor + + if rank == state.worker_rank: + # Free computed_u + state.computed_u = None + + Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) + state.scattered_u = None + u_dtensor = None + + scales_full = Muon._compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None + if scales_full is not None: + # Have to slice scales_full among dim 0 + weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, + state.shard_placements) + ratio = p.shape[0] // scales_full.shape[0] + scales_slice = slice( + None if weight_slices[0].start is None else + weight_slices[0].start // ratio, + None if weight_slices[0].stop is None else + weight_slices[0].stop // ratio, + None, + ) + + scales_local = scales_full[scales_slice] + scales_local = DTensor.from_local( + scales_local, + placements=p.placements, + device_mesh=p.device_mesh, + ) + Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) + + +def default_is_muon(name, x): + skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] + return x.ndim >= 2 and not any(key in name for key in skip_keys) + + +def get_default_muon_param_groups(model, is_muon_func=default_is_muon): + muon_params, muon_names = [], [] + non_muon_params = [] + + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + if is_muon_func(n, p): + muon_params.append(p) + muon_names.append(n) + else: + non_muon_params.append(p) + + return [ + { + "params": muon_params, + "names": muon_names, + "use_muon": True, + }, + { + "params": non_muon_params, + "use_muon": False, + }, + ] + + +def parse_qk_layer(name: str) -> tuple[str | None, int]: + """ + Parse a parameter name to check if it is a query/key projection layer + ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + + Returns: + (kind, layer_idx) or (None, -1) if not matched. + + Example: + 'model.3.attn.wq.weight' -> ('wq', 3) + 'model.5.attn.wk.weight' -> ('wk', 5) + 'model.2.attn.q_proj.weight' -> ('q_proj', 2) + 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.4.attn.v_proj.weight' -> (None, -1) + """ + parts = name.split('.') + if len(parts) < 3: + return None, -1 + + kind = parts[-2] + + layer_idx = -1 + for part in reversed(parts): + if part.isdigit(): + layer_idx = int(part) + break + + if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + return kind, layer_idx + + return None, -1 + + +@dataclass +class QKClipInfo: + """Per-parameter dynamic info computed from config + runtime logits.""" + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping + head_dim: int # from config + threshold: float # from config + logit: torch.Tensor | None + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + model: The model to be optimized by Muon. + is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + weight_decay: The weight decay for Muon and AdamW. + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. + debug: Whether to print debug information. + clip_info : Configuration for QK clipping. Expected keys: + - "q_indices" (list[int]): Indices of query heads to consider. + - "k_indices" (list[int]): Indices of key heads to consider. + - "head_dim" (int): Dimensionality of each attention head. + - "threshold" (float): Threshold value; heads whose QK logits exceed + this value will be scaled down. + Default is: + { + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100 + } + warmup_step : How many all2all gather, compute operations are launched in advance + before the corresponding all2all scatter steps begin. + A higher warmup_step increases memory usage but can improve + performance by overlapping communication. + Parallel muon only. + chunk_size : Batch size of parameters to process in each + all2all gather/compute/scatter step. + Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. + use_distributed_muon: Use distributed muon by Liu et al. (2024). + For testing purpose only. + small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon + """ + + def __init__(self, + params, + lr=1e-3, + momentum=0.95, + nesterov=True, + ns_steps=5, + weight_decay=0.1, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + none_grad=True, + debug=False, + clip_config={ + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100 + }, + warmup_step=5, + chunk_size=-1, + use_distributed_muon=False, + small_param_numel_threshold=65536): + defaults = dict( + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + none_grad=none_grad, + use_muon=True, + ) + error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." + instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" + + if isinstance(params, types.GeneratorType): + raise ValueError(error_message.format(idx=0) + instruction_code) + for _idx, param_group in enumerate(params): + if param_group.get("use_muon", None) is None: + raise ValueError( + error_message.format(idx=_idx) + instruction_code) + + super().__init__(params, defaults) + + self.rank = None + + self.comm_stream = torch.cuda.Stream() + self.compute_stream = torch.cuda.Stream() + self.debug = debug + self.clip_config = clip_config + self.warmup_step = warmup_step + self.chunk_size = chunk_size + self.use_distributed_muon = use_distributed_muon + self.small_param_numel_threshold = small_param_numel_threshold + + def _calc_flops(self, G, steps): + assert len(G.shape) == 2 + M, N = G.shape + if M > N: + M, N = N, M + + return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def set_rank_once(self, rank): + if self.rank is None: + self.rank = rank + else: + assert self.rank == rank + + def get_shard_mesh(self, p): + """ + Get the shard mesh for a parameter p on the given rank. + """ + assert isinstance( + p, DTensor), "Parallel Muon only supports DTensor parameters." + + shard_mesh, shard_pg, shard_placements = construct_shard_mesh( + p.placements, p.device_mesh) + + # set rank with the local rank in the shard process group + self.set_rank_once(dist.get_rank(group=shard_pg)) + + return shard_mesh, shard_pg, shard_placements + + def init_state_and_assign_params(self, names, params, group, qk_logits): + param_to_state = {} + param_to_flops = {} + + total_flops = 0 + for p in params: + g = p.grad + if g is None: + continue + assert g.ndim == 2, "Muon only supports 2D parameters." + + flops = self._calc_flops(g, group["ns_steps"]) + param_to_flops[id(p)] = flops + total_flops += flops + + if self.debug: + print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", + flush=True) + + paired = list(zip(names, params)) + + paired_sorted = sorted(paired, + key=lambda x: param_to_flops[id(x[1])], + reverse=True) + + names_sorted, params_sorted = zip(*paired_sorted) + ordered_names = list(names_sorted) + ordered_params = list(params_sorted) + + round_robin = 0 + mesh = ordered_params[0].device_mesh + placements = ordered_params[0].placements + + shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( + ordered_params[0]) + shard_mesh_flattened = shard_mesh.mesh.flatten() + num_ranks = dist.get_world_size(group=shard_pg) + + for n, p in zip(ordered_names, ordered_params): + if mesh != p.device_mesh: + raise ValueError("All parameters must be on the same mesh.") + if placements != p.placements: + raise ValueError("All parameters must have same placements.") + + worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks + round_robin = (round_robin + 1) % len(shard_mesh_flattened) + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + param_to_state[id(p)] = _muon_state( + worker_rank=worker_rank, + process_group=shard_pg, + shard_mesh=shard_mesh, + shard_placements=shard_placements, + name=n, + qk_clip_state=qk_clip_state, + ) + + return param_to_state, ordered_params + + def base(self, names, params, group, lr, weight_decay, momentum, + qk_logits): + # generate weight updates in distributed fashion + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + g = self._update_g(p, g, group, momentum) + + u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), + steps=group["ns_steps"]) + + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + scales_full = self._compute_scales( + p, qk_clip_state) if qk_clip_state is not None else None + if scales_full is not None: + Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + + def distributed_muon( + self, + names: list[str], + params: list[torch.nn.Parameter], + group: dict[str, Any], + lr: float, + weight_decay: float, + momentum: float, + qk_logits: list[torch.Tensor | DTensor] | None, + ): + """ Implementation of Distributed Muon by Liu et al. """ + + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + g = self._update_g(p, g, group, momentum) + + # Gather G + if isinstance(p.data, DTensor): + g_full = g.full_tensor() + p_full = p.data.full_tensor() + else: + g_full = g + p_full = p + + u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), + steps=group["ns_steps"]) + + adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) + Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + scales_full = self._compute_scales( + p_full, qk_clip_state) if qk_clip_state is not None else None + + if scales_full is not None: + Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) + + if isinstance(p.data, DTensor): + ndims = len(p.device_mesh.mesh.shape) + p_replicate = DTensor.from_local( + p_full, + device_mesh=p.device_mesh, + placements=[Replicate() for _ in range(ndims)], + ) + + p_sharded = p_replicate.redistribute( + device_mesh=p.device_mesh, + placements=p.placements, + ) + + p.copy_(p_sharded) + + def _update_g(self, p, g, group, momentum): + # calc update + state = self.state[p] + buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) + torch.add(g, buf, alpha=momentum, out=buf) + if group["nesterov"]: + g.add_(buf, alpha=momentum) + return g + return buf + + @staticmethod + def _update_p(p, u, lr, adjusted_lr, weight_decay): + if isinstance(p, torch.nn.Parameter): + # apply weight decay + p.data.mul_(1 - lr * weight_decay) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + else: + p.mul_(1 - lr * weight_decay) + p.add_(u, alpha=-adjusted_lr) + + def get_qk_clip_info(self, n, qk_logits): + if self.clip_config is None: + return None + + head_dim = self.clip_config.get('head_dim') + threshold = self.clip_config.get('threshold') + kind, layer_idx = parse_qk_layer(n) + + logit, indices = None, [] + if qk_logits is not None and kind is not None: + logit = qk_logits[layer_idx] + indices_key = 'q_indices' if 'q' in kind else 'k_indices' + indices = self.clip_config.get(indices_key, []) or [] + + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) + + @staticmethod + def _compute_scales(p, qk_clip_state): + kind = qk_clip_state.kind + indices = qk_clip_state.indices + head_dim = qk_clip_state.head_dim + threshold = qk_clip_state.threshold + logit = qk_clip_state.logit + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + scaling = 0 + + for logit_idx, head_idx in enumerate(indices): + v_ele = float(logit[logit_idx]) + if v_ele > threshold: + new_scale = math.sqrt(threshold / v_ele) + if new_scale < scales_full[head_idx]: + scales_full[head_idx] = new_scale + logger.info( + f"[{kind}] Head {head_idx} exceeded threshold " + f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" + ) + scaling += 1 + + return scales_full if scaling > 0 else None + + @staticmethod + def _qk_clip(p, scales, head_dim): + if isinstance(p, torch.nn.Parameter): + W = p.data.view(-1, head_dim, p.data.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + else: + W = p.view(-1, head_dim, p.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + + def parallel(self, names, params, group, lr, weight_decay, momentum, + qk_logits): + """ + Perform a parallel optimization step using Muon. + """ + + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + + # Update g in the local rank + g = self._update_g( + p, + g, + group, + momentum=momentum, + ) + p.grad = g + + param_to_state, ordered_params = self.init_state_and_assign_params( + names, params, group, qk_logits) + + assert self.rank is not None + + def enqueue_all2all_gather(start_idx, chunk_size): + target_params = ordered_params[start_idx:start_idx + chunk_size] + if target_params: + alloc_event = _alloc_gathered_grad(target_params, + param_to_state, self.rank, + self.compute_stream) + _all2all_gather(target_params, param_to_state, self.rank, + self.comm_stream, group["none_grad"], + alloc_event) + + def enqueue_computes(start_idx, chunk_size): + for p in ordered_params[start_idx:start_idx + chunk_size]: + state = param_to_state[id(p)] + _compute_u(p, state, group["ns_steps"], self.rank, + self.compute_stream) + + def enqueue_all2all_scatter(start_idx, chunk_size): + target_params = ordered_params[start_idx:start_idx + chunk_size] + if target_params: + alloc_event = _alloc_scattered_u(target_params, param_to_state, + self.rank, + self.compute_stream) + _all2all_scatter(target_params, param_to_state, self.rank, + self.comm_stream, alloc_event) + + def enqueue_update_param(start_idx, chunk_size): + for p in ordered_params[start_idx:start_idx + chunk_size]: + state = param_to_state[id(p)] + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + _update_param(p, state, lr, adjusted_lr, weight_decay, + self.rank, self.compute_stream) + + if self.chunk_size == -1: + shard_ranks = dist.get_world_size(param_to_state[id( + params[0])].process_group) + chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO + elif self.chunk_size > 0: + chunk_size = self.chunk_size + else: + raise ValueError("chunk_size must be -1 or a positive integer.") + + # Wait grad update + self.comm_stream.wait_stream(torch.cuda.current_stream()) + + warmup_step = self.warmup_step + for i in range(0, warmup_step): + enqueue_all2all_gather(i * chunk_size, chunk_size) + enqueue_computes(i * chunk_size, chunk_size) + + for i in range(0, len(params) + chunk_size - 1, chunk_size): + enqueue_all2all_scatter(i, chunk_size) + enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) + enqueue_update_param(i, chunk_size) + enqueue_computes(i + warmup_step * chunk_size, chunk_size) + + # Wait the last update_param to finish + torch.cuda.current_stream().wait_stream(self.compute_stream) + + @staticmethod + def _fused_adamw( + params: list[torch.Tensor], + grads: list[torch.Tensor], + exp_avgs: list[torch.Tensor], + exp_avg_sqs: list[torch.Tensor], + max_exp_avg_sqs: list[torch.Tensor], + state_steps: list[torch.Tensor], + amsgrad: bool, + beta1: float, + beta2: float, + lr: float | torch.Tensor, + weight_decay: float, + eps: float, + maximize: bool, + ) -> None: + if not params: + return + + # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer + # treating it as a scalar. + lr_dict: DeviceDict | None = ({ + lr.device: lr + } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else + None) + grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( + [ + params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, + state_steps + ] # type: ignore[list-item] + ) + for (device, _), ( + ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_max_exp_avg_sqs, + device_state_steps_, + ), + _, + ) in grouped_tensors.items(): + device_params = cast(list[torch.Tensor], device_params_) + device_grads = cast(list[torch.Tensor], device_grads_) + device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) + device_state_steps = cast(list[torch.Tensor], device_state_steps_) + + if lr_dict is not None and device not in lr_dict: + lr_dict[device] = lr.to( + device=device, + non_blocking=True) # type: ignore[union-attr] + lr = lr_dict[device] + torch._foreach_add_(device_state_steps, 1) + func = torch._fused_adamw_ + func( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, # type: ignore[arg-type] + device_state_steps, + amsgrad=amsgrad, + lr=lr, # type: ignore[arg-type] + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + ) + + def _step_muon(self, group, qk_logits=None): + params = group["params"] + lr = group["lr"] + weight_decay = group["weight_decay"] + momentum = group["momentum"] + names = group["names"] + + param_dtensors = [] + name_dtensors = [] + + param_tensors = [] + name_tensors = [] + + param_dtensors_small = [] + name_dtensors_small = [] + + if self.use_distributed_muon: + self.distributed_muon(names=names, + params=params, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits) + return + + # For simplicity, we use distributed Muon for small parameters + # whose number of elements is below a threshold. + for n, p in zip(names, params): + if p is None or p.grad is None: + continue + if isinstance(p.data, DTensor): + if all( + isinstance(placement, Replicate) + for placement in p.placements): + param_tensors.append(p) + name_tensors.append(n) + elif p.data.numel() <= self.small_param_numel_threshold: + param_dtensors_small.append(p) + name_dtensors_small.append(n) + else: + param_dtensors.append(p) + name_dtensors.append(n) + elif isinstance(p.data, torch.Tensor): + param_tensors.append(p) + name_tensors.append(n) + else: + raise TypeError(f"Unsupported parameter type: {type(p.data)}") + + logger.debug( + f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " + f"{len(param_dtensors_small)} Small DTensors") + + def group_dtensors(dtensors, names): + # To support different placements, we group parameters by placements + # and run parallel Muon on each group. + + placement_to_params = defaultdict(lambda: ([], [])) + # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] + + assert len(dtensors) == len(names) + for p, n in zip(dtensors, names): + placement_to_params[tuple([p.placements, + p.device_mesh])][0].append(n) + placement_to_params[tuple([p.placements, + p.device_mesh])][1].append(p) + return placement_to_params + + if len(param_dtensors_small) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + self.distributed_muon( + params=param_dtensors_small, + names=name_dtensors_small, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_dtensors) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + dtensor_group = group_dtensors(param_dtensors, name_dtensors) + for _, (names, params) in dtensor_group.items(): + self.parallel( + names, + params, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_tensors) > 0: + self.base( + name_tensors, + param_tensors, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + def _step_adamw_params(self, params, group): + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + self._fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + def _step_adamw(self, group): + params = group["params"] + + # group params with it's type and placement + placement_to_params: dict[tuple[Placement | type, + DeviceMesh | None]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for params in placement_to_params.values(): + self._step_adamw_params(params, group) + + @torch.no_grad + def step(self, closure=None, qk_logits=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices + to 1D tensors of shape (num_heads,), representing the maximum + QK logits across all tokens, computed as + (1 / sqrt(head_dim)) * (Q @ K^T). + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + if group["use_muon"]: + self._step_muon(group, qk_logits=qk_logits) + else: + self._step_adamw(group) + + return loss diff --git a/build/torch28-cxx11-cu129-x86_64-linux/optimizer/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/optimizer/__init__.py index 239c7a65f8293e7d0df28f05fce645af56d628c0..03dbc1afe1cf156661a2b1b22003cd5f599a0309 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/optimizer/__init__.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/optimizer/__init__.py @@ -1,5 +1,26 @@ -from .muon import Muon +import ctypes +import sys -__all__ = [ - "Muon", -] +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py b/build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py deleted file mode 100644 index 7d598206add1bca142661a3df6c510e3d9575d54..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu129-x86_64-linux/optimizer/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _optimizer_23d68bb_dirty -ops = torch.ops._optimizer_23d68bb_dirty - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_optimizer_23d68bb_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so deleted file mode 100755 index 1cc1c027dd79defcd367dd32836ace4dc43d3cf1..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu129-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:1cbcd3df518412314d547a86b947998802e488e8aec0f22bf8b59fbc2d1c91e8 -size 1983488 diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/__init__.py b/build/torch28-cxx11-rocm63-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..239c7a65f8293e7d0df28f05fce645af56d628c0 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/__init__.py @@ -0,0 +1,5 @@ +from .muon import Muon + +__all__ = [ + "Muon", +] diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/_ops.py b/build/torch28-cxx11-rocm63-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..e6f6fcf6280e969b1761926112147d3146e27b59 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _optimizer_06a260a_dirty +ops = torch.ops._optimizer_06a260a_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch28-cxx11-rocm63-x86_64-linux/_optimizer_06a260a_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..19ee075424c40e1714e4ef6561d68c368e933792 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/_optimizer_06a260a_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:90ac494e1381bedf95832a91c108ff18d900442203f9b0612efa5519956def2e +size 1865080 diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/distributed/utils.py b/build/torch28-cxx11-rocm63-x86_64-linux/distributed/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6d5843506c13d9d31603b2b4e30c1c91d0baab28 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/distributed/utils.py @@ -0,0 +1,175 @@ +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor +from torch.distributed.tensor.placement_types import (Placement, Shard, + _StridedShard) + + +def get_slices_of_dtensor( + target: DTensor | torch.Tensor, + local_rank: int, + shard_mesh: DeviceMesh, + shard_placements: tuple[Placement], +) -> tuple[slice]: + """ + Get the slice of local tensor for a given rank from a tensor. + Args: + target (DTensor | torch.Tensor): The target tensor. + rank (int): The local rank of the shard group. + shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + shard_placements (tuple[Placement]): The shard placements. + """ + + slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + + # find the global rank of the local rank in the shard mesh + rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] + + rank_coords = (shard_mesh.mesh == rank).nonzero() + + assert len(rank_coords) == 1 + rank_coords = tuple(rank_coords[0].tolist()) + + assert len(rank_coords) == len(shard_placements) + + # Caution: Assuming replicate-to-shard of the shard mesh goes with + # left-to-right sharding. This is ensured by the sorting logic of + # construct_shard_mesh function. + for i, (rank_coord, + placement) in enumerate(zip(rank_coords, shard_placements)): + assert isinstance(placement, Shard) + + num_ranks = shard_mesh.mesh.shape[i] + + dim = placement.dim + dim_size = (slices[dim].stop - slices[dim].start) + + if dim_size % num_ranks != 0: + raise NotImplementedError( + f"Dimension size {dim_size} is not divisible " + f"by number of ranks {num_ranks} for shard " + f"placement on dim {dim}. (shape: {target.shape})") + + shard_size = dim_size // num_ranks + + start = slices[dim].start + rank_coord * shard_size + end = start + shard_size + + assert start < end <= slices[dim].stop + + slices[dim] = slice(start, end) + + return tuple(slices) + + +_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, + ProcessGroup]] = dict() + + +def construct_shard_mesh( + placements: tuple[Placement], + mesh: DeviceMesh, +) -> (DeviceMesh, ProcessGroup, tuple[Placement]): + """ + Construct Shard Mesh and Placements for unsharding. + It removes Replicate placements and constructs a new Mesh and ProcessGroup. + """ + my_rank = dist.get_rank() + + assert mesh.mesh.device.type == 'cpu' + + # Copy mesh to avoid modifying the original mesh + mesh = mesh.mesh.clone() + + # 1. Sort placements. Replicate first, then Shard by dim ascending. + + # For Shard, strided shard comes after regular shard on the same dim + # to preserve left-to-right order of replicate-to-shard. + # This is because that strided shard is using stride to represent + # more fine-grained sharding on the same dim. + # Please check the URL below for _StridedShard. + # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 + + def placement_sort_key( + placement_with_index: tuple[float, Placement] + ) -> tuple[int, float, int]: # (dim, split factor, original index) + index, placement = placement_with_index + is_replicate = placement.is_replicate() + is_shard = placement.is_shard() + is_partial = placement.is_partial() + + assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" + assert not is_partial, "Partial placement is not supported." + + if is_replicate: + return (-1.0, 0, index) + elif is_shard: + if isinstance(placement, _StridedShard): + return (placement.dim, 1 / placement.split_factor, index) + return (placement.dim, 0, index) + else: + raise TypeError(f"Unknown placement type: {type(placement)}") + + placements_with_index: list[tuple[int, + Placement]] = list(enumerate(placements)) + placements_with_index = sorted(placements_with_index, + key=placement_sort_key) + + sorted_indices, sorted_placements = zip(*placements_with_index) + + # 2. Permute mesh according to sorted placements. + sorted_mesh = mesh.permute(sorted_indices) + + # 3. Collect list of shard meshes by removing replicate dims + # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] + # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) + num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + + # merge replicate dims + # shard_meshes became a list of shard meshes with a length of replicate degree + if num_replicates > 0: + sorted_mesh = sorted_mesh.flatten( + 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) + else: + shard_meshes = [sorted_mesh] + shard_placements = sorted_placements[num_replicates:] + + # assume all shard placements are different + assert len(shard_placements) == len(set(shard_placements)) + + # 4. Construct ProcessGroups + # Caution: all groups should be created in the same order in all processes, + # even though each process only needs its own group. + + # To use tensor as dict key, convert it to tuple + def tensor_to_tuple(t): + if isinstance(t, torch.Tensor): + t = t.tolist() + if isinstance(t, list): + return tuple(tensor_to_tuple(x) for x in t) + return t + + my_shard_mesh_as_tuple = None + for shard_mesh in shard_meshes: + assert isinstance(shard_mesh, torch.Tensor) + shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) + + if (my_rank == shard_mesh).any().item(): + assert my_shard_mesh_as_tuple is None + my_shard_mesh_as_tuple = shard_mesh_as_tuple + + # update global cache + if shard_mesh_as_tuple not in _ranks_to_dist_cache: + shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) + _ranks_to_dist_cache[shard_mesh_as_tuple] = ( + DeviceMesh(device_type="cuda", mesh=shard_mesh), + shard_process_group, + ) + + my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ + my_shard_mesh_as_tuple] + + return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/optimizer/matmul_transpose_triton.py b/build/torch28-cxx11-rocm63-x86_64-linux/matmul_transpose_triton.py similarity index 100% rename from build/torch29-cxx11-rocm63-x86_64-linux/optimizer/matmul_transpose_triton.py rename to build/torch28-cxx11-rocm63-x86_64-linux/matmul_transpose_triton.py diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/metadata.json b/build/torch28-cxx11-rocm63-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..76bafa5f33b6818aa6bb4cab04be811b87519b44 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/metadata.json @@ -0,0 +1 @@ +{"python-depends":[]} \ No newline at end of file diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/muon.py b/build/torch28-cxx11-rocm63-x86_64-linux/muon.py new file mode 100644 index 0000000000000000000000000000000000000000..dbf25575f185ff379789482068e4ecf55b9455a9 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/muon.py @@ -0,0 +1,1268 @@ +import logging +import math +import types +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, cast + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Replicate +from torch.distributed.tensor.placement_types import Placement + +from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor +from .matmul_transpose_triton import matmul_transpose_assign + +logger = logging.getLogger(__name__) + +COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 + + +# This code snippet is a modified version adapted from the following GitHub repositories: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +# Muon's Newton–Schulz iteration causes high variance in singular values +# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +@torch.no_grad() +# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + assert G.dtype == COMM_DTYPE + X = G # no manual typecast + + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + # Perform the NS iterations + for a, b, c in [ + (4.0848, -6.8946, 2.9270), + (3.9505, -6.3029, 2.6377), + (3.7418, -5.5913, 2.3037), + (2.8769, -3.1427, 1.2046), + (2.8366, -3.0525, 1.2012), + ]: + matmul_transpose_assign(X, buf1) + matmul_transpose_assign(buf1, buf2) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X + + +@dataclass +class _muon_state: + # TODO: use Optional + worker_rank: int + process_group: ProcessGroup + shard_mesh: DeviceMesh + shard_placements: tuple[Placement, ...] + name: str + qk_clip_state: torch.Tensor | None = None + gathered_grad: torch.Tensor | None = None + scattered_u: DTensor | None = None + computed_u: torch.Tensor | None = None + gather_event: torch.cuda.Event | None = None + compute_event: torch.cuda.Event | None = None + scatter_event: torch.cuda.Event | None = None + + +def numel_for_rank( + param: DTensor, + local_rank: int, + state: _muon_state, +) -> int: + slices = get_slices_of_dtensor( + param, + local_rank, + state.shard_mesh, + state.shard_placements, + ) + + numel = 1 + for s, dim in zip(slices, param.shape): + start, stop, step = s.indices(dim) + length = max(0, (stop - start + (step - 1)) // step) + numel *= length + + return numel + + +@torch.no_grad() +def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): + """ + Pre-allocate gathered_grad buffer on compute_stream + before launching all2all gather + """ + with torch.cuda.stream(compute_stream): + for p in params: + state = param_to_state[id(p)] + if rank == state.worker_rank: + state.gathered_grad = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + else: + state.gathered_grad = None + + alloc_event = torch.cuda.Event() + alloc_event.record(compute_stream) + return alloc_event + + +@torch.no_grad() +def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, + alloc_event): + """ + All2all gathers shards so each owner rank reconstructs its full gradient + """ + with torch.cuda.stream(comm_stream): + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + + # Construct sending buffers + per_dst = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + for p in params: + state = param_to_state[id(p)] + dst = state.worker_rank + assert dst < num_ranks + shard_elems = numel_for_rank(p, rank, state) + g = p.grad + g = g.to_local().to(COMM_DTYPE).contiguous() + assert g.numel() == shard_elems + per_dst[dst].append(g.view(-1)) + send_counts[dst] += shard_elems + + assert any( + len(v) > 0 for v in per_dst + ), "At least one destination rank must receive a sharded tensor" + # list[list[Tensor]] -> list[Tensor] + per_dst = [t for dst in per_dst for t in dst] + + send_buf = torch.cat(per_dst, dim=0) + + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Compute receive sizes and allocate receiving buffers + recv_counts = [0] * num_ranks + + for src in range(num_ranks): + total = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + total += numel_for_rank(p, src, state) + recv_counts[src] = total + + recv_total = sum(recv_counts) + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + #All2All + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") + dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + ) + + # Reconstructs gathered grad from the received buffer + # + # recv_buf (num ranks = 3) + # + # From rank 0 From rank 1 From rank 2 + # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | + # + # Outer loop: + # rank 0 -> rank 1 -> rank2 + # + # Inner loop: + # p1_n -> p2_n -> p3_n + + comm_stream.wait_event(alloc_event) + + off = 0 + for src in range(num_ranks): + if recv_counts[src] == 0: + continue + + block = recv_counts[src] + inner_off = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + + # get the slice of the full dtensor corresponding to rank src. + slices = get_slices_of_dtensor(state.gathered_grad, src, + state.shard_mesh, + state.shard_placements) + + dst = state.gathered_grad[slices] + assert dst._base is state.gathered_grad + + n = dst.numel() + assert n > 0 + + sg = recv_buf.narrow(0, off + inner_off, n) + sg = sg.reshape_as(dst) + dst.copy_(sg) + + inner_off += n + off += block + + for p in params: + state = param_to_state[id(p)] + if state.worker_rank == rank: + state.gather_event = torch.cuda.Event() + state.gather_event.record(comm_stream) + else: + state.gathered_grad = None + state.gather_event = None + if none_grad: + p.grad = None + + +@torch.no_grad() +def _compute_u(p, state, steps, rank, compute_stream): + """ + On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. + """ + with torch.cuda.stream(compute_stream): + if rank == state.worker_rank: + if state.gather_event is None: + raise RuntimeError("Gather event must be set before compute.") + compute_stream.wait_event(state.gather_event) + u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) + state.gathered_grad = None + state.computed_u = u + state.compute_event = torch.cuda.Event() + state.compute_event.record() + else: + state.computed_u = None + state.compute_event = None + + +@torch.no_grad() +def _alloc_scattered_u(params, param_to_state, rank, compute_stream): + """ + Pre-allocate scattered_u buffer on compute_stream + before launching all2all gather + """ + with torch.cuda.stream(compute_stream): + for p in params: + state = param_to_state[id(p)] + state.scattered_u = torch.empty_like(p.to_local(), + dtype=COMM_DTYPE) + + alloc_event = torch.cuda.Event() + alloc_event.record(compute_stream) + return alloc_event + + +def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): + """ + All2all scatters full gradients to all ranks + """ + with torch.cuda.stream(comm_stream): + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Construct sending buffer + per_dst = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + if owned_params: + for p in owned_params: + state = param_to_state[id(p)] + if state.compute_event is None: + raise RuntimeError( + "Compute event must be set before scatter.") + comm_stream.wait_event(state.compute_event) + state.gathered_grad = None + + assert state.computed_u is not None + + u_full = state.computed_u.to(COMM_DTYPE).contiguous() + + offset = 0 + for dst in range(num_ranks): + # get the slice of the full tensor corresponding to rank dst. + slices = get_slices_of_dtensor(u_full, dst, + state.shard_mesh, + state.shard_placements) + su = u_full[slices].flatten() + + n = su.numel() + assert n > 0 + + per_dst[dst].append(su) + send_counts[dst] += n + offset += n + + assert offset == u_full.numel() + + lengths = [len(v) for v in per_dst] + if all(l > 0 for l in lengths): + assert all( + l == lengths[0] for l in lengths + ), "All destination ranks must have the same number of sharded tensor" + # list[list[Tensor]] -> list[Tensor] + per_dst = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst, dim=0) + else: + # all_to_all requires participation from all ranks + # Even non-owner ranks must join the collective call + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") + + # Compute receive sizes and allocate receiving buffers + recv_counts = [0] * num_ranks + + for src in range(num_ranks): + total = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + total += numel_for_rank(p, rank, state) + recv_counts[src] = total + + recv_total = sum(recv_counts) + assert recv_total > 0 + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + #All2All + dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + ) + + # Copy to pre-allocated scattered_u buffer from the received buffer + # + # recv_buf (num ranks = 3, local_rank = 0) + # + # From rank 0 From rank 1 From rank 2 + # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | + # + # Outer loop: + # rank 0 -> rank 1 -> rank2 + # + # Inner loop: + # src(0) : p1_0 -> p2_0 -> p3_0 + # src(1) : p4_0 + # src(2) : p5_0 -> p6_0 + + comm_stream.wait_event(alloc_event) + + off = 0 + for src in range(num_ranks): + block = recv_counts[src] + if block == 0: + continue + + inner_off = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + n = numel_for_rank(p, rank, state) + assert n > 0 + + flat_local = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) + state.scattered_u.copy_(flat_local) + + state.scatter_event = torch.cuda.Event() + state.scatter_event.record(comm_stream) + inner_off += n + + assert inner_off == block + off += block + + +def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, + compute_stream): + """ + Update sharded parameter p with the scattered_u. + Only worker_rank frees computed_u. + """ + with torch.cuda.stream(compute_stream): + if state.scatter_event is None: + raise RuntimeError("Scatter event must be set before update") + compute_stream.wait_event(state.scatter_event) + u_dtensor = DTensor.from_local( + state.scattered_u, + placements=p.placements, + device_mesh=p.device_mesh, + ) + + state.scattered_u = u_dtensor + + if rank == state.worker_rank: + # Free computed_u + state.computed_u = None + + Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) + state.scattered_u = None + u_dtensor = None + + scales_full = Muon._compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None + if scales_full is not None: + # Have to slice scales_full among dim 0 + weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, + state.shard_placements) + ratio = p.shape[0] // scales_full.shape[0] + scales_slice = slice( + None if weight_slices[0].start is None else + weight_slices[0].start // ratio, + None if weight_slices[0].stop is None else + weight_slices[0].stop // ratio, + None, + ) + + scales_local = scales_full[scales_slice] + scales_local = DTensor.from_local( + scales_local, + placements=p.placements, + device_mesh=p.device_mesh, + ) + Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) + + +def default_is_muon(name, x): + skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] + return x.ndim >= 2 and not any(key in name for key in skip_keys) + + +def get_default_muon_param_groups(model, is_muon_func=default_is_muon): + muon_params, muon_names = [], [] + non_muon_params = [] + + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + if is_muon_func(n, p): + muon_params.append(p) + muon_names.append(n) + else: + non_muon_params.append(p) + + return [ + { + "params": muon_params, + "names": muon_names, + "use_muon": True, + }, + { + "params": non_muon_params, + "use_muon": False, + }, + ] + + +def parse_qk_layer(name: str) -> tuple[str | None, int]: + """ + Parse a parameter name to check if it is a query/key projection layer + ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + + Returns: + (kind, layer_idx) or (None, -1) if not matched. + + Example: + 'model.3.attn.wq.weight' -> ('wq', 3) + 'model.5.attn.wk.weight' -> ('wk', 5) + 'model.2.attn.q_proj.weight' -> ('q_proj', 2) + 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.4.attn.v_proj.weight' -> (None, -1) + """ + parts = name.split('.') + if len(parts) < 3: + return None, -1 + + kind = parts[-2] + + layer_idx = -1 + for part in reversed(parts): + if part.isdigit(): + layer_idx = int(part) + break + + if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + return kind, layer_idx + + return None, -1 + + +@dataclass +class QKClipInfo: + """Per-parameter dynamic info computed from config + runtime logits.""" + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping + head_dim: int # from config + threshold: float # from config + logit: torch.Tensor | None + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + model: The model to be optimized by Muon. + is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + weight_decay: The weight decay for Muon and AdamW. + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. + debug: Whether to print debug information. + clip_info : Configuration for QK clipping. Expected keys: + - "q_indices" (list[int]): Indices of query heads to consider. + - "k_indices" (list[int]): Indices of key heads to consider. + - "head_dim" (int): Dimensionality of each attention head. + - "threshold" (float): Threshold value; heads whose QK logits exceed + this value will be scaled down. + Default is: + { + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100 + } + warmup_step : How many all2all gather, compute operations are launched in advance + before the corresponding all2all scatter steps begin. + A higher warmup_step increases memory usage but can improve + performance by overlapping communication. + Parallel muon only. + chunk_size : Batch size of parameters to process in each + all2all gather/compute/scatter step. + Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. + use_distributed_muon: Use distributed muon by Liu et al. (2024). + For testing purpose only. + small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon + """ + + def __init__(self, + params, + lr=1e-3, + momentum=0.95, + nesterov=True, + ns_steps=5, + weight_decay=0.1, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + none_grad=True, + debug=False, + clip_config={ + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100 + }, + warmup_step=5, + chunk_size=-1, + use_distributed_muon=False, + small_param_numel_threshold=65536): + defaults = dict( + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + none_grad=none_grad, + use_muon=True, + ) + error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." + instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" + + if isinstance(params, types.GeneratorType): + raise ValueError(error_message.format(idx=0) + instruction_code) + for _idx, param_group in enumerate(params): + if param_group.get("use_muon", None) is None: + raise ValueError( + error_message.format(idx=_idx) + instruction_code) + + super().__init__(params, defaults) + + self.rank = None + + self.comm_stream = torch.cuda.Stream() + self.compute_stream = torch.cuda.Stream() + self.debug = debug + self.clip_config = clip_config + self.warmup_step = warmup_step + self.chunk_size = chunk_size + self.use_distributed_muon = use_distributed_muon + self.small_param_numel_threshold = small_param_numel_threshold + + def _calc_flops(self, G, steps): + assert len(G.shape) == 2 + M, N = G.shape + if M > N: + M, N = N, M + + return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def set_rank_once(self, rank): + if self.rank is None: + self.rank = rank + else: + assert self.rank == rank + + def get_shard_mesh(self, p): + """ + Get the shard mesh for a parameter p on the given rank. + """ + assert isinstance( + p, DTensor), "Parallel Muon only supports DTensor parameters." + + shard_mesh, shard_pg, shard_placements = construct_shard_mesh( + p.placements, p.device_mesh) + + # set rank with the local rank in the shard process group + self.set_rank_once(dist.get_rank(group=shard_pg)) + + return shard_mesh, shard_pg, shard_placements + + def init_state_and_assign_params(self, names, params, group, qk_logits): + param_to_state = {} + param_to_flops = {} + + total_flops = 0 + for p in params: + g = p.grad + if g is None: + continue + assert g.ndim == 2, "Muon only supports 2D parameters." + + flops = self._calc_flops(g, group["ns_steps"]) + param_to_flops[id(p)] = flops + total_flops += flops + + if self.debug: + print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", + flush=True) + + paired = list(zip(names, params)) + + paired_sorted = sorted(paired, + key=lambda x: param_to_flops[id(x[1])], + reverse=True) + + names_sorted, params_sorted = zip(*paired_sorted) + ordered_names = list(names_sorted) + ordered_params = list(params_sorted) + + round_robin = 0 + mesh = ordered_params[0].device_mesh + placements = ordered_params[0].placements + + shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( + ordered_params[0]) + shard_mesh_flattened = shard_mesh.mesh.flatten() + num_ranks = dist.get_world_size(group=shard_pg) + + for n, p in zip(ordered_names, ordered_params): + if mesh != p.device_mesh: + raise ValueError("All parameters must be on the same mesh.") + if placements != p.placements: + raise ValueError("All parameters must have same placements.") + + worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks + round_robin = (round_robin + 1) % len(shard_mesh_flattened) + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + param_to_state[id(p)] = _muon_state( + worker_rank=worker_rank, + process_group=shard_pg, + shard_mesh=shard_mesh, + shard_placements=shard_placements, + name=n, + qk_clip_state=qk_clip_state, + ) + + return param_to_state, ordered_params + + def base(self, names, params, group, lr, weight_decay, momentum, + qk_logits): + # generate weight updates in distributed fashion + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + g = self._update_g(p, g, group, momentum) + + u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), + steps=group["ns_steps"]) + + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + scales_full = self._compute_scales( + p, qk_clip_state) if qk_clip_state is not None else None + if scales_full is not None: + Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + + def distributed_muon( + self, + names: list[str], + params: list[torch.nn.Parameter], + group: dict[str, Any], + lr: float, + weight_decay: float, + momentum: float, + qk_logits: list[torch.Tensor | DTensor] | None, + ): + """ Implementation of Distributed Muon by Liu et al. """ + + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + g = self._update_g(p, g, group, momentum) + + # Gather G + if isinstance(p.data, DTensor): + g_full = g.full_tensor() + p_full = p.data.full_tensor() + else: + g_full = g + p_full = p + + u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), + steps=group["ns_steps"]) + + adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) + Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + scales_full = self._compute_scales( + p_full, qk_clip_state) if qk_clip_state is not None else None + + if scales_full is not None: + Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) + + if isinstance(p.data, DTensor): + ndims = len(p.device_mesh.mesh.shape) + p_replicate = DTensor.from_local( + p_full, + device_mesh=p.device_mesh, + placements=[Replicate() for _ in range(ndims)], + ) + + p_sharded = p_replicate.redistribute( + device_mesh=p.device_mesh, + placements=p.placements, + ) + + p.copy_(p_sharded) + + def _update_g(self, p, g, group, momentum): + # calc update + state = self.state[p] + buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) + torch.add(g, buf, alpha=momentum, out=buf) + if group["nesterov"]: + g.add_(buf, alpha=momentum) + return g + return buf + + @staticmethod + def _update_p(p, u, lr, adjusted_lr, weight_decay): + if isinstance(p, torch.nn.Parameter): + # apply weight decay + p.data.mul_(1 - lr * weight_decay) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + else: + p.mul_(1 - lr * weight_decay) + p.add_(u, alpha=-adjusted_lr) + + def get_qk_clip_info(self, n, qk_logits): + if self.clip_config is None: + return None + + head_dim = self.clip_config.get('head_dim') + threshold = self.clip_config.get('threshold') + kind, layer_idx = parse_qk_layer(n) + + logit, indices = None, [] + if qk_logits is not None and kind is not None: + logit = qk_logits[layer_idx] + indices_key = 'q_indices' if 'q' in kind else 'k_indices' + indices = self.clip_config.get(indices_key, []) or [] + + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) + + @staticmethod + def _compute_scales(p, qk_clip_state): + kind = qk_clip_state.kind + indices = qk_clip_state.indices + head_dim = qk_clip_state.head_dim + threshold = qk_clip_state.threshold + logit = qk_clip_state.logit + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + scaling = 0 + + for logit_idx, head_idx in enumerate(indices): + v_ele = float(logit[logit_idx]) + if v_ele > threshold: + new_scale = math.sqrt(threshold / v_ele) + if new_scale < scales_full[head_idx]: + scales_full[head_idx] = new_scale + logger.info( + f"[{kind}] Head {head_idx} exceeded threshold " + f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" + ) + scaling += 1 + + return scales_full if scaling > 0 else None + + @staticmethod + def _qk_clip(p, scales, head_dim): + if isinstance(p, torch.nn.Parameter): + W = p.data.view(-1, head_dim, p.data.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + else: + W = p.view(-1, head_dim, p.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + + def parallel(self, names, params, group, lr, weight_decay, momentum, + qk_logits): + """ + Perform a parallel optimization step using Muon. + """ + + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + + # Update g in the local rank + g = self._update_g( + p, + g, + group, + momentum=momentum, + ) + p.grad = g + + param_to_state, ordered_params = self.init_state_and_assign_params( + names, params, group, qk_logits) + + assert self.rank is not None + + def enqueue_all2all_gather(start_idx, chunk_size): + target_params = ordered_params[start_idx:start_idx + chunk_size] + if target_params: + alloc_event = _alloc_gathered_grad(target_params, + param_to_state, self.rank, + self.compute_stream) + _all2all_gather(target_params, param_to_state, self.rank, + self.comm_stream, group["none_grad"], + alloc_event) + + def enqueue_computes(start_idx, chunk_size): + for p in ordered_params[start_idx:start_idx + chunk_size]: + state = param_to_state[id(p)] + _compute_u(p, state, group["ns_steps"], self.rank, + self.compute_stream) + + def enqueue_all2all_scatter(start_idx, chunk_size): + target_params = ordered_params[start_idx:start_idx + chunk_size] + if target_params: + alloc_event = _alloc_scattered_u(target_params, param_to_state, + self.rank, + self.compute_stream) + _all2all_scatter(target_params, param_to_state, self.rank, + self.comm_stream, alloc_event) + + def enqueue_update_param(start_idx, chunk_size): + for p in ordered_params[start_idx:start_idx + chunk_size]: + state = param_to_state[id(p)] + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + _update_param(p, state, lr, adjusted_lr, weight_decay, + self.rank, self.compute_stream) + + if self.chunk_size == -1: + shard_ranks = dist.get_world_size(param_to_state[id( + params[0])].process_group) + chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO + elif self.chunk_size > 0: + chunk_size = self.chunk_size + else: + raise ValueError("chunk_size must be -1 or a positive integer.") + + # Wait grad update + self.comm_stream.wait_stream(torch.cuda.current_stream()) + + warmup_step = self.warmup_step + for i in range(0, warmup_step): + enqueue_all2all_gather(i * chunk_size, chunk_size) + enqueue_computes(i * chunk_size, chunk_size) + + for i in range(0, len(params) + chunk_size - 1, chunk_size): + enqueue_all2all_scatter(i, chunk_size) + enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) + enqueue_update_param(i, chunk_size) + enqueue_computes(i + warmup_step * chunk_size, chunk_size) + + # Wait the last update_param to finish + torch.cuda.current_stream().wait_stream(self.compute_stream) + + @staticmethod + def _fused_adamw( + params: list[torch.Tensor], + grads: list[torch.Tensor], + exp_avgs: list[torch.Tensor], + exp_avg_sqs: list[torch.Tensor], + max_exp_avg_sqs: list[torch.Tensor], + state_steps: list[torch.Tensor], + amsgrad: bool, + beta1: float, + beta2: float, + lr: float | torch.Tensor, + weight_decay: float, + eps: float, + maximize: bool, + ) -> None: + if not params: + return + + # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer + # treating it as a scalar. + lr_dict: DeviceDict | None = ({ + lr.device: lr + } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else + None) + grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( + [ + params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, + state_steps + ] # type: ignore[list-item] + ) + for (device, _), ( + ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_max_exp_avg_sqs, + device_state_steps_, + ), + _, + ) in grouped_tensors.items(): + device_params = cast(list[torch.Tensor], device_params_) + device_grads = cast(list[torch.Tensor], device_grads_) + device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) + device_state_steps = cast(list[torch.Tensor], device_state_steps_) + + if lr_dict is not None and device not in lr_dict: + lr_dict[device] = lr.to( + device=device, + non_blocking=True) # type: ignore[union-attr] + lr = lr_dict[device] + torch._foreach_add_(device_state_steps, 1) + func = torch._fused_adamw_ + func( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, # type: ignore[arg-type] + device_state_steps, + amsgrad=amsgrad, + lr=lr, # type: ignore[arg-type] + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + ) + + def _step_muon(self, group, qk_logits=None): + params = group["params"] + lr = group["lr"] + weight_decay = group["weight_decay"] + momentum = group["momentum"] + names = group["names"] + + param_dtensors = [] + name_dtensors = [] + + param_tensors = [] + name_tensors = [] + + param_dtensors_small = [] + name_dtensors_small = [] + + if self.use_distributed_muon: + self.distributed_muon(names=names, + params=params, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits) + return + + # For simplicity, we use distributed Muon for small parameters + # whose number of elements is below a threshold. + for n, p in zip(names, params): + if p is None or p.grad is None: + continue + if isinstance(p.data, DTensor): + if all( + isinstance(placement, Replicate) + for placement in p.placements): + param_tensors.append(p) + name_tensors.append(n) + elif p.data.numel() <= self.small_param_numel_threshold: + param_dtensors_small.append(p) + name_dtensors_small.append(n) + else: + param_dtensors.append(p) + name_dtensors.append(n) + elif isinstance(p.data, torch.Tensor): + param_tensors.append(p) + name_tensors.append(n) + else: + raise TypeError(f"Unsupported parameter type: {type(p.data)}") + + logger.debug( + f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " + f"{len(param_dtensors_small)} Small DTensors") + + def group_dtensors(dtensors, names): + # To support different placements, we group parameters by placements + # and run parallel Muon on each group. + + placement_to_params = defaultdict(lambda: ([], [])) + # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] + + assert len(dtensors) == len(names) + for p, n in zip(dtensors, names): + placement_to_params[tuple([p.placements, + p.device_mesh])][0].append(n) + placement_to_params[tuple([p.placements, + p.device_mesh])][1].append(p) + return placement_to_params + + if len(param_dtensors_small) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + self.distributed_muon( + params=param_dtensors_small, + names=name_dtensors_small, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_dtensors) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + dtensor_group = group_dtensors(param_dtensors, name_dtensors) + for _, (names, params) in dtensor_group.items(): + self.parallel( + names, + params, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_tensors) > 0: + self.base( + name_tensors, + param_tensors, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + def _step_adamw_params(self, params, group): + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + self._fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + def _step_adamw(self, group): + params = group["params"] + + # group params with it's type and placement + placement_to_params: dict[tuple[Placement | type, + DeviceMesh | None]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for params in placement_to_params.values(): + self._step_adamw_params(params, group) + + @torch.no_grad + def step(self, closure=None, qk_logits=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices + to 1D tensors of shape (num_heads,), representing the maximum + QK logits across all tokens, computed as + (1 / sqrt(head_dim)) * (Q @ K^T). + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + if group["use_muon"]: + self._step_muon(group, qk_logits=qk_logits) + else: + self._step_adamw(group) + + return loss diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/__init__.py b/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/__init__.py index 239c7a65f8293e7d0df28f05fce645af56d628c0..03dbc1afe1cf156661a2b1b22003cd5f599a0309 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/__init__.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/__init__.py @@ -1,5 +1,26 @@ -from .muon import Muon +import ctypes +import sys -__all__ = [ - "Muon", -] +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py b/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py deleted file mode 100644 index 7d598206add1bca142661a3df6c510e3d9575d54..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _optimizer_23d68bb_dirty -ops = torch.ops._optimizer_23d68bb_dirty - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_optimizer_23d68bb_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so b/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so deleted file mode 100755 index 965a07d2753d33cb8afcabbeb81d4c2f28517ce2..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:8a2999010ee158e13e3ef247e877dfab073b5bde7babefe2b2b5273b760c7ddf -size 1852152 diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/__init__.py b/build/torch28-cxx11-rocm64-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..239c7a65f8293e7d0df28f05fce645af56d628c0 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/__init__.py @@ -0,0 +1,5 @@ +from .muon import Muon + +__all__ = [ + "Muon", +] diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/_ops.py b/build/torch28-cxx11-rocm64-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..e6f6fcf6280e969b1761926112147d3146e27b59 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _optimizer_06a260a_dirty +ops = torch.ops._optimizer_06a260a_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch28-cxx11-rocm64-x86_64-linux/_optimizer_06a260a_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..d23cf944ec31a3606755cdac0f39bae6455816d5 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/_optimizer_06a260a_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5ddeadf7e678e0ff7e84b9e4f869ef45ed6840b06e9093e20210769fd15b8cad +size 1865168 diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/distributed/utils.py b/build/torch28-cxx11-rocm64-x86_64-linux/distributed/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6d5843506c13d9d31603b2b4e30c1c91d0baab28 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/distributed/utils.py @@ -0,0 +1,175 @@ +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor +from torch.distributed.tensor.placement_types import (Placement, Shard, + _StridedShard) + + +def get_slices_of_dtensor( + target: DTensor | torch.Tensor, + local_rank: int, + shard_mesh: DeviceMesh, + shard_placements: tuple[Placement], +) -> tuple[slice]: + """ + Get the slice of local tensor for a given rank from a tensor. + Args: + target (DTensor | torch.Tensor): The target tensor. + rank (int): The local rank of the shard group. + shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + shard_placements (tuple[Placement]): The shard placements. + """ + + slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + + # find the global rank of the local rank in the shard mesh + rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] + + rank_coords = (shard_mesh.mesh == rank).nonzero() + + assert len(rank_coords) == 1 + rank_coords = tuple(rank_coords[0].tolist()) + + assert len(rank_coords) == len(shard_placements) + + # Caution: Assuming replicate-to-shard of the shard mesh goes with + # left-to-right sharding. This is ensured by the sorting logic of + # construct_shard_mesh function. + for i, (rank_coord, + placement) in enumerate(zip(rank_coords, shard_placements)): + assert isinstance(placement, Shard) + + num_ranks = shard_mesh.mesh.shape[i] + + dim = placement.dim + dim_size = (slices[dim].stop - slices[dim].start) + + if dim_size % num_ranks != 0: + raise NotImplementedError( + f"Dimension size {dim_size} is not divisible " + f"by number of ranks {num_ranks} for shard " + f"placement on dim {dim}. (shape: {target.shape})") + + shard_size = dim_size // num_ranks + + start = slices[dim].start + rank_coord * shard_size + end = start + shard_size + + assert start < end <= slices[dim].stop + + slices[dim] = slice(start, end) + + return tuple(slices) + + +_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, + ProcessGroup]] = dict() + + +def construct_shard_mesh( + placements: tuple[Placement], + mesh: DeviceMesh, +) -> (DeviceMesh, ProcessGroup, tuple[Placement]): + """ + Construct Shard Mesh and Placements for unsharding. + It removes Replicate placements and constructs a new Mesh and ProcessGroup. + """ + my_rank = dist.get_rank() + + assert mesh.mesh.device.type == 'cpu' + + # Copy mesh to avoid modifying the original mesh + mesh = mesh.mesh.clone() + + # 1. Sort placements. Replicate first, then Shard by dim ascending. + + # For Shard, strided shard comes after regular shard on the same dim + # to preserve left-to-right order of replicate-to-shard. + # This is because that strided shard is using stride to represent + # more fine-grained sharding on the same dim. + # Please check the URL below for _StridedShard. + # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 + + def placement_sort_key( + placement_with_index: tuple[float, Placement] + ) -> tuple[int, float, int]: # (dim, split factor, original index) + index, placement = placement_with_index + is_replicate = placement.is_replicate() + is_shard = placement.is_shard() + is_partial = placement.is_partial() + + assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" + assert not is_partial, "Partial placement is not supported." + + if is_replicate: + return (-1.0, 0, index) + elif is_shard: + if isinstance(placement, _StridedShard): + return (placement.dim, 1 / placement.split_factor, index) + return (placement.dim, 0, index) + else: + raise TypeError(f"Unknown placement type: {type(placement)}") + + placements_with_index: list[tuple[int, + Placement]] = list(enumerate(placements)) + placements_with_index = sorted(placements_with_index, + key=placement_sort_key) + + sorted_indices, sorted_placements = zip(*placements_with_index) + + # 2. Permute mesh according to sorted placements. + sorted_mesh = mesh.permute(sorted_indices) + + # 3. Collect list of shard meshes by removing replicate dims + # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] + # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) + num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + + # merge replicate dims + # shard_meshes became a list of shard meshes with a length of replicate degree + if num_replicates > 0: + sorted_mesh = sorted_mesh.flatten( + 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) + else: + shard_meshes = [sorted_mesh] + shard_placements = sorted_placements[num_replicates:] + + # assume all shard placements are different + assert len(shard_placements) == len(set(shard_placements)) + + # 4. Construct ProcessGroups + # Caution: all groups should be created in the same order in all processes, + # even though each process only needs its own group. + + # To use tensor as dict key, convert it to tuple + def tensor_to_tuple(t): + if isinstance(t, torch.Tensor): + t = t.tolist() + if isinstance(t, list): + return tuple(tensor_to_tuple(x) for x in t) + return t + + my_shard_mesh_as_tuple = None + for shard_mesh in shard_meshes: + assert isinstance(shard_mesh, torch.Tensor) + shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) + + if (my_rank == shard_mesh).any().item(): + assert my_shard_mesh_as_tuple is None + my_shard_mesh_as_tuple = shard_mesh_as_tuple + + # update global cache + if shard_mesh_as_tuple not in _ranks_to_dist_cache: + shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) + _ranks_to_dist_cache[shard_mesh_as_tuple] = ( + DeviceMesh(device_type="cuda", mesh=shard_mesh), + shard_process_group, + ) + + my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ + my_shard_mesh_as_tuple] + + return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/optimizer/matmul_transpose_triton.py b/build/torch28-cxx11-rocm64-x86_64-linux/matmul_transpose_triton.py similarity index 100% rename from build/torch29-cxx11-rocm64-x86_64-linux/optimizer/matmul_transpose_triton.py rename to build/torch28-cxx11-rocm64-x86_64-linux/matmul_transpose_triton.py diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/metadata.json b/build/torch28-cxx11-rocm64-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..76bafa5f33b6818aa6bb4cab04be811b87519b44 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/metadata.json @@ -0,0 +1 @@ +{"python-depends":[]} \ No newline at end of file diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/muon.py b/build/torch28-cxx11-rocm64-x86_64-linux/muon.py new file mode 100644 index 0000000000000000000000000000000000000000..dbf25575f185ff379789482068e4ecf55b9455a9 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/muon.py @@ -0,0 +1,1268 @@ +import logging +import math +import types +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, cast + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Replicate +from torch.distributed.tensor.placement_types import Placement + +from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor +from .matmul_transpose_triton import matmul_transpose_assign + +logger = logging.getLogger(__name__) + +COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 + + +# This code snippet is a modified version adapted from the following GitHub repositories: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +# Muon's Newton–Schulz iteration causes high variance in singular values +# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +@torch.no_grad() +# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + assert G.dtype == COMM_DTYPE + X = G # no manual typecast + + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + # Perform the NS iterations + for a, b, c in [ + (4.0848, -6.8946, 2.9270), + (3.9505, -6.3029, 2.6377), + (3.7418, -5.5913, 2.3037), + (2.8769, -3.1427, 1.2046), + (2.8366, -3.0525, 1.2012), + ]: + matmul_transpose_assign(X, buf1) + matmul_transpose_assign(buf1, buf2) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X + + +@dataclass +class _muon_state: + # TODO: use Optional + worker_rank: int + process_group: ProcessGroup + shard_mesh: DeviceMesh + shard_placements: tuple[Placement, ...] + name: str + qk_clip_state: torch.Tensor | None = None + gathered_grad: torch.Tensor | None = None + scattered_u: DTensor | None = None + computed_u: torch.Tensor | None = None + gather_event: torch.cuda.Event | None = None + compute_event: torch.cuda.Event | None = None + scatter_event: torch.cuda.Event | None = None + + +def numel_for_rank( + param: DTensor, + local_rank: int, + state: _muon_state, +) -> int: + slices = get_slices_of_dtensor( + param, + local_rank, + state.shard_mesh, + state.shard_placements, + ) + + numel = 1 + for s, dim in zip(slices, param.shape): + start, stop, step = s.indices(dim) + length = max(0, (stop - start + (step - 1)) // step) + numel *= length + + return numel + + +@torch.no_grad() +def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): + """ + Pre-allocate gathered_grad buffer on compute_stream + before launching all2all gather + """ + with torch.cuda.stream(compute_stream): + for p in params: + state = param_to_state[id(p)] + if rank == state.worker_rank: + state.gathered_grad = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + else: + state.gathered_grad = None + + alloc_event = torch.cuda.Event() + alloc_event.record(compute_stream) + return alloc_event + + +@torch.no_grad() +def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, + alloc_event): + """ + All2all gathers shards so each owner rank reconstructs its full gradient + """ + with torch.cuda.stream(comm_stream): + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + + # Construct sending buffers + per_dst = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + for p in params: + state = param_to_state[id(p)] + dst = state.worker_rank + assert dst < num_ranks + shard_elems = numel_for_rank(p, rank, state) + g = p.grad + g = g.to_local().to(COMM_DTYPE).contiguous() + assert g.numel() == shard_elems + per_dst[dst].append(g.view(-1)) + send_counts[dst] += shard_elems + + assert any( + len(v) > 0 for v in per_dst + ), "At least one destination rank must receive a sharded tensor" + # list[list[Tensor]] -> list[Tensor] + per_dst = [t for dst in per_dst for t in dst] + + send_buf = torch.cat(per_dst, dim=0) + + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Compute receive sizes and allocate receiving buffers + recv_counts = [0] * num_ranks + + for src in range(num_ranks): + total = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + total += numel_for_rank(p, src, state) + recv_counts[src] = total + + recv_total = sum(recv_counts) + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + #All2All + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") + dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + ) + + # Reconstructs gathered grad from the received buffer + # + # recv_buf (num ranks = 3) + # + # From rank 0 From rank 1 From rank 2 + # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | + # + # Outer loop: + # rank 0 -> rank 1 -> rank2 + # + # Inner loop: + # p1_n -> p2_n -> p3_n + + comm_stream.wait_event(alloc_event) + + off = 0 + for src in range(num_ranks): + if recv_counts[src] == 0: + continue + + block = recv_counts[src] + inner_off = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + + # get the slice of the full dtensor corresponding to rank src. + slices = get_slices_of_dtensor(state.gathered_grad, src, + state.shard_mesh, + state.shard_placements) + + dst = state.gathered_grad[slices] + assert dst._base is state.gathered_grad + + n = dst.numel() + assert n > 0 + + sg = recv_buf.narrow(0, off + inner_off, n) + sg = sg.reshape_as(dst) + dst.copy_(sg) + + inner_off += n + off += block + + for p in params: + state = param_to_state[id(p)] + if state.worker_rank == rank: + state.gather_event = torch.cuda.Event() + state.gather_event.record(comm_stream) + else: + state.gathered_grad = None + state.gather_event = None + if none_grad: + p.grad = None + + +@torch.no_grad() +def _compute_u(p, state, steps, rank, compute_stream): + """ + On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. + """ + with torch.cuda.stream(compute_stream): + if rank == state.worker_rank: + if state.gather_event is None: + raise RuntimeError("Gather event must be set before compute.") + compute_stream.wait_event(state.gather_event) + u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) + state.gathered_grad = None + state.computed_u = u + state.compute_event = torch.cuda.Event() + state.compute_event.record() + else: + state.computed_u = None + state.compute_event = None + + +@torch.no_grad() +def _alloc_scattered_u(params, param_to_state, rank, compute_stream): + """ + Pre-allocate scattered_u buffer on compute_stream + before launching all2all gather + """ + with torch.cuda.stream(compute_stream): + for p in params: + state = param_to_state[id(p)] + state.scattered_u = torch.empty_like(p.to_local(), + dtype=COMM_DTYPE) + + alloc_event = torch.cuda.Event() + alloc_event.record(compute_stream) + return alloc_event + + +def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): + """ + All2all scatters full gradients to all ranks + """ + with torch.cuda.stream(comm_stream): + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Construct sending buffer + per_dst = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + if owned_params: + for p in owned_params: + state = param_to_state[id(p)] + if state.compute_event is None: + raise RuntimeError( + "Compute event must be set before scatter.") + comm_stream.wait_event(state.compute_event) + state.gathered_grad = None + + assert state.computed_u is not None + + u_full = state.computed_u.to(COMM_DTYPE).contiguous() + + offset = 0 + for dst in range(num_ranks): + # get the slice of the full tensor corresponding to rank dst. + slices = get_slices_of_dtensor(u_full, dst, + state.shard_mesh, + state.shard_placements) + su = u_full[slices].flatten() + + n = su.numel() + assert n > 0 + + per_dst[dst].append(su) + send_counts[dst] += n + offset += n + + assert offset == u_full.numel() + + lengths = [len(v) for v in per_dst] + if all(l > 0 for l in lengths): + assert all( + l == lengths[0] for l in lengths + ), "All destination ranks must have the same number of sharded tensor" + # list[list[Tensor]] -> list[Tensor] + per_dst = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst, dim=0) + else: + # all_to_all requires participation from all ranks + # Even non-owner ranks must join the collective call + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") + + # Compute receive sizes and allocate receiving buffers + recv_counts = [0] * num_ranks + + for src in range(num_ranks): + total = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + total += numel_for_rank(p, rank, state) + recv_counts[src] = total + + recv_total = sum(recv_counts) + assert recv_total > 0 + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + #All2All + dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + ) + + # Copy to pre-allocated scattered_u buffer from the received buffer + # + # recv_buf (num ranks = 3, local_rank = 0) + # + # From rank 0 From rank 1 From rank 2 + # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | + # + # Outer loop: + # rank 0 -> rank 1 -> rank2 + # + # Inner loop: + # src(0) : p1_0 -> p2_0 -> p3_0 + # src(1) : p4_0 + # src(2) : p5_0 -> p6_0 + + comm_stream.wait_event(alloc_event) + + off = 0 + for src in range(num_ranks): + block = recv_counts[src] + if block == 0: + continue + + inner_off = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + n = numel_for_rank(p, rank, state) + assert n > 0 + + flat_local = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) + state.scattered_u.copy_(flat_local) + + state.scatter_event = torch.cuda.Event() + state.scatter_event.record(comm_stream) + inner_off += n + + assert inner_off == block + off += block + + +def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, + compute_stream): + """ + Update sharded parameter p with the scattered_u. + Only worker_rank frees computed_u. + """ + with torch.cuda.stream(compute_stream): + if state.scatter_event is None: + raise RuntimeError("Scatter event must be set before update") + compute_stream.wait_event(state.scatter_event) + u_dtensor = DTensor.from_local( + state.scattered_u, + placements=p.placements, + device_mesh=p.device_mesh, + ) + + state.scattered_u = u_dtensor + + if rank == state.worker_rank: + # Free computed_u + state.computed_u = None + + Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) + state.scattered_u = None + u_dtensor = None + + scales_full = Muon._compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None + if scales_full is not None: + # Have to slice scales_full among dim 0 + weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, + state.shard_placements) + ratio = p.shape[0] // scales_full.shape[0] + scales_slice = slice( + None if weight_slices[0].start is None else + weight_slices[0].start // ratio, + None if weight_slices[0].stop is None else + weight_slices[0].stop // ratio, + None, + ) + + scales_local = scales_full[scales_slice] + scales_local = DTensor.from_local( + scales_local, + placements=p.placements, + device_mesh=p.device_mesh, + ) + Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) + + +def default_is_muon(name, x): + skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] + return x.ndim >= 2 and not any(key in name for key in skip_keys) + + +def get_default_muon_param_groups(model, is_muon_func=default_is_muon): + muon_params, muon_names = [], [] + non_muon_params = [] + + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + if is_muon_func(n, p): + muon_params.append(p) + muon_names.append(n) + else: + non_muon_params.append(p) + + return [ + { + "params": muon_params, + "names": muon_names, + "use_muon": True, + }, + { + "params": non_muon_params, + "use_muon": False, + }, + ] + + +def parse_qk_layer(name: str) -> tuple[str | None, int]: + """ + Parse a parameter name to check if it is a query/key projection layer + ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + + Returns: + (kind, layer_idx) or (None, -1) if not matched. + + Example: + 'model.3.attn.wq.weight' -> ('wq', 3) + 'model.5.attn.wk.weight' -> ('wk', 5) + 'model.2.attn.q_proj.weight' -> ('q_proj', 2) + 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.4.attn.v_proj.weight' -> (None, -1) + """ + parts = name.split('.') + if len(parts) < 3: + return None, -1 + + kind = parts[-2] + + layer_idx = -1 + for part in reversed(parts): + if part.isdigit(): + layer_idx = int(part) + break + + if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + return kind, layer_idx + + return None, -1 + + +@dataclass +class QKClipInfo: + """Per-parameter dynamic info computed from config + runtime logits.""" + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping + head_dim: int # from config + threshold: float # from config + logit: torch.Tensor | None + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + model: The model to be optimized by Muon. + is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + weight_decay: The weight decay for Muon and AdamW. + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. + debug: Whether to print debug information. + clip_info : Configuration for QK clipping. Expected keys: + - "q_indices" (list[int]): Indices of query heads to consider. + - "k_indices" (list[int]): Indices of key heads to consider. + - "head_dim" (int): Dimensionality of each attention head. + - "threshold" (float): Threshold value; heads whose QK logits exceed + this value will be scaled down. + Default is: + { + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100 + } + warmup_step : How many all2all gather, compute operations are launched in advance + before the corresponding all2all scatter steps begin. + A higher warmup_step increases memory usage but can improve + performance by overlapping communication. + Parallel muon only. + chunk_size : Batch size of parameters to process in each + all2all gather/compute/scatter step. + Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. + use_distributed_muon: Use distributed muon by Liu et al. (2024). + For testing purpose only. + small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon + """ + + def __init__(self, + params, + lr=1e-3, + momentum=0.95, + nesterov=True, + ns_steps=5, + weight_decay=0.1, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + none_grad=True, + debug=False, + clip_config={ + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100 + }, + warmup_step=5, + chunk_size=-1, + use_distributed_muon=False, + small_param_numel_threshold=65536): + defaults = dict( + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + none_grad=none_grad, + use_muon=True, + ) + error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." + instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" + + if isinstance(params, types.GeneratorType): + raise ValueError(error_message.format(idx=0) + instruction_code) + for _idx, param_group in enumerate(params): + if param_group.get("use_muon", None) is None: + raise ValueError( + error_message.format(idx=_idx) + instruction_code) + + super().__init__(params, defaults) + + self.rank = None + + self.comm_stream = torch.cuda.Stream() + self.compute_stream = torch.cuda.Stream() + self.debug = debug + self.clip_config = clip_config + self.warmup_step = warmup_step + self.chunk_size = chunk_size + self.use_distributed_muon = use_distributed_muon + self.small_param_numel_threshold = small_param_numel_threshold + + def _calc_flops(self, G, steps): + assert len(G.shape) == 2 + M, N = G.shape + if M > N: + M, N = N, M + + return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def set_rank_once(self, rank): + if self.rank is None: + self.rank = rank + else: + assert self.rank == rank + + def get_shard_mesh(self, p): + """ + Get the shard mesh for a parameter p on the given rank. + """ + assert isinstance( + p, DTensor), "Parallel Muon only supports DTensor parameters." + + shard_mesh, shard_pg, shard_placements = construct_shard_mesh( + p.placements, p.device_mesh) + + # set rank with the local rank in the shard process group + self.set_rank_once(dist.get_rank(group=shard_pg)) + + return shard_mesh, shard_pg, shard_placements + + def init_state_and_assign_params(self, names, params, group, qk_logits): + param_to_state = {} + param_to_flops = {} + + total_flops = 0 + for p in params: + g = p.grad + if g is None: + continue + assert g.ndim == 2, "Muon only supports 2D parameters." + + flops = self._calc_flops(g, group["ns_steps"]) + param_to_flops[id(p)] = flops + total_flops += flops + + if self.debug: + print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", + flush=True) + + paired = list(zip(names, params)) + + paired_sorted = sorted(paired, + key=lambda x: param_to_flops[id(x[1])], + reverse=True) + + names_sorted, params_sorted = zip(*paired_sorted) + ordered_names = list(names_sorted) + ordered_params = list(params_sorted) + + round_robin = 0 + mesh = ordered_params[0].device_mesh + placements = ordered_params[0].placements + + shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( + ordered_params[0]) + shard_mesh_flattened = shard_mesh.mesh.flatten() + num_ranks = dist.get_world_size(group=shard_pg) + + for n, p in zip(ordered_names, ordered_params): + if mesh != p.device_mesh: + raise ValueError("All parameters must be on the same mesh.") + if placements != p.placements: + raise ValueError("All parameters must have same placements.") + + worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks + round_robin = (round_robin + 1) % len(shard_mesh_flattened) + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + param_to_state[id(p)] = _muon_state( + worker_rank=worker_rank, + process_group=shard_pg, + shard_mesh=shard_mesh, + shard_placements=shard_placements, + name=n, + qk_clip_state=qk_clip_state, + ) + + return param_to_state, ordered_params + + def base(self, names, params, group, lr, weight_decay, momentum, + qk_logits): + # generate weight updates in distributed fashion + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + g = self._update_g(p, g, group, momentum) + + u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), + steps=group["ns_steps"]) + + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + scales_full = self._compute_scales( + p, qk_clip_state) if qk_clip_state is not None else None + if scales_full is not None: + Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + + def distributed_muon( + self, + names: list[str], + params: list[torch.nn.Parameter], + group: dict[str, Any], + lr: float, + weight_decay: float, + momentum: float, + qk_logits: list[torch.Tensor | DTensor] | None, + ): + """ Implementation of Distributed Muon by Liu et al. """ + + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + g = self._update_g(p, g, group, momentum) + + # Gather G + if isinstance(p.data, DTensor): + g_full = g.full_tensor() + p_full = p.data.full_tensor() + else: + g_full = g + p_full = p + + u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), + steps=group["ns_steps"]) + + adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) + Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + scales_full = self._compute_scales( + p_full, qk_clip_state) if qk_clip_state is not None else None + + if scales_full is not None: + Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) + + if isinstance(p.data, DTensor): + ndims = len(p.device_mesh.mesh.shape) + p_replicate = DTensor.from_local( + p_full, + device_mesh=p.device_mesh, + placements=[Replicate() for _ in range(ndims)], + ) + + p_sharded = p_replicate.redistribute( + device_mesh=p.device_mesh, + placements=p.placements, + ) + + p.copy_(p_sharded) + + def _update_g(self, p, g, group, momentum): + # calc update + state = self.state[p] + buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) + torch.add(g, buf, alpha=momentum, out=buf) + if group["nesterov"]: + g.add_(buf, alpha=momentum) + return g + return buf + + @staticmethod + def _update_p(p, u, lr, adjusted_lr, weight_decay): + if isinstance(p, torch.nn.Parameter): + # apply weight decay + p.data.mul_(1 - lr * weight_decay) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + else: + p.mul_(1 - lr * weight_decay) + p.add_(u, alpha=-adjusted_lr) + + def get_qk_clip_info(self, n, qk_logits): + if self.clip_config is None: + return None + + head_dim = self.clip_config.get('head_dim') + threshold = self.clip_config.get('threshold') + kind, layer_idx = parse_qk_layer(n) + + logit, indices = None, [] + if qk_logits is not None and kind is not None: + logit = qk_logits[layer_idx] + indices_key = 'q_indices' if 'q' in kind else 'k_indices' + indices = self.clip_config.get(indices_key, []) or [] + + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) + + @staticmethod + def _compute_scales(p, qk_clip_state): + kind = qk_clip_state.kind + indices = qk_clip_state.indices + head_dim = qk_clip_state.head_dim + threshold = qk_clip_state.threshold + logit = qk_clip_state.logit + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + scaling = 0 + + for logit_idx, head_idx in enumerate(indices): + v_ele = float(logit[logit_idx]) + if v_ele > threshold: + new_scale = math.sqrt(threshold / v_ele) + if new_scale < scales_full[head_idx]: + scales_full[head_idx] = new_scale + logger.info( + f"[{kind}] Head {head_idx} exceeded threshold " + f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" + ) + scaling += 1 + + return scales_full if scaling > 0 else None + + @staticmethod + def _qk_clip(p, scales, head_dim): + if isinstance(p, torch.nn.Parameter): + W = p.data.view(-1, head_dim, p.data.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + else: + W = p.view(-1, head_dim, p.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + + def parallel(self, names, params, group, lr, weight_decay, momentum, + qk_logits): + """ + Perform a parallel optimization step using Muon. + """ + + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + + # Update g in the local rank + g = self._update_g( + p, + g, + group, + momentum=momentum, + ) + p.grad = g + + param_to_state, ordered_params = self.init_state_and_assign_params( + names, params, group, qk_logits) + + assert self.rank is not None + + def enqueue_all2all_gather(start_idx, chunk_size): + target_params = ordered_params[start_idx:start_idx + chunk_size] + if target_params: + alloc_event = _alloc_gathered_grad(target_params, + param_to_state, self.rank, + self.compute_stream) + _all2all_gather(target_params, param_to_state, self.rank, + self.comm_stream, group["none_grad"], + alloc_event) + + def enqueue_computes(start_idx, chunk_size): + for p in ordered_params[start_idx:start_idx + chunk_size]: + state = param_to_state[id(p)] + _compute_u(p, state, group["ns_steps"], self.rank, + self.compute_stream) + + def enqueue_all2all_scatter(start_idx, chunk_size): + target_params = ordered_params[start_idx:start_idx + chunk_size] + if target_params: + alloc_event = _alloc_scattered_u(target_params, param_to_state, + self.rank, + self.compute_stream) + _all2all_scatter(target_params, param_to_state, self.rank, + self.comm_stream, alloc_event) + + def enqueue_update_param(start_idx, chunk_size): + for p in ordered_params[start_idx:start_idx + chunk_size]: + state = param_to_state[id(p)] + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + _update_param(p, state, lr, adjusted_lr, weight_decay, + self.rank, self.compute_stream) + + if self.chunk_size == -1: + shard_ranks = dist.get_world_size(param_to_state[id( + params[0])].process_group) + chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO + elif self.chunk_size > 0: + chunk_size = self.chunk_size + else: + raise ValueError("chunk_size must be -1 or a positive integer.") + + # Wait grad update + self.comm_stream.wait_stream(torch.cuda.current_stream()) + + warmup_step = self.warmup_step + for i in range(0, warmup_step): + enqueue_all2all_gather(i * chunk_size, chunk_size) + enqueue_computes(i * chunk_size, chunk_size) + + for i in range(0, len(params) + chunk_size - 1, chunk_size): + enqueue_all2all_scatter(i, chunk_size) + enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) + enqueue_update_param(i, chunk_size) + enqueue_computes(i + warmup_step * chunk_size, chunk_size) + + # Wait the last update_param to finish + torch.cuda.current_stream().wait_stream(self.compute_stream) + + @staticmethod + def _fused_adamw( + params: list[torch.Tensor], + grads: list[torch.Tensor], + exp_avgs: list[torch.Tensor], + exp_avg_sqs: list[torch.Tensor], + max_exp_avg_sqs: list[torch.Tensor], + state_steps: list[torch.Tensor], + amsgrad: bool, + beta1: float, + beta2: float, + lr: float | torch.Tensor, + weight_decay: float, + eps: float, + maximize: bool, + ) -> None: + if not params: + return + + # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer + # treating it as a scalar. + lr_dict: DeviceDict | None = ({ + lr.device: lr + } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else + None) + grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( + [ + params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, + state_steps + ] # type: ignore[list-item] + ) + for (device, _), ( + ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_max_exp_avg_sqs, + device_state_steps_, + ), + _, + ) in grouped_tensors.items(): + device_params = cast(list[torch.Tensor], device_params_) + device_grads = cast(list[torch.Tensor], device_grads_) + device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) + device_state_steps = cast(list[torch.Tensor], device_state_steps_) + + if lr_dict is not None and device not in lr_dict: + lr_dict[device] = lr.to( + device=device, + non_blocking=True) # type: ignore[union-attr] + lr = lr_dict[device] + torch._foreach_add_(device_state_steps, 1) + func = torch._fused_adamw_ + func( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, # type: ignore[arg-type] + device_state_steps, + amsgrad=amsgrad, + lr=lr, # type: ignore[arg-type] + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + ) + + def _step_muon(self, group, qk_logits=None): + params = group["params"] + lr = group["lr"] + weight_decay = group["weight_decay"] + momentum = group["momentum"] + names = group["names"] + + param_dtensors = [] + name_dtensors = [] + + param_tensors = [] + name_tensors = [] + + param_dtensors_small = [] + name_dtensors_small = [] + + if self.use_distributed_muon: + self.distributed_muon(names=names, + params=params, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits) + return + + # For simplicity, we use distributed Muon for small parameters + # whose number of elements is below a threshold. + for n, p in zip(names, params): + if p is None or p.grad is None: + continue + if isinstance(p.data, DTensor): + if all( + isinstance(placement, Replicate) + for placement in p.placements): + param_tensors.append(p) + name_tensors.append(n) + elif p.data.numel() <= self.small_param_numel_threshold: + param_dtensors_small.append(p) + name_dtensors_small.append(n) + else: + param_dtensors.append(p) + name_dtensors.append(n) + elif isinstance(p.data, torch.Tensor): + param_tensors.append(p) + name_tensors.append(n) + else: + raise TypeError(f"Unsupported parameter type: {type(p.data)}") + + logger.debug( + f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " + f"{len(param_dtensors_small)} Small DTensors") + + def group_dtensors(dtensors, names): + # To support different placements, we group parameters by placements + # and run parallel Muon on each group. + + placement_to_params = defaultdict(lambda: ([], [])) + # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] + + assert len(dtensors) == len(names) + for p, n in zip(dtensors, names): + placement_to_params[tuple([p.placements, + p.device_mesh])][0].append(n) + placement_to_params[tuple([p.placements, + p.device_mesh])][1].append(p) + return placement_to_params + + if len(param_dtensors_small) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + self.distributed_muon( + params=param_dtensors_small, + names=name_dtensors_small, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_dtensors) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + dtensor_group = group_dtensors(param_dtensors, name_dtensors) + for _, (names, params) in dtensor_group.items(): + self.parallel( + names, + params, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_tensors) > 0: + self.base( + name_tensors, + param_tensors, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + def _step_adamw_params(self, params, group): + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + self._fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + def _step_adamw(self, group): + params = group["params"] + + # group params with it's type and placement + placement_to_params: dict[tuple[Placement | type, + DeviceMesh | None]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for params in placement_to_params.values(): + self._step_adamw_params(params, group) + + @torch.no_grad + def step(self, closure=None, qk_logits=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices + to 1D tensors of shape (num_heads,), representing the maximum + QK logits across all tokens, computed as + (1 / sqrt(head_dim)) * (Q @ K^T). + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + if group["use_muon"]: + self._step_muon(group, qk_logits=qk_logits) + else: + self._step_adamw(group) + + return loss diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/__init__.py b/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/__init__.py index 239c7a65f8293e7d0df28f05fce645af56d628c0..03dbc1afe1cf156661a2b1b22003cd5f599a0309 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/__init__.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/__init__.py @@ -1,5 +1,26 @@ -from .muon import Muon +import ctypes +import sys -__all__ = [ - "Muon", -] +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py b/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py deleted file mode 100644 index 7d598206add1bca142661a3df6c510e3d9575d54..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _optimizer_23d68bb_dirty -ops = torch.ops._optimizer_23d68bb_dirty - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_optimizer_23d68bb_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so b/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so deleted file mode 100755 index 61f550abf74c8dc521bf56e2c9e2a904b2582331..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:55f869cf4220f2033d4e499da522da46794a682495c2b688dbcac0ec89135cf4 -size 1852240 diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/distributed/utils.py b/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/distributed/utils.py deleted file mode 100644 index 0b4b58bfb329b1c015129e4c4fc99f7bfa2ab30a..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/distributed/utils.py +++ /dev/null @@ -1,174 +0,0 @@ -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor -from torch.distributed.tensor.placement_types import (Placement, Shard, - _StridedShard) - - -def get_slices_of_dtensor( - target: DTensor | torch.Tensor, - local_rank: int, - shard_mesh: DeviceMesh, - shard_placements: tuple[Placement], -) -> tuple[slice]: - """ - Get the slice of local tensor for a given rank from a tensor. - Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. - shard_placements (tuple[Placement]): The shard placements. - """ - - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] - - # find the global rank of the local rank in the shard mesh - rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] - - rank_coords = (shard_mesh.mesh == rank).nonzero() - - assert len(rank_coords) == 1 - rank_coords = tuple(rank_coords[0].tolist()) - - assert len(rank_coords) == len(shard_placements) - - # Caution: Assuming replicate-to-shard of the shard mesh goes with - # left-to-right sharding. This is ensured by the sorting logic of - # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) - - num_ranks = shard_mesh.mesh.shape[i] - - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) - - if dim_size % num_ranks != 0: - raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}.") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) - - return tuple(slices) - - -_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, ProcessGroup]] = dict() - - -def construct_shard_mesh( - placements: tuple[Placement], - mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() - - assert mesh.mesh.device.type == 'cpu' - - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") - - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) - - sorted_indices, sorted_placements = zip(*placements_with_index) - - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) - - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) - - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh - shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) - else: - shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different - assert len(shard_placements) == len(set(shard_placements)) - - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, - ) - - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py b/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py deleted file mode 100644 index cfbcca71741be70048bfd290c62148b2aceda631..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm64-x86_64-linux/optimizer/muon.py +++ /dev/null @@ -1,1240 +0,0 @@ -import logging -import math -import types -from collections import defaultdict -from dataclasses import dataclass -from typing import Any, cast - -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign - -logger = logging.getLogger(__name__) - -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 - - -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. -@torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon -def _zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE - X = G # no manual typecast - - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) - - if G.size(0) > G.size(1): - X = X.T - return X - - -@dataclass -class _muon_state: - # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None - gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None - computed_u: torch.Tensor | None = None - gather_event: torch.cuda.Event | None = None - compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel - - -@torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None - - -@torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ - with torch.cuda.stream(compute_stream): - if rank == state.worker_rank: - if state.gather_event is None: - raise RuntimeError("Gather event must be set before compute.") - compute_stream.wait_event(state.gather_event) - u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None - state.computed_u = u - state.compute_event = torch.cuda.Event() - state.compute_event.record() - else: - state.computed_u = None - state.compute_event = None - - -@torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) - else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) - - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n - - assert inner_off == block - off += block - - -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. - """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - - if rank == state.worker_rank: - # Free computed_u - state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: - continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx - - return None, -1 - - -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None - - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Some warnings: - - We believe this optimizer is unlikely to work well for training with small batch size. - - We believe it may not work well for finetuning pretrained models, but we haven't tested this. - - Arguments: - model: The model to be optimized by Muon. - is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. - lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) - momentum: The momentum used by the internal SGD. (0.95 is a good default) - nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) - ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) - weight_decay: The weight decay for Muon and AdamW. - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. - adamw_lr: The learning rate for the internal AdamW. - adamw_betas: The betas for the internal AdamW. - adamw_eps: The epsilon for the internal AdamW. - none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. - debug: Whether to print debug information. - clip_info : Configuration for QK clipping. Expected keys: - - "q_indices" (list[int]): Indices of query heads to consider. - - "k_indices" (list[int]): Indices of key heads to consider. - - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed - this value will be scaled down. - Default is: - { - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - } - warmup_step : How many all2all gather, compute operations are launched in advance - before the corresponding all2all scatter steps begin. - A higher warmup_step increases memory usage but can improve - performance by overlapping communication. - Parallel muon only. - chunk_size : Batch size of parameters to process in each - all2all gather/compute/scatter step. - Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. - use_distributed_muon: Use distributed muon by Liu et al. (2024). - For testing purpose only. - """ - - def __init__(self, - params, - lr=1e-3, - momentum=0.95, - nesterov=True, - ns_steps=5, - weight_decay=0.1, - adamw_betas=(0.9, 0.95), - adamw_eps=1e-8, - none_grad=True, - debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, - warmup_step=5, - chunk_size=-1, - use_distributed_muon=False): - defaults = dict( - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - nesterov=nesterov, - ns_steps=ns_steps, - adamw_betas=adamw_betas, - adamw_eps=adamw_eps, - none_grad=none_grad, - use_muon=True, - ) - error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." - instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" - - if isinstance(params, types.GeneratorType): - raise ValueError(error_message.format(idx=0) + instruction_code) - for _idx, param_group in enumerate(params): - if param_group.get("use_muon", None) is None: - raise ValueError( - error_message.format(idx=_idx) + instruction_code) - - super().__init__(params, defaults) - - self.rank = None - - self.comm_stream = torch.cuda.Stream() - self.compute_stream = torch.cuda.Stream() - self.debug = debug - self.clip_config = clip_config - self.warmup_step = warmup_step - self.chunk_size = chunk_size - self.use_distributed_muon = use_distributed_muon - - def _calc_flops(self, G, steps): - assert len(G.shape) == 2 - M, N = G.shape - if M > N: - M, N = N, M - - return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) - - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - - def get_shard_mesh(self, p): - """ - Get the shard mesh for a parameter p on the given rank. - """ - assert isinstance( - p, DTensor), "Parallel Muon only supports DTensor parameters." - - shard_mesh, shard_pg, shard_placements = construct_shard_mesh( - p.placements, p.device_mesh) - - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - - return shard_mesh, shard_pg, shard_placements - - def init_state_and_assign_params(self, names, params, group, qk_logits): - param_to_state = {} - param_to_flops = {} - - total_flops = 0 - for p in params: - g = p.grad - if g is None: - continue - assert g.ndim == 2, "Muon only supports 2D parameters." - - flops = self._calc_flops(g, group["ns_steps"]) - param_to_flops[id(p)] = flops - total_flops += flops - - if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) - - paired = list(zip(names, params)) - - paired_sorted = sorted(paired, - key=lambda x: param_to_flops[id(x[1])], - reverse=True) - - names_sorted, params_sorted = zip(*paired_sorted) - ordered_names = list(names_sorted) - ordered_params = list(params_sorted) - - round_robin = 0 - mesh = ordered_params[0].device_mesh - placements = ordered_params[0].placements - - shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( - ordered_params[0]) - shard_mesh_flattened = shard_mesh.mesh.flatten() - num_ranks = dist.get_world_size(group=shard_pg) - - for n, p in zip(ordered_names, ordered_params): - if mesh != p.device_mesh: - raise ValueError("All parameters must be on the same mesh.") - if placements != p.placements: - raise ValueError("All parameters must have same placements.") - - worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks - round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - param_to_state[id(p)] = _muon_state( - worker_rank=worker_rank, - process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, - name=n, - qk_clip_state=qk_clip_state, - ) - - return param_to_state, ordered_params - - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - # generate weight updates in distributed fashion - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - # calc update - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if group["nesterov"]: - g = g.add(buf, alpha=momentum) - else: - g = buf - - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p, qk_clip_state) if qk_clip_state is not None else None - if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) - - def distributed_muon( - self, - names: list[str], - params: list[torch.nn.Parameter], - group: dict[str, Any], - lr: float, - weight_decay: float, - momentum: float, - qk_logits: list[torch.Tensor | DTensor] | None, - ): - """ Implementation of Distributed Muon by Liu et al. """ - if qk_logits is not None: - raise NotImplementedError("QK clipping is not supported yet") - - if isinstance(params[0], DTensor): - shard_mesh, _, shard_placements = construct_shard_mesh( - placements=params[0].placements, - mesh=params[0].device_mesh, - ) - - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - # calc update - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if group["nesterov"]: - g = g.add(buf, alpha=momentum) - else: - g = buf - - # Gather G - if isinstance(p.data, DTensor): - g = g.full_tensor() - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) - - if isinstance(p.data, DTensor): - slices = get_slices_of_dtensor( - target=p, - local_rank=dist.get_rank(), - shard_mesh=shard_mesh, - shard_placements=shard_placements, - ) - u_shard = u[slices] - u = DTensor.from_local( - u_shard, - device_mesh=p.device_mesh, - placements=p.placements, - ) - - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) - - def _update_g(self, p, g, group, momentum): - # calc update - state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - """ - Perform a parallel optimization step using Muon. - """ - - for p in params: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - - # Update g in the local rank - g = self._update_g( - p, - g, - group, - momentum=momentum, - ) - p.grad = g - - param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - assert self.rank is not None - - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) - - def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) - - if self.chunk_size == -1: - shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) - chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO - elif self.chunk_size > 0: - chunk_size = self.chunk_size - else: - raise ValueError("chunk_size must be -1 or a positive integer.") - - # Wait grad update - self.comm_stream.wait_stream(torch.cuda.current_stream()) - - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - - for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return - - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) - - def _step_muon(self, group, qk_logits=None): - params = group["params"] - lr = group["lr"] - weight_decay = group["weight_decay"] - momentum = group["momentum"] - names = group["names"] - - param_dtensors = [] - param_tensors = [] - name_dtensors = [] - name_tensors = [] - - if self.use_distributed_muon: - self.distributed_muon(names=names, - params=params, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits) - return - - for n, p in zip(names, params): - if p is None or p.grad is None: - continue - if isinstance(p.data, DTensor): - if all( - isinstance(placement, Replicate) - for placement in p.placements): - param_tensors.append(p) - name_tensors.append(n) - else: - param_dtensors.append(p) - name_dtensors.append(n) - elif isinstance(p.data, torch.Tensor): - param_tensors.append(p) - name_tensors.append(n) - else: - raise TypeError(f"Unsupported parameter type: {type(p.data)}") - - logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors" - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - # To support different placements, we group parameters by placements - # and run parallel Muon on each group. - - placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] - - assert len(name_dtensors) == len(param_dtensors) - for n, p in zip(name_dtensors, param_dtensors): - placement_to_params[tuple([p.placements, - p.device_mesh])][0].append(n) - placement_to_params[tuple([p.placements, - p.device_mesh])][1].append(p) - - for _, (names, params) in placement_to_params.items(): - self.parallel( - names, - params, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_tensors) > 0: - self.base( - name_tensors, - param_tensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] - - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - - def step(self, closure=None, qk_logits=None): - """Perform a single optimization step. - - Args: - closure (Callable, optional): A closure that reevaluates the model - and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as - (1 / sqrt(head_dim)) * (Q @ K^T). - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - if group["use_muon"]: - self._step_muon(group, qk_logits=qk_logits) - else: - self._step_adamw(group) - - return loss diff --git a/build/torch29-cxx11-cu126-x86_64-linux/__init__.py b/build/torch29-cxx11-cu126-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..239c7a65f8293e7d0df28f05fce645af56d628c0 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/__init__.py @@ -0,0 +1,5 @@ +from .muon import Muon + +__all__ = [ + "Muon", +] diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_ops.py b/build/torch29-cxx11-cu126-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..e6f6fcf6280e969b1761926112147d3146e27b59 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _optimizer_06a260a_dirty +ops = torch.ops._optimizer_06a260a_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch29-cxx11-cu126-x86_64-linux/_optimizer_06a260a_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..ca73c2a576e1ad27e2c5a403c459246792b9a9d1 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/_optimizer_06a260a_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:07135e56b4c66b79fcb062c0bd39e61dae7e4251f164638cd09f8e360075f215 +size 1936664 diff --git a/build/torch29-cxx11-cu126-x86_64-linux/distributed/utils.py b/build/torch29-cxx11-cu126-x86_64-linux/distributed/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6d5843506c13d9d31603b2b4e30c1c91d0baab28 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/distributed/utils.py @@ -0,0 +1,175 @@ +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor +from torch.distributed.tensor.placement_types import (Placement, Shard, + _StridedShard) + + +def get_slices_of_dtensor( + target: DTensor | torch.Tensor, + local_rank: int, + shard_mesh: DeviceMesh, + shard_placements: tuple[Placement], +) -> tuple[slice]: + """ + Get the slice of local tensor for a given rank from a tensor. + Args: + target (DTensor | torch.Tensor): The target tensor. + rank (int): The local rank of the shard group. + shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + shard_placements (tuple[Placement]): The shard placements. + """ + + slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + + # find the global rank of the local rank in the shard mesh + rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] + + rank_coords = (shard_mesh.mesh == rank).nonzero() + + assert len(rank_coords) == 1 + rank_coords = tuple(rank_coords[0].tolist()) + + assert len(rank_coords) == len(shard_placements) + + # Caution: Assuming replicate-to-shard of the shard mesh goes with + # left-to-right sharding. This is ensured by the sorting logic of + # construct_shard_mesh function. + for i, (rank_coord, + placement) in enumerate(zip(rank_coords, shard_placements)): + assert isinstance(placement, Shard) + + num_ranks = shard_mesh.mesh.shape[i] + + dim = placement.dim + dim_size = (slices[dim].stop - slices[dim].start) + + if dim_size % num_ranks != 0: + raise NotImplementedError( + f"Dimension size {dim_size} is not divisible " + f"by number of ranks {num_ranks} for shard " + f"placement on dim {dim}. (shape: {target.shape})") + + shard_size = dim_size // num_ranks + + start = slices[dim].start + rank_coord * shard_size + end = start + shard_size + + assert start < end <= slices[dim].stop + + slices[dim] = slice(start, end) + + return tuple(slices) + + +_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, + ProcessGroup]] = dict() + + +def construct_shard_mesh( + placements: tuple[Placement], + mesh: DeviceMesh, +) -> (DeviceMesh, ProcessGroup, tuple[Placement]): + """ + Construct Shard Mesh and Placements for unsharding. + It removes Replicate placements and constructs a new Mesh and ProcessGroup. + """ + my_rank = dist.get_rank() + + assert mesh.mesh.device.type == 'cpu' + + # Copy mesh to avoid modifying the original mesh + mesh = mesh.mesh.clone() + + # 1. Sort placements. Replicate first, then Shard by dim ascending. + + # For Shard, strided shard comes after regular shard on the same dim + # to preserve left-to-right order of replicate-to-shard. + # This is because that strided shard is using stride to represent + # more fine-grained sharding on the same dim. + # Please check the URL below for _StridedShard. + # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 + + def placement_sort_key( + placement_with_index: tuple[float, Placement] + ) -> tuple[int, float, int]: # (dim, split factor, original index) + index, placement = placement_with_index + is_replicate = placement.is_replicate() + is_shard = placement.is_shard() + is_partial = placement.is_partial() + + assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" + assert not is_partial, "Partial placement is not supported." + + if is_replicate: + return (-1.0, 0, index) + elif is_shard: + if isinstance(placement, _StridedShard): + return (placement.dim, 1 / placement.split_factor, index) + return (placement.dim, 0, index) + else: + raise TypeError(f"Unknown placement type: {type(placement)}") + + placements_with_index: list[tuple[int, + Placement]] = list(enumerate(placements)) + placements_with_index = sorted(placements_with_index, + key=placement_sort_key) + + sorted_indices, sorted_placements = zip(*placements_with_index) + + # 2. Permute mesh according to sorted placements. + sorted_mesh = mesh.permute(sorted_indices) + + # 3. Collect list of shard meshes by removing replicate dims + # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] + # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) + num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + + # merge replicate dims + # shard_meshes became a list of shard meshes with a length of replicate degree + if num_replicates > 0: + sorted_mesh = sorted_mesh.flatten( + 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) + else: + shard_meshes = [sorted_mesh] + shard_placements = sorted_placements[num_replicates:] + + # assume all shard placements are different + assert len(shard_placements) == len(set(shard_placements)) + + # 4. Construct ProcessGroups + # Caution: all groups should be created in the same order in all processes, + # even though each process only needs its own group. + + # To use tensor as dict key, convert it to tuple + def tensor_to_tuple(t): + if isinstance(t, torch.Tensor): + t = t.tolist() + if isinstance(t, list): + return tuple(tensor_to_tuple(x) for x in t) + return t + + my_shard_mesh_as_tuple = None + for shard_mesh in shard_meshes: + assert isinstance(shard_mesh, torch.Tensor) + shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) + + if (my_rank == shard_mesh).any().item(): + assert my_shard_mesh_as_tuple is None + my_shard_mesh_as_tuple = shard_mesh_as_tuple + + # update global cache + if shard_mesh_as_tuple not in _ranks_to_dist_cache: + shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) + _ranks_to_dist_cache[shard_mesh_as_tuple] = ( + DeviceMesh(device_type="cuda", mesh=shard_mesh), + shard_process_group, + ) + + my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ + my_shard_mesh_as_tuple] + + return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch29-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py b/build/torch29-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py new file mode 100644 index 0000000000000000000000000000000000000000..4565b2c4fd506a4218340d380d6c962b16774b1d --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/matmul_transpose_triton.py @@ -0,0 +1,128 @@ +# MIT License +# +# Copyright (c) 2025 Tianyang Lin +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import torch +import triton +import triton.language as tl + + +def get_autotune_config(): + return [ + triton.Config( + { + 'BLOCK_SIZE_M': blk_m, + 'BLOCK_SIZE_K': blk_k, + 'GROUP_SIZE_M': grp_sz + }, + num_stages=n_stages, + num_warps=n_warps) for blk_m in [32, 64, 128] + for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5] + for n_warps in [4, 8] + ] + + +@triton.autotune( + configs=get_autotune_config(), + key=['M', 'K'], +) +@triton.jit +def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr): + """ + Core kernel jit function of matmul_transpose that computes y = x @ x.T + The code is a simple adaptation from the triton `matmul` tutorial: + https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html + """ + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + if pid_m > pid_n: + return + + offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + # we use a & b ptrs to denote different rows of x. + a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk) + b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, + mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, + other=0.0) + b = tl.load(b_ptrs, + mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, + other=0.0) + accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator) + a_ptrs += BLOCK_SIZE_K * stride_xk + b_ptrs += BLOCK_SIZE_K * stride_xk + # use dtype.element_ty to accommodate different input datatypes as in cpp templates + # https://github.com/triton-lang/triton/issues/2252 + c = accumulator.to(x.dtype.element_ty) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, c, mask=c_mask) + + # transpose and copy + if pid_m < pid_n: + ct_ptrs = y + stride_ym * offs_cn[:, + None] + stride_yn * offs_cm[None, :] + ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) + + +def matmul_transpose_assign(d_in, d_out): + assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" + assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" + assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" + assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" + assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" + assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" + assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ + "First dimension of `d_in` must match first and second dimension of `d_out`" + + d_in = d_in.contiguous() + M, K = d_in.shape + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( + M, META['BLOCK_SIZE_M']), ) + with torch.cuda.device(d_in.device.index): + mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), + d_out.stride(0), d_out.stride(1)) + + +def matmul_transpose(d_in): + M, _ = d_in.shape + d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) + matmul_transpose_assign(d_in, d_out) + return d_out diff --git a/build/torch29-cxx11-cu126-x86_64-linux/metadata.json b/build/torch29-cxx11-cu126-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..76bafa5f33b6818aa6bb4cab04be811b87519b44 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/metadata.json @@ -0,0 +1 @@ +{"python-depends":[]} \ No newline at end of file diff --git a/build/torch29-cxx11-cu126-x86_64-linux/muon.py b/build/torch29-cxx11-cu126-x86_64-linux/muon.py new file mode 100644 index 0000000000000000000000000000000000000000..dbf25575f185ff379789482068e4ecf55b9455a9 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/muon.py @@ -0,0 +1,1268 @@ +import logging +import math +import types +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, cast + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Replicate +from torch.distributed.tensor.placement_types import Placement + +from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor +from .matmul_transpose_triton import matmul_transpose_assign + +logger = logging.getLogger(__name__) + +COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 + + +# This code snippet is a modified version adapted from the following GitHub repositories: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +# Muon's Newton–Schulz iteration causes high variance in singular values +# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +@torch.no_grad() +# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + assert G.dtype == COMM_DTYPE + X = G # no manual typecast + + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + # Perform the NS iterations + for a, b, c in [ + (4.0848, -6.8946, 2.9270), + (3.9505, -6.3029, 2.6377), + (3.7418, -5.5913, 2.3037), + (2.8769, -3.1427, 1.2046), + (2.8366, -3.0525, 1.2012), + ]: + matmul_transpose_assign(X, buf1) + matmul_transpose_assign(buf1, buf2) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X + + +@dataclass +class _muon_state: + # TODO: use Optional + worker_rank: int + process_group: ProcessGroup + shard_mesh: DeviceMesh + shard_placements: tuple[Placement, ...] + name: str + qk_clip_state: torch.Tensor | None = None + gathered_grad: torch.Tensor | None = None + scattered_u: DTensor | None = None + computed_u: torch.Tensor | None = None + gather_event: torch.cuda.Event | None = None + compute_event: torch.cuda.Event | None = None + scatter_event: torch.cuda.Event | None = None + + +def numel_for_rank( + param: DTensor, + local_rank: int, + state: _muon_state, +) -> int: + slices = get_slices_of_dtensor( + param, + local_rank, + state.shard_mesh, + state.shard_placements, + ) + + numel = 1 + for s, dim in zip(slices, param.shape): + start, stop, step = s.indices(dim) + length = max(0, (stop - start + (step - 1)) // step) + numel *= length + + return numel + + +@torch.no_grad() +def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): + """ + Pre-allocate gathered_grad buffer on compute_stream + before launching all2all gather + """ + with torch.cuda.stream(compute_stream): + for p in params: + state = param_to_state[id(p)] + if rank == state.worker_rank: + state.gathered_grad = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + else: + state.gathered_grad = None + + alloc_event = torch.cuda.Event() + alloc_event.record(compute_stream) + return alloc_event + + +@torch.no_grad() +def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, + alloc_event): + """ + All2all gathers shards so each owner rank reconstructs its full gradient + """ + with torch.cuda.stream(comm_stream): + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + + # Construct sending buffers + per_dst = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + for p in params: + state = param_to_state[id(p)] + dst = state.worker_rank + assert dst < num_ranks + shard_elems = numel_for_rank(p, rank, state) + g = p.grad + g = g.to_local().to(COMM_DTYPE).contiguous() + assert g.numel() == shard_elems + per_dst[dst].append(g.view(-1)) + send_counts[dst] += shard_elems + + assert any( + len(v) > 0 for v in per_dst + ), "At least one destination rank must receive a sharded tensor" + # list[list[Tensor]] -> list[Tensor] + per_dst = [t for dst in per_dst for t in dst] + + send_buf = torch.cat(per_dst, dim=0) + + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Compute receive sizes and allocate receiving buffers + recv_counts = [0] * num_ranks + + for src in range(num_ranks): + total = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + total += numel_for_rank(p, src, state) + recv_counts[src] = total + + recv_total = sum(recv_counts) + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + #All2All + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") + dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + ) + + # Reconstructs gathered grad from the received buffer + # + # recv_buf (num ranks = 3) + # + # From rank 0 From rank 1 From rank 2 + # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | + # + # Outer loop: + # rank 0 -> rank 1 -> rank2 + # + # Inner loop: + # p1_n -> p2_n -> p3_n + + comm_stream.wait_event(alloc_event) + + off = 0 + for src in range(num_ranks): + if recv_counts[src] == 0: + continue + + block = recv_counts[src] + inner_off = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + + # get the slice of the full dtensor corresponding to rank src. + slices = get_slices_of_dtensor(state.gathered_grad, src, + state.shard_mesh, + state.shard_placements) + + dst = state.gathered_grad[slices] + assert dst._base is state.gathered_grad + + n = dst.numel() + assert n > 0 + + sg = recv_buf.narrow(0, off + inner_off, n) + sg = sg.reshape_as(dst) + dst.copy_(sg) + + inner_off += n + off += block + + for p in params: + state = param_to_state[id(p)] + if state.worker_rank == rank: + state.gather_event = torch.cuda.Event() + state.gather_event.record(comm_stream) + else: + state.gathered_grad = None + state.gather_event = None + if none_grad: + p.grad = None + + +@torch.no_grad() +def _compute_u(p, state, steps, rank, compute_stream): + """ + On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. + """ + with torch.cuda.stream(compute_stream): + if rank == state.worker_rank: + if state.gather_event is None: + raise RuntimeError("Gather event must be set before compute.") + compute_stream.wait_event(state.gather_event) + u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) + state.gathered_grad = None + state.computed_u = u + state.compute_event = torch.cuda.Event() + state.compute_event.record() + else: + state.computed_u = None + state.compute_event = None + + +@torch.no_grad() +def _alloc_scattered_u(params, param_to_state, rank, compute_stream): + """ + Pre-allocate scattered_u buffer on compute_stream + before launching all2all gather + """ + with torch.cuda.stream(compute_stream): + for p in params: + state = param_to_state[id(p)] + state.scattered_u = torch.empty_like(p.to_local(), + dtype=COMM_DTYPE) + + alloc_event = torch.cuda.Event() + alloc_event.record(compute_stream) + return alloc_event + + +def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): + """ + All2all scatters full gradients to all ranks + """ + with torch.cuda.stream(comm_stream): + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Construct sending buffer + per_dst = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + if owned_params: + for p in owned_params: + state = param_to_state[id(p)] + if state.compute_event is None: + raise RuntimeError( + "Compute event must be set before scatter.") + comm_stream.wait_event(state.compute_event) + state.gathered_grad = None + + assert state.computed_u is not None + + u_full = state.computed_u.to(COMM_DTYPE).contiguous() + + offset = 0 + for dst in range(num_ranks): + # get the slice of the full tensor corresponding to rank dst. + slices = get_slices_of_dtensor(u_full, dst, + state.shard_mesh, + state.shard_placements) + su = u_full[slices].flatten() + + n = su.numel() + assert n > 0 + + per_dst[dst].append(su) + send_counts[dst] += n + offset += n + + assert offset == u_full.numel() + + lengths = [len(v) for v in per_dst] + if all(l > 0 for l in lengths): + assert all( + l == lengths[0] for l in lengths + ), "All destination ranks must have the same number of sharded tensor" + # list[list[Tensor]] -> list[Tensor] + per_dst = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst, dim=0) + else: + # all_to_all requires participation from all ranks + # Even non-owner ranks must join the collective call + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") + + # Compute receive sizes and allocate receiving buffers + recv_counts = [0] * num_ranks + + for src in range(num_ranks): + total = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + total += numel_for_rank(p, rank, state) + recv_counts[src] = total + + recv_total = sum(recv_counts) + assert recv_total > 0 + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + #All2All + dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + ) + + # Copy to pre-allocated scattered_u buffer from the received buffer + # + # recv_buf (num ranks = 3, local_rank = 0) + # + # From rank 0 From rank 1 From rank 2 + # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | + # + # Outer loop: + # rank 0 -> rank 1 -> rank2 + # + # Inner loop: + # src(0) : p1_0 -> p2_0 -> p3_0 + # src(1) : p4_0 + # src(2) : p5_0 -> p6_0 + + comm_stream.wait_event(alloc_event) + + off = 0 + for src in range(num_ranks): + block = recv_counts[src] + if block == 0: + continue + + inner_off = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + n = numel_for_rank(p, rank, state) + assert n > 0 + + flat_local = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) + state.scattered_u.copy_(flat_local) + + state.scatter_event = torch.cuda.Event() + state.scatter_event.record(comm_stream) + inner_off += n + + assert inner_off == block + off += block + + +def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, + compute_stream): + """ + Update sharded parameter p with the scattered_u. + Only worker_rank frees computed_u. + """ + with torch.cuda.stream(compute_stream): + if state.scatter_event is None: + raise RuntimeError("Scatter event must be set before update") + compute_stream.wait_event(state.scatter_event) + u_dtensor = DTensor.from_local( + state.scattered_u, + placements=p.placements, + device_mesh=p.device_mesh, + ) + + state.scattered_u = u_dtensor + + if rank == state.worker_rank: + # Free computed_u + state.computed_u = None + + Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) + state.scattered_u = None + u_dtensor = None + + scales_full = Muon._compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None + if scales_full is not None: + # Have to slice scales_full among dim 0 + weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, + state.shard_placements) + ratio = p.shape[0] // scales_full.shape[0] + scales_slice = slice( + None if weight_slices[0].start is None else + weight_slices[0].start // ratio, + None if weight_slices[0].stop is None else + weight_slices[0].stop // ratio, + None, + ) + + scales_local = scales_full[scales_slice] + scales_local = DTensor.from_local( + scales_local, + placements=p.placements, + device_mesh=p.device_mesh, + ) + Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) + + +def default_is_muon(name, x): + skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] + return x.ndim >= 2 and not any(key in name for key in skip_keys) + + +def get_default_muon_param_groups(model, is_muon_func=default_is_muon): + muon_params, muon_names = [], [] + non_muon_params = [] + + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + if is_muon_func(n, p): + muon_params.append(p) + muon_names.append(n) + else: + non_muon_params.append(p) + + return [ + { + "params": muon_params, + "names": muon_names, + "use_muon": True, + }, + { + "params": non_muon_params, + "use_muon": False, + }, + ] + + +def parse_qk_layer(name: str) -> tuple[str | None, int]: + """ + Parse a parameter name to check if it is a query/key projection layer + ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + + Returns: + (kind, layer_idx) or (None, -1) if not matched. + + Example: + 'model.3.attn.wq.weight' -> ('wq', 3) + 'model.5.attn.wk.weight' -> ('wk', 5) + 'model.2.attn.q_proj.weight' -> ('q_proj', 2) + 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.4.attn.v_proj.weight' -> (None, -1) + """ + parts = name.split('.') + if len(parts) < 3: + return None, -1 + + kind = parts[-2] + + layer_idx = -1 + for part in reversed(parts): + if part.isdigit(): + layer_idx = int(part) + break + + if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + return kind, layer_idx + + return None, -1 + + +@dataclass +class QKClipInfo: + """Per-parameter dynamic info computed from config + runtime logits.""" + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping + head_dim: int # from config + threshold: float # from config + logit: torch.Tensor | None + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + model: The model to be optimized by Muon. + is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + weight_decay: The weight decay for Muon and AdamW. + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. + debug: Whether to print debug information. + clip_info : Configuration for QK clipping. Expected keys: + - "q_indices" (list[int]): Indices of query heads to consider. + - "k_indices" (list[int]): Indices of key heads to consider. + - "head_dim" (int): Dimensionality of each attention head. + - "threshold" (float): Threshold value; heads whose QK logits exceed + this value will be scaled down. + Default is: + { + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100 + } + warmup_step : How many all2all gather, compute operations are launched in advance + before the corresponding all2all scatter steps begin. + A higher warmup_step increases memory usage but can improve + performance by overlapping communication. + Parallel muon only. + chunk_size : Batch size of parameters to process in each + all2all gather/compute/scatter step. + Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. + use_distributed_muon: Use distributed muon by Liu et al. (2024). + For testing purpose only. + small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon + """ + + def __init__(self, + params, + lr=1e-3, + momentum=0.95, + nesterov=True, + ns_steps=5, + weight_decay=0.1, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + none_grad=True, + debug=False, + clip_config={ + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100 + }, + warmup_step=5, + chunk_size=-1, + use_distributed_muon=False, + small_param_numel_threshold=65536): + defaults = dict( + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + none_grad=none_grad, + use_muon=True, + ) + error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." + instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" + + if isinstance(params, types.GeneratorType): + raise ValueError(error_message.format(idx=0) + instruction_code) + for _idx, param_group in enumerate(params): + if param_group.get("use_muon", None) is None: + raise ValueError( + error_message.format(idx=_idx) + instruction_code) + + super().__init__(params, defaults) + + self.rank = None + + self.comm_stream = torch.cuda.Stream() + self.compute_stream = torch.cuda.Stream() + self.debug = debug + self.clip_config = clip_config + self.warmup_step = warmup_step + self.chunk_size = chunk_size + self.use_distributed_muon = use_distributed_muon + self.small_param_numel_threshold = small_param_numel_threshold + + def _calc_flops(self, G, steps): + assert len(G.shape) == 2 + M, N = G.shape + if M > N: + M, N = N, M + + return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def set_rank_once(self, rank): + if self.rank is None: + self.rank = rank + else: + assert self.rank == rank + + def get_shard_mesh(self, p): + """ + Get the shard mesh for a parameter p on the given rank. + """ + assert isinstance( + p, DTensor), "Parallel Muon only supports DTensor parameters." + + shard_mesh, shard_pg, shard_placements = construct_shard_mesh( + p.placements, p.device_mesh) + + # set rank with the local rank in the shard process group + self.set_rank_once(dist.get_rank(group=shard_pg)) + + return shard_mesh, shard_pg, shard_placements + + def init_state_and_assign_params(self, names, params, group, qk_logits): + param_to_state = {} + param_to_flops = {} + + total_flops = 0 + for p in params: + g = p.grad + if g is None: + continue + assert g.ndim == 2, "Muon only supports 2D parameters." + + flops = self._calc_flops(g, group["ns_steps"]) + param_to_flops[id(p)] = flops + total_flops += flops + + if self.debug: + print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", + flush=True) + + paired = list(zip(names, params)) + + paired_sorted = sorted(paired, + key=lambda x: param_to_flops[id(x[1])], + reverse=True) + + names_sorted, params_sorted = zip(*paired_sorted) + ordered_names = list(names_sorted) + ordered_params = list(params_sorted) + + round_robin = 0 + mesh = ordered_params[0].device_mesh + placements = ordered_params[0].placements + + shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( + ordered_params[0]) + shard_mesh_flattened = shard_mesh.mesh.flatten() + num_ranks = dist.get_world_size(group=shard_pg) + + for n, p in zip(ordered_names, ordered_params): + if mesh != p.device_mesh: + raise ValueError("All parameters must be on the same mesh.") + if placements != p.placements: + raise ValueError("All parameters must have same placements.") + + worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks + round_robin = (round_robin + 1) % len(shard_mesh_flattened) + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + param_to_state[id(p)] = _muon_state( + worker_rank=worker_rank, + process_group=shard_pg, + shard_mesh=shard_mesh, + shard_placements=shard_placements, + name=n, + qk_clip_state=qk_clip_state, + ) + + return param_to_state, ordered_params + + def base(self, names, params, group, lr, weight_decay, momentum, + qk_logits): + # generate weight updates in distributed fashion + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + g = self._update_g(p, g, group, momentum) + + u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), + steps=group["ns_steps"]) + + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + scales_full = self._compute_scales( + p, qk_clip_state) if qk_clip_state is not None else None + if scales_full is not None: + Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + + def distributed_muon( + self, + names: list[str], + params: list[torch.nn.Parameter], + group: dict[str, Any], + lr: float, + weight_decay: float, + momentum: float, + qk_logits: list[torch.Tensor | DTensor] | None, + ): + """ Implementation of Distributed Muon by Liu et al. """ + + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + g = self._update_g(p, g, group, momentum) + + # Gather G + if isinstance(p.data, DTensor): + g_full = g.full_tensor() + p_full = p.data.full_tensor() + else: + g_full = g + p_full = p + + u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), + steps=group["ns_steps"]) + + adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) + Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + scales_full = self._compute_scales( + p_full, qk_clip_state) if qk_clip_state is not None else None + + if scales_full is not None: + Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) + + if isinstance(p.data, DTensor): + ndims = len(p.device_mesh.mesh.shape) + p_replicate = DTensor.from_local( + p_full, + device_mesh=p.device_mesh, + placements=[Replicate() for _ in range(ndims)], + ) + + p_sharded = p_replicate.redistribute( + device_mesh=p.device_mesh, + placements=p.placements, + ) + + p.copy_(p_sharded) + + def _update_g(self, p, g, group, momentum): + # calc update + state = self.state[p] + buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) + torch.add(g, buf, alpha=momentum, out=buf) + if group["nesterov"]: + g.add_(buf, alpha=momentum) + return g + return buf + + @staticmethod + def _update_p(p, u, lr, adjusted_lr, weight_decay): + if isinstance(p, torch.nn.Parameter): + # apply weight decay + p.data.mul_(1 - lr * weight_decay) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + else: + p.mul_(1 - lr * weight_decay) + p.add_(u, alpha=-adjusted_lr) + + def get_qk_clip_info(self, n, qk_logits): + if self.clip_config is None: + return None + + head_dim = self.clip_config.get('head_dim') + threshold = self.clip_config.get('threshold') + kind, layer_idx = parse_qk_layer(n) + + logit, indices = None, [] + if qk_logits is not None and kind is not None: + logit = qk_logits[layer_idx] + indices_key = 'q_indices' if 'q' in kind else 'k_indices' + indices = self.clip_config.get(indices_key, []) or [] + + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) + + @staticmethod + def _compute_scales(p, qk_clip_state): + kind = qk_clip_state.kind + indices = qk_clip_state.indices + head_dim = qk_clip_state.head_dim + threshold = qk_clip_state.threshold + logit = qk_clip_state.logit + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + scaling = 0 + + for logit_idx, head_idx in enumerate(indices): + v_ele = float(logit[logit_idx]) + if v_ele > threshold: + new_scale = math.sqrt(threshold / v_ele) + if new_scale < scales_full[head_idx]: + scales_full[head_idx] = new_scale + logger.info( + f"[{kind}] Head {head_idx} exceeded threshold " + f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" + ) + scaling += 1 + + return scales_full if scaling > 0 else None + + @staticmethod + def _qk_clip(p, scales, head_dim): + if isinstance(p, torch.nn.Parameter): + W = p.data.view(-1, head_dim, p.data.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + else: + W = p.view(-1, head_dim, p.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + + def parallel(self, names, params, group, lr, weight_decay, momentum, + qk_logits): + """ + Perform a parallel optimization step using Muon. + """ + + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + + # Update g in the local rank + g = self._update_g( + p, + g, + group, + momentum=momentum, + ) + p.grad = g + + param_to_state, ordered_params = self.init_state_and_assign_params( + names, params, group, qk_logits) + + assert self.rank is not None + + def enqueue_all2all_gather(start_idx, chunk_size): + target_params = ordered_params[start_idx:start_idx + chunk_size] + if target_params: + alloc_event = _alloc_gathered_grad(target_params, + param_to_state, self.rank, + self.compute_stream) + _all2all_gather(target_params, param_to_state, self.rank, + self.comm_stream, group["none_grad"], + alloc_event) + + def enqueue_computes(start_idx, chunk_size): + for p in ordered_params[start_idx:start_idx + chunk_size]: + state = param_to_state[id(p)] + _compute_u(p, state, group["ns_steps"], self.rank, + self.compute_stream) + + def enqueue_all2all_scatter(start_idx, chunk_size): + target_params = ordered_params[start_idx:start_idx + chunk_size] + if target_params: + alloc_event = _alloc_scattered_u(target_params, param_to_state, + self.rank, + self.compute_stream) + _all2all_scatter(target_params, param_to_state, self.rank, + self.comm_stream, alloc_event) + + def enqueue_update_param(start_idx, chunk_size): + for p in ordered_params[start_idx:start_idx + chunk_size]: + state = param_to_state[id(p)] + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + _update_param(p, state, lr, adjusted_lr, weight_decay, + self.rank, self.compute_stream) + + if self.chunk_size == -1: + shard_ranks = dist.get_world_size(param_to_state[id( + params[0])].process_group) + chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO + elif self.chunk_size > 0: + chunk_size = self.chunk_size + else: + raise ValueError("chunk_size must be -1 or a positive integer.") + + # Wait grad update + self.comm_stream.wait_stream(torch.cuda.current_stream()) + + warmup_step = self.warmup_step + for i in range(0, warmup_step): + enqueue_all2all_gather(i * chunk_size, chunk_size) + enqueue_computes(i * chunk_size, chunk_size) + + for i in range(0, len(params) + chunk_size - 1, chunk_size): + enqueue_all2all_scatter(i, chunk_size) + enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) + enqueue_update_param(i, chunk_size) + enqueue_computes(i + warmup_step * chunk_size, chunk_size) + + # Wait the last update_param to finish + torch.cuda.current_stream().wait_stream(self.compute_stream) + + @staticmethod + def _fused_adamw( + params: list[torch.Tensor], + grads: list[torch.Tensor], + exp_avgs: list[torch.Tensor], + exp_avg_sqs: list[torch.Tensor], + max_exp_avg_sqs: list[torch.Tensor], + state_steps: list[torch.Tensor], + amsgrad: bool, + beta1: float, + beta2: float, + lr: float | torch.Tensor, + weight_decay: float, + eps: float, + maximize: bool, + ) -> None: + if not params: + return + + # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer + # treating it as a scalar. + lr_dict: DeviceDict | None = ({ + lr.device: lr + } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else + None) + grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( + [ + params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, + state_steps + ] # type: ignore[list-item] + ) + for (device, _), ( + ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_max_exp_avg_sqs, + device_state_steps_, + ), + _, + ) in grouped_tensors.items(): + device_params = cast(list[torch.Tensor], device_params_) + device_grads = cast(list[torch.Tensor], device_grads_) + device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) + device_state_steps = cast(list[torch.Tensor], device_state_steps_) + + if lr_dict is not None and device not in lr_dict: + lr_dict[device] = lr.to( + device=device, + non_blocking=True) # type: ignore[union-attr] + lr = lr_dict[device] + torch._foreach_add_(device_state_steps, 1) + func = torch._fused_adamw_ + func( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, # type: ignore[arg-type] + device_state_steps, + amsgrad=amsgrad, + lr=lr, # type: ignore[arg-type] + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + ) + + def _step_muon(self, group, qk_logits=None): + params = group["params"] + lr = group["lr"] + weight_decay = group["weight_decay"] + momentum = group["momentum"] + names = group["names"] + + param_dtensors = [] + name_dtensors = [] + + param_tensors = [] + name_tensors = [] + + param_dtensors_small = [] + name_dtensors_small = [] + + if self.use_distributed_muon: + self.distributed_muon(names=names, + params=params, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits) + return + + # For simplicity, we use distributed Muon for small parameters + # whose number of elements is below a threshold. + for n, p in zip(names, params): + if p is None or p.grad is None: + continue + if isinstance(p.data, DTensor): + if all( + isinstance(placement, Replicate) + for placement in p.placements): + param_tensors.append(p) + name_tensors.append(n) + elif p.data.numel() <= self.small_param_numel_threshold: + param_dtensors_small.append(p) + name_dtensors_small.append(n) + else: + param_dtensors.append(p) + name_dtensors.append(n) + elif isinstance(p.data, torch.Tensor): + param_tensors.append(p) + name_tensors.append(n) + else: + raise TypeError(f"Unsupported parameter type: {type(p.data)}") + + logger.debug( + f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " + f"{len(param_dtensors_small)} Small DTensors") + + def group_dtensors(dtensors, names): + # To support different placements, we group parameters by placements + # and run parallel Muon on each group. + + placement_to_params = defaultdict(lambda: ([], [])) + # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] + + assert len(dtensors) == len(names) + for p, n in zip(dtensors, names): + placement_to_params[tuple([p.placements, + p.device_mesh])][0].append(n) + placement_to_params[tuple([p.placements, + p.device_mesh])][1].append(p) + return placement_to_params + + if len(param_dtensors_small) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + self.distributed_muon( + params=param_dtensors_small, + names=name_dtensors_small, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_dtensors) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + dtensor_group = group_dtensors(param_dtensors, name_dtensors) + for _, (names, params) in dtensor_group.items(): + self.parallel( + names, + params, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_tensors) > 0: + self.base( + name_tensors, + param_tensors, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + def _step_adamw_params(self, params, group): + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + self._fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + def _step_adamw(self, group): + params = group["params"] + + # group params with it's type and placement + placement_to_params: dict[tuple[Placement | type, + DeviceMesh | None]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for params in placement_to_params.values(): + self._step_adamw_params(params, group) + + @torch.no_grad + def step(self, closure=None, qk_logits=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices + to 1D tensors of shape (num_heads,), representing the maximum + QK logits across all tokens, computed as + (1 / sqrt(head_dim)) * (Q @ K^T). + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + if group["use_muon"]: + self._step_muon(group, qk_logits=qk_logits) + else: + self._step_adamw(group) + + return loss diff --git a/build/torch29-cxx11-cu126-x86_64-linux/optimizer/__init__.py b/build/torch29-cxx11-cu126-x86_64-linux/optimizer/__init__.py index 239c7a65f8293e7d0df28f05fce645af56d628c0..03dbc1afe1cf156661a2b1b22003cd5f599a0309 100644 --- a/build/torch29-cxx11-cu126-x86_64-linux/optimizer/__init__.py +++ b/build/torch29-cxx11-cu126-x86_64-linux/optimizer/__init__.py @@ -1,5 +1,26 @@ -from .muon import Muon +import ctypes +import sys -__all__ = [ - "Muon", -] +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu126-x86_64-linux/optimizer/_ops.py b/build/torch29-cxx11-cu126-x86_64-linux/optimizer/_ops.py deleted file mode 100644 index 7d598206add1bca142661a3df6c510e3d9575d54..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu126-x86_64-linux/optimizer/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _optimizer_23d68bb_dirty -ops = torch.ops._optimizer_23d68bb_dirty - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_optimizer_23d68bb_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu126-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so b/build/torch29-cxx11-cu126-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so deleted file mode 100755 index d5bbae6d37395b7e65f64cabcca135df1faac8b3..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu126-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:ca847c77875fc19f211a4c8ac217e9664b46c6862aa3234c270aacfea519d0f5 -size 1924376 diff --git a/build/torch29-cxx11-cu126-x86_64-linux/optimizer/distributed/utils.py b/build/torch29-cxx11-cu126-x86_64-linux/optimizer/distributed/utils.py deleted file mode 100644 index 0b4b58bfb329b1c015129e4c4fc99f7bfa2ab30a..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu126-x86_64-linux/optimizer/distributed/utils.py +++ /dev/null @@ -1,174 +0,0 @@ -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor -from torch.distributed.tensor.placement_types import (Placement, Shard, - _StridedShard) - - -def get_slices_of_dtensor( - target: DTensor | torch.Tensor, - local_rank: int, - shard_mesh: DeviceMesh, - shard_placements: tuple[Placement], -) -> tuple[slice]: - """ - Get the slice of local tensor for a given rank from a tensor. - Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. - shard_placements (tuple[Placement]): The shard placements. - """ - - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] - - # find the global rank of the local rank in the shard mesh - rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] - - rank_coords = (shard_mesh.mesh == rank).nonzero() - - assert len(rank_coords) == 1 - rank_coords = tuple(rank_coords[0].tolist()) - - assert len(rank_coords) == len(shard_placements) - - # Caution: Assuming replicate-to-shard of the shard mesh goes with - # left-to-right sharding. This is ensured by the sorting logic of - # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) - - num_ranks = shard_mesh.mesh.shape[i] - - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) - - if dim_size % num_ranks != 0: - raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}.") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) - - return tuple(slices) - - -_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, ProcessGroup]] = dict() - - -def construct_shard_mesh( - placements: tuple[Placement], - mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() - - assert mesh.mesh.device.type == 'cpu' - - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") - - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) - - sorted_indices, sorted_placements = zip(*placements_with_index) - - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) - - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) - - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh - shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) - else: - shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different - assert len(shard_placements) == len(set(shard_placements)) - - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, - ) - - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch29-cxx11-cu126-x86_64-linux/optimizer/muon.py b/build/torch29-cxx11-cu126-x86_64-linux/optimizer/muon.py deleted file mode 100644 index cfbcca71741be70048bfd290c62148b2aceda631..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu126-x86_64-linux/optimizer/muon.py +++ /dev/null @@ -1,1240 +0,0 @@ -import logging -import math -import types -from collections import defaultdict -from dataclasses import dataclass -from typing import Any, cast - -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign - -logger = logging.getLogger(__name__) - -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 - - -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. -@torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon -def _zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE - X = G # no manual typecast - - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) - - if G.size(0) > G.size(1): - X = X.T - return X - - -@dataclass -class _muon_state: - # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None - gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None - computed_u: torch.Tensor | None = None - gather_event: torch.cuda.Event | None = None - compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel - - -@torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None - - -@torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ - with torch.cuda.stream(compute_stream): - if rank == state.worker_rank: - if state.gather_event is None: - raise RuntimeError("Gather event must be set before compute.") - compute_stream.wait_event(state.gather_event) - u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None - state.computed_u = u - state.compute_event = torch.cuda.Event() - state.compute_event.record() - else: - state.computed_u = None - state.compute_event = None - - -@torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) - else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) - - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n - - assert inner_off == block - off += block - - -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. - """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - - if rank == state.worker_rank: - # Free computed_u - state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: - continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx - - return None, -1 - - -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None - - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Some warnings: - - We believe this optimizer is unlikely to work well for training with small batch size. - - We believe it may not work well for finetuning pretrained models, but we haven't tested this. - - Arguments: - model: The model to be optimized by Muon. - is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. - lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) - momentum: The momentum used by the internal SGD. (0.95 is a good default) - nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) - ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) - weight_decay: The weight decay for Muon and AdamW. - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. - adamw_lr: The learning rate for the internal AdamW. - adamw_betas: The betas for the internal AdamW. - adamw_eps: The epsilon for the internal AdamW. - none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. - debug: Whether to print debug information. - clip_info : Configuration for QK clipping. Expected keys: - - "q_indices" (list[int]): Indices of query heads to consider. - - "k_indices" (list[int]): Indices of key heads to consider. - - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed - this value will be scaled down. - Default is: - { - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - } - warmup_step : How many all2all gather, compute operations are launched in advance - before the corresponding all2all scatter steps begin. - A higher warmup_step increases memory usage but can improve - performance by overlapping communication. - Parallel muon only. - chunk_size : Batch size of parameters to process in each - all2all gather/compute/scatter step. - Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. - use_distributed_muon: Use distributed muon by Liu et al. (2024). - For testing purpose only. - """ - - def __init__(self, - params, - lr=1e-3, - momentum=0.95, - nesterov=True, - ns_steps=5, - weight_decay=0.1, - adamw_betas=(0.9, 0.95), - adamw_eps=1e-8, - none_grad=True, - debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, - warmup_step=5, - chunk_size=-1, - use_distributed_muon=False): - defaults = dict( - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - nesterov=nesterov, - ns_steps=ns_steps, - adamw_betas=adamw_betas, - adamw_eps=adamw_eps, - none_grad=none_grad, - use_muon=True, - ) - error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." - instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" - - if isinstance(params, types.GeneratorType): - raise ValueError(error_message.format(idx=0) + instruction_code) - for _idx, param_group in enumerate(params): - if param_group.get("use_muon", None) is None: - raise ValueError( - error_message.format(idx=_idx) + instruction_code) - - super().__init__(params, defaults) - - self.rank = None - - self.comm_stream = torch.cuda.Stream() - self.compute_stream = torch.cuda.Stream() - self.debug = debug - self.clip_config = clip_config - self.warmup_step = warmup_step - self.chunk_size = chunk_size - self.use_distributed_muon = use_distributed_muon - - def _calc_flops(self, G, steps): - assert len(G.shape) == 2 - M, N = G.shape - if M > N: - M, N = N, M - - return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) - - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - - def get_shard_mesh(self, p): - """ - Get the shard mesh for a parameter p on the given rank. - """ - assert isinstance( - p, DTensor), "Parallel Muon only supports DTensor parameters." - - shard_mesh, shard_pg, shard_placements = construct_shard_mesh( - p.placements, p.device_mesh) - - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - - return shard_mesh, shard_pg, shard_placements - - def init_state_and_assign_params(self, names, params, group, qk_logits): - param_to_state = {} - param_to_flops = {} - - total_flops = 0 - for p in params: - g = p.grad - if g is None: - continue - assert g.ndim == 2, "Muon only supports 2D parameters." - - flops = self._calc_flops(g, group["ns_steps"]) - param_to_flops[id(p)] = flops - total_flops += flops - - if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) - - paired = list(zip(names, params)) - - paired_sorted = sorted(paired, - key=lambda x: param_to_flops[id(x[1])], - reverse=True) - - names_sorted, params_sorted = zip(*paired_sorted) - ordered_names = list(names_sorted) - ordered_params = list(params_sorted) - - round_robin = 0 - mesh = ordered_params[0].device_mesh - placements = ordered_params[0].placements - - shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( - ordered_params[0]) - shard_mesh_flattened = shard_mesh.mesh.flatten() - num_ranks = dist.get_world_size(group=shard_pg) - - for n, p in zip(ordered_names, ordered_params): - if mesh != p.device_mesh: - raise ValueError("All parameters must be on the same mesh.") - if placements != p.placements: - raise ValueError("All parameters must have same placements.") - - worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks - round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - param_to_state[id(p)] = _muon_state( - worker_rank=worker_rank, - process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, - name=n, - qk_clip_state=qk_clip_state, - ) - - return param_to_state, ordered_params - - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - # generate weight updates in distributed fashion - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - # calc update - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if group["nesterov"]: - g = g.add(buf, alpha=momentum) - else: - g = buf - - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p, qk_clip_state) if qk_clip_state is not None else None - if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) - - def distributed_muon( - self, - names: list[str], - params: list[torch.nn.Parameter], - group: dict[str, Any], - lr: float, - weight_decay: float, - momentum: float, - qk_logits: list[torch.Tensor | DTensor] | None, - ): - """ Implementation of Distributed Muon by Liu et al. """ - if qk_logits is not None: - raise NotImplementedError("QK clipping is not supported yet") - - if isinstance(params[0], DTensor): - shard_mesh, _, shard_placements = construct_shard_mesh( - placements=params[0].placements, - mesh=params[0].device_mesh, - ) - - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - # calc update - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if group["nesterov"]: - g = g.add(buf, alpha=momentum) - else: - g = buf - - # Gather G - if isinstance(p.data, DTensor): - g = g.full_tensor() - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) - - if isinstance(p.data, DTensor): - slices = get_slices_of_dtensor( - target=p, - local_rank=dist.get_rank(), - shard_mesh=shard_mesh, - shard_placements=shard_placements, - ) - u_shard = u[slices] - u = DTensor.from_local( - u_shard, - device_mesh=p.device_mesh, - placements=p.placements, - ) - - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) - - def _update_g(self, p, g, group, momentum): - # calc update - state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - """ - Perform a parallel optimization step using Muon. - """ - - for p in params: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - - # Update g in the local rank - g = self._update_g( - p, - g, - group, - momentum=momentum, - ) - p.grad = g - - param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - assert self.rank is not None - - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) - - def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) - - if self.chunk_size == -1: - shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) - chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO - elif self.chunk_size > 0: - chunk_size = self.chunk_size - else: - raise ValueError("chunk_size must be -1 or a positive integer.") - - # Wait grad update - self.comm_stream.wait_stream(torch.cuda.current_stream()) - - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - - for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return - - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) - - def _step_muon(self, group, qk_logits=None): - params = group["params"] - lr = group["lr"] - weight_decay = group["weight_decay"] - momentum = group["momentum"] - names = group["names"] - - param_dtensors = [] - param_tensors = [] - name_dtensors = [] - name_tensors = [] - - if self.use_distributed_muon: - self.distributed_muon(names=names, - params=params, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits) - return - - for n, p in zip(names, params): - if p is None or p.grad is None: - continue - if isinstance(p.data, DTensor): - if all( - isinstance(placement, Replicate) - for placement in p.placements): - param_tensors.append(p) - name_tensors.append(n) - else: - param_dtensors.append(p) - name_dtensors.append(n) - elif isinstance(p.data, torch.Tensor): - param_tensors.append(p) - name_tensors.append(n) - else: - raise TypeError(f"Unsupported parameter type: {type(p.data)}") - - logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors" - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - # To support different placements, we group parameters by placements - # and run parallel Muon on each group. - - placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] - - assert len(name_dtensors) == len(param_dtensors) - for n, p in zip(name_dtensors, param_dtensors): - placement_to_params[tuple([p.placements, - p.device_mesh])][0].append(n) - placement_to_params[tuple([p.placements, - p.device_mesh])][1].append(p) - - for _, (names, params) in placement_to_params.items(): - self.parallel( - names, - params, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_tensors) > 0: - self.base( - name_tensors, - param_tensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] - - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - - def step(self, closure=None, qk_logits=None): - """Perform a single optimization step. - - Args: - closure (Callable, optional): A closure that reevaluates the model - and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as - (1 / sqrt(head_dim)) * (Q @ K^T). - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - if group["use_muon"]: - self._step_muon(group, qk_logits=qk_logits) - else: - self._step_adamw(group) - - return loss diff --git a/build/torch29-cxx11-cu128-x86_64-linux/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..239c7a65f8293e7d0df28f05fce645af56d628c0 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/__init__.py @@ -0,0 +1,5 @@ +from .muon import Muon + +__all__ = [ + "Muon", +] diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_ops.py b/build/torch29-cxx11-cu128-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..e6f6fcf6280e969b1761926112147d3146e27b59 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _optimizer_06a260a_dirty +ops = torch.ops._optimizer_06a260a_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch29-cxx11-cu128-x86_64-linux/_optimizer_06a260a_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..b4ccc5bd24c68e412968b43af9a352dd5ac27863 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/_optimizer_06a260a_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f048516a9820c335263f335df545e404e22ee146355b49669c95a54852448542 +size 1999872 diff --git a/build/torch29-cxx11-cu128-x86_64-linux/distributed/utils.py b/build/torch29-cxx11-cu128-x86_64-linux/distributed/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6d5843506c13d9d31603b2b4e30c1c91d0baab28 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/distributed/utils.py @@ -0,0 +1,175 @@ +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor +from torch.distributed.tensor.placement_types import (Placement, Shard, + _StridedShard) + + +def get_slices_of_dtensor( + target: DTensor | torch.Tensor, + local_rank: int, + shard_mesh: DeviceMesh, + shard_placements: tuple[Placement], +) -> tuple[slice]: + """ + Get the slice of local tensor for a given rank from a tensor. + Args: + target (DTensor | torch.Tensor): The target tensor. + rank (int): The local rank of the shard group. + shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + shard_placements (tuple[Placement]): The shard placements. + """ + + slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + + # find the global rank of the local rank in the shard mesh + rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] + + rank_coords = (shard_mesh.mesh == rank).nonzero() + + assert len(rank_coords) == 1 + rank_coords = tuple(rank_coords[0].tolist()) + + assert len(rank_coords) == len(shard_placements) + + # Caution: Assuming replicate-to-shard of the shard mesh goes with + # left-to-right sharding. This is ensured by the sorting logic of + # construct_shard_mesh function. + for i, (rank_coord, + placement) in enumerate(zip(rank_coords, shard_placements)): + assert isinstance(placement, Shard) + + num_ranks = shard_mesh.mesh.shape[i] + + dim = placement.dim + dim_size = (slices[dim].stop - slices[dim].start) + + if dim_size % num_ranks != 0: + raise NotImplementedError( + f"Dimension size {dim_size} is not divisible " + f"by number of ranks {num_ranks} for shard " + f"placement on dim {dim}. (shape: {target.shape})") + + shard_size = dim_size // num_ranks + + start = slices[dim].start + rank_coord * shard_size + end = start + shard_size + + assert start < end <= slices[dim].stop + + slices[dim] = slice(start, end) + + return tuple(slices) + + +_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, + ProcessGroup]] = dict() + + +def construct_shard_mesh( + placements: tuple[Placement], + mesh: DeviceMesh, +) -> (DeviceMesh, ProcessGroup, tuple[Placement]): + """ + Construct Shard Mesh and Placements for unsharding. + It removes Replicate placements and constructs a new Mesh and ProcessGroup. + """ + my_rank = dist.get_rank() + + assert mesh.mesh.device.type == 'cpu' + + # Copy mesh to avoid modifying the original mesh + mesh = mesh.mesh.clone() + + # 1. Sort placements. Replicate first, then Shard by dim ascending. + + # For Shard, strided shard comes after regular shard on the same dim + # to preserve left-to-right order of replicate-to-shard. + # This is because that strided shard is using stride to represent + # more fine-grained sharding on the same dim. + # Please check the URL below for _StridedShard. + # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 + + def placement_sort_key( + placement_with_index: tuple[float, Placement] + ) -> tuple[int, float, int]: # (dim, split factor, original index) + index, placement = placement_with_index + is_replicate = placement.is_replicate() + is_shard = placement.is_shard() + is_partial = placement.is_partial() + + assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" + assert not is_partial, "Partial placement is not supported." + + if is_replicate: + return (-1.0, 0, index) + elif is_shard: + if isinstance(placement, _StridedShard): + return (placement.dim, 1 / placement.split_factor, index) + return (placement.dim, 0, index) + else: + raise TypeError(f"Unknown placement type: {type(placement)}") + + placements_with_index: list[tuple[int, + Placement]] = list(enumerate(placements)) + placements_with_index = sorted(placements_with_index, + key=placement_sort_key) + + sorted_indices, sorted_placements = zip(*placements_with_index) + + # 2. Permute mesh according to sorted placements. + sorted_mesh = mesh.permute(sorted_indices) + + # 3. Collect list of shard meshes by removing replicate dims + # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] + # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) + num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + + # merge replicate dims + # shard_meshes became a list of shard meshes with a length of replicate degree + if num_replicates > 0: + sorted_mesh = sorted_mesh.flatten( + 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) + else: + shard_meshes = [sorted_mesh] + shard_placements = sorted_placements[num_replicates:] + + # assume all shard placements are different + assert len(shard_placements) == len(set(shard_placements)) + + # 4. Construct ProcessGroups + # Caution: all groups should be created in the same order in all processes, + # even though each process only needs its own group. + + # To use tensor as dict key, convert it to tuple + def tensor_to_tuple(t): + if isinstance(t, torch.Tensor): + t = t.tolist() + if isinstance(t, list): + return tuple(tensor_to_tuple(x) for x in t) + return t + + my_shard_mesh_as_tuple = None + for shard_mesh in shard_meshes: + assert isinstance(shard_mesh, torch.Tensor) + shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) + + if (my_rank == shard_mesh).any().item(): + assert my_shard_mesh_as_tuple is None + my_shard_mesh_as_tuple = shard_mesh_as_tuple + + # update global cache + if shard_mesh_as_tuple not in _ranks_to_dist_cache: + shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) + _ranks_to_dist_cache[shard_mesh_as_tuple] = ( + DeviceMesh(device_type="cuda", mesh=shard_mesh), + shard_process_group, + ) + + my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ + my_shard_mesh_as_tuple] + + return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch29-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py b/build/torch29-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py new file mode 100644 index 0000000000000000000000000000000000000000..4565b2c4fd506a4218340d380d6c962b16774b1d --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/matmul_transpose_triton.py @@ -0,0 +1,128 @@ +# MIT License +# +# Copyright (c) 2025 Tianyang Lin +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import torch +import triton +import triton.language as tl + + +def get_autotune_config(): + return [ + triton.Config( + { + 'BLOCK_SIZE_M': blk_m, + 'BLOCK_SIZE_K': blk_k, + 'GROUP_SIZE_M': grp_sz + }, + num_stages=n_stages, + num_warps=n_warps) for blk_m in [32, 64, 128] + for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5] + for n_warps in [4, 8] + ] + + +@triton.autotune( + configs=get_autotune_config(), + key=['M', 'K'], +) +@triton.jit +def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr): + """ + Core kernel jit function of matmul_transpose that computes y = x @ x.T + The code is a simple adaptation from the triton `matmul` tutorial: + https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html + """ + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + if pid_m > pid_n: + return + + offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + # we use a & b ptrs to denote different rows of x. + a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk) + b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, + mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, + other=0.0) + b = tl.load(b_ptrs, + mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, + other=0.0) + accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator) + a_ptrs += BLOCK_SIZE_K * stride_xk + b_ptrs += BLOCK_SIZE_K * stride_xk + # use dtype.element_ty to accommodate different input datatypes as in cpp templates + # https://github.com/triton-lang/triton/issues/2252 + c = accumulator.to(x.dtype.element_ty) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, c, mask=c_mask) + + # transpose and copy + if pid_m < pid_n: + ct_ptrs = y + stride_ym * offs_cn[:, + None] + stride_yn * offs_cm[None, :] + ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) + + +def matmul_transpose_assign(d_in, d_out): + assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" + assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" + assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" + assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" + assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" + assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" + assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ + "First dimension of `d_in` must match first and second dimension of `d_out`" + + d_in = d_in.contiguous() + M, K = d_in.shape + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( + M, META['BLOCK_SIZE_M']), ) + with torch.cuda.device(d_in.device.index): + mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), + d_out.stride(0), d_out.stride(1)) + + +def matmul_transpose(d_in): + M, _ = d_in.shape + d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) + matmul_transpose_assign(d_in, d_out) + return d_out diff --git a/build/torch29-cxx11-cu128-x86_64-linux/metadata.json b/build/torch29-cxx11-cu128-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..76bafa5f33b6818aa6bb4cab04be811b87519b44 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/metadata.json @@ -0,0 +1 @@ +{"python-depends":[]} \ No newline at end of file diff --git a/build/torch29-cxx11-cu128-x86_64-linux/muon.py b/build/torch29-cxx11-cu128-x86_64-linux/muon.py new file mode 100644 index 0000000000000000000000000000000000000000..dbf25575f185ff379789482068e4ecf55b9455a9 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/muon.py @@ -0,0 +1,1268 @@ +import logging +import math +import types +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, cast + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Replicate +from torch.distributed.tensor.placement_types import Placement + +from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor +from .matmul_transpose_triton import matmul_transpose_assign + +logger = logging.getLogger(__name__) + +COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 + + +# This code snippet is a modified version adapted from the following GitHub repositories: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +# Muon's Newton–Schulz iteration causes high variance in singular values +# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +@torch.no_grad() +# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + assert G.dtype == COMM_DTYPE + X = G # no manual typecast + + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + # Perform the NS iterations + for a, b, c in [ + (4.0848, -6.8946, 2.9270), + (3.9505, -6.3029, 2.6377), + (3.7418, -5.5913, 2.3037), + (2.8769, -3.1427, 1.2046), + (2.8366, -3.0525, 1.2012), + ]: + matmul_transpose_assign(X, buf1) + matmul_transpose_assign(buf1, buf2) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X + + +@dataclass +class _muon_state: + # TODO: use Optional + worker_rank: int + process_group: ProcessGroup + shard_mesh: DeviceMesh + shard_placements: tuple[Placement, ...] + name: str + qk_clip_state: torch.Tensor | None = None + gathered_grad: torch.Tensor | None = None + scattered_u: DTensor | None = None + computed_u: torch.Tensor | None = None + gather_event: torch.cuda.Event | None = None + compute_event: torch.cuda.Event | None = None + scatter_event: torch.cuda.Event | None = None + + +def numel_for_rank( + param: DTensor, + local_rank: int, + state: _muon_state, +) -> int: + slices = get_slices_of_dtensor( + param, + local_rank, + state.shard_mesh, + state.shard_placements, + ) + + numel = 1 + for s, dim in zip(slices, param.shape): + start, stop, step = s.indices(dim) + length = max(0, (stop - start + (step - 1)) // step) + numel *= length + + return numel + + +@torch.no_grad() +def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): + """ + Pre-allocate gathered_grad buffer on compute_stream + before launching all2all gather + """ + with torch.cuda.stream(compute_stream): + for p in params: + state = param_to_state[id(p)] + if rank == state.worker_rank: + state.gathered_grad = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + else: + state.gathered_grad = None + + alloc_event = torch.cuda.Event() + alloc_event.record(compute_stream) + return alloc_event + + +@torch.no_grad() +def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, + alloc_event): + """ + All2all gathers shards so each owner rank reconstructs its full gradient + """ + with torch.cuda.stream(comm_stream): + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + + # Construct sending buffers + per_dst = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + for p in params: + state = param_to_state[id(p)] + dst = state.worker_rank + assert dst < num_ranks + shard_elems = numel_for_rank(p, rank, state) + g = p.grad + g = g.to_local().to(COMM_DTYPE).contiguous() + assert g.numel() == shard_elems + per_dst[dst].append(g.view(-1)) + send_counts[dst] += shard_elems + + assert any( + len(v) > 0 for v in per_dst + ), "At least one destination rank must receive a sharded tensor" + # list[list[Tensor]] -> list[Tensor] + per_dst = [t for dst in per_dst for t in dst] + + send_buf = torch.cat(per_dst, dim=0) + + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Compute receive sizes and allocate receiving buffers + recv_counts = [0] * num_ranks + + for src in range(num_ranks): + total = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + total += numel_for_rank(p, src, state) + recv_counts[src] = total + + recv_total = sum(recv_counts) + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + #All2All + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") + dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + ) + + # Reconstructs gathered grad from the received buffer + # + # recv_buf (num ranks = 3) + # + # From rank 0 From rank 1 From rank 2 + # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | + # + # Outer loop: + # rank 0 -> rank 1 -> rank2 + # + # Inner loop: + # p1_n -> p2_n -> p3_n + + comm_stream.wait_event(alloc_event) + + off = 0 + for src in range(num_ranks): + if recv_counts[src] == 0: + continue + + block = recv_counts[src] + inner_off = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + + # get the slice of the full dtensor corresponding to rank src. + slices = get_slices_of_dtensor(state.gathered_grad, src, + state.shard_mesh, + state.shard_placements) + + dst = state.gathered_grad[slices] + assert dst._base is state.gathered_grad + + n = dst.numel() + assert n > 0 + + sg = recv_buf.narrow(0, off + inner_off, n) + sg = sg.reshape_as(dst) + dst.copy_(sg) + + inner_off += n + off += block + + for p in params: + state = param_to_state[id(p)] + if state.worker_rank == rank: + state.gather_event = torch.cuda.Event() + state.gather_event.record(comm_stream) + else: + state.gathered_grad = None + state.gather_event = None + if none_grad: + p.grad = None + + +@torch.no_grad() +def _compute_u(p, state, steps, rank, compute_stream): + """ + On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. + """ + with torch.cuda.stream(compute_stream): + if rank == state.worker_rank: + if state.gather_event is None: + raise RuntimeError("Gather event must be set before compute.") + compute_stream.wait_event(state.gather_event) + u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) + state.gathered_grad = None + state.computed_u = u + state.compute_event = torch.cuda.Event() + state.compute_event.record() + else: + state.computed_u = None + state.compute_event = None + + +@torch.no_grad() +def _alloc_scattered_u(params, param_to_state, rank, compute_stream): + """ + Pre-allocate scattered_u buffer on compute_stream + before launching all2all gather + """ + with torch.cuda.stream(compute_stream): + for p in params: + state = param_to_state[id(p)] + state.scattered_u = torch.empty_like(p.to_local(), + dtype=COMM_DTYPE) + + alloc_event = torch.cuda.Event() + alloc_event.record(compute_stream) + return alloc_event + + +def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): + """ + All2all scatters full gradients to all ranks + """ + with torch.cuda.stream(comm_stream): + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Construct sending buffer + per_dst = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + if owned_params: + for p in owned_params: + state = param_to_state[id(p)] + if state.compute_event is None: + raise RuntimeError( + "Compute event must be set before scatter.") + comm_stream.wait_event(state.compute_event) + state.gathered_grad = None + + assert state.computed_u is not None + + u_full = state.computed_u.to(COMM_DTYPE).contiguous() + + offset = 0 + for dst in range(num_ranks): + # get the slice of the full tensor corresponding to rank dst. + slices = get_slices_of_dtensor(u_full, dst, + state.shard_mesh, + state.shard_placements) + su = u_full[slices].flatten() + + n = su.numel() + assert n > 0 + + per_dst[dst].append(su) + send_counts[dst] += n + offset += n + + assert offset == u_full.numel() + + lengths = [len(v) for v in per_dst] + if all(l > 0 for l in lengths): + assert all( + l == lengths[0] for l in lengths + ), "All destination ranks must have the same number of sharded tensor" + # list[list[Tensor]] -> list[Tensor] + per_dst = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst, dim=0) + else: + # all_to_all requires participation from all ranks + # Even non-owner ranks must join the collective call + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") + + # Compute receive sizes and allocate receiving buffers + recv_counts = [0] * num_ranks + + for src in range(num_ranks): + total = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + total += numel_for_rank(p, rank, state) + recv_counts[src] = total + + recv_total = sum(recv_counts) + assert recv_total > 0 + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + #All2All + dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + ) + + # Copy to pre-allocated scattered_u buffer from the received buffer + # + # recv_buf (num ranks = 3, local_rank = 0) + # + # From rank 0 From rank 1 From rank 2 + # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | + # + # Outer loop: + # rank 0 -> rank 1 -> rank2 + # + # Inner loop: + # src(0) : p1_0 -> p2_0 -> p3_0 + # src(1) : p4_0 + # src(2) : p5_0 -> p6_0 + + comm_stream.wait_event(alloc_event) + + off = 0 + for src in range(num_ranks): + block = recv_counts[src] + if block == 0: + continue + + inner_off = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + n = numel_for_rank(p, rank, state) + assert n > 0 + + flat_local = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) + state.scattered_u.copy_(flat_local) + + state.scatter_event = torch.cuda.Event() + state.scatter_event.record(comm_stream) + inner_off += n + + assert inner_off == block + off += block + + +def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, + compute_stream): + """ + Update sharded parameter p with the scattered_u. + Only worker_rank frees computed_u. + """ + with torch.cuda.stream(compute_stream): + if state.scatter_event is None: + raise RuntimeError("Scatter event must be set before update") + compute_stream.wait_event(state.scatter_event) + u_dtensor = DTensor.from_local( + state.scattered_u, + placements=p.placements, + device_mesh=p.device_mesh, + ) + + state.scattered_u = u_dtensor + + if rank == state.worker_rank: + # Free computed_u + state.computed_u = None + + Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) + state.scattered_u = None + u_dtensor = None + + scales_full = Muon._compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None + if scales_full is not None: + # Have to slice scales_full among dim 0 + weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, + state.shard_placements) + ratio = p.shape[0] // scales_full.shape[0] + scales_slice = slice( + None if weight_slices[0].start is None else + weight_slices[0].start // ratio, + None if weight_slices[0].stop is None else + weight_slices[0].stop // ratio, + None, + ) + + scales_local = scales_full[scales_slice] + scales_local = DTensor.from_local( + scales_local, + placements=p.placements, + device_mesh=p.device_mesh, + ) + Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) + + +def default_is_muon(name, x): + skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] + return x.ndim >= 2 and not any(key in name for key in skip_keys) + + +def get_default_muon_param_groups(model, is_muon_func=default_is_muon): + muon_params, muon_names = [], [] + non_muon_params = [] + + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + if is_muon_func(n, p): + muon_params.append(p) + muon_names.append(n) + else: + non_muon_params.append(p) + + return [ + { + "params": muon_params, + "names": muon_names, + "use_muon": True, + }, + { + "params": non_muon_params, + "use_muon": False, + }, + ] + + +def parse_qk_layer(name: str) -> tuple[str | None, int]: + """ + Parse a parameter name to check if it is a query/key projection layer + ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + + Returns: + (kind, layer_idx) or (None, -1) if not matched. + + Example: + 'model.3.attn.wq.weight' -> ('wq', 3) + 'model.5.attn.wk.weight' -> ('wk', 5) + 'model.2.attn.q_proj.weight' -> ('q_proj', 2) + 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.4.attn.v_proj.weight' -> (None, -1) + """ + parts = name.split('.') + if len(parts) < 3: + return None, -1 + + kind = parts[-2] + + layer_idx = -1 + for part in reversed(parts): + if part.isdigit(): + layer_idx = int(part) + break + + if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + return kind, layer_idx + + return None, -1 + + +@dataclass +class QKClipInfo: + """Per-parameter dynamic info computed from config + runtime logits.""" + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping + head_dim: int # from config + threshold: float # from config + logit: torch.Tensor | None + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + model: The model to be optimized by Muon. + is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + weight_decay: The weight decay for Muon and AdamW. + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. + debug: Whether to print debug information. + clip_info : Configuration for QK clipping. Expected keys: + - "q_indices" (list[int]): Indices of query heads to consider. + - "k_indices" (list[int]): Indices of key heads to consider. + - "head_dim" (int): Dimensionality of each attention head. + - "threshold" (float): Threshold value; heads whose QK logits exceed + this value will be scaled down. + Default is: + { + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100 + } + warmup_step : How many all2all gather, compute operations are launched in advance + before the corresponding all2all scatter steps begin. + A higher warmup_step increases memory usage but can improve + performance by overlapping communication. + Parallel muon only. + chunk_size : Batch size of parameters to process in each + all2all gather/compute/scatter step. + Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. + use_distributed_muon: Use distributed muon by Liu et al. (2024). + For testing purpose only. + small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon + """ + + def __init__(self, + params, + lr=1e-3, + momentum=0.95, + nesterov=True, + ns_steps=5, + weight_decay=0.1, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + none_grad=True, + debug=False, + clip_config={ + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100 + }, + warmup_step=5, + chunk_size=-1, + use_distributed_muon=False, + small_param_numel_threshold=65536): + defaults = dict( + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + none_grad=none_grad, + use_muon=True, + ) + error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." + instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" + + if isinstance(params, types.GeneratorType): + raise ValueError(error_message.format(idx=0) + instruction_code) + for _idx, param_group in enumerate(params): + if param_group.get("use_muon", None) is None: + raise ValueError( + error_message.format(idx=_idx) + instruction_code) + + super().__init__(params, defaults) + + self.rank = None + + self.comm_stream = torch.cuda.Stream() + self.compute_stream = torch.cuda.Stream() + self.debug = debug + self.clip_config = clip_config + self.warmup_step = warmup_step + self.chunk_size = chunk_size + self.use_distributed_muon = use_distributed_muon + self.small_param_numel_threshold = small_param_numel_threshold + + def _calc_flops(self, G, steps): + assert len(G.shape) == 2 + M, N = G.shape + if M > N: + M, N = N, M + + return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def set_rank_once(self, rank): + if self.rank is None: + self.rank = rank + else: + assert self.rank == rank + + def get_shard_mesh(self, p): + """ + Get the shard mesh for a parameter p on the given rank. + """ + assert isinstance( + p, DTensor), "Parallel Muon only supports DTensor parameters." + + shard_mesh, shard_pg, shard_placements = construct_shard_mesh( + p.placements, p.device_mesh) + + # set rank with the local rank in the shard process group + self.set_rank_once(dist.get_rank(group=shard_pg)) + + return shard_mesh, shard_pg, shard_placements + + def init_state_and_assign_params(self, names, params, group, qk_logits): + param_to_state = {} + param_to_flops = {} + + total_flops = 0 + for p in params: + g = p.grad + if g is None: + continue + assert g.ndim == 2, "Muon only supports 2D parameters." + + flops = self._calc_flops(g, group["ns_steps"]) + param_to_flops[id(p)] = flops + total_flops += flops + + if self.debug: + print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", + flush=True) + + paired = list(zip(names, params)) + + paired_sorted = sorted(paired, + key=lambda x: param_to_flops[id(x[1])], + reverse=True) + + names_sorted, params_sorted = zip(*paired_sorted) + ordered_names = list(names_sorted) + ordered_params = list(params_sorted) + + round_robin = 0 + mesh = ordered_params[0].device_mesh + placements = ordered_params[0].placements + + shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( + ordered_params[0]) + shard_mesh_flattened = shard_mesh.mesh.flatten() + num_ranks = dist.get_world_size(group=shard_pg) + + for n, p in zip(ordered_names, ordered_params): + if mesh != p.device_mesh: + raise ValueError("All parameters must be on the same mesh.") + if placements != p.placements: + raise ValueError("All parameters must have same placements.") + + worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks + round_robin = (round_robin + 1) % len(shard_mesh_flattened) + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + param_to_state[id(p)] = _muon_state( + worker_rank=worker_rank, + process_group=shard_pg, + shard_mesh=shard_mesh, + shard_placements=shard_placements, + name=n, + qk_clip_state=qk_clip_state, + ) + + return param_to_state, ordered_params + + def base(self, names, params, group, lr, weight_decay, momentum, + qk_logits): + # generate weight updates in distributed fashion + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + g = self._update_g(p, g, group, momentum) + + u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), + steps=group["ns_steps"]) + + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + scales_full = self._compute_scales( + p, qk_clip_state) if qk_clip_state is not None else None + if scales_full is not None: + Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + + def distributed_muon( + self, + names: list[str], + params: list[torch.nn.Parameter], + group: dict[str, Any], + lr: float, + weight_decay: float, + momentum: float, + qk_logits: list[torch.Tensor | DTensor] | None, + ): + """ Implementation of Distributed Muon by Liu et al. """ + + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + g = self._update_g(p, g, group, momentum) + + # Gather G + if isinstance(p.data, DTensor): + g_full = g.full_tensor() + p_full = p.data.full_tensor() + else: + g_full = g + p_full = p + + u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), + steps=group["ns_steps"]) + + adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) + Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + scales_full = self._compute_scales( + p_full, qk_clip_state) if qk_clip_state is not None else None + + if scales_full is not None: + Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) + + if isinstance(p.data, DTensor): + ndims = len(p.device_mesh.mesh.shape) + p_replicate = DTensor.from_local( + p_full, + device_mesh=p.device_mesh, + placements=[Replicate() for _ in range(ndims)], + ) + + p_sharded = p_replicate.redistribute( + device_mesh=p.device_mesh, + placements=p.placements, + ) + + p.copy_(p_sharded) + + def _update_g(self, p, g, group, momentum): + # calc update + state = self.state[p] + buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) + torch.add(g, buf, alpha=momentum, out=buf) + if group["nesterov"]: + g.add_(buf, alpha=momentum) + return g + return buf + + @staticmethod + def _update_p(p, u, lr, adjusted_lr, weight_decay): + if isinstance(p, torch.nn.Parameter): + # apply weight decay + p.data.mul_(1 - lr * weight_decay) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + else: + p.mul_(1 - lr * weight_decay) + p.add_(u, alpha=-adjusted_lr) + + def get_qk_clip_info(self, n, qk_logits): + if self.clip_config is None: + return None + + head_dim = self.clip_config.get('head_dim') + threshold = self.clip_config.get('threshold') + kind, layer_idx = parse_qk_layer(n) + + logit, indices = None, [] + if qk_logits is not None and kind is not None: + logit = qk_logits[layer_idx] + indices_key = 'q_indices' if 'q' in kind else 'k_indices' + indices = self.clip_config.get(indices_key, []) or [] + + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) + + @staticmethod + def _compute_scales(p, qk_clip_state): + kind = qk_clip_state.kind + indices = qk_clip_state.indices + head_dim = qk_clip_state.head_dim + threshold = qk_clip_state.threshold + logit = qk_clip_state.logit + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + scaling = 0 + + for logit_idx, head_idx in enumerate(indices): + v_ele = float(logit[logit_idx]) + if v_ele > threshold: + new_scale = math.sqrt(threshold / v_ele) + if new_scale < scales_full[head_idx]: + scales_full[head_idx] = new_scale + logger.info( + f"[{kind}] Head {head_idx} exceeded threshold " + f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" + ) + scaling += 1 + + return scales_full if scaling > 0 else None + + @staticmethod + def _qk_clip(p, scales, head_dim): + if isinstance(p, torch.nn.Parameter): + W = p.data.view(-1, head_dim, p.data.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + else: + W = p.view(-1, head_dim, p.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + + def parallel(self, names, params, group, lr, weight_decay, momentum, + qk_logits): + """ + Perform a parallel optimization step using Muon. + """ + + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + + # Update g in the local rank + g = self._update_g( + p, + g, + group, + momentum=momentum, + ) + p.grad = g + + param_to_state, ordered_params = self.init_state_and_assign_params( + names, params, group, qk_logits) + + assert self.rank is not None + + def enqueue_all2all_gather(start_idx, chunk_size): + target_params = ordered_params[start_idx:start_idx + chunk_size] + if target_params: + alloc_event = _alloc_gathered_grad(target_params, + param_to_state, self.rank, + self.compute_stream) + _all2all_gather(target_params, param_to_state, self.rank, + self.comm_stream, group["none_grad"], + alloc_event) + + def enqueue_computes(start_idx, chunk_size): + for p in ordered_params[start_idx:start_idx + chunk_size]: + state = param_to_state[id(p)] + _compute_u(p, state, group["ns_steps"], self.rank, + self.compute_stream) + + def enqueue_all2all_scatter(start_idx, chunk_size): + target_params = ordered_params[start_idx:start_idx + chunk_size] + if target_params: + alloc_event = _alloc_scattered_u(target_params, param_to_state, + self.rank, + self.compute_stream) + _all2all_scatter(target_params, param_to_state, self.rank, + self.comm_stream, alloc_event) + + def enqueue_update_param(start_idx, chunk_size): + for p in ordered_params[start_idx:start_idx + chunk_size]: + state = param_to_state[id(p)] + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + _update_param(p, state, lr, adjusted_lr, weight_decay, + self.rank, self.compute_stream) + + if self.chunk_size == -1: + shard_ranks = dist.get_world_size(param_to_state[id( + params[0])].process_group) + chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO + elif self.chunk_size > 0: + chunk_size = self.chunk_size + else: + raise ValueError("chunk_size must be -1 or a positive integer.") + + # Wait grad update + self.comm_stream.wait_stream(torch.cuda.current_stream()) + + warmup_step = self.warmup_step + for i in range(0, warmup_step): + enqueue_all2all_gather(i * chunk_size, chunk_size) + enqueue_computes(i * chunk_size, chunk_size) + + for i in range(0, len(params) + chunk_size - 1, chunk_size): + enqueue_all2all_scatter(i, chunk_size) + enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) + enqueue_update_param(i, chunk_size) + enqueue_computes(i + warmup_step * chunk_size, chunk_size) + + # Wait the last update_param to finish + torch.cuda.current_stream().wait_stream(self.compute_stream) + + @staticmethod + def _fused_adamw( + params: list[torch.Tensor], + grads: list[torch.Tensor], + exp_avgs: list[torch.Tensor], + exp_avg_sqs: list[torch.Tensor], + max_exp_avg_sqs: list[torch.Tensor], + state_steps: list[torch.Tensor], + amsgrad: bool, + beta1: float, + beta2: float, + lr: float | torch.Tensor, + weight_decay: float, + eps: float, + maximize: bool, + ) -> None: + if not params: + return + + # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer + # treating it as a scalar. + lr_dict: DeviceDict | None = ({ + lr.device: lr + } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else + None) + grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( + [ + params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, + state_steps + ] # type: ignore[list-item] + ) + for (device, _), ( + ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_max_exp_avg_sqs, + device_state_steps_, + ), + _, + ) in grouped_tensors.items(): + device_params = cast(list[torch.Tensor], device_params_) + device_grads = cast(list[torch.Tensor], device_grads_) + device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) + device_state_steps = cast(list[torch.Tensor], device_state_steps_) + + if lr_dict is not None and device not in lr_dict: + lr_dict[device] = lr.to( + device=device, + non_blocking=True) # type: ignore[union-attr] + lr = lr_dict[device] + torch._foreach_add_(device_state_steps, 1) + func = torch._fused_adamw_ + func( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, # type: ignore[arg-type] + device_state_steps, + amsgrad=amsgrad, + lr=lr, # type: ignore[arg-type] + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + ) + + def _step_muon(self, group, qk_logits=None): + params = group["params"] + lr = group["lr"] + weight_decay = group["weight_decay"] + momentum = group["momentum"] + names = group["names"] + + param_dtensors = [] + name_dtensors = [] + + param_tensors = [] + name_tensors = [] + + param_dtensors_small = [] + name_dtensors_small = [] + + if self.use_distributed_muon: + self.distributed_muon(names=names, + params=params, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits) + return + + # For simplicity, we use distributed Muon for small parameters + # whose number of elements is below a threshold. + for n, p in zip(names, params): + if p is None or p.grad is None: + continue + if isinstance(p.data, DTensor): + if all( + isinstance(placement, Replicate) + for placement in p.placements): + param_tensors.append(p) + name_tensors.append(n) + elif p.data.numel() <= self.small_param_numel_threshold: + param_dtensors_small.append(p) + name_dtensors_small.append(n) + else: + param_dtensors.append(p) + name_dtensors.append(n) + elif isinstance(p.data, torch.Tensor): + param_tensors.append(p) + name_tensors.append(n) + else: + raise TypeError(f"Unsupported parameter type: {type(p.data)}") + + logger.debug( + f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " + f"{len(param_dtensors_small)} Small DTensors") + + def group_dtensors(dtensors, names): + # To support different placements, we group parameters by placements + # and run parallel Muon on each group. + + placement_to_params = defaultdict(lambda: ([], [])) + # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] + + assert len(dtensors) == len(names) + for p, n in zip(dtensors, names): + placement_to_params[tuple([p.placements, + p.device_mesh])][0].append(n) + placement_to_params[tuple([p.placements, + p.device_mesh])][1].append(p) + return placement_to_params + + if len(param_dtensors_small) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + self.distributed_muon( + params=param_dtensors_small, + names=name_dtensors_small, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_dtensors) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + dtensor_group = group_dtensors(param_dtensors, name_dtensors) + for _, (names, params) in dtensor_group.items(): + self.parallel( + names, + params, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_tensors) > 0: + self.base( + name_tensors, + param_tensors, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + def _step_adamw_params(self, params, group): + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + self._fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + def _step_adamw(self, group): + params = group["params"] + + # group params with it's type and placement + placement_to_params: dict[tuple[Placement | type, + DeviceMesh | None]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for params in placement_to_params.values(): + self._step_adamw_params(params, group) + + @torch.no_grad + def step(self, closure=None, qk_logits=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices + to 1D tensors of shape (num_heads,), representing the maximum + QK logits across all tokens, computed as + (1 / sqrt(head_dim)) * (Q @ K^T). + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + if group["use_muon"]: + self._step_muon(group, qk_logits=qk_logits) + else: + self._step_adamw(group) + + return loss diff --git a/build/torch29-cxx11-cu128-x86_64-linux/optimizer/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/optimizer/__init__.py index 239c7a65f8293e7d0df28f05fce645af56d628c0..03dbc1afe1cf156661a2b1b22003cd5f599a0309 100644 --- a/build/torch29-cxx11-cu128-x86_64-linux/optimizer/__init__.py +++ b/build/torch29-cxx11-cu128-x86_64-linux/optimizer/__init__.py @@ -1,5 +1,26 @@ -from .muon import Muon +import ctypes +import sys -__all__ = [ - "Muon", -] +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu128-x86_64-linux/optimizer/_ops.py b/build/torch29-cxx11-cu128-x86_64-linux/optimizer/_ops.py deleted file mode 100644 index 7d598206add1bca142661a3df6c510e3d9575d54..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu128-x86_64-linux/optimizer/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _optimizer_23d68bb_dirty -ops = torch.ops._optimizer_23d68bb_dirty - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_optimizer_23d68bb_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu128-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so b/build/torch29-cxx11-cu128-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so deleted file mode 100755 index 3fd487db1ff3b2ee3b5ab65ea2272e7fe95e5c76..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu128-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:fc97ff00a3255d5eb363958b1e619eadbc4315f1930d0fb59cfc9560c3951721 -size 1983488 diff --git a/build/torch29-cxx11-cu128-x86_64-linux/optimizer/distributed/utils.py b/build/torch29-cxx11-cu128-x86_64-linux/optimizer/distributed/utils.py deleted file mode 100644 index 0b4b58bfb329b1c015129e4c4fc99f7bfa2ab30a..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu128-x86_64-linux/optimizer/distributed/utils.py +++ /dev/null @@ -1,174 +0,0 @@ -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor -from torch.distributed.tensor.placement_types import (Placement, Shard, - _StridedShard) - - -def get_slices_of_dtensor( - target: DTensor | torch.Tensor, - local_rank: int, - shard_mesh: DeviceMesh, - shard_placements: tuple[Placement], -) -> tuple[slice]: - """ - Get the slice of local tensor for a given rank from a tensor. - Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. - shard_placements (tuple[Placement]): The shard placements. - """ - - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] - - # find the global rank of the local rank in the shard mesh - rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] - - rank_coords = (shard_mesh.mesh == rank).nonzero() - - assert len(rank_coords) == 1 - rank_coords = tuple(rank_coords[0].tolist()) - - assert len(rank_coords) == len(shard_placements) - - # Caution: Assuming replicate-to-shard of the shard mesh goes with - # left-to-right sharding. This is ensured by the sorting logic of - # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) - - num_ranks = shard_mesh.mesh.shape[i] - - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) - - if dim_size % num_ranks != 0: - raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}.") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) - - return tuple(slices) - - -_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, ProcessGroup]] = dict() - - -def construct_shard_mesh( - placements: tuple[Placement], - mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() - - assert mesh.mesh.device.type == 'cpu' - - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") - - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) - - sorted_indices, sorted_placements = zip(*placements_with_index) - - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) - - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) - - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh - shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) - else: - shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different - assert len(shard_placements) == len(set(shard_placements)) - - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, - ) - - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch29-cxx11-cu128-x86_64-linux/optimizer/muon.py b/build/torch29-cxx11-cu128-x86_64-linux/optimizer/muon.py deleted file mode 100644 index cfbcca71741be70048bfd290c62148b2aceda631..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu128-x86_64-linux/optimizer/muon.py +++ /dev/null @@ -1,1240 +0,0 @@ -import logging -import math -import types -from collections import defaultdict -from dataclasses import dataclass -from typing import Any, cast - -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign - -logger = logging.getLogger(__name__) - -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 - - -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. -@torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon -def _zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE - X = G # no manual typecast - - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) - - if G.size(0) > G.size(1): - X = X.T - return X - - -@dataclass -class _muon_state: - # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None - gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None - computed_u: torch.Tensor | None = None - gather_event: torch.cuda.Event | None = None - compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel - - -@torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None - - -@torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ - with torch.cuda.stream(compute_stream): - if rank == state.worker_rank: - if state.gather_event is None: - raise RuntimeError("Gather event must be set before compute.") - compute_stream.wait_event(state.gather_event) - u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None - state.computed_u = u - state.compute_event = torch.cuda.Event() - state.compute_event.record() - else: - state.computed_u = None - state.compute_event = None - - -@torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) - else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) - - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n - - assert inner_off == block - off += block - - -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. - """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - - if rank == state.worker_rank: - # Free computed_u - state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: - continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx - - return None, -1 - - -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None - - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Some warnings: - - We believe this optimizer is unlikely to work well for training with small batch size. - - We believe it may not work well for finetuning pretrained models, but we haven't tested this. - - Arguments: - model: The model to be optimized by Muon. - is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. - lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) - momentum: The momentum used by the internal SGD. (0.95 is a good default) - nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) - ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) - weight_decay: The weight decay for Muon and AdamW. - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. - adamw_lr: The learning rate for the internal AdamW. - adamw_betas: The betas for the internal AdamW. - adamw_eps: The epsilon for the internal AdamW. - none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. - debug: Whether to print debug information. - clip_info : Configuration for QK clipping. Expected keys: - - "q_indices" (list[int]): Indices of query heads to consider. - - "k_indices" (list[int]): Indices of key heads to consider. - - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed - this value will be scaled down. - Default is: - { - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - } - warmup_step : How many all2all gather, compute operations are launched in advance - before the corresponding all2all scatter steps begin. - A higher warmup_step increases memory usage but can improve - performance by overlapping communication. - Parallel muon only. - chunk_size : Batch size of parameters to process in each - all2all gather/compute/scatter step. - Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. - use_distributed_muon: Use distributed muon by Liu et al. (2024). - For testing purpose only. - """ - - def __init__(self, - params, - lr=1e-3, - momentum=0.95, - nesterov=True, - ns_steps=5, - weight_decay=0.1, - adamw_betas=(0.9, 0.95), - adamw_eps=1e-8, - none_grad=True, - debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, - warmup_step=5, - chunk_size=-1, - use_distributed_muon=False): - defaults = dict( - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - nesterov=nesterov, - ns_steps=ns_steps, - adamw_betas=adamw_betas, - adamw_eps=adamw_eps, - none_grad=none_grad, - use_muon=True, - ) - error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." - instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" - - if isinstance(params, types.GeneratorType): - raise ValueError(error_message.format(idx=0) + instruction_code) - for _idx, param_group in enumerate(params): - if param_group.get("use_muon", None) is None: - raise ValueError( - error_message.format(idx=_idx) + instruction_code) - - super().__init__(params, defaults) - - self.rank = None - - self.comm_stream = torch.cuda.Stream() - self.compute_stream = torch.cuda.Stream() - self.debug = debug - self.clip_config = clip_config - self.warmup_step = warmup_step - self.chunk_size = chunk_size - self.use_distributed_muon = use_distributed_muon - - def _calc_flops(self, G, steps): - assert len(G.shape) == 2 - M, N = G.shape - if M > N: - M, N = N, M - - return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) - - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - - def get_shard_mesh(self, p): - """ - Get the shard mesh for a parameter p on the given rank. - """ - assert isinstance( - p, DTensor), "Parallel Muon only supports DTensor parameters." - - shard_mesh, shard_pg, shard_placements = construct_shard_mesh( - p.placements, p.device_mesh) - - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - - return shard_mesh, shard_pg, shard_placements - - def init_state_and_assign_params(self, names, params, group, qk_logits): - param_to_state = {} - param_to_flops = {} - - total_flops = 0 - for p in params: - g = p.grad - if g is None: - continue - assert g.ndim == 2, "Muon only supports 2D parameters." - - flops = self._calc_flops(g, group["ns_steps"]) - param_to_flops[id(p)] = flops - total_flops += flops - - if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) - - paired = list(zip(names, params)) - - paired_sorted = sorted(paired, - key=lambda x: param_to_flops[id(x[1])], - reverse=True) - - names_sorted, params_sorted = zip(*paired_sorted) - ordered_names = list(names_sorted) - ordered_params = list(params_sorted) - - round_robin = 0 - mesh = ordered_params[0].device_mesh - placements = ordered_params[0].placements - - shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( - ordered_params[0]) - shard_mesh_flattened = shard_mesh.mesh.flatten() - num_ranks = dist.get_world_size(group=shard_pg) - - for n, p in zip(ordered_names, ordered_params): - if mesh != p.device_mesh: - raise ValueError("All parameters must be on the same mesh.") - if placements != p.placements: - raise ValueError("All parameters must have same placements.") - - worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks - round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - param_to_state[id(p)] = _muon_state( - worker_rank=worker_rank, - process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, - name=n, - qk_clip_state=qk_clip_state, - ) - - return param_to_state, ordered_params - - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - # generate weight updates in distributed fashion - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - # calc update - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if group["nesterov"]: - g = g.add(buf, alpha=momentum) - else: - g = buf - - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p, qk_clip_state) if qk_clip_state is not None else None - if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) - - def distributed_muon( - self, - names: list[str], - params: list[torch.nn.Parameter], - group: dict[str, Any], - lr: float, - weight_decay: float, - momentum: float, - qk_logits: list[torch.Tensor | DTensor] | None, - ): - """ Implementation of Distributed Muon by Liu et al. """ - if qk_logits is not None: - raise NotImplementedError("QK clipping is not supported yet") - - if isinstance(params[0], DTensor): - shard_mesh, _, shard_placements = construct_shard_mesh( - placements=params[0].placements, - mesh=params[0].device_mesh, - ) - - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - # calc update - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if group["nesterov"]: - g = g.add(buf, alpha=momentum) - else: - g = buf - - # Gather G - if isinstance(p.data, DTensor): - g = g.full_tensor() - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) - - if isinstance(p.data, DTensor): - slices = get_slices_of_dtensor( - target=p, - local_rank=dist.get_rank(), - shard_mesh=shard_mesh, - shard_placements=shard_placements, - ) - u_shard = u[slices] - u = DTensor.from_local( - u_shard, - device_mesh=p.device_mesh, - placements=p.placements, - ) - - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) - - def _update_g(self, p, g, group, momentum): - # calc update - state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - """ - Perform a parallel optimization step using Muon. - """ - - for p in params: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - - # Update g in the local rank - g = self._update_g( - p, - g, - group, - momentum=momentum, - ) - p.grad = g - - param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - assert self.rank is not None - - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) - - def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) - - if self.chunk_size == -1: - shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) - chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO - elif self.chunk_size > 0: - chunk_size = self.chunk_size - else: - raise ValueError("chunk_size must be -1 or a positive integer.") - - # Wait grad update - self.comm_stream.wait_stream(torch.cuda.current_stream()) - - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - - for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return - - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) - - def _step_muon(self, group, qk_logits=None): - params = group["params"] - lr = group["lr"] - weight_decay = group["weight_decay"] - momentum = group["momentum"] - names = group["names"] - - param_dtensors = [] - param_tensors = [] - name_dtensors = [] - name_tensors = [] - - if self.use_distributed_muon: - self.distributed_muon(names=names, - params=params, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits) - return - - for n, p in zip(names, params): - if p is None or p.grad is None: - continue - if isinstance(p.data, DTensor): - if all( - isinstance(placement, Replicate) - for placement in p.placements): - param_tensors.append(p) - name_tensors.append(n) - else: - param_dtensors.append(p) - name_dtensors.append(n) - elif isinstance(p.data, torch.Tensor): - param_tensors.append(p) - name_tensors.append(n) - else: - raise TypeError(f"Unsupported parameter type: {type(p.data)}") - - logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors" - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - # To support different placements, we group parameters by placements - # and run parallel Muon on each group. - - placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] - - assert len(name_dtensors) == len(param_dtensors) - for n, p in zip(name_dtensors, param_dtensors): - placement_to_params[tuple([p.placements, - p.device_mesh])][0].append(n) - placement_to_params[tuple([p.placements, - p.device_mesh])][1].append(p) - - for _, (names, params) in placement_to_params.items(): - self.parallel( - names, - params, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_tensors) > 0: - self.base( - name_tensors, - param_tensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] - - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - - def step(self, closure=None, qk_logits=None): - """Perform a single optimization step. - - Args: - closure (Callable, optional): A closure that reevaluates the model - and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as - (1 / sqrt(head_dim)) * (Q @ K^T). - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - if group["use_muon"]: - self._step_muon(group, qk_logits=qk_logits) - else: - self._step_adamw(group) - - return loss diff --git a/build/torch29-cxx11-cu130-x86_64-linux/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..239c7a65f8293e7d0df28f05fce645af56d628c0 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/__init__.py @@ -0,0 +1,5 @@ +from .muon import Muon + +__all__ = [ + "Muon", +] diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_ops.py b/build/torch29-cxx11-cu130-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..e6f6fcf6280e969b1761926112147d3146e27b59 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _optimizer_06a260a_dirty +ops = torch.ops._optimizer_06a260a_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch29-cxx11-cu130-x86_64-linux/_optimizer_06a260a_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..67ccafc522c41f14eaf682f265f2bc7d3f56b114 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/_optimizer_06a260a_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ef9fba09368a2296ebad017f6576f119ebe2b9513be0d51b66b403fe942bb6d5 +size 2000456 diff --git a/build/torch29-cxx11-cu130-x86_64-linux/distributed/utils.py b/build/torch29-cxx11-cu130-x86_64-linux/distributed/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6d5843506c13d9d31603b2b4e30c1c91d0baab28 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/distributed/utils.py @@ -0,0 +1,175 @@ +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor +from torch.distributed.tensor.placement_types import (Placement, Shard, + _StridedShard) + + +def get_slices_of_dtensor( + target: DTensor | torch.Tensor, + local_rank: int, + shard_mesh: DeviceMesh, + shard_placements: tuple[Placement], +) -> tuple[slice]: + """ + Get the slice of local tensor for a given rank from a tensor. + Args: + target (DTensor | torch.Tensor): The target tensor. + rank (int): The local rank of the shard group. + shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + shard_placements (tuple[Placement]): The shard placements. + """ + + slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + + # find the global rank of the local rank in the shard mesh + rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] + + rank_coords = (shard_mesh.mesh == rank).nonzero() + + assert len(rank_coords) == 1 + rank_coords = tuple(rank_coords[0].tolist()) + + assert len(rank_coords) == len(shard_placements) + + # Caution: Assuming replicate-to-shard of the shard mesh goes with + # left-to-right sharding. This is ensured by the sorting logic of + # construct_shard_mesh function. + for i, (rank_coord, + placement) in enumerate(zip(rank_coords, shard_placements)): + assert isinstance(placement, Shard) + + num_ranks = shard_mesh.mesh.shape[i] + + dim = placement.dim + dim_size = (slices[dim].stop - slices[dim].start) + + if dim_size % num_ranks != 0: + raise NotImplementedError( + f"Dimension size {dim_size} is not divisible " + f"by number of ranks {num_ranks} for shard " + f"placement on dim {dim}. (shape: {target.shape})") + + shard_size = dim_size // num_ranks + + start = slices[dim].start + rank_coord * shard_size + end = start + shard_size + + assert start < end <= slices[dim].stop + + slices[dim] = slice(start, end) + + return tuple(slices) + + +_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, + ProcessGroup]] = dict() + + +def construct_shard_mesh( + placements: tuple[Placement], + mesh: DeviceMesh, +) -> (DeviceMesh, ProcessGroup, tuple[Placement]): + """ + Construct Shard Mesh and Placements for unsharding. + It removes Replicate placements and constructs a new Mesh and ProcessGroup. + """ + my_rank = dist.get_rank() + + assert mesh.mesh.device.type == 'cpu' + + # Copy mesh to avoid modifying the original mesh + mesh = mesh.mesh.clone() + + # 1. Sort placements. Replicate first, then Shard by dim ascending. + + # For Shard, strided shard comes after regular shard on the same dim + # to preserve left-to-right order of replicate-to-shard. + # This is because that strided shard is using stride to represent + # more fine-grained sharding on the same dim. + # Please check the URL below for _StridedShard. + # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 + + def placement_sort_key( + placement_with_index: tuple[float, Placement] + ) -> tuple[int, float, int]: # (dim, split factor, original index) + index, placement = placement_with_index + is_replicate = placement.is_replicate() + is_shard = placement.is_shard() + is_partial = placement.is_partial() + + assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" + assert not is_partial, "Partial placement is not supported." + + if is_replicate: + return (-1.0, 0, index) + elif is_shard: + if isinstance(placement, _StridedShard): + return (placement.dim, 1 / placement.split_factor, index) + return (placement.dim, 0, index) + else: + raise TypeError(f"Unknown placement type: {type(placement)}") + + placements_with_index: list[tuple[int, + Placement]] = list(enumerate(placements)) + placements_with_index = sorted(placements_with_index, + key=placement_sort_key) + + sorted_indices, sorted_placements = zip(*placements_with_index) + + # 2. Permute mesh according to sorted placements. + sorted_mesh = mesh.permute(sorted_indices) + + # 3. Collect list of shard meshes by removing replicate dims + # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] + # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) + num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + + # merge replicate dims + # shard_meshes became a list of shard meshes with a length of replicate degree + if num_replicates > 0: + sorted_mesh = sorted_mesh.flatten( + 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) + else: + shard_meshes = [sorted_mesh] + shard_placements = sorted_placements[num_replicates:] + + # assume all shard placements are different + assert len(shard_placements) == len(set(shard_placements)) + + # 4. Construct ProcessGroups + # Caution: all groups should be created in the same order in all processes, + # even though each process only needs its own group. + + # To use tensor as dict key, convert it to tuple + def tensor_to_tuple(t): + if isinstance(t, torch.Tensor): + t = t.tolist() + if isinstance(t, list): + return tuple(tensor_to_tuple(x) for x in t) + return t + + my_shard_mesh_as_tuple = None + for shard_mesh in shard_meshes: + assert isinstance(shard_mesh, torch.Tensor) + shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) + + if (my_rank == shard_mesh).any().item(): + assert my_shard_mesh_as_tuple is None + my_shard_mesh_as_tuple = shard_mesh_as_tuple + + # update global cache + if shard_mesh_as_tuple not in _ranks_to_dist_cache: + shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) + _ranks_to_dist_cache[shard_mesh_as_tuple] = ( + DeviceMesh(device_type="cuda", mesh=shard_mesh), + shard_process_group, + ) + + my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ + my_shard_mesh_as_tuple] + + return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch29-cxx11-cu130-x86_64-linux/matmul_transpose_triton.py b/build/torch29-cxx11-cu130-x86_64-linux/matmul_transpose_triton.py new file mode 100644 index 0000000000000000000000000000000000000000..4565b2c4fd506a4218340d380d6c962b16774b1d --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/matmul_transpose_triton.py @@ -0,0 +1,128 @@ +# MIT License +# +# Copyright (c) 2025 Tianyang Lin +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import torch +import triton +import triton.language as tl + + +def get_autotune_config(): + return [ + triton.Config( + { + 'BLOCK_SIZE_M': blk_m, + 'BLOCK_SIZE_K': blk_k, + 'GROUP_SIZE_M': grp_sz + }, + num_stages=n_stages, + num_warps=n_warps) for blk_m in [32, 64, 128] + for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5] + for n_warps in [4, 8] + ] + + +@triton.autotune( + configs=get_autotune_config(), + key=['M', 'K'], +) +@triton.jit +def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr): + """ + Core kernel jit function of matmul_transpose that computes y = x @ x.T + The code is a simple adaptation from the triton `matmul` tutorial: + https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html + """ + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + if pid_m > pid_n: + return + + offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + # we use a & b ptrs to denote different rows of x. + a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk) + b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, + mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, + other=0.0) + b = tl.load(b_ptrs, + mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, + other=0.0) + accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator) + a_ptrs += BLOCK_SIZE_K * stride_xk + b_ptrs += BLOCK_SIZE_K * stride_xk + # use dtype.element_ty to accommodate different input datatypes as in cpp templates + # https://github.com/triton-lang/triton/issues/2252 + c = accumulator.to(x.dtype.element_ty) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, c, mask=c_mask) + + # transpose and copy + if pid_m < pid_n: + ct_ptrs = y + stride_ym * offs_cn[:, + None] + stride_yn * offs_cm[None, :] + ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) + + +def matmul_transpose_assign(d_in, d_out): + assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" + assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" + assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" + assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" + assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" + assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" + assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ + "First dimension of `d_in` must match first and second dimension of `d_out`" + + d_in = d_in.contiguous() + M, K = d_in.shape + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( + M, META['BLOCK_SIZE_M']), ) + with torch.cuda.device(d_in.device.index): + mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), + d_out.stride(0), d_out.stride(1)) + + +def matmul_transpose(d_in): + M, _ = d_in.shape + d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) + matmul_transpose_assign(d_in, d_out) + return d_out diff --git a/build/torch29-cxx11-cu130-x86_64-linux/metadata.json b/build/torch29-cxx11-cu130-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..76bafa5f33b6818aa6bb4cab04be811b87519b44 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/metadata.json @@ -0,0 +1 @@ +{"python-depends":[]} \ No newline at end of file diff --git a/build/torch29-cxx11-cu130-x86_64-linux/muon.py b/build/torch29-cxx11-cu130-x86_64-linux/muon.py new file mode 100644 index 0000000000000000000000000000000000000000..dbf25575f185ff379789482068e4ecf55b9455a9 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/muon.py @@ -0,0 +1,1268 @@ +import logging +import math +import types +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, cast + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Replicate +from torch.distributed.tensor.placement_types import Placement + +from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor +from .matmul_transpose_triton import matmul_transpose_assign + +logger = logging.getLogger(__name__) + +COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 + + +# This code snippet is a modified version adapted from the following GitHub repositories: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +# Muon's Newton–Schulz iteration causes high variance in singular values +# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +@torch.no_grad() +# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + assert G.dtype == COMM_DTYPE + X = G # no manual typecast + + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + # Perform the NS iterations + for a, b, c in [ + (4.0848, -6.8946, 2.9270), + (3.9505, -6.3029, 2.6377), + (3.7418, -5.5913, 2.3037), + (2.8769, -3.1427, 1.2046), + (2.8366, -3.0525, 1.2012), + ]: + matmul_transpose_assign(X, buf1) + matmul_transpose_assign(buf1, buf2) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X + + +@dataclass +class _muon_state: + # TODO: use Optional + worker_rank: int + process_group: ProcessGroup + shard_mesh: DeviceMesh + shard_placements: tuple[Placement, ...] + name: str + qk_clip_state: torch.Tensor | None = None + gathered_grad: torch.Tensor | None = None + scattered_u: DTensor | None = None + computed_u: torch.Tensor | None = None + gather_event: torch.cuda.Event | None = None + compute_event: torch.cuda.Event | None = None + scatter_event: torch.cuda.Event | None = None + + +def numel_for_rank( + param: DTensor, + local_rank: int, + state: _muon_state, +) -> int: + slices = get_slices_of_dtensor( + param, + local_rank, + state.shard_mesh, + state.shard_placements, + ) + + numel = 1 + for s, dim in zip(slices, param.shape): + start, stop, step = s.indices(dim) + length = max(0, (stop - start + (step - 1)) // step) + numel *= length + + return numel + + +@torch.no_grad() +def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): + """ + Pre-allocate gathered_grad buffer on compute_stream + before launching all2all gather + """ + with torch.cuda.stream(compute_stream): + for p in params: + state = param_to_state[id(p)] + if rank == state.worker_rank: + state.gathered_grad = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + else: + state.gathered_grad = None + + alloc_event = torch.cuda.Event() + alloc_event.record(compute_stream) + return alloc_event + + +@torch.no_grad() +def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, + alloc_event): + """ + All2all gathers shards so each owner rank reconstructs its full gradient + """ + with torch.cuda.stream(comm_stream): + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + + # Construct sending buffers + per_dst = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + for p in params: + state = param_to_state[id(p)] + dst = state.worker_rank + assert dst < num_ranks + shard_elems = numel_for_rank(p, rank, state) + g = p.grad + g = g.to_local().to(COMM_DTYPE).contiguous() + assert g.numel() == shard_elems + per_dst[dst].append(g.view(-1)) + send_counts[dst] += shard_elems + + assert any( + len(v) > 0 for v in per_dst + ), "At least one destination rank must receive a sharded tensor" + # list[list[Tensor]] -> list[Tensor] + per_dst = [t for dst in per_dst for t in dst] + + send_buf = torch.cat(per_dst, dim=0) + + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Compute receive sizes and allocate receiving buffers + recv_counts = [0] * num_ranks + + for src in range(num_ranks): + total = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + total += numel_for_rank(p, src, state) + recv_counts[src] = total + + recv_total = sum(recv_counts) + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + #All2All + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") + dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + ) + + # Reconstructs gathered grad from the received buffer + # + # recv_buf (num ranks = 3) + # + # From rank 0 From rank 1 From rank 2 + # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | + # + # Outer loop: + # rank 0 -> rank 1 -> rank2 + # + # Inner loop: + # p1_n -> p2_n -> p3_n + + comm_stream.wait_event(alloc_event) + + off = 0 + for src in range(num_ranks): + if recv_counts[src] == 0: + continue + + block = recv_counts[src] + inner_off = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + + # get the slice of the full dtensor corresponding to rank src. + slices = get_slices_of_dtensor(state.gathered_grad, src, + state.shard_mesh, + state.shard_placements) + + dst = state.gathered_grad[slices] + assert dst._base is state.gathered_grad + + n = dst.numel() + assert n > 0 + + sg = recv_buf.narrow(0, off + inner_off, n) + sg = sg.reshape_as(dst) + dst.copy_(sg) + + inner_off += n + off += block + + for p in params: + state = param_to_state[id(p)] + if state.worker_rank == rank: + state.gather_event = torch.cuda.Event() + state.gather_event.record(comm_stream) + else: + state.gathered_grad = None + state.gather_event = None + if none_grad: + p.grad = None + + +@torch.no_grad() +def _compute_u(p, state, steps, rank, compute_stream): + """ + On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. + """ + with torch.cuda.stream(compute_stream): + if rank == state.worker_rank: + if state.gather_event is None: + raise RuntimeError("Gather event must be set before compute.") + compute_stream.wait_event(state.gather_event) + u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) + state.gathered_grad = None + state.computed_u = u + state.compute_event = torch.cuda.Event() + state.compute_event.record() + else: + state.computed_u = None + state.compute_event = None + + +@torch.no_grad() +def _alloc_scattered_u(params, param_to_state, rank, compute_stream): + """ + Pre-allocate scattered_u buffer on compute_stream + before launching all2all gather + """ + with torch.cuda.stream(compute_stream): + for p in params: + state = param_to_state[id(p)] + state.scattered_u = torch.empty_like(p.to_local(), + dtype=COMM_DTYPE) + + alloc_event = torch.cuda.Event() + alloc_event.record(compute_stream) + return alloc_event + + +def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): + """ + All2all scatters full gradients to all ranks + """ + with torch.cuda.stream(comm_stream): + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Construct sending buffer + per_dst = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + if owned_params: + for p in owned_params: + state = param_to_state[id(p)] + if state.compute_event is None: + raise RuntimeError( + "Compute event must be set before scatter.") + comm_stream.wait_event(state.compute_event) + state.gathered_grad = None + + assert state.computed_u is not None + + u_full = state.computed_u.to(COMM_DTYPE).contiguous() + + offset = 0 + for dst in range(num_ranks): + # get the slice of the full tensor corresponding to rank dst. + slices = get_slices_of_dtensor(u_full, dst, + state.shard_mesh, + state.shard_placements) + su = u_full[slices].flatten() + + n = su.numel() + assert n > 0 + + per_dst[dst].append(su) + send_counts[dst] += n + offset += n + + assert offset == u_full.numel() + + lengths = [len(v) for v in per_dst] + if all(l > 0 for l in lengths): + assert all( + l == lengths[0] for l in lengths + ), "All destination ranks must have the same number of sharded tensor" + # list[list[Tensor]] -> list[Tensor] + per_dst = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst, dim=0) + else: + # all_to_all requires participation from all ranks + # Even non-owner ranks must join the collective call + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") + + # Compute receive sizes and allocate receiving buffers + recv_counts = [0] * num_ranks + + for src in range(num_ranks): + total = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + total += numel_for_rank(p, rank, state) + recv_counts[src] = total + + recv_total = sum(recv_counts) + assert recv_total > 0 + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + #All2All + dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + ) + + # Copy to pre-allocated scattered_u buffer from the received buffer + # + # recv_buf (num ranks = 3, local_rank = 0) + # + # From rank 0 From rank 1 From rank 2 + # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | + # + # Outer loop: + # rank 0 -> rank 1 -> rank2 + # + # Inner loop: + # src(0) : p1_0 -> p2_0 -> p3_0 + # src(1) : p4_0 + # src(2) : p5_0 -> p6_0 + + comm_stream.wait_event(alloc_event) + + off = 0 + for src in range(num_ranks): + block = recv_counts[src] + if block == 0: + continue + + inner_off = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + n = numel_for_rank(p, rank, state) + assert n > 0 + + flat_local = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) + state.scattered_u.copy_(flat_local) + + state.scatter_event = torch.cuda.Event() + state.scatter_event.record(comm_stream) + inner_off += n + + assert inner_off == block + off += block + + +def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, + compute_stream): + """ + Update sharded parameter p with the scattered_u. + Only worker_rank frees computed_u. + """ + with torch.cuda.stream(compute_stream): + if state.scatter_event is None: + raise RuntimeError("Scatter event must be set before update") + compute_stream.wait_event(state.scatter_event) + u_dtensor = DTensor.from_local( + state.scattered_u, + placements=p.placements, + device_mesh=p.device_mesh, + ) + + state.scattered_u = u_dtensor + + if rank == state.worker_rank: + # Free computed_u + state.computed_u = None + + Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) + state.scattered_u = None + u_dtensor = None + + scales_full = Muon._compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None + if scales_full is not None: + # Have to slice scales_full among dim 0 + weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, + state.shard_placements) + ratio = p.shape[0] // scales_full.shape[0] + scales_slice = slice( + None if weight_slices[0].start is None else + weight_slices[0].start // ratio, + None if weight_slices[0].stop is None else + weight_slices[0].stop // ratio, + None, + ) + + scales_local = scales_full[scales_slice] + scales_local = DTensor.from_local( + scales_local, + placements=p.placements, + device_mesh=p.device_mesh, + ) + Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) + + +def default_is_muon(name, x): + skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] + return x.ndim >= 2 and not any(key in name for key in skip_keys) + + +def get_default_muon_param_groups(model, is_muon_func=default_is_muon): + muon_params, muon_names = [], [] + non_muon_params = [] + + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + if is_muon_func(n, p): + muon_params.append(p) + muon_names.append(n) + else: + non_muon_params.append(p) + + return [ + { + "params": muon_params, + "names": muon_names, + "use_muon": True, + }, + { + "params": non_muon_params, + "use_muon": False, + }, + ] + + +def parse_qk_layer(name: str) -> tuple[str | None, int]: + """ + Parse a parameter name to check if it is a query/key projection layer + ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + + Returns: + (kind, layer_idx) or (None, -1) if not matched. + + Example: + 'model.3.attn.wq.weight' -> ('wq', 3) + 'model.5.attn.wk.weight' -> ('wk', 5) + 'model.2.attn.q_proj.weight' -> ('q_proj', 2) + 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.4.attn.v_proj.weight' -> (None, -1) + """ + parts = name.split('.') + if len(parts) < 3: + return None, -1 + + kind = parts[-2] + + layer_idx = -1 + for part in reversed(parts): + if part.isdigit(): + layer_idx = int(part) + break + + if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + return kind, layer_idx + + return None, -1 + + +@dataclass +class QKClipInfo: + """Per-parameter dynamic info computed from config + runtime logits.""" + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping + head_dim: int # from config + threshold: float # from config + logit: torch.Tensor | None + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + model: The model to be optimized by Muon. + is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + weight_decay: The weight decay for Muon and AdamW. + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. + debug: Whether to print debug information. + clip_info : Configuration for QK clipping. Expected keys: + - "q_indices" (list[int]): Indices of query heads to consider. + - "k_indices" (list[int]): Indices of key heads to consider. + - "head_dim" (int): Dimensionality of each attention head. + - "threshold" (float): Threshold value; heads whose QK logits exceed + this value will be scaled down. + Default is: + { + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100 + } + warmup_step : How many all2all gather, compute operations are launched in advance + before the corresponding all2all scatter steps begin. + A higher warmup_step increases memory usage but can improve + performance by overlapping communication. + Parallel muon only. + chunk_size : Batch size of parameters to process in each + all2all gather/compute/scatter step. + Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. + use_distributed_muon: Use distributed muon by Liu et al. (2024). + For testing purpose only. + small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon + """ + + def __init__(self, + params, + lr=1e-3, + momentum=0.95, + nesterov=True, + ns_steps=5, + weight_decay=0.1, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + none_grad=True, + debug=False, + clip_config={ + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100 + }, + warmup_step=5, + chunk_size=-1, + use_distributed_muon=False, + small_param_numel_threshold=65536): + defaults = dict( + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + none_grad=none_grad, + use_muon=True, + ) + error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." + instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" + + if isinstance(params, types.GeneratorType): + raise ValueError(error_message.format(idx=0) + instruction_code) + for _idx, param_group in enumerate(params): + if param_group.get("use_muon", None) is None: + raise ValueError( + error_message.format(idx=_idx) + instruction_code) + + super().__init__(params, defaults) + + self.rank = None + + self.comm_stream = torch.cuda.Stream() + self.compute_stream = torch.cuda.Stream() + self.debug = debug + self.clip_config = clip_config + self.warmup_step = warmup_step + self.chunk_size = chunk_size + self.use_distributed_muon = use_distributed_muon + self.small_param_numel_threshold = small_param_numel_threshold + + def _calc_flops(self, G, steps): + assert len(G.shape) == 2 + M, N = G.shape + if M > N: + M, N = N, M + + return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def set_rank_once(self, rank): + if self.rank is None: + self.rank = rank + else: + assert self.rank == rank + + def get_shard_mesh(self, p): + """ + Get the shard mesh for a parameter p on the given rank. + """ + assert isinstance( + p, DTensor), "Parallel Muon only supports DTensor parameters." + + shard_mesh, shard_pg, shard_placements = construct_shard_mesh( + p.placements, p.device_mesh) + + # set rank with the local rank in the shard process group + self.set_rank_once(dist.get_rank(group=shard_pg)) + + return shard_mesh, shard_pg, shard_placements + + def init_state_and_assign_params(self, names, params, group, qk_logits): + param_to_state = {} + param_to_flops = {} + + total_flops = 0 + for p in params: + g = p.grad + if g is None: + continue + assert g.ndim == 2, "Muon only supports 2D parameters." + + flops = self._calc_flops(g, group["ns_steps"]) + param_to_flops[id(p)] = flops + total_flops += flops + + if self.debug: + print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", + flush=True) + + paired = list(zip(names, params)) + + paired_sorted = sorted(paired, + key=lambda x: param_to_flops[id(x[1])], + reverse=True) + + names_sorted, params_sorted = zip(*paired_sorted) + ordered_names = list(names_sorted) + ordered_params = list(params_sorted) + + round_robin = 0 + mesh = ordered_params[0].device_mesh + placements = ordered_params[0].placements + + shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( + ordered_params[0]) + shard_mesh_flattened = shard_mesh.mesh.flatten() + num_ranks = dist.get_world_size(group=shard_pg) + + for n, p in zip(ordered_names, ordered_params): + if mesh != p.device_mesh: + raise ValueError("All parameters must be on the same mesh.") + if placements != p.placements: + raise ValueError("All parameters must have same placements.") + + worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks + round_robin = (round_robin + 1) % len(shard_mesh_flattened) + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + param_to_state[id(p)] = _muon_state( + worker_rank=worker_rank, + process_group=shard_pg, + shard_mesh=shard_mesh, + shard_placements=shard_placements, + name=n, + qk_clip_state=qk_clip_state, + ) + + return param_to_state, ordered_params + + def base(self, names, params, group, lr, weight_decay, momentum, + qk_logits): + # generate weight updates in distributed fashion + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + g = self._update_g(p, g, group, momentum) + + u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), + steps=group["ns_steps"]) + + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + scales_full = self._compute_scales( + p, qk_clip_state) if qk_clip_state is not None else None + if scales_full is not None: + Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + + def distributed_muon( + self, + names: list[str], + params: list[torch.nn.Parameter], + group: dict[str, Any], + lr: float, + weight_decay: float, + momentum: float, + qk_logits: list[torch.Tensor | DTensor] | None, + ): + """ Implementation of Distributed Muon by Liu et al. """ + + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + g = self._update_g(p, g, group, momentum) + + # Gather G + if isinstance(p.data, DTensor): + g_full = g.full_tensor() + p_full = p.data.full_tensor() + else: + g_full = g + p_full = p + + u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), + steps=group["ns_steps"]) + + adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) + Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + scales_full = self._compute_scales( + p_full, qk_clip_state) if qk_clip_state is not None else None + + if scales_full is not None: + Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) + + if isinstance(p.data, DTensor): + ndims = len(p.device_mesh.mesh.shape) + p_replicate = DTensor.from_local( + p_full, + device_mesh=p.device_mesh, + placements=[Replicate() for _ in range(ndims)], + ) + + p_sharded = p_replicate.redistribute( + device_mesh=p.device_mesh, + placements=p.placements, + ) + + p.copy_(p_sharded) + + def _update_g(self, p, g, group, momentum): + # calc update + state = self.state[p] + buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) + torch.add(g, buf, alpha=momentum, out=buf) + if group["nesterov"]: + g.add_(buf, alpha=momentum) + return g + return buf + + @staticmethod + def _update_p(p, u, lr, adjusted_lr, weight_decay): + if isinstance(p, torch.nn.Parameter): + # apply weight decay + p.data.mul_(1 - lr * weight_decay) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + else: + p.mul_(1 - lr * weight_decay) + p.add_(u, alpha=-adjusted_lr) + + def get_qk_clip_info(self, n, qk_logits): + if self.clip_config is None: + return None + + head_dim = self.clip_config.get('head_dim') + threshold = self.clip_config.get('threshold') + kind, layer_idx = parse_qk_layer(n) + + logit, indices = None, [] + if qk_logits is not None and kind is not None: + logit = qk_logits[layer_idx] + indices_key = 'q_indices' if 'q' in kind else 'k_indices' + indices = self.clip_config.get(indices_key, []) or [] + + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) + + @staticmethod + def _compute_scales(p, qk_clip_state): + kind = qk_clip_state.kind + indices = qk_clip_state.indices + head_dim = qk_clip_state.head_dim + threshold = qk_clip_state.threshold + logit = qk_clip_state.logit + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + scaling = 0 + + for logit_idx, head_idx in enumerate(indices): + v_ele = float(logit[logit_idx]) + if v_ele > threshold: + new_scale = math.sqrt(threshold / v_ele) + if new_scale < scales_full[head_idx]: + scales_full[head_idx] = new_scale + logger.info( + f"[{kind}] Head {head_idx} exceeded threshold " + f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" + ) + scaling += 1 + + return scales_full if scaling > 0 else None + + @staticmethod + def _qk_clip(p, scales, head_dim): + if isinstance(p, torch.nn.Parameter): + W = p.data.view(-1, head_dim, p.data.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + else: + W = p.view(-1, head_dim, p.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + + def parallel(self, names, params, group, lr, weight_decay, momentum, + qk_logits): + """ + Perform a parallel optimization step using Muon. + """ + + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + + # Update g in the local rank + g = self._update_g( + p, + g, + group, + momentum=momentum, + ) + p.grad = g + + param_to_state, ordered_params = self.init_state_and_assign_params( + names, params, group, qk_logits) + + assert self.rank is not None + + def enqueue_all2all_gather(start_idx, chunk_size): + target_params = ordered_params[start_idx:start_idx + chunk_size] + if target_params: + alloc_event = _alloc_gathered_grad(target_params, + param_to_state, self.rank, + self.compute_stream) + _all2all_gather(target_params, param_to_state, self.rank, + self.comm_stream, group["none_grad"], + alloc_event) + + def enqueue_computes(start_idx, chunk_size): + for p in ordered_params[start_idx:start_idx + chunk_size]: + state = param_to_state[id(p)] + _compute_u(p, state, group["ns_steps"], self.rank, + self.compute_stream) + + def enqueue_all2all_scatter(start_idx, chunk_size): + target_params = ordered_params[start_idx:start_idx + chunk_size] + if target_params: + alloc_event = _alloc_scattered_u(target_params, param_to_state, + self.rank, + self.compute_stream) + _all2all_scatter(target_params, param_to_state, self.rank, + self.comm_stream, alloc_event) + + def enqueue_update_param(start_idx, chunk_size): + for p in ordered_params[start_idx:start_idx + chunk_size]: + state = param_to_state[id(p)] + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + _update_param(p, state, lr, adjusted_lr, weight_decay, + self.rank, self.compute_stream) + + if self.chunk_size == -1: + shard_ranks = dist.get_world_size(param_to_state[id( + params[0])].process_group) + chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO + elif self.chunk_size > 0: + chunk_size = self.chunk_size + else: + raise ValueError("chunk_size must be -1 or a positive integer.") + + # Wait grad update + self.comm_stream.wait_stream(torch.cuda.current_stream()) + + warmup_step = self.warmup_step + for i in range(0, warmup_step): + enqueue_all2all_gather(i * chunk_size, chunk_size) + enqueue_computes(i * chunk_size, chunk_size) + + for i in range(0, len(params) + chunk_size - 1, chunk_size): + enqueue_all2all_scatter(i, chunk_size) + enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) + enqueue_update_param(i, chunk_size) + enqueue_computes(i + warmup_step * chunk_size, chunk_size) + + # Wait the last update_param to finish + torch.cuda.current_stream().wait_stream(self.compute_stream) + + @staticmethod + def _fused_adamw( + params: list[torch.Tensor], + grads: list[torch.Tensor], + exp_avgs: list[torch.Tensor], + exp_avg_sqs: list[torch.Tensor], + max_exp_avg_sqs: list[torch.Tensor], + state_steps: list[torch.Tensor], + amsgrad: bool, + beta1: float, + beta2: float, + lr: float | torch.Tensor, + weight_decay: float, + eps: float, + maximize: bool, + ) -> None: + if not params: + return + + # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer + # treating it as a scalar. + lr_dict: DeviceDict | None = ({ + lr.device: lr + } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else + None) + grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( + [ + params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, + state_steps + ] # type: ignore[list-item] + ) + for (device, _), ( + ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_max_exp_avg_sqs, + device_state_steps_, + ), + _, + ) in grouped_tensors.items(): + device_params = cast(list[torch.Tensor], device_params_) + device_grads = cast(list[torch.Tensor], device_grads_) + device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) + device_state_steps = cast(list[torch.Tensor], device_state_steps_) + + if lr_dict is not None and device not in lr_dict: + lr_dict[device] = lr.to( + device=device, + non_blocking=True) # type: ignore[union-attr] + lr = lr_dict[device] + torch._foreach_add_(device_state_steps, 1) + func = torch._fused_adamw_ + func( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, # type: ignore[arg-type] + device_state_steps, + amsgrad=amsgrad, + lr=lr, # type: ignore[arg-type] + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + ) + + def _step_muon(self, group, qk_logits=None): + params = group["params"] + lr = group["lr"] + weight_decay = group["weight_decay"] + momentum = group["momentum"] + names = group["names"] + + param_dtensors = [] + name_dtensors = [] + + param_tensors = [] + name_tensors = [] + + param_dtensors_small = [] + name_dtensors_small = [] + + if self.use_distributed_muon: + self.distributed_muon(names=names, + params=params, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits) + return + + # For simplicity, we use distributed Muon for small parameters + # whose number of elements is below a threshold. + for n, p in zip(names, params): + if p is None or p.grad is None: + continue + if isinstance(p.data, DTensor): + if all( + isinstance(placement, Replicate) + for placement in p.placements): + param_tensors.append(p) + name_tensors.append(n) + elif p.data.numel() <= self.small_param_numel_threshold: + param_dtensors_small.append(p) + name_dtensors_small.append(n) + else: + param_dtensors.append(p) + name_dtensors.append(n) + elif isinstance(p.data, torch.Tensor): + param_tensors.append(p) + name_tensors.append(n) + else: + raise TypeError(f"Unsupported parameter type: {type(p.data)}") + + logger.debug( + f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " + f"{len(param_dtensors_small)} Small DTensors") + + def group_dtensors(dtensors, names): + # To support different placements, we group parameters by placements + # and run parallel Muon on each group. + + placement_to_params = defaultdict(lambda: ([], [])) + # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] + + assert len(dtensors) == len(names) + for p, n in zip(dtensors, names): + placement_to_params[tuple([p.placements, + p.device_mesh])][0].append(n) + placement_to_params[tuple([p.placements, + p.device_mesh])][1].append(p) + return placement_to_params + + if len(param_dtensors_small) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + self.distributed_muon( + params=param_dtensors_small, + names=name_dtensors_small, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_dtensors) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + dtensor_group = group_dtensors(param_dtensors, name_dtensors) + for _, (names, params) in dtensor_group.items(): + self.parallel( + names, + params, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_tensors) > 0: + self.base( + name_tensors, + param_tensors, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + def _step_adamw_params(self, params, group): + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + self._fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + def _step_adamw(self, group): + params = group["params"] + + # group params with it's type and placement + placement_to_params: dict[tuple[Placement | type, + DeviceMesh | None]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for params in placement_to_params.values(): + self._step_adamw_params(params, group) + + @torch.no_grad + def step(self, closure=None, qk_logits=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices + to 1D tensors of shape (num_heads,), representing the maximum + QK logits across all tokens, computed as + (1 / sqrt(head_dim)) * (Q @ K^T). + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + if group["use_muon"]: + self._step_muon(group, qk_logits=qk_logits) + else: + self._step_adamw(group) + + return loss diff --git a/build/torch29-cxx11-cu130-x86_64-linux/optimizer/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/optimizer/__init__.py index 239c7a65f8293e7d0df28f05fce645af56d628c0..03dbc1afe1cf156661a2b1b22003cd5f599a0309 100644 --- a/build/torch29-cxx11-cu130-x86_64-linux/optimizer/__init__.py +++ b/build/torch29-cxx11-cu130-x86_64-linux/optimizer/__init__.py @@ -1,5 +1,26 @@ -from .muon import Muon +import ctypes +import sys -__all__ = [ - "Muon", -] +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/optimizer/_ops.py b/build/torch29-cxx11-cu130-x86_64-linux/optimizer/_ops.py deleted file mode 100644 index 7d598206add1bca142661a3df6c510e3d9575d54..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/optimizer/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _optimizer_23d68bb_dirty -ops = torch.ops._optimizer_23d68bb_dirty - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_optimizer_23d68bb_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu130-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so b/build/torch29-cxx11-cu130-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so deleted file mode 100755 index 8910b678b8c1c618c797441b964171862abbb32d..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:aa394498c52692c29094cbd2cc3da6c4c37aefaa4454c97487f8e91827fbd814 -size 1988672 diff --git a/build/torch29-cxx11-cu130-x86_64-linux/optimizer/distributed/utils.py b/build/torch29-cxx11-cu130-x86_64-linux/optimizer/distributed/utils.py deleted file mode 100644 index 0b4b58bfb329b1c015129e4c4fc99f7bfa2ab30a..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/optimizer/distributed/utils.py +++ /dev/null @@ -1,174 +0,0 @@ -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor -from torch.distributed.tensor.placement_types import (Placement, Shard, - _StridedShard) - - -def get_slices_of_dtensor( - target: DTensor | torch.Tensor, - local_rank: int, - shard_mesh: DeviceMesh, - shard_placements: tuple[Placement], -) -> tuple[slice]: - """ - Get the slice of local tensor for a given rank from a tensor. - Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. - shard_placements (tuple[Placement]): The shard placements. - """ - - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] - - # find the global rank of the local rank in the shard mesh - rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] - - rank_coords = (shard_mesh.mesh == rank).nonzero() - - assert len(rank_coords) == 1 - rank_coords = tuple(rank_coords[0].tolist()) - - assert len(rank_coords) == len(shard_placements) - - # Caution: Assuming replicate-to-shard of the shard mesh goes with - # left-to-right sharding. This is ensured by the sorting logic of - # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) - - num_ranks = shard_mesh.mesh.shape[i] - - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) - - if dim_size % num_ranks != 0: - raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}.") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) - - return tuple(slices) - - -_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, ProcessGroup]] = dict() - - -def construct_shard_mesh( - placements: tuple[Placement], - mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() - - assert mesh.mesh.device.type == 'cpu' - - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") - - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) - - sorted_indices, sorted_placements = zip(*placements_with_index) - - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) - - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) - - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh - shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) - else: - shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different - assert len(shard_placements) == len(set(shard_placements)) - - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, - ) - - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch29-cxx11-cu130-x86_64-linux/optimizer/muon.py b/build/torch29-cxx11-cu130-x86_64-linux/optimizer/muon.py deleted file mode 100644 index cfbcca71741be70048bfd290c62148b2aceda631..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/optimizer/muon.py +++ /dev/null @@ -1,1240 +0,0 @@ -import logging -import math -import types -from collections import defaultdict -from dataclasses import dataclass -from typing import Any, cast - -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign - -logger = logging.getLogger(__name__) - -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 - - -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. -@torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon -def _zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE - X = G # no manual typecast - - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) - - if G.size(0) > G.size(1): - X = X.T - return X - - -@dataclass -class _muon_state: - # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None - gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None - computed_u: torch.Tensor | None = None - gather_event: torch.cuda.Event | None = None - compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel - - -@torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None - - -@torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ - with torch.cuda.stream(compute_stream): - if rank == state.worker_rank: - if state.gather_event is None: - raise RuntimeError("Gather event must be set before compute.") - compute_stream.wait_event(state.gather_event) - u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None - state.computed_u = u - state.compute_event = torch.cuda.Event() - state.compute_event.record() - else: - state.computed_u = None - state.compute_event = None - - -@torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) - else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) - - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n - - assert inner_off == block - off += block - - -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. - """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - - if rank == state.worker_rank: - # Free computed_u - state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: - continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx - - return None, -1 - - -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None - - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Some warnings: - - We believe this optimizer is unlikely to work well for training with small batch size. - - We believe it may not work well for finetuning pretrained models, but we haven't tested this. - - Arguments: - model: The model to be optimized by Muon. - is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. - lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) - momentum: The momentum used by the internal SGD. (0.95 is a good default) - nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) - ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) - weight_decay: The weight decay for Muon and AdamW. - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. - adamw_lr: The learning rate for the internal AdamW. - adamw_betas: The betas for the internal AdamW. - adamw_eps: The epsilon for the internal AdamW. - none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. - debug: Whether to print debug information. - clip_info : Configuration for QK clipping. Expected keys: - - "q_indices" (list[int]): Indices of query heads to consider. - - "k_indices" (list[int]): Indices of key heads to consider. - - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed - this value will be scaled down. - Default is: - { - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - } - warmup_step : How many all2all gather, compute operations are launched in advance - before the corresponding all2all scatter steps begin. - A higher warmup_step increases memory usage but can improve - performance by overlapping communication. - Parallel muon only. - chunk_size : Batch size of parameters to process in each - all2all gather/compute/scatter step. - Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. - use_distributed_muon: Use distributed muon by Liu et al. (2024). - For testing purpose only. - """ - - def __init__(self, - params, - lr=1e-3, - momentum=0.95, - nesterov=True, - ns_steps=5, - weight_decay=0.1, - adamw_betas=(0.9, 0.95), - adamw_eps=1e-8, - none_grad=True, - debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, - warmup_step=5, - chunk_size=-1, - use_distributed_muon=False): - defaults = dict( - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - nesterov=nesterov, - ns_steps=ns_steps, - adamw_betas=adamw_betas, - adamw_eps=adamw_eps, - none_grad=none_grad, - use_muon=True, - ) - error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." - instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" - - if isinstance(params, types.GeneratorType): - raise ValueError(error_message.format(idx=0) + instruction_code) - for _idx, param_group in enumerate(params): - if param_group.get("use_muon", None) is None: - raise ValueError( - error_message.format(idx=_idx) + instruction_code) - - super().__init__(params, defaults) - - self.rank = None - - self.comm_stream = torch.cuda.Stream() - self.compute_stream = torch.cuda.Stream() - self.debug = debug - self.clip_config = clip_config - self.warmup_step = warmup_step - self.chunk_size = chunk_size - self.use_distributed_muon = use_distributed_muon - - def _calc_flops(self, G, steps): - assert len(G.shape) == 2 - M, N = G.shape - if M > N: - M, N = N, M - - return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) - - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - - def get_shard_mesh(self, p): - """ - Get the shard mesh for a parameter p on the given rank. - """ - assert isinstance( - p, DTensor), "Parallel Muon only supports DTensor parameters." - - shard_mesh, shard_pg, shard_placements = construct_shard_mesh( - p.placements, p.device_mesh) - - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - - return shard_mesh, shard_pg, shard_placements - - def init_state_and_assign_params(self, names, params, group, qk_logits): - param_to_state = {} - param_to_flops = {} - - total_flops = 0 - for p in params: - g = p.grad - if g is None: - continue - assert g.ndim == 2, "Muon only supports 2D parameters." - - flops = self._calc_flops(g, group["ns_steps"]) - param_to_flops[id(p)] = flops - total_flops += flops - - if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) - - paired = list(zip(names, params)) - - paired_sorted = sorted(paired, - key=lambda x: param_to_flops[id(x[1])], - reverse=True) - - names_sorted, params_sorted = zip(*paired_sorted) - ordered_names = list(names_sorted) - ordered_params = list(params_sorted) - - round_robin = 0 - mesh = ordered_params[0].device_mesh - placements = ordered_params[0].placements - - shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( - ordered_params[0]) - shard_mesh_flattened = shard_mesh.mesh.flatten() - num_ranks = dist.get_world_size(group=shard_pg) - - for n, p in zip(ordered_names, ordered_params): - if mesh != p.device_mesh: - raise ValueError("All parameters must be on the same mesh.") - if placements != p.placements: - raise ValueError("All parameters must have same placements.") - - worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks - round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - param_to_state[id(p)] = _muon_state( - worker_rank=worker_rank, - process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, - name=n, - qk_clip_state=qk_clip_state, - ) - - return param_to_state, ordered_params - - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - # generate weight updates in distributed fashion - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - # calc update - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if group["nesterov"]: - g = g.add(buf, alpha=momentum) - else: - g = buf - - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p, qk_clip_state) if qk_clip_state is not None else None - if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) - - def distributed_muon( - self, - names: list[str], - params: list[torch.nn.Parameter], - group: dict[str, Any], - lr: float, - weight_decay: float, - momentum: float, - qk_logits: list[torch.Tensor | DTensor] | None, - ): - """ Implementation of Distributed Muon by Liu et al. """ - if qk_logits is not None: - raise NotImplementedError("QK clipping is not supported yet") - - if isinstance(params[0], DTensor): - shard_mesh, _, shard_placements = construct_shard_mesh( - placements=params[0].placements, - mesh=params[0].device_mesh, - ) - - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - # calc update - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if group["nesterov"]: - g = g.add(buf, alpha=momentum) - else: - g = buf - - # Gather G - if isinstance(p.data, DTensor): - g = g.full_tensor() - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) - - if isinstance(p.data, DTensor): - slices = get_slices_of_dtensor( - target=p, - local_rank=dist.get_rank(), - shard_mesh=shard_mesh, - shard_placements=shard_placements, - ) - u_shard = u[slices] - u = DTensor.from_local( - u_shard, - device_mesh=p.device_mesh, - placements=p.placements, - ) - - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) - - def _update_g(self, p, g, group, momentum): - # calc update - state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - """ - Perform a parallel optimization step using Muon. - """ - - for p in params: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - - # Update g in the local rank - g = self._update_g( - p, - g, - group, - momentum=momentum, - ) - p.grad = g - - param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - assert self.rank is not None - - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) - - def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) - - if self.chunk_size == -1: - shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) - chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO - elif self.chunk_size > 0: - chunk_size = self.chunk_size - else: - raise ValueError("chunk_size must be -1 or a positive integer.") - - # Wait grad update - self.comm_stream.wait_stream(torch.cuda.current_stream()) - - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - - for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return - - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) - - def _step_muon(self, group, qk_logits=None): - params = group["params"] - lr = group["lr"] - weight_decay = group["weight_decay"] - momentum = group["momentum"] - names = group["names"] - - param_dtensors = [] - param_tensors = [] - name_dtensors = [] - name_tensors = [] - - if self.use_distributed_muon: - self.distributed_muon(names=names, - params=params, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits) - return - - for n, p in zip(names, params): - if p is None or p.grad is None: - continue - if isinstance(p.data, DTensor): - if all( - isinstance(placement, Replicate) - for placement in p.placements): - param_tensors.append(p) - name_tensors.append(n) - else: - param_dtensors.append(p) - name_dtensors.append(n) - elif isinstance(p.data, torch.Tensor): - param_tensors.append(p) - name_tensors.append(n) - else: - raise TypeError(f"Unsupported parameter type: {type(p.data)}") - - logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors" - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - # To support different placements, we group parameters by placements - # and run parallel Muon on each group. - - placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] - - assert len(name_dtensors) == len(param_dtensors) - for n, p in zip(name_dtensors, param_dtensors): - placement_to_params[tuple([p.placements, - p.device_mesh])][0].append(n) - placement_to_params[tuple([p.placements, - p.device_mesh])][1].append(p) - - for _, (names, params) in placement_to_params.items(): - self.parallel( - names, - params, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_tensors) > 0: - self.base( - name_tensors, - param_tensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] - - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - - def step(self, closure=None, qk_logits=None): - """Perform a single optimization step. - - Args: - closure (Callable, optional): A closure that reevaluates the model - and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as - (1 / sqrt(head_dim)) * (Q @ K^T). - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - if group["use_muon"]: - self._step_muon(group, qk_logits=qk_logits) - else: - self._step_adamw(group) - - return loss diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/__init__.py b/build/torch29-cxx11-rocm63-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..239c7a65f8293e7d0df28f05fce645af56d628c0 --- /dev/null +++ b/build/torch29-cxx11-rocm63-x86_64-linux/__init__.py @@ -0,0 +1,5 @@ +from .muon import Muon + +__all__ = [ + "Muon", +] diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/_ops.py b/build/torch29-cxx11-rocm63-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..e6f6fcf6280e969b1761926112147d3146e27b59 --- /dev/null +++ b/build/torch29-cxx11-rocm63-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _optimizer_06a260a_dirty +ops = torch.ops._optimizer_06a260a_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch29-cxx11-rocm63-x86_64-linux/_optimizer_06a260a_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..926869eca5ee9c6a8f6899f3966ba361bc640faa --- /dev/null +++ b/build/torch29-cxx11-rocm63-x86_64-linux/_optimizer_06a260a_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c1574fefc74653a663d8c4c53dda381d92c60cdc29358f15618b1b746dc4ae4e +size 1865112 diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/distributed/utils.py b/build/torch29-cxx11-rocm63-x86_64-linux/distributed/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6d5843506c13d9d31603b2b4e30c1c91d0baab28 --- /dev/null +++ b/build/torch29-cxx11-rocm63-x86_64-linux/distributed/utils.py @@ -0,0 +1,175 @@ +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor +from torch.distributed.tensor.placement_types import (Placement, Shard, + _StridedShard) + + +def get_slices_of_dtensor( + target: DTensor | torch.Tensor, + local_rank: int, + shard_mesh: DeviceMesh, + shard_placements: tuple[Placement], +) -> tuple[slice]: + """ + Get the slice of local tensor for a given rank from a tensor. + Args: + target (DTensor | torch.Tensor): The target tensor. + rank (int): The local rank of the shard group. + shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + shard_placements (tuple[Placement]): The shard placements. + """ + + slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + + # find the global rank of the local rank in the shard mesh + rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] + + rank_coords = (shard_mesh.mesh == rank).nonzero() + + assert len(rank_coords) == 1 + rank_coords = tuple(rank_coords[0].tolist()) + + assert len(rank_coords) == len(shard_placements) + + # Caution: Assuming replicate-to-shard of the shard mesh goes with + # left-to-right sharding. This is ensured by the sorting logic of + # construct_shard_mesh function. + for i, (rank_coord, + placement) in enumerate(zip(rank_coords, shard_placements)): + assert isinstance(placement, Shard) + + num_ranks = shard_mesh.mesh.shape[i] + + dim = placement.dim + dim_size = (slices[dim].stop - slices[dim].start) + + if dim_size % num_ranks != 0: + raise NotImplementedError( + f"Dimension size {dim_size} is not divisible " + f"by number of ranks {num_ranks} for shard " + f"placement on dim {dim}. (shape: {target.shape})") + + shard_size = dim_size // num_ranks + + start = slices[dim].start + rank_coord * shard_size + end = start + shard_size + + assert start < end <= slices[dim].stop + + slices[dim] = slice(start, end) + + return tuple(slices) + + +_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, + ProcessGroup]] = dict() + + +def construct_shard_mesh( + placements: tuple[Placement], + mesh: DeviceMesh, +) -> (DeviceMesh, ProcessGroup, tuple[Placement]): + """ + Construct Shard Mesh and Placements for unsharding. + It removes Replicate placements and constructs a new Mesh and ProcessGroup. + """ + my_rank = dist.get_rank() + + assert mesh.mesh.device.type == 'cpu' + + # Copy mesh to avoid modifying the original mesh + mesh = mesh.mesh.clone() + + # 1. Sort placements. Replicate first, then Shard by dim ascending. + + # For Shard, strided shard comes after regular shard on the same dim + # to preserve left-to-right order of replicate-to-shard. + # This is because that strided shard is using stride to represent + # more fine-grained sharding on the same dim. + # Please check the URL below for _StridedShard. + # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 + + def placement_sort_key( + placement_with_index: tuple[float, Placement] + ) -> tuple[int, float, int]: # (dim, split factor, original index) + index, placement = placement_with_index + is_replicate = placement.is_replicate() + is_shard = placement.is_shard() + is_partial = placement.is_partial() + + assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" + assert not is_partial, "Partial placement is not supported." + + if is_replicate: + return (-1.0, 0, index) + elif is_shard: + if isinstance(placement, _StridedShard): + return (placement.dim, 1 / placement.split_factor, index) + return (placement.dim, 0, index) + else: + raise TypeError(f"Unknown placement type: {type(placement)}") + + placements_with_index: list[tuple[int, + Placement]] = list(enumerate(placements)) + placements_with_index = sorted(placements_with_index, + key=placement_sort_key) + + sorted_indices, sorted_placements = zip(*placements_with_index) + + # 2. Permute mesh according to sorted placements. + sorted_mesh = mesh.permute(sorted_indices) + + # 3. Collect list of shard meshes by removing replicate dims + # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] + # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) + num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + + # merge replicate dims + # shard_meshes became a list of shard meshes with a length of replicate degree + if num_replicates > 0: + sorted_mesh = sorted_mesh.flatten( + 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) + else: + shard_meshes = [sorted_mesh] + shard_placements = sorted_placements[num_replicates:] + + # assume all shard placements are different + assert len(shard_placements) == len(set(shard_placements)) + + # 4. Construct ProcessGroups + # Caution: all groups should be created in the same order in all processes, + # even though each process only needs its own group. + + # To use tensor as dict key, convert it to tuple + def tensor_to_tuple(t): + if isinstance(t, torch.Tensor): + t = t.tolist() + if isinstance(t, list): + return tuple(tensor_to_tuple(x) for x in t) + return t + + my_shard_mesh_as_tuple = None + for shard_mesh in shard_meshes: + assert isinstance(shard_mesh, torch.Tensor) + shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) + + if (my_rank == shard_mesh).any().item(): + assert my_shard_mesh_as_tuple is None + my_shard_mesh_as_tuple = shard_mesh_as_tuple + + # update global cache + if shard_mesh_as_tuple not in _ranks_to_dist_cache: + shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) + _ranks_to_dist_cache[shard_mesh_as_tuple] = ( + DeviceMesh(device_type="cuda", mesh=shard_mesh), + shard_process_group, + ) + + my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ + my_shard_mesh_as_tuple] + + return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/matmul_transpose_triton.py b/build/torch29-cxx11-rocm63-x86_64-linux/matmul_transpose_triton.py new file mode 100644 index 0000000000000000000000000000000000000000..4565b2c4fd506a4218340d380d6c962b16774b1d --- /dev/null +++ b/build/torch29-cxx11-rocm63-x86_64-linux/matmul_transpose_triton.py @@ -0,0 +1,128 @@ +# MIT License +# +# Copyright (c) 2025 Tianyang Lin +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import torch +import triton +import triton.language as tl + + +def get_autotune_config(): + return [ + triton.Config( + { + 'BLOCK_SIZE_M': blk_m, + 'BLOCK_SIZE_K': blk_k, + 'GROUP_SIZE_M': grp_sz + }, + num_stages=n_stages, + num_warps=n_warps) for blk_m in [32, 64, 128] + for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5] + for n_warps in [4, 8] + ] + + +@triton.autotune( + configs=get_autotune_config(), + key=['M', 'K'], +) +@triton.jit +def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr): + """ + Core kernel jit function of matmul_transpose that computes y = x @ x.T + The code is a simple adaptation from the triton `matmul` tutorial: + https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html + """ + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + if pid_m > pid_n: + return + + offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + # we use a & b ptrs to denote different rows of x. + a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk) + b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, + mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, + other=0.0) + b = tl.load(b_ptrs, + mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, + other=0.0) + accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator) + a_ptrs += BLOCK_SIZE_K * stride_xk + b_ptrs += BLOCK_SIZE_K * stride_xk + # use dtype.element_ty to accommodate different input datatypes as in cpp templates + # https://github.com/triton-lang/triton/issues/2252 + c = accumulator.to(x.dtype.element_ty) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, c, mask=c_mask) + + # transpose and copy + if pid_m < pid_n: + ct_ptrs = y + stride_ym * offs_cn[:, + None] + stride_yn * offs_cm[None, :] + ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) + + +def matmul_transpose_assign(d_in, d_out): + assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" + assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" + assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" + assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" + assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" + assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" + assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ + "First dimension of `d_in` must match first and second dimension of `d_out`" + + d_in = d_in.contiguous() + M, K = d_in.shape + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( + M, META['BLOCK_SIZE_M']), ) + with torch.cuda.device(d_in.device.index): + mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), + d_out.stride(0), d_out.stride(1)) + + +def matmul_transpose(d_in): + M, _ = d_in.shape + d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) + matmul_transpose_assign(d_in, d_out) + return d_out diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/metadata.json b/build/torch29-cxx11-rocm63-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..76bafa5f33b6818aa6bb4cab04be811b87519b44 --- /dev/null +++ b/build/torch29-cxx11-rocm63-x86_64-linux/metadata.json @@ -0,0 +1 @@ +{"python-depends":[]} \ No newline at end of file diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/muon.py b/build/torch29-cxx11-rocm63-x86_64-linux/muon.py new file mode 100644 index 0000000000000000000000000000000000000000..dbf25575f185ff379789482068e4ecf55b9455a9 --- /dev/null +++ b/build/torch29-cxx11-rocm63-x86_64-linux/muon.py @@ -0,0 +1,1268 @@ +import logging +import math +import types +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, cast + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Replicate +from torch.distributed.tensor.placement_types import Placement + +from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor +from .matmul_transpose_triton import matmul_transpose_assign + +logger = logging.getLogger(__name__) + +COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 + + +# This code snippet is a modified version adapted from the following GitHub repositories: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +# Muon's Newton–Schulz iteration causes high variance in singular values +# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +@torch.no_grad() +# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + assert G.dtype == COMM_DTYPE + X = G # no manual typecast + + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + # Perform the NS iterations + for a, b, c in [ + (4.0848, -6.8946, 2.9270), + (3.9505, -6.3029, 2.6377), + (3.7418, -5.5913, 2.3037), + (2.8769, -3.1427, 1.2046), + (2.8366, -3.0525, 1.2012), + ]: + matmul_transpose_assign(X, buf1) + matmul_transpose_assign(buf1, buf2) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X + + +@dataclass +class _muon_state: + # TODO: use Optional + worker_rank: int + process_group: ProcessGroup + shard_mesh: DeviceMesh + shard_placements: tuple[Placement, ...] + name: str + qk_clip_state: torch.Tensor | None = None + gathered_grad: torch.Tensor | None = None + scattered_u: DTensor | None = None + computed_u: torch.Tensor | None = None + gather_event: torch.cuda.Event | None = None + compute_event: torch.cuda.Event | None = None + scatter_event: torch.cuda.Event | None = None + + +def numel_for_rank( + param: DTensor, + local_rank: int, + state: _muon_state, +) -> int: + slices = get_slices_of_dtensor( + param, + local_rank, + state.shard_mesh, + state.shard_placements, + ) + + numel = 1 + for s, dim in zip(slices, param.shape): + start, stop, step = s.indices(dim) + length = max(0, (stop - start + (step - 1)) // step) + numel *= length + + return numel + + +@torch.no_grad() +def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): + """ + Pre-allocate gathered_grad buffer on compute_stream + before launching all2all gather + """ + with torch.cuda.stream(compute_stream): + for p in params: + state = param_to_state[id(p)] + if rank == state.worker_rank: + state.gathered_grad = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + else: + state.gathered_grad = None + + alloc_event = torch.cuda.Event() + alloc_event.record(compute_stream) + return alloc_event + + +@torch.no_grad() +def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, + alloc_event): + """ + All2all gathers shards so each owner rank reconstructs its full gradient + """ + with torch.cuda.stream(comm_stream): + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + + # Construct sending buffers + per_dst = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + for p in params: + state = param_to_state[id(p)] + dst = state.worker_rank + assert dst < num_ranks + shard_elems = numel_for_rank(p, rank, state) + g = p.grad + g = g.to_local().to(COMM_DTYPE).contiguous() + assert g.numel() == shard_elems + per_dst[dst].append(g.view(-1)) + send_counts[dst] += shard_elems + + assert any( + len(v) > 0 for v in per_dst + ), "At least one destination rank must receive a sharded tensor" + # list[list[Tensor]] -> list[Tensor] + per_dst = [t for dst in per_dst for t in dst] + + send_buf = torch.cat(per_dst, dim=0) + + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Compute receive sizes and allocate receiving buffers + recv_counts = [0] * num_ranks + + for src in range(num_ranks): + total = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + total += numel_for_rank(p, src, state) + recv_counts[src] = total + + recv_total = sum(recv_counts) + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + #All2All + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") + dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + ) + + # Reconstructs gathered grad from the received buffer + # + # recv_buf (num ranks = 3) + # + # From rank 0 From rank 1 From rank 2 + # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | + # + # Outer loop: + # rank 0 -> rank 1 -> rank2 + # + # Inner loop: + # p1_n -> p2_n -> p3_n + + comm_stream.wait_event(alloc_event) + + off = 0 + for src in range(num_ranks): + if recv_counts[src] == 0: + continue + + block = recv_counts[src] + inner_off = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + + # get the slice of the full dtensor corresponding to rank src. + slices = get_slices_of_dtensor(state.gathered_grad, src, + state.shard_mesh, + state.shard_placements) + + dst = state.gathered_grad[slices] + assert dst._base is state.gathered_grad + + n = dst.numel() + assert n > 0 + + sg = recv_buf.narrow(0, off + inner_off, n) + sg = sg.reshape_as(dst) + dst.copy_(sg) + + inner_off += n + off += block + + for p in params: + state = param_to_state[id(p)] + if state.worker_rank == rank: + state.gather_event = torch.cuda.Event() + state.gather_event.record(comm_stream) + else: + state.gathered_grad = None + state.gather_event = None + if none_grad: + p.grad = None + + +@torch.no_grad() +def _compute_u(p, state, steps, rank, compute_stream): + """ + On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. + """ + with torch.cuda.stream(compute_stream): + if rank == state.worker_rank: + if state.gather_event is None: + raise RuntimeError("Gather event must be set before compute.") + compute_stream.wait_event(state.gather_event) + u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) + state.gathered_grad = None + state.computed_u = u + state.compute_event = torch.cuda.Event() + state.compute_event.record() + else: + state.computed_u = None + state.compute_event = None + + +@torch.no_grad() +def _alloc_scattered_u(params, param_to_state, rank, compute_stream): + """ + Pre-allocate scattered_u buffer on compute_stream + before launching all2all gather + """ + with torch.cuda.stream(compute_stream): + for p in params: + state = param_to_state[id(p)] + state.scattered_u = torch.empty_like(p.to_local(), + dtype=COMM_DTYPE) + + alloc_event = torch.cuda.Event() + alloc_event.record(compute_stream) + return alloc_event + + +def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): + """ + All2all scatters full gradients to all ranks + """ + with torch.cuda.stream(comm_stream): + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Construct sending buffer + per_dst = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + if owned_params: + for p in owned_params: + state = param_to_state[id(p)] + if state.compute_event is None: + raise RuntimeError( + "Compute event must be set before scatter.") + comm_stream.wait_event(state.compute_event) + state.gathered_grad = None + + assert state.computed_u is not None + + u_full = state.computed_u.to(COMM_DTYPE).contiguous() + + offset = 0 + for dst in range(num_ranks): + # get the slice of the full tensor corresponding to rank dst. + slices = get_slices_of_dtensor(u_full, dst, + state.shard_mesh, + state.shard_placements) + su = u_full[slices].flatten() + + n = su.numel() + assert n > 0 + + per_dst[dst].append(su) + send_counts[dst] += n + offset += n + + assert offset == u_full.numel() + + lengths = [len(v) for v in per_dst] + if all(l > 0 for l in lengths): + assert all( + l == lengths[0] for l in lengths + ), "All destination ranks must have the same number of sharded tensor" + # list[list[Tensor]] -> list[Tensor] + per_dst = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst, dim=0) + else: + # all_to_all requires participation from all ranks + # Even non-owner ranks must join the collective call + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") + + # Compute receive sizes and allocate receiving buffers + recv_counts = [0] * num_ranks + + for src in range(num_ranks): + total = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + total += numel_for_rank(p, rank, state) + recv_counts[src] = total + + recv_total = sum(recv_counts) + assert recv_total > 0 + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + #All2All + dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + ) + + # Copy to pre-allocated scattered_u buffer from the received buffer + # + # recv_buf (num ranks = 3, local_rank = 0) + # + # From rank 0 From rank 1 From rank 2 + # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | + # + # Outer loop: + # rank 0 -> rank 1 -> rank2 + # + # Inner loop: + # src(0) : p1_0 -> p2_0 -> p3_0 + # src(1) : p4_0 + # src(2) : p5_0 -> p6_0 + + comm_stream.wait_event(alloc_event) + + off = 0 + for src in range(num_ranks): + block = recv_counts[src] + if block == 0: + continue + + inner_off = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + n = numel_for_rank(p, rank, state) + assert n > 0 + + flat_local = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) + state.scattered_u.copy_(flat_local) + + state.scatter_event = torch.cuda.Event() + state.scatter_event.record(comm_stream) + inner_off += n + + assert inner_off == block + off += block + + +def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, + compute_stream): + """ + Update sharded parameter p with the scattered_u. + Only worker_rank frees computed_u. + """ + with torch.cuda.stream(compute_stream): + if state.scatter_event is None: + raise RuntimeError("Scatter event must be set before update") + compute_stream.wait_event(state.scatter_event) + u_dtensor = DTensor.from_local( + state.scattered_u, + placements=p.placements, + device_mesh=p.device_mesh, + ) + + state.scattered_u = u_dtensor + + if rank == state.worker_rank: + # Free computed_u + state.computed_u = None + + Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) + state.scattered_u = None + u_dtensor = None + + scales_full = Muon._compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None + if scales_full is not None: + # Have to slice scales_full among dim 0 + weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, + state.shard_placements) + ratio = p.shape[0] // scales_full.shape[0] + scales_slice = slice( + None if weight_slices[0].start is None else + weight_slices[0].start // ratio, + None if weight_slices[0].stop is None else + weight_slices[0].stop // ratio, + None, + ) + + scales_local = scales_full[scales_slice] + scales_local = DTensor.from_local( + scales_local, + placements=p.placements, + device_mesh=p.device_mesh, + ) + Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) + + +def default_is_muon(name, x): + skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] + return x.ndim >= 2 and not any(key in name for key in skip_keys) + + +def get_default_muon_param_groups(model, is_muon_func=default_is_muon): + muon_params, muon_names = [], [] + non_muon_params = [] + + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + if is_muon_func(n, p): + muon_params.append(p) + muon_names.append(n) + else: + non_muon_params.append(p) + + return [ + { + "params": muon_params, + "names": muon_names, + "use_muon": True, + }, + { + "params": non_muon_params, + "use_muon": False, + }, + ] + + +def parse_qk_layer(name: str) -> tuple[str | None, int]: + """ + Parse a parameter name to check if it is a query/key projection layer + ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + + Returns: + (kind, layer_idx) or (None, -1) if not matched. + + Example: + 'model.3.attn.wq.weight' -> ('wq', 3) + 'model.5.attn.wk.weight' -> ('wk', 5) + 'model.2.attn.q_proj.weight' -> ('q_proj', 2) + 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.4.attn.v_proj.weight' -> (None, -1) + """ + parts = name.split('.') + if len(parts) < 3: + return None, -1 + + kind = parts[-2] + + layer_idx = -1 + for part in reversed(parts): + if part.isdigit(): + layer_idx = int(part) + break + + if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + return kind, layer_idx + + return None, -1 + + +@dataclass +class QKClipInfo: + """Per-parameter dynamic info computed from config + runtime logits.""" + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping + head_dim: int # from config + threshold: float # from config + logit: torch.Tensor | None + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + model: The model to be optimized by Muon. + is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + weight_decay: The weight decay for Muon and AdamW. + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. + debug: Whether to print debug information. + clip_info : Configuration for QK clipping. Expected keys: + - "q_indices" (list[int]): Indices of query heads to consider. + - "k_indices" (list[int]): Indices of key heads to consider. + - "head_dim" (int): Dimensionality of each attention head. + - "threshold" (float): Threshold value; heads whose QK logits exceed + this value will be scaled down. + Default is: + { + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100 + } + warmup_step : How many all2all gather, compute operations are launched in advance + before the corresponding all2all scatter steps begin. + A higher warmup_step increases memory usage but can improve + performance by overlapping communication. + Parallel muon only. + chunk_size : Batch size of parameters to process in each + all2all gather/compute/scatter step. + Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. + use_distributed_muon: Use distributed muon by Liu et al. (2024). + For testing purpose only. + small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon + """ + + def __init__(self, + params, + lr=1e-3, + momentum=0.95, + nesterov=True, + ns_steps=5, + weight_decay=0.1, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + none_grad=True, + debug=False, + clip_config={ + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100 + }, + warmup_step=5, + chunk_size=-1, + use_distributed_muon=False, + small_param_numel_threshold=65536): + defaults = dict( + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + none_grad=none_grad, + use_muon=True, + ) + error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." + instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" + + if isinstance(params, types.GeneratorType): + raise ValueError(error_message.format(idx=0) + instruction_code) + for _idx, param_group in enumerate(params): + if param_group.get("use_muon", None) is None: + raise ValueError( + error_message.format(idx=_idx) + instruction_code) + + super().__init__(params, defaults) + + self.rank = None + + self.comm_stream = torch.cuda.Stream() + self.compute_stream = torch.cuda.Stream() + self.debug = debug + self.clip_config = clip_config + self.warmup_step = warmup_step + self.chunk_size = chunk_size + self.use_distributed_muon = use_distributed_muon + self.small_param_numel_threshold = small_param_numel_threshold + + def _calc_flops(self, G, steps): + assert len(G.shape) == 2 + M, N = G.shape + if M > N: + M, N = N, M + + return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def set_rank_once(self, rank): + if self.rank is None: + self.rank = rank + else: + assert self.rank == rank + + def get_shard_mesh(self, p): + """ + Get the shard mesh for a parameter p on the given rank. + """ + assert isinstance( + p, DTensor), "Parallel Muon only supports DTensor parameters." + + shard_mesh, shard_pg, shard_placements = construct_shard_mesh( + p.placements, p.device_mesh) + + # set rank with the local rank in the shard process group + self.set_rank_once(dist.get_rank(group=shard_pg)) + + return shard_mesh, shard_pg, shard_placements + + def init_state_and_assign_params(self, names, params, group, qk_logits): + param_to_state = {} + param_to_flops = {} + + total_flops = 0 + for p in params: + g = p.grad + if g is None: + continue + assert g.ndim == 2, "Muon only supports 2D parameters." + + flops = self._calc_flops(g, group["ns_steps"]) + param_to_flops[id(p)] = flops + total_flops += flops + + if self.debug: + print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", + flush=True) + + paired = list(zip(names, params)) + + paired_sorted = sorted(paired, + key=lambda x: param_to_flops[id(x[1])], + reverse=True) + + names_sorted, params_sorted = zip(*paired_sorted) + ordered_names = list(names_sorted) + ordered_params = list(params_sorted) + + round_robin = 0 + mesh = ordered_params[0].device_mesh + placements = ordered_params[0].placements + + shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( + ordered_params[0]) + shard_mesh_flattened = shard_mesh.mesh.flatten() + num_ranks = dist.get_world_size(group=shard_pg) + + for n, p in zip(ordered_names, ordered_params): + if mesh != p.device_mesh: + raise ValueError("All parameters must be on the same mesh.") + if placements != p.placements: + raise ValueError("All parameters must have same placements.") + + worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks + round_robin = (round_robin + 1) % len(shard_mesh_flattened) + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + param_to_state[id(p)] = _muon_state( + worker_rank=worker_rank, + process_group=shard_pg, + shard_mesh=shard_mesh, + shard_placements=shard_placements, + name=n, + qk_clip_state=qk_clip_state, + ) + + return param_to_state, ordered_params + + def base(self, names, params, group, lr, weight_decay, momentum, + qk_logits): + # generate weight updates in distributed fashion + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + g = self._update_g(p, g, group, momentum) + + u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), + steps=group["ns_steps"]) + + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + scales_full = self._compute_scales( + p, qk_clip_state) if qk_clip_state is not None else None + if scales_full is not None: + Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + + def distributed_muon( + self, + names: list[str], + params: list[torch.nn.Parameter], + group: dict[str, Any], + lr: float, + weight_decay: float, + momentum: float, + qk_logits: list[torch.Tensor | DTensor] | None, + ): + """ Implementation of Distributed Muon by Liu et al. """ + + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + g = self._update_g(p, g, group, momentum) + + # Gather G + if isinstance(p.data, DTensor): + g_full = g.full_tensor() + p_full = p.data.full_tensor() + else: + g_full = g + p_full = p + + u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), + steps=group["ns_steps"]) + + adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) + Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + scales_full = self._compute_scales( + p_full, qk_clip_state) if qk_clip_state is not None else None + + if scales_full is not None: + Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) + + if isinstance(p.data, DTensor): + ndims = len(p.device_mesh.mesh.shape) + p_replicate = DTensor.from_local( + p_full, + device_mesh=p.device_mesh, + placements=[Replicate() for _ in range(ndims)], + ) + + p_sharded = p_replicate.redistribute( + device_mesh=p.device_mesh, + placements=p.placements, + ) + + p.copy_(p_sharded) + + def _update_g(self, p, g, group, momentum): + # calc update + state = self.state[p] + buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) + torch.add(g, buf, alpha=momentum, out=buf) + if group["nesterov"]: + g.add_(buf, alpha=momentum) + return g + return buf + + @staticmethod + def _update_p(p, u, lr, adjusted_lr, weight_decay): + if isinstance(p, torch.nn.Parameter): + # apply weight decay + p.data.mul_(1 - lr * weight_decay) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + else: + p.mul_(1 - lr * weight_decay) + p.add_(u, alpha=-adjusted_lr) + + def get_qk_clip_info(self, n, qk_logits): + if self.clip_config is None: + return None + + head_dim = self.clip_config.get('head_dim') + threshold = self.clip_config.get('threshold') + kind, layer_idx = parse_qk_layer(n) + + logit, indices = None, [] + if qk_logits is not None and kind is not None: + logit = qk_logits[layer_idx] + indices_key = 'q_indices' if 'q' in kind else 'k_indices' + indices = self.clip_config.get(indices_key, []) or [] + + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) + + @staticmethod + def _compute_scales(p, qk_clip_state): + kind = qk_clip_state.kind + indices = qk_clip_state.indices + head_dim = qk_clip_state.head_dim + threshold = qk_clip_state.threshold + logit = qk_clip_state.logit + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + scaling = 0 + + for logit_idx, head_idx in enumerate(indices): + v_ele = float(logit[logit_idx]) + if v_ele > threshold: + new_scale = math.sqrt(threshold / v_ele) + if new_scale < scales_full[head_idx]: + scales_full[head_idx] = new_scale + logger.info( + f"[{kind}] Head {head_idx} exceeded threshold " + f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" + ) + scaling += 1 + + return scales_full if scaling > 0 else None + + @staticmethod + def _qk_clip(p, scales, head_dim): + if isinstance(p, torch.nn.Parameter): + W = p.data.view(-1, head_dim, p.data.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + else: + W = p.view(-1, head_dim, p.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + + def parallel(self, names, params, group, lr, weight_decay, momentum, + qk_logits): + """ + Perform a parallel optimization step using Muon. + """ + + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + + # Update g in the local rank + g = self._update_g( + p, + g, + group, + momentum=momentum, + ) + p.grad = g + + param_to_state, ordered_params = self.init_state_and_assign_params( + names, params, group, qk_logits) + + assert self.rank is not None + + def enqueue_all2all_gather(start_idx, chunk_size): + target_params = ordered_params[start_idx:start_idx + chunk_size] + if target_params: + alloc_event = _alloc_gathered_grad(target_params, + param_to_state, self.rank, + self.compute_stream) + _all2all_gather(target_params, param_to_state, self.rank, + self.comm_stream, group["none_grad"], + alloc_event) + + def enqueue_computes(start_idx, chunk_size): + for p in ordered_params[start_idx:start_idx + chunk_size]: + state = param_to_state[id(p)] + _compute_u(p, state, group["ns_steps"], self.rank, + self.compute_stream) + + def enqueue_all2all_scatter(start_idx, chunk_size): + target_params = ordered_params[start_idx:start_idx + chunk_size] + if target_params: + alloc_event = _alloc_scattered_u(target_params, param_to_state, + self.rank, + self.compute_stream) + _all2all_scatter(target_params, param_to_state, self.rank, + self.comm_stream, alloc_event) + + def enqueue_update_param(start_idx, chunk_size): + for p in ordered_params[start_idx:start_idx + chunk_size]: + state = param_to_state[id(p)] + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + _update_param(p, state, lr, adjusted_lr, weight_decay, + self.rank, self.compute_stream) + + if self.chunk_size == -1: + shard_ranks = dist.get_world_size(param_to_state[id( + params[0])].process_group) + chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO + elif self.chunk_size > 0: + chunk_size = self.chunk_size + else: + raise ValueError("chunk_size must be -1 or a positive integer.") + + # Wait grad update + self.comm_stream.wait_stream(torch.cuda.current_stream()) + + warmup_step = self.warmup_step + for i in range(0, warmup_step): + enqueue_all2all_gather(i * chunk_size, chunk_size) + enqueue_computes(i * chunk_size, chunk_size) + + for i in range(0, len(params) + chunk_size - 1, chunk_size): + enqueue_all2all_scatter(i, chunk_size) + enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) + enqueue_update_param(i, chunk_size) + enqueue_computes(i + warmup_step * chunk_size, chunk_size) + + # Wait the last update_param to finish + torch.cuda.current_stream().wait_stream(self.compute_stream) + + @staticmethod + def _fused_adamw( + params: list[torch.Tensor], + grads: list[torch.Tensor], + exp_avgs: list[torch.Tensor], + exp_avg_sqs: list[torch.Tensor], + max_exp_avg_sqs: list[torch.Tensor], + state_steps: list[torch.Tensor], + amsgrad: bool, + beta1: float, + beta2: float, + lr: float | torch.Tensor, + weight_decay: float, + eps: float, + maximize: bool, + ) -> None: + if not params: + return + + # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer + # treating it as a scalar. + lr_dict: DeviceDict | None = ({ + lr.device: lr + } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else + None) + grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( + [ + params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, + state_steps + ] # type: ignore[list-item] + ) + for (device, _), ( + ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_max_exp_avg_sqs, + device_state_steps_, + ), + _, + ) in grouped_tensors.items(): + device_params = cast(list[torch.Tensor], device_params_) + device_grads = cast(list[torch.Tensor], device_grads_) + device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) + device_state_steps = cast(list[torch.Tensor], device_state_steps_) + + if lr_dict is not None and device not in lr_dict: + lr_dict[device] = lr.to( + device=device, + non_blocking=True) # type: ignore[union-attr] + lr = lr_dict[device] + torch._foreach_add_(device_state_steps, 1) + func = torch._fused_adamw_ + func( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, # type: ignore[arg-type] + device_state_steps, + amsgrad=amsgrad, + lr=lr, # type: ignore[arg-type] + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + ) + + def _step_muon(self, group, qk_logits=None): + params = group["params"] + lr = group["lr"] + weight_decay = group["weight_decay"] + momentum = group["momentum"] + names = group["names"] + + param_dtensors = [] + name_dtensors = [] + + param_tensors = [] + name_tensors = [] + + param_dtensors_small = [] + name_dtensors_small = [] + + if self.use_distributed_muon: + self.distributed_muon(names=names, + params=params, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits) + return + + # For simplicity, we use distributed Muon for small parameters + # whose number of elements is below a threshold. + for n, p in zip(names, params): + if p is None or p.grad is None: + continue + if isinstance(p.data, DTensor): + if all( + isinstance(placement, Replicate) + for placement in p.placements): + param_tensors.append(p) + name_tensors.append(n) + elif p.data.numel() <= self.small_param_numel_threshold: + param_dtensors_small.append(p) + name_dtensors_small.append(n) + else: + param_dtensors.append(p) + name_dtensors.append(n) + elif isinstance(p.data, torch.Tensor): + param_tensors.append(p) + name_tensors.append(n) + else: + raise TypeError(f"Unsupported parameter type: {type(p.data)}") + + logger.debug( + f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " + f"{len(param_dtensors_small)} Small DTensors") + + def group_dtensors(dtensors, names): + # To support different placements, we group parameters by placements + # and run parallel Muon on each group. + + placement_to_params = defaultdict(lambda: ([], [])) + # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] + + assert len(dtensors) == len(names) + for p, n in zip(dtensors, names): + placement_to_params[tuple([p.placements, + p.device_mesh])][0].append(n) + placement_to_params[tuple([p.placements, + p.device_mesh])][1].append(p) + return placement_to_params + + if len(param_dtensors_small) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + self.distributed_muon( + params=param_dtensors_small, + names=name_dtensors_small, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_dtensors) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + dtensor_group = group_dtensors(param_dtensors, name_dtensors) + for _, (names, params) in dtensor_group.items(): + self.parallel( + names, + params, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_tensors) > 0: + self.base( + name_tensors, + param_tensors, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + def _step_adamw_params(self, params, group): + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + self._fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + def _step_adamw(self, group): + params = group["params"] + + # group params with it's type and placement + placement_to_params: dict[tuple[Placement | type, + DeviceMesh | None]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for params in placement_to_params.values(): + self._step_adamw_params(params, group) + + @torch.no_grad + def step(self, closure=None, qk_logits=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices + to 1D tensors of shape (num_heads,), representing the maximum + QK logits across all tokens, computed as + (1 / sqrt(head_dim)) * (Q @ K^T). + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + if group["use_muon"]: + self._step_muon(group, qk_logits=qk_logits) + else: + self._step_adamw(group) + + return loss diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/optimizer/__init__.py b/build/torch29-cxx11-rocm63-x86_64-linux/optimizer/__init__.py index 239c7a65f8293e7d0df28f05fce645af56d628c0..03dbc1afe1cf156661a2b1b22003cd5f599a0309 100644 --- a/build/torch29-cxx11-rocm63-x86_64-linux/optimizer/__init__.py +++ b/build/torch29-cxx11-rocm63-x86_64-linux/optimizer/__init__.py @@ -1,5 +1,26 @@ -from .muon import Muon +import ctypes +import sys -__all__ = [ - "Muon", -] +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/optimizer/_ops.py b/build/torch29-cxx11-rocm63-x86_64-linux/optimizer/_ops.py deleted file mode 100644 index 7d598206add1bca142661a3df6c510e3d9575d54..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _optimizer_23d68bb_dirty -ops = torch.ops._optimizer_23d68bb_dirty - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_optimizer_23d68bb_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so b/build/torch29-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so deleted file mode 100755 index fd962f33db89d6740c8d181f6b4e3ded3220fec0..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:d297c32252c7f030f3ec60ab1cc908cf145c8ecc710a25690a528d06115ab998 -size 1852184 diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/optimizer/distributed/utils.py b/build/torch29-cxx11-rocm63-x86_64-linux/optimizer/distributed/utils.py deleted file mode 100644 index 0b4b58bfb329b1c015129e4c4fc99f7bfa2ab30a..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-rocm63-x86_64-linux/optimizer/distributed/utils.py +++ /dev/null @@ -1,174 +0,0 @@ -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor -from torch.distributed.tensor.placement_types import (Placement, Shard, - _StridedShard) - - -def get_slices_of_dtensor( - target: DTensor | torch.Tensor, - local_rank: int, - shard_mesh: DeviceMesh, - shard_placements: tuple[Placement], -) -> tuple[slice]: - """ - Get the slice of local tensor for a given rank from a tensor. - Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. - shard_placements (tuple[Placement]): The shard placements. - """ - - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] - - # find the global rank of the local rank in the shard mesh - rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] - - rank_coords = (shard_mesh.mesh == rank).nonzero() - - assert len(rank_coords) == 1 - rank_coords = tuple(rank_coords[0].tolist()) - - assert len(rank_coords) == len(shard_placements) - - # Caution: Assuming replicate-to-shard of the shard mesh goes with - # left-to-right sharding. This is ensured by the sorting logic of - # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) - - num_ranks = shard_mesh.mesh.shape[i] - - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) - - if dim_size % num_ranks != 0: - raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}.") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) - - return tuple(slices) - - -_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, ProcessGroup]] = dict() - - -def construct_shard_mesh( - placements: tuple[Placement], - mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() - - assert mesh.mesh.device.type == 'cpu' - - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") - - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) - - sorted_indices, sorted_placements = zip(*placements_with_index) - - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) - - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) - - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh - shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) - else: - shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different - assert len(shard_placements) == len(set(shard_placements)) - - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, - ) - - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/optimizer/muon.py b/build/torch29-cxx11-rocm63-x86_64-linux/optimizer/muon.py deleted file mode 100644 index cfbcca71741be70048bfd290c62148b2aceda631..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-rocm63-x86_64-linux/optimizer/muon.py +++ /dev/null @@ -1,1240 +0,0 @@ -import logging -import math -import types -from collections import defaultdict -from dataclasses import dataclass -from typing import Any, cast - -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign - -logger = logging.getLogger(__name__) - -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 - - -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. -@torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon -def _zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE - X = G # no manual typecast - - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) - - if G.size(0) > G.size(1): - X = X.T - return X - - -@dataclass -class _muon_state: - # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None - gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None - computed_u: torch.Tensor | None = None - gather_event: torch.cuda.Event | None = None - compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel - - -@torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None - - -@torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ - with torch.cuda.stream(compute_stream): - if rank == state.worker_rank: - if state.gather_event is None: - raise RuntimeError("Gather event must be set before compute.") - compute_stream.wait_event(state.gather_event) - u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None - state.computed_u = u - state.compute_event = torch.cuda.Event() - state.compute_event.record() - else: - state.computed_u = None - state.compute_event = None - - -@torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) - else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) - - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n - - assert inner_off == block - off += block - - -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. - """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - - if rank == state.worker_rank: - # Free computed_u - state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: - continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx - - return None, -1 - - -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None - - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Some warnings: - - We believe this optimizer is unlikely to work well for training with small batch size. - - We believe it may not work well for finetuning pretrained models, but we haven't tested this. - - Arguments: - model: The model to be optimized by Muon. - is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. - lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) - momentum: The momentum used by the internal SGD. (0.95 is a good default) - nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) - ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) - weight_decay: The weight decay for Muon and AdamW. - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. - adamw_lr: The learning rate for the internal AdamW. - adamw_betas: The betas for the internal AdamW. - adamw_eps: The epsilon for the internal AdamW. - none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. - debug: Whether to print debug information. - clip_info : Configuration for QK clipping. Expected keys: - - "q_indices" (list[int]): Indices of query heads to consider. - - "k_indices" (list[int]): Indices of key heads to consider. - - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed - this value will be scaled down. - Default is: - { - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - } - warmup_step : How many all2all gather, compute operations are launched in advance - before the corresponding all2all scatter steps begin. - A higher warmup_step increases memory usage but can improve - performance by overlapping communication. - Parallel muon only. - chunk_size : Batch size of parameters to process in each - all2all gather/compute/scatter step. - Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. - use_distributed_muon: Use distributed muon by Liu et al. (2024). - For testing purpose only. - """ - - def __init__(self, - params, - lr=1e-3, - momentum=0.95, - nesterov=True, - ns_steps=5, - weight_decay=0.1, - adamw_betas=(0.9, 0.95), - adamw_eps=1e-8, - none_grad=True, - debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, - warmup_step=5, - chunk_size=-1, - use_distributed_muon=False): - defaults = dict( - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - nesterov=nesterov, - ns_steps=ns_steps, - adamw_betas=adamw_betas, - adamw_eps=adamw_eps, - none_grad=none_grad, - use_muon=True, - ) - error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." - instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" - - if isinstance(params, types.GeneratorType): - raise ValueError(error_message.format(idx=0) + instruction_code) - for _idx, param_group in enumerate(params): - if param_group.get("use_muon", None) is None: - raise ValueError( - error_message.format(idx=_idx) + instruction_code) - - super().__init__(params, defaults) - - self.rank = None - - self.comm_stream = torch.cuda.Stream() - self.compute_stream = torch.cuda.Stream() - self.debug = debug - self.clip_config = clip_config - self.warmup_step = warmup_step - self.chunk_size = chunk_size - self.use_distributed_muon = use_distributed_muon - - def _calc_flops(self, G, steps): - assert len(G.shape) == 2 - M, N = G.shape - if M > N: - M, N = N, M - - return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) - - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - - def get_shard_mesh(self, p): - """ - Get the shard mesh for a parameter p on the given rank. - """ - assert isinstance( - p, DTensor), "Parallel Muon only supports DTensor parameters." - - shard_mesh, shard_pg, shard_placements = construct_shard_mesh( - p.placements, p.device_mesh) - - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - - return shard_mesh, shard_pg, shard_placements - - def init_state_and_assign_params(self, names, params, group, qk_logits): - param_to_state = {} - param_to_flops = {} - - total_flops = 0 - for p in params: - g = p.grad - if g is None: - continue - assert g.ndim == 2, "Muon only supports 2D parameters." - - flops = self._calc_flops(g, group["ns_steps"]) - param_to_flops[id(p)] = flops - total_flops += flops - - if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) - - paired = list(zip(names, params)) - - paired_sorted = sorted(paired, - key=lambda x: param_to_flops[id(x[1])], - reverse=True) - - names_sorted, params_sorted = zip(*paired_sorted) - ordered_names = list(names_sorted) - ordered_params = list(params_sorted) - - round_robin = 0 - mesh = ordered_params[0].device_mesh - placements = ordered_params[0].placements - - shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( - ordered_params[0]) - shard_mesh_flattened = shard_mesh.mesh.flatten() - num_ranks = dist.get_world_size(group=shard_pg) - - for n, p in zip(ordered_names, ordered_params): - if mesh != p.device_mesh: - raise ValueError("All parameters must be on the same mesh.") - if placements != p.placements: - raise ValueError("All parameters must have same placements.") - - worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks - round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - param_to_state[id(p)] = _muon_state( - worker_rank=worker_rank, - process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, - name=n, - qk_clip_state=qk_clip_state, - ) - - return param_to_state, ordered_params - - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - # generate weight updates in distributed fashion - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - # calc update - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if group["nesterov"]: - g = g.add(buf, alpha=momentum) - else: - g = buf - - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p, qk_clip_state) if qk_clip_state is not None else None - if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) - - def distributed_muon( - self, - names: list[str], - params: list[torch.nn.Parameter], - group: dict[str, Any], - lr: float, - weight_decay: float, - momentum: float, - qk_logits: list[torch.Tensor | DTensor] | None, - ): - """ Implementation of Distributed Muon by Liu et al. """ - if qk_logits is not None: - raise NotImplementedError("QK clipping is not supported yet") - - if isinstance(params[0], DTensor): - shard_mesh, _, shard_placements = construct_shard_mesh( - placements=params[0].placements, - mesh=params[0].device_mesh, - ) - - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - # calc update - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if group["nesterov"]: - g = g.add(buf, alpha=momentum) - else: - g = buf - - # Gather G - if isinstance(p.data, DTensor): - g = g.full_tensor() - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) - - if isinstance(p.data, DTensor): - slices = get_slices_of_dtensor( - target=p, - local_rank=dist.get_rank(), - shard_mesh=shard_mesh, - shard_placements=shard_placements, - ) - u_shard = u[slices] - u = DTensor.from_local( - u_shard, - device_mesh=p.device_mesh, - placements=p.placements, - ) - - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) - - def _update_g(self, p, g, group, momentum): - # calc update - state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - """ - Perform a parallel optimization step using Muon. - """ - - for p in params: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - - # Update g in the local rank - g = self._update_g( - p, - g, - group, - momentum=momentum, - ) - p.grad = g - - param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - assert self.rank is not None - - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) - - def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) - - if self.chunk_size == -1: - shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) - chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO - elif self.chunk_size > 0: - chunk_size = self.chunk_size - else: - raise ValueError("chunk_size must be -1 or a positive integer.") - - # Wait grad update - self.comm_stream.wait_stream(torch.cuda.current_stream()) - - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - - for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return - - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) - - def _step_muon(self, group, qk_logits=None): - params = group["params"] - lr = group["lr"] - weight_decay = group["weight_decay"] - momentum = group["momentum"] - names = group["names"] - - param_dtensors = [] - param_tensors = [] - name_dtensors = [] - name_tensors = [] - - if self.use_distributed_muon: - self.distributed_muon(names=names, - params=params, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits) - return - - for n, p in zip(names, params): - if p is None or p.grad is None: - continue - if isinstance(p.data, DTensor): - if all( - isinstance(placement, Replicate) - for placement in p.placements): - param_tensors.append(p) - name_tensors.append(n) - else: - param_dtensors.append(p) - name_dtensors.append(n) - elif isinstance(p.data, torch.Tensor): - param_tensors.append(p) - name_tensors.append(n) - else: - raise TypeError(f"Unsupported parameter type: {type(p.data)}") - - logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors" - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - # To support different placements, we group parameters by placements - # and run parallel Muon on each group. - - placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] - - assert len(name_dtensors) == len(param_dtensors) - for n, p in zip(name_dtensors, param_dtensors): - placement_to_params[tuple([p.placements, - p.device_mesh])][0].append(n) - placement_to_params[tuple([p.placements, - p.device_mesh])][1].append(p) - - for _, (names, params) in placement_to_params.items(): - self.parallel( - names, - params, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_tensors) > 0: - self.base( - name_tensors, - param_tensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] - - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - - def step(self, closure=None, qk_logits=None): - """Perform a single optimization step. - - Args: - closure (Callable, optional): A closure that reevaluates the model - and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as - (1 / sqrt(head_dim)) * (Q @ K^T). - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - if group["use_muon"]: - self._step_muon(group, qk_logits=qk_logits) - else: - self._step_adamw(group) - - return loss diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/__init__.py b/build/torch29-cxx11-rocm64-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..239c7a65f8293e7d0df28f05fce645af56d628c0 --- /dev/null +++ b/build/torch29-cxx11-rocm64-x86_64-linux/__init__.py @@ -0,0 +1,5 @@ +from .muon import Muon + +__all__ = [ + "Muon", +] diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/_ops.py b/build/torch29-cxx11-rocm64-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..e6f6fcf6280e969b1761926112147d3146e27b59 --- /dev/null +++ b/build/torch29-cxx11-rocm64-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _optimizer_06a260a_dirty +ops = torch.ops._optimizer_06a260a_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_optimizer_06a260a_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/_optimizer_06a260a_dirty.abi3.so b/build/torch29-cxx11-rocm64-x86_64-linux/_optimizer_06a260a_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..95d54a0288c1e9cea520f5e3042a163cb9222346 --- /dev/null +++ b/build/torch29-cxx11-rocm64-x86_64-linux/_optimizer_06a260a_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6ad69fa088ef05b1697f74d59c1a5a12f17dbf2a3cddb8c6b92ed7543b4cbdbc +size 1865232 diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/distributed/utils.py b/build/torch29-cxx11-rocm64-x86_64-linux/distributed/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6d5843506c13d9d31603b2b4e30c1c91d0baab28 --- /dev/null +++ b/build/torch29-cxx11-rocm64-x86_64-linux/distributed/utils.py @@ -0,0 +1,175 @@ +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor +from torch.distributed.tensor.placement_types import (Placement, Shard, + _StridedShard) + + +def get_slices_of_dtensor( + target: DTensor | torch.Tensor, + local_rank: int, + shard_mesh: DeviceMesh, + shard_placements: tuple[Placement], +) -> tuple[slice]: + """ + Get the slice of local tensor for a given rank from a tensor. + Args: + target (DTensor | torch.Tensor): The target tensor. + rank (int): The local rank of the shard group. + shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. + shard_placements (tuple[Placement]): The shard placements. + """ + + slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] + + # find the global rank of the local rank in the shard mesh + rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] + + rank_coords = (shard_mesh.mesh == rank).nonzero() + + assert len(rank_coords) == 1 + rank_coords = tuple(rank_coords[0].tolist()) + + assert len(rank_coords) == len(shard_placements) + + # Caution: Assuming replicate-to-shard of the shard mesh goes with + # left-to-right sharding. This is ensured by the sorting logic of + # construct_shard_mesh function. + for i, (rank_coord, + placement) in enumerate(zip(rank_coords, shard_placements)): + assert isinstance(placement, Shard) + + num_ranks = shard_mesh.mesh.shape[i] + + dim = placement.dim + dim_size = (slices[dim].stop - slices[dim].start) + + if dim_size % num_ranks != 0: + raise NotImplementedError( + f"Dimension size {dim_size} is not divisible " + f"by number of ranks {num_ranks} for shard " + f"placement on dim {dim}. (shape: {target.shape})") + + shard_size = dim_size // num_ranks + + start = slices[dim].start + rank_coord * shard_size + end = start + shard_size + + assert start < end <= slices[dim].stop + + slices[dim] = slice(start, end) + + return tuple(slices) + + +_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, + ProcessGroup]] = dict() + + +def construct_shard_mesh( + placements: tuple[Placement], + mesh: DeviceMesh, +) -> (DeviceMesh, ProcessGroup, tuple[Placement]): + """ + Construct Shard Mesh and Placements for unsharding. + It removes Replicate placements and constructs a new Mesh and ProcessGroup. + """ + my_rank = dist.get_rank() + + assert mesh.mesh.device.type == 'cpu' + + # Copy mesh to avoid modifying the original mesh + mesh = mesh.mesh.clone() + + # 1. Sort placements. Replicate first, then Shard by dim ascending. + + # For Shard, strided shard comes after regular shard on the same dim + # to preserve left-to-right order of replicate-to-shard. + # This is because that strided shard is using stride to represent + # more fine-grained sharding on the same dim. + # Please check the URL below for _StridedShard. + # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 + + def placement_sort_key( + placement_with_index: tuple[float, Placement] + ) -> tuple[int, float, int]: # (dim, split factor, original index) + index, placement = placement_with_index + is_replicate = placement.is_replicate() + is_shard = placement.is_shard() + is_partial = placement.is_partial() + + assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" + assert not is_partial, "Partial placement is not supported." + + if is_replicate: + return (-1.0, 0, index) + elif is_shard: + if isinstance(placement, _StridedShard): + return (placement.dim, 1 / placement.split_factor, index) + return (placement.dim, 0, index) + else: + raise TypeError(f"Unknown placement type: {type(placement)}") + + placements_with_index: list[tuple[int, + Placement]] = list(enumerate(placements)) + placements_with_index = sorted(placements_with_index, + key=placement_sort_key) + + sorted_indices, sorted_placements = zip(*placements_with_index) + + # 2. Permute mesh according to sorted placements. + sorted_mesh = mesh.permute(sorted_indices) + + # 3. Collect list of shard meshes by removing replicate dims + # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] + # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) + num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) + + # merge replicate dims + # shard_meshes became a list of shard meshes with a length of replicate degree + if num_replicates > 0: + sorted_mesh = sorted_mesh.flatten( + 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh + shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) + else: + shard_meshes = [sorted_mesh] + shard_placements = sorted_placements[num_replicates:] + + # assume all shard placements are different + assert len(shard_placements) == len(set(shard_placements)) + + # 4. Construct ProcessGroups + # Caution: all groups should be created in the same order in all processes, + # even though each process only needs its own group. + + # To use tensor as dict key, convert it to tuple + def tensor_to_tuple(t): + if isinstance(t, torch.Tensor): + t = t.tolist() + if isinstance(t, list): + return tuple(tensor_to_tuple(x) for x in t) + return t + + my_shard_mesh_as_tuple = None + for shard_mesh in shard_meshes: + assert isinstance(shard_mesh, torch.Tensor) + shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) + + if (my_rank == shard_mesh).any().item(): + assert my_shard_mesh_as_tuple is None + my_shard_mesh_as_tuple = shard_mesh_as_tuple + + # update global cache + if shard_mesh_as_tuple not in _ranks_to_dist_cache: + shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) + _ranks_to_dist_cache[shard_mesh_as_tuple] = ( + DeviceMesh(device_type="cuda", mesh=shard_mesh), + shard_process_group, + ) + + my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ + my_shard_mesh_as_tuple] + + return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/matmul_transpose_triton.py b/build/torch29-cxx11-rocm64-x86_64-linux/matmul_transpose_triton.py new file mode 100644 index 0000000000000000000000000000000000000000..4565b2c4fd506a4218340d380d6c962b16774b1d --- /dev/null +++ b/build/torch29-cxx11-rocm64-x86_64-linux/matmul_transpose_triton.py @@ -0,0 +1,128 @@ +# MIT License +# +# Copyright (c) 2025 Tianyang Lin +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import torch +import triton +import triton.language as tl + + +def get_autotune_config(): + return [ + triton.Config( + { + 'BLOCK_SIZE_M': blk_m, + 'BLOCK_SIZE_K': blk_k, + 'GROUP_SIZE_M': grp_sz + }, + num_stages=n_stages, + num_warps=n_warps) for blk_m in [32, 64, 128] + for blk_k in [32, 64] for grp_sz in [8] for n_stages in [3, 4, 5] + for n_warps in [4, 8] + ] + + +@triton.autotune( + configs=get_autotune_config(), + key=['M', 'K'], +) +@triton.jit +def mmt_kernel(x, y, M, K, stride_xm, stride_xk, stride_ym, stride_yn, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr): + """ + Core kernel jit function of matmul_transpose that computes y = x @ x.T + The code is a simple adaptation from the triton `matmul` tutorial: + https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html + """ + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + if pid_m > pid_n: + return + + offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_xn = (pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + # we use a & b ptrs to denote different rows of x. + a_ptrs = x + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk) + b_ptrs = x + (offs_xn[:, None] * stride_xm + offs_k[None, :] * stride_xk) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_M), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, + mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, + other=0.0) + b = tl.load(b_ptrs, + mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, + other=0.0) + accumulator = tl.dot(a, tl.permute(b, (1, 0)), accumulator) + a_ptrs += BLOCK_SIZE_K * stride_xk + b_ptrs += BLOCK_SIZE_K * stride_xk + # use dtype.element_ty to accommodate different input datatypes as in cpp templates + # https://github.com/triton-lang/triton/issues/2252 + c = accumulator.to(x.dtype.element_ty) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + c_ptrs = y + stride_ym * offs_cm[:, None] + stride_yn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, c, mask=c_mask) + + # transpose and copy + if pid_m < pid_n: + ct_ptrs = y + stride_ym * offs_cn[:, + None] + stride_yn * offs_cm[None, :] + ct_mask = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(ct_ptrs, tl.permute(c, (1, 0)), mask=ct_mask) + + +def matmul_transpose_assign(d_in, d_out): + assert d_in.is_cuda, "Input `d_in` must be a CUDA tensor" + assert d_out.is_cuda, "Input `d_out` must be a CUDA tensor" + assert d_in.device == d_out.device, "Inputs `d_in` and `d_out` must be on the same CUDA device" + assert d_in.dtype == d_out.dtype, "Inputs must have the same data type" + assert d_in.ndim == 2, "Input `d_in` must be a 2D tensor" + assert d_out.ndim == 2, "Input `d_out` must be a 2D tensor" + assert d_in.size(0) == d_out.size(0) == d_out.size(0), \ + "First dimension of `d_in` must match first and second dimension of `d_out`" + + d_in = d_in.contiguous() + M, K = d_in.shape + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( + M, META['BLOCK_SIZE_M']), ) + with torch.cuda.device(d_in.device.index): + mmt_kernel[grid](d_in, d_out, M, K, d_in.stride(0), d_in.stride(1), + d_out.stride(0), d_out.stride(1)) + + +def matmul_transpose(d_in): + M, _ = d_in.shape + d_out = torch.empty((M, M), device=d_in.device, dtype=d_in.dtype) + matmul_transpose_assign(d_in, d_out) + return d_out diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/metadata.json b/build/torch29-cxx11-rocm64-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..76bafa5f33b6818aa6bb4cab04be811b87519b44 --- /dev/null +++ b/build/torch29-cxx11-rocm64-x86_64-linux/metadata.json @@ -0,0 +1 @@ +{"python-depends":[]} \ No newline at end of file diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/muon.py b/build/torch29-cxx11-rocm64-x86_64-linux/muon.py new file mode 100644 index 0000000000000000000000000000000000000000..dbf25575f185ff379789482068e4ecf55b9455a9 --- /dev/null +++ b/build/torch29-cxx11-rocm64-x86_64-linux/muon.py @@ -0,0 +1,1268 @@ +import logging +import math +import types +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, cast + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Replicate +from torch.distributed.tensor.placement_types import Placement + +from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor +from .matmul_transpose_triton import matmul_transpose_assign + +logger = logging.getLogger(__name__) + +COMM_DTYPE = torch.bfloat16 +DEFAULT_CHUNK_SIZE_RATIO = 4 + + +# This code snippet is a modified version adapted from the following GitHub repositories: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +# Muon's Newton–Schulz iteration causes high variance in singular values +# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. +@torch.no_grad() +# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + assert G.dtype == COMM_DTYPE + X = G # no manual typecast + + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) + # Perform the NS iterations + for a, b, c in [ + (4.0848, -6.8946, 2.9270), + (3.9505, -6.3029, 2.6377), + (3.7418, -5.5913, 2.3037), + (2.8769, -3.1427, 1.2046), + (2.8366, -3.0525, 1.2012), + ]: + matmul_transpose_assign(X, buf1) + matmul_transpose_assign(buf1, buf2) + buf1.mul_(b).add_(buf2, alpha=c) + X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X + + +@dataclass +class _muon_state: + # TODO: use Optional + worker_rank: int + process_group: ProcessGroup + shard_mesh: DeviceMesh + shard_placements: tuple[Placement, ...] + name: str + qk_clip_state: torch.Tensor | None = None + gathered_grad: torch.Tensor | None = None + scattered_u: DTensor | None = None + computed_u: torch.Tensor | None = None + gather_event: torch.cuda.Event | None = None + compute_event: torch.cuda.Event | None = None + scatter_event: torch.cuda.Event | None = None + + +def numel_for_rank( + param: DTensor, + local_rank: int, + state: _muon_state, +) -> int: + slices = get_slices_of_dtensor( + param, + local_rank, + state.shard_mesh, + state.shard_placements, + ) + + numel = 1 + for s, dim in zip(slices, param.shape): + start, stop, step = s.indices(dim) + length = max(0, (stop - start + (step - 1)) // step) + numel *= length + + return numel + + +@torch.no_grad() +def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): + """ + Pre-allocate gathered_grad buffer on compute_stream + before launching all2all gather + """ + with torch.cuda.stream(compute_stream): + for p in params: + state = param_to_state[id(p)] + if rank == state.worker_rank: + state.gathered_grad = torch.empty(p.shape, + dtype=COMM_DTYPE, + device="cuda") + else: + state.gathered_grad = None + + alloc_event = torch.cuda.Event() + alloc_event.record(compute_stream) + return alloc_event + + +@torch.no_grad() +def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, + alloc_event): + """ + All2all gathers shards so each owner rank reconstructs its full gradient + """ + with torch.cuda.stream(comm_stream): + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + + # Construct sending buffers + per_dst = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + for p in params: + state = param_to_state[id(p)] + dst = state.worker_rank + assert dst < num_ranks + shard_elems = numel_for_rank(p, rank, state) + g = p.grad + g = g.to_local().to(COMM_DTYPE).contiguous() + assert g.numel() == shard_elems + per_dst[dst].append(g.view(-1)) + send_counts[dst] += shard_elems + + assert any( + len(v) > 0 for v in per_dst + ), "At least one destination rank must receive a sharded tensor" + # list[list[Tensor]] -> list[Tensor] + per_dst = [t for dst in per_dst for t in dst] + + send_buf = torch.cat(per_dst, dim=0) + + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Compute receive sizes and allocate receiving buffers + recv_counts = [0] * num_ranks + + for src in range(num_ranks): + total = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + total += numel_for_rank(p, src, state) + recv_counts[src] = total + + recv_total = sum(recv_counts) + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + #All2All + logger.debug(f"send_buf size: {send_buf.numel()}, " + f"recv_buf size: {recv_buf.numel()}, " + f"recv_counts: {recv_counts}, " + f"send_counts: {send_counts}, " + f"process_group: {str(process_group)}") + dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + ) + + # Reconstructs gathered grad from the received buffer + # + # recv_buf (num ranks = 3) + # + # From rank 0 From rank 1 From rank 2 + # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | + # + # Outer loop: + # rank 0 -> rank 1 -> rank2 + # + # Inner loop: + # p1_n -> p2_n -> p3_n + + comm_stream.wait_event(alloc_event) + + off = 0 + for src in range(num_ranks): + if recv_counts[src] == 0: + continue + + block = recv_counts[src] + inner_off = 0 + for p in owned_params: + state = param_to_state[id(p)] + assert state.worker_rank == rank + + # get the slice of the full dtensor corresponding to rank src. + slices = get_slices_of_dtensor(state.gathered_grad, src, + state.shard_mesh, + state.shard_placements) + + dst = state.gathered_grad[slices] + assert dst._base is state.gathered_grad + + n = dst.numel() + assert n > 0 + + sg = recv_buf.narrow(0, off + inner_off, n) + sg = sg.reshape_as(dst) + dst.copy_(sg) + + inner_off += n + off += block + + for p in params: + state = param_to_state[id(p)] + if state.worker_rank == rank: + state.gather_event = torch.cuda.Event() + state.gather_event.record(comm_stream) + else: + state.gathered_grad = None + state.gather_event = None + if none_grad: + p.grad = None + + +@torch.no_grad() +def _compute_u(p, state, steps, rank, compute_stream): + """ + On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. + """ + with torch.cuda.stream(compute_stream): + if rank == state.worker_rank: + if state.gather_event is None: + raise RuntimeError("Gather event must be set before compute.") + compute_stream.wait_event(state.gather_event) + u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) + state.gathered_grad = None + state.computed_u = u + state.compute_event = torch.cuda.Event() + state.compute_event.record() + else: + state.computed_u = None + state.compute_event = None + + +@torch.no_grad() +def _alloc_scattered_u(params, param_to_state, rank, compute_stream): + """ + Pre-allocate scattered_u buffer on compute_stream + before launching all2all gather + """ + with torch.cuda.stream(compute_stream): + for p in params: + state = param_to_state[id(p)] + state.scattered_u = torch.empty_like(p.to_local(), + dtype=COMM_DTYPE) + + alloc_event = torch.cuda.Event() + alloc_event.record(compute_stream) + return alloc_event + + +def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): + """ + All2all scatters full gradients to all ranks + """ + with torch.cuda.stream(comm_stream): + process_group = param_to_state[id(params[0])].process_group + num_ranks = dist.get_world_size(group=process_group) + owned_params = [ + p for p in params if param_to_state[id(p)].worker_rank == rank + ] + + # Construct sending buffer + per_dst = [[] for _ in range(num_ranks)] + send_counts = [0] * num_ranks + + if owned_params: + for p in owned_params: + state = param_to_state[id(p)] + if state.compute_event is None: + raise RuntimeError( + "Compute event must be set before scatter.") + comm_stream.wait_event(state.compute_event) + state.gathered_grad = None + + assert state.computed_u is not None + + u_full = state.computed_u.to(COMM_DTYPE).contiguous() + + offset = 0 + for dst in range(num_ranks): + # get the slice of the full tensor corresponding to rank dst. + slices = get_slices_of_dtensor(u_full, dst, + state.shard_mesh, + state.shard_placements) + su = u_full[slices].flatten() + + n = su.numel() + assert n > 0 + + per_dst[dst].append(su) + send_counts[dst] += n + offset += n + + assert offset == u_full.numel() + + lengths = [len(v) for v in per_dst] + if all(l > 0 for l in lengths): + assert all( + l == lengths[0] for l in lengths + ), "All destination ranks must have the same number of sharded tensor" + # list[list[Tensor]] -> list[Tensor] + per_dst = [t for dst in per_dst for t in dst] + send_buf = torch.cat(per_dst, dim=0) + else: + # all_to_all requires participation from all ranks + # Even non-owner ranks must join the collective call + send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") + + # Compute receive sizes and allocate receiving buffers + recv_counts = [0] * num_ranks + + for src in range(num_ranks): + total = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + total += numel_for_rank(p, rank, state) + recv_counts[src] = total + + recv_total = sum(recv_counts) + assert recv_total > 0 + recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") + + #All2All + dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts, + input_split_sizes=send_counts, + group=process_group, + ) + + # Copy to pre-allocated scattered_u buffer from the received buffer + # + # recv_buf (num ranks = 3, local_rank = 0) + # + # From rank 0 From rank 1 From rank 2 + # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | + # + # Outer loop: + # rank 0 -> rank 1 -> rank2 + # + # Inner loop: + # src(0) : p1_0 -> p2_0 -> p3_0 + # src(1) : p4_0 + # src(2) : p5_0 -> p6_0 + + comm_stream.wait_event(alloc_event) + + off = 0 + for src in range(num_ranks): + block = recv_counts[src] + if block == 0: + continue + + inner_off = 0 + for p in params: + state = param_to_state[id(p)] + if state.worker_rank != src: + continue + n = numel_for_rank(p, rank, state) + assert n > 0 + + flat_local = recv_buf.narrow(0, off + inner_off, + n).view_as(p.to_local()) + state.scattered_u.copy_(flat_local) + + state.scatter_event = torch.cuda.Event() + state.scatter_event.record(comm_stream) + inner_off += n + + assert inner_off == block + off += block + + +def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, + compute_stream): + """ + Update sharded parameter p with the scattered_u. + Only worker_rank frees computed_u. + """ + with torch.cuda.stream(compute_stream): + if state.scatter_event is None: + raise RuntimeError("Scatter event must be set before update") + compute_stream.wait_event(state.scatter_event) + u_dtensor = DTensor.from_local( + state.scattered_u, + placements=p.placements, + device_mesh=p.device_mesh, + ) + + state.scattered_u = u_dtensor + + if rank == state.worker_rank: + # Free computed_u + state.computed_u = None + + Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) + state.scattered_u = None + u_dtensor = None + + scales_full = Muon._compute_scales( + p, + state.qk_clip_state) if state.qk_clip_state is not None else None + if scales_full is not None: + # Have to slice scales_full among dim 0 + weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, + state.shard_placements) + ratio = p.shape[0] // scales_full.shape[0] + scales_slice = slice( + None if weight_slices[0].start is None else + weight_slices[0].start // ratio, + None if weight_slices[0].stop is None else + weight_slices[0].stop // ratio, + None, + ) + + scales_local = scales_full[scales_slice] + scales_local = DTensor.from_local( + scales_local, + placements=p.placements, + device_mesh=p.device_mesh, + ) + Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) + + +def default_is_muon(name, x): + skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] + return x.ndim >= 2 and not any(key in name for key in skip_keys) + + +def get_default_muon_param_groups(model, is_muon_func=default_is_muon): + muon_params, muon_names = [], [] + non_muon_params = [] + + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + if is_muon_func(n, p): + muon_params.append(p) + muon_names.append(n) + else: + non_muon_params.append(p) + + return [ + { + "params": muon_params, + "names": muon_names, + "use_muon": True, + }, + { + "params": non_muon_params, + "use_muon": False, + }, + ] + + +def parse_qk_layer(name: str) -> tuple[str | None, int]: + """ + Parse a parameter name to check if it is a query/key projection layer + ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). + + Returns: + (kind, layer_idx) or (None, -1) if not matched. + + Example: + 'model.3.attn.wq.weight' -> ('wq', 3) + 'model.5.attn.wk.weight' -> ('wk', 5) + 'model.2.attn.q_proj.weight' -> ('q_proj', 2) + 'model.7.attn.k_proj.weight' -> ('k_proj', 7) + 'model.4.attn.v_proj.weight' -> (None, -1) + """ + parts = name.split('.') + if len(parts) < 3: + return None, -1 + + kind = parts[-2] + + layer_idx = -1 + for part in reversed(parts): + if part.isdigit(): + layer_idx = int(part) + break + + if kind in ('wq', 'wk', 'q_proj', 'k_proj'): + return kind, layer_idx + + return None, -1 + + +@dataclass +class QKClipInfo: + """Per-parameter dynamic info computed from config + runtime logits.""" + kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None + indices: list[int] # which heads to consider for clipping + head_dim: int # from config + threshold: float # from config + logit: torch.Tensor | None + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + model: The model to be optimized by Muon. + is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + weight_decay: The weight decay for Muon and AdamW. + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. + debug: Whether to print debug information. + clip_info : Configuration for QK clipping. Expected keys: + - "q_indices" (list[int]): Indices of query heads to consider. + - "k_indices" (list[int]): Indices of key heads to consider. + - "head_dim" (int): Dimensionality of each attention head. + - "threshold" (float): Threshold value; heads whose QK logits exceed + this value will be scaled down. + Default is: + { + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100 + } + warmup_step : How many all2all gather, compute operations are launched in advance + before the corresponding all2all scatter steps begin. + A higher warmup_step increases memory usage but can improve + performance by overlapping communication. + Parallel muon only. + chunk_size : Batch size of parameters to process in each + all2all gather/compute/scatter step. + Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. + use_distributed_muon: Use distributed muon by Liu et al. (2024). + For testing purpose only. + small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon + """ + + def __init__(self, + params, + lr=1e-3, + momentum=0.95, + nesterov=True, + ns_steps=5, + weight_decay=0.1, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + none_grad=True, + debug=False, + clip_config={ + "q_indices": [], + "k_indices": [], + "head_dim": 128, + "threshold": 100 + }, + warmup_step=5, + chunk_size=-1, + use_distributed_muon=False, + small_param_numel_threshold=65536): + defaults = dict( + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + none_grad=none_grad, + use_muon=True, + ) + error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." + instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" + + if isinstance(params, types.GeneratorType): + raise ValueError(error_message.format(idx=0) + instruction_code) + for _idx, param_group in enumerate(params): + if param_group.get("use_muon", None) is None: + raise ValueError( + error_message.format(idx=_idx) + instruction_code) + + super().__init__(params, defaults) + + self.rank = None + + self.comm_stream = torch.cuda.Stream() + self.compute_stream = torch.cuda.Stream() + self.debug = debug + self.clip_config = clip_config + self.warmup_step = warmup_step + self.chunk_size = chunk_size + self.use_distributed_muon = use_distributed_muon + self.small_param_numel_threshold = small_param_numel_threshold + + def _calc_flops(self, G, steps): + assert len(G.shape) == 2 + M, N = G.shape + if M > N: + M, N = N, M + + return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def set_rank_once(self, rank): + if self.rank is None: + self.rank = rank + else: + assert self.rank == rank + + def get_shard_mesh(self, p): + """ + Get the shard mesh for a parameter p on the given rank. + """ + assert isinstance( + p, DTensor), "Parallel Muon only supports DTensor parameters." + + shard_mesh, shard_pg, shard_placements = construct_shard_mesh( + p.placements, p.device_mesh) + + # set rank with the local rank in the shard process group + self.set_rank_once(dist.get_rank(group=shard_pg)) + + return shard_mesh, shard_pg, shard_placements + + def init_state_and_assign_params(self, names, params, group, qk_logits): + param_to_state = {} + param_to_flops = {} + + total_flops = 0 + for p in params: + g = p.grad + if g is None: + continue + assert g.ndim == 2, "Muon only supports 2D parameters." + + flops = self._calc_flops(g, group["ns_steps"]) + param_to_flops[id(p)] = flops + total_flops += flops + + if self.debug: + print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", + flush=True) + + paired = list(zip(names, params)) + + paired_sorted = sorted(paired, + key=lambda x: param_to_flops[id(x[1])], + reverse=True) + + names_sorted, params_sorted = zip(*paired_sorted) + ordered_names = list(names_sorted) + ordered_params = list(params_sorted) + + round_robin = 0 + mesh = ordered_params[0].device_mesh + placements = ordered_params[0].placements + + shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( + ordered_params[0]) + shard_mesh_flattened = shard_mesh.mesh.flatten() + num_ranks = dist.get_world_size(group=shard_pg) + + for n, p in zip(ordered_names, ordered_params): + if mesh != p.device_mesh: + raise ValueError("All parameters must be on the same mesh.") + if placements != p.placements: + raise ValueError("All parameters must have same placements.") + + worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks + round_robin = (round_robin + 1) % len(shard_mesh_flattened) + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + param_to_state[id(p)] = _muon_state( + worker_rank=worker_rank, + process_group=shard_pg, + shard_mesh=shard_mesh, + shard_placements=shard_placements, + name=n, + qk_clip_state=qk_clip_state, + ) + + return param_to_state, ordered_params + + def base(self, names, params, group, lr, weight_decay, momentum, + qk_logits): + # generate weight updates in distributed fashion + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + g = self._update_g(p, g, group, momentum) + + u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), + steps=group["ns_steps"]) + + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + scales_full = self._compute_scales( + p, qk_clip_state) if qk_clip_state is not None else None + if scales_full is not None: + Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) + + def distributed_muon( + self, + names: list[str], + params: list[torch.nn.Parameter], + group: dict[str, Any], + lr: float, + weight_decay: float, + momentum: float, + qk_logits: list[torch.Tensor | DTensor] | None, + ): + """ Implementation of Distributed Muon by Liu et al. """ + + for n, p in zip(names, params): + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + g = self._update_g(p, g, group, momentum) + + # Gather G + if isinstance(p.data, DTensor): + g_full = g.full_tensor() + p_full = p.data.full_tensor() + else: + g_full = g + p_full = p + + u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), + steps=group["ns_steps"]) + + adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) + Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + scales_full = self._compute_scales( + p_full, qk_clip_state) if qk_clip_state is not None else None + + if scales_full is not None: + Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) + + if isinstance(p.data, DTensor): + ndims = len(p.device_mesh.mesh.shape) + p_replicate = DTensor.from_local( + p_full, + device_mesh=p.device_mesh, + placements=[Replicate() for _ in range(ndims)], + ) + + p_sharded = p_replicate.redistribute( + device_mesh=p.device_mesh, + placements=p.placements, + ) + + p.copy_(p_sharded) + + def _update_g(self, p, g, group, momentum): + # calc update + state = self.state[p] + buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) + torch.add(g, buf, alpha=momentum, out=buf) + if group["nesterov"]: + g.add_(buf, alpha=momentum) + return g + return buf + + @staticmethod + def _update_p(p, u, lr, adjusted_lr, weight_decay): + if isinstance(p, torch.nn.Parameter): + # apply weight decay + p.data.mul_(1 - lr * weight_decay) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + else: + p.mul_(1 - lr * weight_decay) + p.add_(u, alpha=-adjusted_lr) + + def get_qk_clip_info(self, n, qk_logits): + if self.clip_config is None: + return None + + head_dim = self.clip_config.get('head_dim') + threshold = self.clip_config.get('threshold') + kind, layer_idx = parse_qk_layer(n) + + logit, indices = None, [] + if qk_logits is not None and kind is not None: + logit = qk_logits[layer_idx] + indices_key = 'q_indices' if 'q' in kind else 'k_indices' + indices = self.clip_config.get(indices_key, []) or [] + + if isinstance(logit, DTensor): + # In TP settings, qk_logits may be DTensor + # We convert it to full tensor here for simplicity + logit = logit.full_tensor() + + return QKClipInfo( + kind=kind, + indices=indices, + head_dim=head_dim, + threshold=threshold, + logit=logit, + ) + + @staticmethod + def _compute_scales(p, qk_clip_state): + kind = qk_clip_state.kind + indices = qk_clip_state.indices + head_dim = qk_clip_state.head_dim + threshold = qk_clip_state.threshold + logit = qk_clip_state.logit + + H_global = p.shape[0] // head_dim + scales_full = torch.ones(H_global, device=p.data.device) + scaling = 0 + + for logit_idx, head_idx in enumerate(indices): + v_ele = float(logit[logit_idx]) + if v_ele > threshold: + new_scale = math.sqrt(threshold / v_ele) + if new_scale < scales_full[head_idx]: + scales_full[head_idx] = new_scale + logger.info( + f"[{kind}] Head {head_idx} exceeded threshold " + f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" + ) + scaling += 1 + + return scales_full if scaling > 0 else None + + @staticmethod + def _qk_clip(p, scales, head_dim): + if isinstance(p, torch.nn.Parameter): + W = p.data.view(-1, head_dim, p.data.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + else: + W = p.view(-1, head_dim, p.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + + def parallel(self, names, params, group, lr, weight_decay, momentum, + qk_logits): + """ + Perform a parallel optimization step using Muon. + """ + + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + + # Update g in the local rank + g = self._update_g( + p, + g, + group, + momentum=momentum, + ) + p.grad = g + + param_to_state, ordered_params = self.init_state_and_assign_params( + names, params, group, qk_logits) + + assert self.rank is not None + + def enqueue_all2all_gather(start_idx, chunk_size): + target_params = ordered_params[start_idx:start_idx + chunk_size] + if target_params: + alloc_event = _alloc_gathered_grad(target_params, + param_to_state, self.rank, + self.compute_stream) + _all2all_gather(target_params, param_to_state, self.rank, + self.comm_stream, group["none_grad"], + alloc_event) + + def enqueue_computes(start_idx, chunk_size): + for p in ordered_params[start_idx:start_idx + chunk_size]: + state = param_to_state[id(p)] + _compute_u(p, state, group["ns_steps"], self.rank, + self.compute_stream) + + def enqueue_all2all_scatter(start_idx, chunk_size): + target_params = ordered_params[start_idx:start_idx + chunk_size] + if target_params: + alloc_event = _alloc_scattered_u(target_params, param_to_state, + self.rank, + self.compute_stream) + _all2all_scatter(target_params, param_to_state, self.rank, + self.comm_stream, alloc_event) + + def enqueue_update_param(start_idx, chunk_size): + for p in ordered_params[start_idx:start_idx + chunk_size]: + state = param_to_state[id(p)] + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + _update_param(p, state, lr, adjusted_lr, weight_decay, + self.rank, self.compute_stream) + + if self.chunk_size == -1: + shard_ranks = dist.get_world_size(param_to_state[id( + params[0])].process_group) + chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO + elif self.chunk_size > 0: + chunk_size = self.chunk_size + else: + raise ValueError("chunk_size must be -1 or a positive integer.") + + # Wait grad update + self.comm_stream.wait_stream(torch.cuda.current_stream()) + + warmup_step = self.warmup_step + for i in range(0, warmup_step): + enqueue_all2all_gather(i * chunk_size, chunk_size) + enqueue_computes(i * chunk_size, chunk_size) + + for i in range(0, len(params) + chunk_size - 1, chunk_size): + enqueue_all2all_scatter(i, chunk_size) + enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) + enqueue_update_param(i, chunk_size) + enqueue_computes(i + warmup_step * chunk_size, chunk_size) + + # Wait the last update_param to finish + torch.cuda.current_stream().wait_stream(self.compute_stream) + + @staticmethod + def _fused_adamw( + params: list[torch.Tensor], + grads: list[torch.Tensor], + exp_avgs: list[torch.Tensor], + exp_avg_sqs: list[torch.Tensor], + max_exp_avg_sqs: list[torch.Tensor], + state_steps: list[torch.Tensor], + amsgrad: bool, + beta1: float, + beta2: float, + lr: float | torch.Tensor, + weight_decay: float, + eps: float, + maximize: bool, + ) -> None: + if not params: + return + + # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer + # treating it as a scalar. + lr_dict: DeviceDict | None = ({ + lr.device: lr + } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else + None) + grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( + [ + params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, + state_steps + ] # type: ignore[list-item] + ) + for (device, _), ( + ( + device_params_, + device_grads_, + device_exp_avgs_, + device_exp_avg_sqs_, + device_max_exp_avg_sqs, + device_state_steps_, + ), + _, + ) in grouped_tensors.items(): + device_params = cast(list[torch.Tensor], device_params_) + device_grads = cast(list[torch.Tensor], device_grads_) + device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) + device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) + device_state_steps = cast(list[torch.Tensor], device_state_steps_) + + if lr_dict is not None and device not in lr_dict: + lr_dict[device] = lr.to( + device=device, + non_blocking=True) # type: ignore[union-attr] + lr = lr_dict[device] + torch._foreach_add_(device_state_steps, 1) + func = torch._fused_adamw_ + func( + device_params, + device_grads, + device_exp_avgs, + device_exp_avg_sqs, + device_max_exp_avg_sqs, # type: ignore[arg-type] + device_state_steps, + amsgrad=amsgrad, + lr=lr, # type: ignore[arg-type] + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + eps=eps, + maximize=maximize, + ) + + def _step_muon(self, group, qk_logits=None): + params = group["params"] + lr = group["lr"] + weight_decay = group["weight_decay"] + momentum = group["momentum"] + names = group["names"] + + param_dtensors = [] + name_dtensors = [] + + param_tensors = [] + name_tensors = [] + + param_dtensors_small = [] + name_dtensors_small = [] + + if self.use_distributed_muon: + self.distributed_muon(names=names, + params=params, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits) + return + + # For simplicity, we use distributed Muon for small parameters + # whose number of elements is below a threshold. + for n, p in zip(names, params): + if p is None or p.grad is None: + continue + if isinstance(p.data, DTensor): + if all( + isinstance(placement, Replicate) + for placement in p.placements): + param_tensors.append(p) + name_tensors.append(n) + elif p.data.numel() <= self.small_param_numel_threshold: + param_dtensors_small.append(p) + name_dtensors_small.append(n) + else: + param_dtensors.append(p) + name_dtensors.append(n) + elif isinstance(p.data, torch.Tensor): + param_tensors.append(p) + name_tensors.append(n) + else: + raise TypeError(f"Unsupported parameter type: {type(p.data)}") + + logger.debug( + f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " + f"{len(param_dtensors_small)} Small DTensors") + + def group_dtensors(dtensors, names): + # To support different placements, we group parameters by placements + # and run parallel Muon on each group. + + placement_to_params = defaultdict(lambda: ([], [])) + # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] + + assert len(dtensors) == len(names) + for p, n in zip(dtensors, names): + placement_to_params[tuple([p.placements, + p.device_mesh])][0].append(n) + placement_to_params[tuple([p.placements, + p.device_mesh])][1].append(p) + return placement_to_params + + if len(param_dtensors_small) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + self.distributed_muon( + params=param_dtensors_small, + names=name_dtensors_small, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_dtensors) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + dtensor_group = group_dtensors(param_dtensors, name_dtensors) + for _, (names, params) in dtensor_group.items(): + self.parallel( + names, + params, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_tensors) > 0: + self.base( + name_tensors, + param_tensors, + group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + def _step_adamw_params(self, params, group): + params_with_grads = [] + grads = [] + moment1 = [] + moment2 = [] + max_exp_avg_sqs = [] + state_steps = [] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["weight_decay"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + params_with_grads.append(p) + grads.append(g) + if "step" not in state: + state["step"] = (torch.zeros((), + dtype=torch.float32, + device=p.device)) + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + moment1.append(state["moment1"]) + moment2.append(state["moment2"]) + if not isinstance(state["step"], torch.Tensor): + step_tensor = torch.tensor(state["step"], + dtype=torch.float32, + device=p.device) + else: + step_tensor = state["step"] + state_steps.append(step_tensor) + + self._fused_adamw( + params_with_grads, + grads, + moment1, + moment2, + max_exp_avg_sqs, + state_steps, + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + maximize=False, + ) + + def _step_adamw(self, group): + params = group["params"] + + # group params with it's type and placement + placement_to_params: dict[tuple[Placement | type, + DeviceMesh | None]] = defaultdict(list) + for p in params: + match p: + case DTensor(): + placement_to_params[tuple([p.placements, + p.device_mesh])].append(p) + case torch.Tensor(): + placement_to_params[tuple([torch.Tensor, None])].append(p) + + for params in placement_to_params.values(): + self._step_adamw_params(params, group) + + @torch.no_grad + def step(self, closure=None, qk_logits=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices + to 1D tensors of shape (num_heads,), representing the maximum + QK logits across all tokens, computed as + (1 / sqrt(head_dim)) * (Q @ K^T). + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + if group["use_muon"]: + self._step_muon(group, qk_logits=qk_logits) + else: + self._step_adamw(group) + + return loss diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/optimizer/__init__.py b/build/torch29-cxx11-rocm64-x86_64-linux/optimizer/__init__.py index 239c7a65f8293e7d0df28f05fce645af56d628c0..03dbc1afe1cf156661a2b1b22003cd5f599a0309 100644 --- a/build/torch29-cxx11-rocm64-x86_64-linux/optimizer/__init__.py +++ b/build/torch29-cxx11-rocm64-x86_64-linux/optimizer/__init__.py @@ -1,5 +1,26 @@ -from .muon import Muon +import ctypes +import sys -__all__ = [ - "Muon", -] +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/optimizer/_ops.py b/build/torch29-cxx11-rocm64-x86_64-linux/optimizer/_ops.py deleted file mode 100644 index 7d598206add1bca142661a3df6c510e3d9575d54..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-rocm64-x86_64-linux/optimizer/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _optimizer_23d68bb_dirty -ops = torch.ops._optimizer_23d68bb_dirty - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_optimizer_23d68bb_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so b/build/torch29-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so deleted file mode 100755 index eaf3eac7689223f26618fe6b233e8a98058cb637..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-rocm64-x86_64-linux/optimizer/_optimizer_23d68bb_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:8de22742ad0d387021a7b812ee3b7d0c8c54191914c8c0469886f6d2c082e9e3 -size 1852272 diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/optimizer/distributed/utils.py b/build/torch29-cxx11-rocm64-x86_64-linux/optimizer/distributed/utils.py deleted file mode 100644 index 0b4b58bfb329b1c015129e4c4fc99f7bfa2ab30a..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-rocm64-x86_64-linux/optimizer/distributed/utils.py +++ /dev/null @@ -1,174 +0,0 @@ -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor -from torch.distributed.tensor.placement_types import (Placement, Shard, - _StridedShard) - - -def get_slices_of_dtensor( - target: DTensor | torch.Tensor, - local_rank: int, - shard_mesh: DeviceMesh, - shard_placements: tuple[Placement], -) -> tuple[slice]: - """ - Get the slice of local tensor for a given rank from a tensor. - Args: - target (DTensor | torch.Tensor): The target tensor. - rank (int): The local rank of the shard group. - shard_mesh (DeviceMesh): The shard mesh. It consists of global ranks. - shard_placements (tuple[Placement]): The shard placements. - """ - - slices: list[slice] = [slice(0, dim_size) for dim_size in target.size()] - - # find the global rank of the local rank in the shard mesh - rank = sorted(shard_mesh.mesh.flatten().tolist())[local_rank] - - rank_coords = (shard_mesh.mesh == rank).nonzero() - - assert len(rank_coords) == 1 - rank_coords = tuple(rank_coords[0].tolist()) - - assert len(rank_coords) == len(shard_placements) - - # Caution: Assuming replicate-to-shard of the shard mesh goes with - # left-to-right sharding. This is ensured by the sorting logic of - # construct_shard_mesh function. - for i, (rank_coord, - placement) in enumerate(zip(rank_coords, shard_placements)): - assert isinstance(placement, Shard) - - num_ranks = shard_mesh.mesh.shape[i] - - dim = placement.dim - dim_size = (slices[dim].stop - slices[dim].start) - - if dim_size % num_ranks != 0: - raise NotImplementedError( - f"Dimension size {dim_size} is not divisible " - f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}.") - - shard_size = dim_size // num_ranks - - start = slices[dim].start + rank_coord * shard_size - end = start + shard_size - - assert start < end <= slices[dim].stop - - slices[dim] = slice(start, end) - - return tuple(slices) - - -_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, ProcessGroup]] = dict() - - -def construct_shard_mesh( - placements: tuple[Placement], - mesh: DeviceMesh, -) -> (DeviceMesh, ProcessGroup, tuple[Placement]): - """ - Construct Shard Mesh and Placements for unsharding. - It removes Replicate placements and constructs a new Mesh and ProcessGroup. - """ - my_rank = dist.get_rank() - - assert mesh.mesh.device.type == 'cpu' - - # Copy mesh to avoid modifying the original mesh - mesh = mesh.mesh.clone() - - # 1. Sort placements. Replicate first, then Shard by dim ascending. - - # For Shard, strided shard comes after regular shard on the same dim - # to preserve left-to-right order of replicate-to-shard. - # This is because that strided shard is using stride to represent - # more fine-grained sharding on the same dim. - # Please check the URL below for _StridedShard. - # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/tensor/placement_types.py#L366 - - def placement_sort_key( - placement_with_index: tuple[float, Placement] - ) -> tuple[int, float, int]: # (dim, split factor, original index) - index, placement = placement_with_index - is_replicate = placement.is_replicate() - is_shard = placement.is_shard() - is_partial = placement.is_partial() - - assert is_replicate or is_shard, f"Unsupported placement type: {type(placement)}" - assert not is_partial, "Partial placement is not supported." - - if is_replicate: - return (-1.0, 0, index) - elif is_shard: - if isinstance(placement, _StridedShard): - return (placement.dim, 1 / placement.split_factor, index) - return (placement.dim, 0, index) - else: - raise TypeError(f"Unknown placement type: {type(placement)}") - - placements_with_index: list[tuple[int, - Placement]] = list(enumerate(placements)) - placements_with_index = sorted(placements_with_index, - key=placement_sort_key) - - sorted_indices, sorted_placements = zip(*placements_with_index) - - # 2. Permute mesh according to sorted placements. - sorted_mesh = mesh.permute(sorted_indices) - - # 3. Collect list of shard meshes by removing replicate dims - # For example, (2, 3, 4, 4) with placements [R, R, S(0), S(1)] - # shard_meshes should be list with 2 * 3 = 6 shard meshes of shape (4, 4) - num_replicates = sum(1 for p in sorted_placements if p.is_replicate()) - - # merge replicate dims - # shard_meshes became a list of shard meshes with a length of replicate degree - if num_replicates > 0: - sorted_mesh = sorted_mesh.flatten( - 0, num_replicates - 1) if num_replicates > 1 else sorted_mesh - shard_meshes = list(torch.unbind(sorted_mesh, dim=0)) - else: - shard_meshes = [sorted_mesh] - shard_placements = sorted_placements[num_replicates:] - - # assume all shard placements are different - assert len(shard_placements) == len(set(shard_placements)) - - # 4. Construct ProcessGroups - # Caution: all groups should be created in the same order in all processes, - # even though each process only needs its own group. - - # To use tensor as dict key, convert it to tuple - def tensor_to_tuple(t): - if isinstance(t, torch.Tensor): - t = t.tolist() - if isinstance(t, list): - return tuple(tensor_to_tuple(x) for x in t) - return t - - my_shard_mesh_as_tuple = None - for shard_mesh in shard_meshes: - assert isinstance(shard_mesh, torch.Tensor) - shard_mesh_as_tuple = tensor_to_tuple(shard_mesh) - - if (my_rank == shard_mesh).any().item(): - assert my_shard_mesh_as_tuple is None - my_shard_mesh_as_tuple = shard_mesh_as_tuple - - # update global cache - if shard_mesh_as_tuple not in _ranks_to_dist_cache: - shard_process_group = dist.new_group(shard_mesh.flatten().tolist()) - _ranks_to_dist_cache[shard_mesh_as_tuple] = ( - DeviceMesh(device_type="cuda", mesh=shard_mesh), - shard_process_group, - ) - - my_shard_mesh, my_shard_process_group = _ranks_to_dist_cache[ - my_shard_mesh_as_tuple] - - return my_shard_mesh, my_shard_process_group, shard_placements diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/optimizer/muon.py b/build/torch29-cxx11-rocm64-x86_64-linux/optimizer/muon.py deleted file mode 100644 index cfbcca71741be70048bfd290c62148b2aceda631..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-rocm64-x86_64-linux/optimizer/muon.py +++ /dev/null @@ -1,1240 +0,0 @@ -import logging -import math -import types -from collections import defaultdict -from dataclasses import dataclass -from typing import Any, cast - -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import DTensor, Replicate -from torch.distributed.tensor.placement_types import Placement - -from .distributed.utils import construct_shard_mesh, get_slices_of_dtensor -from .matmul_transpose_triton import matmul_transpose_assign - -logger = logging.getLogger(__name__) - -COMM_DTYPE = torch.bfloat16 -DEFAULT_CHUNK_SIZE_RATIO = 4 - - -# This code snippet is a modified version adapted from the following GitHub repositories: -# https://github.com/KellerJordan/Muon/blob/master/muon.py -# Muon's Newton–Schulz iteration causes high variance in singular values -# Idea: give each iteration its own 3 coefficients and optimize them via gradient descent. -@torch.no_grad() -# matmul_transpose_assign from : https://github.com/nil0x9/flash-muon -def _zeropower_via_newtonschulz5(G, steps): - """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. - """ - assert len(G.shape) == 2 - assert G.dtype == COMM_DTYPE - X = G # no manual typecast - - if G.size(0) > G.size(1): - X = X.T - # Ensure spectral norm is at most 1 - X = X / (X.norm() + 1e-7) - buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) - # Perform the NS iterations - for a, b, c in [ - (4.0848, -6.8946, 2.9270), - (3.9505, -6.3029, 2.6377), - (3.7418, -5.5913, 2.3037), - (2.8769, -3.1427, 1.2046), - (2.8366, -3.0525, 1.2012), - ]: - matmul_transpose_assign(X, buf1) - matmul_transpose_assign(buf1, buf2) - buf1.mul_(b).add_(buf2, alpha=c) - X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) - - if G.size(0) > G.size(1): - X = X.T - return X - - -@dataclass -class _muon_state: - # TODO: use Optional - worker_rank: int - process_group: ProcessGroup - shard_mesh: DeviceMesh - shard_placements: tuple[Placement, ...] - name: str - qk_clip_state: torch.Tensor | None = None - gathered_grad: torch.Tensor | None = None - scattered_u: DTensor | None = None - computed_u: torch.Tensor | None = None - gather_event: torch.cuda.Event | None = None - compute_event: torch.cuda.Event | None = None - scatter_event: torch.cuda.Event | None = None - - -def numel_for_rank( - param: DTensor, - local_rank: int, - state: _muon_state, -) -> int: - slices = get_slices_of_dtensor( - param, - local_rank, - state.shard_mesh, - state.shard_placements, - ) - - numel = 1 - for s, dim in zip(slices, param.shape): - start, stop, step = s.indices(dim) - length = max(0, (stop - start + (step - 1)) // step) - numel *= length - - return numel - - -@torch.no_grad() -def _alloc_gathered_grad(params, param_to_state, rank, compute_stream): - """ - Pre-allocate gathered_grad buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - if rank == state.worker_rank: - state.gathered_grad = torch.empty(p.shape, - dtype=COMM_DTYPE, - device="cuda") - else: - state.gathered_grad = None - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -@torch.no_grad() -def _all2all_gather(params, param_to_state, rank, comm_stream, none_grad, - alloc_event): - """ - All2all gathers shards so each owner rank reconstructs its full gradient - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - - # Construct sending buffers - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - for p in params: - state = param_to_state[id(p)] - dst = state.worker_rank - assert dst < num_ranks - shard_elems = numel_for_rank(p, rank, state) - g = p.grad - g = g.to_local().to(COMM_DTYPE).contiguous() - assert g.numel() == shard_elems - per_dst[dst].append(g.view(-1)) - send_counts[dst] += shard_elems - - assert any( - len(v) > 0 for v in per_dst - ), "At least one destination rank must receive a sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - - send_buf = torch.cat(per_dst, dim=0) - - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - total += numel_for_rank(p, src, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - logger.debug(f"send_buf size: {send_buf.numel()}, " - f"recv_buf size: {recv_buf.numel()}, " - f"recv_counts: {recv_counts}, " - f"send_counts: {send_counts}, " - f"process_group: {str(process_group)}") - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Reconstructs gathered grad from the received buffer - # - # recv_buf (num ranks = 3) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p1_1, p2_1, p3_1 | p1_2, p2_2, p3_2 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # p1_n -> p2_n -> p3_n - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - if recv_counts[src] == 0: - continue - - block = recv_counts[src] - inner_off = 0 - for p in owned_params: - state = param_to_state[id(p)] - assert state.worker_rank == rank - - # get the slice of the full dtensor corresponding to rank src. - slices = get_slices_of_dtensor(state.gathered_grad, src, - state.shard_mesh, - state.shard_placements) - - dst = state.gathered_grad[slices] - assert dst._base is state.gathered_grad - - n = dst.numel() - assert n > 0 - - sg = recv_buf.narrow(0, off + inner_off, n) - sg = sg.reshape_as(dst) - dst.copy_(sg) - - inner_off += n - off += block - - for p in params: - state = param_to_state[id(p)] - if state.worker_rank == rank: - state.gather_event = torch.cuda.Event() - state.gather_event.record(comm_stream) - else: - state.gathered_grad = None - state.gather_event = None - if none_grad: - p.grad = None - - -@torch.no_grad() -def _compute_u(p, state, steps, rank, compute_stream): - """ - On worker_rank, compute the orthogonalized update using Newton-Schulz iteration. - """ - with torch.cuda.stream(compute_stream): - if rank == state.worker_rank: - if state.gather_event is None: - raise RuntimeError("Gather event must be set before compute.") - compute_stream.wait_event(state.gather_event) - u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) - state.gathered_grad = None - state.computed_u = u - state.compute_event = torch.cuda.Event() - state.compute_event.record() - else: - state.computed_u = None - state.compute_event = None - - -@torch.no_grad() -def _alloc_scattered_u(params, param_to_state, rank, compute_stream): - """ - Pre-allocate scattered_u buffer on compute_stream - before launching all2all gather - """ - with torch.cuda.stream(compute_stream): - for p in params: - state = param_to_state[id(p)] - state.scattered_u = torch.empty_like(p.to_local(), - dtype=COMM_DTYPE) - - alloc_event = torch.cuda.Event() - alloc_event.record(compute_stream) - return alloc_event - - -def _all2all_scatter(params, param_to_state, rank, comm_stream, alloc_event): - """ - All2all scatters full gradients to all ranks - """ - with torch.cuda.stream(comm_stream): - process_group = param_to_state[id(params[0])].process_group - num_ranks = dist.get_world_size(group=process_group) - owned_params = [ - p for p in params if param_to_state[id(p)].worker_rank == rank - ] - - # Construct sending buffer - per_dst = [[] for _ in range(num_ranks)] - send_counts = [0] * num_ranks - - if owned_params: - for p in owned_params: - state = param_to_state[id(p)] - if state.compute_event is None: - raise RuntimeError( - "Compute event must be set before scatter.") - comm_stream.wait_event(state.compute_event) - state.gathered_grad = None - - assert state.computed_u is not None - - u_full = state.computed_u.to(COMM_DTYPE).contiguous() - - offset = 0 - for dst in range(num_ranks): - # get the slice of the full tensor corresponding to rank dst. - slices = get_slices_of_dtensor(u_full, dst, - state.shard_mesh, - state.shard_placements) - su = u_full[slices].flatten() - - n = su.numel() - assert n > 0 - - per_dst[dst].append(su) - send_counts[dst] += n - offset += n - - assert offset == u_full.numel() - - lengths = [len(v) for v in per_dst] - if all(l > 0 for l in lengths): - assert all( - l == lengths[0] for l in lengths - ), "All destination ranks must have the same number of sharded tensor" - # list[list[Tensor]] -> list[Tensor] - per_dst = [t for dst in per_dst for t in dst] - send_buf = torch.cat(per_dst, dim=0) - else: - # all_to_all requires participation from all ranks - # Even non-owner ranks must join the collective call - send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") - - # Compute receive sizes and allocate receiving buffers - recv_counts = [0] * num_ranks - - for src in range(num_ranks): - total = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - total += numel_for_rank(p, rank, state) - recv_counts[src] = total - - recv_total = sum(recv_counts) - assert recv_total > 0 - recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") - - #All2All - dist.all_to_all_single( - recv_buf, - send_buf, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=process_group, - ) - - # Copy to pre-allocated scattered_u buffer from the received buffer - # - # recv_buf (num ranks = 3, local_rank = 0) - # - # From rank 0 From rank 1 From rank 2 - # | p1_0, p2_0, p3_0 | p4_0 | p5_0, p6_0 | - # - # Outer loop: - # rank 0 -> rank 1 -> rank2 - # - # Inner loop: - # src(0) : p1_0 -> p2_0 -> p3_0 - # src(1) : p4_0 - # src(2) : p5_0 -> p6_0 - - comm_stream.wait_event(alloc_event) - - off = 0 - for src in range(num_ranks): - block = recv_counts[src] - if block == 0: - continue - - inner_off = 0 - for p in params: - state = param_to_state[id(p)] - if state.worker_rank != src: - continue - n = numel_for_rank(p, rank, state) - assert n > 0 - - flat_local = recv_buf.narrow(0, off + inner_off, - n).view_as(p.to_local()) - state.scattered_u.copy_(flat_local) - - state.scatter_event = torch.cuda.Event() - state.scatter_event.record(comm_stream) - inner_off += n - - assert inner_off == block - off += block - - -def _update_param(p, state, lr, adjusted_lr, weight_decay, rank, - compute_stream): - """ - Update sharded parameter p with the scattered_u. - Only worker_rank frees computed_u. - """ - with torch.cuda.stream(compute_stream): - if state.scatter_event is None: - raise RuntimeError("Scatter event must be set before update") - compute_stream.wait_event(state.scatter_event) - u_dtensor = DTensor.from_local( - state.scattered_u, - placements=p.placements, - device_mesh=p.device_mesh, - ) - - state.scattered_u = u_dtensor - - if rank == state.worker_rank: - # Free computed_u - state.computed_u = None - - Muon._update_p(p, state.scattered_u, lr, adjusted_lr, weight_decay) - state.scattered_u = None - u_dtensor = None - - scales_full = Muon._compute_scales( - p, - state.qk_clip_state) if state.qk_clip_state is not None else None - if scales_full is not None: - # Have to slice scales_full among dim 0 - weight_slices = get_slices_of_dtensor(p, rank, state.shard_mesh, - state.shard_placements) - ratio = p.shape[0] // scales_full.shape[0] - scales_slice = slice( - None if weight_slices[0].start is None else - weight_slices[0].start // ratio, - None if weight_slices[0].stop is None else - weight_slices[0].stop // ratio, - None, - ) - - scales_local = scales_full[scales_slice] - scales_local = DTensor.from_local( - scales_local, - placements=p.placements, - device_mesh=p.device_mesh, - ) - Muon._qk_clip(p, scales_local, state.qk_clip_state.head_dim) - - -def default_is_muon(name, x): - skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"] - return x.ndim >= 2 and not any(key in name for key in skip_keys) - - -def get_default_muon_param_groups(model, is_muon_func=default_is_muon): - muon_params, muon_names = [], [] - non_muon_params = [] - - for n, p in model.named_parameters(): - if not p.requires_grad: - continue - if is_muon_func(n, p): - muon_params.append(p) - muon_names.append(n) - else: - non_muon_params.append(p) - - return [ - { - "params": muon_params, - "names": muon_names, - "use_muon": True, - }, - { - "params": non_muon_params, - "use_muon": False, - }, - ] - - -def parse_qk_layer(name: str) -> tuple[str | None, int]: - """ - Parse a parameter name to check if it is a query/key projection layer - ('wq', 'wk', 'q_proj', 'k_proj') and return (kind, layer_index). - - Returns: - (kind, layer_idx) or (None, -1) if not matched. - - Example: - 'model.3.attn.wq.weight' -> ('wq', 3) - 'model.5.attn.wk.weight' -> ('wk', 5) - 'model.2.attn.q_proj.weight' -> ('q_proj', 2) - 'model.7.attn.k_proj.weight' -> ('k_proj', 7) - 'model.4.attn.v_proj.weight' -> (None, -1) - """ - parts = name.split('.') - if len(parts) < 3: - return None, -1 - - kind = parts[-2] - - layer_idx = -1 - for part in reversed(parts): - if part.isdigit(): - layer_idx = int(part) - break - - if kind in ('wq', 'wk', 'q_proj', 'k_proj'): - return kind, layer_idx - - return None, -1 - - -@dataclass -class QKClipInfo: - """Per-parameter dynamic info computed from config + runtime logits.""" - kind: str | None # 'wq'/'q_proj' or 'wk'/'k_proj' or None - indices: list[int] # which heads to consider for clipping - head_dim: int # from config - threshold: float # from config - logit: torch.Tensor | None - - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Some warnings: - - We believe this optimizer is unlikely to work well for training with small batch size. - - We believe it may not work well for finetuning pretrained models, but we haven't tested this. - - Arguments: - model: The model to be optimized by Muon. - is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. - lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) - momentum: The momentum used by the internal SGD. (0.95 is a good default) - nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) - ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) - weight_decay: The weight decay for Muon and AdamW. - {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. - adamw_lr: The learning rate for the internal AdamW. - adamw_betas: The betas for the internal AdamW. - adamw_eps: The epsilon for the internal AdamW. - none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. - debug: Whether to print debug information. - clip_info : Configuration for QK clipping. Expected keys: - - "q_indices" (list[int]): Indices of query heads to consider. - - "k_indices" (list[int]): Indices of key heads to consider. - - "head_dim" (int): Dimensionality of each attention head. - - "threshold" (float): Threshold value; heads whose QK logits exceed - this value will be scaled down. - Default is: - { - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - } - warmup_step : How many all2all gather, compute operations are launched in advance - before the corresponding all2all scatter steps begin. - A higher warmup_step increases memory usage but can improve - performance by overlapping communication. - Parallel muon only. - chunk_size : Batch size of parameters to process in each - all2all gather/compute/scatter step. - Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. - use_distributed_muon: Use distributed muon by Liu et al. (2024). - For testing purpose only. - """ - - def __init__(self, - params, - lr=1e-3, - momentum=0.95, - nesterov=True, - ns_steps=5, - weight_decay=0.1, - adamw_betas=(0.9, 0.95), - adamw_eps=1e-8, - none_grad=True, - debug=False, - clip_config={ - "q_indices": [], - "k_indices": [], - "head_dim": 128, - "threshold": 100 - }, - warmup_step=5, - chunk_size=-1, - use_distributed_muon=False): - defaults = dict( - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - nesterov=nesterov, - ns_steps=ns_steps, - adamw_betas=adamw_betas, - adamw_eps=adamw_eps, - none_grad=none_grad, - use_muon=True, - ) - error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." - instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" - - if isinstance(params, types.GeneratorType): - raise ValueError(error_message.format(idx=0) + instruction_code) - for _idx, param_group in enumerate(params): - if param_group.get("use_muon", None) is None: - raise ValueError( - error_message.format(idx=_idx) + instruction_code) - - super().__init__(params, defaults) - - self.rank = None - - self.comm_stream = torch.cuda.Stream() - self.compute_stream = torch.cuda.Stream() - self.debug = debug - self.clip_config = clip_config - self.warmup_step = warmup_step - self.chunk_size = chunk_size - self.use_distributed_muon = use_distributed_muon - - def _calc_flops(self, G, steps): - assert len(G.shape) == 2 - M, N = G.shape - if M > N: - M, N = N, M - - return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) - - def adjust_lr_for_muon(self, lr, param_shape): - A, B = param_shape[:2] - # We adjust the learning rate and weight decay based on the size of the parameter matrix - # as describted in the paper - adjusted_ratio = 0.2 * math.sqrt(max(A, B)) - adjusted_lr = lr * adjusted_ratio - return adjusted_lr - - def set_rank_once(self, rank): - if self.rank is None: - self.rank = rank - else: - assert self.rank == rank - - def get_shard_mesh(self, p): - """ - Get the shard mesh for a parameter p on the given rank. - """ - assert isinstance( - p, DTensor), "Parallel Muon only supports DTensor parameters." - - shard_mesh, shard_pg, shard_placements = construct_shard_mesh( - p.placements, p.device_mesh) - - # set rank with the local rank in the shard process group - self.set_rank_once(dist.get_rank(group=shard_pg)) - - return shard_mesh, shard_pg, shard_placements - - def init_state_and_assign_params(self, names, params, group, qk_logits): - param_to_state = {} - param_to_flops = {} - - total_flops = 0 - for p in params: - g = p.grad - if g is None: - continue - assert g.ndim == 2, "Muon only supports 2D parameters." - - flops = self._calc_flops(g, group["ns_steps"]) - param_to_flops[id(p)] = flops - total_flops += flops - - if self.debug: - print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", - flush=True) - - paired = list(zip(names, params)) - - paired_sorted = sorted(paired, - key=lambda x: param_to_flops[id(x[1])], - reverse=True) - - names_sorted, params_sorted = zip(*paired_sorted) - ordered_names = list(names_sorted) - ordered_params = list(params_sorted) - - round_robin = 0 - mesh = ordered_params[0].device_mesh - placements = ordered_params[0].placements - - shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( - ordered_params[0]) - shard_mesh_flattened = shard_mesh.mesh.flatten() - num_ranks = dist.get_world_size(group=shard_pg) - - for n, p in zip(ordered_names, ordered_params): - if mesh != p.device_mesh: - raise ValueError("All parameters must be on the same mesh.") - if placements != p.placements: - raise ValueError("All parameters must have same placements.") - - worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks - round_robin = (round_robin + 1) % len(shard_mesh_flattened) - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - param_to_state[id(p)] = _muon_state( - worker_rank=worker_rank, - process_group=shard_pg, - shard_mesh=shard_mesh, - shard_placements=shard_placements, - name=n, - qk_clip_state=qk_clip_state, - ) - - return param_to_state, ordered_params - - def base(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - # generate weight updates in distributed fashion - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - # calc update - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if group["nesterov"]: - g = g.add(buf, alpha=momentum) - else: - g = buf - - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) - - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) - - qk_clip_state = self.get_qk_clip_info(n, qk_logits) - - scales_full = self._compute_scales( - p, qk_clip_state) if qk_clip_state is not None else None - if scales_full is not None: - Muon._qk_clip(p, scales_full, qk_clip_state.head_dim) - - def distributed_muon( - self, - names: list[str], - params: list[torch.nn.Parameter], - group: dict[str, Any], - lr: float, - weight_decay: float, - momentum: float, - qk_logits: list[torch.Tensor | DTensor] | None, - ): - """ Implementation of Distributed Muon by Liu et al. """ - if qk_logits is not None: - raise NotImplementedError("QK clipping is not supported yet") - - if isinstance(params[0], DTensor): - shard_mesh, _, shard_placements = construct_shard_mesh( - placements=params[0].placements, - mesh=params[0].device_mesh, - ) - - for n, p in zip(names, params): - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - assert g is not None - - # calc update - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if group["nesterov"]: - g = g.add(buf, alpha=momentum) - else: - g = buf - - # Gather G - if isinstance(p.data, DTensor): - g = g.full_tensor() - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) - - if isinstance(p.data, DTensor): - slices = get_slices_of_dtensor( - target=p, - local_rank=dist.get_rank(), - shard_mesh=shard_mesh, - shard_placements=shard_placements, - ) - u_shard = u[slices] - u = DTensor.from_local( - u_shard, - device_mesh=p.device_mesh, - placements=p.placements, - ) - - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) - - def _update_g(self, p, g, group, momentum): - # calc update - state = self.state[p] - buf = state.setdefault("momentum_buffer", torch.zeros_like(g)) - torch.add(g, buf, alpha=momentum, out=buf) - if group["nesterov"]: - g.add_(buf, alpha=momentum) - return g - return buf - - @staticmethod - def _update_p(p, u, lr, adjusted_lr, weight_decay): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) - - def get_qk_clip_info(self, n, qk_logits): - if self.clip_config is None: - return None - - head_dim = self.clip_config.get('head_dim') - threshold = self.clip_config.get('threshold') - kind, layer_idx = parse_qk_layer(n) - - logit, indices = None, [] - if qk_logits is not None and kind is not None: - logit = qk_logits[layer_idx] - indices_key = 'q_indices' if 'q' in kind else 'k_indices' - indices = self.clip_config.get(indices_key, []) or [] - - if isinstance(logit, DTensor): - # In TP settings, qk_logits may be DTensor - # We convert it to full tensor here for simplicity - logit = logit.full_tensor() - - return QKClipInfo( - kind=kind, - indices=indices, - head_dim=head_dim, - threshold=threshold, - logit=logit, - ) - - @staticmethod - def _compute_scales(p, qk_clip_state): - kind = qk_clip_state.kind - indices = qk_clip_state.indices - head_dim = qk_clip_state.head_dim - threshold = qk_clip_state.threshold - logit = qk_clip_state.logit - - H_global = p.shape[0] // head_dim - scales_full = torch.ones(H_global, device=p.data.device) - scaling = 0 - - for logit_idx, head_idx in enumerate(indices): - v_ele = float(logit[logit_idx]) - if v_ele > threshold: - new_scale = math.sqrt(threshold / v_ele) - if new_scale < scales_full[head_idx]: - scales_full[head_idx] = new_scale - logger.info( - f"[{kind}] Head {head_idx} exceeded threshold " - f"(value={v_ele:.4f}, threshold={threshold:.4f}) -> applying scale={new_scale:.4f}" - ) - scaling += 1 - - return scales_full if scaling > 0 else None - - @staticmethod - def _qk_clip(p, scales, head_dim): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) - - def parallel(self, names, params, group, lr, weight_decay, momentum, - qk_logits): - """ - Perform a parallel optimization step using Muon. - """ - - for p in params: - g = p.grad - if g is None: - continue - if g.ndim > 2: - g = g.view(g.size(0), -1) - - # Update g in the local rank - g = self._update_g( - p, - g, - group, - momentum=momentum, - ) - p.grad = g - - param_to_state, ordered_params = self.init_state_and_assign_params( - names, params, group, qk_logits) - - assert self.rank is not None - - def enqueue_all2all_gather(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_gathered_grad(target_params, - param_to_state, self.rank, - self.compute_stream) - _all2all_gather(target_params, param_to_state, self.rank, - self.comm_stream, group["none_grad"], - alloc_event) - - def enqueue_computes(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - _compute_u(p, state, group["ns_steps"], self.rank, - self.compute_stream) - - def enqueue_all2all_scatter(start_idx, chunk_size): - target_params = ordered_params[start_idx:start_idx + chunk_size] - if target_params: - alloc_event = _alloc_scattered_u(target_params, param_to_state, - self.rank, - self.compute_stream) - _all2all_scatter(target_params, param_to_state, self.rank, - self.comm_stream, alloc_event) - - def enqueue_update_param(start_idx, chunk_size): - for p in ordered_params[start_idx:start_idx + chunk_size]: - state = param_to_state[id(p)] - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - _update_param(p, state, lr, adjusted_lr, weight_decay, - self.rank, self.compute_stream) - - if self.chunk_size == -1: - shard_ranks = dist.get_world_size(param_to_state[id( - params[0])].process_group) - chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO - elif self.chunk_size > 0: - chunk_size = self.chunk_size - else: - raise ValueError("chunk_size must be -1 or a positive integer.") - - # Wait grad update - self.comm_stream.wait_stream(torch.cuda.current_stream()) - - warmup_step = self.warmup_step - for i in range(0, warmup_step): - enqueue_all2all_gather(i * chunk_size, chunk_size) - enqueue_computes(i * chunk_size, chunk_size) - - for i in range(0, len(params) + chunk_size - 1, chunk_size): - enqueue_all2all_scatter(i, chunk_size) - enqueue_all2all_gather(i + warmup_step * chunk_size, chunk_size) - enqueue_update_param(i, chunk_size) - enqueue_computes(i + warmup_step * chunk_size, chunk_size) - - # Wait the last update_param to finish - torch.cuda.current_stream().wait_stream(self.compute_stream) - - @staticmethod - def _fused_adamw( - params: list[torch.Tensor], - grads: list[torch.Tensor], - exp_avgs: list[torch.Tensor], - exp_avg_sqs: list[torch.Tensor], - max_exp_avg_sqs: list[torch.Tensor], - state_steps: list[torch.Tensor], - amsgrad: bool, - beta1: float, - beta2: float, - lr: float | torch.Tensor, - weight_decay: float, - eps: float, - maximize: bool, - ) -> None: - if not params: - return - - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: DeviceDict | None = ({ - lr.device: lr - } if isinstance(lr, torch.Tensor) and str(lr.device) != "cpu" else - None) - grouped_tensors = torch.optim.Optimizer._group_tensors_by_device_and_dtype( - [ - params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, - state_steps - ] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(list[torch.Tensor], device_params_) - device_grads = cast(list[torch.Tensor], device_grads_) - device_exp_avgs = cast(list[torch.Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(list[torch.Tensor], device_exp_avg_sqs_) - device_state_steps = cast(list[torch.Tensor], device_state_steps_) - - if lr_dict is not None and device not in lr_dict: - lr_dict[device] = lr.to( - device=device, - non_blocking=True) # type: ignore[union-attr] - lr = lr_dict[device] - torch._foreach_add_(device_state_steps, 1) - func = torch._fused_adamw_ - func( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - ) - - def _step_muon(self, group, qk_logits=None): - params = group["params"] - lr = group["lr"] - weight_decay = group["weight_decay"] - momentum = group["momentum"] - names = group["names"] - - param_dtensors = [] - param_tensors = [] - name_dtensors = [] - name_tensors = [] - - if self.use_distributed_muon: - self.distributed_muon(names=names, - params=params, - group=group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits) - return - - for n, p in zip(names, params): - if p is None or p.grad is None: - continue - if isinstance(p.data, DTensor): - if all( - isinstance(placement, Replicate) - for placement in p.placements): - param_tensors.append(p) - name_tensors.append(n) - else: - param_dtensors.append(p) - name_dtensors.append(n) - elif isinstance(p.data, torch.Tensor): - param_tensors.append(p) - name_tensors.append(n) - else: - raise TypeError(f"Unsupported parameter type: {type(p.data)}") - - logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors" - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) - - # To support different placements, we group parameters by placements - # and run parallel Muon on each group. - - placement_to_params = defaultdict(lambda: ([], [])) - # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] - - assert len(name_dtensors) == len(param_dtensors) - for n, p in zip(name_dtensors, param_dtensors): - placement_to_params[tuple([p.placements, - p.device_mesh])][0].append(n) - placement_to_params[tuple([p.placements, - p.device_mesh])][1].append(p) - - for _, (names, params) in placement_to_params.items(): - self.parallel( - names, - params, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - if len(param_tensors) > 0: - self.base( - name_tensors, - param_tensors, - group, - lr=lr, - weight_decay=weight_decay, - momentum=momentum, - qk_logits=qk_logits, - ) - - def _step_adamw_params(self, params, group): - params_with_grads = [] - grads = [] - moment1 = [] - moment2 = [] - max_exp_avg_sqs = [] - state_steps = [] - lr = group["lr"] - beta1, beta2 = group["adamw_betas"] - eps = group["adamw_eps"] - weight_decay = group["weight_decay"] - - for p in params: - g = p.grad - if g is None: - continue - state = self.state[p] - params_with_grads.append(p) - grads.append(g) - if "step" not in state: - state["step"] = (torch.zeros((), - dtype=torch.float32, - device=p.device)) - state["moment1"] = torch.zeros_like(g) - state["moment2"] = torch.zeros_like(g) - moment1.append(state["moment1"]) - moment2.append(state["moment2"]) - if not isinstance(state["step"], torch.Tensor): - step_tensor = torch.tensor(state["step"], - dtype=torch.float32, - device=p.device) - else: - step_tensor = state["step"] - state_steps.append(step_tensor) - - self._fused_adamw( - params_with_grads, - grads, - moment1, - moment2, - max_exp_avg_sqs, - state_steps, - amsgrad=False, - beta1=beta1, - beta2=beta2, - lr=lr, - weight_decay=weight_decay, - eps=eps, - maximize=False, - ) - - def _step_adamw(self, group): - params = group["params"] - - # group params with it's type and placement - placement_to_params: dict[tuple[Placement | type, - DeviceMesh | None]] = defaultdict(list) - for p in params: - match p: - case DTensor(): - placement_to_params[tuple([p.placements, - p.device_mesh])].append(p) - case torch.Tensor(): - placement_to_params[tuple([torch.Tensor, None])].append(p) - - for params in placement_to_params.values(): - self._step_adamw_params(params, group) - - def step(self, closure=None, qk_logits=None): - """Perform a single optimization step. - - Args: - closure (Callable, optional): A closure that reevaluates the model - and returns the loss. - qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices - to 1D tensors of shape (num_heads,), representing the maximum - QK logits across all tokens, computed as - (1 / sqrt(head_dim)) * (Q @ K^T). - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - if group["use_muon"]: - self._step_muon(group, qk_logits=qk_logits) - else: - self._step_adamw(group) - - return loss diff --git a/test/test_muon.py b/test/test_muon.py index 58f0ad64d2e9eded31d0b142b0c12210b73a3803..3c4085963941120b0c089bfbdfad3a840c00da20 100644 --- a/test/test_muon.py +++ b/test/test_muon.py @@ -23,6 +23,7 @@ def apply_muon_step( grads: list[torch.Tensor], warmup_step: int, chunk_size: int, + small_param_numel_threshold: int, qk_logits: dict[int, torch.Tensor] | None = None, use_distributed_muon: bool = False, measure_perf: bool = False, @@ -65,6 +66,7 @@ def apply_muon_step( none_grad=False, warmup_step=warmup_step, chunk_size=chunk_size, + small_param_numel_threshold=small_param_numel_threshold, use_distributed_muon=use_distributed_muon, ) @@ -136,6 +138,7 @@ def sequential_muon_result( grads=grads, warmup_step=-1, chunk_size=-1, + small_param_numel_threshold=-1, qk_logits=None, )[0].cpu() @@ -145,6 +148,7 @@ def sequential_muon_result( grads=grads, warmup_step=-1, chunk_size=-1, + small_param_numel_threshold=-1, qk_logits=qk_logits, )[0].cpu() @@ -156,6 +160,7 @@ def sequential_muon_result( OVERLAP_STEPS = [5] CHUNK_SIZES = [8] +SMALL_PARAM_NUMEL_THRESHOLDS = [65536, 1_000_000_000] @pytest.mark.parametrize("parallel_dims", [ @@ -170,6 +175,8 @@ CHUNK_SIZES = [8] @pytest.mark.parametrize("use_distributed_muon", [False]) @pytest.mark.parametrize("warmup_step", OVERLAP_STEPS) @pytest.mark.parametrize("chunk_size", CHUNK_SIZES) +@pytest.mark.parametrize("small_param_numel_threshold", + SMALL_PARAM_NUMEL_THRESHOLDS) def test_parallel_muon( request, sequential_muon_result: dict[bool, torch.nn.Module], @@ -178,6 +185,7 @@ def test_parallel_muon( use_distributed_muon: bool, warmup_step: int, chunk_size: int, + small_param_numel_threshold: int, inputs: tuple[torch.nn.Module, list[torch.Tensor], dict[int, torch.Tensor]], # from conftest.py measure_perf, # from conftest.py @@ -209,6 +217,7 @@ def test_parallel_muon( grads=grads, warmup_step=warmup_step, chunk_size=chunk_size, + small_param_numel_threshold=small_param_numel_threshold, qk_logits=qk_logits, use_distributed_muon=use_distributed_muon, measure_perf=measure_perf, diff --git a/torch-ext/optimizer/distributed/utils.py b/torch-ext/optimizer/distributed/utils.py index 0b4b58bfb329b1c015129e4c4fc99f7bfa2ab30a..6d5843506c13d9d31603b2b4e30c1c91d0baab28 100644 --- a/torch-ext/optimizer/distributed/utils.py +++ b/torch-ext/optimizer/distributed/utils.py @@ -50,7 +50,7 @@ def get_slices_of_dtensor( raise NotImplementedError( f"Dimension size {dim_size} is not divisible " f"by number of ranks {num_ranks} for shard " - f"placement on dim {dim}.") + f"placement on dim {dim}. (shape: {target.shape})") shard_size = dim_size // num_ranks @@ -64,7 +64,8 @@ def get_slices_of_dtensor( return tuple(slices) -_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, ProcessGroup]] = dict() +_ranks_to_dist_cache: dict[tuple[int, ...], tuple[DeviceMesh, + ProcessGroup]] = dict() def construct_shard_mesh( diff --git a/torch-ext/optimizer/muon.py b/torch-ext/optimizer/muon.py index cfbcca71741be70048bfd290c62148b2aceda631..dbf25575f185ff379789482068e4ecf55b9455a9 100644 --- a/torch-ext/optimizer/muon.py +++ b/torch-ext/optimizer/muon.py @@ -583,6 +583,7 @@ class Muon(torch.optim.Optimizer): Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. use_distributed_muon: Use distributed muon by Liu et al. (2024). For testing purpose only. + small_param_numel_threshold: Threshold for classifying parameters as small and falling back to distributed Muon """ def __init__(self, @@ -604,7 +605,8 @@ class Muon(torch.optim.Optimizer): }, warmup_step=5, chunk_size=-1, - use_distributed_muon=False): + use_distributed_muon=False, + small_param_numel_threshold=65536): defaults = dict( lr=lr, weight_decay=weight_decay, @@ -637,6 +639,7 @@ class Muon(torch.optim.Optimizer): self.warmup_step = warmup_step self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon + self.small_param_numel_threshold = small_param_numel_threshold def _calc_flops(self, G, steps): assert len(G.shape) == 2 @@ -745,16 +748,7 @@ class Muon(torch.optim.Optimizer): g = g.view(g.size(0), -1) assert g is not None - # calc update - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if group["nesterov"]: - g = g.add(buf, alpha=momentum) - else: - g = buf + g = self._update_g(p, g, group, momentum) u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), steps=group["ns_steps"]) @@ -780,14 +774,6 @@ class Muon(torch.optim.Optimizer): qk_logits: list[torch.Tensor | DTensor] | None, ): """ Implementation of Distributed Muon by Liu et al. """ - if qk_logits is not None: - raise NotImplementedError("QK clipping is not supported yet") - - if isinstance(params[0], DTensor): - shard_mesh, _, shard_placements = construct_shard_mesh( - placements=params[0].placements, - mesh=params[0].device_mesh, - ) for n, p in zip(names, params): g = p.grad @@ -797,39 +783,44 @@ class Muon(torch.optim.Optimizer): g = g.view(g.size(0), -1) assert g is not None - # calc update - state = self.state[p] - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(g) - buf = state["momentum_buffer"] - buf.mul_(momentum).add_(g) - if group["nesterov"]: - g = g.add(buf, alpha=momentum) - else: - g = buf + g = self._update_g(p, g, group, momentum) # Gather G if isinstance(p.data, DTensor): - g = g.full_tensor() - u = _zeropower_via_newtonschulz5(g.to(COMM_DTYPE), - steps=group["ns_steps"]) + g_full = g.full_tensor() + p_full = p.data.full_tensor() + else: + g_full = g + p_full = p + + u_full = _zeropower_via_newtonschulz5(g_full.to(COMM_DTYPE), + steps=group["ns_steps"]) + + adjusted_lr = self.adjust_lr_for_muon(lr, p_full.shape) + Muon._update_p(p_full, u_full, lr, adjusted_lr, weight_decay) + + qk_clip_state = self.get_qk_clip_info(n, qk_logits) + + scales_full = self._compute_scales( + p_full, qk_clip_state) if qk_clip_state is not None else None + + if scales_full is not None: + Muon._qk_clip(p_full, scales_full, qk_clip_state.head_dim) if isinstance(p.data, DTensor): - slices = get_slices_of_dtensor( - target=p, - local_rank=dist.get_rank(), - shard_mesh=shard_mesh, - shard_placements=shard_placements, + ndims = len(p.device_mesh.mesh.shape) + p_replicate = DTensor.from_local( + p_full, + device_mesh=p.device_mesh, + placements=[Replicate() for _ in range(ndims)], ) - u_shard = u[slices] - u = DTensor.from_local( - u_shard, + + p_sharded = p_replicate.redistribute( device_mesh=p.device_mesh, placements=p.placements, ) - adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) - Muon._update_p(p, u, lr, adjusted_lr, weight_decay) + p.copy_(p_sharded) def _update_g(self, p, g, group, momentum): # calc update @@ -843,10 +834,14 @@ class Muon(torch.optim.Optimizer): @staticmethod def _update_p(p, u, lr, adjusted_lr, weight_decay): - # apply weight decay - p.data.mul_(1 - lr * weight_decay) - # apply update - p.data.add_(u, alpha=-adjusted_lr) + if isinstance(p, torch.nn.Parameter): + # apply weight decay + p.data.mul_(1 - lr * weight_decay) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + else: + p.mul_(1 - lr * weight_decay) + p.add_(u, alpha=-adjusted_lr) def get_qk_clip_info(self, n, qk_logits): if self.clip_config is None: @@ -903,8 +898,12 @@ class Muon(torch.optim.Optimizer): @staticmethod def _qk_clip(p, scales, head_dim): - W = p.data.view(-1, head_dim, p.data.shape[1]) - W.mul_(scales.view(-1, 1, 1)) + if isinstance(p, torch.nn.Parameter): + W = p.data.view(-1, head_dim, p.data.shape[1]) + W.mul_(scales.view(-1, 1, 1)) + else: + W = p.view(-1, head_dim, p.shape[1]) + W.mul_(scales.view(-1, 1, 1)) def parallel(self, names, params, group, lr, weight_decay, momentum, qk_logits): @@ -1070,10 +1069,14 @@ class Muon(torch.optim.Optimizer): names = group["names"] param_dtensors = [] - param_tensors = [] name_dtensors = [] + + param_tensors = [] name_tensors = [] + param_dtensors_small = [] + name_dtensors_small = [] + if self.use_distributed_muon: self.distributed_muon(names=names, params=params, @@ -1084,6 +1087,8 @@ class Muon(torch.optim.Optimizer): qk_logits=qk_logits) return + # For simplicity, we use distributed Muon for small parameters + # whose number of elements is below a threshold. for n, p in zip(names, params): if p is None or p.grad is None: continue @@ -1093,6 +1098,9 @@ class Muon(torch.optim.Optimizer): for placement in p.placements): param_tensors.append(p) name_tensors.append(n) + elif p.data.numel() <= self.small_param_numel_threshold: + param_dtensors_small.append(p) + name_dtensors_small.append(n) else: param_dtensors.append(p) name_dtensors.append(n) @@ -1103,29 +1111,48 @@ class Muon(torch.optim.Optimizer): raise TypeError(f"Unsupported parameter type: {type(p.data)}") logger.debug( - f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors" - ) - - if len(param_dtensors) > 0: - if not dist.is_initialized(): - raise RuntimeError( - "Parallel Muon requires torch.distributed to be initialized." - ) + f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors, " + f"{len(param_dtensors_small)} Small DTensors") + def group_dtensors(dtensors, names): # To support different placements, we group parameters by placements # and run parallel Muon on each group. placement_to_params = defaultdict(lambda: ([], [])) # type: dict[tuple[Placement, DeviceMesh], tuple[list[str], list[DTensor]]] - assert len(name_dtensors) == len(param_dtensors) - for n, p in zip(name_dtensors, param_dtensors): + assert len(dtensors) == len(names) + for p, n in zip(dtensors, names): placement_to_params[tuple([p.placements, p.device_mesh])][0].append(n) placement_to_params[tuple([p.placements, p.device_mesh])][1].append(p) + return placement_to_params + + if len(param_dtensors_small) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) + + self.distributed_muon( + params=param_dtensors_small, + names=name_dtensors_small, + group=group, + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + qk_logits=qk_logits, + ) + + if len(param_dtensors) > 0: + if not dist.is_initialized(): + raise RuntimeError( + "Parallel Muon requires torch.distributed to be initialized." + ) - for _, (names, params) in placement_to_params.items(): + dtensor_group = group_dtensors(param_dtensors, name_dtensors) + for _, (names, params) in dtensor_group.items(): self.parallel( names, params, @@ -1215,6 +1242,7 @@ class Muon(torch.optim.Optimizer): for params in placement_to_params.values(): self._step_adamw_params(params, group) + @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step.