File size: 10,396 Bytes
b701455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
"""Attention implementations supporting PyTorch, XFormers, and SageAttention."""
try:
    import sageattention
except ImportError:
    sageattention = None

try:
    import spas_sage_attn
except ImportError:
    spas_sage_attn = None

try:
    import xformers
    BROKEN_XFORMERS = xformers.__version__.startswith("0.0.2") and not xformers.__version__.startswith("0.0.20")
except ImportError:
    xformers = None
    BROKEN_XFORMERS = False

import torch
import torch.nn.functional as F

# Pre-computed padding targets for SageAttention supported dimensions
# Maps dimension -> (target_dim, padding_amount) or None if no padding needed
_SAGE_PAD_CACHE: dict[int, tuple[int, int] | None] = {}


def _get_sage_padding(dim: int) -> tuple[int, int] | None:
    """Get pre-computed padding target for a given dimension.
    
    Returns (target_dim, pad_amount) or None if no padding needed.
    """
    if dim not in _SAGE_PAD_CACHE:
        if dim in (64, 96, 128):
            _SAGE_PAD_CACHE[dim] = None  # No padding needed
        elif dim < 64:
            _SAGE_PAD_CACHE[dim] = (64, 64 - dim)
        elif dim < 128:
            _SAGE_PAD_CACHE[dim] = (128, 128 - dim)
        else:
            _SAGE_PAD_CACHE[dim] = None  # Unsupported, no padding
    return _SAGE_PAD_CACHE[dim]


def _pad_for_sage(q, k, v, dim):
    """Pad tensors to supported SageAttention dimensions (64, 96, 128)."""
    padding = _get_sage_padding(dim)
    if padding is None:
        return q, k, v, dim
    target, pad = padding
    return (F.pad(q, (0, pad)), F.pad(k, (0, pad)), F.pad(v, (0, pad)), dim)


def _reshape_for_heads(q, k, v, heads, flux=False, skip_reshape=False):
    """Reshape tensors for multi-head attention."""
    if flux and skip_reshape:
        return q, k, v, q.shape[-1]
    b = q.shape[0]
    dim_head = q.shape[-1] // heads
    reshape_fn = lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2).contiguous()
    return reshape_fn(q), reshape_fn(k), reshape_fn(v), dim_head


def _reshape_output(out, b, heads, dim_head, flux=False, skip_reshape=False):
    """Reshape attention output back to original format."""
    if flux and not skip_reshape:
        return out.transpose(1, 2).reshape(b, -1, heads * dim_head)
    if not flux:
        return out.transpose(1, 2).reshape(b, -1, heads * dim_head)
    return out.transpose(1, 2).reshape(b, -1, heads * dim_head)


def attention_pytorch(q, k, v, heads, mask=None, skip_reshape=False, flux=False):
    """Multi-head attention using PyTorch SDPA."""
    b = q.shape[0]
    if not flux:
        seq_q, seq_kv = q.shape[1], k.shape[1]
        dim_head = q.shape[-1] // heads
        q = q.view(b, seq_q, heads, dim_head).transpose(1, 2)
        k = k.view(b, seq_kv, heads, dim_head).transpose(1, 2)
        v = v.view(b, seq_kv, heads, dim_head).transpose(1, 2)
        out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
        return out.transpose(1, 2).reshape(b, seq_q, heads * dim_head)
    
    dim_head = q.shape[-1] if skip_reshape else q.shape[-1] // heads
    if not skip_reshape:
        q, k, v = [t.view(b, -1, heads, dim_head).transpose(1, 2) for t in (q, k, v)]
    out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
    return out.transpose(1, 2).reshape(b, -1, heads * dim_head)


def attention_xformers(q, k, v, heads, mask=None, skip_reshape=False, flux=False):
    """Multi-head attention using XFormers."""
    b = q.shape[0]
    if not flux:
        dim_head = q.shape[-1] // heads
        q, k, v = [t.view(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head).contiguous()
                   for t in (q, k, v)]
        try:
            out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
        except (NotImplementedError, RuntimeError):
            out = F.scaled_dot_product_attention(
                q.view(b, heads, -1, dim_head), k.view(b, heads, -1, dim_head), v.view(b, heads, -1, dim_head),
                attn_mask=mask, dropout_p=0.0, is_causal=False).reshape(b * heads, -1, dim_head)
        return out.view(b, heads, -1, dim_head).permute(0, 2, 1, 3).reshape(b, -1, heads * dim_head)
    
    dim_head = q.shape[-1] if skip_reshape else q.shape[-1] // heads
    if BROKEN_XFORMERS and b * heads > 65535:
        return attention_pytorch(q, k, v, heads, mask, skip_reshape, flux)
    
    if skip_reshape:
        q, k, v = [t.reshape(b * heads, -1, dim_head) for t in (q, k, v)]
    else:
        q, k, v = [t.reshape(b, -1, heads, dim_head) for t in (q, k, v)]
    
    try:
        out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
    except (NotImplementedError, RuntimeError):
        out = F.scaled_dot_product_attention(
            q.view(b, heads, -1, dim_head), k.view(b, heads, -1, dim_head), v.view(b, heads, -1, dim_head),
            attn_mask=mask, dropout_p=0.0, is_causal=False).reshape(b * heads, -1, dim_head)
    
    if skip_reshape:
        return out.view(b, heads, -1, dim_head).permute(0, 2, 1, 3).reshape(b, -1, heads * dim_head)
    return out.reshape(b, -1, heads * dim_head)


