Spaces:
Running on Zero
Running on Zero
| import torch | |
| import torch.nn.functional as F | |
| ATTN_PRECISION = torch.float16 | |
| try: | |
| import flash_attn_interface | |
| FLASH_ATTN_3_AVAILABLE = True | |
| FLASH_ATTN_AVAILABLE = False | |
| except ModuleNotFoundError: | |
| FLASH_ATTN_3_AVAILABLE = False | |
| try: | |
| import flash_attn | |
| FLASH_ATTN_AVAILABLE = True | |
| except ModuleNotFoundError: | |
| FLASH_ATTN_AVAILABLE = False | |
| try: | |
| import xformers.ops | |
| XFORMERS_IS_AVAILBLE = True | |
| except: | |
| XFORMERS_IS_AVAILBLE = False | |
| def half(x): | |
| if x.dtype not in [torch.float16, torch.bfloat16]: | |
| x = x.to(ATTN_PRECISION) | |
| return x | |
| def attn_processor(q, k, v, attn_mask = None, *args, **kwargs): | |
| if attn_mask is not None: | |
| if XFORMERS_IS_AVAILBLE: | |
| out = xformers.ops.memory_efficient_attention( | |
| q, k, v, attn_bias=attn_mask, *args, **kwargs | |
| ) | |
| else: | |
| q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v)) | |
| out = F.scaled_dot_product_attention( | |
| q, k, v, attn_mask=attn_mask, *args, **kwargs | |
| ).transpose(1, 2) | |
| else: | |
| if FLASH_ATTN_3_AVAILABLE: | |
| dtype = v.dtype | |
| q, k, v = map(lambda t: half(t), (q, k, v)) | |
| out = flash_attn_interface.flash_attn_func(q, k, v, *args, **kwargs)[0].to(dtype) | |
| elif FLASH_ATTN_AVAILABLE: | |
| dtype = v.dtype | |
| q, k, v = map(lambda t: half(t), (q, k, v)) | |
| out = flash_attn.flash_attn_func(q, k, v, *args, **kwargs).to(dtype) | |
| elif XFORMERS_IS_AVAILBLE: | |
| out = xformers.ops.memory_efficient_attention(q, k, v, *args, **kwargs) | |
| else: | |
| q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v)) | |
| out = F.scaled_dot_product_attention(q, k, v, *args, **kwargs).transpose(1, 2) | |
| return out | |
| def flash_attn_varlen_func(q, k, v, **kwargs): | |
| if FLASH_ATTN_3_AVAILABLE: | |
| return flash_attn_interface.flash_attn_varlen_func(q, k, v, **kwargs)[0] | |
| else: | |
| return flash_attn.flash_attn_varlen_func(q, k, v, **kwargs) | |
| def split_tensor_by_mask(tensor: torch.Tensor, mask: torch.Tensor, threshold: float = 0.5): | |
| """ | |
| Split input tensor into foreground and background based on mask, then concatenate them. | |
| Args: | |
| tensor: Input tensor of shape (batch, seq_len, dim) | |
| mask: Binary mask of shape (batch, seq_len, 1) or (batch, seq_len) | |
| threshold: Threshold for mask binarization | |
| Returns: | |
| split_tensor: Concatenated tensor with foreground first, then background | |
| fg_indices: Indices of foreground elements for restoration | |
| bg_indices: Indices of background elements for restoration | |
| original_shape: Original tensor shape for restoration | |
| """ | |
| batch_size, seq_len, *dims = tensor.shape | |
| device, dtype = tensor.device, tensor.dtype | |
| # Ensure mask has correct shape and binarize | |
| if mask.dim() == 2: | |
| mask = mask.unsqueeze(-1) | |
| binary_mask = (mask > threshold).squeeze(-1) # Shape: (batch, seq_len) | |
| # Store indices for restoration (keep minimal loop for complex indexing) | |
| fg_indices = [torch.where(binary_mask[b])[0] for b in range(batch_size)] | |
| bg_indices = [torch.where(~binary_mask[b])[0] for b in range(batch_size)] | |
| # Count elements efficiently | |
| fg_counts = binary_mask.sum(dim=1) | |
| bg_counts = (~binary_mask).sum(dim=1) | |
| max_fg_len = fg_counts.max().item() | |
| max_bg_len = bg_counts.max().item() | |
| # Early exit if no elements | |
| if max_fg_len == 0 and max_bg_len == 0: | |
| return torch.zeros(batch_size, 0, *dims, device=device, dtype=dtype), fg_indices, bg_indices, tensor.shape | |
| # Create output tensor | |
| split_tensor = torch.zeros(batch_size, max_fg_len + max_bg_len, *dims, device=device, dtype=dtype) | |
| # Vectorized approach using gather for better efficiency | |
| for b in range(batch_size): | |
| if len(fg_indices[b]) > 0: | |
| split_tensor[b, :len(fg_indices[b])] = tensor[b][fg_indices[b]] | |
| if len(bg_indices[b]) > 0: | |
| split_tensor[b, max_fg_len:max_fg_len + len(bg_indices[b])] = tensor[b][bg_indices[b]] | |
| return split_tensor, fg_indices, bg_indices, tensor.shape | |
| def restore_tensor_from_split(split_tensor: torch.Tensor, fg_indices: list, bg_indices: list, | |
| original_shape: torch.Size): | |
| """ | |
| Restore original tensor from split tensor using stored indices. | |
| Args: | |
| split_tensor: Split tensor from split_tensor_by_mask | |
| fg_indices: List of foreground indices for each batch | |
| bg_indices: List of background indices for each batch | |
| original_shape: Original tensor shape | |
| Returns: | |
| restored_tensor: Restored tensor with original shape and ordering | |
| """ | |
| batch_size, seq_len = original_shape[:2] | |
| dims = original_shape[2:] | |
| device, dtype = split_tensor.device, split_tensor.dtype | |
| # Calculate split point efficiently | |
| max_fg_len = max((len(fg) for fg in fg_indices), default=0) | |
| # Initialize restored tensor | |
| restored_tensor = torch.zeros(batch_size, seq_len, *dims, device=device, dtype=dtype) | |
| # Early exit if no elements to restore | |
| if split_tensor.shape[1] == 0: | |
| return restored_tensor | |
| # Split tensor parts | |
| fg_part = split_tensor[:, :max_fg_len] if max_fg_len > 0 else None | |
| bg_part = split_tensor[:, max_fg_len:] if split_tensor.shape[1] > max_fg_len else None | |
| # Restore in single loop with efficient indexing | |
| for b in range(batch_size): | |
| if fg_part is not None and len(fg_indices[b]) > 0: | |
| restored_tensor[b, fg_indices[b]] = fg_part[b, :len(fg_indices[b])] | |
| if bg_part is not None and len(bg_indices[b]) > 0: | |
| restored_tensor[b, bg_indices[b]] = bg_part[b, :len(bg_indices[b])] | |
| return restored_tensor | |