File size: 10,433 Bytes
7344bef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
from __future__ import annotations

import os
from typing import Any, Callable

from postprocessing.flashvsr.sparse_backend_config import (
    SPARSE_BACKEND_AUTO,
    SPARSE_BACKEND_SPARGE,
    SPARSE_BACKEND_TRITON_SPARSE,
    normalize_sparse_backend,
)


class FlashVSRBridge:
    MODE_OFF = 0
    MODE_TINY = 1
    MODE_FULL = 2
    MODE_TINY_LONG = 3
    PERSIST_UNLOAD = 1
    PERSIST_RAM = 2
    BACKEND_AUTO = SPARSE_BACKEND_AUTO
    BACKEND_TRITON_SPARSE = SPARSE_BACKEND_TRITON_SPARSE
    BACKEND_SPARGE = SPARSE_BACKEND_SPARGE
    TOPK_RATIO_DEFAULT = 0.0
    TOPK_RATIO_MAX = 4.0
    UPSAMPLING_VALUE_PREFIX = "flashvsr"
    UPSAMPLING_TWO_PASS_VALUE_PREFIX = "flashvsr2pass"
    UPSAMPLING_RATIOS = (1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0)

    TRANSFORMER_FILENAME = "FlashVSR_v1.1_transformer_bf16.safetensors"
    LQ_PROJ_FILENAME = "FlashVSR_v1.1_lq_proj_bf16.safetensors"
    TCDECODER_FILENAME = "FlashVSR_v1.1_tcdecoder_bf16.safetensors"
    POSI_PROMPT_FILENAME = "FlashVSR_v1.1_posi_prompt_bf16.safetensors"
    VAE_FILENAME = "Wan2.1_VAE.safetensors"

    _VARIANTS = {
        MODE_TINY: "tiny",
        MODE_FULL: "full",
        MODE_TINY_LONG: "tiny-long",
    }

    def __init__(self, server_config: dict[str, Any], files_locator):
        self.server_config = server_config
        self.files_locator = files_locator

    @classmethod
    def normalize_topk_ratio(cls, value: Any) -> float:
        try:
            value = float(value)
        except (TypeError, ValueError):
            value = cls.TOPK_RATIO_DEFAULT
        return max(0.0, min(cls.TOPK_RATIO_MAX, value))

    @classmethod
    def normalize_backend(cls, value: Any) -> str:
        return normalize_sparse_backend(value)

    def normalize_config(self, config: dict[str, Any] | None = None) -> tuple[int, int]:
        config = self.server_config if config is None else config
        mode = config.get("flashvsr_mode", self.MODE_OFF)
        persistence = config.get("flashvsr_persistence", self.PERSIST_UNLOAD)
        try:
            mode = int(mode)
        except (TypeError, ValueError):
            mode = self.MODE_OFF
        try:
            persistence = int(persistence)
        except (TypeError, ValueError):
            persistence = self.PERSIST_UNLOAD
        if mode not in self._VARIANTS and mode != self.MODE_OFF:
            mode = self.MODE_OFF
        if persistence not in (self.PERSIST_UNLOAD, self.PERSIST_RAM):
            persistence = self.PERSIST_UNLOAD
        config["flashvsr_mode"] = mode
        config["flashvsr_persistence"] = persistence
        config["flashvsr_backend"] = self.normalize_backend(config.get("flashvsr_backend", self.BACKEND_AUTO))
        config["flashvsr_topk_ratio"] = self.normalize_topk_ratio(config.get("flashvsr_topk_ratio", self.TOPK_RATIO_DEFAULT))
        return mode, persistence

    def settings(self, config: dict[str, Any] | None = None) -> tuple[bool, str | None, int]:
        mode, persistence = self.normalize_config(config)
        return mode != self.MODE_OFF, self._VARIANTS.get(mode), persistence

    def topk_ratio(self) -> float:
        return self.normalize_topk_ratio(self.server_config.get("flashvsr_topk_ratio", self.TOPK_RATIO_DEFAULT))

    def backend(self) -> str:
        return self.normalize_backend(self.server_config.get("flashvsr_backend", self.BACKEND_AUTO))

    def enabled(self) -> bool:
        return self.settings()[0]

    @classmethod
    def format_ratio(cls, scale: float) -> str:
        scale = float(scale)
        return str(int(scale)) if scale.is_integer() else f"{scale:g}"

    @classmethod
    def format_ratio_label(cls, scale: float) -> str:
        return f"{float(scale):.1f}"

    @classmethod
    def upsampling_value(cls, scale: float) -> str:
        return f"{cls.UPSAMPLING_VALUE_PREFIX}{cls.format_ratio(scale)}"

    @classmethod
    def upsampling_two_pass_value(cls, scale: float) -> str:
        return f"{cls.UPSAMPLING_TWO_PASS_VALUE_PREFIX}{cls.format_ratio(scale)}"

    @classmethod
    def upsampling_choices(cls, include_name: bool = True, include_two_pass: bool = False) -> list[tuple[str, str]]:
        prefix = "FlashVSR " if include_name else ""
        choices = [(f"{prefix}x{cls.format_ratio_label(scale)}", cls.upsampling_value(scale)) for scale in cls.UPSAMPLING_RATIOS]
        return choices + ([(f"{prefix}Two Pass x{cls.format_ratio_label(scale)}", cls.upsampling_two_pass_value(scale)) for scale in cls.UPSAMPLING_RATIOS] if include_two_pass else [])

    @classmethod
    def scale_for_upsampling(cls, spatial_upsampling) -> float | None:
        text = str(spatial_upsampling or "").strip().lower()
        prefix = cls.UPSAMPLING_TWO_PASS_VALUE_PREFIX if text.startswith(cls.UPSAMPLING_TWO_PASS_VALUE_PREFIX) else cls.UPSAMPLING_VALUE_PREFIX
        if not text.startswith(prefix):
            return None
        try:
            scale = float(text[len(prefix):])
        except ValueError:
            return None
        return scale if scale in cls.UPSAMPLING_RATIOS else None

    @classmethod
    def is_two_pass_upsampling(cls, spatial_upsampling) -> bool:
        return str(spatial_upsampling or "").strip().lower().startswith(cls.UPSAMPLING_TWO_PASS_VALUE_PREFIX)

    @classmethod
    def query_edit_mode_def(cls, include_name: bool = True) -> dict[str, Any]:
        return {
            "name": "FlashVSR",
            "spatial_upsampling_choices": cls.upsampling_choices(include_name=include_name, include_two_pass=True),
            "default_spatial_upsampling": cls.upsampling_value(2.0),
        }

    def is_upsampling(self, spatial_upsampling) -> bool:
        return self.scale_for_upsampling(spatial_upsampling) is not None

    def validate_upsampling(self, spatial_upsampling, image_mode: int) -> str:
        if not self.is_upsampling(spatial_upsampling):
            return ""
        if not self.enabled():
            return "FlashVSR Spatial Upsampling is disabled in Configuration > Extensions"
        return ""

    def query_download_def(self, enabled_only: bool = True) -> dict[str, Any] | None:
        if enabled_only and not self.enabled():
            return None
        return {
            "repoId": "DeepBeepMeep/Wan2.1",
            "sourceFolderList": ["FlashVSR", ""],
            "fileList": [[self.TRANSFORMER_FILENAME, self.LQ_PROJ_FILENAME, self.TCDECODER_FILENAME, self.POSI_PROMPT_FILENAME], [self.VAE_FILENAME]],
        }

    def _locate_flashvsr_file(self, filename: str) -> str:
        return self.files_locator.locate_file(os.path.join("FlashVSR", filename))

    def paths(self, variant: str):
        from postprocessing.flashvsr.runtime import FlashVSRPaths
        return FlashVSRPaths(
            transformer=self._locate_flashvsr_file(self.TRANSFORMER_FILENAME),
            lq_proj=self._locate_flashvsr_file(self.LQ_PROJ_FILENAME),
            posi_prompt=self._locate_flashvsr_file(self.POSI_PROMPT_FILENAME),
            tcdecoder=None if variant == "full" else self._locate_flashvsr_file(self.TCDECODER_FILENAME),
            vae=self.files_locator.locate_file(self.VAE_FILENAME) if variant == "full" else None,
        )

    def vae_tile_size(self, vae_config: int, output_height: int | None = None, output_width: int | None = None) -> int:
        import torch
        from models.wan.modules.vae import WanVAE

        device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576 if torch.cuda.is_available() else 0
        mixed_precision = self.server_config.get("vae_precision", "16") == "32"
        return WanVAE.get_VAE_tile_size(vae_config, device_mem_capacity, mixed_precision, output_height=output_height, output_width=output_width)

    def download(self, process_files: Callable[..., Any], send_cmd=None, status_text: str | None = None) -> bool:
        flashvsr_def = self.query_download_def()
        if flashvsr_def is None:
            return False
        _, variant, _ = self.settings()
        required = [os.path.join("FlashVSR", self.TRANSFORMER_FILENAME), os.path.join("FlashVSR", self.LQ_PROJ_FILENAME), os.path.join("FlashVSR", self.POSI_PROMPT_FILENAME)]
        required.append(self.VAE_FILENAME if variant == "full" else os.path.join("FlashVSR", self.TCDECODER_FILENAME))
        if all(self.files_locator.locate_file(path, error_if_none=False) is not None for path in required):
            return False
        from shared.utils.download import send_download_status

        send_download_status(send_cmd, status_text)
        process_files(**flashvsr_def)
        return True

    def upscale(self, sample, spatial_upsampling, *, seed=0, continue_cache=None, return_continue_cache=False, vae_tile_size=None, process_files: Callable[..., Any], vae_config: int, init_pipe: Callable[..., int], profile, still_image=False, abort_callback=None, progress_callback=None):
        scale = self.scale_for_upsampling(spatial_upsampling)
        if scale is None:
            raise ValueError(f"Unknown FlashVSR upsampling mode: {spatial_upsampling}")
        enabled, variant, persistence = self.settings()
        if not enabled:
            raise RuntimeError("FlashVSR spatial upsampling is disabled in Configuration > Extensions.")
        self.download(process_files)
        from postprocessing.flashvsr.attention_backend import set_sparse_backend
        set_sparse_backend(self.backend())
        from postprocessing.flashvsr.runtime import upscale_video

        output_height = int(sample.shape[-2] * scale)
        output_width = int(sample.shape[-1] * scale)
        flashvsr_tile_size = self.vae_tile_size(vae_config, output_height, output_width)
        return upscale_video(
            sample,
            scale,
            self.paths(variant),
            variant=variant,
            seed=seed,
            continue_cache=continue_cache,
            return_continue_cache=return_continue_cache,
            persistent_models=persistence == self.PERSIST_RAM,
            vae_tile_size=flashvsr_tile_size,
            topk_ratio=self.topk_ratio(),
            init_pipe=init_pipe,
            profile=profile,
            still_image=still_image,
            two_pass=self.is_two_pass_upsampling(spatial_upsampling),
            abort_callback=abort_callback,
            progress_callback=progress_callback,
        )

    def release_vram(self) -> None:
        from postprocessing.flashvsr.runtime import release_models
        release_models()