File size: 13,589 Bytes
7344bef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
from __future__ import annotations

import importlib
import importlib.util
import traceback
from typing import Callable

import torch

from .sparse_backend_config import (
    SPARSE_BACKEND_AUTO as _AUTO_BACKEND,
    SPARSE_BACKEND_LABELS as _BACKEND_LABELS,
    SPARSE_BACKEND_SPARGE as _SPARGE_BACKEND,
    SPARSE_BACKEND_TRITON_SPARSE as _TRITON_SPARSE_BACKEND,
    normalize_sparse_backend,
)

_SPARSE_ATTENTION: Callable | None = None
_BACKEND_NAME: str | None = None
_BACKEND_ERROR: str | None = None
_PRINTED_BACKEND = False
_PRINTED_IMPORT_ERRORS: set[str] = set()
_PRINTED_AUTO_FALLBACKS: set[str] = set()
_SPARSE_BACKEND = _AUTO_BACKEND
_REQUIREMENTS_MESSAGE = "FlashVSR sparse attention requirements are not satisfied."
_INSTALL_MESSAGE = "Install them from docs/INSTALLATION.md and restart WanGP."
_BACKEND_DEPENDENCIES = {
    _SPARGE_BACKEND: (("triton", "Triton"), ("spas_sage_attn", "SpargeAttn")),
    _TRITON_SPARSE_BACKEND: (("triton", "Triton"),),
}
_BUNDLED_SPARSE_BACKEND_NAME = "bundled Triton Sparse Attention"
_ARCH_KERNELS = {
    "sm80": ("SM80_ENABLED", "spas_sage_attn.sm80_compile", "spas_sage_attn._qattn_sm80"),
    "sm86": ("SM80_ENABLED", "spas_sage_attn.sm80_compile", "spas_sage_attn._qattn_sm80"),
    "sm87": ("SM80_ENABLED", "spas_sage_attn.sm80_compile", "spas_sage_attn._qattn_sm80"),
    "sm89": ("SM89_ENABLED", "spas_sage_attn.sm89_compile", "spas_sage_attn._qattn_sm89"),
    "sm90": ("SM90_ENABLED", "spas_sage_attn.sm90_compile", "spas_sage_attn._qattn_sm90"),
    "sm100": ("SM89_ENABLED", "spas_sage_attn.sm89_compile", "spas_sage_attn._qattn_sm89"),
    "sm120": ("SM89_ENABLED", "spas_sage_attn.sm89_compile", "spas_sage_attn._qattn_sm89"),
    "sm121": ("SM89_ENABLED", "spas_sage_attn.sm89_compile", "spas_sage_attn._qattn_sm89"),
}


def _print_import_error(module_name: str, exc: BaseException) -> None:
    key = f"{module_name}:{type(exc).__name__}:{exc}"
    if key in _PRINTED_IMPORT_ERRORS:
        return
    _PRINTED_IMPORT_ERRORS.add(key)
    print(f"[FlashVSR] Importing {module_name} failed:")
    traceback.print_exception(type(exc), exc, exc.__traceback__)


def set_sparse_backend(backend: object) -> str:
    global _SPARSE_BACKEND, _SPARSE_ATTENTION, _BACKEND_NAME, _BACKEND_ERROR, _PRINTED_BACKEND
    backend = normalize_sparse_backend(backend)
    if backend != _SPARSE_BACKEND:
        _SPARSE_ATTENTION = None
        _BACKEND_NAME = None
        _BACKEND_ERROR = None
        _PRINTED_BACKEND = False
    _SPARSE_BACKEND = backend
    return _SPARSE_BACKEND


def _selected_sparse_backend(backend: object | None = None) -> str:
    return _SPARSE_BACKEND if backend is None else normalize_sparse_backend(backend)


