Instructions to use Geometric-AI/geometric-ai-kernels with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use Geometric-AI/geometric-ai-kernels with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("Geometric-AI/geometric-ai-kernels") - Notebooks
- Google Colab
- Kaggle
Uploaded using `kernel-builder`.
Browse files
README.md
CHANGED
|
@@ -17,20 +17,21 @@ tags:
|
|
| 17 |
|
| 18 |
Fused **CuteDSL** kernels for the loss functions that dominate post-training
|
| 19 |
workloads: PPO-family policy losses (BNPO, GRPO) and reverse-KL
|
| 20 |
-
self-distillation.
|
| 21 |
-
|
| 22 |
-
|
|
|
|
| 23 |
kernel, and no host-side syncs in the hot path.
|
| 24 |
|
| 25 |
Background and benchmarks: see the
|
| 26 |
[release post](https://geometric.so/blog/2026/05/08/hf-kernel-hub).
|
| 27 |
|
| 28 |
- **Backend**: CUDA (NVIDIA CUTLASS DSL).
|
| 29 |
-
- **Min GPU**: SM80 (Ampere)
|
| 30 |
- **Min CUDA**: 12.8.
|
| 31 |
- **Dtypes**: `float32`, `float16`, `bfloat16`.
|
| 32 |
- **Dynamic shapes**: a single compile handles arbitrary batch size and
|
| 33 |
-
sequence length
|
| 34 |
in post-training rollouts).
|
| 35 |
|
| 36 |
## Kernels
|
|
@@ -45,15 +46,15 @@ Background and benchmarks: see the
|
|
| 45 |
|
| 46 |
Each kernel family exposes three entry points with the same underlying CuteDSL kernel:
|
| 47 |
|
| 48 |
-
- **`<name>(...)`**
|
| 49 |
dispatch. Lowest-overhead path; the caller chains the gradient into the upstream
|
| 50 |
model with `policy_logprobs.backward(grad)`. Use this in custom training loops
|
| 51 |
where you control gradient flow.
|
| 52 |
-
- **`<name>_autograd(...)`**
|
| 53 |
`torch.library.custom_op` + `register_autograd`. `loss.backward()` works
|
| 54 |
and composes with `torch.compile(fullgraph=True)`. There is a noticeable
|
| 55 |
per-call dispatcher overhead vs. the direct path.
|
| 56 |
-
- **`<name>_fwd(...)`**
|
| 57 |
the gradient buffer entirely. Use for inference / validation /
|
| 58 |
reward-model scoring.
|
| 59 |
|
|
@@ -85,7 +86,7 @@ per_token = −advantages · min(ratio, clipped)
|
|
| 85 |
kl = exp(ref_logprobs − policy_logprobs) − (ref_logprobs − policy_logprobs) − 1
|
| 86 |
```
|
| 87 |
|
| 88 |
-
The global denominator is computed entirely on-GPU via cross-CTA atomics
|
| 89 |
no host-side `mask.sum()` sync. When `beta=0` the KL branch is dead-coded
|
| 90 |
at compile time.
|
| 91 |
|
|
@@ -110,7 +111,7 @@ ref_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=devi
|
|
| 110 |
advantages = torch.randn(bs, dtype=torch.bfloat16, device=device)
|
| 111 |
completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2).to(torch.int8)
|
| 112 |
|
| 113 |
-
# 1) Direct (loss, grad)
|
| 114 |
loss, grad = km.bnpo_loss(
|
| 115 |
policy_logprobs, old_policy_logprobs, ref_logprobs,
|
| 116 |
advantages, completions_mask,
|
|
@@ -118,7 +119,7 @@ loss, grad = km.bnpo_loss(
|
|
| 118 |
)
|
| 119 |
policy_logprobs.backward(grad)
|
| 120 |
|
| 121 |
-
# 2) Autograd-aware
|
| 122 |
loss = km.bnpo_loss_autograd(
|
| 123 |
policy_logprobs.requires_grad_(),
|
| 124 |
old_policy_logprobs, ref_logprobs,
|
|
@@ -127,7 +128,7 @@ loss = km.bnpo_loss_autograd(
|
|
| 127 |
)
|
| 128 |
loss.backward()
|
| 129 |
|
| 130 |
-
# 3) Forward-only
|
| 131 |
loss = km.bnpo_loss_fwd(
|
| 132 |
policy_logprobs, old_policy_logprobs, ref_logprobs,
|
| 133 |
advantages, completions_mask,
|
|
@@ -140,7 +141,7 @@ loss = km.bnpo_loss_fwd(
|
|
| 140 |
## GRPO Loss
|
| 141 |
|
| 142 |
**Group Relative Policy Optimization** implements TRL's default
|
| 143 |
-
**per-response normalization** variant
|
| 144 |
own valid-token count before averaging across the batch:
|
| 145 |
|
| 146 |
```
|
|
@@ -150,12 +151,12 @@ loss = mean_r( ((per_token_loss + β·kl) · mask).sum(-1) / max(mask.sum(-1), 1
|
|
| 150 |
`per_token_loss` and `kl` are the same clipped-ratio and KL expressions as BNPO.
|
| 151 |
`completions_mask` is **required** because the per-response denominator is
|
| 152 |
mask-derived. The kernel uses one CTA per row so the per-row mask sum is
|
| 153 |
-
reduced inside the block
|
| 154 |
|
| 155 |
**Inputs**:
|
| 156 |
- `policy_logprobs`, `old_policy_logprobs`, `ref_logprobs`: `(bs, seq_len)`, fp32/fp16/bf16
|
| 157 |
- `advantages`: `(bs,)`
|
| 158 |
-
- `completions_mask`: `(bs, seq_len)`, bool or int8
|
| 159 |
|
| 160 |
**Returns**: `(loss, grad_policy_logprobs)` from `grpo_loss`; scalar `loss` from `grpo_loss_fwd`.
|
| 161 |
|
|
@@ -173,7 +174,7 @@ ref_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=devi
|
|
| 173 |
advantages = torch.randn(bs, dtype=torch.bfloat16, device=device)
|
| 174 |
completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2).to(torch.int8)
|
| 175 |
|
| 176 |
-
# 1) Direct (loss, grad)
|
| 177 |
loss, grad = km.grpo_loss(
|
| 178 |
policy_logprobs, old_policy_logprobs, ref_logprobs,
|
| 179 |
advantages, completions_mask,
|
|
@@ -181,7 +182,7 @@ loss, grad = km.grpo_loss(
|
|
| 181 |
)
|
| 182 |
policy_logprobs.backward(grad)
|
| 183 |
|
| 184 |
-
# 2) Autograd-aware
|
| 185 |
loss = km.grpo_loss_autograd(
|
| 186 |
policy_logprobs.requires_grad_(),
|
| 187 |
old_policy_logprobs, ref_logprobs,
|
|
@@ -190,7 +191,7 @@ loss = km.grpo_loss_autograd(
|
|
| 190 |
)
|
| 191 |
loss.backward()
|
| 192 |
|
| 193 |
-
# 3) Forward-only
|
| 194 |
loss = km.grpo_loss_fwd(
|
| 195 |
policy_logprobs, old_policy_logprobs, ref_logprobs,
|
| 196 |
advantages, completions_mask,
|
|
@@ -222,7 +223,7 @@ grad_student_v = scale · p_v · (log p_v − log q_v − kl_per_row)
|
|
| 222 |
where `scale = mask[r] · inv_n_valid`.
|
| 223 |
|
| 224 |
**Inputs**:
|
| 225 |
-
- `student_logits`, `teacher_logits`: `(*, V)`
|
| 226 |
- `completions_mask`: shape matching `student_logits.shape[:-1]`
|
| 227 |
|
| 228 |
> ⚠️ **Fully-masked batches**: `inv_n_valid = 1 / mask.sum()` is not clamped, so a batch where every token is masked produces inf/NaN. Guard upstream if that case is reachable.
|
|
@@ -242,17 +243,17 @@ student_logits = torch.randn(bs, seq_len, vocab, dtype=torch.bfloat16, device=d
|
|
| 242 |
teacher_logits = torch.randn(bs, seq_len, vocab, dtype=torch.bfloat16, device=device)
|
| 243 |
completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2)
|
| 244 |
|
| 245 |
-
# 1) Direct (loss, grad)
|
| 246 |
loss, grad = km.reverse_kl(student_logits, teacher_logits, completions_mask)
|
| 247 |
student_logits.backward(grad)
|
| 248 |
|
| 249 |
-
# 2) Autograd-aware
|
| 250 |
loss = km.reverse_kl_autograd(
|
| 251 |
student_logits.requires_grad_(), teacher_logits, completions_mask
|
| 252 |
)
|
| 253 |
loss.backward()
|
| 254 |
|
| 255 |
-
# 3) Forward-only
|
| 256 |
loss = km.reverse_kl_fwd(student_logits, teacher_logits, completions_mask)
|
| 257 |
```
|
| 258 |
|
|
@@ -260,37 +261,22 @@ loss = km.reverse_kl_fwd(student_logits, teacher_logits, completions_mask)
|
|
| 260 |
|
| 261 |
## Performance
|
| 262 |
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
**eager PyTorch** and **`torch.compile(mode="max-autotune-no-cudagraphs",
|
| 266 |
-
fullgraph=True)`** with `torch._dynamo.config.trace_autograd_ops = True` so
|
| 267 |
-
the compiled baseline is a real Inductor-fused fwd+bwd graph.
|
| 268 |
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
| `bnpo_loss_fwd` | ≠ 0 | 1.4× | 1.1× |
|
| 274 |
-
| `bnpo_loss` | ≠ 0 | 1.3× | 1.0× |
|
| 275 |
-
| `grpo_loss_fwd` | ≠ 0 | 1.5× | 1.2× |
|
| 276 |
-
| `grpo_loss` | ≠ 0 | 1.4× | 1.1× |
|
| 277 |
-
| `reverse_kl_fwd`| | 1.3× | 1.1× |
|
| 278 |
-
| `reverse_kl` | | 1.2× | 1.0× |
|
| 279 |
-
|
| 280 |
-
Profiled on H100 SXM (SM90a). BNPO and GRPO benchmarked separately for `β = 0`
|
| 281 |
-
(KL term dead-coded at compile time) and `β ≠ 0`. Shapes:
|
| 282 |
-
|
| 283 |
-
- **BNPO / GRPO**: `(16, 1024)`, `(32, 2048)`, `(64, 4096)`, `(128, 8192)`,
|
| 284 |
-
`(128, 8193)` — the last entry exercises the predicated tail-tile path.
|
| 285 |
-
- **Reverse KL** (vocab = 248320, matching Qwen3.5): `(1, 64)`, `(2, 128)`,
|
| 286 |
-
`(4, 256)`, `(8, 512)`, `(16, 1024)`, `(8, 1981)`.
|
| 287 |
-
|
| 288 |
-
Reproduce locally:
|
| 289 |
-
|
| 290 |
-
```bash
|
| 291 |
-
make bench-kernel KERNEL=grpo_loss # or bnpo_loss, reverse_kl
|
| 292 |
-
```
|
| 293 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
---
|
| 295 |
|
| 296 |
## Benchmark animations
|
|
@@ -299,40 +285,40 @@ make bench-kernel KERNEL=grpo_loss # or bnpo_loss, reverse_kl
|
|
| 299 |
|
| 300 |
<picture>
|
| 301 |
<source media="(prefers-color-scheme: dark)" srcset="benchmark_results/bnpo_loss_eager/bnpo_loss_eager_dark_animation.svg">
|
| 302 |
-
<img src="benchmark_results/bnpo_loss_eager/bnpo_loss_eager_light_animation.svg" alt="BNPO loss latency vs eager PyTorch">
|
| 303 |
</picture>
|
| 304 |
|
| 305 |
### BNPO Loss vs torch.compile
|
| 306 |
|
| 307 |
<picture>
|
| 308 |
<source media="(prefers-color-scheme: dark)" srcset="benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_dark_animation.svg">
|
| 309 |
-
<img src="benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_light_animation.svg" alt="BNPO loss latency vs torch.compile">
|
| 310 |
</picture>
|
| 311 |
|
| 312 |
### GRPO Loss vs eager PyTorch
|
| 313 |
|
| 314 |
<picture>
|
| 315 |
<source media="(prefers-color-scheme: dark)" srcset="benchmark_results/grpo_loss_eager/grpo_loss_eager_dark_animation.svg">
|
| 316 |
-
<img src="benchmark_results/grpo_loss_eager/grpo_loss_eager_light_animation.svg" alt="GRPO loss latency vs eager PyTorch">
|
| 317 |
</picture>
|
| 318 |
|
| 319 |
### GRPO Loss vs torch.compile
|
| 320 |
|
| 321 |
<picture>
|
| 322 |
<source media="(prefers-color-scheme: dark)" srcset="benchmark_results/grpo_loss_compiled/grpo_loss_compiled_dark_animation.svg">
|
| 323 |
-
<img src="benchmark_results/grpo_loss_compiled/grpo_loss_compiled_light_animation.svg" alt="GRPO loss latency vs torch.compile">
|
| 324 |
</picture>
|
| 325 |
|
| 326 |
### Reverse KL vs eager PyTorch
|
| 327 |
|
| 328 |
<picture>
|
| 329 |
<source media="(prefers-color-scheme: dark)" srcset="benchmark_results/reverse_kl_eager/reverse_kl_eager_dark_animation.svg">
|
| 330 |
-
<img src="benchmark_results/reverse_kl_eager/reverse_kl_eager_light_animation.svg" alt="Reverse KL latency vs eager PyTorch">
|
| 331 |
</picture>
|
| 332 |
|
| 333 |
### Reverse KL vs torch.compile
|
| 334 |
|
| 335 |
<picture>
|
| 336 |
<source media="(prefers-color-scheme: dark)" srcset="benchmark_results/reverse_kl_compiled/reverse_kl_compiled_dark_animation.svg">
|
| 337 |
-
<img src="benchmark_results/reverse_kl_compiled/reverse_kl_compiled_light_animation.svg" alt="Reverse KL latency vs torch.compile">
|
| 338 |
</picture>
|
|
|
|
| 17 |
|
| 18 |
Fused **CuteDSL** kernels for the loss functions that dominate post-training
|
| 19 |
workloads: PPO-family policy losses (BNPO, GRPO) and reverse-KL
|
| 20 |
+
self-distillation.
|
| 21 |
+
|
| 22 |
+
Each kernel ships a **single-launch fused forward +
|
| 23 |
+
backward** path that returns `(loss, grad_logprobs)` directly. No `torch.autograd.Function` wrapper, no extra `grad_output * dpolicy` backward
|
| 24 |
kernel, and no host-side syncs in the hot path.
|
| 25 |
|
| 26 |
Background and benchmarks: see the
|
| 27 |
[release post](https://geometric.so/blog/2026/05/08/hf-kernel-hub).
|
| 28 |
|
| 29 |
- **Backend**: CUDA (NVIDIA CUTLASS DSL).
|
| 30 |
+
- **Min GPU**: SM80 (Ampere) - required by `nvidia-cutlass-dsl`. Tested on H100 (SM90). Should work on SM80 (Ampere), SM86 (RTX 3090, A40), SM89 (RTX 4090, L40S), SM90a (H100 SXM), and SM100 (Blackwell B200/GB200).
|
| 31 |
- **Min CUDA**: 12.8.
|
| 32 |
- **Dtypes**: `float32`, `float16`, `bfloat16`.
|
| 33 |
- **Dynamic shapes**: a single compile handles arbitrary batch size and
|
| 34 |
+
sequence length, no recompiles when shapes change between calls (common
|
| 35 |
in post-training rollouts).
|
| 36 |
|
| 37 |
## Kernels
|
|
|
|
| 46 |
|
| 47 |
Each kernel family exposes three entry points with the same underlying CuteDSL kernel:
|
| 48 |
|
| 49 |
+
- **`<name>(...)`** - fused fwd+bwd, returns `(loss, grad)` from one `@cute.jit`
|
| 50 |
dispatch. Lowest-overhead path; the caller chains the gradient into the upstream
|
| 51 |
model with `policy_logprobs.backward(grad)`. Use this in custom training loops
|
| 52 |
where you control gradient flow.
|
| 53 |
+
- **`<name>_autograd(...)`** - same kernel, registered via
|
| 54 |
`torch.library.custom_op` + `register_autograd`. `loss.backward()` works
|
| 55 |
and composes with `torch.compile(fullgraph=True)`. There is a noticeable
|
| 56 |
per-call dispatcher overhead vs. the direct path.
|
| 57 |
+
- **`<name>_fwd(...)`** - forward-only, returns scalar `loss` and skips
|
| 58 |
the gradient buffer entirely. Use for inference / validation /
|
| 59 |
reward-model scoring.
|
| 60 |
|
|
|
|
| 86 |
kl = exp(ref_logprobs − policy_logprobs) − (ref_logprobs − policy_logprobs) − 1
|
| 87 |
```
|
| 88 |
|
| 89 |
+
The global denominator is computed entirely on-GPU via cross-CTA atomics -
|
| 90 |
no host-side `mask.sum()` sync. When `beta=0` the KL branch is dead-coded
|
| 91 |
at compile time.
|
| 92 |
|
|
|
|
| 111 |
advantages = torch.randn(bs, dtype=torch.bfloat16, device=device)
|
| 112 |
completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2).to(torch.int8)
|
| 113 |
|
| 114 |
+
# 1) Direct (loss, grad) - lowest overhead training path
|
| 115 |
loss, grad = km.bnpo_loss(
|
| 116 |
policy_logprobs, old_policy_logprobs, ref_logprobs,
|
| 117 |
advantages, completions_mask,
|
|
|
|
| 119 |
)
|
| 120 |
policy_logprobs.backward(grad)
|
| 121 |
|
| 122 |
+
# 2) Autograd-aware - works with loss.backward() and torch.compile
|
| 123 |
loss = km.bnpo_loss_autograd(
|
| 124 |
policy_logprobs.requires_grad_(),
|
| 125 |
old_policy_logprobs, ref_logprobs,
|
|
|
|
| 128 |
)
|
| 129 |
loss.backward()
|
| 130 |
|
| 131 |
+
# 3) Forward-only - inference / reward scoring, no gradient buffer
|
| 132 |
loss = km.bnpo_loss_fwd(
|
| 133 |
policy_logprobs, old_policy_logprobs, ref_logprobs,
|
| 134 |
advantages, completions_mask,
|
|
|
|
| 141 |
## GRPO Loss
|
| 142 |
|
| 143 |
**Group Relative Policy Optimization** implements TRL's default
|
| 144 |
+
**per-response normalization** variant - each response is normalized by its
|
| 145 |
own valid-token count before averaging across the batch:
|
| 146 |
|
| 147 |
```
|
|
|
|
| 151 |
`per_token_loss` and `kl` are the same clipped-ratio and KL expressions as BNPO.
|
| 152 |
`completions_mask` is **required** because the per-response denominator is
|
| 153 |
mask-derived. The kernel uses one CTA per row so the per-row mask sum is
|
| 154 |
+
reduced inside the block - no cross-CTA atomics on the scaling pass.
|
| 155 |
|
| 156 |
**Inputs**:
|
| 157 |
- `policy_logprobs`, `old_policy_logprobs`, `ref_logprobs`: `(bs, seq_len)`, fp32/fp16/bf16
|
| 158 |
- `advantages`: `(bs,)`
|
| 159 |
+
- `completions_mask`: `(bs, seq_len)`, bool or int8 - **required**
|
| 160 |
|
| 161 |
**Returns**: `(loss, grad_policy_logprobs)` from `grpo_loss`; scalar `loss` from `grpo_loss_fwd`.
|
| 162 |
|
|
|
|
| 174 |
advantages = torch.randn(bs, dtype=torch.bfloat16, device=device)
|
| 175 |
completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2).to(torch.int8)
|
| 176 |
|
| 177 |
+
# 1) Direct (loss, grad) - lowest overhead training path
|
| 178 |
loss, grad = km.grpo_loss(
|
| 179 |
policy_logprobs, old_policy_logprobs, ref_logprobs,
|
| 180 |
advantages, completions_mask,
|
|
|
|
| 182 |
)
|
| 183 |
policy_logprobs.backward(grad)
|
| 184 |
|
| 185 |
+
# 2) Autograd-aware - works with loss.backward() and torch.compile
|
| 186 |
loss = km.grpo_loss_autograd(
|
| 187 |
policy_logprobs.requires_grad_(),
|
| 188 |
old_policy_logprobs, ref_logprobs,
|
|
|
|
| 191 |
)
|
| 192 |
loss.backward()
|
| 193 |
|
| 194 |
+
# 3) Forward-only - inference / reward scoring, no gradient buffer
|
| 195 |
loss = km.grpo_loss_fwd(
|
| 196 |
policy_logprobs, old_policy_logprobs, ref_logprobs,
|
| 197 |
advantages, completions_mask,
|
|
|
|
| 223 |
where `scale = mask[r] · inv_n_valid`.
|
| 224 |
|
| 225 |
**Inputs**:
|
| 226 |
+
- `student_logits`, `teacher_logits`: `(*, V)` - arbitrary leading dims (typically `(bs, seq_len, vocab)`); both must share shape and dtype
|
| 227 |
- `completions_mask`: shape matching `student_logits.shape[:-1]`
|
| 228 |
|
| 229 |
> ⚠️ **Fully-masked batches**: `inv_n_valid = 1 / mask.sum()` is not clamped, so a batch where every token is masked produces inf/NaN. Guard upstream if that case is reachable.
|
|
|
|
| 243 |
teacher_logits = torch.randn(bs, seq_len, vocab, dtype=torch.bfloat16, device=device)
|
| 244 |
completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2)
|
| 245 |
|
| 246 |
+
# 1) Direct (loss, grad) - lowest overhead training path
|
| 247 |
loss, grad = km.reverse_kl(student_logits, teacher_logits, completions_mask)
|
| 248 |
student_logits.backward(grad)
|
| 249 |
|
| 250 |
+
# 2) Autograd-aware - works with loss.backward() and torch.compile
|
| 251 |
loss = km.reverse_kl_autograd(
|
| 252 |
student_logits.requires_grad_(), teacher_logits, completions_mask
|
| 253 |
)
|
| 254 |
loss.backward()
|
| 255 |
|
| 256 |
+
# 3) Forward-only - inference / KL monitoring, no gradient buffer
|
| 257 |
loss = km.reverse_kl_fwd(student_logits, teacher_logits, completions_mask)
|
| 258 |
```
|
| 259 |
|
|
|
|
| 261 |
|
| 262 |
## Performance
|
| 263 |
|
| 264 |
+
All numbers are geometric-mean speedups over H100 SXM (SM90a). Full methodology
|
| 265 |
+
and per-shape plots in the [release post](https://geometric.so/blog/2026/05/08/hf-kernel-hub).
|
|
|
|
|
|
|
|
|
|
| 266 |
|
| 267 |
+
### `kernels` CLI benchmark
|
| 268 |
+
|
| 269 |
+
Timed with `time.perf_counter` + `cuda.synchronize()`, single iteration per
|
| 270 |
+
shape, mean over 100 iterations. Baseline runs once after warmup.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
|
| 272 |
+
| Kernel | vs eager | vs `torch.compile` |
|
| 273 |
+
| --- | --- | --- |
|
| 274 |
+
| `grpo_loss_fwd` | 5.60× | 2.52× |
|
| 275 |
+
| `grpo_loss` | 19.92× | 2.28× |
|
| 276 |
+
| `bnpo_loss_fwd` | 5.58× | 2.54× |
|
| 277 |
+
| `bnpo_loss` | 17.46× | 2.15× |
|
| 278 |
+
| `reverse_kl_fwd`| 6.95× | 2.44× |
|
| 279 |
+
| `reverse_kl` | 7.06× | 2.59× |
|
| 280 |
---
|
| 281 |
|
| 282 |
## Benchmark animations
|
|
|
|
| 285 |
|
| 286 |
<picture>
|
| 287 |
<source media="(prefers-color-scheme: dark)" srcset="benchmark_results/bnpo_loss_eager/bnpo_loss_eager_dark_animation.svg">
|
| 288 |
+
<img width="90%" src="benchmark_results/bnpo_loss_eager/bnpo_loss_eager_light_animation.svg" alt="BNPO loss latency vs eager PyTorch">
|
| 289 |
</picture>
|
| 290 |
|
| 291 |
### BNPO Loss vs torch.compile
|
| 292 |
|
| 293 |
<picture>
|
| 294 |
<source media="(prefers-color-scheme: dark)" srcset="benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_dark_animation.svg">
|
| 295 |
+
<img width="90%" src="benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_light_animation.svg" alt="BNPO loss latency vs torch.compile">
|
| 296 |
</picture>
|
| 297 |
|
| 298 |
### GRPO Loss vs eager PyTorch
|
| 299 |
|
| 300 |
<picture>
|
| 301 |
<source media="(prefers-color-scheme: dark)" srcset="benchmark_results/grpo_loss_eager/grpo_loss_eager_dark_animation.svg">
|
| 302 |
+
<img width="90%" src="benchmark_results/grpo_loss_eager/grpo_loss_eager_light_animation.svg" alt="GRPO loss latency vs eager PyTorch">
|
| 303 |
</picture>
|
| 304 |
|
| 305 |
### GRPO Loss vs torch.compile
|
| 306 |
|
| 307 |
<picture>
|
| 308 |
<source media="(prefers-color-scheme: dark)" srcset="benchmark_results/grpo_loss_compiled/grpo_loss_compiled_dark_animation.svg">
|
| 309 |
+
<img width="90%" src="benchmark_results/grpo_loss_compiled/grpo_loss_compiled_light_animation.svg" alt="GRPO loss latency vs torch.compile">
|
| 310 |
</picture>
|
| 311 |
|
| 312 |
### Reverse KL vs eager PyTorch
|
| 313 |
|
| 314 |
<picture>
|
| 315 |
<source media="(prefers-color-scheme: dark)" srcset="benchmark_results/reverse_kl_eager/reverse_kl_eager_dark_animation.svg">
|
| 316 |
+
<img width="90%" src="benchmark_results/reverse_kl_eager/reverse_kl_eager_light_animation.svg" alt="Reverse KL latency vs eager PyTorch">
|
| 317 |
</picture>
|
| 318 |
|
| 319 |
### Reverse KL vs torch.compile
|
| 320 |
|
| 321 |
<picture>
|
| 322 |
<source media="(prefers-color-scheme: dark)" srcset="benchmark_results/reverse_kl_compiled/reverse_kl_compiled_dark_animation.svg">
|
| 323 |
+
<img width="90%" src="benchmark_results/reverse_kl_compiled/reverse_kl_compiled_light_animation.svg" alt="Reverse KL latency vs torch.compile">
|
| 324 |
</picture>
|