| | from typing import Optional |
| |
|
| | import torch |
| | import triton |
| | import triton.language as tl |
| |
|
| | from .utils import ensure_contiguous |
| | from .utils import infer_device |
| |
|
| |
|
| | @triton.jit |
| | def _jsd_kernel( |
| | X_ptr, |
| | X_stride, |
| | Y_ptr, |
| | Y_stride, |
| | loss_ptr, |
| | loss_stride, |
| | dX_ptr, |
| | dX_stride, |
| | label_ptr, |
| | beta: tl.constexpr, |
| | n_non_ignore: int, |
| | ignore_index: tl.constexpr, |
| | n_cols, |
| | BLOCK_SIZE: tl.constexpr, |
| | HAS_LABEL: tl.constexpr, |
| | ): |
| | |
| | |
| | |
| | |
| | pid = tl.program_id(0).to(tl.int64) |
| | X_ptr += pid * X_stride |
| | dX_ptr += pid * dX_stride |
| | Y_ptr += pid * Y_stride |
| | loss_ptr += pid * loss_stride |
| | label_ptr += pid |
| |
|
| | if HAS_LABEL: |
| | label = tl.load(label_ptr) |
| | if label == ignore_index: |
| | for i in range(0, n_cols, BLOCK_SIZE): |
| | offsets = i + tl.arange(0, BLOCK_SIZE) |
| | tl.store(dX_ptr + offsets, 0.0, mask=offsets < n_cols) |
| | return |
| |
|
| | for i in range(0, n_cols, BLOCK_SIZE): |
| | offsets = i + tl.arange(0, BLOCK_SIZE) |
| | mask = offsets < n_cols |
| | X = tl.load(X_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32) |
| | Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32) |
| |
|
| | if beta == 0.0: |
| | Y_max = tl.max(Y, axis=0) |
| | Y_shifted = Y - Y_max |
| | Y_prob = tl.exp(Y_shifted) * tl.exp(Y_max) |
| | loss = Y_prob * (Y - X) |
| | dX = -Y_prob |
| | elif beta == 1.0: |
| | X_max = tl.max(X, axis=0) |
| | X_shifted = X - X_max |
| | X_prob = tl.exp(X_shifted) * tl.exp(X_max) |
| | loss = X_prob * (X - Y) |
| | dX = loss + X_prob |
| | else: |
| | max_val = tl.maximum(tl.max(X, axis=0), tl.max(Y, axis=0)) |
| | X_shifted = X - max_val |
| | Y_shifted = Y - max_val |
| |
|
| | |
| | exp_max = tl.exp(max_val) |
| |
|
| | |
| | Q = tl.exp(X_shifted) * exp_max |
| | P = tl.exp(Y_shifted) * exp_max |
| |
|
| | |
| | beta_P = beta * P |
| | one_minus_beta_Q = (1 - beta) * Q |
| | M = beta_P + one_minus_beta_Q |
| | log_M = tl.log(M) |
| |
|
| | loss = beta_P * Y + one_minus_beta_Q * X - M * log_M |
| | dX = one_minus_beta_Q * (X - log_M) |
| |
|
| | |
| | scale = 1.0 / n_non_ignore |
| | loss = loss * scale |
| | dX = dX * scale |
| |
|
| | tl.store(loss_ptr + offsets, loss, mask=mask) |
| | tl.store(dX_ptr + offsets, dX, mask=mask) |
| |
|
| |
|
| | MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 |
| |
|
| |
|
| | def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label): |
| | BT, V = _input.shape |
| | n_rows = BT |
| | BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) |
| | |
| | loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device) |
| | dX = torch.empty_like(_input) |
| |
|
| | if has_label: |
| | n_non_ignore = (shift_labels != ignore_index).sum().item() |
| | else: |
| | n_non_ignore = BT |
| |
|
| | _jsd_kernel[(n_rows,)]( |
| | X_ptr=_input, |
| | X_stride=_input.stride(-2), |
| | Y_ptr=target, |
| | Y_stride=target.stride(-2), |
| | loss_ptr=loss, |
| | loss_stride=loss.stride(-2), |
| | dX_ptr=dX, |
| | dX_stride=dX.stride(-2), |
| | label_ptr=(shift_labels if has_label else torch.empty(1, device=_input.device)), |
| | beta=beta, |
| | n_non_ignore=n_non_ignore, |
| | ignore_index=ignore_index, |
| | n_cols=V, |
| | BLOCK_SIZE=BLOCK_SIZE, |
| | HAS_LABEL=has_label, |
| | ) |
| |
|
| | loss = torch.sum(loss) |
| | return loss.to(_input.dtype), dX |
| |
|
| |
|
| | def jsd_backward(dX, grad_output): |
| | |
| | if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): |
| | return dX |
| | else: |
| | return grad_output * dX |
| |
|
| |
|
| | class LigerJSDFunction(torch.autograd.Function): |
| | r""" |
| | This class implements the forward and backward pass for the generalized Jensen-Shannon Divergence. |
| | .. math:: |
| | JSD(\beta)(P || Q) |
| | = \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q)) |
| | |
| | .. note:: |
| | As all the other losses in PyTorch, this function expects the first argument, |
| | :attr:`_input`, to be the predictions, the output of the student model, in log-space |
| | and the second, :attr:`target`, to be the observations, the output of the teacher model, in log-space. |
| | This differs from the standard mathematical notation :math:`JSD(P || Q)` where |
| | :math:`P` denotes the teacher model and :math:`Q` denotes the student model. |
| | """ |
| |
|
| | @staticmethod |
| | @ensure_contiguous |
| | def forward( |
| | ctx, |
| | _input: torch.Tensor, |
| | target: torch.Tensor, |
| | shift_labels: Optional[torch.Tensor] = None, |
| | beta: float = 0.5, |
| | ignore_index: int = -100, |
| | ) -> torch.Tensor: |
| | """ |
| | Args: |
| | _input (torch.Tensor): predict values with shape (BT, V) in logspace |
| | target (torch.Tensor): ground truth values with shape (BT, V) in logspace |
| | shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1]. |
| | beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5` |
| | ignore_index (int): the index to ignore. Default: -100 |
| | |
| | Returns: |
| | loss (torch.Tensor): generalized JSD |
| | """ |
| | has_label = False |
| | if shift_labels is not None: |
| | assert shift_labels.shape == (_input.shape[0],), ( |
| | f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}" |
| | ) |
| | shift_labels = shift_labels.contiguous() |
| | has_label = True |
| |
|
| | loss, dX = jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label) |
| | ctx.save_for_backward(dX) |
| | return loss |
| |
|
| | @staticmethod |
| | @ensure_contiguous |
| | def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: |
| | (dX,) = ctx.saved_tensors |
| | dX = jsd_backward(dX, grad_output) |
| | return ( |
| | dX, |
| | None, |
| | None, |
| | None, |
| | None, |
| | ) |