Instructions to use Geometric-AI/geometric-ai-kernels with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use Geometric-AI/geometric-ai-kernels with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("Geometric-AI/geometric-ai-kernels") - Notebooks
- Google Colab
- Kaggle
Uploaded using `kernel-builder`.
Browse files
README.md
CHANGED
|
@@ -1,3 +1,338 @@
|
|
| 1 |
-
---
|
| 2 |
-
|
| 3 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: kernels
|
| 3 |
+
license: apache-2.0
|
| 4 |
+
tags:
|
| 5 |
+
- cuda
|
| 6 |
+
- cutlass
|
| 7 |
+
- cute-dsl
|
| 8 |
+
- rl
|
| 9 |
+
- distillation
|
| 10 |
+
- trl
|
| 11 |
+
- grpo
|
| 12 |
+
- bnpo
|
| 13 |
+
- kl-divergence
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
# Geometric-AI Kernels
|
| 17 |
+
|
| 18 |
+
Fused **CuteDSL** kernels for the loss functions that dominate post-training
|
| 19 |
+
workloads: PPO-family policy losses (BNPO, GRPO) and reverse-KL
|
| 20 |
+
self-distillation. Each kernel ships a **single-launch fused forward +
|
| 21 |
+
backward** path that returns `(loss, grad_logprobs)` directly — no
|
| 22 |
+
`torch.autograd.Function` wrapper, no extra `grad_output * dpolicy` backward
|
| 23 |
+
kernel, and no host-side syncs in the hot path.
|
| 24 |
+
|
| 25 |
+
Background and benchmarks: see the
|
| 26 |
+
[release post](https://geometric.so/blog/2026/05/08/hf-kernel-hub).
|
| 27 |
+
|
| 28 |
+
- **Backend**: CUDA (NVIDIA CUTLASS DSL).
|
| 29 |
+
- **Min GPU**: SM80 (Ampere) — required by `nvidia-cutlass-dsl`. Tested on A100 (SM80) and H100 (SM90). Works on SM86 (RTX 3090, A40), SM89 (RTX 4090, L40S), SM90a (H100 SXM), and SM100 (Blackwell B200/GB200).
|
| 30 |
+
- **Min CUDA**: 12.8.
|
| 31 |
+
- **Dtypes**: `float32`, `float16`, `bfloat16`.
|
| 32 |
+
- **Dynamic shapes**: a single compile handles arbitrary batch size and
|
| 33 |
+
sequence length — no recompiles when shapes change between calls (common
|
| 34 |
+
in post-training rollouts).
|
| 35 |
+
|
| 36 |
+
## Kernels
|
| 37 |
+
|
| 38 |
+
| Kernel family | Direct (no autograd) | Autograd-aware | Forward-only |
|
| 39 |
+
| --- | --- | --- | --- |
|
| 40 |
+
| BNPO loss | `bnpo_loss` | `bnpo_loss_autograd` | `bnpo_loss_fwd` |
|
| 41 |
+
| GRPO loss | `grpo_loss` | `grpo_loss_autograd` | `grpo_loss_fwd` |
|
| 42 |
+
| Reverse KL | `reverse_kl` | `reverse_kl_autograd` | `reverse_kl_fwd` |
|
| 43 |
+
|
| 44 |
+
### Entry points
|
| 45 |
+
|
| 46 |
+
Each kernel family exposes three entry points with the same underlying CuteDSL kernel:
|
| 47 |
+
|
| 48 |
+
- **`<name>(...)`** — fused fwd+bwd, returns `(loss, grad)` from one `@cute.jit`
|
| 49 |
+
dispatch. Lowest-overhead path; the caller chains the gradient into the upstream
|
| 50 |
+
model with `policy_logprobs.backward(grad)`. Use this in custom training loops
|
| 51 |
+
where you control gradient flow.
|
| 52 |
+
- **`<name>_autograd(...)`** — same kernel, registered via
|
| 53 |
+
`torch.library.custom_op` + `register_autograd`. `loss.backward()` works
|
| 54 |
+
and composes with `torch.compile(fullgraph=True)`. There is a noticeable
|
| 55 |
+
per-call dispatcher overhead vs. the direct path.
|
| 56 |
+
- **`<name>_fwd(...)`** — forward-only, returns scalar `loss` and skips
|
| 57 |
+
the gradient buffer entirely. Use for inference / validation /
|
| 58 |
+
reward-model scoring.
|
| 59 |
+
|
| 60 |
+
## Loading the kernels
|
| 61 |
+
|
| 62 |
+
```python
|
| 63 |
+
from kernels import get_kernel
|
| 64 |
+
|
| 65 |
+
km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0)
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
---
|
| 69 |
+
|
| 70 |
+
## BNPO Loss
|
| 71 |
+
|
| 72 |
+
**Batch-Normalized Policy Optimization** sums per-token policy and KL terms
|
| 73 |
+
across the **entire batch** and divides by the global valid-token count:
|
| 74 |
+
|
| 75 |
+
```
|
| 76 |
+
loss = ((per_token_loss + β·kl) · mask).sum() / max(mask.sum(), 1)
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
where `per_token_loss` is the PPO-clipped ratio loss:
|
| 80 |
+
|
| 81 |
+
```
|
| 82 |
+
ratio = exp(policy_logprobs - old_policy_logprobs)
|
| 83 |
+
clipped = clip(ratio, 1−ε, 1+ε_high)
|
| 84 |
+
per_token = −advantages · min(ratio, clipped)
|
| 85 |
+
kl = exp(ref_logprobs − policy_logprobs) − (ref_logprobs − policy_logprobs) − 1
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
The global denominator is computed entirely on-GPU via cross-CTA atomics —
|
| 89 |
+
no host-side `mask.sum()` sync. When `beta=0` the KL branch is dead-coded
|
| 90 |
+
at compile time.
|
| 91 |
+
|
| 92 |
+
**Inputs**:
|
| 93 |
+
- `policy_logprobs`, `old_policy_logprobs`, `ref_logprobs`: `(bs, seq_len)`, fp32/fp16/bf16
|
| 94 |
+
- `advantages`: `(bs,)`
|
| 95 |
+
- `completions_mask`: `(bs, seq_len)`, bool or int8
|
| 96 |
+
|
| 97 |
+
**Returns**: `(loss, grad_policy_logprobs)` from `bnpo_loss`; scalar `loss` from `bnpo_loss_fwd`.
|
| 98 |
+
|
| 99 |
+
```python
|
| 100 |
+
import torch
|
| 101 |
+
from kernels import get_kernel
|
| 102 |
+
|
| 103 |
+
km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0)
|
| 104 |
+
device = torch.device("cuda")
|
| 105 |
+
|
| 106 |
+
bs, seq_len = 16, 1024
|
| 107 |
+
policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device)
|
| 108 |
+
old_policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device)
|
| 109 |
+
ref_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device)
|
| 110 |
+
advantages = torch.randn(bs, dtype=torch.bfloat16, device=device)
|
| 111 |
+
completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2).to(torch.int8)
|
| 112 |
+
|
| 113 |
+
# 1) Direct (loss, grad) — lowest overhead training path
|
| 114 |
+
loss, grad = km.bnpo_loss(
|
| 115 |
+
policy_logprobs, old_policy_logprobs, ref_logprobs,
|
| 116 |
+
advantages, completions_mask,
|
| 117 |
+
epsilon=0.2, epsilon_high=0.2, beta=0.1,
|
| 118 |
+
)
|
| 119 |
+
policy_logprobs.backward(grad)
|
| 120 |
+
|
| 121 |
+
# 2) Autograd-aware — works with loss.backward() and torch.compile
|
| 122 |
+
loss = km.bnpo_loss_autograd(
|
| 123 |
+
policy_logprobs.requires_grad_(),
|
| 124 |
+
old_policy_logprobs, ref_logprobs,
|
| 125 |
+
advantages, completions_mask,
|
| 126 |
+
epsilon=0.2, epsilon_high=0.2, beta=0.1,
|
| 127 |
+
)
|
| 128 |
+
loss.backward()
|
| 129 |
+
|
| 130 |
+
# 3) Forward-only — inference / reward scoring, no gradient buffer
|
| 131 |
+
loss = km.bnpo_loss_fwd(
|
| 132 |
+
policy_logprobs, old_policy_logprobs, ref_logprobs,
|
| 133 |
+
advantages, completions_mask,
|
| 134 |
+
epsilon=0.2, epsilon_high=0.2, beta=0.1,
|
| 135 |
+
)
|
| 136 |
+
```
|
| 137 |
+
|
| 138 |
+
---
|
| 139 |
+
|
| 140 |
+
## GRPO Loss
|
| 141 |
+
|
| 142 |
+
**Group Relative Policy Optimization** implements TRL's default
|
| 143 |
+
**per-response normalization** variant — each response is normalized by its
|
| 144 |
+
own valid-token count before averaging across the batch:
|
| 145 |
+
|
| 146 |
+
```
|
| 147 |
+
loss = mean_r( ((per_token_loss + β·kl) · mask).sum(-1) / max(mask.sum(-1), 1) )
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
`per_token_loss` and `kl` are the same clipped-ratio and KL expressions as BNPO.
|
| 151 |
+
`completions_mask` is **required** because the per-response denominator is
|
| 152 |
+
mask-derived. The kernel uses one CTA per row so the per-row mask sum is
|
| 153 |
+
reduced inside the block — no cross-CTA atomics on the scaling pass.
|
| 154 |
+
|
| 155 |
+
**Inputs**:
|
| 156 |
+
- `policy_logprobs`, `old_policy_logprobs`, `ref_logprobs`: `(bs, seq_len)`, fp32/fp16/bf16
|
| 157 |
+
- `advantages`: `(bs,)`
|
| 158 |
+
- `completions_mask`: `(bs, seq_len)`, bool or int8 — **required**
|
| 159 |
+
|
| 160 |
+
**Returns**: `(loss, grad_policy_logprobs)` from `grpo_loss`; scalar `loss` from `grpo_loss_fwd`.
|
| 161 |
+
|
| 162 |
+
```python
|
| 163 |
+
import torch
|
| 164 |
+
from kernels import get_kernel
|
| 165 |
+
|
| 166 |
+
km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0)
|
| 167 |
+
device = torch.device("cuda")
|
| 168 |
+
|
| 169 |
+
bs, seq_len = 16, 1024
|
| 170 |
+
policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device)
|
| 171 |
+
old_policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device)
|
| 172 |
+
ref_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device)
|
| 173 |
+
advantages = torch.randn(bs, dtype=torch.bfloat16, device=device)
|
| 174 |
+
completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2).to(torch.int8)
|
| 175 |
+
|
| 176 |
+
# 1) Direct (loss, grad) — lowest overhead training path
|
| 177 |
+
loss, grad = km.grpo_loss(
|
| 178 |
+
policy_logprobs, old_policy_logprobs, ref_logprobs,
|
| 179 |
+
advantages, completions_mask,
|
| 180 |
+
epsilon=0.2, epsilon_high=0.2, beta=0.1,
|
| 181 |
+
)
|
| 182 |
+
policy_logprobs.backward(grad)
|
| 183 |
+
|
| 184 |
+
# 2) Autograd-aware — works with loss.backward() and torch.compile
|
| 185 |
+
loss = km.grpo_loss_autograd(
|
| 186 |
+
policy_logprobs.requires_grad_(),
|
| 187 |
+
old_policy_logprobs, ref_logprobs,
|
| 188 |
+
advantages, completions_mask,
|
| 189 |
+
epsilon=0.2, epsilon_high=0.2, beta=0.1,
|
| 190 |
+
)
|
| 191 |
+
loss.backward()
|
| 192 |
+
|
| 193 |
+
# 3) Forward-only — inference / reward scoring, no gradient buffer
|
| 194 |
+
loss = km.grpo_loss_fwd(
|
| 195 |
+
policy_logprobs, old_policy_logprobs, ref_logprobs,
|
| 196 |
+
advantages, completions_mask,
|
| 197 |
+
epsilon=0.2, epsilon_high=0.2, beta=0.1,
|
| 198 |
+
)
|
| 199 |
+
```
|
| 200 |
+
|
| 201 |
+
---
|
| 202 |
+
|
| 203 |
+
## Reverse KL
|
| 204 |
+
|
| 205 |
+
**Reverse-KL self-distillation** computes `KL(student ‖ teacher)` over a
|
| 206 |
+
`(num_tokens, vocab)` slab using an online normalization algorithm that reads
|
| 207 |
+
each logit row exactly once on the forward-only path:
|
| 208 |
+
|
| 209 |
+
```
|
| 210 |
+
p = softmax(student_logits)
|
| 211 |
+
q = softmax(teacher_logits)
|
| 212 |
+
kl_per_row = Σ_v p_v · (log p_v − log q_v)
|
| 213 |
+
loss = (mask · kl_per_row).sum() / mask.sum()
|
| 214 |
+
```
|
| 215 |
+
|
| 216 |
+
The gradient through the softmax Jacobian is analytical:
|
| 217 |
+
|
| 218 |
+
```
|
| 219 |
+
grad_student_v = scale · p_v · (log p_v − log q_v − kl_per_row)
|
| 220 |
+
```
|
| 221 |
+
|
| 222 |
+
where `scale = mask[r] · inv_n_valid`.
|
| 223 |
+
|
| 224 |
+
**Inputs**:
|
| 225 |
+
- `student_logits`, `teacher_logits`: `(*, V)` — arbitrary leading dims (typically `(bs, seq_len, vocab)`); both must share shape and dtype
|
| 226 |
+
- `completions_mask`: shape matching `student_logits.shape[:-1]`
|
| 227 |
+
|
| 228 |
+
> ⚠️ **Fully-masked batches**: `inv_n_valid = 1 / mask.sum()` is not clamped, so a batch where every token is masked produces inf/NaN. Guard upstream if that case is reachable.
|
| 229 |
+
|
| 230 |
+
**Returns**: `(loss, grad_student_logits)` from `reverse_kl`; scalar `loss` from `reverse_kl_fwd`.
|
| 231 |
+
|
| 232 |
+
```python
|
| 233 |
+
import torch
|
| 234 |
+
from kernels import get_kernel
|
| 235 |
+
|
| 236 |
+
km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0)
|
| 237 |
+
device = torch.device("cuda")
|
| 238 |
+
|
| 239 |
+
# Qwen3.5-style vocab; arbitrary leading dims supported
|
| 240 |
+
bs, seq_len, vocab = 4, 256, 248320
|
| 241 |
+
student_logits = torch.randn(bs, seq_len, vocab, dtype=torch.bfloat16, device=device)
|
| 242 |
+
teacher_logits = torch.randn(bs, seq_len, vocab, dtype=torch.bfloat16, device=device)
|
| 243 |
+
completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2)
|
| 244 |
+
|
| 245 |
+
# 1) Direct (loss, grad) — lowest overhead training path
|
| 246 |
+
loss, grad = km.reverse_kl(student_logits, teacher_logits, completions_mask)
|
| 247 |
+
student_logits.backward(grad)
|
| 248 |
+
|
| 249 |
+
# 2) Autograd-aware — works with loss.backward() and torch.compile
|
| 250 |
+
loss = km.reverse_kl_autograd(
|
| 251 |
+
student_logits.requires_grad_(), teacher_logits, completions_mask
|
| 252 |
+
)
|
| 253 |
+
loss.backward()
|
| 254 |
+
|
| 255 |
+
# 3) Forward-only — inference / KL monitoring, no gradient buffer
|
| 256 |
+
loss = km.reverse_kl_fwd(student_logits, teacher_logits, completions_mask)
|
| 257 |
+
```
|
| 258 |
+
|
| 259 |
+
---
|
| 260 |
+
|
| 261 |
+
## Performance
|
| 262 |
+
|
| 263 |
+
Numbers below are geometric-mean speedups from our in-house benchmark
|
| 264 |
+
(`triton.testing.do_bench`, fresh subprocess per shape). Baselines are
|
| 265 |
+
**eager PyTorch** and **`torch.compile(mode="max-autotune-no-cudagraphs",
|
| 266 |
+
fullgraph=True)`** with `torch._dynamo.config.trace_autograd_ops = True` so
|
| 267 |
+
the compiled baseline is a real Inductor-fused fwd+bwd graph.
|
| 268 |
+
|
| 269 |
+
| Kernel | β | vs eager | vs `torch.compile` |
|
| 270 |
+
| --- | --- | --- | --- |
|
| 271 |
+
| `bnpo_loss_fwd` | 0 | 1.6× | 1.3× |
|
| 272 |
+
| `bnpo_loss` | 0 | 1.5× | 1.2× |
|
| 273 |
+
| `bnpo_loss_fwd` | ≠ 0 | 1.4× | 1.1× |
|
| 274 |
+
| `bnpo_loss` | ≠ 0 | 1.3× | 1.0× |
|
| 275 |
+
| `grpo_loss_fwd` | ≠ 0 | 1.5× | 1.2× |
|
| 276 |
+
| `grpo_loss` | ≠ 0 | 1.4× | 1.1× |
|
| 277 |
+
| `reverse_kl_fwd`| | 1.3× | 1.1× |
|
| 278 |
+
| `reverse_kl` | | 1.2× | 1.0× |
|
| 279 |
+
|
| 280 |
+
Profiled on H100 SXM (SM90a). BNPO and GRPO benchmarked separately for `β = 0`
|
| 281 |
+
(KL term dead-coded at compile time) and `β ≠ 0`. Shapes:
|
| 282 |
+
|
| 283 |
+
- **BNPO / GRPO**: `(16, 1024)`, `(32, 2048)`, `(64, 4096)`, `(128, 8192)`,
|
| 284 |
+
`(128, 8193)` — the last entry exercises the predicated tail-tile path.
|
| 285 |
+
- **Reverse KL** (vocab = 248320, matching Qwen3.5): `(1, 64)`, `(2, 128)`,
|
| 286 |
+
`(4, 256)`, `(8, 512)`, `(16, 1024)`, `(8, 1981)`.
|
| 287 |
+
|
| 288 |
+
Reproduce locally:
|
| 289 |
+
|
| 290 |
+
```bash
|
| 291 |
+
make bench-kernel KERNEL=grpo_loss # or bnpo_loss, reverse_kl
|
| 292 |
+
```
|
| 293 |
+
|
| 294 |
+
---
|
| 295 |
+
|
| 296 |
+
## Benchmark animations
|
| 297 |
+
|
| 298 |
+
### BNPO Loss vs eager PyTorch
|
| 299 |
+
|
| 300 |
+
<picture>
|
| 301 |
+
<source media="(prefers-color-scheme: dark)" srcset="benchmark_results/bnpo_loss_eager/bnpo_loss_eager_dark_animation.svg">
|
| 302 |
+
<img src="benchmark_results/bnpo_loss_eager/bnpo_loss_eager_light_animation.svg" alt="BNPO loss latency vs eager PyTorch">
|
| 303 |
+
</picture>
|
| 304 |
+
|
| 305 |
+
### BNPO Loss vs torch.compile
|
| 306 |
+
|
| 307 |
+
<picture>
|
| 308 |
+
<source media="(prefers-color-scheme: dark)" srcset="benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_dark_animation.svg">
|
| 309 |
+
<img src="benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_light_animation.svg" alt="BNPO loss latency vs torch.compile">
|
| 310 |
+
</picture>
|
| 311 |
+
|
| 312 |
+
### GRPO Loss vs eager PyTorch
|
| 313 |
+
|
| 314 |
+
<picture>
|
| 315 |
+
<source media="(prefers-color-scheme: dark)" srcset="benchmark_results/grpo_loss_eager/grpo_loss_eager_dark_animation.svg">
|
| 316 |
+
<img src="benchmark_results/grpo_loss_eager/grpo_loss_eager_light_animation.svg" alt="GRPO loss latency vs eager PyTorch">
|
| 317 |
+
</picture>
|
| 318 |
+
|
| 319 |
+
### GRPO Loss vs torch.compile
|
| 320 |
+
|
| 321 |
+
<picture>
|
| 322 |
+
<source media="(prefers-color-scheme: dark)" srcset="benchmark_results/grpo_loss_compiled/grpo_loss_compiled_dark_animation.svg">
|
| 323 |
+
<img src="benchmark_results/grpo_loss_compiled/grpo_loss_compiled_light_animation.svg" alt="GRPO loss latency vs torch.compile">
|
| 324 |
+
</picture>
|
| 325 |
+
|
| 326 |
+
### Reverse KL vs eager PyTorch
|
| 327 |
+
|
| 328 |
+
<picture>
|
| 329 |
+
<source media="(prefers-color-scheme: dark)" srcset="benchmark_results/reverse_kl_eager/reverse_kl_eager_dark_animation.svg">
|
| 330 |
+
<img src="benchmark_results/reverse_kl_eager/reverse_kl_eager_light_animation.svg" alt="Reverse KL latency vs eager PyTorch">
|
| 331 |
+
</picture>
|
| 332 |
+
|
| 333 |
+
### Reverse KL vs torch.compile
|
| 334 |
+
|
| 335 |
+
<picture>
|
| 336 |
+
<source media="(prefers-color-scheme: dark)" srcset="benchmark_results/reverse_kl_compiled/reverse_kl_compiled_dark_animation.svg">
|
| 337 |
+
<img src="benchmark_results/reverse_kl_compiled/reverse_kl_compiled_light_animation.svg" alt="Reverse KL latency vs torch.compile">
|
| 338 |
+
</picture>
|