def _print_auto_fallback(message: str) -> None:
    if message in _PRINTED_AUTO_FALLBACKS:
        return
    _PRINTED_AUTO_FALLBACKS.add(message)
    print(f"[FlashVSR] Auto backend cannot use {_BACKEND_LABELS[_SPARGE_BACKEND]}: {message} Install SpargeAttn for better FlashVSR quality.")
    print(f"[FlashVSR] Auto backend trying {_BACKEND_LABELS[_TRITON_SPARSE_BACKEND]}.")


def _missing_sparse_attention_dependencies(backend: str) -> list[str]:
    missing = []
    for module_name, display_name in _BACKEND_DEPENDENCIES[backend]:
        if importlib.util.find_spec(module_name) is None:
            missing.append(display_name)
    return missing


def _missing_dependencies_message(backend: str, missing: list[str]) -> str:
    return f"{_REQUIREMENTS_MESSAGE} Backend: {_BACKEND_LABELS[backend]}. Missing: {', '.join(missing)}. {_INSTALL_MESSAGE}"


def _dependency_import_message(display_name: str, module_name: str, exc: BaseException) -> str:
    return f"{_REQUIREMENTS_MESSAGE} {display_name} is installed, but importing {module_name} failed. Check the console for the import error, then reinstall from docs/INSTALLATION.md and restart WanGP. Import failed: {type(exc).__name__}: {exc}"


def _kernel_load_message(sparge_error: str | None) -> str:
    return f"{_REQUIREMENTS_MESSAGE} SpargeAttn is installed, but its kernels could not be loaded. Reinstall SpargeAttn from docs/INSTALLATION.md and restart WanGP. SpargeAttn import failed: {sparge_error or 'not installed'}"


def _arch_kernel_load_message(arch: str, module_name: str, exc: BaseException | None) -> str:
    if exc is not None:
        return f"{_REQUIREMENTS_MESSAGE} SpargeAttn is installed, but importing its {arch} kernel failed. Check the console for the import error, then reinstall SpargeAttn from docs/INSTALLATION.md and restart WanGP. Import failed: {type(exc).__name__}: {exc}"
    return f"{_REQUIREMENTS_MESSAGE} SpargeAttn is installed, but its {arch} kernel is unavailable. Reinstall SpargeAttn from docs/INSTALLATION.md and restart WanGP. Missing kernel module: {module_name}"


def _dependency_import_error(backend: str) -> str | None:
    for module_name, display_name in _BACKEND_DEPENDENCIES[backend]:
        try:
            importlib.import_module(module_name)
        except Exception as exc:
            _print_import_error(module_name, exc)
            return _dependency_import_message(display_name, module_name, exc)
    return None


def _import_sparge_core():
    try:
        return importlib.import_module("shared.spas_sage_attn_core"), None
    except Exception as exc:
        _print_import_error("shared.spas_sage_attn_core", exc)
        return None, f"{type(exc).__name__}: {exc}"


def _current_cuda_arch() -> str | None:
    if not torch.cuda.is_available():
        return None
    major, minor = torch.cuda.get_device_capability(torch.cuda.current_device())
    return f"sm{major}{minor}"


def _arch_kernel_error(module, arch: str | None) -> str | None:
    if arch is None or arch not in _ARCH_KERNELS:
        return None
    flag_name, compile_module_name, direct_module_name = _ARCH_KERNELS[arch]
    if getattr(module, flag_name, False):
        return None
    try:
        importlib.import_module(compile_module_name)
    except ModuleNotFoundError as exc:
        if exc.name != compile_module_name:
            _print_import_error(compile_module_name, exc)
            return _arch_kernel_load_message(arch, compile_module_name, exc)
        try:
            importlib.import_module(direct_module_name)
        except Exception as direct_exc:
            _print_import_error(direct_module_name, direct_exc)
            return _arch_kernel_load_message(arch, direct_module_name, direct_exc)
        return _arch_kernel_load_message(arch, direct_module_name, None)
    except Exception as exc:
        _print_import_error(compile_module_name, exc)
        return _arch_kernel_load_message(arch, compile_module_name, exc)
    return _arch_kernel_load_message(arch, compile_module_name, None)


