""" Stick-breaking Attention - 官方Triton实现 """ from stickbreaking_attention.sb_attn import sb_attn import math import torch from einops import rearrange from typing import Optional def stickbreaking_attention_std( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *, head_first: bool = False, seq_start: Optional[torch.Tensor] = None, sm_scale: Optional[float] = None, normalize: bool = True, attend_current: bool = False, ) -> torch.Tensor: """Stick-breaking attention using official Triton implementation""" if not head_first: q = rearrange(q, "b t h d -> b h t d") k = rearrange(k, "b t h d -> b h t d") v = rearrange(v, "b t h d -> b h t d") B, H, T_q, D = q.shape if sm_scale is None: sm_scale = 1.0 / math.sqrt(D) # 官方Triton实现 # 返回 (output, remainder) out, rem = sb_attn( q, k, v, inv_temp=sm_scale, attend_current=attend_current ) if not head_first: out = rearrange(out, "b h t d -> b t h d") return out