# REVISED FROM # https://github.com/shawntan/stickbreaking-attention/blob/main/stickbreaking_attention/sb_varlen/softplus.py import triton from triton import language as tl from fla.utils import IS_NVIDIA def _generate_softplus(num_pack): template = """ .reg .pred p; setp.gt.f32 p, ${in_reg}, 20.; @p mov.f32 ${out_reg}, ${in_reg}; @!p mul.f32 ${out_reg}, ${in_reg}, 1.4426950408889634; @!p ex2.approx.ftz.f32 ${out_reg}, ${out_reg}; @!p add.f32 ${out_reg}, ${out_reg}, 1.0; @!p lg2.approx.ftz.f32 ${out_reg}, ${out_reg}; @!p mul.f32 ${out_reg}, ${out_reg}, 0.6931471805599453; """ out_str = "" for i in range(num_pack): inner_str = template.format(out_reg=i, in_reg=i + num_pack) out_str += "{" + inner_str + "}\n" # flatten out because torch.compile doesn't like newlines out_str = " ".join(out_str.split("\n")) return out_str def _generate_softplus2(num_pack): template = """ .reg .pred p; setp.gt.f32 p, ${in_reg}, 15.; @p mov.f32 ${out_reg}, ${in_reg}; @!p ex2.approx.ftz.f32 ${out_reg}, ${in_reg}; @!p add.f32 ${out_reg}, ${out_reg}, 1.0; @!p lg2.approx.ftz.f32 ${out_reg}, ${out_reg}; """ out_str = "" for i in range(num_pack): inner_str = template.format(out_reg=i, in_reg=i + num_pack) out_str += "{" + inner_str + "}\n" # flatten out because torch.compile doesn't like newlines out_str = " ".join(out_str.split("\n")) return out_str def _generate_constraints(num_pack): return ",".join("=r" for i in range(num_pack)) + "," + ",".join("r" for i in range(num_pack)) _NUM_REG = 1 s_softplus: tl.constexpr = tl.constexpr(_generate_softplus(_NUM_REG)) s_softplus2: tl.constexpr = tl.constexpr(_generate_softplus2(_NUM_REG)) s_constraints: tl.constexpr = tl.constexpr(_generate_constraints(_NUM_REG)) NUM_REG: tl.constexpr = tl.constexpr(_NUM_REG) @triton.jit def softplus_nv(x): # equivalent to: # return tl.where(x < 20.0, tl.math.log(1 + tl.math.exp(x)), x) return tl.inline_asm_elementwise( asm=s_softplus, constraints=s_constraints, pack=NUM_REG, args=[ x, ], dtype=tl.float32, is_pure=True, ) @triton.jit def softplus_triton(x): return tl.where(x < 20.0, tl.math.log(1 + tl.math.exp(x)), x) @triton.jit def softplus2_nv(x): # equivalent to: # return tl.where(x < 15.0, tl.math.log2(1 + tl.math.exp2(x)), x) return tl.inline_asm_elementwise( asm=s_softplus2, constraints=s_constraints, pack=NUM_REG, args=[ x, ], dtype=tl.float32, is_pure=True, ) @triton.jit def softplus2_triton(x): return tl.where(x < 15.0, tl.math.log2(1 + tl.math.exp2(x)), x) if IS_NVIDIA: softplus = softplus_nv softplus2 = softplus2_nv else: softplus = softplus_triton softplus2 = softplus2_triton