diff --git a/build/torch-cuda/__init__.py b/build/torch-cuda/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..271a47c2646379ff9ec8cf8c3c3a0e973f187652 --- /dev/null +++ b/build/torch-cuda/__init__.py @@ -0,0 +1,3 @@ +from . import layers + +__all__ = ["layers"] \ No newline at end of file diff --git a/build/torch-cuda/_ops.py b/build/torch-cuda/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..45b30870c17a6c248a2a9b9922667f3afc95a973 --- /dev/null +++ b/build/torch-cuda/_ops.py @@ -0,0 +1,8 @@ +import torch +ops = torch.ops._liger_kernels_ab5ef3f + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_liger_kernels_ab5ef3f::{op_name}" \ No newline at end of file diff --git a/build/torch-cuda/cross_entropy.py b/build/torch-cuda/cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..590d94868b34455f93c98f6c2fb46872d6b6956e --- /dev/null +++ b/build/torch-cuda/cross_entropy.py @@ -0,0 +1,460 @@ +import operator + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from .utils import compare_version +from .utils import element_mul_kernel +from .utils import is_hip +from .utils import infer_device + +if compare_version("triton", operator.ge, "3.0.0"): + try: + # typical import path with dispatch available + from triton.language.extra.libdevice import tanh + except ModuleNotFoundError: + # for working with NGC containers + from triton.language.extra.cuda.libdevice import tanh +else: + from triton.language.math import tanh + + +@triton.jit +def liger_cross_entropy_kernel( + X_ptr, + X_stride, + Y_ptr, + Y_stride, + weight_ptr, + loss_ptr, + z_loss_ptr, + loss_stride, + n_cols, + n_non_ignore, + sum_non_ignore_weight, + weight_sum, + ignore_index, + lse_square_scale: tl.constexpr, + label_smoothing: tl.constexpr, + reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time + softcap, + RETURN_Z_LOSS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_SOFTCAPPING: tl.constexpr, +): + """ + This kernel computes both cross entropy loss and the gradient of the input. + We only consider hard label + mean reduction for now. Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math. + + Parameters: + X_ptr: Pointer to input tensor. + X_stride (int): The stride of the input tensor. + Y_ptr: Pointer to target tensor. + Y_stride (int): The stride of the target tensor. + weight_ptr: Pointer to weight tensor. + loss_ptr: Pointer to tensor to store the loss. + z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0. + loss_stride (int): The stride of the loss tensor. + n_cols (int): The number of columns in the input tensor. + n_non_ignore (float): The number of non-ignored elements in the batch. + sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch. + weight_sum (float): The sum of weight tensor. + ignore_index (int): The index to ignore in the target. + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. + reduction (str): The string for the reduction to apply + softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap). + RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1. + BLOCK_SIZE (int): The block size for Triton operations. + HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes. + HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not. + """ + + # https://github.com/triton-lang/triton/issues/1058 + # If B*T*V is too large, program_id * stride will overflow out of int32, so we convert to int64 + program_id = tl.program_id(0).to(tl.int64) + + # 1. Load Y_ptr first because if the target is ignore_index, we can return right away + Y_ptr += program_id * Y_stride + y = tl.load(Y_ptr) + + # 2. locate the start index + X_ptr += program_id * X_stride + + if y == ignore_index: + # set all X_ptr as 0 + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols) + return + + loss_ptr += program_id * loss_stride + if RETURN_Z_LOSS: + z_loss_ptr += program_id * loss_stride + + if HAS_WEIGHT: + weight_y = tl.load(weight_ptr + y).cast(tl.float32) + + # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax) + # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867 + + # 3. [Online softmax] first pass: find max + sum + m = float("-inf") # m is the max value. use the notation from the paper + d = 0.0 # d is the sum. use the notation from the paper + ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation + if HAS_SOFTCAPPING: + ori_X_y = softcap * tanh(ori_X_y / softcap) + + # Label smoothing is a general case of normal cross entropy + # See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310 + scaled_x_sum = 0.0 + eps = label_smoothing / n_cols + + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load( + X_ptr + X_offsets, + mask=X_offsets < n_cols, + other=float("-inf"), + # Ensure float32 precision for softmax calculation + ).cast(tl.float32) + if HAS_SOFTCAPPING: + X_block = softcap * tanh(X_block / softcap) + block_max = tl.max(X_block) + if label_smoothing > 0: + # scale X beforehand to avoid overflow + if HAS_WEIGHT: + weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) + scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block * weight_block, 0.0)) + else: + scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) + m_new = tl.maximum(m, block_max) + d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) + m = m_new + + # log (sum(e^(X_i))) = log (sum(e ^ (max(X) * e ^ (X_i - max(X))))) + # = log (e^(max(X)) * sum(e ^ (X_i - max(X)))) + # = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d + lse = m + tl.log(d) + + # 4. [Online Softmax] Second pass: compute gradients + # For 'mean' reduction, gradients are normalized by number of non-ignored elements (N) + # dx_y = (softmax(x_y) - 1) / N + # dx_i = softmax(x_i) / N, i != y + # For label smoothing: + # dx_i = (softmax(x_i) - label_smoothing / V) / N, V = n_cols, i != y + # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N + # = dx_i - (1 - label_smoothing) / N + # With Z loss: + # dx_i = ((1 + 2 * lse_square_scale * lse) * softmax(x_i) - label_smoothing / V) / N, i != y + # dx_y = dx_i - (1 - label_smoothing) / N + # For 'sum' reduction, no normalization is applied: + # dx_y = softmax(x_y) - 1 + # dx_i = softmax(x_i), for i ≠ y + + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load( + X_ptr + X_offsets, + mask=X_offsets < n_cols, + other=float("-inf"), + # Ensure float32 precision for softmax calculation + ).cast(tl.float32) + if HAS_SOFTCAPPING: + intermediate = tanh(X_block / softcap) + X_block = softcap * intermediate + + if not HAS_WEIGHT: + # softmax(x_i) + X_block = tl.exp(X_block - m) / d + # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i) + X_block += 2 * lse_square_scale * lse * X_block + # smoothing term + X_block += -eps + # special handle dx_y + X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing)) + # reduction scale + if reduction == "mean": + X_block = X_block / n_non_ignore + else: + weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) + softmax_X = tl.exp(X_block - m) / d + # derivative of original_loss + dloss_ori = (1 - label_smoothing) * softmax_X + # specially handle dx_y + dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing)) + dloss_ori = dloss_ori * weight_y + # derivative of smooth_loss + dloss_smooth = eps * (-weight_block + softmax_X * weight_sum) + # derivative of z-loss + dz_loss = 2 * lse_square_scale * lse * softmax_X + # reduction scale + if reduction == "mean": + dloss_ori = dloss_ori / sum_non_ignore_weight + dloss_smooth = dloss_smooth / sum_non_ignore_weight + # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight. + dz_loss = dz_loss / n_non_ignore + # derivative of total_loss + X_block = dloss_ori + dloss_smooth + dz_loss + + # chain rule softcapping + # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap)) + if HAS_SOFTCAPPING: + X_block = X_block * (1 - intermediate * intermediate) + + tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols) + + # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in + # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34 + tl.debug_barrier() + + # 5. Calculate the loss + + # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X)))) + # = (X_y - max(X)) - log(sum(e ^ (X - max(X)))) + # = X_y - m - log d = X_y - lse + # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1 + # So we can safely calculate log (softmax(X_y)) without overflow + loss = lse - ori_X_y + if HAS_WEIGHT: + loss = weight_y * loss + + # Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps + # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) + # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i)) + # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as: + # = (1 - label_smoothing) * H(q, p) + (sum(-eps * x_i) + label_smoothing * (m + logd)) + # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567 + # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516 + # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087 + if label_smoothing > 0: + if HAS_WEIGHT: + smooth_loss = scaled_x_sum + eps * lse * weight_sum + else: + smooth_loss = scaled_x_sum + label_smoothing * lse + loss = loss * (1 - label_smoothing) + smooth_loss + + # An auxiliary loss, z_loss + # Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html + z_loss = lse_square_scale * lse * lse + # Normalize the loss by the number of non-ignored elements if reduction is "mean" + if reduction == "mean": + if HAS_WEIGHT: + loss = loss / sum_non_ignore_weight + else: + loss = loss / n_non_ignore + # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight. + z_loss = z_loss / n_non_ignore + loss += z_loss + + tl.store(loss_ptr, loss) + if RETURN_Z_LOSS: + tl.store(z_loss_ptr, z_loss) + + +# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 +# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling +# The optimal maximum block size depends on your hardware, your kernel, and your dtype +MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2 # the best size we found by manually tuning + + +def cross_entropy_forward( + _input, + target, + weight, + ignore_index, + lse_square_scale, + label_smoothing, + reduction, + softcap, + return_z_loss, +): + assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}" + + BT, V = _input.shape + n_rows = BT + + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + + # unreduced loss + loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) + z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None + + target_mask = target != ignore_index + n_non_ignore = target_mask.sum().item() + assert (target * target_mask).max() < _input.shape[-1], ( + f"Target {target.max()} is out of bounds. Expected < {_input.shape[-1]}" + ) + assert (target * target_mask).min() >= 0, f"Target {target.min()} is out of bounds. Expected >= 0" + sum_non_ignore_weight = n_non_ignore + weight_sum = 0.0 + if weight is not None: + assert weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {weight.shape}" + assert torch.is_floating_point(weight), ( + f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}" + ) + sum_non_ignore_weight = torch.gather(weight, dim=0, index=target.masked_select(target_mask)).sum().item() + weight_sum = weight.sum().item() + # ensure weight is contiguous + if weight.stride(-1) != 1: + weight = weight.contiguous() + + # ensure _input and target are contiguous in the last dimension + if _input.stride(-1) != 1: + _input = _input.contiguous() + if target.stride(-1) != 1: + target = target.contiguous() + + # Here we use a trick to store X_ptr gradient in X_ptr so we can save memory + liger_cross_entropy_kernel[(n_rows,)]( + X_ptr=_input, + X_stride=_input.stride(-2), + Y_ptr=target, + Y_stride=target.stride(-1), # always 1 + weight_ptr=weight, # dummy if None + loss_ptr=loss_1d, + z_loss_ptr=z_loss_1d, + loss_stride=loss_1d.stride(-1), # always 1 + n_cols=V, + n_non_ignore=n_non_ignore, + sum_non_ignore_weight=sum_non_ignore_weight, + ignore_index=ignore_index, + weight_sum=weight_sum, + lse_square_scale=lse_square_scale, + label_smoothing=label_smoothing, + reduction=reduction, + softcap=softcap, + RETURN_Z_LOSS=return_z_loss, + BLOCK_SIZE=BLOCK_SIZE, + HAS_WEIGHT=True if weight is not None else False, + HAS_SOFTCAPPING=True if softcap is not None else False, + # TODO: 32 seems to give the best performance + # Performance is quite sensitive to num_warps + num_warps=32 if not is_hip() else 16, + ) + + if reduction == "none": + loss = loss_1d + z_loss = z_loss_1d if return_z_loss else None + else: + loss = torch.sum(loss_1d) + z_loss = torch.sum(z_loss_1d) if return_z_loss else None + + return loss, z_loss, _input + + +def cross_entropy_backward(_input, grad_output): + # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time + if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + pass + # If reduction is 'none' + elif grad_output.ndim > 0: + _input = _input * grad_output.unsqueeze(dim=1) + # If reduction is ['mean', 'sum'], grad_output is just a scalar + # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place + # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. + else: + BT, V = _input.shape + n_rows = BT + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + + element_mul_kernel[(n_rows,)]( + _input, + _input.stride(-2), + grad_output, + V, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32 if not is_hip() else 16, + ) + + return _input + + +class LigerCrossEntropyFunction(torch.autograd.Function): + """ + This class implements a custom autograd function for the Liger Cross Entropy loss. + It overrides the forward and backward methods of the torch.autograd.Function class. + """ + + @staticmethod + def forward( + ctx, + _input: torch.Tensor, + target: torch.Tensor, + weight: Optional[torch.FloatTensor], + ignore_index: int = -100, + lse_square_scale: float = 0.0, + label_smoothing: float = 0.0, + reduction: str = "mean", + softcap: Optional[float] = None, + return_z_loss: bool = False, + ): + """ + The forward pass of the Liger Cross Entropy loss. + + Parameters: + ctx : The context object. + _input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size. + target (tensor): The target tensor of shape (BT) where each value is in [0, V-1]. + weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype + ignore_index (int): The index to ignore in the target. + lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + reduction (str): The reduction to apply to the output: "none" | "mean | "sum". + softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap). + return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False` + + Returns: + tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None. + """ + loss, z_loss, _input = cross_entropy_forward( + _input, + target, + weight, + ignore_index, + lse_square_scale, + label_smoothing, + reduction, + softcap, + return_z_loss, + ) + # TODO: investigation + # If we don't detach the _input tensor, the memory will double + # Not sure why but seems that there will be a time both grad and value exist but in different location + ctx.save_for_backward(_input.detach()) + ctx.return_z_loss = return_z_loss + + return loss, z_loss + + @staticmethod + def backward(ctx, grad_output, grad_ouput2): + """ + The backward pass of the Liger Cross Entropy loss. + + Parameters: + ctx : The context object with saved tensors. + grad_output (tensor): The tensor containing the gradient of the loss with respect to the output. + grad_output2 (tenosr): No use. + Returns: + tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None. + """ + if ctx.return_z_loss: + del grad_ouput2 # z_loss is only for logging + + (_input,) = ctx.saved_tensors + _input = cross_entropy_backward(_input, grad_output) + return ( + _input, + None, + None, + None, + None, + None, + None, + None, + None, + ) \ No newline at end of file diff --git a/build/torch-cuda/dyt.py b/build/torch-cuda/dyt.py new file mode 100644 index 0000000000000000000000000000000000000000..3cc42d14ca31b33e8d4b570f9b6b6d12126f656d --- /dev/null +++ b/build/torch-cuda/dyt.py @@ -0,0 +1,225 @@ +import operator + +import torch +import triton +import triton.language as tl + +from .utils import calculate_settings +from .utils import compare_version +from .utils import ensure_contiguous +from .utils import infer_device + +if compare_version("triton", operator.ge, "3.0.0"): + try: + # typical import path with dispatch available + from triton.language.extra.libdevice import tanh + except ModuleNotFoundError: + # for working with NGC containers + from triton.language.extra.cuda.libdevice import tanh +else: + from triton.language.math import tanh + + +@triton.jit +def _dyt_fwd_kernel( + x_ptr, + x_row_stride, + alpha_ptr, + gamma_ptr, + beta_ptr, + y_ptr, + y_row_stride, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + """ + Reference: + https://arxiv.org/abs/2503.10622 + + Shapes: + - x: (BT, C) + - alpha: (1) + - gamma: (C) + - beta: (C) + """ + row_idx = tl.program_id(0) + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < n_cols + + x_ptr += row_idx * x_row_stride + y_ptr += row_idx * y_row_stride + + alpha = tl.load(alpha_ptr) + gamma = tl.load(gamma_ptr + offsets, mask=mask) + beta = tl.load(beta_ptr + offsets, mask=mask) + x = tl.load(x_ptr + offsets, mask=mask) + y = gamma * tanh((alpha * x).cast(tl.float32)) + beta + tl.store(y_ptr + offsets, y, mask=mask) + + +@triton.jit +def _dyt_bwd_kernel( + x_ptr, + x_row_stride, + dy_ptr, + dy_row_stride, + dx_ptr, + dx_row_stride, + alpha_ptr, + dalpha_ptr, + gamma_ptr, + dgamma_ptr, + dgamma_row_stride, + n_cols, + n_rows, + ROWS_PER_PROGRAM: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """ + Reference: + https://arxiv.org/abs/2503.10622 + + Shapes: + - x: (BT, C) + - alpha: (1) + - gamma: (C) + - dx: (BT, C) + - dy: (BT, C) + - dgamma: (sm_count, C) + - dalpha: (sm_count,) + """ + # d(gamma * tanh(alpha * x) + beta) / dx + # = gamma * (1 - tanh^2(alpha * x)) * alpha + # d(gamma * tanh(alpha * x) + beta) / dalpha + # = gamma * (1 - tanh^2(alpha * x)) * x + # d(gamma * tanh(alpha * x) + beta) / dgamma + # = tanh(alpha * x) + # d(gamma * tanh(alpha * x)) / dbeta = 1 + pid = tl.program_id(0) + + row_start = pid * ROWS_PER_PROGRAM + row_end = min((pid + 1) * ROWS_PER_PROGRAM, n_rows) + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < n_cols + + dalpha = 0.0 + dgamma = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + + x_ptr += row_start * x_row_stride + dx_ptr += row_start * dx_row_stride + dy_ptr += row_start * dy_row_stride + alpha = tl.load(alpha_ptr) + gamma = tl.load(gamma_ptr + offsets, mask=mask, other=0.0) + + for _ in tl.range(row_start, row_end): + dy = tl.load(dy_ptr + offsets, mask=mask, other=0.0) + x = tl.load(x_ptr + offsets, mask=mask, other=0.0) + tanh_ax = tanh((alpha * x).cast(tl.float32)) + sech2_ax = 1 - tanh_ax * tanh_ax + + dx = dy * gamma * sech2_ax * alpha + dalpha += tl.sum(dy * gamma * sech2_ax * x) + dgamma += dy * tanh_ax + tl.store(dx_ptr + offsets, dx, mask=mask) + + dy_ptr += dy_row_stride + x_ptr += x_row_stride + dx_ptr += dx_row_stride + + tl.store(dgamma_ptr + pid * dgamma_row_stride + offsets, dgamma, mask=mask) + tl.store(dalpha_ptr + pid, dalpha) + + pass + + +def liger_dyt_fwd(x, alpha, gamma, beta): + shape = x.shape + dim = shape[-1] + x = x.view(-1, dim) + n_rows, n_cols = x.shape + y = torch.empty_like(x) + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + _dyt_fwd_kernel[(n_rows,)]( + x_ptr=x, + alpha_ptr=alpha, + gamma_ptr=gamma, + beta_ptr=beta, + y_ptr=y, + x_row_stride=x.stride(0), + y_row_stride=y.stride(0), + n_cols=n_cols, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + return y.view(*shape) + + +def liger_dyt_bwd(dy, x, alpha, gamma): + shape = dy.shape + dtype = x.dtype + dim = shape[-1] + dy = dy.view(-1, dim) + x = x.view(-1, dim) + n_rows, n_cols = dy.shape + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + sm_count = 1 + device = infer_device() + if device == "cuda": + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + elif device == "xpu": + sm_count = torch.xpu.get_device_properties(x.device).gpu_subslice_count + if n_cols > BLOCK_SIZE: + raise RuntimeError( + f"Feature dimension {dim} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension." + ) + + dx = torch.empty_like(x, dtype=torch.float32) + _dalpha = torch.empty((sm_count,), dtype=torch.float32, device=x.device) + _dgamma = torch.empty((sm_count, n_cols), dtype=torch.float32, device=x.device) + + grid = (sm_count,) + rows_per_program = triton.cdiv(n_rows, sm_count) + _dyt_bwd_kernel[grid]( + x_ptr=x, + x_row_stride=x.stride(0), + dy_ptr=dy, + dy_row_stride=dy.stride(0), + dx_ptr=dx, + dx_row_stride=dx.stride(0), + alpha_ptr=alpha, + dalpha_ptr=_dalpha, + gamma_ptr=gamma, + dgamma_ptr=_dgamma, + dgamma_row_stride=_dgamma.stride(0), + n_cols=n_cols, + n_rows=n_rows, + ROWS_PER_PROGRAM=rows_per_program, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + dalpha = _dalpha.sum(dim=0, keepdim=True).to(dtype) + dgamma = _dgamma.sum(dim=0).to(dtype) + dbeta = dy.sum(dim=0).to(dtype) + return dx.view(*shape), dalpha, dgamma, dbeta + + +class LigerDyTFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward(ctx, x, alpha, gamma, beta): + y = liger_dyt_fwd(x, alpha, gamma, beta) + ctx.save_for_backward(x, alpha, gamma) + return y + + @staticmethod + @ensure_contiguous + def backward(ctx, grad_output): + x, alpha, gamma = ctx.saved_tensors + dx, dalpha, dgamma, dbeta = liger_dyt_bwd( + grad_output, + x, + alpha, + gamma, + ) + + return (dx, dalpha, dgamma, dbeta) \ No newline at end of file diff --git a/build/torch-cuda/fused_linear_cross_entropy.py b/build/torch-cuda/fused_linear_cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..a923c05e268872a75ae473f23c7e64df491c866b --- /dev/null +++ b/build/torch-cuda/fused_linear_cross_entropy.py @@ -0,0 +1,283 @@ +import torch +import triton + +from .cross_entropy import liger_cross_entropy_kernel +from .utils import amp_custom_bwd +from .utils import amp_custom_fwd +from .utils import element_mul_kernel +from .utils import is_hip + +# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 +# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling +# The optimal maximum block size depends on your hardware, your kernel, and your dtype +MAX_FUSED_SIZE = 65536 // 2 + + +def fused_linear_cross_entropy_forward( + _input, + weight, + target, + ce_weight=None, + bias=None, + ignore_index=-100, + lse_square_scale=0.0, + label_smoothing=0.0, + reduction="mean", + softcap=None, + return_z_loss=False, +): + assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}" + device = _input.device + + # inputs have shape: BT x H + # materialized activations will have shape: BT x V + # the increase in memory = BT x V + # reduction can be achieved by partitioning the number of tokens BT into smaller chunks. + # for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be: + # inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor + # for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048 + BT, H = _input.shape + V = weight.shape[0] + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + + inc_factor = triton.cdiv(V, H) # (V + H - 1) // H + chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor + num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size + + grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None + grad_input = torch.zeros_like(_input, device=device) + grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None + # we use fp32 for loss accumulator + loss_1d = torch.zeros(BT, dtype=torch.float32, device=device) + z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None + + # TODO: evaluate how CUDA synchronization caused by .item() affects the speed + target_mask = target != ignore_index + total_n_non_ignore = target_mask.sum().item() + total_sum_non_ignore_ce_weight = total_n_non_ignore + ce_weight_sum = 0.0 + if ce_weight is not None: + assert ce_weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {ce_weight.shape}" + assert torch.is_floating_point(ce_weight), ( + f"If given, weight has to be a Tensor of floating point dtype. Got: {ce_weight.dtype}" + ) + total_sum_non_ignore_ce_weight = ( + torch.gather(ce_weight, dim=0, index=target.masked_select(target_mask)).sum().item() + ) + ce_weight_sum = ce_weight.sum().item() + if ce_weight.stride(-1) != 1: + ce_weight = ce_weight.contiguous() + + for chunk_id in range(num_chunks): + start_idx = chunk_id * chunk_size + end_idx = min((chunk_id + 1) * chunk_size, BT) + _input_chunk = _input[start_idx:end_idx] # chunk_size x H + + # when doing matmul, use the original precision + logits_chunk = _input_chunk @ weight.t() # chunk_size x V + if bias is not None: + logits_chunk = logits_chunk + bias + + target_chunk = target[start_idx:end_idx] # chunk_size, + + n_rows = logits_chunk.shape[0] + + # unreduced loss + loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size, + z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None + + # ensure _input and target are contiguous + logits_chunk = logits_chunk.contiguous() + target_chunk = target_chunk.contiguous() + + # Here we calculate the gradient of logits_chunk in place so we can save memory. + liger_cross_entropy_kernel[(n_rows,)]( + X_ptr=logits_chunk, + X_stride=logits_chunk.stride(-2), + Y_ptr=target_chunk, + Y_stride=target_chunk.stride(-1), # always 1 + weight_ptr=ce_weight, + loss_ptr=loss_1d_slice, + z_loss_ptr=z_loss_1d_slice, + loss_stride=loss_1d_slice.stride(-1), # always 1 + n_cols=V, + n_non_ignore=total_n_non_ignore, + sum_non_ignore_weight=total_sum_non_ignore_ce_weight, + weight_sum=ce_weight_sum, + ignore_index=ignore_index, + lse_square_scale=lse_square_scale, + label_smoothing=label_smoothing, + reduction=reduction, + softcap=softcap, + RETURN_Z_LOSS=return_z_loss, + HAS_WEIGHT=True if ce_weight is not None else False, + HAS_SOFTCAPPING=True if softcap is not None else False, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32 if not is_hip() else 16, + ) + + loss_1d[start_idx:end_idx] = loss_1d_slice + if return_z_loss: + z_loss_1d[start_idx:end_idx] = z_loss_1d_slice + grad_logits_chunk = logits_chunk # chunk_size x V + + grad_input[start_idx:end_idx] = grad_logits_chunk @ weight + + if grad_weight is not None: + torch.addmm( + input=grad_weight, + mat1=logits_chunk.t().to( + _input_chunk.dtype + ), # In an autocast scenario without bias, differing logits_chunk data types will cause an addmm operation error. + mat2=_input_chunk, + out=grad_weight, + alpha=1.0, + beta=1.0, + ) + + if bias is not None: + torch.add( + input=grad_bias, + other=logits_chunk.sum(dim=0), + out=grad_bias, + alpha=1.0, + ) + + # Need extra calculations for backward if reduction=='none'. Not supporting reduction='none' now. + # if reduction == "none": + # loss = loss_1d + # z_loss = z_loss_1d if return_z_loss else None + + else: + loss = torch.sum(loss_1d) + z_loss = torch.sum(z_loss_1d) if return_z_loss else None + return loss, z_loss, grad_input, grad_weight, grad_bias + + +def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias): + # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time + if not torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place + # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. + BT, H = grad_input.shape + n_rows = BT + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H)) + + element_mul_kernel[(n_rows,)]( + grad_input, + grad_input.stride(-2), + grad_output, + H, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32 if not is_hip() else 16, + ) + + # handle grad_weight + if grad_weight is not None: + V, H = grad_weight.shape + n_rows = V + + element_mul_kernel[(n_rows,)]( + grad_weight, + grad_weight.stride(-2), + grad_output, + H, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32 if not is_hip() else 16, + ) + + if grad_bias is not None: + V = grad_bias.shape[0] + n_rows = V + + element_mul_kernel[(n_rows,)]( + grad_bias, + grad_bias.stride(-1), + grad_output, + 1, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32 if not is_hip() else 16, + ) + return grad_input, grad_weight, grad_bias + + +class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function): + @staticmethod + @amp_custom_fwd + def forward( + ctx, + _input, + weight, + target, + bias=None, + ce_weight=None, + ignore_index=-100, + lse_square_scale=0.0, + label_smoothing=0.0, + reduction="mean", + softcap=None, + return_z_loss: bool = False, + ): + """ + Fusing the last linear layer with cross-entropy loss + Reference: https://github.com/mgmalek/efficient_cross_entropy + + Handle the forward and backward pass of the final linear layer via cross-entropy loss by avoiding + the materialization of the large logits tensor. Since Cross Entropy Loss is the last layer, we can + compute the gradient at the forward pass. By doing so, we don't have to store the _input and target + for the backward pass. + + _input: (B*T, H) where B is batch size, T is sequence length, H is hidden dimension. + target: (B*T) where each value is in [0, V-1] + weight: (V, H) where V is the number of classes + bias: (V) where V is the number of classes + ce_weight: a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype + ignore_index: the index to ignore in the target + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + reduction: reduction to apply + """ + + loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward( + _input=_input, + weight=weight, + target=target, + bias=bias, + ce_weight=ce_weight, + ignore_index=ignore_index, + lse_square_scale=lse_square_scale, + label_smoothing=label_smoothing, + reduction=reduction, + softcap=softcap, + return_z_loss=return_z_loss, + ) + # downcast to dtype and store for backward + ctx.save_for_backward( + grad_input.detach(), + grad_weight.detach() if grad_weight is not None else None, + grad_bias.detach() if bias is not None else None, + ) + ctx.return_z_loss = return_z_loss + return loss, z_loss + + @staticmethod + @amp_custom_bwd + def backward(ctx, grad_output, grad_output2): + if ctx.return_z_loss: + del grad_output2 # z_loss is only for logging + (grad_input, grad_weight, grad_bias) = ctx.saved_tensors + grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward( + grad_output, grad_input, grad_weight, grad_bias + ) + return ( + grad_input, + grad_weight, + None, + grad_bias, + None, + None, + None, + None, + None, + None, + None, + ) \ No newline at end of file diff --git a/build/torch-cuda/geglu.py b/build/torch-cuda/geglu.py new file mode 100644 index 0000000000000000000000000000000000000000..fdc06723b2852b6f5a6043ad45968af154f9afeb --- /dev/null +++ b/build/torch-cuda/geglu.py @@ -0,0 +1,141 @@ +import operator + +import torch +import triton +import triton.language as tl + +from .utils import calculate_settings +from .utils import compare_version +from .utils import ensure_contiguous + +if compare_version("triton", operator.ge, "3.0.0"): + try: + # typical import path with dispatch available + from triton.language.extra.libdevice import tanh + except ModuleNotFoundError: + # for working with NGC containers + from triton.language.extra.cuda.libdevice import tanh +else: + from triton.language.math import tanh + + +@triton.jit +def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr): + program_id = tl.program_id(0).to(tl.int64) + + # locate start index + a += program_id * stride + b += program_id * stride + c += program_id * stride + + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32) + b_row = tl.load(b + col_offsets, mask=mask, other=0) + + # tanh approximation form of GELU is computed with: + # 0.5 * a * (1 + tanh(sqrt(2 / pi) * (a + 0.044715 * a^3))) + sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi) + a_cubed = a_row * a_row * a_row + tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed) + tanh_result = tanh(tanh_arg) + geglu_a = 0.5 * a_row * (1 + tanh_result) + c_row = geglu_a * b_row + tl.store(c + col_offsets, c_row, mask=mask) + + +@triton.jit +def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr): + program_id = tl.program_id(0).to(tl.int64) + + # locate start index + dc += program_id * stride + a += program_id * stride + b += program_id * stride + + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + dc_row = tl.load(dc + col_offsets, mask=mask, other=0) + a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32) + b_row = tl.load(b + col_offsets, mask=mask, other=0) + + # recomputation to save memory + sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi) + a_cubed = a_row * a_row * a_row + tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed) + tanh_result = tanh(tanh_arg) + geglu_a = 0.5 * a_row * (1 + tanh_result) + + db_row = dc_row * geglu_a + + # Gradient w.r.t. a can be computed with: + # b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2))) + # where z = sqrt(2/pi) * (a + 0.044715 * a^3) + term1 = 0.5 * (1 + tanh_result) + tanh_sq = tanh_result * tanh_result + term2 = 0.5 * a_row * (1 - tanh_sq) * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row)) + da_row = dc_row * b_row * (term1 + term2) + + tl.store(a + col_offsets, da_row, mask=mask) + tl.store(b + col_offsets, db_row, mask=mask) + + +def geglu_forward(a, b): + ori_shape = a.shape + + n_cols = ori_shape[-1] + a = a.view(-1, n_cols) + b = b.view(-1, n_cols) + c = torch.empty_like(a) + n_rows = a.shape[0] + + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + + _geglu_tanh_forward_kernel[(n_rows,)]( + a, + b, + c, + c.stride(-2), + n_cols=n_cols, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + return a, b, c.view(*ori_shape) + + +def geglu_backward(a, b, dc): + ori_shape = dc.shape + n_cols = ori_shape[-1] + dc = dc.view(-1, n_cols) + n_rows = dc.shape[0] + + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + + _geglu_tanh_backward_kernel[(n_rows,)]( + dc, + a, + b, + dc.stride(-2), + n_cols=n_cols, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + + return a.view(*ori_shape), b.view(*ori_shape) + + +class LigerGELUMulFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward(ctx, a, b): + a, b, c = geglu_forward(a, b) + ctx.save_for_backward(a, b) + return c + + @staticmethod + @ensure_contiguous + def backward(ctx, dc): + a, b = ctx.saved_tensors + a, b = geglu_backward(a, b, dc) + return a, b \ No newline at end of file diff --git a/build/torch-cuda/group_norm.py b/build/torch-cuda/group_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..74b460c656e46d4544703345f9892763271588c2 --- /dev/null +++ b/build/torch-cuda/group_norm.py @@ -0,0 +1,305 @@ +import operator + +import torch +import triton +import triton.language as tl + +from .utils import compare_version +from .utils import ensure_contiguous + +if compare_version("triton", operator.ge, "3.0.0"): + try: + # typical import path with dispatch available + from triton.language.extra.libdevice import rsqrt + except ModuleNotFoundError: + # for working with NGC containers + from triton.language.extra.cuda.libdevice import rsqrt +else: + from triton.language.math import rsqrt + +MAX_FUSED_SIZE = 65536 + + +@triton.jit +def _group_norm_forward_kernel( + Y_ptr, # pointer to output, shape (n_rows, n_groups, hidden_size) + Y_row_stride, # stride of each row in output + Y_col_stride, # stride of each column in output + X_ptr, # pointer to input, shape (n_rows, n_groups, hidden_size) + X_row_stride, # stride of each row in input + X_col_stride, # stride of each column in input + Mean_ptr, # pointer to mean, shape (n_rows, n_groups) + Mean_row_stride, # stride of each row in mean + Mean_col_stride, # stride of each column in mean + RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups) + RSTD_row_stride, # stride of each row in rstd + RSTD_col_stride, # stride of each column in rstd + W_ptr, # pointer to W + B_ptr, # pointer to B + hidden_size, # hidden size of X + channels_per_group, # the number of channels per group + eps, + BLOCK_SIZE: tl.constexpr, +): + """ + References: + https://nn.labml.ai/normalization/group_norm/index.html + """ + batch_idx = tl.program_id(0) + group_idx = tl.program_id(1) + + X_ptr += batch_idx * X_row_stride + group_idx * X_col_stride + Y_ptr += batch_idx * Y_row_stride + group_idx * Y_col_stride + + block_range = tl.arange(0, BLOCK_SIZE) + + # Compute mean and variance using the online algorithm + s = 0.0 + squared_sum = 0.0 + for i in tl.range(0, hidden_size, BLOCK_SIZE): + hidden_size_offsets = i + block_range + mask = hidden_size_offsets < hidden_size + X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=0.0) + s += tl.sum(X) + # X**2 + squared_sum += tl.sum(X * X) + + m = s / hidden_size + + # variance = E[X**2] - E[X]**2 + variance = (squared_sum / hidden_size) - (m * m) + + # 1/std + rstd = rsqrt(variance + eps) + + # Normalize + hidden_size_per_channel = hidden_size // channels_per_group + for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group): + W = tl.load(W_ptr + channel_idx) + B = tl.load(B_ptr + channel_idx) + for i in range(0, hidden_size_per_channel, BLOCK_SIZE): + hidden_size_offsets = i + block_range + mask = hidden_size_offsets < hidden_size_per_channel + X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m) + Y = (X - m) * rstd * W + B + tl.store(Y_ptr + hidden_size_offsets, Y, mask=mask) + + X_ptr += hidden_size_per_channel + Y_ptr += hidden_size_per_channel + + tl.store(Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, m) + tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd) + + +@triton.jit +def _group_norm_backward_kernel( + X_ptr, # pointer to input, shape (n_rows, n_channels, hidden_size) + X_row_stride, # stride of each row in input + X_col_stride, # stride of each column in input + W_ptr, # pointer to weights, shape (n_channels) + Mean_ptr, # pointer to mean, shape (n_rows, n_groups) + Mean_ptr_row_stride, # stride of each column in mean + Mean_ptr_col_stride, # stride of each column in mean + RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups) + DX_ptr, # pointer to input grad, shape (n_rows, n_groups, hidden_size) + DW_ptr, # pointer to weights grad, shape (n_channels) + DB_ptr, # pointer to bias grad, shape (n_channels) + UPSTREAM_ptr, # pointer to output grad, shape (n_rows, n_channels, hidden_size) + hidden_size: tl.constexpr, # hidden size + channels_per_group: tl.constexpr, # number of groups in group norm + BLOCK_SIZE: tl.constexpr, + dtype: tl.constexpr, +): + """ + References: + https://nn.labml.ai/normalization/group_norm/index.html + https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md + + The backprop equations are the same for group_norm and layer_norm + the only difference here is that we load the Mean, Rstd corresponding to the + group we're computing gradients for and the mean and rstd are computed over n-channels + so the total number of elements we compute the mean over is num_channels_per_group * hidden_size + + We also need to load the Weights corresponding to the current channel to compute the gradients. + """ + batch_idx = tl.program_id(0) + group_idx = tl.program_id(1) + + # Move the pointers to the correct batch + X_ptr += batch_idx * X_row_stride + DX_ptr += batch_idx * X_row_stride + UPSTREAM_ptr += batch_idx * X_row_stride + + # Mean and rstd are the same shape so have the same strides + mean = tl.load(Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride) + rstd = tl.load(RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride) + + c1 = 0.0 + c2 = 0.0 + block_range = tl.arange(0, BLOCK_SIZE) + + # We need to compute the sum terms of the backprop equations across all channels in the group + for channel_idx in range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group): + dW = 0.0 + dB = 0.0 + # Move the pointers to the correct channel + W = tl.load(W_ptr + channel_idx) + for i in tl.range(0, hidden_size, BLOCK_SIZE): + hidden_size_offsets = i + block_range + mask = hidden_size_offsets < hidden_size + X = tl.load( + X_ptr + channel_idx * X_col_stride + hidden_size_offsets, + mask=mask, + other=0.0, + ) + UPSTREAM_grad = tl.load( + UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets, + mask=mask, + other=0.0, + ) + + x_hat = (X - mean) * rstd + dW += tl.sum(UPSTREAM_grad * x_hat) + dB += tl.sum(UPSTREAM_grad) + + wdy = W * UPSTREAM_grad + c1 += tl.sum(x_hat * wdy) + c2 += tl.sum(wdy) + + # Need to ensure additions to the same channel are atomic + tl.atomic_add(DW_ptr + channel_idx, dW.to(dtype)) + tl.atomic_add(DB_ptr + channel_idx, dB.to(dtype)) + + N = hidden_size * channels_per_group + c1 = c1 / N + c2 = c2 / N + + for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group): + # Move the pointers to the correct channel + W = tl.load(W_ptr + channel_idx) + for i in range(0, hidden_size, BLOCK_SIZE): + hidden_size_offsets = i + block_range + mask = hidden_size_offsets < hidden_size + X = tl.load( + X_ptr + channel_idx * X_col_stride + hidden_size_offsets, + mask=mask, + other=0.0, + ) + UPSTREAM_grad = tl.load( + UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets, + mask=mask, + other=0.0, + ) + + x_hat = (X - mean) * rstd + wdy = W * UPSTREAM_grad + dx = (wdy - (x_hat * c1 + c2)) * rstd + tl.store(DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask) + + +def group_norm_forward(X, num_channels, num_groups, W, B, eps): + shape = X.shape + batch_size = shape[0] + channels_per_group = num_channels // num_groups + # Reshape X so that the mean and std are computed across the groups + X = X.view(batch_size, num_groups, -1).contiguous() + hidden_size = X.shape[-1] + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size)) + Y = torch.empty((batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device) + Mean = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device) + RSTD = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device) + + _group_norm_forward_kernel[(batch_size, num_groups)]( + Y, + Y.stride(0), + Y.stride(1), + X, + X.stride(0), + X.stride(1), + Mean, + Mean.stride(0), + Mean.stride(1), + RSTD, + RSTD.stride(0), + RSTD.stride(1), + W, + B, + hidden_size, + channels_per_group, + eps, + BLOCK_SIZE=BLOCK_SIZE, + ) + # Return tensors in the original shape + return Y.view(*shape), X.view(*shape), Mean, RSTD, BLOCK_SIZE + + +def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups): + shape = dY.shape + batch_size = shape[0] + hidden_size = dY.shape[-1] + channels_per_group = num_channels // num_groups + dY = dY.view(batch_size, num_groups, -1) + DX = torch.empty( + (batch_size, num_groups, hidden_size * channels_per_group), + dtype=X.dtype, + device=X.device, + ) + DW = torch.zeros((num_channels), dtype=W.dtype, device=W.device) + DB = torch.zeros((num_channels), dtype=B.dtype, device=B.device) + triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16 + + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size)) + _group_norm_backward_kernel[(batch_size, num_groups)]( + X, + X.stride(0), + X.stride(1), + W, + Mean, + Mean.stride(0), + Mean.stride(1), + RSTD, + DX, + DW, + DB, + dY, + hidden_size, + channels_per_group, + BLOCK_SIZE=BLOCK_SIZE, + dtype=triton_dtype, + ) + + # Return tensors in the original shape + return DX.view(*shape), DW, DB + + +class LigerGroupNormFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward( + ctx, + X, + affine_scaling_weight, + affine_shifting_bias, + num_channels, + num_groups, + eps, + ): + Y, X, Mean, RSTD, BLOCK_SIZE = group_norm_forward( + X, + num_channels, + num_groups, + affine_scaling_weight, + affine_shifting_bias, + eps, + ) + ctx.num_channels = num_channels + ctx.num_groups = num_groups + ctx.save_for_backward(X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD) + return Y + + @staticmethod + @ensure_contiguous + def backward(ctx, dY): + X, W, B, Mean, RSTD = ctx.saved_tensors + DX, DW, DB = group_norm_backward(dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups) + return DX, DW, DB, None, None, None \ No newline at end of file diff --git a/build/torch-cuda/jsd.py b/build/torch-cuda/jsd.py new file mode 100644 index 0000000000000000000000000000000000000000..b879e0674471983e789f9b05d3111b51fb172671 --- /dev/null +++ b/build/torch-cuda/jsd.py @@ -0,0 +1,201 @@ +from typing import Optional + +import torch +import triton +import triton.language as tl + +from .utils import ensure_contiguous +from .utils import infer_device + + +@triton.jit +def _jsd_kernel( + X_ptr, # input in logspace, X = log Q + X_stride, + Y_ptr, # ground truth in logspace, Y = log P + Y_stride, + loss_ptr, + loss_stride, + dX_ptr, + dX_stride, + label_ptr, + beta: tl.constexpr, + n_non_ignore: int, + ignore_index: tl.constexpr, + n_cols, + BLOCK_SIZE: tl.constexpr, + HAS_LABEL: tl.constexpr, +): + # JSD(P || Q) = (KL(P || M) + KL(Q || M)) / 2, M = (1/2) * (P + Q) = (1/2) * (e ^ Y + e ^ X) + # = sum(P * log P + Q * log Q - 2 * M * log M) / 2 + # = sum(e ^ Y * Y + e ^ X * X - 2 * M * log M) / 2 + # grad_x_i = 0.5 * Q * (X - log_M) + pid = tl.program_id(0).to(tl.int64) + X_ptr += pid * X_stride + dX_ptr += pid * dX_stride + Y_ptr += pid * Y_stride + loss_ptr += pid * loss_stride + label_ptr += pid + + if HAS_LABEL: + label = tl.load(label_ptr) + if label == ignore_index: + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + tl.arange(0, BLOCK_SIZE) + tl.store(dX_ptr + offsets, 0.0, mask=offsets < n_cols) + return + + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_cols + X = tl.load(X_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32) + Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32) + + if beta == 0.0: # forward KL + Y_max = tl.max(Y, axis=0) + Y_shifted = Y - Y_max + Y_prob = tl.exp(Y_shifted) * tl.exp(Y_max) # Compensate for the shift + loss = Y_prob * (Y - X) + dX = -Y_prob + elif beta == 1.0: # reverse KL + X_max = tl.max(X, axis=0) + X_shifted = X - X_max + X_prob = tl.exp(X_shifted) * tl.exp(X_max) # Compensate for the shift + loss = X_prob * (X - Y) + dX = loss + X_prob + else: + max_val = tl.maximum(tl.max(X, axis=0), tl.max(Y, axis=0)) + X_shifted = X - max_val + Y_shifted = Y - max_val + + # Pre-compute exp(max_val) since it's used twice + exp_max = tl.exp(max_val) + + # Compute exp terms with compensation + Q = tl.exp(X_shifted) * exp_max # = exp(X) + P = tl.exp(Y_shifted) * exp_max # = exp(Y) + + # Pre-compute common terms + beta_P = beta * P + one_minus_beta_Q = (1 - beta) * Q + M = beta_P + one_minus_beta_Q + log_M = tl.log(M) # No need to compensate as M is already in original scale + + loss = beta_P * Y + one_minus_beta_Q * X - M * log_M + dX = one_minus_beta_Q * (X - log_M) + + # Pre-compute scaling factor + scale = 1.0 / n_non_ignore + loss = loss * scale + dX = dX * scale + + tl.store(loss_ptr + offsets, loss, mask=mask) + tl.store(dX_ptr + offsets, dX, mask=mask) + + +MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 + + +def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label): + BT, V = _input.shape + n_rows = BT + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + # non reduction loss + loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device) + dX = torch.empty_like(_input) + + if has_label: + n_non_ignore = (shift_labels != ignore_index).sum().item() + else: + n_non_ignore = BT + + _jsd_kernel[(n_rows,)]( + X_ptr=_input, # input in logspace, X = log Q + X_stride=_input.stride(-2), + Y_ptr=target, # ground truth in logspace, Y = log P + Y_stride=target.stride(-2), + loss_ptr=loss, + loss_stride=loss.stride(-2), + dX_ptr=dX, + dX_stride=dX.stride(-2), + label_ptr=(shift_labels if has_label else torch.empty(1, device=_input.device)), # dummy ptr if no label + beta=beta, + n_non_ignore=n_non_ignore, + ignore_index=ignore_index, + n_cols=V, + BLOCK_SIZE=BLOCK_SIZE, + HAS_LABEL=has_label, + ) + + loss = torch.sum(loss) + return loss.to(_input.dtype), dX + + +def jsd_backward(dX, grad_output): + # If jsd is the last layer, grad_output is 1.0. Skip the mul to save time + if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + return dX + else: + return grad_output * dX + + +class LigerJSDFunction(torch.autograd.Function): + r""" + This class implements the forward and backward pass for the generalized Jensen-Shannon Divergence. + .. math:: + JSD(\beta)(P || Q) + = \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q)) + + .. note:: + As all the other losses in PyTorch, this function expects the first argument, + :attr:`_input`, to be the predictions, the output of the student model, in log-space + and the second, :attr:`target`, to be the observations, the output of the teacher model, in log-space. + This differs from the standard mathematical notation :math:`JSD(P || Q)` where + :math:`P` denotes the teacher model and :math:`Q` denotes the student model. + """ + + @staticmethod + @ensure_contiguous + def forward( + ctx, + _input: torch.Tensor, + target: torch.Tensor, + shift_labels: Optional[torch.Tensor] = None, + beta: float = 0.5, + ignore_index: int = -100, + ) -> torch.Tensor: + """ + Args: + _input (torch.Tensor): predict values with shape (BT, V) in logspace + target (torch.Tensor): ground truth values with shape (BT, V) in logspace + shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1]. + beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5` + ignore_index (int): the index to ignore. Default: -100 + + Returns: + loss (torch.Tensor): generalized JSD + """ + has_label = False + if shift_labels is not None: + assert shift_labels.shape == (_input.shape[0],), ( + f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}" + ) + shift_labels = shift_labels.contiguous() + has_label = True + + loss, dX = jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label) + ctx.save_for_backward(dX) + return loss + + @staticmethod + @ensure_contiguous + def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: + (dX,) = ctx.saved_tensors + dX = jsd_backward(dX, grad_output) + return ( + dX, + None, + None, + None, + None, + ) \ No newline at end of file diff --git a/build/torch-cuda/kl_div.py b/build/torch-cuda/kl_div.py new file mode 100644 index 0000000000000000000000000000000000000000..2d563a7eae13aac671e3cf85283758b50a3fdb93 --- /dev/null +++ b/build/torch-cuda/kl_div.py @@ -0,0 +1,262 @@ +from typing import Literal + +import torch +import triton +import triton.language as tl + +from .utils import ensure_contiguous +from .utils import is_hip +from .utils import infer_device + + +def get_num_warps(BLOCK_SIZE): + num_warps = 4 + if BLOCK_SIZE >= 32768: + num_warps = 32 if not is_hip() else 16 + elif BLOCK_SIZE >= 8192: + num_warps = 16 + elif BLOCK_SIZE >= 2048: + num_warps = 8 + + return num_warps + + +MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best + +REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"] + +_REDUCTION_MODE_NONE: tl.constexpr = tl.constexpr(0) +_REDUCTION_MODE_SUM: tl.constexpr = tl.constexpr(1) +_REDUCTION_MODE_MEAN: tl.constexpr = tl.constexpr(2) +_REDUCTION_MODE_BATCHMEAN: tl.constexpr = tl.constexpr(3) + +_str_to_reduction_mode = { + "none": _REDUCTION_MODE_NONE.value, + "sum": _REDUCTION_MODE_SUM.value, + "mean": _REDUCTION_MODE_MEAN.value, + "batchmean": _REDUCTION_MODE_BATCHMEAN.value, +} + + +@triton.jit +def _kldiv_kernel_forward( + y_ptr, # [B, S], prediction ptr, the kernel expects the prediction in log-space + y_stride, # int, prediction stride + gt_ptr, # [B, S], ground truth ptr + gt_stride, # int, ground truth stride + loss_ptr, # [B] or [B, S] if reduction == _REDUCTION_MODE_NONE, output ptr + loss_stride, # int, output stride + n_cols, # int, number of columns in the input tensor + eps, + BLOCK_SIZE: tl.constexpr, + log_target: tl.constexpr = False, + reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN, +): + pid = tl.program_id(0).to(tl.int64) + y_ptr += pid * y_stride + gt_ptr += pid * gt_stride + loss_ptr += pid * loss_stride + + base_offsets = tl.arange(0, BLOCK_SIZE) + + loss_sum = 0.0 + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + base_offsets + mask = offsets < n_cols + y = tl.load(y_ptr + offsets, mask=mask, other=0.0) + y_true = tl.load(gt_ptr + offsets, mask=mask, other=0.0) + + # KL(y_true || y) = y_true * (log(y_true) - log(y)) + # We compute KL(y_true || y) with y in the log-space + if not log_target: + loss = y_true * (tl.log(tl.maximum(y_true, eps)) - y) + else: + loss = tl.exp(y_true) * (y_true - y) + + if reduction == _REDUCTION_MODE_NONE: + tl.store(loss_ptr + offsets, loss, mask=mask) + else: + loss_sum += tl.sum(loss, axis=0) + + if reduction != _REDUCTION_MODE_NONE: + tl.store(loss_ptr, loss_sum) + + +@triton.jit +def _kldiv_kernel_backward( + target_ptr, + target_stride, + new_grads_ptr, + new_grads_stride, + n_cols, + BLOCK_SIZE: tl.constexpr, + log_target: tl.constexpr = False, +): + pid = tl.program_id(0).to(tl.int64) + + target_ptr += pid * target_stride + new_grads_ptr += pid * new_grads_stride + + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < n_cols + + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_cols + + target = tl.load(target_ptr + offsets, mask=mask, other=0.0) + + if not log_target: + res = target * -1 + else: + res = -tl.exp(target) + + tl.store(new_grads_ptr + offsets, res, mask=mask) + + +def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V] + BT, V = y_pred.shape + BLOCK_SIZE = ( + min(8192, triton.next_power_of_2(V)) + if infer_device() == "xpu" + else min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + ) + num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE) + + grid = (BT,) + reduction = _str_to_reduction_mode[reduction] + + out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,) + output_tensor = torch.zeros(out_size, device=y_pred.device, dtype=torch.float32) + + _kldiv_kernel_forward[grid]( + y_pred, + y_pred.stride(0), + y_true, + y_true.stride(0), + output_tensor, + output_tensor.stride(0), + V, + eps=eps, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + log_target=log_target, + reduction=reduction, + ) + + # calculated according to the reduction mode same as in Pytorch. In the later versions, `mean` will be changed to the same behavior as `batchmean` + # https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html + # https://github.com/pytorch/pytorch/blob/d7b57c4d63edb42e1deeeba9497fcb5f1f748ff2/torch/nn/functional.py#L3372 + if reduction == _REDUCTION_MODE_BATCHMEAN.value: + return output_tensor.sum() / BT + elif reduction == _REDUCTION_MODE_SUM.value: + return output_tensor.sum(dim=0) + elif reduction == _REDUCTION_MODE_MEAN.value: + return output_tensor.sum() / (BT * V) + else: + return output_tensor + + +def kldiv_backward_triton(target, grad_output, new_grads, log_target): + BT, V = target.shape + BLOCK_SIZE = ( + min(8192, triton.next_power_of_2(V)) + if infer_device() == "xpu" + else min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + ) + num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE) + + grid = (BT,) + + # We store the gradients in-place in the input tensor + _kldiv_kernel_backward[grid]( + target, + target.stride(0), + new_grads, + new_grads.stride(0), + V, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + log_target=log_target, + ) + + # If cross entropy is the last layer, grad_output is 1.0. Skip the mul then. + if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + return new_grads + + return new_grads * grad_output + + +class LigerKLDivLossFunction(torch.autograd.Function): + """ + Class implementing the forward and backward pass for the KL Divergence Loss using Triton, as defined by the following formula: + ```python + if log_target: + loss = target.exp() * (target - input) + else: + loss = target * (target.log() - input) + ```, + then the loss is reduced according to the `reduction` parameter. + as defined in the PyTorch documentation: https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html + """ + + @staticmethod + @ensure_contiguous + def forward( + ctx, + y_pred: torch.Tensor, + y_true: torch.Tensor, + reduction: REDUCTION_LITERAL = "batchmean", + log_target: bool = False, + eps: float = 1e-10, + ) -> torch.Tensor: + """A forward pass for the KL Divergence Loss. + + Args: + ctx: Torch autograd context + y_pred (torch.Tensor): A tensor of shape (BT, V) containing the predicted values, expected to be log-probabilities. + y_true (torch.Tensor): A tensor of shape (BT, V) containing the target values, expected to be either probabilities or log-probabilities, depending on the value of `log_target`. + reduction (REDUCTION_LITERAL, optional): Reduction to be used. Defaults to "batchmean". + log_target (bool, optional): If set to true, expects the ground truth to already be log-probabilities. Defaults to False. + eps: (float, optional): A small value to avoid division by zero. Defaults to 1e-10. + + Returns: + torch.Tensor: The computed KL Divergence Loss, with shape (BT, V) if `reduction` is "none", else a scalar. + """ + ctx.save_for_backward(y_true) + ctx.reduction = reduction + ctx.log_target = log_target + return kldiv_forward_triton(y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps) + + @staticmethod + @ensure_contiguous + def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: + """A backward pass for the KL Divergence Loss. + + Args: + ctx: Torch autograd context + grad_output (torch.Tensor): The gradient of the loss with respect to the output. + + Returns: + tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs and None for the other arguments of the forward method. + """ + (y_true,) = ctx.saved_tensors + + new_grads = torch.empty_like(y_true) + + derivative = kldiv_backward_triton(y_true, grad_output, new_grads, ctx.log_target) + + if ctx.reduction == "batchmean": + derivative = derivative / y_true.shape[0] + elif ctx.reduction == "sum" or ctx.reduction == "none": + pass + elif ctx.reduction == "mean": + derivative = derivative / (y_true.shape[0] * y_true.shape[1]) + + return ( + derivative, + None, + None, + None, + None, + ) \ No newline at end of file diff --git a/build/torch-cuda/layer_norm.py b/build/torch-cuda/layer_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..00088223b1c58a5483c4227a2dc6be7f26cc1c3c --- /dev/null +++ b/build/torch-cuda/layer_norm.py @@ -0,0 +1,265 @@ +import math +import operator + +import torch +import triton +import triton.language as tl + +from .utils import calculate_settings +from .utils import compare_version +from .utils import ensure_contiguous + +if compare_version("triton", operator.ge, "3.0.0"): + try: + # typical import path with dispatch available + from triton.language.extra.libdevice import rsqrt + except ModuleNotFoundError: + # for working with NGC containers + from triton.language.extra.cuda.libdevice import rsqrt +else: + from triton.language.math import rsqrt + + +@triton.jit +def _layer_norm_forward_kernel( + Y_ptr, # pointer to output, shape (n_rows, n_cols) + Y_row_stride, # stride of each row in output + X_ptr, # pointer to input, shape (n_rows, n_cols) + X_row_stride, # stride of each row in input + W_ptr, # pointer to weights, shape (n_cols,) + W_row_stride, # stride of each row in weights + B_ptr, # pointer to bias, shape (n_cols,) + B_row_stride, # stride of each row in bias + Mean_ptr, # pointer to mean, shape (n_rows,) + Mean_row_stride, # stride of each row in mean + RSTD_ptr, # pointer to rstd, shape (n_rows,) + RSTD_row_stride, # stride of each row in rstd + n_cols, + eps, + BLOCK_SIZE: tl.constexpr, +): + """ + References: + https://arxiv.org/abs/1607.06450 + https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md + """ + row_idx = tl.program_id(0) + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + Y_ptr += row_idx * Y_row_stride + X_ptr += row_idx * X_row_stride + Mean_ptr += row_idx * Mean_row_stride + RSTD_ptr += row_idx * RSTD_row_stride + + X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0) + W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0) + B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0) + + mean = tl.sum(X_row, axis=0) / n_cols + Xmm = tl.where(mask, X_row - mean, 0) + var = tl.sum(Xmm * Xmm, axis=0) / n_cols + rstd = rsqrt(var + eps) + + tl.store(Mean_ptr, mean) + tl.store(RSTD_ptr, rstd) + + Y_row = Xmm * rstd * W_row + B_row + + tl.store(Y_ptr + col_offsets, Y_row, mask=mask) + + +@triton.jit +def _layer_norm_backward_kernel( + X_ptr, # pointer to input, shape (n_rows, n_cols) + W_ptr, # pointer to weights, shape (n_cols,) + Mean_ptr, # pointer to mean, shape (n_rows,) + RSTD_ptr, # pointer to rstd, shape (n_rows,) + DX_ptr, # pointer to input grad, shape (n_rows, n_cols) + DW_ptr, # pointer to weights grad, shape (n_cols,) + DB_ptr, # pointer to bias grad, shape (n_cols,) + DY_ptr, # pointer to output grad, shape (n_rows, n_cols) + stride_x, # stride of each row in input + stride_dx, # stride of each row in input grad + stride_dw, # stride of each row in weights grad + stride_db, # stride of each row in bias grad + stride_dy, # stride of each row in output grad + n_rows, + n_cols, + rows_per_program: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + dtype: tl.constexpr, +): + """ + References: + https://arxiv.org/abs/1607.06450 + https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md + https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html + https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py + """ + row_block_id = tl.program_id(0) + row_start = row_block_id * rows_per_program + row_end = min((row_block_id + 1) * rows_per_program, n_rows) + cols = tl.arange(0, BLOCK_SIZE) + mask = cols < n_cols + + dw_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + + X_ptr += row_start * stride_x + Mean_ptr += row_start + RSTD_ptr += row_start + DX_ptr += row_start * stride_dx + DY_ptr += row_start * stride_dy + + for _ in range(row_start, row_end): + x = tl.load(X_ptr + cols, mask=mask, other=0.0) + w = tl.load(W_ptr + cols, mask=mask, other=0.0) + dy = tl.load(DY_ptr + cols, mask=mask, other=0.0) + mean = tl.load(Mean_ptr) + rstd = tl.load(RSTD_ptr) + + x_hat = (x - mean) * rstd + wdy = w * dy + c1 = tl.sum(x_hat * wdy, axis=0) / n_cols + c2 = tl.sum(wdy, axis=0) / n_cols + dx = (wdy - (x_hat * c1 + c2)) * rstd + tl.store(DX_ptr + cols, dx.to(dtype), mask=mask) + + dw_row += dy * x_hat + db_row += dy + + X_ptr += stride_x + Mean_ptr += 1 + RSTD_ptr += 1 + DX_ptr += stride_dx + DY_ptr += stride_dy + + tl.store(DW_ptr + row_block_id * stride_dw + cols, dw_row.to(dtype), mask=mask) + tl.store(DB_ptr + row_block_id * stride_db + cols, db_row.to(dtype), mask=mask) + + +def layer_norm_forward(X, W, B, eps): + shape = X.shape + dim = shape[-1] + X = X.view(-1, dim) + n_rows, n_cols = X.shape + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) + Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device) + RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device) + if X.shape[1] != W.shape[0]: + raise ValueError( + f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) " + f"must match weight size (W.shape[0]={W.shape[0]})" + ) + + # XPU-specific optimization + kernel_args = {} + if X.device.type == "xpu": + kernel_args["grf_mode"] = "large" + + _layer_norm_forward_kernel[(n_rows,)]( + Y, + Y.stride(0), + X, + X.stride(0), + W, + W.stride(0), + B, + B.stride(0), + Mean, + Mean.stride(0), + RSTD, + RSTD.stride(0), + n_cols, + eps, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + **kernel_args, # XPU-specific optimization + ) + return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps + + +def layer_norm_backward(dY, X, W, B, Mean, RSTD): + shape = dY.shape + dim = shape[-1] + dY = dY.view(-1, dim) + n_rows, n_cols = dY.shape + + sm_count = 1 + if X.device.type == "cuda": + sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count + elif X.device.type == "xpu": + sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count + + DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) + _DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device) + _DB = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device) + + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + if n_cols > BLOCK_SIZE: + raise RuntimeError( + f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension." + ) + + rows_per_program = math.ceil(n_rows / sm_count) + grid = (sm_count,) + triton_dtype = ( + tl.float32 + if X.dtype == torch.float32 + else tl.bfloat16 + if X.dtype == torch.bfloat16 + else tl.float16 + if X.dtype == torch.float16 + else tl.float32 # fallback to float32 for other types + ) + + # XPU-specific optimization + kernel_args = {} + if X.device.type == "xpu": + kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4}) + + _layer_norm_backward_kernel[grid]( + X, + W, + Mean, + RSTD, + DX, + _DW, + _DB, + dY, + X.stride(0), + DX.stride(0), + _DW.stride(0), + _DB.stride(0), + dY.stride(0), + n_rows, + n_cols, + rows_per_program, + BLOCK_SIZE=BLOCK_SIZE, + dtype=triton_dtype, + **kernel_args, # XPU-specific optimization + ) + + DW = _DW.sum(dim=0).to(W.dtype) + DB = _DB.sum(dim=0).to(W.dtype) + + DX = DX.view(*shape) + return DX, DW, DB + + +class LigerLayerNormFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward(ctx, X, W, B, eps): + Y, X, Mean, RSTD, BLOCK_SIZE, num_warps = layer_norm_forward(X, W, B, eps) + ctx.save_for_backward(X, W, B, Mean, RSTD) + return Y + + @staticmethod + @ensure_contiguous + def backward(ctx, dY): + X, W, B, Mean, RSTD = ctx.saved_tensors + DX, DW, DB = layer_norm_backward(dY, X, W, B, Mean, RSTD) + return DX, DW, DB, None \ No newline at end of file diff --git a/build/torch-cuda/layers.py b/build/torch-cuda/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..798de28de6bcb77a98416812dcae89cb9e5014df --- /dev/null +++ b/build/torch-cuda/layers.py @@ -0,0 +1,39 @@ +import torch +from .rms_norm import LigerRMSNormFunction + +class LigerRMSNorm(torch.nn.Module): + """ + RMSNorm module that uses the optimized LigerRMSNormFunction. + + Args: + hidden_size (int): The size of the hidden dimension. + eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6. + offset (float, optional): Offset value to shift the weight tensor. Defaults to 0.0. + casting_mode (str, optional): The casting mode to use. Defaults to "llama". + in_place (bool, optional): Whether to modify dY in-place to store dX during backward. Defaults to True. + """ + + + weight: torch.Tensor + variance_epsilon: float + + def forward(self, hidden_states): + """ + Apply RMS normalization to the input tensor. + + Args: + hidden_states (torch.Tensor): Input tensor of shape (B, T, H) or (BxT, H) + + Returns: + torch.Tensor: Normalized tensor of the same shape as input + """ + return LigerRMSNormFunction.apply( + hidden_states, + self.weight, + self.variance_epsilon, + 0, + "llama", + True + ) + +__all__ = ["LigerRMSNorm"] \ No newline at end of file diff --git a/build/torch-cuda/liger_kernels/__init__.py b/build/torch-cuda/liger_kernels/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch-cuda/liger_kernels/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch-cuda/metadata.json b/build/torch-cuda/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..9cf5deed9898dce769f4cc73913d3530b92a0bd8 --- /dev/null +++ b/build/torch-cuda/metadata.json @@ -0,0 +1,4 @@ +{ + "version": 1, + "python-depends": [] +} \ No newline at end of file diff --git a/build/torch-cuda/qwen2vl_mrope.py b/build/torch-cuda/qwen2vl_mrope.py new file mode 100644 index 0000000000000000000000000000000000000000..33d02e7bcfb95402ccd3c1b0ad0ed8780ef422e4 --- /dev/null +++ b/build/torch-cuda/qwen2vl_mrope.py @@ -0,0 +1,222 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _triton_qwen2vl_mrope( + q_ptr, + k_ptr, + cos, + sin, + sl, + bs: tl.constexpr, + n_qh: tl.constexpr, + n_kh: tl.constexpr, + hd: tl.constexpr, + pad_n_qh: tl.constexpr, + pad_n_kh: tl.constexpr, + pad_hd: tl.constexpr, + mrope_section_t: tl.constexpr, + mrope_section_h: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BACKWARD_PASS: tl.constexpr = False, +): + pid = tl.program_id(0) + + # locate start address + q_ptr = q_ptr + pid * (n_qh * hd) + k_ptr = k_ptr + pid * (n_kh * hd) + + # #################################################################### + # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position + # m of this program instance + # #################################################################### + + # 1. program instances are laid out in a 1D vector of size bsz * seq_len, which + # effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension + # being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index + # and pid % sl to get the sequence index. + # 2. We only need the left half of cos and sin matrix because the right half is just + # a clone of the left half. + t_end = mrope_section_t + h_end = t_end + mrope_section_h + + t_cos = cos + pid * hd + h_cos = t_cos + bs * sl * hd + w_cos = h_cos + bs * sl * hd + t_sin = sin + pid * hd + h_sin = t_sin + bs * sl * hd + w_sin = h_sin + bs * sl * hd + + cos_offsets = tl.arange(0, pad_hd // 2) + t_mask = cos_offsets < t_end + h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end) + w_mask = (h_end <= cos_offsets) & (cos_offsets < hd // 2) + t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0) + h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0) + w_cos_row = tl.load(w_cos + cos_offsets, mask=w_mask, other=0) + t_sin_row = tl.load(t_sin + cos_offsets, mask=t_mask, other=0) + h_sin_row = tl.load(h_sin + cos_offsets, mask=h_mask, other=0) + w_sin_row = tl.load(w_sin + cos_offsets, mask=w_mask, other=0) + cos_row = t_cos_row + h_cos_row + w_cos_row + sin_row = t_sin_row + h_sin_row + w_sin_row + + # #################################################################### + # Load the left and right half of q and k for the current + # program instance (i.e. for the current token) separately + # #################################################################### + # left half of the head + first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] + first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] + first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2) + first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2) + q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype) + k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype) + + # right half of the head + second_half_q_offsets = first_half_q_offsets + (hd // 2) + second_half_k_offsets = first_half_k_offsets + (hd // 2) + second_q_mask = first_q_mask + second_k_mask = first_k_mask + q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype) + k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype) + + if not BACKWARD_PASS: + # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin] + new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row + tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) + new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row + tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) + + new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row + tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask) + new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row + tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) + else: + # with some math, we can get: + # dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin] + new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row + tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) + new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row + tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) + + new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row + tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask) + new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row + tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) + + +def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section): + # transpose it back to the physical shape because Triton looks at the physical storage + # note: q and k are incontiguous before the transformation and will become contiguous after transpose + q = q.transpose(1, 2) + k = k.transpose(1, 2) + + batch_size, seq_len, n_q_head, head_dim = q.shape + n_kv_head = k.shape[2] + pad_hd = triton.next_power_of_2(head_dim) + pad_n_q_head = triton.next_power_of_2(n_q_head) + pad_n_kv_head = triton.next_power_of_2(n_kv_head) + BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head) + + n_row = batch_size * seq_len + + # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous + q = q.contiguous() + k = k.contiguous() + cos = cos.contiguous() + sin = sin.contiguous() + + _triton_qwen2vl_mrope[(n_row,)]( + q, + k, + cos, + sin, + seq_len, + batch_size, + n_q_head, + n_kv_head, + head_dim, + pad_n_q_head, + pad_n_kv_head, + pad_hd, + mrope_section[0], + mrope_section[1], + BLOCK_SIZE=BLOCK_SIZE, + BACKWARD_PASS=False, + ) + return q.transpose(1, 2), k.transpose(1, 2), cos, sin + + +def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section): + dq = dq.transpose(1, 2) + dk = dk.transpose(1, 2) + + batch_size, seq_len, n_q_head, head_dim = dq.shape + n_kv_head = dk.shape[2] + pad_hd = triton.next_power_of_2(head_dim) + pad_n_q_head = triton.next_power_of_2(n_q_head) + pad_n_kv_head = triton.next_power_of_2(n_kv_head) + BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head) + + n_row = batch_size * seq_len + + # ensure dq and dk are contiguous + dq = dq.contiguous() + dk = dk.contiguous() + + # backward is similar to forward except swapping few ops + _triton_qwen2vl_mrope[(n_row,)]( + dq, + dk, + cos, + sin, + seq_len, + batch_size, + n_q_head, + n_kv_head, + head_dim, + pad_n_q_head, + pad_n_kv_head, + pad_hd, + mrope_section[0], + mrope_section[1], + BLOCK_SIZE=BLOCK_SIZE, + BACKWARD_PASS=True, + ) + return dq.transpose(1, 2), dk.transpose(1, 2) + + +class LigerQwen2VLMRopeFunction(torch.autograd.Function): + """ + Triton implementation of the Qwen2VL Multimodal Rotary Positional Embedding (M-RoPE) operation. + + Please find the corresponding HuggingFace implementation here: + https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py + """ + + @staticmethod + def forward(ctx, q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """ + q size: (bsz, n_q_head, seq_len, head_dim) + k size: (bsz, n_kv_head, seq_len, head_dim) + cos size: (3, bsz, seq_len, head_dim) + sin size: (3, bsz, seq_len, head_dim) + """ + q, k, cos, sin = qwen2vl_mrope_forward(q, k, cos, sin, mrope_section) + ctx.save_for_backward(cos, sin) + ctx.mrope_section = mrope_section + return q, k + + def backward(ctx, dq, dk): + """ + dq size: (bsz, n_q_head, seq_len, head_dim) + dk size: (bsz, n_kv_head, seq_len, head_dim) + cos size: (3, bsz, seq_len, head_dim) + sin size: (3, bsz, seq_len, head_dim) + """ + cos, sin = ctx.saved_tensors + mrope_section = ctx.mrope_section + dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section) + return dq, dk, None, None, None, None \ No newline at end of file diff --git a/build/torch-cuda/rms_norm.py b/build/torch-cuda/rms_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..bbf908671842d9e002774416c6a9ccc55acd2684 --- /dev/null +++ b/build/torch-cuda/rms_norm.py @@ -0,0 +1,365 @@ +""" +This file incorporates code from Unsloth licensed under the Apache License, Version 2.0. +See the original Unsloth repository at https://github.com/unslothai/unsloth. + +The following line +https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/rms_norm.py#L30 +is based on code from Unsloth, located at: +https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22 + +Modifications made by Yanning Chen, 2024. +""" + +import math +import operator + +import torch +import triton +import triton.language as tl + +from .utils import calculate_settings +from .utils import compare_version +from .utils import ensure_contiguous +from .utils import torch_to_triton_dtype + +if compare_version("triton", operator.ge, "3.0.0"): + try: + # typical import path with dispatch available + from triton.language.extra.libdevice import rsqrt + except ModuleNotFoundError: + # for working with NGC containers + from triton.language.extra.cuda.libdevice import rsqrt +else: + from triton.language.math import rsqrt + + +_CASTING_MODE_NONE: tl.constexpr = tl.constexpr(-1) +_CASTING_MODE_LLAMA: tl.constexpr = tl.constexpr(0) +_CASTING_MODE_GEMMA: tl.constexpr = tl.constexpr(1) + + +@triton.jit +def _rms_norm_forward_kernel( + Y_ptr, + Y_row_stride, + X_ptr, + X_row_stride, + W_ptr, + W_row_stride, + RSTD_ptr, + RSTD_row_stride, + n_cols, + eps, + offset, + casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out + BLOCK_SIZE: tl.constexpr, +): + """ + y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N) + + Reference: + 1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html + 2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22 + 3. https://arxiv.org/pdf/1910.07467 + """ + + row_idx = tl.program_id(0) + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + Y_ptr += row_idx * Y_row_stride + X_ptr += row_idx * X_row_stride + RSTD_ptr += row_idx * RSTD_row_stride + + X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0) + X_row_dtype = X_row.dtype + W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0) + + # On Llama, only rstd is computed on fp32 + if casting_mode == _CASTING_MODE_LLAMA: + X_row = X_row.to(tl.float32) + + # Gemma computes everything on fp32, and then casts back the output to the original dtype + if casting_mode == _CASTING_MODE_GEMMA: + W_row = W_row.to(tl.float32) + X_row = X_row.to(tl.float32) + + if casting_mode == _CASTING_MODE_NONE: + eps = eps.to(X_row_dtype) + offset = offset.to(X_row_dtype) + + mean_square = tl.sum(X_row * X_row, axis=0) / n_cols + rstd = rsqrt(mean_square + eps) + + # We can save time by caching rms with minimal memory overhead + # because rms is much smaller compared to X_row, as rms is for each row. + # However, on the computation side, it can save 4 operations (*, sum, /, sqrt). + tl.store(RSTD_ptr, rstd) + + X_row = X_row * rstd + + # On Llama, the multiplication with the weight is done on the original dtype + if casting_mode == _CASTING_MODE_LLAMA: + X_row = X_row.to(X_row_dtype) + + Y_row = X_row * (offset + W_row) + + if casting_mode == _CASTING_MODE_GEMMA: + Y_row = Y_row.to(X_row_dtype) + + tl.store(Y_ptr + col_offsets, Y_row, mask=mask) + + +@triton.jit +def _rms_norm_backward_kernel( + dY_ptr, + dY_row_stride, + dX_ptr, + dX_row_stride, + X_ptr, + X_row_stride, + X_dtype: tl.constexpr, + W_ptr, + W_row_stride, + RSTD_ptr, + RSTD_row_stride, + dW_ptr, + dW_row_stride, + n_rows, + n_cols, + offset, + rows_per_program: tl.constexpr, + casting_mode: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """ + dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product + dw = sum(dy * (x / RMS)). summation over BxT dimension + """ + + row_block_id = tl.program_id(0) + row_start = row_block_id * rows_per_program + row_end = min((row_block_id + 1) * rows_per_program, n_rows) + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + + dY_ptr += row_start * dY_row_stride + dX_ptr += row_start * dX_row_stride + + X_ptr += row_start * X_row_stride + RSTD_ptr += row_start + + W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0) + W_row = W_row + offset + + for _ in range(row_start, row_end): + dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0) + X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0) + + # Get cached rms + rstd_row = tl.load(RSTD_ptr) + + X_row = X_row.to(tl.float32) + + # Different bacward graphs for different casting modes + if casting_mode == _CASTING_MODE_LLAMA: + m = (dY_row * W_row).to(tl.float32) + + elif casting_mode == _CASTING_MODE_GEMMA: + dY_row = dY_row.to(tl.float32) + m = dY_row * W_row + else: + m = dY_row * W_row + + dX_row = rstd_row * m + + dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row) + + # calculate the gradient of W + if casting_mode == _CASTING_MODE_LLAMA: + dW_row += dY_row * (X_row * rstd_row).to(X_dtype) + else: + # here X_row is already in fp32 (see previous if block) + dW_row += dY_row * (X_row * rstd_row) + + tl.store(dX_ptr + col_offsets, dX_row.to(X_dtype), mask=mask) + + dY_ptr += dY_row_stride + dX_ptr += dX_row_stride + X_ptr += X_row_stride + RSTD_ptr += RSTD_row_stride + + tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask) + + +_str_to_casting_mode = { + "llama": _CASTING_MODE_LLAMA.value, + "gemma": _CASTING_MODE_GEMMA.value, + "none": _CASTING_MODE_NONE.value, +} + + +def rms_norm_forward(X, W, eps, offset, casting_mode): + if not isinstance(casting_mode, int): + assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}" + casting_mode = _str_to_casting_mode[casting_mode] + else: + assert casting_mode in _str_to_casting_mode.values(), f"Invalid casting mode: {casting_mode}" + + shape = X.shape + dim = shape[-1] + X = X.view(-1, dim) + n_rows, n_cols = X.shape + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + + Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) + # RSTD is to cache rstd for each row + # RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode + rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype + RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device) + + # Check constraints. + assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]" + + # XPU-specific optimization + kernel_args = {} + if X.device.type == "xpu": + kernel_args["grf_mode"] = "large" + _rms_norm_forward_kernel[(n_rows,)]( + Y, + Y.stride(0), + X, + X.stride(0), + W, + W.stride(0), + RSTD, + RSTD.stride(0), + n_cols, + eps, + offset, + casting_mode, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + **kernel_args, # XPU-specific optimization + ) + return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode + + +def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place): + shape = dY.shape + dim = shape[-1] + dY = dY.view(-1, dim) + n_rows, n_cols = dY.shape + + sm_count = 1 + if X.device.type == "cuda": + sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count + elif X.device.type == "xpu": + sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count + + # fp32 for numerical stability especially. + _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device) + + if n_cols > BLOCK_SIZE: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + rows_per_program = math.ceil(n_rows / sm_count) + grid = (sm_count,) + + if in_place is True: + dX = dY + else: + dX = torch.zeros_like(dY) + + # XPU-specific optimization + kernel_args = {} + if X.device.type == "xpu": + kernel_args["grf_mode"] = "large" + + _rms_norm_backward_kernel[grid]( + dY, + dY.stride(0), + dX, + dX.stride(0), + X, + X.stride(0), + torch_to_triton_dtype[X.dtype], + W, + W.stride(0), + RSTD, + RSTD.stride(0), + _dW, + _dW.stride(0), + n_rows, + n_cols, + offset, + rows_per_program, + casting_mode, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + **kernel_args, # XPU-specific optimization + ) + dX = dX.view(*shape) + dW = _dW.sum(dim=0).to(W.dtype) + + return dX, dW + + +class LigerRMSNormFunction(torch.autograd.Function): + """ + Performs RMSNorm (Root Mean Square Normalization), which normalizes the input tensor `X` using the + weight tensor `W`, with an optional offset and casting mode. + + Some models use an 'offset' to shift the weight tensor `W` by a constant value. For example, Gemma + uses an offset of 1.0, so the computation becomes `(X / RMS(X)) * (W + 1.0)` instead of the usual + `(X / RMS(X)) * W`. You can pass the offset value as an argument to the forward function. + + In addition, different models cast their inputs at different places during RMSNorm computation. For + example, Gemma casts everything to fp32 nefore starting the computation, while Llama casts only the + inverse RMS to fp32. You can specify the casting mode using the `casting_mode` argument. We currently + support the following casting modes (they match HuggingFace Transformers' implementations): + - 'llama': matches the Llama implementation, where only the inverse RMS is computed on fp32. + - 'gemma': matches the Gemma implementation, where everything is cast to fp32, then computed, then cast back to the original dtype. + - 'none': no casting is done. The computation is done in the original dtype. This saves memory and is slightly faster, but has more error w.r.t. the original implementation. + + `in_place` option means whether to in_place modify dY to store dX. This is default to `True` to save memory. However, under certain cases, it can produce incorrect inputs. + For example, gemma2 uses two rmsnorm sequentially with residual in between. The resesidual part needs dY so it cannot be modified in-place. + Therefore, for the patching of RMSNorm in gemma2, we set `in_place` to `False` + """ + + @staticmethod + @ensure_contiguous + def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True): + """ + X: (B, T, H) or (BxT, H) + W: (H,) + """ + Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode) + ctx.offset = offset + ctx.casting_mode = casting_mode + ctx.in_place = in_place + ctx.BLOCK_SIZE = BLOCK_SIZE + ctx.num_warps = num_warps + ctx.save_for_backward(X, W, RSTD) + return Y + + @staticmethod + @ensure_contiguous + def backward(ctx, dY): + """ + Y: (B, T, H) or (BxT, H) + """ + X, W, RSTD = ctx.saved_tensors + dX, dW = rms_norm_backward( + dY, + X, + W, + RSTD, + ctx.offset, + ctx.casting_mode, + ctx.BLOCK_SIZE, + ctx.num_warps, + ctx.in_place, + ) + return dX, dW, None, None, None, None \ No newline at end of file diff --git a/build/torch-cuda/rope.py b/build/torch-cuda/rope.py new file mode 100644 index 0000000000000000000000000000000000000000..8c4801bf72c82d718ebd5fa63f278fdcbe0f23b3 --- /dev/null +++ b/build/torch-cuda/rope.py @@ -0,0 +1,239 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _triton_rope( + q_ptr, + q_row_stride, + k_ptr, + k_row_stride, + cos, + cos_row_stride, + sin, + sin_row_stride, + sl, + bs: tl.constexpr, + cos_bs: tl.constexpr, + n_qh: tl.constexpr, + n_kh: tl.constexpr, + hd: tl.constexpr, + pad_n_qh: tl.constexpr, + pad_n_kh: tl.constexpr, + pad_hd: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BACKWARD_PASS: tl.constexpr = False, +): + # q size: (bsz, seq_len, num_q_heads, head_dim) + # q stride: (seq_len * num_q_heads * head_dim, num_q_heads * head_dim, head_dim, 1) + # k size: (bsz, seq_len, num_kv_heads, head_dim) + # k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1) + + # cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) + # stride: (seq_len * head_dim, head_dim, 1) + pid = tl.program_id(0) + + # locate start address + q_ptr = q_ptr + pid * q_row_stride + k_ptr = k_ptr + pid * k_row_stride + + # #################################################################### + # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position + # m of this program instance + # #################################################################### + + # 1. program instances are laid out in a 1D vector of size bsz * seq_len, which + # effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension + # being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index + # and pid % sl to get the sequence index. + # 2. We only need the left half of cos and sin matrix because the right half is just + # a clone of the left half. + batch_idx = pid // sl + cos_row_idx = pid % sl + cos = cos + tl.where( + cos_bs == 1, + cos_row_idx * cos_row_stride, + batch_idx * (sl * cos_row_stride) + cos_row_idx * cos_row_stride, + ) + sin = sin + tl.where( + cos_bs == 1, + cos_row_idx * sin_row_stride, + batch_idx * (sl * sin_row_stride) + cos_row_idx * sin_row_stride, + ) + + cos_offsets = tl.arange(0, pad_hd // 2) + cos_mask = cos_offsets < hd // 2 + cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0) + sin_row = tl.load(sin + cos_offsets, mask=cos_mask, other=0) + + # #################################################################### + # Load the left and right half of q and k for the current + # program instance (i.e. for the current token) separately + # #################################################################### + # left half of the head + first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] + first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] + first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2) + first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2) + q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype) + k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype) + + # right half of the head + second_half_q_offsets = first_half_q_offsets + (hd // 2) + second_half_k_offsets = first_half_k_offsets + (hd // 2) + second_q_mask = first_q_mask + second_k_mask = first_k_mask + q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype) + k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype) + + if not BACKWARD_PASS: + # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin] + new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row + tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) + new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row + tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) + + new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row + tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask) + new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row + tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) + else: + # with some math, we can get: + # dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin] + new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row + tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) + new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row + tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) + + new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row + tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask) + new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row + tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) + + +def rope_forward(q, k, cos, sin): + # transpose it back to the physical shape because Triton looks at the physical storage + # note: q and k are incontiguous before the transformation and will become contiguous after transpose + q = q.transpose(1, 2) + k = k.transpose(1, 2) + + batch_size, seq_len, n_q_head, head_dim = q.shape + n_kv_head = k.shape[2] + pad_hd = triton.next_power_of_2(head_dim) + pad_n_q_head = triton.next_power_of_2(n_q_head) + pad_n_kv_head = triton.next_power_of_2(n_kv_head) + BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head) + + n_row = batch_size * seq_len + + # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous + q = q.contiguous() + k = k.contiguous() + cos = cos.contiguous() + sin = sin.contiguous() + cos_batch_size = cos.shape[0] + + _triton_rope[(n_row,)]( + q, + q.stride(1), + k, + k.stride(1), + cos, + cos.stride(-2), + sin, + sin.stride(-2), + seq_len, + batch_size, + cos_batch_size, + n_q_head, + n_kv_head, + head_dim, + pad_n_q_head, + pad_n_kv_head, + pad_hd, + BLOCK_SIZE=BLOCK_SIZE, + BACKWARD_PASS=False, + ) + return q.transpose(1, 2), k.transpose(1, 2), cos, sin + + +def rope_backward(dq, dk, cos, sin): + dq = dq.transpose(1, 2) + dk = dk.transpose(1, 2) + + batch_size, seq_len, n_q_head, head_dim = dq.shape + cos_batch_size = cos.shape[0] + n_kv_head = dk.shape[2] + pad_hd = triton.next_power_of_2(head_dim) + pad_n_q_head = triton.next_power_of_2(n_q_head) + pad_n_kv_head = triton.next_power_of_2(n_kv_head) + BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head) + + n_row = batch_size * seq_len + + # ensure dq and dk are contiguous + dq = dq.contiguous() + dk = dk.contiguous() + + # backward is similar to forward except swapping few ops + _triton_rope[(n_row,)]( + dq, + dq.stride(1), + dk, + dk.stride(1), + cos, + cos.stride(-2), + sin, + sin.stride(-2), + seq_len, + batch_size, + cos_batch_size, + n_q_head, + n_kv_head, + head_dim, + pad_n_q_head, + pad_n_kv_head, + pad_hd, + BLOCK_SIZE=BLOCK_SIZE, + BACKWARD_PASS=True, + ) + return dq.transpose(1, 2), dk.transpose(1, 2) + + +class LigerRopeFunction(torch.autograd.Function): + """ + Triton implementation of the Rotary Positional Embedding (RoPE) operation. Please note that + this implements the HuggingFace Llama & Mistral version, whose rotation matrix is slightly different + than the original RoPE paper. + + Please find the corresponding HuggingFace implementation here: + https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/llama/modeling_llama.py#L184 + + For more details about the rotation matrix used here, please refer to: + https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509/2 + """ + + @staticmethod + def forward(ctx, q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """ + q size: (bsz, n_q_head, seq_len, head_dim) + k size: (bsz, n_kv_head, seq_len, head_dim) + cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) + sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) + """ + q, k, cos, sin = rope_forward(q, k, cos, sin) + ctx.save_for_backward(cos, sin) + return q, k + + def backward(ctx, dq, dk): + """ + dq size: (bsz, n_q_head, seq_len, head_dim) + dk size: (bsz, n_kv_head, seq_len, head_dim) + cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) + sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) + """ + + cos, sin = ctx.saved_tensors + dq, dk = rope_backward(dq, dk, cos, sin) + return dq, dk, None, None, None, None \ No newline at end of file diff --git a/build/torch-cuda/swiglu.py b/build/torch-cuda/swiglu.py new file mode 100644 index 0000000000000000000000000000000000000000..f4ed39745bb5966a9d5eaee4643d487e66a4e620 --- /dev/null +++ b/build/torch-cuda/swiglu.py @@ -0,0 +1,116 @@ +import torch +import triton +import triton.language as tl + +from .utils import calculate_settings +from .utils import ensure_contiguous + + +@triton.jit +def silu(x): + return x * tl.sigmoid(x) + + +@triton.jit +def _swiglu_forward_kernel(a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr): + program_id = tl.program_id(0).to(tl.int64) + + # locate start index + a_ptr += program_id * stride + b_ptr += program_id * stride + c_ptr += program_id * stride + + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + # sigmoid requires type float32 + a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32) + b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0) + c_row = silu(a_row) * b_row + tl.store(c_ptr + col_offsets, c_row, mask=mask) + + +@triton.jit +def _swiglu_backward_kernel(dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr): + program_id = tl.program_id(0).to(tl.int64) + + # locate start index + dc_ptr += program_id * stride + a_ptr += program_id * stride + b_ptr += program_id * stride + + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + dc_row = tl.load(dc_ptr + col_offsets, mask=mask, other=0) + # sigmoid requires type float32 + a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32) + b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0) + + # recomputation to save memory + sig_a = tl.sigmoid(a_row) + silu_a = a_row * sig_a + db_row = dc_row * silu_a + da_row = dc_row * (silu_a * (1 - sig_a) + sig_a) * b_row + + tl.store(a_ptr + col_offsets, da_row, mask=mask) + tl.store(b_ptr + col_offsets, db_row, mask=mask) + + +def swiglu_forward(a, b): + ori_shape = a.shape + + n_cols = ori_shape[-1] + a = a.view(-1, n_cols) + b = b.view(-1, n_cols) + c = torch.empty_like(a) + n_rows = a.shape[0] + + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + + _swiglu_forward_kernel[(n_rows,)]( + a, + b, + c, + c.stride(-2), + n_cols=n_cols, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + return a, b, c.view(*ori_shape) + + +def swiglu_backward(a, b, dc): + ori_shape = dc.shape + n_cols = ori_shape[-1] + dc = dc.view(-1, n_cols) + n_rows = dc.shape[0] + + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + + _swiglu_backward_kernel[(n_rows,)]( + dc, + a, + b, + dc.stride(-2), + n_cols=n_cols, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + return a.view(*ori_shape), b.view(*ori_shape) + + +class LigerSiLUMulFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward(ctx, a, b): + a, b, c = swiglu_forward(a, b) + ctx.save_for_backward(a, b) + return c + + @staticmethod + @ensure_contiguous + def backward(ctx, dc): + a, b = ctx.saved_tensors + a, b = swiglu_backward(a, b, dc) + return a, b \ No newline at end of file diff --git a/build/torch-cuda/tvd.py b/build/torch-cuda/tvd.py new file mode 100644 index 0000000000000000000000000000000000000000..c072e4118c6fa5980de5c96a7acdaa05b6acb468 --- /dev/null +++ b/build/torch-cuda/tvd.py @@ -0,0 +1,207 @@ +from typing import Literal +from typing import Optional + +import torch +import triton +import triton.language as tl + +from .utils import ensure_contiguous + +MAX_FUSED_SIZE = 65536 // 4 + +REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"] + +_REDUCTION_MODE_NONE = tl.constexpr(0) +_REDUCTION_MODE_SUM = tl.constexpr(1) +_REDUCTION_MODE_MEAN = tl.constexpr(2) +_REDUCTION_MODE_BATCHMEAN = tl.constexpr(3) + +_str_to_reduction_mode = { + "none": _REDUCTION_MODE_NONE.value, + "sum": _REDUCTION_MODE_SUM.value, + "mean": _REDUCTION_MODE_MEAN.value, + "batchmean": _REDUCTION_MODE_BATCHMEAN.value, +} + + +def get_num_warps(BLOCK_SIZE): + num_warps = 4 + if BLOCK_SIZE >= 32768: + num_warps = 32 + elif BLOCK_SIZE >= 8192: + num_warps = 16 + elif BLOCK_SIZE >= 2048: + num_warps = 8 + + return num_warps + + +@triton.jit +def _tv_distance_kernel( + p_ptr, + p_stride, + q_ptr, + q_stride, + loss_ptr, + loss_stride, + grads_ptr, + grads_stride, + label_ptr, + ignore_index: tl.constexpr, + n_cols, + BLOCK_SIZE: tl.constexpr, + HAS_LABEL: tl.constexpr, + reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN, +): + pid = tl.program_id(0).to(tl.int64) + p_ptr += pid * p_stride + q_ptr += pid * q_stride + loss_ptr += pid * loss_stride + grads_ptr += pid * grads_stride + label_ptr += pid + + base_offsets = tl.arange(0, BLOCK_SIZE) + + if HAS_LABEL: + label = tl.load(label_ptr) + if label == ignore_index: + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + base_offsets + mask = offsets < n_cols + tl.store(grads_ptr + offsets, 0.0, mask=mask) + if reduction == _REDUCTION_MODE_NONE: + tl.store(loss_ptr + offsets, 0.0, mask=mask) + return + + loss_sum = 0.0 + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + base_offsets + mask = offsets < n_cols + + p = tl.load(p_ptr + offsets, mask=mask, other=0.0) + q = tl.load(q_ptr + offsets, mask=mask, other=0.0) + + # TVD(P || Q) = 0.5 * |P - Q| + tv_loss = 0.5 * tl.abs(p - q) + + grad_res = tl.where(p > q, 0.5, -0.5) + + tl.store(grads_ptr + offsets, grad_res, mask=mask) + + if reduction == _REDUCTION_MODE_NONE: + tl.store(loss_ptr + offsets, tv_loss, mask=mask) + else: + loss_sum += tl.sum(tv_loss, axis=0) + + if reduction != _REDUCTION_MODE_NONE: + tl.store(loss_ptr, loss_sum) + + +def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label): + BT, V = p.shape + + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + num_warps = get_num_warps(BLOCK_SIZE) + + grid = (BT,) + + reduction = _str_to_reduction_mode[reduction] + + out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,) + output_tensor = torch.zeros(out_size, device=p.device, dtype=torch.float32) + grads = torch.empty_like(p) + + n_non_ignore = (shift_labels != ignore_index).sum().item() if has_label else BT + + _tv_distance_kernel[grid]( + p, + p.stride(0), + q, + q.stride(0), + output_tensor, + output_tensor.stride(0), + grads, + grads.stride(0), + shift_labels if has_label else torch.empty(1, device=p.device), + ignore_index, + V, + BLOCK_SIZE=BLOCK_SIZE, + HAS_LABEL=has_label, + num_warps=num_warps, + reduction=reduction, + ) + + if reduction == _REDUCTION_MODE_BATCHMEAN.value: + return output_tensor.sum() / n_non_ignore, grads / n_non_ignore + elif reduction == _REDUCTION_MODE_SUM.value: + return output_tensor.sum(dim=0), grads + elif reduction == _REDUCTION_MODE_MEAN.value: + return output_tensor.sum() / (n_non_ignore * V), grads / (n_non_ignore * V) + else: + return output_tensor, grads + + +def tvd_backward_triton(grad_output, grads): + # If cross entropy is the last layer, grad_output is 1.0. Skip the mul then. + if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + return grads + + return grads * grad_output + + +class LigerTVDLossFunction(torch.autograd.Function): + """ + Class implementing the forward and backward pass for the Total Variation Distance Loss using Triton. + """ + + @staticmethod + @ensure_contiguous + def forward( + ctx, + p: torch.Tensor, + q: torch.Tensor, + shift_labels: Optional[torch.Tensor] = None, + reduction: REDUCTION_LITERAL = "batchmean", + ignore_index: int = -100, + ) -> torch.Tensor: + """A forward pass for the Total Variation Distance Loss. + + Args: + ctx: Torch autograd context + p (torch.Tensor): A tensor of shape (BT, V) containing the first distribution. + q (torch.Tensor): A tensor of shape (BT, V) containing the second distribution. + shift_labels (Optional[torch.Tensor]): A tensor of shape (BT,) containing the labels. + reduction (REDUCTION_LITERAL, optional): The reduction method to be applied. Defaults to "batchmean". + ignore_index (int, optional): The index to ignore during loss calculation. Defaults to -100. + + Returns: + torch.Tensor: The computed Total Variation Distance Loss. + """ + has_label = False + if shift_labels is not None: + assert shift_labels.shape == (p.shape[0],), ( + f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}" + ) + shift_labels = shift_labels.contiguous() + has_label = True + + loss, grads = tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label) + ctx.save_for_backward(grads) + return loss + + @staticmethod + @ensure_contiguous + def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: + """A backward pass for the Total Variation Distance Loss. + + Args: + ctx: Torch autograd context + grad_output (torch.Tensor): The gradient of the loss with respect to the output. + + Returns: + tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs. + """ + (grads,) = ctx.saved_tensors + grads = tvd_backward_triton(grad_output, grads) + + return grads, None, None, None, None \ No newline at end of file diff --git a/build/torch-cuda/utils.py b/build/torch-cuda/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..718b0bb9c6391cf291bae91e671fc44a2b5ebf36 --- /dev/null +++ b/build/torch-cuda/utils.py @@ -0,0 +1,135 @@ +""" +This file incorporates code from Unsloth licensed under the Apache License, Version 2.0. +See the original Unsloth repository at https://github.com/unslothai/unsloth. + +The following line +https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/utils.py#L23 +is based on code from Unsloth, located at: +https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43 + +Modifications made by Yanning Chen, 2024. +""" + +import functools +import importlib +import operator + +from typing import Callable + +import torch +import triton +import triton.language as tl + +from packaging.version import Version + +def infer_device(): + """ + Get current device name based on available devices + """ + if torch.cuda.is_available(): # Works for both Nvidia and AMD + return "cuda" + elif torch.xpu.is_available(): + return "xpu" + else: + return "cpu" + +def is_hip() -> bool: + return torch.version.hip is not None + + +def ensure_contiguous(fn): + @functools.wraps(fn) + def wrapper(ctx, *args, **kwargs): + def maybe_to_contiguous(x): + return x.contiguous() if isinstance(x, torch.Tensor) else x + + args = [maybe_to_contiguous(arg) for arg in args] + kwargs = {k: maybe_to_contiguous(v) for k, v in kwargs.items()} + return fn(ctx, *args, **kwargs) + + return wrapper + + +def calculate_settings(n): + # reference: https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43 + + MAX_FUSED_SIZE = 65536 + BLOCK_SIZE = triton.next_power_of_2(n) + if BLOCK_SIZE > MAX_FUSED_SIZE: + raise RuntimeError( + f"Cannot launch Triton kernel since n = {n} exceeds the recommended Triton blocksize = {MAX_FUSED_SIZE}." + ) + + num_warps = 4 + if BLOCK_SIZE >= 32768: + num_warps = 32 if not is_hip() else 16 + elif BLOCK_SIZE >= 8192: + num_warps = 16 + elif BLOCK_SIZE >= 2048: + num_warps = 8 + return BLOCK_SIZE, num_warps + + +def compare_version(package: str, operator: Callable, target: str): + try: + pkg = importlib.import_module(package) + except ImportError: + return False + pkg_version = Version(pkg.__version__) + return operator(pkg_version, Version(target)) + + +def get_amp_custom_fwd_bwd() -> Callable: + device = infer_device() + if compare_version("torch", operator.ge, "2.4.0"): + return ( + functools.partial(torch.amp.custom_fwd, device_type=device), + functools.partial(torch.amp.custom_bwd, device_type=device), + ) + return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd + + +amp_custom_fwd, amp_custom_bwd = get_amp_custom_fwd_bwd() + + +torch_to_triton_dtype = { + torch.float32: tl.float32, + torch.float16: tl.float16, + torch.bfloat16: tl.bfloat16, +} + + +@triton.jit +def element_mul_kernel( + X_ptr, + X_stride, + grad_output_ptr, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + """ + This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr. + The multiplication is performed in-place on the tensor pointed by X_ptr. + + Parameters: + X_ptr: Pointer to the input tensor. + X_stride (int): The stride of the input tensor. + grad_output_ptr: Pointer to the gradient output value. + n_cols (int): The number of columns in the input tensor. + BLOCK_SIZE (int): The block size for Triton operations. + """ + + # Get the program ID and convert it to int64 to avoid overflow + program_id = tl.program_id(0).to(tl.int64) + + # Locate the start index + X_ptr += program_id * X_stride + + # Load the gradient output value + grad_output = tl.load(grad_output_ptr) + + # Perform the element-wise multiplication + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols) + tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols) \ No newline at end of file diff --git a/build/torch-rocm/__init__.py b/build/torch-rocm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..271a47c2646379ff9ec8cf8c3c3a0e973f187652 --- /dev/null +++ b/build/torch-rocm/__init__.py @@ -0,0 +1,3 @@ +from . import layers + +__all__ = ["layers"] \ No newline at end of file diff --git a/build/torch-rocm/_ops.py b/build/torch-rocm/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..45b30870c17a6c248a2a9b9922667f3afc95a973 --- /dev/null +++ b/build/torch-rocm/_ops.py @@ -0,0 +1,8 @@ +import torch +ops = torch.ops._liger_kernels_ab5ef3f + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_liger_kernels_ab5ef3f::{op_name}" \ No newline at end of file diff --git a/build/torch-rocm/cross_entropy.py b/build/torch-rocm/cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..590d94868b34455f93c98f6c2fb46872d6b6956e --- /dev/null +++ b/build/torch-rocm/cross_entropy.py @@ -0,0 +1,460 @@ +import operator + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from .utils import compare_version +from .utils import element_mul_kernel +from .utils import is_hip +from .utils import infer_device + +if compare_version("triton", operator.ge, "3.0.0"): + try: + # typical import path with dispatch available + from triton.language.extra.libdevice import tanh + except ModuleNotFoundError: + # for working with NGC containers + from triton.language.extra.cuda.libdevice import tanh +else: + from triton.language.math import tanh + + +@triton.jit +def liger_cross_entropy_kernel( + X_ptr, + X_stride, + Y_ptr, + Y_stride, + weight_ptr, + loss_ptr, + z_loss_ptr, + loss_stride, + n_cols, + n_non_ignore, + sum_non_ignore_weight, + weight_sum, + ignore_index, + lse_square_scale: tl.constexpr, + label_smoothing: tl.constexpr, + reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time + softcap, + RETURN_Z_LOSS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_SOFTCAPPING: tl.constexpr, +): + """ + This kernel computes both cross entropy loss and the gradient of the input. + We only consider hard label + mean reduction for now. Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math. + + Parameters: + X_ptr: Pointer to input tensor. + X_stride (int): The stride of the input tensor. + Y_ptr: Pointer to target tensor. + Y_stride (int): The stride of the target tensor. + weight_ptr: Pointer to weight tensor. + loss_ptr: Pointer to tensor to store the loss. + z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0. + loss_stride (int): The stride of the loss tensor. + n_cols (int): The number of columns in the input tensor. + n_non_ignore (float): The number of non-ignored elements in the batch. + sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch. + weight_sum (float): The sum of weight tensor. + ignore_index (int): The index to ignore in the target. + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. + reduction (str): The string for the reduction to apply + softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap). + RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1. + BLOCK_SIZE (int): The block size for Triton operations. + HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes. + HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not. + """ + + # https://github.com/triton-lang/triton/issues/1058 + # If B*T*V is too large, program_id * stride will overflow out of int32, so we convert to int64 + program_id = tl.program_id(0).to(tl.int64) + + # 1. Load Y_ptr first because if the target is ignore_index, we can return right away + Y_ptr += program_id * Y_stride + y = tl.load(Y_ptr) + + # 2. locate the start index + X_ptr += program_id * X_stride + + if y == ignore_index: + # set all X_ptr as 0 + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols) + return + + loss_ptr += program_id * loss_stride + if RETURN_Z_LOSS: + z_loss_ptr += program_id * loss_stride + + if HAS_WEIGHT: + weight_y = tl.load(weight_ptr + y).cast(tl.float32) + + # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax) + # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867 + + # 3. [Online softmax] first pass: find max + sum + m = float("-inf") # m is the max value. use the notation from the paper + d = 0.0 # d is the sum. use the notation from the paper + ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation + if HAS_SOFTCAPPING: + ori_X_y = softcap * tanh(ori_X_y / softcap) + + # Label smoothing is a general case of normal cross entropy + # See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310 + scaled_x_sum = 0.0 + eps = label_smoothing / n_cols + + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load( + X_ptr + X_offsets, + mask=X_offsets < n_cols, + other=float("-inf"), + # Ensure float32 precision for softmax calculation + ).cast(tl.float32) + if HAS_SOFTCAPPING: + X_block = softcap * tanh(X_block / softcap) + block_max = tl.max(X_block) + if label_smoothing > 0: + # scale X beforehand to avoid overflow + if HAS_WEIGHT: + weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) + scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block * weight_block, 0.0)) + else: + scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) + m_new = tl.maximum(m, block_max) + d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) + m = m_new + + # log (sum(e^(X_i))) = log (sum(e ^ (max(X) * e ^ (X_i - max(X))))) + # = log (e^(max(X)) * sum(e ^ (X_i - max(X)))) + # = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d + lse = m + tl.log(d) + + # 4. [Online Softmax] Second pass: compute gradients + # For 'mean' reduction, gradients are normalized by number of non-ignored elements (N) + # dx_y = (softmax(x_y) - 1) / N + # dx_i = softmax(x_i) / N, i != y + # For label smoothing: + # dx_i = (softmax(x_i) - label_smoothing / V) / N, V = n_cols, i != y + # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N + # = dx_i - (1 - label_smoothing) / N + # With Z loss: + # dx_i = ((1 + 2 * lse_square_scale * lse) * softmax(x_i) - label_smoothing / V) / N, i != y + # dx_y = dx_i - (1 - label_smoothing) / N + # For 'sum' reduction, no normalization is applied: + # dx_y = softmax(x_y) - 1 + # dx_i = softmax(x_i), for i ≠ y + + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load( + X_ptr + X_offsets, + mask=X_offsets < n_cols, + other=float("-inf"), + # Ensure float32 precision for softmax calculation + ).cast(tl.float32) + if HAS_SOFTCAPPING: + intermediate = tanh(X_block / softcap) + X_block = softcap * intermediate + + if not HAS_WEIGHT: + # softmax(x_i) + X_block = tl.exp(X_block - m) / d + # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i) + X_block += 2 * lse_square_scale * lse * X_block + # smoothing term + X_block += -eps + # special handle dx_y + X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing)) + # reduction scale + if reduction == "mean": + X_block = X_block / n_non_ignore + else: + weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) + softmax_X = tl.exp(X_block - m) / d + # derivative of original_loss + dloss_ori = (1 - label_smoothing) * softmax_X + # specially handle dx_y + dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing)) + dloss_ori = dloss_ori * weight_y + # derivative of smooth_loss + dloss_smooth = eps * (-weight_block + softmax_X * weight_sum) + # derivative of z-loss + dz_loss = 2 * lse_square_scale * lse * softmax_X + # reduction scale + if reduction == "mean": + dloss_ori = dloss_ori / sum_non_ignore_weight + dloss_smooth = dloss_smooth / sum_non_ignore_weight + # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight. + dz_loss = dz_loss / n_non_ignore + # derivative of total_loss + X_block = dloss_ori + dloss_smooth + dz_loss + + # chain rule softcapping + # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap)) + if HAS_SOFTCAPPING: + X_block = X_block * (1 - intermediate * intermediate) + + tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols) + + # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in + # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34 + tl.debug_barrier() + + # 5. Calculate the loss + + # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X)))) + # = (X_y - max(X)) - log(sum(e ^ (X - max(X)))) + # = X_y - m - log d = X_y - lse + # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1 + # So we can safely calculate log (softmax(X_y)) without overflow + loss = lse - ori_X_y + if HAS_WEIGHT: + loss = weight_y * loss + + # Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps + # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) + # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i)) + # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as: + # = (1 - label_smoothing) * H(q, p) + (sum(-eps * x_i) + label_smoothing * (m + logd)) + # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567 + # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516 + # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087 + if label_smoothing > 0: + if HAS_WEIGHT: + smooth_loss = scaled_x_sum + eps * lse * weight_sum + else: + smooth_loss = scaled_x_sum + label_smoothing * lse + loss = loss * (1 - label_smoothing) + smooth_loss + + # An auxiliary loss, z_loss + # Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html + z_loss = lse_square_scale * lse * lse + # Normalize the loss by the number of non-ignored elements if reduction is "mean" + if reduction == "mean": + if HAS_WEIGHT: + loss = loss / sum_non_ignore_weight + else: + loss = loss / n_non_ignore + # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight. + z_loss = z_loss / n_non_ignore + loss += z_loss + + tl.store(loss_ptr, loss) + if RETURN_Z_LOSS: + tl.store(z_loss_ptr, z_loss) + + +# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 +# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling +# The optimal maximum block size depends on your hardware, your kernel, and your dtype +MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2 # the best size we found by manually tuning + + +def cross_entropy_forward( + _input, + target, + weight, + ignore_index, + lse_square_scale, + label_smoothing, + reduction, + softcap, + return_z_loss, +): + assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}" + + BT, V = _input.shape + n_rows = BT + + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + + # unreduced loss + loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) + z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None + + target_mask = target != ignore_index + n_non_ignore = target_mask.sum().item() + assert (target * target_mask).max() < _input.shape[-1], ( + f"Target {target.max()} is out of bounds. Expected < {_input.shape[-1]}" + ) + assert (target * target_mask).min() >= 0, f"Target {target.min()} is out of bounds. Expected >= 0" + sum_non_ignore_weight = n_non_ignore + weight_sum = 0.0 + if weight is not None: + assert weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {weight.shape}" + assert torch.is_floating_point(weight), ( + f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}" + ) + sum_non_ignore_weight = torch.gather(weight, dim=0, index=target.masked_select(target_mask)).sum().item() + weight_sum = weight.sum().item() + # ensure weight is contiguous + if weight.stride(-1) != 1: + weight = weight.contiguous() + + # ensure _input and target are contiguous in the last dimension + if _input.stride(-1) != 1: + _input = _input.contiguous() + if target.stride(-1) != 1: + target = target.contiguous() + + # Here we use a trick to store X_ptr gradient in X_ptr so we can save memory + liger_cross_entropy_kernel[(n_rows,)]( + X_ptr=_input, + X_stride=_input.stride(-2), + Y_ptr=target, + Y_stride=target.stride(-1), # always 1 + weight_ptr=weight, # dummy if None + loss_ptr=loss_1d, + z_loss_ptr=z_loss_1d, + loss_stride=loss_1d.stride(-1), # always 1 + n_cols=V, + n_non_ignore=n_non_ignore, + sum_non_ignore_weight=sum_non_ignore_weight, + ignore_index=ignore_index, + weight_sum=weight_sum, + lse_square_scale=lse_square_scale, + label_smoothing=label_smoothing, + reduction=reduction, + softcap=softcap, + RETURN_Z_LOSS=return_z_loss, + BLOCK_SIZE=BLOCK_SIZE, + HAS_WEIGHT=True if weight is not None else False, + HAS_SOFTCAPPING=True if softcap is not None else False, + # TODO: 32 seems to give the best performance + # Performance is quite sensitive to num_warps + num_warps=32 if not is_hip() else 16, + ) + + if reduction == "none": + loss = loss_1d + z_loss = z_loss_1d if return_z_loss else None + else: + loss = torch.sum(loss_1d) + z_loss = torch.sum(z_loss_1d) if return_z_loss else None + + return loss, z_loss, _input + + +def cross_entropy_backward(_input, grad_output): + # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time + if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + pass + # If reduction is 'none' + elif grad_output.ndim > 0: + _input = _input * grad_output.unsqueeze(dim=1) + # If reduction is ['mean', 'sum'], grad_output is just a scalar + # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place + # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. + else: + BT, V = _input.shape + n_rows = BT + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + + element_mul_kernel[(n_rows,)]( + _input, + _input.stride(-2), + grad_output, + V, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32 if not is_hip() else 16, + ) + + return _input + + +class LigerCrossEntropyFunction(torch.autograd.Function): + """ + This class implements a custom autograd function for the Liger Cross Entropy loss. + It overrides the forward and backward methods of the torch.autograd.Function class. + """ + + @staticmethod + def forward( + ctx, + _input: torch.Tensor, + target: torch.Tensor, + weight: Optional[torch.FloatTensor], + ignore_index: int = -100, + lse_square_scale: float = 0.0, + label_smoothing: float = 0.0, + reduction: str = "mean", + softcap: Optional[float] = None, + return_z_loss: bool = False, + ): + """ + The forward pass of the Liger Cross Entropy loss. + + Parameters: + ctx : The context object. + _input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size. + target (tensor): The target tensor of shape (BT) where each value is in [0, V-1]. + weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype + ignore_index (int): The index to ignore in the target. + lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + reduction (str): The reduction to apply to the output: "none" | "mean | "sum". + softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap). + return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False` + + Returns: + tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None. + """ + loss, z_loss, _input = cross_entropy_forward( + _input, + target, + weight, + ignore_index, + lse_square_scale, + label_smoothing, + reduction, + softcap, + return_z_loss, + ) + # TODO: investigation + # If we don't detach the _input tensor, the memory will double + # Not sure why but seems that there will be a time both grad and value exist but in different location + ctx.save_for_backward(_input.detach()) + ctx.return_z_loss = return_z_loss + + return loss, z_loss + + @staticmethod + def backward(ctx, grad_output, grad_ouput2): + """ + The backward pass of the Liger Cross Entropy loss. + + Parameters: + ctx : The context object with saved tensors. + grad_output (tensor): The tensor containing the gradient of the loss with respect to the output. + grad_output2 (tenosr): No use. + Returns: + tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None. + """ + if ctx.return_z_loss: + del grad_ouput2 # z_loss is only for logging + + (_input,) = ctx.saved_tensors + _input = cross_entropy_backward(_input, grad_output) + return ( + _input, + None, + None, + None, + None, + None, + None, + None, + None, + ) \ No newline at end of file diff --git a/build/torch-rocm/dyt.py b/build/torch-rocm/dyt.py new file mode 100644 index 0000000000000000000000000000000000000000..3cc42d14ca31b33e8d4b570f9b6b6d12126f656d --- /dev/null +++ b/build/torch-rocm/dyt.py @@ -0,0 +1,225 @@ +import operator + +import torch +import triton +import triton.language as tl + +from .utils import calculate_settings +from .utils import compare_version +from .utils import ensure_contiguous +from .utils import infer_device + +if compare_version("triton", operator.ge, "3.0.0"): + try: + # typical import path with dispatch available + from triton.language.extra.libdevice import tanh + except ModuleNotFoundError: + # for working with NGC containers + from triton.language.extra.cuda.libdevice import tanh +else: + from triton.language.math import tanh + + +@triton.jit +def _dyt_fwd_kernel( + x_ptr, + x_row_stride, + alpha_ptr, + gamma_ptr, + beta_ptr, + y_ptr, + y_row_stride, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + """ + Reference: + https://arxiv.org/abs/2503.10622 + + Shapes: + - x: (BT, C) + - alpha: (1) + - gamma: (C) + - beta: (C) + """ + row_idx = tl.program_id(0) + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < n_cols + + x_ptr += row_idx * x_row_stride + y_ptr += row_idx * y_row_stride + + alpha = tl.load(alpha_ptr) + gamma = tl.load(gamma_ptr + offsets, mask=mask) + beta = tl.load(beta_ptr + offsets, mask=mask) + x = tl.load(x_ptr + offsets, mask=mask) + y = gamma * tanh((alpha * x).cast(tl.float32)) + beta + tl.store(y_ptr + offsets, y, mask=mask) + + +@triton.jit +def _dyt_bwd_kernel( + x_ptr, + x_row_stride, + dy_ptr, + dy_row_stride, + dx_ptr, + dx_row_stride, + alpha_ptr, + dalpha_ptr, + gamma_ptr, + dgamma_ptr, + dgamma_row_stride, + n_cols, + n_rows, + ROWS_PER_PROGRAM: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """ + Reference: + https://arxiv.org/abs/2503.10622 + + Shapes: + - x: (BT, C) + - alpha: (1) + - gamma: (C) + - dx: (BT, C) + - dy: (BT, C) + - dgamma: (sm_count, C) + - dalpha: (sm_count,) + """ + # d(gamma * tanh(alpha * x) + beta) / dx + # = gamma * (1 - tanh^2(alpha * x)) * alpha + # d(gamma * tanh(alpha * x) + beta) / dalpha + # = gamma * (1 - tanh^2(alpha * x)) * x + # d(gamma * tanh(alpha * x) + beta) / dgamma + # = tanh(alpha * x) + # d(gamma * tanh(alpha * x)) / dbeta = 1 + pid = tl.program_id(0) + + row_start = pid * ROWS_PER_PROGRAM + row_end = min((pid + 1) * ROWS_PER_PROGRAM, n_rows) + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < n_cols + + dalpha = 0.0 + dgamma = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + + x_ptr += row_start * x_row_stride + dx_ptr += row_start * dx_row_stride + dy_ptr += row_start * dy_row_stride + alpha = tl.load(alpha_ptr) + gamma = tl.load(gamma_ptr + offsets, mask=mask, other=0.0) + + for _ in tl.range(row_start, row_end): + dy = tl.load(dy_ptr + offsets, mask=mask, other=0.0) + x = tl.load(x_ptr + offsets, mask=mask, other=0.0) + tanh_ax = tanh((alpha * x).cast(tl.float32)) + sech2_ax = 1 - tanh_ax * tanh_ax + + dx = dy * gamma * sech2_ax * alpha + dalpha += tl.sum(dy * gamma * sech2_ax * x) + dgamma += dy * tanh_ax + tl.store(dx_ptr + offsets, dx, mask=mask) + + dy_ptr += dy_row_stride + x_ptr += x_row_stride + dx_ptr += dx_row_stride + + tl.store(dgamma_ptr + pid * dgamma_row_stride + offsets, dgamma, mask=mask) + tl.store(dalpha_ptr + pid, dalpha) + + pass + + +def liger_dyt_fwd(x, alpha, gamma, beta): + shape = x.shape + dim = shape[-1] + x = x.view(-1, dim) + n_rows, n_cols = x.shape + y = torch.empty_like(x) + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + _dyt_fwd_kernel[(n_rows,)]( + x_ptr=x, + alpha_ptr=alpha, + gamma_ptr=gamma, + beta_ptr=beta, + y_ptr=y, + x_row_stride=x.stride(0), + y_row_stride=y.stride(0), + n_cols=n_cols, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + return y.view(*shape) + + +def liger_dyt_bwd(dy, x, alpha, gamma): + shape = dy.shape + dtype = x.dtype + dim = shape[-1] + dy = dy.view(-1, dim) + x = x.view(-1, dim) + n_rows, n_cols = dy.shape + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + sm_count = 1 + device = infer_device() + if device == "cuda": + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + elif device == "xpu": + sm_count = torch.xpu.get_device_properties(x.device).gpu_subslice_count + if n_cols > BLOCK_SIZE: + raise RuntimeError( + f"Feature dimension {dim} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension." + ) + + dx = torch.empty_like(x, dtype=torch.float32) + _dalpha = torch.empty((sm_count,), dtype=torch.float32, device=x.device) + _dgamma = torch.empty((sm_count, n_cols), dtype=torch.float32, device=x.device) + + grid = (sm_count,) + rows_per_program = triton.cdiv(n_rows, sm_count) + _dyt_bwd_kernel[grid]( + x_ptr=x, + x_row_stride=x.stride(0), + dy_ptr=dy, + dy_row_stride=dy.stride(0), + dx_ptr=dx, + dx_row_stride=dx.stride(0), + alpha_ptr=alpha, + dalpha_ptr=_dalpha, + gamma_ptr=gamma, + dgamma_ptr=_dgamma, + dgamma_row_stride=_dgamma.stride(0), + n_cols=n_cols, + n_rows=n_rows, + ROWS_PER_PROGRAM=rows_per_program, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + dalpha = _dalpha.sum(dim=0, keepdim=True).to(dtype) + dgamma = _dgamma.sum(dim=0).to(dtype) + dbeta = dy.sum(dim=0).to(dtype) + return dx.view(*shape), dalpha, dgamma, dbeta + + +class LigerDyTFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward(ctx, x, alpha, gamma, beta): + y = liger_dyt_fwd(x, alpha, gamma, beta) + ctx.save_for_backward(x, alpha, gamma) + return y + + @staticmethod + @ensure_contiguous + def backward(ctx, grad_output): + x, alpha, gamma = ctx.saved_tensors + dx, dalpha, dgamma, dbeta = liger_dyt_bwd( + grad_output, + x, + alpha, + gamma, + ) + + return (dx, dalpha, dgamma, dbeta) \ No newline at end of file diff --git a/build/torch-rocm/fused_linear_cross_entropy.py b/build/torch-rocm/fused_linear_cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..a923c05e268872a75ae473f23c7e64df491c866b --- /dev/null +++ b/build/torch-rocm/fused_linear_cross_entropy.py @@ -0,0 +1,283 @@ +import torch +import triton + +from .cross_entropy import liger_cross_entropy_kernel +from .utils import amp_custom_bwd +from .utils import amp_custom_fwd +from .utils import element_mul_kernel +from .utils import is_hip + +# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 +# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling +# The optimal maximum block size depends on your hardware, your kernel, and your dtype +MAX_FUSED_SIZE = 65536 // 2 + + +def fused_linear_cross_entropy_forward( + _input, + weight, + target, + ce_weight=None, + bias=None, + ignore_index=-100, + lse_square_scale=0.0, + label_smoothing=0.0, + reduction="mean", + softcap=None, + return_z_loss=False, +): + assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}" + device = _input.device + + # inputs have shape: BT x H + # materialized activations will have shape: BT x V + # the increase in memory = BT x V + # reduction can be achieved by partitioning the number of tokens BT into smaller chunks. + # for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be: + # inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor + # for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048 + BT, H = _input.shape + V = weight.shape[0] + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + + inc_factor = triton.cdiv(V, H) # (V + H - 1) // H + chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor + num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size + + grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None + grad_input = torch.zeros_like(_input, device=device) + grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None + # we use fp32 for loss accumulator + loss_1d = torch.zeros(BT, dtype=torch.float32, device=device) + z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None + + # TODO: evaluate how CUDA synchronization caused by .item() affects the speed + target_mask = target != ignore_index + total_n_non_ignore = target_mask.sum().item() + total_sum_non_ignore_ce_weight = total_n_non_ignore + ce_weight_sum = 0.0 + if ce_weight is not None: + assert ce_weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {ce_weight.shape}" + assert torch.is_floating_point(ce_weight), ( + f"If given, weight has to be a Tensor of floating point dtype. Got: {ce_weight.dtype}" + ) + total_sum_non_ignore_ce_weight = ( + torch.gather(ce_weight, dim=0, index=target.masked_select(target_mask)).sum().item() + ) + ce_weight_sum = ce_weight.sum().item() + if ce_weight.stride(-1) != 1: + ce_weight = ce_weight.contiguous() + + for chunk_id in range(num_chunks): + start_idx = chunk_id * chunk_size + end_idx = min((chunk_id + 1) * chunk_size, BT) + _input_chunk = _input[start_idx:end_idx] # chunk_size x H + + # when doing matmul, use the original precision + logits_chunk = _input_chunk @ weight.t() # chunk_size x V + if bias is not None: + logits_chunk = logits_chunk + bias + + target_chunk = target[start_idx:end_idx] # chunk_size, + + n_rows = logits_chunk.shape[0] + + # unreduced loss + loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size, + z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None + + # ensure _input and target are contiguous + logits_chunk = logits_chunk.contiguous() + target_chunk = target_chunk.contiguous() + + # Here we calculate the gradient of logits_chunk in place so we can save memory. + liger_cross_entropy_kernel[(n_rows,)]( + X_ptr=logits_chunk, + X_stride=logits_chunk.stride(-2), + Y_ptr=target_chunk, + Y_stride=target_chunk.stride(-1), # always 1 + weight_ptr=ce_weight, + loss_ptr=loss_1d_slice, + z_loss_ptr=z_loss_1d_slice, + loss_stride=loss_1d_slice.stride(-1), # always 1 + n_cols=V, + n_non_ignore=total_n_non_ignore, + sum_non_ignore_weight=total_sum_non_ignore_ce_weight, + weight_sum=ce_weight_sum, + ignore_index=ignore_index, + lse_square_scale=lse_square_scale, + label_smoothing=label_smoothing, + reduction=reduction, + softcap=softcap, + RETURN_Z_LOSS=return_z_loss, + HAS_WEIGHT=True if ce_weight is not None else False, + HAS_SOFTCAPPING=True if softcap is not None else False, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32 if not is_hip() else 16, + ) + + loss_1d[start_idx:end_idx] = loss_1d_slice + if return_z_loss: + z_loss_1d[start_idx:end_idx] = z_loss_1d_slice + grad_logits_chunk = logits_chunk # chunk_size x V + + grad_input[start_idx:end_idx] = grad_logits_chunk @ weight + + if grad_weight is not None: + torch.addmm( + input=grad_weight, + mat1=logits_chunk.t().to( + _input_chunk.dtype + ), # In an autocast scenario without bias, differing logits_chunk data types will cause an addmm operation error. + mat2=_input_chunk, + out=grad_weight, + alpha=1.0, + beta=1.0, + ) + + if bias is not None: + torch.add( + input=grad_bias, + other=logits_chunk.sum(dim=0), + out=grad_bias, + alpha=1.0, + ) + + # Need extra calculations for backward if reduction=='none'. Not supporting reduction='none' now. + # if reduction == "none": + # loss = loss_1d + # z_loss = z_loss_1d if return_z_loss else None + + else: + loss = torch.sum(loss_1d) + z_loss = torch.sum(z_loss_1d) if return_z_loss else None + return loss, z_loss, grad_input, grad_weight, grad_bias + + +def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias): + # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time + if not torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place + # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. + BT, H = grad_input.shape + n_rows = BT + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H)) + + element_mul_kernel[(n_rows,)]( + grad_input, + grad_input.stride(-2), + grad_output, + H, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32 if not is_hip() else 16, + ) + + # handle grad_weight + if grad_weight is not None: + V, H = grad_weight.shape + n_rows = V + + element_mul_kernel[(n_rows,)]( + grad_weight, + grad_weight.stride(-2), + grad_output, + H, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32 if not is_hip() else 16, + ) + + if grad_bias is not None: + V = grad_bias.shape[0] + n_rows = V + + element_mul_kernel[(n_rows,)]( + grad_bias, + grad_bias.stride(-1), + grad_output, + 1, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32 if not is_hip() else 16, + ) + return grad_input, grad_weight, grad_bias + + +class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function): + @staticmethod + @amp_custom_fwd + def forward( + ctx, + _input, + weight, + target, + bias=None, + ce_weight=None, + ignore_index=-100, + lse_square_scale=0.0, + label_smoothing=0.0, + reduction="mean", + softcap=None, + return_z_loss: bool = False, + ): + """ + Fusing the last linear layer with cross-entropy loss + Reference: https://github.com/mgmalek/efficient_cross_entropy + + Handle the forward and backward pass of the final linear layer via cross-entropy loss by avoiding + the materialization of the large logits tensor. Since Cross Entropy Loss is the last layer, we can + compute the gradient at the forward pass. By doing so, we don't have to store the _input and target + for the backward pass. + + _input: (B*T, H) where B is batch size, T is sequence length, H is hidden dimension. + target: (B*T) where each value is in [0, V-1] + weight: (V, H) where V is the number of classes + bias: (V) where V is the number of classes + ce_weight: a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype + ignore_index: the index to ignore in the target + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + reduction: reduction to apply + """ + + loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward( + _input=_input, + weight=weight, + target=target, + bias=bias, + ce_weight=ce_weight, + ignore_index=ignore_index, + lse_square_scale=lse_square_scale, + label_smoothing=label_smoothing, + reduction=reduction, + softcap=softcap, + return_z_loss=return_z_loss, + ) + # downcast to dtype and store for backward + ctx.save_for_backward( + grad_input.detach(), + grad_weight.detach() if grad_weight is not None else None, + grad_bias.detach() if bias is not None else None, + ) + ctx.return_z_loss = return_z_loss + return loss, z_loss + + @staticmethod + @amp_custom_bwd + def backward(ctx, grad_output, grad_output2): + if ctx.return_z_loss: + del grad_output2 # z_loss is only for logging + (grad_input, grad_weight, grad_bias) = ctx.saved_tensors + grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward( + grad_output, grad_input, grad_weight, grad_bias + ) + return ( + grad_input, + grad_weight, + None, + grad_bias, + None, + None, + None, + None, + None, + None, + None, + ) \ No newline at end of file diff --git a/build/torch-rocm/geglu.py b/build/torch-rocm/geglu.py new file mode 100644 index 0000000000000000000000000000000000000000..fdc06723b2852b6f5a6043ad45968af154f9afeb --- /dev/null +++ b/build/torch-rocm/geglu.py @@ -0,0 +1,141 @@ +import operator + +import torch +import triton +import triton.language as tl + +from .utils import calculate_settings +from .utils import compare_version +from .utils import ensure_contiguous + +if compare_version("triton", operator.ge, "3.0.0"): + try: + # typical import path with dispatch available + from triton.language.extra.libdevice import tanh + except ModuleNotFoundError: + # for working with NGC containers + from triton.language.extra.cuda.libdevice import tanh +else: + from triton.language.math import tanh + + +@triton.jit +def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr): + program_id = tl.program_id(0).to(tl.int64) + + # locate start index + a += program_id * stride + b += program_id * stride + c += program_id * stride + + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32) + b_row = tl.load(b + col_offsets, mask=mask, other=0) + + # tanh approximation form of GELU is computed with: + # 0.5 * a * (1 + tanh(sqrt(2 / pi) * (a + 0.044715 * a^3))) + sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi) + a_cubed = a_row * a_row * a_row + tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed) + tanh_result = tanh(tanh_arg) + geglu_a = 0.5 * a_row * (1 + tanh_result) + c_row = geglu_a * b_row + tl.store(c + col_offsets, c_row, mask=mask) + + +@triton.jit +def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr): + program_id = tl.program_id(0).to(tl.int64) + + # locate start index + dc += program_id * stride + a += program_id * stride + b += program_id * stride + + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + dc_row = tl.load(dc + col_offsets, mask=mask, other=0) + a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32) + b_row = tl.load(b + col_offsets, mask=mask, other=0) + + # recomputation to save memory + sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi) + a_cubed = a_row * a_row * a_row + tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed) + tanh_result = tanh(tanh_arg) + geglu_a = 0.5 * a_row * (1 + tanh_result) + + db_row = dc_row * geglu_a + + # Gradient w.r.t. a can be computed with: + # b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2))) + # where z = sqrt(2/pi) * (a + 0.044715 * a^3) + term1 = 0.5 * (1 + tanh_result) + tanh_sq = tanh_result * tanh_result + term2 = 0.5 * a_row * (1 - tanh_sq) * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row)) + da_row = dc_row * b_row * (term1 + term2) + + tl.store(a + col_offsets, da_row, mask=mask) + tl.store(b + col_offsets, db_row, mask=mask) + + +def geglu_forward(a, b): + ori_shape = a.shape + + n_cols = ori_shape[-1] + a = a.view(-1, n_cols) + b = b.view(-1, n_cols) + c = torch.empty_like(a) + n_rows = a.shape[0] + + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + + _geglu_tanh_forward_kernel[(n_rows,)]( + a, + b, + c, + c.stride(-2), + n_cols=n_cols, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + return a, b, c.view(*ori_shape) + + +def geglu_backward(a, b, dc): + ori_shape = dc.shape + n_cols = ori_shape[-1] + dc = dc.view(-1, n_cols) + n_rows = dc.shape[0] + + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + + _geglu_tanh_backward_kernel[(n_rows,)]( + dc, + a, + b, + dc.stride(-2), + n_cols=n_cols, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + + return a.view(*ori_shape), b.view(*ori_shape) + + +class LigerGELUMulFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward(ctx, a, b): + a, b, c = geglu_forward(a, b) + ctx.save_for_backward(a, b) + return c + + @staticmethod + @ensure_contiguous + def backward(ctx, dc): + a, b = ctx.saved_tensors + a, b = geglu_backward(a, b, dc) + return a, b \ No newline at end of file diff --git a/build/torch-rocm/group_norm.py b/build/torch-rocm/group_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..74b460c656e46d4544703345f9892763271588c2 --- /dev/null +++ b/build/torch-rocm/group_norm.py @@ -0,0 +1,305 @@ +import operator + +import torch +import triton +import triton.language as tl + +from .utils import compare_version +from .utils import ensure_contiguous + +if compare_version("triton", operator.ge, "3.0.0"): + try: + # typical import path with dispatch available + from triton.language.extra.libdevice import rsqrt + except ModuleNotFoundError: + # for working with NGC containers + from triton.language.extra.cuda.libdevice import rsqrt +else: + from triton.language.math import rsqrt + +MAX_FUSED_SIZE = 65536 + + +@triton.jit +def _group_norm_forward_kernel( + Y_ptr, # pointer to output, shape (n_rows, n_groups, hidden_size) + Y_row_stride, # stride of each row in output + Y_col_stride, # stride of each column in output + X_ptr, # pointer to input, shape (n_rows, n_groups, hidden_size) + X_row_stride, # stride of each row in input + X_col_stride, # stride of each column in input + Mean_ptr, # pointer to mean, shape (n_rows, n_groups) + Mean_row_stride, # stride of each row in mean + Mean_col_stride, # stride of each column in mean + RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups) + RSTD_row_stride, # stride of each row in rstd + RSTD_col_stride, # stride of each column in rstd + W_ptr, # pointer to W + B_ptr, # pointer to B + hidden_size, # hidden size of X + channels_per_group, # the number of channels per group + eps, + BLOCK_SIZE: tl.constexpr, +): + """ + References: + https://nn.labml.ai/normalization/group_norm/index.html + """ + batch_idx = tl.program_id(0) + group_idx = tl.program_id(1) + + X_ptr += batch_idx * X_row_stride + group_idx * X_col_stride + Y_ptr += batch_idx * Y_row_stride + group_idx * Y_col_stride + + block_range = tl.arange(0, BLOCK_SIZE) + + # Compute mean and variance using the online algorithm + s = 0.0 + squared_sum = 0.0 + for i in tl.range(0, hidden_size, BLOCK_SIZE): + hidden_size_offsets = i + block_range + mask = hidden_size_offsets < hidden_size + X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=0.0) + s += tl.sum(X) + # X**2 + squared_sum += tl.sum(X * X) + + m = s / hidden_size + + # variance = E[X**2] - E[X]**2 + variance = (squared_sum / hidden_size) - (m * m) + + # 1/std + rstd = rsqrt(variance + eps) + + # Normalize + hidden_size_per_channel = hidden_size // channels_per_group + for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group): + W = tl.load(W_ptr + channel_idx) + B = tl.load(B_ptr + channel_idx) + for i in range(0, hidden_size_per_channel, BLOCK_SIZE): + hidden_size_offsets = i + block_range + mask = hidden_size_offsets < hidden_size_per_channel + X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m) + Y = (X - m) * rstd * W + B + tl.store(Y_ptr + hidden_size_offsets, Y, mask=mask) + + X_ptr += hidden_size_per_channel + Y_ptr += hidden_size_per_channel + + tl.store(Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, m) + tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd) + + +@triton.jit +def _group_norm_backward_kernel( + X_ptr, # pointer to input, shape (n_rows, n_channels, hidden_size) + X_row_stride, # stride of each row in input + X_col_stride, # stride of each column in input + W_ptr, # pointer to weights, shape (n_channels) + Mean_ptr, # pointer to mean, shape (n_rows, n_groups) + Mean_ptr_row_stride, # stride of each column in mean + Mean_ptr_col_stride, # stride of each column in mean + RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups) + DX_ptr, # pointer to input grad, shape (n_rows, n_groups, hidden_size) + DW_ptr, # pointer to weights grad, shape (n_channels) + DB_ptr, # pointer to bias grad, shape (n_channels) + UPSTREAM_ptr, # pointer to output grad, shape (n_rows, n_channels, hidden_size) + hidden_size: tl.constexpr, # hidden size + channels_per_group: tl.constexpr, # number of groups in group norm + BLOCK_SIZE: tl.constexpr, + dtype: tl.constexpr, +): + """ + References: + https://nn.labml.ai/normalization/group_norm/index.html + https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md + + The backprop equations are the same for group_norm and layer_norm + the only difference here is that we load the Mean, Rstd corresponding to the + group we're computing gradients for and the mean and rstd are computed over n-channels + so the total number of elements we compute the mean over is num_channels_per_group * hidden_size + + We also need to load the Weights corresponding to the current channel to compute the gradients. + """ + batch_idx = tl.program_id(0) + group_idx = tl.program_id(1) + + # Move the pointers to the correct batch + X_ptr += batch_idx * X_row_stride + DX_ptr += batch_idx * X_row_stride + UPSTREAM_ptr += batch_idx * X_row_stride + + # Mean and rstd are the same shape so have the same strides + mean = tl.load(Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride) + rstd = tl.load(RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride) + + c1 = 0.0 + c2 = 0.0 + block_range = tl.arange(0, BLOCK_SIZE) + + # We need to compute the sum terms of the backprop equations across all channels in the group + for channel_idx in range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group): + dW = 0.0 + dB = 0.0 + # Move the pointers to the correct channel + W = tl.load(W_ptr + channel_idx) + for i in tl.range(0, hidden_size, BLOCK_SIZE): + hidden_size_offsets = i + block_range + mask = hidden_size_offsets < hidden_size + X = tl.load( + X_ptr + channel_idx * X_col_stride + hidden_size_offsets, + mask=mask, + other=0.0, + ) + UPSTREAM_grad = tl.load( + UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets, + mask=mask, + other=0.0, + ) + + x_hat = (X - mean) * rstd + dW += tl.sum(UPSTREAM_grad * x_hat) + dB += tl.sum(UPSTREAM_grad) + + wdy = W * UPSTREAM_grad + c1 += tl.sum(x_hat * wdy) + c2 += tl.sum(wdy) + + # Need to ensure additions to the same channel are atomic + tl.atomic_add(DW_ptr + channel_idx, dW.to(dtype)) + tl.atomic_add(DB_ptr + channel_idx, dB.to(dtype)) + + N = hidden_size * channels_per_group + c1 = c1 / N + c2 = c2 / N + + for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group): + # Move the pointers to the correct channel + W = tl.load(W_ptr + channel_idx) + for i in range(0, hidden_size, BLOCK_SIZE): + hidden_size_offsets = i + block_range + mask = hidden_size_offsets < hidden_size + X = tl.load( + X_ptr + channel_idx * X_col_stride + hidden_size_offsets, + mask=mask, + other=0.0, + ) + UPSTREAM_grad = tl.load( + UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets, + mask=mask, + other=0.0, + ) + + x_hat = (X - mean) * rstd + wdy = W * UPSTREAM_grad + dx = (wdy - (x_hat * c1 + c2)) * rstd + tl.store(DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask) + + +def group_norm_forward(X, num_channels, num_groups, W, B, eps): + shape = X.shape + batch_size = shape[0] + channels_per_group = num_channels // num_groups + # Reshape X so that the mean and std are computed across the groups + X = X.view(batch_size, num_groups, -1).contiguous() + hidden_size = X.shape[-1] + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size)) + Y = torch.empty((batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device) + Mean = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device) + RSTD = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device) + + _group_norm_forward_kernel[(batch_size, num_groups)]( + Y, + Y.stride(0), + Y.stride(1), + X, + X.stride(0), + X.stride(1), + Mean, + Mean.stride(0), + Mean.stride(1), + RSTD, + RSTD.stride(0), + RSTD.stride(1), + W, + B, + hidden_size, + channels_per_group, + eps, + BLOCK_SIZE=BLOCK_SIZE, + ) + # Return tensors in the original shape + return Y.view(*shape), X.view(*shape), Mean, RSTD, BLOCK_SIZE + + +def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups): + shape = dY.shape + batch_size = shape[0] + hidden_size = dY.shape[-1] + channels_per_group = num_channels // num_groups + dY = dY.view(batch_size, num_groups, -1) + DX = torch.empty( + (batch_size, num_groups, hidden_size * channels_per_group), + dtype=X.dtype, + device=X.device, + ) + DW = torch.zeros((num_channels), dtype=W.dtype, device=W.device) + DB = torch.zeros((num_channels), dtype=B.dtype, device=B.device) + triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16 + + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size)) + _group_norm_backward_kernel[(batch_size, num_groups)]( + X, + X.stride(0), + X.stride(1), + W, + Mean, + Mean.stride(0), + Mean.stride(1), + RSTD, + DX, + DW, + DB, + dY, + hidden_size, + channels_per_group, + BLOCK_SIZE=BLOCK_SIZE, + dtype=triton_dtype, + ) + + # Return tensors in the original shape + return DX.view(*shape), DW, DB + + +class LigerGroupNormFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward( + ctx, + X, + affine_scaling_weight, + affine_shifting_bias, + num_channels, + num_groups, + eps, + ): + Y, X, Mean, RSTD, BLOCK_SIZE = group_norm_forward( + X, + num_channels, + num_groups, + affine_scaling_weight, + affine_shifting_bias, + eps, + ) + ctx.num_channels = num_channels + ctx.num_groups = num_groups + ctx.save_for_backward(X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD) + return Y + + @staticmethod + @ensure_contiguous + def backward(ctx, dY): + X, W, B, Mean, RSTD = ctx.saved_tensors + DX, DW, DB = group_norm_backward(dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups) + return DX, DW, DB, None, None, None \ No newline at end of file diff --git a/build/torch-rocm/jsd.py b/build/torch-rocm/jsd.py new file mode 100644 index 0000000000000000000000000000000000000000..b879e0674471983e789f9b05d3111b51fb172671 --- /dev/null +++ b/build/torch-rocm/jsd.py @@ -0,0 +1,201 @@ +from typing import Optional + +import torch +import triton +import triton.language as tl + +from .utils import ensure_contiguous +from .utils import infer_device + + +@triton.jit +def _jsd_kernel( + X_ptr, # input in logspace, X = log Q + X_stride, + Y_ptr, # ground truth in logspace, Y = log P + Y_stride, + loss_ptr, + loss_stride, + dX_ptr, + dX_stride, + label_ptr, + beta: tl.constexpr, + n_non_ignore: int, + ignore_index: tl.constexpr, + n_cols, + BLOCK_SIZE: tl.constexpr, + HAS_LABEL: tl.constexpr, +): + # JSD(P || Q) = (KL(P || M) + KL(Q || M)) / 2, M = (1/2) * (P + Q) = (1/2) * (e ^ Y + e ^ X) + # = sum(P * log P + Q * log Q - 2 * M * log M) / 2 + # = sum(e ^ Y * Y + e ^ X * X - 2 * M * log M) / 2 + # grad_x_i = 0.5 * Q * (X - log_M) + pid = tl.program_id(0).to(tl.int64) + X_ptr += pid * X_stride + dX_ptr += pid * dX_stride + Y_ptr += pid * Y_stride + loss_ptr += pid * loss_stride + label_ptr += pid + + if HAS_LABEL: + label = tl.load(label_ptr) + if label == ignore_index: + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + tl.arange(0, BLOCK_SIZE) + tl.store(dX_ptr + offsets, 0.0, mask=offsets < n_cols) + return + + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_cols + X = tl.load(X_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32) + Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32) + + if beta == 0.0: # forward KL + Y_max = tl.max(Y, axis=0) + Y_shifted = Y - Y_max + Y_prob = tl.exp(Y_shifted) * tl.exp(Y_max) # Compensate for the shift + loss = Y_prob * (Y - X) + dX = -Y_prob + elif beta == 1.0: # reverse KL + X_max = tl.max(X, axis=0) + X_shifted = X - X_max + X_prob = tl.exp(X_shifted) * tl.exp(X_max) # Compensate for the shift + loss = X_prob * (X - Y) + dX = loss + X_prob + else: + max_val = tl.maximum(tl.max(X, axis=0), tl.max(Y, axis=0)) + X_shifted = X - max_val + Y_shifted = Y - max_val + + # Pre-compute exp(max_val) since it's used twice + exp_max = tl.exp(max_val) + + # Compute exp terms with compensation + Q = tl.exp(X_shifted) * exp_max # = exp(X) + P = tl.exp(Y_shifted) * exp_max # = exp(Y) + + # Pre-compute common terms + beta_P = beta * P + one_minus_beta_Q = (1 - beta) * Q + M = beta_P + one_minus_beta_Q + log_M = tl.log(M) # No need to compensate as M is already in original scale + + loss = beta_P * Y + one_minus_beta_Q * X - M * log_M + dX = one_minus_beta_Q * (X - log_M) + + # Pre-compute scaling factor + scale = 1.0 / n_non_ignore + loss = loss * scale + dX = dX * scale + + tl.store(loss_ptr + offsets, loss, mask=mask) + tl.store(dX_ptr + offsets, dX, mask=mask) + + +MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 + + +def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label): + BT, V = _input.shape + n_rows = BT + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + # non reduction loss + loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device) + dX = torch.empty_like(_input) + + if has_label: + n_non_ignore = (shift_labels != ignore_index).sum().item() + else: + n_non_ignore = BT + + _jsd_kernel[(n_rows,)]( + X_ptr=_input, # input in logspace, X = log Q + X_stride=_input.stride(-2), + Y_ptr=target, # ground truth in logspace, Y = log P + Y_stride=target.stride(-2), + loss_ptr=loss, + loss_stride=loss.stride(-2), + dX_ptr=dX, + dX_stride=dX.stride(-2), + label_ptr=(shift_labels if has_label else torch.empty(1, device=_input.device)), # dummy ptr if no label + beta=beta, + n_non_ignore=n_non_ignore, + ignore_index=ignore_index, + n_cols=V, + BLOCK_SIZE=BLOCK_SIZE, + HAS_LABEL=has_label, + ) + + loss = torch.sum(loss) + return loss.to(_input.dtype), dX + + +def jsd_backward(dX, grad_output): + # If jsd is the last layer, grad_output is 1.0. Skip the mul to save time + if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + return dX + else: + return grad_output * dX + + +class LigerJSDFunction(torch.autograd.Function): + r""" + This class implements the forward and backward pass for the generalized Jensen-Shannon Divergence. + .. math:: + JSD(\beta)(P || Q) + = \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q)) + + .. note:: + As all the other losses in PyTorch, this function expects the first argument, + :attr:`_input`, to be the predictions, the output of the student model, in log-space + and the second, :attr:`target`, to be the observations, the output of the teacher model, in log-space. + This differs from the standard mathematical notation :math:`JSD(P || Q)` where + :math:`P` denotes the teacher model and :math:`Q` denotes the student model. + """ + + @staticmethod + @ensure_contiguous + def forward( + ctx, + _input: torch.Tensor, + target: torch.Tensor, + shift_labels: Optional[torch.Tensor] = None, + beta: float = 0.5, + ignore_index: int = -100, + ) -> torch.Tensor: + """ + Args: + _input (torch.Tensor): predict values with shape (BT, V) in logspace + target (torch.Tensor): ground truth values with shape (BT, V) in logspace + shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1]. + beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5` + ignore_index (int): the index to ignore. Default: -100 + + Returns: + loss (torch.Tensor): generalized JSD + """ + has_label = False + if shift_labels is not None: + assert shift_labels.shape == (_input.shape[0],), ( + f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}" + ) + shift_labels = shift_labels.contiguous() + has_label = True + + loss, dX = jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label) + ctx.save_for_backward(dX) + return loss + + @staticmethod + @ensure_contiguous + def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: + (dX,) = ctx.saved_tensors + dX = jsd_backward(dX, grad_output) + return ( + dX, + None, + None, + None, + None, + ) \ No newline at end of file diff --git a/build/torch-rocm/kl_div.py b/build/torch-rocm/kl_div.py new file mode 100644 index 0000000000000000000000000000000000000000..2d563a7eae13aac671e3cf85283758b50a3fdb93 --- /dev/null +++ b/build/torch-rocm/kl_div.py @@ -0,0 +1,262 @@ +from typing import Literal + +import torch +import triton +import triton.language as tl + +from .utils import ensure_contiguous +from .utils import is_hip +from .utils import infer_device + + +def get_num_warps(BLOCK_SIZE): + num_warps = 4 + if BLOCK_SIZE >= 32768: + num_warps = 32 if not is_hip() else 16 + elif BLOCK_SIZE >= 8192: + num_warps = 16 + elif BLOCK_SIZE >= 2048: + num_warps = 8 + + return num_warps + + +MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best + +REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"] + +_REDUCTION_MODE_NONE: tl.constexpr = tl.constexpr(0) +_REDUCTION_MODE_SUM: tl.constexpr = tl.constexpr(1) +_REDUCTION_MODE_MEAN: tl.constexpr = tl.constexpr(2) +_REDUCTION_MODE_BATCHMEAN: tl.constexpr = tl.constexpr(3) + +_str_to_reduction_mode = { + "none": _REDUCTION_MODE_NONE.value, + "sum": _REDUCTION_MODE_SUM.value, + "mean": _REDUCTION_MODE_MEAN.value, + "batchmean": _REDUCTION_MODE_BATCHMEAN.value, +} + + +@triton.jit +def _kldiv_kernel_forward( + y_ptr, # [B, S], prediction ptr, the kernel expects the prediction in log-space + y_stride, # int, prediction stride + gt_ptr, # [B, S], ground truth ptr + gt_stride, # int, ground truth stride + loss_ptr, # [B] or [B, S] if reduction == _REDUCTION_MODE_NONE, output ptr + loss_stride, # int, output stride + n_cols, # int, number of columns in the input tensor + eps, + BLOCK_SIZE: tl.constexpr, + log_target: tl.constexpr = False, + reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN, +): + pid = tl.program_id(0).to(tl.int64) + y_ptr += pid * y_stride + gt_ptr += pid * gt_stride + loss_ptr += pid * loss_stride + + base_offsets = tl.arange(0, BLOCK_SIZE) + + loss_sum = 0.0 + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + base_offsets + mask = offsets < n_cols + y = tl.load(y_ptr + offsets, mask=mask, other=0.0) + y_true = tl.load(gt_ptr + offsets, mask=mask, other=0.0) + + # KL(y_true || y) = y_true * (log(y_true) - log(y)) + # We compute KL(y_true || y) with y in the log-space + if not log_target: + loss = y_true * (tl.log(tl.maximum(y_true, eps)) - y) + else: + loss = tl.exp(y_true) * (y_true - y) + + if reduction == _REDUCTION_MODE_NONE: + tl.store(loss_ptr + offsets, loss, mask=mask) + else: + loss_sum += tl.sum(loss, axis=0) + + if reduction != _REDUCTION_MODE_NONE: + tl.store(loss_ptr, loss_sum) + + +@triton.jit +def _kldiv_kernel_backward( + target_ptr, + target_stride, + new_grads_ptr, + new_grads_stride, + n_cols, + BLOCK_SIZE: tl.constexpr, + log_target: tl.constexpr = False, +): + pid = tl.program_id(0).to(tl.int64) + + target_ptr += pid * target_stride + new_grads_ptr += pid * new_grads_stride + + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < n_cols + + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_cols + + target = tl.load(target_ptr + offsets, mask=mask, other=0.0) + + if not log_target: + res = target * -1 + else: + res = -tl.exp(target) + + tl.store(new_grads_ptr + offsets, res, mask=mask) + + +def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V] + BT, V = y_pred.shape + BLOCK_SIZE = ( + min(8192, triton.next_power_of_2(V)) + if infer_device() == "xpu" + else min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + ) + num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE) + + grid = (BT,) + reduction = _str_to_reduction_mode[reduction] + + out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,) + output_tensor = torch.zeros(out_size, device=y_pred.device, dtype=torch.float32) + + _kldiv_kernel_forward[grid]( + y_pred, + y_pred.stride(0), + y_true, + y_true.stride(0), + output_tensor, + output_tensor.stride(0), + V, + eps=eps, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + log_target=log_target, + reduction=reduction, + ) + + # calculated according to the reduction mode same as in Pytorch. In the later versions, `mean` will be changed to the same behavior as `batchmean` + # https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html + # https://github.com/pytorch/pytorch/blob/d7b57c4d63edb42e1deeeba9497fcb5f1f748ff2/torch/nn/functional.py#L3372 + if reduction == _REDUCTION_MODE_BATCHMEAN.value: + return output_tensor.sum() / BT + elif reduction == _REDUCTION_MODE_SUM.value: + return output_tensor.sum(dim=0) + elif reduction == _REDUCTION_MODE_MEAN.value: + return output_tensor.sum() / (BT * V) + else: + return output_tensor + + +def kldiv_backward_triton(target, grad_output, new_grads, log_target): + BT, V = target.shape + BLOCK_SIZE = ( + min(8192, triton.next_power_of_2(V)) + if infer_device() == "xpu" + else min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + ) + num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE) + + grid = (BT,) + + # We store the gradients in-place in the input tensor + _kldiv_kernel_backward[grid]( + target, + target.stride(0), + new_grads, + new_grads.stride(0), + V, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + log_target=log_target, + ) + + # If cross entropy is the last layer, grad_output is 1.0. Skip the mul then. + if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + return new_grads + + return new_grads * grad_output + + +class LigerKLDivLossFunction(torch.autograd.Function): + """ + Class implementing the forward and backward pass for the KL Divergence Loss using Triton, as defined by the following formula: + ```python + if log_target: + loss = target.exp() * (target - input) + else: + loss = target * (target.log() - input) + ```, + then the loss is reduced according to the `reduction` parameter. + as defined in the PyTorch documentation: https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html + """ + + @staticmethod + @ensure_contiguous + def forward( + ctx, + y_pred: torch.Tensor, + y_true: torch.Tensor, + reduction: REDUCTION_LITERAL = "batchmean", + log_target: bool = False, + eps: float = 1e-10, + ) -> torch.Tensor: + """A forward pass for the KL Divergence Loss. + + Args: + ctx: Torch autograd context + y_pred (torch.Tensor): A tensor of shape (BT, V) containing the predicted values, expected to be log-probabilities. + y_true (torch.Tensor): A tensor of shape (BT, V) containing the target values, expected to be either probabilities or log-probabilities, depending on the value of `log_target`. + reduction (REDUCTION_LITERAL, optional): Reduction to be used. Defaults to "batchmean". + log_target (bool, optional): If set to true, expects the ground truth to already be log-probabilities. Defaults to False. + eps: (float, optional): A small value to avoid division by zero. Defaults to 1e-10. + + Returns: + torch.Tensor: The computed KL Divergence Loss, with shape (BT, V) if `reduction` is "none", else a scalar. + """ + ctx.save_for_backward(y_true) + ctx.reduction = reduction + ctx.log_target = log_target + return kldiv_forward_triton(y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps) + + @staticmethod + @ensure_contiguous + def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: + """A backward pass for the KL Divergence Loss. + + Args: + ctx: Torch autograd context + grad_output (torch.Tensor): The gradient of the loss with respect to the output. + + Returns: + tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs and None for the other arguments of the forward method. + """ + (y_true,) = ctx.saved_tensors + + new_grads = torch.empty_like(y_true) + + derivative = kldiv_backward_triton(y_true, grad_output, new_grads, ctx.log_target) + + if ctx.reduction == "batchmean": + derivative = derivative / y_true.shape[0] + elif ctx.reduction == "sum" or ctx.reduction == "none": + pass + elif ctx.reduction == "mean": + derivative = derivative / (y_true.shape[0] * y_true.shape[1]) + + return ( + derivative, + None, + None, + None, + None, + ) \ No newline at end of file diff --git a/build/torch-rocm/layer_norm.py b/build/torch-rocm/layer_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..00088223b1c58a5483c4227a2dc6be7f26cc1c3c --- /dev/null +++ b/build/torch-rocm/layer_norm.py @@ -0,0 +1,265 @@ +import math +import operator + +import torch +import triton +import triton.language as tl + +from .utils import calculate_settings +from .utils import compare_version +from .utils import ensure_contiguous + +if compare_version("triton", operator.ge, "3.0.0"): + try: + # typical import path with dispatch available + from triton.language.extra.libdevice import rsqrt + except ModuleNotFoundError: + # for working with NGC containers + from triton.language.extra.cuda.libdevice import rsqrt +else: + from triton.language.math import rsqrt + + +@triton.jit +def _layer_norm_forward_kernel( + Y_ptr, # pointer to output, shape (n_rows, n_cols) + Y_row_stride, # stride of each row in output + X_ptr, # pointer to input, shape (n_rows, n_cols) + X_row_stride, # stride of each row in input + W_ptr, # pointer to weights, shape (n_cols,) + W_row_stride, # stride of each row in weights + B_ptr, # pointer to bias, shape (n_cols,) + B_row_stride, # stride of each row in bias + Mean_ptr, # pointer to mean, shape (n_rows,) + Mean_row_stride, # stride of each row in mean + RSTD_ptr, # pointer to rstd, shape (n_rows,) + RSTD_row_stride, # stride of each row in rstd + n_cols, + eps, + BLOCK_SIZE: tl.constexpr, +): + """ + References: + https://arxiv.org/abs/1607.06450 + https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md + """ + row_idx = tl.program_id(0) + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + Y_ptr += row_idx * Y_row_stride + X_ptr += row_idx * X_row_stride + Mean_ptr += row_idx * Mean_row_stride + RSTD_ptr += row_idx * RSTD_row_stride + + X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0) + W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0) + B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0) + + mean = tl.sum(X_row, axis=0) / n_cols + Xmm = tl.where(mask, X_row - mean, 0) + var = tl.sum(Xmm * Xmm, axis=0) / n_cols + rstd = rsqrt(var + eps) + + tl.store(Mean_ptr, mean) + tl.store(RSTD_ptr, rstd) + + Y_row = Xmm * rstd * W_row + B_row + + tl.store(Y_ptr + col_offsets, Y_row, mask=mask) + + +@triton.jit +def _layer_norm_backward_kernel( + X_ptr, # pointer to input, shape (n_rows, n_cols) + W_ptr, # pointer to weights, shape (n_cols,) + Mean_ptr, # pointer to mean, shape (n_rows,) + RSTD_ptr, # pointer to rstd, shape (n_rows,) + DX_ptr, # pointer to input grad, shape (n_rows, n_cols) + DW_ptr, # pointer to weights grad, shape (n_cols,) + DB_ptr, # pointer to bias grad, shape (n_cols,) + DY_ptr, # pointer to output grad, shape (n_rows, n_cols) + stride_x, # stride of each row in input + stride_dx, # stride of each row in input grad + stride_dw, # stride of each row in weights grad + stride_db, # stride of each row in bias grad + stride_dy, # stride of each row in output grad + n_rows, + n_cols, + rows_per_program: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + dtype: tl.constexpr, +): + """ + References: + https://arxiv.org/abs/1607.06450 + https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md + https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html + https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py + """ + row_block_id = tl.program_id(0) + row_start = row_block_id * rows_per_program + row_end = min((row_block_id + 1) * rows_per_program, n_rows) + cols = tl.arange(0, BLOCK_SIZE) + mask = cols < n_cols + + dw_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + + X_ptr += row_start * stride_x + Mean_ptr += row_start + RSTD_ptr += row_start + DX_ptr += row_start * stride_dx + DY_ptr += row_start * stride_dy + + for _ in range(row_start, row_end): + x = tl.load(X_ptr + cols, mask=mask, other=0.0) + w = tl.load(W_ptr + cols, mask=mask, other=0.0) + dy = tl.load(DY_ptr + cols, mask=mask, other=0.0) + mean = tl.load(Mean_ptr) + rstd = tl.load(RSTD_ptr) + + x_hat = (x - mean) * rstd + wdy = w * dy + c1 = tl.sum(x_hat * wdy, axis=0) / n_cols + c2 = tl.sum(wdy, axis=0) / n_cols + dx = (wdy - (x_hat * c1 + c2)) * rstd + tl.store(DX_ptr + cols, dx.to(dtype), mask=mask) + + dw_row += dy * x_hat + db_row += dy + + X_ptr += stride_x + Mean_ptr += 1 + RSTD_ptr += 1 + DX_ptr += stride_dx + DY_ptr += stride_dy + + tl.store(DW_ptr + row_block_id * stride_dw + cols, dw_row.to(dtype), mask=mask) + tl.store(DB_ptr + row_block_id * stride_db + cols, db_row.to(dtype), mask=mask) + + +def layer_norm_forward(X, W, B, eps): + shape = X.shape + dim = shape[-1] + X = X.view(-1, dim) + n_rows, n_cols = X.shape + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) + Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device) + RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device) + if X.shape[1] != W.shape[0]: + raise ValueError( + f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) " + f"must match weight size (W.shape[0]={W.shape[0]})" + ) + + # XPU-specific optimization + kernel_args = {} + if X.device.type == "xpu": + kernel_args["grf_mode"] = "large" + + _layer_norm_forward_kernel[(n_rows,)]( + Y, + Y.stride(0), + X, + X.stride(0), + W, + W.stride(0), + B, + B.stride(0), + Mean, + Mean.stride(0), + RSTD, + RSTD.stride(0), + n_cols, + eps, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + **kernel_args, # XPU-specific optimization + ) + return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps + + +def layer_norm_backward(dY, X, W, B, Mean, RSTD): + shape = dY.shape + dim = shape[-1] + dY = dY.view(-1, dim) + n_rows, n_cols = dY.shape + + sm_count = 1 + if X.device.type == "cuda": + sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count + elif X.device.type == "xpu": + sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count + + DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) + _DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device) + _DB = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device) + + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + if n_cols > BLOCK_SIZE: + raise RuntimeError( + f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension." + ) + + rows_per_program = math.ceil(n_rows / sm_count) + grid = (sm_count,) + triton_dtype = ( + tl.float32 + if X.dtype == torch.float32 + else tl.bfloat16 + if X.dtype == torch.bfloat16 + else tl.float16 + if X.dtype == torch.float16 + else tl.float32 # fallback to float32 for other types + ) + + # XPU-specific optimization + kernel_args = {} + if X.device.type == "xpu": + kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4}) + + _layer_norm_backward_kernel[grid]( + X, + W, + Mean, + RSTD, + DX, + _DW, + _DB, + dY, + X.stride(0), + DX.stride(0), + _DW.stride(0), + _DB.stride(0), + dY.stride(0), + n_rows, + n_cols, + rows_per_program, + BLOCK_SIZE=BLOCK_SIZE, + dtype=triton_dtype, + **kernel_args, # XPU-specific optimization + ) + + DW = _DW.sum(dim=0).to(W.dtype) + DB = _DB.sum(dim=0).to(W.dtype) + + DX = DX.view(*shape) + return DX, DW, DB + + +class LigerLayerNormFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward(ctx, X, W, B, eps): + Y, X, Mean, RSTD, BLOCK_SIZE, num_warps = layer_norm_forward(X, W, B, eps) + ctx.save_for_backward(X, W, B, Mean, RSTD) + return Y + + @staticmethod + @ensure_contiguous + def backward(ctx, dY): + X, W, B, Mean, RSTD = ctx.saved_tensors + DX, DW, DB = layer_norm_backward(dY, X, W, B, Mean, RSTD) + return DX, DW, DB, None \ No newline at end of file diff --git a/build/torch-rocm/layers.py b/build/torch-rocm/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..798de28de6bcb77a98416812dcae89cb9e5014df --- /dev/null +++ b/build/torch-rocm/layers.py @@ -0,0 +1,39 @@ +import torch +from .rms_norm import LigerRMSNormFunction + +class LigerRMSNorm(torch.nn.Module): + """ + RMSNorm module that uses the optimized LigerRMSNormFunction. + + Args: + hidden_size (int): The size of the hidden dimension. + eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6. + offset (float, optional): Offset value to shift the weight tensor. Defaults to 0.0. + casting_mode (str, optional): The casting mode to use. Defaults to "llama". + in_place (bool, optional): Whether to modify dY in-place to store dX during backward. Defaults to True. + """ + + + weight: torch.Tensor + variance_epsilon: float + + def forward(self, hidden_states): + """ + Apply RMS normalization to the input tensor. + + Args: + hidden_states (torch.Tensor): Input tensor of shape (B, T, H) or (BxT, H) + + Returns: + torch.Tensor: Normalized tensor of the same shape as input + """ + return LigerRMSNormFunction.apply( + hidden_states, + self.weight, + self.variance_epsilon, + 0, + "llama", + True + ) + +__all__ = ["LigerRMSNorm"] \ No newline at end of file diff --git a/build/torch-rocm/liger_kernels/__init__.py b/build/torch-rocm/liger_kernels/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch-rocm/liger_kernels/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch-rocm/metadata.json b/build/torch-rocm/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..9cf5deed9898dce769f4cc73913d3530b92a0bd8 --- /dev/null +++ b/build/torch-rocm/metadata.json @@ -0,0 +1,4 @@ +{ + "version": 1, + "python-depends": [] +} \ No newline at end of file diff --git a/build/torch-rocm/qwen2vl_mrope.py b/build/torch-rocm/qwen2vl_mrope.py new file mode 100644 index 0000000000000000000000000000000000000000..33d02e7bcfb95402ccd3c1b0ad0ed8780ef422e4 --- /dev/null +++ b/build/torch-rocm/qwen2vl_mrope.py @@ -0,0 +1,222 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _triton_qwen2vl_mrope( + q_ptr, + k_ptr, + cos, + sin, + sl, + bs: tl.constexpr, + n_qh: tl.constexpr, + n_kh: tl.constexpr, + hd: tl.constexpr, + pad_n_qh: tl.constexpr, + pad_n_kh: tl.constexpr, + pad_hd: tl.constexpr, + mrope_section_t: tl.constexpr, + mrope_section_h: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BACKWARD_PASS: tl.constexpr = False, +): + pid = tl.program_id(0) + + # locate start address + q_ptr = q_ptr + pid * (n_qh * hd) + k_ptr = k_ptr + pid * (n_kh * hd) + + # #################################################################### + # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position + # m of this program instance + # #################################################################### + + # 1. program instances are laid out in a 1D vector of size bsz * seq_len, which + # effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension + # being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index + # and pid % sl to get the sequence index. + # 2. We only need the left half of cos and sin matrix because the right half is just + # a clone of the left half. + t_end = mrope_section_t + h_end = t_end + mrope_section_h + + t_cos = cos + pid * hd + h_cos = t_cos + bs * sl * hd + w_cos = h_cos + bs * sl * hd + t_sin = sin + pid * hd + h_sin = t_sin + bs * sl * hd + w_sin = h_sin + bs * sl * hd + + cos_offsets = tl.arange(0, pad_hd // 2) + t_mask = cos_offsets < t_end + h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end) + w_mask = (h_end <= cos_offsets) & (cos_offsets < hd // 2) + t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0) + h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0) + w_cos_row = tl.load(w_cos + cos_offsets, mask=w_mask, other=0) + t_sin_row = tl.load(t_sin + cos_offsets, mask=t_mask, other=0) + h_sin_row = tl.load(h_sin + cos_offsets, mask=h_mask, other=0) + w_sin_row = tl.load(w_sin + cos_offsets, mask=w_mask, other=0) + cos_row = t_cos_row + h_cos_row + w_cos_row + sin_row = t_sin_row + h_sin_row + w_sin_row + + # #################################################################### + # Load the left and right half of q and k for the current + # program instance (i.e. for the current token) separately + # #################################################################### + # left half of the head + first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] + first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] + first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2) + first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2) + q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype) + k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype) + + # right half of the head + second_half_q_offsets = first_half_q_offsets + (hd // 2) + second_half_k_offsets = first_half_k_offsets + (hd // 2) + second_q_mask = first_q_mask + second_k_mask = first_k_mask + q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype) + k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype) + + if not BACKWARD_PASS: + # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin] + new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row + tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) + new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row + tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) + + new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row + tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask) + new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row + tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) + else: + # with some math, we can get: + # dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin] + new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row + tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) + new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row + tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) + + new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row + tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask) + new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row + tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) + + +def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section): + # transpose it back to the physical shape because Triton looks at the physical storage + # note: q and k are incontiguous before the transformation and will become contiguous after transpose + q = q.transpose(1, 2) + k = k.transpose(1, 2) + + batch_size, seq_len, n_q_head, head_dim = q.shape + n_kv_head = k.shape[2] + pad_hd = triton.next_power_of_2(head_dim) + pad_n_q_head = triton.next_power_of_2(n_q_head) + pad_n_kv_head = triton.next_power_of_2(n_kv_head) + BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head) + + n_row = batch_size * seq_len + + # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous + q = q.contiguous() + k = k.contiguous() + cos = cos.contiguous() + sin = sin.contiguous() + + _triton_qwen2vl_mrope[(n_row,)]( + q, + k, + cos, + sin, + seq_len, + batch_size, + n_q_head, + n_kv_head, + head_dim, + pad_n_q_head, + pad_n_kv_head, + pad_hd, + mrope_section[0], + mrope_section[1], + BLOCK_SIZE=BLOCK_SIZE, + BACKWARD_PASS=False, + ) + return q.transpose(1, 2), k.transpose(1, 2), cos, sin + + +def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section): + dq = dq.transpose(1, 2) + dk = dk.transpose(1, 2) + + batch_size, seq_len, n_q_head, head_dim = dq.shape + n_kv_head = dk.shape[2] + pad_hd = triton.next_power_of_2(head_dim) + pad_n_q_head = triton.next_power_of_2(n_q_head) + pad_n_kv_head = triton.next_power_of_2(n_kv_head) + BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head) + + n_row = batch_size * seq_len + + # ensure dq and dk are contiguous + dq = dq.contiguous() + dk = dk.contiguous() + + # backward is similar to forward except swapping few ops + _triton_qwen2vl_mrope[(n_row,)]( + dq, + dk, + cos, + sin, + seq_len, + batch_size, + n_q_head, + n_kv_head, + head_dim, + pad_n_q_head, + pad_n_kv_head, + pad_hd, + mrope_section[0], + mrope_section[1], + BLOCK_SIZE=BLOCK_SIZE, + BACKWARD_PASS=True, + ) + return dq.transpose(1, 2), dk.transpose(1, 2) + + +class LigerQwen2VLMRopeFunction(torch.autograd.Function): + """ + Triton implementation of the Qwen2VL Multimodal Rotary Positional Embedding (M-RoPE) operation. + + Please find the corresponding HuggingFace implementation here: + https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py + """ + + @staticmethod + def forward(ctx, q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """ + q size: (bsz, n_q_head, seq_len, head_dim) + k size: (bsz, n_kv_head, seq_len, head_dim) + cos size: (3, bsz, seq_len, head_dim) + sin size: (3, bsz, seq_len, head_dim) + """ + q, k, cos, sin = qwen2vl_mrope_forward(q, k, cos, sin, mrope_section) + ctx.save_for_backward(cos, sin) + ctx.mrope_section = mrope_section + return q, k + + def backward(ctx, dq, dk): + """ + dq size: (bsz, n_q_head, seq_len, head_dim) + dk size: (bsz, n_kv_head, seq_len, head_dim) + cos size: (3, bsz, seq_len, head_dim) + sin size: (3, bsz, seq_len, head_dim) + """ + cos, sin = ctx.saved_tensors + mrope_section = ctx.mrope_section + dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section) + return dq, dk, None, None, None, None \ No newline at end of file diff --git a/build/torch-rocm/rms_norm.py b/build/torch-rocm/rms_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..bbf908671842d9e002774416c6a9ccc55acd2684 --- /dev/null +++ b/build/torch-rocm/rms_norm.py @@ -0,0 +1,365 @@ +""" +This file incorporates code from Unsloth licensed under the Apache License, Version 2.0. +See the original Unsloth repository at https://github.com/unslothai/unsloth. + +The following line +https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/rms_norm.py#L30 +is based on code from Unsloth, located at: +https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22 + +Modifications made by Yanning Chen, 2024. +""" + +import math +import operator + +import torch +import triton +import triton.language as tl + +from .utils import calculate_settings +from .utils import compare_version +from .utils import ensure_contiguous +from .utils import torch_to_triton_dtype + +if compare_version("triton", operator.ge, "3.0.0"): + try: + # typical import path with dispatch available + from triton.language.extra.libdevice import rsqrt + except ModuleNotFoundError: + # for working with NGC containers + from triton.language.extra.cuda.libdevice import rsqrt +else: + from triton.language.math import rsqrt + + +_CASTING_MODE_NONE: tl.constexpr = tl.constexpr(-1) +_CASTING_MODE_LLAMA: tl.constexpr = tl.constexpr(0) +_CASTING_MODE_GEMMA: tl.constexpr = tl.constexpr(1) + + +@triton.jit +def _rms_norm_forward_kernel( + Y_ptr, + Y_row_stride, + X_ptr, + X_row_stride, + W_ptr, + W_row_stride, + RSTD_ptr, + RSTD_row_stride, + n_cols, + eps, + offset, + casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out + BLOCK_SIZE: tl.constexpr, +): + """ + y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N) + + Reference: + 1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html + 2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22 + 3. https://arxiv.org/pdf/1910.07467 + """ + + row_idx = tl.program_id(0) + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + Y_ptr += row_idx * Y_row_stride + X_ptr += row_idx * X_row_stride + RSTD_ptr += row_idx * RSTD_row_stride + + X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0) + X_row_dtype = X_row.dtype + W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0) + + # On Llama, only rstd is computed on fp32 + if casting_mode == _CASTING_MODE_LLAMA: + X_row = X_row.to(tl.float32) + + # Gemma computes everything on fp32, and then casts back the output to the original dtype + if casting_mode == _CASTING_MODE_GEMMA: + W_row = W_row.to(tl.float32) + X_row = X_row.to(tl.float32) + + if casting_mode == _CASTING_MODE_NONE: + eps = eps.to(X_row_dtype) + offset = offset.to(X_row_dtype) + + mean_square = tl.sum(X_row * X_row, axis=0) / n_cols + rstd = rsqrt(mean_square + eps) + + # We can save time by caching rms with minimal memory overhead + # because rms is much smaller compared to X_row, as rms is for each row. + # However, on the computation side, it can save 4 operations (*, sum, /, sqrt). + tl.store(RSTD_ptr, rstd) + + X_row = X_row * rstd + + # On Llama, the multiplication with the weight is done on the original dtype + if casting_mode == _CASTING_MODE_LLAMA: + X_row = X_row.to(X_row_dtype) + + Y_row = X_row * (offset + W_row) + + if casting_mode == _CASTING_MODE_GEMMA: + Y_row = Y_row.to(X_row_dtype) + + tl.store(Y_ptr + col_offsets, Y_row, mask=mask) + + +@triton.jit +def _rms_norm_backward_kernel( + dY_ptr, + dY_row_stride, + dX_ptr, + dX_row_stride, + X_ptr, + X_row_stride, + X_dtype: tl.constexpr, + W_ptr, + W_row_stride, + RSTD_ptr, + RSTD_row_stride, + dW_ptr, + dW_row_stride, + n_rows, + n_cols, + offset, + rows_per_program: tl.constexpr, + casting_mode: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """ + dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product + dw = sum(dy * (x / RMS)). summation over BxT dimension + """ + + row_block_id = tl.program_id(0) + row_start = row_block_id * rows_per_program + row_end = min((row_block_id + 1) * rows_per_program, n_rows) + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + + dY_ptr += row_start * dY_row_stride + dX_ptr += row_start * dX_row_stride + + X_ptr += row_start * X_row_stride + RSTD_ptr += row_start + + W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0) + W_row = W_row + offset + + for _ in range(row_start, row_end): + dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0) + X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0) + + # Get cached rms + rstd_row = tl.load(RSTD_ptr) + + X_row = X_row.to(tl.float32) + + # Different bacward graphs for different casting modes + if casting_mode == _CASTING_MODE_LLAMA: + m = (dY_row * W_row).to(tl.float32) + + elif casting_mode == _CASTING_MODE_GEMMA: + dY_row = dY_row.to(tl.float32) + m = dY_row * W_row + else: + m = dY_row * W_row + + dX_row = rstd_row * m + + dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row) + + # calculate the gradient of W + if casting_mode == _CASTING_MODE_LLAMA: + dW_row += dY_row * (X_row * rstd_row).to(X_dtype) + else: + # here X_row is already in fp32 (see previous if block) + dW_row += dY_row * (X_row * rstd_row) + + tl.store(dX_ptr + col_offsets, dX_row.to(X_dtype), mask=mask) + + dY_ptr += dY_row_stride + dX_ptr += dX_row_stride + X_ptr += X_row_stride + RSTD_ptr += RSTD_row_stride + + tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask) + + +_str_to_casting_mode = { + "llama": _CASTING_MODE_LLAMA.value, + "gemma": _CASTING_MODE_GEMMA.value, + "none": _CASTING_MODE_NONE.value, +} + + +def rms_norm_forward(X, W, eps, offset, casting_mode): + if not isinstance(casting_mode, int): + assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}" + casting_mode = _str_to_casting_mode[casting_mode] + else: + assert casting_mode in _str_to_casting_mode.values(), f"Invalid casting mode: {casting_mode}" + + shape = X.shape + dim = shape[-1] + X = X.view(-1, dim) + n_rows, n_cols = X.shape + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + + Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) + # RSTD is to cache rstd for each row + # RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode + rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype + RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device) + + # Check constraints. + assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]" + + # XPU-specific optimization + kernel_args = {} + if X.device.type == "xpu": + kernel_args["grf_mode"] = "large" + _rms_norm_forward_kernel[(n_rows,)]( + Y, + Y.stride(0), + X, + X.stride(0), + W, + W.stride(0), + RSTD, + RSTD.stride(0), + n_cols, + eps, + offset, + casting_mode, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + **kernel_args, # XPU-specific optimization + ) + return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode + + +def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place): + shape = dY.shape + dim = shape[-1] + dY = dY.view(-1, dim) + n_rows, n_cols = dY.shape + + sm_count = 1 + if X.device.type == "cuda": + sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count + elif X.device.type == "xpu": + sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count + + # fp32 for numerical stability especially. + _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device) + + if n_cols > BLOCK_SIZE: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + rows_per_program = math.ceil(n_rows / sm_count) + grid = (sm_count,) + + if in_place is True: + dX = dY + else: + dX = torch.zeros_like(dY) + + # XPU-specific optimization + kernel_args = {} + if X.device.type == "xpu": + kernel_args["grf_mode"] = "large" + + _rms_norm_backward_kernel[grid]( + dY, + dY.stride(0), + dX, + dX.stride(0), + X, + X.stride(0), + torch_to_triton_dtype[X.dtype], + W, + W.stride(0), + RSTD, + RSTD.stride(0), + _dW, + _dW.stride(0), + n_rows, + n_cols, + offset, + rows_per_program, + casting_mode, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + **kernel_args, # XPU-specific optimization + ) + dX = dX.view(*shape) + dW = _dW.sum(dim=0).to(W.dtype) + + return dX, dW + + +class LigerRMSNormFunction(torch.autograd.Function): + """ + Performs RMSNorm (Root Mean Square Normalization), which normalizes the input tensor `X` using the + weight tensor `W`, with an optional offset and casting mode. + + Some models use an 'offset' to shift the weight tensor `W` by a constant value. For example, Gemma + uses an offset of 1.0, so the computation becomes `(X / RMS(X)) * (W + 1.0)` instead of the usual + `(X / RMS(X)) * W`. You can pass the offset value as an argument to the forward function. + + In addition, different models cast their inputs at different places during RMSNorm computation. For + example, Gemma casts everything to fp32 nefore starting the computation, while Llama casts only the + inverse RMS to fp32. You can specify the casting mode using the `casting_mode` argument. We currently + support the following casting modes (they match HuggingFace Transformers' implementations): + - 'llama': matches the Llama implementation, where only the inverse RMS is computed on fp32. + - 'gemma': matches the Gemma implementation, where everything is cast to fp32, then computed, then cast back to the original dtype. + - 'none': no casting is done. The computation is done in the original dtype. This saves memory and is slightly faster, but has more error w.r.t. the original implementation. + + `in_place` option means whether to in_place modify dY to store dX. This is default to `True` to save memory. However, under certain cases, it can produce incorrect inputs. + For example, gemma2 uses two rmsnorm sequentially with residual in between. The resesidual part needs dY so it cannot be modified in-place. + Therefore, for the patching of RMSNorm in gemma2, we set `in_place` to `False` + """ + + @staticmethod + @ensure_contiguous + def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True): + """ + X: (B, T, H) or (BxT, H) + W: (H,) + """ + Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode) + ctx.offset = offset + ctx.casting_mode = casting_mode + ctx.in_place = in_place + ctx.BLOCK_SIZE = BLOCK_SIZE + ctx.num_warps = num_warps + ctx.save_for_backward(X, W, RSTD) + return Y + + @staticmethod + @ensure_contiguous + def backward(ctx, dY): + """ + Y: (B, T, H) or (BxT, H) + """ + X, W, RSTD = ctx.saved_tensors + dX, dW = rms_norm_backward( + dY, + X, + W, + RSTD, + ctx.offset, + ctx.casting_mode, + ctx.BLOCK_SIZE, + ctx.num_warps, + ctx.in_place, + ) + return dX, dW, None, None, None, None \ No newline at end of file diff --git a/build/torch-rocm/rope.py b/build/torch-rocm/rope.py new file mode 100644 index 0000000000000000000000000000000000000000..8c4801bf72c82d718ebd5fa63f278fdcbe0f23b3 --- /dev/null +++ b/build/torch-rocm/rope.py @@ -0,0 +1,239 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _triton_rope( + q_ptr, + q_row_stride, + k_ptr, + k_row_stride, + cos, + cos_row_stride, + sin, + sin_row_stride, + sl, + bs: tl.constexpr, + cos_bs: tl.constexpr, + n_qh: tl.constexpr, + n_kh: tl.constexpr, + hd: tl.constexpr, + pad_n_qh: tl.constexpr, + pad_n_kh: tl.constexpr, + pad_hd: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BACKWARD_PASS: tl.constexpr = False, +): + # q size: (bsz, seq_len, num_q_heads, head_dim) + # q stride: (seq_len * num_q_heads * head_dim, num_q_heads * head_dim, head_dim, 1) + # k size: (bsz, seq_len, num_kv_heads, head_dim) + # k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1) + + # cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) + # stride: (seq_len * head_dim, head_dim, 1) + pid = tl.program_id(0) + + # locate start address + q_ptr = q_ptr + pid * q_row_stride + k_ptr = k_ptr + pid * k_row_stride + + # #################################################################### + # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position + # m of this program instance + # #################################################################### + + # 1. program instances are laid out in a 1D vector of size bsz * seq_len, which + # effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension + # being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index + # and pid % sl to get the sequence index. + # 2. We only need the left half of cos and sin matrix because the right half is just + # a clone of the left half. + batch_idx = pid // sl + cos_row_idx = pid % sl + cos = cos + tl.where( + cos_bs == 1, + cos_row_idx * cos_row_stride, + batch_idx * (sl * cos_row_stride) + cos_row_idx * cos_row_stride, + ) + sin = sin + tl.where( + cos_bs == 1, + cos_row_idx * sin_row_stride, + batch_idx * (sl * sin_row_stride) + cos_row_idx * sin_row_stride, + ) + + cos_offsets = tl.arange(0, pad_hd // 2) + cos_mask = cos_offsets < hd // 2 + cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0) + sin_row = tl.load(sin + cos_offsets, mask=cos_mask, other=0) + + # #################################################################### + # Load the left and right half of q and k for the current + # program instance (i.e. for the current token) separately + # #################################################################### + # left half of the head + first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] + first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] + first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2) + first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2) + q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype) + k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype) + + # right half of the head + second_half_q_offsets = first_half_q_offsets + (hd // 2) + second_half_k_offsets = first_half_k_offsets + (hd // 2) + second_q_mask = first_q_mask + second_k_mask = first_k_mask + q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype) + k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype) + + if not BACKWARD_PASS: + # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin] + new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row + tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) + new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row + tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) + + new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row + tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask) + new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row + tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) + else: + # with some math, we can get: + # dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin] + new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row + tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) + new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row + tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) + + new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row + tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask) + new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row + tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) + + +def rope_forward(q, k, cos, sin): + # transpose it back to the physical shape because Triton looks at the physical storage + # note: q and k are incontiguous before the transformation and will become contiguous after transpose + q = q.transpose(1, 2) + k = k.transpose(1, 2) + + batch_size, seq_len, n_q_head, head_dim = q.shape + n_kv_head = k.shape[2] + pad_hd = triton.next_power_of_2(head_dim) + pad_n_q_head = triton.next_power_of_2(n_q_head) + pad_n_kv_head = triton.next_power_of_2(n_kv_head) + BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head) + + n_row = batch_size * seq_len + + # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous + q = q.contiguous() + k = k.contiguous() + cos = cos.contiguous() + sin = sin.contiguous() + cos_batch_size = cos.shape[0] + + _triton_rope[(n_row,)]( + q, + q.stride(1), + k, + k.stride(1), + cos, + cos.stride(-2), + sin, + sin.stride(-2), + seq_len, + batch_size, + cos_batch_size, + n_q_head, + n_kv_head, + head_dim, + pad_n_q_head, + pad_n_kv_head, + pad_hd, + BLOCK_SIZE=BLOCK_SIZE, + BACKWARD_PASS=False, + ) + return q.transpose(1, 2), k.transpose(1, 2), cos, sin + + +def rope_backward(dq, dk, cos, sin): + dq = dq.transpose(1, 2) + dk = dk.transpose(1, 2) + + batch_size, seq_len, n_q_head, head_dim = dq.shape + cos_batch_size = cos.shape[0] + n_kv_head = dk.shape[2] + pad_hd = triton.next_power_of_2(head_dim) + pad_n_q_head = triton.next_power_of_2(n_q_head) + pad_n_kv_head = triton.next_power_of_2(n_kv_head) + BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head) + + n_row = batch_size * seq_len + + # ensure dq and dk are contiguous + dq = dq.contiguous() + dk = dk.contiguous() + + # backward is similar to forward except swapping few ops + _triton_rope[(n_row,)]( + dq, + dq.stride(1), + dk, + dk.stride(1), + cos, + cos.stride(-2), + sin, + sin.stride(-2), + seq_len, + batch_size, + cos_batch_size, + n_q_head, + n_kv_head, + head_dim, + pad_n_q_head, + pad_n_kv_head, + pad_hd, + BLOCK_SIZE=BLOCK_SIZE, + BACKWARD_PASS=True, + ) + return dq.transpose(1, 2), dk.transpose(1, 2) + + +class LigerRopeFunction(torch.autograd.Function): + """ + Triton implementation of the Rotary Positional Embedding (RoPE) operation. Please note that + this implements the HuggingFace Llama & Mistral version, whose rotation matrix is slightly different + than the original RoPE paper. + + Please find the corresponding HuggingFace implementation here: + https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/llama/modeling_llama.py#L184 + + For more details about the rotation matrix used here, please refer to: + https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509/2 + """ + + @staticmethod + def forward(ctx, q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """ + q size: (bsz, n_q_head, seq_len, head_dim) + k size: (bsz, n_kv_head, seq_len, head_dim) + cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) + sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) + """ + q, k, cos, sin = rope_forward(q, k, cos, sin) + ctx.save_for_backward(cos, sin) + return q, k + + def backward(ctx, dq, dk): + """ + dq size: (bsz, n_q_head, seq_len, head_dim) + dk size: (bsz, n_kv_head, seq_len, head_dim) + cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) + sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) + """ + + cos, sin = ctx.saved_tensors + dq, dk = rope_backward(dq, dk, cos, sin) + return dq, dk, None, None, None, None \ No newline at end of file diff --git a/build/torch-rocm/swiglu.py b/build/torch-rocm/swiglu.py new file mode 100644 index 0000000000000000000000000000000000000000..f4ed39745bb5966a9d5eaee4643d487e66a4e620 --- /dev/null +++ b/build/torch-rocm/swiglu.py @@ -0,0 +1,116 @@ +import torch +import triton +import triton.language as tl + +from .utils import calculate_settings +from .utils import ensure_contiguous + + +@triton.jit +def silu(x): + return x * tl.sigmoid(x) + + +@triton.jit +def _swiglu_forward_kernel(a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr): + program_id = tl.program_id(0).to(tl.int64) + + # locate start index + a_ptr += program_id * stride + b_ptr += program_id * stride + c_ptr += program_id * stride + + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + # sigmoid requires type float32 + a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32) + b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0) + c_row = silu(a_row) * b_row + tl.store(c_ptr + col_offsets, c_row, mask=mask) + + +@triton.jit +def _swiglu_backward_kernel(dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr): + program_id = tl.program_id(0).to(tl.int64) + + # locate start index + dc_ptr += program_id * stride + a_ptr += program_id * stride + b_ptr += program_id * stride + + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + dc_row = tl.load(dc_ptr + col_offsets, mask=mask, other=0) + # sigmoid requires type float32 + a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32) + b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0) + + # recomputation to save memory + sig_a = tl.sigmoid(a_row) + silu_a = a_row * sig_a + db_row = dc_row * silu_a + da_row = dc_row * (silu_a * (1 - sig_a) + sig_a) * b_row + + tl.store(a_ptr + col_offsets, da_row, mask=mask) + tl.store(b_ptr + col_offsets, db_row, mask=mask) + + +def swiglu_forward(a, b): + ori_shape = a.shape + + n_cols = ori_shape[-1] + a = a.view(-1, n_cols) + b = b.view(-1, n_cols) + c = torch.empty_like(a) + n_rows = a.shape[0] + + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + + _swiglu_forward_kernel[(n_rows,)]( + a, + b, + c, + c.stride(-2), + n_cols=n_cols, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + return a, b, c.view(*ori_shape) + + +def swiglu_backward(a, b, dc): + ori_shape = dc.shape + n_cols = ori_shape[-1] + dc = dc.view(-1, n_cols) + n_rows = dc.shape[0] + + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + + _swiglu_backward_kernel[(n_rows,)]( + dc, + a, + b, + dc.stride(-2), + n_cols=n_cols, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + return a.view(*ori_shape), b.view(*ori_shape) + + +class LigerSiLUMulFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward(ctx, a, b): + a, b, c = swiglu_forward(a, b) + ctx.save_for_backward(a, b) + return c + + @staticmethod + @ensure_contiguous + def backward(ctx, dc): + a, b = ctx.saved_tensors + a, b = swiglu_backward(a, b, dc) + return a, b \ No newline at end of file diff --git a/build/torch-rocm/tvd.py b/build/torch-rocm/tvd.py new file mode 100644 index 0000000000000000000000000000000000000000..c072e4118c6fa5980de5c96a7acdaa05b6acb468 --- /dev/null +++ b/build/torch-rocm/tvd.py @@ -0,0 +1,207 @@ +from typing import Literal +from typing import Optional + +import torch +import triton +import triton.language as tl + +from .utils import ensure_contiguous + +MAX_FUSED_SIZE = 65536 // 4 + +REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"] + +_REDUCTION_MODE_NONE = tl.constexpr(0) +_REDUCTION_MODE_SUM = tl.constexpr(1) +_REDUCTION_MODE_MEAN = tl.constexpr(2) +_REDUCTION_MODE_BATCHMEAN = tl.constexpr(3) + +_str_to_reduction_mode = { + "none": _REDUCTION_MODE_NONE.value, + "sum": _REDUCTION_MODE_SUM.value, + "mean": _REDUCTION_MODE_MEAN.value, + "batchmean": _REDUCTION_MODE_BATCHMEAN.value, +} + + +def get_num_warps(BLOCK_SIZE): + num_warps = 4 + if BLOCK_SIZE >= 32768: + num_warps = 32 + elif BLOCK_SIZE >= 8192: + num_warps = 16 + elif BLOCK_SIZE >= 2048: + num_warps = 8 + + return num_warps + + +@triton.jit +def _tv_distance_kernel( + p_ptr, + p_stride, + q_ptr, + q_stride, + loss_ptr, + loss_stride, + grads_ptr, + grads_stride, + label_ptr, + ignore_index: tl.constexpr, + n_cols, + BLOCK_SIZE: tl.constexpr, + HAS_LABEL: tl.constexpr, + reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN, +): + pid = tl.program_id(0).to(tl.int64) + p_ptr += pid * p_stride + q_ptr += pid * q_stride + loss_ptr += pid * loss_stride + grads_ptr += pid * grads_stride + label_ptr += pid + + base_offsets = tl.arange(0, BLOCK_SIZE) + + if HAS_LABEL: + label = tl.load(label_ptr) + if label == ignore_index: + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + base_offsets + mask = offsets < n_cols + tl.store(grads_ptr + offsets, 0.0, mask=mask) + if reduction == _REDUCTION_MODE_NONE: + tl.store(loss_ptr + offsets, 0.0, mask=mask) + return + + loss_sum = 0.0 + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + base_offsets + mask = offsets < n_cols + + p = tl.load(p_ptr + offsets, mask=mask, other=0.0) + q = tl.load(q_ptr + offsets, mask=mask, other=0.0) + + # TVD(P || Q) = 0.5 * |P - Q| + tv_loss = 0.5 * tl.abs(p - q) + + grad_res = tl.where(p > q, 0.5, -0.5) + + tl.store(grads_ptr + offsets, grad_res, mask=mask) + + if reduction == _REDUCTION_MODE_NONE: + tl.store(loss_ptr + offsets, tv_loss, mask=mask) + else: + loss_sum += tl.sum(tv_loss, axis=0) + + if reduction != _REDUCTION_MODE_NONE: + tl.store(loss_ptr, loss_sum) + + +def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label): + BT, V = p.shape + + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + num_warps = get_num_warps(BLOCK_SIZE) + + grid = (BT,) + + reduction = _str_to_reduction_mode[reduction] + + out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,) + output_tensor = torch.zeros(out_size, device=p.device, dtype=torch.float32) + grads = torch.empty_like(p) + + n_non_ignore = (shift_labels != ignore_index).sum().item() if has_label else BT + + _tv_distance_kernel[grid]( + p, + p.stride(0), + q, + q.stride(0), + output_tensor, + output_tensor.stride(0), + grads, + grads.stride(0), + shift_labels if has_label else torch.empty(1, device=p.device), + ignore_index, + V, + BLOCK_SIZE=BLOCK_SIZE, + HAS_LABEL=has_label, + num_warps=num_warps, + reduction=reduction, + ) + + if reduction == _REDUCTION_MODE_BATCHMEAN.value: + return output_tensor.sum() / n_non_ignore, grads / n_non_ignore + elif reduction == _REDUCTION_MODE_SUM.value: + return output_tensor.sum(dim=0), grads + elif reduction == _REDUCTION_MODE_MEAN.value: + return output_tensor.sum() / (n_non_ignore * V), grads / (n_non_ignore * V) + else: + return output_tensor, grads + + +def tvd_backward_triton(grad_output, grads): + # If cross entropy is the last layer, grad_output is 1.0. Skip the mul then. + if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + return grads + + return grads * grad_output + + +class LigerTVDLossFunction(torch.autograd.Function): + """ + Class implementing the forward and backward pass for the Total Variation Distance Loss using Triton. + """ + + @staticmethod + @ensure_contiguous + def forward( + ctx, + p: torch.Tensor, + q: torch.Tensor, + shift_labels: Optional[torch.Tensor] = None, + reduction: REDUCTION_LITERAL = "batchmean", + ignore_index: int = -100, + ) -> torch.Tensor: + """A forward pass for the Total Variation Distance Loss. + + Args: + ctx: Torch autograd context + p (torch.Tensor): A tensor of shape (BT, V) containing the first distribution. + q (torch.Tensor): A tensor of shape (BT, V) containing the second distribution. + shift_labels (Optional[torch.Tensor]): A tensor of shape (BT,) containing the labels. + reduction (REDUCTION_LITERAL, optional): The reduction method to be applied. Defaults to "batchmean". + ignore_index (int, optional): The index to ignore during loss calculation. Defaults to -100. + + Returns: + torch.Tensor: The computed Total Variation Distance Loss. + """ + has_label = False + if shift_labels is not None: + assert shift_labels.shape == (p.shape[0],), ( + f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}" + ) + shift_labels = shift_labels.contiguous() + has_label = True + + loss, grads = tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label) + ctx.save_for_backward(grads) + return loss + + @staticmethod + @ensure_contiguous + def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: + """A backward pass for the Total Variation Distance Loss. + + Args: + ctx: Torch autograd context + grad_output (torch.Tensor): The gradient of the loss with respect to the output. + + Returns: + tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs. + """ + (grads,) = ctx.saved_tensors + grads = tvd_backward_triton(grad_output, grads) + + return grads, None, None, None, None \ No newline at end of file diff --git a/build/torch-rocm/utils.py b/build/torch-rocm/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..718b0bb9c6391cf291bae91e671fc44a2b5ebf36 --- /dev/null +++ b/build/torch-rocm/utils.py @@ -0,0 +1,135 @@ +""" +This file incorporates code from Unsloth licensed under the Apache License, Version 2.0. +See the original Unsloth repository at https://github.com/unslothai/unsloth. + +The following line +https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/utils.py#L23 +is based on code from Unsloth, located at: +https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43 + +Modifications made by Yanning Chen, 2024. +""" + +import functools +import importlib +import operator + +from typing import Callable + +import torch +import triton +import triton.language as tl + +from packaging.version import Version + +def infer_device(): + """ + Get current device name based on available devices + """ + if torch.cuda.is_available(): # Works for both Nvidia and AMD + return "cuda" + elif torch.xpu.is_available(): + return "xpu" + else: + return "cpu" + +def is_hip() -> bool: + return torch.version.hip is not None + + +def ensure_contiguous(fn): + @functools.wraps(fn) + def wrapper(ctx, *args, **kwargs): + def maybe_to_contiguous(x): + return x.contiguous() if isinstance(x, torch.Tensor) else x + + args = [maybe_to_contiguous(arg) for arg in args] + kwargs = {k: maybe_to_contiguous(v) for k, v in kwargs.items()} + return fn(ctx, *args, **kwargs) + + return wrapper + + +def calculate_settings(n): + # reference: https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43 + + MAX_FUSED_SIZE = 65536 + BLOCK_SIZE = triton.next_power_of_2(n) + if BLOCK_SIZE > MAX_FUSED_SIZE: + raise RuntimeError( + f"Cannot launch Triton kernel since n = {n} exceeds the recommended Triton blocksize = {MAX_FUSED_SIZE}." + ) + + num_warps = 4 + if BLOCK_SIZE >= 32768: + num_warps = 32 if not is_hip() else 16 + elif BLOCK_SIZE >= 8192: + num_warps = 16 + elif BLOCK_SIZE >= 2048: + num_warps = 8 + return BLOCK_SIZE, num_warps + + +def compare_version(package: str, operator: Callable, target: str): + try: + pkg = importlib.import_module(package) + except ImportError: + return False + pkg_version = Version(pkg.__version__) + return operator(pkg_version, Version(target)) + + +def get_amp_custom_fwd_bwd() -> Callable: + device = infer_device() + if compare_version("torch", operator.ge, "2.4.0"): + return ( + functools.partial(torch.amp.custom_fwd, device_type=device), + functools.partial(torch.amp.custom_bwd, device_type=device), + ) + return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd + + +amp_custom_fwd, amp_custom_bwd = get_amp_custom_fwd_bwd() + + +torch_to_triton_dtype = { + torch.float32: tl.float32, + torch.float16: tl.float16, + torch.bfloat16: tl.bfloat16, +} + + +@triton.jit +def element_mul_kernel( + X_ptr, + X_stride, + grad_output_ptr, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + """ + This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr. + The multiplication is performed in-place on the tensor pointed by X_ptr. + + Parameters: + X_ptr: Pointer to the input tensor. + X_stride (int): The stride of the input tensor. + grad_output_ptr: Pointer to the gradient output value. + n_cols (int): The number of columns in the input tensor. + BLOCK_SIZE (int): The block size for Triton operations. + """ + + # Get the program ID and convert it to int64 to avoid overflow + program_id = tl.program_id(0).to(tl.int64) + + # Locate the start index + X_ptr += program_id * X_stride + + # Load the gradient output value + grad_output = tl.load(grad_output_ptr) + + # Perform the element-wise multiplication + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols) + tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols) \ No newline at end of file diff --git a/build/torch-xpu/__init__.py b/build/torch-xpu/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..271a47c2646379ff9ec8cf8c3c3a0e973f187652 --- /dev/null +++ b/build/torch-xpu/__init__.py @@ -0,0 +1,3 @@ +from . import layers + +__all__ = ["layers"] \ No newline at end of file diff --git a/build/torch-xpu/_ops.py b/build/torch-xpu/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..45b30870c17a6c248a2a9b9922667f3afc95a973 --- /dev/null +++ b/build/torch-xpu/_ops.py @@ -0,0 +1,8 @@ +import torch +ops = torch.ops._liger_kernels_ab5ef3f + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_liger_kernels_ab5ef3f::{op_name}" \ No newline at end of file diff --git a/build/torch-xpu/cross_entropy.py b/build/torch-xpu/cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..590d94868b34455f93c98f6c2fb46872d6b6956e --- /dev/null +++ b/build/torch-xpu/cross_entropy.py @@ -0,0 +1,460 @@ +import operator + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from .utils import compare_version +from .utils import element_mul_kernel +from .utils import is_hip +from .utils import infer_device + +if compare_version("triton", operator.ge, "3.0.0"): + try: + # typical import path with dispatch available + from triton.language.extra.libdevice import tanh + except ModuleNotFoundError: + # for working with NGC containers + from triton.language.extra.cuda.libdevice import tanh +else: + from triton.language.math import tanh + + +@triton.jit +def liger_cross_entropy_kernel( + X_ptr, + X_stride, + Y_ptr, + Y_stride, + weight_ptr, + loss_ptr, + z_loss_ptr, + loss_stride, + n_cols, + n_non_ignore, + sum_non_ignore_weight, + weight_sum, + ignore_index, + lse_square_scale: tl.constexpr, + label_smoothing: tl.constexpr, + reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time + softcap, + RETURN_Z_LOSS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_SOFTCAPPING: tl.constexpr, +): + """ + This kernel computes both cross entropy loss and the gradient of the input. + We only consider hard label + mean reduction for now. Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math. + + Parameters: + X_ptr: Pointer to input tensor. + X_stride (int): The stride of the input tensor. + Y_ptr: Pointer to target tensor. + Y_stride (int): The stride of the target tensor. + weight_ptr: Pointer to weight tensor. + loss_ptr: Pointer to tensor to store the loss. + z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0. + loss_stride (int): The stride of the loss tensor. + n_cols (int): The number of columns in the input tensor. + n_non_ignore (float): The number of non-ignored elements in the batch. + sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch. + weight_sum (float): The sum of weight tensor. + ignore_index (int): The index to ignore in the target. + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. + reduction (str): The string for the reduction to apply + softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap). + RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1. + BLOCK_SIZE (int): The block size for Triton operations. + HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes. + HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not. + """ + + # https://github.com/triton-lang/triton/issues/1058 + # If B*T*V is too large, program_id * stride will overflow out of int32, so we convert to int64 + program_id = tl.program_id(0).to(tl.int64) + + # 1. Load Y_ptr first because if the target is ignore_index, we can return right away + Y_ptr += program_id * Y_stride + y = tl.load(Y_ptr) + + # 2. locate the start index + X_ptr += program_id * X_stride + + if y == ignore_index: + # set all X_ptr as 0 + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols) + return + + loss_ptr += program_id * loss_stride + if RETURN_Z_LOSS: + z_loss_ptr += program_id * loss_stride + + if HAS_WEIGHT: + weight_y = tl.load(weight_ptr + y).cast(tl.float32) + + # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax) + # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867 + + # 3. [Online softmax] first pass: find max + sum + m = float("-inf") # m is the max value. use the notation from the paper + d = 0.0 # d is the sum. use the notation from the paper + ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation + if HAS_SOFTCAPPING: + ori_X_y = softcap * tanh(ori_X_y / softcap) + + # Label smoothing is a general case of normal cross entropy + # See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310 + scaled_x_sum = 0.0 + eps = label_smoothing / n_cols + + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load( + X_ptr + X_offsets, + mask=X_offsets < n_cols, + other=float("-inf"), + # Ensure float32 precision for softmax calculation + ).cast(tl.float32) + if HAS_SOFTCAPPING: + X_block = softcap * tanh(X_block / softcap) + block_max = tl.max(X_block) + if label_smoothing > 0: + # scale X beforehand to avoid overflow + if HAS_WEIGHT: + weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) + scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block * weight_block, 0.0)) + else: + scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) + m_new = tl.maximum(m, block_max) + d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) + m = m_new + + # log (sum(e^(X_i))) = log (sum(e ^ (max(X) * e ^ (X_i - max(X))))) + # = log (e^(max(X)) * sum(e ^ (X_i - max(X)))) + # = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d + lse = m + tl.log(d) + + # 4. [Online Softmax] Second pass: compute gradients + # For 'mean' reduction, gradients are normalized by number of non-ignored elements (N) + # dx_y = (softmax(x_y) - 1) / N + # dx_i = softmax(x_i) / N, i != y + # For label smoothing: + # dx_i = (softmax(x_i) - label_smoothing / V) / N, V = n_cols, i != y + # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N + # = dx_i - (1 - label_smoothing) / N + # With Z loss: + # dx_i = ((1 + 2 * lse_square_scale * lse) * softmax(x_i) - label_smoothing / V) / N, i != y + # dx_y = dx_i - (1 - label_smoothing) / N + # For 'sum' reduction, no normalization is applied: + # dx_y = softmax(x_y) - 1 + # dx_i = softmax(x_i), for i ≠ y + + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load( + X_ptr + X_offsets, + mask=X_offsets < n_cols, + other=float("-inf"), + # Ensure float32 precision for softmax calculation + ).cast(tl.float32) + if HAS_SOFTCAPPING: + intermediate = tanh(X_block / softcap) + X_block = softcap * intermediate + + if not HAS_WEIGHT: + # softmax(x_i) + X_block = tl.exp(X_block - m) / d + # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i) + X_block += 2 * lse_square_scale * lse * X_block + # smoothing term + X_block += -eps + # special handle dx_y + X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing)) + # reduction scale + if reduction == "mean": + X_block = X_block / n_non_ignore + else: + weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) + softmax_X = tl.exp(X_block - m) / d + # derivative of original_loss + dloss_ori = (1 - label_smoothing) * softmax_X + # specially handle dx_y + dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing)) + dloss_ori = dloss_ori * weight_y + # derivative of smooth_loss + dloss_smooth = eps * (-weight_block + softmax_X * weight_sum) + # derivative of z-loss + dz_loss = 2 * lse_square_scale * lse * softmax_X + # reduction scale + if reduction == "mean": + dloss_ori = dloss_ori / sum_non_ignore_weight + dloss_smooth = dloss_smooth / sum_non_ignore_weight + # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight. + dz_loss = dz_loss / n_non_ignore + # derivative of total_loss + X_block = dloss_ori + dloss_smooth + dz_loss + + # chain rule softcapping + # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap)) + if HAS_SOFTCAPPING: + X_block = X_block * (1 - intermediate * intermediate) + + tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols) + + # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in + # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34 + tl.debug_barrier() + + # 5. Calculate the loss + + # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X)))) + # = (X_y - max(X)) - log(sum(e ^ (X - max(X)))) + # = X_y - m - log d = X_y - lse + # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1 + # So we can safely calculate log (softmax(X_y)) without overflow + loss = lse - ori_X_y + if HAS_WEIGHT: + loss = weight_y * loss + + # Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps + # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) + # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i)) + # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as: + # = (1 - label_smoothing) * H(q, p) + (sum(-eps * x_i) + label_smoothing * (m + logd)) + # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567 + # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516 + # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087 + if label_smoothing > 0: + if HAS_WEIGHT: + smooth_loss = scaled_x_sum + eps * lse * weight_sum + else: + smooth_loss = scaled_x_sum + label_smoothing * lse + loss = loss * (1 - label_smoothing) + smooth_loss + + # An auxiliary loss, z_loss + # Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html + z_loss = lse_square_scale * lse * lse + # Normalize the loss by the number of non-ignored elements if reduction is "mean" + if reduction == "mean": + if HAS_WEIGHT: + loss = loss / sum_non_ignore_weight + else: + loss = loss / n_non_ignore + # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight. + z_loss = z_loss / n_non_ignore + loss += z_loss + + tl.store(loss_ptr, loss) + if RETURN_Z_LOSS: + tl.store(z_loss_ptr, z_loss) + + +# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 +# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling +# The optimal maximum block size depends on your hardware, your kernel, and your dtype +MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2 # the best size we found by manually tuning + + +def cross_entropy_forward( + _input, + target, + weight, + ignore_index, + lse_square_scale, + label_smoothing, + reduction, + softcap, + return_z_loss, +): + assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}" + + BT, V = _input.shape + n_rows = BT + + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + + # unreduced loss + loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) + z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None + + target_mask = target != ignore_index + n_non_ignore = target_mask.sum().item() + assert (target * target_mask).max() < _input.shape[-1], ( + f"Target {target.max()} is out of bounds. Expected < {_input.shape[-1]}" + ) + assert (target * target_mask).min() >= 0, f"Target {target.min()} is out of bounds. Expected >= 0" + sum_non_ignore_weight = n_non_ignore + weight_sum = 0.0 + if weight is not None: + assert weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {weight.shape}" + assert torch.is_floating_point(weight), ( + f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}" + ) + sum_non_ignore_weight = torch.gather(weight, dim=0, index=target.masked_select(target_mask)).sum().item() + weight_sum = weight.sum().item() + # ensure weight is contiguous + if weight.stride(-1) != 1: + weight = weight.contiguous() + + # ensure _input and target are contiguous in the last dimension + if _input.stride(-1) != 1: + _input = _input.contiguous() + if target.stride(-1) != 1: + target = target.contiguous() + + # Here we use a trick to store X_ptr gradient in X_ptr so we can save memory + liger_cross_entropy_kernel[(n_rows,)]( + X_ptr=_input, + X_stride=_input.stride(-2), + Y_ptr=target, + Y_stride=target.stride(-1), # always 1 + weight_ptr=weight, # dummy if None + loss_ptr=loss_1d, + z_loss_ptr=z_loss_1d, + loss_stride=loss_1d.stride(-1), # always 1 + n_cols=V, + n_non_ignore=n_non_ignore, + sum_non_ignore_weight=sum_non_ignore_weight, + ignore_index=ignore_index, + weight_sum=weight_sum, + lse_square_scale=lse_square_scale, + label_smoothing=label_smoothing, + reduction=reduction, + softcap=softcap, + RETURN_Z_LOSS=return_z_loss, + BLOCK_SIZE=BLOCK_SIZE, + HAS_WEIGHT=True if weight is not None else False, + HAS_SOFTCAPPING=True if softcap is not None else False, + # TODO: 32 seems to give the best performance + # Performance is quite sensitive to num_warps + num_warps=32 if not is_hip() else 16, + ) + + if reduction == "none": + loss = loss_1d + z_loss = z_loss_1d if return_z_loss else None + else: + loss = torch.sum(loss_1d) + z_loss = torch.sum(z_loss_1d) if return_z_loss else None + + return loss, z_loss, _input + + +def cross_entropy_backward(_input, grad_output): + # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time + if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + pass + # If reduction is 'none' + elif grad_output.ndim > 0: + _input = _input * grad_output.unsqueeze(dim=1) + # If reduction is ['mean', 'sum'], grad_output is just a scalar + # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place + # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. + else: + BT, V = _input.shape + n_rows = BT + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + + element_mul_kernel[(n_rows,)]( + _input, + _input.stride(-2), + grad_output, + V, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32 if not is_hip() else 16, + ) + + return _input + + +class LigerCrossEntropyFunction(torch.autograd.Function): + """ + This class implements a custom autograd function for the Liger Cross Entropy loss. + It overrides the forward and backward methods of the torch.autograd.Function class. + """ + + @staticmethod + def forward( + ctx, + _input: torch.Tensor, + target: torch.Tensor, + weight: Optional[torch.FloatTensor], + ignore_index: int = -100, + lse_square_scale: float = 0.0, + label_smoothing: float = 0.0, + reduction: str = "mean", + softcap: Optional[float] = None, + return_z_loss: bool = False, + ): + """ + The forward pass of the Liger Cross Entropy loss. + + Parameters: + ctx : The context object. + _input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size. + target (tensor): The target tensor of shape (BT) where each value is in [0, V-1]. + weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype + ignore_index (int): The index to ignore in the target. + lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + reduction (str): The reduction to apply to the output: "none" | "mean | "sum". + softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap). + return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False` + + Returns: + tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None. + """ + loss, z_loss, _input = cross_entropy_forward( + _input, + target, + weight, + ignore_index, + lse_square_scale, + label_smoothing, + reduction, + softcap, + return_z_loss, + ) + # TODO: investigation + # If we don't detach the _input tensor, the memory will double + # Not sure why but seems that there will be a time both grad and value exist but in different location + ctx.save_for_backward(_input.detach()) + ctx.return_z_loss = return_z_loss + + return loss, z_loss + + @staticmethod + def backward(ctx, grad_output, grad_ouput2): + """ + The backward pass of the Liger Cross Entropy loss. + + Parameters: + ctx : The context object with saved tensors. + grad_output (tensor): The tensor containing the gradient of the loss with respect to the output. + grad_output2 (tenosr): No use. + Returns: + tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None. + """ + if ctx.return_z_loss: + del grad_ouput2 # z_loss is only for logging + + (_input,) = ctx.saved_tensors + _input = cross_entropy_backward(_input, grad_output) + return ( + _input, + None, + None, + None, + None, + None, + None, + None, + None, + ) \ No newline at end of file diff --git a/build/torch-xpu/dyt.py b/build/torch-xpu/dyt.py new file mode 100644 index 0000000000000000000000000000000000000000..3cc42d14ca31b33e8d4b570f9b6b6d12126f656d --- /dev/null +++ b/build/torch-xpu/dyt.py @@ -0,0 +1,225 @@ +import operator + +import torch +import triton +import triton.language as tl + +from .utils import calculate_settings +from .utils import compare_version +from .utils import ensure_contiguous +from .utils import infer_device + +if compare_version("triton", operator.ge, "3.0.0"): + try: + # typical import path with dispatch available + from triton.language.extra.libdevice import tanh + except ModuleNotFoundError: + # for working with NGC containers + from triton.language.extra.cuda.libdevice import tanh +else: + from triton.language.math import tanh + + +@triton.jit +def _dyt_fwd_kernel( + x_ptr, + x_row_stride, + alpha_ptr, + gamma_ptr, + beta_ptr, + y_ptr, + y_row_stride, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + """ + Reference: + https://arxiv.org/abs/2503.10622 + + Shapes: + - x: (BT, C) + - alpha: (1) + - gamma: (C) + - beta: (C) + """ + row_idx = tl.program_id(0) + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < n_cols + + x_ptr += row_idx * x_row_stride + y_ptr += row_idx * y_row_stride + + alpha = tl.load(alpha_ptr) + gamma = tl.load(gamma_ptr + offsets, mask=mask) + beta = tl.load(beta_ptr + offsets, mask=mask) + x = tl.load(x_ptr + offsets, mask=mask) + y = gamma * tanh((alpha * x).cast(tl.float32)) + beta + tl.store(y_ptr + offsets, y, mask=mask) + + +@triton.jit +def _dyt_bwd_kernel( + x_ptr, + x_row_stride, + dy_ptr, + dy_row_stride, + dx_ptr, + dx_row_stride, + alpha_ptr, + dalpha_ptr, + gamma_ptr, + dgamma_ptr, + dgamma_row_stride, + n_cols, + n_rows, + ROWS_PER_PROGRAM: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """ + Reference: + https://arxiv.org/abs/2503.10622 + + Shapes: + - x: (BT, C) + - alpha: (1) + - gamma: (C) + - dx: (BT, C) + - dy: (BT, C) + - dgamma: (sm_count, C) + - dalpha: (sm_count,) + """ + # d(gamma * tanh(alpha * x) + beta) / dx + # = gamma * (1 - tanh^2(alpha * x)) * alpha + # d(gamma * tanh(alpha * x) + beta) / dalpha + # = gamma * (1 - tanh^2(alpha * x)) * x + # d(gamma * tanh(alpha * x) + beta) / dgamma + # = tanh(alpha * x) + # d(gamma * tanh(alpha * x)) / dbeta = 1 + pid = tl.program_id(0) + + row_start = pid * ROWS_PER_PROGRAM + row_end = min((pid + 1) * ROWS_PER_PROGRAM, n_rows) + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < n_cols + + dalpha = 0.0 + dgamma = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + + x_ptr += row_start * x_row_stride + dx_ptr += row_start * dx_row_stride + dy_ptr += row_start * dy_row_stride + alpha = tl.load(alpha_ptr) + gamma = tl.load(gamma_ptr + offsets, mask=mask, other=0.0) + + for _ in tl.range(row_start, row_end): + dy = tl.load(dy_ptr + offsets, mask=mask, other=0.0) + x = tl.load(x_ptr + offsets, mask=mask, other=0.0) + tanh_ax = tanh((alpha * x).cast(tl.float32)) + sech2_ax = 1 - tanh_ax * tanh_ax + + dx = dy * gamma * sech2_ax * alpha + dalpha += tl.sum(dy * gamma * sech2_ax * x) + dgamma += dy * tanh_ax + tl.store(dx_ptr + offsets, dx, mask=mask) + + dy_ptr += dy_row_stride + x_ptr += x_row_stride + dx_ptr += dx_row_stride + + tl.store(dgamma_ptr + pid * dgamma_row_stride + offsets, dgamma, mask=mask) + tl.store(dalpha_ptr + pid, dalpha) + + pass + + +def liger_dyt_fwd(x, alpha, gamma, beta): + shape = x.shape + dim = shape[-1] + x = x.view(-1, dim) + n_rows, n_cols = x.shape + y = torch.empty_like(x) + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + _dyt_fwd_kernel[(n_rows,)]( + x_ptr=x, + alpha_ptr=alpha, + gamma_ptr=gamma, + beta_ptr=beta, + y_ptr=y, + x_row_stride=x.stride(0), + y_row_stride=y.stride(0), + n_cols=n_cols, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + return y.view(*shape) + + +def liger_dyt_bwd(dy, x, alpha, gamma): + shape = dy.shape + dtype = x.dtype + dim = shape[-1] + dy = dy.view(-1, dim) + x = x.view(-1, dim) + n_rows, n_cols = dy.shape + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + sm_count = 1 + device = infer_device() + if device == "cuda": + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + elif device == "xpu": + sm_count = torch.xpu.get_device_properties(x.device).gpu_subslice_count + if n_cols > BLOCK_SIZE: + raise RuntimeError( + f"Feature dimension {dim} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension." + ) + + dx = torch.empty_like(x, dtype=torch.float32) + _dalpha = torch.empty((sm_count,), dtype=torch.float32, device=x.device) + _dgamma = torch.empty((sm_count, n_cols), dtype=torch.float32, device=x.device) + + grid = (sm_count,) + rows_per_program = triton.cdiv(n_rows, sm_count) + _dyt_bwd_kernel[grid]( + x_ptr=x, + x_row_stride=x.stride(0), + dy_ptr=dy, + dy_row_stride=dy.stride(0), + dx_ptr=dx, + dx_row_stride=dx.stride(0), + alpha_ptr=alpha, + dalpha_ptr=_dalpha, + gamma_ptr=gamma, + dgamma_ptr=_dgamma, + dgamma_row_stride=_dgamma.stride(0), + n_cols=n_cols, + n_rows=n_rows, + ROWS_PER_PROGRAM=rows_per_program, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + dalpha = _dalpha.sum(dim=0, keepdim=True).to(dtype) + dgamma = _dgamma.sum(dim=0).to(dtype) + dbeta = dy.sum(dim=0).to(dtype) + return dx.view(*shape), dalpha, dgamma, dbeta + + +class LigerDyTFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward(ctx, x, alpha, gamma, beta): + y = liger_dyt_fwd(x, alpha, gamma, beta) + ctx.save_for_backward(x, alpha, gamma) + return y + + @staticmethod + @ensure_contiguous + def backward(ctx, grad_output): + x, alpha, gamma = ctx.saved_tensors + dx, dalpha, dgamma, dbeta = liger_dyt_bwd( + grad_output, + x, + alpha, + gamma, + ) + + return (dx, dalpha, dgamma, dbeta) \ No newline at end of file diff --git a/build/torch-xpu/fused_linear_cross_entropy.py b/build/torch-xpu/fused_linear_cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..a923c05e268872a75ae473f23c7e64df491c866b --- /dev/null +++ b/build/torch-xpu/fused_linear_cross_entropy.py @@ -0,0 +1,283 @@ +import torch +import triton + +from .cross_entropy import liger_cross_entropy_kernel +from .utils import amp_custom_bwd +from .utils import amp_custom_fwd +from .utils import element_mul_kernel +from .utils import is_hip + +# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 +# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling +# The optimal maximum block size depends on your hardware, your kernel, and your dtype +MAX_FUSED_SIZE = 65536 // 2 + + +def fused_linear_cross_entropy_forward( + _input, + weight, + target, + ce_weight=None, + bias=None, + ignore_index=-100, + lse_square_scale=0.0, + label_smoothing=0.0, + reduction="mean", + softcap=None, + return_z_loss=False, +): + assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}" + device = _input.device + + # inputs have shape: BT x H + # materialized activations will have shape: BT x V + # the increase in memory = BT x V + # reduction can be achieved by partitioning the number of tokens BT into smaller chunks. + # for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be: + # inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor + # for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048 + BT, H = _input.shape + V = weight.shape[0] + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + + inc_factor = triton.cdiv(V, H) # (V + H - 1) // H + chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor + num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size + + grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None + grad_input = torch.zeros_like(_input, device=device) + grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None + # we use fp32 for loss accumulator + loss_1d = torch.zeros(BT, dtype=torch.float32, device=device) + z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None + + # TODO: evaluate how CUDA synchronization caused by .item() affects the speed + target_mask = target != ignore_index + total_n_non_ignore = target_mask.sum().item() + total_sum_non_ignore_ce_weight = total_n_non_ignore + ce_weight_sum = 0.0 + if ce_weight is not None: + assert ce_weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {ce_weight.shape}" + assert torch.is_floating_point(ce_weight), ( + f"If given, weight has to be a Tensor of floating point dtype. Got: {ce_weight.dtype}" + ) + total_sum_non_ignore_ce_weight = ( + torch.gather(ce_weight, dim=0, index=target.masked_select(target_mask)).sum().item() + ) + ce_weight_sum = ce_weight.sum().item() + if ce_weight.stride(-1) != 1: + ce_weight = ce_weight.contiguous() + + for chunk_id in range(num_chunks): + start_idx = chunk_id * chunk_size + end_idx = min((chunk_id + 1) * chunk_size, BT) + _input_chunk = _input[start_idx:end_idx] # chunk_size x H + + # when doing matmul, use the original precision + logits_chunk = _input_chunk @ weight.t() # chunk_size x V + if bias is not None: + logits_chunk = logits_chunk + bias + + target_chunk = target[start_idx:end_idx] # chunk_size, + + n_rows = logits_chunk.shape[0] + + # unreduced loss + loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size, + z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None + + # ensure _input and target are contiguous + logits_chunk = logits_chunk.contiguous() + target_chunk = target_chunk.contiguous() + + # Here we calculate the gradient of logits_chunk in place so we can save memory. + liger_cross_entropy_kernel[(n_rows,)]( + X_ptr=logits_chunk, + X_stride=logits_chunk.stride(-2), + Y_ptr=target_chunk, + Y_stride=target_chunk.stride(-1), # always 1 + weight_ptr=ce_weight, + loss_ptr=loss_1d_slice, + z_loss_ptr=z_loss_1d_slice, + loss_stride=loss_1d_slice.stride(-1), # always 1 + n_cols=V, + n_non_ignore=total_n_non_ignore, + sum_non_ignore_weight=total_sum_non_ignore_ce_weight, + weight_sum=ce_weight_sum, + ignore_index=ignore_index, + lse_square_scale=lse_square_scale, + label_smoothing=label_smoothing, + reduction=reduction, + softcap=softcap, + RETURN_Z_LOSS=return_z_loss, + HAS_WEIGHT=True if ce_weight is not None else False, + HAS_SOFTCAPPING=True if softcap is not None else False, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32 if not is_hip() else 16, + ) + + loss_1d[start_idx:end_idx] = loss_1d_slice + if return_z_loss: + z_loss_1d[start_idx:end_idx] = z_loss_1d_slice + grad_logits_chunk = logits_chunk # chunk_size x V + + grad_input[start_idx:end_idx] = grad_logits_chunk @ weight + + if grad_weight is not None: + torch.addmm( + input=grad_weight, + mat1=logits_chunk.t().to( + _input_chunk.dtype + ), # In an autocast scenario without bias, differing logits_chunk data types will cause an addmm operation error. + mat2=_input_chunk, + out=grad_weight, + alpha=1.0, + beta=1.0, + ) + + if bias is not None: + torch.add( + input=grad_bias, + other=logits_chunk.sum(dim=0), + out=grad_bias, + alpha=1.0, + ) + + # Need extra calculations for backward if reduction=='none'. Not supporting reduction='none' now. + # if reduction == "none": + # loss = loss_1d + # z_loss = z_loss_1d if return_z_loss else None + + else: + loss = torch.sum(loss_1d) + z_loss = torch.sum(z_loss_1d) if return_z_loss else None + return loss, z_loss, grad_input, grad_weight, grad_bias + + +def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias): + # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time + if not torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place + # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. + BT, H = grad_input.shape + n_rows = BT + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H)) + + element_mul_kernel[(n_rows,)]( + grad_input, + grad_input.stride(-2), + grad_output, + H, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32 if not is_hip() else 16, + ) + + # handle grad_weight + if grad_weight is not None: + V, H = grad_weight.shape + n_rows = V + + element_mul_kernel[(n_rows,)]( + grad_weight, + grad_weight.stride(-2), + grad_output, + H, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32 if not is_hip() else 16, + ) + + if grad_bias is not None: + V = grad_bias.shape[0] + n_rows = V + + element_mul_kernel[(n_rows,)]( + grad_bias, + grad_bias.stride(-1), + grad_output, + 1, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32 if not is_hip() else 16, + ) + return grad_input, grad_weight, grad_bias + + +class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function): + @staticmethod + @amp_custom_fwd + def forward( + ctx, + _input, + weight, + target, + bias=None, + ce_weight=None, + ignore_index=-100, + lse_square_scale=0.0, + label_smoothing=0.0, + reduction="mean", + softcap=None, + return_z_loss: bool = False, + ): + """ + Fusing the last linear layer with cross-entropy loss + Reference: https://github.com/mgmalek/efficient_cross_entropy + + Handle the forward and backward pass of the final linear layer via cross-entropy loss by avoiding + the materialization of the large logits tensor. Since Cross Entropy Loss is the last layer, we can + compute the gradient at the forward pass. By doing so, we don't have to store the _input and target + for the backward pass. + + _input: (B*T, H) where B is batch size, T is sequence length, H is hidden dimension. + target: (B*T) where each value is in [0, V-1] + weight: (V, H) where V is the number of classes + bias: (V) where V is the number of classes + ce_weight: a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype + ignore_index: the index to ignore in the target + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + reduction: reduction to apply + """ + + loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward( + _input=_input, + weight=weight, + target=target, + bias=bias, + ce_weight=ce_weight, + ignore_index=ignore_index, + lse_square_scale=lse_square_scale, + label_smoothing=label_smoothing, + reduction=reduction, + softcap=softcap, + return_z_loss=return_z_loss, + ) + # downcast to dtype and store for backward + ctx.save_for_backward( + grad_input.detach(), + grad_weight.detach() if grad_weight is not None else None, + grad_bias.detach() if bias is not None else None, + ) + ctx.return_z_loss = return_z_loss + return loss, z_loss + + @staticmethod + @amp_custom_bwd + def backward(ctx, grad_output, grad_output2): + if ctx.return_z_loss: + del grad_output2 # z_loss is only for logging + (grad_input, grad_weight, grad_bias) = ctx.saved_tensors + grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward( + grad_output, grad_input, grad_weight, grad_bias + ) + return ( + grad_input, + grad_weight, + None, + grad_bias, + None, + None, + None, + None, + None, + None, + None, + ) \ No newline at end of file diff --git a/build/torch-xpu/geglu.py b/build/torch-xpu/geglu.py new file mode 100644 index 0000000000000000000000000000000000000000..fdc06723b2852b6f5a6043ad45968af154f9afeb --- /dev/null +++ b/build/torch-xpu/geglu.py @@ -0,0 +1,141 @@ +import operator + +import torch +import triton +import triton.language as tl + +from .utils import calculate_settings +from .utils import compare_version +from .utils import ensure_contiguous + +if compare_version("triton", operator.ge, "3.0.0"): + try: + # typical import path with dispatch available + from triton.language.extra.libdevice import tanh + except ModuleNotFoundError: + # for working with NGC containers + from triton.language.extra.cuda.libdevice import tanh +else: + from triton.language.math import tanh + + +@triton.jit +def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr): + program_id = tl.program_id(0).to(tl.int64) + + # locate start index + a += program_id * stride + b += program_id * stride + c += program_id * stride + + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32) + b_row = tl.load(b + col_offsets, mask=mask, other=0) + + # tanh approximation form of GELU is computed with: + # 0.5 * a * (1 + tanh(sqrt(2 / pi) * (a + 0.044715 * a^3))) + sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi) + a_cubed = a_row * a_row * a_row + tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed) + tanh_result = tanh(tanh_arg) + geglu_a = 0.5 * a_row * (1 + tanh_result) + c_row = geglu_a * b_row + tl.store(c + col_offsets, c_row, mask=mask) + + +@triton.jit +def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr): + program_id = tl.program_id(0).to(tl.int64) + + # locate start index + dc += program_id * stride + a += program_id * stride + b += program_id * stride + + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + dc_row = tl.load(dc + col_offsets, mask=mask, other=0) + a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32) + b_row = tl.load(b + col_offsets, mask=mask, other=0) + + # recomputation to save memory + sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi) + a_cubed = a_row * a_row * a_row + tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed) + tanh_result = tanh(tanh_arg) + geglu_a = 0.5 * a_row * (1 + tanh_result) + + db_row = dc_row * geglu_a + + # Gradient w.r.t. a can be computed with: + # b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2))) + # where z = sqrt(2/pi) * (a + 0.044715 * a^3) + term1 = 0.5 * (1 + tanh_result) + tanh_sq = tanh_result * tanh_result + term2 = 0.5 * a_row * (1 - tanh_sq) * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row)) + da_row = dc_row * b_row * (term1 + term2) + + tl.store(a + col_offsets, da_row, mask=mask) + tl.store(b + col_offsets, db_row, mask=mask) + + +def geglu_forward(a, b): + ori_shape = a.shape + + n_cols = ori_shape[-1] + a = a.view(-1, n_cols) + b = b.view(-1, n_cols) + c = torch.empty_like(a) + n_rows = a.shape[0] + + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + + _geglu_tanh_forward_kernel[(n_rows,)]( + a, + b, + c, + c.stride(-2), + n_cols=n_cols, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + return a, b, c.view(*ori_shape) + + +def geglu_backward(a, b, dc): + ori_shape = dc.shape + n_cols = ori_shape[-1] + dc = dc.view(-1, n_cols) + n_rows = dc.shape[0] + + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + + _geglu_tanh_backward_kernel[(n_rows,)]( + dc, + a, + b, + dc.stride(-2), + n_cols=n_cols, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + + return a.view(*ori_shape), b.view(*ori_shape) + + +class LigerGELUMulFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward(ctx, a, b): + a, b, c = geglu_forward(a, b) + ctx.save_for_backward(a, b) + return c + + @staticmethod + @ensure_contiguous + def backward(ctx, dc): + a, b = ctx.saved_tensors + a, b = geglu_backward(a, b, dc) + return a, b \ No newline at end of file diff --git a/build/torch-xpu/group_norm.py b/build/torch-xpu/group_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..74b460c656e46d4544703345f9892763271588c2 --- /dev/null +++ b/build/torch-xpu/group_norm.py @@ -0,0 +1,305 @@ +import operator + +import torch +import triton +import triton.language as tl + +from .utils import compare_version +from .utils import ensure_contiguous + +if compare_version("triton", operator.ge, "3.0.0"): + try: + # typical import path with dispatch available + from triton.language.extra.libdevice import rsqrt + except ModuleNotFoundError: + # for working with NGC containers + from triton.language.extra.cuda.libdevice import rsqrt +else: + from triton.language.math import rsqrt + +MAX_FUSED_SIZE = 65536 + + +@triton.jit +def _group_norm_forward_kernel( + Y_ptr, # pointer to output, shape (n_rows, n_groups, hidden_size) + Y_row_stride, # stride of each row in output + Y_col_stride, # stride of each column in output + X_ptr, # pointer to input, shape (n_rows, n_groups, hidden_size) + X_row_stride, # stride of each row in input + X_col_stride, # stride of each column in input + Mean_ptr, # pointer to mean, shape (n_rows, n_groups) + Mean_row_stride, # stride of each row in mean + Mean_col_stride, # stride of each column in mean + RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups) + RSTD_row_stride, # stride of each row in rstd + RSTD_col_stride, # stride of each column in rstd + W_ptr, # pointer to W + B_ptr, # pointer to B + hidden_size, # hidden size of X + channels_per_group, # the number of channels per group + eps, + BLOCK_SIZE: tl.constexpr, +): + """ + References: + https://nn.labml.ai/normalization/group_norm/index.html + """ + batch_idx = tl.program_id(0) + group_idx = tl.program_id(1) + + X_ptr += batch_idx * X_row_stride + group_idx * X_col_stride + Y_ptr += batch_idx * Y_row_stride + group_idx * Y_col_stride + + block_range = tl.arange(0, BLOCK_SIZE) + + # Compute mean and variance using the online algorithm + s = 0.0 + squared_sum = 0.0 + for i in tl.range(0, hidden_size, BLOCK_SIZE): + hidden_size_offsets = i + block_range + mask = hidden_size_offsets < hidden_size + X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=0.0) + s += tl.sum(X) + # X**2 + squared_sum += tl.sum(X * X) + + m = s / hidden_size + + # variance = E[X**2] - E[X]**2 + variance = (squared_sum / hidden_size) - (m * m) + + # 1/std + rstd = rsqrt(variance + eps) + + # Normalize + hidden_size_per_channel = hidden_size // channels_per_group + for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group): + W = tl.load(W_ptr + channel_idx) + B = tl.load(B_ptr + channel_idx) + for i in range(0, hidden_size_per_channel, BLOCK_SIZE): + hidden_size_offsets = i + block_range + mask = hidden_size_offsets < hidden_size_per_channel + X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m) + Y = (X - m) * rstd * W + B + tl.store(Y_ptr + hidden_size_offsets, Y, mask=mask) + + X_ptr += hidden_size_per_channel + Y_ptr += hidden_size_per_channel + + tl.store(Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, m) + tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd) + + +@triton.jit +def _group_norm_backward_kernel( + X_ptr, # pointer to input, shape (n_rows, n_channels, hidden_size) + X_row_stride, # stride of each row in input + X_col_stride, # stride of each column in input + W_ptr, # pointer to weights, shape (n_channels) + Mean_ptr, # pointer to mean, shape (n_rows, n_groups) + Mean_ptr_row_stride, # stride of each column in mean + Mean_ptr_col_stride, # stride of each column in mean + RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups) + DX_ptr, # pointer to input grad, shape (n_rows, n_groups, hidden_size) + DW_ptr, # pointer to weights grad, shape (n_channels) + DB_ptr, # pointer to bias grad, shape (n_channels) + UPSTREAM_ptr, # pointer to output grad, shape (n_rows, n_channels, hidden_size) + hidden_size: tl.constexpr, # hidden size + channels_per_group: tl.constexpr, # number of groups in group norm + BLOCK_SIZE: tl.constexpr, + dtype: tl.constexpr, +): + """ + References: + https://nn.labml.ai/normalization/group_norm/index.html + https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md + + The backprop equations are the same for group_norm and layer_norm + the only difference here is that we load the Mean, Rstd corresponding to the + group we're computing gradients for and the mean and rstd are computed over n-channels + so the total number of elements we compute the mean over is num_channels_per_group * hidden_size + + We also need to load the Weights corresponding to the current channel to compute the gradients. + """ + batch_idx = tl.program_id(0) + group_idx = tl.program_id(1) + + # Move the pointers to the correct batch + X_ptr += batch_idx * X_row_stride + DX_ptr += batch_idx * X_row_stride + UPSTREAM_ptr += batch_idx * X_row_stride + + # Mean and rstd are the same shape so have the same strides + mean = tl.load(Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride) + rstd = tl.load(RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride) + + c1 = 0.0 + c2 = 0.0 + block_range = tl.arange(0, BLOCK_SIZE) + + # We need to compute the sum terms of the backprop equations across all channels in the group + for channel_idx in range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group): + dW = 0.0 + dB = 0.0 + # Move the pointers to the correct channel + W = tl.load(W_ptr + channel_idx) + for i in tl.range(0, hidden_size, BLOCK_SIZE): + hidden_size_offsets = i + block_range + mask = hidden_size_offsets < hidden_size + X = tl.load( + X_ptr + channel_idx * X_col_stride + hidden_size_offsets, + mask=mask, + other=0.0, + ) + UPSTREAM_grad = tl.load( + UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets, + mask=mask, + other=0.0, + ) + + x_hat = (X - mean) * rstd + dW += tl.sum(UPSTREAM_grad * x_hat) + dB += tl.sum(UPSTREAM_grad) + + wdy = W * UPSTREAM_grad + c1 += tl.sum(x_hat * wdy) + c2 += tl.sum(wdy) + + # Need to ensure additions to the same channel are atomic + tl.atomic_add(DW_ptr + channel_idx, dW.to(dtype)) + tl.atomic_add(DB_ptr + channel_idx, dB.to(dtype)) + + N = hidden_size * channels_per_group + c1 = c1 / N + c2 = c2 / N + + for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group): + # Move the pointers to the correct channel + W = tl.load(W_ptr + channel_idx) + for i in range(0, hidden_size, BLOCK_SIZE): + hidden_size_offsets = i + block_range + mask = hidden_size_offsets < hidden_size + X = tl.load( + X_ptr + channel_idx * X_col_stride + hidden_size_offsets, + mask=mask, + other=0.0, + ) + UPSTREAM_grad = tl.load( + UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets, + mask=mask, + other=0.0, + ) + + x_hat = (X - mean) * rstd + wdy = W * UPSTREAM_grad + dx = (wdy - (x_hat * c1 + c2)) * rstd + tl.store(DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask) + + +def group_norm_forward(X, num_channels, num_groups, W, B, eps): + shape = X.shape + batch_size = shape[0] + channels_per_group = num_channels // num_groups + # Reshape X so that the mean and std are computed across the groups + X = X.view(batch_size, num_groups, -1).contiguous() + hidden_size = X.shape[-1] + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size)) + Y = torch.empty((batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device) + Mean = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device) + RSTD = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device) + + _group_norm_forward_kernel[(batch_size, num_groups)]( + Y, + Y.stride(0), + Y.stride(1), + X, + X.stride(0), + X.stride(1), + Mean, + Mean.stride(0), + Mean.stride(1), + RSTD, + RSTD.stride(0), + RSTD.stride(1), + W, + B, + hidden_size, + channels_per_group, + eps, + BLOCK_SIZE=BLOCK_SIZE, + ) + # Return tensors in the original shape + return Y.view(*shape), X.view(*shape), Mean, RSTD, BLOCK_SIZE + + +def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups): + shape = dY.shape + batch_size = shape[0] + hidden_size = dY.shape[-1] + channels_per_group = num_channels // num_groups + dY = dY.view(batch_size, num_groups, -1) + DX = torch.empty( + (batch_size, num_groups, hidden_size * channels_per_group), + dtype=X.dtype, + device=X.device, + ) + DW = torch.zeros((num_channels), dtype=W.dtype, device=W.device) + DB = torch.zeros((num_channels), dtype=B.dtype, device=B.device) + triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16 + + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size)) + _group_norm_backward_kernel[(batch_size, num_groups)]( + X, + X.stride(0), + X.stride(1), + W, + Mean, + Mean.stride(0), + Mean.stride(1), + RSTD, + DX, + DW, + DB, + dY, + hidden_size, + channels_per_group, + BLOCK_SIZE=BLOCK_SIZE, + dtype=triton_dtype, + ) + + # Return tensors in the original shape + return DX.view(*shape), DW, DB + + +class LigerGroupNormFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward( + ctx, + X, + affine_scaling_weight, + affine_shifting_bias, + num_channels, + num_groups, + eps, + ): + Y, X, Mean, RSTD, BLOCK_SIZE = group_norm_forward( + X, + num_channels, + num_groups, + affine_scaling_weight, + affine_shifting_bias, + eps, + ) + ctx.num_channels = num_channels + ctx.num_groups = num_groups + ctx.save_for_backward(X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD) + return Y + + @staticmethod + @ensure_contiguous + def backward(ctx, dY): + X, W, B, Mean, RSTD = ctx.saved_tensors + DX, DW, DB = group_norm_backward(dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups) + return DX, DW, DB, None, None, None \ No newline at end of file diff --git a/build/torch-xpu/jsd.py b/build/torch-xpu/jsd.py new file mode 100644 index 0000000000000000000000000000000000000000..b879e0674471983e789f9b05d3111b51fb172671 --- /dev/null +++ b/build/torch-xpu/jsd.py @@ -0,0 +1,201 @@ +from typing import Optional + +import torch +import triton +import triton.language as tl + +from .utils import ensure_contiguous +from .utils import infer_device + + +@triton.jit +def _jsd_kernel( + X_ptr, # input in logspace, X = log Q + X_stride, + Y_ptr, # ground truth in logspace, Y = log P + Y_stride, + loss_ptr, + loss_stride, + dX_ptr, + dX_stride, + label_ptr, + beta: tl.constexpr, + n_non_ignore: int, + ignore_index: tl.constexpr, + n_cols, + BLOCK_SIZE: tl.constexpr, + HAS_LABEL: tl.constexpr, +): + # JSD(P || Q) = (KL(P || M) + KL(Q || M)) / 2, M = (1/2) * (P + Q) = (1/2) * (e ^ Y + e ^ X) + # = sum(P * log P + Q * log Q - 2 * M * log M) / 2 + # = sum(e ^ Y * Y + e ^ X * X - 2 * M * log M) / 2 + # grad_x_i = 0.5 * Q * (X - log_M) + pid = tl.program_id(0).to(tl.int64) + X_ptr += pid * X_stride + dX_ptr += pid * dX_stride + Y_ptr += pid * Y_stride + loss_ptr += pid * loss_stride + label_ptr += pid + + if HAS_LABEL: + label = tl.load(label_ptr) + if label == ignore_index: + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + tl.arange(0, BLOCK_SIZE) + tl.store(dX_ptr + offsets, 0.0, mask=offsets < n_cols) + return + + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_cols + X = tl.load(X_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32) + Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32) + + if beta == 0.0: # forward KL + Y_max = tl.max(Y, axis=0) + Y_shifted = Y - Y_max + Y_prob = tl.exp(Y_shifted) * tl.exp(Y_max) # Compensate for the shift + loss = Y_prob * (Y - X) + dX = -Y_prob + elif beta == 1.0: # reverse KL + X_max = tl.max(X, axis=0) + X_shifted = X - X_max + X_prob = tl.exp(X_shifted) * tl.exp(X_max) # Compensate for the shift + loss = X_prob * (X - Y) + dX = loss + X_prob + else: + max_val = tl.maximum(tl.max(X, axis=0), tl.max(Y, axis=0)) + X_shifted = X - max_val + Y_shifted = Y - max_val + + # Pre-compute exp(max_val) since it's used twice + exp_max = tl.exp(max_val) + + # Compute exp terms with compensation + Q = tl.exp(X_shifted) * exp_max # = exp(X) + P = tl.exp(Y_shifted) * exp_max # = exp(Y) + + # Pre-compute common terms + beta_P = beta * P + one_minus_beta_Q = (1 - beta) * Q + M = beta_P + one_minus_beta_Q + log_M = tl.log(M) # No need to compensate as M is already in original scale + + loss = beta_P * Y + one_minus_beta_Q * X - M * log_M + dX = one_minus_beta_Q * (X - log_M) + + # Pre-compute scaling factor + scale = 1.0 / n_non_ignore + loss = loss * scale + dX = dX * scale + + tl.store(loss_ptr + offsets, loss, mask=mask) + tl.store(dX_ptr + offsets, dX, mask=mask) + + +MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 + + +def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label): + BT, V = _input.shape + n_rows = BT + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + # non reduction loss + loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device) + dX = torch.empty_like(_input) + + if has_label: + n_non_ignore = (shift_labels != ignore_index).sum().item() + else: + n_non_ignore = BT + + _jsd_kernel[(n_rows,)]( + X_ptr=_input, # input in logspace, X = log Q + X_stride=_input.stride(-2), + Y_ptr=target, # ground truth in logspace, Y = log P + Y_stride=target.stride(-2), + loss_ptr=loss, + loss_stride=loss.stride(-2), + dX_ptr=dX, + dX_stride=dX.stride(-2), + label_ptr=(shift_labels if has_label else torch.empty(1, device=_input.device)), # dummy ptr if no label + beta=beta, + n_non_ignore=n_non_ignore, + ignore_index=ignore_index, + n_cols=V, + BLOCK_SIZE=BLOCK_SIZE, + HAS_LABEL=has_label, + ) + + loss = torch.sum(loss) + return loss.to(_input.dtype), dX + + +def jsd_backward(dX, grad_output): + # If jsd is the last layer, grad_output is 1.0. Skip the mul to save time + if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + return dX + else: + return grad_output * dX + + +class LigerJSDFunction(torch.autograd.Function): + r""" + This class implements the forward and backward pass for the generalized Jensen-Shannon Divergence. + .. math:: + JSD(\beta)(P || Q) + = \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q)) + + .. note:: + As all the other losses in PyTorch, this function expects the first argument, + :attr:`_input`, to be the predictions, the output of the student model, in log-space + and the second, :attr:`target`, to be the observations, the output of the teacher model, in log-space. + This differs from the standard mathematical notation :math:`JSD(P || Q)` where + :math:`P` denotes the teacher model and :math:`Q` denotes the student model. + """ + + @staticmethod + @ensure_contiguous + def forward( + ctx, + _input: torch.Tensor, + target: torch.Tensor, + shift_labels: Optional[torch.Tensor] = None, + beta: float = 0.5, + ignore_index: int = -100, + ) -> torch.Tensor: + """ + Args: + _input (torch.Tensor): predict values with shape (BT, V) in logspace + target (torch.Tensor): ground truth values with shape (BT, V) in logspace + shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1]. + beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5` + ignore_index (int): the index to ignore. Default: -100 + + Returns: + loss (torch.Tensor): generalized JSD + """ + has_label = False + if shift_labels is not None: + assert shift_labels.shape == (_input.shape[0],), ( + f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}" + ) + shift_labels = shift_labels.contiguous() + has_label = True + + loss, dX = jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label) + ctx.save_for_backward(dX) + return loss + + @staticmethod + @ensure_contiguous + def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: + (dX,) = ctx.saved_tensors + dX = jsd_backward(dX, grad_output) + return ( + dX, + None, + None, + None, + None, + ) \ No newline at end of file diff --git a/build/torch-xpu/kl_div.py b/build/torch-xpu/kl_div.py new file mode 100644 index 0000000000000000000000000000000000000000..2d563a7eae13aac671e3cf85283758b50a3fdb93 --- /dev/null +++ b/build/torch-xpu/kl_div.py @@ -0,0 +1,262 @@ +from typing import Literal + +import torch +import triton +import triton.language as tl + +from .utils import ensure_contiguous +from .utils import is_hip +from .utils import infer_device + + +def get_num_warps(BLOCK_SIZE): + num_warps = 4 + if BLOCK_SIZE >= 32768: + num_warps = 32 if not is_hip() else 16 + elif BLOCK_SIZE >= 8192: + num_warps = 16 + elif BLOCK_SIZE >= 2048: + num_warps = 8 + + return num_warps + + +MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best + +REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"] + +_REDUCTION_MODE_NONE: tl.constexpr = tl.constexpr(0) +_REDUCTION_MODE_SUM: tl.constexpr = tl.constexpr(1) +_REDUCTION_MODE_MEAN: tl.constexpr = tl.constexpr(2) +_REDUCTION_MODE_BATCHMEAN: tl.constexpr = tl.constexpr(3) + +_str_to_reduction_mode = { + "none": _REDUCTION_MODE_NONE.value, + "sum": _REDUCTION_MODE_SUM.value, + "mean": _REDUCTION_MODE_MEAN.value, + "batchmean": _REDUCTION_MODE_BATCHMEAN.value, +} + + +@triton.jit +def _kldiv_kernel_forward( + y_ptr, # [B, S], prediction ptr, the kernel expects the prediction in log-space + y_stride, # int, prediction stride + gt_ptr, # [B, S], ground truth ptr + gt_stride, # int, ground truth stride + loss_ptr, # [B] or [B, S] if reduction == _REDUCTION_MODE_NONE, output ptr + loss_stride, # int, output stride + n_cols, # int, number of columns in the input tensor + eps, + BLOCK_SIZE: tl.constexpr, + log_target: tl.constexpr = False, + reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN, +): + pid = tl.program_id(0).to(tl.int64) + y_ptr += pid * y_stride + gt_ptr += pid * gt_stride + loss_ptr += pid * loss_stride + + base_offsets = tl.arange(0, BLOCK_SIZE) + + loss_sum = 0.0 + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + base_offsets + mask = offsets < n_cols + y = tl.load(y_ptr + offsets, mask=mask, other=0.0) + y_true = tl.load(gt_ptr + offsets, mask=mask, other=0.0) + + # KL(y_true || y) = y_true * (log(y_true) - log(y)) + # We compute KL(y_true || y) with y in the log-space + if not log_target: + loss = y_true * (tl.log(tl.maximum(y_true, eps)) - y) + else: + loss = tl.exp(y_true) * (y_true - y) + + if reduction == _REDUCTION_MODE_NONE: + tl.store(loss_ptr + offsets, loss, mask=mask) + else: + loss_sum += tl.sum(loss, axis=0) + + if reduction != _REDUCTION_MODE_NONE: + tl.store(loss_ptr, loss_sum) + + +@triton.jit +def _kldiv_kernel_backward( + target_ptr, + target_stride, + new_grads_ptr, + new_grads_stride, + n_cols, + BLOCK_SIZE: tl.constexpr, + log_target: tl.constexpr = False, +): + pid = tl.program_id(0).to(tl.int64) + + target_ptr += pid * target_stride + new_grads_ptr += pid * new_grads_stride + + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < n_cols + + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_cols + + target = tl.load(target_ptr + offsets, mask=mask, other=0.0) + + if not log_target: + res = target * -1 + else: + res = -tl.exp(target) + + tl.store(new_grads_ptr + offsets, res, mask=mask) + + +def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V] + BT, V = y_pred.shape + BLOCK_SIZE = ( + min(8192, triton.next_power_of_2(V)) + if infer_device() == "xpu" + else min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + ) + num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE) + + grid = (BT,) + reduction = _str_to_reduction_mode[reduction] + + out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,) + output_tensor = torch.zeros(out_size, device=y_pred.device, dtype=torch.float32) + + _kldiv_kernel_forward[grid]( + y_pred, + y_pred.stride(0), + y_true, + y_true.stride(0), + output_tensor, + output_tensor.stride(0), + V, + eps=eps, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + log_target=log_target, + reduction=reduction, + ) + + # calculated according to the reduction mode same as in Pytorch. In the later versions, `mean` will be changed to the same behavior as `batchmean` + # https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html + # https://github.com/pytorch/pytorch/blob/d7b57c4d63edb42e1deeeba9497fcb5f1f748ff2/torch/nn/functional.py#L3372 + if reduction == _REDUCTION_MODE_BATCHMEAN.value: + return output_tensor.sum() / BT + elif reduction == _REDUCTION_MODE_SUM.value: + return output_tensor.sum(dim=0) + elif reduction == _REDUCTION_MODE_MEAN.value: + return output_tensor.sum() / (BT * V) + else: + return output_tensor + + +def kldiv_backward_triton(target, grad_output, new_grads, log_target): + BT, V = target.shape + BLOCK_SIZE = ( + min(8192, triton.next_power_of_2(V)) + if infer_device() == "xpu" + else min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + ) + num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE) + + grid = (BT,) + + # We store the gradients in-place in the input tensor + _kldiv_kernel_backward[grid]( + target, + target.stride(0), + new_grads, + new_grads.stride(0), + V, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + log_target=log_target, + ) + + # If cross entropy is the last layer, grad_output is 1.0. Skip the mul then. + if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + return new_grads + + return new_grads * grad_output + + +class LigerKLDivLossFunction(torch.autograd.Function): + """ + Class implementing the forward and backward pass for the KL Divergence Loss using Triton, as defined by the following formula: + ```python + if log_target: + loss = target.exp() * (target - input) + else: + loss = target * (target.log() - input) + ```, + then the loss is reduced according to the `reduction` parameter. + as defined in the PyTorch documentation: https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html + """ + + @staticmethod + @ensure_contiguous + def forward( + ctx, + y_pred: torch.Tensor, + y_true: torch.Tensor, + reduction: REDUCTION_LITERAL = "batchmean", + log_target: bool = False, + eps: float = 1e-10, + ) -> torch.Tensor: + """A forward pass for the KL Divergence Loss. + + Args: + ctx: Torch autograd context + y_pred (torch.Tensor): A tensor of shape (BT, V) containing the predicted values, expected to be log-probabilities. + y_true (torch.Tensor): A tensor of shape (BT, V) containing the target values, expected to be either probabilities or log-probabilities, depending on the value of `log_target`. + reduction (REDUCTION_LITERAL, optional): Reduction to be used. Defaults to "batchmean". + log_target (bool, optional): If set to true, expects the ground truth to already be log-probabilities. Defaults to False. + eps: (float, optional): A small value to avoid division by zero. Defaults to 1e-10. + + Returns: + torch.Tensor: The computed KL Divergence Loss, with shape (BT, V) if `reduction` is "none", else a scalar. + """ + ctx.save_for_backward(y_true) + ctx.reduction = reduction + ctx.log_target = log_target + return kldiv_forward_triton(y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps) + + @staticmethod + @ensure_contiguous + def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: + """A backward pass for the KL Divergence Loss. + + Args: + ctx: Torch autograd context + grad_output (torch.Tensor): The gradient of the loss with respect to the output. + + Returns: + tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs and None for the other arguments of the forward method. + """ + (y_true,) = ctx.saved_tensors + + new_grads = torch.empty_like(y_true) + + derivative = kldiv_backward_triton(y_true, grad_output, new_grads, ctx.log_target) + + if ctx.reduction == "batchmean": + derivative = derivative / y_true.shape[0] + elif ctx.reduction == "sum" or ctx.reduction == "none": + pass + elif ctx.reduction == "mean": + derivative = derivative / (y_true.shape[0] * y_true.shape[1]) + + return ( + derivative, + None, + None, + None, + None, + ) \ No newline at end of file diff --git a/build/torch-xpu/layer_norm.py b/build/torch-xpu/layer_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..00088223b1c58a5483c4227a2dc6be7f26cc1c3c --- /dev/null +++ b/build/torch-xpu/layer_norm.py @@ -0,0 +1,265 @@ +import math +import operator + +import torch +import triton +import triton.language as tl + +from .utils import calculate_settings +from .utils import compare_version +from .utils import ensure_contiguous + +if compare_version("triton", operator.ge, "3.0.0"): + try: + # typical import path with dispatch available + from triton.language.extra.libdevice import rsqrt + except ModuleNotFoundError: + # for working with NGC containers + from triton.language.extra.cuda.libdevice import rsqrt +else: + from triton.language.math import rsqrt + + +@triton.jit +def _layer_norm_forward_kernel( + Y_ptr, # pointer to output, shape (n_rows, n_cols) + Y_row_stride, # stride of each row in output + X_ptr, # pointer to input, shape (n_rows, n_cols) + X_row_stride, # stride of each row in input + W_ptr, # pointer to weights, shape (n_cols,) + W_row_stride, # stride of each row in weights + B_ptr, # pointer to bias, shape (n_cols,) + B_row_stride, # stride of each row in bias + Mean_ptr, # pointer to mean, shape (n_rows,) + Mean_row_stride, # stride of each row in mean + RSTD_ptr, # pointer to rstd, shape (n_rows,) + RSTD_row_stride, # stride of each row in rstd + n_cols, + eps, + BLOCK_SIZE: tl.constexpr, +): + """ + References: + https://arxiv.org/abs/1607.06450 + https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md + """ + row_idx = tl.program_id(0) + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + Y_ptr += row_idx * Y_row_stride + X_ptr += row_idx * X_row_stride + Mean_ptr += row_idx * Mean_row_stride + RSTD_ptr += row_idx * RSTD_row_stride + + X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0) + W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0) + B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0) + + mean = tl.sum(X_row, axis=0) / n_cols + Xmm = tl.where(mask, X_row - mean, 0) + var = tl.sum(Xmm * Xmm, axis=0) / n_cols + rstd = rsqrt(var + eps) + + tl.store(Mean_ptr, mean) + tl.store(RSTD_ptr, rstd) + + Y_row = Xmm * rstd * W_row + B_row + + tl.store(Y_ptr + col_offsets, Y_row, mask=mask) + + +@triton.jit +def _layer_norm_backward_kernel( + X_ptr, # pointer to input, shape (n_rows, n_cols) + W_ptr, # pointer to weights, shape (n_cols,) + Mean_ptr, # pointer to mean, shape (n_rows,) + RSTD_ptr, # pointer to rstd, shape (n_rows,) + DX_ptr, # pointer to input grad, shape (n_rows, n_cols) + DW_ptr, # pointer to weights grad, shape (n_cols,) + DB_ptr, # pointer to bias grad, shape (n_cols,) + DY_ptr, # pointer to output grad, shape (n_rows, n_cols) + stride_x, # stride of each row in input + stride_dx, # stride of each row in input grad + stride_dw, # stride of each row in weights grad + stride_db, # stride of each row in bias grad + stride_dy, # stride of each row in output grad + n_rows, + n_cols, + rows_per_program: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + dtype: tl.constexpr, +): + """ + References: + https://arxiv.org/abs/1607.06450 + https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md + https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html + https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py + """ + row_block_id = tl.program_id(0) + row_start = row_block_id * rows_per_program + row_end = min((row_block_id + 1) * rows_per_program, n_rows) + cols = tl.arange(0, BLOCK_SIZE) + mask = cols < n_cols + + dw_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + + X_ptr += row_start * stride_x + Mean_ptr += row_start + RSTD_ptr += row_start + DX_ptr += row_start * stride_dx + DY_ptr += row_start * stride_dy + + for _ in range(row_start, row_end): + x = tl.load(X_ptr + cols, mask=mask, other=0.0) + w = tl.load(W_ptr + cols, mask=mask, other=0.0) + dy = tl.load(DY_ptr + cols, mask=mask, other=0.0) + mean = tl.load(Mean_ptr) + rstd = tl.load(RSTD_ptr) + + x_hat = (x - mean) * rstd + wdy = w * dy + c1 = tl.sum(x_hat * wdy, axis=0) / n_cols + c2 = tl.sum(wdy, axis=0) / n_cols + dx = (wdy - (x_hat * c1 + c2)) * rstd + tl.store(DX_ptr + cols, dx.to(dtype), mask=mask) + + dw_row += dy * x_hat + db_row += dy + + X_ptr += stride_x + Mean_ptr += 1 + RSTD_ptr += 1 + DX_ptr += stride_dx + DY_ptr += stride_dy + + tl.store(DW_ptr + row_block_id * stride_dw + cols, dw_row.to(dtype), mask=mask) + tl.store(DB_ptr + row_block_id * stride_db + cols, db_row.to(dtype), mask=mask) + + +def layer_norm_forward(X, W, B, eps): + shape = X.shape + dim = shape[-1] + X = X.view(-1, dim) + n_rows, n_cols = X.shape + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) + Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device) + RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device) + if X.shape[1] != W.shape[0]: + raise ValueError( + f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) " + f"must match weight size (W.shape[0]={W.shape[0]})" + ) + + # XPU-specific optimization + kernel_args = {} + if X.device.type == "xpu": + kernel_args["grf_mode"] = "large" + + _layer_norm_forward_kernel[(n_rows,)]( + Y, + Y.stride(0), + X, + X.stride(0), + W, + W.stride(0), + B, + B.stride(0), + Mean, + Mean.stride(0), + RSTD, + RSTD.stride(0), + n_cols, + eps, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + **kernel_args, # XPU-specific optimization + ) + return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps + + +def layer_norm_backward(dY, X, W, B, Mean, RSTD): + shape = dY.shape + dim = shape[-1] + dY = dY.view(-1, dim) + n_rows, n_cols = dY.shape + + sm_count = 1 + if X.device.type == "cuda": + sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count + elif X.device.type == "xpu": + sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count + + DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) + _DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device) + _DB = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device) + + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + if n_cols > BLOCK_SIZE: + raise RuntimeError( + f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension." + ) + + rows_per_program = math.ceil(n_rows / sm_count) + grid = (sm_count,) + triton_dtype = ( + tl.float32 + if X.dtype == torch.float32 + else tl.bfloat16 + if X.dtype == torch.bfloat16 + else tl.float16 + if X.dtype == torch.float16 + else tl.float32 # fallback to float32 for other types + ) + + # XPU-specific optimization + kernel_args = {} + if X.device.type == "xpu": + kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4}) + + _layer_norm_backward_kernel[grid]( + X, + W, + Mean, + RSTD, + DX, + _DW, + _DB, + dY, + X.stride(0), + DX.stride(0), + _DW.stride(0), + _DB.stride(0), + dY.stride(0), + n_rows, + n_cols, + rows_per_program, + BLOCK_SIZE=BLOCK_SIZE, + dtype=triton_dtype, + **kernel_args, # XPU-specific optimization + ) + + DW = _DW.sum(dim=0).to(W.dtype) + DB = _DB.sum(dim=0).to(W.dtype) + + DX = DX.view(*shape) + return DX, DW, DB + + +class LigerLayerNormFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward(ctx, X, W, B, eps): + Y, X, Mean, RSTD, BLOCK_SIZE, num_warps = layer_norm_forward(X, W, B, eps) + ctx.save_for_backward(X, W, B, Mean, RSTD) + return Y + + @staticmethod + @ensure_contiguous + def backward(ctx, dY): + X, W, B, Mean, RSTD = ctx.saved_tensors + DX, DW, DB = layer_norm_backward(dY, X, W, B, Mean, RSTD) + return DX, DW, DB, None \ No newline at end of file diff --git a/build/torch-xpu/layers.py b/build/torch-xpu/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..798de28de6bcb77a98416812dcae89cb9e5014df --- /dev/null +++ b/build/torch-xpu/layers.py @@ -0,0 +1,39 @@ +import torch +from .rms_norm import LigerRMSNormFunction + +class LigerRMSNorm(torch.nn.Module): + """ + RMSNorm module that uses the optimized LigerRMSNormFunction. + + Args: + hidden_size (int): The size of the hidden dimension. + eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6. + offset (float, optional): Offset value to shift the weight tensor. Defaults to 0.0. + casting_mode (str, optional): The casting mode to use. Defaults to "llama". + in_place (bool, optional): Whether to modify dY in-place to store dX during backward. Defaults to True. + """ + + + weight: torch.Tensor + variance_epsilon: float + + def forward(self, hidden_states): + """ + Apply RMS normalization to the input tensor. + + Args: + hidden_states (torch.Tensor): Input tensor of shape (B, T, H) or (BxT, H) + + Returns: + torch.Tensor: Normalized tensor of the same shape as input + """ + return LigerRMSNormFunction.apply( + hidden_states, + self.weight, + self.variance_epsilon, + 0, + "llama", + True + ) + +__all__ = ["LigerRMSNorm"] \ No newline at end of file diff --git a/build/torch-xpu/liger_kernels/__init__.py b/build/torch-xpu/liger_kernels/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch-xpu/liger_kernels/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch-xpu/metadata.json b/build/torch-xpu/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..9cf5deed9898dce769f4cc73913d3530b92a0bd8 --- /dev/null +++ b/build/torch-xpu/metadata.json @@ -0,0 +1,4 @@ +{ + "version": 1, + "python-depends": [] +} \ No newline at end of file diff --git a/build/torch-xpu/qwen2vl_mrope.py b/build/torch-xpu/qwen2vl_mrope.py new file mode 100644 index 0000000000000000000000000000000000000000..33d02e7bcfb95402ccd3c1b0ad0ed8780ef422e4 --- /dev/null +++ b/build/torch-xpu/qwen2vl_mrope.py @@ -0,0 +1,222 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _triton_qwen2vl_mrope( + q_ptr, + k_ptr, + cos, + sin, + sl, + bs: tl.constexpr, + n_qh: tl.constexpr, + n_kh: tl.constexpr, + hd: tl.constexpr, + pad_n_qh: tl.constexpr, + pad_n_kh: tl.constexpr, + pad_hd: tl.constexpr, + mrope_section_t: tl.constexpr, + mrope_section_h: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BACKWARD_PASS: tl.constexpr = False, +): + pid = tl.program_id(0) + + # locate start address + q_ptr = q_ptr + pid * (n_qh * hd) + k_ptr = k_ptr + pid * (n_kh * hd) + + # #################################################################### + # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position + # m of this program instance + # #################################################################### + + # 1. program instances are laid out in a 1D vector of size bsz * seq_len, which + # effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension + # being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index + # and pid % sl to get the sequence index. + # 2. We only need the left half of cos and sin matrix because the right half is just + # a clone of the left half. + t_end = mrope_section_t + h_end = t_end + mrope_section_h + + t_cos = cos + pid * hd + h_cos = t_cos + bs * sl * hd + w_cos = h_cos + bs * sl * hd + t_sin = sin + pid * hd + h_sin = t_sin + bs * sl * hd + w_sin = h_sin + bs * sl * hd + + cos_offsets = tl.arange(0, pad_hd // 2) + t_mask = cos_offsets < t_end + h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end) + w_mask = (h_end <= cos_offsets) & (cos_offsets < hd // 2) + t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0) + h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0) + w_cos_row = tl.load(w_cos + cos_offsets, mask=w_mask, other=0) + t_sin_row = tl.load(t_sin + cos_offsets, mask=t_mask, other=0) + h_sin_row = tl.load(h_sin + cos_offsets, mask=h_mask, other=0) + w_sin_row = tl.load(w_sin + cos_offsets, mask=w_mask, other=0) + cos_row = t_cos_row + h_cos_row + w_cos_row + sin_row = t_sin_row + h_sin_row + w_sin_row + + # #################################################################### + # Load the left and right half of q and k for the current + # program instance (i.e. for the current token) separately + # #################################################################### + # left half of the head + first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] + first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] + first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2) + first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2) + q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype) + k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype) + + # right half of the head + second_half_q_offsets = first_half_q_offsets + (hd // 2) + second_half_k_offsets = first_half_k_offsets + (hd // 2) + second_q_mask = first_q_mask + second_k_mask = first_k_mask + q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype) + k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype) + + if not BACKWARD_PASS: + # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin] + new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row + tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) + new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row + tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) + + new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row + tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask) + new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row + tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) + else: + # with some math, we can get: + # dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin] + new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row + tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) + new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row + tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) + + new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row + tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask) + new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row + tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) + + +def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section): + # transpose it back to the physical shape because Triton looks at the physical storage + # note: q and k are incontiguous before the transformation and will become contiguous after transpose + q = q.transpose(1, 2) + k = k.transpose(1, 2) + + batch_size, seq_len, n_q_head, head_dim = q.shape + n_kv_head = k.shape[2] + pad_hd = triton.next_power_of_2(head_dim) + pad_n_q_head = triton.next_power_of_2(n_q_head) + pad_n_kv_head = triton.next_power_of_2(n_kv_head) + BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head) + + n_row = batch_size * seq_len + + # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous + q = q.contiguous() + k = k.contiguous() + cos = cos.contiguous() + sin = sin.contiguous() + + _triton_qwen2vl_mrope[(n_row,)]( + q, + k, + cos, + sin, + seq_len, + batch_size, + n_q_head, + n_kv_head, + head_dim, + pad_n_q_head, + pad_n_kv_head, + pad_hd, + mrope_section[0], + mrope_section[1], + BLOCK_SIZE=BLOCK_SIZE, + BACKWARD_PASS=False, + ) + return q.transpose(1, 2), k.transpose(1, 2), cos, sin + + +def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section): + dq = dq.transpose(1, 2) + dk = dk.transpose(1, 2) + + batch_size, seq_len, n_q_head, head_dim = dq.shape + n_kv_head = dk.shape[2] + pad_hd = triton.next_power_of_2(head_dim) + pad_n_q_head = triton.next_power_of_2(n_q_head) + pad_n_kv_head = triton.next_power_of_2(n_kv_head) + BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head) + + n_row = batch_size * seq_len + + # ensure dq and dk are contiguous + dq = dq.contiguous() + dk = dk.contiguous() + + # backward is similar to forward except swapping few ops + _triton_qwen2vl_mrope[(n_row,)]( + dq, + dk, + cos, + sin, + seq_len, + batch_size, + n_q_head, + n_kv_head, + head_dim, + pad_n_q_head, + pad_n_kv_head, + pad_hd, + mrope_section[0], + mrope_section[1], + BLOCK_SIZE=BLOCK_SIZE, + BACKWARD_PASS=True, + ) + return dq.transpose(1, 2), dk.transpose(1, 2) + + +class LigerQwen2VLMRopeFunction(torch.autograd.Function): + """ + Triton implementation of the Qwen2VL Multimodal Rotary Positional Embedding (M-RoPE) operation. + + Please find the corresponding HuggingFace implementation here: + https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py + """ + + @staticmethod + def forward(ctx, q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """ + q size: (bsz, n_q_head, seq_len, head_dim) + k size: (bsz, n_kv_head, seq_len, head_dim) + cos size: (3, bsz, seq_len, head_dim) + sin size: (3, bsz, seq_len, head_dim) + """ + q, k, cos, sin = qwen2vl_mrope_forward(q, k, cos, sin, mrope_section) + ctx.save_for_backward(cos, sin) + ctx.mrope_section = mrope_section + return q, k + + def backward(ctx, dq, dk): + """ + dq size: (bsz, n_q_head, seq_len, head_dim) + dk size: (bsz, n_kv_head, seq_len, head_dim) + cos size: (3, bsz, seq_len, head_dim) + sin size: (3, bsz, seq_len, head_dim) + """ + cos, sin = ctx.saved_tensors + mrope_section = ctx.mrope_section + dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section) + return dq, dk, None, None, None, None \ No newline at end of file diff --git a/build/torch-xpu/rms_norm.py b/build/torch-xpu/rms_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..bbf908671842d9e002774416c6a9ccc55acd2684 --- /dev/null +++ b/build/torch-xpu/rms_norm.py @@ -0,0 +1,365 @@ +""" +This file incorporates code from Unsloth licensed under the Apache License, Version 2.0. +See the original Unsloth repository at https://github.com/unslothai/unsloth. + +The following line +https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/rms_norm.py#L30 +is based on code from Unsloth, located at: +https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22 + +Modifications made by Yanning Chen, 2024. +""" + +import math +import operator + +import torch +import triton +import triton.language as tl + +from .utils import calculate_settings +from .utils import compare_version +from .utils import ensure_contiguous +from .utils import torch_to_triton_dtype + +if compare_version("triton", operator.ge, "3.0.0"): + try: + # typical import path with dispatch available + from triton.language.extra.libdevice import rsqrt + except ModuleNotFoundError: + # for working with NGC containers + from triton.language.extra.cuda.libdevice import rsqrt +else: + from triton.language.math import rsqrt + + +_CASTING_MODE_NONE: tl.constexpr = tl.constexpr(-1) +_CASTING_MODE_LLAMA: tl.constexpr = tl.constexpr(0) +_CASTING_MODE_GEMMA: tl.constexpr = tl.constexpr(1) + + +@triton.jit +def _rms_norm_forward_kernel( + Y_ptr, + Y_row_stride, + X_ptr, + X_row_stride, + W_ptr, + W_row_stride, + RSTD_ptr, + RSTD_row_stride, + n_cols, + eps, + offset, + casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out + BLOCK_SIZE: tl.constexpr, +): + """ + y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N) + + Reference: + 1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html + 2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22 + 3. https://arxiv.org/pdf/1910.07467 + """ + + row_idx = tl.program_id(0) + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + Y_ptr += row_idx * Y_row_stride + X_ptr += row_idx * X_row_stride + RSTD_ptr += row_idx * RSTD_row_stride + + X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0) + X_row_dtype = X_row.dtype + W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0) + + # On Llama, only rstd is computed on fp32 + if casting_mode == _CASTING_MODE_LLAMA: + X_row = X_row.to(tl.float32) + + # Gemma computes everything on fp32, and then casts back the output to the original dtype + if casting_mode == _CASTING_MODE_GEMMA: + W_row = W_row.to(tl.float32) + X_row = X_row.to(tl.float32) + + if casting_mode == _CASTING_MODE_NONE: + eps = eps.to(X_row_dtype) + offset = offset.to(X_row_dtype) + + mean_square = tl.sum(X_row * X_row, axis=0) / n_cols + rstd = rsqrt(mean_square + eps) + + # We can save time by caching rms with minimal memory overhead + # because rms is much smaller compared to X_row, as rms is for each row. + # However, on the computation side, it can save 4 operations (*, sum, /, sqrt). + tl.store(RSTD_ptr, rstd) + + X_row = X_row * rstd + + # On Llama, the multiplication with the weight is done on the original dtype + if casting_mode == _CASTING_MODE_LLAMA: + X_row = X_row.to(X_row_dtype) + + Y_row = X_row * (offset + W_row) + + if casting_mode == _CASTING_MODE_GEMMA: + Y_row = Y_row.to(X_row_dtype) + + tl.store(Y_ptr + col_offsets, Y_row, mask=mask) + + +@triton.jit +def _rms_norm_backward_kernel( + dY_ptr, + dY_row_stride, + dX_ptr, + dX_row_stride, + X_ptr, + X_row_stride, + X_dtype: tl.constexpr, + W_ptr, + W_row_stride, + RSTD_ptr, + RSTD_row_stride, + dW_ptr, + dW_row_stride, + n_rows, + n_cols, + offset, + rows_per_program: tl.constexpr, + casting_mode: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """ + dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product + dw = sum(dy * (x / RMS)). summation over BxT dimension + """ + + row_block_id = tl.program_id(0) + row_start = row_block_id * rows_per_program + row_end = min((row_block_id + 1) * rows_per_program, n_rows) + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + + dY_ptr += row_start * dY_row_stride + dX_ptr += row_start * dX_row_stride + + X_ptr += row_start * X_row_stride + RSTD_ptr += row_start + + W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0) + W_row = W_row + offset + + for _ in range(row_start, row_end): + dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0) + X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0) + + # Get cached rms + rstd_row = tl.load(RSTD_ptr) + + X_row = X_row.to(tl.float32) + + # Different bacward graphs for different casting modes + if casting_mode == _CASTING_MODE_LLAMA: + m = (dY_row * W_row).to(tl.float32) + + elif casting_mode == _CASTING_MODE_GEMMA: + dY_row = dY_row.to(tl.float32) + m = dY_row * W_row + else: + m = dY_row * W_row + + dX_row = rstd_row * m + + dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row) + + # calculate the gradient of W + if casting_mode == _CASTING_MODE_LLAMA: + dW_row += dY_row * (X_row * rstd_row).to(X_dtype) + else: + # here X_row is already in fp32 (see previous if block) + dW_row += dY_row * (X_row * rstd_row) + + tl.store(dX_ptr + col_offsets, dX_row.to(X_dtype), mask=mask) + + dY_ptr += dY_row_stride + dX_ptr += dX_row_stride + X_ptr += X_row_stride + RSTD_ptr += RSTD_row_stride + + tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask) + + +_str_to_casting_mode = { + "llama": _CASTING_MODE_LLAMA.value, + "gemma": _CASTING_MODE_GEMMA.value, + "none": _CASTING_MODE_NONE.value, +} + + +def rms_norm_forward(X, W, eps, offset, casting_mode): + if not isinstance(casting_mode, int): + assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}" + casting_mode = _str_to_casting_mode[casting_mode] + else: + assert casting_mode in _str_to_casting_mode.values(), f"Invalid casting mode: {casting_mode}" + + shape = X.shape + dim = shape[-1] + X = X.view(-1, dim) + n_rows, n_cols = X.shape + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + + Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) + # RSTD is to cache rstd for each row + # RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode + rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype + RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device) + + # Check constraints. + assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]" + + # XPU-specific optimization + kernel_args = {} + if X.device.type == "xpu": + kernel_args["grf_mode"] = "large" + _rms_norm_forward_kernel[(n_rows,)]( + Y, + Y.stride(0), + X, + X.stride(0), + W, + W.stride(0), + RSTD, + RSTD.stride(0), + n_cols, + eps, + offset, + casting_mode, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + **kernel_args, # XPU-specific optimization + ) + return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode + + +def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place): + shape = dY.shape + dim = shape[-1] + dY = dY.view(-1, dim) + n_rows, n_cols = dY.shape + + sm_count = 1 + if X.device.type == "cuda": + sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count + elif X.device.type == "xpu": + sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count + + # fp32 for numerical stability especially. + _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device) + + if n_cols > BLOCK_SIZE: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + rows_per_program = math.ceil(n_rows / sm_count) + grid = (sm_count,) + + if in_place is True: + dX = dY + else: + dX = torch.zeros_like(dY) + + # XPU-specific optimization + kernel_args = {} + if X.device.type == "xpu": + kernel_args["grf_mode"] = "large" + + _rms_norm_backward_kernel[grid]( + dY, + dY.stride(0), + dX, + dX.stride(0), + X, + X.stride(0), + torch_to_triton_dtype[X.dtype], + W, + W.stride(0), + RSTD, + RSTD.stride(0), + _dW, + _dW.stride(0), + n_rows, + n_cols, + offset, + rows_per_program, + casting_mode, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + **kernel_args, # XPU-specific optimization + ) + dX = dX.view(*shape) + dW = _dW.sum(dim=0).to(W.dtype) + + return dX, dW + + +class LigerRMSNormFunction(torch.autograd.Function): + """ + Performs RMSNorm (Root Mean Square Normalization), which normalizes the input tensor `X` using the + weight tensor `W`, with an optional offset and casting mode. + + Some models use an 'offset' to shift the weight tensor `W` by a constant value. For example, Gemma + uses an offset of 1.0, so the computation becomes `(X / RMS(X)) * (W + 1.0)` instead of the usual + `(X / RMS(X)) * W`. You can pass the offset value as an argument to the forward function. + + In addition, different models cast their inputs at different places during RMSNorm computation. For + example, Gemma casts everything to fp32 nefore starting the computation, while Llama casts only the + inverse RMS to fp32. You can specify the casting mode using the `casting_mode` argument. We currently + support the following casting modes (they match HuggingFace Transformers' implementations): + - 'llama': matches the Llama implementation, where only the inverse RMS is computed on fp32. + - 'gemma': matches the Gemma implementation, where everything is cast to fp32, then computed, then cast back to the original dtype. + - 'none': no casting is done. The computation is done in the original dtype. This saves memory and is slightly faster, but has more error w.r.t. the original implementation. + + `in_place` option means whether to in_place modify dY to store dX. This is default to `True` to save memory. However, under certain cases, it can produce incorrect inputs. + For example, gemma2 uses two rmsnorm sequentially with residual in between. The resesidual part needs dY so it cannot be modified in-place. + Therefore, for the patching of RMSNorm in gemma2, we set `in_place` to `False` + """ + + @staticmethod + @ensure_contiguous + def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True): + """ + X: (B, T, H) or (BxT, H) + W: (H,) + """ + Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode) + ctx.offset = offset + ctx.casting_mode = casting_mode + ctx.in_place = in_place + ctx.BLOCK_SIZE = BLOCK_SIZE + ctx.num_warps = num_warps + ctx.save_for_backward(X, W, RSTD) + return Y + + @staticmethod + @ensure_contiguous + def backward(ctx, dY): + """ + Y: (B, T, H) or (BxT, H) + """ + X, W, RSTD = ctx.saved_tensors + dX, dW = rms_norm_backward( + dY, + X, + W, + RSTD, + ctx.offset, + ctx.casting_mode, + ctx.BLOCK_SIZE, + ctx.num_warps, + ctx.in_place, + ) + return dX, dW, None, None, None, None \ No newline at end of file diff --git a/build/torch-xpu/rope.py b/build/torch-xpu/rope.py new file mode 100644 index 0000000000000000000000000000000000000000..8c4801bf72c82d718ebd5fa63f278fdcbe0f23b3 --- /dev/null +++ b/build/torch-xpu/rope.py @@ -0,0 +1,239 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _triton_rope( + q_ptr, + q_row_stride, + k_ptr, + k_row_stride, + cos, + cos_row_stride, + sin, + sin_row_stride, + sl, + bs: tl.constexpr, + cos_bs: tl.constexpr, + n_qh: tl.constexpr, + n_kh: tl.constexpr, + hd: tl.constexpr, + pad_n_qh: tl.constexpr, + pad_n_kh: tl.constexpr, + pad_hd: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BACKWARD_PASS: tl.constexpr = False, +): + # q size: (bsz, seq_len, num_q_heads, head_dim) + # q stride: (seq_len * num_q_heads * head_dim, num_q_heads * head_dim, head_dim, 1) + # k size: (bsz, seq_len, num_kv_heads, head_dim) + # k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1) + + # cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) + # stride: (seq_len * head_dim, head_dim, 1) + pid = tl.program_id(0) + + # locate start address + q_ptr = q_ptr + pid * q_row_stride + k_ptr = k_ptr + pid * k_row_stride + + # #################################################################### + # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position + # m of this program instance + # #################################################################### + + # 1. program instances are laid out in a 1D vector of size bsz * seq_len, which + # effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension + # being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index + # and pid % sl to get the sequence index. + # 2. We only need the left half of cos and sin matrix because the right half is just + # a clone of the left half. + batch_idx = pid // sl + cos_row_idx = pid % sl + cos = cos + tl.where( + cos_bs == 1, + cos_row_idx * cos_row_stride, + batch_idx * (sl * cos_row_stride) + cos_row_idx * cos_row_stride, + ) + sin = sin + tl.where( + cos_bs == 1, + cos_row_idx * sin_row_stride, + batch_idx * (sl * sin_row_stride) + cos_row_idx * sin_row_stride, + ) + + cos_offsets = tl.arange(0, pad_hd // 2) + cos_mask = cos_offsets < hd // 2 + cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0) + sin_row = tl.load(sin + cos_offsets, mask=cos_mask, other=0) + + # #################################################################### + # Load the left and right half of q and k for the current + # program instance (i.e. for the current token) separately + # #################################################################### + # left half of the head + first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] + first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] + first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2) + first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2) + q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype) + k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype) + + # right half of the head + second_half_q_offsets = first_half_q_offsets + (hd // 2) + second_half_k_offsets = first_half_k_offsets + (hd // 2) + second_q_mask = first_q_mask + second_k_mask = first_k_mask + q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype) + k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype) + + if not BACKWARD_PASS: + # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin] + new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row + tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) + new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row + tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) + + new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row + tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask) + new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row + tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) + else: + # with some math, we can get: + # dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin] + new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row + tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) + new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row + tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) + + new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row + tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask) + new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row + tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) + + +def rope_forward(q, k, cos, sin): + # transpose it back to the physical shape because Triton looks at the physical storage + # note: q and k are incontiguous before the transformation and will become contiguous after transpose + q = q.transpose(1, 2) + k = k.transpose(1, 2) + + batch_size, seq_len, n_q_head, head_dim = q.shape + n_kv_head = k.shape[2] + pad_hd = triton.next_power_of_2(head_dim) + pad_n_q_head = triton.next_power_of_2(n_q_head) + pad_n_kv_head = triton.next_power_of_2(n_kv_head) + BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head) + + n_row = batch_size * seq_len + + # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous + q = q.contiguous() + k = k.contiguous() + cos = cos.contiguous() + sin = sin.contiguous() + cos_batch_size = cos.shape[0] + + _triton_rope[(n_row,)]( + q, + q.stride(1), + k, + k.stride(1), + cos, + cos.stride(-2), + sin, + sin.stride(-2), + seq_len, + batch_size, + cos_batch_size, + n_q_head, + n_kv_head, + head_dim, + pad_n_q_head, + pad_n_kv_head, + pad_hd, + BLOCK_SIZE=BLOCK_SIZE, + BACKWARD_PASS=False, + ) + return q.transpose(1, 2), k.transpose(1, 2), cos, sin + + +def rope_backward(dq, dk, cos, sin): + dq = dq.transpose(1, 2) + dk = dk.transpose(1, 2) + + batch_size, seq_len, n_q_head, head_dim = dq.shape + cos_batch_size = cos.shape[0] + n_kv_head = dk.shape[2] + pad_hd = triton.next_power_of_2(head_dim) + pad_n_q_head = triton.next_power_of_2(n_q_head) + pad_n_kv_head = triton.next_power_of_2(n_kv_head) + BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head) + + n_row = batch_size * seq_len + + # ensure dq and dk are contiguous + dq = dq.contiguous() + dk = dk.contiguous() + + # backward is similar to forward except swapping few ops + _triton_rope[(n_row,)]( + dq, + dq.stride(1), + dk, + dk.stride(1), + cos, + cos.stride(-2), + sin, + sin.stride(-2), + seq_len, + batch_size, + cos_batch_size, + n_q_head, + n_kv_head, + head_dim, + pad_n_q_head, + pad_n_kv_head, + pad_hd, + BLOCK_SIZE=BLOCK_SIZE, + BACKWARD_PASS=True, + ) + return dq.transpose(1, 2), dk.transpose(1, 2) + + +class LigerRopeFunction(torch.autograd.Function): + """ + Triton implementation of the Rotary Positional Embedding (RoPE) operation. Please note that + this implements the HuggingFace Llama & Mistral version, whose rotation matrix is slightly different + than the original RoPE paper. + + Please find the corresponding HuggingFace implementation here: + https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/llama/modeling_llama.py#L184 + + For more details about the rotation matrix used here, please refer to: + https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509/2 + """ + + @staticmethod + def forward(ctx, q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """ + q size: (bsz, n_q_head, seq_len, head_dim) + k size: (bsz, n_kv_head, seq_len, head_dim) + cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) + sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) + """ + q, k, cos, sin = rope_forward(q, k, cos, sin) + ctx.save_for_backward(cos, sin) + return q, k + + def backward(ctx, dq, dk): + """ + dq size: (bsz, n_q_head, seq_len, head_dim) + dk size: (bsz, n_kv_head, seq_len, head_dim) + cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) + sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) + """ + + cos, sin = ctx.saved_tensors + dq, dk = rope_backward(dq, dk, cos, sin) + return dq, dk, None, None, None, None \ No newline at end of file diff --git a/build/torch-xpu/swiglu.py b/build/torch-xpu/swiglu.py new file mode 100644 index 0000000000000000000000000000000000000000..f4ed39745bb5966a9d5eaee4643d487e66a4e620 --- /dev/null +++ b/build/torch-xpu/swiglu.py @@ -0,0 +1,116 @@ +import torch +import triton +import triton.language as tl + +from .utils import calculate_settings +from .utils import ensure_contiguous + + +@triton.jit +def silu(x): + return x * tl.sigmoid(x) + + +@triton.jit +def _swiglu_forward_kernel(a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr): + program_id = tl.program_id(0).to(tl.int64) + + # locate start index + a_ptr += program_id * stride + b_ptr += program_id * stride + c_ptr += program_id * stride + + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + # sigmoid requires type float32 + a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32) + b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0) + c_row = silu(a_row) * b_row + tl.store(c_ptr + col_offsets, c_row, mask=mask) + + +@triton.jit +def _swiglu_backward_kernel(dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr): + program_id = tl.program_id(0).to(tl.int64) + + # locate start index + dc_ptr += program_id * stride + a_ptr += program_id * stride + b_ptr += program_id * stride + + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + dc_row = tl.load(dc_ptr + col_offsets, mask=mask, other=0) + # sigmoid requires type float32 + a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32) + b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0) + + # recomputation to save memory + sig_a = tl.sigmoid(a_row) + silu_a = a_row * sig_a + db_row = dc_row * silu_a + da_row = dc_row * (silu_a * (1 - sig_a) + sig_a) * b_row + + tl.store(a_ptr + col_offsets, da_row, mask=mask) + tl.store(b_ptr + col_offsets, db_row, mask=mask) + + +def swiglu_forward(a, b): + ori_shape = a.shape + + n_cols = ori_shape[-1] + a = a.view(-1, n_cols) + b = b.view(-1, n_cols) + c = torch.empty_like(a) + n_rows = a.shape[0] + + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + + _swiglu_forward_kernel[(n_rows,)]( + a, + b, + c, + c.stride(-2), + n_cols=n_cols, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + return a, b, c.view(*ori_shape) + + +def swiglu_backward(a, b, dc): + ori_shape = dc.shape + n_cols = ori_shape[-1] + dc = dc.view(-1, n_cols) + n_rows = dc.shape[0] + + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + + _swiglu_backward_kernel[(n_rows,)]( + dc, + a, + b, + dc.stride(-2), + n_cols=n_cols, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + return a.view(*ori_shape), b.view(*ori_shape) + + +class LigerSiLUMulFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward(ctx, a, b): + a, b, c = swiglu_forward(a, b) + ctx.save_for_backward(a, b) + return c + + @staticmethod + @ensure_contiguous + def backward(ctx, dc): + a, b = ctx.saved_tensors + a, b = swiglu_backward(a, b, dc) + return a, b \ No newline at end of file diff --git a/build/torch-xpu/tvd.py b/build/torch-xpu/tvd.py new file mode 100644 index 0000000000000000000000000000000000000000..c072e4118c6fa5980de5c96a7acdaa05b6acb468 --- /dev/null +++ b/build/torch-xpu/tvd.py @@ -0,0 +1,207 @@ +from typing import Literal +from typing import Optional + +import torch +import triton +import triton.language as tl + +from .utils import ensure_contiguous + +MAX_FUSED_SIZE = 65536 // 4 + +REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"] + +_REDUCTION_MODE_NONE = tl.constexpr(0) +_REDUCTION_MODE_SUM = tl.constexpr(1) +_REDUCTION_MODE_MEAN = tl.constexpr(2) +_REDUCTION_MODE_BATCHMEAN = tl.constexpr(3) + +_str_to_reduction_mode = { + "none": _REDUCTION_MODE_NONE.value, + "sum": _REDUCTION_MODE_SUM.value, + "mean": _REDUCTION_MODE_MEAN.value, + "batchmean": _REDUCTION_MODE_BATCHMEAN.value, +} + + +def get_num_warps(BLOCK_SIZE): + num_warps = 4 + if BLOCK_SIZE >= 32768: + num_warps = 32 + elif BLOCK_SIZE >= 8192: + num_warps = 16 + elif BLOCK_SIZE >= 2048: + num_warps = 8 + + return num_warps + + +@triton.jit +def _tv_distance_kernel( + p_ptr, + p_stride, + q_ptr, + q_stride, + loss_ptr, + loss_stride, + grads_ptr, + grads_stride, + label_ptr, + ignore_index: tl.constexpr, + n_cols, + BLOCK_SIZE: tl.constexpr, + HAS_LABEL: tl.constexpr, + reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN, +): + pid = tl.program_id(0).to(tl.int64) + p_ptr += pid * p_stride + q_ptr += pid * q_stride + loss_ptr += pid * loss_stride + grads_ptr += pid * grads_stride + label_ptr += pid + + base_offsets = tl.arange(0, BLOCK_SIZE) + + if HAS_LABEL: + label = tl.load(label_ptr) + if label == ignore_index: + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + base_offsets + mask = offsets < n_cols + tl.store(grads_ptr + offsets, 0.0, mask=mask) + if reduction == _REDUCTION_MODE_NONE: + tl.store(loss_ptr + offsets, 0.0, mask=mask) + return + + loss_sum = 0.0 + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + base_offsets + mask = offsets < n_cols + + p = tl.load(p_ptr + offsets, mask=mask, other=0.0) + q = tl.load(q_ptr + offsets, mask=mask, other=0.0) + + # TVD(P || Q) = 0.5 * |P - Q| + tv_loss = 0.5 * tl.abs(p - q) + + grad_res = tl.where(p > q, 0.5, -0.5) + + tl.store(grads_ptr + offsets, grad_res, mask=mask) + + if reduction == _REDUCTION_MODE_NONE: + tl.store(loss_ptr + offsets, tv_loss, mask=mask) + else: + loss_sum += tl.sum(tv_loss, axis=0) + + if reduction != _REDUCTION_MODE_NONE: + tl.store(loss_ptr, loss_sum) + + +def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label): + BT, V = p.shape + + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + num_warps = get_num_warps(BLOCK_SIZE) + + grid = (BT,) + + reduction = _str_to_reduction_mode[reduction] + + out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,) + output_tensor = torch.zeros(out_size, device=p.device, dtype=torch.float32) + grads = torch.empty_like(p) + + n_non_ignore = (shift_labels != ignore_index).sum().item() if has_label else BT + + _tv_distance_kernel[grid]( + p, + p.stride(0), + q, + q.stride(0), + output_tensor, + output_tensor.stride(0), + grads, + grads.stride(0), + shift_labels if has_label else torch.empty(1, device=p.device), + ignore_index, + V, + BLOCK_SIZE=BLOCK_SIZE, + HAS_LABEL=has_label, + num_warps=num_warps, + reduction=reduction, + ) + + if reduction == _REDUCTION_MODE_BATCHMEAN.value: + return output_tensor.sum() / n_non_ignore, grads / n_non_ignore + elif reduction == _REDUCTION_MODE_SUM.value: + return output_tensor.sum(dim=0), grads + elif reduction == _REDUCTION_MODE_MEAN.value: + return output_tensor.sum() / (n_non_ignore * V), grads / (n_non_ignore * V) + else: + return output_tensor, grads + + +def tvd_backward_triton(grad_output, grads): + # If cross entropy is the last layer, grad_output is 1.0. Skip the mul then. + if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + return grads + + return grads * grad_output + + +class LigerTVDLossFunction(torch.autograd.Function): + """ + Class implementing the forward and backward pass for the Total Variation Distance Loss using Triton. + """ + + @staticmethod + @ensure_contiguous + def forward( + ctx, + p: torch.Tensor, + q: torch.Tensor, + shift_labels: Optional[torch.Tensor] = None, + reduction: REDUCTION_LITERAL = "batchmean", + ignore_index: int = -100, + ) -> torch.Tensor: + """A forward pass for the Total Variation Distance Loss. + + Args: + ctx: Torch autograd context + p (torch.Tensor): A tensor of shape (BT, V) containing the first distribution. + q (torch.Tensor): A tensor of shape (BT, V) containing the second distribution. + shift_labels (Optional[torch.Tensor]): A tensor of shape (BT,) containing the labels. + reduction (REDUCTION_LITERAL, optional): The reduction method to be applied. Defaults to "batchmean". + ignore_index (int, optional): The index to ignore during loss calculation. Defaults to -100. + + Returns: + torch.Tensor: The computed Total Variation Distance Loss. + """ + has_label = False + if shift_labels is not None: + assert shift_labels.shape == (p.shape[0],), ( + f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}" + ) + shift_labels = shift_labels.contiguous() + has_label = True + + loss, grads = tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label) + ctx.save_for_backward(grads) + return loss + + @staticmethod + @ensure_contiguous + def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: + """A backward pass for the Total Variation Distance Loss. + + Args: + ctx: Torch autograd context + grad_output (torch.Tensor): The gradient of the loss with respect to the output. + + Returns: + tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs. + """ + (grads,) = ctx.saved_tensors + grads = tvd_backward_triton(grad_output, grads) + + return grads, None, None, None, None \ No newline at end of file diff --git a/build/torch-xpu/utils.py b/build/torch-xpu/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..718b0bb9c6391cf291bae91e671fc44a2b5ebf36 --- /dev/null +++ b/build/torch-xpu/utils.py @@ -0,0 +1,135 @@ +""" +This file incorporates code from Unsloth licensed under the Apache License, Version 2.0. +See the original Unsloth repository at https://github.com/unslothai/unsloth. + +The following line +https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/utils.py#L23 +is based on code from Unsloth, located at: +https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43 + +Modifications made by Yanning Chen, 2024. +""" + +import functools +import importlib +import operator + +from typing import Callable + +import torch +import triton +import triton.language as tl + +from packaging.version import Version + +def infer_device(): + """ + Get current device name based on available devices + """ + if torch.cuda.is_available(): # Works for both Nvidia and AMD + return "cuda" + elif torch.xpu.is_available(): + return "xpu" + else: + return "cpu" + +def is_hip() -> bool: + return torch.version.hip is not None + + +def ensure_contiguous(fn): + @functools.wraps(fn) + def wrapper(ctx, *args, **kwargs): + def maybe_to_contiguous(x): + return x.contiguous() if isinstance(x, torch.Tensor) else x + + args = [maybe_to_contiguous(arg) for arg in args] + kwargs = {k: maybe_to_contiguous(v) for k, v in kwargs.items()} + return fn(ctx, *args, **kwargs) + + return wrapper + + +def calculate_settings(n): + # reference: https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43 + + MAX_FUSED_SIZE = 65536 + BLOCK_SIZE = triton.next_power_of_2(n) + if BLOCK_SIZE > MAX_FUSED_SIZE: + raise RuntimeError( + f"Cannot launch Triton kernel since n = {n} exceeds the recommended Triton blocksize = {MAX_FUSED_SIZE}." + ) + + num_warps = 4 + if BLOCK_SIZE >= 32768: + num_warps = 32 if not is_hip() else 16 + elif BLOCK_SIZE >= 8192: + num_warps = 16 + elif BLOCK_SIZE >= 2048: + num_warps = 8 + return BLOCK_SIZE, num_warps + + +def compare_version(package: str, operator: Callable, target: str): + try: + pkg = importlib.import_module(package) + except ImportError: + return False + pkg_version = Version(pkg.__version__) + return operator(pkg_version, Version(target)) + + +def get_amp_custom_fwd_bwd() -> Callable: + device = infer_device() + if compare_version("torch", operator.ge, "2.4.0"): + return ( + functools.partial(torch.amp.custom_fwd, device_type=device), + functools.partial(torch.amp.custom_bwd, device_type=device), + ) + return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd + + +amp_custom_fwd, amp_custom_bwd = get_amp_custom_fwd_bwd() + + +torch_to_triton_dtype = { + torch.float32: tl.float32, + torch.float16: tl.float16, + torch.bfloat16: tl.bfloat16, +} + + +@triton.jit +def element_mul_kernel( + X_ptr, + X_stride, + grad_output_ptr, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + """ + This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr. + The multiplication is performed in-place on the tensor pointed by X_ptr. + + Parameters: + X_ptr: Pointer to the input tensor. + X_stride (int): The stride of the input tensor. + grad_output_ptr: Pointer to the gradient output value. + n_cols (int): The number of columns in the input tensor. + BLOCK_SIZE (int): The block size for Triton operations. + """ + + # Get the program ID and convert it to int64 to avoid overflow + program_id = tl.program_id(0).to(tl.int64) + + # Locate the start index + X_ptr += program_id * X_stride + + # Load the gradient output value + grad_output = tl.load(grad_output_ptr) + + # Perform the element-wise multiplication + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols) + tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols) \ No newline at end of file