tellurion's picture
initialize huggingface space demo
d066167
import torch
import torch.nn.functional as F
ATTN_PRECISION = torch.float16
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
FLASH_ATTN_AVAILABLE = False
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
try:
import flash_attn
FLASH_ATTN_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_AVAILABLE = False
try:
import xformers.ops
XFORMERS_IS_AVAILBLE = True
except:
XFORMERS_IS_AVAILBLE = False
def half(x):
if x.dtype not in [torch.float16, torch.bfloat16]:
x = x.to(ATTN_PRECISION)
return x
def attn_processor(q, k, v, attn_mask = None, *args, **kwargs):
if attn_mask is not None:
if XFORMERS_IS_AVAILBLE:
out = xformers.ops.memory_efficient_attention(
q, k, v, attn_bias=attn_mask, *args, **kwargs
)
else:
q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v))
out = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, *args, **kwargs
).transpose(1, 2)
else:
if FLASH_ATTN_3_AVAILABLE:
dtype = v.dtype
q, k, v = map(lambda t: half(t), (q, k, v))
out = flash_attn_interface.flash_attn_func(q, k, v, *args, **kwargs)[0].to(dtype)
elif FLASH_ATTN_AVAILABLE:
dtype = v.dtype
q, k, v = map(lambda t: half(t), (q, k, v))
out = flash_attn.flash_attn_func(q, k, v, *args, **kwargs).to(dtype)
elif XFORMERS_IS_AVAILBLE:
out = xformers.ops.memory_efficient_attention(q, k, v, *args, **kwargs)
else:
q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v))
out = F.scaled_dot_product_attention(q, k, v, *args, **kwargs).transpose(1, 2)
return out
def flash_attn_varlen_func(q, k, v, **kwargs):
if FLASH_ATTN_3_AVAILABLE:
return flash_attn_interface.flash_attn_varlen_func(q, k, v, **kwargs)[0]
else:
return flash_attn.flash_attn_varlen_func(q, k, v, **kwargs)
def split_tensor_by_mask(tensor: torch.Tensor, mask: torch.Tensor, threshold: float = 0.5):
"""
Split input tensor into foreground and background based on mask, then concatenate them.
Args:
tensor: Input tensor of shape (batch, seq_len, dim)
mask: Binary mask of shape (batch, seq_len, 1) or (batch, seq_len)
threshold: Threshold for mask binarization
Returns:
split_tensor: Concatenated tensor with foreground first, then background
fg_indices: Indices of foreground elements for restoration
bg_indices: Indices of background elements for restoration
original_shape: Original tensor shape for restoration
"""
batch_size, seq_len, *dims = tensor.shape
device, dtype = tensor.device, tensor.dtype
# Ensure mask has correct shape and binarize
if mask.dim() == 2:
mask = mask.unsqueeze(-1)
binary_mask = (mask > threshold).squeeze(-1) # Shape: (batch, seq_len)
# Store indices for restoration (keep minimal loop for complex indexing)
fg_indices = [torch.where(binary_mask[b])[0] for b in range(batch_size)]
bg_indices = [torch.where(~binary_mask[b])[0] for b in range(batch_size)]
# Count elements efficiently
fg_counts = binary_mask.sum(dim=1)
bg_counts = (~binary_mask).sum(dim=1)
max_fg_len = fg_counts.max().item()
max_bg_len = bg_counts.max().item()
# Early exit if no elements
if max_fg_len == 0 and max_bg_len == 0:
return torch.zeros(batch_size, 0, *dims, device=device, dtype=dtype), fg_indices, bg_indices, tensor.shape
# Create output tensor
split_tensor = torch.zeros(batch_size, max_fg_len + max_bg_len, *dims, device=device, dtype=dtype)
# Vectorized approach using gather for better efficiency
for b in range(batch_size):
if len(fg_indices[b]) > 0:
split_tensor[b, :len(fg_indices[b])] = tensor[b][fg_indices[b]]
if len(bg_indices[b]) > 0:
split_tensor[b, max_fg_len:max_fg_len + len(bg_indices[b])] = tensor[b][bg_indices[b]]
return split_tensor, fg_indices, bg_indices, tensor.shape
def restore_tensor_from_split(split_tensor: torch.Tensor, fg_indices: list, bg_indices: list,
original_shape: torch.Size):
"""
Restore original tensor from split tensor using stored indices.
Args:
split_tensor: Split tensor from split_tensor_by_mask
fg_indices: List of foreground indices for each batch
bg_indices: List of background indices for each batch
original_shape: Original tensor shape
Returns:
restored_tensor: Restored tensor with original shape and ordering
"""
batch_size, seq_len = original_shape[:2]
dims = original_shape[2:]
device, dtype = split_tensor.device, split_tensor.dtype
# Calculate split point efficiently
max_fg_len = max((len(fg) for fg in fg_indices), default=0)
# Initialize restored tensor
restored_tensor = torch.zeros(batch_size, seq_len, *dims, device=device, dtype=dtype)
# Early exit if no elements to restore
if split_tensor.shape[1] == 0:
return restored_tensor
# Split tensor parts
fg_part = split_tensor[:, :max_fg_len] if max_fg_len > 0 else None
bg_part = split_tensor[:, max_fg_len:] if split_tensor.shape[1] > max_fg_len else None
# Restore in single loop with efficient indexing
for b in range(batch_size):
if fg_part is not None and len(fg_indices[b]) > 0:
restored_tensor[b, fg_indices[b]] = fg_part[b, :len(fg_indices[b])]
if bg_part is not None and len(bg_indices[b]) > 0:
restored_tensor[b, bg_indices[b]] = bg_part[b, :len(bg_indices[b])]
return restored_tensor