yitongl's picture
Upload FastVideo 5090 safetensors checkpoint2950
d4cc469 verified
from __future__ import annotations
import os
from typing import Tuple # noqa: F401 — kept for backward compat
import torch
# Global flag to enable FP4 P-quantization in the Triton forward/backward path
# (set by sparse FP4 backend).
_QAT_BACKWARD_ENABLED = False
def set_qat_backward(enabled: bool):
global _QAT_BACKWARD_ENABLED
_QAT_BACKWARD_ENABLED = enabled
def _use_high_prec_output_for_backward() -> bool:
value = os.environ.get("FASTVIDEO_SPARSE_FP4_USE_HIGH_PREC_O", "1")
return value.lower() not in ("0", "false", "no", "off")
def _get_sm90_ops():
try:
from fastvideo_kernel._C import fastvideo_kernel_ops # type: ignore
except Exception:
return None, None
return (
getattr(fastvideo_kernel_ops, "block_sparse_fwd", None),
getattr(fastvideo_kernel_ops, "block_sparse_bwd", None),
)
def _is_sm90() -> bool:
if not torch.cuda.is_available():
return False
major, minor = torch.cuda.get_device_capability(0)
return major == 9 and minor == 0
def _force_triton() -> bool:
# Force Triton even on SM90 and even if the compiled extension is available.
# Useful for CI / debugging / parity testing.
return os.environ.get("FASTVIDEO_KERNEL_VSA_FORCE_TRITON", "0") == "1"
def _map_to_index(block_map: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Preferred map->index conversion used by the wrapper.
This wrapper **requires** the Triton implementation.
If Triton (or the Triton map_to_index module) is not available, it raises.
"""
if block_map.dim() == 3:
block_map = block_map.unsqueeze(0)
if block_map.dim() != 4:
raise ValueError(f"block_map must be [B,H,Q,KV] (or [H,Q,KV]), got shape={tuple(block_map.shape)}")
if block_map.dtype != torch.bool:
block_map = block_map.to(torch.bool)
if not block_map.is_cuda:
raise RuntimeError("block_map must be a CUDA tensor (Triton map_to_index required).")
try:
from fastvideo_kernel.triton_kernels.index import map_to_index as triton_map_to_index # local import
except Exception as e:
raise ImportError(
"Triton map_to_index is required but not available. "
"Ensure Triton is installed and fastvideo_kernel.triton_kernels.index is importable."
) from e
return triton_map_to_index(block_map)
@torch.library.custom_op(
"fastvideo_kernel::block_sparse_attn_triton",
mutates_args=(),
device_types="cuda",
)
def block_sparse_attn_triton(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
block_map: torch.Tensor,
variable_block_sizes: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
block_map = block_map.to(torch.bool)
q2k_idx, q2k_num = _map_to_index(block_map)
from fastvideo_kernel.triton_kernels.block_sparse_attn_triton import ( # local import
triton_block_sparse_attn_forward,
)
o, M, high_prec_o = triton_block_sparse_attn_forward(
q, k, v, q2k_idx, q2k_num, variable_block_sizes,
is_qat=_QAT_BACKWARD_ENABLED,
)
return o, M, high_prec_o
@torch.library.register_fake("fastvideo_kernel::block_sparse_attn_triton")
def _block_sparse_attn_triton_fake(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
block_map: torch.Tensor,
variable_block_sizes: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
o = torch.empty_like(q)
high_prec_o = torch.empty_like(q)
M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
return o, M, high_prec_o
@torch.library.custom_op(
"fastvideo_kernel::block_sparse_attn_backward_triton",
mutates_args=(),
device_types="cuda",
)
def block_sparse_attn_backward_triton(
grad_output: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
o: torch.Tensor,
M: torch.Tensor,
block_map: torch.Tensor,
variable_block_sizes: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
grad_output = grad_output.contiguous()
block_map = block_map.to(torch.bool)
q2k_idx, q2k_num = _map_to_index(block_map)
k2q_idx, k2q_num = _map_to_index(block_map.transpose(-1, -2).contiguous())
from fastvideo_kernel.triton_kernels.block_sparse_attn_triton import ( # local import
triton_block_sparse_attn_backward,
)
dq, dk, dv = triton_block_sparse_attn_backward(
grad_output, q, k, v, o, M, q2k_idx, q2k_num, k2q_idx, k2q_num, variable_block_sizes,
is_qat=_QAT_BACKWARD_ENABLED,
)
return dq, dk, dv
@torch.library.register_fake("fastvideo_kernel::block_sparse_attn_backward_triton")
def _block_sparse_attn_backward_triton_fake(
grad_output: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
o: torch.Tensor,
M: torch.Tensor,
block_map: torch.Tensor,
variable_block_sizes: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
return dq, dk, dv
def _backward_triton(ctx, grad_o, grad_M, grad_high_prec_o):
q, k, v, o_for_bwd, M, block_map, variable_block_sizes = ctx.saved_tensors
dq, dk, dv = block_sparse_attn_backward_triton(grad_o, q, k, v, o_for_bwd, M, block_map, variable_block_sizes)
return dq, dk, dv, None, None
def _setup_context_triton(ctx, inputs, output):
q, k, v, block_map, variable_block_sizes = inputs
o, M, high_prec_o = output
o_for_bwd = (
high_prec_o
if _QAT_BACKWARD_ENABLED and _use_high_prec_output_for_backward()
else o
)
ctx.save_for_backward(q, k, v, o_for_bwd, M, block_map,
variable_block_sizes)
block_sparse_attn_triton.register_autograd(_backward_triton, setup_context=_setup_context_triton)
class _BlockSparseAttnTileComp(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, q_mean, k_mean, v_mean, block_map,
variable_block_sizes):
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
q_mean = q_mean.contiguous()
k_mean = k_mean.contiguous()
v_mean = v_mean.contiguous()
block_map = block_map.to(torch.bool)
dropped_block_map = torch.logical_not(block_map)
q2k_idx, q2k_num = _map_to_index(block_map)
dropped_q2k_idx, dropped_q2k_num = _map_to_index(dropped_block_map)
from fastvideo_kernel.triton_kernels.block_sparse_attn_triton import ( # local import
triton_block_sparse_attn_forward,
)
o, M, high_prec_o = triton_block_sparse_attn_forward(
q,
k,
v,
q2k_idx,
q2k_num,
variable_block_sizes,
is_qat=_QAT_BACKWARD_ENABLED,
q_mean=q_mean,
k_mean=k_mean,
v_mean=v_mean,
dropped_q2k_index=dropped_q2k_idx,
dropped_q2k_num=dropped_q2k_num,
)
o_for_bwd = (
high_prec_o
if _QAT_BACKWARD_ENABLED and _use_high_prec_output_for_backward()
else o
)
ctx.save_for_backward(q, k, v, q_mean, k_mean, v_mean, o_for_bwd, M,
block_map, dropped_block_map,
variable_block_sizes)
return o, M
@staticmethod
def backward(ctx, grad_o, grad_M):
q, k, v, q_mean, k_mean, v_mean, o_for_bwd, M, block_map, dropped_block_map, variable_block_sizes = ctx.saved_tensors
q2k_idx, q2k_num = _map_to_index(block_map)
k2q_idx, k2q_num = _map_to_index(block_map.transpose(-1, -2).contiguous())
dropped_q2k_idx, dropped_q2k_num = _map_to_index(dropped_block_map)
dropped_k2q_idx, dropped_k2q_num = _map_to_index(
dropped_block_map.transpose(-1, -2).contiguous())
from fastvideo_kernel.triton_kernels.block_sparse_attn_triton import ( # local import
triton_block_sparse_attn_backward,
)
dq, dk, dv = triton_block_sparse_attn_backward(
grad_o.contiguous(),
q,
k,
v,
o_for_bwd,
M,
q2k_idx,
q2k_num,
k2q_idx,
k2q_num,
variable_block_sizes,
is_qat=_QAT_BACKWARD_ENABLED,
q_mean=q_mean,
k_mean=k_mean,
v_mean=v_mean,
dropped_q2k_index=dropped_q2k_idx,
dropped_q2k_num=dropped_q2k_num,
dropped_k2q_index=dropped_k2q_idx,
dropped_k2q_num=dropped_k2q_num,
)
return dq, dk, dv, None, None, None, None, None
@torch.library.custom_op(
"fastvideo_kernel::block_sparse_attn_sm90",
mutates_args=(),
device_types="cuda",
)
def block_sparse_attn_sm90(
q_padded: torch.Tensor,
k_padded: torch.Tensor,
v_padded: torch.Tensor,
block_map: torch.Tensor,
variable_block_sizes: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
block_sparse_fwd, _ = _get_sm90_ops()
if block_sparse_fwd is None:
raise ImportError("fastvideo_kernel_ops.block_sparse_fwd is not available")
q_padded = q_padded.contiguous()
k_padded = k_padded.contiguous()
v_padded = v_padded.contiguous()
block_map = block_map.to(torch.bool)
q2k_idx, q2k_num = _map_to_index(block_map)
o_padded, lse_padded = block_sparse_fwd(
q_padded, k_padded, v_padded, q2k_idx, q2k_num, variable_block_sizes.int()
)
return o_padded, lse_padded
@torch.library.register_fake("fastvideo_kernel::block_sparse_attn_sm90")
def _block_sparse_attn_sm90_fake(
q_padded: torch.Tensor,
k_padded: torch.Tensor,
v_padded: torch.Tensor,
block_map: torch.Tensor,
variable_block_sizes: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
o = torch.empty_like(q_padded)
lse = torch.empty((q_padded.shape[0], q_padded.shape[1], q_padded.shape[2], 1), device=q_padded.device, dtype=torch.float32)
return o, lse
@torch.library.custom_op(
"fastvideo_kernel::block_sparse_attn_backward_sm90",
mutates_args=(),
device_types="cuda",
)
def block_sparse_attn_backward_sm90(
grad_output_padded: torch.Tensor,
q_padded: torch.Tensor,
k_padded: torch.Tensor,
v_padded: torch.Tensor,
o_padded: torch.Tensor,
lse_padded: torch.Tensor,
block_map: torch.Tensor,
variable_block_sizes: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
_, block_sparse_bwd = _get_sm90_ops()
if block_sparse_bwd is None:
raise ImportError("fastvideo_kernel_ops.block_sparse_bwd is not available")
grad_output_padded = grad_output_padded.contiguous()
block_map = block_map.to(torch.bool)
k2q_idx, k2q_num = _map_to_index(block_map.transpose(-1, -2).contiguous())
dq, dk, dv = block_sparse_bwd(
q_padded,
k_padded,
v_padded,
o_padded,
lse_padded,
grad_output_padded,
k2q_idx,
k2q_num,
variable_block_sizes.int(),
)
# C++ kernel returns fp32 grads; cast back to match PyTorch convention if needed
return dq.to(grad_output_padded.dtype), dk.to(grad_output_padded.dtype), dv.to(grad_output_padded.dtype)
@torch.library.register_fake("fastvideo_kernel::block_sparse_attn_backward_sm90")
def _block_sparse_attn_backward_sm90_fake(
grad_output_padded: torch.Tensor,
q_padded: torch.Tensor,
k_padded: torch.Tensor,
v_padded: torch.Tensor,
o_padded: torch.Tensor,
lse_padded: torch.Tensor,
block_map: torch.Tensor,
variable_block_sizes: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
dq = torch.empty_like(q_padded)
dk = torch.empty_like(k_padded)
dv = torch.empty_like(v_padded)
return dq, dk, dv
def _backward_sm90(ctx, grad_o, grad_lse):
q, k, v, o, lse, block_map, variable_block_sizes = ctx.saved_tensors
dq, dk, dv = block_sparse_attn_backward_sm90(
grad_o, q, k, v, o, lse, block_map, variable_block_sizes
)
return dq, dk, dv, None, None
def _setup_context_sm90(ctx, inputs, output):
q, k, v, block_map, variable_block_sizes = inputs
o, lse = output
ctx.save_for_backward(q, k, v, o, lse, block_map, variable_block_sizes)
block_sparse_attn_sm90.register_autograd(_backward_sm90, setup_context=_setup_context_sm90)
def block_sparse_attn(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
block_map: torch.Tensor,
variable_block_sizes: torch.Tensor,
q_mean: torch.Tensor | None = None,
k_mean: torch.Tensor | None = None,
v_mean: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Unified block-sparse attention op with autograd support.
- On SM90 with compiled extension present: uses fastvideo_kernel_ops.block_sparse_fwd/bwd.
- Otherwise: uses Triton implementation (requires q/k/v to have same padded length today).
- P-quant QAT currently requires the Triton path.
- Passing q_mean/k_mean/v_mean enables tile-level compensation for
blocks omitted by block_map.
"""
if (q_mean is not None) or (k_mean is not None) or (v_mean is not None):
if q_mean is None or k_mean is None or v_mean is None:
raise ValueError("q_mean, k_mean, and v_mean must be provided together")
return _BlockSparseAttnTileComp.apply(q, k, v, q_mean, k_mean,
v_mean, block_map,
variable_block_sizes)
block_sparse_fwd, block_sparse_bwd = _get_sm90_ops()
if (
not _QAT_BACKWARD_ENABLED
and (not _force_triton())
and _is_sm90()
and (block_sparse_fwd is not None)
and (block_sparse_bwd is not None)
):
return block_sparse_attn_sm90(q, k, v, block_map, variable_block_sizes)
# Triton path: supports q_seq_len != kv_seq_len as long as both are padded
# to a multiple of the block size (64 tokens).
o, M, _ = block_sparse_attn_triton(q, k, v, block_map,
variable_block_sizes)
return o, M
def block_sparse_attn_qat(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
block_map: torch.Tensor,
variable_block_sizes: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Block-sparse attention with FP4 fake quantization on P (QAT mode)."""
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
block_map = block_map.to(torch.bool)
q2k_idx, q2k_num = _map_to_index(block_map)
from fastvideo_kernel.triton_kernels.block_sparse_attn_triton import (
triton_block_sparse_attn_forward,
triton_block_sparse_attn_backward,
)
o, M, _ = triton_block_sparse_attn_forward(
q, k, v, q2k_idx, q2k_num, variable_block_sizes, is_qat=True,
)
# Note: backward with IS_QAT is handled through the autograd of
# block_sparse_attn_triton. For now, QAT mode is forward-only for inference.
return o, M