Pramodith commited on
Commit
5285eac
·
verified ·
1 Parent(s): 314505a

Build uploaded using `kernels`.

Browse files
build/torch-cuda/__init__.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Geometric-AI CuteDSL kernels for RL / distillation training.
2
+
3
+ Public surface:
4
+ * ``bnpo_loss`` / ``bnpo_loss_autograd`` / ``bnpo_loss_fwd`` —
5
+ fused fwd+bwd BNPO loss with three entry points (direct
6
+ ``(loss, grad)``, autograd-wrapped, forward-only).
7
+ * ``grpo_loss`` / ``grpo_loss_autograd`` / ``grpo_loss_fwd`` —
8
+ fused fwd+bwd GRPO loss (TRL's per-response normalization
9
+ variant). Same three-entry-point shape as BNPO. Requires
10
+ ``completions_mask``.
11
+ * ``reverse_kl`` / ``reverse_kl_autograd`` /
12
+ ``reverse_kl_fwd`` — fused fwd+bwd reverse-KL
13
+ self-distillation loss with the same three-entry-point shape.
14
+
15
+ HF Kernels integration: :mod:`geometric_ai_kernels.layers` exposes
16
+ ``nn.Module`` adapters per kernel (``bnpoLoss`` / ``bnpoLossInference``,
17
+ ``grpoLoss`` / ``grpoLossInference``, ``ReverseKL`` /
18
+ ``ReverseKLInference``) for use with the ``kernels``
19
+ library's ``kernelize()`` flow.
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import torch._dynamo
25
+
26
+ from .bnpo_loss import bnpo_loss, bnpo_loss_autograd, bnpo_loss_fwd
27
+ from .grpo_loss import grpo_loss, grpo_loss_autograd, grpo_loss_fwd
28
+ from .layers import (
29
+ ReverseKL,
30
+ ReverseKLInference,
31
+ bnpoLoss,
32
+ bnpoLossInference,
33
+ grpoLoss,
34
+ grpoLossInference,
35
+ )
36
+ from .reverse_kl import (
37
+ reverse_kl,
38
+ reverse_kl_autograd,
39
+ reverse_kl_fwd,
40
+ )
41
+
42
+ # Required so ``torch.compile(fullgraph=True)`` can trace through
43
+ # ``torch.autograd.grad`` calls — without it Dynamo graph-breaks at the
44
+ # autograd.grad call site even when AOTAutograd has already derived the
45
+ # joint fwd+bwd graph. Set at package import so any consumer (benches,
46
+ # user training loops, ``kernelize`` flows) gets it for free. Guarded
47
+ # because ``trace_autograd_ops`` was added in torch 2.10 and the
48
+ # Nix-pinned build environment may be on an older torch (2.9 today);
49
+ # the underlying ``Config.__setattr__`` raises on unknown keys.
50
+ if hasattr(torch._dynamo.config, "trace_autograd_ops"):
51
+ torch._dynamo.config.trace_autograd_ops = True # ty: ignore[invalid-assignment]
52
+
53
+ __all__ = [
54
+ "ReverseKL",
55
+ "ReverseKLInference",
56
+ "bnpoLoss",
57
+ "bnpoLossInference",
58
+ "bnpo_loss",
59
+ "bnpo_loss_autograd",
60
+ "bnpo_loss_fwd",
61
+ "grpoLoss",
62
+ "grpoLossInference",
63
+ "grpo_loss",
64
+ "grpo_loss_autograd",
65
+ "grpo_loss_fwd",
66
+ "reverse_kl",
67
+ "reverse_kl_autograd",
68
+ "reverse_kl_fwd",
69
+ ]
build/torch-cuda/_ops.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def get_backend() -> str:
4
+ """Detect the backend by inspecting torch."""
5
+ import torch
6
+
7
+ if hasattr(torch, "neuron"):
8
+ # Needs to be sorted before specific Torch builds, since Neuron
9
+ # extension can be loaded into e.g. CUDA Torch builds.
10
+ return "neuron"
11
+ elif torch.version.cuda is not None:
12
+ return "cuda"
13
+ elif torch.version.hip is not None:
14
+ return "rocm"
15
+ elif torch.backends.mps.is_available():
16
+ return "metal"
17
+ elif hasattr(torch.version, "xpu") and torch.version.xpu is not None:
18
+ return "xpu"
19
+ else:
20
+ return "cpu"
21
+
22
+
23
+ def _find_ops_name() -> str:
24
+ kernel_name = "geometric_ai_kernels"
25
+ unique_id = "a766fbd_dirty"
26
+ backend = get_backend()
27
+ return f"_{kernel_name}_{backend}_{unique_id}"
28
+
29
+
30
+ _OPS_NAME = _find_ops_name()
31
+
32
+ ops = getattr(torch.ops, _OPS_NAME)
33
+
34
+ def add_op_namespace_prefix(op_name: str) -> str:
35
+ """
36
+ Prefix op by namespace.
37
+ """
38
+ return f"{_OPS_NAME}::{op_name}"
build/torch-cuda/bnpo_loss/__init__.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """bnpo loss with CuteDSL fused fwd+bwd.
2
+
3
+ Two public APIs route to two compiled kernels:
4
+
5
+ * :func:`bnpo_loss` — primary training entry point. Returns
6
+ ``(loss, grad_policy_logprobs)`` from a single fused fwd+bwd kernel
7
+ launch. Inputs do **not** need ``requires_grad=True`` and there is no
8
+ ``torch.autograd.Function`` wrapper — chain the gradient into the
9
+ upstream model with ``policy_logprobs.backward(grad)`` (or, more
10
+ commonly, by passing ``grad`` to whatever step does the next leg of
11
+ backprop).
12
+ * :func:`bnpo_loss_fwd` — inference / validation path. Returns the
13
+ scalar ``loss`` from a forward-only kernel that computes the masked
14
+ mean denominator on-GPU via a last-block trick (no host
15
+ ``completions_mask.sum()``).
16
+
17
+ The two share the same compiled-kernel cache; per-call output and
18
+ gradient buffers are allocated inside the runner, and cross-CTA scratch
19
+ (atomic accumulators + counters) is owned by the compiled-kernel
20
+ closure and self-resets each launch — callers don't manage scratch.
21
+
22
+ Why no autograd wrapper here? bnpo's gradient is closed-form — the
23
+ kernel already writes ``dL/d(policy_logprobs)`` in the same launch as
24
+ the loss. Wrapping in ``torch.autograd.Function`` would cost an extra
25
+ ``grad_output * dpolicy`` kernel launch on backward (typically a
26
+ no-op multiply by ``1.0``), plus per-call autograd graph bookkeeping.
27
+ The autograd-aware sibling :func:`bnpo_loss_autograd` uses
28
+ ``torch.library.custom_op`` instead, which composes with
29
+ ``torch.compile``.
30
+ """
31
+
32
+ from __future__ import annotations
33
+
34
+ from functools import lru_cache
35
+ from typing import TYPE_CHECKING, cast
36
+
37
+ import torch
38
+
39
+ from .cute_bnpo_loss import (
40
+ create_compiled_bnpo_loss,
41
+ create_compiled_bnpo_loss_with_backward,
42
+ )
43
+
44
+ if TYPE_CHECKING:
45
+ from collections.abc import Callable
46
+
47
+
48
+ __all__ = ["bnpo_loss", "bnpo_loss_autograd", "bnpo_loss_fwd"]
49
+
50
+
51
+ @lru_cache(maxsize=32)
52
+ def _get_compiled_fwd(
53
+ dtype: torch.dtype,
54
+ epsilon: float,
55
+ epsilon_high: float,
56
+ beta: float,
57
+ ) -> Callable[..., torch.Tensor]:
58
+ return cast(
59
+ "Callable[..., torch.Tensor]",
60
+ create_compiled_bnpo_loss(
61
+ policy_dtype=dtype,
62
+ epsilon=epsilon,
63
+ epsilon_high=epsilon_high,
64
+ beta=beta,
65
+ compute_backward=False,
66
+ ),
67
+ )
68
+
69
+
70
+ @lru_cache(maxsize=32)
71
+ def _get_compiled_fwd_bwd(
72
+ dtype: torch.dtype,
73
+ epsilon: float,
74
+ epsilon_high: float,
75
+ beta: float,
76
+ ) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]:
77
+ return create_compiled_bnpo_loss_with_backward(
78
+ policy_dtype=dtype,
79
+ epsilon=epsilon,
80
+ epsilon_high=epsilon_high,
81
+ beta=beta,
82
+ )
83
+
84
+
85
+ def bnpo_loss_fwd(
86
+ policy_logprobs: torch.Tensor,
87
+ old_policy_logprobs: torch.Tensor,
88
+ ref_logprobs: torch.Tensor,
89
+ advantages: torch.Tensor,
90
+ completions_mask: torch.Tensor,
91
+ epsilon: float = 0.2,
92
+ epsilon_high: float = 0.2,
93
+ beta: float = 0.1,
94
+ ) -> torch.Tensor:
95
+ """Forward-only bnpo loss. Returns the scalar ``loss``.
96
+
97
+ Use for inference / validation. The masked mean denominator is
98
+ computed on-GPU by an atomic accumulator + last-block trick — no
99
+ host ``completions_mask.sum()`` syncs.
100
+
101
+ Args:
102
+ policy_logprobs, old_policy_logprobs, ref_logprobs: ``(bs, seq_len)``.
103
+ advantages: ``(bs,)``.
104
+ completions_mask: bool/int8 mask ``(bs, seq_len)``; truthy = valid token.
105
+ epsilon, epsilon_high: PPO-style clipping bounds.
106
+ beta: KL-penalty coefficient. ``0.0`` compiles away the KL branch.
107
+
108
+ Returns:
109
+ Scalar tensor (0-dim) with the same dtype as ``policy_logprobs``.
110
+ """
111
+ run = _get_compiled_fwd(
112
+ policy_logprobs.dtype,
113
+ float(epsilon),
114
+ float(epsilon_high),
115
+ float(beta),
116
+ )
117
+ mask_arg = (
118
+ completions_mask
119
+ if completions_mask.dtype == torch.int8
120
+ else completions_mask.to(torch.int8)
121
+ )
122
+ return run(
123
+ policy_logprobs,
124
+ old_policy_logprobs,
125
+ ref_logprobs,
126
+ advantages,
127
+ mask_arg,
128
+ )
129
+
130
+
131
+ def bnpo_loss(
132
+ policy_logprobs: torch.Tensor,
133
+ old_policy_logprobs: torch.Tensor,
134
+ ref_logprobs: torch.Tensor,
135
+ advantages: torch.Tensor,
136
+ completions_mask: torch.Tensor,
137
+ epsilon: float = 0.2,
138
+ epsilon_high: float = 0.2,
139
+ beta: float = 0.1,
140
+ ) -> tuple[torch.Tensor, torch.Tensor]:
141
+ """Fused fwd+bwd bnpo loss. Returns ``(loss, grad_policy_logprobs)``.
142
+
143
+ Single-launch training entry point. The kernel writes both the
144
+ scalar loss and the scaled ``dL/d(policy_logprobs)`` tensor in one
145
+ ``@cute.jit`` dispatch — a bundled mask-sum kernel runs inside the
146
+ same launch so ``inv_total`` is populated on-GPU without a host-side
147
+ ``torch.sum`` round trip.
148
+
149
+ Inputs do **not** need ``requires_grad=True``. To chain ``grad``
150
+ into the upstream model that produced ``policy_logprobs``::
151
+
152
+ loss, grad = bnpo_loss(policy_logprobs, ..., completions_mask=mask)
153
+ policy_logprobs.backward(grad)
154
+ optimizer.step()
155
+
156
+ Args:
157
+ policy_logprobs, old_policy_logprobs, ref_logprobs: ``(bs, seq_len)``.
158
+ advantages: ``(bs,)``.
159
+ completions_mask: bool/int8 mask ``(bs, seq_len)``.
160
+ epsilon, epsilon_high: PPO-style clipping bounds.
161
+ beta: KL-penalty coefficient. ``0.0`` compiles away the KL branch.
162
+
163
+ Returns:
164
+ ``(loss, grad_policy_logprobs)`` — ``loss`` is a 0-dim tensor in
165
+ ``policy_logprobs.dtype``; ``grad_policy_logprobs`` has shape
166
+ ``(bs, seq_len)`` and is already scaled by ``1 / n_valid``. The
167
+ gradient tensor is freshly allocated per call (no shared cache),
168
+ so callers may keep it around freely.
169
+
170
+ For inference / validation where you only need the loss, use
171
+ :func:`bnpo_loss_fwd` — it skips the dpolicy write entirely and
172
+ computes the mean denominator with the on-GPU last-block trick.
173
+ """
174
+ run = _get_compiled_fwd_bwd(
175
+ policy_logprobs.dtype,
176
+ float(epsilon),
177
+ float(epsilon_high),
178
+ float(beta),
179
+ )
180
+ mask_arg = (
181
+ completions_mask
182
+ if completions_mask.dtype == torch.int8
183
+ else completions_mask.to(torch.int8)
184
+ )
185
+ return run(
186
+ policy_logprobs,
187
+ old_policy_logprobs,
188
+ ref_logprobs,
189
+ advantages,
190
+ mask_arg,
191
+ )
192
+
193
+
194
+ # Imported at the bottom: ``autograd.py`` imports ``bnpo_loss`` from this
195
+ # module, so the function must be fully defined before its import runs.
196
+ from .autograd import bnpo_loss_autograd # noqa: E402
build/torch-cuda/bnpo_loss/_torch_ref.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Plain-PyTorch bnpo reference shared between the bench and the tests.
2
+
3
+ This module is intentionally minimal — every op is a vanilla torch op so
4
+ ``AOTAutograd`` can derive the joint fwd+bwd graph and Inductor can fuse
5
+ both passes (used by ``benchmarks/benchmark_bnpo_loss.py``'s compiled
6
+ baseline). The same function is imported by ``tests/test_bnpo_loss.py``
7
+ as the correctness reference, so both paths agree on what "the eager
8
+ torch implementation of bnpo loss" means.
9
+
10
+ Underscore-prefixed module name signals "shared internal", not a public
11
+ API surface — there's no re-export from the package's top-level
12
+ ``__init__.py``.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import torch
18
+
19
+
20
+ def torch_bnpo_loss(
21
+ policy_logprobs: torch.Tensor,
22
+ old_policy_logprobs: torch.Tensor,
23
+ ref_logprobs: torch.Tensor,
24
+ advantages: torch.Tensor,
25
+ completions_mask: torch.Tensor,
26
+ epsilon: float = 0.2,
27
+ epsilon_high: float = 0.2,
28
+ beta: float = 0.1,
29
+ ) -> torch.Tensor:
30
+ """Plain-Python bnpo reference traceable by AOTAutograd / Inductor.
31
+
32
+ Operates in the input dtype throughout (no internal fp32 cast),
33
+ which is what real torch users would write — and what
34
+ ``torch.compile`` competes against in the bench.
35
+ """
36
+ ratio = torch.exp(policy_logprobs - old_policy_logprobs)
37
+ adv = advantages.unsqueeze(1)
38
+
39
+ surrogate = ratio * adv
40
+ surrogate_clipped = torch.clamp(ratio, 1.0 - epsilon, 1.0 + epsilon_high) * adv
41
+ policy_loss = -torch.min(surrogate, surrogate_clipped)
42
+
43
+ log_ratio_ref = ref_logprobs - policy_logprobs
44
+ kl = torch.exp(log_ratio_ref) - log_ratio_ref - 1.0
45
+
46
+ # Cast n_valid to fp32: int64 → fp16 overflows when n_valid > 65504.
47
+ # ``clamp(min=1.0)`` matches TRL's ``mask.sum().clamp(min=1)``: a
48
+ # fully-masked batch produces ``loss=0`` instead of inf/NaN. Mirrors
49
+ # the cute kernel's ``cute.arch.fmax(..., 1.0)`` before ``rcp_approx``
50
+ # in ``cute_bnpo_loss.py``.
51
+ n_valid = completions_mask.sum().to(torch.float32).clamp(min=1.0)
52
+ policy_loss = (policy_loss * completions_mask).sum() / n_valid
53
+ kl = (kl * completions_mask).sum() / n_valid
54
+
55
+ loss = policy_loss + beta * kl
56
+ return loss.to(policy_logprobs.dtype)
build/torch-cuda/bnpo_loss/autograd.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Autograd-aware wrapper for bnpo loss via ``torch.library.custom_op``.
2
+
3
+ The fused cute kernel writes both the scalar loss and the closed-form
4
+ ``dL/d(policy_logprobs)`` in one launch. This module wraps that into an
5
+ autograd-compatible op so callers can write::
6
+
7
+ loss = bnpo_loss_autograd(policy, old, ref, adv, completions_mask)
8
+ loss.backward() # propagates through to the upstream model
9
+
10
+ instead of the manual ``policy.backward(grad)`` chain. The cost is
11
+ ~12µs of autograd dispatcher overhead per call (vs the direct
12
+ ``bnpo_loss`` ``(loss, grad)`` tuple); for ergonomic / kernelize() flows
13
+ that's cheap, but for tight microbenches use the direct path.
14
+
15
+ Implementation notes:
16
+
17
+ - The registered op returns ``(loss, dpolicy)`` so ``setup_context`` can
18
+ ``save_for_backward(dpolicy)``. The public ``bnpo_loss_autograd``
19
+ wrapper hides the second output.
20
+ - ``dpolicy`` is allocated fresh by the runner on every call (no shared
21
+ cache), so ``ctx.save_for_backward(dpolicy)`` keeps a stable reference
22
+ across subsequent calls without any extra copy.
23
+ - Backward returns ``grad_loss * dpolicy``. Under ``torch.compile``,
24
+ when ``loss`` is consumed by ``.backward()`` directly, ``grad_loss``
25
+ is the constant 1.0 and Inductor can fold the multiply away — that's
26
+ the main reason this path uses ``custom_op`` instead of a plain
27
+ ``autograd.Function``.
28
+ - ``register_fake`` provides the meta kernel for ``torch.compile`` shape
29
+ propagation; the real cute kernel never runs under ``FakeTensorMode``.
30
+ """
31
+
32
+ from __future__ import annotations
33
+
34
+ import torch
35
+
36
+ from . import bnpo_loss as _bnpo_loss_fwd_bwd
37
+
38
+ __all__ = ["bnpo_loss_autograd"]
39
+
40
+
41
+ @torch.library.custom_op(
42
+ "geometric_ai_kernels::_bnpo_loss_with_grad",
43
+ mutates_args=(),
44
+ )
45
+ def _bnpo_loss_with_grad(
46
+ policy_logprobs: torch.Tensor,
47
+ old_policy_logprobs: torch.Tensor,
48
+ ref_logprobs: torch.Tensor,
49
+ advantages: torch.Tensor,
50
+ completions_mask: torch.Tensor,
51
+ epsilon: float,
52
+ epsilon_high: float,
53
+ beta: float,
54
+ ) -> tuple[torch.Tensor, torch.Tensor]:
55
+ loss, dpolicy = _bnpo_loss_fwd_bwd(
56
+ policy_logprobs,
57
+ old_policy_logprobs,
58
+ ref_logprobs,
59
+ advantages,
60
+ completions_mask,
61
+ epsilon=epsilon,
62
+ epsilon_high=epsilon_high,
63
+ beta=beta,
64
+ )
65
+ return loss, dpolicy
66
+
67
+
68
+ @_bnpo_loss_with_grad.register_fake
69
+ def _(
70
+ policy_logprobs: torch.Tensor,
71
+ old_policy_logprobs: torch.Tensor,
72
+ ref_logprobs: torch.Tensor,
73
+ advantages: torch.Tensor,
74
+ completions_mask: torch.Tensor,
75
+ epsilon: float,
76
+ epsilon_high: float,
77
+ beta: float,
78
+ ) -> tuple[torch.Tensor, torch.Tensor]:
79
+ # Signature must mirror the op; only ``policy_logprobs`` shapes the outputs.
80
+ del old_policy_logprobs, ref_logprobs, advantages, completions_mask
81
+ del epsilon, epsilon_high, beta
82
+ loss = policy_logprobs.new_empty(())
83
+ dpolicy = torch.empty_like(policy_logprobs)
84
+ return loss, dpolicy
85
+
86
+
87
+ def _setup_context(ctx, inputs, output) -> None: # type: ignore[no-untyped-def]
88
+ del inputs # only ``output`` carries what we need to save.
89
+ _, dpolicy = output
90
+ ctx.save_for_backward(dpolicy)
91
+
92
+
93
+ def _backward(ctx, grad_loss, grad_dpolicy): # type: ignore[no-untyped-def]
94
+ # ``grad_dpolicy`` is unused — ``dpolicy`` is an internal intermediate
95
+ # exposed only so ``setup_context`` can save it. Under typical usage
96
+ # (``loss.backward()``) it arrives as ``None`` or a zero tensor.
97
+ del grad_dpolicy
98
+ (dpolicy,) = ctx.saved_tensors
99
+ grad_policy = grad_loss * dpolicy
100
+ # One return per input to the op (8): policy_logprobs gets the grad,
101
+ # everything else gets None (no autograd flow).
102
+ return grad_policy, None, None, None, None, None, None, None
103
+
104
+
105
+ torch.library.register_autograd(
106
+ "geometric_ai_kernels::_bnpo_loss_with_grad",
107
+ _backward,
108
+ setup_context=_setup_context,
109
+ )
110
+
111
+
112
+ def bnpo_loss_autograd(
113
+ policy_logprobs: torch.Tensor,
114
+ old_policy_logprobs: torch.Tensor,
115
+ ref_logprobs: torch.Tensor,
116
+ advantages: torch.Tensor,
117
+ completions_mask: torch.Tensor,
118
+ epsilon: float = 0.2,
119
+ epsilon_high: float = 0.2,
120
+ beta: float = 0.1,
121
+ ) -> torch.Tensor:
122
+ """Autograd-aware bnpo loss. Returns scalar ``loss``.
123
+
124
+ Same numerics as :func:`bnpo_loss` but registered as a
125
+ ``torch.library`` custom op with autograd, so::
126
+
127
+ loss = bnpo_loss_autograd(policy, ..., completions_mask)
128
+ loss.backward()
129
+
130
+ propagates through to whatever produced ``policy_logprobs``. For
131
+ direct ``(loss, grad)`` access without the autograd dispatcher
132
+ overhead, use :func:`bnpo_loss` and chain the gradient manually
133
+ via ``policy_logprobs.backward(grad)``.
134
+
135
+ Composes with ``torch.compile``: the op is opaque to Inductor but
136
+ has a fake/meta kernel registered, so models containing this layer
137
+ can be compiled end-to-end without graph breaks.
138
+ """
139
+ loss, _ = _bnpo_loss_with_grad(
140
+ policy_logprobs,
141
+ old_policy_logprobs,
142
+ ref_logprobs,
143
+ advantages,
144
+ completions_mask,
145
+ float(epsilon),
146
+ float(epsilon_high),
147
+ float(beta),
148
+ )
149
+ return loss
build/torch-cuda/bnpo_loss/cute_bnpo_loss.py ADDED
@@ -0,0 +1,1081 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CuteDSL kernel for bnpo loss.
2
+
3
+ Computes (element-wise over ``(bs, seq_len)`` logprob tensors, reduced to a
4
+ scalar):
5
+
6
+ ratio = exp(policy - old_policy)
7
+ surrogate = ratio * adv
8
+ clipped = clip(ratio, 1 - eps, 1 + eps_high) * adv
9
+ policy_loss = -min(surrogate, clipped)
10
+ log_ratio_ref = ref - policy
11
+ kl = exp(log_ratio_ref) - log_ratio_ref - 1
12
+ L_bnpo = (policy_loss * mask).sum() / n_valid
13
+ + beta * (kl * mask).sum() / n_valid
14
+
15
+ where ``n_valid = max(completions_mask.sum(), 1)``. The mean denominator is
16
+ computed entirely on-GPU — the forward-only path uses an atomic accumulator
17
+ + last-block trick on ``valid_acc``; the fused fwd+bwd path bundles a small
18
+ companion mask-sum kernel into the same ``@cute.jit`` launch that writes
19
+ ``1 / completions_mask.sum()`` into the ``inv_total`` GMEM scalar before the
20
+ main kernel reads it. Every block needs ``inv_total`` mid-loop to scale its
21
+ ``dpolicy`` slab, so the fwd-only last-block trick doesn't compose with
22
+ backward; bundling the mask-sum keeps both paths host-sync-free and CUDA-graph
23
+ compatible.
24
+
25
+ When ``beta=0`` the KL term is skipped at compile time (no ``ref`` tensor
26
+ access, no ``kl_acc`` atomic add).
27
+
28
+ Sequence lengths that are **not** a multiple of ``TILE_N`` are handled
29
+ natively: the grid launches ``ceil(seq_len / TILE_N)`` column tiles; full tiles
30
+ use the vectorized ``LDG.128`` path and the tail tile uses predicated vector
31
+ loads with neutral prefill.
32
+
33
+ Two compiled-kernel flavors are exposed:
34
+
35
+ * :func:`create_compiled_bnpo_loss` — forward-only.
36
+ * :func:`create_compiled_bnpo_loss_with_backward` — fused fwd+bwd. Returns
37
+ ``(loss, dpolicy)`` directly — no ``torch.autograd.Function`` wrapper. The
38
+ autograd-aware sibling lives in ``autograd.py`` and uses
39
+ ``torch.library.custom_op`` instead.
40
+
41
+ Per-call output (``loss``, ``dpolicy``, ``inv_total``) is allocated inside the
42
+ runner. Cross-CTA scratch (atomic accumulators + counters) is allocated lazily
43
+ on first call inside the compiled-kernel closure and self-resets each launch
44
+ via the kernel's last-block epilogue + ``atom.inc.u32`` wrap-around — callers
45
+ don't manage scratch state.
46
+ """
47
+
48
+ from __future__ import annotations
49
+
50
+ import math
51
+ import operator
52
+ from typing import TYPE_CHECKING, Any
53
+ from typing import cast as _typing_cast
54
+
55
+ import cutlass
56
+ import cutlass.utils
57
+ import torch
58
+ from cutlass import cute
59
+ from cutlass._mlir.dialects import llvm
60
+ from cutlass.base_dsl.typing import cast
61
+ from cutlass.cutlass_dsl import T, dsl_user_op
62
+
63
+ if TYPE_CHECKING:
64
+ from collections.abc import Callable
65
+
66
+
67
+ TILE_N: int = 512
68
+ NUM_WARPS: int = 4
69
+ # ``VEC=4`` (fp32) emits 128-bit ``LDG.128``. Pairs with ``NUM_WARPS=4`` so
70
+ # each block processes ``block_size * VEC = 512 = TILE_N`` elements per iter.
71
+ VEC: int = 4
72
+ # Large-tile variant: at very long ``seq_len`` the small-TILE_N grid
73
+ # explodes (e.g. 8192/512 = 16 col-tiles per row → thousands of CTAs),
74
+ # inflating last-block-detection latency and atomic contention. A second
75
+ # compiled variant with this larger tile is dispatched when
76
+ # ``seq_len >= TILE_N_LARGE_THRESHOLD``.
77
+ TILE_N_LARGE: int = 4096
78
+ TILE_N_LARGE_THRESHOLD: int = 2048
79
+
80
+ _LOG2_E: float = math.log2(math.e)
81
+
82
+ _TORCH_TO_CUTLASS_DTYPE: dict[torch.dtype, Any] = {
83
+ torch.float32: cutlass.Float32,
84
+ torch.float16: cutlass.Float16,
85
+ torch.bfloat16: cutlass.BFloat16,
86
+ }
87
+
88
+
89
+ @dsl_user_op
90
+ def _atomic_add_f32_gmem(
91
+ ptr_i64: Any,
92
+ val: cutlass.Float32,
93
+ *,
94
+ loc: Any = None,
95
+ ip: Any = None,
96
+ ) -> None:
97
+ llvm.inline_asm(
98
+ T.f32(),
99
+ [ptr_i64, cutlass.Float32(val).ir_value(loc=loc, ip=ip)],
100
+ "atom.global.add.f32 $0, [$1], $2;",
101
+ "=f,l,f",
102
+ has_side_effects=True,
103
+ is_align_stack=False,
104
+ asm_dialect=llvm.AsmDialect.AD_ATT,
105
+ )
106
+
107
+
108
+ @dsl_user_op
109
+ def _atomic_add_s32_gmem(
110
+ ptr_i64: Any,
111
+ val: cutlass.Int32,
112
+ *,
113
+ loc: Any = None,
114
+ ip: Any = None,
115
+ ) -> None:
116
+ """Emit ``atom.global.add.s32`` to a 64-bit GMEM address."""
117
+ llvm.inline_asm(
118
+ T.i32(),
119
+ [ptr_i64, cutlass.Int32(val).ir_value(loc=loc, ip=ip)],
120
+ "atom.global.add.s32 $0, [$1], $2;",
121
+ "=r,l,r",
122
+ has_side_effects=True,
123
+ is_align_stack=False,
124
+ asm_dialect=llvm.AsmDialect.AD_ATT,
125
+ )
126
+
127
+
128
+ @dsl_user_op
129
+ def _dp4a_u32_acc_s32(
130
+ packed_a: cutlass.Uint32,
131
+ packed_b: cutlass.Uint32,
132
+ acc: cutlass.Int32,
133
+ *,
134
+ loc: Any = None,
135
+ ip: Any = None,
136
+ ) -> cutlass.Int32:
137
+ """``dp4a.u32.u32`` — sum 4 packed u8 products into an s32 acc.
138
+
139
+ Computes ``a[0]*b[0] + a[1]*b[1] + a[2]*b[2] + a[3]*b[3] + acc`` in
140
+ one ``IDP4A.U8.S32`` instruction (full-rate on Hopper/Blackwell).
141
+ For mask summation, pass ``packed_b = 0x01010101`` so the products
142
+ reduce to ``sum(a_bytes) + acc`` — 4× fewer ALU ops than 4 separate
143
+ int8→int32 widens + adds.
144
+ """
145
+ return cutlass.Int32(
146
+ llvm.inline_asm(
147
+ T.i32(),
148
+ [
149
+ cutlass.Uint32(packed_a).ir_value(loc=loc, ip=ip),
150
+ cutlass.Uint32(packed_b).ir_value(loc=loc, ip=ip),
151
+ cutlass.Int32(acc).ir_value(loc=loc, ip=ip),
152
+ ],
153
+ "dp4a.u32.u32 $0, $1, $2, $3;",
154
+ "=r,r,r,r",
155
+ has_side_effects=False,
156
+ is_align_stack=False,
157
+ asm_dialect=llvm.AsmDialect.AD_ATT,
158
+ )
159
+ )
160
+
161
+
162
+ @dsl_user_op
163
+ def _atomic_inc_u32_gmem(
164
+ ptr_i64: Any,
165
+ threshold: cutlass.Int32,
166
+ *,
167
+ loc: Any = None,
168
+ ip: Any = None,
169
+ ) -> cutlass.Int32:
170
+ """``atom.global.inc.u32`` — returns old value; wraps to 0 at threshold."""
171
+ return cutlass.Int32(
172
+ llvm.inline_asm(
173
+ T.i32(),
174
+ [ptr_i64, cutlass.Int32(threshold).ir_value(loc=loc, ip=ip)],
175
+ "atom.global.inc.u32 $0, [$1], $2;",
176
+ "=r,l,r",
177
+ has_side_effects=True,
178
+ is_align_stack=False,
179
+ asm_dialect=llvm.AsmDialect.AD_ATT,
180
+ )
181
+ )
182
+
183
+
184
+ # ---------------------------------------------------------------------------
185
+ # Mask-sum kernel — replaces ``torch.sum(completions_mask)`` on the fwd+bwd
186
+ # path. Bundled into the same ``@cute.jit`` launch as the main kernel so the
187
+ # whole step is one tvm-ffi dispatch (no extra Python/torch dispatcher round
188
+ # trip). The kernel writes ``1 / completions_mask.sum()`` directly into
189
+ # ``inv_total_tensor`` so the main kernel reads it as a pre-inverted scalar.
190
+ # ---------------------------------------------------------------------------
191
+
192
+
193
+ def _make_mask_sum_kernel(tile_n: int) -> Callable[..., None]:
194
+ """Return a ``@cute.kernel`` that reduces ``completions_mask`` and writes 1/sum.
195
+
196
+ Grid mirrors the main kernel — ``(bs, num_col_tiles)`` — so the mask is
197
+ read once with the same vectorised LDG pattern as the main compute.
198
+ Each block:
199
+
200
+ 1. Loads its ``tile_n`` int8 slab of ``completions_mask`` (predicated tail).
201
+ 2. Reduces to a per-block ``int32`` scalar (bit-exact, no per-element
202
+ i8→f32 cast — IADD throughput equals FADD on Hopper/Blackwell).
203
+ 3. Atomically adds it to ``valid_acc`` (global int32 accumulator).
204
+ 4. Increments ``mask_counter``; the last block reads ``valid_acc``,
205
+ casts to fp32, computes ``rcp_approx`` and writes
206
+ ``inv_total_tensor[0]``, then resets ``valid_acc`` to ``0`` so
207
+ the next call starts fresh. The counter self-resets via
208
+ ``atom.inc.u32`` wrap-around.
209
+
210
+ A separate ``mask_counter`` tensor (not the main kernel's ``counter``)
211
+ is required because the two kernels run in series within the same
212
+ ``@cute.jit`` and both rely on a wrap-around for self-reset; sharing
213
+ one counter would race.
214
+ """
215
+
216
+ @cute.kernel
217
+ def _mask_sum_kernel(
218
+ completions_mask: cute.Tensor, # (bs, seq_len) int8
219
+ inv_total_tensor: cute.Tensor, # (1,) fp32 — output
220
+ valid_acc: cute.Tensor, # (1,) int32 — accumulator
221
+ mask_counter: cute.Tensor, # (1,) i32 — last-block detection
222
+ total_blocks: cutlass.Int32,
223
+ num_full_tiles: cutlass.Int32,
224
+ tail_len: cutlass.Int32,
225
+ ) -> None:
226
+ block_size = NUM_WARPS * 32
227
+ iters = tile_n // (block_size * VEC)
228
+
229
+ _no_alloc = cute.nvgpu.CacheEvictionPriority.NO_ALLOCATE
230
+ g2r_op = cute.nvgpu.CopyUniversalOp()
231
+ g2r_mask_atom = cute.make_copy_atom(
232
+ g2r_op,
233
+ completions_mask.element_type,
234
+ num_bits_per_copy=0,
235
+ l1c_evict_priority=_no_alloc,
236
+ )
237
+
238
+ row = cute.arch.block_idx()[0]
239
+ col_block = cute.arch.block_idx()[1]
240
+ tid = cute.arch.thread_idx()[0]
241
+
242
+ local_valid_sum = cutlass.Int32(0)
243
+ mask_row = cute.slice_(completions_mask, (row, None))
244
+
245
+ # ``dp4a.u32.u32`` consumes a packed-u8x4 register. With VEC=4 each
246
+ # thread loads 4 contiguous int8 bytes per iteration, so we recast
247
+ # the fragment as a single ``Uint32`` view and feed it directly
248
+ # into dp4a — one instruction sums all 4 bytes, vs the previous
249
+ # cast+reduce which emitted 4 widens + 3 adds per iteration.
250
+ ones_packed = cutlass.Uint32(0x01010101)
251
+
252
+ if col_block < num_full_tiles:
253
+ mask_slab = cute.local_tile(mask_row, (tile_n,), (col_block,))
254
+ for k in cutlass.range(iters, unroll_full=True):
255
+ sub_idx = tid + k * block_size
256
+ mask_src = cute.local_tile(mask_slab, (VEC,), (sub_idx,))
257
+ mask_frag = cute.make_fragment_like(mask_src)
258
+ cute.copy(g2r_mask_atom, mask_src, mask_frag)
259
+ packed = cute.recast_tensor(mask_frag, cutlass.Uint32)[0]
260
+ local_valid_sum = _dp4a_u32_acc_s32(packed, ones_packed, local_valid_sum)
261
+ else:
262
+ mask_slab = cute.local_tile(mask_row, (tile_n,), (col_block,))
263
+ for k in cutlass.range(iters, unroll_full=True):
264
+ sub_idx = tid + k * block_size
265
+ chunk_base = sub_idx * VEC
266
+ if chunk_base < tail_len:
267
+ mask_src = cute.local_tile(mask_slab, (VEC,), (sub_idx,))
268
+ pred = cute.make_rmem_tensor(mask_src.shape, cutlass.Boolean)
269
+ for v in cutlass.range(VEC, unroll_full=True):
270
+ pred[v] = cute.elem_less(chunk_base + v, tail_len)
271
+ mask_frag = cute.make_fragment_like(mask_src)
272
+ mask_frag.fill(0)
273
+ cute.copy(g2r_mask_atom, mask_src, mask_frag, pred=pred)
274
+ packed = cute.recast_tensor(mask_frag, cutlass.Uint32)[0]
275
+ local_valid_sum = _dp4a_u32_acc_s32(packed, ones_packed, local_valid_sum)
276
+
277
+ # Warp + cross-warp reduction (same pattern as main kernel).
278
+ warp_valid = cute.arch.warp_reduction(local_valid_sum, operator.add)
279
+ smem = cutlass.utils.SmemAllocator()
280
+ buf_valid = smem.allocate_tensor(cutlass.Int32, cute.make_layout(NUM_WARPS))
281
+
282
+ lane_idx = cute.arch.lane_idx()
283
+ warp_idx = cute.arch.warp_idx()
284
+
285
+ if lane_idx == 0:
286
+ buf_valid[warp_idx] = warp_valid
287
+ cute.arch.barrier()
288
+
289
+ if warp_idx == 0:
290
+ val_v = cutlass.Int32(0)
291
+ if lane_idx < NUM_WARPS:
292
+ val_v = buf_valid[lane_idx]
293
+ block_valid = cute.arch.warp_reduction(val_v, operator.add, threads_in_group=NUM_WARPS)
294
+
295
+ if lane_idx == 0:
296
+ valid_ptr = valid_acc.iterator.toint().ir_value() # ty: ignore[unresolved-attribute]
297
+ counter_ptr = mask_counter.iterator.toint().ir_value() # ty: ignore[unresolved-attribute]
298
+
299
+ _atomic_add_s32_gmem(valid_ptr, block_valid)
300
+ cute.arch.fence_acq_rel_gpu()
301
+ old = _atomic_inc_u32_gmem(counter_ptr, total_blocks - 1)
302
+
303
+ if old == total_blocks - 1:
304
+ # Clamp to >=1.0 so a fully-masked batch (n_valid=0)
305
+ # produces ``loss=0`` instead of inf/NaN — matches
306
+ # TRL's ``mask.sum().clamp(min=1)`` semantics.
307
+ n_valid = cute.arch.fmax(cutlass.Float32(valid_acc[0]), cutlass.Float32(1.0))
308
+ inv_total_tensor[0] = cute.arch.rcp_approx(n_valid)
309
+ valid_acc[0] = cutlass.Int32(0)
310
+
311
+ return _mask_sum_kernel
312
+
313
+
314
+ def _make_bnpo_kernel(
315
+ compute_kl: bool,
316
+ compute_backward: bool,
317
+ tile_n: int,
318
+ ) -> Callable[..., None]:
319
+ """Return a ``@cute.kernel`` specialised on compile-time flags.
320
+
321
+ The returned kernel captures *compute_kl*, *compute_backward*, and
322
+ *tile_n* in its closure. ``cutlass.const_expr`` evaluates the booleans
323
+ at trace time so dead branches are eliminated from the compiled PTX.
324
+ ``tile_n`` is a Python ``int`` captured at trace time, so the same
325
+ factory can emit two specialised kernels (small / large tile) — see
326
+ :func:`create_compiled_bnpo_loss` for dispatch.
327
+
328
+ When *compute_backward* is True the kernel additionally writes
329
+ ``dpolicy = dL/d(policy_logprobs)`` to GMEM in the same inner loop —
330
+ no extra HBM reads of the inputs. Because every block must scale
331
+ ``dpolicy`` by ``inv_total`` mid-loop, the on-GPU last-block computation
332
+ of ``inv_total`` from the masked accumulator does **not** compose with
333
+ backward; the bundled mask-sum kernel populates ``inv_total_tensor``
334
+ before the main kernel runs.
335
+
336
+ When *compute_backward* is False the kernel accumulates the
337
+ mask-element count via ``valid_acc`` and computes
338
+ ``inv_total = 1 / n_valid`` on-GPU in the last-block path — no
339
+ host-side ``completions_mask.sum()`` required.
340
+ """
341
+
342
+ @cute.kernel
343
+ def _bnpo_loss_kernel(
344
+ policy: cute.Tensor,
345
+ old_policy: cute.Tensor,
346
+ ref: cute.Tensor,
347
+ advantages: cute.Tensor,
348
+ completions_mask: cute.Tensor,
349
+ dpolicy: cute.Tensor, # (bs, seq_len) when compute_backward; (bs, 1) dummy otherwise
350
+ inv_total_tensor: cute.Tensor, # (1,) fp32 — caller-populated 1/n_valid
351
+ policy_acc: cute.Tensor,
352
+ kl_acc: cute.Tensor,
353
+ valid_acc: cute.Tensor, # (1,) int32 — mask-element count accumulator
354
+ counter: cute.Tensor,
355
+ output: cute.Tensor,
356
+ epsilon: cutlass.Float32,
357
+ epsilon_high: cutlass.Float32,
358
+ beta: cutlass.Float32,
359
+ total_blocks: cutlass.Int32,
360
+ num_full_tiles: cutlass.Int32,
361
+ tail_len: cutlass.Int32,
362
+ ) -> None:
363
+ block_size = NUM_WARPS * 32
364
+ iters = tile_n // (block_size * VEC)
365
+
366
+ # Read inv_total from GMEM once per block (hoisted, single load).
367
+ # Skipped on the fwd-only path which uses an on-GPU last-block
368
+ # computation from the valid_acc accumulator instead. On the
369
+ # compute_backward path the bundled mask-sum kernel writes
370
+ # ``1 / completions_mask.sum()`` into ``inv_total_tensor`` before
371
+ # this kernel runs, so the load returns the pre-inverted scalar.
372
+ accumulate_valid = not compute_backward
373
+ if cutlass.const_expr(not accumulate_valid):
374
+ inv_total = cast(inv_total_tensor[0], cutlass.Float32)
375
+
376
+ _no_alloc = cute.nvgpu.CacheEvictionPriority.NO_ALLOCATE
377
+ g2r_op = cute.nvgpu.CopyUniversalOp()
378
+ g2r_atom = cute.make_copy_atom(
379
+ g2r_op,
380
+ policy.element_type,
381
+ num_bits_per_copy=0,
382
+ l1c_evict_priority=_no_alloc,
383
+ )
384
+ g2r_mask_atom = cute.make_copy_atom(
385
+ g2r_op,
386
+ completions_mask.element_type,
387
+ num_bits_per_copy=0,
388
+ l1c_evict_priority=_no_alloc,
389
+ )
390
+ if cutlass.const_expr(compute_backward):
391
+ r2g_atom = cute.make_copy_atom(
392
+ g2r_op,
393
+ dpolicy.element_type,
394
+ num_bits_per_copy=0,
395
+ )
396
+
397
+ row = cute.arch.block_idx()[0]
398
+ col_block = cute.arch.block_idx()[1]
399
+ tid = cute.arch.thread_idx()[0]
400
+
401
+ adv = cast(advantages[row], cutlass.Float32)
402
+ lo = cutlass.Float32(1.0) - epsilon
403
+ hi = cutlass.Float32(1.0) + epsilon_high
404
+
405
+ local_policy_sum = cutlass.Float32(0.0)
406
+ local_kl_sum = cutlass.Float32(0.0)
407
+ # mask_vec is already cast to fp32 for loss/kl multiplications, so
408
+ # accumulate valid in fp32 too (avoids a separate i8→i32 reduction).
409
+ # Cast to int32 only at the atomic boundary so the shared
410
+ # ``valid_acc`` global can remain int32 — see ``_atomic_add_s32_gmem``.
411
+ local_valid_sum = cutlass.Float32(0.0)
412
+
413
+ pol_row = cute.slice_(policy, (row, None))
414
+ old_row = cute.slice_(old_policy, (row, None))
415
+
416
+ if cutlass.const_expr(compute_kl):
417
+ ref_row = cute.slice_(ref, (row, None))
418
+
419
+ mask_row = cute.slice_(completions_mask, (row, None))
420
+
421
+ if cutlass.const_expr(compute_backward):
422
+ dp_row = cute.slice_(dpolicy, (row, None))
423
+
424
+ # ---- Full-tile vectorised path (LDG.128) ----
425
+ if col_block < num_full_tiles:
426
+ pol_slab = cute.local_tile(pol_row, (tile_n,), (col_block,))
427
+ old_slab = cute.local_tile(old_row, (tile_n,), (col_block,))
428
+
429
+ if cutlass.const_expr(compute_kl):
430
+ ref_slab = cute.local_tile(ref_row, (tile_n,), (col_block,))
431
+
432
+ mask_slab = cute.local_tile(mask_row, (tile_n,), (col_block,))
433
+
434
+ if cutlass.const_expr(compute_backward):
435
+ dp_slab = cute.local_tile(dp_row, (tile_n,), (col_block,))
436
+
437
+ for k in cutlass.range(iters, unroll_full=True):
438
+ sub_idx = tid + k * block_size
439
+
440
+ pol_src = cute.local_tile(pol_slab, (VEC,), (sub_idx,))
441
+ old_src = cute.local_tile(old_slab, (VEC,), (sub_idx,))
442
+ pol_frag = cute.make_fragment_like(pol_src)
443
+ old_frag = cute.make_fragment_like(old_src)
444
+ cute.copy(g2r_atom, pol_src, pol_frag)
445
+ cute.copy(g2r_atom, old_src, old_frag)
446
+
447
+ if cutlass.const_expr(compute_kl):
448
+ ref_src = cute.local_tile(ref_slab, (VEC,), (sub_idx,))
449
+ ref_frag = cute.make_fragment_like(ref_src)
450
+ cute.copy(g2r_atom, ref_src, ref_frag)
451
+
452
+ mask_src = cute.local_tile(mask_slab, (VEC,), (sub_idx,))
453
+ mask_frag = cute.make_fragment_like(mask_src)
454
+ cute.copy(g2r_mask_atom, mask_src, mask_frag)
455
+
456
+ pol_vec = pol_frag.load().to(cutlass.Float32)
457
+ old_vec = old_frag.load().to(cutlass.Float32)
458
+
459
+ log_ratio = pol_vec - old_vec
460
+ ratio = cute.math.exp2(log_ratio * _LOG2_E, fastmath=True)
461
+ surrogate = ratio * adv
462
+ clipped_ratio = cute.where(
463
+ ratio < lo,
464
+ lo,
465
+ cute.where(ratio > hi, hi, ratio),
466
+ )
467
+ clipped = clipped_ratio * adv
468
+ policy_loss = -cute.where(surrogate < clipped, surrogate, clipped)
469
+
470
+ if cutlass.const_expr(compute_kl):
471
+ ref_vec = ref_frag.load().to(cutlass.Float32)
472
+ log_ratio_ref = ref_vec - pol_vec
473
+ ratio_ref = cute.math.exp2(log_ratio_ref * _LOG2_E, fastmath=True)
474
+ # FFMA-friendly rearrangement: ``(ratio_ref - 1) - log_ratio_ref``
475
+ # exposes a ``ratio_ref + (-1)`` pair that ptxas folds with
476
+ # the subsequent subtract — same arithmetic, fewer FADDs
477
+ # surviving SASS than the original 3-term ``a - b - c``.
478
+ kl_val = (ratio_ref - cutlass.Float32(1.0)) - log_ratio_ref
479
+
480
+ mask_vec = mask_frag.load().to(cutlass.Float32)
481
+ local_policy_sum += (policy_loss * mask_vec).reduce(
482
+ cute.ReductionOp.ADD,
483
+ cutlass.Float32(0.0),
484
+ reduction_profile=0,
485
+ )
486
+ if cutlass.const_expr(not compute_backward):
487
+ local_valid_sum += mask_vec.reduce(
488
+ cute.ReductionOp.ADD,
489
+ cutlass.Float32(0.0),
490
+ reduction_profile=0,
491
+ )
492
+ if cutlass.const_expr(compute_kl):
493
+ local_kl_sum += (kl_val * mask_vec).reduce(
494
+ cute.ReductionOp.ADD,
495
+ cutlass.Float32(0.0),
496
+ reduction_profile=0,
497
+ )
498
+
499
+ # ---- Backward: write scaled dpolicy slab in same loop ----
500
+ # use_unclipped = (surrogate <= clipped) — matches torch's
501
+ # convention. d/d(policy) of -min(surrogate, clipped) is
502
+ # -adv*ratio when use_unclipped, else 0 (clamp grad = 0).
503
+ # ``-(adv * ratio)`` is just ``-surrogate`` (already in
504
+ # scope) — saves one FMUL per element.
505
+ # KL term: d/d(policy) of (ratio_ref - log_ratio_ref - 1)
506
+ # = -(ratio_ref - 1) = 1 - ratio_ref.
507
+ if cutlass.const_expr(compute_backward):
508
+ neg_surrogate_grad = cute.where(
509
+ surrogate <= clipped,
510
+ -surrogate,
511
+ cutlass.Float32(0.0),
512
+ )
513
+ if cutlass.const_expr(compute_kl):
514
+ # ``beta - beta*ratio_ref`` instead of ``beta*(1 - ratio_ref)``
515
+ # gives ptxas an obvious FFMA pattern (``FFMA -beta,
516
+ # ratio_ref, beta``) — saves one FMUL per element vs
517
+ # the (1 - ratio_ref) intermediate.
518
+ kl_grad = beta - beta * ratio_ref
519
+ dpolicy_vec = neg_surrogate_grad + kl_grad
520
+ else:
521
+ dpolicy_vec = neg_surrogate_grad
522
+ dpolicy_vec = dpolicy_vec * mask_vec
523
+ dpolicy_vec = dpolicy_vec * inv_total
524
+
525
+ dp_dst = cute.local_tile(dp_slab, (VEC,), (sub_idx,))
526
+ dp_frag = cute.make_fragment_like(dp_dst)
527
+ dp_frag.store(dpolicy_vec.to(dpolicy.element_type))
528
+ cute.copy(r2g_atom, dp_frag, dp_dst)
529
+
530
+ else:
531
+ # ---- Predicated vector tail path (< tile_n valid elements) ----
532
+ pol_slab = cute.local_tile(pol_row, (tile_n,), (col_block,))
533
+ old_slab = cute.local_tile(old_row, (tile_n,), (col_block,))
534
+
535
+ if cutlass.const_expr(compute_kl):
536
+ ref_slab = cute.local_tile(ref_row, (tile_n,), (col_block,))
537
+
538
+ mask_slab = cute.local_tile(mask_row, (tile_n,), (col_block,))
539
+
540
+ if cutlass.const_expr(compute_backward):
541
+ dp_slab = cute.local_tile(dp_row, (tile_n,), (col_block,))
542
+
543
+ for k in cutlass.range(iters, unroll_full=True):
544
+ sub_idx = tid + k * block_size
545
+ chunk_base = sub_idx * VEC
546
+
547
+ if chunk_base < tail_len:
548
+ pol_src = cute.local_tile(pol_slab, (VEC,), (sub_idx,))
549
+ old_src = cute.local_tile(old_slab, (VEC,), (sub_idx,))
550
+ pred = cute.make_rmem_tensor(pol_src.shape, cutlass.Boolean)
551
+ for v in cutlass.range(VEC, unroll_full=True):
552
+ pred[v] = cute.elem_less(chunk_base + v, tail_len)
553
+
554
+ pol_frag = cute.make_fragment_like(pol_src)
555
+ old_frag = cute.make_fragment_like(old_src)
556
+ pol_frag.fill(0.0)
557
+ old_frag.fill(0.0)
558
+ cute.copy(g2r_atom, pol_src, pol_frag, pred=pred)
559
+ cute.copy(g2r_atom, old_src, old_frag, pred=pred)
560
+
561
+ if cutlass.const_expr(compute_kl):
562
+ ref_src = cute.local_tile(ref_slab, (VEC,), (sub_idx,))
563
+ ref_frag = cute.make_fragment_like(ref_src)
564
+ ref_frag.fill(0.0)
565
+ cute.copy(g2r_atom, ref_src, ref_frag, pred=pred)
566
+
567
+ mask_src = cute.local_tile(mask_slab, (VEC,), (sub_idx,))
568
+ mask_frag = cute.make_fragment_like(mask_src)
569
+ mask_frag.fill(0)
570
+ cute.copy(g2r_mask_atom, mask_src, mask_frag, pred=pred)
571
+
572
+ pol_vec = pol_frag.load().to(cutlass.Float32)
573
+ old_vec = old_frag.load().to(cutlass.Float32)
574
+ valid_vec = cute.where(
575
+ pred.load(),
576
+ cute.full_like(pol_vec, cutlass.Float32(1.0)),
577
+ cute.zeros_like(pol_vec, dtype=cutlass.Float32),
578
+ )
579
+
580
+ log_ratio = pol_vec - old_vec
581
+ ratio = cute.math.exp2(log_ratio * _LOG2_E, fastmath=True)
582
+ surrogate = ratio * adv
583
+ clipped_ratio = cute.where(
584
+ ratio < lo,
585
+ lo,
586
+ cute.where(ratio > hi, hi, ratio),
587
+ )
588
+ clipped = clipped_ratio * adv
589
+ policy_loss = -cute.where(surrogate < clipped, surrogate, clipped)
590
+
591
+ if cutlass.const_expr(compute_kl):
592
+ ref_vec = ref_frag.load().to(cutlass.Float32)
593
+ log_ratio_ref = ref_vec - pol_vec
594
+ ratio_ref = cute.math.exp2(log_ratio_ref * _LOG2_E, fastmath=True)
595
+ # FFMA-friendly rearrangement — see full-tile path.
596
+ kl_val = (ratio_ref - cutlass.Float32(1.0)) - log_ratio_ref
597
+
598
+ mask_vec = mask_frag.load().to(cutlass.Float32) * valid_vec
599
+ local_policy_sum += (policy_loss * mask_vec).reduce(
600
+ cute.ReductionOp.ADD,
601
+ cutlass.Float32(0.0),
602
+ reduction_profile=0,
603
+ )
604
+ if cutlass.const_expr(not compute_backward):
605
+ local_valid_sum += mask_vec.reduce(
606
+ cute.ReductionOp.ADD,
607
+ cutlass.Float32(0.0),
608
+ reduction_profile=0,
609
+ )
610
+ if cutlass.const_expr(compute_kl):
611
+ local_kl_sum += (kl_val * mask_vec).reduce(
612
+ cute.ReductionOp.ADD,
613
+ cutlass.Float32(0.0),
614
+ reduction_profile=0,
615
+ )
616
+
617
+ # ---- Backward: predicated dpolicy slab write ----
618
+ # Same gradient math as the full-tile path. ``valid_vec``
619
+ # already encodes the in-bounds predicate (1.0 inside,
620
+ # 0.0 outside) and is folded into ``mask_vec``, so
621
+ # multiplying by it zeros out the padded positions.
622
+ if cutlass.const_expr(compute_backward):
623
+ neg_surrogate_grad = cute.where(
624
+ surrogate <= clipped,
625
+ -surrogate,
626
+ cutlass.Float32(0.0),
627
+ )
628
+ if cutlass.const_expr(compute_kl):
629
+ kl_grad = beta - beta * ratio_ref
630
+ dpolicy_vec = neg_surrogate_grad + kl_grad
631
+ else:
632
+ dpolicy_vec = neg_surrogate_grad
633
+ dpolicy_vec = dpolicy_vec * mask_vec
634
+ dpolicy_vec = dpolicy_vec * inv_total
635
+
636
+ dp_dst = cute.local_tile(dp_slab, (VEC,), (sub_idx,))
637
+ dp_frag = cute.make_fragment_like(dp_dst)
638
+ dp_frag.store(dpolicy_vec.to(dpolicy.element_type))
639
+ cute.copy(r2g_atom, dp_frag, dp_dst, pred=pred)
640
+
641
+ # ---- Stage 1: Intra-warp reduction (butterfly XOR shuffles) ----
642
+ warp_policy = cute.arch.warp_reduction(local_policy_sum, operator.add)
643
+ if cutlass.const_expr(compute_kl):
644
+ warp_kl = cute.arch.warp_reduction(local_kl_sum, operator.add)
645
+
646
+ smem = cutlass.utils.SmemAllocator()
647
+ buf_policy = smem.allocate_tensor(cutlass.Float32, cute.make_layout(NUM_WARPS))
648
+ if cutlass.const_expr(compute_kl):
649
+ buf_kl = smem.allocate_tensor(cutlass.Float32, cute.make_layout(NUM_WARPS))
650
+
651
+ lane_idx = cute.arch.lane_idx()
652
+ warp_idx = cute.arch.warp_idx()
653
+
654
+ # When compute_backward is True the bundled mask-sum kernel populates
655
+ # inv_total_tensor before this kernel runs, so on-GPU mask-element
656
+ # accumulation is dead code.
657
+ if cutlass.const_expr(accumulate_valid):
658
+ warp_valid = cute.arch.warp_reduction(local_valid_sum, operator.add)
659
+ buf_valid = smem.allocate_tensor(cutlass.Float32, cute.make_layout(NUM_WARPS))
660
+
661
+ # ---- Stage 2: Cross-warp reduction via SMEM ----
662
+ if lane_idx == 0:
663
+ buf_policy[warp_idx] = warp_policy
664
+ if cutlass.const_expr(compute_kl):
665
+ buf_kl[warp_idx] = warp_kl
666
+ if cutlass.const_expr(accumulate_valid):
667
+ buf_valid[warp_idx] = warp_valid
668
+ cute.arch.barrier()
669
+
670
+ if warp_idx == 0:
671
+ val_p = cutlass.Float32(0.0)
672
+ if lane_idx < NUM_WARPS:
673
+ val_p = buf_policy[lane_idx]
674
+
675
+ block_policy = cute.arch.warp_reduction(val_p, operator.add, threads_in_group=NUM_WARPS)
676
+
677
+ if cutlass.const_expr(compute_kl):
678
+ val_k = cutlass.Float32(0.0)
679
+ if lane_idx < NUM_WARPS:
680
+ val_k = buf_kl[lane_idx]
681
+ block_kl = cute.arch.warp_reduction(val_k, operator.add, threads_in_group=NUM_WARPS)
682
+
683
+ if cutlass.const_expr(accumulate_valid):
684
+ val_v = cutlass.Float32(0.0)
685
+ if lane_idx < NUM_WARPS:
686
+ val_v = buf_valid[lane_idx]
687
+ block_valid = cute.arch.warp_reduction(
688
+ val_v, operator.add, threads_in_group=NUM_WARPS
689
+ )
690
+
691
+ # ---- Stage 3: Cross-CTA atomic accumulation ----
692
+ if lane_idx == 0:
693
+ policy_ptr = policy_acc.iterator.toint().ir_value() # ty: ignore[unresolved-attribute]
694
+ counter_ptr = counter.iterator.toint().ir_value() # ty: ignore[unresolved-attribute]
695
+
696
+ _atomic_add_f32_gmem(policy_ptr, block_policy)
697
+
698
+ if cutlass.const_expr(compute_kl):
699
+ kl_ptr = kl_acc.iterator.toint().ir_value() # ty: ignore[unresolved-attribute]
700
+ _atomic_add_f32_gmem(kl_ptr, block_kl)
701
+
702
+ if cutlass.const_expr(accumulate_valid):
703
+ valid_ptr = valid_acc.iterator.toint().ir_value() # ty: ignore[unresolved-attribute]
704
+ # valid_acc is int32. Per-block sums of int8 0/1 values
705
+ # fit exactly in fp32 (≤ tile_n ≤ 4096 ≪ 2²⁴) so the
706
+ # cast is bit-exact.
707
+ _atomic_add_s32_gmem(valid_ptr, cutlass.Int32(block_valid))
708
+
709
+ cute.arch.fence_acq_rel_gpu()
710
+
711
+ old = _atomic_inc_u32_gmem(counter_ptr, total_blocks - 1)
712
+
713
+ if old == total_blocks - 1:
714
+ pol_sum = policy_acc[0]
715
+
716
+ if cutlass.const_expr(accumulate_valid):
717
+ # Clamp to >=1.0 so a fully-masked batch (n_valid=0)
718
+ # produces ``loss=0`` instead of inf/NaN — matches
719
+ # TRL's ``mask.sum().clamp(min=1)`` semantics.
720
+ n_valid = cute.arch.fmax(
721
+ cutlass.Float32(valid_acc[0]), cutlass.Float32(1.0)
722
+ )
723
+ inv_total_computed = cute.arch.rcp_approx(n_valid)
724
+ else:
725
+ # compute_backward path: bundled mask-sum kernel
726
+ # already wrote the inverse so forward and backward
727
+ # share the same scalar.
728
+ inv_total_computed = inv_total
729
+
730
+ if cutlass.const_expr(compute_kl):
731
+ kl_sum = kl_acc[0]
732
+ loss = (pol_sum + beta * kl_sum) * inv_total_computed
733
+ else:
734
+ loss = pol_sum * inv_total_computed
735
+ output[0] = cast(loss, output.element_type) # ty: ignore[invalid-argument-type]
736
+
737
+ # Reset accumulators for the next invocation.
738
+ # Counter self-resets via atom.inc wrap-around.
739
+ policy_acc[0] = cutlass.Float32(0.0)
740
+ if cutlass.const_expr(compute_kl):
741
+ kl_acc[0] = cutlass.Float32(0.0)
742
+ if cutlass.const_expr(accumulate_valid):
743
+ valid_acc[0] = cutlass.Int32(0)
744
+
745
+ return _bnpo_loss_kernel
746
+
747
+
748
+ def create_compiled_bnpo_loss(
749
+ policy_dtype: torch.dtype,
750
+ epsilon: float,
751
+ epsilon_high: float,
752
+ beta: float,
753
+ compute_backward: bool = False,
754
+ ) -> Callable[..., torch.Tensor | tuple[torch.Tensor, torch.Tensor]]:
755
+ """Compile the bnpo loss kernel for a given dtype/KL/backward configuration.
756
+
757
+ The runner allocates per-call scratch (``output``, ``inv_total``, and on
758
+ the fwd+bwd path ``dpolicy``) inside ``_run`` itself; cross-CTA scratch
759
+ (atomic accumulators + counters) is allocated lazily on first call from
760
+ the input device and self-resets each launch via the kernel's last-block
761
+ epilogue + ``atom.inc.u32`` wrap-around.
762
+ """
763
+ compute_kl = beta != 0.0
764
+
765
+ if policy_dtype not in _TORCH_TO_CUTLASS_DTYPE:
766
+ raise ValueError(f"Unsupported dtype for bnpo kernel: {policy_dtype}")
767
+
768
+ tile_n_small = TILE_N
769
+ tile_n_large = TILE_N_LARGE
770
+ seq_len_threshold = TILE_N_LARGE_THRESHOLD
771
+ block_size = NUM_WARPS * 32
772
+ if tile_n_small % (block_size * VEC) != 0:
773
+ raise ValueError(
774
+ f"TILE_N={tile_n_small} must be a multiple of BLOCK_SIZE*VEC={block_size * VEC}"
775
+ )
776
+ if tile_n_large % (block_size * VEC) != 0:
777
+ raise ValueError(
778
+ f"TILE_N_LARGE={tile_n_large} must be a multiple of BLOCK_SIZE*VEC={block_size * VEC}"
779
+ )
780
+
781
+ bs_sym = cute.sym_int()
782
+ seq_len_sym = cute.sym_int()
783
+ cute_dtype = _TORCH_TO_CUTLASS_DTYPE[policy_dtype]
784
+
785
+ def _fake2d(dt: Any, cols: Any) -> Any:
786
+ return cute.runtime.make_fake_compact_tensor(
787
+ dt,
788
+ (bs_sym, cols),
789
+ stride_order=(1, 0),
790
+ assumed_align=16,
791
+ )
792
+
793
+ fake_pol = _fake2d(cute_dtype, seq_len_sym)
794
+ fake_old = _fake2d(cute_dtype, seq_len_sym)
795
+ fake_ref = _fake2d(cute_dtype, seq_len_sym)
796
+ fake_adv = cute.runtime.make_fake_compact_tensor(
797
+ cute_dtype,
798
+ (bs_sym,),
799
+ assumed_align=16,
800
+ )
801
+ fake_mask = cute.runtime.make_fake_compact_tensor(
802
+ cutlass.Int8,
803
+ (bs_sym, seq_len_sym),
804
+ stride_order=(1, 0),
805
+ assumed_align=16,
806
+ )
807
+ dpolicy_cols = seq_len_sym if compute_backward else 1
808
+ fake_dpolicy = cute.runtime.make_fake_compact_tensor(
809
+ cute_dtype,
810
+ (bs_sym, dpolicy_cols),
811
+ stride_order=(1, 0),
812
+ assumed_align=16,
813
+ )
814
+ fake_scalar_f32 = cute.runtime.make_fake_compact_tensor(
815
+ cutlass.Float32,
816
+ (1,),
817
+ assumed_align=16,
818
+ )
819
+ fake_valid_acc = cute.runtime.make_fake_compact_tensor(
820
+ cutlass.Int32,
821
+ (1,),
822
+ assumed_align=16,
823
+ )
824
+ fake_counter = cute.runtime.make_fake_compact_tensor(
825
+ cutlass.Int32,
826
+ (1,),
827
+ assumed_align=16,
828
+ )
829
+ fake_mask_counter = cute.runtime.make_fake_compact_tensor(
830
+ cutlass.Int32,
831
+ (1,),
832
+ assumed_align=16,
833
+ )
834
+ fake_output = cute.runtime.make_fake_compact_tensor(
835
+ cute_dtype,
836
+ (1,),
837
+ assumed_align=16,
838
+ )
839
+
840
+ def _build_launch(tile_n_v: int) -> Callable[..., None]:
841
+ """Build a ``@cute.jit`` ``_launch`` for a given ``tile_n``.
842
+
843
+ Captures *tile_n_v* via closure; both the main kernel and the
844
+ (optional) mask-sum kernel are specialised to this tile size.
845
+ One ``_launch`` per tier; the runner dispatches at call time.
846
+ """
847
+ specialized_kernel = _make_bnpo_kernel(compute_kl, compute_backward, tile_n_v)
848
+ if compute_backward:
849
+ mask_sum_kernel = _make_mask_sum_kernel(tile_n_v)
850
+
851
+ @cute.jit
852
+ def _launch(
853
+ pol_ct: cute.Tensor,
854
+ old_ct: cute.Tensor,
855
+ ref_ct: cute.Tensor,
856
+ adv_ct: cute.Tensor,
857
+ mask_ct: cute.Tensor,
858
+ dpolicy_ct: cute.Tensor,
859
+ inv_total_ct: cute.Tensor,
860
+ policy_acc_ct: cute.Tensor,
861
+ kl_acc_ct: cute.Tensor,
862
+ valid_acc_ct: cute.Tensor,
863
+ counter_ct: cute.Tensor,
864
+ mask_counter_ct: cute.Tensor,
865
+ output_ct: cute.Tensor,
866
+ epsilon_v: cutlass.Float32,
867
+ epsilon_high_v: cutlass.Float32,
868
+ beta_v: cutlass.Float32,
869
+ total_blocks_v: cutlass.Int32,
870
+ num_full_tiles_v: cutlass.Int32,
871
+ tail_len_v: cutlass.Int32,
872
+ num_col_tiles_v: cutlass.Int32,
873
+ ) -> None:
874
+ bs_v = pol_ct.shape[0] # ty: ignore[not-subscriptable]
875
+ # Bundled mask-sum (compute_backward only) — writes
876
+ # ``1 / completions_mask.sum()`` into ``inv_total_ct`` before the
877
+ # main kernel reads it. Both kernels in one tvm-ffi dispatch
878
+ # eliminates the per-call ``torch.sum`` + reciprocal round trip.
879
+ if cutlass.const_expr(compute_backward):
880
+ mask_sum_kernel( # ty: ignore[unresolved-attribute]
881
+ mask_ct,
882
+ inv_total_ct,
883
+ valid_acc_ct,
884
+ mask_counter_ct,
885
+ total_blocks_v,
886
+ num_full_tiles_v,
887
+ tail_len_v,
888
+ ).launch(
889
+ grid=(bs_v, num_col_tiles_v, 1),
890
+ block=(NUM_WARPS * 32, 1, 1),
891
+ )
892
+ specialized_kernel( # ty: ignore[unresolved-attribute]
893
+ pol_ct,
894
+ old_ct,
895
+ ref_ct,
896
+ adv_ct,
897
+ mask_ct,
898
+ dpolicy_ct,
899
+ inv_total_ct,
900
+ policy_acc_ct,
901
+ kl_acc_ct,
902
+ valid_acc_ct,
903
+ counter_ct,
904
+ output_ct,
905
+ epsilon_v,
906
+ epsilon_high_v,
907
+ beta_v,
908
+ total_blocks_v,
909
+ num_full_tiles_v,
910
+ tail_len_v,
911
+ ).launch(
912
+ grid=(bs_v, num_col_tiles_v, 1),
913
+ block=(NUM_WARPS * 32, 1, 1),
914
+ )
915
+
916
+ return _launch
917
+
918
+ def _compile_launch(launch_fn: Callable[..., None]) -> Callable[..., None]:
919
+ return cute.compile(
920
+ launch_fn,
921
+ fake_pol,
922
+ fake_old,
923
+ fake_ref,
924
+ fake_adv,
925
+ fake_mask,
926
+ fake_dpolicy,
927
+ fake_scalar_f32,
928
+ fake_scalar_f32,
929
+ fake_scalar_f32,
930
+ fake_valid_acc,
931
+ fake_counter,
932
+ fake_mask_counter,
933
+ fake_output,
934
+ cutlass.Float32(epsilon),
935
+ cutlass.Float32(epsilon_high),
936
+ cutlass.Float32(beta),
937
+ cutlass.Int32(1),
938
+ cutlass.Int32(1),
939
+ cutlass.Int32(0),
940
+ cutlass.Int32(1),
941
+ options="--enable-tvm-ffi",
942
+ )
943
+
944
+ compiled_small = _compile_launch(_build_launch(tile_n_small))
945
+ if tile_n_large == tile_n_small:
946
+ compiled_large = compiled_small
947
+ else:
948
+ compiled_large = _compile_launch(_build_launch(tile_n_large))
949
+
950
+ eps_const = cutlass.Float32(epsilon)
951
+ eps_high_const = cutlass.Float32(epsilon_high)
952
+ beta_const = cutlass.Float32(beta)
953
+
954
+ # Cross-CTA scratch slab — one int32 buffer with stride-4 (16-byte) slices
955
+ # so each slot is individually 16-byte aligned (``assumed_align=16`` at
956
+ # compile time). Bit-pattern of int32 0 equals fp32 0.0, so a single
957
+ # ``zeros`` factory legitimately initialises both the int32 counters and
958
+ # the fp32 accumulators. The kernel's last block self-resets accumulators
959
+ # in its epilogue and the counters self-reset via ``atom.inc.u32``
960
+ # wrap-around, so the up-front ``torch.zeros`` only matters for the very
961
+ # first call.
962
+ _scratch: list[torch.Tensor | None] = [None]
963
+
964
+ def _ensure_scratch(device: torch.device) -> tuple[torch.Tensor, ...]:
965
+ s = _scratch[0]
966
+ if s is None or s.device != device:
967
+ s = torch.zeros(20, dtype=torch.int32, device=device)
968
+ _scratch[0] = s
969
+ return (
970
+ s[0:1], # counter (int32)
971
+ s[4:5], # mask_counter (int32)
972
+ s[8:9], # valid_acc (int32)
973
+ s[12:13].view(torch.float32), # policy_acc (fp32)
974
+ s[16:17].view(torch.float32), # kl_acc (fp32)
975
+ )
976
+
977
+ def _run(
978
+ policy_logprobs_r: torch.Tensor,
979
+ old_policy_logprobs_r: torch.Tensor,
980
+ ref_logprobs_r: torch.Tensor,
981
+ advantages_r: torch.Tensor,
982
+ completions_mask_r: torch.Tensor,
983
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
984
+ bs, seq_len = policy_logprobs_r.shape
985
+ device = policy_logprobs_r.device
986
+ dtype = policy_logprobs_r.dtype
987
+
988
+ # Tier dispatch: long sequences pay too much last-block-detection
989
+ # latency under the small-tile grid, so swap to the large-tile
990
+ # compiled variant.
991
+ if seq_len >= seq_len_threshold:
992
+ tile_n_active = tile_n_large
993
+ compiled_active = compiled_large
994
+ else:
995
+ tile_n_active = tile_n_small
996
+ compiled_active = compiled_small
997
+ num_full_tiles = seq_len // tile_n_active
998
+ tail_len = seq_len % tile_n_active
999
+ num_col_tiles = num_full_tiles + (1 if tail_len > 0 else 0)
1000
+ total_blocks = bs * num_col_tiles
1001
+
1002
+ # Per-call write-only buffers — ``empty`` is enough (Liger / TE
1003
+ # pattern). ``inv_total`` is populated by the bundled mask-sum
1004
+ # kernel (compute_backward path) or by the main kernel's last-block
1005
+ # trick (fwd-only path); the runner never reads it.
1006
+ output_r = torch.empty(1, dtype=dtype, device=device)
1007
+ inv_total_r = torch.empty(1, dtype=torch.float32, device=device)
1008
+ if compute_backward:
1009
+ dpolicy_r = torch.empty_like(policy_logprobs_r)
1010
+ else:
1011
+ dpolicy_r = torch.empty(bs, 1, dtype=dtype, device=device)
1012
+
1013
+ counter_r, mask_counter_r, valid_acc_r, policy_acc_r, kl_acc_r = _ensure_scratch(device)
1014
+
1015
+ compiled_active(
1016
+ policy_logprobs_r,
1017
+ old_policy_logprobs_r,
1018
+ ref_logprobs_r,
1019
+ advantages_r,
1020
+ completions_mask_r,
1021
+ dpolicy_r,
1022
+ inv_total_r,
1023
+ policy_acc_r,
1024
+ kl_acc_r,
1025
+ valid_acc_r,
1026
+ counter_r,
1027
+ mask_counter_r,
1028
+ output_r,
1029
+ eps_const,
1030
+ eps_high_const,
1031
+ beta_const,
1032
+ total_blocks,
1033
+ num_full_tiles,
1034
+ tail_len,
1035
+ num_col_tiles,
1036
+ )
1037
+ out_view = output_r.view(())
1038
+ if compute_backward:
1039
+ return out_view, dpolicy_r
1040
+ return out_view
1041
+
1042
+ return _run
1043
+
1044
+
1045
+ # ---------------------------------------------------------------------------
1046
+ # Fused forward + backward — direct (loss, grad) runner, no autograd
1047
+ # ---------------------------------------------------------------------------
1048
+
1049
+
1050
+ def create_compiled_bnpo_loss_with_backward(
1051
+ policy_dtype: torch.dtype,
1052
+ epsilon: float,
1053
+ epsilon_high: float,
1054
+ beta: float,
1055
+ ) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]:
1056
+ """Compile the fused fwd+bwd bnpo kernel and return a tuple-returning runner.
1057
+
1058
+ The returned callable runs one training-step worth of work: a single
1059
+ ``@cute.jit`` dispatch produces both the scalar loss and the scaled
1060
+ ``dL/d(policy_logprobs)`` tensor. It returns ``(loss, dpolicy)`` directly
1061
+ — no ``torch.autograd.Function`` wrapper, no extra ``grad_output * dpolicy``
1062
+ backward kernel. Callers that need autograd integration (so
1063
+ ``loss.backward()`` works) wrap this themselves at the public-API layer;
1064
+ callers that control gradient flow manually (benchmarks, custom training
1065
+ loops) can use it as-is for zero overhead.
1066
+
1067
+ ``inv_total`` is computed entirely on-GPU by a bundled mask-sum kernel
1068
+ that runs in series with the main kernel inside the same ``@cute.jit``
1069
+ launch — no host sync, no extra ``torch.sum`` dispatch, CUDA-graph
1070
+ compatible.
1071
+ """
1072
+ return _typing_cast(
1073
+ "Callable[..., tuple[torch.Tensor, torch.Tensor]]",
1074
+ create_compiled_bnpo_loss(
1075
+ policy_dtype=policy_dtype,
1076
+ epsilon=epsilon,
1077
+ epsilon_high=epsilon_high,
1078
+ beta=beta,
1079
+ compute_backward=True,
1080
+ ),
1081
+ )
build/torch-cuda/geometric_ai_kernels/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import importlib.util
3
+ import sys
4
+ from pathlib import Path
5
+ from types import ModuleType
6
+
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch-cuda/grpo_loss/__init__.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GRPO loss with CuteDSL fused fwd+bwd.
2
+
3
+ Three public APIs:
4
+
5
+ * :func:`grpo_loss` — fused fwd+bwd. Returns
6
+ ``(loss, grad_policy_logprobs)`` from a single ``@cute.jit`` dispatch.
7
+ Caller chains via ``policy_logprobs.backward(grad)``.
8
+ * :func:`grpo_loss_fwd` — forward-only (inference / validation).
9
+ Returns scalar ``loss`` and skips the dpolicy buffer entirely.
10
+ * :func:`grpo_loss_autograd` — autograd-aware via
11
+ ``torch.library.custom_op``. ``loss.backward()`` works and composes
12
+ with ``torch.compile``. ~12µs of dispatcher overhead vs.
13
+ :func:`grpo_loss`.
14
+
15
+ GRPO requires ``completions_mask``; the per-response normalization
16
+ formula is mask-derived. The cute kernel uses one CTA per row so the
17
+ per-row mask sum is reduced inside the block — no cross-CTA atomics
18
+ or last-block detection on the per-row scaling pass.
19
+
20
+ Per-call output and gradient buffers are allocated inside the runner;
21
+ cross-CTA scratch (the ``policy_acc`` accumulator + last-block counter)
22
+ is owned by the compiled-kernel closure and self-resets each launch.
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ from functools import lru_cache
28
+ from typing import TYPE_CHECKING, cast
29
+
30
+ import torch
31
+
32
+ from .cute_grpo_loss import create_compiled_grpo_loss
33
+
34
+ if TYPE_CHECKING:
35
+ from collections.abc import Callable
36
+
37
+
38
+ __all__ = ["grpo_loss", "grpo_loss_autograd", "grpo_loss_fwd"]
39
+
40
+
41
+ @lru_cache(maxsize=32)
42
+ def _get_compiled_fwd(
43
+ dtype: torch.dtype,
44
+ epsilon: float,
45
+ epsilon_high: float,
46
+ beta: float,
47
+ ) -> Callable[..., torch.Tensor]:
48
+ return cast(
49
+ "Callable[..., torch.Tensor]",
50
+ create_compiled_grpo_loss(
51
+ policy_dtype=dtype,
52
+ epsilon=epsilon,
53
+ epsilon_high=epsilon_high,
54
+ beta=beta,
55
+ compute_backward=False,
56
+ ),
57
+ )
58
+
59
+
60
+ @lru_cache(maxsize=32)
61
+ def _get_compiled_fwd_bwd(
62
+ dtype: torch.dtype,
63
+ epsilon: float,
64
+ epsilon_high: float,
65
+ beta: float,
66
+ ) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]:
67
+ return cast(
68
+ "Callable[..., tuple[torch.Tensor, torch.Tensor]]",
69
+ create_compiled_grpo_loss(
70
+ policy_dtype=dtype,
71
+ epsilon=epsilon,
72
+ epsilon_high=epsilon_high,
73
+ beta=beta,
74
+ compute_backward=True,
75
+ ),
76
+ )
77
+
78
+
79
+ def _mask_to_int8(completions_mask: torch.Tensor) -> torch.Tensor:
80
+ return (
81
+ completions_mask
82
+ if completions_mask.dtype == torch.int8
83
+ else completions_mask.to(torch.int8)
84
+ )
85
+
86
+
87
+ def grpo_loss_fwd(
88
+ policy_logprobs: torch.Tensor,
89
+ old_policy_logprobs: torch.Tensor,
90
+ ref_logprobs: torch.Tensor,
91
+ advantages: torch.Tensor,
92
+ completions_mask: torch.Tensor,
93
+ epsilon: float = 0.2,
94
+ epsilon_high: float = 0.2,
95
+ beta: float = 0.1,
96
+ ) -> torch.Tensor:
97
+ """Forward-only GRPO loss. Returns the scalar ``loss``.
98
+
99
+ Args:
100
+ policy_logprobs, old_policy_logprobs, ref_logprobs: ``(bs, seq_len)``.
101
+ advantages: ``(bs,)``.
102
+ completions_mask: Bool / int8 mask ``(bs, seq_len)``. Required.
103
+ epsilon, epsilon_high: PPO-style clipping bounds.
104
+ beta: KL-penalty coefficient. ``0.0`` compiles away the KL branch.
105
+
106
+ Returns:
107
+ Scalar tensor (0-dim) with the same dtype as ``policy_logprobs``.
108
+ """
109
+ run = _get_compiled_fwd(
110
+ policy_logprobs.dtype,
111
+ float(epsilon),
112
+ float(epsilon_high),
113
+ float(beta),
114
+ )
115
+ return run(
116
+ policy_logprobs,
117
+ old_policy_logprobs,
118
+ ref_logprobs,
119
+ advantages,
120
+ _mask_to_int8(completions_mask),
121
+ )
122
+
123
+
124
+ def grpo_loss(
125
+ policy_logprobs: torch.Tensor,
126
+ old_policy_logprobs: torch.Tensor,
127
+ ref_logprobs: torch.Tensor,
128
+ advantages: torch.Tensor,
129
+ completions_mask: torch.Tensor,
130
+ epsilon: float = 0.2,
131
+ epsilon_high: float = 0.2,
132
+ beta: float = 0.1,
133
+ ) -> tuple[torch.Tensor, torch.Tensor]:
134
+ """Fused fwd+bwd GRPO loss. Returns ``(loss, grad_policy_logprobs)``.
135
+
136
+ Inputs do **not** need ``requires_grad=True``. Chain ``grad`` into
137
+ the upstream model via ``policy_logprobs.backward(grad)``.
138
+
139
+ Args:
140
+ policy_logprobs, old_policy_logprobs, ref_logprobs: ``(bs, seq_len)``.
141
+ advantages: ``(bs,)``.
142
+ completions_mask: Bool / int8 mask ``(bs, seq_len)``. Required.
143
+ epsilon, epsilon_high: PPO-style clipping bounds.
144
+ beta: KL-penalty coefficient. ``0.0`` compiles away the KL branch.
145
+
146
+ Returns:
147
+ ``(loss, grad_policy_logprobs)``. The grad already has the
148
+ per-row ``1 / mask.sum(-1).clamp(min=1)`` and across-row
149
+ ``1/n_rows`` scalings folded in. The grad tensor is freshly
150
+ allocated per call (no shared cache).
151
+ """
152
+ run = _get_compiled_fwd_bwd(
153
+ policy_logprobs.dtype,
154
+ float(epsilon),
155
+ float(epsilon_high),
156
+ float(beta),
157
+ )
158
+ return run(
159
+ policy_logprobs,
160
+ old_policy_logprobs,
161
+ ref_logprobs,
162
+ advantages,
163
+ _mask_to_int8(completions_mask),
164
+ )
165
+
166
+
167
+ # ``autograd.py`` imports ``grpo_loss`` from this module, so the function
168
+ # must be fully defined before its import runs.
169
+ from .autograd import grpo_loss_autograd # noqa: E402
build/torch-cuda/grpo_loss/_torch_ref.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Plain-PyTorch GRPO reference shared between bench and tests.
2
+
3
+ Mirrors TRL's default per-response normalization variant: per-row mask
4
+ sum acts as the divisor for that row's loss before averaging across
5
+ rows. Every op is a vanilla torch op so AOTAutograd can derive the
6
+ joint fwd+bwd graph and Inductor can fuse both passes.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import torch
12
+
13
+
14
+ def torch_grpo_loss(
15
+ policy_logprobs: torch.Tensor,
16
+ old_policy_logprobs: torch.Tensor,
17
+ ref_logprobs: torch.Tensor,
18
+ advantages: torch.Tensor,
19
+ completions_mask: torch.Tensor,
20
+ epsilon: float = 0.2,
21
+ epsilon_high: float = 0.2,
22
+ beta: float = 0.1,
23
+ ) -> torch.Tensor:
24
+ """Compute the GRPO (Group Relative Policy Optimization) loss.
25
+
26
+ Implements TRL's default per-response normalization variant:
27
+
28
+ L_GRPO = mean_r( ((policy_loss + beta*kl) * mask).sum(-1)
29
+ / mask.sum(-1).clamp(min=1) )
30
+
31
+ Each row (response) is independently normalized by its own valid-token
32
+ count, then the per-row losses are averaged across rows. This differs
33
+ from the BNPO variant in :func:`torch_bnpo_loss`, which sums numerators
34
+ *and* denominators globally before dividing — under variable response
35
+ lengths BNPO weights longer responses more heavily, while GRPO weights
36
+ every response equally.
37
+
38
+ **Probability ratio:**
39
+
40
+ r_t(theta) = exp(log pi_theta - log pi_{theta_old})
41
+
42
+ **Clipped surrogate (per token):**
43
+
44
+ L_CLIP_t = -min( r_t * A, clip(r_t, 1 - eps, 1 + eps_high) * A )
45
+
46
+ **KL divergence (Schulman approximation, per token):**
47
+
48
+ kl_t ~= exp(log pi_ref - log pi_theta) - (log pi_ref - log pi_theta) - 1
49
+
50
+ Args:
51
+ policy_logprobs: Log-probabilities of the current policy, shape (N, C).
52
+ old_policy_logprobs: Log-probabilities of the behaviour policy used to
53
+ collect the rollout, shape (N, C).
54
+ ref_logprobs: Log-probabilities of the frozen reference policy, shape
55
+ (N, C).
56
+ advantages: Per-sequence advantage estimates, shape (N,).
57
+ completions_mask: Boolean mask of shape (N, C) where True marks valid tokens.
58
+ Required — GRPO's per-response normalization is mask-derived.
59
+ epsilon: Lower asymmetric clipping bound (1 - epsilon). Default: 0.2.
60
+ epsilon_high: Upper asymmetric clipping bound (1 + ε_high). Default:
61
+ 0.2 (symmetric with ``epsilon``, matching TRL's GRPOConfig).
62
+ beta: Coefficient for the KL-divergence penalty. Default: 0.1.
63
+
64
+ Returns:
65
+ Scalar tensor representing the GRPO loss.
66
+ """
67
+ ratio = torch.exp(policy_logprobs - old_policy_logprobs)
68
+ adv = advantages.unsqueeze(1) # (N, 1) for broadcasting over tokens
69
+
70
+ surrogate = ratio * adv
71
+ surrogate_clipped = torch.clamp(ratio, 1.0 - epsilon, 1.0 + epsilon_high) * adv
72
+ policy_loss_per_tok = -torch.min(surrogate, surrogate_clipped)
73
+
74
+ # Per-row valid-token count, fp32 + clamp to avoid div-by-zero on
75
+ # fully-masked rows (matches TRL's ``mask.sum(-1).clamp(min=1)``).
76
+ mask_sum = completions_mask.sum(-1).to(torch.float32).clamp_min(1.0) # (N,)
77
+ policy_per_row = (policy_loss_per_tok * completions_mask).sum(-1) / mask_sum
78
+
79
+ if beta != 0.0:
80
+ log_ratio_ref = ref_logprobs - policy_logprobs
81
+ kl_per_tok = torch.exp(log_ratio_ref) - log_ratio_ref - 1.0
82
+ kl_per_row = (kl_per_tok * completions_mask).sum(-1) / mask_sum
83
+ loss = (policy_per_row + beta * kl_per_row).mean()
84
+ else:
85
+ loss = policy_per_row.mean()
86
+
87
+ return loss.to(policy_logprobs.dtype)
build/torch-cuda/grpo_loss/autograd.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Autograd-aware wrapper for GRPO loss via ``torch.library.custom_op``.
2
+
3
+ The fused cute kernel writes both the scalar loss and the closed-form
4
+ ``dL/d(policy_logprobs)`` in one launch. This module wraps that into an
5
+ autograd-compatible op so callers can write::
6
+
7
+ loss = grpo_loss_autograd(policy, old, ref, adv, completions_mask)
8
+ loss.backward()
9
+
10
+ Implementation mirrors the BNPO autograd binding: ``custom_op`` with a
11
+ registered ``setup_context`` / backward, plus a ``register_fake`` for
12
+ shape propagation under ``torch.compile``. The runner allocates
13
+ ``dpolicy`` fresh on every call (no shared cache), so
14
+ ``ctx.save_for_backward(dpolicy)`` keeps a stable reference for free.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import torch
20
+
21
+ from . import grpo_loss as _grpo_loss_fwd_bwd
22
+
23
+ __all__ = ["grpo_loss_autograd"]
24
+
25
+
26
+ @torch.library.custom_op(
27
+ "geometric_ai_kernels::_grpo_loss_with_grad",
28
+ mutates_args=(),
29
+ )
30
+ def _grpo_loss_with_grad(
31
+ policy_logprobs: torch.Tensor,
32
+ old_policy_logprobs: torch.Tensor,
33
+ ref_logprobs: torch.Tensor,
34
+ advantages: torch.Tensor,
35
+ completions_mask: torch.Tensor,
36
+ epsilon: float,
37
+ epsilon_high: float,
38
+ beta: float,
39
+ ) -> tuple[torch.Tensor, torch.Tensor]:
40
+ loss, dpolicy = _grpo_loss_fwd_bwd(
41
+ policy_logprobs,
42
+ old_policy_logprobs,
43
+ ref_logprobs,
44
+ advantages,
45
+ completions_mask,
46
+ epsilon=epsilon,
47
+ epsilon_high=epsilon_high,
48
+ beta=beta,
49
+ )
50
+ return loss, dpolicy
51
+
52
+
53
+ @_grpo_loss_with_grad.register_fake
54
+ def _(
55
+ policy_logprobs: torch.Tensor,
56
+ old_policy_logprobs: torch.Tensor,
57
+ ref_logprobs: torch.Tensor,
58
+ advantages: torch.Tensor,
59
+ completions_mask: torch.Tensor,
60
+ epsilon: float,
61
+ epsilon_high: float,
62
+ beta: float,
63
+ ) -> tuple[torch.Tensor, torch.Tensor]:
64
+ del old_policy_logprobs, ref_logprobs, advantages, completions_mask
65
+ del epsilon, epsilon_high, beta
66
+ loss = policy_logprobs.new_empty(())
67
+ dpolicy = torch.empty_like(policy_logprobs)
68
+ return loss, dpolicy
69
+
70
+
71
+ def _setup_context(ctx, inputs, output) -> None: # type: ignore[no-untyped-def]
72
+ del inputs
73
+ _, dpolicy = output
74
+ ctx.save_for_backward(dpolicy)
75
+
76
+
77
+ def _backward(ctx, grad_loss, grad_dpolicy): # type: ignore[no-untyped-def]
78
+ del grad_dpolicy
79
+ (dpolicy,) = ctx.saved_tensors
80
+ grad_policy = grad_loss * dpolicy
81
+ # One return per input (8): policy_logprobs gets the grad, the rest get None.
82
+ return grad_policy, None, None, None, None, None, None, None
83
+
84
+
85
+ torch.library.register_autograd(
86
+ "geometric_ai_kernels::_grpo_loss_with_grad",
87
+ _backward,
88
+ setup_context=_setup_context,
89
+ )
90
+
91
+
92
+ def grpo_loss_autograd(
93
+ policy_logprobs: torch.Tensor,
94
+ old_policy_logprobs: torch.Tensor,
95
+ ref_logprobs: torch.Tensor,
96
+ advantages: torch.Tensor,
97
+ completions_mask: torch.Tensor,
98
+ epsilon: float = 0.2,
99
+ epsilon_high: float = 0.2,
100
+ beta: float = 0.1,
101
+ ) -> torch.Tensor:
102
+ """Autograd-aware GRPO loss. Returns scalar ``loss``.
103
+
104
+ Same numerics as :func:`grpo_loss` but registered as a
105
+ ``torch.library`` custom op with autograd, so ``loss.backward()``
106
+ Just Works. For direct ``(loss, grad)`` access without the
107
+ autograd dispatcher overhead, use :func:`grpo_loss` and chain via
108
+ ``policy_logprobs.backward(grad)``.
109
+ """
110
+ loss, _ = _grpo_loss_with_grad(
111
+ policy_logprobs,
112
+ old_policy_logprobs,
113
+ ref_logprobs,
114
+ advantages,
115
+ completions_mask,
116
+ float(epsilon),
117
+ float(epsilon_high),
118
+ float(beta),
119
+ )
120
+ return loss
build/torch-cuda/grpo_loss/cute_grpo_loss.py ADDED
@@ -0,0 +1,805 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CuteDSL kernel for GRPO (Group Relative Policy Optimization) loss.
2
+
3
+ Implements TRL's default per-response normalization variant:
4
+
5
+ loss = mean_r( ((per_token_loss + beta * kl) * mask).sum(-1)
6
+ / mask.sum(-1).clamp(min=1) )
7
+
8
+ Element-wise over ``(N, C)`` logprob tensors:
9
+
10
+ ratio = exp(policy - old_policy)
11
+ surrogate = ratio * adv
12
+ clipped = clip(ratio, 1 - eps, 1 + eps_high) * adv
13
+ policy_loss = -min(surrogate, clipped)
14
+ log_ratio_ref = ref - policy
15
+ kl = exp(log_ratio_ref) - log_ratio_ref - 1
16
+
17
+ ``completions_mask`` is **required** for GRPO -- the per-response normalization
18
+ formula is mask-derived.
19
+
20
+ **One CTA per row.** Each row is owned by exactly one block, so the
21
+ per-row mask sum is computed locally (warp + cross-warp reduction) with
22
+ no cross-CTA atomics, fences, or last-block detection on the per-row
23
+ scaling pass.
24
+
25
+ **Mask prescan on the fwd+bwd path.** Before the main compute pass each
26
+ CTA runs a cheap mask-only sweep over its row to derive ``row_scale``
27
+ (``1 / max(mask.sum(), 1) * inv_n_rows``). With ``row_scale`` known up
28
+ front the main pass reads ``policy / old / ref`` exactly **once**,
29
+ computes loss and gradient together, and writes the **scaled**
30
+ ``dpolicy`` directly — no second logprob read, no unscaled GMEM round
31
+ trip. The prescan touches only the int8 mask (1 byte / element) and is
32
+ much cheaper than the byte of logprob traffic it eliminates (2 B in
33
+ bf16/fp16, 4 B in fp32, ×2 or ×3 depending on the KL term).
34
+
35
+ The fwd-only path skips the prescan and accumulates ``valid`` directly
36
+ in a single sweep over the row.
37
+
38
+ When ``beta=0`` the KL term is skipped at compile time (no ``ref``
39
+ tensor access, no ``kl`` accumulator).
40
+
41
+ Sequence lengths that are **not** a multiple of ``TILE_N`` are handled
42
+ natively: the in-block tile loop runs ``ceil(C / TILE_N)`` iterations;
43
+ full tiles use the vectorized ``LDG.128`` path and the tail tile uses
44
+ predicated vector loads with neutral prefill.
45
+
46
+ Each CTA reduces its row's policy / kl sums in registers, finishes the
47
+ cross-warp reduction in SMEM, computes its scaled per-row contribution
48
+ ``(block_policy + beta * block_kl) * row_scale`` locally, and
49
+ ``atomicAdd``s the scalar result into ``policy_acc[0]``. The last CTA
50
+ — detected via a single grid-scope ``atomic_inc`` on ``counter`` —
51
+ reads ``policy_acc[0]``, casts to the output dtype, writes
52
+ ``output[0]``, and resets the accumulator to ``0``.
53
+ """
54
+
55
+ from __future__ import annotations
56
+
57
+ import math
58
+ import operator
59
+ from typing import TYPE_CHECKING, Any
60
+
61
+ import cutlass
62
+ import cutlass.utils
63
+ import torch
64
+ from cutlass import cute
65
+ from cutlass.base_dsl.typing import cast
66
+
67
+ from ..bnpo_loss.cute_bnpo_loss import (
68
+ _atomic_add_f32_gmem,
69
+ _atomic_inc_u32_gmem,
70
+ )
71
+
72
+ if TYPE_CHECKING:
73
+ from collections.abc import Callable
74
+
75
+
76
+ TILE_N: int = 2048
77
+ NUM_WARPS: int = 16
78
+ VEC: int = 4
79
+ # Long-context variant: a wider tile keeps the per-row block reduction
80
+ # cheap when ``seq_len`` blows past the small tier (where the inner
81
+ # col-tile loop iterates many times). Two compiled variants — small +
82
+ # large — are dispatched at runtime by sequence length.
83
+ TILE_N_LARGE: int = 8192
84
+ TILE_N_LARGE_THRESHOLD: int = 8192
85
+
86
+ _LOG2_E: float = math.log2(math.e)
87
+
88
+ _TORCH_TO_CUTLASS_DTYPE: dict[torch.dtype, Any] = {
89
+ torch.float32: cutlass.Float32,
90
+ torch.float16: cutlass.Float16,
91
+ torch.bfloat16: cutlass.BFloat16,
92
+ }
93
+
94
+
95
+ # ---------------------------------------------------------------------------
96
+ # Main GRPO kernel — single CTA per row.
97
+ # ---------------------------------------------------------------------------
98
+
99
+
100
+ def _make_grpo_kernel(
101
+ compute_kl: bool,
102
+ compute_backward: bool,
103
+ tile_n: int,
104
+ ) -> Callable[..., None]:
105
+ """Build the GRPO kernel specialized on compile-time flags.
106
+
107
+ On the fwd+bwd path each CTA does a mask-only prescan to derive
108
+ ``row_scale`` before the main pass; this lets the main pass fold
109
+ loss accumulation and gradient compute into a single read of the
110
+ logprob tensors. On the fwd-only path the prescan is skipped — the
111
+ main pass accumulates ``valid`` directly.
112
+ """
113
+
114
+ @cute.kernel
115
+ def _grpo_loss_kernel(
116
+ policy: cute.Tensor,
117
+ old_policy: cute.Tensor,
118
+ ref: cute.Tensor,
119
+ advantages: cute.Tensor,
120
+ completions_mask: cute.Tensor,
121
+ dpolicy: cute.Tensor, # (N, C) when compute_backward; (N, 1) dummy otherwise
122
+ policy_acc: cute.Tensor, # (1,) fp32 — global scalar loss accumulator
123
+ counter: cute.Tensor, # (1,) int32 — global last-block detection
124
+ output: cute.Tensor,
125
+ epsilon: cutlass.Float32,
126
+ epsilon_high: cutlass.Float32,
127
+ beta: cutlass.Float32,
128
+ inv_n_rows: cutlass.Float32,
129
+ n_rows: cutlass.Int32,
130
+ num_full_tiles: cutlass.Int32,
131
+ tail_len: cutlass.Int32,
132
+ num_col_tiles: cutlass.Int32,
133
+ ) -> None:
134
+ block_size = NUM_WARPS * 32
135
+ iters = tile_n // (block_size * VEC)
136
+
137
+ _no_alloc = cute.nvgpu.CacheEvictionPriority.NO_ALLOCATE
138
+ g2r_op = cute.nvgpu.CopyUniversalOp()
139
+ # Logprob loads stream once -- ``NO_ALLOCATE`` keeps them out of L1
140
+ # so they don't evict the mask data we *do* re-read. The current
141
+ # single-pass main loop never re-reads pol/old/ref; only the mask
142
+ # is touched twice (fwd+bwd prescan + main pass).
143
+ g2r_atom = cute.make_copy_atom(
144
+ g2r_op,
145
+ policy.element_type,
146
+ num_bits_per_copy=0,
147
+ l1c_evict_priority=_no_alloc,
148
+ )
149
+ # Mask is read twice on the fwd+bwd path (prescan + main), so bias
150
+ # L1 to keep it. On fwd-only the prescan does not run, so the hint
151
+ # is unused and falls through to the streaming default.
152
+ mask_evict = cute.nvgpu.CacheEvictionPriority.EVICT_LAST if compute_backward else _no_alloc
153
+ g2r_mask_atom = cute.make_copy_atom(
154
+ g2r_op,
155
+ completions_mask.element_type,
156
+ num_bits_per_copy=0,
157
+ l1c_evict_priority=mask_evict,
158
+ )
159
+ if cutlass.const_expr(compute_backward):
160
+ r2g_atom = cute.make_copy_atom(
161
+ g2r_op,
162
+ dpolicy.element_type,
163
+ num_bits_per_copy=0,
164
+ )
165
+
166
+ row = cute.arch.block_idx()[0]
167
+ tid = cute.arch.thread_idx()[0]
168
+ lane_idx = cute.arch.lane_idx()
169
+ warp_idx = cute.arch.warp_idx()
170
+
171
+ adv = cast(advantages[row], cutlass.Float32)
172
+ lo = cutlass.Float32(1.0) - epsilon
173
+ hi = cutlass.Float32(1.0) + epsilon_high
174
+
175
+ pol_row = cute.slice_(policy, (row, None))
176
+ old_row = cute.slice_(old_policy, (row, None))
177
+ mask_row = cute.slice_(completions_mask, (row, None))
178
+ if cutlass.const_expr(compute_kl):
179
+ ref_row = cute.slice_(ref, (row, None))
180
+ if cutlass.const_expr(compute_backward):
181
+ dp_row = cute.slice_(dpolicy, (row, None))
182
+
183
+ smem = cutlass.utils.SmemAllocator()
184
+ buf_policy = smem.allocate_tensor(cutlass.Float32, cute.make_layout(NUM_WARPS))
185
+ if cutlass.const_expr(compute_kl):
186
+ buf_kl = smem.allocate_tensor(cutlass.Float32, cute.make_layout(NUM_WARPS))
187
+ # Always need a valid buffer in fwd-only path; on fwd+bwd path the
188
+ # prescan also reuses cross-warp SMEM reduction.
189
+ buf_valid = smem.allocate_tensor(cutlass.Float32, cute.make_layout(NUM_WARPS))
190
+ if cutlass.const_expr(compute_backward):
191
+ row_scale_smem = smem.allocate_tensor(cutlass.Float32, cute.make_layout(1))
192
+
193
+ # ---- Stage A (fwd+bwd only): mask-only prescan → row_scale ----
194
+ if cutlass.const_expr(compute_backward):
195
+ local_valid_pre = cutlass.Float32(0.0)
196
+ for col_block in cutlass.range(num_col_tiles, unroll=1):
197
+ mask_slab = cute.local_tile(mask_row, (tile_n,), (col_block,))
198
+ if col_block < num_full_tiles:
199
+ for k in cutlass.range(iters, unroll_full=True):
200
+ sub_idx = tid + k * block_size
201
+ mask_src = cute.local_tile(mask_slab, (VEC,), (sub_idx,))
202
+ mask_frag = cute.make_fragment_like(mask_src)
203
+ cute.copy(g2r_mask_atom, mask_src, mask_frag)
204
+ local_valid_pre += (
205
+ mask_frag.load()
206
+ .to(cutlass.Float32)
207
+ .reduce(
208
+ cute.ReductionOp.ADD,
209
+ cutlass.Float32(0.0),
210
+ reduction_profile=0,
211
+ )
212
+ )
213
+ else:
214
+ for k in cutlass.range(iters, unroll_full=True):
215
+ sub_idx = tid + k * block_size
216
+ chunk_base = sub_idx * VEC
217
+ if chunk_base < tail_len:
218
+ mask_src = cute.local_tile(mask_slab, (VEC,), (sub_idx,))
219
+ pred = cute.make_rmem_tensor(mask_src.shape, cutlass.Boolean)
220
+ for v in cutlass.range(VEC, unroll_full=True):
221
+ pred[v] = cute.elem_less(chunk_base + v, tail_len)
222
+ mask_frag = cute.make_fragment_like(mask_src)
223
+ mask_frag.fill(0)
224
+ cute.copy(g2r_mask_atom, mask_src, mask_frag, pred=pred)
225
+ local_valid_pre += (
226
+ mask_frag.load()
227
+ .to(cutlass.Float32)
228
+ .reduce(
229
+ cute.ReductionOp.ADD,
230
+ cutlass.Float32(0.0),
231
+ reduction_profile=0,
232
+ )
233
+ )
234
+
235
+ warp_valid_pre = cute.arch.warp_reduction(local_valid_pre, operator.add)
236
+ if lane_idx == 0:
237
+ buf_valid[warp_idx] = warp_valid_pre
238
+ cute.arch.barrier()
239
+
240
+ if warp_idx == 0:
241
+ lane_in_warp_range = lane_idx < NUM_WARPS
242
+ val_v = cutlass.Float32(0.0)
243
+ if lane_in_warp_range:
244
+ val_v = buf_valid[lane_idx]
245
+ block_valid_pre = cute.arch.warp_reduction(
246
+ val_v, operator.add, threads_in_group=NUM_WARPS
247
+ )
248
+ if lane_idx == 0:
249
+ n_v_row = cute.arch.fmax(block_valid_pre, cutlass.Float32(1.0))
250
+ row_scale_smem[0] = cute.arch.rcp_approx(n_v_row) * inv_n_rows
251
+ cute.arch.barrier()
252
+ row_scale_v = row_scale_smem[0]
253
+
254
+ # ---- Stage B: main pass — loss accumulation + (fused) gradient ----
255
+ local_policy_sum = cutlass.Float32(0.0)
256
+ local_kl_sum = cutlass.Float32(0.0)
257
+ # Only the fwd-only path needs to accumulate valid here; the
258
+ # fwd+bwd path already produced ``row_scale_v`` from the prescan.
259
+ local_valid_sum = cutlass.Float32(0.0)
260
+
261
+ for col_block in cutlass.range(num_col_tiles, unroll=1):
262
+ if col_block < num_full_tiles:
263
+ pol_slab = cute.local_tile(pol_row, (tile_n,), (col_block,))
264
+ old_slab = cute.local_tile(old_row, (tile_n,), (col_block,))
265
+ mask_slab = cute.local_tile(mask_row, (tile_n,), (col_block,))
266
+ if cutlass.const_expr(compute_kl):
267
+ ref_slab = cute.local_tile(ref_row, (tile_n,), (col_block,))
268
+ if cutlass.const_expr(compute_backward):
269
+ dp_slab = cute.local_tile(dp_row, (tile_n,), (col_block,))
270
+
271
+ for k in cutlass.range(iters, unroll_full=True):
272
+ sub_idx = tid + k * block_size
273
+
274
+ pol_src = cute.local_tile(pol_slab, (VEC,), (sub_idx,))
275
+ old_src = cute.local_tile(old_slab, (VEC,), (sub_idx,))
276
+ pol_frag = cute.make_fragment_like(pol_src)
277
+ old_frag = cute.make_fragment_like(old_src)
278
+ cute.copy(g2r_atom, pol_src, pol_frag)
279
+ cute.copy(g2r_atom, old_src, old_frag)
280
+
281
+ mask_src = cute.local_tile(mask_slab, (VEC,), (sub_idx,))
282
+ mask_frag = cute.make_fragment_like(mask_src)
283
+ cute.copy(g2r_mask_atom, mask_src, mask_frag)
284
+
285
+ pol_vec = pol_frag.load().to(cutlass.Float32)
286
+ old_vec = old_frag.load().to(cutlass.Float32)
287
+ mask_vec = mask_frag.load().to(cutlass.Float32)
288
+
289
+ log_ratio = pol_vec - old_vec
290
+ ratio = cute.math.exp2(log_ratio * _LOG2_E, fastmath=True)
291
+ surrogate = ratio * adv
292
+ clipped_ratio = cute.where(
293
+ ratio < lo,
294
+ lo,
295
+ cute.where(ratio > hi, hi, ratio),
296
+ )
297
+ clipped = clipped_ratio * adv
298
+ policy_loss = -cute.where(surrogate < clipped, surrogate, clipped)
299
+
300
+ # Compute the (negated) surrogate gradient before the KL
301
+ # block so ``surrogate``/``clipped`` can die before the
302
+ # ref load. ``-(adv * ratio)`` is just ``-surrogate`` —
303
+ # saves one FMUL per element.
304
+ if cutlass.const_expr(compute_backward):
305
+ neg_surrogate_grad = cute.where(
306
+ surrogate <= clipped,
307
+ -surrogate,
308
+ cutlass.Float32(0.0),
309
+ )
310
+
311
+ # Reduce ``policy_loss`` immediately so it can be freed
312
+ # before any further work.
313
+ local_policy_sum += (policy_loss * mask_vec).reduce(
314
+ cute.ReductionOp.ADD,
315
+ cutlass.Float32(0.0),
316
+ reduction_profile=0,
317
+ )
318
+ if cutlass.const_expr(not compute_backward):
319
+ local_valid_sum += mask_vec.reduce(
320
+ cute.ReductionOp.ADD,
321
+ cutlass.Float32(0.0),
322
+ reduction_profile=0,
323
+ )
324
+
325
+ if cutlass.const_expr(compute_kl):
326
+ ref_src = cute.local_tile(ref_slab, (VEC,), (sub_idx,))
327
+ ref_frag = cute.make_fragment_like(ref_src)
328
+ cute.copy(g2r_atom, ref_src, ref_frag)
329
+ ref_vec = ref_frag.load().to(cutlass.Float32)
330
+ log_ratio_ref = ref_vec - pol_vec
331
+ ratio_ref = cute.math.exp2(log_ratio_ref * _LOG2_E, fastmath=True)
332
+ # FFMA-friendly rearrangement: ``(ratio_ref - 1) - log_ratio_ref``
333
+ # exposes a ``ratio_ref + (-1)`` pair that ptxas folds
334
+ # with the subsequent subtract — fewer FADDs surviving
335
+ # SASS than the original 3-term ``a - b - c``.
336
+ kl_val = (ratio_ref - cutlass.Float32(1.0)) - log_ratio_ref
337
+ local_kl_sum += (kl_val * mask_vec).reduce(
338
+ cute.ReductionOp.ADD,
339
+ cutlass.Float32(0.0),
340
+ reduction_profile=0,
341
+ )
342
+
343
+ if cutlass.const_expr(compute_backward):
344
+ if cutlass.const_expr(compute_kl):
345
+ # ``beta - beta*ratio_ref`` instead of ``beta*(1 - ratio_ref)``
346
+ # gives ptxas an obvious FFMA pattern (``FFMA -beta,
347
+ # ratio_ref, beta``) — saves one FMUL per element.
348
+ kl_grad = beta - beta * ratio_ref
349
+ grad_vec = (neg_surrogate_grad + kl_grad) * mask_vec
350
+ else:
351
+ grad_vec = neg_surrogate_grad * mask_vec
352
+ scaled = grad_vec * row_scale_v
353
+
354
+ dp_dst = cute.local_tile(dp_slab, (VEC,), (sub_idx,))
355
+ dp_frag = cute.make_fragment_like(dp_dst)
356
+ dp_frag.store(scaled.to(dpolicy.element_type))
357
+ cute.copy(r2g_atom, dp_frag, dp_dst)
358
+ else:
359
+ pol_slab = cute.local_tile(pol_row, (tile_n,), (col_block,))
360
+ old_slab = cute.local_tile(old_row, (tile_n,), (col_block,))
361
+ mask_slab = cute.local_tile(mask_row, (tile_n,), (col_block,))
362
+ if cutlass.const_expr(compute_kl):
363
+ ref_slab = cute.local_tile(ref_row, (tile_n,), (col_block,))
364
+ if cutlass.const_expr(compute_backward):
365
+ dp_slab = cute.local_tile(dp_row, (tile_n,), (col_block,))
366
+
367
+ for k in cutlass.range(iters, unroll_full=True):
368
+ sub_idx = tid + k * block_size
369
+ chunk_base = sub_idx * VEC
370
+
371
+ if chunk_base < tail_len:
372
+ pol_src = cute.local_tile(pol_slab, (VEC,), (sub_idx,))
373
+ old_src = cute.local_tile(old_slab, (VEC,), (sub_idx,))
374
+ pred = cute.make_rmem_tensor(pol_src.shape, cutlass.Boolean)
375
+ for v in cutlass.range(VEC, unroll_full=True):
376
+ pred[v] = cute.elem_less(chunk_base + v, tail_len)
377
+
378
+ pol_frag = cute.make_fragment_like(pol_src)
379
+ old_frag = cute.make_fragment_like(old_src)
380
+ pol_frag.fill(0.0)
381
+ old_frag.fill(0.0)
382
+ cute.copy(g2r_atom, pol_src, pol_frag, pred=pred)
383
+ cute.copy(g2r_atom, old_src, old_frag, pred=pred)
384
+
385
+ mask_src = cute.local_tile(mask_slab, (VEC,), (sub_idx,))
386
+ mask_frag = cute.make_fragment_like(mask_src)
387
+ mask_frag.fill(0)
388
+ cute.copy(g2r_mask_atom, mask_src, mask_frag, pred=pred)
389
+
390
+ pol_vec = pol_frag.load().to(cutlass.Float32)
391
+ old_vec = old_frag.load().to(cutlass.Float32)
392
+ valid_vec = cute.where(
393
+ pred.load(),
394
+ cute.full_like(pol_vec, cutlass.Float32(1.0)),
395
+ cute.zeros_like(pol_vec, dtype=cutlass.Float32),
396
+ )
397
+ # ``mask_vec * valid_vec`` zeros out-of-bounds lanes.
398
+ mask_vec = mask_frag.load().to(cutlass.Float32) * valid_vec
399
+
400
+ log_ratio = pol_vec - old_vec
401
+ ratio = cute.math.exp2(log_ratio * _LOG2_E, fastmath=True)
402
+ surrogate = ratio * adv
403
+ clipped_ratio = cute.where(
404
+ ratio < lo,
405
+ lo,
406
+ cute.where(ratio > hi, hi, ratio),
407
+ )
408
+ clipped = clipped_ratio * adv
409
+ policy_loss = -cute.where(surrogate < clipped, surrogate, clipped)
410
+
411
+ # Same live-range narrowing as the full-tile path:
412
+ # fold ``surrogate``/``clipped`` into the gradient
413
+ # term and reduce ``policy_loss`` before the KL
414
+ # block so they can be freed. ``-(adv * ratio)`` is
415
+ # ``-surrogate`` — saves one FMUL per element.
416
+ if cutlass.const_expr(compute_backward):
417
+ neg_surrogate_grad = cute.where(
418
+ surrogate <= clipped,
419
+ -surrogate,
420
+ cutlass.Float32(0.0),
421
+ )
422
+
423
+ local_policy_sum += (policy_loss * mask_vec).reduce(
424
+ cute.ReductionOp.ADD,
425
+ cutlass.Float32(0.0),
426
+ reduction_profile=0,
427
+ )
428
+ if cutlass.const_expr(not compute_backward):
429
+ local_valid_sum += mask_vec.reduce(
430
+ cute.ReductionOp.ADD,
431
+ cutlass.Float32(0.0),
432
+ reduction_profile=0,
433
+ )
434
+
435
+ if cutlass.const_expr(compute_kl):
436
+ ref_src = cute.local_tile(ref_slab, (VEC,), (sub_idx,))
437
+ ref_frag = cute.make_fragment_like(ref_src)
438
+ ref_frag.fill(0.0)
439
+ cute.copy(g2r_atom, ref_src, ref_frag, pred=pred)
440
+ ref_vec = ref_frag.load().to(cutlass.Float32)
441
+ log_ratio_ref = ref_vec - pol_vec
442
+ ratio_ref = cute.math.exp2(log_ratio_ref * _LOG2_E, fastmath=True)
443
+ # See full-tile path: FFMA-friendly rearrangement.
444
+ kl_val = (ratio_ref - cutlass.Float32(1.0)) - log_ratio_ref
445
+ local_kl_sum += (kl_val * mask_vec).reduce(
446
+ cute.ReductionOp.ADD,
447
+ cutlass.Float32(0.0),
448
+ reduction_profile=0,
449
+ )
450
+
451
+ if cutlass.const_expr(compute_backward):
452
+ if cutlass.const_expr(compute_kl):
453
+ # See full-tile path: ``beta - beta*ratio_ref``
454
+ # is FFMA-friendly.
455
+ kl_grad = beta - beta * ratio_ref
456
+ grad_vec = (neg_surrogate_grad + kl_grad) * mask_vec
457
+ else:
458
+ grad_vec = neg_surrogate_grad * mask_vec
459
+ scaled = grad_vec * row_scale_v
460
+
461
+ dp_dst = cute.local_tile(dp_slab, (VEC,), (sub_idx,))
462
+ dp_frag = cute.make_fragment_like(dp_dst)
463
+ dp_frag.store(scaled.to(dpolicy.element_type))
464
+ cute.copy(r2g_atom, dp_frag, dp_dst, pred=pred)
465
+
466
+ # ---- Stage C: warp + cross-warp reduction → atomic-add row_loss ----
467
+ warp_policy = cute.arch.warp_reduction(local_policy_sum, operator.add)
468
+ if cutlass.const_expr(compute_kl):
469
+ warp_kl = cute.arch.warp_reduction(local_kl_sum, operator.add)
470
+ if cutlass.const_expr(not compute_backward):
471
+ warp_valid = cute.arch.warp_reduction(local_valid_sum, operator.add)
472
+
473
+ # The fwd+bwd path used ``buf_valid`` for the prescan reduction;
474
+ # ensure all threads have observed ``row_scale_smem`` (Stage A's
475
+ # final barrier) before we reuse ``buf_policy`` / ``buf_valid``.
476
+ # Stage A never runs on the fwd-only path, so the barrier is
477
+ # only needed when ``compute_backward`` is set.
478
+ if cutlass.const_expr(compute_backward):
479
+ cute.arch.barrier()
480
+
481
+ if lane_idx == 0:
482
+ buf_policy[warp_idx] = warp_policy
483
+ if cutlass.const_expr(compute_kl):
484
+ buf_kl[warp_idx] = warp_kl
485
+ if cutlass.const_expr(not compute_backward):
486
+ buf_valid[warp_idx] = warp_valid
487
+ cute.arch.barrier()
488
+
489
+ if warp_idx == 0:
490
+ lane_in_warp_range = lane_idx < NUM_WARPS
491
+
492
+ val_p = cutlass.Float32(0.0)
493
+ if lane_in_warp_range:
494
+ val_p = buf_policy[lane_idx]
495
+ block_policy = cute.arch.warp_reduction(val_p, operator.add, threads_in_group=NUM_WARPS)
496
+
497
+ block_kl = cutlass.Float32(0.0)
498
+ if cutlass.const_expr(compute_kl):
499
+ val_k = cutlass.Float32(0.0)
500
+ if lane_in_warp_range:
501
+ val_k = buf_kl[lane_idx]
502
+ block_kl = cute.arch.warp_reduction(val_k, operator.add, threads_in_group=NUM_WARPS)
503
+
504
+ if cutlass.const_expr(not compute_backward):
505
+ val_v = cutlass.Float32(0.0)
506
+ if lane_in_warp_range:
507
+ val_v = buf_valid[lane_idx]
508
+ block_valid = cute.arch.warp_reduction(
509
+ val_v, operator.add, threads_in_group=NUM_WARPS
510
+ )
511
+
512
+ if lane_idx == 0:
513
+ if cutlass.const_expr(compute_backward):
514
+ row_scale = row_scale_smem[0]
515
+ else:
516
+ n_v_row = cute.arch.fmax(block_valid, cutlass.Float32(1.0))
517
+ row_scale = cute.arch.rcp_approx(n_v_row) * inv_n_rows
518
+
519
+ if cutlass.const_expr(compute_kl):
520
+ row_loss = (block_policy + beta * block_kl) * row_scale
521
+ else:
522
+ row_loss = block_policy * row_scale
523
+
524
+ loss_ptr = policy_acc.iterator.toint().ir_value() # ty: ignore[unresolved-attribute]
525
+ _atomic_add_f32_gmem(loss_ptr, row_loss)
526
+
527
+ # ---- Stage D: last-block detection → write final loss ----
528
+ if warp_idx == 0:
529
+ is_last_lane0 = cutlass.Int32(0)
530
+ if lane_idx == 0:
531
+ counter_ptr = counter.iterator.toint().ir_value() # ty: ignore[unresolved-attribute]
532
+ cute.arch.fence_acq_rel_gpu()
533
+ old = _atomic_inc_u32_gmem(counter_ptr, n_rows - 1)
534
+ if old == n_rows - 1:
535
+ is_last_lane0 = cutlass.Int32(1)
536
+
537
+ is_last = cute.arch.shuffle_sync(is_last_lane0, 0)
538
+
539
+ if is_last == cutlass.Int32(1) and lane_idx == 0:
540
+ # Each CTA already atomic-added its scaled ``row_loss``
541
+ # (including ``inv_n_rows``); the accumulator now holds
542
+ # the final loss.
543
+ total = policy_acc[0]
544
+ output[0] = cast(total, output.element_type) # ty: ignore[invalid-argument-type]
545
+ # Reset for the next call.
546
+ policy_acc[0] = cutlass.Float32(0.0)
547
+
548
+ return _grpo_loss_kernel
549
+
550
+
551
+ # ---------------------------------------------------------------------------
552
+ # Compile-and-run factory
553
+ # ---------------------------------------------------------------------------
554
+
555
+
556
+ def create_compiled_grpo_loss(
557
+ policy_dtype: torch.dtype,
558
+ epsilon: float,
559
+ epsilon_high: float,
560
+ beta: float,
561
+ compute_backward: bool = False,
562
+ ) -> Callable[..., torch.Tensor | tuple[torch.Tensor, torch.Tensor]]:
563
+ """Compile the GRPO loss kernel and return a runtime closure.
564
+
565
+ Two compiled variants — small-tile (``TILE_N``) and large-tile
566
+ (``TILE_N_LARGE``) — are produced; the runner selects one per call
567
+ based on ``seq_len`` vs. ``TILE_N_LARGE_THRESHOLD``.
568
+
569
+ When ``compute_backward=True`` the kernel additionally writes the
570
+ scaled gradient ``dL/d(policy_logprobs)`` to a caller-provided
571
+ ``dpolicy`` tensor in the same launch.
572
+ """
573
+ compute_kl = beta != 0.0
574
+
575
+ if policy_dtype not in _TORCH_TO_CUTLASS_DTYPE:
576
+ raise ValueError(f"Unsupported dtype for GRPO kernel: {policy_dtype}")
577
+
578
+ tile_n_small = TILE_N
579
+ tile_n_large = TILE_N_LARGE
580
+ seq_len_threshold = TILE_N_LARGE_THRESHOLD
581
+ block_size = NUM_WARPS * 32
582
+ if tile_n_small % (block_size * VEC) != 0:
583
+ raise ValueError(
584
+ f"TILE_N={tile_n_small} must be a multiple of BLOCK_SIZE*VEC={block_size * VEC}"
585
+ )
586
+ if tile_n_large % (block_size * VEC) != 0:
587
+ raise ValueError(
588
+ f"TILE_N_LARGE={tile_n_large} must be a multiple of BLOCK_SIZE*VEC={block_size * VEC}"
589
+ )
590
+
591
+ n_rows_sym = cute.sym_int()
592
+ seq_len_sym = cute.sym_int()
593
+ cute_dtype = _TORCH_TO_CUTLASS_DTYPE[policy_dtype]
594
+
595
+ def _fake2d(dt: Any, cols: Any) -> Any:
596
+ return cute.runtime.make_fake_compact_tensor(
597
+ dt,
598
+ (n_rows_sym, cols),
599
+ stride_order=(1, 0),
600
+ assumed_align=16,
601
+ )
602
+
603
+ fake_pol = _fake2d(cute_dtype, seq_len_sym)
604
+ fake_old = _fake2d(cute_dtype, seq_len_sym)
605
+ fake_ref = _fake2d(cute_dtype, seq_len_sym)
606
+ fake_adv = cute.runtime.make_fake_compact_tensor(
607
+ cute_dtype,
608
+ (n_rows_sym,),
609
+ assumed_align=16,
610
+ )
611
+ fake_mask = cute.runtime.make_fake_compact_tensor(
612
+ cutlass.Int8,
613
+ (n_rows_sym, seq_len_sym),
614
+ stride_order=(1, 0),
615
+ assumed_align=16,
616
+ )
617
+ dpolicy_cols = seq_len_sym if compute_backward else 1
618
+ fake_dpolicy = cute.runtime.make_fake_compact_tensor(
619
+ cute_dtype,
620
+ (n_rows_sym, dpolicy_cols),
621
+ stride_order=(1, 0),
622
+ assumed_align=16,
623
+ )
624
+ fake_policy_acc = cute.runtime.make_fake_compact_tensor(
625
+ cutlass.Float32,
626
+ (1,),
627
+ assumed_align=16,
628
+ )
629
+ fake_counter = cute.runtime.make_fake_compact_tensor(
630
+ cutlass.Int32,
631
+ (1,),
632
+ assumed_align=16,
633
+ )
634
+ fake_output = cute.runtime.make_fake_compact_tensor(
635
+ cute_dtype,
636
+ (1,),
637
+ assumed_align=16,
638
+ )
639
+
640
+ def _build_launch(tile_n_v: int) -> Callable[..., None]:
641
+ """Build a JIT-compiled launcher specialized for ``tile_n_v``.
642
+
643
+ Bakes ``compute_kl``, ``compute_backward``, and the column-tile
644
+ width into the kernel as compile-time constants and returns a
645
+ ``@cute.jit`` wrapper that forwards runtime tensors/scalars to
646
+ ``.launch()`` with ``grid=(n_rows, 1, 1)`` and
647
+ ``block=(NUM_WARPS*32, 1, 1)``.
648
+ """
649
+ specialized_kernel = _make_grpo_kernel(compute_kl, compute_backward, tile_n_v)
650
+
651
+ @cute.jit
652
+ def _launch(
653
+ pol_ct: cute.Tensor,
654
+ old_ct: cute.Tensor,
655
+ ref_ct: cute.Tensor,
656
+ adv_ct: cute.Tensor,
657
+ mask_ct: cute.Tensor,
658
+ dpolicy_ct: cute.Tensor,
659
+ policy_acc_ct: cute.Tensor,
660
+ counter_ct: cute.Tensor,
661
+ output_ct: cute.Tensor,
662
+ epsilon_v: cutlass.Float32,
663
+ epsilon_high_v: cutlass.Float32,
664
+ beta_v: cutlass.Float32,
665
+ inv_n_rows_v: cutlass.Float32,
666
+ n_rows_v: cutlass.Int32,
667
+ num_full_tiles_v: cutlass.Int32,
668
+ tail_len_v: cutlass.Int32,
669
+ num_col_tiles_v: cutlass.Int32,
670
+ ) -> None:
671
+ specialized_kernel( # ty: ignore[unresolved-attribute]
672
+ pol_ct,
673
+ old_ct,
674
+ ref_ct,
675
+ adv_ct,
676
+ mask_ct,
677
+ dpolicy_ct,
678
+ policy_acc_ct,
679
+ counter_ct,
680
+ output_ct,
681
+ epsilon_v,
682
+ epsilon_high_v,
683
+ beta_v,
684
+ inv_n_rows_v,
685
+ n_rows_v,
686
+ num_full_tiles_v,
687
+ tail_len_v,
688
+ num_col_tiles_v,
689
+ ).launch(
690
+ grid=(n_rows_v, 1, 1),
691
+ block=(NUM_WARPS * 32, 1, 1),
692
+ )
693
+
694
+ return _launch
695
+
696
+ def _compile_launch(launch_fn: Callable[..., None]) -> Callable[..., None]:
697
+ return cute.compile(
698
+ launch_fn,
699
+ fake_pol,
700
+ fake_old,
701
+ fake_ref,
702
+ fake_adv,
703
+ fake_mask,
704
+ fake_dpolicy,
705
+ fake_policy_acc,
706
+ fake_counter,
707
+ fake_output,
708
+ cutlass.Float32(epsilon),
709
+ cutlass.Float32(epsilon_high),
710
+ cutlass.Float32(beta),
711
+ cutlass.Float32(1.0),
712
+ cutlass.Int32(1),
713
+ cutlass.Int32(1),
714
+ cutlass.Int32(0),
715
+ cutlass.Int32(1),
716
+ options="--enable-tvm-ffi",
717
+ )
718
+
719
+ compiled_small = _compile_launch(_build_launch(tile_n_small))
720
+ if tile_n_large == tile_n_small:
721
+ compiled_large = compiled_small
722
+ else:
723
+ compiled_large = _compile_launch(_build_launch(tile_n_large))
724
+
725
+ eps_const = cutlass.Float32(epsilon)
726
+ eps_high_const = cutlass.Float32(epsilon_high)
727
+ beta_const = cutlass.Float32(beta)
728
+
729
+ # Cross-CTA scratch slab — one int32 buffer with stride-4 (16-byte) slices
730
+ # so each slot is individually 16-byte aligned (``assumed_align=16`` at
731
+ # compile time). Bit-pattern of int32 0 equals fp32 0.0, so a single
732
+ # ``zeros`` factory legitimately initialises both the int32 counter and
733
+ # the fp32 ``policy_acc``. The kernel's last block self-resets
734
+ # ``policy_acc`` in its epilogue and the counter self-resets via
735
+ # ``atom.inc.u32`` wrap-around, so the up-front ``torch.zeros`` only
736
+ # matters for the very first call. Allocated lazily on first ``_run``
737
+ # call when the device is known.
738
+ _scratch: list[torch.Tensor | None] = [None]
739
+
740
+ def _ensure_scratch(device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
741
+ s = _scratch[0]
742
+ if s is None or s.device != device:
743
+ s = torch.zeros(8, dtype=torch.int32, device=device)
744
+ _scratch[0] = s
745
+ return (
746
+ s[0:1], # counter (int32)
747
+ s[4:5].view(torch.float32), # policy_acc (fp32)
748
+ )
749
+
750
+ def _run(
751
+ policy_logprobs_r: torch.Tensor,
752
+ old_policy_logprobs_r: torch.Tensor,
753
+ ref_logprobs_r: torch.Tensor,
754
+ advantages_r: torch.Tensor,
755
+ completions_mask_r: torch.Tensor,
756
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
757
+ n_rows, seq_len = policy_logprobs_r.shape
758
+ device = policy_logprobs_r.device
759
+ dtype = policy_logprobs_r.dtype
760
+
761
+ if seq_len >= seq_len_threshold:
762
+ tile_n_active = tile_n_large
763
+ compiled_active = compiled_large
764
+ else:
765
+ tile_n_active = tile_n_small
766
+ compiled_active = compiled_small
767
+ num_full_tiles = seq_len // tile_n_active
768
+ tail_len = seq_len % tile_n_active
769
+ num_col_tiles = num_full_tiles + (1 if tail_len > 0 else 0)
770
+ inv_n_rows = cutlass.Float32(1.0 / float(n_rows))
771
+
772
+ # Per-call write-only buffers — ``empty`` is enough (Liger / TE pattern).
773
+ output_r = torch.empty(1, dtype=dtype, device=device)
774
+ if compute_backward:
775
+ dpolicy_r = torch.empty_like(policy_logprobs_r)
776
+ else:
777
+ dpolicy_r = torch.empty(n_rows, 1, dtype=dtype, device=device)
778
+
779
+ counter_r, policy_acc_r = _ensure_scratch(device)
780
+
781
+ compiled_active(
782
+ policy_logprobs_r,
783
+ old_policy_logprobs_r,
784
+ ref_logprobs_r,
785
+ advantages_r,
786
+ completions_mask_r,
787
+ dpolicy_r,
788
+ policy_acc_r,
789
+ counter_r,
790
+ output_r,
791
+ eps_const,
792
+ eps_high_const,
793
+ beta_const,
794
+ inv_n_rows,
795
+ cutlass.Int32(n_rows),
796
+ cutlass.Int32(num_full_tiles),
797
+ cutlass.Int32(tail_len),
798
+ cutlass.Int32(num_col_tiles),
799
+ )
800
+ out_view = output_r.view(())
801
+ if compute_backward:
802
+ return out_view, dpolicy_r
803
+ return out_view
804
+
805
+ return _run
build/torch-cuda/layers.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HF Kernels layer adapters for ``kernelize()``.
2
+
3
+ These ``nn.Module`` classes are the entry points for users who want to
4
+ plug our cute kernels into a model via the ``kernels`` library's
5
+ ``kernelize`` flow. Two classes per kernel, one per supported mode:
6
+
7
+ * :class:`bnpoLoss` / :class:`grpoLoss` / :class:`ReverseKL`
8
+ — autograd-aware (``loss.backward()`` works). Register against
9
+ ``Mode.TRAINING`` and/or ``Mode.TRAINING | Mode.TORCH_COMPILE``.
10
+ * :class:`bnpoLossInference` / :class:`grpoLossInference` /
11
+ :class:`ReverseKLInference` — forward-only, no autograd
12
+ dispatcher. Register against ``Mode.INFERENCE`` for inference /
13
+ validation.
14
+
15
+ All are stateless (no ``__init__``, no member tensors) as required by
16
+ ``kernelize`` — it validates that layer classes don't add constructor
17
+ state. The ``has_backward`` and ``can_torch_compile`` attributes are
18
+ the only allowed extras and let ``kernelize`` choose the right layer
19
+ for the requested mode.
20
+
21
+ Forward-signature contract: a downstream user wraps the loss in their
22
+ own ``nn.Module`` (decorated with ``@use_kernel_forward_from_hub(...)``)
23
+ whose ``forward`` matches the signature here. ``kernelize`` swaps the
24
+ ``forward`` method by class identity, so the signature MUST line up
25
+ positionally with the user's module.
26
+
27
+ Typical user-side wiring::
28
+
29
+ from kernels import (
30
+ Mode, LayerRepository, kernelize, use_kernel_forward_from_hub,
31
+ )
32
+
33
+ @use_kernel_forward_from_hub("bnpoLoss")
34
+ class bnpoLoss(nn.Module):
35
+ def forward(self, policy, old, ref, adv, completions_mask,
36
+ epsilon=0.2, epsilon_high=0.2, beta=0.1):
37
+ ... # eager fallback
38
+
39
+ mapping = {
40
+ "bnpoLoss": {
41
+ "cuda": {
42
+ Mode.INFERENCE: LayerRepository(
43
+ repo_id="Geometric-AI/geometric-ai-kernels",
44
+ layer_name="bnpoLossInference",
45
+ ),
46
+ Mode.TRAINING: LayerRepository(
47
+ repo_id="Geometric-AI/geometric-ai-kernels",
48
+ layer_name="bnpoLoss",
49
+ ),
50
+ }
51
+ }
52
+ }
53
+
54
+ with use_kernel_mapping(mapping):
55
+ model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
56
+ """
57
+
58
+ from __future__ import annotations
59
+
60
+ import torch
61
+ from torch import nn
62
+
63
+ from .bnpo_loss import bnpo_loss_fwd
64
+ from .bnpo_loss.autograd import bnpo_loss_autograd
65
+ from .grpo_loss import grpo_loss_fwd
66
+ from .grpo_loss.autograd import grpo_loss_autograd
67
+ from .reverse_kl import (
68
+ reverse_kl_autograd,
69
+ reverse_kl_fwd,
70
+ )
71
+
72
+ __all__ = [
73
+ "ReverseKL",
74
+ "ReverseKLInference",
75
+ "bnpoLoss",
76
+ "bnpoLossInference",
77
+ "grpoLoss",
78
+ "grpoLossInference",
79
+ ]
80
+
81
+
82
+ class bnpoLoss(nn.Module):
83
+ """Training-mode bnpo loss layer. ``loss.backward()`` works.
84
+
85
+ Routes through :func:`bnpo_loss_autograd`, which wraps the fused
86
+ cute kernel in a ``torch.library.custom_op`` with a registered
87
+ backward. Compatible with ``torch.compile`` (the op has a fake
88
+ kernel for shape propagation).
89
+ """
90
+
91
+ has_backward = True
92
+ can_torch_compile = True
93
+
94
+ def forward(
95
+ self,
96
+ policy_logprobs: torch.Tensor,
97
+ old_policy_logprobs: torch.Tensor,
98
+ ref_logprobs: torch.Tensor,
99
+ advantages: torch.Tensor,
100
+ completions_mask: torch.Tensor,
101
+ epsilon: float = 0.2,
102
+ epsilon_high: float = 0.2,
103
+ beta: float = 0.1,
104
+ ) -> torch.Tensor:
105
+ return bnpo_loss_autograd(
106
+ policy_logprobs,
107
+ old_policy_logprobs,
108
+ ref_logprobs,
109
+ advantages,
110
+ completions_mask,
111
+ epsilon=epsilon,
112
+ epsilon_high=epsilon_high,
113
+ beta=beta,
114
+ )
115
+
116
+
117
+ class bnpoLossInference(nn.Module):
118
+ """Inference / validation bnpo loss layer. No autograd dispatcher.
119
+
120
+ Routes through :func:`bnpo_loss_fwd` — the forward-only kernel that
121
+ computes the masked mean denominator on-GPU via the last-block
122
+ trick, skipping the dpolicy buffer entirely.
123
+ """
124
+
125
+ has_backward = False
126
+ can_torch_compile = True
127
+
128
+ def forward(
129
+ self,
130
+ policy_logprobs: torch.Tensor,
131
+ old_policy_logprobs: torch.Tensor,
132
+ ref_logprobs: torch.Tensor,
133
+ advantages: torch.Tensor,
134
+ completions_mask: torch.Tensor,
135
+ epsilon: float = 0.2,
136
+ epsilon_high: float = 0.2,
137
+ beta: float = 0.1,
138
+ ) -> torch.Tensor:
139
+ return bnpo_loss_fwd(
140
+ policy_logprobs,
141
+ old_policy_logprobs,
142
+ ref_logprobs,
143
+ advantages,
144
+ completions_mask,
145
+ epsilon=epsilon,
146
+ epsilon_high=epsilon_high,
147
+ beta=beta,
148
+ )
149
+
150
+
151
+ class grpoLoss(nn.Module):
152
+ """Training-mode GRPO loss layer. ``loss.backward()`` works.
153
+
154
+ Routes through :func:`grpo_loss_autograd`, which wraps the fused
155
+ cute kernel in a ``torch.library.custom_op`` with a registered
156
+ backward. ``completions_mask`` is required (GRPO's per-response
157
+ normalization is mask-derived).
158
+ """
159
+
160
+ has_backward = True
161
+ can_torch_compile = True
162
+
163
+ def forward(
164
+ self,
165
+ policy_logprobs: torch.Tensor,
166
+ old_policy_logprobs: torch.Tensor,
167
+ ref_logprobs: torch.Tensor,
168
+ advantages: torch.Tensor,
169
+ completions_mask: torch.Tensor,
170
+ epsilon: float = 0.2,
171
+ epsilon_high: float = 0.2,
172
+ beta: float = 0.1,
173
+ ) -> torch.Tensor:
174
+ return grpo_loss_autograd(
175
+ policy_logprobs,
176
+ old_policy_logprobs,
177
+ ref_logprobs,
178
+ advantages,
179
+ completions_mask,
180
+ epsilon=epsilon,
181
+ epsilon_high=epsilon_high,
182
+ beta=beta,
183
+ )
184
+
185
+
186
+ class grpoLossInference(nn.Module):
187
+ """Inference / validation GRPO loss layer. No autograd dispatcher.
188
+
189
+ Routes through :func:`grpo_loss_fwd` — the forward-only kernel,
190
+ skipping the dpolicy buffer entirely.
191
+ """
192
+
193
+ has_backward = False
194
+ can_torch_compile = True
195
+
196
+ def forward(
197
+ self,
198
+ policy_logprobs: torch.Tensor,
199
+ old_policy_logprobs: torch.Tensor,
200
+ ref_logprobs: torch.Tensor,
201
+ advantages: torch.Tensor,
202
+ completions_mask: torch.Tensor,
203
+ epsilon: float = 0.2,
204
+ epsilon_high: float = 0.2,
205
+ beta: float = 0.1,
206
+ ) -> torch.Tensor:
207
+ return grpo_loss_fwd(
208
+ policy_logprobs,
209
+ old_policy_logprobs,
210
+ ref_logprobs,
211
+ advantages,
212
+ completions_mask,
213
+ epsilon=epsilon,
214
+ epsilon_high=epsilon_high,
215
+ beta=beta,
216
+ )
217
+
218
+
219
+ class ReverseKL(nn.Module):
220
+ """Training-mode reverse-KL self-distillation loss layer.
221
+
222
+ Routes through :func:`reverse_kl_autograd`, which wraps
223
+ the fused cute kernel in a ``torch.library.custom_op`` with a
224
+ registered backward. ``loss.backward()`` propagates to whatever
225
+ produced ``student_logits``. Compatible with ``torch.compile`` (the
226
+ op has a fake kernel for shape propagation).
227
+ """
228
+
229
+ has_backward = True
230
+ can_torch_compile = True
231
+
232
+ def forward(
233
+ self,
234
+ student_logits: torch.Tensor,
235
+ teacher_logits: torch.Tensor,
236
+ completions_mask: torch.Tensor,
237
+ ) -> torch.Tensor:
238
+ return reverse_kl_autograd(student_logits, teacher_logits, completions_mask)
239
+
240
+
241
+ class ReverseKLInference(nn.Module):
242
+ """Inference / validation reverse-KL self-distillation loss layer.
243
+
244
+ Routes through :func:`reverse_kl_fwd` — the forward-only
245
+ kernel that dead-code-eliminates the gradient pass and skips the
246
+ grad-student buffer entirely.
247
+ """
248
+
249
+ has_backward = False
250
+ can_torch_compile = True
251
+
252
+ def forward(
253
+ self,
254
+ student_logits: torch.Tensor,
255
+ teacher_logits: torch.Tensor,
256
+ completions_mask: torch.Tensor,
257
+ ) -> torch.Tensor:
258
+ return reverse_kl_fwd(student_logits, teacher_logits, completions_mask)
build/torch-cuda/metadata.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "id": "_geometric_ai_kernels_cuda_a766fbd_dirty",
3
+ "version": 0,
4
+ "license": "Apache-2.0",
5
+ "python-depends": [
6
+ "tvm-ffi",
7
+ "nvidia-cutlass-dsl"
8
+ ],
9
+ "backend": {
10
+ "type": "cuda"
11
+ }
12
+ }
build/torch-cuda/reverse_kl/__init__.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Reverse-KL self-distillation loss with CuteDSL fused fwd+bwd.
2
+
3
+ Three public APIs route to two compiled kernels:
4
+
5
+ * :func:`reverse_kl` — primary training entry point.
6
+ Returns ``(loss, grad_student_logits)`` from a single fused fwd+bwd
7
+ kernel launch. Inputs do **not** need ``requires_grad=True`` and there
8
+ is no ``torch.autograd.Function`` wrapper — chain the gradient into
9
+ the upstream model with ``student_logits.backward(grad)``.
10
+ * :func:`reverse_kl_fwd` — inference / validation path.
11
+ Returns the scalar ``loss`` from a forward-only kernel that
12
+ dead-code-eliminates the gradient pass.
13
+ * :func:`reverse_kl_autograd` — autograd-aware variant via
14
+ ``torch.library.custom_op``. Returns scalar ``loss``;
15
+ ``loss.backward()`` Just Works. Re-exported from
16
+ :mod:`reverse_kl.autograd`.
17
+
18
+ Per-call output and gradient buffers are allocated inside the runner;
19
+ cross-CTA scratch (atomic accumulators + counters + the constant
20
+ ``grad_output=1.0`` scalar) is owned by the compiled-kernel closure and
21
+ self-resets each launch — callers don't manage scratch state.
22
+
23
+ Why no autograd wrapper for :func:`reverse_kl`?
24
+ The reverse-KL gradient is closed-form: ``dL/d(student) = mask *
25
+ inv_n_valid * p * (log_p - log_q - kl_per_row)``. The fused kernel
26
+ already writes that analytically in the same launch as the loss.
27
+ Wrapping in ``torch.autograd.Function`` would cost an extra
28
+ ``grad_output * dpolicy`` kernel on backward (~2× per-call overhead in
29
+ practice) and is opaque to ``torch.compile``. Use the autograd-aware
30
+ variant when you need ``loss.backward()`` ergonomics; pay only the
31
+ ``custom_op`` dispatcher cost (also Inductor-traceable).
32
+ """
33
+
34
+ from __future__ import annotations
35
+
36
+ from functools import lru_cache
37
+ from typing import TYPE_CHECKING
38
+
39
+ import torch
40
+
41
+ from .cute_reverse_kl import create_compiled_reverse_kl
42
+
43
+ if TYPE_CHECKING:
44
+ from collections.abc import Callable
45
+
46
+
47
+ __all__ = [
48
+ "reverse_kl",
49
+ "reverse_kl_autograd",
50
+ "reverse_kl_fwd",
51
+ ]
52
+
53
+
54
+ @lru_cache(maxsize=32)
55
+ def _get_compiled_fwd(
56
+ dtype: torch.dtype,
57
+ vocab: int,
58
+ ) -> Callable[..., torch.Tensor]:
59
+ return create_compiled_reverse_kl( # ty: ignore[invalid-return-type]
60
+ policy_dtype=dtype,
61
+ vocab=vocab,
62
+ compute_backward=False,
63
+ )
64
+
65
+
66
+ @lru_cache(maxsize=32)
67
+ def _get_compiled_fwd_bwd(
68
+ dtype: torch.dtype,
69
+ vocab: int,
70
+ ) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]:
71
+ return create_compiled_reverse_kl( # ty: ignore[invalid-return-type]
72
+ policy_dtype=dtype,
73
+ vocab=vocab,
74
+ compute_backward=True,
75
+ )
76
+
77
+
78
+ def _flatten_inputs(
79
+ student_logits: torch.Tensor,
80
+ teacher_logits: torch.Tensor,
81
+ completions_mask: torch.Tensor,
82
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, tuple[int, ...], int, int]:
83
+ """Reshape ``(*, V)`` inputs to ``(num_rows, V)`` for the kernel.
84
+
85
+ The kernel works on a flat ``(num_rows, V)`` slab; the public API
86
+ accepts arbitrary leading dims (typically ``(N, C, V)``) and we
87
+ flatten here so the wrapper signature stays user-friendly. We
88
+ assume contiguous-on-V inputs — ``view`` becomes a no-copy reshape.
89
+ """
90
+ if student_logits.shape != teacher_logits.shape:
91
+ raise ValueError(
92
+ "student_logits and teacher_logits must have the same shape; got "
93
+ f"{tuple(student_logits.shape)} vs {tuple(teacher_logits.shape)}"
94
+ )
95
+ if student_logits.dtype != teacher_logits.dtype:
96
+ raise ValueError(
97
+ "student_logits and teacher_logits must have the same dtype; got "
98
+ f"{student_logits.dtype} vs {teacher_logits.dtype}"
99
+ )
100
+
101
+ leading_shape = tuple(student_logits.shape[:-1])
102
+ vocab = int(student_logits.shape[-1])
103
+
104
+ student_2d = student_logits.view(-1, vocab)
105
+ teacher_2d = teacher_logits.view(-1, vocab)
106
+ num_rows = student_2d.shape[0]
107
+
108
+ if tuple(completions_mask.shape) != leading_shape:
109
+ raise ValueError(
110
+ f"completions_mask must have shape {leading_shape} (logits' leading "
111
+ f"dims); got {tuple(completions_mask.shape)}"
112
+ )
113
+ flat_mask = completions_mask.view(-1)
114
+ if flat_mask.dtype != torch.float32:
115
+ flat_mask = flat_mask.to(torch.float32)
116
+
117
+ return student_2d, teacher_2d, flat_mask, leading_shape, num_rows, vocab
118
+
119
+
120
+ def reverse_kl_fwd(
121
+ student_logits: torch.Tensor,
122
+ teacher_logits: torch.Tensor,
123
+ completions_mask: torch.Tensor,
124
+ ) -> torch.Tensor:
125
+ """Forward-only reverse-KL self-distillation loss. Returns scalar ``loss``.
126
+
127
+ Use for inference / validation. The masked mean denominator is
128
+ computed on-GPU by the bundled mask-sum kernel — no host
129
+ ``mask.sum()`` syncs.
130
+
131
+ Args:
132
+ student_logits, teacher_logits: ``(*, V)`` logit tensors with
133
+ arbitrary leading dims (typically ``(N, C, V)``); both must
134
+ share shape and dtype.
135
+ completions_mask: Bool / int / float mask with shape matching
136
+ ``student_logits.shape[:-1]``; truthy = valid token.
137
+
138
+ Returns:
139
+ Scalar tensor (0-dim) with the same dtype as ``student_logits``.
140
+ """
141
+ student_2d, teacher_2d, flat_mask, _, _, vocab = _flatten_inputs(
142
+ student_logits, teacher_logits, completions_mask
143
+ )
144
+ run = _get_compiled_fwd(student_logits.dtype, vocab)
145
+ return run(student_2d, teacher_2d, flat_mask)
146
+
147
+
148
+ def reverse_kl(
149
+ student_logits: torch.Tensor,
150
+ teacher_logits: torch.Tensor,
151
+ completions_mask: torch.Tensor,
152
+ ) -> tuple[torch.Tensor, torch.Tensor]:
153
+ """Fused fwd+bwd reverse-KL self-distillation. Returns ``(loss, grad_student)``.
154
+
155
+ Single-launch training entry point. The kernel writes both the
156
+ scalar loss and the analytical ``dL/d(student_logits)`` in one
157
+ ``@cute.jit`` dispatch — the bundled mask-sum kernel populates
158
+ ``inv_n_valid`` on-GPU before the main kernel reads it, so there's
159
+ no host-side ``mask.sum()`` round trip.
160
+
161
+ Inputs do **not** need ``requires_grad=True``. To chain ``grad``
162
+ into the upstream model that produced ``student_logits``::
163
+
164
+ loss, grad = reverse_kl(student_logits, teacher_logits, mask)
165
+ student_logits.backward(grad)
166
+ optimizer.step()
167
+
168
+ Args:
169
+ student_logits, teacher_logits: ``(*, V)`` logit tensors with
170
+ arbitrary leading dims; both must share shape and dtype.
171
+ completions_mask: Mask with shape matching
172
+ ``student_logits.shape[:-1]``.
173
+
174
+ Returns:
175
+ ``(loss, grad_student_logits)`` — ``loss`` is a 0-dim tensor in
176
+ ``student_logits.dtype``; ``grad_student_logits`` matches
177
+ ``student_logits.shape`` and is already scaled by
178
+ ``1 / n_valid`` (undefined when ``n_valid == 0`` — fully-masked
179
+ batches produce inf/NaN; callers must guard upstream). The grad
180
+ tensor is freshly allocated per call (no shared cache).
181
+
182
+ For inference / validation where you only need the loss, use
183
+ :func:`reverse_kl_fwd` — it skips the gradient slab entirely.
184
+ """
185
+ student_2d, teacher_2d, flat_mask, leading_shape, _, vocab = _flatten_inputs(
186
+ student_logits, teacher_logits, completions_mask
187
+ )
188
+ run = _get_compiled_fwd_bwd(student_logits.dtype, vocab)
189
+ loss, grad_2d = run(student_2d, teacher_2d, flat_mask)
190
+ grad_student = grad_2d.view((*leading_shape, vocab))
191
+ return loss, grad_student
192
+
193
+
194
+ # Imported at the bottom: ``autograd.py`` imports ``reverse_kl`` from
195
+ # this module, so the function must be fully defined before its import runs.
196
+ from .autograd import reverse_kl_autograd # noqa: E402
build/torch-cuda/reverse_kl/_torch_ref.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Plain-PyTorch reverse-KL reference shared between bench and tests.
2
+
3
+ Every op is a vanilla torch op so AOTAutograd can derive the joint
4
+ fwd+bwd graph and Inductor can fuse both passes (used by
5
+ ``benchmarks/benchmark_reverse_kl.py``'s compiled baseline).
6
+ The same function is imported by ``tests/test_reverse_kl.py``
7
+ as the correctness reference, so both paths agree on what "the eager
8
+ torch implementation of reverse-KL self-distillation" means.
9
+
10
+ Reverse-KL definition (KL(student || teacher)):
11
+
12
+ p = softmax(student)
13
+ q = softmax(teacher)
14
+ kl_per_row = sum_v p_v * (log p_v - log q_v)
15
+ loss = sum_r mask_r * kl_per_row[r] / max(sum_r mask_r, 1)
16
+
17
+ The ``clamp(min=1)`` matches TRL's masked-mean convention so a
18
+ fully-masked batch yields ``loss=0`` instead of NaN, mirroring the
19
+ cute kernel's ``cute.arch.fmax(n_valid, 1.0)`` clamp.
20
+
21
+ Underscore-prefixed module name signals "shared internal", not a public
22
+ API surface — there's no re-export from the package's top-level
23
+ ``__init__.py``.
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ import torch
29
+ from torch.nn.functional import kl_div, log_softmax
30
+
31
+
32
+ def torch_reverse_kl(
33
+ student_logits: torch.Tensor,
34
+ teacher_logits: torch.Tensor,
35
+ completions_mask: torch.Tensor,
36
+ ) -> torch.Tensor:
37
+ """Compute reverse-KL divergence loss.
38
+
39
+ Computes per-token reverse KL divergence:
40
+
41
+ KL(student || teacher) = sum_v p(v) [log p(v) - log q(v)]
42
+
43
+ where p is the student distribution and q is the teacher distribution,
44
+ both obtained by softmax over the vocabulary dimension.
45
+
46
+ Args:
47
+ student_logits: Student logits, shape (N, C, V).
48
+ teacher_logits: Teacher logits, shape (N, C, V).
49
+ completions_mask: Boolean mask of shape (N, C) where True marks
50
+ valid tokens. Pass an all-ones mask if no tokens are padded.
51
+
52
+ Returns:
53
+ Scalar tensor representing the loss.
54
+ """
55
+ log_p = log_softmax(student_logits, dim=-1)
56
+ log_q = log_softmax(teacher_logits, dim=-1)
57
+
58
+ # kl_div(input, target, log_target=True) computes KL(target || input)
59
+ # so input=log_q, target=log_p gives KL(student || teacher)
60
+ kl = kl_div(log_q, log_p, log_target=True, reduction="none").sum(dim=-1)
61
+
62
+ n_valid = completions_mask.sum().to(torch.float32)
63
+ kl = (kl * completions_mask).sum() / n_valid
64
+
65
+ return kl.to(student_logits.dtype)
build/torch-cuda/reverse_kl/autograd.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Autograd-aware wrapper for reverse-KL self-distillation via ``torch.library.custom_op``.
2
+
3
+ The fused cute kernel writes both the scalar loss and the closed-form
4
+ ``dL/d(student_logits)`` in one launch. This module wraps that into an
5
+ autograd-compatible op so callers can write::
6
+
7
+ loss = reverse_kl_autograd(student, teacher, completions_mask)
8
+ loss.backward() # propagates through to whatever produced student_logits
9
+
10
+ instead of the manual ``student.backward(grad)`` chain. The cost is
11
+ ~12µs of autograd dispatcher overhead per call (vs the direct
12
+ ``reverse_kl`` (loss, grad) tuple); for ergonomic /
13
+ ``kernelize()`` flows that's cheap, but for tight microbenches use the
14
+ direct path.
15
+
16
+ Implementation notes:
17
+
18
+ - The registered op returns ``(loss, grad_student)`` so
19
+ ``setup_context`` can ``save_for_backward(grad_student)``. The public
20
+ :func:`reverse_kl_autograd` wrapper hides the second output.
21
+ - The runner allocates ``grad_student`` fresh on every call (no shared
22
+ cache), so ``ctx.save_for_backward(grad_student)`` keeps a stable
23
+ reference for free.
24
+ - Backward returns ``grad_loss * grad_student``. Under
25
+ ``torch.compile``, when ``loss`` is consumed by ``.backward()``
26
+ directly, ``grad_loss`` is the constant 1.0 and Inductor can fold the
27
+ multiply away — the main reason this path uses ``custom_op`` instead
28
+ of a plain ``autograd.Function``.
29
+ - ``register_fake`` provides the meta kernel for ``torch.compile``
30
+ shape propagation; the real cute kernel never runs under
31
+ ``FakeTensorMode``.
32
+ """
33
+
34
+ from __future__ import annotations
35
+
36
+ import torch
37
+
38
+ from . import reverse_kl as _reverse_kl_fwd_bwd
39
+
40
+ __all__ = ["reverse_kl_autograd"]
41
+
42
+
43
+ @torch.library.custom_op(
44
+ "geometric_ai_kernels::_reverse_kl_with_grad",
45
+ mutates_args=(),
46
+ )
47
+ def _reverse_kl_with_grad(
48
+ student_logits: torch.Tensor,
49
+ teacher_logits: torch.Tensor,
50
+ completions_mask: torch.Tensor,
51
+ ) -> tuple[torch.Tensor, torch.Tensor]:
52
+ loss, grad_student = _reverse_kl_fwd_bwd(
53
+ student_logits,
54
+ teacher_logits,
55
+ completions_mask,
56
+ )
57
+ return loss, grad_student
58
+
59
+
60
+ @_reverse_kl_with_grad.register_fake
61
+ def _(
62
+ student_logits: torch.Tensor,
63
+ teacher_logits: torch.Tensor,
64
+ completions_mask: torch.Tensor,
65
+ ) -> tuple[torch.Tensor, torch.Tensor]:
66
+ # Signature must mirror the op; only ``student_logits`` shapes the outputs.
67
+ del teacher_logits, completions_mask
68
+ loss = student_logits.new_empty(())
69
+ grad_student = torch.empty_like(student_logits)
70
+ return loss, grad_student
71
+
72
+
73
+ def _setup_context(ctx, inputs, output) -> None: # type: ignore[no-untyped-def]
74
+ del inputs # only ``output`` carries what we need to save.
75
+ _, grad_student = output
76
+ ctx.save_for_backward(grad_student)
77
+
78
+
79
+ def _backward(ctx, grad_loss, grad_grad_student): # type: ignore[no-untyped-def]
80
+ # ``grad_grad_student`` is unused — ``grad_student`` is an internal
81
+ # intermediate exposed only so ``setup_context`` can save it. Under
82
+ # typical usage (``loss.backward()``) it arrives as ``None`` or a
83
+ # zero tensor.
84
+ del grad_grad_student
85
+ (grad_student,) = ctx.saved_tensors
86
+ grad_input = grad_loss * grad_student
87
+ # One return per input to the op (3): student_logits gets the grad,
88
+ # teacher_logits and completions_mask have no autograd flow.
89
+ return grad_input, None, None
90
+
91
+
92
+ torch.library.register_autograd(
93
+ "geometric_ai_kernels::_reverse_kl_with_grad",
94
+ _backward,
95
+ setup_context=_setup_context,
96
+ )
97
+
98
+
99
+ def reverse_kl_autograd(
100
+ student_logits: torch.Tensor,
101
+ teacher_logits: torch.Tensor,
102
+ completions_mask: torch.Tensor,
103
+ ) -> torch.Tensor:
104
+ """Autograd-aware reverse-KL self-distillation. Returns scalar ``loss``.
105
+
106
+ Same numerics as :func:`reverse_kl` but registered as a
107
+ ``torch.library`` custom op with autograd, so::
108
+
109
+ loss = reverse_kl_autograd(student, teacher, completions_mask)
110
+ loss.backward()
111
+
112
+ propagates through to whatever produced ``student_logits``. For
113
+ direct ``(loss, grad)`` access without the autograd dispatcher
114
+ overhead, use :func:`reverse_kl` and chain the gradient
115
+ manually via ``student_logits.backward(grad)``.
116
+
117
+ Composes with ``torch.compile``: the op is opaque to Inductor but
118
+ has a fake/meta kernel registered, so models containing this layer
119
+ can be compiled end-to-end without graph breaks.
120
+ """
121
+ loss, _ = _reverse_kl_with_grad(student_logits, teacher_logits, completions_mask)
122
+ return loss
build/torch-cuda/reverse_kl/cute_reverse_kl.py ADDED
@@ -0,0 +1,881 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Fused single-pass reverse-KL self-distillation loss kernel (CuteDSL, SM90).
2
+
3
+ Computes ``KL(student || teacher)`` over a ``(num_tokens, vocab)`` slab using
4
+ an online normalisation algorithm that reads each logit row exactly once.
5
+ The two log-softmax passes, the element-wise product ``p * (log_p - log_q)``,
6
+ the sum-reduction over ``V``, the mask application, and the final mean are
7
+ all fused into one launch:
8
+
9
+ 1. **Mask-sum kernel** — reduces ``mask_flat`` and writes
10
+ ``inv_n_valid = 1 / sum(mask)``. Undefined when ``sum(mask) == 0``
11
+ (fully-masked batches produce inf/NaN in the final loss); callers
12
+ must guard upstream if that case is reachable.
13
+ 2. **Per-row main kernel** — one CTA per token; computes per-row KL and
14
+ (when ``compute_backward=True``) writes the analytical gradient
15
+ through the softmax Jacobian:
16
+
17
+ ``grad_student_v = scale * p_v * (log_p_v - log_q_v - KL_per_row)``
18
+
19
+ where ``scale = mask[r] * grad_output * inv_n_valid``. The two passes
20
+ per row (online stats then gradient write-out) execute within a single
21
+ CTA so the Pass-1 ``KL_per_row`` is broadcast through SMEM with no
22
+ DSMEM/cluster overhead. The cross-row loss reduction piggybacks on the
23
+ atomic-add + last-block-detect pattern used by ``cute_bnpo_loss``.
24
+
25
+ When ``compute_backward=False`` Pass 2, the broadcast SMEM, the
26
+ ``grad_output`` read, and the ``* grad_output`` factor in the final scalar
27
+ are all dead-code-eliminated at trace time — the kernel becomes a pure
28
+ forward path identical in PTX to a hand-written fwd-only kernel.
29
+
30
+ A single public entry point :func:`create_compiled_reverse_kl`
31
+ JIT-compiles either the fwd-only or fused fwd+bwd path depending on its
32
+ ``compute_backward`` flag. ``vocab`` is captured as a compile-time
33
+ constant (the tile + tail layout closes over it); the number of token
34
+ rows is symbolic and may vary across calls.
35
+ """
36
+
37
+ from __future__ import annotations
38
+
39
+ import math
40
+ import operator
41
+ from typing import TYPE_CHECKING, Any
42
+
43
+ import cutlass
44
+ import torch
45
+ from cutlass import cute
46
+ from cutlass._mlir.dialects import llvm
47
+ from cutlass.base_dsl.typing import cast
48
+ from cutlass.cute.nvgpu import CacheEvictionPriority, CopyUniversalOp
49
+ from cutlass.cutlass_dsl import T, dsl_user_op
50
+ from cutlass.utils import SmemAllocator
51
+
52
+ if TYPE_CHECKING:
53
+ from collections.abc import Callable
54
+
55
+
56
+ # ---------------------------------------------------------------------------
57
+ # Tunable constants — shared by fwd-only and fwd+bwd kernels
58
+ # ---------------------------------------------------------------------------
59
+ # Wider tile, more threads, and 128-bit loads so the two-pass (fwd+bwd) and
60
+ # single-pass (fwd-only) variants both stay bandwidth-bound from the same
61
+ # specialisation. ``FB_VEC_SIZE`` is recomputed at compile time from
62
+ # ``LOAD_BITS`` and the dtype width:
63
+ # FP16/BF16 → VEC=8, TILE_V=8192
64
+ # FP32 → VEC=4, TILE_V=4096
65
+ FB_NUM_THREADS = 1024
66
+ FB_NUM_WARPS = FB_NUM_THREADS // 32 # 32
67
+ FB_LOAD_BITS = 128
68
+
69
+ _LOG2E = math.log2(math.e) # 1.4426950408889634
70
+
71
+ _TORCH_TO_CUTLASS_DTYPE: dict[torch.dtype, Any] = {
72
+ torch.float32: cutlass.Float32,
73
+ torch.float16: cutlass.Float16,
74
+ torch.bfloat16: cutlass.BFloat16,
75
+ }
76
+
77
+
78
+ # ---------------------------------------------------------------------------
79
+ # Vendored atomic helpers (mirrors cute_bnpo_loss._atomic_*). Copied locally
80
+ # so the reverse_kl subpackage stays independent of bnpo_loss —
81
+ # the two ship together today, but kernel packages should be free-standing
82
+ # so a single one can be peeled off for separate publishing.
83
+ # ---------------------------------------------------------------------------
84
+
85
+
86
+ @dsl_user_op
87
+ def _atomic_add_f32_gmem(
88
+ ptr_i64: Any,
89
+ val: cutlass.Float32,
90
+ *,
91
+ loc: Any = None,
92
+ ip: Any = None,
93
+ ) -> None:
94
+ llvm.inline_asm(
95
+ T.f32(),
96
+ [ptr_i64, cutlass.Float32(val).ir_value(loc=loc, ip=ip)],
97
+ "atom.global.add.f32 $0, [$1], $2;",
98
+ "=f,l,f",
99
+ has_side_effects=True,
100
+ is_align_stack=False,
101
+ asm_dialect=llvm.AsmDialect.AD_ATT,
102
+ )
103
+
104
+
105
+ @dsl_user_op
106
+ def _atomic_inc_u32_gmem(
107
+ ptr_i64: Any,
108
+ threshold: cutlass.Int32,
109
+ *,
110
+ loc: Any = None,
111
+ ip: Any = None,
112
+ ) -> cutlass.Int32:
113
+ """``atom.global.inc.u32`` — returns old value; wraps to 0 at threshold."""
114
+ return cutlass.Int32(
115
+ llvm.inline_asm(
116
+ T.i32(),
117
+ [ptr_i64, cutlass.Int32(threshold).ir_value(loc=loc, ip=ip)],
118
+ "atom.global.inc.u32 $0, [$1], $2;",
119
+ "=r,l,r",
120
+ has_side_effects=True,
121
+ is_align_stack=False,
122
+ asm_dialect=llvm.AsmDialect.AD_ATT,
123
+ )
124
+ )
125
+
126
+
127
+ # ---------------------------------------------------------------------------
128
+ # Mask-sum kernel — reduces ``mask_flat`` (fp32, length N) to a per-block
129
+ # partial via warp + SMEM reduction, then atomically accumulates into
130
+ # ``valid_acc``; the last block writes ``rcp_approx(n_valid)`` into
131
+ # ``inv_n_valid`` and resets ``valid_acc`` to 0. Counter self-resets via
132
+ # ``atom.inc.u32`` wrap-around.
133
+ # ---------------------------------------------------------------------------
134
+
135
+
136
+ def _make_mask_sum_kernel(
137
+ fb_num_threads: int,
138
+ fb_num_warps: int,
139
+ ) -> Callable[..., None]:
140
+ """Return a ``@cute.kernel`` that reduces ``mask_flat`` and writes 1/sum.
141
+
142
+ Grid: ``(num_blocks, 1, 1)`` where each block processes
143
+ ``fb_num_threads`` mask elements (one element per thread, no
144
+ vectorisation — sufficient for the small N relative to vocab work).
145
+ A separate ``mask_counter`` (not shared with the main kernel's
146
+ ``counter``) is required because both rely on ``atom.inc.u32``
147
+ wrap-around for self-reset.
148
+ """
149
+
150
+ @cute.kernel
151
+ def _mask_sum_kernel(
152
+ mask_flat: cute.Tensor,
153
+ inv_n_valid: cute.Tensor,
154
+ valid_acc: cute.Tensor,
155
+ mask_counter: cute.Tensor,
156
+ num_rows: cutlass.Int32,
157
+ total_blocks: cutlass.Int32,
158
+ ) -> None:
159
+ block_size = fb_num_warps * 32
160
+ bidx = cute.arch.block_idx()[0]
161
+ tidx = cute.arch.thread_idx()[0]
162
+
163
+ global_idx = bidx * block_size + tidx
164
+ local_val = cutlass.Float32(0.0)
165
+ if global_idx < num_rows:
166
+ local_val = mask_flat[global_idx]
167
+
168
+ warp_val = cute.arch.warp_reduction(local_val, operator.add)
169
+
170
+ smem = SmemAllocator()
171
+ buf = smem.allocate_tensor(cutlass.Float32, cute.make_layout(fb_num_warps))
172
+
173
+ lane_idx = cute.arch.lane_idx()
174
+ warp_idx = cute.arch.warp_idx()
175
+
176
+ if lane_idx == 0:
177
+ buf[warp_idx] = warp_val
178
+ cute.arch.barrier()
179
+
180
+ if warp_idx == 0:
181
+ v = cutlass.Float32(0.0)
182
+ if lane_idx < fb_num_warps:
183
+ v = buf[lane_idx]
184
+ block_sum = cute.arch.warp_reduction(v, operator.add, threads_in_group=fb_num_warps)
185
+
186
+ if lane_idx == 0:
187
+ valid_ptr = valid_acc.iterator.toint().ir_value() # ty: ignore[unresolved-attribute]
188
+ counter_ptr = mask_counter.iterator.toint().ir_value() # ty: ignore[unresolved-attribute]
189
+ _atomic_add_f32_gmem(valid_ptr, block_sum)
190
+ cute.arch.fence_acq_rel_gpu()
191
+ old = _atomic_inc_u32_gmem(counter_ptr, total_blocks - 1)
192
+ if old == total_blocks - 1:
193
+ n_valid = valid_acc[0]
194
+ inv_n_valid[0] = cute.arch.rcp_approx(n_valid)
195
+ valid_acc[0] = cutlass.Float32(0.0)
196
+
197
+ return _mask_sum_kernel
198
+
199
+
200
+ # ---------------------------------------------------------------------------
201
+ # Main per-row reverse-KL kernel (fused fwd[+bwd]).
202
+ # ---------------------------------------------------------------------------
203
+
204
+
205
+ def _make_reverse_kl_kernel(
206
+ compute_backward: bool,
207
+ fb_num_threads: int,
208
+ fb_num_warps: int,
209
+ fb_vec_size: int,
210
+ fb_tile_v: int,
211
+ ) -> Callable[..., None]:
212
+ """Return a ``@cute.kernel`` that fuses fwd loss + (optional) bwd grad.
213
+
214
+ One CTA processes one row. Pass 1 computes the online softmax stats
215
+ ``(max_s, D_s, W = sum exp(s - max_s) (s - t), max_t, D_t)`` and the
216
+ per-row ``KL = W/D_s + (max_t - max_s) + log(D_t) - log(D_s)``.
217
+
218
+ When ``compute_backward=True`` the block-reduced stats are broadcast
219
+ through SMEM and Pass 2 re-reads student/teacher to write the
220
+ analytical gradient
221
+ ``grad_v = scale * p_v * (log_p_v - log_q_v - KL)`` where
222
+ ``p_v = exp(s_v - max_s) / D_s``,
223
+ ``log_p_v - log_q_v = (s_v - t_v) + (max_t - max_s) + log(D_t) - log(D_s)``,
224
+ and ``scale = mask[r] * grad_output * inv_n_valid``.
225
+
226
+ When ``compute_backward=False`` Pass 2, the bcast SMEM, the
227
+ ``grad_output`` read, and the ``* grad_output`` factor in the final
228
+ scalar are eliminated at trace time.
229
+
230
+ Cross-row loss accumulation rides on an atomic ``loss_acc`` plus a
231
+ wrap-around ``counter`` for last-block detection. The kernel always
232
+ reads ``mask_flat[bidx]`` (callers pass an all-ones mask in the
233
+ unmasked case) and, in the bwd path, short-circuits Pass 2 — writes
234
+ zeros and skips loss accumulation — for rows whose ``scale`` is 0.
235
+ """
236
+
237
+ @cute.kernel
238
+ def _kernel(
239
+ student: cute.Tensor,
240
+ teacher: cute.Tensor,
241
+ mask_flat: cute.Tensor,
242
+ grad_output: cute.Tensor,
243
+ inv_n_valid: cute.Tensor,
244
+ grad_student: cute.Tensor,
245
+ loss_acc: cute.Tensor,
246
+ counter: cute.Tensor,
247
+ output: cute.Tensor,
248
+ num_full_tiles: int,
249
+ tail_len: int,
250
+ total_rows: cutlass.Int32,
251
+ tiled_copy_p1: cute.TiledCopy,
252
+ copy_atom_p1: cute.CopyAtom,
253
+ tiled_copy_p2: cute.TiledCopy,
254
+ copy_atom_p2: cute.CopyAtom,
255
+ copy_atom_store: cute.CopyAtom,
256
+ ) -> None:
257
+ tidx = cute.arch.thread_idx()[0]
258
+ bidx = cute.arch.block_idx()[0]
259
+
260
+ # SMEM:
261
+ # red_buf: NUM_WARPS x 5 — per-warp partials for cross-warp reduce
262
+ # bcast: 6 floats — broadcast (kl, max_s, log_d_s, max_t,
263
+ # log_d_t, scale) to all threads in pass 2
264
+ # The bcast buffer is only needed when ``compute_backward`` is True.
265
+ smem = SmemAllocator()
266
+ red_buf = smem.allocate_tensor(cutlass.Float32, cute.make_layout((fb_num_warps, 5)))
267
+ if cutlass.const_expr(compute_backward):
268
+ bcast = smem.allocate_tensor(cutlass.Float32, cute.make_layout(6))
269
+
270
+ log2e = cutlass.Float32(_LOG2E)
271
+ neg_inf = cutlass.Float32(-cutlass.Float32.inf)
272
+
273
+ # Per-row scale = mask[bidx] * grad_output * inv_n_valid. The
274
+ # ``grad_output`` read is dead-code-eliminated for the fwd-only
275
+ # path so no spurious GMEM load is emitted.
276
+ mask_val = cast(mask_flat[bidx], cutlass.Float32)
277
+ inv_n = cast(inv_n_valid[0], cutlass.Float32)
278
+ if cutlass.const_expr(compute_backward):
279
+ grad_val = cast(grad_output[0], cutlass.Float32)
280
+ scale_val = mask_val * grad_val * inv_n
281
+
282
+ s_row = cute.slice_(student, (cutlass.Int64(bidx), None))
283
+ t_row = cute.slice_(teacher, (cutlass.Int64(bidx), None))
284
+ if cutlass.const_expr(compute_backward):
285
+ g_row = cute.slice_(grad_student, (cutlass.Int64(bidx), None))
286
+ out_dtype = grad_student.element_type
287
+
288
+ max_s = neg_inf
289
+ d_s = cutlass.Float32(0.0)
290
+ w_acc = cutlass.Float32(0.0)
291
+ max_t = neg_inf
292
+ d_t = cutlass.Float32(0.0)
293
+
294
+ thr_copy_p1 = tiled_copy_p1.get_slice(tidx)
295
+
296
+ # ---- Pass 1: online stats + W ----
297
+ for k in cutlass.range(num_full_tiles, unroll=2):
298
+ s_slab = cute.local_tile(s_row, (fb_tile_v,), (k,))
299
+ t_slab = cute.local_tile(t_row, (fb_tile_v,), (k,))
300
+
301
+ src_s = thr_copy_p1.partition_S(s_slab)
302
+ frag_s = cute.make_fragment_like(src_s)
303
+ cute.copy(copy_atom_p1, src_s, frag_s)
304
+
305
+ src_t = thr_copy_p1.partition_S(t_slab)
306
+ frag_t = cute.make_fragment_like(src_t)
307
+ cute.copy(copy_atom_p1, src_t, frag_t)
308
+
309
+ s_f32 = frag_s.load().to(cutlass.Float32)
310
+ t_f32 = frag_t.load().to(cutlass.Float32)
311
+
312
+ tile_max_s = s_f32.reduce(cute.ReductionOp.MAX, neg_inf, reduction_profile=0)
313
+ exp_s = cute.math.exp2((s_f32 - tile_max_s) * log2e, fastmath=True)
314
+ tile_d_s = exp_s.reduce( # ty: ignore[unresolved-attribute]
315
+ cute.ReductionOp.ADD, cutlass.Float32(0.0), reduction_profile=0
316
+ )
317
+ diff = s_f32 - t_f32
318
+ tile_w = (exp_s * diff).reduce(
319
+ cute.ReductionOp.ADD, cutlass.Float32(0.0), reduction_profile=0
320
+ )
321
+
322
+ new_max_s = cute.arch.fmax(max_s, tile_max_s)
323
+ corr_s = cute.math.exp2((max_s - new_max_s) * log2e, fastmath=True)
324
+ tile_corr_s = cute.math.exp2((tile_max_s - new_max_s) * log2e, fastmath=True)
325
+ d_s = d_s * corr_s + tile_d_s * tile_corr_s
326
+ w_acc = w_acc * corr_s + tile_w * tile_corr_s
327
+ max_s = new_max_s
328
+
329
+ tile_max_t = t_f32.reduce(cute.ReductionOp.MAX, neg_inf, reduction_profile=0)
330
+ exp_t = cute.math.exp2((t_f32 - tile_max_t) * log2e, fastmath=True)
331
+ tile_d_t = exp_t.reduce( # ty: ignore[unresolved-attribute]
332
+ cute.ReductionOp.ADD, cutlass.Float32(0.0), reduction_profile=0
333
+ )
334
+
335
+ new_max_t = cute.arch.fmax(max_t, tile_max_t)
336
+ corr_t = cute.math.exp2((max_t - new_max_t) * log2e, fastmath=True)
337
+ tile_corr_t = cute.math.exp2((tile_max_t - new_max_t) * log2e, fastmath=True)
338
+ d_t = d_t * corr_t + tile_d_t * tile_corr_t
339
+ max_t = new_max_t
340
+
341
+ if tail_len > 0:
342
+ tail_base = num_full_tiles * fb_tile_v
343
+
344
+ thr_max_s = neg_inf
345
+ thr_max_t = neg_inf
346
+ for i in cutlass.range(fb_vec_size):
347
+ e = tidx + i * fb_num_threads
348
+ if e < tail_len:
349
+ s_val = cast(s_row[tail_base + e], cutlass.Float32)
350
+ t_val = cast(t_row[tail_base + e], cutlass.Float32)
351
+ thr_max_s = cute.arch.fmax(thr_max_s, s_val)
352
+ thr_max_t = cute.arch.fmax(thr_max_t, t_val)
353
+
354
+ thr_d_s = cutlass.Float32(0.0)
355
+ thr_w = cutlass.Float32(0.0)
356
+ thr_d_t = cutlass.Float32(0.0)
357
+ for i in cutlass.range(fb_vec_size):
358
+ e = tidx + i * fb_num_threads
359
+ if e < tail_len:
360
+ s_val = cast(s_row[tail_base + e], cutlass.Float32)
361
+ t_val = cast(t_row[tail_base + e], cutlass.Float32)
362
+ exp_sv = cute.math.exp2((s_val - thr_max_s) * log2e, fastmath=True)
363
+ thr_d_s = thr_d_s + exp_sv
364
+ thr_w = thr_w + exp_sv * (s_val - t_val)
365
+ exp_tv = cute.math.exp2((t_val - thr_max_t) * log2e, fastmath=True)
366
+ thr_d_t = thr_d_t + exp_tv
367
+
368
+ new_max_s = cute.arch.fmax(max_s, thr_max_s)
369
+ corr_s = cute.math.exp2((max_s - new_max_s) * log2e, fastmath=True)
370
+ tail_corr_s = cute.math.exp2((thr_max_s - new_max_s) * log2e, fastmath=True)
371
+ d_s = d_s * corr_s + thr_d_s * tail_corr_s
372
+ w_acc = w_acc * corr_s + thr_w * tail_corr_s
373
+ max_s = new_max_s
374
+
375
+ new_max_t = cute.arch.fmax(max_t, thr_max_t)
376
+ corr_t = cute.math.exp2((max_t - new_max_t) * log2e, fastmath=True)
377
+ tail_corr_t = cute.math.exp2((thr_max_t - new_max_t) * log2e, fastmath=True)
378
+ d_t = d_t * corr_t + thr_d_t * tail_corr_t
379
+ max_t = new_max_t
380
+
381
+ # ---- Cross-warp reduction for student/teacher stats ----
382
+ warp_max_s = cute.arch.warp_reduction(max_s, cute.arch.fmax)
383
+ corr_w_s = cute.math.exp2((max_s - warp_max_s) * log2e, fastmath=True)
384
+ d_s = d_s * corr_w_s
385
+ w_acc = w_acc * corr_w_s
386
+ warp_d_s = cute.arch.warp_reduction(d_s, operator.add)
387
+ warp_w = cute.arch.warp_reduction(w_acc, operator.add)
388
+
389
+ warp_max_t = cute.arch.warp_reduction(max_t, cute.arch.fmax)
390
+ corr_w_t = cute.math.exp2((max_t - warp_max_t) * log2e, fastmath=True)
391
+ d_t = d_t * corr_w_t
392
+ warp_d_t = cute.arch.warp_reduction(d_t, operator.add)
393
+
394
+ lane_idx = cute.arch.lane_idx()
395
+ warp_idx = cute.arch.warp_idx()
396
+
397
+ if lane_idx == 0:
398
+ red_buf[warp_idx, 0] = warp_max_s
399
+ red_buf[warp_idx, 1] = warp_d_s
400
+ red_buf[warp_idx, 2] = warp_w
401
+ red_buf[warp_idx, 3] = warp_max_t
402
+ red_buf[warp_idx, 4] = warp_d_t
403
+ cute.arch.sync_threads()
404
+
405
+ # Warp 0 finishes the cross-warp reduction.
406
+ if warp_idx == 0:
407
+ r_max_s = neg_inf
408
+ r_d_s = cutlass.Float32(0.0)
409
+ r_w = cutlass.Float32(0.0)
410
+ r_max_t = neg_inf
411
+ r_d_t = cutlass.Float32(0.0)
412
+
413
+ if lane_idx < fb_num_warps:
414
+ r_max_s = red_buf[lane_idx, 0]
415
+ r_d_s = red_buf[lane_idx, 1]
416
+ r_w = red_buf[lane_idx, 2]
417
+ r_max_t = red_buf[lane_idx, 3]
418
+ r_d_t = red_buf[lane_idx, 4]
419
+
420
+ final_max_s = cute.arch.warp_reduction(
421
+ r_max_s, cute.arch.fmax, threads_in_group=fb_num_warps
422
+ )
423
+ fcorr_s = cute.math.exp2((r_max_s - final_max_s) * log2e, fastmath=True)
424
+ r_d_s = r_d_s * fcorr_s
425
+ r_w = r_w * fcorr_s
426
+ final_d_s = cute.arch.warp_reduction(r_d_s, operator.add, threads_in_group=fb_num_warps)
427
+ final_w = cute.arch.warp_reduction(r_w, operator.add, threads_in_group=fb_num_warps)
428
+
429
+ final_max_t = cute.arch.warp_reduction(
430
+ r_max_t, cute.arch.fmax, threads_in_group=fb_num_warps
431
+ )
432
+ fcorr_t = cute.math.exp2((r_max_t - final_max_t) * log2e, fastmath=True)
433
+ r_d_t = r_d_t * fcorr_t
434
+ final_d_t = cute.arch.warp_reduction(r_d_t, operator.add, threads_in_group=fb_num_warps)
435
+
436
+ if lane_idx == 0:
437
+ rcp_d_s = cute.arch.rcp_approx(final_d_s)
438
+ log_d_s = cute.math.log(final_d_s)
439
+ log_d_t = cute.math.log(final_d_t)
440
+ kl = final_w * rcp_d_s + log_d_t + final_max_t - log_d_s - final_max_s
441
+
442
+ # Cross-row loss accumulation: atomic-add (kl * mask) into
443
+ # ``loss_acc``; the last block scales by ``inv_n_valid``
444
+ # (and ``grad_output`` on the bwd path) and writes the
445
+ # scalar output.
446
+ loss_ptr = loss_acc.iterator.toint().ir_value() # ty: ignore[unresolved-attribute]
447
+ counter_ptr = counter.iterator.toint().ir_value() # ty: ignore[unresolved-attribute]
448
+ contribution = kl * mask_val
449
+ _atomic_add_f32_gmem(loss_ptr, contribution)
450
+ cute.arch.fence_acq_rel_gpu()
451
+ old = _atomic_inc_u32_gmem(counter_ptr, total_rows - 1)
452
+ if old == total_rows - 1:
453
+ if cutlass.const_expr(compute_backward):
454
+ final_loss = loss_acc[0] * inv_n * grad_val
455
+ else:
456
+ final_loss = loss_acc[0] * inv_n
457
+ output[0] = cast(final_loss, output.element_type) # ty: ignore[invalid-argument-type]
458
+ loss_acc[0] = cutlass.Float32(0.0)
459
+
460
+ if cutlass.const_expr(compute_backward):
461
+ bcast[0] = kl
462
+ bcast[1] = final_max_s
463
+ bcast[2] = log_d_s
464
+ bcast[3] = final_max_t
465
+ bcast[4] = log_d_t
466
+ bcast[5] = scale_val
467
+
468
+ # ---- Pass 2: re-read logits and write gradient ----
469
+ # Entire pass is dead-code-eliminated when ``compute_backward`` is
470
+ # False — the kernel becomes a single-pass forward identical in
471
+ # PTX to a hand-written fwd-only kernel.
472
+ if cutlass.const_expr(compute_backward):
473
+ cute.arch.sync_threads()
474
+
475
+ kl_b = bcast[0]
476
+ max_s_b = bcast[1]
477
+ log_d_s_b = bcast[2]
478
+ max_t_b = bcast[3]
479
+ log_d_t_b = bcast[4]
480
+ scale_b = bcast[5]
481
+
482
+ log_offset = max_t_b - max_s_b + log_d_t_b - log_d_s_b
483
+
484
+ # Skip Pass 2 entirely for rows with scale=0 (masked-out rows).
485
+ if scale_b != cutlass.Float32(0.0):
486
+ thr_copy_p2 = tiled_copy_p2.get_slice(tidx)
487
+
488
+ for k in cutlass.range(num_full_tiles, unroll=2):
489
+ s_slab = cute.local_tile(s_row, (fb_tile_v,), (k,))
490
+ t_slab = cute.local_tile(t_row, (fb_tile_v,), (k,))
491
+ g_slab = cute.local_tile(g_row, (fb_tile_v,), (k,))
492
+
493
+ src_s = thr_copy_p2.partition_S(s_slab)
494
+ frag_s = cute.make_fragment_like(src_s)
495
+ cute.copy(copy_atom_p2, src_s, frag_s)
496
+
497
+ src_t = thr_copy_p2.partition_S(t_slab)
498
+ frag_t = cute.make_fragment_like(src_t)
499
+ cute.copy(copy_atom_p2, src_t, frag_t)
500
+
501
+ s_f32 = frag_s.load().to(cutlass.Float32)
502
+ t_f32 = frag_t.load().to(cutlass.Float32)
503
+
504
+ # p_v = exp((s_v - max_s) * log2e) / D_s = exp(s_v - logZ_s)
505
+ p_v = cute.math.exp2((s_f32 - max_s_b - log_d_s_b) * log2e, fastmath=True)
506
+ log_diff = s_f32 - t_f32 + log_offset
507
+ grad = scale_b * p_v * (log_diff - kl_b)
508
+
509
+ dst_g = thr_copy_p2.partition_D(g_slab)
510
+ out_frag = cute.make_fragment_like(dst_g)
511
+ out_frag.store(grad.to(out_dtype))
512
+ cute.copy(copy_atom_store, out_frag, dst_g)
513
+
514
+ if tail_len > 0:
515
+ tail_base = num_full_tiles * fb_tile_v
516
+ for i in cutlass.range(fb_vec_size):
517
+ e = tidx + i * fb_num_threads
518
+ if e < tail_len:
519
+ s_val = cast(s_row[tail_base + e], cutlass.Float32)
520
+ t_val = cast(t_row[tail_base + e], cutlass.Float32)
521
+ p_v = cute.math.exp2(
522
+ (s_val - max_s_b - log_d_s_b) * log2e, fastmath=True
523
+ )
524
+ log_diff = s_val - t_val + log_offset
525
+ grad = scale_b * p_v * (log_diff - kl_b)
526
+ g_row[tail_base + e] = cast(grad, out_dtype) # ty: ignore[invalid-argument-type]
527
+ else:
528
+ # Masked row: write zeros so callers see a clean grad slab.
529
+ zero_elem = cast(cutlass.Float32(0.0), out_dtype) # ty: ignore[invalid-argument-type]
530
+ thr_copy_z = tiled_copy_p2.get_slice(tidx)
531
+ for k in cutlass.range(num_full_tiles, unroll=2):
532
+ g_slab = cute.local_tile(g_row, (fb_tile_v,), (k,))
533
+ dst_g = thr_copy_z.partition_D(g_slab)
534
+ zero_frag = cute.make_fragment_like(dst_g)
535
+ zero_frag.fill(zero_elem)
536
+ cute.copy(copy_atom_store, zero_frag, dst_g)
537
+
538
+ if tail_len > 0:
539
+ tail_base = num_full_tiles * fb_tile_v
540
+ for i in cutlass.range(fb_vec_size):
541
+ e = tidx + i * fb_num_threads
542
+ if e < tail_len:
543
+ g_row[tail_base + e] = zero_elem
544
+
545
+ return _kernel
546
+
547
+
548
+ def _make_fwd_bwd_launcher(
549
+ compute_backward: bool,
550
+ fb_num_threads: int,
551
+ fb_num_warps: int,
552
+ fb_vec_size: int,
553
+ fb_tile_v: int,
554
+ ) -> Callable[..., None]:
555
+ """Return a ``@cute.jit`` launcher that runs mask-sum + main kernel.
556
+
557
+ When ``compute_backward=False`` Pass 2 + bcast SMEM are dead-code
558
+ eliminated inside the kernel and Pass 1 loads default to NO_ALLOCATE
559
+ (no benefit from L2 pinning since there is no re-read).
560
+ """
561
+ main_kernel = _make_reverse_kl_kernel(
562
+ compute_backward, fb_num_threads, fb_num_warps, fb_vec_size, fb_tile_v
563
+ )
564
+ mask_sum_kernel = _make_mask_sum_kernel(fb_num_threads, fb_num_warps)
565
+
566
+ @cute.jit
567
+ def _launch(
568
+ student_2d: cute.Tensor,
569
+ teacher_2d: cute.Tensor,
570
+ mask_flat: cute.Tensor,
571
+ grad_output: cute.Tensor,
572
+ inv_n_valid: cute.Tensor,
573
+ grad_student_2d: cute.Tensor,
574
+ loss_acc: cute.Tensor,
575
+ valid_acc: cute.Tensor,
576
+ counter: cute.Tensor,
577
+ mask_counter: cute.Tensor,
578
+ output: cute.Tensor,
579
+ num_full_tiles: cutlass.Int32,
580
+ tail_len: cutlass.Int32,
581
+ mask_sum_blocks: cutlass.Int32,
582
+ ) -> None:
583
+ num_rows = student_2d.shape[0] # ty: ignore[not-subscriptable]
584
+ dtype = student_2d.element_type
585
+ out_dtype = grad_student_2d.element_type
586
+
587
+ # Pass 1 loads: EVICT_LAST when we need the data pinned in L2 for
588
+ # Pass 2 re-reads; NO_ALLOCATE for fwd-only since the data is
589
+ # never re-read.
590
+ if cutlass.const_expr(compute_backward):
591
+ p1_evict = CacheEvictionPriority.EVICT_LAST
592
+ else:
593
+ p1_evict = CacheEvictionPriority.NO_ALLOCATE
594
+ copy_atom_p1 = cute.make_copy_atom(
595
+ CopyUniversalOp(),
596
+ dtype,
597
+ num_bits_per_copy=fb_vec_size * dtype.width, # ty: ignore[unresolved-attribute]
598
+ l1c_evict_priority=p1_evict,
599
+ )
600
+ thr_layout = cute.make_layout((fb_num_threads,))
601
+ val_layout = cute.make_layout((fb_vec_size,))
602
+ tiler_v_p1, layout_tv_p1 = cute.make_layout_tv(thr_layout, val_layout)
603
+ tiled_copy_p1 = cute.make_tiled_copy(copy_atom_p1, layout_tv_p1, tiler_v_p1)
604
+
605
+ # Pass 2 loads: NO_ALLOCATE — streaming, never re-read.
606
+ copy_atom_p2 = cute.make_copy_atom(
607
+ CopyUniversalOp(),
608
+ dtype,
609
+ num_bits_per_copy=fb_vec_size * dtype.width, # ty: ignore[unresolved-attribute]
610
+ l1c_evict_priority=CacheEvictionPriority.NO_ALLOCATE,
611
+ )
612
+ tiler_v_p2, layout_tv_p2 = cute.make_layout_tv(thr_layout, val_layout)
613
+ tiled_copy_p2 = cute.make_tiled_copy(copy_atom_p2, layout_tv_p2, tiler_v_p2)
614
+
615
+ # Stores: NO_ALLOCATE — write-once, never re-read.
616
+ copy_atom_store = cute.make_copy_atom(
617
+ CopyUniversalOp(),
618
+ out_dtype,
619
+ num_bits_per_copy=fb_vec_size * out_dtype.width, # ty: ignore[unresolved-attribute]
620
+ l1c_evict_priority=CacheEvictionPriority.NO_ALLOCATE,
621
+ )
622
+
623
+ mask_sum_kernel( # ty: ignore[unresolved-attribute]
624
+ mask_flat,
625
+ inv_n_valid,
626
+ valid_acc,
627
+ mask_counter,
628
+ num_rows,
629
+ mask_sum_blocks,
630
+ ).launch(
631
+ grid=(mask_sum_blocks, 1, 1),
632
+ block=(fb_num_threads, 1, 1),
633
+ )
634
+
635
+ main_kernel( # ty: ignore[unresolved-attribute]
636
+ student_2d,
637
+ teacher_2d,
638
+ mask_flat,
639
+ grad_output,
640
+ inv_n_valid,
641
+ grad_student_2d,
642
+ loss_acc,
643
+ counter,
644
+ output,
645
+ num_full_tiles,
646
+ tail_len,
647
+ num_rows,
648
+ tiled_copy_p1,
649
+ copy_atom_p1,
650
+ tiled_copy_p2,
651
+ copy_atom_p2,
652
+ copy_atom_store,
653
+ ).launch(
654
+ grid=(num_rows, 1, 1),
655
+ block=(fb_num_threads, 1, 1),
656
+ )
657
+
658
+ return _launch
659
+
660
+
661
+ # ---------------------------------------------------------------------------
662
+ # Public factory
663
+ # ---------------------------------------------------------------------------
664
+
665
+
666
+ def create_compiled_reverse_kl(
667
+ policy_dtype: torch.dtype,
668
+ vocab: int,
669
+ compute_backward: bool = False,
670
+ ) -> Callable[..., torch.Tensor | tuple[torch.Tensor, torch.Tensor]]:
671
+ """JIT-compile the fused reverse-KL kernel (forward or fused fwd+bwd).
672
+
673
+ Bundles a mask-sum kernel (which computes ``inv_n_valid``) and the
674
+ main per-row kernel into a single ``@cute.jit`` launch so each call
675
+ is one tvm-ffi dispatch with no host syncs. The unmasked case is
676
+ handled by passing an all-ones ``mask_flat`` (callers are responsible).
677
+
678
+ When ``compute_backward=False`` the runner takes 10 args and returns
679
+ the scalar loss. When ``compute_backward=True`` it takes 11 args
680
+ (with a real ``grad_student_2d`` slab) and returns
681
+ ``(loss_scalar, grad_student_2d)``.
682
+
683
+ Args:
684
+ policy_dtype: Element dtype of student/teacher logits and the
685
+ output gradient slab. Both tensors must share this dtype.
686
+ vocab: ``V`` dimension of the ``(num_tokens, V)`` slab. Captured
687
+ as a compile-time constant; the number of token rows stays
688
+ symbolic across calls.
689
+ compute_backward: When ``True`` the kernel additionally writes
690
+ the analytical gradient through the softmax Jacobian into
691
+ the caller's ``grad_student_2d`` slab.
692
+ """
693
+ if policy_dtype not in _TORCH_TO_CUTLASS_DTYPE:
694
+ raise ValueError(f"Unsupported dtype for self-distillation loss: {policy_dtype}")
695
+
696
+ student_dtype = _TORCH_TO_CUTLASS_DTYPE[policy_dtype]
697
+ fb_vec_size = FB_LOAD_BITS // student_dtype.width
698
+ fb_tile_v = FB_NUM_THREADS * fb_vec_size
699
+
700
+ num_full_tiles = vocab // fb_tile_v
701
+ tail_len = vocab % fb_tile_v
702
+ block_size = FB_NUM_WARPS * 32
703
+
704
+ num_tokens_sym = cute.sym_int()
705
+
706
+ fake_s = cute.runtime.make_fake_compact_tensor(
707
+ student_dtype,
708
+ (num_tokens_sym, vocab),
709
+ stride_order=(1, 0),
710
+ assumed_align=16,
711
+ )
712
+ fake_t = cute.runtime.make_fake_compact_tensor(
713
+ student_dtype,
714
+ (num_tokens_sym, vocab),
715
+ stride_order=(1, 0),
716
+ assumed_align=16,
717
+ )
718
+ fake_mask = cute.runtime.make_fake_compact_tensor(
719
+ cutlass.Float32,
720
+ (num_tokens_sym,),
721
+ assumed_align=16,
722
+ )
723
+ fake_grad_out = cute.runtime.make_fake_compact_tensor(
724
+ cutlass.Float32,
725
+ (1,),
726
+ assumed_align=16,
727
+ )
728
+ fake_inv_n = cute.runtime.make_fake_compact_tensor(
729
+ cutlass.Float32,
730
+ (1,),
731
+ assumed_align=16,
732
+ )
733
+ # Full ``(N, V)`` slab when computing the backward; 1-column dummy
734
+ # slab for fwd-only since Pass 2 is dead-code-eliminated and only
735
+ # the tensor signature matters.
736
+ grad_v_dim = vocab if compute_backward else 1
737
+ fake_grad_student = cute.runtime.make_fake_compact_tensor(
738
+ student_dtype,
739
+ (num_tokens_sym, grad_v_dim),
740
+ stride_order=(1, 0),
741
+ assumed_align=16,
742
+ )
743
+ fake_loss = cute.runtime.make_fake_compact_tensor(
744
+ cutlass.Float32,
745
+ (1,),
746
+ assumed_align=16,
747
+ )
748
+ fake_valid = cute.runtime.make_fake_compact_tensor(
749
+ cutlass.Float32,
750
+ (1,),
751
+ assumed_align=16,
752
+ )
753
+ fake_counter = cute.runtime.make_fake_compact_tensor(
754
+ cutlass.Int32,
755
+ (1,),
756
+ assumed_align=16,
757
+ )
758
+ fake_mask_counter = cute.runtime.make_fake_compact_tensor(
759
+ cutlass.Int32,
760
+ (1,),
761
+ assumed_align=16,
762
+ )
763
+ fake_output = cute.runtime.make_fake_compact_tensor(
764
+ student_dtype,
765
+ (1,),
766
+ assumed_align=16,
767
+ )
768
+
769
+ launcher = _make_fwd_bwd_launcher(
770
+ compute_backward=compute_backward,
771
+ fb_num_threads=FB_NUM_THREADS,
772
+ fb_num_warps=FB_NUM_WARPS,
773
+ fb_vec_size=fb_vec_size,
774
+ fb_tile_v=fb_tile_v,
775
+ )
776
+ compiled = cute.compile(
777
+ launcher,
778
+ fake_s,
779
+ fake_t,
780
+ fake_mask,
781
+ fake_grad_out,
782
+ fake_inv_n,
783
+ fake_grad_student,
784
+ fake_loss,
785
+ fake_valid,
786
+ fake_counter,
787
+ fake_mask_counter,
788
+ fake_output,
789
+ cutlass.Int32(num_full_tiles),
790
+ cutlass.Int32(tail_len),
791
+ cutlass.Int32(1),
792
+ options="--enable-tvm-ffi",
793
+ )
794
+
795
+ nft_const = cutlass.Int32(num_full_tiles)
796
+ tl_const = cutlass.Int32(tail_len)
797
+ grad_v_dim_runtime = vocab if compute_backward else 1
798
+
799
+ # ---- Closure-scoped scratch ----
800
+ # ``grad_output`` is the upstream gradient feeding Pass 2 (1.0 for
801
+ # backward, irrelevant for forward — Pass 2 is dead-code-eliminated).
802
+ # Constant across calls; allocated lazily on first call when the
803
+ # device is known, then reused.
804
+ #
805
+ # ``scratch_z`` coalesces the 4 atomic-accumulator scalars (counter,
806
+ # mask_counter, loss_acc.fp32, valid_acc.fp32) into one int32 slab
807
+ # with stride-4 (16-byte) slices so each slot is individually 16-byte
808
+ # aligned (``assumed_align=16`` at compile time). Bit-pattern of int32 0
809
+ # equals fp32 0.0, so a single ``zeros`` factory legitimately
810
+ # initialises both the int32 counters and the fp32 accumulators.
811
+ # Both kernels' last blocks self-reset their fp32 accumulators in
812
+ # their epilogues, and counters self-reset via ``atom.inc.u32``
813
+ # wrap-around — so the up-front ``torch.zeros`` only matters for the
814
+ # first call.
815
+ _scratch: list[tuple[torch.Tensor, torch.Tensor] | None] = [None]
816
+
817
+ def _ensure_scratch(
818
+ device: torch.device,
819
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
820
+ s = _scratch[0]
821
+ if s is None or s[0].device != device:
822
+ slab = torch.zeros(16, dtype=torch.int32, device=device)
823
+ if compute_backward:
824
+ grad_output = torch.ones(1, dtype=torch.float32, device=device)
825
+ else:
826
+ grad_output = torch.empty(1, dtype=torch.float32, device=device)
827
+ _scratch[0] = (slab, grad_output)
828
+ s = _scratch[0]
829
+ slab, grad_output = s
830
+ return (
831
+ slab[0:1], # counter (int32)
832
+ slab[4:5], # mask_counter (int32)
833
+ slab[8:9].view(torch.float32), # loss_acc (fp32)
834
+ slab[12:13].view(torch.float32), # valid_acc (fp32)
835
+ grad_output,
836
+ )
837
+
838
+ def _run(
839
+ student_2d: torch.Tensor,
840
+ teacher_2d: torch.Tensor,
841
+ mask_flat: torch.Tensor,
842
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
843
+ num_rows = student_2d.shape[0]
844
+ device = student_2d.device
845
+ dtype = student_2d.dtype
846
+
847
+ counter_r, mask_counter_r, loss_acc_r, valid_acc_r, grad_output_r = _ensure_scratch(device)
848
+
849
+ # Per-call write-only buffers — ``empty`` is enough.
850
+ # ``inv_n_valid`` is populated by the bundled mask-sum kernel
851
+ # before the main kernel reads it; the runner never reads it.
852
+ inv_n_valid_r = torch.empty(1, dtype=torch.float32, device=device)
853
+ output_r = torch.empty(1, dtype=dtype, device=device)
854
+ # ``grad_v_dim`` is ``vocab`` for backward and ``1`` for forward
855
+ # (1-column dummy slab — Pass 2 is dead-code-eliminated, only the
856
+ # tensor-parameter signature matters).
857
+ grad_buffer = torch.empty(num_rows, grad_v_dim_runtime, dtype=dtype, device=device)
858
+
859
+ mask_sum_blocks = (num_rows + block_size - 1) // block_size
860
+ compiled(
861
+ student_2d,
862
+ teacher_2d,
863
+ mask_flat,
864
+ grad_output_r,
865
+ inv_n_valid_r,
866
+ grad_buffer,
867
+ loss_acc_r,
868
+ valid_acc_r,
869
+ counter_r,
870
+ mask_counter_r,
871
+ output_r,
872
+ nft_const,
873
+ tl_const,
874
+ cutlass.Int32(mask_sum_blocks),
875
+ )
876
+ out_view = output_r.view(())
877
+ if compute_backward:
878
+ return out_view, grad_buffer
879
+ return out_view
880
+
881
+ return _run