Kernels
danieldk's picture
danieldk HF Staff
Build uploaded using `kernels`.
e5b00e0 verified
import torch
import triton
import triton.language as tl
from .utils import calculate_settings
from .utils import ensure_contiguous
@triton.jit
def silu(x):
return x * tl.sigmoid(x)
@triton.jit
def _swiglu_forward_kernel(a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
program_id = tl.program_id(0).to(tl.int64)
# locate start index
a_ptr += program_id * stride
b_ptr += program_id * stride
c_ptr += program_id * stride
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
# sigmoid requires type float32
a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
c_row = silu(a_row) * b_row
tl.store(c_ptr + col_offsets, c_row, mask=mask)
@triton.jit
def _swiglu_backward_kernel(dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
program_id = tl.program_id(0).to(tl.int64)
# locate start index
dc_ptr += program_id * stride
a_ptr += program_id * stride
b_ptr += program_id * stride
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
dc_row = tl.load(dc_ptr + col_offsets, mask=mask, other=0)
# sigmoid requires type float32
a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
# recomputation to save memory
sig_a = tl.sigmoid(a_row)
silu_a = a_row * sig_a
db_row = dc_row * silu_a
da_row = dc_row * (silu_a * (1 - sig_a) + sig_a) * b_row
tl.store(a_ptr + col_offsets, da_row, mask=mask)
tl.store(b_ptr + col_offsets, db_row, mask=mask)
def swiglu_forward(a, b):
ori_shape = a.shape
n_cols = ori_shape[-1]
a = a.view(-1, n_cols)
b = b.view(-1, n_cols)
c = torch.empty_like(a)
n_rows = a.shape[0]
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
_swiglu_forward_kernel[(n_rows,)](
a,
b,
c,
c.stride(-2),
n_cols=n_cols,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
return a, b, c.view(*ori_shape)
def swiglu_backward(a, b, dc):
ori_shape = dc.shape
n_cols = ori_shape[-1]
dc = dc.view(-1, n_cols)
n_rows = dc.shape[0]
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
_swiglu_backward_kernel[(n_rows,)](
dc,
a,
b,
dc.stride(-2),
n_cols=n_cols,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
return a.view(*ori_shape), b.view(*ori_shape)
class LigerSiLUMulFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(ctx, a, b):
a, b, c = swiglu_forward(a, b)
ctx.save_for_backward(a, b)
return c
@staticmethod
@ensure_contiguous
def backward(ctx, dc):
a, b = ctx.saved_tensors
a, b = swiglu_backward(a, b, dc)
return a, b