| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Tuple, Optional |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| import triton |
| import triton.language as tl |
|
|
| |
| |
| |
| |
| if "all_gather_into_tensor" not in dir(torch.distributed): |
| torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base |
|
|
|
|
| @triton.heuristics( |
| { |
| "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0, |
| } |
| ) |
| @triton.jit |
| def cross_entropy_fwd_kernel( |
| loss_ptr, |
| lse_ptr, |
| z_loss_ptr, |
| logits_ptr, |
| labels_ptr, |
| smoothing, |
| logit_scale, |
| lse_square_scale, |
| ignore_index, |
| total_classes, |
| class_start_idx, |
| n_cols, |
| logits_row_stride, |
| BLOCK_SIZE: tl.constexpr, |
| HAS_SMOOTHING: tl.constexpr, |
| |
| SPLIT: tl.constexpr, |
| PRECOMPUTED_LSE: tl.constexpr, |
| ): |
| row_idx = tl.program_id(0) |
| logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64) |
| sum_logits = 0.0 |
| if not PRECOMPUTED_LSE: |
| |
| m_i = -float("inf") |
| l_i = 0.0 |
| for col_offset in range(0, n_cols, BLOCK_SIZE): |
| cols = col_offset + tl.arange(0, BLOCK_SIZE) |
| logits = tl.load(logits_ptr + cols, mask=cols < n_cols, other=-float("inf")).to( |
| tl.float32 |
| ) * logit_scale |
| if HAS_SMOOTHING: |
| sum_logits += tl.sum(tl.where(cols < n_cols, logits, 0.0)) |
| m_i_new = tl.maximum(m_i, tl.max(logits)) |
| l_i = tl.exp(m_i - m_i_new) * l_i + tl.sum(tl.exp(logits - m_i_new)) |
| m_i = m_i_new |
| lse = tl.log(l_i) + m_i |
| tl.store(lse_ptr + row_idx, lse) |
| else: |
| lse = tl.load(lse_ptr + row_idx) |
| label_idx = tl.load(labels_ptr + row_idx) |
| if label_idx == ignore_index: |
| loss = 0.0 |
| z_loss = 0.0 |
| else: |
| label_idx -= class_start_idx |
| if label_idx >= 0 and label_idx < n_cols: |
| logits_label = tl.load(logits_ptr + label_idx) * logit_scale |
| if HAS_SMOOTHING: |
| loss = ( |
| (lse if not SPLIT else 0.0) |
| - smoothing * sum_logits / total_classes |
| - (1 - smoothing) * logits_label |
| ) |
| else: |
| loss = (lse if not SPLIT else 0.0) - logits_label |
| else: |
| |
| if HAS_SMOOTHING: |
| loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes) |
| else: |
| loss = 0.0 |
| if not SPLIT: |
| z_loss = lse_square_scale * lse * lse |
| loss += z_loss |
| else: |
| z_loss = 0.0 |
| tl.store(loss_ptr + row_idx, loss) |
| if not SPLIT: |
| tl.store(z_loss_ptr + row_idx, z_loss) |
|
|
|
|
| @triton.heuristics( |
| { |
| "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0, |
| } |
| ) |
| @triton.jit |
| def cross_entropy_bwd_kernel( |
| dlogits_ptr, |
| dloss_ptr, |
| logits_ptr, |
| lse_ptr, |
| labels_ptr, |
| smoothing, |
| logit_scale, |
| lse_square_scale, |
| ignore_index, |
| total_classes, |
| class_start_idx, |
| n_cols, |
| logits_row_stride, |
| dlogits_row_stride, |
| dloss_row_stride, |
| BLOCK_SIZE: tl.constexpr, |
| HAS_SMOOTHING: tl.constexpr, |
| ): |
| row_idx = tl.program_id(0) |
| col_block_idx = tl.program_id(1) |
| logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64) |
| dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64) |
| col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| label_idx = tl.load(labels_ptr + row_idx) |
| if label_idx != ignore_index: |
| dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride) |
| else: |
| dloss = 0.0 |
| logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to( |
| tl.float32 |
| ) * logit_scale |
| lse = tl.load(lse_ptr + row_idx) |
| probs = tl.exp(logits - lse) |
| probs += 2.0 * lse_square_scale * lse * probs |
| label_idx -= class_start_idx |
| if HAS_SMOOTHING: |
| smooth_positive = 1.0 - smoothing |
| smooth_negative = smoothing / total_classes |
| probs = tl.where(col_offsets == label_idx, probs - smooth_positive, probs) - smooth_negative |
| else: |
| probs = tl.where(col_offsets == label_idx, probs - 1.0, probs) |
| tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols) |
|
|
| @torch.library.custom_op("flasht5::cross_entropy_triton_fwd", mutates_args=(), device_types="cuda") |
| def cross_entropy_triton_fwd( |
| logits: torch.Tensor, |
| labels: torch.Tensor, |
| precomputed_lse: torch.Tensor, |
| use_precomputed_lse: bool, |
| split: bool, |
| smoothing: float, |
| logit_scale: float, |
| lse_square_scale: float, |
| ignore_index: int, |
| total_classes: int, |
| class_start_idx: int, |
| n_cols: int, |
| n_rows: int, |
| BLOCK_SIZE: int, |
| num_warps: int |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
| if logits.stride(-1) != 1: |
| logits = logits.contiguous() |
|
|
| losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) |
| if use_precomputed_lse: |
| assert precomputed_lse.shape == (n_rows,) |
| lse = precomputed_lse.contiguous() |
| else: |
| lse = torch.empty(n_rows, dtype=torch.float, device=logits.device) |
|
|
| z_losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) |
| |
| |
| with torch.cuda.device(logits.device.index): |
| cross_entropy_fwd_kernel[(n_rows,)]( |
| losses, |
| lse, |
| z_losses, |
| logits, |
| labels, |
| smoothing, |
| logit_scale, |
| lse_square_scale, |
| ignore_index, |
| total_classes, |
| class_start_idx, |
| n_cols, |
| logits.stride(0), |
| BLOCK_SIZE=BLOCK_SIZE, |
| SPLIT=split, |
| PRECOMPUTED_LSE=use_precomputed_lse, |
| num_warps=num_warps, |
| ) |
|
|
| return losses, z_losses, lse |
|
|
|
|
| @torch.library.register_fake("flasht5::cross_entropy_triton_fwd") |
| def cross_entropy_triton_fwd_abstract(logits, labels, precomputed_lse, use_precomputed_lse, split, smoothing, logit_scale, lse_square_scale, ignore_index, total_classes, class_start_idx, n_cols, n_rows, BLOCK_SIZE, num_warps): |
| losses = torch.empty(n_rows, dtype=torch.float32, device=logits.device) |
| z_losses = torch.empty(n_rows, dtype=torch.float32, device=logits.device) |
| logsumexp = torch.empty(n_rows, dtype=torch.float32, device=logits.device) |
|
|
| return losses, z_losses, logsumexp |
|
|
| @torch.library.custom_op("flasht5::cross_entropy_triton_bwd", mutates_args={"logits"}, device_types="cuda") |
| def cross_entropy_triton_bwd( |
| dlosses: torch.Tensor, |
| logits: torch.Tensor, |
| lse: torch.Tensor, |
| labels: torch.Tensor, |
| inplace_backward: bool, |
| smoothing: float, |
| logit_scale: float, |
| lse_square_scale: float, |
| ignore_index: int, |
| total_classes: int, |
| class_start_idx: int, |
| n_cols: int, |
| n_rows: int, |
| BLOCK_SIZE: int, |
| num_warps: int |
| ) -> torch.Tensor: |
|
|
| dlogits = logits if inplace_backward else torch.empty_like(logits) |
|
|
| grid = lambda META: (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) |
|
|
| |
| |
| with torch.cuda.device(logits.device.index): |
| cross_entropy_bwd_kernel[grid]( |
| dlogits, |
| dlosses, |
| logits, |
| lse, |
| labels, |
| smoothing, |
| logit_scale, |
| lse_square_scale, |
| ignore_index, |
| total_classes, |
| class_start_idx, |
| n_cols, |
| logits.stride(0), |
| dlogits.stride(0), |
| dlosses.stride(0), |
| BLOCK_SIZE=BLOCK_SIZE, |
| num_warps=num_warps, |
| ) |
|
|
| return dlogits if not inplace_backward else None |
|
|
| @torch.library.register_fake("flasht5::cross_entropy_triton_bwd") |
| def cross_entropy_triton_bwd_abstract(dlosses, logits, lse, labels, inplace_backward, smoothing, logit_scale, lse_square_scale, ignore_index, total_classes, class_start_idx, n_cols, n_rows, BLOCK_SIZE, num_warps): |
| return torch.empty_like(logits) |
|
|
| class CrossEntropyLoss(torch.autograd.Function): |
|
|
| @staticmethod |
| def forward( |
| ctx, |
| logits, |
| labels, |
| precomputed_lse=None, |
| smoothing=0.0, |
| logit_scale=1.0, |
| lse_square_scale=0.0, |
| ignore_index=-100, |
| inplace_backward=False, |
| process_group=None, |
| ): |
| |
| |
| if labels.dtype == torch.long and labels.data_ptr() % 16 != 0: |
| labels = F.pad(labels, (0, 1))[..., :-1] |
| assert labels.data_ptr() % 16 == 0 |
|
|
| n_rows, n_cols = logits.shape |
| assert labels.shape == (n_rows,) |
| world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group) |
| total_classes = world_size * n_cols |
| rank = 0 if process_group is None else torch.distributed.get_rank(process_group) |
| class_start_idx = rank * n_cols |
| use_precomputed_lse = precomputed_lse is not None and logit_scale == 1.0 and smoothing == 0.0 |
|
|
| MAX_BLOCK_SIZE = 16 * 1024 |
| BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE) |
| num_warps = ( |
| 4 |
| if BLOCK_SIZE < 2048 |
| else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32)) |
| ) |
|
|
| losses, z_losses, lse = torch.ops.flasht5.cross_entropy_triton_fwd( |
| logits, labels, precomputed_lse, use_precomputed_lse, \ |
| world_size > 1, smoothing, logit_scale, lse_square_scale, \ |
| ignore_index, total_classes, class_start_idx, \ |
| n_cols, n_rows, BLOCK_SIZE, num_warps |
| ) |
|
|
| if world_size > 1: |
| |
| |
| |
| |
| |
| |
| if world_size > 1: |
| lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device) |
| torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group) |
| handle_losses = torch.distributed.all_reduce( |
| losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True |
| ) |
| lse = torch.logsumexp(lse_allgather, dim=0) |
| handle_losses.wait() |
| |
| |
| |
| |
| |
| losses += lse |
| if lse_square_scale != 0.0: |
| z_losses = lse_square_scale * lse.square() |
| z_losses.masked_fill_(labels == ignore_index, 0.0) |
| losses += z_losses |
| else: |
| z_losses = torch.zeros_like(losses) |
| losses.masked_fill_(labels == ignore_index, 0.0) |
|
|
| ctx.save_for_backward(logits, lse, labels) |
| ctx.mark_non_differentiable(z_losses) |
| ctx.smoothing = smoothing |
| ctx.logit_scale = logit_scale |
| ctx.lse_square_scale = lse_square_scale |
| ctx.ignore_index = ignore_index |
| ctx.total_classes = total_classes |
| ctx.class_start_idx = class_start_idx |
| ctx.inplace_backward = inplace_backward |
|
|
| return losses, z_losses |
|
|
| @staticmethod |
| def backward(ctx, grad_losses, grad_z_losses): |
| del grad_z_losses |
|
|
| logits, lse, labels = ctx.saved_tensors |
|
|
| n_rows, n_cols = logits.shape |
| BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024) |
| num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16) |
|
|
| dlogits = torch.ops.flasht5.cross_entropy_triton_bwd( |
| grad_losses, logits, lse, labels, \ |
| ctx.inplace_backward, ctx.smoothing, ctx.logit_scale, \ |
| ctx.lse_square_scale, ctx.ignore_index, ctx.total_classes, \ |
| ctx.class_start_idx, n_cols, n_rows, BLOCK_SIZE, num_warps |
| ) |
|
|
| if ctx.inplace_backward: |
| dlogits = logits |
|
|
| return dlogits, None, None, None, None, None, None, None, None, None |
|
|
|
|
| def cross_entropy_loss( |
| logits: torch.Tensor, |
| labels: torch.Tensor, |
| precomputed_lse: Optional[torch.Tensor] = None, |
| label_smoothing: float = 0.0, |
| logit_scale: float = 1.0, |
| lse_square_scale: float = 0.0, |
| ignore_index=-100, |
| inplace_backward: bool = False, |
| process_group=None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Arguments: |
| logits: (batch, vocab_size) |
| labels: (batch,) |
| label_smoothing: float |
| logit_scale: float. Multiply logits by this scale before calculating the loss. |
| lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss. |
| This is also referred to as "z-loss". |
| ignore_index: int. If labels == ignore_index, the loss is set to 0.0. |
| inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits. |
| This saves memory. |
| process_group: if not None, we're doing Tensor Parallel: each process is responsible for |
| one part of the vocab. The loss will be aggregated across processes. |
| Returns: |
| losses: (batch,), float |
| z_losses: (batch,), float |
| """ |
| return CrossEntropyLoss.apply( |
| logits.view(-1, logits.shape[-1]), |
| labels.view(-1), |
| precomputed_lse, |
| label_smoothing, |
| logit_scale, |
| lse_square_scale, |
| ignore_index, |
| inplace_backward, |
| process_group, |
| ) |
|
|