Kernels
File size: 3,896 Bytes
95a620f
 
5a99e12
95a620f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f008017
 
95a620f
 
f008017
 
95a620f
 
 
 
 
 
 
 
 
 
 
 
f008017
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a99e12
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""Unit tests for FQN normalization (no GPU / distributed required)."""

from optimizer.core import default_is_muon, is_expert_param, normalize_fqn
from optimizer.qk_clip import parse_qk_layer


class TestNormalizeFqn:

    def test_passthrough(self):
        assert normalize_fqn("model.layers.3.attn.q_proj.weight") == \
            "model.layers.3.attn.q_proj.weight"

    def test_strip_orig_mod(self):
        assert normalize_fqn("model._orig_mod.layers.3.attn.q_proj.weight") == \
            "model.layers.3.attn.q_proj.weight"

    def test_strip_checkpoint_wrapped(self):
        name = "model.layers.0._checkpoint_wrapped_module.moe.experts.w1.weight"
        assert normalize_fqn(name) == \
            "model.layers.0.moe.experts.w1.weight"

    def test_strip_both(self):
        name = "model._orig_mod.layers.0._checkpoint_wrapped_module.attn.q_proj.weight"
        assert normalize_fqn(name) == \
            "model.layers.0.attn.q_proj.weight"

    def test_strip_nested_orig_mod(self):
        name = "_orig_mod._orig_mod.layers.0.mlp.gate_proj.weight"
        assert normalize_fqn(name) == \
            "layers.0.mlp.gate_proj.weight"


class TestParseQkLayerWithWrappers:

    def test_plain_name(self):
        assert parse_qk_layer("model.layers.3.attn.q_proj.weight") == (
            "q_proj", 3)

    def test_orig_mod(self):
        assert parse_qk_layer("model._orig_mod.layers.3.attn.wq.weight") == (
            "wq", 3)

    def test_checkpoint_wrapped(self):
        name = "model.layers.5._checkpoint_wrapped_module.self_attn.k_proj.weight"
        assert parse_qk_layer(name) == ("k_proj", 5)

    def test_both_wrappers(self):
        name = "_orig_mod.model._checkpoint_wrapped_module.layers.7.attn.wk.weight"
        assert parse_qk_layer(name) == ("wk", 7)

    def test_non_qk_still_none(self):
        name = "model._orig_mod.layers.2.attn.v_proj.weight"
        assert parse_qk_layer(name) == (None, -1)


class TestExpertKeyMatching:
    """Verify expert_keys uses component-level matching, not substring."""

    class FakeParam:

        def __init__(self, ndim):
            self.ndim = ndim

    def test_experts_matches(self):
        name = "model.layers.0.moe.experts.w1.weight"
        assert default_is_muon(name,
                               self.FakeParam(3),
                               expert_keys=["experts"])

    def test_shared_experts_does_not_match(self):
        name = "model.layers.0.moe.shared_experts.w1.weight"
        # shared_experts has ndim=2, which is muon-eligible on its own.
        # But it must NOT be recognized as expert (ndim-1 would make it 1D → False).
        assert default_is_muon(name,
                               self.FakeParam(2),
                               expert_keys=["experts"])

    def test_shared_experts_3d_not_treated_as_expert(self):
        # 3D shared_experts: if wrongly matched as expert, ndim-1=2 → True (same result).
        # Verify by checking that a 2D shared_experts is NOT downgraded to 1D.
        name = "model.layers.0.moe.shared_experts.gate_proj.weight"
        # 2D param: if expert-matched → ndim-1=1 → False. Must stay True.
        assert default_is_muon(name,
                               self.FakeParam(2),
                               expert_keys=["experts"])

    def test_multi_component_key_matches(self):
        name = "model.layers.0.moe.experts.w1.weight"
        assert is_expert_param(name, expert_keys=["experts.w1"])

    def test_multi_component_key_no_false_positive(self):
        # "experts.w2" should not match "experts.w1"
        name = "model.layers.0.moe.experts.w1.weight"
        assert not is_expert_param(name, expert_keys=["experts.w2"])

    def test_multi_component_key_shared_experts(self):
        name = "model.layers.0.moe.shared_experts.w1.weight"
        assert not is_expert_param(name, expert_keys=["experts.w1"])