| import torch, os |
| from einops import rearrange |
|
|
|
|
| try: |
| import flash_attn_interface |
| FLASH_ATTN_3_AVAILABLE = True |
| except ModuleNotFoundError: |
| FLASH_ATTN_3_AVAILABLE = False |
|
|
| try: |
| import flash_attn |
| FLASH_ATTN_2_AVAILABLE = True |
| except ModuleNotFoundError: |
| FLASH_ATTN_2_AVAILABLE = False |
|
|
| try: |
| from sageattention import sageattn |
| SAGE_ATTN_AVAILABLE = True |
| except ModuleNotFoundError: |
| SAGE_ATTN_AVAILABLE = False |
|
|
| try: |
| import xformers.ops as xops |
| XFORMERS_AVAILABLE = True |
| except ModuleNotFoundError: |
| XFORMERS_AVAILABLE = False |
|
|
|
|
| def initialize_attention_priority(): |
| if os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION') is not None: |
| return os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION').lower() |
| elif FLASH_ATTN_3_AVAILABLE: |
| return "flash_attention_3" |
| elif FLASH_ATTN_2_AVAILABLE: |
| return "flash_attention_2" |
| elif SAGE_ATTN_AVAILABLE: |
| return "sage_attention" |
| elif XFORMERS_AVAILABLE: |
| return "xformers" |
| else: |
| return "torch" |
|
|
|
|
| ATTENTION_IMPLEMENTATION = initialize_attention_priority() |
|
|
|
|
| def rearrange_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", required_in_pattern="b n s d", dims=None): |
| dims = {} if dims is None else dims |
| if q_pattern != required_in_pattern: |
| q = rearrange(q, f"{q_pattern} -> {required_in_pattern}", **dims) |
| if k_pattern != required_in_pattern: |
| k = rearrange(k, f"{k_pattern} -> {required_in_pattern}", **dims) |
| if v_pattern != required_in_pattern: |
| v = rearrange(v, f"{v_pattern} -> {required_in_pattern}", **dims) |
| return q, k, v |
|
|
|
|
| def rearrange_out(out: torch.Tensor, out_pattern="b n s d", required_out_pattern="b n s d", dims=None): |
| dims = {} if dims is None else dims |
| if out_pattern != required_out_pattern: |
| out = rearrange(out, f"{required_out_pattern} -> {out_pattern}", **dims) |
| return out |
|
|
|
|
| def torch_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None): |
| required_in_pattern, required_out_pattern= "b n s d", "b n s d" |
| q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims) |
| out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask, scale=scale) |
| out = rearrange_out(out, out_pattern, required_out_pattern, dims) |
| return out |
|
|
|
|
| def flash_attention_3(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None): |
| required_in_pattern, required_out_pattern= "b s n d", "b s n d" |
| q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims) |
| out = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=scale) |
| if isinstance(out, tuple): |
| out = out[0] |
| out = rearrange_out(out, out_pattern, required_out_pattern, dims) |
| return out |
|
|
|
|
| def flash_attention_2(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None): |
| required_in_pattern, required_out_pattern= "b s n d", "b s n d" |
| q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims) |
| out = flash_attn.flash_attn_func(q, k, v, softmax_scale=scale) |
| out = rearrange_out(out, out_pattern, required_out_pattern, dims) |
| return out |
|
|
|
|
| def sage_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None): |
| required_in_pattern, required_out_pattern= "b n s d", "b n s d" |
| q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims) |
| out = sageattn(q, k, v, sm_scale=scale) |
| out = rearrange_out(out, out_pattern, required_out_pattern, dims) |
| return out |
|
|
|
|
| def xformers_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None): |
| required_in_pattern, required_out_pattern= "b s n d", "b s n d" |
| q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims) |
| out = xops.memory_efficient_attention(q, k, v, scale=scale) |
| out = rearrange_out(out, out_pattern, required_out_pattern, dims) |
| return out |
|
|
|
|
| def attention_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None, compatibility_mode=False): |
| if compatibility_mode or (attn_mask is not None): |
| return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, attn_mask=attn_mask, scale=scale) |
| else: |
| if ATTENTION_IMPLEMENTATION == "flash_attention_3": |
| return flash_attention_3(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale) |
| elif ATTENTION_IMPLEMENTATION == "flash_attention_2": |
| return flash_attention_2(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale) |
| elif ATTENTION_IMPLEMENTATION == "sage_attention": |
| return sage_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale) |
| elif ATTENTION_IMPLEMENTATION == "xformers": |
| return xformers_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale) |
| else: |
| return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale) |
|
|