def attention_sage(q, k, v, heads, mask=None, skip_reshape=False, flux=False):
    """Multi-head attention using SageAttention."""
    if mask is not None and mask.device != q.device:
        mask = mask.to(q.device)
    
    b = q.shape[0]
    dim_head = q.shape[-1] if (flux and skip_reshape) else q.shape[-1] // heads
    
    if not (flux and skip_reshape):
        if not flux:
            q, k, v = [t.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous()
                       for t in (q, k, v)]
        else:
            q, k, v = [t.reshape(b, -1, heads, dim_head).transpose(1, 2) for t in (q, k, v)]
    
    # Pad and compute attention
    qp, kp, vp, orig_dim = _pad_for_sage(q, k, v, dim_head)
    if orig_dim != dim_head or orig_dim in [64, 96, 128]:
        out = sageattention.sageattn(qp, kp, vp, tensor_layout="HND", attn_mask=mask, is_causal=False)
        if orig_dim != dim_head:
            out = out[..., :orig_dim]
    elif dim_head > 128:
        # Fallback for unsupported dimensions
        try:
            out = xformers.ops.memory_efficient_attention(
                q.reshape(b * heads, -1, dim_head), k.reshape(b * heads, -1, dim_head), 
                v.reshape(b * heads, -1, dim_head), attn_bias=mask)
            out = out.reshape(b, heads, -1, dim_head)
        except:
            out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
    else:
        out = sageattention.sageattn(qp, kp, vp, tensor_layout="HND", attn_mask=mask, is_causal=False)
        out = out[..., :dim_head]
    
    if not flux:
        return out.reshape(b, heads, -1, dim_head).permute(0, 2, 1, 3).reshape(b, -1, heads * dim_head)
    return out.transpose(1, 2).reshape(b, -1, heads * dim_head)


def attention_sparge(q, k, v, heads, mask=None, skip_reshape=False, flux=False):
    """Multi-head attention using SpargeAttn (Sparse + SageAttention)."""
    b = q.shape[0]
    dim_head = q.shape[-1] if (flux and skip_reshape) else q.shape[-1] // heads
    
    if not (flux and skip_reshape):
        if not flux:
            q, k, v = [t.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous()
                       for t in (q, k, v)]
        else:
            q, k, v = [t.reshape(b, -1, heads, dim_head).transpose(1, 2) for t in (q, k, v)]
    
    qp, kp, vp, orig_dim = _pad_for_sage(q, k, v, dim_head)
    sparge_kwargs = dict(simthreshd1=0.6, cdfthreshd=0.97, pvthreshd=15, is_causal=False)
    
    if orig_dim != dim_head or orig_dim in [64, 96, 128]:
        out = spas_sage_attn.spas_sage2_attn_meansim_cuda(qp, kp, vp, **sparge_kwargs)
        if orig_dim != dim_head:
            out = out[..., :orig_dim]
    elif dim_head > 128:
        out = sageattention.sageattn(q, k, v, tensor_layout="HND", attn_mask=mask, is_causal=False)
    else:
        out = spas_sage_attn.spas_sage2_attn_meansim_cuda(qp, kp, vp, **sparge_kwargs)
        out = out[..., :dim_head]
    
    if not flux:
        return out.reshape(b, heads, -1, dim_head).permute(0, 2, 1, 3).reshape(b, -1, heads * dim_head)
    return out.transpose(1, 2).reshape(b, -1, heads * dim_head)


# Simple 4D attention variants (B, C, H, W format)
def sage_attention(q, k, v):
    """SageAttention for 4D tensors (B, C, H, W)."""
    B, C, H, W = q.shape
    q, k, v = [t.view(B, 1, C, -1).transpose(2, 3).contiguous() for t in (q, k, v)]
    qp, kp, vp, orig = _pad_for_sage(q, k, v, C)
    if C > 128:
        out = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
    else:
        out = sageattention.sageattn(qp, kp, vp, tensor_layout="HND", is_causal=False)
        if orig != C:
            out = out[..., :C]
    return out.transpose(2, 3).reshape(B, C, H, W)


def sparge_attention(q, k, v):
    """SpargeAttn for 4D tensors (B, C, H, W)."""
    B, C, H, W = q.shape
    q, k, v = [t.view(B, 1, C, -1).transpose(2, 3).contiguous() for t in (q, k, v)]
    qp, kp, vp, orig = _pad_for_sage(q, k, v, C)
    sparge_kwargs = dict(simthreshd1=0.6, cdfthreshd=0.97, pvthreshd=15, is_causal=False)
    if C > 128:
        out = sageattention.sageattn(q, k, v, tensor_layout="HND", is_causal=False)
    else:
        out = spas_sage_attn.spas_sage2_attn_meansim_cuda(qp, kp, vp, **sparge_kwargs)
        if orig != C:
            out = out[..., :C]
    return out.transpose(2, 3).reshape(B, C, H, W)


def xformers_attention(q, k, v):
    """XFormers attention for 4D tensors (B, C, H, W)."""
    B, C, H, W = q.shape
    q, k, v = [t.view(B, C, -1).transpose(1, 2).contiguous() for t in (q, k, v)]
    try:
        out = xformers.ops.memory_efficient_attention(q, k, v)
    except (NotImplementedError, RuntimeError):
        out = F.scaled_dot_product_attention(q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1), dropout_p=0.0, is_causal=False).squeeze(1)
    return out.transpose(1, 2).reshape(B, C, H, W)


def pytorch_attention(q, k, v):
    """PyTorch attention for 4D tensors (B, C, H, W)."""
    B, C, H, W = q.shape
    q, k, v = [t.view(B, 1, C, -1).transpose(2, 3).contiguous() for t in (q, k, v)]
    out = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
    return out.transpose(2, 3).reshape(B, C, H, W)