| --- |
| library_name: kernels |
| license: apache-2.0 |
| tags: |
| - cuda |
| - cutlass |
| - cute-dsl |
| - rl |
| - distillation |
| - trl |
| - grpo |
| - bnpo |
| - kl-divergence |
| --- |
| |
| # Geometric-AI Kernels |
|
|
| Fused **CuteDSL** kernels for the loss functions that dominate post-training |
| workloads: PPO-family policy losses (BNPO, GRPO) and reverse-KL |
| self-distillation. |
|
|
| Each kernel ships a **single-launch fused forward + |
| backward** path that returns `(loss, grad_logprobs)` directly. No `torch.autograd.Function` wrapper, no extra `grad_output * dpolicy` backward |
| kernel, and no host-side syncs in the hot path. |
|
|
| Background and benchmarks: see the |
| [release post](https://geometric.so/blog/2026/05/08/hf-kernel-hub). |
|
|
| - **Backend**: CUDA (NVIDIA CUTLASS DSL). |
| - **Min GPU**: SM80 (Ampere) - required by `nvidia-cutlass-dsl`. Tested on H100 (SM90). Should work on SM80 (Ampere), SM86 (RTX 3090, A40), SM89 (RTX 4090, L40S), SM90a (H100 SXM), and SM100 (Blackwell B200/GB200). |
| - **Min CUDA**: 12.8. |
| - **Dtypes**: `float32`, `float16`, `bfloat16`. |
| - **Dynamic shapes**: a single compile handles arbitrary batch size and |
| sequence length, no recompiles when shapes change between calls (common |
| in post-training rollouts). |
|
|
| ## Kernels |
|
|
| | Kernel family | Direct (no autograd) | Autograd-aware | Forward-only | |
| | --- | --- | --- | --- | |
| | BNPO loss | `bnpo_loss` | `bnpo_loss_autograd` | `bnpo_loss_fwd` | |
| | GRPO loss | `grpo_loss` | `grpo_loss_autograd` | `grpo_loss_fwd` | |
| | Reverse KL | `reverse_kl` | `reverse_kl_autograd` | `reverse_kl_fwd` | |
|
|
| ### Entry points |
|
|
| Each kernel family exposes three entry points with the same underlying CuteDSL kernel: |
|
|
| - **`<name>(...)`** - fused fwd+bwd, returns `(loss, grad)` from one `@cute.jit` |
| dispatch. Lowest-overhead path; the caller chains the gradient into the upstream |
| model with `policy_logprobs.backward(grad)`. Use this in custom training loops |
| where you control gradient flow. |
| - **`<name>_autograd(...)`** - same kernel, registered via |
| `torch.library.custom_op` + `register_autograd`. `loss.backward()` works |
| and composes with `torch.compile(fullgraph=True)`. There is a noticeable |
| per-call dispatcher overhead vs. the direct path. |
| - **`<name>_fwd(...)`** - forward-only, returns scalar `loss` and skips |
| the gradient buffer entirely. Use for inference / validation / |
| reward-model scoring. |
|
|
| ## Loading the kernels |
| ``` |
| pip install apache-tvm-ffi nvidia-cutlass-dsl |
| ``` |
|
|
| ```python |
| from kernels import get_kernel |
| |
| km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0) |
| ``` |
|
|
| --- |
|
|
| ## BNPO Loss |
|
|
| **Batch-Normalized Policy Optimization** sums per-token policy and KL terms |
| across the **entire batch** and divides by the global valid-token count: |
|
|
| ``` |
| loss = ((per_token_loss + β·kl) · mask).sum() / max(mask.sum(), 1) |
| ``` |
|
|
| where `per_token_loss` is the PPO-clipped ratio loss: |
|
|
| ``` |
| ratio = exp(policy_logprobs - old_policy_logprobs) |
| clipped = clip(ratio, 1−ε, 1+ε_high) |
| per_token = −advantages · min(ratio, clipped) |
| kl = exp(ref_logprobs − policy_logprobs) − (ref_logprobs − policy_logprobs) − 1 |
| ``` |
|
|
| The global denominator is computed entirely on-GPU via cross-CTA atomics - |
| no host-side `mask.sum()` sync. When `beta=0` the KL branch is dead-coded |
| at compile time. |
|
|
| **Inputs**: |
| - `policy_logprobs`, `old_policy_logprobs`, `ref_logprobs`: `(bs, seq_len)`, fp32/fp16/bf16 |
| - `advantages`: `(bs,)` |
| - `completions_mask`: `(bs, seq_len)`, bool or int8 |
|
|
| **Returns**: `(loss, grad_policy_logprobs)` from `bnpo_loss`; scalar `loss` from `bnpo_loss_fwd`. |
|
|
| ```python |
| import torch |
| from kernels import get_kernel |
| |
| km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0) |
| device = torch.device("cuda") |
| |
| bs, seq_len = 16, 1024 |
| policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device, requires_grad=True) |
| old_policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device) |
| ref_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device) |
| advantages = torch.randn(bs, dtype=torch.bfloat16, device=device) |
| completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2).to(torch.int8) |
| |
| # 1) Direct (loss, grad) - lowest overhead training path |
| loss, grad = km.bnpo_loss( |
| policy_logprobs, old_policy_logprobs, ref_logprobs, |
| advantages, completions_mask, |
| epsilon=0.2, epsilon_high=0.2, beta=0.1, |
| ) |
| policy_logprobs.backward(grad) |
| |
| # 2) Autograd-aware - works with loss.backward() and torch.compile |
| loss = km.bnpo_loss_autograd( |
| policy_logprobs.requires_grad_(), |
| old_policy_logprobs, ref_logprobs, |
| advantages, completions_mask, |
| epsilon=0.2, epsilon_high=0.2, beta=0.1, |
| ) |
| loss.backward() |
| |
| # 3) Forward-only - inference / reward scoring, no gradient buffer |
| loss = km.bnpo_loss_fwd( |
| policy_logprobs, old_policy_logprobs, ref_logprobs, |
| advantages, completions_mask, |
| epsilon=0.2, epsilon_high=0.2, beta=0.1, |
| ) |
| ``` |
|
|
| --- |
|
|
| ## GRPO Loss |
|
|
| **Group Relative Policy Optimization** implements TRL's default |
| **per-response normalization** variant - each response is normalized by its |
| own valid-token count before averaging across the batch: |
|
|
| ``` |
| loss = mean_r( ((per_token_loss + β·kl) · mask).sum(-1) / max(mask.sum(-1), 1) ) |
| ``` |
|
|
| `per_token_loss` and `kl` are the same clipped-ratio and KL expressions as BNPO. |
| `completions_mask` is **required** because the per-response denominator is |
| mask-derived. The kernel uses one CTA per row so the per-row mask sum is |
| reduced inside the block - no cross-CTA atomics on the scaling pass. |
|
|
| **Inputs**: |
| - `policy_logprobs`, `old_policy_logprobs`, `ref_logprobs`: `(bs, seq_len)`, fp32/fp16/bf16 |
| - `advantages`: `(bs,)` |
| - `completions_mask`: `(bs, seq_len)`, bool or int8 - **required** |
|
|
| **Returns**: `(loss, grad_policy_logprobs)` from `grpo_loss`; scalar `loss` from `grpo_loss_fwd`. |
|
|
| ```python |
| import torch |
| from kernels import get_kernel |
| |
| km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0) |
| device = torch.device("cuda") |
| |
| bs, seq_len = 16, 1024 |
| policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device, requires_grad=True) |
| old_policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device) |
| ref_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device) |
| advantages = torch.randn(bs, dtype=torch.bfloat16, device=device) |
| completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2).to(torch.int8) |
| |
| # 1) Direct (loss, grad) - lowest overhead training path |
| loss, grad = km.grpo_loss( |
| policy_logprobs, old_policy_logprobs, ref_logprobs, |
| advantages, completions_mask, |
| epsilon=0.2, epsilon_high=0.2, beta=0.1, |
| ) |
| policy_logprobs.backward(grad) |
| |
| # 2) Autograd-aware - works with loss.backward() and torch.compile |
| loss = km.grpo_loss_autograd( |
| policy_logprobs.requires_grad_(), |
| old_policy_logprobs, ref_logprobs, |
| advantages, completions_mask, |
| epsilon=0.2, epsilon_high=0.2, beta=0.1, |
| ) |
| loss.backward() |
| |
| # 3) Forward-only - inference / reward scoring, no gradient buffer |
| loss = km.grpo_loss_fwd( |
| policy_logprobs, old_policy_logprobs, ref_logprobs, |
| advantages, completions_mask, |
| epsilon=0.2, epsilon_high=0.2, beta=0.1, |
| ) |
| ``` |
|
|
| --- |
|
|
| ## Reverse KL |
|
|
| **Reverse-KL self-distillation** computes `KL(student ‖ teacher)` over a |
| `(num_tokens, vocab)` slab using an online normalization algorithm that reads |
| each logit row exactly once on the forward-only path: |
|
|
| ``` |
| p = softmax(student_logits) |
| q = softmax(teacher_logits) |
| kl_per_row = Σ_v p_v · (log p_v − log q_v) |
| loss = (mask · kl_per_row).sum() / mask.sum() |
| ``` |
|
|
| The gradient through the softmax Jacobian is analytical: |
|
|
| ``` |
| grad_student_v = scale · p_v · (log p_v − log q_v − kl_per_row) |
| ``` |
|
|
| where `scale = mask[r] · inv_n_valid`. |
|
|
| **Inputs**: |
| - `student_logits`, `teacher_logits`: `(*, V)` - arbitrary leading dims (typically `(bs, seq_len, vocab)`); both must share shape and dtype |
| - `completions_mask`: shape matching `student_logits.shape[:-1]` |
|
|
| > ⚠️ **Fully-masked batches**: `inv_n_valid = 1 / mask.sum()` is not clamped, so a batch where every token is masked produces inf/NaN. Guard upstream if that case is reachable. |
|
|
| **Returns**: `(loss, grad_student_logits)` from `reverse_kl`; scalar `loss` from `reverse_kl_fwd`. |
|
|
| ```python |
| import torch |
| from kernels import get_kernel |
| |
| km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0) |
| device = torch.device("cuda") |
| |
| # Qwen3.5-style vocab; arbitrary leading dims supported |
| bs, seq_len, vocab = 4, 256, 248320 |
| student_logits = torch.randn(bs, seq_len, vocab, dtype=torch.bfloat16, device=device, requires_grad=True) |
| teacher_logits = torch.randn(bs, seq_len, vocab, dtype=torch.bfloat16, device=device) |
| completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2) |
| |
| # 1) Direct (loss, grad) - lowest overhead training path |
| loss, grad = km.reverse_kl(student_logits, teacher_logits, completions_mask) |
| student_logits.backward(grad) |
| |
| # 2) Autograd-aware - works with loss.backward() and torch.compile |
| loss = km.reverse_kl_autograd( |
| student_logits.requires_grad_(), teacher_logits, completions_mask |
| ) |
| loss.backward() |
| |
| # 3) Forward-only - inference / KL monitoring, no gradient buffer |
| loss = km.reverse_kl_fwd(student_logits, teacher_logits, completions_mask) |
| ``` |
|
|
| --- |
|
|
| ## Performance |
|
|
| All numbers are geometric-mean speedups over H100 SXM (SM90a). Full methodology |
| and per-shape plots in the [release post](https://geometric.so/blog/2026/05/08/hf-kernel-hub). |
|
|
| ### `kernels` CLI benchmark |
|
|
| Timed with `time.perf_counter` + `cuda.synchronize()`, mean over 100 iterations. |
|
|
| | Kernel | vs eager | vs `torch.compile` | |
| | --- | --- | --- | |
| | `grpo_loss_fwd` | 5.68× | 2.45× | |
| | `grpo_loss` | 20.79× | 1.98x | |
| | `bnpo_loss_fwd` | 5.29× | 2.52× | |
| | `bnpo_loss` | 16.81× | 2.27× | |
| | `reverse_kl_fwd`| 6.88× | 2.45× | |
| | `reverse_kl` | 7.03× | 2.61× | |
| --- |
| |
| ## Benchmark animations |
| |
| ### BNPO Loss vs eager PyTorch |
| |
| <picture> |
| <source media="(prefers-color-scheme: dark)" srcset="benchmark_results/bnpo_loss_eager/bnpo_loss_eager_dark_animation.svg"> |
| <img width="90%" src="benchmark_results/bnpo_loss_eager/bnpo_loss_eager_light_animation.svg" alt="BNPO loss latency vs eager PyTorch"> |
| </picture> |
| |
| ### BNPO Loss vs torch.compile |
| |
| <picture> |
| <source media="(prefers-color-scheme: dark)" srcset="benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_dark_animation.svg"> |
| <img width="90%" src="benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_light_animation.svg" alt="BNPO loss latency vs torch.compile"> |
| </picture> |
| |
| ### GRPO Loss vs eager PyTorch |
| |
| <picture> |
| <source media="(prefers-color-scheme: dark)" srcset="benchmark_results/grpo_loss_eager/grpo_loss_eager_dark_animation.svg"> |
| <img width="90%" src="benchmark_results/grpo_loss_eager/grpo_loss_eager_light_animation.svg" alt="GRPO loss latency vs eager PyTorch"> |
| </picture> |
| |
| ### GRPO Loss vs torch.compile |
| |
| <picture> |
| <source media="(prefers-color-scheme: dark)" srcset="benchmark_results/grpo_loss_compiled/grpo_loss_compiled_dark_animation.svg"> |
| <img width="90%" src="benchmark_results/grpo_loss_compiled/grpo_loss_compiled_light_animation.svg" alt="GRPO loss latency vs torch.compile"> |
| </picture> |
| |
| ### Reverse KL vs eager PyTorch |
| |
| <picture> |
| <source media="(prefers-color-scheme: dark)" srcset="benchmark_results/reverse_kl_eager/reverse_kl_eager_dark_animation.svg"> |
| <img width="90%" src="benchmark_results/reverse_kl_eager/reverse_kl_eager_light_animation.svg" alt="Reverse KL latency vs eager PyTorch"> |
| </picture> |
| |
| ### Reverse KL vs torch.compile |
| |
| <picture> |
| <source media="(prefers-color-scheme: dark)" srcset="benchmark_results/reverse_kl_compiled/reverse_kl_compiled_dark_animation.svg"> |
| <img width="90%" src="benchmark_results/reverse_kl_compiled/reverse_kl_compiled_light_animation.svg" alt="Reverse KL latency vs torch.compile"> |
| </picture> |