Build uploaded using `kernels`.
Browse files- build/torch-cuda/__init__.py +69 -0
- build/torch-cuda/_ops.py +38 -0
- build/torch-cuda/bnpo_loss/__init__.py +196 -0
- build/torch-cuda/bnpo_loss/_torch_ref.py +56 -0
- build/torch-cuda/bnpo_loss/autograd.py +149 -0
- build/torch-cuda/bnpo_loss/cute_bnpo_loss.py +1081 -0
- build/torch-cuda/geometric_ai_kernels/__init__.py +26 -0
- build/torch-cuda/grpo_loss/__init__.py +169 -0
- build/torch-cuda/grpo_loss/_torch_ref.py +87 -0
- build/torch-cuda/grpo_loss/autograd.py +120 -0
- build/torch-cuda/grpo_loss/cute_grpo_loss.py +805 -0
- build/torch-cuda/layers.py +258 -0
- build/torch-cuda/metadata.json +12 -0
- build/torch-cuda/reverse_kl/__init__.py +196 -0
- build/torch-cuda/reverse_kl/_torch_ref.py +65 -0
- build/torch-cuda/reverse_kl/autograd.py +122 -0
- build/torch-cuda/reverse_kl/cute_reverse_kl.py +881 -0
build/torch-cuda/__init__.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Geometric-AI CuteDSL kernels for RL / distillation training.
|
| 2 |
+
|
| 3 |
+
Public surface:
|
| 4 |
+
* ``bnpo_loss`` / ``bnpo_loss_autograd`` / ``bnpo_loss_fwd`` —
|
| 5 |
+
fused fwd+bwd BNPO loss with three entry points (direct
|
| 6 |
+
``(loss, grad)``, autograd-wrapped, forward-only).
|
| 7 |
+
* ``grpo_loss`` / ``grpo_loss_autograd`` / ``grpo_loss_fwd`` —
|
| 8 |
+
fused fwd+bwd GRPO loss (TRL's per-response normalization
|
| 9 |
+
variant). Same three-entry-point shape as BNPO. Requires
|
| 10 |
+
``completions_mask``.
|
| 11 |
+
* ``reverse_kl`` / ``reverse_kl_autograd`` /
|
| 12 |
+
``reverse_kl_fwd`` — fused fwd+bwd reverse-KL
|
| 13 |
+
self-distillation loss with the same three-entry-point shape.
|
| 14 |
+
|
| 15 |
+
HF Kernels integration: :mod:`geometric_ai_kernels.layers` exposes
|
| 16 |
+
``nn.Module`` adapters per kernel (``bnpoLoss`` / ``bnpoLossInference``,
|
| 17 |
+
``grpoLoss`` / ``grpoLossInference``, ``ReverseKL`` /
|
| 18 |
+
``ReverseKLInference``) for use with the ``kernels``
|
| 19 |
+
library's ``kernelize()`` flow.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
from __future__ import annotations
|
| 23 |
+
|
| 24 |
+
import torch._dynamo
|
| 25 |
+
|
| 26 |
+
from .bnpo_loss import bnpo_loss, bnpo_loss_autograd, bnpo_loss_fwd
|
| 27 |
+
from .grpo_loss import grpo_loss, grpo_loss_autograd, grpo_loss_fwd
|
| 28 |
+
from .layers import (
|
| 29 |
+
ReverseKL,
|
| 30 |
+
ReverseKLInference,
|
| 31 |
+
bnpoLoss,
|
| 32 |
+
bnpoLossInference,
|
| 33 |
+
grpoLoss,
|
| 34 |
+
grpoLossInference,
|
| 35 |
+
)
|
| 36 |
+
from .reverse_kl import (
|
| 37 |
+
reverse_kl,
|
| 38 |
+
reverse_kl_autograd,
|
| 39 |
+
reverse_kl_fwd,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# Required so ``torch.compile(fullgraph=True)`` can trace through
|
| 43 |
+
# ``torch.autograd.grad`` calls — without it Dynamo graph-breaks at the
|
| 44 |
+
# autograd.grad call site even when AOTAutograd has already derived the
|
| 45 |
+
# joint fwd+bwd graph. Set at package import so any consumer (benches,
|
| 46 |
+
# user training loops, ``kernelize`` flows) gets it for free. Guarded
|
| 47 |
+
# because ``trace_autograd_ops`` was added in torch 2.10 and the
|
| 48 |
+
# Nix-pinned build environment may be on an older torch (2.9 today);
|
| 49 |
+
# the underlying ``Config.__setattr__`` raises on unknown keys.
|
| 50 |
+
if hasattr(torch._dynamo.config, "trace_autograd_ops"):
|
| 51 |
+
torch._dynamo.config.trace_autograd_ops = True # ty: ignore[invalid-assignment]
|
| 52 |
+
|
| 53 |
+
__all__ = [
|
| 54 |
+
"ReverseKL",
|
| 55 |
+
"ReverseKLInference",
|
| 56 |
+
"bnpoLoss",
|
| 57 |
+
"bnpoLossInference",
|
| 58 |
+
"bnpo_loss",
|
| 59 |
+
"bnpo_loss_autograd",
|
| 60 |
+
"bnpo_loss_fwd",
|
| 61 |
+
"grpoLoss",
|
| 62 |
+
"grpoLossInference",
|
| 63 |
+
"grpo_loss",
|
| 64 |
+
"grpo_loss_autograd",
|
| 65 |
+
"grpo_loss_fwd",
|
| 66 |
+
"reverse_kl",
|
| 67 |
+
"reverse_kl_autograd",
|
| 68 |
+
"reverse_kl_fwd",
|
| 69 |
+
]
|
build/torch-cuda/_ops.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
def get_backend() -> str:
|
| 4 |
+
"""Detect the backend by inspecting torch."""
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
if hasattr(torch, "neuron"):
|
| 8 |
+
# Needs to be sorted before specific Torch builds, since Neuron
|
| 9 |
+
# extension can be loaded into e.g. CUDA Torch builds.
|
| 10 |
+
return "neuron"
|
| 11 |
+
elif torch.version.cuda is not None:
|
| 12 |
+
return "cuda"
|
| 13 |
+
elif torch.version.hip is not None:
|
| 14 |
+
return "rocm"
|
| 15 |
+
elif torch.backends.mps.is_available():
|
| 16 |
+
return "metal"
|
| 17 |
+
elif hasattr(torch.version, "xpu") and torch.version.xpu is not None:
|
| 18 |
+
return "xpu"
|
| 19 |
+
else:
|
| 20 |
+
return "cpu"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _find_ops_name() -> str:
|
| 24 |
+
kernel_name = "geometric_ai_kernels"
|
| 25 |
+
unique_id = "a766fbd_dirty"
|
| 26 |
+
backend = get_backend()
|
| 27 |
+
return f"_{kernel_name}_{backend}_{unique_id}"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
_OPS_NAME = _find_ops_name()
|
| 31 |
+
|
| 32 |
+
ops = getattr(torch.ops, _OPS_NAME)
|
| 33 |
+
|
| 34 |
+
def add_op_namespace_prefix(op_name: str) -> str:
|
| 35 |
+
"""
|
| 36 |
+
Prefix op by namespace.
|
| 37 |
+
"""
|
| 38 |
+
return f"{_OPS_NAME}::{op_name}"
|
build/torch-cuda/bnpo_loss/__init__.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""bnpo loss with CuteDSL fused fwd+bwd.
|
| 2 |
+
|
| 3 |
+
Two public APIs route to two compiled kernels:
|
| 4 |
+
|
| 5 |
+
* :func:`bnpo_loss` — primary training entry point. Returns
|
| 6 |
+
``(loss, grad_policy_logprobs)`` from a single fused fwd+bwd kernel
|
| 7 |
+
launch. Inputs do **not** need ``requires_grad=True`` and there is no
|
| 8 |
+
``torch.autograd.Function`` wrapper — chain the gradient into the
|
| 9 |
+
upstream model with ``policy_logprobs.backward(grad)`` (or, more
|
| 10 |
+
commonly, by passing ``grad`` to whatever step does the next leg of
|
| 11 |
+
backprop).
|
| 12 |
+
* :func:`bnpo_loss_fwd` — inference / validation path. Returns the
|
| 13 |
+
scalar ``loss`` from a forward-only kernel that computes the masked
|
| 14 |
+
mean denominator on-GPU via a last-block trick (no host
|
| 15 |
+
``completions_mask.sum()``).
|
| 16 |
+
|
| 17 |
+
The two share the same compiled-kernel cache; per-call output and
|
| 18 |
+
gradient buffers are allocated inside the runner, and cross-CTA scratch
|
| 19 |
+
(atomic accumulators + counters) is owned by the compiled-kernel
|
| 20 |
+
closure and self-resets each launch — callers don't manage scratch.
|
| 21 |
+
|
| 22 |
+
Why no autograd wrapper here? bnpo's gradient is closed-form — the
|
| 23 |
+
kernel already writes ``dL/d(policy_logprobs)`` in the same launch as
|
| 24 |
+
the loss. Wrapping in ``torch.autograd.Function`` would cost an extra
|
| 25 |
+
``grad_output * dpolicy`` kernel launch on backward (typically a
|
| 26 |
+
no-op multiply by ``1.0``), plus per-call autograd graph bookkeeping.
|
| 27 |
+
The autograd-aware sibling :func:`bnpo_loss_autograd` uses
|
| 28 |
+
``torch.library.custom_op`` instead, which composes with
|
| 29 |
+
``torch.compile``.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
from __future__ import annotations
|
| 33 |
+
|
| 34 |
+
from functools import lru_cache
|
| 35 |
+
from typing import TYPE_CHECKING, cast
|
| 36 |
+
|
| 37 |
+
import torch
|
| 38 |
+
|
| 39 |
+
from .cute_bnpo_loss import (
|
| 40 |
+
create_compiled_bnpo_loss,
|
| 41 |
+
create_compiled_bnpo_loss_with_backward,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
if TYPE_CHECKING:
|
| 45 |
+
from collections.abc import Callable
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
__all__ = ["bnpo_loss", "bnpo_loss_autograd", "bnpo_loss_fwd"]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@lru_cache(maxsize=32)
|
| 52 |
+
def _get_compiled_fwd(
|
| 53 |
+
dtype: torch.dtype,
|
| 54 |
+
epsilon: float,
|
| 55 |
+
epsilon_high: float,
|
| 56 |
+
beta: float,
|
| 57 |
+
) -> Callable[..., torch.Tensor]:
|
| 58 |
+
return cast(
|
| 59 |
+
"Callable[..., torch.Tensor]",
|
| 60 |
+
create_compiled_bnpo_loss(
|
| 61 |
+
policy_dtype=dtype,
|
| 62 |
+
epsilon=epsilon,
|
| 63 |
+
epsilon_high=epsilon_high,
|
| 64 |
+
beta=beta,
|
| 65 |
+
compute_backward=False,
|
| 66 |
+
),
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@lru_cache(maxsize=32)
|
| 71 |
+
def _get_compiled_fwd_bwd(
|
| 72 |
+
dtype: torch.dtype,
|
| 73 |
+
epsilon: float,
|
| 74 |
+
epsilon_high: float,
|
| 75 |
+
beta: float,
|
| 76 |
+
) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]:
|
| 77 |
+
return create_compiled_bnpo_loss_with_backward(
|
| 78 |
+
policy_dtype=dtype,
|
| 79 |
+
epsilon=epsilon,
|
| 80 |
+
epsilon_high=epsilon_high,
|
| 81 |
+
beta=beta,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def bnpo_loss_fwd(
|
| 86 |
+
policy_logprobs: torch.Tensor,
|
| 87 |
+
old_policy_logprobs: torch.Tensor,
|
| 88 |
+
ref_logprobs: torch.Tensor,
|
| 89 |
+
advantages: torch.Tensor,
|
| 90 |
+
completions_mask: torch.Tensor,
|
| 91 |
+
epsilon: float = 0.2,
|
| 92 |
+
epsilon_high: float = 0.2,
|
| 93 |
+
beta: float = 0.1,
|
| 94 |
+
) -> torch.Tensor:
|
| 95 |
+
"""Forward-only bnpo loss. Returns the scalar ``loss``.
|
| 96 |
+
|
| 97 |
+
Use for inference / validation. The masked mean denominator is
|
| 98 |
+
computed on-GPU by an atomic accumulator + last-block trick — no
|
| 99 |
+
host ``completions_mask.sum()`` syncs.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
policy_logprobs, old_policy_logprobs, ref_logprobs: ``(bs, seq_len)``.
|
| 103 |
+
advantages: ``(bs,)``.
|
| 104 |
+
completions_mask: bool/int8 mask ``(bs, seq_len)``; truthy = valid token.
|
| 105 |
+
epsilon, epsilon_high: PPO-style clipping bounds.
|
| 106 |
+
beta: KL-penalty coefficient. ``0.0`` compiles away the KL branch.
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
Scalar tensor (0-dim) with the same dtype as ``policy_logprobs``.
|
| 110 |
+
"""
|
| 111 |
+
run = _get_compiled_fwd(
|
| 112 |
+
policy_logprobs.dtype,
|
| 113 |
+
float(epsilon),
|
| 114 |
+
float(epsilon_high),
|
| 115 |
+
float(beta),
|
| 116 |
+
)
|
| 117 |
+
mask_arg = (
|
| 118 |
+
completions_mask
|
| 119 |
+
if completions_mask.dtype == torch.int8
|
| 120 |
+
else completions_mask.to(torch.int8)
|
| 121 |
+
)
|
| 122 |
+
return run(
|
| 123 |
+
policy_logprobs,
|
| 124 |
+
old_policy_logprobs,
|
| 125 |
+
ref_logprobs,
|
| 126 |
+
advantages,
|
| 127 |
+
mask_arg,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def bnpo_loss(
|
| 132 |
+
policy_logprobs: torch.Tensor,
|
| 133 |
+
old_policy_logprobs: torch.Tensor,
|
| 134 |
+
ref_logprobs: torch.Tensor,
|
| 135 |
+
advantages: torch.Tensor,
|
| 136 |
+
completions_mask: torch.Tensor,
|
| 137 |
+
epsilon: float = 0.2,
|
| 138 |
+
epsilon_high: float = 0.2,
|
| 139 |
+
beta: float = 0.1,
|
| 140 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 141 |
+
"""Fused fwd+bwd bnpo loss. Returns ``(loss, grad_policy_logprobs)``.
|
| 142 |
+
|
| 143 |
+
Single-launch training entry point. The kernel writes both the
|
| 144 |
+
scalar loss and the scaled ``dL/d(policy_logprobs)`` tensor in one
|
| 145 |
+
``@cute.jit`` dispatch — a bundled mask-sum kernel runs inside the
|
| 146 |
+
same launch so ``inv_total`` is populated on-GPU without a host-side
|
| 147 |
+
``torch.sum`` round trip.
|
| 148 |
+
|
| 149 |
+
Inputs do **not** need ``requires_grad=True``. To chain ``grad``
|
| 150 |
+
into the upstream model that produced ``policy_logprobs``::
|
| 151 |
+
|
| 152 |
+
loss, grad = bnpo_loss(policy_logprobs, ..., completions_mask=mask)
|
| 153 |
+
policy_logprobs.backward(grad)
|
| 154 |
+
optimizer.step()
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
policy_logprobs, old_policy_logprobs, ref_logprobs: ``(bs, seq_len)``.
|
| 158 |
+
advantages: ``(bs,)``.
|
| 159 |
+
completions_mask: bool/int8 mask ``(bs, seq_len)``.
|
| 160 |
+
epsilon, epsilon_high: PPO-style clipping bounds.
|
| 161 |
+
beta: KL-penalty coefficient. ``0.0`` compiles away the KL branch.
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
``(loss, grad_policy_logprobs)`` — ``loss`` is a 0-dim tensor in
|
| 165 |
+
``policy_logprobs.dtype``; ``grad_policy_logprobs`` has shape
|
| 166 |
+
``(bs, seq_len)`` and is already scaled by ``1 / n_valid``. The
|
| 167 |
+
gradient tensor is freshly allocated per call (no shared cache),
|
| 168 |
+
so callers may keep it around freely.
|
| 169 |
+
|
| 170 |
+
For inference / validation where you only need the loss, use
|
| 171 |
+
:func:`bnpo_loss_fwd` — it skips the dpolicy write entirely and
|
| 172 |
+
computes the mean denominator with the on-GPU last-block trick.
|
| 173 |
+
"""
|
| 174 |
+
run = _get_compiled_fwd_bwd(
|
| 175 |
+
policy_logprobs.dtype,
|
| 176 |
+
float(epsilon),
|
| 177 |
+
float(epsilon_high),
|
| 178 |
+
float(beta),
|
| 179 |
+
)
|
| 180 |
+
mask_arg = (
|
| 181 |
+
completions_mask
|
| 182 |
+
if completions_mask.dtype == torch.int8
|
| 183 |
+
else completions_mask.to(torch.int8)
|
| 184 |
+
)
|
| 185 |
+
return run(
|
| 186 |
+
policy_logprobs,
|
| 187 |
+
old_policy_logprobs,
|
| 188 |
+
ref_logprobs,
|
| 189 |
+
advantages,
|
| 190 |
+
mask_arg,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
# Imported at the bottom: ``autograd.py`` imports ``bnpo_loss`` from this
|
| 195 |
+
# module, so the function must be fully defined before its import runs.
|
| 196 |
+
from .autograd import bnpo_loss_autograd # noqa: E402
|
build/torch-cuda/bnpo_loss/_torch_ref.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Plain-PyTorch bnpo reference shared between the bench and the tests.
|
| 2 |
+
|
| 3 |
+
This module is intentionally minimal — every op is a vanilla torch op so
|
| 4 |
+
``AOTAutograd`` can derive the joint fwd+bwd graph and Inductor can fuse
|
| 5 |
+
both passes (used by ``benchmarks/benchmark_bnpo_loss.py``'s compiled
|
| 6 |
+
baseline). The same function is imported by ``tests/test_bnpo_loss.py``
|
| 7 |
+
as the correctness reference, so both paths agree on what "the eager
|
| 8 |
+
torch implementation of bnpo loss" means.
|
| 9 |
+
|
| 10 |
+
Underscore-prefixed module name signals "shared internal", not a public
|
| 11 |
+
API surface — there's no re-export from the package's top-level
|
| 12 |
+
``__init__.py``.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def torch_bnpo_loss(
|
| 21 |
+
policy_logprobs: torch.Tensor,
|
| 22 |
+
old_policy_logprobs: torch.Tensor,
|
| 23 |
+
ref_logprobs: torch.Tensor,
|
| 24 |
+
advantages: torch.Tensor,
|
| 25 |
+
completions_mask: torch.Tensor,
|
| 26 |
+
epsilon: float = 0.2,
|
| 27 |
+
epsilon_high: float = 0.2,
|
| 28 |
+
beta: float = 0.1,
|
| 29 |
+
) -> torch.Tensor:
|
| 30 |
+
"""Plain-Python bnpo reference traceable by AOTAutograd / Inductor.
|
| 31 |
+
|
| 32 |
+
Operates in the input dtype throughout (no internal fp32 cast),
|
| 33 |
+
which is what real torch users would write — and what
|
| 34 |
+
``torch.compile`` competes against in the bench.
|
| 35 |
+
"""
|
| 36 |
+
ratio = torch.exp(policy_logprobs - old_policy_logprobs)
|
| 37 |
+
adv = advantages.unsqueeze(1)
|
| 38 |
+
|
| 39 |
+
surrogate = ratio * adv
|
| 40 |
+
surrogate_clipped = torch.clamp(ratio, 1.0 - epsilon, 1.0 + epsilon_high) * adv
|
| 41 |
+
policy_loss = -torch.min(surrogate, surrogate_clipped)
|
| 42 |
+
|
| 43 |
+
log_ratio_ref = ref_logprobs - policy_logprobs
|
| 44 |
+
kl = torch.exp(log_ratio_ref) - log_ratio_ref - 1.0
|
| 45 |
+
|
| 46 |
+
# Cast n_valid to fp32: int64 → fp16 overflows when n_valid > 65504.
|
| 47 |
+
# ``clamp(min=1.0)`` matches TRL's ``mask.sum().clamp(min=1)``: a
|
| 48 |
+
# fully-masked batch produces ``loss=0`` instead of inf/NaN. Mirrors
|
| 49 |
+
# the cute kernel's ``cute.arch.fmax(..., 1.0)`` before ``rcp_approx``
|
| 50 |
+
# in ``cute_bnpo_loss.py``.
|
| 51 |
+
n_valid = completions_mask.sum().to(torch.float32).clamp(min=1.0)
|
| 52 |
+
policy_loss = (policy_loss * completions_mask).sum() / n_valid
|
| 53 |
+
kl = (kl * completions_mask).sum() / n_valid
|
| 54 |
+
|
| 55 |
+
loss = policy_loss + beta * kl
|
| 56 |
+
return loss.to(policy_logprobs.dtype)
|
build/torch-cuda/bnpo_loss/autograd.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Autograd-aware wrapper for bnpo loss via ``torch.library.custom_op``.
|
| 2 |
+
|
| 3 |
+
The fused cute kernel writes both the scalar loss and the closed-form
|
| 4 |
+
``dL/d(policy_logprobs)`` in one launch. This module wraps that into an
|
| 5 |
+
autograd-compatible op so callers can write::
|
| 6 |
+
|
| 7 |
+
loss = bnpo_loss_autograd(policy, old, ref, adv, completions_mask)
|
| 8 |
+
loss.backward() # propagates through to the upstream model
|
| 9 |
+
|
| 10 |
+
instead of the manual ``policy.backward(grad)`` chain. The cost is
|
| 11 |
+
~12µs of autograd dispatcher overhead per call (vs the direct
|
| 12 |
+
``bnpo_loss`` ``(loss, grad)`` tuple); for ergonomic / kernelize() flows
|
| 13 |
+
that's cheap, but for tight microbenches use the direct path.
|
| 14 |
+
|
| 15 |
+
Implementation notes:
|
| 16 |
+
|
| 17 |
+
- The registered op returns ``(loss, dpolicy)`` so ``setup_context`` can
|
| 18 |
+
``save_for_backward(dpolicy)``. The public ``bnpo_loss_autograd``
|
| 19 |
+
wrapper hides the second output.
|
| 20 |
+
- ``dpolicy`` is allocated fresh by the runner on every call (no shared
|
| 21 |
+
cache), so ``ctx.save_for_backward(dpolicy)`` keeps a stable reference
|
| 22 |
+
across subsequent calls without any extra copy.
|
| 23 |
+
- Backward returns ``grad_loss * dpolicy``. Under ``torch.compile``,
|
| 24 |
+
when ``loss`` is consumed by ``.backward()`` directly, ``grad_loss``
|
| 25 |
+
is the constant 1.0 and Inductor can fold the multiply away — that's
|
| 26 |
+
the main reason this path uses ``custom_op`` instead of a plain
|
| 27 |
+
``autograd.Function``.
|
| 28 |
+
- ``register_fake`` provides the meta kernel for ``torch.compile`` shape
|
| 29 |
+
propagation; the real cute kernel never runs under ``FakeTensorMode``.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
from __future__ import annotations
|
| 33 |
+
|
| 34 |
+
import torch
|
| 35 |
+
|
| 36 |
+
from . import bnpo_loss as _bnpo_loss_fwd_bwd
|
| 37 |
+
|
| 38 |
+
__all__ = ["bnpo_loss_autograd"]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@torch.library.custom_op(
|
| 42 |
+
"geometric_ai_kernels::_bnpo_loss_with_grad",
|
| 43 |
+
mutates_args=(),
|
| 44 |
+
)
|
| 45 |
+
def _bnpo_loss_with_grad(
|
| 46 |
+
policy_logprobs: torch.Tensor,
|
| 47 |
+
old_policy_logprobs: torch.Tensor,
|
| 48 |
+
ref_logprobs: torch.Tensor,
|
| 49 |
+
advantages: torch.Tensor,
|
| 50 |
+
completions_mask: torch.Tensor,
|
| 51 |
+
epsilon: float,
|
| 52 |
+
epsilon_high: float,
|
| 53 |
+
beta: float,
|
| 54 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 55 |
+
loss, dpolicy = _bnpo_loss_fwd_bwd(
|
| 56 |
+
policy_logprobs,
|
| 57 |
+
old_policy_logprobs,
|
| 58 |
+
ref_logprobs,
|
| 59 |
+
advantages,
|
| 60 |
+
completions_mask,
|
| 61 |
+
epsilon=epsilon,
|
| 62 |
+
epsilon_high=epsilon_high,
|
| 63 |
+
beta=beta,
|
| 64 |
+
)
|
| 65 |
+
return loss, dpolicy
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@_bnpo_loss_with_grad.register_fake
|
| 69 |
+
def _(
|
| 70 |
+
policy_logprobs: torch.Tensor,
|
| 71 |
+
old_policy_logprobs: torch.Tensor,
|
| 72 |
+
ref_logprobs: torch.Tensor,
|
| 73 |
+
advantages: torch.Tensor,
|
| 74 |
+
completions_mask: torch.Tensor,
|
| 75 |
+
epsilon: float,
|
| 76 |
+
epsilon_high: float,
|
| 77 |
+
beta: float,
|
| 78 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 79 |
+
# Signature must mirror the op; only ``policy_logprobs`` shapes the outputs.
|
| 80 |
+
del old_policy_logprobs, ref_logprobs, advantages, completions_mask
|
| 81 |
+
del epsilon, epsilon_high, beta
|
| 82 |
+
loss = policy_logprobs.new_empty(())
|
| 83 |
+
dpolicy = torch.empty_like(policy_logprobs)
|
| 84 |
+
return loss, dpolicy
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _setup_context(ctx, inputs, output) -> None: # type: ignore[no-untyped-def]
|
| 88 |
+
del inputs # only ``output`` carries what we need to save.
|
| 89 |
+
_, dpolicy = output
|
| 90 |
+
ctx.save_for_backward(dpolicy)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _backward(ctx, grad_loss, grad_dpolicy): # type: ignore[no-untyped-def]
|
| 94 |
+
# ``grad_dpolicy`` is unused — ``dpolicy`` is an internal intermediate
|
| 95 |
+
# exposed only so ``setup_context`` can save it. Under typical usage
|
| 96 |
+
# (``loss.backward()``) it arrives as ``None`` or a zero tensor.
|
| 97 |
+
del grad_dpolicy
|
| 98 |
+
(dpolicy,) = ctx.saved_tensors
|
| 99 |
+
grad_policy = grad_loss * dpolicy
|
| 100 |
+
# One return per input to the op (8): policy_logprobs gets the grad,
|
| 101 |
+
# everything else gets None (no autograd flow).
|
| 102 |
+
return grad_policy, None, None, None, None, None, None, None
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
torch.library.register_autograd(
|
| 106 |
+
"geometric_ai_kernels::_bnpo_loss_with_grad",
|
| 107 |
+
_backward,
|
| 108 |
+
setup_context=_setup_context,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def bnpo_loss_autograd(
|
| 113 |
+
policy_logprobs: torch.Tensor,
|
| 114 |
+
old_policy_logprobs: torch.Tensor,
|
| 115 |
+
ref_logprobs: torch.Tensor,
|
| 116 |
+
advantages: torch.Tensor,
|
| 117 |
+
completions_mask: torch.Tensor,
|
| 118 |
+
epsilon: float = 0.2,
|
| 119 |
+
epsilon_high: float = 0.2,
|
| 120 |
+
beta: float = 0.1,
|
| 121 |
+
) -> torch.Tensor:
|
| 122 |
+
"""Autograd-aware bnpo loss. Returns scalar ``loss``.
|
| 123 |
+
|
| 124 |
+
Same numerics as :func:`bnpo_loss` but registered as a
|
| 125 |
+
``torch.library`` custom op with autograd, so::
|
| 126 |
+
|
| 127 |
+
loss = bnpo_loss_autograd(policy, ..., completions_mask)
|
| 128 |
+
loss.backward()
|
| 129 |
+
|
| 130 |
+
propagates through to whatever produced ``policy_logprobs``. For
|
| 131 |
+
direct ``(loss, grad)`` access without the autograd dispatcher
|
| 132 |
+
overhead, use :func:`bnpo_loss` and chain the gradient manually
|
| 133 |
+
via ``policy_logprobs.backward(grad)``.
|
| 134 |
+
|
| 135 |
+
Composes with ``torch.compile``: the op is opaque to Inductor but
|
| 136 |
+
has a fake/meta kernel registered, so models containing this layer
|
| 137 |
+
can be compiled end-to-end without graph breaks.
|
| 138 |
+
"""
|
| 139 |
+
loss, _ = _bnpo_loss_with_grad(
|
| 140 |
+
policy_logprobs,
|
| 141 |
+
old_policy_logprobs,
|
| 142 |
+
ref_logprobs,
|
| 143 |
+
advantages,
|
| 144 |
+
completions_mask,
|
| 145 |
+
float(epsilon),
|
| 146 |
+
float(epsilon_high),
|
| 147 |
+
float(beta),
|
| 148 |
+
)
|
| 149 |
+
return loss
|
build/torch-cuda/bnpo_loss/cute_bnpo_loss.py
ADDED
|
@@ -0,0 +1,1081 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""CuteDSL kernel for bnpo loss.
|
| 2 |
+
|
| 3 |
+
Computes (element-wise over ``(bs, seq_len)`` logprob tensors, reduced to a
|
| 4 |
+
scalar):
|
| 5 |
+
|
| 6 |
+
ratio = exp(policy - old_policy)
|
| 7 |
+
surrogate = ratio * adv
|
| 8 |
+
clipped = clip(ratio, 1 - eps, 1 + eps_high) * adv
|
| 9 |
+
policy_loss = -min(surrogate, clipped)
|
| 10 |
+
log_ratio_ref = ref - policy
|
| 11 |
+
kl = exp(log_ratio_ref) - log_ratio_ref - 1
|
| 12 |
+
L_bnpo = (policy_loss * mask).sum() / n_valid
|
| 13 |
+
+ beta * (kl * mask).sum() / n_valid
|
| 14 |
+
|
| 15 |
+
where ``n_valid = max(completions_mask.sum(), 1)``. The mean denominator is
|
| 16 |
+
computed entirely on-GPU — the forward-only path uses an atomic accumulator
|
| 17 |
+
+ last-block trick on ``valid_acc``; the fused fwd+bwd path bundles a small
|
| 18 |
+
companion mask-sum kernel into the same ``@cute.jit`` launch that writes
|
| 19 |
+
``1 / completions_mask.sum()`` into the ``inv_total`` GMEM scalar before the
|
| 20 |
+
main kernel reads it. Every block needs ``inv_total`` mid-loop to scale its
|
| 21 |
+
``dpolicy`` slab, so the fwd-only last-block trick doesn't compose with
|
| 22 |
+
backward; bundling the mask-sum keeps both paths host-sync-free and CUDA-graph
|
| 23 |
+
compatible.
|
| 24 |
+
|
| 25 |
+
When ``beta=0`` the KL term is skipped at compile time (no ``ref`` tensor
|
| 26 |
+
access, no ``kl_acc`` atomic add).
|
| 27 |
+
|
| 28 |
+
Sequence lengths that are **not** a multiple of ``TILE_N`` are handled
|
| 29 |
+
natively: the grid launches ``ceil(seq_len / TILE_N)`` column tiles; full tiles
|
| 30 |
+
use the vectorized ``LDG.128`` path and the tail tile uses predicated vector
|
| 31 |
+
loads with neutral prefill.
|
| 32 |
+
|
| 33 |
+
Two compiled-kernel flavors are exposed:
|
| 34 |
+
|
| 35 |
+
* :func:`create_compiled_bnpo_loss` — forward-only.
|
| 36 |
+
* :func:`create_compiled_bnpo_loss_with_backward` — fused fwd+bwd. Returns
|
| 37 |
+
``(loss, dpolicy)`` directly — no ``torch.autograd.Function`` wrapper. The
|
| 38 |
+
autograd-aware sibling lives in ``autograd.py`` and uses
|
| 39 |
+
``torch.library.custom_op`` instead.
|
| 40 |
+
|
| 41 |
+
Per-call output (``loss``, ``dpolicy``, ``inv_total``) is allocated inside the
|
| 42 |
+
runner. Cross-CTA scratch (atomic accumulators + counters) is allocated lazily
|
| 43 |
+
on first call inside the compiled-kernel closure and self-resets each launch
|
| 44 |
+
via the kernel's last-block epilogue + ``atom.inc.u32`` wrap-around — callers
|
| 45 |
+
don't manage scratch state.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
from __future__ import annotations
|
| 49 |
+
|
| 50 |
+
import math
|
| 51 |
+
import operator
|
| 52 |
+
from typing import TYPE_CHECKING, Any
|
| 53 |
+
from typing import cast as _typing_cast
|
| 54 |
+
|
| 55 |
+
import cutlass
|
| 56 |
+
import cutlass.utils
|
| 57 |
+
import torch
|
| 58 |
+
from cutlass import cute
|
| 59 |
+
from cutlass._mlir.dialects import llvm
|
| 60 |
+
from cutlass.base_dsl.typing import cast
|
| 61 |
+
from cutlass.cutlass_dsl import T, dsl_user_op
|
| 62 |
+
|
| 63 |
+
if TYPE_CHECKING:
|
| 64 |
+
from collections.abc import Callable
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
TILE_N: int = 512
|
| 68 |
+
NUM_WARPS: int = 4
|
| 69 |
+
# ``VEC=4`` (fp32) emits 128-bit ``LDG.128``. Pairs with ``NUM_WARPS=4`` so
|
| 70 |
+
# each block processes ``block_size * VEC = 512 = TILE_N`` elements per iter.
|
| 71 |
+
VEC: int = 4
|
| 72 |
+
# Large-tile variant: at very long ``seq_len`` the small-TILE_N grid
|
| 73 |
+
# explodes (e.g. 8192/512 = 16 col-tiles per row → thousands of CTAs),
|
| 74 |
+
# inflating last-block-detection latency and atomic contention. A second
|
| 75 |
+
# compiled variant with this larger tile is dispatched when
|
| 76 |
+
# ``seq_len >= TILE_N_LARGE_THRESHOLD``.
|
| 77 |
+
TILE_N_LARGE: int = 4096
|
| 78 |
+
TILE_N_LARGE_THRESHOLD: int = 2048
|
| 79 |
+
|
| 80 |
+
_LOG2_E: float = math.log2(math.e)
|
| 81 |
+
|
| 82 |
+
_TORCH_TO_CUTLASS_DTYPE: dict[torch.dtype, Any] = {
|
| 83 |
+
torch.float32: cutlass.Float32,
|
| 84 |
+
torch.float16: cutlass.Float16,
|
| 85 |
+
torch.bfloat16: cutlass.BFloat16,
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@dsl_user_op
|
| 90 |
+
def _atomic_add_f32_gmem(
|
| 91 |
+
ptr_i64: Any,
|
| 92 |
+
val: cutlass.Float32,
|
| 93 |
+
*,
|
| 94 |
+
loc: Any = None,
|
| 95 |
+
ip: Any = None,
|
| 96 |
+
) -> None:
|
| 97 |
+
llvm.inline_asm(
|
| 98 |
+
T.f32(),
|
| 99 |
+
[ptr_i64, cutlass.Float32(val).ir_value(loc=loc, ip=ip)],
|
| 100 |
+
"atom.global.add.f32 $0, [$1], $2;",
|
| 101 |
+
"=f,l,f",
|
| 102 |
+
has_side_effects=True,
|
| 103 |
+
is_align_stack=False,
|
| 104 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
@dsl_user_op
|
| 109 |
+
def _atomic_add_s32_gmem(
|
| 110 |
+
ptr_i64: Any,
|
| 111 |
+
val: cutlass.Int32,
|
| 112 |
+
*,
|
| 113 |
+
loc: Any = None,
|
| 114 |
+
ip: Any = None,
|
| 115 |
+
) -> None:
|
| 116 |
+
"""Emit ``atom.global.add.s32`` to a 64-bit GMEM address."""
|
| 117 |
+
llvm.inline_asm(
|
| 118 |
+
T.i32(),
|
| 119 |
+
[ptr_i64, cutlass.Int32(val).ir_value(loc=loc, ip=ip)],
|
| 120 |
+
"atom.global.add.s32 $0, [$1], $2;",
|
| 121 |
+
"=r,l,r",
|
| 122 |
+
has_side_effects=True,
|
| 123 |
+
is_align_stack=False,
|
| 124 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
@dsl_user_op
|
| 129 |
+
def _dp4a_u32_acc_s32(
|
| 130 |
+
packed_a: cutlass.Uint32,
|
| 131 |
+
packed_b: cutlass.Uint32,
|
| 132 |
+
acc: cutlass.Int32,
|
| 133 |
+
*,
|
| 134 |
+
loc: Any = None,
|
| 135 |
+
ip: Any = None,
|
| 136 |
+
) -> cutlass.Int32:
|
| 137 |
+
"""``dp4a.u32.u32`` — sum 4 packed u8 products into an s32 acc.
|
| 138 |
+
|
| 139 |
+
Computes ``a[0]*b[0] + a[1]*b[1] + a[2]*b[2] + a[3]*b[3] + acc`` in
|
| 140 |
+
one ``IDP4A.U8.S32`` instruction (full-rate on Hopper/Blackwell).
|
| 141 |
+
For mask summation, pass ``packed_b = 0x01010101`` so the products
|
| 142 |
+
reduce to ``sum(a_bytes) + acc`` — 4× fewer ALU ops than 4 separate
|
| 143 |
+
int8→int32 widens + adds.
|
| 144 |
+
"""
|
| 145 |
+
return cutlass.Int32(
|
| 146 |
+
llvm.inline_asm(
|
| 147 |
+
T.i32(),
|
| 148 |
+
[
|
| 149 |
+
cutlass.Uint32(packed_a).ir_value(loc=loc, ip=ip),
|
| 150 |
+
cutlass.Uint32(packed_b).ir_value(loc=loc, ip=ip),
|
| 151 |
+
cutlass.Int32(acc).ir_value(loc=loc, ip=ip),
|
| 152 |
+
],
|
| 153 |
+
"dp4a.u32.u32 $0, $1, $2, $3;",
|
| 154 |
+
"=r,r,r,r",
|
| 155 |
+
has_side_effects=False,
|
| 156 |
+
is_align_stack=False,
|
| 157 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 158 |
+
)
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
@dsl_user_op
|
| 163 |
+
def _atomic_inc_u32_gmem(
|
| 164 |
+
ptr_i64: Any,
|
| 165 |
+
threshold: cutlass.Int32,
|
| 166 |
+
*,
|
| 167 |
+
loc: Any = None,
|
| 168 |
+
ip: Any = None,
|
| 169 |
+
) -> cutlass.Int32:
|
| 170 |
+
"""``atom.global.inc.u32`` — returns old value; wraps to 0 at threshold."""
|
| 171 |
+
return cutlass.Int32(
|
| 172 |
+
llvm.inline_asm(
|
| 173 |
+
T.i32(),
|
| 174 |
+
[ptr_i64, cutlass.Int32(threshold).ir_value(loc=loc, ip=ip)],
|
| 175 |
+
"atom.global.inc.u32 $0, [$1], $2;",
|
| 176 |
+
"=r,l,r",
|
| 177 |
+
has_side_effects=True,
|
| 178 |
+
is_align_stack=False,
|
| 179 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 180 |
+
)
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
# ---------------------------------------------------------------------------
|
| 185 |
+
# Mask-sum kernel — replaces ``torch.sum(completions_mask)`` on the fwd+bwd
|
| 186 |
+
# path. Bundled into the same ``@cute.jit`` launch as the main kernel so the
|
| 187 |
+
# whole step is one tvm-ffi dispatch (no extra Python/torch dispatcher round
|
| 188 |
+
# trip). The kernel writes ``1 / completions_mask.sum()`` directly into
|
| 189 |
+
# ``inv_total_tensor`` so the main kernel reads it as a pre-inverted scalar.
|
| 190 |
+
# ---------------------------------------------------------------------------
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def _make_mask_sum_kernel(tile_n: int) -> Callable[..., None]:
|
| 194 |
+
"""Return a ``@cute.kernel`` that reduces ``completions_mask`` and writes 1/sum.
|
| 195 |
+
|
| 196 |
+
Grid mirrors the main kernel — ``(bs, num_col_tiles)`` — so the mask is
|
| 197 |
+
read once with the same vectorised LDG pattern as the main compute.
|
| 198 |
+
Each block:
|
| 199 |
+
|
| 200 |
+
1. Loads its ``tile_n`` int8 slab of ``completions_mask`` (predicated tail).
|
| 201 |
+
2. Reduces to a per-block ``int32`` scalar (bit-exact, no per-element
|
| 202 |
+
i8→f32 cast — IADD throughput equals FADD on Hopper/Blackwell).
|
| 203 |
+
3. Atomically adds it to ``valid_acc`` (global int32 accumulator).
|
| 204 |
+
4. Increments ``mask_counter``; the last block reads ``valid_acc``,
|
| 205 |
+
casts to fp32, computes ``rcp_approx`` and writes
|
| 206 |
+
``inv_total_tensor[0]``, then resets ``valid_acc`` to ``0`` so
|
| 207 |
+
the next call starts fresh. The counter self-resets via
|
| 208 |
+
``atom.inc.u32`` wrap-around.
|
| 209 |
+
|
| 210 |
+
A separate ``mask_counter`` tensor (not the main kernel's ``counter``)
|
| 211 |
+
is required because the two kernels run in series within the same
|
| 212 |
+
``@cute.jit`` and both rely on a wrap-around for self-reset; sharing
|
| 213 |
+
one counter would race.
|
| 214 |
+
"""
|
| 215 |
+
|
| 216 |
+
@cute.kernel
|
| 217 |
+
def _mask_sum_kernel(
|
| 218 |
+
completions_mask: cute.Tensor, # (bs, seq_len) int8
|
| 219 |
+
inv_total_tensor: cute.Tensor, # (1,) fp32 — output
|
| 220 |
+
valid_acc: cute.Tensor, # (1,) int32 — accumulator
|
| 221 |
+
mask_counter: cute.Tensor, # (1,) i32 — last-block detection
|
| 222 |
+
total_blocks: cutlass.Int32,
|
| 223 |
+
num_full_tiles: cutlass.Int32,
|
| 224 |
+
tail_len: cutlass.Int32,
|
| 225 |
+
) -> None:
|
| 226 |
+
block_size = NUM_WARPS * 32
|
| 227 |
+
iters = tile_n // (block_size * VEC)
|
| 228 |
+
|
| 229 |
+
_no_alloc = cute.nvgpu.CacheEvictionPriority.NO_ALLOCATE
|
| 230 |
+
g2r_op = cute.nvgpu.CopyUniversalOp()
|
| 231 |
+
g2r_mask_atom = cute.make_copy_atom(
|
| 232 |
+
g2r_op,
|
| 233 |
+
completions_mask.element_type,
|
| 234 |
+
num_bits_per_copy=0,
|
| 235 |
+
l1c_evict_priority=_no_alloc,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
row = cute.arch.block_idx()[0]
|
| 239 |
+
col_block = cute.arch.block_idx()[1]
|
| 240 |
+
tid = cute.arch.thread_idx()[0]
|
| 241 |
+
|
| 242 |
+
local_valid_sum = cutlass.Int32(0)
|
| 243 |
+
mask_row = cute.slice_(completions_mask, (row, None))
|
| 244 |
+
|
| 245 |
+
# ``dp4a.u32.u32`` consumes a packed-u8x4 register. With VEC=4 each
|
| 246 |
+
# thread loads 4 contiguous int8 bytes per iteration, so we recast
|
| 247 |
+
# the fragment as a single ``Uint32`` view and feed it directly
|
| 248 |
+
# into dp4a — one instruction sums all 4 bytes, vs the previous
|
| 249 |
+
# cast+reduce which emitted 4 widens + 3 adds per iteration.
|
| 250 |
+
ones_packed = cutlass.Uint32(0x01010101)
|
| 251 |
+
|
| 252 |
+
if col_block < num_full_tiles:
|
| 253 |
+
mask_slab = cute.local_tile(mask_row, (tile_n,), (col_block,))
|
| 254 |
+
for k in cutlass.range(iters, unroll_full=True):
|
| 255 |
+
sub_idx = tid + k * block_size
|
| 256 |
+
mask_src = cute.local_tile(mask_slab, (VEC,), (sub_idx,))
|
| 257 |
+
mask_frag = cute.make_fragment_like(mask_src)
|
| 258 |
+
cute.copy(g2r_mask_atom, mask_src, mask_frag)
|
| 259 |
+
packed = cute.recast_tensor(mask_frag, cutlass.Uint32)[0]
|
| 260 |
+
local_valid_sum = _dp4a_u32_acc_s32(packed, ones_packed, local_valid_sum)
|
| 261 |
+
else:
|
| 262 |
+
mask_slab = cute.local_tile(mask_row, (tile_n,), (col_block,))
|
| 263 |
+
for k in cutlass.range(iters, unroll_full=True):
|
| 264 |
+
sub_idx = tid + k * block_size
|
| 265 |
+
chunk_base = sub_idx * VEC
|
| 266 |
+
if chunk_base < tail_len:
|
| 267 |
+
mask_src = cute.local_tile(mask_slab, (VEC,), (sub_idx,))
|
| 268 |
+
pred = cute.make_rmem_tensor(mask_src.shape, cutlass.Boolean)
|
| 269 |
+
for v in cutlass.range(VEC, unroll_full=True):
|
| 270 |
+
pred[v] = cute.elem_less(chunk_base + v, tail_len)
|
| 271 |
+
mask_frag = cute.make_fragment_like(mask_src)
|
| 272 |
+
mask_frag.fill(0)
|
| 273 |
+
cute.copy(g2r_mask_atom, mask_src, mask_frag, pred=pred)
|
| 274 |
+
packed = cute.recast_tensor(mask_frag, cutlass.Uint32)[0]
|
| 275 |
+
local_valid_sum = _dp4a_u32_acc_s32(packed, ones_packed, local_valid_sum)
|
| 276 |
+
|
| 277 |
+
# Warp + cross-warp reduction (same pattern as main kernel).
|
| 278 |
+
warp_valid = cute.arch.warp_reduction(local_valid_sum, operator.add)
|
| 279 |
+
smem = cutlass.utils.SmemAllocator()
|
| 280 |
+
buf_valid = smem.allocate_tensor(cutlass.Int32, cute.make_layout(NUM_WARPS))
|
| 281 |
+
|
| 282 |
+
lane_idx = cute.arch.lane_idx()
|
| 283 |
+
warp_idx = cute.arch.warp_idx()
|
| 284 |
+
|
| 285 |
+
if lane_idx == 0:
|
| 286 |
+
buf_valid[warp_idx] = warp_valid
|
| 287 |
+
cute.arch.barrier()
|
| 288 |
+
|
| 289 |
+
if warp_idx == 0:
|
| 290 |
+
val_v = cutlass.Int32(0)
|
| 291 |
+
if lane_idx < NUM_WARPS:
|
| 292 |
+
val_v = buf_valid[lane_idx]
|
| 293 |
+
block_valid = cute.arch.warp_reduction(val_v, operator.add, threads_in_group=NUM_WARPS)
|
| 294 |
+
|
| 295 |
+
if lane_idx == 0:
|
| 296 |
+
valid_ptr = valid_acc.iterator.toint().ir_value() # ty: ignore[unresolved-attribute]
|
| 297 |
+
counter_ptr = mask_counter.iterator.toint().ir_value() # ty: ignore[unresolved-attribute]
|
| 298 |
+
|
| 299 |
+
_atomic_add_s32_gmem(valid_ptr, block_valid)
|
| 300 |
+
cute.arch.fence_acq_rel_gpu()
|
| 301 |
+
old = _atomic_inc_u32_gmem(counter_ptr, total_blocks - 1)
|
| 302 |
+
|
| 303 |
+
if old == total_blocks - 1:
|
| 304 |
+
# Clamp to >=1.0 so a fully-masked batch (n_valid=0)
|
| 305 |
+
# produces ``loss=0`` instead of inf/NaN — matches
|
| 306 |
+
# TRL's ``mask.sum().clamp(min=1)`` semantics.
|
| 307 |
+
n_valid = cute.arch.fmax(cutlass.Float32(valid_acc[0]), cutlass.Float32(1.0))
|
| 308 |
+
inv_total_tensor[0] = cute.arch.rcp_approx(n_valid)
|
| 309 |
+
valid_acc[0] = cutlass.Int32(0)
|
| 310 |
+
|
| 311 |
+
return _mask_sum_kernel
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def _make_bnpo_kernel(
|
| 315 |
+
compute_kl: bool,
|
| 316 |
+
compute_backward: bool,
|
| 317 |
+
tile_n: int,
|
| 318 |
+
) -> Callable[..., None]:
|
| 319 |
+
"""Return a ``@cute.kernel`` specialised on compile-time flags.
|
| 320 |
+
|
| 321 |
+
The returned kernel captures *compute_kl*, *compute_backward*, and
|
| 322 |
+
*tile_n* in its closure. ``cutlass.const_expr`` evaluates the booleans
|
| 323 |
+
at trace time so dead branches are eliminated from the compiled PTX.
|
| 324 |
+
``tile_n`` is a Python ``int`` captured at trace time, so the same
|
| 325 |
+
factory can emit two specialised kernels (small / large tile) — see
|
| 326 |
+
:func:`create_compiled_bnpo_loss` for dispatch.
|
| 327 |
+
|
| 328 |
+
When *compute_backward* is True the kernel additionally writes
|
| 329 |
+
``dpolicy = dL/d(policy_logprobs)`` to GMEM in the same inner loop —
|
| 330 |
+
no extra HBM reads of the inputs. Because every block must scale
|
| 331 |
+
``dpolicy`` by ``inv_total`` mid-loop, the on-GPU last-block computation
|
| 332 |
+
of ``inv_total`` from the masked accumulator does **not** compose with
|
| 333 |
+
backward; the bundled mask-sum kernel populates ``inv_total_tensor``
|
| 334 |
+
before the main kernel runs.
|
| 335 |
+
|
| 336 |
+
When *compute_backward* is False the kernel accumulates the
|
| 337 |
+
mask-element count via ``valid_acc`` and computes
|
| 338 |
+
``inv_total = 1 / n_valid`` on-GPU in the last-block path — no
|
| 339 |
+
host-side ``completions_mask.sum()`` required.
|
| 340 |
+
"""
|
| 341 |
+
|
| 342 |
+
@cute.kernel
|
| 343 |
+
def _bnpo_loss_kernel(
|
| 344 |
+
policy: cute.Tensor,
|
| 345 |
+
old_policy: cute.Tensor,
|
| 346 |
+
ref: cute.Tensor,
|
| 347 |
+
advantages: cute.Tensor,
|
| 348 |
+
completions_mask: cute.Tensor,
|
| 349 |
+
dpolicy: cute.Tensor, # (bs, seq_len) when compute_backward; (bs, 1) dummy otherwise
|
| 350 |
+
inv_total_tensor: cute.Tensor, # (1,) fp32 — caller-populated 1/n_valid
|
| 351 |
+
policy_acc: cute.Tensor,
|
| 352 |
+
kl_acc: cute.Tensor,
|
| 353 |
+
valid_acc: cute.Tensor, # (1,) int32 — mask-element count accumulator
|
| 354 |
+
counter: cute.Tensor,
|
| 355 |
+
output: cute.Tensor,
|
| 356 |
+
epsilon: cutlass.Float32,
|
| 357 |
+
epsilon_high: cutlass.Float32,
|
| 358 |
+
beta: cutlass.Float32,
|
| 359 |
+
total_blocks: cutlass.Int32,
|
| 360 |
+
num_full_tiles: cutlass.Int32,
|
| 361 |
+
tail_len: cutlass.Int32,
|
| 362 |
+
) -> None:
|
| 363 |
+
block_size = NUM_WARPS * 32
|
| 364 |
+
iters = tile_n // (block_size * VEC)
|
| 365 |
+
|
| 366 |
+
# Read inv_total from GMEM once per block (hoisted, single load).
|
| 367 |
+
# Skipped on the fwd-only path which uses an on-GPU last-block
|
| 368 |
+
# computation from the valid_acc accumulator instead. On the
|
| 369 |
+
# compute_backward path the bundled mask-sum kernel writes
|
| 370 |
+
# ``1 / completions_mask.sum()`` into ``inv_total_tensor`` before
|
| 371 |
+
# this kernel runs, so the load returns the pre-inverted scalar.
|
| 372 |
+
accumulate_valid = not compute_backward
|
| 373 |
+
if cutlass.const_expr(not accumulate_valid):
|
| 374 |
+
inv_total = cast(inv_total_tensor[0], cutlass.Float32)
|
| 375 |
+
|
| 376 |
+
_no_alloc = cute.nvgpu.CacheEvictionPriority.NO_ALLOCATE
|
| 377 |
+
g2r_op = cute.nvgpu.CopyUniversalOp()
|
| 378 |
+
g2r_atom = cute.make_copy_atom(
|
| 379 |
+
g2r_op,
|
| 380 |
+
policy.element_type,
|
| 381 |
+
num_bits_per_copy=0,
|
| 382 |
+
l1c_evict_priority=_no_alloc,
|
| 383 |
+
)
|
| 384 |
+
g2r_mask_atom = cute.make_copy_atom(
|
| 385 |
+
g2r_op,
|
| 386 |
+
completions_mask.element_type,
|
| 387 |
+
num_bits_per_copy=0,
|
| 388 |
+
l1c_evict_priority=_no_alloc,
|
| 389 |
+
)
|
| 390 |
+
if cutlass.const_expr(compute_backward):
|
| 391 |
+
r2g_atom = cute.make_copy_atom(
|
| 392 |
+
g2r_op,
|
| 393 |
+
dpolicy.element_type,
|
| 394 |
+
num_bits_per_copy=0,
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
row = cute.arch.block_idx()[0]
|
| 398 |
+
col_block = cute.arch.block_idx()[1]
|
| 399 |
+
tid = cute.arch.thread_idx()[0]
|
| 400 |
+
|
| 401 |
+
adv = cast(advantages[row], cutlass.Float32)
|
| 402 |
+
lo = cutlass.Float32(1.0) - epsilon
|
| 403 |
+
hi = cutlass.Float32(1.0) + epsilon_high
|
| 404 |
+
|
| 405 |
+
local_policy_sum = cutlass.Float32(0.0)
|
| 406 |
+
local_kl_sum = cutlass.Float32(0.0)
|
| 407 |
+
# mask_vec is already cast to fp32 for loss/kl multiplications, so
|
| 408 |
+
# accumulate valid in fp32 too (avoids a separate i8→i32 reduction).
|
| 409 |
+
# Cast to int32 only at the atomic boundary so the shared
|
| 410 |
+
# ``valid_acc`` global can remain int32 — see ``_atomic_add_s32_gmem``.
|
| 411 |
+
local_valid_sum = cutlass.Float32(0.0)
|
| 412 |
+
|
| 413 |
+
pol_row = cute.slice_(policy, (row, None))
|
| 414 |
+
old_row = cute.slice_(old_policy, (row, None))
|
| 415 |
+
|
| 416 |
+
if cutlass.const_expr(compute_kl):
|
| 417 |
+
ref_row = cute.slice_(ref, (row, None))
|
| 418 |
+
|
| 419 |
+
mask_row = cute.slice_(completions_mask, (row, None))
|
| 420 |
+
|
| 421 |
+
if cutlass.const_expr(compute_backward):
|
| 422 |
+
dp_row = cute.slice_(dpolicy, (row, None))
|
| 423 |
+
|
| 424 |
+
# ---- Full-tile vectorised path (LDG.128) ----
|
| 425 |
+
if col_block < num_full_tiles:
|
| 426 |
+
pol_slab = cute.local_tile(pol_row, (tile_n,), (col_block,))
|
| 427 |
+
old_slab = cute.local_tile(old_row, (tile_n,), (col_block,))
|
| 428 |
+
|
| 429 |
+
if cutlass.const_expr(compute_kl):
|
| 430 |
+
ref_slab = cute.local_tile(ref_row, (tile_n,), (col_block,))
|
| 431 |
+
|
| 432 |
+
mask_slab = cute.local_tile(mask_row, (tile_n,), (col_block,))
|
| 433 |
+
|
| 434 |
+
if cutlass.const_expr(compute_backward):
|
| 435 |
+
dp_slab = cute.local_tile(dp_row, (tile_n,), (col_block,))
|
| 436 |
+
|
| 437 |
+
for k in cutlass.range(iters, unroll_full=True):
|
| 438 |
+
sub_idx = tid + k * block_size
|
| 439 |
+
|
| 440 |
+
pol_src = cute.local_tile(pol_slab, (VEC,), (sub_idx,))
|
| 441 |
+
old_src = cute.local_tile(old_slab, (VEC,), (sub_idx,))
|
| 442 |
+
pol_frag = cute.make_fragment_like(pol_src)
|
| 443 |
+
old_frag = cute.make_fragment_like(old_src)
|
| 444 |
+
cute.copy(g2r_atom, pol_src, pol_frag)
|
| 445 |
+
cute.copy(g2r_atom, old_src, old_frag)
|
| 446 |
+
|
| 447 |
+
if cutlass.const_expr(compute_kl):
|
| 448 |
+
ref_src = cute.local_tile(ref_slab, (VEC,), (sub_idx,))
|
| 449 |
+
ref_frag = cute.make_fragment_like(ref_src)
|
| 450 |
+
cute.copy(g2r_atom, ref_src, ref_frag)
|
| 451 |
+
|
| 452 |
+
mask_src = cute.local_tile(mask_slab, (VEC,), (sub_idx,))
|
| 453 |
+
mask_frag = cute.make_fragment_like(mask_src)
|
| 454 |
+
cute.copy(g2r_mask_atom, mask_src, mask_frag)
|
| 455 |
+
|
| 456 |
+
pol_vec = pol_frag.load().to(cutlass.Float32)
|
| 457 |
+
old_vec = old_frag.load().to(cutlass.Float32)
|
| 458 |
+
|
| 459 |
+
log_ratio = pol_vec - old_vec
|
| 460 |
+
ratio = cute.math.exp2(log_ratio * _LOG2_E, fastmath=True)
|
| 461 |
+
surrogate = ratio * adv
|
| 462 |
+
clipped_ratio = cute.where(
|
| 463 |
+
ratio < lo,
|
| 464 |
+
lo,
|
| 465 |
+
cute.where(ratio > hi, hi, ratio),
|
| 466 |
+
)
|
| 467 |
+
clipped = clipped_ratio * adv
|
| 468 |
+
policy_loss = -cute.where(surrogate < clipped, surrogate, clipped)
|
| 469 |
+
|
| 470 |
+
if cutlass.const_expr(compute_kl):
|
| 471 |
+
ref_vec = ref_frag.load().to(cutlass.Float32)
|
| 472 |
+
log_ratio_ref = ref_vec - pol_vec
|
| 473 |
+
ratio_ref = cute.math.exp2(log_ratio_ref * _LOG2_E, fastmath=True)
|
| 474 |
+
# FFMA-friendly rearrangement: ``(ratio_ref - 1) - log_ratio_ref``
|
| 475 |
+
# exposes a ``ratio_ref + (-1)`` pair that ptxas folds with
|
| 476 |
+
# the subsequent subtract — same arithmetic, fewer FADDs
|
| 477 |
+
# surviving SASS than the original 3-term ``a - b - c``.
|
| 478 |
+
kl_val = (ratio_ref - cutlass.Float32(1.0)) - log_ratio_ref
|
| 479 |
+
|
| 480 |
+
mask_vec = mask_frag.load().to(cutlass.Float32)
|
| 481 |
+
local_policy_sum += (policy_loss * mask_vec).reduce(
|
| 482 |
+
cute.ReductionOp.ADD,
|
| 483 |
+
cutlass.Float32(0.0),
|
| 484 |
+
reduction_profile=0,
|
| 485 |
+
)
|
| 486 |
+
if cutlass.const_expr(not compute_backward):
|
| 487 |
+
local_valid_sum += mask_vec.reduce(
|
| 488 |
+
cute.ReductionOp.ADD,
|
| 489 |
+
cutlass.Float32(0.0),
|
| 490 |
+
reduction_profile=0,
|
| 491 |
+
)
|
| 492 |
+
if cutlass.const_expr(compute_kl):
|
| 493 |
+
local_kl_sum += (kl_val * mask_vec).reduce(
|
| 494 |
+
cute.ReductionOp.ADD,
|
| 495 |
+
cutlass.Float32(0.0),
|
| 496 |
+
reduction_profile=0,
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
# ---- Backward: write scaled dpolicy slab in same loop ----
|
| 500 |
+
# use_unclipped = (surrogate <= clipped) — matches torch's
|
| 501 |
+
# convention. d/d(policy) of -min(surrogate, clipped) is
|
| 502 |
+
# -adv*ratio when use_unclipped, else 0 (clamp grad = 0).
|
| 503 |
+
# ``-(adv * ratio)`` is just ``-surrogate`` (already in
|
| 504 |
+
# scope) — saves one FMUL per element.
|
| 505 |
+
# KL term: d/d(policy) of (ratio_ref - log_ratio_ref - 1)
|
| 506 |
+
# = -(ratio_ref - 1) = 1 - ratio_ref.
|
| 507 |
+
if cutlass.const_expr(compute_backward):
|
| 508 |
+
neg_surrogate_grad = cute.where(
|
| 509 |
+
surrogate <= clipped,
|
| 510 |
+
-surrogate,
|
| 511 |
+
cutlass.Float32(0.0),
|
| 512 |
+
)
|
| 513 |
+
if cutlass.const_expr(compute_kl):
|
| 514 |
+
# ``beta - beta*ratio_ref`` instead of ``beta*(1 - ratio_ref)``
|
| 515 |
+
# gives ptxas an obvious FFMA pattern (``FFMA -beta,
|
| 516 |
+
# ratio_ref, beta``) — saves one FMUL per element vs
|
| 517 |
+
# the (1 - ratio_ref) intermediate.
|
| 518 |
+
kl_grad = beta - beta * ratio_ref
|
| 519 |
+
dpolicy_vec = neg_surrogate_grad + kl_grad
|
| 520 |
+
else:
|
| 521 |
+
dpolicy_vec = neg_surrogate_grad
|
| 522 |
+
dpolicy_vec = dpolicy_vec * mask_vec
|
| 523 |
+
dpolicy_vec = dpolicy_vec * inv_total
|
| 524 |
+
|
| 525 |
+
dp_dst = cute.local_tile(dp_slab, (VEC,), (sub_idx,))
|
| 526 |
+
dp_frag = cute.make_fragment_like(dp_dst)
|
| 527 |
+
dp_frag.store(dpolicy_vec.to(dpolicy.element_type))
|
| 528 |
+
cute.copy(r2g_atom, dp_frag, dp_dst)
|
| 529 |
+
|
| 530 |
+
else:
|
| 531 |
+
# ---- Predicated vector tail path (< tile_n valid elements) ----
|
| 532 |
+
pol_slab = cute.local_tile(pol_row, (tile_n,), (col_block,))
|
| 533 |
+
old_slab = cute.local_tile(old_row, (tile_n,), (col_block,))
|
| 534 |
+
|
| 535 |
+
if cutlass.const_expr(compute_kl):
|
| 536 |
+
ref_slab = cute.local_tile(ref_row, (tile_n,), (col_block,))
|
| 537 |
+
|
| 538 |
+
mask_slab = cute.local_tile(mask_row, (tile_n,), (col_block,))
|
| 539 |
+
|
| 540 |
+
if cutlass.const_expr(compute_backward):
|
| 541 |
+
dp_slab = cute.local_tile(dp_row, (tile_n,), (col_block,))
|
| 542 |
+
|
| 543 |
+
for k in cutlass.range(iters, unroll_full=True):
|
| 544 |
+
sub_idx = tid + k * block_size
|
| 545 |
+
chunk_base = sub_idx * VEC
|
| 546 |
+
|
| 547 |
+
if chunk_base < tail_len:
|
| 548 |
+
pol_src = cute.local_tile(pol_slab, (VEC,), (sub_idx,))
|
| 549 |
+
old_src = cute.local_tile(old_slab, (VEC,), (sub_idx,))
|
| 550 |
+
pred = cute.make_rmem_tensor(pol_src.shape, cutlass.Boolean)
|
| 551 |
+
for v in cutlass.range(VEC, unroll_full=True):
|
| 552 |
+
pred[v] = cute.elem_less(chunk_base + v, tail_len)
|
| 553 |
+
|
| 554 |
+
pol_frag = cute.make_fragment_like(pol_src)
|
| 555 |
+
old_frag = cute.make_fragment_like(old_src)
|
| 556 |
+
pol_frag.fill(0.0)
|
| 557 |
+
old_frag.fill(0.0)
|
| 558 |
+
cute.copy(g2r_atom, pol_src, pol_frag, pred=pred)
|
| 559 |
+
cute.copy(g2r_atom, old_src, old_frag, pred=pred)
|
| 560 |
+
|
| 561 |
+
if cutlass.const_expr(compute_kl):
|
| 562 |
+
ref_src = cute.local_tile(ref_slab, (VEC,), (sub_idx,))
|
| 563 |
+
ref_frag = cute.make_fragment_like(ref_src)
|
| 564 |
+
ref_frag.fill(0.0)
|
| 565 |
+
cute.copy(g2r_atom, ref_src, ref_frag, pred=pred)
|
| 566 |
+
|
| 567 |
+
mask_src = cute.local_tile(mask_slab, (VEC,), (sub_idx,))
|
| 568 |
+
mask_frag = cute.make_fragment_like(mask_src)
|
| 569 |
+
mask_frag.fill(0)
|
| 570 |
+
cute.copy(g2r_mask_atom, mask_src, mask_frag, pred=pred)
|
| 571 |
+
|
| 572 |
+
pol_vec = pol_frag.load().to(cutlass.Float32)
|
| 573 |
+
old_vec = old_frag.load().to(cutlass.Float32)
|
| 574 |
+
valid_vec = cute.where(
|
| 575 |
+
pred.load(),
|
| 576 |
+
cute.full_like(pol_vec, cutlass.Float32(1.0)),
|
| 577 |
+
cute.zeros_like(pol_vec, dtype=cutlass.Float32),
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
log_ratio = pol_vec - old_vec
|
| 581 |
+
ratio = cute.math.exp2(log_ratio * _LOG2_E, fastmath=True)
|
| 582 |
+
surrogate = ratio * adv
|
| 583 |
+
clipped_ratio = cute.where(
|
| 584 |
+
ratio < lo,
|
| 585 |
+
lo,
|
| 586 |
+
cute.where(ratio > hi, hi, ratio),
|
| 587 |
+
)
|
| 588 |
+
clipped = clipped_ratio * adv
|
| 589 |
+
policy_loss = -cute.where(surrogate < clipped, surrogate, clipped)
|
| 590 |
+
|
| 591 |
+
if cutlass.const_expr(compute_kl):
|
| 592 |
+
ref_vec = ref_frag.load().to(cutlass.Float32)
|
| 593 |
+
log_ratio_ref = ref_vec - pol_vec
|
| 594 |
+
ratio_ref = cute.math.exp2(log_ratio_ref * _LOG2_E, fastmath=True)
|
| 595 |
+
# FFMA-friendly rearrangement — see full-tile path.
|
| 596 |
+
kl_val = (ratio_ref - cutlass.Float32(1.0)) - log_ratio_ref
|
| 597 |
+
|
| 598 |
+
mask_vec = mask_frag.load().to(cutlass.Float32) * valid_vec
|
| 599 |
+
local_policy_sum += (policy_loss * mask_vec).reduce(
|
| 600 |
+
cute.ReductionOp.ADD,
|
| 601 |
+
cutlass.Float32(0.0),
|
| 602 |
+
reduction_profile=0,
|
| 603 |
+
)
|
| 604 |
+
if cutlass.const_expr(not compute_backward):
|
| 605 |
+
local_valid_sum += mask_vec.reduce(
|
| 606 |
+
cute.ReductionOp.ADD,
|
| 607 |
+
cutlass.Float32(0.0),
|
| 608 |
+
reduction_profile=0,
|
| 609 |
+
)
|
| 610 |
+
if cutlass.const_expr(compute_kl):
|
| 611 |
+
local_kl_sum += (kl_val * mask_vec).reduce(
|
| 612 |
+
cute.ReductionOp.ADD,
|
| 613 |
+
cutlass.Float32(0.0),
|
| 614 |
+
reduction_profile=0,
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
# ---- Backward: predicated dpolicy slab write ----
|
| 618 |
+
# Same gradient math as the full-tile path. ``valid_vec``
|
| 619 |
+
# already encodes the in-bounds predicate (1.0 inside,
|
| 620 |
+
# 0.0 outside) and is folded into ``mask_vec``, so
|
| 621 |
+
# multiplying by it zeros out the padded positions.
|
| 622 |
+
if cutlass.const_expr(compute_backward):
|
| 623 |
+
neg_surrogate_grad = cute.where(
|
| 624 |
+
surrogate <= clipped,
|
| 625 |
+
-surrogate,
|
| 626 |
+
cutlass.Float32(0.0),
|
| 627 |
+
)
|
| 628 |
+
if cutlass.const_expr(compute_kl):
|
| 629 |
+
kl_grad = beta - beta * ratio_ref
|
| 630 |
+
dpolicy_vec = neg_surrogate_grad + kl_grad
|
| 631 |
+
else:
|
| 632 |
+
dpolicy_vec = neg_surrogate_grad
|
| 633 |
+
dpolicy_vec = dpolicy_vec * mask_vec
|
| 634 |
+
dpolicy_vec = dpolicy_vec * inv_total
|
| 635 |
+
|
| 636 |
+
dp_dst = cute.local_tile(dp_slab, (VEC,), (sub_idx,))
|
| 637 |
+
dp_frag = cute.make_fragment_like(dp_dst)
|
| 638 |
+
dp_frag.store(dpolicy_vec.to(dpolicy.element_type))
|
| 639 |
+
cute.copy(r2g_atom, dp_frag, dp_dst, pred=pred)
|
| 640 |
+
|
| 641 |
+
# ---- Stage 1: Intra-warp reduction (butterfly XOR shuffles) ----
|
| 642 |
+
warp_policy = cute.arch.warp_reduction(local_policy_sum, operator.add)
|
| 643 |
+
if cutlass.const_expr(compute_kl):
|
| 644 |
+
warp_kl = cute.arch.warp_reduction(local_kl_sum, operator.add)
|
| 645 |
+
|
| 646 |
+
smem = cutlass.utils.SmemAllocator()
|
| 647 |
+
buf_policy = smem.allocate_tensor(cutlass.Float32, cute.make_layout(NUM_WARPS))
|
| 648 |
+
if cutlass.const_expr(compute_kl):
|
| 649 |
+
buf_kl = smem.allocate_tensor(cutlass.Float32, cute.make_layout(NUM_WARPS))
|
| 650 |
+
|
| 651 |
+
lane_idx = cute.arch.lane_idx()
|
| 652 |
+
warp_idx = cute.arch.warp_idx()
|
| 653 |
+
|
| 654 |
+
# When compute_backward is True the bundled mask-sum kernel populates
|
| 655 |
+
# inv_total_tensor before this kernel runs, so on-GPU mask-element
|
| 656 |
+
# accumulation is dead code.
|
| 657 |
+
if cutlass.const_expr(accumulate_valid):
|
| 658 |
+
warp_valid = cute.arch.warp_reduction(local_valid_sum, operator.add)
|
| 659 |
+
buf_valid = smem.allocate_tensor(cutlass.Float32, cute.make_layout(NUM_WARPS))
|
| 660 |
+
|
| 661 |
+
# ---- Stage 2: Cross-warp reduction via SMEM ----
|
| 662 |
+
if lane_idx == 0:
|
| 663 |
+
buf_policy[warp_idx] = warp_policy
|
| 664 |
+
if cutlass.const_expr(compute_kl):
|
| 665 |
+
buf_kl[warp_idx] = warp_kl
|
| 666 |
+
if cutlass.const_expr(accumulate_valid):
|
| 667 |
+
buf_valid[warp_idx] = warp_valid
|
| 668 |
+
cute.arch.barrier()
|
| 669 |
+
|
| 670 |
+
if warp_idx == 0:
|
| 671 |
+
val_p = cutlass.Float32(0.0)
|
| 672 |
+
if lane_idx < NUM_WARPS:
|
| 673 |
+
val_p = buf_policy[lane_idx]
|
| 674 |
+
|
| 675 |
+
block_policy = cute.arch.warp_reduction(val_p, operator.add, threads_in_group=NUM_WARPS)
|
| 676 |
+
|
| 677 |
+
if cutlass.const_expr(compute_kl):
|
| 678 |
+
val_k = cutlass.Float32(0.0)
|
| 679 |
+
if lane_idx < NUM_WARPS:
|
| 680 |
+
val_k = buf_kl[lane_idx]
|
| 681 |
+
block_kl = cute.arch.warp_reduction(val_k, operator.add, threads_in_group=NUM_WARPS)
|
| 682 |
+
|
| 683 |
+
if cutlass.const_expr(accumulate_valid):
|
| 684 |
+
val_v = cutlass.Float32(0.0)
|
| 685 |
+
if lane_idx < NUM_WARPS:
|
| 686 |
+
val_v = buf_valid[lane_idx]
|
| 687 |
+
block_valid = cute.arch.warp_reduction(
|
| 688 |
+
val_v, operator.add, threads_in_group=NUM_WARPS
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
# ---- Stage 3: Cross-CTA atomic accumulation ----
|
| 692 |
+
if lane_idx == 0:
|
| 693 |
+
policy_ptr = policy_acc.iterator.toint().ir_value() # ty: ignore[unresolved-attribute]
|
| 694 |
+
counter_ptr = counter.iterator.toint().ir_value() # ty: ignore[unresolved-attribute]
|
| 695 |
+
|
| 696 |
+
_atomic_add_f32_gmem(policy_ptr, block_policy)
|
| 697 |
+
|
| 698 |
+
if cutlass.const_expr(compute_kl):
|
| 699 |
+
kl_ptr = kl_acc.iterator.toint().ir_value() # ty: ignore[unresolved-attribute]
|
| 700 |
+
_atomic_add_f32_gmem(kl_ptr, block_kl)
|
| 701 |
+
|
| 702 |
+
if cutlass.const_expr(accumulate_valid):
|
| 703 |
+
valid_ptr = valid_acc.iterator.toint().ir_value() # ty: ignore[unresolved-attribute]
|
| 704 |
+
# valid_acc is int32. Per-block sums of int8 0/1 values
|
| 705 |
+
# fit exactly in fp32 (≤ tile_n ≤ 4096 ≪ 2²⁴) so the
|
| 706 |
+
# cast is bit-exact.
|
| 707 |
+
_atomic_add_s32_gmem(valid_ptr, cutlass.Int32(block_valid))
|
| 708 |
+
|
| 709 |
+
cute.arch.fence_acq_rel_gpu()
|
| 710 |
+
|
| 711 |
+
old = _atomic_inc_u32_gmem(counter_ptr, total_blocks - 1)
|
| 712 |
+
|
| 713 |
+
if old == total_blocks - 1:
|
| 714 |
+
pol_sum = policy_acc[0]
|
| 715 |
+
|
| 716 |
+
if cutlass.const_expr(accumulate_valid):
|
| 717 |
+
# Clamp to >=1.0 so a fully-masked batch (n_valid=0)
|
| 718 |
+
# produces ``loss=0`` instead of inf/NaN — matches
|
| 719 |
+
# TRL's ``mask.sum().clamp(min=1)`` semantics.
|
| 720 |
+
n_valid = cute.arch.fmax(
|
| 721 |
+
cutlass.Float32(valid_acc[0]), cutlass.Float32(1.0)
|
| 722 |
+
)
|
| 723 |
+
inv_total_computed = cute.arch.rcp_approx(n_valid)
|
| 724 |
+
else:
|
| 725 |
+
# compute_backward path: bundled mask-sum kernel
|
| 726 |
+
# already wrote the inverse so forward and backward
|
| 727 |
+
# share the same scalar.
|
| 728 |
+
inv_total_computed = inv_total
|
| 729 |
+
|
| 730 |
+
if cutlass.const_expr(compute_kl):
|
| 731 |
+
kl_sum = kl_acc[0]
|
| 732 |
+
loss = (pol_sum + beta * kl_sum) * inv_total_computed
|
| 733 |
+
else:
|
| 734 |
+
loss = pol_sum * inv_total_computed
|
| 735 |
+
output[0] = cast(loss, output.element_type) # ty: ignore[invalid-argument-type]
|
| 736 |
+
|
| 737 |
+
# Reset accumulators for the next invocation.
|
| 738 |
+
# Counter self-resets via atom.inc wrap-around.
|
| 739 |
+
policy_acc[0] = cutlass.Float32(0.0)
|
| 740 |
+
if cutlass.const_expr(compute_kl):
|
| 741 |
+
kl_acc[0] = cutlass.Float32(0.0)
|
| 742 |
+
if cutlass.const_expr(accumulate_valid):
|
| 743 |
+
valid_acc[0] = cutlass.Int32(0)
|
| 744 |
+
|
| 745 |
+
return _bnpo_loss_kernel
|
| 746 |
+
|
| 747 |
+
|
| 748 |
+
def create_compiled_bnpo_loss(
|
| 749 |
+
policy_dtype: torch.dtype,
|
| 750 |
+
epsilon: float,
|
| 751 |
+
epsilon_high: float,
|
| 752 |
+
beta: float,
|
| 753 |
+
compute_backward: bool = False,
|
| 754 |
+
) -> Callable[..., torch.Tensor | tuple[torch.Tensor, torch.Tensor]]:
|
| 755 |
+
"""Compile the bnpo loss kernel for a given dtype/KL/backward configuration.
|
| 756 |
+
|
| 757 |
+
The runner allocates per-call scratch (``output``, ``inv_total``, and on
|
| 758 |
+
the fwd+bwd path ``dpolicy``) inside ``_run`` itself; cross-CTA scratch
|
| 759 |
+
(atomic accumulators + counters) is allocated lazily on first call from
|
| 760 |
+
the input device and self-resets each launch via the kernel's last-block
|
| 761 |
+
epilogue + ``atom.inc.u32`` wrap-around.
|
| 762 |
+
"""
|
| 763 |
+
compute_kl = beta != 0.0
|
| 764 |
+
|
| 765 |
+
if policy_dtype not in _TORCH_TO_CUTLASS_DTYPE:
|
| 766 |
+
raise ValueError(f"Unsupported dtype for bnpo kernel: {policy_dtype}")
|
| 767 |
+
|
| 768 |
+
tile_n_small = TILE_N
|
| 769 |
+
tile_n_large = TILE_N_LARGE
|
| 770 |
+
seq_len_threshold = TILE_N_LARGE_THRESHOLD
|
| 771 |
+
block_size = NUM_WARPS * 32
|
| 772 |
+
if tile_n_small % (block_size * VEC) != 0:
|
| 773 |
+
raise ValueError(
|
| 774 |
+
f"TILE_N={tile_n_small} must be a multiple of BLOCK_SIZE*VEC={block_size * VEC}"
|
| 775 |
+
)
|
| 776 |
+
if tile_n_large % (block_size * VEC) != 0:
|
| 777 |
+
raise ValueError(
|
| 778 |
+
f"TILE_N_LARGE={tile_n_large} must be a multiple of BLOCK_SIZE*VEC={block_size * VEC}"
|
| 779 |
+
)
|
| 780 |
+
|
| 781 |
+
bs_sym = cute.sym_int()
|
| 782 |
+
seq_len_sym = cute.sym_int()
|
| 783 |
+
cute_dtype = _TORCH_TO_CUTLASS_DTYPE[policy_dtype]
|
| 784 |
+
|
| 785 |
+
def _fake2d(dt: Any, cols: Any) -> Any:
|
| 786 |
+
return cute.runtime.make_fake_compact_tensor(
|
| 787 |
+
dt,
|
| 788 |
+
(bs_sym, cols),
|
| 789 |
+
stride_order=(1, 0),
|
| 790 |
+
assumed_align=16,
|
| 791 |
+
)
|
| 792 |
+
|
| 793 |
+
fake_pol = _fake2d(cute_dtype, seq_len_sym)
|
| 794 |
+
fake_old = _fake2d(cute_dtype, seq_len_sym)
|
| 795 |
+
fake_ref = _fake2d(cute_dtype, seq_len_sym)
|
| 796 |
+
fake_adv = cute.runtime.make_fake_compact_tensor(
|
| 797 |
+
cute_dtype,
|
| 798 |
+
(bs_sym,),
|
| 799 |
+
assumed_align=16,
|
| 800 |
+
)
|
| 801 |
+
fake_mask = cute.runtime.make_fake_compact_tensor(
|
| 802 |
+
cutlass.Int8,
|
| 803 |
+
(bs_sym, seq_len_sym),
|
| 804 |
+
stride_order=(1, 0),
|
| 805 |
+
assumed_align=16,
|
| 806 |
+
)
|
| 807 |
+
dpolicy_cols = seq_len_sym if compute_backward else 1
|
| 808 |
+
fake_dpolicy = cute.runtime.make_fake_compact_tensor(
|
| 809 |
+
cute_dtype,
|
| 810 |
+
(bs_sym, dpolicy_cols),
|
| 811 |
+
stride_order=(1, 0),
|
| 812 |
+
assumed_align=16,
|
| 813 |
+
)
|
| 814 |
+
fake_scalar_f32 = cute.runtime.make_fake_compact_tensor(
|
| 815 |
+
cutlass.Float32,
|
| 816 |
+
(1,),
|
| 817 |
+
assumed_align=16,
|
| 818 |
+
)
|
| 819 |
+
fake_valid_acc = cute.runtime.make_fake_compact_tensor(
|
| 820 |
+
cutlass.Int32,
|
| 821 |
+
(1,),
|
| 822 |
+
assumed_align=16,
|
| 823 |
+
)
|
| 824 |
+
fake_counter = cute.runtime.make_fake_compact_tensor(
|
| 825 |
+
cutlass.Int32,
|
| 826 |
+
(1,),
|
| 827 |
+
assumed_align=16,
|
| 828 |
+
)
|
| 829 |
+
fake_mask_counter = cute.runtime.make_fake_compact_tensor(
|
| 830 |
+
cutlass.Int32,
|
| 831 |
+
(1,),
|
| 832 |
+
assumed_align=16,
|
| 833 |
+
)
|
| 834 |
+
fake_output = cute.runtime.make_fake_compact_tensor(
|
| 835 |
+
cute_dtype,
|
| 836 |
+
(1,),
|
| 837 |
+
assumed_align=16,
|
| 838 |
+
)
|
| 839 |
+
|
| 840 |
+
def _build_launch(tile_n_v: int) -> Callable[..., None]:
|
| 841 |
+
"""Build a ``@cute.jit`` ``_launch`` for a given ``tile_n``.
|
| 842 |
+
|
| 843 |
+
Captures *tile_n_v* via closure; both the main kernel and the
|
| 844 |
+
(optional) mask-sum kernel are specialised to this tile size.
|
| 845 |
+
One ``_launch`` per tier; the runner dispatches at call time.
|
| 846 |
+
"""
|
| 847 |
+
specialized_kernel = _make_bnpo_kernel(compute_kl, compute_backward, tile_n_v)
|
| 848 |
+
if compute_backward:
|
| 849 |
+
mask_sum_kernel = _make_mask_sum_kernel(tile_n_v)
|
| 850 |
+
|
| 851 |
+
@cute.jit
|
| 852 |
+
def _launch(
|
| 853 |
+
pol_ct: cute.Tensor,
|
| 854 |
+
old_ct: cute.Tensor,
|
| 855 |
+
ref_ct: cute.Tensor,
|
| 856 |
+
adv_ct: cute.Tensor,
|
| 857 |
+
mask_ct: cute.Tensor,
|
| 858 |
+
dpolicy_ct: cute.Tensor,
|
| 859 |
+
inv_total_ct: cute.Tensor,
|
| 860 |
+
policy_acc_ct: cute.Tensor,
|
| 861 |
+
kl_acc_ct: cute.Tensor,
|
| 862 |
+
valid_acc_ct: cute.Tensor,
|
| 863 |
+
counter_ct: cute.Tensor,
|
| 864 |
+
mask_counter_ct: cute.Tensor,
|
| 865 |
+
output_ct: cute.Tensor,
|
| 866 |
+
epsilon_v: cutlass.Float32,
|
| 867 |
+
epsilon_high_v: cutlass.Float32,
|
| 868 |
+
beta_v: cutlass.Float32,
|
| 869 |
+
total_blocks_v: cutlass.Int32,
|
| 870 |
+
num_full_tiles_v: cutlass.Int32,
|
| 871 |
+
tail_len_v: cutlass.Int32,
|
| 872 |
+
num_col_tiles_v: cutlass.Int32,
|
| 873 |
+
) -> None:
|
| 874 |
+
bs_v = pol_ct.shape[0] # ty: ignore[not-subscriptable]
|
| 875 |
+
# Bundled mask-sum (compute_backward only) — writes
|
| 876 |
+
# ``1 / completions_mask.sum()`` into ``inv_total_ct`` before the
|
| 877 |
+
# main kernel reads it. Both kernels in one tvm-ffi dispatch
|
| 878 |
+
# eliminates the per-call ``torch.sum`` + reciprocal round trip.
|
| 879 |
+
if cutlass.const_expr(compute_backward):
|
| 880 |
+
mask_sum_kernel( # ty: ignore[unresolved-attribute]
|
| 881 |
+
mask_ct,
|
| 882 |
+
inv_total_ct,
|
| 883 |
+
valid_acc_ct,
|
| 884 |
+
mask_counter_ct,
|
| 885 |
+
total_blocks_v,
|
| 886 |
+
num_full_tiles_v,
|
| 887 |
+
tail_len_v,
|
| 888 |
+
).launch(
|
| 889 |
+
grid=(bs_v, num_col_tiles_v, 1),
|
| 890 |
+
block=(NUM_WARPS * 32, 1, 1),
|
| 891 |
+
)
|
| 892 |
+
specialized_kernel( # ty: ignore[unresolved-attribute]
|
| 893 |
+
pol_ct,
|
| 894 |
+
old_ct,
|
| 895 |
+
ref_ct,
|
| 896 |
+
adv_ct,
|
| 897 |
+
mask_ct,
|
| 898 |
+
dpolicy_ct,
|
| 899 |
+
inv_total_ct,
|
| 900 |
+
policy_acc_ct,
|
| 901 |
+
kl_acc_ct,
|
| 902 |
+
valid_acc_ct,
|
| 903 |
+
counter_ct,
|
| 904 |
+
output_ct,
|
| 905 |
+
epsilon_v,
|
| 906 |
+
epsilon_high_v,
|
| 907 |
+
beta_v,
|
| 908 |
+
total_blocks_v,
|
| 909 |
+
num_full_tiles_v,
|
| 910 |
+
tail_len_v,
|
| 911 |
+
).launch(
|
| 912 |
+
grid=(bs_v, num_col_tiles_v, 1),
|
| 913 |
+
block=(NUM_WARPS * 32, 1, 1),
|
| 914 |
+
)
|
| 915 |
+
|
| 916 |
+
return _launch
|
| 917 |
+
|
| 918 |
+
def _compile_launch(launch_fn: Callable[..., None]) -> Callable[..., None]:
|
| 919 |
+
return cute.compile(
|
| 920 |
+
launch_fn,
|
| 921 |
+
fake_pol,
|
| 922 |
+
fake_old,
|
| 923 |
+
fake_ref,
|
| 924 |
+
fake_adv,
|
| 925 |
+
fake_mask,
|
| 926 |
+
fake_dpolicy,
|
| 927 |
+
fake_scalar_f32,
|
| 928 |
+
fake_scalar_f32,
|
| 929 |
+
fake_scalar_f32,
|
| 930 |
+
fake_valid_acc,
|
| 931 |
+
fake_counter,
|
| 932 |
+
fake_mask_counter,
|
| 933 |
+
fake_output,
|
| 934 |
+
cutlass.Float32(epsilon),
|
| 935 |
+
cutlass.Float32(epsilon_high),
|
| 936 |
+
cutlass.Float32(beta),
|
| 937 |
+
cutlass.Int32(1),
|
| 938 |
+
cutlass.Int32(1),
|
| 939 |
+
cutlass.Int32(0),
|
| 940 |
+
cutlass.Int32(1),
|
| 941 |
+
options="--enable-tvm-ffi",
|
| 942 |
+
)
|
| 943 |
+
|
| 944 |
+
compiled_small = _compile_launch(_build_launch(tile_n_small))
|
| 945 |
+
if tile_n_large == tile_n_small:
|
| 946 |
+
compiled_large = compiled_small
|
| 947 |
+
else:
|
| 948 |
+
compiled_large = _compile_launch(_build_launch(tile_n_large))
|
| 949 |
+
|
| 950 |
+
eps_const = cutlass.Float32(epsilon)
|
| 951 |
+
eps_high_const = cutlass.Float32(epsilon_high)
|
| 952 |
+
beta_const = cutlass.Float32(beta)
|
| 953 |
+
|
| 954 |
+
# Cross-CTA scratch slab — one int32 buffer with stride-4 (16-byte) slices
|
| 955 |
+
# so each slot is individually 16-byte aligned (``assumed_align=16`` at
|
| 956 |
+
# compile time). Bit-pattern of int32 0 equals fp32 0.0, so a single
|
| 957 |
+
# ``zeros`` factory legitimately initialises both the int32 counters and
|
| 958 |
+
# the fp32 accumulators. The kernel's last block self-resets accumulators
|
| 959 |
+
# in its epilogue and the counters self-reset via ``atom.inc.u32``
|
| 960 |
+
# wrap-around, so the up-front ``torch.zeros`` only matters for the very
|
| 961 |
+
# first call.
|
| 962 |
+
_scratch: list[torch.Tensor | None] = [None]
|
| 963 |
+
|
| 964 |
+
def _ensure_scratch(device: torch.device) -> tuple[torch.Tensor, ...]:
|
| 965 |
+
s = _scratch[0]
|
| 966 |
+
if s is None or s.device != device:
|
| 967 |
+
s = torch.zeros(20, dtype=torch.int32, device=device)
|
| 968 |
+
_scratch[0] = s
|
| 969 |
+
return (
|
| 970 |
+
s[0:1], # counter (int32)
|
| 971 |
+
s[4:5], # mask_counter (int32)
|
| 972 |
+
s[8:9], # valid_acc (int32)
|
| 973 |
+
s[12:13].view(torch.float32), # policy_acc (fp32)
|
| 974 |
+
s[16:17].view(torch.float32), # kl_acc (fp32)
|
| 975 |
+
)
|
| 976 |
+
|
| 977 |
+
def _run(
|
| 978 |
+
policy_logprobs_r: torch.Tensor,
|
| 979 |
+
old_policy_logprobs_r: torch.Tensor,
|
| 980 |
+
ref_logprobs_r: torch.Tensor,
|
| 981 |
+
advantages_r: torch.Tensor,
|
| 982 |
+
completions_mask_r: torch.Tensor,
|
| 983 |
+
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
| 984 |
+
bs, seq_len = policy_logprobs_r.shape
|
| 985 |
+
device = policy_logprobs_r.device
|
| 986 |
+
dtype = policy_logprobs_r.dtype
|
| 987 |
+
|
| 988 |
+
# Tier dispatch: long sequences pay too much last-block-detection
|
| 989 |
+
# latency under the small-tile grid, so swap to the large-tile
|
| 990 |
+
# compiled variant.
|
| 991 |
+
if seq_len >= seq_len_threshold:
|
| 992 |
+
tile_n_active = tile_n_large
|
| 993 |
+
compiled_active = compiled_large
|
| 994 |
+
else:
|
| 995 |
+
tile_n_active = tile_n_small
|
| 996 |
+
compiled_active = compiled_small
|
| 997 |
+
num_full_tiles = seq_len // tile_n_active
|
| 998 |
+
tail_len = seq_len % tile_n_active
|
| 999 |
+
num_col_tiles = num_full_tiles + (1 if tail_len > 0 else 0)
|
| 1000 |
+
total_blocks = bs * num_col_tiles
|
| 1001 |
+
|
| 1002 |
+
# Per-call write-only buffers — ``empty`` is enough (Liger / TE
|
| 1003 |
+
# pattern). ``inv_total`` is populated by the bundled mask-sum
|
| 1004 |
+
# kernel (compute_backward path) or by the main kernel's last-block
|
| 1005 |
+
# trick (fwd-only path); the runner never reads it.
|
| 1006 |
+
output_r = torch.empty(1, dtype=dtype, device=device)
|
| 1007 |
+
inv_total_r = torch.empty(1, dtype=torch.float32, device=device)
|
| 1008 |
+
if compute_backward:
|
| 1009 |
+
dpolicy_r = torch.empty_like(policy_logprobs_r)
|
| 1010 |
+
else:
|
| 1011 |
+
dpolicy_r = torch.empty(bs, 1, dtype=dtype, device=device)
|
| 1012 |
+
|
| 1013 |
+
counter_r, mask_counter_r, valid_acc_r, policy_acc_r, kl_acc_r = _ensure_scratch(device)
|
| 1014 |
+
|
| 1015 |
+
compiled_active(
|
| 1016 |
+
policy_logprobs_r,
|
| 1017 |
+
old_policy_logprobs_r,
|
| 1018 |
+
ref_logprobs_r,
|
| 1019 |
+
advantages_r,
|
| 1020 |
+
completions_mask_r,
|
| 1021 |
+
dpolicy_r,
|
| 1022 |
+
inv_total_r,
|
| 1023 |
+
policy_acc_r,
|
| 1024 |
+
kl_acc_r,
|
| 1025 |
+
valid_acc_r,
|
| 1026 |
+
counter_r,
|
| 1027 |
+
mask_counter_r,
|
| 1028 |
+
output_r,
|
| 1029 |
+
eps_const,
|
| 1030 |
+
eps_high_const,
|
| 1031 |
+
beta_const,
|
| 1032 |
+
total_blocks,
|
| 1033 |
+
num_full_tiles,
|
| 1034 |
+
tail_len,
|
| 1035 |
+
num_col_tiles,
|
| 1036 |
+
)
|
| 1037 |
+
out_view = output_r.view(())
|
| 1038 |
+
if compute_backward:
|
| 1039 |
+
return out_view, dpolicy_r
|
| 1040 |
+
return out_view
|
| 1041 |
+
|
| 1042 |
+
return _run
|
| 1043 |
+
|
| 1044 |
+
|
| 1045 |
+
# ---------------------------------------------------------------------------
|
| 1046 |
+
# Fused forward + backward — direct (loss, grad) runner, no autograd
|
| 1047 |
+
# ---------------------------------------------------------------------------
|
| 1048 |
+
|
| 1049 |
+
|
| 1050 |
+
def create_compiled_bnpo_loss_with_backward(
|
| 1051 |
+
policy_dtype: torch.dtype,
|
| 1052 |
+
epsilon: float,
|
| 1053 |
+
epsilon_high: float,
|
| 1054 |
+
beta: float,
|
| 1055 |
+
) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]:
|
| 1056 |
+
"""Compile the fused fwd+bwd bnpo kernel and return a tuple-returning runner.
|
| 1057 |
+
|
| 1058 |
+
The returned callable runs one training-step worth of work: a single
|
| 1059 |
+
``@cute.jit`` dispatch produces both the scalar loss and the scaled
|
| 1060 |
+
``dL/d(policy_logprobs)`` tensor. It returns ``(loss, dpolicy)`` directly
|
| 1061 |
+
— no ``torch.autograd.Function`` wrapper, no extra ``grad_output * dpolicy``
|
| 1062 |
+
backward kernel. Callers that need autograd integration (so
|
| 1063 |
+
``loss.backward()`` works) wrap this themselves at the public-API layer;
|
| 1064 |
+
callers that control gradient flow manually (benchmarks, custom training
|
| 1065 |
+
loops) can use it as-is for zero overhead.
|
| 1066 |
+
|
| 1067 |
+
``inv_total`` is computed entirely on-GPU by a bundled mask-sum kernel
|
| 1068 |
+
that runs in series with the main kernel inside the same ``@cute.jit``
|
| 1069 |
+
launch — no host sync, no extra ``torch.sum`` dispatch, CUDA-graph
|
| 1070 |
+
compatible.
|
| 1071 |
+
"""
|
| 1072 |
+
return _typing_cast(
|
| 1073 |
+
"Callable[..., tuple[torch.Tensor, torch.Tensor]]",
|
| 1074 |
+
create_compiled_bnpo_loss(
|
| 1075 |
+
policy_dtype=policy_dtype,
|
| 1076 |
+
epsilon=epsilon,
|
| 1077 |
+
epsilon_high=epsilon_high,
|
| 1078 |
+
beta=beta,
|
| 1079 |
+
compute_backward=True,
|
| 1080 |
+
),
|
| 1081 |
+
)
|
build/torch-cuda/geometric_ai_kernels/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ctypes
|
| 2 |
+
import importlib.util
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from types import ModuleType
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def _import_from_path(file_path: Path) -> ModuleType:
|
| 9 |
+
# We cannot use the module name as-is, after adding it to `sys.modules`,
|
| 10 |
+
# it would also be used for other imports. So, we make a module name that
|
| 11 |
+
# depends on the path for it to be unique using the hex-encoded hash of
|
| 12 |
+
# the path.
|
| 13 |
+
path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
|
| 14 |
+
module_name = path_hash
|
| 15 |
+
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
| 16 |
+
if spec is None:
|
| 17 |
+
raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
|
| 18 |
+
module = importlib.util.module_from_spec(spec)
|
| 19 |
+
if module is None:
|
| 20 |
+
raise ImportError(f"Cannot load module {module_name} from spec")
|
| 21 |
+
sys.modules[module_name] = module
|
| 22 |
+
spec.loader.exec_module(module) # type: ignore
|
| 23 |
+
return module
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
|
build/torch-cuda/grpo_loss/__init__.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GRPO loss with CuteDSL fused fwd+bwd.
|
| 2 |
+
|
| 3 |
+
Three public APIs:
|
| 4 |
+
|
| 5 |
+
* :func:`grpo_loss` — fused fwd+bwd. Returns
|
| 6 |
+
``(loss, grad_policy_logprobs)`` from a single ``@cute.jit`` dispatch.
|
| 7 |
+
Caller chains via ``policy_logprobs.backward(grad)``.
|
| 8 |
+
* :func:`grpo_loss_fwd` — forward-only (inference / validation).
|
| 9 |
+
Returns scalar ``loss`` and skips the dpolicy buffer entirely.
|
| 10 |
+
* :func:`grpo_loss_autograd` — autograd-aware via
|
| 11 |
+
``torch.library.custom_op``. ``loss.backward()`` works and composes
|
| 12 |
+
with ``torch.compile``. ~12µs of dispatcher overhead vs.
|
| 13 |
+
:func:`grpo_loss`.
|
| 14 |
+
|
| 15 |
+
GRPO requires ``completions_mask``; the per-response normalization
|
| 16 |
+
formula is mask-derived. The cute kernel uses one CTA per row so the
|
| 17 |
+
per-row mask sum is reduced inside the block — no cross-CTA atomics
|
| 18 |
+
or last-block detection on the per-row scaling pass.
|
| 19 |
+
|
| 20 |
+
Per-call output and gradient buffers are allocated inside the runner;
|
| 21 |
+
cross-CTA scratch (the ``policy_acc`` accumulator + last-block counter)
|
| 22 |
+
is owned by the compiled-kernel closure and self-resets each launch.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
from __future__ import annotations
|
| 26 |
+
|
| 27 |
+
from functools import lru_cache
|
| 28 |
+
from typing import TYPE_CHECKING, cast
|
| 29 |
+
|
| 30 |
+
import torch
|
| 31 |
+
|
| 32 |
+
from .cute_grpo_loss import create_compiled_grpo_loss
|
| 33 |
+
|
| 34 |
+
if TYPE_CHECKING:
|
| 35 |
+
from collections.abc import Callable
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
__all__ = ["grpo_loss", "grpo_loss_autograd", "grpo_loss_fwd"]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@lru_cache(maxsize=32)
|
| 42 |
+
def _get_compiled_fwd(
|
| 43 |
+
dtype: torch.dtype,
|
| 44 |
+
epsilon: float,
|
| 45 |
+
epsilon_high: float,
|
| 46 |
+
beta: float,
|
| 47 |
+
) -> Callable[..., torch.Tensor]:
|
| 48 |
+
return cast(
|
| 49 |
+
"Callable[..., torch.Tensor]",
|
| 50 |
+
create_compiled_grpo_loss(
|
| 51 |
+
policy_dtype=dtype,
|
| 52 |
+
epsilon=epsilon,
|
| 53 |
+
epsilon_high=epsilon_high,
|
| 54 |
+
beta=beta,
|
| 55 |
+
compute_backward=False,
|
| 56 |
+
),
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@lru_cache(maxsize=32)
|
| 61 |
+
def _get_compiled_fwd_bwd(
|
| 62 |
+
dtype: torch.dtype,
|
| 63 |
+
epsilon: float,
|
| 64 |
+
epsilon_high: float,
|
| 65 |
+
beta: float,
|
| 66 |
+
) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]:
|
| 67 |
+
return cast(
|
| 68 |
+
"Callable[..., tuple[torch.Tensor, torch.Tensor]]",
|
| 69 |
+
create_compiled_grpo_loss(
|
| 70 |
+
policy_dtype=dtype,
|
| 71 |
+
epsilon=epsilon,
|
| 72 |
+
epsilon_high=epsilon_high,
|
| 73 |
+
beta=beta,
|
| 74 |
+
compute_backward=True,
|
| 75 |
+
),
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _mask_to_int8(completions_mask: torch.Tensor) -> torch.Tensor:
|
| 80 |
+
return (
|
| 81 |
+
completions_mask
|
| 82 |
+
if completions_mask.dtype == torch.int8
|
| 83 |
+
else completions_mask.to(torch.int8)
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def grpo_loss_fwd(
|
| 88 |
+
policy_logprobs: torch.Tensor,
|
| 89 |
+
old_policy_logprobs: torch.Tensor,
|
| 90 |
+
ref_logprobs: torch.Tensor,
|
| 91 |
+
advantages: torch.Tensor,
|
| 92 |
+
completions_mask: torch.Tensor,
|
| 93 |
+
epsilon: float = 0.2,
|
| 94 |
+
epsilon_high: float = 0.2,
|
| 95 |
+
beta: float = 0.1,
|
| 96 |
+
) -> torch.Tensor:
|
| 97 |
+
"""Forward-only GRPO loss. Returns the scalar ``loss``.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
policy_logprobs, old_policy_logprobs, ref_logprobs: ``(bs, seq_len)``.
|
| 101 |
+
advantages: ``(bs,)``.
|
| 102 |
+
completions_mask: Bool / int8 mask ``(bs, seq_len)``. Required.
|
| 103 |
+
epsilon, epsilon_high: PPO-style clipping bounds.
|
| 104 |
+
beta: KL-penalty coefficient. ``0.0`` compiles away the KL branch.
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
Scalar tensor (0-dim) with the same dtype as ``policy_logprobs``.
|
| 108 |
+
"""
|
| 109 |
+
run = _get_compiled_fwd(
|
| 110 |
+
policy_logprobs.dtype,
|
| 111 |
+
float(epsilon),
|
| 112 |
+
float(epsilon_high),
|
| 113 |
+
float(beta),
|
| 114 |
+
)
|
| 115 |
+
return run(
|
| 116 |
+
policy_logprobs,
|
| 117 |
+
old_policy_logprobs,
|
| 118 |
+
ref_logprobs,
|
| 119 |
+
advantages,
|
| 120 |
+
_mask_to_int8(completions_mask),
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def grpo_loss(
|
| 125 |
+
policy_logprobs: torch.Tensor,
|
| 126 |
+
old_policy_logprobs: torch.Tensor,
|
| 127 |
+
ref_logprobs: torch.Tensor,
|
| 128 |
+
advantages: torch.Tensor,
|
| 129 |
+
completions_mask: torch.Tensor,
|
| 130 |
+
epsilon: float = 0.2,
|
| 131 |
+
epsilon_high: float = 0.2,
|
| 132 |
+
beta: float = 0.1,
|
| 133 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 134 |
+
"""Fused fwd+bwd GRPO loss. Returns ``(loss, grad_policy_logprobs)``.
|
| 135 |
+
|
| 136 |
+
Inputs do **not** need ``requires_grad=True``. Chain ``grad`` into
|
| 137 |
+
the upstream model via ``policy_logprobs.backward(grad)``.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
policy_logprobs, old_policy_logprobs, ref_logprobs: ``(bs, seq_len)``.
|
| 141 |
+
advantages: ``(bs,)``.
|
| 142 |
+
completions_mask: Bool / int8 mask ``(bs, seq_len)``. Required.
|
| 143 |
+
epsilon, epsilon_high: PPO-style clipping bounds.
|
| 144 |
+
beta: KL-penalty coefficient. ``0.0`` compiles away the KL branch.
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
``(loss, grad_policy_logprobs)``. The grad already has the
|
| 148 |
+
per-row ``1 / mask.sum(-1).clamp(min=1)`` and across-row
|
| 149 |
+
``1/n_rows`` scalings folded in. The grad tensor is freshly
|
| 150 |
+
allocated per call (no shared cache).
|
| 151 |
+
"""
|
| 152 |
+
run = _get_compiled_fwd_bwd(
|
| 153 |
+
policy_logprobs.dtype,
|
| 154 |
+
float(epsilon),
|
| 155 |
+
float(epsilon_high),
|
| 156 |
+
float(beta),
|
| 157 |
+
)
|
| 158 |
+
return run(
|
| 159 |
+
policy_logprobs,
|
| 160 |
+
old_policy_logprobs,
|
| 161 |
+
ref_logprobs,
|
| 162 |
+
advantages,
|
| 163 |
+
_mask_to_int8(completions_mask),
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# ``autograd.py`` imports ``grpo_loss`` from this module, so the function
|
| 168 |
+
# must be fully defined before its import runs.
|
| 169 |
+
from .autograd import grpo_loss_autograd # noqa: E402
|
build/torch-cuda/grpo_loss/_torch_ref.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Plain-PyTorch GRPO reference shared between bench and tests.
|
| 2 |
+
|
| 3 |
+
Mirrors TRL's default per-response normalization variant: per-row mask
|
| 4 |
+
sum acts as the divisor for that row's loss before averaging across
|
| 5 |
+
rows. Every op is a vanilla torch op so AOTAutograd can derive the
|
| 6 |
+
joint fwd+bwd graph and Inductor can fuse both passes.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def torch_grpo_loss(
|
| 15 |
+
policy_logprobs: torch.Tensor,
|
| 16 |
+
old_policy_logprobs: torch.Tensor,
|
| 17 |
+
ref_logprobs: torch.Tensor,
|
| 18 |
+
advantages: torch.Tensor,
|
| 19 |
+
completions_mask: torch.Tensor,
|
| 20 |
+
epsilon: float = 0.2,
|
| 21 |
+
epsilon_high: float = 0.2,
|
| 22 |
+
beta: float = 0.1,
|
| 23 |
+
) -> torch.Tensor:
|
| 24 |
+
"""Compute the GRPO (Group Relative Policy Optimization) loss.
|
| 25 |
+
|
| 26 |
+
Implements TRL's default per-response normalization variant:
|
| 27 |
+
|
| 28 |
+
L_GRPO = mean_r( ((policy_loss + beta*kl) * mask).sum(-1)
|
| 29 |
+
/ mask.sum(-1).clamp(min=1) )
|
| 30 |
+
|
| 31 |
+
Each row (response) is independently normalized by its own valid-token
|
| 32 |
+
count, then the per-row losses are averaged across rows. This differs
|
| 33 |
+
from the BNPO variant in :func:`torch_bnpo_loss`, which sums numerators
|
| 34 |
+
*and* denominators globally before dividing — under variable response
|
| 35 |
+
lengths BNPO weights longer responses more heavily, while GRPO weights
|
| 36 |
+
every response equally.
|
| 37 |
+
|
| 38 |
+
**Probability ratio:**
|
| 39 |
+
|
| 40 |
+
r_t(theta) = exp(log pi_theta - log pi_{theta_old})
|
| 41 |
+
|
| 42 |
+
**Clipped surrogate (per token):**
|
| 43 |
+
|
| 44 |
+
L_CLIP_t = -min( r_t * A, clip(r_t, 1 - eps, 1 + eps_high) * A )
|
| 45 |
+
|
| 46 |
+
**KL divergence (Schulman approximation, per token):**
|
| 47 |
+
|
| 48 |
+
kl_t ~= exp(log pi_ref - log pi_theta) - (log pi_ref - log pi_theta) - 1
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
policy_logprobs: Log-probabilities of the current policy, shape (N, C).
|
| 52 |
+
old_policy_logprobs: Log-probabilities of the behaviour policy used to
|
| 53 |
+
collect the rollout, shape (N, C).
|
| 54 |
+
ref_logprobs: Log-probabilities of the frozen reference policy, shape
|
| 55 |
+
(N, C).
|
| 56 |
+
advantages: Per-sequence advantage estimates, shape (N,).
|
| 57 |
+
completions_mask: Boolean mask of shape (N, C) where True marks valid tokens.
|
| 58 |
+
Required — GRPO's per-response normalization is mask-derived.
|
| 59 |
+
epsilon: Lower asymmetric clipping bound (1 - epsilon). Default: 0.2.
|
| 60 |
+
epsilon_high: Upper asymmetric clipping bound (1 + ε_high). Default:
|
| 61 |
+
0.2 (symmetric with ``epsilon``, matching TRL's GRPOConfig).
|
| 62 |
+
beta: Coefficient for the KL-divergence penalty. Default: 0.1.
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
Scalar tensor representing the GRPO loss.
|
| 66 |
+
"""
|
| 67 |
+
ratio = torch.exp(policy_logprobs - old_policy_logprobs)
|
| 68 |
+
adv = advantages.unsqueeze(1) # (N, 1) for broadcasting over tokens
|
| 69 |
+
|
| 70 |
+
surrogate = ratio * adv
|
| 71 |
+
surrogate_clipped = torch.clamp(ratio, 1.0 - epsilon, 1.0 + epsilon_high) * adv
|
| 72 |
+
policy_loss_per_tok = -torch.min(surrogate, surrogate_clipped)
|
| 73 |
+
|
| 74 |
+
# Per-row valid-token count, fp32 + clamp to avoid div-by-zero on
|
| 75 |
+
# fully-masked rows (matches TRL's ``mask.sum(-1).clamp(min=1)``).
|
| 76 |
+
mask_sum = completions_mask.sum(-1).to(torch.float32).clamp_min(1.0) # (N,)
|
| 77 |
+
policy_per_row = (policy_loss_per_tok * completions_mask).sum(-1) / mask_sum
|
| 78 |
+
|
| 79 |
+
if beta != 0.0:
|
| 80 |
+
log_ratio_ref = ref_logprobs - policy_logprobs
|
| 81 |
+
kl_per_tok = torch.exp(log_ratio_ref) - log_ratio_ref - 1.0
|
| 82 |
+
kl_per_row = (kl_per_tok * completions_mask).sum(-1) / mask_sum
|
| 83 |
+
loss = (policy_per_row + beta * kl_per_row).mean()
|
| 84 |
+
else:
|
| 85 |
+
loss = policy_per_row.mean()
|
| 86 |
+
|
| 87 |
+
return loss.to(policy_logprobs.dtype)
|
build/torch-cuda/grpo_loss/autograd.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Autograd-aware wrapper for GRPO loss via ``torch.library.custom_op``.
|
| 2 |
+
|
| 3 |
+
The fused cute kernel writes both the scalar loss and the closed-form
|
| 4 |
+
``dL/d(policy_logprobs)`` in one launch. This module wraps that into an
|
| 5 |
+
autograd-compatible op so callers can write::
|
| 6 |
+
|
| 7 |
+
loss = grpo_loss_autograd(policy, old, ref, adv, completions_mask)
|
| 8 |
+
loss.backward()
|
| 9 |
+
|
| 10 |
+
Implementation mirrors the BNPO autograd binding: ``custom_op`` with a
|
| 11 |
+
registered ``setup_context`` / backward, plus a ``register_fake`` for
|
| 12 |
+
shape propagation under ``torch.compile``. The runner allocates
|
| 13 |
+
``dpolicy`` fresh on every call (no shared cache), so
|
| 14 |
+
``ctx.save_for_backward(dpolicy)`` keeps a stable reference for free.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
from . import grpo_loss as _grpo_loss_fwd_bwd
|
| 22 |
+
|
| 23 |
+
__all__ = ["grpo_loss_autograd"]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@torch.library.custom_op(
|
| 27 |
+
"geometric_ai_kernels::_grpo_loss_with_grad",
|
| 28 |
+
mutates_args=(),
|
| 29 |
+
)
|
| 30 |
+
def _grpo_loss_with_grad(
|
| 31 |
+
policy_logprobs: torch.Tensor,
|
| 32 |
+
old_policy_logprobs: torch.Tensor,
|
| 33 |
+
ref_logprobs: torch.Tensor,
|
| 34 |
+
advantages: torch.Tensor,
|
| 35 |
+
completions_mask: torch.Tensor,
|
| 36 |
+
epsilon: float,
|
| 37 |
+
epsilon_high: float,
|
| 38 |
+
beta: float,
|
| 39 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 40 |
+
loss, dpolicy = _grpo_loss_fwd_bwd(
|
| 41 |
+
policy_logprobs,
|
| 42 |
+
old_policy_logprobs,
|
| 43 |
+
ref_logprobs,
|
| 44 |
+
advantages,
|
| 45 |
+
completions_mask,
|
| 46 |
+
epsilon=epsilon,
|
| 47 |
+
epsilon_high=epsilon_high,
|
| 48 |
+
beta=beta,
|
| 49 |
+
)
|
| 50 |
+
return loss, dpolicy
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@_grpo_loss_with_grad.register_fake
|
| 54 |
+
def _(
|
| 55 |
+
policy_logprobs: torch.Tensor,
|
| 56 |
+
old_policy_logprobs: torch.Tensor,
|
| 57 |
+
ref_logprobs: torch.Tensor,
|
| 58 |
+
advantages: torch.Tensor,
|
| 59 |
+
completions_mask: torch.Tensor,
|
| 60 |
+
epsilon: float,
|
| 61 |
+
epsilon_high: float,
|
| 62 |
+
beta: float,
|
| 63 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 64 |
+
del old_policy_logprobs, ref_logprobs, advantages, completions_mask
|
| 65 |
+
del epsilon, epsilon_high, beta
|
| 66 |
+
loss = policy_logprobs.new_empty(())
|
| 67 |
+
dpolicy = torch.empty_like(policy_logprobs)
|
| 68 |
+
return loss, dpolicy
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _setup_context(ctx, inputs, output) -> None: # type: ignore[no-untyped-def]
|
| 72 |
+
del inputs
|
| 73 |
+
_, dpolicy = output
|
| 74 |
+
ctx.save_for_backward(dpolicy)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _backward(ctx, grad_loss, grad_dpolicy): # type: ignore[no-untyped-def]
|
| 78 |
+
del grad_dpolicy
|
| 79 |
+
(dpolicy,) = ctx.saved_tensors
|
| 80 |
+
grad_policy = grad_loss * dpolicy
|
| 81 |
+
# One return per input (8): policy_logprobs gets the grad, the rest get None.
|
| 82 |
+
return grad_policy, None, None, None, None, None, None, None
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
torch.library.register_autograd(
|
| 86 |
+
"geometric_ai_kernels::_grpo_loss_with_grad",
|
| 87 |
+
_backward,
|
| 88 |
+
setup_context=_setup_context,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def grpo_loss_autograd(
|
| 93 |
+
policy_logprobs: torch.Tensor,
|
| 94 |
+
old_policy_logprobs: torch.Tensor,
|
| 95 |
+
ref_logprobs: torch.Tensor,
|
| 96 |
+
advantages: torch.Tensor,
|
| 97 |
+
completions_mask: torch.Tensor,
|
| 98 |
+
epsilon: float = 0.2,
|
| 99 |
+
epsilon_high: float = 0.2,
|
| 100 |
+
beta: float = 0.1,
|
| 101 |
+
) -> torch.Tensor:
|
| 102 |
+
"""Autograd-aware GRPO loss. Returns scalar ``loss``.
|
| 103 |
+
|
| 104 |
+
Same numerics as :func:`grpo_loss` but registered as a
|
| 105 |
+
``torch.library`` custom op with autograd, so ``loss.backward()``
|
| 106 |
+
Just Works. For direct ``(loss, grad)`` access without the
|
| 107 |
+
autograd dispatcher overhead, use :func:`grpo_loss` and chain via
|
| 108 |
+
``policy_logprobs.backward(grad)``.
|
| 109 |
+
"""
|
| 110 |
+
loss, _ = _grpo_loss_with_grad(
|
| 111 |
+
policy_logprobs,
|
| 112 |
+
old_policy_logprobs,
|
| 113 |
+
ref_logprobs,
|
| 114 |
+
advantages,
|
| 115 |
+
completions_mask,
|
| 116 |
+
float(epsilon),
|
| 117 |
+
float(epsilon_high),
|
| 118 |
+
float(beta),
|
| 119 |
+
)
|
| 120 |
+
return loss
|
build/torch-cuda/grpo_loss/cute_grpo_loss.py
ADDED
|
@@ -0,0 +1,805 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""CuteDSL kernel for GRPO (Group Relative Policy Optimization) loss.
|
| 2 |
+
|
| 3 |
+
Implements TRL's default per-response normalization variant:
|
| 4 |
+
|
| 5 |
+
loss = mean_r( ((per_token_loss + beta * kl) * mask).sum(-1)
|
| 6 |
+
/ mask.sum(-1).clamp(min=1) )
|
| 7 |
+
|
| 8 |
+
Element-wise over ``(N, C)`` logprob tensors:
|
| 9 |
+
|
| 10 |
+
ratio = exp(policy - old_policy)
|
| 11 |
+
surrogate = ratio * adv
|
| 12 |
+
clipped = clip(ratio, 1 - eps, 1 + eps_high) * adv
|
| 13 |
+
policy_loss = -min(surrogate, clipped)
|
| 14 |
+
log_ratio_ref = ref - policy
|
| 15 |
+
kl = exp(log_ratio_ref) - log_ratio_ref - 1
|
| 16 |
+
|
| 17 |
+
``completions_mask`` is **required** for GRPO -- the per-response normalization
|
| 18 |
+
formula is mask-derived.
|
| 19 |
+
|
| 20 |
+
**One CTA per row.** Each row is owned by exactly one block, so the
|
| 21 |
+
per-row mask sum is computed locally (warp + cross-warp reduction) with
|
| 22 |
+
no cross-CTA atomics, fences, or last-block detection on the per-row
|
| 23 |
+
scaling pass.
|
| 24 |
+
|
| 25 |
+
**Mask prescan on the fwd+bwd path.** Before the main compute pass each
|
| 26 |
+
CTA runs a cheap mask-only sweep over its row to derive ``row_scale``
|
| 27 |
+
(``1 / max(mask.sum(), 1) * inv_n_rows``). With ``row_scale`` known up
|
| 28 |
+
front the main pass reads ``policy / old / ref`` exactly **once**,
|
| 29 |
+
computes loss and gradient together, and writes the **scaled**
|
| 30 |
+
``dpolicy`` directly — no second logprob read, no unscaled GMEM round
|
| 31 |
+
trip. The prescan touches only the int8 mask (1 byte / element) and is
|
| 32 |
+
much cheaper than the byte of logprob traffic it eliminates (2 B in
|
| 33 |
+
bf16/fp16, 4 B in fp32, ×2 or ×3 depending on the KL term).
|
| 34 |
+
|
| 35 |
+
The fwd-only path skips the prescan and accumulates ``valid`` directly
|
| 36 |
+
in a single sweep over the row.
|
| 37 |
+
|
| 38 |
+
When ``beta=0`` the KL term is skipped at compile time (no ``ref``
|
| 39 |
+
tensor access, no ``kl`` accumulator).
|
| 40 |
+
|
| 41 |
+
Sequence lengths that are **not** a multiple of ``TILE_N`` are handled
|
| 42 |
+
natively: the in-block tile loop runs ``ceil(C / TILE_N)`` iterations;
|
| 43 |
+
full tiles use the vectorized ``LDG.128`` path and the tail tile uses
|
| 44 |
+
predicated vector loads with neutral prefill.
|
| 45 |
+
|
| 46 |
+
Each CTA reduces its row's policy / kl sums in registers, finishes the
|
| 47 |
+
cross-warp reduction in SMEM, computes its scaled per-row contribution
|
| 48 |
+
``(block_policy + beta * block_kl) * row_scale`` locally, and
|
| 49 |
+
``atomicAdd``s the scalar result into ``policy_acc[0]``. The last CTA
|
| 50 |
+
— detected via a single grid-scope ``atomic_inc`` on ``counter`` —
|
| 51 |
+
reads ``policy_acc[0]``, casts to the output dtype, writes
|
| 52 |
+
``output[0]``, and resets the accumulator to ``0``.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
from __future__ import annotations
|
| 56 |
+
|
| 57 |
+
import math
|
| 58 |
+
import operator
|
| 59 |
+
from typing import TYPE_CHECKING, Any
|
| 60 |
+
|
| 61 |
+
import cutlass
|
| 62 |
+
import cutlass.utils
|
| 63 |
+
import torch
|
| 64 |
+
from cutlass import cute
|
| 65 |
+
from cutlass.base_dsl.typing import cast
|
| 66 |
+
|
| 67 |
+
from ..bnpo_loss.cute_bnpo_loss import (
|
| 68 |
+
_atomic_add_f32_gmem,
|
| 69 |
+
_atomic_inc_u32_gmem,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
if TYPE_CHECKING:
|
| 73 |
+
from collections.abc import Callable
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
TILE_N: int = 2048
|
| 77 |
+
NUM_WARPS: int = 16
|
| 78 |
+
VEC: int = 4
|
| 79 |
+
# Long-context variant: a wider tile keeps the per-row block reduction
|
| 80 |
+
# cheap when ``seq_len`` blows past the small tier (where the inner
|
| 81 |
+
# col-tile loop iterates many times). Two compiled variants — small +
|
| 82 |
+
# large — are dispatched at runtime by sequence length.
|
| 83 |
+
TILE_N_LARGE: int = 8192
|
| 84 |
+
TILE_N_LARGE_THRESHOLD: int = 8192
|
| 85 |
+
|
| 86 |
+
_LOG2_E: float = math.log2(math.e)
|
| 87 |
+
|
| 88 |
+
_TORCH_TO_CUTLASS_DTYPE: dict[torch.dtype, Any] = {
|
| 89 |
+
torch.float32: cutlass.Float32,
|
| 90 |
+
torch.float16: cutlass.Float16,
|
| 91 |
+
torch.bfloat16: cutlass.BFloat16,
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# ---------------------------------------------------------------------------
|
| 96 |
+
# Main GRPO kernel — single CTA per row.
|
| 97 |
+
# ---------------------------------------------------------------------------
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _make_grpo_kernel(
|
| 101 |
+
compute_kl: bool,
|
| 102 |
+
compute_backward: bool,
|
| 103 |
+
tile_n: int,
|
| 104 |
+
) -> Callable[..., None]:
|
| 105 |
+
"""Build the GRPO kernel specialized on compile-time flags.
|
| 106 |
+
|
| 107 |
+
On the fwd+bwd path each CTA does a mask-only prescan to derive
|
| 108 |
+
``row_scale`` before the main pass; this lets the main pass fold
|
| 109 |
+
loss accumulation and gradient compute into a single read of the
|
| 110 |
+
logprob tensors. On the fwd-only path the prescan is skipped — the
|
| 111 |
+
main pass accumulates ``valid`` directly.
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
@cute.kernel
|
| 115 |
+
def _grpo_loss_kernel(
|
| 116 |
+
policy: cute.Tensor,
|
| 117 |
+
old_policy: cute.Tensor,
|
| 118 |
+
ref: cute.Tensor,
|
| 119 |
+
advantages: cute.Tensor,
|
| 120 |
+
completions_mask: cute.Tensor,
|
| 121 |
+
dpolicy: cute.Tensor, # (N, C) when compute_backward; (N, 1) dummy otherwise
|
| 122 |
+
policy_acc: cute.Tensor, # (1,) fp32 — global scalar loss accumulator
|
| 123 |
+
counter: cute.Tensor, # (1,) int32 — global last-block detection
|
| 124 |
+
output: cute.Tensor,
|
| 125 |
+
epsilon: cutlass.Float32,
|
| 126 |
+
epsilon_high: cutlass.Float32,
|
| 127 |
+
beta: cutlass.Float32,
|
| 128 |
+
inv_n_rows: cutlass.Float32,
|
| 129 |
+
n_rows: cutlass.Int32,
|
| 130 |
+
num_full_tiles: cutlass.Int32,
|
| 131 |
+
tail_len: cutlass.Int32,
|
| 132 |
+
num_col_tiles: cutlass.Int32,
|
| 133 |
+
) -> None:
|
| 134 |
+
block_size = NUM_WARPS * 32
|
| 135 |
+
iters = tile_n // (block_size * VEC)
|
| 136 |
+
|
| 137 |
+
_no_alloc = cute.nvgpu.CacheEvictionPriority.NO_ALLOCATE
|
| 138 |
+
g2r_op = cute.nvgpu.CopyUniversalOp()
|
| 139 |
+
# Logprob loads stream once -- ``NO_ALLOCATE`` keeps them out of L1
|
| 140 |
+
# so they don't evict the mask data we *do* re-read. The current
|
| 141 |
+
# single-pass main loop never re-reads pol/old/ref; only the mask
|
| 142 |
+
# is touched twice (fwd+bwd prescan + main pass).
|
| 143 |
+
g2r_atom = cute.make_copy_atom(
|
| 144 |
+
g2r_op,
|
| 145 |
+
policy.element_type,
|
| 146 |
+
num_bits_per_copy=0,
|
| 147 |
+
l1c_evict_priority=_no_alloc,
|
| 148 |
+
)
|
| 149 |
+
# Mask is read twice on the fwd+bwd path (prescan + main), so bias
|
| 150 |
+
# L1 to keep it. On fwd-only the prescan does not run, so the hint
|
| 151 |
+
# is unused and falls through to the streaming default.
|
| 152 |
+
mask_evict = cute.nvgpu.CacheEvictionPriority.EVICT_LAST if compute_backward else _no_alloc
|
| 153 |
+
g2r_mask_atom = cute.make_copy_atom(
|
| 154 |
+
g2r_op,
|
| 155 |
+
completions_mask.element_type,
|
| 156 |
+
num_bits_per_copy=0,
|
| 157 |
+
l1c_evict_priority=mask_evict,
|
| 158 |
+
)
|
| 159 |
+
if cutlass.const_expr(compute_backward):
|
| 160 |
+
r2g_atom = cute.make_copy_atom(
|
| 161 |
+
g2r_op,
|
| 162 |
+
dpolicy.element_type,
|
| 163 |
+
num_bits_per_copy=0,
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
row = cute.arch.block_idx()[0]
|
| 167 |
+
tid = cute.arch.thread_idx()[0]
|
| 168 |
+
lane_idx = cute.arch.lane_idx()
|
| 169 |
+
warp_idx = cute.arch.warp_idx()
|
| 170 |
+
|
| 171 |
+
adv = cast(advantages[row], cutlass.Float32)
|
| 172 |
+
lo = cutlass.Float32(1.0) - epsilon
|
| 173 |
+
hi = cutlass.Float32(1.0) + epsilon_high
|
| 174 |
+
|
| 175 |
+
pol_row = cute.slice_(policy, (row, None))
|
| 176 |
+
old_row = cute.slice_(old_policy, (row, None))
|
| 177 |
+
mask_row = cute.slice_(completions_mask, (row, None))
|
| 178 |
+
if cutlass.const_expr(compute_kl):
|
| 179 |
+
ref_row = cute.slice_(ref, (row, None))
|
| 180 |
+
if cutlass.const_expr(compute_backward):
|
| 181 |
+
dp_row = cute.slice_(dpolicy, (row, None))
|
| 182 |
+
|
| 183 |
+
smem = cutlass.utils.SmemAllocator()
|
| 184 |
+
buf_policy = smem.allocate_tensor(cutlass.Float32, cute.make_layout(NUM_WARPS))
|
| 185 |
+
if cutlass.const_expr(compute_kl):
|
| 186 |
+
buf_kl = smem.allocate_tensor(cutlass.Float32, cute.make_layout(NUM_WARPS))
|
| 187 |
+
# Always need a valid buffer in fwd-only path; on fwd+bwd path the
|
| 188 |
+
# prescan also reuses cross-warp SMEM reduction.
|
| 189 |
+
buf_valid = smem.allocate_tensor(cutlass.Float32, cute.make_layout(NUM_WARPS))
|
| 190 |
+
if cutlass.const_expr(compute_backward):
|
| 191 |
+
row_scale_smem = smem.allocate_tensor(cutlass.Float32, cute.make_layout(1))
|
| 192 |
+
|
| 193 |
+
# ---- Stage A (fwd+bwd only): mask-only prescan → row_scale ----
|
| 194 |
+
if cutlass.const_expr(compute_backward):
|
| 195 |
+
local_valid_pre = cutlass.Float32(0.0)
|
| 196 |
+
for col_block in cutlass.range(num_col_tiles, unroll=1):
|
| 197 |
+
mask_slab = cute.local_tile(mask_row, (tile_n,), (col_block,))
|
| 198 |
+
if col_block < num_full_tiles:
|
| 199 |
+
for k in cutlass.range(iters, unroll_full=True):
|
| 200 |
+
sub_idx = tid + k * block_size
|
| 201 |
+
mask_src = cute.local_tile(mask_slab, (VEC,), (sub_idx,))
|
| 202 |
+
mask_frag = cute.make_fragment_like(mask_src)
|
| 203 |
+
cute.copy(g2r_mask_atom, mask_src, mask_frag)
|
| 204 |
+
local_valid_pre += (
|
| 205 |
+
mask_frag.load()
|
| 206 |
+
.to(cutlass.Float32)
|
| 207 |
+
.reduce(
|
| 208 |
+
cute.ReductionOp.ADD,
|
| 209 |
+
cutlass.Float32(0.0),
|
| 210 |
+
reduction_profile=0,
|
| 211 |
+
)
|
| 212 |
+
)
|
| 213 |
+
else:
|
| 214 |
+
for k in cutlass.range(iters, unroll_full=True):
|
| 215 |
+
sub_idx = tid + k * block_size
|
| 216 |
+
chunk_base = sub_idx * VEC
|
| 217 |
+
if chunk_base < tail_len:
|
| 218 |
+
mask_src = cute.local_tile(mask_slab, (VEC,), (sub_idx,))
|
| 219 |
+
pred = cute.make_rmem_tensor(mask_src.shape, cutlass.Boolean)
|
| 220 |
+
for v in cutlass.range(VEC, unroll_full=True):
|
| 221 |
+
pred[v] = cute.elem_less(chunk_base + v, tail_len)
|
| 222 |
+
mask_frag = cute.make_fragment_like(mask_src)
|
| 223 |
+
mask_frag.fill(0)
|
| 224 |
+
cute.copy(g2r_mask_atom, mask_src, mask_frag, pred=pred)
|
| 225 |
+
local_valid_pre += (
|
| 226 |
+
mask_frag.load()
|
| 227 |
+
.to(cutlass.Float32)
|
| 228 |
+
.reduce(
|
| 229 |
+
cute.ReductionOp.ADD,
|
| 230 |
+
cutlass.Float32(0.0),
|
| 231 |
+
reduction_profile=0,
|
| 232 |
+
)
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
warp_valid_pre = cute.arch.warp_reduction(local_valid_pre, operator.add)
|
| 236 |
+
if lane_idx == 0:
|
| 237 |
+
buf_valid[warp_idx] = warp_valid_pre
|
| 238 |
+
cute.arch.barrier()
|
| 239 |
+
|
| 240 |
+
if warp_idx == 0:
|
| 241 |
+
lane_in_warp_range = lane_idx < NUM_WARPS
|
| 242 |
+
val_v = cutlass.Float32(0.0)
|
| 243 |
+
if lane_in_warp_range:
|
| 244 |
+
val_v = buf_valid[lane_idx]
|
| 245 |
+
block_valid_pre = cute.arch.warp_reduction(
|
| 246 |
+
val_v, operator.add, threads_in_group=NUM_WARPS
|
| 247 |
+
)
|
| 248 |
+
if lane_idx == 0:
|
| 249 |
+
n_v_row = cute.arch.fmax(block_valid_pre, cutlass.Float32(1.0))
|
| 250 |
+
row_scale_smem[0] = cute.arch.rcp_approx(n_v_row) * inv_n_rows
|
| 251 |
+
cute.arch.barrier()
|
| 252 |
+
row_scale_v = row_scale_smem[0]
|
| 253 |
+
|
| 254 |
+
# ---- Stage B: main pass — loss accumulation + (fused) gradient ----
|
| 255 |
+
local_policy_sum = cutlass.Float32(0.0)
|
| 256 |
+
local_kl_sum = cutlass.Float32(0.0)
|
| 257 |
+
# Only the fwd-only path needs to accumulate valid here; the
|
| 258 |
+
# fwd+bwd path already produced ``row_scale_v`` from the prescan.
|
| 259 |
+
local_valid_sum = cutlass.Float32(0.0)
|
| 260 |
+
|
| 261 |
+
for col_block in cutlass.range(num_col_tiles, unroll=1):
|
| 262 |
+
if col_block < num_full_tiles:
|
| 263 |
+
pol_slab = cute.local_tile(pol_row, (tile_n,), (col_block,))
|
| 264 |
+
old_slab = cute.local_tile(old_row, (tile_n,), (col_block,))
|
| 265 |
+
mask_slab = cute.local_tile(mask_row, (tile_n,), (col_block,))
|
| 266 |
+
if cutlass.const_expr(compute_kl):
|
| 267 |
+
ref_slab = cute.local_tile(ref_row, (tile_n,), (col_block,))
|
| 268 |
+
if cutlass.const_expr(compute_backward):
|
| 269 |
+
dp_slab = cute.local_tile(dp_row, (tile_n,), (col_block,))
|
| 270 |
+
|
| 271 |
+
for k in cutlass.range(iters, unroll_full=True):
|
| 272 |
+
sub_idx = tid + k * block_size
|
| 273 |
+
|
| 274 |
+
pol_src = cute.local_tile(pol_slab, (VEC,), (sub_idx,))
|
| 275 |
+
old_src = cute.local_tile(old_slab, (VEC,), (sub_idx,))
|
| 276 |
+
pol_frag = cute.make_fragment_like(pol_src)
|
| 277 |
+
old_frag = cute.make_fragment_like(old_src)
|
| 278 |
+
cute.copy(g2r_atom, pol_src, pol_frag)
|
| 279 |
+
cute.copy(g2r_atom, old_src, old_frag)
|
| 280 |
+
|
| 281 |
+
mask_src = cute.local_tile(mask_slab, (VEC,), (sub_idx,))
|
| 282 |
+
mask_frag = cute.make_fragment_like(mask_src)
|
| 283 |
+
cute.copy(g2r_mask_atom, mask_src, mask_frag)
|
| 284 |
+
|
| 285 |
+
pol_vec = pol_frag.load().to(cutlass.Float32)
|
| 286 |
+
old_vec = old_frag.load().to(cutlass.Float32)
|
| 287 |
+
mask_vec = mask_frag.load().to(cutlass.Float32)
|
| 288 |
+
|
| 289 |
+
log_ratio = pol_vec - old_vec
|
| 290 |
+
ratio = cute.math.exp2(log_ratio * _LOG2_E, fastmath=True)
|
| 291 |
+
surrogate = ratio * adv
|
| 292 |
+
clipped_ratio = cute.where(
|
| 293 |
+
ratio < lo,
|
| 294 |
+
lo,
|
| 295 |
+
cute.where(ratio > hi, hi, ratio),
|
| 296 |
+
)
|
| 297 |
+
clipped = clipped_ratio * adv
|
| 298 |
+
policy_loss = -cute.where(surrogate < clipped, surrogate, clipped)
|
| 299 |
+
|
| 300 |
+
# Compute the (negated) surrogate gradient before the KL
|
| 301 |
+
# block so ``surrogate``/``clipped`` can die before the
|
| 302 |
+
# ref load. ``-(adv * ratio)`` is just ``-surrogate`` —
|
| 303 |
+
# saves one FMUL per element.
|
| 304 |
+
if cutlass.const_expr(compute_backward):
|
| 305 |
+
neg_surrogate_grad = cute.where(
|
| 306 |
+
surrogate <= clipped,
|
| 307 |
+
-surrogate,
|
| 308 |
+
cutlass.Float32(0.0),
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
# Reduce ``policy_loss`` immediately so it can be freed
|
| 312 |
+
# before any further work.
|
| 313 |
+
local_policy_sum += (policy_loss * mask_vec).reduce(
|
| 314 |
+
cute.ReductionOp.ADD,
|
| 315 |
+
cutlass.Float32(0.0),
|
| 316 |
+
reduction_profile=0,
|
| 317 |
+
)
|
| 318 |
+
if cutlass.const_expr(not compute_backward):
|
| 319 |
+
local_valid_sum += mask_vec.reduce(
|
| 320 |
+
cute.ReductionOp.ADD,
|
| 321 |
+
cutlass.Float32(0.0),
|
| 322 |
+
reduction_profile=0,
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
if cutlass.const_expr(compute_kl):
|
| 326 |
+
ref_src = cute.local_tile(ref_slab, (VEC,), (sub_idx,))
|
| 327 |
+
ref_frag = cute.make_fragment_like(ref_src)
|
| 328 |
+
cute.copy(g2r_atom, ref_src, ref_frag)
|
| 329 |
+
ref_vec = ref_frag.load().to(cutlass.Float32)
|
| 330 |
+
log_ratio_ref = ref_vec - pol_vec
|
| 331 |
+
ratio_ref = cute.math.exp2(log_ratio_ref * _LOG2_E, fastmath=True)
|
| 332 |
+
# FFMA-friendly rearrangement: ``(ratio_ref - 1) - log_ratio_ref``
|
| 333 |
+
# exposes a ``ratio_ref + (-1)`` pair that ptxas folds
|
| 334 |
+
# with the subsequent subtract — fewer FADDs surviving
|
| 335 |
+
# SASS than the original 3-term ``a - b - c``.
|
| 336 |
+
kl_val = (ratio_ref - cutlass.Float32(1.0)) - log_ratio_ref
|
| 337 |
+
local_kl_sum += (kl_val * mask_vec).reduce(
|
| 338 |
+
cute.ReductionOp.ADD,
|
| 339 |
+
cutlass.Float32(0.0),
|
| 340 |
+
reduction_profile=0,
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
if cutlass.const_expr(compute_backward):
|
| 344 |
+
if cutlass.const_expr(compute_kl):
|
| 345 |
+
# ``beta - beta*ratio_ref`` instead of ``beta*(1 - ratio_ref)``
|
| 346 |
+
# gives ptxas an obvious FFMA pattern (``FFMA -beta,
|
| 347 |
+
# ratio_ref, beta``) — saves one FMUL per element.
|
| 348 |
+
kl_grad = beta - beta * ratio_ref
|
| 349 |
+
grad_vec = (neg_surrogate_grad + kl_grad) * mask_vec
|
| 350 |
+
else:
|
| 351 |
+
grad_vec = neg_surrogate_grad * mask_vec
|
| 352 |
+
scaled = grad_vec * row_scale_v
|
| 353 |
+
|
| 354 |
+
dp_dst = cute.local_tile(dp_slab, (VEC,), (sub_idx,))
|
| 355 |
+
dp_frag = cute.make_fragment_like(dp_dst)
|
| 356 |
+
dp_frag.store(scaled.to(dpolicy.element_type))
|
| 357 |
+
cute.copy(r2g_atom, dp_frag, dp_dst)
|
| 358 |
+
else:
|
| 359 |
+
pol_slab = cute.local_tile(pol_row, (tile_n,), (col_block,))
|
| 360 |
+
old_slab = cute.local_tile(old_row, (tile_n,), (col_block,))
|
| 361 |
+
mask_slab = cute.local_tile(mask_row, (tile_n,), (col_block,))
|
| 362 |
+
if cutlass.const_expr(compute_kl):
|
| 363 |
+
ref_slab = cute.local_tile(ref_row, (tile_n,), (col_block,))
|
| 364 |
+
if cutlass.const_expr(compute_backward):
|
| 365 |
+
dp_slab = cute.local_tile(dp_row, (tile_n,), (col_block,))
|
| 366 |
+
|
| 367 |
+
for k in cutlass.range(iters, unroll_full=True):
|
| 368 |
+
sub_idx = tid + k * block_size
|
| 369 |
+
chunk_base = sub_idx * VEC
|
| 370 |
+
|
| 371 |
+
if chunk_base < tail_len:
|
| 372 |
+
pol_src = cute.local_tile(pol_slab, (VEC,), (sub_idx,))
|
| 373 |
+
old_src = cute.local_tile(old_slab, (VEC,), (sub_idx,))
|
| 374 |
+
pred = cute.make_rmem_tensor(pol_src.shape, cutlass.Boolean)
|
| 375 |
+
for v in cutlass.range(VEC, unroll_full=True):
|
| 376 |
+
pred[v] = cute.elem_less(chunk_base + v, tail_len)
|
| 377 |
+
|
| 378 |
+
pol_frag = cute.make_fragment_like(pol_src)
|
| 379 |
+
old_frag = cute.make_fragment_like(old_src)
|
| 380 |
+
pol_frag.fill(0.0)
|
| 381 |
+
old_frag.fill(0.0)
|
| 382 |
+
cute.copy(g2r_atom, pol_src, pol_frag, pred=pred)
|
| 383 |
+
cute.copy(g2r_atom, old_src, old_frag, pred=pred)
|
| 384 |
+
|
| 385 |
+
mask_src = cute.local_tile(mask_slab, (VEC,), (sub_idx,))
|
| 386 |
+
mask_frag = cute.make_fragment_like(mask_src)
|
| 387 |
+
mask_frag.fill(0)
|
| 388 |
+
cute.copy(g2r_mask_atom, mask_src, mask_frag, pred=pred)
|
| 389 |
+
|
| 390 |
+
pol_vec = pol_frag.load().to(cutlass.Float32)
|
| 391 |
+
old_vec = old_frag.load().to(cutlass.Float32)
|
| 392 |
+
valid_vec = cute.where(
|
| 393 |
+
pred.load(),
|
| 394 |
+
cute.full_like(pol_vec, cutlass.Float32(1.0)),
|
| 395 |
+
cute.zeros_like(pol_vec, dtype=cutlass.Float32),
|
| 396 |
+
)
|
| 397 |
+
# ``mask_vec * valid_vec`` zeros out-of-bounds lanes.
|
| 398 |
+
mask_vec = mask_frag.load().to(cutlass.Float32) * valid_vec
|
| 399 |
+
|
| 400 |
+
log_ratio = pol_vec - old_vec
|
| 401 |
+
ratio = cute.math.exp2(log_ratio * _LOG2_E, fastmath=True)
|
| 402 |
+
surrogate = ratio * adv
|
| 403 |
+
clipped_ratio = cute.where(
|
| 404 |
+
ratio < lo,
|
| 405 |
+
lo,
|
| 406 |
+
cute.where(ratio > hi, hi, ratio),
|
| 407 |
+
)
|
| 408 |
+
clipped = clipped_ratio * adv
|
| 409 |
+
policy_loss = -cute.where(surrogate < clipped, surrogate, clipped)
|
| 410 |
+
|
| 411 |
+
# Same live-range narrowing as the full-tile path:
|
| 412 |
+
# fold ``surrogate``/``clipped`` into the gradient
|
| 413 |
+
# term and reduce ``policy_loss`` before the KL
|
| 414 |
+
# block so they can be freed. ``-(adv * ratio)`` is
|
| 415 |
+
# ``-surrogate`` — saves one FMUL per element.
|
| 416 |
+
if cutlass.const_expr(compute_backward):
|
| 417 |
+
neg_surrogate_grad = cute.where(
|
| 418 |
+
surrogate <= clipped,
|
| 419 |
+
-surrogate,
|
| 420 |
+
cutlass.Float32(0.0),
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
local_policy_sum += (policy_loss * mask_vec).reduce(
|
| 424 |
+
cute.ReductionOp.ADD,
|
| 425 |
+
cutlass.Float32(0.0),
|
| 426 |
+
reduction_profile=0,
|
| 427 |
+
)
|
| 428 |
+
if cutlass.const_expr(not compute_backward):
|
| 429 |
+
local_valid_sum += mask_vec.reduce(
|
| 430 |
+
cute.ReductionOp.ADD,
|
| 431 |
+
cutlass.Float32(0.0),
|
| 432 |
+
reduction_profile=0,
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
if cutlass.const_expr(compute_kl):
|
| 436 |
+
ref_src = cute.local_tile(ref_slab, (VEC,), (sub_idx,))
|
| 437 |
+
ref_frag = cute.make_fragment_like(ref_src)
|
| 438 |
+
ref_frag.fill(0.0)
|
| 439 |
+
cute.copy(g2r_atom, ref_src, ref_frag, pred=pred)
|
| 440 |
+
ref_vec = ref_frag.load().to(cutlass.Float32)
|
| 441 |
+
log_ratio_ref = ref_vec - pol_vec
|
| 442 |
+
ratio_ref = cute.math.exp2(log_ratio_ref * _LOG2_E, fastmath=True)
|
| 443 |
+
# See full-tile path: FFMA-friendly rearrangement.
|
| 444 |
+
kl_val = (ratio_ref - cutlass.Float32(1.0)) - log_ratio_ref
|
| 445 |
+
local_kl_sum += (kl_val * mask_vec).reduce(
|
| 446 |
+
cute.ReductionOp.ADD,
|
| 447 |
+
cutlass.Float32(0.0),
|
| 448 |
+
reduction_profile=0,
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
if cutlass.const_expr(compute_backward):
|
| 452 |
+
if cutlass.const_expr(compute_kl):
|
| 453 |
+
# See full-tile path: ``beta - beta*ratio_ref``
|
| 454 |
+
# is FFMA-friendly.
|
| 455 |
+
kl_grad = beta - beta * ratio_ref
|
| 456 |
+
grad_vec = (neg_surrogate_grad + kl_grad) * mask_vec
|
| 457 |
+
else:
|
| 458 |
+
grad_vec = neg_surrogate_grad * mask_vec
|
| 459 |
+
scaled = grad_vec * row_scale_v
|
| 460 |
+
|
| 461 |
+
dp_dst = cute.local_tile(dp_slab, (VEC,), (sub_idx,))
|
| 462 |
+
dp_frag = cute.make_fragment_like(dp_dst)
|
| 463 |
+
dp_frag.store(scaled.to(dpolicy.element_type))
|
| 464 |
+
cute.copy(r2g_atom, dp_frag, dp_dst, pred=pred)
|
| 465 |
+
|
| 466 |
+
# ---- Stage C: warp + cross-warp reduction → atomic-add row_loss ----
|
| 467 |
+
warp_policy = cute.arch.warp_reduction(local_policy_sum, operator.add)
|
| 468 |
+
if cutlass.const_expr(compute_kl):
|
| 469 |
+
warp_kl = cute.arch.warp_reduction(local_kl_sum, operator.add)
|
| 470 |
+
if cutlass.const_expr(not compute_backward):
|
| 471 |
+
warp_valid = cute.arch.warp_reduction(local_valid_sum, operator.add)
|
| 472 |
+
|
| 473 |
+
# The fwd+bwd path used ``buf_valid`` for the prescan reduction;
|
| 474 |
+
# ensure all threads have observed ``row_scale_smem`` (Stage A's
|
| 475 |
+
# final barrier) before we reuse ``buf_policy`` / ``buf_valid``.
|
| 476 |
+
# Stage A never runs on the fwd-only path, so the barrier is
|
| 477 |
+
# only needed when ``compute_backward`` is set.
|
| 478 |
+
if cutlass.const_expr(compute_backward):
|
| 479 |
+
cute.arch.barrier()
|
| 480 |
+
|
| 481 |
+
if lane_idx == 0:
|
| 482 |
+
buf_policy[warp_idx] = warp_policy
|
| 483 |
+
if cutlass.const_expr(compute_kl):
|
| 484 |
+
buf_kl[warp_idx] = warp_kl
|
| 485 |
+
if cutlass.const_expr(not compute_backward):
|
| 486 |
+
buf_valid[warp_idx] = warp_valid
|
| 487 |
+
cute.arch.barrier()
|
| 488 |
+
|
| 489 |
+
if warp_idx == 0:
|
| 490 |
+
lane_in_warp_range = lane_idx < NUM_WARPS
|
| 491 |
+
|
| 492 |
+
val_p = cutlass.Float32(0.0)
|
| 493 |
+
if lane_in_warp_range:
|
| 494 |
+
val_p = buf_policy[lane_idx]
|
| 495 |
+
block_policy = cute.arch.warp_reduction(val_p, operator.add, threads_in_group=NUM_WARPS)
|
| 496 |
+
|
| 497 |
+
block_kl = cutlass.Float32(0.0)
|
| 498 |
+
if cutlass.const_expr(compute_kl):
|
| 499 |
+
val_k = cutlass.Float32(0.0)
|
| 500 |
+
if lane_in_warp_range:
|
| 501 |
+
val_k = buf_kl[lane_idx]
|
| 502 |
+
block_kl = cute.arch.warp_reduction(val_k, operator.add, threads_in_group=NUM_WARPS)
|
| 503 |
+
|
| 504 |
+
if cutlass.const_expr(not compute_backward):
|
| 505 |
+
val_v = cutlass.Float32(0.0)
|
| 506 |
+
if lane_in_warp_range:
|
| 507 |
+
val_v = buf_valid[lane_idx]
|
| 508 |
+
block_valid = cute.arch.warp_reduction(
|
| 509 |
+
val_v, operator.add, threads_in_group=NUM_WARPS
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
if lane_idx == 0:
|
| 513 |
+
if cutlass.const_expr(compute_backward):
|
| 514 |
+
row_scale = row_scale_smem[0]
|
| 515 |
+
else:
|
| 516 |
+
n_v_row = cute.arch.fmax(block_valid, cutlass.Float32(1.0))
|
| 517 |
+
row_scale = cute.arch.rcp_approx(n_v_row) * inv_n_rows
|
| 518 |
+
|
| 519 |
+
if cutlass.const_expr(compute_kl):
|
| 520 |
+
row_loss = (block_policy + beta * block_kl) * row_scale
|
| 521 |
+
else:
|
| 522 |
+
row_loss = block_policy * row_scale
|
| 523 |
+
|
| 524 |
+
loss_ptr = policy_acc.iterator.toint().ir_value() # ty: ignore[unresolved-attribute]
|
| 525 |
+
_atomic_add_f32_gmem(loss_ptr, row_loss)
|
| 526 |
+
|
| 527 |
+
# ---- Stage D: last-block detection → write final loss ----
|
| 528 |
+
if warp_idx == 0:
|
| 529 |
+
is_last_lane0 = cutlass.Int32(0)
|
| 530 |
+
if lane_idx == 0:
|
| 531 |
+
counter_ptr = counter.iterator.toint().ir_value() # ty: ignore[unresolved-attribute]
|
| 532 |
+
cute.arch.fence_acq_rel_gpu()
|
| 533 |
+
old = _atomic_inc_u32_gmem(counter_ptr, n_rows - 1)
|
| 534 |
+
if old == n_rows - 1:
|
| 535 |
+
is_last_lane0 = cutlass.Int32(1)
|
| 536 |
+
|
| 537 |
+
is_last = cute.arch.shuffle_sync(is_last_lane0, 0)
|
| 538 |
+
|
| 539 |
+
if is_last == cutlass.Int32(1) and lane_idx == 0:
|
| 540 |
+
# Each CTA already atomic-added its scaled ``row_loss``
|
| 541 |
+
# (including ``inv_n_rows``); the accumulator now holds
|
| 542 |
+
# the final loss.
|
| 543 |
+
total = policy_acc[0]
|
| 544 |
+
output[0] = cast(total, output.element_type) # ty: ignore[invalid-argument-type]
|
| 545 |
+
# Reset for the next call.
|
| 546 |
+
policy_acc[0] = cutlass.Float32(0.0)
|
| 547 |
+
|
| 548 |
+
return _grpo_loss_kernel
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
# ---------------------------------------------------------------------------
|
| 552 |
+
# Compile-and-run factory
|
| 553 |
+
# ---------------------------------------------------------------------------
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
def create_compiled_grpo_loss(
|
| 557 |
+
policy_dtype: torch.dtype,
|
| 558 |
+
epsilon: float,
|
| 559 |
+
epsilon_high: float,
|
| 560 |
+
beta: float,
|
| 561 |
+
compute_backward: bool = False,
|
| 562 |
+
) -> Callable[..., torch.Tensor | tuple[torch.Tensor, torch.Tensor]]:
|
| 563 |
+
"""Compile the GRPO loss kernel and return a runtime closure.
|
| 564 |
+
|
| 565 |
+
Two compiled variants — small-tile (``TILE_N``) and large-tile
|
| 566 |
+
(``TILE_N_LARGE``) — are produced; the runner selects one per call
|
| 567 |
+
based on ``seq_len`` vs. ``TILE_N_LARGE_THRESHOLD``.
|
| 568 |
+
|
| 569 |
+
When ``compute_backward=True`` the kernel additionally writes the
|
| 570 |
+
scaled gradient ``dL/d(policy_logprobs)`` to a caller-provided
|
| 571 |
+
``dpolicy`` tensor in the same launch.
|
| 572 |
+
"""
|
| 573 |
+
compute_kl = beta != 0.0
|
| 574 |
+
|
| 575 |
+
if policy_dtype not in _TORCH_TO_CUTLASS_DTYPE:
|
| 576 |
+
raise ValueError(f"Unsupported dtype for GRPO kernel: {policy_dtype}")
|
| 577 |
+
|
| 578 |
+
tile_n_small = TILE_N
|
| 579 |
+
tile_n_large = TILE_N_LARGE
|
| 580 |
+
seq_len_threshold = TILE_N_LARGE_THRESHOLD
|
| 581 |
+
block_size = NUM_WARPS * 32
|
| 582 |
+
if tile_n_small % (block_size * VEC) != 0:
|
| 583 |
+
raise ValueError(
|
| 584 |
+
f"TILE_N={tile_n_small} must be a multiple of BLOCK_SIZE*VEC={block_size * VEC}"
|
| 585 |
+
)
|
| 586 |
+
if tile_n_large % (block_size * VEC) != 0:
|
| 587 |
+
raise ValueError(
|
| 588 |
+
f"TILE_N_LARGE={tile_n_large} must be a multiple of BLOCK_SIZE*VEC={block_size * VEC}"
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
n_rows_sym = cute.sym_int()
|
| 592 |
+
seq_len_sym = cute.sym_int()
|
| 593 |
+
cute_dtype = _TORCH_TO_CUTLASS_DTYPE[policy_dtype]
|
| 594 |
+
|
| 595 |
+
def _fake2d(dt: Any, cols: Any) -> Any:
|
| 596 |
+
return cute.runtime.make_fake_compact_tensor(
|
| 597 |
+
dt,
|
| 598 |
+
(n_rows_sym, cols),
|
| 599 |
+
stride_order=(1, 0),
|
| 600 |
+
assumed_align=16,
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
fake_pol = _fake2d(cute_dtype, seq_len_sym)
|
| 604 |
+
fake_old = _fake2d(cute_dtype, seq_len_sym)
|
| 605 |
+
fake_ref = _fake2d(cute_dtype, seq_len_sym)
|
| 606 |
+
fake_adv = cute.runtime.make_fake_compact_tensor(
|
| 607 |
+
cute_dtype,
|
| 608 |
+
(n_rows_sym,),
|
| 609 |
+
assumed_align=16,
|
| 610 |
+
)
|
| 611 |
+
fake_mask = cute.runtime.make_fake_compact_tensor(
|
| 612 |
+
cutlass.Int8,
|
| 613 |
+
(n_rows_sym, seq_len_sym),
|
| 614 |
+
stride_order=(1, 0),
|
| 615 |
+
assumed_align=16,
|
| 616 |
+
)
|
| 617 |
+
dpolicy_cols = seq_len_sym if compute_backward else 1
|
| 618 |
+
fake_dpolicy = cute.runtime.make_fake_compact_tensor(
|
| 619 |
+
cute_dtype,
|
| 620 |
+
(n_rows_sym, dpolicy_cols),
|
| 621 |
+
stride_order=(1, 0),
|
| 622 |
+
assumed_align=16,
|
| 623 |
+
)
|
| 624 |
+
fake_policy_acc = cute.runtime.make_fake_compact_tensor(
|
| 625 |
+
cutlass.Float32,
|
| 626 |
+
(1,),
|
| 627 |
+
assumed_align=16,
|
| 628 |
+
)
|
| 629 |
+
fake_counter = cute.runtime.make_fake_compact_tensor(
|
| 630 |
+
cutlass.Int32,
|
| 631 |
+
(1,),
|
| 632 |
+
assumed_align=16,
|
| 633 |
+
)
|
| 634 |
+
fake_output = cute.runtime.make_fake_compact_tensor(
|
| 635 |
+
cute_dtype,
|
| 636 |
+
(1,),
|
| 637 |
+
assumed_align=16,
|
| 638 |
+
)
|
| 639 |
+
|
| 640 |
+
def _build_launch(tile_n_v: int) -> Callable[..., None]:
|
| 641 |
+
"""Build a JIT-compiled launcher specialized for ``tile_n_v``.
|
| 642 |
+
|
| 643 |
+
Bakes ``compute_kl``, ``compute_backward``, and the column-tile
|
| 644 |
+
width into the kernel as compile-time constants and returns a
|
| 645 |
+
``@cute.jit`` wrapper that forwards runtime tensors/scalars to
|
| 646 |
+
``.launch()`` with ``grid=(n_rows, 1, 1)`` and
|
| 647 |
+
``block=(NUM_WARPS*32, 1, 1)``.
|
| 648 |
+
"""
|
| 649 |
+
specialized_kernel = _make_grpo_kernel(compute_kl, compute_backward, tile_n_v)
|
| 650 |
+
|
| 651 |
+
@cute.jit
|
| 652 |
+
def _launch(
|
| 653 |
+
pol_ct: cute.Tensor,
|
| 654 |
+
old_ct: cute.Tensor,
|
| 655 |
+
ref_ct: cute.Tensor,
|
| 656 |
+
adv_ct: cute.Tensor,
|
| 657 |
+
mask_ct: cute.Tensor,
|
| 658 |
+
dpolicy_ct: cute.Tensor,
|
| 659 |
+
policy_acc_ct: cute.Tensor,
|
| 660 |
+
counter_ct: cute.Tensor,
|
| 661 |
+
output_ct: cute.Tensor,
|
| 662 |
+
epsilon_v: cutlass.Float32,
|
| 663 |
+
epsilon_high_v: cutlass.Float32,
|
| 664 |
+
beta_v: cutlass.Float32,
|
| 665 |
+
inv_n_rows_v: cutlass.Float32,
|
| 666 |
+
n_rows_v: cutlass.Int32,
|
| 667 |
+
num_full_tiles_v: cutlass.Int32,
|
| 668 |
+
tail_len_v: cutlass.Int32,
|
| 669 |
+
num_col_tiles_v: cutlass.Int32,
|
| 670 |
+
) -> None:
|
| 671 |
+
specialized_kernel( # ty: ignore[unresolved-attribute]
|
| 672 |
+
pol_ct,
|
| 673 |
+
old_ct,
|
| 674 |
+
ref_ct,
|
| 675 |
+
adv_ct,
|
| 676 |
+
mask_ct,
|
| 677 |
+
dpolicy_ct,
|
| 678 |
+
policy_acc_ct,
|
| 679 |
+
counter_ct,
|
| 680 |
+
output_ct,
|
| 681 |
+
epsilon_v,
|
| 682 |
+
epsilon_high_v,
|
| 683 |
+
beta_v,
|
| 684 |
+
inv_n_rows_v,
|
| 685 |
+
n_rows_v,
|
| 686 |
+
num_full_tiles_v,
|
| 687 |
+
tail_len_v,
|
| 688 |
+
num_col_tiles_v,
|
| 689 |
+
).launch(
|
| 690 |
+
grid=(n_rows_v, 1, 1),
|
| 691 |
+
block=(NUM_WARPS * 32, 1, 1),
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
return _launch
|
| 695 |
+
|
| 696 |
+
def _compile_launch(launch_fn: Callable[..., None]) -> Callable[..., None]:
|
| 697 |
+
return cute.compile(
|
| 698 |
+
launch_fn,
|
| 699 |
+
fake_pol,
|
| 700 |
+
fake_old,
|
| 701 |
+
fake_ref,
|
| 702 |
+
fake_adv,
|
| 703 |
+
fake_mask,
|
| 704 |
+
fake_dpolicy,
|
| 705 |
+
fake_policy_acc,
|
| 706 |
+
fake_counter,
|
| 707 |
+
fake_output,
|
| 708 |
+
cutlass.Float32(epsilon),
|
| 709 |
+
cutlass.Float32(epsilon_high),
|
| 710 |
+
cutlass.Float32(beta),
|
| 711 |
+
cutlass.Float32(1.0),
|
| 712 |
+
cutlass.Int32(1),
|
| 713 |
+
cutlass.Int32(1),
|
| 714 |
+
cutlass.Int32(0),
|
| 715 |
+
cutlass.Int32(1),
|
| 716 |
+
options="--enable-tvm-ffi",
|
| 717 |
+
)
|
| 718 |
+
|
| 719 |
+
compiled_small = _compile_launch(_build_launch(tile_n_small))
|
| 720 |
+
if tile_n_large == tile_n_small:
|
| 721 |
+
compiled_large = compiled_small
|
| 722 |
+
else:
|
| 723 |
+
compiled_large = _compile_launch(_build_launch(tile_n_large))
|
| 724 |
+
|
| 725 |
+
eps_const = cutlass.Float32(epsilon)
|
| 726 |
+
eps_high_const = cutlass.Float32(epsilon_high)
|
| 727 |
+
beta_const = cutlass.Float32(beta)
|
| 728 |
+
|
| 729 |
+
# Cross-CTA scratch slab — one int32 buffer with stride-4 (16-byte) slices
|
| 730 |
+
# so each slot is individually 16-byte aligned (``assumed_align=16`` at
|
| 731 |
+
# compile time). Bit-pattern of int32 0 equals fp32 0.0, so a single
|
| 732 |
+
# ``zeros`` factory legitimately initialises both the int32 counter and
|
| 733 |
+
# the fp32 ``policy_acc``. The kernel's last block self-resets
|
| 734 |
+
# ``policy_acc`` in its epilogue and the counter self-resets via
|
| 735 |
+
# ``atom.inc.u32`` wrap-around, so the up-front ``torch.zeros`` only
|
| 736 |
+
# matters for the very first call. Allocated lazily on first ``_run``
|
| 737 |
+
# call when the device is known.
|
| 738 |
+
_scratch: list[torch.Tensor | None] = [None]
|
| 739 |
+
|
| 740 |
+
def _ensure_scratch(device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
|
| 741 |
+
s = _scratch[0]
|
| 742 |
+
if s is None or s.device != device:
|
| 743 |
+
s = torch.zeros(8, dtype=torch.int32, device=device)
|
| 744 |
+
_scratch[0] = s
|
| 745 |
+
return (
|
| 746 |
+
s[0:1], # counter (int32)
|
| 747 |
+
s[4:5].view(torch.float32), # policy_acc (fp32)
|
| 748 |
+
)
|
| 749 |
+
|
| 750 |
+
def _run(
|
| 751 |
+
policy_logprobs_r: torch.Tensor,
|
| 752 |
+
old_policy_logprobs_r: torch.Tensor,
|
| 753 |
+
ref_logprobs_r: torch.Tensor,
|
| 754 |
+
advantages_r: torch.Tensor,
|
| 755 |
+
completions_mask_r: torch.Tensor,
|
| 756 |
+
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
| 757 |
+
n_rows, seq_len = policy_logprobs_r.shape
|
| 758 |
+
device = policy_logprobs_r.device
|
| 759 |
+
dtype = policy_logprobs_r.dtype
|
| 760 |
+
|
| 761 |
+
if seq_len >= seq_len_threshold:
|
| 762 |
+
tile_n_active = tile_n_large
|
| 763 |
+
compiled_active = compiled_large
|
| 764 |
+
else:
|
| 765 |
+
tile_n_active = tile_n_small
|
| 766 |
+
compiled_active = compiled_small
|
| 767 |
+
num_full_tiles = seq_len // tile_n_active
|
| 768 |
+
tail_len = seq_len % tile_n_active
|
| 769 |
+
num_col_tiles = num_full_tiles + (1 if tail_len > 0 else 0)
|
| 770 |
+
inv_n_rows = cutlass.Float32(1.0 / float(n_rows))
|
| 771 |
+
|
| 772 |
+
# Per-call write-only buffers — ``empty`` is enough (Liger / TE pattern).
|
| 773 |
+
output_r = torch.empty(1, dtype=dtype, device=device)
|
| 774 |
+
if compute_backward:
|
| 775 |
+
dpolicy_r = torch.empty_like(policy_logprobs_r)
|
| 776 |
+
else:
|
| 777 |
+
dpolicy_r = torch.empty(n_rows, 1, dtype=dtype, device=device)
|
| 778 |
+
|
| 779 |
+
counter_r, policy_acc_r = _ensure_scratch(device)
|
| 780 |
+
|
| 781 |
+
compiled_active(
|
| 782 |
+
policy_logprobs_r,
|
| 783 |
+
old_policy_logprobs_r,
|
| 784 |
+
ref_logprobs_r,
|
| 785 |
+
advantages_r,
|
| 786 |
+
completions_mask_r,
|
| 787 |
+
dpolicy_r,
|
| 788 |
+
policy_acc_r,
|
| 789 |
+
counter_r,
|
| 790 |
+
output_r,
|
| 791 |
+
eps_const,
|
| 792 |
+
eps_high_const,
|
| 793 |
+
beta_const,
|
| 794 |
+
inv_n_rows,
|
| 795 |
+
cutlass.Int32(n_rows),
|
| 796 |
+
cutlass.Int32(num_full_tiles),
|
| 797 |
+
cutlass.Int32(tail_len),
|
| 798 |
+
cutlass.Int32(num_col_tiles),
|
| 799 |
+
)
|
| 800 |
+
out_view = output_r.view(())
|
| 801 |
+
if compute_backward:
|
| 802 |
+
return out_view, dpolicy_r
|
| 803 |
+
return out_view
|
| 804 |
+
|
| 805 |
+
return _run
|
build/torch-cuda/layers.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""HF Kernels layer adapters for ``kernelize()``.
|
| 2 |
+
|
| 3 |
+
These ``nn.Module`` classes are the entry points for users who want to
|
| 4 |
+
plug our cute kernels into a model via the ``kernels`` library's
|
| 5 |
+
``kernelize`` flow. Two classes per kernel, one per supported mode:
|
| 6 |
+
|
| 7 |
+
* :class:`bnpoLoss` / :class:`grpoLoss` / :class:`ReverseKL`
|
| 8 |
+
— autograd-aware (``loss.backward()`` works). Register against
|
| 9 |
+
``Mode.TRAINING`` and/or ``Mode.TRAINING | Mode.TORCH_COMPILE``.
|
| 10 |
+
* :class:`bnpoLossInference` / :class:`grpoLossInference` /
|
| 11 |
+
:class:`ReverseKLInference` — forward-only, no autograd
|
| 12 |
+
dispatcher. Register against ``Mode.INFERENCE`` for inference /
|
| 13 |
+
validation.
|
| 14 |
+
|
| 15 |
+
All are stateless (no ``__init__``, no member tensors) as required by
|
| 16 |
+
``kernelize`` — it validates that layer classes don't add constructor
|
| 17 |
+
state. The ``has_backward`` and ``can_torch_compile`` attributes are
|
| 18 |
+
the only allowed extras and let ``kernelize`` choose the right layer
|
| 19 |
+
for the requested mode.
|
| 20 |
+
|
| 21 |
+
Forward-signature contract: a downstream user wraps the loss in their
|
| 22 |
+
own ``nn.Module`` (decorated with ``@use_kernel_forward_from_hub(...)``)
|
| 23 |
+
whose ``forward`` matches the signature here. ``kernelize`` swaps the
|
| 24 |
+
``forward`` method by class identity, so the signature MUST line up
|
| 25 |
+
positionally with the user's module.
|
| 26 |
+
|
| 27 |
+
Typical user-side wiring::
|
| 28 |
+
|
| 29 |
+
from kernels import (
|
| 30 |
+
Mode, LayerRepository, kernelize, use_kernel_forward_from_hub,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
@use_kernel_forward_from_hub("bnpoLoss")
|
| 34 |
+
class bnpoLoss(nn.Module):
|
| 35 |
+
def forward(self, policy, old, ref, adv, completions_mask,
|
| 36 |
+
epsilon=0.2, epsilon_high=0.2, beta=0.1):
|
| 37 |
+
... # eager fallback
|
| 38 |
+
|
| 39 |
+
mapping = {
|
| 40 |
+
"bnpoLoss": {
|
| 41 |
+
"cuda": {
|
| 42 |
+
Mode.INFERENCE: LayerRepository(
|
| 43 |
+
repo_id="Geometric-AI/geometric-ai-kernels",
|
| 44 |
+
layer_name="bnpoLossInference",
|
| 45 |
+
),
|
| 46 |
+
Mode.TRAINING: LayerRepository(
|
| 47 |
+
repo_id="Geometric-AI/geometric-ai-kernels",
|
| 48 |
+
layer_name="bnpoLoss",
|
| 49 |
+
),
|
| 50 |
+
}
|
| 51 |
+
}
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
with use_kernel_mapping(mapping):
|
| 55 |
+
model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
from __future__ import annotations
|
| 59 |
+
|
| 60 |
+
import torch
|
| 61 |
+
from torch import nn
|
| 62 |
+
|
| 63 |
+
from .bnpo_loss import bnpo_loss_fwd
|
| 64 |
+
from .bnpo_loss.autograd import bnpo_loss_autograd
|
| 65 |
+
from .grpo_loss import grpo_loss_fwd
|
| 66 |
+
from .grpo_loss.autograd import grpo_loss_autograd
|
| 67 |
+
from .reverse_kl import (
|
| 68 |
+
reverse_kl_autograd,
|
| 69 |
+
reverse_kl_fwd,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
__all__ = [
|
| 73 |
+
"ReverseKL",
|
| 74 |
+
"ReverseKLInference",
|
| 75 |
+
"bnpoLoss",
|
| 76 |
+
"bnpoLossInference",
|
| 77 |
+
"grpoLoss",
|
| 78 |
+
"grpoLossInference",
|
| 79 |
+
]
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class bnpoLoss(nn.Module):
|
| 83 |
+
"""Training-mode bnpo loss layer. ``loss.backward()`` works.
|
| 84 |
+
|
| 85 |
+
Routes through :func:`bnpo_loss_autograd`, which wraps the fused
|
| 86 |
+
cute kernel in a ``torch.library.custom_op`` with a registered
|
| 87 |
+
backward. Compatible with ``torch.compile`` (the op has a fake
|
| 88 |
+
kernel for shape propagation).
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
has_backward = True
|
| 92 |
+
can_torch_compile = True
|
| 93 |
+
|
| 94 |
+
def forward(
|
| 95 |
+
self,
|
| 96 |
+
policy_logprobs: torch.Tensor,
|
| 97 |
+
old_policy_logprobs: torch.Tensor,
|
| 98 |
+
ref_logprobs: torch.Tensor,
|
| 99 |
+
advantages: torch.Tensor,
|
| 100 |
+
completions_mask: torch.Tensor,
|
| 101 |
+
epsilon: float = 0.2,
|
| 102 |
+
epsilon_high: float = 0.2,
|
| 103 |
+
beta: float = 0.1,
|
| 104 |
+
) -> torch.Tensor:
|
| 105 |
+
return bnpo_loss_autograd(
|
| 106 |
+
policy_logprobs,
|
| 107 |
+
old_policy_logprobs,
|
| 108 |
+
ref_logprobs,
|
| 109 |
+
advantages,
|
| 110 |
+
completions_mask,
|
| 111 |
+
epsilon=epsilon,
|
| 112 |
+
epsilon_high=epsilon_high,
|
| 113 |
+
beta=beta,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class bnpoLossInference(nn.Module):
|
| 118 |
+
"""Inference / validation bnpo loss layer. No autograd dispatcher.
|
| 119 |
+
|
| 120 |
+
Routes through :func:`bnpo_loss_fwd` — the forward-only kernel that
|
| 121 |
+
computes the masked mean denominator on-GPU via the last-block
|
| 122 |
+
trick, skipping the dpolicy buffer entirely.
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
has_backward = False
|
| 126 |
+
can_torch_compile = True
|
| 127 |
+
|
| 128 |
+
def forward(
|
| 129 |
+
self,
|
| 130 |
+
policy_logprobs: torch.Tensor,
|
| 131 |
+
old_policy_logprobs: torch.Tensor,
|
| 132 |
+
ref_logprobs: torch.Tensor,
|
| 133 |
+
advantages: torch.Tensor,
|
| 134 |
+
completions_mask: torch.Tensor,
|
| 135 |
+
epsilon: float = 0.2,
|
| 136 |
+
epsilon_high: float = 0.2,
|
| 137 |
+
beta: float = 0.1,
|
| 138 |
+
) -> torch.Tensor:
|
| 139 |
+
return bnpo_loss_fwd(
|
| 140 |
+
policy_logprobs,
|
| 141 |
+
old_policy_logprobs,
|
| 142 |
+
ref_logprobs,
|
| 143 |
+
advantages,
|
| 144 |
+
completions_mask,
|
| 145 |
+
epsilon=epsilon,
|
| 146 |
+
epsilon_high=epsilon_high,
|
| 147 |
+
beta=beta,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class grpoLoss(nn.Module):
|
| 152 |
+
"""Training-mode GRPO loss layer. ``loss.backward()`` works.
|
| 153 |
+
|
| 154 |
+
Routes through :func:`grpo_loss_autograd`, which wraps the fused
|
| 155 |
+
cute kernel in a ``torch.library.custom_op`` with a registered
|
| 156 |
+
backward. ``completions_mask`` is required (GRPO's per-response
|
| 157 |
+
normalization is mask-derived).
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
has_backward = True
|
| 161 |
+
can_torch_compile = True
|
| 162 |
+
|
| 163 |
+
def forward(
|
| 164 |
+
self,
|
| 165 |
+
policy_logprobs: torch.Tensor,
|
| 166 |
+
old_policy_logprobs: torch.Tensor,
|
| 167 |
+
ref_logprobs: torch.Tensor,
|
| 168 |
+
advantages: torch.Tensor,
|
| 169 |
+
completions_mask: torch.Tensor,
|
| 170 |
+
epsilon: float = 0.2,
|
| 171 |
+
epsilon_high: float = 0.2,
|
| 172 |
+
beta: float = 0.1,
|
| 173 |
+
) -> torch.Tensor:
|
| 174 |
+
return grpo_loss_autograd(
|
| 175 |
+
policy_logprobs,
|
| 176 |
+
old_policy_logprobs,
|
| 177 |
+
ref_logprobs,
|
| 178 |
+
advantages,
|
| 179 |
+
completions_mask,
|
| 180 |
+
epsilon=epsilon,
|
| 181 |
+
epsilon_high=epsilon_high,
|
| 182 |
+
beta=beta,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class grpoLossInference(nn.Module):
|
| 187 |
+
"""Inference / validation GRPO loss layer. No autograd dispatcher.
|
| 188 |
+
|
| 189 |
+
Routes through :func:`grpo_loss_fwd` — the forward-only kernel,
|
| 190 |
+
skipping the dpolicy buffer entirely.
|
| 191 |
+
"""
|
| 192 |
+
|
| 193 |
+
has_backward = False
|
| 194 |
+
can_torch_compile = True
|
| 195 |
+
|
| 196 |
+
def forward(
|
| 197 |
+
self,
|
| 198 |
+
policy_logprobs: torch.Tensor,
|
| 199 |
+
old_policy_logprobs: torch.Tensor,
|
| 200 |
+
ref_logprobs: torch.Tensor,
|
| 201 |
+
advantages: torch.Tensor,
|
| 202 |
+
completions_mask: torch.Tensor,
|
| 203 |
+
epsilon: float = 0.2,
|
| 204 |
+
epsilon_high: float = 0.2,
|
| 205 |
+
beta: float = 0.1,
|
| 206 |
+
) -> torch.Tensor:
|
| 207 |
+
return grpo_loss_fwd(
|
| 208 |
+
policy_logprobs,
|
| 209 |
+
old_policy_logprobs,
|
| 210 |
+
ref_logprobs,
|
| 211 |
+
advantages,
|
| 212 |
+
completions_mask,
|
| 213 |
+
epsilon=epsilon,
|
| 214 |
+
epsilon_high=epsilon_high,
|
| 215 |
+
beta=beta,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class ReverseKL(nn.Module):
|
| 220 |
+
"""Training-mode reverse-KL self-distillation loss layer.
|
| 221 |
+
|
| 222 |
+
Routes through :func:`reverse_kl_autograd`, which wraps
|
| 223 |
+
the fused cute kernel in a ``torch.library.custom_op`` with a
|
| 224 |
+
registered backward. ``loss.backward()`` propagates to whatever
|
| 225 |
+
produced ``student_logits``. Compatible with ``torch.compile`` (the
|
| 226 |
+
op has a fake kernel for shape propagation).
|
| 227 |
+
"""
|
| 228 |
+
|
| 229 |
+
has_backward = True
|
| 230 |
+
can_torch_compile = True
|
| 231 |
+
|
| 232 |
+
def forward(
|
| 233 |
+
self,
|
| 234 |
+
student_logits: torch.Tensor,
|
| 235 |
+
teacher_logits: torch.Tensor,
|
| 236 |
+
completions_mask: torch.Tensor,
|
| 237 |
+
) -> torch.Tensor:
|
| 238 |
+
return reverse_kl_autograd(student_logits, teacher_logits, completions_mask)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
class ReverseKLInference(nn.Module):
|
| 242 |
+
"""Inference / validation reverse-KL self-distillation loss layer.
|
| 243 |
+
|
| 244 |
+
Routes through :func:`reverse_kl_fwd` — the forward-only
|
| 245 |
+
kernel that dead-code-eliminates the gradient pass and skips the
|
| 246 |
+
grad-student buffer entirely.
|
| 247 |
+
"""
|
| 248 |
+
|
| 249 |
+
has_backward = False
|
| 250 |
+
can_torch_compile = True
|
| 251 |
+
|
| 252 |
+
def forward(
|
| 253 |
+
self,
|
| 254 |
+
student_logits: torch.Tensor,
|
| 255 |
+
teacher_logits: torch.Tensor,
|
| 256 |
+
completions_mask: torch.Tensor,
|
| 257 |
+
) -> torch.Tensor:
|
| 258 |
+
return reverse_kl_fwd(student_logits, teacher_logits, completions_mask)
|
build/torch-cuda/metadata.json
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"id": "_geometric_ai_kernels_cuda_a766fbd_dirty",
|
| 3 |
+
"version": 0,
|
| 4 |
+
"license": "Apache-2.0",
|
| 5 |
+
"python-depends": [
|
| 6 |
+
"tvm-ffi",
|
| 7 |
+
"nvidia-cutlass-dsl"
|
| 8 |
+
],
|
| 9 |
+
"backend": {
|
| 10 |
+
"type": "cuda"
|
| 11 |
+
}
|
| 12 |
+
}
|
build/torch-cuda/reverse_kl/__init__.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Reverse-KL self-distillation loss with CuteDSL fused fwd+bwd.
|
| 2 |
+
|
| 3 |
+
Three public APIs route to two compiled kernels:
|
| 4 |
+
|
| 5 |
+
* :func:`reverse_kl` — primary training entry point.
|
| 6 |
+
Returns ``(loss, grad_student_logits)`` from a single fused fwd+bwd
|
| 7 |
+
kernel launch. Inputs do **not** need ``requires_grad=True`` and there
|
| 8 |
+
is no ``torch.autograd.Function`` wrapper — chain the gradient into
|
| 9 |
+
the upstream model with ``student_logits.backward(grad)``.
|
| 10 |
+
* :func:`reverse_kl_fwd` — inference / validation path.
|
| 11 |
+
Returns the scalar ``loss`` from a forward-only kernel that
|
| 12 |
+
dead-code-eliminates the gradient pass.
|
| 13 |
+
* :func:`reverse_kl_autograd` — autograd-aware variant via
|
| 14 |
+
``torch.library.custom_op``. Returns scalar ``loss``;
|
| 15 |
+
``loss.backward()`` Just Works. Re-exported from
|
| 16 |
+
:mod:`reverse_kl.autograd`.
|
| 17 |
+
|
| 18 |
+
Per-call output and gradient buffers are allocated inside the runner;
|
| 19 |
+
cross-CTA scratch (atomic accumulators + counters + the constant
|
| 20 |
+
``grad_output=1.0`` scalar) is owned by the compiled-kernel closure and
|
| 21 |
+
self-resets each launch — callers don't manage scratch state.
|
| 22 |
+
|
| 23 |
+
Why no autograd wrapper for :func:`reverse_kl`?
|
| 24 |
+
The reverse-KL gradient is closed-form: ``dL/d(student) = mask *
|
| 25 |
+
inv_n_valid * p * (log_p - log_q - kl_per_row)``. The fused kernel
|
| 26 |
+
already writes that analytically in the same launch as the loss.
|
| 27 |
+
Wrapping in ``torch.autograd.Function`` would cost an extra
|
| 28 |
+
``grad_output * dpolicy`` kernel on backward (~2× per-call overhead in
|
| 29 |
+
practice) and is opaque to ``torch.compile``. Use the autograd-aware
|
| 30 |
+
variant when you need ``loss.backward()`` ergonomics; pay only the
|
| 31 |
+
``custom_op`` dispatcher cost (also Inductor-traceable).
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
from __future__ import annotations
|
| 35 |
+
|
| 36 |
+
from functools import lru_cache
|
| 37 |
+
from typing import TYPE_CHECKING
|
| 38 |
+
|
| 39 |
+
import torch
|
| 40 |
+
|
| 41 |
+
from .cute_reverse_kl import create_compiled_reverse_kl
|
| 42 |
+
|
| 43 |
+
if TYPE_CHECKING:
|
| 44 |
+
from collections.abc import Callable
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
__all__ = [
|
| 48 |
+
"reverse_kl",
|
| 49 |
+
"reverse_kl_autograd",
|
| 50 |
+
"reverse_kl_fwd",
|
| 51 |
+
]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@lru_cache(maxsize=32)
|
| 55 |
+
def _get_compiled_fwd(
|
| 56 |
+
dtype: torch.dtype,
|
| 57 |
+
vocab: int,
|
| 58 |
+
) -> Callable[..., torch.Tensor]:
|
| 59 |
+
return create_compiled_reverse_kl( # ty: ignore[invalid-return-type]
|
| 60 |
+
policy_dtype=dtype,
|
| 61 |
+
vocab=vocab,
|
| 62 |
+
compute_backward=False,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@lru_cache(maxsize=32)
|
| 67 |
+
def _get_compiled_fwd_bwd(
|
| 68 |
+
dtype: torch.dtype,
|
| 69 |
+
vocab: int,
|
| 70 |
+
) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]:
|
| 71 |
+
return create_compiled_reverse_kl( # ty: ignore[invalid-return-type]
|
| 72 |
+
policy_dtype=dtype,
|
| 73 |
+
vocab=vocab,
|
| 74 |
+
compute_backward=True,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _flatten_inputs(
|
| 79 |
+
student_logits: torch.Tensor,
|
| 80 |
+
teacher_logits: torch.Tensor,
|
| 81 |
+
completions_mask: torch.Tensor,
|
| 82 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, tuple[int, ...], int, int]:
|
| 83 |
+
"""Reshape ``(*, V)`` inputs to ``(num_rows, V)`` for the kernel.
|
| 84 |
+
|
| 85 |
+
The kernel works on a flat ``(num_rows, V)`` slab; the public API
|
| 86 |
+
accepts arbitrary leading dims (typically ``(N, C, V)``) and we
|
| 87 |
+
flatten here so the wrapper signature stays user-friendly. We
|
| 88 |
+
assume contiguous-on-V inputs — ``view`` becomes a no-copy reshape.
|
| 89 |
+
"""
|
| 90 |
+
if student_logits.shape != teacher_logits.shape:
|
| 91 |
+
raise ValueError(
|
| 92 |
+
"student_logits and teacher_logits must have the same shape; got "
|
| 93 |
+
f"{tuple(student_logits.shape)} vs {tuple(teacher_logits.shape)}"
|
| 94 |
+
)
|
| 95 |
+
if student_logits.dtype != teacher_logits.dtype:
|
| 96 |
+
raise ValueError(
|
| 97 |
+
"student_logits and teacher_logits must have the same dtype; got "
|
| 98 |
+
f"{student_logits.dtype} vs {teacher_logits.dtype}"
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
leading_shape = tuple(student_logits.shape[:-1])
|
| 102 |
+
vocab = int(student_logits.shape[-1])
|
| 103 |
+
|
| 104 |
+
student_2d = student_logits.view(-1, vocab)
|
| 105 |
+
teacher_2d = teacher_logits.view(-1, vocab)
|
| 106 |
+
num_rows = student_2d.shape[0]
|
| 107 |
+
|
| 108 |
+
if tuple(completions_mask.shape) != leading_shape:
|
| 109 |
+
raise ValueError(
|
| 110 |
+
f"completions_mask must have shape {leading_shape} (logits' leading "
|
| 111 |
+
f"dims); got {tuple(completions_mask.shape)}"
|
| 112 |
+
)
|
| 113 |
+
flat_mask = completions_mask.view(-1)
|
| 114 |
+
if flat_mask.dtype != torch.float32:
|
| 115 |
+
flat_mask = flat_mask.to(torch.float32)
|
| 116 |
+
|
| 117 |
+
return student_2d, teacher_2d, flat_mask, leading_shape, num_rows, vocab
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def reverse_kl_fwd(
|
| 121 |
+
student_logits: torch.Tensor,
|
| 122 |
+
teacher_logits: torch.Tensor,
|
| 123 |
+
completions_mask: torch.Tensor,
|
| 124 |
+
) -> torch.Tensor:
|
| 125 |
+
"""Forward-only reverse-KL self-distillation loss. Returns scalar ``loss``.
|
| 126 |
+
|
| 127 |
+
Use for inference / validation. The masked mean denominator is
|
| 128 |
+
computed on-GPU by the bundled mask-sum kernel — no host
|
| 129 |
+
``mask.sum()`` syncs.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
student_logits, teacher_logits: ``(*, V)`` logit tensors with
|
| 133 |
+
arbitrary leading dims (typically ``(N, C, V)``); both must
|
| 134 |
+
share shape and dtype.
|
| 135 |
+
completions_mask: Bool / int / float mask with shape matching
|
| 136 |
+
``student_logits.shape[:-1]``; truthy = valid token.
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
Scalar tensor (0-dim) with the same dtype as ``student_logits``.
|
| 140 |
+
"""
|
| 141 |
+
student_2d, teacher_2d, flat_mask, _, _, vocab = _flatten_inputs(
|
| 142 |
+
student_logits, teacher_logits, completions_mask
|
| 143 |
+
)
|
| 144 |
+
run = _get_compiled_fwd(student_logits.dtype, vocab)
|
| 145 |
+
return run(student_2d, teacher_2d, flat_mask)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def reverse_kl(
|
| 149 |
+
student_logits: torch.Tensor,
|
| 150 |
+
teacher_logits: torch.Tensor,
|
| 151 |
+
completions_mask: torch.Tensor,
|
| 152 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 153 |
+
"""Fused fwd+bwd reverse-KL self-distillation. Returns ``(loss, grad_student)``.
|
| 154 |
+
|
| 155 |
+
Single-launch training entry point. The kernel writes both the
|
| 156 |
+
scalar loss and the analytical ``dL/d(student_logits)`` in one
|
| 157 |
+
``@cute.jit`` dispatch — the bundled mask-sum kernel populates
|
| 158 |
+
``inv_n_valid`` on-GPU before the main kernel reads it, so there's
|
| 159 |
+
no host-side ``mask.sum()`` round trip.
|
| 160 |
+
|
| 161 |
+
Inputs do **not** need ``requires_grad=True``. To chain ``grad``
|
| 162 |
+
into the upstream model that produced ``student_logits``::
|
| 163 |
+
|
| 164 |
+
loss, grad = reverse_kl(student_logits, teacher_logits, mask)
|
| 165 |
+
student_logits.backward(grad)
|
| 166 |
+
optimizer.step()
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
student_logits, teacher_logits: ``(*, V)`` logit tensors with
|
| 170 |
+
arbitrary leading dims; both must share shape and dtype.
|
| 171 |
+
completions_mask: Mask with shape matching
|
| 172 |
+
``student_logits.shape[:-1]``.
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
``(loss, grad_student_logits)`` — ``loss`` is a 0-dim tensor in
|
| 176 |
+
``student_logits.dtype``; ``grad_student_logits`` matches
|
| 177 |
+
``student_logits.shape`` and is already scaled by
|
| 178 |
+
``1 / n_valid`` (undefined when ``n_valid == 0`` — fully-masked
|
| 179 |
+
batches produce inf/NaN; callers must guard upstream). The grad
|
| 180 |
+
tensor is freshly allocated per call (no shared cache).
|
| 181 |
+
|
| 182 |
+
For inference / validation where you only need the loss, use
|
| 183 |
+
:func:`reverse_kl_fwd` — it skips the gradient slab entirely.
|
| 184 |
+
"""
|
| 185 |
+
student_2d, teacher_2d, flat_mask, leading_shape, _, vocab = _flatten_inputs(
|
| 186 |
+
student_logits, teacher_logits, completions_mask
|
| 187 |
+
)
|
| 188 |
+
run = _get_compiled_fwd_bwd(student_logits.dtype, vocab)
|
| 189 |
+
loss, grad_2d = run(student_2d, teacher_2d, flat_mask)
|
| 190 |
+
grad_student = grad_2d.view((*leading_shape, vocab))
|
| 191 |
+
return loss, grad_student
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
# Imported at the bottom: ``autograd.py`` imports ``reverse_kl`` from
|
| 195 |
+
# this module, so the function must be fully defined before its import runs.
|
| 196 |
+
from .autograd import reverse_kl_autograd # noqa: E402
|
build/torch-cuda/reverse_kl/_torch_ref.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Plain-PyTorch reverse-KL reference shared between bench and tests.
|
| 2 |
+
|
| 3 |
+
Every op is a vanilla torch op so AOTAutograd can derive the joint
|
| 4 |
+
fwd+bwd graph and Inductor can fuse both passes (used by
|
| 5 |
+
``benchmarks/benchmark_reverse_kl.py``'s compiled baseline).
|
| 6 |
+
The same function is imported by ``tests/test_reverse_kl.py``
|
| 7 |
+
as the correctness reference, so both paths agree on what "the eager
|
| 8 |
+
torch implementation of reverse-KL self-distillation" means.
|
| 9 |
+
|
| 10 |
+
Reverse-KL definition (KL(student || teacher)):
|
| 11 |
+
|
| 12 |
+
p = softmax(student)
|
| 13 |
+
q = softmax(teacher)
|
| 14 |
+
kl_per_row = sum_v p_v * (log p_v - log q_v)
|
| 15 |
+
loss = sum_r mask_r * kl_per_row[r] / max(sum_r mask_r, 1)
|
| 16 |
+
|
| 17 |
+
The ``clamp(min=1)`` matches TRL's masked-mean convention so a
|
| 18 |
+
fully-masked batch yields ``loss=0`` instead of NaN, mirroring the
|
| 19 |
+
cute kernel's ``cute.arch.fmax(n_valid, 1.0)`` clamp.
|
| 20 |
+
|
| 21 |
+
Underscore-prefixed module name signals "shared internal", not a public
|
| 22 |
+
API surface — there's no re-export from the package's top-level
|
| 23 |
+
``__init__.py``.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
from __future__ import annotations
|
| 27 |
+
|
| 28 |
+
import torch
|
| 29 |
+
from torch.nn.functional import kl_div, log_softmax
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def torch_reverse_kl(
|
| 33 |
+
student_logits: torch.Tensor,
|
| 34 |
+
teacher_logits: torch.Tensor,
|
| 35 |
+
completions_mask: torch.Tensor,
|
| 36 |
+
) -> torch.Tensor:
|
| 37 |
+
"""Compute reverse-KL divergence loss.
|
| 38 |
+
|
| 39 |
+
Computes per-token reverse KL divergence:
|
| 40 |
+
|
| 41 |
+
KL(student || teacher) = sum_v p(v) [log p(v) - log q(v)]
|
| 42 |
+
|
| 43 |
+
where p is the student distribution and q is the teacher distribution,
|
| 44 |
+
both obtained by softmax over the vocabulary dimension.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
student_logits: Student logits, shape (N, C, V).
|
| 48 |
+
teacher_logits: Teacher logits, shape (N, C, V).
|
| 49 |
+
completions_mask: Boolean mask of shape (N, C) where True marks
|
| 50 |
+
valid tokens. Pass an all-ones mask if no tokens are padded.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
Scalar tensor representing the loss.
|
| 54 |
+
"""
|
| 55 |
+
log_p = log_softmax(student_logits, dim=-1)
|
| 56 |
+
log_q = log_softmax(teacher_logits, dim=-1)
|
| 57 |
+
|
| 58 |
+
# kl_div(input, target, log_target=True) computes KL(target || input)
|
| 59 |
+
# so input=log_q, target=log_p gives KL(student || teacher)
|
| 60 |
+
kl = kl_div(log_q, log_p, log_target=True, reduction="none").sum(dim=-1)
|
| 61 |
+
|
| 62 |
+
n_valid = completions_mask.sum().to(torch.float32)
|
| 63 |
+
kl = (kl * completions_mask).sum() / n_valid
|
| 64 |
+
|
| 65 |
+
return kl.to(student_logits.dtype)
|
build/torch-cuda/reverse_kl/autograd.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Autograd-aware wrapper for reverse-KL self-distillation via ``torch.library.custom_op``.
|
| 2 |
+
|
| 3 |
+
The fused cute kernel writes both the scalar loss and the closed-form
|
| 4 |
+
``dL/d(student_logits)`` in one launch. This module wraps that into an
|
| 5 |
+
autograd-compatible op so callers can write::
|
| 6 |
+
|
| 7 |
+
loss = reverse_kl_autograd(student, teacher, completions_mask)
|
| 8 |
+
loss.backward() # propagates through to whatever produced student_logits
|
| 9 |
+
|
| 10 |
+
instead of the manual ``student.backward(grad)`` chain. The cost is
|
| 11 |
+
~12µs of autograd dispatcher overhead per call (vs the direct
|
| 12 |
+
``reverse_kl`` (loss, grad) tuple); for ergonomic /
|
| 13 |
+
``kernelize()`` flows that's cheap, but for tight microbenches use the
|
| 14 |
+
direct path.
|
| 15 |
+
|
| 16 |
+
Implementation notes:
|
| 17 |
+
|
| 18 |
+
- The registered op returns ``(loss, grad_student)`` so
|
| 19 |
+
``setup_context`` can ``save_for_backward(grad_student)``. The public
|
| 20 |
+
:func:`reverse_kl_autograd` wrapper hides the second output.
|
| 21 |
+
- The runner allocates ``grad_student`` fresh on every call (no shared
|
| 22 |
+
cache), so ``ctx.save_for_backward(grad_student)`` keeps a stable
|
| 23 |
+
reference for free.
|
| 24 |
+
- Backward returns ``grad_loss * grad_student``. Under
|
| 25 |
+
``torch.compile``, when ``loss`` is consumed by ``.backward()``
|
| 26 |
+
directly, ``grad_loss`` is the constant 1.0 and Inductor can fold the
|
| 27 |
+
multiply away — the main reason this path uses ``custom_op`` instead
|
| 28 |
+
of a plain ``autograd.Function``.
|
| 29 |
+
- ``register_fake`` provides the meta kernel for ``torch.compile``
|
| 30 |
+
shape propagation; the real cute kernel never runs under
|
| 31 |
+
``FakeTensorMode``.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
from __future__ import annotations
|
| 35 |
+
|
| 36 |
+
import torch
|
| 37 |
+
|
| 38 |
+
from . import reverse_kl as _reverse_kl_fwd_bwd
|
| 39 |
+
|
| 40 |
+
__all__ = ["reverse_kl_autograd"]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@torch.library.custom_op(
|
| 44 |
+
"geometric_ai_kernels::_reverse_kl_with_grad",
|
| 45 |
+
mutates_args=(),
|
| 46 |
+
)
|
| 47 |
+
def _reverse_kl_with_grad(
|
| 48 |
+
student_logits: torch.Tensor,
|
| 49 |
+
teacher_logits: torch.Tensor,
|
| 50 |
+
completions_mask: torch.Tensor,
|
| 51 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 52 |
+
loss, grad_student = _reverse_kl_fwd_bwd(
|
| 53 |
+
student_logits,
|
| 54 |
+
teacher_logits,
|
| 55 |
+
completions_mask,
|
| 56 |
+
)
|
| 57 |
+
return loss, grad_student
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@_reverse_kl_with_grad.register_fake
|
| 61 |
+
def _(
|
| 62 |
+
student_logits: torch.Tensor,
|
| 63 |
+
teacher_logits: torch.Tensor,
|
| 64 |
+
completions_mask: torch.Tensor,
|
| 65 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 66 |
+
# Signature must mirror the op; only ``student_logits`` shapes the outputs.
|
| 67 |
+
del teacher_logits, completions_mask
|
| 68 |
+
loss = student_logits.new_empty(())
|
| 69 |
+
grad_student = torch.empty_like(student_logits)
|
| 70 |
+
return loss, grad_student
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _setup_context(ctx, inputs, output) -> None: # type: ignore[no-untyped-def]
|
| 74 |
+
del inputs # only ``output`` carries what we need to save.
|
| 75 |
+
_, grad_student = output
|
| 76 |
+
ctx.save_for_backward(grad_student)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _backward(ctx, grad_loss, grad_grad_student): # type: ignore[no-untyped-def]
|
| 80 |
+
# ``grad_grad_student`` is unused — ``grad_student`` is an internal
|
| 81 |
+
# intermediate exposed only so ``setup_context`` can save it. Under
|
| 82 |
+
# typical usage (``loss.backward()``) it arrives as ``None`` or a
|
| 83 |
+
# zero tensor.
|
| 84 |
+
del grad_grad_student
|
| 85 |
+
(grad_student,) = ctx.saved_tensors
|
| 86 |
+
grad_input = grad_loss * grad_student
|
| 87 |
+
# One return per input to the op (3): student_logits gets the grad,
|
| 88 |
+
# teacher_logits and completions_mask have no autograd flow.
|
| 89 |
+
return grad_input, None, None
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
torch.library.register_autograd(
|
| 93 |
+
"geometric_ai_kernels::_reverse_kl_with_grad",
|
| 94 |
+
_backward,
|
| 95 |
+
setup_context=_setup_context,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def reverse_kl_autograd(
|
| 100 |
+
student_logits: torch.Tensor,
|
| 101 |
+
teacher_logits: torch.Tensor,
|
| 102 |
+
completions_mask: torch.Tensor,
|
| 103 |
+
) -> torch.Tensor:
|
| 104 |
+
"""Autograd-aware reverse-KL self-distillation. Returns scalar ``loss``.
|
| 105 |
+
|
| 106 |
+
Same numerics as :func:`reverse_kl` but registered as a
|
| 107 |
+
``torch.library`` custom op with autograd, so::
|
| 108 |
+
|
| 109 |
+
loss = reverse_kl_autograd(student, teacher, completions_mask)
|
| 110 |
+
loss.backward()
|
| 111 |
+
|
| 112 |
+
propagates through to whatever produced ``student_logits``. For
|
| 113 |
+
direct ``(loss, grad)`` access without the autograd dispatcher
|
| 114 |
+
overhead, use :func:`reverse_kl` and chain the gradient
|
| 115 |
+
manually via ``student_logits.backward(grad)``.
|
| 116 |
+
|
| 117 |
+
Composes with ``torch.compile``: the op is opaque to Inductor but
|
| 118 |
+
has a fake/meta kernel registered, so models containing this layer
|
| 119 |
+
can be compiled end-to-end without graph breaks.
|
| 120 |
+
"""
|
| 121 |
+
loss, _ = _reverse_kl_with_grad(student_logits, teacher_logits, completions_mask)
|
| 122 |
+
return loss
|
build/torch-cuda/reverse_kl/cute_reverse_kl.py
ADDED
|
@@ -0,0 +1,881 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Fused single-pass reverse-KL self-distillation loss kernel (CuteDSL, SM90).
|
| 2 |
+
|
| 3 |
+
Computes ``KL(student || teacher)`` over a ``(num_tokens, vocab)`` slab using
|
| 4 |
+
an online normalisation algorithm that reads each logit row exactly once.
|
| 5 |
+
The two log-softmax passes, the element-wise product ``p * (log_p - log_q)``,
|
| 6 |
+
the sum-reduction over ``V``, the mask application, and the final mean are
|
| 7 |
+
all fused into one launch:
|
| 8 |
+
|
| 9 |
+
1. **Mask-sum kernel** — reduces ``mask_flat`` and writes
|
| 10 |
+
``inv_n_valid = 1 / sum(mask)``. Undefined when ``sum(mask) == 0``
|
| 11 |
+
(fully-masked batches produce inf/NaN in the final loss); callers
|
| 12 |
+
must guard upstream if that case is reachable.
|
| 13 |
+
2. **Per-row main kernel** — one CTA per token; computes per-row KL and
|
| 14 |
+
(when ``compute_backward=True``) writes the analytical gradient
|
| 15 |
+
through the softmax Jacobian:
|
| 16 |
+
|
| 17 |
+
``grad_student_v = scale * p_v * (log_p_v - log_q_v - KL_per_row)``
|
| 18 |
+
|
| 19 |
+
where ``scale = mask[r] * grad_output * inv_n_valid``. The two passes
|
| 20 |
+
per row (online stats then gradient write-out) execute within a single
|
| 21 |
+
CTA so the Pass-1 ``KL_per_row`` is broadcast through SMEM with no
|
| 22 |
+
DSMEM/cluster overhead. The cross-row loss reduction piggybacks on the
|
| 23 |
+
atomic-add + last-block-detect pattern used by ``cute_bnpo_loss``.
|
| 24 |
+
|
| 25 |
+
When ``compute_backward=False`` Pass 2, the broadcast SMEM, the
|
| 26 |
+
``grad_output`` read, and the ``* grad_output`` factor in the final scalar
|
| 27 |
+
are all dead-code-eliminated at trace time — the kernel becomes a pure
|
| 28 |
+
forward path identical in PTX to a hand-written fwd-only kernel.
|
| 29 |
+
|
| 30 |
+
A single public entry point :func:`create_compiled_reverse_kl`
|
| 31 |
+
JIT-compiles either the fwd-only or fused fwd+bwd path depending on its
|
| 32 |
+
``compute_backward`` flag. ``vocab`` is captured as a compile-time
|
| 33 |
+
constant (the tile + tail layout closes over it); the number of token
|
| 34 |
+
rows is symbolic and may vary across calls.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from __future__ import annotations
|
| 38 |
+
|
| 39 |
+
import math
|
| 40 |
+
import operator
|
| 41 |
+
from typing import TYPE_CHECKING, Any
|
| 42 |
+
|
| 43 |
+
import cutlass
|
| 44 |
+
import torch
|
| 45 |
+
from cutlass import cute
|
| 46 |
+
from cutlass._mlir.dialects import llvm
|
| 47 |
+
from cutlass.base_dsl.typing import cast
|
| 48 |
+
from cutlass.cute.nvgpu import CacheEvictionPriority, CopyUniversalOp
|
| 49 |
+
from cutlass.cutlass_dsl import T, dsl_user_op
|
| 50 |
+
from cutlass.utils import SmemAllocator
|
| 51 |
+
|
| 52 |
+
if TYPE_CHECKING:
|
| 53 |
+
from collections.abc import Callable
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# ---------------------------------------------------------------------------
|
| 57 |
+
# Tunable constants — shared by fwd-only and fwd+bwd kernels
|
| 58 |
+
# ---------------------------------------------------------------------------
|
| 59 |
+
# Wider tile, more threads, and 128-bit loads so the two-pass (fwd+bwd) and
|
| 60 |
+
# single-pass (fwd-only) variants both stay bandwidth-bound from the same
|
| 61 |
+
# specialisation. ``FB_VEC_SIZE`` is recomputed at compile time from
|
| 62 |
+
# ``LOAD_BITS`` and the dtype width:
|
| 63 |
+
# FP16/BF16 → VEC=8, TILE_V=8192
|
| 64 |
+
# FP32 → VEC=4, TILE_V=4096
|
| 65 |
+
FB_NUM_THREADS = 1024
|
| 66 |
+
FB_NUM_WARPS = FB_NUM_THREADS // 32 # 32
|
| 67 |
+
FB_LOAD_BITS = 128
|
| 68 |
+
|
| 69 |
+
_LOG2E = math.log2(math.e) # 1.4426950408889634
|
| 70 |
+
|
| 71 |
+
_TORCH_TO_CUTLASS_DTYPE: dict[torch.dtype, Any] = {
|
| 72 |
+
torch.float32: cutlass.Float32,
|
| 73 |
+
torch.float16: cutlass.Float16,
|
| 74 |
+
torch.bfloat16: cutlass.BFloat16,
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# ---------------------------------------------------------------------------
|
| 79 |
+
# Vendored atomic helpers (mirrors cute_bnpo_loss._atomic_*). Copied locally
|
| 80 |
+
# so the reverse_kl subpackage stays independent of bnpo_loss —
|
| 81 |
+
# the two ship together today, but kernel packages should be free-standing
|
| 82 |
+
# so a single one can be peeled off for separate publishing.
|
| 83 |
+
# ---------------------------------------------------------------------------
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
@dsl_user_op
|
| 87 |
+
def _atomic_add_f32_gmem(
|
| 88 |
+
ptr_i64: Any,
|
| 89 |
+
val: cutlass.Float32,
|
| 90 |
+
*,
|
| 91 |
+
loc: Any = None,
|
| 92 |
+
ip: Any = None,
|
| 93 |
+
) -> None:
|
| 94 |
+
llvm.inline_asm(
|
| 95 |
+
T.f32(),
|
| 96 |
+
[ptr_i64, cutlass.Float32(val).ir_value(loc=loc, ip=ip)],
|
| 97 |
+
"atom.global.add.f32 $0, [$1], $2;",
|
| 98 |
+
"=f,l,f",
|
| 99 |
+
has_side_effects=True,
|
| 100 |
+
is_align_stack=False,
|
| 101 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@dsl_user_op
|
| 106 |
+
def _atomic_inc_u32_gmem(
|
| 107 |
+
ptr_i64: Any,
|
| 108 |
+
threshold: cutlass.Int32,
|
| 109 |
+
*,
|
| 110 |
+
loc: Any = None,
|
| 111 |
+
ip: Any = None,
|
| 112 |
+
) -> cutlass.Int32:
|
| 113 |
+
"""``atom.global.inc.u32`` — returns old value; wraps to 0 at threshold."""
|
| 114 |
+
return cutlass.Int32(
|
| 115 |
+
llvm.inline_asm(
|
| 116 |
+
T.i32(),
|
| 117 |
+
[ptr_i64, cutlass.Int32(threshold).ir_value(loc=loc, ip=ip)],
|
| 118 |
+
"atom.global.inc.u32 $0, [$1], $2;",
|
| 119 |
+
"=r,l,r",
|
| 120 |
+
has_side_effects=True,
|
| 121 |
+
is_align_stack=False,
|
| 122 |
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
| 123 |
+
)
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# ---------------------------------------------------------------------------
|
| 128 |
+
# Mask-sum kernel — reduces ``mask_flat`` (fp32, length N) to a per-block
|
| 129 |
+
# partial via warp + SMEM reduction, then atomically accumulates into
|
| 130 |
+
# ``valid_acc``; the last block writes ``rcp_approx(n_valid)`` into
|
| 131 |
+
# ``inv_n_valid`` and resets ``valid_acc`` to 0. Counter self-resets via
|
| 132 |
+
# ``atom.inc.u32`` wrap-around.
|
| 133 |
+
# ---------------------------------------------------------------------------
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def _make_mask_sum_kernel(
|
| 137 |
+
fb_num_threads: int,
|
| 138 |
+
fb_num_warps: int,
|
| 139 |
+
) -> Callable[..., None]:
|
| 140 |
+
"""Return a ``@cute.kernel`` that reduces ``mask_flat`` and writes 1/sum.
|
| 141 |
+
|
| 142 |
+
Grid: ``(num_blocks, 1, 1)`` where each block processes
|
| 143 |
+
``fb_num_threads`` mask elements (one element per thread, no
|
| 144 |
+
vectorisation — sufficient for the small N relative to vocab work).
|
| 145 |
+
A separate ``mask_counter`` (not shared with the main kernel's
|
| 146 |
+
``counter``) is required because both rely on ``atom.inc.u32``
|
| 147 |
+
wrap-around for self-reset.
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
@cute.kernel
|
| 151 |
+
def _mask_sum_kernel(
|
| 152 |
+
mask_flat: cute.Tensor,
|
| 153 |
+
inv_n_valid: cute.Tensor,
|
| 154 |
+
valid_acc: cute.Tensor,
|
| 155 |
+
mask_counter: cute.Tensor,
|
| 156 |
+
num_rows: cutlass.Int32,
|
| 157 |
+
total_blocks: cutlass.Int32,
|
| 158 |
+
) -> None:
|
| 159 |
+
block_size = fb_num_warps * 32
|
| 160 |
+
bidx = cute.arch.block_idx()[0]
|
| 161 |
+
tidx = cute.arch.thread_idx()[0]
|
| 162 |
+
|
| 163 |
+
global_idx = bidx * block_size + tidx
|
| 164 |
+
local_val = cutlass.Float32(0.0)
|
| 165 |
+
if global_idx < num_rows:
|
| 166 |
+
local_val = mask_flat[global_idx]
|
| 167 |
+
|
| 168 |
+
warp_val = cute.arch.warp_reduction(local_val, operator.add)
|
| 169 |
+
|
| 170 |
+
smem = SmemAllocator()
|
| 171 |
+
buf = smem.allocate_tensor(cutlass.Float32, cute.make_layout(fb_num_warps))
|
| 172 |
+
|
| 173 |
+
lane_idx = cute.arch.lane_idx()
|
| 174 |
+
warp_idx = cute.arch.warp_idx()
|
| 175 |
+
|
| 176 |
+
if lane_idx == 0:
|
| 177 |
+
buf[warp_idx] = warp_val
|
| 178 |
+
cute.arch.barrier()
|
| 179 |
+
|
| 180 |
+
if warp_idx == 0:
|
| 181 |
+
v = cutlass.Float32(0.0)
|
| 182 |
+
if lane_idx < fb_num_warps:
|
| 183 |
+
v = buf[lane_idx]
|
| 184 |
+
block_sum = cute.arch.warp_reduction(v, operator.add, threads_in_group=fb_num_warps)
|
| 185 |
+
|
| 186 |
+
if lane_idx == 0:
|
| 187 |
+
valid_ptr = valid_acc.iterator.toint().ir_value() # ty: ignore[unresolved-attribute]
|
| 188 |
+
counter_ptr = mask_counter.iterator.toint().ir_value() # ty: ignore[unresolved-attribute]
|
| 189 |
+
_atomic_add_f32_gmem(valid_ptr, block_sum)
|
| 190 |
+
cute.arch.fence_acq_rel_gpu()
|
| 191 |
+
old = _atomic_inc_u32_gmem(counter_ptr, total_blocks - 1)
|
| 192 |
+
if old == total_blocks - 1:
|
| 193 |
+
n_valid = valid_acc[0]
|
| 194 |
+
inv_n_valid[0] = cute.arch.rcp_approx(n_valid)
|
| 195 |
+
valid_acc[0] = cutlass.Float32(0.0)
|
| 196 |
+
|
| 197 |
+
return _mask_sum_kernel
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
# ---------------------------------------------------------------------------
|
| 201 |
+
# Main per-row reverse-KL kernel (fused fwd[+bwd]).
|
| 202 |
+
# ---------------------------------------------------------------------------
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def _make_reverse_kl_kernel(
|
| 206 |
+
compute_backward: bool,
|
| 207 |
+
fb_num_threads: int,
|
| 208 |
+
fb_num_warps: int,
|
| 209 |
+
fb_vec_size: int,
|
| 210 |
+
fb_tile_v: int,
|
| 211 |
+
) -> Callable[..., None]:
|
| 212 |
+
"""Return a ``@cute.kernel`` that fuses fwd loss + (optional) bwd grad.
|
| 213 |
+
|
| 214 |
+
One CTA processes one row. Pass 1 computes the online softmax stats
|
| 215 |
+
``(max_s, D_s, W = sum exp(s - max_s) (s - t), max_t, D_t)`` and the
|
| 216 |
+
per-row ``KL = W/D_s + (max_t - max_s) + log(D_t) - log(D_s)``.
|
| 217 |
+
|
| 218 |
+
When ``compute_backward=True`` the block-reduced stats are broadcast
|
| 219 |
+
through SMEM and Pass 2 re-reads student/teacher to write the
|
| 220 |
+
analytical gradient
|
| 221 |
+
``grad_v = scale * p_v * (log_p_v - log_q_v - KL)`` where
|
| 222 |
+
``p_v = exp(s_v - max_s) / D_s``,
|
| 223 |
+
``log_p_v - log_q_v = (s_v - t_v) + (max_t - max_s) + log(D_t) - log(D_s)``,
|
| 224 |
+
and ``scale = mask[r] * grad_output * inv_n_valid``.
|
| 225 |
+
|
| 226 |
+
When ``compute_backward=False`` Pass 2, the bcast SMEM, the
|
| 227 |
+
``grad_output`` read, and the ``* grad_output`` factor in the final
|
| 228 |
+
scalar are eliminated at trace time.
|
| 229 |
+
|
| 230 |
+
Cross-row loss accumulation rides on an atomic ``loss_acc`` plus a
|
| 231 |
+
wrap-around ``counter`` for last-block detection. The kernel always
|
| 232 |
+
reads ``mask_flat[bidx]`` (callers pass an all-ones mask in the
|
| 233 |
+
unmasked case) and, in the bwd path, short-circuits Pass 2 — writes
|
| 234 |
+
zeros and skips loss accumulation — for rows whose ``scale`` is 0.
|
| 235 |
+
"""
|
| 236 |
+
|
| 237 |
+
@cute.kernel
|
| 238 |
+
def _kernel(
|
| 239 |
+
student: cute.Tensor,
|
| 240 |
+
teacher: cute.Tensor,
|
| 241 |
+
mask_flat: cute.Tensor,
|
| 242 |
+
grad_output: cute.Tensor,
|
| 243 |
+
inv_n_valid: cute.Tensor,
|
| 244 |
+
grad_student: cute.Tensor,
|
| 245 |
+
loss_acc: cute.Tensor,
|
| 246 |
+
counter: cute.Tensor,
|
| 247 |
+
output: cute.Tensor,
|
| 248 |
+
num_full_tiles: int,
|
| 249 |
+
tail_len: int,
|
| 250 |
+
total_rows: cutlass.Int32,
|
| 251 |
+
tiled_copy_p1: cute.TiledCopy,
|
| 252 |
+
copy_atom_p1: cute.CopyAtom,
|
| 253 |
+
tiled_copy_p2: cute.TiledCopy,
|
| 254 |
+
copy_atom_p2: cute.CopyAtom,
|
| 255 |
+
copy_atom_store: cute.CopyAtom,
|
| 256 |
+
) -> None:
|
| 257 |
+
tidx = cute.arch.thread_idx()[0]
|
| 258 |
+
bidx = cute.arch.block_idx()[0]
|
| 259 |
+
|
| 260 |
+
# SMEM:
|
| 261 |
+
# red_buf: NUM_WARPS x 5 — per-warp partials for cross-warp reduce
|
| 262 |
+
# bcast: 6 floats — broadcast (kl, max_s, log_d_s, max_t,
|
| 263 |
+
# log_d_t, scale) to all threads in pass 2
|
| 264 |
+
# The bcast buffer is only needed when ``compute_backward`` is True.
|
| 265 |
+
smem = SmemAllocator()
|
| 266 |
+
red_buf = smem.allocate_tensor(cutlass.Float32, cute.make_layout((fb_num_warps, 5)))
|
| 267 |
+
if cutlass.const_expr(compute_backward):
|
| 268 |
+
bcast = smem.allocate_tensor(cutlass.Float32, cute.make_layout(6))
|
| 269 |
+
|
| 270 |
+
log2e = cutlass.Float32(_LOG2E)
|
| 271 |
+
neg_inf = cutlass.Float32(-cutlass.Float32.inf)
|
| 272 |
+
|
| 273 |
+
# Per-row scale = mask[bidx] * grad_output * inv_n_valid. The
|
| 274 |
+
# ``grad_output`` read is dead-code-eliminated for the fwd-only
|
| 275 |
+
# path so no spurious GMEM load is emitted.
|
| 276 |
+
mask_val = cast(mask_flat[bidx], cutlass.Float32)
|
| 277 |
+
inv_n = cast(inv_n_valid[0], cutlass.Float32)
|
| 278 |
+
if cutlass.const_expr(compute_backward):
|
| 279 |
+
grad_val = cast(grad_output[0], cutlass.Float32)
|
| 280 |
+
scale_val = mask_val * grad_val * inv_n
|
| 281 |
+
|
| 282 |
+
s_row = cute.slice_(student, (cutlass.Int64(bidx), None))
|
| 283 |
+
t_row = cute.slice_(teacher, (cutlass.Int64(bidx), None))
|
| 284 |
+
if cutlass.const_expr(compute_backward):
|
| 285 |
+
g_row = cute.slice_(grad_student, (cutlass.Int64(bidx), None))
|
| 286 |
+
out_dtype = grad_student.element_type
|
| 287 |
+
|
| 288 |
+
max_s = neg_inf
|
| 289 |
+
d_s = cutlass.Float32(0.0)
|
| 290 |
+
w_acc = cutlass.Float32(0.0)
|
| 291 |
+
max_t = neg_inf
|
| 292 |
+
d_t = cutlass.Float32(0.0)
|
| 293 |
+
|
| 294 |
+
thr_copy_p1 = tiled_copy_p1.get_slice(tidx)
|
| 295 |
+
|
| 296 |
+
# ---- Pass 1: online stats + W ----
|
| 297 |
+
for k in cutlass.range(num_full_tiles, unroll=2):
|
| 298 |
+
s_slab = cute.local_tile(s_row, (fb_tile_v,), (k,))
|
| 299 |
+
t_slab = cute.local_tile(t_row, (fb_tile_v,), (k,))
|
| 300 |
+
|
| 301 |
+
src_s = thr_copy_p1.partition_S(s_slab)
|
| 302 |
+
frag_s = cute.make_fragment_like(src_s)
|
| 303 |
+
cute.copy(copy_atom_p1, src_s, frag_s)
|
| 304 |
+
|
| 305 |
+
src_t = thr_copy_p1.partition_S(t_slab)
|
| 306 |
+
frag_t = cute.make_fragment_like(src_t)
|
| 307 |
+
cute.copy(copy_atom_p1, src_t, frag_t)
|
| 308 |
+
|
| 309 |
+
s_f32 = frag_s.load().to(cutlass.Float32)
|
| 310 |
+
t_f32 = frag_t.load().to(cutlass.Float32)
|
| 311 |
+
|
| 312 |
+
tile_max_s = s_f32.reduce(cute.ReductionOp.MAX, neg_inf, reduction_profile=0)
|
| 313 |
+
exp_s = cute.math.exp2((s_f32 - tile_max_s) * log2e, fastmath=True)
|
| 314 |
+
tile_d_s = exp_s.reduce( # ty: ignore[unresolved-attribute]
|
| 315 |
+
cute.ReductionOp.ADD, cutlass.Float32(0.0), reduction_profile=0
|
| 316 |
+
)
|
| 317 |
+
diff = s_f32 - t_f32
|
| 318 |
+
tile_w = (exp_s * diff).reduce(
|
| 319 |
+
cute.ReductionOp.ADD, cutlass.Float32(0.0), reduction_profile=0
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
new_max_s = cute.arch.fmax(max_s, tile_max_s)
|
| 323 |
+
corr_s = cute.math.exp2((max_s - new_max_s) * log2e, fastmath=True)
|
| 324 |
+
tile_corr_s = cute.math.exp2((tile_max_s - new_max_s) * log2e, fastmath=True)
|
| 325 |
+
d_s = d_s * corr_s + tile_d_s * tile_corr_s
|
| 326 |
+
w_acc = w_acc * corr_s + tile_w * tile_corr_s
|
| 327 |
+
max_s = new_max_s
|
| 328 |
+
|
| 329 |
+
tile_max_t = t_f32.reduce(cute.ReductionOp.MAX, neg_inf, reduction_profile=0)
|
| 330 |
+
exp_t = cute.math.exp2((t_f32 - tile_max_t) * log2e, fastmath=True)
|
| 331 |
+
tile_d_t = exp_t.reduce( # ty: ignore[unresolved-attribute]
|
| 332 |
+
cute.ReductionOp.ADD, cutlass.Float32(0.0), reduction_profile=0
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
new_max_t = cute.arch.fmax(max_t, tile_max_t)
|
| 336 |
+
corr_t = cute.math.exp2((max_t - new_max_t) * log2e, fastmath=True)
|
| 337 |
+
tile_corr_t = cute.math.exp2((tile_max_t - new_max_t) * log2e, fastmath=True)
|
| 338 |
+
d_t = d_t * corr_t + tile_d_t * tile_corr_t
|
| 339 |
+
max_t = new_max_t
|
| 340 |
+
|
| 341 |
+
if tail_len > 0:
|
| 342 |
+
tail_base = num_full_tiles * fb_tile_v
|
| 343 |
+
|
| 344 |
+
thr_max_s = neg_inf
|
| 345 |
+
thr_max_t = neg_inf
|
| 346 |
+
for i in cutlass.range(fb_vec_size):
|
| 347 |
+
e = tidx + i * fb_num_threads
|
| 348 |
+
if e < tail_len:
|
| 349 |
+
s_val = cast(s_row[tail_base + e], cutlass.Float32)
|
| 350 |
+
t_val = cast(t_row[tail_base + e], cutlass.Float32)
|
| 351 |
+
thr_max_s = cute.arch.fmax(thr_max_s, s_val)
|
| 352 |
+
thr_max_t = cute.arch.fmax(thr_max_t, t_val)
|
| 353 |
+
|
| 354 |
+
thr_d_s = cutlass.Float32(0.0)
|
| 355 |
+
thr_w = cutlass.Float32(0.0)
|
| 356 |
+
thr_d_t = cutlass.Float32(0.0)
|
| 357 |
+
for i in cutlass.range(fb_vec_size):
|
| 358 |
+
e = tidx + i * fb_num_threads
|
| 359 |
+
if e < tail_len:
|
| 360 |
+
s_val = cast(s_row[tail_base + e], cutlass.Float32)
|
| 361 |
+
t_val = cast(t_row[tail_base + e], cutlass.Float32)
|
| 362 |
+
exp_sv = cute.math.exp2((s_val - thr_max_s) * log2e, fastmath=True)
|
| 363 |
+
thr_d_s = thr_d_s + exp_sv
|
| 364 |
+
thr_w = thr_w + exp_sv * (s_val - t_val)
|
| 365 |
+
exp_tv = cute.math.exp2((t_val - thr_max_t) * log2e, fastmath=True)
|
| 366 |
+
thr_d_t = thr_d_t + exp_tv
|
| 367 |
+
|
| 368 |
+
new_max_s = cute.arch.fmax(max_s, thr_max_s)
|
| 369 |
+
corr_s = cute.math.exp2((max_s - new_max_s) * log2e, fastmath=True)
|
| 370 |
+
tail_corr_s = cute.math.exp2((thr_max_s - new_max_s) * log2e, fastmath=True)
|
| 371 |
+
d_s = d_s * corr_s + thr_d_s * tail_corr_s
|
| 372 |
+
w_acc = w_acc * corr_s + thr_w * tail_corr_s
|
| 373 |
+
max_s = new_max_s
|
| 374 |
+
|
| 375 |
+
new_max_t = cute.arch.fmax(max_t, thr_max_t)
|
| 376 |
+
corr_t = cute.math.exp2((max_t - new_max_t) * log2e, fastmath=True)
|
| 377 |
+
tail_corr_t = cute.math.exp2((thr_max_t - new_max_t) * log2e, fastmath=True)
|
| 378 |
+
d_t = d_t * corr_t + thr_d_t * tail_corr_t
|
| 379 |
+
max_t = new_max_t
|
| 380 |
+
|
| 381 |
+
# ---- Cross-warp reduction for student/teacher stats ----
|
| 382 |
+
warp_max_s = cute.arch.warp_reduction(max_s, cute.arch.fmax)
|
| 383 |
+
corr_w_s = cute.math.exp2((max_s - warp_max_s) * log2e, fastmath=True)
|
| 384 |
+
d_s = d_s * corr_w_s
|
| 385 |
+
w_acc = w_acc * corr_w_s
|
| 386 |
+
warp_d_s = cute.arch.warp_reduction(d_s, operator.add)
|
| 387 |
+
warp_w = cute.arch.warp_reduction(w_acc, operator.add)
|
| 388 |
+
|
| 389 |
+
warp_max_t = cute.arch.warp_reduction(max_t, cute.arch.fmax)
|
| 390 |
+
corr_w_t = cute.math.exp2((max_t - warp_max_t) * log2e, fastmath=True)
|
| 391 |
+
d_t = d_t * corr_w_t
|
| 392 |
+
warp_d_t = cute.arch.warp_reduction(d_t, operator.add)
|
| 393 |
+
|
| 394 |
+
lane_idx = cute.arch.lane_idx()
|
| 395 |
+
warp_idx = cute.arch.warp_idx()
|
| 396 |
+
|
| 397 |
+
if lane_idx == 0:
|
| 398 |
+
red_buf[warp_idx, 0] = warp_max_s
|
| 399 |
+
red_buf[warp_idx, 1] = warp_d_s
|
| 400 |
+
red_buf[warp_idx, 2] = warp_w
|
| 401 |
+
red_buf[warp_idx, 3] = warp_max_t
|
| 402 |
+
red_buf[warp_idx, 4] = warp_d_t
|
| 403 |
+
cute.arch.sync_threads()
|
| 404 |
+
|
| 405 |
+
# Warp 0 finishes the cross-warp reduction.
|
| 406 |
+
if warp_idx == 0:
|
| 407 |
+
r_max_s = neg_inf
|
| 408 |
+
r_d_s = cutlass.Float32(0.0)
|
| 409 |
+
r_w = cutlass.Float32(0.0)
|
| 410 |
+
r_max_t = neg_inf
|
| 411 |
+
r_d_t = cutlass.Float32(0.0)
|
| 412 |
+
|
| 413 |
+
if lane_idx < fb_num_warps:
|
| 414 |
+
r_max_s = red_buf[lane_idx, 0]
|
| 415 |
+
r_d_s = red_buf[lane_idx, 1]
|
| 416 |
+
r_w = red_buf[lane_idx, 2]
|
| 417 |
+
r_max_t = red_buf[lane_idx, 3]
|
| 418 |
+
r_d_t = red_buf[lane_idx, 4]
|
| 419 |
+
|
| 420 |
+
final_max_s = cute.arch.warp_reduction(
|
| 421 |
+
r_max_s, cute.arch.fmax, threads_in_group=fb_num_warps
|
| 422 |
+
)
|
| 423 |
+
fcorr_s = cute.math.exp2((r_max_s - final_max_s) * log2e, fastmath=True)
|
| 424 |
+
r_d_s = r_d_s * fcorr_s
|
| 425 |
+
r_w = r_w * fcorr_s
|
| 426 |
+
final_d_s = cute.arch.warp_reduction(r_d_s, operator.add, threads_in_group=fb_num_warps)
|
| 427 |
+
final_w = cute.arch.warp_reduction(r_w, operator.add, threads_in_group=fb_num_warps)
|
| 428 |
+
|
| 429 |
+
final_max_t = cute.arch.warp_reduction(
|
| 430 |
+
r_max_t, cute.arch.fmax, threads_in_group=fb_num_warps
|
| 431 |
+
)
|
| 432 |
+
fcorr_t = cute.math.exp2((r_max_t - final_max_t) * log2e, fastmath=True)
|
| 433 |
+
r_d_t = r_d_t * fcorr_t
|
| 434 |
+
final_d_t = cute.arch.warp_reduction(r_d_t, operator.add, threads_in_group=fb_num_warps)
|
| 435 |
+
|
| 436 |
+
if lane_idx == 0:
|
| 437 |
+
rcp_d_s = cute.arch.rcp_approx(final_d_s)
|
| 438 |
+
log_d_s = cute.math.log(final_d_s)
|
| 439 |
+
log_d_t = cute.math.log(final_d_t)
|
| 440 |
+
kl = final_w * rcp_d_s + log_d_t + final_max_t - log_d_s - final_max_s
|
| 441 |
+
|
| 442 |
+
# Cross-row loss accumulation: atomic-add (kl * mask) into
|
| 443 |
+
# ``loss_acc``; the last block scales by ``inv_n_valid``
|
| 444 |
+
# (and ``grad_output`` on the bwd path) and writes the
|
| 445 |
+
# scalar output.
|
| 446 |
+
loss_ptr = loss_acc.iterator.toint().ir_value() # ty: ignore[unresolved-attribute]
|
| 447 |
+
counter_ptr = counter.iterator.toint().ir_value() # ty: ignore[unresolved-attribute]
|
| 448 |
+
contribution = kl * mask_val
|
| 449 |
+
_atomic_add_f32_gmem(loss_ptr, contribution)
|
| 450 |
+
cute.arch.fence_acq_rel_gpu()
|
| 451 |
+
old = _atomic_inc_u32_gmem(counter_ptr, total_rows - 1)
|
| 452 |
+
if old == total_rows - 1:
|
| 453 |
+
if cutlass.const_expr(compute_backward):
|
| 454 |
+
final_loss = loss_acc[0] * inv_n * grad_val
|
| 455 |
+
else:
|
| 456 |
+
final_loss = loss_acc[0] * inv_n
|
| 457 |
+
output[0] = cast(final_loss, output.element_type) # ty: ignore[invalid-argument-type]
|
| 458 |
+
loss_acc[0] = cutlass.Float32(0.0)
|
| 459 |
+
|
| 460 |
+
if cutlass.const_expr(compute_backward):
|
| 461 |
+
bcast[0] = kl
|
| 462 |
+
bcast[1] = final_max_s
|
| 463 |
+
bcast[2] = log_d_s
|
| 464 |
+
bcast[3] = final_max_t
|
| 465 |
+
bcast[4] = log_d_t
|
| 466 |
+
bcast[5] = scale_val
|
| 467 |
+
|
| 468 |
+
# ---- Pass 2: re-read logits and write gradient ----
|
| 469 |
+
# Entire pass is dead-code-eliminated when ``compute_backward`` is
|
| 470 |
+
# False — the kernel becomes a single-pass forward identical in
|
| 471 |
+
# PTX to a hand-written fwd-only kernel.
|
| 472 |
+
if cutlass.const_expr(compute_backward):
|
| 473 |
+
cute.arch.sync_threads()
|
| 474 |
+
|
| 475 |
+
kl_b = bcast[0]
|
| 476 |
+
max_s_b = bcast[1]
|
| 477 |
+
log_d_s_b = bcast[2]
|
| 478 |
+
max_t_b = bcast[3]
|
| 479 |
+
log_d_t_b = bcast[4]
|
| 480 |
+
scale_b = bcast[5]
|
| 481 |
+
|
| 482 |
+
log_offset = max_t_b - max_s_b + log_d_t_b - log_d_s_b
|
| 483 |
+
|
| 484 |
+
# Skip Pass 2 entirely for rows with scale=0 (masked-out rows).
|
| 485 |
+
if scale_b != cutlass.Float32(0.0):
|
| 486 |
+
thr_copy_p2 = tiled_copy_p2.get_slice(tidx)
|
| 487 |
+
|
| 488 |
+
for k in cutlass.range(num_full_tiles, unroll=2):
|
| 489 |
+
s_slab = cute.local_tile(s_row, (fb_tile_v,), (k,))
|
| 490 |
+
t_slab = cute.local_tile(t_row, (fb_tile_v,), (k,))
|
| 491 |
+
g_slab = cute.local_tile(g_row, (fb_tile_v,), (k,))
|
| 492 |
+
|
| 493 |
+
src_s = thr_copy_p2.partition_S(s_slab)
|
| 494 |
+
frag_s = cute.make_fragment_like(src_s)
|
| 495 |
+
cute.copy(copy_atom_p2, src_s, frag_s)
|
| 496 |
+
|
| 497 |
+
src_t = thr_copy_p2.partition_S(t_slab)
|
| 498 |
+
frag_t = cute.make_fragment_like(src_t)
|
| 499 |
+
cute.copy(copy_atom_p2, src_t, frag_t)
|
| 500 |
+
|
| 501 |
+
s_f32 = frag_s.load().to(cutlass.Float32)
|
| 502 |
+
t_f32 = frag_t.load().to(cutlass.Float32)
|
| 503 |
+
|
| 504 |
+
# p_v = exp((s_v - max_s) * log2e) / D_s = exp(s_v - logZ_s)
|
| 505 |
+
p_v = cute.math.exp2((s_f32 - max_s_b - log_d_s_b) * log2e, fastmath=True)
|
| 506 |
+
log_diff = s_f32 - t_f32 + log_offset
|
| 507 |
+
grad = scale_b * p_v * (log_diff - kl_b)
|
| 508 |
+
|
| 509 |
+
dst_g = thr_copy_p2.partition_D(g_slab)
|
| 510 |
+
out_frag = cute.make_fragment_like(dst_g)
|
| 511 |
+
out_frag.store(grad.to(out_dtype))
|
| 512 |
+
cute.copy(copy_atom_store, out_frag, dst_g)
|
| 513 |
+
|
| 514 |
+
if tail_len > 0:
|
| 515 |
+
tail_base = num_full_tiles * fb_tile_v
|
| 516 |
+
for i in cutlass.range(fb_vec_size):
|
| 517 |
+
e = tidx + i * fb_num_threads
|
| 518 |
+
if e < tail_len:
|
| 519 |
+
s_val = cast(s_row[tail_base + e], cutlass.Float32)
|
| 520 |
+
t_val = cast(t_row[tail_base + e], cutlass.Float32)
|
| 521 |
+
p_v = cute.math.exp2(
|
| 522 |
+
(s_val - max_s_b - log_d_s_b) * log2e, fastmath=True
|
| 523 |
+
)
|
| 524 |
+
log_diff = s_val - t_val + log_offset
|
| 525 |
+
grad = scale_b * p_v * (log_diff - kl_b)
|
| 526 |
+
g_row[tail_base + e] = cast(grad, out_dtype) # ty: ignore[invalid-argument-type]
|
| 527 |
+
else:
|
| 528 |
+
# Masked row: write zeros so callers see a clean grad slab.
|
| 529 |
+
zero_elem = cast(cutlass.Float32(0.0), out_dtype) # ty: ignore[invalid-argument-type]
|
| 530 |
+
thr_copy_z = tiled_copy_p2.get_slice(tidx)
|
| 531 |
+
for k in cutlass.range(num_full_tiles, unroll=2):
|
| 532 |
+
g_slab = cute.local_tile(g_row, (fb_tile_v,), (k,))
|
| 533 |
+
dst_g = thr_copy_z.partition_D(g_slab)
|
| 534 |
+
zero_frag = cute.make_fragment_like(dst_g)
|
| 535 |
+
zero_frag.fill(zero_elem)
|
| 536 |
+
cute.copy(copy_atom_store, zero_frag, dst_g)
|
| 537 |
+
|
| 538 |
+
if tail_len > 0:
|
| 539 |
+
tail_base = num_full_tiles * fb_tile_v
|
| 540 |
+
for i in cutlass.range(fb_vec_size):
|
| 541 |
+
e = tidx + i * fb_num_threads
|
| 542 |
+
if e < tail_len:
|
| 543 |
+
g_row[tail_base + e] = zero_elem
|
| 544 |
+
|
| 545 |
+
return _kernel
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
def _make_fwd_bwd_launcher(
|
| 549 |
+
compute_backward: bool,
|
| 550 |
+
fb_num_threads: int,
|
| 551 |
+
fb_num_warps: int,
|
| 552 |
+
fb_vec_size: int,
|
| 553 |
+
fb_tile_v: int,
|
| 554 |
+
) -> Callable[..., None]:
|
| 555 |
+
"""Return a ``@cute.jit`` launcher that runs mask-sum + main kernel.
|
| 556 |
+
|
| 557 |
+
When ``compute_backward=False`` Pass 2 + bcast SMEM are dead-code
|
| 558 |
+
eliminated inside the kernel and Pass 1 loads default to NO_ALLOCATE
|
| 559 |
+
(no benefit from L2 pinning since there is no re-read).
|
| 560 |
+
"""
|
| 561 |
+
main_kernel = _make_reverse_kl_kernel(
|
| 562 |
+
compute_backward, fb_num_threads, fb_num_warps, fb_vec_size, fb_tile_v
|
| 563 |
+
)
|
| 564 |
+
mask_sum_kernel = _make_mask_sum_kernel(fb_num_threads, fb_num_warps)
|
| 565 |
+
|
| 566 |
+
@cute.jit
|
| 567 |
+
def _launch(
|
| 568 |
+
student_2d: cute.Tensor,
|
| 569 |
+
teacher_2d: cute.Tensor,
|
| 570 |
+
mask_flat: cute.Tensor,
|
| 571 |
+
grad_output: cute.Tensor,
|
| 572 |
+
inv_n_valid: cute.Tensor,
|
| 573 |
+
grad_student_2d: cute.Tensor,
|
| 574 |
+
loss_acc: cute.Tensor,
|
| 575 |
+
valid_acc: cute.Tensor,
|
| 576 |
+
counter: cute.Tensor,
|
| 577 |
+
mask_counter: cute.Tensor,
|
| 578 |
+
output: cute.Tensor,
|
| 579 |
+
num_full_tiles: cutlass.Int32,
|
| 580 |
+
tail_len: cutlass.Int32,
|
| 581 |
+
mask_sum_blocks: cutlass.Int32,
|
| 582 |
+
) -> None:
|
| 583 |
+
num_rows = student_2d.shape[0] # ty: ignore[not-subscriptable]
|
| 584 |
+
dtype = student_2d.element_type
|
| 585 |
+
out_dtype = grad_student_2d.element_type
|
| 586 |
+
|
| 587 |
+
# Pass 1 loads: EVICT_LAST when we need the data pinned in L2 for
|
| 588 |
+
# Pass 2 re-reads; NO_ALLOCATE for fwd-only since the data is
|
| 589 |
+
# never re-read.
|
| 590 |
+
if cutlass.const_expr(compute_backward):
|
| 591 |
+
p1_evict = CacheEvictionPriority.EVICT_LAST
|
| 592 |
+
else:
|
| 593 |
+
p1_evict = CacheEvictionPriority.NO_ALLOCATE
|
| 594 |
+
copy_atom_p1 = cute.make_copy_atom(
|
| 595 |
+
CopyUniversalOp(),
|
| 596 |
+
dtype,
|
| 597 |
+
num_bits_per_copy=fb_vec_size * dtype.width, # ty: ignore[unresolved-attribute]
|
| 598 |
+
l1c_evict_priority=p1_evict,
|
| 599 |
+
)
|
| 600 |
+
thr_layout = cute.make_layout((fb_num_threads,))
|
| 601 |
+
val_layout = cute.make_layout((fb_vec_size,))
|
| 602 |
+
tiler_v_p1, layout_tv_p1 = cute.make_layout_tv(thr_layout, val_layout)
|
| 603 |
+
tiled_copy_p1 = cute.make_tiled_copy(copy_atom_p1, layout_tv_p1, tiler_v_p1)
|
| 604 |
+
|
| 605 |
+
# Pass 2 loads: NO_ALLOCATE — streaming, never re-read.
|
| 606 |
+
copy_atom_p2 = cute.make_copy_atom(
|
| 607 |
+
CopyUniversalOp(),
|
| 608 |
+
dtype,
|
| 609 |
+
num_bits_per_copy=fb_vec_size * dtype.width, # ty: ignore[unresolved-attribute]
|
| 610 |
+
l1c_evict_priority=CacheEvictionPriority.NO_ALLOCATE,
|
| 611 |
+
)
|
| 612 |
+
tiler_v_p2, layout_tv_p2 = cute.make_layout_tv(thr_layout, val_layout)
|
| 613 |
+
tiled_copy_p2 = cute.make_tiled_copy(copy_atom_p2, layout_tv_p2, tiler_v_p2)
|
| 614 |
+
|
| 615 |
+
# Stores: NO_ALLOCATE — write-once, never re-read.
|
| 616 |
+
copy_atom_store = cute.make_copy_atom(
|
| 617 |
+
CopyUniversalOp(),
|
| 618 |
+
out_dtype,
|
| 619 |
+
num_bits_per_copy=fb_vec_size * out_dtype.width, # ty: ignore[unresolved-attribute]
|
| 620 |
+
l1c_evict_priority=CacheEvictionPriority.NO_ALLOCATE,
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
mask_sum_kernel( # ty: ignore[unresolved-attribute]
|
| 624 |
+
mask_flat,
|
| 625 |
+
inv_n_valid,
|
| 626 |
+
valid_acc,
|
| 627 |
+
mask_counter,
|
| 628 |
+
num_rows,
|
| 629 |
+
mask_sum_blocks,
|
| 630 |
+
).launch(
|
| 631 |
+
grid=(mask_sum_blocks, 1, 1),
|
| 632 |
+
block=(fb_num_threads, 1, 1),
|
| 633 |
+
)
|
| 634 |
+
|
| 635 |
+
main_kernel( # ty: ignore[unresolved-attribute]
|
| 636 |
+
student_2d,
|
| 637 |
+
teacher_2d,
|
| 638 |
+
mask_flat,
|
| 639 |
+
grad_output,
|
| 640 |
+
inv_n_valid,
|
| 641 |
+
grad_student_2d,
|
| 642 |
+
loss_acc,
|
| 643 |
+
counter,
|
| 644 |
+
output,
|
| 645 |
+
num_full_tiles,
|
| 646 |
+
tail_len,
|
| 647 |
+
num_rows,
|
| 648 |
+
tiled_copy_p1,
|
| 649 |
+
copy_atom_p1,
|
| 650 |
+
tiled_copy_p2,
|
| 651 |
+
copy_atom_p2,
|
| 652 |
+
copy_atom_store,
|
| 653 |
+
).launch(
|
| 654 |
+
grid=(num_rows, 1, 1),
|
| 655 |
+
block=(fb_num_threads, 1, 1),
|
| 656 |
+
)
|
| 657 |
+
|
| 658 |
+
return _launch
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
# ---------------------------------------------------------------------------
|
| 662 |
+
# Public factory
|
| 663 |
+
# ---------------------------------------------------------------------------
|
| 664 |
+
|
| 665 |
+
|
| 666 |
+
def create_compiled_reverse_kl(
|
| 667 |
+
policy_dtype: torch.dtype,
|
| 668 |
+
vocab: int,
|
| 669 |
+
compute_backward: bool = False,
|
| 670 |
+
) -> Callable[..., torch.Tensor | tuple[torch.Tensor, torch.Tensor]]:
|
| 671 |
+
"""JIT-compile the fused reverse-KL kernel (forward or fused fwd+bwd).
|
| 672 |
+
|
| 673 |
+
Bundles a mask-sum kernel (which computes ``inv_n_valid``) and the
|
| 674 |
+
main per-row kernel into a single ``@cute.jit`` launch so each call
|
| 675 |
+
is one tvm-ffi dispatch with no host syncs. The unmasked case is
|
| 676 |
+
handled by passing an all-ones ``mask_flat`` (callers are responsible).
|
| 677 |
+
|
| 678 |
+
When ``compute_backward=False`` the runner takes 10 args and returns
|
| 679 |
+
the scalar loss. When ``compute_backward=True`` it takes 11 args
|
| 680 |
+
(with a real ``grad_student_2d`` slab) and returns
|
| 681 |
+
``(loss_scalar, grad_student_2d)``.
|
| 682 |
+
|
| 683 |
+
Args:
|
| 684 |
+
policy_dtype: Element dtype of student/teacher logits and the
|
| 685 |
+
output gradient slab. Both tensors must share this dtype.
|
| 686 |
+
vocab: ``V`` dimension of the ``(num_tokens, V)`` slab. Captured
|
| 687 |
+
as a compile-time constant; the number of token rows stays
|
| 688 |
+
symbolic across calls.
|
| 689 |
+
compute_backward: When ``True`` the kernel additionally writes
|
| 690 |
+
the analytical gradient through the softmax Jacobian into
|
| 691 |
+
the caller's ``grad_student_2d`` slab.
|
| 692 |
+
"""
|
| 693 |
+
if policy_dtype not in _TORCH_TO_CUTLASS_DTYPE:
|
| 694 |
+
raise ValueError(f"Unsupported dtype for self-distillation loss: {policy_dtype}")
|
| 695 |
+
|
| 696 |
+
student_dtype = _TORCH_TO_CUTLASS_DTYPE[policy_dtype]
|
| 697 |
+
fb_vec_size = FB_LOAD_BITS // student_dtype.width
|
| 698 |
+
fb_tile_v = FB_NUM_THREADS * fb_vec_size
|
| 699 |
+
|
| 700 |
+
num_full_tiles = vocab // fb_tile_v
|
| 701 |
+
tail_len = vocab % fb_tile_v
|
| 702 |
+
block_size = FB_NUM_WARPS * 32
|
| 703 |
+
|
| 704 |
+
num_tokens_sym = cute.sym_int()
|
| 705 |
+
|
| 706 |
+
fake_s = cute.runtime.make_fake_compact_tensor(
|
| 707 |
+
student_dtype,
|
| 708 |
+
(num_tokens_sym, vocab),
|
| 709 |
+
stride_order=(1, 0),
|
| 710 |
+
assumed_align=16,
|
| 711 |
+
)
|
| 712 |
+
fake_t = cute.runtime.make_fake_compact_tensor(
|
| 713 |
+
student_dtype,
|
| 714 |
+
(num_tokens_sym, vocab),
|
| 715 |
+
stride_order=(1, 0),
|
| 716 |
+
assumed_align=16,
|
| 717 |
+
)
|
| 718 |
+
fake_mask = cute.runtime.make_fake_compact_tensor(
|
| 719 |
+
cutlass.Float32,
|
| 720 |
+
(num_tokens_sym,),
|
| 721 |
+
assumed_align=16,
|
| 722 |
+
)
|
| 723 |
+
fake_grad_out = cute.runtime.make_fake_compact_tensor(
|
| 724 |
+
cutlass.Float32,
|
| 725 |
+
(1,),
|
| 726 |
+
assumed_align=16,
|
| 727 |
+
)
|
| 728 |
+
fake_inv_n = cute.runtime.make_fake_compact_tensor(
|
| 729 |
+
cutlass.Float32,
|
| 730 |
+
(1,),
|
| 731 |
+
assumed_align=16,
|
| 732 |
+
)
|
| 733 |
+
# Full ``(N, V)`` slab when computing the backward; 1-column dummy
|
| 734 |
+
# slab for fwd-only since Pass 2 is dead-code-eliminated and only
|
| 735 |
+
# the tensor signature matters.
|
| 736 |
+
grad_v_dim = vocab if compute_backward else 1
|
| 737 |
+
fake_grad_student = cute.runtime.make_fake_compact_tensor(
|
| 738 |
+
student_dtype,
|
| 739 |
+
(num_tokens_sym, grad_v_dim),
|
| 740 |
+
stride_order=(1, 0),
|
| 741 |
+
assumed_align=16,
|
| 742 |
+
)
|
| 743 |
+
fake_loss = cute.runtime.make_fake_compact_tensor(
|
| 744 |
+
cutlass.Float32,
|
| 745 |
+
(1,),
|
| 746 |
+
assumed_align=16,
|
| 747 |
+
)
|
| 748 |
+
fake_valid = cute.runtime.make_fake_compact_tensor(
|
| 749 |
+
cutlass.Float32,
|
| 750 |
+
(1,),
|
| 751 |
+
assumed_align=16,
|
| 752 |
+
)
|
| 753 |
+
fake_counter = cute.runtime.make_fake_compact_tensor(
|
| 754 |
+
cutlass.Int32,
|
| 755 |
+
(1,),
|
| 756 |
+
assumed_align=16,
|
| 757 |
+
)
|
| 758 |
+
fake_mask_counter = cute.runtime.make_fake_compact_tensor(
|
| 759 |
+
cutlass.Int32,
|
| 760 |
+
(1,),
|
| 761 |
+
assumed_align=16,
|
| 762 |
+
)
|
| 763 |
+
fake_output = cute.runtime.make_fake_compact_tensor(
|
| 764 |
+
student_dtype,
|
| 765 |
+
(1,),
|
| 766 |
+
assumed_align=16,
|
| 767 |
+
)
|
| 768 |
+
|
| 769 |
+
launcher = _make_fwd_bwd_launcher(
|
| 770 |
+
compute_backward=compute_backward,
|
| 771 |
+
fb_num_threads=FB_NUM_THREADS,
|
| 772 |
+
fb_num_warps=FB_NUM_WARPS,
|
| 773 |
+
fb_vec_size=fb_vec_size,
|
| 774 |
+
fb_tile_v=fb_tile_v,
|
| 775 |
+
)
|
| 776 |
+
compiled = cute.compile(
|
| 777 |
+
launcher,
|
| 778 |
+
fake_s,
|
| 779 |
+
fake_t,
|
| 780 |
+
fake_mask,
|
| 781 |
+
fake_grad_out,
|
| 782 |
+
fake_inv_n,
|
| 783 |
+
fake_grad_student,
|
| 784 |
+
fake_loss,
|
| 785 |
+
fake_valid,
|
| 786 |
+
fake_counter,
|
| 787 |
+
fake_mask_counter,
|
| 788 |
+
fake_output,
|
| 789 |
+
cutlass.Int32(num_full_tiles),
|
| 790 |
+
cutlass.Int32(tail_len),
|
| 791 |
+
cutlass.Int32(1),
|
| 792 |
+
options="--enable-tvm-ffi",
|
| 793 |
+
)
|
| 794 |
+
|
| 795 |
+
nft_const = cutlass.Int32(num_full_tiles)
|
| 796 |
+
tl_const = cutlass.Int32(tail_len)
|
| 797 |
+
grad_v_dim_runtime = vocab if compute_backward else 1
|
| 798 |
+
|
| 799 |
+
# ---- Closure-scoped scratch ----
|
| 800 |
+
# ``grad_output`` is the upstream gradient feeding Pass 2 (1.0 for
|
| 801 |
+
# backward, irrelevant for forward — Pass 2 is dead-code-eliminated).
|
| 802 |
+
# Constant across calls; allocated lazily on first call when the
|
| 803 |
+
# device is known, then reused.
|
| 804 |
+
#
|
| 805 |
+
# ``scratch_z`` coalesces the 4 atomic-accumulator scalars (counter,
|
| 806 |
+
# mask_counter, loss_acc.fp32, valid_acc.fp32) into one int32 slab
|
| 807 |
+
# with stride-4 (16-byte) slices so each slot is individually 16-byte
|
| 808 |
+
# aligned (``assumed_align=16`` at compile time). Bit-pattern of int32 0
|
| 809 |
+
# equals fp32 0.0, so a single ``zeros`` factory legitimately
|
| 810 |
+
# initialises both the int32 counters and the fp32 accumulators.
|
| 811 |
+
# Both kernels' last blocks self-reset their fp32 accumulators in
|
| 812 |
+
# their epilogues, and counters self-reset via ``atom.inc.u32``
|
| 813 |
+
# wrap-around — so the up-front ``torch.zeros`` only matters for the
|
| 814 |
+
# first call.
|
| 815 |
+
_scratch: list[tuple[torch.Tensor, torch.Tensor] | None] = [None]
|
| 816 |
+
|
| 817 |
+
def _ensure_scratch(
|
| 818 |
+
device: torch.device,
|
| 819 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 820 |
+
s = _scratch[0]
|
| 821 |
+
if s is None or s[0].device != device:
|
| 822 |
+
slab = torch.zeros(16, dtype=torch.int32, device=device)
|
| 823 |
+
if compute_backward:
|
| 824 |
+
grad_output = torch.ones(1, dtype=torch.float32, device=device)
|
| 825 |
+
else:
|
| 826 |
+
grad_output = torch.empty(1, dtype=torch.float32, device=device)
|
| 827 |
+
_scratch[0] = (slab, grad_output)
|
| 828 |
+
s = _scratch[0]
|
| 829 |
+
slab, grad_output = s
|
| 830 |
+
return (
|
| 831 |
+
slab[0:1], # counter (int32)
|
| 832 |
+
slab[4:5], # mask_counter (int32)
|
| 833 |
+
slab[8:9].view(torch.float32), # loss_acc (fp32)
|
| 834 |
+
slab[12:13].view(torch.float32), # valid_acc (fp32)
|
| 835 |
+
grad_output,
|
| 836 |
+
)
|
| 837 |
+
|
| 838 |
+
def _run(
|
| 839 |
+
student_2d: torch.Tensor,
|
| 840 |
+
teacher_2d: torch.Tensor,
|
| 841 |
+
mask_flat: torch.Tensor,
|
| 842 |
+
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
| 843 |
+
num_rows = student_2d.shape[0]
|
| 844 |
+
device = student_2d.device
|
| 845 |
+
dtype = student_2d.dtype
|
| 846 |
+
|
| 847 |
+
counter_r, mask_counter_r, loss_acc_r, valid_acc_r, grad_output_r = _ensure_scratch(device)
|
| 848 |
+
|
| 849 |
+
# Per-call write-only buffers — ``empty`` is enough.
|
| 850 |
+
# ``inv_n_valid`` is populated by the bundled mask-sum kernel
|
| 851 |
+
# before the main kernel reads it; the runner never reads it.
|
| 852 |
+
inv_n_valid_r = torch.empty(1, dtype=torch.float32, device=device)
|
| 853 |
+
output_r = torch.empty(1, dtype=dtype, device=device)
|
| 854 |
+
# ``grad_v_dim`` is ``vocab`` for backward and ``1`` for forward
|
| 855 |
+
# (1-column dummy slab — Pass 2 is dead-code-eliminated, only the
|
| 856 |
+
# tensor-parameter signature matters).
|
| 857 |
+
grad_buffer = torch.empty(num_rows, grad_v_dim_runtime, dtype=dtype, device=device)
|
| 858 |
+
|
| 859 |
+
mask_sum_blocks = (num_rows + block_size - 1) // block_size
|
| 860 |
+
compiled(
|
| 861 |
+
student_2d,
|
| 862 |
+
teacher_2d,
|
| 863 |
+
mask_flat,
|
| 864 |
+
grad_output_r,
|
| 865 |
+
inv_n_valid_r,
|
| 866 |
+
grad_buffer,
|
| 867 |
+
loss_acc_r,
|
| 868 |
+
valid_acc_r,
|
| 869 |
+
counter_r,
|
| 870 |
+
mask_counter_r,
|
| 871 |
+
output_r,
|
| 872 |
+
nft_const,
|
| 873 |
+
tl_const,
|
| 874 |
+
cutlass.Int32(mask_sum_blocks),
|
| 875 |
+
)
|
| 876 |
+
out_view = output_r.view(())
|
| 877 |
+
if compute_backward:
|
| 878 |
+
return out_view, grad_buffer
|
| 879 |
+
return out_view
|
| 880 |
+
|
| 881 |
+
return _run
|