Kernels
danieldk's picture
danieldk HF Staff
Build uploaded using `kernels`.
e5b00e0 verified
from typing import Optional
import torch
import triton
import triton.language as tl
from .utils import ensure_contiguous
from .utils import infer_device
@triton.jit
def _jsd_kernel(
X_ptr, # input in logspace, X = log Q
X_stride,
Y_ptr, # ground truth in logspace, Y = log P
Y_stride,
loss_ptr,
loss_stride,
dX_ptr,
dX_stride,
label_ptr,
beta: tl.constexpr,
n_non_ignore: int,
ignore_index: tl.constexpr,
n_cols,
BLOCK_SIZE: tl.constexpr,
HAS_LABEL: tl.constexpr,
):
# JSD(P || Q) = (KL(P || M) + KL(Q || M)) / 2, M = (1/2) * (P + Q) = (1/2) * (e ^ Y + e ^ X)
# = sum(P * log P + Q * log Q - 2 * M * log M) / 2
# = sum(e ^ Y * Y + e ^ X * X - 2 * M * log M) / 2
# grad_x_i = 0.5 * Q * (X - log_M)
pid = tl.program_id(0).to(tl.int64)
X_ptr += pid * X_stride
dX_ptr += pid * dX_stride
Y_ptr += pid * Y_stride
loss_ptr += pid * loss_stride
label_ptr += pid
if HAS_LABEL:
label = tl.load(label_ptr)
if label == ignore_index:
for i in range(0, n_cols, BLOCK_SIZE):
offsets = i + tl.arange(0, BLOCK_SIZE)
tl.store(dX_ptr + offsets, 0.0, mask=offsets < n_cols)
return
for i in range(0, n_cols, BLOCK_SIZE):
offsets = i + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_cols
X = tl.load(X_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
if beta == 0.0: # forward KL
Y_max = tl.max(Y, axis=0)
Y_shifted = Y - Y_max
Y_prob = tl.exp(Y_shifted) * tl.exp(Y_max) # Compensate for the shift
loss = Y_prob * (Y - X)
dX = -Y_prob
elif beta == 1.0: # reverse KL
X_max = tl.max(X, axis=0)
X_shifted = X - X_max
X_prob = tl.exp(X_shifted) * tl.exp(X_max) # Compensate for the shift
loss = X_prob * (X - Y)
dX = loss + X_prob
else:
max_val = tl.maximum(tl.max(X, axis=0), tl.max(Y, axis=0))
X_shifted = X - max_val
Y_shifted = Y - max_val
# Pre-compute exp(max_val) since it's used twice
exp_max = tl.exp(max_val)
# Compute exp terms with compensation
Q = tl.exp(X_shifted) * exp_max # = exp(X)
P = tl.exp(Y_shifted) * exp_max # = exp(Y)
# Pre-compute common terms
beta_P = beta * P
one_minus_beta_Q = (1 - beta) * Q
M = beta_P + one_minus_beta_Q
log_M = tl.log(M) # No need to compensate as M is already in original scale
loss = beta_P * Y + one_minus_beta_Q * X - M * log_M
dX = one_minus_beta_Q * (X - log_M)
# Pre-compute scaling factor
scale = 1.0 / n_non_ignore
loss = loss * scale
dX = dX * scale
tl.store(loss_ptr + offsets, loss, mask=mask)
tl.store(dX_ptr + offsets, dX, mask=mask)
MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536
def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label):
BT, V = _input.shape
n_rows = BT
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
# non reduction loss
loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device)
dX = torch.empty_like(_input)
if has_label:
n_non_ignore = (shift_labels != ignore_index).sum().item()
else:
n_non_ignore = BT
_jsd_kernel[(n_rows,)](
X_ptr=_input, # input in logspace, X = log Q
X_stride=_input.stride(-2),
Y_ptr=target, # ground truth in logspace, Y = log P
Y_stride=target.stride(-2),
loss_ptr=loss,
loss_stride=loss.stride(-2),
dX_ptr=dX,
dX_stride=dX.stride(-2),
label_ptr=(shift_labels if has_label else torch.empty(1, device=_input.device)), # dummy ptr if no label
beta=beta,
n_non_ignore=n_non_ignore,
ignore_index=ignore_index,
n_cols=V,
BLOCK_SIZE=BLOCK_SIZE,
HAS_LABEL=has_label,
)
loss = torch.sum(loss)
return loss.to(_input.dtype), dX
def jsd_backward(dX, grad_output):
# If jsd is the last layer, grad_output is 1.0. Skip the mul to save time
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
return dX
else:
return grad_output * dX
class LigerJSDFunction(torch.autograd.Function):
r"""
This class implements the forward and backward pass for the generalized Jensen-Shannon Divergence.
.. math::
JSD(\beta)(P || Q)
= \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q))
.. note::
As all the other losses in PyTorch, this function expects the first argument,
:attr:`_input`, to be the predictions, the output of the student model, in log-space
and the second, :attr:`target`, to be the observations, the output of the teacher model, in log-space.
This differs from the standard mathematical notation :math:`JSD(P || Q)` where
:math:`P` denotes the teacher model and :math:`Q` denotes the student model.
"""
@staticmethod
@ensure_contiguous
def forward(
ctx,
_input: torch.Tensor,
target: torch.Tensor,
shift_labels: Optional[torch.Tensor] = None,
beta: float = 0.5,
ignore_index: int = -100,
) -> torch.Tensor:
"""
Args:
_input (torch.Tensor): predict values with shape (BT, V) in logspace
target (torch.Tensor): ground truth values with shape (BT, V) in logspace
shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
ignore_index (int): the index to ignore. Default: -100
Returns:
loss (torch.Tensor): generalized JSD
"""
has_label = False
if shift_labels is not None:
assert shift_labels.shape == (_input.shape[0],), (
f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
)
shift_labels = shift_labels.contiguous()
has_label = True
loss, dX = jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label)
ctx.save_for_backward(dX)
return loss
@staticmethod
@ensure_contiguous
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
(dX,) = ctx.saved_tensors
dX = jsd_backward(dX, grad_output)
return (
dX,
None,
None,
None,
None,
)