Instructions to use Geometric-AI/geometric-ai-kernels with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use Geometric-AI/geometric-ai-kernels with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("Geometric-AI/geometric-ai-kernels") - Notebooks
- Google Colab
- Kaggle
File size: 11,786 Bytes
43849eb c92a555 43849eb c92a555 43849eb c92a555 43849eb c92a555 43849eb c92a555 43849eb c92a555 43849eb 6a0a8e9 43849eb c92a555 43849eb 6a0a8e9 43849eb c92a555 43849eb c92a555 43849eb c92a555 43849eb c92a555 43849eb c92a555 43849eb c92a555 43849eb 6a0a8e9 43849eb c92a555 43849eb c92a555 43849eb c92a555 43849eb c92a555 43849eb 6a0a8e9 43849eb c92a555 43849eb c92a555 43849eb c92a555 43849eb c92a555 43849eb c92a555 3d37a31 43849eb c92a555 6a0a8e9 43849eb 121a6c0 43849eb 121a6c0 43849eb 121a6c0 43849eb 121a6c0 43849eb 121a6c0 43849eb 121a6c0 43849eb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 | ---
library_name: kernels
license: apache-2.0
tags:
- cuda
- cutlass
- cute-dsl
- rl
- distillation
- trl
- grpo
- bnpo
- kl-divergence
---
# Geometric-AI Kernels
Fused **CuteDSL** kernels for the loss functions that dominate post-training
workloads: PPO-family policy losses (BNPO, GRPO) and reverse-KL
self-distillation.
Each kernel ships a **single-launch fused forward +
backward** path that returns `(loss, grad_logprobs)` directly. No `torch.autograd.Function` wrapper, no extra `grad_output * dpolicy` backward
kernel, and no host-side syncs in the hot path.
Background and benchmarks: see the
[release post](https://geometric.so/blog/2026/05/08/hf-kernel-hub).
- **Backend**: CUDA (NVIDIA CUTLASS DSL).
- **Min GPU**: SM80 (Ampere) - required by `nvidia-cutlass-dsl`. Tested on H100 (SM90). Should work on SM80 (Ampere), SM86 (RTX 3090, A40), SM89 (RTX 4090, L40S), SM90a (H100 SXM), and SM100 (Blackwell B200/GB200).
- **Min CUDA**: 12.8.
- **Dtypes**: `float32`, `float16`, `bfloat16`.
- **Dynamic shapes**: a single compile handles arbitrary batch size and
sequence length, no recompiles when shapes change between calls (common
in post-training rollouts).
## Kernels
| Kernel family | Direct (no autograd) | Autograd-aware | Forward-only |
| --- | --- | --- | --- |
| BNPO loss | `bnpo_loss` | `bnpo_loss_autograd` | `bnpo_loss_fwd` |
| GRPO loss | `grpo_loss` | `grpo_loss_autograd` | `grpo_loss_fwd` |
| Reverse KL | `reverse_kl` | `reverse_kl_autograd` | `reverse_kl_fwd` |
### Entry points
Each kernel family exposes three entry points with the same underlying CuteDSL kernel:
- **`<name>(...)`** - fused fwd+bwd, returns `(loss, grad)` from one `@cute.jit`
dispatch. Lowest-overhead path; the caller chains the gradient into the upstream
model with `policy_logprobs.backward(grad)`. Use this in custom training loops
where you control gradient flow.
- **`<name>_autograd(...)`** - same kernel, registered via
`torch.library.custom_op` + `register_autograd`. `loss.backward()` works
and composes with `torch.compile(fullgraph=True)`. There is a noticeable
per-call dispatcher overhead vs. the direct path.
- **`<name>_fwd(...)`** - forward-only, returns scalar `loss` and skips
the gradient buffer entirely. Use for inference / validation /
reward-model scoring.
## Loading the kernels
```
pip install apache-tvm-ffi nvidia-cutlass-dsl
```
```python
from kernels import get_kernel
km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0)
```
---
## BNPO Loss
**Batch-Normalized Policy Optimization** sums per-token policy and KL terms
across the **entire batch** and divides by the global valid-token count:
```
loss = ((per_token_loss + β·kl) · mask).sum() / max(mask.sum(), 1)
```
where `per_token_loss` is the PPO-clipped ratio loss:
```
ratio = exp(policy_logprobs - old_policy_logprobs)
clipped = clip(ratio, 1−ε, 1+ε_high)
per_token = −advantages · min(ratio, clipped)
kl = exp(ref_logprobs − policy_logprobs) − (ref_logprobs − policy_logprobs) − 1
```
The global denominator is computed entirely on-GPU via cross-CTA atomics -
no host-side `mask.sum()` sync. When `beta=0` the KL branch is dead-coded
at compile time.
**Inputs**:
- `policy_logprobs`, `old_policy_logprobs`, `ref_logprobs`: `(bs, seq_len)`, fp32/fp16/bf16
- `advantages`: `(bs,)`
- `completions_mask`: `(bs, seq_len)`, bool or int8
**Returns**: `(loss, grad_policy_logprobs)` from `bnpo_loss`; scalar `loss` from `bnpo_loss_fwd`.
```python
import torch
from kernels import get_kernel
km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0)
device = torch.device("cuda")
bs, seq_len = 16, 1024
policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device, requires_grad=True)
old_policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device)
ref_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device)
advantages = torch.randn(bs, dtype=torch.bfloat16, device=device)
completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2).to(torch.int8)
# 1) Direct (loss, grad) - lowest overhead training path
loss, grad = km.bnpo_loss(
policy_logprobs, old_policy_logprobs, ref_logprobs,
advantages, completions_mask,
epsilon=0.2, epsilon_high=0.2, beta=0.1,
)
policy_logprobs.backward(grad)
# 2) Autograd-aware - works with loss.backward() and torch.compile
loss = km.bnpo_loss_autograd(
policy_logprobs.requires_grad_(),
old_policy_logprobs, ref_logprobs,
advantages, completions_mask,
epsilon=0.2, epsilon_high=0.2, beta=0.1,
)
loss.backward()
# 3) Forward-only - inference / reward scoring, no gradient buffer
loss = km.bnpo_loss_fwd(
policy_logprobs, old_policy_logprobs, ref_logprobs,
advantages, completions_mask,
epsilon=0.2, epsilon_high=0.2, beta=0.1,
)
```
---
## GRPO Loss
**Group Relative Policy Optimization** implements TRL's default
**per-response normalization** variant - each response is normalized by its
own valid-token count before averaging across the batch:
```
loss = mean_r( ((per_token_loss + β·kl) · mask).sum(-1) / max(mask.sum(-1), 1) )
```
`per_token_loss` and `kl` are the same clipped-ratio and KL expressions as BNPO.
`completions_mask` is **required** because the per-response denominator is
mask-derived. The kernel uses one CTA per row so the per-row mask sum is
reduced inside the block - no cross-CTA atomics on the scaling pass.
**Inputs**:
- `policy_logprobs`, `old_policy_logprobs`, `ref_logprobs`: `(bs, seq_len)`, fp32/fp16/bf16
- `advantages`: `(bs,)`
- `completions_mask`: `(bs, seq_len)`, bool or int8 - **required**
**Returns**: `(loss, grad_policy_logprobs)` from `grpo_loss`; scalar `loss` from `grpo_loss_fwd`.
```python
import torch
from kernels import get_kernel
km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0)
device = torch.device("cuda")
bs, seq_len = 16, 1024
policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device, requires_grad=True)
old_policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device)
ref_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device)
advantages = torch.randn(bs, dtype=torch.bfloat16, device=device)
completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2).to(torch.int8)
# 1) Direct (loss, grad) - lowest overhead training path
loss, grad = km.grpo_loss(
policy_logprobs, old_policy_logprobs, ref_logprobs,
advantages, completions_mask,
epsilon=0.2, epsilon_high=0.2, beta=0.1,
)
policy_logprobs.backward(grad)
# 2) Autograd-aware - works with loss.backward() and torch.compile
loss = km.grpo_loss_autograd(
policy_logprobs.requires_grad_(),
old_policy_logprobs, ref_logprobs,
advantages, completions_mask,
epsilon=0.2, epsilon_high=0.2, beta=0.1,
)
loss.backward()
# 3) Forward-only - inference / reward scoring, no gradient buffer
loss = km.grpo_loss_fwd(
policy_logprobs, old_policy_logprobs, ref_logprobs,
advantages, completions_mask,
epsilon=0.2, epsilon_high=0.2, beta=0.1,
)
```
---
## Reverse KL
**Reverse-KL self-distillation** computes `KL(student ‖ teacher)` over a
`(num_tokens, vocab)` slab using an online normalization algorithm that reads
each logit row exactly once on the forward-only path:
```
p = softmax(student_logits)
q = softmax(teacher_logits)
kl_per_row = Σ_v p_v · (log p_v − log q_v)
loss = (mask · kl_per_row).sum() / mask.sum()
```
The gradient through the softmax Jacobian is analytical:
```
grad_student_v = scale · p_v · (log p_v − log q_v − kl_per_row)
```
where `scale = mask[r] · inv_n_valid`.
**Inputs**:
- `student_logits`, `teacher_logits`: `(*, V)` - arbitrary leading dims (typically `(bs, seq_len, vocab)`); both must share shape and dtype
- `completions_mask`: shape matching `student_logits.shape[:-1]`
> ⚠️ **Fully-masked batches**: `inv_n_valid = 1 / mask.sum()` is not clamped, so a batch where every token is masked produces inf/NaN. Guard upstream if that case is reachable.
**Returns**: `(loss, grad_student_logits)` from `reverse_kl`; scalar `loss` from `reverse_kl_fwd`.
```python
import torch
from kernels import get_kernel
km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0)
device = torch.device("cuda")
# Qwen3.5-style vocab; arbitrary leading dims supported
bs, seq_len, vocab = 4, 256, 248320
student_logits = torch.randn(bs, seq_len, vocab, dtype=torch.bfloat16, device=device, requires_grad=True)
teacher_logits = torch.randn(bs, seq_len, vocab, dtype=torch.bfloat16, device=device)
completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2)
# 1) Direct (loss, grad) - lowest overhead training path
loss, grad = km.reverse_kl(student_logits, teacher_logits, completions_mask)
student_logits.backward(grad)
# 2) Autograd-aware - works with loss.backward() and torch.compile
loss = km.reverse_kl_autograd(
student_logits.requires_grad_(), teacher_logits, completions_mask
)
loss.backward()
# 3) Forward-only - inference / KL monitoring, no gradient buffer
loss = km.reverse_kl_fwd(student_logits, teacher_logits, completions_mask)
```
---
## Performance
All numbers are geometric-mean speedups over H100 SXM (SM90a). Full methodology
and per-shape plots in the [release post](https://geometric.so/blog/2026/05/08/hf-kernel-hub).
### `kernels` CLI benchmark
Timed with `time.perf_counter` + `cuda.synchronize()`, mean over 100 iterations.
| Kernel | vs eager | vs `torch.compile` |
| --- | --- | --- |
| `grpo_loss_fwd` | 5.68× | 2.45× |
| `grpo_loss` | 20.79× | 1.98x |
| `bnpo_loss_fwd` | 5.29× | 2.52× |
| `bnpo_loss` | 16.81× | 2.27× |
| `reverse_kl_fwd`| 6.88× | 2.45× |
| `reverse_kl` | 7.03× | 2.61× |
---
## Benchmark animations
### BNPO Loss vs eager PyTorch
<picture>
<source media="(prefers-color-scheme: dark)" srcset="benchmark_results/bnpo_loss_eager/bnpo_loss_eager_dark_animation.svg">
<img width="90%" src="benchmark_results/bnpo_loss_eager/bnpo_loss_eager_light_animation.svg" alt="BNPO loss latency vs eager PyTorch">
</picture>
### BNPO Loss vs torch.compile
<picture>
<source media="(prefers-color-scheme: dark)" srcset="benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_dark_animation.svg">
<img width="90%" src="benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_light_animation.svg" alt="BNPO loss latency vs torch.compile">
</picture>
### GRPO Loss vs eager PyTorch
<picture>
<source media="(prefers-color-scheme: dark)" srcset="benchmark_results/grpo_loss_eager/grpo_loss_eager_dark_animation.svg">
<img width="90%" src="benchmark_results/grpo_loss_eager/grpo_loss_eager_light_animation.svg" alt="GRPO loss latency vs eager PyTorch">
</picture>
### GRPO Loss vs torch.compile
<picture>
<source media="(prefers-color-scheme: dark)" srcset="benchmark_results/grpo_loss_compiled/grpo_loss_compiled_dark_animation.svg">
<img width="90%" src="benchmark_results/grpo_loss_compiled/grpo_loss_compiled_light_animation.svg" alt="GRPO loss latency vs torch.compile">
</picture>
### Reverse KL vs eager PyTorch
<picture>
<source media="(prefers-color-scheme: dark)" srcset="benchmark_results/reverse_kl_eager/reverse_kl_eager_dark_animation.svg">
<img width="90%" src="benchmark_results/reverse_kl_eager/reverse_kl_eager_light_animation.svg" alt="Reverse KL latency vs eager PyTorch">
</picture>
### Reverse KL vs torch.compile
<picture>
<source media="(prefers-color-scheme: dark)" srcset="benchmark_results/reverse_kl_compiled/reverse_kl_compiled_dark_animation.svg">
<img width="90%" src="benchmark_results/reverse_kl_compiled/reverse_kl_compiled_light_animation.svg" alt="Reverse KL latency vs torch.compile">
</picture> |