base_IIXIV / fla /ops /utils /softplus.py
mainline777's picture
Duplicate from silx-ai/Quasar-Preview
41865df
Raw
History Blame Contribute Delete
3.05 kB
# REVISED FROM
# https://github.com/shawntan/stickbreaking-attention/blob/main/stickbreaking_attention/sb_varlen/softplus.py
import triton
from triton import language as tl
from fla.utils import IS_NVIDIA
def _generate_softplus(num_pack):
template = """
.reg .pred p;
setp.gt.f32 p, ${in_reg}, 20.;
@p mov.f32 ${out_reg}, ${in_reg};
@!p mul.f32 ${out_reg}, ${in_reg}, 1.4426950408889634;
@!p ex2.approx.ftz.f32 ${out_reg}, ${out_reg};
@!p add.f32 ${out_reg}, ${out_reg}, 1.0;
@!p lg2.approx.ftz.f32 ${out_reg}, ${out_reg};
@!p mul.f32 ${out_reg}, ${out_reg}, 0.6931471805599453;
"""
out_str = ""
for i in range(num_pack):
inner_str = template.format(out_reg=i, in_reg=i + num_pack)
out_str += "{" + inner_str + "}\n"
# flatten out because torch.compile doesn't like newlines
out_str = " ".join(out_str.split("\n"))
return out_str
def _generate_softplus2(num_pack):
template = """
.reg .pred p;
setp.gt.f32 p, ${in_reg}, 15.;
@p mov.f32 ${out_reg}, ${in_reg};
@!p ex2.approx.ftz.f32 ${out_reg}, ${in_reg};
@!p add.f32 ${out_reg}, ${out_reg}, 1.0;
@!p lg2.approx.ftz.f32 ${out_reg}, ${out_reg};
"""
out_str = ""
for i in range(num_pack):
inner_str = template.format(out_reg=i, in_reg=i + num_pack)
out_str += "{" + inner_str + "}\n"
# flatten out because torch.compile doesn't like newlines
out_str = " ".join(out_str.split("\n"))
return out_str
def _generate_constraints(num_pack):
return ",".join("=r" for i in range(num_pack)) + "," + ",".join("r" for i in range(num_pack))
_NUM_REG = 1
s_softplus: tl.constexpr = tl.constexpr(_generate_softplus(_NUM_REG))
s_softplus2: tl.constexpr = tl.constexpr(_generate_softplus2(_NUM_REG))
s_constraints: tl.constexpr = tl.constexpr(_generate_constraints(_NUM_REG))
NUM_REG: tl.constexpr = tl.constexpr(_NUM_REG)
@triton.jit
def softplus_nv(x):
# equivalent to:
# return tl.where(x < 20.0, tl.math.log(1 + tl.math.exp(x)), x)
return tl.inline_asm_elementwise(
asm=s_softplus,
constraints=s_constraints,
pack=NUM_REG,
args=[
x,
],
dtype=tl.float32,
is_pure=True,
)
@triton.jit
def softplus_triton(x):
return tl.where(x < 20.0, tl.math.log(1 + tl.math.exp(x)), x)
@triton.jit
def softplus2_nv(x):
# equivalent to:
# return tl.where(x < 15.0, tl.math.log2(1 + tl.math.exp2(x)), x)
return tl.inline_asm_elementwise(
asm=s_softplus2,
constraints=s_constraints,
pack=NUM_REG,
args=[
x,
],
dtype=tl.float32,
is_pure=True,
)
@triton.jit
def softplus2_triton(x):
return tl.where(x < 15.0, tl.math.log2(1 + tl.math.exp2(x)), x)
if IS_NVIDIA:
softplus = softplus_nv
softplus2 = softplus2_nv
else:
softplus = softplus_triton
softplus2 = softplus2_triton