Pramodith's picture
Uploaded using `kernel-builder`.
6a0a8e9 verified
---
library_name: kernels
license: apache-2.0
tags:
- cuda
- cutlass
- cute-dsl
- rl
- distillation
- trl
- grpo
- bnpo
- kl-divergence
---
# Geometric-AI Kernels
Fused **CuteDSL** kernels for the loss functions that dominate post-training
workloads: PPO-family policy losses (BNPO, GRPO) and reverse-KL
self-distillation.
Each kernel ships a **single-launch fused forward +
backward** path that returns `(loss, grad_logprobs)` directly. No `torch.autograd.Function` wrapper, no extra `grad_output * dpolicy` backward
kernel, and no host-side syncs in the hot path.
Background and benchmarks: see the
[release post](https://geometric.so/blog/2026/05/08/hf-kernel-hub).
- **Backend**: CUDA (NVIDIA CUTLASS DSL).
- **Min GPU**: SM80 (Ampere) - required by `nvidia-cutlass-dsl`. Tested on H100 (SM90). Should work on SM80 (Ampere), SM86 (RTX 3090, A40), SM89 (RTX 4090, L40S), SM90a (H100 SXM), and SM100 (Blackwell B200/GB200).
- **Min CUDA**: 12.8.
- **Dtypes**: `float32`, `float16`, `bfloat16`.
- **Dynamic shapes**: a single compile handles arbitrary batch size and
sequence length, no recompiles when shapes change between calls (common
in post-training rollouts).
## Kernels
| Kernel family | Direct (no autograd) | Autograd-aware | Forward-only |
| --- | --- | --- | --- |
| BNPO loss | `bnpo_loss` | `bnpo_loss_autograd` | `bnpo_loss_fwd` |
| GRPO loss | `grpo_loss` | `grpo_loss_autograd` | `grpo_loss_fwd` |
| Reverse KL | `reverse_kl` | `reverse_kl_autograd` | `reverse_kl_fwd` |
### Entry points
Each kernel family exposes three entry points with the same underlying CuteDSL kernel:
- **`<name>(...)`** - fused fwd+bwd, returns `(loss, grad)` from one `@cute.jit`
dispatch. Lowest-overhead path; the caller chains the gradient into the upstream
model with `policy_logprobs.backward(grad)`. Use this in custom training loops
where you control gradient flow.
- **`<name>_autograd(...)`** - same kernel, registered via
`torch.library.custom_op` + `register_autograd`. `loss.backward()` works
and composes with `torch.compile(fullgraph=True)`. There is a noticeable
per-call dispatcher overhead vs. the direct path.
- **`<name>_fwd(...)`** - forward-only, returns scalar `loss` and skips
the gradient buffer entirely. Use for inference / validation /
reward-model scoring.
## Loading the kernels
```
pip install apache-tvm-ffi nvidia-cutlass-dsl
```
```python
from kernels import get_kernel
km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0)
```
---
## BNPO Loss
**Batch-Normalized Policy Optimization** sums per-token policy and KL terms
across the **entire batch** and divides by the global valid-token count:
```
loss = ((per_token_loss + β·kl) · mask).sum() / max(mask.sum(), 1)
```
where `per_token_loss` is the PPO-clipped ratio loss:
```
ratio = exp(policy_logprobs - old_policy_logprobs)
clipped = clip(ratio, 1−ε, 1+ε_high)
per_token = −advantages · min(ratio, clipped)
kl = exp(ref_logprobs − policy_logprobs) − (ref_logprobs − policy_logprobs) − 1
```
The global denominator is computed entirely on-GPU via cross-CTA atomics -
no host-side `mask.sum()` sync. When `beta=0` the KL branch is dead-coded
at compile time.
**Inputs**:
- `policy_logprobs`, `old_policy_logprobs`, `ref_logprobs`: `(bs, seq_len)`, fp32/fp16/bf16
- `advantages`: `(bs,)`
- `completions_mask`: `(bs, seq_len)`, bool or int8
**Returns**: `(loss, grad_policy_logprobs)` from `bnpo_loss`; scalar `loss` from `bnpo_loss_fwd`.
```python
import torch
from kernels import get_kernel
km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0)
device = torch.device("cuda")
bs, seq_len = 16, 1024
policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device, requires_grad=True)
old_policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device)
ref_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device)
advantages = torch.randn(bs, dtype=torch.bfloat16, device=device)
completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2).to(torch.int8)
# 1) Direct (loss, grad) - lowest overhead training path
loss, grad = km.bnpo_loss(
policy_logprobs, old_policy_logprobs, ref_logprobs,
advantages, completions_mask,
epsilon=0.2, epsilon_high=0.2, beta=0.1,
)
policy_logprobs.backward(grad)
# 2) Autograd-aware - works with loss.backward() and torch.compile
loss = km.bnpo_loss_autograd(
policy_logprobs.requires_grad_(),
old_policy_logprobs, ref_logprobs,
advantages, completions_mask,
epsilon=0.2, epsilon_high=0.2, beta=0.1,
)
loss.backward()
# 3) Forward-only - inference / reward scoring, no gradient buffer
loss = km.bnpo_loss_fwd(
policy_logprobs, old_policy_logprobs, ref_logprobs,
advantages, completions_mask,
epsilon=0.2, epsilon_high=0.2, beta=0.1,
)
```
---
## GRPO Loss
**Group Relative Policy Optimization** implements TRL's default
**per-response normalization** variant - each response is normalized by its
own valid-token count before averaging across the batch:
```
loss = mean_r( ((per_token_loss + β·kl) · mask).sum(-1) / max(mask.sum(-1), 1) )
```
`per_token_loss` and `kl` are the same clipped-ratio and KL expressions as BNPO.
`completions_mask` is **required** because the per-response denominator is
mask-derived. The kernel uses one CTA per row so the per-row mask sum is
reduced inside the block - no cross-CTA atomics on the scaling pass.
**Inputs**:
- `policy_logprobs`, `old_policy_logprobs`, `ref_logprobs`: `(bs, seq_len)`, fp32/fp16/bf16
- `advantages`: `(bs,)`
- `completions_mask`: `(bs, seq_len)`, bool or int8 - **required**
**Returns**: `(loss, grad_policy_logprobs)` from `grpo_loss`; scalar `loss` from `grpo_loss_fwd`.
```python
import torch
from kernels import get_kernel
km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0)
device = torch.device("cuda")
bs, seq_len = 16, 1024
policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device, requires_grad=True)
old_policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device)
ref_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device)
advantages = torch.randn(bs, dtype=torch.bfloat16, device=device)
completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2).to(torch.int8)
# 1) Direct (loss, grad) - lowest overhead training path
loss, grad = km.grpo_loss(
policy_logprobs, old_policy_logprobs, ref_logprobs,
advantages, completions_mask,
epsilon=0.2, epsilon_high=0.2, beta=0.1,
)
policy_logprobs.backward(grad)
# 2) Autograd-aware - works with loss.backward() and torch.compile
loss = km.grpo_loss_autograd(
policy_logprobs.requires_grad_(),
old_policy_logprobs, ref_logprobs,
advantages, completions_mask,
epsilon=0.2, epsilon_high=0.2, beta=0.1,
)
loss.backward()
# 3) Forward-only - inference / reward scoring, no gradient buffer
loss = km.grpo_loss_fwd(
policy_logprobs, old_policy_logprobs, ref_logprobs,
advantages, completions_mask,
epsilon=0.2, epsilon_high=0.2, beta=0.1,
)
```
---
## Reverse KL
**Reverse-KL self-distillation** computes `KL(student ‖ teacher)` over a
`(num_tokens, vocab)` slab using an online normalization algorithm that reads
each logit row exactly once on the forward-only path:
```
p = softmax(student_logits)
q = softmax(teacher_logits)
kl_per_row = Σ_v p_v · (log p_v − log q_v)
loss = (mask · kl_per_row).sum() / mask.sum()
```
The gradient through the softmax Jacobian is analytical:
```
grad_student_v = scale · p_v · (log p_v − log q_v − kl_per_row)
```
where `scale = mask[r] · inv_n_valid`.
**Inputs**:
- `student_logits`, `teacher_logits`: `(*, V)` - arbitrary leading dims (typically `(bs, seq_len, vocab)`); both must share shape and dtype
- `completions_mask`: shape matching `student_logits.shape[:-1]`
> ⚠️ **Fully-masked batches**: `inv_n_valid = 1 / mask.sum()` is not clamped, so a batch where every token is masked produces inf/NaN. Guard upstream if that case is reachable.
**Returns**: `(loss, grad_student_logits)` from `reverse_kl`; scalar `loss` from `reverse_kl_fwd`.
```python
import torch
from kernels import get_kernel
km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0)
device = torch.device("cuda")
# Qwen3.5-style vocab; arbitrary leading dims supported
bs, seq_len, vocab = 4, 256, 248320
student_logits = torch.randn(bs, seq_len, vocab, dtype=torch.bfloat16, device=device, requires_grad=True)
teacher_logits = torch.randn(bs, seq_len, vocab, dtype=torch.bfloat16, device=device)
completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2)
# 1) Direct (loss, grad) - lowest overhead training path
loss, grad = km.reverse_kl(student_logits, teacher_logits, completions_mask)
student_logits.backward(grad)
# 2) Autograd-aware - works with loss.backward() and torch.compile
loss = km.reverse_kl_autograd(
student_logits.requires_grad_(), teacher_logits, completions_mask
)
loss.backward()
# 3) Forward-only - inference / KL monitoring, no gradient buffer
loss = km.reverse_kl_fwd(student_logits, teacher_logits, completions_mask)
```
---
## Performance
All numbers are geometric-mean speedups over H100 SXM (SM90a). Full methodology
and per-shape plots in the [release post](https://geometric.so/blog/2026/05/08/hf-kernel-hub).
### `kernels` CLI benchmark
Timed with `time.perf_counter` + `cuda.synchronize()`, mean over 100 iterations.
| Kernel | vs eager | vs `torch.compile` |
| --- | --- | --- |
| `grpo_loss_fwd` | 5.68× | 2.45× |
| `grpo_loss` | 20.79× | 1.98x |
| `bnpo_loss_fwd` | 5.29× | 2.52× |
| `bnpo_loss` | 16.81× | 2.27× |
| `reverse_kl_fwd`| 6.88× | 2.45× |
| `reverse_kl` | 7.03× | 2.61× |
---
## Benchmark animations
### BNPO Loss vs eager PyTorch
<picture>
<source media="(prefers-color-scheme: dark)" srcset="benchmark_results/bnpo_loss_eager/bnpo_loss_eager_dark_animation.svg">
<img width="90%" src="benchmark_results/bnpo_loss_eager/bnpo_loss_eager_light_animation.svg" alt="BNPO loss latency vs eager PyTorch">
</picture>
### BNPO Loss vs torch.compile
<picture>
<source media="(prefers-color-scheme: dark)" srcset="benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_dark_animation.svg">
<img width="90%" src="benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_light_animation.svg" alt="BNPO loss latency vs torch.compile">
</picture>
### GRPO Loss vs eager PyTorch
<picture>
<source media="(prefers-color-scheme: dark)" srcset="benchmark_results/grpo_loss_eager/grpo_loss_eager_dark_animation.svg">
<img width="90%" src="benchmark_results/grpo_loss_eager/grpo_loss_eager_light_animation.svg" alt="GRPO loss latency vs eager PyTorch">
</picture>
### GRPO Loss vs torch.compile
<picture>
<source media="(prefers-color-scheme: dark)" srcset="benchmark_results/grpo_loss_compiled/grpo_loss_compiled_dark_animation.svg">
<img width="90%" src="benchmark_results/grpo_loss_compiled/grpo_loss_compiled_light_animation.svg" alt="GRPO loss latency vs torch.compile">
</picture>
### Reverse KL vs eager PyTorch
<picture>
<source media="(prefers-color-scheme: dark)" srcset="benchmark_results/reverse_kl_eager/reverse_kl_eager_dark_animation.svg">
<img width="90%" src="benchmark_results/reverse_kl_eager/reverse_kl_eager_light_animation.svg" alt="Reverse KL latency vs eager PyTorch">
</picture>
### Reverse KL vs torch.compile
<picture>
<source media="(prefers-color-scheme: dark)" srcset="benchmark_results/reverse_kl_compiled/reverse_kl_compiled_dark_animation.svg">
<img width="90%" src="benchmark_results/reverse_kl_compiled/reverse_kl_compiled_light_animation.svg" alt="Reverse KL latency vs torch.compile">
</picture>