alibi_pile_4layer / ops /stickbreaking_attention_std.py
Lanni-ni's picture
add remote code + model files
92bfd9d verified
"""
Stick-breaking Attention - 官方Triton实现
"""
from stickbreaking_attention.sb_attn import sb_attn
import math
import torch
from einops import rearrange
from typing import Optional
def stickbreaking_attention_std(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
*,
head_first: bool = False,
seq_start: Optional[torch.Tensor] = None,
sm_scale: Optional[float] = None,
normalize: bool = True,
attend_current: bool = False,
) -> torch.Tensor:
"""Stick-breaking attention using official Triton implementation"""
if not head_first:
q = rearrange(q, "b t h d -> b h t d")
k = rearrange(k, "b t h d -> b h t d")
v = rearrange(v, "b t h d -> b h t d")
B, H, T_q, D = q.shape
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(D)
# 官方Triton实现
# 返回 (output, remainder)
out, rem = sb_attn(
q, k, v,
inv_temp=sm_scale,
attend_current=attend_current
)
if not head_first:
out = rearrange(out, "b h t d -> b t h d")
return out