File size: 3,954 Bytes
fcfea15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# Adapted from https://github.com/hao-ai-lab/FastVideo/tree/main/fastvideo/attention

import os
import sys
import torch
from einops import rearrange

_FLASH_ATTN_IMPORT_ERROR = None

try:
    # Check for Flash Attention 3 installation path
    flash_attn3_path = os.getenv("FLASH_ATTN3_PATH")
    if flash_attn3_path:
        print(f"Using Flash Attention 3 from: {flash_attn3_path}")
        sys.path.insert(0, flash_attn3_path)
        from flash_attn_interface import flash_attn_varlen_func
    else:
        from flash_attn.flash_attn_interface import flash_attn_varlen_func
except ImportError as exc:
    flash_attn_varlen_func = None
    _FLASH_ATTN_IMPORT_ERROR = exc



def is_flash_attn_available() -> bool:
    return flash_attn_varlen_func is not None


def get_preferred_attention_backend() -> str:
    return "flash_attn" if is_flash_attn_available() else "torch_spda"


def describe_attention_backend() -> str:
    backend = get_preferred_attention_backend()
    if backend == "flash_attn":
        return "flash_attn"
    if _FLASH_ATTN_IMPORT_ERROR is None:
        return "torch_spda"
    return f"torch_spda (flash_attn unavailable: {_FLASH_ATTN_IMPORT_ERROR})"


def get_cu_seqlens(text_mask, img_len):
    """Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len

    Args:
        text_mask (torch.Tensor): the mask of text
        img_len (int): the length of image

    Returns:
        torch.Tensor: the calculated cu_seqlens for flash attention
    """
    batch_size = text_mask.shape[0]
    text_len = text_mask.sum(dim=1)
    max_len = text_mask.shape[1] + img_len

    cu_seqlens = torch.zeros([2 * batch_size + 1],
                             dtype=torch.int32, device="cuda")

    for i in range(batch_size):
        s = text_len[i] + img_len
        s1 = i * max_len + s
        s2 = (i + 1) * max_len
        cu_seqlens[2 * i + 1] = s1
        cu_seqlens[2 * i + 2] = s2

    return cu_seqlens


def attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    backend: str = "flash_attn",
    *,
    causal: bool = False,
    softmax_scale: float = None,
    attn_kwargs: dict = None,
):
    """
    Args:
        q (torch.Tensor): Query tensor of shape [batch_size, seq_len, num_heads, head_dim]
        k (torch.Tensor): Key tensor of shape [batch_size, seq_len, num_heads, head_dim]
        v (torch.Tensor): Value tensor of shape [batch_size, seq_len, num_heads
    """
    if backend == "auto":
        backend = get_preferred_attention_backend()
    # Fall back to torch_spda when flash_attn was requested but unavailable
    if backend == "flash_attn" and flash_attn_varlen_func is None:
        backend = "torch_spda"
    assert backend in [
        "torch_spda", "flash_attn"], f"Unsupported attention backend: {backend}"
    assert q.dim() == 4 and k.dim() == 4 and v.dim() == 4, "Input tensors must be 4D"
    batch_size = q.shape[0]
    if backend == "torch_spda":
        q = rearrange(q, "b l h c -> b h l c")
        k = rearrange(k, "b l h c -> b h l c")
        v = rearrange(v, "b l h c -> b h l c")
        output = torch.nn.functional.scaled_dot_product_attention(
            q, k, v, is_causal=causal, scale=softmax_scale)
        output = rearrange(output, "b h l c -> b l h c")
    elif backend == "flash_attn":
        cu_seqlens_q = attn_kwargs['cu_seqlens_q']
        cu_seqlens_kv = attn_kwargs['cu_seqlens_kv']
        max_seqlen_q = attn_kwargs['max_seqlen_q']
        max_seqlen_kv = attn_kwargs['max_seqlen_kv']
        x = flash_attn_varlen_func(
            q.view(q.shape[0] * q.shape[1], *q.shape[2:]),
            k.view(k.shape[0] * k.shape[1], *k.shape[2:]),
            v.view(v.shape[0] * v.shape[1], *v.shape[2:]),
            cu_seqlens_q,
            cu_seqlens_kv,
            max_seqlen_q,
            max_seqlen_kv,
        )
        output = x.view(
            batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]
        )

    return output