Kernels
wyldecat Claude Opus 4.6 commited on
Commit
5a99e12
·
1 Parent(s): e615b1c

Support multi-component expert_keys (e.g. "experts.w1")

Browse files

Single-component keys match any single FQN component as before.
Multi-component keys (containing dots) now match as a contiguous
subsequence, so "experts.w1" matches "moe.experts.w1.weight" but
not "moe.shared_experts.w1.weight".

[skip-build]

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

test/test_normalize_fqn.py CHANGED
@@ -1,7 +1,7 @@
1
  """Unit tests for FQN normalization (no GPU / distributed required)."""
2
 
3
  import pytest
4
- from optimizer.core import default_is_muon, normalize_fqn
5
  from optimizer.qk_clip import parse_qk_layer
6
 
7
 
@@ -84,3 +84,16 @@ class TestExpertKeyMatching:
84
  assert default_is_muon(name,
85
  self.FakeParam(2),
86
  expert_keys=["experts"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Unit tests for FQN normalization (no GPU / distributed required)."""
2
 
3
  import pytest
4
+ from optimizer.core import default_is_muon, is_expert_param, normalize_fqn
5
  from optimizer.qk_clip import parse_qk_layer
6
 
7
 
 
84
  assert default_is_muon(name,
85
  self.FakeParam(2),
86
  expert_keys=["experts"])
87
+
88
+ def test_multi_component_key_matches(self):
89
+ name = "model.layers.0.moe.experts.w1.weight"
90
+ assert is_expert_param(name, expert_keys=["experts.w1"])
91
+
92
+ def test_multi_component_key_no_false_positive(self):
93
+ # "experts.w2" should not match "experts.w1"
94
+ name = "model.layers.0.moe.experts.w1.weight"
95
+ assert not is_expert_param(name, expert_keys=["experts.w2"])
96
+
97
+ def test_multi_component_key_shared_experts(self):
98
+ name = "model.layers.0.moe.shared_experts.w1.weight"
99
+ assert not is_expert_param(name, expert_keys=["experts.w1"])
torch-ext/optimizer/core.py CHANGED
@@ -91,12 +91,26 @@ def adjust_lr_for_muon(lr, param_shape):
91
  return adjusted_lr
92
 
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  def is_expert_param(name, expert_keys):
95
  """Check if a parameter name matches any expert key (component-level)."""
96
  if not expert_keys:
97
  return False
98
  parts = normalize_fqn(name).split(".")
99
- return any(key in parts for key in expert_keys)
100
 
101
 
102
  def default_is_muon(name, x, expert_keys=None):
 
91
  return adjusted_lr
92
 
93
 
94
+ def _match_key(parts, key):
95
+ """Check if key matches as contiguous components in parts.
96
+
97
+ Single-component keys (e.g. "experts") match any single component.
98
+ Multi-component keys (e.g. "experts.w1") match as a contiguous subsequence.
99
+ """
100
+ key_parts = key.split(".")
101
+ key_len = len(key_parts)
102
+ if key_len == 1:
103
+ return key in parts
104
+ return any(parts[i:i + key_len] == key_parts
105
+ for i in range(len(parts) - key_len + 1))
106
+
107
+
108
  def is_expert_param(name, expert_keys):
109
  """Check if a parameter name matches any expert key (component-level)."""
110
  if not expert_keys:
111
  return False
112
  parts = normalize_fqn(name).split(".")
113
+ return any(_match_key(parts, key) for key in expert_keys)
114
 
115
 
116
  def default_is_muon(name, x, expert_keys=None):