| """ |
| This file incorporates code from Unsloth licensed under the Apache License, Version 2.0. |
| See the original Unsloth repository at https://github.com/unslothai/unsloth. |
| |
| The following line |
| https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/utils.py#L23 |
| is based on code from Unsloth, located at: |
| https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43 |
| |
| Modifications made by Yanning Chen, 2024. |
| """ |
|
|
| import functools |
| import importlib |
| import operator |
|
|
| from typing import Callable |
|
|
| import torch |
| import triton |
| import triton.language as tl |
|
|
| from packaging.version import Version |
|
|
|
|
| def is_npu_available() -> bool: |
| """Detect Ascend NPU availability.""" |
| try: |
| from transformers.utils import is_torch_npu_available |
|
|
| return is_torch_npu_available() |
| except Exception: |
| return False |
|
|
|
|
| def infer_device(): |
| """ |
| Get current device name based on available devices |
| """ |
| if torch.cuda.is_available(): |
| return "cuda" |
| |
| elif is_npu_available(): |
| return "npu" |
| |
| elif torch.xpu.is_available(): |
| return "xpu" |
| else: |
| return "cpu" |
|
|
|
|
| def is_hip() -> bool: |
| return torch.version.hip is not None |
|
|
|
|
| def ensure_contiguous(fn): |
| @functools.wraps(fn) |
| def wrapper(ctx, *args, **kwargs): |
| def maybe_to_contiguous(x): |
| return x.contiguous() if isinstance(x, torch.Tensor) else x |
|
|
| args = [maybe_to_contiguous(arg) for arg in args] |
| kwargs = {k: maybe_to_contiguous(v) for k, v in kwargs.items()} |
| return fn(ctx, *args, **kwargs) |
|
|
| return wrapper |
|
|
|
|
| def calculate_settings(n): |
| |
|
|
| MAX_FUSED_SIZE = 65536 |
| BLOCK_SIZE = triton.next_power_of_2(n) |
| if BLOCK_SIZE > MAX_FUSED_SIZE: |
| raise RuntimeError( |
| f"Cannot launch Triton kernel since n = {n} exceeds the recommended Triton blocksize = {MAX_FUSED_SIZE}." |
| ) |
|
|
| num_warps = 4 |
| if BLOCK_SIZE >= 32768: |
| num_warps = 32 if not is_hip() else 16 |
| elif BLOCK_SIZE >= 8192: |
| num_warps = 16 |
| elif BLOCK_SIZE >= 2048: |
| num_warps = 8 |
| return BLOCK_SIZE, num_warps |
|
|
|
|
| def compare_version(package: str, operator: Callable, target: str): |
| try: |
| pkg = importlib.import_module(package) |
| except ImportError: |
| return False |
| pkg_version = Version(pkg.__version__) |
| return operator(pkg_version, Version(target)) |
|
|
|
|
| def get_amp_custom_fwd_bwd() -> Callable: |
| device = infer_device() |
| if compare_version("torch", operator.ge, "2.4.0"): |
| return ( |
| functools.partial(torch.amp.custom_fwd, device_type=device), |
| functools.partial(torch.amp.custom_bwd, device_type=device), |
| ) |
| if hasattr(torch, "npu") and getattr(torch.npu, "amp", None) is not None: |
| return torch.npu.amp.custom_fwd, torch.npu.amp.custom_bwd |
| return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd |
|
|
|
|
| amp_custom_fwd, amp_custom_bwd = get_amp_custom_fwd_bwd() |
|
|
|
|
| torch_to_triton_dtype = { |
| torch.float32: tl.float32, |
| torch.float16: tl.float16, |
| torch.bfloat16: tl.bfloat16, |
| } |
|
|
|
|
| @triton.jit |
| def element_mul_kernel( |
| X_ptr, |
| X_stride, |
| grad_output_ptr, |
| n_cols, |
| BLOCK_SIZE: tl.constexpr, |
| ): |
| """ |
| This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr. |
| The multiplication is performed in-place on the tensor pointed by X_ptr. |
| |
| Parameters: |
| X_ptr: Pointer to the input tensor. |
| X_stride (int): The stride of the input tensor. |
| grad_output_ptr: Pointer to the gradient output value. |
| n_cols (int): The number of columns in the input tensor. |
| BLOCK_SIZE (int): The block size for Triton operations. |
| """ |
|
|
| |
| program_id = tl.program_id(0).to(tl.int64) |
|
|
| |
| X_ptr += program_id * X_stride |
|
|
| |
| grad_output = tl.load(grad_output_ptr) |
|
|
| |
| for i in range(0, n_cols, BLOCK_SIZE): |
| X_offsets = i + tl.arange(0, BLOCK_SIZE) |
| X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols) |
| tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols) |
|
|
|
|
| def get_npu_core_count(default: int = 20) -> int: |
| """Return NPU vector core count. |
| Fallback to `default` if Triton runtime or NPU device is unavailable. |
| """ |
| try: |
| utils = triton.runtime.driver.active.utils |
| props = utils.get_device_properties(0) |
| return int(props.get("num_vectorcore", default)) |
| except Exception: |
| return default |
|
|
|
|
| def set_large_grf_mode(kernel_args: dict): |
| """Set large GRF mode for XPU devices.""" |
| |
| |
| if compare_version("pytorch-triton-xpu", operator.ge, "3.6.0") or compare_version("triton", operator.ge, "3.6.0"): |
| kernel_args["grf_mode"] = "256" |
| else: |
| |
| kernel_args["grf_mode"] = "large" |
|
|