def _load_triton_sparse_backend() -> tuple[Callable | None, str | None, str | None]:
    try:
        from .sparse_sage.core import sparse_sageattn
    except Exception as exc:
        _print_import_error("postprocessing.flashvsr.sparse_sage.core", exc)
        return None, None, _dependency_import_message(_BUNDLED_SPARSE_BACKEND_NAME, "postprocessing.flashvsr.sparse_sage.core", exc)

    def bundled_sparse_sage(qkv_list: list[torch.Tensor], mask_id: torch.Tensor | list[torch.Tensor], recycle_q: bool = False) -> torch.Tensor:
        mask_id = _int8_mask(mask_id)
        return sparse_sageattn(qkv_list, mask_id=_take_mask(mask_id), is_causal=False, tensor_layout="HND")

    return bundled_sparse_sage, _BUNDLED_SPARSE_BACKEND_NAME, None


def _backend_requirement_status(backend: str) -> tuple[Callable | None, str | None, str | None]:
    missing = _missing_sparse_attention_dependencies(backend)
    if missing:
        return None, None, _missing_dependencies_message(backend, missing)
    dependency_import_error = _dependency_import_error(backend)
    if dependency_import_error is not None:
        return None, None, dependency_import_error
    if backend == _TRITON_SPARSE_BACKEND:
        return _load_triton_sparse_backend()
    module, sparge_error = _import_sparge_core()
    if module is None:
        return None, None, _kernel_load_message(sparge_error)
    arch_kernel_error = _arch_kernel_error(module, _current_cuda_arch())
    if arch_kernel_error is not None:
        return None, None, arch_kernel_error
    fn = getattr(module, "block_sparse_attn_cuda", None)
    if not callable(fn):
        return None, None, _kernel_load_message("WanGP SpargeAttn block sparse CUDA function not found")
    return fn, "WanGP SpargeAttn block sparse CUDA", None


def _sparse_attention_requirement_status(backend: object | None = None) -> tuple[Callable | None, str | None, str | None]:
    backend = _selected_sparse_backend(backend)
    if backend != _AUTO_BACKEND:
        return _backend_requirement_status(backend)

    sparge_fn, sparge_name, sparge_message = _backend_requirement_status(_SPARGE_BACKEND)
    if sparge_message is None:
        return sparge_fn, sparge_name, None

    _print_auto_fallback(sparge_message)
    triton_sparse_fn, triton_sparse_name, triton_sparse_message = _backend_requirement_status(_TRITON_SPARSE_BACKEND)
    if triton_sparse_message is None:
        return triton_sparse_fn, triton_sparse_name, None
    return None, None, f"FlashVSR Auto backend could not load any sparse attention backend. Sparge: {sparge_message} {_BACKEND_LABELS[_TRITON_SPARSE_BACKEND]}: {triton_sparse_message}"


def sparse_attention_requirement_message(backend: object | None = None) -> str | None:
    _, _, message = _sparse_attention_requirement_status(backend)
    return message


def sparge_attention_available() -> bool:
    return sparse_attention_requirement_message(_SPARGE_BACKEND) is None


def require_sparge_attention() -> None:
    _, _, message = _sparse_attention_requirement_status()
    if message is not None:
        raise RuntimeError(message)


def _mask_topk(mask_id: torch.Tensor | None, q: torch.Tensor) -> torch.Tensor | float:
    if isinstance(mask_id, list):
        mask_id = mask_id[0] if len(mask_id) > 0 else None
    if mask_id is None or not torch.is_tensor(mask_id):
        return 0.5
    density = mask_id.to(device=q.device, dtype=torch.float32).mean(dim=(0, 2, 3))
    return density.clamp(1.0 / max(int(mask_id.shape[-1]), 1), 1.0)


