Kernels
wyldecat Claude Opus 4.6 commited on
Commit
f008017
·
1 Parent(s): 95a620f

Use component-level matching for expert_keys to avoid shared_experts collision

Browse files

Substring matching (`key in name`) causes "experts" to match
"shared_experts". Switch to dot-split component exact matching
(`key in name.split(".")`) in default_is_muon() and
_expand_expert_params(). Also applies to skip_keys for consistency.

[skip-build]

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

test/test_normalize_fqn.py CHANGED
@@ -1,7 +1,7 @@
1
  """Unit tests for FQN normalization (no GPU / distributed required)."""
2
 
3
  import pytest
4
- from optimizer.core import normalize_fqn
5
  from optimizer.qk_clip import parse_qk_layer
6
 
7
 
@@ -34,10 +34,12 @@ class TestNormalizeFqn:
34
  class TestParseQkLayerWithWrappers:
35
 
36
  def test_plain_name(self):
37
- assert parse_qk_layer("model.layers.3.attn.q_proj.weight") == ("q_proj", 3)
 
38
 
39
  def test_orig_mod(self):
40
- assert parse_qk_layer("model._orig_mod.layers.3.attn.wq.weight") == ("wq", 3)
 
41
 
42
  def test_checkpoint_wrapped(self):
43
  name = "model.layers.5._checkpoint_wrapped_module.self_attn.k_proj.weight"
@@ -50,3 +52,35 @@ class TestParseQkLayerWithWrappers:
50
  def test_non_qk_still_none(self):
51
  name = "model._orig_mod.layers.2.attn.v_proj.weight"
52
  assert parse_qk_layer(name) == (None, -1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Unit tests for FQN normalization (no GPU / distributed required)."""
2
 
3
  import pytest
4
+ from optimizer.core import default_is_muon, normalize_fqn
5
  from optimizer.qk_clip import parse_qk_layer
6
 
7
 
 
34
  class TestParseQkLayerWithWrappers:
35
 
36
  def test_plain_name(self):
37
+ assert parse_qk_layer("model.layers.3.attn.q_proj.weight") == (
38
+ "q_proj", 3)
39
 
40
  def test_orig_mod(self):
41
+ assert parse_qk_layer("model._orig_mod.layers.3.attn.wq.weight") == (
42
+ "wq", 3)
43
 
44
  def test_checkpoint_wrapped(self):
45
  name = "model.layers.5._checkpoint_wrapped_module.self_attn.k_proj.weight"
 
52
  def test_non_qk_still_none(self):
53
  name = "model._orig_mod.layers.2.attn.v_proj.weight"
54
  assert parse_qk_layer(name) == (None, -1)
55
+
56
+
57
+ class TestExpertKeyMatching:
58
+ """Verify expert_keys uses component-level matching, not substring."""
59
+
60
+ class FakeParam:
61
+
62
+ def __init__(self, ndim):
63
+ self.ndim = ndim
64
+
65
+ def test_experts_matches(self):
66
+ name = "model.layers.0.moe.experts.w1.weight"
67
+ assert default_is_muon(name,
68
+ self.FakeParam(3),
69
+ expert_keys=["experts"])
70
+
71
+ def test_shared_experts_does_not_match(self):
72
+ name = "model.layers.0.moe.shared_experts.w1.weight"
73
+ # shared_experts has ndim=2, which is muon-eligible on its own.
74
+ # But it must NOT be recognized as expert (ndim-1 would make it 1D → False).
75
+ assert default_is_muon(name,
76
+ self.FakeParam(2),
77
+ expert_keys=["experts"])
78
+
79
+ def test_shared_experts_3d_not_treated_as_expert(self):
80
+ # 3D shared_experts: if wrongly matched as expert, ndim-1=2 → True (same result).
81
+ # Verify by checking that a 2D shared_experts is NOT downgraded to 1D.
82
+ name = "model.layers.0.moe.shared_experts.gate_proj.weight"
83
+ # 2D param: if expert-matched → ndim-1=1 → False. Must stay True.
84
+ assert default_is_muon(name,
85
+ self.FakeParam(2),
86
+ expert_keys=["experts"])
torch-ext/optimizer/core.py CHANGED
@@ -6,7 +6,6 @@ import torch.distributed as dist
6
  from torch.distributed import ProcessGroup
7
  from torch.distributed.tensor import DTensor
8
 
9
-
10
  # torch.compile wraps modules as OptimizedModule, inserting "_orig_mod" into
11
  # parameter FQNs. Activation checkpointing similarly inserts
12
  # "_checkpoint_wrapped_module". Strip these so name-based matching (skip_keys,
@@ -90,12 +89,12 @@ def adjust_lr_for_muon(lr, param_shape):
90
 
91
 
92
  def default_is_muon(name, x, expert_keys=None):
93
- name = normalize_fqn(name)
94
  skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
95
- if any(key in name for key in skip_keys):
96
  return False
97
  effective_ndim = x.ndim
98
- if expert_keys and any(key in name for key in expert_keys):
99
  effective_ndim -= 1
100
  return effective_ndim >= 2
101
 
 
6
  from torch.distributed import ProcessGroup
7
  from torch.distributed.tensor import DTensor
8
 
 
9
  # torch.compile wraps modules as OptimizedModule, inserting "_orig_mod" into
10
  # parameter FQNs. Activation checkpointing similarly inserts
11
  # "_checkpoint_wrapped_module". Strip these so name-based matching (skip_keys,
 
89
 
90
 
91
  def default_is_muon(name, x, expert_keys=None):
92
+ parts = normalize_fqn(name).split(".")
93
  skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
94
+ if any(key in parts for key in skip_keys):
95
  return False
96
  effective_ndim = x.ndim
97
+ if expert_keys and any(key in parts for key in expert_keys):
98
  effective_ndim -= 1
99
  return effective_ndim >= 2
100
 
torch-ext/optimizer/muon.py CHANGED
@@ -46,8 +46,8 @@ def _expand_expert_params(names, params, expert_keys):
46
  expanded_params = []
47
 
48
  for n, p in zip(names, params):
49
- is_expert = expert_keys and any(
50
- key in normalize_fqn(n) for key in expert_keys)
51
  is_dtensor = isinstance(p.data, DTensor)
52
 
53
  if not is_expert:
 
46
  expanded_params = []
47
 
48
  for n, p in zip(names, params):
49
+ is_expert = expert_keys and any(key in normalize_fqn(n).split(".")
50
+ for key in expert_keys)
51
  is_dtensor = isinstance(p.data, DTensor)
52
 
53
  if not is_expert: