diff --git a/README.md b/README.md index 2dd4e77f19c661fc5f2354d8ebcdfd0222a9813c..7b95401dc46245ac339fc25059d4a56d90b4cde5 100644 --- a/README.md +++ b/README.md @@ -1,326 +1,3 @@ ---- -library_name: kernels -license: apache-2.0 -tags: -- cuda -- cutlass -- cute-dsl -- rl -- distillation -- trl -- grpo -- bnpo -- kl-divergence ---- - -# Geometric-AI Kernels - -Fused **CuteDSL** kernels for the loss functions that dominate post-training -workloads: PPO-family policy losses (BNPO, GRPO) and reverse-KL -self-distillation. - -Each kernel ships a **single-launch fused forward + -backward** path that returns `(loss, grad_logprobs)` directly. No `torch.autograd.Function` wrapper, no extra `grad_output * dpolicy` backward -kernel, and no host-side syncs in the hot path. - -Background and benchmarks: see the -[release post](https://geometric.so/blog/2026/05/08/hf-kernel-hub). - -- **Backend**: CUDA (NVIDIA CUTLASS DSL). -- **Min GPU**: SM80 (Ampere) - required by `nvidia-cutlass-dsl`. Tested on H100 (SM90). Should work on SM80 (Ampere), SM86 (RTX 3090, A40), SM89 (RTX 4090, L40S), SM90a (H100 SXM), and SM100 (Blackwell B200/GB200). -- **Min CUDA**: 12.8. -- **Dtypes**: `float32`, `float16`, `bfloat16`. -- **Dynamic shapes**: a single compile handles arbitrary batch size and - sequence length, no recompiles when shapes change between calls (common - in post-training rollouts). - -## Kernels - -| Kernel family | Direct (no autograd) | Autograd-aware | Forward-only | -| --- | --- | --- | --- | -| BNPO loss | `bnpo_loss` | `bnpo_loss_autograd` | `bnpo_loss_fwd` | -| GRPO loss | `grpo_loss` | `grpo_loss_autograd` | `grpo_loss_fwd` | -| Reverse KL | `reverse_kl` | `reverse_kl_autograd` | `reverse_kl_fwd` | - -### Entry points - -Each kernel family exposes three entry points with the same underlying CuteDSL kernel: - -- **`(...)`** - fused fwd+bwd, returns `(loss, grad)` from one `@cute.jit` - dispatch. Lowest-overhead path; the caller chains the gradient into the upstream - model with `policy_logprobs.backward(grad)`. Use this in custom training loops - where you control gradient flow. -- **`_autograd(...)`** - same kernel, registered via - `torch.library.custom_op` + `register_autograd`. `loss.backward()` works - and composes with `torch.compile(fullgraph=True)`. There is a noticeable - per-call dispatcher overhead vs. the direct path. -- **`_fwd(...)`** - forward-only, returns scalar `loss` and skips - the gradient buffer entirely. Use for inference / validation / - reward-model scoring. - -## Loading the kernels -``` -pip install apache-tvm-ffi nvidia-cutlass-dsl -``` - -```python -from kernels import get_kernel - -km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0) -``` - ---- - -## BNPO Loss - -**Batch-Normalized Policy Optimization** sums per-token policy and KL terms -across the **entire batch** and divides by the global valid-token count: - -``` -loss = ((per_token_loss + β·kl) · mask).sum() / max(mask.sum(), 1) -``` - -where `per_token_loss` is the PPO-clipped ratio loss: - -``` -ratio = exp(policy_logprobs - old_policy_logprobs) -clipped = clip(ratio, 1−ε, 1+ε_high) -per_token = −advantages · min(ratio, clipped) -kl = exp(ref_logprobs − policy_logprobs) − (ref_logprobs − policy_logprobs) − 1 -``` - -The global denominator is computed entirely on-GPU via cross-CTA atomics - -no host-side `mask.sum()` sync. When `beta=0` the KL branch is dead-coded -at compile time. - -**Inputs**: -- `policy_logprobs`, `old_policy_logprobs`, `ref_logprobs`: `(bs, seq_len)`, fp32/fp16/bf16 -- `advantages`: `(bs,)` -- `completions_mask`: `(bs, seq_len)`, bool or int8 - -**Returns**: `(loss, grad_policy_logprobs)` from `bnpo_loss`; scalar `loss` from `bnpo_loss_fwd`. - -```python -import torch -from kernels import get_kernel - -km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0) -device = torch.device("cuda") - -bs, seq_len = 16, 1024 -policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device, requires_grad=True) -old_policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device) -ref_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device) -advantages = torch.randn(bs, dtype=torch.bfloat16, device=device) -completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2).to(torch.int8) - -# 1) Direct (loss, grad) - lowest overhead training path -loss, grad = km.bnpo_loss( - policy_logprobs, old_policy_logprobs, ref_logprobs, - advantages, completions_mask, - epsilon=0.2, epsilon_high=0.2, beta=0.1, -) -policy_logprobs.backward(grad) - -# 2) Autograd-aware - works with loss.backward() and torch.compile -loss = km.bnpo_loss_autograd( - policy_logprobs.requires_grad_(), - old_policy_logprobs, ref_logprobs, - advantages, completions_mask, - epsilon=0.2, epsilon_high=0.2, beta=0.1, -) -loss.backward() - -# 3) Forward-only - inference / reward scoring, no gradient buffer -loss = km.bnpo_loss_fwd( - policy_logprobs, old_policy_logprobs, ref_logprobs, - advantages, completions_mask, - epsilon=0.2, epsilon_high=0.2, beta=0.1, -) -``` - ---- - -## GRPO Loss - -**Group Relative Policy Optimization** implements TRL's default -**per-response normalization** variant - each response is normalized by its -own valid-token count before averaging across the batch: - -``` -loss = mean_r( ((per_token_loss + β·kl) · mask).sum(-1) / max(mask.sum(-1), 1) ) -``` - -`per_token_loss` and `kl` are the same clipped-ratio and KL expressions as BNPO. -`completions_mask` is **required** because the per-response denominator is -mask-derived. The kernel uses one CTA per row so the per-row mask sum is -reduced inside the block - no cross-CTA atomics on the scaling pass. - -**Inputs**: -- `policy_logprobs`, `old_policy_logprobs`, `ref_logprobs`: `(bs, seq_len)`, fp32/fp16/bf16 -- `advantages`: `(bs,)` -- `completions_mask`: `(bs, seq_len)`, bool or int8 - **required** - -**Returns**: `(loss, grad_policy_logprobs)` from `grpo_loss`; scalar `loss` from `grpo_loss_fwd`. - -```python -import torch -from kernels import get_kernel - -km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0) -device = torch.device("cuda") - -bs, seq_len = 16, 1024 -policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device, requires_grad=True) -old_policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device) -ref_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device) -advantages = torch.randn(bs, dtype=torch.bfloat16, device=device) -completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2).to(torch.int8) - -# 1) Direct (loss, grad) - lowest overhead training path -loss, grad = km.grpo_loss( - policy_logprobs, old_policy_logprobs, ref_logprobs, - advantages, completions_mask, - epsilon=0.2, epsilon_high=0.2, beta=0.1, -) -policy_logprobs.backward(grad) - -# 2) Autograd-aware - works with loss.backward() and torch.compile -loss = km.grpo_loss_autograd( - policy_logprobs.requires_grad_(), - old_policy_logprobs, ref_logprobs, - advantages, completions_mask, - epsilon=0.2, epsilon_high=0.2, beta=0.1, -) -loss.backward() - -# 3) Forward-only - inference / reward scoring, no gradient buffer -loss = km.grpo_loss_fwd( - policy_logprobs, old_policy_logprobs, ref_logprobs, - advantages, completions_mask, - epsilon=0.2, epsilon_high=0.2, beta=0.1, -) -``` - ---- - -## Reverse KL - -**Reverse-KL self-distillation** computes `KL(student ‖ teacher)` over a -`(num_tokens, vocab)` slab using an online normalization algorithm that reads -each logit row exactly once on the forward-only path: - -``` -p = softmax(student_logits) -q = softmax(teacher_logits) -kl_per_row = Σ_v p_v · (log p_v − log q_v) -loss = (mask · kl_per_row).sum() / mask.sum() -``` - -The gradient through the softmax Jacobian is analytical: - -``` -grad_student_v = scale · p_v · (log p_v − log q_v − kl_per_row) -``` - -where `scale = mask[r] · inv_n_valid`. - -**Inputs**: -- `student_logits`, `teacher_logits`: `(*, V)` - arbitrary leading dims (typically `(bs, seq_len, vocab)`); both must share shape and dtype -- `completions_mask`: shape matching `student_logits.shape[:-1]` - -> ⚠️ **Fully-masked batches**: `inv_n_valid = 1 / mask.sum()` is not clamped, so a batch where every token is masked produces inf/NaN. Guard upstream if that case is reachable. - -**Returns**: `(loss, grad_student_logits)` from `reverse_kl`; scalar `loss` from `reverse_kl_fwd`. - -```python -import torch -from kernels import get_kernel - -km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0) -device = torch.device("cuda") - -# Qwen3.5-style vocab; arbitrary leading dims supported -bs, seq_len, vocab = 4, 256, 248320 -student_logits = torch.randn(bs, seq_len, vocab, dtype=torch.bfloat16, device=device, requires_grad=True) -teacher_logits = torch.randn(bs, seq_len, vocab, dtype=torch.bfloat16, device=device) -completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2) - -# 1) Direct (loss, grad) - lowest overhead training path -loss, grad = km.reverse_kl(student_logits, teacher_logits, completions_mask) -student_logits.backward(grad) - -# 2) Autograd-aware - works with loss.backward() and torch.compile -loss = km.reverse_kl_autograd( - student_logits.requires_grad_(), teacher_logits, completions_mask -) -loss.backward() - -# 3) Forward-only - inference / KL monitoring, no gradient buffer -loss = km.reverse_kl_fwd(student_logits, teacher_logits, completions_mask) -``` - ---- - -## Performance - -All numbers are geometric-mean speedups over H100 SXM (SM90a). Full methodology -and per-shape plots in the [release post](https://geometric.so/blog/2026/05/08/hf-kernel-hub). - -### `kernels` CLI benchmark - -Timed with `time.perf_counter` + `cuda.synchronize()`, mean over 100 iterations. - -| Kernel | vs eager | vs `torch.compile` | -| --- | --- | --- | -| `grpo_loss_fwd` | 5.68× | 2.45× | -| `grpo_loss` | 20.79× | 1.98x | -| `bnpo_loss_fwd` | 5.29× | 2.52× | -| `bnpo_loss` | 16.81× | 2.27× | -| `reverse_kl_fwd`| 6.88× | 2.45× | -| `reverse_kl` | 7.03× | 2.61× | ---- - -## Benchmark animations - -### BNPO Loss vs eager PyTorch - - - - BNPO loss latency vs eager PyTorch - - -### BNPO Loss vs torch.compile - - - - BNPO loss latency vs torch.compile - - -### GRPO Loss vs eager PyTorch - - - - GRPO loss latency vs eager PyTorch - - -### GRPO Loss vs torch.compile - - - - GRPO loss latency vs torch.compile - - -### Reverse KL vs eager PyTorch - - - - Reverse KL latency vs eager PyTorch - - -### Reverse KL vs torch.compile - - - - Reverse KL latency vs torch.compile - \ No newline at end of file +--- +license: apache-2.0 +--- diff --git a/benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_dark_animation.svg b/benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_dark_animation.svg deleted file mode 100644 index 9fd19986f8083838ec9fb3fa02d22b5a23d6aeb2..0000000000000000000000000000000000000000 --- a/benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_dark_animation.svg +++ /dev/null @@ -1,123 +0,0 @@ - -./ vs Torch - Relative Speed -PyTorch 2.11.0+cu130 · CUDA 13.0 - -bnpoLossBenchmark.bnpo_loss_batch128_seqlen02781_compiled -2.15x - - - - - - - -bnpoLossBenchmark.bnpo_loss_batch128_seqlen08192_compiled -2.20x - - - - - - - -bnpoLossBenchmark.bnpo_loss_batch16_seqlen01024_compiled -2.28x - - - - - - - -bnpoLossBenchmark.bnpo_loss_batch16_seqlen02781_compiled -2.00x - - - - - - - -bnpoLossBenchmark.bnpo_loss_batch32_seqlen02048_compiled -2.13x - - - - - - - -bnpoLossBenchmark.bnpo_loss_batch64_seqlen04096_compiled -2.15x - - - - - - - -bnpoLossBenchmark.bnpo_loss_fwd_batch128_seqlen02781_compiled -2.50x - - - - - - - -bnpoLossBenchmark.bnpo_loss_fwd_batch128_seqlen08192_compiled -2.66x - - - - - - - -bnpoLossBenchmark.bnpo_loss_fwd_batch16_seqlen01024_compiled -2.49x - - - - - - - -bnpoLossBenchmark.bnpo_loss_fwd_batch16_seqlen02781_compiled -2.75x - - - - - - - -bnpoLossBenchmark.bnpo_loss_fwd_batch32_seqlen02048_compiled -2.50x - - - - - - - -bnpoLossBenchmark.bnpo_loss_fwd_batch64_seqlen04096_compiled -2.35x - - - - - - - -Kernel - -Torch (ref) - - - - - - - - \ No newline at end of file diff --git a/benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_dark_latency.svg b/benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_dark_latency.svg deleted file mode 100644 index e96e216f6dc598a6c692adcc540e2d6dbad0b029..0000000000000000000000000000000000000000 --- a/benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_dark_latency.svg +++ /dev/null @@ -1,3365 +0,0 @@ - - - - - - - - 2026-05-08T15:08:44.964442 - image/svg+xml - - - Matplotlib v3.10.9, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_dark_throughput.svg b/benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_dark_throughput.svg deleted file mode 100644 index e6626ec19c6c29f4918494563c4e7bedaa8eb8a7..0000000000000000000000000000000000000000 --- a/benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_dark_throughput.svg +++ /dev/null @@ -1,3639 +0,0 @@ - - - - - - - - 2026-05-08T15:08:45.251021 - image/svg+xml - - - Matplotlib v3.10.9, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_light_animation.svg b/benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_light_animation.svg deleted file mode 100644 index 74286a83de43bb3215dcf10c24acb6c622dbf015..0000000000000000000000000000000000000000 --- a/benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_light_animation.svg +++ /dev/null @@ -1,123 +0,0 @@ - -./ vs Torch - Relative Speed -PyTorch 2.11.0+cu130 · CUDA 13.0 - -bnpoLossBenchmark.bnpo_loss_batch128_seqlen02781_compiled -2.15x - - - - - - - -bnpoLossBenchmark.bnpo_loss_batch128_seqlen08192_compiled -2.20x - - - - - - - -bnpoLossBenchmark.bnpo_loss_batch16_seqlen01024_compiled -2.28x - - - - - - - -bnpoLossBenchmark.bnpo_loss_batch16_seqlen02781_compiled -2.00x - - - - - - - -bnpoLossBenchmark.bnpo_loss_batch32_seqlen02048_compiled -2.13x - - - - - - - -bnpoLossBenchmark.bnpo_loss_batch64_seqlen04096_compiled -2.15x - - - - - - - -bnpoLossBenchmark.bnpo_loss_fwd_batch128_seqlen02781_compiled -2.50x - - - - - - - -bnpoLossBenchmark.bnpo_loss_fwd_batch128_seqlen08192_compiled -2.66x - - - - - - - -bnpoLossBenchmark.bnpo_loss_fwd_batch16_seqlen01024_compiled -2.49x - - - - - - - -bnpoLossBenchmark.bnpo_loss_fwd_batch16_seqlen02781_compiled -2.75x - - - - - - - -bnpoLossBenchmark.bnpo_loss_fwd_batch32_seqlen02048_compiled -2.50x - - - - - - - -bnpoLossBenchmark.bnpo_loss_fwd_batch64_seqlen04096_compiled -2.35x - - - - - - - -Kernel - -Torch (ref) - - - - - - - - \ No newline at end of file diff --git a/benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_light_latency.svg b/benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_light_latency.svg deleted file mode 100644 index 24c97156ed50d58aa6fe883bf4d548afa29f9ad4..0000000000000000000000000000000000000000 --- a/benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_light_latency.svg +++ /dev/null @@ -1,3365 +0,0 @@ - - - - - - - - 2026-05-08T15:08:44.222545 - image/svg+xml - - - Matplotlib v3.10.9, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_light_throughput.svg b/benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_light_throughput.svg deleted file mode 100644 index 34814c47fbbbdfc9ff99d6db484384a9a4155eaa..0000000000000000000000000000000000000000 --- a/benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_light_throughput.svg +++ /dev/null @@ -1,3639 +0,0 @@ - - - - - - - - 2026-05-08T15:08:44.531746 - image/svg+xml - - - Matplotlib v3.10.9, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/benchmark_results/bnpo_loss_compiled/results.json b/benchmark_results/bnpo_loss_compiled/results.json deleted file mode 100644 index 03468956f6669953f78ec9ef277469fb38d380f0..0000000000000000000000000000000000000000 --- a/benchmark_results/bnpo_loss_compiled/results.json +++ /dev/null @@ -1,206 +0,0 @@ -{ - "results": [ - { - "workload": "bnpoLossBenchmark.bnpo_loss_batch128_seqlen02781_compiled", - "timingResults": { - "mean_ms": 0.0359, - "std_ms": 0.0038, - "min_ms": 0.0332, - "max_ms": 0.0701, - "q1_ms": 0.0344, - "q3_ms": 0.0357, - "iqr_ms": 0.0013, - "outliers": 20, - "iterations": 200, - "refMeanMs": 0.0771 - }, - "verified": true - }, - { - "workload": "bnpoLossBenchmark.bnpo_loss_batch128_seqlen08192_compiled", - "timingResults": { - "mean_ms": 0.0351, - "std_ms": 0.0033, - "min_ms": 0.0327, - "max_ms": 0.0557, - "q1_ms": 0.0336, - "q3_ms": 0.035, - "iqr_ms": 0.0014, - "outliers": 14, - "iterations": 200, - "refMeanMs": 0.0771 - }, - "verified": true - }, - { - "workload": "bnpoLossBenchmark.bnpo_loss_batch16_seqlen01024_compiled", - "timingResults": { - "mean_ms": 0.0355, - "std_ms": 0.0042, - "min_ms": 0.0331, - "max_ms": 0.0706, - "q1_ms": 0.034, - "q3_ms": 0.0351, - "iqr_ms": 0.0011, - "outliers": 21, - "iterations": 200, - "refMeanMs": 0.0811 - }, - "verified": true - }, - { - "workload": "bnpoLossBenchmark.bnpo_loss_batch16_seqlen02781_compiled", - "timingResults": { - "mean_ms": 0.0355, - "std_ms": 0.004, - "min_ms": 0.0319, - "max_ms": 0.0591, - "q1_ms": 0.0338, - "q3_ms": 0.0352, - "iqr_ms": 0.0014, - "outliers": 24, - "iterations": 200, - "refMeanMs": 0.0709 - }, - "verified": true - }, - { - "workload": "bnpoLossBenchmark.bnpo_loss_batch32_seqlen02048_compiled", - "timingResults": { - "mean_ms": 0.0358, - "std_ms": 0.0042, - "min_ms": 0.032, - "max_ms": 0.0569, - "q1_ms": 0.0338, - "q3_ms": 0.0355, - "iqr_ms": 0.0017, - "outliers": 27, - "iterations": 200, - "refMeanMs": 0.0763 - }, - "verified": true - }, - { - "workload": "bnpoLossBenchmark.bnpo_loss_batch64_seqlen04096_compiled", - "timingResults": { - "mean_ms": 0.0344, - "std_ms": 0.0031, - "min_ms": 0.032, - "max_ms": 0.0557, - "q1_ms": 0.0331, - "q3_ms": 0.0341, - "iqr_ms": 0.001, - "outliers": 32, - "iterations": 200, - "refMeanMs": 0.0739 - }, - "verified": true - }, - { - "workload": "bnpoLossBenchmark.bnpo_loss_fwd_batch128_seqlen02781_compiled", - "timingResults": { - "mean_ms": 0.0323, - "std_ms": 0.0034, - "min_ms": 0.03, - "max_ms": 0.053, - "q1_ms": 0.0311, - "q3_ms": 0.0318, - "iqr_ms": 0.0007, - "outliers": 25, - "iterations": 200, - "refMeanMs": 0.0808 - }, - "verified": true - }, - { - "workload": "bnpoLossBenchmark.bnpo_loss_fwd_batch128_seqlen08192_compiled", - "timingResults": { - "mean_ms": 0.0318, - "std_ms": 0.0032, - "min_ms": 0.0293, - "max_ms": 0.0502, - "q1_ms": 0.0304, - "q3_ms": 0.0317, - "iqr_ms": 0.0013, - "outliers": 17, - "iterations": 200, - "refMeanMs": 0.0845 - }, - "verified": true - }, - { - "workload": "bnpoLossBenchmark.bnpo_loss_fwd_batch16_seqlen01024_compiled", - "timingResults": { - "mean_ms": 0.0317, - "std_ms": 0.0031, - "min_ms": 0.0293, - "max_ms": 0.0593, - "q1_ms": 0.0304, - "q3_ms": 0.0317, - "iqr_ms": 0.0013, - "outliers": 17, - "iterations": 200, - "refMeanMs": 0.079 - }, - "verified": true - }, - { - "workload": "bnpoLossBenchmark.bnpo_loss_fwd_batch16_seqlen02781_compiled", - "timingResults": { - "mean_ms": 0.0306, - "std_ms": 0.0035, - "min_ms": 0.0279, - "max_ms": 0.0534, - "q1_ms": 0.0289, - "q3_ms": 0.0306, - "iqr_ms": 0.0017, - "outliers": 20, - "iterations": 200, - "refMeanMs": 0.084 - }, - "verified": true - }, - { - "workload": "bnpoLossBenchmark.bnpo_loss_fwd_batch32_seqlen02048_compiled", - "timingResults": { - "mean_ms": 0.0305, - "std_ms": 0.0035, - "min_ms": 0.0279, - "max_ms": 0.051, - "q1_ms": 0.0288, - "q3_ms": 0.0308, - "iqr_ms": 0.002, - "outliers": 15, - "iterations": 200, - "refMeanMs": 0.0764 - }, - "verified": true - }, - { - "workload": "bnpoLossBenchmark.bnpo_loss_fwd_batch64_seqlen04096_compiled", - "timingResults": { - "mean_ms": 0.0315, - "std_ms": 0.0033, - "min_ms": 0.0293, - "max_ms": 0.0543, - "q1_ms": 0.0302, - "q3_ms": 0.0311, - "iqr_ms": 0.0009, - "outliers": 21, - "iterations": 200, - "refMeanMs": 0.0739 - }, - "verified": true - } - ], - "machineInfo": { - "gpu": "NVIDIA H100 80GB HBM3", - "backend": "CUDA 13.0", - "pytorchVersion": "2.11.0+cu130", - "os": "Linux 6.11.0-1016-nvidia", - "cpu": "x86_64" - }, - "kernelCommitSha": "7972ab0e834be24d", - "benchmarkScriptPath": "benchmarks", - "benchmarkScriptSha": "68426064f76adff2066ad365f6c97be3fe279bd6b20d025b3dc5614f9b2da449" -} \ No newline at end of file diff --git a/benchmark_results/bnpo_loss_eager/bnpo_loss_eager_dark_animation.svg b/benchmark_results/bnpo_loss_eager/bnpo_loss_eager_dark_animation.svg deleted file mode 100644 index f3ddd7fa226b6778f3de0b9a950640939c174d9d..0000000000000000000000000000000000000000 --- a/benchmark_results/bnpo_loss_eager/bnpo_loss_eager_dark_animation.svg +++ /dev/null @@ -1,123 +0,0 @@ - -./ vs Torch - Relative Speed -PyTorch 2.11.0+cu130 · CUDA 13.0 - -bnpoLossBenchmark.bnpo_loss_batch128_seqlen02781_eager -15.51x - - - - - - - -bnpoLossBenchmark.bnpo_loss_batch128_seqlen08192_eager -18.80x - - - - - - - -bnpoLossBenchmark.bnpo_loss_batch16_seqlen01024_eager -17.01x - - - - - - - -bnpoLossBenchmark.bnpo_loss_batch16_seqlen02781_eager -18.00x - - - - - - - -bnpoLossBenchmark.bnpo_loss_batch32_seqlen02048_eager -18.27x - - - - - - - -bnpoLossBenchmark.bnpo_loss_batch64_seqlen04096_eager -17.37x - - - - - - - -bnpoLossBenchmark.bnpo_loss_fwd_batch128_seqlen02781_eager -5.86x - - - - - - - -bnpoLossBenchmark.bnpo_loss_fwd_batch128_seqlen08192_eager -5.59x - - - - - - - -bnpoLossBenchmark.bnpo_loss_fwd_batch16_seqlen01024_eager -5.66x - - - - - - - -bnpoLossBenchmark.bnpo_loss_fwd_batch16_seqlen02781_eager -6.09x - - - - - - - -bnpoLossBenchmark.bnpo_loss_fwd_batch32_seqlen02048_eager -4.95x - - - - - - - -bnpoLossBenchmark.bnpo_loss_fwd_batch64_seqlen04096_eager -5.41x - - - - - - - -Kernel - -Torch (ref) - - - - - - - - \ No newline at end of file diff --git a/benchmark_results/bnpo_loss_eager/bnpo_loss_eager_dark_latency.svg b/benchmark_results/bnpo_loss_eager/bnpo_loss_eager_dark_latency.svg deleted file mode 100644 index 7cd868f3e038b5676054610c4d4aadc406606b9b..0000000000000000000000000000000000000000 --- a/benchmark_results/bnpo_loss_eager/bnpo_loss_eager_dark_latency.svg +++ /dev/null @@ -1,3372 +0,0 @@ - - - - - - - - 2026-05-08T15:08:24.649226 - image/svg+xml - - - Matplotlib v3.10.9, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/benchmark_results/bnpo_loss_eager/bnpo_loss_eager_dark_throughput.svg b/benchmark_results/bnpo_loss_eager/bnpo_loss_eager_dark_throughput.svg deleted file mode 100644 index 59eaef9197015aad63aecb1194c37a3a89c2e8f1..0000000000000000000000000000000000000000 --- a/benchmark_results/bnpo_loss_eager/bnpo_loss_eager_dark_throughput.svg +++ /dev/null @@ -1,3629 +0,0 @@ - - - - - - - - 2026-05-08T15:08:24.930877 - image/svg+xml - - - Matplotlib v3.10.9, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/benchmark_results/bnpo_loss_eager/bnpo_loss_eager_light_animation.svg b/benchmark_results/bnpo_loss_eager/bnpo_loss_eager_light_animation.svg deleted file mode 100644 index e00ef7c0051c8f193fb202433d43432aa4bf2a53..0000000000000000000000000000000000000000 --- a/benchmark_results/bnpo_loss_eager/bnpo_loss_eager_light_animation.svg +++ /dev/null @@ -1,123 +0,0 @@ - -./ vs Torch - Relative Speed -PyTorch 2.11.0+cu130 · CUDA 13.0 - -bnpoLossBenchmark.bnpo_loss_batch128_seqlen02781_eager -15.51x - - - - - - - -bnpoLossBenchmark.bnpo_loss_batch128_seqlen08192_eager -18.80x - - - - - - - -bnpoLossBenchmark.bnpo_loss_batch16_seqlen01024_eager -17.01x - - - - - - - -bnpoLossBenchmark.bnpo_loss_batch16_seqlen02781_eager -18.00x - - - - - - - -bnpoLossBenchmark.bnpo_loss_batch32_seqlen02048_eager -18.27x - - - - - - - -bnpoLossBenchmark.bnpo_loss_batch64_seqlen04096_eager -17.37x - - - - - - - -bnpoLossBenchmark.bnpo_loss_fwd_batch128_seqlen02781_eager -5.86x - - - - - - - -bnpoLossBenchmark.bnpo_loss_fwd_batch128_seqlen08192_eager -5.59x - - - - - - - -bnpoLossBenchmark.bnpo_loss_fwd_batch16_seqlen01024_eager -5.66x - - - - - - - -bnpoLossBenchmark.bnpo_loss_fwd_batch16_seqlen02781_eager -6.09x - - - - - - - -bnpoLossBenchmark.bnpo_loss_fwd_batch32_seqlen02048_eager -4.95x - - - - - - - -bnpoLossBenchmark.bnpo_loss_fwd_batch64_seqlen04096_eager -5.41x - - - - - - - -Kernel - -Torch (ref) - - - - - - - - \ No newline at end of file diff --git a/benchmark_results/bnpo_loss_eager/bnpo_loss_eager_light_latency.svg b/benchmark_results/bnpo_loss_eager/bnpo_loss_eager_light_latency.svg deleted file mode 100644 index 1b377421327ac6e7899e4c7369a99727b26ededc..0000000000000000000000000000000000000000 --- a/benchmark_results/bnpo_loss_eager/bnpo_loss_eager_light_latency.svg +++ /dev/null @@ -1,3372 +0,0 @@ - - - - - - - - 2026-05-08T15:08:23.918541 - image/svg+xml - - - Matplotlib v3.10.9, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/benchmark_results/bnpo_loss_eager/bnpo_loss_eager_light_throughput.svg b/benchmark_results/bnpo_loss_eager/bnpo_loss_eager_light_throughput.svg deleted file mode 100644 index 996be9aaa9e30148c59769acd042616af4e3b120..0000000000000000000000000000000000000000 --- a/benchmark_results/bnpo_loss_eager/bnpo_loss_eager_light_throughput.svg +++ /dev/null @@ -1,3629 +0,0 @@ - - - - - - - - 2026-05-08T15:08:24.214872 - image/svg+xml - - - Matplotlib v3.10.9, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/benchmark_results/bnpo_loss_eager/results.json b/benchmark_results/bnpo_loss_eager/results.json deleted file mode 100644 index 80426aac977d8c901b8c1f954ac5fb1de39f7ef0..0000000000000000000000000000000000000000 --- a/benchmark_results/bnpo_loss_eager/results.json +++ /dev/null @@ -1,206 +0,0 @@ -{ - "results": [ - { - "workload": "bnpoLossBenchmark.bnpo_loss_batch128_seqlen02781_eager", - "timingResults": { - "mean_ms": 0.0358, - "std_ms": 0.0035, - "min_ms": 0.0323, - "max_ms": 0.0536, - "q1_ms": 0.0342, - "q3_ms": 0.0358, - "iqr_ms": 0.0017, - "outliers": 17, - "iterations": 200, - "refMeanMs": 0.5552 - }, - "verified": true - }, - { - "workload": "bnpoLossBenchmark.bnpo_loss_batch128_seqlen08192_eager", - "timingResults": { - "mean_ms": 0.0344, - "std_ms": 0.0031, - "min_ms": 0.0314, - "max_ms": 0.0537, - "q1_ms": 0.0329, - "q3_ms": 0.0345, - "iqr_ms": 0.0015, - "outliers": 20, - "iterations": 200, - "refMeanMs": 0.6466 - }, - "verified": true - }, - { - "workload": "bnpoLossBenchmark.bnpo_loss_batch16_seqlen01024_eager", - "timingResults": { - "mean_ms": 0.0345, - "std_ms": 0.0171, - "min_ms": 0.0305, - "max_ms": 0.2718, - "q1_ms": 0.0319, - "q3_ms": 0.033, - "iqr_ms": 0.0011, - "outliers": 23, - "iterations": 200, - "refMeanMs": 0.5868 - }, - "verified": true - }, - { - "workload": "bnpoLossBenchmark.bnpo_loss_batch16_seqlen02781_eager", - "timingResults": { - "mean_ms": 0.0324, - "std_ms": 0.0027, - "min_ms": 0.0301, - "max_ms": 0.0508, - "q1_ms": 0.0312, - "q3_ms": 0.0324, - "iqr_ms": 0.0012, - "outliers": 17, - "iterations": 200, - "refMeanMs": 0.5832 - }, - "verified": true - }, - { - "workload": "bnpoLossBenchmark.bnpo_loss_batch32_seqlen02048_eager", - "timingResults": { - "mean_ms": 0.0343, - "std_ms": 0.0033, - "min_ms": 0.031, - "max_ms": 0.0513, - "q1_ms": 0.0325, - "q3_ms": 0.0346, - "iqr_ms": 0.0021, - "outliers": 19, - "iterations": 200, - "refMeanMs": 0.6265 - }, - "verified": true - }, - { - "workload": "bnpoLossBenchmark.bnpo_loss_batch64_seqlen04096_eager", - "timingResults": { - "mean_ms": 0.0328, - "std_ms": 0.0029, - "min_ms": 0.0306, - "max_ms": 0.0499, - "q1_ms": 0.0317, - "q3_ms": 0.0326, - "iqr_ms": 0.0009, - "outliers": 20, - "iterations": 200, - "refMeanMs": 0.5698 - }, - "verified": true - }, - { - "workload": "bnpoLossBenchmark.bnpo_loss_fwd_batch128_seqlen02781_eager", - "timingResults": { - "mean_ms": 0.0317, - "std_ms": 0.0034, - "min_ms": 0.0285, - "max_ms": 0.052, - "q1_ms": 0.0305, - "q3_ms": 0.0314, - "iqr_ms": 0.0009, - "outliers": 22, - "iterations": 200, - "refMeanMs": 0.1858 - }, - "verified": true - }, - { - "workload": "bnpoLossBenchmark.bnpo_loss_fwd_batch128_seqlen08192_eager", - "timingResults": { - "mean_ms": 0.0292, - "std_ms": 0.0028, - "min_ms": 0.0273, - "max_ms": 0.0455, - "q1_ms": 0.0281, - "q3_ms": 0.0289, - "iqr_ms": 0.0008, - "outliers": 23, - "iterations": 200, - "refMeanMs": 0.1633 - }, - "verified": true - }, - { - "workload": "bnpoLossBenchmark.bnpo_loss_fwd_batch16_seqlen01024_eager", - "timingResults": { - "mean_ms": 0.0311, - "std_ms": 0.0267, - "min_ms": 0.0256, - "max_ms": 0.4049, - "q1_ms": 0.0276, - "q3_ms": 0.0295, - "iqr_ms": 0.0018, - "outliers": 18, - "iterations": 200, - "refMeanMs": 0.1761 - }, - "verified": true - }, - { - "workload": "bnpoLossBenchmark.bnpo_loss_fwd_batch16_seqlen02781_eager", - "timingResults": { - "mean_ms": 0.0288, - "std_ms": 0.003, - "min_ms": 0.027, - "max_ms": 0.0554, - "q1_ms": 0.0278, - "q3_ms": 0.0284, - "iqr_ms": 0.0006, - "outliers": 22, - "iterations": 200, - "refMeanMs": 0.1755 - }, - "verified": true - }, - { - "workload": "bnpoLossBenchmark.bnpo_loss_fwd_batch32_seqlen02048_eager", - "timingResults": { - "mean_ms": 0.031, - "std_ms": 0.0034, - "min_ms": 0.0281, - "max_ms": 0.0484, - "q1_ms": 0.0296, - "q3_ms": 0.0306, - "iqr_ms": 0.0009, - "outliers": 27, - "iterations": 200, - "refMeanMs": 0.1533 - }, - "verified": true - }, - { - "workload": "bnpoLossBenchmark.bnpo_loss_fwd_batch64_seqlen04096_eager", - "timingResults": { - "mean_ms": 0.031, - "std_ms": 0.0041, - "min_ms": 0.0286, - "max_ms": 0.0625, - "q1_ms": 0.0294, - "q3_ms": 0.0305, - "iqr_ms": 0.0011, - "outliers": 22, - "iterations": 200, - "refMeanMs": 0.1678 - }, - "verified": true - } - ], - "machineInfo": { - "gpu": "NVIDIA H100 80GB HBM3", - "backend": "CUDA 13.0", - "pytorchVersion": "2.11.0+cu130", - "os": "Linux 6.11.0-1016-nvidia", - "cpu": "x86_64" - }, - "kernelCommitSha": "84e79b2f3ee3088a", - "benchmarkScriptPath": "benchmarks", - "benchmarkScriptSha": "68426064f76adff2066ad365f6c97be3fe279bd6b20d025b3dc5614f9b2da449" -} \ No newline at end of file diff --git a/benchmark_results/grpo_loss_compiled/grpo_loss_compiled_dark_animation.svg b/benchmark_results/grpo_loss_compiled/grpo_loss_compiled_dark_animation.svg deleted file mode 100644 index 845ad6df6f97393888222fffcafd36381c88a3a6..0000000000000000000000000000000000000000 --- a/benchmark_results/grpo_loss_compiled/grpo_loss_compiled_dark_animation.svg +++ /dev/null @@ -1,105 +0,0 @@ - -./ vs Torch - Relative Speed -PyTorch 2.11.0+cu130 · CUDA 13.0 - -GrpoLossBenchmark.grpo_loss_batch128_seqlen02781_compiled -2.66x - - - - - - - -GrpoLossBenchmark.grpo_loss_batch128_seqlen08192_compiled -2.45x - - - - - - - -GrpoLossBenchmark.grpo_loss_batch16_seqlen01024_compiled -1.94x - - - - - - - -GrpoLossBenchmark.grpo_loss_batch32_seqlen02048_compiled -1.95x - - - - - - - -GrpoLossBenchmark.grpo_loss_batch64_seqlen04096_compiled -2.49x - - - - - - - -GrpoLossBenchmark.grpo_loss_fwd_batch128_seqlen02781_compiled -2.34x - - - - - - - -GrpoLossBenchmark.grpo_loss_fwd_batch128_seqlen08192_compiled -2.32x - - - - - - - -GrpoLossBenchmark.grpo_loss_fwd_batch16_seqlen01024_compiled -2.56x - - - - - - - -GrpoLossBenchmark.grpo_loss_fwd_batch32_seqlen02048_compiled -3.10x - - - - - - - -GrpoLossBenchmark.grpo_loss_fwd_batch64_seqlen04096_compiled -2.34x - - - - - - - -Kernel - -Torch (ref) - - - - - - - - \ No newline at end of file diff --git a/benchmark_results/grpo_loss_compiled/grpo_loss_compiled_dark_latency.svg b/benchmark_results/grpo_loss_compiled/grpo_loss_compiled_dark_latency.svg deleted file mode 100644 index e2e169c032fc182c38bd0e62b3616c90dc9b1d6a..0000000000000000000000000000000000000000 --- a/benchmark_results/grpo_loss_compiled/grpo_loss_compiled_dark_latency.svg +++ /dev/null @@ -1,3216 +0,0 @@ - - - - - - - - 2026-05-08T15:09:18.021986 - image/svg+xml - - - Matplotlib v3.10.9, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/benchmark_results/grpo_loss_compiled/grpo_loss_compiled_dark_throughput.svg b/benchmark_results/grpo_loss_compiled/grpo_loss_compiled_dark_throughput.svg deleted file mode 100644 index e9c68f1b3396727cd26e2fbb4a49227006ae1a99..0000000000000000000000000000000000000000 --- a/benchmark_results/grpo_loss_compiled/grpo_loss_compiled_dark_throughput.svg +++ /dev/null @@ -1,3437 +0,0 @@ - - - - - - - - 2026-05-08T15:09:18.292420 - image/svg+xml - - - Matplotlib v3.10.9, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/benchmark_results/grpo_loss_compiled/grpo_loss_compiled_light_animation.svg b/benchmark_results/grpo_loss_compiled/grpo_loss_compiled_light_animation.svg deleted file mode 100644 index eb3569b7f28e8ed04655eae5056a4b3ba0f8c062..0000000000000000000000000000000000000000 --- a/benchmark_results/grpo_loss_compiled/grpo_loss_compiled_light_animation.svg +++ /dev/null @@ -1,105 +0,0 @@ - -./ vs Torch - Relative Speed -PyTorch 2.11.0+cu130 · CUDA 13.0 - -GrpoLossBenchmark.grpo_loss_batch128_seqlen02781_compiled -2.66x - - - - - - - -GrpoLossBenchmark.grpo_loss_batch128_seqlen08192_compiled -2.45x - - - - - - - -GrpoLossBenchmark.grpo_loss_batch16_seqlen01024_compiled -1.94x - - - - - - - -GrpoLossBenchmark.grpo_loss_batch32_seqlen02048_compiled -1.95x - - - - - - - -GrpoLossBenchmark.grpo_loss_batch64_seqlen04096_compiled -2.49x - - - - - - - -GrpoLossBenchmark.grpo_loss_fwd_batch128_seqlen02781_compiled -2.34x - - - - - - - -GrpoLossBenchmark.grpo_loss_fwd_batch128_seqlen08192_compiled -2.32x - - - - - - - -GrpoLossBenchmark.grpo_loss_fwd_batch16_seqlen01024_compiled -2.56x - - - - - - - -GrpoLossBenchmark.grpo_loss_fwd_batch32_seqlen02048_compiled -3.10x - - - - - - - -GrpoLossBenchmark.grpo_loss_fwd_batch64_seqlen04096_compiled -2.34x - - - - - - - -Kernel - -Torch (ref) - - - - - - - - \ No newline at end of file diff --git a/benchmark_results/grpo_loss_compiled/grpo_loss_compiled_light_latency.svg b/benchmark_results/grpo_loss_compiled/grpo_loss_compiled_light_latency.svg deleted file mode 100644 index c2cac9ba02c975ebe08ec5c5adf8b330e753853f..0000000000000000000000000000000000000000 --- a/benchmark_results/grpo_loss_compiled/grpo_loss_compiled_light_latency.svg +++ /dev/null @@ -1,3216 +0,0 @@ - - - - - - - - 2026-05-08T15:09:17.298491 - image/svg+xml - - - Matplotlib v3.10.9, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/benchmark_results/grpo_loss_compiled/grpo_loss_compiled_light_throughput.svg b/benchmark_results/grpo_loss_compiled/grpo_loss_compiled_light_throughput.svg deleted file mode 100644 index 91a9951c6d9a2c76106fb8c868d198c16003a4f9..0000000000000000000000000000000000000000 --- a/benchmark_results/grpo_loss_compiled/grpo_loss_compiled_light_throughput.svg +++ /dev/null @@ -1,3437 +0,0 @@ - - - - - - - - 2026-05-08T15:09:17.593907 - image/svg+xml - - - Matplotlib v3.10.9, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/benchmark_results/grpo_loss_compiled/results.json b/benchmark_results/grpo_loss_compiled/results.json deleted file mode 100644 index ba1fbfe1bf9be219c4963f3f9018e2d63481dac0..0000000000000000000000000000000000000000 --- a/benchmark_results/grpo_loss_compiled/results.json +++ /dev/null @@ -1,174 +0,0 @@ -{ - "results": [ - { - "workload": "GrpoLossBenchmark.grpo_loss_batch128_seqlen02781_compiled", - "timingResults": { - "mean_ms": 0.0329, - "std_ms": 0.0042, - "min_ms": 0.0301, - "max_ms": 0.0632, - "q1_ms": 0.031, - "q3_ms": 0.0326, - "iqr_ms": 0.0016, - "outliers": 22, - "iterations": 200, - "refMeanMs": 0.0874 - }, - "verified": true - }, - { - "workload": "GrpoLossBenchmark.grpo_loss_batch128_seqlen08192_compiled", - "timingResults": { - "mean_ms": 0.0337, - "std_ms": 0.0045, - "min_ms": 0.0305, - "max_ms": 0.065, - "q1_ms": 0.0318, - "q3_ms": 0.0333, - "iqr_ms": 0.0015, - "outliers": 23, - "iterations": 200, - "refMeanMs": 0.0824 - }, - "verified": true - }, - { - "workload": "GrpoLossBenchmark.grpo_loss_batch16_seqlen01024_compiled", - "timingResults": { - "mean_ms": 0.0323, - "std_ms": 0.0045, - "min_ms": 0.0286, - "max_ms": 0.0621, - "q1_ms": 0.0306, - "q3_ms": 0.0321, - "iqr_ms": 0.0015, - "outliers": 24, - "iterations": 200, - "refMeanMs": 0.0626 - }, - "verified": true - }, - { - "workload": "GrpoLossBenchmark.grpo_loss_batch32_seqlen02048_compiled", - "timingResults": { - "mean_ms": 0.0324, - "std_ms": 0.0046, - "min_ms": 0.0286, - "max_ms": 0.0688, - "q1_ms": 0.0305, - "q3_ms": 0.0321, - "iqr_ms": 0.0016, - "outliers": 22, - "iterations": 200, - "refMeanMs": 0.0633 - }, - "verified": true - }, - { - "workload": "GrpoLossBenchmark.grpo_loss_batch64_seqlen04096_compiled", - "timingResults": { - "mean_ms": 0.0349, - "std_ms": 0.0058, - "min_ms": 0.0315, - "max_ms": 0.0814, - "q1_ms": 0.0325, - "q3_ms": 0.0341, - "iqr_ms": 0.0016, - "outliers": 26, - "iterations": 200, - "refMeanMs": 0.0869 - }, - "verified": true - }, - { - "workload": "GrpoLossBenchmark.grpo_loss_fwd_batch128_seqlen02781_compiled", - "timingResults": { - "mean_ms": 0.033, - "std_ms": 0.0038, - "min_ms": 0.0295, - "max_ms": 0.0543, - "q1_ms": 0.0313, - "q3_ms": 0.0333, - "iqr_ms": 0.0019, - "outliers": 16, - "iterations": 200, - "refMeanMs": 0.0772 - }, - "verified": true - }, - { - "workload": "GrpoLossBenchmark.grpo_loss_fwd_batch128_seqlen08192_compiled", - "timingResults": { - "mean_ms": 0.0331, - "std_ms": 0.0032, - "min_ms": 0.0295, - "max_ms": 0.0535, - "q1_ms": 0.0316, - "q3_ms": 0.0331, - "iqr_ms": 0.0015, - "outliers": 19, - "iterations": 200, - "refMeanMs": 0.0767 - }, - "verified": true - }, - { - "workload": "GrpoLossBenchmark.grpo_loss_fwd_batch16_seqlen01024_compiled", - "timingResults": { - "mean_ms": 0.033, - "std_ms": 0.0032, - "min_ms": 0.029, - "max_ms": 0.051, - "q1_ms": 0.0315, - "q3_ms": 0.0332, - "iqr_ms": 0.0016, - "outliers": 17, - "iterations": 200, - "refMeanMs": 0.0845 - }, - "verified": true - }, - { - "workload": "GrpoLossBenchmark.grpo_loss_fwd_batch32_seqlen02048_compiled", - "timingResults": { - "mean_ms": 0.0339, - "std_ms": 0.006, - "min_ms": 0.03, - "max_ms": 0.0674, - "q1_ms": 0.0314, - "q3_ms": 0.0331, - "iqr_ms": 0.0017, - "outliers": 23, - "iterations": 200, - "refMeanMs": 0.1052 - }, - "verified": true - }, - { - "workload": "GrpoLossBenchmark.grpo_loss_fwd_batch64_seqlen04096_compiled", - "timingResults": { - "mean_ms": 0.034, - "std_ms": 0.004, - "min_ms": 0.031, - "max_ms": 0.0623, - "q1_ms": 0.0323, - "q3_ms": 0.0339, - "iqr_ms": 0.0016, - "outliers": 20, - "iterations": 200, - "refMeanMs": 0.0796 - }, - "verified": true - } - ], - "machineInfo": { - "gpu": "NVIDIA H100 80GB HBM3", - "backend": "CUDA 13.0", - "pytorchVersion": "2.11.0+cu130", - "os": "Linux 6.11.0-1016-nvidia", - "cpu": "x86_64" - }, - "kernelCommitSha": "ad285d68b8c8c0ff", - "benchmarkScriptPath": "benchmarks", - "benchmarkScriptSha": "ff35d63fbca37cfcbf5c94f067c930adc2bd0043ce6788f286dbad5a4f9b9d4a" -} \ No newline at end of file diff --git a/benchmark_results/grpo_loss_eager/grpo_loss_eager_dark_animation.svg b/benchmark_results/grpo_loss_eager/grpo_loss_eager_dark_animation.svg deleted file mode 100644 index a77b4f3436b46390be2bee024aff963184dd2d5a..0000000000000000000000000000000000000000 --- a/benchmark_results/grpo_loss_eager/grpo_loss_eager_dark_animation.svg +++ /dev/null @@ -1,105 +0,0 @@ - -./ vs Torch - Relative Speed -PyTorch 2.11.0+cu130 · CUDA 13.0 - -GrpoLossBenchmark.grpo_loss_batch128_seqlen02781_eager -21.22x - - - - - - - -GrpoLossBenchmark.grpo_loss_batch128_seqlen08192_eager -19.29x - - - - - - - -GrpoLossBenchmark.grpo_loss_batch16_seqlen01024_eager -19.47x - - - - - - - -GrpoLossBenchmark.grpo_loss_batch32_seqlen02048_eager -20.01x - - - - - - - -GrpoLossBenchmark.grpo_loss_batch64_seqlen04096_eager -19.65x - - - - - - - -GrpoLossBenchmark.grpo_loss_fwd_batch128_seqlen02781_eager -5.61x - - - - - - - -GrpoLossBenchmark.grpo_loss_fwd_batch128_seqlen08192_eager -5.63x - - - - - - - -GrpoLossBenchmark.grpo_loss_fwd_batch16_seqlen01024_eager -5.69x - - - - - - - -GrpoLossBenchmark.grpo_loss_fwd_batch32_seqlen02048_eager -5.51x - - - - - - - -GrpoLossBenchmark.grpo_loss_fwd_batch64_seqlen04096_eager -5.58x - - - - - - - -Kernel - -Torch (ref) - - - - - - - - \ No newline at end of file diff --git a/benchmark_results/grpo_loss_eager/grpo_loss_eager_dark_latency.svg b/benchmark_results/grpo_loss_eager/grpo_loss_eager_dark_latency.svg deleted file mode 100644 index 7643747cde724df3556c8a5dda0b820bf97e3238..0000000000000000000000000000000000000000 --- a/benchmark_results/grpo_loss_eager/grpo_loss_eager_dark_latency.svg +++ /dev/null @@ -1,3128 +0,0 @@ - - - - - - - - 2026-05-08T15:08:53.746514 - image/svg+xml - - - Matplotlib v3.10.9, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/benchmark_results/grpo_loss_eager/grpo_loss_eager_dark_throughput.svg b/benchmark_results/grpo_loss_eager/grpo_loss_eager_dark_throughput.svg deleted file mode 100644 index 3be72f7a0083379128f957fa2fb1d26bad9f7738..0000000000000000000000000000000000000000 --- a/benchmark_results/grpo_loss_eager/grpo_loss_eager_dark_throughput.svg +++ /dev/null @@ -1,3405 +0,0 @@ - - - - - - - - 2026-05-08T15:08:54.010839 - image/svg+xml - - - Matplotlib v3.10.9, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/benchmark_results/grpo_loss_eager/grpo_loss_eager_light_animation.svg b/benchmark_results/grpo_loss_eager/grpo_loss_eager_light_animation.svg deleted file mode 100644 index d5b47b4483e8c7c41ae8b16dbc19016c0f9760e6..0000000000000000000000000000000000000000 --- a/benchmark_results/grpo_loss_eager/grpo_loss_eager_light_animation.svg +++ /dev/null @@ -1,105 +0,0 @@ - -./ vs Torch - Relative Speed -PyTorch 2.11.0+cu130 · CUDA 13.0 - -GrpoLossBenchmark.grpo_loss_batch128_seqlen02781_eager -21.22x - - - - - - - -GrpoLossBenchmark.grpo_loss_batch128_seqlen08192_eager -19.29x - - - - - - - -GrpoLossBenchmark.grpo_loss_batch16_seqlen01024_eager -19.47x - - - - - - - -GrpoLossBenchmark.grpo_loss_batch32_seqlen02048_eager -20.01x - - - - - - - -GrpoLossBenchmark.grpo_loss_batch64_seqlen04096_eager -19.65x - - - - - - - -GrpoLossBenchmark.grpo_loss_fwd_batch128_seqlen02781_eager -5.61x - - - - - - - -GrpoLossBenchmark.grpo_loss_fwd_batch128_seqlen08192_eager -5.63x - - - - - - - -GrpoLossBenchmark.grpo_loss_fwd_batch16_seqlen01024_eager -5.69x - - - - - - - -GrpoLossBenchmark.grpo_loss_fwd_batch32_seqlen02048_eager -5.51x - - - - - - - -GrpoLossBenchmark.grpo_loss_fwd_batch64_seqlen04096_eager -5.58x - - - - - - - -Kernel - -Torch (ref) - - - - - - - - \ No newline at end of file diff --git a/benchmark_results/grpo_loss_eager/grpo_loss_eager_light_latency.svg b/benchmark_results/grpo_loss_eager/grpo_loss_eager_light_latency.svg deleted file mode 100644 index fc341b05bfcdd787fa9b9c2175fcb80604181dbf..0000000000000000000000000000000000000000 --- a/benchmark_results/grpo_loss_eager/grpo_loss_eager_light_latency.svg +++ /dev/null @@ -1,3128 +0,0 @@ - - - - - - - - 2026-05-08T15:08:53.053345 - image/svg+xml - - - Matplotlib v3.10.9, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/benchmark_results/grpo_loss_eager/grpo_loss_eager_light_throughput.svg b/benchmark_results/grpo_loss_eager/grpo_loss_eager_light_throughput.svg deleted file mode 100644 index 53d76e5dc1cc81bc52aa5cf6ac6176e79c6fa07a..0000000000000000000000000000000000000000 --- a/benchmark_results/grpo_loss_eager/grpo_loss_eager_light_throughput.svg +++ /dev/null @@ -1,3405 +0,0 @@ - - - - - - - - 2026-05-08T15:08:53.335016 - image/svg+xml - - - Matplotlib v3.10.9, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/benchmark_results/grpo_loss_eager/results.json b/benchmark_results/grpo_loss_eager/results.json deleted file mode 100644 index 093556c65e7da9ae4773a1ca31fdb143af7557ab..0000000000000000000000000000000000000000 --- a/benchmark_results/grpo_loss_eager/results.json +++ /dev/null @@ -1,174 +0,0 @@ -{ - "results": [ - { - "workload": "GrpoLossBenchmark.grpo_loss_batch128_seqlen02781_eager", - "timingResults": { - "mean_ms": 0.0313, - "std_ms": 0.0029, - "min_ms": 0.0281, - "max_ms": 0.0482, - "q1_ms": 0.03, - "q3_ms": 0.0314, - "iqr_ms": 0.0013, - "outliers": 16, - "iterations": 200, - "refMeanMs": 0.6643 - }, - "verified": true - }, - { - "workload": "GrpoLossBenchmark.grpo_loss_batch128_seqlen08192_eager", - "timingResults": { - "mean_ms": 0.0309, - "std_ms": 0.0031, - "min_ms": 0.0285, - "max_ms": 0.0477, - "q1_ms": 0.0298, - "q3_ms": 0.0306, - "iqr_ms": 0.0008, - "outliers": 19, - "iterations": 200, - "refMeanMs": 0.5961 - }, - "verified": true - }, - { - "workload": "GrpoLossBenchmark.grpo_loss_batch16_seqlen01024_eager", - "timingResults": { - "mean_ms": 0.0315, - "std_ms": 0.0033, - "min_ms": 0.0293, - "max_ms": 0.0507, - "q1_ms": 0.0302, - "q3_ms": 0.0311, - "iqr_ms": 0.0009, - "outliers": 23, - "iterations": 200, - "refMeanMs": 0.6132 - }, - "verified": true - }, - { - "workload": "GrpoLossBenchmark.grpo_loss_batch32_seqlen02048_eager", - "timingResults": { - "mean_ms": 0.0302, - "std_ms": 0.0029, - "min_ms": 0.028, - "max_ms": 0.0467, - "q1_ms": 0.029, - "q3_ms": 0.0299, - "iqr_ms": 0.0008, - "outliers": 20, - "iterations": 200, - "refMeanMs": 0.6043 - }, - "verified": true - }, - { - "workload": "GrpoLossBenchmark.grpo_loss_batch64_seqlen04096_eager", - "timingResults": { - "mean_ms": 0.0295, - "std_ms": 0.003, - "min_ms": 0.0268, - "max_ms": 0.0465, - "q1_ms": 0.0279, - "q3_ms": 0.03, - "iqr_ms": 0.002, - "outliers": 12, - "iterations": 200, - "refMeanMs": 0.5798 - }, - "verified": true - }, - { - "workload": "GrpoLossBenchmark.grpo_loss_fwd_batch128_seqlen02781_eager", - "timingResults": { - "mean_ms": 0.0306, - "std_ms": 0.0032, - "min_ms": 0.0281, - "max_ms": 0.0513, - "q1_ms": 0.0293, - "q3_ms": 0.0302, - "iqr_ms": 0.0009, - "outliers": 24, - "iterations": 200, - "refMeanMs": 0.1716 - }, - "verified": true - }, - { - "workload": "GrpoLossBenchmark.grpo_loss_fwd_batch128_seqlen08192_eager", - "timingResults": { - "mean_ms": 0.0302, - "std_ms": 0.0031, - "min_ms": 0.0284, - "max_ms": 0.0594, - "q1_ms": 0.0291, - "q3_ms": 0.0299, - "iqr_ms": 0.0008, - "outliers": 21, - "iterations": 200, - "refMeanMs": 0.1701 - }, - "verified": true - }, - { - "workload": "GrpoLossBenchmark.grpo_loss_fwd_batch16_seqlen01024_eager", - "timingResults": { - "mean_ms": 0.0306, - "std_ms": 0.0027, - "min_ms": 0.0286, - "max_ms": 0.0455, - "q1_ms": 0.0294, - "q3_ms": 0.0304, - "iqr_ms": 0.001, - "outliers": 16, - "iterations": 200, - "refMeanMs": 0.1741 - }, - "verified": true - }, - { - "workload": "GrpoLossBenchmark.grpo_loss_fwd_batch32_seqlen02048_eager", - "timingResults": { - "mean_ms": 0.0299, - "std_ms": 0.0029, - "min_ms": 0.0269, - "max_ms": 0.0488, - "q1_ms": 0.0287, - "q3_ms": 0.0301, - "iqr_ms": 0.0015, - "outliers": 14, - "iterations": 200, - "refMeanMs": 0.1647 - }, - "verified": true - }, - { - "workload": "GrpoLossBenchmark.grpo_loss_fwd_batch64_seqlen04096_eager", - "timingResults": { - "mean_ms": 0.0314, - "std_ms": 0.0028, - "min_ms": 0.0289, - "max_ms": 0.0465, - "q1_ms": 0.0301, - "q3_ms": 0.0312, - "iqr_ms": 0.0011, - "outliers": 22, - "iterations": 200, - "refMeanMs": 0.1751 - }, - "verified": true - } - ], - "machineInfo": { - "gpu": "NVIDIA H100 80GB HBM3", - "backend": "CUDA 13.0", - "pytorchVersion": "2.11.0+cu130", - "os": "Linux 6.11.0-1016-nvidia", - "cpu": "x86_64" - }, - "kernelCommitSha": "87ec9b61421d0121", - "benchmarkScriptPath": "benchmarks", - "benchmarkScriptSha": "ff35d63fbca37cfcbf5c94f067c930adc2bd0043ce6788f286dbad5a4f9b9d4a" -} \ No newline at end of file diff --git a/benchmark_results/reverse_kl_compiled/results.json b/benchmark_results/reverse_kl_compiled/results.json deleted file mode 100644 index 556cbea44f5220777dadd1de56ba32ac91cbf1ed..0000000000000000000000000000000000000000 --- a/benchmark_results/reverse_kl_compiled/results.json +++ /dev/null @@ -1,206 +0,0 @@ -{ - "results": [ - { - "workload": "ReverseKLBenchmark.reverse_kl_batch01_seqlen064_vocab248320_compiled", - "timingResults": { - "mean_ms": 0.1039, - "std_ms": 0.0035, - "min_ms": 0.1, - "max_ms": 0.1229, - "q1_ms": 0.1018, - "q3_ms": 0.104, - "iqr_ms": 0.0022, - "outliers": 28, - "iterations": 200, - "refMeanMs": 0.2322 - }, - "verified": true - }, - { - "workload": "ReverseKLBenchmark.reverse_kl_batch02_seqlen128_vocab248320_compiled", - "timingResults": { - "mean_ms": 0.2483, - "std_ms": 0.0035, - "min_ms": 0.2418, - "max_ms": 0.2612, - "q1_ms": 0.2457, - "q3_ms": 0.2513, - "iqr_ms": 0.0057, - "outliers": 2, - "iterations": 200, - "refMeanMs": 0.6455 - }, - "verified": true - }, - { - "workload": "ReverseKLBenchmark.reverse_kl_batch04_seqlen256_vocab248320_compiled", - "timingResults": { - "mean_ms": 0.8322, - "std_ms": 0.0044, - "min_ms": 0.8232, - "max_ms": 0.8623, - "q1_ms": 0.8303, - "q3_ms": 0.8335, - "iqr_ms": 0.0032, - "outliers": 18, - "iterations": 200, - "refMeanMs": 2.2082 - }, - "verified": true - }, - { - "workload": "ReverseKLBenchmark.reverse_kl_batch08_seqlen1024_vocab248320_compiled", - "timingResults": { - "mean_ms": 6.1083, - "std_ms": 0.0054, - "min_ms": 6.097, - "max_ms": 6.1513, - "q1_ms": 6.1054, - "q3_ms": 6.11, - "iqr_ms": 0.0046, - "outliers": 13, - "iterations": 200, - "refMeanMs": 16.4779 - }, - "verified": true - }, - { - "workload": "ReverseKLBenchmark.reverse_kl_batch08_seqlen512_vocab248320_compiled", - "timingResults": { - "mean_ms": 3.0861, - "std_ms": 0.0045, - "min_ms": 3.0769, - "max_ms": 3.1155, - "q1_ms": 3.0832, - "q3_ms": 3.0883, - "iqr_ms": 0.0051, - "outliers": 5, - "iterations": 200, - "refMeanMs": 8.3849 - }, - "verified": true - }, - { - "workload": "ReverseKLBenchmark.reverse_kl_batch08_seqlen981_vocab248320_compiled", - "timingResults": { - "mean_ms": 5.8622, - "std_ms": 0.0044, - "min_ms": 5.8544, - "max_ms": 5.8821, - "q1_ms": 5.859, - "q3_ms": 5.8646, - "iqr_ms": 0.0056, - "outliers": 6, - "iterations": 200, - "refMeanMs": 15.8101 - }, - "verified": true - }, - { - "workload": "ReverseKLBenchmark.reverse_kl_fwd_batch01_seqlen064_vocab248320_compiled", - "timingResults": { - "mean_ms": 0.0657, - "std_ms": 0.0041, - "min_ms": 0.0619, - "max_ms": 0.093, - "q1_ms": 0.0635, - "q3_ms": 0.0656, - "iqr_ms": 0.0021, - "outliers": 24, - "iterations": 200, - "refMeanMs": 0.1434 - }, - "verified": true - }, - { - "workload": "ReverseKLBenchmark.reverse_kl_fwd_batch02_seqlen128_vocab248320_compiled", - "timingResults": { - "mean_ms": 0.1234, - "std_ms": 0.0041, - "min_ms": 0.1187, - "max_ms": 0.1464, - "q1_ms": 0.1208, - "q3_ms": 0.1244, - "iqr_ms": 0.0036, - "outliers": 16, - "iterations": 200, - "refMeanMs": 0.3277 - }, - "verified": true - }, - { - "workload": "ReverseKLBenchmark.reverse_kl_fwd_batch04_seqlen256_vocab248320_compiled", - "timingResults": { - "mean_ms": 0.3764, - "std_ms": 0.0037, - "min_ms": 0.3699, - "max_ms": 0.3926, - "q1_ms": 0.3733, - "q3_ms": 0.3787, - "iqr_ms": 0.0054, - "outliers": 2, - "iterations": 200, - "refMeanMs": 0.9228 - }, - "verified": true - }, - { - "workload": "ReverseKLBenchmark.reverse_kl_fwd_batch08_seqlen1024_vocab248320_compiled", - "timingResults": { - "mean_ms": 2.658, - "std_ms": 0.0089, - "min_ms": 2.6359, - "max_ms": 2.6859, - "q1_ms": 2.6524, - "q3_ms": 2.663, - "iqr_ms": 0.0106, - "outliers": 4, - "iterations": 200, - "refMeanMs": 6.6033 - }, - "verified": true - }, - { - "workload": "ReverseKLBenchmark.reverse_kl_fwd_batch08_seqlen512_vocab248320_compiled", - "timingResults": { - "mean_ms": 1.38, - "std_ms": 0.0035, - "min_ms": 1.37, - "max_ms": 1.3924, - "q1_ms": 1.3776, - "q3_ms": 1.3818, - "iqr_ms": 0.0042, - "outliers": 6, - "iterations": 200, - "refMeanMs": 3.3854 - }, - "verified": true - }, - { - "workload": "ReverseKLBenchmark.reverse_kl_fwd_batch08_seqlen981_vocab248320_compiled", - "timingResults": { - "mean_ms": 2.5422, - "std_ms": 0.0091, - "min_ms": 2.5286, - "max_ms": 2.5773, - "q1_ms": 2.5356, - "q3_ms": 2.5455, - "iqr_ms": 0.0099, - "outliers": 9, - "iterations": 200, - "refMeanMs": 6.2191 - }, - "verified": true - } - ], - "machineInfo": { - "gpu": "NVIDIA H100 80GB HBM3", - "backend": "CUDA 13.0", - "pytorchVersion": "2.11.0+cu130", - "os": "Linux 6.11.0-1016-nvidia", - "cpu": "x86_64" - }, - "kernelCommitSha": "ca5cbc20b4d2c7d8", - "benchmarkScriptPath": "benchmarks", - "benchmarkScriptSha": "690eea1f54f31bef1ad248380201005fd667d4b9c535f92f06eb6a5a33380d22" -} \ No newline at end of file diff --git a/benchmark_results/reverse_kl_compiled/reverse_kl_compiled_dark_animation.svg b/benchmark_results/reverse_kl_compiled/reverse_kl_compiled_dark_animation.svg deleted file mode 100644 index d65ac5ab4d492584ff4ac7015449d3cc639a522d..0000000000000000000000000000000000000000 --- a/benchmark_results/reverse_kl_compiled/reverse_kl_compiled_dark_animation.svg +++ /dev/null @@ -1,123 +0,0 @@ - -./ vs Torch - Relative Speed -PyTorch 2.11.0+cu130 · CUDA 13.0 - -ReverseKLBenchmark.reverse_kl_batch01_seqlen064_vocab248320_compiled -2.23x - - - - - - - -ReverseKLBenchmark.reverse_kl_batch02_seqlen128_vocab248320_compiled -2.60x - - - - - - - -ReverseKLBenchmark.reverse_kl_batch04_seqlen256_vocab248320_compiled -2.65x - - - - - - - -ReverseKLBenchmark.reverse_kl_batch08_seqlen1024_vocab248320_compiled -2.70x - - - - - - - -ReverseKLBenchmark.reverse_kl_batch08_seqlen512_vocab248320_compiled -2.72x - - - - - - - -ReverseKLBenchmark.reverse_kl_batch08_seqlen981_vocab248320_compiled -2.70x - - - - - - - -ReverseKLBenchmark.reverse_kl_fwd_batch01_seqlen064_vocab248320_compiled -2.18x - - - - - - - -ReverseKLBenchmark.reverse_kl_fwd_batch02_seqlen128_vocab248320_compiled -2.66x - - - - - - - -ReverseKLBenchmark.reverse_kl_fwd_batch04_seqlen256_vocab248320_compiled -2.45x - - - - - - - -ReverseKLBenchmark.reverse_kl_fwd_batch08_seqlen1024_vocab248320_compiled -2.48x - - - - - - - -ReverseKLBenchmark.reverse_kl_fwd_batch08_seqlen512_vocab248320_compiled -2.45x - - - - - - - -ReverseKLBenchmark.reverse_kl_fwd_batch08_seqlen981_vocab248320_compiled -2.45x - - - - - - - -Kernel - -Torch (ref) - - - - - - - - \ No newline at end of file diff --git a/benchmark_results/reverse_kl_compiled/reverse_kl_compiled_dark_latency.svg b/benchmark_results/reverse_kl_compiled/reverse_kl_compiled_dark_latency.svg deleted file mode 100644 index 34d7ae2b2290044ed31d918bdc732d377b9a8b6c..0000000000000000000000000000000000000000 --- a/benchmark_results/reverse_kl_compiled/reverse_kl_compiled_dark_latency.svg +++ /dev/null @@ -1,3569 +0,0 @@ - - - - - - - - 2026-05-08T15:09:55.712221 - image/svg+xml - - - Matplotlib v3.10.9, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/benchmark_results/reverse_kl_compiled/reverse_kl_compiled_dark_throughput.svg b/benchmark_results/reverse_kl_compiled/reverse_kl_compiled_dark_throughput.svg deleted file mode 100644 index e376f192677f10a8813fc761b20c26b623169dd1..0000000000000000000000000000000000000000 --- a/benchmark_results/reverse_kl_compiled/reverse_kl_compiled_dark_throughput.svg +++ /dev/null @@ -1,3817 +0,0 @@ - - - - - - - - 2026-05-08T15:09:55.996697 - image/svg+xml - - - Matplotlib v3.10.9, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/benchmark_results/reverse_kl_compiled/reverse_kl_compiled_light_animation.svg b/benchmark_results/reverse_kl_compiled/reverse_kl_compiled_light_animation.svg deleted file mode 100644 index fa0b7de5e869d37115d38d2b6672a84674881b4d..0000000000000000000000000000000000000000 --- a/benchmark_results/reverse_kl_compiled/reverse_kl_compiled_light_animation.svg +++ /dev/null @@ -1,123 +0,0 @@ - -./ vs Torch - Relative Speed -PyTorch 2.11.0+cu130 · CUDA 13.0 - -ReverseKLBenchmark.reverse_kl_batch01_seqlen064_vocab248320_compiled -2.23x - - - - - - - -ReverseKLBenchmark.reverse_kl_batch02_seqlen128_vocab248320_compiled -2.60x - - - - - - - -ReverseKLBenchmark.reverse_kl_batch04_seqlen256_vocab248320_compiled -2.65x - - - - - - - -ReverseKLBenchmark.reverse_kl_batch08_seqlen1024_vocab248320_compiled -2.70x - - - - - - - -ReverseKLBenchmark.reverse_kl_batch08_seqlen512_vocab248320_compiled -2.72x - - - - - - - -ReverseKLBenchmark.reverse_kl_batch08_seqlen981_vocab248320_compiled -2.70x - - - - - - - -ReverseKLBenchmark.reverse_kl_fwd_batch01_seqlen064_vocab248320_compiled -2.18x - - - - - - - -ReverseKLBenchmark.reverse_kl_fwd_batch02_seqlen128_vocab248320_compiled -2.66x - - - - - - - -ReverseKLBenchmark.reverse_kl_fwd_batch04_seqlen256_vocab248320_compiled -2.45x - - - - - - - -ReverseKLBenchmark.reverse_kl_fwd_batch08_seqlen1024_vocab248320_compiled -2.48x - - - - - - - -ReverseKLBenchmark.reverse_kl_fwd_batch08_seqlen512_vocab248320_compiled -2.45x - - - - - - - -ReverseKLBenchmark.reverse_kl_fwd_batch08_seqlen981_vocab248320_compiled -2.45x - - - - - - - -Kernel - -Torch (ref) - - - - - - - - \ No newline at end of file diff --git a/benchmark_results/reverse_kl_compiled/reverse_kl_compiled_light_latency.svg b/benchmark_results/reverse_kl_compiled/reverse_kl_compiled_light_latency.svg deleted file mode 100644 index 4f42106b2b489db13d7ad8a53b7c46e1daa3a5b8..0000000000000000000000000000000000000000 --- a/benchmark_results/reverse_kl_compiled/reverse_kl_compiled_light_latency.svg +++ /dev/null @@ -1,3569 +0,0 @@ - - - - - - - - 2026-05-08T15:09:54.975890 - image/svg+xml - - - Matplotlib v3.10.9, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/benchmark_results/reverse_kl_compiled/reverse_kl_compiled_light_throughput.svg b/benchmark_results/reverse_kl_compiled/reverse_kl_compiled_light_throughput.svg deleted file mode 100644 index c0302e3e9da84b2191c5b601392ef9d171cc294b..0000000000000000000000000000000000000000 --- a/benchmark_results/reverse_kl_compiled/reverse_kl_compiled_light_throughput.svg +++ /dev/null @@ -1,3817 +0,0 @@ - - - - - - - - 2026-05-08T15:09:55.275490 - image/svg+xml - - - Matplotlib v3.10.9, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/benchmark_results/reverse_kl_eager/results.json b/benchmark_results/reverse_kl_eager/results.json deleted file mode 100644 index d08977eff8fa485bbb8358835b325bac45de56f5..0000000000000000000000000000000000000000 --- a/benchmark_results/reverse_kl_eager/results.json +++ /dev/null @@ -1,206 +0,0 @@ -{ - "results": [ - { - "workload": "ReverseKLBenchmark.reverse_kl_batch01_seqlen064_vocab248320_eager", - "timingResults": { - "mean_ms": 0.1029, - "std_ms": 0.0032, - "min_ms": 0.0982, - "max_ms": 0.1129, - "q1_ms": 0.101, - "q3_ms": 0.1036, - "iqr_ms": 0.0026, - "outliers": 27, - "iterations": 200, - "refMeanMs": 0.5293 - }, - "verified": true - }, - { - "workload": "ReverseKLBenchmark.reverse_kl_batch02_seqlen128_vocab248320_eager", - "timingResults": { - "mean_ms": 0.248, - "std_ms": 0.0037, - "min_ms": 0.2417, - "max_ms": 0.2592, - "q1_ms": 0.2451, - "q3_ms": 0.251, - "iqr_ms": 0.0058, - "outliers": 0, - "iterations": 200, - "refMeanMs": 1.624 - }, - "verified": true - }, - { - "workload": "ReverseKLBenchmark.reverse_kl_batch04_seqlen256_vocab248320_eager", - "timingResults": { - "mean_ms": 0.8321, - "std_ms": 0.0035, - "min_ms": 0.8234, - "max_ms": 0.854, - "q1_ms": 0.8306, - "q3_ms": 0.8335, - "iqr_ms": 0.003, - "outliers": 20, - "iterations": 200, - "refMeanMs": 6.174 - }, - "verified": true - }, - { - "workload": "ReverseKLBenchmark.reverse_kl_batch08_seqlen1024_vocab248320_eager", - "timingResults": { - "mean_ms": 6.1046, - "std_ms": 0.0041, - "min_ms": 6.0961, - "max_ms": 6.1376, - "q1_ms": 6.1023, - "q3_ms": 6.106, - "iqr_ms": 0.0037, - "outliers": 9, - "iterations": 200, - "refMeanMs": 48.4051 - }, - "verified": true - }, - { - "workload": "ReverseKLBenchmark.reverse_kl_batch08_seqlen512_vocab248320_eager", - "timingResults": { - "mean_ms": 3.0816, - "std_ms": 0.0035, - "min_ms": 3.0743, - "max_ms": 3.0939, - "q1_ms": 3.0794, - "q3_ms": 3.0832, - "iqr_ms": 0.0038, - "outliers": 8, - "iterations": 200, - "refMeanMs": 24.3385 - }, - "verified": true - }, - { - "workload": "ReverseKLBenchmark.reverse_kl_batch08_seqlen981_vocab248320_eager", - "timingResults": { - "mean_ms": 5.8549, - "std_ms": 0.0045, - "min_ms": 5.8459, - "max_ms": 5.8819, - "q1_ms": 5.8524, - "q3_ms": 5.8561, - "iqr_ms": 0.0037, - "outliers": 14, - "iterations": 200, - "refMeanMs": 46.4274 - }, - "verified": true - }, - { - "workload": "ReverseKLBenchmark.reverse_kl_fwd_batch01_seqlen064_vocab248320_eager", - "timingResults": { - "mean_ms": 0.0638, - "std_ms": 0.0027, - "min_ms": 0.0604, - "max_ms": 0.0787, - "q1_ms": 0.0624, - "q3_ms": 0.064, - "iqr_ms": 0.0016, - "outliers": 20, - "iterations": 200, - "refMeanMs": 0.2532 - }, - "verified": true - }, - { - "workload": "ReverseKLBenchmark.reverse_kl_fwd_batch02_seqlen128_vocab248320_eager", - "timingResults": { - "mean_ms": 0.1217, - "std_ms": 0.0038, - "min_ms": 0.1166, - "max_ms": 0.1428, - "q1_ms": 0.1193, - "q3_ms": 0.1227, - "iqr_ms": 0.0034, - "outliers": 19, - "iterations": 200, - "refMeanMs": 0.7671 - }, - "verified": true - }, - { - "workload": "ReverseKLBenchmark.reverse_kl_fwd_batch04_seqlen256_vocab248320_eager", - "timingResults": { - "mean_ms": 0.3753, - "std_ms": 0.0033, - "min_ms": 0.3695, - "max_ms": 0.3843, - "q1_ms": 0.3726, - "q3_ms": 0.3779, - "iqr_ms": 0.0053, - "outliers": 0, - "iterations": 200, - "refMeanMs": 2.869 - }, - "verified": true - }, - { - "workload": "ReverseKLBenchmark.reverse_kl_fwd_batch08_seqlen1024_vocab248320_eager", - "timingResults": { - "mean_ms": 2.6484, - "std_ms": 0.0065, - "min_ms": 2.6364, - "max_ms": 2.7044, - "q1_ms": 2.6449, - "q3_ms": 2.6515, - "iqr_ms": 0.0067, - "outliers": 3, - "iterations": 200, - "refMeanMs": 22.3336 - }, - "verified": true - }, - { - "workload": "ReverseKLBenchmark.reverse_kl_fwd_batch08_seqlen512_vocab248320_eager", - "timingResults": { - "mean_ms": 1.365, - "std_ms": 0.0046, - "min_ms": 1.3548, - "max_ms": 1.3865, - "q1_ms": 1.3618, - "q3_ms": 1.3675, - "iqr_ms": 0.0057, - "outliers": 4, - "iterations": 200, - "refMeanMs": 11.2401 - }, - "verified": true - }, - { - "workload": "ReverseKLBenchmark.reverse_kl_fwd_batch08_seqlen981_vocab248320_eager", - "timingResults": { - "mean_ms": 2.5316, - "std_ms": 0.0059, - "min_ms": 2.5203, - "max_ms": 2.5523, - "q1_ms": 2.5272, - "q3_ms": 2.5355, - "iqr_ms": 0.0083, - "outliers": 3, - "iterations": 200, - "refMeanMs": 21.4099 - }, - "verified": true - } - ], - "machineInfo": { - "gpu": "NVIDIA H100 80GB HBM3", - "backend": "CUDA 13.0", - "pytorchVersion": "2.11.0+cu130", - "os": "Linux 6.11.0-1016-nvidia", - "cpu": "x86_64" - }, - "kernelCommitSha": "3e023eb5121761b8", - "benchmarkScriptPath": "benchmarks", - "benchmarkScriptSha": "690eea1f54f31bef1ad248380201005fd667d4b9c535f92f06eb6a5a33380d22" -} \ No newline at end of file diff --git a/benchmark_results/reverse_kl_eager/reverse_kl_eager_dark_animation.svg b/benchmark_results/reverse_kl_eager/reverse_kl_eager_dark_animation.svg deleted file mode 100644 index 6bd29cf1f71f41e09af0cdd8b3f42dfbc1c31250..0000000000000000000000000000000000000000 --- a/benchmark_results/reverse_kl_eager/reverse_kl_eager_dark_animation.svg +++ /dev/null @@ -1,123 +0,0 @@ - -./ vs Torch - Relative Speed -PyTorch 2.11.0+cu130 · CUDA 13.0 - -ReverseKLBenchmark.reverse_kl_batch01_seqlen064_vocab248320_eager -5.14x - - - - - - - -ReverseKLBenchmark.reverse_kl_batch02_seqlen128_vocab248320_eager -6.55x - - - - - - - -ReverseKLBenchmark.reverse_kl_batch04_seqlen256_vocab248320_eager -7.42x - - - - - - - -ReverseKLBenchmark.reverse_kl_batch08_seqlen1024_vocab248320_eager -7.93x - - - - - - - -ReverseKLBenchmark.reverse_kl_batch08_seqlen512_vocab248320_eager -7.90x - - - - - - - -ReverseKLBenchmark.reverse_kl_batch08_seqlen981_vocab248320_eager -7.93x - - - - - - - -ReverseKLBenchmark.reverse_kl_fwd_batch01_seqlen064_vocab248320_eager -3.97x - - - - - - - -ReverseKLBenchmark.reverse_kl_fwd_batch02_seqlen128_vocab248320_eager -6.30x - - - - - - - -ReverseKLBenchmark.reverse_kl_fwd_batch04_seqlen256_vocab248320_eager -7.64x - - - - - - - -ReverseKLBenchmark.reverse_kl_fwd_batch08_seqlen1024_vocab248320_eager -8.43x - - - - - - - -ReverseKLBenchmark.reverse_kl_fwd_batch08_seqlen512_vocab248320_eager -8.23x - - - - - - - -ReverseKLBenchmark.reverse_kl_fwd_batch08_seqlen981_vocab248320_eager -8.46x - - - - - - - -Kernel - -Torch (ref) - - - - - - - - \ No newline at end of file diff --git a/benchmark_results/reverse_kl_eager/reverse_kl_eager_dark_latency.svg b/benchmark_results/reverse_kl_eager/reverse_kl_eager_dark_latency.svg deleted file mode 100644 index 56883331249a4721c53426ab926e542e7526cf8c..0000000000000000000000000000000000000000 --- a/benchmark_results/reverse_kl_eager/reverse_kl_eager_dark_latency.svg +++ /dev/null @@ -1,3535 +0,0 @@ - - - - - - - - 2026-05-08T15:09:34.973115 - image/svg+xml - - - Matplotlib v3.10.9, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/benchmark_results/reverse_kl_eager/reverse_kl_eager_dark_throughput.svg b/benchmark_results/reverse_kl_eager/reverse_kl_eager_dark_throughput.svg deleted file mode 100644 index 04347f2fc035bb7321ced383977dc7b80dff842a..0000000000000000000000000000000000000000 --- a/benchmark_results/reverse_kl_eager/reverse_kl_eager_dark_throughput.svg +++ /dev/null @@ -1,3756 +0,0 @@ - - - - - - - - 2026-05-08T15:09:35.253165 - image/svg+xml - - - Matplotlib v3.10.9, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/benchmark_results/reverse_kl_eager/reverse_kl_eager_light_animation.svg b/benchmark_results/reverse_kl_eager/reverse_kl_eager_light_animation.svg deleted file mode 100644 index b81a1d952bf78a479461336118617968a7f8723a..0000000000000000000000000000000000000000 --- a/benchmark_results/reverse_kl_eager/reverse_kl_eager_light_animation.svg +++ /dev/null @@ -1,123 +0,0 @@ - -./ vs Torch - Relative Speed -PyTorch 2.11.0+cu130 · CUDA 13.0 - -ReverseKLBenchmark.reverse_kl_batch01_seqlen064_vocab248320_eager -5.14x - - - - - - - -ReverseKLBenchmark.reverse_kl_batch02_seqlen128_vocab248320_eager -6.55x - - - - - - - -ReverseKLBenchmark.reverse_kl_batch04_seqlen256_vocab248320_eager -7.42x - - - - - - - -ReverseKLBenchmark.reverse_kl_batch08_seqlen1024_vocab248320_eager -7.93x - - - - - - - -ReverseKLBenchmark.reverse_kl_batch08_seqlen512_vocab248320_eager -7.90x - - - - - - - -ReverseKLBenchmark.reverse_kl_batch08_seqlen981_vocab248320_eager -7.93x - - - - - - - -ReverseKLBenchmark.reverse_kl_fwd_batch01_seqlen064_vocab248320_eager -3.97x - - - - - - - -ReverseKLBenchmark.reverse_kl_fwd_batch02_seqlen128_vocab248320_eager -6.30x - - - - - - - -ReverseKLBenchmark.reverse_kl_fwd_batch04_seqlen256_vocab248320_eager -7.64x - - - - - - - -ReverseKLBenchmark.reverse_kl_fwd_batch08_seqlen1024_vocab248320_eager -8.43x - - - - - - - -ReverseKLBenchmark.reverse_kl_fwd_batch08_seqlen512_vocab248320_eager -8.23x - - - - - - - -ReverseKLBenchmark.reverse_kl_fwd_batch08_seqlen981_vocab248320_eager -8.46x - - - - - - - -Kernel - -Torch (ref) - - - - - - - - \ No newline at end of file diff --git a/benchmark_results/reverse_kl_eager/reverse_kl_eager_light_latency.svg b/benchmark_results/reverse_kl_eager/reverse_kl_eager_light_latency.svg deleted file mode 100644 index 51f2860419b90e817620e29dee79656364d07552..0000000000000000000000000000000000000000 --- a/benchmark_results/reverse_kl_eager/reverse_kl_eager_light_latency.svg +++ /dev/null @@ -1,3535 +0,0 @@ - - - - - - - - 2026-05-08T15:09:34.228994 - image/svg+xml - - - Matplotlib v3.10.9, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/benchmark_results/reverse_kl_eager/reverse_kl_eager_light_throughput.svg b/benchmark_results/reverse_kl_eager/reverse_kl_eager_light_throughput.svg deleted file mode 100644 index c81589239f3ffdf23b690be86a1804e0f49ff3a7..0000000000000000000000000000000000000000 --- a/benchmark_results/reverse_kl_eager/reverse_kl_eager_light_throughput.svg +++ /dev/null @@ -1,3756 +0,0 @@ - - - - - - - - 2026-05-08T15:09:34.536765 - image/svg+xml - - - Matplotlib v3.10.9, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/build/torch-cuda/__init__.py b/build/torch-cuda/__init__.py deleted file mode 100644 index f710f3dd6ed6a972ee3cd40a6824d6ead33094f9..0000000000000000000000000000000000000000 --- a/build/torch-cuda/__init__.py +++ /dev/null @@ -1,69 +0,0 @@ -"""Geometric-AI CuteDSL kernels for RL / distillation training. - -Public surface: - * ``bnpo_loss`` / ``bnpo_loss_autograd`` / ``bnpo_loss_fwd`` — - fused fwd+bwd BNPO loss with three entry points (direct - ``(loss, grad)``, autograd-wrapped, forward-only). - * ``grpo_loss`` / ``grpo_loss_autograd`` / ``grpo_loss_fwd`` — - fused fwd+bwd GRPO loss (TRL's per-response normalization - variant). Same three-entry-point shape as BNPO. Requires - ``completions_mask``. - * ``reverse_kl`` / ``reverse_kl_autograd`` / - ``reverse_kl_fwd`` — fused fwd+bwd reverse-KL - self-distillation loss with the same three-entry-point shape. - -HF Kernels integration: :mod:`geometric_ai_kernels.layers` exposes -``nn.Module`` adapters per kernel (``bnpoLoss`` / ``bnpoLossInference``, -``grpoLoss`` / ``grpoLossInference``, ``ReverseKL`` / -``ReverseKLInference``) for use with the ``kernels`` -library's ``kernelize()`` flow. -""" - -from __future__ import annotations - -import torch._dynamo - -from .bnpo_loss import bnpo_loss, bnpo_loss_autograd, bnpo_loss_fwd -from .grpo_loss import grpo_loss, grpo_loss_autograd, grpo_loss_fwd -from .layers import ( - ReverseKL, - ReverseKLInference, - bnpoLoss, - bnpoLossInference, - grpoLoss, - grpoLossInference, -) -from .reverse_kl import ( - reverse_kl, - reverse_kl_autograd, - reverse_kl_fwd, -) - -# Required so ``torch.compile(fullgraph=True)`` can trace through -# ``torch.autograd.grad`` calls — without it Dynamo graph-breaks at the -# autograd.grad call site even when AOTAutograd has already derived the -# joint fwd+bwd graph. Set at package import so any consumer (benches, -# user training loops, ``kernelize`` flows) gets it for free. Guarded -# because ``trace_autograd_ops`` was added in torch 2.10 and the -# Nix-pinned build environment may be on an older torch (2.9 today); -# the underlying ``Config.__setattr__`` raises on unknown keys. -if hasattr(torch._dynamo.config, "trace_autograd_ops"): - torch._dynamo.config.trace_autograd_ops = True # ty: ignore[invalid-assignment] - -__all__ = [ - "ReverseKL", - "ReverseKLInference", - "bnpoLoss", - "bnpoLossInference", - "bnpo_loss", - "bnpo_loss_autograd", - "bnpo_loss_fwd", - "grpoLoss", - "grpoLossInference", - "grpo_loss", - "grpo_loss_autograd", - "grpo_loss_fwd", - "reverse_kl", - "reverse_kl_autograd", - "reverse_kl_fwd", -] diff --git a/build/torch-cuda/_ops.py b/build/torch-cuda/_ops.py deleted file mode 100644 index 80dc07c3b89ddf9661020bed5993df9e5bad586c..0000000000000000000000000000000000000000 --- a/build/torch-cuda/_ops.py +++ /dev/null @@ -1,38 +0,0 @@ -import torch - -def get_backend() -> str: - """Detect the backend by inspecting torch.""" - import torch - - if hasattr(torch, "neuron"): - # Needs to be sorted before specific Torch builds, since Neuron - # extension can be loaded into e.g. CUDA Torch builds. - return "neuron" - elif torch.version.cuda is not None: - return "cuda" - elif torch.version.hip is not None: - return "rocm" - elif torch.backends.mps.is_available(): - return "metal" - elif hasattr(torch.version, "xpu") and torch.version.xpu is not None: - return "xpu" - else: - return "cpu" - - -def _find_ops_name() -> str: - kernel_name = "geometric_ai_kernels" - unique_id = "a766fbd_dirty" - backend = get_backend() - return f"_{kernel_name}_{backend}_{unique_id}" - - -_OPS_NAME = _find_ops_name() - -ops = getattr(torch.ops, _OPS_NAME) - -def add_op_namespace_prefix(op_name: str) -> str: - """ - Prefix op by namespace. - """ - return f"{_OPS_NAME}::{op_name}" \ No newline at end of file diff --git a/build/torch-cuda/bnpo_loss/__init__.py b/build/torch-cuda/bnpo_loss/__init__.py deleted file mode 100644 index b7c0aff70dc2bd0373e13522f1158b378f811600..0000000000000000000000000000000000000000 --- a/build/torch-cuda/bnpo_loss/__init__.py +++ /dev/null @@ -1,196 +0,0 @@ -"""bnpo loss with CuteDSL fused fwd+bwd. - -Two public APIs route to two compiled kernels: - -* :func:`bnpo_loss` — primary training entry point. Returns - ``(loss, grad_policy_logprobs)`` from a single fused fwd+bwd kernel - launch. Inputs do **not** need ``requires_grad=True`` and there is no - ``torch.autograd.Function`` wrapper — chain the gradient into the - upstream model with ``policy_logprobs.backward(grad)`` (or, more - commonly, by passing ``grad`` to whatever step does the next leg of - backprop). -* :func:`bnpo_loss_fwd` — inference / validation path. Returns the - scalar ``loss`` from a forward-only kernel that computes the masked - mean denominator on-GPU via a last-block trick (no host - ``completions_mask.sum()``). - -The two share the same compiled-kernel cache; per-call output and -gradient buffers are allocated inside the runner, and cross-CTA scratch -(atomic accumulators + counters) is owned by the compiled-kernel -closure and self-resets each launch — callers don't manage scratch. - -Why no autograd wrapper here? bnpo's gradient is closed-form — the -kernel already writes ``dL/d(policy_logprobs)`` in the same launch as -the loss. Wrapping in ``torch.autograd.Function`` would cost an extra -``grad_output * dpolicy`` kernel launch on backward (typically a -no-op multiply by ``1.0``), plus per-call autograd graph bookkeeping. -The autograd-aware sibling :func:`bnpo_loss_autograd` uses -``torch.library.custom_op`` instead, which composes with -``torch.compile``. -""" - -from __future__ import annotations - -from functools import lru_cache -from typing import TYPE_CHECKING, cast - -import torch - -from .cute_bnpo_loss import ( - create_compiled_bnpo_loss, - create_compiled_bnpo_loss_with_backward, -) - -if TYPE_CHECKING: - from collections.abc import Callable - - -__all__ = ["bnpo_loss", "bnpo_loss_autograd", "bnpo_loss_fwd"] - - -@lru_cache(maxsize=32) -def _get_compiled_fwd( - dtype: torch.dtype, - epsilon: float, - epsilon_high: float, - beta: float, -) -> Callable[..., torch.Tensor]: - return cast( - "Callable[..., torch.Tensor]", - create_compiled_bnpo_loss( - policy_dtype=dtype, - epsilon=epsilon, - epsilon_high=epsilon_high, - beta=beta, - compute_backward=False, - ), - ) - - -@lru_cache(maxsize=32) -def _get_compiled_fwd_bwd( - dtype: torch.dtype, - epsilon: float, - epsilon_high: float, - beta: float, -) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]: - return create_compiled_bnpo_loss_with_backward( - policy_dtype=dtype, - epsilon=epsilon, - epsilon_high=epsilon_high, - beta=beta, - ) - - -def bnpo_loss_fwd( - policy_logprobs: torch.Tensor, - old_policy_logprobs: torch.Tensor, - ref_logprobs: torch.Tensor, - advantages: torch.Tensor, - completions_mask: torch.Tensor, - epsilon: float = 0.2, - epsilon_high: float = 0.2, - beta: float = 0.1, -) -> torch.Tensor: - """Forward-only bnpo loss. Returns the scalar ``loss``. - - Use for inference / validation. The masked mean denominator is - computed on-GPU by an atomic accumulator + last-block trick — no - host ``completions_mask.sum()`` syncs. - - Args: - policy_logprobs, old_policy_logprobs, ref_logprobs: ``(bs, seq_len)``. - advantages: ``(bs,)``. - completions_mask: bool/int8 mask ``(bs, seq_len)``; truthy = valid token. - epsilon, epsilon_high: PPO-style clipping bounds. - beta: KL-penalty coefficient. ``0.0`` compiles away the KL branch. - - Returns: - Scalar tensor (0-dim) with the same dtype as ``policy_logprobs``. - """ - run = _get_compiled_fwd( - policy_logprobs.dtype, - float(epsilon), - float(epsilon_high), - float(beta), - ) - mask_arg = ( - completions_mask - if completions_mask.dtype == torch.int8 - else completions_mask.to(torch.int8) - ) - return run( - policy_logprobs, - old_policy_logprobs, - ref_logprobs, - advantages, - mask_arg, - ) - - -def bnpo_loss( - policy_logprobs: torch.Tensor, - old_policy_logprobs: torch.Tensor, - ref_logprobs: torch.Tensor, - advantages: torch.Tensor, - completions_mask: torch.Tensor, - epsilon: float = 0.2, - epsilon_high: float = 0.2, - beta: float = 0.1, -) -> tuple[torch.Tensor, torch.Tensor]: - """Fused fwd+bwd bnpo loss. Returns ``(loss, grad_policy_logprobs)``. - - Single-launch training entry point. The kernel writes both the - scalar loss and the scaled ``dL/d(policy_logprobs)`` tensor in one - ``@cute.jit`` dispatch — a bundled mask-sum kernel runs inside the - same launch so ``inv_total`` is populated on-GPU without a host-side - ``torch.sum`` round trip. - - Inputs do **not** need ``requires_grad=True``. To chain ``grad`` - into the upstream model that produced ``policy_logprobs``:: - - loss, grad = bnpo_loss(policy_logprobs, ..., completions_mask=mask) - policy_logprobs.backward(grad) - optimizer.step() - - Args: - policy_logprobs, old_policy_logprobs, ref_logprobs: ``(bs, seq_len)``. - advantages: ``(bs,)``. - completions_mask: bool/int8 mask ``(bs, seq_len)``. - epsilon, epsilon_high: PPO-style clipping bounds. - beta: KL-penalty coefficient. ``0.0`` compiles away the KL branch. - - Returns: - ``(loss, grad_policy_logprobs)`` — ``loss`` is a 0-dim tensor in - ``policy_logprobs.dtype``; ``grad_policy_logprobs`` has shape - ``(bs, seq_len)`` and is already scaled by ``1 / n_valid``. The - gradient tensor is freshly allocated per call (no shared cache), - so callers may keep it around freely. - - For inference / validation where you only need the loss, use - :func:`bnpo_loss_fwd` — it skips the dpolicy write entirely and - computes the mean denominator with the on-GPU last-block trick. - """ - run = _get_compiled_fwd_bwd( - policy_logprobs.dtype, - float(epsilon), - float(epsilon_high), - float(beta), - ) - mask_arg = ( - completions_mask - if completions_mask.dtype == torch.int8 - else completions_mask.to(torch.int8) - ) - return run( - policy_logprobs, - old_policy_logprobs, - ref_logprobs, - advantages, - mask_arg, - ) - - -# Imported at the bottom: ``autograd.py`` imports ``bnpo_loss`` from this -# module, so the function must be fully defined before its import runs. -from .autograd import bnpo_loss_autograd # noqa: E402 diff --git a/build/torch-cuda/bnpo_loss/_torch_ref.py b/build/torch-cuda/bnpo_loss/_torch_ref.py deleted file mode 100644 index ce1ef06ba7ac33e4404b1484836d0b1dc8687328..0000000000000000000000000000000000000000 --- a/build/torch-cuda/bnpo_loss/_torch_ref.py +++ /dev/null @@ -1,56 +0,0 @@ -"""Plain-PyTorch bnpo reference shared between the bench and the tests. - -This module is intentionally minimal — every op is a vanilla torch op so -``AOTAutograd`` can derive the joint fwd+bwd graph and Inductor can fuse -both passes (used by ``benchmarks/benchmark_bnpo_loss.py``'s compiled -baseline). The same function is imported by ``tests/test_bnpo_loss.py`` -as the correctness reference, so both paths agree on what "the eager -torch implementation of bnpo loss" means. - -Underscore-prefixed module name signals "shared internal", not a public -API surface — there's no re-export from the package's top-level -``__init__.py``. -""" - -from __future__ import annotations - -import torch - - -def torch_bnpo_loss( - policy_logprobs: torch.Tensor, - old_policy_logprobs: torch.Tensor, - ref_logprobs: torch.Tensor, - advantages: torch.Tensor, - completions_mask: torch.Tensor, - epsilon: float = 0.2, - epsilon_high: float = 0.2, - beta: float = 0.1, -) -> torch.Tensor: - """Plain-Python bnpo reference traceable by AOTAutograd / Inductor. - - Operates in the input dtype throughout (no internal fp32 cast), - which is what real torch users would write — and what - ``torch.compile`` competes against in the bench. - """ - ratio = torch.exp(policy_logprobs - old_policy_logprobs) - adv = advantages.unsqueeze(1) - - surrogate = ratio * adv - surrogate_clipped = torch.clamp(ratio, 1.0 - epsilon, 1.0 + epsilon_high) * adv - policy_loss = -torch.min(surrogate, surrogate_clipped) - - log_ratio_ref = ref_logprobs - policy_logprobs - kl = torch.exp(log_ratio_ref) - log_ratio_ref - 1.0 - - # Cast n_valid to fp32: int64 → fp16 overflows when n_valid > 65504. - # ``clamp(min=1.0)`` matches TRL's ``mask.sum().clamp(min=1)``: a - # fully-masked batch produces ``loss=0`` instead of inf/NaN. Mirrors - # the cute kernel's ``cute.arch.fmax(..., 1.0)`` before ``rcp_approx`` - # in ``cute_bnpo_loss.py``. - n_valid = completions_mask.sum().to(torch.float32).clamp(min=1.0) - policy_loss = (policy_loss * completions_mask).sum() / n_valid - kl = (kl * completions_mask).sum() / n_valid - - loss = policy_loss + beta * kl - return loss.to(policy_logprobs.dtype) diff --git a/build/torch-cuda/bnpo_loss/autograd.py b/build/torch-cuda/bnpo_loss/autograd.py deleted file mode 100644 index f4cdf9de6808a4d66419c41ab58eb1170267bc0a..0000000000000000000000000000000000000000 --- a/build/torch-cuda/bnpo_loss/autograd.py +++ /dev/null @@ -1,149 +0,0 @@ -"""Autograd-aware wrapper for bnpo loss via ``torch.library.custom_op``. - -The fused cute kernel writes both the scalar loss and the closed-form -``dL/d(policy_logprobs)`` in one launch. This module wraps that into an -autograd-compatible op so callers can write:: - - loss = bnpo_loss_autograd(policy, old, ref, adv, completions_mask) - loss.backward() # propagates through to the upstream model - -instead of the manual ``policy.backward(grad)`` chain. The cost is -~12µs of autograd dispatcher overhead per call (vs the direct -``bnpo_loss`` ``(loss, grad)`` tuple); for ergonomic / kernelize() flows -that's cheap, but for tight microbenches use the direct path. - -Implementation notes: - -- The registered op returns ``(loss, dpolicy)`` so ``setup_context`` can - ``save_for_backward(dpolicy)``. The public ``bnpo_loss_autograd`` - wrapper hides the second output. -- ``dpolicy`` is allocated fresh by the runner on every call (no shared - cache), so ``ctx.save_for_backward(dpolicy)`` keeps a stable reference - across subsequent calls without any extra copy. -- Backward returns ``grad_loss * dpolicy``. Under ``torch.compile``, - when ``loss`` is consumed by ``.backward()`` directly, ``grad_loss`` - is the constant 1.0 and Inductor can fold the multiply away — that's - the main reason this path uses ``custom_op`` instead of a plain - ``autograd.Function``. -- ``register_fake`` provides the meta kernel for ``torch.compile`` shape - propagation; the real cute kernel never runs under ``FakeTensorMode``. -""" - -from __future__ import annotations - -import torch - -from . import bnpo_loss as _bnpo_loss_fwd_bwd - -__all__ = ["bnpo_loss_autograd"] - - -@torch.library.custom_op( - "geometric_ai_kernels::_bnpo_loss_with_grad", - mutates_args=(), -) -def _bnpo_loss_with_grad( - policy_logprobs: torch.Tensor, - old_policy_logprobs: torch.Tensor, - ref_logprobs: torch.Tensor, - advantages: torch.Tensor, - completions_mask: torch.Tensor, - epsilon: float, - epsilon_high: float, - beta: float, -) -> tuple[torch.Tensor, torch.Tensor]: - loss, dpolicy = _bnpo_loss_fwd_bwd( - policy_logprobs, - old_policy_logprobs, - ref_logprobs, - advantages, - completions_mask, - epsilon=epsilon, - epsilon_high=epsilon_high, - beta=beta, - ) - return loss, dpolicy - - -@_bnpo_loss_with_grad.register_fake -def _( - policy_logprobs: torch.Tensor, - old_policy_logprobs: torch.Tensor, - ref_logprobs: torch.Tensor, - advantages: torch.Tensor, - completions_mask: torch.Tensor, - epsilon: float, - epsilon_high: float, - beta: float, -) -> tuple[torch.Tensor, torch.Tensor]: - # Signature must mirror the op; only ``policy_logprobs`` shapes the outputs. - del old_policy_logprobs, ref_logprobs, advantages, completions_mask - del epsilon, epsilon_high, beta - loss = policy_logprobs.new_empty(()) - dpolicy = torch.empty_like(policy_logprobs) - return loss, dpolicy - - -def _setup_context(ctx, inputs, output) -> None: # type: ignore[no-untyped-def] - del inputs # only ``output`` carries what we need to save. - _, dpolicy = output - ctx.save_for_backward(dpolicy) - - -def _backward(ctx, grad_loss, grad_dpolicy): # type: ignore[no-untyped-def] - # ``grad_dpolicy`` is unused — ``dpolicy`` is an internal intermediate - # exposed only so ``setup_context`` can save it. Under typical usage - # (``loss.backward()``) it arrives as ``None`` or a zero tensor. - del grad_dpolicy - (dpolicy,) = ctx.saved_tensors - grad_policy = grad_loss * dpolicy - # One return per input to the op (8): policy_logprobs gets the grad, - # everything else gets None (no autograd flow). - return grad_policy, None, None, None, None, None, None, None - - -torch.library.register_autograd( - "geometric_ai_kernels::_bnpo_loss_with_grad", - _backward, - setup_context=_setup_context, -) - - -def bnpo_loss_autograd( - policy_logprobs: torch.Tensor, - old_policy_logprobs: torch.Tensor, - ref_logprobs: torch.Tensor, - advantages: torch.Tensor, - completions_mask: torch.Tensor, - epsilon: float = 0.2, - epsilon_high: float = 0.2, - beta: float = 0.1, -) -> torch.Tensor: - """Autograd-aware bnpo loss. Returns scalar ``loss``. - - Same numerics as :func:`bnpo_loss` but registered as a - ``torch.library`` custom op with autograd, so:: - - loss = bnpo_loss_autograd(policy, ..., completions_mask) - loss.backward() - - propagates through to whatever produced ``policy_logprobs``. For - direct ``(loss, grad)`` access without the autograd dispatcher - overhead, use :func:`bnpo_loss` and chain the gradient manually - via ``policy_logprobs.backward(grad)``. - - Composes with ``torch.compile``: the op is opaque to Inductor but - has a fake/meta kernel registered, so models containing this layer - can be compiled end-to-end without graph breaks. - """ - loss, _ = _bnpo_loss_with_grad( - policy_logprobs, - old_policy_logprobs, - ref_logprobs, - advantages, - completions_mask, - float(epsilon), - float(epsilon_high), - float(beta), - ) - return loss diff --git a/build/torch-cuda/bnpo_loss/cute_bnpo_loss.py b/build/torch-cuda/bnpo_loss/cute_bnpo_loss.py deleted file mode 100644 index d9849b921283a762635255be9e68220935bff7d0..0000000000000000000000000000000000000000 --- a/build/torch-cuda/bnpo_loss/cute_bnpo_loss.py +++ /dev/null @@ -1,1081 +0,0 @@ -"""CuteDSL kernel for bnpo loss. - -Computes (element-wise over ``(bs, seq_len)`` logprob tensors, reduced to a -scalar): - - ratio = exp(policy - old_policy) - surrogate = ratio * adv - clipped = clip(ratio, 1 - eps, 1 + eps_high) * adv - policy_loss = -min(surrogate, clipped) - log_ratio_ref = ref - policy - kl = exp(log_ratio_ref) - log_ratio_ref - 1 - L_bnpo = (policy_loss * mask).sum() / n_valid - + beta * (kl * mask).sum() / n_valid - -where ``n_valid = max(completions_mask.sum(), 1)``. The mean denominator is -computed entirely on-GPU — the forward-only path uses an atomic accumulator -+ last-block trick on ``valid_acc``; the fused fwd+bwd path bundles a small -companion mask-sum kernel into the same ``@cute.jit`` launch that writes -``1 / completions_mask.sum()`` into the ``inv_total`` GMEM scalar before the -main kernel reads it. Every block needs ``inv_total`` mid-loop to scale its -``dpolicy`` slab, so the fwd-only last-block trick doesn't compose with -backward; bundling the mask-sum keeps both paths host-sync-free and CUDA-graph -compatible. - -When ``beta=0`` the KL term is skipped at compile time (no ``ref`` tensor -access, no ``kl_acc`` atomic add). - -Sequence lengths that are **not** a multiple of ``TILE_N`` are handled -natively: the grid launches ``ceil(seq_len / TILE_N)`` column tiles; full tiles -use the vectorized ``LDG.128`` path and the tail tile uses predicated vector -loads with neutral prefill. - -Two compiled-kernel flavors are exposed: - -* :func:`create_compiled_bnpo_loss` — forward-only. -* :func:`create_compiled_bnpo_loss_with_backward` — fused fwd+bwd. Returns - ``(loss, dpolicy)`` directly — no ``torch.autograd.Function`` wrapper. The - autograd-aware sibling lives in ``autograd.py`` and uses - ``torch.library.custom_op`` instead. - -Per-call output (``loss``, ``dpolicy``, ``inv_total``) is allocated inside the -runner. Cross-CTA scratch (atomic accumulators + counters) is allocated lazily -on first call inside the compiled-kernel closure and self-resets each launch -via the kernel's last-block epilogue + ``atom.inc.u32`` wrap-around — callers -don't manage scratch state. -""" - -from __future__ import annotations - -import math -import operator -from typing import TYPE_CHECKING, Any -from typing import cast as _typing_cast - -import cutlass -import cutlass.utils -import torch -from cutlass import cute -from cutlass._mlir.dialects import llvm -from cutlass.base_dsl.typing import cast -from cutlass.cutlass_dsl import T, dsl_user_op - -if TYPE_CHECKING: - from collections.abc import Callable - - -TILE_N: int = 512 -NUM_WARPS: int = 4 -# ``VEC=4`` (fp32) emits 128-bit ``LDG.128``. Pairs with ``NUM_WARPS=4`` so -# each block processes ``block_size * VEC = 512 = TILE_N`` elements per iter. -VEC: int = 4 -# Large-tile variant: at very long ``seq_len`` the small-TILE_N grid -# explodes (e.g. 8192/512 = 16 col-tiles per row → thousands of CTAs), -# inflating last-block-detection latency and atomic contention. A second -# compiled variant with this larger tile is dispatched when -# ``seq_len >= TILE_N_LARGE_THRESHOLD``. -TILE_N_LARGE: int = 4096 -TILE_N_LARGE_THRESHOLD: int = 2048 - -_LOG2_E: float = math.log2(math.e) - -_TORCH_TO_CUTLASS_DTYPE: dict[torch.dtype, Any] = { - torch.float32: cutlass.Float32, - torch.float16: cutlass.Float16, - torch.bfloat16: cutlass.BFloat16, -} - - -@dsl_user_op -def _atomic_add_f32_gmem( - ptr_i64: Any, - val: cutlass.Float32, - *, - loc: Any = None, - ip: Any = None, -) -> None: - llvm.inline_asm( - T.f32(), - [ptr_i64, cutlass.Float32(val).ir_value(loc=loc, ip=ip)], - "atom.global.add.f32 $0, [$1], $2;", - "=f,l,f", - has_side_effects=True, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - - -@dsl_user_op -def _atomic_add_s32_gmem( - ptr_i64: Any, - val: cutlass.Int32, - *, - loc: Any = None, - ip: Any = None, -) -> None: - """Emit ``atom.global.add.s32`` to a 64-bit GMEM address.""" - llvm.inline_asm( - T.i32(), - [ptr_i64, cutlass.Int32(val).ir_value(loc=loc, ip=ip)], - "atom.global.add.s32 $0, [$1], $2;", - "=r,l,r", - has_side_effects=True, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - - -@dsl_user_op -def _dp4a_u32_acc_s32( - packed_a: cutlass.Uint32, - packed_b: cutlass.Uint32, - acc: cutlass.Int32, - *, - loc: Any = None, - ip: Any = None, -) -> cutlass.Int32: - """``dp4a.u32.u32`` — sum 4 packed u8 products into an s32 acc. - - Computes ``a[0]*b[0] + a[1]*b[1] + a[2]*b[2] + a[3]*b[3] + acc`` in - one ``IDP4A.U8.S32`` instruction (full-rate on Hopper/Blackwell). - For mask summation, pass ``packed_b = 0x01010101`` so the products - reduce to ``sum(a_bytes) + acc`` — 4× fewer ALU ops than 4 separate - int8→int32 widens + adds. - """ - return cutlass.Int32( - llvm.inline_asm( - T.i32(), - [ - cutlass.Uint32(packed_a).ir_value(loc=loc, ip=ip), - cutlass.Uint32(packed_b).ir_value(loc=loc, ip=ip), - cutlass.Int32(acc).ir_value(loc=loc, ip=ip), - ], - "dp4a.u32.u32 $0, $1, $2, $3;", - "=r,r,r,r", - has_side_effects=False, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - ) - - -@dsl_user_op -def _atomic_inc_u32_gmem( - ptr_i64: Any, - threshold: cutlass.Int32, - *, - loc: Any = None, - ip: Any = None, -) -> cutlass.Int32: - """``atom.global.inc.u32`` — returns old value; wraps to 0 at threshold.""" - return cutlass.Int32( - llvm.inline_asm( - T.i32(), - [ptr_i64, cutlass.Int32(threshold).ir_value(loc=loc, ip=ip)], - "atom.global.inc.u32 $0, [$1], $2;", - "=r,l,r", - has_side_effects=True, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - ) - - -# --------------------------------------------------------------------------- -# Mask-sum kernel — replaces ``torch.sum(completions_mask)`` on the fwd+bwd -# path. Bundled into the same ``@cute.jit`` launch as the main kernel so the -# whole step is one tvm-ffi dispatch (no extra Python/torch dispatcher round -# trip). The kernel writes ``1 / completions_mask.sum()`` directly into -# ``inv_total_tensor`` so the main kernel reads it as a pre-inverted scalar. -# --------------------------------------------------------------------------- - - -def _make_mask_sum_kernel(tile_n: int) -> Callable[..., None]: - """Return a ``@cute.kernel`` that reduces ``completions_mask`` and writes 1/sum. - - Grid mirrors the main kernel — ``(bs, num_col_tiles)`` — so the mask is - read once with the same vectorised LDG pattern as the main compute. - Each block: - - 1. Loads its ``tile_n`` int8 slab of ``completions_mask`` (predicated tail). - 2. Reduces to a per-block ``int32`` scalar (bit-exact, no per-element - i8→f32 cast — IADD throughput equals FADD on Hopper/Blackwell). - 3. Atomically adds it to ``valid_acc`` (global int32 accumulator). - 4. Increments ``mask_counter``; the last block reads ``valid_acc``, - casts to fp32, computes ``rcp_approx`` and writes - ``inv_total_tensor[0]``, then resets ``valid_acc`` to ``0`` so - the next call starts fresh. The counter self-resets via - ``atom.inc.u32`` wrap-around. - - A separate ``mask_counter`` tensor (not the main kernel's ``counter``) - is required because the two kernels run in series within the same - ``@cute.jit`` and both rely on a wrap-around for self-reset; sharing - one counter would race. - """ - - @cute.kernel - def _mask_sum_kernel( - completions_mask: cute.Tensor, # (bs, seq_len) int8 - inv_total_tensor: cute.Tensor, # (1,) fp32 — output - valid_acc: cute.Tensor, # (1,) int32 — accumulator - mask_counter: cute.Tensor, # (1,) i32 — last-block detection - total_blocks: cutlass.Int32, - num_full_tiles: cutlass.Int32, - tail_len: cutlass.Int32, - ) -> None: - block_size = NUM_WARPS * 32 - iters = tile_n // (block_size * VEC) - - _no_alloc = cute.nvgpu.CacheEvictionPriority.NO_ALLOCATE - g2r_op = cute.nvgpu.CopyUniversalOp() - g2r_mask_atom = cute.make_copy_atom( - g2r_op, - completions_mask.element_type, - num_bits_per_copy=0, - l1c_evict_priority=_no_alloc, - ) - - row = cute.arch.block_idx()[0] - col_block = cute.arch.block_idx()[1] - tid = cute.arch.thread_idx()[0] - - local_valid_sum = cutlass.Int32(0) - mask_row = cute.slice_(completions_mask, (row, None)) - - # ``dp4a.u32.u32`` consumes a packed-u8x4 register. With VEC=4 each - # thread loads 4 contiguous int8 bytes per iteration, so we recast - # the fragment as a single ``Uint32`` view and feed it directly - # into dp4a — one instruction sums all 4 bytes, vs the previous - # cast+reduce which emitted 4 widens + 3 adds per iteration. - ones_packed = cutlass.Uint32(0x01010101) - - if col_block < num_full_tiles: - mask_slab = cute.local_tile(mask_row, (tile_n,), (col_block,)) - for k in cutlass.range(iters, unroll_full=True): - sub_idx = tid + k * block_size - mask_src = cute.local_tile(mask_slab, (VEC,), (sub_idx,)) - mask_frag = cute.make_fragment_like(mask_src) - cute.copy(g2r_mask_atom, mask_src, mask_frag) - packed = cute.recast_tensor(mask_frag, cutlass.Uint32)[0] - local_valid_sum = _dp4a_u32_acc_s32(packed, ones_packed, local_valid_sum) - else: - mask_slab = cute.local_tile(mask_row, (tile_n,), (col_block,)) - for k in cutlass.range(iters, unroll_full=True): - sub_idx = tid + k * block_size - chunk_base = sub_idx * VEC - if chunk_base < tail_len: - mask_src = cute.local_tile(mask_slab, (VEC,), (sub_idx,)) - pred = cute.make_rmem_tensor(mask_src.shape, cutlass.Boolean) - for v in cutlass.range(VEC, unroll_full=True): - pred[v] = cute.elem_less(chunk_base + v, tail_len) - mask_frag = cute.make_fragment_like(mask_src) - mask_frag.fill(0) - cute.copy(g2r_mask_atom, mask_src, mask_frag, pred=pred) - packed = cute.recast_tensor(mask_frag, cutlass.Uint32)[0] - local_valid_sum = _dp4a_u32_acc_s32(packed, ones_packed, local_valid_sum) - - # Warp + cross-warp reduction (same pattern as main kernel). - warp_valid = cute.arch.warp_reduction(local_valid_sum, operator.add) - smem = cutlass.utils.SmemAllocator() - buf_valid = smem.allocate_tensor(cutlass.Int32, cute.make_layout(NUM_WARPS)) - - lane_idx = cute.arch.lane_idx() - warp_idx = cute.arch.warp_idx() - - if lane_idx == 0: - buf_valid[warp_idx] = warp_valid - cute.arch.barrier() - - if warp_idx == 0: - val_v = cutlass.Int32(0) - if lane_idx < NUM_WARPS: - val_v = buf_valid[lane_idx] - block_valid = cute.arch.warp_reduction(val_v, operator.add, threads_in_group=NUM_WARPS) - - if lane_idx == 0: - valid_ptr = valid_acc.iterator.toint().ir_value() # ty: ignore[unresolved-attribute] - counter_ptr = mask_counter.iterator.toint().ir_value() # ty: ignore[unresolved-attribute] - - _atomic_add_s32_gmem(valid_ptr, block_valid) - cute.arch.fence_acq_rel_gpu() - old = _atomic_inc_u32_gmem(counter_ptr, total_blocks - 1) - - if old == total_blocks - 1: - # Clamp to >=1.0 so a fully-masked batch (n_valid=0) - # produces ``loss=0`` instead of inf/NaN — matches - # TRL's ``mask.sum().clamp(min=1)`` semantics. - n_valid = cute.arch.fmax(cutlass.Float32(valid_acc[0]), cutlass.Float32(1.0)) - inv_total_tensor[0] = cute.arch.rcp_approx(n_valid) - valid_acc[0] = cutlass.Int32(0) - - return _mask_sum_kernel - - -def _make_bnpo_kernel( - compute_kl: bool, - compute_backward: bool, - tile_n: int, -) -> Callable[..., None]: - """Return a ``@cute.kernel`` specialised on compile-time flags. - - The returned kernel captures *compute_kl*, *compute_backward*, and - *tile_n* in its closure. ``cutlass.const_expr`` evaluates the booleans - at trace time so dead branches are eliminated from the compiled PTX. - ``tile_n`` is a Python ``int`` captured at trace time, so the same - factory can emit two specialised kernels (small / large tile) — see - :func:`create_compiled_bnpo_loss` for dispatch. - - When *compute_backward* is True the kernel additionally writes - ``dpolicy = dL/d(policy_logprobs)`` to GMEM in the same inner loop — - no extra HBM reads of the inputs. Because every block must scale - ``dpolicy`` by ``inv_total`` mid-loop, the on-GPU last-block computation - of ``inv_total`` from the masked accumulator does **not** compose with - backward; the bundled mask-sum kernel populates ``inv_total_tensor`` - before the main kernel runs. - - When *compute_backward* is False the kernel accumulates the - mask-element count via ``valid_acc`` and computes - ``inv_total = 1 / n_valid`` on-GPU in the last-block path — no - host-side ``completions_mask.sum()`` required. - """ - - @cute.kernel - def _bnpo_loss_kernel( - policy: cute.Tensor, - old_policy: cute.Tensor, - ref: cute.Tensor, - advantages: cute.Tensor, - completions_mask: cute.Tensor, - dpolicy: cute.Tensor, # (bs, seq_len) when compute_backward; (bs, 1) dummy otherwise - inv_total_tensor: cute.Tensor, # (1,) fp32 — caller-populated 1/n_valid - policy_acc: cute.Tensor, - kl_acc: cute.Tensor, - valid_acc: cute.Tensor, # (1,) int32 — mask-element count accumulator - counter: cute.Tensor, - output: cute.Tensor, - epsilon: cutlass.Float32, - epsilon_high: cutlass.Float32, - beta: cutlass.Float32, - total_blocks: cutlass.Int32, - num_full_tiles: cutlass.Int32, - tail_len: cutlass.Int32, - ) -> None: - block_size = NUM_WARPS * 32 - iters = tile_n // (block_size * VEC) - - # Read inv_total from GMEM once per block (hoisted, single load). - # Skipped on the fwd-only path which uses an on-GPU last-block - # computation from the valid_acc accumulator instead. On the - # compute_backward path the bundled mask-sum kernel writes - # ``1 / completions_mask.sum()`` into ``inv_total_tensor`` before - # this kernel runs, so the load returns the pre-inverted scalar. - accumulate_valid = not compute_backward - if cutlass.const_expr(not accumulate_valid): - inv_total = cast(inv_total_tensor[0], cutlass.Float32) - - _no_alloc = cute.nvgpu.CacheEvictionPriority.NO_ALLOCATE - g2r_op = cute.nvgpu.CopyUniversalOp() - g2r_atom = cute.make_copy_atom( - g2r_op, - policy.element_type, - num_bits_per_copy=0, - l1c_evict_priority=_no_alloc, - ) - g2r_mask_atom = cute.make_copy_atom( - g2r_op, - completions_mask.element_type, - num_bits_per_copy=0, - l1c_evict_priority=_no_alloc, - ) - if cutlass.const_expr(compute_backward): - r2g_atom = cute.make_copy_atom( - g2r_op, - dpolicy.element_type, - num_bits_per_copy=0, - ) - - row = cute.arch.block_idx()[0] - col_block = cute.arch.block_idx()[1] - tid = cute.arch.thread_idx()[0] - - adv = cast(advantages[row], cutlass.Float32) - lo = cutlass.Float32(1.0) - epsilon - hi = cutlass.Float32(1.0) + epsilon_high - - local_policy_sum = cutlass.Float32(0.0) - local_kl_sum = cutlass.Float32(0.0) - # mask_vec is already cast to fp32 for loss/kl multiplications, so - # accumulate valid in fp32 too (avoids a separate i8→i32 reduction). - # Cast to int32 only at the atomic boundary so the shared - # ``valid_acc`` global can remain int32 — see ``_atomic_add_s32_gmem``. - local_valid_sum = cutlass.Float32(0.0) - - pol_row = cute.slice_(policy, (row, None)) - old_row = cute.slice_(old_policy, (row, None)) - - if cutlass.const_expr(compute_kl): - ref_row = cute.slice_(ref, (row, None)) - - mask_row = cute.slice_(completions_mask, (row, None)) - - if cutlass.const_expr(compute_backward): - dp_row = cute.slice_(dpolicy, (row, None)) - - # ---- Full-tile vectorised path (LDG.128) ---- - if col_block < num_full_tiles: - pol_slab = cute.local_tile(pol_row, (tile_n,), (col_block,)) - old_slab = cute.local_tile(old_row, (tile_n,), (col_block,)) - - if cutlass.const_expr(compute_kl): - ref_slab = cute.local_tile(ref_row, (tile_n,), (col_block,)) - - mask_slab = cute.local_tile(mask_row, (tile_n,), (col_block,)) - - if cutlass.const_expr(compute_backward): - dp_slab = cute.local_tile(dp_row, (tile_n,), (col_block,)) - - for k in cutlass.range(iters, unroll_full=True): - sub_idx = tid + k * block_size - - pol_src = cute.local_tile(pol_slab, (VEC,), (sub_idx,)) - old_src = cute.local_tile(old_slab, (VEC,), (sub_idx,)) - pol_frag = cute.make_fragment_like(pol_src) - old_frag = cute.make_fragment_like(old_src) - cute.copy(g2r_atom, pol_src, pol_frag) - cute.copy(g2r_atom, old_src, old_frag) - - if cutlass.const_expr(compute_kl): - ref_src = cute.local_tile(ref_slab, (VEC,), (sub_idx,)) - ref_frag = cute.make_fragment_like(ref_src) - cute.copy(g2r_atom, ref_src, ref_frag) - - mask_src = cute.local_tile(mask_slab, (VEC,), (sub_idx,)) - mask_frag = cute.make_fragment_like(mask_src) - cute.copy(g2r_mask_atom, mask_src, mask_frag) - - pol_vec = pol_frag.load().to(cutlass.Float32) - old_vec = old_frag.load().to(cutlass.Float32) - - log_ratio = pol_vec - old_vec - ratio = cute.math.exp2(log_ratio * _LOG2_E, fastmath=True) - surrogate = ratio * adv - clipped_ratio = cute.where( - ratio < lo, - lo, - cute.where(ratio > hi, hi, ratio), - ) - clipped = clipped_ratio * adv - policy_loss = -cute.where(surrogate < clipped, surrogate, clipped) - - if cutlass.const_expr(compute_kl): - ref_vec = ref_frag.load().to(cutlass.Float32) - log_ratio_ref = ref_vec - pol_vec - ratio_ref = cute.math.exp2(log_ratio_ref * _LOG2_E, fastmath=True) - # FFMA-friendly rearrangement: ``(ratio_ref - 1) - log_ratio_ref`` - # exposes a ``ratio_ref + (-1)`` pair that ptxas folds with - # the subsequent subtract — same arithmetic, fewer FADDs - # surviving SASS than the original 3-term ``a - b - c``. - kl_val = (ratio_ref - cutlass.Float32(1.0)) - log_ratio_ref - - mask_vec = mask_frag.load().to(cutlass.Float32) - local_policy_sum += (policy_loss * mask_vec).reduce( - cute.ReductionOp.ADD, - cutlass.Float32(0.0), - reduction_profile=0, - ) - if cutlass.const_expr(not compute_backward): - local_valid_sum += mask_vec.reduce( - cute.ReductionOp.ADD, - cutlass.Float32(0.0), - reduction_profile=0, - ) - if cutlass.const_expr(compute_kl): - local_kl_sum += (kl_val * mask_vec).reduce( - cute.ReductionOp.ADD, - cutlass.Float32(0.0), - reduction_profile=0, - ) - - # ---- Backward: write scaled dpolicy slab in same loop ---- - # use_unclipped = (surrogate <= clipped) — matches torch's - # convention. d/d(policy) of -min(surrogate, clipped) is - # -adv*ratio when use_unclipped, else 0 (clamp grad = 0). - # ``-(adv * ratio)`` is just ``-surrogate`` (already in - # scope) — saves one FMUL per element. - # KL term: d/d(policy) of (ratio_ref - log_ratio_ref - 1) - # = -(ratio_ref - 1) = 1 - ratio_ref. - if cutlass.const_expr(compute_backward): - neg_surrogate_grad = cute.where( - surrogate <= clipped, - -surrogate, - cutlass.Float32(0.0), - ) - if cutlass.const_expr(compute_kl): - # ``beta - beta*ratio_ref`` instead of ``beta*(1 - ratio_ref)`` - # gives ptxas an obvious FFMA pattern (``FFMA -beta, - # ratio_ref, beta``) — saves one FMUL per element vs - # the (1 - ratio_ref) intermediate. - kl_grad = beta - beta * ratio_ref - dpolicy_vec = neg_surrogate_grad + kl_grad - else: - dpolicy_vec = neg_surrogate_grad - dpolicy_vec = dpolicy_vec * mask_vec - dpolicy_vec = dpolicy_vec * inv_total - - dp_dst = cute.local_tile(dp_slab, (VEC,), (sub_idx,)) - dp_frag = cute.make_fragment_like(dp_dst) - dp_frag.store(dpolicy_vec.to(dpolicy.element_type)) - cute.copy(r2g_atom, dp_frag, dp_dst) - - else: - # ---- Predicated vector tail path (< tile_n valid elements) ---- - pol_slab = cute.local_tile(pol_row, (tile_n,), (col_block,)) - old_slab = cute.local_tile(old_row, (tile_n,), (col_block,)) - - if cutlass.const_expr(compute_kl): - ref_slab = cute.local_tile(ref_row, (tile_n,), (col_block,)) - - mask_slab = cute.local_tile(mask_row, (tile_n,), (col_block,)) - - if cutlass.const_expr(compute_backward): - dp_slab = cute.local_tile(dp_row, (tile_n,), (col_block,)) - - for k in cutlass.range(iters, unroll_full=True): - sub_idx = tid + k * block_size - chunk_base = sub_idx * VEC - - if chunk_base < tail_len: - pol_src = cute.local_tile(pol_slab, (VEC,), (sub_idx,)) - old_src = cute.local_tile(old_slab, (VEC,), (sub_idx,)) - pred = cute.make_rmem_tensor(pol_src.shape, cutlass.Boolean) - for v in cutlass.range(VEC, unroll_full=True): - pred[v] = cute.elem_less(chunk_base + v, tail_len) - - pol_frag = cute.make_fragment_like(pol_src) - old_frag = cute.make_fragment_like(old_src) - pol_frag.fill(0.0) - old_frag.fill(0.0) - cute.copy(g2r_atom, pol_src, pol_frag, pred=pred) - cute.copy(g2r_atom, old_src, old_frag, pred=pred) - - if cutlass.const_expr(compute_kl): - ref_src = cute.local_tile(ref_slab, (VEC,), (sub_idx,)) - ref_frag = cute.make_fragment_like(ref_src) - ref_frag.fill(0.0) - cute.copy(g2r_atom, ref_src, ref_frag, pred=pred) - - mask_src = cute.local_tile(mask_slab, (VEC,), (sub_idx,)) - mask_frag = cute.make_fragment_like(mask_src) - mask_frag.fill(0) - cute.copy(g2r_mask_atom, mask_src, mask_frag, pred=pred) - - pol_vec = pol_frag.load().to(cutlass.Float32) - old_vec = old_frag.load().to(cutlass.Float32) - valid_vec = cute.where( - pred.load(), - cute.full_like(pol_vec, cutlass.Float32(1.0)), - cute.zeros_like(pol_vec, dtype=cutlass.Float32), - ) - - log_ratio = pol_vec - old_vec - ratio = cute.math.exp2(log_ratio * _LOG2_E, fastmath=True) - surrogate = ratio * adv - clipped_ratio = cute.where( - ratio < lo, - lo, - cute.where(ratio > hi, hi, ratio), - ) - clipped = clipped_ratio * adv - policy_loss = -cute.where(surrogate < clipped, surrogate, clipped) - - if cutlass.const_expr(compute_kl): - ref_vec = ref_frag.load().to(cutlass.Float32) - log_ratio_ref = ref_vec - pol_vec - ratio_ref = cute.math.exp2(log_ratio_ref * _LOG2_E, fastmath=True) - # FFMA-friendly rearrangement — see full-tile path. - kl_val = (ratio_ref - cutlass.Float32(1.0)) - log_ratio_ref - - mask_vec = mask_frag.load().to(cutlass.Float32) * valid_vec - local_policy_sum += (policy_loss * mask_vec).reduce( - cute.ReductionOp.ADD, - cutlass.Float32(0.0), - reduction_profile=0, - ) - if cutlass.const_expr(not compute_backward): - local_valid_sum += mask_vec.reduce( - cute.ReductionOp.ADD, - cutlass.Float32(0.0), - reduction_profile=0, - ) - if cutlass.const_expr(compute_kl): - local_kl_sum += (kl_val * mask_vec).reduce( - cute.ReductionOp.ADD, - cutlass.Float32(0.0), - reduction_profile=0, - ) - - # ---- Backward: predicated dpolicy slab write ---- - # Same gradient math as the full-tile path. ``valid_vec`` - # already encodes the in-bounds predicate (1.0 inside, - # 0.0 outside) and is folded into ``mask_vec``, so - # multiplying by it zeros out the padded positions. - if cutlass.const_expr(compute_backward): - neg_surrogate_grad = cute.where( - surrogate <= clipped, - -surrogate, - cutlass.Float32(0.0), - ) - if cutlass.const_expr(compute_kl): - kl_grad = beta - beta * ratio_ref - dpolicy_vec = neg_surrogate_grad + kl_grad - else: - dpolicy_vec = neg_surrogate_grad - dpolicy_vec = dpolicy_vec * mask_vec - dpolicy_vec = dpolicy_vec * inv_total - - dp_dst = cute.local_tile(dp_slab, (VEC,), (sub_idx,)) - dp_frag = cute.make_fragment_like(dp_dst) - dp_frag.store(dpolicy_vec.to(dpolicy.element_type)) - cute.copy(r2g_atom, dp_frag, dp_dst, pred=pred) - - # ---- Stage 1: Intra-warp reduction (butterfly XOR shuffles) ---- - warp_policy = cute.arch.warp_reduction(local_policy_sum, operator.add) - if cutlass.const_expr(compute_kl): - warp_kl = cute.arch.warp_reduction(local_kl_sum, operator.add) - - smem = cutlass.utils.SmemAllocator() - buf_policy = smem.allocate_tensor(cutlass.Float32, cute.make_layout(NUM_WARPS)) - if cutlass.const_expr(compute_kl): - buf_kl = smem.allocate_tensor(cutlass.Float32, cute.make_layout(NUM_WARPS)) - - lane_idx = cute.arch.lane_idx() - warp_idx = cute.arch.warp_idx() - - # When compute_backward is True the bundled mask-sum kernel populates - # inv_total_tensor before this kernel runs, so on-GPU mask-element - # accumulation is dead code. - if cutlass.const_expr(accumulate_valid): - warp_valid = cute.arch.warp_reduction(local_valid_sum, operator.add) - buf_valid = smem.allocate_tensor(cutlass.Float32, cute.make_layout(NUM_WARPS)) - - # ---- Stage 2: Cross-warp reduction via SMEM ---- - if lane_idx == 0: - buf_policy[warp_idx] = warp_policy - if cutlass.const_expr(compute_kl): - buf_kl[warp_idx] = warp_kl - if cutlass.const_expr(accumulate_valid): - buf_valid[warp_idx] = warp_valid - cute.arch.barrier() - - if warp_idx == 0: - val_p = cutlass.Float32(0.0) - if lane_idx < NUM_WARPS: - val_p = buf_policy[lane_idx] - - block_policy = cute.arch.warp_reduction(val_p, operator.add, threads_in_group=NUM_WARPS) - - if cutlass.const_expr(compute_kl): - val_k = cutlass.Float32(0.0) - if lane_idx < NUM_WARPS: - val_k = buf_kl[lane_idx] - block_kl = cute.arch.warp_reduction(val_k, operator.add, threads_in_group=NUM_WARPS) - - if cutlass.const_expr(accumulate_valid): - val_v = cutlass.Float32(0.0) - if lane_idx < NUM_WARPS: - val_v = buf_valid[lane_idx] - block_valid = cute.arch.warp_reduction( - val_v, operator.add, threads_in_group=NUM_WARPS - ) - - # ---- Stage 3: Cross-CTA atomic accumulation ---- - if lane_idx == 0: - policy_ptr = policy_acc.iterator.toint().ir_value() # ty: ignore[unresolved-attribute] - counter_ptr = counter.iterator.toint().ir_value() # ty: ignore[unresolved-attribute] - - _atomic_add_f32_gmem(policy_ptr, block_policy) - - if cutlass.const_expr(compute_kl): - kl_ptr = kl_acc.iterator.toint().ir_value() # ty: ignore[unresolved-attribute] - _atomic_add_f32_gmem(kl_ptr, block_kl) - - if cutlass.const_expr(accumulate_valid): - valid_ptr = valid_acc.iterator.toint().ir_value() # ty: ignore[unresolved-attribute] - # valid_acc is int32. Per-block sums of int8 0/1 values - # fit exactly in fp32 (≤ tile_n ≤ 4096 ≪ 2²⁴) so the - # cast is bit-exact. - _atomic_add_s32_gmem(valid_ptr, cutlass.Int32(block_valid)) - - cute.arch.fence_acq_rel_gpu() - - old = _atomic_inc_u32_gmem(counter_ptr, total_blocks - 1) - - if old == total_blocks - 1: - pol_sum = policy_acc[0] - - if cutlass.const_expr(accumulate_valid): - # Clamp to >=1.0 so a fully-masked batch (n_valid=0) - # produces ``loss=0`` instead of inf/NaN — matches - # TRL's ``mask.sum().clamp(min=1)`` semantics. - n_valid = cute.arch.fmax( - cutlass.Float32(valid_acc[0]), cutlass.Float32(1.0) - ) - inv_total_computed = cute.arch.rcp_approx(n_valid) - else: - # compute_backward path: bundled mask-sum kernel - # already wrote the inverse so forward and backward - # share the same scalar. - inv_total_computed = inv_total - - if cutlass.const_expr(compute_kl): - kl_sum = kl_acc[0] - loss = (pol_sum + beta * kl_sum) * inv_total_computed - else: - loss = pol_sum * inv_total_computed - output[0] = cast(loss, output.element_type) # ty: ignore[invalid-argument-type] - - # Reset accumulators for the next invocation. - # Counter self-resets via atom.inc wrap-around. - policy_acc[0] = cutlass.Float32(0.0) - if cutlass.const_expr(compute_kl): - kl_acc[0] = cutlass.Float32(0.0) - if cutlass.const_expr(accumulate_valid): - valid_acc[0] = cutlass.Int32(0) - - return _bnpo_loss_kernel - - -def create_compiled_bnpo_loss( - policy_dtype: torch.dtype, - epsilon: float, - epsilon_high: float, - beta: float, - compute_backward: bool = False, -) -> Callable[..., torch.Tensor | tuple[torch.Tensor, torch.Tensor]]: - """Compile the bnpo loss kernel for a given dtype/KL/backward configuration. - - The runner allocates per-call scratch (``output``, ``inv_total``, and on - the fwd+bwd path ``dpolicy``) inside ``_run`` itself; cross-CTA scratch - (atomic accumulators + counters) is allocated lazily on first call from - the input device and self-resets each launch via the kernel's last-block - epilogue + ``atom.inc.u32`` wrap-around. - """ - compute_kl = beta != 0.0 - - if policy_dtype not in _TORCH_TO_CUTLASS_DTYPE: - raise ValueError(f"Unsupported dtype for bnpo kernel: {policy_dtype}") - - tile_n_small = TILE_N - tile_n_large = TILE_N_LARGE - seq_len_threshold = TILE_N_LARGE_THRESHOLD - block_size = NUM_WARPS * 32 - if tile_n_small % (block_size * VEC) != 0: - raise ValueError( - f"TILE_N={tile_n_small} must be a multiple of BLOCK_SIZE*VEC={block_size * VEC}" - ) - if tile_n_large % (block_size * VEC) != 0: - raise ValueError( - f"TILE_N_LARGE={tile_n_large} must be a multiple of BLOCK_SIZE*VEC={block_size * VEC}" - ) - - bs_sym = cute.sym_int() - seq_len_sym = cute.sym_int() - cute_dtype = _TORCH_TO_CUTLASS_DTYPE[policy_dtype] - - def _fake2d(dt: Any, cols: Any) -> Any: - return cute.runtime.make_fake_compact_tensor( - dt, - (bs_sym, cols), - stride_order=(1, 0), - assumed_align=16, - ) - - fake_pol = _fake2d(cute_dtype, seq_len_sym) - fake_old = _fake2d(cute_dtype, seq_len_sym) - fake_ref = _fake2d(cute_dtype, seq_len_sym) - fake_adv = cute.runtime.make_fake_compact_tensor( - cute_dtype, - (bs_sym,), - assumed_align=16, - ) - fake_mask = cute.runtime.make_fake_compact_tensor( - cutlass.Int8, - (bs_sym, seq_len_sym), - stride_order=(1, 0), - assumed_align=16, - ) - dpolicy_cols = seq_len_sym if compute_backward else 1 - fake_dpolicy = cute.runtime.make_fake_compact_tensor( - cute_dtype, - (bs_sym, dpolicy_cols), - stride_order=(1, 0), - assumed_align=16, - ) - fake_scalar_f32 = cute.runtime.make_fake_compact_tensor( - cutlass.Float32, - (1,), - assumed_align=16, - ) - fake_valid_acc = cute.runtime.make_fake_compact_tensor( - cutlass.Int32, - (1,), - assumed_align=16, - ) - fake_counter = cute.runtime.make_fake_compact_tensor( - cutlass.Int32, - (1,), - assumed_align=16, - ) - fake_mask_counter = cute.runtime.make_fake_compact_tensor( - cutlass.Int32, - (1,), - assumed_align=16, - ) - fake_output = cute.runtime.make_fake_compact_tensor( - cute_dtype, - (1,), - assumed_align=16, - ) - - def _build_launch(tile_n_v: int) -> Callable[..., None]: - """Build a ``@cute.jit`` ``_launch`` for a given ``tile_n``. - - Captures *tile_n_v* via closure; both the main kernel and the - (optional) mask-sum kernel are specialised to this tile size. - One ``_launch`` per tier; the runner dispatches at call time. - """ - specialized_kernel = _make_bnpo_kernel(compute_kl, compute_backward, tile_n_v) - if compute_backward: - mask_sum_kernel = _make_mask_sum_kernel(tile_n_v) - - @cute.jit - def _launch( - pol_ct: cute.Tensor, - old_ct: cute.Tensor, - ref_ct: cute.Tensor, - adv_ct: cute.Tensor, - mask_ct: cute.Tensor, - dpolicy_ct: cute.Tensor, - inv_total_ct: cute.Tensor, - policy_acc_ct: cute.Tensor, - kl_acc_ct: cute.Tensor, - valid_acc_ct: cute.Tensor, - counter_ct: cute.Tensor, - mask_counter_ct: cute.Tensor, - output_ct: cute.Tensor, - epsilon_v: cutlass.Float32, - epsilon_high_v: cutlass.Float32, - beta_v: cutlass.Float32, - total_blocks_v: cutlass.Int32, - num_full_tiles_v: cutlass.Int32, - tail_len_v: cutlass.Int32, - num_col_tiles_v: cutlass.Int32, - ) -> None: - bs_v = pol_ct.shape[0] # ty: ignore[not-subscriptable] - # Bundled mask-sum (compute_backward only) — writes - # ``1 / completions_mask.sum()`` into ``inv_total_ct`` before the - # main kernel reads it. Both kernels in one tvm-ffi dispatch - # eliminates the per-call ``torch.sum`` + reciprocal round trip. - if cutlass.const_expr(compute_backward): - mask_sum_kernel( # ty: ignore[unresolved-attribute] - mask_ct, - inv_total_ct, - valid_acc_ct, - mask_counter_ct, - total_blocks_v, - num_full_tiles_v, - tail_len_v, - ).launch( - grid=(bs_v, num_col_tiles_v, 1), - block=(NUM_WARPS * 32, 1, 1), - ) - specialized_kernel( # ty: ignore[unresolved-attribute] - pol_ct, - old_ct, - ref_ct, - adv_ct, - mask_ct, - dpolicy_ct, - inv_total_ct, - policy_acc_ct, - kl_acc_ct, - valid_acc_ct, - counter_ct, - output_ct, - epsilon_v, - epsilon_high_v, - beta_v, - total_blocks_v, - num_full_tiles_v, - tail_len_v, - ).launch( - grid=(bs_v, num_col_tiles_v, 1), - block=(NUM_WARPS * 32, 1, 1), - ) - - return _launch - - def _compile_launch(launch_fn: Callable[..., None]) -> Callable[..., None]: - return cute.compile( - launch_fn, - fake_pol, - fake_old, - fake_ref, - fake_adv, - fake_mask, - fake_dpolicy, - fake_scalar_f32, - fake_scalar_f32, - fake_scalar_f32, - fake_valid_acc, - fake_counter, - fake_mask_counter, - fake_output, - cutlass.Float32(epsilon), - cutlass.Float32(epsilon_high), - cutlass.Float32(beta), - cutlass.Int32(1), - cutlass.Int32(1), - cutlass.Int32(0), - cutlass.Int32(1), - options="--enable-tvm-ffi", - ) - - compiled_small = _compile_launch(_build_launch(tile_n_small)) - if tile_n_large == tile_n_small: - compiled_large = compiled_small - else: - compiled_large = _compile_launch(_build_launch(tile_n_large)) - - eps_const = cutlass.Float32(epsilon) - eps_high_const = cutlass.Float32(epsilon_high) - beta_const = cutlass.Float32(beta) - - # Cross-CTA scratch slab — one int32 buffer with stride-4 (16-byte) slices - # so each slot is individually 16-byte aligned (``assumed_align=16`` at - # compile time). Bit-pattern of int32 0 equals fp32 0.0, so a single - # ``zeros`` factory legitimately initialises both the int32 counters and - # the fp32 accumulators. The kernel's last block self-resets accumulators - # in its epilogue and the counters self-reset via ``atom.inc.u32`` - # wrap-around, so the up-front ``torch.zeros`` only matters for the very - # first call. - _scratch: list[torch.Tensor | None] = [None] - - def _ensure_scratch(device: torch.device) -> tuple[torch.Tensor, ...]: - s = _scratch[0] - if s is None or s.device != device: - s = torch.zeros(20, dtype=torch.int32, device=device) - _scratch[0] = s - return ( - s[0:1], # counter (int32) - s[4:5], # mask_counter (int32) - s[8:9], # valid_acc (int32) - s[12:13].view(torch.float32), # policy_acc (fp32) - s[16:17].view(torch.float32), # kl_acc (fp32) - ) - - def _run( - policy_logprobs_r: torch.Tensor, - old_policy_logprobs_r: torch.Tensor, - ref_logprobs_r: torch.Tensor, - advantages_r: torch.Tensor, - completions_mask_r: torch.Tensor, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - bs, seq_len = policy_logprobs_r.shape - device = policy_logprobs_r.device - dtype = policy_logprobs_r.dtype - - # Tier dispatch: long sequences pay too much last-block-detection - # latency under the small-tile grid, so swap to the large-tile - # compiled variant. - if seq_len >= seq_len_threshold: - tile_n_active = tile_n_large - compiled_active = compiled_large - else: - tile_n_active = tile_n_small - compiled_active = compiled_small - num_full_tiles = seq_len // tile_n_active - tail_len = seq_len % tile_n_active - num_col_tiles = num_full_tiles + (1 if tail_len > 0 else 0) - total_blocks = bs * num_col_tiles - - # Per-call write-only buffers — ``empty`` is enough (Liger / TE - # pattern). ``inv_total`` is populated by the bundled mask-sum - # kernel (compute_backward path) or by the main kernel's last-block - # trick (fwd-only path); the runner never reads it. - output_r = torch.empty(1, dtype=dtype, device=device) - inv_total_r = torch.empty(1, dtype=torch.float32, device=device) - if compute_backward: - dpolicy_r = torch.empty_like(policy_logprobs_r) - else: - dpolicy_r = torch.empty(bs, 1, dtype=dtype, device=device) - - counter_r, mask_counter_r, valid_acc_r, policy_acc_r, kl_acc_r = _ensure_scratch(device) - - compiled_active( - policy_logprobs_r, - old_policy_logprobs_r, - ref_logprobs_r, - advantages_r, - completions_mask_r, - dpolicy_r, - inv_total_r, - policy_acc_r, - kl_acc_r, - valid_acc_r, - counter_r, - mask_counter_r, - output_r, - eps_const, - eps_high_const, - beta_const, - total_blocks, - num_full_tiles, - tail_len, - num_col_tiles, - ) - out_view = output_r.view(()) - if compute_backward: - return out_view, dpolicy_r - return out_view - - return _run - - -# --------------------------------------------------------------------------- -# Fused forward + backward — direct (loss, grad) runner, no autograd -# --------------------------------------------------------------------------- - - -def create_compiled_bnpo_loss_with_backward( - policy_dtype: torch.dtype, - epsilon: float, - epsilon_high: float, - beta: float, -) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]: - """Compile the fused fwd+bwd bnpo kernel and return a tuple-returning runner. - - The returned callable runs one training-step worth of work: a single - ``@cute.jit`` dispatch produces both the scalar loss and the scaled - ``dL/d(policy_logprobs)`` tensor. It returns ``(loss, dpolicy)`` directly - — no ``torch.autograd.Function`` wrapper, no extra ``grad_output * dpolicy`` - backward kernel. Callers that need autograd integration (so - ``loss.backward()`` works) wrap this themselves at the public-API layer; - callers that control gradient flow manually (benchmarks, custom training - loops) can use it as-is for zero overhead. - - ``inv_total`` is computed entirely on-GPU by a bundled mask-sum kernel - that runs in series with the main kernel inside the same ``@cute.jit`` - launch — no host sync, no extra ``torch.sum`` dispatch, CUDA-graph - compatible. - """ - return _typing_cast( - "Callable[..., tuple[torch.Tensor, torch.Tensor]]", - create_compiled_bnpo_loss( - policy_dtype=policy_dtype, - epsilon=epsilon, - epsilon_high=epsilon_high, - beta=beta, - compute_backward=True, - ), - ) diff --git a/build/torch-cuda/geometric_ai_kernels/__init__.py b/build/torch-cuda/geometric_ai_kernels/__init__.py deleted file mode 100644 index a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23..0000000000000000000000000000000000000000 --- a/build/torch-cuda/geometric_ai_kernels/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -import ctypes -import importlib.util -import sys -from pathlib import Path -from types import ModuleType - - -def _import_from_path(file_path: Path) -> ModuleType: - # We cannot use the module name as-is, after adding it to `sys.modules`, - # it would also be used for other imports. So, we make a module name that - # depends on the path for it to be unique using the hex-encoded hash of - # the path. - path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) - module_name = path_hash - spec = importlib.util.spec_from_file_location(module_name, file_path) - if spec is None: - raise ImportError(f"Cannot load spec for {module_name} from {file_path}") - module = importlib.util.module_from_spec(spec) - if module is None: - raise ImportError(f"Cannot load module {module_name} from spec") - sys.modules[module_name] = module - spec.loader.exec_module(module) # type: ignore - return module - - -globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch-cuda/grpo_loss/__init__.py b/build/torch-cuda/grpo_loss/__init__.py deleted file mode 100644 index 24ef3038feeefd2ac304ae1ae25105c8bbb48fff..0000000000000000000000000000000000000000 --- a/build/torch-cuda/grpo_loss/__init__.py +++ /dev/null @@ -1,169 +0,0 @@ -"""GRPO loss with CuteDSL fused fwd+bwd. - -Three public APIs: - -* :func:`grpo_loss` — fused fwd+bwd. Returns - ``(loss, grad_policy_logprobs)`` from a single ``@cute.jit`` dispatch. - Caller chains via ``policy_logprobs.backward(grad)``. -* :func:`grpo_loss_fwd` — forward-only (inference / validation). - Returns scalar ``loss`` and skips the dpolicy buffer entirely. -* :func:`grpo_loss_autograd` — autograd-aware via - ``torch.library.custom_op``. ``loss.backward()`` works and composes - with ``torch.compile``. ~12µs of dispatcher overhead vs. - :func:`grpo_loss`. - -GRPO requires ``completions_mask``; the per-response normalization -formula is mask-derived. The cute kernel uses one CTA per row so the -per-row mask sum is reduced inside the block — no cross-CTA atomics -or last-block detection on the per-row scaling pass. - -Per-call output and gradient buffers are allocated inside the runner; -cross-CTA scratch (the ``policy_acc`` accumulator + last-block counter) -is owned by the compiled-kernel closure and self-resets each launch. -""" - -from __future__ import annotations - -from functools import lru_cache -from typing import TYPE_CHECKING, cast - -import torch - -from .cute_grpo_loss import create_compiled_grpo_loss - -if TYPE_CHECKING: - from collections.abc import Callable - - -__all__ = ["grpo_loss", "grpo_loss_autograd", "grpo_loss_fwd"] - - -@lru_cache(maxsize=32) -def _get_compiled_fwd( - dtype: torch.dtype, - epsilon: float, - epsilon_high: float, - beta: float, -) -> Callable[..., torch.Tensor]: - return cast( - "Callable[..., torch.Tensor]", - create_compiled_grpo_loss( - policy_dtype=dtype, - epsilon=epsilon, - epsilon_high=epsilon_high, - beta=beta, - compute_backward=False, - ), - ) - - -@lru_cache(maxsize=32) -def _get_compiled_fwd_bwd( - dtype: torch.dtype, - epsilon: float, - epsilon_high: float, - beta: float, -) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]: - return cast( - "Callable[..., tuple[torch.Tensor, torch.Tensor]]", - create_compiled_grpo_loss( - policy_dtype=dtype, - epsilon=epsilon, - epsilon_high=epsilon_high, - beta=beta, - compute_backward=True, - ), - ) - - -def _mask_to_int8(completions_mask: torch.Tensor) -> torch.Tensor: - return ( - completions_mask - if completions_mask.dtype == torch.int8 - else completions_mask.to(torch.int8) - ) - - -def grpo_loss_fwd( - policy_logprobs: torch.Tensor, - old_policy_logprobs: torch.Tensor, - ref_logprobs: torch.Tensor, - advantages: torch.Tensor, - completions_mask: torch.Tensor, - epsilon: float = 0.2, - epsilon_high: float = 0.2, - beta: float = 0.1, -) -> torch.Tensor: - """Forward-only GRPO loss. Returns the scalar ``loss``. - - Args: - policy_logprobs, old_policy_logprobs, ref_logprobs: ``(bs, seq_len)``. - advantages: ``(bs,)``. - completions_mask: Bool / int8 mask ``(bs, seq_len)``. Required. - epsilon, epsilon_high: PPO-style clipping bounds. - beta: KL-penalty coefficient. ``0.0`` compiles away the KL branch. - - Returns: - Scalar tensor (0-dim) with the same dtype as ``policy_logprobs``. - """ - run = _get_compiled_fwd( - policy_logprobs.dtype, - float(epsilon), - float(epsilon_high), - float(beta), - ) - return run( - policy_logprobs, - old_policy_logprobs, - ref_logprobs, - advantages, - _mask_to_int8(completions_mask), - ) - - -def grpo_loss( - policy_logprobs: torch.Tensor, - old_policy_logprobs: torch.Tensor, - ref_logprobs: torch.Tensor, - advantages: torch.Tensor, - completions_mask: torch.Tensor, - epsilon: float = 0.2, - epsilon_high: float = 0.2, - beta: float = 0.1, -) -> tuple[torch.Tensor, torch.Tensor]: - """Fused fwd+bwd GRPO loss. Returns ``(loss, grad_policy_logprobs)``. - - Inputs do **not** need ``requires_grad=True``. Chain ``grad`` into - the upstream model via ``policy_logprobs.backward(grad)``. - - Args: - policy_logprobs, old_policy_logprobs, ref_logprobs: ``(bs, seq_len)``. - advantages: ``(bs,)``. - completions_mask: Bool / int8 mask ``(bs, seq_len)``. Required. - epsilon, epsilon_high: PPO-style clipping bounds. - beta: KL-penalty coefficient. ``0.0`` compiles away the KL branch. - - Returns: - ``(loss, grad_policy_logprobs)``. The grad already has the - per-row ``1 / mask.sum(-1).clamp(min=1)`` and across-row - ``1/n_rows`` scalings folded in. The grad tensor is freshly - allocated per call (no shared cache). - """ - run = _get_compiled_fwd_bwd( - policy_logprobs.dtype, - float(epsilon), - float(epsilon_high), - float(beta), - ) - return run( - policy_logprobs, - old_policy_logprobs, - ref_logprobs, - advantages, - _mask_to_int8(completions_mask), - ) - - -# ``autograd.py`` imports ``grpo_loss`` from this module, so the function -# must be fully defined before its import runs. -from .autograd import grpo_loss_autograd # noqa: E402 diff --git a/build/torch-cuda/grpo_loss/_torch_ref.py b/build/torch-cuda/grpo_loss/_torch_ref.py deleted file mode 100644 index c5f7d023554763a9ef8cbcae7390890e40f70aea..0000000000000000000000000000000000000000 --- a/build/torch-cuda/grpo_loss/_torch_ref.py +++ /dev/null @@ -1,87 +0,0 @@ -"""Plain-PyTorch GRPO reference shared between bench and tests. - -Mirrors TRL's default per-response normalization variant: per-row mask -sum acts as the divisor for that row's loss before averaging across -rows. Every op is a vanilla torch op so AOTAutograd can derive the -joint fwd+bwd graph and Inductor can fuse both passes. -""" - -from __future__ import annotations - -import torch - - -def torch_grpo_loss( - policy_logprobs: torch.Tensor, - old_policy_logprobs: torch.Tensor, - ref_logprobs: torch.Tensor, - advantages: torch.Tensor, - completions_mask: torch.Tensor, - epsilon: float = 0.2, - epsilon_high: float = 0.2, - beta: float = 0.1, -) -> torch.Tensor: - """Compute the GRPO (Group Relative Policy Optimization) loss. - - Implements TRL's default per-response normalization variant: - - L_GRPO = mean_r( ((policy_loss + beta*kl) * mask).sum(-1) - / mask.sum(-1).clamp(min=1) ) - - Each row (response) is independently normalized by its own valid-token - count, then the per-row losses are averaged across rows. This differs - from the BNPO variant in :func:`torch_bnpo_loss`, which sums numerators - *and* denominators globally before dividing — under variable response - lengths BNPO weights longer responses more heavily, while GRPO weights - every response equally. - - **Probability ratio:** - - r_t(theta) = exp(log pi_theta - log pi_{theta_old}) - - **Clipped surrogate (per token):** - - L_CLIP_t = -min( r_t * A, clip(r_t, 1 - eps, 1 + eps_high) * A ) - - **KL divergence (Schulman approximation, per token):** - - kl_t ~= exp(log pi_ref - log pi_theta) - (log pi_ref - log pi_theta) - 1 - - Args: - policy_logprobs: Log-probabilities of the current policy, shape (N, C). - old_policy_logprobs: Log-probabilities of the behaviour policy used to - collect the rollout, shape (N, C). - ref_logprobs: Log-probabilities of the frozen reference policy, shape - (N, C). - advantages: Per-sequence advantage estimates, shape (N,). - completions_mask: Boolean mask of shape (N, C) where True marks valid tokens. - Required — GRPO's per-response normalization is mask-derived. - epsilon: Lower asymmetric clipping bound (1 - epsilon). Default: 0.2. - epsilon_high: Upper asymmetric clipping bound (1 + ε_high). Default: - 0.2 (symmetric with ``epsilon``, matching TRL's GRPOConfig). - beta: Coefficient for the KL-divergence penalty. Default: 0.1. - - Returns: - Scalar tensor representing the GRPO loss. - """ - ratio = torch.exp(policy_logprobs - old_policy_logprobs) - adv = advantages.unsqueeze(1) # (N, 1) for broadcasting over tokens - - surrogate = ratio * adv - surrogate_clipped = torch.clamp(ratio, 1.0 - epsilon, 1.0 + epsilon_high) * adv - policy_loss_per_tok = -torch.min(surrogate, surrogate_clipped) - - # Per-row valid-token count, fp32 + clamp to avoid div-by-zero on - # fully-masked rows (matches TRL's ``mask.sum(-1).clamp(min=1)``). - mask_sum = completions_mask.sum(-1).to(torch.float32).clamp_min(1.0) # (N,) - policy_per_row = (policy_loss_per_tok * completions_mask).sum(-1) / mask_sum - - if beta != 0.0: - log_ratio_ref = ref_logprobs - policy_logprobs - kl_per_tok = torch.exp(log_ratio_ref) - log_ratio_ref - 1.0 - kl_per_row = (kl_per_tok * completions_mask).sum(-1) / mask_sum - loss = (policy_per_row + beta * kl_per_row).mean() - else: - loss = policy_per_row.mean() - - return loss.to(policy_logprobs.dtype) diff --git a/build/torch-cuda/grpo_loss/autograd.py b/build/torch-cuda/grpo_loss/autograd.py deleted file mode 100644 index def81c869c70c8796c720d7f0572c5e52b6fb65c..0000000000000000000000000000000000000000 --- a/build/torch-cuda/grpo_loss/autograd.py +++ /dev/null @@ -1,120 +0,0 @@ -"""Autograd-aware wrapper for GRPO loss via ``torch.library.custom_op``. - -The fused cute kernel writes both the scalar loss and the closed-form -``dL/d(policy_logprobs)`` in one launch. This module wraps that into an -autograd-compatible op so callers can write:: - - loss = grpo_loss_autograd(policy, old, ref, adv, completions_mask) - loss.backward() - -Implementation mirrors the BNPO autograd binding: ``custom_op`` with a -registered ``setup_context`` / backward, plus a ``register_fake`` for -shape propagation under ``torch.compile``. The runner allocates -``dpolicy`` fresh on every call (no shared cache), so -``ctx.save_for_backward(dpolicy)`` keeps a stable reference for free. -""" - -from __future__ import annotations - -import torch - -from . import grpo_loss as _grpo_loss_fwd_bwd - -__all__ = ["grpo_loss_autograd"] - - -@torch.library.custom_op( - "geometric_ai_kernels::_grpo_loss_with_grad", - mutates_args=(), -) -def _grpo_loss_with_grad( - policy_logprobs: torch.Tensor, - old_policy_logprobs: torch.Tensor, - ref_logprobs: torch.Tensor, - advantages: torch.Tensor, - completions_mask: torch.Tensor, - epsilon: float, - epsilon_high: float, - beta: float, -) -> tuple[torch.Tensor, torch.Tensor]: - loss, dpolicy = _grpo_loss_fwd_bwd( - policy_logprobs, - old_policy_logprobs, - ref_logprobs, - advantages, - completions_mask, - epsilon=epsilon, - epsilon_high=epsilon_high, - beta=beta, - ) - return loss, dpolicy - - -@_grpo_loss_with_grad.register_fake -def _( - policy_logprobs: torch.Tensor, - old_policy_logprobs: torch.Tensor, - ref_logprobs: torch.Tensor, - advantages: torch.Tensor, - completions_mask: torch.Tensor, - epsilon: float, - epsilon_high: float, - beta: float, -) -> tuple[torch.Tensor, torch.Tensor]: - del old_policy_logprobs, ref_logprobs, advantages, completions_mask - del epsilon, epsilon_high, beta - loss = policy_logprobs.new_empty(()) - dpolicy = torch.empty_like(policy_logprobs) - return loss, dpolicy - - -def _setup_context(ctx, inputs, output) -> None: # type: ignore[no-untyped-def] - del inputs - _, dpolicy = output - ctx.save_for_backward(dpolicy) - - -def _backward(ctx, grad_loss, grad_dpolicy): # type: ignore[no-untyped-def] - del grad_dpolicy - (dpolicy,) = ctx.saved_tensors - grad_policy = grad_loss * dpolicy - # One return per input (8): policy_logprobs gets the grad, the rest get None. - return grad_policy, None, None, None, None, None, None, None - - -torch.library.register_autograd( - "geometric_ai_kernels::_grpo_loss_with_grad", - _backward, - setup_context=_setup_context, -) - - -def grpo_loss_autograd( - policy_logprobs: torch.Tensor, - old_policy_logprobs: torch.Tensor, - ref_logprobs: torch.Tensor, - advantages: torch.Tensor, - completions_mask: torch.Tensor, - epsilon: float = 0.2, - epsilon_high: float = 0.2, - beta: float = 0.1, -) -> torch.Tensor: - """Autograd-aware GRPO loss. Returns scalar ``loss``. - - Same numerics as :func:`grpo_loss` but registered as a - ``torch.library`` custom op with autograd, so ``loss.backward()`` - Just Works. For direct ``(loss, grad)`` access without the - autograd dispatcher overhead, use :func:`grpo_loss` and chain via - ``policy_logprobs.backward(grad)``. - """ - loss, _ = _grpo_loss_with_grad( - policy_logprobs, - old_policy_logprobs, - ref_logprobs, - advantages, - completions_mask, - float(epsilon), - float(epsilon_high), - float(beta), - ) - return loss diff --git a/build/torch-cuda/grpo_loss/cute_grpo_loss.py b/build/torch-cuda/grpo_loss/cute_grpo_loss.py deleted file mode 100644 index 3e08941a3b191b22c87c7677023378bb3f98fdf0..0000000000000000000000000000000000000000 --- a/build/torch-cuda/grpo_loss/cute_grpo_loss.py +++ /dev/null @@ -1,805 +0,0 @@ -"""CuteDSL kernel for GRPO (Group Relative Policy Optimization) loss. - -Implements TRL's default per-response normalization variant: - - loss = mean_r( ((per_token_loss + beta * kl) * mask).sum(-1) - / mask.sum(-1).clamp(min=1) ) - -Element-wise over ``(N, C)`` logprob tensors: - - ratio = exp(policy - old_policy) - surrogate = ratio * adv - clipped = clip(ratio, 1 - eps, 1 + eps_high) * adv - policy_loss = -min(surrogate, clipped) - log_ratio_ref = ref - policy - kl = exp(log_ratio_ref) - log_ratio_ref - 1 - -``completions_mask`` is **required** for GRPO -- the per-response normalization -formula is mask-derived. - -**One CTA per row.** Each row is owned by exactly one block, so the -per-row mask sum is computed locally (warp + cross-warp reduction) with -no cross-CTA atomics, fences, or last-block detection on the per-row -scaling pass. - -**Mask prescan on the fwd+bwd path.** Before the main compute pass each -CTA runs a cheap mask-only sweep over its row to derive ``row_scale`` -(``1 / max(mask.sum(), 1) * inv_n_rows``). With ``row_scale`` known up -front the main pass reads ``policy / old / ref`` exactly **once**, -computes loss and gradient together, and writes the **scaled** -``dpolicy`` directly — no second logprob read, no unscaled GMEM round -trip. The prescan touches only the int8 mask (1 byte / element) and is -much cheaper than the byte of logprob traffic it eliminates (2 B in -bf16/fp16, 4 B in fp32, ×2 or ×3 depending on the KL term). - -The fwd-only path skips the prescan and accumulates ``valid`` directly -in a single sweep over the row. - -When ``beta=0`` the KL term is skipped at compile time (no ``ref`` -tensor access, no ``kl`` accumulator). - -Sequence lengths that are **not** a multiple of ``TILE_N`` are handled -natively: the in-block tile loop runs ``ceil(C / TILE_N)`` iterations; -full tiles use the vectorized ``LDG.128`` path and the tail tile uses -predicated vector loads with neutral prefill. - -Each CTA reduces its row's policy / kl sums in registers, finishes the -cross-warp reduction in SMEM, computes its scaled per-row contribution -``(block_policy + beta * block_kl) * row_scale`` locally, and -``atomicAdd``s the scalar result into ``policy_acc[0]``. The last CTA -— detected via a single grid-scope ``atomic_inc`` on ``counter`` — -reads ``policy_acc[0]``, casts to the output dtype, writes -``output[0]``, and resets the accumulator to ``0``. -""" - -from __future__ import annotations - -import math -import operator -from typing import TYPE_CHECKING, Any - -import cutlass -import cutlass.utils -import torch -from cutlass import cute -from cutlass.base_dsl.typing import cast - -from ..bnpo_loss.cute_bnpo_loss import ( - _atomic_add_f32_gmem, - _atomic_inc_u32_gmem, -) - -if TYPE_CHECKING: - from collections.abc import Callable - - -TILE_N: int = 2048 -NUM_WARPS: int = 16 -VEC: int = 4 -# Long-context variant: a wider tile keeps the per-row block reduction -# cheap when ``seq_len`` blows past the small tier (where the inner -# col-tile loop iterates many times). Two compiled variants — small + -# large — are dispatched at runtime by sequence length. -TILE_N_LARGE: int = 8192 -TILE_N_LARGE_THRESHOLD: int = 8192 - -_LOG2_E: float = math.log2(math.e) - -_TORCH_TO_CUTLASS_DTYPE: dict[torch.dtype, Any] = { - torch.float32: cutlass.Float32, - torch.float16: cutlass.Float16, - torch.bfloat16: cutlass.BFloat16, -} - - -# --------------------------------------------------------------------------- -# Main GRPO kernel — single CTA per row. -# --------------------------------------------------------------------------- - - -def _make_grpo_kernel( - compute_kl: bool, - compute_backward: bool, - tile_n: int, -) -> Callable[..., None]: - """Build the GRPO kernel specialized on compile-time flags. - - On the fwd+bwd path each CTA does a mask-only prescan to derive - ``row_scale`` before the main pass; this lets the main pass fold - loss accumulation and gradient compute into a single read of the - logprob tensors. On the fwd-only path the prescan is skipped — the - main pass accumulates ``valid`` directly. - """ - - @cute.kernel - def _grpo_loss_kernel( - policy: cute.Tensor, - old_policy: cute.Tensor, - ref: cute.Tensor, - advantages: cute.Tensor, - completions_mask: cute.Tensor, - dpolicy: cute.Tensor, # (N, C) when compute_backward; (N, 1) dummy otherwise - policy_acc: cute.Tensor, # (1,) fp32 — global scalar loss accumulator - counter: cute.Tensor, # (1,) int32 — global last-block detection - output: cute.Tensor, - epsilon: cutlass.Float32, - epsilon_high: cutlass.Float32, - beta: cutlass.Float32, - inv_n_rows: cutlass.Float32, - n_rows: cutlass.Int32, - num_full_tiles: cutlass.Int32, - tail_len: cutlass.Int32, - num_col_tiles: cutlass.Int32, - ) -> None: - block_size = NUM_WARPS * 32 - iters = tile_n // (block_size * VEC) - - _no_alloc = cute.nvgpu.CacheEvictionPriority.NO_ALLOCATE - g2r_op = cute.nvgpu.CopyUniversalOp() - # Logprob loads stream once -- ``NO_ALLOCATE`` keeps them out of L1 - # so they don't evict the mask data we *do* re-read. The current - # single-pass main loop never re-reads pol/old/ref; only the mask - # is touched twice (fwd+bwd prescan + main pass). - g2r_atom = cute.make_copy_atom( - g2r_op, - policy.element_type, - num_bits_per_copy=0, - l1c_evict_priority=_no_alloc, - ) - # Mask is read twice on the fwd+bwd path (prescan + main), so bias - # L1 to keep it. On fwd-only the prescan does not run, so the hint - # is unused and falls through to the streaming default. - mask_evict = cute.nvgpu.CacheEvictionPriority.EVICT_LAST if compute_backward else _no_alloc - g2r_mask_atom = cute.make_copy_atom( - g2r_op, - completions_mask.element_type, - num_bits_per_copy=0, - l1c_evict_priority=mask_evict, - ) - if cutlass.const_expr(compute_backward): - r2g_atom = cute.make_copy_atom( - g2r_op, - dpolicy.element_type, - num_bits_per_copy=0, - ) - - row = cute.arch.block_idx()[0] - tid = cute.arch.thread_idx()[0] - lane_idx = cute.arch.lane_idx() - warp_idx = cute.arch.warp_idx() - - adv = cast(advantages[row], cutlass.Float32) - lo = cutlass.Float32(1.0) - epsilon - hi = cutlass.Float32(1.0) + epsilon_high - - pol_row = cute.slice_(policy, (row, None)) - old_row = cute.slice_(old_policy, (row, None)) - mask_row = cute.slice_(completions_mask, (row, None)) - if cutlass.const_expr(compute_kl): - ref_row = cute.slice_(ref, (row, None)) - if cutlass.const_expr(compute_backward): - dp_row = cute.slice_(dpolicy, (row, None)) - - smem = cutlass.utils.SmemAllocator() - buf_policy = smem.allocate_tensor(cutlass.Float32, cute.make_layout(NUM_WARPS)) - if cutlass.const_expr(compute_kl): - buf_kl = smem.allocate_tensor(cutlass.Float32, cute.make_layout(NUM_WARPS)) - # Always need a valid buffer in fwd-only path; on fwd+bwd path the - # prescan also reuses cross-warp SMEM reduction. - buf_valid = smem.allocate_tensor(cutlass.Float32, cute.make_layout(NUM_WARPS)) - if cutlass.const_expr(compute_backward): - row_scale_smem = smem.allocate_tensor(cutlass.Float32, cute.make_layout(1)) - - # ---- Stage A (fwd+bwd only): mask-only prescan → row_scale ---- - if cutlass.const_expr(compute_backward): - local_valid_pre = cutlass.Float32(0.0) - for col_block in cutlass.range(num_col_tiles, unroll=1): - mask_slab = cute.local_tile(mask_row, (tile_n,), (col_block,)) - if col_block < num_full_tiles: - for k in cutlass.range(iters, unroll_full=True): - sub_idx = tid + k * block_size - mask_src = cute.local_tile(mask_slab, (VEC,), (sub_idx,)) - mask_frag = cute.make_fragment_like(mask_src) - cute.copy(g2r_mask_atom, mask_src, mask_frag) - local_valid_pre += ( - mask_frag.load() - .to(cutlass.Float32) - .reduce( - cute.ReductionOp.ADD, - cutlass.Float32(0.0), - reduction_profile=0, - ) - ) - else: - for k in cutlass.range(iters, unroll_full=True): - sub_idx = tid + k * block_size - chunk_base = sub_idx * VEC - if chunk_base < tail_len: - mask_src = cute.local_tile(mask_slab, (VEC,), (sub_idx,)) - pred = cute.make_rmem_tensor(mask_src.shape, cutlass.Boolean) - for v in cutlass.range(VEC, unroll_full=True): - pred[v] = cute.elem_less(chunk_base + v, tail_len) - mask_frag = cute.make_fragment_like(mask_src) - mask_frag.fill(0) - cute.copy(g2r_mask_atom, mask_src, mask_frag, pred=pred) - local_valid_pre += ( - mask_frag.load() - .to(cutlass.Float32) - .reduce( - cute.ReductionOp.ADD, - cutlass.Float32(0.0), - reduction_profile=0, - ) - ) - - warp_valid_pre = cute.arch.warp_reduction(local_valid_pre, operator.add) - if lane_idx == 0: - buf_valid[warp_idx] = warp_valid_pre - cute.arch.barrier() - - if warp_idx == 0: - lane_in_warp_range = lane_idx < NUM_WARPS - val_v = cutlass.Float32(0.0) - if lane_in_warp_range: - val_v = buf_valid[lane_idx] - block_valid_pre = cute.arch.warp_reduction( - val_v, operator.add, threads_in_group=NUM_WARPS - ) - if lane_idx == 0: - n_v_row = cute.arch.fmax(block_valid_pre, cutlass.Float32(1.0)) - row_scale_smem[0] = cute.arch.rcp_approx(n_v_row) * inv_n_rows - cute.arch.barrier() - row_scale_v = row_scale_smem[0] - - # ---- Stage B: main pass — loss accumulation + (fused) gradient ---- - local_policy_sum = cutlass.Float32(0.0) - local_kl_sum = cutlass.Float32(0.0) - # Only the fwd-only path needs to accumulate valid here; the - # fwd+bwd path already produced ``row_scale_v`` from the prescan. - local_valid_sum = cutlass.Float32(0.0) - - for col_block in cutlass.range(num_col_tiles, unroll=1): - if col_block < num_full_tiles: - pol_slab = cute.local_tile(pol_row, (tile_n,), (col_block,)) - old_slab = cute.local_tile(old_row, (tile_n,), (col_block,)) - mask_slab = cute.local_tile(mask_row, (tile_n,), (col_block,)) - if cutlass.const_expr(compute_kl): - ref_slab = cute.local_tile(ref_row, (tile_n,), (col_block,)) - if cutlass.const_expr(compute_backward): - dp_slab = cute.local_tile(dp_row, (tile_n,), (col_block,)) - - for k in cutlass.range(iters, unroll_full=True): - sub_idx = tid + k * block_size - - pol_src = cute.local_tile(pol_slab, (VEC,), (sub_idx,)) - old_src = cute.local_tile(old_slab, (VEC,), (sub_idx,)) - pol_frag = cute.make_fragment_like(pol_src) - old_frag = cute.make_fragment_like(old_src) - cute.copy(g2r_atom, pol_src, pol_frag) - cute.copy(g2r_atom, old_src, old_frag) - - mask_src = cute.local_tile(mask_slab, (VEC,), (sub_idx,)) - mask_frag = cute.make_fragment_like(mask_src) - cute.copy(g2r_mask_atom, mask_src, mask_frag) - - pol_vec = pol_frag.load().to(cutlass.Float32) - old_vec = old_frag.load().to(cutlass.Float32) - mask_vec = mask_frag.load().to(cutlass.Float32) - - log_ratio = pol_vec - old_vec - ratio = cute.math.exp2(log_ratio * _LOG2_E, fastmath=True) - surrogate = ratio * adv - clipped_ratio = cute.where( - ratio < lo, - lo, - cute.where(ratio > hi, hi, ratio), - ) - clipped = clipped_ratio * adv - policy_loss = -cute.where(surrogate < clipped, surrogate, clipped) - - # Compute the (negated) surrogate gradient before the KL - # block so ``surrogate``/``clipped`` can die before the - # ref load. ``-(adv * ratio)`` is just ``-surrogate`` — - # saves one FMUL per element. - if cutlass.const_expr(compute_backward): - neg_surrogate_grad = cute.where( - surrogate <= clipped, - -surrogate, - cutlass.Float32(0.0), - ) - - # Reduce ``policy_loss`` immediately so it can be freed - # before any further work. - local_policy_sum += (policy_loss * mask_vec).reduce( - cute.ReductionOp.ADD, - cutlass.Float32(0.0), - reduction_profile=0, - ) - if cutlass.const_expr(not compute_backward): - local_valid_sum += mask_vec.reduce( - cute.ReductionOp.ADD, - cutlass.Float32(0.0), - reduction_profile=0, - ) - - if cutlass.const_expr(compute_kl): - ref_src = cute.local_tile(ref_slab, (VEC,), (sub_idx,)) - ref_frag = cute.make_fragment_like(ref_src) - cute.copy(g2r_atom, ref_src, ref_frag) - ref_vec = ref_frag.load().to(cutlass.Float32) - log_ratio_ref = ref_vec - pol_vec - ratio_ref = cute.math.exp2(log_ratio_ref * _LOG2_E, fastmath=True) - # FFMA-friendly rearrangement: ``(ratio_ref - 1) - log_ratio_ref`` - # exposes a ``ratio_ref + (-1)`` pair that ptxas folds - # with the subsequent subtract — fewer FADDs surviving - # SASS than the original 3-term ``a - b - c``. - kl_val = (ratio_ref - cutlass.Float32(1.0)) - log_ratio_ref - local_kl_sum += (kl_val * mask_vec).reduce( - cute.ReductionOp.ADD, - cutlass.Float32(0.0), - reduction_profile=0, - ) - - if cutlass.const_expr(compute_backward): - if cutlass.const_expr(compute_kl): - # ``beta - beta*ratio_ref`` instead of ``beta*(1 - ratio_ref)`` - # gives ptxas an obvious FFMA pattern (``FFMA -beta, - # ratio_ref, beta``) — saves one FMUL per element. - kl_grad = beta - beta * ratio_ref - grad_vec = (neg_surrogate_grad + kl_grad) * mask_vec - else: - grad_vec = neg_surrogate_grad * mask_vec - scaled = grad_vec * row_scale_v - - dp_dst = cute.local_tile(dp_slab, (VEC,), (sub_idx,)) - dp_frag = cute.make_fragment_like(dp_dst) - dp_frag.store(scaled.to(dpolicy.element_type)) - cute.copy(r2g_atom, dp_frag, dp_dst) - else: - pol_slab = cute.local_tile(pol_row, (tile_n,), (col_block,)) - old_slab = cute.local_tile(old_row, (tile_n,), (col_block,)) - mask_slab = cute.local_tile(mask_row, (tile_n,), (col_block,)) - if cutlass.const_expr(compute_kl): - ref_slab = cute.local_tile(ref_row, (tile_n,), (col_block,)) - if cutlass.const_expr(compute_backward): - dp_slab = cute.local_tile(dp_row, (tile_n,), (col_block,)) - - for k in cutlass.range(iters, unroll_full=True): - sub_idx = tid + k * block_size - chunk_base = sub_idx * VEC - - if chunk_base < tail_len: - pol_src = cute.local_tile(pol_slab, (VEC,), (sub_idx,)) - old_src = cute.local_tile(old_slab, (VEC,), (sub_idx,)) - pred = cute.make_rmem_tensor(pol_src.shape, cutlass.Boolean) - for v in cutlass.range(VEC, unroll_full=True): - pred[v] = cute.elem_less(chunk_base + v, tail_len) - - pol_frag = cute.make_fragment_like(pol_src) - old_frag = cute.make_fragment_like(old_src) - pol_frag.fill(0.0) - old_frag.fill(0.0) - cute.copy(g2r_atom, pol_src, pol_frag, pred=pred) - cute.copy(g2r_atom, old_src, old_frag, pred=pred) - - mask_src = cute.local_tile(mask_slab, (VEC,), (sub_idx,)) - mask_frag = cute.make_fragment_like(mask_src) - mask_frag.fill(0) - cute.copy(g2r_mask_atom, mask_src, mask_frag, pred=pred) - - pol_vec = pol_frag.load().to(cutlass.Float32) - old_vec = old_frag.load().to(cutlass.Float32) - valid_vec = cute.where( - pred.load(), - cute.full_like(pol_vec, cutlass.Float32(1.0)), - cute.zeros_like(pol_vec, dtype=cutlass.Float32), - ) - # ``mask_vec * valid_vec`` zeros out-of-bounds lanes. - mask_vec = mask_frag.load().to(cutlass.Float32) * valid_vec - - log_ratio = pol_vec - old_vec - ratio = cute.math.exp2(log_ratio * _LOG2_E, fastmath=True) - surrogate = ratio * adv - clipped_ratio = cute.where( - ratio < lo, - lo, - cute.where(ratio > hi, hi, ratio), - ) - clipped = clipped_ratio * adv - policy_loss = -cute.where(surrogate < clipped, surrogate, clipped) - - # Same live-range narrowing as the full-tile path: - # fold ``surrogate``/``clipped`` into the gradient - # term and reduce ``policy_loss`` before the KL - # block so they can be freed. ``-(adv * ratio)`` is - # ``-surrogate`` — saves one FMUL per element. - if cutlass.const_expr(compute_backward): - neg_surrogate_grad = cute.where( - surrogate <= clipped, - -surrogate, - cutlass.Float32(0.0), - ) - - local_policy_sum += (policy_loss * mask_vec).reduce( - cute.ReductionOp.ADD, - cutlass.Float32(0.0), - reduction_profile=0, - ) - if cutlass.const_expr(not compute_backward): - local_valid_sum += mask_vec.reduce( - cute.ReductionOp.ADD, - cutlass.Float32(0.0), - reduction_profile=0, - ) - - if cutlass.const_expr(compute_kl): - ref_src = cute.local_tile(ref_slab, (VEC,), (sub_idx,)) - ref_frag = cute.make_fragment_like(ref_src) - ref_frag.fill(0.0) - cute.copy(g2r_atom, ref_src, ref_frag, pred=pred) - ref_vec = ref_frag.load().to(cutlass.Float32) - log_ratio_ref = ref_vec - pol_vec - ratio_ref = cute.math.exp2(log_ratio_ref * _LOG2_E, fastmath=True) - # See full-tile path: FFMA-friendly rearrangement. - kl_val = (ratio_ref - cutlass.Float32(1.0)) - log_ratio_ref - local_kl_sum += (kl_val * mask_vec).reduce( - cute.ReductionOp.ADD, - cutlass.Float32(0.0), - reduction_profile=0, - ) - - if cutlass.const_expr(compute_backward): - if cutlass.const_expr(compute_kl): - # See full-tile path: ``beta - beta*ratio_ref`` - # is FFMA-friendly. - kl_grad = beta - beta * ratio_ref - grad_vec = (neg_surrogate_grad + kl_grad) * mask_vec - else: - grad_vec = neg_surrogate_grad * mask_vec - scaled = grad_vec * row_scale_v - - dp_dst = cute.local_tile(dp_slab, (VEC,), (sub_idx,)) - dp_frag = cute.make_fragment_like(dp_dst) - dp_frag.store(scaled.to(dpolicy.element_type)) - cute.copy(r2g_atom, dp_frag, dp_dst, pred=pred) - - # ---- Stage C: warp + cross-warp reduction → atomic-add row_loss ---- - warp_policy = cute.arch.warp_reduction(local_policy_sum, operator.add) - if cutlass.const_expr(compute_kl): - warp_kl = cute.arch.warp_reduction(local_kl_sum, operator.add) - if cutlass.const_expr(not compute_backward): - warp_valid = cute.arch.warp_reduction(local_valid_sum, operator.add) - - # The fwd+bwd path used ``buf_valid`` for the prescan reduction; - # ensure all threads have observed ``row_scale_smem`` (Stage A's - # final barrier) before we reuse ``buf_policy`` / ``buf_valid``. - # Stage A never runs on the fwd-only path, so the barrier is - # only needed when ``compute_backward`` is set. - if cutlass.const_expr(compute_backward): - cute.arch.barrier() - - if lane_idx == 0: - buf_policy[warp_idx] = warp_policy - if cutlass.const_expr(compute_kl): - buf_kl[warp_idx] = warp_kl - if cutlass.const_expr(not compute_backward): - buf_valid[warp_idx] = warp_valid - cute.arch.barrier() - - if warp_idx == 0: - lane_in_warp_range = lane_idx < NUM_WARPS - - val_p = cutlass.Float32(0.0) - if lane_in_warp_range: - val_p = buf_policy[lane_idx] - block_policy = cute.arch.warp_reduction(val_p, operator.add, threads_in_group=NUM_WARPS) - - block_kl = cutlass.Float32(0.0) - if cutlass.const_expr(compute_kl): - val_k = cutlass.Float32(0.0) - if lane_in_warp_range: - val_k = buf_kl[lane_idx] - block_kl = cute.arch.warp_reduction(val_k, operator.add, threads_in_group=NUM_WARPS) - - if cutlass.const_expr(not compute_backward): - val_v = cutlass.Float32(0.0) - if lane_in_warp_range: - val_v = buf_valid[lane_idx] - block_valid = cute.arch.warp_reduction( - val_v, operator.add, threads_in_group=NUM_WARPS - ) - - if lane_idx == 0: - if cutlass.const_expr(compute_backward): - row_scale = row_scale_smem[0] - else: - n_v_row = cute.arch.fmax(block_valid, cutlass.Float32(1.0)) - row_scale = cute.arch.rcp_approx(n_v_row) * inv_n_rows - - if cutlass.const_expr(compute_kl): - row_loss = (block_policy + beta * block_kl) * row_scale - else: - row_loss = block_policy * row_scale - - loss_ptr = policy_acc.iterator.toint().ir_value() # ty: ignore[unresolved-attribute] - _atomic_add_f32_gmem(loss_ptr, row_loss) - - # ---- Stage D: last-block detection → write final loss ---- - if warp_idx == 0: - is_last_lane0 = cutlass.Int32(0) - if lane_idx == 0: - counter_ptr = counter.iterator.toint().ir_value() # ty: ignore[unresolved-attribute] - cute.arch.fence_acq_rel_gpu() - old = _atomic_inc_u32_gmem(counter_ptr, n_rows - 1) - if old == n_rows - 1: - is_last_lane0 = cutlass.Int32(1) - - is_last = cute.arch.shuffle_sync(is_last_lane0, 0) - - if is_last == cutlass.Int32(1) and lane_idx == 0: - # Each CTA already atomic-added its scaled ``row_loss`` - # (including ``inv_n_rows``); the accumulator now holds - # the final loss. - total = policy_acc[0] - output[0] = cast(total, output.element_type) # ty: ignore[invalid-argument-type] - # Reset for the next call. - policy_acc[0] = cutlass.Float32(0.0) - - return _grpo_loss_kernel - - -# --------------------------------------------------------------------------- -# Compile-and-run factory -# --------------------------------------------------------------------------- - - -def create_compiled_grpo_loss( - policy_dtype: torch.dtype, - epsilon: float, - epsilon_high: float, - beta: float, - compute_backward: bool = False, -) -> Callable[..., torch.Tensor | tuple[torch.Tensor, torch.Tensor]]: - """Compile the GRPO loss kernel and return a runtime closure. - - Two compiled variants — small-tile (``TILE_N``) and large-tile - (``TILE_N_LARGE``) — are produced; the runner selects one per call - based on ``seq_len`` vs. ``TILE_N_LARGE_THRESHOLD``. - - When ``compute_backward=True`` the kernel additionally writes the - scaled gradient ``dL/d(policy_logprobs)`` to a caller-provided - ``dpolicy`` tensor in the same launch. - """ - compute_kl = beta != 0.0 - - if policy_dtype not in _TORCH_TO_CUTLASS_DTYPE: - raise ValueError(f"Unsupported dtype for GRPO kernel: {policy_dtype}") - - tile_n_small = TILE_N - tile_n_large = TILE_N_LARGE - seq_len_threshold = TILE_N_LARGE_THRESHOLD - block_size = NUM_WARPS * 32 - if tile_n_small % (block_size * VEC) != 0: - raise ValueError( - f"TILE_N={tile_n_small} must be a multiple of BLOCK_SIZE*VEC={block_size * VEC}" - ) - if tile_n_large % (block_size * VEC) != 0: - raise ValueError( - f"TILE_N_LARGE={tile_n_large} must be a multiple of BLOCK_SIZE*VEC={block_size * VEC}" - ) - - n_rows_sym = cute.sym_int() - seq_len_sym = cute.sym_int() - cute_dtype = _TORCH_TO_CUTLASS_DTYPE[policy_dtype] - - def _fake2d(dt: Any, cols: Any) -> Any: - return cute.runtime.make_fake_compact_tensor( - dt, - (n_rows_sym, cols), - stride_order=(1, 0), - assumed_align=16, - ) - - fake_pol = _fake2d(cute_dtype, seq_len_sym) - fake_old = _fake2d(cute_dtype, seq_len_sym) - fake_ref = _fake2d(cute_dtype, seq_len_sym) - fake_adv = cute.runtime.make_fake_compact_tensor( - cute_dtype, - (n_rows_sym,), - assumed_align=16, - ) - fake_mask = cute.runtime.make_fake_compact_tensor( - cutlass.Int8, - (n_rows_sym, seq_len_sym), - stride_order=(1, 0), - assumed_align=16, - ) - dpolicy_cols = seq_len_sym if compute_backward else 1 - fake_dpolicy = cute.runtime.make_fake_compact_tensor( - cute_dtype, - (n_rows_sym, dpolicy_cols), - stride_order=(1, 0), - assumed_align=16, - ) - fake_policy_acc = cute.runtime.make_fake_compact_tensor( - cutlass.Float32, - (1,), - assumed_align=16, - ) - fake_counter = cute.runtime.make_fake_compact_tensor( - cutlass.Int32, - (1,), - assumed_align=16, - ) - fake_output = cute.runtime.make_fake_compact_tensor( - cute_dtype, - (1,), - assumed_align=16, - ) - - def _build_launch(tile_n_v: int) -> Callable[..., None]: - """Build a JIT-compiled launcher specialized for ``tile_n_v``. - - Bakes ``compute_kl``, ``compute_backward``, and the column-tile - width into the kernel as compile-time constants and returns a - ``@cute.jit`` wrapper that forwards runtime tensors/scalars to - ``.launch()`` with ``grid=(n_rows, 1, 1)`` and - ``block=(NUM_WARPS*32, 1, 1)``. - """ - specialized_kernel = _make_grpo_kernel(compute_kl, compute_backward, tile_n_v) - - @cute.jit - def _launch( - pol_ct: cute.Tensor, - old_ct: cute.Tensor, - ref_ct: cute.Tensor, - adv_ct: cute.Tensor, - mask_ct: cute.Tensor, - dpolicy_ct: cute.Tensor, - policy_acc_ct: cute.Tensor, - counter_ct: cute.Tensor, - output_ct: cute.Tensor, - epsilon_v: cutlass.Float32, - epsilon_high_v: cutlass.Float32, - beta_v: cutlass.Float32, - inv_n_rows_v: cutlass.Float32, - n_rows_v: cutlass.Int32, - num_full_tiles_v: cutlass.Int32, - tail_len_v: cutlass.Int32, - num_col_tiles_v: cutlass.Int32, - ) -> None: - specialized_kernel( # ty: ignore[unresolved-attribute] - pol_ct, - old_ct, - ref_ct, - adv_ct, - mask_ct, - dpolicy_ct, - policy_acc_ct, - counter_ct, - output_ct, - epsilon_v, - epsilon_high_v, - beta_v, - inv_n_rows_v, - n_rows_v, - num_full_tiles_v, - tail_len_v, - num_col_tiles_v, - ).launch( - grid=(n_rows_v, 1, 1), - block=(NUM_WARPS * 32, 1, 1), - ) - - return _launch - - def _compile_launch(launch_fn: Callable[..., None]) -> Callable[..., None]: - return cute.compile( - launch_fn, - fake_pol, - fake_old, - fake_ref, - fake_adv, - fake_mask, - fake_dpolicy, - fake_policy_acc, - fake_counter, - fake_output, - cutlass.Float32(epsilon), - cutlass.Float32(epsilon_high), - cutlass.Float32(beta), - cutlass.Float32(1.0), - cutlass.Int32(1), - cutlass.Int32(1), - cutlass.Int32(0), - cutlass.Int32(1), - options="--enable-tvm-ffi", - ) - - compiled_small = _compile_launch(_build_launch(tile_n_small)) - if tile_n_large == tile_n_small: - compiled_large = compiled_small - else: - compiled_large = _compile_launch(_build_launch(tile_n_large)) - - eps_const = cutlass.Float32(epsilon) - eps_high_const = cutlass.Float32(epsilon_high) - beta_const = cutlass.Float32(beta) - - # Cross-CTA scratch slab — one int32 buffer with stride-4 (16-byte) slices - # so each slot is individually 16-byte aligned (``assumed_align=16`` at - # compile time). Bit-pattern of int32 0 equals fp32 0.0, so a single - # ``zeros`` factory legitimately initialises both the int32 counter and - # the fp32 ``policy_acc``. The kernel's last block self-resets - # ``policy_acc`` in its epilogue and the counter self-resets via - # ``atom.inc.u32`` wrap-around, so the up-front ``torch.zeros`` only - # matters for the very first call. Allocated lazily on first ``_run`` - # call when the device is known. - _scratch: list[torch.Tensor | None] = [None] - - def _ensure_scratch(device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: - s = _scratch[0] - if s is None or s.device != device: - s = torch.zeros(8, dtype=torch.int32, device=device) - _scratch[0] = s - return ( - s[0:1], # counter (int32) - s[4:5].view(torch.float32), # policy_acc (fp32) - ) - - def _run( - policy_logprobs_r: torch.Tensor, - old_policy_logprobs_r: torch.Tensor, - ref_logprobs_r: torch.Tensor, - advantages_r: torch.Tensor, - completions_mask_r: torch.Tensor, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - n_rows, seq_len = policy_logprobs_r.shape - device = policy_logprobs_r.device - dtype = policy_logprobs_r.dtype - - if seq_len >= seq_len_threshold: - tile_n_active = tile_n_large - compiled_active = compiled_large - else: - tile_n_active = tile_n_small - compiled_active = compiled_small - num_full_tiles = seq_len // tile_n_active - tail_len = seq_len % tile_n_active - num_col_tiles = num_full_tiles + (1 if tail_len > 0 else 0) - inv_n_rows = cutlass.Float32(1.0 / float(n_rows)) - - # Per-call write-only buffers — ``empty`` is enough (Liger / TE pattern). - output_r = torch.empty(1, dtype=dtype, device=device) - if compute_backward: - dpolicy_r = torch.empty_like(policy_logprobs_r) - else: - dpolicy_r = torch.empty(n_rows, 1, dtype=dtype, device=device) - - counter_r, policy_acc_r = _ensure_scratch(device) - - compiled_active( - policy_logprobs_r, - old_policy_logprobs_r, - ref_logprobs_r, - advantages_r, - completions_mask_r, - dpolicy_r, - policy_acc_r, - counter_r, - output_r, - eps_const, - eps_high_const, - beta_const, - inv_n_rows, - cutlass.Int32(n_rows), - cutlass.Int32(num_full_tiles), - cutlass.Int32(tail_len), - cutlass.Int32(num_col_tiles), - ) - out_view = output_r.view(()) - if compute_backward: - return out_view, dpolicy_r - return out_view - - return _run diff --git a/build/torch-cuda/layers.py b/build/torch-cuda/layers.py deleted file mode 100644 index 348ad87334c57a7a7fffd0013b03a77ae03443b5..0000000000000000000000000000000000000000 --- a/build/torch-cuda/layers.py +++ /dev/null @@ -1,258 +0,0 @@ -"""HF Kernels layer adapters for ``kernelize()``. - -These ``nn.Module`` classes are the entry points for users who want to -plug our cute kernels into a model via the ``kernels`` library's -``kernelize`` flow. Two classes per kernel, one per supported mode: - -* :class:`bnpoLoss` / :class:`grpoLoss` / :class:`ReverseKL` - — autograd-aware (``loss.backward()`` works). Register against - ``Mode.TRAINING`` and/or ``Mode.TRAINING | Mode.TORCH_COMPILE``. -* :class:`bnpoLossInference` / :class:`grpoLossInference` / - :class:`ReverseKLInference` — forward-only, no autograd - dispatcher. Register against ``Mode.INFERENCE`` for inference / - validation. - -All are stateless (no ``__init__``, no member tensors) as required by -``kernelize`` — it validates that layer classes don't add constructor -state. The ``has_backward`` and ``can_torch_compile`` attributes are -the only allowed extras and let ``kernelize`` choose the right layer -for the requested mode. - -Forward-signature contract: a downstream user wraps the loss in their -own ``nn.Module`` (decorated with ``@use_kernel_forward_from_hub(...)``) -whose ``forward`` matches the signature here. ``kernelize`` swaps the -``forward`` method by class identity, so the signature MUST line up -positionally with the user's module. - -Typical user-side wiring:: - - from kernels import ( - Mode, LayerRepository, kernelize, use_kernel_forward_from_hub, - ) - - @use_kernel_forward_from_hub("bnpoLoss") - class bnpoLoss(nn.Module): - def forward(self, policy, old, ref, adv, completions_mask, - epsilon=0.2, epsilon_high=0.2, beta=0.1): - ... # eager fallback - - mapping = { - "bnpoLoss": { - "cuda": { - Mode.INFERENCE: LayerRepository( - repo_id="Geometric-AI/geometric-ai-kernels", - layer_name="bnpoLossInference", - ), - Mode.TRAINING: LayerRepository( - repo_id="Geometric-AI/geometric-ai-kernels", - layer_name="bnpoLoss", - ), - } - } - } - - with use_kernel_mapping(mapping): - model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE) -""" - -from __future__ import annotations - -import torch -from torch import nn - -from .bnpo_loss import bnpo_loss_fwd -from .bnpo_loss.autograd import bnpo_loss_autograd -from .grpo_loss import grpo_loss_fwd -from .grpo_loss.autograd import grpo_loss_autograd -from .reverse_kl import ( - reverse_kl_autograd, - reverse_kl_fwd, -) - -__all__ = [ - "ReverseKL", - "ReverseKLInference", - "bnpoLoss", - "bnpoLossInference", - "grpoLoss", - "grpoLossInference", -] - - -class bnpoLoss(nn.Module): - """Training-mode bnpo loss layer. ``loss.backward()`` works. - - Routes through :func:`bnpo_loss_autograd`, which wraps the fused - cute kernel in a ``torch.library.custom_op`` with a registered - backward. Compatible with ``torch.compile`` (the op has a fake - kernel for shape propagation). - """ - - has_backward = True - can_torch_compile = True - - def forward( - self, - policy_logprobs: torch.Tensor, - old_policy_logprobs: torch.Tensor, - ref_logprobs: torch.Tensor, - advantages: torch.Tensor, - completions_mask: torch.Tensor, - epsilon: float = 0.2, - epsilon_high: float = 0.2, - beta: float = 0.1, - ) -> torch.Tensor: - return bnpo_loss_autograd( - policy_logprobs, - old_policy_logprobs, - ref_logprobs, - advantages, - completions_mask, - epsilon=epsilon, - epsilon_high=epsilon_high, - beta=beta, - ) - - -class bnpoLossInference(nn.Module): - """Inference / validation bnpo loss layer. No autograd dispatcher. - - Routes through :func:`bnpo_loss_fwd` — the forward-only kernel that - computes the masked mean denominator on-GPU via the last-block - trick, skipping the dpolicy buffer entirely. - """ - - has_backward = False - can_torch_compile = True - - def forward( - self, - policy_logprobs: torch.Tensor, - old_policy_logprobs: torch.Tensor, - ref_logprobs: torch.Tensor, - advantages: torch.Tensor, - completions_mask: torch.Tensor, - epsilon: float = 0.2, - epsilon_high: float = 0.2, - beta: float = 0.1, - ) -> torch.Tensor: - return bnpo_loss_fwd( - policy_logprobs, - old_policy_logprobs, - ref_logprobs, - advantages, - completions_mask, - epsilon=epsilon, - epsilon_high=epsilon_high, - beta=beta, - ) - - -class grpoLoss(nn.Module): - """Training-mode GRPO loss layer. ``loss.backward()`` works. - - Routes through :func:`grpo_loss_autograd`, which wraps the fused - cute kernel in a ``torch.library.custom_op`` with a registered - backward. ``completions_mask`` is required (GRPO's per-response - normalization is mask-derived). - """ - - has_backward = True - can_torch_compile = True - - def forward( - self, - policy_logprobs: torch.Tensor, - old_policy_logprobs: torch.Tensor, - ref_logprobs: torch.Tensor, - advantages: torch.Tensor, - completions_mask: torch.Tensor, - epsilon: float = 0.2, - epsilon_high: float = 0.2, - beta: float = 0.1, - ) -> torch.Tensor: - return grpo_loss_autograd( - policy_logprobs, - old_policy_logprobs, - ref_logprobs, - advantages, - completions_mask, - epsilon=epsilon, - epsilon_high=epsilon_high, - beta=beta, - ) - - -class grpoLossInference(nn.Module): - """Inference / validation GRPO loss layer. No autograd dispatcher. - - Routes through :func:`grpo_loss_fwd` — the forward-only kernel, - skipping the dpolicy buffer entirely. - """ - - has_backward = False - can_torch_compile = True - - def forward( - self, - policy_logprobs: torch.Tensor, - old_policy_logprobs: torch.Tensor, - ref_logprobs: torch.Tensor, - advantages: torch.Tensor, - completions_mask: torch.Tensor, - epsilon: float = 0.2, - epsilon_high: float = 0.2, - beta: float = 0.1, - ) -> torch.Tensor: - return grpo_loss_fwd( - policy_logprobs, - old_policy_logprobs, - ref_logprobs, - advantages, - completions_mask, - epsilon=epsilon, - epsilon_high=epsilon_high, - beta=beta, - ) - - -class ReverseKL(nn.Module): - """Training-mode reverse-KL self-distillation loss layer. - - Routes through :func:`reverse_kl_autograd`, which wraps - the fused cute kernel in a ``torch.library.custom_op`` with a - registered backward. ``loss.backward()`` propagates to whatever - produced ``student_logits``. Compatible with ``torch.compile`` (the - op has a fake kernel for shape propagation). - """ - - has_backward = True - can_torch_compile = True - - def forward( - self, - student_logits: torch.Tensor, - teacher_logits: torch.Tensor, - completions_mask: torch.Tensor, - ) -> torch.Tensor: - return reverse_kl_autograd(student_logits, teacher_logits, completions_mask) - - -class ReverseKLInference(nn.Module): - """Inference / validation reverse-KL self-distillation loss layer. - - Routes through :func:`reverse_kl_fwd` — the forward-only - kernel that dead-code-eliminates the gradient pass and skips the - grad-student buffer entirely. - """ - - has_backward = False - can_torch_compile = True - - def forward( - self, - student_logits: torch.Tensor, - teacher_logits: torch.Tensor, - completions_mask: torch.Tensor, - ) -> torch.Tensor: - return reverse_kl_fwd(student_logits, teacher_logits, completions_mask) diff --git a/build/torch-cuda/metadata.json b/build/torch-cuda/metadata.json deleted file mode 100644 index 2777190a5ffd618cb96cc58216d0b495839db4a1..0000000000000000000000000000000000000000 --- a/build/torch-cuda/metadata.json +++ /dev/null @@ -1,12 +0,0 @@ -{ - "id": "_geometric_ai_kernels_cuda_a766fbd_dirty", - "version": 0, - "license": "Apache-2.0", - "python-depends": [ - "tvm-ffi", - "nvidia-cutlass-dsl" - ], - "backend": { - "type": "cuda" - } -} \ No newline at end of file diff --git a/build/torch-cuda/reverse_kl/__init__.py b/build/torch-cuda/reverse_kl/__init__.py deleted file mode 100644 index 047b38ce2ccc1d438da7a8ffc5d07e40957a950c..0000000000000000000000000000000000000000 --- a/build/torch-cuda/reverse_kl/__init__.py +++ /dev/null @@ -1,196 +0,0 @@ -"""Reverse-KL self-distillation loss with CuteDSL fused fwd+bwd. - -Three public APIs route to two compiled kernels: - -* :func:`reverse_kl` — primary training entry point. - Returns ``(loss, grad_student_logits)`` from a single fused fwd+bwd - kernel launch. Inputs do **not** need ``requires_grad=True`` and there - is no ``torch.autograd.Function`` wrapper — chain the gradient into - the upstream model with ``student_logits.backward(grad)``. -* :func:`reverse_kl_fwd` — inference / validation path. - Returns the scalar ``loss`` from a forward-only kernel that - dead-code-eliminates the gradient pass. -* :func:`reverse_kl_autograd` — autograd-aware variant via - ``torch.library.custom_op``. Returns scalar ``loss``; - ``loss.backward()`` Just Works. Re-exported from - :mod:`reverse_kl.autograd`. - -Per-call output and gradient buffers are allocated inside the runner; -cross-CTA scratch (atomic accumulators + counters + the constant -``grad_output=1.0`` scalar) is owned by the compiled-kernel closure and -self-resets each launch — callers don't manage scratch state. - -Why no autograd wrapper for :func:`reverse_kl`? -The reverse-KL gradient is closed-form: ``dL/d(student) = mask * -inv_n_valid * p * (log_p - log_q - kl_per_row)``. The fused kernel -already writes that analytically in the same launch as the loss. -Wrapping in ``torch.autograd.Function`` would cost an extra -``grad_output * dpolicy`` kernel on backward (~2× per-call overhead in -practice) and is opaque to ``torch.compile``. Use the autograd-aware -variant when you need ``loss.backward()`` ergonomics; pay only the -``custom_op`` dispatcher cost (also Inductor-traceable). -""" - -from __future__ import annotations - -from functools import lru_cache -from typing import TYPE_CHECKING - -import torch - -from .cute_reverse_kl import create_compiled_reverse_kl - -if TYPE_CHECKING: - from collections.abc import Callable - - -__all__ = [ - "reverse_kl", - "reverse_kl_autograd", - "reverse_kl_fwd", -] - - -@lru_cache(maxsize=32) -def _get_compiled_fwd( - dtype: torch.dtype, - vocab: int, -) -> Callable[..., torch.Tensor]: - return create_compiled_reverse_kl( # ty: ignore[invalid-return-type] - policy_dtype=dtype, - vocab=vocab, - compute_backward=False, - ) - - -@lru_cache(maxsize=32) -def _get_compiled_fwd_bwd( - dtype: torch.dtype, - vocab: int, -) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]: - return create_compiled_reverse_kl( # ty: ignore[invalid-return-type] - policy_dtype=dtype, - vocab=vocab, - compute_backward=True, - ) - - -def _flatten_inputs( - student_logits: torch.Tensor, - teacher_logits: torch.Tensor, - completions_mask: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, tuple[int, ...], int, int]: - """Reshape ``(*, V)`` inputs to ``(num_rows, V)`` for the kernel. - - The kernel works on a flat ``(num_rows, V)`` slab; the public API - accepts arbitrary leading dims (typically ``(N, C, V)``) and we - flatten here so the wrapper signature stays user-friendly. We - assume contiguous-on-V inputs — ``view`` becomes a no-copy reshape. - """ - if student_logits.shape != teacher_logits.shape: - raise ValueError( - "student_logits and teacher_logits must have the same shape; got " - f"{tuple(student_logits.shape)} vs {tuple(teacher_logits.shape)}" - ) - if student_logits.dtype != teacher_logits.dtype: - raise ValueError( - "student_logits and teacher_logits must have the same dtype; got " - f"{student_logits.dtype} vs {teacher_logits.dtype}" - ) - - leading_shape = tuple(student_logits.shape[:-1]) - vocab = int(student_logits.shape[-1]) - - student_2d = student_logits.view(-1, vocab) - teacher_2d = teacher_logits.view(-1, vocab) - num_rows = student_2d.shape[0] - - if tuple(completions_mask.shape) != leading_shape: - raise ValueError( - f"completions_mask must have shape {leading_shape} (logits' leading " - f"dims); got {tuple(completions_mask.shape)}" - ) - flat_mask = completions_mask.view(-1) - if flat_mask.dtype != torch.float32: - flat_mask = flat_mask.to(torch.float32) - - return student_2d, teacher_2d, flat_mask, leading_shape, num_rows, vocab - - -def reverse_kl_fwd( - student_logits: torch.Tensor, - teacher_logits: torch.Tensor, - completions_mask: torch.Tensor, -) -> torch.Tensor: - """Forward-only reverse-KL self-distillation loss. Returns scalar ``loss``. - - Use for inference / validation. The masked mean denominator is - computed on-GPU by the bundled mask-sum kernel — no host - ``mask.sum()`` syncs. - - Args: - student_logits, teacher_logits: ``(*, V)`` logit tensors with - arbitrary leading dims (typically ``(N, C, V)``); both must - share shape and dtype. - completions_mask: Bool / int / float mask with shape matching - ``student_logits.shape[:-1]``; truthy = valid token. - - Returns: - Scalar tensor (0-dim) with the same dtype as ``student_logits``. - """ - student_2d, teacher_2d, flat_mask, _, _, vocab = _flatten_inputs( - student_logits, teacher_logits, completions_mask - ) - run = _get_compiled_fwd(student_logits.dtype, vocab) - return run(student_2d, teacher_2d, flat_mask) - - -def reverse_kl( - student_logits: torch.Tensor, - teacher_logits: torch.Tensor, - completions_mask: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor]: - """Fused fwd+bwd reverse-KL self-distillation. Returns ``(loss, grad_student)``. - - Single-launch training entry point. The kernel writes both the - scalar loss and the analytical ``dL/d(student_logits)`` in one - ``@cute.jit`` dispatch — the bundled mask-sum kernel populates - ``inv_n_valid`` on-GPU before the main kernel reads it, so there's - no host-side ``mask.sum()`` round trip. - - Inputs do **not** need ``requires_grad=True``. To chain ``grad`` - into the upstream model that produced ``student_logits``:: - - loss, grad = reverse_kl(student_logits, teacher_logits, mask) - student_logits.backward(grad) - optimizer.step() - - Args: - student_logits, teacher_logits: ``(*, V)`` logit tensors with - arbitrary leading dims; both must share shape and dtype. - completions_mask: Mask with shape matching - ``student_logits.shape[:-1]``. - - Returns: - ``(loss, grad_student_logits)`` — ``loss`` is a 0-dim tensor in - ``student_logits.dtype``; ``grad_student_logits`` matches - ``student_logits.shape`` and is already scaled by - ``1 / n_valid`` (undefined when ``n_valid == 0`` — fully-masked - batches produce inf/NaN; callers must guard upstream). The grad - tensor is freshly allocated per call (no shared cache). - - For inference / validation where you only need the loss, use - :func:`reverse_kl_fwd` — it skips the gradient slab entirely. - """ - student_2d, teacher_2d, flat_mask, leading_shape, _, vocab = _flatten_inputs( - student_logits, teacher_logits, completions_mask - ) - run = _get_compiled_fwd_bwd(student_logits.dtype, vocab) - loss, grad_2d = run(student_2d, teacher_2d, flat_mask) - grad_student = grad_2d.view((*leading_shape, vocab)) - return loss, grad_student - - -# Imported at the bottom: ``autograd.py`` imports ``reverse_kl`` from -# this module, so the function must be fully defined before its import runs. -from .autograd import reverse_kl_autograd # noqa: E402 diff --git a/build/torch-cuda/reverse_kl/_torch_ref.py b/build/torch-cuda/reverse_kl/_torch_ref.py deleted file mode 100644 index 73b0db92c145f18bb8030cc224acc6d21a91a7bc..0000000000000000000000000000000000000000 --- a/build/torch-cuda/reverse_kl/_torch_ref.py +++ /dev/null @@ -1,65 +0,0 @@ -"""Plain-PyTorch reverse-KL reference shared between bench and tests. - -Every op is a vanilla torch op so AOTAutograd can derive the joint -fwd+bwd graph and Inductor can fuse both passes (used by -``benchmarks/benchmark_reverse_kl.py``'s compiled baseline). -The same function is imported by ``tests/test_reverse_kl.py`` -as the correctness reference, so both paths agree on what "the eager -torch implementation of reverse-KL self-distillation" means. - -Reverse-KL definition (KL(student || teacher)): - - p = softmax(student) - q = softmax(teacher) - kl_per_row = sum_v p_v * (log p_v - log q_v) - loss = sum_r mask_r * kl_per_row[r] / max(sum_r mask_r, 1) - -The ``clamp(min=1)`` matches TRL's masked-mean convention so a -fully-masked batch yields ``loss=0`` instead of NaN, mirroring the -cute kernel's ``cute.arch.fmax(n_valid, 1.0)`` clamp. - -Underscore-prefixed module name signals "shared internal", not a public -API surface — there's no re-export from the package's top-level -``__init__.py``. -""" - -from __future__ import annotations - -import torch -from torch.nn.functional import kl_div, log_softmax - - -def torch_reverse_kl( - student_logits: torch.Tensor, - teacher_logits: torch.Tensor, - completions_mask: torch.Tensor, -) -> torch.Tensor: - """Compute reverse-KL divergence loss. - - Computes per-token reverse KL divergence: - - KL(student || teacher) = sum_v p(v) [log p(v) - log q(v)] - - where p is the student distribution and q is the teacher distribution, - both obtained by softmax over the vocabulary dimension. - - Args: - student_logits: Student logits, shape (N, C, V). - teacher_logits: Teacher logits, shape (N, C, V). - completions_mask: Boolean mask of shape (N, C) where True marks - valid tokens. Pass an all-ones mask if no tokens are padded. - - Returns: - Scalar tensor representing the loss. - """ - log_p = log_softmax(student_logits, dim=-1) - log_q = log_softmax(teacher_logits, dim=-1) - - # kl_div(input, target, log_target=True) computes KL(target || input) - # so input=log_q, target=log_p gives KL(student || teacher) - kl = kl_div(log_q, log_p, log_target=True, reduction="none").sum(dim=-1) - - n_valid = completions_mask.sum().to(torch.float32) - kl = (kl * completions_mask).sum() / n_valid - - return kl.to(student_logits.dtype) diff --git a/build/torch-cuda/reverse_kl/autograd.py b/build/torch-cuda/reverse_kl/autograd.py deleted file mode 100644 index 9d2d4e602a9e8efd2a75cfeb60df6ac8b68e6573..0000000000000000000000000000000000000000 --- a/build/torch-cuda/reverse_kl/autograd.py +++ /dev/null @@ -1,122 +0,0 @@ -"""Autograd-aware wrapper for reverse-KL self-distillation via ``torch.library.custom_op``. - -The fused cute kernel writes both the scalar loss and the closed-form -``dL/d(student_logits)`` in one launch. This module wraps that into an -autograd-compatible op so callers can write:: - - loss = reverse_kl_autograd(student, teacher, completions_mask) - loss.backward() # propagates through to whatever produced student_logits - -instead of the manual ``student.backward(grad)`` chain. The cost is -~12µs of autograd dispatcher overhead per call (vs the direct -``reverse_kl`` (loss, grad) tuple); for ergonomic / -``kernelize()`` flows that's cheap, but for tight microbenches use the -direct path. - -Implementation notes: - -- The registered op returns ``(loss, grad_student)`` so - ``setup_context`` can ``save_for_backward(grad_student)``. The public - :func:`reverse_kl_autograd` wrapper hides the second output. -- The runner allocates ``grad_student`` fresh on every call (no shared - cache), so ``ctx.save_for_backward(grad_student)`` keeps a stable - reference for free. -- Backward returns ``grad_loss * grad_student``. Under - ``torch.compile``, when ``loss`` is consumed by ``.backward()`` - directly, ``grad_loss`` is the constant 1.0 and Inductor can fold the - multiply away — the main reason this path uses ``custom_op`` instead - of a plain ``autograd.Function``. -- ``register_fake`` provides the meta kernel for ``torch.compile`` - shape propagation; the real cute kernel never runs under - ``FakeTensorMode``. -""" - -from __future__ import annotations - -import torch - -from . import reverse_kl as _reverse_kl_fwd_bwd - -__all__ = ["reverse_kl_autograd"] - - -@torch.library.custom_op( - "geometric_ai_kernels::_reverse_kl_with_grad", - mutates_args=(), -) -def _reverse_kl_with_grad( - student_logits: torch.Tensor, - teacher_logits: torch.Tensor, - completions_mask: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor]: - loss, grad_student = _reverse_kl_fwd_bwd( - student_logits, - teacher_logits, - completions_mask, - ) - return loss, grad_student - - -@_reverse_kl_with_grad.register_fake -def _( - student_logits: torch.Tensor, - teacher_logits: torch.Tensor, - completions_mask: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor]: - # Signature must mirror the op; only ``student_logits`` shapes the outputs. - del teacher_logits, completions_mask - loss = student_logits.new_empty(()) - grad_student = torch.empty_like(student_logits) - return loss, grad_student - - -def _setup_context(ctx, inputs, output) -> None: # type: ignore[no-untyped-def] - del inputs # only ``output`` carries what we need to save. - _, grad_student = output - ctx.save_for_backward(grad_student) - - -def _backward(ctx, grad_loss, grad_grad_student): # type: ignore[no-untyped-def] - # ``grad_grad_student`` is unused — ``grad_student`` is an internal - # intermediate exposed only so ``setup_context`` can save it. Under - # typical usage (``loss.backward()``) it arrives as ``None`` or a - # zero tensor. - del grad_grad_student - (grad_student,) = ctx.saved_tensors - grad_input = grad_loss * grad_student - # One return per input to the op (3): student_logits gets the grad, - # teacher_logits and completions_mask have no autograd flow. - return grad_input, None, None - - -torch.library.register_autograd( - "geometric_ai_kernels::_reverse_kl_with_grad", - _backward, - setup_context=_setup_context, -) - - -def reverse_kl_autograd( - student_logits: torch.Tensor, - teacher_logits: torch.Tensor, - completions_mask: torch.Tensor, -) -> torch.Tensor: - """Autograd-aware reverse-KL self-distillation. Returns scalar ``loss``. - - Same numerics as :func:`reverse_kl` but registered as a - ``torch.library`` custom op with autograd, so:: - - loss = reverse_kl_autograd(student, teacher, completions_mask) - loss.backward() - - propagates through to whatever produced ``student_logits``. For - direct ``(loss, grad)`` access without the autograd dispatcher - overhead, use :func:`reverse_kl` and chain the gradient - manually via ``student_logits.backward(grad)``. - - Composes with ``torch.compile``: the op is opaque to Inductor but - has a fake/meta kernel registered, so models containing this layer - can be compiled end-to-end without graph breaks. - """ - loss, _ = _reverse_kl_with_grad(student_logits, teacher_logits, completions_mask) - return loss diff --git a/build/torch-cuda/reverse_kl/cute_reverse_kl.py b/build/torch-cuda/reverse_kl/cute_reverse_kl.py deleted file mode 100644 index 3f9a7a16c50135af790146306dd5ccc5ddd6e744..0000000000000000000000000000000000000000 --- a/build/torch-cuda/reverse_kl/cute_reverse_kl.py +++ /dev/null @@ -1,881 +0,0 @@ -"""Fused single-pass reverse-KL self-distillation loss kernel (CuteDSL, SM90). - -Computes ``KL(student || teacher)`` over a ``(num_tokens, vocab)`` slab using -an online normalisation algorithm that reads each logit row exactly once. -The two log-softmax passes, the element-wise product ``p * (log_p - log_q)``, -the sum-reduction over ``V``, the mask application, and the final mean are -all fused into one launch: - -1. **Mask-sum kernel** — reduces ``mask_flat`` and writes - ``inv_n_valid = 1 / sum(mask)``. Undefined when ``sum(mask) == 0`` - (fully-masked batches produce inf/NaN in the final loss); callers - must guard upstream if that case is reachable. -2. **Per-row main kernel** — one CTA per token; computes per-row KL and - (when ``compute_backward=True``) writes the analytical gradient - through the softmax Jacobian: - - ``grad_student_v = scale * p_v * (log_p_v - log_q_v - KL_per_row)`` - - where ``scale = mask[r] * grad_output * inv_n_valid``. The two passes - per row (online stats then gradient write-out) execute within a single - CTA so the Pass-1 ``KL_per_row`` is broadcast through SMEM with no - DSMEM/cluster overhead. The cross-row loss reduction piggybacks on the - atomic-add + last-block-detect pattern used by ``cute_bnpo_loss``. - -When ``compute_backward=False`` Pass 2, the broadcast SMEM, the -``grad_output`` read, and the ``* grad_output`` factor in the final scalar -are all dead-code-eliminated at trace time — the kernel becomes a pure -forward path identical in PTX to a hand-written fwd-only kernel. - -A single public entry point :func:`create_compiled_reverse_kl` -JIT-compiles either the fwd-only or fused fwd+bwd path depending on its -``compute_backward`` flag. ``vocab`` is captured as a compile-time -constant (the tile + tail layout closes over it); the number of token -rows is symbolic and may vary across calls. -""" - -from __future__ import annotations - -import math -import operator -from typing import TYPE_CHECKING, Any - -import cutlass -import torch -from cutlass import cute -from cutlass._mlir.dialects import llvm -from cutlass.base_dsl.typing import cast -from cutlass.cute.nvgpu import CacheEvictionPriority, CopyUniversalOp -from cutlass.cutlass_dsl import T, dsl_user_op -from cutlass.utils import SmemAllocator - -if TYPE_CHECKING: - from collections.abc import Callable - - -# --------------------------------------------------------------------------- -# Tunable constants — shared by fwd-only and fwd+bwd kernels -# --------------------------------------------------------------------------- -# Wider tile, more threads, and 128-bit loads so the two-pass (fwd+bwd) and -# single-pass (fwd-only) variants both stay bandwidth-bound from the same -# specialisation. ``FB_VEC_SIZE`` is recomputed at compile time from -# ``LOAD_BITS`` and the dtype width: -# FP16/BF16 → VEC=8, TILE_V=8192 -# FP32 → VEC=4, TILE_V=4096 -FB_NUM_THREADS = 1024 -FB_NUM_WARPS = FB_NUM_THREADS // 32 # 32 -FB_LOAD_BITS = 128 - -_LOG2E = math.log2(math.e) # 1.4426950408889634 - -_TORCH_TO_CUTLASS_DTYPE: dict[torch.dtype, Any] = { - torch.float32: cutlass.Float32, - torch.float16: cutlass.Float16, - torch.bfloat16: cutlass.BFloat16, -} - - -# --------------------------------------------------------------------------- -# Vendored atomic helpers (mirrors cute_bnpo_loss._atomic_*). Copied locally -# so the reverse_kl subpackage stays independent of bnpo_loss — -# the two ship together today, but kernel packages should be free-standing -# so a single one can be peeled off for separate publishing. -# --------------------------------------------------------------------------- - - -@dsl_user_op -def _atomic_add_f32_gmem( - ptr_i64: Any, - val: cutlass.Float32, - *, - loc: Any = None, - ip: Any = None, -) -> None: - llvm.inline_asm( - T.f32(), - [ptr_i64, cutlass.Float32(val).ir_value(loc=loc, ip=ip)], - "atom.global.add.f32 $0, [$1], $2;", - "=f,l,f", - has_side_effects=True, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - - -@dsl_user_op -def _atomic_inc_u32_gmem( - ptr_i64: Any, - threshold: cutlass.Int32, - *, - loc: Any = None, - ip: Any = None, -) -> cutlass.Int32: - """``atom.global.inc.u32`` — returns old value; wraps to 0 at threshold.""" - return cutlass.Int32( - llvm.inline_asm( - T.i32(), - [ptr_i64, cutlass.Int32(threshold).ir_value(loc=loc, ip=ip)], - "atom.global.inc.u32 $0, [$1], $2;", - "=r,l,r", - has_side_effects=True, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - ) - - -# --------------------------------------------------------------------------- -# Mask-sum kernel — reduces ``mask_flat`` (fp32, length N) to a per-block -# partial via warp + SMEM reduction, then atomically accumulates into -# ``valid_acc``; the last block writes ``rcp_approx(n_valid)`` into -# ``inv_n_valid`` and resets ``valid_acc`` to 0. Counter self-resets via -# ``atom.inc.u32`` wrap-around. -# --------------------------------------------------------------------------- - - -def _make_mask_sum_kernel( - fb_num_threads: int, - fb_num_warps: int, -) -> Callable[..., None]: - """Return a ``@cute.kernel`` that reduces ``mask_flat`` and writes 1/sum. - - Grid: ``(num_blocks, 1, 1)`` where each block processes - ``fb_num_threads`` mask elements (one element per thread, no - vectorisation — sufficient for the small N relative to vocab work). - A separate ``mask_counter`` (not shared with the main kernel's - ``counter``) is required because both rely on ``atom.inc.u32`` - wrap-around for self-reset. - """ - - @cute.kernel - def _mask_sum_kernel( - mask_flat: cute.Tensor, - inv_n_valid: cute.Tensor, - valid_acc: cute.Tensor, - mask_counter: cute.Tensor, - num_rows: cutlass.Int32, - total_blocks: cutlass.Int32, - ) -> None: - block_size = fb_num_warps * 32 - bidx = cute.arch.block_idx()[0] - tidx = cute.arch.thread_idx()[0] - - global_idx = bidx * block_size + tidx - local_val = cutlass.Float32(0.0) - if global_idx < num_rows: - local_val = mask_flat[global_idx] - - warp_val = cute.arch.warp_reduction(local_val, operator.add) - - smem = SmemAllocator() - buf = smem.allocate_tensor(cutlass.Float32, cute.make_layout(fb_num_warps)) - - lane_idx = cute.arch.lane_idx() - warp_idx = cute.arch.warp_idx() - - if lane_idx == 0: - buf[warp_idx] = warp_val - cute.arch.barrier() - - if warp_idx == 0: - v = cutlass.Float32(0.0) - if lane_idx < fb_num_warps: - v = buf[lane_idx] - block_sum = cute.arch.warp_reduction(v, operator.add, threads_in_group=fb_num_warps) - - if lane_idx == 0: - valid_ptr = valid_acc.iterator.toint().ir_value() # ty: ignore[unresolved-attribute] - counter_ptr = mask_counter.iterator.toint().ir_value() # ty: ignore[unresolved-attribute] - _atomic_add_f32_gmem(valid_ptr, block_sum) - cute.arch.fence_acq_rel_gpu() - old = _atomic_inc_u32_gmem(counter_ptr, total_blocks - 1) - if old == total_blocks - 1: - n_valid = valid_acc[0] - inv_n_valid[0] = cute.arch.rcp_approx(n_valid) - valid_acc[0] = cutlass.Float32(0.0) - - return _mask_sum_kernel - - -# --------------------------------------------------------------------------- -# Main per-row reverse-KL kernel (fused fwd[+bwd]). -# --------------------------------------------------------------------------- - - -def _make_reverse_kl_kernel( - compute_backward: bool, - fb_num_threads: int, - fb_num_warps: int, - fb_vec_size: int, - fb_tile_v: int, -) -> Callable[..., None]: - """Return a ``@cute.kernel`` that fuses fwd loss + (optional) bwd grad. - - One CTA processes one row. Pass 1 computes the online softmax stats - ``(max_s, D_s, W = sum exp(s - max_s) (s - t), max_t, D_t)`` and the - per-row ``KL = W/D_s + (max_t - max_s) + log(D_t) - log(D_s)``. - - When ``compute_backward=True`` the block-reduced stats are broadcast - through SMEM and Pass 2 re-reads student/teacher to write the - analytical gradient - ``grad_v = scale * p_v * (log_p_v - log_q_v - KL)`` where - ``p_v = exp(s_v - max_s) / D_s``, - ``log_p_v - log_q_v = (s_v - t_v) + (max_t - max_s) + log(D_t) - log(D_s)``, - and ``scale = mask[r] * grad_output * inv_n_valid``. - - When ``compute_backward=False`` Pass 2, the bcast SMEM, the - ``grad_output`` read, and the ``* grad_output`` factor in the final - scalar are eliminated at trace time. - - Cross-row loss accumulation rides on an atomic ``loss_acc`` plus a - wrap-around ``counter`` for last-block detection. The kernel always - reads ``mask_flat[bidx]`` (callers pass an all-ones mask in the - unmasked case) and, in the bwd path, short-circuits Pass 2 — writes - zeros and skips loss accumulation — for rows whose ``scale`` is 0. - """ - - @cute.kernel - def _kernel( - student: cute.Tensor, - teacher: cute.Tensor, - mask_flat: cute.Tensor, - grad_output: cute.Tensor, - inv_n_valid: cute.Tensor, - grad_student: cute.Tensor, - loss_acc: cute.Tensor, - counter: cute.Tensor, - output: cute.Tensor, - num_full_tiles: int, - tail_len: int, - total_rows: cutlass.Int32, - tiled_copy_p1: cute.TiledCopy, - copy_atom_p1: cute.CopyAtom, - tiled_copy_p2: cute.TiledCopy, - copy_atom_p2: cute.CopyAtom, - copy_atom_store: cute.CopyAtom, - ) -> None: - tidx = cute.arch.thread_idx()[0] - bidx = cute.arch.block_idx()[0] - - # SMEM: - # red_buf: NUM_WARPS x 5 — per-warp partials for cross-warp reduce - # bcast: 6 floats — broadcast (kl, max_s, log_d_s, max_t, - # log_d_t, scale) to all threads in pass 2 - # The bcast buffer is only needed when ``compute_backward`` is True. - smem = SmemAllocator() - red_buf = smem.allocate_tensor(cutlass.Float32, cute.make_layout((fb_num_warps, 5))) - if cutlass.const_expr(compute_backward): - bcast = smem.allocate_tensor(cutlass.Float32, cute.make_layout(6)) - - log2e = cutlass.Float32(_LOG2E) - neg_inf = cutlass.Float32(-cutlass.Float32.inf) - - # Per-row scale = mask[bidx] * grad_output * inv_n_valid. The - # ``grad_output`` read is dead-code-eliminated for the fwd-only - # path so no spurious GMEM load is emitted. - mask_val = cast(mask_flat[bidx], cutlass.Float32) - inv_n = cast(inv_n_valid[0], cutlass.Float32) - if cutlass.const_expr(compute_backward): - grad_val = cast(grad_output[0], cutlass.Float32) - scale_val = mask_val * grad_val * inv_n - - s_row = cute.slice_(student, (cutlass.Int64(bidx), None)) - t_row = cute.slice_(teacher, (cutlass.Int64(bidx), None)) - if cutlass.const_expr(compute_backward): - g_row = cute.slice_(grad_student, (cutlass.Int64(bidx), None)) - out_dtype = grad_student.element_type - - max_s = neg_inf - d_s = cutlass.Float32(0.0) - w_acc = cutlass.Float32(0.0) - max_t = neg_inf - d_t = cutlass.Float32(0.0) - - thr_copy_p1 = tiled_copy_p1.get_slice(tidx) - - # ---- Pass 1: online stats + W ---- - for k in cutlass.range(num_full_tiles, unroll=2): - s_slab = cute.local_tile(s_row, (fb_tile_v,), (k,)) - t_slab = cute.local_tile(t_row, (fb_tile_v,), (k,)) - - src_s = thr_copy_p1.partition_S(s_slab) - frag_s = cute.make_fragment_like(src_s) - cute.copy(copy_atom_p1, src_s, frag_s) - - src_t = thr_copy_p1.partition_S(t_slab) - frag_t = cute.make_fragment_like(src_t) - cute.copy(copy_atom_p1, src_t, frag_t) - - s_f32 = frag_s.load().to(cutlass.Float32) - t_f32 = frag_t.load().to(cutlass.Float32) - - tile_max_s = s_f32.reduce(cute.ReductionOp.MAX, neg_inf, reduction_profile=0) - exp_s = cute.math.exp2((s_f32 - tile_max_s) * log2e, fastmath=True) - tile_d_s = exp_s.reduce( # ty: ignore[unresolved-attribute] - cute.ReductionOp.ADD, cutlass.Float32(0.0), reduction_profile=0 - ) - diff = s_f32 - t_f32 - tile_w = (exp_s * diff).reduce( - cute.ReductionOp.ADD, cutlass.Float32(0.0), reduction_profile=0 - ) - - new_max_s = cute.arch.fmax(max_s, tile_max_s) - corr_s = cute.math.exp2((max_s - new_max_s) * log2e, fastmath=True) - tile_corr_s = cute.math.exp2((tile_max_s - new_max_s) * log2e, fastmath=True) - d_s = d_s * corr_s + tile_d_s * tile_corr_s - w_acc = w_acc * corr_s + tile_w * tile_corr_s - max_s = new_max_s - - tile_max_t = t_f32.reduce(cute.ReductionOp.MAX, neg_inf, reduction_profile=0) - exp_t = cute.math.exp2((t_f32 - tile_max_t) * log2e, fastmath=True) - tile_d_t = exp_t.reduce( # ty: ignore[unresolved-attribute] - cute.ReductionOp.ADD, cutlass.Float32(0.0), reduction_profile=0 - ) - - new_max_t = cute.arch.fmax(max_t, tile_max_t) - corr_t = cute.math.exp2((max_t - new_max_t) * log2e, fastmath=True) - tile_corr_t = cute.math.exp2((tile_max_t - new_max_t) * log2e, fastmath=True) - d_t = d_t * corr_t + tile_d_t * tile_corr_t - max_t = new_max_t - - if tail_len > 0: - tail_base = num_full_tiles * fb_tile_v - - thr_max_s = neg_inf - thr_max_t = neg_inf - for i in cutlass.range(fb_vec_size): - e = tidx + i * fb_num_threads - if e < tail_len: - s_val = cast(s_row[tail_base + e], cutlass.Float32) - t_val = cast(t_row[tail_base + e], cutlass.Float32) - thr_max_s = cute.arch.fmax(thr_max_s, s_val) - thr_max_t = cute.arch.fmax(thr_max_t, t_val) - - thr_d_s = cutlass.Float32(0.0) - thr_w = cutlass.Float32(0.0) - thr_d_t = cutlass.Float32(0.0) - for i in cutlass.range(fb_vec_size): - e = tidx + i * fb_num_threads - if e < tail_len: - s_val = cast(s_row[tail_base + e], cutlass.Float32) - t_val = cast(t_row[tail_base + e], cutlass.Float32) - exp_sv = cute.math.exp2((s_val - thr_max_s) * log2e, fastmath=True) - thr_d_s = thr_d_s + exp_sv - thr_w = thr_w + exp_sv * (s_val - t_val) - exp_tv = cute.math.exp2((t_val - thr_max_t) * log2e, fastmath=True) - thr_d_t = thr_d_t + exp_tv - - new_max_s = cute.arch.fmax(max_s, thr_max_s) - corr_s = cute.math.exp2((max_s - new_max_s) * log2e, fastmath=True) - tail_corr_s = cute.math.exp2((thr_max_s - new_max_s) * log2e, fastmath=True) - d_s = d_s * corr_s + thr_d_s * tail_corr_s - w_acc = w_acc * corr_s + thr_w * tail_corr_s - max_s = new_max_s - - new_max_t = cute.arch.fmax(max_t, thr_max_t) - corr_t = cute.math.exp2((max_t - new_max_t) * log2e, fastmath=True) - tail_corr_t = cute.math.exp2((thr_max_t - new_max_t) * log2e, fastmath=True) - d_t = d_t * corr_t + thr_d_t * tail_corr_t - max_t = new_max_t - - # ---- Cross-warp reduction for student/teacher stats ---- - warp_max_s = cute.arch.warp_reduction(max_s, cute.arch.fmax) - corr_w_s = cute.math.exp2((max_s - warp_max_s) * log2e, fastmath=True) - d_s = d_s * corr_w_s - w_acc = w_acc * corr_w_s - warp_d_s = cute.arch.warp_reduction(d_s, operator.add) - warp_w = cute.arch.warp_reduction(w_acc, operator.add) - - warp_max_t = cute.arch.warp_reduction(max_t, cute.arch.fmax) - corr_w_t = cute.math.exp2((max_t - warp_max_t) * log2e, fastmath=True) - d_t = d_t * corr_w_t - warp_d_t = cute.arch.warp_reduction(d_t, operator.add) - - lane_idx = cute.arch.lane_idx() - warp_idx = cute.arch.warp_idx() - - if lane_idx == 0: - red_buf[warp_idx, 0] = warp_max_s - red_buf[warp_idx, 1] = warp_d_s - red_buf[warp_idx, 2] = warp_w - red_buf[warp_idx, 3] = warp_max_t - red_buf[warp_idx, 4] = warp_d_t - cute.arch.sync_threads() - - # Warp 0 finishes the cross-warp reduction. - if warp_idx == 0: - r_max_s = neg_inf - r_d_s = cutlass.Float32(0.0) - r_w = cutlass.Float32(0.0) - r_max_t = neg_inf - r_d_t = cutlass.Float32(0.0) - - if lane_idx < fb_num_warps: - r_max_s = red_buf[lane_idx, 0] - r_d_s = red_buf[lane_idx, 1] - r_w = red_buf[lane_idx, 2] - r_max_t = red_buf[lane_idx, 3] - r_d_t = red_buf[lane_idx, 4] - - final_max_s = cute.arch.warp_reduction( - r_max_s, cute.arch.fmax, threads_in_group=fb_num_warps - ) - fcorr_s = cute.math.exp2((r_max_s - final_max_s) * log2e, fastmath=True) - r_d_s = r_d_s * fcorr_s - r_w = r_w * fcorr_s - final_d_s = cute.arch.warp_reduction(r_d_s, operator.add, threads_in_group=fb_num_warps) - final_w = cute.arch.warp_reduction(r_w, operator.add, threads_in_group=fb_num_warps) - - final_max_t = cute.arch.warp_reduction( - r_max_t, cute.arch.fmax, threads_in_group=fb_num_warps - ) - fcorr_t = cute.math.exp2((r_max_t - final_max_t) * log2e, fastmath=True) - r_d_t = r_d_t * fcorr_t - final_d_t = cute.arch.warp_reduction(r_d_t, operator.add, threads_in_group=fb_num_warps) - - if lane_idx == 0: - rcp_d_s = cute.arch.rcp_approx(final_d_s) - log_d_s = cute.math.log(final_d_s) - log_d_t = cute.math.log(final_d_t) - kl = final_w * rcp_d_s + log_d_t + final_max_t - log_d_s - final_max_s - - # Cross-row loss accumulation: atomic-add (kl * mask) into - # ``loss_acc``; the last block scales by ``inv_n_valid`` - # (and ``grad_output`` on the bwd path) and writes the - # scalar output. - loss_ptr = loss_acc.iterator.toint().ir_value() # ty: ignore[unresolved-attribute] - counter_ptr = counter.iterator.toint().ir_value() # ty: ignore[unresolved-attribute] - contribution = kl * mask_val - _atomic_add_f32_gmem(loss_ptr, contribution) - cute.arch.fence_acq_rel_gpu() - old = _atomic_inc_u32_gmem(counter_ptr, total_rows - 1) - if old == total_rows - 1: - if cutlass.const_expr(compute_backward): - final_loss = loss_acc[0] * inv_n * grad_val - else: - final_loss = loss_acc[0] * inv_n - output[0] = cast(final_loss, output.element_type) # ty: ignore[invalid-argument-type] - loss_acc[0] = cutlass.Float32(0.0) - - if cutlass.const_expr(compute_backward): - bcast[0] = kl - bcast[1] = final_max_s - bcast[2] = log_d_s - bcast[3] = final_max_t - bcast[4] = log_d_t - bcast[5] = scale_val - - # ---- Pass 2: re-read logits and write gradient ---- - # Entire pass is dead-code-eliminated when ``compute_backward`` is - # False — the kernel becomes a single-pass forward identical in - # PTX to a hand-written fwd-only kernel. - if cutlass.const_expr(compute_backward): - cute.arch.sync_threads() - - kl_b = bcast[0] - max_s_b = bcast[1] - log_d_s_b = bcast[2] - max_t_b = bcast[3] - log_d_t_b = bcast[4] - scale_b = bcast[5] - - log_offset = max_t_b - max_s_b + log_d_t_b - log_d_s_b - - # Skip Pass 2 entirely for rows with scale=0 (masked-out rows). - if scale_b != cutlass.Float32(0.0): - thr_copy_p2 = tiled_copy_p2.get_slice(tidx) - - for k in cutlass.range(num_full_tiles, unroll=2): - s_slab = cute.local_tile(s_row, (fb_tile_v,), (k,)) - t_slab = cute.local_tile(t_row, (fb_tile_v,), (k,)) - g_slab = cute.local_tile(g_row, (fb_tile_v,), (k,)) - - src_s = thr_copy_p2.partition_S(s_slab) - frag_s = cute.make_fragment_like(src_s) - cute.copy(copy_atom_p2, src_s, frag_s) - - src_t = thr_copy_p2.partition_S(t_slab) - frag_t = cute.make_fragment_like(src_t) - cute.copy(copy_atom_p2, src_t, frag_t) - - s_f32 = frag_s.load().to(cutlass.Float32) - t_f32 = frag_t.load().to(cutlass.Float32) - - # p_v = exp((s_v - max_s) * log2e) / D_s = exp(s_v - logZ_s) - p_v = cute.math.exp2((s_f32 - max_s_b - log_d_s_b) * log2e, fastmath=True) - log_diff = s_f32 - t_f32 + log_offset - grad = scale_b * p_v * (log_diff - kl_b) - - dst_g = thr_copy_p2.partition_D(g_slab) - out_frag = cute.make_fragment_like(dst_g) - out_frag.store(grad.to(out_dtype)) - cute.copy(copy_atom_store, out_frag, dst_g) - - if tail_len > 0: - tail_base = num_full_tiles * fb_tile_v - for i in cutlass.range(fb_vec_size): - e = tidx + i * fb_num_threads - if e < tail_len: - s_val = cast(s_row[tail_base + e], cutlass.Float32) - t_val = cast(t_row[tail_base + e], cutlass.Float32) - p_v = cute.math.exp2( - (s_val - max_s_b - log_d_s_b) * log2e, fastmath=True - ) - log_diff = s_val - t_val + log_offset - grad = scale_b * p_v * (log_diff - kl_b) - g_row[tail_base + e] = cast(grad, out_dtype) # ty: ignore[invalid-argument-type] - else: - # Masked row: write zeros so callers see a clean grad slab. - zero_elem = cast(cutlass.Float32(0.0), out_dtype) # ty: ignore[invalid-argument-type] - thr_copy_z = tiled_copy_p2.get_slice(tidx) - for k in cutlass.range(num_full_tiles, unroll=2): - g_slab = cute.local_tile(g_row, (fb_tile_v,), (k,)) - dst_g = thr_copy_z.partition_D(g_slab) - zero_frag = cute.make_fragment_like(dst_g) - zero_frag.fill(zero_elem) - cute.copy(copy_atom_store, zero_frag, dst_g) - - if tail_len > 0: - tail_base = num_full_tiles * fb_tile_v - for i in cutlass.range(fb_vec_size): - e = tidx + i * fb_num_threads - if e < tail_len: - g_row[tail_base + e] = zero_elem - - return _kernel - - -def _make_fwd_bwd_launcher( - compute_backward: bool, - fb_num_threads: int, - fb_num_warps: int, - fb_vec_size: int, - fb_tile_v: int, -) -> Callable[..., None]: - """Return a ``@cute.jit`` launcher that runs mask-sum + main kernel. - - When ``compute_backward=False`` Pass 2 + bcast SMEM are dead-code - eliminated inside the kernel and Pass 1 loads default to NO_ALLOCATE - (no benefit from L2 pinning since there is no re-read). - """ - main_kernel = _make_reverse_kl_kernel( - compute_backward, fb_num_threads, fb_num_warps, fb_vec_size, fb_tile_v - ) - mask_sum_kernel = _make_mask_sum_kernel(fb_num_threads, fb_num_warps) - - @cute.jit - def _launch( - student_2d: cute.Tensor, - teacher_2d: cute.Tensor, - mask_flat: cute.Tensor, - grad_output: cute.Tensor, - inv_n_valid: cute.Tensor, - grad_student_2d: cute.Tensor, - loss_acc: cute.Tensor, - valid_acc: cute.Tensor, - counter: cute.Tensor, - mask_counter: cute.Tensor, - output: cute.Tensor, - num_full_tiles: cutlass.Int32, - tail_len: cutlass.Int32, - mask_sum_blocks: cutlass.Int32, - ) -> None: - num_rows = student_2d.shape[0] # ty: ignore[not-subscriptable] - dtype = student_2d.element_type - out_dtype = grad_student_2d.element_type - - # Pass 1 loads: EVICT_LAST when we need the data pinned in L2 for - # Pass 2 re-reads; NO_ALLOCATE for fwd-only since the data is - # never re-read. - if cutlass.const_expr(compute_backward): - p1_evict = CacheEvictionPriority.EVICT_LAST - else: - p1_evict = CacheEvictionPriority.NO_ALLOCATE - copy_atom_p1 = cute.make_copy_atom( - CopyUniversalOp(), - dtype, - num_bits_per_copy=fb_vec_size * dtype.width, # ty: ignore[unresolved-attribute] - l1c_evict_priority=p1_evict, - ) - thr_layout = cute.make_layout((fb_num_threads,)) - val_layout = cute.make_layout((fb_vec_size,)) - tiler_v_p1, layout_tv_p1 = cute.make_layout_tv(thr_layout, val_layout) - tiled_copy_p1 = cute.make_tiled_copy(copy_atom_p1, layout_tv_p1, tiler_v_p1) - - # Pass 2 loads: NO_ALLOCATE — streaming, never re-read. - copy_atom_p2 = cute.make_copy_atom( - CopyUniversalOp(), - dtype, - num_bits_per_copy=fb_vec_size * dtype.width, # ty: ignore[unresolved-attribute] - l1c_evict_priority=CacheEvictionPriority.NO_ALLOCATE, - ) - tiler_v_p2, layout_tv_p2 = cute.make_layout_tv(thr_layout, val_layout) - tiled_copy_p2 = cute.make_tiled_copy(copy_atom_p2, layout_tv_p2, tiler_v_p2) - - # Stores: NO_ALLOCATE — write-once, never re-read. - copy_atom_store = cute.make_copy_atom( - CopyUniversalOp(), - out_dtype, - num_bits_per_copy=fb_vec_size * out_dtype.width, # ty: ignore[unresolved-attribute] - l1c_evict_priority=CacheEvictionPriority.NO_ALLOCATE, - ) - - mask_sum_kernel( # ty: ignore[unresolved-attribute] - mask_flat, - inv_n_valid, - valid_acc, - mask_counter, - num_rows, - mask_sum_blocks, - ).launch( - grid=(mask_sum_blocks, 1, 1), - block=(fb_num_threads, 1, 1), - ) - - main_kernel( # ty: ignore[unresolved-attribute] - student_2d, - teacher_2d, - mask_flat, - grad_output, - inv_n_valid, - grad_student_2d, - loss_acc, - counter, - output, - num_full_tiles, - tail_len, - num_rows, - tiled_copy_p1, - copy_atom_p1, - tiled_copy_p2, - copy_atom_p2, - copy_atom_store, - ).launch( - grid=(num_rows, 1, 1), - block=(fb_num_threads, 1, 1), - ) - - return _launch - - -# --------------------------------------------------------------------------- -# Public factory -# --------------------------------------------------------------------------- - - -def create_compiled_reverse_kl( - policy_dtype: torch.dtype, - vocab: int, - compute_backward: bool = False, -) -> Callable[..., torch.Tensor | tuple[torch.Tensor, torch.Tensor]]: - """JIT-compile the fused reverse-KL kernel (forward or fused fwd+bwd). - - Bundles a mask-sum kernel (which computes ``inv_n_valid``) and the - main per-row kernel into a single ``@cute.jit`` launch so each call - is one tvm-ffi dispatch with no host syncs. The unmasked case is - handled by passing an all-ones ``mask_flat`` (callers are responsible). - - When ``compute_backward=False`` the runner takes 10 args and returns - the scalar loss. When ``compute_backward=True`` it takes 11 args - (with a real ``grad_student_2d`` slab) and returns - ``(loss_scalar, grad_student_2d)``. - - Args: - policy_dtype: Element dtype of student/teacher logits and the - output gradient slab. Both tensors must share this dtype. - vocab: ``V`` dimension of the ``(num_tokens, V)`` slab. Captured - as a compile-time constant; the number of token rows stays - symbolic across calls. - compute_backward: When ``True`` the kernel additionally writes - the analytical gradient through the softmax Jacobian into - the caller's ``grad_student_2d`` slab. - """ - if policy_dtype not in _TORCH_TO_CUTLASS_DTYPE: - raise ValueError(f"Unsupported dtype for self-distillation loss: {policy_dtype}") - - student_dtype = _TORCH_TO_CUTLASS_DTYPE[policy_dtype] - fb_vec_size = FB_LOAD_BITS // student_dtype.width - fb_tile_v = FB_NUM_THREADS * fb_vec_size - - num_full_tiles = vocab // fb_tile_v - tail_len = vocab % fb_tile_v - block_size = FB_NUM_WARPS * 32 - - num_tokens_sym = cute.sym_int() - - fake_s = cute.runtime.make_fake_compact_tensor( - student_dtype, - (num_tokens_sym, vocab), - stride_order=(1, 0), - assumed_align=16, - ) - fake_t = cute.runtime.make_fake_compact_tensor( - student_dtype, - (num_tokens_sym, vocab), - stride_order=(1, 0), - assumed_align=16, - ) - fake_mask = cute.runtime.make_fake_compact_tensor( - cutlass.Float32, - (num_tokens_sym,), - assumed_align=16, - ) - fake_grad_out = cute.runtime.make_fake_compact_tensor( - cutlass.Float32, - (1,), - assumed_align=16, - ) - fake_inv_n = cute.runtime.make_fake_compact_tensor( - cutlass.Float32, - (1,), - assumed_align=16, - ) - # Full ``(N, V)`` slab when computing the backward; 1-column dummy - # slab for fwd-only since Pass 2 is dead-code-eliminated and only - # the tensor signature matters. - grad_v_dim = vocab if compute_backward else 1 - fake_grad_student = cute.runtime.make_fake_compact_tensor( - student_dtype, - (num_tokens_sym, grad_v_dim), - stride_order=(1, 0), - assumed_align=16, - ) - fake_loss = cute.runtime.make_fake_compact_tensor( - cutlass.Float32, - (1,), - assumed_align=16, - ) - fake_valid = cute.runtime.make_fake_compact_tensor( - cutlass.Float32, - (1,), - assumed_align=16, - ) - fake_counter = cute.runtime.make_fake_compact_tensor( - cutlass.Int32, - (1,), - assumed_align=16, - ) - fake_mask_counter = cute.runtime.make_fake_compact_tensor( - cutlass.Int32, - (1,), - assumed_align=16, - ) - fake_output = cute.runtime.make_fake_compact_tensor( - student_dtype, - (1,), - assumed_align=16, - ) - - launcher = _make_fwd_bwd_launcher( - compute_backward=compute_backward, - fb_num_threads=FB_NUM_THREADS, - fb_num_warps=FB_NUM_WARPS, - fb_vec_size=fb_vec_size, - fb_tile_v=fb_tile_v, - ) - compiled = cute.compile( - launcher, - fake_s, - fake_t, - fake_mask, - fake_grad_out, - fake_inv_n, - fake_grad_student, - fake_loss, - fake_valid, - fake_counter, - fake_mask_counter, - fake_output, - cutlass.Int32(num_full_tiles), - cutlass.Int32(tail_len), - cutlass.Int32(1), - options="--enable-tvm-ffi", - ) - - nft_const = cutlass.Int32(num_full_tiles) - tl_const = cutlass.Int32(tail_len) - grad_v_dim_runtime = vocab if compute_backward else 1 - - # ---- Closure-scoped scratch ---- - # ``grad_output`` is the upstream gradient feeding Pass 2 (1.0 for - # backward, irrelevant for forward — Pass 2 is dead-code-eliminated). - # Constant across calls; allocated lazily on first call when the - # device is known, then reused. - # - # ``scratch_z`` coalesces the 4 atomic-accumulator scalars (counter, - # mask_counter, loss_acc.fp32, valid_acc.fp32) into one int32 slab - # with stride-4 (16-byte) slices so each slot is individually 16-byte - # aligned (``assumed_align=16`` at compile time). Bit-pattern of int32 0 - # equals fp32 0.0, so a single ``zeros`` factory legitimately - # initialises both the int32 counters and the fp32 accumulators. - # Both kernels' last blocks self-reset their fp32 accumulators in - # their epilogues, and counters self-reset via ``atom.inc.u32`` - # wrap-around — so the up-front ``torch.zeros`` only matters for the - # first call. - _scratch: list[tuple[torch.Tensor, torch.Tensor] | None] = [None] - - def _ensure_scratch( - device: torch.device, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - s = _scratch[0] - if s is None or s[0].device != device: - slab = torch.zeros(16, dtype=torch.int32, device=device) - if compute_backward: - grad_output = torch.ones(1, dtype=torch.float32, device=device) - else: - grad_output = torch.empty(1, dtype=torch.float32, device=device) - _scratch[0] = (slab, grad_output) - s = _scratch[0] - slab, grad_output = s - return ( - slab[0:1], # counter (int32) - slab[4:5], # mask_counter (int32) - slab[8:9].view(torch.float32), # loss_acc (fp32) - slab[12:13].view(torch.float32), # valid_acc (fp32) - grad_output, - ) - - def _run( - student_2d: torch.Tensor, - teacher_2d: torch.Tensor, - mask_flat: torch.Tensor, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - num_rows = student_2d.shape[0] - device = student_2d.device - dtype = student_2d.dtype - - counter_r, mask_counter_r, loss_acc_r, valid_acc_r, grad_output_r = _ensure_scratch(device) - - # Per-call write-only buffers — ``empty`` is enough. - # ``inv_n_valid`` is populated by the bundled mask-sum kernel - # before the main kernel reads it; the runner never reads it. - inv_n_valid_r = torch.empty(1, dtype=torch.float32, device=device) - output_r = torch.empty(1, dtype=dtype, device=device) - # ``grad_v_dim`` is ``vocab`` for backward and ``1`` for forward - # (1-column dummy slab — Pass 2 is dead-code-eliminated, only the - # tensor-parameter signature matters). - grad_buffer = torch.empty(num_rows, grad_v_dim_runtime, dtype=dtype, device=device) - - mask_sum_blocks = (num_rows + block_size - 1) // block_size - compiled( - student_2d, - teacher_2d, - mask_flat, - grad_output_r, - inv_n_valid_r, - grad_buffer, - loss_acc_r, - valid_acc_r, - counter_r, - mask_counter_r, - output_r, - nft_const, - tl_const, - cutlass.Int32(mask_sum_blocks), - ) - out_view = output_r.view(()) - if compute_backward: - return out_view, grad_buffer - return out_view - - return _run