Normalize parameter FQNs to handle torch.compile / checkpoint wrappers
Browse filestorch.compile wraps modules as OptimizedModule, inserting _orig_mod into
parameter FQNs. Activation checkpointing similarly inserts
_checkpoint_wrapped_module. These wrapper components break name-based
matching for skip_keys, expert_keys, and QK layer parsing.
Add normalize_fqn() that strips these wrapper components, and apply it
in default_is_muon(), _expand_expert_params(), and parse_qk_layer().
[skip-build]
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- test/test_normalize_fqn.py +52 -0
- torch-ext/optimizer/core.py +13 -0
- torch-ext/optimizer/muon.py +4 -2
- torch-ext/optimizer/qk_clip.py +3 -1
test/test_normalize_fqn.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unit tests for FQN normalization (no GPU / distributed required)."""
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
from optimizer.core import normalize_fqn
|
| 5 |
+
from optimizer.qk_clip import parse_qk_layer
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class TestNormalizeFqn:
|
| 9 |
+
|
| 10 |
+
def test_passthrough(self):
|
| 11 |
+
assert normalize_fqn("model.layers.3.attn.q_proj.weight") == \
|
| 12 |
+
"model.layers.3.attn.q_proj.weight"
|
| 13 |
+
|
| 14 |
+
def test_strip_orig_mod(self):
|
| 15 |
+
assert normalize_fqn("model._orig_mod.layers.3.attn.q_proj.weight") == \
|
| 16 |
+
"model.layers.3.attn.q_proj.weight"
|
| 17 |
+
|
| 18 |
+
def test_strip_checkpoint_wrapped(self):
|
| 19 |
+
name = "model.layers.0._checkpoint_wrapped_module.moe.experts.w1.weight"
|
| 20 |
+
assert normalize_fqn(name) == \
|
| 21 |
+
"model.layers.0.moe.experts.w1.weight"
|
| 22 |
+
|
| 23 |
+
def test_strip_both(self):
|
| 24 |
+
name = "model._orig_mod.layers.0._checkpoint_wrapped_module.attn.q_proj.weight"
|
| 25 |
+
assert normalize_fqn(name) == \
|
| 26 |
+
"model.layers.0.attn.q_proj.weight"
|
| 27 |
+
|
| 28 |
+
def test_strip_nested_orig_mod(self):
|
| 29 |
+
name = "_orig_mod._orig_mod.layers.0.mlp.gate_proj.weight"
|
| 30 |
+
assert normalize_fqn(name) == \
|
| 31 |
+
"layers.0.mlp.gate_proj.weight"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class TestParseQkLayerWithWrappers:
|
| 35 |
+
|
| 36 |
+
def test_plain_name(self):
|
| 37 |
+
assert parse_qk_layer("model.layers.3.attn.q_proj.weight") == ("q_proj", 3)
|
| 38 |
+
|
| 39 |
+
def test_orig_mod(self):
|
| 40 |
+
assert parse_qk_layer("model._orig_mod.layers.3.attn.wq.weight") == ("wq", 3)
|
| 41 |
+
|
| 42 |
+
def test_checkpoint_wrapped(self):
|
| 43 |
+
name = "model.layers.5._checkpoint_wrapped_module.self_attn.k_proj.weight"
|
| 44 |
+
assert parse_qk_layer(name) == ("k_proj", 5)
|
| 45 |
+
|
| 46 |
+
def test_both_wrappers(self):
|
| 47 |
+
name = "_orig_mod.model._checkpoint_wrapped_module.layers.7.attn.wk.weight"
|
| 48 |
+
assert parse_qk_layer(name) == ("wk", 7)
|
| 49 |
+
|
| 50 |
+
def test_non_qk_still_none(self):
|
| 51 |
+
name = "model._orig_mod.layers.2.attn.v_proj.weight"
|
| 52 |
+
assert parse_qk_layer(name) == (None, -1)
|
torch-ext/optimizer/core.py
CHANGED
|
@@ -7,6 +7,18 @@ from torch.distributed import ProcessGroup
|
|
| 7 |
from torch.distributed.tensor import DTensor
|
| 8 |
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
@dataclass
|
| 11 |
class _muon_state:
|
| 12 |
worker_rank: int
|
|
@@ -78,6 +90,7 @@ def adjust_lr_for_muon(lr, param_shape):
|
|
| 78 |
|
| 79 |
|
| 80 |
def default_is_muon(name, x, expert_keys=None):
|
|
|
|
| 81 |
skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
|
| 82 |
if any(key in name for key in skip_keys):
|
| 83 |
return False
|
|
|
|
| 7 |
from torch.distributed.tensor import DTensor
|
| 8 |
|
| 9 |
|
| 10 |
+
# torch.compile wraps modules as OptimizedModule, inserting "_orig_mod" into
|
| 11 |
+
# parameter FQNs. Activation checkpointing similarly inserts
|
| 12 |
+
# "_checkpoint_wrapped_module". Strip these so name-based matching (skip_keys,
|
| 13 |
+
# expert_keys, QK layer parsing) works regardless of wrapper nesting.
|
| 14 |
+
_WRAPPER_PARTS = frozenset({"_orig_mod", "_checkpoint_wrapped_module"})
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def normalize_fqn(name: str) -> str:
|
| 18 |
+
"""Strip torch.compile / checkpoint wrapper components from a parameter FQN."""
|
| 19 |
+
return ".".join(p for p in name.split(".") if p not in _WRAPPER_PARTS)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
@dataclass
|
| 23 |
class _muon_state:
|
| 24 |
worker_rank: int
|
|
|
|
| 90 |
|
| 91 |
|
| 92 |
def default_is_muon(name, x, expert_keys=None):
|
| 93 |
+
name = normalize_fqn(name)
|
| 94 |
skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
|
| 95 |
if any(key in name for key in skip_keys):
|
| 96 |
return False
|
torch-ext/optimizer/muon.py
CHANGED
|
@@ -11,7 +11,8 @@ from torch.profiler import record_function
|
|
| 11 |
from .adamw import step_adamw
|
| 12 |
from .async_utils import run_pipeline
|
| 13 |
from .core import (_muon_state, adjust_lr_for_muon,
|
| 14 |
-
get_default_muon_param_groups,
|
|
|
|
| 15 |
from .distributed.utils import (_is_shard, construct_shard_mesh,
|
| 16 |
get_slices_of_dtensor)
|
| 17 |
from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO,
|
|
@@ -45,7 +46,8 @@ def _expand_expert_params(names, params, expert_keys):
|
|
| 45 |
expanded_params = []
|
| 46 |
|
| 47 |
for n, p in zip(names, params):
|
| 48 |
-
is_expert = expert_keys and any(
|
|
|
|
| 49 |
is_dtensor = isinstance(p.data, DTensor)
|
| 50 |
|
| 51 |
if not is_expert:
|
|
|
|
| 11 |
from .adamw import step_adamw
|
| 12 |
from .async_utils import run_pipeline
|
| 13 |
from .core import (_muon_state, adjust_lr_for_muon,
|
| 14 |
+
get_default_muon_param_groups, normalize_fqn, update_g,
|
| 15 |
+
update_p)
|
| 16 |
from .distributed.utils import (_is_shard, construct_shard_mesh,
|
| 17 |
get_slices_of_dtensor)
|
| 18 |
from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO,
|
|
|
|
| 46 |
expanded_params = []
|
| 47 |
|
| 48 |
for n, p in zip(names, params):
|
| 49 |
+
is_expert = expert_keys and any(
|
| 50 |
+
key in normalize_fqn(n) for key in expert_keys)
|
| 51 |
is_dtensor = isinstance(p.data, DTensor)
|
| 52 |
|
| 53 |
if not is_expert:
|
torch-ext/optimizer/qk_clip.py
CHANGED
|
@@ -5,6 +5,8 @@ from dataclasses import dataclass
|
|
| 5 |
import torch
|
| 6 |
from torch.distributed.tensor import DTensor
|
| 7 |
|
|
|
|
|
|
|
| 8 |
logger = logging.getLogger(__name__)
|
| 9 |
|
| 10 |
|
|
@@ -23,7 +25,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
|
|
| 23 |
'model.7.attn.k_proj.weight' -> ('k_proj', 7)
|
| 24 |
'model.4.attn.v_proj.weight' -> (None, -1)
|
| 25 |
"""
|
| 26 |
-
parts = name.split('.')
|
| 27 |
if len(parts) < 3:
|
| 28 |
return None, -1
|
| 29 |
|
|
|
|
| 5 |
import torch
|
| 6 |
from torch.distributed.tensor import DTensor
|
| 7 |
|
| 8 |
+
from .core import normalize_fqn
|
| 9 |
+
|
| 10 |
logger = logging.getLogger(__name__)
|
| 11 |
|
| 12 |
|
|
|
|
| 25 |
'model.7.attn.k_proj.weight' -> ('k_proj', 7)
|
| 26 |
'model.4.attn.v_proj.weight' -> (None, -1)
|
| 27 |
"""
|
| 28 |
+
parts = normalize_fqn(name).split('.')
|
| 29 |
if len(parts) < 3:
|
| 30 |
return None, -1
|
| 31 |
|