Kernels
wyldecat Claude Opus 4.6 commited on
Commit
1118752
·
1 Parent(s): f008017

Add info-level logging for param group classification (Muon vs AdamW)

Browse files

Log each parameter's skip/expert status and effective ndim in
default_is_muon(), and summarize Muon/AdamW param lists in
get_default_muon_param_groups().

[skip-build]

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. torch-ext/optimizer/core.py +20 -3
torch-ext/optimizer/core.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import math
2
  from dataclasses import dataclass
3
 
@@ -12,6 +13,8 @@ from torch.distributed.tensor import DTensor
12
  # expert_keys, QK layer parsing) works regardless of wrapper nesting.
13
  _WRAPPER_PARTS = frozenset({"_orig_mod", "_checkpoint_wrapped_module"})
14
 
 
 
15
 
16
  def normalize_fqn(name: str) -> str:
17
  """Strip torch.compile / checkpoint wrapper components from a parameter FQN."""
@@ -92,11 +95,18 @@ def default_is_muon(name, x, expert_keys=None):
92
  parts = normalize_fqn(name).split(".")
93
  skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
94
  if any(key in parts for key in skip_keys):
 
 
95
  return False
96
  effective_ndim = x.ndim
97
- if expert_keys and any(key in parts for key in expert_keys):
 
98
  effective_ndim -= 1
99
- return effective_ndim >= 2
 
 
 
 
100
 
101
 
102
  def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
@@ -104,7 +114,7 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
104
  is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys)
105
 
106
  muon_params, muon_names = [], []
107
- non_muon_params = []
108
 
109
  for n, p in model.named_parameters():
110
  if not p.requires_grad:
@@ -114,6 +124,13 @@ def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
114
  muon_names.append(n)
115
  else:
116
  non_muon_params.append(p)
 
 
 
 
 
 
 
117
 
118
  return [
119
  {
 
1
+ import logging
2
  import math
3
  from dataclasses import dataclass
4
 
 
13
  # expert_keys, QK layer parsing) works regardless of wrapper nesting.
14
  _WRAPPER_PARTS = frozenset({"_orig_mod", "_checkpoint_wrapped_module"})
15
 
16
+ logger = logging.getLogger(__name__)
17
+
18
 
19
  def normalize_fqn(name: str) -> str:
20
  """Strip torch.compile / checkpoint wrapper components from a parameter FQN."""
 
95
  parts = normalize_fqn(name).split(".")
96
  skip_keys = ["embed_tokens", "lm_head", "tok_embeddings", "output"]
97
  if any(key in parts for key in skip_keys):
98
+ logger.info("[is_muon] %s: skip (matched skip_key), ndim=%d", name,
99
+ x.ndim)
100
  return False
101
  effective_ndim = x.ndim
102
+ is_expert = expert_keys and any(key in parts for key in expert_keys)
103
+ if is_expert:
104
  effective_ndim -= 1
105
+ result = effective_ndim >= 2
106
+ logger.info("[is_muon] %s: ndim=%d, expert=%s, effective_ndim=%d → %s",
107
+ name, x.ndim, is_expert, effective_ndim,
108
+ "Muon" if result else "AdamW")
109
+ return result
110
 
111
 
112
  def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None):
 
114
  is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys)
115
 
116
  muon_params, muon_names = [], []
117
+ non_muon_params, non_muon_names = [], []
118
 
119
  for n, p in model.named_parameters():
120
  if not p.requires_grad:
 
124
  muon_names.append(n)
125
  else:
126
  non_muon_params.append(p)
127
+ non_muon_names.append(n)
128
+
129
+ logger.info("[param_groups] expert_keys=%s", expert_keys)
130
+ logger.info("[param_groups] Muon params (%d): %s", len(muon_names),
131
+ muon_names)
132
+ logger.info("[param_groups] AdamW params (%d): %s", len(non_muon_names),
133
+ non_muon_names)
134
 
135
  return [
136
  {