Add info-level logging for param group classification (Muon vs AdamW)
Browse filesLog each parameter's skip/expert status and effective ndim in
default_is_muon(), and summarize Muon/AdamW param lists in
get_default_muon_param_groups().
[skip-build]
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- torch-ext/optimizer/core.py +20 -3
torch-ext/optimizer/core.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import math
|
| 2 |
from dataclasses import dataclass
|
| 3 |
|
|
@@ -12,6 +13,8 @@ from torch.distributed.tensor import DTensor
|
|
| 12 |
# expert_keys, QK layer parsing) works regardless of wrapper nesting.
|
| 13 |
_WRAPPER_PARTS = frozenset({"_orig_mod", "_checkpoint_wrapped_module"})
|
| 14 |
|
|
|
|
|
|
|
| 15 |
|
| 16 |
def normalize_fqn(name: str) -> str:
|
| 17 |
"""Strip torch.compile / checkpoint wrapper components from a parameter FQN."""
|
|
@@ -92,11 +95,18 @@ def default_is_muon(name, x, expert_keys=None):
|
|
| 92 |
parts = normalize_fqn(name).split(".")
|
| 93 |
skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
|
| 94 |
if any(key in parts for key in skip_keys):
|
|
|
|
|
|
|
| 95 |
return False
|
| 96 |
effective_ndim = x.ndim
|
| 97 |
-
|
|
|
|
| 98 |
effective_ndim -= 1
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
|
| 102 |
def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
|
|
@@ -104,7 +114,7 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
|
|
| 104 |
is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys)
|
| 105 |
|
| 106 |
muon_params, muon_names = [], []
|
| 107 |
-
non_muon_params = []
|
| 108 |
|
| 109 |
for n, p in model.named_parameters():
|
| 110 |
if not p.requires_grad:
|
|
@@ -114,6 +124,13 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
|
|
| 114 |
muon_names.append(n)
|
| 115 |
else:
|
| 116 |
non_muon_params.append(p)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
return [
|
| 119 |
{
|
|
|
|
| 1 |
+
import logging
|
| 2 |
import math
|
| 3 |
from dataclasses import dataclass
|
| 4 |
|
|
|
|
| 13 |
# expert_keys, QK layer parsing) works regardless of wrapper nesting.
|
| 14 |
_WRAPPER_PARTS = frozenset({"_orig_mod", "_checkpoint_wrapped_module"})
|
| 15 |
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
|
| 19 |
def normalize_fqn(name: str) -> str:
|
| 20 |
"""Strip torch.compile / checkpoint wrapper components from a parameter FQN."""
|
|
|
|
| 95 |
parts = normalize_fqn(name).split(".")
|
| 96 |
skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
|
| 97 |
if any(key in parts for key in skip_keys):
|
| 98 |
+
logger.info("[is_muon] %s: skip (matched skip_key), ndim=%d", name,
|
| 99 |
+
x.ndim)
|
| 100 |
return False
|
| 101 |
effective_ndim = x.ndim
|
| 102 |
+
is_expert = expert_keys and any(key in parts for key in expert_keys)
|
| 103 |
+
if is_expert:
|
| 104 |
effective_ndim -= 1
|
| 105 |
+
result = effective_ndim >= 2
|
| 106 |
+
logger.info("[is_muon] %s: ndim=%d, expert=%s, effective_ndim=%d → %s",
|
| 107 |
+
name, x.ndim, is_expert, effective_ndim,
|
| 108 |
+
"Muon" if result else "AdamW")
|
| 109 |
+
return result
|
| 110 |
|
| 111 |
|
| 112 |
def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
|
|
|
|
| 114 |
is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys)
|
| 115 |
|
| 116 |
muon_params, muon_names = [], []
|
| 117 |
+
non_muon_params, non_muon_names = [], []
|
| 118 |
|
| 119 |
for n, p in model.named_parameters():
|
| 120 |
if not p.requires_grad:
|
|
|
|
| 124 |
muon_names.append(n)
|
| 125 |
else:
|
| 126 |
non_muon_params.append(p)
|
| 127 |
+
non_muon_names.append(n)
|
| 128 |
+
|
| 129 |
+
logger.info("[param_groups] expert_keys=%s", expert_keys)
|
| 130 |
+
logger.info("[param_groups] Muon params (%d): %s", len(muon_names),
|
| 131 |
+
muon_names)
|
| 132 |
+
logger.info("[param_groups] AdamW params (%d): %s", len(non_muon_names),
|
| 133 |
+
non_muon_names)
|
| 134 |
|
| 135 |
return [
|
| 136 |
{
|