| from __future__ import annotations |
|
|
| import os |
| from typing import Any, Callable |
|
|
| from postprocessing.flashvsr.sparse_backend_config import ( |
| SPARSE_BACKEND_AUTO, |
| SPARSE_BACKEND_SPARGE, |
| SPARSE_BACKEND_TRITON_SPARSE, |
| normalize_sparse_backend, |
| ) |
|
|
|
|
| class FlashVSRBridge: |
| MODE_OFF = 0 |
| MODE_TINY = 1 |
| MODE_FULL = 2 |
| MODE_TINY_LONG = 3 |
| PERSIST_UNLOAD = 1 |
| PERSIST_RAM = 2 |
| BACKEND_AUTO = SPARSE_BACKEND_AUTO |
| BACKEND_TRITON_SPARSE = SPARSE_BACKEND_TRITON_SPARSE |
| BACKEND_SPARGE = SPARSE_BACKEND_SPARGE |
| TOPK_RATIO_DEFAULT = 0.0 |
| TOPK_RATIO_MAX = 4.0 |
| UPSAMPLING_VALUE_PREFIX = "flashvsr" |
| UPSAMPLING_TWO_PASS_VALUE_PREFIX = "flashvsr2pass" |
| UPSAMPLING_RATIOS = (1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0) |
|
|
| TRANSFORMER_FILENAME = "FlashVSR_v1.1_transformer_bf16.safetensors" |
| LQ_PROJ_FILENAME = "FlashVSR_v1.1_lq_proj_bf16.safetensors" |
| TCDECODER_FILENAME = "FlashVSR_v1.1_tcdecoder_bf16.safetensors" |
| POSI_PROMPT_FILENAME = "FlashVSR_v1.1_posi_prompt_bf16.safetensors" |
| VAE_FILENAME = "Wan2.1_VAE.safetensors" |
|
|
| _VARIANTS = { |
| MODE_TINY: "tiny", |
| MODE_FULL: "full", |
| MODE_TINY_LONG: "tiny-long", |
| } |
|
|
| def __init__(self, server_config: dict[str, Any], files_locator): |
| self.server_config = server_config |
| self.files_locator = files_locator |
|
|
| @classmethod |
| def normalize_topk_ratio(cls, value: Any) -> float: |
| try: |
| value = float(value) |
| except (TypeError, ValueError): |
| value = cls.TOPK_RATIO_DEFAULT |
| return max(0.0, min(cls.TOPK_RATIO_MAX, value)) |
|
|
| @classmethod |
| def normalize_backend(cls, value: Any) -> str: |
| return normalize_sparse_backend(value) |
|
|
| def normalize_config(self, config: dict[str, Any] | None = None) -> tuple[int, int]: |
| config = self.server_config if config is None else config |
| mode = config.get("flashvsr_mode", self.MODE_OFF) |
| persistence = config.get("flashvsr_persistence", self.PERSIST_UNLOAD) |
| try: |
| mode = int(mode) |
| except (TypeError, ValueError): |
| mode = self.MODE_OFF |
| try: |
| persistence = int(persistence) |
| except (TypeError, ValueError): |
| persistence = self.PERSIST_UNLOAD |
| if mode not in self._VARIANTS and mode != self.MODE_OFF: |
| mode = self.MODE_OFF |
| if persistence not in (self.PERSIST_UNLOAD, self.PERSIST_RAM): |
| persistence = self.PERSIST_UNLOAD |
| config["flashvsr_mode"] = mode |
| config["flashvsr_persistence"] = persistence |
| config["flashvsr_backend"] = self.normalize_backend(config.get("flashvsr_backend", self.BACKEND_AUTO)) |
| config["flashvsr_topk_ratio"] = self.normalize_topk_ratio(config.get("flashvsr_topk_ratio", self.TOPK_RATIO_DEFAULT)) |
| return mode, persistence |
|
|
| def settings(self, config: dict[str, Any] | None = None) -> tuple[bool, str | None, int]: |
| mode, persistence = self.normalize_config(config) |
| return mode != self.MODE_OFF, self._VARIANTS.get(mode), persistence |
|
|
| def topk_ratio(self) -> float: |
| return self.normalize_topk_ratio(self.server_config.get("flashvsr_topk_ratio", self.TOPK_RATIO_DEFAULT)) |
|
|
| def backend(self) -> str: |
| return self.normalize_backend(self.server_config.get("flashvsr_backend", self.BACKEND_AUTO)) |
|
|
| def enabled(self) -> bool: |
| return self.settings()[0] |
|
|
| @classmethod |
| def format_ratio(cls, scale: float) -> str: |
| scale = float(scale) |
| return str(int(scale)) if scale.is_integer() else f"{scale:g}" |
|
|
| @classmethod |
| def format_ratio_label(cls, scale: float) -> str: |
| return f"{float(scale):.1f}" |
|
|
| @classmethod |
| def upsampling_value(cls, scale: float) -> str: |
| return f"{cls.UPSAMPLING_VALUE_PREFIX}{cls.format_ratio(scale)}" |
|
|
| @classmethod |
| def upsampling_two_pass_value(cls, scale: float) -> str: |
| return f"{cls.UPSAMPLING_TWO_PASS_VALUE_PREFIX}{cls.format_ratio(scale)}" |
|
|
| @classmethod |
| def upsampling_choices(cls, include_name: bool = True, include_two_pass: bool = False) -> list[tuple[str, str]]: |
| prefix = "FlashVSR " if include_name else "" |
| choices = [(f"{prefix}x{cls.format_ratio_label(scale)}", cls.upsampling_value(scale)) for scale in cls.UPSAMPLING_RATIOS] |
| return choices + ([(f"{prefix}Two Pass x{cls.format_ratio_label(scale)}", cls.upsampling_two_pass_value(scale)) for scale in cls.UPSAMPLING_RATIOS] if include_two_pass else []) |
|
|
| @classmethod |
| def scale_for_upsampling(cls, spatial_upsampling) -> float | None: |
| text = str(spatial_upsampling or "").strip().lower() |
| prefix = cls.UPSAMPLING_TWO_PASS_VALUE_PREFIX if text.startswith(cls.UPSAMPLING_TWO_PASS_VALUE_PREFIX) else cls.UPSAMPLING_VALUE_PREFIX |
| if not text.startswith(prefix): |
| return None |
| try: |
| scale = float(text[len(prefix):]) |
| except ValueError: |
| return None |
| return scale if scale in cls.UPSAMPLING_RATIOS else None |
|
|
| @classmethod |
| def is_two_pass_upsampling(cls, spatial_upsampling) -> bool: |
| return str(spatial_upsampling or "").strip().lower().startswith(cls.UPSAMPLING_TWO_PASS_VALUE_PREFIX) |
|
|
| @classmethod |
| def query_edit_mode_def(cls, include_name: bool = True) -> dict[str, Any]: |
| return { |
| "name": "FlashVSR", |
| "spatial_upsampling_choices": cls.upsampling_choices(include_name=include_name, include_two_pass=True), |
| "default_spatial_upsampling": cls.upsampling_value(2.0), |
| } |
|
|
| def is_upsampling(self, spatial_upsampling) -> bool: |
| return self.scale_for_upsampling(spatial_upsampling) is not None |
|
|
| def validate_upsampling(self, spatial_upsampling, image_mode: int) -> str: |
| if not self.is_upsampling(spatial_upsampling): |
| return "" |
| if not self.enabled(): |
| return "FlashVSR Spatial Upsampling is disabled in Configuration > Extensions" |
| return "" |
|
|
| def query_download_def(self, enabled_only: bool = True) -> dict[str, Any] | None: |
| if enabled_only and not self.enabled(): |
| return None |
| return { |
| "repoId": "DeepBeepMeep/Wan2.1", |
| "sourceFolderList": ["FlashVSR", ""], |
| "fileList": [[self.TRANSFORMER_FILENAME, self.LQ_PROJ_FILENAME, self.TCDECODER_FILENAME, self.POSI_PROMPT_FILENAME], [self.VAE_FILENAME]], |
| } |
|
|
| def _locate_flashvsr_file(self, filename: str) -> str: |
| return self.files_locator.locate_file(os.path.join("FlashVSR", filename)) |
|
|
| def paths(self, variant: str): |
| from postprocessing.flashvsr.runtime import FlashVSRPaths |
| return FlashVSRPaths( |
| transformer=self._locate_flashvsr_file(self.TRANSFORMER_FILENAME), |
| lq_proj=self._locate_flashvsr_file(self.LQ_PROJ_FILENAME), |
| posi_prompt=self._locate_flashvsr_file(self.POSI_PROMPT_FILENAME), |
| tcdecoder=None if variant == "full" else self._locate_flashvsr_file(self.TCDECODER_FILENAME), |
| vae=self.files_locator.locate_file(self.VAE_FILENAME) if variant == "full" else None, |
| ) |
|
|
| def vae_tile_size(self, vae_config: int, output_height: int | None = None, output_width: int | None = None) -> int: |
| import torch |
| from models.wan.modules.vae import WanVAE |
|
|
| device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576 if torch.cuda.is_available() else 0 |
| mixed_precision = self.server_config.get("vae_precision", "16") == "32" |
| return WanVAE.get_VAE_tile_size(vae_config, device_mem_capacity, mixed_precision, output_height=output_height, output_width=output_width) |
|
|
| def download(self, process_files: Callable[..., Any], send_cmd=None, status_text: str | None = None) -> bool: |
| flashvsr_def = self.query_download_def() |
| if flashvsr_def is None: |
| return False |
| _, variant, _ = self.settings() |
| required = [os.path.join("FlashVSR", self.TRANSFORMER_FILENAME), os.path.join("FlashVSR", self.LQ_PROJ_FILENAME), os.path.join("FlashVSR", self.POSI_PROMPT_FILENAME)] |
| required.append(self.VAE_FILENAME if variant == "full" else os.path.join("FlashVSR", self.TCDECODER_FILENAME)) |
| if all(self.files_locator.locate_file(path, error_if_none=False) is not None for path in required): |
| return False |
| from shared.utils.download import send_download_status |
|
|
| send_download_status(send_cmd, status_text) |
| process_files(**flashvsr_def) |
| return True |
|
|
| def upscale(self, sample, spatial_upsampling, *, seed=0, continue_cache=None, return_continue_cache=False, vae_tile_size=None, process_files: Callable[..., Any], vae_config: int, init_pipe: Callable[..., int], profile, still_image=False, abort_callback=None, progress_callback=None): |
| scale = self.scale_for_upsampling(spatial_upsampling) |
| if scale is None: |
| raise ValueError(f"Unknown FlashVSR upsampling mode: {spatial_upsampling}") |
| enabled, variant, persistence = self.settings() |
| if not enabled: |
| raise RuntimeError("FlashVSR spatial upsampling is disabled in Configuration > Extensions.") |
| self.download(process_files) |
| from postprocessing.flashvsr.attention_backend import set_sparse_backend |
| set_sparse_backend(self.backend()) |
| from postprocessing.flashvsr.runtime import upscale_video |
|
|
| output_height = int(sample.shape[-2] * scale) |
| output_width = int(sample.shape[-1] * scale) |
| flashvsr_tile_size = self.vae_tile_size(vae_config, output_height, output_width) |
| return upscale_video( |
| sample, |
| scale, |
| self.paths(variant), |
| variant=variant, |
| seed=seed, |
| continue_cache=continue_cache, |
| return_continue_cache=return_continue_cache, |
| persistent_models=persistence == self.PERSIST_RAM, |
| vae_tile_size=flashvsr_tile_size, |
| topk_ratio=self.topk_ratio(), |
| init_pipe=init_pipe, |
| profile=profile, |
| still_image=still_image, |
| two_pass=self.is_two_pass_upsampling(spatial_upsampling), |
| abort_callback=abort_callback, |
| progress_callback=progress_callback, |
| ) |
|
|
| def release_vram(self) -> None: |
| from postprocessing.flashvsr.runtime import release_models |
| release_models() |
|
|