from __future__ import annotations import importlib import importlib.util import traceback from typing import Callable import torch from .sparse_backend_config import ( SPARSE_BACKEND_AUTO as _AUTO_BACKEND, SPARSE_BACKEND_LABELS as _BACKEND_LABELS, SPARSE_BACKEND_SPARGE as _SPARGE_BACKEND, SPARSE_BACKEND_TRITON_SPARSE as _TRITON_SPARSE_BACKEND, normalize_sparse_backend, ) _SPARSE_ATTENTION: Callable | None = None _BACKEND_NAME: str | None = None _BACKEND_ERROR: str | None = None _PRINTED_BACKEND = False _PRINTED_IMPORT_ERRORS: set[str] = set() _PRINTED_AUTO_FALLBACKS: set[str] = set() _SPARSE_BACKEND = _AUTO_BACKEND _REQUIREMENTS_MESSAGE = "FlashVSR sparse attention requirements are not satisfied." _INSTALL_MESSAGE = "Install them from docs/INSTALLATION.md and restart WanGP." _BACKEND_DEPENDENCIES = { _SPARGE_BACKEND: (("triton", "Triton"), ("spas_sage_attn", "SpargeAttn")), _TRITON_SPARSE_BACKEND: (("triton", "Triton"),), } _BUNDLED_SPARSE_BACKEND_NAME = "bundled Triton Sparse Attention" _ARCH_KERNELS = { "sm80": ("SM80_ENABLED", "spas_sage_attn.sm80_compile", "spas_sage_attn._qattn_sm80"), "sm86": ("SM80_ENABLED", "spas_sage_attn.sm80_compile", "spas_sage_attn._qattn_sm80"), "sm87": ("SM80_ENABLED", "spas_sage_attn.sm80_compile", "spas_sage_attn._qattn_sm80"), "sm89": ("SM89_ENABLED", "spas_sage_attn.sm89_compile", "spas_sage_attn._qattn_sm89"), "sm90": ("SM90_ENABLED", "spas_sage_attn.sm90_compile", "spas_sage_attn._qattn_sm90"), "sm100": ("SM89_ENABLED", "spas_sage_attn.sm89_compile", "spas_sage_attn._qattn_sm89"), "sm120": ("SM89_ENABLED", "spas_sage_attn.sm89_compile", "spas_sage_attn._qattn_sm89"), "sm121": ("SM89_ENABLED", "spas_sage_attn.sm89_compile", "spas_sage_attn._qattn_sm89"), } def _print_import_error(module_name: str, exc: BaseException) -> None: key = f"{module_name}:{type(exc).__name__}:{exc}" if key in _PRINTED_IMPORT_ERRORS: return _PRINTED_IMPORT_ERRORS.add(key) print(f"[FlashVSR] Importing {module_name} failed:") traceback.print_exception(type(exc), exc, exc.__traceback__) def set_sparse_backend(backend: object) -> str: global _SPARSE_BACKEND, _SPARSE_ATTENTION, _BACKEND_NAME, _BACKEND_ERROR, _PRINTED_BACKEND backend = normalize_sparse_backend(backend) if backend != _SPARSE_BACKEND: _SPARSE_ATTENTION = None _BACKEND_NAME = None _BACKEND_ERROR = None _PRINTED_BACKEND = False _SPARSE_BACKEND = backend return _SPARSE_BACKEND def _selected_sparse_backend(backend: object | None = None) -> str: return _SPARSE_BACKEND if backend is None else normalize_sparse_backend(backend) def _print_auto_fallback(message: str) -> None: if message in _PRINTED_AUTO_FALLBACKS: return _PRINTED_AUTO_FALLBACKS.add(message) print(f"[FlashVSR] Auto backend cannot use {_BACKEND_LABELS[_SPARGE_BACKEND]}: {message} Install SpargeAttn for better FlashVSR quality.") print(f"[FlashVSR] Auto backend trying {_BACKEND_LABELS[_TRITON_SPARSE_BACKEND]}.") def _missing_sparse_attention_dependencies(backend: str) -> list[str]: missing = [] for module_name, display_name in _BACKEND_DEPENDENCIES[backend]: if importlib.util.find_spec(module_name) is None: missing.append(display_name) return missing def _missing_dependencies_message(backend: str, missing: list[str]) -> str: return f"{_REQUIREMENTS_MESSAGE} Backend: {_BACKEND_LABELS[backend]}. Missing: {', '.join(missing)}. {_INSTALL_MESSAGE}" def _dependency_import_message(display_name: str, module_name: str, exc: BaseException) -> str: return f"{_REQUIREMENTS_MESSAGE} {display_name} is installed, but importing {module_name} failed. Check the console for the import error, then reinstall from docs/INSTALLATION.md and restart WanGP. Import failed: {type(exc).__name__}: {exc}" def _kernel_load_message(sparge_error: str | None) -> str: return f"{_REQUIREMENTS_MESSAGE} SpargeAttn is installed, but its kernels could not be loaded. Reinstall SpargeAttn from docs/INSTALLATION.md and restart WanGP. SpargeAttn import failed: {sparge_error or 'not installed'}" def _arch_kernel_load_message(arch: str, module_name: str, exc: BaseException | None) -> str: if exc is not None: return f"{_REQUIREMENTS_MESSAGE} SpargeAttn is installed, but importing its {arch} kernel failed. Check the console for the import error, then reinstall SpargeAttn from docs/INSTALLATION.md and restart WanGP. Import failed: {type(exc).__name__}: {exc}" return f"{_REQUIREMENTS_MESSAGE} SpargeAttn is installed, but its {arch} kernel is unavailable. Reinstall SpargeAttn from docs/INSTALLATION.md and restart WanGP. Missing kernel module: {module_name}" def _dependency_import_error(backend: str) -> str | None: for module_name, display_name in _BACKEND_DEPENDENCIES[backend]: try: importlib.import_module(module_name) except Exception as exc: _print_import_error(module_name, exc) return _dependency_import_message(display_name, module_name, exc) return None def _import_sparge_core(): try: return importlib.import_module("shared.spas_sage_attn_core"), None except Exception as exc: _print_import_error("shared.spas_sage_attn_core", exc) return None, f"{type(exc).__name__}: {exc}" def _current_cuda_arch() -> str | None: if not torch.cuda.is_available(): return None major, minor = torch.cuda.get_device_capability(torch.cuda.current_device()) return f"sm{major}{minor}" def _arch_kernel_error(module, arch: str | None) -> str | None: if arch is None or arch not in _ARCH_KERNELS: return None flag_name, compile_module_name, direct_module_name = _ARCH_KERNELS[arch] if getattr(module, flag_name, False): return None try: importlib.import_module(compile_module_name) except ModuleNotFoundError as exc: if exc.name != compile_module_name: _print_import_error(compile_module_name, exc) return _arch_kernel_load_message(arch, compile_module_name, exc) try: importlib.import_module(direct_module_name) except Exception as direct_exc: _print_import_error(direct_module_name, direct_exc) return _arch_kernel_load_message(arch, direct_module_name, direct_exc) return _arch_kernel_load_message(arch, direct_module_name, None) except Exception as exc: _print_import_error(compile_module_name, exc) return _arch_kernel_load_message(arch, compile_module_name, exc) return _arch_kernel_load_message(arch, compile_module_name, None) def _load_triton_sparse_backend() -> tuple[Callable | None, str | None, str | None]: try: from .sparse_sage.core import sparse_sageattn except Exception as exc: _print_import_error("postprocessing.flashvsr.sparse_sage.core", exc) return None, None, _dependency_import_message(_BUNDLED_SPARSE_BACKEND_NAME, "postprocessing.flashvsr.sparse_sage.core", exc) def bundled_sparse_sage(qkv_list: list[torch.Tensor], mask_id: torch.Tensor | list[torch.Tensor], recycle_q: bool = False) -> torch.Tensor: mask_id = _int8_mask(mask_id) return sparse_sageattn(qkv_list, mask_id=_take_mask(mask_id), is_causal=False, tensor_layout="HND") return bundled_sparse_sage, _BUNDLED_SPARSE_BACKEND_NAME, None def _backend_requirement_status(backend: str) -> tuple[Callable | None, str | None, str | None]: missing = _missing_sparse_attention_dependencies(backend) if missing: return None, None, _missing_dependencies_message(backend, missing) dependency_import_error = _dependency_import_error(backend) if dependency_import_error is not None: return None, None, dependency_import_error if backend == _TRITON_SPARSE_BACKEND: return _load_triton_sparse_backend() module, sparge_error = _import_sparge_core() if module is None: return None, None, _kernel_load_message(sparge_error) arch_kernel_error = _arch_kernel_error(module, _current_cuda_arch()) if arch_kernel_error is not None: return None, None, arch_kernel_error fn = getattr(module, "block_sparse_attn_cuda", None) if not callable(fn): return None, None, _kernel_load_message("WanGP SpargeAttn block sparse CUDA function not found") return fn, "WanGP SpargeAttn block sparse CUDA", None def _sparse_attention_requirement_status(backend: object | None = None) -> tuple[Callable | None, str | None, str | None]: backend = _selected_sparse_backend(backend) if backend != _AUTO_BACKEND: return _backend_requirement_status(backend) sparge_fn, sparge_name, sparge_message = _backend_requirement_status(_SPARGE_BACKEND) if sparge_message is None: return sparge_fn, sparge_name, None _print_auto_fallback(sparge_message) triton_sparse_fn, triton_sparse_name, triton_sparse_message = _backend_requirement_status(_TRITON_SPARSE_BACKEND) if triton_sparse_message is None: return triton_sparse_fn, triton_sparse_name, None return None, None, f"FlashVSR Auto backend could not load any sparse attention backend. Sparge: {sparge_message} {_BACKEND_LABELS[_TRITON_SPARSE_BACKEND]}: {triton_sparse_message}" def sparse_attention_requirement_message(backend: object | None = None) -> str | None: _, _, message = _sparse_attention_requirement_status(backend) return message def sparge_attention_available() -> bool: return sparse_attention_requirement_message(_SPARGE_BACKEND) is None def require_sparge_attention() -> None: _, _, message = _sparse_attention_requirement_status() if message is not None: raise RuntimeError(message) def _mask_topk(mask_id: torch.Tensor | None, q: torch.Tensor) -> torch.Tensor | float: if isinstance(mask_id, list): mask_id = mask_id[0] if len(mask_id) > 0 else None if mask_id is None or not torch.is_tensor(mask_id): return 0.5 density = mask_id.to(device=q.device, dtype=torch.float32).mean(dim=(0, 2, 3)) return density.clamp(1.0 / max(int(mask_id.shape[-1]), 1), 1.0) def _int8_mask(mask_id: torch.Tensor | list[torch.Tensor]) -> torch.Tensor | list[torch.Tensor]: mask = mask_id[0] if isinstance(mask_id, list) else mask_id if mask.dtype != torch.int8: mask = mask.to(torch.int8) if isinstance(mask_id, list): mask_id[0] = mask return mask_id if isinstance(mask_id, list) else mask def _take_mask(mask_id: torch.Tensor | list[torch.Tensor]) -> torch.Tensor: if isinstance(mask_id, list): mask = mask_id[0] mask_id.clear() return mask return mask_id def _load_backend() -> tuple[Callable, str]: global _BACKEND_ERROR backend = _selected_sparse_backend() sparge_fn, sparge_name_or_error, message = _sparse_attention_requirement_status() if message is not None: _BACKEND_ERROR = message raise RuntimeError(message) if sparge_fn is None: _BACKEND_ERROR = _kernel_load_message(sparge_name_or_error) raise RuntimeError(_BACKEND_ERROR) if backend == _TRITON_SPARSE_BACKEND or sparge_name_or_error == _BUNDLED_SPARSE_BACKEND_NAME: return sparge_fn, sparge_name_or_error or _BUNDLED_SPARSE_BACKEND_NAME use_qkv_list = sparge_fn.__module__ == "shared.spas_sage_attn_core" def sparge_attention(qkv_list: list[torch.Tensor], mask_id: torch.Tensor | list[torch.Tensor], recycle_q: bool = False) -> torch.Tensor: if "mask_id" in sparge_fn.__code__.co_varnames: mask_id = _int8_mask(mask_id) if use_qkv_list: return sparge_fn(qkv_list, mask_id=mask_id, tensor_layout="HND", output_dtype=qkv_list[0].dtype, recycle_q=recycle_q) q, k, v = qkv_list qkv_list.clear() return sparge_fn(q, k, v, mask_id=_take_mask(mask_id), tensor_layout="HND", output_dtype=q.dtype) if "topk" in sparge_fn.__code__.co_varnames: if use_qkv_list: topk = _mask_topk(mask_id, qkv_list[0]) if isinstance(mask_id, list): mask_id.clear() return sparge_fn(qkv_list, is_causal=False, tensor_layout="HND", output_dtype=qkv_list[0].dtype, topk=topk, recycle_q=recycle_q) q, k, v = qkv_list qkv_list.clear() topk = _mask_topk(mask_id, q) if isinstance(mask_id, list): mask_id.clear() return sparge_fn(q, k, v, is_causal=False, tensor_layout="HND", output_dtype=q.dtype, topk=topk) q, k, v = qkv_list qkv_list.clear() return sparge_fn(q, k, v, is_causal=False, tensor_layout="HND", output_dtype=q.dtype) return sparge_attention, sparge_name_or_error or "SpargeAttn" def get_sparse_backend_name() -> str: global _SPARSE_ATTENTION, _BACKEND_NAME if _SPARSE_ATTENTION is None: _SPARSE_ATTENTION, _BACKEND_NAME = _load_backend() return _BACKEND_NAME or "unknown" def log_sparse_backend() -> None: global _PRINTED_BACKEND backend_name = get_sparse_backend_name() if not _PRINTED_BACKEND: print(f"[FlashVSR] Sparse attention backend: {backend_name}") _PRINTED_BACKEND = True def sparse_attention(qkv_list: list[torch.Tensor], mask_id: torch.Tensor | list[torch.Tensor], recycle_q: bool = False) -> torch.Tensor: log_sparse_backend() return _SPARSE_ATTENTION(qkv_list, mask_id, recycle_q=recycle_q)