Kernels
optimizer / docs /muon-clip.md
dongseokmotif's picture
feat: extend QK-Clip to support MLA (MuonClip Algorithm 1) [skip-build] (#28)
e8e2c81 unverified
# QK-Clip for MuonClip Optimizer (MLA)
> Reference: [Kimi K2 Technical Report](https://arxiv.org/pdf/2507.20534), Section 2.1, Algorithm 1
## ๊ฐœ์š”
QK-Clip์€ Muon optimizer์—์„œ ๋ฐœ์ƒํ•˜๋Š” attention logit explosion์„ ๋ฐฉ์ง€ํ•˜๊ธฐ ์œ„ํ•œ **weight rescaling** ๊ธฐ๋ฒ•์ด๋‹ค.
forward/backward์—๋Š” ๊ฐœ์ž…ํ•˜์ง€ ์•Š๊ณ , optimizer step **์ดํ›„**์— weight๋ฅผ rescaleํ•˜์—ฌ logit ์„ฑ์žฅ์„ ์›์ฒœ ์ฐจ๋‹จํ•œ๋‹ค.
## Algorithm 1: MuonClip
```
for each training step t:
// 1. Muon optimizer step
for each weight W:
Mt = ยตยทMt-1 + Gt
Ot = Newton-Schulz(Mt) ยท โˆšmax(n,m) ยท 0.2
Wt = Wt-1 - ฮทยท(Ot + ฮปยทWt-1)
// 2. QK-Clip
for each attention head h:
S^h_max โ† forward์—์„œ ๊ธฐ๋กํ•œ head h์˜ max pre-softmax logit
if S^h_max > ฯ„:
ฮณ โ† ฯ„ / S^h_max
W^h_qc โ† W^h_qc ยท โˆšฮณ (query compressed, q_nope)
W^h_kc โ† W^h_kc ยท โˆšฮณ (key compressed, k_nope)
W^h_qr โ† W^h_qr ยท ฮณ (query rotary, q_pe)
// k_R (shared rotary, k_pe): ์•ˆ ๊ฑด๋“œ๋ฆผ
```
## ๊ธฐ์กด ์ฝ”๋“œ โ†’ MLA ์ˆ˜๋„์ฝ”๋“œ
### ํ˜„์žฌ ์ฝ”๋“œ ๊ตฌ์กฐ (MHA/GQA)
```
parse_qk_layer(name) โ†’ wq/wk ์—ฌ๋ถ€ ํŒ๋ณ„, layer index ์ถ”์ถœ
get_qk_clip_info(config, n) โ†’ QKClipInfo (kind, indices, head_dim, threshold, logit)
compute_scales(p, info) โ†’ per-head โˆšฮณ scales ํ…์„œ ๋ฐ˜ํ™˜
qk_clip(p, scales, head_dim) โ†’ W.view(-1, head_dim, in_dim).mul_(scales)
```
ํ˜„์žฌ ์ฝ”๋“œ๋Š” head_dim์ด ๊ท ์ผํ•˜๊ณ , Q/K weight ์ „์ฒด์— ๋™์ผํ•œ โˆšฮณ๋ฅผ ์ ์šฉํ•œ๋‹ค.
### MLA์—์„œ ๋‹ฌ๋ผ์ง€๋Š” ์ 
| ํ•ญ๋ชฉ | MHA/GQA (ํ˜„์žฌ) | MLA |
|---|---|---|
| Q weight | `wq` / `q_proj` | `wq_b` (up-proj from LoRA) |
| K weight | `wk` / `k_proj` | `wkv_b` (k_nope + v ํ•ฉ์ณ์ ธ ์žˆ์Œ) |
| Q head stride | `qk_head_dim` (๊ท ์ผ) | `qk_head_dim` = `qk_nope_head_dim + qk_rope_head_dim` |
| K head stride | `qk_head_dim` (๊ท ์ผ) | `kv_stride` = `qk_nope_head_dim + v_head_dim` |
| Q scaling | ์ „์ฒด โˆšฮณ | nope โ†’ โˆšฮณ, rope โ†’ ฮณ (์„œ๋กœ ๋‹ค๋ฆ„) |
| K scaling | ์ „์ฒด โˆšฮณ | k_nope โ†’ โˆšฮณ, v โ†’ 1.0 (๋ถ€๋ถ„๋งŒ) |
| shared k_pe | ์—†์Œ | `wkv_a` ๋’ท๋ถ€๋ถ„, ์•ˆ ๊ฑด๋“œ๋ฆผ |
### ์ˆ˜๋„์ฝ”๋“œ: parse_qk_layer (MLA ํ™•์žฅ)
```python
def parse_qk_layer(name: str) -> tuple[str | None, int]:
parts = normalize_fqn(name).split('.')
kind = parts[-2]
layer_idx = -1
for part in reversed(parts):
if part.isdigit():
layer_idx = int(part)
break
# MHA/GQA: wq, wk, q_proj, k_proj
# MLA: wq_b (Q up-proj), wkv_b (KV up-proj)
if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'):
return kind, layer_idx
return None, -1
```
### ์ˆ˜๋„์ฝ”๋“œ: QKClipInfo (MLA ํ™•์žฅ)
```python
@dataclass
class QKClipInfo:
kind: str | None # 'wq_b' or 'wkv_b' (MLA) / 'wq','wk' (MHA)
indices: list[int] # clipping ๋Œ€์ƒ head indices
head_dim: int # ๊ธฐ์กด MHA์šฉ (uniform stride)
threshold: float
logit: torch.Tensor | None
# MLA ์ „์šฉ ํ•„๋“œ
is_mla: bool = False
qk_nope_head_dim: int = 0
qk_rope_head_dim: int = 0
v_head_dim: int = 0
```
### ์ˆ˜๋„์ฝ”๋“œ: get_qk_clip_info (MLA ํ™•์žฅ)
```python
def get_qk_clip_info(clip_config, n, qk_logits):
if clip_config is None:
return None
threshold = clip_config['threshold']
kind, layer_idx = parse_qk_layer(n)
is_mla = clip_config.get('is_mla', False)
logit, indices = None, []
if qk_logits is not None and kind is not None:
logit = qk_logits[layer_idx]
if isinstance(logit, DTensor):
logit = logit.full_tensor()
if kind in ('wq_b', 'wq', 'q_proj'):
indices = clip_config.get('q_indices', []) or []
elif kind in ('wkv_b', 'wk', 'k_proj'):
indices = clip_config.get('k_indices', []) or []
if is_mla:
return QKClipInfo(
kind=kind,
indices=indices,
head_dim=clip_config['head_dim'], # qk_head_dim (for wq_b)
threshold=threshold,
logit=logit,
is_mla=True,
qk_nope_head_dim=clip_config['qk_nope_head_dim'],
qk_rope_head_dim=clip_config['qk_rope_head_dim'],
v_head_dim=clip_config['v_head_dim'],
)
else:
# ๊ธฐ์กด MHA/GQA ๊ฒฝ๋กœ
return QKClipInfo(
kind=kind, indices=indices,
head_dim=clip_config['head_dim'],
threshold=threshold, logit=logit,
)
```
### ์ˆ˜๋„์ฝ”๋“œ: compute_scales (MLA ํ™•์žฅ)
๊ธฐ์กด๊ณผ ๋™์ผํ•˜๊ฒŒ per-head ฮณ๋ฅผ ๊ณ„์‚ฐํ•œ๋‹ค. (ฮณ ๊ฒฐ์ •์€ MHA์™€ ๋™์ผ)
๋‹ฌ๋ผ์ง€๋Š” ๊ฑด `qk_clip` ์ ์šฉ ์‹œ head ๋‚ด๋ถ€๋ฅผ sub-region๋ณ„๋กœ ๋‚˜๋ˆ ์„œ ๋‹ค๋ฅธ ๋ณ€ํ™˜์„ ์“ฐ๋Š” ๊ฒƒ์ด๋‹ค.
```python
def compute_scales(p, qk_clip_state):
"""๊ธฐ์กด ์ฝ”๋“œ์™€ ๋™์ผ. per-head โˆšฮณ ๋ฐ˜ํ™˜."""
kind = qk_clip_state.kind
indices = qk_clip_state.indices
threshold = qk_clip_state.threshold
logit = qk_clip_state.logit
head_scales = {}
for logit_idx, head_idx in enumerate(indices):
v_ele = float(logit[logit_idx])
if v_ele > threshold:
new_scale = math.sqrt(threshold / v_ele) # โˆšฮณ
if head_idx not in head_scales or new_scale < head_scales[head_idx]:
head_scales[head_idx] = new_scale
if not head_scales:
return None
H_global = p.shape[0] // qk_clip_state.head_dim # MLA: head_dim = qk_head_dim or kv_stride
scales_full = torch.ones(H_global, device=p.data.device)
for head_idx, scale in head_scales.items():
scales_full[head_idx] = scale # โˆšฮณ_h
return scales_full
```
### ์ˆ˜๋„์ฝ”๋“œ: qk_clip (MLA ํ™•์žฅ)
per-head scales(โˆšฮณ)๋Š” ๋™์ผํ•˜๊ฒŒ ๋ฐ›๋˜, head ๋‚ด๋ถ€ sub-region์— ๋‹ค๋ฅธ ํ•จ์ˆ˜๋ฅผ ์ ์šฉํ•œ๋‹ค.
```python
def qk_clip(p, scales, head_dim, is_mla=False, kind=None, info=None):
"""
scales: [n_heads] ํ…์„œ, ๊ฐ ์›์†Œ = โˆšฮณ_h
is_mla=False: ๊ธฐ์กด MHA/GQA (head ๋‚ด uniform โˆšฮณ)
is_mla=True: MLA (head ๋‚ด sub-region๋ณ„ ๋‹ค๋ฅธ ๋ณ€ํ™˜)
"""
W = p.data if isinstance(p, torch.nn.Parameter) else p
if not is_mla:
# ๊ธฐ์กด: ๋ชจ๋“  ํ–‰์— โˆšฮณ ๊ท ์ผ ์ ์šฉ
W.view(-1, head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1))
return
# MLA: head๋ณ„๋กœ sub-region ๋ถ„๋ฆฌ ์ ์šฉ
if kind == 'wq_b':
qk_nope = info.qk_nope_head_dim
qk_rope = info.qk_rope_head_dim
qk_head_dim = qk_nope + qk_rope
for h in range(len(scales)):
sqrt_gamma = scales[h].item()
if sqrt_gamma >= 1.0:
continue
gamma = sqrt_gamma * sqrt_gamma # โˆšฮณ โ†’ ฮณ
s = h * qk_head_dim
W[s : s + qk_nope] *= sqrt_gamma # q_nope โ†’ โˆšฮณ
W[s + qk_nope : s + qk_head_dim] *= gamma # q_pe โ†’ ฮณ
elif kind == 'wkv_b':
qk_nope = info.qk_nope_head_dim
kv_stride = qk_nope + info.v_head_dim
for h in range(len(scales)):
sqrt_gamma = scales[h].item()
if sqrt_gamma >= 1.0:
continue
s = h * kv_stride
W[s : s + qk_nope] *= sqrt_gamma # k_nope โ†’ โˆšฮณ
# v ํ–‰: ์•ˆ ๊ฑด๋“œ๋ฆผ
```
### ์ˆ˜๋„์ฝ”๋“œ: GQA์—์„œ wkv_b indices ์ฒ˜๋ฆฌ
Q head โ†’ KV head ๋งคํ•‘์ด ํ•„์š”ํ•˜๋‹ค.
์—ฌ๋Ÿฌ Q head๊ฐ€ ๊ฐ™์€ KV head๋ฅผ ๊ณต์œ ํ•˜๋ฏ€๋กœ, **group ๋‚ด ์ตœ์†Œ gamma** ๊ธฐ์ค€์œผ๋กœ ํ•œ ๋ฒˆ๋งŒ ์ ์šฉํ•ด์•ผ ํ•œ๋‹ค.
```python
def build_k_indices_for_mla(clip_config, n_heads, n_kv_heads):
"""
Q head ๊ธฐ์ค€ logit์œผ๋กœ๋ถ€ํ„ฐ KV head indices๋ฅผ ์ƒ์„ฑํ•œ๋‹ค.
q_indices๊ฐ€ Q head index ๊ธฐ์ค€์ด๋ผ๋ฉด,
k_indices๋Š” ๋Œ€์‘๋˜๋Š” KV head index๋กœ ๋ณ€ํ™˜ํ•ด์•ผ ํ•œ๋‹ค.
์ฃผ์˜: ๊ฐ™์€ KV head์— ๋งคํ•‘๋˜๋Š” ์—ฌ๋Ÿฌ Q head ์ค‘
๊ฐ€์žฅ ํฐ logit (= ๊ฐ€์žฅ ์ž‘์€ gamma)์„ ์‚ฌ์šฉํ•ด์•ผ ํ•œ๋‹ค.
"""
heads_per_kv = n_heads // n_kv_heads
q_indices = clip_config.get('q_indices', list(range(n_heads)))
# Q head โ†’ KV head ๋งคํ•‘
# logit ํ…์„œ์—์„œ ๊ฐ™์€ kv_head์— ๋Œ€์‘๋˜๋Š” Q head๋“ค ์ค‘ max๋ฅผ ์ทจํ•˜๋Š” ๊ฒƒ์€
# compute_scales_mla ๋‚ด๋ถ€์—์„œ min(gamma) ๋กœ ์ฒ˜๋ฆฌ๋จ
k_indices = []
seen = set()
for q_idx in q_indices:
kv_idx = q_idx // heads_per_kv
if kv_idx not in seen:
k_indices.append(kv_idx)
seen.add(kv_idx)
return k_indices
```
### ์ˆ˜๋„์ฝ”๋“œ: ํ˜ธ์ถœ ํ๋ฆ„ (ํ†ตํ•ฉ)
```python
# optimizer step ์ดํ›„ ํ˜ธ์ถœ๋˜๋Š” ๋ถ€๋ถ„ (๊ธฐ์กด ์ฝ”๋“œ ๊ตฌ์กฐ ์œ ์ง€)
for name, param in model.named_parameters():
info = get_qk_clip_info(clip_config, name, qk_logits)
if info is None or info.kind is None:
continue
scales = compute_scales(param, info) # per-head โˆšฮณ (MHA/MLA ๊ณตํ†ต)
if scales is not None:
qk_clip(param, scales, info.head_dim,
is_mla=info.is_mla, kind=info.kind, info=info)
```
### ์ˆ˜๋„์ฝ”๋“œ: clip_config ์˜ˆ์‹œ
```python
# MHA/GQA (๊ธฐ์กด)
clip_config = {
'head_dim': 128,
'threshold': 100.0,
'q_indices': list(range(n_heads)),
'k_indices': list(range(n_kv_heads)),
}
# MLA (ํ™•์žฅ)
clip_config = {
'is_mla': True,
'head_dim': 192, # qk_head_dim (= qk_nope + qk_rope)
'qk_nope_head_dim': 128,
'qk_rope_head_dim': 64,
'v_head_dim': 128,
'threshold': 100.0,
'q_indices': list(range(n_heads)),
'k_indices': list(range(n_kv_heads)), # build_k_indices_for_mla๋กœ ์ƒ์„ฑ
}
```
## ํ–‰ ์ธ๋ฑ์Šค ๋งคํ•‘ ํ…Œ์ด๋ธ”
| ์•Œ๊ณ ๋ฆฌ์ฆ˜ ๊ธฐํ˜ธ | ํ…์„œ | ํ–‰ ๋ฒ”์œ„ | scale |
|---|---|---|---|
| W^h_qc | `wq_b.weight` | `[h*qk_head_dim : h*qk_head_dim + qk_nope_head_dim]` | โˆšฮณ |
| W^h_qr | `wq_b.weight` | `[h*qk_head_dim + qk_nope_head_dim : (h+1)*qk_head_dim]` | ฮณ |
| W^h_kc | `wkv_b.weight` | `[kv_h*kv_stride : kv_h*kv_stride + qk_nope_head_dim]` | โˆšฮณ |
| k_R | `wkv_a` output ๋’ท๋ถ€๋ถ„ | - | ์•ˆ ๊ฑด๋“œ๋ฆผ |
- `kv_stride = qk_nope_head_dim + v_head_dim`
- `kv_h = h // (n_heads // n_kv_heads)` (GQA head ๋งคํ•‘)
## ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ
| ํŒŒ๋ผ๋ฏธํ„ฐ | ๊ฐ’ | ๋น„๊ณ  |
|---|---|---|
| ฯ„ (threshold) | 100 | K2 full-scale ํ•™์Šต |
| ฯ„ (aggressive) | 30 | ์†Œ๊ทœ๋ชจ ablation, ์„ฑ๋Šฅ ์ €ํ•˜ ์—†์Œ ํ™•์ธ |
## ์ฐธ๊ณ ์‚ฌํ•ญ
- **Self-deactivation**: K2์—์„œ ์ดˆ๊ธฐ 70k step ๋™์•ˆ 12.7%์˜ head๋งŒ trigger๋จ. ์ดํ›„ ๋ชจ๋“  head์˜ S_max๊ฐ€ ฯ„ ์•„๋ž˜๋กœ ๋‚ด๋ ค๊ฐ€๋ฉด์„œ ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ๋น„ํ™œ์„ฑํ™”.
- **DP/TP ํ™˜๊ฒฝ**: S^h_max๋ฅผ all-reduce๋กœ ๋ชจ๋“  rank์—์„œ max ์ˆ˜์ง‘ ํ•„์š”.
- **GQA ์ค‘๋ณต ์ ์šฉ ๋ฐฉ์ง€**: ๊ฐ™์€ KV head๋ฅผ ๊ณต์œ ํ•˜๋Š” Q head group์—์„œ ๊ฐ€์žฅ ์ž‘์€ gamma(= ๊ฐ€์žฅ ํฐ logit)๋ฅผ ๊ธฐ์ค€์œผ๋กœ KV weight๋ฅผ ํ•œ ๋ฒˆ๋งŒ scaling. `compute_scales_mla`์—์„œ `min(gamma)` ๋กœ์ง์œผ๋กœ ์ฒ˜๋ฆฌ.
- **wq_b_gate**: attention logit์ด ์•„๋‹Œ output gate์—๋งŒ ๊ด€์—ฌํ•˜๋ฏ€๋กœ QK-Clip ๋Œ€์ƒ ์•„๋‹˜.
- **๊ธฐ์กด logit soft-cap**: forward-level safety net์œผ๋กœ ๋‚จ๊ฒจ๋‘๋˜, optimizer-level QK-Clip์„ ์ถ”๊ฐ€ํ•˜๋Š” ๊ฒƒ์ด ๋…ผ๋ฌธ์˜ ์ ‘๊ทผ๋ฒ•.