Kernels
wyldecat Claude Opus 4.6 commited on
Commit
e615b1c
·
1 Parent(s): 135fc66

Extract is_expert_param() helper to consolidate expert key matching

Browse files

The same normalize_fqn + component-level matching logic existed in both
default_is_muon() and _expand_expert_params(). Extract into a single
is_expert_param() function in core.py so the logic lives in one place.

[skip-build]

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

torch-ext/optimizer/core.py CHANGED
@@ -91,16 +91,25 @@ def adjust_lr_for_muon(lr, param_shape):
91
  return adjusted_lr
92
 
93
 
 
 
 
 
 
 
 
 
94
  def default_is_muon(name, x, expert_keys=None):
95
  normalized = normalize_fqn(name)
96
  parts = normalized.split(".")
97
  skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
98
  if any(key in parts for key in skip_keys):
99
- logger.info("[is_muon] %s (orig: %s): skip (matched skip_key), ndim=%d",
100
- normalized, name, x.ndim)
 
101
  return False
102
  effective_ndim = x.ndim
103
- is_expert = expert_keys and any(key in parts for key in expert_keys)
104
  if is_expert:
105
  effective_ndim -= 1
106
  result = effective_ndim >= 2
 
91
  return adjusted_lr
92
 
93
 
94
+ def is_expert_param(name, expert_keys):
95
+ """Check if a parameter name matches any expert key (component-level)."""
96
+ if not expert_keys:
97
+ return False
98
+ parts = normalize_fqn(name).split(".")
99
+ return any(key in parts for key in expert_keys)
100
+
101
+
102
  def default_is_muon(name, x, expert_keys=None):
103
  normalized = normalize_fqn(name)
104
  parts = normalized.split(".")
105
  skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
106
  if any(key in parts for key in skip_keys):
107
+ logger.info(
108
+ "[is_muon] %s (orig: %s): skip (matched skip_key), ndim=%d",
109
+ normalized, name, x.ndim)
110
  return False
111
  effective_ndim = x.ndim
112
+ is_expert = is_expert_param(name, expert_keys)
113
  if is_expert:
114
  effective_ndim -= 1
115
  result = effective_ndim >= 2
torch-ext/optimizer/muon.py CHANGED
@@ -11,7 +11,7 @@ from torch.profiler import record_function
11
  from .adamw import step_adamw
12
  from .async_utils import run_pipeline
13
  from .core import (_muon_state, adjust_lr_for_muon,
14
- get_default_muon_param_groups, normalize_fqn, update_g,
15
  update_p)
16
  from .distributed.utils import (_is_shard, construct_shard_mesh,
17
  get_slices_of_dtensor)
@@ -46,8 +46,7 @@ def _expand_expert_params(names, params, expert_keys):
46
  expanded_params = []
47
 
48
  for n, p in zip(names, params):
49
- is_expert = expert_keys and any(key in normalize_fqn(n).split(".")
50
- for key in expert_keys)
51
  is_dtensor = isinstance(p.data, DTensor)
52
 
53
  if not is_expert:
 
11
  from .adamw import step_adamw
12
  from .async_utils import run_pipeline
13
  from .core import (_muon_state, adjust_lr_for_muon,
14
+ get_default_muon_param_groups, is_expert_param, update_g,
15
  update_p)
16
  from .distributed.utils import (_is_shard, construct_shard_mesh,
17
  get_slices_of_dtensor)
 
46
  expanded_params = []
47
 
48
  for n, p in zip(names, params):
49
+ is_expert = is_expert_param(n, expert_keys)
 
50
  is_dtensor = isinstance(p.data, DTensor)
51
 
52
  if not is_expert: