File size: 11,786 Bytes
43849eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c92a555
 
 
 
43849eb
 
 
 
 
 
c92a555
43849eb
 
 
c92a555
43849eb
 
 
 
 
 
 
 
 
 
 
 
 
 
c92a555
43849eb
 
 
c92a555
43849eb
 
 
c92a555
43849eb
 
 
 
6a0a8e9
 
 
43849eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c92a555
43849eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a0a8e9
43849eb
 
 
 
 
c92a555
43849eb
 
 
 
 
 
 
c92a555
43849eb
 
 
 
 
 
 
 
c92a555
43849eb
 
 
 
 
 
 
 
 
 
 
 
c92a555
43849eb
 
 
 
 
 
 
 
 
c92a555
43849eb
 
 
 
c92a555
43849eb
 
 
 
 
 
 
 
 
 
 
6a0a8e9
43849eb
 
 
 
 
c92a555
43849eb
 
 
 
 
 
 
c92a555
43849eb
 
 
 
 
 
 
 
c92a555
43849eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c92a555
43849eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a0a8e9
43849eb
 
 
c92a555
43849eb
 
 
c92a555
43849eb
 
 
 
 
c92a555
43849eb
 
 
 
 
 
 
c92a555
 
43849eb
c92a555
 
3d37a31
43849eb
c92a555
 
6a0a8e9
 
 
 
 
 
43849eb
 
 
 
 
 
 
 
121a6c0
43849eb
 
 
 
 
 
121a6c0
43849eb
 
 
 
 
 
121a6c0
43849eb
 
 
 
 
 
121a6c0
43849eb
 
 
 
 
 
121a6c0
43849eb
 
 
 
 
 
121a6c0
43849eb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
---
library_name: kernels
license: apache-2.0
tags:
- cuda
- cutlass
- cute-dsl
- rl
- distillation
- trl
- grpo
- bnpo
- kl-divergence
---

# Geometric-AI Kernels

Fused **CuteDSL** kernels for the loss functions that dominate post-training
workloads: PPO-family policy losses (BNPO, GRPO) and reverse-KL
self-distillation. 

Each kernel ships a **single-launch fused forward +
backward** path that returns `(loss, grad_logprobs)` directly. No `torch.autograd.Function` wrapper, no extra `grad_output * dpolicy` backward
kernel, and no host-side syncs in the hot path.

Background and benchmarks: see the
[release post](https://geometric.so/blog/2026/05/08/hf-kernel-hub).

- **Backend**: CUDA (NVIDIA CUTLASS DSL).
- **Min GPU**: SM80 (Ampere) - required by `nvidia-cutlass-dsl`. Tested on H100 (SM90). Should work on SM80 (Ampere), SM86 (RTX 3090, A40), SM89 (RTX 4090, L40S), SM90a (H100 SXM), and SM100 (Blackwell B200/GB200).
- **Min CUDA**: 12.8.
- **Dtypes**: `float32`, `float16`, `bfloat16`.
- **Dynamic shapes**: a single compile handles arbitrary batch size and
  sequence length, no recompiles when shapes change between calls (common
  in post-training rollouts).

## Kernels

| Kernel family | Direct (no autograd) | Autograd-aware | Forward-only |
| --- | --- | --- | --- |
| BNPO loss | `bnpo_loss` | `bnpo_loss_autograd` | `bnpo_loss_fwd` |
| GRPO loss | `grpo_loss` | `grpo_loss_autograd` | `grpo_loss_fwd` |
| Reverse KL | `reverse_kl` | `reverse_kl_autograd` | `reverse_kl_fwd` |

### Entry points

Each kernel family exposes three entry points with the same underlying CuteDSL kernel:

- **`<name>(...)`** - fused fwd+bwd, returns `(loss, grad)` from one `@cute.jit`
  dispatch. Lowest-overhead path; the caller chains the gradient into the upstream
  model with `policy_logprobs.backward(grad)`. Use this in custom training loops
  where you control gradient flow.
- **`<name>_autograd(...)`** - same kernel, registered via
  `torch.library.custom_op` + `register_autograd`. `loss.backward()` works
  and composes with `torch.compile(fullgraph=True)`. There is a noticeable
  per-call dispatcher overhead vs. the direct path.
- **`<name>_fwd(...)`** - forward-only, returns scalar `loss` and skips
  the gradient buffer entirely. Use for inference / validation /
  reward-model scoring.

## Loading the kernels
```
pip install apache-tvm-ffi nvidia-cutlass-dsl
```

```python
from kernels import get_kernel

km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0)
```

---

## BNPO Loss

**Batch-Normalized Policy Optimization** sums per-token policy and KL terms
across the **entire batch** and divides by the global valid-token count:

```
loss = ((per_token_loss + β·kl) · mask).sum() / max(mask.sum(), 1)
```

where `per_token_loss` is the PPO-clipped ratio loss:

```
ratio      = exp(policy_logprobs - old_policy_logprobs)
clipped    = clip(ratio, 1−ε, 1+ε_high)
per_token  = −advantages · min(ratio, clipped)
kl         = exp(ref_logprobs − policy_logprobs) − (ref_logprobs − policy_logprobs) − 1
```

The global denominator is computed entirely on-GPU via cross-CTA atomics -
no host-side `mask.sum()` sync. When `beta=0` the KL branch is dead-coded
at compile time.

**Inputs**:
- `policy_logprobs`, `old_policy_logprobs`, `ref_logprobs`: `(bs, seq_len)`, fp32/fp16/bf16
- `advantages`: `(bs,)`
- `completions_mask`: `(bs, seq_len)`, bool or int8

**Returns**: `(loss, grad_policy_logprobs)` from `bnpo_loss`; scalar `loss` from `bnpo_loss_fwd`.

```python
import torch
from kernels import get_kernel

km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0)
device = torch.device("cuda")

bs, seq_len = 16, 1024
policy_logprobs     = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device, requires_grad=True)
old_policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device)
ref_logprobs        = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device)
advantages          = torch.randn(bs, dtype=torch.bfloat16, device=device)
completions_mask    = (torch.rand(bs, seq_len, device=device) > 0.2).to(torch.int8)

# 1) Direct (loss, grad) - lowest overhead training path
loss, grad = km.bnpo_loss(
    policy_logprobs, old_policy_logprobs, ref_logprobs,
    advantages, completions_mask,
    epsilon=0.2, epsilon_high=0.2, beta=0.1,
)
policy_logprobs.backward(grad)

# 2) Autograd-aware - works with loss.backward() and torch.compile
loss = km.bnpo_loss_autograd(
    policy_logprobs.requires_grad_(),
    old_policy_logprobs, ref_logprobs,
    advantages, completions_mask,
    epsilon=0.2, epsilon_high=0.2, beta=0.1,
)
loss.backward()

# 3) Forward-only - inference / reward scoring, no gradient buffer
loss = km.bnpo_loss_fwd(
    policy_logprobs, old_policy_logprobs, ref_logprobs,
    advantages, completions_mask,
    epsilon=0.2, epsilon_high=0.2, beta=0.1,
)
```

---

## GRPO Loss

**Group Relative Policy Optimization** implements TRL's default
**per-response normalization** variant - each response is normalized by its
own valid-token count before averaging across the batch:

```
loss = mean_r( ((per_token_loss + β·kl) · mask).sum(-1) / max(mask.sum(-1), 1) )
```

`per_token_loss` and `kl` are the same clipped-ratio and KL expressions as BNPO.
`completions_mask` is **required** because the per-response denominator is
mask-derived. The kernel uses one CTA per row so the per-row mask sum is
reduced inside the block - no cross-CTA atomics on the scaling pass.

**Inputs**:
- `policy_logprobs`, `old_policy_logprobs`, `ref_logprobs`: `(bs, seq_len)`, fp32/fp16/bf16
- `advantages`: `(bs,)`
- `completions_mask`: `(bs, seq_len)`, bool or int8 - **required**

**Returns**: `(loss, grad_policy_logprobs)` from `grpo_loss`; scalar `loss` from `grpo_loss_fwd`.

```python
import torch
from kernels import get_kernel

km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0)
device = torch.device("cuda")

bs, seq_len = 16, 1024
policy_logprobs     = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device, requires_grad=True)
old_policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device)
ref_logprobs        = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device)
advantages          = torch.randn(bs, dtype=torch.bfloat16, device=device)
completions_mask    = (torch.rand(bs, seq_len, device=device) > 0.2).to(torch.int8)

# 1) Direct (loss, grad) - lowest overhead training path
loss, grad = km.grpo_loss(
    policy_logprobs, old_policy_logprobs, ref_logprobs,
    advantages, completions_mask,
    epsilon=0.2, epsilon_high=0.2, beta=0.1,
)
policy_logprobs.backward(grad)

# 2) Autograd-aware - works with loss.backward() and torch.compile
loss = km.grpo_loss_autograd(
    policy_logprobs.requires_grad_(),
    old_policy_logprobs, ref_logprobs,
    advantages, completions_mask,
    epsilon=0.2, epsilon_high=0.2, beta=0.1,
)
loss.backward()

# 3) Forward-only - inference / reward scoring, no gradient buffer
loss = km.grpo_loss_fwd(
    policy_logprobs, old_policy_logprobs, ref_logprobs,
    advantages, completions_mask,
    epsilon=0.2, epsilon_high=0.2, beta=0.1,
)
```

---

## Reverse KL

**Reverse-KL self-distillation** computes `KL(student ‖ teacher)` over a
`(num_tokens, vocab)` slab using an online normalization algorithm that reads
each logit row exactly once on the forward-only path:

```
p = softmax(student_logits)
q = softmax(teacher_logits)
kl_per_row = Σ_v  p_v · (log p_v − log q_v)
loss = (mask · kl_per_row).sum() / mask.sum()
```

The gradient through the softmax Jacobian is analytical:

```
grad_student_v = scale · p_v · (log p_v − log q_v − kl_per_row)
```

where `scale = mask[r] · inv_n_valid`.

**Inputs**:
- `student_logits`, `teacher_logits`: `(*, V)` - arbitrary leading dims (typically `(bs, seq_len, vocab)`); both must share shape and dtype
- `completions_mask`: shape matching `student_logits.shape[:-1]`

> ⚠️ **Fully-masked batches**: `inv_n_valid = 1 / mask.sum()` is not clamped, so a batch where every token is masked produces inf/NaN. Guard upstream if that case is reachable.

**Returns**: `(loss, grad_student_logits)` from `reverse_kl`; scalar `loss` from `reverse_kl_fwd`.

```python
import torch
from kernels import get_kernel

km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0)
device = torch.device("cuda")

# Qwen3.5-style vocab; arbitrary leading dims supported
bs, seq_len, vocab = 4, 256, 248320
student_logits  = torch.randn(bs, seq_len, vocab, dtype=torch.bfloat16, device=device, requires_grad=True)
teacher_logits  = torch.randn(bs, seq_len, vocab, dtype=torch.bfloat16, device=device)
completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2)

# 1) Direct (loss, grad) - lowest overhead training path
loss, grad = km.reverse_kl(student_logits, teacher_logits, completions_mask)
student_logits.backward(grad)

# 2) Autograd-aware - works with loss.backward() and torch.compile
loss = km.reverse_kl_autograd(
    student_logits.requires_grad_(), teacher_logits, completions_mask
)
loss.backward()

# 3) Forward-only - inference / KL monitoring, no gradient buffer
loss = km.reverse_kl_fwd(student_logits, teacher_logits, completions_mask)
```

---

## Performance

All numbers are geometric-mean speedups over H100 SXM (SM90a). Full methodology
and per-shape plots in the [release post](https://geometric.so/blog/2026/05/08/hf-kernel-hub).

### `kernels` CLI benchmark

Timed with `time.perf_counter` + `cuda.synchronize()`, mean over 100 iterations.

| Kernel | vs eager | vs `torch.compile` |
| --- | --- | --- |
| `grpo_loss_fwd` | 5.68×  | 2.45× |
| `grpo_loss`     | 20.79× | 1.98x |
| `bnpo_loss_fwd` | 5.29×  | 2.52× |
| `bnpo_loss`     | 16.81× | 2.27× |
| `reverse_kl_fwd`| 6.88×  | 2.45× |
| `reverse_kl`    | 7.03×  | 2.61× |
---

## Benchmark animations

### BNPO Loss vs eager PyTorch

<picture>
  <source media="(prefers-color-scheme: dark)" srcset="benchmark_results/bnpo_loss_eager/bnpo_loss_eager_dark_animation.svg">
  <img width="90%" src="benchmark_results/bnpo_loss_eager/bnpo_loss_eager_light_animation.svg" alt="BNPO loss latency vs eager PyTorch">
</picture>

### BNPO Loss vs torch.compile

<picture>
  <source media="(prefers-color-scheme: dark)" srcset="benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_dark_animation.svg">
  <img width="90%" src="benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_light_animation.svg" alt="BNPO loss latency vs torch.compile">
</picture>

### GRPO Loss vs eager PyTorch

<picture>
  <source media="(prefers-color-scheme: dark)" srcset="benchmark_results/grpo_loss_eager/grpo_loss_eager_dark_animation.svg">
  <img width="90%" src="benchmark_results/grpo_loss_eager/grpo_loss_eager_light_animation.svg" alt="GRPO loss latency vs eager PyTorch">
</picture>

### GRPO Loss vs torch.compile

<picture>
  <source media="(prefers-color-scheme: dark)" srcset="benchmark_results/grpo_loss_compiled/grpo_loss_compiled_dark_animation.svg">
  <img width="90%" src="benchmark_results/grpo_loss_compiled/grpo_loss_compiled_light_animation.svg" alt="GRPO loss latency vs torch.compile">
</picture>

### Reverse KL vs eager PyTorch

<picture>
  <source media="(prefers-color-scheme: dark)" srcset="benchmark_results/reverse_kl_eager/reverse_kl_eager_dark_animation.svg">
  <img width="90%" src="benchmark_results/reverse_kl_eager/reverse_kl_eager_light_animation.svg" alt="Reverse KL latency vs eager PyTorch">
</picture>

### Reverse KL vs torch.compile

<picture>
  <source media="(prefers-color-scheme: dark)" srcset="benchmark_results/reverse_kl_compiled/reverse_kl_compiled_dark_animation.svg">
  <img width="90%" src="benchmark_results/reverse_kl_compiled/reverse_kl_compiled_light_animation.svg" alt="Reverse KL latency vs torch.compile">
</picture>