File size: 1,105 Bytes
15063d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
"""
Stick-breaking Attention - 官方Triton实现
"""
from stickbreaking_attention.sb_attn import sb_attn
import math
import torch
from einops import rearrange
from typing import Optional
def stickbreaking_attention_std(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
*,
head_first: bool = False,
seq_start: Optional[torch.Tensor] = None,
sm_scale: Optional[float] = None,
normalize: bool = True,
attend_current: bool = False,
) -> torch.Tensor:
"""Stick-breaking attention using official Triton implementation"""
if not head_first:
q = rearrange(q, "b t h d -> b h t d")
k = rearrange(k, "b t h d -> b h t d")
v = rearrange(v, "b t h d -> b h t d")
B, H, T_q, D = q.shape
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(D)
# 官方Triton实现
# 返回 (output, remainder)
out, rem = sb_attn(
q, k, v,
inv_temp=sm_scale,
attend_current=attend_current
)
if not head_first:
out = rearrange(out, "b h t d -> b t h d")
return out
|