Extract is_expert_param() helper to consolidate expert key matching
Browse filesThe same normalize_fqn + component-level matching logic existed in both
default_is_muon() and _expand_expert_params(). Extract into a single
is_expert_param() function in core.py so the logic lives in one place.
[skip-build]
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- torch-ext/optimizer/core.py +12 -3
- torch-ext/optimizer/muon.py +2 -3
torch-ext/optimizer/core.py
CHANGED
|
@@ -91,16 +91,25 @@ def adjust_lr_for_muon(lr, param_shape):
|
|
| 91 |
return adjusted_lr
|
| 92 |
|
| 93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
def default_is_muon(name, x, expert_keys=None):
|
| 95 |
normalized = normalize_fqn(name)
|
| 96 |
parts = normalized.split(".")
|
| 97 |
skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
|
| 98 |
if any(key in parts for key in skip_keys):
|
| 99 |
-
logger.info(
|
| 100 |
-
|
|
|
|
| 101 |
return False
|
| 102 |
effective_ndim = x.ndim
|
| 103 |
-
is_expert =
|
| 104 |
if is_expert:
|
| 105 |
effective_ndim -= 1
|
| 106 |
result = effective_ndim >= 2
|
|
|
|
| 91 |
return adjusted_lr
|
| 92 |
|
| 93 |
|
| 94 |
+
def is_expert_param(name, expert_keys):
|
| 95 |
+
"""Check if a parameter name matches any expert key (component-level)."""
|
| 96 |
+
if not expert_keys:
|
| 97 |
+
return False
|
| 98 |
+
parts = normalize_fqn(name).split(".")
|
| 99 |
+
return any(key in parts for key in expert_keys)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
def default_is_muon(name, x, expert_keys=None):
|
| 103 |
normalized = normalize_fqn(name)
|
| 104 |
parts = normalized.split(".")
|
| 105 |
skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
|
| 106 |
if any(key in parts for key in skip_keys):
|
| 107 |
+
logger.info(
|
| 108 |
+
"[is_muon] %s (orig: %s): skip (matched skip_key), ndim=%d",
|
| 109 |
+
normalized, name, x.ndim)
|
| 110 |
return False
|
| 111 |
effective_ndim = x.ndim
|
| 112 |
+
is_expert = is_expert_param(name, expert_keys)
|
| 113 |
if is_expert:
|
| 114 |
effective_ndim -= 1
|
| 115 |
result = effective_ndim >= 2
|
torch-ext/optimizer/muon.py
CHANGED
|
@@ -11,7 +11,7 @@ from torch.profiler import record_function
|
|
| 11 |
from .adamw import step_adamw
|
| 12 |
from .async_utils import run_pipeline
|
| 13 |
from .core import (_muon_state, adjust_lr_for_muon,
|
| 14 |
-
get_default_muon_param_groups,
|
| 15 |
update_p)
|
| 16 |
from .distributed.utils import (_is_shard, construct_shard_mesh,
|
| 17 |
get_slices_of_dtensor)
|
|
@@ -46,8 +46,7 @@ def _expand_expert_params(names, params, expert_keys):
|
|
| 46 |
expanded_params = []
|
| 47 |
|
| 48 |
for n, p in zip(names, params):
|
| 49 |
-
is_expert =
|
| 50 |
-
for key in expert_keys)
|
| 51 |
is_dtensor = isinstance(p.data, DTensor)
|
| 52 |
|
| 53 |
if not is_expert:
|
|
|
|
| 11 |
from .adamw import step_adamw
|
| 12 |
from .async_utils import run_pipeline
|
| 13 |
from .core import (_muon_state, adjust_lr_for_muon,
|
| 14 |
+
get_default_muon_param_groups, is_expert_param, update_g,
|
| 15 |
update_p)
|
| 16 |
from .distributed.utils import (_is_shard, construct_shard_mesh,
|
| 17 |
get_slices_of_dtensor)
|
|
|
|
| 46 |
expanded_params = []
|
| 47 |
|
| 48 |
for n, p in zip(names, params):
|
| 49 |
+
is_expert = is_expert_param(n, expert_keys)
|
|
|
|
| 50 |
is_dtensor = isinstance(p.data, DTensor)
|
| 51 |
|
| 52 |
if not is_expert:
|