stevengrove's picture
Initial commit with Xet-tracked image assets
fcfea15
# Adapted from https://github.com/hao-ai-lab/FastVideo/tree/main/fastvideo/attention
import os
import sys
import torch
from einops import rearrange
_FLASH_ATTN_IMPORT_ERROR = None
try:
# Check for Flash Attention 3 installation path
flash_attn3_path = os.getenv("FLASH_ATTN3_PATH")
if flash_attn3_path:
print(f"Using Flash Attention 3 from: {flash_attn3_path}")
sys.path.insert(0, flash_attn3_path)
from flash_attn_interface import flash_attn_varlen_func
else:
from flash_attn.flash_attn_interface import flash_attn_varlen_func
except ImportError as exc:
flash_attn_varlen_func = None
_FLASH_ATTN_IMPORT_ERROR = exc
def is_flash_attn_available() -> bool:
return flash_attn_varlen_func is not None
def get_preferred_attention_backend() -> str:
return "flash_attn" if is_flash_attn_available() else "torch_spda"
def describe_attention_backend() -> str:
backend = get_preferred_attention_backend()
if backend == "flash_attn":
return "flash_attn"
if _FLASH_ATTN_IMPORT_ERROR is None:
return "torch_spda"
return f"torch_spda (flash_attn unavailable: {_FLASH_ATTN_IMPORT_ERROR})"
def get_cu_seqlens(text_mask, img_len):
"""Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len
Args:
text_mask (torch.Tensor): the mask of text
img_len (int): the length of image
Returns:
torch.Tensor: the calculated cu_seqlens for flash attention
"""
batch_size = text_mask.shape[0]
text_len = text_mask.sum(dim=1)
max_len = text_mask.shape[1] + img_len
cu_seqlens = torch.zeros([2 * batch_size + 1],
dtype=torch.int32, device="cuda")
for i in range(batch_size):
s = text_len[i] + img_len
s1 = i * max_len + s
s2 = (i + 1) * max_len
cu_seqlens[2 * i + 1] = s1
cu_seqlens[2 * i + 2] = s2
return cu_seqlens
def attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
backend: str = "flash_attn",
*,
causal: bool = False,
softmax_scale: float = None,
attn_kwargs: dict = None,
):
"""
Args:
q (torch.Tensor): Query tensor of shape [batch_size, seq_len, num_heads, head_dim]
k (torch.Tensor): Key tensor of shape [batch_size, seq_len, num_heads, head_dim]
v (torch.Tensor): Value tensor of shape [batch_size, seq_len, num_heads
"""
if backend == "auto":
backend = get_preferred_attention_backend()
# Fall back to torch_spda when flash_attn was requested but unavailable
if backend == "flash_attn" and flash_attn_varlen_func is None:
backend = "torch_spda"
assert backend in [
"torch_spda", "flash_attn"], f"Unsupported attention backend: {backend}"
assert q.dim() == 4 and k.dim() == 4 and v.dim() == 4, "Input tensors must be 4D"
batch_size = q.shape[0]
if backend == "torch_spda":
q = rearrange(q, "b l h c -> b h l c")
k = rearrange(k, "b l h c -> b h l c")
v = rearrange(v, "b l h c -> b h l c")
output = torch.nn.functional.scaled_dot_product_attention(
q, k, v, is_causal=causal, scale=softmax_scale)
output = rearrange(output, "b h l c -> b l h c")
elif backend == "flash_attn":
cu_seqlens_q = attn_kwargs['cu_seqlens_q']
cu_seqlens_kv = attn_kwargs['cu_seqlens_kv']
max_seqlen_q = attn_kwargs['max_seqlen_q']
max_seqlen_kv = attn_kwargs['max_seqlen_kv']
x = flash_attn_varlen_func(
q.view(q.shape[0] * q.shape[1], *q.shape[2:]),
k.view(k.shape[0] * k.shape[1], *k.shape[2:]),
v.view(v.shape[0] * v.shape[1], *v.shape[2:]),
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
)
output = x.view(
batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]
)
return output