def _int8_mask(mask_id: torch.Tensor | list[torch.Tensor]) -> torch.Tensor | list[torch.Tensor]:
    mask = mask_id[0] if isinstance(mask_id, list) else mask_id
    if mask.dtype != torch.int8:
        mask = mask.to(torch.int8)
        if isinstance(mask_id, list):
            mask_id[0] = mask
    return mask_id if isinstance(mask_id, list) else mask


def _take_mask(mask_id: torch.Tensor | list[torch.Tensor]) -> torch.Tensor:
    if isinstance(mask_id, list):
        mask = mask_id[0]
        mask_id.clear()
        return mask
    return mask_id


def _load_backend() -> tuple[Callable, str]:
    global _BACKEND_ERROR
    backend = _selected_sparse_backend()
    sparge_fn, sparge_name_or_error, message = _sparse_attention_requirement_status()
    if message is not None:
        _BACKEND_ERROR = message
        raise RuntimeError(message)
    if sparge_fn is None:
        _BACKEND_ERROR = _kernel_load_message(sparge_name_or_error)
        raise RuntimeError(_BACKEND_ERROR)
    if backend == _TRITON_SPARSE_BACKEND or sparge_name_or_error == _BUNDLED_SPARSE_BACKEND_NAME:
        return sparge_fn, sparge_name_or_error or _BUNDLED_SPARSE_BACKEND_NAME

    use_qkv_list = sparge_fn.__module__ == "shared.spas_sage_attn_core"

    def sparge_attention(qkv_list: list[torch.Tensor], mask_id: torch.Tensor | list[torch.Tensor], recycle_q: bool = False) -> torch.Tensor:
        if "mask_id" in sparge_fn.__code__.co_varnames:
            mask_id = _int8_mask(mask_id)
            if use_qkv_list:
                return sparge_fn(qkv_list, mask_id=mask_id, tensor_layout="HND", output_dtype=qkv_list[0].dtype, recycle_q=recycle_q)
            q, k, v = qkv_list
            qkv_list.clear()
            return sparge_fn(q, k, v, mask_id=_take_mask(mask_id), tensor_layout="HND", output_dtype=q.dtype)
        if "topk" in sparge_fn.__code__.co_varnames:
            if use_qkv_list:
                topk = _mask_topk(mask_id, qkv_list[0])
                if isinstance(mask_id, list):
                    mask_id.clear()
                return sparge_fn(qkv_list, is_causal=False, tensor_layout="HND", output_dtype=qkv_list[0].dtype, topk=topk, recycle_q=recycle_q)
            q, k, v = qkv_list
            qkv_list.clear()
            topk = _mask_topk(mask_id, q)
            if isinstance(mask_id, list):
                mask_id.clear()
            return sparge_fn(q, k, v, is_causal=False, tensor_layout="HND", output_dtype=q.dtype, topk=topk)
        q, k, v = qkv_list
        qkv_list.clear()
        return sparge_fn(q, k, v, is_causal=False, tensor_layout="HND", output_dtype=q.dtype)

    return sparge_attention, sparge_name_or_error or "SpargeAttn"


def get_sparse_backend_name() -> str:
    global _SPARSE_ATTENTION, _BACKEND_NAME
    if _SPARSE_ATTENTION is None:
        _SPARSE_ATTENTION, _BACKEND_NAME = _load_backend()
    return _BACKEND_NAME or "unknown"


def log_sparse_backend() -> None:
    global _PRINTED_BACKEND
    backend_name = get_sparse_backend_name()
    if not _PRINTED_BACKEND:
        print(f"[FlashVSR] Sparse attention backend: {backend_name}")
        _PRINTED_BACKEND = True


def sparse_attention(qkv_list: list[torch.Tensor], mask_id: torch.Tensor | list[torch.Tensor], recycle_q: bool = False) -> torch.Tensor:
    log_sparse_backend()
    return _SPARSE_ATTENTION(qkv_list, mask_id, recycle_q=recycle_q)