File size: 4,171 Bytes
176b11a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | """
sparse_attn.py
--------------
Triton sparse attention kernel for SparseVLM.
Computes attention scores ONLY for kept visual tokens against text,
skipping pruned tokens entirely instead of masking after dense compute.
For K=80 kept from N_vis=196:
Dense: 196 * 77 = 15,092 attention pairs
Sparse: 80 * 77 = 6,160 attention pairs (59% fewer FLOPs)
Falls back to pure PyTorch automatically when Triton is unavailable (CPU testing).
"""
import torch
try:
import triton
import triton.language as tl
TRITON_AVAILABLE = True
except ImportError:
TRITON_AVAILABLE = False
if TRITON_AVAILABLE:
@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=8, num_stages=3),
],
key=["K", "N_text", "D"],
)
@triton.jit
def _sparse_attn_kernel(
Q_ptr, K_ptr, Out_ptr,
stride_qb, stride_qk, stride_qd,
stride_kb, stride_kn, stride_kd,
stride_ob, stride_ok, stride_on,
B: tl.constexpr,
K: tl.constexpr,
N_text: tl.constexpr,
D: tl.constexpr,
scale,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
pid_b = tl.program_id(2)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, D)
Q_base = Q_ptr + pid_b * stride_qb
q_mask = (offs_m[:, None] < K) & (offs_d[None, :] < D)
q = tl.load(
Q_base + offs_m[:, None] * stride_qk + offs_d[None, :] * stride_qd,
mask=q_mask, other=0.0,
)
K_base = K_ptr + pid_b * stride_kb
k_mask = (offs_n[:, None] < N_text) & (offs_d[None, :] < D)
k = tl.load(
K_base + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd,
mask=k_mask, other=0.0,
)
scores = tl.dot(q, tl.trans(k)) * scale
Out_base = Out_ptr + pid_b * stride_ob
out_mask = (offs_m[:, None] < K) & (offs_n[None, :] < N_text)
tl.store(
Out_base + offs_m[:, None] * stride_ok + offs_n[None, :] * stride_on,
scores, mask=out_mask,
)
def _sparse_attn_triton(Q: torch.Tensor, K: torch.Tensor) -> torch.Tensor:
B, Kk, D = Q.shape
_, N_text, _ = K.shape
scale = D ** -0.5
Out = torch.empty(B, Kk, N_text, device=Q.device, dtype=Q.dtype)
def grid(meta):
return (
triton.cdiv(Kk, meta["BLOCK_M"]),
triton.cdiv(N_text, meta["BLOCK_N"]),
B,
)
_sparse_attn_kernel[grid](
Q, K, Out,
Q.stride(0), Q.stride(1), Q.stride(2),
K.stride(0), K.stride(1), K.stride(2),
Out.stride(0), Out.stride(1), Out.stride(2),
B=B, K=Kk, N_text=N_text, D=D, scale=scale,
)
return Out
def _sparse_attn_pytorch(Q: torch.Tensor, K: torch.Tensor) -> torch.Tensor:
scale = Q.shape[-1] ** -0.5
return torch.bmm(Q, K.transpose(1, 2)) * scale
def sparse_vision_attn(
patch_tokens: torch.Tensor, # [B, N_vis, D]
text_embeds: torch.Tensor, # [B, N_text, D]
kept_indices: torch.Tensor, # [B, K] int64
use_triton: bool = True,
) -> torch.Tensor: # [B, K, N_text]
"""
Compute attention scores only for kept visual tokens.
Replaces:
torch.matmul(patch_tokens, text_embeds.transpose(1, 2))
With a sparse version operating only on kept tokens.
"""
B, N_vis, D = patch_tokens.shape
_, K = kept_indices.shape
idx = kept_indices.unsqueeze(-1).expand(B, K, D)
Q = torch.gather(patch_tokens, dim=1, index=idx).contiguous()
K_mat = text_embeds.contiguous()
if use_triton and TRITON_AVAILABLE and Q.is_cuda:
return _sparse_attn_triton(Q, K_mat)
return _sparse_attn_pytorch(Q, K_mat)
|