ColabWan / postprocessing /flashvsr /wgp_bridge.py
1ripon1's picture
Upload folder using huggingface_hub
7344bef verified
Raw
History Blame Contribute Delete
10.4 kB
from __future__ import annotations
import os
from typing import Any, Callable
from postprocessing.flashvsr.sparse_backend_config import (
SPARSE_BACKEND_AUTO,
SPARSE_BACKEND_SPARGE,
SPARSE_BACKEND_TRITON_SPARSE,
normalize_sparse_backend,
)
class FlashVSRBridge:
MODE_OFF = 0
MODE_TINY = 1
MODE_FULL = 2
MODE_TINY_LONG = 3
PERSIST_UNLOAD = 1
PERSIST_RAM = 2
BACKEND_AUTO = SPARSE_BACKEND_AUTO
BACKEND_TRITON_SPARSE = SPARSE_BACKEND_TRITON_SPARSE
BACKEND_SPARGE = SPARSE_BACKEND_SPARGE
TOPK_RATIO_DEFAULT = 0.0
TOPK_RATIO_MAX = 4.0
UPSAMPLING_VALUE_PREFIX = "flashvsr"
UPSAMPLING_TWO_PASS_VALUE_PREFIX = "flashvsr2pass"
UPSAMPLING_RATIOS = (1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0)
TRANSFORMER_FILENAME = "FlashVSR_v1.1_transformer_bf16.safetensors"
LQ_PROJ_FILENAME = "FlashVSR_v1.1_lq_proj_bf16.safetensors"
TCDECODER_FILENAME = "FlashVSR_v1.1_tcdecoder_bf16.safetensors"
POSI_PROMPT_FILENAME = "FlashVSR_v1.1_posi_prompt_bf16.safetensors"
VAE_FILENAME = "Wan2.1_VAE.safetensors"
_VARIANTS = {
MODE_TINY: "tiny",
MODE_FULL: "full",
MODE_TINY_LONG: "tiny-long",
}
def __init__(self, server_config: dict[str, Any], files_locator):
self.server_config = server_config
self.files_locator = files_locator
@classmethod
def normalize_topk_ratio(cls, value: Any) -> float:
try:
value = float(value)
except (TypeError, ValueError):
value = cls.TOPK_RATIO_DEFAULT
return max(0.0, min(cls.TOPK_RATIO_MAX, value))
@classmethod
def normalize_backend(cls, value: Any) -> str:
return normalize_sparse_backend(value)
def normalize_config(self, config: dict[str, Any] | None = None) -> tuple[int, int]:
config = self.server_config if config is None else config
mode = config.get("flashvsr_mode", self.MODE_OFF)
persistence = config.get("flashvsr_persistence", self.PERSIST_UNLOAD)
try:
mode = int(mode)
except (TypeError, ValueError):
mode = self.MODE_OFF
try:
persistence = int(persistence)
except (TypeError, ValueError):
persistence = self.PERSIST_UNLOAD
if mode not in self._VARIANTS and mode != self.MODE_OFF:
mode = self.MODE_OFF
if persistence not in (self.PERSIST_UNLOAD, self.PERSIST_RAM):
persistence = self.PERSIST_UNLOAD
config["flashvsr_mode"] = mode
config["flashvsr_persistence"] = persistence
config["flashvsr_backend"] = self.normalize_backend(config.get("flashvsr_backend", self.BACKEND_AUTO))
config["flashvsr_topk_ratio"] = self.normalize_topk_ratio(config.get("flashvsr_topk_ratio", self.TOPK_RATIO_DEFAULT))
return mode, persistence
def settings(self, config: dict[str, Any] | None = None) -> tuple[bool, str | None, int]:
mode, persistence = self.normalize_config(config)
return mode != self.MODE_OFF, self._VARIANTS.get(mode), persistence
def topk_ratio(self) -> float:
return self.normalize_topk_ratio(self.server_config.get("flashvsr_topk_ratio", self.TOPK_RATIO_DEFAULT))
def backend(self) -> str:
return self.normalize_backend(self.server_config.get("flashvsr_backend", self.BACKEND_AUTO))
def enabled(self) -> bool:
return self.settings()[0]
@classmethod
def format_ratio(cls, scale: float) -> str:
scale = float(scale)
return str(int(scale)) if scale.is_integer() else f"{scale:g}"
@classmethod
def format_ratio_label(cls, scale: float) -> str:
return f"{float(scale):.1f}"
@classmethod
def upsampling_value(cls, scale: float) -> str:
return f"{cls.UPSAMPLING_VALUE_PREFIX}{cls.format_ratio(scale)}"
@classmethod
def upsampling_two_pass_value(cls, scale: float) -> str:
return f"{cls.UPSAMPLING_TWO_PASS_VALUE_PREFIX}{cls.format_ratio(scale)}"
@classmethod
def upsampling_choices(cls, include_name: bool = True, include_two_pass: bool = False) -> list[tuple[str, str]]:
prefix = "FlashVSR " if include_name else ""
choices = [(f"{prefix}x{cls.format_ratio_label(scale)}", cls.upsampling_value(scale)) for scale in cls.UPSAMPLING_RATIOS]
return choices + ([(f"{prefix}Two Pass x{cls.format_ratio_label(scale)}", cls.upsampling_two_pass_value(scale)) for scale in cls.UPSAMPLING_RATIOS] if include_two_pass else [])
@classmethod
def scale_for_upsampling(cls, spatial_upsampling) -> float | None:
text = str(spatial_upsampling or "").strip().lower()
prefix = cls.UPSAMPLING_TWO_PASS_VALUE_PREFIX if text.startswith(cls.UPSAMPLING_TWO_PASS_VALUE_PREFIX) else cls.UPSAMPLING_VALUE_PREFIX
if not text.startswith(prefix):
return None
try:
scale = float(text[len(prefix):])
except ValueError:
return None
return scale if scale in cls.UPSAMPLING_RATIOS else None
@classmethod
def is_two_pass_upsampling(cls, spatial_upsampling) -> bool:
return str(spatial_upsampling or "").strip().lower().startswith(cls.UPSAMPLING_TWO_PASS_VALUE_PREFIX)
@classmethod
def query_edit_mode_def(cls, include_name: bool = True) -> dict[str, Any]:
return {
"name": "FlashVSR",
"spatial_upsampling_choices": cls.upsampling_choices(include_name=include_name, include_two_pass=True),
"default_spatial_upsampling": cls.upsampling_value(2.0),
}
def is_upsampling(self, spatial_upsampling) -> bool:
return self.scale_for_upsampling(spatial_upsampling) is not None
def validate_upsampling(self, spatial_upsampling, image_mode: int) -> str:
if not self.is_upsampling(spatial_upsampling):
return ""
if not self.enabled():
return "FlashVSR Spatial Upsampling is disabled in Configuration > Extensions"
return ""
def query_download_def(self, enabled_only: bool = True) -> dict[str, Any] | None:
if enabled_only and not self.enabled():
return None
return {
"repoId": "DeepBeepMeep/Wan2.1",
"sourceFolderList": ["FlashVSR", ""],
"fileList": [[self.TRANSFORMER_FILENAME, self.LQ_PROJ_FILENAME, self.TCDECODER_FILENAME, self.POSI_PROMPT_FILENAME], [self.VAE_FILENAME]],
}
def _locate_flashvsr_file(self, filename: str) -> str:
return self.files_locator.locate_file(os.path.join("FlashVSR", filename))
def paths(self, variant: str):
from postprocessing.flashvsr.runtime import FlashVSRPaths
return FlashVSRPaths(
transformer=self._locate_flashvsr_file(self.TRANSFORMER_FILENAME),
lq_proj=self._locate_flashvsr_file(self.LQ_PROJ_FILENAME),
posi_prompt=self._locate_flashvsr_file(self.POSI_PROMPT_FILENAME),
tcdecoder=None if variant == "full" else self._locate_flashvsr_file(self.TCDECODER_FILENAME),
vae=self.files_locator.locate_file(self.VAE_FILENAME) if variant == "full" else None,
)
def vae_tile_size(self, vae_config: int, output_height: int | None = None, output_width: int | None = None) -> int:
import torch
from models.wan.modules.vae import WanVAE
device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576 if torch.cuda.is_available() else 0
mixed_precision = self.server_config.get("vae_precision", "16") == "32"
return WanVAE.get_VAE_tile_size(vae_config, device_mem_capacity, mixed_precision, output_height=output_height, output_width=output_width)
def download(self, process_files: Callable[..., Any], send_cmd=None, status_text: str | None = None) -> bool:
flashvsr_def = self.query_download_def()
if flashvsr_def is None:
return False
_, variant, _ = self.settings()
required = [os.path.join("FlashVSR", self.TRANSFORMER_FILENAME), os.path.join("FlashVSR", self.LQ_PROJ_FILENAME), os.path.join("FlashVSR", self.POSI_PROMPT_FILENAME)]
required.append(self.VAE_FILENAME if variant == "full" else os.path.join("FlashVSR", self.TCDECODER_FILENAME))
if all(self.files_locator.locate_file(path, error_if_none=False) is not None for path in required):
return False
from shared.utils.download import send_download_status
send_download_status(send_cmd, status_text)
process_files(**flashvsr_def)
return True
def upscale(self, sample, spatial_upsampling, *, seed=0, continue_cache=None, return_continue_cache=False, vae_tile_size=None, process_files: Callable[..., Any], vae_config: int, init_pipe: Callable[..., int], profile, still_image=False, abort_callback=None, progress_callback=None):
scale = self.scale_for_upsampling(spatial_upsampling)
if scale is None:
raise ValueError(f"Unknown FlashVSR upsampling mode: {spatial_upsampling}")
enabled, variant, persistence = self.settings()
if not enabled:
raise RuntimeError("FlashVSR spatial upsampling is disabled in Configuration > Extensions.")
self.download(process_files)
from postprocessing.flashvsr.attention_backend import set_sparse_backend
set_sparse_backend(self.backend())
from postprocessing.flashvsr.runtime import upscale_video
output_height = int(sample.shape[-2] * scale)
output_width = int(sample.shape[-1] * scale)
flashvsr_tile_size = self.vae_tile_size(vae_config, output_height, output_width)
return upscale_video(
sample,
scale,
self.paths(variant),
variant=variant,
seed=seed,
continue_cache=continue_cache,
return_continue_cache=return_continue_cache,
persistent_models=persistence == self.PERSIST_RAM,
vae_tile_size=flashvsr_tile_size,
topk_ratio=self.topk_ratio(),
init_pipe=init_pipe,
profile=profile,
still_image=still_image,
two_pass=self.is_two_pass_upsampling(spatial_upsampling),
abort_callback=abort_callback,
progress_callback=progress_callback,
)
def release_vram(self) -> None:
from postprocessing.flashvsr.runtime import release_models
release_models()