"""Geometric-AI CuteDSL kernels for RL / distillation training. Public surface: * ``bnpo_loss`` / ``bnpo_loss_autograd`` / ``bnpo_loss_fwd`` — fused fwd+bwd BNPO loss with three entry points (direct ``(loss, grad)``, autograd-wrapped, forward-only). * ``grpo_loss`` / ``grpo_loss_autograd`` / ``grpo_loss_fwd`` — fused fwd+bwd GRPO loss (TRL's per-response normalization variant). Same three-entry-point shape as BNPO. Requires ``completions_mask``. * ``reverse_kl`` / ``reverse_kl_autograd`` / ``reverse_kl_fwd`` — fused fwd+bwd reverse-KL self-distillation loss with the same three-entry-point shape. HF Kernels integration: :mod:`geometric_ai_kernels.layers` exposes ``nn.Module`` adapters per kernel (``bnpoLoss`` / ``bnpoLossInference``, ``grpoLoss`` / ``grpoLossInference``, ``ReverseKL`` / ``ReverseKLInference``) for use with the ``kernels`` library's ``kernelize()`` flow. """ from __future__ import annotations import torch._dynamo from .bnpo_loss import bnpo_loss, bnpo_loss_autograd, bnpo_loss_fwd from .grpo_loss import grpo_loss, grpo_loss_autograd, grpo_loss_fwd from .layers import ( ReverseKL, ReverseKLInference, bnpoLoss, bnpoLossInference, grpoLoss, grpoLossInference, ) from .reverse_kl import ( reverse_kl, reverse_kl_autograd, reverse_kl_fwd, ) # Required so ``torch.compile(fullgraph=True)`` can trace through # ``torch.autograd.grad`` calls — without it Dynamo graph-breaks at the # autograd.grad call site even when AOTAutograd has already derived the # joint fwd+bwd graph. Set at package import so any consumer (benches, # user training loops, ``kernelize`` flows) gets it for free. Guarded # because ``trace_autograd_ops`` was added in torch 2.10 and the # Nix-pinned build environment may be on an older torch (2.9 today); # the underlying ``Config.__setattr__`` raises on unknown keys. if hasattr(torch._dynamo.config, "trace_autograd_ops"): torch._dynamo.config.trace_autograd_ops = True # ty: ignore[invalid-assignment] __all__ = [ "ReverseKL", "ReverseKLInference", "bnpoLoss", "bnpoLossInference", "bnpo_loss", "bnpo_loss_autograd", "bnpo_loss_fwd", "grpoLoss", "grpoLossInference", "grpo_loss", "grpo_loss_autograd", "grpo_loss_fwd", "reverse_kl", "reverse_kl_autograd", "reverse_kl_fwd", ]