Pramodith commited on
Commit
c92a555
·
verified ·
1 Parent(s): 55142ca

Uploaded using `kernel-builder`.

Browse files
Files changed (1) hide show
  1. README.md +43 -57
README.md CHANGED
@@ -17,20 +17,21 @@ tags:
17
 
18
  Fused **CuteDSL** kernels for the loss functions that dominate post-training
19
  workloads: PPO-family policy losses (BNPO, GRPO) and reverse-KL
20
- self-distillation. Each kernel ships a **single-launch fused forward +
21
- backward** path that returns `(loss, grad_logprobs)` directly — no
22
- `torch.autograd.Function` wrapper, no extra `grad_output * dpolicy` backward
 
23
  kernel, and no host-side syncs in the hot path.
24
 
25
  Background and benchmarks: see the
26
  [release post](https://geometric.so/blog/2026/05/08/hf-kernel-hub).
27
 
28
  - **Backend**: CUDA (NVIDIA CUTLASS DSL).
29
- - **Min GPU**: SM80 (Ampere) required by `nvidia-cutlass-dsl`. Tested on A100 (SM80) and H100 (SM90). Works on SM86 (RTX 3090, A40), SM89 (RTX 4090, L40S), SM90a (H100 SXM), and SM100 (Blackwell B200/GB200).
30
  - **Min CUDA**: 12.8.
31
  - **Dtypes**: `float32`, `float16`, `bfloat16`.
32
  - **Dynamic shapes**: a single compile handles arbitrary batch size and
33
- sequence length no recompiles when shapes change between calls (common
34
  in post-training rollouts).
35
 
36
  ## Kernels
@@ -45,15 +46,15 @@ Background and benchmarks: see the
45
 
46
  Each kernel family exposes three entry points with the same underlying CuteDSL kernel:
47
 
48
- - **`<name>(...)`** fused fwd+bwd, returns `(loss, grad)` from one `@cute.jit`
49
  dispatch. Lowest-overhead path; the caller chains the gradient into the upstream
50
  model with `policy_logprobs.backward(grad)`. Use this in custom training loops
51
  where you control gradient flow.
52
- - **`<name>_autograd(...)`** same kernel, registered via
53
  `torch.library.custom_op` + `register_autograd`. `loss.backward()` works
54
  and composes with `torch.compile(fullgraph=True)`. There is a noticeable
55
  per-call dispatcher overhead vs. the direct path.
56
- - **`<name>_fwd(...)`** forward-only, returns scalar `loss` and skips
57
  the gradient buffer entirely. Use for inference / validation /
58
  reward-model scoring.
59
 
@@ -85,7 +86,7 @@ per_token = −advantages · min(ratio, clipped)
85
  kl = exp(ref_logprobs − policy_logprobs) − (ref_logprobs − policy_logprobs) − 1
86
  ```
87
 
88
- The global denominator is computed entirely on-GPU via cross-CTA atomics
89
  no host-side `mask.sum()` sync. When `beta=0` the KL branch is dead-coded
90
  at compile time.
91
 
@@ -110,7 +111,7 @@ ref_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=devi
110
  advantages = torch.randn(bs, dtype=torch.bfloat16, device=device)
111
  completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2).to(torch.int8)
112
 
113
- # 1) Direct (loss, grad) lowest overhead training path
114
  loss, grad = km.bnpo_loss(
115
  policy_logprobs, old_policy_logprobs, ref_logprobs,
116
  advantages, completions_mask,
@@ -118,7 +119,7 @@ loss, grad = km.bnpo_loss(
118
  )
119
  policy_logprobs.backward(grad)
120
 
121
- # 2) Autograd-aware works with loss.backward() and torch.compile
122
  loss = km.bnpo_loss_autograd(
123
  policy_logprobs.requires_grad_(),
124
  old_policy_logprobs, ref_logprobs,
@@ -127,7 +128,7 @@ loss = km.bnpo_loss_autograd(
127
  )
128
  loss.backward()
129
 
130
- # 3) Forward-only inference / reward scoring, no gradient buffer
131
  loss = km.bnpo_loss_fwd(
132
  policy_logprobs, old_policy_logprobs, ref_logprobs,
133
  advantages, completions_mask,
@@ -140,7 +141,7 @@ loss = km.bnpo_loss_fwd(
140
  ## GRPO Loss
141
 
142
  **Group Relative Policy Optimization** implements TRL's default
143
- **per-response normalization** variant each response is normalized by its
144
  own valid-token count before averaging across the batch:
145
 
146
  ```
@@ -150,12 +151,12 @@ loss = mean_r( ((per_token_loss + β·kl) · mask).sum(-1) / max(mask.sum(-1), 1
150
  `per_token_loss` and `kl` are the same clipped-ratio and KL expressions as BNPO.
151
  `completions_mask` is **required** because the per-response denominator is
152
  mask-derived. The kernel uses one CTA per row so the per-row mask sum is
153
- reduced inside the block no cross-CTA atomics on the scaling pass.
154
 
155
  **Inputs**:
156
  - `policy_logprobs`, `old_policy_logprobs`, `ref_logprobs`: `(bs, seq_len)`, fp32/fp16/bf16
157
  - `advantages`: `(bs,)`
158
- - `completions_mask`: `(bs, seq_len)`, bool or int8 **required**
159
 
160
  **Returns**: `(loss, grad_policy_logprobs)` from `grpo_loss`; scalar `loss` from `grpo_loss_fwd`.
161
 
@@ -173,7 +174,7 @@ ref_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=devi
173
  advantages = torch.randn(bs, dtype=torch.bfloat16, device=device)
174
  completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2).to(torch.int8)
175
 
176
- # 1) Direct (loss, grad) lowest overhead training path
177
  loss, grad = km.grpo_loss(
178
  policy_logprobs, old_policy_logprobs, ref_logprobs,
179
  advantages, completions_mask,
@@ -181,7 +182,7 @@ loss, grad = km.grpo_loss(
181
  )
182
  policy_logprobs.backward(grad)
183
 
184
- # 2) Autograd-aware works with loss.backward() and torch.compile
185
  loss = km.grpo_loss_autograd(
186
  policy_logprobs.requires_grad_(),
187
  old_policy_logprobs, ref_logprobs,
@@ -190,7 +191,7 @@ loss = km.grpo_loss_autograd(
190
  )
191
  loss.backward()
192
 
193
- # 3) Forward-only inference / reward scoring, no gradient buffer
194
  loss = km.grpo_loss_fwd(
195
  policy_logprobs, old_policy_logprobs, ref_logprobs,
196
  advantages, completions_mask,
@@ -222,7 +223,7 @@ grad_student_v = scale · p_v · (log p_v − log q_v − kl_per_row)
222
  where `scale = mask[r] · inv_n_valid`.
223
 
224
  **Inputs**:
225
- - `student_logits`, `teacher_logits`: `(*, V)` arbitrary leading dims (typically `(bs, seq_len, vocab)`); both must share shape and dtype
226
  - `completions_mask`: shape matching `student_logits.shape[:-1]`
227
 
228
  > ⚠️ **Fully-masked batches**: `inv_n_valid = 1 / mask.sum()` is not clamped, so a batch where every token is masked produces inf/NaN. Guard upstream if that case is reachable.
@@ -242,17 +243,17 @@ student_logits = torch.randn(bs, seq_len, vocab, dtype=torch.bfloat16, device=d
242
  teacher_logits = torch.randn(bs, seq_len, vocab, dtype=torch.bfloat16, device=device)
243
  completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2)
244
 
245
- # 1) Direct (loss, grad) lowest overhead training path
246
  loss, grad = km.reverse_kl(student_logits, teacher_logits, completions_mask)
247
  student_logits.backward(grad)
248
 
249
- # 2) Autograd-aware works with loss.backward() and torch.compile
250
  loss = km.reverse_kl_autograd(
251
  student_logits.requires_grad_(), teacher_logits, completions_mask
252
  )
253
  loss.backward()
254
 
255
- # 3) Forward-only inference / KL monitoring, no gradient buffer
256
  loss = km.reverse_kl_fwd(student_logits, teacher_logits, completions_mask)
257
  ```
258
 
@@ -260,37 +261,22 @@ loss = km.reverse_kl_fwd(student_logits, teacher_logits, completions_mask)
260
 
261
  ## Performance
262
 
263
- Numbers below are geometric-mean speedups from our in-house benchmark
264
- (`triton.testing.do_bench`, fresh subprocess per shape). Baselines are
265
- **eager PyTorch** and **`torch.compile(mode="max-autotune-no-cudagraphs",
266
- fullgraph=True)`** with `torch._dynamo.config.trace_autograd_ops = True` so
267
- the compiled baseline is a real Inductor-fused fwd+bwd graph.
268
 
269
- | Kernel | β | vs eager | vs `torch.compile` |
270
- | --- | --- | --- | --- |
271
- | `bnpo_loss_fwd` | 0 | 1. | 1.3× |
272
- | `bnpo_loss` | 0 | 1. | 1.2× |
273
- | `bnpo_loss_fwd` | ≠ 0 | 1.4× | 1.1× |
274
- | `bnpo_loss` | ≠ 0 | 1.3× | 1.0× |
275
- | `grpo_loss_fwd` | ≠ 0 | 1.5× | 1.2× |
276
- | `grpo_loss` | ≠ 0 | 1.4× | 1.1× |
277
- | `reverse_kl_fwd`| | 1.3× | 1.1× |
278
- | `reverse_kl` | | 1.2× | 1.0× |
279
-
280
- Profiled on H100 SXM (SM90a). BNPO and GRPO benchmarked separately for `β = 0`
281
- (KL term dead-coded at compile time) and `β ≠ 0`. Shapes:
282
-
283
- - **BNPO / GRPO**: `(16, 1024)`, `(32, 2048)`, `(64, 4096)`, `(128, 8192)`,
284
- `(128, 8193)` — the last entry exercises the predicated tail-tile path.
285
- - **Reverse KL** (vocab = 248320, matching Qwen3.5): `(1, 64)`, `(2, 128)`,
286
- `(4, 256)`, `(8, 512)`, `(16, 1024)`, `(8, 1981)`.
287
-
288
- Reproduce locally:
289
-
290
- ```bash
291
- make bench-kernel KERNEL=grpo_loss # or bnpo_loss, reverse_kl
292
- ```
293
 
 
 
 
 
 
 
 
 
294
  ---
295
 
296
  ## Benchmark animations
@@ -299,40 +285,40 @@ make bench-kernel KERNEL=grpo_loss # or bnpo_loss, reverse_kl
299
 
300
  <picture>
301
  <source media="(prefers-color-scheme: dark)" srcset="benchmark_results/bnpo_loss_eager/bnpo_loss_eager_dark_animation.svg">
302
- <img src="benchmark_results/bnpo_loss_eager/bnpo_loss_eager_light_animation.svg" alt="BNPO loss latency vs eager PyTorch">
303
  </picture>
304
 
305
  ### BNPO Loss vs torch.compile
306
 
307
  <picture>
308
  <source media="(prefers-color-scheme: dark)" srcset="benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_dark_animation.svg">
309
- <img src="benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_light_animation.svg" alt="BNPO loss latency vs torch.compile">
310
  </picture>
311
 
312
  ### GRPO Loss vs eager PyTorch
313
 
314
  <picture>
315
  <source media="(prefers-color-scheme: dark)" srcset="benchmark_results/grpo_loss_eager/grpo_loss_eager_dark_animation.svg">
316
- <img src="benchmark_results/grpo_loss_eager/grpo_loss_eager_light_animation.svg" alt="GRPO loss latency vs eager PyTorch">
317
  </picture>
318
 
319
  ### GRPO Loss vs torch.compile
320
 
321
  <picture>
322
  <source media="(prefers-color-scheme: dark)" srcset="benchmark_results/grpo_loss_compiled/grpo_loss_compiled_dark_animation.svg">
323
- <img src="benchmark_results/grpo_loss_compiled/grpo_loss_compiled_light_animation.svg" alt="GRPO loss latency vs torch.compile">
324
  </picture>
325
 
326
  ### Reverse KL vs eager PyTorch
327
 
328
  <picture>
329
  <source media="(prefers-color-scheme: dark)" srcset="benchmark_results/reverse_kl_eager/reverse_kl_eager_dark_animation.svg">
330
- <img src="benchmark_results/reverse_kl_eager/reverse_kl_eager_light_animation.svg" alt="Reverse KL latency vs eager PyTorch">
331
  </picture>
332
 
333
  ### Reverse KL vs torch.compile
334
 
335
  <picture>
336
  <source media="(prefers-color-scheme: dark)" srcset="benchmark_results/reverse_kl_compiled/reverse_kl_compiled_dark_animation.svg">
337
- <img src="benchmark_results/reverse_kl_compiled/reverse_kl_compiled_light_animation.svg" alt="Reverse KL latency vs torch.compile">
338
  </picture>
 
17
 
18
  Fused **CuteDSL** kernels for the loss functions that dominate post-training
19
  workloads: PPO-family policy losses (BNPO, GRPO) and reverse-KL
20
+ self-distillation.
21
+
22
+ Each kernel ships a **single-launch fused forward +
23
+ backward** path that returns `(loss, grad_logprobs)` directly. No `torch.autograd.Function` wrapper, no extra `grad_output * dpolicy` backward
24
  kernel, and no host-side syncs in the hot path.
25
 
26
  Background and benchmarks: see the
27
  [release post](https://geometric.so/blog/2026/05/08/hf-kernel-hub).
28
 
29
  - **Backend**: CUDA (NVIDIA CUTLASS DSL).
30
+ - **Min GPU**: SM80 (Ampere) - required by `nvidia-cutlass-dsl`. Tested on H100 (SM90). Should work on SM80 (Ampere), SM86 (RTX 3090, A40), SM89 (RTX 4090, L40S), SM90a (H100 SXM), and SM100 (Blackwell B200/GB200).
31
  - **Min CUDA**: 12.8.
32
  - **Dtypes**: `float32`, `float16`, `bfloat16`.
33
  - **Dynamic shapes**: a single compile handles arbitrary batch size and
34
+ sequence length, no recompiles when shapes change between calls (common
35
  in post-training rollouts).
36
 
37
  ## Kernels
 
46
 
47
  Each kernel family exposes three entry points with the same underlying CuteDSL kernel:
48
 
49
+ - **`<name>(...)`** - fused fwd+bwd, returns `(loss, grad)` from one `@cute.jit`
50
  dispatch. Lowest-overhead path; the caller chains the gradient into the upstream
51
  model with `policy_logprobs.backward(grad)`. Use this in custom training loops
52
  where you control gradient flow.
53
+ - **`<name>_autograd(...)`** - same kernel, registered via
54
  `torch.library.custom_op` + `register_autograd`. `loss.backward()` works
55
  and composes with `torch.compile(fullgraph=True)`. There is a noticeable
56
  per-call dispatcher overhead vs. the direct path.
57
+ - **`<name>_fwd(...)`** - forward-only, returns scalar `loss` and skips
58
  the gradient buffer entirely. Use for inference / validation /
59
  reward-model scoring.
60
 
 
86
  kl = exp(ref_logprobs − policy_logprobs) − (ref_logprobs − policy_logprobs) − 1
87
  ```
88
 
89
+ The global denominator is computed entirely on-GPU via cross-CTA atomics -
90
  no host-side `mask.sum()` sync. When `beta=0` the KL branch is dead-coded
91
  at compile time.
92
 
 
111
  advantages = torch.randn(bs, dtype=torch.bfloat16, device=device)
112
  completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2).to(torch.int8)
113
 
114
+ # 1) Direct (loss, grad) - lowest overhead training path
115
  loss, grad = km.bnpo_loss(
116
  policy_logprobs, old_policy_logprobs, ref_logprobs,
117
  advantages, completions_mask,
 
119
  )
120
  policy_logprobs.backward(grad)
121
 
122
+ # 2) Autograd-aware - works with loss.backward() and torch.compile
123
  loss = km.bnpo_loss_autograd(
124
  policy_logprobs.requires_grad_(),
125
  old_policy_logprobs, ref_logprobs,
 
128
  )
129
  loss.backward()
130
 
131
+ # 3) Forward-only - inference / reward scoring, no gradient buffer
132
  loss = km.bnpo_loss_fwd(
133
  policy_logprobs, old_policy_logprobs, ref_logprobs,
134
  advantages, completions_mask,
 
141
  ## GRPO Loss
142
 
143
  **Group Relative Policy Optimization** implements TRL's default
144
+ **per-response normalization** variant - each response is normalized by its
145
  own valid-token count before averaging across the batch:
146
 
147
  ```
 
151
  `per_token_loss` and `kl` are the same clipped-ratio and KL expressions as BNPO.
152
  `completions_mask` is **required** because the per-response denominator is
153
  mask-derived. The kernel uses one CTA per row so the per-row mask sum is
154
+ reduced inside the block - no cross-CTA atomics on the scaling pass.
155
 
156
  **Inputs**:
157
  - `policy_logprobs`, `old_policy_logprobs`, `ref_logprobs`: `(bs, seq_len)`, fp32/fp16/bf16
158
  - `advantages`: `(bs,)`
159
+ - `completions_mask`: `(bs, seq_len)`, bool or int8 - **required**
160
 
161
  **Returns**: `(loss, grad_policy_logprobs)` from `grpo_loss`; scalar `loss` from `grpo_loss_fwd`.
162
 
 
174
  advantages = torch.randn(bs, dtype=torch.bfloat16, device=device)
175
  completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2).to(torch.int8)
176
 
177
+ # 1) Direct (loss, grad) - lowest overhead training path
178
  loss, grad = km.grpo_loss(
179
  policy_logprobs, old_policy_logprobs, ref_logprobs,
180
  advantages, completions_mask,
 
182
  )
183
  policy_logprobs.backward(grad)
184
 
185
+ # 2) Autograd-aware - works with loss.backward() and torch.compile
186
  loss = km.grpo_loss_autograd(
187
  policy_logprobs.requires_grad_(),
188
  old_policy_logprobs, ref_logprobs,
 
191
  )
192
  loss.backward()
193
 
194
+ # 3) Forward-only - inference / reward scoring, no gradient buffer
195
  loss = km.grpo_loss_fwd(
196
  policy_logprobs, old_policy_logprobs, ref_logprobs,
197
  advantages, completions_mask,
 
223
  where `scale = mask[r] · inv_n_valid`.
224
 
225
  **Inputs**:
226
+ - `student_logits`, `teacher_logits`: `(*, V)` - arbitrary leading dims (typically `(bs, seq_len, vocab)`); both must share shape and dtype
227
  - `completions_mask`: shape matching `student_logits.shape[:-1]`
228
 
229
  > ⚠️ **Fully-masked batches**: `inv_n_valid = 1 / mask.sum()` is not clamped, so a batch where every token is masked produces inf/NaN. Guard upstream if that case is reachable.
 
243
  teacher_logits = torch.randn(bs, seq_len, vocab, dtype=torch.bfloat16, device=device)
244
  completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2)
245
 
246
+ # 1) Direct (loss, grad) - lowest overhead training path
247
  loss, grad = km.reverse_kl(student_logits, teacher_logits, completions_mask)
248
  student_logits.backward(grad)
249
 
250
+ # 2) Autograd-aware - works with loss.backward() and torch.compile
251
  loss = km.reverse_kl_autograd(
252
  student_logits.requires_grad_(), teacher_logits, completions_mask
253
  )
254
  loss.backward()
255
 
256
+ # 3) Forward-only - inference / KL monitoring, no gradient buffer
257
  loss = km.reverse_kl_fwd(student_logits, teacher_logits, completions_mask)
258
  ```
259
 
 
261
 
262
  ## Performance
263
 
264
+ All numbers are geometric-mean speedups over H100 SXM (SM90a). Full methodology
265
+ and per-shape plots in the [release post](https://geometric.so/blog/2026/05/08/hf-kernel-hub).
 
 
 
266
 
267
+ ### `kernels` CLI benchmark
268
+
269
+ Timed with `time.perf_counter` + `cuda.synchronize()`, single iteration per
270
+ shape, mean over 100 iterations. Baseline runs once after warmup.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
+ | Kernel | vs eager | vs `torch.compile` |
273
+ | --- | --- | --- |
274
+ | `grpo_loss_fwd` | 5.60× | 2.52× |
275
+ | `grpo_loss` | 19.92× | 2.28× |
276
+ | `bnpo_loss_fwd` | 5.58× | 2.54× |
277
+ | `bnpo_loss` | 17.46× | 2.15× |
278
+ | `reverse_kl_fwd`| 6.95× | 2.44× |
279
+ | `reverse_kl` | 7.06× | 2.59× |
280
  ---
281
 
282
  ## Benchmark animations
 
285
 
286
  <picture>
287
  <source media="(prefers-color-scheme: dark)" srcset="benchmark_results/bnpo_loss_eager/bnpo_loss_eager_dark_animation.svg">
288
+ <img width="90%" src="benchmark_results/bnpo_loss_eager/bnpo_loss_eager_light_animation.svg" alt="BNPO loss latency vs eager PyTorch">
289
  </picture>
290
 
291
  ### BNPO Loss vs torch.compile
292
 
293
  <picture>
294
  <source media="(prefers-color-scheme: dark)" srcset="benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_dark_animation.svg">
295
+ <img width="90%" src="benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_light_animation.svg" alt="BNPO loss latency vs torch.compile">
296
  </picture>
297
 
298
  ### GRPO Loss vs eager PyTorch
299
 
300
  <picture>
301
  <source media="(prefers-color-scheme: dark)" srcset="benchmark_results/grpo_loss_eager/grpo_loss_eager_dark_animation.svg">
302
+ <img width="90%" src="benchmark_results/grpo_loss_eager/grpo_loss_eager_light_animation.svg" alt="GRPO loss latency vs eager PyTorch">
303
  </picture>
304
 
305
  ### GRPO Loss vs torch.compile
306
 
307
  <picture>
308
  <source media="(prefers-color-scheme: dark)" srcset="benchmark_results/grpo_loss_compiled/grpo_loss_compiled_dark_animation.svg">
309
+ <img width="90%" src="benchmark_results/grpo_loss_compiled/grpo_loss_compiled_light_animation.svg" alt="GRPO loss latency vs torch.compile">
310
  </picture>
311
 
312
  ### Reverse KL vs eager PyTorch
313
 
314
  <picture>
315
  <source media="(prefers-color-scheme: dark)" srcset="benchmark_results/reverse_kl_eager/reverse_kl_eager_dark_animation.svg">
316
+ <img width="90%" src="benchmark_results/reverse_kl_eager/reverse_kl_eager_light_animation.svg" alt="Reverse KL latency vs eager PyTorch">
317
  </picture>
318
 
319
  ### Reverse KL vs torch.compile
320
 
321
  <picture>
322
  <source media="(prefers-color-scheme: dark)" srcset="benchmark_results/reverse_kl_compiled/reverse_kl_compiled_dark_animation.svg">
323
+ <img width="90%" src="benchmark_results/reverse_kl_compiled/reverse_kl_compiled_light_animation.svg" alt="Reverse KL latency vs torch.compile">
324
  </picture>