|
|
import torch |
|
|
from torch import Tensor |
|
|
from torch.nn.attention.flex_attention import _score_mod_signature |
|
|
from torch._inductor.lowering import make_pointwise, register_lowering |
|
|
|
|
|
|
|
|
from torch._inductor.virtualized import ops |
|
|
from functools import partial |
|
|
|
|
|
|
|
|
@torch.library.custom_op("approx::tanh", mutates_args=()) |
|
|
def _tanh_approx(inp: Tensor) -> Tensor: |
|
|
return torch.tanh(inp) |
|
|
|
|
|
|
|
|
@_tanh_approx.register_fake |
|
|
def _(inp: torch.Tensor) -> torch.Tensor: |
|
|
return torch.tanh(inp) |
|
|
|
|
|
|
|
|
def _tanh_approx_lowering(inp): |
|
|
fn = partial(ops.inline_asm_elementwise, asm="tanh.approx.f32 $0, $1;") |
|
|
return make_pointwise(fn)(inp) |
|
|
|
|
|
|
|
|
register_lowering(torch.ops.approx.tanh)(_tanh_approx_lowering) |
|
|
|
|
|
|
|
|
class _TanhApprox(torch.autograd.Function): |
|
|
@staticmethod |
|
|
def forward(x): |
|
|
return torch.ops.approx.tanh(x) |
|
|
|
|
|
@staticmethod |
|
|
def setup_context(ctx, inputs, output): |
|
|
(x,) = inputs |
|
|
result = output |
|
|
ctx.save_for_backward(result) |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad_output): |
|
|
(result,) = ctx.saved_tensors |
|
|
return grad_output * (1 - result * result) |
|
|
|
|
|
@staticmethod |
|
|
def vmap(info, in_dims, x): |
|
|
return torch.tanh(x), 0 |
|
|
|
|
|
|
|
|
_tanh_approx = _TanhApprox.apply |
|
|
|
|
|
|
|
|
def generate_tanh_softcap(soft_cap: int, approx: bool = False) -> _score_mod_signature: |
|
|
"""Returns an tanh bias score_mod given the number of heads H |
|
|
|
|
|
Args: |
|
|
soft_cap: The soft cap value to use for normalizing logits |
|
|
approx: Whether to use the `tanh.approx.` ptx instruction |
|
|
|
|
|
Returns: |
|
|
tanh_softcap: score_mod |
|
|
""" |
|
|
tanh = _tanh_approx if approx else torch.tanh |
|
|
|
|
|
def tanh_softcap(score, b, h, q_idx, kv_idx): |
|
|
return soft_cap * tanh(score / soft_cap) |
|
|
|
|
|
prefix = "tanh_softcap_approx" if approx else "tanh_softcap" |
|
|
tanh_softcap.__name__ = f"{prefix}_{soft_cap}" |
|
|
|
|
|
return tanh_softcap |
|
|
|
|
|
def generate_alibi_bias(H: int) -> _score_mod_signature: |
|
|
"""Returns an alibi bias score_mod given the number of heads H |
|
|
|
|
|
Args: |
|
|
H: number of heads |
|
|
|
|
|
Returns: |
|
|
alibi_bias: alibi bias score_mod |
|
|
""" |
|
|
|
|
|
def alibi_mod(score, b, h, q_idx, kv_idx): |
|
|
scale = torch.exp2(-((h + 1) * 8.0 / H)) |
|
|
bias = (kv_idx - q_idx) * scale |
|
|
return score + bias |
|
|
|
|
|
return alibi_mod |
|
|
|
|
|
|
|
|
def generate_tanh_softcap_alibi(H: int, soft_cap: float, approx: bool = False) -> _score_mod_signature: |
|
|
"""Returns a combined ALiBi and tanh softcapping score_mod. |
|
|
|
|
|
Args: |
|
|
H (int): number of heads for ALiBi scaling |
|
|
soft_cap (float): the soft cap value for normalizing/logit clipping |
|
|
approx (bool): Whether to use the 'tanh.approx' PTX-based approximation |
|
|
|
|
|
Returns: |
|
|
A combined score_mod function that first applies ALiBi, |
|
|
then performs softcap + tanh (optionally approximate). |
|
|
""" |
|
|
tanh_func = _tanh_approx if approx else torch.tanh |
|
|
|
|
|
def alibi_tanh_softcap(score, b, h, q_idx, kv_idx): |
|
|
|
|
|
scale = torch.exp2(-((h + 1) * 8.0 / H)) |
|
|
bias = (kv_idx - q_idx) * scale |
|
|
score = score + bias |
|
|
|
|
|
|
|
|
score = score / soft_cap |
|
|
|
|
|
|
|
|
score = tanh_func(score) |
|
|
|
|
|
|
|
|
score = score * soft_cap |
|
|
return score |
|
|
|
|
|
|
|
|
if approx: |
|
|
alibi_tanh_softcap.__name__ = f"tanh_softcap_alibi_approx_{soft_cap}" |
|
|
else: |
|
|
alibi_tanh_softcap.__name__ = f"tanh_softcap_alibi_{soft_cap}" |
|
|
|
|
|
return alibi_tanh_softcap |