Instructions to use mhnakif/comfy2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use mhnakif/comfy2 with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("mhnakif/comfy2", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +6 -0
- custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/.gitignore +210 -0
- custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/LICENSE +674 -0
- custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/README.md +68 -0
- custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/README_zh.md +66 -0
- custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/__init__.py +3 -0
- custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/img/preview.jpg +3 -0
- custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/nodes.py +553 -0
- custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/posi_prompt.pth +3 -0
- custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/requirements.txt +10 -0
- custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/LICENSE.txt +201 -0
- custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/__init__.py +3 -0
- custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/configs/__init__.py +0 -0
- custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/configs/model_config.py +29 -0
- custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/TCDecoder.py +320 -0
- custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/__init__.py +1 -0
- custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/model_manager.py +402 -0
- custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/sparse_sage/LICENSE.txt +201 -0
- custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/sparse_sage/core.py +45 -0
- custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/sparse_sage/quant_per_block.py +101 -0
- custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/sparse_sage/sparse_int8_attn.py +162 -0
- custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/utils.py +462 -0
- custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/wan_video_dit.py +864 -0
- custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/wan_video_vae.py +847 -0
- custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/pipelines/__init__.py +3 -0
- custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/pipelines/base.py +130 -0
- custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/pipelines/flashvsr_full.py +618 -0
- custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/pipelines/flashvsr_tiny.py +615 -0
- custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/pipelines/flashvsr_tiny_long.py +620 -0
- custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/schedulers/__init__.py +1 -0
- custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/schedulers/flow_match.py +79 -0
- custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/vram_management/__init__.py +1 -0
- custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/vram_management/layers.py +95 -0
- custom_nodes/ComfyUI-LCS/.gitignore +2 -0
- custom_nodes/ComfyUI-LCS/README.md +344 -0
- custom_nodes/ComfyUI-LCS/README_zh.md +343 -0
- custom_nodes/ComfyUI-LCS/__init__.py +53 -0
- custom_nodes/ComfyUI-LCS/core/__init__.py +10 -0
- custom_nodes/ComfyUI-LCS/core/adaptive.py +109 -0
- custom_nodes/ComfyUI-LCS/core/bilateral.py +79 -0
- custom_nodes/ComfyUI-LCS/core/calibration.py +214 -0
- custom_nodes/ComfyUI-LCS/core/color_space.py +380 -0
- custom_nodes/ComfyUI-LCS/core/defaults.py +65 -0
- custom_nodes/ComfyUI-LCS/core/diagnostics.py +246 -0
- custom_nodes/ComfyUI-LCS/core/lcs_data.py +28 -0
- custom_nodes/ComfyUI-LCS/core/patchify.py +93 -0
- custom_nodes/ComfyUI-LCS/core/relationships.py +117 -0
- custom_nodes/ComfyUI-LCS/core/sampling.py +105 -0
- custom_nodes/ComfyUI-LCS/core/sharpness.py +213 -0
- custom_nodes/ComfyUI-LCS/core/timestep.py +75 -0
.gitattributes
CHANGED
|
@@ -142,6 +142,8 @@ models/unet/zit_beyond_reality.safetensors filter=lfs diff=lfs merge=lfs -text
|
|
| 142 |
models/vae/ae.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 143 |
models/vae/flux2-vae.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 144 |
models/vae/wan_2.1_vae.safetensors filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
| 145 |
models/FlashVSR/FlashVSR1_1.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 146 |
models/FlashVSR/LQ_proj_in.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 147 |
models/FlashVSR/Prompt.safetensors filter=lfs diff=lfs merge=lfs -text
|
|
@@ -149,3 +151,7 @@ models/FlashVSR/TCDecoder.safetensors filter=lfs diff=lfs merge=lfs -text
|
|
| 149 |
models/FlashVSR/Wan2.1_VAE.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 150 |
models/FlashVSR/Wan2_1-T2V-1_3B_FlashVSR_fp32.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 151 |
models/FlashVSR/Wan2_1_FlashVSR_LQ_proj_model_bf16.safetensors filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
models/vae/ae.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 143 |
models/vae/flux2-vae.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 144 |
models/vae/wan_2.1_vae.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 145 |
+
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/img/preview.jpg filter=lfs diff=lfs merge=lfs -text
|
| 146 |
+
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/posi_prompt.pth filter=lfs diff=lfs merge=lfs -text
|
| 147 |
models/FlashVSR/FlashVSR1_1.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 148 |
models/FlashVSR/LQ_proj_in.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 149 |
models/FlashVSR/Prompt.safetensors filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 151 |
models/FlashVSR/Wan2.1_VAE.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 152 |
models/FlashVSR/Wan2_1-T2V-1_3B_FlashVSR_fp32.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 153 |
models/FlashVSR/Wan2_1_FlashVSR_LQ_proj_model_bf16.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 154 |
+
models/FlashVSR-v1.1/LQ_proj_in.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 155 |
+
models/FlashVSR-v1.1/TCDecoder.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 156 |
+
models/FlashVSR-v1.1/Wan2.1_VAE.pth filter=lfs diff=lfs merge=lfs -text
|
| 157 |
+
models/FlashVSR-v1.1/diffusion_pytorch_model_streaming_dmd.safetensors filter=lfs diff=lfs merge=lfs -text
|
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/.gitignore
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[codz]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py.cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# UV
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
#uv.lock
|
| 102 |
+
|
| 103 |
+
# poetry
|
| 104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 106 |
+
# commonly ignored for libraries.
|
| 107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 108 |
+
#poetry.lock
|
| 109 |
+
#poetry.toml
|
| 110 |
+
|
| 111 |
+
# pdm
|
| 112 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 113 |
+
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
| 114 |
+
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
| 115 |
+
#pdm.lock
|
| 116 |
+
#pdm.toml
|
| 117 |
+
.pdm-python
|
| 118 |
+
.pdm-build/
|
| 119 |
+
|
| 120 |
+
# pixi
|
| 121 |
+
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
| 122 |
+
#pixi.lock
|
| 123 |
+
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
| 124 |
+
# in the .venv directory. It is recommended not to include this directory in version control.
|
| 125 |
+
.pixi
|
| 126 |
+
|
| 127 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 128 |
+
__pypackages__/
|
| 129 |
+
|
| 130 |
+
# Celery stuff
|
| 131 |
+
celerybeat-schedule
|
| 132 |
+
celerybeat.pid
|
| 133 |
+
|
| 134 |
+
# SageMath parsed files
|
| 135 |
+
*.sage.py
|
| 136 |
+
|
| 137 |
+
# Environments
|
| 138 |
+
.env
|
| 139 |
+
.envrc
|
| 140 |
+
.venv
|
| 141 |
+
env/
|
| 142 |
+
venv/
|
| 143 |
+
ENV/
|
| 144 |
+
env.bak/
|
| 145 |
+
venv.bak/
|
| 146 |
+
|
| 147 |
+
# Spyder project settings
|
| 148 |
+
.spyderproject
|
| 149 |
+
.spyproject
|
| 150 |
+
|
| 151 |
+
# Rope project settings
|
| 152 |
+
.ropeproject
|
| 153 |
+
|
| 154 |
+
# mkdocs documentation
|
| 155 |
+
/site
|
| 156 |
+
|
| 157 |
+
# mypy
|
| 158 |
+
.mypy_cache/
|
| 159 |
+
.dmypy.json
|
| 160 |
+
dmypy.json
|
| 161 |
+
|
| 162 |
+
# Pyre type checker
|
| 163 |
+
.pyre/
|
| 164 |
+
|
| 165 |
+
# pytype static type analyzer
|
| 166 |
+
.pytype/
|
| 167 |
+
|
| 168 |
+
# Cython debug symbols
|
| 169 |
+
cython_debug/
|
| 170 |
+
|
| 171 |
+
# PyCharm
|
| 172 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 173 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 174 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 175 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 176 |
+
#.idea/
|
| 177 |
+
|
| 178 |
+
# Abstra
|
| 179 |
+
# Abstra is an AI-powered process automation framework.
|
| 180 |
+
# Ignore directories containing user credentials, local state, and settings.
|
| 181 |
+
# Learn more at https://abstra.io/docs
|
| 182 |
+
.abstra/
|
| 183 |
+
|
| 184 |
+
# Visual Studio Code
|
| 185 |
+
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
| 186 |
+
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
| 187 |
+
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
| 188 |
+
# you could uncomment the following to ignore the entire vscode folder
|
| 189 |
+
# .vscode/
|
| 190 |
+
|
| 191 |
+
# Ruff stuff:
|
| 192 |
+
.ruff_cache/
|
| 193 |
+
|
| 194 |
+
# PyPI configuration file
|
| 195 |
+
.pypirc
|
| 196 |
+
|
| 197 |
+
# Cursor
|
| 198 |
+
# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
|
| 199 |
+
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
|
| 200 |
+
# refer to https://docs.cursor.com/context/ignore-files
|
| 201 |
+
.cursorignore
|
| 202 |
+
.cursorindexingignore
|
| 203 |
+
|
| 204 |
+
# Marimo
|
| 205 |
+
marimo/_static/
|
| 206 |
+
marimo/_lsp/
|
| 207 |
+
__marimo__/
|
| 208 |
+
|
| 209 |
+
# macOS
|
| 210 |
+
.DS_Store
|
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/LICENSE
ADDED
|
@@ -0,0 +1,674 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
GNU GENERAL PUBLIC LICENSE
|
| 2 |
+
Version 3, 29 June 2007
|
| 3 |
+
|
| 4 |
+
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
| 5 |
+
Everyone is permitted to copy and distribute verbatim copies
|
| 6 |
+
of this license document, but changing it is not allowed.
|
| 7 |
+
|
| 8 |
+
Preamble
|
| 9 |
+
|
| 10 |
+
The GNU General Public License is a free, copyleft license for
|
| 11 |
+
software and other kinds of works.
|
| 12 |
+
|
| 13 |
+
The licenses for most software and other practical works are designed
|
| 14 |
+
to take away your freedom to share and change the works. By contrast,
|
| 15 |
+
the GNU General Public License is intended to guarantee your freedom to
|
| 16 |
+
share and change all versions of a program--to make sure it remains free
|
| 17 |
+
software for all its users. We, the Free Software Foundation, use the
|
| 18 |
+
GNU General Public License for most of our software; it applies also to
|
| 19 |
+
any other work released this way by its authors. You can apply it to
|
| 20 |
+
your programs, too.
|
| 21 |
+
|
| 22 |
+
When we speak of free software, we are referring to freedom, not
|
| 23 |
+
price. Our General Public Licenses are designed to make sure that you
|
| 24 |
+
have the freedom to distribute copies of free software (and charge for
|
| 25 |
+
them if you wish), that you receive source code or can get it if you
|
| 26 |
+
want it, that you can change the software or use pieces of it in new
|
| 27 |
+
free programs, and that you know you can do these things.
|
| 28 |
+
|
| 29 |
+
To protect your rights, we need to prevent others from denying you
|
| 30 |
+
these rights or asking you to surrender the rights. Therefore, you have
|
| 31 |
+
certain responsibilities if you distribute copies of the software, or if
|
| 32 |
+
you modify it: responsibilities to respect the freedom of others.
|
| 33 |
+
|
| 34 |
+
For example, if you distribute copies of such a program, whether
|
| 35 |
+
gratis or for a fee, you must pass on to the recipients the same
|
| 36 |
+
freedoms that you received. You must make sure that they, too, receive
|
| 37 |
+
or can get the source code. And you must show them these terms so they
|
| 38 |
+
know their rights.
|
| 39 |
+
|
| 40 |
+
Developers that use the GNU GPL protect your rights with two steps:
|
| 41 |
+
(1) assert copyright on the software, and (2) offer you this License
|
| 42 |
+
giving you legal permission to copy, distribute and/or modify it.
|
| 43 |
+
|
| 44 |
+
For the developers' and authors' protection, the GPL clearly explains
|
| 45 |
+
that there is no warranty for this free software. For both users' and
|
| 46 |
+
authors' sake, the GPL requires that modified versions be marked as
|
| 47 |
+
changed, so that their problems will not be attributed erroneously to
|
| 48 |
+
authors of previous versions.
|
| 49 |
+
|
| 50 |
+
Some devices are designed to deny users access to install or run
|
| 51 |
+
modified versions of the software inside them, although the manufacturer
|
| 52 |
+
can do so. This is fundamentally incompatible with the aim of
|
| 53 |
+
protecting users' freedom to change the software. The systematic
|
| 54 |
+
pattern of such abuse occurs in the area of products for individuals to
|
| 55 |
+
use, which is precisely where it is most unacceptable. Therefore, we
|
| 56 |
+
have designed this version of the GPL to prohibit the practice for those
|
| 57 |
+
products. If such problems arise substantially in other domains, we
|
| 58 |
+
stand ready to extend this provision to those domains in future versions
|
| 59 |
+
of the GPL, as needed to protect the freedom of users.
|
| 60 |
+
|
| 61 |
+
Finally, every program is threatened constantly by software patents.
|
| 62 |
+
States should not allow patents to restrict development and use of
|
| 63 |
+
software on general-purpose computers, but in those that do, we wish to
|
| 64 |
+
avoid the special danger that patents applied to a free program could
|
| 65 |
+
make it effectively proprietary. To prevent this, the GPL assures that
|
| 66 |
+
patents cannot be used to render the program non-free.
|
| 67 |
+
|
| 68 |
+
The precise terms and conditions for copying, distribution and
|
| 69 |
+
modification follow.
|
| 70 |
+
|
| 71 |
+
TERMS AND CONDITIONS
|
| 72 |
+
|
| 73 |
+
0. Definitions.
|
| 74 |
+
|
| 75 |
+
"This License" refers to version 3 of the GNU General Public License.
|
| 76 |
+
|
| 77 |
+
"Copyright" also means copyright-like laws that apply to other kinds of
|
| 78 |
+
works, such as semiconductor masks.
|
| 79 |
+
|
| 80 |
+
"The Program" refers to any copyrightable work licensed under this
|
| 81 |
+
License. Each licensee is addressed as "you". "Licensees" and
|
| 82 |
+
"recipients" may be individuals or organizations.
|
| 83 |
+
|
| 84 |
+
To "modify" a work means to copy from or adapt all or part of the work
|
| 85 |
+
in a fashion requiring copyright permission, other than the making of an
|
| 86 |
+
exact copy. The resulting work is called a "modified version" of the
|
| 87 |
+
earlier work or a work "based on" the earlier work.
|
| 88 |
+
|
| 89 |
+
A "covered work" means either the unmodified Program or a work based
|
| 90 |
+
on the Program.
|
| 91 |
+
|
| 92 |
+
To "propagate" a work means to do anything with it that, without
|
| 93 |
+
permission, would make you directly or secondarily liable for
|
| 94 |
+
infringement under applicable copyright law, except executing it on a
|
| 95 |
+
computer or modifying a private copy. Propagation includes copying,
|
| 96 |
+
distribution (with or without modification), making available to the
|
| 97 |
+
public, and in some countries other activities as well.
|
| 98 |
+
|
| 99 |
+
To "convey" a work means any kind of propagation that enables other
|
| 100 |
+
parties to make or receive copies. Mere interaction with a user through
|
| 101 |
+
a computer network, with no transfer of a copy, is not conveying.
|
| 102 |
+
|
| 103 |
+
An interactive user interface displays "Appropriate Legal Notices"
|
| 104 |
+
to the extent that it includes a convenient and prominently visible
|
| 105 |
+
feature that (1) displays an appropriate copyright notice, and (2)
|
| 106 |
+
tells the user that there is no warranty for the work (except to the
|
| 107 |
+
extent that warranties are provided), that licensees may convey the
|
| 108 |
+
work under this License, and how to view a copy of this License. If
|
| 109 |
+
the interface presents a list of user commands or options, such as a
|
| 110 |
+
menu, a prominent item in the list meets this criterion.
|
| 111 |
+
|
| 112 |
+
1. Source Code.
|
| 113 |
+
|
| 114 |
+
The "source code" for a work means the preferred form of the work
|
| 115 |
+
for making modifications to it. "Object code" means any non-source
|
| 116 |
+
form of a work.
|
| 117 |
+
|
| 118 |
+
A "Standard Interface" means an interface that either is an official
|
| 119 |
+
standard defined by a recognized standards body, or, in the case of
|
| 120 |
+
interfaces specified for a particular programming language, one that
|
| 121 |
+
is widely used among developers working in that language.
|
| 122 |
+
|
| 123 |
+
The "System Libraries" of an executable work include anything, other
|
| 124 |
+
than the work as a whole, that (a) is included in the normal form of
|
| 125 |
+
packaging a Major Component, but which is not part of that Major
|
| 126 |
+
Component, and (b) serves only to enable use of the work with that
|
| 127 |
+
Major Component, or to implement a Standard Interface for which an
|
| 128 |
+
implementation is available to the public in source code form. A
|
| 129 |
+
"Major Component", in this context, means a major essential component
|
| 130 |
+
(kernel, window system, and so on) of the specific operating system
|
| 131 |
+
(if any) on which the executable work runs, or a compiler used to
|
| 132 |
+
produce the work, or an object code interpreter used to run it.
|
| 133 |
+
|
| 134 |
+
The "Corresponding Source" for a work in object code form means all
|
| 135 |
+
the source code needed to generate, install, and (for an executable
|
| 136 |
+
work) run the object code and to modify the work, including scripts to
|
| 137 |
+
control those activities. However, it does not include the work's
|
| 138 |
+
System Libraries, or general-purpose tools or generally available free
|
| 139 |
+
programs which are used unmodified in performing those activities but
|
| 140 |
+
which are not part of the work. For example, Corresponding Source
|
| 141 |
+
includes interface definition files associated with source files for
|
| 142 |
+
the work, and the source code for shared libraries and dynamically
|
| 143 |
+
linked subprograms that the work is specifically designed to require,
|
| 144 |
+
such as by intimate data communication or control flow between those
|
| 145 |
+
subprograms and other parts of the work.
|
| 146 |
+
|
| 147 |
+
The Corresponding Source need not include anything that users
|
| 148 |
+
can regenerate automatically from other parts of the Corresponding
|
| 149 |
+
Source.
|
| 150 |
+
|
| 151 |
+
The Corresponding Source for a work in source code form is that
|
| 152 |
+
same work.
|
| 153 |
+
|
| 154 |
+
2. Basic Permissions.
|
| 155 |
+
|
| 156 |
+
All rights granted under this License are granted for the term of
|
| 157 |
+
copyright on the Program, and are irrevocable provided the stated
|
| 158 |
+
conditions are met. This License explicitly affirms your unlimited
|
| 159 |
+
permission to run the unmodified Program. The output from running a
|
| 160 |
+
covered work is covered by this License only if the output, given its
|
| 161 |
+
content, constitutes a covered work. This License acknowledges your
|
| 162 |
+
rights of fair use or other equivalent, as provided by copyright law.
|
| 163 |
+
|
| 164 |
+
You may make, run and propagate covered works that you do not
|
| 165 |
+
convey, without conditions so long as your license otherwise remains
|
| 166 |
+
in force. You may convey covered works to others for the sole purpose
|
| 167 |
+
of having them make modifications exclusively for you, or provide you
|
| 168 |
+
with facilities for running those works, provided that you comply with
|
| 169 |
+
the terms of this License in conveying all material for which you do
|
| 170 |
+
not control copyright. Those thus making or running the covered works
|
| 171 |
+
for you must do so exclusively on your behalf, under your direction
|
| 172 |
+
and control, on terms that prohibit them from making any copies of
|
| 173 |
+
your copyrighted material outside their relationship with you.
|
| 174 |
+
|
| 175 |
+
Conveying under any other circumstances is permitted solely under
|
| 176 |
+
the conditions stated below. Sublicensing is not allowed; section 10
|
| 177 |
+
makes it unnecessary.
|
| 178 |
+
|
| 179 |
+
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
| 180 |
+
|
| 181 |
+
No covered work shall be deemed part of an effective technological
|
| 182 |
+
measure under any applicable law fulfilling obligations under article
|
| 183 |
+
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
| 184 |
+
similar laws prohibiting or restricting circumvention of such
|
| 185 |
+
measures.
|
| 186 |
+
|
| 187 |
+
When you convey a covered work, you waive any legal power to forbid
|
| 188 |
+
circumvention of technological measures to the extent such circumvention
|
| 189 |
+
is effected by exercising rights under this License with respect to
|
| 190 |
+
the covered work, and you disclaim any intention to limit operation or
|
| 191 |
+
modification of the work as a means of enforcing, against the work's
|
| 192 |
+
users, your or third parties' legal rights to forbid circumvention of
|
| 193 |
+
technological measures.
|
| 194 |
+
|
| 195 |
+
4. Conveying Verbatim Copies.
|
| 196 |
+
|
| 197 |
+
You may convey verbatim copies of the Program's source code as you
|
| 198 |
+
receive it, in any medium, provided that you conspicuously and
|
| 199 |
+
appropriately publish on each copy an appropriate copyright notice;
|
| 200 |
+
keep intact all notices stating that this License and any
|
| 201 |
+
non-permissive terms added in accord with section 7 apply to the code;
|
| 202 |
+
keep intact all notices of the absence of any warranty; and give all
|
| 203 |
+
recipients a copy of this License along with the Program.
|
| 204 |
+
|
| 205 |
+
You may charge any price or no price for each copy that you convey,
|
| 206 |
+
and you may offer support or warranty protection for a fee.
|
| 207 |
+
|
| 208 |
+
5. Conveying Modified Source Versions.
|
| 209 |
+
|
| 210 |
+
You may convey a work based on the Program, or the modifications to
|
| 211 |
+
produce it from the Program, in the form of source code under the
|
| 212 |
+
terms of section 4, provided that you also meet all of these conditions:
|
| 213 |
+
|
| 214 |
+
a) The work must carry prominent notices stating that you modified
|
| 215 |
+
it, and giving a relevant date.
|
| 216 |
+
|
| 217 |
+
b) The work must carry prominent notices stating that it is
|
| 218 |
+
released under this License and any conditions added under section
|
| 219 |
+
7. This requirement modifies the requirement in section 4 to
|
| 220 |
+
"keep intact all notices".
|
| 221 |
+
|
| 222 |
+
c) You must license the entire work, as a whole, under this
|
| 223 |
+
License to anyone who comes into possession of a copy. This
|
| 224 |
+
License will therefore apply, along with any applicable section 7
|
| 225 |
+
additional terms, to the whole of the work, and all its parts,
|
| 226 |
+
regardless of how they are packaged. This License gives no
|
| 227 |
+
permission to license the work in any other way, but it does not
|
| 228 |
+
invalidate such permission if you have separately received it.
|
| 229 |
+
|
| 230 |
+
d) If the work has interactive user interfaces, each must display
|
| 231 |
+
Appropriate Legal Notices; however, if the Program has interactive
|
| 232 |
+
interfaces that do not display Appropriate Legal Notices, your
|
| 233 |
+
work need not make them do so.
|
| 234 |
+
|
| 235 |
+
A compilation of a covered work with other separate and independent
|
| 236 |
+
works, which are not by their nature extensions of the covered work,
|
| 237 |
+
and which are not combined with it such as to form a larger program,
|
| 238 |
+
in or on a volume of a storage or distribution medium, is called an
|
| 239 |
+
"aggregate" if the compilation and its resulting copyright are not
|
| 240 |
+
used to limit the access or legal rights of the compilation's users
|
| 241 |
+
beyond what the individual works permit. Inclusion of a covered work
|
| 242 |
+
in an aggregate does not cause this License to apply to the other
|
| 243 |
+
parts of the aggregate.
|
| 244 |
+
|
| 245 |
+
6. Conveying Non-Source Forms.
|
| 246 |
+
|
| 247 |
+
You may convey a covered work in object code form under the terms
|
| 248 |
+
of sections 4 and 5, provided that you also convey the
|
| 249 |
+
machine-readable Corresponding Source under the terms of this License,
|
| 250 |
+
in one of these ways:
|
| 251 |
+
|
| 252 |
+
a) Convey the object code in, or embodied in, a physical product
|
| 253 |
+
(including a physical distribution medium), accompanied by the
|
| 254 |
+
Corresponding Source fixed on a durable physical medium
|
| 255 |
+
customarily used for software interchange.
|
| 256 |
+
|
| 257 |
+
b) Convey the object code in, or embodied in, a physical product
|
| 258 |
+
(including a physical distribution medium), accompanied by a
|
| 259 |
+
written offer, valid for at least three years and valid for as
|
| 260 |
+
long as you offer spare parts or customer support for that product
|
| 261 |
+
model, to give anyone who possesses the object code either (1) a
|
| 262 |
+
copy of the Corresponding Source for all the software in the
|
| 263 |
+
product that is covered by this License, on a durable physical
|
| 264 |
+
medium customarily used for software interchange, for a price no
|
| 265 |
+
more than your reasonable cost of physically performing this
|
| 266 |
+
conveying of source, or (2) access to copy the
|
| 267 |
+
Corresponding Source from a network server at no charge.
|
| 268 |
+
|
| 269 |
+
c) Convey individual copies of the object code with a copy of the
|
| 270 |
+
written offer to provide the Corresponding Source. This
|
| 271 |
+
alternative is allowed only occasionally and noncommercially, and
|
| 272 |
+
only if you received the object code with such an offer, in accord
|
| 273 |
+
with subsection 6b.
|
| 274 |
+
|
| 275 |
+
d) Convey the object code by offering access from a designated
|
| 276 |
+
place (gratis or for a charge), and offer equivalent access to the
|
| 277 |
+
Corresponding Source in the same way through the same place at no
|
| 278 |
+
further charge. You need not require recipients to copy the
|
| 279 |
+
Corresponding Source along with the object code. If the place to
|
| 280 |
+
copy the object code is a network server, the Corresponding Source
|
| 281 |
+
may be on a different server (operated by you or a third party)
|
| 282 |
+
that supports equivalent copying facilities, provided you maintain
|
| 283 |
+
clear directions next to the object code saying where to find the
|
| 284 |
+
Corresponding Source. Regardless of what server hosts the
|
| 285 |
+
Corresponding Source, you remain obligated to ensure that it is
|
| 286 |
+
available for as long as needed to satisfy these requirements.
|
| 287 |
+
|
| 288 |
+
e) Convey the object code using peer-to-peer transmission, provided
|
| 289 |
+
you inform other peers where the object code and Corresponding
|
| 290 |
+
Source of the work are being offered to the general public at no
|
| 291 |
+
charge under subsection 6d.
|
| 292 |
+
|
| 293 |
+
A separable portion of the object code, whose source code is excluded
|
| 294 |
+
from the Corresponding Source as a System Library, need not be
|
| 295 |
+
included in conveying the object code work.
|
| 296 |
+
|
| 297 |
+
A "User Product" is either (1) a "consumer product", which means any
|
| 298 |
+
tangible personal property which is normally used for personal, family,
|
| 299 |
+
or household purposes, or (2) anything designed or sold for incorporation
|
| 300 |
+
into a dwelling. In determining whether a product is a consumer product,
|
| 301 |
+
doubtful cases shall be resolved in favor of coverage. For a particular
|
| 302 |
+
product received by a particular user, "normally used" refers to a
|
| 303 |
+
typical or common use of that class of product, regardless of the status
|
| 304 |
+
of the particular user or of the way in which the particular user
|
| 305 |
+
actually uses, or expects or is expected to use, the product. A product
|
| 306 |
+
is a consumer product regardless of whether the product has substantial
|
| 307 |
+
commercial, industrial or non-consumer uses, unless such uses represent
|
| 308 |
+
the only significant mode of use of the product.
|
| 309 |
+
|
| 310 |
+
"Installation Information" for a User Product means any methods,
|
| 311 |
+
procedures, authorization keys, or other information required to install
|
| 312 |
+
and execute modified versions of a covered work in that User Product from
|
| 313 |
+
a modified version of its Corresponding Source. The information must
|
| 314 |
+
suffice to ensure that the continued functioning of the modified object
|
| 315 |
+
code is in no case prevented or interfered with solely because
|
| 316 |
+
modification has been made.
|
| 317 |
+
|
| 318 |
+
If you convey an object code work under this section in, or with, or
|
| 319 |
+
specifically for use in, a User Product, and the conveying occurs as
|
| 320 |
+
part of a transaction in which the right of possession and use of the
|
| 321 |
+
User Product is transferred to the recipient in perpetuity or for a
|
| 322 |
+
fixed term (regardless of how the transaction is characterized), the
|
| 323 |
+
Corresponding Source conveyed under this section must be accompanied
|
| 324 |
+
by the Installation Information. But this requirement does not apply
|
| 325 |
+
if neither you nor any third party retains the ability to install
|
| 326 |
+
modified object code on the User Product (for example, the work has
|
| 327 |
+
been installed in ROM).
|
| 328 |
+
|
| 329 |
+
The requirement to provide Installation Information does not include a
|
| 330 |
+
requirement to continue to provide support service, warranty, or updates
|
| 331 |
+
for a work that has been modified or installed by the recipient, or for
|
| 332 |
+
the User Product in which it has been modified or installed. Access to a
|
| 333 |
+
network may be denied when the modification itself materially and
|
| 334 |
+
adversely affects the operation of the network or violates the rules and
|
| 335 |
+
protocols for communication across the network.
|
| 336 |
+
|
| 337 |
+
Corresponding Source conveyed, and Installation Information provided,
|
| 338 |
+
in accord with this section must be in a format that is publicly
|
| 339 |
+
documented (and with an implementation available to the public in
|
| 340 |
+
source code form), and must require no special password or key for
|
| 341 |
+
unpacking, reading or copying.
|
| 342 |
+
|
| 343 |
+
7. Additional Terms.
|
| 344 |
+
|
| 345 |
+
"Additional permissions" are terms that supplement the terms of this
|
| 346 |
+
License by making exceptions from one or more of its conditions.
|
| 347 |
+
Additional permissions that are applicable to the entire Program shall
|
| 348 |
+
be treated as though they were included in this License, to the extent
|
| 349 |
+
that they are valid under applicable law. If additional permissions
|
| 350 |
+
apply only to part of the Program, that part may be used separately
|
| 351 |
+
under those permissions, but the entire Program remains governed by
|
| 352 |
+
this License without regard to the additional permissions.
|
| 353 |
+
|
| 354 |
+
When you convey a copy of a covered work, you may at your option
|
| 355 |
+
remove any additional permissions from that copy, or from any part of
|
| 356 |
+
it. (Additional permissions may be written to require their own
|
| 357 |
+
removal in certain cases when you modify the work.) You may place
|
| 358 |
+
additional permissions on material, added by you to a covered work,
|
| 359 |
+
for which you have or can give appropriate copyright permission.
|
| 360 |
+
|
| 361 |
+
Notwithstanding any other provision of this License, for material you
|
| 362 |
+
add to a covered work, you may (if authorized by the copyright holders of
|
| 363 |
+
that material) supplement the terms of this License with terms:
|
| 364 |
+
|
| 365 |
+
a) Disclaiming warranty or limiting liability differently from the
|
| 366 |
+
terms of sections 15 and 16 of this License; or
|
| 367 |
+
|
| 368 |
+
b) Requiring preservation of specified reasonable legal notices or
|
| 369 |
+
author attributions in that material or in the Appropriate Legal
|
| 370 |
+
Notices displayed by works containing it; or
|
| 371 |
+
|
| 372 |
+
c) Prohibiting misrepresentation of the origin of that material, or
|
| 373 |
+
requiring that modified versions of such material be marked in
|
| 374 |
+
reasonable ways as different from the original version; or
|
| 375 |
+
|
| 376 |
+
d) Limiting the use for publicity purposes of names of licensors or
|
| 377 |
+
authors of the material; or
|
| 378 |
+
|
| 379 |
+
e) Declining to grant rights under trademark law for use of some
|
| 380 |
+
trade names, trademarks, or service marks; or
|
| 381 |
+
|
| 382 |
+
f) Requiring indemnification of licensors and authors of that
|
| 383 |
+
material by anyone who conveys the material (or modified versions of
|
| 384 |
+
it) with contractual assumptions of liability to the recipient, for
|
| 385 |
+
any liability that these contractual assumptions directly impose on
|
| 386 |
+
those licensors and authors.
|
| 387 |
+
|
| 388 |
+
All other non-permissive additional terms are considered "further
|
| 389 |
+
restrictions" within the meaning of section 10. If the Program as you
|
| 390 |
+
received it, or any part of it, contains a notice stating that it is
|
| 391 |
+
governed by this License along with a term that is a further
|
| 392 |
+
restriction, you may remove that term. If a license document contains
|
| 393 |
+
a further restriction but permits relicensing or conveying under this
|
| 394 |
+
License, you may add to a covered work material governed by the terms
|
| 395 |
+
of that license document, provided that the further restriction does
|
| 396 |
+
not survive such relicensing or conveying.
|
| 397 |
+
|
| 398 |
+
If you add terms to a covered work in accord with this section, you
|
| 399 |
+
must place, in the relevant source files, a statement of the
|
| 400 |
+
additional terms that apply to those files, or a notice indicating
|
| 401 |
+
where to find the applicable terms.
|
| 402 |
+
|
| 403 |
+
Additional terms, permissive or non-permissive, may be stated in the
|
| 404 |
+
form of a separately written license, or stated as exceptions;
|
| 405 |
+
the above requirements apply either way.
|
| 406 |
+
|
| 407 |
+
8. Termination.
|
| 408 |
+
|
| 409 |
+
You may not propagate or modify a covered work except as expressly
|
| 410 |
+
provided under this License. Any attempt otherwise to propagate or
|
| 411 |
+
modify it is void, and will automatically terminate your rights under
|
| 412 |
+
this License (including any patent licenses granted under the third
|
| 413 |
+
paragraph of section 11).
|
| 414 |
+
|
| 415 |
+
However, if you cease all violation of this License, then your
|
| 416 |
+
license from a particular copyright holder is reinstated (a)
|
| 417 |
+
provisionally, unless and until the copyright holder explicitly and
|
| 418 |
+
finally terminates your license, and (b) permanently, if the copyright
|
| 419 |
+
holder fails to notify you of the violation by some reasonable means
|
| 420 |
+
prior to 60 days after the cessation.
|
| 421 |
+
|
| 422 |
+
Moreover, your license from a particular copyright holder is
|
| 423 |
+
reinstated permanently if the copyright holder notifies you of the
|
| 424 |
+
violation by some reasonable means, this is the first time you have
|
| 425 |
+
received notice of violation of this License (for any work) from that
|
| 426 |
+
copyright holder, and you cure the violation prior to 30 days after
|
| 427 |
+
your receipt of the notice.
|
| 428 |
+
|
| 429 |
+
Termination of your rights under this section does not terminate the
|
| 430 |
+
licenses of parties who have received copies or rights from you under
|
| 431 |
+
this License. If your rights have been terminated and not permanently
|
| 432 |
+
reinstated, you do not qualify to receive new licenses for the same
|
| 433 |
+
material under section 10.
|
| 434 |
+
|
| 435 |
+
9. Acceptance Not Required for Having Copies.
|
| 436 |
+
|
| 437 |
+
You are not required to accept this License in order to receive or
|
| 438 |
+
run a copy of the Program. Ancillary propagation of a covered work
|
| 439 |
+
occurring solely as a consequence of using peer-to-peer transmission
|
| 440 |
+
to receive a copy likewise does not require acceptance. However,
|
| 441 |
+
nothing other than this License grants you permission to propagate or
|
| 442 |
+
modify any covered work. These actions infringe copyright if you do
|
| 443 |
+
not accept this License. Therefore, by modifying or propagating a
|
| 444 |
+
covered work, you indicate your acceptance of this License to do so.
|
| 445 |
+
|
| 446 |
+
10. Automatic Licensing of Downstream Recipients.
|
| 447 |
+
|
| 448 |
+
Each time you convey a covered work, the recipient automatically
|
| 449 |
+
receives a license from the original licensors, to run, modify and
|
| 450 |
+
propagate that work, subject to this License. You are not responsible
|
| 451 |
+
for enforcing compliance by third parties with this License.
|
| 452 |
+
|
| 453 |
+
An "entity transaction" is a transaction transferring control of an
|
| 454 |
+
organization, or substantially all assets of one, or subdividing an
|
| 455 |
+
organization, or merging organizations. If propagation of a covered
|
| 456 |
+
work results from an entity transaction, each party to that
|
| 457 |
+
transaction who receives a copy of the work also receives whatever
|
| 458 |
+
licenses to the work the party's predecessor in interest had or could
|
| 459 |
+
give under the previous paragraph, plus a right to possession of the
|
| 460 |
+
Corresponding Source of the work from the predecessor in interest, if
|
| 461 |
+
the predecessor has it or can get it with reasonable efforts.
|
| 462 |
+
|
| 463 |
+
You may not impose any further restrictions on the exercise of the
|
| 464 |
+
rights granted or affirmed under this License. For example, you may
|
| 465 |
+
not impose a license fee, royalty, or other charge for exercise of
|
| 466 |
+
rights granted under this License, and you may not initiate litigation
|
| 467 |
+
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
| 468 |
+
any patent claim is infringed by making, using, selling, offering for
|
| 469 |
+
sale, or importing the Program or any portion of it.
|
| 470 |
+
|
| 471 |
+
11. Patents.
|
| 472 |
+
|
| 473 |
+
A "contributor" is a copyright holder who authorizes use under this
|
| 474 |
+
License of the Program or a work on which the Program is based. The
|
| 475 |
+
work thus licensed is called the contributor's "contributor version".
|
| 476 |
+
|
| 477 |
+
A contributor's "essential patent claims" are all patent claims
|
| 478 |
+
owned or controlled by the contributor, whether already acquired or
|
| 479 |
+
hereafter acquired, that would be infringed by some manner, permitted
|
| 480 |
+
by this License, of making, using, or selling its contributor version,
|
| 481 |
+
but do not include claims that would be infringed only as a
|
| 482 |
+
consequence of further modification of the contributor version. For
|
| 483 |
+
purposes of this definition, "control" includes the right to grant
|
| 484 |
+
patent sublicenses in a manner consistent with the requirements of
|
| 485 |
+
this License.
|
| 486 |
+
|
| 487 |
+
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
| 488 |
+
patent license under the contributor's essential patent claims, to
|
| 489 |
+
make, use, sell, offer for sale, import and otherwise run, modify and
|
| 490 |
+
propagate the contents of its contributor version.
|
| 491 |
+
|
| 492 |
+
In the following three paragraphs, a "patent license" is any express
|
| 493 |
+
agreement or commitment, however denominated, not to enforce a patent
|
| 494 |
+
(such as an express permission to practice a patent or covenant not to
|
| 495 |
+
sue for patent infringement). To "grant" such a patent license to a
|
| 496 |
+
party means to make such an agreement or commitment not to enforce a
|
| 497 |
+
patent against the party.
|
| 498 |
+
|
| 499 |
+
If you convey a covered work, knowingly relying on a patent license,
|
| 500 |
+
and the Corresponding Source of the work is not available for anyone
|
| 501 |
+
to copy, free of charge and under the terms of this License, through a
|
| 502 |
+
publicly available network server or other readily accessible means,
|
| 503 |
+
then you must either (1) cause the Corresponding Source to be so
|
| 504 |
+
available, or (2) arrange to deprive yourself of the benefit of the
|
| 505 |
+
patent license for this particular work, or (3) arrange, in a manner
|
| 506 |
+
consistent with the requirements of this License, to extend the patent
|
| 507 |
+
license to downstream recipients. "Knowingly relying" means you have
|
| 508 |
+
actual knowledge that, but for the patent license, your conveying the
|
| 509 |
+
covered work in a country, or your recipient's use of the covered work
|
| 510 |
+
in a country, would infringe one or more identifiable patents in that
|
| 511 |
+
country that you have reason to believe are valid.
|
| 512 |
+
|
| 513 |
+
If, pursuant to or in connection with a single transaction or
|
| 514 |
+
arrangement, you convey, or propagate by procuring conveyance of, a
|
| 515 |
+
covered work, and grant a patent license to some of the parties
|
| 516 |
+
receiving the covered work authorizing them to use, propagate, modify
|
| 517 |
+
or convey a specific copy of the covered work, then the patent license
|
| 518 |
+
you grant is automatically extended to all recipients of the covered
|
| 519 |
+
work and works based on it.
|
| 520 |
+
|
| 521 |
+
A patent license is "discriminatory" if it does not include within
|
| 522 |
+
the scope of its coverage, prohibits the exercise of, or is
|
| 523 |
+
conditioned on the non-exercise of one or more of the rights that are
|
| 524 |
+
specifically granted under this License. You may not convey a covered
|
| 525 |
+
work if you are a party to an arrangement with a third party that is
|
| 526 |
+
in the business of distributing software, under which you make payment
|
| 527 |
+
to the third party based on the extent of your activity of conveying
|
| 528 |
+
the work, and under which the third party grants, to any of the
|
| 529 |
+
parties who would receive the covered work from you, a discriminatory
|
| 530 |
+
patent license (a) in connection with copies of the covered work
|
| 531 |
+
conveyed by you (or copies made from those copies), or (b) primarily
|
| 532 |
+
for and in connection with specific products or compilations that
|
| 533 |
+
contain the covered work, unless you entered into that arrangement,
|
| 534 |
+
or that patent license was granted, prior to 28 March 2007.
|
| 535 |
+
|
| 536 |
+
Nothing in this License shall be construed as excluding or limiting
|
| 537 |
+
any implied license or other defenses to infringement that may
|
| 538 |
+
otherwise be available to you under applicable patent law.
|
| 539 |
+
|
| 540 |
+
12. No Surrender of Others' Freedom.
|
| 541 |
+
|
| 542 |
+
If conditions are imposed on you (whether by court order, agreement or
|
| 543 |
+
otherwise) that contradict the conditions of this License, they do not
|
| 544 |
+
excuse you from the conditions of this License. If you cannot convey a
|
| 545 |
+
covered work so as to satisfy simultaneously your obligations under this
|
| 546 |
+
License and any other pertinent obligations, then as a consequence you may
|
| 547 |
+
not convey it at all. For example, if you agree to terms that obligate you
|
| 548 |
+
to collect a royalty for further conveying from those to whom you convey
|
| 549 |
+
the Program, the only way you could satisfy both those terms and this
|
| 550 |
+
License would be to refrain entirely from conveying the Program.
|
| 551 |
+
|
| 552 |
+
13. Use with the GNU Affero General Public License.
|
| 553 |
+
|
| 554 |
+
Notwithstanding any other provision of this License, you have
|
| 555 |
+
permission to link or combine any covered work with a work licensed
|
| 556 |
+
under version 3 of the GNU Affero General Public License into a single
|
| 557 |
+
combined work, and to convey the resulting work. The terms of this
|
| 558 |
+
License will continue to apply to the part which is the covered work,
|
| 559 |
+
but the special requirements of the GNU Affero General Public License,
|
| 560 |
+
section 13, concerning interaction through a network will apply to the
|
| 561 |
+
combination as such.
|
| 562 |
+
|
| 563 |
+
14. Revised Versions of this License.
|
| 564 |
+
|
| 565 |
+
The Free Software Foundation may publish revised and/or new versions of
|
| 566 |
+
the GNU General Public License from time to time. Such new versions will
|
| 567 |
+
be similar in spirit to the present version, but may differ in detail to
|
| 568 |
+
address new problems or concerns.
|
| 569 |
+
|
| 570 |
+
Each version is given a distinguishing version number. If the
|
| 571 |
+
Program specifies that a certain numbered version of the GNU General
|
| 572 |
+
Public License "or any later version" applies to it, you have the
|
| 573 |
+
option of following the terms and conditions either of that numbered
|
| 574 |
+
version or of any later version published by the Free Software
|
| 575 |
+
Foundation. If the Program does not specify a version number of the
|
| 576 |
+
GNU General Public License, you may choose any version ever published
|
| 577 |
+
by the Free Software Foundation.
|
| 578 |
+
|
| 579 |
+
If the Program specifies that a proxy can decide which future
|
| 580 |
+
versions of the GNU General Public License can be used, that proxy's
|
| 581 |
+
public statement of acceptance of a version permanently authorizes you
|
| 582 |
+
to choose that version for the Program.
|
| 583 |
+
|
| 584 |
+
Later license versions may give you additional or different
|
| 585 |
+
permissions. However, no additional obligations are imposed on any
|
| 586 |
+
author or copyright holder as a result of your choosing to follow a
|
| 587 |
+
later version.
|
| 588 |
+
|
| 589 |
+
15. Disclaimer of Warranty.
|
| 590 |
+
|
| 591 |
+
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
| 592 |
+
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
| 593 |
+
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
| 594 |
+
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
| 595 |
+
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
| 596 |
+
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
| 597 |
+
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
| 598 |
+
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
| 599 |
+
|
| 600 |
+
16. Limitation of Liability.
|
| 601 |
+
|
| 602 |
+
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
| 603 |
+
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
| 604 |
+
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
| 605 |
+
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
| 606 |
+
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
| 607 |
+
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
| 608 |
+
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
| 609 |
+
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
| 610 |
+
SUCH DAMAGES.
|
| 611 |
+
|
| 612 |
+
17. Interpretation of Sections 15 and 16.
|
| 613 |
+
|
| 614 |
+
If the disclaimer of warranty and limitation of liability provided
|
| 615 |
+
above cannot be given local legal effect according to their terms,
|
| 616 |
+
reviewing courts shall apply local law that most closely approximates
|
| 617 |
+
an absolute waiver of all civil liability in connection with the
|
| 618 |
+
Program, unless a warranty or assumption of liability accompanies a
|
| 619 |
+
copy of the Program in return for a fee.
|
| 620 |
+
|
| 621 |
+
END OF TERMS AND CONDITIONS
|
| 622 |
+
|
| 623 |
+
How to Apply These Terms to Your New Programs
|
| 624 |
+
|
| 625 |
+
If you develop a new program, and you want it to be of the greatest
|
| 626 |
+
possible use to the public, the best way to achieve this is to make it
|
| 627 |
+
free software which everyone can redistribute and change under these terms.
|
| 628 |
+
|
| 629 |
+
To do so, attach the following notices to the program. It is safest
|
| 630 |
+
to attach them to the start of each source file to most effectively
|
| 631 |
+
state the exclusion of warranty; and each file should have at least
|
| 632 |
+
the "copyright" line and a pointer to where the full notice is found.
|
| 633 |
+
|
| 634 |
+
<one line to give the program's name and a brief idea of what it does.>
|
| 635 |
+
Copyright (C) <year> <name of author>
|
| 636 |
+
|
| 637 |
+
This program is free software: you can redistribute it and/or modify
|
| 638 |
+
it under the terms of the GNU General Public License as published by
|
| 639 |
+
the Free Software Foundation, either version 3 of the License, or
|
| 640 |
+
(at your option) any later version.
|
| 641 |
+
|
| 642 |
+
This program is distributed in the hope that it will be useful,
|
| 643 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 644 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 645 |
+
GNU General Public License for more details.
|
| 646 |
+
|
| 647 |
+
You should have received a copy of the GNU General Public License
|
| 648 |
+
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 649 |
+
|
| 650 |
+
Also add information on how to contact you by electronic and paper mail.
|
| 651 |
+
|
| 652 |
+
If the program does terminal interaction, make it output a short
|
| 653 |
+
notice like this when it starts in an interactive mode:
|
| 654 |
+
|
| 655 |
+
<program> Copyright (C) <year> <name of author>
|
| 656 |
+
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
|
| 657 |
+
This is free software, and you are welcome to redistribute it
|
| 658 |
+
under certain conditions; type `show c' for details.
|
| 659 |
+
|
| 660 |
+
The hypothetical commands `show w' and `show c' should show the appropriate
|
| 661 |
+
parts of the General Public License. Of course, your program's commands
|
| 662 |
+
might be different; for a GUI interface, you would use an "about box".
|
| 663 |
+
|
| 664 |
+
You should also get your employer (if you work as a programmer) or school,
|
| 665 |
+
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
| 666 |
+
For more information on this, and how to apply and follow the GNU GPL, see
|
| 667 |
+
<https://www.gnu.org/licenses/>.
|
| 668 |
+
|
| 669 |
+
The GNU General Public License does not permit incorporating your program
|
| 670 |
+
into proprietary programs. If your program is a subroutine library, you
|
| 671 |
+
may consider it more useful to permit linking proprietary applications with
|
| 672 |
+
the library. If this is what you want to do, use the GNU Lesser General
|
| 673 |
+
Public License instead of this License. But first, please read
|
| 674 |
+
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/README.md
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ComfyUI-FlashVSR_Ultra_Fast
|
| 2 |
+
Running FlashVSR on lower VRAM without any artifacts.
|
| 3 |
+
**[[📃中文版本](./README_zh.md)]**
|
| 4 |
+
|
| 5 |
+
## Changelog
|
| 6 |
+
#### 2025-10-24
|
| 7 |
+
- Added long video pipeline that significantly reduces VRAM usage when upscaling long videos.
|
| 8 |
+
|
| 9 |
+
#### 2025-10-21
|
| 10 |
+
- Initial this project, introducing features such as `tile_dit` to significantly reducing VRAM usage.
|
| 11 |
+
|
| 12 |
+
#### 2025-10-22
|
| 13 |
+
- Replaced `Block-Sparse-Attention` with `Sparse_Sage`, removing the need to compile any custom kernels.
|
| 14 |
+
- Added support for running on RTX 50 series GPUs.
|
| 15 |
+
|
| 16 |
+
## Preview
|
| 17 |
+

|
| 18 |
+
|
| 19 |
+
## Usage
|
| 20 |
+
- **mode:**
|
| 21 |
+
`tiny` -> faster (default); `full` -> higher quality
|
| 22 |
+
- **scale:**
|
| 23 |
+
`4` is always better, unless you are low on VRAM then use `2`
|
| 24 |
+
- **color_fix:**
|
| 25 |
+
Use wavelet transform to correct the color of output video.
|
| 26 |
+
- **tiled_vae:**
|
| 27 |
+
Set to True for lower VRAM consumption during decoding at the cost of speed.
|
| 28 |
+
- **tiled_dit:**
|
| 29 |
+
Significantly reduces VRAM usage at the cost of speed.
|
| 30 |
+
- **tile\_size, tile\_overlap**:
|
| 31 |
+
How to split the input video.
|
| 32 |
+
- **unload_dit:**
|
| 33 |
+
Unload DiT before decoding to reduce VRAM peak at the cost of speed.
|
| 34 |
+
|
| 35 |
+
## Installation
|
| 36 |
+
|
| 37 |
+
#### nodes:
|
| 38 |
+
|
| 39 |
+
```bash
|
| 40 |
+
cd ComfyUI/custom_nodes
|
| 41 |
+
git clone https://github.com/lihaoyun6/ComfyUI-FlashVSR_Ultra_Fast.git
|
| 42 |
+
python -m pip install -r ComfyUI-FlashVSR_Ultra_Fast/requirements.txt
|
| 43 |
+
```
|
| 44 |
+
📢: For Turing or older GPU, please install `triton<3.3.0`:
|
| 45 |
+
|
| 46 |
+
```bash
|
| 47 |
+
# Windows
|
| 48 |
+
python -m pip install -U triton-windows<3.3.0
|
| 49 |
+
# Linux
|
| 50 |
+
python -m pip install -U triton<3.3.0
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
#### models:
|
| 54 |
+
|
| 55 |
+
- Download the entire `FlashVSR` folder with all the files inside it from [here](https://huggingface.co/JunhaoZhuang/FlashVSR) and put it in the `ComfyUI/models`
|
| 56 |
+
|
| 57 |
+
```
|
| 58 |
+
├── ComfyUI/models/FlashVSR
|
| 59 |
+
| ├── LQ_proj_in.ckpt
|
| 60 |
+
| ├── TCDecoder.ckpt
|
| 61 |
+
| ├── diffusion_pytorch_model_streaming_dmd.safetensors
|
| 62 |
+
| ├── Wan2.1_VAE.pth
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
## Acknowledgments
|
| 66 |
+
- [FlashVSR](https://github.com/OpenImagingLab/FlashVSR) @OpenImagingLab
|
| 67 |
+
- [Sparse_SageAttention](https://github.com/jt-zhang/Sparse_SageAttention_API) @jt-zhang
|
| 68 |
+
- [ComfyUI](https://github.com/comfyanonymous/ComfyUI) @comfyanonymous
|
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/README_zh.md
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ComfyUI-FlashVSR_Ultra_Fast
|
| 2 |
+
在低显存环境下运行 FlashVSR,同时保持无伪影高质量输出。
|
| 3 |
+
**[[📃English](./readme.md)]**
|
| 4 |
+
|
| 5 |
+
## 更新日志
|
| 6 |
+
#### 2025-10-24
|
| 7 |
+
- 新增长视频管道, 可显著降低长视频放大的显存用量
|
| 8 |
+
|
| 9 |
+
#### 2025-10-21
|
| 10 |
+
- 项目首次发布, 引入了`tile_dit`等功能, 大幅度降低显存需求
|
| 11 |
+
|
| 12 |
+
#### 2025-10-22
|
| 13 |
+
- 使用`Sparse_SageAttention`替换了`Block-Sparse-Attention`, 无需编译安装任何自定义内核, 开箱即用.
|
| 14 |
+
- 支持在 RTX50 系列显卡上运行.
|
| 15 |
+
|
| 16 |
+
## 预览
|
| 17 |
+

|
| 18 |
+
|
| 19 |
+
## 使用说明
|
| 20 |
+
- **mode(模式):**
|
| 21 |
+
`tiny` → 更快(默认);`full` → 更高质量
|
| 22 |
+
- **scale(放大倍数):**
|
| 23 |
+
通常使用 `4` 效果更好;如果显存不足,可使用 `2`
|
| 24 |
+
- **color_fix(颜色修正):**
|
| 25 |
+
使用小波变换方法修正输出视频的颜色偏差。
|
| 26 |
+
- **tiled_vae(VAE分块解码):**
|
| 27 |
+
启用后可显著降低显存占用,但会降低解码速度。
|
| 28 |
+
- **tiled_dit(DiT分块计算):**
|
| 29 |
+
大幅减少显存占用,但会降低推理速度。
|
| 30 |
+
- **tile_size / tile_overlap(分块大小与重叠):**
|
| 31 |
+
控制输入视频在推理时的分块方式。
|
| 32 |
+
- **unload_dit(卸载DiT模型):**
|
| 33 |
+
解码前卸载 DiT 模型以降低显存峰值,但会略微降低速度。
|
| 34 |
+
|
| 35 |
+
## 安装步骤
|
| 36 |
+
|
| 37 |
+
#### 安装节点:
|
| 38 |
+
```bash
|
| 39 |
+
cd ComfyUI/custom_nodes
|
| 40 |
+
git clone https://github.com/lihaoyun6/ComfyUI-FlashVSR_Ultra_Fast.git
|
| 41 |
+
python -m pip install -r ComfyUI-FlashVSR_Ultra_Fast/requirements.txt
|
| 42 |
+
```
|
| 43 |
+
📢: 要在RTX20系或更早的GPU上运行, 请安装`triton<3.3.0`:
|
| 44 |
+
|
| 45 |
+
```bash
|
| 46 |
+
# Windows
|
| 47 |
+
python -m pip install -U triton-windows<3.3.0
|
| 48 |
+
# Linux
|
| 49 |
+
python -m pip install -U triton<3.3.0
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
#### 模型下载:
|
| 53 |
+
- 从[这里](https://huggingface.co/JunhaoZhuang/FlashVSR)下载整个`FlashVSR`文件夹和它里面的所有文件, 并将其放到`ComfyUI/models`目录中。
|
| 54 |
+
|
| 55 |
+
```
|
| 56 |
+
├── ComfyUI/models/FlashVSR
|
| 57 |
+
| ├── LQ_proj_in.ckpt
|
| 58 |
+
| ├── TCDecoder.ckpt
|
| 59 |
+
| ├── diffusion_pytorch_model_streaming_dmd.safetensors
|
| 60 |
+
| ├── Wan2.1_VAE.pth
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
## 致谢
|
| 64 |
+
- [FlashVSR](https://github.com/OpenImagingLab/FlashVSR) @OpenImagingLab
|
| 65 |
+
- [Sparse_SageAttention](https://github.com/jt-zhang/Sparse_SageAttention_API) @jt-zhang
|
| 66 |
+
- [ComfyUI](https://github.com/comfyanonymous/ComfyUI) @comfyanonymous
|
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
|
| 2 |
+
|
| 3 |
+
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
|
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/img/preview.jpg
ADDED
|
Git LFS Details
|
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/nodes.py
ADDED
|
@@ -0,0 +1,553 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import os,gc
|
| 5 |
+
import math
|
| 6 |
+
import torch
|
| 7 |
+
import folder_paths
|
| 8 |
+
import comfy.utils
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
from einops import rearrange
|
| 14 |
+
from huggingface_hub import snapshot_download
|
| 15 |
+
from .src import ModelManager, FlashVSRFullPipeline, FlashVSRTinyPipeline, FlashVSRTinyLongPipeline
|
| 16 |
+
from .src.models.TCDecoder import build_tcdecoder
|
| 17 |
+
from .src.models.utils import clean_vram, get_device_list, Buffer_LQ4x_Proj, Causal_LQ4x_Proj
|
| 18 |
+
from .src.models import wan_video_dit
|
| 19 |
+
|
| 20 |
+
device_choices = get_device_list()
|
| 21 |
+
|
| 22 |
+
def log(message:str, message_type:str='normal'):
|
| 23 |
+
if message_type == 'error':
|
| 24 |
+
message = '\033[1;41m' + message + '\033[m'
|
| 25 |
+
elif message_type == 'warning':
|
| 26 |
+
message = '\033[1;31m' + message + '\033[m'
|
| 27 |
+
elif message_type == 'finish':
|
| 28 |
+
message = '\033[1;32m' + message + '\033[m'
|
| 29 |
+
elif message_type == 'info':
|
| 30 |
+
message = '\033[1;33m' + message + '\033[m'
|
| 31 |
+
else:
|
| 32 |
+
message = message
|
| 33 |
+
print(f"{message}")
|
| 34 |
+
|
| 35 |
+
def model_downlod(model_name="JunhaoZhuang/FlashVSR"):
|
| 36 |
+
model_dir = os.path.join(folder_paths.models_dir, model_name.split("/")[-1])
|
| 37 |
+
if not os.path.exists(model_dir):
|
| 38 |
+
log(f"Downloading model '{model_name}' from huggingface...", message_type='info')
|
| 39 |
+
snapshot_download(repo_id=model_name, local_dir=model_dir, local_dir_use_symlinks=False, resume_download=True)
|
| 40 |
+
|
| 41 |
+
def tensor2video(frames: torch.Tensor):
|
| 42 |
+
video_squeezed = frames.squeeze(0)
|
| 43 |
+
video_permuted = rearrange(video_squeezed, "C F H W -> F H W C")
|
| 44 |
+
video_final = (video_permuted.float() + 1.0) / 2.0
|
| 45 |
+
return video_final
|
| 46 |
+
|
| 47 |
+
def largest_8n1_leq(n): # 8n+1
|
| 48 |
+
return 0 if n < 1 else ((n - 1)//8)*8 + 1
|
| 49 |
+
|
| 50 |
+
def next_8n5(n): # next 8n+5
|
| 51 |
+
return 21 if n < 21 else ((n - 5 + 7) // 8) * 8 + 5
|
| 52 |
+
|
| 53 |
+
def compute_scaled_and_target_dims(w0: int, h0: int, scale: int = 4, multiple: int = 128):
|
| 54 |
+
if w0 <= 0 or h0 <= 0:
|
| 55 |
+
raise ValueError("invalid original size")
|
| 56 |
+
|
| 57 |
+
sW, sH = w0 * scale, h0 * scale
|
| 58 |
+
tW = max(multiple, (sW // multiple) * multiple)
|
| 59 |
+
tH = max(multiple, (sH // multiple) * multiple)
|
| 60 |
+
return sW, sH, tW, tH
|
| 61 |
+
|
| 62 |
+
def tensor_upscale_then_center_crop(frame_tensor: torch.Tensor, scale: int, tW: int, tH: int) -> torch.Tensor:
|
| 63 |
+
h0, w0, c = frame_tensor.shape
|
| 64 |
+
tensor_bchw = frame_tensor.permute(2, 0, 1).unsqueeze(0) # HWC -> CHW -> BCHW
|
| 65 |
+
|
| 66 |
+
sW, sH = w0 * scale, h0 * scale
|
| 67 |
+
upscaled_tensor = F.interpolate(tensor_bchw, size=(sH, sW), mode='bicubic', align_corners=False)
|
| 68 |
+
|
| 69 |
+
l = max(0, (sW - tW) // 2)
|
| 70 |
+
t = max(0, (sH - tH) // 2)
|
| 71 |
+
cropped_tensor = upscaled_tensor[:, :, t:t + tH, l:l + tW]
|
| 72 |
+
|
| 73 |
+
return cropped_tensor.squeeze(0)
|
| 74 |
+
|
| 75 |
+
def prepare_input_tensor(image_tensor: torch.Tensor, device, scale: int = 4, dtype=torch.bfloat16):
|
| 76 |
+
N0, h0, w0, _ = image_tensor.shape
|
| 77 |
+
|
| 78 |
+
multiple = 128
|
| 79 |
+
sW, sH, tW, tH = compute_scaled_and_target_dims(w0, h0, scale=scale, multiple=multiple)
|
| 80 |
+
num_frames_with_padding = N0 + 4
|
| 81 |
+
F = largest_8n1_leq(num_frames_with_padding)
|
| 82 |
+
|
| 83 |
+
if F == 0:
|
| 84 |
+
raise RuntimeError(f"Not enough frames after padding. Got {num_frames_with_padding}.")
|
| 85 |
+
|
| 86 |
+
frames = []
|
| 87 |
+
for i in range(F):
|
| 88 |
+
frame_idx = min(i, N0 - 1)
|
| 89 |
+
frame_slice = image_tensor[frame_idx].to(device)
|
| 90 |
+
tensor_chw = tensor_upscale_then_center_crop(frame_slice, scale=scale, tW=tW, tH=tH).to('cpu').to(dtype) * 2.0 - 1.0
|
| 91 |
+
frames.append(tensor_chw)
|
| 92 |
+
del frame_slice
|
| 93 |
+
|
| 94 |
+
vid_stacked = torch.stack(frames, 0)
|
| 95 |
+
vid_final = vid_stacked.permute(1, 0, 2, 3).unsqueeze(0)
|
| 96 |
+
|
| 97 |
+
del vid_stacked
|
| 98 |
+
clean_vram()
|
| 99 |
+
|
| 100 |
+
return vid_final, tH, tW, F
|
| 101 |
+
|
| 102 |
+
def calculate_tile_coords(height, width, tile_size, overlap):
|
| 103 |
+
coords = []
|
| 104 |
+
|
| 105 |
+
stride = tile_size - overlap
|
| 106 |
+
num_rows = math.ceil((height - overlap) / stride)
|
| 107 |
+
num_cols = math.ceil((width - overlap) / stride)
|
| 108 |
+
|
| 109 |
+
for r in range(num_rows):
|
| 110 |
+
for c in range(num_cols):
|
| 111 |
+
y1 = r * stride
|
| 112 |
+
x1 = c * stride
|
| 113 |
+
|
| 114 |
+
y2 = min(y1 + tile_size, height)
|
| 115 |
+
x2 = min(x1 + tile_size, width)
|
| 116 |
+
|
| 117 |
+
if y2 - y1 < tile_size:
|
| 118 |
+
y1 = max(0, y2 - tile_size)
|
| 119 |
+
if x2 - x1 < tile_size:
|
| 120 |
+
x1 = max(0, x2 - tile_size)
|
| 121 |
+
|
| 122 |
+
coords.append((x1, y1, x2, y2))
|
| 123 |
+
|
| 124 |
+
return coords
|
| 125 |
+
|
| 126 |
+
def create_feather_mask(size, overlap):
|
| 127 |
+
H, W = size
|
| 128 |
+
mask = torch.ones(1, 1, H, W)
|
| 129 |
+
ramp = torch.linspace(0, 1, overlap)
|
| 130 |
+
|
| 131 |
+
mask[:, :, :, :overlap] = torch.minimum(mask[:, :, :, :overlap], ramp.view(1, 1, 1, -1))
|
| 132 |
+
mask[:, :, :, -overlap:] = torch.minimum(mask[:, :, :, -overlap:], ramp.flip(0).view(1, 1, 1, -1))
|
| 133 |
+
|
| 134 |
+
mask[:, :, :overlap, :] = torch.minimum(mask[:, :, :overlap, :], ramp.view(1, 1, -1, 1))
|
| 135 |
+
mask[:, :, -overlap:, :] = torch.minimum(mask[:, :, -overlap:, :], ramp.flip(0).view(1, 1, -1, 1))
|
| 136 |
+
|
| 137 |
+
return mask
|
| 138 |
+
|
| 139 |
+
def init_pipeline(model, mode, device, dtype, alt_vae="none"):
|
| 140 |
+
model_downlod(model_name="JunhaoZhuang/"+model)
|
| 141 |
+
model_path = os.path.join(folder_paths.models_dir, model)
|
| 142 |
+
if not os.path.exists(model_path):
|
| 143 |
+
raise RuntimeError(f'Model directory does not exist!\nPlease save all weights to "{model_path}"')
|
| 144 |
+
ckpt_path = os.path.join(model_path, "diffusion_pytorch_model_streaming_dmd.safetensors")
|
| 145 |
+
if not os.path.exists(ckpt_path):
|
| 146 |
+
raise RuntimeError(f'"diffusion_pytorch_model_streaming_dmd.safetensors" does not exist!\nPlease save it to "{model_path}"')
|
| 147 |
+
if alt_vae != "none":
|
| 148 |
+
vae_path = folder_paths.get_full_path_or_raise("vae", alt_vae)
|
| 149 |
+
if not os.path.exists(vae_path):
|
| 150 |
+
raise RuntimeError(f'"{alt_vae}" does not exist!')
|
| 151 |
+
else:
|
| 152 |
+
vae_path = os.path.join(model_path, "Wan2.1_VAE.pth")
|
| 153 |
+
if not os.path.exists(vae_path):
|
| 154 |
+
raise RuntimeError(f'"Wan2.1_VAE.pth" does not exist!\nPlease save it to "{model_path}"')
|
| 155 |
+
lq_path = os.path.join(model_path, "LQ_proj_in.ckpt")
|
| 156 |
+
if not os.path.exists(lq_path):
|
| 157 |
+
raise RuntimeError(f'"LQ_proj_in.ckpt" does not exist!\nPlease save it to "{model_path}"')
|
| 158 |
+
tcd_path = os.path.join(model_path, "TCDecoder.ckpt")
|
| 159 |
+
if not os.path.exists(tcd_path):
|
| 160 |
+
raise RuntimeError(f'"TCDecoder.ckpt" does not exist!\nPlease save it to "{model_path}"')
|
| 161 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 162 |
+
prompt_path = os.path.join(current_dir, "posi_prompt.pth")
|
| 163 |
+
|
| 164 |
+
mm = ModelManager(torch_dtype=dtype, device="cpu")
|
| 165 |
+
if mode == "full":
|
| 166 |
+
mm.load_models([ckpt_path, vae_path])
|
| 167 |
+
pipe = FlashVSRFullPipeline.from_model_manager(mm, device=device)
|
| 168 |
+
pipe.vae.model.encoder = None
|
| 169 |
+
pipe.vae.model.conv1 = None
|
| 170 |
+
else:
|
| 171 |
+
mm.load_models([ckpt_path])
|
| 172 |
+
if mode == "tiny":
|
| 173 |
+
pipe = FlashVSRTinyPipeline.from_model_manager(mm, device=device)
|
| 174 |
+
else:
|
| 175 |
+
pipe = FlashVSRTinyLongPipeline.from_model_manager(mm, device=device)
|
| 176 |
+
multi_scale_channels = [512, 256, 128, 128]
|
| 177 |
+
pipe.TCDecoder = build_tcdecoder(new_channels=multi_scale_channels, device=device, dtype=dtype, new_latent_channels=16+768)
|
| 178 |
+
mis = pipe.TCDecoder.load_state_dict(torch.load(tcd_path, map_location=device), strict=False)
|
| 179 |
+
pipe.TCDecoder.clean_mem()
|
| 180 |
+
|
| 181 |
+
if model == "FlashVSR":
|
| 182 |
+
pipe.denoising_model().LQ_proj_in = Buffer_LQ4x_Proj(in_dim=3, out_dim=1536, layer_num=1).to(device, dtype=dtype)
|
| 183 |
+
else:
|
| 184 |
+
pipe.denoising_model().LQ_proj_in = Causal_LQ4x_Proj(in_dim=3, out_dim=1536, layer_num=1).to(device, dtype=dtype)
|
| 185 |
+
pipe.denoising_model().LQ_proj_in.load_state_dict(torch.load(lq_path, map_location="cpu"), strict=True)
|
| 186 |
+
pipe.denoising_model().LQ_proj_in.to(device)
|
| 187 |
+
pipe.to(device, dtype=dtype)
|
| 188 |
+
pipe.enable_vram_management(num_persistent_param_in_dit=None)
|
| 189 |
+
pipe.init_cross_kv(prompt_path=prompt_path)
|
| 190 |
+
pipe.load_models_to_device(["dit","vae"])
|
| 191 |
+
pipe.offload_model()
|
| 192 |
+
|
| 193 |
+
return pipe
|
| 194 |
+
|
| 195 |
+
class cqdm:
|
| 196 |
+
def __init__(self, iterable=None, total=None, desc="Processing"):
|
| 197 |
+
self.desc = desc
|
| 198 |
+
self.pbar = None
|
| 199 |
+
self.iterable = None
|
| 200 |
+
self.total = total
|
| 201 |
+
|
| 202 |
+
if iterable is not None:
|
| 203 |
+
try:
|
| 204 |
+
self.total = len(iterable)
|
| 205 |
+
self.iterable = iter(iterable)
|
| 206 |
+
except TypeError:
|
| 207 |
+
if self.total is None:
|
| 208 |
+
raise ValueError("Total must be provided for iterables with no length.")
|
| 209 |
+
|
| 210 |
+
elif self.total is not None:
|
| 211 |
+
pass
|
| 212 |
+
|
| 213 |
+
else:
|
| 214 |
+
raise ValueError("Either iterable or total must be provided.")
|
| 215 |
+
|
| 216 |
+
def __iter__(self):
|
| 217 |
+
if self.iterable is None:
|
| 218 |
+
raise TypeError(f"'{type(self).__name__}' object is not iterable. Did you mean to use it with a 'with' statement?")
|
| 219 |
+
if self.pbar is None:
|
| 220 |
+
self.pbar = comfy.utils.ProgressBar(self.total)
|
| 221 |
+
return self
|
| 222 |
+
|
| 223 |
+
def __next__(self):
|
| 224 |
+
if self.iterable is None:
|
| 225 |
+
raise TypeError("Cannot call __next__ on a non-iterable cqdm object.")
|
| 226 |
+
try:
|
| 227 |
+
val = next(self.iterable)
|
| 228 |
+
if self.pbar:
|
| 229 |
+
self.pbar.update(1)
|
| 230 |
+
return val
|
| 231 |
+
except StopIteration:
|
| 232 |
+
raise
|
| 233 |
+
|
| 234 |
+
def __enter__(self):
|
| 235 |
+
if self.pbar is None:
|
| 236 |
+
self.pbar = comfy.utils.ProgressBar(self.total)
|
| 237 |
+
return self.pbar
|
| 238 |
+
|
| 239 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 240 |
+
pass
|
| 241 |
+
|
| 242 |
+
def __len__(self):
|
| 243 |
+
return self.total
|
| 244 |
+
|
| 245 |
+
def flashvsr(pipe, frames, scale, color_fix, tiled_vae, tiled_dit, tile_size, tile_overlap, unload_dit, sparse_ratio, kv_ratio, local_range, seed, force_offload):
|
| 246 |
+
_frames = frames
|
| 247 |
+
_device = pipe.device
|
| 248 |
+
dtype = pipe.torch_dtype
|
| 249 |
+
|
| 250 |
+
add = next_8n5(frames.shape[0]) - frames.shape[0]
|
| 251 |
+
padding_frames = frames[-1:, :, :, :].repeat(add, 1, 1, 1)
|
| 252 |
+
_frames = torch.cat([frames, padding_frames], dim=0)
|
| 253 |
+
|
| 254 |
+
if tiled_dit:
|
| 255 |
+
N, H, W, C = _frames.shape
|
| 256 |
+
|
| 257 |
+
final_output_canvas = torch.zeros(
|
| 258 |
+
(N, H * scale, W * scale, C),
|
| 259 |
+
dtype=torch.float16,
|
| 260 |
+
device="cpu"
|
| 261 |
+
)
|
| 262 |
+
weight_sum_canvas = torch.zeros_like(final_output_canvas)
|
| 263 |
+
tile_coords = calculate_tile_coords(H, W, tile_size, tile_overlap)
|
| 264 |
+
latent_tiles_cpu = []
|
| 265 |
+
|
| 266 |
+
for i, (x1, y1, x2, y2) in enumerate(cqdm(tile_coords, desc="Processing Tiles")):
|
| 267 |
+
log(f"[FlashVSR] Processing tile {i+1}/{len(tile_coords)}: coords ({x1},{y1}) to ({x2},{y2})", message_type='info')
|
| 268 |
+
input_tile = _frames[:, y1:y2, x1:x2, :]
|
| 269 |
+
|
| 270 |
+
LQ_tile, th, tw, F = prepare_input_tensor(input_tile, _device, scale=scale, dtype=dtype)
|
| 271 |
+
if not isinstance(pipe, FlashVSRTinyLongPipeline):
|
| 272 |
+
LQ_tile = LQ_tile.to(_device)
|
| 273 |
+
|
| 274 |
+
output_tile_gpu = pipe(
|
| 275 |
+
prompt="", negative_prompt="", cfg_scale=1.0, num_inference_steps=1, seed=seed, tiled=tiled_vae,
|
| 276 |
+
LQ_video=LQ_tile, num_frames=F, height=th, width=tw, is_full_block=False, if_buffer=True,
|
| 277 |
+
topk_ratio=sparse_ratio*768*1280/(th*tw), kv_ratio=kv_ratio, local_range=local_range,
|
| 278 |
+
color_fix=color_fix, unload_dit=unload_dit, force_offload=force_offload
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
processed_tile_cpu = tensor2video(output_tile_gpu).to("cpu")
|
| 282 |
+
|
| 283 |
+
mask_nchw = create_feather_mask(
|
| 284 |
+
(processed_tile_cpu.shape[1], processed_tile_cpu.shape[2]),
|
| 285 |
+
tile_overlap * scale
|
| 286 |
+
).to("cpu")
|
| 287 |
+
mask_nhwc = mask_nchw.permute(0, 2, 3, 1)
|
| 288 |
+
out_x1, out_y1 = x1 * scale, y1 * scale
|
| 289 |
+
|
| 290 |
+
tile_H_scaled = processed_tile_cpu.shape[1]
|
| 291 |
+
tile_W_scaled = processed_tile_cpu.shape[2]
|
| 292 |
+
out_x2, out_y2 = out_x1 + tile_W_scaled, out_y1 + tile_H_scaled
|
| 293 |
+
final_output_canvas[:, out_y1:out_y2, out_x1:out_x2, :] += processed_tile_cpu * mask_nhwc
|
| 294 |
+
weight_sum_canvas[:, out_y1:out_y2, out_x1:out_x2, :] += mask_nhwc
|
| 295 |
+
|
| 296 |
+
del LQ_tile, output_tile_gpu, processed_tile_cpu, input_tile
|
| 297 |
+
clean_vram()
|
| 298 |
+
|
| 299 |
+
weight_sum_canvas[weight_sum_canvas == 0] = 1.0
|
| 300 |
+
final_output = final_output_canvas / weight_sum_canvas
|
| 301 |
+
else:
|
| 302 |
+
log("[FlashVSR] Preparing frames...")
|
| 303 |
+
LQ, th, tw, F = prepare_input_tensor(_frames, _device, scale=scale, dtype=dtype)
|
| 304 |
+
if not isinstance(pipe, FlashVSRTinyLongPipeline):
|
| 305 |
+
LQ = LQ.to(_device)
|
| 306 |
+
log(f"[FlashVSR] Processing {frames.shape[0]} frames...", message_type='info')
|
| 307 |
+
|
| 308 |
+
video = pipe(
|
| 309 |
+
prompt="", negative_prompt="", cfg_scale=1.0, num_inference_steps=1, seed=seed, tiled=tiled_vae,
|
| 310 |
+
progress_bar_cmd=cqdm, LQ_video=LQ, num_frames=F, height=th, width=tw, is_full_block=False, if_buffer=True,
|
| 311 |
+
topk_ratio=sparse_ratio*768*1280/(th*tw), kv_ratio=kv_ratio, local_range=local_range,
|
| 312 |
+
color_fix = color_fix, unload_dit=unload_dit, force_offload=force_offload
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
final_output = tensor2video(video).to('cpu')
|
| 316 |
+
|
| 317 |
+
del video, LQ
|
| 318 |
+
clean_vram()
|
| 319 |
+
|
| 320 |
+
log("[FlashVSR] Done.", message_type='info')
|
| 321 |
+
if frames.shape[0] == 1:
|
| 322 |
+
final_output = final_output.to(_device)
|
| 323 |
+
stacked_image_tensor = torch.median(final_output, dim=0).values.unsqueeze(0).float().to('cpu')
|
| 324 |
+
del final_output
|
| 325 |
+
clean_vram()
|
| 326 |
+
return stacked_image_tensor
|
| 327 |
+
|
| 328 |
+
return final_output[:frames.shape[0], :, :, :]
|
| 329 |
+
|
| 330 |
+
class FlashVSRNodeInitPipe:
|
| 331 |
+
@classmethod
|
| 332 |
+
def INPUT_TYPES(cls):
|
| 333 |
+
return {
|
| 334 |
+
"required": {
|
| 335 |
+
"model": (["FlashVSR", "FlashVSR-v1.1"], {
|
| 336 |
+
"default": "FlashVSR-v1.1",
|
| 337 |
+
"tooltip": "Model version."
|
| 338 |
+
}),
|
| 339 |
+
"mode": (["tiny", "tiny-long", "full"], {
|
| 340 |
+
"default": "tiny",
|
| 341 |
+
"tooltip": 'Using "tiny-long" mode can significantly reduce VRAM used with long video input.'
|
| 342 |
+
}),
|
| 343 |
+
"alt_vae": (["none"] + folder_paths.get_filename_list("vae"), {
|
| 344 |
+
"default": "none",
|
| 345 |
+
"tooltip": 'Replaces the built-in VAE, only available in "full" mode.'
|
| 346 |
+
}),
|
| 347 |
+
"force_offload": ("BOOLEAN", {
|
| 348 |
+
"default": True,
|
| 349 |
+
"tooltip": "Offload all weights to CPU after running a workflow to free up VRAM."
|
| 350 |
+
}),
|
| 351 |
+
"precision": (["fp16", "bf16"], {
|
| 352 |
+
"default": "bf16",
|
| 353 |
+
"tooltip": "Data and inference precision."
|
| 354 |
+
}),
|
| 355 |
+
"device": (device_choices, {
|
| 356 |
+
"default": device_choices[0],
|
| 357 |
+
"tooltip": "Device to load the weights, default: auto (CUDA if available, else CPU)"
|
| 358 |
+
}),
|
| 359 |
+
"attention_mode": (["sparse_sage_attention", "block_sparse_attention"], {
|
| 360 |
+
"default": "sparse_sage_attention",
|
| 361 |
+
"tooltip": '"sparse_sage_attention" is available for sm_75 to sm_120\n"block_sparse_attention" is available for sm_80 to sm_100'
|
| 362 |
+
}),
|
| 363 |
+
}
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
RETURN_TYPES = ("PIPE",)
|
| 367 |
+
RETURN_NAMES = ("pipe",)
|
| 368 |
+
FUNCTION = "main"
|
| 369 |
+
CATEGORY = "FlashVSR"
|
| 370 |
+
DESCRIPTION = 'Download the entire "FlashVSR" folder with all the files inside it from "https://huggingface.co/JunhaoZhuang/FlashVSR" and put it in the "ComfyUI/models"'
|
| 371 |
+
|
| 372 |
+
def main(self, model, mode, alt_vae, force_offload, precision, device, attention_mode):
|
| 373 |
+
_device = device
|
| 374 |
+
if device == "auto":
|
| 375 |
+
_device = "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else device
|
| 376 |
+
if _device == "auto" or _device not in device_choices:
|
| 377 |
+
raise RuntimeError("No devices found to run FlashVSR!")
|
| 378 |
+
|
| 379 |
+
if _device.startswith("cuda"):
|
| 380 |
+
torch.cuda.set_device(_device)
|
| 381 |
+
|
| 382 |
+
if attention_mode == "sparse_sage_attention":
|
| 383 |
+
wan_video_dit.USE_BLOCK_ATTN = False
|
| 384 |
+
else:
|
| 385 |
+
wan_video_dit.USE_BLOCK_ATTN = True
|
| 386 |
+
|
| 387 |
+
dtype_map = {
|
| 388 |
+
"fp32": torch.float32,
|
| 389 |
+
"fp16": torch.float16,
|
| 390 |
+
"bf16": torch.bfloat16,
|
| 391 |
+
}
|
| 392 |
+
try:
|
| 393 |
+
dtype = dtype_map[precision]
|
| 394 |
+
except:
|
| 395 |
+
dtype = torch.bfloat16
|
| 396 |
+
|
| 397 |
+
pipe = init_pipeline(model, mode, _device, dtype, alt_vae=alt_vae)
|
| 398 |
+
return((pipe, force_offload),)
|
| 399 |
+
|
| 400 |
+
class FlashVSRNodeAdv:
|
| 401 |
+
@classmethod
|
| 402 |
+
def INPUT_TYPES(cls):
|
| 403 |
+
return {
|
| 404 |
+
"required": {
|
| 405 |
+
"pipe": ("PIPE", {
|
| 406 |
+
"tooltip": "FlashVSR pipeline"
|
| 407 |
+
}),
|
| 408 |
+
"frames": ("IMAGE", {
|
| 409 |
+
"tooltip": "Sequential video frames as IMAGE tensor batch"
|
| 410 |
+
}),
|
| 411 |
+
"scale": ("INT", {
|
| 412 |
+
"default": 2,
|
| 413 |
+
"min": 2,
|
| 414 |
+
"max": 4,
|
| 415 |
+
}),
|
| 416 |
+
"color_fix": ("BOOLEAN", {
|
| 417 |
+
"default": True,
|
| 418 |
+
"tooltip": "Use wavelet transform to correct output video color."
|
| 419 |
+
}),
|
| 420 |
+
"tiled_vae": ("BOOLEAN", {
|
| 421 |
+
"default": True,
|
| 422 |
+
"tooltip": "Disable tiling: faster decode but higher VRAM usage.\nSet to True for lower memory consumption at the cost of speed."
|
| 423 |
+
}),
|
| 424 |
+
"tiled_dit": ("BOOLEAN", {
|
| 425 |
+
"default": True,
|
| 426 |
+
"tooltip": "Significantly reduces VRAM usage at the cost of speed."
|
| 427 |
+
}),
|
| 428 |
+
"tile_size": ("INT", {
|
| 429 |
+
"default": 256,
|
| 430 |
+
"min": 32,
|
| 431 |
+
"max": 1024,
|
| 432 |
+
"step": 32,
|
| 433 |
+
}),
|
| 434 |
+
"tile_overlap": ("INT", {
|
| 435 |
+
"default": 24,
|
| 436 |
+
"min": 8,
|
| 437 |
+
"max": 512,
|
| 438 |
+
"step": 8,
|
| 439 |
+
}),
|
| 440 |
+
"unload_dit": ("BOOLEAN", {
|
| 441 |
+
"default": False,
|
| 442 |
+
"tooltip": "Unload DiT before decoding to reduce VRAM peak at the cost of speed."
|
| 443 |
+
}),
|
| 444 |
+
"sparse_ratio": ("FLOAT", {
|
| 445 |
+
"default": 2.0,
|
| 446 |
+
"min": 1.5,
|
| 447 |
+
"max": 2.0,
|
| 448 |
+
"step": 0.1,
|
| 449 |
+
"display": "slider",
|
| 450 |
+
"tooltip": "Recommended: 1.5 or 2.0\n1.5 → faster; 2.0 → more stable"
|
| 451 |
+
}),
|
| 452 |
+
"kv_ratio": ("FLOAT", {
|
| 453 |
+
"default": 3.0,
|
| 454 |
+
"min": 1.0,
|
| 455 |
+
"max": 3.0,
|
| 456 |
+
"step": 0.1,
|
| 457 |
+
"display": "slider",
|
| 458 |
+
"tooltip": "Recommended: 1.0 to 3.0\n1.0 → less vram; 3.0 → high quality"
|
| 459 |
+
}),
|
| 460 |
+
"local_range": ("INT", {
|
| 461 |
+
"default": 11,
|
| 462 |
+
"min": 9,
|
| 463 |
+
"max": 11,
|
| 464 |
+
"step": 2,
|
| 465 |
+
"tooltip": "Recommended: 9 or 11\nlocal_range=9 → sharper details; 11 → more stable results"
|
| 466 |
+
}),
|
| 467 |
+
"seed": ("INT", {
|
| 468 |
+
"default": 0,
|
| 469 |
+
"min": 0,
|
| 470 |
+
"max": 1125899906842624
|
| 471 |
+
}),
|
| 472 |
+
}
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
RETURN_TYPES = ("IMAGE",)
|
| 476 |
+
RETURN_NAMES = ("image",)
|
| 477 |
+
FUNCTION = "main"
|
| 478 |
+
CATEGORY = "FlashVSR"
|
| 479 |
+
#DESCRIPTION = ""
|
| 480 |
+
|
| 481 |
+
def main(self, pipe, frames, scale, color_fix, tiled_vae, tiled_dit, tile_size, tile_overlap, unload_dit, sparse_ratio, kv_ratio, local_range, seed):
|
| 482 |
+
_pipe, force_offload = pipe
|
| 483 |
+
output = flashvsr(_pipe, frames, scale, color_fix, tiled_vae, tiled_dit, tile_size, tile_overlap, unload_dit, sparse_ratio, kv_ratio, local_range, seed, force_offload)
|
| 484 |
+
return(output,)
|
| 485 |
+
|
| 486 |
+
class FlashVSRNode:
|
| 487 |
+
@classmethod
|
| 488 |
+
def INPUT_TYPES(cls):
|
| 489 |
+
return {
|
| 490 |
+
"required": {
|
| 491 |
+
"frames": ("IMAGE", {
|
| 492 |
+
"tooltip": "Sequential video frames as IMAGE tensor batch"
|
| 493 |
+
}),
|
| 494 |
+
"model": (["FlashVSR", "FlashVSR-v1.1"], {
|
| 495 |
+
"default": "FlashVSR-v1.1",
|
| 496 |
+
"tooltip": "Model version."
|
| 497 |
+
}),
|
| 498 |
+
"mode": (["tiny", "tiny-long", "full"], {
|
| 499 |
+
"default": "tiny",
|
| 500 |
+
"tooltip": 'Using "tiny-long" mode can significantly reduce VRAM used with long video input.'
|
| 501 |
+
}),
|
| 502 |
+
"scale": ("INT", {
|
| 503 |
+
"default": 2,
|
| 504 |
+
"min": 2,
|
| 505 |
+
"max": 4,
|
| 506 |
+
}),
|
| 507 |
+
"tiled_vae": ("BOOLEAN", {
|
| 508 |
+
"default": True,
|
| 509 |
+
"tooltip": "Disable tiling: faster decode but higher VRAM usage.\nSet to True for lower memory consumption at the cost of speed."
|
| 510 |
+
}),
|
| 511 |
+
"tiled_dit": ("BOOLEAN", {
|
| 512 |
+
"default": True,
|
| 513 |
+
"tooltip": "Significantly reduces VRAM usage at the cost of speed."
|
| 514 |
+
}),
|
| 515 |
+
"unload_dit": ("BOOLEAN", {
|
| 516 |
+
"default": False,
|
| 517 |
+
"tooltip": "Unload DiT before decoding to reduce VRAM peak at the cost of speed."
|
| 518 |
+
}),
|
| 519 |
+
"seed": ("INT", {
|
| 520 |
+
"default": 0,
|
| 521 |
+
"min": 0,
|
| 522 |
+
"max": 1125899906842624
|
| 523 |
+
}),
|
| 524 |
+
}
|
| 525 |
+
}
|
| 526 |
+
|
| 527 |
+
RETURN_TYPES = ("IMAGE",)
|
| 528 |
+
RETURN_NAMES = ("image",)
|
| 529 |
+
FUNCTION = "main"
|
| 530 |
+
CATEGORY = "FlashVSR"
|
| 531 |
+
DESCRIPTION = 'Download the entire "FlashVSR" folder with all the files inside it from "https://huggingface.co/JunhaoZhuang/FlashVSR" and put it in the "ComfyUI/models"'
|
| 532 |
+
|
| 533 |
+
def main(self, model, frames, mode, scale, tiled_vae, tiled_dit, unload_dit, seed):
|
| 534 |
+
wan_video_dit.USE_BLOCK_ATTN = False
|
| 535 |
+
_device = "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "auto"
|
| 536 |
+
if _device == "auto" or _device not in device_choices:
|
| 537 |
+
raise RuntimeError("No devices found to run FlashVSR!")
|
| 538 |
+
|
| 539 |
+
pipe = init_pipeline(model, mode, _device, torch.float16)
|
| 540 |
+
output = flashvsr(pipe, frames, scale, True, tiled_vae, tiled_dit, 256, 24, unload_dit, 2.0, 3.0, 11, seed, True)
|
| 541 |
+
return(output,)
|
| 542 |
+
|
| 543 |
+
NODE_CLASS_MAPPINGS = {
|
| 544 |
+
"FlashVSRNode": FlashVSRNode,
|
| 545 |
+
"FlashVSRNodeAdv": FlashVSRNodeAdv,
|
| 546 |
+
"FlashVSRInitPipe": FlashVSRNodeInitPipe,
|
| 547 |
+
}
|
| 548 |
+
|
| 549 |
+
NODE_DISPLAY_NAME_MAPPINGS = {
|
| 550 |
+
"FlashVSRNode": "FlashVSR Ultra-Fast",
|
| 551 |
+
"FlashVSRNodeAdv": "FlashVSR Ultra-Fast (Advanced)",
|
| 552 |
+
"FlashVSRInitPipe": "FlashVSR Init Pipeline",
|
| 553 |
+
}
|
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/posi_prompt.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4601107a11e4e11a936a6b79df579e54dbc99872132bf542151f0ffd65b4b1ef
|
| 3 |
+
size 4195504
|
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
numpy
|
| 4 |
+
einops
|
| 5 |
+
safetensors
|
| 6 |
+
tqdm
|
| 7 |
+
pillow
|
| 8 |
+
huggingface_hub
|
| 9 |
+
triton; platform_system!="Windows"
|
| 10 |
+
triton-windows; platform_system=="Windows"
|
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/LICENSE.txt
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .models import *
|
| 2 |
+
from .pipelines import *
|
| 3 |
+
from .schedulers import *
|
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/configs/__init__.py
ADDED
|
File without changes
|
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/configs/model_config.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing_extensions import Literal, TypeAlias
|
| 2 |
+
|
| 3 |
+
from ..models.wan_video_dit import WanModel
|
| 4 |
+
from ..models.wan_video_vae import WanVideoVAE
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
model_loader_configs = [
|
| 8 |
+
# These configs are provided for detecting model type automatically.
|
| 9 |
+
# The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
|
| 10 |
+
(None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
|
| 11 |
+
(None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
|
| 12 |
+
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
| 13 |
+
(None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"),
|
| 14 |
+
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
| 15 |
+
(None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
|
| 16 |
+
(None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
|
| 17 |
+
(None, "3ef3b1f8e1dab83d5b71fd7b617f859f", ["wan_video_dit"], [WanModel], "civitai"),
|
| 18 |
+
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
|
| 19 |
+
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
| 20 |
+
(None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
| 21 |
+
]
|
| 22 |
+
huggingface_model_loader_configs = [
|
| 23 |
+
# These configs are provided for detecting model type automatically.
|
| 24 |
+
# The format is (architecture_in_huggingface_config, huggingface_lib, model_name, redirected_architecture)
|
| 25 |
+
]
|
| 26 |
+
patch_model_loader_configs = [
|
| 27 |
+
# These configs are provided for detecting model type automatically.
|
| 28 |
+
# The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs)
|
| 29 |
+
]
|
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/TCDecoder.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Tiny AutoEncoder for Hunyuan Video (Decoder-only, pruned)
|
| 4 |
+
- Encoder removed
|
| 5 |
+
- Transplant/widening helpers removed
|
| 6 |
+
- Deepening (IdentityConv2d+ReLU) is now built into the decoder structure itself
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from tqdm.auto import tqdm
|
| 13 |
+
from collections import namedtuple
|
| 14 |
+
from einops import rearrange
|
| 15 |
+
import torch.nn.init as init
|
| 16 |
+
|
| 17 |
+
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
|
| 18 |
+
TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
|
| 19 |
+
|
| 20 |
+
# ----------------------------
|
| 21 |
+
# Utility / building blocks
|
| 22 |
+
# ----------------------------
|
| 23 |
+
|
| 24 |
+
class IdentityConv2d(nn.Conv2d):
|
| 25 |
+
"""Same-shape Conv2d initialized to identity (Dirac)."""
|
| 26 |
+
def __init__(self, C, kernel_size=3, bias=False):
|
| 27 |
+
pad = kernel_size // 2
|
| 28 |
+
super().__init__(C, C, kernel_size, padding=pad, bias=bias)
|
| 29 |
+
with torch.no_grad():
|
| 30 |
+
init.dirac_(self.weight)
|
| 31 |
+
if self.bias is not None:
|
| 32 |
+
self.bias.zero_()
|
| 33 |
+
|
| 34 |
+
def conv(n_in, n_out, **kwargs):
|
| 35 |
+
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
| 36 |
+
|
| 37 |
+
class Clamp(nn.Module):
|
| 38 |
+
def forward(self, x):
|
| 39 |
+
return torch.tanh(x / 3) * 3
|
| 40 |
+
|
| 41 |
+
class MemBlock(nn.Module):
|
| 42 |
+
def __init__(self, n_in, n_out):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.conv = nn.Sequential(
|
| 45 |
+
conv(n_in * 2, n_out), nn.ReLU(inplace=True),
|
| 46 |
+
conv(n_out, n_out), nn.ReLU(inplace=True),
|
| 47 |
+
conv(n_out, n_out)
|
| 48 |
+
)
|
| 49 |
+
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
| 50 |
+
self.act = nn.ReLU(inplace=True)
|
| 51 |
+
def forward(self, x, past):
|
| 52 |
+
return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
|
| 53 |
+
|
| 54 |
+
class TPool(nn.Module):
|
| 55 |
+
def __init__(self, n_f, stride):
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.stride = stride
|
| 58 |
+
self.conv = nn.Conv2d(n_f*stride, n_f, 1, bias=False)
|
| 59 |
+
def forward(self, x):
|
| 60 |
+
_NT, C, H, W = x.shape
|
| 61 |
+
return self.conv(x.reshape(-1, self.stride * C, H, W))
|
| 62 |
+
|
| 63 |
+
class TGrow(nn.Module):
|
| 64 |
+
def __init__(self, n_f, stride):
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.stride = stride
|
| 67 |
+
self.conv = nn.Conv2d(n_f, n_f*stride, 1, bias=False)
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
_NT, C, H, W = x.shape
|
| 70 |
+
x = self.conv(x)
|
| 71 |
+
return x.reshape(-1, C, H, W)
|
| 72 |
+
|
| 73 |
+
class PixelShuffle3d(nn.Module):
|
| 74 |
+
def __init__(self, ff, hh, ww):
|
| 75 |
+
super().__init__()
|
| 76 |
+
self.ff = ff
|
| 77 |
+
self.hh = hh
|
| 78 |
+
self.ww = ww
|
| 79 |
+
def forward(self, x):
|
| 80 |
+
# x: (B, C, F, H, W)
|
| 81 |
+
B, C, F, H, W = x.shape
|
| 82 |
+
if F % self.ff != 0:
|
| 83 |
+
first_frame = x[:, :, 0:1, :, :].repeat(1, 1, self.ff - F % self.ff, 1, 1)
|
| 84 |
+
x = torch.cat([first_frame, x], dim=2)
|
| 85 |
+
return rearrange(
|
| 86 |
+
x,
|
| 87 |
+
'b c (f ff) (h hh) (w ww) -> b (c ff hh ww) f h w',
|
| 88 |
+
ff=self.ff, hh=self.hh, ww=self.ww
|
| 89 |
+
).transpose(1, 2)
|
| 90 |
+
|
| 91 |
+
# ----------------------------
|
| 92 |
+
# Generic NTCHW graph executor (kept; used by decoder)
|
| 93 |
+
# ----------------------------
|
| 94 |
+
|
| 95 |
+
def apply_model_with_memblocks(model, x, parallel, show_progress_bar, mem=None):
|
| 96 |
+
"""
|
| 97 |
+
Apply a sequential model with memblocks to the given input.
|
| 98 |
+
Args:
|
| 99 |
+
- model: nn.Sequential of blocks to apply
|
| 100 |
+
- x: input data, of dimensions NTCHW
|
| 101 |
+
- parallel: if True, parallelize over timesteps (fast but uses O(T) memory)
|
| 102 |
+
if False, each timestep will be processed sequentially (slow but uses O(1) memory)
|
| 103 |
+
- show_progress_bar: if True, enables tqdm progressbar display
|
| 104 |
+
|
| 105 |
+
Returns NTCHW tensor of output data.
|
| 106 |
+
"""
|
| 107 |
+
assert x.ndim == 5, f"TAEHV operates on NTCHW tensors, but got {x.ndim}-dim tensor"
|
| 108 |
+
N, T, C, H, W = x.shape
|
| 109 |
+
if parallel:
|
| 110 |
+
x = x.reshape(N*T, C, H, W)
|
| 111 |
+
for b in tqdm(model, disable=not show_progress_bar):
|
| 112 |
+
if isinstance(b, MemBlock):
|
| 113 |
+
NT, C, H, W = x.shape
|
| 114 |
+
T = NT // N
|
| 115 |
+
_x = x.reshape(N, T, C, H, W)
|
| 116 |
+
mem = F.pad(_x, (0,0,0,0,0,0,1,0), value=0)[:,:T].reshape(x.shape)
|
| 117 |
+
x = b(x, mem)
|
| 118 |
+
else:
|
| 119 |
+
x = b(x)
|
| 120 |
+
NT, C, H, W = x.shape
|
| 121 |
+
T = NT // N
|
| 122 |
+
x = x.view(N, T, C, H, W)
|
| 123 |
+
else:
|
| 124 |
+
out = []
|
| 125 |
+
work_queue = [TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(N, T * C, H, W).chunk(T, dim=1))]
|
| 126 |
+
progress_bar = tqdm(range(T), disable=not show_progress_bar)
|
| 127 |
+
while work_queue:
|
| 128 |
+
xt, i = work_queue.pop(0)
|
| 129 |
+
if i == 0:
|
| 130 |
+
progress_bar.update(1)
|
| 131 |
+
if i == len(model):
|
| 132 |
+
out.append(xt)
|
| 133 |
+
else:
|
| 134 |
+
b = model[i]
|
| 135 |
+
if isinstance(b, MemBlock):
|
| 136 |
+
if mem[i] is None:
|
| 137 |
+
xt_new = b(xt, xt * 0)
|
| 138 |
+
mem[i] = xt
|
| 139 |
+
else:
|
| 140 |
+
xt_new = b(xt, mem[i])
|
| 141 |
+
mem[i].copy_(xt)
|
| 142 |
+
work_queue.insert(0, TWorkItem(xt_new, i+1))
|
| 143 |
+
elif isinstance(b, TPool):
|
| 144 |
+
if mem[i] is None:
|
| 145 |
+
mem[i] = []
|
| 146 |
+
mem[i].append(xt)
|
| 147 |
+
if len(mem[i]) > b.stride:
|
| 148 |
+
raise ValueError("TPool internal state invalid.")
|
| 149 |
+
elif len(mem[i]) == b.stride:
|
| 150 |
+
N_, C_, H_, W_ = xt.shape
|
| 151 |
+
xt = b(torch.cat(mem[i], 1).view(N_*b.stride, C_, H_, W_))
|
| 152 |
+
mem[i] = []
|
| 153 |
+
work_queue.insert(0, TWorkItem(xt, i+1))
|
| 154 |
+
elif isinstance(b, TGrow):
|
| 155 |
+
xt = b(xt)
|
| 156 |
+
NT, C_, H_, W_ = xt.shape
|
| 157 |
+
for xt_next in reversed(xt.view(N, b.stride*C_, H_, W_).chunk(b.stride, 1)):
|
| 158 |
+
work_queue.insert(0, TWorkItem(xt_next, i+1))
|
| 159 |
+
else:
|
| 160 |
+
xt = b(xt)
|
| 161 |
+
work_queue.insert(0, TWorkItem(xt, i+1))
|
| 162 |
+
progress_bar.close()
|
| 163 |
+
x = torch.stack(out, 1)
|
| 164 |
+
return x, mem
|
| 165 |
+
|
| 166 |
+
# ----------------------------
|
| 167 |
+
# Decoder-only TAEHV
|
| 168 |
+
# ----------------------------
|
| 169 |
+
|
| 170 |
+
class TAEHV(nn.Module):
|
| 171 |
+
image_channels = 3
|
| 172 |
+
def __init__(
|
| 173 |
+
self,
|
| 174 |
+
checkpoint_path="taehv.pth",
|
| 175 |
+
decoder_time_upscale=(True, True),
|
| 176 |
+
decoder_space_upscale=(True, True, True),
|
| 177 |
+
channels = [256, 128, 64, 64],
|
| 178 |
+
latent_channels = 16
|
| 179 |
+
):
|
| 180 |
+
"""Initialize TAEHV (decoder-only) with built-in deepening after every ReLU.
|
| 181 |
+
Deepening config: how_many_each=1, k=3 (fixed as requested).
|
| 182 |
+
"""
|
| 183 |
+
super().__init__()
|
| 184 |
+
self.latent_channels = latent_channels
|
| 185 |
+
n_f = channels
|
| 186 |
+
self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
|
| 187 |
+
|
| 188 |
+
# Build the decoder "skeleton"
|
| 189 |
+
base_decoder = nn.Sequential(
|
| 190 |
+
Clamp(), conv(self.latent_channels, n_f[0]), nn.ReLU(inplace=True),
|
| 191 |
+
|
| 192 |
+
MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]),
|
| 193 |
+
nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1),
|
| 194 |
+
TGrow(n_f[0], 1),
|
| 195 |
+
conv(n_f[0], n_f[1], bias=False),
|
| 196 |
+
|
| 197 |
+
MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]),
|
| 198 |
+
nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1),
|
| 199 |
+
TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1),
|
| 200 |
+
conv(n_f[1], n_f[2], bias=False),
|
| 201 |
+
|
| 202 |
+
MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]),
|
| 203 |
+
nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1),
|
| 204 |
+
TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1),
|
| 205 |
+
conv(n_f[2], n_f[3], bias=False),
|
| 206 |
+
|
| 207 |
+
nn.ReLU(inplace=True), conv(n_f[3], TAEHV.image_channels),
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# Inline deepening: insert (IdentityConv2d(k=3) + ReLU) after every ReLU
|
| 211 |
+
self.decoder = self._apply_identity_deepen(base_decoder, how_many_each=1, k=3)
|
| 212 |
+
|
| 213 |
+
self.pixel_shuffle = PixelShuffle3d(4, 8, 8)
|
| 214 |
+
|
| 215 |
+
if checkpoint_path is not None:
|
| 216 |
+
missing_keys = self.load_state_dict(
|
| 217 |
+
self.patch_tgrow_layers(torch.load(checkpoint_path, map_location="cpu", weights_only=True)),
|
| 218 |
+
strict=False
|
| 219 |
+
)
|
| 220 |
+
print('missing_keys', missing_keys)
|
| 221 |
+
|
| 222 |
+
# Initialize decoder mem state
|
| 223 |
+
self.mem = [None] * len(self.decoder)
|
| 224 |
+
|
| 225 |
+
@staticmethod
|
| 226 |
+
def _apply_identity_deepen(decoder: nn.Sequential, how_many_each=1, k=3) -> nn.Sequential:
|
| 227 |
+
"""Return a new Sequential where every nn.ReLU is followed by how_many_each*(IdentityConv2d(k)+ReLU)."""
|
| 228 |
+
new_layers = []
|
| 229 |
+
for b in decoder:
|
| 230 |
+
new_layers.append(b)
|
| 231 |
+
if isinstance(b, nn.ReLU):
|
| 232 |
+
# Deduce channel count from preceding layer
|
| 233 |
+
C = None
|
| 234 |
+
if len(new_layers) >= 2 and isinstance(new_layers[-2], nn.Conv2d):
|
| 235 |
+
C = new_layers[-2].out_channels
|
| 236 |
+
elif len(new_layers) >= 2 and isinstance(new_layers[-2], MemBlock):
|
| 237 |
+
C = new_layers[-2].conv[-1].out_channels
|
| 238 |
+
if C is not None:
|
| 239 |
+
for _ in range(how_many_each):
|
| 240 |
+
new_layers.append(IdentityConv2d(C, kernel_size=k, bias=False))
|
| 241 |
+
new_layers.append(nn.ReLU(inplace=True))
|
| 242 |
+
return nn.Sequential(*new_layers)
|
| 243 |
+
|
| 244 |
+
def patch_tgrow_layers(self, sd):
|
| 245 |
+
"""Patch TGrow layers to use a smaller kernel if needed (decoder-only)."""
|
| 246 |
+
new_sd = self.state_dict()
|
| 247 |
+
for i, layer in enumerate(self.decoder):
|
| 248 |
+
if isinstance(layer, TGrow):
|
| 249 |
+
key = f"decoder.{i}.conv.weight"
|
| 250 |
+
if key in sd and sd[key].shape[0] > new_sd[key].shape[0]:
|
| 251 |
+
sd[key] = sd[key][-new_sd[key].shape[0]:]
|
| 252 |
+
return sd
|
| 253 |
+
|
| 254 |
+
def decode_video(self, x, parallel=True, show_progress_bar=False, cond=None):
|
| 255 |
+
"""Decode a sequence of frames from latents.
|
| 256 |
+
x: NTCHW latent tensor; returns NTCHW RGB in ~[0, 1].
|
| 257 |
+
"""
|
| 258 |
+
trim_flag = self.mem[-8] is None # keeps original relative check
|
| 259 |
+
|
| 260 |
+
if cond is not None:
|
| 261 |
+
x = torch.cat([self.pixel_shuffle(cond), x], dim=2)
|
| 262 |
+
|
| 263 |
+
x, self.mem = apply_model_with_memblocks(self.decoder, x, parallel, show_progress_bar, mem=self.mem)
|
| 264 |
+
|
| 265 |
+
if trim_flag:
|
| 266 |
+
return x[:, self.frames_to_trim:]
|
| 267 |
+
return x
|
| 268 |
+
|
| 269 |
+
def forward(self, *args, **kwargs):
|
| 270 |
+
raise NotImplementedError("Decoder-only model: call decode_video(...) instead.")
|
| 271 |
+
|
| 272 |
+
def clean_mem(self):
|
| 273 |
+
self.mem = [None] * len(self.decoder)
|
| 274 |
+
|
| 275 |
+
class DotDict(dict):
|
| 276 |
+
__getattr__ = dict.__getitem__
|
| 277 |
+
__setattr__ = dict.__setitem__
|
| 278 |
+
|
| 279 |
+
class TAEW2_1DiffusersWrapper(nn.Module):
|
| 280 |
+
def __init__(self, pretrained_path=None, channels = [256, 128, 64, 64]):
|
| 281 |
+
super().__init__()
|
| 282 |
+
self.dtype = torch.bfloat16
|
| 283 |
+
self.device = "cuda"
|
| 284 |
+
self.taehv = TAEHV(pretrained_path, channels = channels).to(self.dtype)
|
| 285 |
+
self.temperal_downsample = [True, True, False] # [sic]
|
| 286 |
+
self.config = DotDict(scaling_factor=1.0, latents_mean=torch.zeros(16), z_dim=16, latents_std=torch.ones(16))
|
| 287 |
+
|
| 288 |
+
def decode(self, latents, return_dict=None):
|
| 289 |
+
n, c, t, h, w = latents.shape
|
| 290 |
+
return (self.taehv.decode_video(latents.transpose(1, 2), parallel=False).transpose(1, 2).mul_(2).sub_(1),)
|
| 291 |
+
|
| 292 |
+
def stream_decode_with_cond(self, latents, tiled=False, cond=None):
|
| 293 |
+
n, c, t, h, w = latents.shape
|
| 294 |
+
return self.taehv.decode_video(latents.transpose(1, 2), parallel=False, cond=cond).transpose(1, 2).mul_(2).sub_(1)
|
| 295 |
+
|
| 296 |
+
def clean_mem(self):
|
| 297 |
+
self.taehv.clean_mem()
|
| 298 |
+
|
| 299 |
+
# ----------------------------
|
| 300 |
+
# Simplified builder (no small, no transplant, no post-hoc deepening)
|
| 301 |
+
# ----------------------------
|
| 302 |
+
|
| 303 |
+
def build_tcdecoder(new_channels = [512, 256, 128, 128],
|
| 304 |
+
device="cuda",
|
| 305 |
+
dtype=torch.bfloat16,
|
| 306 |
+
new_latent_channels=None):
|
| 307 |
+
"""
|
| 308 |
+
构建“更宽”的 decoder;深度增强(IdentityConv2d+ReLU)已在 TAEHV 内部完成。
|
| 309 |
+
- 不创建 small / 不做移植
|
| 310 |
+
- base_ckpt_path 参数保留但不使用(接口兼容)
|
| 311 |
+
|
| 312 |
+
返回:big (单个模型)
|
| 313 |
+
"""
|
| 314 |
+
if new_latent_channels is not None:
|
| 315 |
+
big = TAEHV(checkpoint_path=None, channels=new_channels, latent_channels=new_latent_channels).to(device).to(dtype).train()
|
| 316 |
+
else:
|
| 317 |
+
big = TAEHV(checkpoint_path=None, channels=new_channels).to(device).to(dtype).train()
|
| 318 |
+
|
| 319 |
+
big.clean_mem()
|
| 320 |
+
return big
|
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .model_manager import *
|
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/model_manager.py
ADDED
|
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, torch, json, importlib
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
|
| 5 |
+
from .utils import load_state_dict, init_weights_on_device, hash_state_dict_keys, split_state_dict_with_prefix
|
| 6 |
+
|
| 7 |
+
def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):
|
| 8 |
+
loaded_model_names, loaded_models = [], []
|
| 9 |
+
for model_name, model_class in zip(model_names, model_classes):
|
| 10 |
+
#print(f" model_name: {model_name} model_class: {model_class.__name__}")
|
| 11 |
+
state_dict_converter = model_class.state_dict_converter()
|
| 12 |
+
if model_resource == "civitai":
|
| 13 |
+
state_dict_results = state_dict_converter.from_civitai(state_dict)
|
| 14 |
+
elif model_resource == "diffusers":
|
| 15 |
+
state_dict_results = state_dict_converter.from_diffusers(state_dict)
|
| 16 |
+
if isinstance(state_dict_results, tuple):
|
| 17 |
+
model_state_dict, extra_kwargs = state_dict_results
|
| 18 |
+
#print(f" This model is initialized with extra kwargs: {extra_kwargs}")
|
| 19 |
+
else:
|
| 20 |
+
model_state_dict, extra_kwargs = state_dict_results, {}
|
| 21 |
+
torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
|
| 22 |
+
with init_weights_on_device():
|
| 23 |
+
model = model_class(**extra_kwargs)
|
| 24 |
+
if hasattr(model, "eval"):
|
| 25 |
+
model = model.eval()
|
| 26 |
+
model.load_state_dict(model_state_dict, assign=True)
|
| 27 |
+
model = model.to(dtype=torch_dtype, device=device)
|
| 28 |
+
loaded_model_names.append(model_name)
|
| 29 |
+
loaded_models.append(model)
|
| 30 |
+
return loaded_model_names, loaded_models
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
|
| 34 |
+
loaded_model_names, loaded_models = [], []
|
| 35 |
+
for model_name, model_class in zip(model_names, model_classes):
|
| 36 |
+
if torch_dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
| 37 |
+
model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
|
| 38 |
+
else:
|
| 39 |
+
model = model_class.from_pretrained(file_path).eval().to(dtype=torch_dtype)
|
| 40 |
+
if torch_dtype == torch.float16 and hasattr(model, "half"):
|
| 41 |
+
model = model.half()
|
| 42 |
+
try:
|
| 43 |
+
model = model.to(device=device)
|
| 44 |
+
except:
|
| 45 |
+
pass
|
| 46 |
+
loaded_model_names.append(model_name)
|
| 47 |
+
loaded_models.append(model)
|
| 48 |
+
return loaded_model_names, loaded_models
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device):
|
| 52 |
+
#print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}")
|
| 53 |
+
base_state_dict = base_model.state_dict()
|
| 54 |
+
base_model.to("cpu")
|
| 55 |
+
del base_model
|
| 56 |
+
model = model_class(**extra_kwargs)
|
| 57 |
+
model.load_state_dict(base_state_dict, strict=False)
|
| 58 |
+
model.load_state_dict(state_dict, strict=False)
|
| 59 |
+
model.to(dtype=torch_dtype, device=device)
|
| 60 |
+
return model
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device):
|
| 64 |
+
loaded_model_names, loaded_models = [], []
|
| 65 |
+
for model_name, model_class in zip(model_names, model_classes):
|
| 66 |
+
while True:
|
| 67 |
+
for model_id in range(len(model_manager.model)):
|
| 68 |
+
base_model_name = model_manager.model_name[model_id]
|
| 69 |
+
if base_model_name == model_name:
|
| 70 |
+
base_model_path = model_manager.model_path[model_id]
|
| 71 |
+
base_model = model_manager.model[model_id]
|
| 72 |
+
print(f" Adding patch model to {base_model_name} ({base_model_path})")
|
| 73 |
+
patched_model = load_single_patch_model_from_single_file(
|
| 74 |
+
state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device)
|
| 75 |
+
loaded_model_names.append(base_model_name)
|
| 76 |
+
loaded_models.append(patched_model)
|
| 77 |
+
model_manager.model.pop(model_id)
|
| 78 |
+
model_manager.model_path.pop(model_id)
|
| 79 |
+
model_manager.model_name.pop(model_id)
|
| 80 |
+
break
|
| 81 |
+
else:
|
| 82 |
+
break
|
| 83 |
+
return loaded_model_names, loaded_models
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class ModelDetectorTemplate:
|
| 88 |
+
def __init__(self):
|
| 89 |
+
pass
|
| 90 |
+
|
| 91 |
+
def match(self, file_path="", state_dict={}):
|
| 92 |
+
return False
|
| 93 |
+
|
| 94 |
+
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
| 95 |
+
return [], []
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class ModelDetectorFromSingleFile:
|
| 100 |
+
def __init__(self, model_loader_configs=[]):
|
| 101 |
+
self.keys_hash_with_shape_dict = {}
|
| 102 |
+
self.keys_hash_dict = {}
|
| 103 |
+
for metadata in model_loader_configs:
|
| 104 |
+
self.add_model_metadata(*metadata)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource):
|
| 108 |
+
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource)
|
| 109 |
+
if keys_hash is not None:
|
| 110 |
+
self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def match(self, file_path="", state_dict={}):
|
| 114 |
+
if isinstance(file_path, str) and os.path.isdir(file_path):
|
| 115 |
+
return False
|
| 116 |
+
if len(state_dict) == 0:
|
| 117 |
+
state_dict = load_state_dict(file_path)
|
| 118 |
+
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
| 119 |
+
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
| 120 |
+
return True
|
| 121 |
+
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
|
| 122 |
+
if keys_hash in self.keys_hash_dict:
|
| 123 |
+
return True
|
| 124 |
+
return False
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
| 128 |
+
if len(state_dict) == 0:
|
| 129 |
+
state_dict = load_state_dict(file_path)
|
| 130 |
+
|
| 131 |
+
# Load models with strict matching
|
| 132 |
+
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
| 133 |
+
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
| 134 |
+
model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape]
|
| 135 |
+
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
|
| 136 |
+
return loaded_model_names, loaded_models
|
| 137 |
+
|
| 138 |
+
# Load models without strict matching
|
| 139 |
+
# (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
|
| 140 |
+
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
|
| 141 |
+
if keys_hash in self.keys_hash_dict:
|
| 142 |
+
model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
|
| 143 |
+
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
|
| 144 |
+
return loaded_model_names, loaded_models
|
| 145 |
+
|
| 146 |
+
return loaded_model_names, loaded_models
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
|
| 151 |
+
def __init__(self, model_loader_configs=[]):
|
| 152 |
+
super().__init__(model_loader_configs)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def match(self, file_path="", state_dict={}):
|
| 156 |
+
if isinstance(file_path, str) and os.path.isdir(file_path):
|
| 157 |
+
return False
|
| 158 |
+
if len(state_dict) == 0:
|
| 159 |
+
state_dict = load_state_dict(file_path)
|
| 160 |
+
splited_state_dict = split_state_dict_with_prefix(state_dict)
|
| 161 |
+
for sub_state_dict in splited_state_dict:
|
| 162 |
+
if super().match(file_path, sub_state_dict):
|
| 163 |
+
return True
|
| 164 |
+
return False
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
| 168 |
+
# Split the state_dict and load from each component
|
| 169 |
+
splited_state_dict = split_state_dict_with_prefix(state_dict)
|
| 170 |
+
valid_state_dict = {}
|
| 171 |
+
for sub_state_dict in splited_state_dict:
|
| 172 |
+
if super().match(file_path, sub_state_dict):
|
| 173 |
+
valid_state_dict.update(sub_state_dict)
|
| 174 |
+
if super().match(file_path, valid_state_dict):
|
| 175 |
+
loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype)
|
| 176 |
+
else:
|
| 177 |
+
loaded_model_names, loaded_models = [], []
|
| 178 |
+
for sub_state_dict in splited_state_dict:
|
| 179 |
+
if super().match(file_path, sub_state_dict):
|
| 180 |
+
loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype)
|
| 181 |
+
loaded_model_names += loaded_model_names_
|
| 182 |
+
loaded_models += loaded_models_
|
| 183 |
+
return loaded_model_names, loaded_models
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class ModelDetectorFromHuggingfaceFolder:
|
| 188 |
+
def __init__(self, model_loader_configs=[]):
|
| 189 |
+
self.architecture_dict = {}
|
| 190 |
+
for metadata in model_loader_configs:
|
| 191 |
+
self.add_model_metadata(*metadata)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def add_model_metadata(self, architecture, huggingface_lib, model_name, redirected_architecture):
|
| 195 |
+
self.architecture_dict[architecture] = (huggingface_lib, model_name, redirected_architecture)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def match(self, file_path="", state_dict={}):
|
| 199 |
+
if not isinstance(file_path, str) or os.path.isfile(file_path):
|
| 200 |
+
return False
|
| 201 |
+
file_list = os.listdir(file_path)
|
| 202 |
+
if "config.json" not in file_list:
|
| 203 |
+
return False
|
| 204 |
+
with open(os.path.join(file_path, "config.json"), "r") as f:
|
| 205 |
+
config = json.load(f)
|
| 206 |
+
if "architectures" not in config and "_class_name" not in config:
|
| 207 |
+
return False
|
| 208 |
+
return True
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
| 212 |
+
with open(os.path.join(file_path, "config.json"), "r") as f:
|
| 213 |
+
config = json.load(f)
|
| 214 |
+
loaded_model_names, loaded_models = [], []
|
| 215 |
+
architectures = config["architectures"] if "architectures" in config else [config["_class_name"]]
|
| 216 |
+
for architecture in architectures:
|
| 217 |
+
huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture]
|
| 218 |
+
if redirected_architecture is not None:
|
| 219 |
+
architecture = redirected_architecture
|
| 220 |
+
model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
|
| 221 |
+
loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
|
| 222 |
+
loaded_model_names += loaded_model_names_
|
| 223 |
+
loaded_models += loaded_models_
|
| 224 |
+
return loaded_model_names, loaded_models
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class ModelDetectorFromPatchedSingleFile:
|
| 229 |
+
def __init__(self, model_loader_configs=[]):
|
| 230 |
+
self.keys_hash_with_shape_dict = {}
|
| 231 |
+
for metadata in model_loader_configs:
|
| 232 |
+
self.add_model_metadata(*metadata)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs):
|
| 236 |
+
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def match(self, file_path="", state_dict={}):
|
| 240 |
+
if not isinstance(file_path, str) or os.path.isdir(file_path):
|
| 241 |
+
return False
|
| 242 |
+
if len(state_dict) == 0:
|
| 243 |
+
state_dict = load_state_dict(file_path)
|
| 244 |
+
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
| 245 |
+
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
| 246 |
+
return True
|
| 247 |
+
return False
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs):
|
| 251 |
+
if len(state_dict) == 0:
|
| 252 |
+
state_dict = load_state_dict(file_path)
|
| 253 |
+
|
| 254 |
+
# Load models with strict matching
|
| 255 |
+
loaded_model_names, loaded_models = [], []
|
| 256 |
+
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
| 257 |
+
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
| 258 |
+
model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape]
|
| 259 |
+
loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
|
| 260 |
+
state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device)
|
| 261 |
+
loaded_model_names += loaded_model_names_
|
| 262 |
+
loaded_models += loaded_models_
|
| 263 |
+
return loaded_model_names, loaded_models
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
class ModelManager:
|
| 268 |
+
def __init__(
|
| 269 |
+
self,
|
| 270 |
+
torch_dtype=torch.float16,
|
| 271 |
+
device="cuda",
|
| 272 |
+
file_path_list: List[str] = [],
|
| 273 |
+
):
|
| 274 |
+
self.torch_dtype = torch_dtype
|
| 275 |
+
self.device = device
|
| 276 |
+
self.model = []
|
| 277 |
+
self.model_path = []
|
| 278 |
+
self.model_name = []
|
| 279 |
+
self.model_detector = [
|
| 280 |
+
ModelDetectorFromSingleFile(model_loader_configs),
|
| 281 |
+
ModelDetectorFromSplitedSingleFile(model_loader_configs),
|
| 282 |
+
ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
|
| 283 |
+
ModelDetectorFromPatchedSingleFile(patch_model_loader_configs),
|
| 284 |
+
]
|
| 285 |
+
self.load_models(file_path_list)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None):
|
| 289 |
+
print(f"Loading models from file: {file_path}")
|
| 290 |
+
if len(state_dict) == 0:
|
| 291 |
+
state_dict = load_state_dict(file_path)
|
| 292 |
+
model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device)
|
| 293 |
+
for model_name, model in zip(model_names, models):
|
| 294 |
+
self.model.append(model)
|
| 295 |
+
self.model_path.append(file_path)
|
| 296 |
+
self.model_name.append(model_name)
|
| 297 |
+
#print(f" The following models are loaded: {model_names}.")
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]):
|
| 301 |
+
print(f"Loading models from folder: {file_path}")
|
| 302 |
+
model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device)
|
| 303 |
+
for model_name, model in zip(model_names, models):
|
| 304 |
+
self.model.append(model)
|
| 305 |
+
self.model_path.append(file_path)
|
| 306 |
+
self.model_name.append(model_name)
|
| 307 |
+
#print(f" The following models are loaded: {model_names}.")
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}):
|
| 311 |
+
print(f"Loading patch models from file: {file_path}")
|
| 312 |
+
model_names, models = load_patch_model_from_single_file(
|
| 313 |
+
state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device)
|
| 314 |
+
for model_name, model in zip(model_names, models):
|
| 315 |
+
self.model.append(model)
|
| 316 |
+
self.model_path.append(file_path)
|
| 317 |
+
self.model_name.append(model_name)
|
| 318 |
+
print(f" The following patched models are loaded: {model_names}.")
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
|
| 322 |
+
if isinstance(file_path, list):
|
| 323 |
+
for file_path_ in file_path:
|
| 324 |
+
self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
|
| 325 |
+
else:
|
| 326 |
+
print(f"Loading LoRA models from file: {file_path}")
|
| 327 |
+
is_loaded = False
|
| 328 |
+
if len(state_dict) == 0:
|
| 329 |
+
state_dict = load_state_dict(file_path)
|
| 330 |
+
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
|
| 331 |
+
for lora in get_lora_loaders():
|
| 332 |
+
match_results = lora.match(model, state_dict)
|
| 333 |
+
if match_results is not None:
|
| 334 |
+
print(f" Adding LoRA to {model_name} ({model_path}).")
|
| 335 |
+
lora_prefix, model_resource = match_results
|
| 336 |
+
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
|
| 337 |
+
is_loaded = True
|
| 338 |
+
break
|
| 339 |
+
if not is_loaded:
|
| 340 |
+
print(f" Cannot load LoRA: {file_path}")
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
|
| 344 |
+
#print(f"Loading models from: {file_path}")
|
| 345 |
+
if device is None: device = self.device
|
| 346 |
+
if torch_dtype is None: torch_dtype = self.torch_dtype
|
| 347 |
+
if isinstance(file_path, list):
|
| 348 |
+
state_dict = {}
|
| 349 |
+
for path in file_path:
|
| 350 |
+
state_dict.update(load_state_dict(path))
|
| 351 |
+
elif os.path.isfile(file_path):
|
| 352 |
+
state_dict = load_state_dict(file_path)
|
| 353 |
+
else:
|
| 354 |
+
state_dict = None
|
| 355 |
+
for model_detector in self.model_detector:
|
| 356 |
+
if model_detector.match(file_path, state_dict):
|
| 357 |
+
model_names, models = model_detector.load(
|
| 358 |
+
file_path, state_dict,
|
| 359 |
+
device=device, torch_dtype=torch_dtype,
|
| 360 |
+
allowed_model_names=model_names, model_manager=self
|
| 361 |
+
)
|
| 362 |
+
for model_name, model in zip(model_names, models):
|
| 363 |
+
self.model.append(model)
|
| 364 |
+
self.model_path.append(file_path)
|
| 365 |
+
self.model_name.append(model_name)
|
| 366 |
+
#print(f" The following models are loaded: {model_names}.")
|
| 367 |
+
break
|
| 368 |
+
else:
|
| 369 |
+
print(f" We cannot detect the model type. No models are loaded.")
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def load_models(self, file_path_list, model_names=None, device=None, torch_dtype=None):
|
| 373 |
+
for file_path in file_path_list:
|
| 374 |
+
self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype)
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
def fetch_model(self, model_name, file_path=None, require_model_path=False):
|
| 378 |
+
fetched_models = []
|
| 379 |
+
fetched_model_paths = []
|
| 380 |
+
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
|
| 381 |
+
if file_path is not None and file_path != model_path:
|
| 382 |
+
continue
|
| 383 |
+
if model_name == model_name_:
|
| 384 |
+
fetched_models.append(model)
|
| 385 |
+
fetched_model_paths.append(model_path)
|
| 386 |
+
if len(fetched_models) == 0:
|
| 387 |
+
#print(f"No {model_name} models available.")
|
| 388 |
+
return None
|
| 389 |
+
if len(fetched_models) == 1:
|
| 390 |
+
print(f"Using {model_name} from {fetched_model_paths[0]}")
|
| 391 |
+
else:
|
| 392 |
+
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}")
|
| 393 |
+
if require_model_path:
|
| 394 |
+
return fetched_models[0], fetched_model_paths[0]
|
| 395 |
+
else:
|
| 396 |
+
return fetched_models[0]
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def to(self, device):
|
| 400 |
+
for model in self.model:
|
| 401 |
+
model.to(device)
|
| 402 |
+
|
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/sparse_sage/LICENSE.txt
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright 2025 Jintao Zhang
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/sparse_sage/core.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
https://github.com/jt-zhang/Sparse_SageAttention_API
|
| 3 |
+
|
| 4 |
+
Copyright (c) 2024 by SageAttention team.
|
| 5 |
+
|
| 6 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
you may not use this file except in compliance with the License.
|
| 8 |
+
You may obtain a copy of the License at
|
| 9 |
+
|
| 10 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
|
| 12 |
+
Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
See the License for the specific language governing permissions and
|
| 16 |
+
limitations under the License.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from .quant_per_block import per_block_int8
|
| 20 |
+
from .sparse_int8_attn import forward as sparse_sageattn_fwd
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def sparse_sageattn(q, k, v, mask_id = None, is_causal=False, tensor_layout="HND"):
|
| 25 |
+
if mask_id is None:
|
| 26 |
+
mask_id = torch.ones((q.shape[0], q.shape[1], (q.shape[2] + 128 - 1)//128, (q.shape[3] + 64 - 1)//64), dtype=torch.int8, device=q.device) # TODO
|
| 27 |
+
|
| 28 |
+
output_dtype = q.dtype
|
| 29 |
+
if output_dtype == torch.bfloat16 or output_dtype == torch.float32:
|
| 30 |
+
v = v.to(torch.float16)
|
| 31 |
+
|
| 32 |
+
seq_dim = 1 if tensor_layout == "NHD" else 2
|
| 33 |
+
km = k.mean(dim=seq_dim, keepdim=True)
|
| 34 |
+
# km = torch.zeros((k.size(0), k.size(1), 1, k.size(3)), dtype=torch.float16, device=k.device) # Placeholder for mean, not used in quantization
|
| 35 |
+
|
| 36 |
+
q_int8, q_scale, k_int8, k_scale = per_block_int8(q, k, km=km, tensor_layout=tensor_layout)
|
| 37 |
+
|
| 38 |
+
o = sparse_sageattn_fwd(
|
| 39 |
+
q_int8, k_int8, mask_id, v, q_scale, k_scale,
|
| 40 |
+
is_causal=is_causal, tensor_layout=tensor_layout, output_dtype=output_dtype
|
| 41 |
+
)
|
| 42 |
+
return o
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# flops = 4 * q.size(0) * q.size(1) * q.size(2)**2 * q.size(3) / (2 if is_causal else 1)
|
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/sparse_sage/quant_per_block.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright (c) 2024 by SageAttention team.
|
| 3 |
+
|
| 4 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
you may not use this file except in compliance with the License.
|
| 6 |
+
You may obtain a copy of the License at
|
| 7 |
+
|
| 8 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
|
| 10 |
+
Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
See the License for the specific language governing permissions and
|
| 14 |
+
limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import triton
|
| 19 |
+
import triton.language as tl
|
| 20 |
+
|
| 21 |
+
@triton.jit
|
| 22 |
+
def quant_per_block_int8_kernel(Input, Output, Scale, L,
|
| 23 |
+
stride_iz, stride_ih, stride_in,
|
| 24 |
+
stride_oz, stride_oh, stride_on,
|
| 25 |
+
stride_sz, stride_sh,
|
| 26 |
+
sm_scale,
|
| 27 |
+
C: tl.constexpr, BLK: tl.constexpr):
|
| 28 |
+
off_blk = tl.program_id(0)
|
| 29 |
+
off_h = tl.program_id(1)
|
| 30 |
+
off_b = tl.program_id(2)
|
| 31 |
+
|
| 32 |
+
offs_n = off_blk * BLK + tl.arange(0, BLK)
|
| 33 |
+
offs_k = tl.arange(0, C)
|
| 34 |
+
|
| 35 |
+
input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :]
|
| 36 |
+
output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :]
|
| 37 |
+
scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk
|
| 38 |
+
|
| 39 |
+
x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
|
| 40 |
+
x = x.to(tl.float32)
|
| 41 |
+
x *= sm_scale
|
| 42 |
+
scale = tl.max(tl.abs(x)) / 127.
|
| 43 |
+
x_int8 = x / scale
|
| 44 |
+
x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
|
| 45 |
+
x_int8 = x_int8.to(tl.int8)
|
| 46 |
+
tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
|
| 47 |
+
tl.store(scale_ptrs, scale)
|
| 48 |
+
|
| 49 |
+
def per_block_int8(q, k, km=None, BLKQ=128, BLKK=64, sm_scale=None, tensor_layout="HND"):
|
| 50 |
+
q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
|
| 51 |
+
k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
|
| 52 |
+
|
| 53 |
+
if km is not None:
|
| 54 |
+
k = k - km
|
| 55 |
+
|
| 56 |
+
if tensor_layout == "HND":
|
| 57 |
+
b, h_qo, qo_len, head_dim = q.shape
|
| 58 |
+
_, h_kv, kv_len, _ = k.shape
|
| 59 |
+
|
| 60 |
+
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2)
|
| 61 |
+
stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(1), q_int8.stride(2)
|
| 62 |
+
stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2)
|
| 63 |
+
stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(1), k_int8.stride(2)
|
| 64 |
+
elif tensor_layout == "NHD":
|
| 65 |
+
b, qo_len, h_qo, head_dim = q.shape
|
| 66 |
+
_, kv_len, h_kv, _ = k.shape
|
| 67 |
+
|
| 68 |
+
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1)
|
| 69 |
+
stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(2), q_int8.stride(1)
|
| 70 |
+
stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1)
|
| 71 |
+
stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(2), k_int8.stride(1)
|
| 72 |
+
else:
|
| 73 |
+
raise ValueError(f"Unknown tensor layout: {tensor_layout}")
|
| 74 |
+
|
| 75 |
+
q_scale = torch.empty((b, h_qo, (qo_len + BLKQ - 1) // BLKQ), device=q.device, dtype=torch.float32)
|
| 76 |
+
k_scale = torch.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32)
|
| 77 |
+
|
| 78 |
+
if sm_scale is None:
|
| 79 |
+
sm_scale = head_dim**-0.5
|
| 80 |
+
|
| 81 |
+
grid = ((qo_len + BLKQ - 1) // BLKQ, h_qo, b)
|
| 82 |
+
quant_per_block_int8_kernel[grid](
|
| 83 |
+
q, q_int8, q_scale, qo_len,
|
| 84 |
+
stride_bz_q, stride_h_q, stride_seq_q,
|
| 85 |
+
stride_bz_qo, stride_h_qo, stride_seq_qo,
|
| 86 |
+
q_scale.stride(0), q_scale.stride(1),
|
| 87 |
+
sm_scale=(sm_scale * 1.44269504),
|
| 88 |
+
C=head_dim, BLK=BLKQ
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
grid = ((kv_len + BLKK - 1) // BLKK, h_kv, b)
|
| 92 |
+
quant_per_block_int8_kernel[grid](
|
| 93 |
+
k, k_int8, k_scale, kv_len,
|
| 94 |
+
stride_bz_k, stride_h_k, stride_seq_k,
|
| 95 |
+
stride_bz_ko, stride_h_ko, stride_seq_ko,
|
| 96 |
+
k_scale.stride(0), k_scale.stride(1),
|
| 97 |
+
sm_scale=1.0,
|
| 98 |
+
C=head_dim, BLK=BLKK
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
return q_int8, q_scale, k_int8, k_scale
|
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/sparse_sage/sparse_int8_attn.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright (c) 2024 by SageAttention team.
|
| 3 |
+
|
| 4 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
you may not use this file except in compliance with the License.
|
| 6 |
+
You may obtain a copy of the License at
|
| 7 |
+
|
| 8 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
|
| 10 |
+
Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
See the License for the specific language governing permissions and
|
| 14 |
+
limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import torch, math
|
| 18 |
+
import triton
|
| 19 |
+
import triton.language as tl
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
|
| 22 |
+
@triton.jit
|
| 23 |
+
def _attn_fwd_inner(acc, l_i, old_m, q, q_scale, kv_len,
|
| 24 |
+
K_ptrs, K_bid_ptr, K_scale_ptr, V_ptrs, stride_kn, stride_vn, start_m,
|
| 25 |
+
BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr,
|
| 26 |
+
STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr,
|
| 27 |
+
):
|
| 28 |
+
if STAGE == 1:
|
| 29 |
+
lo, hi = 0, start_m * BLOCK_M
|
| 30 |
+
elif STAGE == 2:
|
| 31 |
+
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
|
| 32 |
+
lo = tl.multiple_of(lo, BLOCK_M)
|
| 33 |
+
K_scale_ptr += lo // BLOCK_N
|
| 34 |
+
K_ptrs += stride_kn * lo
|
| 35 |
+
V_ptrs += stride_vn * lo
|
| 36 |
+
elif STAGE == 3:
|
| 37 |
+
lo, hi = 0, kv_len
|
| 38 |
+
for start_n in range(lo, hi, BLOCK_N):
|
| 39 |
+
kbid = tl.load(K_bid_ptr + start_n//BLOCK_N)
|
| 40 |
+
if kbid:
|
| 41 |
+
k_mask = offs_n[None, :] < (kv_len - start_n)
|
| 42 |
+
k = tl.load(K_ptrs, mask = k_mask)
|
| 43 |
+
k_scale = tl.load(K_scale_ptr)
|
| 44 |
+
qk = tl.dot(q, k).to(tl.float32) * q_scale * k_scale
|
| 45 |
+
if STAGE == 2:
|
| 46 |
+
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
|
| 47 |
+
qk = qk + tl.where(mask, 0, -1.0e6)
|
| 48 |
+
local_m = tl.max(qk, 1)
|
| 49 |
+
new_m = tl.maximum(old_m, local_m)
|
| 50 |
+
qk -= new_m[:, None]
|
| 51 |
+
else:
|
| 52 |
+
local_m = tl.max(qk, 1)
|
| 53 |
+
new_m = tl.maximum(old_m, local_m)
|
| 54 |
+
qk = qk - new_m[:, None]
|
| 55 |
+
|
| 56 |
+
p = tl.math.exp2(qk)
|
| 57 |
+
l_ij = tl.sum(p, 1)
|
| 58 |
+
alpha = tl.math.exp2(old_m - new_m)
|
| 59 |
+
l_i = l_i * alpha + l_ij
|
| 60 |
+
acc = acc * alpha[:, None]
|
| 61 |
+
v = tl.load(V_ptrs, mask = offs_n[:, None] < (kv_len - start_n))
|
| 62 |
+
p = p.to(tl.float16)
|
| 63 |
+
acc += tl.dot(p, v, out_dtype=tl.float16)
|
| 64 |
+
old_m = new_m
|
| 65 |
+
K_ptrs += BLOCK_N * stride_kn
|
| 66 |
+
K_scale_ptr += 1
|
| 67 |
+
V_ptrs += BLOCK_N * stride_vn
|
| 68 |
+
return acc, l_i, old_m
|
| 69 |
+
|
| 70 |
+
@triton.jit
|
| 71 |
+
def _attn_fwd(Q, K, K_blkid, V, Q_scale, K_scale, Out,
|
| 72 |
+
stride_qz, stride_qh, stride_qn,
|
| 73 |
+
stride_kz, stride_kh, stride_kn,
|
| 74 |
+
stride_vz, stride_vh, stride_vn,
|
| 75 |
+
stride_oz, stride_oh, stride_on,
|
| 76 |
+
stride_kbidq, stride_kbidk,
|
| 77 |
+
qo_len, kv_len, H:tl.constexpr, num_kv_groups:tl.constexpr,
|
| 78 |
+
HEAD_DIM: tl.constexpr,
|
| 79 |
+
BLOCK_M: tl.constexpr,
|
| 80 |
+
BLOCK_N: tl.constexpr,
|
| 81 |
+
STAGE: tl.constexpr
|
| 82 |
+
):
|
| 83 |
+
start_m = tl.program_id(0)
|
| 84 |
+
off_z = tl.program_id(2).to(tl.int64)
|
| 85 |
+
off_h = tl.program_id(1).to(tl.int64)
|
| 86 |
+
q_scale_offset = (off_z * H + off_h) * tl.cdiv(qo_len, BLOCK_M)
|
| 87 |
+
k_scale_offset = (off_z * (H // num_kv_groups) + off_h // num_kv_groups) * tl.cdiv(kv_len, BLOCK_N)
|
| 88 |
+
k_bid_offset = (off_z * (H // num_kv_groups) + off_h // num_kv_groups) * stride_kbidq
|
| 89 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 90 |
+
offs_n = tl.arange(0, BLOCK_N)
|
| 91 |
+
offs_k = tl.arange(0, HEAD_DIM)
|
| 92 |
+
Q_ptrs = Q + (off_z * stride_qz + off_h * stride_qh) + offs_m[:, None] * stride_qn + offs_k[None, :]
|
| 93 |
+
Q_scale_ptr = Q_scale + q_scale_offset + start_m
|
| 94 |
+
K_ptrs = K + (off_z * stride_kz + (off_h // num_kv_groups) * stride_kh) + offs_n[None, :] * stride_kn + offs_k[:, None]
|
| 95 |
+
K_scale_ptr = K_scale + k_scale_offset
|
| 96 |
+
K_bid_ptr = K_blkid + k_bid_offset + start_m * stride_kbidk
|
| 97 |
+
V_ptrs = V + (off_z * stride_vz + (off_h // num_kv_groups) * stride_vh) + offs_n[:, None] * stride_vn + offs_k[None, :]
|
| 98 |
+
O_block_ptr = Out + (off_z * stride_oz + off_h * stride_oh) + offs_m[:, None] * stride_on + offs_k[None, :]
|
| 99 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
| 100 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
|
| 101 |
+
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
|
| 102 |
+
q = tl.load(Q_ptrs, mask = offs_m[:, None] < qo_len)
|
| 103 |
+
q_scale = tl.load(Q_scale_ptr)
|
| 104 |
+
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, q_scale, kv_len, K_ptrs, K_bid_ptr, K_scale_ptr, V_ptrs, stride_kn, stride_vn,
|
| 105 |
+
start_m,
|
| 106 |
+
BLOCK_M, HEAD_DIM, BLOCK_N,
|
| 107 |
+
4 - STAGE, offs_m, offs_n
|
| 108 |
+
)
|
| 109 |
+
if STAGE != 1:
|
| 110 |
+
acc, l_i, _ = _attn_fwd_inner(acc, l_i, m_i, q, q_scale, kv_len, K_ptrs, K_bid_ptr, K_scale_ptr, V_ptrs, stride_kn, stride_vn,
|
| 111 |
+
start_m,
|
| 112 |
+
BLOCK_M, HEAD_DIM, BLOCK_N,
|
| 113 |
+
2, offs_m, offs_n
|
| 114 |
+
)
|
| 115 |
+
acc = acc / l_i[:, None]
|
| 116 |
+
tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask = (offs_m[:, None] < qo_len))
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def forward(q, k, k_block_id, v, q_scale, k_scale, is_causal=False, tensor_layout="HND", output_dtype=torch.float16):
|
| 120 |
+
BLOCK_M = 128
|
| 121 |
+
BLOCK_N = 64
|
| 122 |
+
stage = 3 if is_causal else 1
|
| 123 |
+
o = torch.empty(q.shape, dtype=output_dtype, device=q.device)
|
| 124 |
+
|
| 125 |
+
if tensor_layout == "HND":
|
| 126 |
+
b, h_qo, qo_len, head_dim = q.shape
|
| 127 |
+
_, h_kv, kv_len, _ = k.shape
|
| 128 |
+
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2)
|
| 129 |
+
stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2)
|
| 130 |
+
stride_bz_v, stride_h_v, stride_seq_v = v.stride(0), v.stride(1), v.stride(2)
|
| 131 |
+
stride_bz_o, stride_h_o, stride_seq_o = o.stride(0), o.stride(1), o.stride(2)
|
| 132 |
+
elif tensor_layout == "NHD":
|
| 133 |
+
b, qo_len, h_qo, head_dim = q.shape
|
| 134 |
+
_, kv_len, h_kv, _ = k.shape
|
| 135 |
+
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1)
|
| 136 |
+
stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1)
|
| 137 |
+
stride_bz_v, stride_h_v, stride_seq_v = v.stride(0), v.stride(2), v.stride(1)
|
| 138 |
+
stride_bz_o, stride_h_o, stride_seq_o = o.stride(0), o.stride(2), o.stride(1)
|
| 139 |
+
else:
|
| 140 |
+
raise ValueError(f"tensor_layout {tensor_layout} not supported")
|
| 141 |
+
|
| 142 |
+
if is_causal:
|
| 143 |
+
assert qo_len == kv_len, "qo_len and kv_len must be equal for causal attention"
|
| 144 |
+
|
| 145 |
+
HEAD_DIM_K = head_dim
|
| 146 |
+
num_kv_groups = h_qo // h_kv
|
| 147 |
+
|
| 148 |
+
grid = (triton.cdiv(qo_len, BLOCK_M), h_qo, b )
|
| 149 |
+
_attn_fwd[grid](
|
| 150 |
+
q, k, k_block_id, v, q_scale, k_scale, o,
|
| 151 |
+
stride_bz_q, stride_h_q, stride_seq_q,
|
| 152 |
+
stride_bz_k, stride_h_k, stride_seq_k,
|
| 153 |
+
stride_bz_v, stride_h_v, stride_seq_v,
|
| 154 |
+
stride_bz_o, stride_h_o, stride_seq_o,
|
| 155 |
+
k_block_id.stride(1), k_block_id.stride(2),
|
| 156 |
+
qo_len, kv_len,
|
| 157 |
+
h_qo, num_kv_groups,
|
| 158 |
+
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, HEAD_DIM=HEAD_DIM_K,
|
| 159 |
+
STAGE=stage,
|
| 160 |
+
num_warps=4 if head_dim == 64 else 8,
|
| 161 |
+
num_stages=4)
|
| 162 |
+
return o
|
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/utils.py
ADDED
|
@@ -0,0 +1,462 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, os, gc
|
| 2 |
+
from safetensors import safe_open
|
| 3 |
+
from contextlib import contextmanager
|
| 4 |
+
from einops import rearrange, repeat
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
import time
|
| 9 |
+
import hashlib
|
| 10 |
+
|
| 11 |
+
CACHE_T = 2
|
| 12 |
+
|
| 13 |
+
@contextmanager
|
| 14 |
+
def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
|
| 15 |
+
|
| 16 |
+
old_register_parameter = torch.nn.Module.register_parameter
|
| 17 |
+
if include_buffers:
|
| 18 |
+
old_register_buffer = torch.nn.Module.register_buffer
|
| 19 |
+
|
| 20 |
+
def register_empty_parameter(module, name, param):
|
| 21 |
+
old_register_parameter(module, name, param)
|
| 22 |
+
if param is not None:
|
| 23 |
+
param_cls = type(module._parameters[name])
|
| 24 |
+
kwargs = module._parameters[name].__dict__
|
| 25 |
+
kwargs["requires_grad"] = param.requires_grad
|
| 26 |
+
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
|
| 27 |
+
|
| 28 |
+
def register_empty_buffer(module, name, buffer, persistent=True):
|
| 29 |
+
old_register_buffer(module, name, buffer, persistent=persistent)
|
| 30 |
+
if buffer is not None:
|
| 31 |
+
module._buffers[name] = module._buffers[name].to(device)
|
| 32 |
+
|
| 33 |
+
def patch_tensor_constructor(fn):
|
| 34 |
+
def wrapper(*args, **kwargs):
|
| 35 |
+
kwargs["device"] = device
|
| 36 |
+
return fn(*args, **kwargs)
|
| 37 |
+
|
| 38 |
+
return wrapper
|
| 39 |
+
|
| 40 |
+
if include_buffers:
|
| 41 |
+
tensor_constructors_to_patch = {
|
| 42 |
+
torch_function_name: getattr(torch, torch_function_name)
|
| 43 |
+
for torch_function_name in ["empty", "zeros", "ones", "full"]
|
| 44 |
+
}
|
| 45 |
+
else:
|
| 46 |
+
tensor_constructors_to_patch = {}
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
torch.nn.Module.register_parameter = register_empty_parameter
|
| 50 |
+
if include_buffers:
|
| 51 |
+
torch.nn.Module.register_buffer = register_empty_buffer
|
| 52 |
+
for torch_function_name in tensor_constructors_to_patch.keys():
|
| 53 |
+
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
|
| 54 |
+
yield
|
| 55 |
+
finally:
|
| 56 |
+
torch.nn.Module.register_parameter = old_register_parameter
|
| 57 |
+
if include_buffers:
|
| 58 |
+
torch.nn.Module.register_buffer = old_register_buffer
|
| 59 |
+
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
|
| 60 |
+
setattr(torch, torch_function_name, old_torch_function)
|
| 61 |
+
|
| 62 |
+
def load_state_dict_from_folder(file_path, torch_dtype=None):
|
| 63 |
+
state_dict = {}
|
| 64 |
+
for file_name in os.listdir(file_path):
|
| 65 |
+
if "." in file_name and file_name.split(".")[-1] in [
|
| 66 |
+
"safetensors", "bin", "ckpt", "pth", "pt"
|
| 67 |
+
]:
|
| 68 |
+
state_dict.update(load_state_dict(os.path.join(file_path, file_name), torch_dtype=torch_dtype))
|
| 69 |
+
return state_dict
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def load_state_dict(file_path, torch_dtype=None):
|
| 73 |
+
if file_path.endswith(".safetensors"):
|
| 74 |
+
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
|
| 75 |
+
else:
|
| 76 |
+
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
|
| 80 |
+
state_dict = {}
|
| 81 |
+
with safe_open(file_path, framework="pt", device="cpu") as f:
|
| 82 |
+
for k in f.keys():
|
| 83 |
+
state_dict[k] = f.get_tensor(k)
|
| 84 |
+
if torch_dtype is not None:
|
| 85 |
+
state_dict[k] = state_dict[k].to(torch_dtype)
|
| 86 |
+
return state_dict
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def load_state_dict_from_bin(file_path, torch_dtype=None):
|
| 90 |
+
state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
|
| 91 |
+
if torch_dtype is not None:
|
| 92 |
+
for i in state_dict:
|
| 93 |
+
if isinstance(state_dict[i], torch.Tensor):
|
| 94 |
+
state_dict[i] = state_dict[i].to(torch_dtype)
|
| 95 |
+
return state_dict
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def search_for_embeddings(state_dict):
|
| 99 |
+
embeddings = []
|
| 100 |
+
for k in state_dict:
|
| 101 |
+
if isinstance(state_dict[k], torch.Tensor):
|
| 102 |
+
embeddings.append(state_dict[k])
|
| 103 |
+
elif isinstance(state_dict[k], dict):
|
| 104 |
+
embeddings += search_for_embeddings(state_dict[k])
|
| 105 |
+
return embeddings
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def search_parameter(param, state_dict):
|
| 109 |
+
for name, param_ in state_dict.items():
|
| 110 |
+
if param.numel() == param_.numel():
|
| 111 |
+
if param.shape == param_.shape:
|
| 112 |
+
if torch.dist(param, param_) < 1e-3:
|
| 113 |
+
return name
|
| 114 |
+
else:
|
| 115 |
+
if torch.dist(param.flatten(), param_.flatten()) < 1e-3:
|
| 116 |
+
return name
|
| 117 |
+
return None
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
|
| 121 |
+
matched_keys = set()
|
| 122 |
+
with torch.no_grad():
|
| 123 |
+
for name in source_state_dict:
|
| 124 |
+
rename = search_parameter(source_state_dict[name], target_state_dict)
|
| 125 |
+
if rename is not None:
|
| 126 |
+
print(f'"{name}": "{rename}",')
|
| 127 |
+
matched_keys.add(rename)
|
| 128 |
+
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
|
| 129 |
+
length = source_state_dict[name].shape[0] // 3
|
| 130 |
+
rename = []
|
| 131 |
+
for i in range(3):
|
| 132 |
+
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
|
| 133 |
+
if None not in rename:
|
| 134 |
+
print(f'"{name}": {rename},')
|
| 135 |
+
for rename_ in rename:
|
| 136 |
+
matched_keys.add(rename_)
|
| 137 |
+
for name in target_state_dict:
|
| 138 |
+
if name not in matched_keys:
|
| 139 |
+
print("Cannot find", name, target_state_dict[name].shape)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def search_for_files(folder, extensions):
|
| 143 |
+
files = []
|
| 144 |
+
if os.path.isdir(folder):
|
| 145 |
+
for file in sorted(os.listdir(folder)):
|
| 146 |
+
files += search_for_files(os.path.join(folder, file), extensions)
|
| 147 |
+
elif os.path.isfile(folder):
|
| 148 |
+
for extension in extensions:
|
| 149 |
+
if folder.endswith(extension):
|
| 150 |
+
files.append(folder)
|
| 151 |
+
break
|
| 152 |
+
return files
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
|
| 156 |
+
keys = []
|
| 157 |
+
for key, value in state_dict.items():
|
| 158 |
+
if isinstance(key, str):
|
| 159 |
+
if isinstance(value, torch.Tensor):
|
| 160 |
+
if with_shape:
|
| 161 |
+
shape = "_".join(map(str, list(value.shape)))
|
| 162 |
+
keys.append(key + ":" + shape)
|
| 163 |
+
keys.append(key)
|
| 164 |
+
elif isinstance(value, dict):
|
| 165 |
+
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
|
| 166 |
+
keys.sort()
|
| 167 |
+
keys_str = ",".join(keys)
|
| 168 |
+
return keys_str
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def split_state_dict_with_prefix(state_dict):
|
| 172 |
+
keys = sorted([key for key in state_dict if isinstance(key, str)])
|
| 173 |
+
prefix_dict = {}
|
| 174 |
+
for key in keys:
|
| 175 |
+
prefix = key if "." not in key else key.split(".")[0]
|
| 176 |
+
if prefix not in prefix_dict:
|
| 177 |
+
prefix_dict[prefix] = []
|
| 178 |
+
prefix_dict[prefix].append(key)
|
| 179 |
+
state_dicts = []
|
| 180 |
+
for prefix, keys in prefix_dict.items():
|
| 181 |
+
sub_state_dict = {key: state_dict[key] for key in keys}
|
| 182 |
+
state_dicts.append(sub_state_dict)
|
| 183 |
+
return state_dicts
|
| 184 |
+
|
| 185 |
+
def hash_state_dict_keys(state_dict, with_shape=True):
|
| 186 |
+
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
| 187 |
+
keys_str = keys_str.encode(encoding="UTF-8")
|
| 188 |
+
return hashlib.md5(keys_str).hexdigest()
|
| 189 |
+
|
| 190 |
+
def clean_vram():
|
| 191 |
+
gc.collect()
|
| 192 |
+
if torch.cuda.is_available():
|
| 193 |
+
torch.cuda.empty_cache()
|
| 194 |
+
torch.cuda.ipc_collect()
|
| 195 |
+
if torch.backends.mps.is_available():
|
| 196 |
+
torch.mps.empty_cache()
|
| 197 |
+
|
| 198 |
+
def get_device_list():
|
| 199 |
+
devs = ["auto"]
|
| 200 |
+
try:
|
| 201 |
+
if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_available") and torch.cuda.is_available():
|
| 202 |
+
devs += [f"cuda:{i}" for i in range(torch.cuda.device_count())]
|
| 203 |
+
except Exception:
|
| 204 |
+
pass
|
| 205 |
+
try:
|
| 206 |
+
if hasattr(torch, "mps") and hasattr(torch.mps, "is_available") and torch.backends.mps.is_available():
|
| 207 |
+
devs += [f"mps:{i}" for i in range(torch.mps.device_count())]
|
| 208 |
+
except Exception:
|
| 209 |
+
pass
|
| 210 |
+
return devs
|
| 211 |
+
|
| 212 |
+
class RMS_norm(nn.Module):
|
| 213 |
+
|
| 214 |
+
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
| 215 |
+
super().__init__()
|
| 216 |
+
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
| 217 |
+
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
| 218 |
+
|
| 219 |
+
self.channel_first = channel_first
|
| 220 |
+
self.scale = dim**0.5
|
| 221 |
+
self.gamma = nn.Parameter(torch.ones(shape))
|
| 222 |
+
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
|
| 223 |
+
|
| 224 |
+
def forward(self, x):
|
| 225 |
+
return F.normalize(
|
| 226 |
+
x, dim=(1 if self.channel_first else
|
| 227 |
+
-1)) * self.scale * self.gamma + self.bias
|
| 228 |
+
|
| 229 |
+
class CausalConv3d(nn.Conv3d):
|
| 230 |
+
"""
|
| 231 |
+
Causal 3d convolusion.
|
| 232 |
+
"""
|
| 233 |
+
|
| 234 |
+
def __init__(self, *args, **kwargs):
|
| 235 |
+
super().__init__(*args, **kwargs)
|
| 236 |
+
self._padding = (self.padding[2], self.padding[2], self.padding[1],
|
| 237 |
+
self.padding[1], 2 * self.padding[0], 0)
|
| 238 |
+
self.padding = (0, 0, 0)
|
| 239 |
+
|
| 240 |
+
def forward(self, x, cache_x=None):
|
| 241 |
+
padding = list(self._padding)
|
| 242 |
+
if cache_x is not None and self._padding[4] > 0:
|
| 243 |
+
cache_x = cache_x.to(x.device)
|
| 244 |
+
# print(cache_x.shape, x.shape)
|
| 245 |
+
x = torch.cat([cache_x, x], dim=2)
|
| 246 |
+
padding[4] -= cache_x.shape[2]
|
| 247 |
+
# print('cache!')
|
| 248 |
+
x = F.pad(x, padding, mode='replicate') # mode='replicate'
|
| 249 |
+
# print(x[0,0,:,0,0])
|
| 250 |
+
|
| 251 |
+
return super().forward(x)
|
| 252 |
+
|
| 253 |
+
class PixelShuffle3d(nn.Module):
|
| 254 |
+
def __init__(self, ff, hh, ww):
|
| 255 |
+
super().__init__()
|
| 256 |
+
self.ff = ff
|
| 257 |
+
self.hh = hh
|
| 258 |
+
self.ww = ww
|
| 259 |
+
|
| 260 |
+
def forward(self, x):
|
| 261 |
+
# x: (B, C, F, H, W)
|
| 262 |
+
return rearrange(x,
|
| 263 |
+
'b c (f ff) (h hh) (w ww) -> b (c ff hh ww) f h w',
|
| 264 |
+
ff=self.ff, hh=self.hh, ww=self.ww)
|
| 265 |
+
|
| 266 |
+
class Buffer_LQ4x_Proj(nn.Module):
|
| 267 |
+
|
| 268 |
+
def __init__(self, in_dim, out_dim, layer_num=30):
|
| 269 |
+
super().__init__()
|
| 270 |
+
self.ff = 1
|
| 271 |
+
self.hh = 16
|
| 272 |
+
self.ww = 16
|
| 273 |
+
self.hidden_dim1 = 2048
|
| 274 |
+
self.hidden_dim2 = 3072
|
| 275 |
+
self.layer_num = layer_num
|
| 276 |
+
|
| 277 |
+
self.pixel_shuffle = PixelShuffle3d(self.ff, self.hh, self.ww)
|
| 278 |
+
|
| 279 |
+
self.conv1 = CausalConv3d(in_dim*self.ff*self.hh*self.ww, self.hidden_dim1, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) # f -> f/2 h -> h w -> w
|
| 280 |
+
self.norm1 = RMS_norm(self.hidden_dim1, images=False)
|
| 281 |
+
self.act1 = nn.SiLU()
|
| 282 |
+
|
| 283 |
+
self.conv2 = CausalConv3d(self.hidden_dim1, self.hidden_dim2, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) # f -> f/2 h -> h w -> w
|
| 284 |
+
self.norm2 = RMS_norm(self.hidden_dim2, images=False)
|
| 285 |
+
self.act2 = nn.SiLU()
|
| 286 |
+
|
| 287 |
+
self.linear_layers = nn.ModuleList([nn.Linear(self.hidden_dim2, out_dim) for _ in range(layer_num)])
|
| 288 |
+
|
| 289 |
+
self.clip_idx = 0
|
| 290 |
+
|
| 291 |
+
def forward(self, video):
|
| 292 |
+
self.clear_cache()
|
| 293 |
+
# x: (B, C, F, H, W)
|
| 294 |
+
|
| 295 |
+
t = video.shape[2]
|
| 296 |
+
iter_ = 1 + (t - 1) // 4
|
| 297 |
+
first_frame = video[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
|
| 298 |
+
video = torch.cat([first_frame, video], dim=2)
|
| 299 |
+
# print(video.shape)
|
| 300 |
+
|
| 301 |
+
out_x = []
|
| 302 |
+
for i in range(iter_):
|
| 303 |
+
x = self.pixel_shuffle(video[:,:,i*4:(i+1)*4,:,:])
|
| 304 |
+
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 305 |
+
self.cache['conv1'] = cache1_x
|
| 306 |
+
x = self.conv1(x, self.cache['conv1'])
|
| 307 |
+
x = self.norm1(x)
|
| 308 |
+
x = self.act1(x)
|
| 309 |
+
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 310 |
+
self.cache['conv2'] = cache2_x
|
| 311 |
+
if i == 0:
|
| 312 |
+
continue
|
| 313 |
+
x = self.conv2(x, self.cache['conv2'])
|
| 314 |
+
x = self.norm2(x)
|
| 315 |
+
x = self.act2(x)
|
| 316 |
+
out_x.append(x)
|
| 317 |
+
out_x = torch.cat(out_x, dim = 2)
|
| 318 |
+
# print(out_x.shape)
|
| 319 |
+
out_x = rearrange(out_x, 'b c f h w -> b (f h w) c')
|
| 320 |
+
outputs = []
|
| 321 |
+
for i in range(self.layer_num):
|
| 322 |
+
outputs.append(self.linear_layers[i](out_x))
|
| 323 |
+
return outputs
|
| 324 |
+
|
| 325 |
+
def clear_cache(self):
|
| 326 |
+
self.cache = {}
|
| 327 |
+
self.cache['conv1'] = None
|
| 328 |
+
self.cache['conv2'] = None
|
| 329 |
+
self.clip_idx = 0
|
| 330 |
+
|
| 331 |
+
def stream_forward(self, video_clip):
|
| 332 |
+
if self.clip_idx == 0:
|
| 333 |
+
# self.clear_cache()
|
| 334 |
+
first_frame = video_clip[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
|
| 335 |
+
video_clip = torch.cat([first_frame, video_clip], dim=2)
|
| 336 |
+
x = self.pixel_shuffle(video_clip)
|
| 337 |
+
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 338 |
+
self.cache['conv1'] = cache1_x
|
| 339 |
+
x = self.conv1(x, self.cache['conv1'])
|
| 340 |
+
x = self.norm1(x)
|
| 341 |
+
x = self.act1(x)
|
| 342 |
+
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 343 |
+
self.cache['conv2'] = cache2_x
|
| 344 |
+
self.clip_idx += 1
|
| 345 |
+
return None
|
| 346 |
+
else:
|
| 347 |
+
x = self.pixel_shuffle(video_clip)
|
| 348 |
+
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 349 |
+
self.cache['conv1'] = cache1_x
|
| 350 |
+
x = self.conv1(x, self.cache['conv1'])
|
| 351 |
+
x = self.norm1(x)
|
| 352 |
+
x = self.act1(x)
|
| 353 |
+
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 354 |
+
self.cache['conv2'] = cache2_x
|
| 355 |
+
x = self.conv2(x, self.cache['conv2'])
|
| 356 |
+
x = self.norm2(x)
|
| 357 |
+
x = self.act2(x)
|
| 358 |
+
out_x = rearrange(x, 'b c f h w -> b (f h w) c')
|
| 359 |
+
outputs = []
|
| 360 |
+
for i in range(self.layer_num):
|
| 361 |
+
outputs.append(self.linear_layers[i](out_x))
|
| 362 |
+
self.clip_idx += 1
|
| 363 |
+
return outputs
|
| 364 |
+
|
| 365 |
+
class Causal_LQ4x_Proj(nn.Module):
|
| 366 |
+
|
| 367 |
+
def __init__(self, in_dim, out_dim, layer_num=30):
|
| 368 |
+
super().__init__()
|
| 369 |
+
self.ff = 1
|
| 370 |
+
self.hh = 16
|
| 371 |
+
self.ww = 16
|
| 372 |
+
self.hidden_dim1 = 2048
|
| 373 |
+
self.hidden_dim2 = 3072
|
| 374 |
+
self.layer_num = layer_num
|
| 375 |
+
|
| 376 |
+
self.pixel_shuffle = PixelShuffle3d(self.ff, self.hh, self.ww)
|
| 377 |
+
|
| 378 |
+
self.conv1 = CausalConv3d(in_dim*self.ff*self.hh*self.ww, self.hidden_dim1, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) # f -> f/2 h -> h w -> w
|
| 379 |
+
self.norm1 = RMS_norm(self.hidden_dim1, images=False)
|
| 380 |
+
self.act1 = nn.SiLU()
|
| 381 |
+
|
| 382 |
+
self.conv2 = CausalConv3d(self.hidden_dim1, self.hidden_dim2, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) # f -> f/2 h -> h w -> w
|
| 383 |
+
self.norm2 = RMS_norm(self.hidden_dim2, images=False)
|
| 384 |
+
self.act2 = nn.SiLU()
|
| 385 |
+
|
| 386 |
+
self.linear_layers = nn.ModuleList([nn.Linear(self.hidden_dim2, out_dim) for _ in range(layer_num)])
|
| 387 |
+
|
| 388 |
+
self.clip_idx = 0
|
| 389 |
+
|
| 390 |
+
def forward(self, video):
|
| 391 |
+
self.clear_cache()
|
| 392 |
+
# x: (B, C, F, H, W)
|
| 393 |
+
|
| 394 |
+
t = video.shape[2]
|
| 395 |
+
iter_ = 1 + (t - 1) // 4
|
| 396 |
+
first_frame = video[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
|
| 397 |
+
video = torch.cat([first_frame, video], dim=2)
|
| 398 |
+
# print(video.shape)
|
| 399 |
+
|
| 400 |
+
out_x = []
|
| 401 |
+
for i in range(iter_):
|
| 402 |
+
x = self.pixel_shuffle(video[:,:,i*4:(i+1)*4,:,:])
|
| 403 |
+
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 404 |
+
x = self.conv1(x, self.cache['conv1'])
|
| 405 |
+
self.cache['conv1'] = cache1_x
|
| 406 |
+
x = self.norm1(x)
|
| 407 |
+
x = self.act1(x)
|
| 408 |
+
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 409 |
+
if i == 0:
|
| 410 |
+
self.cache['conv2'] = cache2_x
|
| 411 |
+
continue
|
| 412 |
+
x = self.conv2(x, self.cache['conv2'])
|
| 413 |
+
self.cache['conv2'] = cache2_x
|
| 414 |
+
x = self.norm2(x)
|
| 415 |
+
x = self.act2(x)
|
| 416 |
+
out_x.append(x)
|
| 417 |
+
out_x = torch.cat(out_x, dim = 2)
|
| 418 |
+
out_x = rearrange(out_x, 'b c f h w -> b (f h w) c')
|
| 419 |
+
outputs = []
|
| 420 |
+
for i in range(self.layer_num):
|
| 421 |
+
outputs.append(self.linear_layers[i](out_x))
|
| 422 |
+
return outputs
|
| 423 |
+
|
| 424 |
+
def clear_cache(self):
|
| 425 |
+
self.cache = {}
|
| 426 |
+
self.cache['conv1'] = None
|
| 427 |
+
self.cache['conv2'] = None
|
| 428 |
+
self.clip_idx = 0
|
| 429 |
+
|
| 430 |
+
def stream_forward(self, video_clip):
|
| 431 |
+
if self.clip_idx == 0:
|
| 432 |
+
# self.clear_cache()
|
| 433 |
+
first_frame = video_clip[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
|
| 434 |
+
video_clip = torch.cat([first_frame, video_clip], dim=2)
|
| 435 |
+
x = self.pixel_shuffle(video_clip)
|
| 436 |
+
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 437 |
+
x = self.conv1(x, self.cache['conv1'])
|
| 438 |
+
self.cache['conv1'] = cache1_x
|
| 439 |
+
x = self.norm1(x)
|
| 440 |
+
x = self.act1(x)
|
| 441 |
+
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 442 |
+
self.cache['conv2'] = cache2_x
|
| 443 |
+
self.clip_idx += 1
|
| 444 |
+
return None
|
| 445 |
+
else:
|
| 446 |
+
x = self.pixel_shuffle(video_clip)
|
| 447 |
+
cache1_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 448 |
+
x = self.conv1(x, self.cache['conv1'])
|
| 449 |
+
self.cache['conv1'] = cache1_x
|
| 450 |
+
x = self.norm1(x)
|
| 451 |
+
x = self.act1(x)
|
| 452 |
+
cache2_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 453 |
+
x = self.conv2(x, self.cache['conv2'])
|
| 454 |
+
self.cache['conv2'] = cache2_x
|
| 455 |
+
x = self.norm2(x)
|
| 456 |
+
x = self.act2(x)
|
| 457 |
+
out_x = rearrange(x, 'b c f h w -> b (f h w) c')
|
| 458 |
+
outputs = []
|
| 459 |
+
for i in range(self.layer_num):
|
| 460 |
+
outputs.append(self.linear_layers[i](out_x))
|
| 461 |
+
self.clip_idx += 1
|
| 462 |
+
return outputs
|
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/wan_video_dit.py
ADDED
|
@@ -0,0 +1,864 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
import random
|
| 6 |
+
import os
|
| 7 |
+
import time
|
| 8 |
+
from typing import Tuple, Optional, List
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
from .utils import hash_state_dict_keys
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
import flash_attn_interface
|
| 14 |
+
FLASH_ATTN_3_AVAILABLE = True
|
| 15 |
+
except ModuleNotFoundError:
|
| 16 |
+
FLASH_ATTN_3_AVAILABLE = False
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
import flash_attn
|
| 20 |
+
FLASH_ATTN_2_AVAILABLE = True
|
| 21 |
+
except ModuleNotFoundError:
|
| 22 |
+
FLASH_ATTN_2_AVAILABLE = False
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
from sageattention import sageattn
|
| 26 |
+
SAGE_ATTN_AVAILABLE = True
|
| 27 |
+
except ModuleNotFoundError:
|
| 28 |
+
SAGE_ATTN_AVAILABLE = False
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
from block_sparse_attn import block_sparse_attn_func
|
| 32 |
+
BLOCK_ATTN_AVAILABLE = True
|
| 33 |
+
except:
|
| 34 |
+
BLOCK_ATTN_AVAILABLE = False
|
| 35 |
+
|
| 36 |
+
from .sparse_sage.core import sparse_sageattn
|
| 37 |
+
from PIL import Image
|
| 38 |
+
import numpy as np
|
| 39 |
+
|
| 40 |
+
USE_BLOCK_ATTN = False
|
| 41 |
+
|
| 42 |
+
# ----------------------------
|
| 43 |
+
# Local / window masks
|
| 44 |
+
# ----------------------------
|
| 45 |
+
@torch.no_grad()
|
| 46 |
+
def build_local_block_mask_shifted_vec(block_h: int,
|
| 47 |
+
block_w: int,
|
| 48 |
+
win_h: int = 6,
|
| 49 |
+
win_w: int = 6,
|
| 50 |
+
include_self: bool = True,
|
| 51 |
+
device=None) -> torch.Tensor:
|
| 52 |
+
device = device or torch.device("cpu")
|
| 53 |
+
H, W = block_h, block_w
|
| 54 |
+
r = torch.arange(H, device=device)
|
| 55 |
+
c = torch.arange(W, device=device)
|
| 56 |
+
YY, XX = torch.meshgrid(r, c, indexing="ij")
|
| 57 |
+
r_all = YY.reshape(-1)
|
| 58 |
+
c_all = XX.reshape(-1)
|
| 59 |
+
r_half = win_h // 2
|
| 60 |
+
c_half = win_w // 2
|
| 61 |
+
start_r = torch.clamp(r_all - r_half, 0, H - win_h)
|
| 62 |
+
end_r = start_r + win_h - 1
|
| 63 |
+
start_c = torch.clamp(c_all - c_half, 0, W - win_w)
|
| 64 |
+
end_c = start_c + win_w - 1
|
| 65 |
+
in_row = (r_all[None, :] >= start_r[:, None]) & (r_all[None, :] <= end_r[:, None])
|
| 66 |
+
in_col = (c_all[None, :] >= start_c[:, None]) & (c_all[None, :] <= end_c[:, None])
|
| 67 |
+
mask = in_row & in_col
|
| 68 |
+
if not include_self:
|
| 69 |
+
mask.fill_diagonal_(False)
|
| 70 |
+
return mask
|
| 71 |
+
|
| 72 |
+
@torch.no_grad()
|
| 73 |
+
def build_local_block_mask_shifted_vec_normal_slide(block_h: int,
|
| 74 |
+
block_w: int,
|
| 75 |
+
win_h: int = 6,
|
| 76 |
+
win_w: int = 6,
|
| 77 |
+
include_self: bool = True,
|
| 78 |
+
device=None) -> torch.Tensor:
|
| 79 |
+
device = device or torch.device("cpu")
|
| 80 |
+
H, W = block_h, block_w
|
| 81 |
+
r = torch.arange(H, device=device)
|
| 82 |
+
c = torch.arange(W, device=device)
|
| 83 |
+
YY, XX = torch.meshgrid(r, c, indexing="ij")
|
| 84 |
+
r_all = YY.reshape(-1)
|
| 85 |
+
c_all = XX.reshape(-1)
|
| 86 |
+
r_half = win_h // 2
|
| 87 |
+
c_half = win_w // 2
|
| 88 |
+
start_r = r_all - r_half
|
| 89 |
+
end_r = start_r + win_h - 1
|
| 90 |
+
start_c = c_all - c_half
|
| 91 |
+
end_c = start_c + win_w - 1
|
| 92 |
+
in_row = (r_all[None, :] >= start_r[:, None]) & (r_all[None, :] <= end_r[:, None])
|
| 93 |
+
in_col = (c_all[None, :] >= start_c[:, None]) & (c_all[None, :] <= end_c[:, None])
|
| 94 |
+
mask = in_row & in_col
|
| 95 |
+
if not include_self:
|
| 96 |
+
mask.fill_diagonal_(False)
|
| 97 |
+
return mask
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class WindowPartition3D:
|
| 101 |
+
"""Partition / reverse-partition helpers for 5-D tensors (B,F,H,W,C)."""
|
| 102 |
+
@staticmethod
|
| 103 |
+
def partition(x: torch.Tensor, win: Tuple[int, int, int]):
|
| 104 |
+
B, F, H, W, C = x.shape
|
| 105 |
+
wf, wh, ww = win
|
| 106 |
+
assert F % wf == 0 and H % wh == 0 and W % ww == 0, "Dims must divide by window size."
|
| 107 |
+
x = x.view(B, F // wf, wf, H // wh, wh, W // ww, ww, C)
|
| 108 |
+
x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous()
|
| 109 |
+
return x.view(-1, wf * wh * ww, C)
|
| 110 |
+
|
| 111 |
+
@staticmethod
|
| 112 |
+
def reverse(windows: torch.Tensor, win: Tuple[int, int, int], orig: Tuple[int, int, int]):
|
| 113 |
+
F, H, W = orig
|
| 114 |
+
wf, wh, ww = win
|
| 115 |
+
nf, nh, nw = F // wf, H // wh, W // ww
|
| 116 |
+
B = windows.size(0) // (nf * nh * nw)
|
| 117 |
+
x = windows.view(B, nf, nh, nw, wf, wh, ww, -1)
|
| 118 |
+
x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous()
|
| 119 |
+
return x.view(B, F, H, W, -1)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@torch.no_grad()
|
| 123 |
+
def generate_draft_block_mask(batch_size, nheads, seqlen,
|
| 124 |
+
q_w, k_w, topk=10, local_attn_mask=None):
|
| 125 |
+
assert batch_size == 1, "Only batch_size=1 supported for now"
|
| 126 |
+
assert local_attn_mask is not None, "local_attn_mask must be provided"
|
| 127 |
+
avgpool_q = torch.mean(q_w, dim=1)
|
| 128 |
+
avgpool_k = torch.mean(k_w, dim=1)
|
| 129 |
+
avgpool_q = rearrange(avgpool_q, 's (h d) -> s h d', h=nheads)
|
| 130 |
+
avgpool_k = rearrange(avgpool_k, 's (h d) -> s h d', h=nheads)
|
| 131 |
+
q_heads = avgpool_q.permute(1, 0, 2)
|
| 132 |
+
k_heads = avgpool_k.permute(1, 0, 2)
|
| 133 |
+
D = avgpool_q.shape[-1]
|
| 134 |
+
scores = torch.einsum("hld,hmd->hlm", q_heads, k_heads) / math.sqrt(D)
|
| 135 |
+
|
| 136 |
+
repeat_head = scores.shape[0]
|
| 137 |
+
repeat_len = scores.shape[1] // local_attn_mask.shape[0]
|
| 138 |
+
repeat_num = scores.shape[2] // local_attn_mask.shape[1]
|
| 139 |
+
local_attn_mask = local_attn_mask.unsqueeze(1).unsqueeze(0).repeat(repeat_len, 1, repeat_num, 1)
|
| 140 |
+
local_attn_mask = rearrange(local_attn_mask, 'x a y b -> (x a) (y b)')
|
| 141 |
+
local_attn_mask = local_attn_mask.unsqueeze(0).repeat(repeat_head, 1, 1)
|
| 142 |
+
local_attn_mask = local_attn_mask.to(torch.float32)
|
| 143 |
+
local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == False, -float('inf'))
|
| 144 |
+
local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == True, 0)
|
| 145 |
+
scores = scores + local_attn_mask
|
| 146 |
+
|
| 147 |
+
attn_map = torch.softmax(scores, dim=-1)
|
| 148 |
+
attn_map = rearrange(attn_map, 'h (it s1) s2 -> (h it) s1 s2', it=seqlen)
|
| 149 |
+
loop_num, s1, s2 = attn_map.shape
|
| 150 |
+
flat = attn_map.reshape(loop_num, -1)
|
| 151 |
+
n = flat.shape[1]
|
| 152 |
+
apply_topk = min(flat.shape[1]-1, topk)
|
| 153 |
+
thresholds = torch.topk(flat, k=apply_topk + 1, dim=1, largest=True).values[:, -1]
|
| 154 |
+
thresholds = thresholds.unsqueeze(1)
|
| 155 |
+
mask_new = (flat > thresholds).reshape(loop_num, s1, s2)
|
| 156 |
+
mask_new = rearrange(mask_new, '(h it) s1 s2 -> h (it s1) s2', it=seqlen) # keep shape note
|
| 157 |
+
# 修正:上行变量名统一
|
| 158 |
+
# mask_new = rearrange(attn_map, 'h (it s1) s2 -> h (it s1) s2', it=seqlen) * 0 + mask_new
|
| 159 |
+
mask = mask_new.unsqueeze(0).repeat(batch_size, 1, 1, 1)
|
| 160 |
+
return mask
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
@torch.no_grad()
|
| 164 |
+
def generate_draft_block_mask_sage(batch_size, nheads, seqlen,
|
| 165 |
+
q_w, k_w, topk=10, local_attn_mask=None):
|
| 166 |
+
assert batch_size == 1, "Only batch_size=1 supported for now"
|
| 167 |
+
assert local_attn_mask is not None, "local_attn_mask must be provided"
|
| 168 |
+
|
| 169 |
+
avgpool_q = torch.mean(q_w, dim=1)
|
| 170 |
+
avgpool_q = rearrange(avgpool_q, 's (h d) -> s h d', h=nheads)
|
| 171 |
+
q_heads = avgpool_q.permute(1, 0, 2)
|
| 172 |
+
D = avgpool_q.shape[-1]
|
| 173 |
+
|
| 174 |
+
k_w_split = k_w.view(k_w.shape[0], 2, 64, k_w.shape[2])
|
| 175 |
+
avgpool_k_split = torch.mean(k_w_split, dim=2)
|
| 176 |
+
avgpool_k_refined = rearrange(avgpool_k_split, 's two d -> (s two) d', two=2) # shape: (s*2, C)
|
| 177 |
+
avgpool_k_refined = rearrange(avgpool_k_refined, 's (h d) -> s h d', h=nheads) # shape: (s*2, h, d)
|
| 178 |
+
k_heads_doubled = avgpool_k_refined.permute(1, 0, 2) # shape: (h, s*2, d)
|
| 179 |
+
|
| 180 |
+
k_heads_1, k_heads_2 = torch.chunk(k_heads_doubled, 2, dim=1)
|
| 181 |
+
scores_1 = torch.einsum("hld,hmd->hlm", q_heads, k_heads_1) / math.sqrt(D)
|
| 182 |
+
scores_2 = torch.einsum("hld,hmd->hlm", q_heads, k_heads_2) / math.sqrt(D)
|
| 183 |
+
scores = torch.cat([scores_1, scores_2], dim=-1)
|
| 184 |
+
|
| 185 |
+
repeat_head = scores.shape[0]
|
| 186 |
+
repeat_len = scores.shape[1] // local_attn_mask.shape[0]
|
| 187 |
+
repeat_num = (scores.shape[2] // 2) // local_attn_mask.shape[1]
|
| 188 |
+
|
| 189 |
+
local_attn_mask = local_attn_mask.unsqueeze(1).unsqueeze(0).repeat(repeat_len, 1, repeat_num, 1)
|
| 190 |
+
local_attn_mask = rearrange(local_attn_mask, 'x a y b -> (x a) (y b)')
|
| 191 |
+
local_attn_mask = local_attn_mask.repeat_interleave(2, dim=1)
|
| 192 |
+
local_attn_mask = local_attn_mask.unsqueeze(0).repeat(repeat_head, 1, 1)
|
| 193 |
+
|
| 194 |
+
assert scores.shape == local_attn_mask.shape, \
|
| 195 |
+
f"Scores shape {scores.shape} != Mask shape {local_attn_mask.shape}"
|
| 196 |
+
|
| 197 |
+
local_attn_mask = local_attn_mask.to(torch.float32)
|
| 198 |
+
local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == False, -float('inf'))
|
| 199 |
+
local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == True, 0)
|
| 200 |
+
scores = scores + local_attn_mask
|
| 201 |
+
|
| 202 |
+
attn_map = torch.softmax(scores, dim=-1)
|
| 203 |
+
attn_map = rearrange(attn_map, 'h (it s1) s2 -> (h it) s1 s2', it=seqlen)
|
| 204 |
+
loop_num, s1, s2 = attn_map.shape
|
| 205 |
+
flat = attn_map.reshape(loop_num, -1)
|
| 206 |
+
apply_topk = min(flat.shape[1]-1, topk)
|
| 207 |
+
|
| 208 |
+
if apply_topk <= 0:
|
| 209 |
+
mask_new = torch.zeros_like(flat, dtype=torch.bool).reshape(loop_num, s1, s2)
|
| 210 |
+
else:
|
| 211 |
+
thresholds = torch.topk(flat, k=apply_topk + 1, dim=1, largest=True).values[:, -1]
|
| 212 |
+
thresholds = thresholds.unsqueeze(1)
|
| 213 |
+
mask_new = (flat > thresholds).reshape(loop_num, s1, s2)
|
| 214 |
+
|
| 215 |
+
mask_new = rearrange(mask_new, '(h it) s1 s2 -> h (it s1) s2', it=seqlen)
|
| 216 |
+
mask = mask_new.unsqueeze(0).repeat(batch_size, 1, 1, 1)
|
| 217 |
+
return mask
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
# ----------------------------
|
| 221 |
+
# Attention kernels
|
| 222 |
+
# ----------------------------
|
| 223 |
+
def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False, attention_mask=None, return_KV=False):
|
| 224 |
+
if attention_mask is not None:
|
| 225 |
+
seqlen = q.shape[1]
|
| 226 |
+
seqlen_kv = k.shape[1]
|
| 227 |
+
if USE_BLOCK_ATTN and BLOCK_ATTN_AVAILABLE:
|
| 228 |
+
q = rearrange(q, "b s (n d) -> (b s) n d", n=num_heads)
|
| 229 |
+
k = rearrange(k, "b s (n d) -> (b s) n d", n=num_heads)
|
| 230 |
+
v = rearrange(v, "b s (n d) -> (b s) n d", n=num_heads)
|
| 231 |
+
else:
|
| 232 |
+
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
| 233 |
+
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
| 234 |
+
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
| 235 |
+
cu_seqlens_q = torch.tensor([0, seqlen], device=q.device, dtype=torch.int32)
|
| 236 |
+
cu_seqlens_k = torch.tensor([0, seqlen_kv], device=q.device, dtype=torch.int32)
|
| 237 |
+
head_mask_type = torch.tensor([1]*num_heads, device=q.device, dtype=torch.int32)
|
| 238 |
+
streaming_info = None
|
| 239 |
+
base_blockmask = attention_mask
|
| 240 |
+
max_seqlen_q_ = seqlen
|
| 241 |
+
max_seqlen_k_ = seqlen_kv
|
| 242 |
+
p_dropout = 0.0
|
| 243 |
+
if USE_BLOCK_ATTN and BLOCK_ATTN_AVAILABLE:
|
| 244 |
+
x = block_sparse_attn_func(
|
| 245 |
+
q, k, v,
|
| 246 |
+
cu_seqlens_q, cu_seqlens_k,
|
| 247 |
+
head_mask_type,
|
| 248 |
+
streaming_info,
|
| 249 |
+
base_blockmask,
|
| 250 |
+
max_seqlen_q_, max_seqlen_k_,
|
| 251 |
+
p_dropout,
|
| 252 |
+
deterministic=False,
|
| 253 |
+
softmax_scale=None,
|
| 254 |
+
is_causal=False,
|
| 255 |
+
exact_streaming=False,
|
| 256 |
+
return_attn_probs=False,
|
| 257 |
+
).unsqueeze(0)
|
| 258 |
+
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
|
| 259 |
+
else:
|
| 260 |
+
x = sparse_sageattn(
|
| 261 |
+
q, k, v,
|
| 262 |
+
mask_id=base_blockmask.to(torch.int8),
|
| 263 |
+
is_causal=False,
|
| 264 |
+
tensor_layout="HND"
|
| 265 |
+
)
|
| 266 |
+
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
| 267 |
+
elif compatibility_mode:
|
| 268 |
+
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
| 269 |
+
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
| 270 |
+
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
| 271 |
+
x = F.scaled_dot_product_attention(q, k, v)
|
| 272 |
+
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
| 273 |
+
elif FLASH_ATTN_3_AVAILABLE:
|
| 274 |
+
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
|
| 275 |
+
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
| 276 |
+
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
|
| 277 |
+
x = flash_attn_interface.flash_attn_func(q, k, v)
|
| 278 |
+
if isinstance(x, tuple):
|
| 279 |
+
x = x[0]
|
| 280 |
+
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
|
| 281 |
+
elif FLASH_ATTN_2_AVAILABLE:
|
| 282 |
+
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
|
| 283 |
+
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
| 284 |
+
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
|
| 285 |
+
x = flash_attn.flash_attn_func(q, k, v)
|
| 286 |
+
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
|
| 287 |
+
elif SAGE_ATTN_AVAILABLE:
|
| 288 |
+
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
| 289 |
+
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
| 290 |
+
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
| 291 |
+
x = sageattn(q, k, v)
|
| 292 |
+
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
| 293 |
+
else:
|
| 294 |
+
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
| 295 |
+
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
| 296 |
+
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
| 297 |
+
x = F.scaled_dot_product_attention(q, k, v)
|
| 298 |
+
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
| 299 |
+
return x
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
|
| 303 |
+
return (x * (1 + scale) + shift)
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def sinusoidal_embedding_1d(dim, position):
|
| 307 |
+
sinusoid = torch.outer(position.type(torch.float64), torch.pow(
|
| 308 |
+
10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
|
| 309 |
+
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
| 310 |
+
return x.to(position.dtype)
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0):
|
| 314 |
+
f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta)
|
| 315 |
+
h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
|
| 316 |
+
w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
|
| 317 |
+
return f_freqs_cis, h_freqs_cis, w_freqs_cis
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
|
| 321 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)
|
| 322 |
+
[: (dim // 2)].double() / dim))
|
| 323 |
+
freqs = torch.outer(torch.arange(end, device=freqs.device), freqs)
|
| 324 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
| 325 |
+
return freqs_cis
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def rope_apply(x, freqs, num_heads):
|
| 329 |
+
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
| 330 |
+
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
|
| 331 |
+
x.shape[0], x.shape[1], x.shape[2], -1, 2))
|
| 332 |
+
x_out = torch.view_as_real(x_out * freqs).flatten(2)
|
| 333 |
+
return x_out.to(x.dtype)
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
# ----------------------------
|
| 337 |
+
# Norms & Blocks
|
| 338 |
+
# ----------------------------
|
| 339 |
+
class RMSNorm(nn.Module):
|
| 340 |
+
def __init__(self, dim, eps=1e-5):
|
| 341 |
+
super().__init__()
|
| 342 |
+
self.eps = eps
|
| 343 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 344 |
+
|
| 345 |
+
def norm(self, x):
|
| 346 |
+
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
| 347 |
+
|
| 348 |
+
def forward(self, x):
|
| 349 |
+
dtype = x.dtype
|
| 350 |
+
return self.norm(x.float()).to(dtype) * self.weight
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
class AttentionModule(nn.Module):
|
| 354 |
+
def __init__(self, num_heads):
|
| 355 |
+
super().__init__()
|
| 356 |
+
self.num_heads = num_heads
|
| 357 |
+
|
| 358 |
+
def forward(self, q, k, v, attention_mask=None):
|
| 359 |
+
x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads, attention_mask=attention_mask)
|
| 360 |
+
return x
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
class SelfAttention(nn.Module):
|
| 364 |
+
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
|
| 365 |
+
super().__init__()
|
| 366 |
+
self.dim = dim
|
| 367 |
+
self.num_heads = num_heads
|
| 368 |
+
self.head_dim = dim // num_heads
|
| 369 |
+
|
| 370 |
+
self.q = nn.Linear(dim, dim)
|
| 371 |
+
self.k = nn.Linear(dim, dim)
|
| 372 |
+
self.v = nn.Linear(dim, dim)
|
| 373 |
+
self.o = nn.Linear(dim, dim)
|
| 374 |
+
self.norm_q = RMSNorm(dim, eps=eps)
|
| 375 |
+
self.norm_k = RMSNorm(dim, eps=eps)
|
| 376 |
+
|
| 377 |
+
self.attn = AttentionModule(self.num_heads)
|
| 378 |
+
self.local_attn_mask = None
|
| 379 |
+
|
| 380 |
+
def forward(self, x, freqs, f=None, h=None, w=None, local_num=None, topk=None,
|
| 381 |
+
train_img=False, block_id=None, kv_len=None, is_full_block=False,
|
| 382 |
+
is_stream=False, pre_cache_k=None, pre_cache_v=None, local_range = 9):
|
| 383 |
+
B, L, D = x.shape
|
| 384 |
+
if is_stream and pre_cache_k is not None and pre_cache_v is not None:
|
| 385 |
+
assert f==2, "f must be 2"
|
| 386 |
+
if is_stream and (pre_cache_k is None or pre_cache_v is None):
|
| 387 |
+
assert f==6, " start f must be 6"
|
| 388 |
+
assert L == f * h * w, "Sequence length mismatch with provided (f,h,w)."
|
| 389 |
+
|
| 390 |
+
q = self.norm_q(self.q(x))
|
| 391 |
+
k = self.norm_k(self.k(x))
|
| 392 |
+
v = self.v(x)
|
| 393 |
+
q = rope_apply(q, freqs, self.num_heads)
|
| 394 |
+
k = rope_apply(k, freqs, self.num_heads)
|
| 395 |
+
|
| 396 |
+
win = (2, 8, 8)
|
| 397 |
+
q = q.view(B, f, h, w, D)
|
| 398 |
+
k = k.view(B, f, h, w, D)
|
| 399 |
+
v = v.view(B, f, h, w, D)
|
| 400 |
+
|
| 401 |
+
q_w = WindowPartition3D.partition(q, win)
|
| 402 |
+
k_w = WindowPartition3D.partition(k, win)
|
| 403 |
+
v_w = WindowPartition3D.partition(v, win)
|
| 404 |
+
|
| 405 |
+
seqlen = f//win[0]
|
| 406 |
+
one_len = k_w.shape[0] // B // seqlen
|
| 407 |
+
if pre_cache_k is not None and pre_cache_v is not None:
|
| 408 |
+
k_w = torch.cat([pre_cache_k, k_w], dim=0)
|
| 409 |
+
v_w = torch.cat([pre_cache_v, v_w], dim=0)
|
| 410 |
+
|
| 411 |
+
block_n = q_w.shape[0] // B
|
| 412 |
+
block_s = q_w.shape[1]
|
| 413 |
+
block_n_kv = k_w.shape[0] // B
|
| 414 |
+
|
| 415 |
+
reorder_q = rearrange(q_w, '(b block_n) (block_s) d -> b (block_n block_s) d', block_n=block_n, block_s=block_s)
|
| 416 |
+
reorder_k = rearrange(k_w, '(b block_n) (block_s) d -> b (block_n block_s) d', block_n=block_n_kv, block_s=block_s)
|
| 417 |
+
reorder_v = rearrange(v_w, '(b block_n) (block_s) d -> b (block_n block_s) d', block_n=block_n_kv, block_s=block_s)
|
| 418 |
+
|
| 419 |
+
window_size = win[0]*h*w//128
|
| 420 |
+
|
| 421 |
+
if self.local_attn_mask is None or self.local_attn_mask_h!=h//8 or self.local_attn_mask_w!=w//8 or self.local_range!=local_range:
|
| 422 |
+
self.local_attn_mask = build_local_block_mask_shifted_vec_normal_slide(h//8, w//8, local_range, local_range, include_self=True, device=k_w.device)
|
| 423 |
+
self.local_attn_mask_h = h//8
|
| 424 |
+
self.local_attn_mask_w = w//8
|
| 425 |
+
self.local_range = local_range
|
| 426 |
+
if USE_BLOCK_ATTN and BLOCK_ATTN_AVAILABLE:
|
| 427 |
+
attention_mask = generate_draft_block_mask(B, self.num_heads, seqlen, q_w, k_w, topk=topk, local_attn_mask=self.local_attn_mask)
|
| 428 |
+
else:
|
| 429 |
+
attention_mask = generate_draft_block_mask_sage(B, self.num_heads, seqlen, q_w, k_w, topk=topk, local_attn_mask=self.local_attn_mask)
|
| 430 |
+
|
| 431 |
+
x = self.attn(reorder_q, reorder_k, reorder_v, attention_mask)
|
| 432 |
+
|
| 433 |
+
cur_block_n, cur_block_s, _ = k_w.shape
|
| 434 |
+
cache_num = cur_block_n // one_len
|
| 435 |
+
if cache_num > kv_len:
|
| 436 |
+
cache_k = k_w[one_len:, :, :]
|
| 437 |
+
cache_v = v_w[one_len:, :, :]
|
| 438 |
+
else:
|
| 439 |
+
cache_k = k_w
|
| 440 |
+
cache_v = v_w
|
| 441 |
+
|
| 442 |
+
x = rearrange(x, 'b (block_n block_s) d -> (b block_n) (block_s) d', block_n=block_n, block_s=block_s)
|
| 443 |
+
x = WindowPartition3D.reverse(x, win, (f, h, w))
|
| 444 |
+
x = x.view(B, f*h*w, D)
|
| 445 |
+
|
| 446 |
+
if is_stream:
|
| 447 |
+
return self.o(x), cache_k, cache_v
|
| 448 |
+
return self.o(x)
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
class CrossAttention(nn.Module):
|
| 452 |
+
"""
|
| 453 |
+
仅考虑文本 context;提供持久 KV 缓存。
|
| 454 |
+
"""
|
| 455 |
+
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
|
| 456 |
+
super().__init__()
|
| 457 |
+
self.dim = dim
|
| 458 |
+
self.num_heads = num_heads
|
| 459 |
+
self.head_dim = dim // num_heads
|
| 460 |
+
|
| 461 |
+
self.q = nn.Linear(dim, dim)
|
| 462 |
+
self.k = nn.Linear(dim, dim)
|
| 463 |
+
self.v = nn.Linear(dim, dim)
|
| 464 |
+
self.o = nn.Linear(dim, dim)
|
| 465 |
+
|
| 466 |
+
self.norm_q = RMSNorm(dim, eps=eps)
|
| 467 |
+
self.norm_k = RMSNorm(dim, eps=eps)
|
| 468 |
+
|
| 469 |
+
self.attn = AttentionModule(self.num_heads)
|
| 470 |
+
|
| 471 |
+
# 持久缓存
|
| 472 |
+
self.cache_k = None
|
| 473 |
+
self.cache_v = None
|
| 474 |
+
|
| 475 |
+
@torch.no_grad()
|
| 476 |
+
def init_cache(self, ctx: torch.Tensor):
|
| 477 |
+
"""ctx: [B, S_ctx, dim] —— 经过 text_embedding 之后的上下文"""
|
| 478 |
+
self.cache_k = self.norm_k(self.k(ctx))
|
| 479 |
+
self.cache_v = self.v(ctx)
|
| 480 |
+
|
| 481 |
+
def clear_cache(self):
|
| 482 |
+
self.cache_k = None
|
| 483 |
+
self.cache_v = None
|
| 484 |
+
|
| 485 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor, is_stream: bool = False):
|
| 486 |
+
"""
|
| 487 |
+
y 即文本上下文(未做其他分支)。
|
| 488 |
+
"""
|
| 489 |
+
q = self.norm_q(self.q(x))
|
| 490 |
+
assert self.cache_k is not None and self.cache_v is not None
|
| 491 |
+
k = self.cache_k
|
| 492 |
+
v = self.cache_v
|
| 493 |
+
|
| 494 |
+
x = self.attn(q, k, v)
|
| 495 |
+
return self.o(x)
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
class GateModule(nn.Module):
|
| 499 |
+
def __init__(self,):
|
| 500 |
+
super().__init__()
|
| 501 |
+
|
| 502 |
+
def forward(self, x, gate, residual):
|
| 503 |
+
return x + gate * residual
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
class DiTBlock(nn.Module):
|
| 507 |
+
def __init__(self, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6):
|
| 508 |
+
super().__init__()
|
| 509 |
+
self.dim = dim
|
| 510 |
+
self.num_heads = num_heads
|
| 511 |
+
self.ffn_dim = ffn_dim
|
| 512 |
+
|
| 513 |
+
self.self_attn = SelfAttention(dim, num_heads, eps)
|
| 514 |
+
self.cross_attn = CrossAttention(dim, num_heads, eps)
|
| 515 |
+
|
| 516 |
+
self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
|
| 517 |
+
self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
|
| 518 |
+
self.norm3 = nn.LayerNorm(dim, eps=eps)
|
| 519 |
+
self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(
|
| 520 |
+
approximate='tanh'), nn.Linear(ffn_dim, dim))
|
| 521 |
+
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
| 522 |
+
self.gate = GateModule()
|
| 523 |
+
|
| 524 |
+
def forward(self, x, context, t_mod, freqs, f, h, w, local_num=None, topk=None,
|
| 525 |
+
train_img=False, block_id=None, kv_len=None, is_full_block=False,
|
| 526 |
+
is_stream=False, pre_cache_k=None, pre_cache_v=None, local_range = 9):
|
| 527 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
| 528 |
+
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
|
| 529 |
+
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
|
| 530 |
+
self_attn_output, self_attn_cache_k, self_attn_cache_v = self.self_attn(
|
| 531 |
+
input_x, freqs, f, h, w, local_num, topk, train_img, block_id,
|
| 532 |
+
kv_len=kv_len, is_full_block=is_full_block, is_stream=is_stream,
|
| 533 |
+
pre_cache_k=pre_cache_k, pre_cache_v=pre_cache_v, local_range = local_range)
|
| 534 |
+
|
| 535 |
+
x = self.gate(x, gate_msa, self_attn_output)
|
| 536 |
+
x = x + self.cross_attn(self.norm3(x), context, is_stream=is_stream)
|
| 537 |
+
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
|
| 538 |
+
x = self.gate(x, gate_mlp, self.ffn(input_x))
|
| 539 |
+
if is_stream:
|
| 540 |
+
return x, self_attn_cache_k, self_attn_cache_v
|
| 541 |
+
return x
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
class MLP(torch.nn.Module):
|
| 545 |
+
def __init__(self, in_dim, out_dim, has_pos_emb=False):
|
| 546 |
+
super().__init__()
|
| 547 |
+
self.proj = torch.nn.Sequential(
|
| 548 |
+
nn.LayerNorm(in_dim),
|
| 549 |
+
nn.Linear(in_dim, in_dim),
|
| 550 |
+
nn.GELU(),
|
| 551 |
+
nn.Linear(in_dim, out_dim),
|
| 552 |
+
nn.LayerNorm(out_dim)
|
| 553 |
+
)
|
| 554 |
+
self.has_pos_emb = has_pos_emb
|
| 555 |
+
if has_pos_emb:
|
| 556 |
+
self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280)))
|
| 557 |
+
|
| 558 |
+
def forward(self, x):
|
| 559 |
+
if self.has_pos_emb:
|
| 560 |
+
x = x + self.emb_pos.to(dtype=x.dtype, device=x.device)
|
| 561 |
+
return self.proj(x)
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
class Head(nn.Module):
|
| 565 |
+
def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float):
|
| 566 |
+
super().__init__()
|
| 567 |
+
self.dim = dim
|
| 568 |
+
self.patch_size = patch_size
|
| 569 |
+
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
|
| 570 |
+
self.head = nn.Linear(dim, out_dim * math.prod(patch_size))
|
| 571 |
+
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
|
| 572 |
+
|
| 573 |
+
def forward(self, x, t_mod):
|
| 574 |
+
shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1)
|
| 575 |
+
x = (self.head(self.norm(x) * (1 + scale) + shift))
|
| 576 |
+
return x
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
# ----------------------------
|
| 580 |
+
# WanModel (no image branch) — init 时即产生 KV 缓存
|
| 581 |
+
# ----------------------------
|
| 582 |
+
class WanModel(torch.nn.Module):
|
| 583 |
+
def __init__(
|
| 584 |
+
self,
|
| 585 |
+
dim: int,
|
| 586 |
+
in_dim: int,
|
| 587 |
+
ffn_dim: int,
|
| 588 |
+
out_dim: int,
|
| 589 |
+
text_dim: int,
|
| 590 |
+
freq_dim: int,
|
| 591 |
+
eps: float,
|
| 592 |
+
patch_size: Tuple[int, int, int],
|
| 593 |
+
num_heads: int,
|
| 594 |
+
num_layers: int,
|
| 595 |
+
# init_context: torch.Tensor, # <<<< 必填:在 __init__ 里用它生成 cross-attn KV 缓存
|
| 596 |
+
has_image_input: bool = False,
|
| 597 |
+
):
|
| 598 |
+
super().__init__()
|
| 599 |
+
self.dim = dim
|
| 600 |
+
self.freq_dim = freq_dim
|
| 601 |
+
self.patch_size = patch_size
|
| 602 |
+
|
| 603 |
+
# patch embed
|
| 604 |
+
self.patch_embedding = nn.Conv3d(
|
| 605 |
+
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
| 606 |
+
|
| 607 |
+
# text / time embed
|
| 608 |
+
self.text_embedding = nn.Sequential(
|
| 609 |
+
nn.Linear(text_dim, dim),
|
| 610 |
+
nn.GELU(approximate='tanh'),
|
| 611 |
+
nn.Linear(dim, dim)
|
| 612 |
+
)
|
| 613 |
+
self.time_embedding = nn.Sequential(
|
| 614 |
+
nn.Linear(freq_dim, dim),
|
| 615 |
+
nn.SiLU(),
|
| 616 |
+
nn.Linear(dim, dim)
|
| 617 |
+
)
|
| 618 |
+
self.time_projection = nn.Sequential(
|
| 619 |
+
nn.SiLU(), nn.Linear(dim, dim * 6))
|
| 620 |
+
|
| 621 |
+
# blocks
|
| 622 |
+
self.blocks = nn.ModuleList([
|
| 623 |
+
DiTBlock(dim, num_heads, ffn_dim, eps)
|
| 624 |
+
for _ in range(num_layers)
|
| 625 |
+
])
|
| 626 |
+
self.head = Head(dim, out_dim, patch_size, eps)
|
| 627 |
+
|
| 628 |
+
head_dim = dim // num_heads
|
| 629 |
+
self.freqs = precompute_freqs_cis_3d(head_dim)
|
| 630 |
+
|
| 631 |
+
self._cross_kv_initialized = False
|
| 632 |
+
|
| 633 |
+
# 可选:手动清空 / 重新初始化
|
| 634 |
+
def clear_cross_kv(self):
|
| 635 |
+
for blk in self.blocks:
|
| 636 |
+
blk.cross_attn.clear_cache()
|
| 637 |
+
self._cross_kv_initialized = False
|
| 638 |
+
|
| 639 |
+
@torch.no_grad()
|
| 640 |
+
def reinit_cross_kv(self, new_context: torch.Tensor):
|
| 641 |
+
ctx_txt = self.text_embedding(new_context)
|
| 642 |
+
for blk in self.blocks:
|
| 643 |
+
blk.cross_attn.init_cache(ctx_txt)
|
| 644 |
+
self._cross_kv_initialized = True
|
| 645 |
+
|
| 646 |
+
def patchify(self, x: torch.Tensor):
|
| 647 |
+
x = self.patch_embedding(x)
|
| 648 |
+
grid_size = x.shape[2:]
|
| 649 |
+
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
|
| 650 |
+
return x, grid_size # x, grid_size: (f, h, w)
|
| 651 |
+
|
| 652 |
+
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
|
| 653 |
+
return rearrange(
|
| 654 |
+
x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)',
|
| 655 |
+
f=grid_size[0], h=grid_size[1], w=grid_size[2],
|
| 656 |
+
x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2]
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
def forward(self,
|
| 660 |
+
x: torch.Tensor,
|
| 661 |
+
timestep: torch.Tensor,
|
| 662 |
+
context: torch.Tensor,
|
| 663 |
+
use_gradient_checkpointing: bool = False,
|
| 664 |
+
use_gradient_checkpointing_offload: bool = False,
|
| 665 |
+
LQ_latents: Optional[List[torch.Tensor]] = None,
|
| 666 |
+
train_img: bool = False,
|
| 667 |
+
topk_ratio: Optional[float] = None,
|
| 668 |
+
kv_ratio: Optional[float] = None,
|
| 669 |
+
local_num: Optional[int] = None,
|
| 670 |
+
is_full_block: bool = False,
|
| 671 |
+
causal_idx: Optional[int] = None,
|
| 672 |
+
**kwargs,
|
| 673 |
+
):
|
| 674 |
+
# time / text embeds
|
| 675 |
+
t = self.time_embedding(
|
| 676 |
+
sinusoidal_embedding_1d(self.freq_dim, timestep))
|
| 677 |
+
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
|
| 678 |
+
|
| 679 |
+
# 这里仍会嵌入 text(CrossAttention 若已有缓存会忽略它)
|
| 680 |
+
# context = self.text_embedding(context)
|
| 681 |
+
|
| 682 |
+
# 输入打补丁
|
| 683 |
+
x, (f, h, w) = self.patchify(x)
|
| 684 |
+
B = x.shape[0]
|
| 685 |
+
|
| 686 |
+
# window / masks 超参
|
| 687 |
+
win = (2, 8, 8)
|
| 688 |
+
seqlen = f//win[0]
|
| 689 |
+
if local_num is None:
|
| 690 |
+
local_random = random.random()
|
| 691 |
+
if local_random < 0.3:
|
| 692 |
+
local_num = seqlen - 3
|
| 693 |
+
elif local_random < 0.4:
|
| 694 |
+
local_num = seqlen - 4
|
| 695 |
+
elif local_random < 0.5:
|
| 696 |
+
local_num = seqlen - 2
|
| 697 |
+
else:
|
| 698 |
+
local_num = seqlen
|
| 699 |
+
|
| 700 |
+
window_size = win[0]*h*w//128
|
| 701 |
+
square_num = window_size*window_size
|
| 702 |
+
topk_ratio = 2.0
|
| 703 |
+
topk = min(max(int(square_num*topk_ratio), 1), int(square_num*seqlen)-1)
|
| 704 |
+
|
| 705 |
+
if kv_ratio is None:
|
| 706 |
+
kv_ratio = (random.uniform(0., 1.0)**2)*(local_num-2-2)+2
|
| 707 |
+
kv_len = min(max(int(window_size*kv_ratio), 1), int(window_size*seqlen)-1)
|
| 708 |
+
|
| 709 |
+
decay_ratio = random.uniform(0.7, 1.0)
|
| 710 |
+
|
| 711 |
+
# RoPE 3D
|
| 712 |
+
freqs = torch.cat([
|
| 713 |
+
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
| 714 |
+
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
| 715 |
+
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
| 716 |
+
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
| 717 |
+
|
| 718 |
+
def create_custom_forward(module):
|
| 719 |
+
def custom_forward(*inputs):
|
| 720 |
+
return module(*inputs)
|
| 721 |
+
return custom_forward
|
| 722 |
+
|
| 723 |
+
# blocks
|
| 724 |
+
for block_id, block in enumerate(self.blocks):
|
| 725 |
+
if LQ_latents is not None and block_id < len(LQ_latents):
|
| 726 |
+
x += LQ_latents[block_id]
|
| 727 |
+
|
| 728 |
+
if self.training and use_gradient_checkpointing:
|
| 729 |
+
if use_gradient_checkpointing_offload:
|
| 730 |
+
with torch.autograd.graph.save_on_cpu():
|
| 731 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 732 |
+
create_custom_forward(block),
|
| 733 |
+
x, context, t_mod, freqs, f, h, w, local_num, topk,
|
| 734 |
+
train_img, block_id, kv_len, is_full_block, False,
|
| 735 |
+
None, None,
|
| 736 |
+
use_reentrant=False,
|
| 737 |
+
)
|
| 738 |
+
else:
|
| 739 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 740 |
+
create_custom_forward(block),
|
| 741 |
+
x, context, t_mod, freqs, f, h, w, local_num, topk,
|
| 742 |
+
train_img, block_id, kv_len, is_full_block, False,
|
| 743 |
+
None, None,
|
| 744 |
+
use_reentrant=False,
|
| 745 |
+
)
|
| 746 |
+
else:
|
| 747 |
+
x = block(x, context, t_mod, freqs, f, h, w, local_num, topk,
|
| 748 |
+
train_img, block_id, kv_len, is_full_block, False,
|
| 749 |
+
None, None)
|
| 750 |
+
|
| 751 |
+
x = self.head(x, t)
|
| 752 |
+
x = self.unpatchify(x, (f, h, w))
|
| 753 |
+
return x
|
| 754 |
+
|
| 755 |
+
@staticmethod
|
| 756 |
+
def state_dict_converter():
|
| 757 |
+
return WanModelStateDictConverter()
|
| 758 |
+
|
| 759 |
+
|
| 760 |
+
# ----------------------------
|
| 761 |
+
# State dict converter(保持原映射;已忽略 has_image_input 使用)
|
| 762 |
+
# ----------------------------
|
| 763 |
+
class WanModelStateDictConverter:
|
| 764 |
+
def __init__(self):
|
| 765 |
+
pass
|
| 766 |
+
|
| 767 |
+
def from_diffusers(self, state_dict):
|
| 768 |
+
rename_dict = {
|
| 769 |
+
"blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
|
| 770 |
+
"blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
|
| 771 |
+
"blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
|
| 772 |
+
"blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
|
| 773 |
+
"blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
|
| 774 |
+
"blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
|
| 775 |
+
"blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
|
| 776 |
+
"blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
|
| 777 |
+
"blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
|
| 778 |
+
"blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
|
| 779 |
+
"blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
|
| 780 |
+
"blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
|
| 781 |
+
"blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
|
| 782 |
+
"blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
|
| 783 |
+
"blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
|
| 784 |
+
"blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
|
| 785 |
+
"blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
|
| 786 |
+
"blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
|
| 787 |
+
"blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
|
| 788 |
+
"blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
|
| 789 |
+
"blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
|
| 790 |
+
"blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
|
| 791 |
+
"blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
|
| 792 |
+
"blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
|
| 793 |
+
"blocks.0.norm2.bias": "blocks.0.norm3.bias",
|
| 794 |
+
"blocks.0.norm2.weight": "blocks.0.norm3.weight",
|
| 795 |
+
"blocks.0.scale_shift_table": "blocks.0.modulation",
|
| 796 |
+
"condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
|
| 797 |
+
"condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
|
| 798 |
+
"condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
|
| 799 |
+
"condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
|
| 800 |
+
"condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
|
| 801 |
+
"condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
|
| 802 |
+
"condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
|
| 803 |
+
"condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
|
| 804 |
+
"condition_embedder.time_proj.bias": "time_projection.1.bias",
|
| 805 |
+
"condition_embedder.time_proj.weight": "time_projection.1.weight",
|
| 806 |
+
"patch_embedding.bias": "patch_embedding.bias",
|
| 807 |
+
"patch_embedding.weight": "patch_embedding.weight",
|
| 808 |
+
"scale_shift_table": "head.modulation",
|
| 809 |
+
"proj_out.bias": "head.head.bias",
|
| 810 |
+
"proj_out.weight": "head.head.weight",
|
| 811 |
+
}
|
| 812 |
+
state_dict_ = {}
|
| 813 |
+
for name, param in state_dict.items():
|
| 814 |
+
if name in rename_dict:
|
| 815 |
+
state_dict_[rename_dict[name]] = param
|
| 816 |
+
else:
|
| 817 |
+
name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
|
| 818 |
+
if name_ in rename_dict:
|
| 819 |
+
name_ = rename_dict[name_]
|
| 820 |
+
name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
|
| 821 |
+
state_dict_[name_] = param
|
| 822 |
+
if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b":
|
| 823 |
+
config = {
|
| 824 |
+
"model_type": "t2v",
|
| 825 |
+
"patch_size": (1, 2, 2),
|
| 826 |
+
"text_len": 512,
|
| 827 |
+
"in_dim": 16,
|
| 828 |
+
"dim": 5120,
|
| 829 |
+
"ffn_dim": 13824,
|
| 830 |
+
"freq_dim": 256,
|
| 831 |
+
"text_dim": 4096,
|
| 832 |
+
"out_dim": 16,
|
| 833 |
+
"num_heads": 40,
|
| 834 |
+
"num_layers": 40,
|
| 835 |
+
"window_size": (-1, -1),
|
| 836 |
+
"qk_norm": True,
|
| 837 |
+
"cross_attn_norm": True,
|
| 838 |
+
"eps": 1e-6,
|
| 839 |
+
}
|
| 840 |
+
else:
|
| 841 |
+
config = {}
|
| 842 |
+
return state_dict_, config
|
| 843 |
+
|
| 844 |
+
def from_civitai(self, state_dict):
|
| 845 |
+
state_dict = {name: param for name, param in state_dict.items() if not name.startswith("vace")}
|
| 846 |
+
# 保留原有哈希匹配返回的 config;实现本身不使用 has_image_input 分支
|
| 847 |
+
if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
|
| 848 |
+
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 16,"dim": 1536,"ffn_dim": 8960,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 12,"num_layers": 30,"eps": 1e-6}
|
| 849 |
+
elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70":
|
| 850 |
+
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 16,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6}
|
| 851 |
+
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
|
| 852 |
+
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 36,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6}
|
| 853 |
+
elif hash_state_dict_keys(state_dict) == "6d6ccde6845b95ad9114ab993d917893":
|
| 854 |
+
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 36,"dim": 1536,"ffn_dim": 8960,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 12,"num_layers": 30,"eps": 1e-6}
|
| 855 |
+
elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
|
| 856 |
+
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 48,"dim": 1536,"ffn_dim": 8960,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 12,"num_layers": 30,"eps": 1e-6}
|
| 857 |
+
elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c":
|
| 858 |
+
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 48,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6}
|
| 859 |
+
elif hash_state_dict_keys(state_dict) == "3ef3b1f8e1dab83d5b71fd7b617f859f":
|
| 860 |
+
config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 36,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6,"has_image_pos_emb": False}
|
| 861 |
+
else:
|
| 862 |
+
config = {}
|
| 863 |
+
return state_dict, config
|
| 864 |
+
|
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/wan_video_vae.py
ADDED
|
@@ -0,0 +1,847 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from einops import rearrange, repeat
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
CACHE_T = 2
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def check_is_instance(model, module_class):
|
| 12 |
+
if isinstance(model, module_class):
|
| 13 |
+
return True
|
| 14 |
+
if hasattr(model, "module") and isinstance(model.module, module_class):
|
| 15 |
+
return True
|
| 16 |
+
return False
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def block_causal_mask(x, block_size):
|
| 20 |
+
# params
|
| 21 |
+
b, n, s, _, device = *x.size(), x.device
|
| 22 |
+
assert s % block_size == 0
|
| 23 |
+
num_blocks = s // block_size
|
| 24 |
+
|
| 25 |
+
# build mask
|
| 26 |
+
mask = torch.zeros(b, n, s, s, dtype=torch.bool, device=device)
|
| 27 |
+
for i in range(num_blocks):
|
| 28 |
+
mask[:, :,
|
| 29 |
+
i * block_size:(i + 1) * block_size, :(i + 1) * block_size] = 1
|
| 30 |
+
return mask
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class CausalConv3d(nn.Conv3d):
|
| 34 |
+
"""
|
| 35 |
+
Causal 3d convolusion.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, *args, **kwargs):
|
| 39 |
+
super().__init__(*args, **kwargs)
|
| 40 |
+
self._padding = (self.padding[2], self.padding[2], self.padding[1],
|
| 41 |
+
self.padding[1], 2 * self.padding[0], 0)
|
| 42 |
+
self.padding = (0, 0, 0)
|
| 43 |
+
|
| 44 |
+
def forward(self, x, cache_x=None):
|
| 45 |
+
padding = list(self._padding)
|
| 46 |
+
if cache_x is not None and self._padding[4] > 0:
|
| 47 |
+
cache_x = cache_x.to(x.device)
|
| 48 |
+
# print('cache_x.shape', cache_x.shape, 'x.shape', x.shape)
|
| 49 |
+
x = torch.cat([cache_x, x], dim=2)
|
| 50 |
+
padding[4] -= cache_x.shape[2]
|
| 51 |
+
x = F.pad(x, padding)
|
| 52 |
+
|
| 53 |
+
return super().forward(x)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class RMS_norm(nn.Module):
|
| 57 |
+
|
| 58 |
+
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
| 59 |
+
super().__init__()
|
| 60 |
+
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
| 61 |
+
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
| 62 |
+
|
| 63 |
+
self.channel_first = channel_first
|
| 64 |
+
self.scale = dim**0.5
|
| 65 |
+
self.gamma = nn.Parameter(torch.ones(shape))
|
| 66 |
+
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
|
| 67 |
+
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
return F.normalize(
|
| 70 |
+
x, dim=(1 if self.channel_first else
|
| 71 |
+
-1)) * self.scale * self.gamma + self.bias
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class Upsample(nn.Upsample):
|
| 75 |
+
|
| 76 |
+
def forward(self, x):
|
| 77 |
+
"""
|
| 78 |
+
Fix bfloat16 support for nearest neighbor interpolation.
|
| 79 |
+
"""
|
| 80 |
+
return super().forward(x.float()).type_as(x)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class Resample(nn.Module):
|
| 84 |
+
|
| 85 |
+
def __init__(self, dim, mode):
|
| 86 |
+
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
|
| 87 |
+
'downsample3d')
|
| 88 |
+
super().__init__()
|
| 89 |
+
self.dim = dim
|
| 90 |
+
self.mode = mode
|
| 91 |
+
|
| 92 |
+
# layers
|
| 93 |
+
if mode == 'upsample2d':
|
| 94 |
+
self.resample = nn.Sequential(
|
| 95 |
+
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
| 96 |
+
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
| 97 |
+
elif mode == 'upsample3d':
|
| 98 |
+
self.resample = nn.Sequential(
|
| 99 |
+
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
| 100 |
+
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
| 101 |
+
self.time_conv = CausalConv3d(dim,
|
| 102 |
+
dim * 2, (3, 1, 1),
|
| 103 |
+
padding=(1, 0, 0))
|
| 104 |
+
|
| 105 |
+
elif mode == 'downsample2d':
|
| 106 |
+
self.resample = nn.Sequential(
|
| 107 |
+
nn.ZeroPad2d((0, 1, 0, 1)),
|
| 108 |
+
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
| 109 |
+
elif mode == 'downsample3d':
|
| 110 |
+
self.resample = nn.Sequential(
|
| 111 |
+
nn.ZeroPad2d((0, 1, 0, 1)),
|
| 112 |
+
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
| 113 |
+
self.time_conv = CausalConv3d(dim,
|
| 114 |
+
dim, (3, 1, 1),
|
| 115 |
+
stride=(2, 1, 1),
|
| 116 |
+
padding=(0, 0, 0))
|
| 117 |
+
|
| 118 |
+
else:
|
| 119 |
+
self.resample = nn.Identity()
|
| 120 |
+
|
| 121 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 122 |
+
b, c, t, h, w = x.size()
|
| 123 |
+
if self.mode == 'upsample3d':
|
| 124 |
+
if feat_cache is not None:
|
| 125 |
+
idx = feat_idx[0]
|
| 126 |
+
if feat_cache[idx] is None:
|
| 127 |
+
feat_cache[idx] = 'Rep'
|
| 128 |
+
feat_idx[0] += 1
|
| 129 |
+
else:
|
| 130 |
+
|
| 131 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 132 |
+
if cache_x.shape[2] < 2 and feat_cache[
|
| 133 |
+
idx] is not None and feat_cache[idx] != 'Rep':
|
| 134 |
+
# cache last frame of last two chunk
|
| 135 |
+
cache_x = torch.cat([
|
| 136 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 137 |
+
cache_x.device), cache_x
|
| 138 |
+
],
|
| 139 |
+
dim=2)
|
| 140 |
+
if cache_x.shape[2] < 2 and feat_cache[
|
| 141 |
+
idx] is not None and feat_cache[idx] == 'Rep':
|
| 142 |
+
cache_x = torch.cat([
|
| 143 |
+
torch.zeros_like(cache_x).to(cache_x.device),
|
| 144 |
+
cache_x
|
| 145 |
+
],
|
| 146 |
+
dim=2)
|
| 147 |
+
if feat_cache[idx] == 'Rep':
|
| 148 |
+
x = self.time_conv(x)
|
| 149 |
+
else:
|
| 150 |
+
x = self.time_conv(x, feat_cache[idx])
|
| 151 |
+
feat_cache[idx] = cache_x
|
| 152 |
+
feat_idx[0] += 1
|
| 153 |
+
|
| 154 |
+
x = x.reshape(b, 2, c, t, h, w)
|
| 155 |
+
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
|
| 156 |
+
3)
|
| 157 |
+
x = x.reshape(b, c, t * 2, h, w)
|
| 158 |
+
t = x.shape[2]
|
| 159 |
+
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
| 160 |
+
x = self.resample(x)
|
| 161 |
+
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
| 162 |
+
|
| 163 |
+
if self.mode == 'downsample3d':
|
| 164 |
+
if feat_cache is not None:
|
| 165 |
+
idx = feat_idx[0]
|
| 166 |
+
if feat_cache[idx] is None:
|
| 167 |
+
feat_cache[idx] = x.clone()
|
| 168 |
+
feat_idx[0] += 1
|
| 169 |
+
else:
|
| 170 |
+
cache_x = x[:, :, -1:, :, :].clone()
|
| 171 |
+
x = self.time_conv(
|
| 172 |
+
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
| 173 |
+
feat_cache[idx] = cache_x
|
| 174 |
+
feat_idx[0] += 1
|
| 175 |
+
return x
|
| 176 |
+
|
| 177 |
+
def init_weight(self, conv):
|
| 178 |
+
conv_weight = conv.weight
|
| 179 |
+
nn.init.zeros_(conv_weight)
|
| 180 |
+
c1, c2, t, h, w = conv_weight.size()
|
| 181 |
+
one_matrix = torch.eye(c1, c2)
|
| 182 |
+
init_matrix = one_matrix
|
| 183 |
+
nn.init.zeros_(conv_weight)
|
| 184 |
+
conv_weight.data[:, :, 1, 0, 0] = init_matrix
|
| 185 |
+
conv.weight.data.copy_(conv_weight)
|
| 186 |
+
nn.init.zeros_(conv.bias.data)
|
| 187 |
+
|
| 188 |
+
def init_weight2(self, conv):
|
| 189 |
+
conv_weight = conv.weight.data
|
| 190 |
+
nn.init.zeros_(conv_weight)
|
| 191 |
+
c1, c2, t, h, w = conv_weight.size()
|
| 192 |
+
init_matrix = torch.eye(c1 // 2, c2)
|
| 193 |
+
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
|
| 194 |
+
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
|
| 195 |
+
conv.weight.data.copy_(conv_weight)
|
| 196 |
+
nn.init.zeros_(conv.bias.data)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class ResidualBlock(nn.Module):
|
| 200 |
+
|
| 201 |
+
def __init__(self, in_dim, out_dim, dropout=0.0):
|
| 202 |
+
super().__init__()
|
| 203 |
+
self.in_dim = in_dim
|
| 204 |
+
self.out_dim = out_dim
|
| 205 |
+
|
| 206 |
+
# layers
|
| 207 |
+
self.residual = nn.Sequential(
|
| 208 |
+
RMS_norm(in_dim, images=False), nn.SiLU(),
|
| 209 |
+
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
| 210 |
+
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
|
| 211 |
+
CausalConv3d(out_dim, out_dim, 3, padding=1))
|
| 212 |
+
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
| 213 |
+
if in_dim != out_dim else nn.Identity()
|
| 214 |
+
|
| 215 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 216 |
+
h = self.shortcut(x)
|
| 217 |
+
for layer in self.residual:
|
| 218 |
+
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
|
| 219 |
+
idx = feat_idx[0]
|
| 220 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 221 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 222 |
+
# cache last frame of last two chunk
|
| 223 |
+
cache_x = torch.cat([
|
| 224 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 225 |
+
cache_x.device), cache_x
|
| 226 |
+
],
|
| 227 |
+
dim=2)
|
| 228 |
+
x = layer(x, feat_cache[idx])
|
| 229 |
+
feat_cache[idx] = cache_x
|
| 230 |
+
feat_idx[0] += 1
|
| 231 |
+
else:
|
| 232 |
+
x = layer(x)
|
| 233 |
+
return x + h
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class AttentionBlock(nn.Module):
|
| 237 |
+
"""
|
| 238 |
+
Causal self-attention with a single head.
|
| 239 |
+
"""
|
| 240 |
+
|
| 241 |
+
def __init__(self, dim):
|
| 242 |
+
super().__init__()
|
| 243 |
+
self.dim = dim
|
| 244 |
+
|
| 245 |
+
# layers
|
| 246 |
+
self.norm = RMS_norm(dim)
|
| 247 |
+
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
| 248 |
+
self.proj = nn.Conv2d(dim, dim, 1)
|
| 249 |
+
|
| 250 |
+
# zero out the last layer params
|
| 251 |
+
nn.init.zeros_(self.proj.weight)
|
| 252 |
+
|
| 253 |
+
def forward(self, x):
|
| 254 |
+
identity = x
|
| 255 |
+
b, c, t, h, w = x.size()
|
| 256 |
+
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
| 257 |
+
x = self.norm(x)
|
| 258 |
+
# compute query, key, value
|
| 259 |
+
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(
|
| 260 |
+
0, 1, 3, 2).contiguous().chunk(3, dim=-1)
|
| 261 |
+
|
| 262 |
+
# apply attention
|
| 263 |
+
x = F.scaled_dot_product_attention(
|
| 264 |
+
q,
|
| 265 |
+
k,
|
| 266 |
+
v,
|
| 267 |
+
#attn_mask=block_causal_mask(q, block_size=h * w)
|
| 268 |
+
)
|
| 269 |
+
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
|
| 270 |
+
|
| 271 |
+
# output
|
| 272 |
+
x = self.proj(x)
|
| 273 |
+
x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
|
| 274 |
+
return x + identity
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class Encoder3d(nn.Module):
|
| 278 |
+
|
| 279 |
+
def __init__(self,
|
| 280 |
+
dim=128,
|
| 281 |
+
z_dim=4,
|
| 282 |
+
dim_mult=[1, 2, 4, 4],
|
| 283 |
+
num_res_blocks=2,
|
| 284 |
+
attn_scales=[],
|
| 285 |
+
temperal_downsample=[True, True, False],
|
| 286 |
+
dropout=0.0):
|
| 287 |
+
super().__init__()
|
| 288 |
+
self.dim = dim
|
| 289 |
+
self.z_dim = z_dim
|
| 290 |
+
self.dim_mult = dim_mult
|
| 291 |
+
self.num_res_blocks = num_res_blocks
|
| 292 |
+
self.attn_scales = attn_scales
|
| 293 |
+
self.temperal_downsample = temperal_downsample
|
| 294 |
+
|
| 295 |
+
# dimensions
|
| 296 |
+
dims = [dim * u for u in [1] + dim_mult]
|
| 297 |
+
scale = 1.0
|
| 298 |
+
|
| 299 |
+
# init block
|
| 300 |
+
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
| 301 |
+
|
| 302 |
+
# downsample blocks
|
| 303 |
+
downsamples = []
|
| 304 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 305 |
+
# residual (+attention) blocks
|
| 306 |
+
for _ in range(num_res_blocks):
|
| 307 |
+
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
| 308 |
+
if scale in attn_scales:
|
| 309 |
+
downsamples.append(AttentionBlock(out_dim))
|
| 310 |
+
in_dim = out_dim
|
| 311 |
+
|
| 312 |
+
# downsample block
|
| 313 |
+
if i != len(dim_mult) - 1:
|
| 314 |
+
mode = 'downsample3d' if temperal_downsample[
|
| 315 |
+
i] else 'downsample2d'
|
| 316 |
+
downsamples.append(Resample(out_dim, mode=mode))
|
| 317 |
+
scale /= 2.0
|
| 318 |
+
self.downsamples = nn.Sequential(*downsamples)
|
| 319 |
+
|
| 320 |
+
# middle blocks
|
| 321 |
+
self.middle = nn.Sequential(ResidualBlock(out_dim, out_dim, dropout),
|
| 322 |
+
AttentionBlock(out_dim),
|
| 323 |
+
ResidualBlock(out_dim, out_dim, dropout))
|
| 324 |
+
|
| 325 |
+
# output blocks
|
| 326 |
+
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
|
| 327 |
+
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
| 328 |
+
|
| 329 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 330 |
+
if feat_cache is not None:
|
| 331 |
+
idx = feat_idx[0]
|
| 332 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 333 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 334 |
+
# cache last frame of last two chunk
|
| 335 |
+
cache_x = torch.cat([
|
| 336 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 337 |
+
cache_x.device), cache_x
|
| 338 |
+
],
|
| 339 |
+
dim=2)
|
| 340 |
+
x = self.conv1(x, feat_cache[idx])
|
| 341 |
+
feat_cache[idx] = cache_x
|
| 342 |
+
feat_idx[0] += 1
|
| 343 |
+
else:
|
| 344 |
+
x = self.conv1(x)
|
| 345 |
+
|
| 346 |
+
## downsamples
|
| 347 |
+
for layer in self.downsamples:
|
| 348 |
+
if feat_cache is not None:
|
| 349 |
+
x = layer(x, feat_cache, feat_idx)
|
| 350 |
+
else:
|
| 351 |
+
x = layer(x)
|
| 352 |
+
|
| 353 |
+
## middle
|
| 354 |
+
for layer in self.middle:
|
| 355 |
+
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
| 356 |
+
x = layer(x, feat_cache, feat_idx)
|
| 357 |
+
else:
|
| 358 |
+
x = layer(x)
|
| 359 |
+
|
| 360 |
+
## head
|
| 361 |
+
for layer in self.head:
|
| 362 |
+
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
|
| 363 |
+
idx = feat_idx[0]
|
| 364 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 365 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 366 |
+
# cache last frame of last two chunk
|
| 367 |
+
cache_x = torch.cat([
|
| 368 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 369 |
+
cache_x.device), cache_x
|
| 370 |
+
],
|
| 371 |
+
dim=2)
|
| 372 |
+
x = layer(x, feat_cache[idx])
|
| 373 |
+
feat_cache[idx] = cache_x
|
| 374 |
+
feat_idx[0] += 1
|
| 375 |
+
else:
|
| 376 |
+
x = layer(x)
|
| 377 |
+
return x
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
class Decoder3d(nn.Module):
|
| 381 |
+
|
| 382 |
+
def __init__(self,
|
| 383 |
+
dim=128,
|
| 384 |
+
z_dim=4,
|
| 385 |
+
dim_mult=[1, 2, 4, 4],
|
| 386 |
+
num_res_blocks=2,
|
| 387 |
+
attn_scales=[],
|
| 388 |
+
temperal_upsample=[False, True, True],
|
| 389 |
+
dropout=0.0):
|
| 390 |
+
super().__init__()
|
| 391 |
+
self.dim = dim
|
| 392 |
+
self.z_dim = z_dim
|
| 393 |
+
self.dim_mult = dim_mult
|
| 394 |
+
self.num_res_blocks = num_res_blocks
|
| 395 |
+
self.attn_scales = attn_scales
|
| 396 |
+
self.temperal_upsample = temperal_upsample
|
| 397 |
+
|
| 398 |
+
# dimensions
|
| 399 |
+
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
| 400 |
+
scale = 1.0 / 2**(len(dim_mult) - 2)
|
| 401 |
+
|
| 402 |
+
# init block
|
| 403 |
+
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
| 404 |
+
|
| 405 |
+
# middle blocks
|
| 406 |
+
self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout),
|
| 407 |
+
AttentionBlock(dims[0]),
|
| 408 |
+
ResidualBlock(dims[0], dims[0], dropout))
|
| 409 |
+
|
| 410 |
+
# upsample blocks
|
| 411 |
+
upsamples = []
|
| 412 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 413 |
+
# residual (+attention) blocks
|
| 414 |
+
if i == 1 or i == 2 or i == 3:
|
| 415 |
+
in_dim = in_dim // 2
|
| 416 |
+
for _ in range(num_res_blocks + 1):
|
| 417 |
+
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
| 418 |
+
if scale in attn_scales:
|
| 419 |
+
upsamples.append(AttentionBlock(out_dim))
|
| 420 |
+
in_dim = out_dim
|
| 421 |
+
|
| 422 |
+
# upsample block
|
| 423 |
+
if i != len(dim_mult) - 1:
|
| 424 |
+
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
|
| 425 |
+
upsamples.append(Resample(out_dim, mode=mode))
|
| 426 |
+
scale *= 2.0
|
| 427 |
+
self.upsamples = nn.Sequential(*upsamples)
|
| 428 |
+
|
| 429 |
+
# output blocks
|
| 430 |
+
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
|
| 431 |
+
CausalConv3d(out_dim, 3, 3, padding=1))
|
| 432 |
+
|
| 433 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 434 |
+
## conv1
|
| 435 |
+
if feat_cache is not None:
|
| 436 |
+
idx = feat_idx[0]
|
| 437 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 438 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 439 |
+
# cache last frame of last two chunk
|
| 440 |
+
cache_x = torch.cat([
|
| 441 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 442 |
+
cache_x.device), cache_x
|
| 443 |
+
],
|
| 444 |
+
dim=2)
|
| 445 |
+
x = self.conv1(x, feat_cache[idx])
|
| 446 |
+
feat_cache[idx] = cache_x
|
| 447 |
+
feat_idx[0] += 1
|
| 448 |
+
else:
|
| 449 |
+
x = self.conv1(x)
|
| 450 |
+
|
| 451 |
+
## middle
|
| 452 |
+
for layer in self.middle:
|
| 453 |
+
if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
|
| 454 |
+
x = layer(x, feat_cache, feat_idx)
|
| 455 |
+
else:
|
| 456 |
+
x = layer(x)
|
| 457 |
+
|
| 458 |
+
## upsamples
|
| 459 |
+
for layer in self.upsamples:
|
| 460 |
+
if feat_cache is not None:
|
| 461 |
+
x = layer(x, feat_cache, feat_idx)
|
| 462 |
+
else:
|
| 463 |
+
x = layer(x)
|
| 464 |
+
|
| 465 |
+
## head
|
| 466 |
+
for layer in self.head:
|
| 467 |
+
if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
|
| 468 |
+
idx = feat_idx[0]
|
| 469 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 470 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 471 |
+
# cache last frame of last two chunk
|
| 472 |
+
cache_x = torch.cat([
|
| 473 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 474 |
+
cache_x.device), cache_x
|
| 475 |
+
],
|
| 476 |
+
dim=2)
|
| 477 |
+
x = layer(x, feat_cache[idx])
|
| 478 |
+
feat_cache[idx] = cache_x
|
| 479 |
+
feat_idx[0] += 1
|
| 480 |
+
else:
|
| 481 |
+
x = layer(x)
|
| 482 |
+
return x
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
def count_conv3d(model):
|
| 486 |
+
count = 0
|
| 487 |
+
for m in model.modules():
|
| 488 |
+
if check_is_instance(m, CausalConv3d):
|
| 489 |
+
count += 1
|
| 490 |
+
return count
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
class VideoVAE_(nn.Module):
|
| 494 |
+
|
| 495 |
+
def __init__(self,
|
| 496 |
+
dim=96,
|
| 497 |
+
z_dim=16,
|
| 498 |
+
dim_mult=[1, 2, 4, 4],
|
| 499 |
+
num_res_blocks=2,
|
| 500 |
+
attn_scales=[],
|
| 501 |
+
temperal_downsample=[False, True, True],
|
| 502 |
+
dropout=0.0):
|
| 503 |
+
super().__init__()
|
| 504 |
+
self.dim = dim
|
| 505 |
+
self.z_dim = z_dim
|
| 506 |
+
self.dim_mult = dim_mult
|
| 507 |
+
self.num_res_blocks = num_res_blocks
|
| 508 |
+
self.attn_scales = attn_scales
|
| 509 |
+
self.temperal_downsample = temperal_downsample
|
| 510 |
+
self.temperal_upsample = temperal_downsample[::-1]
|
| 511 |
+
|
| 512 |
+
# modules
|
| 513 |
+
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
|
| 514 |
+
attn_scales, self.temperal_downsample, dropout)
|
| 515 |
+
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
| 516 |
+
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
| 517 |
+
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
|
| 518 |
+
attn_scales, self.temperal_upsample, dropout)
|
| 519 |
+
|
| 520 |
+
def forward(self, x):
|
| 521 |
+
mu, log_var = self.encode(x)
|
| 522 |
+
z = self.reparameterize(mu, log_var)
|
| 523 |
+
x_recon = self.decode(z)
|
| 524 |
+
return x_recon, mu, log_var
|
| 525 |
+
|
| 526 |
+
def encode(self, x, scale):
|
| 527 |
+
self.clear_cache()
|
| 528 |
+
## cache
|
| 529 |
+
t = x.shape[2]
|
| 530 |
+
iter_ = 1 + (t - 1) // 4
|
| 531 |
+
|
| 532 |
+
for i in range(iter_):
|
| 533 |
+
self._enc_conv_idx = [0]
|
| 534 |
+
if i == 0:
|
| 535 |
+
out = self.encoder(x[:, :, :1, :, :],
|
| 536 |
+
feat_cache=self._enc_feat_map,
|
| 537 |
+
feat_idx=self._enc_conv_idx)
|
| 538 |
+
else:
|
| 539 |
+
out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
| 540 |
+
feat_cache=self._enc_feat_map,
|
| 541 |
+
feat_idx=self._enc_conv_idx)
|
| 542 |
+
out = torch.cat([out, out_], 2)
|
| 543 |
+
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
| 544 |
+
if isinstance(scale[0], torch.Tensor):
|
| 545 |
+
scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale]
|
| 546 |
+
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
|
| 547 |
+
1, self.z_dim, 1, 1, 1)
|
| 548 |
+
else:
|
| 549 |
+
scale = scale.to(dtype=mu.dtype, device=mu.device)
|
| 550 |
+
mu = (mu - scale[0]) * scale[1]
|
| 551 |
+
return mu
|
| 552 |
+
|
| 553 |
+
def decode(self, z, scale):
|
| 554 |
+
self.clear_cache()
|
| 555 |
+
# z: [b,c,t,h,w]
|
| 556 |
+
if isinstance(scale[0], torch.Tensor):
|
| 557 |
+
scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
|
| 558 |
+
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
| 559 |
+
1, self.z_dim, 1, 1, 1)
|
| 560 |
+
else:
|
| 561 |
+
scale = scale.to(dtype=z.dtype, device=z.device)
|
| 562 |
+
z = z / scale[1] + scale[0]
|
| 563 |
+
iter_ = z.shape[2]
|
| 564 |
+
x = self.conv2(z)
|
| 565 |
+
for i in range(iter_):
|
| 566 |
+
self._conv_idx = [0]
|
| 567 |
+
if i == 0:
|
| 568 |
+
out = self.decoder(x[:, :, i:i + 1, :, :],
|
| 569 |
+
feat_cache=self._feat_map,
|
| 570 |
+
feat_idx=self._conv_idx)
|
| 571 |
+
else:
|
| 572 |
+
out_ = self.decoder(x[:, :, i:i + 1, :, :],
|
| 573 |
+
feat_cache=self._feat_map,
|
| 574 |
+
feat_idx=self._conv_idx)
|
| 575 |
+
out = torch.cat([out, out_], 2) # may add tensor offload
|
| 576 |
+
return out
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
def stream_decode(self, z, scale):
|
| 580 |
+
# self.clear_cache()
|
| 581 |
+
# z: [b,c,t,h,w]
|
| 582 |
+
if isinstance(scale[0], torch.Tensor):
|
| 583 |
+
scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
|
| 584 |
+
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
| 585 |
+
1, self.z_dim, 1, 1, 1)
|
| 586 |
+
else:
|
| 587 |
+
scale = scale.to(dtype=z.dtype, device=z.device)
|
| 588 |
+
z = z / scale[1] + scale[0]
|
| 589 |
+
iter_ = z.shape[2]
|
| 590 |
+
x = self.conv2(z)
|
| 591 |
+
for i in range(iter_):
|
| 592 |
+
self._conv_idx = [0]
|
| 593 |
+
if i == 0:
|
| 594 |
+
out = self.decoder(x[:, :, i:i + 1, :, :],
|
| 595 |
+
feat_cache=self._feat_map,
|
| 596 |
+
feat_idx=self._conv_idx)
|
| 597 |
+
else:
|
| 598 |
+
out_ = self.decoder(x[:, :, i:i + 1, :, :],
|
| 599 |
+
feat_cache=self._feat_map,
|
| 600 |
+
feat_idx=self._conv_idx)
|
| 601 |
+
out = torch.cat([out, out_], 2) # may add tensor offload
|
| 602 |
+
return out
|
| 603 |
+
|
| 604 |
+
def reparameterize(self, mu, log_var):
|
| 605 |
+
std = torch.exp(0.5 * log_var)
|
| 606 |
+
eps = torch.randn_like(std)
|
| 607 |
+
return eps * std + mu
|
| 608 |
+
|
| 609 |
+
def sample(self, imgs, deterministic=False):
|
| 610 |
+
mu, log_var = self.encode(imgs)
|
| 611 |
+
if deterministic:
|
| 612 |
+
return mu
|
| 613 |
+
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
| 614 |
+
return mu + std * torch.randn_like(std)
|
| 615 |
+
|
| 616 |
+
def clear_cache(self):
|
| 617 |
+
self._conv_num = count_conv3d(self.decoder)
|
| 618 |
+
self._conv_idx = [0]
|
| 619 |
+
self._feat_map = [None] * self._conv_num
|
| 620 |
+
# print('self._feat_map', len(self._feat_map))
|
| 621 |
+
# cache encode
|
| 622 |
+
if self.encoder is not None:
|
| 623 |
+
self._enc_conv_num = count_conv3d(self.encoder)
|
| 624 |
+
self._enc_conv_idx = [0]
|
| 625 |
+
self._enc_feat_map = [None] * self._enc_conv_num
|
| 626 |
+
# print('self._enc_feat_map', len(self._enc_feat_map))
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
class WanVideoVAE(nn.Module):
|
| 630 |
+
|
| 631 |
+
def __init__(self, z_dim=16, dim=96):
|
| 632 |
+
super().__init__()
|
| 633 |
+
|
| 634 |
+
mean = [
|
| 635 |
+
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
|
| 636 |
+
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
|
| 637 |
+
]
|
| 638 |
+
std = [
|
| 639 |
+
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
|
| 640 |
+
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
|
| 641 |
+
]
|
| 642 |
+
self.mean = torch.tensor(mean)
|
| 643 |
+
self.std = torch.tensor(std)
|
| 644 |
+
self.scale = [self.mean, 1.0 / self.std]
|
| 645 |
+
|
| 646 |
+
# init model
|
| 647 |
+
self.model = VideoVAE_(z_dim=z_dim, dim = dim).eval().requires_grad_(False)
|
| 648 |
+
self.upsampling_factor = 8
|
| 649 |
+
|
| 650 |
+
|
| 651 |
+
def build_1d_mask(self, length, left_bound, right_bound, border_width):
|
| 652 |
+
x = torch.ones((length,))
|
| 653 |
+
if not left_bound:
|
| 654 |
+
x[:border_width] = (torch.arange(border_width) + 1) / border_width
|
| 655 |
+
if not right_bound:
|
| 656 |
+
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
|
| 657 |
+
return x
|
| 658 |
+
|
| 659 |
+
|
| 660 |
+
def build_mask(self, data, is_bound, border_width):
|
| 661 |
+
_, _, _, H, W = data.shape
|
| 662 |
+
h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0])
|
| 663 |
+
w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1])
|
| 664 |
+
|
| 665 |
+
h = repeat(h, "H -> H W", H=H, W=W)
|
| 666 |
+
w = repeat(w, "W -> H W", H=H, W=W)
|
| 667 |
+
|
| 668 |
+
mask = torch.stack([h, w]).min(dim=0).values
|
| 669 |
+
mask = rearrange(mask, "H W -> 1 1 1 H W")
|
| 670 |
+
return mask
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
def tiled_decode(self, hidden_states, device, tile_size, tile_stride):
|
| 674 |
+
_, _, T, H, W = hidden_states.shape
|
| 675 |
+
size_h, size_w = tile_size
|
| 676 |
+
stride_h, stride_w = tile_stride
|
| 677 |
+
|
| 678 |
+
# Split tasks
|
| 679 |
+
tasks = []
|
| 680 |
+
for h in range(0, H, stride_h):
|
| 681 |
+
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
|
| 682 |
+
for w in range(0, W, stride_w):
|
| 683 |
+
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
|
| 684 |
+
h_, w_ = h + size_h, w + size_w
|
| 685 |
+
tasks.append((h, h_, w, w_))
|
| 686 |
+
|
| 687 |
+
data_device = "cpu"
|
| 688 |
+
computation_device = device
|
| 689 |
+
|
| 690 |
+
out_T = T * 4 - 3
|
| 691 |
+
weight = torch.zeros((1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
|
| 692 |
+
values = torch.zeros((1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
|
| 693 |
+
|
| 694 |
+
for h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"):
|
| 695 |
+
hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(computation_device)
|
| 696 |
+
hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(data_device)
|
| 697 |
+
|
| 698 |
+
mask = self.build_mask(
|
| 699 |
+
hidden_states_batch,
|
| 700 |
+
is_bound=(h==0, h_>=H, w==0, w_>=W),
|
| 701 |
+
border_width=((size_h - stride_h) * self.upsampling_factor, (size_w - stride_w) * self.upsampling_factor)
|
| 702 |
+
).to(dtype=hidden_states.dtype, device=data_device)
|
| 703 |
+
|
| 704 |
+
target_h = h * self.upsampling_factor
|
| 705 |
+
target_w = w * self.upsampling_factor
|
| 706 |
+
values[
|
| 707 |
+
:,
|
| 708 |
+
:,
|
| 709 |
+
:,
|
| 710 |
+
target_h:target_h + hidden_states_batch.shape[3],
|
| 711 |
+
target_w:target_w + hidden_states_batch.shape[4],
|
| 712 |
+
] += hidden_states_batch * mask
|
| 713 |
+
weight[
|
| 714 |
+
:,
|
| 715 |
+
:,
|
| 716 |
+
:,
|
| 717 |
+
target_h: target_h + hidden_states_batch.shape[3],
|
| 718 |
+
target_w: target_w + hidden_states_batch.shape[4],
|
| 719 |
+
] += mask
|
| 720 |
+
values = values / weight
|
| 721 |
+
values = values.clamp_(-1, 1)
|
| 722 |
+
return values
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
def tiled_encode(self, video, device, tile_size, tile_stride):
|
| 726 |
+
_, _, T, H, W = video.shape
|
| 727 |
+
size_h, size_w = tile_size
|
| 728 |
+
stride_h, stride_w = tile_stride
|
| 729 |
+
|
| 730 |
+
# Split tasks
|
| 731 |
+
tasks = []
|
| 732 |
+
for h in range(0, H, stride_h):
|
| 733 |
+
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
|
| 734 |
+
for w in range(0, W, stride_w):
|
| 735 |
+
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
|
| 736 |
+
h_, w_ = h + size_h, w + size_w
|
| 737 |
+
tasks.append((h, h_, w, w_))
|
| 738 |
+
|
| 739 |
+
data_device = "cpu"
|
| 740 |
+
computation_device = device
|
| 741 |
+
|
| 742 |
+
out_T = (T + 3) // 4
|
| 743 |
+
weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
|
| 744 |
+
values = torch.zeros((1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
|
| 745 |
+
|
| 746 |
+
for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
|
| 747 |
+
hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device)
|
| 748 |
+
hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(data_device)
|
| 749 |
+
|
| 750 |
+
mask = self.build_mask(
|
| 751 |
+
hidden_states_batch,
|
| 752 |
+
is_bound=(h==0, h_>=H, w==0, w_>=W),
|
| 753 |
+
border_width=((size_h - stride_h) // self.upsampling_factor, (size_w - stride_w) // self.upsampling_factor)
|
| 754 |
+
).to(dtype=video.dtype, device=data_device)
|
| 755 |
+
|
| 756 |
+
target_h = h // self.upsampling_factor
|
| 757 |
+
target_w = w // self.upsampling_factor
|
| 758 |
+
values[
|
| 759 |
+
:,
|
| 760 |
+
:,
|
| 761 |
+
:,
|
| 762 |
+
target_h:target_h + hidden_states_batch.shape[3],
|
| 763 |
+
target_w:target_w + hidden_states_batch.shape[4],
|
| 764 |
+
] += hidden_states_batch * mask
|
| 765 |
+
weight[
|
| 766 |
+
:,
|
| 767 |
+
:,
|
| 768 |
+
:,
|
| 769 |
+
target_h: target_h + hidden_states_batch.shape[3],
|
| 770 |
+
target_w: target_w + hidden_states_batch.shape[4],
|
| 771 |
+
] += mask
|
| 772 |
+
values = values / weight
|
| 773 |
+
return values
|
| 774 |
+
|
| 775 |
+
|
| 776 |
+
def single_encode(self, video, device):
|
| 777 |
+
video = video.to(device)
|
| 778 |
+
x = self.model.encode(video, self.scale)
|
| 779 |
+
return x
|
| 780 |
+
|
| 781 |
+
|
| 782 |
+
def single_decode(self, hidden_state, device):
|
| 783 |
+
hidden_state = hidden_state.to(device)
|
| 784 |
+
video = self.model.decode(hidden_state, self.scale)
|
| 785 |
+
return video.clamp_(-1, 1)
|
| 786 |
+
|
| 787 |
+
|
| 788 |
+
def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
| 789 |
+
|
| 790 |
+
videos = [video.to("cpu") for video in videos]
|
| 791 |
+
hidden_states = []
|
| 792 |
+
for video in videos:
|
| 793 |
+
video = video.unsqueeze(0)
|
| 794 |
+
if tiled:
|
| 795 |
+
tile_size = (tile_size[0] * 8, tile_size[1] * 8)
|
| 796 |
+
tile_stride = (tile_stride[0] * 8, tile_stride[1] * 8)
|
| 797 |
+
hidden_state = self.tiled_encode(video, device, tile_size, tile_stride)
|
| 798 |
+
else:
|
| 799 |
+
hidden_state = self.single_encode(video, device)
|
| 800 |
+
hidden_state = hidden_state.squeeze(0)
|
| 801 |
+
hidden_states.append(hidden_state)
|
| 802 |
+
hidden_states = torch.stack(hidden_states)
|
| 803 |
+
return hidden_states
|
| 804 |
+
|
| 805 |
+
|
| 806 |
+
def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
| 807 |
+
hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
|
| 808 |
+
videos = []
|
| 809 |
+
for hidden_state in hidden_states:
|
| 810 |
+
hidden_state = hidden_state.unsqueeze(0)
|
| 811 |
+
if tiled:
|
| 812 |
+
video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
|
| 813 |
+
else:
|
| 814 |
+
video = self.single_decode(hidden_state, device)
|
| 815 |
+
video = video.squeeze(0)
|
| 816 |
+
videos.append(video)
|
| 817 |
+
videos = torch.stack(videos)
|
| 818 |
+
return videos
|
| 819 |
+
|
| 820 |
+
def clear_cache(self):
|
| 821 |
+
self.model.clear_cache()
|
| 822 |
+
|
| 823 |
+
def stream_decode(self, hidden_states, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
| 824 |
+
hidden_states = [hidden_state for hidden_state in hidden_states]
|
| 825 |
+
assert len(hidden_states) == 1
|
| 826 |
+
hidden_state = hidden_states[0]
|
| 827 |
+
video = self.model.stream_decode(hidden_state, self.scale)
|
| 828 |
+
return video
|
| 829 |
+
|
| 830 |
+
|
| 831 |
+
@staticmethod
|
| 832 |
+
def state_dict_converter():
|
| 833 |
+
return WanVideoVAEStateDictConverter()
|
| 834 |
+
|
| 835 |
+
|
| 836 |
+
class WanVideoVAEStateDictConverter:
|
| 837 |
+
|
| 838 |
+
def __init__(self):
|
| 839 |
+
pass
|
| 840 |
+
|
| 841 |
+
def from_civitai(self, state_dict):
|
| 842 |
+
state_dict_ = {}
|
| 843 |
+
if 'model_state' in state_dict:
|
| 844 |
+
state_dict = state_dict['model_state']
|
| 845 |
+
for name in state_dict:
|
| 846 |
+
state_dict_['model.' + name] = state_dict[name]
|
| 847 |
+
return state_dict_
|
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/pipelines/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .flashvsr_full import FlashVSRFullPipeline
|
| 2 |
+
from .flashvsr_tiny import FlashVSRTinyPipeline
|
| 3 |
+
from .flashvsr_tiny_long import FlashVSRTinyLongPipeline
|
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/pipelines/base.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import gc
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from torchvision.transforms import GaussianBlur
|
| 6 |
+
|
| 7 |
+
class BasePipeline(torch.nn.Module):
|
| 8 |
+
|
| 9 |
+
def __init__(self, device="cuda", torch_dtype=torch.float16, height_division_factor=64, width_division_factor=64):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.device = device
|
| 12 |
+
self.torch_dtype = torch_dtype
|
| 13 |
+
self.height_division_factor = height_division_factor
|
| 14 |
+
self.width_division_factor = width_division_factor
|
| 15 |
+
self.cpu_offload = False
|
| 16 |
+
self.model_names = []
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def check_resize_height_width(self, height, width):
|
| 20 |
+
if height % self.height_division_factor != 0:
|
| 21 |
+
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
|
| 22 |
+
print(f"The height cannot be evenly divided by {self.height_division_factor}. We round it up to {height}.")
|
| 23 |
+
if width % self.width_division_factor != 0:
|
| 24 |
+
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
|
| 25 |
+
print(f"The width cannot be evenly divided by {self.width_division_factor}. We round it up to {width}.")
|
| 26 |
+
return height, width
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def preprocess_image(self, image):
|
| 30 |
+
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
| 31 |
+
return image
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def preprocess_images(self, images):
|
| 35 |
+
return [self.preprocess_image(image) for image in images]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def vae_output_to_image(self, vae_output):
|
| 39 |
+
image = vae_output[0].cpu().float().permute(1, 2, 0).numpy()
|
| 40 |
+
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
| 41 |
+
return image
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def vae_output_to_video(self, vae_output):
|
| 45 |
+
video = vae_output.cpu().permute(1, 2, 0).numpy()
|
| 46 |
+
video = [Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) for image in video]
|
| 47 |
+
return video
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def merge_latents(self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0):
|
| 51 |
+
if len(latents) > 0:
|
| 52 |
+
blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma)
|
| 53 |
+
height, width = value.shape[-2:]
|
| 54 |
+
weight = torch.ones_like(value)
|
| 55 |
+
for latent, mask, scale in zip(latents, masks, scales):
|
| 56 |
+
mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0
|
| 57 |
+
mask = mask.repeat(1, latent.shape[1], 1, 1).to(dtype=latent.dtype, device=latent.device)
|
| 58 |
+
mask = blur(mask)
|
| 59 |
+
value += latent * mask * scale
|
| 60 |
+
weight += mask * scale
|
| 61 |
+
value /= weight
|
| 62 |
+
return value
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback, special_kwargs=None, special_local_kwargs_list=None):
|
| 66 |
+
if special_kwargs is None:
|
| 67 |
+
noise_pred_global = inference_callback(prompt_emb_global)
|
| 68 |
+
else:
|
| 69 |
+
noise_pred_global = inference_callback(prompt_emb_global, special_kwargs)
|
| 70 |
+
if special_local_kwargs_list is None:
|
| 71 |
+
noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals]
|
| 72 |
+
else:
|
| 73 |
+
noise_pred_locals = [inference_callback(prompt_emb_local, special_kwargs) for prompt_emb_local, special_kwargs in zip(prompt_emb_locals, special_local_kwargs_list)]
|
| 74 |
+
noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales)
|
| 75 |
+
return noise_pred
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def extend_prompt(self, prompt, local_prompts, masks, mask_scales):
|
| 79 |
+
local_prompts = local_prompts or []
|
| 80 |
+
masks = masks or []
|
| 81 |
+
mask_scales = mask_scales or []
|
| 82 |
+
extended_prompt_dict = self.prompter.extend_prompt(prompt)
|
| 83 |
+
prompt = extended_prompt_dict.get("prompt", prompt)
|
| 84 |
+
local_prompts += extended_prompt_dict.get("prompts", [])
|
| 85 |
+
masks += extended_prompt_dict.get("masks", [])
|
| 86 |
+
mask_scales += [100.0] * len(extended_prompt_dict.get("masks", []))
|
| 87 |
+
return prompt, local_prompts, masks, mask_scales
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def enable_cpu_offload(self):
|
| 91 |
+
self.cpu_offload = True
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def load_models_to_device(self, loadmodel_names=[]):
|
| 95 |
+
# only load models to device if cpu_offload is enabled
|
| 96 |
+
if not self.cpu_offload:
|
| 97 |
+
return
|
| 98 |
+
# offload the unneeded models to cpu
|
| 99 |
+
for model_name in self.model_names:
|
| 100 |
+
if model_name not in loadmodel_names:
|
| 101 |
+
model = getattr(self, model_name)
|
| 102 |
+
if model is not None:
|
| 103 |
+
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
| 104 |
+
for module in model.modules():
|
| 105 |
+
if hasattr(module, "offload"):
|
| 106 |
+
module.offload()
|
| 107 |
+
else:
|
| 108 |
+
model.cpu()
|
| 109 |
+
# load the needed models to device
|
| 110 |
+
for model_name in loadmodel_names:
|
| 111 |
+
model = getattr(self, model_name)
|
| 112 |
+
if model is not None:
|
| 113 |
+
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
| 114 |
+
for module in model.modules():
|
| 115 |
+
if hasattr(module, "onload"):
|
| 116 |
+
module.onload()
|
| 117 |
+
else:
|
| 118 |
+
model.to(self.device)
|
| 119 |
+
# fresh the cuda cache
|
| 120 |
+
if torch.cuda.is_available():
|
| 121 |
+
torch.cuda.empty_cache()
|
| 122 |
+
if torch.backends.mps.is_available():
|
| 123 |
+
torch.mps.empty_cache()
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16):
|
| 127 |
+
generator = None if seed is None else torch.Generator(device).manual_seed(seed)
|
| 128 |
+
noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
| 129 |
+
return noise
|
| 130 |
+
|
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/pipelines/flashvsr_full.py
ADDED
|
@@ -0,0 +1,618 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import types
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
from typing import Optional, Tuple, Literal
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import numpy as np
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
# import pyfiglet
|
| 14 |
+
|
| 15 |
+
from ..models import ModelManager
|
| 16 |
+
from ..models.utils import clean_vram
|
| 17 |
+
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
|
| 18 |
+
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
|
| 19 |
+
from ..schedulers.flow_match import FlowMatchScheduler
|
| 20 |
+
from .base import BasePipeline
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# -----------------------------
|
| 24 |
+
# 基础工具:ADAIN 所需的统计量(保留以备需要;管线默认用 wavelet)
|
| 25 |
+
# -----------------------------
|
| 26 |
+
def _calc_mean_std(feat: torch.Tensor, eps: float = 1e-5) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 27 |
+
assert feat.dim() == 4, 'feat 必须是 (N, C, H, W)'
|
| 28 |
+
N, C = feat.shape[:2]
|
| 29 |
+
var = feat.view(N, C, -1).var(dim=2, unbiased=False) + eps
|
| 30 |
+
std = var.sqrt().view(N, C, 1, 1)
|
| 31 |
+
mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
|
| 32 |
+
return mean, std
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _adain(content_feat: torch.Tensor, style_feat: torch.Tensor) -> torch.Tensor:
|
| 36 |
+
assert content_feat.shape[:2] == style_feat.shape[:2], "ADAIN: N、C 必须匹配"
|
| 37 |
+
size = content_feat.size()
|
| 38 |
+
style_mean, style_std = _calc_mean_std(style_feat)
|
| 39 |
+
content_mean, content_std = _calc_mean_std(content_feat)
|
| 40 |
+
normalized = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
| 41 |
+
return normalized * style_std.expand(size) + style_mean.expand(size)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# -----------------------------
|
| 45 |
+
# 小波式模糊与分解/重构(ColorCorrector 用)
|
| 46 |
+
# -----------------------------
|
| 47 |
+
def _make_gaussian3x3_kernel(dtype, device) -> torch.Tensor:
|
| 48 |
+
vals = [
|
| 49 |
+
[0.0625, 0.125, 0.0625],
|
| 50 |
+
[0.125, 0.25, 0.125 ],
|
| 51 |
+
[0.0625, 0.125, 0.0625],
|
| 52 |
+
]
|
| 53 |
+
return torch.tensor(vals, dtype=dtype, device=device)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _wavelet_blur(x: torch.Tensor, radius: int) -> torch.Tensor:
|
| 57 |
+
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
|
| 58 |
+
N, C, H, W = x.shape
|
| 59 |
+
base = _make_gaussian3x3_kernel(x.dtype, x.device)
|
| 60 |
+
weight = base.view(1, 1, 3, 3).repeat(C, 1, 1, 1)
|
| 61 |
+
pad = radius
|
| 62 |
+
x_pad = F.pad(x, (pad, pad, pad, pad), mode='replicate')
|
| 63 |
+
out = F.conv2d(x_pad, weight, bias=None, stride=1, padding=0, dilation=radius, groups=C)
|
| 64 |
+
return out
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _wavelet_decompose(x: torch.Tensor, levels: int = 5) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 68 |
+
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
|
| 69 |
+
high = torch.zeros_like(x)
|
| 70 |
+
low = x
|
| 71 |
+
for i in range(levels):
|
| 72 |
+
radius = 2 ** i
|
| 73 |
+
blurred = _wavelet_blur(low, radius)
|
| 74 |
+
high = high + (low - blurred)
|
| 75 |
+
low = blurred
|
| 76 |
+
return high, low
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _wavelet_reconstruct(content: torch.Tensor, style: torch.Tensor, levels: int = 5) -> torch.Tensor:
|
| 80 |
+
c_high, _ = _wavelet_decompose(content, levels=levels)
|
| 81 |
+
_, s_low = _wavelet_decompose(style, levels=levels)
|
| 82 |
+
return c_high + s_low
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# -----------------------------
|
| 86 |
+
# 无状态颜色矫正模块(视频友好,默认 wavelet)
|
| 87 |
+
# -----------------------------
|
| 88 |
+
class TorchColorCorrectorWavelet(nn.Module):
|
| 89 |
+
def __init__(self, levels: int = 5):
|
| 90 |
+
super().__init__()
|
| 91 |
+
self.levels = levels
|
| 92 |
+
|
| 93 |
+
@staticmethod
|
| 94 |
+
def _flatten_time(x: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
|
| 95 |
+
assert x.dim() == 5, '输入必须是 (B, C, f, H, W)'
|
| 96 |
+
B, C, f, H, W = x.shape
|
| 97 |
+
y = x.permute(0, 2, 1, 3, 4).reshape(B * f, C, H, W)
|
| 98 |
+
return y, B, f
|
| 99 |
+
|
| 100 |
+
@staticmethod
|
| 101 |
+
def _unflatten_time(y: torch.Tensor, B: int, f: int) -> torch.Tensor:
|
| 102 |
+
BF, C, H, W = y.shape
|
| 103 |
+
assert BF == B * f
|
| 104 |
+
return y.reshape(B, f, C, H, W).permute(0, 2, 1, 3, 4)
|
| 105 |
+
|
| 106 |
+
def forward(
|
| 107 |
+
self,
|
| 108 |
+
hq_image: torch.Tensor, # (B, C, f, H, W)
|
| 109 |
+
lq_image: torch.Tensor, # (B, C, f, H, W)
|
| 110 |
+
clip_range: Tuple[float, float] = (-1.0, 1.0),
|
| 111 |
+
method: Literal['wavelet', 'adain'] = 'wavelet',
|
| 112 |
+
chunk_size: Optional[int] = None,
|
| 113 |
+
) -> torch.Tensor:
|
| 114 |
+
assert hq_image.shape == lq_image.shape, "HQ 与 LQ 的形状必须一致"
|
| 115 |
+
assert hq_image.dim() == 5 and hq_image.shape[1] == 3, "输入必须是 (B, 3, f, H, W)"
|
| 116 |
+
|
| 117 |
+
B, C, f, H, W = hq_image.shape
|
| 118 |
+
if chunk_size is None or chunk_size >= f:
|
| 119 |
+
hq4, B, f = self._flatten_time(hq_image)
|
| 120 |
+
lq4, _, _ = self._flatten_time(lq_image)
|
| 121 |
+
if method == 'wavelet':
|
| 122 |
+
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
|
| 123 |
+
elif method == 'adain':
|
| 124 |
+
out4 = _adain(hq4, lq4)
|
| 125 |
+
else:
|
| 126 |
+
raise ValueError(f"未知 method: {method}")
|
| 127 |
+
out4 = torch.clamp(out4, *clip_range)
|
| 128 |
+
out = self._unflatten_time(out4, B, f)
|
| 129 |
+
return out
|
| 130 |
+
|
| 131 |
+
outs = []
|
| 132 |
+
for start in range(0, f, chunk_size):
|
| 133 |
+
end = min(start + chunk_size, f)
|
| 134 |
+
hq_chunk = hq_image[:, :, start:end]
|
| 135 |
+
lq_chunk = lq_image[:, :, start:end]
|
| 136 |
+
hq4, B_, f_ = self._flatten_time(hq_chunk)
|
| 137 |
+
lq4, _, _ = self._flatten_time(lq_chunk)
|
| 138 |
+
if method == 'wavelet':
|
| 139 |
+
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
|
| 140 |
+
elif method == 'adain':
|
| 141 |
+
out4 = _adain(hq4, lq4)
|
| 142 |
+
else:
|
| 143 |
+
raise ValueError(f"未知 method: {method}")
|
| 144 |
+
out4 = torch.clamp(out4, *clip_range)
|
| 145 |
+
out_chunk = self._unflatten_time(out4, B_, f_)
|
| 146 |
+
outs.append(out_chunk)
|
| 147 |
+
out = torch.cat(outs, dim=2)
|
| 148 |
+
return out
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# -----------------------------
|
| 152 |
+
# 简化版 Pipeline(仅 dit + vae)
|
| 153 |
+
# -----------------------------
|
| 154 |
+
class FlashVSRFullPipeline(BasePipeline):
|
| 155 |
+
|
| 156 |
+
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
| 157 |
+
super().__init__(device=device, torch_dtype=torch_dtype)
|
| 158 |
+
self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
|
| 159 |
+
self.dit: WanModel = None
|
| 160 |
+
self.vae: WanVideoVAE = None
|
| 161 |
+
self.model_names = ['dit', 'vae']
|
| 162 |
+
self.height_division_factor = 16
|
| 163 |
+
self.width_division_factor = 16
|
| 164 |
+
self.use_unified_sequence_parallel = False
|
| 165 |
+
self.prompt_emb_posi = None
|
| 166 |
+
self.ColorCorrector = TorchColorCorrectorWavelet(levels=5)
|
| 167 |
+
|
| 168 |
+
print(r"""
|
| 169 |
+
███████╗██╗ █████╗ ███████╗██╗ ██╗██╗ ██╗███████╗█████╗
|
| 170 |
+
██╔════╝██║ ██╔══██╗██╔════╝██║ ██║██║ ██║██╔════╝██╔══██╗ ██╗
|
| 171 |
+
█████╗ ██║ ███████║███████╗███████║╚██╗ ██╔╝███████╗███████║ ██████╗
|
| 172 |
+
██╔══╝ ██║ ██╔══██║╚════██║██╔══██║ ╚████╔╝ ╚════██║██╔═██║ ██╔═╝
|
| 173 |
+
██║ ███████╗██║ ██║███████║██║ ██║ ╚██╔╝ ███████║██║ ██║ ╚═╝
|
| 174 |
+
╚═╝ ╚══════╝╚═╝ ╚═╝╚══════╝╚═╝ ╚═╝ ╚═╝ ╚══════╝╚═╝ ╚═╝
|
| 175 |
+
""")
|
| 176 |
+
|
| 177 |
+
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
| 178 |
+
# 仅管理 dit / vae
|
| 179 |
+
dtype = next(iter(self.dit.parameters())).dtype
|
| 180 |
+
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
| 181 |
+
enable_vram_management(
|
| 182 |
+
self.dit,
|
| 183 |
+
module_map={
|
| 184 |
+
torch.nn.Linear: AutoWrappedLinear,
|
| 185 |
+
torch.nn.Conv3d: AutoWrappedModule,
|
| 186 |
+
torch.nn.LayerNorm: AutoWrappedModule,
|
| 187 |
+
RMSNorm: AutoWrappedModule,
|
| 188 |
+
},
|
| 189 |
+
module_config=dict(
|
| 190 |
+
offload_dtype=dtype,
|
| 191 |
+
offload_device="cpu",
|
| 192 |
+
onload_dtype=dtype,
|
| 193 |
+
onload_device=self.device,
|
| 194 |
+
computation_dtype=self.torch_dtype,
|
| 195 |
+
computation_device=self.device,
|
| 196 |
+
),
|
| 197 |
+
max_num_param=num_persistent_param_in_dit,
|
| 198 |
+
overflow_module_config=dict(
|
| 199 |
+
offload_dtype=dtype,
|
| 200 |
+
offload_device="cpu",
|
| 201 |
+
onload_dtype=dtype,
|
| 202 |
+
onload_device="cpu",
|
| 203 |
+
computation_dtype=self.torch_dtype,
|
| 204 |
+
computation_device=self.device,
|
| 205 |
+
),
|
| 206 |
+
)
|
| 207 |
+
dtype = next(iter(self.vae.parameters())).dtype
|
| 208 |
+
enable_vram_management(
|
| 209 |
+
self.vae,
|
| 210 |
+
module_map={
|
| 211 |
+
torch.nn.Linear: AutoWrappedLinear,
|
| 212 |
+
torch.nn.Conv2d: AutoWrappedModule,
|
| 213 |
+
RMS_norm: AutoWrappedModule,
|
| 214 |
+
CausalConv3d: AutoWrappedModule,
|
| 215 |
+
Upsample: AutoWrappedModule,
|
| 216 |
+
torch.nn.SiLU: AutoWrappedModule,
|
| 217 |
+
torch.nn.Dropout: AutoWrappedModule,
|
| 218 |
+
},
|
| 219 |
+
module_config=dict(
|
| 220 |
+
offload_dtype=dtype,
|
| 221 |
+
offload_device="cpu",
|
| 222 |
+
onload_dtype=dtype,
|
| 223 |
+
onload_device=self.device,
|
| 224 |
+
computation_dtype=self.torch_dtype,
|
| 225 |
+
computation_device=self.device,
|
| 226 |
+
),
|
| 227 |
+
)
|
| 228 |
+
self.enable_cpu_offload()
|
| 229 |
+
|
| 230 |
+
def fetch_models(self, model_manager: ModelManager):
|
| 231 |
+
self.dit = model_manager.fetch_model("wan_video_dit")
|
| 232 |
+
self.vae = model_manager.fetch_model("wan_video_vae")
|
| 233 |
+
|
| 234 |
+
@staticmethod
|
| 235 |
+
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
|
| 236 |
+
if device is None: device = model_manager.device
|
| 237 |
+
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
|
| 238 |
+
pipe = FlashVSRFullPipeline(device=device, torch_dtype=torch_dtype)
|
| 239 |
+
pipe.fetch_models(model_manager)
|
| 240 |
+
# 可选:统一序列并行入口(此处默认关闭)
|
| 241 |
+
pipe.use_unified_sequence_parallel = False
|
| 242 |
+
return pipe
|
| 243 |
+
|
| 244 |
+
def denoising_model(self):
|
| 245 |
+
return self.dit
|
| 246 |
+
|
| 247 |
+
# -------------------------
|
| 248 |
+
# 新增:显式 KV 预初始化函数
|
| 249 |
+
# -------------------------
|
| 250 |
+
def init_cross_kv(
|
| 251 |
+
self,
|
| 252 |
+
context_tensor: Optional[torch.Tensor] = None,
|
| 253 |
+
prompt_path = None
|
| 254 |
+
):
|
| 255 |
+
self.load_models_to_device(["dit"])
|
| 256 |
+
"""
|
| 257 |
+
使用固定 prompt 生成文本 context,并在 WanModel 中初始化所有 CrossAttention 的 KV 缓存。
|
| 258 |
+
必须在 __call__ 前显式调用一次。
|
| 259 |
+
"""
|
| 260 |
+
#prompt_path = "../../examples/WanVSR/prompt_tensor/posi_prompt.pth"
|
| 261 |
+
if self.dit is None:
|
| 262 |
+
raise RuntimeError("请先通过 fetch_models / from_model_manager 初始化 self.dit")
|
| 263 |
+
|
| 264 |
+
if context_tensor is None:
|
| 265 |
+
if prompt_path is None:
|
| 266 |
+
raise ValueError("init_cross_kv: 需要提供 prompt_path 或 context_tensor 其一")
|
| 267 |
+
ctx = torch.load(prompt_path, map_location=self.device)
|
| 268 |
+
else:
|
| 269 |
+
ctx = context_tensor
|
| 270 |
+
|
| 271 |
+
ctx = ctx.to(dtype=self.torch_dtype, device=self.device)
|
| 272 |
+
|
| 273 |
+
if self.prompt_emb_posi is None:
|
| 274 |
+
self.prompt_emb_posi = {}
|
| 275 |
+
self.prompt_emb_posi['context'] = ctx
|
| 276 |
+
self.prompt_emb_posi['stats'] = "load"
|
| 277 |
+
|
| 278 |
+
if hasattr(self.dit, "reinit_cross_kv"):
|
| 279 |
+
self.dit.reinit_cross_kv(ctx)
|
| 280 |
+
else:
|
| 281 |
+
raise AttributeError("WanModel 缺少 reinit_cross_kv(ctx) 方法,请在模型实现中加入该能力。")
|
| 282 |
+
self.timestep = torch.tensor([1000.], device=self.device, dtype=self.torch_dtype)
|
| 283 |
+
self.t = self.dit.time_embedding(sinusoidal_embedding_1d(self.dit.freq_dim, self.timestep))
|
| 284 |
+
self.t_mod = self.dit.time_projection(self.t).unflatten(1, (6, self.dit.dim))
|
| 285 |
+
# Scheduler
|
| 286 |
+
self.scheduler.set_timesteps(1, denoising_strength=1.0, shift=5.0)
|
| 287 |
+
self.load_models_to_device([])
|
| 288 |
+
|
| 289 |
+
def prepare_unified_sequence_parallel(self):
|
| 290 |
+
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
|
| 291 |
+
|
| 292 |
+
def prepare_extra_input(self, latents=None):
|
| 293 |
+
return {}
|
| 294 |
+
|
| 295 |
+
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
| 296 |
+
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
| 297 |
+
return latents
|
| 298 |
+
|
| 299 |
+
def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
| 300 |
+
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
| 301 |
+
return frames
|
| 302 |
+
|
| 303 |
+
def offload_model(self, keep_vae=False):
|
| 304 |
+
self.dit.clear_cross_kv()
|
| 305 |
+
self.prompt_emb_posi['stats'] = "offload"
|
| 306 |
+
if hasattr(self.dit, "LQ_proj_in"):
|
| 307 |
+
self.dit.LQ_proj_in.to('cpu')
|
| 308 |
+
if keep_vae:
|
| 309 |
+
self.load_models_to_device(["vae"])
|
| 310 |
+
else:
|
| 311 |
+
self.load_models_to_device([])
|
| 312 |
+
|
| 313 |
+
@torch.no_grad()
|
| 314 |
+
def __call__(
|
| 315 |
+
self,
|
| 316 |
+
prompt=None,
|
| 317 |
+
negative_prompt="",
|
| 318 |
+
denoising_strength=1.0,
|
| 319 |
+
seed=None,
|
| 320 |
+
rand_device="gpu",
|
| 321 |
+
height=480,
|
| 322 |
+
width=832,
|
| 323 |
+
num_frames=81,
|
| 324 |
+
cfg_scale=5.0,
|
| 325 |
+
num_inference_steps=50,
|
| 326 |
+
sigma_shift=5.0,
|
| 327 |
+
tiled=True,
|
| 328 |
+
tile_size=(60, 104),
|
| 329 |
+
tile_stride=(30, 52),
|
| 330 |
+
tea_cache_l1_thresh=None,
|
| 331 |
+
tea_cache_model_id="Wan2.1-T2V-1.3B",
|
| 332 |
+
progress_bar_cmd=tqdm,
|
| 333 |
+
progress_bar_st=None,
|
| 334 |
+
LQ_video=None,
|
| 335 |
+
is_full_block=False,
|
| 336 |
+
if_buffer=False,
|
| 337 |
+
topk_ratio=2.0,
|
| 338 |
+
kv_ratio=3.0,
|
| 339 |
+
local_range = 9,
|
| 340 |
+
color_fix = True,
|
| 341 |
+
unload_dit = False,
|
| 342 |
+
force_offload = False,
|
| 343 |
+
):
|
| 344 |
+
# 只接受 cfg=1.0(与原代码一致)
|
| 345 |
+
assert cfg_scale == 1.0, "cfg_scale must be 1.0"
|
| 346 |
+
|
| 347 |
+
# 要求:必须先 init_cross_kv()
|
| 348 |
+
if self.prompt_emb_posi is None or 'context' not in self.prompt_emb_posi:
|
| 349 |
+
raise RuntimeError(
|
| 350 |
+
"Cross-Attn KV 未初始化。请在调用 __call__ 前先执行:\n"
|
| 351 |
+
" pipe.init_cross_kv()\n"
|
| 352 |
+
"或传入自定义 context:\n"
|
| 353 |
+
" pipe.init_cross_kv(context_tensor=your_context_tensor)"
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
if num_frames % 4 != 1:
|
| 357 |
+
num_frames = (num_frames + 2) // 4 * 4 + 1
|
| 358 |
+
print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
|
| 359 |
+
|
| 360 |
+
# Tiler 参数
|
| 361 |
+
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
| 362 |
+
|
| 363 |
+
# 初始化噪声
|
| 364 |
+
if if_buffer:
|
| 365 |
+
noise = self.generate_noise((1, 16, (num_frames - 1) // 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
| 366 |
+
else:
|
| 367 |
+
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
| 368 |
+
# noise = noise.to(dtype=self.torch_dtype, device=self.device)
|
| 369 |
+
latents = noise
|
| 370 |
+
|
| 371 |
+
process_total_num = (num_frames - 1) // 8 - 2
|
| 372 |
+
is_stream = True
|
| 373 |
+
|
| 374 |
+
if self.prompt_emb_posi['stats'] == "offload":
|
| 375 |
+
self.init_cross_kv(context_tensor=self.prompt_emb_posi['context'])
|
| 376 |
+
self.load_models_to_device(["dit", "vae"])
|
| 377 |
+
self.dit.LQ_proj_in.to(self.device)
|
| 378 |
+
|
| 379 |
+
# 清理可能存在的 LQ_proj_in cache
|
| 380 |
+
if hasattr(self.dit, "LQ_proj_in"):
|
| 381 |
+
self.dit.LQ_proj_in.clear_cache()
|
| 382 |
+
|
| 383 |
+
latents_total = []
|
| 384 |
+
self.vae.clear_cache()
|
| 385 |
+
|
| 386 |
+
with torch.no_grad():
|
| 387 |
+
for cur_process_idx in progress_bar_cmd(range(process_total_num)):
|
| 388 |
+
if cur_process_idx == 0:
|
| 389 |
+
pre_cache_k = [None] * len(self.dit.blocks)
|
| 390 |
+
pre_cache_v = [None] * len(self.dit.blocks)
|
| 391 |
+
LQ_latents = None
|
| 392 |
+
inner_loop_num = 7
|
| 393 |
+
for inner_idx in range(inner_loop_num):
|
| 394 |
+
cur = self.denoising_model().LQ_proj_in.stream_forward(
|
| 395 |
+
LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :]
|
| 396 |
+
) if LQ_video is not None else None
|
| 397 |
+
if cur is None:
|
| 398 |
+
continue
|
| 399 |
+
if LQ_latents is None:
|
| 400 |
+
LQ_latents = cur
|
| 401 |
+
else:
|
| 402 |
+
for layer_idx in range(len(LQ_latents)):
|
| 403 |
+
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
|
| 404 |
+
cur_latents = latents[:, :, :6, :, :]
|
| 405 |
+
else:
|
| 406 |
+
LQ_latents = None
|
| 407 |
+
inner_loop_num = 2
|
| 408 |
+
for inner_idx in range(inner_loop_num):
|
| 409 |
+
cur = self.denoising_model().LQ_proj_in.stream_forward(
|
| 410 |
+
LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :]
|
| 411 |
+
) if LQ_video is not None else None
|
| 412 |
+
if cur is None:
|
| 413 |
+
continue
|
| 414 |
+
if LQ_latents is None:
|
| 415 |
+
LQ_latents = cur
|
| 416 |
+
else:
|
| 417 |
+
for layer_idx in range(len(LQ_latents)):
|
| 418 |
+
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
|
| 419 |
+
cur_latents = latents[:, :, 4+cur_process_idx*2:6+cur_process_idx*2, :, :]
|
| 420 |
+
|
| 421 |
+
# 推理(无 motion_controller / vace)
|
| 422 |
+
noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video(
|
| 423 |
+
self.dit,
|
| 424 |
+
x=cur_latents,
|
| 425 |
+
timestep=self.timestep,
|
| 426 |
+
context=None,
|
| 427 |
+
tea_cache=None,
|
| 428 |
+
use_unified_sequence_parallel=False,
|
| 429 |
+
LQ_latents=LQ_latents,
|
| 430 |
+
is_full_block=is_full_block,
|
| 431 |
+
is_stream=is_stream,
|
| 432 |
+
pre_cache_k=pre_cache_k,
|
| 433 |
+
pre_cache_v=pre_cache_v,
|
| 434 |
+
topk_ratio=topk_ratio,
|
| 435 |
+
kv_ratio=kv_ratio,
|
| 436 |
+
cur_process_idx=cur_process_idx,
|
| 437 |
+
t_mod=self.t_mod,
|
| 438 |
+
t=self.t,
|
| 439 |
+
local_range = local_range,
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
# 更新 latent
|
| 443 |
+
cur_latents = cur_latents - noise_pred_posi
|
| 444 |
+
latents_total.append(cur_latents)
|
| 445 |
+
|
| 446 |
+
if hasattr(self.dit, "LQ_proj_in"):
|
| 447 |
+
self.dit.LQ_proj_in.clear_cache()
|
| 448 |
+
|
| 449 |
+
if unload_dit and hasattr(self, 'dit') and not next(self.dit.parameters()).is_cpu:
|
| 450 |
+
print("[FlashVSR] Offloading DiT to the CPU to free up VRAM...")
|
| 451 |
+
self.offload_model(keep_vae=True)
|
| 452 |
+
|
| 453 |
+
latents = torch.cat(latents_total, dim=2)
|
| 454 |
+
|
| 455 |
+
# Decode
|
| 456 |
+
print("[FlashVSR] Starting VAE decoding...")
|
| 457 |
+
frames = self.decode_video(latents, **tiler_kwargs)
|
| 458 |
+
|
| 459 |
+
self.vae.clear_cache()
|
| 460 |
+
if force_offload:
|
| 461 |
+
self.offload_model()
|
| 462 |
+
|
| 463 |
+
# 颜色校正(wavelet)
|
| 464 |
+
try:
|
| 465 |
+
if color_fix:
|
| 466 |
+
frames = self.ColorCorrector(
|
| 467 |
+
frames.to(device=LQ_video.device),
|
| 468 |
+
LQ_video[:, :, :frames.shape[2], :, :],
|
| 469 |
+
clip_range=(-1, 1),
|
| 470 |
+
chunk_size=16,
|
| 471 |
+
method='adain'
|
| 472 |
+
)
|
| 473 |
+
except:
|
| 474 |
+
pass
|
| 475 |
+
|
| 476 |
+
return frames[0]
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
# -----------------------------
|
| 480 |
+
# TeaCache(保留原逻���;此处默认不启用)
|
| 481 |
+
# -----------------------------
|
| 482 |
+
class TeaCache:
|
| 483 |
+
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
|
| 484 |
+
self.num_inference_steps = num_inference_steps
|
| 485 |
+
self.step = 0
|
| 486 |
+
self.accumulated_rel_l1_distance = 0
|
| 487 |
+
self.previous_modulated_input = None
|
| 488 |
+
self.rel_l1_thresh = rel_l1_thresh
|
| 489 |
+
self.previous_residual = None
|
| 490 |
+
self.previous_hidden_states = None
|
| 491 |
+
|
| 492 |
+
self.coefficients_dict = {
|
| 493 |
+
"Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
|
| 494 |
+
"Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
|
| 495 |
+
"Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
|
| 496 |
+
"Wan2.1-I2V-14B-720P": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
|
| 497 |
+
}
|
| 498 |
+
if model_id not in self.coefficients_dict:
|
| 499 |
+
supported_model_ids = ", ".join([i for i in self.coefficients_dict])
|
| 500 |
+
raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
|
| 501 |
+
self.coefficients = self.coefficients_dict[model_id]
|
| 502 |
+
|
| 503 |
+
def check(self, dit: WanModel, x, t_mod):
|
| 504 |
+
modulated_inp = t_mod.clone()
|
| 505 |
+
if self.step == 0 or self.step == self.num_inference_steps - 1:
|
| 506 |
+
should_calc = True
|
| 507 |
+
self.accumulated_rel_l1_distance = 0
|
| 508 |
+
else:
|
| 509 |
+
coefficients = self.coefficients
|
| 510 |
+
rescale_func = np.poly1d(coefficients)
|
| 511 |
+
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
| 512 |
+
should_calc = not (self.accumulated_rel_l1_distance < self.rel_l1_thresh)
|
| 513 |
+
if should_calc:
|
| 514 |
+
self.accumulated_rel_l1_distance = 0
|
| 515 |
+
self.previous_modulated_input = modulated_inp
|
| 516 |
+
self.step = (self.step + 1) % self.num_inference_steps
|
| 517 |
+
if should_calc:
|
| 518 |
+
self.previous_hidden_states = x.clone()
|
| 519 |
+
return not should_calc
|
| 520 |
+
|
| 521 |
+
def store(self, hidden_states):
|
| 522 |
+
self.previous_residual = hidden_states - self.previous_hidden_states
|
| 523 |
+
self.previous_hidden_states = None
|
| 524 |
+
|
| 525 |
+
def update(self, hidden_states):
|
| 526 |
+
hidden_states = hidden_states + self.previous_residual
|
| 527 |
+
return hidden_states
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
# -----------------------------
|
| 531 |
+
# 简化版模型前向封装(无 vace / 无 motion_controller)
|
| 532 |
+
# -----------------------------
|
| 533 |
+
def model_fn_wan_video(
|
| 534 |
+
dit: WanModel,
|
| 535 |
+
x: torch.Tensor,
|
| 536 |
+
timestep: torch.Tensor,
|
| 537 |
+
context: torch.Tensor,
|
| 538 |
+
tea_cache: Optional[TeaCache] = None,
|
| 539 |
+
use_unified_sequence_parallel: bool = False,
|
| 540 |
+
LQ_latents: Optional[torch.Tensor] = None,
|
| 541 |
+
is_full_block: bool = False,
|
| 542 |
+
is_stream: bool = False,
|
| 543 |
+
pre_cache_k: Optional[list[torch.Tensor]] = None,
|
| 544 |
+
pre_cache_v: Optional[list[torch.Tensor]] = None,
|
| 545 |
+
topk_ratio: float = 2.0,
|
| 546 |
+
kv_ratio: float = 3.0,
|
| 547 |
+
cur_process_idx: int = 0,
|
| 548 |
+
t_mod : torch.Tensor = None,
|
| 549 |
+
t : torch.Tensor = None,
|
| 550 |
+
local_range: int = 9,
|
| 551 |
+
**kwargs,
|
| 552 |
+
):
|
| 553 |
+
# patchify
|
| 554 |
+
x, (f, h, w) = dit.patchify(x)
|
| 555 |
+
|
| 556 |
+
win = (2, 8, 8)
|
| 557 |
+
seqlen = f // win[0]
|
| 558 |
+
local_num = seqlen
|
| 559 |
+
window_size = win[0] * h * w // 128
|
| 560 |
+
square_num = window_size * window_size
|
| 561 |
+
topk = int(square_num * topk_ratio) - 1
|
| 562 |
+
kv_len = int(kv_ratio)
|
| 563 |
+
|
| 564 |
+
# RoPE 位置(分段)
|
| 565 |
+
if cur_process_idx == 0:
|
| 566 |
+
freqs = torch.cat([
|
| 567 |
+
dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
| 568 |
+
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
| 569 |
+
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
| 570 |
+
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
| 571 |
+
else:
|
| 572 |
+
freqs = torch.cat([
|
| 573 |
+
dit.freqs[0][4 + cur_process_idx*2:4 + cur_process_idx*2 + f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
| 574 |
+
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
| 575 |
+
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
| 576 |
+
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
| 577 |
+
|
| 578 |
+
# TeaCache(默认不启用)
|
| 579 |
+
tea_cache_update = tea_cache.check(dit, x, t_mod) if tea_cache is not None else False
|
| 580 |
+
|
| 581 |
+
# 统一序列并行(此处默认关闭)
|
| 582 |
+
if use_unified_sequence_parallel:
|
| 583 |
+
import torch.distributed as dist
|
| 584 |
+
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
| 585 |
+
get_sequence_parallel_world_size,
|
| 586 |
+
get_sp_group)
|
| 587 |
+
if dist.is_initialized() and dist.get_world_size() > 1:
|
| 588 |
+
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
| 589 |
+
|
| 590 |
+
# Block 堆叠
|
| 591 |
+
if tea_cache_update:
|
| 592 |
+
x = tea_cache.update(x)
|
| 593 |
+
else:
|
| 594 |
+
for block_id, block in enumerate(dit.blocks):
|
| 595 |
+
if LQ_latents is not None and block_id < len(LQ_latents):
|
| 596 |
+
x = x + LQ_latents[block_id]
|
| 597 |
+
x, last_pre_cache_k, last_pre_cache_v = block(
|
| 598 |
+
x, context, t_mod, freqs, f, h, w,
|
| 599 |
+
local_num, topk,
|
| 600 |
+
block_id=block_id,
|
| 601 |
+
kv_len=kv_len,
|
| 602 |
+
is_full_block=is_full_block,
|
| 603 |
+
is_stream=is_stream,
|
| 604 |
+
pre_cache_k=pre_cache_k[block_id] if pre_cache_k is not None else None,
|
| 605 |
+
pre_cache_v=pre_cache_v[block_id] if pre_cache_v is not None else None,
|
| 606 |
+
local_range = local_range,
|
| 607 |
+
)
|
| 608 |
+
if pre_cache_k is not None: pre_cache_k[block_id] = last_pre_cache_k
|
| 609 |
+
if pre_cache_v is not None: pre_cache_v[block_id] = last_pre_cache_v
|
| 610 |
+
|
| 611 |
+
x = dit.head(x, t)
|
| 612 |
+
if use_unified_sequence_parallel:
|
| 613 |
+
import torch.distributed as dist
|
| 614 |
+
from xfuser.core.distributed import get_sp_group
|
| 615 |
+
if dist.is_initialized() and dist.get_world_size() > 1:
|
| 616 |
+
x = get_sp_group().all_gather(x, dim=1)
|
| 617 |
+
x = dit.unpatchify(x, (f, h, w))
|
| 618 |
+
return x, pre_cache_k, pre_cache_v
|
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/pipelines/flashvsr_tiny.py
ADDED
|
@@ -0,0 +1,615 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import types
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
from typing import Optional, Tuple, Literal
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import numpy as np
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
# import pyfiglet
|
| 14 |
+
|
| 15 |
+
from ..models import ModelManager
|
| 16 |
+
from ..models.utils import clean_vram
|
| 17 |
+
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
|
| 18 |
+
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
|
| 19 |
+
from ..schedulers.flow_match import FlowMatchScheduler
|
| 20 |
+
from .base import BasePipeline
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# -----------------------------
|
| 24 |
+
# 基础工具:ADAIN 所需的统计量(保留以备需要;管线默认用 wavelet)
|
| 25 |
+
# -----------------------------
|
| 26 |
+
def _calc_mean_std(feat: torch.Tensor, eps: float = 1e-5) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 27 |
+
assert feat.dim() == 4, 'feat 必须是 (N, C, H, W)'
|
| 28 |
+
N, C = feat.shape[:2]
|
| 29 |
+
var = feat.view(N, C, -1).var(dim=2, unbiased=False) + eps
|
| 30 |
+
std = var.sqrt().view(N, C, 1, 1)
|
| 31 |
+
mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
|
| 32 |
+
return mean, std
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _adain(content_feat: torch.Tensor, style_feat: torch.Tensor) -> torch.Tensor:
|
| 36 |
+
assert content_feat.shape[:2] == style_feat.shape[:2], "ADAIN: N、C 必须匹配"
|
| 37 |
+
size = content_feat.size()
|
| 38 |
+
style_mean, style_std = _calc_mean_std(style_feat)
|
| 39 |
+
content_mean, content_std = _calc_mean_std(content_feat)
|
| 40 |
+
normalized = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
| 41 |
+
return normalized * style_std.expand(size) + style_mean.expand(size)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# -----------------------------
|
| 45 |
+
# 小波式模糊与分解/重构(ColorCorrector 用)
|
| 46 |
+
# -----------------------------
|
| 47 |
+
def _make_gaussian3x3_kernel(dtype, device) -> torch.Tensor:
|
| 48 |
+
vals = [
|
| 49 |
+
[0.0625, 0.125, 0.0625],
|
| 50 |
+
[0.125, 0.25, 0.125 ],
|
| 51 |
+
[0.0625, 0.125, 0.0625],
|
| 52 |
+
]
|
| 53 |
+
return torch.tensor(vals, dtype=dtype, device=device)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _wavelet_blur(x: torch.Tensor, radius: int) -> torch.Tensor:
|
| 57 |
+
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
|
| 58 |
+
N, C, H, W = x.shape
|
| 59 |
+
base = _make_gaussian3x3_kernel(x.dtype, x.device)
|
| 60 |
+
weight = base.view(1, 1, 3, 3).repeat(C, 1, 1, 1)
|
| 61 |
+
pad = radius
|
| 62 |
+
x_pad = F.pad(x, (pad, pad, pad, pad), mode='replicate')
|
| 63 |
+
out = F.conv2d(x_pad, weight, bias=None, stride=1, padding=0, dilation=radius, groups=C)
|
| 64 |
+
return out
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _wavelet_decompose(x: torch.Tensor, levels: int = 5) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 68 |
+
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
|
| 69 |
+
high = torch.zeros_like(x)
|
| 70 |
+
low = x
|
| 71 |
+
for i in range(levels):
|
| 72 |
+
radius = 2 ** i
|
| 73 |
+
blurred = _wavelet_blur(low, radius)
|
| 74 |
+
high = high + (low - blurred)
|
| 75 |
+
low = blurred
|
| 76 |
+
return high, low
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _wavelet_reconstruct(content: torch.Tensor, style: torch.Tensor, levels: int = 5) -> torch.Tensor:
|
| 80 |
+
c_high, _ = _wavelet_decompose(content, levels=levels)
|
| 81 |
+
_, s_low = _wavelet_decompose(style, levels=levels)
|
| 82 |
+
return c_high + s_low
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# -----------------------------
|
| 86 |
+
# 无状态颜色矫正模块(视频友好,默认 wavelet)
|
| 87 |
+
# -----------------------------
|
| 88 |
+
class TorchColorCorrectorWavelet(nn.Module):
|
| 89 |
+
def __init__(self, levels: int = 5):
|
| 90 |
+
super().__init__()
|
| 91 |
+
self.levels = levels
|
| 92 |
+
|
| 93 |
+
@staticmethod
|
| 94 |
+
def _flatten_time(x: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
|
| 95 |
+
assert x.dim() == 5, '输入必须是 (B, C, f, H, W)'
|
| 96 |
+
B, C, f, H, W = x.shape
|
| 97 |
+
y = x.permute(0, 2, 1, 3, 4).reshape(B * f, C, H, W)
|
| 98 |
+
return y, B, f
|
| 99 |
+
|
| 100 |
+
@staticmethod
|
| 101 |
+
def _unflatten_time(y: torch.Tensor, B: int, f: int) -> torch.Tensor:
|
| 102 |
+
BF, C, H, W = y.shape
|
| 103 |
+
assert BF == B * f
|
| 104 |
+
return y.reshape(B, f, C, H, W).permute(0, 2, 1, 3, 4)
|
| 105 |
+
|
| 106 |
+
def forward(
|
| 107 |
+
self,
|
| 108 |
+
hq_image: torch.Tensor, # (B, C, f, H, W)
|
| 109 |
+
lq_image: torch.Tensor, # (B, C, f, H, W)
|
| 110 |
+
clip_range: Tuple[float, float] = (-1.0, 1.0),
|
| 111 |
+
method: Literal['wavelet', 'adain'] = 'wavelet',
|
| 112 |
+
chunk_size: Optional[int] = None,
|
| 113 |
+
) -> torch.Tensor:
|
| 114 |
+
assert hq_image.shape == lq_image.shape, "HQ 与 LQ 的形状必须一致"
|
| 115 |
+
assert hq_image.dim() == 5 and hq_image.shape[1] == 3, "输入必须是 (B, 3, f, H, W)"
|
| 116 |
+
|
| 117 |
+
B, C, f, H, W = hq_image.shape
|
| 118 |
+
if chunk_size is None or chunk_size >= f:
|
| 119 |
+
hq4, B, f = self._flatten_time(hq_image)
|
| 120 |
+
lq4, _, _ = self._flatten_time(lq_image)
|
| 121 |
+
if method == 'wavelet':
|
| 122 |
+
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
|
| 123 |
+
elif method == 'adain':
|
| 124 |
+
out4 = _adain(hq4, lq4)
|
| 125 |
+
else:
|
| 126 |
+
raise ValueError(f"未知 method: {method}")
|
| 127 |
+
out4 = torch.clamp(out4, *clip_range)
|
| 128 |
+
out = self._unflatten_time(out4, B, f)
|
| 129 |
+
return out
|
| 130 |
+
|
| 131 |
+
outs = []
|
| 132 |
+
for start in range(0, f, chunk_size):
|
| 133 |
+
end = min(start + chunk_size, f)
|
| 134 |
+
hq_chunk = hq_image[:, :, start:end]
|
| 135 |
+
lq_chunk = lq_image[:, :, start:end]
|
| 136 |
+
hq4, B_, f_ = self._flatten_time(hq_chunk)
|
| 137 |
+
lq4, _, _ = self._flatten_time(lq_chunk)
|
| 138 |
+
if method == 'wavelet':
|
| 139 |
+
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
|
| 140 |
+
elif method == 'adain':
|
| 141 |
+
out4 = _adain(hq4, lq4)
|
| 142 |
+
else:
|
| 143 |
+
raise ValueError(f"未知 method: {method}")
|
| 144 |
+
out4 = torch.clamp(out4, *clip_range)
|
| 145 |
+
out_chunk = self._unflatten_time(out4, B_, f_)
|
| 146 |
+
outs.append(out_chunk)
|
| 147 |
+
out = torch.cat(outs, dim=2)
|
| 148 |
+
return out
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# -----------------------------
|
| 152 |
+
# 简化版 Pipeline(仅 dit + vae)
|
| 153 |
+
# -----------------------------
|
| 154 |
+
class FlashVSRTinyPipeline(BasePipeline):
|
| 155 |
+
|
| 156 |
+
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
| 157 |
+
super().__init__(device=device, torch_dtype=torch_dtype)
|
| 158 |
+
self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
|
| 159 |
+
self.dit: WanModel = None
|
| 160 |
+
self.vae: WanVideoVAE = None
|
| 161 |
+
self.model_names = ['dit', 'vae']
|
| 162 |
+
self.height_division_factor = 16
|
| 163 |
+
self.width_division_factor = 16
|
| 164 |
+
self.use_unified_sequence_parallel = False
|
| 165 |
+
self.prompt_emb_posi = None
|
| 166 |
+
self.ColorCorrector = TorchColorCorrectorWavelet(levels=5)
|
| 167 |
+
|
| 168 |
+
print(r"""
|
| 169 |
+
███████╗██╗ █████╗ ███████╗██╗ ██╗██╗ ██╗███████╗█████╗
|
| 170 |
+
██╔════╝██║ ██╔══██╗██╔════╝██║ ██║██║ ██║██╔════╝██╔══██╗ ██╗
|
| 171 |
+
█████╗ ██║ ███████║███████╗███████║╚██╗ ██╔╝███████╗███████║ ██████╗
|
| 172 |
+
██╔══╝ ██║ ██╔══██║╚════██║██╔══██║ ╚████╔╝ ╚════██║██╔═██║ ██╔═╝
|
| 173 |
+
██║ ███████╗██║ ██║███████║██║ ██║ ╚██╔╝ ███████║██║ ██║ ╚═╝
|
| 174 |
+
╚═╝ ╚══════╝╚═╝ ╚═╝╚══════╝╚═╝ ╚═╝ ╚═╝ ╚══════╝╚═╝ ╚═╝
|
| 175 |
+
""")
|
| 176 |
+
|
| 177 |
+
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
| 178 |
+
# 仅管理 dit / vae
|
| 179 |
+
dtype = next(iter(self.dit.parameters())).dtype
|
| 180 |
+
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
| 181 |
+
enable_vram_management(
|
| 182 |
+
self.dit,
|
| 183 |
+
module_map={
|
| 184 |
+
torch.nn.Linear: AutoWrappedLinear,
|
| 185 |
+
torch.nn.Conv3d: AutoWrappedModule,
|
| 186 |
+
torch.nn.LayerNorm: AutoWrappedModule,
|
| 187 |
+
RMSNorm: AutoWrappedModule,
|
| 188 |
+
},
|
| 189 |
+
module_config=dict(
|
| 190 |
+
offload_dtype=dtype,
|
| 191 |
+
offload_device="cpu",
|
| 192 |
+
onload_dtype=dtype,
|
| 193 |
+
onload_device=self.device,
|
| 194 |
+
computation_dtype=self.torch_dtype,
|
| 195 |
+
computation_device=self.device,
|
| 196 |
+
),
|
| 197 |
+
max_num_param=num_persistent_param_in_dit,
|
| 198 |
+
overflow_module_config=dict(
|
| 199 |
+
offload_dtype=dtype,
|
| 200 |
+
offload_device="cpu",
|
| 201 |
+
onload_dtype=dtype,
|
| 202 |
+
onload_device="cpu",
|
| 203 |
+
computation_dtype=self.torch_dtype,
|
| 204 |
+
computation_device=self.device,
|
| 205 |
+
),
|
| 206 |
+
)
|
| 207 |
+
self.enable_cpu_offload()
|
| 208 |
+
|
| 209 |
+
def fetch_models(self, model_manager: ModelManager):
|
| 210 |
+
self.dit = model_manager.fetch_model("wan_video_dit")
|
| 211 |
+
self.vae = model_manager.fetch_model("wan_video_vae")
|
| 212 |
+
|
| 213 |
+
@staticmethod
|
| 214 |
+
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
|
| 215 |
+
if device is None: device = model_manager.device
|
| 216 |
+
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
|
| 217 |
+
pipe = FlashVSRTinyPipeline(device=device, torch_dtype=torch_dtype)
|
| 218 |
+
pipe.fetch_models(model_manager)
|
| 219 |
+
# 可选:统一序列并行入口(此处默认关闭)
|
| 220 |
+
pipe.use_unified_sequence_parallel = False
|
| 221 |
+
return pipe
|
| 222 |
+
|
| 223 |
+
def denoising_model(self):
|
| 224 |
+
return self.dit
|
| 225 |
+
|
| 226 |
+
# -------------------------
|
| 227 |
+
# 新增:显式 KV 预初始化函数
|
| 228 |
+
# -------------------------
|
| 229 |
+
def init_cross_kv(
|
| 230 |
+
self,
|
| 231 |
+
context_tensor: Optional[torch.Tensor] = None,
|
| 232 |
+
prompt_path = None,
|
| 233 |
+
):
|
| 234 |
+
self.load_models_to_device(["dit"])
|
| 235 |
+
"""
|
| 236 |
+
使用固定 prompt 生成文本 context,并在 WanModel 中初始化所有 CrossAttention 的 KV 缓存。
|
| 237 |
+
必须在 __call__ 前显式调用一次。
|
| 238 |
+
"""
|
| 239 |
+
#prompt_path = "../../examples/WanVSR/prompt_tensor/posi_prompt.pth"
|
| 240 |
+
|
| 241 |
+
if self.dit is None:
|
| 242 |
+
raise RuntimeError("请先通过 fetch_models / from_model_manager 初始化 self.dit")
|
| 243 |
+
|
| 244 |
+
if context_tensor is None:
|
| 245 |
+
if prompt_path is None:
|
| 246 |
+
raise ValueError("init_cross_kv: 需要提供 prompt_path 或 context_tensor 其一")
|
| 247 |
+
ctx = torch.load(prompt_path, map_location=self.device)
|
| 248 |
+
else:
|
| 249 |
+
ctx = context_tensor
|
| 250 |
+
|
| 251 |
+
ctx = ctx.to(dtype=self.torch_dtype, device=self.device)
|
| 252 |
+
|
| 253 |
+
if self.prompt_emb_posi is None:
|
| 254 |
+
self.prompt_emb_posi = {}
|
| 255 |
+
self.prompt_emb_posi['context'] = ctx
|
| 256 |
+
self.prompt_emb_posi['stats'] = "load"
|
| 257 |
+
|
| 258 |
+
if hasattr(self.dit, "reinit_cross_kv"):
|
| 259 |
+
self.dit.reinit_cross_kv(ctx)
|
| 260 |
+
else:
|
| 261 |
+
raise AttributeError("WanModel 缺少 reinit_cross_kv(ctx) 方法,请在模型实现中加入该能力。")
|
| 262 |
+
self.timestep = torch.tensor([1000.], device=self.device, dtype=self.torch_dtype)
|
| 263 |
+
self.t = self.dit.time_embedding(sinusoidal_embedding_1d(self.dit.freq_dim, self.timestep))
|
| 264 |
+
self.t_mod = self.dit.time_projection(self.t).unflatten(1, (6, self.dit.dim))
|
| 265 |
+
# Scheduler
|
| 266 |
+
self.scheduler.set_timesteps(1, denoising_strength=1.0, shift=5.0)
|
| 267 |
+
self.load_models_to_device([])
|
| 268 |
+
|
| 269 |
+
def prepare_unified_sequence_parallel(self):
|
| 270 |
+
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
|
| 271 |
+
|
| 272 |
+
def prepare_extra_input(self, latents=None):
|
| 273 |
+
return {}
|
| 274 |
+
|
| 275 |
+
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
| 276 |
+
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
| 277 |
+
return latents
|
| 278 |
+
|
| 279 |
+
def _decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
| 280 |
+
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
| 281 |
+
return frames
|
| 282 |
+
|
| 283 |
+
def decode_video(self, latents, cond=None, **kwargs):
|
| 284 |
+
frames = self.TCDecoder.decode_video(
|
| 285 |
+
latents.transpose(1, 2), # TCDecoder 需要 (B, F, C, H, W)
|
| 286 |
+
parallel=False,
|
| 287 |
+
show_progress_bar=False,
|
| 288 |
+
cond=cond
|
| 289 |
+
).transpose(1, 2).mul_(2).sub_(1) # 转回 (B, C, F, H, W) 格式,范围 -1 to 1
|
| 290 |
+
|
| 291 |
+
return frames
|
| 292 |
+
|
| 293 |
+
def offload_model(self, keep_vae=False):
|
| 294 |
+
self.dit.clear_cross_kv()
|
| 295 |
+
self.prompt_emb_posi['stats'] = "offload"
|
| 296 |
+
self.load_models_to_device([])
|
| 297 |
+
if hasattr(self.dit, "LQ_proj_in"):
|
| 298 |
+
self.dit.LQ_proj_in.to('cpu')
|
| 299 |
+
if not keep_vae:
|
| 300 |
+
self.TCDecoder.to('cpu')
|
| 301 |
+
|
| 302 |
+
@torch.no_grad()
|
| 303 |
+
def __call__(
|
| 304 |
+
self,
|
| 305 |
+
prompt=None,
|
| 306 |
+
negative_prompt="",
|
| 307 |
+
denoising_strength=1.0,
|
| 308 |
+
seed=None,
|
| 309 |
+
rand_device="gpu",
|
| 310 |
+
height=480,
|
| 311 |
+
width=832,
|
| 312 |
+
num_frames=81,
|
| 313 |
+
cfg_scale=5.0,
|
| 314 |
+
num_inference_steps=50,
|
| 315 |
+
sigma_shift=5.0,
|
| 316 |
+
tiled=True,
|
| 317 |
+
tile_size=(60, 104),
|
| 318 |
+
tile_stride=(30, 52),
|
| 319 |
+
tea_cache_l1_thresh=None,
|
| 320 |
+
tea_cache_model_id="Wan2.1-T2V-1.3B",
|
| 321 |
+
progress_bar_cmd=tqdm,
|
| 322 |
+
progress_bar_st=None,
|
| 323 |
+
LQ_video=None,
|
| 324 |
+
is_full_block=False,
|
| 325 |
+
if_buffer=False,
|
| 326 |
+
topk_ratio=2.0,
|
| 327 |
+
kv_ratio=3.0,
|
| 328 |
+
local_range = 9,
|
| 329 |
+
color_fix = True,
|
| 330 |
+
unload_dit = False,
|
| 331 |
+
force_offload = False,
|
| 332 |
+
):
|
| 333 |
+
# 只接受 cfg=1.0(与原代码一致)
|
| 334 |
+
assert cfg_scale == 1.0, "cfg_scale must be 1.0"
|
| 335 |
+
|
| 336 |
+
# 要求:必须先 init_cross_kv()
|
| 337 |
+
if self.prompt_emb_posi is None or 'context' not in self.prompt_emb_posi:
|
| 338 |
+
raise RuntimeError(
|
| 339 |
+
"Cross-Attn KV 未初始化。请在调用 __call__ 前先执行:\n"
|
| 340 |
+
" pipe.init_cross_kv()\n"
|
| 341 |
+
"或传入自定义 context:\n"
|
| 342 |
+
" pipe.init_cross_kv(context_tensor=your_context_tensor)"
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
# 尺寸修正
|
| 346 |
+
height, width = self.check_resize_height_width(height, width)
|
| 347 |
+
if num_frames % 4 != 1:
|
| 348 |
+
num_frames = (num_frames + 2) // 4 * 4 + 1
|
| 349 |
+
print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
|
| 350 |
+
|
| 351 |
+
# Tiler 参数
|
| 352 |
+
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
| 353 |
+
|
| 354 |
+
# 初始化噪声
|
| 355 |
+
if if_buffer:
|
| 356 |
+
noise = self.generate_noise((1, 16, (num_frames - 1) // 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
| 357 |
+
else:
|
| 358 |
+
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
| 359 |
+
# noise = noise.to(dtype=self.torch_dtype, device=self.device)
|
| 360 |
+
latents = noise
|
| 361 |
+
|
| 362 |
+
process_total_num = (num_frames - 1) // 8 - 2
|
| 363 |
+
is_stream = True
|
| 364 |
+
|
| 365 |
+
if self.prompt_emb_posi['stats'] == "offload":
|
| 366 |
+
self.init_cross_kv(context_tensor=self.prompt_emb_posi['context'])
|
| 367 |
+
self.load_models_to_device(["dit"])
|
| 368 |
+
self.dit.LQ_proj_in.to(self.device)
|
| 369 |
+
self.TCDecoder.to(self.device)
|
| 370 |
+
|
| 371 |
+
# 清理可能存在的 LQ_proj_in cache
|
| 372 |
+
if hasattr(self.dit, "LQ_proj_in"):
|
| 373 |
+
self.dit.LQ_proj_in.clear_cache()
|
| 374 |
+
|
| 375 |
+
latents_total = []
|
| 376 |
+
self.TCDecoder.clean_mem()
|
| 377 |
+
LQ_pre_idx = 0
|
| 378 |
+
LQ_cur_idx = 0
|
| 379 |
+
|
| 380 |
+
with torch.no_grad():
|
| 381 |
+
for cur_process_idx in progress_bar_cmd(range(process_total_num)):
|
| 382 |
+
if cur_process_idx == 0:
|
| 383 |
+
pre_cache_k = [None] * len(self.dit.blocks)
|
| 384 |
+
pre_cache_v = [None] * len(self.dit.blocks)
|
| 385 |
+
LQ_latents = None
|
| 386 |
+
inner_loop_num = 7
|
| 387 |
+
for inner_idx in range(inner_loop_num):
|
| 388 |
+
cur = self.denoising_model().LQ_proj_in.stream_forward(
|
| 389 |
+
LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :]
|
| 390 |
+
) if LQ_video is not None else None
|
| 391 |
+
if cur is None:
|
| 392 |
+
continue
|
| 393 |
+
if LQ_latents is None:
|
| 394 |
+
LQ_latents = cur
|
| 395 |
+
else:
|
| 396 |
+
for layer_idx in range(len(LQ_latents)):
|
| 397 |
+
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
|
| 398 |
+
LQ_cur_idx = (inner_loop_num-1)*4-3
|
| 399 |
+
cur_latents = latents[:, :, :6, :, :]
|
| 400 |
+
else:
|
| 401 |
+
LQ_latents = None
|
| 402 |
+
inner_loop_num = 2
|
| 403 |
+
for inner_idx in range(inner_loop_num):
|
| 404 |
+
cur = self.denoising_model().LQ_proj_in.stream_forward(
|
| 405 |
+
LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :]
|
| 406 |
+
) if LQ_video is not None else None
|
| 407 |
+
if cur is None:
|
| 408 |
+
continue
|
| 409 |
+
if LQ_latents is None:
|
| 410 |
+
LQ_latents = cur
|
| 411 |
+
else:
|
| 412 |
+
for layer_idx in range(len(LQ_latents)):
|
| 413 |
+
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
|
| 414 |
+
LQ_cur_idx = cur_process_idx*8+21+(inner_loop_num-2)*4
|
| 415 |
+
cur_latents = latents[:, :, 4+cur_process_idx*2:6+cur_process_idx*2, :, :]
|
| 416 |
+
|
| 417 |
+
# 推理(无 motion_controller / vace)
|
| 418 |
+
noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video(
|
| 419 |
+
self.dit,
|
| 420 |
+
x=cur_latents,
|
| 421 |
+
timestep=self.timestep,
|
| 422 |
+
context=None,
|
| 423 |
+
tea_cache=None,
|
| 424 |
+
use_unified_sequence_parallel=False,
|
| 425 |
+
LQ_latents=LQ_latents,
|
| 426 |
+
is_full_block=is_full_block,
|
| 427 |
+
is_stream=is_stream,
|
| 428 |
+
pre_cache_k=pre_cache_k,
|
| 429 |
+
pre_cache_v=pre_cache_v,
|
| 430 |
+
topk_ratio=topk_ratio,
|
| 431 |
+
kv_ratio=kv_ratio,
|
| 432 |
+
cur_process_idx=cur_process_idx,
|
| 433 |
+
t_mod=self.t_mod,
|
| 434 |
+
t=self.t,
|
| 435 |
+
local_range = local_range,
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
# 更新 latent
|
| 439 |
+
cur_latents = cur_latents - noise_pred_posi
|
| 440 |
+
latents_total.append(cur_latents)
|
| 441 |
+
LQ_pre_idx = LQ_cur_idx
|
| 442 |
+
|
| 443 |
+
if hasattr(self.dit, "LQ_proj_in"):
|
| 444 |
+
self.dit.LQ_proj_in.clear_cache()
|
| 445 |
+
|
| 446 |
+
if unload_dit and hasattr(self, 'dit') and not next(self.dit.parameters()).is_cpu:
|
| 447 |
+
print("[FlashVSR] Offloading DiT to the CPU to free up VRAM...")
|
| 448 |
+
self.offload_model(keep_vae=True)
|
| 449 |
+
|
| 450 |
+
latents = torch.cat(latents_total, dim=2)
|
| 451 |
+
|
| 452 |
+
# Decode
|
| 453 |
+
print("[FlashVSR] Starting VAE decoding...")
|
| 454 |
+
frames = self.TCDecoder.decode_video(latents.transpose(1, 2),parallel=False, show_progress_bar=False, cond=LQ_video[:,:,:LQ_cur_idx,:,:]).transpose(1, 2).mul_(2).sub_(1)
|
| 455 |
+
|
| 456 |
+
self.TCDecoder.clean_mem()
|
| 457 |
+
if force_offload:
|
| 458 |
+
self.offload_model()
|
| 459 |
+
|
| 460 |
+
# 颜色校正(wavelet)
|
| 461 |
+
try:
|
| 462 |
+
if color_fix:
|
| 463 |
+
frames = self.ColorCorrector(
|
| 464 |
+
frames.to(device=LQ_video.device),
|
| 465 |
+
LQ_video[:, :, :frames.shape[2], :, :],
|
| 466 |
+
clip_range=(-1, 1),
|
| 467 |
+
chunk_size=16,
|
| 468 |
+
method='adain'
|
| 469 |
+
)
|
| 470 |
+
except:
|
| 471 |
+
pass
|
| 472 |
+
|
| 473 |
+
return frames[0]
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
# -----------------------------
|
| 477 |
+
# TeaCache(保留原逻���;此处默认不启用)
|
| 478 |
+
# -----------------------------
|
| 479 |
+
class TeaCache:
|
| 480 |
+
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
|
| 481 |
+
self.num_inference_steps = num_inference_steps
|
| 482 |
+
self.step = 0
|
| 483 |
+
self.accumulated_rel_l1_distance = 0
|
| 484 |
+
self.previous_modulated_input = None
|
| 485 |
+
self.rel_l1_thresh = rel_l1_thresh
|
| 486 |
+
self.previous_residual = None
|
| 487 |
+
self.previous_hidden_states = None
|
| 488 |
+
|
| 489 |
+
self.coefficients_dict = {
|
| 490 |
+
"Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
|
| 491 |
+
"Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
|
| 492 |
+
"Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
|
| 493 |
+
"Wan2.1-I2V-14B-720P": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
|
| 494 |
+
}
|
| 495 |
+
if model_id not in self.coefficients_dict:
|
| 496 |
+
supported_model_ids = ", ".join([i for i in self.coefficients_dict])
|
| 497 |
+
raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
|
| 498 |
+
self.coefficients = self.coefficients_dict[model_id]
|
| 499 |
+
|
| 500 |
+
def check(self, dit: WanModel, x, t_mod):
|
| 501 |
+
modulated_inp = t_mod.clone()
|
| 502 |
+
if self.step == 0 or self.step == self.num_inference_steps - 1:
|
| 503 |
+
should_calc = True
|
| 504 |
+
self.accumulated_rel_l1_distance = 0
|
| 505 |
+
else:
|
| 506 |
+
coefficients = self.coefficients
|
| 507 |
+
rescale_func = np.poly1d(coefficients)
|
| 508 |
+
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
| 509 |
+
should_calc = not (self.accumulated_rel_l1_distance < self.rel_l1_thresh)
|
| 510 |
+
if should_calc:
|
| 511 |
+
self.accumulated_rel_l1_distance = 0
|
| 512 |
+
self.previous_modulated_input = modulated_inp
|
| 513 |
+
self.step = (self.step + 1) % self.num_inference_steps
|
| 514 |
+
if should_calc:
|
| 515 |
+
self.previous_hidden_states = x.clone()
|
| 516 |
+
return not should_calc
|
| 517 |
+
|
| 518 |
+
def store(self, hidden_states):
|
| 519 |
+
self.previous_residual = hidden_states - self.previous_hidden_states
|
| 520 |
+
self.previous_hidden_states = None
|
| 521 |
+
|
| 522 |
+
def update(self, hidden_states):
|
| 523 |
+
hidden_states = hidden_states + self.previous_residual
|
| 524 |
+
return hidden_states
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
# -----------------------------
|
| 528 |
+
# 简化版模型前向封装(无 vace / 无 motion_controller)
|
| 529 |
+
# -----------------------------
|
| 530 |
+
def model_fn_wan_video(
|
| 531 |
+
dit: WanModel,
|
| 532 |
+
x: torch.Tensor,
|
| 533 |
+
timestep: torch.Tensor,
|
| 534 |
+
context: torch.Tensor,
|
| 535 |
+
tea_cache: Optional[TeaCache] = None,
|
| 536 |
+
use_unified_sequence_parallel: bool = False,
|
| 537 |
+
LQ_latents: Optional[torch.Tensor] = None,
|
| 538 |
+
is_full_block: bool = False,
|
| 539 |
+
is_stream: bool = False,
|
| 540 |
+
pre_cache_k: Optional[list[torch.Tensor]] = None,
|
| 541 |
+
pre_cache_v: Optional[list[torch.Tensor]] = None,
|
| 542 |
+
topk_ratio: float = 2.0,
|
| 543 |
+
kv_ratio: float = 3.0,
|
| 544 |
+
cur_process_idx: int = 0,
|
| 545 |
+
t_mod : torch.Tensor = None,
|
| 546 |
+
t : torch.Tensor = None,
|
| 547 |
+
local_range: int = 9,
|
| 548 |
+
**kwargs,
|
| 549 |
+
):
|
| 550 |
+
# patchify
|
| 551 |
+
x, (f, h, w) = dit.patchify(x)
|
| 552 |
+
|
| 553 |
+
win = (2, 8, 8)
|
| 554 |
+
seqlen = f // win[0]
|
| 555 |
+
local_num = seqlen
|
| 556 |
+
window_size = win[0] * h * w // 128
|
| 557 |
+
square_num = window_size * window_size
|
| 558 |
+
topk = int(square_num * topk_ratio) - 1
|
| 559 |
+
kv_len = int(kv_ratio)
|
| 560 |
+
|
| 561 |
+
# RoPE 位置(分段)
|
| 562 |
+
if cur_process_idx == 0:
|
| 563 |
+
freqs = torch.cat([
|
| 564 |
+
dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
| 565 |
+
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
| 566 |
+
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
| 567 |
+
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
| 568 |
+
else:
|
| 569 |
+
freqs = torch.cat([
|
| 570 |
+
dit.freqs[0][4 + cur_process_idx*2:4 + cur_process_idx*2 + f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
| 571 |
+
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
| 572 |
+
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
| 573 |
+
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
| 574 |
+
|
| 575 |
+
# TeaCache(默认不启用)
|
| 576 |
+
tea_cache_update = tea_cache.check(dit, x, t_mod) if tea_cache is not None else False
|
| 577 |
+
|
| 578 |
+
# 统一序列并行(此处默认关闭)
|
| 579 |
+
if use_unified_sequence_parallel:
|
| 580 |
+
import torch.distributed as dist
|
| 581 |
+
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
| 582 |
+
get_sequence_parallel_world_size,
|
| 583 |
+
get_sp_group)
|
| 584 |
+
if dist.is_initialized() and dist.get_world_size() > 1:
|
| 585 |
+
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
| 586 |
+
|
| 587 |
+
# Block 堆叠
|
| 588 |
+
if tea_cache_update:
|
| 589 |
+
x = tea_cache.update(x)
|
| 590 |
+
else:
|
| 591 |
+
for block_id, block in enumerate(dit.blocks):
|
| 592 |
+
if LQ_latents is not None and block_id < len(LQ_latents):
|
| 593 |
+
x = x + LQ_latents[block_id]
|
| 594 |
+
x, last_pre_cache_k, last_pre_cache_v = block(
|
| 595 |
+
x, context, t_mod, freqs, f, h, w,
|
| 596 |
+
local_num, topk,
|
| 597 |
+
block_id=block_id,
|
| 598 |
+
kv_len=kv_len,
|
| 599 |
+
is_full_block=is_full_block,
|
| 600 |
+
is_stream=is_stream,
|
| 601 |
+
pre_cache_k=pre_cache_k[block_id] if pre_cache_k is not None else None,
|
| 602 |
+
pre_cache_v=pre_cache_v[block_id] if pre_cache_v is not None else None,
|
| 603 |
+
local_range = local_range,
|
| 604 |
+
)
|
| 605 |
+
if pre_cache_k is not None: pre_cache_k[block_id] = last_pre_cache_k
|
| 606 |
+
if pre_cache_v is not None: pre_cache_v[block_id] = last_pre_cache_v
|
| 607 |
+
|
| 608 |
+
x = dit.head(x, t)
|
| 609 |
+
if use_unified_sequence_parallel:
|
| 610 |
+
import torch.distributed as dist
|
| 611 |
+
from xfuser.core.distributed import get_sp_group
|
| 612 |
+
if dist.is_initialized() and dist.get_world_size() > 1:
|
| 613 |
+
x = get_sp_group().all_gather(x, dim=1)
|
| 614 |
+
x = dit.unpatchify(x, (f, h, w))
|
| 615 |
+
return x, pre_cache_k, pre_cache_v
|
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/pipelines/flashvsr_tiny_long.py
ADDED
|
@@ -0,0 +1,620 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import types
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
from typing import Optional, Tuple, Literal
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import numpy as np
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
# import pyfiglet
|
| 14 |
+
|
| 15 |
+
from ..models import ModelManager
|
| 16 |
+
from ..models.utils import clean_vram
|
| 17 |
+
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
|
| 18 |
+
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
|
| 19 |
+
from ..schedulers.flow_match import FlowMatchScheduler
|
| 20 |
+
from .base import BasePipeline
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# -----------------------------
|
| 24 |
+
# 基础工具:ADAIN 所需的统计量(保留以备需要;管线默认用 wavelet)
|
| 25 |
+
# -----------------------------
|
| 26 |
+
def _calc_mean_std(feat: torch.Tensor, eps: float = 1e-5) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 27 |
+
assert feat.dim() == 4, 'feat 必须是 (N, C, H, W)'
|
| 28 |
+
N, C = feat.shape[:2]
|
| 29 |
+
var = feat.view(N, C, -1).var(dim=2, unbiased=False) + eps
|
| 30 |
+
std = var.sqrt().view(N, C, 1, 1)
|
| 31 |
+
mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
|
| 32 |
+
return mean, std
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _adain(content_feat: torch.Tensor, style_feat: torch.Tensor) -> torch.Tensor:
|
| 36 |
+
assert content_feat.shape[:2] == style_feat.shape[:2], "ADAIN: N、C 必须匹配"
|
| 37 |
+
size = content_feat.size()
|
| 38 |
+
style_mean, style_std = _calc_mean_std(style_feat)
|
| 39 |
+
content_mean, content_std = _calc_mean_std(content_feat)
|
| 40 |
+
normalized = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
| 41 |
+
return normalized * style_std.expand(size) + style_mean.expand(size)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# -----------------------------
|
| 45 |
+
# 小波式模糊与分解/重构(ColorCorrector 用)
|
| 46 |
+
# -----------------------------
|
| 47 |
+
def _make_gaussian3x3_kernel(dtype, device) -> torch.Tensor:
|
| 48 |
+
vals = [
|
| 49 |
+
[0.0625, 0.125, 0.0625],
|
| 50 |
+
[0.125, 0.25, 0.125 ],
|
| 51 |
+
[0.0625, 0.125, 0.0625],
|
| 52 |
+
]
|
| 53 |
+
return torch.tensor(vals, dtype=dtype, device=device)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _wavelet_blur(x: torch.Tensor, radius: int) -> torch.Tensor:
|
| 57 |
+
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
|
| 58 |
+
N, C, H, W = x.shape
|
| 59 |
+
base = _make_gaussian3x3_kernel(x.dtype, x.device)
|
| 60 |
+
weight = base.view(1, 1, 3, 3).repeat(C, 1, 1, 1)
|
| 61 |
+
pad = radius
|
| 62 |
+
x_pad = F.pad(x, (pad, pad, pad, pad), mode='replicate')
|
| 63 |
+
out = F.conv2d(x_pad, weight, bias=None, stride=1, padding=0, dilation=radius, groups=C)
|
| 64 |
+
return out
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _wavelet_decompose(x: torch.Tensor, levels: int = 5) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 68 |
+
assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
|
| 69 |
+
high = torch.zeros_like(x)
|
| 70 |
+
low = x
|
| 71 |
+
for i in range(levels):
|
| 72 |
+
radius = 2 ** i
|
| 73 |
+
blurred = _wavelet_blur(low, radius)
|
| 74 |
+
high = high + (low - blurred)
|
| 75 |
+
low = blurred
|
| 76 |
+
return high, low
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _wavelet_reconstruct(content: torch.Tensor, style: torch.Tensor, levels: int = 5) -> torch.Tensor:
|
| 80 |
+
c_high, _ = _wavelet_decompose(content, levels=levels)
|
| 81 |
+
_, s_low = _wavelet_decompose(style, levels=levels)
|
| 82 |
+
return c_high + s_low
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# -----------------------------
|
| 86 |
+
# 无状态颜色矫正模块(视频友好,默认 wavelet)
|
| 87 |
+
# -----------------------------
|
| 88 |
+
class TorchColorCorrectorWavelet(nn.Module):
|
| 89 |
+
def __init__(self, levels: int = 5):
|
| 90 |
+
super().__init__()
|
| 91 |
+
self.levels = levels
|
| 92 |
+
|
| 93 |
+
@staticmethod
|
| 94 |
+
def _flatten_time(x: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
|
| 95 |
+
assert x.dim() == 5, '输入必须是 (B, C, f, H, W)'
|
| 96 |
+
B, C, f, H, W = x.shape
|
| 97 |
+
y = x.permute(0, 2, 1, 3, 4).reshape(B * f, C, H, W)
|
| 98 |
+
return y, B, f
|
| 99 |
+
|
| 100 |
+
@staticmethod
|
| 101 |
+
def _unflatten_time(y: torch.Tensor, B: int, f: int) -> torch.Tensor:
|
| 102 |
+
BF, C, H, W = y.shape
|
| 103 |
+
assert BF == B * f
|
| 104 |
+
return y.reshape(B, f, C, H, W).permute(0, 2, 1, 3, 4)
|
| 105 |
+
|
| 106 |
+
def forward(
|
| 107 |
+
self,
|
| 108 |
+
hq_image: torch.Tensor, # (B, C, f, H, W)
|
| 109 |
+
lq_image: torch.Tensor, # (B, C, f, H, W)
|
| 110 |
+
clip_range: Tuple[float, float] = (-1.0, 1.0),
|
| 111 |
+
method: Literal['wavelet', 'adain'] = 'wavelet',
|
| 112 |
+
chunk_size: Optional[int] = None,
|
| 113 |
+
) -> torch.Tensor:
|
| 114 |
+
assert hq_image.shape == lq_image.shape, "HQ 与 LQ 的形状必须一致"
|
| 115 |
+
assert hq_image.dim() == 5 and hq_image.shape[1] == 3, "输入必须是 (B, 3, f, H, W)"
|
| 116 |
+
|
| 117 |
+
B, C, f, H, W = hq_image.shape
|
| 118 |
+
if chunk_size is None or chunk_size >= f:
|
| 119 |
+
hq4, B, f = self._flatten_time(hq_image)
|
| 120 |
+
lq4, _, _ = self._flatten_time(lq_image)
|
| 121 |
+
if method == 'wavelet':
|
| 122 |
+
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
|
| 123 |
+
elif method == 'adain':
|
| 124 |
+
out4 = _adain(hq4, lq4)
|
| 125 |
+
else:
|
| 126 |
+
raise ValueError(f"未知 method: {method}")
|
| 127 |
+
out4 = torch.clamp(out4, *clip_range)
|
| 128 |
+
out = self._unflatten_time(out4, B, f)
|
| 129 |
+
return out
|
| 130 |
+
|
| 131 |
+
outs = []
|
| 132 |
+
for start in range(0, f, chunk_size):
|
| 133 |
+
end = min(start + chunk_size, f)
|
| 134 |
+
hq_chunk = hq_image[:, :, start:end]
|
| 135 |
+
lq_chunk = lq_image[:, :, start:end]
|
| 136 |
+
hq4, B_, f_ = self._flatten_time(hq_chunk)
|
| 137 |
+
lq4, _, _ = self._flatten_time(lq_chunk)
|
| 138 |
+
if method == 'wavelet':
|
| 139 |
+
out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
|
| 140 |
+
elif method == 'adain':
|
| 141 |
+
out4 = _adain(hq4, lq4)
|
| 142 |
+
else:
|
| 143 |
+
raise ValueError(f"未知 method: {method}")
|
| 144 |
+
out4 = torch.clamp(out4, *clip_range)
|
| 145 |
+
out_chunk = self._unflatten_time(out4, B_, f_)
|
| 146 |
+
outs.append(out_chunk)
|
| 147 |
+
out = torch.cat(outs, dim=2)
|
| 148 |
+
return out
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# -----------------------------
|
| 152 |
+
# 简化版 Pipeline(仅 dit + vae)
|
| 153 |
+
# -----------------------------
|
| 154 |
+
class FlashVSRTinyLongPipeline(BasePipeline):
|
| 155 |
+
|
| 156 |
+
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
| 157 |
+
super().__init__(device=device, torch_dtype=torch_dtype)
|
| 158 |
+
self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
|
| 159 |
+
self.dit: WanModel = None
|
| 160 |
+
self.vae: WanVideoVAE = None
|
| 161 |
+
self.model_names = ['dit', 'vae']
|
| 162 |
+
self.height_division_factor = 16
|
| 163 |
+
self.width_division_factor = 16
|
| 164 |
+
self.use_unified_sequence_parallel = False
|
| 165 |
+
self.prompt_emb_posi = None
|
| 166 |
+
self.ColorCorrector = TorchColorCorrectorWavelet(levels=5)
|
| 167 |
+
|
| 168 |
+
print(r"""
|
| 169 |
+
███████╗██╗ █████╗ ███████╗██╗ ██╗██╗ ██╗███████╗█████╗
|
| 170 |
+
██╔════╝██║ ██╔══██╗██╔════╝██║ ██║██║ ██║██╔════╝██╔══██╗ ██╗
|
| 171 |
+
█████╗ ██║ ███████║███████╗███████║╚██╗ ██╔╝███████╗███████║ ██████╗
|
| 172 |
+
██╔══╝ ██║ ██╔══██║╚════██║██╔══██║ ╚████╔╝ ╚════██║██╔═██║ ██╔═╝
|
| 173 |
+
██║ ███████╗██║ ██║███████║██║ ██║ ╚██╔╝ ███████║██║ ██║ ╚═╝
|
| 174 |
+
╚═╝ ╚══════╝╚═╝ ╚═╝╚══════╝╚═╝ ╚═╝ ╚═╝ ╚══════╝╚═╝ ╚═╝
|
| 175 |
+
""")
|
| 176 |
+
|
| 177 |
+
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
| 178 |
+
# 仅管理 dit / vae
|
| 179 |
+
dtype = next(iter(self.dit.parameters())).dtype
|
| 180 |
+
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
| 181 |
+
enable_vram_management(
|
| 182 |
+
self.dit,
|
| 183 |
+
module_map={
|
| 184 |
+
torch.nn.Linear: AutoWrappedLinear,
|
| 185 |
+
torch.nn.Conv3d: AutoWrappedModule,
|
| 186 |
+
torch.nn.LayerNorm: AutoWrappedModule,
|
| 187 |
+
RMSNorm: AutoWrappedModule,
|
| 188 |
+
},
|
| 189 |
+
module_config=dict(
|
| 190 |
+
offload_dtype=dtype,
|
| 191 |
+
offload_device="cpu",
|
| 192 |
+
onload_dtype=dtype,
|
| 193 |
+
onload_device=self.device,
|
| 194 |
+
computation_dtype=self.torch_dtype,
|
| 195 |
+
computation_device=self.device,
|
| 196 |
+
),
|
| 197 |
+
max_num_param=num_persistent_param_in_dit,
|
| 198 |
+
overflow_module_config=dict(
|
| 199 |
+
offload_dtype=dtype,
|
| 200 |
+
offload_device="cpu",
|
| 201 |
+
onload_dtype=dtype,
|
| 202 |
+
onload_device="cpu",
|
| 203 |
+
computation_dtype=self.torch_dtype,
|
| 204 |
+
computation_device=self.device,
|
| 205 |
+
),
|
| 206 |
+
)
|
| 207 |
+
self.enable_cpu_offload()
|
| 208 |
+
|
| 209 |
+
def fetch_models(self, model_manager: ModelManager):
|
| 210 |
+
self.dit = model_manager.fetch_model("wan_video_dit")
|
| 211 |
+
self.vae = model_manager.fetch_model("wan_video_vae")
|
| 212 |
+
|
| 213 |
+
@staticmethod
|
| 214 |
+
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
|
| 215 |
+
if device is None: device = model_manager.device
|
| 216 |
+
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
|
| 217 |
+
pipe = FlashVSRTinyLongPipeline(device=device, torch_dtype=torch_dtype)
|
| 218 |
+
pipe.fetch_models(model_manager)
|
| 219 |
+
# 可选:统一序列并行入口(此处默认关闭)
|
| 220 |
+
pipe.use_unified_sequence_parallel = False
|
| 221 |
+
return pipe
|
| 222 |
+
|
| 223 |
+
def denoising_model(self):
|
| 224 |
+
return self.dit
|
| 225 |
+
|
| 226 |
+
# -------------------------
|
| 227 |
+
# 新增:显式 KV 预初始化函数
|
| 228 |
+
# -------------------------
|
| 229 |
+
def init_cross_kv(
|
| 230 |
+
self,
|
| 231 |
+
context_tensor: Optional[torch.Tensor] = None,
|
| 232 |
+
prompt_path = None,
|
| 233 |
+
):
|
| 234 |
+
self.load_models_to_device(["dit"])
|
| 235 |
+
"""
|
| 236 |
+
使用固定 prompt 生成文本 context,并在 WanModel 中初始化所有 CrossAttention 的 KV 缓存。
|
| 237 |
+
必须在 __call__ 前显式调用一次。
|
| 238 |
+
"""
|
| 239 |
+
#prompt_path = "../../examples/WanVSR/prompt_tensor/posi_prompt.pth"
|
| 240 |
+
|
| 241 |
+
if self.dit is None:
|
| 242 |
+
raise RuntimeError("请先通过 fetch_models / from_model_manager 初始化 self.dit")
|
| 243 |
+
|
| 244 |
+
if context_tensor is None:
|
| 245 |
+
if prompt_path is None:
|
| 246 |
+
raise ValueError("init_cross_kv: 需要提供 prompt_path 或 context_tensor 其一")
|
| 247 |
+
ctx = torch.load(prompt_path, map_location=self.device)
|
| 248 |
+
else:
|
| 249 |
+
ctx = context_tensor
|
| 250 |
+
|
| 251 |
+
ctx = ctx.to(dtype=self.torch_dtype, device=self.device)
|
| 252 |
+
|
| 253 |
+
if self.prompt_emb_posi is None:
|
| 254 |
+
self.prompt_emb_posi = {}
|
| 255 |
+
self.prompt_emb_posi['context'] = ctx
|
| 256 |
+
self.prompt_emb_posi['stats'] = "load"
|
| 257 |
+
|
| 258 |
+
if hasattr(self.dit, "reinit_cross_kv"):
|
| 259 |
+
self.dit.reinit_cross_kv(ctx)
|
| 260 |
+
else:
|
| 261 |
+
raise AttributeError("WanModel 缺少 reinit_cross_kv(ctx) 方法,请在模型实现中加入该能力。")
|
| 262 |
+
self.timestep = torch.tensor([1000.], device=self.device, dtype=self.torch_dtype)
|
| 263 |
+
self.t = self.dit.time_embedding(sinusoidal_embedding_1d(self.dit.freq_dim, self.timestep))
|
| 264 |
+
self.t_mod = self.dit.time_projection(self.t).unflatten(1, (6, self.dit.dim))
|
| 265 |
+
# Scheduler
|
| 266 |
+
self.scheduler.set_timesteps(1, denoising_strength=1.0, shift=5.0)
|
| 267 |
+
self.load_models_to_device([])
|
| 268 |
+
|
| 269 |
+
def prepare_unified_sequence_parallel(self):
|
| 270 |
+
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
|
| 271 |
+
|
| 272 |
+
def prepare_extra_input(self, latents=None):
|
| 273 |
+
return {}
|
| 274 |
+
|
| 275 |
+
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
| 276 |
+
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
| 277 |
+
return latents
|
| 278 |
+
|
| 279 |
+
def _decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
| 280 |
+
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
| 281 |
+
return frames
|
| 282 |
+
|
| 283 |
+
def decode_video(self, latents, cond=None, **kwargs):
|
| 284 |
+
frames = self.TCDecoder.decode_video(
|
| 285 |
+
latents.transpose(1, 2), # TCDecoder 需要 (B, F, C, H, W)
|
| 286 |
+
parallel=False,
|
| 287 |
+
show_progress_bar=False,
|
| 288 |
+
cond=cond
|
| 289 |
+
).transpose(1, 2).mul_(2).sub_(1) # 转回 (B, C, F, H, W) 格式,范围 -1 to 1
|
| 290 |
+
|
| 291 |
+
return frames
|
| 292 |
+
|
| 293 |
+
def offload_model(self, keep_vae=False):
|
| 294 |
+
self.dit.clear_cross_kv()
|
| 295 |
+
self.prompt_emb_posi['stats'] = "offload"
|
| 296 |
+
self.load_models_to_device([])
|
| 297 |
+
if hasattr(self.dit, "LQ_proj_in"):
|
| 298 |
+
self.dit.LQ_proj_in.to('cpu')
|
| 299 |
+
if not keep_vae:
|
| 300 |
+
self.TCDecoder.to('cpu')
|
| 301 |
+
|
| 302 |
+
@torch.no_grad()
|
| 303 |
+
def __call__(
|
| 304 |
+
self,
|
| 305 |
+
prompt=None,
|
| 306 |
+
negative_prompt="",
|
| 307 |
+
denoising_strength=1.0,
|
| 308 |
+
seed=None,
|
| 309 |
+
rand_device="gpu",
|
| 310 |
+
height=480,
|
| 311 |
+
width=832,
|
| 312 |
+
num_frames=81,
|
| 313 |
+
cfg_scale=5.0,
|
| 314 |
+
num_inference_steps=50,
|
| 315 |
+
sigma_shift=5.0,
|
| 316 |
+
tiled=True,
|
| 317 |
+
tile_size=(60, 104),
|
| 318 |
+
tile_stride=(30, 52),
|
| 319 |
+
tea_cache_l1_thresh=None,
|
| 320 |
+
tea_cache_model_id="Wan2.1-T2V-1.3B",
|
| 321 |
+
progress_bar_cmd=tqdm,
|
| 322 |
+
progress_bar_st=None,
|
| 323 |
+
LQ_video=None,
|
| 324 |
+
is_full_block=False,
|
| 325 |
+
if_buffer=False,
|
| 326 |
+
topk_ratio=2.0,
|
| 327 |
+
kv_ratio=3.0,
|
| 328 |
+
local_range = 9,
|
| 329 |
+
color_fix = True,
|
| 330 |
+
unload_dit = False,
|
| 331 |
+
force_offload = False,
|
| 332 |
+
):
|
| 333 |
+
# 只接受 cfg=1.0(与原代码一致)
|
| 334 |
+
assert cfg_scale == 1.0, "cfg_scale must be 1.0"
|
| 335 |
+
|
| 336 |
+
# 要求:必须先 init_cross_kv()
|
| 337 |
+
if self.prompt_emb_posi is None or 'context' not in self.prompt_emb_posi:
|
| 338 |
+
raise RuntimeError(
|
| 339 |
+
"Cross-Attn KV 未初始化。请在调用 __call__ 前先执行:\n"
|
| 340 |
+
" pipe.init_cross_kv()\n"
|
| 341 |
+
"或传入自定义 context:\n"
|
| 342 |
+
" pipe.init_cross_kv(context_tensor=your_context_tensor)"
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
# 尺寸修正
|
| 346 |
+
height, width = self.check_resize_height_width(height, width)
|
| 347 |
+
if num_frames % 4 != 1:
|
| 348 |
+
num_frames = (num_frames + 2) // 4 * 4 + 1
|
| 349 |
+
print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
|
| 350 |
+
|
| 351 |
+
# Tiler 参数
|
| 352 |
+
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
| 353 |
+
|
| 354 |
+
# 初始化噪声
|
| 355 |
+
if if_buffer:
|
| 356 |
+
noise = self.generate_noise((1, 16, (num_frames - 1) // 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
| 357 |
+
else:
|
| 358 |
+
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
| 359 |
+
# noise = noise.to(dtype=self.torch_dtype, device=self.device)
|
| 360 |
+
latents = noise
|
| 361 |
+
|
| 362 |
+
process_total_num = (num_frames - 1) // 8 - 2
|
| 363 |
+
is_stream = True
|
| 364 |
+
|
| 365 |
+
if self.prompt_emb_posi['stats'] == "offload":
|
| 366 |
+
self.init_cross_kv(context_tensor=self.prompt_emb_posi['context'])
|
| 367 |
+
self.load_models_to_device(["dit"])
|
| 368 |
+
self.dit.LQ_proj_in.to(self.device)
|
| 369 |
+
self.TCDecoder.to(self.device)
|
| 370 |
+
|
| 371 |
+
# 清理可能存在的 LQ_proj_in cache
|
| 372 |
+
if hasattr(self.dit, "LQ_proj_in"):
|
| 373 |
+
self.dit.LQ_proj_in.clear_cache()
|
| 374 |
+
|
| 375 |
+
frames_total = []
|
| 376 |
+
LQ_pre_idx = 0
|
| 377 |
+
LQ_cur_idx = 0
|
| 378 |
+
self.TCDecoder.clean_mem()
|
| 379 |
+
|
| 380 |
+
with torch.no_grad():
|
| 381 |
+
for cur_process_idx in progress_bar_cmd(range(process_total_num)):
|
| 382 |
+
if cur_process_idx == 0:
|
| 383 |
+
pre_cache_k = [None] * len(self.dit.blocks)
|
| 384 |
+
pre_cache_v = [None] * len(self.dit.blocks)
|
| 385 |
+
LQ_latents = None
|
| 386 |
+
inner_loop_num = 7
|
| 387 |
+
for inner_idx in range(inner_loop_num):
|
| 388 |
+
cur = self.denoising_model().LQ_proj_in.stream_forward(
|
| 389 |
+
LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :].to(self.device)
|
| 390 |
+
) if LQ_video is not None else None
|
| 391 |
+
if cur is None:
|
| 392 |
+
continue
|
| 393 |
+
if LQ_latents is None:
|
| 394 |
+
LQ_latents = cur
|
| 395 |
+
else:
|
| 396 |
+
for layer_idx in range(len(LQ_latents)):
|
| 397 |
+
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
|
| 398 |
+
LQ_cur_idx = (inner_loop_num-1)*4-3
|
| 399 |
+
cur_latents = latents[:, :, :6, :, :]
|
| 400 |
+
else:
|
| 401 |
+
LQ_latents = None
|
| 402 |
+
inner_loop_num = 2
|
| 403 |
+
for inner_idx in range(inner_loop_num):
|
| 404 |
+
cur = self.denoising_model().LQ_proj_in.stream_forward(
|
| 405 |
+
LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :].to(self.device)
|
| 406 |
+
) if LQ_video is not None else None
|
| 407 |
+
if cur is None:
|
| 408 |
+
continue
|
| 409 |
+
if LQ_latents is None:
|
| 410 |
+
LQ_latents = cur
|
| 411 |
+
else:
|
| 412 |
+
for layer_idx in range(len(LQ_latents)):
|
| 413 |
+
LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
|
| 414 |
+
LQ_cur_idx = cur_process_idx*8+21+(inner_loop_num-2)*4
|
| 415 |
+
cur_latents = latents[:, :, 4+cur_process_idx*2:6+cur_process_idx*2, :, :]
|
| 416 |
+
|
| 417 |
+
# 推理(无 motion_controller / vace)
|
| 418 |
+
noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video(
|
| 419 |
+
self.dit,
|
| 420 |
+
x=cur_latents,
|
| 421 |
+
timestep=self.timestep,
|
| 422 |
+
context=None,
|
| 423 |
+
tea_cache=None,
|
| 424 |
+
use_unified_sequence_parallel=False,
|
| 425 |
+
LQ_latents=LQ_latents,
|
| 426 |
+
is_full_block=is_full_block,
|
| 427 |
+
is_stream=is_stream,
|
| 428 |
+
pre_cache_k=pre_cache_k,
|
| 429 |
+
pre_cache_v=pre_cache_v,
|
| 430 |
+
topk_ratio=topk_ratio,
|
| 431 |
+
kv_ratio=kv_ratio,
|
| 432 |
+
cur_process_idx=cur_process_idx,
|
| 433 |
+
t_mod=self.t_mod,
|
| 434 |
+
t=self.t,
|
| 435 |
+
local_range = local_range,
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
# 更新 latent
|
| 439 |
+
cur_latents = cur_latents - noise_pred_posi
|
| 440 |
+
|
| 441 |
+
# Decode
|
| 442 |
+
cur_LQ_frame = LQ_video[:,:,LQ_pre_idx:LQ_cur_idx,:,:].to(self.device)
|
| 443 |
+
cur_frames = self.TCDecoder.decode_video(
|
| 444 |
+
cur_latents.transpose(1, 2),
|
| 445 |
+
parallel=False,
|
| 446 |
+
show_progress_bar=False,
|
| 447 |
+
cond=cur_LQ_frame).transpose(1, 2).mul_(2).sub_(1)
|
| 448 |
+
|
| 449 |
+
# 颜色校正(wavelet)
|
| 450 |
+
try:
|
| 451 |
+
if color_fix:
|
| 452 |
+
cur_frames = self.ColorCorrector(
|
| 453 |
+
cur_frames.to(device=self.device),
|
| 454 |
+
cur_LQ_frame,
|
| 455 |
+
clip_range=(-1, 1),
|
| 456 |
+
chunk_size=None,
|
| 457 |
+
method='adain'
|
| 458 |
+
)
|
| 459 |
+
except:
|
| 460 |
+
pass
|
| 461 |
+
|
| 462 |
+
frames_total.append(cur_frames.to('cpu'))
|
| 463 |
+
LQ_pre_idx = LQ_cur_idx
|
| 464 |
+
|
| 465 |
+
if unload_dit:
|
| 466 |
+
del noise_pred_posi, cur_frames, cur_latents, cur_LQ_frame
|
| 467 |
+
clean_vram()
|
| 468 |
+
|
| 469 |
+
if hasattr(self.dit, "LQ_proj_in"):
|
| 470 |
+
self.dit.LQ_proj_in.clear_cache()
|
| 471 |
+
|
| 472 |
+
self.TCDecoder.clean_mem()
|
| 473 |
+
if force_offload:
|
| 474 |
+
self.offload_model()
|
| 475 |
+
|
| 476 |
+
frames = torch.cat(frames_total, dim=2)
|
| 477 |
+
|
| 478 |
+
return frames[0]
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
# -----------------------------
|
| 482 |
+
# TeaCache(保留原逻辑;此处默认不启用)
|
| 483 |
+
# -----------------------------
|
| 484 |
+
class TeaCache:
|
| 485 |
+
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
|
| 486 |
+
self.num_inference_steps = num_inference_steps
|
| 487 |
+
self.step = 0
|
| 488 |
+
self.accumulated_rel_l1_distance = 0
|
| 489 |
+
self.previous_modulated_input = None
|
| 490 |
+
self.rel_l1_thresh = rel_l1_thresh
|
| 491 |
+
self.previous_residual = None
|
| 492 |
+
self.previous_hidden_states = None
|
| 493 |
+
|
| 494 |
+
self.coefficients_dict = {
|
| 495 |
+
"Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
|
| 496 |
+
"Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
|
| 497 |
+
"Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
|
| 498 |
+
"Wan2.1-I2V-14B-720P": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
|
| 499 |
+
}
|
| 500 |
+
if model_id not in self.coefficients_dict:
|
| 501 |
+
supported_model_ids = ", ".join([i for i in self.coefficients_dict])
|
| 502 |
+
raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
|
| 503 |
+
self.coefficients = self.coefficients_dict[model_id]
|
| 504 |
+
|
| 505 |
+
def check(self, dit: WanModel, x, t_mod):
|
| 506 |
+
modulated_inp = t_mod.clone()
|
| 507 |
+
if self.step == 0 or self.step == self.num_inference_steps - 1:
|
| 508 |
+
should_calc = True
|
| 509 |
+
self.accumulated_rel_l1_distance = 0
|
| 510 |
+
else:
|
| 511 |
+
coefficients = self.coefficients
|
| 512 |
+
rescale_func = np.poly1d(coefficients)
|
| 513 |
+
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
| 514 |
+
should_calc = not (self.accumulated_rel_l1_distance < self.rel_l1_thresh)
|
| 515 |
+
if should_calc:
|
| 516 |
+
self.accumulated_rel_l1_distance = 0
|
| 517 |
+
self.previous_modulated_input = modulated_inp
|
| 518 |
+
self.step = (self.step + 1) % self.num_inference_steps
|
| 519 |
+
if should_calc:
|
| 520 |
+
self.previous_hidden_states = x.clone()
|
| 521 |
+
return not should_calc
|
| 522 |
+
|
| 523 |
+
def store(self, hidden_states):
|
| 524 |
+
self.previous_residual = hidden_states - self.previous_hidden_states
|
| 525 |
+
self.previous_hidden_states = None
|
| 526 |
+
|
| 527 |
+
def update(self, hidden_states):
|
| 528 |
+
hidden_states = hidden_states + self.previous_residual
|
| 529 |
+
return hidden_states
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
# -----------------------------
|
| 533 |
+
# 简化版模型前向封装(无 vace / 无 motion_controller)
|
| 534 |
+
# -----------------------------
|
| 535 |
+
def model_fn_wan_video(
|
| 536 |
+
dit: WanModel,
|
| 537 |
+
x: torch.Tensor,
|
| 538 |
+
timestep: torch.Tensor,
|
| 539 |
+
context: torch.Tensor,
|
| 540 |
+
tea_cache: Optional[TeaCache] = None,
|
| 541 |
+
use_unified_sequence_parallel: bool = False,
|
| 542 |
+
LQ_latents: Optional[torch.Tensor] = None,
|
| 543 |
+
is_full_block: bool = False,
|
| 544 |
+
is_stream: bool = False,
|
| 545 |
+
pre_cache_k: Optional[list[torch.Tensor]] = None,
|
| 546 |
+
pre_cache_v: Optional[list[torch.Tensor]] = None,
|
| 547 |
+
topk_ratio: float = 2.0,
|
| 548 |
+
kv_ratio: float = 3.0,
|
| 549 |
+
cur_process_idx: int = 0,
|
| 550 |
+
t_mod : torch.Tensor = None,
|
| 551 |
+
t : torch.Tensor = None,
|
| 552 |
+
local_range: int = 9,
|
| 553 |
+
**kwargs,
|
| 554 |
+
):
|
| 555 |
+
# patchify
|
| 556 |
+
x, (f, h, w) = dit.patchify(x)
|
| 557 |
+
|
| 558 |
+
win = (2, 8, 8)
|
| 559 |
+
seqlen = f // win[0]
|
| 560 |
+
local_num = seqlen
|
| 561 |
+
window_size = win[0] * h * w // 128
|
| 562 |
+
square_num = window_size * window_size
|
| 563 |
+
topk = int(square_num * topk_ratio) - 1
|
| 564 |
+
kv_len = int(kv_ratio)
|
| 565 |
+
|
| 566 |
+
# RoPE 位置(分段)
|
| 567 |
+
if cur_process_idx == 0:
|
| 568 |
+
freqs = torch.cat([
|
| 569 |
+
dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
| 570 |
+
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
| 571 |
+
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
| 572 |
+
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
| 573 |
+
else:
|
| 574 |
+
freqs = torch.cat([
|
| 575 |
+
dit.freqs[0][4 + cur_process_idx*2:4 + cur_process_idx*2 + f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
| 576 |
+
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
| 577 |
+
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
| 578 |
+
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
| 579 |
+
|
| 580 |
+
# TeaCache(默认不启用)
|
| 581 |
+
tea_cache_update = tea_cache.check(dit, x, t_mod) if tea_cache is not None else False
|
| 582 |
+
|
| 583 |
+
# 统一序列并行(此处默认关闭)
|
| 584 |
+
if use_unified_sequence_parallel:
|
| 585 |
+
import torch.distributed as dist
|
| 586 |
+
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
| 587 |
+
get_sequence_parallel_world_size,
|
| 588 |
+
get_sp_group)
|
| 589 |
+
if dist.is_initialized() and dist.get_world_size() > 1:
|
| 590 |
+
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
| 591 |
+
|
| 592 |
+
# Block 堆叠
|
| 593 |
+
if tea_cache_update:
|
| 594 |
+
x = tea_cache.update(x)
|
| 595 |
+
else:
|
| 596 |
+
for block_id, block in enumerate(dit.blocks):
|
| 597 |
+
if LQ_latents is not None and block_id < len(LQ_latents):
|
| 598 |
+
x = x + LQ_latents[block_id]
|
| 599 |
+
x, last_pre_cache_k, last_pre_cache_v = block(
|
| 600 |
+
x, context, t_mod, freqs, f, h, w,
|
| 601 |
+
local_num, topk,
|
| 602 |
+
block_id=block_id,
|
| 603 |
+
kv_len=kv_len,
|
| 604 |
+
is_full_block=is_full_block,
|
| 605 |
+
is_stream=is_stream,
|
| 606 |
+
pre_cache_k=pre_cache_k[block_id] if pre_cache_k is not None else None,
|
| 607 |
+
pre_cache_v=pre_cache_v[block_id] if pre_cache_v is not None else None,
|
| 608 |
+
local_range = local_range,
|
| 609 |
+
)
|
| 610 |
+
if pre_cache_k is not None: pre_cache_k[block_id] = last_pre_cache_k
|
| 611 |
+
if pre_cache_v is not None: pre_cache_v[block_id] = last_pre_cache_v
|
| 612 |
+
|
| 613 |
+
x = dit.head(x, t)
|
| 614 |
+
if use_unified_sequence_parallel:
|
| 615 |
+
import torch.distributed as dist
|
| 616 |
+
from xfuser.core.distributed import get_sp_group
|
| 617 |
+
if dist.is_initialized() and dist.get_world_size() > 1:
|
| 618 |
+
x = get_sp_group().all_gather(x, dim=1)
|
| 619 |
+
x = dit.unpatchify(x, (f, h, w))
|
| 620 |
+
return x, pre_cache_k, pre_cache_v
|
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/schedulers/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .flow_match import FlowMatchScheduler
|
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/schedulers/flow_match.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class FlowMatchScheduler():
|
| 6 |
+
|
| 7 |
+
def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False, extra_one_step=False, reverse_sigmas=False):
|
| 8 |
+
self.num_train_timesteps = num_train_timesteps
|
| 9 |
+
self.shift = shift
|
| 10 |
+
self.sigma_max = sigma_max
|
| 11 |
+
self.sigma_min = sigma_min
|
| 12 |
+
self.inverse_timesteps = inverse_timesteps
|
| 13 |
+
self.extra_one_step = extra_one_step
|
| 14 |
+
self.reverse_sigmas = reverse_sigmas
|
| 15 |
+
self.set_timesteps(num_inference_steps)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None):
|
| 19 |
+
if shift is not None:
|
| 20 |
+
self.shift = shift
|
| 21 |
+
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
|
| 22 |
+
if self.extra_one_step:
|
| 23 |
+
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
|
| 24 |
+
else:
|
| 25 |
+
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
|
| 26 |
+
if self.inverse_timesteps:
|
| 27 |
+
self.sigmas = torch.flip(self.sigmas, dims=[0])
|
| 28 |
+
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
|
| 29 |
+
if self.reverse_sigmas:
|
| 30 |
+
self.sigmas = 1 - self.sigmas
|
| 31 |
+
self.timesteps = self.sigmas * self.num_train_timesteps
|
| 32 |
+
if training:
|
| 33 |
+
x = self.timesteps
|
| 34 |
+
y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
|
| 35 |
+
y_shifted = y - y.min()
|
| 36 |
+
bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
|
| 37 |
+
self.linear_timesteps_weights = bsmntw_weighing
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def step(self, model_output, timestep, sample, to_final=False, **kwargs):
|
| 41 |
+
if isinstance(timestep, torch.Tensor):
|
| 42 |
+
timestep = timestep.cpu()
|
| 43 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
| 44 |
+
sigma = self.sigmas[timestep_id]
|
| 45 |
+
if to_final or timestep_id + 1 >= len(self.timesteps):
|
| 46 |
+
sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
|
| 47 |
+
else:
|
| 48 |
+
sigma_ = self.sigmas[timestep_id + 1]
|
| 49 |
+
prev_sample = sample + model_output * (sigma_ - sigma)
|
| 50 |
+
return prev_sample
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def return_to_timestep(self, timestep, sample, sample_stablized):
|
| 54 |
+
if isinstance(timestep, torch.Tensor):
|
| 55 |
+
timestep = timestep.cpu()
|
| 56 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
| 57 |
+
sigma = self.sigmas[timestep_id]
|
| 58 |
+
model_output = (sample - sample_stablized) / sigma
|
| 59 |
+
return model_output
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def add_noise(self, original_samples, noise, timestep):
|
| 63 |
+
if isinstance(timestep, torch.Tensor):
|
| 64 |
+
timestep = timestep.cpu()
|
| 65 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
| 66 |
+
sigma = self.sigmas[timestep_id]
|
| 67 |
+
sample = (1 - sigma) * original_samples + sigma * noise
|
| 68 |
+
return sample
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def training_target(self, sample, noise, timestep):
|
| 72 |
+
target = noise - sample
|
| 73 |
+
return target
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def training_weight(self, timestep):
|
| 77 |
+
timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
|
| 78 |
+
weights = self.linear_timesteps_weights[timestep_id]
|
| 79 |
+
return weights
|
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/vram_management/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .layers import *
|
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/vram_management/layers.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, copy
|
| 2 |
+
from ..models.utils import init_weights_on_device
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def cast_to(weight, dtype, device):
|
| 6 |
+
r = torch.empty_like(weight, dtype=dtype, device=device)
|
| 7 |
+
r.copy_(weight)
|
| 8 |
+
return r
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class AutoWrappedModule(torch.nn.Module):
|
| 12 |
+
def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.module = module.to(dtype=offload_dtype, device=offload_device)
|
| 15 |
+
self.offload_dtype = offload_dtype
|
| 16 |
+
self.offload_device = offload_device
|
| 17 |
+
self.onload_dtype = onload_dtype
|
| 18 |
+
self.onload_device = onload_device
|
| 19 |
+
self.computation_dtype = computation_dtype
|
| 20 |
+
self.computation_device = computation_device
|
| 21 |
+
self.state = 0
|
| 22 |
+
|
| 23 |
+
def offload(self):
|
| 24 |
+
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
| 25 |
+
self.module.to(dtype=self.offload_dtype, device=self.offload_device)
|
| 26 |
+
self.state = 0
|
| 27 |
+
|
| 28 |
+
def onload(self):
|
| 29 |
+
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
| 30 |
+
self.module.to(dtype=self.onload_dtype, device=self.onload_device)
|
| 31 |
+
self.state = 1
|
| 32 |
+
|
| 33 |
+
def forward(self, *args, **kwargs):
|
| 34 |
+
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
|
| 35 |
+
module = self.module
|
| 36 |
+
else:
|
| 37 |
+
module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device)
|
| 38 |
+
return module(*args, **kwargs)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class AutoWrappedLinear(torch.nn.Linear):
|
| 42 |
+
def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
|
| 43 |
+
with init_weights_on_device(device=torch.device("meta")):
|
| 44 |
+
super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
|
| 45 |
+
self.weight = module.weight
|
| 46 |
+
self.bias = module.bias
|
| 47 |
+
self.offload_dtype = offload_dtype
|
| 48 |
+
self.offload_device = offload_device
|
| 49 |
+
self.onload_dtype = onload_dtype
|
| 50 |
+
self.onload_device = onload_device
|
| 51 |
+
self.computation_dtype = computation_dtype
|
| 52 |
+
self.computation_device = computation_device
|
| 53 |
+
self.state = 0
|
| 54 |
+
|
| 55 |
+
def offload(self):
|
| 56 |
+
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
| 57 |
+
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
| 58 |
+
self.state = 0
|
| 59 |
+
|
| 60 |
+
def onload(self):
|
| 61 |
+
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
| 62 |
+
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
| 63 |
+
self.state = 1
|
| 64 |
+
|
| 65 |
+
def forward(self, x, *args, **kwargs):
|
| 66 |
+
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
|
| 67 |
+
weight, bias = self.weight, self.bias
|
| 68 |
+
else:
|
| 69 |
+
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
|
| 70 |
+
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
|
| 71 |
+
return torch.nn.functional.linear(x, weight, bias)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0):
|
| 75 |
+
for name, module in model.named_children():
|
| 76 |
+
for source_module, target_module in module_map.items():
|
| 77 |
+
if isinstance(module, source_module):
|
| 78 |
+
num_param = sum(p.numel() for p in module.parameters())
|
| 79 |
+
if max_num_param is not None and total_num_param + num_param > max_num_param:
|
| 80 |
+
module_config_ = overflow_module_config
|
| 81 |
+
else:
|
| 82 |
+
module_config_ = module_config
|
| 83 |
+
module_ = target_module(module, **module_config_)
|
| 84 |
+
setattr(model, name, module_)
|
| 85 |
+
total_num_param += num_param
|
| 86 |
+
break
|
| 87 |
+
else:
|
| 88 |
+
total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param)
|
| 89 |
+
return total_num_param
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None):
|
| 93 |
+
enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0)
|
| 94 |
+
model.vram_management_enabled = True
|
| 95 |
+
|
custom_nodes/ComfyUI-LCS/.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.claude/
|
| 2 |
+
__pycache__/
|
custom_nodes/ComfyUI-LCS/README.md
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ComfyUI-LCS
|
| 2 |
+
|
| 3 |
+
Training-free color control via the **Latent Color Subspace**, plus **sharpness control** via a discovered sharpness subspace.
|
| 4 |
+
|
| 5 |
+
> **Note:** This is an unofficial community implementation. For the official code, see [ExplainableML/LCS](https://github.com/ExplainableML/LCS).
|
| 6 |
+
|
| 7 |
+
Based on ["The Latent Color Subspace"](https://arxiv.org/abs/2603.12261v1) (ICML 2026): color in diffusion model latent patch spaces lives in a **3D subspace** (PCA captures 100% color variance), while the remaining 61 dimensions encode structure and detail orthogonally.
|
| 8 |
+
|
| 9 |
+
This plugin steers colors directly in the 3D LCS during diffusion sampling — no training, no LoRA, no post-processing.
|
| 10 |
+
|
| 11 |
+
> [中文版 README](README_zh.md)
|
| 12 |
+
|
| 13 |
+
## LCS vs Traditional Post-Processing
|
| 14 |
+
|
| 15 |
+
LCS operates **during** diffusion sampling, not after — this is the key difference from traditional color grading (Photoshop, filters, etc.).
|
| 16 |
+
|
| 17 |
+
| | Traditional Post-Processing | LCS |
|
| 18 |
+
|---|---|---|
|
| 19 |
+
| **When** | After VAE decode, in pixel space | During sampling, in latent space |
|
| 20 |
+
| **Mechanism** | Color filter on the final image | Modifies 3D color subspace mid-generation |
|
| 21 |
+
| **Model awareness** | None — structure already locked | Model adapts to color shifts in subsequent steps |
|
| 22 |
+
| **Result** | Colors can look "painted on" | Colors look naturally intended by the model |
|
| 23 |
+
|
| 24 |
+
For example: to get a warm orange sunset, post-processing tints everything orange (muddying shadows and skin tones), while LCS nudges the color subspace early in sampling so clouds, lighting, and reflections are *coherently* warm.
|
| 25 |
+
|
| 26 |
+
The core insight: color and structure are **orthogonal** in the latent patch space — you can steer one without disturbing the other.
|
| 27 |
+
|
| 28 |
+
## Tested Models
|
| 29 |
+
|
| 30 |
+
| Model | Status |
|
| 31 |
+
|-------|--------|
|
| 32 |
+
| FLUX | Tested |
|
| 33 |
+
| FLUX2.klein | Tested |
|
| 34 |
+
| z-image | Tested |
|
| 35 |
+
| z-image-turbo | Tested |
|
| 36 |
+
| Wan (qwen-image) | Tested |
|
| 37 |
+
| LTX2.3 | Tested |
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
LCS calibrates per-VAE, so it should work with any model using a compatible VAE. Feel free to report results with other models.
|
| 41 |
+
|
| 42 |
+
## Features
|
| 43 |
+
|
| 44 |
+
- **Color Steering** — Push colors toward any target color
|
| 45 |
+
- **Batch Multi-Color** — Different colors per batch item
|
| 46 |
+
- **Tone Adjustment** — Contrast, brightness, saturation, temperature with one-click presets
|
| 47 |
+
- **Color Anchor** — Zero-config color drift correction: self-anchor, reference-based, or spatial smoothing with auto mode
|
| 48 |
+
- **Sharpness Control** — Sharpen or blur during generation via a discovered sharpness subspace (PC1 explains ~97% variance)
|
| 49 |
+
- **Localized Control** — Optional mask for region-specific changes
|
| 50 |
+
- **Latent Color Preview** — Visualize color structure without VAE decoding
|
| 51 |
+
- **Step Observer** — Per-step color previews for debugging
|
| 52 |
+
|
| 53 |
+
## Installation
|
| 54 |
+
|
| 55 |
+
```bash
|
| 56 |
+
cd ComfyUI/custom_nodes
|
| 57 |
+
git clone https://github.com/facok/ComfyUI-LCS.git
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
Dependencies (usually already present in ComfyUI):
|
| 61 |
+
|
| 62 |
+
```bash
|
| 63 |
+
pip install einops safetensors
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
## Quick Start
|
| 67 |
+
|
| 68 |
+
### Basic Color Control
|
| 69 |
+
|
| 70 |
+
```
|
| 71 |
+
LCS Load Data → LCS Color Intervene → KSampler
|
| 72 |
+
↑
|
| 73 |
+
(pick a color)
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
1. **LCS Load Data** — connect your VAE (auto-calibrates on first run)
|
| 77 |
+
2. **LCS Color Intervene** — connect MODEL and LCS_DATA, pick a target color
|
| 78 |
+
3. Connect the output MODEL to KSampler
|
| 79 |
+
|
| 80 |
+
### Tone Adjustment
|
| 81 |
+
|
| 82 |
+
```
|
| 83 |
+
LCS Load Data → LCS Tone Adjust → KSampler
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
1. **LCS Load Data** → **LCS Tone Adjust**
|
| 87 |
+
2. Select a preset (e.g., "Cinematic") or adjust sliders manually
|
| 88 |
+
|
| 89 |
+

|
| 90 |
+

|
| 91 |
+
|
| 92 |
+
### Sharpness Control
|
| 93 |
+
|
| 94 |
+
```
|
| 95 |
+
LCS Load Data ──→ LCS Sharpness Calibrate → LCS Sharpness Intervene → KSampler
|
| 96 |
+
↑ lcs_data
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
1. **LCS Sharpness Calibrate** — connect VAE (auto-calibrates and caches). Optionally connect `lcs_data` from LCS Load Data to ensure sharpness edits don't affect color.
|
| 100 |
+
2. **LCS Sharpness Intervene** — connect MODEL and SHARPNESS_DATA, set strength
|
| 101 |
+
- Positive strength → sharper
|
| 102 |
+
- Negative strength → blurrier
|
| 103 |
+
- 0 → no change
|
| 104 |
+

|
| 105 |
+
### Multi-Color Batch
|
| 106 |
+
|
| 107 |
+
```
|
| 108 |
+
LCS Load Data → LCS Color Batch → KSampler
|
| 109 |
+
↓
|
| 110 |
+
batch_size → EmptyLatentImage
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
Enter comma-separated hex colors (e.g., `#FF0000,#00FF00,#0000FF`). Each color applies to one batch item.
|
| 114 |
+
|
| 115 |
+
### Color Anchor (Zero-Config Drift Correction)
|
| 116 |
+
|
| 117 |
+
```
|
| 118 |
+
LCS Load Data → LCS Color Anchor → KSampler
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
1. **LCS Load Data** → **LCS Color Anchor** — connect MODEL and LCS_DATA
|
| 122 |
+
2. Set mode to **auto** (default) and leave intensity at default
|
| 123 |
+
3. Connect the output MODEL to KSampler
|
| 124 |
+
|
| 125 |
+
That's it. In `auto` mode, the node automatically selects the correction strategy based on which optional inputs are connected:
|
| 126 |
+
|
| 127 |
+
| Connected Inputs | Resolved Mode | Behavior |
|
| 128 |
+
|---|---|---|
|
| 129 |
+
| Nothing | self_anchor | Learns the image's color patterns early on, then prevents sudden color shifts |
|
| 130 |
+
| reference_image + vae | reference | Keeps generated colors close to your reference image |
|
| 131 |
+
| mask (no reference) | smooth | Smooths out color seams (great for inpainting) |
|
| 132 |
+
|
| 133 |
+
Intensity is also derived automatically from measured drift — no manual tuning needed.
|
| 134 |
+
|
| 135 |
+
> **When to use manual mode:** If you want full control, set mode to `smooth`, `reference`, or `self_anchor` explicitly and adjust the `intensity` slider (0–1). Auto mode is designed for zero-config "just works" usage.
|
| 136 |
+
|
| 137 |
+
## Nodes
|
| 138 |
+
|
| 139 |
+
### Calibration
|
| 140 |
+
|
| 141 |
+
| Node | Description |
|
| 142 |
+
|------|-------------|
|
| 143 |
+
| **LCS Load Data** | Auto-calibrate and cache LCS color data per-VAE. Fingerprints VAE weights for automatic cache management. |
|
| 144 |
+
| **LCS Sharpness Calibrate** | Discover sharpness subspace via PCA on blur stimuli. Optionally connect `lcs_data` for color-orthogonal sharpness. |
|
| 145 |
+
|
| 146 |
+
Calibration runs once per VAE and caches automatically. Subsequent runs load instantly.
|
| 147 |
+
|
| 148 |
+
### Intervention
|
| 149 |
+
|
| 150 |
+
| Node | Description |
|
| 151 |
+
|------|-------------|
|
| 152 |
+
| **LCS Color Intervene** | Steer colors toward a target. Supports Type I (LCS shift), Type II (HSL shift), or interpolated mode. |
|
| 153 |
+
| **LCS Color Batch** | Different target colors per batch item. Outputs `batch_size` for EmptyLatentImage. |
|
| 154 |
+
| **LCS Tone Adjust** | Contrast, brightness, saturation, temperature. Preset dropdown with real-time slider sync. |
|
| 155 |
+
| **LCS Color Anchor** | Correct color drift during sampling. Auto mode infers strategy and intensity from connected inputs. |
|
| 156 |
+
| **LCS Sharpness Intervene** | Control sharpness during generation. Positive = sharper, negative = blurrier. |
|
| 157 |
+
|
| 158 |
+
### Observation
|
| 159 |
+
|
| 160 |
+
| Node | Description |
|
| 161 |
+
|------|-------------|
|
| 162 |
+
| **LCS Preview Colors** | Decode latent colors to RGB preview without VAE decoding. |
|
| 163 |
+
| **LCS Step Observer** | Save per-step color preview PNGs to ComfyUI temp directory. |
|
| 164 |
+
|
| 165 |
+
## Intervention Modes
|
| 166 |
+
|
| 167 |
+
| Mode | Description | Best For |
|
| 168 |
+
|------|-------------|----------|
|
| 169 |
+
| **interpolated** (default) | Blends Type I and Type II using sigma | General use |
|
| 170 |
+
| **type_i** | Direct translation in 3D LCS space | Strong global color shifts |
|
| 171 |
+
| **type_ii** | Per-patch HSL interpolation via bicone geometry | Precise local color control |
|
| 172 |
+
|
| 173 |
+
## Key Parameters
|
| 174 |
+
|
| 175 |
+
### Color Intervention
|
| 176 |
+
- **strength** (0.0–2.0): Intervention intensity. 1.0 = full, 0.0 = none.
|
| 177 |
+
- **start_step / end_step**: Step range for intervention. Paper optimal: steps 8–10 of 50.
|
| 178 |
+
- **mask**: Optional. Downsampled to patch grid for localized control.
|
| 179 |
+
|
| 180 |
+
### Sharpness Intervention
|
| 181 |
+
- **strength** (-5.0–5.0): Positive = sharper, negative = blurrier, 0 = no change.
|
| 182 |
+
- **start_step / end_step**: Step range (default 5–15).
|
| 183 |
+
- **mask**: Optional. Localized sharpness control.
|
| 184 |
+
|
| 185 |
+
> **Tip for distilled models**: Step-distilled models (e.g., z-image-turbo) use far fewer steps, so intervention should start earlier — even from step 0.
|
| 186 |
+
|
| 187 |
+
### Color Anchor
|
| 188 |
+
|
| 189 |
+
Sometimes diffusion models produce unexpected color shifts during sampling — a blue sky suddenly turns purple, or inpainting leaves visible color seams. The Color Anchor node fixes these problems by monitoring and correcting colors as the image is being generated.
|
| 190 |
+
|
| 191 |
+
**Modes:**
|
| 192 |
+
|
| 193 |
+
| Mode | What it does | When to use |
|
| 194 |
+
|------|-------------|----------|
|
| 195 |
+
| **auto** (default) | Looks at what you connected and picks the best strategy for you | Just want it to work, no config needed |
|
| 196 |
+
| **self_anchor** | Watches how colors evolve in early steps, then prevents sudden color jumps in later steps | General color stability, no reference needed |
|
| 197 |
+
| **reference** | Keeps the generated image's colors close to a reference image you provide | "Make it look like this photo's color palette" |
|
| 198 |
+
| **smooth** | Smooths out abrupt color boundaries between regions | Fixing visible seams after inpainting |
|
| 199 |
+
|
| 200 |
+
**How auto mode picks for you:**
|
| 201 |
+
|
| 202 |
+
1. **Which strategy?** Based on what you plugged in:
|
| 203 |
+
- Connected a reference image + VAE → uses `reference`
|
| 204 |
+
- Connected a mask (but no reference) → uses `smooth`
|
| 205 |
+
- Connected nothing extra → uses `self_anchor`
|
| 206 |
+
2. **How strong?** The node measures how much color drift is actually happening, then sets the correction strength accordingly. Big drift → stronger fix. Small drift → gentle touch. The range is 0.15–0.6, so it never over-corrects or does nothing.
|
| 207 |
+
|
| 208 |
+
**What happens during sampling:**
|
| 209 |
+
|
| 210 |
+
The node runs at every sampling step but doesn't always intervene. It automatically figures out which steps are safe to correct:
|
| 211 |
+
|
| 212 |
+
1. **Early steps** (image is mostly noise) — Too early to fix colors without creating artifacts. Skipped. In self_anchor mode, the node uses these steps to *learn* the image's color patterns.
|
| 213 |
+
2. **Middle steps** (image is taking shape) — The sweet spot. The node applies corrections here, ramping smoothly in and out to avoid sudden changes.
|
| 214 |
+
3. **Late steps** (fine details) — Corrections would disturb fine detail. Skipped.
|
| 215 |
+
|
| 216 |
+
Only colors are modified — structure, texture, and detail are never touched.
|
| 217 |
+
|
| 218 |
+
**Parameters:**
|
| 219 |
+
|
| 220 |
+
- **mode**: `auto`, `smooth`, `reference`, or `self_anchor`
|
| 221 |
+
- **intensity** (0.0–1.0): How strong the correction is. In `auto` mode this is determined automatically. Set to 0 to disable the node entirely.
|
| 222 |
+
- **vae** (optional): Needed for `reference` mode to encode the reference image
|
| 223 |
+
- **reference_image** (optional): The image whose colors you want to match
|
| 224 |
+
- **mask** (optional): Only correct colors inside the masked area
|
| 225 |
+
|
| 226 |
+
## Tone Presets
|
| 227 |
+
|
| 228 |
+
Select a preset — sliders update in real-time. Tweak after selecting for fine-tuning. Select **Custom** to set values manually.
|
| 229 |
+
|
| 230 |
+
| Preset | Contrast | Brightness | Saturation | Temperature |
|
| 231 |
+
|--------|----------|------------|------------|-------------|
|
| 232 |
+
| Base | 1.0 | 0.0 | 1.0 | 0.0 |
|
| 233 |
+
| Cinematic | 1.20 | -0.05 | 0.90 | 0.05 |
|
| 234 |
+
| HDR | 1.40 | 0.0 | 1.20 | 0.0 |
|
| 235 |
+
| Vivid | 1.10 | 0.0 | 1.50 | 0.0 |
|
| 236 |
+
| Dramatic | 1.50 | -0.10 | 0.85 | 0.0 |
|
| 237 |
+
| Low Key | 1.30 | -0.20 | 0.80 | 0.0 |
|
| 238 |
+
| High Key | 0.80 | 0.20 | 0.90 | 0.0 |
|
| 239 |
+
| Warm | 1.05 | 0.03 | 1.10 | 0.30 |
|
| 240 |
+
| Cool | 1.05 | 0.0 | 1.05 | -0.30 |
|
| 241 |
+
| Desaturated | 1.0 | 0.0 | 0.40 | 0.0 |
|
| 242 |
+
|
| 243 |
+
## How It Works
|
| 244 |
+
|
| 245 |
+
### Color (LCS)
|
| 246 |
+
|
| 247 |
+
1. **Project** — Convert denoised prediction to 64D patch space, project onto 3D LCS basis
|
| 248 |
+
2. **Decompose** — Separate 3D color coordinates from the 61D structural residual
|
| 249 |
+
3. **Normalize** — Transform to reference timestep (t=50) using learned alpha/beta statistics
|
| 250 |
+
4. **Manipulate** — Shift colors, adjust tone, or apply other transformations in 3D LCS
|
| 251 |
+
5. **Reconstruct** — Denormalize, add back the preserved 61D residual, convert to latent space
|
| 252 |
+
|
| 253 |
+
The 61D residual (structure, texture, detail) is never modified — only the 3D color subspace is touched.
|
| 254 |
+
|
| 255 |
+
### Sharpness
|
| 256 |
+
|
| 257 |
+
Sharpness lives in a separate subspace orthogonal to color:
|
| 258 |
+
|
| 259 |
+
1. **Calibrate** — Generate grayscale noise images at multiple blur levels, VAE-encode, PCA on color-removed patch vectors. PC1 captures ~97% of sharpness variance.
|
| 260 |
+
2. **Intervene** — Add `strength * pc1_direction` to each patch. Since pc1_direction is orthogonal to color (calibrated with LCS removal) and DC-free (per-vector zero-mean before PCA), this modifies only spatial frequency content without affecting color or brightness.
|
| 261 |
+
|
| 262 |
+
### Color Anchor
|
| 263 |
+
|
| 264 |
+
The Color Anchor stabilizes colors without pushing them toward a specific target — it prevents drift from what the model is already generating:
|
| 265 |
+
|
| 266 |
+
1. **Decide when to act** — The node checks each sampling step: is the image still mostly noise (too early), taking shape (good time to correct), or nearly finished (too late)? It only corrects during the safe middle window.
|
| 267 |
+
2. **Learn the color pattern** (self_anchor) — During early noisy steps, the node watches how colors relate to their neighbors and builds a running average of these relationships. This is more reliable than tracking absolute colors, which shift naturally as the image forms.
|
| 268 |
+
3. **Measure drift** — On the first correction step, the node measures how much the colors have actually drifted (varies by mode: step-to-step jumps, distance from reference, or spatial roughness). This sets the correction strength in auto mode.
|
| 269 |
+
4. **Apply gentle corrections** — Corrections ramp smoothly in and out (no sudden jumps). Each mode corrects differently: self_anchor fixes patches that deviate from learned patterns, reference pulls toward the reference image's colors, smooth blurs out sharp color boundaries.
|
| 270 |
+
5. **Preserve everything else** — As with all LCS operations, only the 3D color coordinates change. Structure, texture, and detail are untouched.
|
| 271 |
+
|
| 272 |
+
## File Structure
|
| 273 |
+
|
| 274 |
+
```
|
| 275 |
+
ComfyUI-LCS/
|
| 276 |
+
├── __init__.py # Entry point (V3 + V2 compat)
|
| 277 |
+
├── requirements.txt
|
| 278 |
+
├── core/
|
| 279 |
+
│ ├── adaptive.py # Adaptive scheduling (phases, envelopes, drift estimation)
|
| 280 |
+
│ ├── bilateral.py # Bilateral filter for LCS color smoothing
|
| 281 |
+
│ ├── calibration.py # PCA calibration pipeline (color)
|
| 282 |
+
│ ├── color_space.py # Bicone LCS ↔ HSL mapping
|
| 283 |
+
│ ├── defaults.py # Alpha/beta tables from paper
|
| 284 |
+
│ ├── lcs_data.py # LCSData dataclass
|
| 285 |
+
│ ├── patchify.py # Patch ↔ latent conversion
|
| 286 |
+
│ ├── relationships.py # Local color relationship analysis & anomaly detection
|
| 287 |
+
│ ├── sampling.py # Shared constants & step utilities
|
| 288 |
+
│ ├── sharpness.py # Sharpness subspace calibration
|
| 289 |
+
│ └── timestep.py # Sigma/timestep utilities
|
| 290 |
+
├── nodes/
|
| 291 |
+
│ ├── anchor.py # LCSColorAnchor (adaptive color drift correction)
|
| 292 |
+
│ ├── calibrate.py # LCSLoadData (auto-calibrate + cache)
|
| 293 |
+
│ ├── intervene.py # LCSColorIntervene, LCSColorBatch, LCSToneAdjust
|
| 294 |
+
│ ├── observe.py # LCSPreviewColors, LCSStepObserver
|
| 295 |
+
│ └── sharpen.py # LCSSharpnessCalibrate, LCSSharpnessIntervene
|
| 296 |
+
├── data/ # Cached calibration files
|
| 297 |
+
└── web/js/
|
| 298 |
+
└── tone_preset.js # Frontend preset sync
|
| 299 |
+
```
|
| 300 |
+
|
| 301 |
+
## Changelog
|
| 302 |
+
|
| 303 |
+
### 2026-03-21
|
| 304 |
+
- **Color Anchor: auto mode** — New `auto` mode that infers correction strategy (self_anchor / reference / smooth) from connected inputs and derives intensity from measured drift. Zero-config usage.
|
| 305 |
+
- **Color Anchor: adaptive scheduling** — Phase assignment (observe/correct/skip) and strength envelope are derived from the sigma schedule at runtime.
|
| 306 |
+
|
| 307 |
+
### 2026-03-20
|
| 308 |
+
- **Sharpness Control** — New sharpness subspace discovered via PCA on blur stimuli. `LCS Sharpness Calibrate` + `LCS Sharpness Intervene` nodes. PC1 explains ~97% variance, orthogonal to color.
|
| 309 |
+
- **Color-orthogonal sharpness** — Optional `lcs_data` input removes color component during sharpness calibration, preventing color shift.
|
| 310 |
+
|
| 311 |
+
### 2026-03-19
|
| 312 |
+
- **Video VAE support (Wan)** — Handle 5D video latents in patchify/unpatchify. Per-image VAE encoding fallback for video VAEs.
|
| 313 |
+
- **LTXV compatibility** — Pad odd spatial dims in patchify, handle 3D tensors, skip gracefully for incompatible formats.
|
| 314 |
+
- **FLUX2 support** — Auto-detect 128-channel latents in unpatchify.
|
| 315 |
+
- **Universal latent format** — Use model's `latent_format` for space conversion instead of hardcoded FLUX constants.
|
| 316 |
+
|
| 317 |
+
### 2026-03-18
|
| 318 |
+
- **Tone Adjust** — `LCS Tone Adjust` node with contrast, brightness, saturation, temperature sliders. 10 presets with frontend real-time sync.
|
| 319 |
+
- **Color temperature** — Warm/cool shift along LCS blue-yellow axis.
|
| 320 |
+
- **Bicone HSL geometry** — Correct Type II intervention via bicone LCS-to-HSL mapping.
|
| 321 |
+
|
| 322 |
+
### 2026-03-17
|
| 323 |
+
- **Initial release** — Color steering (Type I + Type II + interpolated), batch multi-color, localized mask control, latent color preview, step observer. Per-VAE auto-calibration with caching.
|
| 324 |
+
|
| 325 |
+
## Citation
|
| 326 |
+
|
| 327 |
+
Official repository: [ExplainableML/LCS](https://github.com/ExplainableML/LCS)
|
| 328 |
+
|
| 329 |
+
```bibtex
|
| 330 |
+
@article{pach2026latentcolorsubspace,
|
| 331 |
+
title={The Latent Color Subspace: Emergent Order in High-Dimensional Chaos},
|
| 332 |
+
author={Mateusz Pach and Jessica Bader and Quentin Bouniot and Serge Belongie and Zeynep Akata},
|
| 333 |
+
journal={arxiv},
|
| 334 |
+
year={2026}
|
| 335 |
+
}
|
| 336 |
+
```
|
| 337 |
+
|
| 338 |
+
## Acknowledgments
|
| 339 |
+
|
| 340 |
+
Thanks to Mateusz Pach, Jessica Bader, Quentin Bouniot, Serge Belongie, and Zeynep Akata for their research making training-free color control possible.
|
| 341 |
+
|
| 342 |
+
## License
|
| 343 |
+
|
| 344 |
+
MIT
|
custom_nodes/ComfyUI-LCS/README_zh.md
ADDED
|
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ComfyUI-LCS
|
| 2 |
+
|
| 3 |
+
基于**潜在颜色子空间**(Latent Color Subspace)的免训练颜色控制,以及基于发现的**锐度子空间**的锐度控制。
|
| 4 |
+
|
| 5 |
+
> **注意:** 本项目为非官方社区实现。官方代码见 [ExplainableML/LCS](https://github.com/ExplainableML/LCS)。
|
| 6 |
+
|
| 7 |
+
基于论文 ["The Latent Color Subspace"](https://arxiv.org/abs/2603.12261v1)(ICML 2026):扩散模型潜在 patch 空间中的颜色完全存在于一个 **3 维子空间**(PCA 捕获 100% 颜色方差),剩余 61 维编码结构与细节,与颜色正交。
|
| 8 |
+
|
| 9 |
+
本插件在扩散采样过程中直接操作 3D LCS 控制颜色——无需训练、无需 LoRA、无需后处理。
|
| 10 |
+
|
| 11 |
+
> [English README](README.md)
|
| 12 |
+
|
| 13 |
+
## LCS 与传统后处理调色的区别
|
| 14 |
+
|
| 15 |
+
LCS 在扩散采样**过程中**操作,而非生成之后——这是与传统调色(Photoshop、滤镜等)的根本区别。
|
| 16 |
+
|
| 17 |
+
| | 传统后处理 | LCS |
|
| 18 |
+
|---|---|---|
|
| 19 |
+
| **时机** | VAE 解码后,像素空间 | 采样过程中,潜在空间 |
|
| 20 |
+
| **机制** | 对成品图像施加颜色滤镜 | 在生成中途修改 3D 颜色子空间 |
|
| 21 |
+
| **模型感知** | 无——结构已定型 | 模型在后续步骤中自适应颜色偏移 |
|
| 22 |
+
| **效果** | 颜色容易显得"涂上去的" | 颜色与内容自然融合 |
|
| 23 |
+
|
| 24 |
+
例:想要暖橙色日落,后处理会给全图叠橙色(阴影和肤色变脏),而 LCS 在采样早期推动颜色子空间,模型生成的云层、光照、反射与暖色调**内在一致**。
|
| 25 |
+
|
| 26 |
+
核心发现:颜色与结构在潜在 patch 空间中**正交**——可以单独控制颜色而不干扰结构。
|
| 27 |
+
|
| 28 |
+
## 已测试模型
|
| 29 |
+
|
| 30 |
+
| 模型 | 状态 |
|
| 31 |
+
|------|------|
|
| 32 |
+
| FLUX | 已测试 |
|
| 33 |
+
| FLUX2.klein | 已测试 |
|
| 34 |
+
| z-image | 已测试 |
|
| 35 |
+
| z-image-turbo | 已测试 |
|
| 36 |
+
| Wan (qwen-image) | 已测试 |
|
| 37 |
+
| LTX2.3 | 已测试 |
|
| 38 |
+
|
| 39 |
+
LCS 按 VAE 校准,理论上适用于任何使用兼容 VAE 架构的模型。欢迎反馈其他模型的测试结果。
|
| 40 |
+
|
| 41 |
+
## 功能
|
| 42 |
+
|
| 43 |
+
- **颜色引导** — 将颜色推向任意目标色
|
| 44 |
+
- **批量多色** — 为批次中每张图像指定不同颜色
|
| 45 |
+
- **色调调整** — 对比度、亮度、饱和度、色温,支持一键预设
|
| 46 |
+
- **颜色锚定** — 零配置颜色漂移校正:自锚定、参考图锚定、空间平滑,支持全自动模式
|
| 47 |
+
- **锐度控制** — 在生成过程中增强或减弱锐度,基于发现的锐度子空间(PC1 解释 ~97% 方差)
|
| 48 |
+
- **局部控制** — 可选遮罩,实现区域性变化
|
| 49 |
+
- **潜在颜色预览** — 无需 VAE 解码即可可视化颜色结构
|
| 50 |
+
- **步骤观察器** — 保存每步颜色预览,用于调试
|
| 51 |
+
|
| 52 |
+
## 安装
|
| 53 |
+
|
| 54 |
+
```bash
|
| 55 |
+
cd ComfyUI/custom_nodes
|
| 56 |
+
git clone https://github.com/facok/ComfyUI-LCS.git
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
依赖(通常 ComfyUI 已自带):
|
| 60 |
+
|
| 61 |
+
```bash
|
| 62 |
+
pip install einops safetensors
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
## 快速开始
|
| 66 |
+
|
| 67 |
+
### 基本颜色控制
|
| 68 |
+
|
| 69 |
+
```
|
| 70 |
+
LCS Load Data → LCS Color Intervene → KSampler
|
| 71 |
+
↑
|
| 72 |
+
(选择颜色)
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
1. **LCS Load Data** — 连接 VAE(首次运行自动校准)
|
| 76 |
+
2. **LCS Color Intervene** — 连接 MODEL 和 LCS_DATA,选择目标颜色
|
| 77 |
+
3. 将输出 MODEL 连接到 KSampler
|
| 78 |
+
|
| 79 |
+
### 色调调整
|
| 80 |
+
|
| 81 |
+
```
|
| 82 |
+
LCS Load Data → LCS Tone Adjust → KSampler
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
1. **LCS Load Data** → **LCS Tone Adjust**
|
| 86 |
+
2. 选择预设(如 "Cinematic")或手动调整滑条
|
| 87 |
+
|
| 88 |
+

|
| 89 |
+

|
| 90 |
+
### 锐度控制
|
| 91 |
+
|
| 92 |
+
```
|
| 93 |
+
LCS Load Data ──→ LCS Sharpness Calibrate → LCS Sharpness Intervene → KSampler
|
| 94 |
+
↑ lcs_data
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
1. **LCS Sharpness Calibrate** — 连接 VAE(首次运行自动校准并缓存)。可选连接 `lcs_data`(来自 LCS Load Data),确保锐度编辑不影响颜色。
|
| 98 |
+
2. **LCS Sharpness Intervene** — 连接 MODEL 和 SHARPNESS_DATA,设置强度
|
| 99 |
+
- 正值 → 更锐利
|
| 100 |
+
- 负值 → 更模糊
|
| 101 |
+
- 0 → 无变化
|
| 102 |
+

|
| 103 |
+
|
| 104 |
+
### 批量多色生成
|
| 105 |
+
|
| 106 |
+
```
|
| 107 |
+
LCS Load Data → LCS Color Batch → KSampler
|
| 108 |
+
↓
|
| 109 |
+
batch_size → EmptyLatentImage
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
输入逗号分隔的十六进制颜色(如 `#FF0000,#00FF00,#0000FF`),每个颜色对应一个批次项。
|
| 113 |
+
|
| 114 |
+
### 颜色锚定(零配置漂移校正)
|
| 115 |
+
|
| 116 |
+
```
|
| 117 |
+
LCS Load Data → LCS Color Anchor → KSampler
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
1. **LCS Load Data** → **LCS Color Anchor** — 连接 MODEL 和 LCS_DATA
|
| 121 |
+
2. 模式设为 **auto**(默认),intensity 保持默认值
|
| 122 |
+
3. 将输出 MODEL 连接到 KSampler
|
| 123 |
+
|
| 124 |
+
完成。在 `auto` 模式下,节点根据连接的可选输入自动选择校正策略:
|
| 125 |
+
|
| 126 |
+
| 已连接输入 | 解析模式 | 行为 |
|
| 127 |
+
|---|---|---|
|
| 128 |
+
| 无 | self_anchor | 在早期学习图像的颜色规律,然后防止突然的颜色偏移 |
|
| 129 |
+
| reference_image + vae | reference | 让生成的颜色贴近你的参考图 |
|
| 130 |
+
| mask(无参考图) | smooth | 平滑颜色接缝(很适合修复/补绘) |
|
| 131 |
+
|
| 132 |
+
intensity 也会根据实测漂移自动推导——无需手动调参。
|
| 133 |
+
|
| 134 |
+
> **手动模式:** 如果需要完全控制,可以将模式设为 `smooth`、`reference` 或 `self_anchor`,并手动调节 `intensity` 滑条(0–1)。auto 模式适合零配置「开箱即用」场景。
|
| 135 |
+
|
| 136 |
+
## 节点一览
|
| 137 |
+
|
| 138 |
+
### 校准
|
| 139 |
+
|
| 140 |
+
| 节点 | 说明 |
|
| 141 |
+
|------|------|
|
| 142 |
+
| **LCS Load Data** | 自动校准并按 VAE 缓存 LCS 颜色数据。通过 VAE 权重指纹自动管理缓存。 |
|
| 143 |
+
| **LCS Sharpness Calibrate** | 通过模糊刺激 PCA 发现锐度子空间。可选连接 `lcs_data` 使锐度正交于颜色。 |
|
| 144 |
+
|
| 145 |
+
每个 VAE 只需校准一次,结果自动缓存,后续运行瞬时加载。
|
| 146 |
+
|
| 147 |
+
### 干预
|
| 148 |
+
|
| 149 |
+
| 节点 | 说明 |
|
| 150 |
+
|------|------|
|
| 151 |
+
| **LCS Color Intervene** | 将颜色引导至目标色。支持 Type I(LCS 平移)、Type II(HSL 偏移)或插值模式。 |
|
| 152 |
+
| **LCS Color Batch** | 每个批次项施加不同目标颜色。输出 `batch_size` 可连接 EmptyLatentImage。 |
|
| 153 |
+
| **LCS Tone Adjust** | 对比度、亮度、饱和度、色温调整。预设下拉菜单,滑条实时同步。 |
|
| 154 |
+
| **LCS Color Anchor** | 采样过程中校正颜色漂移。auto 模式根据连接输入自动推断策略和强度。 |
|
| 155 |
+
| **LCS Sharpness Intervene** | 在生成过程中控制锐度。正值 = 更锐利,负值 = 更模糊。 |
|
| 156 |
+
|
| 157 |
+
### 观察
|
| 158 |
+
|
| 159 |
+
| 节点 | 说明 |
|
| 160 |
+
|------|------|
|
| 161 |
+
| **LCS Preview Colors** | 将潜在颜色解码为 RGB 预览图,无需 VAE 解码。 |
|
| 162 |
+
| **LCS Step Observer** | 将每步颜色预览 PNG 保存至 ComfyUI 临时目录。 |
|
| 163 |
+
|
| 164 |
+
## 干预模式
|
| 165 |
+
|
| 166 |
+
| 模式 | 说明 | 适用场景 |
|
| 167 |
+
|------|------|----------|
|
| 168 |
+
| **interpolated**(默认) | 以 sigma 为权重混合 Type I 和 Type II | 通用场景 |
|
| 169 |
+
| **type_i** | 3D LCS 空间中的直接平移 | 强烈的全局颜色偏移 |
|
| 170 |
+
| **type_ii** | 通过双锥几何进行逐 patch HSL 插值 | 精确的局部颜色控制 |
|
| 171 |
+
|
| 172 |
+
## 关键参数
|
| 173 |
+
|
| 174 |
+
### 颜色干预
|
| 175 |
+
- **strength**(0.0–2.0):干预强度。1.0 = 完整干预,0.0 = 无干预。
|
| 176 |
+
- **start_step / end_step**:干预步骤范围。论文最优:50 步中的第 8–10 步。
|
| 177 |
+
- **mask**:可选。下采样至 patch 网格分辨率,用于局部控制。
|
| 178 |
+
|
| 179 |
+
### 锐度干预
|
| 180 |
+
- **strength**(-5.0–5.0):正值 = 更锐利,负值 = 更模糊,0 = 无变化。
|
| 181 |
+
- **start_step / end_step**:干预步骤范围(默认 5–15)。
|
| 182 |
+
- **mask**:可选。用于局部锐度控制。
|
| 183 |
+
|
| 184 |
+
> **步数蒸馏模型提示**:对于步数蒸馏模型(如 z-image-turbo),总步数很少,干预应从更早的步骤开始——甚至可以从第 0 步就开始干预。
|
| 185 |
+
|
| 186 |
+
### 颜色锚定
|
| 187 |
+
|
| 188 |
+
扩散模型在采样过程中有时会出现意想不到的颜色偏移——蓝天突然变紫,或者修复/补绘后留下明显的颜色接缝。颜色锚定节点在图像生成过程中监控和修正这些问题。
|
| 189 |
+
|
| 190 |
+
**模式:**
|
| 191 |
+
|
| 192 |
+
| 模式 | 功能 | 适用场景 |
|
| 193 |
+
|------|------|----------|
|
| 194 |
+
| **auto**(默认) | 根据你连接的输入自动选最合适的策略 | 不想调参,开箱即用 |
|
| 195 |
+
| **self_anchor** | 在早期步骤观察颜色变化规律,在后续步骤防止突然的颜色跳变 | 通用颜色稳定,不需要参考图 |
|
| 196 |
+
| **reference** | 让生成图像的颜色贴近你提供的参考图 | 「我想要这张照片的配色风格」 |
|
| 197 |
+
| **smooth** | 平滑区域之间的突兀颜色边界 | 修复/补绘后消除接缝 |
|
| 198 |
+
|
| 199 |
+
**auto 模式如何自动选择:**
|
| 200 |
+
|
| 201 |
+
1. **用哪种策略?** 看你连了什么:
|
| 202 |
+
- 连了参考图 + VAE → 用 `reference`
|
| 203 |
+
- 连了遮罩(没有参考图)→ 用 `smooth`
|
| 204 |
+
- 什么额外输入都没连 → 用 `self_anchor`
|
| 205 |
+
2. **修正多强?** 节点会测量实际的颜色漂移幅度,据此自动设置校正强度。漂移大 → 修正更强;漂移小 → 轻轻一碰。范围是 0.15–0.6,既不会矫枉过正,也不会毫无作用。
|
| 206 |
+
|
| 207 |
+
**采样过程中发生了什么:**
|
| 208 |
+
|
| 209 |
+
节点在每个采样步都会运行,但不会每步都干预。它自动判断哪些步骤适合校正:
|
| 210 |
+
|
| 211 |
+
1. **早期步骤**(图像基本是噪声)— 太早修正颜色会产生伪影,跳过。在 self_anchor 模式下,节点利用这些步骤*学习*图像的颜色规律。
|
| 212 |
+
2. **中间步骤**(图像逐渐成形)— 最佳校正时机。节点在这里施加校正,平滑地渐入渐出,避免突变。
|
| 213 |
+
3. **后期步骤**(精细细节)— 校正会干扰细节,跳过。
|
| 214 |
+
|
| 215 |
+
只修改颜色——结构、纹理、细节始终不受影响。
|
| 216 |
+
|
| 217 |
+
**参数:**
|
| 218 |
+
|
| 219 |
+
- **mode**:`auto`、`smooth`、`reference` 或 `self_anchor`
|
| 220 |
+
- **intensity**(0.0–1.0):校正强度。auto 模式下自动决定。设为 0 可完全禁用此节点。
|
| 221 |
+
- **vae**(可选):reference 模式需要用它来编码参考图
|
| 222 |
+
- **reference_image**(可选):你想匹配其颜色的参考图
|
| 223 |
+
- **mask**(可选):只在遮罩区域内校正颜色
|
| 224 |
+
|
| 225 |
+
## 色调预设
|
| 226 |
+
|
| 227 |
+
选择预���后滑条实时更新。可在预设基础上微调。选择 **Custom** 可完全手动设置。
|
| 228 |
+
|
| 229 |
+
| 预设 | 对比度 | 亮度 | 饱和度 | 色温 |
|
| 230 |
+
|------|--------|------|--------|------|
|
| 231 |
+
| Base | 1.0 | 0.0 | 1.0 | 0.0 |
|
| 232 |
+
| Cinematic | 1.20 | -0.05 | 0.90 | 0.05 |
|
| 233 |
+
| HDR | 1.40 | 0.0 | 1.20 | 0.0 |
|
| 234 |
+
| Vivid | 1.10 | 0.0 | 1.50 | 0.0 |
|
| 235 |
+
| Dramatic | 1.50 | -0.10 | 0.85 | 0.0 |
|
| 236 |
+
| Low Key | 1.30 | -0.20 | 0.80 | 0.0 |
|
| 237 |
+
| High Key | 0.80 | 0.20 | 0.90 | 0.0 |
|
| 238 |
+
| Warm | 1.05 | 0.03 | 1.10 | 0.30 |
|
| 239 |
+
| Cool | 1.05 | 0.0 | 1.05 | -0.30 |
|
| 240 |
+
| Desaturated | 1.0 | 0.0 | 0.40 | 0.0 |
|
| 241 |
+
|
| 242 |
+
## 工作原理
|
| 243 |
+
|
| 244 |
+
### 颜色(LCS)
|
| 245 |
+
|
| 246 |
+
1. **投影** — 将去噪预测转换到 64D patch 空间,投影到 3D LCS 基底
|
| 247 |
+
2. **分解** — 将 3D 颜色坐标与 61D 结构残差分离
|
| 248 |
+
3. **归一化** — 使用学习的 alpha/beta 统计量变换至参考时间步(t=50)
|
| 249 |
+
4. **操作** — 在 3D LCS 中偏移颜色、调整色调或进行其他变换
|
| 250 |
+
5. **重建** — 反归一化,加回保留的 61D 残差,转换回潜在空间
|
| 251 |
+
|
| 252 |
+
61D 残差(结构、纹理、细节)始终不被修改——只有 3D 颜色子空间会被改变。
|
| 253 |
+
|
| 254 |
+
### 锐度
|
| 255 |
+
|
| 256 |
+
锐度存在于与颜色正交的独立子空间中:
|
| 257 |
+
|
| 258 |
+
1. **校准** — 生成灰度噪声图像,应用多级高斯模糊,VAE 编码后对去除颜色分量的 patch 向量做 PCA。PC1 捕获 ~97% 的锐度方差。
|
| 259 |
+
2. **干预** — 在每个 patch 上沿 `strength * pc1_direction` 方向添加偏移。由于 pc1_direction 与颜色正交(校准时已移除 LCS 分量)且无直流分量(PCA 前做了逐向量零均值化),因此只改变空间频率内容,不影响颜色或亮度。
|
| 260 |
+
|
| 261 |
+
### 颜色锚定
|
| 262 |
+
|
| 263 |
+
颜色锚定的作用是稳定颜色,而不是把颜色推向某个特定目标——它防止模型已经在生成的颜色发生偏移:
|
| 264 |
+
|
| 265 |
+
1. **判断何时介入** — 节点检查每个采样步:图像还是一片噪声(太早)、正在成形(适合校正)、还是快完成了(太晚)?只在安全的中间窗口进行校正。
|
| 266 |
+
2. **学习颜色规律**(self_anchor)— 在早期噪声较大的步骤中,节点观察每个区域的颜色与邻居之间的关系,建立一个动态平均值。比起追踪绝对颜色值,这种「相对关系」更可靠,因为绝对颜色在图像成形过程中本来就会自然变化。
|
| 267 |
+
3. **测量漂移** — 在第一个校正步,节点测量颜色实际漂移了多少(根据模式不同:步间跳变幅度、与参考图的差距、或空间粗糙程度)。这决定了 auto 模式下的校正强度。
|
| 268 |
+
4. **温和地修正** — 校正平滑地渐入渐出(不会突变)。每种模式的修正方式不同:self_anchor 修复偏离已学规律的区域,reference 拉近与参考图的颜色,smooth 模糊掉尖锐的颜色边界。
|
| 269 |
+
5. **保留其他一切** — 与所有 LCS 操作一样,只修改 3D 颜色坐标,结构、纹理、细节完全不受影响。
|
| 270 |
+
|
| 271 |
+
## 文件结构
|
| 272 |
+
|
| 273 |
+
```
|
| 274 |
+
ComfyUI-LCS/
|
| 275 |
+
├── __init__.py # 入口(V3 + V2 兼容)
|
| 276 |
+
├── requirements.txt
|
| 277 |
+
├── core/
|
| 278 |
+
│ ├── adaptive.py # 自适应调度(阶段、包络、漂移估计)
|
| 279 |
+
│ ├── bilateral.py # LCS 颜色平滑的双边滤波
|
| 280 |
+
│ ├── calibration.py # PCA 校准流程(颜色)
|
| 281 |
+
│ ├── color_space.py # 双锥 LCS ↔ HSL 映射
|
| 282 |
+
│ ├── defaults.py # 论文中的 Alpha/beta 表
|
| 283 |
+
│ ├── lcs_data.py # LCSData 数据类
|
| 284 |
+
│ ├── patchify.py # Patch ↔ 潜在空间转换
|
| 285 |
+
│ ├── relationships.py # 局部颜色关系分析与异常检测
|
| 286 |
+
│ ├── sampling.py # 共享常量和步骤工具
|
| 287 |
+
│ ├── sharpness.py # 锐度子空间校准
|
| 288 |
+
│ └── timestep.py # Sigma/时间步工具
|
| 289 |
+
├── nodes/
|
| 290 |
+
│ ├── anchor.py # LCSColorAnchor(自适应颜色漂移校正)
|
| 291 |
+
│ ├── calibrate.py # LCSLoadData(自动校准 + 缓存)
|
| 292 |
+
│ ├── intervene.py # LCSColorIntervene, LCSColorBatch, LCSToneAdjust
|
| 293 |
+
│ ├── observe.py # LCSPreviewColors, LCSStepObserver
|
| 294 |
+
│ └── sharpen.py # LCSSharpnessCalibrate, LCSSharpnessIntervene
|
| 295 |
+
├── data/ # 缓存的校准文件
|
| 296 |
+
└── web/js/
|
| 297 |
+
└── tone_preset.js # 前端预设同步
|
| 298 |
+
```
|
| 299 |
+
|
| 300 |
+
## 更新日志
|
| 301 |
+
|
| 302 |
+
### 2026-03-21
|
| 303 |
+
- **颜色锚定:auto 模式** — 新增 `auto` 模式,根据连接的输入自动推断校正策略(self_anchor / reference / smooth),并根据实测漂移推导强度。零配置使用。
|
| 304 |
+
- **颜色锚定:自适应调度** — 阶段分配(observe/correct/skip)和强度包络在运行时从 sigma 调度表推导。
|
| 305 |
+
|
| 306 |
+
### 2026-03-20
|
| 307 |
+
- **锐度控制** — 通过模糊刺激 PCA 发现锐度子空间。新增 `LCS Sharpness Calibrate` + `LCS Sharpness Intervene` 节点。PC1 解释 ~97% 方差,与颜色正交。
|
| 308 |
+
- **颜色正交锐度** — 可选连接 `lcs_data`,在锐度校准时移除颜色分量,防止颜色偏移。
|
| 309 |
+
|
| 310 |
+
### 2026-03-19
|
| 311 |
+
- **视频 VAE 支持(Wan)** — 在 patchify/unpatchify 中处理 5D 视频潜在表示。视频 VAE 自动回退到逐帧编码。
|
| 312 |
+
- **LTXV 兼容** — patchify 中填充奇数空间维度,处理 3D 张量,不兼容格式时优雅跳过。
|
| 313 |
+
- **FLUX2 支持** — unpatchify 自动检测 128 通道潜在表示。
|
| 314 |
+
- **通用潜在格式** — 使用模型的 `latent_format` 进行空间转换,不再硬编码 FLUX 常量。
|
| 315 |
+
|
| 316 |
+
### 2026-03-18
|
| 317 |
+
- **色调调整** — `LCS Tone Adjust` 节点,支持对比度、亮度、饱和度、色温滑条。10 个预设,前端实时同步。
|
| 318 |
+
- **色温控制** — 沿 LCS 蓝-黄轴的暖/冷偏移。
|
| 319 |
+
- **双锥 HSL 几何** — 通过双锥 LCS-to-HSL 映射实现正确的 Type II 干预。
|
| 320 |
+
|
| 321 |
+
### 2026-03-17
|
| 322 |
+
- **首次发布** — 颜色引导(Type I + Type II + 插值模式)、批量多色、局部遮罩控制、潜在颜色预览、步骤观察器。按 VAE 自动校准并缓存。
|
| 323 |
+
|
| 324 |
+
## 引用
|
| 325 |
+
|
| 326 |
+
官方仓库:[ExplainableML/LCS](https://github.com/ExplainableML/LCS)
|
| 327 |
+
|
| 328 |
+
```bibtex
|
| 329 |
+
@article{pach2026latentcolorsubspace,
|
| 330 |
+
title={The Latent Color Subspace: Emergent Order in High-Dimensional Chaos},
|
| 331 |
+
author={Mateusz Pach and Jessica Bader and Quentin Bouniot and Serge Belongie and Zeynep Akata},
|
| 332 |
+
journal={arxiv},
|
| 333 |
+
year={2026}
|
| 334 |
+
}
|
| 335 |
+
```
|
| 336 |
+
|
| 337 |
+
## 致谢
|
| 338 |
+
|
| 339 |
+
感谢 Mateusz Pach、Jessica Bader、Quentin Bouniot、Serge Belongie 和 Zeynep Akata,他们的研究使免训练颜色控制成为可能。
|
| 340 |
+
|
| 341 |
+
## 许可证
|
| 342 |
+
|
| 343 |
+
MIT
|
custom_nodes/ComfyUI-LCS/__init__.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ComfyUI-LCS: The Latent Color Subspace — training-free color control for FLUX.
|
| 2 |
+
|
| 3 |
+
Paper: "The Latent Color Subspace" (arXiv:2603.12261v1, ICML 2026)
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
# Register as ComfyUI_LCS so other plugins can `from ComfyUI_LCS.core.xxx import ...`
|
| 7 |
+
import sys as _sys
|
| 8 |
+
_sys.modules.setdefault("ComfyUI_LCS", _sys.modules[__name__])
|
| 9 |
+
|
| 10 |
+
# V3 ComfyExtension entry point
|
| 11 |
+
from comfy_api.latest import ComfyExtension, io
|
| 12 |
+
from .nodes.calibrate import LCSLoadData
|
| 13 |
+
from .nodes.intervene import LCSColorIntervene, LCSColorBatch, LCSToneAdjust
|
| 14 |
+
from .nodes.observe import LCSPreviewColors, LCSStepObserver
|
| 15 |
+
from .nodes.sharpen import LCSSharpnessCalibrate, LCSSharpnessIntervene
|
| 16 |
+
from .nodes.anchor import LCSColorAnchor
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class LCSExtension(ComfyExtension):
|
| 20 |
+
"""V3 ComfyExtension providing all LCS nodes to ComfyUI."""
|
| 21 |
+
|
| 22 |
+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
| 23 |
+
"""Return all LCS node classes."""
|
| 24 |
+
return [
|
| 25 |
+
LCSLoadData,
|
| 26 |
+
LCSColorIntervene,
|
| 27 |
+
LCSColorBatch,
|
| 28 |
+
LCSToneAdjust,
|
| 29 |
+
LCSPreviewColors,
|
| 30 |
+
LCSStepObserver,
|
| 31 |
+
LCSSharpnessCalibrate,
|
| 32 |
+
LCSSharpnessIntervene,
|
| 33 |
+
LCSColorAnchor,
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
async def comfy_entrypoint() -> LCSExtension:
|
| 38 |
+
"""V3 async entry point called by ComfyUI on startup."""
|
| 39 |
+
return LCSExtension()
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# V2 backward compatibility
|
| 43 |
+
from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
|
| 44 |
+
|
| 45 |
+
WEB_DIRECTORY = "./web"
|
| 46 |
+
|
| 47 |
+
__all__ = [
|
| 48 |
+
"NODE_CLASS_MAPPINGS",
|
| 49 |
+
"NODE_DISPLAY_NAME_MAPPINGS",
|
| 50 |
+
"WEB_DIRECTORY",
|
| 51 |
+
"LCSExtension",
|
| 52 |
+
"comfy_entrypoint",
|
| 53 |
+
]
|
custom_nodes/ComfyUI-LCS/core/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .lcs_data import LCSData
|
| 2 |
+
from .patchify import patchify, unpatchify
|
| 3 |
+
from .timestep import sigma_to_paper_t, get_alpha_beta, normalize_to_t50, denormalize_from_t50
|
| 4 |
+
from .color_space import decode_lcs_to_hsl, encode_hsl_to_lcs, hex_to_hsl, hsl_to_rgb
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def calibrate(*args, **kwargs):
|
| 8 |
+
"""Lazy wrapper for core.calibration.calibrate (avoids importing comfy.utils at module level)."""
|
| 9 |
+
from .calibration import calibrate as _calibrate
|
| 10 |
+
return _calibrate(*args, **kwargs)
|
custom_nodes/ComfyUI-LCS/core/adaptive.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Schedule-aware adaptive logic for LCS color anchoring.
|
| 2 |
+
|
| 3 |
+
Derives intervention windows, strength envelopes, and phase assignments
|
| 4 |
+
from the sigma schedule's amplification factor (beta_50 / beta_t), replacing
|
| 5 |
+
all manually-tuned step/strength parameters with data-driven decisions.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import math
|
| 9 |
+
import torch
|
| 10 |
+
from .defaults import get_beta_table
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def compute_amplification(sigma_val, device=None):
|
| 14 |
+
"""Compute amplification factor A = max_k(beta_50[k] / beta_t(sigma)[k]).
|
| 15 |
+
|
| 16 |
+
The amplification factor measures how much the normalization step inflates
|
| 17 |
+
noise relative to signal. High A means corrections are dangerous (amplified
|
| 18 |
+
noise dominates), low A means corrections are safe.
|
| 19 |
+
|
| 20 |
+
sigma_val: float in [0, 1] (FLUX sigma, 1=noise, 0=clean)
|
| 21 |
+
Returns: float amplification factor
|
| 22 |
+
"""
|
| 23 |
+
beta_table = get_beta_table() # [51, 3]
|
| 24 |
+
beta_50 = beta_table[50] # [3]
|
| 25 |
+
|
| 26 |
+
# Convert sigma to paper timestep
|
| 27 |
+
t = 50.0 * (1.0 - max(0.0, min(1.0, sigma_val)))
|
| 28 |
+
t = max(0.0, min(50.0, t))
|
| 29 |
+
t_low = int(t)
|
| 30 |
+
t_high = min(t_low + 1, 50)
|
| 31 |
+
frac = t - t_low
|
| 32 |
+
|
| 33 |
+
beta_t = (1.0 - frac) * beta_table[t_low] + frac * beta_table[t_high]
|
| 34 |
+
|
| 35 |
+
# Per-component ratio, take max
|
| 36 |
+
beta_t_safe = beta_t.clamp(min=1e-8)
|
| 37 |
+
ratios = beta_50 / beta_t_safe # [3]
|
| 38 |
+
return ratios.max().item()
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def compute_step_phases(sigmas, mode):
|
| 42 |
+
"""Assign a phase to each sampling step based on amplification factor.
|
| 43 |
+
|
| 44 |
+
Physics-derived constants (not empirical):
|
| 45 |
+
A_MAX = 10.0 — above: normalization amplifies noise >10x → skip
|
| 46 |
+
A_WARMUP = 5.0 — self_anchor only: observe phase for EMA buildup
|
| 47 |
+
SIGMA_MIN = 0.15 — below: final detail refinement → skip
|
| 48 |
+
|
| 49 |
+
sigmas: 1D tensor of sigma values for each step (length N+1, last is 0)
|
| 50 |
+
mode: "smooth", "reference", or "self_anchor"
|
| 51 |
+
|
| 52 |
+
Returns: list of N strings, each "skip" / "observe" / "correct"
|
| 53 |
+
"""
|
| 54 |
+
A_MAX = 10.0
|
| 55 |
+
A_WARMUP = 5.0
|
| 56 |
+
SIGMA_MIN = 0.15
|
| 57 |
+
|
| 58 |
+
n_steps = len(sigmas) - 1 # last sigma is terminal (0)
|
| 59 |
+
phases = []
|
| 60 |
+
|
| 61 |
+
for i in range(n_steps):
|
| 62 |
+
sigma_val = float(sigmas[i])
|
| 63 |
+
|
| 64 |
+
# Final refinement — skip
|
| 65 |
+
if sigma_val < SIGMA_MIN:
|
| 66 |
+
phases.append("skip")
|
| 67 |
+
continue
|
| 68 |
+
|
| 69 |
+
amp = compute_amplification(sigma_val)
|
| 70 |
+
|
| 71 |
+
# Too noisy — skip
|
| 72 |
+
if amp > A_MAX:
|
| 73 |
+
phases.append("skip")
|
| 74 |
+
continue
|
| 75 |
+
|
| 76 |
+
# Self-anchor warmup zone
|
| 77 |
+
if mode == "self_anchor" and amp > A_WARMUP:
|
| 78 |
+
phases.append("observe")
|
| 79 |
+
continue
|
| 80 |
+
|
| 81 |
+
phases.append("correct")
|
| 82 |
+
|
| 83 |
+
return phases
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def estimate_intensity(drift_signal):
|
| 87 |
+
"""Map drift magnitude to intensity in [0.15, 0.6]."""
|
| 88 |
+
DRIFT_SCALE = 0.2
|
| 89 |
+
INTENSITY_MIN = 0.15
|
| 90 |
+
INTENSITY_MAX = 0.6
|
| 91 |
+
return max(INTENSITY_MIN, min(INTENSITY_MAX, drift_signal / DRIFT_SCALE))
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def compute_strength_envelope(n_correction_steps):
|
| 95 |
+
"""Sinusoidal bell envelope over correction steps.
|
| 96 |
+
|
| 97 |
+
sin(pi * i / (n-1)) for i in 0..n-1
|
| 98 |
+
Prevents abrupt on/off at phase boundaries.
|
| 99 |
+
Single step returns [1.0].
|
| 100 |
+
|
| 101 |
+
Returns: 1D tensor of length n_correction_steps
|
| 102 |
+
"""
|
| 103 |
+
if n_correction_steps <= 0:
|
| 104 |
+
return torch.zeros(0)
|
| 105 |
+
if n_correction_steps == 1:
|
| 106 |
+
return torch.ones(1)
|
| 107 |
+
n = n_correction_steps
|
| 108 |
+
indices = torch.arange(n, dtype=torch.float32)
|
| 109 |
+
return torch.sin(math.pi * indices / (n - 1))
|
custom_nodes/ComfyUI-LCS/core/bilateral.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Bilateral filter in LCS space for smooth color anchoring."""
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def estimate_bilateral_params(c, h_len, w_len):
|
| 10 |
+
"""Estimate bilateral filter parameters from local color statistics.
|
| 11 |
+
|
| 12 |
+
Computes per-channel spatial std of c across the grid, takes the median
|
| 13 |
+
to derive sigma_color. sigma_spatial is fixed at 1.5 (5x5 kernel is small).
|
| 14 |
+
|
| 15 |
+
c: [B, L, 3] LCS coordinates
|
| 16 |
+
Returns: (sigma_spatial, sigma_color) floats
|
| 17 |
+
"""
|
| 18 |
+
B = c.shape[0]
|
| 19 |
+
grid = c.reshape(B, h_len, w_len, 3) # [B, H, W, 3]
|
| 20 |
+
# Per-channel std across spatial dims → [B, 3]
|
| 21 |
+
channel_std = grid.reshape(B, -1, 3).std(dim=1) # [B, 3]
|
| 22 |
+
# Median across batch and channels
|
| 23 |
+
median_std = float(channel_std.median())
|
| 24 |
+
sigma_color = max(0.05, min(3.0, 0.75 * median_std))
|
| 25 |
+
sigma_spatial = 1.5
|
| 26 |
+
return sigma_spatial, sigma_color
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def bilateral_filter_lcs(c, h_len, w_len, sigma_spatial, sigma_color, kernel_radius=2):
|
| 30 |
+
"""Bilateral filter on [B, L, 3] LCS coordinates arranged on h_len x w_len grid.
|
| 31 |
+
|
| 32 |
+
Uses spatial distance + LCS color distance as joint weights.
|
| 33 |
+
kernel_radius=2 -> 5x5 neighborhood (25 lookups per patch).
|
| 34 |
+
Returns [B, L, 3] filtered coordinates.
|
| 35 |
+
"""
|
| 36 |
+
B = c.shape[0]
|
| 37 |
+
# Reshape to spatial grid
|
| 38 |
+
grid = c.reshape(B, h_len, w_len, 3) # [B, H, W, 3]
|
| 39 |
+
|
| 40 |
+
# Pad by kernel_radius (replicate) — pad last two spatial dims
|
| 41 |
+
# F.pad on [B, H, W, 3]: need to pad dims -3 and -2 (H and W)
|
| 42 |
+
# Permute to [B, 3, H, W] for F.pad, then back
|
| 43 |
+
grid_chw = grid.permute(0, 3, 1, 2) # [B, 3, H, W]
|
| 44 |
+
r = kernel_radius
|
| 45 |
+
padded = F.pad(grid_chw, (r, r, r, r), mode="replicate") # [B, 3, H+2r, W+2r]
|
| 46 |
+
|
| 47 |
+
# Precompute spatial Gaussian weights for each offset in kernel
|
| 48 |
+
inv_2ss = -0.5 / (sigma_spatial * sigma_spatial)
|
| 49 |
+
inv_2sc = -0.5 / (sigma_color * sigma_color)
|
| 50 |
+
|
| 51 |
+
# Accumulate weighted sum
|
| 52 |
+
weight_sum = torch.zeros(B, 1, h_len, w_len, device=c.device, dtype=c.dtype)
|
| 53 |
+
value_sum = torch.zeros(B, 3, h_len, w_len, device=c.device, dtype=c.dtype)
|
| 54 |
+
|
| 55 |
+
for dy in range(-r, r + 1):
|
| 56 |
+
for dx in range(-r, r + 1):
|
| 57 |
+
# Spatial weight (constant per offset)
|
| 58 |
+
spatial_dist_sq = float(dy * dy + dx * dx)
|
| 59 |
+
w_spatial = math.exp(spatial_dist_sq * inv_2ss)
|
| 60 |
+
|
| 61 |
+
# Extract neighbor values from padded grid
|
| 62 |
+
y_start = r + dy
|
| 63 |
+
x_start = r + dx
|
| 64 |
+
neighbor = padded[:, :, y_start:y_start + h_len, x_start:x_start + w_len] # [B, 3, H, W]
|
| 65 |
+
|
| 66 |
+
# Color distance weight (per-pixel)
|
| 67 |
+
diff = neighbor - grid_chw # [B, 3, H, W]
|
| 68 |
+
color_dist_sq = (diff * diff).sum(dim=1, keepdim=True) # [B, 1, H, W]
|
| 69 |
+
w_color = torch.exp(color_dist_sq * inv_2sc) # [B, 1, H, W]
|
| 70 |
+
|
| 71 |
+
w = w_spatial * w_color
|
| 72 |
+
weight_sum.add_(w)
|
| 73 |
+
value_sum.add_(w * neighbor)
|
| 74 |
+
|
| 75 |
+
# Normalize
|
| 76 |
+
result = value_sum / weight_sum.clamp(min=1e-8) # [B, 3, H, W]
|
| 77 |
+
|
| 78 |
+
# Back to [B, L, 3]
|
| 79 |
+
return result.permute(0, 2, 3, 1).reshape(B, -1, 3)
|
custom_nodes/ComfyUI-LCS/core/calibration.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PCA calibration from FLUX VAE: compute LCS basis, mean, and anchor positions."""
|
| 2 |
+
|
| 3 |
+
import hashlib
|
| 4 |
+
import math
|
| 5 |
+
import torch
|
| 6 |
+
import comfy.utils
|
| 7 |
+
from .patchify import patchify
|
| 8 |
+
from .lcs_data import LCSData
|
| 9 |
+
from .color_space import _chromatic_plane_basis
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def vae_fingerprint(vae) -> str:
|
| 13 |
+
"""8-char hex fingerprint from VAE decoder weights.
|
| 14 |
+
|
| 15 |
+
Used to cache calibration data per-VAE so different VAE models
|
| 16 |
+
get separate calibration files automatically.
|
| 17 |
+
"""
|
| 18 |
+
sd = vae.get_sd()
|
| 19 |
+
# Use first decoder weight tensor as fingerprint source
|
| 20 |
+
for key in sorted(sd.keys()):
|
| 21 |
+
if "decoder" in key and "weight" in key:
|
| 22 |
+
w = sd[key]
|
| 23 |
+
return hashlib.sha256(w.cpu().float().numpy().tobytes()).hexdigest()[:8]
|
| 24 |
+
# Fallback: hash first weight found
|
| 25 |
+
first_key = sorted(sd.keys())[0]
|
| 26 |
+
w = sd[first_key]
|
| 27 |
+
return hashlib.sha256(w.cpu().float().numpy().tobytes()).hexdigest()[:8]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# 8 anchor colors: R, B, G, M, C, Y, Black, White
|
| 31 |
+
ANCHOR_COLORS_RGB = [
|
| 32 |
+
(1.0, 0.0, 0.0), # Red
|
| 33 |
+
(0.0, 0.0, 1.0), # Blue
|
| 34 |
+
(0.0, 1.0, 0.0), # Green
|
| 35 |
+
(1.0, 0.0, 1.0), # Magenta
|
| 36 |
+
(0.0, 1.0, 1.0), # Cyan
|
| 37 |
+
(1.0, 1.0, 0.0), # Yellow
|
| 38 |
+
(0.0, 0.0, 0.0), # Black
|
| 39 |
+
(1.0, 1.0, 1.0), # White
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def calibrate(vae, num_colors=512, image_size=512, batch_size=8):
|
| 44 |
+
"""Compute LCS data (PCA basis, mean, anchors) from FLUX VAE.
|
| 45 |
+
|
| 46 |
+
1. Sample num_colors solid-color images uniformly from HSV
|
| 47 |
+
2. VAE encode each → latent
|
| 48 |
+
3. Patchify → average patches per image → vector in R^64
|
| 49 |
+
4. PCA on all vectors → basis B [64,3], mean μ [64]
|
| 50 |
+
5. Encode 8 anchor colors → compute LCS coords + hue angles
|
| 51 |
+
|
| 52 |
+
Returns: LCSData
|
| 53 |
+
"""
|
| 54 |
+
device = comfy.model_management.intermediate_device()
|
| 55 |
+
|
| 56 |
+
print(f"\n[LCS Calibration] Starting calibration for {num_colors} colors...")
|
| 57 |
+
print(f"[LCS Calibration] Image size: {image_size}x{image_size}, Batch size: {batch_size}")
|
| 58 |
+
|
| 59 |
+
# Step 1: Sample colors uniformly from HSV (full saturation, full value for diversity)
|
| 60 |
+
colors = []
|
| 61 |
+
for i in range(num_colors):
|
| 62 |
+
# Uniform sampling in HSV
|
| 63 |
+
h = (i * 137.508) % 360.0 / 360.0 # Golden angle for uniform coverage
|
| 64 |
+
s = 0.3 + 0.7 * ((i * 73) % 100) / 100.0 # Vary saturation 0.3-1.0
|
| 65 |
+
v = 0.3 + 0.7 * ((i * 47) % 100) / 100.0 # Vary value 0.3-1.0
|
| 66 |
+
# HSV to RGB
|
| 67 |
+
r, g, b = _hsv_to_rgb(h, s, v)
|
| 68 |
+
colors.append((r, g, b))
|
| 69 |
+
|
| 70 |
+
# Step 2+3: Encode and average patches
|
| 71 |
+
vectors = []
|
| 72 |
+
pbar = comfy.utils.ProgressBar(num_colors)
|
| 73 |
+
|
| 74 |
+
num_batches = (num_colors + batch_size - 1) // batch_size
|
| 75 |
+
print(f"[LCS Calibration] Encoding {num_colors} color images in {num_batches} batches...")
|
| 76 |
+
|
| 77 |
+
for batch_start in range(0, num_colors, batch_size):
|
| 78 |
+
batch_end = min(batch_start + batch_size, num_colors)
|
| 79 |
+
batch_colors = colors[batch_start:batch_end]
|
| 80 |
+
actual_batch = len(batch_colors)
|
| 81 |
+
|
| 82 |
+
# Create solid color images [B, H, W, 3] (BHWC format for ComfyUI VAE)
|
| 83 |
+
imgs = torch.zeros(actual_batch, image_size, image_size, 3, dtype=torch.float32, device="cpu")
|
| 84 |
+
for j, (r, g, b) in enumerate(batch_colors):
|
| 85 |
+
imgs[j, :, :, 0] = r
|
| 86 |
+
imgs[j, :, :, 1] = g
|
| 87 |
+
imgs[j, :, :, 2] = b
|
| 88 |
+
|
| 89 |
+
# VAE encode — try batch first, fall back to per-image for video VAEs
|
| 90 |
+
latent = vae.encode(imgs[:, :, :, :3])
|
| 91 |
+
|
| 92 |
+
# Squeeze video VAE temporal dim — calibration uses still images
|
| 93 |
+
if latent.ndim == 5:
|
| 94 |
+
latent = latent[:, :, 0, :, :]
|
| 95 |
+
|
| 96 |
+
# Patchify → [B', L, D]
|
| 97 |
+
patches, _, _, _ = patchify(latent)
|
| 98 |
+
|
| 99 |
+
# Average across patches → [B', D]
|
| 100 |
+
avg = patches.mean(dim=1).cpu()
|
| 101 |
+
|
| 102 |
+
if avg.shape[0] == actual_batch:
|
| 103 |
+
# Normal VAE: batch encode worked
|
| 104 |
+
vectors.extend(avg.unbind(0))
|
| 105 |
+
else:
|
| 106 |
+
# Video VAE or unexpected batch collapse — encode one by one
|
| 107 |
+
for k in range(actual_batch):
|
| 108 |
+
single = imgs[k:k+1, :, :, :3]
|
| 109 |
+
lat = vae.encode(single)
|
| 110 |
+
if lat.ndim == 5:
|
| 111 |
+
lat = lat[:, :, 0, :, :]
|
| 112 |
+
p, _, _, _ = patchify(lat)
|
| 113 |
+
vectors.append(p.mean(dim=1).cpu().squeeze(0))
|
| 114 |
+
|
| 115 |
+
pbar.update(actual_batch)
|
| 116 |
+
|
| 117 |
+
# Stack all vectors: [N, 64]
|
| 118 |
+
X = torch.stack(vectors, dim=0).float()
|
| 119 |
+
print(f"[LCS Calibration] Collected {X.shape[0]} patch vectors of dimension {X.shape[1]}")
|
| 120 |
+
|
| 121 |
+
# Step 4: PCA
|
| 122 |
+
print("[LCS Calibration] Computing PCA...")
|
| 123 |
+
mean = X.mean(dim=0) # [64]
|
| 124 |
+
X_centered = X - mean
|
| 125 |
+
# SVD for PCA
|
| 126 |
+
U, S, Vh = torch.linalg.svd(X_centered, full_matrices=False)
|
| 127 |
+
# Top 3 components: B = V[:, :3] (columns are principal directions)
|
| 128 |
+
basis = Vh[:3].T # [64, 3] (Vh rows are right singular vectors)
|
| 129 |
+
|
| 130 |
+
# Variance explained
|
| 131 |
+
total_var = (S ** 2).sum()
|
| 132 |
+
explained = (S[:3] ** 2) / total_var
|
| 133 |
+
print(f"[LCS Calibration] Top 3 components explain {explained.sum():.1%} variance")
|
| 134 |
+
print(f"[LCS Calibration] PC1: {explained[0]:.1%}, PC2: {explained[1]:.1%}, PC3: {explained[2]:.1%}")
|
| 135 |
+
|
| 136 |
+
# Step 5: Encode 8 anchor colors → LCS coords
|
| 137 |
+
print("[LCS Calibration] Encoding 8 anchor colors...")
|
| 138 |
+
anchor_lcs_list = []
|
| 139 |
+
for i, (r, g, b) in enumerate(ANCHOR_COLORS_RGB):
|
| 140 |
+
img = torch.zeros(1, image_size, image_size, 3, dtype=torch.float32, device="cpu")
|
| 141 |
+
img[0, :, :, 0] = r
|
| 142 |
+
img[0, :, :, 1] = g
|
| 143 |
+
img[0, :, :, 2] = b
|
| 144 |
+
latent = vae.encode(img[:, :, :, :3])
|
| 145 |
+
if latent.ndim == 5:
|
| 146 |
+
latent = latent[:, :, 0, :, :]
|
| 147 |
+
patches, _, _, _ = patchify(latent)
|
| 148 |
+
avg = patches.mean(dim=1).cpu().squeeze(0) # [64]
|
| 149 |
+
# Project to LCS
|
| 150 |
+
lcs_coord = (avg - mean) @ basis # [3]
|
| 151 |
+
anchor_lcs_list.append(lcs_coord)
|
| 152 |
+
|
| 153 |
+
anchor_lcs = torch.stack(anchor_lcs_list, dim=0) # [8, 3]
|
| 154 |
+
|
| 155 |
+
# Compute hue angles for 6 chromatic anchors
|
| 156 |
+
anchor_angles = _compute_anchor_angles(anchor_lcs, basis, mean)
|
| 157 |
+
|
| 158 |
+
print(f"[LCS Calibration] Complete! Basis shape: {basis.shape}")
|
| 159 |
+
print(f"[LCS Calibration] Anchor LCS coords:\n{anchor_lcs}")
|
| 160 |
+
|
| 161 |
+
return LCSData(
|
| 162 |
+
basis=basis,
|
| 163 |
+
mean=mean,
|
| 164 |
+
anchor_lcs=anchor_lcs,
|
| 165 |
+
anchor_angles=anchor_angles,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def _compute_anchor_angles(anchor_lcs, basis, mean):
|
| 170 |
+
"""Compute hue angles of the 6 chromatic anchors in the chromatic plane.
|
| 171 |
+
|
| 172 |
+
The chromatic plane is perpendicular to the achromatic axis (black→white).
|
| 173 |
+
Returns [6] tensor of angles in radians.
|
| 174 |
+
"""
|
| 175 |
+
black = anchor_lcs[6] # [3]
|
| 176 |
+
white = anchor_lcs[7] # [3]
|
| 177 |
+
chromatic = anchor_lcs[:6] # [6, 3]
|
| 178 |
+
|
| 179 |
+
# Achromatic axis
|
| 180 |
+
a = white - black
|
| 181 |
+
a_unit, e1, e2 = _chromatic_plane_basis(a)
|
| 182 |
+
|
| 183 |
+
# Project each chromatic anchor onto the plane and compute angle
|
| 184 |
+
angles = []
|
| 185 |
+
for i in range(6):
|
| 186 |
+
c = chromatic[i]
|
| 187 |
+
# Project onto achromatic axis
|
| 188 |
+
c_proj = black + ((c - black) * a).sum() / ((a * a).sum() + 1e-10) * a
|
| 189 |
+
# Chromatic residual
|
| 190 |
+
chroma = c - c_proj
|
| 191 |
+
x = (chroma * e1).sum()
|
| 192 |
+
y = (chroma * e2).sum()
|
| 193 |
+
angle = torch.atan2(y, x) % (2 * math.pi)
|
| 194 |
+
angles.append(angle)
|
| 195 |
+
|
| 196 |
+
return torch.stack(angles) # [6]
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def _hsv_to_rgb(h, s, v):
|
| 200 |
+
"""Convert HSV to RGB (scalars in [0,1])."""
|
| 201 |
+
if s < 1e-10:
|
| 202 |
+
return v, v, v
|
| 203 |
+
h6 = h * 6.0
|
| 204 |
+
i = int(h6) % 6
|
| 205 |
+
f = h6 - int(h6)
|
| 206 |
+
p = v * (1.0 - s)
|
| 207 |
+
q = v * (1.0 - s * f)
|
| 208 |
+
t = v * (1.0 - s * (1.0 - f))
|
| 209 |
+
if i == 0: return v, t, p
|
| 210 |
+
if i == 1: return q, v, p
|
| 211 |
+
if i == 2: return p, v, t
|
| 212 |
+
if i == 3: return p, q, v
|
| 213 |
+
if i == 4: return t, p, v
|
| 214 |
+
return v, p, q
|
custom_nodes/ComfyUI-LCS/core/color_space.py
ADDED
|
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Bicone LCS ↔ HSL mapping using 8 anchor colors.
|
| 2 |
+
|
| 3 |
+
Anchors are indexed as: [Red, Blue, Green, Magenta, Cyan, Yellow, Black, White]
|
| 4 |
+
Indices: 0=R, 1=B, 2=G, 3=M, 4=C, 5=Y, 6=Black, 7=White
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
# Standard HSL hue for each anchor: R=0, B=4/6, G=2/6, M=5/6, C=3/6, Y=1/6
|
| 11 |
+
_ANCHOR_HUES = (0.0, 4.0/6.0, 2.0/6.0, 5.0/6.0, 3.0/6.0, 1.0/6.0)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _bicone_factor(l, clamp_min=None):
|
| 15 |
+
"""Compute bicone scaling factor: 1 - |2L - 1|.
|
| 16 |
+
|
| 17 |
+
At l=0.5 (equator), factor=1 (full radius).
|
| 18 |
+
At l=0 or l=1 (poles), factor=0 (zero radius).
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
l: Lightness tensor [...]
|
| 22 |
+
clamp_min: Optional minimum value for numerical stability
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
Bicone factor tensor [...]
|
| 26 |
+
"""
|
| 27 |
+
factor = 1.0 - (2.0 * l - 1.0).abs()
|
| 28 |
+
if clamp_min is not None:
|
| 29 |
+
factor = factor.clamp(min=clamp_min)
|
| 30 |
+
return factor
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _wrap_hue_diff(diff):
|
| 34 |
+
"""Wrap hue differences to the shortest path on the unit circle [-0.5, 0.5]."""
|
| 35 |
+
return diff - (diff > 0.5).float() + (diff < -0.5).float()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _hue_lerp(h1, h2, t):
|
| 39 |
+
"""Lerp hues on the circle [0,1], taking the shortest path."""
|
| 40 |
+
return (h1 + t * _wrap_hue_diff(h2 - h1)) % 1.0
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _chromatic_plane_basis(a):
|
| 44 |
+
"""Build orthonormal basis (a_unit, e1, e2) for the chromatic plane perpendicular to a."""
|
| 45 |
+
a_unit = a / (a.norm() + 1e-10)
|
| 46 |
+
arb = torch.zeros(3, device=a.device, dtype=a.dtype)
|
| 47 |
+
arb[0] = 1.0
|
| 48 |
+
if a_unit[0].abs() > 0.9:
|
| 49 |
+
arb[0] = 0.0
|
| 50 |
+
arb[1] = 1.0
|
| 51 |
+
e1 = arb - (arb * a_unit).sum() * a_unit
|
| 52 |
+
e1 = e1 / (e1.norm() + 1e-10)
|
| 53 |
+
e2 = torch.linalg.cross(a_unit, e1)
|
| 54 |
+
return a_unit, e1, e2
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def hex_to_hsl(hex_str):
|
| 58 |
+
"""Convert "#RRGGBB" to (h, s, l) where h∈[0,1], s∈[0,1], l∈[0,1]."""
|
| 59 |
+
hex_str = hex_str.lstrip("#")
|
| 60 |
+
r = int(hex_str[0:2], 16) / 255.0
|
| 61 |
+
g = int(hex_str[2:4], 16) / 255.0
|
| 62 |
+
b = int(hex_str[4:6], 16) / 255.0
|
| 63 |
+
return rgb_to_hsl(r, g, b)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def rgb_to_hsl(r, g, b):
|
| 67 |
+
"""Convert RGB [0,1] to HSL [0,1]."""
|
| 68 |
+
cmax = max(r, g, b)
|
| 69 |
+
cmin = min(r, g, b)
|
| 70 |
+
delta = cmax - cmin
|
| 71 |
+
l = (cmax + cmin) / 2.0
|
| 72 |
+
|
| 73 |
+
if delta < 1e-10:
|
| 74 |
+
return 0.0, 0.0, l
|
| 75 |
+
|
| 76 |
+
s = delta / (1.0 - abs(2.0 * l - 1.0)) if abs(2.0 * l - 1.0) < 1.0 else 0.0
|
| 77 |
+
|
| 78 |
+
if cmax == r:
|
| 79 |
+
h = ((g - b) / delta) % 6.0
|
| 80 |
+
elif cmax == g:
|
| 81 |
+
h = (b - r) / delta + 2.0
|
| 82 |
+
else:
|
| 83 |
+
h = (r - g) / delta + 4.0
|
| 84 |
+
h = h / 6.0
|
| 85 |
+
if h < 0:
|
| 86 |
+
h += 1.0
|
| 87 |
+
|
| 88 |
+
return h, max(0.0, min(1.0, s)), max(0.0, min(1.0, l))
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def hsl_to_rgb(h, s, l):
|
| 92 |
+
"""Convert HSL [0,1] to RGB [0,1]. Works with scalars or tensors."""
|
| 93 |
+
if isinstance(h, torch.Tensor):
|
| 94 |
+
return _hsl_to_rgb_tensor(h, s, l)
|
| 95 |
+
|
| 96 |
+
c = (1.0 - abs(2.0 * l - 1.0)) * s
|
| 97 |
+
x = c * (1.0 - abs((h * 6.0) % 2.0 - 1.0))
|
| 98 |
+
m = l - c / 2.0
|
| 99 |
+
|
| 100 |
+
h6 = h * 6.0
|
| 101 |
+
if h6 < 1:
|
| 102 |
+
r, g, b = c, x, 0
|
| 103 |
+
elif h6 < 2:
|
| 104 |
+
r, g, b = x, c, 0
|
| 105 |
+
elif h6 < 3:
|
| 106 |
+
r, g, b = 0, c, x
|
| 107 |
+
elif h6 < 4:
|
| 108 |
+
r, g, b = 0, x, c
|
| 109 |
+
elif h6 < 5:
|
| 110 |
+
r, g, b = x, 0, c
|
| 111 |
+
else:
|
| 112 |
+
r, g, b = c, 0, x
|
| 113 |
+
|
| 114 |
+
return r + m, g + m, b + m
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _hsl_to_rgb_tensor(h, s, l):
|
| 118 |
+
"""Vectorized HSL→RGB for tensors."""
|
| 119 |
+
c = _bicone_factor(l) * s
|
| 120 |
+
h6 = h * 6.0
|
| 121 |
+
x = c * (1.0 - ((h6 % 2.0) - 1.0).abs())
|
| 122 |
+
m = l - c / 2.0
|
| 123 |
+
|
| 124 |
+
r = torch.zeros_like(h)
|
| 125 |
+
g = torch.zeros_like(h)
|
| 126 |
+
b = torch.zeros_like(h)
|
| 127 |
+
|
| 128 |
+
mask0 = h6 < 1
|
| 129 |
+
mask1 = (h6 >= 1) & (h6 < 2)
|
| 130 |
+
mask2 = (h6 >= 2) & (h6 < 3)
|
| 131 |
+
mask3 = (h6 >= 3) & (h6 < 4)
|
| 132 |
+
mask4 = (h6 >= 4) & (h6 < 5)
|
| 133 |
+
mask5 = h6 >= 5
|
| 134 |
+
|
| 135 |
+
r[mask0] = c[mask0]; g[mask0] = x[mask0]
|
| 136 |
+
r[mask1] = x[mask1]; g[mask1] = c[mask1]
|
| 137 |
+
g[mask2] = c[mask2]; b[mask2] = x[mask2]
|
| 138 |
+
g[mask3] = x[mask3]; b[mask3] = c[mask3]
|
| 139 |
+
r[mask4] = x[mask4]; b[mask4] = c[mask4]
|
| 140 |
+
r[mask5] = c[mask5]; b[mask5] = x[mask5]
|
| 141 |
+
|
| 142 |
+
return (r + m).clamp(0, 1), (g + m).clamp(0, 1), (b + m).clamp(0, 1)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def decode_lcs_to_hsl(c, anchor_lcs, anchor_angles):
|
| 146 |
+
"""Decode LCS coordinates to HSL using bicone geometry.
|
| 147 |
+
|
| 148 |
+
c: [..., 3] LCS coordinates (normalized to t=50)
|
| 149 |
+
anchor_lcs: [8, 3] anchor positions [R,B,G,M,C,Y,Black,White]
|
| 150 |
+
anchor_angles: [6] hue angles of chromatic anchors in radians
|
| 151 |
+
|
| 152 |
+
Returns: (h, s, l) each [...] in [0,1]
|
| 153 |
+
"""
|
| 154 |
+
black = anchor_lcs[6] # [3]
|
| 155 |
+
white = anchor_lcs[7] # [3]
|
| 156 |
+
chromatic = anchor_lcs[:6] # [6, 3]
|
| 157 |
+
|
| 158 |
+
# Achromatic axis
|
| 159 |
+
a = white - black # [3]
|
| 160 |
+
a_norm_sq = (a * a).sum() + 1e-10
|
| 161 |
+
|
| 162 |
+
# Lightness: project onto achromatic axis
|
| 163 |
+
diff = c - black # [..., 3]
|
| 164 |
+
l = (diff * a).sum(dim=-1) / a_norm_sq # [...]
|
| 165 |
+
l = l.clamp(0.0, 1.0)
|
| 166 |
+
|
| 167 |
+
# Point on achromatic axis
|
| 168 |
+
c_L = black + l.unsqueeze(-1) * a # [..., 3]
|
| 169 |
+
|
| 170 |
+
# Chromatic residual
|
| 171 |
+
chroma_vec = c - c_L # [..., 3]
|
| 172 |
+
chroma_dist = chroma_vec.norm(dim=-1) + 1e-10 # [...]
|
| 173 |
+
|
| 174 |
+
# Compute hue angle in chromatic plane
|
| 175 |
+
a_unit, e1, e2 = _chromatic_plane_basis(a)
|
| 176 |
+
|
| 177 |
+
# Project chromatic vector to 2D
|
| 178 |
+
x_coord = (chroma_vec * e1).sum(dim=-1) # [...]
|
| 179 |
+
y_coord = (chroma_vec * e2).sum(dim=-1) # [...]
|
| 180 |
+
angle = torch.atan2(y_coord, x_coord) # [...] radians
|
| 181 |
+
angle = angle % (2 * math.pi)
|
| 182 |
+
|
| 183 |
+
# Map angle to hue [0,1] using sorted anchor angles
|
| 184 |
+
# anchor_angles are the angles of [R,B,G,M,C,Y] in the same coordinate system
|
| 185 |
+
# Standard HSL hue: R=0, Y=1/6, G=2/6, C=3/6, B=4/6, M=5/6
|
| 186 |
+
# But anchors may not be in that order in angle-space, so we interpolate
|
| 187 |
+
sorted_angles, sort_idx = anchor_angles.sort()
|
| 188 |
+
anchor_hues = torch.tensor(_ANCHOR_HUES, device=c.device, dtype=c.dtype)
|
| 189 |
+
sorted_hues = anchor_hues[sort_idx]
|
| 190 |
+
|
| 191 |
+
# Piecewise linear interpolation around the circle
|
| 192 |
+
h = _angle_to_hue(angle, sorted_angles, sorted_hues)
|
| 193 |
+
|
| 194 |
+
# Saturation: distance to achromatic axis normalized by max distance
|
| 195 |
+
# Max distance at this hue and lightness
|
| 196 |
+
bicone_factor = _bicone_factor(l, clamp_min=1e-10)
|
| 197 |
+
|
| 198 |
+
# Find the chroma boundary at this hue (perpendicular to achromatic axis)
|
| 199 |
+
chroma_boundary = _hue_to_chroma_vector(h, chromatic, anchor_angles, a_unit, e1, e2, black, a)
|
| 200 |
+
max_radius = chroma_boundary.norm(dim=-1) + 1e-10
|
| 201 |
+
s = chroma_dist / (max_radius * bicone_factor)
|
| 202 |
+
s = s.clamp(0.0, 1.0)
|
| 203 |
+
|
| 204 |
+
return h, s, l
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def encode_hsl_to_lcs(h, s, l, anchor_lcs, anchor_angles):
|
| 208 |
+
"""Encode HSL to LCS coordinates using bicone geometry.
|
| 209 |
+
|
| 210 |
+
h, s, l: [...] in [0,1]
|
| 211 |
+
anchor_lcs: [8, 3]
|
| 212 |
+
anchor_angles: [6] radians
|
| 213 |
+
|
| 214 |
+
Returns: c [..., 3] LCS coordinates
|
| 215 |
+
"""
|
| 216 |
+
black = anchor_lcs[6] # [3]
|
| 217 |
+
white = anchor_lcs[7] # [3]
|
| 218 |
+
chromatic = anchor_lcs[:6] # [6, 3]
|
| 219 |
+
|
| 220 |
+
a = white - black
|
| 221 |
+
a_unit, e1, e2 = _chromatic_plane_basis(a)
|
| 222 |
+
|
| 223 |
+
# Lightness point on achromatic axis
|
| 224 |
+
c_L = black + l.unsqueeze(-1) * a # [..., 3]
|
| 225 |
+
|
| 226 |
+
# Chroma direction vector (equatorial radius at this hue)
|
| 227 |
+
chroma_dir = _hue_to_chroma_vector(h, chromatic, anchor_angles, a_unit, e1, e2, black, a)
|
| 228 |
+
|
| 229 |
+
# Combine: c = c_L + s * (1 - |2l-1|) * chroma_dir
|
| 230 |
+
bicone_factor = _bicone_factor(l)
|
| 231 |
+
c = c_L + (s * bicone_factor).unsqueeze(-1) * chroma_dir
|
| 232 |
+
|
| 233 |
+
return c
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def _angle_to_hue(angle, sorted_angles, sorted_hues):
|
| 237 |
+
"""Map an angle [...] to hue [0,1] via piecewise linear interpolation on anchor angles."""
|
| 238 |
+
n = len(sorted_angles)
|
| 239 |
+
h = torch.zeros_like(angle)
|
| 240 |
+
|
| 241 |
+
for i in range(n):
|
| 242 |
+
j = (i + 1) % n
|
| 243 |
+
a_start = sorted_angles[i]
|
| 244 |
+
a_end = sorted_angles[j]
|
| 245 |
+
h_start = sorted_hues[i]
|
| 246 |
+
h_end = sorted_hues[j]
|
| 247 |
+
|
| 248 |
+
# Handle wraparound
|
| 249 |
+
if a_end < a_start:
|
| 250 |
+
a_end = a_end + 2 * math.pi
|
| 251 |
+
span = a_end - a_start
|
| 252 |
+
if span < 1e-10:
|
| 253 |
+
continue
|
| 254 |
+
|
| 255 |
+
# Check which angles fall in this segment
|
| 256 |
+
if a_end > 2 * math.pi:
|
| 257 |
+
# Wraparound segment
|
| 258 |
+
mask = (angle >= a_start) | (angle < (a_end - 2 * math.pi))
|
| 259 |
+
angle_shifted = torch.where(angle < a_start, angle + 2 * math.pi, angle)
|
| 260 |
+
else:
|
| 261 |
+
mask = (angle >= a_start) & (angle < a_end)
|
| 262 |
+
angle_shifted = angle
|
| 263 |
+
|
| 264 |
+
frac = ((angle_shifted - a_start) / span).clamp(0, 1)
|
| 265 |
+
|
| 266 |
+
# Interpolate hue (handling hue wraparound)
|
| 267 |
+
h_diff = h_end - h_start
|
| 268 |
+
if abs(h_diff) > 0.5:
|
| 269 |
+
if h_diff > 0:
|
| 270 |
+
h_diff -= 1.0
|
| 271 |
+
else:
|
| 272 |
+
h_diff += 1.0
|
| 273 |
+
interp = h_start + frac * h_diff
|
| 274 |
+
interp = interp % 1.0
|
| 275 |
+
|
| 276 |
+
h = torch.where(mask, interp, h)
|
| 277 |
+
|
| 278 |
+
return h
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def _hue_to_chroma_vector(h, chromatic, anchor_angles, a_unit, e1, e2, black, a):
|
| 282 |
+
"""Map hue values [...] to EQUATORIAL chroma direction vectors.
|
| 283 |
+
|
| 284 |
+
Returns vectors in 3D LCS space that lie in the chromatic plane (perpendicular to a_unit)
|
| 285 |
+
with magnitude equal to the equatorial chroma radius at that hue (i.e., the radius at l=0.5).
|
| 286 |
+
|
| 287 |
+
The equatorial radius is computed by normalizing each anchor's chroma radius by its
|
| 288 |
+
bicone factor (1 - |2L - 1|), where L is the anchor's lightness. This ensures proper
|
| 289 |
+
round-trip encoding/decoding across the bicone.
|
| 290 |
+
|
| 291 |
+
chromatic: [6, 3] anchor LCS positions
|
| 292 |
+
anchor_angles: [6] calibrated angles of chromatic anchors (radians)
|
| 293 |
+
a_unit: [3] unit vector along achromatic axis
|
| 294 |
+
e1, e2: [3] orthonormal basis for chromatic plane
|
| 295 |
+
black: [3] black anchor position
|
| 296 |
+
a: [3] full achromatic axis vector (white - black)
|
| 297 |
+
"""
|
| 298 |
+
# Compute each anchor's lightness (scalar projection onto achromatic axis)
|
| 299 |
+
a_sq = (a * a).sum() + 1e-10
|
| 300 |
+
anchor_diff = chromatic - black # [6, 3]
|
| 301 |
+
anchor_l = (anchor_diff * a).sum(dim=-1) / a_sq # [6] lightness values
|
| 302 |
+
|
| 303 |
+
# Project anchors onto chromatic plane to get chroma vectors
|
| 304 |
+
anchor_on_axis = black + anchor_l.unsqueeze(-1) * a # [6, 3]
|
| 305 |
+
anchor_chroma = chromatic - anchor_on_axis # [6, 3] chroma vectors
|
| 306 |
+
anchor_r = anchor_chroma.norm(dim=-1) # [6] radii at anchor lightness
|
| 307 |
+
|
| 308 |
+
# Normalize to equatorial radii (radius at l=0.5 where bicone_factor=1)
|
| 309 |
+
bicone_factors = _bicone_factor(anchor_l, clamp_min=1e-6) # [6]
|
| 310 |
+
equatorial_r = anchor_r / bicone_factors # [6] equatorial radii
|
| 311 |
+
|
| 312 |
+
anchor_hues = torch.tensor(_ANCHOR_HUES, device=chromatic.device, dtype=chromatic.dtype)
|
| 313 |
+
|
| 314 |
+
# Sort by ANGLE (same as _angle_to_hue) to match segment structure
|
| 315 |
+
sorted_angles, sort_idx = anchor_angles.sort()
|
| 316 |
+
sorted_hues = anchor_hues[sort_idx]
|
| 317 |
+
sorted_radii = equatorial_r[sort_idx] # [6] equatorial radii
|
| 318 |
+
|
| 319 |
+
# Iterate segments in angle order (same as _angle_to_hue)
|
| 320 |
+
n = 6
|
| 321 |
+
result = torch.empty(h.shape + (3,), device=chromatic.device, dtype=chromatic.dtype)
|
| 322 |
+
|
| 323 |
+
for i in range(n):
|
| 324 |
+
j = (i + 1) % n
|
| 325 |
+
h_start = sorted_hues[i]
|
| 326 |
+
h_end = sorted_hues[j]
|
| 327 |
+
|
| 328 |
+
# Hue span with wraparound (same logic as _angle_to_hue)
|
| 329 |
+
h_diff = h_end - h_start
|
| 330 |
+
if abs(h_diff) > 0.5:
|
| 331 |
+
if h_diff > 0:
|
| 332 |
+
h_diff -= 1.0
|
| 333 |
+
else:
|
| 334 |
+
h_diff += 1.0
|
| 335 |
+
|
| 336 |
+
if abs(h_diff) < 1e-10:
|
| 337 |
+
continue
|
| 338 |
+
|
| 339 |
+
# Determine hue range for this segment
|
| 340 |
+
h_end_unwrapped = h_start + h_diff
|
| 341 |
+
|
| 342 |
+
# Build mask for which input hues fall in this segment
|
| 343 |
+
if h_diff > 0:
|
| 344 |
+
if h_end_unwrapped > 1.0:
|
| 345 |
+
mask = (h >= h_start) | (h < (h_end_unwrapped - 1.0))
|
| 346 |
+
h_shifted = torch.where(h < h_start, h + 1.0, h)
|
| 347 |
+
else:
|
| 348 |
+
mask = (h >= h_start) & (h < h_end_unwrapped)
|
| 349 |
+
h_shifted = h
|
| 350 |
+
else:
|
| 351 |
+
# Hue decreases
|
| 352 |
+
if h_end_unwrapped < 0.0:
|
| 353 |
+
mask = (h <= h_start) | (h > (h_end_unwrapped + 1.0))
|
| 354 |
+
h_shifted = torch.where(h > h_start, h - 1.0, h)
|
| 355 |
+
else:
|
| 356 |
+
mask = (h <= h_start) & (h > h_end_unwrapped)
|
| 357 |
+
h_shifted = h
|
| 358 |
+
|
| 359 |
+
frac = ((h_shifted - h_start) / h_diff).clamp(0, 1)
|
| 360 |
+
|
| 361 |
+
# Interpolate radius
|
| 362 |
+
interp_r = sorted_radii[i] + frac * (sorted_radii[j] - sorted_radii[i])
|
| 363 |
+
|
| 364 |
+
# Interpolate angle
|
| 365 |
+
a_start = sorted_angles[i]
|
| 366 |
+
a_end = sorted_angles[j]
|
| 367 |
+
a_span = a_end - a_start
|
| 368 |
+
if a_span < 0:
|
| 369 |
+
a_span += 2 * math.pi
|
| 370 |
+
interp_angle = (a_start + frac * a_span) % (2 * math.pi)
|
| 371 |
+
|
| 372 |
+
# Reconstruct 3D chroma vector
|
| 373 |
+
interp_vec = interp_r.unsqueeze(-1) * (
|
| 374 |
+
torch.cos(interp_angle).unsqueeze(-1) * e1
|
| 375 |
+
+ torch.sin(interp_angle).unsqueeze(-1) * e2
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
result = torch.where(mask.unsqueeze(-1), interp_vec, result)
|
| 379 |
+
|
| 380 |
+
return result
|
custom_nodes/ComfyUI-LCS/core/defaults.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hardcoded alpha_t and beta_t tables from paper Appendix F (51 entries, t=0..50)."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
# Shift alpha_t: 3D vectors for each timestep t=0..50
|
| 6 |
+
ALPHA_T = [
|
| 7 |
+
[2.3413, -2.3586, 0.4266], [2.3574, -2.3833, 0.4644], [2.3638, -2.3904, 0.4883],
|
| 8 |
+
[2.3734, -2.3951, 0.5122], [2.3831, -2.3993, 0.5384], [2.3925, -2.4026, 0.5647],
|
| 9 |
+
[2.4023, -2.4047, 0.5919], [2.4124, -2.4060, 0.6198], [2.4226, -2.4064, 0.6484],
|
| 10 |
+
[2.4330, -2.4060, 0.6772], [2.4437, -2.4051, 0.7065], [2.4546, -2.4035, 0.7367],
|
| 11 |
+
[2.4659, -2.4011, 0.7668], [2.4775, -2.3981, 0.7974], [2.4897, -2.4009, 0.8312],
|
| 12 |
+
[2.5021, -2.4036, 0.8656], [2.5148, -2.4065, 0.9008], [2.5277, -2.4093, 0.9364],
|
| 13 |
+
[2.5408, -2.4123, 0.9727], [2.5542, -2.4154, 1.0099], [2.5680, -2.4186, 1.0481],
|
| 14 |
+
[2.5820, -2.4218, 1.0868], [2.5963, -2.4252, 1.1263], [2.6110, -2.4288, 1.1672],
|
| 15 |
+
[2.6261, -2.4324, 1.2090], [2.6416, -2.4363, 1.2520], [2.6575, -2.4403, 1.2957],
|
| 16 |
+
[2.6738, -2.4444, 1.3406], [2.6904, -2.4485, 1.3865], [2.7074, -2.4529, 1.4336],
|
| 17 |
+
[2.7250, -2.4574, 1.4818], [2.7432, -2.4621, 1.5314], [2.7618, -2.4669, 1.5823],
|
| 18 |
+
[2.7810, -2.4720, 1.6344], [2.8006, -2.4771, 1.6878], [2.8209, -2.4826, 1.7430],
|
| 19 |
+
[2.8418, -2.4883, 1.7995], [2.8631, -2.4944, 1.8578], [2.8853, -2.5005, 1.9179],
|
| 20 |
+
[2.9080, -2.5066, 1.9793], [2.9313, -2.5132, 2.0426], [2.9555, -2.5199, 2.1082],
|
| 21 |
+
[2.9804, -2.5268, 2.1756], [3.0060, -2.5338, 2.2450], [3.0328, -2.5411, 2.3172],
|
| 22 |
+
[3.0603, -2.5486, 2.3914], [3.0889, -2.5561, 2.4682], [3.1189, -2.5640, 2.5482],
|
| 23 |
+
[3.1497, -2.5725, 2.6302], [3.1824, -2.5796, 2.7175], [3.2152, -2.5889, 2.8050],
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
# Scale beta_t: 3D vectors for each timestep t=0..50
|
| 27 |
+
BETA_T = [
|
| 28 |
+
[0.0163, 0.0172, 0.0295], [0.0905, 0.0716, 0.0999], [0.1345, 0.1123, 0.1544],
|
| 29 |
+
[0.1826, 0.1491, 0.2065], [0.2360, 0.1899, 0.2630], [0.2904, 0.2316, 0.3202],
|
| 30 |
+
[0.3471, 0.2749, 0.3793], [0.4050, 0.3191, 0.4394], [0.4640, 0.3641, 0.5003],
|
| 31 |
+
[0.5231, 0.4091, 0.5611], [0.5834, 0.4547, 0.6228], [0.6456, 0.5016, 0.6861],
|
| 32 |
+
[0.7077, 0.5481, 0.7488], [0.7713, 0.5958, 0.8127], [0.8410, 0.6496, 0.8866],
|
| 33 |
+
[0.9119, 0.7044, 0.9616], [0.9845, 0.7605, 1.0386], [1.0578, 0.8172, 1.1163],
|
| 34 |
+
[1.1325, 0.8750, 1.1957], [1.2094, 0.9344, 1.2771], [1.2880, 0.9953, 1.3606],
|
| 35 |
+
[1.3680, 1.0571, 1.4453], [1.4498, 1.1205, 1.5321], [1.5341, 1.1858, 1.6216],
|
| 36 |
+
[1.6206, 1.2526, 1.7131], [1.7094, 1.3214, 1.8072], [1.7998, 1.3913, 1.9030],
|
| 37 |
+
[1.8927, 1.4633, 2.0014], [1.9879, 1.5370, 2.1022], [2.0854, 1.6126, 2.2056],
|
| 38 |
+
[2.1853, 1.6900, 2.3114], [2.2881, 1.7696, 2.4202], [2.3939, 1.8515, 2.5321],
|
| 39 |
+
[2.5021, 1.9354, 2.6467], [2.6133, 2.0215, 2.7642], [2.7280, 2.1106, 2.8857],
|
| 40 |
+
[2.8455, 2.2017, 3.0101], [2.9668, 2.2957, 3.1386], [3.0921, 2.3929, 3.2712],
|
| 41 |
+
[3.2204, 2.4922, 3.4067], [3.3523, 2.5946, 3.5464], [3.4888, 2.7006, 3.6911],
|
| 42 |
+
[3.6292, 2.8097, 3.8398], [3.7741, 2.9222, 3.9931], [3.9247, 3.0394, 4.1527],
|
| 43 |
+
[4.0793, 3.1597, 4.3168], [4.2393, 3.2843, 4.4866], [4.4053, 3.4142, 4.6636],
|
| 44 |
+
[4.5760, 3.5480, 4.8461], [4.7541, 3.6886, 5.0383], [4.9407, 3.8364, 5.2390],
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
# Pre-convert to tensors (lazily cached on first access)
|
| 48 |
+
_alpha_tensor = None
|
| 49 |
+
_beta_tensor = None
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def get_alpha_table():
|
| 53 |
+
"""Return α_t table as tensor [51, 3], cached after first call."""
|
| 54 |
+
global _alpha_tensor
|
| 55 |
+
if _alpha_tensor is None:
|
| 56 |
+
_alpha_tensor = torch.tensor(ALPHA_T, dtype=torch.float32) # [51, 3]
|
| 57 |
+
return _alpha_tensor
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def get_beta_table():
|
| 61 |
+
"""Return β_t table as tensor [51, 3], cached after first call."""
|
| 62 |
+
global _beta_tensor
|
| 63 |
+
if _beta_tensor is None:
|
| 64 |
+
_beta_tensor = torch.tensor(BETA_T, dtype=torch.float32) # [51, 3]
|
| 65 |
+
return _beta_tensor
|
custom_nodes/ComfyUI-LCS/core/diagnostics.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Diagnostic tests for LCS intervention pipeline.
|
| 2 |
+
|
| 3 |
+
This module provides tests and diagnostics to identify conditions that
|
| 4 |
+
cause image blurriness or quality degradation during LCS intervention.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import math
|
| 9 |
+
from .color_space import decode_lcs_to_hsl, encode_hsl_to_lcs, _hue_lerp
|
| 10 |
+
from .timestep import get_alpha_beta, get_alpha_beta_t50, normalize_to_t50, denormalize_from_t50
|
| 11 |
+
|
| 12 |
+
# Test constants
|
| 13 |
+
_T50_REFERENCE_COORD = [0.5, 0.3, 0.1] # Typical LCS magnitude at t=50
|
| 14 |
+
_TEST_STRENGTHS = [0.0, 0.25, 0.5, 0.75, 1.0, 1.5, 2.0] # Range from none to overshoot
|
| 15 |
+
_VARIATION_SCALE = 0.5 # Scale for test patch variation
|
| 16 |
+
_NOISE_SCALE = 2.0 # Simulated diffusion noise magnitude
|
| 17 |
+
_PROBLEMATIC_AMPLIFICATION_THRESHOLD = 50 # >50x noise amplification is problematic
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def test_round_trip_consistency(anchor_lcs, anchor_angles):
|
| 21 |
+
"""Test that encode(decode(x)) ≈ x for typical LCS coordinates.
|
| 22 |
+
|
| 23 |
+
This verifies the bicone geometry math is correct.
|
| 24 |
+
"""
|
| 25 |
+
chromatic = anchor_lcs[:6]
|
| 26 |
+
black, white = anchor_lcs[6], anchor_lcs[7]
|
| 27 |
+
|
| 28 |
+
# Test round-trip on anchor positions
|
| 29 |
+
errors = []
|
| 30 |
+
test_cases = list(chromatic) # All 6 chromatic anchors
|
| 31 |
+
|
| 32 |
+
# Add some mid-tones and random points
|
| 33 |
+
for _ in range(5):
|
| 34 |
+
# Generate random LCS point
|
| 35 |
+
h = torch.rand(1).item()
|
| 36 |
+
s = torch.rand(1).item()
|
| 37 |
+
l = torch.rand(1).item()
|
| 38 |
+
c = encode_hsl_to_lcs(
|
| 39 |
+
torch.tensor(h), torch.tensor(s), torch.tensor(l),
|
| 40 |
+
anchor_lcs, anchor_angles
|
| 41 |
+
)
|
| 42 |
+
test_cases.append(c)
|
| 43 |
+
|
| 44 |
+
for c in test_cases:
|
| 45 |
+
h, s, l = decode_lcs_to_hsl(c, anchor_lcs, anchor_angles)
|
| 46 |
+
c_round = encode_hsl_to_lcs(h, s, l, anchor_lcs, anchor_angles)
|
| 47 |
+
error = (c - c_round).norm().item()
|
| 48 |
+
errors.append(error)
|
| 49 |
+
|
| 50 |
+
max_error = max(errors)
|
| 51 |
+
avg_error = sum(errors) / len(errors)
|
| 52 |
+
return {
|
| 53 |
+
"max_round_trip_error": max_error,
|
| 54 |
+
"avg_round_trip_error": avg_error,
|
| 55 |
+
"passed": max_error < 1e-4,
|
| 56 |
+
"errors": errors,
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def test_normalization_stability():
|
| 61 |
+
"""Test that normalize/denormalize round-trip is stable across all timesteps.
|
| 62 |
+
|
| 63 |
+
Identifies timesteps where numerical instability could cause issues.
|
| 64 |
+
"""
|
| 65 |
+
# Sample LCS coordinates at t=50 (clean image reference)
|
| 66 |
+
c_t50 = torch.tensor(_T50_REFERENCE_COORD, dtype=torch.float32)
|
| 67 |
+
alpha_50, beta_50 = get_alpha_beta_t50()
|
| 68 |
+
|
| 69 |
+
results = []
|
| 70 |
+
for t in range(51):
|
| 71 |
+
sigma = 1.0 - t / 50.0 # sigma = 1 - t/50
|
| 72 |
+
alpha_t, beta_t = get_alpha_beta(sigma)
|
| 73 |
+
|
| 74 |
+
# Normalize then denormalize
|
| 75 |
+
c_norm = normalize_to_t50(c_t50, alpha_t, beta_t, alpha_50, beta_50)
|
| 76 |
+
c_back = denormalize_from_t50(c_norm, alpha_t, beta_t, alpha_50, beta_50)
|
| 77 |
+
|
| 78 |
+
error = (c_t50 - c_back).norm().item()
|
| 79 |
+
|
| 80 |
+
# Check amplification factor
|
| 81 |
+
amplification = (beta_50 / beta_t).max().item()
|
| 82 |
+
|
| 83 |
+
results.append({
|
| 84 |
+
"t": t,
|
| 85 |
+
"sigma": sigma,
|
| 86 |
+
"beta_t_min": beta_t.min().item(),
|
| 87 |
+
"amplification": amplification,
|
| 88 |
+
"round_trip_error": error,
|
| 89 |
+
})
|
| 90 |
+
|
| 91 |
+
return results
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def test_type_ii_uniformity(anchor_lcs, anchor_angles):
|
| 95 |
+
"""Test if Type II intervention at high strength produces uniform outputs.
|
| 96 |
+
|
| 97 |
+
This is a key diagnostic for the blurriness issue - if all patches
|
| 98 |
+
converge to the same HSL values, the image loses detail.
|
| 99 |
+
"""
|
| 100 |
+
# Create diverse patch set (simulate image with color variation)
|
| 101 |
+
patches = torch.randn(100, 3) * _VARIATION_SCALE + torch.tensor([0.3, 0.2, 0.1])
|
| 102 |
+
|
| 103 |
+
# Target color (e.g., saturated red)
|
| 104 |
+
t_h, t_s, t_l = 0.0, 1.0, 0.5
|
| 105 |
+
|
| 106 |
+
# Decode all patches ONCE (constant across strengths)
|
| 107 |
+
h_cur, s_cur, l_cur = decode_lcs_to_hsl(patches, anchor_lcs, anchor_angles)
|
| 108 |
+
|
| 109 |
+
# Target HSL tensors
|
| 110 |
+
h_new = torch.full_like(h_cur, t_h)
|
| 111 |
+
s_new = torch.full_like(s_cur, t_s)
|
| 112 |
+
l_new = torch.full_like(l_cur, t_l)
|
| 113 |
+
|
| 114 |
+
# Compute input variance once (patches never changes)
|
| 115 |
+
input_var = patches.var(dim=0).mean().item()
|
| 116 |
+
|
| 117 |
+
# Test different strengths
|
| 118 |
+
for strength in _TEST_STRENGTHS:
|
| 119 |
+
# Hue lerp using shared helper
|
| 120 |
+
h_interp = _hue_lerp(h_cur, h_new, strength)
|
| 121 |
+
s_interp = (s_cur + strength * (s_new - s_cur)).clamp(0, 1)
|
| 122 |
+
l_interp = (l_cur + strength * (l_new - l_cur)).clamp(0, 1)
|
| 123 |
+
|
| 124 |
+
# Re-encode
|
| 125 |
+
new_patches = encode_hsl_to_lcs(h_interp, s_interp, l_interp, anchor_lcs, anchor_angles)
|
| 126 |
+
|
| 127 |
+
# Measure variance loss
|
| 128 |
+
output_var = new_patches.var(dim=0).mean().item()
|
| 129 |
+
var_ratio = output_var / (input_var + 1e-10)
|
| 130 |
+
|
| 131 |
+
# Check how many unique HSL values we end up with
|
| 132 |
+
h_unique = len(torch.unique(h_interp.round(decimals=3)))
|
| 133 |
+
s_unique = len(torch.unique(s_interp.round(decimals=3)))
|
| 134 |
+
l_unique = len(torch.unique(l_interp.round(decimals=3)))
|
| 135 |
+
|
| 136 |
+
print(f"strength={strength:.2f}: var_ratio={var_ratio:.3f}, "
|
| 137 |
+
f"unique_h={h_unique}, unique_s={s_unique}, unique_l={l_unique}")
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def test_early_timestep_amplification():
|
| 141 |
+
"""Test numerical behavior at very early timesteps (high sigma).
|
| 142 |
+
|
| 143 |
+
At t≈0 (sigma≈1), beta_t is very small, causing large amplification
|
| 144 |
+
in normalize_to_t50. This could amplify noise and corrupt the signal.
|
| 145 |
+
"""
|
| 146 |
+
# Typical LCS coordinate magnitude at t=50
|
| 147 |
+
c_ref = torch.tensor(_T50_REFERENCE_COORD, dtype=torch.float32)
|
| 148 |
+
alpha_50, beta_50 = get_alpha_beta_t50() # Constant across all sigmas
|
| 149 |
+
|
| 150 |
+
for sigma in [1.0, 0.99, 0.95, 0.90, 0.85, 0.80, 0.50, 0.0]:
|
| 151 |
+
alpha_t, beta_t = get_alpha_beta(sigma)
|
| 152 |
+
|
| 153 |
+
# Simulate a noisy observation at timestep t
|
| 154 |
+
# In diffusion, the observation is alpha_t * clean + beta_t * noise
|
| 155 |
+
# At high sigma, noise dominates
|
| 156 |
+
noise = torch.randn(3) * _NOISE_SCALE
|
| 157 |
+
c_observed = alpha_t + beta_t * c_ref + beta_t * noise
|
| 158 |
+
|
| 159 |
+
# Normalize to t=50
|
| 160 |
+
c_norm = normalize_to_t50(c_observed, alpha_t, beta_t, alpha_50, beta_50)
|
| 161 |
+
|
| 162 |
+
# Measure deviation from reference
|
| 163 |
+
deviation = (c_norm - c_ref).norm().item()
|
| 164 |
+
amplification = (beta_50 / beta_t).max().item()
|
| 165 |
+
|
| 166 |
+
print(f"sigma={sigma:.2f}: beta_t={beta_t.numpy()}, "
|
| 167 |
+
f"amplification={amplification:.1f}x, deviation={deviation:.3f}")
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def analyze_blurriness_causes(lcs_data_path=None):
|
| 171 |
+
"""Comprehensive analysis of all potential blurriness causes."""
|
| 172 |
+
print("=" * 60)
|
| 173 |
+
print("LCS INTERVENTION BLURRINESS ANALYSIS")
|
| 174 |
+
print("=" * 60)
|
| 175 |
+
|
| 176 |
+
# Load actual calibration data
|
| 177 |
+
if lcs_data_path is None:
|
| 178 |
+
from pathlib import Path
|
| 179 |
+
data_dir = Path(__file__).parent.parent / "data"
|
| 180 |
+
safetensors_files = list(data_dir.glob("lcs_*.safetensors"))
|
| 181 |
+
if safetensors_files:
|
| 182 |
+
lcs_data_path = safetensors_files[0]
|
| 183 |
+
else:
|
| 184 |
+
print("ERROR: No calibration data found. Run LCSLoadData with calibrate=True first.")
|
| 185 |
+
return
|
| 186 |
+
|
| 187 |
+
from safetensors.torch import load_file
|
| 188 |
+
data = load_file(lcs_data_path)
|
| 189 |
+
anchor_lcs = data["anchor_lcs"]
|
| 190 |
+
anchor_angles = data["anchor_angles"]
|
| 191 |
+
|
| 192 |
+
print(f"\nLoaded calibration data from: {lcs_data_path}")
|
| 193 |
+
print(f"anchor_lcs shape: {anchor_lcs.shape}")
|
| 194 |
+
print(f"anchor_angles shape: {anchor_angles.shape}")
|
| 195 |
+
|
| 196 |
+
print("\n1. ROUND-TRIP CONSISTENCY TEST")
|
| 197 |
+
print("-" * 40)
|
| 198 |
+
result = test_round_trip_consistency(anchor_lcs, anchor_angles)
|
| 199 |
+
print(f"Max error: {result['max_round_trip_error']:.2e}")
|
| 200 |
+
print(f"Avg error: {result['avg_round_trip_error']:.2e}")
|
| 201 |
+
print(f"Status: {'PASS' if result['passed'] else 'FAIL'}")
|
| 202 |
+
|
| 203 |
+
print("\n2. NORMALIZATION STABILITY TEST")
|
| 204 |
+
print("-" * 40)
|
| 205 |
+
norm_results = test_normalization_stability()
|
| 206 |
+
problematic = [r for r in norm_results if r['amplification'] > _PROBLEMATIC_AMPLIFICATION_THRESHOLD]
|
| 207 |
+
print(f"Timesteps with >{_PROBLEMATIC_AMPLIFICATION_THRESHOLD}x amplification: {len(problematic)}")
|
| 208 |
+
for r in problematic[:5]:
|
| 209 |
+
print(f" t={r['t']:2d} (sigma={r['sigma']:.2f}): amp={r['amplification']:.1f}x")
|
| 210 |
+
|
| 211 |
+
print("\n3. TYPE II UNIFORMITY TEST")
|
| 212 |
+
print("-" * 40)
|
| 213 |
+
test_type_ii_uniformity(anchor_lcs, anchor_angles)
|
| 214 |
+
|
| 215 |
+
print("\n4. EARLY TIMESTEP AMPLIFICATION TEST")
|
| 216 |
+
print("-" * 40)
|
| 217 |
+
test_early_timestep_amplification()
|
| 218 |
+
|
| 219 |
+
print("\n" + "=" * 60)
|
| 220 |
+
print("CONCLUSIONS")
|
| 221 |
+
print("=" * 60)
|
| 222 |
+
print("""
|
| 223 |
+
Potential blurriness causes identified:
|
| 224 |
+
|
| 225 |
+
1. TYPE II AT HIGH STRENGTH: At strength=1.0, all patches get the same
|
| 226 |
+
target HSL, destroying spatial color variation. This is the PRIMARY
|
| 227 |
+
cause of blur in type_ii mode.
|
| 228 |
+
|
| 229 |
+
2. EARLY TIMESTEP AMPLIFICATION: At sigma>0.95 (t<2.5), beta_t is ~0.02,
|
| 230 |
+
causing ~250x amplification of noise. Intervening too early (step 0-2)
|
| 231 |
+
will corrupt the signal.
|
| 232 |
+
|
| 233 |
+
3. OVERSHOOTING: strength>1.0 overshoots the target, potentially pushing
|
| 234 |
+
values outside the valid color gamut. This can cause clipping and
|
| 235 |
+
artifacts.
|
| 236 |
+
|
| 237 |
+
RECOMMENDATIONS:
|
| 238 |
+
- For type_ii mode, use strength<0.8 to preserve some original variation
|
| 239 |
+
- Avoid intervening before step 5 (sigma<0.90)
|
| 240 |
+
- For interpolated mode, the gamma=sigma blending naturally limits damage
|
| 241 |
+
at early steps
|
| 242 |
+
""")
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
if __name__ == "__main__":
|
| 246 |
+
analyze_blurriness_causes()
|
custom_nodes/ComfyUI-LCS/core/lcs_data.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@dataclass
|
| 6 |
+
class LCSData:
|
| 7 |
+
"""Calibration data for the Latent Color Subspace.
|
| 8 |
+
|
| 9 |
+
Produced by PCA on FLUX VAE-encoded solid-color images. Flows between
|
| 10 |
+
all LCS nodes as the shared LCS_DATA custom type.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
basis: torch.Tensor # [64, 3] PCA basis B (orthonormal columns)
|
| 14 |
+
mean: torch.Tensor # [64] PCA mean mu
|
| 15 |
+
anchor_lcs: torch.Tensor # [8, 3] LCS coords of 8 anchor colors [R,B,G,M,C,Y,Black,White]
|
| 16 |
+
anchor_angles: torch.Tensor # [6] hue angles (radians) of the 6 chromatic anchors
|
| 17 |
+
|
| 18 |
+
def to(self, device, dtype=None):
|
| 19 |
+
"""Move all tensors to device/dtype."""
|
| 20 |
+
kw = {"device": device}
|
| 21 |
+
if dtype is not None:
|
| 22 |
+
kw["dtype"] = dtype
|
| 23 |
+
return LCSData(
|
| 24 |
+
basis=self.basis.to(**kw),
|
| 25 |
+
mean=self.mean.to(**kw),
|
| 26 |
+
anchor_lcs=self.anchor_lcs.to(**kw),
|
| 27 |
+
anchor_angles=self.anchor_angles.to(**kw),
|
| 28 |
+
)
|
custom_nodes/ComfyUI-LCS/core/patchify.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Patchify/unpatchify for latent tensors (patch_size=2, auto-detect channels).
|
| 2 |
+
|
| 3 |
+
Handles 3D, 4D, and 5D inputs. Pads odd spatial dims to even before patchifying.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from einops import rearrange
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def patchify(x):
|
| 11 |
+
"""Convert latent [C, H, W], [B, C, H, W], or [B, C, T, H, W] → patch sequence [B, L, C*4].
|
| 12 |
+
|
| 13 |
+
Handles three input formats:
|
| 14 |
+
- 3D [C, H, W]: adds batch dim, extra_shape="unbatched"
|
| 15 |
+
- 4D [B, C, H, W]: standard path, extra_shape=None
|
| 16 |
+
- 5D [B, C, T, H, W]: video VAE, merges T into batch, extra_shape=(B, C, T)
|
| 17 |
+
|
| 18 |
+
Pads odd H/W to even before patchifying. The pad amounts are stored
|
| 19 |
+
in the returned extra_shape for unpatchify to crop back.
|
| 20 |
+
|
| 21 |
+
L = (H_padded/2) * (W_padded/2), d = C * 2 * 2.
|
| 22 |
+
"""
|
| 23 |
+
extra_shape = None
|
| 24 |
+
pad_h = 0
|
| 25 |
+
pad_w = 0
|
| 26 |
+
|
| 27 |
+
if x.ndim == 3:
|
| 28 |
+
extra_shape = "unbatched"
|
| 29 |
+
x = x.unsqueeze(0)
|
| 30 |
+
elif x.ndim == 5:
|
| 31 |
+
B_orig, C, T, H, W = x.shape
|
| 32 |
+
extra_shape = (B_orig, C, T)
|
| 33 |
+
x = x.permute(0, 2, 1, 3, 4).reshape(B_orig * T, C, H, W)
|
| 34 |
+
|
| 35 |
+
B, C, H, W = x.shape
|
| 36 |
+
if H < 1 or W < 1:
|
| 37 |
+
return None, None, None, None
|
| 38 |
+
|
| 39 |
+
# Pad odd dimensions to even (replicate last row/col)
|
| 40 |
+
if H % 2 != 0:
|
| 41 |
+
pad_h = 1
|
| 42 |
+
if W % 2 != 0:
|
| 43 |
+
pad_w = 1
|
| 44 |
+
if pad_h or pad_w:
|
| 45 |
+
x = F.pad(x, (0, pad_w, 0, pad_h), mode="replicate")
|
| 46 |
+
|
| 47 |
+
H_p, W_p = x.shape[2], x.shape[3]
|
| 48 |
+
h_len = H_p // 2
|
| 49 |
+
w_len = W_p // 2
|
| 50 |
+
patches = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
| 51 |
+
|
| 52 |
+
# Bundle pad info with extra_shape
|
| 53 |
+
if pad_h or pad_w:
|
| 54 |
+
extra_shape = {"orig_extra": extra_shape, "pad_h": pad_h, "pad_w": pad_w}
|
| 55 |
+
|
| 56 |
+
return patches, h_len, w_len, extra_shape
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def unpatchify(patches, h_len, w_len, extra_shape=None):
|
| 60 |
+
"""Convert patch sequence [B, L, C*4] → latent, restoring original shape.
|
| 61 |
+
|
| 62 |
+
Auto-detects channel count from patch dimension: C = D / 4.
|
| 63 |
+
Handles padding removal and 3D/5D restoration based on extra_shape.
|
| 64 |
+
"""
|
| 65 |
+
D = patches.shape[-1]
|
| 66 |
+
C = D // 4 # patch_size=2×2=4
|
| 67 |
+
x = rearrange(patches, "b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
| 68 |
+
h=h_len, w=w_len, c=C, ph=2, pw=2)
|
| 69 |
+
|
| 70 |
+
# Unwrap pad info if present
|
| 71 |
+
pad_h = 0
|
| 72 |
+
pad_w = 0
|
| 73 |
+
orig_extra = extra_shape
|
| 74 |
+
if isinstance(extra_shape, dict):
|
| 75 |
+
pad_h = extra_shape["pad_h"]
|
| 76 |
+
pad_w = extra_shape["pad_w"]
|
| 77 |
+
orig_extra = extra_shape["orig_extra"]
|
| 78 |
+
|
| 79 |
+
# Remove padding
|
| 80 |
+
if pad_h:
|
| 81 |
+
x = x[:, :, :-pad_h, :]
|
| 82 |
+
if pad_w:
|
| 83 |
+
x = x[:, :, :, :-pad_w]
|
| 84 |
+
|
| 85 |
+
# Restore original format
|
| 86 |
+
if orig_extra == "unbatched":
|
| 87 |
+
x = x.squeeze(0)
|
| 88 |
+
elif orig_extra is not None:
|
| 89 |
+
B_orig, C_orig, T = orig_extra
|
| 90 |
+
H, W = x.shape[2], x.shape[3]
|
| 91 |
+
x = x.reshape(B_orig, T, C_orig, H, W).permute(0, 2, 1, 3, 4)
|
| 92 |
+
|
| 93 |
+
return x
|
custom_nodes/ComfyUI-LCS/core/relationships.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Local color relationship analysis for drift detection and correction."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def compute_local_relationships(c, h_len, w_len, kernel_radius=2):
|
| 8 |
+
"""Compute per-patch relationship vector from 5x5 neighborhood.
|
| 9 |
+
|
| 10 |
+
For each patch, cosine similarity with each of up to 24 neighbors.
|
| 11 |
+
Returns [B, L, N_neighbors] relationship vectors where N_neighbors = (2*r+1)^2 - 1.
|
| 12 |
+
"""
|
| 13 |
+
B = c.shape[0]
|
| 14 |
+
r = kernel_radius
|
| 15 |
+
k_size = 2 * r + 1
|
| 16 |
+
n_neighbors = k_size * k_size - 1 # 24 for r=2
|
| 17 |
+
|
| 18 |
+
# Reshape to spatial grid
|
| 19 |
+
grid = c.reshape(B, h_len, w_len, 3) # [B, H, W, 3]
|
| 20 |
+
|
| 21 |
+
# Permute to [B, 3, H, W] for padding
|
| 22 |
+
grid_chw = grid.permute(0, 3, 1, 2) # [B, 3, H, W]
|
| 23 |
+
padded = F.pad(grid_chw, (r, r, r, r), mode="replicate") # [B, 3, H+2r, W+2r]
|
| 24 |
+
|
| 25 |
+
# Center values — normalize for cosine similarity
|
| 26 |
+
center_norm = grid_chw / grid_chw.norm(dim=1, keepdim=True).clamp(min=1e-8)
|
| 27 |
+
|
| 28 |
+
# Pre-normalize padded tensor once (avoids per-neighbor normalization in loop)
|
| 29 |
+
padded_norm = padded / padded.norm(dim=1, keepdim=True).clamp(min=1e-8)
|
| 30 |
+
|
| 31 |
+
# Collect cosine similarities with each neighbor
|
| 32 |
+
similarities = []
|
| 33 |
+
for dy in range(-r, r + 1):
|
| 34 |
+
for dx in range(-r, r + 1):
|
| 35 |
+
if dy == 0 and dx == 0:
|
| 36 |
+
continue
|
| 37 |
+
y_start = r + dy
|
| 38 |
+
x_start = r + dx
|
| 39 |
+
neighbor_norm = padded_norm[:, :, y_start:y_start + h_len, x_start:x_start + w_len]
|
| 40 |
+
# Cosine similarity per pixel
|
| 41 |
+
sim = (center_norm * neighbor_norm).sum(dim=1) # [B, H, W]
|
| 42 |
+
similarities.append(sim)
|
| 43 |
+
|
| 44 |
+
# Stack to [B, H, W, N_neighbors] -> [B, L, N_neighbors]
|
| 45 |
+
rel = torch.stack(similarities, dim=-1) # [B, H, W, N_neighbors]
|
| 46 |
+
return rel.reshape(B, -1, n_neighbors)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def detect_anomalies_adaptive(r_current, r_reference):
|
| 50 |
+
"""Compare current vs reference relationships with adaptive threshold.
|
| 51 |
+
|
| 52 |
+
Uses per-batch robust outlier detection: threshold = median + 3.0 * 1.4826 * MAD.
|
| 53 |
+
Returns anomaly_magnitude [B, L, 1] in [0, 1].
|
| 54 |
+
"""
|
| 55 |
+
# Mean absolute difference across neighbor relationships
|
| 56 |
+
diff = (r_current - r_reference).abs().mean(dim=-1) # [B, L]
|
| 57 |
+
|
| 58 |
+
# Per-batch robust statistics
|
| 59 |
+
median = diff.median(dim=-1, keepdim=True).values # [B, 1]
|
| 60 |
+
mad = (diff - median).abs().median(dim=-1, keepdim=True).values # [B, 1]
|
| 61 |
+
threshold = median + 3.0 * 1.4826 * mad # [B, 1]
|
| 62 |
+
|
| 63 |
+
# Soft ramp above threshold, normalized to [0, 1]
|
| 64 |
+
anomaly = (diff - threshold).clamp(min=0.0) # [B, L]
|
| 65 |
+
# Normalize per-batch: max anomaly → 1.0
|
| 66 |
+
amax = anomaly.amax(dim=-1, keepdim=True).clamp(min=1e-8) # [B, 1]
|
| 67 |
+
anomaly = anomaly / amax
|
| 68 |
+
|
| 69 |
+
return anomaly.unsqueeze(-1) # [B, L, 1]
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def infer_color_from_neighbors(c, anomaly_mag, h_len, w_len, kernel_radius=2):
|
| 73 |
+
"""For anomalous patches, infer correct color from non-anomalous neighbors.
|
| 74 |
+
|
| 75 |
+
Uses inverse-anomaly weighting: patches with low anomaly contribute more.
|
| 76 |
+
Returns [B, L, 3] corrected colors (blended: anomalous patches get
|
| 77 |
+
neighbor-inferred values, non-anomalous patches keep their original).
|
| 78 |
+
"""
|
| 79 |
+
B = c.shape[0]
|
| 80 |
+
r = kernel_radius
|
| 81 |
+
|
| 82 |
+
# Reshape to spatial grid
|
| 83 |
+
grid = c.reshape(B, h_len, w_len, 3)
|
| 84 |
+
anom_grid = anomaly_mag.reshape(B, h_len, w_len, 1)
|
| 85 |
+
|
| 86 |
+
# Pad both grid and anomaly
|
| 87 |
+
grid_chw = grid.permute(0, 3, 1, 2) # [B, 3, H, W]
|
| 88 |
+
anom_chw = anom_grid.permute(0, 3, 1, 2) # [B, 1, H, W]
|
| 89 |
+
padded_c = F.pad(grid_chw, (r, r, r, r), mode="replicate")
|
| 90 |
+
padded_a = F.pad(anom_chw, (r, r, r, r), mode="replicate")
|
| 91 |
+
|
| 92 |
+
# Weight neighbors by how non-anomalous they are
|
| 93 |
+
weight_sum = torch.zeros(B, 1, h_len, w_len, device=c.device, dtype=c.dtype)
|
| 94 |
+
value_sum = torch.zeros(B, 3, h_len, w_len, device=c.device, dtype=c.dtype)
|
| 95 |
+
|
| 96 |
+
for dy in range(-r, r + 1):
|
| 97 |
+
for dx in range(-r, r + 1):
|
| 98 |
+
if dy == 0 and dx == 0:
|
| 99 |
+
continue
|
| 100 |
+
y_start = r + dy
|
| 101 |
+
x_start = r + dx
|
| 102 |
+
neighbor_c = padded_c[:, :, y_start:y_start + h_len, x_start:x_start + w_len]
|
| 103 |
+
neighbor_a = padded_a[:, :, y_start:y_start + h_len, x_start:x_start + w_len]
|
| 104 |
+
|
| 105 |
+
# Weight: 1 - anomaly (non-anomalous neighbors get high weight)
|
| 106 |
+
w = (1.0 - neighbor_a).clamp(min=0.01) # [B, 1, H, W]
|
| 107 |
+
weight_sum = weight_sum + w
|
| 108 |
+
value_sum = value_sum + w * neighbor_c
|
| 109 |
+
|
| 110 |
+
# Inferred color from neighbors
|
| 111 |
+
inferred = value_sum / weight_sum.clamp(min=1e-8) # [B, 3, H, W]
|
| 112 |
+
inferred = inferred.permute(0, 2, 3, 1).reshape(B, -1, 3) # [B, L, 3]
|
| 113 |
+
|
| 114 |
+
# Blend: anomalous patches use inferred, non-anomalous keep original
|
| 115 |
+
# anomaly_mag is [B, L, 1], range [0, ~1]
|
| 116 |
+
blend = anomaly_mag.clamp(0, 1)
|
| 117 |
+
return c * (1.0 - blend) + inferred * blend
|
custom_nodes/ComfyUI-LCS/core/sampling.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared sampling utilities for LCS intervention hooks."""
|
| 2 |
+
|
| 3 |
+
import comfy.utils
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def find_step_index(sigma, sigmas):
|
| 9 |
+
"""Find the step index for a given sigma value in the sigma schedule.
|
| 10 |
+
|
| 11 |
+
Uses torch.isclose for robust matching across dtype differences (e.g. bfloat16
|
| 12 |
+
sigma vs float32 sample_sigmas), with argmin fallback for edge cases.
|
| 13 |
+
"""
|
| 14 |
+
sigma_val = sigma.flatten()[0].float()
|
| 15 |
+
sigmas_f = sigmas.float()
|
| 16 |
+
matched = torch.isclose(sigmas_f, sigma_val, rtol=1e-3, atol=1e-5).nonzero()
|
| 17 |
+
if len(matched) > 0:
|
| 18 |
+
return matched[0].item()
|
| 19 |
+
return (sigmas_f - sigma_val).abs().argmin().item()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def denoised_to_raw(denoised, model):
|
| 23 |
+
"""Convert denoised tensor from process_in space to raw VAE space.
|
| 24 |
+
|
| 25 |
+
Uses the model's latent_format.process_out (inverse of process_in).
|
| 26 |
+
Works for any model: FLUX (scale+shift), LTXV (identity), SD (scale), etc.
|
| 27 |
+
"""
|
| 28 |
+
return model.latent_format.process_out(denoised)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def raw_to_denoised(raw, model):
|
| 32 |
+
"""Convert raw VAE space tensor back to process_in space.
|
| 33 |
+
|
| 34 |
+
Uses the model's latent_format.process_in.
|
| 35 |
+
"""
|
| 36 |
+
return model.latent_format.process_in(raw)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def unpack_video_if_needed(denoised, args):
|
| 40 |
+
"""Unpack LTXAV-style packed latents if detected.
|
| 41 |
+
|
| 42 |
+
LTXAV packs video [B,128,F,H,W] + audio [B,ch,T,freq] into [B,1,flat].
|
| 43 |
+
Returns (tensor_to_process, pack_info) where pack_info is None for
|
| 44 |
+
non-packed formats or a dict for repacking.
|
| 45 |
+
"""
|
| 46 |
+
# Detect packed format: shape [B, 1, flat] with very large last dim
|
| 47 |
+
if denoised.ndim == 3 and denoised.shape[1] == 1:
|
| 48 |
+
# Try to find latent_shapes from cond data
|
| 49 |
+
cond = args.get("cond")
|
| 50 |
+
latent_shapes = _extract_latent_shapes(cond)
|
| 51 |
+
if latent_shapes is not None and len(latent_shapes) > 1:
|
| 52 |
+
tensors = comfy.utils.unpack_latents(denoised, latent_shapes)
|
| 53 |
+
# tensors[0] = video [B, 128, F, H, W], tensors[1] = audio [B, ch, T, freq]
|
| 54 |
+
return tensors[0], {"other_tensors": tensors[1:]}
|
| 55 |
+
return denoised, None
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def repack_video_if_needed(modified, pack_info):
|
| 59 |
+
"""Repack video tensor back into LTXAV packed format if it was unpacked.
|
| 60 |
+
|
| 61 |
+
modified: the video tensor after intervention [B, 128, F, H, W]
|
| 62 |
+
pack_info: from unpack_video_if_needed
|
| 63 |
+
"""
|
| 64 |
+
if pack_info is None:
|
| 65 |
+
return modified
|
| 66 |
+
all_tensors = [modified] + pack_info["other_tensors"]
|
| 67 |
+
packed, _ = comfy.utils.pack_latents(all_tensors)
|
| 68 |
+
return packed
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def downsample_mask(mask, h_len, w_len, device, dtype):
|
| 72 |
+
"""Downsample a mask to patch grid and flatten to [1, L, 1]."""
|
| 73 |
+
mask_dev = mask.to(device=device, dtype=dtype)
|
| 74 |
+
if mask_dev.ndim == 3:
|
| 75 |
+
mask_dev = mask_dev[:1]
|
| 76 |
+
if mask_dev.ndim == 2:
|
| 77 |
+
mask_4d = mask_dev.unsqueeze(0).unsqueeze(0) # [1, 1, H, W]
|
| 78 |
+
elif mask_dev.ndim == 3:
|
| 79 |
+
mask_4d = mask_dev.unsqueeze(1) # [B, 1, H, W]
|
| 80 |
+
else:
|
| 81 |
+
mask_4d = mask_dev
|
| 82 |
+
mask_resized = F.interpolate(
|
| 83 |
+
mask_4d, size=(h_len, w_len), mode="bilinear", align_corners=False
|
| 84 |
+
)
|
| 85 |
+
return mask_resized.reshape(1, -1, 1) # [1, L, 1]
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _extract_latent_shapes(cond):
|
| 89 |
+
"""Try to extract latent_shapes from conditioning data.
|
| 90 |
+
|
| 91 |
+
After convert_cond, cond is a list of dicts with 'model_conds' containing
|
| 92 |
+
CONDConstant-wrapped values like 'latent_shapes'.
|
| 93 |
+
"""
|
| 94 |
+
if cond is None:
|
| 95 |
+
return None
|
| 96 |
+
for c in cond:
|
| 97 |
+
if isinstance(c, dict):
|
| 98 |
+
model_conds = c.get('model_conds', {})
|
| 99 |
+
if 'latent_shapes' in model_conds:
|
| 100 |
+
ls = model_conds['latent_shapes']
|
| 101 |
+
# CONDConstant wraps the value in .cond
|
| 102 |
+
if hasattr(ls, 'cond'):
|
| 103 |
+
return ls.cond
|
| 104 |
+
return ls
|
| 105 |
+
return None
|
custom_nodes/ComfyUI-LCS/core/sharpness.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Sharpness subspace calibration via sinusoidal grating stimuli.
|
| 2 |
+
|
| 3 |
+
Replaces the previous Gaussian blur approach with narrowband frequency
|
| 4 |
+
gratings, which achieve higher linearity (R²=0.94 vs 0.88) because each
|
| 5 |
+
stimulus contains a single spatial frequency — a purer probe of the VAE's
|
| 6 |
+
frequency encoding axis.
|
| 7 |
+
|
| 8 |
+
The two methods discover the same 1D subspace (|cos|=0.986, 9.7° apart),
|
| 9 |
+
but grating stimuli yield a cleaner PC1 direction.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import math
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
from typing import List, Optional, Tuple
|
| 15 |
+
import warnings
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import comfy.utils
|
| 19 |
+
|
| 20 |
+
from .patchify import patchify
|
| 21 |
+
from .lcs_data import LCSData
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class SharpnessData:
|
| 26 |
+
"""Calibration data for the sharpness subspace.
|
| 27 |
+
|
| 28 |
+
Produced by PCA on FLUX VAE-encoded sinusoidal gratings at varying
|
| 29 |
+
spatial frequencies. PC1 captures ~94% of variance with R²=0.94
|
| 30 |
+
linearity vs log₂(frequency).
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
basis: torch.Tensor # [64, K] PCA basis (columns), K typically 1-2
|
| 34 |
+
mean: torch.Tensor # [64] PCA mean (in color-removed space if lcs_data was used)
|
| 35 |
+
sign: float # +1 or -1: ensures positive strength = sharper
|
| 36 |
+
lcs_basis: Optional[torch.Tensor] = None # [64, 3] LCS basis used during calibration (for re-orthogonalization)
|
| 37 |
+
|
| 38 |
+
def to(self, device, dtype=None):
|
| 39 |
+
"""Move all tensors to device/dtype."""
|
| 40 |
+
kw = {"device": device}
|
| 41 |
+
if dtype is not None:
|
| 42 |
+
kw["dtype"] = dtype
|
| 43 |
+
return SharpnessData(
|
| 44 |
+
basis=self.basis.to(**kw),
|
| 45 |
+
mean=self.mean.to(**kw),
|
| 46 |
+
sign=self.sign,
|
| 47 |
+
lcs_basis=self.lcs_basis.to(**kw) if self.lcs_basis is not None else None,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _generate_grating_batch(
|
| 52 |
+
indices: List[int],
|
| 53 |
+
angles: torch.Tensor,
|
| 54 |
+
phases: torch.Tensor,
|
| 55 |
+
frequencies: Tuple[float, ...],
|
| 56 |
+
coord_x: torch.Tensor,
|
| 57 |
+
coord_y: torch.Tensor,
|
| 58 |
+
) -> torch.Tensor:
|
| 59 |
+
"""Generate a batch of sinusoidal grating stimuli by flat index.
|
| 60 |
+
|
| 61 |
+
Each flat index maps to (orientation, frequency) via divmod.
|
| 62 |
+
Returns [len(indices), 3, H, W] tensor.
|
| 63 |
+
"""
|
| 64 |
+
num_freqs = len(frequencies)
|
| 65 |
+
batch = []
|
| 66 |
+
for idx in indices:
|
| 67 |
+
ori = idx // num_freqs
|
| 68 |
+
freq = frequencies[idx % num_freqs]
|
| 69 |
+
angle = angles[ori].item()
|
| 70 |
+
phase = phases[ori].item()
|
| 71 |
+
cos_a, sin_a = math.cos(angle), math.sin(angle)
|
| 72 |
+
coord = coord_x * cos_a + coord_y * sin_a
|
| 73 |
+
grating = 0.5 + 0.3 * torch.sin(2 * math.pi * freq * coord + phase)
|
| 74 |
+
batch.append(grating.unsqueeze(0).expand(3, -1, -1))
|
| 75 |
+
return torch.stack(batch, dim=0)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def calibrate_sharpness(vae, num_samples: int = 64, image_size: int = 512,
|
| 79 |
+
frequencies: Tuple[float, ...] = (1, 2, 4, 8, 16, 32, 64),
|
| 80 |
+
batch_size: int = 8,
|
| 81 |
+
lcs_data: LCSData = None,
|
| 82 |
+
# Legacy parameter — accepted but ignored
|
| 83 |
+
blur_levels: Optional[Tuple[float, ...]] = None,
|
| 84 |
+
) -> SharpnessData:
|
| 85 |
+
"""Compute sharpness subspace data (PCA basis, mean, sign) from FLUX VAE.
|
| 86 |
+
|
| 87 |
+
Generates sinusoidal gratings at varying spatial frequencies (one pure
|
| 88 |
+
frequency per stimulus), VAE-encodes them, and runs PCA to find the
|
| 89 |
+
sharpness/frequency direction in 64D patch space.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
vae: ComfyUI VAE object
|
| 93 |
+
num_samples: Number of orientations (each combined with all frequencies)
|
| 94 |
+
image_size: Size of generated images
|
| 95 |
+
frequencies: Spatial frequencies in cycles/image
|
| 96 |
+
batch_size: Batch size for VAE encoding
|
| 97 |
+
lcs_data: Optional LCS data for removing color component during calibration.
|
| 98 |
+
When provided, the sharpness PC1 will be orthogonal to the color subspace,
|
| 99 |
+
preventing color shifts during intervention.
|
| 100 |
+
|
| 101 |
+
Returns: SharpnessData
|
| 102 |
+
"""
|
| 103 |
+
if blur_levels is not None:
|
| 104 |
+
warnings.warn(
|
| 105 |
+
"blur_levels is deprecated and ignored; calibration now uses sinusoidal gratings",
|
| 106 |
+
DeprecationWarning, stacklevel=2,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
n_freqs = len(frequencies)
|
| 110 |
+
total_images = num_samples * n_freqs
|
| 111 |
+
|
| 112 |
+
print(f"\n[LCS Sharpness Calibration] Starting: {num_samples} orientations × {n_freqs} frequencies = {total_images} stimuli")
|
| 113 |
+
print(f"[LCS Sharpness Calibration] Frequencies: {list(frequencies)} cycles/image")
|
| 114 |
+
|
| 115 |
+
# Pre-compute shared state for grating generation
|
| 116 |
+
gen = torch.Generator().manual_seed(42)
|
| 117 |
+
angles = torch.rand(num_samples, generator=gen) * math.pi # [0, π)
|
| 118 |
+
phases = torch.rand(num_samples, generator=gen) * 2 * math.pi # [0, 2π)
|
| 119 |
+
y_coords = torch.linspace(-0.5, 0.5, image_size).unsqueeze(1)
|
| 120 |
+
x_coords = torch.linspace(-0.5, 0.5, image_size).unsqueeze(0)
|
| 121 |
+
coord_y = y_coords.expand(image_size, image_size)
|
| 122 |
+
coord_x = x_coords.expand(image_size, image_size)
|
| 123 |
+
|
| 124 |
+
# Build frequency labels for all stimuli (flat index → frequency)
|
| 125 |
+
freq_labels = [frequencies[idx % n_freqs] for idx in range(total_images)]
|
| 126 |
+
freq_labels_t = torch.tensor(freq_labels, dtype=torch.float32)
|
| 127 |
+
log_freq = torch.log2(freq_labels_t.clamp(min=0.5))
|
| 128 |
+
|
| 129 |
+
# Generate stimuli lazily per batch and VAE encode
|
| 130 |
+
vectors = []
|
| 131 |
+
pbar = comfy.utils.ProgressBar(total_images)
|
| 132 |
+
|
| 133 |
+
for batch_start in range(0, total_images, batch_size):
|
| 134 |
+
batch_end = min(batch_start + batch_size, total_images)
|
| 135 |
+
indices = list(range(batch_start, batch_end))
|
| 136 |
+
batch = _generate_grating_batch(indices, angles, phases, frequencies, coord_x, coord_y)
|
| 137 |
+
actual_batch = batch.shape[0]
|
| 138 |
+
|
| 139 |
+
# Convert BCHW → BHWC for ComfyUI VAE
|
| 140 |
+
imgs_bhwc = batch.permute(0, 2, 3, 1).contiguous().cpu()
|
| 141 |
+
|
| 142 |
+
# VAE encode — try batch first, fall back to per-image for video VAEs
|
| 143 |
+
latent = vae.encode(imgs_bhwc)
|
| 144 |
+
patches, _, _, _ = patchify(latent)
|
| 145 |
+
avg = patches.mean(dim=1).cpu()
|
| 146 |
+
|
| 147 |
+
if avg.shape[0] == actual_batch:
|
| 148 |
+
vectors.extend(avg.unbind(0))
|
| 149 |
+
else:
|
| 150 |
+
# Video VAE: batch not fully supported, encode one by one
|
| 151 |
+
vectors.extend(avg.unbind(0))
|
| 152 |
+
for k in range(1, actual_batch):
|
| 153 |
+
single = imgs_bhwc[k:k+1]
|
| 154 |
+
lat = vae.encode(single)
|
| 155 |
+
p, _, _, _ = patchify(lat)
|
| 156 |
+
vectors.append(p.mean(dim=1).cpu().squeeze(0))
|
| 157 |
+
|
| 158 |
+
pbar.update(actual_batch)
|
| 159 |
+
|
| 160 |
+
# Stack all vectors: [N, 64]
|
| 161 |
+
X = torch.stack(vectors, dim=0).float()
|
| 162 |
+
print(f"[LCS Sharpness Calibration] Collected {X.shape[0]} vectors of dimension {X.shape[1]}")
|
| 163 |
+
|
| 164 |
+
# Remove LCS color component FIRST, in the raw space where LCS was calibrated.
|
| 165 |
+
# This must happen before per-vector DC removal, because the LCS basis has
|
| 166 |
+
# significant DC components (PC1 ≈ brightness). Doing DC removal first would
|
| 167 |
+
# shift vectors into a different space where B^T(x - mu) is incorrect.
|
| 168 |
+
if lcs_data is not None:
|
| 169 |
+
print("[LCS Sharpness Calibration] Removing LCS color component...")
|
| 170 |
+
lcs_mean = lcs_data.mean.to(X.device, X.dtype)
|
| 171 |
+
lcs_basis = lcs_data.basis.to(X.device, X.dtype)
|
| 172 |
+
# Project out color: X' = X - B B^T (X - mu)
|
| 173 |
+
centered = X - lcs_mean
|
| 174 |
+
lcs_coords = centered @ lcs_basis # [N, 3]
|
| 175 |
+
X = X - lcs_coords @ lcs_basis.T
|
| 176 |
+
print("[LCS Sharpness Calibration] Color component removed")
|
| 177 |
+
|
| 178 |
+
# Remove per-vector DC AFTER color removal.
|
| 179 |
+
# VAE encoding shifts the latent mean depending on stimulus content.
|
| 180 |
+
# Per-vector zero-mean forces PCA to find patterns in the relative channel
|
| 181 |
+
# structure, not in the absolute level.
|
| 182 |
+
X = X - X.mean(dim=1, keepdim=True)
|
| 183 |
+
|
| 184 |
+
# Step 3: PCA
|
| 185 |
+
print("[LCS Sharpness Calibration] Computing PCA...")
|
| 186 |
+
mean = X.mean(dim=0) # [64]
|
| 187 |
+
X_centered = X - mean
|
| 188 |
+
U, S, Vh = torch.linalg.svd(X_centered, full_matrices=False)
|
| 189 |
+
# Top 2 components
|
| 190 |
+
basis = Vh[:2].T # [64, 2]
|
| 191 |
+
|
| 192 |
+
# Variance explained
|
| 193 |
+
total_var = (S ** 2).sum()
|
| 194 |
+
explained = (S[:2] ** 2) / total_var
|
| 195 |
+
print(f"[LCS Sharpness Calibration] PC1: {explained[0]:.1%}, PC2: {explained[1]:.1%} ({(explained[0]+explained[1]):.1%} total)")
|
| 196 |
+
|
| 197 |
+
# Step 4: Determine sign convention
|
| 198 |
+
# Project all vectors onto PC1
|
| 199 |
+
pc1_scores = X_centered @ basis[:, 0] # [N]
|
| 200 |
+
|
| 201 |
+
# Correlate PC1 score with log₂(frequency)
|
| 202 |
+
# Higher frequency = sharper → if positive correlation, sign = +1
|
| 203 |
+
correlation = torch.corrcoef(torch.stack([pc1_scores, log_freq]))[0, 1]
|
| 204 |
+
sign = 1.0 if correlation > 0 else -1.0
|
| 205 |
+
print(f"[LCS Sharpness Calibration] PC1-frequency correlation: {correlation:.3f} → sign = {sign:+.0f}")
|
| 206 |
+
print(f"[LCS Sharpness Calibration] Complete! Basis shape: {basis.shape}")
|
| 207 |
+
|
| 208 |
+
return SharpnessData(
|
| 209 |
+
basis=basis,
|
| 210 |
+
mean=mean,
|
| 211 |
+
sign=sign,
|
| 212 |
+
lcs_basis=lcs_data.basis.clone() if lcs_data is not None else None,
|
| 213 |
+
)
|
custom_nodes/ComfyUI-LCS/core/timestep.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Sigma ↔ paper timestep conversion and α_t/β_t interpolation."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from .defaults import get_alpha_table, get_beta_table
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def sigma_to_paper_t(sigma):
|
| 8 |
+
"""Convert FLUX sigma ∈ [0,1] to paper timestep t ∈ [0,50].
|
| 9 |
+
|
| 10 |
+
sigma=1 → noise → t=0, sigma=0 → clean → t=50.
|
| 11 |
+
"""
|
| 12 |
+
if isinstance(sigma, torch.Tensor):
|
| 13 |
+
return 50.0 * (1.0 - sigma.clamp(0.0, 1.0))
|
| 14 |
+
return 50.0 * (1.0 - max(0.0, min(1.0, sigma)))
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_alpha_beta(sigma, device=None):
|
| 18 |
+
"""Get interpolated α_t and β_t [3] vectors for a given sigma.
|
| 19 |
+
|
| 20 |
+
Returns (alpha_t, beta_t) as tensors on the specified device.
|
| 21 |
+
"""
|
| 22 |
+
t = sigma_to_paper_t(sigma)
|
| 23 |
+
if isinstance(t, torch.Tensor):
|
| 24 |
+
t = t.item()
|
| 25 |
+
|
| 26 |
+
alpha_table = get_alpha_table() # [51, 3]
|
| 27 |
+
beta_table = get_beta_table() # [51, 3]
|
| 28 |
+
|
| 29 |
+
t = max(0.0, min(50.0, t))
|
| 30 |
+
t_low = int(t)
|
| 31 |
+
t_high = min(t_low + 1, 50)
|
| 32 |
+
frac = t - t_low
|
| 33 |
+
|
| 34 |
+
alpha = (1.0 - frac) * alpha_table[t_low] + frac * alpha_table[t_high]
|
| 35 |
+
beta = (1.0 - frac) * beta_table[t_low] + frac * beta_table[t_high]
|
| 36 |
+
|
| 37 |
+
if device is not None:
|
| 38 |
+
alpha = alpha.to(device)
|
| 39 |
+
beta = beta.to(device)
|
| 40 |
+
return alpha, beta
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def get_alpha_beta_t50(device=None):
|
| 44 |
+
"""Get α_50 and β_50 (reference timestep t=50, clean image)."""
|
| 45 |
+
alpha_table = get_alpha_table()
|
| 46 |
+
beta_table = get_beta_table()
|
| 47 |
+
alpha_50 = alpha_table[50]
|
| 48 |
+
beta_50 = beta_table[50]
|
| 49 |
+
if device is not None:
|
| 50 |
+
alpha_50 = alpha_50.to(device)
|
| 51 |
+
beta_50 = beta_50.to(device)
|
| 52 |
+
return alpha_50, beta_50
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def normalize_to_t50(c, alpha_t, beta_t, alpha_50, beta_50):
|
| 56 |
+
"""Normalize LCS coords from timestep t to reference t=50.
|
| 57 |
+
|
| 58 |
+
ĉ = (c - α_t) / β_t * β_50 + α_50
|
| 59 |
+
c: [..., 3], alpha_t/beta_t/alpha_50/beta_50: [3]
|
| 60 |
+
"""
|
| 61 |
+
beta_t_safe = beta_t.clone()
|
| 62 |
+
beta_t_safe = torch.where(beta_t_safe.abs() < 1e-6,
|
| 63 |
+
torch.full_like(beta_t_safe, 1e-6), beta_t_safe)
|
| 64 |
+
return (c - alpha_t) / beta_t_safe * beta_50 + alpha_50
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def denormalize_from_t50(c_hat, alpha_t, beta_t, alpha_50, beta_50):
|
| 68 |
+
"""Denormalize LCS coords from reference t=50 back to timestep t.
|
| 69 |
+
|
| 70 |
+
c = (ĉ - α_50) / β_50 * β_t + α_t
|
| 71 |
+
"""
|
| 72 |
+
beta_50_safe = beta_50.clone()
|
| 73 |
+
beta_50_safe = torch.where(beta_50_safe.abs() < 1e-6,
|
| 74 |
+
torch.full_like(beta_50_safe, 1e-6), beta_50_safe)
|
| 75 |
+
return (c_hat - alpha_50) / beta_50_safe * beta_t + alpha_t
|