Pramodith's picture
Build uploaded using `kernels`.
5285eac verified
"""Geometric-AI CuteDSL kernels for RL / distillation training.
Public surface:
* ``bnpo_loss`` / ``bnpo_loss_autograd`` / ``bnpo_loss_fwd`` —
fused fwd+bwd BNPO loss with three entry points (direct
``(loss, grad)``, autograd-wrapped, forward-only).
* ``grpo_loss`` / ``grpo_loss_autograd`` / ``grpo_loss_fwd`` —
fused fwd+bwd GRPO loss (TRL's per-response normalization
variant). Same three-entry-point shape as BNPO. Requires
``completions_mask``.
* ``reverse_kl`` / ``reverse_kl_autograd`` /
``reverse_kl_fwd`` — fused fwd+bwd reverse-KL
self-distillation loss with the same three-entry-point shape.
HF Kernels integration: :mod:`geometric_ai_kernels.layers` exposes
``nn.Module`` adapters per kernel (``bnpoLoss`` / ``bnpoLossInference``,
``grpoLoss`` / ``grpoLossInference``, ``ReverseKL`` /
``ReverseKLInference``) for use with the ``kernels``
library's ``kernelize()`` flow.
"""
from __future__ import annotations
import torch._dynamo
from .bnpo_loss import bnpo_loss, bnpo_loss_autograd, bnpo_loss_fwd
from .grpo_loss import grpo_loss, grpo_loss_autograd, grpo_loss_fwd
from .layers import (
ReverseKL,
ReverseKLInference,
bnpoLoss,
bnpoLossInference,
grpoLoss,
grpoLossInference,
)
from .reverse_kl import (
reverse_kl,
reverse_kl_autograd,
reverse_kl_fwd,
)
# Required so ``torch.compile(fullgraph=True)`` can trace through
# ``torch.autograd.grad`` calls — without it Dynamo graph-breaks at the
# autograd.grad call site even when AOTAutograd has already derived the
# joint fwd+bwd graph. Set at package import so any consumer (benches,
# user training loops, ``kernelize`` flows) gets it for free. Guarded
# because ``trace_autograd_ops`` was added in torch 2.10 and the
# Nix-pinned build environment may be on an older torch (2.9 today);
# the underlying ``Config.__setattr__`` raises on unknown keys.
if hasattr(torch._dynamo.config, "trace_autograd_ops"):
torch._dynamo.config.trace_autograd_ops = True # ty: ignore[invalid-assignment]
__all__ = [
"ReverseKL",
"ReverseKLInference",
"bnpoLoss",
"bnpoLossInference",
"bnpo_loss",
"bnpo_loss_autograd",
"bnpo_loss_fwd",
"grpoLoss",
"grpoLossInference",
"grpo_loss",
"grpo_loss_autograd",
"grpo_loss_fwd",
"reverse_kl",
"reverse_kl_autograd",
"reverse_kl_fwd",
]