--- library_name: kernels license: apache-2.0 tags: - cuda - cutlass - cute-dsl - rl - distillation - trl - grpo - bnpo - kl-divergence --- # Geometric-AI Kernels Fused **CuteDSL** kernels for the loss functions that dominate post-training workloads: PPO-family policy losses (BNPO, GRPO) and reverse-KL self-distillation. Each kernel ships a **single-launch fused forward + backward** path that returns `(loss, grad_logprobs)` directly. No `torch.autograd.Function` wrapper, no extra `grad_output * dpolicy` backward kernel, and no host-side syncs in the hot path. Background and benchmarks: see the [release post](https://geometric.so/blog/2026/05/08/hf-kernel-hub). - **Backend**: CUDA (NVIDIA CUTLASS DSL). - **Min GPU**: SM80 (Ampere) - required by `nvidia-cutlass-dsl`. Tested on H100 (SM90). Should work on SM80 (Ampere), SM86 (RTX 3090, A40), SM89 (RTX 4090, L40S), SM90a (H100 SXM), and SM100 (Blackwell B200/GB200). - **Min CUDA**: 12.8. - **Dtypes**: `float32`, `float16`, `bfloat16`. - **Dynamic shapes**: a single compile handles arbitrary batch size and sequence length, no recompiles when shapes change between calls (common in post-training rollouts). ## Kernels | Kernel family | Direct (no autograd) | Autograd-aware | Forward-only | | --- | --- | --- | --- | | BNPO loss | `bnpo_loss` | `bnpo_loss_autograd` | `bnpo_loss_fwd` | | GRPO loss | `grpo_loss` | `grpo_loss_autograd` | `grpo_loss_fwd` | | Reverse KL | `reverse_kl` | `reverse_kl_autograd` | `reverse_kl_fwd` | ### Entry points Each kernel family exposes three entry points with the same underlying CuteDSL kernel: - **`(...)`** - fused fwd+bwd, returns `(loss, grad)` from one `@cute.jit` dispatch. Lowest-overhead path; the caller chains the gradient into the upstream model with `policy_logprobs.backward(grad)`. Use this in custom training loops where you control gradient flow. - **`_autograd(...)`** - same kernel, registered via `torch.library.custom_op` + `register_autograd`. `loss.backward()` works and composes with `torch.compile(fullgraph=True)`. There is a noticeable per-call dispatcher overhead vs. the direct path. - **`_fwd(...)`** - forward-only, returns scalar `loss` and skips the gradient buffer entirely. Use for inference / validation / reward-model scoring. ## Loading the kernels ``` pip install apache-tvm-ffi nvidia-cutlass-dsl ``` ```python from kernels import get_kernel km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0) ``` --- ## BNPO Loss **Batch-Normalized Policy Optimization** sums per-token policy and KL terms across the **entire batch** and divides by the global valid-token count: ``` loss = ((per_token_loss + β·kl) · mask).sum() / max(mask.sum(), 1) ``` where `per_token_loss` is the PPO-clipped ratio loss: ``` ratio = exp(policy_logprobs - old_policy_logprobs) clipped = clip(ratio, 1−ε, 1+ε_high) per_token = −advantages · min(ratio, clipped) kl = exp(ref_logprobs − policy_logprobs) − (ref_logprobs − policy_logprobs) − 1 ``` The global denominator is computed entirely on-GPU via cross-CTA atomics - no host-side `mask.sum()` sync. When `beta=0` the KL branch is dead-coded at compile time. **Inputs**: - `policy_logprobs`, `old_policy_logprobs`, `ref_logprobs`: `(bs, seq_len)`, fp32/fp16/bf16 - `advantages`: `(bs,)` - `completions_mask`: `(bs, seq_len)`, bool or int8 **Returns**: `(loss, grad_policy_logprobs)` from `bnpo_loss`; scalar `loss` from `bnpo_loss_fwd`. ```python import torch from kernels import get_kernel km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0) device = torch.device("cuda") bs, seq_len = 16, 1024 policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device, requires_grad=True) old_policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device) ref_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device) advantages = torch.randn(bs, dtype=torch.bfloat16, device=device) completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2).to(torch.int8) # 1) Direct (loss, grad) - lowest overhead training path loss, grad = km.bnpo_loss( policy_logprobs, old_policy_logprobs, ref_logprobs, advantages, completions_mask, epsilon=0.2, epsilon_high=0.2, beta=0.1, ) policy_logprobs.backward(grad) # 2) Autograd-aware - works with loss.backward() and torch.compile loss = km.bnpo_loss_autograd( policy_logprobs.requires_grad_(), old_policy_logprobs, ref_logprobs, advantages, completions_mask, epsilon=0.2, epsilon_high=0.2, beta=0.1, ) loss.backward() # 3) Forward-only - inference / reward scoring, no gradient buffer loss = km.bnpo_loss_fwd( policy_logprobs, old_policy_logprobs, ref_logprobs, advantages, completions_mask, epsilon=0.2, epsilon_high=0.2, beta=0.1, ) ``` --- ## GRPO Loss **Group Relative Policy Optimization** implements TRL's default **per-response normalization** variant - each response is normalized by its own valid-token count before averaging across the batch: ``` loss = mean_r( ((per_token_loss + β·kl) · mask).sum(-1) / max(mask.sum(-1), 1) ) ``` `per_token_loss` and `kl` are the same clipped-ratio and KL expressions as BNPO. `completions_mask` is **required** because the per-response denominator is mask-derived. The kernel uses one CTA per row so the per-row mask sum is reduced inside the block - no cross-CTA atomics on the scaling pass. **Inputs**: - `policy_logprobs`, `old_policy_logprobs`, `ref_logprobs`: `(bs, seq_len)`, fp32/fp16/bf16 - `advantages`: `(bs,)` - `completions_mask`: `(bs, seq_len)`, bool or int8 - **required** **Returns**: `(loss, grad_policy_logprobs)` from `grpo_loss`; scalar `loss` from `grpo_loss_fwd`. ```python import torch from kernels import get_kernel km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0) device = torch.device("cuda") bs, seq_len = 16, 1024 policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device, requires_grad=True) old_policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device) ref_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device) advantages = torch.randn(bs, dtype=torch.bfloat16, device=device) completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2).to(torch.int8) # 1) Direct (loss, grad) - lowest overhead training path loss, grad = km.grpo_loss( policy_logprobs, old_policy_logprobs, ref_logprobs, advantages, completions_mask, epsilon=0.2, epsilon_high=0.2, beta=0.1, ) policy_logprobs.backward(grad) # 2) Autograd-aware - works with loss.backward() and torch.compile loss = km.grpo_loss_autograd( policy_logprobs.requires_grad_(), old_policy_logprobs, ref_logprobs, advantages, completions_mask, epsilon=0.2, epsilon_high=0.2, beta=0.1, ) loss.backward() # 3) Forward-only - inference / reward scoring, no gradient buffer loss = km.grpo_loss_fwd( policy_logprobs, old_policy_logprobs, ref_logprobs, advantages, completions_mask, epsilon=0.2, epsilon_high=0.2, beta=0.1, ) ``` --- ## Reverse KL **Reverse-KL self-distillation** computes `KL(student ‖ teacher)` over a `(num_tokens, vocab)` slab using an online normalization algorithm that reads each logit row exactly once on the forward-only path: ``` p = softmax(student_logits) q = softmax(teacher_logits) kl_per_row = Σ_v p_v · (log p_v − log q_v) loss = (mask · kl_per_row).sum() / mask.sum() ``` The gradient through the softmax Jacobian is analytical: ``` grad_student_v = scale · p_v · (log p_v − log q_v − kl_per_row) ``` where `scale = mask[r] · inv_n_valid`. **Inputs**: - `student_logits`, `teacher_logits`: `(*, V)` - arbitrary leading dims (typically `(bs, seq_len, vocab)`); both must share shape and dtype - `completions_mask`: shape matching `student_logits.shape[:-1]` > ⚠️ **Fully-masked batches**: `inv_n_valid = 1 / mask.sum()` is not clamped, so a batch where every token is masked produces inf/NaN. Guard upstream if that case is reachable. **Returns**: `(loss, grad_student_logits)` from `reverse_kl`; scalar `loss` from `reverse_kl_fwd`. ```python import torch from kernels import get_kernel km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0) device = torch.device("cuda") # Qwen3.5-style vocab; arbitrary leading dims supported bs, seq_len, vocab = 4, 256, 248320 student_logits = torch.randn(bs, seq_len, vocab, dtype=torch.bfloat16, device=device, requires_grad=True) teacher_logits = torch.randn(bs, seq_len, vocab, dtype=torch.bfloat16, device=device) completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2) # 1) Direct (loss, grad) - lowest overhead training path loss, grad = km.reverse_kl(student_logits, teacher_logits, completions_mask) student_logits.backward(grad) # 2) Autograd-aware - works with loss.backward() and torch.compile loss = km.reverse_kl_autograd( student_logits.requires_grad_(), teacher_logits, completions_mask ) loss.backward() # 3) Forward-only - inference / KL monitoring, no gradient buffer loss = km.reverse_kl_fwd(student_logits, teacher_logits, completions_mask) ``` --- ## Performance All numbers are geometric-mean speedups over H100 SXM (SM90a). Full methodology and per-shape plots in the [release post](https://geometric.so/blog/2026/05/08/hf-kernel-hub). ### `kernels` CLI benchmark Timed with `time.perf_counter` + `cuda.synchronize()`, mean over 100 iterations. | Kernel | vs eager | vs `torch.compile` | | --- | --- | --- | | `grpo_loss_fwd` | 5.68× | 2.45× | | `grpo_loss` | 20.79× | 1.98x | | `bnpo_loss_fwd` | 5.29× | 2.52× | | `bnpo_loss` | 16.81× | 2.27× | | `reverse_kl_fwd`| 6.88× | 2.45× | | `reverse_kl` | 7.03× | 2.61× | --- ## Benchmark animations ### BNPO Loss vs eager PyTorch BNPO loss latency vs eager PyTorch ### BNPO Loss vs torch.compile BNPO loss latency vs torch.compile ### GRPO Loss vs eager PyTorch GRPO loss latency vs eager PyTorch ### GRPO Loss vs torch.compile GRPO loss latency vs torch.compile ### Reverse KL vs eager PyTorch Reverse KL latency vs eager PyTorch ### Reverse KL vs torch.compile Reverse KL latency vs torch.compile