This view is limited to 50 files because it contains too many changes. See the raw diff here.
Files changed (50) hide show
  1. README.md +3 -326
  2. benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_dark_animation.svg +0 -123
  3. benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_dark_latency.svg +0 -0
  4. benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_dark_throughput.svg +0 -0
  5. benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_light_animation.svg +0 -123
  6. benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_light_latency.svg +0 -0
  7. benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_light_throughput.svg +0 -0
  8. benchmark_results/bnpo_loss_compiled/results.json +0 -206
  9. benchmark_results/bnpo_loss_eager/bnpo_loss_eager_dark_animation.svg +0 -123
  10. benchmark_results/bnpo_loss_eager/bnpo_loss_eager_dark_latency.svg +0 -0
  11. benchmark_results/bnpo_loss_eager/bnpo_loss_eager_dark_throughput.svg +0 -0
  12. benchmark_results/bnpo_loss_eager/bnpo_loss_eager_light_animation.svg +0 -123
  13. benchmark_results/bnpo_loss_eager/bnpo_loss_eager_light_latency.svg +0 -0
  14. benchmark_results/bnpo_loss_eager/bnpo_loss_eager_light_throughput.svg +0 -0
  15. benchmark_results/bnpo_loss_eager/results.json +0 -206
  16. benchmark_results/grpo_loss_compiled/grpo_loss_compiled_dark_animation.svg +0 -105
  17. benchmark_results/grpo_loss_compiled/grpo_loss_compiled_dark_latency.svg +0 -0
  18. benchmark_results/grpo_loss_compiled/grpo_loss_compiled_dark_throughput.svg +0 -0
  19. benchmark_results/grpo_loss_compiled/grpo_loss_compiled_light_animation.svg +0 -105
  20. benchmark_results/grpo_loss_compiled/grpo_loss_compiled_light_latency.svg +0 -0
  21. benchmark_results/grpo_loss_compiled/grpo_loss_compiled_light_throughput.svg +0 -0
  22. benchmark_results/grpo_loss_compiled/results.json +0 -174
  23. benchmark_results/grpo_loss_eager/grpo_loss_eager_dark_animation.svg +0 -105
  24. benchmark_results/grpo_loss_eager/grpo_loss_eager_dark_latency.svg +0 -0
  25. benchmark_results/grpo_loss_eager/grpo_loss_eager_dark_throughput.svg +0 -0
  26. benchmark_results/grpo_loss_eager/grpo_loss_eager_light_animation.svg +0 -105
  27. benchmark_results/grpo_loss_eager/grpo_loss_eager_light_latency.svg +0 -0
  28. benchmark_results/grpo_loss_eager/grpo_loss_eager_light_throughput.svg +0 -0
  29. benchmark_results/grpo_loss_eager/results.json +0 -174
  30. benchmark_results/reverse_kl_compiled/results.json +0 -206
  31. benchmark_results/reverse_kl_compiled/reverse_kl_compiled_dark_animation.svg +0 -123
  32. benchmark_results/reverse_kl_compiled/reverse_kl_compiled_dark_latency.svg +0 -0
  33. benchmark_results/reverse_kl_compiled/reverse_kl_compiled_dark_throughput.svg +0 -0
  34. benchmark_results/reverse_kl_compiled/reverse_kl_compiled_light_animation.svg +0 -123
  35. benchmark_results/reverse_kl_compiled/reverse_kl_compiled_light_latency.svg +0 -0
  36. benchmark_results/reverse_kl_compiled/reverse_kl_compiled_light_throughput.svg +0 -0
  37. benchmark_results/reverse_kl_eager/results.json +0 -206
  38. benchmark_results/reverse_kl_eager/reverse_kl_eager_dark_animation.svg +0 -123
  39. benchmark_results/reverse_kl_eager/reverse_kl_eager_dark_latency.svg +0 -0
  40. benchmark_results/reverse_kl_eager/reverse_kl_eager_dark_throughput.svg +0 -0
  41. benchmark_results/reverse_kl_eager/reverse_kl_eager_light_animation.svg +0 -123
  42. benchmark_results/reverse_kl_eager/reverse_kl_eager_light_latency.svg +0 -0
  43. benchmark_results/reverse_kl_eager/reverse_kl_eager_light_throughput.svg +0 -0
  44. build/torch-cuda/__init__.py +0 -69
  45. build/torch-cuda/_ops.py +0 -38
  46. build/torch-cuda/bnpo_loss/__init__.py +0 -196
  47. build/torch-cuda/bnpo_loss/_torch_ref.py +0 -56
  48. build/torch-cuda/bnpo_loss/autograd.py +0 -149
  49. build/torch-cuda/bnpo_loss/cute_bnpo_loss.py +0 -1081
  50. build/torch-cuda/geometric_ai_kernels/__init__.py +0 -26
