Pramodith commited on
Commit
43849eb
·
verified ·
1 Parent(s): 4a2470a

Uploaded using `kernel-builder`.

Browse files
Files changed (1) hide show
  1. README.md +338 -3
README.md CHANGED
@@ -1,3 +1,338 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: kernels
3
+ license: apache-2.0
4
+ tags:
5
+ - cuda
6
+ - cutlass
7
+ - cute-dsl
8
+ - rl
9
+ - distillation
10
+ - trl
11
+ - grpo
12
+ - bnpo
13
+ - kl-divergence
14
+ ---
15
+
16
+ # Geometric-AI Kernels
17
+
18
+ Fused **CuteDSL** kernels for the loss functions that dominate post-training
19
+ workloads: PPO-family policy losses (BNPO, GRPO) and reverse-KL
20
+ self-distillation. Each kernel ships a **single-launch fused forward +
21
+ backward** path that returns `(loss, grad_logprobs)` directly — no
22
+ `torch.autograd.Function` wrapper, no extra `grad_output * dpolicy` backward
23
+ kernel, and no host-side syncs in the hot path.
24
+
25
+ Background and benchmarks: see the
26
+ [release post](https://geometric.so/blog/2026/05/08/hf-kernel-hub).
27
+
28
+ - **Backend**: CUDA (NVIDIA CUTLASS DSL).
29
+ - **Min GPU**: SM80 (Ampere) — required by `nvidia-cutlass-dsl`. Tested on A100 (SM80) and H100 (SM90). Works on SM86 (RTX 3090, A40), SM89 (RTX 4090, L40S), SM90a (H100 SXM), and SM100 (Blackwell B200/GB200).
30
+ - **Min CUDA**: 12.8.
31
+ - **Dtypes**: `float32`, `float16`, `bfloat16`.
32
+ - **Dynamic shapes**: a single compile handles arbitrary batch size and
33
+ sequence length — no recompiles when shapes change between calls (common
34
+ in post-training rollouts).
35
+
36
+ ## Kernels
37
+
38
+ | Kernel family | Direct (no autograd) | Autograd-aware | Forward-only |
39
+ | --- | --- | --- | --- |
40
+ | BNPO loss | `bnpo_loss` | `bnpo_loss_autograd` | `bnpo_loss_fwd` |
41
+ | GRPO loss | `grpo_loss` | `grpo_loss_autograd` | `grpo_loss_fwd` |
42
+ | Reverse KL | `reverse_kl` | `reverse_kl_autograd` | `reverse_kl_fwd` |
43
+
44
+ ### Entry points
45
+
46
+ Each kernel family exposes three entry points with the same underlying CuteDSL kernel:
47
+
48
+ - **`<name>(...)`** — fused fwd+bwd, returns `(loss, grad)` from one `@cute.jit`
49
+ dispatch. Lowest-overhead path; the caller chains the gradient into the upstream
50
+ model with `policy_logprobs.backward(grad)`. Use this in custom training loops
51
+ where you control gradient flow.
52
+ - **`<name>_autograd(...)`** — same kernel, registered via
53
+ `torch.library.custom_op` + `register_autograd`. `loss.backward()` works
54
+ and composes with `torch.compile(fullgraph=True)`. There is a noticeable
55
+ per-call dispatcher overhead vs. the direct path.
56
+ - **`<name>_fwd(...)`** — forward-only, returns scalar `loss` and skips
57
+ the gradient buffer entirely. Use for inference / validation /
58
+ reward-model scoring.
59
+
60
+ ## Loading the kernels
61
+
62
+ ```python
63
+ from kernels import get_kernel
64
+
65
+ km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0)
66
+ ```
67
+
68
+ ---
69
+
70
+ ## BNPO Loss
71
+
72
+ **Batch-Normalized Policy Optimization** sums per-token policy and KL terms
73
+ across the **entire batch** and divides by the global valid-token count:
74
+
75
+ ```
76
+ loss = ((per_token_loss + β·kl) · mask).sum() / max(mask.sum(), 1)
77
+ ```
78
+
79
+ where `per_token_loss` is the PPO-clipped ratio loss:
80
+
81
+ ```
82
+ ratio = exp(policy_logprobs - old_policy_logprobs)
83
+ clipped = clip(ratio, 1−ε, 1+ε_high)
84
+ per_token = −advantages · min(ratio, clipped)
85
+ kl = exp(ref_logprobs − policy_logprobs) − (ref_logprobs − policy_logprobs) − 1
86
+ ```
87
+
88
+ The global denominator is computed entirely on-GPU via cross-CTA atomics —
89
+ no host-side `mask.sum()` sync. When `beta=0` the KL branch is dead-coded
90
+ at compile time.
91
+
92
+ **Inputs**:
93
+ - `policy_logprobs`, `old_policy_logprobs`, `ref_logprobs`: `(bs, seq_len)`, fp32/fp16/bf16
94
+ - `advantages`: `(bs,)`
95
+ - `completions_mask`: `(bs, seq_len)`, bool or int8
96
+
97
+ **Returns**: `(loss, grad_policy_logprobs)` from `bnpo_loss`; scalar `loss` from `bnpo_loss_fwd`.
98
+
99
+ ```python
100
+ import torch
101
+ from kernels import get_kernel
102
+
103
+ km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0)
104
+ device = torch.device("cuda")
105
+
106
+ bs, seq_len = 16, 1024
107
+ policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device)
108
+ old_policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device)
109
+ ref_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device)
110
+ advantages = torch.randn(bs, dtype=torch.bfloat16, device=device)
111
+ completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2).to(torch.int8)
112
+
113
+ # 1) Direct (loss, grad) — lowest overhead training path
114
+ loss, grad = km.bnpo_loss(
115
+ policy_logprobs, old_policy_logprobs, ref_logprobs,
116
+ advantages, completions_mask,
117
+ epsilon=0.2, epsilon_high=0.2, beta=0.1,
118
+ )
119
+ policy_logprobs.backward(grad)
120
+
121
+ # 2) Autograd-aware — works with loss.backward() and torch.compile
122
+ loss = km.bnpo_loss_autograd(
123
+ policy_logprobs.requires_grad_(),
124
+ old_policy_logprobs, ref_logprobs,
125
+ advantages, completions_mask,
126
+ epsilon=0.2, epsilon_high=0.2, beta=0.1,
127
+ )
128
+ loss.backward()
129
+
130
+ # 3) Forward-only — inference / reward scoring, no gradient buffer
131
+ loss = km.bnpo_loss_fwd(
132
+ policy_logprobs, old_policy_logprobs, ref_logprobs,
133
+ advantages, completions_mask,
134
+ epsilon=0.2, epsilon_high=0.2, beta=0.1,
135
+ )
136
+ ```
137
+
138
+ ---
139
+
140
+ ## GRPO Loss
141
+
142
+ **Group Relative Policy Optimization** implements TRL's default
143
+ **per-response normalization** variant — each response is normalized by its
144
+ own valid-token count before averaging across the batch:
145
+
146
+ ```
147
+ loss = mean_r( ((per_token_loss + β·kl) · mask).sum(-1) / max(mask.sum(-1), 1) )
148
+ ```
149
+
150
+ `per_token_loss` and `kl` are the same clipped-ratio and KL expressions as BNPO.
151
+ `completions_mask` is **required** because the per-response denominator is
152
+ mask-derived. The kernel uses one CTA per row so the per-row mask sum is
153
+ reduced inside the block — no cross-CTA atomics on the scaling pass.
154
+
155
+ **Inputs**:
156
+ - `policy_logprobs`, `old_policy_logprobs`, `ref_logprobs`: `(bs, seq_len)`, fp32/fp16/bf16
157
+ - `advantages`: `(bs,)`
158
+ - `completions_mask`: `(bs, seq_len)`, bool or int8 — **required**
159
+
160
+ **Returns**: `(loss, grad_policy_logprobs)` from `grpo_loss`; scalar `loss` from `grpo_loss_fwd`.
161
+
162
+ ```python
163
+ import torch
164
+ from kernels import get_kernel
165
+
166
+ km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0)
167
+ device = torch.device("cuda")
168
+
169
+ bs, seq_len = 16, 1024
170
+ policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device)
171
+ old_policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device)
172
+ ref_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device)
173
+ advantages = torch.randn(bs, dtype=torch.bfloat16, device=device)
174
+ completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2).to(torch.int8)
175
+
176
+ # 1) Direct (loss, grad) — lowest overhead training path
177
+ loss, grad = km.grpo_loss(
178
+ policy_logprobs, old_policy_logprobs, ref_logprobs,
179
+ advantages, completions_mask,
180
+ epsilon=0.2, epsilon_high=0.2, beta=0.1,
181
+ )
182
+ policy_logprobs.backward(grad)
183
+
184
+ # 2) Autograd-aware — works with loss.backward() and torch.compile
185
+ loss = km.grpo_loss_autograd(
186
+ policy_logprobs.requires_grad_(),
187
+ old_policy_logprobs, ref_logprobs,
188
+ advantages, completions_mask,
189
+ epsilon=0.2, epsilon_high=0.2, beta=0.1,
190
+ )
191
+ loss.backward()
192
+
193
+ # 3) Forward-only — inference / reward scoring, no gradient buffer
194
+ loss = km.grpo_loss_fwd(
195
+ policy_logprobs, old_policy_logprobs, ref_logprobs,
196
+ advantages, completions_mask,
197
+ epsilon=0.2, epsilon_high=0.2, beta=0.1,
198
+ )
199
+ ```
200
+
201
+ ---
202
+
203
+ ## Reverse KL
204
+
205
+ **Reverse-KL self-distillation** computes `KL(student ‖ teacher)` over a
206
+ `(num_tokens, vocab)` slab using an online normalization algorithm that reads
207
+ each logit row exactly once on the forward-only path:
208
+
209
+ ```
210
+ p = softmax(student_logits)
211
+ q = softmax(teacher_logits)
212
+ kl_per_row = Σ_v p_v · (log p_v − log q_v)
213
+ loss = (mask · kl_per_row).sum() / mask.sum()
214
+ ```
215
+
216
+ The gradient through the softmax Jacobian is analytical:
217
+
218
+ ```
219
+ grad_student_v = scale · p_v · (log p_v − log q_v − kl_per_row)
220
+ ```
221
+
222
+ where `scale = mask[r] · inv_n_valid`.
223
+
224
+ **Inputs**:
225
+ - `student_logits`, `teacher_logits`: `(*, V)` — arbitrary leading dims (typically `(bs, seq_len, vocab)`); both must share shape and dtype
226
+ - `completions_mask`: shape matching `student_logits.shape[:-1]`
227
+
228
+ > ⚠️ **Fully-masked batches**: `inv_n_valid = 1 / mask.sum()` is not clamped, so a batch where every token is masked produces inf/NaN. Guard upstream if that case is reachable.
229
+
230
+ **Returns**: `(loss, grad_student_logits)` from `reverse_kl`; scalar `loss` from `reverse_kl_fwd`.
231
+
232
+ ```python
233
+ import torch
234
+ from kernels import get_kernel
235
+
236
+ km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0)
237
+ device = torch.device("cuda")
238
+
239
+ # Qwen3.5-style vocab; arbitrary leading dims supported
240
+ bs, seq_len, vocab = 4, 256, 248320
241
+ student_logits = torch.randn(bs, seq_len, vocab, dtype=torch.bfloat16, device=device)
242
+ teacher_logits = torch.randn(bs, seq_len, vocab, dtype=torch.bfloat16, device=device)
243
+ completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2)
244
+
245
+ # 1) Direct (loss, grad) — lowest overhead training path
246
+ loss, grad = km.reverse_kl(student_logits, teacher_logits, completions_mask)
247
+ student_logits.backward(grad)
248
+
249
+ # 2) Autograd-aware — works with loss.backward() and torch.compile
250
+ loss = km.reverse_kl_autograd(
251
+ student_logits.requires_grad_(), teacher_logits, completions_mask
252
+ )
253
+ loss.backward()
254
+
255
+ # 3) Forward-only — inference / KL monitoring, no gradient buffer
256
+ loss = km.reverse_kl_fwd(student_logits, teacher_logits, completions_mask)
257
+ ```
258
+
259
+ ---
260
+
261
+ ## Performance
262
+
263
+ Numbers below are geometric-mean speedups from our in-house benchmark
264
+ (`triton.testing.do_bench`, fresh subprocess per shape). Baselines are
265
+ **eager PyTorch** and **`torch.compile(mode="max-autotune-no-cudagraphs",
266
+ fullgraph=True)`** with `torch._dynamo.config.trace_autograd_ops = True` so
267
+ the compiled baseline is a real Inductor-fused fwd+bwd graph.
268
+
269
+ | Kernel | β | vs eager | vs `torch.compile` |
270
+ | --- | --- | --- | --- |
271
+ | `bnpo_loss_fwd` | 0 | 1.6× | 1.3× |
272
+ | `bnpo_loss` | 0 | 1.5× | 1.2× |
273
+ | `bnpo_loss_fwd` | ≠ 0 | 1.4× | 1.1× |
274
+ | `bnpo_loss` | ≠ 0 | 1.3× | 1.0× |
275
+ | `grpo_loss_fwd` | ≠ 0 | 1.5× | 1.2× |
276
+ | `grpo_loss` | ≠ 0 | 1.4× | 1.1× |
277
+ | `reverse_kl_fwd`| | 1.3× | 1.1× |
278
+ | `reverse_kl` | | 1.2× | 1.0× |
279
+
280
+ Profiled on H100 SXM (SM90a). BNPO and GRPO benchmarked separately for `β = 0`
281
+ (KL term dead-coded at compile time) and `β ≠ 0`. Shapes:
282
+
283
+ - **BNPO / GRPO**: `(16, 1024)`, `(32, 2048)`, `(64, 4096)`, `(128, 8192)`,
284
+ `(128, 8193)` — the last entry exercises the predicated tail-tile path.
285
+ - **Reverse KL** (vocab = 248320, matching Qwen3.5): `(1, 64)`, `(2, 128)`,
286
+ `(4, 256)`, `(8, 512)`, `(16, 1024)`, `(8, 1981)`.
287
+
288
+ Reproduce locally:
289
+
290
+ ```bash
291
+ make bench-kernel KERNEL=grpo_loss # or bnpo_loss, reverse_kl
292
+ ```
293
+
294
+ ---
295
+
296
+ ## Benchmark animations
297
+
298
+ ### BNPO Loss vs eager PyTorch
299
+
300
+ <picture>
301
+ <source media="(prefers-color-scheme: dark)" srcset="benchmark_results/bnpo_loss_eager/bnpo_loss_eager_dark_animation.svg">
302
+ <img src="benchmark_results/bnpo_loss_eager/bnpo_loss_eager_light_animation.svg" alt="BNPO loss latency vs eager PyTorch">
303
+ </picture>
304
+
305
+ ### BNPO Loss vs torch.compile
306
+
307
+ <picture>
308
+ <source media="(prefers-color-scheme: dark)" srcset="benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_dark_animation.svg">
309
+ <img src="benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_light_animation.svg" alt="BNPO loss latency vs torch.compile">
310
+ </picture>
311
+
312
+ ### GRPO Loss vs eager PyTorch
313
+
314
+ <picture>
315
+ <source media="(prefers-color-scheme: dark)" srcset="benchmark_results/grpo_loss_eager/grpo_loss_eager_dark_animation.svg">
316
+ <img src="benchmark_results/grpo_loss_eager/grpo_loss_eager_light_animation.svg" alt="GRPO loss latency vs eager PyTorch">
317
+ </picture>
318
+
319
+ ### GRPO Loss vs torch.compile
320
+
321
+ <picture>
322
+ <source media="(prefers-color-scheme: dark)" srcset="benchmark_results/grpo_loss_compiled/grpo_loss_compiled_dark_animation.svg">
323
+ <img src="benchmark_results/grpo_loss_compiled/grpo_loss_compiled_light_animation.svg" alt="GRPO loss latency vs torch.compile">
324
+ </picture>
325
+
326
+ ### Reverse KL vs eager PyTorch
327
+
328
+ <picture>
329
+ <source media="(prefers-color-scheme: dark)" srcset="benchmark_results/reverse_kl_eager/reverse_kl_eager_dark_animation.svg">
330
+ <img src="benchmark_results/reverse_kl_eager/reverse_kl_eager_light_animation.svg" alt="Reverse KL latency vs eager PyTorch">
331
+ </picture>
332
+
333
+ ### Reverse KL vs torch.compile
334
+
335
+ <picture>
336
+ <source media="(prefers-color-scheme: dark)" srcset="benchmark_results/reverse_kl_compiled/reverse_kl_compiled_dark_animation.svg">
337
+ <img src="benchmark_results/reverse_kl_compiled/reverse_kl_compiled_light_animation.svg" alt="Reverse KL latency vs torch.compile">
338
+ </picture>