File size: 4,171 Bytes
176b11a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""
sparse_attn.py
--------------
Triton sparse attention kernel for SparseVLM.

Computes attention scores ONLY for kept visual tokens against text,
skipping pruned tokens entirely instead of masking after dense compute.

For K=80 kept from N_vis=196:
  Dense:  196 * 77 = 15,092 attention pairs
  Sparse:  80 * 77 =  6,160 attention pairs  (59% fewer FLOPs)

Falls back to pure PyTorch automatically when Triton is unavailable (CPU testing).
"""

import torch

try:
    import triton
    import triton.language as tl
    TRITON_AVAILABLE = True
except ImportError:
    TRITON_AVAILABLE = False


if TRITON_AVAILABLE:

    @triton.autotune(
        configs=[
            triton.Config({"BLOCK_M": 64,  "BLOCK_N": 64},  num_warps=4, num_stages=2),
            triton.Config({"BLOCK_M": 128, "BLOCK_N": 64},  num_warps=4, num_stages=3),
            triton.Config({"BLOCK_M": 64,  "BLOCK_N": 128}, num_warps=8, num_stages=2),
            triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=8, num_stages=3),
        ],
        key=["K", "N_text", "D"],
    )
    @triton.jit
    def _sparse_attn_kernel(
        Q_ptr, K_ptr, Out_ptr,
        stride_qb, stride_qk, stride_qd,
        stride_kb, stride_kn, stride_kd,
        stride_ob, stride_ok, stride_on,
        B: tl.constexpr,
        K: tl.constexpr,
        N_text: tl.constexpr,
        D: tl.constexpr,
        scale,
        BLOCK_M: tl.constexpr,
        BLOCK_N: tl.constexpr,
    ):
        pid_m = tl.program_id(0)
        pid_n = tl.program_id(1)
        pid_b = tl.program_id(2)

        offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
        offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
        offs_d = tl.arange(0, D)

        Q_base = Q_ptr + pid_b * stride_qb
        q_mask = (offs_m[:, None] < K) & (offs_d[None, :] < D)
        q = tl.load(
            Q_base + offs_m[:, None] * stride_qk + offs_d[None, :] * stride_qd,
            mask=q_mask, other=0.0,
        )

        K_base = K_ptr + pid_b * stride_kb
        k_mask = (offs_n[:, None] < N_text) & (offs_d[None, :] < D)
        k = tl.load(
            K_base + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd,
            mask=k_mask, other=0.0,
        )

        scores = tl.dot(q, tl.trans(k)) * scale

        Out_base = Out_ptr + pid_b * stride_ob
        out_mask = (offs_m[:, None] < K) & (offs_n[None, :] < N_text)
        tl.store(
            Out_base + offs_m[:, None] * stride_ok + offs_n[None, :] * stride_on,
            scores, mask=out_mask,
        )


def _sparse_attn_triton(Q: torch.Tensor, K: torch.Tensor) -> torch.Tensor:
    B, Kk, D = Q.shape
    _, N_text, _ = K.shape
    scale = D ** -0.5
    Out = torch.empty(B, Kk, N_text, device=Q.device, dtype=Q.dtype)

    def grid(meta):
        return (
            triton.cdiv(Kk, meta["BLOCK_M"]),
            triton.cdiv(N_text, meta["BLOCK_N"]),
            B,
        )

    _sparse_attn_kernel[grid](
        Q, K, Out,
        Q.stride(0), Q.stride(1), Q.stride(2),
        K.stride(0), K.stride(1), K.stride(2),
        Out.stride(0), Out.stride(1), Out.stride(2),
        B=B, K=Kk, N_text=N_text, D=D, scale=scale,
    )
    return Out


def _sparse_attn_pytorch(Q: torch.Tensor, K: torch.Tensor) -> torch.Tensor:
    scale = Q.shape[-1] ** -0.5
    return torch.bmm(Q, K.transpose(1, 2)) * scale


def sparse_vision_attn(
    patch_tokens: torch.Tensor,     # [B, N_vis, D]
    text_embeds: torch.Tensor,      # [B, N_text, D]
    kept_indices: torch.Tensor,     # [B, K] int64
    use_triton: bool = True,
) -> torch.Tensor:                  # [B, K, N_text]
    """
    Compute attention scores only for kept visual tokens.

    Replaces:
        torch.matmul(patch_tokens, text_embeds.transpose(1, 2))
    With a sparse version operating only on kept tokens.
    """
    B, N_vis, D = patch_tokens.shape
    _, K = kept_indices.shape

    idx = kept_indices.unsqueeze(-1).expand(B, K, D)
    Q = torch.gather(patch_tokens, dim=1, index=idx).contiguous()
    K_mat = text_embeds.contiguous()

    if use_triton and TRITON_AVAILABLE and Q.is_cuda:
        return _sparse_attn_triton(Q, K_mat)
    return _sparse_attn_pytorch(Q, K_mat)