README.md CHANGED
@@ -1,326 +1,3 @@
1
- ---
2
- library_name: kernels
3
- license: apache-2.0
4
- tags:
5
- - cuda
6
- - cutlass
7
- - cute-dsl
8
- - rl
9
- - distillation
10
- - trl
11
- - grpo
12
- - bnpo
13
- - kl-divergence
14
- ---
15
-
16
- # Geometric-AI Kernels
17
-
18
- Fused **CuteDSL** kernels for the loss functions that dominate post-training
19
- workloads: PPO-family policy losses (BNPO, GRPO) and reverse-KL
20
- self-distillation.
21
-
22
- Each kernel ships a **single-launch fused forward +
23
- backward** path that returns `(loss, grad_logprobs)` directly. No `torch.autograd.Function` wrapper, no extra `grad_output * dpolicy` backward
24
- kernel, and no host-side syncs in the hot path.
25
-
26
- Background and benchmarks: see the
27
- [release post](https://geometric.so/blog/2026/05/08/hf-kernel-hub).
28
-
29
- - **Backend**: CUDA (NVIDIA CUTLASS DSL).
30
- - **Min GPU**: SM80 (Ampere) - required by `nvidia-cutlass-dsl`. Tested on H100 (SM90). Should work on SM80 (Ampere), SM86 (RTX 3090, A40), SM89 (RTX 4090, L40S), SM90a (H100 SXM), and SM100 (Blackwell B200/GB200).
31
- - **Min CUDA**: 12.8.
32
- - **Dtypes**: `float32`, `float16`, `bfloat16`.
33
- - **Dynamic shapes**: a single compile handles arbitrary batch size and
34
- sequence length, no recompiles when shapes change between calls (common
35
- in post-training rollouts).
36
-
37
- ## Kernels
38
-
39
- | Kernel family | Direct (no autograd) | Autograd-aware | Forward-only |
40
- | --- | --- | --- | --- |
41
- | BNPO loss | `bnpo_loss` | `bnpo_loss_autograd` | `bnpo_loss_fwd` |
42
- | GRPO loss | `grpo_loss` | `grpo_loss_autograd` | `grpo_loss_fwd` |
43
- | Reverse KL | `reverse_kl` | `reverse_kl_autograd` | `reverse_kl_fwd` |
44
-
45
- ### Entry points
46
-
47
- Each kernel family exposes three entry points with the same underlying CuteDSL kernel:
48
-
49
- - **`<name>(...)`** - fused fwd+bwd, returns `(loss, grad)` from one `@cute.jit`
50
- dispatch. Lowest-overhead path; the caller chains the gradient into the upstream
51
- model with `policy_logprobs.backward(grad)`. Use this in custom training loops
52
- where you control gradient flow.
53
- - **`<name>_autograd(...)`** - same kernel, registered via
54
- `torch.library.custom_op` + `register_autograd`. `loss.backward()` works
55
- and composes with `torch.compile(fullgraph=True)`. There is a noticeable
56
- per-call dispatcher overhead vs. the direct path.
57
- - **`<name>_fwd(...)`** - forward-only, returns scalar `loss` and skips
58
- the gradient buffer entirely. Use for inference / validation /
59
- reward-model scoring.
60
-
61
- ## Loading the kernels
62
- ```
63
- pip install apache-tvm-ffi nvidia-cutlass-dsl
64
- ```
65
-
66
- ```python
67
- from kernels import get_kernel
68
-
69
- km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0)
70
- ```
71
-
72
- ---
73
-
74
- ## BNPO Loss
75
-
76
- **Batch-Normalized Policy Optimization** sums per-token policy and KL terms
77
- across the **entire batch** and divides by the global valid-token count:
78
-
79
- ```
80
- loss = ((per_token_loss + β·kl) · mask).sum() / max(mask.sum(), 1)
81
- ```
82
-
83
- where `per_token_loss` is the PPO-clipped ratio loss:
84
-
85
- ```
86
- ratio = exp(policy_logprobs - old_policy_logprobs)
87
- clipped = clip(ratio, 1−ε, 1+ε_high)
88
- per_token = −advantages · min(ratio, clipped)
89
- kl = exp(ref_logprobs − policy_logprobs) − (ref_logprobs − policy_logprobs) − 1
90
- ```
91
-
92
- The global denominator is computed entirely on-GPU via cross-CTA atomics -
93
- no host-side `mask.sum()` sync. When `beta=0` the KL branch is dead-coded
94
- at compile time.
95
-
96
- **Inputs**:
97
- - `policy_logprobs`, `old_policy_logprobs`, `ref_logprobs`: `(bs, seq_len)`, fp32/fp16/bf16
98
- - `advantages`: `(bs,)`
99
- - `completions_mask`: `(bs, seq_len)`, bool or int8
100
-
101
- **Returns**: `(loss, grad_policy_logprobs)` from `bnpo_loss`; scalar `loss` from `bnpo_loss_fwd`.
102
-
103
- ```python
104
- import torch
105
- from kernels import get_kernel
106
-
107
- km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0)
108
- device = torch.device("cuda")
109
-
110
- bs, seq_len = 16, 1024
111
- policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device, requires_grad=True)
112
- old_policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device)
113
- ref_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device)
114
- advantages = torch.randn(bs, dtype=torch.bfloat16, device=device)
115
- completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2).to(torch.int8)
116
-
117
- # 1) Direct (loss, grad) - lowest overhead training path
118
- loss, grad = km.bnpo_loss(
119
- policy_logprobs, old_policy_logprobs, ref_logprobs,
120
- advantages, completions_mask,
121
- epsilon=0.2, epsilon_high=0.2, beta=0.1,
122
- )
123
- policy_logprobs.backward(grad)
124
-
125
- # 2) Autograd-aware - works with loss.backward() and torch.compile
126
- loss = km.bnpo_loss_autograd(
127
- policy_logprobs.requires_grad_(),
128
- old_policy_logprobs, ref_logprobs,
129
- advantages, completions_mask,
130
- epsilon=0.2, epsilon_high=0.2, beta=0.1,
131
- )
132
- loss.backward()
133
-
134
- # 3) Forward-only - inference / reward scoring, no gradient buffer
135
- loss = km.bnpo_loss_fwd(
136
- policy_logprobs, old_policy_logprobs, ref_logprobs,
137
- advantages, completions_mask,
138
- epsilon=0.2, epsilon_high=0.2, beta=0.1,
139
- )
140
- ```
141
-
142
- ---
143
-
144
- ## GRPO Loss
145
-
146
- **Group Relative Policy Optimization** implements TRL's default
147
- **per-response normalization** variant - each response is normalized by its
148
- own valid-token count before averaging across the batch:
149
-
150
- ```
151
- loss = mean_r( ((per_token_loss + β·kl) · mask).sum(-1) / max(mask.sum(-1), 1) )
152
- ```
153
-
154
- `per_token_loss` and `kl` are the same clipped-ratio and KL expressions as BNPO.
155
- `completions_mask` is **required** because the per-response denominator is
156
- mask-derived. The kernel uses one CTA per row so the per-row mask sum is
157
- reduced inside the block - no cross-CTA atomics on the scaling pass.
158
-
159
- **Inputs**:
160
- - `policy_logprobs`, `old_policy_logprobs`, `ref_logprobs`: `(bs, seq_len)`, fp32/fp16/bf16
161
- - `advantages`: `(bs,)`
162
- - `completions_mask`: `(bs, seq_len)`, bool or int8 - **required**
163
-
164
- **Returns**: `(loss, grad_policy_logprobs)` from `grpo_loss`; scalar `loss` from `grpo_loss_fwd`.
165
-
166
- ```python
167
- import torch
168
- from kernels import get_kernel
169
-
170
- km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0)
171
- device = torch.device("cuda")
172
-
173
- bs, seq_len = 16, 1024
174
- policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device, requires_grad=True)
175
- old_policy_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device)
176
- ref_logprobs = torch.randn(bs, seq_len, dtype=torch.bfloat16, device=device)
177
- advantages = torch.randn(bs, dtype=torch.bfloat16, device=device)
178
- completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2).to(torch.int8)
179
-
180
- # 1) Direct (loss, grad) - lowest overhead training path
181
- loss, grad = km.grpo_loss(
182
- policy_logprobs, old_policy_logprobs, ref_logprobs,
183
- advantages, completions_mask,
184
- epsilon=0.2, epsilon_high=0.2, beta=0.1,
185
- )
186
- policy_logprobs.backward(grad)
187
-
188
- # 2) Autograd-aware - works with loss.backward() and torch.compile
189
- loss = km.grpo_loss_autograd(
190
- policy_logprobs.requires_grad_(),
191
- old_policy_logprobs, ref_logprobs,
192
- advantages, completions_mask,
193
- epsilon=0.2, epsilon_high=0.2, beta=0.1,
194
- )
195
- loss.backward()
196
-
197
- # 3) Forward-only - inference / reward scoring, no gradient buffer
198
- loss = km.grpo_loss_fwd(
199
- policy_logprobs, old_policy_logprobs, ref_logprobs,
200
- advantages, completions_mask,
201
- epsilon=0.2, epsilon_high=0.2, beta=0.1,
202
- )
203
- ```
204
-
205
- ---
206
-
207
- ## Reverse KL
208
-
209
- **Reverse-KL self-distillation** computes `KL(student ‖ teacher)` over a
210
- `(num_tokens, vocab)` slab using an online normalization algorithm that reads
211
- each logit row exactly once on the forward-only path:
212
-
213
- ```
214
- p = softmax(student_logits)
215
- q = softmax(teacher_logits)
216
- kl_per_row = Σ_v p_v · (log p_v − log q_v)
217
- loss = (mask · kl_per_row).sum() / mask.sum()
218
- ```
219
-
220
- The gradient through the softmax Jacobian is analytical:
221
-
222
- ```
223
- grad_student_v = scale · p_v · (log p_v − log q_v − kl_per_row)
224
- ```
225
-
226
- where `scale = mask[r] · inv_n_valid`.
227
-
228
- **Inputs**:
229
- - `student_logits`, `teacher_logits`: `(*, V)` - arbitrary leading dims (typically `(bs, seq_len, vocab)`); both must share shape and dtype
230
- - `completions_mask`: shape matching `student_logits.shape[:-1]`
231
-
232
- > ⚠️ **Fully-masked batches**: `inv_n_valid = 1 / mask.sum()` is not clamped, so a batch where every token is masked produces inf/NaN. Guard upstream if that case is reachable.
233
-
234
- **Returns**: `(loss, grad_student_logits)` from `reverse_kl`; scalar `loss` from `reverse_kl_fwd`.
235
-
236
- ```python
237
- import torch
238
- from kernels import get_kernel
239
-
240
- km = get_kernel("Geometric-AI/geometric-ai-kernels", version=0)
241
- device = torch.device("cuda")
242
-
243
- # Qwen3.5-style vocab; arbitrary leading dims supported
244
- bs, seq_len, vocab = 4, 256, 248320
245
- student_logits = torch.randn(bs, seq_len, vocab, dtype=torch.bfloat16, device=device, requires_grad=True)
246
- teacher_logits = torch.randn(bs, seq_len, vocab, dtype=torch.bfloat16, device=device)
247
- completions_mask = (torch.rand(bs, seq_len, device=device) > 0.2)
248
-
249
- # 1) Direct (loss, grad) - lowest overhead training path
250
- loss, grad = km.reverse_kl(student_logits, teacher_logits, completions_mask)
251
- student_logits.backward(grad)
252
-
253
- # 2) Autograd-aware - works with loss.backward() and torch.compile
254
- loss = km.reverse_kl_autograd(
255
- student_logits.requires_grad_(), teacher_logits, completions_mask
256
- )
257
- loss.backward()
258
-
259
- # 3) Forward-only - inference / KL monitoring, no gradient buffer
260
- loss = km.reverse_kl_fwd(student_logits, teacher_logits, completions_mask)
261
- ```
262
-
263
- ---
264
-
265
- ## Performance
266
-
267
- All numbers are geometric-mean speedups over H100 SXM (SM90a). Full methodology
268
- and per-shape plots in the [release post](https://geometric.so/blog/2026/05/08/hf-kernel-hub).
269
-
270
- ### `kernels` CLI benchmark
271
-
272
- Timed with `time.perf_counter` + `cuda.synchronize()`, mean over 100 iterations.
273
-
274
- | Kernel | vs eager | vs `torch.compile` |
275
- | --- | --- | --- |
276
- | `grpo_loss_fwd` | 5.68× | 2.45× |
277
- | `grpo_loss` | 20.79× | 1.98x |
278
- | `bnpo_loss_fwd` | 5.29× | 2.52× |
279
- | `bnpo_loss` | 16.81× | 2.27× |
280
- | `reverse_kl_fwd`| 6.88× | 2.45× |
281
- | `reverse_kl` | 7.03× | 2.61× |
282
- ---
283
-
284
- ## Benchmark animations
285
-
286
- ### BNPO Loss vs eager PyTorch
287
-
288
- <picture>
289
- <source media="(prefers-color-scheme: dark)" srcset="benchmark_results/bnpo_loss_eager/bnpo_loss_eager_dark_animation.svg">
290
- <img width="90%" src="benchmark_results/bnpo_loss_eager/bnpo_loss_eager_light_animation.svg" alt="BNPO loss latency vs eager PyTorch">
291
- </picture>
292
-
293
- ### BNPO Loss vs torch.compile
294
-
295
- <picture>
296
- <source media="(prefers-color-scheme: dark)" srcset="benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_dark_animation.svg">
297
- <img width="90%" src="benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_light_animation.svg" alt="BNPO loss latency vs torch.compile">
298
- </picture>
299
-
300
- ### GRPO Loss vs eager PyTorch
301
-
302
- <picture>
303
- <source media="(prefers-color-scheme: dark)" srcset="benchmark_results/grpo_loss_eager/grpo_loss_eager_dark_animation.svg">
304
- <img width="90%" src="benchmark_results/grpo_loss_eager/grpo_loss_eager_light_animation.svg" alt="GRPO loss latency vs eager PyTorch">
305
- </picture>
306
-
307
- ### GRPO Loss vs torch.compile
308
-
309
- <picture>
310
- <source media="(prefers-color-scheme: dark)" srcset="benchmark_results/grpo_loss_compiled/grpo_loss_compiled_dark_animation.svg">
311
- <img width="90%" src="benchmark_results/grpo_loss_compiled/grpo_loss_compiled_light_animation.svg" alt="GRPO loss latency vs torch.compile">
312
- </picture>
313
-
314
- ### Reverse KL vs eager PyTorch
315
-
316
- <picture>
317
- <source media="(prefers-color-scheme: dark)" srcset="benchmark_results/reverse_kl_eager/reverse_kl_eager_dark_animation.svg">
318
- <img width="90%" src="benchmark_results/reverse_kl_eager/reverse_kl_eager_light_animation.svg" alt="Reverse KL latency vs eager PyTorch">
319
- </picture>
320
-
321
- ### Reverse KL vs torch.compile
322
-
323
- <picture>
324
- <source media="(prefers-color-scheme: dark)" srcset="benchmark_results/reverse_kl_compiled/reverse_kl_compiled_dark_animation.svg">
325
- <img width="90%" src="benchmark_results/reverse_kl_compiled/reverse_kl_compiled_light_animation.svg" alt="Reverse KL latency vs torch.compile">
326
- </picture>
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_dark_animation.svg DELETED
benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_dark_latency.svg DELETED
benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_dark_throughput.svg DELETED
benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_light_animation.svg DELETED
benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_light_latency.svg DELETED
benchmark_results/bnpo_loss_compiled/bnpo_loss_compiled_light_throughput.svg DELETED
benchmark_results/bnpo_loss_compiled/results.json DELETED
@@ -1,206 +0,0 @@
1
- {
2
- "results": [
3
- {
4
- "workload": "bnpoLossBenchmark.bnpo_loss_batch128_seqlen02781_compiled",
5
- "timingResults": {
6
- "mean_ms": 0.0359,
7
- "std_ms": 0.0038,
8
- "min_ms": 0.0332,
9
- "max_ms": 0.0701,
10
- "q1_ms": 0.0344,
11
- "q3_ms": 0.0357,
12
- "iqr_ms": 0.0013,
13
- "outliers": 20,
14
- "iterations": 200,
15
- "refMeanMs": 0.0771
16
- },
17
- "verified": true
18
- },
19
- {
20
- "workload": "bnpoLossBenchmark.bnpo_loss_batch128_seqlen08192_compiled",
21
- "timingResults": {
22
- "mean_ms": 0.0351,
23
- "std_ms": 0.0033,
24
- "min_ms": 0.0327,
25
- "max_ms": 0.0557,
26
- "q1_ms": 0.0336,
27
- "q3_ms": 0.035,
28
- "iqr_ms": 0.0014,
29
- "outliers": 14,
30
- "iterations": 200,
31
- "refMeanMs": 0.0771
32
- },
33
- "verified": true
34
- },
35
- {
36
- "workload": "bnpoLossBenchmark.bnpo_loss_batch16_seqlen01024_compiled",
37
- "timingResults": {
38
- "mean_ms": 0.0355,
39
- "std_ms": 0.0042,
40
- "min_ms": 0.0331,
41
- "max_ms": 0.0706,
42
- "q1_ms": 0.034,
43
- "q3_ms": 0.0351,
44
- "iqr_ms": 0.0011,
45
- "outliers": 21,
46
- "iterations": 200,
47
- "refMeanMs": 0.0811
48
- },
49
- "verified": true
50
- },
51
- {
52
- "workload": "bnpoLossBenchmark.bnpo_loss_batch16_seqlen02781_compiled",
53
- "timingResults": {
54
- "mean_ms": 0.0355,
55
- "std_ms": 0.004,
56
- "min_ms": 0.0319,
57
- "max_ms": 0.0591,
58
- "q1_ms": 0.0338,
59
- "q3_ms": 0.0352,
60
- "iqr_ms": 0.0014,
61
- "outliers": 24,
62
- "iterations": 200,
63
- "refMeanMs": 0.0709
64
- },
65
- "verified": true
66
- },
67
- {
68
- "workload": "bnpoLossBenchmark.bnpo_loss_batch32_seqlen02048_compiled",
69
- "timingResults": {
70
- "mean_ms": 0.0358,
71
- "std_ms": 0.0042,
72
- "min_ms": 0.032,
73
- "max_ms": 0.0569,
74
- "q1_ms": 0.0338,
75
- "q3_ms": 0.0355,
76
- "iqr_ms": 0.0017,
77
- "outliers": 27,
78
- "iterations": 200,
79
- "refMeanMs": 0.0763
80
- },
81
- "verified": true
82
- },
83
- {
84
- "workload": "bnpoLossBenchmark.bnpo_loss_batch64_seqlen04096_compiled",
85
- "timingResults": {
86
- "mean_ms": 0.0344,
87
- "std_ms": 0.0031,
88
- "min_ms": 0.032,
89
- "max_ms": 0.0557,
90
- "q1_ms": 0.0331,
91
- "q3_ms": 0.0341,
92
- "iqr_ms": 0.001,
93
- "outliers": 32,
94
- "iterations": 200,
95
- "refMeanMs": 0.0739
96
- },
97
- "verified": true
98
- },
99
- {
100
- "workload": "bnpoLossBenchmark.bnpo_loss_fwd_batch128_seqlen02781_compiled",
101
- "timingResults": {
102
- "mean_ms": 0.0323,
103
- "std_ms": 0.0034,
104
- "min_ms": 0.03,
105
- "max_ms": 0.053,
106
- "q1_ms": 0.0311,
107
- "q3_ms": 0.0318,
108
- "iqr_ms": 0.0007,
109
- "outliers": 25,
110
- "iterations": 200,
111
- "refMeanMs": 0.0808
112
- },
113
- "verified": true
114
- },
115
- {
116
- "workload": "bnpoLossBenchmark.bnpo_loss_fwd_batch128_seqlen08192_compiled",
117
- "timingResults": {
118
- "mean_ms": 0.0318,
119
- "std_ms": 0.0032,
120
- "min_ms": 0.0293,
121
- "max_ms": 0.0502,
122
- "q1_ms": 0.0304,
123
- "q3_ms": 0.0317,
124
- "iqr_ms": 0.0013,
125
- "outliers": 17,
126
- "iterations": 200,
127
- "refMeanMs": 0.0845
128
- },
129
- "verified": true
130
- },
131
- {
132
- "workload": "bnpoLossBenchmark.bnpo_loss_fwd_batch16_seqlen01024_compiled",
133
- "timingResults": {
134
- "mean_ms": 0.0317,
135
- "std_ms": 0.0031,
136
- "min_ms": 0.0293,
137
- "max_ms": 0.0593,
138
- "q1_ms": 0.0304,
139
- "q3_ms": 0.0317,
140
- "iqr_ms": 0.0013,
141
- "outliers": 17,
142
- "iterations": 200,
143
- "refMeanMs": 0.079
144
- },
145
- "verified": true
146
- },
147
- {
148
- "workload": "bnpoLossBenchmark.bnpo_loss_fwd_batch16_seqlen02781_compiled",
149
- "timingResults": {
150
- "mean_ms": 0.0306,
151
- "std_ms": 0.0035,
152
- "min_ms": 0.0279,
153
- "max_ms": 0.0534,
154
- "q1_ms": 0.0289,
155
- "q3_ms": 0.0306,
156
- "iqr_ms": 0.0017,
157
- "outliers": 20,
158
- "iterations": 200,
159
- "refMeanMs": 0.084
160
- },
161
- "verified": true
162
- },
163
- {
164
- "workload": "bnpoLossBenchmark.bnpo_loss_fwd_batch32_seqlen02048_compiled",
165
- "timingResults": {
166
- "mean_ms": 0.0305,
167
- "std_ms": 0.0035,
168
- "min_ms": 0.0279,
169
- "max_ms": 0.051,
170
- "q1_ms": 0.0288,
171
- "q3_ms": 0.0308,
172
- "iqr_ms": 0.002,
173
- "outliers": 15,
174
- "iterations": 200,
175
- "refMeanMs": 0.0764
176
- },
177
- "verified": true
178
- },
179
- {
180
- "workload": "bnpoLossBenchmark.bnpo_loss_fwd_batch64_seqlen04096_compiled",
181
- "timingResults": {
182
- "mean_ms": 0.0315,
183
- "std_ms": 0.0033,
184
- "min_ms": 0.0293,
185
- "max_ms": 0.0543,
186
- "q1_ms": 0.0302,
187
- "q3_ms": 0.0311,
188
- "iqr_ms": 0.0009,
189
- "outliers": 21,
190
- "iterations": 200,
191
- "refMeanMs": 0.0739
192
- },
193
- "verified": true
194
- }
195
- ],
196
- "machineInfo": {
197
- "gpu": "NVIDIA H100 80GB HBM3",
198
- "backend": "CUDA 13.0",
199
- "pytorchVersion": "2.11.0+cu130",
200
- "os": "Linux 6.11.0-1016-nvidia",
201
- "cpu": "x86_64"
202
- },
203
- "kernelCommitSha": "7972ab0e834be24d",
204
- "benchmarkScriptPath": "benchmarks",
205
- "benchmarkScriptSha": "68426064f76adff2066ad365f6c97be3fe279bd6b20d025b3dc5614f9b2da449"
206
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmark_results/bnpo_loss_eager/bnpo_loss_eager_dark_animation.svg DELETED
benchmark_results/bnpo_loss_eager/bnpo_loss_eager_dark_latency.svg DELETED
benchmark_results/bnpo_loss_eager/bnpo_loss_eager_dark_throughput.svg DELETED
benchmark_results/bnpo_loss_eager/bnpo_loss_eager_light_animation.svg DELETED
benchmark_results/bnpo_loss_eager/bnpo_loss_eager_light_latency.svg DELETED
benchmark_results/bnpo_loss_eager/bnpo_loss_eager_light_throughput.svg DELETED
benchmark_results/bnpo_loss_eager/results.json DELETED
@@ -1,206 +0,0 @@
1
- {
2
- "results": [
3
- {
4
- "workload": "bnpoLossBenchmark.bnpo_loss_batch128_seqlen02781_eager",
5
- "timingResults": {
6
- "mean_ms": 0.0358,
7
- "std_ms": 0.0035,
8
- "min_ms": 0.0323,
9
- "max_ms": 0.0536,
10
- "q1_ms": 0.0342,
11
- "q3_ms": 0.0358,
12
- "iqr_ms": 0.0017,
13
- "outliers": 17,
14
- "iterations": 200,
15
- "refMeanMs": 0.5552
16
- },
17
- "verified": true
18
- },
19
- {
20
- "workload": "bnpoLossBenchmark.bnpo_loss_batch128_seqlen08192_eager",
21
- "timingResults": {
22
- "mean_ms": 0.0344,
23
- "std_ms": 0.0031,
24
- "min_ms": 0.0314,
25
- "max_ms": 0.0537,
26
- "q1_ms": 0.0329,
27
- "q3_ms": 0.0345,
28
- "iqr_ms": 0.0015,
29
- "outliers": 20,
30
- "iterations": 200,
31
- "refMeanMs": 0.6466
32
- },
33
- "verified": true
34
- },
35
- {
36
- "workload": "bnpoLossBenchmark.bnpo_loss_batch16_seqlen01024_eager",
37
- "timingResults": {
38
- "mean_ms": 0.0345,
39
- "std_ms": 0.0171,
40
- "min_ms": 0.0305,
41
- "max_ms": 0.2718,
42
- "q1_ms": 0.0319,
43
- "q3_ms": 0.033,
44
- "iqr_ms": 0.0011,
45
- "outliers": 23,
46
- "iterations": 200,
47
- "refMeanMs": 0.5868
48
- },
49
- "verified": true
50
- },
51
- {
52
- "workload": "bnpoLossBenchmark.bnpo_loss_batch16_seqlen02781_eager",
53
- "timingResults": {
54
- "mean_ms": 0.0324,
55
- "std_ms": 0.0027,
56
- "min_ms": 0.0301,
57
- "max_ms": 0.0508,
58
- "q1_ms": 0.0312,
59
- "q3_ms": 0.0324,
60
- "iqr_ms": 0.0012,
61
- "outliers": 17,
62
- "iterations": 200,
63
- "refMeanMs": 0.5832
64
- },
65
- "verified": true
66
- },
67
- {
68
- "workload": "bnpoLossBenchmark.bnpo_loss_batch32_seqlen02048_eager",
69
- "timingResults": {
70
- "mean_ms": 0.0343,
71
- "std_ms": 0.0033,
72
- "min_ms": 0.031,
73
- "max_ms": 0.0513,
74
- "q1_ms": 0.0325,
75
- "q3_ms": 0.0346,
76
- "iqr_ms": 0.0021,
77
- "outliers": 19,
78
- "iterations": 200,
79
- "refMeanMs": 0.6265
80
- },
81
- "verified": true
82
- },
83
- {
84
- "workload": "bnpoLossBenchmark.bnpo_loss_batch64_seqlen04096_eager",
85
- "timingResults": {
86
- "mean_ms": 0.0328,
87
- "std_ms": 0.0029,
88
- "min_ms": 0.0306,
89
- "max_ms": 0.0499,
90
- "q1_ms": 0.0317,
91
- "q3_ms": 0.0326,
92
- "iqr_ms": 0.0009,
93
- "outliers": 20,
94
- "iterations": 200,
95
- "refMeanMs": 0.5698
96
- },
97
- "verified": true
98
- },
99
- {
100
- "workload": "bnpoLossBenchmark.bnpo_loss_fwd_batch128_seqlen02781_eager",
101
- "timingResults": {
102
- "mean_ms": 0.0317,
103
- "std_ms": 0.0034,
104
- "min_ms": 0.0285,
105
- "max_ms": 0.052,
106
- "q1_ms": 0.0305,
107
- "q3_ms": 0.0314,
108
- "iqr_ms": 0.0009,
109
- "outliers": 22,
110
- "iterations": 200,
111
- "refMeanMs": 0.1858
112
- },
113
- "verified": true
114
- },
115
- {
116
- "workload": "bnpoLossBenchmark.bnpo_loss_fwd_batch128_seqlen08192_eager",
117
- "timingResults": {
118
- "mean_ms": 0.0292,
119
- "std_ms": 0.0028,
120
- "min_ms": 0.0273,
121
- "max_ms": 0.0455,
122
- "q1_ms": 0.0281,
123
- "q3_ms": 0.0289,
124
- "iqr_ms": 0.0008,
125
- "outliers": 23,
126
- "iterations": 200,
127
- "refMeanMs": 0.1633
128
- },
129
- "verified": true
130
- },
131
- {
132
- "workload": "bnpoLossBenchmark.bnpo_loss_fwd_batch16_seqlen01024_eager",
133
- "timingResults": {
134
- "mean_ms": 0.0311,
135
- "std_ms": 0.0267,
136
- "min_ms": 0.0256,
137
- "max_ms": 0.4049,
138
- "q1_ms": 0.0276,
139
- "q3_ms": 0.0295,
140
- "iqr_ms": 0.0018,
141
- "outliers": 18,
142
- "iterations": 200,
143
- "refMeanMs": 0.1761
144
- },
145
- "verified": true
146
- },
147
- {
148
- "workload": "bnpoLossBenchmark.bnpo_loss_fwd_batch16_seqlen02781_eager",
149
- "timingResults": {
150
- "mean_ms": 0.0288,
151
- "std_ms": 0.003,
152
- "min_ms": 0.027,
153
- "max_ms": 0.0554,
154
- "q1_ms": 0.0278,
155
- "q3_ms": 0.0284,
156
- "iqr_ms": 0.0006,
157
- "outliers": 22,
158
- "iterations": 200,
159
- "refMeanMs": 0.1755
160
- },
161
- "verified": true
162
- },
163
- {
164
- "workload": "bnpoLossBenchmark.bnpo_loss_fwd_batch32_seqlen02048_eager",
165
- "timingResults": {
166
- "mean_ms": 0.031,
167
- "std_ms": 0.0034,
168
- "min_ms": 0.0281,
169
- "max_ms": 0.0484,
170
- "q1_ms": 0.0296,
171
- "q3_ms": 0.0306,
172
- "iqr_ms": 0.0009,
173
- "outliers": 27,
174
- "iterations": 200,
175
- "refMeanMs": 0.1533
176
- },
177
- "verified": true
178
- },
179
- {
180
- "workload": "bnpoLossBenchmark.bnpo_loss_fwd_batch64_seqlen04096_eager",
181
- "timingResults": {
182
- "mean_ms": 0.031,
183
- "std_ms": 0.0041,
184
- "min_ms": 0.0286,
185
- "max_ms": 0.0625,
186
- "q1_ms": 0.0294,
187
- "q3_ms": 0.0305,
188
- "iqr_ms": 0.0011,
189
- "outliers": 22,
190
- "iterations": 200,
191
- "refMeanMs": 0.1678
192
- },
193
- "verified": true
194
- }
195
- ],
196
- "machineInfo": {
197
- "gpu": "NVIDIA H100 80GB HBM3",
198
- "backend": "CUDA 13.0",
199
- "pytorchVersion": "2.11.0+cu130",
200
- "os": "Linux 6.11.0-1016-nvidia",
201
- "cpu": "x86_64"
202
- },
203
- "kernelCommitSha": "84e79b2f3ee3088a",
204
- "benchmarkScriptPath": "benchmarks",
205
- "benchmarkScriptSha": "68426064f76adff2066ad365f6c97be3fe279bd6b20d025b3dc5614f9b2da449"
206
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmark_results/grpo_loss_compiled/grpo_loss_compiled_dark_animation.svg DELETED
benchmark_results/grpo_loss_compiled/grpo_loss_compiled_dark_latency.svg DELETED
benchmark_results/grpo_loss_compiled/grpo_loss_compiled_dark_throughput.svg DELETED
benchmark_results/grpo_loss_compiled/grpo_loss_compiled_light_animation.svg DELETED
benchmark_results/grpo_loss_compiled/grpo_loss_compiled_light_latency.svg DELETED
benchmark_results/grpo_loss_compiled/grpo_loss_compiled_light_throughput.svg DELETED
benchmark_results/grpo_loss_compiled/results.json DELETED
@@ -1,174 +0,0 @@
1
- {
2
- "results": [
3
- {
4
- "workload": "GrpoLossBenchmark.grpo_loss_batch128_seqlen02781_compiled",
5
- "timingResults": {
6
- "mean_ms": 0.0329,
7
- "std_ms": 0.0042,
8
- "min_ms": 0.0301,
9
- "max_ms": 0.0632,
10
- "q1_ms": 0.031,
11
- "q3_ms": 0.0326,
12
- "iqr_ms": 0.0016,
13
- "outliers": 22,
14
- "iterations": 200,
15
- "refMeanMs": 0.0874
16
- },
17
- "verified": true
18
- },
19
- {
20
- "workload": "GrpoLossBenchmark.grpo_loss_batch128_seqlen08192_compiled",
21
- "timingResults": {
22
- "mean_ms": 0.0337,
23
- "std_ms": 0.0045,
24
- "min_ms": 0.0305,
25
- "max_ms": 0.065,
26
- "q1_ms": 0.0318,
27
- "q3_ms": 0.0333,
28
- "iqr_ms": 0.0015,
29
- "outliers": 23,
30
- "iterations": 200,
31
- "refMeanMs": 0.0824
32
- },
33
- "verified": true
34
- },
35
- {
36
- "workload": "GrpoLossBenchmark.grpo_loss_batch16_seqlen01024_compiled",
37
- "timingResults": {
38
- "mean_ms": 0.0323,
39
- "std_ms": 0.0045,
40
- "min_ms": 0.0286,
41
- "max_ms": 0.0621,
42
- "q1_ms": 0.0306,
43
- "q3_ms": 0.0321,
44
- "iqr_ms": 0.0015,
45
- "outliers": 24,
46
- "iterations": 200,
47
- "refMeanMs": 0.0626
48
- },
49
- "verified": true
50
- },
51
- {
52
- "workload": "GrpoLossBenchmark.grpo_loss_batch32_seqlen02048_compiled",
53
- "timingResults": {
54
- "mean_ms": 0.0324,
55
- "std_ms": 0.0046,
56
- "min_ms": 0.0286,
57
- "max_ms": 0.0688,
58
- "q1_ms": 0.0305,
59
- "q3_ms": 0.0321,
60
- "iqr_ms": 0.0016,
61
- "outliers": 22,
62
- "iterations": 200,
63
- "refMeanMs": 0.0633
64
- },
65
- "verified": true
66
- },
67
- {
68
- "workload": "GrpoLossBenchmark.grpo_loss_batch64_seqlen04096_compiled",
69
- "timingResults": {
70
- "mean_ms": 0.0349,
71
- "std_ms": 0.0058,
72
- "min_ms": 0.0315,
73
- "max_ms": 0.0814,
74
- "q1_ms": 0.0325,
75
- "q3_ms": 0.0341,
76
- "iqr_ms": 0.0016,
77
- "outliers": 26,
78
- "iterations": 200,
79
- "refMeanMs": 0.0869
80
- },
81
- "verified": true
82
- },
83
- {
84
- "workload": "GrpoLossBenchmark.grpo_loss_fwd_batch128_seqlen02781_compiled",
85
- "timingResults": {
86
- "mean_ms": 0.033,
87
- "std_ms": 0.0038,
88
- "min_ms": 0.0295,
89
- "max_ms": 0.0543,
90
- "q1_ms": 0.0313,
91
- "q3_ms": 0.0333,
92
- "iqr_ms": 0.0019,
93
- "outliers": 16,
94
- "iterations": 200,
95
- "refMeanMs": 0.0772
96
- },
97
- "verified": true
98
- },
99
- {
100
- "workload": "GrpoLossBenchmark.grpo_loss_fwd_batch128_seqlen08192_compiled",
101
- "timingResults": {
102
- "mean_ms": 0.0331,
103
- "std_ms": 0.0032,
104
- "min_ms": 0.0295,
105
- "max_ms": 0.0535,
106
- "q1_ms": 0.0316,
107
- "q3_ms": 0.0331,
108
- "iqr_ms": 0.0015,
109
- "outliers": 19,
110
- "iterations": 200,
111
- "refMeanMs": 0.0767
112
- },
113
- "verified": true
114
- },
115
- {
116
- "workload": "GrpoLossBenchmark.grpo_loss_fwd_batch16_seqlen01024_compiled",
117
- "timingResults": {
118
- "mean_ms": 0.033,
119
- "std_ms": 0.0032,
120
- "min_ms": 0.029,
121
- "max_ms": 0.051,
122
- "q1_ms": 0.0315,
123
- "q3_ms": 0.0332,
124
- "iqr_ms": 0.0016,
125
- "outliers": 17,
126
- "iterations": 200,
127
- "refMeanMs": 0.0845
128
- },
129
- "verified": true
130
- },
131
- {
132
- "workload": "GrpoLossBenchmark.grpo_loss_fwd_batch32_seqlen02048_compiled",
133
- "timingResults": {
134
- "mean_ms": 0.0339,
135
- "std_ms": 0.006,
136
- "min_ms": 0.03,
137
- "max_ms": 0.0674,
138
- "q1_ms": 0.0314,
139
- "q3_ms": 0.0331,
140
- "iqr_ms": 0.0017,
141
- "outliers": 23,
142
- "iterations": 200,
143
- "refMeanMs": 0.1052
144
- },
145
- "verified": true
146
- },
147
- {
148
- "workload": "GrpoLossBenchmark.grpo_loss_fwd_batch64_seqlen04096_compiled",
149
- "timingResults": {
150
- "mean_ms": 0.034,
151
- "std_ms": 0.004,
152
- "min_ms": 0.031,
153
- "max_ms": 0.0623,
154
- "q1_ms": 0.0323,
155
- "q3_ms": 0.0339,
156
- "iqr_ms": 0.0016,
157
- "outliers": 20,
158
- "iterations": 200,
159
- "refMeanMs": 0.0796
160
- },
161
- "verified": true
162
- }
163
- ],
164
- "machineInfo": {
165
- "gpu": "NVIDIA H100 80GB HBM3",
166
- "backend": "CUDA 13.0",
167
- "pytorchVersion": "2.11.0+cu130",
168
- "os": "Linux 6.11.0-1016-nvidia",
169
- "cpu": "x86_64"
170
- },
171
- "kernelCommitSha": "ad285d68b8c8c0ff",
172
- "benchmarkScriptPath": "benchmarks",
173
- "benchmarkScriptSha": "ff35d63fbca37cfcbf5c94f067c930adc2bd0043ce6788f286dbad5a4f9b9d4a"
174
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmark_results/grpo_loss_eager/grpo_loss_eager_dark_animation.svg DELETED
benchmark_results/grpo_loss_eager/grpo_loss_eager_dark_latency.svg DELETED
benchmark_results/grpo_loss_eager/grpo_loss_eager_dark_throughput.svg DELETED
benchmark_results/grpo_loss_eager/grpo_loss_eager_light_animation.svg DELETED
benchmark_results/grpo_loss_eager/grpo_loss_eager_light_latency.svg DELETED
benchmark_results/grpo_loss_eager/grpo_loss_eager_light_throughput.svg DELETED
benchmark_results/grpo_loss_eager/results.json DELETED
@@ -1,174 +0,0 @@
1
- {
2
- "results": [
3
- {
4
- "workload": "GrpoLossBenchmark.grpo_loss_batch128_seqlen02781_eager",
5
- "timingResults": {
6
- "mean_ms": 0.0313,
7
- "std_ms": 0.0029,
8
- "min_ms": 0.0281,
9
- "max_ms": 0.0482,
10
- "q1_ms": 0.03,
11
- "q3_ms": 0.0314,
12
- "iqr_ms": 0.0013,
13
- "outliers": 16,
14
- "iterations": 200,
15
- "refMeanMs": 0.6643
16
- },
17
- "verified": true
18
- },
19
- {
20
- "workload": "GrpoLossBenchmark.grpo_loss_batch128_seqlen08192_eager",
21
- "timingResults": {
22
- "mean_ms": 0.0309,
23
- "std_ms": 0.0031,
24
- "min_ms": 0.0285,
25
- "max_ms": 0.0477,
26
- "q1_ms": 0.0298,
27
- "q3_ms": 0.0306,
28
- "iqr_ms": 0.0008,
29
- "outliers": 19,
30
- "iterations": 200,
31
- "refMeanMs": 0.5961
32
- },
33
- "verified": true
34
- },
35
- {
36
- "workload": "GrpoLossBenchmark.grpo_loss_batch16_seqlen01024_eager",
37
- "timingResults": {
38
- "mean_ms": 0.0315,
39
- "std_ms": 0.0033,
40
- "min_ms": 0.0293,
41
- "max_ms": 0.0507,
42
- "q1_ms": 0.0302,
43
- "q3_ms": 0.0311,
44
- "iqr_ms": 0.0009,
45
- "outliers": 23,
46
- "iterations": 200,
47
- "refMeanMs": 0.6132
48
- },
49
- "verified": true
50
- },
51
- {
52
- "workload": "GrpoLossBenchmark.grpo_loss_batch32_seqlen02048_eager",
53
- "timingResults": {
54
- "mean_ms": 0.0302,
55
- "std_ms": 0.0029,
56
- "min_ms": 0.028,
57
- "max_ms": 0.0467,
58
- "q1_ms": 0.029,
59
- "q3_ms": 0.0299,
60
- "iqr_ms": 0.0008,
61
- "outliers": 20,
62
- "iterations": 200,
63
- "refMeanMs": 0.6043
64
- },
65
- "verified": true
66
- },
67
- {
68
- "workload": "GrpoLossBenchmark.grpo_loss_batch64_seqlen04096_eager",
69
- "timingResults": {
70
- "mean_ms": 0.0295,
71
- "std_ms": 0.003,
72
- "min_ms": 0.0268,
73
- "max_ms": 0.0465,
74
- "q1_ms": 0.0279,
75
- "q3_ms": 0.03,
76
- "iqr_ms": 0.002,
77
- "outliers": 12,
78
- "iterations": 200,
79
- "refMeanMs": 0.5798
80
- },
81
- "verified": true
82
- },
83
- {
84
- "workload": "GrpoLossBenchmark.grpo_loss_fwd_batch128_seqlen02781_eager",
85
- "timingResults": {
86
- "mean_ms": 0.0306,
87
- "std_ms": 0.0032,
88
- "min_ms": 0.0281,
89
- "max_ms": 0.0513,
90
- "q1_ms": 0.0293,
91
- "q3_ms": 0.0302,
92
- "iqr_ms": 0.0009,
93
- "outliers": 24,
94
- "iterations": 200,
95
- "refMeanMs": 0.1716
96
- },
97
- "verified": true
98
- },
99
- {
100
- "workload": "GrpoLossBenchmark.grpo_loss_fwd_batch128_seqlen08192_eager",
101
- "timingResults": {
102
- "mean_ms": 0.0302,
103
- "std_ms": 0.0031,
104
- "min_ms": 0.0284,
105
- "max_ms": 0.0594,
106
- "q1_ms": 0.0291,
107
- "q3_ms": 0.0299,
108
- "iqr_ms": 0.0008,
109
- "outliers": 21,
110
- "iterations": 200,
111
- "refMeanMs": 0.1701
112
- },
113
- "verified": true
114
- },
115
- {
116
- "workload": "GrpoLossBenchmark.grpo_loss_fwd_batch16_seqlen01024_eager",
117
- "timingResults": {
118
- "mean_ms": 0.0306,
119
- "std_ms": 0.0027,
120
- "min_ms": 0.0286,
121
- "max_ms": 0.0455,
122
- "q1_ms": 0.0294,
123
- "q3_ms": 0.0304,
124
- "iqr_ms": 0.001,
125
- "outliers": 16,
126
- "iterations": 200,
127
- "refMeanMs": 0.1741
128
- },
129
- "verified": true
130
- },
131
- {
132
- "workload": "GrpoLossBenchmark.grpo_loss_fwd_batch32_seqlen02048_eager",
133
- "timingResults": {
134
- "mean_ms": 0.0299,
135
- "std_ms": 0.0029,
136
- "min_ms": 0.0269,
137
- "max_ms": 0.0488,
138
- "q1_ms": 0.0287,
139
- "q3_ms": 0.0301,
140
- "iqr_ms": 0.0015,
141
- "outliers": 14,
142
- "iterations": 200,
143
- "refMeanMs": 0.1647
144
- },
145
- "verified": true
146
- },
147
- {
148
- "workload": "GrpoLossBenchmark.grpo_loss_fwd_batch64_seqlen04096_eager",
149
- "timingResults": {
150
- "mean_ms": 0.0314,
151
- "std_ms": 0.0028,
152
- "min_ms": 0.0289,
153
- "max_ms": 0.0465,
154
- "q1_ms": 0.0301,
155
- "q3_ms": 0.0312,
156
- "iqr_ms": 0.0011,
157
- "outliers": 22,
158
- "iterations": 200,
159
- "refMeanMs": 0.1751
160
- },
161
- "verified": true
162
- }
163
- ],
164
- "machineInfo": {
165
- "gpu": "NVIDIA H100 80GB HBM3",
166
- "backend": "CUDA 13.0",
167
- "pytorchVersion": "2.11.0+cu130",
168
- "os": "Linux 6.11.0-1016-nvidia",
169
- "cpu": "x86_64"
170
- },
171
- "kernelCommitSha": "87ec9b61421d0121",
172
- "benchmarkScriptPath": "benchmarks",
173
- "benchmarkScriptSha": "ff35d63fbca37cfcbf5c94f067c930adc2bd0043ce6788f286dbad5a4f9b9d4a"
174
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmark_results/reverse_kl_compiled/results.json DELETED
@@ -1,206 +0,0 @@
1
- {
2
- "results": [
3
- {
4
- "workload": "ReverseKLBenchmark.reverse_kl_batch01_seqlen064_vocab248320_compiled",
5
- "timingResults": {
6
- "mean_ms": 0.1039,
7
- "std_ms": 0.0035,
8
- "min_ms": 0.1,
9
- "max_ms": 0.1229,
10
- "q1_ms": 0.1018,
11
- "q3_ms": 0.104,
12
- "iqr_ms": 0.0022,
13
- "outliers": 28,
14
- "iterations": 200,
15
- "refMeanMs": 0.2322
16
- },
17
- "verified": true
18
- },
19
- {
20
- "workload": "ReverseKLBenchmark.reverse_kl_batch02_seqlen128_vocab248320_compiled",
21
- "timingResults": {
22
- "mean_ms": 0.2483,
23
- "std_ms": 0.0035,
24
- "min_ms": 0.2418,
25
- "max_ms": 0.2612,
26
- "q1_ms": 0.2457,
27
- "q3_ms": 0.2513,
28
- "iqr_ms": 0.0057,
29
- "outliers": 2,
30
- "iterations": 200,
31
- "refMeanMs": 0.6455
32
- },
33
- "verified": true
34
- },
35
- {
36
- "workload": "ReverseKLBenchmark.reverse_kl_batch04_seqlen256_vocab248320_compiled",
37
- "timingResults": {
38
- "mean_ms": 0.8322,
39
- "std_ms": 0.0044,
40
- "min_ms": 0.8232,
41
- "max_ms": 0.8623,
42
- "q1_ms": 0.8303,
43
- "q3_ms": 0.8335,
44
- "iqr_ms": 0.0032,
45
- "outliers": 18,
46
- "iterations": 200,
47
- "refMeanMs": 2.2082
48
- },
49
- "verified": true
50
- },
51
- {
52
- "workload": "ReverseKLBenchmark.reverse_kl_batch08_seqlen1024_vocab248320_compiled",
53
- "timingResults": {
54
- "mean_ms": 6.1083,
55
- "std_ms": 0.0054,
56
- "min_ms": 6.097,
57
- "max_ms": 6.1513,
58
- "q1_ms": 6.1054,
59
- "q3_ms": 6.11,
60
- "iqr_ms": 0.0046,
61
- "outliers": 13,
62
- "iterations": 200,
63
- "refMeanMs": 16.4779
64
- },
65
- "verified": true
66
- },
67
- {
68
- "workload": "ReverseKLBenchmark.reverse_kl_batch08_seqlen512_vocab248320_compiled",
69
- "timingResults": {
70
- "mean_ms": 3.0861,
71
- "std_ms": 0.0045,
72
- "min_ms": 3.0769,
73
- "max_ms": 3.1155,
74
- "q1_ms": 3.0832,
75
- "q3_ms": 3.0883,
76
- "iqr_ms": 0.0051,
77
- "outliers": 5,
78
- "iterations": 200,
79
- "refMeanMs": 8.3849
80
- },
81
- "verified": true
82
- },
83
- {
84
- "workload": "ReverseKLBenchmark.reverse_kl_batch08_seqlen981_vocab248320_compiled",
85
- "timingResults": {
86
- "mean_ms": 5.8622,
87
- "std_ms": 0.0044,
88
- "min_ms": 5.8544,
89
- "max_ms": 5.8821,
90
- "q1_ms": 5.859,
91
- "q3_ms": 5.8646,
92
- "iqr_ms": 0.0056,
93
- "outliers": 6,
94
- "iterations": 200,
95
- "refMeanMs": 15.8101
96
- },
97
- "verified": true
98
- },
99
- {
100
- "workload": "ReverseKLBenchmark.reverse_kl_fwd_batch01_seqlen064_vocab248320_compiled",
101
- "timingResults": {
102
- "mean_ms": 0.0657,
103
- "std_ms": 0.0041,
104
- "min_ms": 0.0619,
105
- "max_ms": 0.093,
106
- "q1_ms": 0.0635,
107
- "q3_ms": 0.0656,
108
- "iqr_ms": 0.0021,
109
- "outliers": 24,
110
- "iterations": 200,
111
- "refMeanMs": 0.1434
112
- },
113
- "verified": true
114
- },
115
- {
116
- "workload": "ReverseKLBenchmark.reverse_kl_fwd_batch02_seqlen128_vocab248320_compiled",
117
- "timingResults": {
118
- "mean_ms": 0.1234,
119
- "std_ms": 0.0041,
120
- "min_ms": 0.1187,
121
- "max_ms": 0.1464,
122
- "q1_ms": 0.1208,
123
- "q3_ms": 0.1244,
124
- "iqr_ms": 0.0036,
125
- "outliers": 16,
126
- "iterations": 200,
127
- "refMeanMs": 0.3277
128
- },
129
- "verified": true
130
- },
131
- {
132
- "workload": "ReverseKLBenchmark.reverse_kl_fwd_batch04_seqlen256_vocab248320_compiled",
133
- "timingResults": {
134
- "mean_ms": 0.3764,
135
- "std_ms": 0.0037,
136
- "min_ms": 0.3699,
137
- "max_ms": 0.3926,
138
- "q1_ms": 0.3733,
139
- "q3_ms": 0.3787,
140
- "iqr_ms": 0.0054,
141
- "outliers": 2,
142
- "iterations": 200,
143
- "refMeanMs": 0.9228
144
- },
145
- "verified": true
146
- },
147
- {
148
- "workload": "ReverseKLBenchmark.reverse_kl_fwd_batch08_seqlen1024_vocab248320_compiled",
149
- "timingResults": {
150
- "mean_ms": 2.658,
151
- "std_ms": 0.0089,
152
- "min_ms": 2.6359,
153
- "max_ms": 2.6859,
154
- "q1_ms": 2.6524,
155
- "q3_ms": 2.663,
156
- "iqr_ms": 0.0106,
157
- "outliers": 4,
158
- "iterations": 200,
159
- "refMeanMs": 6.6033
160
- },
161
- "verified": true
162
- },
163
- {
164
- "workload": "ReverseKLBenchmark.reverse_kl_fwd_batch08_seqlen512_vocab248320_compiled",
165
- "timingResults": {
166
- "mean_ms": 1.38,
167
- "std_ms": 0.0035,
168
- "min_ms": 1.37,
169
- "max_ms": 1.3924,
170
- "q1_ms": 1.3776,
171
- "q3_ms": 1.3818,
172
- "iqr_ms": 0.0042,
173
- "outliers": 6,
174
- "iterations": 200,
175
- "refMeanMs": 3.3854
176
- },
177
- "verified": true
178
- },
179
- {
180
- "workload": "ReverseKLBenchmark.reverse_kl_fwd_batch08_seqlen981_vocab248320_compiled",
181
- "timingResults": {
182
- "mean_ms": 2.5422,
183
- "std_ms": 0.0091,
184
- "min_ms": 2.5286,
185
- "max_ms": 2.5773,
186
- "q1_ms": 2.5356,
187
- "q3_ms": 2.5455,
188
- "iqr_ms": 0.0099,
189
- "outliers": 9,
190
- "iterations": 200,
191
- "refMeanMs": 6.2191
192
- },
193
- "verified": true
194
- }
195
- ],
196
- "machineInfo": {
197
- "gpu": "NVIDIA H100 80GB HBM3",
198
- "backend": "CUDA 13.0",
199
- "pytorchVersion": "2.11.0+cu130",
200
- "os": "Linux 6.11.0-1016-nvidia",
201
- "cpu": "x86_64"
202
- },
203
- "kernelCommitSha": "ca5cbc20b4d2c7d8",
204
- "benchmarkScriptPath": "benchmarks",
205
- "benchmarkScriptSha": "690eea1f54f31bef1ad248380201005fd667d4b9c535f92f06eb6a5a33380d22"
206
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmark_results/reverse_kl_compiled/reverse_kl_compiled_dark_animation.svg DELETED
benchmark_results/reverse_kl_compiled/reverse_kl_compiled_dark_latency.svg DELETED
benchmark_results/reverse_kl_compiled/reverse_kl_compiled_dark_throughput.svg DELETED
benchmark_results/reverse_kl_compiled/reverse_kl_compiled_light_animation.svg DELETED
benchmark_results/reverse_kl_compiled/reverse_kl_compiled_light_latency.svg DELETED
benchmark_results/reverse_kl_compiled/reverse_kl_compiled_light_throughput.svg DELETED
benchmark_results/reverse_kl_eager/results.json DELETED
@@ -1,206 +0,0 @@
1
- {
2
- "results": [
3
- {
4
- "workload": "ReverseKLBenchmark.reverse_kl_batch01_seqlen064_vocab248320_eager",
5
- "timingResults": {
6
- "mean_ms": 0.1029,
7
- "std_ms": 0.0032,
8
- "min_ms": 0.0982,
9
- "max_ms": 0.1129,
10
- "q1_ms": 0.101,
11
- "q3_ms": 0.1036,
12
- "iqr_ms": 0.0026,
13
- "outliers": 27,
14
- "iterations": 200,
15
- "refMeanMs": 0.5293
16
- },
17
- "verified": true
18
- },
19
- {
20
- "workload": "ReverseKLBenchmark.reverse_kl_batch02_seqlen128_vocab248320_eager",
21
- "timingResults": {
22
- "mean_ms": 0.248,
23
- "std_ms": 0.0037,
24
- "min_ms": 0.2417,
25
- "max_ms": 0.2592,
26
- "q1_ms": 0.2451,
27
- "q3_ms": 0.251,
28
- "iqr_ms": 0.0058,
29
- "outliers": 0,
30
- "iterations": 200,
31
- "refMeanMs": 1.624
32
- },
33
- "verified": true
34
- },
35
- {
36
- "workload": "ReverseKLBenchmark.reverse_kl_batch04_seqlen256_vocab248320_eager",
37
- "timingResults": {
38
- "mean_ms": 0.8321,
39
- "std_ms": 0.0035,
40
- "min_ms": 0.8234,
41
- "max_ms": 0.854,
42
- "q1_ms": 0.8306,
43
- "q3_ms": 0.8335,
44
- "iqr_ms": 0.003,
45
- "outliers": 20,
46
- "iterations": 200,
47
- "refMeanMs": 6.174
48
- },
49
- "verified": true
50
- },
51
- {
52
- "workload": "ReverseKLBenchmark.reverse_kl_batch08_seqlen1024_vocab248320_eager",
53
- "timingResults": {
54
- "mean_ms": 6.1046,
55
- "std_ms": 0.0041,
56
- "min_ms": 6.0961,
57
- "max_ms": 6.1376,
58
- "q1_ms": 6.1023,
59
- "q3_ms": 6.106,
60
- "iqr_ms": 0.0037,
61
- "outliers": 9,
62
- "iterations": 200,
63
- "refMeanMs": 48.4051
64
- },
65
- "verified": true
66
- },
67
- {
68
- "workload": "ReverseKLBenchmark.reverse_kl_batch08_seqlen512_vocab248320_eager",
69
- "timingResults": {
70
- "mean_ms": 3.0816,
71
- "std_ms": 0.0035,
72
- "min_ms": 3.0743,
73
- "max_ms": 3.0939,
74
- "q1_ms": 3.0794,
75
- "q3_ms": 3.0832,
76
- "iqr_ms": 0.0038,
77
- "outliers": 8,
78
- "iterations": 200,
79
- "refMeanMs": 24.3385
80
- },
81
- "verified": true
82
- },
83
- {
84
- "workload": "ReverseKLBenchmark.reverse_kl_batch08_seqlen981_vocab248320_eager",
85
- "timingResults": {
86
- "mean_ms": 5.8549,
87
- "std_ms": 0.0045,
88
- "min_ms": 5.8459,
89
- "max_ms": 5.8819,
90
- "q1_ms": 5.8524,
91
- "q3_ms": 5.8561,
92
- "iqr_ms": 0.0037,
93
- "outliers": 14,
94
- "iterations": 200,
95
- "refMeanMs": 46.4274
96
- },
97
- "verified": true
98
- },
99
- {
100
- "workload": "ReverseKLBenchmark.reverse_kl_fwd_batch01_seqlen064_vocab248320_eager",
101
- "timingResults": {
102
- "mean_ms": 0.0638,
103
- "std_ms": 0.0027,
104
- "min_ms": 0.0604,
105
- "max_ms": 0.0787,
106
- "q1_ms": 0.0624,
107
- "q3_ms": 0.064,
108
- "iqr_ms": 0.0016,
109
- "outliers": 20,
110
- "iterations": 200,
111
- "refMeanMs": 0.2532
112
- },
113
- "verified": true
114
- },
115
- {
116
- "workload": "ReverseKLBenchmark.reverse_kl_fwd_batch02_seqlen128_vocab248320_eager",
117
- "timingResults": {
118
- "mean_ms": 0.1217,
119
- "std_ms": 0.0038,
120
- "min_ms": 0.1166,
121
- "max_ms": 0.1428,
122
- "q1_ms": 0.1193,
123
- "q3_ms": 0.1227,
124
- "iqr_ms": 0.0034,
125
- "outliers": 19,
126
- "iterations": 200,
127
- "refMeanMs": 0.7671
128
- },
129
- "verified": true
130
- },
131
- {
132
- "workload": "ReverseKLBenchmark.reverse_kl_fwd_batch04_seqlen256_vocab248320_eager",
133
- "timingResults": {
134
- "mean_ms": 0.3753,
135
- "std_ms": 0.0033,
136
- "min_ms": 0.3695,
137
- "max_ms": 0.3843,
138
- "q1_ms": 0.3726,
139
- "q3_ms": 0.3779,
140
- "iqr_ms": 0.0053,
141
- "outliers": 0,
142
- "iterations": 200,
143
- "refMeanMs": 2.869
144
- },
145
- "verified": true
146
- },
147
- {
148
- "workload": "ReverseKLBenchmark.reverse_kl_fwd_batch08_seqlen1024_vocab248320_eager",
149
- "timingResults": {
150
- "mean_ms": 2.6484,
151
- "std_ms": 0.0065,
152
- "min_ms": 2.6364,
153
- "max_ms": 2.7044,
154
- "q1_ms": 2.6449,
155
- "q3_ms": 2.6515,
156
- "iqr_ms": 0.0067,
157
- "outliers": 3,
158
- "iterations": 200,
159
- "refMeanMs": 22.3336
160
- },
161
- "verified": true
162
- },
163
- {
164
- "workload": "ReverseKLBenchmark.reverse_kl_fwd_batch08_seqlen512_vocab248320_eager",
165
- "timingResults": {
166
- "mean_ms": 1.365,
167
- "std_ms": 0.0046,
168
- "min_ms": 1.3548,
169
- "max_ms": 1.3865,
170
- "q1_ms": 1.3618,
171
- "q3_ms": 1.3675,
172
- "iqr_ms": 0.0057,
173
- "outliers": 4,
174
- "iterations": 200,
175
- "refMeanMs": 11.2401
176
- },
177
- "verified": true
178
- },
179
- {
180
- "workload": "ReverseKLBenchmark.reverse_kl_fwd_batch08_seqlen981_vocab248320_eager",
181
- "timingResults": {
182
- "mean_ms": 2.5316,
183
- "std_ms": 0.0059,
184
- "min_ms": 2.5203,
185
- "max_ms": 2.5523,
186
- "q1_ms": 2.5272,
187
- "q3_ms": 2.5355,
188
- "iqr_ms": 0.0083,
189
- "outliers": 3,
190
- "iterations": 200,
191
- "refMeanMs": 21.4099
192
- },
193
- "verified": true
194
- }
195
- ],
196
- "machineInfo": {
197
- "gpu": "NVIDIA H100 80GB HBM3",
198
- "backend": "CUDA 13.0",
199
- "pytorchVersion": "2.11.0+cu130",
200
- "os": "Linux 6.11.0-1016-nvidia",
201
- "cpu": "x86_64"
202
- },
203
- "kernelCommitSha": "3e023eb5121761b8",
204
- "benchmarkScriptPath": "benchmarks",
205
- "benchmarkScriptSha": "690eea1f54f31bef1ad248380201005fd667d4b9c535f92f06eb6a5a33380d22"
206
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmark_results/reverse_kl_eager/reverse_kl_eager_dark_animation.svg DELETED
benchmark_results/reverse_kl_eager/reverse_kl_eager_dark_latency.svg DELETED
benchmark_results/reverse_kl_eager/reverse_kl_eager_dark_throughput.svg DELETED
benchmark_results/reverse_kl_eager/reverse_kl_eager_light_animation.svg DELETED
benchmark_results/reverse_kl_eager/reverse_kl_eager_light_latency.svg DELETED
benchmark_results/reverse_kl_eager/reverse_kl_eager_light_throughput.svg DELETED
build/torch-cuda/__init__.py DELETED
@@ -1,69 +0,0 @@
1
- """Geometric-AI CuteDSL kernels for RL / distillation training.
2
-
3
- Public surface:
4
- * ``bnpo_loss`` / ``bnpo_loss_autograd`` / ``bnpo_loss_fwd`` —
5
- fused fwd+bwd BNPO loss with three entry points (direct
6
- ``(loss, grad)``, autograd-wrapped, forward-only).
7
- * ``grpo_loss`` / ``grpo_loss_autograd`` / ``grpo_loss_fwd`` —
8
- fused fwd+bwd GRPO loss (TRL's per-response normalization
9
- variant). Same three-entry-point shape as BNPO. Requires
10
- ``completions_mask``.
11
- * ``reverse_kl`` / ``reverse_kl_autograd`` /
12
- ``reverse_kl_fwd`` — fused fwd+bwd reverse-KL
13
- self-distillation loss with the same three-entry-point shape.
14
-
15
- HF Kernels integration: :mod:`geometric_ai_kernels.layers` exposes
16
- ``nn.Module`` adapters per kernel (``bnpoLoss`` / ``bnpoLossInference``,
17
- ``grpoLoss`` / ``grpoLossInference``, ``ReverseKL`` /
18
- ``ReverseKLInference``) for use with the ``kernels``
19
- library's ``kernelize()`` flow.
20
- """
21
-
22
- from __future__ import annotations
23
-
24
- import torch._dynamo
25
-
26
- from .bnpo_loss import bnpo_loss, bnpo_loss_autograd, bnpo_loss_fwd
27
- from .grpo_loss import grpo_loss, grpo_loss_autograd, grpo_loss_fwd
28
- from .layers import (
29
- ReverseKL,
30
- ReverseKLInference,
31
- bnpoLoss,
32
- bnpoLossInference,
33
- grpoLoss,
34
- grpoLossInference,
35
- )
36
- from .reverse_kl import (
37
- reverse_kl,
38
- reverse_kl_autograd,
39
- reverse_kl_fwd,
40
- )
41
-
42
- # Required so ``torch.compile(fullgraph=True)`` can trace through
43
- # ``torch.autograd.grad`` calls — without it Dynamo graph-breaks at the
44
- # autograd.grad call site even when AOTAutograd has already derived the
45
- # joint fwd+bwd graph. Set at package import so any consumer (benches,
46
- # user training loops, ``kernelize`` flows) gets it for free. Guarded
47
- # because ``trace_autograd_ops`` was added in torch 2.10 and the
48
- # Nix-pinned build environment may be on an older torch (2.9 today);
49
- # the underlying ``Config.__setattr__`` raises on unknown keys.
50
- if hasattr(torch._dynamo.config, "trace_autograd_ops"):
51
- torch._dynamo.config.trace_autograd_ops = True # ty: ignore[invalid-assignment]
52
-
53
- __all__ = [
54
- "ReverseKL",
55
- "ReverseKLInference",
56
- "bnpoLoss",
57
- "bnpoLossInference",
58
- "bnpo_loss",
59
- "bnpo_loss_autograd",
60
- "bnpo_loss_fwd",
61
- "grpoLoss",
62
- "grpoLossInference",
63
- "grpo_loss",
64
- "grpo_loss_autograd",
65
- "grpo_loss_fwd",
66
- "reverse_kl",
67
- "reverse_kl_autograd",
68
- "reverse_kl_fwd",
69
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-cuda/_ops.py DELETED
@@ -1,38 +0,0 @@
1
- import torch
2
-
3
- def get_backend() -> str:
4
- """Detect the backend by inspecting torch."""
5
- import torch
6
-
7
- if hasattr(torch, "neuron"):
8
- # Needs to be sorted before specific Torch builds, since Neuron
9
- # extension can be loaded into e.g. CUDA Torch builds.
10
- return "neuron"
11
- elif torch.version.cuda is not None:
12
- return "cuda"
13
- elif torch.version.hip is not None:
14
- return "rocm"
15
- elif torch.backends.mps.is_available():
16
- return "metal"
17
- elif hasattr(torch.version, "xpu") and torch.version.xpu is not None:
18
- return "xpu"
19
- else:
20
- return "cpu"
21
-
22
-
23
- def _find_ops_name() -> str:
24
- kernel_name = "geometric_ai_kernels"
25
- unique_id = "a766fbd_dirty"
26
- backend = get_backend()
27
- return f"_{kernel_name}_{backend}_{unique_id}"
28
-
29
-
30
- _OPS_NAME = _find_ops_name()
31
-
32
- ops = getattr(torch.ops, _OPS_NAME)
33
-
34
- def add_op_namespace_prefix(op_name: str) -> str:
35
- """
36
- Prefix op by namespace.
37
- """
38
- return f"{_OPS_NAME}::{op_name}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-cuda/bnpo_loss/__init__.py DELETED
@@ -1,196 +0,0 @@
1
- """bnpo loss with CuteDSL fused fwd+bwd.
2
-
3
- Two public APIs route to two compiled kernels:
4
-
5
- * :func:`bnpo_loss` — primary training entry point. Returns
6
- ``(loss, grad_policy_logprobs)`` from a single fused fwd+bwd kernel
7
- launch. Inputs do **not** need ``requires_grad=True`` and there is no
8
- ``torch.autograd.Function`` wrapper — chain the gradient into the
9
- upstream model with ``policy_logprobs.backward(grad)`` (or, more
10
- commonly, by passing ``grad`` to whatever step does the next leg of
11
- backprop).
12
- * :func:`bnpo_loss_fwd` — inference / validation path. Returns the
13
- scalar ``loss`` from a forward-only kernel that computes the masked
14
- mean denominator on-GPU via a last-block trick (no host
15
- ``completions_mask.sum()``).
16
-
17
- The two share the same compiled-kernel cache; per-call output and
18
- gradient buffers are allocated inside the runner, and cross-CTA scratch
19
- (atomic accumulators + counters) is owned by the compiled-kernel
20
- closure and self-resets each launch — callers don't manage scratch.
21
-
22
- Why no autograd wrapper here? bnpo's gradient is closed-form — the
23
- kernel already writes ``dL/d(policy_logprobs)`` in the same launch as
24
- the loss. Wrapping in ``torch.autograd.Function`` would cost an extra
25
- ``grad_output * dpolicy`` kernel launch on backward (typically a
26
- no-op multiply by ``1.0``), plus per-call autograd graph bookkeeping.
27
- The autograd-aware sibling :func:`bnpo_loss_autograd` uses
28
- ``torch.library.custom_op`` instead, which composes with
29
- ``torch.compile``.
30
- """
31
-
32
- from __future__ import annotations
33
-
34
- from functools import lru_cache
35
- from typing import TYPE_CHECKING, cast
36
-
37
- import torch
38
-
39
- from .cute_bnpo_loss import (
40
- create_compiled_bnpo_loss,
41
- create_compiled_bnpo_loss_with_backward,
42
- )
43
-
44
- if TYPE_CHECKING:
45
- from collections.abc import Callable
46
-
47
-
48
- __all__ = ["bnpo_loss", "bnpo_loss_autograd", "bnpo_loss_fwd"]
49
-
50
-
51
- @lru_cache(maxsize=32)
52
- def _get_compiled_fwd(
53
- dtype: torch.dtype,
54
- epsilon: float,
55
- epsilon_high: float,
56
- beta: float,
57
- ) -> Callable[..., torch.Tensor]:
58
- return cast(
59
- "Callable[..., torch.Tensor]",
60
- create_compiled_bnpo_loss(
61
- policy_dtype=dtype,
62
- epsilon=epsilon,
63
- epsilon_high=epsilon_high,
64
- beta=beta,
65
- compute_backward=False,
66
- ),
67
- )
68
-
69
-
70
- @lru_cache(maxsize=32)
71
- def _get_compiled_fwd_bwd(
72
- dtype: torch.dtype,
73
- epsilon: float,
74
- epsilon_high: float,
75
- beta: float,
76
- ) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]:
77
- return create_compiled_bnpo_loss_with_backward(
78
- policy_dtype=dtype,
79
- epsilon=epsilon,
80
- epsilon_high=epsilon_high,
81
- beta=beta,
82
- )
83
-
84
-
85
- def bnpo_loss_fwd(
86
- policy_logprobs: torch.Tensor,
87
- old_policy_logprobs: torch.Tensor,
88
- ref_logprobs: torch.Tensor,
89
- advantages: torch.Tensor,
90
- completions_mask: torch.Tensor,
91
- epsilon: float = 0.2,
92
- epsilon_high: float = 0.2,
93
- beta: float = 0.1,
94
- ) -> torch.Tensor:
95
- """Forward-only bnpo loss. Returns the scalar ``loss``.
96
-
97
- Use for inference / validation. The masked mean denominator is
98
- computed on-GPU by an atomic accumulator + last-block trick — no
99
- host ``completions_mask.sum()`` syncs.
100
-
101
- Args:
102
- policy_logprobs, old_policy_logprobs, ref_logprobs: ``(bs, seq_len)``.
103
- advantages: ``(bs,)``.
104
- completions_mask: bool/int8 mask ``(bs, seq_len)``; truthy = valid token.
105
- epsilon, epsilon_high: PPO-style clipping bounds.
106
- beta: KL-penalty coefficient. ``0.0`` compiles away the KL branch.
107
-
108
- Returns:
109
- Scalar tensor (0-dim) with the same dtype as ``policy_logprobs``.
110
- """
111
- run = _get_compiled_fwd(
112
- policy_logprobs.dtype,
113
- float(epsilon),
114
- float(epsilon_high),
115
- float(beta),
116
- )
117
- mask_arg = (
118
- completions_mask
119
- if completions_mask.dtype == torch.int8
120
- else completions_mask.to(torch.int8)
121
- )
122
- return run(
123
- policy_logprobs,
124
- old_policy_logprobs,
125
- ref_logprobs,
126
- advantages,
127
- mask_arg,
128
- )
129
-
130
-
131
- def bnpo_loss(
132
- policy_logprobs: torch.Tensor,
133
- old_policy_logprobs: torch.Tensor,
134
- ref_logprobs: torch.Tensor,
135
- advantages: torch.Tensor,
136
- completions_mask: torch.Tensor,
137
- epsilon: float = 0.2,
138
- epsilon_high: float = 0.2,
139
- beta: float = 0.1,
140
- ) -> tuple[torch.Tensor, torch.Tensor]:
141
- """Fused fwd+bwd bnpo loss. Returns ``(loss, grad_policy_logprobs)``.
142
-
143
- Single-launch training entry point. The kernel writes both the
144
- scalar loss and the scaled ``dL/d(policy_logprobs)`` tensor in one
145
- ``@cute.jit`` dispatch — a bundled mask-sum kernel runs inside the
146
- same launch so ``inv_total`` is populated on-GPU without a host-side
147
- ``torch.sum`` round trip.
148
-
149
- Inputs do **not** need ``requires_grad=True``. To chain ``grad``
150
- into the upstream model that produced ``policy_logprobs``::
151
-
152
- loss, grad = bnpo_loss(policy_logprobs, ..., completions_mask=mask)
153
- policy_logprobs.backward(grad)
154
- optimizer.step()
155
-
156
- Args:
157
- policy_logprobs, old_policy_logprobs, ref_logprobs: ``(bs, seq_len)``.
158
- advantages: ``(bs,)``.
159
- completions_mask: bool/int8 mask ``(bs, seq_len)``.
160
- epsilon, epsilon_high: PPO-style clipping bounds.
161
- beta: KL-penalty coefficient. ``0.0`` compiles away the KL branch.
162
-
163
- Returns:
164
- ``(loss, grad_policy_logprobs)`` — ``loss`` is a 0-dim tensor in
165
- ``policy_logprobs.dtype``; ``grad_policy_logprobs`` has shape
166
- ``(bs, seq_len)`` and is already scaled by ``1 / n_valid``. The
167
- gradient tensor is freshly allocated per call (no shared cache),
168
- so callers may keep it around freely.
169
-
170
- For inference / validation where you only need the loss, use
171
- :func:`bnpo_loss_fwd` — it skips the dpolicy write entirely and
172
- computes the mean denominator with the on-GPU last-block trick.
173
- """
174
- run = _get_compiled_fwd_bwd(
175
- policy_logprobs.dtype,
176
- float(epsilon),
177
- float(epsilon_high),
178
- float(beta),
179
- )
180
- mask_arg = (
181
- completions_mask
182
- if completions_mask.dtype == torch.int8
183
- else completions_mask.to(torch.int8)
184
- )
185
- return run(
186
- policy_logprobs,
187
- old_policy_logprobs,
188
- ref_logprobs,
189
- advantages,
190
- mask_arg,
191
- )
192
-
193
-
194
- # Imported at the bottom: ``autograd.py`` imports ``bnpo_loss`` from this
195
- # module, so the function must be fully defined before its import runs.
196
- from .autograd import bnpo_loss_autograd # noqa: E402
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-cuda/bnpo_loss/_torch_ref.py DELETED
@@ -1,56 +0,0 @@
1
- """Plain-PyTorch bnpo reference shared between the bench and the tests.
2
-
3
- This module is intentionally minimal — every op is a vanilla torch op so
4
- ``AOTAutograd`` can derive the joint fwd+bwd graph and Inductor can fuse
5
- both passes (used by ``benchmarks/benchmark_bnpo_loss.py``'s compiled
6
- baseline). The same function is imported by ``tests/test_bnpo_loss.py``
7
- as the correctness reference, so both paths agree on what "the eager
8
- torch implementation of bnpo loss" means.
9
-
10
- Underscore-prefixed module name signals "shared internal", not a public
11
- API surface — there's no re-export from the package's top-level
12
- ``__init__.py``.
13
- """
14
-
15
- from __future__ import annotations
16
-
17
- import torch
18
-
19
-
20
- def torch_bnpo_loss(
21
- policy_logprobs: torch.Tensor,
22
- old_policy_logprobs: torch.Tensor,
23
- ref_logprobs: torch.Tensor,
24
- advantages: torch.Tensor,
25
- completions_mask: torch.Tensor,
26
- epsilon: float = 0.2,
27
- epsilon_high: float = 0.2,
28
- beta: float = 0.1,
29
- ) -> torch.Tensor:
30
- """Plain-Python bnpo reference traceable by AOTAutograd / Inductor.
31
-
32
- Operates in the input dtype throughout (no internal fp32 cast),
33
- which is what real torch users would write — and what
34
- ``torch.compile`` competes against in the bench.
35
- """
36
- ratio = torch.exp(policy_logprobs - old_policy_logprobs)
37
- adv = advantages.unsqueeze(1)
38
-
39
- surrogate = ratio * adv
40
- surrogate_clipped = torch.clamp(ratio, 1.0 - epsilon, 1.0 + epsilon_high) * adv
41
- policy_loss = -torch.min(surrogate, surrogate_clipped)
42
-
43
- log_ratio_ref = ref_logprobs - policy_logprobs
44
- kl = torch.exp(log_ratio_ref) - log_ratio_ref - 1.0
45
-
46
- # Cast n_valid to fp32: int64 → fp16 overflows when n_valid > 65504.
47
- # ``clamp(min=1.0)`` matches TRL's ``mask.sum().clamp(min=1)``: a
48
- # fully-masked batch produces ``loss=0`` instead of inf/NaN. Mirrors
49
- # the cute kernel's ``cute.arch.fmax(..., 1.0)`` before ``rcp_approx``
50
- # in ``cute_bnpo_loss.py``.
51
- n_valid = completions_mask.sum().to(torch.float32).clamp(min=1.0)
52
- policy_loss = (policy_loss * completions_mask).sum() / n_valid
53
- kl = (kl * completions_mask).sum() / n_valid
54
-
55
- loss = policy_loss + beta * kl
56
- return loss.to(policy_logprobs.dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-cuda/bnpo_loss/autograd.py DELETED
@@ -1,149 +0,0 @@
1
- """Autograd-aware wrapper for bnpo loss via ``torch.library.custom_op``.
2
-
3
- The fused cute kernel writes both the scalar loss and the closed-form
4
- ``dL/d(policy_logprobs)`` in one launch. This module wraps that into an
5
- autograd-compatible op so callers can write::
6
-
7
- loss = bnpo_loss_autograd(policy, old, ref, adv, completions_mask)
8
- loss.backward() # propagates through to the upstream model
9
-
10
- instead of the manual ``policy.backward(grad)`` chain. The cost is
11
- ~12µs of autograd dispatcher overhead per call (vs the direct
12
- ``bnpo_loss`` ``(loss, grad)`` tuple); for ergonomic / kernelize() flows
13
- that's cheap, but for tight microbenches use the direct path.
14
-
15
- Implementation notes:
16
-
17
- - The registered op returns ``(loss, dpolicy)`` so ``setup_context`` can
18
- ``save_for_backward(dpolicy)``. The public ``bnpo_loss_autograd``
19
- wrapper hides the second output.
20
- - ``dpolicy`` is allocated fresh by the runner on every call (no shared
21
- cache), so ``ctx.save_for_backward(dpolicy)`` keeps a stable reference
22
- across subsequent calls without any extra copy.
23
- - Backward returns ``grad_loss * dpolicy``. Under ``torch.compile``,
24
- when ``loss`` is consumed by ``.backward()`` directly, ``grad_loss``
25
- is the constant 1.0 and Inductor can fold the multiply away — that's
26
- the main reason this path uses ``custom_op`` instead of a plain
27
- ``autograd.Function``.
28
- - ``register_fake`` provides the meta kernel for ``torch.compile`` shape
29
- propagation; the real cute kernel never runs under ``FakeTensorMode``.
30
- """
31
-
32
- from __future__ import annotations
33
-
34
- import torch
35
-
36
- from . import bnpo_loss as _bnpo_loss_fwd_bwd
37
-
38
- __all__ = ["bnpo_loss_autograd"]
39
-
40
-
41
- @torch.library.custom_op(
42
- "geometric_ai_kernels::_bnpo_loss_with_grad",
43
- mutates_args=(),
44
- )
45
- def _bnpo_loss_with_grad(
46
- policy_logprobs: torch.Tensor,
47
- old_policy_logprobs: torch.Tensor,
48
- ref_logprobs: torch.Tensor,
49
- advantages: torch.Tensor,
50
- completions_mask: torch.Tensor,
51
- epsilon: float,
52
- epsilon_high: float,
53
- beta: float,
54
- ) -> tuple[torch.Tensor, torch.Tensor]:
55
- loss, dpolicy = _bnpo_loss_fwd_bwd(
56
- policy_logprobs,
57
- old_policy_logprobs,
58
- ref_logprobs,
59
- advantages,
60
- completions_mask,
61
- epsilon=epsilon,
62
- epsilon_high=epsilon_high,
63
- beta=beta,
64
- )
65
- return loss, dpolicy
66
-
67
-
68
- @_bnpo_loss_with_grad.register_fake
69
- def _(
70
- policy_logprobs: torch.Tensor,
71
- old_policy_logprobs: torch.Tensor,
72
- ref_logprobs: torch.Tensor,
73
- advantages: torch.Tensor,
74
- completions_mask: torch.Tensor,
75
- epsilon: float,
76
- epsilon_high: float,
77
- beta: float,
78
- ) -> tuple[torch.Tensor, torch.Tensor]:
79
- # Signature must mirror the op; only ``policy_logprobs`` shapes the outputs.
80
- del old_policy_logprobs, ref_logprobs, advantages, completions_mask
81
- del epsilon, epsilon_high, beta
82
- loss = policy_logprobs.new_empty(())
83
- dpolicy = torch.empty_like(policy_logprobs)
84
- return loss, dpolicy
85
-
86
-
87
- def _setup_context(ctx, inputs, output) -> None: # type: ignore[no-untyped-def]
88
- del inputs # only ``output`` carries what we need to save.
89
- _, dpolicy = output
90
- ctx.save_for_backward(dpolicy)
91
-
92
-
93
- def _backward(ctx, grad_loss, grad_dpolicy): # type: ignore[no-untyped-def]
94
- # ``grad_dpolicy`` is unused — ``dpolicy`` is an internal intermediate
95
- # exposed only so ``setup_context`` can save it. Under typical usage
96
- # (``loss.backward()``) it arrives as ``None`` or a zero tensor.
97
- del grad_dpolicy
98
- (dpolicy,) = ctx.saved_tensors
99
- grad_policy = grad_loss * dpolicy
100
- # One return per input to the op (8): policy_logprobs gets the grad,
101
- # everything else gets None (no autograd flow).
102
- return grad_policy, None, None, None, None, None, None, None
103
-
104
-
105
- torch.library.register_autograd(
106
- "geometric_ai_kernels::_bnpo_loss_with_grad",
107
- _backward,
108
- setup_context=_setup_context,
109
- )
110
-
111
-
112
- def bnpo_loss_autograd(
113
- policy_logprobs: torch.Tensor,
114
- old_policy_logprobs: torch.Tensor,
115
- ref_logprobs: torch.Tensor,
116
- advantages: torch.Tensor,
117
- completions_mask: torch.Tensor,
118
- epsilon: float = 0.2,
119
- epsilon_high: float = 0.2,
120
- beta: float = 0.1,
121
- ) -> torch.Tensor:
122
- """Autograd-aware bnpo loss. Returns scalar ``loss``.
123
-
124
- Same numerics as :func:`bnpo_loss` but registered as a
125
- ``torch.library`` custom op with autograd, so::
126
-
127
- loss = bnpo_loss_autograd(policy, ..., completions_mask)
128
- loss.backward()
129
-
130
- propagates through to whatever produced ``policy_logprobs``. For
131
- direct ``(loss, grad)`` access without the autograd dispatcher
132
- overhead, use :func:`bnpo_loss` and chain the gradient manually
133
- via ``policy_logprobs.backward(grad)``.
134
-
135
- Composes with ``torch.compile``: the op is opaque to Inductor but
136
- has a fake/meta kernel registered, so models containing this layer
137
- can be compiled end-to-end without graph breaks.
138
- """
139
- loss, _ = _bnpo_loss_with_grad(
140
- policy_logprobs,
141
- old_policy_logprobs,
142
- ref_logprobs,
143
- advantages,
144
- completions_mask,
145
- float(epsilon),
146
- float(epsilon_high),
147
- float(beta),
148
- )
149
- return loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-cuda/bnpo_loss/cute_bnpo_loss.py DELETED
@@ -1,1081 +0,0 @@
1
- """CuteDSL kernel for bnpo loss.
2
-
3
- Computes (element-wise over ``(bs, seq_len)`` logprob tensors, reduced to a
4
- scalar):
5
-
6
- ratio = exp(policy - old_policy)
7
- surrogate = ratio * adv
8
- clipped = clip(ratio, 1 - eps, 1 + eps_high) * adv
9
- policy_loss = -min(surrogate, clipped)
10
- log_ratio_ref = ref - policy
11
- kl = exp(log_ratio_ref) - log_ratio_ref - 1
12
- L_bnpo = (policy_loss * mask).sum() / n_valid
13
- + beta * (kl * mask).sum() / n_valid
14
-
15
- where ``n_valid = max(completions_mask.sum(), 1)``. The mean denominator is
16
- computed entirely on-GPU — the forward-only path uses an atomic accumulator
17
- + last-block trick on ``valid_acc``; the fused fwd+bwd path bundles a small
18
- companion mask-sum kernel into the same ``@cute.jit`` launch that writes
19
- ``1 / completions_mask.sum()`` into the ``inv_total`` GMEM scalar before the
20
- main kernel reads it. Every block needs ``inv_total`` mid-loop to scale its
21
- ``dpolicy`` slab, so the fwd-only last-block trick doesn't compose with
22
- backward; bundling the mask-sum keeps both paths host-sync-free and CUDA-graph
23
- compatible.
24
-
25
- When ``beta=0`` the KL term is skipped at compile time (no ``ref`` tensor
26
- access, no ``kl_acc`` atomic add).
27
-
28
- Sequence lengths that are **not** a multiple of ``TILE_N`` are handled
29
- natively: the grid launches ``ceil(seq_len / TILE_N)`` column tiles; full tiles
30
- use the vectorized ``LDG.128`` path and the tail tile uses predicated vector
31
- loads with neutral prefill.
32
-
33
- Two compiled-kernel flavors are exposed:
34
-
35
- * :func:`create_compiled_bnpo_loss` — forward-only.
36
- * :func:`create_compiled_bnpo_loss_with_backward` — fused fwd+bwd. Returns
37
- ``(loss, dpolicy)`` directly — no ``torch.autograd.Function`` wrapper. The
38
- autograd-aware sibling lives in ``autograd.py`` and uses
39
- ``torch.library.custom_op`` instead.
40
-
41
- Per-call output (``loss``, ``dpolicy``, ``inv_total``) is allocated inside the
42
- runner. Cross-CTA scratch (atomic accumulators + counters) is allocated lazily
43
- on first call inside the compiled-kernel closure and self-resets each launch
44
- via the kernel's last-block epilogue + ``atom.inc.u32`` wrap-around — callers
45
- don't manage scratch state.
46
- """
47
-
48
- from __future__ import annotations
49
-
50
- import math
51
- import operator
52
- from typing import TYPE_CHECKING, Any
53
- from typing import cast as _typing_cast
54
-
55
- import cutlass
56
- import cutlass.utils
57
- import torch
58
- from cutlass import cute
59
- from cutlass._mlir.dialects import llvm
60
- from cutlass.base_dsl.typing import cast
61
- from cutlass.cutlass_dsl import T, dsl_user_op
62
-
63
- if TYPE_CHECKING:
64
- from collections.abc import Callable
65
-
66
-
67
- TILE_N: int = 512
68
- NUM_WARPS: int = 4
69
- # ``VEC=4`` (fp32) emits 128-bit ``LDG.128``. Pairs with ``NUM_WARPS=4`` so
70
- # each block processes ``block_size * VEC = 512 = TILE_N`` elements per iter.
71
- VEC: int = 4
72
- # Large-tile variant: at very long ``seq_len`` the small-TILE_N grid
73
- # explodes (e.g. 8192/512 = 16 col-tiles per row → thousands of CTAs),
74
- # inflating last-block-detection latency and atomic contention. A second
75
- # compiled variant with this larger tile is dispatched when
76
- # ``seq_len >= TILE_N_LARGE_THRESHOLD``.
77
- TILE_N_LARGE: int = 4096
78
- TILE_N_LARGE_THRESHOLD: int = 2048
79
-
80
- _LOG2_E: float = math.log2(math.e)
81
-
82
- _TORCH_TO_CUTLASS_DTYPE: dict[torch.dtype, Any] = {
83
- torch.float32: cutlass.Float32,
84
- torch.float16: cutlass.Float16,
85
- torch.bfloat16: cutlass.BFloat16,
86
- }
87
-
88
-
89
- @dsl_user_op
90
- def _atomic_add_f32_gmem(
91
- ptr_i64: Any,
92
- val: cutlass.Float32,
93
- *,
94
- loc: Any = None,
95
- ip: Any = None,
96
- ) -> None:
97
- llvm.inline_asm(
98
- T.f32(),
99
- [ptr_i64, cutlass.Float32(val).ir_value(loc=loc, ip=ip)],
100
- "atom.global.add.f32 $0, [$1], $2;",
101
- "=f,l,f",
102
- has_side_effects=True,
103
- is_align_stack=False,
104
- asm_dialect=llvm.AsmDialect.AD_ATT,
105
- )
106
-
107
-
108
- @dsl_user_op
109
- def _atomic_add_s32_gmem(
110
- ptr_i64: Any,
111
- val: cutlass.Int32,
112
- *,
113
- loc: Any = None,
114
- ip: Any = None,
115
- ) -> None:
116
- """Emit ``atom.global.add.s32`` to a 64-bit GMEM address."""
117
- llvm.inline_asm(
118
- T.i32(),
119
- [ptr_i64, cutlass.Int32(val).ir_value(loc=loc, ip=ip)],
120
- "atom.global.add.s32 $0, [$1], $2;",
121
- "=r,l,r",
122
- has_side_effects=True,
123
- is_align_stack=False,
124
- asm_dialect=llvm.AsmDialect.AD_ATT,
125
- )
126
-
127
-
128
- @dsl_user_op
129
- def _dp4a_u32_acc_s32(
130
- packed_a: cutlass.Uint32,
131
- packed_b: cutlass.Uint32,
132
- acc: cutlass.Int32,
133
- *,
134
- loc: Any = None,
135
- ip: Any = None,
136
- ) -> cutlass.Int32:
137
- """``dp4a.u32.u32`` — sum 4 packed u8 products into an s32 acc.
138
-
139
- Computes ``a[0]*b[0] + a[1]*b[1] + a[2]*b[2] + a[3]*b[3] + acc`` in
140
- one ``IDP4A.U8.S32`` instruction (full-rate on Hopper/Blackwell).
141
- For mask summation, pass ``packed_b = 0x01010101`` so the products
142
- reduce to ``sum(a_bytes) + acc`` — 4× fewer ALU ops than 4 separate
143
- int8→int32 widens + adds.
144
- """
145
- return cutlass.Int32(
146
- llvm.inline_asm(
147
- T.i32(),
148
- [
149
- cutlass.Uint32(packed_a).ir_value(loc=loc, ip=ip),
150
- cutlass.Uint32(packed_b).ir_value(loc=loc, ip=ip),
151
- cutlass.Int32(acc).ir_value(loc=loc, ip=ip),
152
- ],
153
- "dp4a.u32.u32 $0, $1, $2, $3;",
154
- "=r,r,r,r",
155
- has_side_effects=False,
156
- is_align_stack=False,
157
- asm_dialect=llvm.AsmDialect.AD_ATT,
158
- )
159
- )
160
-
161
-
162
- @dsl_user_op
163
- def _atomic_inc_u32_gmem(
164
- ptr_i64: Any,
165
- threshold: cutlass.Int32,
166
- *,
167
- loc: Any = None,
168
- ip: Any = None,
169
- ) -> cutlass.Int32:
170
- """``atom.global.inc.u32`` — returns old value; wraps to 0 at threshold."""
171
- return cutlass.Int32(
172
- llvm.inline_asm(
173
- T.i32(),
174
- [ptr_i64, cutlass.Int32(threshold).ir_value(loc=loc, ip=ip)],
175
- "atom.global.inc.u32 $0, [$1], $2;",
176
- "=r,l,r",
177
- has_side_effects=True,
178
- is_align_stack=False,
179
- asm_dialect=llvm.AsmDialect.AD_ATT,
180
- )
181
- )
182
-
183
-
184
- # ---------------------------------------------------------------------------
185
- # Mask-sum kernel — replaces ``torch.sum(completions_mask)`` on the fwd+bwd
186
- # path. Bundled into the same ``@cute.jit`` launch as the main kernel so the
187
- # whole step is one tvm-ffi dispatch (no extra Python/torch dispatcher round
188
- # trip). The kernel writes ``1 / completions_mask.sum()`` directly into
189
- # ``inv_total_tensor`` so the main kernel reads it as a pre-inverted scalar.
190
- # ---------------------------------------------------------------------------
191
-
192
-
193
- def _make_mask_sum_kernel(tile_n: int) -> Callable[..., None]:
194
- """Return a ``@cute.kernel`` that reduces ``completions_mask`` and writes 1/sum.
195
-
196
- Grid mirrors the main kernel — ``(bs, num_col_tiles)`` — so the mask is
197
- read once with the same vectorised LDG pattern as the main compute.
198
- Each block:
199
-
200
- 1. Loads its ``tile_n`` int8 slab of ``completions_mask`` (predicated tail).
201
- 2. Reduces to a per-block ``int32`` scalar (bit-exact, no per-element
202
- i8→f32 cast — IADD throughput equals FADD on Hopper/Blackwell).
203
- 3. Atomically adds it to ``valid_acc`` (global int32 accumulator).
204
- 4. Increments ``mask_counter``; the last block reads ``valid_acc``,
205
- casts to fp32, computes ``rcp_approx`` and writes
206
- ``inv_total_tensor[0]``, then resets ``valid_acc`` to ``0`` so
207
- the next call starts fresh. The counter self-resets via
208
- ``atom.inc.u32`` wrap-around.
209
-
210
- A separate ``mask_counter`` tensor (not the main kernel's ``counter``)
211
- is required because the two kernels run in series within the same
212
- ``@cute.jit`` and both rely on a wrap-around for self-reset; sharing
213
- one counter would race.
214
- """
215
-
216
- @cute.kernel
217
- def _mask_sum_kernel(
218
- completions_mask: cute.Tensor, # (bs, seq_len) int8
219
- inv_total_tensor: cute.Tensor, # (1,) fp32 — output
220
- valid_acc: cute.Tensor, # (1,) int32 — accumulator
221
- mask_counter: cute.Tensor, # (1,) i32 — last-block detection
222
- total_blocks: cutlass.Int32,
223
- num_full_tiles: cutlass.Int32,
224
- tail_len: cutlass.Int32,
225
- ) -> None:
226
- block_size = NUM_WARPS * 32
227
- iters = tile_n // (block_size * VEC)
228
-
229
- _no_alloc = cute.nvgpu.CacheEvictionPriority.NO_ALLOCATE
230
- g2r_op = cute.nvgpu.CopyUniversalOp()
231
- g2r_mask_atom = cute.make_copy_atom(
232
- g2r_op,
233
- completions_mask.element_type,
234
- num_bits_per_copy=0,
235
- l1c_evict_priority=_no_alloc,
236
- )
237
-
238
- row = cute.arch.block_idx()[0]
239
- col_block = cute.arch.block_idx()[1]
240
- tid = cute.arch.thread_idx()[0]
241
-
242
- local_valid_sum = cutlass.Int32(0)
243
- mask_row = cute.slice_(completions_mask, (row, None))
244
-
245
- # ``dp4a.u32.u32`` consumes a packed-u8x4 register. With VEC=4 each
246
- # thread loads 4 contiguous int8 bytes per iteration, so we recast
247
- # the fragment as a single ``Uint32`` view and feed it directly
248
- # into dp4a — one instruction sums all 4 bytes, vs the previous
249
- # cast+reduce which emitted 4 widens + 3 adds per iteration.
250
- ones_packed = cutlass.Uint32(0x01010101)
251
-
252
- if col_block < num_full_tiles:
253
- mask_slab = cute.local_tile(mask_row, (tile_n,), (col_block,))
254
- for k in cutlass.range(iters, unroll_full=True):
255
- sub_idx = tid + k * block_size
256
- mask_src = cute.local_tile(mask_slab, (VEC,), (sub_idx,))
257
- mask_frag = cute.make_fragment_like(mask_src)
258
- cute.copy(g2r_mask_atom, mask_src, mask_frag)
259
- packed = cute.recast_tensor(mask_frag, cutlass.Uint32)[0]
260
- local_valid_sum = _dp4a_u32_acc_s32(packed, ones_packed, local_valid_sum)
261
- else:
262
- mask_slab = cute.local_tile(mask_row, (tile_n,), (col_block,))
263
- for k in cutlass.range(iters, unroll_full=True):
264
- sub_idx = tid + k * block_size
265
- chunk_base = sub_idx * VEC
266
- if chunk_base < tail_len:
267
- mask_src = cute.local_tile(mask_slab, (VEC,), (sub_idx,))
268
- pred = cute.make_rmem_tensor(mask_src.shape, cutlass.Boolean)
269
- for v in cutlass.range(VEC, unroll_full=True):
270
- pred[v] = cute.elem_less(chunk_base + v, tail_len)
271
- mask_frag = cute.make_fragment_like(mask_src)
272
- mask_frag.fill(0)
273
- cute.copy(g2r_mask_atom, mask_src, mask_frag, pred=pred)
274
- packed = cute.recast_tensor(mask_frag, cutlass.Uint32)[0]
275
- local_valid_sum = _dp4a_u32_acc_s32(packed, ones_packed, local_valid_sum)
276
-
277
- # Warp + cross-warp reduction (same pattern as main kernel).
278
- warp_valid = cute.arch.warp_reduction(local_valid_sum, operator.add)
279
- smem = cutlass.utils.SmemAllocator()
280
- buf_valid = smem.allocate_tensor(cutlass.Int32, cute.make_layout(NUM_WARPS))
281
-
282
- lane_idx = cute.arch.lane_idx()
283
- warp_idx = cute.arch.warp_idx()
284
-
285
- if lane_idx == 0:
286
- buf_valid[warp_idx] = warp_valid
287
- cute.arch.barrier()
288
-
289
- if warp_idx == 0:
290
- val_v = cutlass.Int32(0)
291
- if lane_idx < NUM_WARPS:
292
- val_v = buf_valid[lane_idx]
293
- block_valid = cute.arch.warp_reduction(val_v, operator.add, threads_in_group=NUM_WARPS)
294
-
295
- if lane_idx == 0:
296
- valid_ptr = valid_acc.iterator.toint().ir_value() # ty: ignore[unresolved-attribute]
297
- counter_ptr = mask_counter.iterator.toint().ir_value() # ty: ignore[unresolved-attribute]
298
-
299
- _atomic_add_s32_gmem(valid_ptr, block_valid)
300
- cute.arch.fence_acq_rel_gpu()
301
- old = _atomic_inc_u32_gmem(counter_ptr, total_blocks - 1)
302
-
303
- if old == total_blocks - 1:
304
- # Clamp to >=1.0 so a fully-masked batch (n_valid=0)
305
- # produces ``loss=0`` instead of inf/NaN — matches
306
- # TRL's ``mask.sum().clamp(min=1)`` semantics.
307
- n_valid = cute.arch.fmax(cutlass.Float32(valid_acc[0]), cutlass.Float32(1.0))
308
- inv_total_tensor[0] = cute.arch.rcp_approx(n_valid)
309
- valid_acc[0] = cutlass.Int32(0)
310
-
311
- return _mask_sum_kernel
312
-
313
-
314
- def _make_bnpo_kernel(
315
- compute_kl: bool,
316
- compute_backward: bool,
317
- tile_n: int,
318
- ) -> Callable[..., None]:
319
- """Return a ``@cute.kernel`` specialised on compile-time flags.
320
-
321
- The returned kernel captures *compute_kl*, *compute_backward*, and
322
- *tile_n* in its closure. ``cutlass.const_expr`` evaluates the booleans
323
- at trace time so dead branches are eliminated from the compiled PTX.
324
- ``tile_n`` is a Python ``int`` captured at trace time, so the same
325
- factory can emit two specialised kernels (small / large tile) — see
326
- :func:`create_compiled_bnpo_loss` for dispatch.
327
-
328
- When *compute_backward* is True the kernel additionally writes
329
- ``dpolicy = dL/d(policy_logprobs)`` to GMEM in the same inner loop —
330
- no extra HBM reads of the inputs. Because every block must scale
331
- ``dpolicy`` by ``inv_total`` mid-loop, the on-GPU last-block computation
332
- of ``inv_total`` from the masked accumulator does **not** compose with
333
- backward; the bundled mask-sum kernel populates ``inv_total_tensor``
334
- before the main kernel runs.
335
-
336
- When *compute_backward* is False the kernel accumulates the
337
- mask-element count via ``valid_acc`` and computes
338
- ``inv_total = 1 / n_valid`` on-GPU in the last-block path — no
339
- host-side ``completions_mask.sum()`` required.
340
- """
341
-
342
- @cute.kernel
343
- def _bnpo_loss_kernel(
344
- policy: cute.Tensor,
345
- old_policy: cute.Tensor,
346
- ref: cute.Tensor,
347
- advantages: cute.Tensor,
348
- completions_mask: cute.Tensor,
349
- dpolicy: cute.Tensor, # (bs, seq_len) when compute_backward; (bs, 1) dummy otherwise
350
- inv_total_tensor: cute.Tensor, # (1,) fp32 — caller-populated 1/n_valid
351
- policy_acc: cute.Tensor,
352
- kl_acc: cute.Tensor,
353
- valid_acc: cute.Tensor, # (1,) int32 — mask-element count accumulator
354
- counter: cute.Tensor,
355
- output: cute.Tensor,
356
- epsilon: cutlass.Float32,
357
- epsilon_high: cutlass.Float32,
358
- beta: cutlass.Float32,
359
- total_blocks: cutlass.Int32,
360
- num_full_tiles: cutlass.Int32,
361
- tail_len: cutlass.Int32,
362
- ) -> None:
363
- block_size = NUM_WARPS * 32
364
- iters = tile_n // (block_size * VEC)
365
-
366
- # Read inv_total from GMEM once per block (hoisted, single load).
367
- # Skipped on the fwd-only path which uses an on-GPU last-block
368
- # computation from the valid_acc accumulator instead. On the
369
- # compute_backward path the bundled mask-sum kernel writes
370
- # ``1 / completions_mask.sum()`` into ``inv_total_tensor`` before
371
- # this kernel runs, so the load returns the pre-inverted scalar.
372
- accumulate_valid = not compute_backward
373
- if cutlass.const_expr(not accumulate_valid):
374
- inv_total = cast(inv_total_tensor[0], cutlass.Float32)
375
-
376
- _no_alloc = cute.nvgpu.CacheEvictionPriority.NO_ALLOCATE
377
- g2r_op = cute.nvgpu.CopyUniversalOp()
378
- g2r_atom = cute.make_copy_atom(
379
- g2r_op,
380
- policy.element_type,
381
- num_bits_per_copy=0,
382
- l1c_evict_priority=_no_alloc,
383
- )
384
- g2r_mask_atom = cute.make_copy_atom(
385
- g2r_op,
386
- completions_mask.element_type,
387
- num_bits_per_copy=0,
388
- l1c_evict_priority=_no_alloc,
389
- )
390
- if cutlass.const_expr(compute_backward):
391
- r2g_atom = cute.make_copy_atom(
392
- g2r_op,
393
- dpolicy.element_type,
394
- num_bits_per_copy=0,
395
- )
396
-
397
- row = cute.arch.block_idx()[0]
398
- col_block = cute.arch.block_idx()[1]
399
- tid = cute.arch.thread_idx()[0]
400
-
401
- adv = cast(advantages[row], cutlass.Float32)
402
- lo = cutlass.Float32(1.0) - epsilon
403
- hi = cutlass.Float32(1.0) + epsilon_high
404
-
405
- local_policy_sum = cutlass.Float32(0.0)
406
- local_kl_sum = cutlass.Float32(0.0)
407
- # mask_vec is already cast to fp32 for loss/kl multiplications, so
408
- # accumulate valid in fp32 too (avoids a separate i8→i32 reduction).
409
- # Cast to int32 only at the atomic boundary so the shared
410
- # ``valid_acc`` global can remain int32 — see ``_atomic_add_s32_gmem``.
411
- local_valid_sum = cutlass.Float32(0.0)
412
-
413
- pol_row = cute.slice_(policy, (row, None))
414
- old_row = cute.slice_(old_policy, (row, None))
415
-
416
- if cutlass.const_expr(compute_kl):
417
- ref_row = cute.slice_(ref, (row, None))
418
-
419
- mask_row = cute.slice_(completions_mask, (row, None))
420
-
421
- if cutlass.const_expr(compute_backward):
422
- dp_row = cute.slice_(dpolicy, (row, None))
423
-
424
- # ---- Full-tile vectorised path (LDG.128) ----
425
- if col_block < num_full_tiles:
426
- pol_slab = cute.local_tile(pol_row, (tile_n,), (col_block,))
427
- old_slab = cute.local_tile(old_row, (tile_n,), (col_block,))
428
-
429
- if cutlass.const_expr(compute_kl):
430
- ref_slab = cute.local_tile(ref_row, (tile_n,), (col_block,))
431
-
432
- mask_slab = cute.local_tile(mask_row, (tile_n,), (col_block,))
433
-
434
- if cutlass.const_expr(compute_backward):
435
- dp_slab = cute.local_tile(dp_row, (tile_n,), (col_block,))
436
-
437
- for k in cutlass.range(iters, unroll_full=True):
438
- sub_idx = tid + k * block_size
439
-
440
- pol_src = cute.local_tile(pol_slab, (VEC,), (sub_idx,))
441
- old_src = cute.local_tile(old_slab, (VEC,), (sub_idx,))
442
- pol_frag = cute.make_fragment_like(pol_src)
443
- old_frag = cute.make_fragment_like(old_src)
444
- cute.copy(g2r_atom, pol_src, pol_frag)
445
- cute.copy(g2r_atom, old_src, old_frag)
446
-
447
- if cutlass.const_expr(compute_kl):
448
- ref_src = cute.local_tile(ref_slab, (VEC,), (sub_idx,))
449
- ref_frag = cute.make_fragment_like(ref_src)
450
- cute.copy(g2r_atom, ref_src, ref_frag)
451
-
452
- mask_src = cute.local_tile(mask_slab, (VEC,), (sub_idx,))
453
- mask_frag = cute.make_fragment_like(mask_src)
454
- cute.copy(g2r_mask_atom, mask_src, mask_frag)
455
-
456
- pol_vec = pol_frag.load().to(cutlass.Float32)
457
- old_vec = old_frag.load().to(cutlass.Float32)
458
-
459
- log_ratio = pol_vec - old_vec
460
- ratio = cute.math.exp2(log_ratio * _LOG2_E, fastmath=True)
461
- surrogate = ratio * adv
462
- clipped_ratio = cute.where(
463
- ratio < lo,
464
- lo,
465
- cute.where(ratio > hi, hi, ratio),
466
- )
467
- clipped = clipped_ratio * adv
468
- policy_loss = -cute.where(surrogate < clipped, surrogate, clipped)
469
-
470
- if cutlass.const_expr(compute_kl):
471
- ref_vec = ref_frag.load().to(cutlass.Float32)
472
- log_ratio_ref = ref_vec - pol_vec
473
- ratio_ref = cute.math.exp2(log_ratio_ref * _LOG2_E, fastmath=True)
474
- # FFMA-friendly rearrangement: ``(ratio_ref - 1) - log_ratio_ref``
475
- # exposes a ``ratio_ref + (-1)`` pair that ptxas folds with
476
- # the subsequent subtract — same arithmetic, fewer FADDs
477
- # surviving SASS than the original 3-term ``a - b - c``.
478
- kl_val = (ratio_ref - cutlass.Float32(1.0)) - log_ratio_ref
479
-
480
- mask_vec = mask_frag.load().to(cutlass.Float32)
481
- local_policy_sum += (policy_loss * mask_vec).reduce(
482
- cute.ReductionOp.ADD,
483
- cutlass.Float32(0.0),
484
- reduction_profile=0,
485
- )
486
- if cutlass.const_expr(not compute_backward):
487
- local_valid_sum += mask_vec.reduce(
488
- cute.ReductionOp.ADD,
489
- cutlass.Float32(0.0),
490
- reduction_profile=0,
491
- )
492
- if cutlass.const_expr(compute_kl):
493
- local_kl_sum += (kl_val * mask_vec).reduce(
494
- cute.ReductionOp.ADD,
495
- cutlass.Float32(0.0),
496
- reduction_profile=0,
497
- )
498
-
499
- # ---- Backward: write scaled dpolicy slab in same loop ----
500
- # use_unclipped = (surrogate <= clipped) — matches torch's
501
- # convention. d/d(policy) of -min(surrogate, clipped) is
502
- # -adv*ratio when use_unclipped, else 0 (clamp grad = 0).
503
- # ``-(adv * ratio)`` is just ``-surrogate`` (already in
504
- # scope) — saves one FMUL per element.
505
- # KL term: d/d(policy) of (ratio_ref - log_ratio_ref - 1)
506
- # = -(ratio_ref - 1) = 1 - ratio_ref.
507
- if cutlass.const_expr(compute_backward):
508
- neg_surrogate_grad = cute.where(
509
- surrogate <= clipped,
510
- -surrogate,
511
- cutlass.Float32(0.0),
512
- )
513
- if cutlass.const_expr(compute_kl):
514
- # ``beta - beta*ratio_ref`` instead of ``beta*(1 - ratio_ref)``
515
- # gives ptxas an obvious FFMA pattern (``FFMA -beta,
516
- # ratio_ref, beta``) — saves one FMUL per element vs
517
- # the (1 - ratio_ref) intermediate.
518
- kl_grad = beta - beta * ratio_ref
519
- dpolicy_vec = neg_surrogate_grad + kl_grad
520
- else:
521
- dpolicy_vec = neg_surrogate_grad
522
- dpolicy_vec = dpolicy_vec * mask_vec
523
- dpolicy_vec = dpolicy_vec * inv_total
524
-
525
- dp_dst = cute.local_tile(dp_slab, (VEC,), (sub_idx,))
526
- dp_frag = cute.make_fragment_like(dp_dst)
527
- dp_frag.store(dpolicy_vec.to(dpolicy.element_type))
528
- cute.copy(r2g_atom, dp_frag, dp_dst)
529
-
530
- else:
531
- # ---- Predicated vector tail path (< tile_n valid elements) ----
532
- pol_slab = cute.local_tile(pol_row, (tile_n,), (col_block,))
533
- old_slab = cute.local_tile(old_row, (tile_n,), (col_block,))
534
-
535
- if cutlass.const_expr(compute_kl):
536
- ref_slab = cute.local_tile(ref_row, (tile_n,), (col_block,))
537
-
538
- mask_slab = cute.local_tile(mask_row, (tile_n,), (col_block,))
539
-
540
- if cutlass.const_expr(compute_backward):
541
- dp_slab = cute.local_tile(dp_row, (tile_n,), (col_block,))
542
-
543
- for k in cutlass.range(iters, unroll_full=True):
544
- sub_idx = tid + k * block_size
545
- chunk_base = sub_idx * VEC
546
-
547
- if chunk_base < tail_len:
548
- pol_src = cute.local_tile(pol_slab, (VEC,), (sub_idx,))
549
- old_src = cute.local_tile(old_slab, (VEC,), (sub_idx,))
550
- pred = cute.make_rmem_tensor(pol_src.shape, cutlass.Boolean)
551
- for v in cutlass.range(VEC, unroll_full=True):
552
- pred[v] = cute.elem_less(chunk_base + v, tail_len)
553
-
554
- pol_frag = cute.make_fragment_like(pol_src)
555
- old_frag = cute.make_fragment_like(old_src)
556
- pol_frag.fill(0.0)
557
- old_frag.fill(0.0)
558
- cute.copy(g2r_atom, pol_src, pol_frag, pred=pred)
559
- cute.copy(g2r_atom, old_src, old_frag, pred=pred)
560
-
561
- if cutlass.const_expr(compute_kl):
562
- ref_src = cute.local_tile(ref_slab, (VEC,), (sub_idx,))
563
- ref_frag = cute.make_fragment_like(ref_src)
564
- ref_frag.fill(0.0)
565
- cute.copy(g2r_atom, ref_src, ref_frag, pred=pred)
566
-
567
- mask_src = cute.local_tile(mask_slab, (VEC,), (sub_idx,))
568
- mask_frag = cute.make_fragment_like(mask_src)
569
- mask_frag.fill(0)
570
- cute.copy(g2r_mask_atom, mask_src, mask_frag, pred=pred)
571
-
572
- pol_vec = pol_frag.load().to(cutlass.Float32)
573
- old_vec = old_frag.load().to(cutlass.Float32)
574
- valid_vec = cute.where(
575
- pred.load(),
576
- cute.full_like(pol_vec, cutlass.Float32(1.0)),
577
- cute.zeros_like(pol_vec, dtype=cutlass.Float32),
578
- )
579
-
580
- log_ratio = pol_vec - old_vec
581
- ratio = cute.math.exp2(log_ratio * _LOG2_E, fastmath=True)
582
- surrogate = ratio * adv
583
- clipped_ratio = cute.where(
584
- ratio < lo,
585
- lo,
586
- cute.where(ratio > hi, hi, ratio),
587
- )
588
- clipped = clipped_ratio * adv
589
- policy_loss = -cute.where(surrogate < clipped, surrogate, clipped)
590
-
591
- if cutlass.const_expr(compute_kl):
592
- ref_vec = ref_frag.load().to(cutlass.Float32)
593
- log_ratio_ref = ref_vec - pol_vec
594
- ratio_ref = cute.math.exp2(log_ratio_ref * _LOG2_E, fastmath=True)
595
- # FFMA-friendly rearrangement — see full-tile path.
596
- kl_val = (ratio_ref - cutlass.Float32(1.0)) - log_ratio_ref
597
-
598
- mask_vec = mask_frag.load().to(cutlass.Float32) * valid_vec
599
- local_policy_sum += (policy_loss * mask_vec).reduce(
600
- cute.ReductionOp.ADD,
601
- cutlass.Float32(0.0),
602
- reduction_profile=0,
603
- )
604
- if cutlass.const_expr(not compute_backward):
605
- local_valid_sum += mask_vec.reduce(
606
- cute.ReductionOp.ADD,
607
- cutlass.Float32(0.0),
608
- reduction_profile=0,
609
- )
610
- if cutlass.const_expr(compute_kl):
611
- local_kl_sum += (kl_val * mask_vec).reduce(
612
- cute.ReductionOp.ADD,
613
- cutlass.Float32(0.0),
614
- reduction_profile=0,
615
- )
616
-
617
- # ---- Backward: predicated dpolicy slab write ----
618
- # Same gradient math as the full-tile path. ``valid_vec``
619
- # already encodes the in-bounds predicate (1.0 inside,
620
- # 0.0 outside) and is folded into ``mask_vec``, so
621
- # multiplying by it zeros out the padded positions.
622
- if cutlass.const_expr(compute_backward):
623
- neg_surrogate_grad = cute.where(
624
- surrogate <= clipped,
625
- -surrogate,
626
- cutlass.Float32(0.0),
627
- )
628
- if cutlass.const_expr(compute_kl):
629
- kl_grad = beta - beta * ratio_ref
630
- dpolicy_vec = neg_surrogate_grad + kl_grad
631
- else:
632
- dpolicy_vec = neg_surrogate_grad
633
- dpolicy_vec = dpolicy_vec * mask_vec
634
- dpolicy_vec = dpolicy_vec * inv_total
635
-
636
- dp_dst = cute.local_tile(dp_slab, (VEC,), (sub_idx,))
637
- dp_frag = cute.make_fragment_like(dp_dst)
638
- dp_frag.store(dpolicy_vec.to(dpolicy.element_type))
639
- cute.copy(r2g_atom, dp_frag, dp_dst, pred=pred)
640
-
641
- # ---- Stage 1: Intra-warp reduction (butterfly XOR shuffles) ----
642
- warp_policy = cute.arch.warp_reduction(local_policy_sum, operator.add)
643
- if cutlass.const_expr(compute_kl):
644
- warp_kl = cute.arch.warp_reduction(local_kl_sum, operator.add)
645
-
646
- smem = cutlass.utils.SmemAllocator()
647
- buf_policy = smem.allocate_tensor(cutlass.Float32, cute.make_layout(NUM_WARPS))
648
- if cutlass.const_expr(compute_kl):
649
- buf_kl = smem.allocate_tensor(cutlass.Float32, cute.make_layout(NUM_WARPS))
650
-
651
- lane_idx = cute.arch.lane_idx()
652
- warp_idx = cute.arch.warp_idx()
653
-
654
- # When compute_backward is True the bundled mask-sum kernel populates
655
- # inv_total_tensor before this kernel runs, so on-GPU mask-element
656
- # accumulation is dead code.
657
- if cutlass.const_expr(accumulate_valid):
658
- warp_valid = cute.arch.warp_reduction(local_valid_sum, operator.add)
659
- buf_valid = smem.allocate_tensor(cutlass.Float32, cute.make_layout(NUM_WARPS))
660
-
661
- # ---- Stage 2: Cross-warp reduction via SMEM ----
662
- if lane_idx == 0:
663
- buf_policy[warp_idx] = warp_policy
664
- if cutlass.const_expr(compute_kl):
665
- buf_kl[warp_idx] = warp_kl
666
- if cutlass.const_expr(accumulate_valid):
667
- buf_valid[warp_idx] = warp_valid
668
- cute.arch.barrier()
669
-
670
- if warp_idx == 0:
671
- val_p = cutlass.Float32(0.0)
672
- if lane_idx < NUM_WARPS:
673
- val_p = buf_policy[lane_idx]
674
-
675
- block_policy = cute.arch.warp_reduction(val_p, operator.add, threads_in_group=NUM_WARPS)
676
-
677
- if cutlass.const_expr(compute_kl):
678
- val_k = cutlass.Float32(0.0)
679
- if lane_idx < NUM_WARPS:
680
- val_k = buf_kl[lane_idx]
681
- block_kl = cute.arch.warp_reduction(val_k, operator.add, threads_in_group=NUM_WARPS)
682
-
683
- if cutlass.const_expr(accumulate_valid):
684
- val_v = cutlass.Float32(0.0)
685
- if lane_idx < NUM_WARPS:
686
- val_v = buf_valid[lane_idx]
687
- block_valid = cute.arch.warp_reduction(
688
- val_v, operator.add, threads_in_group=NUM_WARPS
689
- )
690
-
691
- # ---- Stage 3: Cross-CTA atomic accumulation ----
692
- if lane_idx == 0:
693
- policy_ptr = policy_acc.iterator.toint().ir_value() # ty: ignore[unresolved-attribute]
694
- counter_ptr = counter.iterator.toint().ir_value() # ty: ignore[unresolved-attribute]
695
-
696
- _atomic_add_f32_gmem(policy_ptr, block_policy)
697
-
698
- if cutlass.const_expr(compute_kl):
699
- kl_ptr = kl_acc.iterator.toint().ir_value() # ty: ignore[unresolved-attribute]
700
- _atomic_add_f32_gmem(kl_ptr, block_kl)
701
-
702
- if cutlass.const_expr(accumulate_valid):
703
- valid_ptr = valid_acc.iterator.toint().ir_value() # ty: ignore[unresolved-attribute]
704
- # valid_acc is int32. Per-block sums of int8 0/1 values
705
- # fit exactly in fp32 (≤ tile_n ≤ 4096 ≪ 2²⁴) so the
706
- # cast is bit-exact.
707
- _atomic_add_s32_gmem(valid_ptr, cutlass.Int32(block_valid))
708
-
709
- cute.arch.fence_acq_rel_gpu()
710
-
711
- old = _atomic_inc_u32_gmem(counter_ptr, total_blocks - 1)
712
-
713
- if old == total_blocks - 1:
714
- pol_sum = policy_acc[0]
715
-
716
- if cutlass.const_expr(accumulate_valid):
717
- # Clamp to >=1.0 so a fully-masked batch (n_valid=0)
718
- # produces ``loss=0`` instead of inf/NaN — matches
719
- # TRL's ``mask.sum().clamp(min=1)`` semantics.
720
- n_valid = cute.arch.fmax(
721
- cutlass.Float32(valid_acc[0]), cutlass.Float32(1.0)
722
- )
723
- inv_total_computed = cute.arch.rcp_approx(n_valid)
724
- else:
725
- # compute_backward path: bundled mask-sum kernel
726
- # already wrote the inverse so forward and backward
727
- # share the same scalar.
728
- inv_total_computed = inv_total
729
-
730
- if cutlass.const_expr(compute_kl):
731
- kl_sum = kl_acc[0]
732
- loss = (pol_sum + beta * kl_sum) * inv_total_computed
733
- else:
734
- loss = pol_sum * inv_total_computed
735
- output[0] = cast(loss, output.element_type) # ty: ignore[invalid-argument-type]
736
-
737
- # Reset accumulators for the next invocation.
738
- # Counter self-resets via atom.inc wrap-around.
739
- policy_acc[0] = cutlass.Float32(0.0)
740
- if cutlass.const_expr(compute_kl):
741
- kl_acc[0] = cutlass.Float32(0.0)
742
- if cutlass.const_expr(accumulate_valid):
743
- valid_acc[0] = cutlass.Int32(0)
744
-
745
- return _bnpo_loss_kernel
746
-
747
-
748
- def create_compiled_bnpo_loss(
749
- policy_dtype: torch.dtype,
750
- epsilon: float,
751
- epsilon_high: float,
752
- beta: float,
753
- compute_backward: bool = False,
754
- ) -> Callable[..., torch.Tensor | tuple[torch.Tensor, torch.Tensor]]:
755
- """Compile the bnpo loss kernel for a given dtype/KL/backward configuration.
756
-
757
- The runner allocates per-call scratch (``output``, ``inv_total``, and on
758
- the fwd+bwd path ``dpolicy``) inside ``_run`` itself; cross-CTA scratch
759
- (atomic accumulators + counters) is allocated lazily on first call from
760
- the input device and self-resets each launch via the kernel's last-block
761
- epilogue + ``atom.inc.u32`` wrap-around.
762
- """
763
- compute_kl = beta != 0.0
764
-
765
- if policy_dtype not in _TORCH_TO_CUTLASS_DTYPE:
766
- raise ValueError(f"Unsupported dtype for bnpo kernel: {policy_dtype}")
767
-
768
- tile_n_small = TILE_N
769
- tile_n_large = TILE_N_LARGE
770
- seq_len_threshold = TILE_N_LARGE_THRESHOLD
771
- block_size = NUM_WARPS * 32
772
- if tile_n_small % (block_size * VEC) != 0:
773
- raise ValueError(
774
- f"TILE_N={tile_n_small} must be a multiple of BLOCK_SIZE*VEC={block_size * VEC}"
775
- )
776
- if tile_n_large % (block_size * VEC) != 0:
777
- raise ValueError(
778
- f"TILE_N_LARGE={tile_n_large} must be a multiple of BLOCK_SIZE*VEC={block_size * VEC}"
779
- )
780
-
781
- bs_sym = cute.sym_int()
782
- seq_len_sym = cute.sym_int()
783
- cute_dtype = _TORCH_TO_CUTLASS_DTYPE[policy_dtype]
784
-
785
- def _fake2d(dt: Any, cols: Any) -> Any:
786
- return cute.runtime.make_fake_compact_tensor(
787
- dt,
788
- (bs_sym, cols),
789
- stride_order=(1, 0),
790
- assumed_align=16,
791
- )
792
-
793
- fake_pol = _fake2d(cute_dtype, seq_len_sym)
794
- fake_old = _fake2d(cute_dtype, seq_len_sym)
795
- fake_ref = _fake2d(cute_dtype, seq_len_sym)
796
- fake_adv = cute.runtime.make_fake_compact_tensor(
797
- cute_dtype,
798
- (bs_sym,),
799
- assumed_align=16,
800
- )
801
- fake_mask = cute.runtime.make_fake_compact_tensor(
802
- cutlass.Int8,
803
- (bs_sym, seq_len_sym),
804
- stride_order=(1, 0),
805
- assumed_align=16,
806
- )
807
- dpolicy_cols = seq_len_sym if compute_backward else 1
808
- fake_dpolicy = cute.runtime.make_fake_compact_tensor(
809
- cute_dtype,
810
- (bs_sym, dpolicy_cols),
811
- stride_order=(1, 0),
812
- assumed_align=16,
813
- )
814
- fake_scalar_f32 = cute.runtime.make_fake_compact_tensor(
815
- cutlass.Float32,
816
- (1,),
817
- assumed_align=16,
818
- )
819
- fake_valid_acc = cute.runtime.make_fake_compact_tensor(
820
- cutlass.Int32,
821
- (1,),
822
- assumed_align=16,
823
- )
824
- fake_counter = cute.runtime.make_fake_compact_tensor(
825
- cutlass.Int32,
826
- (1,),
827
- assumed_align=16,
828
- )
829
- fake_mask_counter = cute.runtime.make_fake_compact_tensor(
830
- cutlass.Int32,
831
- (1,),
832
- assumed_align=16,
833
- )
834
- fake_output = cute.runtime.make_fake_compact_tensor(
835
- cute_dtype,
836
- (1,),
837
- assumed_align=16,
838
- )
839
-
840
- def _build_launch(tile_n_v: int) -> Callable[..., None]:
841
- """Build a ``@cute.jit`` ``_launch`` for a given ``tile_n``.
842
-
843
- Captures *tile_n_v* via closure; both the main kernel and the
844
- (optional) mask-sum kernel are specialised to this tile size.
845
- One ``_launch`` per tier; the runner dispatches at call time.
846
- """
847
- specialized_kernel = _make_bnpo_kernel(compute_kl, compute_backward, tile_n_v)
848
- if compute_backward:
849
- mask_sum_kernel = _make_mask_sum_kernel(tile_n_v)
850
-
851
- @cute.jit
852
- def _launch(
853
- pol_ct: cute.Tensor,
854
- old_ct: cute.Tensor,
855
- ref_ct: cute.Tensor,
856
- adv_ct: cute.Tensor,
857
- mask_ct: cute.Tensor,
858
- dpolicy_ct: cute.Tensor,
859
- inv_total_ct: cute.Tensor,
860
- policy_acc_ct: cute.Tensor,
861
- kl_acc_ct: cute.Tensor,
862
- valid_acc_ct: cute.Tensor,
863
- counter_ct: cute.Tensor,
864
- mask_counter_ct: cute.Tensor,
865
- output_ct: cute.Tensor,
866
- epsilon_v: cutlass.Float32,
867
- epsilon_high_v: cutlass.Float32,
868
- beta_v: cutlass.Float32,
869
- total_blocks_v: cutlass.Int32,
870
- num_full_tiles_v: cutlass.Int32,
871
- tail_len_v: cutlass.Int32,
872
- num_col_tiles_v: cutlass.Int32,
873
- ) -> None:
874
- bs_v = pol_ct.shape[0] # ty: ignore[not-subscriptable]
875
- # Bundled mask-sum (compute_backward only) — writes
876
- # ``1 / completions_mask.sum()`` into ``inv_total_ct`` before the
877
- # main kernel reads it. Both kernels in one tvm-ffi dispatch
878
- # eliminates the per-call ``torch.sum`` + reciprocal round trip.
879
- if cutlass.const_expr(compute_backward):
880
- mask_sum_kernel( # ty: ignore[unresolved-attribute]
881
- mask_ct,
882
- inv_total_ct,
883
- valid_acc_ct,
884
- mask_counter_ct,
885
- total_blocks_v,
886
- num_full_tiles_v,
887
- tail_len_v,
888
- ).launch(
889
- grid=(bs_v, num_col_tiles_v, 1),
890
- block=(NUM_WARPS * 32, 1, 1),
891
- )
892
- specialized_kernel( # ty: ignore[unresolved-attribute]
893
- pol_ct,
894
- old_ct,
895
- ref_ct,
896
- adv_ct,
897
- mask_ct,
898
- dpolicy_ct,
899
- inv_total_ct,
900
- policy_acc_ct,
901
- kl_acc_ct,
902
- valid_acc_ct,
903
- counter_ct,
904
- output_ct,
905
- epsilon_v,
906
- epsilon_high_v,
907
- beta_v,
908
- total_blocks_v,
909
- num_full_tiles_v,
910
- tail_len_v,
911
- ).launch(
912
- grid=(bs_v, num_col_tiles_v, 1),
913
- block=(NUM_WARPS * 32, 1, 1),
914
- )
915
-
916
- return _launch
917
-
918
- def _compile_launch(launch_fn: Callable[..., None]) -> Callable[..., None]:
919
- return cute.compile(
920
- launch_fn,
921
- fake_pol,
922
- fake_old,
923
- fake_ref,
924
- fake_adv,
925
- fake_mask,
926
- fake_dpolicy,
927
- fake_scalar_f32,
928
- fake_scalar_f32,
929
- fake_scalar_f32,
930
- fake_valid_acc,
931
- fake_counter,
932
- fake_mask_counter,
933
- fake_output,
934
- cutlass.Float32(epsilon),
935
- cutlass.Float32(epsilon_high),
936
- cutlass.Float32(beta),
937
- cutlass.Int32(1),
938
- cutlass.Int32(1),
939
- cutlass.Int32(0),
940
- cutlass.Int32(1),
941
- options="--enable-tvm-ffi",
942
- )
943
-
944
- compiled_small = _compile_launch(_build_launch(tile_n_small))
945
- if tile_n_large == tile_n_small:
946
- compiled_large = compiled_small
947
- else:
948
- compiled_large = _compile_launch(_build_launch(tile_n_large))
949
-
950
- eps_const = cutlass.Float32(epsilon)
951
- eps_high_const = cutlass.Float32(epsilon_high)
952
- beta_const = cutlass.Float32(beta)
953
-
954
- # Cross-CTA scratch slab — one int32 buffer with stride-4 (16-byte) slices
955
- # so each slot is individually 16-byte aligned (``assumed_align=16`` at
956
- # compile time). Bit-pattern of int32 0 equals fp32 0.0, so a single
957
- # ``zeros`` factory legitimately initialises both the int32 counters and
958
- # the fp32 accumulators. The kernel's last block self-resets accumulators
959
- # in its epilogue and the counters self-reset via ``atom.inc.u32``
960
- # wrap-around, so the up-front ``torch.zeros`` only matters for the very
961
- # first call.
962
- _scratch: list[torch.Tensor | None] = [None]
963
-
964
- def _ensure_scratch(device: torch.device) -> tuple[torch.Tensor, ...]:
965
- s = _scratch[0]
966
- if s is None or s.device != device:
967
- s = torch.zeros(20, dtype=torch.int32, device=device)
968
- _scratch[0] = s
969
- return (
970
- s[0:1], # counter (int32)
971
- s[4:5], # mask_counter (int32)
972
- s[8:9], # valid_acc (int32)
973
- s[12:13].view(torch.float32), # policy_acc (fp32)
974
- s[16:17].view(torch.float32), # kl_acc (fp32)
975
- )
976
-
977
- def _run(
978
- policy_logprobs_r: torch.Tensor,
979
- old_policy_logprobs_r: torch.Tensor,
980
- ref_logprobs_r: torch.Tensor,
981
- advantages_r: torch.Tensor,
982
- completions_mask_r: torch.Tensor,
983
- ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
984
- bs, seq_len = policy_logprobs_r.shape
985
- device = policy_logprobs_r.device
986
- dtype = policy_logprobs_r.dtype
987
-
988
- # Tier dispatch: long sequences pay too much last-block-detection
989
- # latency under the small-tile grid, so swap to the large-tile
990
- # compiled variant.
991
- if seq_len >= seq_len_threshold:
992
- tile_n_active = tile_n_large
993
- compiled_active = compiled_large
994
- else:
995
- tile_n_active = tile_n_small
996
- compiled_active = compiled_small
997
- num_full_tiles = seq_len // tile_n_active
998
- tail_len = seq_len % tile_n_active
999
- num_col_tiles = num_full_tiles + (1 if tail_len > 0 else 0)
1000
- total_blocks = bs * num_col_tiles
1001
-
1002
- # Per-call write-only buffers — ``empty`` is enough (Liger / TE
1003
- # pattern). ``inv_total`` is populated by the bundled mask-sum
1004
- # kernel (compute_backward path) or by the main kernel's last-block
1005
- # trick (fwd-only path); the runner never reads it.
1006
- output_r = torch.empty(1, dtype=dtype, device=device)
1007
- inv_total_r = torch.empty(1, dtype=torch.float32, device=device)
1008
- if compute_backward:
1009
- dpolicy_r = torch.empty_like(policy_logprobs_r)
1010
- else:
1011
- dpolicy_r = torch.empty(bs, 1, dtype=dtype, device=device)
1012
-
1013
- counter_r, mask_counter_r, valid_acc_r, policy_acc_r, kl_acc_r = _ensure_scratch(device)
1014
-
1015
- compiled_active(
1016
- policy_logprobs_r,
1017
- old_policy_logprobs_r,
1018
- ref_logprobs_r,
1019
- advantages_r,
1020
- completions_mask_r,
1021
- dpolicy_r,
1022
- inv_total_r,
1023
- policy_acc_r,
1024
- kl_acc_r,
1025
- valid_acc_r,
1026
- counter_r,
1027
- mask_counter_r,
1028
- output_r,
1029
- eps_const,
1030
- eps_high_const,
1031
- beta_const,
1032
- total_blocks,
1033
- num_full_tiles,
1034
- tail_len,
1035
- num_col_tiles,
1036
- )
1037
- out_view = output_r.view(())
1038
- if compute_backward:
1039
- return out_view, dpolicy_r
1040
- return out_view
1041
-
1042
- return _run
1043
-
1044
-
1045
- # ---------------------------------------------------------------------------
1046
- # Fused forward + backward — direct (loss, grad) runner, no autograd
1047
- # ---------------------------------------------------------------------------
1048
-
1049
-
1050
- def create_compiled_bnpo_loss_with_backward(
1051
- policy_dtype: torch.dtype,
1052
- epsilon: float,
1053
- epsilon_high: float,
1054
- beta: float,
1055
- ) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]:
1056
- """Compile the fused fwd+bwd bnpo kernel and return a tuple-returning runner.
1057
-
1058
- The returned callable runs one training-step worth of work: a single
1059
- ``@cute.jit`` dispatch produces both the scalar loss and the scaled
1060
- ``dL/d(policy_logprobs)`` tensor. It returns ``(loss, dpolicy)`` directly
1061
- — no ``torch.autograd.Function`` wrapper, no extra ``grad_output * dpolicy``
1062
- backward kernel. Callers that need autograd integration (so
1063
- ``loss.backward()`` works) wrap this themselves at the public-API layer;
1064
- callers that control gradient flow manually (benchmarks, custom training
1065
- loops) can use it as-is for zero overhead.
1066
-
1067
- ``inv_total`` is computed entirely on-GPU by a bundled mask-sum kernel
1068
- that runs in series with the main kernel inside the same ``@cute.jit``
1069
- launch — no host sync, no extra ``torch.sum`` dispatch, CUDA-graph
1070
- compatible.
1071
- """
1072
- return _typing_cast(
1073
- "Callable[..., tuple[torch.Tensor, torch.Tensor]]",
1074
- create_compiled_bnpo_loss(
1075
- policy_dtype=policy_dtype,
1076
- epsilon=epsilon,
1077
- epsilon_high=epsilon_high,
1078
- beta=beta,
1079
- compute_backward=True,
1080
- ),
1081
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-cuda/geometric_ai_kernels/__init__.py DELETED
@@ -1,26 +0,0 @@
1
- import ctypes
2
- import importlib.util
3
- import sys
4
- from pathlib import Path
5
- from types import ModuleType
6
-
7
-
8
- def _import_from_path(file_path: Path) -> ModuleType:
9
- # We cannot use the module name as-is, after adding it to `sys.modules`,
10
- # it would also be used for other imports. So, we make a module name that
11
- # depends on the path for it to be unique using the hex-encoded hash of
12
- # the path.
13
- path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
- module_name = path_hash
15
- spec = importlib.util.spec_from_file_location(module_name, file_path)
16
- if spec is None:
17
- raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
- module = importlib.util.module_from_spec(spec)
19
- if module is None:
20
- raise ImportError(f"Cannot load module {module_name} from spec")
21
- sys.modules[module_name] = module
22
- spec.loader.exec_module(module) # type: ignore
23
- return module
24
-
25
-
26
- globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))