|
|
""" |
|
|
This file incorporates code from Unsloth licensed under the Apache License, Version 2.0. |
|
|
See the original Unsloth repository at https://github.com/unslothai/unsloth. |
|
|
|
|
|
The following line |
|
|
https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/utils.py#L23 |
|
|
is based on code from Unsloth, located at: |
|
|
https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43 |
|
|
|
|
|
Modifications made by Yanning Chen, 2024. |
|
|
""" |
|
|
|
|
|
import functools |
|
|
import importlib |
|
|
import operator |
|
|
|
|
|
from typing import Callable |
|
|
|
|
|
import torch |
|
|
import triton |
|
|
import triton.language as tl |
|
|
|
|
|
from packaging.version import Version |
|
|
|
|
|
def infer_device(): |
|
|
""" |
|
|
Get current device name based on available devices |
|
|
""" |
|
|
if torch.cuda.is_available(): |
|
|
return "cuda" |
|
|
elif torch.xpu.is_available(): |
|
|
return "xpu" |
|
|
else: |
|
|
return "cpu" |
|
|
|
|
|
def is_hip() -> bool: |
|
|
return torch.version.hip is not None |
|
|
|
|
|
|
|
|
def ensure_contiguous(fn): |
|
|
@functools.wraps(fn) |
|
|
def wrapper(ctx, *args, **kwargs): |
|
|
def maybe_to_contiguous(x): |
|
|
return x.contiguous() if isinstance(x, torch.Tensor) else x |
|
|
|
|
|
args = [maybe_to_contiguous(arg) for arg in args] |
|
|
kwargs = {k: maybe_to_contiguous(v) for k, v in kwargs.items()} |
|
|
return fn(ctx, *args, **kwargs) |
|
|
|
|
|
return wrapper |
|
|
|
|
|
|
|
|
def calculate_settings(n): |
|
|
|
|
|
|
|
|
MAX_FUSED_SIZE = 65536 |
|
|
BLOCK_SIZE = triton.next_power_of_2(n) |
|
|
if BLOCK_SIZE > MAX_FUSED_SIZE: |
|
|
raise RuntimeError( |
|
|
f"Cannot launch Triton kernel since n = {n} exceeds the recommended Triton blocksize = {MAX_FUSED_SIZE}." |
|
|
) |
|
|
|
|
|
num_warps = 4 |
|
|
if BLOCK_SIZE >= 32768: |
|
|
num_warps = 32 if not is_hip() else 16 |
|
|
elif BLOCK_SIZE >= 8192: |
|
|
num_warps = 16 |
|
|
elif BLOCK_SIZE >= 2048: |
|
|
num_warps = 8 |
|
|
return BLOCK_SIZE, num_warps |
|
|
|
|
|
|
|
|
def compare_version(package: str, operator: Callable, target: str): |
|
|
try: |
|
|
pkg = importlib.import_module(package) |
|
|
except ImportError: |
|
|
return False |
|
|
pkg_version = Version(pkg.__version__) |
|
|
return operator(pkg_version, Version(target)) |
|
|
|
|
|
|
|
|
def get_amp_custom_fwd_bwd() -> Callable: |
|
|
device = infer_device() |
|
|
if compare_version("torch", operator.ge, "2.4.0"): |
|
|
return ( |
|
|
functools.partial(torch.amp.custom_fwd, device_type=device), |
|
|
functools.partial(torch.amp.custom_bwd, device_type=device), |
|
|
) |
|
|
return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd |
|
|
|
|
|
|
|
|
amp_custom_fwd, amp_custom_bwd = get_amp_custom_fwd_bwd() |
|
|
|
|
|
|
|
|
torch_to_triton_dtype = { |
|
|
torch.float32: tl.float32, |
|
|
torch.float16: tl.float16, |
|
|
torch.bfloat16: tl.bfloat16, |
|
|
} |
|
|
|
|
|
|
|
|
@triton.jit |
|
|
def element_mul_kernel( |
|
|
X_ptr, |
|
|
X_stride, |
|
|
grad_output_ptr, |
|
|
n_cols, |
|
|
BLOCK_SIZE: tl.constexpr, |
|
|
): |
|
|
""" |
|
|
This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr. |
|
|
The multiplication is performed in-place on the tensor pointed by X_ptr. |
|
|
|
|
|
Parameters: |
|
|
X_ptr: Pointer to the input tensor. |
|
|
X_stride (int): The stride of the input tensor. |
|
|
grad_output_ptr: Pointer to the gradient output value. |
|
|
n_cols (int): The number of columns in the input tensor. |
|
|
BLOCK_SIZE (int): The block size for Triton operations. |
|
|
""" |
|
|
|
|
|
|
|
|
program_id = tl.program_id(0).to(tl.int64) |
|
|
|
|
|
|
|
|
X_ptr += program_id * X_stride |
|
|
|
|
|
|
|
|
grad_output = tl.load(grad_output_ptr) |
|
|
|
|
|
|
|
|
for i in range(0, n_cols, BLOCK_SIZE): |
|
|
X_offsets = i + tl.arange(0, BLOCK_SIZE) |
|
|
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols) |
|
|
tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols) |