base_IIXIV / fla /modules /activations.py
mainline777's picture
Duplicate from silx-ai/Quasar-Preview
41865df
Raw
History Blame Contribute Delete
21 kB
# Copyright (c) 2023-2025, Tri Dao, Yu Zhang, Songlin Yang.
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from fla.ops.utils.op import exp, log
from fla.utils import IS_AMD, autocast_custom_bwd, autocast_custom_fwd, autotune_cache_kwargs, input_guard
NUM_WARPS_AUTOTUNE = [1, 2, 4, 8, 16] if IS_AMD else [1, 2, 4, 8, 16, 32]
def _get_stride(x: torch.Tensor) -> int:
"""Get the row stride for viewing a tensor as 2D (num_rows, D) where D = shape[-1].
Returns stride(-2) if the tensor is at least 2D, or 0 for 1D tensors.
The caller must ensure the tensor is "inner-contiguous" (stride(-1) == 1 and
higher dims are contiguous relative to dim -2) before using this value.
"""
if x.ndim < 2:
return 0
return x.stride(-2)
def _is_inner_contiguous(x: torch.Tensor) -> bool:
"""Check if a tensor can be safely viewed as 2D (num_rows, D) with row stride = stride(-2).
This holds when stride(-1) == 1 and all dimensions above -2 are contiguous
with respect to the dimension below them.
"""
ndim = x.ndim
if ndim < 2:
return True
if x.stride(-1) != 1:
return False
if ndim == 2:
# 2D: any layout with stride(-1)==1 is valid (can view as (T, D))
return True
if ndim == 3:
# 3D (B, T, D): stride should be (T*D, D, 1)
return x.stride(0) == x.stride(-2) * x.shape[-2]
if ndim == 4:
# 4D (B, H, T, D): stride should be (H*T*D, T*D, D, 1)
if x.stride(1) != x.stride(-2) * x.shape[-2]:
return False
return x.stride(0) == x.stride(1) * x.shape[1]
# 5D+ fallback to loop
expected = x.stride(-2) * x.shape[-2]
for d in range(ndim - 3, -1, -1):
if x.stride(d) != expected:
return False
expected *= x.shape[d]
return True
def _ensure_inner_contiguous(x: torch.Tensor) -> torch.Tensor:
"""Make the tensor inner-contiguous if it isn't already."""
if _is_inner_contiguous(x):
return x
return x.contiguous()
def _alloc_output(x: torch.Tensor, contiguous: bool = False) -> torch.Tensor:
"""Allocate output tensor: contiguous buffer or same layout as input."""
if contiguous:
return x.new_empty(x.shape)
return torch.empty_like(x)
@triton.autotune(
configs=[
triton.Config({'B': bs}, num_warps=num_warps)
for bs in [512, 1024, 2048, 4096, 8192]
for num_warps in NUM_WARPS_AUTOTUNE
],
key=['D'],
**autotune_cache_kwargs,
)
@triton.jit(do_not_specialize=['T'])
def sigmoid_fwd_kernel(
x, y,
T,
D: tl.constexpr,
stride_x_row,
stride_y_row,
B: tl.constexpr,
):
pid = tl.program_id(0)
offs = pid * B + tl.arange(0, B)
mask = offs < T
row = offs // D
col = offs % D
x_off = row * stride_x_row + col
y_off = row * stride_y_row + col
x_val = tl.load(x + x_off, mask=mask, other=0.).to(tl.float32)
y_val = 1.0 / (1.0 + exp(-x_val))
tl.store(y + y_off, y_val.to(y.dtype.element_ty), mask=mask)
@triton.autotune(
configs=[
triton.Config({'B': bs}, num_warps=num_warps)
for bs in [512, 1024, 2048, 4096, 8192]
for num_warps in NUM_WARPS_AUTOTUNE
],
key=['D'],
**autotune_cache_kwargs,
)
@triton.jit(do_not_specialize=['T'])
def sigmoid_bwd_kernel(
x, dy, dx,
T,
D: tl.constexpr,
stride_x_row,
stride_dy_row,
stride_dx_row,
B: tl.constexpr,
):
pid = tl.program_id(0)
offs = pid * B + tl.arange(0, B)
mask = offs < T
row = offs // D
col = offs % D
x_off = row * stride_x_row + col
dy_off = row * stride_dy_row + col
dx_off = row * stride_dx_row + col
x_val = tl.load(x + x_off, mask=mask, other=0.).to(tl.float32)
g_val = tl.load(dy + dy_off, mask=mask, other=0.).to(tl.float32)
s = 1.0 / (1.0 + exp(-x_val))
dx_val = g_val * s * (1.0 - s)
tl.store(dx + dx_off, dx_val.to(dx.dtype.element_ty), mask=mask)
@torch.compiler.disable
def sigmoid_fwd(x: torch.Tensor, output_contiguous: bool = False) -> torch.Tensor:
x = _ensure_inner_contiguous(x)
T, D = x.numel(), x.shape[-1]
y = _alloc_output(x, output_contiguous)
sigmoid_fwd_kernel[lambda meta: (triton.cdiv(T, meta['B']),)](
x, y, T=T, D=D,
stride_x_row=_get_stride(x),
stride_y_row=_get_stride(y),
)
return y
@torch.compiler.disable
def sigmoid_bwd(x: torch.Tensor, dy: torch.Tensor, output_contiguous: bool = False) -> torch.Tensor:
x = _ensure_inner_contiguous(x)
dy = _ensure_inner_contiguous(dy)
T, D = x.numel(), x.shape[-1]
dx = _alloc_output(x, output_contiguous)
sigmoid_bwd_kernel[lambda meta: (triton.cdiv(T, meta['B']),)](
x, dy, dx, T=T, D=D,
stride_x_row=_get_stride(x),
stride_dy_row=_get_stride(dy),
stride_dx_row=_get_stride(dx),
)
return dx
class SigmoidFunction(torch.autograd.Function):
@staticmethod
@input_guard(no_guard_contiguous=True)
def forward(ctx, x):
ctx.save_for_backward(x)
return sigmoid_fwd(x)
@staticmethod
@input_guard(no_guard_contiguous=True)
def backward(ctx, dout):
x, = ctx.saved_tensors
return sigmoid_bwd(x, dout)
sigmoid = SigmoidFunction.apply
@triton.autotune(
configs=[
triton.Config({'B': bs}, num_warps=num_warps)
for bs in [512, 1024, 2048, 4096, 8192]
for num_warps in NUM_WARPS_AUTOTUNE
],
key=['D'],
**autotune_cache_kwargs,
)
@triton.jit(do_not_specialize=['T'])
def logsigmoid_fwd_kernel(
x,
y,
temperature,
T,
D: tl.constexpr,
stride_x_row,
stride_y_row,
B: tl.constexpr,
):
i = tl.program_id(0)
o_i = i * B + tl.arange(0, B)
m_i = o_i < T
row = o_i // D
col = o_i % D
x_off = row * stride_x_row + col
y_off = row * stride_y_row + col
b_x = tl.load(x + x_off, mask=m_i, other=0.).to(tl.float32)
b_m = tl.minimum(0., b_x)
b_z = 1. + exp(-tl.abs(b_x))
b_y = (b_m - log(b_z)) / temperature
tl.store(y + y_off, b_y.to(y.dtype.element_ty), mask=m_i)
@triton.autotune(
configs=[
triton.Config({'B': bs}, num_warps=num_warps)
for bs in [512, 1024, 2048, 4096, 8192]
for num_warps in NUM_WARPS_AUTOTUNE
],
key=['D'],
**autotune_cache_kwargs,
)
@triton.jit(do_not_specialize=['T'])
def logsigmoid_bwd_kernel(
x,
dx,
dy,
temperature,
T,
D: tl.constexpr,
stride_x_row,
stride_dx_row,
stride_dy_row,
B: tl.constexpr,
):
i = tl.program_id(0)
o_i = i * B + tl.arange(0, B)
m_i = o_i < T
row = o_i // D
col = o_i % D
x_off = row * stride_x_row + col
dx_off = row * stride_dx_row + col
dy_off = row * stride_dy_row + col
b_x = tl.load(x + x_off, mask=m_i, other=0.).to(tl.float32)
b_dy = tl.load(dy + dy_off, mask=m_i, other=0.).to(tl.float32)
b_dx = b_dy * ((1. - tl.sigmoid(b_x)) / temperature)
tl.store(dx + dx_off, b_dx.to(dx.dtype.element_ty), mask=m_i)
@torch.compiler.disable
def logsigmoid_fwd(x: torch.Tensor, temperature: float = 1., output_contiguous: bool = False) -> torch.Tensor:
x = _ensure_inner_contiguous(x)
T, D = x.numel(), x.shape[-1]
y = _alloc_output(x, output_contiguous)
logsigmoid_fwd_kernel[lambda meta: (triton.cdiv(T, meta['B']),)](
x=x,
y=y,
temperature=temperature,
T=T,
D=D,
stride_x_row=_get_stride(x),
stride_y_row=_get_stride(y),
)
return y
@torch.compiler.disable
def logsigmoid_bwd(x: torch.Tensor, dy: torch.Tensor, temperature: float = 1., output_contiguous: bool = False) -> torch.Tensor:
x = _ensure_inner_contiguous(x)
dy = _ensure_inner_contiguous(dy)
T, D = x.numel(), x.shape[-1]
dx = _alloc_output(x, output_contiguous)
logsigmoid_bwd_kernel[lambda meta: (triton.cdiv(T, meta['B']),)](
x=x,
dx=dx,
dy=dy,
temperature=temperature,
T=T,
D=D,
stride_x_row=_get_stride(x),
stride_dx_row=_get_stride(dx),
stride_dy_row=_get_stride(dy),
)
return dx
class LogSigmoidFunction(torch.autograd.Function):
@staticmethod
@input_guard(no_guard_contiguous=True)
def forward(ctx, x, temperature):
ctx.save_for_backward(x)
ctx.temperature = temperature
return logsigmoid_fwd(x, temperature)
@staticmethod
@input_guard(no_guard_contiguous=True)
def backward(ctx, dy):
x, = ctx.saved_tensors
return logsigmoid_bwd(x, dy, ctx.temperature), None
def logsigmoid(x: torch.Tensor, temperature: float = 1.) -> torch.Tensor:
return LogSigmoidFunction.apply(x, temperature)
@triton.autotune(
configs=[
triton.Config({'B': bs}, num_warps=num_warps)
for bs in [512, 1024, 2048, 4096, 8192]
for num_warps in NUM_WARPS_AUTOTUNE
],
key=['D'],
**autotune_cache_kwargs,
)
@triton.jit(do_not_specialize=['T'])
def swish_fwd_kernel(
x, y,
T,
D: tl.constexpr,
stride_x_row,
stride_y_row,
B: tl.constexpr,
):
pid = tl.program_id(0)
offs = pid * B + tl.arange(0, B)
mask = offs < T
row = offs // D
col = offs % D
x_off = row * stride_x_row + col
y_off = row * stride_y_row + col
x_val = tl.load(x + x_off, mask=mask, other=0.).to(tl.float32)
s = 1.0 / (1.0 + exp(-x_val))
y_val = x_val * s
tl.store(y + y_off, y_val.to(y.dtype.element_ty), mask=mask)
@triton.autotune(
configs=[
triton.Config({'B': bs}, num_warps=num_warps)
for bs in [512, 1024, 2048, 4096, 8192]
for num_warps in NUM_WARPS_AUTOTUNE
],
key=['D'],
**autotune_cache_kwargs,
)
@triton.jit(do_not_specialize=['T'])
def swish_bwd_kernel(
x, dy, dx,
T,
D: tl.constexpr,
stride_x_row,
stride_dy_row,
stride_dx_row,
B: tl.constexpr,
):
pid = tl.program_id(0)
offs = pid * B + tl.arange(0, B)
mask = offs < T
row = offs // D
col = offs % D
x_off = row * stride_x_row + col
dy_off = row * stride_dy_row + col
dx_off = row * stride_dx_row + col
x_val = tl.load(x + x_off, mask=mask, other=0.).to(tl.float32)
g_val = tl.load(dy + dy_off, mask=mask, other=0.).to(tl.float32)
s = 1.0 / (1.0 + exp(-x_val))
dx_val = g_val * s * (1.0 + x_val * (1.0 - s))
tl.store(dx + dx_off, dx_val.to(dx.dtype.element_ty), mask=mask)
@torch.compiler.disable
def swish_fwd(x: torch.Tensor, output_contiguous: bool = False) -> torch.Tensor:
x = _ensure_inner_contiguous(x)
T, D = x.numel(), x.shape[-1]
y = _alloc_output(x, output_contiguous)
swish_fwd_kernel[lambda meta: (triton.cdiv(T, meta['B']),)](
x, y, T=T, D=D,
stride_x_row=_get_stride(x),
stride_y_row=_get_stride(y),
)
return y
@torch.compiler.disable
def swish_bwd(x: torch.Tensor, dy: torch.Tensor, output_contiguous: bool = False) -> torch.Tensor:
x = _ensure_inner_contiguous(x)
dy = _ensure_inner_contiguous(dy)
T, D = x.numel(), x.shape[-1]
dx = _alloc_output(x, output_contiguous)
swish_bwd_kernel[lambda meta: (triton.cdiv(T, meta['B']),)](
x, dy, dx, T=T, D=D,
stride_x_row=_get_stride(x),
stride_dy_row=_get_stride(dy),
stride_dx_row=_get_stride(dx),
)
return dx
class SwishFunction(torch.autograd.Function):
@staticmethod
@input_guard(no_guard_contiguous=True)
def forward(ctx, x):
ctx.save_for_backward(x)
return swish_fwd(x)
@staticmethod
@input_guard(no_guard_contiguous=True)
def backward(ctx, dout):
x, = ctx.saved_tensors
return swish_bwd(x, dout)
swish = SwishFunction.apply
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@torch.compile
def bias_gelu(y, bias):
x = bias + y
return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype)
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.compile
def bias_gelu_bwd(g, y, bias):
"""Assume that y has shape (B, D=D) and bias has shape (D)"""
x = bias + y
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
1 + tanh_out
)
grad_y = ff * g
return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype)
class GeLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, bias):
ctx.save_for_backward(input, bias)
return bias_gelu(input, bias)
@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
tmp = bias_gelu_bwd(grad_output, input, bias)
return tmp, tmp
bias_gelu_impl = GeLUFunction.apply
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@torch.compile
def gelu_fwd(x):
return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype)
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.compile
def gelu_bwd(g, x):
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
1 + tanh_out
)
return (ff * g).to(dtype=x.dtype)
class FastGeLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input):
ctx.save_for_backward(input)
return gelu_fwd(input)
@staticmethod
def backward(ctx, grad_output):
(input,) = ctx.saved_tensors
tmp = gelu_bwd(grad_output, input)
return tmp
fast_gelu_impl = FastGeLUFunction.apply
@torch.compile
def relu_bwd(g, x):
return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype)
@torch.compile
def sqrelu_fwd(x):
r = F.relu(x.float())
return (r * r).to(dtype=x.dtype)
@torch.compile
def sqrelu_bwd(g, x):
return (2.0 * g * F.relu(x.float())).to(dtype=x.dtype)
class SquaredReLUFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return sqrelu_fwd(input)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return sqrelu_bwd(grad_output, input)
sqrelu = SquaredReLUFunction.apply
@triton.autotune(
configs=[
triton.Config({'B': bs}, num_warps=num_warps)
for bs in [512, 1024, 2048, 4096, 8192]
for num_warps in NUM_WARPS_AUTOTUNE
],
key=['D'],
**autotune_cache_kwargs,
)
@triton.jit(do_not_specialize=['T'])
def swiglu_fwd_kernel(
x, y, z,
T,
D: tl.constexpr,
stride_x_row,
stride_y_row,
stride_z_row,
B: tl.constexpr,
):
pid = tl.program_id(0)
offs = pid * B + tl.arange(0, B)
mask = offs < T
row = offs // D
col = offs % D
x_off = row * stride_x_row + col
y_off = row * stride_y_row + col
z_off = row * stride_z_row + col
x_val = tl.load(x + x_off, mask=mask, other=0.).to(tl.float32)
y_val = tl.load(y + y_off, mask=mask, other=0.).to(tl.float32)
s = 1.0 / (1.0 + exp(-x_val))
z_val = x_val * s * y_val
tl.store(z + z_off, z_val.to(z.dtype.element_ty), mask=mask)
@triton.heuristics({
'HAS_WEIGHT': lambda args: args['z'] is not None,
})
@triton.autotune(
configs=[
triton.Config({'B': bs}, num_warps=num_warps)
for bs in [512, 1024, 2048, 4096, 8192]
for num_warps in NUM_WARPS_AUTOTUNE
],
key=['D'],
**autotune_cache_kwargs,
)
@triton.jit(do_not_specialize=['T'])
def swiglu_fwdbwd_kernel(
x, y, g, dx, dy, z,
T,
D: tl.constexpr,
stride_x_row,
stride_y_row,
stride_g_row,
stride_dx_row,
stride_dy_row,
stride_z_row,
B: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
):
pid = tl.program_id(0)
offs = pid * B + tl.arange(0, B)
mask = offs < T
row = offs // D
col = offs % D
x_off = row * stride_x_row + col
y_off = row * stride_y_row + col
g_off = row * stride_g_row + col
dx_off = row * stride_dx_row + col
dy_off = row * stride_dy_row + col
x_val = tl.load(x + x_off, mask=mask, other=0.).to(tl.float32)
y_val = tl.load(y + y_off, mask=mask, other=0.).to(tl.float32)
g_val = tl.load(g + g_off, mask=mask, other=0.).to(tl.float32)
s = 1.0 / (1.0 + exp(-x_val))
x_s = x_val * s
dx_val = g_val * s * (1.0 + x_val * (1.0 - s)) * y_val
dy_val = g_val * x_s
tl.store(dx + dx_off, dx_val.to(dx.dtype.element_ty), mask=mask)
tl.store(dy + dy_off, dy_val.to(dy.dtype.element_ty), mask=mask)
if HAS_WEIGHT:
z_off = row * stride_z_row + col
z_val = x_s * y_val
tl.store(z + z_off, z_val.to(z.dtype.element_ty), mask=mask)
@torch.compiler.disable
def swiglu_fwd(x: torch.Tensor, y: torch.Tensor, output_contiguous: bool = False) -> torch.Tensor:
assert x.shape == y.shape, f"swiglu_fwd: shape mismatch x={x.shape} y={y.shape}"
x = _ensure_inner_contiguous(x)
y = _ensure_inner_contiguous(y)
T, D = x.numel(), x.shape[-1]
z = _alloc_output(x, output_contiguous)
swiglu_fwd_kernel[lambda meta: (triton.cdiv(T, meta['B']),)](
x, y, z, T=T, D=D,
stride_x_row=_get_stride(x),
stride_y_row=_get_stride(y),
stride_z_row=_get_stride(z),
)
return z
@torch.compiler.disable
def swiglu_fwdbwd(
x: torch.Tensor,
y: torch.Tensor,
g: torch.Tensor,
use_weight: bool = False,
output_contiguous: bool = False,
):
assert x.shape == y.shape == g.shape, f"swiglu_fwdbwd: shape mismatch x={x.shape} y={y.shape} g={g.shape}"
x = _ensure_inner_contiguous(x)
y = _ensure_inner_contiguous(y)
g = _ensure_inner_contiguous(g)
T, D = x.numel(), x.shape[-1]
dx = _alloc_output(x, output_contiguous)
dy = _alloc_output(y, output_contiguous)
if use_weight:
z = _alloc_output(x, output_contiguous)
else:
z = None
swiglu_fwdbwd_kernel[lambda meta: (triton.cdiv(T, meta['B']),)](
x, y, g, dx, dy, z, T=T, D=D,
stride_x_row=_get_stride(x),
stride_y_row=_get_stride(y),
stride_g_row=_get_stride(g),
stride_dx_row=_get_stride(dx),
stride_dy_row=_get_stride(dy),
stride_z_row=_get_stride(z) if z is not None else 0,
)
if use_weight:
return dx, dy, z
return dx, dy
class SwiGLUFunction(torch.autograd.Function):
r"""
Swish-Gated Linear Unit (SwiGLU) function.
.. math::
\text{SwiGLU}(x, y) = swish(x) * y = \frac{x}{1 + \exp(-x)} * y
"""
@staticmethod
@input_guard(no_guard_contiguous=True)
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
return swiglu_fwd(x, y)
@staticmethod
@input_guard(no_guard_contiguous=True)
def backward(ctx, dout):
x, y = ctx.saved_tensors
return swiglu_fwdbwd(x, y, dout)
class SwiGLULinearFunction(torch.autograd.Function):
r"""
Swish-Gated Linear Unit (SwiGLU) function followed by a linear transformation.
.. math::
\text{SwiGLULinear}(x, y, W, b) = (swish(x) * y) W + b
This simple wrap discards the intermediate results of SwiGLU(x, y) to save memory.
"""
@staticmethod
@input_guard(no_guard_contiguous=True)
@autocast_custom_fwd
def forward(ctx, x, y, weight, bias):
z = swiglu_fwd(x, y, output_contiguous=True)
out = F.linear(z, weight, bias)
ctx.save_for_backward(x, y, weight)
ctx.linear_bias_is_none = bias is None
return out
@staticmethod
@input_guard(no_guard_contiguous=True)
@autocast_custom_bwd
def backward(ctx, dout, *args):
x, y, weight = ctx.saved_tensors
dout = dout.reshape(-1, dout.shape[-1])
dz = F.linear(dout, weight.t()).view_as(x)
dx, dy, z = swiglu_fwdbwd(x, y, dz, use_weight=True, output_contiguous=True)
dlinear_weight = torch.einsum("bo,bi->oi", dout, z.reshape(-1, z.shape[-1]))
dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
return dx, dy, dlinear_weight, dlinear_bias
swiglu = SwiGLUFunction.apply
swiglu_linear = SwiGLULinearFunction.apply
ACT2FN = {
'relu': F.relu,
'sigmoid': sigmoid,
'logsigmoid': logsigmoid,
'silu': swish,
'swish': swish,
'sqrelu': sqrelu,
'gelu': fast_gelu_impl,
'bias_gelu': bias_gelu_impl,
}