File size: 3,840 Bytes
d979fab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# Copyright (c) 2026, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

import torch

# MagiAttention attn_type_map convention
FULL, CAUSAL = 0, 1

def build_magi_ranges(kv_len: int, q_len: int, block_size: int, ar_decode: bool=False, device: str = "cpu"):
    """
    Fixed strategy:
      - use_cache=True: Mask blocked_k = (kv_len - block_size - 1) column
      - causal_attn=False: Window interior is FULL (bidirectional)
      - If q_len==kv_len: Use coarse prefix version (fewer ranges)
      - Otherwise: General decode version (recompute rows expanding visible region row by row)

    Conventions:
      - K/V global length kv_len: [0, kv_len)
      - Current Q is "last q_len tokens"
      - First r=q_len-block_size rows are recomputed; last block_size rows are window
    """
    assert 0 < q_len <= kv_len

    if ar_decode:
        return {
            "q_ranges": torch.tensor([[0, q_len]], dtype=torch.int32, device=device).contiguous(),
            "k_ranges": torch.tensor([[0, kv_len]], dtype=torch.int32, device=device).contiguous(),
            "attn_type_map": torch.tensor([CAUSAL],    dtype=torch.int32, device=device).contiguous(),
        }


    assert 0 < block_size <= q_len <= kv_len
    B = block_size
    r = q_len - B
    q_global_start = kv_len - q_len

    window_start_k = kv_len - B
    blocked_k = window_start_k - 1  # The column that is blocked

    q_ranges, k_ranges, types = [], [], []

    # -------- prefix (q_len == kv_len) coarse-grained --------
    if q_len == kv_len:
        prefix_len = window_start_k  # kv_len - B

        # prefix->prefix: causal
        if prefix_len > 0:
            q_ranges += [[0, prefix_len]]
            k_ranges += [[0, prefix_len]]
            types   += [CAUSAL]

        # window->prefix: full, but exclude blocked_k => keys [0, blocked_k)
        if prefix_len > 0 and blocked_k > 0:
            q_ranges += [[prefix_len, kv_len]]
            k_ranges += [[0, blocked_k]]
            types   += [FULL]

        # window->window: full
        q_ranges += [[prefix_len, kv_len]]
        k_ranges += [[prefix_len, kv_len]]
        types += [FULL]

        return {
            "q_ranges": torch.tensor(q_ranges, dtype=torch.int32, device=device).contiguous(),
            "k_ranges": torch.tensor(k_ranges, dtype=torch.int32, device=device).contiguous(),
            "attn_type_map": torch.tensor(types,    dtype=torch.int32, device=device).contiguous(),
        }

    # -------- decode / general (q_len < kv_len) --------

    # A) Recomputed rows: expand visible key cutoff row by row (use FULL + single-row q_range for precise shape)
    for i in range(r):
        g = q_global_start + i
        q_ranges.append([i, i + 1])
        k_ranges.append([0, g + 1])   # Allow keys [0, g]
        types.append(FULL)

    # B) Window rows: allow prefix but block blocked_k; window interior is full
    q_win = [r, q_len]

    # prefix keys [0, blocked_k)
    if blocked_k > 0:
        q_ranges.append(q_win)
        k_ranges.append([0, blocked_k])
        types.append(FULL)

    # window keys [window_start_k, kv_len)
    q_ranges.append(q_win)
    k_ranges.append([window_start_k, kv_len])
    types.append(FULL)

    return {
        "q_ranges": torch.tensor(q_ranges, dtype=torch.int32, device=device).contiguous(),
        "k_ranges": torch.tensor(k_ranges, dtype=torch.int32, device=device).contiguous(),
        "attn_type_map": torch.tensor(types, dtype=torch.int32, device=device).contiguous(),
    }