Kernels
wyldecat Claude Opus 4.6 commited on
Commit
95a620f
·
1 Parent(s): b220459

Normalize parameter FQNs to handle torch.compile / checkpoint wrappers

Browse files

torch.compile wraps modules as OptimizedModule, inserting _orig_mod into
parameter FQNs. Activation checkpointing similarly inserts
_checkpoint_wrapped_module. These wrapper components break name-based
matching for skip_keys, expert_keys, and QK layer parsing.

Add normalize_fqn() that strips these wrapper components, and apply it
in default_is_muon(), _expand_expert_params(), and parse_qk_layer().

[skip-build]

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

test/test_normalize_fqn.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for FQN normalization (no GPU / distributed required)."""
2
+
3
+ import pytest
4
+ from optimizer.core import normalize_fqn
5
+ from optimizer.qk_clip import parse_qk_layer
6
+
7
+
8
+ class TestNormalizeFqn:
9
+
10
+ def test_passthrough(self):
11
+ assert normalize_fqn("model.layers.3.attn.q_proj.weight") == \
12
+ "model.layers.3.attn.q_proj.weight"
13
+
14
+ def test_strip_orig_mod(self):
15
+ assert normalize_fqn("model._orig_mod.layers.3.attn.q_proj.weight") == \
16
+ "model.layers.3.attn.q_proj.weight"
17
+
18
+ def test_strip_checkpoint_wrapped(self):
19
+ name = "model.layers.0._checkpoint_wrapped_module.moe.experts.w1.weight"
20
+ assert normalize_fqn(name) == \
21
+ "model.layers.0.moe.experts.w1.weight"
22
+
23
+ def test_strip_both(self):
24
+ name = "model._orig_mod.layers.0._checkpoint_wrapped_module.attn.q_proj.weight"
25
+ assert normalize_fqn(name) == \
26
+ "model.layers.0.attn.q_proj.weight"
27
+
28
+ def test_strip_nested_orig_mod(self):
29
+ name = "_orig_mod._orig_mod.layers.0.mlp.gate_proj.weight"
30
+ assert normalize_fqn(name) == \
31
+ "layers.0.mlp.gate_proj.weight"
32
+
33
+
34
+ class TestParseQkLayerWithWrappers:
35
+
36
+ def test_plain_name(self):
37
+ assert parse_qk_layer("model.layers.3.attn.q_proj.weight") == ("q_proj", 3)
38
+
39
+ def test_orig_mod(self):
40
+ assert parse_qk_layer("model._orig_mod.layers.3.attn.wq.weight") == ("wq", 3)
41
+
42
+ def test_checkpoint_wrapped(self):
43
+ name = "model.layers.5._checkpoint_wrapped_module.self_attn.k_proj.weight"
44
+ assert parse_qk_layer(name) == ("k_proj", 5)
45
+
46
+ def test_both_wrappers(self):
47
+ name = "_orig_mod.model._checkpoint_wrapped_module.layers.7.attn.wk.weight"
48
+ assert parse_qk_layer(name) == ("wk", 7)
49
+
50
+ def test_non_qk_still_none(self):
51
+ name = "model._orig_mod.layers.2.attn.v_proj.weight"
52
+ assert parse_qk_layer(name) == (None, -1)
torch-ext/optimizer/core.py CHANGED
@@ -7,6 +7,18 @@ from torch.distributed import ProcessGroup
7
  from torch.distributed.tensor import DTensor
8
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  @dataclass
11
  class _muon_state:
12
  worker_rank: int
@@ -78,6 +90,7 @@ def adjust_lr_for_muon(lr, param_shape):
78
 
79
 
80
  def default_is_muon(name, x, expert_keys=None):
 
81
  skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
82
  if any(key in name for key in skip_keys):
83
  return False
 
7
  from torch.distributed.tensor import DTensor
8
 
9
 
10
+ # torch.compile wraps modules as OptimizedModule, inserting "_orig_mod" into
11
+ # parameter FQNs. Activation checkpointing similarly inserts
12
+ # "_checkpoint_wrapped_module". Strip these so name-based matching (skip_keys,
13
+ # expert_keys, QK layer parsing) works regardless of wrapper nesting.
14
+ _WRAPPER_PARTS = frozenset({"_orig_mod", "_checkpoint_wrapped_module"})
15
+
16
+
17
+ def normalize_fqn(name: str) -> str:
18
+ """Strip torch.compile / checkpoint wrapper components from a parameter FQN."""
19
+ return ".".join(p for p in name.split(".") if p not in _WRAPPER_PARTS)
20
+
21
+
22
  @dataclass
23
  class _muon_state:
24
  worker_rank: int
 
90
 
91
 
92
  def default_is_muon(name, x, expert_keys=None):
93
+ name = normalize_fqn(name)
94
  skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
95
  if any(key in name for key in skip_keys):
96
  return False
torch-ext/optimizer/muon.py CHANGED
@@ -11,7 +11,8 @@ from torch.profiler import record_function
11
  from .adamw import step_adamw
12
  from .async_utils import run_pipeline
13
  from .core import (_muon_state, adjust_lr_for_muon,
14
- get_default_muon_param_groups, update_g, update_p)
 
15
  from .distributed.utils import (_is_shard, construct_shard_mesh,
16
  get_slices_of_dtensor)
17
  from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO,
@@ -45,7 +46,8 @@ def _expand_expert_params(names, params, expert_keys):
45
  expanded_params = []
46
 
47
  for n, p in zip(names, params):
48
- is_expert = expert_keys and any(key in n for key in expert_keys)
 
49
  is_dtensor = isinstance(p.data, DTensor)
50
 
51
  if not is_expert:
 
11
  from .adamw import step_adamw
12
  from .async_utils import run_pipeline
13
  from .core import (_muon_state, adjust_lr_for_muon,
14
+ get_default_muon_param_groups, normalize_fqn, update_g,
15
+ update_p)
16
  from .distributed.utils import (_is_shard, construct_shard_mesh,
17
  get_slices_of_dtensor)
18
  from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO,
 
46
  expanded_params = []
47
 
48
  for n, p in zip(names, params):
49
+ is_expert = expert_keys and any(
50
+ key in normalize_fqn(n) for key in expert_keys)
51
  is_dtensor = isinstance(p.data, DTensor)
52
 
53
  if not is_expert:
torch-ext/optimizer/qk_clip.py CHANGED
@@ -5,6 +5,8 @@ from dataclasses import dataclass
5
  import torch
6
  from torch.distributed.tensor import DTensor
7
 
 
 
8
  logger = logging.getLogger(__name__)
9
 
10
 
@@ -23,7 +25,7 @@ def parse_qk_layer(name: str) -> tuple[str | None, int]:
23
  'model.7.attn.k_proj.weight' -> ('k_proj', 7)
24
  'model.4.attn.v_proj.weight' -> (None, -1)
25
  """
26
- parts = name.split('.')
27
  if len(parts) < 3:
28
  return None, -1
29
 
 
5
  import torch
6
  from torch.distributed.tensor import DTensor
7
 
8
+ from .core import normalize_fqn
9
+
10
  logger = logging.getLogger(__name__)
11
 
12
 
 
25
  'model.7.attn.k_proj.weight' -> ('k_proj', 7)
26
  'model.4.attn.v_proj.weight' -> (None, -1)
27
  """
28
+ parts = normalize_fqn(name).split('.')
29
  if len(parts) < 3:
30
  return None, -1
31