| from __future__ import annotations |
|
|
| import os |
| from typing import Tuple |
|
|
| import torch |
|
|
| |
| |
| _QAT_BACKWARD_ENABLED = False |
|
|
|
|
| def set_qat_backward(enabled: bool): |
| global _QAT_BACKWARD_ENABLED |
| _QAT_BACKWARD_ENABLED = enabled |
|
|
|
|
| def _use_high_prec_output_for_backward() -> bool: |
| value = os.environ.get("FASTVIDEO_SPARSE_FP4_USE_HIGH_PREC_O", "1") |
| return value.lower() not in ("0", "false", "no", "off") |
|
|
|
|
| def _get_sm90_ops(): |
| try: |
| from fastvideo_kernel._C import fastvideo_kernel_ops |
| except Exception: |
| return None, None |
| return ( |
| getattr(fastvideo_kernel_ops, "block_sparse_fwd", None), |
| getattr(fastvideo_kernel_ops, "block_sparse_bwd", None), |
| ) |
|
|
|
|
| def _is_sm90() -> bool: |
| if not torch.cuda.is_available(): |
| return False |
| major, minor = torch.cuda.get_device_capability(0) |
| return major == 9 and minor == 0 |
|
|
|
|
| def _force_triton() -> bool: |
| |
| |
| return os.environ.get("FASTVIDEO_KERNEL_VSA_FORCE_TRITON", "0") == "1" |
|
|
|
|
| def _map_to_index(block_map: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Preferred map->index conversion used by the wrapper. |
| |
| This wrapper **requires** the Triton implementation. |
| If Triton (or the Triton map_to_index module) is not available, it raises. |
| """ |
| if block_map.dim() == 3: |
| block_map = block_map.unsqueeze(0) |
| if block_map.dim() != 4: |
| raise ValueError(f"block_map must be [B,H,Q,KV] (or [H,Q,KV]), got shape={tuple(block_map.shape)}") |
| if block_map.dtype != torch.bool: |
| block_map = block_map.to(torch.bool) |
|
|
| if not block_map.is_cuda: |
| raise RuntimeError("block_map must be a CUDA tensor (Triton map_to_index required).") |
|
|
| try: |
| from fastvideo_kernel.triton_kernels.index import map_to_index as triton_map_to_index |
| except Exception as e: |
| raise ImportError( |
| "Triton map_to_index is required but not available. " |
| "Ensure Triton is installed and fastvideo_kernel.triton_kernels.index is importable." |
| ) from e |
| return triton_map_to_index(block_map) |
|
|
|
|
| @torch.library.custom_op( |
| "fastvideo_kernel::block_sparse_attn_triton", |
| mutates_args=(), |
| device_types="cuda", |
| ) |
| def block_sparse_attn_triton( |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| block_map: torch.Tensor, |
| variable_block_sizes: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| q = q.contiguous() |
| k = k.contiguous() |
| v = v.contiguous() |
| block_map = block_map.to(torch.bool) |
| q2k_idx, q2k_num = _map_to_index(block_map) |
|
|
| from fastvideo_kernel.triton_kernels.block_sparse_attn_triton import ( |
| triton_block_sparse_attn_forward, |
| ) |
|
|
| o, M, high_prec_o = triton_block_sparse_attn_forward( |
| q, k, v, q2k_idx, q2k_num, variable_block_sizes, |
| is_qat=_QAT_BACKWARD_ENABLED, |
| ) |
| return o, M, high_prec_o |
|
|
|
|
|
|
| @torch.library.register_fake("fastvideo_kernel::block_sparse_attn_triton") |
| def _block_sparse_attn_triton_fake( |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| block_map: torch.Tensor, |
| variable_block_sizes: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| o = torch.empty_like(q) |
| high_prec_o = torch.empty_like(q) |
| M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) |
| return o, M, high_prec_o |
|
|
|
|
| @torch.library.custom_op( |
| "fastvideo_kernel::block_sparse_attn_backward_triton", |
| mutates_args=(), |
| device_types="cuda", |
| ) |
| def block_sparse_attn_backward_triton( |
| grad_output: torch.Tensor, |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| o: torch.Tensor, |
| M: torch.Tensor, |
| block_map: torch.Tensor, |
| variable_block_sizes: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| grad_output = grad_output.contiguous() |
| block_map = block_map.to(torch.bool) |
| q2k_idx, q2k_num = _map_to_index(block_map) |
| k2q_idx, k2q_num = _map_to_index(block_map.transpose(-1, -2).contiguous()) |
|
|
| from fastvideo_kernel.triton_kernels.block_sparse_attn_triton import ( |
| triton_block_sparse_attn_backward, |
| ) |
|
|
| dq, dk, dv = triton_block_sparse_attn_backward( |
| grad_output, q, k, v, o, M, q2k_idx, q2k_num, k2q_idx, k2q_num, variable_block_sizes, |
| is_qat=_QAT_BACKWARD_ENABLED, |
| ) |
| return dq, dk, dv |
|
|
|
|
| @torch.library.register_fake("fastvideo_kernel::block_sparse_attn_backward_triton") |
| def _block_sparse_attn_backward_triton_fake( |
| grad_output: torch.Tensor, |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| o: torch.Tensor, |
| M: torch.Tensor, |
| block_map: torch.Tensor, |
| variable_block_sizes: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| dq = torch.empty_like(q) |
| dk = torch.empty_like(k) |
| dv = torch.empty_like(v) |
| return dq, dk, dv |
|
|
|
|
| def _backward_triton(ctx, grad_o, grad_M, grad_high_prec_o): |
| q, k, v, o_for_bwd, M, block_map, variable_block_sizes = ctx.saved_tensors |
| dq, dk, dv = block_sparse_attn_backward_triton(grad_o, q, k, v, o_for_bwd, M, block_map, variable_block_sizes) |
| return dq, dk, dv, None, None |
|
|
|
|
| def _setup_context_triton(ctx, inputs, output): |
| q, k, v, block_map, variable_block_sizes = inputs |
| o, M, high_prec_o = output |
| o_for_bwd = ( |
| high_prec_o |
| if _QAT_BACKWARD_ENABLED and _use_high_prec_output_for_backward() |
| else o |
| ) |
| ctx.save_for_backward(q, k, v, o_for_bwd, M, block_map, |
| variable_block_sizes) |
|
|
|
|
| block_sparse_attn_triton.register_autograd(_backward_triton, setup_context=_setup_context_triton) |
|
|
|
|
| class _BlockSparseAttnTileComp(torch.autograd.Function): |
|
|
| @staticmethod |
| def forward(ctx, q, k, v, q_mean, k_mean, v_mean, block_map, |
| variable_block_sizes): |
| q = q.contiguous() |
| k = k.contiguous() |
| v = v.contiguous() |
| q_mean = q_mean.contiguous() |
| k_mean = k_mean.contiguous() |
| v_mean = v_mean.contiguous() |
| block_map = block_map.to(torch.bool) |
| dropped_block_map = torch.logical_not(block_map) |
|
|
| q2k_idx, q2k_num = _map_to_index(block_map) |
| dropped_q2k_idx, dropped_q2k_num = _map_to_index(dropped_block_map) |
|
|
| from fastvideo_kernel.triton_kernels.block_sparse_attn_triton import ( |
| triton_block_sparse_attn_forward, |
| ) |
|
|
| o, M, high_prec_o = triton_block_sparse_attn_forward( |
| q, |
| k, |
| v, |
| q2k_idx, |
| q2k_num, |
| variable_block_sizes, |
| is_qat=_QAT_BACKWARD_ENABLED, |
| q_mean=q_mean, |
| k_mean=k_mean, |
| v_mean=v_mean, |
| dropped_q2k_index=dropped_q2k_idx, |
| dropped_q2k_num=dropped_q2k_num, |
| ) |
| o_for_bwd = ( |
| high_prec_o |
| if _QAT_BACKWARD_ENABLED and _use_high_prec_output_for_backward() |
| else o |
| ) |
| ctx.save_for_backward(q, k, v, q_mean, k_mean, v_mean, o_for_bwd, M, |
| block_map, dropped_block_map, |
| variable_block_sizes) |
| return o, M |
|
|
| @staticmethod |
| def backward(ctx, grad_o, grad_M): |
| q, k, v, q_mean, k_mean, v_mean, o_for_bwd, M, block_map, dropped_block_map, variable_block_sizes = ctx.saved_tensors |
|
|
| q2k_idx, q2k_num = _map_to_index(block_map) |
| k2q_idx, k2q_num = _map_to_index(block_map.transpose(-1, -2).contiguous()) |
| dropped_q2k_idx, dropped_q2k_num = _map_to_index(dropped_block_map) |
| dropped_k2q_idx, dropped_k2q_num = _map_to_index( |
| dropped_block_map.transpose(-1, -2).contiguous()) |
|
|
| from fastvideo_kernel.triton_kernels.block_sparse_attn_triton import ( |
| triton_block_sparse_attn_backward, |
| ) |
|
|
| dq, dk, dv = triton_block_sparse_attn_backward( |
| grad_o.contiguous(), |
| q, |
| k, |
| v, |
| o_for_bwd, |
| M, |
| q2k_idx, |
| q2k_num, |
| k2q_idx, |
| k2q_num, |
| variable_block_sizes, |
| is_qat=_QAT_BACKWARD_ENABLED, |
| q_mean=q_mean, |
| k_mean=k_mean, |
| v_mean=v_mean, |
| dropped_q2k_index=dropped_q2k_idx, |
| dropped_q2k_num=dropped_q2k_num, |
| dropped_k2q_index=dropped_k2q_idx, |
| dropped_k2q_num=dropped_k2q_num, |
| ) |
| return dq, dk, dv, None, None, None, None, None |
|
|
|
|
| @torch.library.custom_op( |
| "fastvideo_kernel::block_sparse_attn_sm90", |
| mutates_args=(), |
| device_types="cuda", |
| ) |
| def block_sparse_attn_sm90( |
| q_padded: torch.Tensor, |
| k_padded: torch.Tensor, |
| v_padded: torch.Tensor, |
| block_map: torch.Tensor, |
| variable_block_sizes: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| block_sparse_fwd, _ = _get_sm90_ops() |
| if block_sparse_fwd is None: |
| raise ImportError("fastvideo_kernel_ops.block_sparse_fwd is not available") |
|
|
| q_padded = q_padded.contiguous() |
| k_padded = k_padded.contiguous() |
| v_padded = v_padded.contiguous() |
| block_map = block_map.to(torch.bool) |
| q2k_idx, q2k_num = _map_to_index(block_map) |
|
|
| o_padded, lse_padded = block_sparse_fwd( |
| q_padded, k_padded, v_padded, q2k_idx, q2k_num, variable_block_sizes.int() |
| ) |
| return o_padded, lse_padded |
|
|
|
|
| @torch.library.register_fake("fastvideo_kernel::block_sparse_attn_sm90") |
| def _block_sparse_attn_sm90_fake( |
| q_padded: torch.Tensor, |
| k_padded: torch.Tensor, |
| v_padded: torch.Tensor, |
| block_map: torch.Tensor, |
| variable_block_sizes: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| o = torch.empty_like(q_padded) |
| lse = torch.empty((q_padded.shape[0], q_padded.shape[1], q_padded.shape[2], 1), device=q_padded.device, dtype=torch.float32) |
| return o, lse |
|
|
|
|
| @torch.library.custom_op( |
| "fastvideo_kernel::block_sparse_attn_backward_sm90", |
| mutates_args=(), |
| device_types="cuda", |
| ) |
| def block_sparse_attn_backward_sm90( |
| grad_output_padded: torch.Tensor, |
| q_padded: torch.Tensor, |
| k_padded: torch.Tensor, |
| v_padded: torch.Tensor, |
| o_padded: torch.Tensor, |
| lse_padded: torch.Tensor, |
| block_map: torch.Tensor, |
| variable_block_sizes: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| _, block_sparse_bwd = _get_sm90_ops() |
| if block_sparse_bwd is None: |
| raise ImportError("fastvideo_kernel_ops.block_sparse_bwd is not available") |
|
|
| grad_output_padded = grad_output_padded.contiguous() |
| block_map = block_map.to(torch.bool) |
| k2q_idx, k2q_num = _map_to_index(block_map.transpose(-1, -2).contiguous()) |
|
|
| dq, dk, dv = block_sparse_bwd( |
| q_padded, |
| k_padded, |
| v_padded, |
| o_padded, |
| lse_padded, |
| grad_output_padded, |
| k2q_idx, |
| k2q_num, |
| variable_block_sizes.int(), |
| ) |
| |
| return dq.to(grad_output_padded.dtype), dk.to(grad_output_padded.dtype), dv.to(grad_output_padded.dtype) |
|
|
|
|
| @torch.library.register_fake("fastvideo_kernel::block_sparse_attn_backward_sm90") |
| def _block_sparse_attn_backward_sm90_fake( |
| grad_output_padded: torch.Tensor, |
| q_padded: torch.Tensor, |
| k_padded: torch.Tensor, |
| v_padded: torch.Tensor, |
| o_padded: torch.Tensor, |
| lse_padded: torch.Tensor, |
| block_map: torch.Tensor, |
| variable_block_sizes: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| dq = torch.empty_like(q_padded) |
| dk = torch.empty_like(k_padded) |
| dv = torch.empty_like(v_padded) |
| return dq, dk, dv |
|
|
|
|
| def _backward_sm90(ctx, grad_o, grad_lse): |
| q, k, v, o, lse, block_map, variable_block_sizes = ctx.saved_tensors |
| dq, dk, dv = block_sparse_attn_backward_sm90( |
| grad_o, q, k, v, o, lse, block_map, variable_block_sizes |
| ) |
| return dq, dk, dv, None, None |
|
|
|
|
| def _setup_context_sm90(ctx, inputs, output): |
| q, k, v, block_map, variable_block_sizes = inputs |
| o, lse = output |
| ctx.save_for_backward(q, k, v, o, lse, block_map, variable_block_sizes) |
|
|
|
|
| block_sparse_attn_sm90.register_autograd(_backward_sm90, setup_context=_setup_context_sm90) |
|
|
|
|
| def block_sparse_attn( |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| block_map: torch.Tensor, |
| variable_block_sizes: torch.Tensor, |
| q_mean: torch.Tensor | None = None, |
| k_mean: torch.Tensor | None = None, |
| v_mean: torch.Tensor | None = None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Unified block-sparse attention op with autograd support. |
| - On SM90 with compiled extension present: uses fastvideo_kernel_ops.block_sparse_fwd/bwd. |
| - Otherwise: uses Triton implementation (requires q/k/v to have same padded length today). |
| - P-quant QAT currently requires the Triton path. |
| - Passing q_mean/k_mean/v_mean enables tile-level compensation for |
| blocks omitted by block_map. |
| """ |
| if (q_mean is not None) or (k_mean is not None) or (v_mean is not None): |
| if q_mean is None or k_mean is None or v_mean is None: |
| raise ValueError("q_mean, k_mean, and v_mean must be provided together") |
| return _BlockSparseAttnTileComp.apply(q, k, v, q_mean, k_mean, |
| v_mean, block_map, |
| variable_block_sizes) |
|
|
| block_sparse_fwd, block_sparse_bwd = _get_sm90_ops() |
| if ( |
| not _QAT_BACKWARD_ENABLED |
| and (not _force_triton()) |
| and _is_sm90() |
| and (block_sparse_fwd is not None) |
| and (block_sparse_bwd is not None) |
| ): |
| return block_sparse_attn_sm90(q, k, v, block_map, variable_block_sizes) |
| |
| |
| o, M, _ = block_sparse_attn_triton(q, k, v, block_map, |
| variable_block_sizes) |
| return o, M |
|
|
|
|
| def block_sparse_attn_qat( |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| block_map: torch.Tensor, |
| variable_block_sizes: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """Block-sparse attention with FP4 fake quantization on P (QAT mode).""" |
| q = q.contiguous() |
| k = k.contiguous() |
| v = v.contiguous() |
| block_map = block_map.to(torch.bool) |
| q2k_idx, q2k_num = _map_to_index(block_map) |
|
|
| from fastvideo_kernel.triton_kernels.block_sparse_attn_triton import ( |
| triton_block_sparse_attn_forward, |
| triton_block_sparse_attn_backward, |
| ) |
|
|
| o, M, _ = triton_block_sparse_attn_forward( |
| q, k, v, q2k_idx, q2k_num, variable_block_sizes, is_qat=True, |
| ) |
| |
| |
| return o, M |
|
|