Kernels
File size: 11,052 Bytes
e8e2c81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
# QK-Clip for MuonClip Optimizer (MLA)

> Reference: [Kimi K2 Technical Report](https://arxiv.org/pdf/2507.20534), Section 2.1, Algorithm 1

## ๊ฐœ์š”

QK-Clip์€ Muon optimizer์—์„œ ๋ฐœ์ƒํ•˜๋Š” attention logit explosion์„ ๋ฐฉ์ง€ํ•˜๊ธฐ ์œ„ํ•œ **weight rescaling** ๊ธฐ๋ฒ•์ด๋‹ค.
forward/backward์—๋Š” ๊ฐœ์ž…ํ•˜์ง€ ์•Š๊ณ , optimizer step **์ดํ›„**์— weight๋ฅผ rescaleํ•˜์—ฌ logit ์„ฑ์žฅ์„ ์›์ฒœ ์ฐจ๋‹จํ•œ๋‹ค.

## Algorithm 1: MuonClip

```
for each training step t:
    // 1. Muon optimizer step
    for each weight W:
        Mt = ยตยทMt-1 + Gt
        Ot = Newton-Schulz(Mt) ยท โˆšmax(n,m) ยท 0.2
        Wt = Wt-1 - ฮทยท(Ot + ฮปยทWt-1)

    // 2. QK-Clip
    for each attention head h:
        S^h_max โ† forward์—์„œ ๊ธฐ๋กํ•œ head h์˜ max pre-softmax logit
        if S^h_max > ฯ„:
            ฮณ โ† ฯ„ / S^h_max
            W^h_qc โ† W^h_qc ยท โˆšฮณ      (query compressed, q_nope)
            W^h_kc โ† W^h_kc ยท โˆšฮณ      (key compressed, k_nope)
            W^h_qr โ† W^h_qr ยท ฮณ       (query rotary, q_pe)
            // k_R (shared rotary, k_pe): ์•ˆ ๊ฑด๋“œ๋ฆผ
```

## ๊ธฐ์กด ์ฝ”๋“œ โ†’ MLA ์ˆ˜๋„์ฝ”๋“œ

### ํ˜„์žฌ ์ฝ”๋“œ ๊ตฌ์กฐ (MHA/GQA)

```
parse_qk_layer(name)          โ†’ wq/wk ์—ฌ๋ถ€ ํŒ๋ณ„, layer index ์ถ”์ถœ
get_qk_clip_info(config, n)   โ†’ QKClipInfo (kind, indices, head_dim, threshold, logit)
compute_scales(p, info)        โ†’ per-head โˆšฮณ scales ํ…์„œ ๋ฐ˜ํ™˜
qk_clip(p, scales, head_dim)  โ†’ W.view(-1, head_dim, in_dim).mul_(scales)
```

ํ˜„์žฌ ์ฝ”๋“œ๋Š” head_dim์ด ๊ท ์ผํ•˜๊ณ , Q/K weight ์ „์ฒด์— ๋™์ผํ•œ โˆšฮณ๋ฅผ ์ ์šฉํ•œ๋‹ค.

### MLA์—์„œ ๋‹ฌ๋ผ์ง€๋Š” ์ 

| ํ•ญ๋ชฉ | MHA/GQA (ํ˜„์žฌ) | MLA |
|---|---|---|
| Q weight | `wq` / `q_proj` | `wq_b` (up-proj from LoRA) |
| K weight | `wk` / `k_proj` | `wkv_b` (k_nope + v ํ•ฉ์ณ์ ธ ์žˆ์Œ) |
| Q head stride | `qk_head_dim` (๊ท ์ผ) | `qk_head_dim` = `qk_nope_head_dim + qk_rope_head_dim` |
| K head stride | `qk_head_dim` (๊ท ์ผ) | `kv_stride` = `qk_nope_head_dim + v_head_dim` |
| Q scaling | ์ „์ฒด โˆšฮณ | nope โ†’ โˆšฮณ, rope โ†’ ฮณ (์„œ๋กœ ๋‹ค๋ฆ„) |
| K scaling | ์ „์ฒด โˆšฮณ | k_nope โ†’ โˆšฮณ, v โ†’ 1.0 (๋ถ€๋ถ„๋งŒ) |
| shared k_pe | ์—†์Œ | `wkv_a` ๋’ท๋ถ€๋ถ„, ์•ˆ ๊ฑด๋“œ๋ฆผ |

### ์ˆ˜๋„์ฝ”๋“œ: parse_qk_layer (MLA ํ™•์žฅ)

```python
def parse_qk_layer(name: str) -> tuple[str | None, int]:
    parts = normalize_fqn(name).split('.')
    kind = parts[-2]

    layer_idx = -1
    for part in reversed(parts):
        if part.isdigit():
            layer_idx = int(part)
            break

    # MHA/GQA: wq, wk, q_proj, k_proj
    # MLA:     wq_b (Q up-proj), wkv_b (KV up-proj)
    if kind in ('wq', 'wk', 'q_proj', 'k_proj', 'wq_b', 'wkv_b'):
        return kind, layer_idx

    return None, -1
```

### ์ˆ˜๋„์ฝ”๋“œ: QKClipInfo (MLA ํ™•์žฅ)

```python
@dataclass
class QKClipInfo:
    kind: str | None          # 'wq_b' or 'wkv_b' (MLA) / 'wq','wk' (MHA)
    indices: list[int]        # clipping ๋Œ€์ƒ head indices
    head_dim: int             # ๊ธฐ์กด MHA์šฉ (uniform stride)
    threshold: float
    logit: torch.Tensor | None

    # MLA ์ „์šฉ ํ•„๋“œ
    is_mla: bool = False
    qk_nope_head_dim: int = 0
    qk_rope_head_dim: int = 0
    v_head_dim: int = 0
```

### ์ˆ˜๋„์ฝ”๋“œ: get_qk_clip_info (MLA ํ™•์žฅ)

```python
def get_qk_clip_info(clip_config, n, qk_logits):
    if clip_config is None:
        return None

    threshold = clip_config['threshold']
    kind, layer_idx = parse_qk_layer(n)
    is_mla = clip_config.get('is_mla', False)

    logit, indices = None, []
    if qk_logits is not None and kind is not None:
        logit = qk_logits[layer_idx]
        if isinstance(logit, DTensor):
            logit = logit.full_tensor()

        if kind in ('wq_b', 'wq', 'q_proj'):
            indices = clip_config.get('q_indices', []) or []
        elif kind in ('wkv_b', 'wk', 'k_proj'):
            indices = clip_config.get('k_indices', []) or []

    if is_mla:
        return QKClipInfo(
            kind=kind,
            indices=indices,
            head_dim=clip_config['head_dim'],          # qk_head_dim (for wq_b)
            threshold=threshold,
            logit=logit,
            is_mla=True,
            qk_nope_head_dim=clip_config['qk_nope_head_dim'],
            qk_rope_head_dim=clip_config['qk_rope_head_dim'],
            v_head_dim=clip_config['v_head_dim'],
        )
    else:
        # ๊ธฐ์กด MHA/GQA ๊ฒฝ๋กœ
        return QKClipInfo(
            kind=kind, indices=indices,
            head_dim=clip_config['head_dim'],
            threshold=threshold, logit=logit,
        )
```

### ์ˆ˜๋„์ฝ”๋“œ: compute_scales (MLA ํ™•์žฅ)

๊ธฐ์กด๊ณผ ๋™์ผํ•˜๊ฒŒ per-head ฮณ๋ฅผ ๊ณ„์‚ฐํ•œ๋‹ค. (ฮณ ๊ฒฐ์ •์€ MHA์™€ ๋™์ผ)
๋‹ฌ๋ผ์ง€๋Š” ๊ฑด `qk_clip` ์ ์šฉ ์‹œ head ๋‚ด๋ถ€๋ฅผ sub-region๋ณ„๋กœ ๋‚˜๋ˆ ์„œ ๋‹ค๋ฅธ ๋ณ€ํ™˜์„ ์“ฐ๋Š” ๊ฒƒ์ด๋‹ค.

```python
def compute_scales(p, qk_clip_state):
    """๊ธฐ์กด ์ฝ”๋“œ์™€ ๋™์ผ. per-head โˆšฮณ ๋ฐ˜ํ™˜."""
    kind = qk_clip_state.kind
    indices = qk_clip_state.indices
    threshold = qk_clip_state.threshold
    logit = qk_clip_state.logit

    head_scales = {}
    for logit_idx, head_idx in enumerate(indices):
        v_ele = float(logit[logit_idx])
        if v_ele > threshold:
            new_scale = math.sqrt(threshold / v_ele)      # โˆšฮณ
            if head_idx not in head_scales or new_scale < head_scales[head_idx]:
                head_scales[head_idx] = new_scale

    if not head_scales:
        return None

    H_global = p.shape[0] // qk_clip_state.head_dim      # MLA: head_dim = qk_head_dim or kv_stride
    scales_full = torch.ones(H_global, device=p.data.device)
    for head_idx, scale in head_scales.items():
        scales_full[head_idx] = scale                     # โˆšฮณ_h

    return scales_full
```

### ์ˆ˜๋„์ฝ”๋“œ: qk_clip (MLA ํ™•์žฅ)

per-head scales(โˆšฮณ)๋Š” ๋™์ผํ•˜๊ฒŒ ๋ฐ›๋˜, head ๋‚ด๋ถ€ sub-region์— ๋‹ค๋ฅธ ํ•จ์ˆ˜๋ฅผ ์ ์šฉํ•œ๋‹ค.

```python
def qk_clip(p, scales, head_dim, is_mla=False, kind=None, info=None):
    """
    scales: [n_heads] ํ…์„œ, ๊ฐ ์›์†Œ = โˆšฮณ_h

    is_mla=False: ๊ธฐ์กด MHA/GQA (head ๋‚ด uniform โˆšฮณ)
    is_mla=True:  MLA (head ๋‚ด sub-region๋ณ„ ๋‹ค๋ฅธ ๋ณ€ํ™˜)
    """
    W = p.data if isinstance(p, torch.nn.Parameter) else p

    if not is_mla:
        # ๊ธฐ์กด: ๋ชจ๋“  ํ–‰์— โˆšฮณ ๊ท ์ผ ์ ์šฉ
        W.view(-1, head_dim, W.shape[1]).mul_(scales.view(-1, 1, 1))
        return

    # MLA: head๋ณ„๋กœ sub-region ๋ถ„๋ฆฌ ์ ์šฉ
    if kind == 'wq_b':
        qk_nope = info.qk_nope_head_dim
        qk_rope = info.qk_rope_head_dim
        qk_head_dim = qk_nope + qk_rope

        for h in range(len(scales)):
            sqrt_gamma = scales[h].item()
            if sqrt_gamma >= 1.0:
                continue
            gamma = sqrt_gamma * sqrt_gamma      # โˆšฮณ โ†’ ฮณ
            s = h * qk_head_dim

            W[s : s + qk_nope]              *= sqrt_gamma   # q_nope โ†’ โˆšฮณ
            W[s + qk_nope : s + qk_head_dim] *= gamma       # q_pe   โ†’ ฮณ

    elif kind == 'wkv_b':
        qk_nope = info.qk_nope_head_dim
        kv_stride = qk_nope + info.v_head_dim

        for h in range(len(scales)):
            sqrt_gamma = scales[h].item()
            if sqrt_gamma >= 1.0:
                continue
            s = h * kv_stride

            W[s : s + qk_nope] *= sqrt_gamma                # k_nope โ†’ โˆšฮณ
            # v ํ–‰: ์•ˆ ๊ฑด๋“œ๋ฆผ
```

### ์ˆ˜๋„์ฝ”๋“œ: GQA์—์„œ wkv_b indices ์ฒ˜๋ฆฌ

Q head โ†’ KV head ๋งคํ•‘์ด ํ•„์š”ํ•˜๋‹ค.
์—ฌ๋Ÿฌ Q head๊ฐ€ ๊ฐ™์€ KV head๋ฅผ ๊ณต์œ ํ•˜๋ฏ€๋กœ, **group ๋‚ด ์ตœ์†Œ gamma** ๊ธฐ์ค€์œผ๋กœ ํ•œ ๋ฒˆ๋งŒ ์ ์šฉํ•ด์•ผ ํ•œ๋‹ค.

```python
def build_k_indices_for_mla(clip_config, n_heads, n_kv_heads):
    """
    Q head ๊ธฐ์ค€ logit์œผ๋กœ๋ถ€ํ„ฐ KV head indices๋ฅผ ์ƒ์„ฑํ•œ๋‹ค.
    q_indices๊ฐ€ Q head index ๊ธฐ์ค€์ด๋ผ๋ฉด,
    k_indices๋Š” ๋Œ€์‘๋˜๋Š” KV head index๋กœ ๋ณ€ํ™˜ํ•ด์•ผ ํ•œ๋‹ค.

    ์ฃผ์˜: ๊ฐ™์€ KV head์— ๋งคํ•‘๋˜๋Š” ์—ฌ๋Ÿฌ Q head ์ค‘
          ๊ฐ€์žฅ ํฐ logit (= ๊ฐ€์žฅ ์ž‘์€ gamma)์„ ์‚ฌ์šฉํ•ด์•ผ ํ•œ๋‹ค.
    """
    heads_per_kv = n_heads // n_kv_heads
    q_indices = clip_config.get('q_indices', list(range(n_heads)))

    # Q head โ†’ KV head ๋งคํ•‘
    # logit ํ…์„œ์—์„œ ๊ฐ™์€ kv_head์— ๋Œ€์‘๋˜๋Š” Q head๋“ค ์ค‘ max๋ฅผ ์ทจํ•˜๋Š” ๊ฒƒ์€
    # compute_scales_mla ๋‚ด๋ถ€์—์„œ min(gamma) ๋กœ ์ฒ˜๋ฆฌ๋จ

    k_indices = []
    seen = set()
    for q_idx in q_indices:
        kv_idx = q_idx // heads_per_kv
        if kv_idx not in seen:
            k_indices.append(kv_idx)
            seen.add(kv_idx)

    return k_indices
```

### ์ˆ˜๋„์ฝ”๋“œ: ํ˜ธ์ถœ ํ๋ฆ„ (ํ†ตํ•ฉ)

```python
# optimizer step ์ดํ›„ ํ˜ธ์ถœ๋˜๋Š” ๋ถ€๋ถ„ (๊ธฐ์กด ์ฝ”๋“œ ๊ตฌ์กฐ ์œ ์ง€)

for name, param in model.named_parameters():
    info = get_qk_clip_info(clip_config, name, qk_logits)
    if info is None or info.kind is None:
        continue

    scales = compute_scales(param, info)    # per-head โˆšฮณ (MHA/MLA ๊ณตํ†ต)
    if scales is not None:
        qk_clip(param, scales, info.head_dim,
                is_mla=info.is_mla, kind=info.kind, info=info)
```

### ์ˆ˜๋„์ฝ”๋“œ: clip_config ์˜ˆ์‹œ

```python
# MHA/GQA (๊ธฐ์กด)
clip_config = {
    'head_dim': 128,
    'threshold': 100.0,
    'q_indices': list(range(n_heads)),
    'k_indices': list(range(n_kv_heads)),
}

# MLA (ํ™•์žฅ)
clip_config = {
    'is_mla': True,
    'head_dim': 192,                  # qk_head_dim (= qk_nope + qk_rope)
    'qk_nope_head_dim': 128,
    'qk_rope_head_dim': 64,
    'v_head_dim': 128,
    'threshold': 100.0,
    'q_indices': list(range(n_heads)),
    'k_indices': list(range(n_kv_heads)),  # build_k_indices_for_mla๋กœ ์ƒ์„ฑ
}
```

## ํ–‰ ์ธ๋ฑ์Šค ๋งคํ•‘ ํ…Œ์ด๋ธ”

| ์•Œ๊ณ ๋ฆฌ์ฆ˜ ๊ธฐํ˜ธ | ํ…์„œ | ํ–‰ ๋ฒ”์œ„ | scale |
|---|---|---|---|
| W^h_qc | `wq_b.weight` | `[h*qk_head_dim : h*qk_head_dim + qk_nope_head_dim]` | โˆšฮณ |
| W^h_qr | `wq_b.weight` | `[h*qk_head_dim + qk_nope_head_dim : (h+1)*qk_head_dim]` | ฮณ |
| W^h_kc | `wkv_b.weight` | `[kv_h*kv_stride : kv_h*kv_stride + qk_nope_head_dim]` | โˆšฮณ |
| k_R | `wkv_a` output ๋’ท๋ถ€๋ถ„ | - | ์•ˆ ๊ฑด๋“œ๋ฆผ |

- `kv_stride = qk_nope_head_dim + v_head_dim`
- `kv_h = h // (n_heads // n_kv_heads)` (GQA head ๋งคํ•‘)

## ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ

| ํŒŒ๋ผ๋ฏธํ„ฐ | ๊ฐ’ | ๋น„๊ณ  |
|---|---|---|
| ฯ„ (threshold) | 100 | K2 full-scale ํ•™์Šต |
| ฯ„ (aggressive) | 30 | ์†Œ๊ทœ๋ชจ ablation, ์„ฑ๋Šฅ ์ €ํ•˜ ์—†์Œ ํ™•์ธ |

## ์ฐธ๊ณ ์‚ฌํ•ญ

- **Self-deactivation**: K2์—์„œ ์ดˆ๊ธฐ 70k step ๋™์•ˆ 12.7%์˜ head๋งŒ trigger๋จ. ์ดํ›„ ๋ชจ๋“  head์˜ S_max๊ฐ€ ฯ„ ์•„๋ž˜๋กœ ๋‚ด๋ ค๊ฐ€๋ฉด์„œ ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ๋น„ํ™œ์„ฑํ™”.
- **DP/TP ํ™˜๊ฒฝ**: S^h_max๋ฅผ all-reduce๋กœ ๋ชจ๋“  rank์—์„œ max ์ˆ˜์ง‘ ํ•„์š”.
- **GQA ์ค‘๋ณต ์ ์šฉ ๋ฐฉ์ง€**: ๊ฐ™์€ KV head๋ฅผ ๊ณต์œ ํ•˜๋Š” Q head group์—์„œ ๊ฐ€์žฅ ์ž‘์€ gamma(= ๊ฐ€์žฅ ํฐ logit)๋ฅผ ๊ธฐ์ค€์œผ๋กœ KV weight๋ฅผ ํ•œ ๋ฒˆ๋งŒ scaling. `compute_scales_mla`์—์„œ `min(gamma)` ๋กœ์ง์œผ๋กœ ์ฒ˜๋ฆฌ.
- **wq_b_gate**: attention logit์ด ์•„๋‹Œ output gate์—๋งŒ ๊ด€์—ฌํ•˜๋ฏ€๋กœ QK-Clip ๋Œ€์ƒ ์•„๋‹˜.
- **๊ธฐ์กด logit soft-cap**: forward-level safety net์œผ๋กœ ๋‚จ๊ฒจ๋‘๋˜, optimizer-level QK-Clip์„ ์ถ”๊ฐ€ํ•˜๋Š” ๊ฒƒ์ด ๋…ผ๋ฌธ์˜ ์ ‘๊ทผ๋ฒ•.