ColabWan / postprocessing /flashvsr /attention_backend.py
1ripon1's picture
Upload folder using huggingface_hub
7344bef verified
Raw
History Blame Contribute Delete
13.6 kB
from __future__ import annotations
import importlib
import importlib.util
import traceback
from typing import Callable
import torch
from .sparse_backend_config import (
SPARSE_BACKEND_AUTO as _AUTO_BACKEND,
SPARSE_BACKEND_LABELS as _BACKEND_LABELS,
SPARSE_BACKEND_SPARGE as _SPARGE_BACKEND,
SPARSE_BACKEND_TRITON_SPARSE as _TRITON_SPARSE_BACKEND,
normalize_sparse_backend,
)
_SPARSE_ATTENTION: Callable | None = None
_BACKEND_NAME: str | None = None
_BACKEND_ERROR: str | None = None
_PRINTED_BACKEND = False
_PRINTED_IMPORT_ERRORS: set[str] = set()
_PRINTED_AUTO_FALLBACKS: set[str] = set()
_SPARSE_BACKEND = _AUTO_BACKEND
_REQUIREMENTS_MESSAGE = "FlashVSR sparse attention requirements are not satisfied."
_INSTALL_MESSAGE = "Install them from docs/INSTALLATION.md and restart WanGP."
_BACKEND_DEPENDENCIES = {
_SPARGE_BACKEND: (("triton", "Triton"), ("spas_sage_attn", "SpargeAttn")),
_TRITON_SPARSE_BACKEND: (("triton", "Triton"),),
}
_BUNDLED_SPARSE_BACKEND_NAME = "bundled Triton Sparse Attention"
_ARCH_KERNELS = {
"sm80": ("SM80_ENABLED", "spas_sage_attn.sm80_compile", "spas_sage_attn._qattn_sm80"),
"sm86": ("SM80_ENABLED", "spas_sage_attn.sm80_compile", "spas_sage_attn._qattn_sm80"),
"sm87": ("SM80_ENABLED", "spas_sage_attn.sm80_compile", "spas_sage_attn._qattn_sm80"),
"sm89": ("SM89_ENABLED", "spas_sage_attn.sm89_compile", "spas_sage_attn._qattn_sm89"),
"sm90": ("SM90_ENABLED", "spas_sage_attn.sm90_compile", "spas_sage_attn._qattn_sm90"),
"sm100": ("SM89_ENABLED", "spas_sage_attn.sm89_compile", "spas_sage_attn._qattn_sm89"),
"sm120": ("SM89_ENABLED", "spas_sage_attn.sm89_compile", "spas_sage_attn._qattn_sm89"),
"sm121": ("SM89_ENABLED", "spas_sage_attn.sm89_compile", "spas_sage_attn._qattn_sm89"),
}
def _print_import_error(module_name: str, exc: BaseException) -> None:
key = f"{module_name}:{type(exc).__name__}:{exc}"
if key in _PRINTED_IMPORT_ERRORS:
return
_PRINTED_IMPORT_ERRORS.add(key)
print(f"[FlashVSR] Importing {module_name} failed:")
traceback.print_exception(type(exc), exc, exc.__traceback__)
def set_sparse_backend(backend: object) -> str:
global _SPARSE_BACKEND, _SPARSE_ATTENTION, _BACKEND_NAME, _BACKEND_ERROR, _PRINTED_BACKEND
backend = normalize_sparse_backend(backend)
if backend != _SPARSE_BACKEND:
_SPARSE_ATTENTION = None
_BACKEND_NAME = None
_BACKEND_ERROR = None
_PRINTED_BACKEND = False
_SPARSE_BACKEND = backend
return _SPARSE_BACKEND
def _selected_sparse_backend(backend: object | None = None) -> str:
return _SPARSE_BACKEND if backend is None else normalize_sparse_backend(backend)
def _print_auto_fallback(message: str) -> None:
if message in _PRINTED_AUTO_FALLBACKS:
return
_PRINTED_AUTO_FALLBACKS.add(message)
print(f"[FlashVSR] Auto backend cannot use {_BACKEND_LABELS[_SPARGE_BACKEND]}: {message} Install SpargeAttn for better FlashVSR quality.")
print(f"[FlashVSR] Auto backend trying {_BACKEND_LABELS[_TRITON_SPARSE_BACKEND]}.")
def _missing_sparse_attention_dependencies(backend: str) -> list[str]:
missing = []
for module_name, display_name in _BACKEND_DEPENDENCIES[backend]:
if importlib.util.find_spec(module_name) is None:
missing.append(display_name)
return missing
def _missing_dependencies_message(backend: str, missing: list[str]) -> str:
return f"{_REQUIREMENTS_MESSAGE} Backend: {_BACKEND_LABELS[backend]}. Missing: {', '.join(missing)}. {_INSTALL_MESSAGE}"
def _dependency_import_message(display_name: str, module_name: str, exc: BaseException) -> str:
return f"{_REQUIREMENTS_MESSAGE} {display_name} is installed, but importing {module_name} failed. Check the console for the import error, then reinstall from docs/INSTALLATION.md and restart WanGP. Import failed: {type(exc).__name__}: {exc}"
def _kernel_load_message(sparge_error: str | None) -> str:
return f"{_REQUIREMENTS_MESSAGE} SpargeAttn is installed, but its kernels could not be loaded. Reinstall SpargeAttn from docs/INSTALLATION.md and restart WanGP. SpargeAttn import failed: {sparge_error or 'not installed'}"
def _arch_kernel_load_message(arch: str, module_name: str, exc: BaseException | None) -> str:
if exc is not None:
return f"{_REQUIREMENTS_MESSAGE} SpargeAttn is installed, but importing its {arch} kernel failed. Check the console for the import error, then reinstall SpargeAttn from docs/INSTALLATION.md and restart WanGP. Import failed: {type(exc).__name__}: {exc}"
return f"{_REQUIREMENTS_MESSAGE} SpargeAttn is installed, but its {arch} kernel is unavailable. Reinstall SpargeAttn from docs/INSTALLATION.md and restart WanGP. Missing kernel module: {module_name}"
def _dependency_import_error(backend: str) -> str | None:
for module_name, display_name in _BACKEND_DEPENDENCIES[backend]:
try:
importlib.import_module(module_name)
except Exception as exc:
_print_import_error(module_name, exc)
return _dependency_import_message(display_name, module_name, exc)
return None
def _import_sparge_core():
try:
return importlib.import_module("shared.spas_sage_attn_core"), None
except Exception as exc:
_print_import_error("shared.spas_sage_attn_core", exc)
return None, f"{type(exc).__name__}: {exc}"
def _current_cuda_arch() -> str | None:
if not torch.cuda.is_available():
return None
major, minor = torch.cuda.get_device_capability(torch.cuda.current_device())
return f"sm{major}{minor}"
def _arch_kernel_error(module, arch: str | None) -> str | None:
if arch is None or arch not in _ARCH_KERNELS:
return None
flag_name, compile_module_name, direct_module_name = _ARCH_KERNELS[arch]
if getattr(module, flag_name, False):
return None
try:
importlib.import_module(compile_module_name)
except ModuleNotFoundError as exc:
if exc.name != compile_module_name:
_print_import_error(compile_module_name, exc)
return _arch_kernel_load_message(arch, compile_module_name, exc)
try:
importlib.import_module(direct_module_name)
except Exception as direct_exc:
_print_import_error(direct_module_name, direct_exc)
return _arch_kernel_load_message(arch, direct_module_name, direct_exc)
return _arch_kernel_load_message(arch, direct_module_name, None)
except Exception as exc:
_print_import_error(compile_module_name, exc)
return _arch_kernel_load_message(arch, compile_module_name, exc)
return _arch_kernel_load_message(arch, compile_module_name, None)
def _load_triton_sparse_backend() -> tuple[Callable | None, str | None, str | None]:
try:
from .sparse_sage.core import sparse_sageattn
except Exception as exc:
_print_import_error("postprocessing.flashvsr.sparse_sage.core", exc)
return None, None, _dependency_import_message(_BUNDLED_SPARSE_BACKEND_NAME, "postprocessing.flashvsr.sparse_sage.core", exc)
def bundled_sparse_sage(qkv_list: list[torch.Tensor], mask_id: torch.Tensor | list[torch.Tensor], recycle_q: bool = False) -> torch.Tensor:
mask_id = _int8_mask(mask_id)
return sparse_sageattn(qkv_list, mask_id=_take_mask(mask_id), is_causal=False, tensor_layout="HND")
return bundled_sparse_sage, _BUNDLED_SPARSE_BACKEND_NAME, None
def _backend_requirement_status(backend: str) -> tuple[Callable | None, str | None, str | None]:
missing = _missing_sparse_attention_dependencies(backend)
if missing:
return None, None, _missing_dependencies_message(backend, missing)
dependency_import_error = _dependency_import_error(backend)
if dependency_import_error is not None:
return None, None, dependency_import_error
if backend == _TRITON_SPARSE_BACKEND:
return _load_triton_sparse_backend()
module, sparge_error = _import_sparge_core()
if module is None:
return None, None, _kernel_load_message(sparge_error)
arch_kernel_error = _arch_kernel_error(module, _current_cuda_arch())
if arch_kernel_error is not None:
return None, None, arch_kernel_error
fn = getattr(module, "block_sparse_attn_cuda", None)
if not callable(fn):
return None, None, _kernel_load_message("WanGP SpargeAttn block sparse CUDA function not found")
return fn, "WanGP SpargeAttn block sparse CUDA", None
def _sparse_attention_requirement_status(backend: object | None = None) -> tuple[Callable | None, str | None, str | None]:
backend = _selected_sparse_backend(backend)
if backend != _AUTO_BACKEND:
return _backend_requirement_status(backend)
sparge_fn, sparge_name, sparge_message = _backend_requirement_status(_SPARGE_BACKEND)
if sparge_message is None:
return sparge_fn, sparge_name, None
_print_auto_fallback(sparge_message)
triton_sparse_fn, triton_sparse_name, triton_sparse_message = _backend_requirement_status(_TRITON_SPARSE_BACKEND)
if triton_sparse_message is None:
return triton_sparse_fn, triton_sparse_name, None
return None, None, f"FlashVSR Auto backend could not load any sparse attention backend. Sparge: {sparge_message} {_BACKEND_LABELS[_TRITON_SPARSE_BACKEND]}: {triton_sparse_message}"
def sparse_attention_requirement_message(backend: object | None = None) -> str | None:
_, _, message = _sparse_attention_requirement_status(backend)
return message
def sparge_attention_available() -> bool:
return sparse_attention_requirement_message(_SPARGE_BACKEND) is None
def require_sparge_attention() -> None:
_, _, message = _sparse_attention_requirement_status()
if message is not None:
raise RuntimeError(message)
def _mask_topk(mask_id: torch.Tensor | None, q: torch.Tensor) -> torch.Tensor | float:
if isinstance(mask_id, list):
mask_id = mask_id[0] if len(mask_id) > 0 else None
if mask_id is None or not torch.is_tensor(mask_id):
return 0.5
density = mask_id.to(device=q.device, dtype=torch.float32).mean(dim=(0, 2, 3))
return density.clamp(1.0 / max(int(mask_id.shape[-1]), 1), 1.0)
def _int8_mask(mask_id: torch.Tensor | list[torch.Tensor]) -> torch.Tensor | list[torch.Tensor]:
mask = mask_id[0] if isinstance(mask_id, list) else mask_id
if mask.dtype != torch.int8:
mask = mask.to(torch.int8)
if isinstance(mask_id, list):
mask_id[0] = mask
return mask_id if isinstance(mask_id, list) else mask
def _take_mask(mask_id: torch.Tensor | list[torch.Tensor]) -> torch.Tensor:
if isinstance(mask_id, list):
mask = mask_id[0]
mask_id.clear()
return mask
return mask_id
def _load_backend() -> tuple[Callable, str]:
global _BACKEND_ERROR
backend = _selected_sparse_backend()
sparge_fn, sparge_name_or_error, message = _sparse_attention_requirement_status()
if message is not None:
_BACKEND_ERROR = message
raise RuntimeError(message)
if sparge_fn is None:
_BACKEND_ERROR = _kernel_load_message(sparge_name_or_error)
raise RuntimeError(_BACKEND_ERROR)
if backend == _TRITON_SPARSE_BACKEND or sparge_name_or_error == _BUNDLED_SPARSE_BACKEND_NAME:
return sparge_fn, sparge_name_or_error or _BUNDLED_SPARSE_BACKEND_NAME
use_qkv_list = sparge_fn.__module__ == "shared.spas_sage_attn_core"
def sparge_attention(qkv_list: list[torch.Tensor], mask_id: torch.Tensor | list[torch.Tensor], recycle_q: bool = False) -> torch.Tensor:
if "mask_id" in sparge_fn.__code__.co_varnames:
mask_id = _int8_mask(mask_id)
if use_qkv_list:
return sparge_fn(qkv_list, mask_id=mask_id, tensor_layout="HND", output_dtype=qkv_list[0].dtype, recycle_q=recycle_q)
q, k, v = qkv_list
qkv_list.clear()
return sparge_fn(q, k, v, mask_id=_take_mask(mask_id), tensor_layout="HND", output_dtype=q.dtype)
if "topk" in sparge_fn.__code__.co_varnames:
if use_qkv_list:
topk = _mask_topk(mask_id, qkv_list[0])
if isinstance(mask_id, list):
mask_id.clear()
return sparge_fn(qkv_list, is_causal=False, tensor_layout="HND", output_dtype=qkv_list[0].dtype, topk=topk, recycle_q=recycle_q)
q, k, v = qkv_list
qkv_list.clear()
topk = _mask_topk(mask_id, q)
if isinstance(mask_id, list):
mask_id.clear()
return sparge_fn(q, k, v, is_causal=False, tensor_layout="HND", output_dtype=q.dtype, topk=topk)
q, k, v = qkv_list
qkv_list.clear()
return sparge_fn(q, k, v, is_causal=False, tensor_layout="HND", output_dtype=q.dtype)
return sparge_attention, sparge_name_or_error or "SpargeAttn"
def get_sparse_backend_name() -> str:
global _SPARSE_ATTENTION, _BACKEND_NAME
if _SPARSE_ATTENTION is None:
_SPARSE_ATTENTION, _BACKEND_NAME = _load_backend()
return _BACKEND_NAME or "unknown"
def log_sparse_backend() -> None:
global _PRINTED_BACKEND
backend_name = get_sparse_backend_name()
if not _PRINTED_BACKEND:
print(f"[FlashVSR] Sparse attention backend: {backend_name}")
_PRINTED_BACKEND = True
def sparse_attention(qkv_list: list[torch.Tensor], mask_id: torch.Tensor | list[torch.Tensor], recycle_q: bool = False) -> torch.Tensor:
log_sparse_backend()
return _SPARSE_ATTENTION(qkv_list, mask_id, recycle_q=recycle_q)