Instructions to use Motif-Technologies/optimizer with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use Motif-Technologies/optimizer with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("Motif-Technologies/optimizer") - Notebooks
- Google Colab
- Kaggle
| # QK-Clip for MuonClip Optimizer (MLA) | |
| > Reference: [Kimi K2 Technical Report](https://arxiv.org/pdf/2507.20534), Section 2.1, Algorithm 1 | |
| ## ๊ฐ์ | |
| QK-Clip์ Muon optimizer์์ ๋ฐ์ํ๋ attention logit explosion์ ๋ฐฉ์งํ๊ธฐ ์ํ **weight rescaling** ๊ธฐ๋ฒ์ด๋ค. | |
| forward/backward์๋ ๊ฐ์ ํ์ง ์๊ณ , optimizer step **์ดํ**์ weight๋ฅผ rescaleํ์ฌ logit ์ฑ์ฅ์ ์์ฒ ์ฐจ๋จํ๋ค. | |
| ## Algorithm 1: MuonClip | |
| ``` | |
| for each training step t: | |
| // 1. Muon optimizer step | |
| for each weight W: | |
| Mt = ยตยทMt-1 + Gt | |
| Ot = Newton-Schulz(Mt) ยท โmax(n,m) ยท 0.2 | |
| Wt = Wt-1 - ฮทยท(Ot + ฮปยทWt-1) | |
| // 2. QK-Clip | |
| for each attention head h: | |
| S^h_max โ forward์์ ๊ธฐ๋กํ head h์ max pre-softmax logit | |
| if S^h_max > ฯ: | |
| ฮณ โ ฯ / S^h_max | |
| W^h_qc โ W^h_qc ยท โฮณ (query compressed, q_nope) | |
| W^h_kc โ W^h_kc ยท โฮณ (key compressed, k_nope) | |
| W^h_qr โ W^h_qr ยท ฮณ (query rotary, q_pe) | |
| // k_R (shared rotary, k_pe): ์ ๊ฑด๋๋ฆผ | |
| ``` | |
| ## ๊ธฐ์กด ์ฝ๋ โ MLA ์๋์ฝ๋ | |
| ### ํ์ฌ ์ฝ๋ ๊ตฌ์กฐ (MHA/GQA) | |
| ``` | |
| parse_qk_layer(name) โ wq/wk ์ฌ๋ถ ํ๋ณ, layer index ์ถ์ถ | |
| get_qk_clip_info(config, n) โ QKClipInfo (kind, indices, head_dim, threshold, logit) | |
| compute_scales(p, info) โ per-head โฮณ scales ํ ์ ๋ฐํ | |
| qk_clip(p, scales, head_dim) โ W.view(-1, head_dim, in_dim).mul_(scales) | |
| ``` | |
| ํ์ฌ ์ฝ๋๋ head_dim์ด ๊ท ์ผํ๊ณ , Q/K weight ์ ์ฒด์ ๋์ผํ โฮณ๋ฅผ ์ ์ฉํ๋ค. | |
| ### MLA์์ ๋ฌ๋ผ์ง๋ ์ | |
| | ํญ๋ชฉ | MHA/GQA (ํ์ฌ) | MLA | | |
| |---|---|---| | |
| | Q weight | `wq` / `q_proj` | `wq_b` (up-proj from LoRA) | | |
| | K weight | `wk` / `k_proj` | `wkv_b` (k_nope + v ํฉ์ณ์ ธ ์์) | | |
| | Q head stride | `qk_head_dim` (๊ท ์ผ) | `qk_head_dim` = `qk_nope_head_dim + qk_rope_head_dim` | | |
| | K head stride | `qk_head_dim` (๊ท ์ผ) | `kv_stride` = `qk_nope_head_dim + v_head_dim` | | |
| | Q scaling | ์ ์ฒด โฮณ | nope โ โฮณ, rope โ ฮณ (์๋ก ๋ค๋ฆ) | | |
| | K scaling | ์ ์ฒด โฮณ | k_nope โ โฮณ, v โ 1.0 (๋ถ๋ถ๋ง) | | |
| | shared k_pe | ์์ | `wkv_a` ๋ท๋ถ๋ถ, ์ ๊ฑด๋๋ฆผ | | |
| ### ์๋์ฝ๋: parse_qk_layer (MLA ํ์ฅ) | |
| ```python | |
| def parse_qk_layer(name: str) -> tuple[str | None, int]: | |
| parts = normalize_fqn(name).split('.') | |
| kind = parts[-2] | |
| layer_idx = -1 | |
| for part in reversed(parts): | |
| if part.isdigit(): | |
| layer_idx = int(part) | |
| break | |
| # MHA/GQA: wq, wk, q_proj, k_proj | |
| # MLA: wq_b (Q up-proj), wkv_b (KV up-proj) | |
| if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'): | |
| return kind, layer_idx | |
| return None, -1 | |
| ``` | |
| ### ์๋์ฝ๋: QKClipInfo (MLA ํ์ฅ) | |
| ```python | |
| @dataclass | |
| class QKClipInfo: | |
| kind: str | None # 'wq_b' or 'wkv_b' (MLA) / 'wq','wk' (MHA) | |
| indices: list[int] # clipping ๋์ head indices | |
| head_dim: int # ๊ธฐ์กด MHA์ฉ (uniform stride) | |
| threshold: float | |
| logit: torch.Tensor | None | |
| # MLA ์ ์ฉ ํ๋ | |
| is_mla: bool = False | |
| qk_nope_head_dim: int = 0 | |
| qk_rope_head_dim: int = 0 | |
| v_head_dim: int = 0 | |
| ``` | |
| ### ์๋์ฝ๋: get_qk_clip_info (MLA ํ์ฅ) | |
| ```python | |
| def get_qk_clip_info(clip_config, n, qk_logits): | |
| if clip_config is None: | |
| return None | |
| threshold = clip_config['threshold'] | |
| kind, layer_idx = parse_qk_layer(n) | |
| is_mla = clip_config.get('is_mla', False) | |
| logit, indices = None, [] | |
| if qk_logits is not None and kind is not None: | |
| logit = qk_logits[layer_idx] | |
| if isinstance(logit, DTensor): | |
| logit = logit.full_tensor() | |
| if kind in ('wq_b', 'wq', 'q_proj'): | |
| indices = clip_config.get('q_indices', []) or [] | |
| elif kind in ('wkv_b', 'wk', 'k_proj'): | |
| indices = clip_config.get('k_indices', []) or [] | |
| if is_mla: | |
| return QKClipInfo( | |
| kind=kind, | |
| indices=indices, | |
| head_dim=clip_config['head_dim'], # qk_head_dim (for wq_b) | |
| threshold=threshold, | |
| logit=logit, | |
| is_mla=True, | |
| qk_nope_head_dim=clip_config['qk_nope_head_dim'], | |
| qk_rope_head_dim=clip_config['qk_rope_head_dim'], | |
| v_head_dim=clip_config['v_head_dim'], | |
| ) | |
| else: | |
| # ๊ธฐ์กด MHA/GQA ๊ฒฝ๋ก | |
| return QKClipInfo( | |
| kind=kind, indices=indices, | |
| head_dim=clip_config['head_dim'], | |
| threshold=threshold, logit=logit, | |
| ) | |
| ``` | |
| ### ์๋์ฝ๋: compute_scales (MLA ํ์ฅ) | |
| ๊ธฐ์กด๊ณผ ๋์ผํ๊ฒ per-head ฮณ๋ฅผ ๊ณ์ฐํ๋ค. (ฮณ ๊ฒฐ์ ์ MHA์ ๋์ผ) | |
| ๋ฌ๋ผ์ง๋ ๊ฑด `qk_clip` ์ ์ฉ ์ head ๋ด๋ถ๋ฅผ sub-region๋ณ๋ก ๋๋ ์ ๋ค๋ฅธ ๋ณํ์ ์ฐ๋ ๊ฒ์ด๋ค. | |
| ```python | |
| def compute_scales(p, qk_clip_state): | |
| """๊ธฐ์กด ์ฝ๋์ ๋์ผ. per-head โฮณ ๋ฐํ.""" | |
| kind = qk_clip_state.kind | |
| indices = qk_clip_state.indices | |
| threshold = qk_clip_state.threshold | |
| logit = qk_clip_state.logit | |
| head_scales = {} | |
| for logit_idx, head_idx in enumerate(indices): | |
| v_ele = float(logit[logit_idx]) | |
| if v_ele > threshold: | |
| new_scale = math.sqrt(threshold / v_ele) # โฮณ | |
| if head_idx not in head_scales or new_scale < head_scales[head_idx]: | |
| head_scales[head_idx] = new_scale | |
| if not head_scales: | |
| return None | |
| H_global = p.shape[0] // qk_clip_state.head_dim # MLA: head_dim = qk_head_dim or kv_stride | |
| scales_full = torch.ones(H_global, device=p.data.device) | |
| for head_idx, scale in head_scales.items(): | |
| scales_full[head_idx] = scale # โฮณ_h | |
| return scales_full | |
| ``` | |
| ### ์๋์ฝ๋: qk_clip (MLA ํ์ฅ) | |
| per-head scales(โฮณ)๋ ๋์ผํ๊ฒ ๋ฐ๋, head ๋ด๋ถ sub-region์ ๋ค๋ฅธ ํจ์๋ฅผ ์ ์ฉํ๋ค. | |
| ```python | |
| def qk_clip(p, scales, head_dim, is_mla=False, kind=None, info=None): | |
| """ | |
| scales: [n_heads] ํ ์, ๊ฐ ์์ = โฮณ_h | |
| is_mla=False: ๊ธฐ์กด MHA/GQA (head ๋ด uniform โฮณ) | |
| is_mla=True: MLA (head ๋ด sub-region๋ณ ๋ค๋ฅธ ๋ณํ) | |
| """ | |
| W = p.data if isinstance(p, torch.nn.Parameter) else p | |
| if not is_mla: | |
| # ๊ธฐ์กด: ๋ชจ๋ ํ์ โฮณ ๊ท ์ผ ์ ์ฉ | |
| W.view(-1, head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1)) | |
| return | |
| # MLA: head๋ณ๋ก sub-region ๋ถ๋ฆฌ ์ ์ฉ | |
| if kind == 'wq_b': | |
| qk_nope = info.qk_nope_head_dim | |
| qk_rope = info.qk_rope_head_dim | |
| qk_head_dim = qk_nope + qk_rope | |
| for h in range(len(scales)): | |
| sqrt_gamma = scales[h].item() | |
| if sqrt_gamma >= 1.0: | |
| continue | |
| gamma = sqrt_gamma * sqrt_gamma # โฮณ โ ฮณ | |
| s = h * qk_head_dim | |
| W[s : s + qk_nope] *= sqrt_gamma # q_nope โ โฮณ | |
| W[s + qk_nope : s + qk_head_dim] *= gamma # q_pe โ ฮณ | |
| elif kind == 'wkv_b': | |
| qk_nope = info.qk_nope_head_dim | |
| kv_stride = qk_nope + info.v_head_dim | |
| for h in range(len(scales)): | |
| sqrt_gamma = scales[h].item() | |
| if sqrt_gamma >= 1.0: | |
| continue | |
| s = h * kv_stride | |
| W[s : s + qk_nope] *= sqrt_gamma # k_nope โ โฮณ | |
| # v ํ: ์ ๊ฑด๋๋ฆผ | |
| ``` | |
| ### ์๋์ฝ๋: GQA์์ wkv_b indices ์ฒ๋ฆฌ | |
| Q head โ KV head ๋งคํ์ด ํ์ํ๋ค. | |
| ์ฌ๋ฌ Q head๊ฐ ๊ฐ์ KV head๋ฅผ ๊ณต์ ํ๋ฏ๋ก, **group ๋ด ์ต์ gamma** ๊ธฐ์ค์ผ๋ก ํ ๋ฒ๋ง ์ ์ฉํด์ผ ํ๋ค. | |
| ```python | |
| def build_k_indices_for_mla(clip_config, n_heads, n_kv_heads): | |
| """ | |
| Q head ๊ธฐ์ค logit์ผ๋ก๋ถํฐ KV head indices๋ฅผ ์์ฑํ๋ค. | |
| q_indices๊ฐ Q head index ๊ธฐ์ค์ด๋ผ๋ฉด, | |
| k_indices๋ ๋์๋๋ KV head index๋ก ๋ณํํด์ผ ํ๋ค. | |
| ์ฃผ์: ๊ฐ์ KV head์ ๋งคํ๋๋ ์ฌ๋ฌ Q head ์ค | |
| ๊ฐ์ฅ ํฐ logit (= ๊ฐ์ฅ ์์ gamma)์ ์ฌ์ฉํด์ผ ํ๋ค. | |
| """ | |
| heads_per_kv = n_heads // n_kv_heads | |
| q_indices = clip_config.get('q_indices', list(range(n_heads))) | |
| # Q head โ KV head ๋งคํ | |
| # logit ํ ์์์ ๊ฐ์ kv_head์ ๋์๋๋ Q head๋ค ์ค max๋ฅผ ์ทจํ๋ ๊ฒ์ | |
| # compute_scales_mla ๋ด๋ถ์์ min(gamma) ๋ก ์ฒ๋ฆฌ๋จ | |
| k_indices = [] | |
| seen = set() | |
| for q_idx in q_indices: | |
| kv_idx = q_idx // heads_per_kv | |
| if kv_idx not in seen: | |
| k_indices.append(kv_idx) | |
| seen.add(kv_idx) | |
| return k_indices | |
| ``` | |
| ### ์๋์ฝ๋: ํธ์ถ ํ๋ฆ (ํตํฉ) | |
| ```python | |
| # optimizer step ์ดํ ํธ์ถ๋๋ ๋ถ๋ถ (๊ธฐ์กด ์ฝ๋ ๊ตฌ์กฐ ์ ์ง) | |
| for name, param in model.named_parameters(): | |
| info = get_qk_clip_info(clip_config, name, qk_logits) | |
| if info is None or info.kind is None: | |
| continue | |
| scales = compute_scales(param, info) # per-head โฮณ (MHA/MLA ๊ณตํต) | |
| if scales is not None: | |
| qk_clip(param, scales, info.head_dim, | |
| is_mla=info.is_mla, kind=info.kind, info=info) | |
| ``` | |
| ### ์๋์ฝ๋: clip_config ์์ | |
| ```python | |
| # MHA/GQA (๊ธฐ์กด) | |
| clip_config = { | |
| 'head_dim': 128, | |
| 'threshold': 100.0, | |
| 'q_indices': list(range(n_heads)), | |
| 'k_indices': list(range(n_kv_heads)), | |
| } | |
| # MLA (ํ์ฅ) | |
| clip_config = { | |
| 'is_mla': True, | |
| 'head_dim': 192, # qk_head_dim (= qk_nope + qk_rope) | |
| 'qk_nope_head_dim': 128, | |
| 'qk_rope_head_dim': 64, | |
| 'v_head_dim': 128, | |
| 'threshold': 100.0, | |
| 'q_indices': list(range(n_heads)), | |
| 'k_indices': list(range(n_kv_heads)), # build_k_indices_for_mla๋ก ์์ฑ | |
| } | |
| ``` | |
| ## ํ ์ธ๋ฑ์ค ๋งคํ ํ ์ด๋ธ | |
| | ์๊ณ ๋ฆฌ์ฆ ๊ธฐํธ | ํ ์ | ํ ๋ฒ์ | scale | | |
| |---|---|---|---| | |
| | W^h_qc | `wq_b.weight` | `[h*qk_head_dim : h*qk_head_dim + qk_nope_head_dim]` | โฮณ | | |
| | W^h_qr | `wq_b.weight` | `[h*qk_head_dim + qk_nope_head_dim : (h+1)*qk_head_dim]` | ฮณ | | |
| | W^h_kc | `wkv_b.weight` | `[kv_h*kv_stride : kv_h*kv_stride + qk_nope_head_dim]` | โฮณ | | |
| | k_R | `wkv_a` output ๋ท๋ถ๋ถ | - | ์ ๊ฑด๋๋ฆผ | | |
| - `kv_stride = qk_nope_head_dim + v_head_dim` | |
| - `kv_h = h // (n_heads // n_kv_heads)` (GQA head ๋งคํ) | |
| ## ํ์ดํผํ๋ผ๋ฏธํฐ | |
| | ํ๋ผ๋ฏธํฐ | ๊ฐ | ๋น๊ณ | | |
| |---|---|---| | |
| | ฯ (threshold) | 100 | K2 full-scale ํ์ต | | |
| | ฯ (aggressive) | 30 | ์๊ท๋ชจ ablation, ์ฑ๋ฅ ์ ํ ์์ ํ์ธ | | |
| ## ์ฐธ๊ณ ์ฌํญ | |
| - **Self-deactivation**: K2์์ ์ด๊ธฐ 70k step ๋์ 12.7%์ head๋ง trigger๋จ. ์ดํ ๋ชจ๋ head์ S_max๊ฐ ฯ ์๋๋ก ๋ด๋ ค๊ฐ๋ฉด์ ์์ฐ์ค๋ฝ๊ฒ ๋นํ์ฑํ. | |
| - **DP/TP ํ๊ฒฝ**: S^h_max๋ฅผ all-reduce๋ก ๋ชจ๋ rank์์ max ์์ง ํ์. | |
| - **GQA ์ค๋ณต ์ ์ฉ ๋ฐฉ์ง**: ๊ฐ์ KV head๋ฅผ ๊ณต์ ํ๋ Q head group์์ ๊ฐ์ฅ ์์ gamma(= ๊ฐ์ฅ ํฐ logit)๋ฅผ ๊ธฐ์ค์ผ๋ก KV weight๋ฅผ ํ ๋ฒ๋ง scaling. `compute_scales_mla`์์ `min(gamma)` ๋ก์ง์ผ๋ก ์ฒ๋ฆฌ. | |
| - **wq_b_gate**: attention logit์ด ์๋ output gate์๋ง ๊ด์ฌํ๋ฏ๋ก QK-Clip ๋์ ์๋. | |
| - **๊ธฐ์กด logit soft-cap**: forward-level safety net์ผ๋ก ๋จ๊ฒจ๋๋, optimizer-level QK-Clip์ ์ถ๊ฐํ๋ ๊ฒ์ด ๋ ผ๋ฌธ์ ์ ๊ทผ๋ฒ. | |