Use component-level matching for expert_keys to avoid shared_experts collision
Browse filesSubstring matching (`key in name`) causes "experts" to match
"shared_experts". Switch to dot-split component exact matching
(`key in name.split(".")`) in default_is_muon() and
_expand_expert_params(). Also applies to skip_keys for consistency.
[skip-build]
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- test/test_normalize_fqn.py +37 -3
- torch-ext/optimizer/core.py +3 -4
- torch-ext/optimizer/muon.py +2 -2
test/test_normalize_fqn.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
"""Unit tests for FQN normalization (no GPU / distributed required)."""
|
| 2 |
|
| 3 |
import pytest
|
| 4 |
-
from optimizer.core import normalize_fqn
|
| 5 |
from optimizer.qk_clip import parse_qk_layer
|
| 6 |
|
| 7 |
|
|
@@ -34,10 +34,12 @@ class TestNormalizeFqn:
|
|
| 34 |
class TestParseQkLayerWithWrappers:
|
| 35 |
|
| 36 |
def test_plain_name(self):
|
| 37 |
-
assert parse_qk_layer("model.layers.3.attn.q_proj.weight") == (
|
|
|
|
| 38 |
|
| 39 |
def test_orig_mod(self):
|
| 40 |
-
assert parse_qk_layer("model._orig_mod.layers.3.attn.wq.weight") == (
|
|
|
|
| 41 |
|
| 42 |
def test_checkpoint_wrapped(self):
|
| 43 |
name = "model.layers.5._checkpoint_wrapped_module.self_attn.k_proj.weight"
|
|
@@ -50,3 +52,35 @@ class TestParseQkLayerWithWrappers:
|
|
| 50 |
def test_non_qk_still_none(self):
|
| 51 |
name = "model._orig_mod.layers.2.attn.v_proj.weight"
|
| 52 |
assert parse_qk_layer(name) == (None, -1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""Unit tests for FQN normalization (no GPU / distributed required)."""
|
| 2 |
|
| 3 |
import pytest
|
| 4 |
+
from optimizer.core import default_is_muon, normalize_fqn
|
| 5 |
from optimizer.qk_clip import parse_qk_layer
|
| 6 |
|
| 7 |
|
|
|
|
| 34 |
class TestParseQkLayerWithWrappers:
|
| 35 |
|
| 36 |
def test_plain_name(self):
|
| 37 |
+
assert parse_qk_layer("model.layers.3.attn.q_proj.weight") == (
|
| 38 |
+
"q_proj", 3)
|
| 39 |
|
| 40 |
def test_orig_mod(self):
|
| 41 |
+
assert parse_qk_layer("model._orig_mod.layers.3.attn.wq.weight") == (
|
| 42 |
+
"wq", 3)
|
| 43 |
|
| 44 |
def test_checkpoint_wrapped(self):
|
| 45 |
name = "model.layers.5._checkpoint_wrapped_module.self_attn.k_proj.weight"
|
|
|
|
| 52 |
def test_non_qk_still_none(self):
|
| 53 |
name = "model._orig_mod.layers.2.attn.v_proj.weight"
|
| 54 |
assert parse_qk_layer(name) == (None, -1)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class TestExpertKeyMatching:
|
| 58 |
+
"""Verify expert_keys uses component-level matching, not substring."""
|
| 59 |
+
|
| 60 |
+
class FakeParam:
|
| 61 |
+
|
| 62 |
+
def __init__(self, ndim):
|
| 63 |
+
self.ndim = ndim
|
| 64 |
+
|
| 65 |
+
def test_experts_matches(self):
|
| 66 |
+
name = "model.layers.0.moe.experts.w1.weight"
|
| 67 |
+
assert default_is_muon(name,
|
| 68 |
+
self.FakeParam(3),
|
| 69 |
+
expert_keys=["experts"])
|
| 70 |
+
|
| 71 |
+
def test_shared_experts_does_not_match(self):
|
| 72 |
+
name = "model.layers.0.moe.shared_experts.w1.weight"
|
| 73 |
+
# shared_experts has ndim=2, which is muon-eligible on its own.
|
| 74 |
+
# But it must NOT be recognized as expert (ndim-1 would make it 1D → False).
|
| 75 |
+
assert default_is_muon(name,
|
| 76 |
+
self.FakeParam(2),
|
| 77 |
+
expert_keys=["experts"])
|
| 78 |
+
|
| 79 |
+
def test_shared_experts_3d_not_treated_as_expert(self):
|
| 80 |
+
# 3D shared_experts: if wrongly matched as expert, ndim-1=2 → True (same result).
|
| 81 |
+
# Verify by checking that a 2D shared_experts is NOT downgraded to 1D.
|
| 82 |
+
name = "model.layers.0.moe.shared_experts.gate_proj.weight"
|
| 83 |
+
# 2D param: if expert-matched → ndim-1=1 → False. Must stay True.
|
| 84 |
+
assert default_is_muon(name,
|
| 85 |
+
self.FakeParam(2),
|
| 86 |
+
expert_keys=["experts"])
|
torch-ext/optimizer/core.py
CHANGED
|
@@ -6,7 +6,6 @@ import torch.distributed as dist
|
|
| 6 |
from torch.distributed import ProcessGroup
|
| 7 |
from torch.distributed.tensor import DTensor
|
| 8 |
|
| 9 |
-
|
| 10 |
# torch.compile wraps modules as OptimizedModule, inserting "_orig_mod" into
|
| 11 |
# parameter FQNs. Activation checkpointing similarly inserts
|
| 12 |
# "_checkpoint_wrapped_module". Strip these so name-based matching (skip_keys,
|
|
@@ -90,12 +89,12 @@ def adjust_lr_for_muon(lr, param_shape):
|
|
| 90 |
|
| 91 |
|
| 92 |
def default_is_muon(name, x, expert_keys=None):
|
| 93 |
-
|
| 94 |
skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
|
| 95 |
-
if any(key in
|
| 96 |
return False
|
| 97 |
effective_ndim = x.ndim
|
| 98 |
-
if expert_keys and any(key in
|
| 99 |
effective_ndim -= 1
|
| 100 |
return effective_ndim >= 2
|
| 101 |
|
|
|
|
| 6 |
from torch.distributed import ProcessGroup
|
| 7 |
from torch.distributed.tensor import DTensor
|
| 8 |
|
|
|
|
| 9 |
# torch.compile wraps modules as OptimizedModule, inserting "_orig_mod" into
|
| 10 |
# parameter FQNs. Activation checkpointing similarly inserts
|
| 11 |
# "_checkpoint_wrapped_module". Strip these so name-based matching (skip_keys,
|
|
|
|
| 89 |
|
| 90 |
|
| 91 |
def default_is_muon(name, x, expert_keys=None):
|
| 92 |
+
parts = normalize_fqn(name).split(".")
|
| 93 |
skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
|
| 94 |
+
if any(key in parts for key in skip_keys):
|
| 95 |
return False
|
| 96 |
effective_ndim = x.ndim
|
| 97 |
+
if expert_keys and any(key in parts for key in expert_keys):
|
| 98 |
effective_ndim -= 1
|
| 99 |
return effective_ndim >= 2
|
| 100 |
|
torch-ext/optimizer/muon.py
CHANGED
|
@@ -46,8 +46,8 @@ def _expand_expert_params(names, params, expert_keys):
|
|
| 46 |
expanded_params = []
|
| 47 |
|
| 48 |
for n, p in zip(names, params):
|
| 49 |
-
is_expert = expert_keys and any(
|
| 50 |
-
|
| 51 |
is_dtensor = isinstance(p.data, DTensor)
|
| 52 |
|
| 53 |
if not is_expert:
|
|
|
|
| 46 |
expanded_params = []
|
| 47 |
|
| 48 |
for n, p in zip(names, params):
|
| 49 |
+
is_expert = expert_keys and any(key in normalize_fqn(n).split(".")
|
| 50 |
+
for key in expert_keys)
|
| 51 |
is_dtensor = isinstance(p.data, DTensor)
|
| 52 |
|
| 53 |
if not is_expert:
|