Instructions to use yitongl/5090_test with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use yitongl/5090_test with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("yitongl/5090_test", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| from __future__ import annotations | |
| import os | |
| from typing import Tuple # noqa: F401 — kept for backward compat | |
| import torch | |
| # Global flag to enable FP4 P-quantization in the Triton forward/backward path | |
| # (set by sparse FP4 backend). | |
| _QAT_BACKWARD_ENABLED = False | |
| def set_qat_backward(enabled: bool): | |
| global _QAT_BACKWARD_ENABLED | |
| _QAT_BACKWARD_ENABLED = enabled | |
| def _use_high_prec_output_for_backward() -> bool: | |
| value = os.environ.get("FASTVIDEO_SPARSE_FP4_USE_HIGH_PREC_O", "1") | |
| return value.lower() not in ("0", "false", "no", "off") | |
| def _get_sm90_ops(): | |
| try: | |
| from fastvideo_kernel._C import fastvideo_kernel_ops # type: ignore | |
| except Exception: | |
| return None, None | |
| return ( | |
| getattr(fastvideo_kernel_ops, "block_sparse_fwd", None), | |
| getattr(fastvideo_kernel_ops, "block_sparse_bwd", None), | |
| ) | |
| def _is_sm90() -> bool: | |
| if not torch.cuda.is_available(): | |
| return False | |
| major, minor = torch.cuda.get_device_capability(0) | |
| return major == 9 and minor == 0 | |
| def _force_triton() -> bool: | |
| # Force Triton even on SM90 and even if the compiled extension is available. | |
| # Useful for CI / debugging / parity testing. | |
| return os.environ.get("FASTVIDEO_KERNEL_VSA_FORCE_TRITON", "0") == "1" | |
| def _map_to_index(block_map: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Preferred map->index conversion used by the wrapper. | |
| This wrapper **requires** the Triton implementation. | |
| If Triton (or the Triton map_to_index module) is not available, it raises. | |
| """ | |
| if block_map.dim() == 3: | |
| block_map = block_map.unsqueeze(0) | |
| if block_map.dim() != 4: | |
| raise ValueError(f"block_map must be [B,H,Q,KV] (or [H,Q,KV]), got shape={tuple(block_map.shape)}") | |
| if block_map.dtype != torch.bool: | |
| block_map = block_map.to(torch.bool) | |
| if not block_map.is_cuda: | |
| raise RuntimeError("block_map must be a CUDA tensor (Triton map_to_index required).") | |
| try: | |
| from fastvideo_kernel.triton_kernels.index import map_to_index as triton_map_to_index # local import | |
| except Exception as e: | |
| raise ImportError( | |
| "Triton map_to_index is required but not available. " | |
| "Ensure Triton is installed and fastvideo_kernel.triton_kernels.index is importable." | |
| ) from e | |
| return triton_map_to_index(block_map) | |
| def block_sparse_attn_triton( | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| block_map: torch.Tensor, | |
| variable_block_sizes: torch.Tensor, | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| q = q.contiguous() | |
| k = k.contiguous() | |
| v = v.contiguous() | |
| block_map = block_map.to(torch.bool) | |
| q2k_idx, q2k_num = _map_to_index(block_map) | |
| from fastvideo_kernel.triton_kernels.block_sparse_attn_triton import ( # local import | |
| triton_block_sparse_attn_forward, | |
| ) | |
| o, M, high_prec_o = triton_block_sparse_attn_forward( | |
| q, k, v, q2k_idx, q2k_num, variable_block_sizes, | |
| is_qat=_QAT_BACKWARD_ENABLED, | |
| ) | |
| return o, M, high_prec_o | |
| def _block_sparse_attn_triton_fake( | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| block_map: torch.Tensor, | |
| variable_block_sizes: torch.Tensor, | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| o = torch.empty_like(q) | |
| high_prec_o = torch.empty_like(q) | |
| M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) | |
| return o, M, high_prec_o | |
| def block_sparse_attn_backward_triton( | |
| grad_output: torch.Tensor, | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| o: torch.Tensor, | |
| M: torch.Tensor, | |
| block_map: torch.Tensor, | |
| variable_block_sizes: torch.Tensor, | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| grad_output = grad_output.contiguous() | |
| block_map = block_map.to(torch.bool) | |
| q2k_idx, q2k_num = _map_to_index(block_map) | |
| k2q_idx, k2q_num = _map_to_index(block_map.transpose(-1, -2).contiguous()) | |
| from fastvideo_kernel.triton_kernels.block_sparse_attn_triton import ( # local import | |
| triton_block_sparse_attn_backward, | |
| ) | |
| dq, dk, dv = triton_block_sparse_attn_backward( | |
| grad_output, q, k, v, o, M, q2k_idx, q2k_num, k2q_idx, k2q_num, variable_block_sizes, | |
| is_qat=_QAT_BACKWARD_ENABLED, | |
| ) | |
| return dq, dk, dv | |
| def _block_sparse_attn_backward_triton_fake( | |
| grad_output: torch.Tensor, | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| o: torch.Tensor, | |
| M: torch.Tensor, | |
| block_map: torch.Tensor, | |
| variable_block_sizes: torch.Tensor, | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| dq = torch.empty_like(q) | |
| dk = torch.empty_like(k) | |
| dv = torch.empty_like(v) | |
| return dq, dk, dv | |
| def _backward_triton(ctx, grad_o, grad_M, grad_high_prec_o): | |
| q, k, v, o_for_bwd, M, block_map, variable_block_sizes = ctx.saved_tensors | |
| dq, dk, dv = block_sparse_attn_backward_triton(grad_o, q, k, v, o_for_bwd, M, block_map, variable_block_sizes) | |
| return dq, dk, dv, None, None | |
| def _setup_context_triton(ctx, inputs, output): | |
| q, k, v, block_map, variable_block_sizes = inputs | |
| o, M, high_prec_o = output | |
| o_for_bwd = ( | |
| high_prec_o | |
| if _QAT_BACKWARD_ENABLED and _use_high_prec_output_for_backward() | |
| else o | |
| ) | |
| ctx.save_for_backward(q, k, v, o_for_bwd, M, block_map, | |
| variable_block_sizes) | |
| block_sparse_attn_triton.register_autograd(_backward_triton, setup_context=_setup_context_triton) | |
| class _BlockSparseAttnTileComp(torch.autograd.Function): | |
| def forward(ctx, q, k, v, q_mean, k_mean, v_mean, block_map, | |
| variable_block_sizes): | |
| q = q.contiguous() | |
| k = k.contiguous() | |
| v = v.contiguous() | |
| q_mean = q_mean.contiguous() | |
| k_mean = k_mean.contiguous() | |
| v_mean = v_mean.contiguous() | |
| block_map = block_map.to(torch.bool) | |
| dropped_block_map = torch.logical_not(block_map) | |
| q2k_idx, q2k_num = _map_to_index(block_map) | |
| dropped_q2k_idx, dropped_q2k_num = _map_to_index(dropped_block_map) | |
| from fastvideo_kernel.triton_kernels.block_sparse_attn_triton import ( # local import | |
| triton_block_sparse_attn_forward, | |
| ) | |
| o, M, high_prec_o = triton_block_sparse_attn_forward( | |
| q, | |
| k, | |
| v, | |
| q2k_idx, | |
| q2k_num, | |
| variable_block_sizes, | |
| is_qat=_QAT_BACKWARD_ENABLED, | |
| q_mean=q_mean, | |
| k_mean=k_mean, | |
| v_mean=v_mean, | |
| dropped_q2k_index=dropped_q2k_idx, | |
| dropped_q2k_num=dropped_q2k_num, | |
| ) | |
| o_for_bwd = ( | |
| high_prec_o | |
| if _QAT_BACKWARD_ENABLED and _use_high_prec_output_for_backward() | |
| else o | |
| ) | |
| ctx.save_for_backward(q, k, v, q_mean, k_mean, v_mean, o_for_bwd, M, | |
| block_map, dropped_block_map, | |
| variable_block_sizes) | |
| return o, M | |
| def backward(ctx, grad_o, grad_M): | |
| q, k, v, q_mean, k_mean, v_mean, o_for_bwd, M, block_map, dropped_block_map, variable_block_sizes = ctx.saved_tensors | |
| q2k_idx, q2k_num = _map_to_index(block_map) | |
| k2q_idx, k2q_num = _map_to_index(block_map.transpose(-1, -2).contiguous()) | |
| dropped_q2k_idx, dropped_q2k_num = _map_to_index(dropped_block_map) | |
| dropped_k2q_idx, dropped_k2q_num = _map_to_index( | |
| dropped_block_map.transpose(-1, -2).contiguous()) | |
| from fastvideo_kernel.triton_kernels.block_sparse_attn_triton import ( # local import | |
| triton_block_sparse_attn_backward, | |
| ) | |
| dq, dk, dv = triton_block_sparse_attn_backward( | |
| grad_o.contiguous(), | |
| q, | |
| k, | |
| v, | |
| o_for_bwd, | |
| M, | |
| q2k_idx, | |
| q2k_num, | |
| k2q_idx, | |
| k2q_num, | |
| variable_block_sizes, | |
| is_qat=_QAT_BACKWARD_ENABLED, | |
| q_mean=q_mean, | |
| k_mean=k_mean, | |
| v_mean=v_mean, | |
| dropped_q2k_index=dropped_q2k_idx, | |
| dropped_q2k_num=dropped_q2k_num, | |
| dropped_k2q_index=dropped_k2q_idx, | |
| dropped_k2q_num=dropped_k2q_num, | |
| ) | |
| return dq, dk, dv, None, None, None, None, None | |
| def block_sparse_attn_sm90( | |
| q_padded: torch.Tensor, | |
| k_padded: torch.Tensor, | |
| v_padded: torch.Tensor, | |
| block_map: torch.Tensor, | |
| variable_block_sizes: torch.Tensor, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| block_sparse_fwd, _ = _get_sm90_ops() | |
| if block_sparse_fwd is None: | |
| raise ImportError("fastvideo_kernel_ops.block_sparse_fwd is not available") | |
| q_padded = q_padded.contiguous() | |
| k_padded = k_padded.contiguous() | |
| v_padded = v_padded.contiguous() | |
| block_map = block_map.to(torch.bool) | |
| q2k_idx, q2k_num = _map_to_index(block_map) | |
| o_padded, lse_padded = block_sparse_fwd( | |
| q_padded, k_padded, v_padded, q2k_idx, q2k_num, variable_block_sizes.int() | |
| ) | |
| return o_padded, lse_padded | |
| def _block_sparse_attn_sm90_fake( | |
| q_padded: torch.Tensor, | |
| k_padded: torch.Tensor, | |
| v_padded: torch.Tensor, | |
| block_map: torch.Tensor, | |
| variable_block_sizes: torch.Tensor, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| o = torch.empty_like(q_padded) | |
| lse = torch.empty((q_padded.shape[0], q_padded.shape[1], q_padded.shape[2], 1), device=q_padded.device, dtype=torch.float32) | |
| return o, lse | |
| def block_sparse_attn_backward_sm90( | |
| grad_output_padded: torch.Tensor, | |
| q_padded: torch.Tensor, | |
| k_padded: torch.Tensor, | |
| v_padded: torch.Tensor, | |
| o_padded: torch.Tensor, | |
| lse_padded: torch.Tensor, | |
| block_map: torch.Tensor, | |
| variable_block_sizes: torch.Tensor, | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| _, block_sparse_bwd = _get_sm90_ops() | |
| if block_sparse_bwd is None: | |
| raise ImportError("fastvideo_kernel_ops.block_sparse_bwd is not available") | |
| grad_output_padded = grad_output_padded.contiguous() | |
| block_map = block_map.to(torch.bool) | |
| k2q_idx, k2q_num = _map_to_index(block_map.transpose(-1, -2).contiguous()) | |
| dq, dk, dv = block_sparse_bwd( | |
| q_padded, | |
| k_padded, | |
| v_padded, | |
| o_padded, | |
| lse_padded, | |
| grad_output_padded, | |
| k2q_idx, | |
| k2q_num, | |
| variable_block_sizes.int(), | |
| ) | |
| # C++ kernel returns fp32 grads; cast back to match PyTorch convention if needed | |
| return dq.to(grad_output_padded.dtype), dk.to(grad_output_padded.dtype), dv.to(grad_output_padded.dtype) | |
| def _block_sparse_attn_backward_sm90_fake( | |
| grad_output_padded: torch.Tensor, | |
| q_padded: torch.Tensor, | |
| k_padded: torch.Tensor, | |
| v_padded: torch.Tensor, | |
| o_padded: torch.Tensor, | |
| lse_padded: torch.Tensor, | |
| block_map: torch.Tensor, | |
| variable_block_sizes: torch.Tensor, | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| dq = torch.empty_like(q_padded) | |
| dk = torch.empty_like(k_padded) | |
| dv = torch.empty_like(v_padded) | |
| return dq, dk, dv | |
| def _backward_sm90(ctx, grad_o, grad_lse): | |
| q, k, v, o, lse, block_map, variable_block_sizes = ctx.saved_tensors | |
| dq, dk, dv = block_sparse_attn_backward_sm90( | |
| grad_o, q, k, v, o, lse, block_map, variable_block_sizes | |
| ) | |
| return dq, dk, dv, None, None | |
| def _setup_context_sm90(ctx, inputs, output): | |
| q, k, v, block_map, variable_block_sizes = inputs | |
| o, lse = output | |
| ctx.save_for_backward(q, k, v, o, lse, block_map, variable_block_sizes) | |
| block_sparse_attn_sm90.register_autograd(_backward_sm90, setup_context=_setup_context_sm90) | |
| def block_sparse_attn( | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| block_map: torch.Tensor, | |
| variable_block_sizes: torch.Tensor, | |
| q_mean: torch.Tensor | None = None, | |
| k_mean: torch.Tensor | None = None, | |
| v_mean: torch.Tensor | None = None, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Unified block-sparse attention op with autograd support. | |
| - On SM90 with compiled extension present: uses fastvideo_kernel_ops.block_sparse_fwd/bwd. | |
| - Otherwise: uses Triton implementation (requires q/k/v to have same padded length today). | |
| - P-quant QAT currently requires the Triton path. | |
| - Passing q_mean/k_mean/v_mean enables tile-level compensation for | |
| blocks omitted by block_map. | |
| """ | |
| if (q_mean is not None) or (k_mean is not None) or (v_mean is not None): | |
| if q_mean is None or k_mean is None or v_mean is None: | |
| raise ValueError("q_mean, k_mean, and v_mean must be provided together") | |
| return _BlockSparseAttnTileComp.apply(q, k, v, q_mean, k_mean, | |
| v_mean, block_map, | |
| variable_block_sizes) | |
| block_sparse_fwd, block_sparse_bwd = _get_sm90_ops() | |
| if ( | |
| not _QAT_BACKWARD_ENABLED | |
| and (not _force_triton()) | |
| and _is_sm90() | |
| and (block_sparse_fwd is not None) | |
| and (block_sparse_bwd is not None) | |
| ): | |
| return block_sparse_attn_sm90(q, k, v, block_map, variable_block_sizes) | |
| # Triton path: supports q_seq_len != kv_seq_len as long as both are padded | |
| # to a multiple of the block size (64 tokens). | |
| o, M, _ = block_sparse_attn_triton(q, k, v, block_map, | |
| variable_block_sizes) | |
| return o, M | |
| def block_sparse_attn_qat( | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| block_map: torch.Tensor, | |
| variable_block_sizes: torch.Tensor, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Block-sparse attention with FP4 fake quantization on P (QAT mode).""" | |
| q = q.contiguous() | |
| k = k.contiguous() | |
| v = v.contiguous() | |
| block_map = block_map.to(torch.bool) | |
| q2k_idx, q2k_num = _map_to_index(block_map) | |
| from fastvideo_kernel.triton_kernels.block_sparse_attn_triton import ( | |
| triton_block_sparse_attn_forward, | |
| triton_block_sparse_attn_backward, | |
| ) | |
| o, M, _ = triton_block_sparse_attn_forward( | |
| q, k, v, q2k_idx, q2k_num, variable_block_sizes, is_qat=True, | |
| ) | |
| # Note: backward with IS_QAT is handled through the autograd of | |
| # block_sparse_attn_triton. For now, QAT mode is forward-only for inference. | |
| return o, M | |