Transformer_500M / attn_mods.py
yagizdevre's picture
transformer new
a2fbb2f
import torch
from torch import Tensor
from torch.nn.attention.flex_attention import _score_mod_signature
from torch._inductor.lowering import make_pointwise, register_lowering
# Some internal torch.compile details
from torch._inductor.virtualized import ops
from functools import partial
@torch.library.custom_op("approx::tanh", mutates_args=())
def _tanh_approx(inp: Tensor) -> Tensor:
return torch.tanh(inp)
@_tanh_approx.register_fake
def _(inp: torch.Tensor) -> torch.Tensor:
return torch.tanh(inp)
def _tanh_approx_lowering(inp):
fn = partial(ops.inline_asm_elementwise, asm="tanh.approx.f32 $0, $1;")
return make_pointwise(fn)(inp)
register_lowering(torch.ops.approx.tanh)(_tanh_approx_lowering)
class _TanhApprox(torch.autograd.Function):
@staticmethod
def forward(x):
return torch.ops.approx.tanh(x)
@staticmethod
def setup_context(ctx, inputs, output):
(x,) = inputs
result = output
ctx.save_for_backward(result)
@staticmethod
def backward(ctx, grad_output):
(result,) = ctx.saved_tensors
return grad_output * (1 - result * result)
@staticmethod
def vmap(info, in_dims, x):
return torch.tanh(x), 0
_tanh_approx = _TanhApprox.apply
def generate_tanh_softcap(soft_cap: int, approx: bool = False) -> _score_mod_signature:
"""Returns an tanh bias score_mod given the number of heads H
Args:
soft_cap: The soft cap value to use for normalizing logits
approx: Whether to use the `tanh.approx.` ptx instruction
Returns:
tanh_softcap: score_mod
"""
tanh = _tanh_approx if approx else torch.tanh
def tanh_softcap(score, b, h, q_idx, kv_idx):
return soft_cap * tanh(score / soft_cap)
prefix = "tanh_softcap_approx" if approx else "tanh_softcap"
tanh_softcap.__name__ = f"{prefix}_{soft_cap}"
return tanh_softcap
def generate_alibi_bias(H: int) -> _score_mod_signature:
"""Returns an alibi bias score_mod given the number of heads H
Args:
H: number of heads
Returns:
alibi_bias: alibi bias score_mod
"""
def alibi_mod(score, b, h, q_idx, kv_idx):
scale = torch.exp2(-((h + 1) * 8.0 / H))
bias = (kv_idx - q_idx) * scale
return score + bias
return alibi_mod
def generate_tanh_softcap_alibi(H: int, soft_cap: float, approx: bool = False) -> _score_mod_signature:
"""Returns a combined ALiBi and tanh softcapping score_mod.
Args:
H (int): number of heads for ALiBi scaling
soft_cap (float): the soft cap value for normalizing/logit clipping
approx (bool): Whether to use the 'tanh.approx' PTX-based approximation
Returns:
A combined score_mod function that first applies ALiBi,
then performs softcap + tanh (optionally approximate).
"""
tanh_func = _tanh_approx if approx else torch.tanh
def alibi_tanh_softcap(score, b, h, q_idx, kv_idx):
# Compute ALiBi bias
scale = torch.exp2(-((h + 1) * 8.0 / H))
bias = (kv_idx - q_idx) * scale
score = score + bias
# Apply softcap
score = score / soft_cap
# Apply tanh
score = tanh_func(score)
# Rescale by soft_cap
score = score * soft_cap
return score
# Give the score_mod a unique name:
if approx:
alibi_tanh_softcap.__name__ = f"tanh_softcap_alibi_approx_{soft_cap}"
else:
alibi_tanh_softcap.__name__ = f"tanh_softcap_alibi_{soft_cap}"
return alibi_tanh_softcap