Include original (pre-normalize) FQN in is_muon logging
Browse filesShow both normalized and original parameter names so wrapper-injected
components (_orig_mod, _checkpoint_wrapped_module) are visible in logs.
[skip-build]
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
torch-ext/optimizer/core.py
CHANGED
|
@@ -92,20 +92,22 @@ def adjust_lr_for_muon(lr, param_shape):
|
|
| 92 |
|
| 93 |
|
| 94 |
def default_is_muon(name, x, expert_keys=None):
|
| 95 |
-
|
|
|
|
| 96 |
skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
|
| 97 |
if any(key in parts for key in skip_keys):
|
| 98 |
-
logger.info("[is_muon] %s: skip (matched skip_key), ndim=%d",
|
| 99 |
-
x.ndim)
|
| 100 |
return False
|
| 101 |
effective_ndim = x.ndim
|
| 102 |
is_expert = expert_keys and any(key in parts for key in expert_keys)
|
| 103 |
if is_expert:
|
| 104 |
effective_ndim -= 1
|
| 105 |
result = effective_ndim >= 2
|
| 106 |
-
logger.info(
|
| 107 |
-
|
| 108 |
-
|
|
|
|
| 109 |
return result
|
| 110 |
|
| 111 |
|
|
|
|
| 92 |
|
| 93 |
|
| 94 |
def default_is_muon(name, x, expert_keys=None):
|
| 95 |
+
normalized = normalize_fqn(name)
|
| 96 |
+
parts = normalized.split(".")
|
| 97 |
skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
|
| 98 |
if any(key in parts for key in skip_keys):
|
| 99 |
+
logger.info("[is_muon] %s (orig: %s): skip (matched skip_key), ndim=%d",
|
| 100 |
+
normalized, name, x.ndim)
|
| 101 |
return False
|
| 102 |
effective_ndim = x.ndim
|
| 103 |
is_expert = expert_keys and any(key in parts for key in expert_keys)
|
| 104 |
if is_expert:
|
| 105 |
effective_ndim -= 1
|
| 106 |
result = effective_ndim >= 2
|
| 107 |
+
logger.info(
|
| 108 |
+
"[is_muon] %s (orig: %s): ndim=%d, expert=%s, effective_ndim=%d → %s",
|
| 109 |
+
normalized, name, x.ndim, is_expert, effective_ndim,
|
| 110 |
+
"Muon" if result else "AdamW")
|
| 111 |
return result
|
| 112 |
|
| 113 |
|