mhnakif commited on
Commit
18cc273
·
verified ·
1 Parent(s): b049c2f

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/.gitignore +210 -0
  3. custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/LICENSE +674 -0
  4. custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/README.md +68 -0
  5. custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/README_zh.md +66 -0
  6. custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/__init__.py +3 -0
  7. custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/img/preview.jpg +3 -0
  8. custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/nodes.py +553 -0
  9. custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/posi_prompt.pth +3 -0
  10. custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/requirements.txt +10 -0
  11. custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/LICENSE.txt +201 -0
  12. custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/__init__.py +3 -0
  13. custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/configs/__init__.py +0 -0
  14. custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/configs/model_config.py +29 -0
  15. custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/TCDecoder.py +320 -0
  16. custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/__init__.py +1 -0
  17. custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/model_manager.py +402 -0
  18. custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/sparse_sage/LICENSE.txt +201 -0
  19. custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/sparse_sage/core.py +45 -0
  20. custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/sparse_sage/quant_per_block.py +101 -0
  21. custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/sparse_sage/sparse_int8_attn.py +162 -0
  22. custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/utils.py +462 -0
  23. custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/wan_video_dit.py +864 -0
  24. custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/wan_video_vae.py +847 -0
  25. custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/pipelines/__init__.py +3 -0
  26. custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/pipelines/base.py +130 -0
  27. custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/pipelines/flashvsr_full.py +618 -0
  28. custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/pipelines/flashvsr_tiny.py +615 -0
  29. custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/pipelines/flashvsr_tiny_long.py +620 -0
  30. custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/schedulers/__init__.py +1 -0
  31. custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/schedulers/flow_match.py +79 -0
  32. custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/vram_management/__init__.py +1 -0
  33. custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/vram_management/layers.py +95 -0
  34. custom_nodes/ComfyUI-LCS/.gitignore +2 -0
  35. custom_nodes/ComfyUI-LCS/README.md +344 -0
  36. custom_nodes/ComfyUI-LCS/README_zh.md +343 -0
  37. custom_nodes/ComfyUI-LCS/__init__.py +53 -0
  38. custom_nodes/ComfyUI-LCS/core/__init__.py +10 -0
  39. custom_nodes/ComfyUI-LCS/core/adaptive.py +109 -0
  40. custom_nodes/ComfyUI-LCS/core/bilateral.py +79 -0
  41. custom_nodes/ComfyUI-LCS/core/calibration.py +214 -0
  42. custom_nodes/ComfyUI-LCS/core/color_space.py +380 -0
  43. custom_nodes/ComfyUI-LCS/core/defaults.py +65 -0
  44. custom_nodes/ComfyUI-LCS/core/diagnostics.py +246 -0
  45. custom_nodes/ComfyUI-LCS/core/lcs_data.py +28 -0
  46. custom_nodes/ComfyUI-LCS/core/patchify.py +93 -0
  47. custom_nodes/ComfyUI-LCS/core/relationships.py +117 -0
  48. custom_nodes/ComfyUI-LCS/core/sampling.py +105 -0
  49. custom_nodes/ComfyUI-LCS/core/sharpness.py +213 -0
  50. custom_nodes/ComfyUI-LCS/core/timestep.py +75 -0
.gitattributes CHANGED
@@ -142,6 +142,8 @@ models/unet/zit_beyond_reality.safetensors filter=lfs diff=lfs merge=lfs -text
142
  models/vae/ae.safetensors filter=lfs diff=lfs merge=lfs -text
143
  models/vae/flux2-vae.safetensors filter=lfs diff=lfs merge=lfs -text
144
  models/vae/wan_2.1_vae.safetensors filter=lfs diff=lfs merge=lfs -text
 
 
145
  models/FlashVSR/FlashVSR1_1.safetensors filter=lfs diff=lfs merge=lfs -text
146
  models/FlashVSR/LQ_proj_in.safetensors filter=lfs diff=lfs merge=lfs -text
147
  models/FlashVSR/Prompt.safetensors filter=lfs diff=lfs merge=lfs -text
@@ -149,3 +151,7 @@ models/FlashVSR/TCDecoder.safetensors filter=lfs diff=lfs merge=lfs -text
149
  models/FlashVSR/Wan2.1_VAE.safetensors filter=lfs diff=lfs merge=lfs -text
150
  models/FlashVSR/Wan2_1-T2V-1_3B_FlashVSR_fp32.safetensors filter=lfs diff=lfs merge=lfs -text
151
  models/FlashVSR/Wan2_1_FlashVSR_LQ_proj_model_bf16.safetensors filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
142
  models/vae/ae.safetensors filter=lfs diff=lfs merge=lfs -text
143
  models/vae/flux2-vae.safetensors filter=lfs diff=lfs merge=lfs -text
144
  models/vae/wan_2.1_vae.safetensors filter=lfs diff=lfs merge=lfs -text
145
+ custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/img/preview.jpg filter=lfs diff=lfs merge=lfs -text
146
+ custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/posi_prompt.pth filter=lfs diff=lfs merge=lfs -text
147
  models/FlashVSR/FlashVSR1_1.safetensors filter=lfs diff=lfs merge=lfs -text
148
  models/FlashVSR/LQ_proj_in.safetensors filter=lfs diff=lfs merge=lfs -text
149
  models/FlashVSR/Prompt.safetensors filter=lfs diff=lfs merge=lfs -text
 
151
  models/FlashVSR/Wan2.1_VAE.safetensors filter=lfs diff=lfs merge=lfs -text
152
  models/FlashVSR/Wan2_1-T2V-1_3B_FlashVSR_fp32.safetensors filter=lfs diff=lfs merge=lfs -text
153
  models/FlashVSR/Wan2_1_FlashVSR_LQ_proj_model_bf16.safetensors filter=lfs diff=lfs merge=lfs -text
154
+ models/FlashVSR-v1.1/LQ_proj_in.ckpt filter=lfs diff=lfs merge=lfs -text
155
+ models/FlashVSR-v1.1/TCDecoder.ckpt filter=lfs diff=lfs merge=lfs -text
156
+ models/FlashVSR-v1.1/Wan2.1_VAE.pth filter=lfs diff=lfs merge=lfs -text
157
+ models/FlashVSR-v1.1/diffusion_pytorch_model_streaming_dmd.safetensors filter=lfs diff=lfs merge=lfs -text
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/.gitignore ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[codz]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py.cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+ #poetry.toml
110
+
111
+ # pdm
112
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
114
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
115
+ #pdm.lock
116
+ #pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # pixi
121
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
122
+ #pixi.lock
123
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
124
+ # in the .venv directory. It is recommended not to include this directory in version control.
125
+ .pixi
126
+
127
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
128
+ __pypackages__/
129
+
130
+ # Celery stuff
131
+ celerybeat-schedule
132
+ celerybeat.pid
133
+
134
+ # SageMath parsed files
135
+ *.sage.py
136
+
137
+ # Environments
138
+ .env
139
+ .envrc
140
+ .venv
141
+ env/
142
+ venv/
143
+ ENV/
144
+ env.bak/
145
+ venv.bak/
146
+
147
+ # Spyder project settings
148
+ .spyderproject
149
+ .spyproject
150
+
151
+ # Rope project settings
152
+ .ropeproject
153
+
154
+ # mkdocs documentation
155
+ /site
156
+
157
+ # mypy
158
+ .mypy_cache/
159
+ .dmypy.json
160
+ dmypy.json
161
+
162
+ # Pyre type checker
163
+ .pyre/
164
+
165
+ # pytype static type analyzer
166
+ .pytype/
167
+
168
+ # Cython debug symbols
169
+ cython_debug/
170
+
171
+ # PyCharm
172
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
173
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
174
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
175
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
176
+ #.idea/
177
+
178
+ # Abstra
179
+ # Abstra is an AI-powered process automation framework.
180
+ # Ignore directories containing user credentials, local state, and settings.
181
+ # Learn more at https://abstra.io/docs
182
+ .abstra/
183
+
184
+ # Visual Studio Code
185
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
186
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
187
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
188
+ # you could uncomment the following to ignore the entire vscode folder
189
+ # .vscode/
190
+
191
+ # Ruff stuff:
192
+ .ruff_cache/
193
+
194
+ # PyPI configuration file
195
+ .pypirc
196
+
197
+ # Cursor
198
+ # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
199
+ # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
200
+ # refer to https://docs.cursor.com/context/ignore-files
201
+ .cursorignore
202
+ .cursorindexingignore
203
+
204
+ # Marimo
205
+ marimo/_static/
206
+ marimo/_lsp/
207
+ __marimo__/
208
+
209
+ # macOS
210
+ .DS_Store
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/LICENSE ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GNU GENERAL PUBLIC LICENSE
2
+ Version 3, 29 June 2007
3
+
4
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
5
+ Everyone is permitted to copy and distribute verbatim copies
6
+ of this license document, but changing it is not allowed.
7
+
8
+ Preamble
9
+
10
+ The GNU General Public License is a free, copyleft license for
11
+ software and other kinds of works.
12
+
13
+ The licenses for most software and other practical works are designed
14
+ to take away your freedom to share and change the works. By contrast,
15
+ the GNU General Public License is intended to guarantee your freedom to
16
+ share and change all versions of a program--to make sure it remains free
17
+ software for all its users. We, the Free Software Foundation, use the
18
+ GNU General Public License for most of our software; it applies also to
19
+ any other work released this way by its authors. You can apply it to
20
+ your programs, too.
21
+
22
+ When we speak of free software, we are referring to freedom, not
23
+ price. Our General Public Licenses are designed to make sure that you
24
+ have the freedom to distribute copies of free software (and charge for
25
+ them if you wish), that you receive source code or can get it if you
26
+ want it, that you can change the software or use pieces of it in new
27
+ free programs, and that you know you can do these things.
28
+
29
+ To protect your rights, we need to prevent others from denying you
30
+ these rights or asking you to surrender the rights. Therefore, you have
31
+ certain responsibilities if you distribute copies of the software, or if
32
+ you modify it: responsibilities to respect the freedom of others.
33
+
34
+ For example, if you distribute copies of such a program, whether
35
+ gratis or for a fee, you must pass on to the recipients the same
36
+ freedoms that you received. You must make sure that they, too, receive
37
+ or can get the source code. And you must show them these terms so they
38
+ know their rights.
39
+
40
+ Developers that use the GNU GPL protect your rights with two steps:
41
+ (1) assert copyright on the software, and (2) offer you this License
42
+ giving you legal permission to copy, distribute and/or modify it.
43
+
44
+ For the developers' and authors' protection, the GPL clearly explains
45
+ that there is no warranty for this free software. For both users' and
46
+ authors' sake, the GPL requires that modified versions be marked as
47
+ changed, so that their problems will not be attributed erroneously to
48
+ authors of previous versions.
49
+
50
+ Some devices are designed to deny users access to install or run
51
+ modified versions of the software inside them, although the manufacturer
52
+ can do so. This is fundamentally incompatible with the aim of
53
+ protecting users' freedom to change the software. The systematic
54
+ pattern of such abuse occurs in the area of products for individuals to
55
+ use, which is precisely where it is most unacceptable. Therefore, we
56
+ have designed this version of the GPL to prohibit the practice for those
57
+ products. If such problems arise substantially in other domains, we
58
+ stand ready to extend this provision to those domains in future versions
59
+ of the GPL, as needed to protect the freedom of users.
60
+
61
+ Finally, every program is threatened constantly by software patents.
62
+ States should not allow patents to restrict development and use of
63
+ software on general-purpose computers, but in those that do, we wish to
64
+ avoid the special danger that patents applied to a free program could
65
+ make it effectively proprietary. To prevent this, the GPL assures that
66
+ patents cannot be used to render the program non-free.
67
+
68
+ The precise terms and conditions for copying, distribution and
69
+ modification follow.
70
+
71
+ TERMS AND CONDITIONS
72
+
73
+ 0. Definitions.
74
+
75
+ "This License" refers to version 3 of the GNU General Public License.
76
+
77
+ "Copyright" also means copyright-like laws that apply to other kinds of
78
+ works, such as semiconductor masks.
79
+
80
+ "The Program" refers to any copyrightable work licensed under this
81
+ License. Each licensee is addressed as "you". "Licensees" and
82
+ "recipients" may be individuals or organizations.
83
+
84
+ To "modify" a work means to copy from or adapt all or part of the work
85
+ in a fashion requiring copyright permission, other than the making of an
86
+ exact copy. The resulting work is called a "modified version" of the
87
+ earlier work or a work "based on" the earlier work.
88
+
89
+ A "covered work" means either the unmodified Program or a work based
90
+ on the Program.
91
+
92
+ To "propagate" a work means to do anything with it that, without
93
+ permission, would make you directly or secondarily liable for
94
+ infringement under applicable copyright law, except executing it on a
95
+ computer or modifying a private copy. Propagation includes copying,
96
+ distribution (with or without modification), making available to the
97
+ public, and in some countries other activities as well.
98
+
99
+ To "convey" a work means any kind of propagation that enables other
100
+ parties to make or receive copies. Mere interaction with a user through
101
+ a computer network, with no transfer of a copy, is not conveying.
102
+
103
+ An interactive user interface displays "Appropriate Legal Notices"
104
+ to the extent that it includes a convenient and prominently visible
105
+ feature that (1) displays an appropriate copyright notice, and (2)
106
+ tells the user that there is no warranty for the work (except to the
107
+ extent that warranties are provided), that licensees may convey the
108
+ work under this License, and how to view a copy of this License. If
109
+ the interface presents a list of user commands or options, such as a
110
+ menu, a prominent item in the list meets this criterion.
111
+
112
+ 1. Source Code.
113
+
114
+ The "source code" for a work means the preferred form of the work
115
+ for making modifications to it. "Object code" means any non-source
116
+ form of a work.
117
+
118
+ A "Standard Interface" means an interface that either is an official
119
+ standard defined by a recognized standards body, or, in the case of
120
+ interfaces specified for a particular programming language, one that
121
+ is widely used among developers working in that language.
122
+
123
+ The "System Libraries" of an executable work include anything, other
124
+ than the work as a whole, that (a) is included in the normal form of
125
+ packaging a Major Component, but which is not part of that Major
126
+ Component, and (b) serves only to enable use of the work with that
127
+ Major Component, or to implement a Standard Interface for which an
128
+ implementation is available to the public in source code form. A
129
+ "Major Component", in this context, means a major essential component
130
+ (kernel, window system, and so on) of the specific operating system
131
+ (if any) on which the executable work runs, or a compiler used to
132
+ produce the work, or an object code interpreter used to run it.
133
+
134
+ The "Corresponding Source" for a work in object code form means all
135
+ the source code needed to generate, install, and (for an executable
136
+ work) run the object code and to modify the work, including scripts to
137
+ control those activities. However, it does not include the work's
138
+ System Libraries, or general-purpose tools or generally available free
139
+ programs which are used unmodified in performing those activities but
140
+ which are not part of the work. For example, Corresponding Source
141
+ includes interface definition files associated with source files for
142
+ the work, and the source code for shared libraries and dynamically
143
+ linked subprograms that the work is specifically designed to require,
144
+ such as by intimate data communication or control flow between those
145
+ subprograms and other parts of the work.
146
+
147
+ The Corresponding Source need not include anything that users
148
+ can regenerate automatically from other parts of the Corresponding
149
+ Source.
150
+
151
+ The Corresponding Source for a work in source code form is that
152
+ same work.
153
+
154
+ 2. Basic Permissions.
155
+
156
+ All rights granted under this License are granted for the term of
157
+ copyright on the Program, and are irrevocable provided the stated
158
+ conditions are met. This License explicitly affirms your unlimited
159
+ permission to run the unmodified Program. The output from running a
160
+ covered work is covered by this License only if the output, given its
161
+ content, constitutes a covered work. This License acknowledges your
162
+ rights of fair use or other equivalent, as provided by copyright law.
163
+
164
+ You may make, run and propagate covered works that you do not
165
+ convey, without conditions so long as your license otherwise remains
166
+ in force. You may convey covered works to others for the sole purpose
167
+ of having them make modifications exclusively for you, or provide you
168
+ with facilities for running those works, provided that you comply with
169
+ the terms of this License in conveying all material for which you do
170
+ not control copyright. Those thus making or running the covered works
171
+ for you must do so exclusively on your behalf, under your direction
172
+ and control, on terms that prohibit them from making any copies of
173
+ your copyrighted material outside their relationship with you.
174
+
175
+ Conveying under any other circumstances is permitted solely under
176
+ the conditions stated below. Sublicensing is not allowed; section 10
177
+ makes it unnecessary.
178
+
179
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180
+
181
+ No covered work shall be deemed part of an effective technological
182
+ measure under any applicable law fulfilling obligations under article
183
+ 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184
+ similar laws prohibiting or restricting circumvention of such
185
+ measures.
186
+
187
+ When you convey a covered work, you waive any legal power to forbid
188
+ circumvention of technological measures to the extent such circumvention
189
+ is effected by exercising rights under this License with respect to
190
+ the covered work, and you disclaim any intention to limit operation or
191
+ modification of the work as a means of enforcing, against the work's
192
+ users, your or third parties' legal rights to forbid circumvention of
193
+ technological measures.
194
+
195
+ 4. Conveying Verbatim Copies.
196
+
197
+ You may convey verbatim copies of the Program's source code as you
198
+ receive it, in any medium, provided that you conspicuously and
199
+ appropriately publish on each copy an appropriate copyright notice;
200
+ keep intact all notices stating that this License and any
201
+ non-permissive terms added in accord with section 7 apply to the code;
202
+ keep intact all notices of the absence of any warranty; and give all
203
+ recipients a copy of this License along with the Program.
204
+
205
+ You may charge any price or no price for each copy that you convey,
206
+ and you may offer support or warranty protection for a fee.
207
+
208
+ 5. Conveying Modified Source Versions.
209
+
210
+ You may convey a work based on the Program, or the modifications to
211
+ produce it from the Program, in the form of source code under the
212
+ terms of section 4, provided that you also meet all of these conditions:
213
+
214
+ a) The work must carry prominent notices stating that you modified
215
+ it, and giving a relevant date.
216
+
217
+ b) The work must carry prominent notices stating that it is
218
+ released under this License and any conditions added under section
219
+ 7. This requirement modifies the requirement in section 4 to
220
+ "keep intact all notices".
221
+
222
+ c) You must license the entire work, as a whole, under this
223
+ License to anyone who comes into possession of a copy. This
224
+ License will therefore apply, along with any applicable section 7
225
+ additional terms, to the whole of the work, and all its parts,
226
+ regardless of how they are packaged. This License gives no
227
+ permission to license the work in any other way, but it does not
228
+ invalidate such permission if you have separately received it.
229
+
230
+ d) If the work has interactive user interfaces, each must display
231
+ Appropriate Legal Notices; however, if the Program has interactive
232
+ interfaces that do not display Appropriate Legal Notices, your
233
+ work need not make them do so.
234
+
235
+ A compilation of a covered work with other separate and independent
236
+ works, which are not by their nature extensions of the covered work,
237
+ and which are not combined with it such as to form a larger program,
238
+ in or on a volume of a storage or distribution medium, is called an
239
+ "aggregate" if the compilation and its resulting copyright are not
240
+ used to limit the access or legal rights of the compilation's users
241
+ beyond what the individual works permit. Inclusion of a covered work
242
+ in an aggregate does not cause this License to apply to the other
243
+ parts of the aggregate.
244
+
245
+ 6. Conveying Non-Source Forms.
246
+
247
+ You may convey a covered work in object code form under the terms
248
+ of sections 4 and 5, provided that you also convey the
249
+ machine-readable Corresponding Source under the terms of this License,
250
+ in one of these ways:
251
+
252
+ a) Convey the object code in, or embodied in, a physical product
253
+ (including a physical distribution medium), accompanied by the
254
+ Corresponding Source fixed on a durable physical medium
255
+ customarily used for software interchange.
256
+
257
+ b) Convey the object code in, or embodied in, a physical product
258
+ (including a physical distribution medium), accompanied by a
259
+ written offer, valid for at least three years and valid for as
260
+ long as you offer spare parts or customer support for that product
261
+ model, to give anyone who possesses the object code either (1) a
262
+ copy of the Corresponding Source for all the software in the
263
+ product that is covered by this License, on a durable physical
264
+ medium customarily used for software interchange, for a price no
265
+ more than your reasonable cost of physically performing this
266
+ conveying of source, or (2) access to copy the
267
+ Corresponding Source from a network server at no charge.
268
+
269
+ c) Convey individual copies of the object code with a copy of the
270
+ written offer to provide the Corresponding Source. This
271
+ alternative is allowed only occasionally and noncommercially, and
272
+ only if you received the object code with such an offer, in accord
273
+ with subsection 6b.
274
+
275
+ d) Convey the object code by offering access from a designated
276
+ place (gratis or for a charge), and offer equivalent access to the
277
+ Corresponding Source in the same way through the same place at no
278
+ further charge. You need not require recipients to copy the
279
+ Corresponding Source along with the object code. If the place to
280
+ copy the object code is a network server, the Corresponding Source
281
+ may be on a different server (operated by you or a third party)
282
+ that supports equivalent copying facilities, provided you maintain
283
+ clear directions next to the object code saying where to find the
284
+ Corresponding Source. Regardless of what server hosts the
285
+ Corresponding Source, you remain obligated to ensure that it is
286
+ available for as long as needed to satisfy these requirements.
287
+
288
+ e) Convey the object code using peer-to-peer transmission, provided
289
+ you inform other peers where the object code and Corresponding
290
+ Source of the work are being offered to the general public at no
291
+ charge under subsection 6d.
292
+
293
+ A separable portion of the object code, whose source code is excluded
294
+ from the Corresponding Source as a System Library, need not be
295
+ included in conveying the object code work.
296
+
297
+ A "User Product" is either (1) a "consumer product", which means any
298
+ tangible personal property which is normally used for personal, family,
299
+ or household purposes, or (2) anything designed or sold for incorporation
300
+ into a dwelling. In determining whether a product is a consumer product,
301
+ doubtful cases shall be resolved in favor of coverage. For a particular
302
+ product received by a particular user, "normally used" refers to a
303
+ typical or common use of that class of product, regardless of the status
304
+ of the particular user or of the way in which the particular user
305
+ actually uses, or expects or is expected to use, the product. A product
306
+ is a consumer product regardless of whether the product has substantial
307
+ commercial, industrial or non-consumer uses, unless such uses represent
308
+ the only significant mode of use of the product.
309
+
310
+ "Installation Information" for a User Product means any methods,
311
+ procedures, authorization keys, or other information required to install
312
+ and execute modified versions of a covered work in that User Product from
313
+ a modified version of its Corresponding Source. The information must
314
+ suffice to ensure that the continued functioning of the modified object
315
+ code is in no case prevented or interfered with solely because
316
+ modification has been made.
317
+
318
+ If you convey an object code work under this section in, or with, or
319
+ specifically for use in, a User Product, and the conveying occurs as
320
+ part of a transaction in which the right of possession and use of the
321
+ User Product is transferred to the recipient in perpetuity or for a
322
+ fixed term (regardless of how the transaction is characterized), the
323
+ Corresponding Source conveyed under this section must be accompanied
324
+ by the Installation Information. But this requirement does not apply
325
+ if neither you nor any third party retains the ability to install
326
+ modified object code on the User Product (for example, the work has
327
+ been installed in ROM).
328
+
329
+ The requirement to provide Installation Information does not include a
330
+ requirement to continue to provide support service, warranty, or updates
331
+ for a work that has been modified or installed by the recipient, or for
332
+ the User Product in which it has been modified or installed. Access to a
333
+ network may be denied when the modification itself materially and
334
+ adversely affects the operation of the network or violates the rules and
335
+ protocols for communication across the network.
336
+
337
+ Corresponding Source conveyed, and Installation Information provided,
338
+ in accord with this section must be in a format that is publicly
339
+ documented (and with an implementation available to the public in
340
+ source code form), and must require no special password or key for
341
+ unpacking, reading or copying.
342
+
343
+ 7. Additional Terms.
344
+
345
+ "Additional permissions" are terms that supplement the terms of this
346
+ License by making exceptions from one or more of its conditions.
347
+ Additional permissions that are applicable to the entire Program shall
348
+ be treated as though they were included in this License, to the extent
349
+ that they are valid under applicable law. If additional permissions
350
+ apply only to part of the Program, that part may be used separately
351
+ under those permissions, but the entire Program remains governed by
352
+ this License without regard to the additional permissions.
353
+
354
+ When you convey a copy of a covered work, you may at your option
355
+ remove any additional permissions from that copy, or from any part of
356
+ it. (Additional permissions may be written to require their own
357
+ removal in certain cases when you modify the work.) You may place
358
+ additional permissions on material, added by you to a covered work,
359
+ for which you have or can give appropriate copyright permission.
360
+
361
+ Notwithstanding any other provision of this License, for material you
362
+ add to a covered work, you may (if authorized by the copyright holders of
363
+ that material) supplement the terms of this License with terms:
364
+
365
+ a) Disclaiming warranty or limiting liability differently from the
366
+ terms of sections 15 and 16 of this License; or
367
+
368
+ b) Requiring preservation of specified reasonable legal notices or
369
+ author attributions in that material or in the Appropriate Legal
370
+ Notices displayed by works containing it; or
371
+
372
+ c) Prohibiting misrepresentation of the origin of that material, or
373
+ requiring that modified versions of such material be marked in
374
+ reasonable ways as different from the original version; or
375
+
376
+ d) Limiting the use for publicity purposes of names of licensors or
377
+ authors of the material; or
378
+
379
+ e) Declining to grant rights under trademark law for use of some
380
+ trade names, trademarks, or service marks; or
381
+
382
+ f) Requiring indemnification of licensors and authors of that
383
+ material by anyone who conveys the material (or modified versions of
384
+ it) with contractual assumptions of liability to the recipient, for
385
+ any liability that these contractual assumptions directly impose on
386
+ those licensors and authors.
387
+
388
+ All other non-permissive additional terms are considered "further
389
+ restrictions" within the meaning of section 10. If the Program as you
390
+ received it, or any part of it, contains a notice stating that it is
391
+ governed by this License along with a term that is a further
392
+ restriction, you may remove that term. If a license document contains
393
+ a further restriction but permits relicensing or conveying under this
394
+ License, you may add to a covered work material governed by the terms
395
+ of that license document, provided that the further restriction does
396
+ not survive such relicensing or conveying.
397
+
398
+ If you add terms to a covered work in accord with this section, you
399
+ must place, in the relevant source files, a statement of the
400
+ additional terms that apply to those files, or a notice indicating
401
+ where to find the applicable terms.
402
+
403
+ Additional terms, permissive or non-permissive, may be stated in the
404
+ form of a separately written license, or stated as exceptions;
405
+ the above requirements apply either way.
406
+
407
+ 8. Termination.
408
+
409
+ You may not propagate or modify a covered work except as expressly
410
+ provided under this License. Any attempt otherwise to propagate or
411
+ modify it is void, and will automatically terminate your rights under
412
+ this License (including any patent licenses granted under the third
413
+ paragraph of section 11).
414
+
415
+ However, if you cease all violation of this License, then your
416
+ license from a particular copyright holder is reinstated (a)
417
+ provisionally, unless and until the copyright holder explicitly and
418
+ finally terminates your license, and (b) permanently, if the copyright
419
+ holder fails to notify you of the violation by some reasonable means
420
+ prior to 60 days after the cessation.
421
+
422
+ Moreover, your license from a particular copyright holder is
423
+ reinstated permanently if the copyright holder notifies you of the
424
+ violation by some reasonable means, this is the first time you have
425
+ received notice of violation of this License (for any work) from that
426
+ copyright holder, and you cure the violation prior to 30 days after
427
+ your receipt of the notice.
428
+
429
+ Termination of your rights under this section does not terminate the
430
+ licenses of parties who have received copies or rights from you under
431
+ this License. If your rights have been terminated and not permanently
432
+ reinstated, you do not qualify to receive new licenses for the same
433
+ material under section 10.
434
+
435
+ 9. Acceptance Not Required for Having Copies.
436
+
437
+ You are not required to accept this License in order to receive or
438
+ run a copy of the Program. Ancillary propagation of a covered work
439
+ occurring solely as a consequence of using peer-to-peer transmission
440
+ to receive a copy likewise does not require acceptance. However,
441
+ nothing other than this License grants you permission to propagate or
442
+ modify any covered work. These actions infringe copyright if you do
443
+ not accept this License. Therefore, by modifying or propagating a
444
+ covered work, you indicate your acceptance of this License to do so.
445
+
446
+ 10. Automatic Licensing of Downstream Recipients.
447
+
448
+ Each time you convey a covered work, the recipient automatically
449
+ receives a license from the original licensors, to run, modify and
450
+ propagate that work, subject to this License. You are not responsible
451
+ for enforcing compliance by third parties with this License.
452
+
453
+ An "entity transaction" is a transaction transferring control of an
454
+ organization, or substantially all assets of one, or subdividing an
455
+ organization, or merging organizations. If propagation of a covered
456
+ work results from an entity transaction, each party to that
457
+ transaction who receives a copy of the work also receives whatever
458
+ licenses to the work the party's predecessor in interest had or could
459
+ give under the previous paragraph, plus a right to possession of the
460
+ Corresponding Source of the work from the predecessor in interest, if
461
+ the predecessor has it or can get it with reasonable efforts.
462
+
463
+ You may not impose any further restrictions on the exercise of the
464
+ rights granted or affirmed under this License. For example, you may
465
+ not impose a license fee, royalty, or other charge for exercise of
466
+ rights granted under this License, and you may not initiate litigation
467
+ (including a cross-claim or counterclaim in a lawsuit) alleging that
468
+ any patent claim is infringed by making, using, selling, offering for
469
+ sale, or importing the Program or any portion of it.
470
+
471
+ 11. Patents.
472
+
473
+ A "contributor" is a copyright holder who authorizes use under this
474
+ License of the Program or a work on which the Program is based. The
475
+ work thus licensed is called the contributor's "contributor version".
476
+
477
+ A contributor's "essential patent claims" are all patent claims
478
+ owned or controlled by the contributor, whether already acquired or
479
+ hereafter acquired, that would be infringed by some manner, permitted
480
+ by this License, of making, using, or selling its contributor version,
481
+ but do not include claims that would be infringed only as a
482
+ consequence of further modification of the contributor version. For
483
+ purposes of this definition, "control" includes the right to grant
484
+ patent sublicenses in a manner consistent with the requirements of
485
+ this License.
486
+
487
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
488
+ patent license under the contributor's essential patent claims, to
489
+ make, use, sell, offer for sale, import and otherwise run, modify and
490
+ propagate the contents of its contributor version.
491
+
492
+ In the following three paragraphs, a "patent license" is any express
493
+ agreement or commitment, however denominated, not to enforce a patent
494
+ (such as an express permission to practice a patent or covenant not to
495
+ sue for patent infringement). To "grant" such a patent license to a
496
+ party means to make such an agreement or commitment not to enforce a
497
+ patent against the party.
498
+
499
+ If you convey a covered work, knowingly relying on a patent license,
500
+ and the Corresponding Source of the work is not available for anyone
501
+ to copy, free of charge and under the terms of this License, through a
502
+ publicly available network server or other readily accessible means,
503
+ then you must either (1) cause the Corresponding Source to be so
504
+ available, or (2) arrange to deprive yourself of the benefit of the
505
+ patent license for this particular work, or (3) arrange, in a manner
506
+ consistent with the requirements of this License, to extend the patent
507
+ license to downstream recipients. "Knowingly relying" means you have
508
+ actual knowledge that, but for the patent license, your conveying the
509
+ covered work in a country, or your recipient's use of the covered work
510
+ in a country, would infringe one or more identifiable patents in that
511
+ country that you have reason to believe are valid.
512
+
513
+ If, pursuant to or in connection with a single transaction or
514
+ arrangement, you convey, or propagate by procuring conveyance of, a
515
+ covered work, and grant a patent license to some of the parties
516
+ receiving the covered work authorizing them to use, propagate, modify
517
+ or convey a specific copy of the covered work, then the patent license
518
+ you grant is automatically extended to all recipients of the covered
519
+ work and works based on it.
520
+
521
+ A patent license is "discriminatory" if it does not include within
522
+ the scope of its coverage, prohibits the exercise of, or is
523
+ conditioned on the non-exercise of one or more of the rights that are
524
+ specifically granted under this License. You may not convey a covered
525
+ work if you are a party to an arrangement with a third party that is
526
+ in the business of distributing software, under which you make payment
527
+ to the third party based on the extent of your activity of conveying
528
+ the work, and under which the third party grants, to any of the
529
+ parties who would receive the covered work from you, a discriminatory
530
+ patent license (a) in connection with copies of the covered work
531
+ conveyed by you (or copies made from those copies), or (b) primarily
532
+ for and in connection with specific products or compilations that
533
+ contain the covered work, unless you entered into that arrangement,
534
+ or that patent license was granted, prior to 28 March 2007.
535
+
536
+ Nothing in this License shall be construed as excluding or limiting
537
+ any implied license or other defenses to infringement that may
538
+ otherwise be available to you under applicable patent law.
539
+
540
+ 12. No Surrender of Others' Freedom.
541
+
542
+ If conditions are imposed on you (whether by court order, agreement or
543
+ otherwise) that contradict the conditions of this License, they do not
544
+ excuse you from the conditions of this License. If you cannot convey a
545
+ covered work so as to satisfy simultaneously your obligations under this
546
+ License and any other pertinent obligations, then as a consequence you may
547
+ not convey it at all. For example, if you agree to terms that obligate you
548
+ to collect a royalty for further conveying from those to whom you convey
549
+ the Program, the only way you could satisfy both those terms and this
550
+ License would be to refrain entirely from conveying the Program.
551
+
552
+ 13. Use with the GNU Affero General Public License.
553
+
554
+ Notwithstanding any other provision of this License, you have
555
+ permission to link or combine any covered work with a work licensed
556
+ under version 3 of the GNU Affero General Public License into a single
557
+ combined work, and to convey the resulting work. The terms of this
558
+ License will continue to apply to the part which is the covered work,
559
+ but the special requirements of the GNU Affero General Public License,
560
+ section 13, concerning interaction through a network will apply to the
561
+ combination as such.
562
+
563
+ 14. Revised Versions of this License.
564
+
565
+ The Free Software Foundation may publish revised and/or new versions of
566
+ the GNU General Public License from time to time. Such new versions will
567
+ be similar in spirit to the present version, but may differ in detail to
568
+ address new problems or concerns.
569
+
570
+ Each version is given a distinguishing version number. If the
571
+ Program specifies that a certain numbered version of the GNU General
572
+ Public License "or any later version" applies to it, you have the
573
+ option of following the terms and conditions either of that numbered
574
+ version or of any later version published by the Free Software
575
+ Foundation. If the Program does not specify a version number of the
576
+ GNU General Public License, you may choose any version ever published
577
+ by the Free Software Foundation.
578
+
579
+ If the Program specifies that a proxy can decide which future
580
+ versions of the GNU General Public License can be used, that proxy's
581
+ public statement of acceptance of a version permanently authorizes you
582
+ to choose that version for the Program.
583
+
584
+ Later license versions may give you additional or different
585
+ permissions. However, no additional obligations are imposed on any
586
+ author or copyright holder as a result of your choosing to follow a
587
+ later version.
588
+
589
+ 15. Disclaimer of Warranty.
590
+
591
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592
+ APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593
+ HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594
+ OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595
+ THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596
+ PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597
+ IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598
+ ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599
+
600
+ 16. Limitation of Liability.
601
+
602
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603
+ WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604
+ THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605
+ GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606
+ USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607
+ DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608
+ PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609
+ EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610
+ SUCH DAMAGES.
611
+
612
+ 17. Interpretation of Sections 15 and 16.
613
+
614
+ If the disclaimer of warranty and limitation of liability provided
615
+ above cannot be given local legal effect according to their terms,
616
+ reviewing courts shall apply local law that most closely approximates
617
+ an absolute waiver of all civil liability in connection with the
618
+ Program, unless a warranty or assumption of liability accompanies a
619
+ copy of the Program in return for a fee.
620
+
621
+ END OF TERMS AND CONDITIONS
622
+
623
+ How to Apply These Terms to Your New Programs
624
+
625
+ If you develop a new program, and you want it to be of the greatest
626
+ possible use to the public, the best way to achieve this is to make it
627
+ free software which everyone can redistribute and change under these terms.
628
+
629
+ To do so, attach the following notices to the program. It is safest
630
+ to attach them to the start of each source file to most effectively
631
+ state the exclusion of warranty; and each file should have at least
632
+ the "copyright" line and a pointer to where the full notice is found.
633
+
634
+ <one line to give the program's name and a brief idea of what it does.>
635
+ Copyright (C) <year> <name of author>
636
+
637
+ This program is free software: you can redistribute it and/or modify
638
+ it under the terms of the GNU General Public License as published by
639
+ the Free Software Foundation, either version 3 of the License, or
640
+ (at your option) any later version.
641
+
642
+ This program is distributed in the hope that it will be useful,
643
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
644
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645
+ GNU General Public License for more details.
646
+
647
+ You should have received a copy of the GNU General Public License
648
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
649
+
650
+ Also add information on how to contact you by electronic and paper mail.
651
+
652
+ If the program does terminal interaction, make it output a short
653
+ notice like this when it starts in an interactive mode:
654
+
655
+ <program> Copyright (C) <year> <name of author>
656
+ This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657
+ This is free software, and you are welcome to redistribute it
658
+ under certain conditions; type `show c' for details.
659
+
660
+ The hypothetical commands `show w' and `show c' should show the appropriate
661
+ parts of the General Public License. Of course, your program's commands
662
+ might be different; for a GUI interface, you would use an "about box".
663
+
664
+ You should also get your employer (if you work as a programmer) or school,
665
+ if any, to sign a "copyright disclaimer" for the program, if necessary.
666
+ For more information on this, and how to apply and follow the GNU GPL, see
667
+ <https://www.gnu.org/licenses/>.
668
+
669
+ The GNU General Public License does not permit incorporating your program
670
+ into proprietary programs. If your program is a subroutine library, you
671
+ may consider it more useful to permit linking proprietary applications with
672
+ the library. If this is what you want to do, use the GNU Lesser General
673
+ Public License instead of this License. But first, please read
674
+ <https://www.gnu.org/licenses/why-not-lgpl.html>.
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/README.md ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ComfyUI-FlashVSR_Ultra_Fast
2
+ Running FlashVSR on lower VRAM without any artifacts.
3
+ **[[📃中文版本](./README_zh.md)]**
4
+
5
+ ## Changelog
6
+ #### 2025-10-24
7
+ - Added long video pipeline that significantly reduces VRAM usage when upscaling long videos.
8
+
9
+ #### 2025-10-21
10
+ - Initial this project, introducing features such as `tile_dit` to significantly reducing VRAM usage.
11
+
12
+ #### 2025-10-22
13
+ - Replaced `Block-Sparse-Attention` with `Sparse_Sage`, removing the need to compile any custom kernels.
14
+ - Added support for running on RTX 50 series GPUs.
15
+
16
+ ## Preview
17
+ ![](./img/preview.jpg)
18
+
19
+ ## Usage
20
+ - **mode:**
21
+ `tiny` -> faster (default); `full` -> higher quality
22
+ - **scale:**
23
+ `4` is always better, unless you are low on VRAM then use `2`
24
+ - **color_fix:**
25
+ Use wavelet transform to correct the color of output video.
26
+ - **tiled_vae:**
27
+ Set to True for lower VRAM consumption during decoding at the cost of speed.
28
+ - **tiled_dit:**
29
+ Significantly reduces VRAM usage at the cost of speed.
30
+ - **tile\_size, tile\_overlap**:
31
+ How to split the input video.
32
+ - **unload_dit:**
33
+ Unload DiT before decoding to reduce VRAM peak at the cost of speed.
34
+
35
+ ## Installation
36
+
37
+ #### nodes:
38
+
39
+ ```bash
40
+ cd ComfyUI/custom_nodes
41
+ git clone https://github.com/lihaoyun6/ComfyUI-FlashVSR_Ultra_Fast.git
42
+ python -m pip install -r ComfyUI-FlashVSR_Ultra_Fast/requirements.txt
43
+ ```
44
+ 📢: For Turing or older GPU, please install `triton<3.3.0`:
45
+
46
+ ```bash
47
+ # Windows
48
+ python -m pip install -U triton-windows<3.3.0
49
+ # Linux
50
+ python -m pip install -U triton<3.3.0
51
+ ```
52
+
53
+ #### models:
54
+
55
+ - Download the entire `FlashVSR` folder with all the files inside it from [here](https://huggingface.co/JunhaoZhuang/FlashVSR) and put it in the `ComfyUI/models`
56
+
57
+ ```
58
+ ├── ComfyUI/models/FlashVSR
59
+ | ├── LQ_proj_in.ckpt
60
+ | ├── TCDecoder.ckpt
61
+ | ├── diffusion_pytorch_model_streaming_dmd.safetensors
62
+ | ├── Wan2.1_VAE.pth
63
+ ```
64
+
65
+ ## Acknowledgments
66
+ - [FlashVSR](https://github.com/OpenImagingLab/FlashVSR) @OpenImagingLab
67
+ - [Sparse_SageAttention](https://github.com/jt-zhang/Sparse_SageAttention_API) @jt-zhang
68
+ - [ComfyUI](https://github.com/comfyanonymous/ComfyUI) @comfyanonymous
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/README_zh.md ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ComfyUI-FlashVSR_Ultra_Fast
2
+ 在低显存环境下运行 FlashVSR,同时保持无伪影高质量输出。
3
+ **[[📃English](./readme.md)]**
4
+
5
+ ## 更新日志
6
+ #### 2025-10-24
7
+ - 新增长视频管道, 可显著降低长视频放大的显存用量
8
+
9
+ #### 2025-10-21
10
+ - 项目首次发布, 引入了`tile_dit`等功能, 大幅度降低显存需求
11
+
12
+ #### 2025-10-22
13
+ - 使用`Sparse_SageAttention`替换了`Block-Sparse-Attention`, 无需编译安装任何自定义内核, 开箱即用.
14
+ - 支持在 RTX50 系列显卡上运行.
15
+
16
+ ## 预览
17
+ ![](./img/preview.jpg)
18
+
19
+ ## 使用说明
20
+ - **mode(模式):**
21
+ `tiny` → 更快(默认);`full` → 更高质量
22
+ - **scale(放大倍数):**
23
+ 通常使用 `4` 效果更好;如果显存不足,可使用 `2`
24
+ - **color_fix(颜色修正):**
25
+ 使用小波变换方法修正输出视频的颜色偏差。
26
+ - **tiled_vae(VAE分块解码):**
27
+ 启用后可显著降低显存占用,但会降低解码速度。
28
+ - **tiled_dit(DiT分块计算):**
29
+ 大幅减少显存占用,但会降低推理速度。
30
+ - **tile_size / tile_overlap(分块大小与重叠):**
31
+ 控制输入视频在推理时的分块方式。
32
+ - **unload_dit(卸载DiT模型):**
33
+ 解码前卸载 DiT 模型以降低显存峰值,但会略微降低速度。
34
+
35
+ ## 安装步骤
36
+
37
+ #### 安装节点:
38
+ ```bash
39
+ cd ComfyUI/custom_nodes
40
+ git clone https://github.com/lihaoyun6/ComfyUI-FlashVSR_Ultra_Fast.git
41
+ python -m pip install -r ComfyUI-FlashVSR_Ultra_Fast/requirements.txt
42
+ ```
43
+ 📢: 要在RTX20系或更早的GPU上运行, 请安装`triton<3.3.0`:
44
+
45
+ ```bash
46
+ # Windows
47
+ python -m pip install -U triton-windows<3.3.0
48
+ # Linux
49
+ python -m pip install -U triton<3.3.0
50
+ ```
51
+
52
+ #### 模型下载:
53
+ - 从[这里](https://huggingface.co/JunhaoZhuang/FlashVSR)下载整个`FlashVSR`文件夹和它里面的所有文件, 并将其放到`ComfyUI/models`目录中。
54
+
55
+ ```
56
+ ├── ComfyUI/models/FlashVSR
57
+ | ├── LQ_proj_in.ckpt
58
+ | ├── TCDecoder.ckpt
59
+ | ├── diffusion_pytorch_model_streaming_dmd.safetensors
60
+ | ├── Wan2.1_VAE.pth
61
+ ```
62
+
63
+ ## 致谢
64
+ - [FlashVSR](https://github.com/OpenImagingLab/FlashVSR) @OpenImagingLab
65
+ - [Sparse_SageAttention](https://github.com/jt-zhang/Sparse_SageAttention_API) @jt-zhang
66
+ - [ComfyUI](https://github.com/comfyanonymous/ComfyUI) @comfyanonymous
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
2
+
3
+ __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/img/preview.jpg ADDED

Git LFS Details

  • SHA256: ad7cc28a6c911472d5653b7c90aa8ca0737c42f34fa82b5f093e48af53039c0e
  • Pointer size: 131 Bytes
  • Size of remote file: 776 kB
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/nodes.py ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import os,gc
5
+ import math
6
+ import torch
7
+ import folder_paths
8
+ import comfy.utils
9
+
10
+ import numpy as np
11
+ import torch.nn.functional as F
12
+
13
+ from einops import rearrange
14
+ from huggingface_hub import snapshot_download
15
+ from .src import ModelManager, FlashVSRFullPipeline, FlashVSRTinyPipeline, FlashVSRTinyLongPipeline
16
+ from .src.models.TCDecoder import build_tcdecoder
17
+ from .src.models.utils import clean_vram, get_device_list, Buffer_LQ4x_Proj, Causal_LQ4x_Proj
18
+ from .src.models import wan_video_dit
19
+
20
+ device_choices = get_device_list()
21
+
22
+ def log(message:str, message_type:str='normal'):
23
+ if message_type == 'error':
24
+ message = '\033[1;41m' + message + '\033[m'
25
+ elif message_type == 'warning':
26
+ message = '\033[1;31m' + message + '\033[m'
27
+ elif message_type == 'finish':
28
+ message = '\033[1;32m' + message + '\033[m'
29
+ elif message_type == 'info':
30
+ message = '\033[1;33m' + message + '\033[m'
31
+ else:
32
+ message = message
33
+ print(f"{message}")
34
+
35
+ def model_downlod(model_name="JunhaoZhuang/FlashVSR"):
36
+ model_dir = os.path.join(folder_paths.models_dir, model_name.split("/")[-1])
37
+ if not os.path.exists(model_dir):
38
+ log(f"Downloading model '{model_name}' from huggingface...", message_type='info')
39
+ snapshot_download(repo_id=model_name, local_dir=model_dir, local_dir_use_symlinks=False, resume_download=True)
40
+
41
+ def tensor2video(frames: torch.Tensor):
42
+ video_squeezed = frames.squeeze(0)
43
+ video_permuted = rearrange(video_squeezed, "C F H W -> F H W C")
44
+ video_final = (video_permuted.float() + 1.0) / 2.0
45
+ return video_final
46
+
47
+ def largest_8n1_leq(n): # 8n+1
48
+ return 0 if n < 1 else ((n - 1)//8)*8 + 1
49
+
50
+ def next_8n5(n): # next 8n+5
51
+ return 21 if n < 21 else ((n - 5 + 7) // 8) * 8 + 5
52
+
53
+ def compute_scaled_and_target_dims(w0: int, h0: int, scale: int = 4, multiple: int = 128):
54
+ if w0 <= 0 or h0 <= 0:
55
+ raise ValueError("invalid original size")
56
+
57
+ sW, sH = w0 * scale, h0 * scale
58
+ tW = max(multiple, (sW // multiple) * multiple)
59
+ tH = max(multiple, (sH // multiple) * multiple)
60
+ return sW, sH, tW, tH
61
+
62
+ def tensor_upscale_then_center_crop(frame_tensor: torch.Tensor, scale: int, tW: int, tH: int) -> torch.Tensor:
63
+ h0, w0, c = frame_tensor.shape
64
+ tensor_bchw = frame_tensor.permute(2, 0, 1).unsqueeze(0) # HWC -> CHW -> BCHW
65
+
66
+ sW, sH = w0 * scale, h0 * scale
67
+ upscaled_tensor = F.interpolate(tensor_bchw, size=(sH, sW), mode='bicubic', align_corners=False)
68
+
69
+ l = max(0, (sW - tW) // 2)
70
+ t = max(0, (sH - tH) // 2)
71
+ cropped_tensor = upscaled_tensor[:, :, t:t + tH, l:l + tW]
72
+
73
+ return cropped_tensor.squeeze(0)
74
+
75
+ def prepare_input_tensor(image_tensor: torch.Tensor, device, scale: int = 4, dtype=torch.bfloat16):
76
+ N0, h0, w0, _ = image_tensor.shape
77
+
78
+ multiple = 128
79
+ sW, sH, tW, tH = compute_scaled_and_target_dims(w0, h0, scale=scale, multiple=multiple)
80
+ num_frames_with_padding = N0 + 4
81
+ F = largest_8n1_leq(num_frames_with_padding)
82
+
83
+ if F == 0:
84
+ raise RuntimeError(f"Not enough frames after padding. Got {num_frames_with_padding}.")
85
+
86
+ frames = []
87
+ for i in range(F):
88
+ frame_idx = min(i, N0 - 1)
89
+ frame_slice = image_tensor[frame_idx].to(device)
90
+ tensor_chw = tensor_upscale_then_center_crop(frame_slice, scale=scale, tW=tW, tH=tH).to('cpu').to(dtype) * 2.0 - 1.0
91
+ frames.append(tensor_chw)
92
+ del frame_slice
93
+
94
+ vid_stacked = torch.stack(frames, 0)
95
+ vid_final = vid_stacked.permute(1, 0, 2, 3).unsqueeze(0)
96
+
97
+ del vid_stacked
98
+ clean_vram()
99
+
100
+ return vid_final, tH, tW, F
101
+
102
+ def calculate_tile_coords(height, width, tile_size, overlap):
103
+ coords = []
104
+
105
+ stride = tile_size - overlap
106
+ num_rows = math.ceil((height - overlap) / stride)
107
+ num_cols = math.ceil((width - overlap) / stride)
108
+
109
+ for r in range(num_rows):
110
+ for c in range(num_cols):
111
+ y1 = r * stride
112
+ x1 = c * stride
113
+
114
+ y2 = min(y1 + tile_size, height)
115
+ x2 = min(x1 + tile_size, width)
116
+
117
+ if y2 - y1 < tile_size:
118
+ y1 = max(0, y2 - tile_size)
119
+ if x2 - x1 < tile_size:
120
+ x1 = max(0, x2 - tile_size)
121
+
122
+ coords.append((x1, y1, x2, y2))
123
+
124
+ return coords
125
+
126
+ def create_feather_mask(size, overlap):
127
+ H, W = size
128
+ mask = torch.ones(1, 1, H, W)
129
+ ramp = torch.linspace(0, 1, overlap)
130
+
131
+ mask[:, :, :, :overlap] = torch.minimum(mask[:, :, :, :overlap], ramp.view(1, 1, 1, -1))
132
+ mask[:, :, :, -overlap:] = torch.minimum(mask[:, :, :, -overlap:], ramp.flip(0).view(1, 1, 1, -1))
133
+
134
+ mask[:, :, :overlap, :] = torch.minimum(mask[:, :, :overlap, :], ramp.view(1, 1, -1, 1))
135
+ mask[:, :, -overlap:, :] = torch.minimum(mask[:, :, -overlap:, :], ramp.flip(0).view(1, 1, -1, 1))
136
+
137
+ return mask
138
+
139
+ def init_pipeline(model, mode, device, dtype, alt_vae="none"):
140
+ model_downlod(model_name="JunhaoZhuang/"+model)
141
+ model_path = os.path.join(folder_paths.models_dir, model)
142
+ if not os.path.exists(model_path):
143
+ raise RuntimeError(f'Model directory does not exist!\nPlease save all weights to "{model_path}"')
144
+ ckpt_path = os.path.join(model_path, "diffusion_pytorch_model_streaming_dmd.safetensors")
145
+ if not os.path.exists(ckpt_path):
146
+ raise RuntimeError(f'"diffusion_pytorch_model_streaming_dmd.safetensors" does not exist!\nPlease save it to "{model_path}"')
147
+ if alt_vae != "none":
148
+ vae_path = folder_paths.get_full_path_or_raise("vae", alt_vae)
149
+ if not os.path.exists(vae_path):
150
+ raise RuntimeError(f'"{alt_vae}" does not exist!')
151
+ else:
152
+ vae_path = os.path.join(model_path, "Wan2.1_VAE.pth")
153
+ if not os.path.exists(vae_path):
154
+ raise RuntimeError(f'"Wan2.1_VAE.pth" does not exist!\nPlease save it to "{model_path}"')
155
+ lq_path = os.path.join(model_path, "LQ_proj_in.ckpt")
156
+ if not os.path.exists(lq_path):
157
+ raise RuntimeError(f'"LQ_proj_in.ckpt" does not exist!\nPlease save it to "{model_path}"')
158
+ tcd_path = os.path.join(model_path, "TCDecoder.ckpt")
159
+ if not os.path.exists(tcd_path):
160
+ raise RuntimeError(f'"TCDecoder.ckpt" does not exist!\nPlease save it to "{model_path}"')
161
+ current_dir = os.path.dirname(os.path.abspath(__file__))
162
+ prompt_path = os.path.join(current_dir, "posi_prompt.pth")
163
+
164
+ mm = ModelManager(torch_dtype=dtype, device="cpu")
165
+ if mode == "full":
166
+ mm.load_models([ckpt_path, vae_path])
167
+ pipe = FlashVSRFullPipeline.from_model_manager(mm, device=device)
168
+ pipe.vae.model.encoder = None
169
+ pipe.vae.model.conv1 = None
170
+ else:
171
+ mm.load_models([ckpt_path])
172
+ if mode == "tiny":
173
+ pipe = FlashVSRTinyPipeline.from_model_manager(mm, device=device)
174
+ else:
175
+ pipe = FlashVSRTinyLongPipeline.from_model_manager(mm, device=device)
176
+ multi_scale_channels = [512, 256, 128, 128]
177
+ pipe.TCDecoder = build_tcdecoder(new_channels=multi_scale_channels, device=device, dtype=dtype, new_latent_channels=16+768)
178
+ mis = pipe.TCDecoder.load_state_dict(torch.load(tcd_path, map_location=device), strict=False)
179
+ pipe.TCDecoder.clean_mem()
180
+
181
+ if model == "FlashVSR":
182
+ pipe.denoising_model().LQ_proj_in = Buffer_LQ4x_Proj(in_dim=3, out_dim=1536, layer_num=1).to(device, dtype=dtype)
183
+ else:
184
+ pipe.denoising_model().LQ_proj_in = Causal_LQ4x_Proj(in_dim=3, out_dim=1536, layer_num=1).to(device, dtype=dtype)
185
+ pipe.denoising_model().LQ_proj_in.load_state_dict(torch.load(lq_path, map_location="cpu"), strict=True)
186
+ pipe.denoising_model().LQ_proj_in.to(device)
187
+ pipe.to(device, dtype=dtype)
188
+ pipe.enable_vram_management(num_persistent_param_in_dit=None)
189
+ pipe.init_cross_kv(prompt_path=prompt_path)
190
+ pipe.load_models_to_device(["dit","vae"])
191
+ pipe.offload_model()
192
+
193
+ return pipe
194
+
195
+ class cqdm:
196
+ def __init__(self, iterable=None, total=None, desc="Processing"):
197
+ self.desc = desc
198
+ self.pbar = None
199
+ self.iterable = None
200
+ self.total = total
201
+
202
+ if iterable is not None:
203
+ try:
204
+ self.total = len(iterable)
205
+ self.iterable = iter(iterable)
206
+ except TypeError:
207
+ if self.total is None:
208
+ raise ValueError("Total must be provided for iterables with no length.")
209
+
210
+ elif self.total is not None:
211
+ pass
212
+
213
+ else:
214
+ raise ValueError("Either iterable or total must be provided.")
215
+
216
+ def __iter__(self):
217
+ if self.iterable is None:
218
+ raise TypeError(f"'{type(self).__name__}' object is not iterable. Did you mean to use it with a 'with' statement?")
219
+ if self.pbar is None:
220
+ self.pbar = comfy.utils.ProgressBar(self.total)
221
+ return self
222
+
223
+ def __next__(self):
224
+ if self.iterable is None:
225
+ raise TypeError("Cannot call __next__ on a non-iterable cqdm object.")
226
+ try:
227
+ val = next(self.iterable)
228
+ if self.pbar:
229
+ self.pbar.update(1)
230
+ return val
231
+ except StopIteration:
232
+ raise
233
+
234
+ def __enter__(self):
235
+ if self.pbar is None:
236
+ self.pbar = comfy.utils.ProgressBar(self.total)
237
+ return self.pbar
238
+
239
+ def __exit__(self, exc_type, exc_val, exc_tb):
240
+ pass
241
+
242
+ def __len__(self):
243
+ return self.total
244
+
245
+ def flashvsr(pipe, frames, scale, color_fix, tiled_vae, tiled_dit, tile_size, tile_overlap, unload_dit, sparse_ratio, kv_ratio, local_range, seed, force_offload):
246
+ _frames = frames
247
+ _device = pipe.device
248
+ dtype = pipe.torch_dtype
249
+
250
+ add = next_8n5(frames.shape[0]) - frames.shape[0]
251
+ padding_frames = frames[-1:, :, :, :].repeat(add, 1, 1, 1)
252
+ _frames = torch.cat([frames, padding_frames], dim=0)
253
+
254
+ if tiled_dit:
255
+ N, H, W, C = _frames.shape
256
+
257
+ final_output_canvas = torch.zeros(
258
+ (N, H * scale, W * scale, C),
259
+ dtype=torch.float16,
260
+ device="cpu"
261
+ )
262
+ weight_sum_canvas = torch.zeros_like(final_output_canvas)
263
+ tile_coords = calculate_tile_coords(H, W, tile_size, tile_overlap)
264
+ latent_tiles_cpu = []
265
+
266
+ for i, (x1, y1, x2, y2) in enumerate(cqdm(tile_coords, desc="Processing Tiles")):
267
+ log(f"[FlashVSR] Processing tile {i+1}/{len(tile_coords)}: coords ({x1},{y1}) to ({x2},{y2})", message_type='info')
268
+ input_tile = _frames[:, y1:y2, x1:x2, :]
269
+
270
+ LQ_tile, th, tw, F = prepare_input_tensor(input_tile, _device, scale=scale, dtype=dtype)
271
+ if not isinstance(pipe, FlashVSRTinyLongPipeline):
272
+ LQ_tile = LQ_tile.to(_device)
273
+
274
+ output_tile_gpu = pipe(
275
+ prompt="", negative_prompt="", cfg_scale=1.0, num_inference_steps=1, seed=seed, tiled=tiled_vae,
276
+ LQ_video=LQ_tile, num_frames=F, height=th, width=tw, is_full_block=False, if_buffer=True,
277
+ topk_ratio=sparse_ratio*768*1280/(th*tw), kv_ratio=kv_ratio, local_range=local_range,
278
+ color_fix=color_fix, unload_dit=unload_dit, force_offload=force_offload
279
+ )
280
+
281
+ processed_tile_cpu = tensor2video(output_tile_gpu).to("cpu")
282
+
283
+ mask_nchw = create_feather_mask(
284
+ (processed_tile_cpu.shape[1], processed_tile_cpu.shape[2]),
285
+ tile_overlap * scale
286
+ ).to("cpu")
287
+ mask_nhwc = mask_nchw.permute(0, 2, 3, 1)
288
+ out_x1, out_y1 = x1 * scale, y1 * scale
289
+
290
+ tile_H_scaled = processed_tile_cpu.shape[1]
291
+ tile_W_scaled = processed_tile_cpu.shape[2]
292
+ out_x2, out_y2 = out_x1 + tile_W_scaled, out_y1 + tile_H_scaled
293
+ final_output_canvas[:, out_y1:out_y2, out_x1:out_x2, :] += processed_tile_cpu * mask_nhwc
294
+ weight_sum_canvas[:, out_y1:out_y2, out_x1:out_x2, :] += mask_nhwc
295
+
296
+ del LQ_tile, output_tile_gpu, processed_tile_cpu, input_tile
297
+ clean_vram()
298
+
299
+ weight_sum_canvas[weight_sum_canvas == 0] = 1.0
300
+ final_output = final_output_canvas / weight_sum_canvas
301
+ else:
302
+ log("[FlashVSR] Preparing frames...")
303
+ LQ, th, tw, F = prepare_input_tensor(_frames, _device, scale=scale, dtype=dtype)
304
+ if not isinstance(pipe, FlashVSRTinyLongPipeline):
305
+ LQ = LQ.to(_device)
306
+ log(f"[FlashVSR] Processing {frames.shape[0]} frames...", message_type='info')
307
+
308
+ video = pipe(
309
+ prompt="", negative_prompt="", cfg_scale=1.0, num_inference_steps=1, seed=seed, tiled=tiled_vae,
310
+ progress_bar_cmd=cqdm, LQ_video=LQ, num_frames=F, height=th, width=tw, is_full_block=False, if_buffer=True,
311
+ topk_ratio=sparse_ratio*768*1280/(th*tw), kv_ratio=kv_ratio, local_range=local_range,
312
+ color_fix = color_fix, unload_dit=unload_dit, force_offload=force_offload
313
+ )
314
+
315
+ final_output = tensor2video(video).to('cpu')
316
+
317
+ del video, LQ
318
+ clean_vram()
319
+
320
+ log("[FlashVSR] Done.", message_type='info')
321
+ if frames.shape[0] == 1:
322
+ final_output = final_output.to(_device)
323
+ stacked_image_tensor = torch.median(final_output, dim=0).values.unsqueeze(0).float().to('cpu')
324
+ del final_output
325
+ clean_vram()
326
+ return stacked_image_tensor
327
+
328
+ return final_output[:frames.shape[0], :, :, :]
329
+
330
+ class FlashVSRNodeInitPipe:
331
+ @classmethod
332
+ def INPUT_TYPES(cls):
333
+ return {
334
+ "required": {
335
+ "model": (["FlashVSR", "FlashVSR-v1.1"], {
336
+ "default": "FlashVSR-v1.1",
337
+ "tooltip": "Model version."
338
+ }),
339
+ "mode": (["tiny", "tiny-long", "full"], {
340
+ "default": "tiny",
341
+ "tooltip": 'Using "tiny-long" mode can significantly reduce VRAM used with long video input.'
342
+ }),
343
+ "alt_vae": (["none"] + folder_paths.get_filename_list("vae"), {
344
+ "default": "none",
345
+ "tooltip": 'Replaces the built-in VAE, only available in "full" mode.'
346
+ }),
347
+ "force_offload": ("BOOLEAN", {
348
+ "default": True,
349
+ "tooltip": "Offload all weights to CPU after running a workflow to free up VRAM."
350
+ }),
351
+ "precision": (["fp16", "bf16"], {
352
+ "default": "bf16",
353
+ "tooltip": "Data and inference precision."
354
+ }),
355
+ "device": (device_choices, {
356
+ "default": device_choices[0],
357
+ "tooltip": "Device to load the weights, default: auto (CUDA if available, else CPU)"
358
+ }),
359
+ "attention_mode": (["sparse_sage_attention", "block_sparse_attention"], {
360
+ "default": "sparse_sage_attention",
361
+ "tooltip": '"sparse_sage_attention" is available for sm_75 to sm_120\n"block_sparse_attention" is available for sm_80 to sm_100'
362
+ }),
363
+ }
364
+ }
365
+
366
+ RETURN_TYPES = ("PIPE",)
367
+ RETURN_NAMES = ("pipe",)
368
+ FUNCTION = "main"
369
+ CATEGORY = "FlashVSR"
370
+ DESCRIPTION = 'Download the entire "FlashVSR" folder with all the files inside it from "https://huggingface.co/JunhaoZhuang/FlashVSR" and put it in the "ComfyUI/models"'
371
+
372
+ def main(self, model, mode, alt_vae, force_offload, precision, device, attention_mode):
373
+ _device = device
374
+ if device == "auto":
375
+ _device = "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else device
376
+ if _device == "auto" or _device not in device_choices:
377
+ raise RuntimeError("No devices found to run FlashVSR!")
378
+
379
+ if _device.startswith("cuda"):
380
+ torch.cuda.set_device(_device)
381
+
382
+ if attention_mode == "sparse_sage_attention":
383
+ wan_video_dit.USE_BLOCK_ATTN = False
384
+ else:
385
+ wan_video_dit.USE_BLOCK_ATTN = True
386
+
387
+ dtype_map = {
388
+ "fp32": torch.float32,
389
+ "fp16": torch.float16,
390
+ "bf16": torch.bfloat16,
391
+ }
392
+ try:
393
+ dtype = dtype_map[precision]
394
+ except:
395
+ dtype = torch.bfloat16
396
+
397
+ pipe = init_pipeline(model, mode, _device, dtype, alt_vae=alt_vae)
398
+ return((pipe, force_offload),)
399
+
400
+ class FlashVSRNodeAdv:
401
+ @classmethod
402
+ def INPUT_TYPES(cls):
403
+ return {
404
+ "required": {
405
+ "pipe": ("PIPE", {
406
+ "tooltip": "FlashVSR pipeline"
407
+ }),
408
+ "frames": ("IMAGE", {
409
+ "tooltip": "Sequential video frames as IMAGE tensor batch"
410
+ }),
411
+ "scale": ("INT", {
412
+ "default": 2,
413
+ "min": 2,
414
+ "max": 4,
415
+ }),
416
+ "color_fix": ("BOOLEAN", {
417
+ "default": True,
418
+ "tooltip": "Use wavelet transform to correct output video color."
419
+ }),
420
+ "tiled_vae": ("BOOLEAN", {
421
+ "default": True,
422
+ "tooltip": "Disable tiling: faster decode but higher VRAM usage.\nSet to True for lower memory consumption at the cost of speed."
423
+ }),
424
+ "tiled_dit": ("BOOLEAN", {
425
+ "default": True,
426
+ "tooltip": "Significantly reduces VRAM usage at the cost of speed."
427
+ }),
428
+ "tile_size": ("INT", {
429
+ "default": 256,
430
+ "min": 32,
431
+ "max": 1024,
432
+ "step": 32,
433
+ }),
434
+ "tile_overlap": ("INT", {
435
+ "default": 24,
436
+ "min": 8,
437
+ "max": 512,
438
+ "step": 8,
439
+ }),
440
+ "unload_dit": ("BOOLEAN", {
441
+ "default": False,
442
+ "tooltip": "Unload DiT before decoding to reduce VRAM peak at the cost of speed."
443
+ }),
444
+ "sparse_ratio": ("FLOAT", {
445
+ "default": 2.0,
446
+ "min": 1.5,
447
+ "max": 2.0,
448
+ "step": 0.1,
449
+ "display": "slider",
450
+ "tooltip": "Recommended: 1.5 or 2.0\n1.5 → faster; 2.0 → more stable"
451
+ }),
452
+ "kv_ratio": ("FLOAT", {
453
+ "default": 3.0,
454
+ "min": 1.0,
455
+ "max": 3.0,
456
+ "step": 0.1,
457
+ "display": "slider",
458
+ "tooltip": "Recommended: 1.0 to 3.0\n1.0 → less vram; 3.0 → high quality"
459
+ }),
460
+ "local_range": ("INT", {
461
+ "default": 11,
462
+ "min": 9,
463
+ "max": 11,
464
+ "step": 2,
465
+ "tooltip": "Recommended: 9 or 11\nlocal_range=9 → sharper details; 11 → more stable results"
466
+ }),
467
+ "seed": ("INT", {
468
+ "default": 0,
469
+ "min": 0,
470
+ "max": 1125899906842624
471
+ }),
472
+ }
473
+ }
474
+
475
+ RETURN_TYPES = ("IMAGE",)
476
+ RETURN_NAMES = ("image",)
477
+ FUNCTION = "main"
478
+ CATEGORY = "FlashVSR"
479
+ #DESCRIPTION = ""
480
+
481
+ def main(self, pipe, frames, scale, color_fix, tiled_vae, tiled_dit, tile_size, tile_overlap, unload_dit, sparse_ratio, kv_ratio, local_range, seed):
482
+ _pipe, force_offload = pipe
483
+ output = flashvsr(_pipe, frames, scale, color_fix, tiled_vae, tiled_dit, tile_size, tile_overlap, unload_dit, sparse_ratio, kv_ratio, local_range, seed, force_offload)
484
+ return(output,)
485
+
486
+ class FlashVSRNode:
487
+ @classmethod
488
+ def INPUT_TYPES(cls):
489
+ return {
490
+ "required": {
491
+ "frames": ("IMAGE", {
492
+ "tooltip": "Sequential video frames as IMAGE tensor batch"
493
+ }),
494
+ "model": (["FlashVSR", "FlashVSR-v1.1"], {
495
+ "default": "FlashVSR-v1.1",
496
+ "tooltip": "Model version."
497
+ }),
498
+ "mode": (["tiny", "tiny-long", "full"], {
499
+ "default": "tiny",
500
+ "tooltip": 'Using "tiny-long" mode can significantly reduce VRAM used with long video input.'
501
+ }),
502
+ "scale": ("INT", {
503
+ "default": 2,
504
+ "min": 2,
505
+ "max": 4,
506
+ }),
507
+ "tiled_vae": ("BOOLEAN", {
508
+ "default": True,
509
+ "tooltip": "Disable tiling: faster decode but higher VRAM usage.\nSet to True for lower memory consumption at the cost of speed."
510
+ }),
511
+ "tiled_dit": ("BOOLEAN", {
512
+ "default": True,
513
+ "tooltip": "Significantly reduces VRAM usage at the cost of speed."
514
+ }),
515
+ "unload_dit": ("BOOLEAN", {
516
+ "default": False,
517
+ "tooltip": "Unload DiT before decoding to reduce VRAM peak at the cost of speed."
518
+ }),
519
+ "seed": ("INT", {
520
+ "default": 0,
521
+ "min": 0,
522
+ "max": 1125899906842624
523
+ }),
524
+ }
525
+ }
526
+
527
+ RETURN_TYPES = ("IMAGE",)
528
+ RETURN_NAMES = ("image",)
529
+ FUNCTION = "main"
530
+ CATEGORY = "FlashVSR"
531
+ DESCRIPTION = 'Download the entire "FlashVSR" folder with all the files inside it from "https://huggingface.co/JunhaoZhuang/FlashVSR" and put it in the "ComfyUI/models"'
532
+
533
+ def main(self, model, frames, mode, scale, tiled_vae, tiled_dit, unload_dit, seed):
534
+ wan_video_dit.USE_BLOCK_ATTN = False
535
+ _device = "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "auto"
536
+ if _device == "auto" or _device not in device_choices:
537
+ raise RuntimeError("No devices found to run FlashVSR!")
538
+
539
+ pipe = init_pipeline(model, mode, _device, torch.float16)
540
+ output = flashvsr(pipe, frames, scale, True, tiled_vae, tiled_dit, 256, 24, unload_dit, 2.0, 3.0, 11, seed, True)
541
+ return(output,)
542
+
543
+ NODE_CLASS_MAPPINGS = {
544
+ "FlashVSRNode": FlashVSRNode,
545
+ "FlashVSRNodeAdv": FlashVSRNodeAdv,
546
+ "FlashVSRInitPipe": FlashVSRNodeInitPipe,
547
+ }
548
+
549
+ NODE_DISPLAY_NAME_MAPPINGS = {
550
+ "FlashVSRNode": "FlashVSR Ultra-Fast",
551
+ "FlashVSRNodeAdv": "FlashVSR Ultra-Fast (Advanced)",
552
+ "FlashVSRInitPipe": "FlashVSR Init Pipeline",
553
+ }
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/posi_prompt.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4601107a11e4e11a936a6b79df579e54dbc99872132bf542151f0ffd65b4b1ef
3
+ size 4195504
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ numpy
4
+ einops
5
+ safetensors
6
+ tqdm
7
+ pillow
8
+ huggingface_hub
9
+ triton; platform_system!="Windows"
10
+ triton-windows; platform_system=="Windows"
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/LICENSE.txt ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .models import *
2
+ from .pipelines import *
3
+ from .schedulers import *
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/configs/__init__.py ADDED
File without changes
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/configs/model_config.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing_extensions import Literal, TypeAlias
2
+
3
+ from ..models.wan_video_dit import WanModel
4
+ from ..models.wan_video_vae import WanVideoVAE
5
+
6
+
7
+ model_loader_configs = [
8
+ # These configs are provided for detecting model type automatically.
9
+ # The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
10
+ (None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
11
+ (None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
12
+ (None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
13
+ (None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"),
14
+ (None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
15
+ (None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
16
+ (None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
17
+ (None, "3ef3b1f8e1dab83d5b71fd7b617f859f", ["wan_video_dit"], [WanModel], "civitai"),
18
+ (None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
19
+ (None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
20
+ (None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
21
+ ]
22
+ huggingface_model_loader_configs = [
23
+ # These configs are provided for detecting model type automatically.
24
+ # The format is (architecture_in_huggingface_config, huggingface_lib, model_name, redirected_architecture)
25
+ ]
26
+ patch_model_loader_configs = [
27
+ # These configs are provided for detecting model type automatically.
28
+ # The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs)
29
+ ]
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/TCDecoder.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Tiny AutoEncoder for Hunyuan Video (Decoder-only, pruned)
4
+ - Encoder removed
5
+ - Transplant/widening helpers removed
6
+ - Deepening (IdentityConv2d+ReLU) is now built into the decoder structure itself
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from tqdm.auto import tqdm
13
+ from collections import namedtuple
14
+ from einops import rearrange
15
+ import torch.nn.init as init
16
+
17
+ DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
18
+ TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
19
+
20
+ # ----------------------------
21
+ # Utility / building blocks
22
+ # ----------------------------
23
+
24
+ class IdentityConv2d(nn.Conv2d):
25
+ """Same-shape Conv2d initialized to identity (Dirac)."""
26
+ def __init__(self, C, kernel_size=3, bias=False):
27
+ pad = kernel_size // 2
28
+ super().__init__(C, C, kernel_size, padding=pad, bias=bias)
29
+ with torch.no_grad():
30
+ init.dirac_(self.weight)
31
+ if self.bias is not None:
32
+ self.bias.zero_()
33
+
34
+ def conv(n_in, n_out, **kwargs):
35
+ return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
36
+
37
+ class Clamp(nn.Module):
38
+ def forward(self, x):
39
+ return torch.tanh(x / 3) * 3
40
+
41
+ class MemBlock(nn.Module):
42
+ def __init__(self, n_in, n_out):
43
+ super().__init__()
44
+ self.conv = nn.Sequential(
45
+ conv(n_in * 2, n_out), nn.ReLU(inplace=True),
46
+ conv(n_out, n_out), nn.ReLU(inplace=True),
47
+ conv(n_out, n_out)
48
+ )
49
+ self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
50
+ self.act = nn.ReLU(inplace=True)
51
+ def forward(self, x, past):
52
+ return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
53
+
54
+ class TPool(nn.Module):
55
+ def __init__(self, n_f, stride):
56
+ super().__init__()
57
+ self.stride = stride
58
+ self.conv = nn.Conv2d(n_f*stride, n_f, 1, bias=False)
59
+ def forward(self, x):
60
+ _NT, C, H, W = x.shape
61
+ return self.conv(x.reshape(-1, self.stride * C, H, W))
62
+
63
+ class TGrow(nn.Module):
64
+ def __init__(self, n_f, stride):
65
+ super().__init__()
66
+ self.stride = stride
67
+ self.conv = nn.Conv2d(n_f, n_f*stride, 1, bias=False)
68
+ def forward(self, x):
69
+ _NT, C, H, W = x.shape
70
+ x = self.conv(x)
71
+ return x.reshape(-1, C, H, W)
72
+
73
+ class PixelShuffle3d(nn.Module):
74
+ def __init__(self, ff, hh, ww):
75
+ super().__init__()
76
+ self.ff = ff
77
+ self.hh = hh
78
+ self.ww = ww
79
+ def forward(self, x):
80
+ # x: (B, C, F, H, W)
81
+ B, C, F, H, W = x.shape
82
+ if F % self.ff != 0:
83
+ first_frame = x[:, :, 0:1, :, :].repeat(1, 1, self.ff - F % self.ff, 1, 1)
84
+ x = torch.cat([first_frame, x], dim=2)
85
+ return rearrange(
86
+ x,
87
+ 'b c (f ff) (h hh) (w ww) -> b (c ff hh ww) f h w',
88
+ ff=self.ff, hh=self.hh, ww=self.ww
89
+ ).transpose(1, 2)
90
+
91
+ # ----------------------------
92
+ # Generic NTCHW graph executor (kept; used by decoder)
93
+ # ----------------------------
94
+
95
+ def apply_model_with_memblocks(model, x, parallel, show_progress_bar, mem=None):
96
+ """
97
+ Apply a sequential model with memblocks to the given input.
98
+ Args:
99
+ - model: nn.Sequential of blocks to apply
100
+ - x: input data, of dimensions NTCHW
101
+ - parallel: if True, parallelize over timesteps (fast but uses O(T) memory)
102
+ if False, each timestep will be processed sequentially (slow but uses O(1) memory)
103
+ - show_progress_bar: if True, enables tqdm progressbar display
104
+
105
+ Returns NTCHW tensor of output data.
106
+ """
107
+ assert x.ndim == 5, f"TAEHV operates on NTCHW tensors, but got {x.ndim}-dim tensor"
108
+ N, T, C, H, W = x.shape
109
+ if parallel:
110
+ x = x.reshape(N*T, C, H, W)
111
+ for b in tqdm(model, disable=not show_progress_bar):
112
+ if isinstance(b, MemBlock):
113
+ NT, C, H, W = x.shape
114
+ T = NT // N
115
+ _x = x.reshape(N, T, C, H, W)
116
+ mem = F.pad(_x, (0,0,0,0,0,0,1,0), value=0)[:,:T].reshape(x.shape)
117
+ x = b(x, mem)
118
+ else:
119
+ x = b(x)
120
+ NT, C, H, W = x.shape
121
+ T = NT // N
122
+ x = x.view(N, T, C, H, W)
123
+ else:
124
+ out = []
125
+ work_queue = [TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(N, T * C, H, W).chunk(T, dim=1))]
126
+ progress_bar = tqdm(range(T), disable=not show_progress_bar)
127
+ while work_queue:
128
+ xt, i = work_queue.pop(0)
129
+ if i == 0:
130
+ progress_bar.update(1)
131
+ if i == len(model):
132
+ out.append(xt)
133
+ else:
134
+ b = model[i]
135
+ if isinstance(b, MemBlock):
136
+ if mem[i] is None:
137
+ xt_new = b(xt, xt * 0)
138
+ mem[i] = xt
139
+ else:
140
+ xt_new = b(xt, mem[i])
141
+ mem[i].copy_(xt)
142
+ work_queue.insert(0, TWorkItem(xt_new, i+1))
143
+ elif isinstance(b, TPool):
144
+ if mem[i] is None:
145
+ mem[i] = []
146
+ mem[i].append(xt)
147
+ if len(mem[i]) > b.stride:
148
+ raise ValueError("TPool internal state invalid.")
149
+ elif len(mem[i]) == b.stride:
150
+ N_, C_, H_, W_ = xt.shape
151
+ xt = b(torch.cat(mem[i], 1).view(N_*b.stride, C_, H_, W_))
152
+ mem[i] = []
153
+ work_queue.insert(0, TWorkItem(xt, i+1))
154
+ elif isinstance(b, TGrow):
155
+ xt = b(xt)
156
+ NT, C_, H_, W_ = xt.shape
157
+ for xt_next in reversed(xt.view(N, b.stride*C_, H_, W_).chunk(b.stride, 1)):
158
+ work_queue.insert(0, TWorkItem(xt_next, i+1))
159
+ else:
160
+ xt = b(xt)
161
+ work_queue.insert(0, TWorkItem(xt, i+1))
162
+ progress_bar.close()
163
+ x = torch.stack(out, 1)
164
+ return x, mem
165
+
166
+ # ----------------------------
167
+ # Decoder-only TAEHV
168
+ # ----------------------------
169
+
170
+ class TAEHV(nn.Module):
171
+ image_channels = 3
172
+ def __init__(
173
+ self,
174
+ checkpoint_path="taehv.pth",
175
+ decoder_time_upscale=(True, True),
176
+ decoder_space_upscale=(True, True, True),
177
+ channels = [256, 128, 64, 64],
178
+ latent_channels = 16
179
+ ):
180
+ """Initialize TAEHV (decoder-only) with built-in deepening after every ReLU.
181
+ Deepening config: how_many_each=1, k=3 (fixed as requested).
182
+ """
183
+ super().__init__()
184
+ self.latent_channels = latent_channels
185
+ n_f = channels
186
+ self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
187
+
188
+ # Build the decoder "skeleton"
189
+ base_decoder = nn.Sequential(
190
+ Clamp(), conv(self.latent_channels, n_f[0]), nn.ReLU(inplace=True),
191
+
192
+ MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]),
193
+ nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1),
194
+ TGrow(n_f[0], 1),
195
+ conv(n_f[0], n_f[1], bias=False),
196
+
197
+ MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]),
198
+ nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1),
199
+ TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1),
200
+ conv(n_f[1], n_f[2], bias=False),
201
+
202
+ MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]),
203
+ nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1),
204
+ TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1),
205
+ conv(n_f[2], n_f[3], bias=False),
206
+
207
+ nn.ReLU(inplace=True), conv(n_f[3], TAEHV.image_channels),
208
+ )
209
+
210
+ # Inline deepening: insert (IdentityConv2d(k=3) + ReLU) after every ReLU
211
+ self.decoder = self._apply_identity_deepen(base_decoder, how_many_each=1, k=3)
212
+
213
+ self.pixel_shuffle = PixelShuffle3d(4, 8, 8)
214
+
215
+ if checkpoint_path is not None:
216
+ missing_keys = self.load_state_dict(
217
+ self.patch_tgrow_layers(torch.load(checkpoint_path, map_location="cpu", weights_only=True)),
218
+ strict=False
219
+ )
220
+ print('missing_keys', missing_keys)
221
+
222
+ # Initialize decoder mem state
223
+ self.mem = [None] * len(self.decoder)
224
+
225
+ @staticmethod
226
+ def _apply_identity_deepen(decoder: nn.Sequential, how_many_each=1, k=3) -> nn.Sequential:
227
+ """Return a new Sequential where every nn.ReLU is followed by how_many_each*(IdentityConv2d(k)+ReLU)."""
228
+ new_layers = []
229
+ for b in decoder:
230
+ new_layers.append(b)
231
+ if isinstance(b, nn.ReLU):
232
+ # Deduce channel count from preceding layer
233
+ C = None
234
+ if len(new_layers) >= 2 and isinstance(new_layers[-2], nn.Conv2d):
235
+ C = new_layers[-2].out_channels
236
+ elif len(new_layers) >= 2 and isinstance(new_layers[-2], MemBlock):
237
+ C = new_layers[-2].conv[-1].out_channels
238
+ if C is not None:
239
+ for _ in range(how_many_each):
240
+ new_layers.append(IdentityConv2d(C, kernel_size=k, bias=False))
241
+ new_layers.append(nn.ReLU(inplace=True))
242
+ return nn.Sequential(*new_layers)
243
+
244
+ def patch_tgrow_layers(self, sd):
245
+ """Patch TGrow layers to use a smaller kernel if needed (decoder-only)."""
246
+ new_sd = self.state_dict()
247
+ for i, layer in enumerate(self.decoder):
248
+ if isinstance(layer, TGrow):
249
+ key = f"decoder.{i}.conv.weight"
250
+ if key in sd and sd[key].shape[0] > new_sd[key].shape[0]:
251
+ sd[key] = sd[key][-new_sd[key].shape[0]:]
252
+ return sd
253
+
254
+ def decode_video(self, x, parallel=True, show_progress_bar=False, cond=None):
255
+ """Decode a sequence of frames from latents.
256
+ x: NTCHW latent tensor; returns NTCHW RGB in ~[0, 1].
257
+ """
258
+ trim_flag = self.mem[-8] is None # keeps original relative check
259
+
260
+ if cond is not None:
261
+ x = torch.cat([self.pixel_shuffle(cond), x], dim=2)
262
+
263
+ x, self.mem = apply_model_with_memblocks(self.decoder, x, parallel, show_progress_bar, mem=self.mem)
264
+
265
+ if trim_flag:
266
+ return x[:, self.frames_to_trim:]
267
+ return x
268
+
269
+ def forward(self, *args, **kwargs):
270
+ raise NotImplementedError("Decoder-only model: call decode_video(...) instead.")
271
+
272
+ def clean_mem(self):
273
+ self.mem = [None] * len(self.decoder)
274
+
275
+ class DotDict(dict):
276
+ __getattr__ = dict.__getitem__
277
+ __setattr__ = dict.__setitem__
278
+
279
+ class TAEW2_1DiffusersWrapper(nn.Module):
280
+ def __init__(self, pretrained_path=None, channels = [256, 128, 64, 64]):
281
+ super().__init__()
282
+ self.dtype = torch.bfloat16
283
+ self.device = "cuda"
284
+ self.taehv = TAEHV(pretrained_path, channels = channels).to(self.dtype)
285
+ self.temperal_downsample = [True, True, False] # [sic]
286
+ self.config = DotDict(scaling_factor=1.0, latents_mean=torch.zeros(16), z_dim=16, latents_std=torch.ones(16))
287
+
288
+ def decode(self, latents, return_dict=None):
289
+ n, c, t, h, w = latents.shape
290
+ return (self.taehv.decode_video(latents.transpose(1, 2), parallel=False).transpose(1, 2).mul_(2).sub_(1),)
291
+
292
+ def stream_decode_with_cond(self, latents, tiled=False, cond=None):
293
+ n, c, t, h, w = latents.shape
294
+ return self.taehv.decode_video(latents.transpose(1, 2), parallel=False, cond=cond).transpose(1, 2).mul_(2).sub_(1)
295
+
296
+ def clean_mem(self):
297
+ self.taehv.clean_mem()
298
+
299
+ # ----------------------------
300
+ # Simplified builder (no small, no transplant, no post-hoc deepening)
301
+ # ----------------------------
302
+
303
+ def build_tcdecoder(new_channels = [512, 256, 128, 128],
304
+ device="cuda",
305
+ dtype=torch.bfloat16,
306
+ new_latent_channels=None):
307
+ """
308
+ 构建“更宽”的 decoder;深度增强(IdentityConv2d+ReLU)已在 TAEHV 内部完成。
309
+ - 不创建 small / 不做移植
310
+ - base_ckpt_path 参数保留但不使用(接口兼容)
311
+
312
+ 返回:big (单个模型)
313
+ """
314
+ if new_latent_channels is not None:
315
+ big = TAEHV(checkpoint_path=None, channels=new_channels, latent_channels=new_latent_channels).to(device).to(dtype).train()
316
+ else:
317
+ big = TAEHV(checkpoint_path=None, channels=new_channels).to(device).to(dtype).train()
318
+
319
+ big.clean_mem()
320
+ return big
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model_manager import *
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/model_manager.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, torch, json, importlib
2
+ from typing import List
3
+
4
+ from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
5
+ from .utils import load_state_dict, init_weights_on_device, hash_state_dict_keys, split_state_dict_with_prefix
6
+
7
+ def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):
8
+ loaded_model_names, loaded_models = [], []
9
+ for model_name, model_class in zip(model_names, model_classes):
10
+ #print(f" model_name: {model_name} model_class: {model_class.__name__}")
11
+ state_dict_converter = model_class.state_dict_converter()
12
+ if model_resource == "civitai":
13
+ state_dict_results = state_dict_converter.from_civitai(state_dict)
14
+ elif model_resource == "diffusers":
15
+ state_dict_results = state_dict_converter.from_diffusers(state_dict)
16
+ if isinstance(state_dict_results, tuple):
17
+ model_state_dict, extra_kwargs = state_dict_results
18
+ #print(f" This model is initialized with extra kwargs: {extra_kwargs}")
19
+ else:
20
+ model_state_dict, extra_kwargs = state_dict_results, {}
21
+ torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
22
+ with init_weights_on_device():
23
+ model = model_class(**extra_kwargs)
24
+ if hasattr(model, "eval"):
25
+ model = model.eval()
26
+ model.load_state_dict(model_state_dict, assign=True)
27
+ model = model.to(dtype=torch_dtype, device=device)
28
+ loaded_model_names.append(model_name)
29
+ loaded_models.append(model)
30
+ return loaded_model_names, loaded_models
31
+
32
+
33
+ def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
34
+ loaded_model_names, loaded_models = [], []
35
+ for model_name, model_class in zip(model_names, model_classes):
36
+ if torch_dtype in [torch.float32, torch.float16, torch.bfloat16]:
37
+ model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
38
+ else:
39
+ model = model_class.from_pretrained(file_path).eval().to(dtype=torch_dtype)
40
+ if torch_dtype == torch.float16 and hasattr(model, "half"):
41
+ model = model.half()
42
+ try:
43
+ model = model.to(device=device)
44
+ except:
45
+ pass
46
+ loaded_model_names.append(model_name)
47
+ loaded_models.append(model)
48
+ return loaded_model_names, loaded_models
49
+
50
+
51
+ def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device):
52
+ #print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}")
53
+ base_state_dict = base_model.state_dict()
54
+ base_model.to("cpu")
55
+ del base_model
56
+ model = model_class(**extra_kwargs)
57
+ model.load_state_dict(base_state_dict, strict=False)
58
+ model.load_state_dict(state_dict, strict=False)
59
+ model.to(dtype=torch_dtype, device=device)
60
+ return model
61
+
62
+
63
+ def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device):
64
+ loaded_model_names, loaded_models = [], []
65
+ for model_name, model_class in zip(model_names, model_classes):
66
+ while True:
67
+ for model_id in range(len(model_manager.model)):
68
+ base_model_name = model_manager.model_name[model_id]
69
+ if base_model_name == model_name:
70
+ base_model_path = model_manager.model_path[model_id]
71
+ base_model = model_manager.model[model_id]
72
+ print(f" Adding patch model to {base_model_name} ({base_model_path})")
73
+ patched_model = load_single_patch_model_from_single_file(
74
+ state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device)
75
+ loaded_model_names.append(base_model_name)
76
+ loaded_models.append(patched_model)
77
+ model_manager.model.pop(model_id)
78
+ model_manager.model_path.pop(model_id)
79
+ model_manager.model_name.pop(model_id)
80
+ break
81
+ else:
82
+ break
83
+ return loaded_model_names, loaded_models
84
+
85
+
86
+
87
+ class ModelDetectorTemplate:
88
+ def __init__(self):
89
+ pass
90
+
91
+ def match(self, file_path="", state_dict={}):
92
+ return False
93
+
94
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
95
+ return [], []
96
+
97
+
98
+
99
+ class ModelDetectorFromSingleFile:
100
+ def __init__(self, model_loader_configs=[]):
101
+ self.keys_hash_with_shape_dict = {}
102
+ self.keys_hash_dict = {}
103
+ for metadata in model_loader_configs:
104
+ self.add_model_metadata(*metadata)
105
+
106
+
107
+ def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource):
108
+ self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource)
109
+ if keys_hash is not None:
110
+ self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource)
111
+
112
+
113
+ def match(self, file_path="", state_dict={}):
114
+ if isinstance(file_path, str) and os.path.isdir(file_path):
115
+ return False
116
+ if len(state_dict) == 0:
117
+ state_dict = load_state_dict(file_path)
118
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
119
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
120
+ return True
121
+ keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
122
+ if keys_hash in self.keys_hash_dict:
123
+ return True
124
+ return False
125
+
126
+
127
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
128
+ if len(state_dict) == 0:
129
+ state_dict = load_state_dict(file_path)
130
+
131
+ # Load models with strict matching
132
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
133
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
134
+ model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape]
135
+ loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
136
+ return loaded_model_names, loaded_models
137
+
138
+ # Load models without strict matching
139
+ # (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
140
+ keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
141
+ if keys_hash in self.keys_hash_dict:
142
+ model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
143
+ loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
144
+ return loaded_model_names, loaded_models
145
+
146
+ return loaded_model_names, loaded_models
147
+
148
+
149
+
150
+ class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
151
+ def __init__(self, model_loader_configs=[]):
152
+ super().__init__(model_loader_configs)
153
+
154
+
155
+ def match(self, file_path="", state_dict={}):
156
+ if isinstance(file_path, str) and os.path.isdir(file_path):
157
+ return False
158
+ if len(state_dict) == 0:
159
+ state_dict = load_state_dict(file_path)
160
+ splited_state_dict = split_state_dict_with_prefix(state_dict)
161
+ for sub_state_dict in splited_state_dict:
162
+ if super().match(file_path, sub_state_dict):
163
+ return True
164
+ return False
165
+
166
+
167
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
168
+ # Split the state_dict and load from each component
169
+ splited_state_dict = split_state_dict_with_prefix(state_dict)
170
+ valid_state_dict = {}
171
+ for sub_state_dict in splited_state_dict:
172
+ if super().match(file_path, sub_state_dict):
173
+ valid_state_dict.update(sub_state_dict)
174
+ if super().match(file_path, valid_state_dict):
175
+ loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype)
176
+ else:
177
+ loaded_model_names, loaded_models = [], []
178
+ for sub_state_dict in splited_state_dict:
179
+ if super().match(file_path, sub_state_dict):
180
+ loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype)
181
+ loaded_model_names += loaded_model_names_
182
+ loaded_models += loaded_models_
183
+ return loaded_model_names, loaded_models
184
+
185
+
186
+
187
+ class ModelDetectorFromHuggingfaceFolder:
188
+ def __init__(self, model_loader_configs=[]):
189
+ self.architecture_dict = {}
190
+ for metadata in model_loader_configs:
191
+ self.add_model_metadata(*metadata)
192
+
193
+
194
+ def add_model_metadata(self, architecture, huggingface_lib, model_name, redirected_architecture):
195
+ self.architecture_dict[architecture] = (huggingface_lib, model_name, redirected_architecture)
196
+
197
+
198
+ def match(self, file_path="", state_dict={}):
199
+ if not isinstance(file_path, str) or os.path.isfile(file_path):
200
+ return False
201
+ file_list = os.listdir(file_path)
202
+ if "config.json" not in file_list:
203
+ return False
204
+ with open(os.path.join(file_path, "config.json"), "r") as f:
205
+ config = json.load(f)
206
+ if "architectures" not in config and "_class_name" not in config:
207
+ return False
208
+ return True
209
+
210
+
211
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
212
+ with open(os.path.join(file_path, "config.json"), "r") as f:
213
+ config = json.load(f)
214
+ loaded_model_names, loaded_models = [], []
215
+ architectures = config["architectures"] if "architectures" in config else [config["_class_name"]]
216
+ for architecture in architectures:
217
+ huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture]
218
+ if redirected_architecture is not None:
219
+ architecture = redirected_architecture
220
+ model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
221
+ loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
222
+ loaded_model_names += loaded_model_names_
223
+ loaded_models += loaded_models_
224
+ return loaded_model_names, loaded_models
225
+
226
+
227
+
228
+ class ModelDetectorFromPatchedSingleFile:
229
+ def __init__(self, model_loader_configs=[]):
230
+ self.keys_hash_with_shape_dict = {}
231
+ for metadata in model_loader_configs:
232
+ self.add_model_metadata(*metadata)
233
+
234
+
235
+ def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs):
236
+ self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs)
237
+
238
+
239
+ def match(self, file_path="", state_dict={}):
240
+ if not isinstance(file_path, str) or os.path.isdir(file_path):
241
+ return False
242
+ if len(state_dict) == 0:
243
+ state_dict = load_state_dict(file_path)
244
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
245
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
246
+ return True
247
+ return False
248
+
249
+
250
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs):
251
+ if len(state_dict) == 0:
252
+ state_dict = load_state_dict(file_path)
253
+
254
+ # Load models with strict matching
255
+ loaded_model_names, loaded_models = [], []
256
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
257
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
258
+ model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape]
259
+ loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
260
+ state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device)
261
+ loaded_model_names += loaded_model_names_
262
+ loaded_models += loaded_models_
263
+ return loaded_model_names, loaded_models
264
+
265
+
266
+
267
+ class ModelManager:
268
+ def __init__(
269
+ self,
270
+ torch_dtype=torch.float16,
271
+ device="cuda",
272
+ file_path_list: List[str] = [],
273
+ ):
274
+ self.torch_dtype = torch_dtype
275
+ self.device = device
276
+ self.model = []
277
+ self.model_path = []
278
+ self.model_name = []
279
+ self.model_detector = [
280
+ ModelDetectorFromSingleFile(model_loader_configs),
281
+ ModelDetectorFromSplitedSingleFile(model_loader_configs),
282
+ ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
283
+ ModelDetectorFromPatchedSingleFile(patch_model_loader_configs),
284
+ ]
285
+ self.load_models(file_path_list)
286
+
287
+
288
+ def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None):
289
+ print(f"Loading models from file: {file_path}")
290
+ if len(state_dict) == 0:
291
+ state_dict = load_state_dict(file_path)
292
+ model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device)
293
+ for model_name, model in zip(model_names, models):
294
+ self.model.append(model)
295
+ self.model_path.append(file_path)
296
+ self.model_name.append(model_name)
297
+ #print(f" The following models are loaded: {model_names}.")
298
+
299
+
300
+ def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]):
301
+ print(f"Loading models from folder: {file_path}")
302
+ model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device)
303
+ for model_name, model in zip(model_names, models):
304
+ self.model.append(model)
305
+ self.model_path.append(file_path)
306
+ self.model_name.append(model_name)
307
+ #print(f" The following models are loaded: {model_names}.")
308
+
309
+
310
+ def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}):
311
+ print(f"Loading patch models from file: {file_path}")
312
+ model_names, models = load_patch_model_from_single_file(
313
+ state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device)
314
+ for model_name, model in zip(model_names, models):
315
+ self.model.append(model)
316
+ self.model_path.append(file_path)
317
+ self.model_name.append(model_name)
318
+ print(f" The following patched models are loaded: {model_names}.")
319
+
320
+
321
+ def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
322
+ if isinstance(file_path, list):
323
+ for file_path_ in file_path:
324
+ self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
325
+ else:
326
+ print(f"Loading LoRA models from file: {file_path}")
327
+ is_loaded = False
328
+ if len(state_dict) == 0:
329
+ state_dict = load_state_dict(file_path)
330
+ for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
331
+ for lora in get_lora_loaders():
332
+ match_results = lora.match(model, state_dict)
333
+ if match_results is not None:
334
+ print(f" Adding LoRA to {model_name} ({model_path}).")
335
+ lora_prefix, model_resource = match_results
336
+ lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
337
+ is_loaded = True
338
+ break
339
+ if not is_loaded:
340
+ print(f" Cannot load LoRA: {file_path}")
341
+
342
+
343
+ def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
344
+ #print(f"Loading models from: {file_path}")
345
+ if device is None: device = self.device
346
+ if torch_dtype is None: torch_dtype = self.torch_dtype
347
+ if isinstance(file_path, list):
348
+ state_dict = {}
349
+ for path in file_path:
350
+ state_dict.update(load_state_dict(path))
351
+ elif os.path.isfile(file_path):
352
+ state_dict = load_state_dict(file_path)
353
+ else:
354
+ state_dict = None
355
+ for model_detector in self.model_detector:
356
+ if model_detector.match(file_path, state_dict):
357
+ model_names, models = model_detector.load(
358
+ file_path, state_dict,
359
+ device=device, torch_dtype=torch_dtype,
360
+ allowed_model_names=model_names, model_manager=self
361
+ )
362
+ for model_name, model in zip(model_names, models):
363
+ self.model.append(model)
364
+ self.model_path.append(file_path)
365
+ self.model_name.append(model_name)
366
+ #print(f" The following models are loaded: {model_names}.")
367
+ break
368
+ else:
369
+ print(f" We cannot detect the model type. No models are loaded.")
370
+
371
+
372
+ def load_models(self, file_path_list, model_names=None, device=None, torch_dtype=None):
373
+ for file_path in file_path_list:
374
+ self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype)
375
+
376
+
377
+ def fetch_model(self, model_name, file_path=None, require_model_path=False):
378
+ fetched_models = []
379
+ fetched_model_paths = []
380
+ for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
381
+ if file_path is not None and file_path != model_path:
382
+ continue
383
+ if model_name == model_name_:
384
+ fetched_models.append(model)
385
+ fetched_model_paths.append(model_path)
386
+ if len(fetched_models) == 0:
387
+ #print(f"No {model_name} models available.")
388
+ return None
389
+ if len(fetched_models) == 1:
390
+ print(f"Using {model_name} from {fetched_model_paths[0]}")
391
+ else:
392
+ print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}")
393
+ if require_model_path:
394
+ return fetched_models[0], fetched_model_paths[0]
395
+ else:
396
+ return fetched_models[0]
397
+
398
+
399
+ def to(self, device):
400
+ for model in self.model:
401
+ model.to(device)
402
+
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/sparse_sage/LICENSE.txt ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2025 Jintao Zhang
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/sparse_sage/core.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ https://github.com/jt-zhang/Sparse_SageAttention_API
3
+
4
+ Copyright (c) 2024 by SageAttention team.
5
+
6
+ Licensed under the Apache License, Version 2.0 (the "License");
7
+ you may not use this file except in compliance with the License.
8
+ You may obtain a copy of the License at
9
+
10
+ http://www.apache.org/licenses/LICENSE-2.0
11
+
12
+ Unless required by applicable law or agreed to in writing, software
13
+ distributed under the License is distributed on an "AS IS" BASIS,
14
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ See the License for the specific language governing permissions and
16
+ limitations under the License.
17
+ """
18
+
19
+ from .quant_per_block import per_block_int8
20
+ from .sparse_int8_attn import forward as sparse_sageattn_fwd
21
+ import torch
22
+
23
+
24
+ def sparse_sageattn(q, k, v, mask_id = None, is_causal=False, tensor_layout="HND"):
25
+ if mask_id is None:
26
+ mask_id = torch.ones((q.shape[0], q.shape[1], (q.shape[2] + 128 - 1)//128, (q.shape[3] + 64 - 1)//64), dtype=torch.int8, device=q.device) # TODO
27
+
28
+ output_dtype = q.dtype
29
+ if output_dtype == torch.bfloat16 or output_dtype == torch.float32:
30
+ v = v.to(torch.float16)
31
+
32
+ seq_dim = 1 if tensor_layout == "NHD" else 2
33
+ km = k.mean(dim=seq_dim, keepdim=True)
34
+ # km = torch.zeros((k.size(0), k.size(1), 1, k.size(3)), dtype=torch.float16, device=k.device) # Placeholder for mean, not used in quantization
35
+
36
+ q_int8, q_scale, k_int8, k_scale = per_block_int8(q, k, km=km, tensor_layout=tensor_layout)
37
+
38
+ o = sparse_sageattn_fwd(
39
+ q_int8, k_int8, mask_id, v, q_scale, k_scale,
40
+ is_causal=is_causal, tensor_layout=tensor_layout, output_dtype=output_dtype
41
+ )
42
+ return o
43
+
44
+
45
+ # flops = 4 * q.size(0) * q.size(1) * q.size(2)**2 * q.size(3) / (2 if is_causal else 1)
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/sparse_sage/quant_per_block.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024 by SageAttention team.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import torch
18
+ import triton
19
+ import triton.language as tl
20
+
21
+ @triton.jit
22
+ def quant_per_block_int8_kernel(Input, Output, Scale, L,
23
+ stride_iz, stride_ih, stride_in,
24
+ stride_oz, stride_oh, stride_on,
25
+ stride_sz, stride_sh,
26
+ sm_scale,
27
+ C: tl.constexpr, BLK: tl.constexpr):
28
+ off_blk = tl.program_id(0)
29
+ off_h = tl.program_id(1)
30
+ off_b = tl.program_id(2)
31
+
32
+ offs_n = off_blk * BLK + tl.arange(0, BLK)
33
+ offs_k = tl.arange(0, C)
34
+
35
+ input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :]
36
+ output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :]
37
+ scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk
38
+
39
+ x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
40
+ x = x.to(tl.float32)
41
+ x *= sm_scale
42
+ scale = tl.max(tl.abs(x)) / 127.
43
+ x_int8 = x / scale
44
+ x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
45
+ x_int8 = x_int8.to(tl.int8)
46
+ tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
47
+ tl.store(scale_ptrs, scale)
48
+
49
+ def per_block_int8(q, k, km=None, BLKQ=128, BLKK=64, sm_scale=None, tensor_layout="HND"):
50
+ q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
51
+ k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
52
+
53
+ if km is not None:
54
+ k = k - km
55
+
56
+ if tensor_layout == "HND":
57
+ b, h_qo, qo_len, head_dim = q.shape
58
+ _, h_kv, kv_len, _ = k.shape
59
+
60
+ stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2)
61
+ stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(1), q_int8.stride(2)
62
+ stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2)
63
+ stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(1), k_int8.stride(2)
64
+ elif tensor_layout == "NHD":
65
+ b, qo_len, h_qo, head_dim = q.shape
66
+ _, kv_len, h_kv, _ = k.shape
67
+
68
+ stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1)
69
+ stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(2), q_int8.stride(1)
70
+ stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1)
71
+ stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(2), k_int8.stride(1)
72
+ else:
73
+ raise ValueError(f"Unknown tensor layout: {tensor_layout}")
74
+
75
+ q_scale = torch.empty((b, h_qo, (qo_len + BLKQ - 1) // BLKQ), device=q.device, dtype=torch.float32)
76
+ k_scale = torch.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32)
77
+
78
+ if sm_scale is None:
79
+ sm_scale = head_dim**-0.5
80
+
81
+ grid = ((qo_len + BLKQ - 1) // BLKQ, h_qo, b)
82
+ quant_per_block_int8_kernel[grid](
83
+ q, q_int8, q_scale, qo_len,
84
+ stride_bz_q, stride_h_q, stride_seq_q,
85
+ stride_bz_qo, stride_h_qo, stride_seq_qo,
86
+ q_scale.stride(0), q_scale.stride(1),
87
+ sm_scale=(sm_scale * 1.44269504),
88
+ C=head_dim, BLK=BLKQ
89
+ )
90
+
91
+ grid = ((kv_len + BLKK - 1) // BLKK, h_kv, b)
92
+ quant_per_block_int8_kernel[grid](
93
+ k, k_int8, k_scale, kv_len,
94
+ stride_bz_k, stride_h_k, stride_seq_k,
95
+ stride_bz_ko, stride_h_ko, stride_seq_ko,
96
+ k_scale.stride(0), k_scale.stride(1),
97
+ sm_scale=1.0,
98
+ C=head_dim, BLK=BLKK
99
+ )
100
+
101
+ return q_int8, q_scale, k_int8, k_scale
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/sparse_sage/sparse_int8_attn.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024 by SageAttention team.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import torch, math
18
+ import triton
19
+ import triton.language as tl
20
+ import torch.nn.functional as F
21
+
22
+ @triton.jit
23
+ def _attn_fwd_inner(acc, l_i, old_m, q, q_scale, kv_len,
24
+ K_ptrs, K_bid_ptr, K_scale_ptr, V_ptrs, stride_kn, stride_vn, start_m,
25
+ BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr,
26
+ STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr,
27
+ ):
28
+ if STAGE == 1:
29
+ lo, hi = 0, start_m * BLOCK_M
30
+ elif STAGE == 2:
31
+ lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
32
+ lo = tl.multiple_of(lo, BLOCK_M)
33
+ K_scale_ptr += lo // BLOCK_N
34
+ K_ptrs += stride_kn * lo
35
+ V_ptrs += stride_vn * lo
36
+ elif STAGE == 3:
37
+ lo, hi = 0, kv_len
38
+ for start_n in range(lo, hi, BLOCK_N):
39
+ kbid = tl.load(K_bid_ptr + start_n//BLOCK_N)
40
+ if kbid:
41
+ k_mask = offs_n[None, :] < (kv_len - start_n)
42
+ k = tl.load(K_ptrs, mask = k_mask)
43
+ k_scale = tl.load(K_scale_ptr)
44
+ qk = tl.dot(q, k).to(tl.float32) * q_scale * k_scale
45
+ if STAGE == 2:
46
+ mask = offs_m[:, None] >= (start_n + offs_n[None, :])
47
+ qk = qk + tl.where(mask, 0, -1.0e6)
48
+ local_m = tl.max(qk, 1)
49
+ new_m = tl.maximum(old_m, local_m)
50
+ qk -= new_m[:, None]
51
+ else:
52
+ local_m = tl.max(qk, 1)
53
+ new_m = tl.maximum(old_m, local_m)
54
+ qk = qk - new_m[:, None]
55
+
56
+ p = tl.math.exp2(qk)
57
+ l_ij = tl.sum(p, 1)
58
+ alpha = tl.math.exp2(old_m - new_m)
59
+ l_i = l_i * alpha + l_ij
60
+ acc = acc * alpha[:, None]
61
+ v = tl.load(V_ptrs, mask = offs_n[:, None] < (kv_len - start_n))
62
+ p = p.to(tl.float16)
63
+ acc += tl.dot(p, v, out_dtype=tl.float16)
64
+ old_m = new_m
65
+ K_ptrs += BLOCK_N * stride_kn
66
+ K_scale_ptr += 1
67
+ V_ptrs += BLOCK_N * stride_vn
68
+ return acc, l_i, old_m
69
+
70
+ @triton.jit
71
+ def _attn_fwd(Q, K, K_blkid, V, Q_scale, K_scale, Out,
72
+ stride_qz, stride_qh, stride_qn,
73
+ stride_kz, stride_kh, stride_kn,
74
+ stride_vz, stride_vh, stride_vn,
75
+ stride_oz, stride_oh, stride_on,
76
+ stride_kbidq, stride_kbidk,
77
+ qo_len, kv_len, H:tl.constexpr, num_kv_groups:tl.constexpr,
78
+ HEAD_DIM: tl.constexpr,
79
+ BLOCK_M: tl.constexpr,
80
+ BLOCK_N: tl.constexpr,
81
+ STAGE: tl.constexpr
82
+ ):
83
+ start_m = tl.program_id(0)
84
+ off_z = tl.program_id(2).to(tl.int64)
85
+ off_h = tl.program_id(1).to(tl.int64)
86
+ q_scale_offset = (off_z * H + off_h) * tl.cdiv(qo_len, BLOCK_M)
87
+ k_scale_offset = (off_z * (H // num_kv_groups) + off_h // num_kv_groups) * tl.cdiv(kv_len, BLOCK_N)
88
+ k_bid_offset = (off_z * (H // num_kv_groups) + off_h // num_kv_groups) * stride_kbidq
89
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
90
+ offs_n = tl.arange(0, BLOCK_N)
91
+ offs_k = tl.arange(0, HEAD_DIM)
92
+ Q_ptrs = Q + (off_z * stride_qz + off_h * stride_qh) + offs_m[:, None] * stride_qn + offs_k[None, :]
93
+ Q_scale_ptr = Q_scale + q_scale_offset + start_m
94
+ K_ptrs = K + (off_z * stride_kz + (off_h // num_kv_groups) * stride_kh) + offs_n[None, :] * stride_kn + offs_k[:, None]
95
+ K_scale_ptr = K_scale + k_scale_offset
96
+ K_bid_ptr = K_blkid + k_bid_offset + start_m * stride_kbidk
97
+ V_ptrs = V + (off_z * stride_vz + (off_h // num_kv_groups) * stride_vh) + offs_n[:, None] * stride_vn + offs_k[None, :]
98
+ O_block_ptr = Out + (off_z * stride_oz + off_h * stride_oh) + offs_m[:, None] * stride_on + offs_k[None, :]
99
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
100
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
101
+ acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
102
+ q = tl.load(Q_ptrs, mask = offs_m[:, None] < qo_len)
103
+ q_scale = tl.load(Q_scale_ptr)
104
+ acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, q_scale, kv_len, K_ptrs, K_bid_ptr, K_scale_ptr, V_ptrs, stride_kn, stride_vn,
105
+ start_m,
106
+ BLOCK_M, HEAD_DIM, BLOCK_N,
107
+ 4 - STAGE, offs_m, offs_n
108
+ )
109
+ if STAGE != 1:
110
+ acc, l_i, _ = _attn_fwd_inner(acc, l_i, m_i, q, q_scale, kv_len, K_ptrs, K_bid_ptr, K_scale_ptr, V_ptrs, stride_kn, stride_vn,
111
+ start_m,
112
+ BLOCK_M, HEAD_DIM, BLOCK_N,
113
+ 2, offs_m, offs_n
114
+ )
115
+ acc = acc / l_i[:, None]
116
+ tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask = (offs_m[:, None] < qo_len))
117
+
118
+
119
+ def forward(q, k, k_block_id, v, q_scale, k_scale, is_causal=False, tensor_layout="HND", output_dtype=torch.float16):
120
+ BLOCK_M = 128
121
+ BLOCK_N = 64
122
+ stage = 3 if is_causal else 1
123
+ o = torch.empty(q.shape, dtype=output_dtype, device=q.device)
124
+
125
+ if tensor_layout == "HND":
126
+ b, h_qo, qo_len, head_dim = q.shape
127
+ _, h_kv, kv_len, _ = k.shape
128
+ stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2)
129
+ stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2)
130
+ stride_bz_v, stride_h_v, stride_seq_v = v.stride(0), v.stride(1), v.stride(2)
131
+ stride_bz_o, stride_h_o, stride_seq_o = o.stride(0), o.stride(1), o.stride(2)
132
+ elif tensor_layout == "NHD":
133
+ b, qo_len, h_qo, head_dim = q.shape
134
+ _, kv_len, h_kv, _ = k.shape
135
+ stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1)
136
+ stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1)
137
+ stride_bz_v, stride_h_v, stride_seq_v = v.stride(0), v.stride(2), v.stride(1)
138
+ stride_bz_o, stride_h_o, stride_seq_o = o.stride(0), o.stride(2), o.stride(1)
139
+ else:
140
+ raise ValueError(f"tensor_layout {tensor_layout} not supported")
141
+
142
+ if is_causal:
143
+ assert qo_len == kv_len, "qo_len and kv_len must be equal for causal attention"
144
+
145
+ HEAD_DIM_K = head_dim
146
+ num_kv_groups = h_qo // h_kv
147
+
148
+ grid = (triton.cdiv(qo_len, BLOCK_M), h_qo, b )
149
+ _attn_fwd[grid](
150
+ q, k, k_block_id, v, q_scale, k_scale, o,
151
+ stride_bz_q, stride_h_q, stride_seq_q,
152
+ stride_bz_k, stride_h_k, stride_seq_k,
153
+ stride_bz_v, stride_h_v, stride_seq_v,
154
+ stride_bz_o, stride_h_o, stride_seq_o,
155
+ k_block_id.stride(1), k_block_id.stride(2),
156
+ qo_len, kv_len,
157
+ h_qo, num_kv_groups,
158
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, HEAD_DIM=HEAD_DIM_K,
159
+ STAGE=stage,
160
+ num_warps=4 if head_dim == 64 else 8,
161
+ num_stages=4)
162
+ return o
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/utils.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, os, gc
2
+ from safetensors import safe_open
3
+ from contextlib import contextmanager
4
+ from einops import rearrange, repeat
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from tqdm import tqdm
8
+ import time
9
+ import hashlib
10
+
11
+ CACHE_T = 2
12
+
13
+ @contextmanager
14
+ def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
15
+
16
+ old_register_parameter = torch.nn.Module.register_parameter
17
+ if include_buffers:
18
+ old_register_buffer = torch.nn.Module.register_buffer
19
+
20
+ def register_empty_parameter(module, name, param):
21
+ old_register_parameter(module, name, param)
22
+ if param is not None:
23
+ param_cls = type(module._parameters[name])
24
+ kwargs = module._parameters[name].__dict__
25
+ kwargs["requires_grad"] = param.requires_grad
26
+ module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
27
+
28
+ def register_empty_buffer(module, name, buffer, persistent=True):
29
+ old_register_buffer(module, name, buffer, persistent=persistent)
30
+ if buffer is not None:
31
+ module._buffers[name] = module._buffers[name].to(device)
32
+
33
+ def patch_tensor_constructor(fn):
34
+ def wrapper(*args, **kwargs):
35
+ kwargs["device"] = device
36
+ return fn(*args, **kwargs)
37
+
38
+ return wrapper
39
+
40
+ if include_buffers:
41
+ tensor_constructors_to_patch = {
42
+ torch_function_name: getattr(torch, torch_function_name)
43
+ for torch_function_name in ["empty", "zeros", "ones", "full"]
44
+ }
45
+ else:
46
+ tensor_constructors_to_patch = {}
47
+
48
+ try:
49
+ torch.nn.Module.register_parameter = register_empty_parameter
50
+ if include_buffers:
51
+ torch.nn.Module.register_buffer = register_empty_buffer
52
+ for torch_function_name in tensor_constructors_to_patch.keys():
53
+ setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
54
+ yield
55
+ finally:
56
+ torch.nn.Module.register_parameter = old_register_parameter
57
+ if include_buffers:
58
+ torch.nn.Module.register_buffer = old_register_buffer
59
+ for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
60
+ setattr(torch, torch_function_name, old_torch_function)
61
+
62
+ def load_state_dict_from_folder(file_path, torch_dtype=None):
63
+ state_dict = {}
64
+ for file_name in os.listdir(file_path):
65
+ if "." in file_name and file_name.split(".")[-1] in [
66
+ "safetensors", "bin", "ckpt", "pth", "pt"
67
+ ]:
68
+ state_dict.update(load_state_dict(os.path.join(file_path, file_name), torch_dtype=torch_dtype))
69
+ return state_dict
70
+
71
+
72
+ def load_state_dict(file_path, torch_dtype=None):
73
+ if file_path.endswith(".safetensors"):
74
+ return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
75
+ else:
76
+ return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
77
+
78
+
79
+ def load_state_dict_from_safetensors(file_path, torch_dtype=None):
80
+ state_dict = {}
81
+ with safe_open(file_path, framework="pt", device="cpu") as f:
82
+ for k in f.keys():
83
+ state_dict[k] = f.get_tensor(k)
84
+ if torch_dtype is not None:
85
+ state_dict[k] = state_dict[k].to(torch_dtype)
86
+ return state_dict
87
+
88
+
89
+ def load_state_dict_from_bin(file_path, torch_dtype=None):
90
+ state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
91
+ if torch_dtype is not None:
92
+ for i in state_dict:
93
+ if isinstance(state_dict[i], torch.Tensor):
94
+ state_dict[i] = state_dict[i].to(torch_dtype)
95
+ return state_dict
96
+
97
+
98
+ def search_for_embeddings(state_dict):
99
+ embeddings = []
100
+ for k in state_dict:
101
+ if isinstance(state_dict[k], torch.Tensor):
102
+ embeddings.append(state_dict[k])
103
+ elif isinstance(state_dict[k], dict):
104
+ embeddings += search_for_embeddings(state_dict[k])
105
+ return embeddings
106
+
107
+
108
+ def search_parameter(param, state_dict):
109
+ for name, param_ in state_dict.items():
110
+ if param.numel() == param_.numel():
111
+ if param.shape == param_.shape:
112
+ if torch.dist(param, param_) < 1e-3:
113
+ return name
114
+ else:
115
+ if torch.dist(param.flatten(), param_.flatten()) < 1e-3:
116
+ return name
117
+ return None
118
+
119
+
120
+ def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
121
+ matched_keys = set()
122
+ with torch.no_grad():
123
+ for name in source_state_dict:
124
+ rename = search_parameter(source_state_dict[name], target_state_dict)
125
+ if rename is not None:
126
+ print(f'"{name}": "{rename}",')
127
+ matched_keys.add(rename)
128
+ elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
129
+ length = source_state_dict[name].shape[0] // 3
130
+ rename = []
131
+ for i in range(3):
132
+ rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
133
+ if None not in rename:
134
+ print(f'"{name}": {rename},')
135
+ for rename_ in rename:
136
+ matched_keys.add(rename_)
137
+ for name in target_state_dict:
138
+ if name not in matched_keys:
139
+ print("Cannot find", name, target_state_dict[name].shape)
140
+
141
+
142
+ def search_for_files(folder, extensions):
143
+ files = []
144
+ if os.path.isdir(folder):
145
+ for file in sorted(os.listdir(folder)):
146
+ files += search_for_files(os.path.join(folder, file), extensions)
147
+ elif os.path.isfile(folder):
148
+ for extension in extensions:
149
+ if folder.endswith(extension):
150
+ files.append(folder)
151
+ break
152
+ return files
153
+
154
+
155
+ def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
156
+ keys = []
157
+ for key, value in state_dict.items():
158
+ if isinstance(key, str):
159
+ if isinstance(value, torch.Tensor):
160
+ if with_shape:
161
+ shape = "_".join(map(str, list(value.shape)))
162
+ keys.append(key + ":" + shape)
163
+ keys.append(key)
164
+ elif isinstance(value, dict):
165
+ keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
166
+ keys.sort()
167
+ keys_str = ",".join(keys)
168
+ return keys_str
169
+
170
+
171
+ def split_state_dict_with_prefix(state_dict):
172
+ keys = sorted([key for key in state_dict if isinstance(key, str)])
173
+ prefix_dict = {}
174
+ for key in keys:
175
+ prefix = key if "." not in key else key.split(".")[0]
176
+ if prefix not in prefix_dict:
177
+ prefix_dict[prefix] = []
178
+ prefix_dict[prefix].append(key)
179
+ state_dicts = []
180
+ for prefix, keys in prefix_dict.items():
181
+ sub_state_dict = {key: state_dict[key] for key in keys}
182
+ state_dicts.append(sub_state_dict)
183
+ return state_dicts
184
+
185
+ def hash_state_dict_keys(state_dict, with_shape=True):
186
+ keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
187
+ keys_str = keys_str.encode(encoding="UTF-8")
188
+ return hashlib.md5(keys_str).hexdigest()
189
+
190
+ def clean_vram():
191
+ gc.collect()
192
+ if torch.cuda.is_available():
193
+ torch.cuda.empty_cache()
194
+ torch.cuda.ipc_collect()
195
+ if torch.backends.mps.is_available():
196
+ torch.mps.empty_cache()
197
+
198
+ def get_device_list():
199
+ devs = ["auto"]
200
+ try:
201
+ if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_available") and torch.cuda.is_available():
202
+ devs += [f"cuda:{i}" for i in range(torch.cuda.device_count())]
203
+ except Exception:
204
+ pass
205
+ try:
206
+ if hasattr(torch, "mps") and hasattr(torch.mps, "is_available") and torch.backends.mps.is_available():
207
+ devs += [f"mps:{i}" for i in range(torch.mps.device_count())]
208
+ except Exception:
209
+ pass
210
+ return devs
211
+
212
+ class RMS_norm(nn.Module):
213
+
214
+ def __init__(self, dim, channel_first=True, images=True, bias=False):
215
+ super().__init__()
216
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
217
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
218
+
219
+ self.channel_first = channel_first
220
+ self.scale = dim**0.5
221
+ self.gamma = nn.Parameter(torch.ones(shape))
222
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
223
+
224
+ def forward(self, x):
225
+ return F.normalize(
226
+ x, dim=(1 if self.channel_first else
227
+ -1)) * self.scale * self.gamma + self.bias
228
+
229
+ class CausalConv3d(nn.Conv3d):
230
+ """
231
+ Causal 3d convolusion.
232
+ """
233
+
234
+ def __init__(self, *args, **kwargs):
235
+ super().__init__(*args, **kwargs)
236
+ self._padding = (self.padding[2], self.padding[2], self.padding[1],
237
+ self.padding[1], 2 * self.padding[0], 0)
238
+ self.padding = (0, 0, 0)
239
+
240
+ def forward(self, x, cache_x=None):
241
+ padding = list(self._padding)
242
+ if cache_x is not None and self._padding[4] > 0:
243
+ cache_x = cache_x.to(x.device)
244
+ # print(cache_x.shape, x.shape)
245
+ x = torch.cat([cache_x, x], dim=2)
246
+ padding[4] -= cache_x.shape[2]
247
+ # print('cache!')
248
+ x = F.pad(x, padding, mode='replicate') # mode='replicate'
249
+ # print(x[0,0,:,0,0])
250
+
251
+ return super().forward(x)
252
+
253
+ class PixelShuffle3d(nn.Module):
254
+ def __init__(self, ff, hh, ww):
255
+ super().__init__()
256
+ self.ff = ff
257
+ self.hh = hh
258
+ self.ww = ww
259
+
260
+ def forward(self, x):
261
+ # x: (B, C, F, H, W)
262
+ return rearrange(x,
263
+ 'b c (f ff) (h hh) (w ww) -> b (c ff hh ww) f h w',
264
+ ff=self.ff, hh=self.hh, ww=self.ww)
265
+
266
+ class Buffer_LQ4x_Proj(nn.Module):
267
+
268
+ def __init__(self, in_dim, out_dim, layer_num=30):
269
+ super().__init__()
270
+ self.ff = 1
271
+ self.hh = 16
272
+ self.ww = 16
273
+ self.hidden_dim1 = 2048
274
+ self.hidden_dim2 = 3072
275
+ self.layer_num = layer_num
276
+
277
+ self.pixel_shuffle = PixelShuffle3d(self.ff, self.hh, self.ww)
278
+
279
+ self.conv1 = CausalConv3d(in_dim*self.ff*self.hh*self.ww, self.hidden_dim1, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) # f -> f/2 h -> h w -> w
280
+ self.norm1 = RMS_norm(self.hidden_dim1, images=False)
281
+ self.act1 = nn.SiLU()
282
+
283
+ self.conv2 = CausalConv3d(self.hidden_dim1, self.hidden_dim2, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) # f -> f/2 h -> h w -> w
284
+ self.norm2 = RMS_norm(self.hidden_dim2, images=False)
285
+ self.act2 = nn.SiLU()
286
+
287
+ self.linear_layers = nn.ModuleList([nn.Linear(self.hidden_dim2, out_dim) for _ in range(layer_num)])
288
+
289
+ self.clip_idx = 0
290
+
291
+ def forward(self, video):
292
+ self.clear_cache()
293
+ # x: (B, C, F, H, W)
294
+
295
+ t = video.shape[2]
296
+ iter_ = 1 + (t - 1) // 4
297
+ first_frame = video[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
298
+ video = torch.cat([first_frame, video], dim=2)
299
+ # print(video.shape)
300
+
301
+ out_x = []
302
+ for i in range(iter_):
303
+ x = self.pixel_shuffle(video[:,:,i*4:(i+1)*4,:,:])
304
+ cache1_x = x[:, :, -CACHE_T:, :, :].clone()
305
+ self.cache['conv1'] = cache1_x
306
+ x = self.conv1(x, self.cache['conv1'])
307
+ x = self.norm1(x)
308
+ x = self.act1(x)
309
+ cache2_x = x[:, :, -CACHE_T:, :, :].clone()
310
+ self.cache['conv2'] = cache2_x
311
+ if i == 0:
312
+ continue
313
+ x = self.conv2(x, self.cache['conv2'])
314
+ x = self.norm2(x)
315
+ x = self.act2(x)
316
+ out_x.append(x)
317
+ out_x = torch.cat(out_x, dim = 2)
318
+ # print(out_x.shape)
319
+ out_x = rearrange(out_x, 'b c f h w -> b (f h w) c')
320
+ outputs = []
321
+ for i in range(self.layer_num):
322
+ outputs.append(self.linear_layers[i](out_x))
323
+ return outputs
324
+
325
+ def clear_cache(self):
326
+ self.cache = {}
327
+ self.cache['conv1'] = None
328
+ self.cache['conv2'] = None
329
+ self.clip_idx = 0
330
+
331
+ def stream_forward(self, video_clip):
332
+ if self.clip_idx == 0:
333
+ # self.clear_cache()
334
+ first_frame = video_clip[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
335
+ video_clip = torch.cat([first_frame, video_clip], dim=2)
336
+ x = self.pixel_shuffle(video_clip)
337
+ cache1_x = x[:, :, -CACHE_T:, :, :].clone()
338
+ self.cache['conv1'] = cache1_x
339
+ x = self.conv1(x, self.cache['conv1'])
340
+ x = self.norm1(x)
341
+ x = self.act1(x)
342
+ cache2_x = x[:, :, -CACHE_T:, :, :].clone()
343
+ self.cache['conv2'] = cache2_x
344
+ self.clip_idx += 1
345
+ return None
346
+ else:
347
+ x = self.pixel_shuffle(video_clip)
348
+ cache1_x = x[:, :, -CACHE_T:, :, :].clone()
349
+ self.cache['conv1'] = cache1_x
350
+ x = self.conv1(x, self.cache['conv1'])
351
+ x = self.norm1(x)
352
+ x = self.act1(x)
353
+ cache2_x = x[:, :, -CACHE_T:, :, :].clone()
354
+ self.cache['conv2'] = cache2_x
355
+ x = self.conv2(x, self.cache['conv2'])
356
+ x = self.norm2(x)
357
+ x = self.act2(x)
358
+ out_x = rearrange(x, 'b c f h w -> b (f h w) c')
359
+ outputs = []
360
+ for i in range(self.layer_num):
361
+ outputs.append(self.linear_layers[i](out_x))
362
+ self.clip_idx += 1
363
+ return outputs
364
+
365
+ class Causal_LQ4x_Proj(nn.Module):
366
+
367
+ def __init__(self, in_dim, out_dim, layer_num=30):
368
+ super().__init__()
369
+ self.ff = 1
370
+ self.hh = 16
371
+ self.ww = 16
372
+ self.hidden_dim1 = 2048
373
+ self.hidden_dim2 = 3072
374
+ self.layer_num = layer_num
375
+
376
+ self.pixel_shuffle = PixelShuffle3d(self.ff, self.hh, self.ww)
377
+
378
+ self.conv1 = CausalConv3d(in_dim*self.ff*self.hh*self.ww, self.hidden_dim1, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) # f -> f/2 h -> h w -> w
379
+ self.norm1 = RMS_norm(self.hidden_dim1, images=False)
380
+ self.act1 = nn.SiLU()
381
+
382
+ self.conv2 = CausalConv3d(self.hidden_dim1, self.hidden_dim2, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) # f -> f/2 h -> h w -> w
383
+ self.norm2 = RMS_norm(self.hidden_dim2, images=False)
384
+ self.act2 = nn.SiLU()
385
+
386
+ self.linear_layers = nn.ModuleList([nn.Linear(self.hidden_dim2, out_dim) for _ in range(layer_num)])
387
+
388
+ self.clip_idx = 0
389
+
390
+ def forward(self, video):
391
+ self.clear_cache()
392
+ # x: (B, C, F, H, W)
393
+
394
+ t = video.shape[2]
395
+ iter_ = 1 + (t - 1) // 4
396
+ first_frame = video[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
397
+ video = torch.cat([first_frame, video], dim=2)
398
+ # print(video.shape)
399
+
400
+ out_x = []
401
+ for i in range(iter_):
402
+ x = self.pixel_shuffle(video[:,:,i*4:(i+1)*4,:,:])
403
+ cache1_x = x[:, :, -CACHE_T:, :, :].clone()
404
+ x = self.conv1(x, self.cache['conv1'])
405
+ self.cache['conv1'] = cache1_x
406
+ x = self.norm1(x)
407
+ x = self.act1(x)
408
+ cache2_x = x[:, :, -CACHE_T:, :, :].clone()
409
+ if i == 0:
410
+ self.cache['conv2'] = cache2_x
411
+ continue
412
+ x = self.conv2(x, self.cache['conv2'])
413
+ self.cache['conv2'] = cache2_x
414
+ x = self.norm2(x)
415
+ x = self.act2(x)
416
+ out_x.append(x)
417
+ out_x = torch.cat(out_x, dim = 2)
418
+ out_x = rearrange(out_x, 'b c f h w -> b (f h w) c')
419
+ outputs = []
420
+ for i in range(self.layer_num):
421
+ outputs.append(self.linear_layers[i](out_x))
422
+ return outputs
423
+
424
+ def clear_cache(self):
425
+ self.cache = {}
426
+ self.cache['conv1'] = None
427
+ self.cache['conv2'] = None
428
+ self.clip_idx = 0
429
+
430
+ def stream_forward(self, video_clip):
431
+ if self.clip_idx == 0:
432
+ # self.clear_cache()
433
+ first_frame = video_clip[:, :, :1, :, :].repeat(1, 1, 3, 1, 1)
434
+ video_clip = torch.cat([first_frame, video_clip], dim=2)
435
+ x = self.pixel_shuffle(video_clip)
436
+ cache1_x = x[:, :, -CACHE_T:, :, :].clone()
437
+ x = self.conv1(x, self.cache['conv1'])
438
+ self.cache['conv1'] = cache1_x
439
+ x = self.norm1(x)
440
+ x = self.act1(x)
441
+ cache2_x = x[:, :, -CACHE_T:, :, :].clone()
442
+ self.cache['conv2'] = cache2_x
443
+ self.clip_idx += 1
444
+ return None
445
+ else:
446
+ x = self.pixel_shuffle(video_clip)
447
+ cache1_x = x[:, :, -CACHE_T:, :, :].clone()
448
+ x = self.conv1(x, self.cache['conv1'])
449
+ self.cache['conv1'] = cache1_x
450
+ x = self.norm1(x)
451
+ x = self.act1(x)
452
+ cache2_x = x[:, :, -CACHE_T:, :, :].clone()
453
+ x = self.conv2(x, self.cache['conv2'])
454
+ self.cache['conv2'] = cache2_x
455
+ x = self.norm2(x)
456
+ x = self.act2(x)
457
+ out_x = rearrange(x, 'b c f h w -> b (f h w) c')
458
+ outputs = []
459
+ for i in range(self.layer_num):
460
+ outputs.append(self.linear_layers[i](out_x))
461
+ self.clip_idx += 1
462
+ return outputs
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/wan_video_dit.py ADDED
@@ -0,0 +1,864 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ import random
6
+ import os
7
+ import time
8
+ from typing import Tuple, Optional, List
9
+ from einops import rearrange
10
+ from .utils import hash_state_dict_keys
11
+
12
+ try:
13
+ import flash_attn_interface
14
+ FLASH_ATTN_3_AVAILABLE = True
15
+ except ModuleNotFoundError:
16
+ FLASH_ATTN_3_AVAILABLE = False
17
+
18
+ try:
19
+ import flash_attn
20
+ FLASH_ATTN_2_AVAILABLE = True
21
+ except ModuleNotFoundError:
22
+ FLASH_ATTN_2_AVAILABLE = False
23
+
24
+ try:
25
+ from sageattention import sageattn
26
+ SAGE_ATTN_AVAILABLE = True
27
+ except ModuleNotFoundError:
28
+ SAGE_ATTN_AVAILABLE = False
29
+
30
+ try:
31
+ from block_sparse_attn import block_sparse_attn_func
32
+ BLOCK_ATTN_AVAILABLE = True
33
+ except:
34
+ BLOCK_ATTN_AVAILABLE = False
35
+
36
+ from .sparse_sage.core import sparse_sageattn
37
+ from PIL import Image
38
+ import numpy as np
39
+
40
+ USE_BLOCK_ATTN = False
41
+
42
+ # ----------------------------
43
+ # Local / window masks
44
+ # ----------------------------
45
+ @torch.no_grad()
46
+ def build_local_block_mask_shifted_vec(block_h: int,
47
+ block_w: int,
48
+ win_h: int = 6,
49
+ win_w: int = 6,
50
+ include_self: bool = True,
51
+ device=None) -> torch.Tensor:
52
+ device = device or torch.device("cpu")
53
+ H, W = block_h, block_w
54
+ r = torch.arange(H, device=device)
55
+ c = torch.arange(W, device=device)
56
+ YY, XX = torch.meshgrid(r, c, indexing="ij")
57
+ r_all = YY.reshape(-1)
58
+ c_all = XX.reshape(-1)
59
+ r_half = win_h // 2
60
+ c_half = win_w // 2
61
+ start_r = torch.clamp(r_all - r_half, 0, H - win_h)
62
+ end_r = start_r + win_h - 1
63
+ start_c = torch.clamp(c_all - c_half, 0, W - win_w)
64
+ end_c = start_c + win_w - 1
65
+ in_row = (r_all[None, :] >= start_r[:, None]) & (r_all[None, :] <= end_r[:, None])
66
+ in_col = (c_all[None, :] >= start_c[:, None]) & (c_all[None, :] <= end_c[:, None])
67
+ mask = in_row & in_col
68
+ if not include_self:
69
+ mask.fill_diagonal_(False)
70
+ return mask
71
+
72
+ @torch.no_grad()
73
+ def build_local_block_mask_shifted_vec_normal_slide(block_h: int,
74
+ block_w: int,
75
+ win_h: int = 6,
76
+ win_w: int = 6,
77
+ include_self: bool = True,
78
+ device=None) -> torch.Tensor:
79
+ device = device or torch.device("cpu")
80
+ H, W = block_h, block_w
81
+ r = torch.arange(H, device=device)
82
+ c = torch.arange(W, device=device)
83
+ YY, XX = torch.meshgrid(r, c, indexing="ij")
84
+ r_all = YY.reshape(-1)
85
+ c_all = XX.reshape(-1)
86
+ r_half = win_h // 2
87
+ c_half = win_w // 2
88
+ start_r = r_all - r_half
89
+ end_r = start_r + win_h - 1
90
+ start_c = c_all - c_half
91
+ end_c = start_c + win_w - 1
92
+ in_row = (r_all[None, :] >= start_r[:, None]) & (r_all[None, :] <= end_r[:, None])
93
+ in_col = (c_all[None, :] >= start_c[:, None]) & (c_all[None, :] <= end_c[:, None])
94
+ mask = in_row & in_col
95
+ if not include_self:
96
+ mask.fill_diagonal_(False)
97
+ return mask
98
+
99
+
100
+ class WindowPartition3D:
101
+ """Partition / reverse-partition helpers for 5-D tensors (B,F,H,W,C)."""
102
+ @staticmethod
103
+ def partition(x: torch.Tensor, win: Tuple[int, int, int]):
104
+ B, F, H, W, C = x.shape
105
+ wf, wh, ww = win
106
+ assert F % wf == 0 and H % wh == 0 and W % ww == 0, "Dims must divide by window size."
107
+ x = x.view(B, F // wf, wf, H // wh, wh, W // ww, ww, C)
108
+ x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous()
109
+ return x.view(-1, wf * wh * ww, C)
110
+
111
+ @staticmethod
112
+ def reverse(windows: torch.Tensor, win: Tuple[int, int, int], orig: Tuple[int, int, int]):
113
+ F, H, W = orig
114
+ wf, wh, ww = win
115
+ nf, nh, nw = F // wf, H // wh, W // ww
116
+ B = windows.size(0) // (nf * nh * nw)
117
+ x = windows.view(B, nf, nh, nw, wf, wh, ww, -1)
118
+ x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous()
119
+ return x.view(B, F, H, W, -1)
120
+
121
+
122
+ @torch.no_grad()
123
+ def generate_draft_block_mask(batch_size, nheads, seqlen,
124
+ q_w, k_w, topk=10, local_attn_mask=None):
125
+ assert batch_size == 1, "Only batch_size=1 supported for now"
126
+ assert local_attn_mask is not None, "local_attn_mask must be provided"
127
+ avgpool_q = torch.mean(q_w, dim=1)
128
+ avgpool_k = torch.mean(k_w, dim=1)
129
+ avgpool_q = rearrange(avgpool_q, 's (h d) -> s h d', h=nheads)
130
+ avgpool_k = rearrange(avgpool_k, 's (h d) -> s h d', h=nheads)
131
+ q_heads = avgpool_q.permute(1, 0, 2)
132
+ k_heads = avgpool_k.permute(1, 0, 2)
133
+ D = avgpool_q.shape[-1]
134
+ scores = torch.einsum("hld,hmd->hlm", q_heads, k_heads) / math.sqrt(D)
135
+
136
+ repeat_head = scores.shape[0]
137
+ repeat_len = scores.shape[1] // local_attn_mask.shape[0]
138
+ repeat_num = scores.shape[2] // local_attn_mask.shape[1]
139
+ local_attn_mask = local_attn_mask.unsqueeze(1).unsqueeze(0).repeat(repeat_len, 1, repeat_num, 1)
140
+ local_attn_mask = rearrange(local_attn_mask, 'x a y b -> (x a) (y b)')
141
+ local_attn_mask = local_attn_mask.unsqueeze(0).repeat(repeat_head, 1, 1)
142
+ local_attn_mask = local_attn_mask.to(torch.float32)
143
+ local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == False, -float('inf'))
144
+ local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == True, 0)
145
+ scores = scores + local_attn_mask
146
+
147
+ attn_map = torch.softmax(scores, dim=-1)
148
+ attn_map = rearrange(attn_map, 'h (it s1) s2 -> (h it) s1 s2', it=seqlen)
149
+ loop_num, s1, s2 = attn_map.shape
150
+ flat = attn_map.reshape(loop_num, -1)
151
+ n = flat.shape[1]
152
+ apply_topk = min(flat.shape[1]-1, topk)
153
+ thresholds = torch.topk(flat, k=apply_topk + 1, dim=1, largest=True).values[:, -1]
154
+ thresholds = thresholds.unsqueeze(1)
155
+ mask_new = (flat > thresholds).reshape(loop_num, s1, s2)
156
+ mask_new = rearrange(mask_new, '(h it) s1 s2 -> h (it s1) s2', it=seqlen) # keep shape note
157
+ # 修正:上行变量名统一
158
+ # mask_new = rearrange(attn_map, 'h (it s1) s2 -> h (it s1) s2', it=seqlen) * 0 + mask_new
159
+ mask = mask_new.unsqueeze(0).repeat(batch_size, 1, 1, 1)
160
+ return mask
161
+
162
+
163
+ @torch.no_grad()
164
+ def generate_draft_block_mask_sage(batch_size, nheads, seqlen,
165
+ q_w, k_w, topk=10, local_attn_mask=None):
166
+ assert batch_size == 1, "Only batch_size=1 supported for now"
167
+ assert local_attn_mask is not None, "local_attn_mask must be provided"
168
+
169
+ avgpool_q = torch.mean(q_w, dim=1)
170
+ avgpool_q = rearrange(avgpool_q, 's (h d) -> s h d', h=nheads)
171
+ q_heads = avgpool_q.permute(1, 0, 2)
172
+ D = avgpool_q.shape[-1]
173
+
174
+ k_w_split = k_w.view(k_w.shape[0], 2, 64, k_w.shape[2])
175
+ avgpool_k_split = torch.mean(k_w_split, dim=2)
176
+ avgpool_k_refined = rearrange(avgpool_k_split, 's two d -> (s two) d', two=2) # shape: (s*2, C)
177
+ avgpool_k_refined = rearrange(avgpool_k_refined, 's (h d) -> s h d', h=nheads) # shape: (s*2, h, d)
178
+ k_heads_doubled = avgpool_k_refined.permute(1, 0, 2) # shape: (h, s*2, d)
179
+
180
+ k_heads_1, k_heads_2 = torch.chunk(k_heads_doubled, 2, dim=1)
181
+ scores_1 = torch.einsum("hld,hmd->hlm", q_heads, k_heads_1) / math.sqrt(D)
182
+ scores_2 = torch.einsum("hld,hmd->hlm", q_heads, k_heads_2) / math.sqrt(D)
183
+ scores = torch.cat([scores_1, scores_2], dim=-1)
184
+
185
+ repeat_head = scores.shape[0]
186
+ repeat_len = scores.shape[1] // local_attn_mask.shape[0]
187
+ repeat_num = (scores.shape[2] // 2) // local_attn_mask.shape[1]
188
+
189
+ local_attn_mask = local_attn_mask.unsqueeze(1).unsqueeze(0).repeat(repeat_len, 1, repeat_num, 1)
190
+ local_attn_mask = rearrange(local_attn_mask, 'x a y b -> (x a) (y b)')
191
+ local_attn_mask = local_attn_mask.repeat_interleave(2, dim=1)
192
+ local_attn_mask = local_attn_mask.unsqueeze(0).repeat(repeat_head, 1, 1)
193
+
194
+ assert scores.shape == local_attn_mask.shape, \
195
+ f"Scores shape {scores.shape} != Mask shape {local_attn_mask.shape}"
196
+
197
+ local_attn_mask = local_attn_mask.to(torch.float32)
198
+ local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == False, -float('inf'))
199
+ local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == True, 0)
200
+ scores = scores + local_attn_mask
201
+
202
+ attn_map = torch.softmax(scores, dim=-1)
203
+ attn_map = rearrange(attn_map, 'h (it s1) s2 -> (h it) s1 s2', it=seqlen)
204
+ loop_num, s1, s2 = attn_map.shape
205
+ flat = attn_map.reshape(loop_num, -1)
206
+ apply_topk = min(flat.shape[1]-1, topk)
207
+
208
+ if apply_topk <= 0:
209
+ mask_new = torch.zeros_like(flat, dtype=torch.bool).reshape(loop_num, s1, s2)
210
+ else:
211
+ thresholds = torch.topk(flat, k=apply_topk + 1, dim=1, largest=True).values[:, -1]
212
+ thresholds = thresholds.unsqueeze(1)
213
+ mask_new = (flat > thresholds).reshape(loop_num, s1, s2)
214
+
215
+ mask_new = rearrange(mask_new, '(h it) s1 s2 -> h (it s1) s2', it=seqlen)
216
+ mask = mask_new.unsqueeze(0).repeat(batch_size, 1, 1, 1)
217
+ return mask
218
+
219
+
220
+ # ----------------------------
221
+ # Attention kernels
222
+ # ----------------------------
223
+ def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False, attention_mask=None, return_KV=False):
224
+ if attention_mask is not None:
225
+ seqlen = q.shape[1]
226
+ seqlen_kv = k.shape[1]
227
+ if USE_BLOCK_ATTN and BLOCK_ATTN_AVAILABLE:
228
+ q = rearrange(q, "b s (n d) -> (b s) n d", n=num_heads)
229
+ k = rearrange(k, "b s (n d) -> (b s) n d", n=num_heads)
230
+ v = rearrange(v, "b s (n d) -> (b s) n d", n=num_heads)
231
+ else:
232
+ q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
233
+ k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
234
+ v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
235
+ cu_seqlens_q = torch.tensor([0, seqlen], device=q.device, dtype=torch.int32)
236
+ cu_seqlens_k = torch.tensor([0, seqlen_kv], device=q.device, dtype=torch.int32)
237
+ head_mask_type = torch.tensor([1]*num_heads, device=q.device, dtype=torch.int32)
238
+ streaming_info = None
239
+ base_blockmask = attention_mask
240
+ max_seqlen_q_ = seqlen
241
+ max_seqlen_k_ = seqlen_kv
242
+ p_dropout = 0.0
243
+ if USE_BLOCK_ATTN and BLOCK_ATTN_AVAILABLE:
244
+ x = block_sparse_attn_func(
245
+ q, k, v,
246
+ cu_seqlens_q, cu_seqlens_k,
247
+ head_mask_type,
248
+ streaming_info,
249
+ base_blockmask,
250
+ max_seqlen_q_, max_seqlen_k_,
251
+ p_dropout,
252
+ deterministic=False,
253
+ softmax_scale=None,
254
+ is_causal=False,
255
+ exact_streaming=False,
256
+ return_attn_probs=False,
257
+ ).unsqueeze(0)
258
+ x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
259
+ else:
260
+ x = sparse_sageattn(
261
+ q, k, v,
262
+ mask_id=base_blockmask.to(torch.int8),
263
+ is_causal=False,
264
+ tensor_layout="HND"
265
+ )
266
+ x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
267
+ elif compatibility_mode:
268
+ q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
269
+ k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
270
+ v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
271
+ x = F.scaled_dot_product_attention(q, k, v)
272
+ x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
273
+ elif FLASH_ATTN_3_AVAILABLE:
274
+ q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
275
+ k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
276
+ v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
277
+ x = flash_attn_interface.flash_attn_func(q, k, v)
278
+ if isinstance(x, tuple):
279
+ x = x[0]
280
+ x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
281
+ elif FLASH_ATTN_2_AVAILABLE:
282
+ q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
283
+ k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
284
+ v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
285
+ x = flash_attn.flash_attn_func(q, k, v)
286
+ x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
287
+ elif SAGE_ATTN_AVAILABLE:
288
+ q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
289
+ k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
290
+ v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
291
+ x = sageattn(q, k, v)
292
+ x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
293
+ else:
294
+ q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
295
+ k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
296
+ v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
297
+ x = F.scaled_dot_product_attention(q, k, v)
298
+ x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
299
+ return x
300
+
301
+
302
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
303
+ return (x * (1 + scale) + shift)
304
+
305
+
306
+ def sinusoidal_embedding_1d(dim, position):
307
+ sinusoid = torch.outer(position.type(torch.float64), torch.pow(
308
+ 10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
309
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
310
+ return x.to(position.dtype)
311
+
312
+
313
+ def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0):
314
+ f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta)
315
+ h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
316
+ w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
317
+ return f_freqs_cis, h_freqs_cis, w_freqs_cis
318
+
319
+
320
+ def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
321
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)
322
+ [: (dim // 2)].double() / dim))
323
+ freqs = torch.outer(torch.arange(end, device=freqs.device), freqs)
324
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
325
+ return freqs_cis
326
+
327
+
328
+ def rope_apply(x, freqs, num_heads):
329
+ x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
330
+ x_out = torch.view_as_complex(x.to(torch.float64).reshape(
331
+ x.shape[0], x.shape[1], x.shape[2], -1, 2))
332
+ x_out = torch.view_as_real(x_out * freqs).flatten(2)
333
+ return x_out.to(x.dtype)
334
+
335
+
336
+ # ----------------------------
337
+ # Norms & Blocks
338
+ # ----------------------------
339
+ class RMSNorm(nn.Module):
340
+ def __init__(self, dim, eps=1e-5):
341
+ super().__init__()
342
+ self.eps = eps
343
+ self.weight = nn.Parameter(torch.ones(dim))
344
+
345
+ def norm(self, x):
346
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
347
+
348
+ def forward(self, x):
349
+ dtype = x.dtype
350
+ return self.norm(x.float()).to(dtype) * self.weight
351
+
352
+
353
+ class AttentionModule(nn.Module):
354
+ def __init__(self, num_heads):
355
+ super().__init__()
356
+ self.num_heads = num_heads
357
+
358
+ def forward(self, q, k, v, attention_mask=None):
359
+ x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads, attention_mask=attention_mask)
360
+ return x
361
+
362
+
363
+ class SelfAttention(nn.Module):
364
+ def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
365
+ super().__init__()
366
+ self.dim = dim
367
+ self.num_heads = num_heads
368
+ self.head_dim = dim // num_heads
369
+
370
+ self.q = nn.Linear(dim, dim)
371
+ self.k = nn.Linear(dim, dim)
372
+ self.v = nn.Linear(dim, dim)
373
+ self.o = nn.Linear(dim, dim)
374
+ self.norm_q = RMSNorm(dim, eps=eps)
375
+ self.norm_k = RMSNorm(dim, eps=eps)
376
+
377
+ self.attn = AttentionModule(self.num_heads)
378
+ self.local_attn_mask = None
379
+
380
+ def forward(self, x, freqs, f=None, h=None, w=None, local_num=None, topk=None,
381
+ train_img=False, block_id=None, kv_len=None, is_full_block=False,
382
+ is_stream=False, pre_cache_k=None, pre_cache_v=None, local_range = 9):
383
+ B, L, D = x.shape
384
+ if is_stream and pre_cache_k is not None and pre_cache_v is not None:
385
+ assert f==2, "f must be 2"
386
+ if is_stream and (pre_cache_k is None or pre_cache_v is None):
387
+ assert f==6, " start f must be 6"
388
+ assert L == f * h * w, "Sequence length mismatch with provided (f,h,w)."
389
+
390
+ q = self.norm_q(self.q(x))
391
+ k = self.norm_k(self.k(x))
392
+ v = self.v(x)
393
+ q = rope_apply(q, freqs, self.num_heads)
394
+ k = rope_apply(k, freqs, self.num_heads)
395
+
396
+ win = (2, 8, 8)
397
+ q = q.view(B, f, h, w, D)
398
+ k = k.view(B, f, h, w, D)
399
+ v = v.view(B, f, h, w, D)
400
+
401
+ q_w = WindowPartition3D.partition(q, win)
402
+ k_w = WindowPartition3D.partition(k, win)
403
+ v_w = WindowPartition3D.partition(v, win)
404
+
405
+ seqlen = f//win[0]
406
+ one_len = k_w.shape[0] // B // seqlen
407
+ if pre_cache_k is not None and pre_cache_v is not None:
408
+ k_w = torch.cat([pre_cache_k, k_w], dim=0)
409
+ v_w = torch.cat([pre_cache_v, v_w], dim=0)
410
+
411
+ block_n = q_w.shape[0] // B
412
+ block_s = q_w.shape[1]
413
+ block_n_kv = k_w.shape[0] // B
414
+
415
+ reorder_q = rearrange(q_w, '(b block_n) (block_s) d -> b (block_n block_s) d', block_n=block_n, block_s=block_s)
416
+ reorder_k = rearrange(k_w, '(b block_n) (block_s) d -> b (block_n block_s) d', block_n=block_n_kv, block_s=block_s)
417
+ reorder_v = rearrange(v_w, '(b block_n) (block_s) d -> b (block_n block_s) d', block_n=block_n_kv, block_s=block_s)
418
+
419
+ window_size = win[0]*h*w//128
420
+
421
+ if self.local_attn_mask is None or self.local_attn_mask_h!=h//8 or self.local_attn_mask_w!=w//8 or self.local_range!=local_range:
422
+ self.local_attn_mask = build_local_block_mask_shifted_vec_normal_slide(h//8, w//8, local_range, local_range, include_self=True, device=k_w.device)
423
+ self.local_attn_mask_h = h//8
424
+ self.local_attn_mask_w = w//8
425
+ self.local_range = local_range
426
+ if USE_BLOCK_ATTN and BLOCK_ATTN_AVAILABLE:
427
+ attention_mask = generate_draft_block_mask(B, self.num_heads, seqlen, q_w, k_w, topk=topk, local_attn_mask=self.local_attn_mask)
428
+ else:
429
+ attention_mask = generate_draft_block_mask_sage(B, self.num_heads, seqlen, q_w, k_w, topk=topk, local_attn_mask=self.local_attn_mask)
430
+
431
+ x = self.attn(reorder_q, reorder_k, reorder_v, attention_mask)
432
+
433
+ cur_block_n, cur_block_s, _ = k_w.shape
434
+ cache_num = cur_block_n // one_len
435
+ if cache_num > kv_len:
436
+ cache_k = k_w[one_len:, :, :]
437
+ cache_v = v_w[one_len:, :, :]
438
+ else:
439
+ cache_k = k_w
440
+ cache_v = v_w
441
+
442
+ x = rearrange(x, 'b (block_n block_s) d -> (b block_n) (block_s) d', block_n=block_n, block_s=block_s)
443
+ x = WindowPartition3D.reverse(x, win, (f, h, w))
444
+ x = x.view(B, f*h*w, D)
445
+
446
+ if is_stream:
447
+ return self.o(x), cache_k, cache_v
448
+ return self.o(x)
449
+
450
+
451
+ class CrossAttention(nn.Module):
452
+ """
453
+ 仅考虑文本 context;提供持久 KV 缓存。
454
+ """
455
+ def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
456
+ super().__init__()
457
+ self.dim = dim
458
+ self.num_heads = num_heads
459
+ self.head_dim = dim // num_heads
460
+
461
+ self.q = nn.Linear(dim, dim)
462
+ self.k = nn.Linear(dim, dim)
463
+ self.v = nn.Linear(dim, dim)
464
+ self.o = nn.Linear(dim, dim)
465
+
466
+ self.norm_q = RMSNorm(dim, eps=eps)
467
+ self.norm_k = RMSNorm(dim, eps=eps)
468
+
469
+ self.attn = AttentionModule(self.num_heads)
470
+
471
+ # 持久缓存
472
+ self.cache_k = None
473
+ self.cache_v = None
474
+
475
+ @torch.no_grad()
476
+ def init_cache(self, ctx: torch.Tensor):
477
+ """ctx: [B, S_ctx, dim] —— 经过 text_embedding 之后的上下文"""
478
+ self.cache_k = self.norm_k(self.k(ctx))
479
+ self.cache_v = self.v(ctx)
480
+
481
+ def clear_cache(self):
482
+ self.cache_k = None
483
+ self.cache_v = None
484
+
485
+ def forward(self, x: torch.Tensor, y: torch.Tensor, is_stream: bool = False):
486
+ """
487
+ y 即文本上下文(未做其他分支)。
488
+ """
489
+ q = self.norm_q(self.q(x))
490
+ assert self.cache_k is not None and self.cache_v is not None
491
+ k = self.cache_k
492
+ v = self.cache_v
493
+
494
+ x = self.attn(q, k, v)
495
+ return self.o(x)
496
+
497
+
498
+ class GateModule(nn.Module):
499
+ def __init__(self,):
500
+ super().__init__()
501
+
502
+ def forward(self, x, gate, residual):
503
+ return x + gate * residual
504
+
505
+
506
+ class DiTBlock(nn.Module):
507
+ def __init__(self, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6):
508
+ super().__init__()
509
+ self.dim = dim
510
+ self.num_heads = num_heads
511
+ self.ffn_dim = ffn_dim
512
+
513
+ self.self_attn = SelfAttention(dim, num_heads, eps)
514
+ self.cross_attn = CrossAttention(dim, num_heads, eps)
515
+
516
+ self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
517
+ self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
518
+ self.norm3 = nn.LayerNorm(dim, eps=eps)
519
+ self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(
520
+ approximate='tanh'), nn.Linear(ffn_dim, dim))
521
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
522
+ self.gate = GateModule()
523
+
524
+ def forward(self, x, context, t_mod, freqs, f, h, w, local_num=None, topk=None,
525
+ train_img=False, block_id=None, kv_len=None, is_full_block=False,
526
+ is_stream=False, pre_cache_k=None, pre_cache_v=None, local_range = 9):
527
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
528
+ self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
529
+ input_x = modulate(self.norm1(x), shift_msa, scale_msa)
530
+ self_attn_output, self_attn_cache_k, self_attn_cache_v = self.self_attn(
531
+ input_x, freqs, f, h, w, local_num, topk, train_img, block_id,
532
+ kv_len=kv_len, is_full_block=is_full_block, is_stream=is_stream,
533
+ pre_cache_k=pre_cache_k, pre_cache_v=pre_cache_v, local_range = local_range)
534
+
535
+ x = self.gate(x, gate_msa, self_attn_output)
536
+ x = x + self.cross_attn(self.norm3(x), context, is_stream=is_stream)
537
+ input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
538
+ x = self.gate(x, gate_mlp, self.ffn(input_x))
539
+ if is_stream:
540
+ return x, self_attn_cache_k, self_attn_cache_v
541
+ return x
542
+
543
+
544
+ class MLP(torch.nn.Module):
545
+ def __init__(self, in_dim, out_dim, has_pos_emb=False):
546
+ super().__init__()
547
+ self.proj = torch.nn.Sequential(
548
+ nn.LayerNorm(in_dim),
549
+ nn.Linear(in_dim, in_dim),
550
+ nn.GELU(),
551
+ nn.Linear(in_dim, out_dim),
552
+ nn.LayerNorm(out_dim)
553
+ )
554
+ self.has_pos_emb = has_pos_emb
555
+ if has_pos_emb:
556
+ self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280)))
557
+
558
+ def forward(self, x):
559
+ if self.has_pos_emb:
560
+ x = x + self.emb_pos.to(dtype=x.dtype, device=x.device)
561
+ return self.proj(x)
562
+
563
+
564
+ class Head(nn.Module):
565
+ def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float):
566
+ super().__init__()
567
+ self.dim = dim
568
+ self.patch_size = patch_size
569
+ self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
570
+ self.head = nn.Linear(dim, out_dim * math.prod(patch_size))
571
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
572
+
573
+ def forward(self, x, t_mod):
574
+ shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1)
575
+ x = (self.head(self.norm(x) * (1 + scale) + shift))
576
+ return x
577
+
578
+
579
+ # ----------------------------
580
+ # WanModel (no image branch) — init 时即产生 KV 缓存
581
+ # ----------------------------
582
+ class WanModel(torch.nn.Module):
583
+ def __init__(
584
+ self,
585
+ dim: int,
586
+ in_dim: int,
587
+ ffn_dim: int,
588
+ out_dim: int,
589
+ text_dim: int,
590
+ freq_dim: int,
591
+ eps: float,
592
+ patch_size: Tuple[int, int, int],
593
+ num_heads: int,
594
+ num_layers: int,
595
+ # init_context: torch.Tensor, # <<<< 必填:在 __init__ 里用它生成 cross-attn KV 缓存
596
+ has_image_input: bool = False,
597
+ ):
598
+ super().__init__()
599
+ self.dim = dim
600
+ self.freq_dim = freq_dim
601
+ self.patch_size = patch_size
602
+
603
+ # patch embed
604
+ self.patch_embedding = nn.Conv3d(
605
+ in_dim, dim, kernel_size=patch_size, stride=patch_size)
606
+
607
+ # text / time embed
608
+ self.text_embedding = nn.Sequential(
609
+ nn.Linear(text_dim, dim),
610
+ nn.GELU(approximate='tanh'),
611
+ nn.Linear(dim, dim)
612
+ )
613
+ self.time_embedding = nn.Sequential(
614
+ nn.Linear(freq_dim, dim),
615
+ nn.SiLU(),
616
+ nn.Linear(dim, dim)
617
+ )
618
+ self.time_projection = nn.Sequential(
619
+ nn.SiLU(), nn.Linear(dim, dim * 6))
620
+
621
+ # blocks
622
+ self.blocks = nn.ModuleList([
623
+ DiTBlock(dim, num_heads, ffn_dim, eps)
624
+ for _ in range(num_layers)
625
+ ])
626
+ self.head = Head(dim, out_dim, patch_size, eps)
627
+
628
+ head_dim = dim // num_heads
629
+ self.freqs = precompute_freqs_cis_3d(head_dim)
630
+
631
+ self._cross_kv_initialized = False
632
+
633
+ # 可选:手动清空 / 重新初始化
634
+ def clear_cross_kv(self):
635
+ for blk in self.blocks:
636
+ blk.cross_attn.clear_cache()
637
+ self._cross_kv_initialized = False
638
+
639
+ @torch.no_grad()
640
+ def reinit_cross_kv(self, new_context: torch.Tensor):
641
+ ctx_txt = self.text_embedding(new_context)
642
+ for blk in self.blocks:
643
+ blk.cross_attn.init_cache(ctx_txt)
644
+ self._cross_kv_initialized = True
645
+
646
+ def patchify(self, x: torch.Tensor):
647
+ x = self.patch_embedding(x)
648
+ grid_size = x.shape[2:]
649
+ x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
650
+ return x, grid_size # x, grid_size: (f, h, w)
651
+
652
+ def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
653
+ return rearrange(
654
+ x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)',
655
+ f=grid_size[0], h=grid_size[1], w=grid_size[2],
656
+ x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2]
657
+ )
658
+
659
+ def forward(self,
660
+ x: torch.Tensor,
661
+ timestep: torch.Tensor,
662
+ context: torch.Tensor,
663
+ use_gradient_checkpointing: bool = False,
664
+ use_gradient_checkpointing_offload: bool = False,
665
+ LQ_latents: Optional[List[torch.Tensor]] = None,
666
+ train_img: bool = False,
667
+ topk_ratio: Optional[float] = None,
668
+ kv_ratio: Optional[float] = None,
669
+ local_num: Optional[int] = None,
670
+ is_full_block: bool = False,
671
+ causal_idx: Optional[int] = None,
672
+ **kwargs,
673
+ ):
674
+ # time / text embeds
675
+ t = self.time_embedding(
676
+ sinusoidal_embedding_1d(self.freq_dim, timestep))
677
+ t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
678
+
679
+ # 这里仍会嵌入 text(CrossAttention 若已有缓存会忽略它)
680
+ # context = self.text_embedding(context)
681
+
682
+ # 输入打补丁
683
+ x, (f, h, w) = self.patchify(x)
684
+ B = x.shape[0]
685
+
686
+ # window / masks 超参
687
+ win = (2, 8, 8)
688
+ seqlen = f//win[0]
689
+ if local_num is None:
690
+ local_random = random.random()
691
+ if local_random < 0.3:
692
+ local_num = seqlen - 3
693
+ elif local_random < 0.4:
694
+ local_num = seqlen - 4
695
+ elif local_random < 0.5:
696
+ local_num = seqlen - 2
697
+ else:
698
+ local_num = seqlen
699
+
700
+ window_size = win[0]*h*w//128
701
+ square_num = window_size*window_size
702
+ topk_ratio = 2.0
703
+ topk = min(max(int(square_num*topk_ratio), 1), int(square_num*seqlen)-1)
704
+
705
+ if kv_ratio is None:
706
+ kv_ratio = (random.uniform(0., 1.0)**2)*(local_num-2-2)+2
707
+ kv_len = min(max(int(window_size*kv_ratio), 1), int(window_size*seqlen)-1)
708
+
709
+ decay_ratio = random.uniform(0.7, 1.0)
710
+
711
+ # RoPE 3D
712
+ freqs = torch.cat([
713
+ self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
714
+ self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
715
+ self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
716
+ ], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
717
+
718
+ def create_custom_forward(module):
719
+ def custom_forward(*inputs):
720
+ return module(*inputs)
721
+ return custom_forward
722
+
723
+ # blocks
724
+ for block_id, block in enumerate(self.blocks):
725
+ if LQ_latents is not None and block_id < len(LQ_latents):
726
+ x += LQ_latents[block_id]
727
+
728
+ if self.training and use_gradient_checkpointing:
729
+ if use_gradient_checkpointing_offload:
730
+ with torch.autograd.graph.save_on_cpu():
731
+ x = torch.utils.checkpoint.checkpoint(
732
+ create_custom_forward(block),
733
+ x, context, t_mod, freqs, f, h, w, local_num, topk,
734
+ train_img, block_id, kv_len, is_full_block, False,
735
+ None, None,
736
+ use_reentrant=False,
737
+ )
738
+ else:
739
+ x = torch.utils.checkpoint.checkpoint(
740
+ create_custom_forward(block),
741
+ x, context, t_mod, freqs, f, h, w, local_num, topk,
742
+ train_img, block_id, kv_len, is_full_block, False,
743
+ None, None,
744
+ use_reentrant=False,
745
+ )
746
+ else:
747
+ x = block(x, context, t_mod, freqs, f, h, w, local_num, topk,
748
+ train_img, block_id, kv_len, is_full_block, False,
749
+ None, None)
750
+
751
+ x = self.head(x, t)
752
+ x = self.unpatchify(x, (f, h, w))
753
+ return x
754
+
755
+ @staticmethod
756
+ def state_dict_converter():
757
+ return WanModelStateDictConverter()
758
+
759
+
760
+ # ----------------------------
761
+ # State dict converter(保持原映射;已忽略 has_image_input 使用)
762
+ # ----------------------------
763
+ class WanModelStateDictConverter:
764
+ def __init__(self):
765
+ pass
766
+
767
+ def from_diffusers(self, state_dict):
768
+ rename_dict = {
769
+ "blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
770
+ "blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
771
+ "blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
772
+ "blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
773
+ "blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
774
+ "blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
775
+ "blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
776
+ "blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
777
+ "blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
778
+ "blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
779
+ "blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
780
+ "blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
781
+ "blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
782
+ "blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
783
+ "blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
784
+ "blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
785
+ "blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
786
+ "blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
787
+ "blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
788
+ "blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
789
+ "blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
790
+ "blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
791
+ "blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
792
+ "blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
793
+ "blocks.0.norm2.bias": "blocks.0.norm3.bias",
794
+ "blocks.0.norm2.weight": "blocks.0.norm3.weight",
795
+ "blocks.0.scale_shift_table": "blocks.0.modulation",
796
+ "condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
797
+ "condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
798
+ "condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
799
+ "condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
800
+ "condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
801
+ "condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
802
+ "condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
803
+ "condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
804
+ "condition_embedder.time_proj.bias": "time_projection.1.bias",
805
+ "condition_embedder.time_proj.weight": "time_projection.1.weight",
806
+ "patch_embedding.bias": "patch_embedding.bias",
807
+ "patch_embedding.weight": "patch_embedding.weight",
808
+ "scale_shift_table": "head.modulation",
809
+ "proj_out.bias": "head.head.bias",
810
+ "proj_out.weight": "head.head.weight",
811
+ }
812
+ state_dict_ = {}
813
+ for name, param in state_dict.items():
814
+ if name in rename_dict:
815
+ state_dict_[rename_dict[name]] = param
816
+ else:
817
+ name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
818
+ if name_ in rename_dict:
819
+ name_ = rename_dict[name_]
820
+ name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
821
+ state_dict_[name_] = param
822
+ if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b":
823
+ config = {
824
+ "model_type": "t2v",
825
+ "patch_size": (1, 2, 2),
826
+ "text_len": 512,
827
+ "in_dim": 16,
828
+ "dim": 5120,
829
+ "ffn_dim": 13824,
830
+ "freq_dim": 256,
831
+ "text_dim": 4096,
832
+ "out_dim": 16,
833
+ "num_heads": 40,
834
+ "num_layers": 40,
835
+ "window_size": (-1, -1),
836
+ "qk_norm": True,
837
+ "cross_attn_norm": True,
838
+ "eps": 1e-6,
839
+ }
840
+ else:
841
+ config = {}
842
+ return state_dict_, config
843
+
844
+ def from_civitai(self, state_dict):
845
+ state_dict = {name: param for name, param in state_dict.items() if not name.startswith("vace")}
846
+ # 保留原有哈希匹配返回的 config;实现本身不使用 has_image_input 分支
847
+ if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
848
+ config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 16,"dim": 1536,"ffn_dim": 8960,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 12,"num_layers": 30,"eps": 1e-6}
849
+ elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70":
850
+ config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 16,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6}
851
+ elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
852
+ config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 36,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6}
853
+ elif hash_state_dict_keys(state_dict) == "6d6ccde6845b95ad9114ab993d917893":
854
+ config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 36,"dim": 1536,"ffn_dim": 8960,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 12,"num_layers": 30,"eps": 1e-6}
855
+ elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
856
+ config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 48,"dim": 1536,"ffn_dim": 8960,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 12,"num_layers": 30,"eps": 1e-6}
857
+ elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c":
858
+ config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 48,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6}
859
+ elif hash_state_dict_keys(state_dict) == "3ef3b1f8e1dab83d5b71fd7b617f859f":
860
+ config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 36,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6,"has_image_pos_emb": False}
861
+ else:
862
+ config = {}
863
+ return state_dict, config
864
+
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/wan_video_vae.py ADDED
@@ -0,0 +1,847 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from einops import rearrange, repeat
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from tqdm import tqdm
7
+
8
+ CACHE_T = 2
9
+
10
+
11
+ def check_is_instance(model, module_class):
12
+ if isinstance(model, module_class):
13
+ return True
14
+ if hasattr(model, "module") and isinstance(model.module, module_class):
15
+ return True
16
+ return False
17
+
18
+
19
+ def block_causal_mask(x, block_size):
20
+ # params
21
+ b, n, s, _, device = *x.size(), x.device
22
+ assert s % block_size == 0
23
+ num_blocks = s // block_size
24
+
25
+ # build mask
26
+ mask = torch.zeros(b, n, s, s, dtype=torch.bool, device=device)
27
+ for i in range(num_blocks):
28
+ mask[:, :,
29
+ i * block_size:(i + 1) * block_size, :(i + 1) * block_size] = 1
30
+ return mask
31
+
32
+
33
+ class CausalConv3d(nn.Conv3d):
34
+ """
35
+ Causal 3d convolusion.
36
+ """
37
+
38
+ def __init__(self, *args, **kwargs):
39
+ super().__init__(*args, **kwargs)
40
+ self._padding = (self.padding[2], self.padding[2], self.padding[1],
41
+ self.padding[1], 2 * self.padding[0], 0)
42
+ self.padding = (0, 0, 0)
43
+
44
+ def forward(self, x, cache_x=None):
45
+ padding = list(self._padding)
46
+ if cache_x is not None and self._padding[4] > 0:
47
+ cache_x = cache_x.to(x.device)
48
+ # print('cache_x.shape', cache_x.shape, 'x.shape', x.shape)
49
+ x = torch.cat([cache_x, x], dim=2)
50
+ padding[4] -= cache_x.shape[2]
51
+ x = F.pad(x, padding)
52
+
53
+ return super().forward(x)
54
+
55
+
56
+ class RMS_norm(nn.Module):
57
+
58
+ def __init__(self, dim, channel_first=True, images=True, bias=False):
59
+ super().__init__()
60
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
61
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
62
+
63
+ self.channel_first = channel_first
64
+ self.scale = dim**0.5
65
+ self.gamma = nn.Parameter(torch.ones(shape))
66
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
67
+
68
+ def forward(self, x):
69
+ return F.normalize(
70
+ x, dim=(1 if self.channel_first else
71
+ -1)) * self.scale * self.gamma + self.bias
72
+
73
+
74
+ class Upsample(nn.Upsample):
75
+
76
+ def forward(self, x):
77
+ """
78
+ Fix bfloat16 support for nearest neighbor interpolation.
79
+ """
80
+ return super().forward(x.float()).type_as(x)
81
+
82
+
83
+ class Resample(nn.Module):
84
+
85
+ def __init__(self, dim, mode):
86
+ assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
87
+ 'downsample3d')
88
+ super().__init__()
89
+ self.dim = dim
90
+ self.mode = mode
91
+
92
+ # layers
93
+ if mode == 'upsample2d':
94
+ self.resample = nn.Sequential(
95
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
96
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
97
+ elif mode == 'upsample3d':
98
+ self.resample = nn.Sequential(
99
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
100
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
101
+ self.time_conv = CausalConv3d(dim,
102
+ dim * 2, (3, 1, 1),
103
+ padding=(1, 0, 0))
104
+
105
+ elif mode == 'downsample2d':
106
+ self.resample = nn.Sequential(
107
+ nn.ZeroPad2d((0, 1, 0, 1)),
108
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
109
+ elif mode == 'downsample3d':
110
+ self.resample = nn.Sequential(
111
+ nn.ZeroPad2d((0, 1, 0, 1)),
112
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
113
+ self.time_conv = CausalConv3d(dim,
114
+ dim, (3, 1, 1),
115
+ stride=(2, 1, 1),
116
+ padding=(0, 0, 0))
117
+
118
+ else:
119
+ self.resample = nn.Identity()
120
+
121
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
122
+ b, c, t, h, w = x.size()
123
+ if self.mode == 'upsample3d':
124
+ if feat_cache is not None:
125
+ idx = feat_idx[0]
126
+ if feat_cache[idx] is None:
127
+ feat_cache[idx] = 'Rep'
128
+ feat_idx[0] += 1
129
+ else:
130
+
131
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
132
+ if cache_x.shape[2] < 2 and feat_cache[
133
+ idx] is not None and feat_cache[idx] != 'Rep':
134
+ # cache last frame of last two chunk
135
+ cache_x = torch.cat([
136
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
137
+ cache_x.device), cache_x
138
+ ],
139
+ dim=2)
140
+ if cache_x.shape[2] < 2 and feat_cache[
141
+ idx] is not None and feat_cache[idx] == 'Rep':
142
+ cache_x = torch.cat([
143
+ torch.zeros_like(cache_x).to(cache_x.device),
144
+ cache_x
145
+ ],
146
+ dim=2)
147
+ if feat_cache[idx] == 'Rep':
148
+ x = self.time_conv(x)
149
+ else:
150
+ x = self.time_conv(x, feat_cache[idx])
151
+ feat_cache[idx] = cache_x
152
+ feat_idx[0] += 1
153
+
154
+ x = x.reshape(b, 2, c, t, h, w)
155
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
156
+ 3)
157
+ x = x.reshape(b, c, t * 2, h, w)
158
+ t = x.shape[2]
159
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
160
+ x = self.resample(x)
161
+ x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
162
+
163
+ if self.mode == 'downsample3d':
164
+ if feat_cache is not None:
165
+ idx = feat_idx[0]
166
+ if feat_cache[idx] is None:
167
+ feat_cache[idx] = x.clone()
168
+ feat_idx[0] += 1
169
+ else:
170
+ cache_x = x[:, :, -1:, :, :].clone()
171
+ x = self.time_conv(
172
+ torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
173
+ feat_cache[idx] = cache_x
174
+ feat_idx[0] += 1
175
+ return x
176
+
177
+ def init_weight(self, conv):
178
+ conv_weight = conv.weight
179
+ nn.init.zeros_(conv_weight)
180
+ c1, c2, t, h, w = conv_weight.size()
181
+ one_matrix = torch.eye(c1, c2)
182
+ init_matrix = one_matrix
183
+ nn.init.zeros_(conv_weight)
184
+ conv_weight.data[:, :, 1, 0, 0] = init_matrix
185
+ conv.weight.data.copy_(conv_weight)
186
+ nn.init.zeros_(conv.bias.data)
187
+
188
+ def init_weight2(self, conv):
189
+ conv_weight = conv.weight.data
190
+ nn.init.zeros_(conv_weight)
191
+ c1, c2, t, h, w = conv_weight.size()
192
+ init_matrix = torch.eye(c1 // 2, c2)
193
+ conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
194
+ conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
195
+ conv.weight.data.copy_(conv_weight)
196
+ nn.init.zeros_(conv.bias.data)
197
+
198
+
199
+ class ResidualBlock(nn.Module):
200
+
201
+ def __init__(self, in_dim, out_dim, dropout=0.0):
202
+ super().__init__()
203
+ self.in_dim = in_dim
204
+ self.out_dim = out_dim
205
+
206
+ # layers
207
+ self.residual = nn.Sequential(
208
+ RMS_norm(in_dim, images=False), nn.SiLU(),
209
+ CausalConv3d(in_dim, out_dim, 3, padding=1),
210
+ RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
211
+ CausalConv3d(out_dim, out_dim, 3, padding=1))
212
+ self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
213
+ if in_dim != out_dim else nn.Identity()
214
+
215
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
216
+ h = self.shortcut(x)
217
+ for layer in self.residual:
218
+ if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
219
+ idx = feat_idx[0]
220
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
221
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
222
+ # cache last frame of last two chunk
223
+ cache_x = torch.cat([
224
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
225
+ cache_x.device), cache_x
226
+ ],
227
+ dim=2)
228
+ x = layer(x, feat_cache[idx])
229
+ feat_cache[idx] = cache_x
230
+ feat_idx[0] += 1
231
+ else:
232
+ x = layer(x)
233
+ return x + h
234
+
235
+
236
+ class AttentionBlock(nn.Module):
237
+ """
238
+ Causal self-attention with a single head.
239
+ """
240
+
241
+ def __init__(self, dim):
242
+ super().__init__()
243
+ self.dim = dim
244
+
245
+ # layers
246
+ self.norm = RMS_norm(dim)
247
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
248
+ self.proj = nn.Conv2d(dim, dim, 1)
249
+
250
+ # zero out the last layer params
251
+ nn.init.zeros_(self.proj.weight)
252
+
253
+ def forward(self, x):
254
+ identity = x
255
+ b, c, t, h, w = x.size()
256
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
257
+ x = self.norm(x)
258
+ # compute query, key, value
259
+ q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(
260
+ 0, 1, 3, 2).contiguous().chunk(3, dim=-1)
261
+
262
+ # apply attention
263
+ x = F.scaled_dot_product_attention(
264
+ q,
265
+ k,
266
+ v,
267
+ #attn_mask=block_causal_mask(q, block_size=h * w)
268
+ )
269
+ x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
270
+
271
+ # output
272
+ x = self.proj(x)
273
+ x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
274
+ return x + identity
275
+
276
+
277
+ class Encoder3d(nn.Module):
278
+
279
+ def __init__(self,
280
+ dim=128,
281
+ z_dim=4,
282
+ dim_mult=[1, 2, 4, 4],
283
+ num_res_blocks=2,
284
+ attn_scales=[],
285
+ temperal_downsample=[True, True, False],
286
+ dropout=0.0):
287
+ super().__init__()
288
+ self.dim = dim
289
+ self.z_dim = z_dim
290
+ self.dim_mult = dim_mult
291
+ self.num_res_blocks = num_res_blocks
292
+ self.attn_scales = attn_scales
293
+ self.temperal_downsample = temperal_downsample
294
+
295
+ # dimensions
296
+ dims = [dim * u for u in [1] + dim_mult]
297
+ scale = 1.0
298
+
299
+ # init block
300
+ self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
301
+
302
+ # downsample blocks
303
+ downsamples = []
304
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
305
+ # residual (+attention) blocks
306
+ for _ in range(num_res_blocks):
307
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
308
+ if scale in attn_scales:
309
+ downsamples.append(AttentionBlock(out_dim))
310
+ in_dim = out_dim
311
+
312
+ # downsample block
313
+ if i != len(dim_mult) - 1:
314
+ mode = 'downsample3d' if temperal_downsample[
315
+ i] else 'downsample2d'
316
+ downsamples.append(Resample(out_dim, mode=mode))
317
+ scale /= 2.0
318
+ self.downsamples = nn.Sequential(*downsamples)
319
+
320
+ # middle blocks
321
+ self.middle = nn.Sequential(ResidualBlock(out_dim, out_dim, dropout),
322
+ AttentionBlock(out_dim),
323
+ ResidualBlock(out_dim, out_dim, dropout))
324
+
325
+ # output blocks
326
+ self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
327
+ CausalConv3d(out_dim, z_dim, 3, padding=1))
328
+
329
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
330
+ if feat_cache is not None:
331
+ idx = feat_idx[0]
332
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
333
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
334
+ # cache last frame of last two chunk
335
+ cache_x = torch.cat([
336
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
337
+ cache_x.device), cache_x
338
+ ],
339
+ dim=2)
340
+ x = self.conv1(x, feat_cache[idx])
341
+ feat_cache[idx] = cache_x
342
+ feat_idx[0] += 1
343
+ else:
344
+ x = self.conv1(x)
345
+
346
+ ## downsamples
347
+ for layer in self.downsamples:
348
+ if feat_cache is not None:
349
+ x = layer(x, feat_cache, feat_idx)
350
+ else:
351
+ x = layer(x)
352
+
353
+ ## middle
354
+ for layer in self.middle:
355
+ if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
356
+ x = layer(x, feat_cache, feat_idx)
357
+ else:
358
+ x = layer(x)
359
+
360
+ ## head
361
+ for layer in self.head:
362
+ if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
363
+ idx = feat_idx[0]
364
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
365
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
366
+ # cache last frame of last two chunk
367
+ cache_x = torch.cat([
368
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
369
+ cache_x.device), cache_x
370
+ ],
371
+ dim=2)
372
+ x = layer(x, feat_cache[idx])
373
+ feat_cache[idx] = cache_x
374
+ feat_idx[0] += 1
375
+ else:
376
+ x = layer(x)
377
+ return x
378
+
379
+
380
+ class Decoder3d(nn.Module):
381
+
382
+ def __init__(self,
383
+ dim=128,
384
+ z_dim=4,
385
+ dim_mult=[1, 2, 4, 4],
386
+ num_res_blocks=2,
387
+ attn_scales=[],
388
+ temperal_upsample=[False, True, True],
389
+ dropout=0.0):
390
+ super().__init__()
391
+ self.dim = dim
392
+ self.z_dim = z_dim
393
+ self.dim_mult = dim_mult
394
+ self.num_res_blocks = num_res_blocks
395
+ self.attn_scales = attn_scales
396
+ self.temperal_upsample = temperal_upsample
397
+
398
+ # dimensions
399
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
400
+ scale = 1.0 / 2**(len(dim_mult) - 2)
401
+
402
+ # init block
403
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
404
+
405
+ # middle blocks
406
+ self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout),
407
+ AttentionBlock(dims[0]),
408
+ ResidualBlock(dims[0], dims[0], dropout))
409
+
410
+ # upsample blocks
411
+ upsamples = []
412
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
413
+ # residual (+attention) blocks
414
+ if i == 1 or i == 2 or i == 3:
415
+ in_dim = in_dim // 2
416
+ for _ in range(num_res_blocks + 1):
417
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
418
+ if scale in attn_scales:
419
+ upsamples.append(AttentionBlock(out_dim))
420
+ in_dim = out_dim
421
+
422
+ # upsample block
423
+ if i != len(dim_mult) - 1:
424
+ mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
425
+ upsamples.append(Resample(out_dim, mode=mode))
426
+ scale *= 2.0
427
+ self.upsamples = nn.Sequential(*upsamples)
428
+
429
+ # output blocks
430
+ self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
431
+ CausalConv3d(out_dim, 3, 3, padding=1))
432
+
433
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
434
+ ## conv1
435
+ if feat_cache is not None:
436
+ idx = feat_idx[0]
437
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
438
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
439
+ # cache last frame of last two chunk
440
+ cache_x = torch.cat([
441
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
442
+ cache_x.device), cache_x
443
+ ],
444
+ dim=2)
445
+ x = self.conv1(x, feat_cache[idx])
446
+ feat_cache[idx] = cache_x
447
+ feat_idx[0] += 1
448
+ else:
449
+ x = self.conv1(x)
450
+
451
+ ## middle
452
+ for layer in self.middle:
453
+ if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
454
+ x = layer(x, feat_cache, feat_idx)
455
+ else:
456
+ x = layer(x)
457
+
458
+ ## upsamples
459
+ for layer in self.upsamples:
460
+ if feat_cache is not None:
461
+ x = layer(x, feat_cache, feat_idx)
462
+ else:
463
+ x = layer(x)
464
+
465
+ ## head
466
+ for layer in self.head:
467
+ if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
468
+ idx = feat_idx[0]
469
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
470
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
471
+ # cache last frame of last two chunk
472
+ cache_x = torch.cat([
473
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
474
+ cache_x.device), cache_x
475
+ ],
476
+ dim=2)
477
+ x = layer(x, feat_cache[idx])
478
+ feat_cache[idx] = cache_x
479
+ feat_idx[0] += 1
480
+ else:
481
+ x = layer(x)
482
+ return x
483
+
484
+
485
+ def count_conv3d(model):
486
+ count = 0
487
+ for m in model.modules():
488
+ if check_is_instance(m, CausalConv3d):
489
+ count += 1
490
+ return count
491
+
492
+
493
+ class VideoVAE_(nn.Module):
494
+
495
+ def __init__(self,
496
+ dim=96,
497
+ z_dim=16,
498
+ dim_mult=[1, 2, 4, 4],
499
+ num_res_blocks=2,
500
+ attn_scales=[],
501
+ temperal_downsample=[False, True, True],
502
+ dropout=0.0):
503
+ super().__init__()
504
+ self.dim = dim
505
+ self.z_dim = z_dim
506
+ self.dim_mult = dim_mult
507
+ self.num_res_blocks = num_res_blocks
508
+ self.attn_scales = attn_scales
509
+ self.temperal_downsample = temperal_downsample
510
+ self.temperal_upsample = temperal_downsample[::-1]
511
+
512
+ # modules
513
+ self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
514
+ attn_scales, self.temperal_downsample, dropout)
515
+ self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
516
+ self.conv2 = CausalConv3d(z_dim, z_dim, 1)
517
+ self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
518
+ attn_scales, self.temperal_upsample, dropout)
519
+
520
+ def forward(self, x):
521
+ mu, log_var = self.encode(x)
522
+ z = self.reparameterize(mu, log_var)
523
+ x_recon = self.decode(z)
524
+ return x_recon, mu, log_var
525
+
526
+ def encode(self, x, scale):
527
+ self.clear_cache()
528
+ ## cache
529
+ t = x.shape[2]
530
+ iter_ = 1 + (t - 1) // 4
531
+
532
+ for i in range(iter_):
533
+ self._enc_conv_idx = [0]
534
+ if i == 0:
535
+ out = self.encoder(x[:, :, :1, :, :],
536
+ feat_cache=self._enc_feat_map,
537
+ feat_idx=self._enc_conv_idx)
538
+ else:
539
+ out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
540
+ feat_cache=self._enc_feat_map,
541
+ feat_idx=self._enc_conv_idx)
542
+ out = torch.cat([out, out_], 2)
543
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
544
+ if isinstance(scale[0], torch.Tensor):
545
+ scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale]
546
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
547
+ 1, self.z_dim, 1, 1, 1)
548
+ else:
549
+ scale = scale.to(dtype=mu.dtype, device=mu.device)
550
+ mu = (mu - scale[0]) * scale[1]
551
+ return mu
552
+
553
+ def decode(self, z, scale):
554
+ self.clear_cache()
555
+ # z: [b,c,t,h,w]
556
+ if isinstance(scale[0], torch.Tensor):
557
+ scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
558
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
559
+ 1, self.z_dim, 1, 1, 1)
560
+ else:
561
+ scale = scale.to(dtype=z.dtype, device=z.device)
562
+ z = z / scale[1] + scale[0]
563
+ iter_ = z.shape[2]
564
+ x = self.conv2(z)
565
+ for i in range(iter_):
566
+ self._conv_idx = [0]
567
+ if i == 0:
568
+ out = self.decoder(x[:, :, i:i + 1, :, :],
569
+ feat_cache=self._feat_map,
570
+ feat_idx=self._conv_idx)
571
+ else:
572
+ out_ = self.decoder(x[:, :, i:i + 1, :, :],
573
+ feat_cache=self._feat_map,
574
+ feat_idx=self._conv_idx)
575
+ out = torch.cat([out, out_], 2) # may add tensor offload
576
+ return out
577
+
578
+
579
+ def stream_decode(self, z, scale):
580
+ # self.clear_cache()
581
+ # z: [b,c,t,h,w]
582
+ if isinstance(scale[0], torch.Tensor):
583
+ scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
584
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
585
+ 1, self.z_dim, 1, 1, 1)
586
+ else:
587
+ scale = scale.to(dtype=z.dtype, device=z.device)
588
+ z = z / scale[1] + scale[0]
589
+ iter_ = z.shape[2]
590
+ x = self.conv2(z)
591
+ for i in range(iter_):
592
+ self._conv_idx = [0]
593
+ if i == 0:
594
+ out = self.decoder(x[:, :, i:i + 1, :, :],
595
+ feat_cache=self._feat_map,
596
+ feat_idx=self._conv_idx)
597
+ else:
598
+ out_ = self.decoder(x[:, :, i:i + 1, :, :],
599
+ feat_cache=self._feat_map,
600
+ feat_idx=self._conv_idx)
601
+ out = torch.cat([out, out_], 2) # may add tensor offload
602
+ return out
603
+
604
+ def reparameterize(self, mu, log_var):
605
+ std = torch.exp(0.5 * log_var)
606
+ eps = torch.randn_like(std)
607
+ return eps * std + mu
608
+
609
+ def sample(self, imgs, deterministic=False):
610
+ mu, log_var = self.encode(imgs)
611
+ if deterministic:
612
+ return mu
613
+ std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
614
+ return mu + std * torch.randn_like(std)
615
+
616
+ def clear_cache(self):
617
+ self._conv_num = count_conv3d(self.decoder)
618
+ self._conv_idx = [0]
619
+ self._feat_map = [None] * self._conv_num
620
+ # print('self._feat_map', len(self._feat_map))
621
+ # cache encode
622
+ if self.encoder is not None:
623
+ self._enc_conv_num = count_conv3d(self.encoder)
624
+ self._enc_conv_idx = [0]
625
+ self._enc_feat_map = [None] * self._enc_conv_num
626
+ # print('self._enc_feat_map', len(self._enc_feat_map))
627
+
628
+
629
+ class WanVideoVAE(nn.Module):
630
+
631
+ def __init__(self, z_dim=16, dim=96):
632
+ super().__init__()
633
+
634
+ mean = [
635
+ -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
636
+ 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
637
+ ]
638
+ std = [
639
+ 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
640
+ 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
641
+ ]
642
+ self.mean = torch.tensor(mean)
643
+ self.std = torch.tensor(std)
644
+ self.scale = [self.mean, 1.0 / self.std]
645
+
646
+ # init model
647
+ self.model = VideoVAE_(z_dim=z_dim, dim = dim).eval().requires_grad_(False)
648
+ self.upsampling_factor = 8
649
+
650
+
651
+ def build_1d_mask(self, length, left_bound, right_bound, border_width):
652
+ x = torch.ones((length,))
653
+ if not left_bound:
654
+ x[:border_width] = (torch.arange(border_width) + 1) / border_width
655
+ if not right_bound:
656
+ x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
657
+ return x
658
+
659
+
660
+ def build_mask(self, data, is_bound, border_width):
661
+ _, _, _, H, W = data.shape
662
+ h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0])
663
+ w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1])
664
+
665
+ h = repeat(h, "H -> H W", H=H, W=W)
666
+ w = repeat(w, "W -> H W", H=H, W=W)
667
+
668
+ mask = torch.stack([h, w]).min(dim=0).values
669
+ mask = rearrange(mask, "H W -> 1 1 1 H W")
670
+ return mask
671
+
672
+
673
+ def tiled_decode(self, hidden_states, device, tile_size, tile_stride):
674
+ _, _, T, H, W = hidden_states.shape
675
+ size_h, size_w = tile_size
676
+ stride_h, stride_w = tile_stride
677
+
678
+ # Split tasks
679
+ tasks = []
680
+ for h in range(0, H, stride_h):
681
+ if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
682
+ for w in range(0, W, stride_w):
683
+ if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
684
+ h_, w_ = h + size_h, w + size_w
685
+ tasks.append((h, h_, w, w_))
686
+
687
+ data_device = "cpu"
688
+ computation_device = device
689
+
690
+ out_T = T * 4 - 3
691
+ weight = torch.zeros((1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
692
+ values = torch.zeros((1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
693
+
694
+ for h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"):
695
+ hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(computation_device)
696
+ hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(data_device)
697
+
698
+ mask = self.build_mask(
699
+ hidden_states_batch,
700
+ is_bound=(h==0, h_>=H, w==0, w_>=W),
701
+ border_width=((size_h - stride_h) * self.upsampling_factor, (size_w - stride_w) * self.upsampling_factor)
702
+ ).to(dtype=hidden_states.dtype, device=data_device)
703
+
704
+ target_h = h * self.upsampling_factor
705
+ target_w = w * self.upsampling_factor
706
+ values[
707
+ :,
708
+ :,
709
+ :,
710
+ target_h:target_h + hidden_states_batch.shape[3],
711
+ target_w:target_w + hidden_states_batch.shape[4],
712
+ ] += hidden_states_batch * mask
713
+ weight[
714
+ :,
715
+ :,
716
+ :,
717
+ target_h: target_h + hidden_states_batch.shape[3],
718
+ target_w: target_w + hidden_states_batch.shape[4],
719
+ ] += mask
720
+ values = values / weight
721
+ values = values.clamp_(-1, 1)
722
+ return values
723
+
724
+
725
+ def tiled_encode(self, video, device, tile_size, tile_stride):
726
+ _, _, T, H, W = video.shape
727
+ size_h, size_w = tile_size
728
+ stride_h, stride_w = tile_stride
729
+
730
+ # Split tasks
731
+ tasks = []
732
+ for h in range(0, H, stride_h):
733
+ if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
734
+ for w in range(0, W, stride_w):
735
+ if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
736
+ h_, w_ = h + size_h, w + size_w
737
+ tasks.append((h, h_, w, w_))
738
+
739
+ data_device = "cpu"
740
+ computation_device = device
741
+
742
+ out_T = (T + 3) // 4
743
+ weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
744
+ values = torch.zeros((1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
745
+
746
+ for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
747
+ hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device)
748
+ hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(data_device)
749
+
750
+ mask = self.build_mask(
751
+ hidden_states_batch,
752
+ is_bound=(h==0, h_>=H, w==0, w_>=W),
753
+ border_width=((size_h - stride_h) // self.upsampling_factor, (size_w - stride_w) // self.upsampling_factor)
754
+ ).to(dtype=video.dtype, device=data_device)
755
+
756
+ target_h = h // self.upsampling_factor
757
+ target_w = w // self.upsampling_factor
758
+ values[
759
+ :,
760
+ :,
761
+ :,
762
+ target_h:target_h + hidden_states_batch.shape[3],
763
+ target_w:target_w + hidden_states_batch.shape[4],
764
+ ] += hidden_states_batch * mask
765
+ weight[
766
+ :,
767
+ :,
768
+ :,
769
+ target_h: target_h + hidden_states_batch.shape[3],
770
+ target_w: target_w + hidden_states_batch.shape[4],
771
+ ] += mask
772
+ values = values / weight
773
+ return values
774
+
775
+
776
+ def single_encode(self, video, device):
777
+ video = video.to(device)
778
+ x = self.model.encode(video, self.scale)
779
+ return x
780
+
781
+
782
+ def single_decode(self, hidden_state, device):
783
+ hidden_state = hidden_state.to(device)
784
+ video = self.model.decode(hidden_state, self.scale)
785
+ return video.clamp_(-1, 1)
786
+
787
+
788
+ def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
789
+
790
+ videos = [video.to("cpu") for video in videos]
791
+ hidden_states = []
792
+ for video in videos:
793
+ video = video.unsqueeze(0)
794
+ if tiled:
795
+ tile_size = (tile_size[0] * 8, tile_size[1] * 8)
796
+ tile_stride = (tile_stride[0] * 8, tile_stride[1] * 8)
797
+ hidden_state = self.tiled_encode(video, device, tile_size, tile_stride)
798
+ else:
799
+ hidden_state = self.single_encode(video, device)
800
+ hidden_state = hidden_state.squeeze(0)
801
+ hidden_states.append(hidden_state)
802
+ hidden_states = torch.stack(hidden_states)
803
+ return hidden_states
804
+
805
+
806
+ def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
807
+ hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
808
+ videos = []
809
+ for hidden_state in hidden_states:
810
+ hidden_state = hidden_state.unsqueeze(0)
811
+ if tiled:
812
+ video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
813
+ else:
814
+ video = self.single_decode(hidden_state, device)
815
+ video = video.squeeze(0)
816
+ videos.append(video)
817
+ videos = torch.stack(videos)
818
+ return videos
819
+
820
+ def clear_cache(self):
821
+ self.model.clear_cache()
822
+
823
+ def stream_decode(self, hidden_states, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
824
+ hidden_states = [hidden_state for hidden_state in hidden_states]
825
+ assert len(hidden_states) == 1
826
+ hidden_state = hidden_states[0]
827
+ video = self.model.stream_decode(hidden_state, self.scale)
828
+ return video
829
+
830
+
831
+ @staticmethod
832
+ def state_dict_converter():
833
+ return WanVideoVAEStateDictConverter()
834
+
835
+
836
+ class WanVideoVAEStateDictConverter:
837
+
838
+ def __init__(self):
839
+ pass
840
+
841
+ def from_civitai(self, state_dict):
842
+ state_dict_ = {}
843
+ if 'model_state' in state_dict:
844
+ state_dict = state_dict['model_state']
845
+ for name in state_dict:
846
+ state_dict_['model.' + name] = state_dict[name]
847
+ return state_dict_
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/pipelines/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .flashvsr_full import FlashVSRFullPipeline
2
+ from .flashvsr_tiny import FlashVSRTinyPipeline
3
+ from .flashvsr_tiny_long import FlashVSRTinyLongPipeline
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/pipelines/base.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gc
3
+ import numpy as np
4
+ from PIL import Image
5
+ from torchvision.transforms import GaussianBlur
6
+
7
+ class BasePipeline(torch.nn.Module):
8
+
9
+ def __init__(self, device="cuda", torch_dtype=torch.float16, height_division_factor=64, width_division_factor=64):
10
+ super().__init__()
11
+ self.device = device
12
+ self.torch_dtype = torch_dtype
13
+ self.height_division_factor = height_division_factor
14
+ self.width_division_factor = width_division_factor
15
+ self.cpu_offload = False
16
+ self.model_names = []
17
+
18
+
19
+ def check_resize_height_width(self, height, width):
20
+ if height % self.height_division_factor != 0:
21
+ height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
22
+ print(f"The height cannot be evenly divided by {self.height_division_factor}. We round it up to {height}.")
23
+ if width % self.width_division_factor != 0:
24
+ width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
25
+ print(f"The width cannot be evenly divided by {self.width_division_factor}. We round it up to {width}.")
26
+ return height, width
27
+
28
+
29
+ def preprocess_image(self, image):
30
+ image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
31
+ return image
32
+
33
+
34
+ def preprocess_images(self, images):
35
+ return [self.preprocess_image(image) for image in images]
36
+
37
+
38
+ def vae_output_to_image(self, vae_output):
39
+ image = vae_output[0].cpu().float().permute(1, 2, 0).numpy()
40
+ image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
41
+ return image
42
+
43
+
44
+ def vae_output_to_video(self, vae_output):
45
+ video = vae_output.cpu().permute(1, 2, 0).numpy()
46
+ video = [Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) for image in video]
47
+ return video
48
+
49
+
50
+ def merge_latents(self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0):
51
+ if len(latents) > 0:
52
+ blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma)
53
+ height, width = value.shape[-2:]
54
+ weight = torch.ones_like(value)
55
+ for latent, mask, scale in zip(latents, masks, scales):
56
+ mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0
57
+ mask = mask.repeat(1, latent.shape[1], 1, 1).to(dtype=latent.dtype, device=latent.device)
58
+ mask = blur(mask)
59
+ value += latent * mask * scale
60
+ weight += mask * scale
61
+ value /= weight
62
+ return value
63
+
64
+
65
+ def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback, special_kwargs=None, special_local_kwargs_list=None):
66
+ if special_kwargs is None:
67
+ noise_pred_global = inference_callback(prompt_emb_global)
68
+ else:
69
+ noise_pred_global = inference_callback(prompt_emb_global, special_kwargs)
70
+ if special_local_kwargs_list is None:
71
+ noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals]
72
+ else:
73
+ noise_pred_locals = [inference_callback(prompt_emb_local, special_kwargs) for prompt_emb_local, special_kwargs in zip(prompt_emb_locals, special_local_kwargs_list)]
74
+ noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales)
75
+ return noise_pred
76
+
77
+
78
+ def extend_prompt(self, prompt, local_prompts, masks, mask_scales):
79
+ local_prompts = local_prompts or []
80
+ masks = masks or []
81
+ mask_scales = mask_scales or []
82
+ extended_prompt_dict = self.prompter.extend_prompt(prompt)
83
+ prompt = extended_prompt_dict.get("prompt", prompt)
84
+ local_prompts += extended_prompt_dict.get("prompts", [])
85
+ masks += extended_prompt_dict.get("masks", [])
86
+ mask_scales += [100.0] * len(extended_prompt_dict.get("masks", []))
87
+ return prompt, local_prompts, masks, mask_scales
88
+
89
+
90
+ def enable_cpu_offload(self):
91
+ self.cpu_offload = True
92
+
93
+
94
+ def load_models_to_device(self, loadmodel_names=[]):
95
+ # only load models to device if cpu_offload is enabled
96
+ if not self.cpu_offload:
97
+ return
98
+ # offload the unneeded models to cpu
99
+ for model_name in self.model_names:
100
+ if model_name not in loadmodel_names:
101
+ model = getattr(self, model_name)
102
+ if model is not None:
103
+ if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
104
+ for module in model.modules():
105
+ if hasattr(module, "offload"):
106
+ module.offload()
107
+ else:
108
+ model.cpu()
109
+ # load the needed models to device
110
+ for model_name in loadmodel_names:
111
+ model = getattr(self, model_name)
112
+ if model is not None:
113
+ if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
114
+ for module in model.modules():
115
+ if hasattr(module, "onload"):
116
+ module.onload()
117
+ else:
118
+ model.to(self.device)
119
+ # fresh the cuda cache
120
+ if torch.cuda.is_available():
121
+ torch.cuda.empty_cache()
122
+ if torch.backends.mps.is_available():
123
+ torch.mps.empty_cache()
124
+
125
+
126
+ def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16):
127
+ generator = None if seed is None else torch.Generator(device).manual_seed(seed)
128
+ noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
129
+ return noise
130
+
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/pipelines/flashvsr_full.py ADDED
@@ -0,0 +1,618 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import types
2
+ import os
3
+ import time
4
+ from typing import Optional, Tuple, Literal
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ from einops import rearrange
11
+ from PIL import Image
12
+ from tqdm import tqdm
13
+ # import pyfiglet
14
+
15
+ from ..models import ModelManager
16
+ from ..models.utils import clean_vram
17
+ from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
18
+ from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
19
+ from ..schedulers.flow_match import FlowMatchScheduler
20
+ from .base import BasePipeline
21
+
22
+
23
+ # -----------------------------
24
+ # 基础工具:ADAIN 所需的统计量(保留以备需要;管线默认用 wavelet)
25
+ # -----------------------------
26
+ def _calc_mean_std(feat: torch.Tensor, eps: float = 1e-5) -> Tuple[torch.Tensor, torch.Tensor]:
27
+ assert feat.dim() == 4, 'feat 必须是 (N, C, H, W)'
28
+ N, C = feat.shape[:2]
29
+ var = feat.view(N, C, -1).var(dim=2, unbiased=False) + eps
30
+ std = var.sqrt().view(N, C, 1, 1)
31
+ mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
32
+ return mean, std
33
+
34
+
35
+ def _adain(content_feat: torch.Tensor, style_feat: torch.Tensor) -> torch.Tensor:
36
+ assert content_feat.shape[:2] == style_feat.shape[:2], "ADAIN: N、C 必须匹配"
37
+ size = content_feat.size()
38
+ style_mean, style_std = _calc_mean_std(style_feat)
39
+ content_mean, content_std = _calc_mean_std(content_feat)
40
+ normalized = (content_feat - content_mean.expand(size)) / content_std.expand(size)
41
+ return normalized * style_std.expand(size) + style_mean.expand(size)
42
+
43
+
44
+ # -----------------------------
45
+ # 小波式模糊与分解/重构(ColorCorrector 用)
46
+ # -----------------------------
47
+ def _make_gaussian3x3_kernel(dtype, device) -> torch.Tensor:
48
+ vals = [
49
+ [0.0625, 0.125, 0.0625],
50
+ [0.125, 0.25, 0.125 ],
51
+ [0.0625, 0.125, 0.0625],
52
+ ]
53
+ return torch.tensor(vals, dtype=dtype, device=device)
54
+
55
+
56
+ def _wavelet_blur(x: torch.Tensor, radius: int) -> torch.Tensor:
57
+ assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
58
+ N, C, H, W = x.shape
59
+ base = _make_gaussian3x3_kernel(x.dtype, x.device)
60
+ weight = base.view(1, 1, 3, 3).repeat(C, 1, 1, 1)
61
+ pad = radius
62
+ x_pad = F.pad(x, (pad, pad, pad, pad), mode='replicate')
63
+ out = F.conv2d(x_pad, weight, bias=None, stride=1, padding=0, dilation=radius, groups=C)
64
+ return out
65
+
66
+
67
+ def _wavelet_decompose(x: torch.Tensor, levels: int = 5) -> Tuple[torch.Tensor, torch.Tensor]:
68
+ assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
69
+ high = torch.zeros_like(x)
70
+ low = x
71
+ for i in range(levels):
72
+ radius = 2 ** i
73
+ blurred = _wavelet_blur(low, radius)
74
+ high = high + (low - blurred)
75
+ low = blurred
76
+ return high, low
77
+
78
+
79
+ def _wavelet_reconstruct(content: torch.Tensor, style: torch.Tensor, levels: int = 5) -> torch.Tensor:
80
+ c_high, _ = _wavelet_decompose(content, levels=levels)
81
+ _, s_low = _wavelet_decompose(style, levels=levels)
82
+ return c_high + s_low
83
+
84
+
85
+ # -----------------------------
86
+ # 无状态颜色矫正模块(视频友好,默认 wavelet)
87
+ # -----------------------------
88
+ class TorchColorCorrectorWavelet(nn.Module):
89
+ def __init__(self, levels: int = 5):
90
+ super().__init__()
91
+ self.levels = levels
92
+
93
+ @staticmethod
94
+ def _flatten_time(x: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
95
+ assert x.dim() == 5, '输入必须是 (B, C, f, H, W)'
96
+ B, C, f, H, W = x.shape
97
+ y = x.permute(0, 2, 1, 3, 4).reshape(B * f, C, H, W)
98
+ return y, B, f
99
+
100
+ @staticmethod
101
+ def _unflatten_time(y: torch.Tensor, B: int, f: int) -> torch.Tensor:
102
+ BF, C, H, W = y.shape
103
+ assert BF == B * f
104
+ return y.reshape(B, f, C, H, W).permute(0, 2, 1, 3, 4)
105
+
106
+ def forward(
107
+ self,
108
+ hq_image: torch.Tensor, # (B, C, f, H, W)
109
+ lq_image: torch.Tensor, # (B, C, f, H, W)
110
+ clip_range: Tuple[float, float] = (-1.0, 1.0),
111
+ method: Literal['wavelet', 'adain'] = 'wavelet',
112
+ chunk_size: Optional[int] = None,
113
+ ) -> torch.Tensor:
114
+ assert hq_image.shape == lq_image.shape, "HQ 与 LQ 的形状必须一致"
115
+ assert hq_image.dim() == 5 and hq_image.shape[1] == 3, "输入必须是 (B, 3, f, H, W)"
116
+
117
+ B, C, f, H, W = hq_image.shape
118
+ if chunk_size is None or chunk_size >= f:
119
+ hq4, B, f = self._flatten_time(hq_image)
120
+ lq4, _, _ = self._flatten_time(lq_image)
121
+ if method == 'wavelet':
122
+ out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
123
+ elif method == 'adain':
124
+ out4 = _adain(hq4, lq4)
125
+ else:
126
+ raise ValueError(f"未知 method: {method}")
127
+ out4 = torch.clamp(out4, *clip_range)
128
+ out = self._unflatten_time(out4, B, f)
129
+ return out
130
+
131
+ outs = []
132
+ for start in range(0, f, chunk_size):
133
+ end = min(start + chunk_size, f)
134
+ hq_chunk = hq_image[:, :, start:end]
135
+ lq_chunk = lq_image[:, :, start:end]
136
+ hq4, B_, f_ = self._flatten_time(hq_chunk)
137
+ lq4, _, _ = self._flatten_time(lq_chunk)
138
+ if method == 'wavelet':
139
+ out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
140
+ elif method == 'adain':
141
+ out4 = _adain(hq4, lq4)
142
+ else:
143
+ raise ValueError(f"未知 method: {method}")
144
+ out4 = torch.clamp(out4, *clip_range)
145
+ out_chunk = self._unflatten_time(out4, B_, f_)
146
+ outs.append(out_chunk)
147
+ out = torch.cat(outs, dim=2)
148
+ return out
149
+
150
+
151
+ # -----------------------------
152
+ # 简化版 Pipeline(仅 dit + vae)
153
+ # -----------------------------
154
+ class FlashVSRFullPipeline(BasePipeline):
155
+
156
+ def __init__(self, device="cuda", torch_dtype=torch.float16):
157
+ super().__init__(device=device, torch_dtype=torch_dtype)
158
+ self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
159
+ self.dit: WanModel = None
160
+ self.vae: WanVideoVAE = None
161
+ self.model_names = ['dit', 'vae']
162
+ self.height_division_factor = 16
163
+ self.width_division_factor = 16
164
+ self.use_unified_sequence_parallel = False
165
+ self.prompt_emb_posi = None
166
+ self.ColorCorrector = TorchColorCorrectorWavelet(levels=5)
167
+
168
+ print(r"""
169
+ ███████╗██╗ █████╗ ███████╗██╗ ██╗██╗ ██╗███████╗█████╗
170
+ ██╔════╝██║ ██╔══██╗██╔════╝██║ ██║██║ ██║██╔════╝██╔══██╗ ██╗
171
+ █████╗ ██║ ███████║███████╗███████║╚██╗ ██╔╝███████╗███████║ ██████╗
172
+ ██╔══╝ ██║ ██╔══██║╚════██║██╔══██║ ╚████╔╝ ╚════██║██╔═██║ ██╔═╝
173
+ ██║ ███████╗██║ ██║███████║██║ ██║ ╚██╔╝ ███████║██║ ██║ ╚═╝
174
+ ╚═╝ ╚══════╝╚═╝ ╚═╝╚══════╝╚═╝ ╚═╝ ╚═╝ ╚══════╝╚═╝ ╚═╝
175
+ """)
176
+
177
+ def enable_vram_management(self, num_persistent_param_in_dit=None):
178
+ # 仅管理 dit / vae
179
+ dtype = next(iter(self.dit.parameters())).dtype
180
+ from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
181
+ enable_vram_management(
182
+ self.dit,
183
+ module_map={
184
+ torch.nn.Linear: AutoWrappedLinear,
185
+ torch.nn.Conv3d: AutoWrappedModule,
186
+ torch.nn.LayerNorm: AutoWrappedModule,
187
+ RMSNorm: AutoWrappedModule,
188
+ },
189
+ module_config=dict(
190
+ offload_dtype=dtype,
191
+ offload_device="cpu",
192
+ onload_dtype=dtype,
193
+ onload_device=self.device,
194
+ computation_dtype=self.torch_dtype,
195
+ computation_device=self.device,
196
+ ),
197
+ max_num_param=num_persistent_param_in_dit,
198
+ overflow_module_config=dict(
199
+ offload_dtype=dtype,
200
+ offload_device="cpu",
201
+ onload_dtype=dtype,
202
+ onload_device="cpu",
203
+ computation_dtype=self.torch_dtype,
204
+ computation_device=self.device,
205
+ ),
206
+ )
207
+ dtype = next(iter(self.vae.parameters())).dtype
208
+ enable_vram_management(
209
+ self.vae,
210
+ module_map={
211
+ torch.nn.Linear: AutoWrappedLinear,
212
+ torch.nn.Conv2d: AutoWrappedModule,
213
+ RMS_norm: AutoWrappedModule,
214
+ CausalConv3d: AutoWrappedModule,
215
+ Upsample: AutoWrappedModule,
216
+ torch.nn.SiLU: AutoWrappedModule,
217
+ torch.nn.Dropout: AutoWrappedModule,
218
+ },
219
+ module_config=dict(
220
+ offload_dtype=dtype,
221
+ offload_device="cpu",
222
+ onload_dtype=dtype,
223
+ onload_device=self.device,
224
+ computation_dtype=self.torch_dtype,
225
+ computation_device=self.device,
226
+ ),
227
+ )
228
+ self.enable_cpu_offload()
229
+
230
+ def fetch_models(self, model_manager: ModelManager):
231
+ self.dit = model_manager.fetch_model("wan_video_dit")
232
+ self.vae = model_manager.fetch_model("wan_video_vae")
233
+
234
+ @staticmethod
235
+ def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
236
+ if device is None: device = model_manager.device
237
+ if torch_dtype is None: torch_dtype = model_manager.torch_dtype
238
+ pipe = FlashVSRFullPipeline(device=device, torch_dtype=torch_dtype)
239
+ pipe.fetch_models(model_manager)
240
+ # 可选:统一序列并行入口(此处默认关闭)
241
+ pipe.use_unified_sequence_parallel = False
242
+ return pipe
243
+
244
+ def denoising_model(self):
245
+ return self.dit
246
+
247
+ # -------------------------
248
+ # 新增:显式 KV 预初始化函数
249
+ # -------------------------
250
+ def init_cross_kv(
251
+ self,
252
+ context_tensor: Optional[torch.Tensor] = None,
253
+ prompt_path = None
254
+ ):
255
+ self.load_models_to_device(["dit"])
256
+ """
257
+ 使用固定 prompt 生成文本 context,并在 WanModel 中初始化所有 CrossAttention 的 KV 缓存。
258
+ 必须在 __call__ 前显式调用一次。
259
+ """
260
+ #prompt_path = "../../examples/WanVSR/prompt_tensor/posi_prompt.pth"
261
+ if self.dit is None:
262
+ raise RuntimeError("请先通过 fetch_models / from_model_manager 初始化 self.dit")
263
+
264
+ if context_tensor is None:
265
+ if prompt_path is None:
266
+ raise ValueError("init_cross_kv: 需要提供 prompt_path 或 context_tensor 其一")
267
+ ctx = torch.load(prompt_path, map_location=self.device)
268
+ else:
269
+ ctx = context_tensor
270
+
271
+ ctx = ctx.to(dtype=self.torch_dtype, device=self.device)
272
+
273
+ if self.prompt_emb_posi is None:
274
+ self.prompt_emb_posi = {}
275
+ self.prompt_emb_posi['context'] = ctx
276
+ self.prompt_emb_posi['stats'] = "load"
277
+
278
+ if hasattr(self.dit, "reinit_cross_kv"):
279
+ self.dit.reinit_cross_kv(ctx)
280
+ else:
281
+ raise AttributeError("WanModel 缺少 reinit_cross_kv(ctx) 方法,请在模型实现中加入该能力。")
282
+ self.timestep = torch.tensor([1000.], device=self.device, dtype=self.torch_dtype)
283
+ self.t = self.dit.time_embedding(sinusoidal_embedding_1d(self.dit.freq_dim, self.timestep))
284
+ self.t_mod = self.dit.time_projection(self.t).unflatten(1, (6, self.dit.dim))
285
+ # Scheduler
286
+ self.scheduler.set_timesteps(1, denoising_strength=1.0, shift=5.0)
287
+ self.load_models_to_device([])
288
+
289
+ def prepare_unified_sequence_parallel(self):
290
+ return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
291
+
292
+ def prepare_extra_input(self, latents=None):
293
+ return {}
294
+
295
+ def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
296
+ latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
297
+ return latents
298
+
299
+ def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
300
+ frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
301
+ return frames
302
+
303
+ def offload_model(self, keep_vae=False):
304
+ self.dit.clear_cross_kv()
305
+ self.prompt_emb_posi['stats'] = "offload"
306
+ if hasattr(self.dit, "LQ_proj_in"):
307
+ self.dit.LQ_proj_in.to('cpu')
308
+ if keep_vae:
309
+ self.load_models_to_device(["vae"])
310
+ else:
311
+ self.load_models_to_device([])
312
+
313
+ @torch.no_grad()
314
+ def __call__(
315
+ self,
316
+ prompt=None,
317
+ negative_prompt="",
318
+ denoising_strength=1.0,
319
+ seed=None,
320
+ rand_device="gpu",
321
+ height=480,
322
+ width=832,
323
+ num_frames=81,
324
+ cfg_scale=5.0,
325
+ num_inference_steps=50,
326
+ sigma_shift=5.0,
327
+ tiled=True,
328
+ tile_size=(60, 104),
329
+ tile_stride=(30, 52),
330
+ tea_cache_l1_thresh=None,
331
+ tea_cache_model_id="Wan2.1-T2V-1.3B",
332
+ progress_bar_cmd=tqdm,
333
+ progress_bar_st=None,
334
+ LQ_video=None,
335
+ is_full_block=False,
336
+ if_buffer=False,
337
+ topk_ratio=2.0,
338
+ kv_ratio=3.0,
339
+ local_range = 9,
340
+ color_fix = True,
341
+ unload_dit = False,
342
+ force_offload = False,
343
+ ):
344
+ # 只接受 cfg=1.0(与原代码一致)
345
+ assert cfg_scale == 1.0, "cfg_scale must be 1.0"
346
+
347
+ # 要求:必须先 init_cross_kv()
348
+ if self.prompt_emb_posi is None or 'context' not in self.prompt_emb_posi:
349
+ raise RuntimeError(
350
+ "Cross-Attn KV 未初始化。请在调用 __call__ 前先执行:\n"
351
+ " pipe.init_cross_kv()\n"
352
+ "或传入自定义 context:\n"
353
+ " pipe.init_cross_kv(context_tensor=your_context_tensor)"
354
+ )
355
+
356
+ if num_frames % 4 != 1:
357
+ num_frames = (num_frames + 2) // 4 * 4 + 1
358
+ print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
359
+
360
+ # Tiler 参数
361
+ tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
362
+
363
+ # 初始化噪声
364
+ if if_buffer:
365
+ noise = self.generate_noise((1, 16, (num_frames - 1) // 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
366
+ else:
367
+ noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
368
+ # noise = noise.to(dtype=self.torch_dtype, device=self.device)
369
+ latents = noise
370
+
371
+ process_total_num = (num_frames - 1) // 8 - 2
372
+ is_stream = True
373
+
374
+ if self.prompt_emb_posi['stats'] == "offload":
375
+ self.init_cross_kv(context_tensor=self.prompt_emb_posi['context'])
376
+ self.load_models_to_device(["dit", "vae"])
377
+ self.dit.LQ_proj_in.to(self.device)
378
+
379
+ # 清理可能存在的 LQ_proj_in cache
380
+ if hasattr(self.dit, "LQ_proj_in"):
381
+ self.dit.LQ_proj_in.clear_cache()
382
+
383
+ latents_total = []
384
+ self.vae.clear_cache()
385
+
386
+ with torch.no_grad():
387
+ for cur_process_idx in progress_bar_cmd(range(process_total_num)):
388
+ if cur_process_idx == 0:
389
+ pre_cache_k = [None] * len(self.dit.blocks)
390
+ pre_cache_v = [None] * len(self.dit.blocks)
391
+ LQ_latents = None
392
+ inner_loop_num = 7
393
+ for inner_idx in range(inner_loop_num):
394
+ cur = self.denoising_model().LQ_proj_in.stream_forward(
395
+ LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :]
396
+ ) if LQ_video is not None else None
397
+ if cur is None:
398
+ continue
399
+ if LQ_latents is None:
400
+ LQ_latents = cur
401
+ else:
402
+ for layer_idx in range(len(LQ_latents)):
403
+ LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
404
+ cur_latents = latents[:, :, :6, :, :]
405
+ else:
406
+ LQ_latents = None
407
+ inner_loop_num = 2
408
+ for inner_idx in range(inner_loop_num):
409
+ cur = self.denoising_model().LQ_proj_in.stream_forward(
410
+ LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :]
411
+ ) if LQ_video is not None else None
412
+ if cur is None:
413
+ continue
414
+ if LQ_latents is None:
415
+ LQ_latents = cur
416
+ else:
417
+ for layer_idx in range(len(LQ_latents)):
418
+ LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
419
+ cur_latents = latents[:, :, 4+cur_process_idx*2:6+cur_process_idx*2, :, :]
420
+
421
+ # 推理(无 motion_controller / vace)
422
+ noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video(
423
+ self.dit,
424
+ x=cur_latents,
425
+ timestep=self.timestep,
426
+ context=None,
427
+ tea_cache=None,
428
+ use_unified_sequence_parallel=False,
429
+ LQ_latents=LQ_latents,
430
+ is_full_block=is_full_block,
431
+ is_stream=is_stream,
432
+ pre_cache_k=pre_cache_k,
433
+ pre_cache_v=pre_cache_v,
434
+ topk_ratio=topk_ratio,
435
+ kv_ratio=kv_ratio,
436
+ cur_process_idx=cur_process_idx,
437
+ t_mod=self.t_mod,
438
+ t=self.t,
439
+ local_range = local_range,
440
+ )
441
+
442
+ # 更新 latent
443
+ cur_latents = cur_latents - noise_pred_posi
444
+ latents_total.append(cur_latents)
445
+
446
+ if hasattr(self.dit, "LQ_proj_in"):
447
+ self.dit.LQ_proj_in.clear_cache()
448
+
449
+ if unload_dit and hasattr(self, 'dit') and not next(self.dit.parameters()).is_cpu:
450
+ print("[FlashVSR] Offloading DiT to the CPU to free up VRAM...")
451
+ self.offload_model(keep_vae=True)
452
+
453
+ latents = torch.cat(latents_total, dim=2)
454
+
455
+ # Decode
456
+ print("[FlashVSR] Starting VAE decoding...")
457
+ frames = self.decode_video(latents, **tiler_kwargs)
458
+
459
+ self.vae.clear_cache()
460
+ if force_offload:
461
+ self.offload_model()
462
+
463
+ # 颜色校正(wavelet)
464
+ try:
465
+ if color_fix:
466
+ frames = self.ColorCorrector(
467
+ frames.to(device=LQ_video.device),
468
+ LQ_video[:, :, :frames.shape[2], :, :],
469
+ clip_range=(-1, 1),
470
+ chunk_size=16,
471
+ method='adain'
472
+ )
473
+ except:
474
+ pass
475
+
476
+ return frames[0]
477
+
478
+
479
+ # -----------------------------
480
+ # TeaCache(保留原逻���;此处默认不启用)
481
+ # -----------------------------
482
+ class TeaCache:
483
+ def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
484
+ self.num_inference_steps = num_inference_steps
485
+ self.step = 0
486
+ self.accumulated_rel_l1_distance = 0
487
+ self.previous_modulated_input = None
488
+ self.rel_l1_thresh = rel_l1_thresh
489
+ self.previous_residual = None
490
+ self.previous_hidden_states = None
491
+
492
+ self.coefficients_dict = {
493
+ "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
494
+ "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
495
+ "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
496
+ "Wan2.1-I2V-14B-720P": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
497
+ }
498
+ if model_id not in self.coefficients_dict:
499
+ supported_model_ids = ", ".join([i for i in self.coefficients_dict])
500
+ raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
501
+ self.coefficients = self.coefficients_dict[model_id]
502
+
503
+ def check(self, dit: WanModel, x, t_mod):
504
+ modulated_inp = t_mod.clone()
505
+ if self.step == 0 or self.step == self.num_inference_steps - 1:
506
+ should_calc = True
507
+ self.accumulated_rel_l1_distance = 0
508
+ else:
509
+ coefficients = self.coefficients
510
+ rescale_func = np.poly1d(coefficients)
511
+ self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
512
+ should_calc = not (self.accumulated_rel_l1_distance < self.rel_l1_thresh)
513
+ if should_calc:
514
+ self.accumulated_rel_l1_distance = 0
515
+ self.previous_modulated_input = modulated_inp
516
+ self.step = (self.step + 1) % self.num_inference_steps
517
+ if should_calc:
518
+ self.previous_hidden_states = x.clone()
519
+ return not should_calc
520
+
521
+ def store(self, hidden_states):
522
+ self.previous_residual = hidden_states - self.previous_hidden_states
523
+ self.previous_hidden_states = None
524
+
525
+ def update(self, hidden_states):
526
+ hidden_states = hidden_states + self.previous_residual
527
+ return hidden_states
528
+
529
+
530
+ # -----------------------------
531
+ # 简化版模型前向封装(无 vace / 无 motion_controller)
532
+ # -----------------------------
533
+ def model_fn_wan_video(
534
+ dit: WanModel,
535
+ x: torch.Tensor,
536
+ timestep: torch.Tensor,
537
+ context: torch.Tensor,
538
+ tea_cache: Optional[TeaCache] = None,
539
+ use_unified_sequence_parallel: bool = False,
540
+ LQ_latents: Optional[torch.Tensor] = None,
541
+ is_full_block: bool = False,
542
+ is_stream: bool = False,
543
+ pre_cache_k: Optional[list[torch.Tensor]] = None,
544
+ pre_cache_v: Optional[list[torch.Tensor]] = None,
545
+ topk_ratio: float = 2.0,
546
+ kv_ratio: float = 3.0,
547
+ cur_process_idx: int = 0,
548
+ t_mod : torch.Tensor = None,
549
+ t : torch.Tensor = None,
550
+ local_range: int = 9,
551
+ **kwargs,
552
+ ):
553
+ # patchify
554
+ x, (f, h, w) = dit.patchify(x)
555
+
556
+ win = (2, 8, 8)
557
+ seqlen = f // win[0]
558
+ local_num = seqlen
559
+ window_size = win[0] * h * w // 128
560
+ square_num = window_size * window_size
561
+ topk = int(square_num * topk_ratio) - 1
562
+ kv_len = int(kv_ratio)
563
+
564
+ # RoPE 位置(分段)
565
+ if cur_process_idx == 0:
566
+ freqs = torch.cat([
567
+ dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
568
+ dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
569
+ dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
570
+ ], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
571
+ else:
572
+ freqs = torch.cat([
573
+ dit.freqs[0][4 + cur_process_idx*2:4 + cur_process_idx*2 + f].view(f, 1, 1, -1).expand(f, h, w, -1),
574
+ dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
575
+ dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
576
+ ], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
577
+
578
+ # TeaCache(默认不启用)
579
+ tea_cache_update = tea_cache.check(dit, x, t_mod) if tea_cache is not None else False
580
+
581
+ # 统一序列并行(此处默认关闭)
582
+ if use_unified_sequence_parallel:
583
+ import torch.distributed as dist
584
+ from xfuser.core.distributed import (get_sequence_parallel_rank,
585
+ get_sequence_parallel_world_size,
586
+ get_sp_group)
587
+ if dist.is_initialized() and dist.get_world_size() > 1:
588
+ x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
589
+
590
+ # Block 堆叠
591
+ if tea_cache_update:
592
+ x = tea_cache.update(x)
593
+ else:
594
+ for block_id, block in enumerate(dit.blocks):
595
+ if LQ_latents is not None and block_id < len(LQ_latents):
596
+ x = x + LQ_latents[block_id]
597
+ x, last_pre_cache_k, last_pre_cache_v = block(
598
+ x, context, t_mod, freqs, f, h, w,
599
+ local_num, topk,
600
+ block_id=block_id,
601
+ kv_len=kv_len,
602
+ is_full_block=is_full_block,
603
+ is_stream=is_stream,
604
+ pre_cache_k=pre_cache_k[block_id] if pre_cache_k is not None else None,
605
+ pre_cache_v=pre_cache_v[block_id] if pre_cache_v is not None else None,
606
+ local_range = local_range,
607
+ )
608
+ if pre_cache_k is not None: pre_cache_k[block_id] = last_pre_cache_k
609
+ if pre_cache_v is not None: pre_cache_v[block_id] = last_pre_cache_v
610
+
611
+ x = dit.head(x, t)
612
+ if use_unified_sequence_parallel:
613
+ import torch.distributed as dist
614
+ from xfuser.core.distributed import get_sp_group
615
+ if dist.is_initialized() and dist.get_world_size() > 1:
616
+ x = get_sp_group().all_gather(x, dim=1)
617
+ x = dit.unpatchify(x, (f, h, w))
618
+ return x, pre_cache_k, pre_cache_v
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/pipelines/flashvsr_tiny.py ADDED
@@ -0,0 +1,615 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import types
2
+ import os
3
+ import time
4
+ from typing import Optional, Tuple, Literal
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ from einops import rearrange
11
+ from PIL import Image
12
+ from tqdm import tqdm
13
+ # import pyfiglet
14
+
15
+ from ..models import ModelManager
16
+ from ..models.utils import clean_vram
17
+ from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
18
+ from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
19
+ from ..schedulers.flow_match import FlowMatchScheduler
20
+ from .base import BasePipeline
21
+
22
+
23
+ # -----------------------------
24
+ # 基础工具:ADAIN 所需的统计量(保留以备需要;管线默认用 wavelet)
25
+ # -----------------------------
26
+ def _calc_mean_std(feat: torch.Tensor, eps: float = 1e-5) -> Tuple[torch.Tensor, torch.Tensor]:
27
+ assert feat.dim() == 4, 'feat 必须是 (N, C, H, W)'
28
+ N, C = feat.shape[:2]
29
+ var = feat.view(N, C, -1).var(dim=2, unbiased=False) + eps
30
+ std = var.sqrt().view(N, C, 1, 1)
31
+ mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
32
+ return mean, std
33
+
34
+
35
+ def _adain(content_feat: torch.Tensor, style_feat: torch.Tensor) -> torch.Tensor:
36
+ assert content_feat.shape[:2] == style_feat.shape[:2], "ADAIN: N、C 必须匹配"
37
+ size = content_feat.size()
38
+ style_mean, style_std = _calc_mean_std(style_feat)
39
+ content_mean, content_std = _calc_mean_std(content_feat)
40
+ normalized = (content_feat - content_mean.expand(size)) / content_std.expand(size)
41
+ return normalized * style_std.expand(size) + style_mean.expand(size)
42
+
43
+
44
+ # -----------------------------
45
+ # 小波式模糊与分解/重构(ColorCorrector 用)
46
+ # -----------------------------
47
+ def _make_gaussian3x3_kernel(dtype, device) -> torch.Tensor:
48
+ vals = [
49
+ [0.0625, 0.125, 0.0625],
50
+ [0.125, 0.25, 0.125 ],
51
+ [0.0625, 0.125, 0.0625],
52
+ ]
53
+ return torch.tensor(vals, dtype=dtype, device=device)
54
+
55
+
56
+ def _wavelet_blur(x: torch.Tensor, radius: int) -> torch.Tensor:
57
+ assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
58
+ N, C, H, W = x.shape
59
+ base = _make_gaussian3x3_kernel(x.dtype, x.device)
60
+ weight = base.view(1, 1, 3, 3).repeat(C, 1, 1, 1)
61
+ pad = radius
62
+ x_pad = F.pad(x, (pad, pad, pad, pad), mode='replicate')
63
+ out = F.conv2d(x_pad, weight, bias=None, stride=1, padding=0, dilation=radius, groups=C)
64
+ return out
65
+
66
+
67
+ def _wavelet_decompose(x: torch.Tensor, levels: int = 5) -> Tuple[torch.Tensor, torch.Tensor]:
68
+ assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
69
+ high = torch.zeros_like(x)
70
+ low = x
71
+ for i in range(levels):
72
+ radius = 2 ** i
73
+ blurred = _wavelet_blur(low, radius)
74
+ high = high + (low - blurred)
75
+ low = blurred
76
+ return high, low
77
+
78
+
79
+ def _wavelet_reconstruct(content: torch.Tensor, style: torch.Tensor, levels: int = 5) -> torch.Tensor:
80
+ c_high, _ = _wavelet_decompose(content, levels=levels)
81
+ _, s_low = _wavelet_decompose(style, levels=levels)
82
+ return c_high + s_low
83
+
84
+
85
+ # -----------------------------
86
+ # 无状态颜色矫正模块(视频友好,默认 wavelet)
87
+ # -----------------------------
88
+ class TorchColorCorrectorWavelet(nn.Module):
89
+ def __init__(self, levels: int = 5):
90
+ super().__init__()
91
+ self.levels = levels
92
+
93
+ @staticmethod
94
+ def _flatten_time(x: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
95
+ assert x.dim() == 5, '输入必须是 (B, C, f, H, W)'
96
+ B, C, f, H, W = x.shape
97
+ y = x.permute(0, 2, 1, 3, 4).reshape(B * f, C, H, W)
98
+ return y, B, f
99
+
100
+ @staticmethod
101
+ def _unflatten_time(y: torch.Tensor, B: int, f: int) -> torch.Tensor:
102
+ BF, C, H, W = y.shape
103
+ assert BF == B * f
104
+ return y.reshape(B, f, C, H, W).permute(0, 2, 1, 3, 4)
105
+
106
+ def forward(
107
+ self,
108
+ hq_image: torch.Tensor, # (B, C, f, H, W)
109
+ lq_image: torch.Tensor, # (B, C, f, H, W)
110
+ clip_range: Tuple[float, float] = (-1.0, 1.0),
111
+ method: Literal['wavelet', 'adain'] = 'wavelet',
112
+ chunk_size: Optional[int] = None,
113
+ ) -> torch.Tensor:
114
+ assert hq_image.shape == lq_image.shape, "HQ 与 LQ 的形状必须一致"
115
+ assert hq_image.dim() == 5 and hq_image.shape[1] == 3, "输入必须是 (B, 3, f, H, W)"
116
+
117
+ B, C, f, H, W = hq_image.shape
118
+ if chunk_size is None or chunk_size >= f:
119
+ hq4, B, f = self._flatten_time(hq_image)
120
+ lq4, _, _ = self._flatten_time(lq_image)
121
+ if method == 'wavelet':
122
+ out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
123
+ elif method == 'adain':
124
+ out4 = _adain(hq4, lq4)
125
+ else:
126
+ raise ValueError(f"未知 method: {method}")
127
+ out4 = torch.clamp(out4, *clip_range)
128
+ out = self._unflatten_time(out4, B, f)
129
+ return out
130
+
131
+ outs = []
132
+ for start in range(0, f, chunk_size):
133
+ end = min(start + chunk_size, f)
134
+ hq_chunk = hq_image[:, :, start:end]
135
+ lq_chunk = lq_image[:, :, start:end]
136
+ hq4, B_, f_ = self._flatten_time(hq_chunk)
137
+ lq4, _, _ = self._flatten_time(lq_chunk)
138
+ if method == 'wavelet':
139
+ out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
140
+ elif method == 'adain':
141
+ out4 = _adain(hq4, lq4)
142
+ else:
143
+ raise ValueError(f"未知 method: {method}")
144
+ out4 = torch.clamp(out4, *clip_range)
145
+ out_chunk = self._unflatten_time(out4, B_, f_)
146
+ outs.append(out_chunk)
147
+ out = torch.cat(outs, dim=2)
148
+ return out
149
+
150
+
151
+ # -----------------------------
152
+ # 简化版 Pipeline(仅 dit + vae)
153
+ # -----------------------------
154
+ class FlashVSRTinyPipeline(BasePipeline):
155
+
156
+ def __init__(self, device="cuda", torch_dtype=torch.float16):
157
+ super().__init__(device=device, torch_dtype=torch_dtype)
158
+ self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
159
+ self.dit: WanModel = None
160
+ self.vae: WanVideoVAE = None
161
+ self.model_names = ['dit', 'vae']
162
+ self.height_division_factor = 16
163
+ self.width_division_factor = 16
164
+ self.use_unified_sequence_parallel = False
165
+ self.prompt_emb_posi = None
166
+ self.ColorCorrector = TorchColorCorrectorWavelet(levels=5)
167
+
168
+ print(r"""
169
+ ███████╗██╗ █████╗ ███████╗██╗ ██╗██╗ ██╗███████╗█████╗
170
+ ██╔════╝██║ ██╔══██╗██╔════╝██║ ██║██║ ██║██╔════╝██╔══██╗ ██╗
171
+ █████╗ ██║ ███████║███████╗███████║╚██╗ ██╔╝███████╗███████║ ██████╗
172
+ ██╔══╝ ██║ ██╔══██║╚════██║██╔══██║ ╚████╔╝ ╚════██║██╔═██║ ██╔═╝
173
+ ██║ ███████╗██║ ██║███████║██║ ██║ ╚██╔╝ ███████║██║ ██║ ╚═╝
174
+ ╚═╝ ╚══════╝╚═╝ ╚═╝╚══════╝╚═╝ ╚═╝ ╚═╝ ╚══════╝╚═╝ ╚═╝
175
+ """)
176
+
177
+ def enable_vram_management(self, num_persistent_param_in_dit=None):
178
+ # 仅管理 dit / vae
179
+ dtype = next(iter(self.dit.parameters())).dtype
180
+ from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
181
+ enable_vram_management(
182
+ self.dit,
183
+ module_map={
184
+ torch.nn.Linear: AutoWrappedLinear,
185
+ torch.nn.Conv3d: AutoWrappedModule,
186
+ torch.nn.LayerNorm: AutoWrappedModule,
187
+ RMSNorm: AutoWrappedModule,
188
+ },
189
+ module_config=dict(
190
+ offload_dtype=dtype,
191
+ offload_device="cpu",
192
+ onload_dtype=dtype,
193
+ onload_device=self.device,
194
+ computation_dtype=self.torch_dtype,
195
+ computation_device=self.device,
196
+ ),
197
+ max_num_param=num_persistent_param_in_dit,
198
+ overflow_module_config=dict(
199
+ offload_dtype=dtype,
200
+ offload_device="cpu",
201
+ onload_dtype=dtype,
202
+ onload_device="cpu",
203
+ computation_dtype=self.torch_dtype,
204
+ computation_device=self.device,
205
+ ),
206
+ )
207
+ self.enable_cpu_offload()
208
+
209
+ def fetch_models(self, model_manager: ModelManager):
210
+ self.dit = model_manager.fetch_model("wan_video_dit")
211
+ self.vae = model_manager.fetch_model("wan_video_vae")
212
+
213
+ @staticmethod
214
+ def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
215
+ if device is None: device = model_manager.device
216
+ if torch_dtype is None: torch_dtype = model_manager.torch_dtype
217
+ pipe = FlashVSRTinyPipeline(device=device, torch_dtype=torch_dtype)
218
+ pipe.fetch_models(model_manager)
219
+ # 可选:统一序列并行入口(此处默认关闭)
220
+ pipe.use_unified_sequence_parallel = False
221
+ return pipe
222
+
223
+ def denoising_model(self):
224
+ return self.dit
225
+
226
+ # -------------------------
227
+ # 新增:显式 KV 预初始化函数
228
+ # -------------------------
229
+ def init_cross_kv(
230
+ self,
231
+ context_tensor: Optional[torch.Tensor] = None,
232
+ prompt_path = None,
233
+ ):
234
+ self.load_models_to_device(["dit"])
235
+ """
236
+ 使用固定 prompt 生成文本 context,并在 WanModel 中初始化所有 CrossAttention 的 KV 缓存。
237
+ 必须在 __call__ 前显式调用一次。
238
+ """
239
+ #prompt_path = "../../examples/WanVSR/prompt_tensor/posi_prompt.pth"
240
+
241
+ if self.dit is None:
242
+ raise RuntimeError("请先通过 fetch_models / from_model_manager 初始化 self.dit")
243
+
244
+ if context_tensor is None:
245
+ if prompt_path is None:
246
+ raise ValueError("init_cross_kv: 需要提供 prompt_path 或 context_tensor 其一")
247
+ ctx = torch.load(prompt_path, map_location=self.device)
248
+ else:
249
+ ctx = context_tensor
250
+
251
+ ctx = ctx.to(dtype=self.torch_dtype, device=self.device)
252
+
253
+ if self.prompt_emb_posi is None:
254
+ self.prompt_emb_posi = {}
255
+ self.prompt_emb_posi['context'] = ctx
256
+ self.prompt_emb_posi['stats'] = "load"
257
+
258
+ if hasattr(self.dit, "reinit_cross_kv"):
259
+ self.dit.reinit_cross_kv(ctx)
260
+ else:
261
+ raise AttributeError("WanModel 缺少 reinit_cross_kv(ctx) 方法,请在模型实现中加入该能力。")
262
+ self.timestep = torch.tensor([1000.], device=self.device, dtype=self.torch_dtype)
263
+ self.t = self.dit.time_embedding(sinusoidal_embedding_1d(self.dit.freq_dim, self.timestep))
264
+ self.t_mod = self.dit.time_projection(self.t).unflatten(1, (6, self.dit.dim))
265
+ # Scheduler
266
+ self.scheduler.set_timesteps(1, denoising_strength=1.0, shift=5.0)
267
+ self.load_models_to_device([])
268
+
269
+ def prepare_unified_sequence_parallel(self):
270
+ return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
271
+
272
+ def prepare_extra_input(self, latents=None):
273
+ return {}
274
+
275
+ def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
276
+ latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
277
+ return latents
278
+
279
+ def _decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
280
+ frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
281
+ return frames
282
+
283
+ def decode_video(self, latents, cond=None, **kwargs):
284
+ frames = self.TCDecoder.decode_video(
285
+ latents.transpose(1, 2), # TCDecoder 需要 (B, F, C, H, W)
286
+ parallel=False,
287
+ show_progress_bar=False,
288
+ cond=cond
289
+ ).transpose(1, 2).mul_(2).sub_(1) # 转回 (B, C, F, H, W) 格式,范围 -1 to 1
290
+
291
+ return frames
292
+
293
+ def offload_model(self, keep_vae=False):
294
+ self.dit.clear_cross_kv()
295
+ self.prompt_emb_posi['stats'] = "offload"
296
+ self.load_models_to_device([])
297
+ if hasattr(self.dit, "LQ_proj_in"):
298
+ self.dit.LQ_proj_in.to('cpu')
299
+ if not keep_vae:
300
+ self.TCDecoder.to('cpu')
301
+
302
+ @torch.no_grad()
303
+ def __call__(
304
+ self,
305
+ prompt=None,
306
+ negative_prompt="",
307
+ denoising_strength=1.0,
308
+ seed=None,
309
+ rand_device="gpu",
310
+ height=480,
311
+ width=832,
312
+ num_frames=81,
313
+ cfg_scale=5.0,
314
+ num_inference_steps=50,
315
+ sigma_shift=5.0,
316
+ tiled=True,
317
+ tile_size=(60, 104),
318
+ tile_stride=(30, 52),
319
+ tea_cache_l1_thresh=None,
320
+ tea_cache_model_id="Wan2.1-T2V-1.3B",
321
+ progress_bar_cmd=tqdm,
322
+ progress_bar_st=None,
323
+ LQ_video=None,
324
+ is_full_block=False,
325
+ if_buffer=False,
326
+ topk_ratio=2.0,
327
+ kv_ratio=3.0,
328
+ local_range = 9,
329
+ color_fix = True,
330
+ unload_dit = False,
331
+ force_offload = False,
332
+ ):
333
+ # 只接受 cfg=1.0(与原代码一致)
334
+ assert cfg_scale == 1.0, "cfg_scale must be 1.0"
335
+
336
+ # 要求:必须先 init_cross_kv()
337
+ if self.prompt_emb_posi is None or 'context' not in self.prompt_emb_posi:
338
+ raise RuntimeError(
339
+ "Cross-Attn KV 未初始化。请在调用 __call__ 前先执行:\n"
340
+ " pipe.init_cross_kv()\n"
341
+ "或传入自定义 context:\n"
342
+ " pipe.init_cross_kv(context_tensor=your_context_tensor)"
343
+ )
344
+
345
+ # 尺寸修正
346
+ height, width = self.check_resize_height_width(height, width)
347
+ if num_frames % 4 != 1:
348
+ num_frames = (num_frames + 2) // 4 * 4 + 1
349
+ print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
350
+
351
+ # Tiler 参数
352
+ tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
353
+
354
+ # 初始化噪声
355
+ if if_buffer:
356
+ noise = self.generate_noise((1, 16, (num_frames - 1) // 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
357
+ else:
358
+ noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
359
+ # noise = noise.to(dtype=self.torch_dtype, device=self.device)
360
+ latents = noise
361
+
362
+ process_total_num = (num_frames - 1) // 8 - 2
363
+ is_stream = True
364
+
365
+ if self.prompt_emb_posi['stats'] == "offload":
366
+ self.init_cross_kv(context_tensor=self.prompt_emb_posi['context'])
367
+ self.load_models_to_device(["dit"])
368
+ self.dit.LQ_proj_in.to(self.device)
369
+ self.TCDecoder.to(self.device)
370
+
371
+ # 清理可能存在的 LQ_proj_in cache
372
+ if hasattr(self.dit, "LQ_proj_in"):
373
+ self.dit.LQ_proj_in.clear_cache()
374
+
375
+ latents_total = []
376
+ self.TCDecoder.clean_mem()
377
+ LQ_pre_idx = 0
378
+ LQ_cur_idx = 0
379
+
380
+ with torch.no_grad():
381
+ for cur_process_idx in progress_bar_cmd(range(process_total_num)):
382
+ if cur_process_idx == 0:
383
+ pre_cache_k = [None] * len(self.dit.blocks)
384
+ pre_cache_v = [None] * len(self.dit.blocks)
385
+ LQ_latents = None
386
+ inner_loop_num = 7
387
+ for inner_idx in range(inner_loop_num):
388
+ cur = self.denoising_model().LQ_proj_in.stream_forward(
389
+ LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :]
390
+ ) if LQ_video is not None else None
391
+ if cur is None:
392
+ continue
393
+ if LQ_latents is None:
394
+ LQ_latents = cur
395
+ else:
396
+ for layer_idx in range(len(LQ_latents)):
397
+ LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
398
+ LQ_cur_idx = (inner_loop_num-1)*4-3
399
+ cur_latents = latents[:, :, :6, :, :]
400
+ else:
401
+ LQ_latents = None
402
+ inner_loop_num = 2
403
+ for inner_idx in range(inner_loop_num):
404
+ cur = self.denoising_model().LQ_proj_in.stream_forward(
405
+ LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :]
406
+ ) if LQ_video is not None else None
407
+ if cur is None:
408
+ continue
409
+ if LQ_latents is None:
410
+ LQ_latents = cur
411
+ else:
412
+ for layer_idx in range(len(LQ_latents)):
413
+ LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
414
+ LQ_cur_idx = cur_process_idx*8+21+(inner_loop_num-2)*4
415
+ cur_latents = latents[:, :, 4+cur_process_idx*2:6+cur_process_idx*2, :, :]
416
+
417
+ # 推理(无 motion_controller / vace)
418
+ noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video(
419
+ self.dit,
420
+ x=cur_latents,
421
+ timestep=self.timestep,
422
+ context=None,
423
+ tea_cache=None,
424
+ use_unified_sequence_parallel=False,
425
+ LQ_latents=LQ_latents,
426
+ is_full_block=is_full_block,
427
+ is_stream=is_stream,
428
+ pre_cache_k=pre_cache_k,
429
+ pre_cache_v=pre_cache_v,
430
+ topk_ratio=topk_ratio,
431
+ kv_ratio=kv_ratio,
432
+ cur_process_idx=cur_process_idx,
433
+ t_mod=self.t_mod,
434
+ t=self.t,
435
+ local_range = local_range,
436
+ )
437
+
438
+ # 更新 latent
439
+ cur_latents = cur_latents - noise_pred_posi
440
+ latents_total.append(cur_latents)
441
+ LQ_pre_idx = LQ_cur_idx
442
+
443
+ if hasattr(self.dit, "LQ_proj_in"):
444
+ self.dit.LQ_proj_in.clear_cache()
445
+
446
+ if unload_dit and hasattr(self, 'dit') and not next(self.dit.parameters()).is_cpu:
447
+ print("[FlashVSR] Offloading DiT to the CPU to free up VRAM...")
448
+ self.offload_model(keep_vae=True)
449
+
450
+ latents = torch.cat(latents_total, dim=2)
451
+
452
+ # Decode
453
+ print("[FlashVSR] Starting VAE decoding...")
454
+ frames = self.TCDecoder.decode_video(latents.transpose(1, 2),parallel=False, show_progress_bar=False, cond=LQ_video[:,:,:LQ_cur_idx,:,:]).transpose(1, 2).mul_(2).sub_(1)
455
+
456
+ self.TCDecoder.clean_mem()
457
+ if force_offload:
458
+ self.offload_model()
459
+
460
+ # 颜色校正(wavelet)
461
+ try:
462
+ if color_fix:
463
+ frames = self.ColorCorrector(
464
+ frames.to(device=LQ_video.device),
465
+ LQ_video[:, :, :frames.shape[2], :, :],
466
+ clip_range=(-1, 1),
467
+ chunk_size=16,
468
+ method='adain'
469
+ )
470
+ except:
471
+ pass
472
+
473
+ return frames[0]
474
+
475
+
476
+ # -----------------------------
477
+ # TeaCache(保留原逻���;此处默认不启用)
478
+ # -----------------------------
479
+ class TeaCache:
480
+ def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
481
+ self.num_inference_steps = num_inference_steps
482
+ self.step = 0
483
+ self.accumulated_rel_l1_distance = 0
484
+ self.previous_modulated_input = None
485
+ self.rel_l1_thresh = rel_l1_thresh
486
+ self.previous_residual = None
487
+ self.previous_hidden_states = None
488
+
489
+ self.coefficients_dict = {
490
+ "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
491
+ "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
492
+ "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
493
+ "Wan2.1-I2V-14B-720P": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
494
+ }
495
+ if model_id not in self.coefficients_dict:
496
+ supported_model_ids = ", ".join([i for i in self.coefficients_dict])
497
+ raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
498
+ self.coefficients = self.coefficients_dict[model_id]
499
+
500
+ def check(self, dit: WanModel, x, t_mod):
501
+ modulated_inp = t_mod.clone()
502
+ if self.step == 0 or self.step == self.num_inference_steps - 1:
503
+ should_calc = True
504
+ self.accumulated_rel_l1_distance = 0
505
+ else:
506
+ coefficients = self.coefficients
507
+ rescale_func = np.poly1d(coefficients)
508
+ self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
509
+ should_calc = not (self.accumulated_rel_l1_distance < self.rel_l1_thresh)
510
+ if should_calc:
511
+ self.accumulated_rel_l1_distance = 0
512
+ self.previous_modulated_input = modulated_inp
513
+ self.step = (self.step + 1) % self.num_inference_steps
514
+ if should_calc:
515
+ self.previous_hidden_states = x.clone()
516
+ return not should_calc
517
+
518
+ def store(self, hidden_states):
519
+ self.previous_residual = hidden_states - self.previous_hidden_states
520
+ self.previous_hidden_states = None
521
+
522
+ def update(self, hidden_states):
523
+ hidden_states = hidden_states + self.previous_residual
524
+ return hidden_states
525
+
526
+
527
+ # -----------------------------
528
+ # 简化版模型前向封装(无 vace / 无 motion_controller)
529
+ # -----------------------------
530
+ def model_fn_wan_video(
531
+ dit: WanModel,
532
+ x: torch.Tensor,
533
+ timestep: torch.Tensor,
534
+ context: torch.Tensor,
535
+ tea_cache: Optional[TeaCache] = None,
536
+ use_unified_sequence_parallel: bool = False,
537
+ LQ_latents: Optional[torch.Tensor] = None,
538
+ is_full_block: bool = False,
539
+ is_stream: bool = False,
540
+ pre_cache_k: Optional[list[torch.Tensor]] = None,
541
+ pre_cache_v: Optional[list[torch.Tensor]] = None,
542
+ topk_ratio: float = 2.0,
543
+ kv_ratio: float = 3.0,
544
+ cur_process_idx: int = 0,
545
+ t_mod : torch.Tensor = None,
546
+ t : torch.Tensor = None,
547
+ local_range: int = 9,
548
+ **kwargs,
549
+ ):
550
+ # patchify
551
+ x, (f, h, w) = dit.patchify(x)
552
+
553
+ win = (2, 8, 8)
554
+ seqlen = f // win[0]
555
+ local_num = seqlen
556
+ window_size = win[0] * h * w // 128
557
+ square_num = window_size * window_size
558
+ topk = int(square_num * topk_ratio) - 1
559
+ kv_len = int(kv_ratio)
560
+
561
+ # RoPE 位置(分段)
562
+ if cur_process_idx == 0:
563
+ freqs = torch.cat([
564
+ dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
565
+ dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
566
+ dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
567
+ ], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
568
+ else:
569
+ freqs = torch.cat([
570
+ dit.freqs[0][4 + cur_process_idx*2:4 + cur_process_idx*2 + f].view(f, 1, 1, -1).expand(f, h, w, -1),
571
+ dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
572
+ dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
573
+ ], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
574
+
575
+ # TeaCache(默认不启用)
576
+ tea_cache_update = tea_cache.check(dit, x, t_mod) if tea_cache is not None else False
577
+
578
+ # 统一序列并行(此处默认关闭)
579
+ if use_unified_sequence_parallel:
580
+ import torch.distributed as dist
581
+ from xfuser.core.distributed import (get_sequence_parallel_rank,
582
+ get_sequence_parallel_world_size,
583
+ get_sp_group)
584
+ if dist.is_initialized() and dist.get_world_size() > 1:
585
+ x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
586
+
587
+ # Block 堆叠
588
+ if tea_cache_update:
589
+ x = tea_cache.update(x)
590
+ else:
591
+ for block_id, block in enumerate(dit.blocks):
592
+ if LQ_latents is not None and block_id < len(LQ_latents):
593
+ x = x + LQ_latents[block_id]
594
+ x, last_pre_cache_k, last_pre_cache_v = block(
595
+ x, context, t_mod, freqs, f, h, w,
596
+ local_num, topk,
597
+ block_id=block_id,
598
+ kv_len=kv_len,
599
+ is_full_block=is_full_block,
600
+ is_stream=is_stream,
601
+ pre_cache_k=pre_cache_k[block_id] if pre_cache_k is not None else None,
602
+ pre_cache_v=pre_cache_v[block_id] if pre_cache_v is not None else None,
603
+ local_range = local_range,
604
+ )
605
+ if pre_cache_k is not None: pre_cache_k[block_id] = last_pre_cache_k
606
+ if pre_cache_v is not None: pre_cache_v[block_id] = last_pre_cache_v
607
+
608
+ x = dit.head(x, t)
609
+ if use_unified_sequence_parallel:
610
+ import torch.distributed as dist
611
+ from xfuser.core.distributed import get_sp_group
612
+ if dist.is_initialized() and dist.get_world_size() > 1:
613
+ x = get_sp_group().all_gather(x, dim=1)
614
+ x = dit.unpatchify(x, (f, h, w))
615
+ return x, pre_cache_k, pre_cache_v
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/pipelines/flashvsr_tiny_long.py ADDED
@@ -0,0 +1,620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import types
2
+ import os
3
+ import time
4
+ from typing import Optional, Tuple, Literal
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ from einops import rearrange
11
+ from PIL import Image
12
+ from tqdm import tqdm
13
+ # import pyfiglet
14
+
15
+ from ..models import ModelManager
16
+ from ..models.utils import clean_vram
17
+ from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
18
+ from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
19
+ from ..schedulers.flow_match import FlowMatchScheduler
20
+ from .base import BasePipeline
21
+
22
+
23
+ # -----------------------------
24
+ # 基础工具:ADAIN 所需的统计量(保留以备需要;管线默认用 wavelet)
25
+ # -----------------------------
26
+ def _calc_mean_std(feat: torch.Tensor, eps: float = 1e-5) -> Tuple[torch.Tensor, torch.Tensor]:
27
+ assert feat.dim() == 4, 'feat 必须是 (N, C, H, W)'
28
+ N, C = feat.shape[:2]
29
+ var = feat.view(N, C, -1).var(dim=2, unbiased=False) + eps
30
+ std = var.sqrt().view(N, C, 1, 1)
31
+ mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
32
+ return mean, std
33
+
34
+
35
+ def _adain(content_feat: torch.Tensor, style_feat: torch.Tensor) -> torch.Tensor:
36
+ assert content_feat.shape[:2] == style_feat.shape[:2], "ADAIN: N、C 必须匹配"
37
+ size = content_feat.size()
38
+ style_mean, style_std = _calc_mean_std(style_feat)
39
+ content_mean, content_std = _calc_mean_std(content_feat)
40
+ normalized = (content_feat - content_mean.expand(size)) / content_std.expand(size)
41
+ return normalized * style_std.expand(size) + style_mean.expand(size)
42
+
43
+
44
+ # -----------------------------
45
+ # 小波式模糊与分解/重构(ColorCorrector 用)
46
+ # -----------------------------
47
+ def _make_gaussian3x3_kernel(dtype, device) -> torch.Tensor:
48
+ vals = [
49
+ [0.0625, 0.125, 0.0625],
50
+ [0.125, 0.25, 0.125 ],
51
+ [0.0625, 0.125, 0.0625],
52
+ ]
53
+ return torch.tensor(vals, dtype=dtype, device=device)
54
+
55
+
56
+ def _wavelet_blur(x: torch.Tensor, radius: int) -> torch.Tensor:
57
+ assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
58
+ N, C, H, W = x.shape
59
+ base = _make_gaussian3x3_kernel(x.dtype, x.device)
60
+ weight = base.view(1, 1, 3, 3).repeat(C, 1, 1, 1)
61
+ pad = radius
62
+ x_pad = F.pad(x, (pad, pad, pad, pad), mode='replicate')
63
+ out = F.conv2d(x_pad, weight, bias=None, stride=1, padding=0, dilation=radius, groups=C)
64
+ return out
65
+
66
+
67
+ def _wavelet_decompose(x: torch.Tensor, levels: int = 5) -> Tuple[torch.Tensor, torch.Tensor]:
68
+ assert x.dim() == 4, 'x 必须是 (N, C, H, W)'
69
+ high = torch.zeros_like(x)
70
+ low = x
71
+ for i in range(levels):
72
+ radius = 2 ** i
73
+ blurred = _wavelet_blur(low, radius)
74
+ high = high + (low - blurred)
75
+ low = blurred
76
+ return high, low
77
+
78
+
79
+ def _wavelet_reconstruct(content: torch.Tensor, style: torch.Tensor, levels: int = 5) -> torch.Tensor:
80
+ c_high, _ = _wavelet_decompose(content, levels=levels)
81
+ _, s_low = _wavelet_decompose(style, levels=levels)
82
+ return c_high + s_low
83
+
84
+
85
+ # -----------------------------
86
+ # 无状态颜色矫正模块(视频友好,默认 wavelet)
87
+ # -----------------------------
88
+ class TorchColorCorrectorWavelet(nn.Module):
89
+ def __init__(self, levels: int = 5):
90
+ super().__init__()
91
+ self.levels = levels
92
+
93
+ @staticmethod
94
+ def _flatten_time(x: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
95
+ assert x.dim() == 5, '输入必须是 (B, C, f, H, W)'
96
+ B, C, f, H, W = x.shape
97
+ y = x.permute(0, 2, 1, 3, 4).reshape(B * f, C, H, W)
98
+ return y, B, f
99
+
100
+ @staticmethod
101
+ def _unflatten_time(y: torch.Tensor, B: int, f: int) -> torch.Tensor:
102
+ BF, C, H, W = y.shape
103
+ assert BF == B * f
104
+ return y.reshape(B, f, C, H, W).permute(0, 2, 1, 3, 4)
105
+
106
+ def forward(
107
+ self,
108
+ hq_image: torch.Tensor, # (B, C, f, H, W)
109
+ lq_image: torch.Tensor, # (B, C, f, H, W)
110
+ clip_range: Tuple[float, float] = (-1.0, 1.0),
111
+ method: Literal['wavelet', 'adain'] = 'wavelet',
112
+ chunk_size: Optional[int] = None,
113
+ ) -> torch.Tensor:
114
+ assert hq_image.shape == lq_image.shape, "HQ 与 LQ 的形状必须一致"
115
+ assert hq_image.dim() == 5 and hq_image.shape[1] == 3, "输入必须是 (B, 3, f, H, W)"
116
+
117
+ B, C, f, H, W = hq_image.shape
118
+ if chunk_size is None or chunk_size >= f:
119
+ hq4, B, f = self._flatten_time(hq_image)
120
+ lq4, _, _ = self._flatten_time(lq_image)
121
+ if method == 'wavelet':
122
+ out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
123
+ elif method == 'adain':
124
+ out4 = _adain(hq4, lq4)
125
+ else:
126
+ raise ValueError(f"未知 method: {method}")
127
+ out4 = torch.clamp(out4, *clip_range)
128
+ out = self._unflatten_time(out4, B, f)
129
+ return out
130
+
131
+ outs = []
132
+ for start in range(0, f, chunk_size):
133
+ end = min(start + chunk_size, f)
134
+ hq_chunk = hq_image[:, :, start:end]
135
+ lq_chunk = lq_image[:, :, start:end]
136
+ hq4, B_, f_ = self._flatten_time(hq_chunk)
137
+ lq4, _, _ = self._flatten_time(lq_chunk)
138
+ if method == 'wavelet':
139
+ out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels)
140
+ elif method == 'adain':
141
+ out4 = _adain(hq4, lq4)
142
+ else:
143
+ raise ValueError(f"未知 method: {method}")
144
+ out4 = torch.clamp(out4, *clip_range)
145
+ out_chunk = self._unflatten_time(out4, B_, f_)
146
+ outs.append(out_chunk)
147
+ out = torch.cat(outs, dim=2)
148
+ return out
149
+
150
+
151
+ # -----------------------------
152
+ # 简化版 Pipeline(仅 dit + vae)
153
+ # -----------------------------
154
+ class FlashVSRTinyLongPipeline(BasePipeline):
155
+
156
+ def __init__(self, device="cuda", torch_dtype=torch.float16):
157
+ super().__init__(device=device, torch_dtype=torch_dtype)
158
+ self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
159
+ self.dit: WanModel = None
160
+ self.vae: WanVideoVAE = None
161
+ self.model_names = ['dit', 'vae']
162
+ self.height_division_factor = 16
163
+ self.width_division_factor = 16
164
+ self.use_unified_sequence_parallel = False
165
+ self.prompt_emb_posi = None
166
+ self.ColorCorrector = TorchColorCorrectorWavelet(levels=5)
167
+
168
+ print(r"""
169
+ ███████╗██╗ █████╗ ███████╗██╗ ██╗██╗ ██╗███████╗█████╗
170
+ ██╔════╝██║ ██╔══██╗██╔════╝██║ ██║██║ ██║██╔════╝██╔══██╗ ██╗
171
+ █████╗ ██║ ███████║███████╗███████║╚██╗ ██╔╝███████╗███████║ ██████╗
172
+ ██╔══╝ ██║ ██╔══██║╚════██║██╔══██║ ╚████╔╝ ╚════██║██╔═██║ ██╔═╝
173
+ ██║ ███████╗██║ ██║███████║██║ ██║ ╚██╔╝ ███████║██║ ██║ ╚═╝
174
+ ╚═╝ ╚══════╝╚═╝ ╚═╝╚══════╝╚═╝ ╚═╝ ╚═╝ ╚══════╝╚═╝ ╚═╝
175
+ """)
176
+
177
+ def enable_vram_management(self, num_persistent_param_in_dit=None):
178
+ # 仅管理 dit / vae
179
+ dtype = next(iter(self.dit.parameters())).dtype
180
+ from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
181
+ enable_vram_management(
182
+ self.dit,
183
+ module_map={
184
+ torch.nn.Linear: AutoWrappedLinear,
185
+ torch.nn.Conv3d: AutoWrappedModule,
186
+ torch.nn.LayerNorm: AutoWrappedModule,
187
+ RMSNorm: AutoWrappedModule,
188
+ },
189
+ module_config=dict(
190
+ offload_dtype=dtype,
191
+ offload_device="cpu",
192
+ onload_dtype=dtype,
193
+ onload_device=self.device,
194
+ computation_dtype=self.torch_dtype,
195
+ computation_device=self.device,
196
+ ),
197
+ max_num_param=num_persistent_param_in_dit,
198
+ overflow_module_config=dict(
199
+ offload_dtype=dtype,
200
+ offload_device="cpu",
201
+ onload_dtype=dtype,
202
+ onload_device="cpu",
203
+ computation_dtype=self.torch_dtype,
204
+ computation_device=self.device,
205
+ ),
206
+ )
207
+ self.enable_cpu_offload()
208
+
209
+ def fetch_models(self, model_manager: ModelManager):
210
+ self.dit = model_manager.fetch_model("wan_video_dit")
211
+ self.vae = model_manager.fetch_model("wan_video_vae")
212
+
213
+ @staticmethod
214
+ def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
215
+ if device is None: device = model_manager.device
216
+ if torch_dtype is None: torch_dtype = model_manager.torch_dtype
217
+ pipe = FlashVSRTinyLongPipeline(device=device, torch_dtype=torch_dtype)
218
+ pipe.fetch_models(model_manager)
219
+ # 可选:统一序列并行入口(此处默认关闭)
220
+ pipe.use_unified_sequence_parallel = False
221
+ return pipe
222
+
223
+ def denoising_model(self):
224
+ return self.dit
225
+
226
+ # -------------------------
227
+ # 新增:显式 KV 预初始化函数
228
+ # -------------------------
229
+ def init_cross_kv(
230
+ self,
231
+ context_tensor: Optional[torch.Tensor] = None,
232
+ prompt_path = None,
233
+ ):
234
+ self.load_models_to_device(["dit"])
235
+ """
236
+ 使用固定 prompt 生成文本 context,并在 WanModel 中初始化所有 CrossAttention 的 KV 缓存。
237
+ 必须在 __call__ 前显式调用一次。
238
+ """
239
+ #prompt_path = "../../examples/WanVSR/prompt_tensor/posi_prompt.pth"
240
+
241
+ if self.dit is None:
242
+ raise RuntimeError("请先通过 fetch_models / from_model_manager 初始化 self.dit")
243
+
244
+ if context_tensor is None:
245
+ if prompt_path is None:
246
+ raise ValueError("init_cross_kv: 需要提供 prompt_path 或 context_tensor 其一")
247
+ ctx = torch.load(prompt_path, map_location=self.device)
248
+ else:
249
+ ctx = context_tensor
250
+
251
+ ctx = ctx.to(dtype=self.torch_dtype, device=self.device)
252
+
253
+ if self.prompt_emb_posi is None:
254
+ self.prompt_emb_posi = {}
255
+ self.prompt_emb_posi['context'] = ctx
256
+ self.prompt_emb_posi['stats'] = "load"
257
+
258
+ if hasattr(self.dit, "reinit_cross_kv"):
259
+ self.dit.reinit_cross_kv(ctx)
260
+ else:
261
+ raise AttributeError("WanModel 缺少 reinit_cross_kv(ctx) 方法,请在模型实现中加入该能力。")
262
+ self.timestep = torch.tensor([1000.], device=self.device, dtype=self.torch_dtype)
263
+ self.t = self.dit.time_embedding(sinusoidal_embedding_1d(self.dit.freq_dim, self.timestep))
264
+ self.t_mod = self.dit.time_projection(self.t).unflatten(1, (6, self.dit.dim))
265
+ # Scheduler
266
+ self.scheduler.set_timesteps(1, denoising_strength=1.0, shift=5.0)
267
+ self.load_models_to_device([])
268
+
269
+ def prepare_unified_sequence_parallel(self):
270
+ return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
271
+
272
+ def prepare_extra_input(self, latents=None):
273
+ return {}
274
+
275
+ def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
276
+ latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
277
+ return latents
278
+
279
+ def _decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
280
+ frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
281
+ return frames
282
+
283
+ def decode_video(self, latents, cond=None, **kwargs):
284
+ frames = self.TCDecoder.decode_video(
285
+ latents.transpose(1, 2), # TCDecoder 需要 (B, F, C, H, W)
286
+ parallel=False,
287
+ show_progress_bar=False,
288
+ cond=cond
289
+ ).transpose(1, 2).mul_(2).sub_(1) # 转回 (B, C, F, H, W) 格式,范围 -1 to 1
290
+
291
+ return frames
292
+
293
+ def offload_model(self, keep_vae=False):
294
+ self.dit.clear_cross_kv()
295
+ self.prompt_emb_posi['stats'] = "offload"
296
+ self.load_models_to_device([])
297
+ if hasattr(self.dit, "LQ_proj_in"):
298
+ self.dit.LQ_proj_in.to('cpu')
299
+ if not keep_vae:
300
+ self.TCDecoder.to('cpu')
301
+
302
+ @torch.no_grad()
303
+ def __call__(
304
+ self,
305
+ prompt=None,
306
+ negative_prompt="",
307
+ denoising_strength=1.0,
308
+ seed=None,
309
+ rand_device="gpu",
310
+ height=480,
311
+ width=832,
312
+ num_frames=81,
313
+ cfg_scale=5.0,
314
+ num_inference_steps=50,
315
+ sigma_shift=5.0,
316
+ tiled=True,
317
+ tile_size=(60, 104),
318
+ tile_stride=(30, 52),
319
+ tea_cache_l1_thresh=None,
320
+ tea_cache_model_id="Wan2.1-T2V-1.3B",
321
+ progress_bar_cmd=tqdm,
322
+ progress_bar_st=None,
323
+ LQ_video=None,
324
+ is_full_block=False,
325
+ if_buffer=False,
326
+ topk_ratio=2.0,
327
+ kv_ratio=3.0,
328
+ local_range = 9,
329
+ color_fix = True,
330
+ unload_dit = False,
331
+ force_offload = False,
332
+ ):
333
+ # 只接受 cfg=1.0(与原代码一致)
334
+ assert cfg_scale == 1.0, "cfg_scale must be 1.0"
335
+
336
+ # 要求:必须先 init_cross_kv()
337
+ if self.prompt_emb_posi is None or 'context' not in self.prompt_emb_posi:
338
+ raise RuntimeError(
339
+ "Cross-Attn KV 未初始化。请在调用 __call__ 前先执行:\n"
340
+ " pipe.init_cross_kv()\n"
341
+ "或传入自定义 context:\n"
342
+ " pipe.init_cross_kv(context_tensor=your_context_tensor)"
343
+ )
344
+
345
+ # 尺寸修正
346
+ height, width = self.check_resize_height_width(height, width)
347
+ if num_frames % 4 != 1:
348
+ num_frames = (num_frames + 2) // 4 * 4 + 1
349
+ print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
350
+
351
+ # Tiler 参数
352
+ tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
353
+
354
+ # 初始化噪声
355
+ if if_buffer:
356
+ noise = self.generate_noise((1, 16, (num_frames - 1) // 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
357
+ else:
358
+ noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
359
+ # noise = noise.to(dtype=self.torch_dtype, device=self.device)
360
+ latents = noise
361
+
362
+ process_total_num = (num_frames - 1) // 8 - 2
363
+ is_stream = True
364
+
365
+ if self.prompt_emb_posi['stats'] == "offload":
366
+ self.init_cross_kv(context_tensor=self.prompt_emb_posi['context'])
367
+ self.load_models_to_device(["dit"])
368
+ self.dit.LQ_proj_in.to(self.device)
369
+ self.TCDecoder.to(self.device)
370
+
371
+ # 清理可能存在的 LQ_proj_in cache
372
+ if hasattr(self.dit, "LQ_proj_in"):
373
+ self.dit.LQ_proj_in.clear_cache()
374
+
375
+ frames_total = []
376
+ LQ_pre_idx = 0
377
+ LQ_cur_idx = 0
378
+ self.TCDecoder.clean_mem()
379
+
380
+ with torch.no_grad():
381
+ for cur_process_idx in progress_bar_cmd(range(process_total_num)):
382
+ if cur_process_idx == 0:
383
+ pre_cache_k = [None] * len(self.dit.blocks)
384
+ pre_cache_v = [None] * len(self.dit.blocks)
385
+ LQ_latents = None
386
+ inner_loop_num = 7
387
+ for inner_idx in range(inner_loop_num):
388
+ cur = self.denoising_model().LQ_proj_in.stream_forward(
389
+ LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :].to(self.device)
390
+ ) if LQ_video is not None else None
391
+ if cur is None:
392
+ continue
393
+ if LQ_latents is None:
394
+ LQ_latents = cur
395
+ else:
396
+ for layer_idx in range(len(LQ_latents)):
397
+ LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
398
+ LQ_cur_idx = (inner_loop_num-1)*4-3
399
+ cur_latents = latents[:, :, :6, :, :]
400
+ else:
401
+ LQ_latents = None
402
+ inner_loop_num = 2
403
+ for inner_idx in range(inner_loop_num):
404
+ cur = self.denoising_model().LQ_proj_in.stream_forward(
405
+ LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :].to(self.device)
406
+ ) if LQ_video is not None else None
407
+ if cur is None:
408
+ continue
409
+ if LQ_latents is None:
410
+ LQ_latents = cur
411
+ else:
412
+ for layer_idx in range(len(LQ_latents)):
413
+ LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1)
414
+ LQ_cur_idx = cur_process_idx*8+21+(inner_loop_num-2)*4
415
+ cur_latents = latents[:, :, 4+cur_process_idx*2:6+cur_process_idx*2, :, :]
416
+
417
+ # 推理(无 motion_controller / vace)
418
+ noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video(
419
+ self.dit,
420
+ x=cur_latents,
421
+ timestep=self.timestep,
422
+ context=None,
423
+ tea_cache=None,
424
+ use_unified_sequence_parallel=False,
425
+ LQ_latents=LQ_latents,
426
+ is_full_block=is_full_block,
427
+ is_stream=is_stream,
428
+ pre_cache_k=pre_cache_k,
429
+ pre_cache_v=pre_cache_v,
430
+ topk_ratio=topk_ratio,
431
+ kv_ratio=kv_ratio,
432
+ cur_process_idx=cur_process_idx,
433
+ t_mod=self.t_mod,
434
+ t=self.t,
435
+ local_range = local_range,
436
+ )
437
+
438
+ # 更新 latent
439
+ cur_latents = cur_latents - noise_pred_posi
440
+
441
+ # Decode
442
+ cur_LQ_frame = LQ_video[:,:,LQ_pre_idx:LQ_cur_idx,:,:].to(self.device)
443
+ cur_frames = self.TCDecoder.decode_video(
444
+ cur_latents.transpose(1, 2),
445
+ parallel=False,
446
+ show_progress_bar=False,
447
+ cond=cur_LQ_frame).transpose(1, 2).mul_(2).sub_(1)
448
+
449
+ # 颜色校正(wavelet)
450
+ try:
451
+ if color_fix:
452
+ cur_frames = self.ColorCorrector(
453
+ cur_frames.to(device=self.device),
454
+ cur_LQ_frame,
455
+ clip_range=(-1, 1),
456
+ chunk_size=None,
457
+ method='adain'
458
+ )
459
+ except:
460
+ pass
461
+
462
+ frames_total.append(cur_frames.to('cpu'))
463
+ LQ_pre_idx = LQ_cur_idx
464
+
465
+ if unload_dit:
466
+ del noise_pred_posi, cur_frames, cur_latents, cur_LQ_frame
467
+ clean_vram()
468
+
469
+ if hasattr(self.dit, "LQ_proj_in"):
470
+ self.dit.LQ_proj_in.clear_cache()
471
+
472
+ self.TCDecoder.clean_mem()
473
+ if force_offload:
474
+ self.offload_model()
475
+
476
+ frames = torch.cat(frames_total, dim=2)
477
+
478
+ return frames[0]
479
+
480
+
481
+ # -----------------------------
482
+ # TeaCache(保留原逻辑;此处默认不启用)
483
+ # -----------------------------
484
+ class TeaCache:
485
+ def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
486
+ self.num_inference_steps = num_inference_steps
487
+ self.step = 0
488
+ self.accumulated_rel_l1_distance = 0
489
+ self.previous_modulated_input = None
490
+ self.rel_l1_thresh = rel_l1_thresh
491
+ self.previous_residual = None
492
+ self.previous_hidden_states = None
493
+
494
+ self.coefficients_dict = {
495
+ "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
496
+ "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
497
+ "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
498
+ "Wan2.1-I2V-14B-720P": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
499
+ }
500
+ if model_id not in self.coefficients_dict:
501
+ supported_model_ids = ", ".join([i for i in self.coefficients_dict])
502
+ raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
503
+ self.coefficients = self.coefficients_dict[model_id]
504
+
505
+ def check(self, dit: WanModel, x, t_mod):
506
+ modulated_inp = t_mod.clone()
507
+ if self.step == 0 or self.step == self.num_inference_steps - 1:
508
+ should_calc = True
509
+ self.accumulated_rel_l1_distance = 0
510
+ else:
511
+ coefficients = self.coefficients
512
+ rescale_func = np.poly1d(coefficients)
513
+ self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
514
+ should_calc = not (self.accumulated_rel_l1_distance < self.rel_l1_thresh)
515
+ if should_calc:
516
+ self.accumulated_rel_l1_distance = 0
517
+ self.previous_modulated_input = modulated_inp
518
+ self.step = (self.step + 1) % self.num_inference_steps
519
+ if should_calc:
520
+ self.previous_hidden_states = x.clone()
521
+ return not should_calc
522
+
523
+ def store(self, hidden_states):
524
+ self.previous_residual = hidden_states - self.previous_hidden_states
525
+ self.previous_hidden_states = None
526
+
527
+ def update(self, hidden_states):
528
+ hidden_states = hidden_states + self.previous_residual
529
+ return hidden_states
530
+
531
+
532
+ # -----------------------------
533
+ # 简化版模型前向封装(无 vace / 无 motion_controller)
534
+ # -----------------------------
535
+ def model_fn_wan_video(
536
+ dit: WanModel,
537
+ x: torch.Tensor,
538
+ timestep: torch.Tensor,
539
+ context: torch.Tensor,
540
+ tea_cache: Optional[TeaCache] = None,
541
+ use_unified_sequence_parallel: bool = False,
542
+ LQ_latents: Optional[torch.Tensor] = None,
543
+ is_full_block: bool = False,
544
+ is_stream: bool = False,
545
+ pre_cache_k: Optional[list[torch.Tensor]] = None,
546
+ pre_cache_v: Optional[list[torch.Tensor]] = None,
547
+ topk_ratio: float = 2.0,
548
+ kv_ratio: float = 3.0,
549
+ cur_process_idx: int = 0,
550
+ t_mod : torch.Tensor = None,
551
+ t : torch.Tensor = None,
552
+ local_range: int = 9,
553
+ **kwargs,
554
+ ):
555
+ # patchify
556
+ x, (f, h, w) = dit.patchify(x)
557
+
558
+ win = (2, 8, 8)
559
+ seqlen = f // win[0]
560
+ local_num = seqlen
561
+ window_size = win[0] * h * w // 128
562
+ square_num = window_size * window_size
563
+ topk = int(square_num * topk_ratio) - 1
564
+ kv_len = int(kv_ratio)
565
+
566
+ # RoPE 位置(分段)
567
+ if cur_process_idx == 0:
568
+ freqs = torch.cat([
569
+ dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
570
+ dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
571
+ dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
572
+ ], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
573
+ else:
574
+ freqs = torch.cat([
575
+ dit.freqs[0][4 + cur_process_idx*2:4 + cur_process_idx*2 + f].view(f, 1, 1, -1).expand(f, h, w, -1),
576
+ dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
577
+ dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
578
+ ], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
579
+
580
+ # TeaCache(默认不启用)
581
+ tea_cache_update = tea_cache.check(dit, x, t_mod) if tea_cache is not None else False
582
+
583
+ # 统一序列并行(此处默认关闭)
584
+ if use_unified_sequence_parallel:
585
+ import torch.distributed as dist
586
+ from xfuser.core.distributed import (get_sequence_parallel_rank,
587
+ get_sequence_parallel_world_size,
588
+ get_sp_group)
589
+ if dist.is_initialized() and dist.get_world_size() > 1:
590
+ x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
591
+
592
+ # Block 堆叠
593
+ if tea_cache_update:
594
+ x = tea_cache.update(x)
595
+ else:
596
+ for block_id, block in enumerate(dit.blocks):
597
+ if LQ_latents is not None and block_id < len(LQ_latents):
598
+ x = x + LQ_latents[block_id]
599
+ x, last_pre_cache_k, last_pre_cache_v = block(
600
+ x, context, t_mod, freqs, f, h, w,
601
+ local_num, topk,
602
+ block_id=block_id,
603
+ kv_len=kv_len,
604
+ is_full_block=is_full_block,
605
+ is_stream=is_stream,
606
+ pre_cache_k=pre_cache_k[block_id] if pre_cache_k is not None else None,
607
+ pre_cache_v=pre_cache_v[block_id] if pre_cache_v is not None else None,
608
+ local_range = local_range,
609
+ )
610
+ if pre_cache_k is not None: pre_cache_k[block_id] = last_pre_cache_k
611
+ if pre_cache_v is not None: pre_cache_v[block_id] = last_pre_cache_v
612
+
613
+ x = dit.head(x, t)
614
+ if use_unified_sequence_parallel:
615
+ import torch.distributed as dist
616
+ from xfuser.core.distributed import get_sp_group
617
+ if dist.is_initialized() and dist.get_world_size() > 1:
618
+ x = get_sp_group().all_gather(x, dim=1)
619
+ x = dit.unpatchify(x, (f, h, w))
620
+ return x, pre_cache_k, pre_cache_v
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/schedulers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .flow_match import FlowMatchScheduler
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/schedulers/flow_match.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+
5
+ class FlowMatchScheduler():
6
+
7
+ def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False, extra_one_step=False, reverse_sigmas=False):
8
+ self.num_train_timesteps = num_train_timesteps
9
+ self.shift = shift
10
+ self.sigma_max = sigma_max
11
+ self.sigma_min = sigma_min
12
+ self.inverse_timesteps = inverse_timesteps
13
+ self.extra_one_step = extra_one_step
14
+ self.reverse_sigmas = reverse_sigmas
15
+ self.set_timesteps(num_inference_steps)
16
+
17
+
18
+ def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None):
19
+ if shift is not None:
20
+ self.shift = shift
21
+ sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
22
+ if self.extra_one_step:
23
+ self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
24
+ else:
25
+ self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
26
+ if self.inverse_timesteps:
27
+ self.sigmas = torch.flip(self.sigmas, dims=[0])
28
+ self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
29
+ if self.reverse_sigmas:
30
+ self.sigmas = 1 - self.sigmas
31
+ self.timesteps = self.sigmas * self.num_train_timesteps
32
+ if training:
33
+ x = self.timesteps
34
+ y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
35
+ y_shifted = y - y.min()
36
+ bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
37
+ self.linear_timesteps_weights = bsmntw_weighing
38
+
39
+
40
+ def step(self, model_output, timestep, sample, to_final=False, **kwargs):
41
+ if isinstance(timestep, torch.Tensor):
42
+ timestep = timestep.cpu()
43
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
44
+ sigma = self.sigmas[timestep_id]
45
+ if to_final or timestep_id + 1 >= len(self.timesteps):
46
+ sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
47
+ else:
48
+ sigma_ = self.sigmas[timestep_id + 1]
49
+ prev_sample = sample + model_output * (sigma_ - sigma)
50
+ return prev_sample
51
+
52
+
53
+ def return_to_timestep(self, timestep, sample, sample_stablized):
54
+ if isinstance(timestep, torch.Tensor):
55
+ timestep = timestep.cpu()
56
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
57
+ sigma = self.sigmas[timestep_id]
58
+ model_output = (sample - sample_stablized) / sigma
59
+ return model_output
60
+
61
+
62
+ def add_noise(self, original_samples, noise, timestep):
63
+ if isinstance(timestep, torch.Tensor):
64
+ timestep = timestep.cpu()
65
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
66
+ sigma = self.sigmas[timestep_id]
67
+ sample = (1 - sigma) * original_samples + sigma * noise
68
+ return sample
69
+
70
+
71
+ def training_target(self, sample, noise, timestep):
72
+ target = noise - sample
73
+ return target
74
+
75
+
76
+ def training_weight(self, timestep):
77
+ timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
78
+ weights = self.linear_timesteps_weights[timestep_id]
79
+ return weights
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/vram_management/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .layers import *
custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/vram_management/layers.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, copy
2
+ from ..models.utils import init_weights_on_device
3
+
4
+
5
+ def cast_to(weight, dtype, device):
6
+ r = torch.empty_like(weight, dtype=dtype, device=device)
7
+ r.copy_(weight)
8
+ return r
9
+
10
+
11
+ class AutoWrappedModule(torch.nn.Module):
12
+ def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
13
+ super().__init__()
14
+ self.module = module.to(dtype=offload_dtype, device=offload_device)
15
+ self.offload_dtype = offload_dtype
16
+ self.offload_device = offload_device
17
+ self.onload_dtype = onload_dtype
18
+ self.onload_device = onload_device
19
+ self.computation_dtype = computation_dtype
20
+ self.computation_device = computation_device
21
+ self.state = 0
22
+
23
+ def offload(self):
24
+ if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
25
+ self.module.to(dtype=self.offload_dtype, device=self.offload_device)
26
+ self.state = 0
27
+
28
+ def onload(self):
29
+ if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
30
+ self.module.to(dtype=self.onload_dtype, device=self.onload_device)
31
+ self.state = 1
32
+
33
+ def forward(self, *args, **kwargs):
34
+ if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
35
+ module = self.module
36
+ else:
37
+ module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device)
38
+ return module(*args, **kwargs)
39
+
40
+
41
+ class AutoWrappedLinear(torch.nn.Linear):
42
+ def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
43
+ with init_weights_on_device(device=torch.device("meta")):
44
+ super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
45
+ self.weight = module.weight
46
+ self.bias = module.bias
47
+ self.offload_dtype = offload_dtype
48
+ self.offload_device = offload_device
49
+ self.onload_dtype = onload_dtype
50
+ self.onload_device = onload_device
51
+ self.computation_dtype = computation_dtype
52
+ self.computation_device = computation_device
53
+ self.state = 0
54
+
55
+ def offload(self):
56
+ if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
57
+ self.to(dtype=self.offload_dtype, device=self.offload_device)
58
+ self.state = 0
59
+
60
+ def onload(self):
61
+ if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
62
+ self.to(dtype=self.onload_dtype, device=self.onload_device)
63
+ self.state = 1
64
+
65
+ def forward(self, x, *args, **kwargs):
66
+ if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
67
+ weight, bias = self.weight, self.bias
68
+ else:
69
+ weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
70
+ bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
71
+ return torch.nn.functional.linear(x, weight, bias)
72
+
73
+
74
+ def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0):
75
+ for name, module in model.named_children():
76
+ for source_module, target_module in module_map.items():
77
+ if isinstance(module, source_module):
78
+ num_param = sum(p.numel() for p in module.parameters())
79
+ if max_num_param is not None and total_num_param + num_param > max_num_param:
80
+ module_config_ = overflow_module_config
81
+ else:
82
+ module_config_ = module_config
83
+ module_ = target_module(module, **module_config_)
84
+ setattr(model, name, module_)
85
+ total_num_param += num_param
86
+ break
87
+ else:
88
+ total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param)
89
+ return total_num_param
90
+
91
+
92
+ def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None):
93
+ enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0)
94
+ model.vram_management_enabled = True
95
+
custom_nodes/ComfyUI-LCS/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .claude/
2
+ __pycache__/
custom_nodes/ComfyUI-LCS/README.md ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ComfyUI-LCS
2
+
3
+ Training-free color control via the **Latent Color Subspace**, plus **sharpness control** via a discovered sharpness subspace.
4
+
5
+ > **Note:** This is an unofficial community implementation. For the official code, see [ExplainableML/LCS](https://github.com/ExplainableML/LCS).
6
+
7
+ Based on ["The Latent Color Subspace"](https://arxiv.org/abs/2603.12261v1) (ICML 2026): color in diffusion model latent patch spaces lives in a **3D subspace** (PCA captures 100% color variance), while the remaining 61 dimensions encode structure and detail orthogonally.
8
+
9
+ This plugin steers colors directly in the 3D LCS during diffusion sampling — no training, no LoRA, no post-processing.
10
+
11
+ > [中文版 README](README_zh.md)
12
+
13
+ ## LCS vs Traditional Post-Processing
14
+
15
+ LCS operates **during** diffusion sampling, not after — this is the key difference from traditional color grading (Photoshop, filters, etc.).
16
+
17
+ | | Traditional Post-Processing | LCS |
18
+ |---|---|---|
19
+ | **When** | After VAE decode, in pixel space | During sampling, in latent space |
20
+ | **Mechanism** | Color filter on the final image | Modifies 3D color subspace mid-generation |
21
+ | **Model awareness** | None — structure already locked | Model adapts to color shifts in subsequent steps |
22
+ | **Result** | Colors can look "painted on" | Colors look naturally intended by the model |
23
+
24
+ For example: to get a warm orange sunset, post-processing tints everything orange (muddying shadows and skin tones), while LCS nudges the color subspace early in sampling so clouds, lighting, and reflections are *coherently* warm.
25
+
26
+ The core insight: color and structure are **orthogonal** in the latent patch space — you can steer one without disturbing the other.
27
+
28
+ ## Tested Models
29
+
30
+ | Model | Status |
31
+ |-------|--------|
32
+ | FLUX | Tested |
33
+ | FLUX2.klein | Tested |
34
+ | z-image | Tested |
35
+ | z-image-turbo | Tested |
36
+ | Wan (qwen-image) | Tested |
37
+ | LTX2.3 | Tested |
38
+
39
+
40
+ LCS calibrates per-VAE, so it should work with any model using a compatible VAE. Feel free to report results with other models.
41
+
42
+ ## Features
43
+
44
+ - **Color Steering** — Push colors toward any target color
45
+ - **Batch Multi-Color** — Different colors per batch item
46
+ - **Tone Adjustment** — Contrast, brightness, saturation, temperature with one-click presets
47
+ - **Color Anchor** — Zero-config color drift correction: self-anchor, reference-based, or spatial smoothing with auto mode
48
+ - **Sharpness Control** — Sharpen or blur during generation via a discovered sharpness subspace (PC1 explains ~97% variance)
49
+ - **Localized Control** — Optional mask for region-specific changes
50
+ - **Latent Color Preview** — Visualize color structure without VAE decoding
51
+ - **Step Observer** — Per-step color previews for debugging
52
+
53
+ ## Installation
54
+
55
+ ```bash
56
+ cd ComfyUI/custom_nodes
57
+ git clone https://github.com/facok/ComfyUI-LCS.git
58
+ ```
59
+
60
+ Dependencies (usually already present in ComfyUI):
61
+
62
+ ```bash
63
+ pip install einops safetensors
64
+ ```
65
+
66
+ ## Quick Start
67
+
68
+ ### Basic Color Control
69
+
70
+ ```
71
+ LCS Load Data → LCS Color Intervene → KSampler
72
+
73
+ (pick a color)
74
+ ```
75
+
76
+ 1. **LCS Load Data** — connect your VAE (auto-calibrates on first run)
77
+ 2. **LCS Color Intervene** — connect MODEL and LCS_DATA, pick a target color
78
+ 3. Connect the output MODEL to KSampler
79
+
80
+ ### Tone Adjustment
81
+
82
+ ```
83
+ LCS Load Data → LCS Tone Adjust → KSampler
84
+ ```
85
+
86
+ 1. **LCS Load Data** → **LCS Tone Adjust**
87
+ 2. Select a preset (e.g., "Cinematic") or adjust sliders manually
88
+
89
+ ![3d3c82eb0e89ed1608e40ac7a8cc3408](https://github.com/user-attachments/assets/62868e2d-0275-4801-a9bd-606bfea3ce2f)
90
+ ![42541357](https://github.com/user-attachments/assets/fe22f09e-98ac-4281-ae40-f58232c7700f)
91
+
92
+ ### Sharpness Control
93
+
94
+ ```
95
+ LCS Load Data ──→ LCS Sharpness Calibrate → LCS Sharpness Intervene → KSampler
96
+ ↑ lcs_data
97
+ ```
98
+
99
+ 1. **LCS Sharpness Calibrate** — connect VAE (auto-calibrates and caches). Optionally connect `lcs_data` from LCS Load Data to ensure sharpness edits don't affect color.
100
+ 2. **LCS Sharpness Intervene** — connect MODEL and SHARPNESS_DATA, set strength
101
+ - Positive strength → sharper
102
+ - Negative strength → blurrier
103
+ - 0 → no change
104
+ ![89814728](https://github.com/user-attachments/assets/62f036e9-0bea-4cc0-9220-af4c2fb8fa76)
105
+ ### Multi-Color Batch
106
+
107
+ ```
108
+ LCS Load Data → LCS Color Batch → KSampler
109
+
110
+ batch_size → EmptyLatentImage
111
+ ```
112
+
113
+ Enter comma-separated hex colors (e.g., `#FF0000,#00FF00,#0000FF`). Each color applies to one batch item.
114
+
115
+ ### Color Anchor (Zero-Config Drift Correction)
116
+
117
+ ```
118
+ LCS Load Data → LCS Color Anchor → KSampler
119
+ ```
120
+
121
+ 1. **LCS Load Data** → **LCS Color Anchor** — connect MODEL and LCS_DATA
122
+ 2. Set mode to **auto** (default) and leave intensity at default
123
+ 3. Connect the output MODEL to KSampler
124
+
125
+ That's it. In `auto` mode, the node automatically selects the correction strategy based on which optional inputs are connected:
126
+
127
+ | Connected Inputs | Resolved Mode | Behavior |
128
+ |---|---|---|
129
+ | Nothing | self_anchor | Learns the image's color patterns early on, then prevents sudden color shifts |
130
+ | reference_image + vae | reference | Keeps generated colors close to your reference image |
131
+ | mask (no reference) | smooth | Smooths out color seams (great for inpainting) |
132
+
133
+ Intensity is also derived automatically from measured drift — no manual tuning needed.
134
+
135
+ > **When to use manual mode:** If you want full control, set mode to `smooth`, `reference`, or `self_anchor` explicitly and adjust the `intensity` slider (0–1). Auto mode is designed for zero-config "just works" usage.
136
+
137
+ ## Nodes
138
+
139
+ ### Calibration
140
+
141
+ | Node | Description |
142
+ |------|-------------|
143
+ | **LCS Load Data** | Auto-calibrate and cache LCS color data per-VAE. Fingerprints VAE weights for automatic cache management. |
144
+ | **LCS Sharpness Calibrate** | Discover sharpness subspace via PCA on blur stimuli. Optionally connect `lcs_data` for color-orthogonal sharpness. |
145
+
146
+ Calibration runs once per VAE and caches automatically. Subsequent runs load instantly.
147
+
148
+ ### Intervention
149
+
150
+ | Node | Description |
151
+ |------|-------------|
152
+ | **LCS Color Intervene** | Steer colors toward a target. Supports Type I (LCS shift), Type II (HSL shift), or interpolated mode. |
153
+ | **LCS Color Batch** | Different target colors per batch item. Outputs `batch_size` for EmptyLatentImage. |
154
+ | **LCS Tone Adjust** | Contrast, brightness, saturation, temperature. Preset dropdown with real-time slider sync. |
155
+ | **LCS Color Anchor** | Correct color drift during sampling. Auto mode infers strategy and intensity from connected inputs. |
156
+ | **LCS Sharpness Intervene** | Control sharpness during generation. Positive = sharper, negative = blurrier. |
157
+
158
+ ### Observation
159
+
160
+ | Node | Description |
161
+ |------|-------------|
162
+ | **LCS Preview Colors** | Decode latent colors to RGB preview without VAE decoding. |
163
+ | **LCS Step Observer** | Save per-step color preview PNGs to ComfyUI temp directory. |
164
+
165
+ ## Intervention Modes
166
+
167
+ | Mode | Description | Best For |
168
+ |------|-------------|----------|
169
+ | **interpolated** (default) | Blends Type I and Type II using sigma | General use |
170
+ | **type_i** | Direct translation in 3D LCS space | Strong global color shifts |
171
+ | **type_ii** | Per-patch HSL interpolation via bicone geometry | Precise local color control |
172
+
173
+ ## Key Parameters
174
+
175
+ ### Color Intervention
176
+ - **strength** (0.0–2.0): Intervention intensity. 1.0 = full, 0.0 = none.
177
+ - **start_step / end_step**: Step range for intervention. Paper optimal: steps 8–10 of 50.
178
+ - **mask**: Optional. Downsampled to patch grid for localized control.
179
+
180
+ ### Sharpness Intervention
181
+ - **strength** (-5.0–5.0): Positive = sharper, negative = blurrier, 0 = no change.
182
+ - **start_step / end_step**: Step range (default 5–15).
183
+ - **mask**: Optional. Localized sharpness control.
184
+
185
+ > **Tip for distilled models**: Step-distilled models (e.g., z-image-turbo) use far fewer steps, so intervention should start earlier — even from step 0.
186
+
187
+ ### Color Anchor
188
+
189
+ Sometimes diffusion models produce unexpected color shifts during sampling — a blue sky suddenly turns purple, or inpainting leaves visible color seams. The Color Anchor node fixes these problems by monitoring and correcting colors as the image is being generated.
190
+
191
+ **Modes:**
192
+
193
+ | Mode | What it does | When to use |
194
+ |------|-------------|----------|
195
+ | **auto** (default) | Looks at what you connected and picks the best strategy for you | Just want it to work, no config needed |
196
+ | **self_anchor** | Watches how colors evolve in early steps, then prevents sudden color jumps in later steps | General color stability, no reference needed |
197
+ | **reference** | Keeps the generated image's colors close to a reference image you provide | "Make it look like this photo's color palette" |
198
+ | **smooth** | Smooths out abrupt color boundaries between regions | Fixing visible seams after inpainting |
199
+
200
+ **How auto mode picks for you:**
201
+
202
+ 1. **Which strategy?** Based on what you plugged in:
203
+ - Connected a reference image + VAE → uses `reference`
204
+ - Connected a mask (but no reference) → uses `smooth`
205
+ - Connected nothing extra → uses `self_anchor`
206
+ 2. **How strong?** The node measures how much color drift is actually happening, then sets the correction strength accordingly. Big drift → stronger fix. Small drift → gentle touch. The range is 0.15–0.6, so it never over-corrects or does nothing.
207
+
208
+ **What happens during sampling:**
209
+
210
+ The node runs at every sampling step but doesn't always intervene. It automatically figures out which steps are safe to correct:
211
+
212
+ 1. **Early steps** (image is mostly noise) — Too early to fix colors without creating artifacts. Skipped. In self_anchor mode, the node uses these steps to *learn* the image's color patterns.
213
+ 2. **Middle steps** (image is taking shape) — The sweet spot. The node applies corrections here, ramping smoothly in and out to avoid sudden changes.
214
+ 3. **Late steps** (fine details) — Corrections would disturb fine detail. Skipped.
215
+
216
+ Only colors are modified — structure, texture, and detail are never touched.
217
+
218
+ **Parameters:**
219
+
220
+ - **mode**: `auto`, `smooth`, `reference`, or `self_anchor`
221
+ - **intensity** (0.0–1.0): How strong the correction is. In `auto` mode this is determined automatically. Set to 0 to disable the node entirely.
222
+ - **vae** (optional): Needed for `reference` mode to encode the reference image
223
+ - **reference_image** (optional): The image whose colors you want to match
224
+ - **mask** (optional): Only correct colors inside the masked area
225
+
226
+ ## Tone Presets
227
+
228
+ Select a preset — sliders update in real-time. Tweak after selecting for fine-tuning. Select **Custom** to set values manually.
229
+
230
+ | Preset | Contrast | Brightness | Saturation | Temperature |
231
+ |--------|----------|------------|------------|-------------|
232
+ | Base | 1.0 | 0.0 | 1.0 | 0.0 |
233
+ | Cinematic | 1.20 | -0.05 | 0.90 | 0.05 |
234
+ | HDR | 1.40 | 0.0 | 1.20 | 0.0 |
235
+ | Vivid | 1.10 | 0.0 | 1.50 | 0.0 |
236
+ | Dramatic | 1.50 | -0.10 | 0.85 | 0.0 |
237
+ | Low Key | 1.30 | -0.20 | 0.80 | 0.0 |
238
+ | High Key | 0.80 | 0.20 | 0.90 | 0.0 |
239
+ | Warm | 1.05 | 0.03 | 1.10 | 0.30 |
240
+ | Cool | 1.05 | 0.0 | 1.05 | -0.30 |
241
+ | Desaturated | 1.0 | 0.0 | 0.40 | 0.0 |
242
+
243
+ ## How It Works
244
+
245
+ ### Color (LCS)
246
+
247
+ 1. **Project** — Convert denoised prediction to 64D patch space, project onto 3D LCS basis
248
+ 2. **Decompose** — Separate 3D color coordinates from the 61D structural residual
249
+ 3. **Normalize** — Transform to reference timestep (t=50) using learned alpha/beta statistics
250
+ 4. **Manipulate** — Shift colors, adjust tone, or apply other transformations in 3D LCS
251
+ 5. **Reconstruct** — Denormalize, add back the preserved 61D residual, convert to latent space
252
+
253
+ The 61D residual (structure, texture, detail) is never modified — only the 3D color subspace is touched.
254
+
255
+ ### Sharpness
256
+
257
+ Sharpness lives in a separate subspace orthogonal to color:
258
+
259
+ 1. **Calibrate** — Generate grayscale noise images at multiple blur levels, VAE-encode, PCA on color-removed patch vectors. PC1 captures ~97% of sharpness variance.
260
+ 2. **Intervene** — Add `strength * pc1_direction` to each patch. Since pc1_direction is orthogonal to color (calibrated with LCS removal) and DC-free (per-vector zero-mean before PCA), this modifies only spatial frequency content without affecting color or brightness.
261
+
262
+ ### Color Anchor
263
+
264
+ The Color Anchor stabilizes colors without pushing them toward a specific target — it prevents drift from what the model is already generating:
265
+
266
+ 1. **Decide when to act** — The node checks each sampling step: is the image still mostly noise (too early), taking shape (good time to correct), or nearly finished (too late)? It only corrects during the safe middle window.
267
+ 2. **Learn the color pattern** (self_anchor) — During early noisy steps, the node watches how colors relate to their neighbors and builds a running average of these relationships. This is more reliable than tracking absolute colors, which shift naturally as the image forms.
268
+ 3. **Measure drift** — On the first correction step, the node measures how much the colors have actually drifted (varies by mode: step-to-step jumps, distance from reference, or spatial roughness). This sets the correction strength in auto mode.
269
+ 4. **Apply gentle corrections** — Corrections ramp smoothly in and out (no sudden jumps). Each mode corrects differently: self_anchor fixes patches that deviate from learned patterns, reference pulls toward the reference image's colors, smooth blurs out sharp color boundaries.
270
+ 5. **Preserve everything else** — As with all LCS operations, only the 3D color coordinates change. Structure, texture, and detail are untouched.
271
+
272
+ ## File Structure
273
+
274
+ ```
275
+ ComfyUI-LCS/
276
+ ├── __init__.py # Entry point (V3 + V2 compat)
277
+ ├── requirements.txt
278
+ ├── core/
279
+ │ ├── adaptive.py # Adaptive scheduling (phases, envelopes, drift estimation)
280
+ │ ├── bilateral.py # Bilateral filter for LCS color smoothing
281
+ │ ├── calibration.py # PCA calibration pipeline (color)
282
+ │ ├── color_space.py # Bicone LCS ↔ HSL mapping
283
+ │ ├── defaults.py # Alpha/beta tables from paper
284
+ │ ├── lcs_data.py # LCSData dataclass
285
+ │ ├── patchify.py # Patch ↔ latent conversion
286
+ │ ├── relationships.py # Local color relationship analysis & anomaly detection
287
+ │ ├── sampling.py # Shared constants & step utilities
288
+ │ ├── sharpness.py # Sharpness subspace calibration
289
+ │ └── timestep.py # Sigma/timestep utilities
290
+ ├── nodes/
291
+ │ ├── anchor.py # LCSColorAnchor (adaptive color drift correction)
292
+ │ ├── calibrate.py # LCSLoadData (auto-calibrate + cache)
293
+ │ ├── intervene.py # LCSColorIntervene, LCSColorBatch, LCSToneAdjust
294
+ │ ├── observe.py # LCSPreviewColors, LCSStepObserver
295
+ │ └── sharpen.py # LCSSharpnessCalibrate, LCSSharpnessIntervene
296
+ ├── data/ # Cached calibration files
297
+ └── web/js/
298
+ └── tone_preset.js # Frontend preset sync
299
+ ```
300
+
301
+ ## Changelog
302
+
303
+ ### 2026-03-21
304
+ - **Color Anchor: auto mode** — New `auto` mode that infers correction strategy (self_anchor / reference / smooth) from connected inputs and derives intensity from measured drift. Zero-config usage.
305
+ - **Color Anchor: adaptive scheduling** — Phase assignment (observe/correct/skip) and strength envelope are derived from the sigma schedule at runtime.
306
+
307
+ ### 2026-03-20
308
+ - **Sharpness Control** — New sharpness subspace discovered via PCA on blur stimuli. `LCS Sharpness Calibrate` + `LCS Sharpness Intervene` nodes. PC1 explains ~97% variance, orthogonal to color.
309
+ - **Color-orthogonal sharpness** — Optional `lcs_data` input removes color component during sharpness calibration, preventing color shift.
310
+
311
+ ### 2026-03-19
312
+ - **Video VAE support (Wan)** — Handle 5D video latents in patchify/unpatchify. Per-image VAE encoding fallback for video VAEs.
313
+ - **LTXV compatibility** — Pad odd spatial dims in patchify, handle 3D tensors, skip gracefully for incompatible formats.
314
+ - **FLUX2 support** — Auto-detect 128-channel latents in unpatchify.
315
+ - **Universal latent format** — Use model's `latent_format` for space conversion instead of hardcoded FLUX constants.
316
+
317
+ ### 2026-03-18
318
+ - **Tone Adjust** — `LCS Tone Adjust` node with contrast, brightness, saturation, temperature sliders. 10 presets with frontend real-time sync.
319
+ - **Color temperature** — Warm/cool shift along LCS blue-yellow axis.
320
+ - **Bicone HSL geometry** — Correct Type II intervention via bicone LCS-to-HSL mapping.
321
+
322
+ ### 2026-03-17
323
+ - **Initial release** — Color steering (Type I + Type II + interpolated), batch multi-color, localized mask control, latent color preview, step observer. Per-VAE auto-calibration with caching.
324
+
325
+ ## Citation
326
+
327
+ Official repository: [ExplainableML/LCS](https://github.com/ExplainableML/LCS)
328
+
329
+ ```bibtex
330
+ @article{pach2026latentcolorsubspace,
331
+ title={The Latent Color Subspace: Emergent Order in High-Dimensional Chaos},
332
+ author={Mateusz Pach and Jessica Bader and Quentin Bouniot and Serge Belongie and Zeynep Akata},
333
+ journal={arxiv},
334
+ year={2026}
335
+ }
336
+ ```
337
+
338
+ ## Acknowledgments
339
+
340
+ Thanks to Mateusz Pach, Jessica Bader, Quentin Bouniot, Serge Belongie, and Zeynep Akata for their research making training-free color control possible.
341
+
342
+ ## License
343
+
344
+ MIT
custom_nodes/ComfyUI-LCS/README_zh.md ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ComfyUI-LCS
2
+
3
+ 基于**潜在颜色子空间**(Latent Color Subspace)的免训练颜色控制,以及基于发现的**锐度子空间**的锐度控制。
4
+
5
+ > **注意:** 本项目为非官方社区实现。官方代码见 [ExplainableML/LCS](https://github.com/ExplainableML/LCS)。
6
+
7
+ 基于论文 ["The Latent Color Subspace"](https://arxiv.org/abs/2603.12261v1)(ICML 2026):扩散模型潜在 patch 空间中的颜色完全存在于一个 **3 维子空间**(PCA 捕获 100% 颜色方差),剩余 61 维编码结构与细节,与颜色正交。
8
+
9
+ 本插件在扩散采样过程中直接操作 3D LCS 控制颜色——无需训练、无需 LoRA、无需后处理。
10
+
11
+ > [English README](README.md)
12
+
13
+ ## LCS 与传统后处理调色的区别
14
+
15
+ LCS 在扩散采样**过程中**操作,而非生成之后——这是与传统调色(Photoshop、滤镜等)的根本区别。
16
+
17
+ | | 传统后处理 | LCS |
18
+ |---|---|---|
19
+ | **时机** | VAE 解码后,像素空间 | 采样过程中,潜在空间 |
20
+ | **机制** | 对成品图像施加颜色滤镜 | 在生成中途修改 3D 颜色子空间 |
21
+ | **模型感知** | 无——结构已定型 | 模型在后续步骤中自适应颜色偏移 |
22
+ | **效果** | 颜色容易显得"涂上去的" | 颜色与内容自然融合 |
23
+
24
+ 例:想要暖橙色日落,后处理会给全图叠橙色(阴影和肤色变脏),而 LCS 在采样早期推动颜色子空间,模型生成的云层、光照、反射与暖色调**内在一致**。
25
+
26
+ 核心发现:颜色与结构在潜在 patch 空间中**正交**——可以单独控制颜色而不干扰结构。
27
+
28
+ ## 已测试模型
29
+
30
+ | 模型 | 状态 |
31
+ |------|------|
32
+ | FLUX | 已测试 |
33
+ | FLUX2.klein | 已测试 |
34
+ | z-image | 已测试 |
35
+ | z-image-turbo | 已测试 |
36
+ | Wan (qwen-image) | 已测试 |
37
+ | LTX2.3 | 已测试 |
38
+
39
+ LCS 按 VAE 校准,理论上适用于任何使用兼容 VAE 架构的模型。欢迎反馈其他模型的测试结果。
40
+
41
+ ## 功能
42
+
43
+ - **颜色引导** — 将颜色推向任意目标色
44
+ - **批量多色** — 为批次中每张图像指定不同颜色
45
+ - **色调调整** — 对比度、亮度、饱和度、色温,支持一键预设
46
+ - **颜色锚定** — 零配置颜色漂移校正:自锚定、参考图锚定、空间平滑,支持全自动模式
47
+ - **锐度控制** — 在生成过程中增强或减弱锐度,基于发现的锐度子空间(PC1 解释 ~97% 方差)
48
+ - **局部控制** — 可选遮罩,实现区域性变化
49
+ - **潜在颜色预览** — 无需 VAE 解码即可可视化颜色结构
50
+ - **步骤观察器** — 保存每步颜色预览,用于调试
51
+
52
+ ## 安装
53
+
54
+ ```bash
55
+ cd ComfyUI/custom_nodes
56
+ git clone https://github.com/facok/ComfyUI-LCS.git
57
+ ```
58
+
59
+ 依赖(通常 ComfyUI 已自带):
60
+
61
+ ```bash
62
+ pip install einops safetensors
63
+ ```
64
+
65
+ ## 快速开始
66
+
67
+ ### 基本颜色控制
68
+
69
+ ```
70
+ LCS Load Data → LCS Color Intervene → KSampler
71
+
72
+ (选择颜色)
73
+ ```
74
+
75
+ 1. **LCS Load Data** — 连接 VAE(首次运行自动校准)
76
+ 2. **LCS Color Intervene** — 连接 MODEL 和 LCS_DATA,选择目标颜色
77
+ 3. 将输出 MODEL 连接到 KSampler
78
+
79
+ ### 色调调整
80
+
81
+ ```
82
+ LCS Load Data → LCS Tone Adjust → KSampler
83
+ ```
84
+
85
+ 1. **LCS Load Data** → **LCS Tone Adjust**
86
+ 2. 选择预设(如 "Cinematic")或手动调整滑条
87
+
88
+ ![3d3c82eb0e89ed1608e40ac7a8cc3408](https://github.com/user-attachments/assets/62868e2d-0275-4801-a9bd-606bfea3ce2f)
89
+ ![42541357](https://github.com/user-attachments/assets/fe22f09e-98ac-4281-ae40-f58232c7700f)
90
+ ### 锐度控制
91
+
92
+ ```
93
+ LCS Load Data ──→ LCS Sharpness Calibrate → LCS Sharpness Intervene → KSampler
94
+ ↑ lcs_data
95
+ ```
96
+
97
+ 1. **LCS Sharpness Calibrate** — 连接 VAE(首次运行自动校准并缓存)。可选连接 `lcs_data`(来自 LCS Load Data),确保锐度编辑不影响颜色。
98
+ 2. **LCS Sharpness Intervene** — 连接 MODEL 和 SHARPNESS_DATA,设置强度
99
+ - 正值 → 更锐利
100
+ - 负值 → 更模糊
101
+ - 0 → 无变化
102
+ ![89814728](https://github.com/user-attachments/assets/62f036e9-0bea-4cc0-9220-af4c2fb8fa76)
103
+
104
+ ### 批量多色生成
105
+
106
+ ```
107
+ LCS Load Data → LCS Color Batch → KSampler
108
+
109
+ batch_size → EmptyLatentImage
110
+ ```
111
+
112
+ 输入逗号分隔的十六进制颜色(如 `#FF0000,#00FF00,#0000FF`),每个颜色对应一个批次项。
113
+
114
+ ### 颜色锚定(零配置漂移校正)
115
+
116
+ ```
117
+ LCS Load Data → LCS Color Anchor → KSampler
118
+ ```
119
+
120
+ 1. **LCS Load Data** → **LCS Color Anchor** — 连接 MODEL 和 LCS_DATA
121
+ 2. 模式设为 **auto**(默认),intensity 保持默认值
122
+ 3. 将输出 MODEL 连接到 KSampler
123
+
124
+ 完成。在 `auto` 模式下,节点根据连接的可选输入自动选择校正策略:
125
+
126
+ | 已连接输入 | 解析模式 | 行为 |
127
+ |---|---|---|
128
+ | 无 | self_anchor | 在早期学习图像的颜色规律,然后防止突然的颜色偏移 |
129
+ | reference_image + vae | reference | 让生成的颜色贴近你的参考图 |
130
+ | mask(无参考图) | smooth | 平滑颜色接缝(很适合修复/补绘) |
131
+
132
+ intensity 也会根据实测漂移自动推导——无需手动调参。
133
+
134
+ > **手动模式:** 如果需要完全控制,可以将模式设为 `smooth`、`reference` 或 `self_anchor`,并手动调节 `intensity` 滑条(0–1)。auto 模式适合零配置「开箱即用」场景。
135
+
136
+ ## 节点一览
137
+
138
+ ### 校准
139
+
140
+ | 节点 | 说明 |
141
+ |------|------|
142
+ | **LCS Load Data** | 自动校准并按 VAE 缓存 LCS 颜色数据。通过 VAE 权重指纹自动管理缓存。 |
143
+ | **LCS Sharpness Calibrate** | 通过模糊刺激 PCA 发现锐度子空间。可选连接 `lcs_data` 使锐度正交于颜色。 |
144
+
145
+ 每个 VAE 只需校准一次,结果自动缓存,后续运行瞬时加载。
146
+
147
+ ### 干预
148
+
149
+ | 节点 | 说明 |
150
+ |------|------|
151
+ | **LCS Color Intervene** | 将颜色引导至目标色。支持 Type I(LCS 平移)、Type II(HSL 偏移)或插值模式。 |
152
+ | **LCS Color Batch** | 每个批次项施加不同目标颜色。输出 `batch_size` 可连接 EmptyLatentImage。 |
153
+ | **LCS Tone Adjust** | 对比度、亮度、饱和度、色温调整。预设下拉菜单,滑条实时同步。 |
154
+ | **LCS Color Anchor** | 采样过程中校正颜色漂移。auto 模式根据连接输入自动推断策略和强度。 |
155
+ | **LCS Sharpness Intervene** | 在生成过程中控制锐度。正值 = 更锐利,负值 = 更模糊。 |
156
+
157
+ ### 观察
158
+
159
+ | 节点 | 说明 |
160
+ |------|------|
161
+ | **LCS Preview Colors** | 将潜在颜色解码为 RGB 预览图,无需 VAE 解码。 |
162
+ | **LCS Step Observer** | 将每步颜色预览 PNG 保存至 ComfyUI 临时目录。 |
163
+
164
+ ## 干预模式
165
+
166
+ | 模式 | 说明 | 适用场景 |
167
+ |------|------|----------|
168
+ | **interpolated**(默认) | 以 sigma 为权重混合 Type I 和 Type II | 通用场景 |
169
+ | **type_i** | 3D LCS 空间中的直接平移 | 强烈的全局颜色偏移 |
170
+ | **type_ii** | 通过双锥几何进行逐 patch HSL 插值 | 精确的局部颜色控制 |
171
+
172
+ ## 关键参数
173
+
174
+ ### 颜色干预
175
+ - **strength**(0.0–2.0):干预强度。1.0 = 完整干预,0.0 = 无干预。
176
+ - **start_step / end_step**:干预步骤范围。论文最优:50 步中的第 8–10 步。
177
+ - **mask**:可选。下采样至 patch 网格分辨率,用于局部控制。
178
+
179
+ ### 锐度干预
180
+ - **strength**(-5.0–5.0):正值 = 更锐利,负值 = 更模糊,0 = 无变化。
181
+ - **start_step / end_step**:干预步骤范围(默认 5–15)。
182
+ - **mask**:可选。用于局部锐度控制。
183
+
184
+ > **步数蒸馏模型提示**:对于步数蒸馏模型(如 z-image-turbo),总步数很少,干预应从更早的步骤开始——甚至可以从第 0 步就开始干预。
185
+
186
+ ### 颜色锚定
187
+
188
+ 扩散模型在采样过程中有时会出现意想不到的颜色偏移——蓝天突然变紫,或者修复/补绘后留下明显的颜色接缝。颜色锚定节点在图像生成过程中监控和修正这些问题。
189
+
190
+ **模式:**
191
+
192
+ | 模式 | 功能 | 适用场景 |
193
+ |------|------|----------|
194
+ | **auto**(默认) | 根据你连接的输入自动选最合适的策略 | 不想调参,开箱即用 |
195
+ | **self_anchor** | 在早期步骤观察颜色变化规律,在后续步骤防止突然的颜色跳变 | 通用颜色稳定,不需要参考图 |
196
+ | **reference** | 让生成图像的颜色贴近你提供的参考图 | 「我想要这张照片的配色风格」 |
197
+ | **smooth** | 平滑区域之间的突兀颜色边界 | 修复/补绘后消除接缝 |
198
+
199
+ **auto 模式如何自动选择:**
200
+
201
+ 1. **用哪种策略?** 看你连了什么:
202
+ - 连了参考图 + VAE → 用 `reference`
203
+ - 连了遮罩(没有参考图)→ 用 `smooth`
204
+ - 什么额外输入都没连 → 用 `self_anchor`
205
+ 2. **修正多强?** 节点会测量实际的颜色漂移幅度,据此自动设置校正强度。漂移大 → 修正更强;漂移小 → 轻轻一碰。范围是 0.15–0.6,既不会矫枉过正,也不会毫无作用。
206
+
207
+ **采样过程中发生了什么:**
208
+
209
+ 节点在每个采样步都会运行,但不会每步都干预。它自动判断哪些步骤适合校正:
210
+
211
+ 1. **早期步骤**(图像基本是噪声)— 太早修正颜色会产生伪影,跳过。在 self_anchor 模式下,节点利用这些步骤*学习*图像的颜色规律。
212
+ 2. **中间步骤**(图像逐渐成形)— 最佳校正时机。节点在这里施加校正,平滑地渐入渐出,避免突变。
213
+ 3. **后期步骤**(精细细节)— 校正会干扰细节,跳过。
214
+
215
+ 只修改颜色——结构、纹理、细节始终不受影响。
216
+
217
+ **参数:**
218
+
219
+ - **mode**:`auto`、`smooth`、`reference` 或 `self_anchor`
220
+ - **intensity**(0.0–1.0):校正强度。auto 模式下自动决定。设为 0 可完全禁用此节点。
221
+ - **vae**(可选):reference 模式需要用它来编码参考图
222
+ - **reference_image**(可选):你想匹配其颜色的参考图
223
+ - **mask**(可选):只在遮罩区域内校正颜色
224
+
225
+ ## 色调预设
226
+
227
+ 选择预���后滑条实时更新。可在预设基础上微调。选择 **Custom** 可完全手动设置。
228
+
229
+ | 预设 | 对比度 | 亮度 | 饱和度 | 色温 |
230
+ |------|--------|------|--------|------|
231
+ | Base | 1.0 | 0.0 | 1.0 | 0.0 |
232
+ | Cinematic | 1.20 | -0.05 | 0.90 | 0.05 |
233
+ | HDR | 1.40 | 0.0 | 1.20 | 0.0 |
234
+ | Vivid | 1.10 | 0.0 | 1.50 | 0.0 |
235
+ | Dramatic | 1.50 | -0.10 | 0.85 | 0.0 |
236
+ | Low Key | 1.30 | -0.20 | 0.80 | 0.0 |
237
+ | High Key | 0.80 | 0.20 | 0.90 | 0.0 |
238
+ | Warm | 1.05 | 0.03 | 1.10 | 0.30 |
239
+ | Cool | 1.05 | 0.0 | 1.05 | -0.30 |
240
+ | Desaturated | 1.0 | 0.0 | 0.40 | 0.0 |
241
+
242
+ ## 工作原理
243
+
244
+ ### 颜色(LCS)
245
+
246
+ 1. **投影** — 将去噪预测转换到 64D patch 空间,投影到 3D LCS 基底
247
+ 2. **分解** — 将 3D 颜色坐标与 61D 结构残差分离
248
+ 3. **归一化** — 使用学习的 alpha/beta 统计量变换至参考时间步(t=50)
249
+ 4. **操作** — 在 3D LCS 中偏移颜色、调整色调或进行其他变换
250
+ 5. **重建** — 反归一化,加回保留的 61D 残差,转换回潜在空间
251
+
252
+ 61D 残差(结构、纹理、细节)始终不被修改——只有 3D 颜色子空间会被改变。
253
+
254
+ ### 锐度
255
+
256
+ 锐度存在于与颜色正交的独立子空间中:
257
+
258
+ 1. **校准** — 生成灰度噪声图像,应用多级高斯模糊,VAE 编码后对去除颜色分量的 patch 向量做 PCA。PC1 捕获 ~97% 的锐度方差。
259
+ 2. **干预** — 在每个 patch 上沿 `strength * pc1_direction` 方向添加偏移。由于 pc1_direction 与颜色正交(校准时已移除 LCS 分量)且无直流分量(PCA 前做了逐向量零均值化),因此只改变空间频率内容,不影响颜色或亮度。
260
+
261
+ ### 颜色锚定
262
+
263
+ 颜色锚定的作用是稳定颜色,而不是把颜色推向某个特定目标——它防止模型已经在生成的颜色发生偏移:
264
+
265
+ 1. **判断何时介入** — 节点检查每个采样步:图像还是一片噪声(太早)、正在成形(适合校正)、还是快完成了(太晚)?只在安全的中间窗口进行校正。
266
+ 2. **学习颜色规律**(self_anchor)— 在早期噪声较大的步骤中,节点观察每个区域的颜色与邻居之间的关系,建立一个动态平均值。比起追踪绝对颜色值,这种「相对关系」更可靠,因为绝对颜色在图像成形过程中本来就会自然变化。
267
+ 3. **测量漂移** — 在第一个校正步,节点测量颜色实际漂移了多少(根据模式不同:步间跳变幅度、与参考图的差距、或空间粗糙程度)。这决定了 auto 模式下的校正强度。
268
+ 4. **温和地修正** — 校正平滑地渐入渐出(不会突变)。每种模式的修正方式不同:self_anchor 修复偏离已学规律的区域,reference 拉近与参考图的颜色,smooth 模糊掉尖锐的颜色边界。
269
+ 5. **保留其他一切** — 与所有 LCS 操作一样,只修改 3D 颜色坐标,结构、纹理、细节完全不受影响。
270
+
271
+ ## 文件结构
272
+
273
+ ```
274
+ ComfyUI-LCS/
275
+ ├── __init__.py # 入口(V3 + V2 兼容)
276
+ ├── requirements.txt
277
+ ├── core/
278
+ │ ├── adaptive.py # 自适应调度(阶段、包络、漂移估计)
279
+ │ ├── bilateral.py # LCS 颜色平滑的双边滤波
280
+ │ ├── calibration.py # PCA 校准流程(颜色)
281
+ │ ├── color_space.py # 双锥 LCS ↔ HSL 映射
282
+ │ ├── defaults.py # 论文中的 Alpha/beta 表
283
+ │ ├── lcs_data.py # LCSData 数据类
284
+ │ ├── patchify.py # Patch ↔ 潜在空间转换
285
+ │ ├── relationships.py # 局部颜色关系分析与异常检测
286
+ │ ├── sampling.py # 共享常量和步骤工具
287
+ │ ├── sharpness.py # 锐度子空间校准
288
+ │ └── timestep.py # Sigma/时间步工具
289
+ ├── nodes/
290
+ │ ├── anchor.py # LCSColorAnchor(自适应颜色漂移校正)
291
+ │ ├── calibrate.py # LCSLoadData(自动校准 + 缓存)
292
+ │ ├── intervene.py # LCSColorIntervene, LCSColorBatch, LCSToneAdjust
293
+ │ ├── observe.py # LCSPreviewColors, LCSStepObserver
294
+ │ └── sharpen.py # LCSSharpnessCalibrate, LCSSharpnessIntervene
295
+ ├── data/ # 缓存的校准文件
296
+ └── web/js/
297
+ └── tone_preset.js # 前端预设同步
298
+ ```
299
+
300
+ ## 更新日志
301
+
302
+ ### 2026-03-21
303
+ - **颜色锚定:auto 模式** — 新增 `auto` 模式,根据连接的输入自动推断校正策略(self_anchor / reference / smooth),并根据实测漂移推导强度。零配置使用。
304
+ - **颜色锚定:自适应调度** — 阶段分配(observe/correct/skip)和强度包络在运行时从 sigma 调度表推导。
305
+
306
+ ### 2026-03-20
307
+ - **锐度控制** — 通过模糊刺激 PCA 发现锐度子空间。新增 `LCS Sharpness Calibrate` + `LCS Sharpness Intervene` 节点。PC1 解释 ~97% 方差,与颜色正交。
308
+ - **颜色正交锐度** — 可选连接 `lcs_data`,在锐度校准时移除颜色分量,防止颜色偏移。
309
+
310
+ ### 2026-03-19
311
+ - **视频 VAE 支持(Wan)** — 在 patchify/unpatchify 中处理 5D 视频潜在表示。视频 VAE 自动回退到逐帧编码。
312
+ - **LTXV 兼容** — patchify 中填充奇数空间维度,处理 3D 张量,不兼容格式时优雅跳过。
313
+ - **FLUX2 支持** — unpatchify 自动检测 128 通道潜在表示。
314
+ - **通用潜在格式** — 使用模型的 `latent_format` 进行空间转换,不再硬编码 FLUX 常量。
315
+
316
+ ### 2026-03-18
317
+ - **色调调整** — `LCS Tone Adjust` 节点,支持对比度、亮度、饱和度、色温滑条。10 个预设,前端实时同步。
318
+ - **色温控制** — 沿 LCS 蓝-黄轴的暖/冷偏移。
319
+ - **双锥 HSL 几何** — 通过双锥 LCS-to-HSL 映射实现正确的 Type II 干预。
320
+
321
+ ### 2026-03-17
322
+ - **首次发布** — 颜色引导(Type I + Type II + 插值模式)、批量多色、局部遮罩控制、潜在颜色预览、步骤观察器。按 VAE 自动校准并缓存。
323
+
324
+ ## 引用
325
+
326
+ 官方仓库:[ExplainableML/LCS](https://github.com/ExplainableML/LCS)
327
+
328
+ ```bibtex
329
+ @article{pach2026latentcolorsubspace,
330
+ title={The Latent Color Subspace: Emergent Order in High-Dimensional Chaos},
331
+ author={Mateusz Pach and Jessica Bader and Quentin Bouniot and Serge Belongie and Zeynep Akata},
332
+ journal={arxiv},
333
+ year={2026}
334
+ }
335
+ ```
336
+
337
+ ## 致谢
338
+
339
+ 感谢 Mateusz Pach、Jessica Bader、Quentin Bouniot、Serge Belongie 和 Zeynep Akata,他们的研究使免训练颜色控制成为可能。
340
+
341
+ ## 许可证
342
+
343
+ MIT
custom_nodes/ComfyUI-LCS/__init__.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ComfyUI-LCS: The Latent Color Subspace — training-free color control for FLUX.
2
+
3
+ Paper: "The Latent Color Subspace" (arXiv:2603.12261v1, ICML 2026)
4
+ """
5
+
6
+ # Register as ComfyUI_LCS so other plugins can `from ComfyUI_LCS.core.xxx import ...`
7
+ import sys as _sys
8
+ _sys.modules.setdefault("ComfyUI_LCS", _sys.modules[__name__])
9
+
10
+ # V3 ComfyExtension entry point
11
+ from comfy_api.latest import ComfyExtension, io
12
+ from .nodes.calibrate import LCSLoadData
13
+ from .nodes.intervene import LCSColorIntervene, LCSColorBatch, LCSToneAdjust
14
+ from .nodes.observe import LCSPreviewColors, LCSStepObserver
15
+ from .nodes.sharpen import LCSSharpnessCalibrate, LCSSharpnessIntervene
16
+ from .nodes.anchor import LCSColorAnchor
17
+
18
+
19
+ class LCSExtension(ComfyExtension):
20
+ """V3 ComfyExtension providing all LCS nodes to ComfyUI."""
21
+
22
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
23
+ """Return all LCS node classes."""
24
+ return [
25
+ LCSLoadData,
26
+ LCSColorIntervene,
27
+ LCSColorBatch,
28
+ LCSToneAdjust,
29
+ LCSPreviewColors,
30
+ LCSStepObserver,
31
+ LCSSharpnessCalibrate,
32
+ LCSSharpnessIntervene,
33
+ LCSColorAnchor,
34
+ ]
35
+
36
+
37
+ async def comfy_entrypoint() -> LCSExtension:
38
+ """V3 async entry point called by ComfyUI on startup."""
39
+ return LCSExtension()
40
+
41
+
42
+ # V2 backward compatibility
43
+ from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
44
+
45
+ WEB_DIRECTORY = "./web"
46
+
47
+ __all__ = [
48
+ "NODE_CLASS_MAPPINGS",
49
+ "NODE_DISPLAY_NAME_MAPPINGS",
50
+ "WEB_DIRECTORY",
51
+ "LCSExtension",
52
+ "comfy_entrypoint",
53
+ ]
custom_nodes/ComfyUI-LCS/core/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .lcs_data import LCSData
2
+ from .patchify import patchify, unpatchify
3
+ from .timestep import sigma_to_paper_t, get_alpha_beta, normalize_to_t50, denormalize_from_t50
4
+ from .color_space import decode_lcs_to_hsl, encode_hsl_to_lcs, hex_to_hsl, hsl_to_rgb
5
+
6
+
7
+ def calibrate(*args, **kwargs):
8
+ """Lazy wrapper for core.calibration.calibrate (avoids importing comfy.utils at module level)."""
9
+ from .calibration import calibrate as _calibrate
10
+ return _calibrate(*args, **kwargs)
custom_nodes/ComfyUI-LCS/core/adaptive.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Schedule-aware adaptive logic for LCS color anchoring.
2
+
3
+ Derives intervention windows, strength envelopes, and phase assignments
4
+ from the sigma schedule's amplification factor (beta_50 / beta_t), replacing
5
+ all manually-tuned step/strength parameters with data-driven decisions.
6
+ """
7
+
8
+ import math
9
+ import torch
10
+ from .defaults import get_beta_table
11
+
12
+
13
+ def compute_amplification(sigma_val, device=None):
14
+ """Compute amplification factor A = max_k(beta_50[k] / beta_t(sigma)[k]).
15
+
16
+ The amplification factor measures how much the normalization step inflates
17
+ noise relative to signal. High A means corrections are dangerous (amplified
18
+ noise dominates), low A means corrections are safe.
19
+
20
+ sigma_val: float in [0, 1] (FLUX sigma, 1=noise, 0=clean)
21
+ Returns: float amplification factor
22
+ """
23
+ beta_table = get_beta_table() # [51, 3]
24
+ beta_50 = beta_table[50] # [3]
25
+
26
+ # Convert sigma to paper timestep
27
+ t = 50.0 * (1.0 - max(0.0, min(1.0, sigma_val)))
28
+ t = max(0.0, min(50.0, t))
29
+ t_low = int(t)
30
+ t_high = min(t_low + 1, 50)
31
+ frac = t - t_low
32
+
33
+ beta_t = (1.0 - frac) * beta_table[t_low] + frac * beta_table[t_high]
34
+
35
+ # Per-component ratio, take max
36
+ beta_t_safe = beta_t.clamp(min=1e-8)
37
+ ratios = beta_50 / beta_t_safe # [3]
38
+ return ratios.max().item()
39
+
40
+
41
+ def compute_step_phases(sigmas, mode):
42
+ """Assign a phase to each sampling step based on amplification factor.
43
+
44
+ Physics-derived constants (not empirical):
45
+ A_MAX = 10.0 — above: normalization amplifies noise >10x → skip
46
+ A_WARMUP = 5.0 — self_anchor only: observe phase for EMA buildup
47
+ SIGMA_MIN = 0.15 — below: final detail refinement → skip
48
+
49
+ sigmas: 1D tensor of sigma values for each step (length N+1, last is 0)
50
+ mode: "smooth", "reference", or "self_anchor"
51
+
52
+ Returns: list of N strings, each "skip" / "observe" / "correct"
53
+ """
54
+ A_MAX = 10.0
55
+ A_WARMUP = 5.0
56
+ SIGMA_MIN = 0.15
57
+
58
+ n_steps = len(sigmas) - 1 # last sigma is terminal (0)
59
+ phases = []
60
+
61
+ for i in range(n_steps):
62
+ sigma_val = float(sigmas[i])
63
+
64
+ # Final refinement — skip
65
+ if sigma_val < SIGMA_MIN:
66
+ phases.append("skip")
67
+ continue
68
+
69
+ amp = compute_amplification(sigma_val)
70
+
71
+ # Too noisy — skip
72
+ if amp > A_MAX:
73
+ phases.append("skip")
74
+ continue
75
+
76
+ # Self-anchor warmup zone
77
+ if mode == "self_anchor" and amp > A_WARMUP:
78
+ phases.append("observe")
79
+ continue
80
+
81
+ phases.append("correct")
82
+
83
+ return phases
84
+
85
+
86
+ def estimate_intensity(drift_signal):
87
+ """Map drift magnitude to intensity in [0.15, 0.6]."""
88
+ DRIFT_SCALE = 0.2
89
+ INTENSITY_MIN = 0.15
90
+ INTENSITY_MAX = 0.6
91
+ return max(INTENSITY_MIN, min(INTENSITY_MAX, drift_signal / DRIFT_SCALE))
92
+
93
+
94
+ def compute_strength_envelope(n_correction_steps):
95
+ """Sinusoidal bell envelope over correction steps.
96
+
97
+ sin(pi * i / (n-1)) for i in 0..n-1
98
+ Prevents abrupt on/off at phase boundaries.
99
+ Single step returns [1.0].
100
+
101
+ Returns: 1D tensor of length n_correction_steps
102
+ """
103
+ if n_correction_steps <= 0:
104
+ return torch.zeros(0)
105
+ if n_correction_steps == 1:
106
+ return torch.ones(1)
107
+ n = n_correction_steps
108
+ indices = torch.arange(n, dtype=torch.float32)
109
+ return torch.sin(math.pi * indices / (n - 1))
custom_nodes/ComfyUI-LCS/core/bilateral.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Bilateral filter in LCS space for smooth color anchoring."""
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+
9
+ def estimate_bilateral_params(c, h_len, w_len):
10
+ """Estimate bilateral filter parameters from local color statistics.
11
+
12
+ Computes per-channel spatial std of c across the grid, takes the median
13
+ to derive sigma_color. sigma_spatial is fixed at 1.5 (5x5 kernel is small).
14
+
15
+ c: [B, L, 3] LCS coordinates
16
+ Returns: (sigma_spatial, sigma_color) floats
17
+ """
18
+ B = c.shape[0]
19
+ grid = c.reshape(B, h_len, w_len, 3) # [B, H, W, 3]
20
+ # Per-channel std across spatial dims → [B, 3]
21
+ channel_std = grid.reshape(B, -1, 3).std(dim=1) # [B, 3]
22
+ # Median across batch and channels
23
+ median_std = float(channel_std.median())
24
+ sigma_color = max(0.05, min(3.0, 0.75 * median_std))
25
+ sigma_spatial = 1.5
26
+ return sigma_spatial, sigma_color
27
+
28
+
29
+ def bilateral_filter_lcs(c, h_len, w_len, sigma_spatial, sigma_color, kernel_radius=2):
30
+ """Bilateral filter on [B, L, 3] LCS coordinates arranged on h_len x w_len grid.
31
+
32
+ Uses spatial distance + LCS color distance as joint weights.
33
+ kernel_radius=2 -> 5x5 neighborhood (25 lookups per patch).
34
+ Returns [B, L, 3] filtered coordinates.
35
+ """
36
+ B = c.shape[0]
37
+ # Reshape to spatial grid
38
+ grid = c.reshape(B, h_len, w_len, 3) # [B, H, W, 3]
39
+
40
+ # Pad by kernel_radius (replicate) — pad last two spatial dims
41
+ # F.pad on [B, H, W, 3]: need to pad dims -3 and -2 (H and W)
42
+ # Permute to [B, 3, H, W] for F.pad, then back
43
+ grid_chw = grid.permute(0, 3, 1, 2) # [B, 3, H, W]
44
+ r = kernel_radius
45
+ padded = F.pad(grid_chw, (r, r, r, r), mode="replicate") # [B, 3, H+2r, W+2r]
46
+
47
+ # Precompute spatial Gaussian weights for each offset in kernel
48
+ inv_2ss = -0.5 / (sigma_spatial * sigma_spatial)
49
+ inv_2sc = -0.5 / (sigma_color * sigma_color)
50
+
51
+ # Accumulate weighted sum
52
+ weight_sum = torch.zeros(B, 1, h_len, w_len, device=c.device, dtype=c.dtype)
53
+ value_sum = torch.zeros(B, 3, h_len, w_len, device=c.device, dtype=c.dtype)
54
+
55
+ for dy in range(-r, r + 1):
56
+ for dx in range(-r, r + 1):
57
+ # Spatial weight (constant per offset)
58
+ spatial_dist_sq = float(dy * dy + dx * dx)
59
+ w_spatial = math.exp(spatial_dist_sq * inv_2ss)
60
+
61
+ # Extract neighbor values from padded grid
62
+ y_start = r + dy
63
+ x_start = r + dx
64
+ neighbor = padded[:, :, y_start:y_start + h_len, x_start:x_start + w_len] # [B, 3, H, W]
65
+
66
+ # Color distance weight (per-pixel)
67
+ diff = neighbor - grid_chw # [B, 3, H, W]
68
+ color_dist_sq = (diff * diff).sum(dim=1, keepdim=True) # [B, 1, H, W]
69
+ w_color = torch.exp(color_dist_sq * inv_2sc) # [B, 1, H, W]
70
+
71
+ w = w_spatial * w_color
72
+ weight_sum.add_(w)
73
+ value_sum.add_(w * neighbor)
74
+
75
+ # Normalize
76
+ result = value_sum / weight_sum.clamp(min=1e-8) # [B, 3, H, W]
77
+
78
+ # Back to [B, L, 3]
79
+ return result.permute(0, 2, 3, 1).reshape(B, -1, 3)
custom_nodes/ComfyUI-LCS/core/calibration.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PCA calibration from FLUX VAE: compute LCS basis, mean, and anchor positions."""
2
+
3
+ import hashlib
4
+ import math
5
+ import torch
6
+ import comfy.utils
7
+ from .patchify import patchify
8
+ from .lcs_data import LCSData
9
+ from .color_space import _chromatic_plane_basis
10
+
11
+
12
+ def vae_fingerprint(vae) -> str:
13
+ """8-char hex fingerprint from VAE decoder weights.
14
+
15
+ Used to cache calibration data per-VAE so different VAE models
16
+ get separate calibration files automatically.
17
+ """
18
+ sd = vae.get_sd()
19
+ # Use first decoder weight tensor as fingerprint source
20
+ for key in sorted(sd.keys()):
21
+ if "decoder" in key and "weight" in key:
22
+ w = sd[key]
23
+ return hashlib.sha256(w.cpu().float().numpy().tobytes()).hexdigest()[:8]
24
+ # Fallback: hash first weight found
25
+ first_key = sorted(sd.keys())[0]
26
+ w = sd[first_key]
27
+ return hashlib.sha256(w.cpu().float().numpy().tobytes()).hexdigest()[:8]
28
+
29
+
30
+ # 8 anchor colors: R, B, G, M, C, Y, Black, White
31
+ ANCHOR_COLORS_RGB = [
32
+ (1.0, 0.0, 0.0), # Red
33
+ (0.0, 0.0, 1.0), # Blue
34
+ (0.0, 1.0, 0.0), # Green
35
+ (1.0, 0.0, 1.0), # Magenta
36
+ (0.0, 1.0, 1.0), # Cyan
37
+ (1.0, 1.0, 0.0), # Yellow
38
+ (0.0, 0.0, 0.0), # Black
39
+ (1.0, 1.0, 1.0), # White
40
+ ]
41
+
42
+
43
+ def calibrate(vae, num_colors=512, image_size=512, batch_size=8):
44
+ """Compute LCS data (PCA basis, mean, anchors) from FLUX VAE.
45
+
46
+ 1. Sample num_colors solid-color images uniformly from HSV
47
+ 2. VAE encode each → latent
48
+ 3. Patchify → average patches per image → vector in R^64
49
+ 4. PCA on all vectors → basis B [64,3], mean μ [64]
50
+ 5. Encode 8 anchor colors → compute LCS coords + hue angles
51
+
52
+ Returns: LCSData
53
+ """
54
+ device = comfy.model_management.intermediate_device()
55
+
56
+ print(f"\n[LCS Calibration] Starting calibration for {num_colors} colors...")
57
+ print(f"[LCS Calibration] Image size: {image_size}x{image_size}, Batch size: {batch_size}")
58
+
59
+ # Step 1: Sample colors uniformly from HSV (full saturation, full value for diversity)
60
+ colors = []
61
+ for i in range(num_colors):
62
+ # Uniform sampling in HSV
63
+ h = (i * 137.508) % 360.0 / 360.0 # Golden angle for uniform coverage
64
+ s = 0.3 + 0.7 * ((i * 73) % 100) / 100.0 # Vary saturation 0.3-1.0
65
+ v = 0.3 + 0.7 * ((i * 47) % 100) / 100.0 # Vary value 0.3-1.0
66
+ # HSV to RGB
67
+ r, g, b = _hsv_to_rgb(h, s, v)
68
+ colors.append((r, g, b))
69
+
70
+ # Step 2+3: Encode and average patches
71
+ vectors = []
72
+ pbar = comfy.utils.ProgressBar(num_colors)
73
+
74
+ num_batches = (num_colors + batch_size - 1) // batch_size
75
+ print(f"[LCS Calibration] Encoding {num_colors} color images in {num_batches} batches...")
76
+
77
+ for batch_start in range(0, num_colors, batch_size):
78
+ batch_end = min(batch_start + batch_size, num_colors)
79
+ batch_colors = colors[batch_start:batch_end]
80
+ actual_batch = len(batch_colors)
81
+
82
+ # Create solid color images [B, H, W, 3] (BHWC format for ComfyUI VAE)
83
+ imgs = torch.zeros(actual_batch, image_size, image_size, 3, dtype=torch.float32, device="cpu")
84
+ for j, (r, g, b) in enumerate(batch_colors):
85
+ imgs[j, :, :, 0] = r
86
+ imgs[j, :, :, 1] = g
87
+ imgs[j, :, :, 2] = b
88
+
89
+ # VAE encode — try batch first, fall back to per-image for video VAEs
90
+ latent = vae.encode(imgs[:, :, :, :3])
91
+
92
+ # Squeeze video VAE temporal dim — calibration uses still images
93
+ if latent.ndim == 5:
94
+ latent = latent[:, :, 0, :, :]
95
+
96
+ # Patchify → [B', L, D]
97
+ patches, _, _, _ = patchify(latent)
98
+
99
+ # Average across patches → [B', D]
100
+ avg = patches.mean(dim=1).cpu()
101
+
102
+ if avg.shape[0] == actual_batch:
103
+ # Normal VAE: batch encode worked
104
+ vectors.extend(avg.unbind(0))
105
+ else:
106
+ # Video VAE or unexpected batch collapse — encode one by one
107
+ for k in range(actual_batch):
108
+ single = imgs[k:k+1, :, :, :3]
109
+ lat = vae.encode(single)
110
+ if lat.ndim == 5:
111
+ lat = lat[:, :, 0, :, :]
112
+ p, _, _, _ = patchify(lat)
113
+ vectors.append(p.mean(dim=1).cpu().squeeze(0))
114
+
115
+ pbar.update(actual_batch)
116
+
117
+ # Stack all vectors: [N, 64]
118
+ X = torch.stack(vectors, dim=0).float()
119
+ print(f"[LCS Calibration] Collected {X.shape[0]} patch vectors of dimension {X.shape[1]}")
120
+
121
+ # Step 4: PCA
122
+ print("[LCS Calibration] Computing PCA...")
123
+ mean = X.mean(dim=0) # [64]
124
+ X_centered = X - mean
125
+ # SVD for PCA
126
+ U, S, Vh = torch.linalg.svd(X_centered, full_matrices=False)
127
+ # Top 3 components: B = V[:, :3] (columns are principal directions)
128
+ basis = Vh[:3].T # [64, 3] (Vh rows are right singular vectors)
129
+
130
+ # Variance explained
131
+ total_var = (S ** 2).sum()
132
+ explained = (S[:3] ** 2) / total_var
133
+ print(f"[LCS Calibration] Top 3 components explain {explained.sum():.1%} variance")
134
+ print(f"[LCS Calibration] PC1: {explained[0]:.1%}, PC2: {explained[1]:.1%}, PC3: {explained[2]:.1%}")
135
+
136
+ # Step 5: Encode 8 anchor colors → LCS coords
137
+ print("[LCS Calibration] Encoding 8 anchor colors...")
138
+ anchor_lcs_list = []
139
+ for i, (r, g, b) in enumerate(ANCHOR_COLORS_RGB):
140
+ img = torch.zeros(1, image_size, image_size, 3, dtype=torch.float32, device="cpu")
141
+ img[0, :, :, 0] = r
142
+ img[0, :, :, 1] = g
143
+ img[0, :, :, 2] = b
144
+ latent = vae.encode(img[:, :, :, :3])
145
+ if latent.ndim == 5:
146
+ latent = latent[:, :, 0, :, :]
147
+ patches, _, _, _ = patchify(latent)
148
+ avg = patches.mean(dim=1).cpu().squeeze(0) # [64]
149
+ # Project to LCS
150
+ lcs_coord = (avg - mean) @ basis # [3]
151
+ anchor_lcs_list.append(lcs_coord)
152
+
153
+ anchor_lcs = torch.stack(anchor_lcs_list, dim=0) # [8, 3]
154
+
155
+ # Compute hue angles for 6 chromatic anchors
156
+ anchor_angles = _compute_anchor_angles(anchor_lcs, basis, mean)
157
+
158
+ print(f"[LCS Calibration] Complete! Basis shape: {basis.shape}")
159
+ print(f"[LCS Calibration] Anchor LCS coords:\n{anchor_lcs}")
160
+
161
+ return LCSData(
162
+ basis=basis,
163
+ mean=mean,
164
+ anchor_lcs=anchor_lcs,
165
+ anchor_angles=anchor_angles,
166
+ )
167
+
168
+
169
+ def _compute_anchor_angles(anchor_lcs, basis, mean):
170
+ """Compute hue angles of the 6 chromatic anchors in the chromatic plane.
171
+
172
+ The chromatic plane is perpendicular to the achromatic axis (black→white).
173
+ Returns [6] tensor of angles in radians.
174
+ """
175
+ black = anchor_lcs[6] # [3]
176
+ white = anchor_lcs[7] # [3]
177
+ chromatic = anchor_lcs[:6] # [6, 3]
178
+
179
+ # Achromatic axis
180
+ a = white - black
181
+ a_unit, e1, e2 = _chromatic_plane_basis(a)
182
+
183
+ # Project each chromatic anchor onto the plane and compute angle
184
+ angles = []
185
+ for i in range(6):
186
+ c = chromatic[i]
187
+ # Project onto achromatic axis
188
+ c_proj = black + ((c - black) * a).sum() / ((a * a).sum() + 1e-10) * a
189
+ # Chromatic residual
190
+ chroma = c - c_proj
191
+ x = (chroma * e1).sum()
192
+ y = (chroma * e2).sum()
193
+ angle = torch.atan2(y, x) % (2 * math.pi)
194
+ angles.append(angle)
195
+
196
+ return torch.stack(angles) # [6]
197
+
198
+
199
+ def _hsv_to_rgb(h, s, v):
200
+ """Convert HSV to RGB (scalars in [0,1])."""
201
+ if s < 1e-10:
202
+ return v, v, v
203
+ h6 = h * 6.0
204
+ i = int(h6) % 6
205
+ f = h6 - int(h6)
206
+ p = v * (1.0 - s)
207
+ q = v * (1.0 - s * f)
208
+ t = v * (1.0 - s * (1.0 - f))
209
+ if i == 0: return v, t, p
210
+ if i == 1: return q, v, p
211
+ if i == 2: return p, v, t
212
+ if i == 3: return p, q, v
213
+ if i == 4: return t, p, v
214
+ return v, p, q
custom_nodes/ComfyUI-LCS/core/color_space.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Bicone LCS ↔ HSL mapping using 8 anchor colors.
2
+
3
+ Anchors are indexed as: [Red, Blue, Green, Magenta, Cyan, Yellow, Black, White]
4
+ Indices: 0=R, 1=B, 2=G, 3=M, 4=C, 5=Y, 6=Black, 7=White
5
+ """
6
+
7
+ import math
8
+ import torch
9
+
10
+ # Standard HSL hue for each anchor: R=0, B=4/6, G=2/6, M=5/6, C=3/6, Y=1/6
11
+ _ANCHOR_HUES = (0.0, 4.0/6.0, 2.0/6.0, 5.0/6.0, 3.0/6.0, 1.0/6.0)
12
+
13
+
14
+ def _bicone_factor(l, clamp_min=None):
15
+ """Compute bicone scaling factor: 1 - |2L - 1|.
16
+
17
+ At l=0.5 (equator), factor=1 (full radius).
18
+ At l=0 or l=1 (poles), factor=0 (zero radius).
19
+
20
+ Args:
21
+ l: Lightness tensor [...]
22
+ clamp_min: Optional minimum value for numerical stability
23
+
24
+ Returns:
25
+ Bicone factor tensor [...]
26
+ """
27
+ factor = 1.0 - (2.0 * l - 1.0).abs()
28
+ if clamp_min is not None:
29
+ factor = factor.clamp(min=clamp_min)
30
+ return factor
31
+
32
+
33
+ def _wrap_hue_diff(diff):
34
+ """Wrap hue differences to the shortest path on the unit circle [-0.5, 0.5]."""
35
+ return diff - (diff > 0.5).float() + (diff < -0.5).float()
36
+
37
+
38
+ def _hue_lerp(h1, h2, t):
39
+ """Lerp hues on the circle [0,1], taking the shortest path."""
40
+ return (h1 + t * _wrap_hue_diff(h2 - h1)) % 1.0
41
+
42
+
43
+ def _chromatic_plane_basis(a):
44
+ """Build orthonormal basis (a_unit, e1, e2) for the chromatic plane perpendicular to a."""
45
+ a_unit = a / (a.norm() + 1e-10)
46
+ arb = torch.zeros(3, device=a.device, dtype=a.dtype)
47
+ arb[0] = 1.0
48
+ if a_unit[0].abs() > 0.9:
49
+ arb[0] = 0.0
50
+ arb[1] = 1.0
51
+ e1 = arb - (arb * a_unit).sum() * a_unit
52
+ e1 = e1 / (e1.norm() + 1e-10)
53
+ e2 = torch.linalg.cross(a_unit, e1)
54
+ return a_unit, e1, e2
55
+
56
+
57
+ def hex_to_hsl(hex_str):
58
+ """Convert "#RRGGBB" to (h, s, l) where h∈[0,1], s∈[0,1], l∈[0,1]."""
59
+ hex_str = hex_str.lstrip("#")
60
+ r = int(hex_str[0:2], 16) / 255.0
61
+ g = int(hex_str[2:4], 16) / 255.0
62
+ b = int(hex_str[4:6], 16) / 255.0
63
+ return rgb_to_hsl(r, g, b)
64
+
65
+
66
+ def rgb_to_hsl(r, g, b):
67
+ """Convert RGB [0,1] to HSL [0,1]."""
68
+ cmax = max(r, g, b)
69
+ cmin = min(r, g, b)
70
+ delta = cmax - cmin
71
+ l = (cmax + cmin) / 2.0
72
+
73
+ if delta < 1e-10:
74
+ return 0.0, 0.0, l
75
+
76
+ s = delta / (1.0 - abs(2.0 * l - 1.0)) if abs(2.0 * l - 1.0) < 1.0 else 0.0
77
+
78
+ if cmax == r:
79
+ h = ((g - b) / delta) % 6.0
80
+ elif cmax == g:
81
+ h = (b - r) / delta + 2.0
82
+ else:
83
+ h = (r - g) / delta + 4.0
84
+ h = h / 6.0
85
+ if h < 0:
86
+ h += 1.0
87
+
88
+ return h, max(0.0, min(1.0, s)), max(0.0, min(1.0, l))
89
+
90
+
91
+ def hsl_to_rgb(h, s, l):
92
+ """Convert HSL [0,1] to RGB [0,1]. Works with scalars or tensors."""
93
+ if isinstance(h, torch.Tensor):
94
+ return _hsl_to_rgb_tensor(h, s, l)
95
+
96
+ c = (1.0 - abs(2.0 * l - 1.0)) * s
97
+ x = c * (1.0 - abs((h * 6.0) % 2.0 - 1.0))
98
+ m = l - c / 2.0
99
+
100
+ h6 = h * 6.0
101
+ if h6 < 1:
102
+ r, g, b = c, x, 0
103
+ elif h6 < 2:
104
+ r, g, b = x, c, 0
105
+ elif h6 < 3:
106
+ r, g, b = 0, c, x
107
+ elif h6 < 4:
108
+ r, g, b = 0, x, c
109
+ elif h6 < 5:
110
+ r, g, b = x, 0, c
111
+ else:
112
+ r, g, b = c, 0, x
113
+
114
+ return r + m, g + m, b + m
115
+
116
+
117
+ def _hsl_to_rgb_tensor(h, s, l):
118
+ """Vectorized HSL→RGB for tensors."""
119
+ c = _bicone_factor(l) * s
120
+ h6 = h * 6.0
121
+ x = c * (1.0 - ((h6 % 2.0) - 1.0).abs())
122
+ m = l - c / 2.0
123
+
124
+ r = torch.zeros_like(h)
125
+ g = torch.zeros_like(h)
126
+ b = torch.zeros_like(h)
127
+
128
+ mask0 = h6 < 1
129
+ mask1 = (h6 >= 1) & (h6 < 2)
130
+ mask2 = (h6 >= 2) & (h6 < 3)
131
+ mask3 = (h6 >= 3) & (h6 < 4)
132
+ mask4 = (h6 >= 4) & (h6 < 5)
133
+ mask5 = h6 >= 5
134
+
135
+ r[mask0] = c[mask0]; g[mask0] = x[mask0]
136
+ r[mask1] = x[mask1]; g[mask1] = c[mask1]
137
+ g[mask2] = c[mask2]; b[mask2] = x[mask2]
138
+ g[mask3] = x[mask3]; b[mask3] = c[mask3]
139
+ r[mask4] = x[mask4]; b[mask4] = c[mask4]
140
+ r[mask5] = c[mask5]; b[mask5] = x[mask5]
141
+
142
+ return (r + m).clamp(0, 1), (g + m).clamp(0, 1), (b + m).clamp(0, 1)
143
+
144
+
145
+ def decode_lcs_to_hsl(c, anchor_lcs, anchor_angles):
146
+ """Decode LCS coordinates to HSL using bicone geometry.
147
+
148
+ c: [..., 3] LCS coordinates (normalized to t=50)
149
+ anchor_lcs: [8, 3] anchor positions [R,B,G,M,C,Y,Black,White]
150
+ anchor_angles: [6] hue angles of chromatic anchors in radians
151
+
152
+ Returns: (h, s, l) each [...] in [0,1]
153
+ """
154
+ black = anchor_lcs[6] # [3]
155
+ white = anchor_lcs[7] # [3]
156
+ chromatic = anchor_lcs[:6] # [6, 3]
157
+
158
+ # Achromatic axis
159
+ a = white - black # [3]
160
+ a_norm_sq = (a * a).sum() + 1e-10
161
+
162
+ # Lightness: project onto achromatic axis
163
+ diff = c - black # [..., 3]
164
+ l = (diff * a).sum(dim=-1) / a_norm_sq # [...]
165
+ l = l.clamp(0.0, 1.0)
166
+
167
+ # Point on achromatic axis
168
+ c_L = black + l.unsqueeze(-1) * a # [..., 3]
169
+
170
+ # Chromatic residual
171
+ chroma_vec = c - c_L # [..., 3]
172
+ chroma_dist = chroma_vec.norm(dim=-1) + 1e-10 # [...]
173
+
174
+ # Compute hue angle in chromatic plane
175
+ a_unit, e1, e2 = _chromatic_plane_basis(a)
176
+
177
+ # Project chromatic vector to 2D
178
+ x_coord = (chroma_vec * e1).sum(dim=-1) # [...]
179
+ y_coord = (chroma_vec * e2).sum(dim=-1) # [...]
180
+ angle = torch.atan2(y_coord, x_coord) # [...] radians
181
+ angle = angle % (2 * math.pi)
182
+
183
+ # Map angle to hue [0,1] using sorted anchor angles
184
+ # anchor_angles are the angles of [R,B,G,M,C,Y] in the same coordinate system
185
+ # Standard HSL hue: R=0, Y=1/6, G=2/6, C=3/6, B=4/6, M=5/6
186
+ # But anchors may not be in that order in angle-space, so we interpolate
187
+ sorted_angles, sort_idx = anchor_angles.sort()
188
+ anchor_hues = torch.tensor(_ANCHOR_HUES, device=c.device, dtype=c.dtype)
189
+ sorted_hues = anchor_hues[sort_idx]
190
+
191
+ # Piecewise linear interpolation around the circle
192
+ h = _angle_to_hue(angle, sorted_angles, sorted_hues)
193
+
194
+ # Saturation: distance to achromatic axis normalized by max distance
195
+ # Max distance at this hue and lightness
196
+ bicone_factor = _bicone_factor(l, clamp_min=1e-10)
197
+
198
+ # Find the chroma boundary at this hue (perpendicular to achromatic axis)
199
+ chroma_boundary = _hue_to_chroma_vector(h, chromatic, anchor_angles, a_unit, e1, e2, black, a)
200
+ max_radius = chroma_boundary.norm(dim=-1) + 1e-10
201
+ s = chroma_dist / (max_radius * bicone_factor)
202
+ s = s.clamp(0.0, 1.0)
203
+
204
+ return h, s, l
205
+
206
+
207
+ def encode_hsl_to_lcs(h, s, l, anchor_lcs, anchor_angles):
208
+ """Encode HSL to LCS coordinates using bicone geometry.
209
+
210
+ h, s, l: [...] in [0,1]
211
+ anchor_lcs: [8, 3]
212
+ anchor_angles: [6] radians
213
+
214
+ Returns: c [..., 3] LCS coordinates
215
+ """
216
+ black = anchor_lcs[6] # [3]
217
+ white = anchor_lcs[7] # [3]
218
+ chromatic = anchor_lcs[:6] # [6, 3]
219
+
220
+ a = white - black
221
+ a_unit, e1, e2 = _chromatic_plane_basis(a)
222
+
223
+ # Lightness point on achromatic axis
224
+ c_L = black + l.unsqueeze(-1) * a # [..., 3]
225
+
226
+ # Chroma direction vector (equatorial radius at this hue)
227
+ chroma_dir = _hue_to_chroma_vector(h, chromatic, anchor_angles, a_unit, e1, e2, black, a)
228
+
229
+ # Combine: c = c_L + s * (1 - |2l-1|) * chroma_dir
230
+ bicone_factor = _bicone_factor(l)
231
+ c = c_L + (s * bicone_factor).unsqueeze(-1) * chroma_dir
232
+
233
+ return c
234
+
235
+
236
+ def _angle_to_hue(angle, sorted_angles, sorted_hues):
237
+ """Map an angle [...] to hue [0,1] via piecewise linear interpolation on anchor angles."""
238
+ n = len(sorted_angles)
239
+ h = torch.zeros_like(angle)
240
+
241
+ for i in range(n):
242
+ j = (i + 1) % n
243
+ a_start = sorted_angles[i]
244
+ a_end = sorted_angles[j]
245
+ h_start = sorted_hues[i]
246
+ h_end = sorted_hues[j]
247
+
248
+ # Handle wraparound
249
+ if a_end < a_start:
250
+ a_end = a_end + 2 * math.pi
251
+ span = a_end - a_start
252
+ if span < 1e-10:
253
+ continue
254
+
255
+ # Check which angles fall in this segment
256
+ if a_end > 2 * math.pi:
257
+ # Wraparound segment
258
+ mask = (angle >= a_start) | (angle < (a_end - 2 * math.pi))
259
+ angle_shifted = torch.where(angle < a_start, angle + 2 * math.pi, angle)
260
+ else:
261
+ mask = (angle >= a_start) & (angle < a_end)
262
+ angle_shifted = angle
263
+
264
+ frac = ((angle_shifted - a_start) / span).clamp(0, 1)
265
+
266
+ # Interpolate hue (handling hue wraparound)
267
+ h_diff = h_end - h_start
268
+ if abs(h_diff) > 0.5:
269
+ if h_diff > 0:
270
+ h_diff -= 1.0
271
+ else:
272
+ h_diff += 1.0
273
+ interp = h_start + frac * h_diff
274
+ interp = interp % 1.0
275
+
276
+ h = torch.where(mask, interp, h)
277
+
278
+ return h
279
+
280
+
281
+ def _hue_to_chroma_vector(h, chromatic, anchor_angles, a_unit, e1, e2, black, a):
282
+ """Map hue values [...] to EQUATORIAL chroma direction vectors.
283
+
284
+ Returns vectors in 3D LCS space that lie in the chromatic plane (perpendicular to a_unit)
285
+ with magnitude equal to the equatorial chroma radius at that hue (i.e., the radius at l=0.5).
286
+
287
+ The equatorial radius is computed by normalizing each anchor's chroma radius by its
288
+ bicone factor (1 - |2L - 1|), where L is the anchor's lightness. This ensures proper
289
+ round-trip encoding/decoding across the bicone.
290
+
291
+ chromatic: [6, 3] anchor LCS positions
292
+ anchor_angles: [6] calibrated angles of chromatic anchors (radians)
293
+ a_unit: [3] unit vector along achromatic axis
294
+ e1, e2: [3] orthonormal basis for chromatic plane
295
+ black: [3] black anchor position
296
+ a: [3] full achromatic axis vector (white - black)
297
+ """
298
+ # Compute each anchor's lightness (scalar projection onto achromatic axis)
299
+ a_sq = (a * a).sum() + 1e-10
300
+ anchor_diff = chromatic - black # [6, 3]
301
+ anchor_l = (anchor_diff * a).sum(dim=-1) / a_sq # [6] lightness values
302
+
303
+ # Project anchors onto chromatic plane to get chroma vectors
304
+ anchor_on_axis = black + anchor_l.unsqueeze(-1) * a # [6, 3]
305
+ anchor_chroma = chromatic - anchor_on_axis # [6, 3] chroma vectors
306
+ anchor_r = anchor_chroma.norm(dim=-1) # [6] radii at anchor lightness
307
+
308
+ # Normalize to equatorial radii (radius at l=0.5 where bicone_factor=1)
309
+ bicone_factors = _bicone_factor(anchor_l, clamp_min=1e-6) # [6]
310
+ equatorial_r = anchor_r / bicone_factors # [6] equatorial radii
311
+
312
+ anchor_hues = torch.tensor(_ANCHOR_HUES, device=chromatic.device, dtype=chromatic.dtype)
313
+
314
+ # Sort by ANGLE (same as _angle_to_hue) to match segment structure
315
+ sorted_angles, sort_idx = anchor_angles.sort()
316
+ sorted_hues = anchor_hues[sort_idx]
317
+ sorted_radii = equatorial_r[sort_idx] # [6] equatorial radii
318
+
319
+ # Iterate segments in angle order (same as _angle_to_hue)
320
+ n = 6
321
+ result = torch.empty(h.shape + (3,), device=chromatic.device, dtype=chromatic.dtype)
322
+
323
+ for i in range(n):
324
+ j = (i + 1) % n
325
+ h_start = sorted_hues[i]
326
+ h_end = sorted_hues[j]
327
+
328
+ # Hue span with wraparound (same logic as _angle_to_hue)
329
+ h_diff = h_end - h_start
330
+ if abs(h_diff) > 0.5:
331
+ if h_diff > 0:
332
+ h_diff -= 1.0
333
+ else:
334
+ h_diff += 1.0
335
+
336
+ if abs(h_diff) < 1e-10:
337
+ continue
338
+
339
+ # Determine hue range for this segment
340
+ h_end_unwrapped = h_start + h_diff
341
+
342
+ # Build mask for which input hues fall in this segment
343
+ if h_diff > 0:
344
+ if h_end_unwrapped > 1.0:
345
+ mask = (h >= h_start) | (h < (h_end_unwrapped - 1.0))
346
+ h_shifted = torch.where(h < h_start, h + 1.0, h)
347
+ else:
348
+ mask = (h >= h_start) & (h < h_end_unwrapped)
349
+ h_shifted = h
350
+ else:
351
+ # Hue decreases
352
+ if h_end_unwrapped < 0.0:
353
+ mask = (h <= h_start) | (h > (h_end_unwrapped + 1.0))
354
+ h_shifted = torch.where(h > h_start, h - 1.0, h)
355
+ else:
356
+ mask = (h <= h_start) & (h > h_end_unwrapped)
357
+ h_shifted = h
358
+
359
+ frac = ((h_shifted - h_start) / h_diff).clamp(0, 1)
360
+
361
+ # Interpolate radius
362
+ interp_r = sorted_radii[i] + frac * (sorted_radii[j] - sorted_radii[i])
363
+
364
+ # Interpolate angle
365
+ a_start = sorted_angles[i]
366
+ a_end = sorted_angles[j]
367
+ a_span = a_end - a_start
368
+ if a_span < 0:
369
+ a_span += 2 * math.pi
370
+ interp_angle = (a_start + frac * a_span) % (2 * math.pi)
371
+
372
+ # Reconstruct 3D chroma vector
373
+ interp_vec = interp_r.unsqueeze(-1) * (
374
+ torch.cos(interp_angle).unsqueeze(-1) * e1
375
+ + torch.sin(interp_angle).unsqueeze(-1) * e2
376
+ )
377
+
378
+ result = torch.where(mask.unsqueeze(-1), interp_vec, result)
379
+
380
+ return result
custom_nodes/ComfyUI-LCS/core/defaults.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hardcoded alpha_t and beta_t tables from paper Appendix F (51 entries, t=0..50)."""
2
+
3
+ import torch
4
+
5
+ # Shift alpha_t: 3D vectors for each timestep t=0..50
6
+ ALPHA_T = [
7
+ [2.3413, -2.3586, 0.4266], [2.3574, -2.3833, 0.4644], [2.3638, -2.3904, 0.4883],
8
+ [2.3734, -2.3951, 0.5122], [2.3831, -2.3993, 0.5384], [2.3925, -2.4026, 0.5647],
9
+ [2.4023, -2.4047, 0.5919], [2.4124, -2.4060, 0.6198], [2.4226, -2.4064, 0.6484],
10
+ [2.4330, -2.4060, 0.6772], [2.4437, -2.4051, 0.7065], [2.4546, -2.4035, 0.7367],
11
+ [2.4659, -2.4011, 0.7668], [2.4775, -2.3981, 0.7974], [2.4897, -2.4009, 0.8312],
12
+ [2.5021, -2.4036, 0.8656], [2.5148, -2.4065, 0.9008], [2.5277, -2.4093, 0.9364],
13
+ [2.5408, -2.4123, 0.9727], [2.5542, -2.4154, 1.0099], [2.5680, -2.4186, 1.0481],
14
+ [2.5820, -2.4218, 1.0868], [2.5963, -2.4252, 1.1263], [2.6110, -2.4288, 1.1672],
15
+ [2.6261, -2.4324, 1.2090], [2.6416, -2.4363, 1.2520], [2.6575, -2.4403, 1.2957],
16
+ [2.6738, -2.4444, 1.3406], [2.6904, -2.4485, 1.3865], [2.7074, -2.4529, 1.4336],
17
+ [2.7250, -2.4574, 1.4818], [2.7432, -2.4621, 1.5314], [2.7618, -2.4669, 1.5823],
18
+ [2.7810, -2.4720, 1.6344], [2.8006, -2.4771, 1.6878], [2.8209, -2.4826, 1.7430],
19
+ [2.8418, -2.4883, 1.7995], [2.8631, -2.4944, 1.8578], [2.8853, -2.5005, 1.9179],
20
+ [2.9080, -2.5066, 1.9793], [2.9313, -2.5132, 2.0426], [2.9555, -2.5199, 2.1082],
21
+ [2.9804, -2.5268, 2.1756], [3.0060, -2.5338, 2.2450], [3.0328, -2.5411, 2.3172],
22
+ [3.0603, -2.5486, 2.3914], [3.0889, -2.5561, 2.4682], [3.1189, -2.5640, 2.5482],
23
+ [3.1497, -2.5725, 2.6302], [3.1824, -2.5796, 2.7175], [3.2152, -2.5889, 2.8050],
24
+ ]
25
+
26
+ # Scale beta_t: 3D vectors for each timestep t=0..50
27
+ BETA_T = [
28
+ [0.0163, 0.0172, 0.0295], [0.0905, 0.0716, 0.0999], [0.1345, 0.1123, 0.1544],
29
+ [0.1826, 0.1491, 0.2065], [0.2360, 0.1899, 0.2630], [0.2904, 0.2316, 0.3202],
30
+ [0.3471, 0.2749, 0.3793], [0.4050, 0.3191, 0.4394], [0.4640, 0.3641, 0.5003],
31
+ [0.5231, 0.4091, 0.5611], [0.5834, 0.4547, 0.6228], [0.6456, 0.5016, 0.6861],
32
+ [0.7077, 0.5481, 0.7488], [0.7713, 0.5958, 0.8127], [0.8410, 0.6496, 0.8866],
33
+ [0.9119, 0.7044, 0.9616], [0.9845, 0.7605, 1.0386], [1.0578, 0.8172, 1.1163],
34
+ [1.1325, 0.8750, 1.1957], [1.2094, 0.9344, 1.2771], [1.2880, 0.9953, 1.3606],
35
+ [1.3680, 1.0571, 1.4453], [1.4498, 1.1205, 1.5321], [1.5341, 1.1858, 1.6216],
36
+ [1.6206, 1.2526, 1.7131], [1.7094, 1.3214, 1.8072], [1.7998, 1.3913, 1.9030],
37
+ [1.8927, 1.4633, 2.0014], [1.9879, 1.5370, 2.1022], [2.0854, 1.6126, 2.2056],
38
+ [2.1853, 1.6900, 2.3114], [2.2881, 1.7696, 2.4202], [2.3939, 1.8515, 2.5321],
39
+ [2.5021, 1.9354, 2.6467], [2.6133, 2.0215, 2.7642], [2.7280, 2.1106, 2.8857],
40
+ [2.8455, 2.2017, 3.0101], [2.9668, 2.2957, 3.1386], [3.0921, 2.3929, 3.2712],
41
+ [3.2204, 2.4922, 3.4067], [3.3523, 2.5946, 3.5464], [3.4888, 2.7006, 3.6911],
42
+ [3.6292, 2.8097, 3.8398], [3.7741, 2.9222, 3.9931], [3.9247, 3.0394, 4.1527],
43
+ [4.0793, 3.1597, 4.3168], [4.2393, 3.2843, 4.4866], [4.4053, 3.4142, 4.6636],
44
+ [4.5760, 3.5480, 4.8461], [4.7541, 3.6886, 5.0383], [4.9407, 3.8364, 5.2390],
45
+ ]
46
+
47
+ # Pre-convert to tensors (lazily cached on first access)
48
+ _alpha_tensor = None
49
+ _beta_tensor = None
50
+
51
+
52
+ def get_alpha_table():
53
+ """Return α_t table as tensor [51, 3], cached after first call."""
54
+ global _alpha_tensor
55
+ if _alpha_tensor is None:
56
+ _alpha_tensor = torch.tensor(ALPHA_T, dtype=torch.float32) # [51, 3]
57
+ return _alpha_tensor
58
+
59
+
60
+ def get_beta_table():
61
+ """Return β_t table as tensor [51, 3], cached after first call."""
62
+ global _beta_tensor
63
+ if _beta_tensor is None:
64
+ _beta_tensor = torch.tensor(BETA_T, dtype=torch.float32) # [51, 3]
65
+ return _beta_tensor
custom_nodes/ComfyUI-LCS/core/diagnostics.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Diagnostic tests for LCS intervention pipeline.
2
+
3
+ This module provides tests and diagnostics to identify conditions that
4
+ cause image blurriness or quality degradation during LCS intervention.
5
+ """
6
+
7
+ import torch
8
+ import math
9
+ from .color_space import decode_lcs_to_hsl, encode_hsl_to_lcs, _hue_lerp
10
+ from .timestep import get_alpha_beta, get_alpha_beta_t50, normalize_to_t50, denormalize_from_t50
11
+
12
+ # Test constants
13
+ _T50_REFERENCE_COORD = [0.5, 0.3, 0.1] # Typical LCS magnitude at t=50
14
+ _TEST_STRENGTHS = [0.0, 0.25, 0.5, 0.75, 1.0, 1.5, 2.0] # Range from none to overshoot
15
+ _VARIATION_SCALE = 0.5 # Scale for test patch variation
16
+ _NOISE_SCALE = 2.0 # Simulated diffusion noise magnitude
17
+ _PROBLEMATIC_AMPLIFICATION_THRESHOLD = 50 # >50x noise amplification is problematic
18
+
19
+
20
+ def test_round_trip_consistency(anchor_lcs, anchor_angles):
21
+ """Test that encode(decode(x)) ≈ x for typical LCS coordinates.
22
+
23
+ This verifies the bicone geometry math is correct.
24
+ """
25
+ chromatic = anchor_lcs[:6]
26
+ black, white = anchor_lcs[6], anchor_lcs[7]
27
+
28
+ # Test round-trip on anchor positions
29
+ errors = []
30
+ test_cases = list(chromatic) # All 6 chromatic anchors
31
+
32
+ # Add some mid-tones and random points
33
+ for _ in range(5):
34
+ # Generate random LCS point
35
+ h = torch.rand(1).item()
36
+ s = torch.rand(1).item()
37
+ l = torch.rand(1).item()
38
+ c = encode_hsl_to_lcs(
39
+ torch.tensor(h), torch.tensor(s), torch.tensor(l),
40
+ anchor_lcs, anchor_angles
41
+ )
42
+ test_cases.append(c)
43
+
44
+ for c in test_cases:
45
+ h, s, l = decode_lcs_to_hsl(c, anchor_lcs, anchor_angles)
46
+ c_round = encode_hsl_to_lcs(h, s, l, anchor_lcs, anchor_angles)
47
+ error = (c - c_round).norm().item()
48
+ errors.append(error)
49
+
50
+ max_error = max(errors)
51
+ avg_error = sum(errors) / len(errors)
52
+ return {
53
+ "max_round_trip_error": max_error,
54
+ "avg_round_trip_error": avg_error,
55
+ "passed": max_error < 1e-4,
56
+ "errors": errors,
57
+ }
58
+
59
+
60
+ def test_normalization_stability():
61
+ """Test that normalize/denormalize round-trip is stable across all timesteps.
62
+
63
+ Identifies timesteps where numerical instability could cause issues.
64
+ """
65
+ # Sample LCS coordinates at t=50 (clean image reference)
66
+ c_t50 = torch.tensor(_T50_REFERENCE_COORD, dtype=torch.float32)
67
+ alpha_50, beta_50 = get_alpha_beta_t50()
68
+
69
+ results = []
70
+ for t in range(51):
71
+ sigma = 1.0 - t / 50.0 # sigma = 1 - t/50
72
+ alpha_t, beta_t = get_alpha_beta(sigma)
73
+
74
+ # Normalize then denormalize
75
+ c_norm = normalize_to_t50(c_t50, alpha_t, beta_t, alpha_50, beta_50)
76
+ c_back = denormalize_from_t50(c_norm, alpha_t, beta_t, alpha_50, beta_50)
77
+
78
+ error = (c_t50 - c_back).norm().item()
79
+
80
+ # Check amplification factor
81
+ amplification = (beta_50 / beta_t).max().item()
82
+
83
+ results.append({
84
+ "t": t,
85
+ "sigma": sigma,
86
+ "beta_t_min": beta_t.min().item(),
87
+ "amplification": amplification,
88
+ "round_trip_error": error,
89
+ })
90
+
91
+ return results
92
+
93
+
94
+ def test_type_ii_uniformity(anchor_lcs, anchor_angles):
95
+ """Test if Type II intervention at high strength produces uniform outputs.
96
+
97
+ This is a key diagnostic for the blurriness issue - if all patches
98
+ converge to the same HSL values, the image loses detail.
99
+ """
100
+ # Create diverse patch set (simulate image with color variation)
101
+ patches = torch.randn(100, 3) * _VARIATION_SCALE + torch.tensor([0.3, 0.2, 0.1])
102
+
103
+ # Target color (e.g., saturated red)
104
+ t_h, t_s, t_l = 0.0, 1.0, 0.5
105
+
106
+ # Decode all patches ONCE (constant across strengths)
107
+ h_cur, s_cur, l_cur = decode_lcs_to_hsl(patches, anchor_lcs, anchor_angles)
108
+
109
+ # Target HSL tensors
110
+ h_new = torch.full_like(h_cur, t_h)
111
+ s_new = torch.full_like(s_cur, t_s)
112
+ l_new = torch.full_like(l_cur, t_l)
113
+
114
+ # Compute input variance once (patches never changes)
115
+ input_var = patches.var(dim=0).mean().item()
116
+
117
+ # Test different strengths
118
+ for strength in _TEST_STRENGTHS:
119
+ # Hue lerp using shared helper
120
+ h_interp = _hue_lerp(h_cur, h_new, strength)
121
+ s_interp = (s_cur + strength * (s_new - s_cur)).clamp(0, 1)
122
+ l_interp = (l_cur + strength * (l_new - l_cur)).clamp(0, 1)
123
+
124
+ # Re-encode
125
+ new_patches = encode_hsl_to_lcs(h_interp, s_interp, l_interp, anchor_lcs, anchor_angles)
126
+
127
+ # Measure variance loss
128
+ output_var = new_patches.var(dim=0).mean().item()
129
+ var_ratio = output_var / (input_var + 1e-10)
130
+
131
+ # Check how many unique HSL values we end up with
132
+ h_unique = len(torch.unique(h_interp.round(decimals=3)))
133
+ s_unique = len(torch.unique(s_interp.round(decimals=3)))
134
+ l_unique = len(torch.unique(l_interp.round(decimals=3)))
135
+
136
+ print(f"strength={strength:.2f}: var_ratio={var_ratio:.3f}, "
137
+ f"unique_h={h_unique}, unique_s={s_unique}, unique_l={l_unique}")
138
+
139
+
140
+ def test_early_timestep_amplification():
141
+ """Test numerical behavior at very early timesteps (high sigma).
142
+
143
+ At t≈0 (sigma≈1), beta_t is very small, causing large amplification
144
+ in normalize_to_t50. This could amplify noise and corrupt the signal.
145
+ """
146
+ # Typical LCS coordinate magnitude at t=50
147
+ c_ref = torch.tensor(_T50_REFERENCE_COORD, dtype=torch.float32)
148
+ alpha_50, beta_50 = get_alpha_beta_t50() # Constant across all sigmas
149
+
150
+ for sigma in [1.0, 0.99, 0.95, 0.90, 0.85, 0.80, 0.50, 0.0]:
151
+ alpha_t, beta_t = get_alpha_beta(sigma)
152
+
153
+ # Simulate a noisy observation at timestep t
154
+ # In diffusion, the observation is alpha_t * clean + beta_t * noise
155
+ # At high sigma, noise dominates
156
+ noise = torch.randn(3) * _NOISE_SCALE
157
+ c_observed = alpha_t + beta_t * c_ref + beta_t * noise
158
+
159
+ # Normalize to t=50
160
+ c_norm = normalize_to_t50(c_observed, alpha_t, beta_t, alpha_50, beta_50)
161
+
162
+ # Measure deviation from reference
163
+ deviation = (c_norm - c_ref).norm().item()
164
+ amplification = (beta_50 / beta_t).max().item()
165
+
166
+ print(f"sigma={sigma:.2f}: beta_t={beta_t.numpy()}, "
167
+ f"amplification={amplification:.1f}x, deviation={deviation:.3f}")
168
+
169
+
170
+ def analyze_blurriness_causes(lcs_data_path=None):
171
+ """Comprehensive analysis of all potential blurriness causes."""
172
+ print("=" * 60)
173
+ print("LCS INTERVENTION BLURRINESS ANALYSIS")
174
+ print("=" * 60)
175
+
176
+ # Load actual calibration data
177
+ if lcs_data_path is None:
178
+ from pathlib import Path
179
+ data_dir = Path(__file__).parent.parent / "data"
180
+ safetensors_files = list(data_dir.glob("lcs_*.safetensors"))
181
+ if safetensors_files:
182
+ lcs_data_path = safetensors_files[0]
183
+ else:
184
+ print("ERROR: No calibration data found. Run LCSLoadData with calibrate=True first.")
185
+ return
186
+
187
+ from safetensors.torch import load_file
188
+ data = load_file(lcs_data_path)
189
+ anchor_lcs = data["anchor_lcs"]
190
+ anchor_angles = data["anchor_angles"]
191
+
192
+ print(f"\nLoaded calibration data from: {lcs_data_path}")
193
+ print(f"anchor_lcs shape: {anchor_lcs.shape}")
194
+ print(f"anchor_angles shape: {anchor_angles.shape}")
195
+
196
+ print("\n1. ROUND-TRIP CONSISTENCY TEST")
197
+ print("-" * 40)
198
+ result = test_round_trip_consistency(anchor_lcs, anchor_angles)
199
+ print(f"Max error: {result['max_round_trip_error']:.2e}")
200
+ print(f"Avg error: {result['avg_round_trip_error']:.2e}")
201
+ print(f"Status: {'PASS' if result['passed'] else 'FAIL'}")
202
+
203
+ print("\n2. NORMALIZATION STABILITY TEST")
204
+ print("-" * 40)
205
+ norm_results = test_normalization_stability()
206
+ problematic = [r for r in norm_results if r['amplification'] > _PROBLEMATIC_AMPLIFICATION_THRESHOLD]
207
+ print(f"Timesteps with >{_PROBLEMATIC_AMPLIFICATION_THRESHOLD}x amplification: {len(problematic)}")
208
+ for r in problematic[:5]:
209
+ print(f" t={r['t']:2d} (sigma={r['sigma']:.2f}): amp={r['amplification']:.1f}x")
210
+
211
+ print("\n3. TYPE II UNIFORMITY TEST")
212
+ print("-" * 40)
213
+ test_type_ii_uniformity(anchor_lcs, anchor_angles)
214
+
215
+ print("\n4. EARLY TIMESTEP AMPLIFICATION TEST")
216
+ print("-" * 40)
217
+ test_early_timestep_amplification()
218
+
219
+ print("\n" + "=" * 60)
220
+ print("CONCLUSIONS")
221
+ print("=" * 60)
222
+ print("""
223
+ Potential blurriness causes identified:
224
+
225
+ 1. TYPE II AT HIGH STRENGTH: At strength=1.0, all patches get the same
226
+ target HSL, destroying spatial color variation. This is the PRIMARY
227
+ cause of blur in type_ii mode.
228
+
229
+ 2. EARLY TIMESTEP AMPLIFICATION: At sigma>0.95 (t<2.5), beta_t is ~0.02,
230
+ causing ~250x amplification of noise. Intervening too early (step 0-2)
231
+ will corrupt the signal.
232
+
233
+ 3. OVERSHOOTING: strength>1.0 overshoots the target, potentially pushing
234
+ values outside the valid color gamut. This can cause clipping and
235
+ artifacts.
236
+
237
+ RECOMMENDATIONS:
238
+ - For type_ii mode, use strength<0.8 to preserve some original variation
239
+ - Avoid intervening before step 5 (sigma<0.90)
240
+ - For interpolated mode, the gamma=sigma blending naturally limits damage
241
+ at early steps
242
+ """)
243
+
244
+
245
+ if __name__ == "__main__":
246
+ analyze_blurriness_causes()
custom_nodes/ComfyUI-LCS/core/lcs_data.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import torch
3
+
4
+
5
+ @dataclass
6
+ class LCSData:
7
+ """Calibration data for the Latent Color Subspace.
8
+
9
+ Produced by PCA on FLUX VAE-encoded solid-color images. Flows between
10
+ all LCS nodes as the shared LCS_DATA custom type.
11
+ """
12
+
13
+ basis: torch.Tensor # [64, 3] PCA basis B (orthonormal columns)
14
+ mean: torch.Tensor # [64] PCA mean mu
15
+ anchor_lcs: torch.Tensor # [8, 3] LCS coords of 8 anchor colors [R,B,G,M,C,Y,Black,White]
16
+ anchor_angles: torch.Tensor # [6] hue angles (radians) of the 6 chromatic anchors
17
+
18
+ def to(self, device, dtype=None):
19
+ """Move all tensors to device/dtype."""
20
+ kw = {"device": device}
21
+ if dtype is not None:
22
+ kw["dtype"] = dtype
23
+ return LCSData(
24
+ basis=self.basis.to(**kw),
25
+ mean=self.mean.to(**kw),
26
+ anchor_lcs=self.anchor_lcs.to(**kw),
27
+ anchor_angles=self.anchor_angles.to(**kw),
28
+ )
custom_nodes/ComfyUI-LCS/core/patchify.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Patchify/unpatchify for latent tensors (patch_size=2, auto-detect channels).
2
+
3
+ Handles 3D, 4D, and 5D inputs. Pads odd spatial dims to even before patchifying.
4
+ """
5
+
6
+ from einops import rearrange
7
+ import torch.nn.functional as F
8
+
9
+
10
+ def patchify(x):
11
+ """Convert latent [C, H, W], [B, C, H, W], or [B, C, T, H, W] → patch sequence [B, L, C*4].
12
+
13
+ Handles three input formats:
14
+ - 3D [C, H, W]: adds batch dim, extra_shape="unbatched"
15
+ - 4D [B, C, H, W]: standard path, extra_shape=None
16
+ - 5D [B, C, T, H, W]: video VAE, merges T into batch, extra_shape=(B, C, T)
17
+
18
+ Pads odd H/W to even before patchifying. The pad amounts are stored
19
+ in the returned extra_shape for unpatchify to crop back.
20
+
21
+ L = (H_padded/2) * (W_padded/2), d = C * 2 * 2.
22
+ """
23
+ extra_shape = None
24
+ pad_h = 0
25
+ pad_w = 0
26
+
27
+ if x.ndim == 3:
28
+ extra_shape = "unbatched"
29
+ x = x.unsqueeze(0)
30
+ elif x.ndim == 5:
31
+ B_orig, C, T, H, W = x.shape
32
+ extra_shape = (B_orig, C, T)
33
+ x = x.permute(0, 2, 1, 3, 4).reshape(B_orig * T, C, H, W)
34
+
35
+ B, C, H, W = x.shape
36
+ if H < 1 or W < 1:
37
+ return None, None, None, None
38
+
39
+ # Pad odd dimensions to even (replicate last row/col)
40
+ if H % 2 != 0:
41
+ pad_h = 1
42
+ if W % 2 != 0:
43
+ pad_w = 1
44
+ if pad_h or pad_w:
45
+ x = F.pad(x, (0, pad_w, 0, pad_h), mode="replicate")
46
+
47
+ H_p, W_p = x.shape[2], x.shape[3]
48
+ h_len = H_p // 2
49
+ w_len = W_p // 2
50
+ patches = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
51
+
52
+ # Bundle pad info with extra_shape
53
+ if pad_h or pad_w:
54
+ extra_shape = {"orig_extra": extra_shape, "pad_h": pad_h, "pad_w": pad_w}
55
+
56
+ return patches, h_len, w_len, extra_shape
57
+
58
+
59
+ def unpatchify(patches, h_len, w_len, extra_shape=None):
60
+ """Convert patch sequence [B, L, C*4] → latent, restoring original shape.
61
+
62
+ Auto-detects channel count from patch dimension: C = D / 4.
63
+ Handles padding removal and 3D/5D restoration based on extra_shape.
64
+ """
65
+ D = patches.shape[-1]
66
+ C = D // 4 # patch_size=2×2=4
67
+ x = rearrange(patches, "b (h w) (c ph pw) -> b c (h ph) (w pw)",
68
+ h=h_len, w=w_len, c=C, ph=2, pw=2)
69
+
70
+ # Unwrap pad info if present
71
+ pad_h = 0
72
+ pad_w = 0
73
+ orig_extra = extra_shape
74
+ if isinstance(extra_shape, dict):
75
+ pad_h = extra_shape["pad_h"]
76
+ pad_w = extra_shape["pad_w"]
77
+ orig_extra = extra_shape["orig_extra"]
78
+
79
+ # Remove padding
80
+ if pad_h:
81
+ x = x[:, :, :-pad_h, :]
82
+ if pad_w:
83
+ x = x[:, :, :, :-pad_w]
84
+
85
+ # Restore original format
86
+ if orig_extra == "unbatched":
87
+ x = x.squeeze(0)
88
+ elif orig_extra is not None:
89
+ B_orig, C_orig, T = orig_extra
90
+ H, W = x.shape[2], x.shape[3]
91
+ x = x.reshape(B_orig, T, C_orig, H, W).permute(0, 2, 1, 3, 4)
92
+
93
+ return x
custom_nodes/ComfyUI-LCS/core/relationships.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Local color relationship analysis for drift detection and correction."""
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ def compute_local_relationships(c, h_len, w_len, kernel_radius=2):
8
+ """Compute per-patch relationship vector from 5x5 neighborhood.
9
+
10
+ For each patch, cosine similarity with each of up to 24 neighbors.
11
+ Returns [B, L, N_neighbors] relationship vectors where N_neighbors = (2*r+1)^2 - 1.
12
+ """
13
+ B = c.shape[0]
14
+ r = kernel_radius
15
+ k_size = 2 * r + 1
16
+ n_neighbors = k_size * k_size - 1 # 24 for r=2
17
+
18
+ # Reshape to spatial grid
19
+ grid = c.reshape(B, h_len, w_len, 3) # [B, H, W, 3]
20
+
21
+ # Permute to [B, 3, H, W] for padding
22
+ grid_chw = grid.permute(0, 3, 1, 2) # [B, 3, H, W]
23
+ padded = F.pad(grid_chw, (r, r, r, r), mode="replicate") # [B, 3, H+2r, W+2r]
24
+
25
+ # Center values — normalize for cosine similarity
26
+ center_norm = grid_chw / grid_chw.norm(dim=1, keepdim=True).clamp(min=1e-8)
27
+
28
+ # Pre-normalize padded tensor once (avoids per-neighbor normalization in loop)
29
+ padded_norm = padded / padded.norm(dim=1, keepdim=True).clamp(min=1e-8)
30
+
31
+ # Collect cosine similarities with each neighbor
32
+ similarities = []
33
+ for dy in range(-r, r + 1):
34
+ for dx in range(-r, r + 1):
35
+ if dy == 0 and dx == 0:
36
+ continue
37
+ y_start = r + dy
38
+ x_start = r + dx
39
+ neighbor_norm = padded_norm[:, :, y_start:y_start + h_len, x_start:x_start + w_len]
40
+ # Cosine similarity per pixel
41
+ sim = (center_norm * neighbor_norm).sum(dim=1) # [B, H, W]
42
+ similarities.append(sim)
43
+
44
+ # Stack to [B, H, W, N_neighbors] -> [B, L, N_neighbors]
45
+ rel = torch.stack(similarities, dim=-1) # [B, H, W, N_neighbors]
46
+ return rel.reshape(B, -1, n_neighbors)
47
+
48
+
49
+ def detect_anomalies_adaptive(r_current, r_reference):
50
+ """Compare current vs reference relationships with adaptive threshold.
51
+
52
+ Uses per-batch robust outlier detection: threshold = median + 3.0 * 1.4826 * MAD.
53
+ Returns anomaly_magnitude [B, L, 1] in [0, 1].
54
+ """
55
+ # Mean absolute difference across neighbor relationships
56
+ diff = (r_current - r_reference).abs().mean(dim=-1) # [B, L]
57
+
58
+ # Per-batch robust statistics
59
+ median = diff.median(dim=-1, keepdim=True).values # [B, 1]
60
+ mad = (diff - median).abs().median(dim=-1, keepdim=True).values # [B, 1]
61
+ threshold = median + 3.0 * 1.4826 * mad # [B, 1]
62
+
63
+ # Soft ramp above threshold, normalized to [0, 1]
64
+ anomaly = (diff - threshold).clamp(min=0.0) # [B, L]
65
+ # Normalize per-batch: max anomaly → 1.0
66
+ amax = anomaly.amax(dim=-1, keepdim=True).clamp(min=1e-8) # [B, 1]
67
+ anomaly = anomaly / amax
68
+
69
+ return anomaly.unsqueeze(-1) # [B, L, 1]
70
+
71
+
72
+ def infer_color_from_neighbors(c, anomaly_mag, h_len, w_len, kernel_radius=2):
73
+ """For anomalous patches, infer correct color from non-anomalous neighbors.
74
+
75
+ Uses inverse-anomaly weighting: patches with low anomaly contribute more.
76
+ Returns [B, L, 3] corrected colors (blended: anomalous patches get
77
+ neighbor-inferred values, non-anomalous patches keep their original).
78
+ """
79
+ B = c.shape[0]
80
+ r = kernel_radius
81
+
82
+ # Reshape to spatial grid
83
+ grid = c.reshape(B, h_len, w_len, 3)
84
+ anom_grid = anomaly_mag.reshape(B, h_len, w_len, 1)
85
+
86
+ # Pad both grid and anomaly
87
+ grid_chw = grid.permute(0, 3, 1, 2) # [B, 3, H, W]
88
+ anom_chw = anom_grid.permute(0, 3, 1, 2) # [B, 1, H, W]
89
+ padded_c = F.pad(grid_chw, (r, r, r, r), mode="replicate")
90
+ padded_a = F.pad(anom_chw, (r, r, r, r), mode="replicate")
91
+
92
+ # Weight neighbors by how non-anomalous they are
93
+ weight_sum = torch.zeros(B, 1, h_len, w_len, device=c.device, dtype=c.dtype)
94
+ value_sum = torch.zeros(B, 3, h_len, w_len, device=c.device, dtype=c.dtype)
95
+
96
+ for dy in range(-r, r + 1):
97
+ for dx in range(-r, r + 1):
98
+ if dy == 0 and dx == 0:
99
+ continue
100
+ y_start = r + dy
101
+ x_start = r + dx
102
+ neighbor_c = padded_c[:, :, y_start:y_start + h_len, x_start:x_start + w_len]
103
+ neighbor_a = padded_a[:, :, y_start:y_start + h_len, x_start:x_start + w_len]
104
+
105
+ # Weight: 1 - anomaly (non-anomalous neighbors get high weight)
106
+ w = (1.0 - neighbor_a).clamp(min=0.01) # [B, 1, H, W]
107
+ weight_sum = weight_sum + w
108
+ value_sum = value_sum + w * neighbor_c
109
+
110
+ # Inferred color from neighbors
111
+ inferred = value_sum / weight_sum.clamp(min=1e-8) # [B, 3, H, W]
112
+ inferred = inferred.permute(0, 2, 3, 1).reshape(B, -1, 3) # [B, L, 3]
113
+
114
+ # Blend: anomalous patches use inferred, non-anomalous keep original
115
+ # anomaly_mag is [B, L, 1], range [0, ~1]
116
+ blend = anomaly_mag.clamp(0, 1)
117
+ return c * (1.0 - blend) + inferred * blend
custom_nodes/ComfyUI-LCS/core/sampling.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared sampling utilities for LCS intervention hooks."""
2
+
3
+ import comfy.utils
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def find_step_index(sigma, sigmas):
9
+ """Find the step index for a given sigma value in the sigma schedule.
10
+
11
+ Uses torch.isclose for robust matching across dtype differences (e.g. bfloat16
12
+ sigma vs float32 sample_sigmas), with argmin fallback for edge cases.
13
+ """
14
+ sigma_val = sigma.flatten()[0].float()
15
+ sigmas_f = sigmas.float()
16
+ matched = torch.isclose(sigmas_f, sigma_val, rtol=1e-3, atol=1e-5).nonzero()
17
+ if len(matched) > 0:
18
+ return matched[0].item()
19
+ return (sigmas_f - sigma_val).abs().argmin().item()
20
+
21
+
22
+ def denoised_to_raw(denoised, model):
23
+ """Convert denoised tensor from process_in space to raw VAE space.
24
+
25
+ Uses the model's latent_format.process_out (inverse of process_in).
26
+ Works for any model: FLUX (scale+shift), LTXV (identity), SD (scale), etc.
27
+ """
28
+ return model.latent_format.process_out(denoised)
29
+
30
+
31
+ def raw_to_denoised(raw, model):
32
+ """Convert raw VAE space tensor back to process_in space.
33
+
34
+ Uses the model's latent_format.process_in.
35
+ """
36
+ return model.latent_format.process_in(raw)
37
+
38
+
39
+ def unpack_video_if_needed(denoised, args):
40
+ """Unpack LTXAV-style packed latents if detected.
41
+
42
+ LTXAV packs video [B,128,F,H,W] + audio [B,ch,T,freq] into [B,1,flat].
43
+ Returns (tensor_to_process, pack_info) where pack_info is None for
44
+ non-packed formats or a dict for repacking.
45
+ """
46
+ # Detect packed format: shape [B, 1, flat] with very large last dim
47
+ if denoised.ndim == 3 and denoised.shape[1] == 1:
48
+ # Try to find latent_shapes from cond data
49
+ cond = args.get("cond")
50
+ latent_shapes = _extract_latent_shapes(cond)
51
+ if latent_shapes is not None and len(latent_shapes) > 1:
52
+ tensors = comfy.utils.unpack_latents(denoised, latent_shapes)
53
+ # tensors[0] = video [B, 128, F, H, W], tensors[1] = audio [B, ch, T, freq]
54
+ return tensors[0], {"other_tensors": tensors[1:]}
55
+ return denoised, None
56
+
57
+
58
+ def repack_video_if_needed(modified, pack_info):
59
+ """Repack video tensor back into LTXAV packed format if it was unpacked.
60
+
61
+ modified: the video tensor after intervention [B, 128, F, H, W]
62
+ pack_info: from unpack_video_if_needed
63
+ """
64
+ if pack_info is None:
65
+ return modified
66
+ all_tensors = [modified] + pack_info["other_tensors"]
67
+ packed, _ = comfy.utils.pack_latents(all_tensors)
68
+ return packed
69
+
70
+
71
+ def downsample_mask(mask, h_len, w_len, device, dtype):
72
+ """Downsample a mask to patch grid and flatten to [1, L, 1]."""
73
+ mask_dev = mask.to(device=device, dtype=dtype)
74
+ if mask_dev.ndim == 3:
75
+ mask_dev = mask_dev[:1]
76
+ if mask_dev.ndim == 2:
77
+ mask_4d = mask_dev.unsqueeze(0).unsqueeze(0) # [1, 1, H, W]
78
+ elif mask_dev.ndim == 3:
79
+ mask_4d = mask_dev.unsqueeze(1) # [B, 1, H, W]
80
+ else:
81
+ mask_4d = mask_dev
82
+ mask_resized = F.interpolate(
83
+ mask_4d, size=(h_len, w_len), mode="bilinear", align_corners=False
84
+ )
85
+ return mask_resized.reshape(1, -1, 1) # [1, L, 1]
86
+
87
+
88
+ def _extract_latent_shapes(cond):
89
+ """Try to extract latent_shapes from conditioning data.
90
+
91
+ After convert_cond, cond is a list of dicts with 'model_conds' containing
92
+ CONDConstant-wrapped values like 'latent_shapes'.
93
+ """
94
+ if cond is None:
95
+ return None
96
+ for c in cond:
97
+ if isinstance(c, dict):
98
+ model_conds = c.get('model_conds', {})
99
+ if 'latent_shapes' in model_conds:
100
+ ls = model_conds['latent_shapes']
101
+ # CONDConstant wraps the value in .cond
102
+ if hasattr(ls, 'cond'):
103
+ return ls.cond
104
+ return ls
105
+ return None
custom_nodes/ComfyUI-LCS/core/sharpness.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sharpness subspace calibration via sinusoidal grating stimuli.
2
+
3
+ Replaces the previous Gaussian blur approach with narrowband frequency
4
+ gratings, which achieve higher linearity (R²=0.94 vs 0.88) because each
5
+ stimulus contains a single spatial frequency — a purer probe of the VAE's
6
+ frequency encoding axis.
7
+
8
+ The two methods discover the same 1D subspace (|cos|=0.986, 9.7° apart),
9
+ but grating stimuli yield a cleaner PC1 direction.
10
+ """
11
+
12
+ import math
13
+ from dataclasses import dataclass
14
+ from typing import List, Optional, Tuple
15
+ import warnings
16
+
17
+ import torch
18
+ import comfy.utils
19
+
20
+ from .patchify import patchify
21
+ from .lcs_data import LCSData
22
+
23
+
24
+ @dataclass
25
+ class SharpnessData:
26
+ """Calibration data for the sharpness subspace.
27
+
28
+ Produced by PCA on FLUX VAE-encoded sinusoidal gratings at varying
29
+ spatial frequencies. PC1 captures ~94% of variance with R²=0.94
30
+ linearity vs log₂(frequency).
31
+ """
32
+
33
+ basis: torch.Tensor # [64, K] PCA basis (columns), K typically 1-2
34
+ mean: torch.Tensor # [64] PCA mean (in color-removed space if lcs_data was used)
35
+ sign: float # +1 or -1: ensures positive strength = sharper
36
+ lcs_basis: Optional[torch.Tensor] = None # [64, 3] LCS basis used during calibration (for re-orthogonalization)
37
+
38
+ def to(self, device, dtype=None):
39
+ """Move all tensors to device/dtype."""
40
+ kw = {"device": device}
41
+ if dtype is not None:
42
+ kw["dtype"] = dtype
43
+ return SharpnessData(
44
+ basis=self.basis.to(**kw),
45
+ mean=self.mean.to(**kw),
46
+ sign=self.sign,
47
+ lcs_basis=self.lcs_basis.to(**kw) if self.lcs_basis is not None else None,
48
+ )
49
+
50
+
51
+ def _generate_grating_batch(
52
+ indices: List[int],
53
+ angles: torch.Tensor,
54
+ phases: torch.Tensor,
55
+ frequencies: Tuple[float, ...],
56
+ coord_x: torch.Tensor,
57
+ coord_y: torch.Tensor,
58
+ ) -> torch.Tensor:
59
+ """Generate a batch of sinusoidal grating stimuli by flat index.
60
+
61
+ Each flat index maps to (orientation, frequency) via divmod.
62
+ Returns [len(indices), 3, H, W] tensor.
63
+ """
64
+ num_freqs = len(frequencies)
65
+ batch = []
66
+ for idx in indices:
67
+ ori = idx // num_freqs
68
+ freq = frequencies[idx % num_freqs]
69
+ angle = angles[ori].item()
70
+ phase = phases[ori].item()
71
+ cos_a, sin_a = math.cos(angle), math.sin(angle)
72
+ coord = coord_x * cos_a + coord_y * sin_a
73
+ grating = 0.5 + 0.3 * torch.sin(2 * math.pi * freq * coord + phase)
74
+ batch.append(grating.unsqueeze(0).expand(3, -1, -1))
75
+ return torch.stack(batch, dim=0)
76
+
77
+
78
+ def calibrate_sharpness(vae, num_samples: int = 64, image_size: int = 512,
79
+ frequencies: Tuple[float, ...] = (1, 2, 4, 8, 16, 32, 64),
80
+ batch_size: int = 8,
81
+ lcs_data: LCSData = None,
82
+ # Legacy parameter — accepted but ignored
83
+ blur_levels: Optional[Tuple[float, ...]] = None,
84
+ ) -> SharpnessData:
85
+ """Compute sharpness subspace data (PCA basis, mean, sign) from FLUX VAE.
86
+
87
+ Generates sinusoidal gratings at varying spatial frequencies (one pure
88
+ frequency per stimulus), VAE-encodes them, and runs PCA to find the
89
+ sharpness/frequency direction in 64D patch space.
90
+
91
+ Args:
92
+ vae: ComfyUI VAE object
93
+ num_samples: Number of orientations (each combined with all frequencies)
94
+ image_size: Size of generated images
95
+ frequencies: Spatial frequencies in cycles/image
96
+ batch_size: Batch size for VAE encoding
97
+ lcs_data: Optional LCS data for removing color component during calibration.
98
+ When provided, the sharpness PC1 will be orthogonal to the color subspace,
99
+ preventing color shifts during intervention.
100
+
101
+ Returns: SharpnessData
102
+ """
103
+ if blur_levels is not None:
104
+ warnings.warn(
105
+ "blur_levels is deprecated and ignored; calibration now uses sinusoidal gratings",
106
+ DeprecationWarning, stacklevel=2,
107
+ )
108
+
109
+ n_freqs = len(frequencies)
110
+ total_images = num_samples * n_freqs
111
+
112
+ print(f"\n[LCS Sharpness Calibration] Starting: {num_samples} orientations × {n_freqs} frequencies = {total_images} stimuli")
113
+ print(f"[LCS Sharpness Calibration] Frequencies: {list(frequencies)} cycles/image")
114
+
115
+ # Pre-compute shared state for grating generation
116
+ gen = torch.Generator().manual_seed(42)
117
+ angles = torch.rand(num_samples, generator=gen) * math.pi # [0, π)
118
+ phases = torch.rand(num_samples, generator=gen) * 2 * math.pi # [0, 2π)
119
+ y_coords = torch.linspace(-0.5, 0.5, image_size).unsqueeze(1)
120
+ x_coords = torch.linspace(-0.5, 0.5, image_size).unsqueeze(0)
121
+ coord_y = y_coords.expand(image_size, image_size)
122
+ coord_x = x_coords.expand(image_size, image_size)
123
+
124
+ # Build frequency labels for all stimuli (flat index → frequency)
125
+ freq_labels = [frequencies[idx % n_freqs] for idx in range(total_images)]
126
+ freq_labels_t = torch.tensor(freq_labels, dtype=torch.float32)
127
+ log_freq = torch.log2(freq_labels_t.clamp(min=0.5))
128
+
129
+ # Generate stimuli lazily per batch and VAE encode
130
+ vectors = []
131
+ pbar = comfy.utils.ProgressBar(total_images)
132
+
133
+ for batch_start in range(0, total_images, batch_size):
134
+ batch_end = min(batch_start + batch_size, total_images)
135
+ indices = list(range(batch_start, batch_end))
136
+ batch = _generate_grating_batch(indices, angles, phases, frequencies, coord_x, coord_y)
137
+ actual_batch = batch.shape[0]
138
+
139
+ # Convert BCHW → BHWC for ComfyUI VAE
140
+ imgs_bhwc = batch.permute(0, 2, 3, 1).contiguous().cpu()
141
+
142
+ # VAE encode — try batch first, fall back to per-image for video VAEs
143
+ latent = vae.encode(imgs_bhwc)
144
+ patches, _, _, _ = patchify(latent)
145
+ avg = patches.mean(dim=1).cpu()
146
+
147
+ if avg.shape[0] == actual_batch:
148
+ vectors.extend(avg.unbind(0))
149
+ else:
150
+ # Video VAE: batch not fully supported, encode one by one
151
+ vectors.extend(avg.unbind(0))
152
+ for k in range(1, actual_batch):
153
+ single = imgs_bhwc[k:k+1]
154
+ lat = vae.encode(single)
155
+ p, _, _, _ = patchify(lat)
156
+ vectors.append(p.mean(dim=1).cpu().squeeze(0))
157
+
158
+ pbar.update(actual_batch)
159
+
160
+ # Stack all vectors: [N, 64]
161
+ X = torch.stack(vectors, dim=0).float()
162
+ print(f"[LCS Sharpness Calibration] Collected {X.shape[0]} vectors of dimension {X.shape[1]}")
163
+
164
+ # Remove LCS color component FIRST, in the raw space where LCS was calibrated.
165
+ # This must happen before per-vector DC removal, because the LCS basis has
166
+ # significant DC components (PC1 ≈ brightness). Doing DC removal first would
167
+ # shift vectors into a different space where B^T(x - mu) is incorrect.
168
+ if lcs_data is not None:
169
+ print("[LCS Sharpness Calibration] Removing LCS color component...")
170
+ lcs_mean = lcs_data.mean.to(X.device, X.dtype)
171
+ lcs_basis = lcs_data.basis.to(X.device, X.dtype)
172
+ # Project out color: X' = X - B B^T (X - mu)
173
+ centered = X - lcs_mean
174
+ lcs_coords = centered @ lcs_basis # [N, 3]
175
+ X = X - lcs_coords @ lcs_basis.T
176
+ print("[LCS Sharpness Calibration] Color component removed")
177
+
178
+ # Remove per-vector DC AFTER color removal.
179
+ # VAE encoding shifts the latent mean depending on stimulus content.
180
+ # Per-vector zero-mean forces PCA to find patterns in the relative channel
181
+ # structure, not in the absolute level.
182
+ X = X - X.mean(dim=1, keepdim=True)
183
+
184
+ # Step 3: PCA
185
+ print("[LCS Sharpness Calibration] Computing PCA...")
186
+ mean = X.mean(dim=0) # [64]
187
+ X_centered = X - mean
188
+ U, S, Vh = torch.linalg.svd(X_centered, full_matrices=False)
189
+ # Top 2 components
190
+ basis = Vh[:2].T # [64, 2]
191
+
192
+ # Variance explained
193
+ total_var = (S ** 2).sum()
194
+ explained = (S[:2] ** 2) / total_var
195
+ print(f"[LCS Sharpness Calibration] PC1: {explained[0]:.1%}, PC2: {explained[1]:.1%} ({(explained[0]+explained[1]):.1%} total)")
196
+
197
+ # Step 4: Determine sign convention
198
+ # Project all vectors onto PC1
199
+ pc1_scores = X_centered @ basis[:, 0] # [N]
200
+
201
+ # Correlate PC1 score with log₂(frequency)
202
+ # Higher frequency = sharper → if positive correlation, sign = +1
203
+ correlation = torch.corrcoef(torch.stack([pc1_scores, log_freq]))[0, 1]
204
+ sign = 1.0 if correlation > 0 else -1.0
205
+ print(f"[LCS Sharpness Calibration] PC1-frequency correlation: {correlation:.3f} → sign = {sign:+.0f}")
206
+ print(f"[LCS Sharpness Calibration] Complete! Basis shape: {basis.shape}")
207
+
208
+ return SharpnessData(
209
+ basis=basis,
210
+ mean=mean,
211
+ sign=sign,
212
+ lcs_basis=lcs_data.basis.clone() if lcs_data is not None else None,
213
+ )
custom_nodes/ComfyUI-LCS/core/timestep.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sigma ↔ paper timestep conversion and α_t/β_t interpolation."""
2
+
3
+ import torch
4
+ from .defaults import get_alpha_table, get_beta_table
5
+
6
+
7
+ def sigma_to_paper_t(sigma):
8
+ """Convert FLUX sigma ∈ [0,1] to paper timestep t ∈ [0,50].
9
+
10
+ sigma=1 → noise → t=0, sigma=0 → clean → t=50.
11
+ """
12
+ if isinstance(sigma, torch.Tensor):
13
+ return 50.0 * (1.0 - sigma.clamp(0.0, 1.0))
14
+ return 50.0 * (1.0 - max(0.0, min(1.0, sigma)))
15
+
16
+
17
+ def get_alpha_beta(sigma, device=None):
18
+ """Get interpolated α_t and β_t [3] vectors for a given sigma.
19
+
20
+ Returns (alpha_t, beta_t) as tensors on the specified device.
21
+ """
22
+ t = sigma_to_paper_t(sigma)
23
+ if isinstance(t, torch.Tensor):
24
+ t = t.item()
25
+
26
+ alpha_table = get_alpha_table() # [51, 3]
27
+ beta_table = get_beta_table() # [51, 3]
28
+
29
+ t = max(0.0, min(50.0, t))
30
+ t_low = int(t)
31
+ t_high = min(t_low + 1, 50)
32
+ frac = t - t_low
33
+
34
+ alpha = (1.0 - frac) * alpha_table[t_low] + frac * alpha_table[t_high]
35
+ beta = (1.0 - frac) * beta_table[t_low] + frac * beta_table[t_high]
36
+
37
+ if device is not None:
38
+ alpha = alpha.to(device)
39
+ beta = beta.to(device)
40
+ return alpha, beta
41
+
42
+
43
+ def get_alpha_beta_t50(device=None):
44
+ """Get α_50 and β_50 (reference timestep t=50, clean image)."""
45
+ alpha_table = get_alpha_table()
46
+ beta_table = get_beta_table()
47
+ alpha_50 = alpha_table[50]
48
+ beta_50 = beta_table[50]
49
+ if device is not None:
50
+ alpha_50 = alpha_50.to(device)
51
+ beta_50 = beta_50.to(device)
52
+ return alpha_50, beta_50
53
+
54
+
55
+ def normalize_to_t50(c, alpha_t, beta_t, alpha_50, beta_50):
56
+ """Normalize LCS coords from timestep t to reference t=50.
57
+
58
+ ĉ = (c - α_t) / β_t * β_50 + α_50
59
+ c: [..., 3], alpha_t/beta_t/alpha_50/beta_50: [3]
60
+ """
61
+ beta_t_safe = beta_t.clone()
62
+ beta_t_safe = torch.where(beta_t_safe.abs() < 1e-6,
63
+ torch.full_like(beta_t_safe, 1e-6), beta_t_safe)
64
+ return (c - alpha_t) / beta_t_safe * beta_50 + alpha_50
65
+
66
+
67
+ def denormalize_from_t50(c_hat, alpha_t, beta_t, alpha_50, beta_50):
68
+ """Denormalize LCS coords from reference t=50 back to timestep t.
69
+
70
+ c = (ĉ - α_50) / β_50 * β_t + α_t
71
+ """
72
+ beta_50_safe = beta_50.clone()
73
+ beta_50_safe = torch.where(beta_50_safe.abs() < 1e-6,
74
+ torch.full_like(beta_50_safe, 1e-6), beta_50_safe)
75
+ return (c_hat - alpha_50) / beta_50_safe * beta_t + alpha_t