diff --git a/.gitattributes b/.gitattributes index 21ae8588d5db87138d82f83eff34c9225403e235..8d46f2a4e5308b8972c736e6998008385436e48d 100644 --- a/.gitattributes +++ b/.gitattributes @@ -142,6 +142,8 @@ models/unet/zit_beyond_reality.safetensors filter=lfs diff=lfs merge=lfs -text models/vae/ae.safetensors filter=lfs diff=lfs merge=lfs -text models/vae/flux2-vae.safetensors filter=lfs diff=lfs merge=lfs -text models/vae/wan_2.1_vae.safetensors filter=lfs diff=lfs merge=lfs -text +custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/img/preview.jpg filter=lfs diff=lfs merge=lfs -text +custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/posi_prompt.pth filter=lfs diff=lfs merge=lfs -text models/FlashVSR/FlashVSR1_1.safetensors filter=lfs diff=lfs merge=lfs -text models/FlashVSR/LQ_proj_in.safetensors filter=lfs diff=lfs merge=lfs -text models/FlashVSR/Prompt.safetensors filter=lfs diff=lfs merge=lfs -text @@ -149,3 +151,7 @@ models/FlashVSR/TCDecoder.safetensors filter=lfs diff=lfs merge=lfs -text models/FlashVSR/Wan2.1_VAE.safetensors filter=lfs diff=lfs merge=lfs -text models/FlashVSR/Wan2_1-T2V-1_3B_FlashVSR_fp32.safetensors filter=lfs diff=lfs merge=lfs -text models/FlashVSR/Wan2_1_FlashVSR_LQ_proj_model_bf16.safetensors filter=lfs diff=lfs merge=lfs -text +models/FlashVSR-v1.1/LQ_proj_in.ckpt filter=lfs diff=lfs merge=lfs -text +models/FlashVSR-v1.1/TCDecoder.ckpt filter=lfs diff=lfs merge=lfs -text +models/FlashVSR-v1.1/Wan2.1_VAE.pth filter=lfs diff=lfs merge=lfs -text +models/FlashVSR-v1.1/diffusion_pytorch_model_streaming_dmd.safetensors filter=lfs diff=lfs merge=lfs -text diff --git a/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/.gitignore b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..562f04eebd6e79906fd91fd86e8531df360221f6 --- /dev/null +++ b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/.gitignore @@ -0,0 +1,210 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[codz] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py.cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +#uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock +#poetry.toml + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. +# https://pdm-project.org/en/latest/usage/project/#working-with-version-control +#pdm.lock +#pdm.toml +.pdm-python +.pdm-build/ + +# pixi +# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. +#pixi.lock +# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one +# in the .venv directory. It is recommended not to include this directory in version control. +.pixi + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.envrc +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Abstra +# Abstra is an AI-powered process automation framework. +# Ignore directories containing user credentials, local state, and settings. +# Learn more at https://abstra.io/docs +.abstra/ + +# Visual Studio Code +# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore +# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore +# and can be added to the global gitignore or merged into this file. However, if you prefer, +# you could uncomment the following to ignore the entire vscode folder +# .vscode/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc + +# Cursor +# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to +# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data +# refer to https://docs.cursor.com/context/ignore-files +.cursorignore +.cursorindexingignore + +# Marimo +marimo/_static/ +marimo/_lsp/ +__marimo__/ + +# macOS +.DS_Store \ No newline at end of file diff --git a/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/LICENSE b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..f288702d2fa16d3cdf0035b15a9fcbc552cd88e7 --- /dev/null +++ b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/LICENSE @@ -0,0 +1,674 @@ + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. diff --git a/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/README.md b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e77018407bf1f45ad5ae61f1017d4b569c624ab7 --- /dev/null +++ b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/README.md @@ -0,0 +1,68 @@ +# ComfyUI-FlashVSR_Ultra_Fast +Running FlashVSR on lower VRAM without any artifacts. +**[[📃中文版本](./README_zh.md)]** + +## Changelog +#### 2025-10-24 +- Added long video pipeline that significantly reduces VRAM usage when upscaling long videos. + +#### 2025-10-21 +- Initial this project, introducing features such as `tile_dit` to significantly reducing VRAM usage. + +#### 2025-10-22 +- Replaced `Block-Sparse-Attention` with `Sparse_Sage`, removing the need to compile any custom kernels. +- Added support for running on RTX 50 series GPUs. + +## Preview +![](./img/preview.jpg) + +## Usage +- **mode:** +`tiny` -> faster (default); `full` -> higher quality +- **scale:** +`4` is always better, unless you are low on VRAM then use `2` +- **color_fix:** +Use wavelet transform to correct the color of output video. +- **tiled_vae:** +Set to True for lower VRAM consumption during decoding at the cost of speed. +- **tiled_dit:** +Significantly reduces VRAM usage at the cost of speed. +- **tile\_size, tile\_overlap**: +How to split the input video. +- **unload_dit:** +Unload DiT before decoding to reduce VRAM peak at the cost of speed. + +## Installation + +#### nodes: + +```bash +cd ComfyUI/custom_nodes +git clone https://github.com/lihaoyun6/ComfyUI-FlashVSR_Ultra_Fast.git +python -m pip install -r ComfyUI-FlashVSR_Ultra_Fast/requirements.txt +``` +📢: For Turing or older GPU, please install `triton<3.3.0`: + +```bash +# Windows +python -m pip install -U triton-windows<3.3.0 +# Linux +python -m pip install -U triton<3.3.0 +``` + +#### models: + +- Download the entire `FlashVSR` folder with all the files inside it from [here](https://huggingface.co/JunhaoZhuang/FlashVSR) and put it in the `ComfyUI/models` + +``` +├── ComfyUI/models/FlashVSR +| ├── LQ_proj_in.ckpt +| ├── TCDecoder.ckpt +| ├── diffusion_pytorch_model_streaming_dmd.safetensors +| ├── Wan2.1_VAE.pth +``` + +## Acknowledgments +- [FlashVSR](https://github.com/OpenImagingLab/FlashVSR) @OpenImagingLab +- [Sparse_SageAttention](https://github.com/jt-zhang/Sparse_SageAttention_API) @jt-zhang +- [ComfyUI](https://github.com/comfyanonymous/ComfyUI) @comfyanonymous diff --git a/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/README_zh.md b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/README_zh.md new file mode 100644 index 0000000000000000000000000000000000000000..4cd055db1aa5329b15d260d2f9e052e2c48579df --- /dev/null +++ b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/README_zh.md @@ -0,0 +1,66 @@ +# ComfyUI-FlashVSR_Ultra_Fast +在低显存环境下运行 FlashVSR,同时保持无伪影高质量输出。 +**[[📃English](./readme.md)]** + +## 更新日志 +#### 2025-10-24 +- 新增长视频管道, 可显著降低长视频放大的显存用量 + +#### 2025-10-21 +- 项目首次发布, 引入了`tile_dit`等功能, 大幅度降低显存需求 + +#### 2025-10-22 +- 使用`Sparse_SageAttention`替换了`Block-Sparse-Attention`, 无需编译安装任何自定义内核, 开箱即用. +- 支持在 RTX50 系列显卡上运行. + +## 预览 +![](./img/preview.jpg) + +## 使用说明 +- **mode(模式):** + `tiny` → 更快(默认);`full` → 更高质量 +- **scale(放大倍数):** + 通常使用 `4` 效果更好;如果显存不足,可使用 `2` +- **color_fix(颜色修正):** + 使用小波变换方法修正输出视频的颜色偏差。 +- **tiled_vae(VAE分块解码):** + 启用后可显著降低显存占用,但会降低解码速度。 +- **tiled_dit(DiT分块计算):** + 大幅减少显存占用,但会降低推理速度。 +- **tile_size / tile_overlap(分块大小与重叠):** + 控制输入视频在推理时的分块方式。 +- **unload_dit(卸载DiT模型):** + 解码前卸载 DiT 模型以降低显存峰值,但会略微降低速度。 + +## 安装步骤 + +#### 安装节点: +```bash +cd ComfyUI/custom_nodes +git clone https://github.com/lihaoyun6/ComfyUI-FlashVSR_Ultra_Fast.git +python -m pip install -r ComfyUI-FlashVSR_Ultra_Fast/requirements.txt +``` +📢: 要在RTX20系或更早的GPU上运行, 请安装`triton<3.3.0`: + +```bash +# Windows +python -m pip install -U triton-windows<3.3.0 +# Linux +python -m pip install -U triton<3.3.0 +``` + +#### 模型下载: +- 从[这里](https://huggingface.co/JunhaoZhuang/FlashVSR)下载整个`FlashVSR`文件夹和它里面的所有文件, 并将其放到`ComfyUI/models`目录中。 + +``` +├── ComfyUI/models/FlashVSR +| ├── LQ_proj_in.ckpt +| ├── TCDecoder.ckpt +| ├── diffusion_pytorch_model_streaming_dmd.safetensors +| ├── Wan2.1_VAE.pth +``` + +## 致谢 +- [FlashVSR](https://github.com/OpenImagingLab/FlashVSR) @OpenImagingLab +- [Sparse_SageAttention](https://github.com/jt-zhang/Sparse_SageAttention_API) @jt-zhang +- [ComfyUI](https://github.com/comfyanonymous/ComfyUI) @comfyanonymous diff --git a/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/__init__.py b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2e96bd6ab3db650f769ae7886e0c13515752bd16 --- /dev/null +++ b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/__init__.py @@ -0,0 +1,3 @@ +from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS + +__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] \ No newline at end of file diff --git a/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/img/preview.jpg b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/img/preview.jpg new file mode 100644 index 0000000000000000000000000000000000000000..04912119ed4702b9849c25f8e0ebab7181ae8e4d --- /dev/null +++ b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/img/preview.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ad7cc28a6c911472d5653b7c90aa8ca0737c42f34fa82b5f093e48af53039c0e +size 775988 diff --git a/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/nodes.py b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..37e9b3e654367121b1d1d8a93bd7e425bcc2a9fc --- /dev/null +++ b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/nodes.py @@ -0,0 +1,553 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import os,gc +import math +import torch +import folder_paths +import comfy.utils + +import numpy as np +import torch.nn.functional as F + +from einops import rearrange +from huggingface_hub import snapshot_download +from .src import ModelManager, FlashVSRFullPipeline, FlashVSRTinyPipeline, FlashVSRTinyLongPipeline +from .src.models.TCDecoder import build_tcdecoder +from .src.models.utils import clean_vram, get_device_list, Buffer_LQ4x_Proj, Causal_LQ4x_Proj +from .src.models import wan_video_dit + +device_choices = get_device_list() + +def log(message:str, message_type:str='normal'): + if message_type == 'error': + message = '\033[1;41m' + message + '\033[m' + elif message_type == 'warning': + message = '\033[1;31m' + message + '\033[m' + elif message_type == 'finish': + message = '\033[1;32m' + message + '\033[m' + elif message_type == 'info': + message = '\033[1;33m' + message + '\033[m' + else: + message = message + print(f"{message}") + +def model_downlod(model_name="JunhaoZhuang/FlashVSR"): + model_dir = os.path.join(folder_paths.models_dir, model_name.split("/")[-1]) + if not os.path.exists(model_dir): + log(f"Downloading model '{model_name}' from huggingface...", message_type='info') + snapshot_download(repo_id=model_name, local_dir=model_dir, local_dir_use_symlinks=False, resume_download=True) + +def tensor2video(frames: torch.Tensor): + video_squeezed = frames.squeeze(0) + video_permuted = rearrange(video_squeezed, "C F H W -> F H W C") + video_final = (video_permuted.float() + 1.0) / 2.0 + return video_final + +def largest_8n1_leq(n): # 8n+1 + return 0 if n < 1 else ((n - 1)//8)*8 + 1 + +def next_8n5(n): # next 8n+5 + return 21 if n < 21 else ((n - 5 + 7) // 8) * 8 + 5 + +def compute_scaled_and_target_dims(w0: int, h0: int, scale: int = 4, multiple: int = 128): + if w0 <= 0 or h0 <= 0: + raise ValueError("invalid original size") + + sW, sH = w0 * scale, h0 * scale + tW = max(multiple, (sW // multiple) * multiple) + tH = max(multiple, (sH // multiple) * multiple) + return sW, sH, tW, tH + +def tensor_upscale_then_center_crop(frame_tensor: torch.Tensor, scale: int, tW: int, tH: int) -> torch.Tensor: + h0, w0, c = frame_tensor.shape + tensor_bchw = frame_tensor.permute(2, 0, 1).unsqueeze(0) # HWC -> CHW -> BCHW + + sW, sH = w0 * scale, h0 * scale + upscaled_tensor = F.interpolate(tensor_bchw, size=(sH, sW), mode='bicubic', align_corners=False) + + l = max(0, (sW - tW) // 2) + t = max(0, (sH - tH) // 2) + cropped_tensor = upscaled_tensor[:, :, t:t + tH, l:l + tW] + + return cropped_tensor.squeeze(0) + +def prepare_input_tensor(image_tensor: torch.Tensor, device, scale: int = 4, dtype=torch.bfloat16): + N0, h0, w0, _ = image_tensor.shape + + multiple = 128 + sW, sH, tW, tH = compute_scaled_and_target_dims(w0, h0, scale=scale, multiple=multiple) + num_frames_with_padding = N0 + 4 + F = largest_8n1_leq(num_frames_with_padding) + + if F == 0: + raise RuntimeError(f"Not enough frames after padding. Got {num_frames_with_padding}.") + + frames = [] + for i in range(F): + frame_idx = min(i, N0 - 1) + frame_slice = image_tensor[frame_idx].to(device) + tensor_chw = tensor_upscale_then_center_crop(frame_slice, scale=scale, tW=tW, tH=tH).to('cpu').to(dtype) * 2.0 - 1.0 + frames.append(tensor_chw) + del frame_slice + + vid_stacked = torch.stack(frames, 0) + vid_final = vid_stacked.permute(1, 0, 2, 3).unsqueeze(0) + + del vid_stacked + clean_vram() + + return vid_final, tH, tW, F + +def calculate_tile_coords(height, width, tile_size, overlap): + coords = [] + + stride = tile_size - overlap + num_rows = math.ceil((height - overlap) / stride) + num_cols = math.ceil((width - overlap) / stride) + + for r in range(num_rows): + for c in range(num_cols): + y1 = r * stride + x1 = c * stride + + y2 = min(y1 + tile_size, height) + x2 = min(x1 + tile_size, width) + + if y2 - y1 < tile_size: + y1 = max(0, y2 - tile_size) + if x2 - x1 < tile_size: + x1 = max(0, x2 - tile_size) + + coords.append((x1, y1, x2, y2)) + + return coords + +def create_feather_mask(size, overlap): + H, W = size + mask = torch.ones(1, 1, H, W) + ramp = torch.linspace(0, 1, overlap) + + mask[:, :, :, :overlap] = torch.minimum(mask[:, :, :, :overlap], ramp.view(1, 1, 1, -1)) + mask[:, :, :, -overlap:] = torch.minimum(mask[:, :, :, -overlap:], ramp.flip(0).view(1, 1, 1, -1)) + + mask[:, :, :overlap, :] = torch.minimum(mask[:, :, :overlap, :], ramp.view(1, 1, -1, 1)) + mask[:, :, -overlap:, :] = torch.minimum(mask[:, :, -overlap:, :], ramp.flip(0).view(1, 1, -1, 1)) + + return mask + +def init_pipeline(model, mode, device, dtype, alt_vae="none"): + model_downlod(model_name="JunhaoZhuang/"+model) + model_path = os.path.join(folder_paths.models_dir, model) + if not os.path.exists(model_path): + raise RuntimeError(f'Model directory does not exist!\nPlease save all weights to "{model_path}"') + ckpt_path = os.path.join(model_path, "diffusion_pytorch_model_streaming_dmd.safetensors") + if not os.path.exists(ckpt_path): + raise RuntimeError(f'"diffusion_pytorch_model_streaming_dmd.safetensors" does not exist!\nPlease save it to "{model_path}"') + if alt_vae != "none": + vae_path = folder_paths.get_full_path_or_raise("vae", alt_vae) + if not os.path.exists(vae_path): + raise RuntimeError(f'"{alt_vae}" does not exist!') + else: + vae_path = os.path.join(model_path, "Wan2.1_VAE.pth") + if not os.path.exists(vae_path): + raise RuntimeError(f'"Wan2.1_VAE.pth" does not exist!\nPlease save it to "{model_path}"') + lq_path = os.path.join(model_path, "LQ_proj_in.ckpt") + if not os.path.exists(lq_path): + raise RuntimeError(f'"LQ_proj_in.ckpt" does not exist!\nPlease save it to "{model_path}"') + tcd_path = os.path.join(model_path, "TCDecoder.ckpt") + if not os.path.exists(tcd_path): + raise RuntimeError(f'"TCDecoder.ckpt" does not exist!\nPlease save it to "{model_path}"') + current_dir = os.path.dirname(os.path.abspath(__file__)) + prompt_path = os.path.join(current_dir, "posi_prompt.pth") + + mm = ModelManager(torch_dtype=dtype, device="cpu") + if mode == "full": + mm.load_models([ckpt_path, vae_path]) + pipe = FlashVSRFullPipeline.from_model_manager(mm, device=device) + pipe.vae.model.encoder = None + pipe.vae.model.conv1 = None + else: + mm.load_models([ckpt_path]) + if mode == "tiny": + pipe = FlashVSRTinyPipeline.from_model_manager(mm, device=device) + else: + pipe = FlashVSRTinyLongPipeline.from_model_manager(mm, device=device) + multi_scale_channels = [512, 256, 128, 128] + pipe.TCDecoder = build_tcdecoder(new_channels=multi_scale_channels, device=device, dtype=dtype, new_latent_channels=16+768) + mis = pipe.TCDecoder.load_state_dict(torch.load(tcd_path, map_location=device), strict=False) + pipe.TCDecoder.clean_mem() + + if model == "FlashVSR": + pipe.denoising_model().LQ_proj_in = Buffer_LQ4x_Proj(in_dim=3, out_dim=1536, layer_num=1).to(device, dtype=dtype) + else: + pipe.denoising_model().LQ_proj_in = Causal_LQ4x_Proj(in_dim=3, out_dim=1536, layer_num=1).to(device, dtype=dtype) + pipe.denoising_model().LQ_proj_in.load_state_dict(torch.load(lq_path, map_location="cpu"), strict=True) + pipe.denoising_model().LQ_proj_in.to(device) + pipe.to(device, dtype=dtype) + pipe.enable_vram_management(num_persistent_param_in_dit=None) + pipe.init_cross_kv(prompt_path=prompt_path) + pipe.load_models_to_device(["dit","vae"]) + pipe.offload_model() + + return pipe + +class cqdm: + def __init__(self, iterable=None, total=None, desc="Processing"): + self.desc = desc + self.pbar = None + self.iterable = None + self.total = total + + if iterable is not None: + try: + self.total = len(iterable) + self.iterable = iter(iterable) + except TypeError: + if self.total is None: + raise ValueError("Total must be provided for iterables with no length.") + + elif self.total is not None: + pass + + else: + raise ValueError("Either iterable or total must be provided.") + + def __iter__(self): + if self.iterable is None: + raise TypeError(f"'{type(self).__name__}' object is not iterable. Did you mean to use it with a 'with' statement?") + if self.pbar is None: + self.pbar = comfy.utils.ProgressBar(self.total) + return self + + def __next__(self): + if self.iterable is None: + raise TypeError("Cannot call __next__ on a non-iterable cqdm object.") + try: + val = next(self.iterable) + if self.pbar: + self.pbar.update(1) + return val + except StopIteration: + raise + + def __enter__(self): + if self.pbar is None: + self.pbar = comfy.utils.ProgressBar(self.total) + return self.pbar + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def __len__(self): + return self.total + +def flashvsr(pipe, frames, scale, color_fix, tiled_vae, tiled_dit, tile_size, tile_overlap, unload_dit, sparse_ratio, kv_ratio, local_range, seed, force_offload): + _frames = frames + _device = pipe.device + dtype = pipe.torch_dtype + + add = next_8n5(frames.shape[0]) - frames.shape[0] + padding_frames = frames[-1:, :, :, :].repeat(add, 1, 1, 1) + _frames = torch.cat([frames, padding_frames], dim=0) + + if tiled_dit: + N, H, W, C = _frames.shape + + final_output_canvas = torch.zeros( + (N, H * scale, W * scale, C), + dtype=torch.float16, + device="cpu" + ) + weight_sum_canvas = torch.zeros_like(final_output_canvas) + tile_coords = calculate_tile_coords(H, W, tile_size, tile_overlap) + latent_tiles_cpu = [] + + for i, (x1, y1, x2, y2) in enumerate(cqdm(tile_coords, desc="Processing Tiles")): + log(f"[FlashVSR] Processing tile {i+1}/{len(tile_coords)}: coords ({x1},{y1}) to ({x2},{y2})", message_type='info') + input_tile = _frames[:, y1:y2, x1:x2, :] + + LQ_tile, th, tw, F = prepare_input_tensor(input_tile, _device, scale=scale, dtype=dtype) + if not isinstance(pipe, FlashVSRTinyLongPipeline): + LQ_tile = LQ_tile.to(_device) + + output_tile_gpu = pipe( + prompt="", negative_prompt="", cfg_scale=1.0, num_inference_steps=1, seed=seed, tiled=tiled_vae, + LQ_video=LQ_tile, num_frames=F, height=th, width=tw, is_full_block=False, if_buffer=True, + topk_ratio=sparse_ratio*768*1280/(th*tw), kv_ratio=kv_ratio, local_range=local_range, + color_fix=color_fix, unload_dit=unload_dit, force_offload=force_offload + ) + + processed_tile_cpu = tensor2video(output_tile_gpu).to("cpu") + + mask_nchw = create_feather_mask( + (processed_tile_cpu.shape[1], processed_tile_cpu.shape[2]), + tile_overlap * scale + ).to("cpu") + mask_nhwc = mask_nchw.permute(0, 2, 3, 1) + out_x1, out_y1 = x1 * scale, y1 * scale + + tile_H_scaled = processed_tile_cpu.shape[1] + tile_W_scaled = processed_tile_cpu.shape[2] + out_x2, out_y2 = out_x1 + tile_W_scaled, out_y1 + tile_H_scaled + final_output_canvas[:, out_y1:out_y2, out_x1:out_x2, :] += processed_tile_cpu * mask_nhwc + weight_sum_canvas[:, out_y1:out_y2, out_x1:out_x2, :] += mask_nhwc + + del LQ_tile, output_tile_gpu, processed_tile_cpu, input_tile + clean_vram() + + weight_sum_canvas[weight_sum_canvas == 0] = 1.0 + final_output = final_output_canvas / weight_sum_canvas + else: + log("[FlashVSR] Preparing frames...") + LQ, th, tw, F = prepare_input_tensor(_frames, _device, scale=scale, dtype=dtype) + if not isinstance(pipe, FlashVSRTinyLongPipeline): + LQ = LQ.to(_device) + log(f"[FlashVSR] Processing {frames.shape[0]} frames...", message_type='info') + + video = pipe( + prompt="", negative_prompt="", cfg_scale=1.0, num_inference_steps=1, seed=seed, tiled=tiled_vae, + progress_bar_cmd=cqdm, LQ_video=LQ, num_frames=F, height=th, width=tw, is_full_block=False, if_buffer=True, + topk_ratio=sparse_ratio*768*1280/(th*tw), kv_ratio=kv_ratio, local_range=local_range, + color_fix = color_fix, unload_dit=unload_dit, force_offload=force_offload + ) + + final_output = tensor2video(video).to('cpu') + + del video, LQ + clean_vram() + + log("[FlashVSR] Done.", message_type='info') + if frames.shape[0] == 1: + final_output = final_output.to(_device) + stacked_image_tensor = torch.median(final_output, dim=0).values.unsqueeze(0).float().to('cpu') + del final_output + clean_vram() + return stacked_image_tensor + + return final_output[:frames.shape[0], :, :, :] + +class FlashVSRNodeInitPipe: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model": (["FlashVSR", "FlashVSR-v1.1"], { + "default": "FlashVSR-v1.1", + "tooltip": "Model version." + }), + "mode": (["tiny", "tiny-long", "full"], { + "default": "tiny", + "tooltip": 'Using "tiny-long" mode can significantly reduce VRAM used with long video input.' + }), + "alt_vae": (["none"] + folder_paths.get_filename_list("vae"), { + "default": "none", + "tooltip": 'Replaces the built-in VAE, only available in "full" mode.' + }), + "force_offload": ("BOOLEAN", { + "default": True, + "tooltip": "Offload all weights to CPU after running a workflow to free up VRAM." + }), + "precision": (["fp16", "bf16"], { + "default": "bf16", + "tooltip": "Data and inference precision." + }), + "device": (device_choices, { + "default": device_choices[0], + "tooltip": "Device to load the weights, default: auto (CUDA if available, else CPU)" + }), + "attention_mode": (["sparse_sage_attention", "block_sparse_attention"], { + "default": "sparse_sage_attention", + "tooltip": '"sparse_sage_attention" is available for sm_75 to sm_120\n"block_sparse_attention" is available for sm_80 to sm_100' + }), + } + } + + RETURN_TYPES = ("PIPE",) + RETURN_NAMES = ("pipe",) + FUNCTION = "main" + CATEGORY = "FlashVSR" + DESCRIPTION = 'Download the entire "FlashVSR" folder with all the files inside it from "https://huggingface.co/JunhaoZhuang/FlashVSR" and put it in the "ComfyUI/models"' + + def main(self, model, mode, alt_vae, force_offload, precision, device, attention_mode): + _device = device + if device == "auto": + _device = "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else device + if _device == "auto" or _device not in device_choices: + raise RuntimeError("No devices found to run FlashVSR!") + + if _device.startswith("cuda"): + torch.cuda.set_device(_device) + + if attention_mode == "sparse_sage_attention": + wan_video_dit.USE_BLOCK_ATTN = False + else: + wan_video_dit.USE_BLOCK_ATTN = True + + dtype_map = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, + } + try: + dtype = dtype_map[precision] + except: + dtype = torch.bfloat16 + + pipe = init_pipeline(model, mode, _device, dtype, alt_vae=alt_vae) + return((pipe, force_offload),) + +class FlashVSRNodeAdv: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "pipe": ("PIPE", { + "tooltip": "FlashVSR pipeline" + }), + "frames": ("IMAGE", { + "tooltip": "Sequential video frames as IMAGE tensor batch" + }), + "scale": ("INT", { + "default": 2, + "min": 2, + "max": 4, + }), + "color_fix": ("BOOLEAN", { + "default": True, + "tooltip": "Use wavelet transform to correct output video color." + }), + "tiled_vae": ("BOOLEAN", { + "default": True, + "tooltip": "Disable tiling: faster decode but higher VRAM usage.\nSet to True for lower memory consumption at the cost of speed." + }), + "tiled_dit": ("BOOLEAN", { + "default": True, + "tooltip": "Significantly reduces VRAM usage at the cost of speed." + }), + "tile_size": ("INT", { + "default": 256, + "min": 32, + "max": 1024, + "step": 32, + }), + "tile_overlap": ("INT", { + "default": 24, + "min": 8, + "max": 512, + "step": 8, + }), + "unload_dit": ("BOOLEAN", { + "default": False, + "tooltip": "Unload DiT before decoding to reduce VRAM peak at the cost of speed." + }), + "sparse_ratio": ("FLOAT", { + "default": 2.0, + "min": 1.5, + "max": 2.0, + "step": 0.1, + "display": "slider", + "tooltip": "Recommended: 1.5 or 2.0\n1.5 → faster; 2.0 → more stable" + }), + "kv_ratio": ("FLOAT", { + "default": 3.0, + "min": 1.0, + "max": 3.0, + "step": 0.1, + "display": "slider", + "tooltip": "Recommended: 1.0 to 3.0\n1.0 → less vram; 3.0 → high quality" + }), + "local_range": ("INT", { + "default": 11, + "min": 9, + "max": 11, + "step": 2, + "tooltip": "Recommended: 9 or 11\nlocal_range=9 → sharper details; 11 → more stable results" + }), + "seed": ("INT", { + "default": 0, + "min": 0, + "max": 1125899906842624 + }), + } + } + + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("image",) + FUNCTION = "main" + CATEGORY = "FlashVSR" + #DESCRIPTION = "" + + def main(self, pipe, frames, scale, color_fix, tiled_vae, tiled_dit, tile_size, tile_overlap, unload_dit, sparse_ratio, kv_ratio, local_range, seed): + _pipe, force_offload = pipe + output = flashvsr(_pipe, frames, scale, color_fix, tiled_vae, tiled_dit, tile_size, tile_overlap, unload_dit, sparse_ratio, kv_ratio, local_range, seed, force_offload) + return(output,) + +class FlashVSRNode: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "frames": ("IMAGE", { + "tooltip": "Sequential video frames as IMAGE tensor batch" + }), + "model": (["FlashVSR", "FlashVSR-v1.1"], { + "default": "FlashVSR-v1.1", + "tooltip": "Model version." + }), + "mode": (["tiny", "tiny-long", "full"], { + "default": "tiny", + "tooltip": 'Using "tiny-long" mode can significantly reduce VRAM used with long video input.' + }), + "scale": ("INT", { + "default": 2, + "min": 2, + "max": 4, + }), + "tiled_vae": ("BOOLEAN", { + "default": True, + "tooltip": "Disable tiling: faster decode but higher VRAM usage.\nSet to True for lower memory consumption at the cost of speed." + }), + "tiled_dit": ("BOOLEAN", { + "default": True, + "tooltip": "Significantly reduces VRAM usage at the cost of speed." + }), + "unload_dit": ("BOOLEAN", { + "default": False, + "tooltip": "Unload DiT before decoding to reduce VRAM peak at the cost of speed." + }), + "seed": ("INT", { + "default": 0, + "min": 0, + "max": 1125899906842624 + }), + } + } + + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("image",) + FUNCTION = "main" + CATEGORY = "FlashVSR" + DESCRIPTION = 'Download the entire "FlashVSR" folder with all the files inside it from "https://huggingface.co/JunhaoZhuang/FlashVSR" and put it in the "ComfyUI/models"' + + def main(self, model, frames, mode, scale, tiled_vae, tiled_dit, unload_dit, seed): + wan_video_dit.USE_BLOCK_ATTN = False + _device = "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "auto" + if _device == "auto" or _device not in device_choices: + raise RuntimeError("No devices found to run FlashVSR!") + + pipe = init_pipeline(model, mode, _device, torch.float16) + output = flashvsr(pipe, frames, scale, True, tiled_vae, tiled_dit, 256, 24, unload_dit, 2.0, 3.0, 11, seed, True) + return(output,) + +NODE_CLASS_MAPPINGS = { + "FlashVSRNode": FlashVSRNode, + "FlashVSRNodeAdv": FlashVSRNodeAdv, + "FlashVSRInitPipe": FlashVSRNodeInitPipe, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "FlashVSRNode": "FlashVSR Ultra-Fast", + "FlashVSRNodeAdv": "FlashVSR Ultra-Fast (Advanced)", + "FlashVSRInitPipe": "FlashVSR Init Pipeline", +} \ No newline at end of file diff --git a/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/posi_prompt.pth b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/posi_prompt.pth new file mode 100644 index 0000000000000000000000000000000000000000..062e420f699425c3c844f813e862dcd8ef820e3d --- /dev/null +++ b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/posi_prompt.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4601107a11e4e11a936a6b79df579e54dbc99872132bf542151f0ffd65b4b1ef +size 4195504 diff --git a/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/requirements.txt b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..9412e6c10028f152820bbeb6a0b508ec0b581efa --- /dev/null +++ b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/requirements.txt @@ -0,0 +1,10 @@ +torch +torchvision +numpy +einops +safetensors +tqdm +pillow +huggingface_hub +triton; platform_system!="Windows" +triton-windows; platform_system=="Windows" diff --git a/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/LICENSE.txt b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/LICENSE.txt @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/__init__.py b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d2555fe40d2010d1227816853e72a8c78762fa0c --- /dev/null +++ b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/__init__.py @@ -0,0 +1,3 @@ +from .models import * +from .pipelines import * +from .schedulers import * \ No newline at end of file diff --git a/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/configs/__init__.py b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/configs/model_config.py b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/configs/model_config.py new file mode 100644 index 0000000000000000000000000000000000000000..e89ba3e4a472c50b879010eb32594418a4bff777 --- /dev/null +++ b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/configs/model_config.py @@ -0,0 +1,29 @@ +from typing_extensions import Literal, TypeAlias + +from ..models.wan_video_dit import WanModel +from ..models.wan_video_vae import WanVideoVAE + + +model_loader_configs = [ + # These configs are provided for detecting model type automatically. + # The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource) + (None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"), + (None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"), + (None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"), + (None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"), + (None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"), + (None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"), + (None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"), + (None, "3ef3b1f8e1dab83d5b71fd7b617f859f", ["wan_video_dit"], [WanModel], "civitai"), + (None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"), + (None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"), + (None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"), +] +huggingface_model_loader_configs = [ + # These configs are provided for detecting model type automatically. + # The format is (architecture_in_huggingface_config, huggingface_lib, model_name, redirected_architecture) +] +patch_model_loader_configs = [ + # These configs are provided for detecting model type automatically. + # The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs) +] \ No newline at end of file diff --git a/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/TCDecoder.py b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/TCDecoder.py new file mode 100644 index 0000000000000000000000000000000000000000..f2d55d41a0f4035adfb2db2e35492c0f38adeaa6 --- /dev/null +++ b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/TCDecoder.py @@ -0,0 +1,320 @@ +#!/usr/bin/env python3 +""" +Tiny AutoEncoder for Hunyuan Video (Decoder-only, pruned) +- Encoder removed +- Transplant/widening helpers removed +- Deepening (IdentityConv2d+ReLU) is now built into the decoder structure itself +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm.auto import tqdm +from collections import namedtuple +from einops import rearrange +import torch.nn.init as init + +DecoderResult = namedtuple("DecoderResult", ("frame", "memory")) +TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index")) + +# ---------------------------- +# Utility / building blocks +# ---------------------------- + +class IdentityConv2d(nn.Conv2d): + """Same-shape Conv2d initialized to identity (Dirac).""" + def __init__(self, C, kernel_size=3, bias=False): + pad = kernel_size // 2 + super().__init__(C, C, kernel_size, padding=pad, bias=bias) + with torch.no_grad(): + init.dirac_(self.weight) + if self.bias is not None: + self.bias.zero_() + +def conv(n_in, n_out, **kwargs): + return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) + +class Clamp(nn.Module): + def forward(self, x): + return torch.tanh(x / 3) * 3 + +class MemBlock(nn.Module): + def __init__(self, n_in, n_out): + super().__init__() + self.conv = nn.Sequential( + conv(n_in * 2, n_out), nn.ReLU(inplace=True), + conv(n_out, n_out), nn.ReLU(inplace=True), + conv(n_out, n_out) + ) + self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() + self.act = nn.ReLU(inplace=True) + def forward(self, x, past): + return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x)) + +class TPool(nn.Module): + def __init__(self, n_f, stride): + super().__init__() + self.stride = stride + self.conv = nn.Conv2d(n_f*stride, n_f, 1, bias=False) + def forward(self, x): + _NT, C, H, W = x.shape + return self.conv(x.reshape(-1, self.stride * C, H, W)) + +class TGrow(nn.Module): + def __init__(self, n_f, stride): + super().__init__() + self.stride = stride + self.conv = nn.Conv2d(n_f, n_f*stride, 1, bias=False) + def forward(self, x): + _NT, C, H, W = x.shape + x = self.conv(x) + return x.reshape(-1, C, H, W) + +class PixelShuffle3d(nn.Module): + def __init__(self, ff, hh, ww): + super().__init__() + self.ff = ff + self.hh = hh + self.ww = ww + def forward(self, x): + # x: (B, C, F, H, W) + B, C, F, H, W = x.shape + if F % self.ff != 0: + first_frame = x[:, :, 0:1, :, :].repeat(1, 1, self.ff - F % self.ff, 1, 1) + x = torch.cat([first_frame, x], dim=2) + return rearrange( + x, + 'b c (f ff) (h hh) (w ww) -> b (c ff hh ww) f h w', + ff=self.ff, hh=self.hh, ww=self.ww + ).transpose(1, 2) + +# ---------------------------- +# Generic NTCHW graph executor (kept; used by decoder) +# ---------------------------- + +def apply_model_with_memblocks(model, x, parallel, show_progress_bar, mem=None): + """ + Apply a sequential model with memblocks to the given input. + Args: + - model: nn.Sequential of blocks to apply + - x: input data, of dimensions NTCHW + - parallel: if True, parallelize over timesteps (fast but uses O(T) memory) + if False, each timestep will be processed sequentially (slow but uses O(1) memory) + - show_progress_bar: if True, enables tqdm progressbar display + + Returns NTCHW tensor of output data. + """ + assert x.ndim == 5, f"TAEHV operates on NTCHW tensors, but got {x.ndim}-dim tensor" + N, T, C, H, W = x.shape + if parallel: + x = x.reshape(N*T, C, H, W) + for b in tqdm(model, disable=not show_progress_bar): + if isinstance(b, MemBlock): + NT, C, H, W = x.shape + T = NT // N + _x = x.reshape(N, T, C, H, W) + mem = F.pad(_x, (0,0,0,0,0,0,1,0), value=0)[:,:T].reshape(x.shape) + x = b(x, mem) + else: + x = b(x) + NT, C, H, W = x.shape + T = NT // N + x = x.view(N, T, C, H, W) + else: + out = [] + work_queue = [TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(N, T * C, H, W).chunk(T, dim=1))] + progress_bar = tqdm(range(T), disable=not show_progress_bar) + while work_queue: + xt, i = work_queue.pop(0) + if i == 0: + progress_bar.update(1) + if i == len(model): + out.append(xt) + else: + b = model[i] + if isinstance(b, MemBlock): + if mem[i] is None: + xt_new = b(xt, xt * 0) + mem[i] = xt + else: + xt_new = b(xt, mem[i]) + mem[i].copy_(xt) + work_queue.insert(0, TWorkItem(xt_new, i+1)) + elif isinstance(b, TPool): + if mem[i] is None: + mem[i] = [] + mem[i].append(xt) + if len(mem[i]) > b.stride: + raise ValueError("TPool internal state invalid.") + elif len(mem[i]) == b.stride: + N_, C_, H_, W_ = xt.shape + xt = b(torch.cat(mem[i], 1).view(N_*b.stride, C_, H_, W_)) + mem[i] = [] + work_queue.insert(0, TWorkItem(xt, i+1)) + elif isinstance(b, TGrow): + xt = b(xt) + NT, C_, H_, W_ = xt.shape + for xt_next in reversed(xt.view(N, b.stride*C_, H_, W_).chunk(b.stride, 1)): + work_queue.insert(0, TWorkItem(xt_next, i+1)) + else: + xt = b(xt) + work_queue.insert(0, TWorkItem(xt, i+1)) + progress_bar.close() + x = torch.stack(out, 1) + return x, mem + +# ---------------------------- +# Decoder-only TAEHV +# ---------------------------- + +class TAEHV(nn.Module): + image_channels = 3 + def __init__( + self, + checkpoint_path="taehv.pth", + decoder_time_upscale=(True, True), + decoder_space_upscale=(True, True, True), + channels = [256, 128, 64, 64], + latent_channels = 16 + ): + """Initialize TAEHV (decoder-only) with built-in deepening after every ReLU. + Deepening config: how_many_each=1, k=3 (fixed as requested). + """ + super().__init__() + self.latent_channels = latent_channels + n_f = channels + self.frames_to_trim = 2**sum(decoder_time_upscale) - 1 + + # Build the decoder "skeleton" + base_decoder = nn.Sequential( + Clamp(), conv(self.latent_channels, n_f[0]), nn.ReLU(inplace=True), + + MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), + nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), + TGrow(n_f[0], 1), + conv(n_f[0], n_f[1], bias=False), + + MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), + nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), + TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), + conv(n_f[1], n_f[2], bias=False), + + MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), + nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), + TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), + conv(n_f[2], n_f[3], bias=False), + + nn.ReLU(inplace=True), conv(n_f[3], TAEHV.image_channels), + ) + + # Inline deepening: insert (IdentityConv2d(k=3) + ReLU) after every ReLU + self.decoder = self._apply_identity_deepen(base_decoder, how_many_each=1, k=3) + + self.pixel_shuffle = PixelShuffle3d(4, 8, 8) + + if checkpoint_path is not None: + missing_keys = self.load_state_dict( + self.patch_tgrow_layers(torch.load(checkpoint_path, map_location="cpu", weights_only=True)), + strict=False + ) + print('missing_keys', missing_keys) + + # Initialize decoder mem state + self.mem = [None] * len(self.decoder) + + @staticmethod + def _apply_identity_deepen(decoder: nn.Sequential, how_many_each=1, k=3) -> nn.Sequential: + """Return a new Sequential where every nn.ReLU is followed by how_many_each*(IdentityConv2d(k)+ReLU).""" + new_layers = [] + for b in decoder: + new_layers.append(b) + if isinstance(b, nn.ReLU): + # Deduce channel count from preceding layer + C = None + if len(new_layers) >= 2 and isinstance(new_layers[-2], nn.Conv2d): + C = new_layers[-2].out_channels + elif len(new_layers) >= 2 and isinstance(new_layers[-2], MemBlock): + C = new_layers[-2].conv[-1].out_channels + if C is not None: + for _ in range(how_many_each): + new_layers.append(IdentityConv2d(C, kernel_size=k, bias=False)) + new_layers.append(nn.ReLU(inplace=True)) + return nn.Sequential(*new_layers) + + def patch_tgrow_layers(self, sd): + """Patch TGrow layers to use a smaller kernel if needed (decoder-only).""" + new_sd = self.state_dict() + for i, layer in enumerate(self.decoder): + if isinstance(layer, TGrow): + key = f"decoder.{i}.conv.weight" + if key in sd and sd[key].shape[0] > new_sd[key].shape[0]: + sd[key] = sd[key][-new_sd[key].shape[0]:] + return sd + + def decode_video(self, x, parallel=True, show_progress_bar=False, cond=None): + """Decode a sequence of frames from latents. + x: NTCHW latent tensor; returns NTCHW RGB in ~[0, 1]. + """ + trim_flag = self.mem[-8] is None # keeps original relative check + + if cond is not None: + x = torch.cat([self.pixel_shuffle(cond), x], dim=2) + + x, self.mem = apply_model_with_memblocks(self.decoder, x, parallel, show_progress_bar, mem=self.mem) + + if trim_flag: + return x[:, self.frames_to_trim:] + return x + + def forward(self, *args, **kwargs): + raise NotImplementedError("Decoder-only model: call decode_video(...) instead.") + + def clean_mem(self): + self.mem = [None] * len(self.decoder) + +class DotDict(dict): + __getattr__ = dict.__getitem__ + __setattr__ = dict.__setitem__ + +class TAEW2_1DiffusersWrapper(nn.Module): + def __init__(self, pretrained_path=None, channels = [256, 128, 64, 64]): + super().__init__() + self.dtype = torch.bfloat16 + self.device = "cuda" + self.taehv = TAEHV(pretrained_path, channels = channels).to(self.dtype) + self.temperal_downsample = [True, True, False] # [sic] + self.config = DotDict(scaling_factor=1.0, latents_mean=torch.zeros(16), z_dim=16, latents_std=torch.ones(16)) + + def decode(self, latents, return_dict=None): + n, c, t, h, w = latents.shape + return (self.taehv.decode_video(latents.transpose(1, 2), parallel=False).transpose(1, 2).mul_(2).sub_(1),) + + def stream_decode_with_cond(self, latents, tiled=False, cond=None): + n, c, t, h, w = latents.shape + return self.taehv.decode_video(latents.transpose(1, 2), parallel=False, cond=cond).transpose(1, 2).mul_(2).sub_(1) + + def clean_mem(self): + self.taehv.clean_mem() + +# ---------------------------- +# Simplified builder (no small, no transplant, no post-hoc deepening) +# ---------------------------- + +def build_tcdecoder(new_channels = [512, 256, 128, 128], + device="cuda", + dtype=torch.bfloat16, + new_latent_channels=None): + """ + 构建“更宽”的 decoder;深度增强(IdentityConv2d+ReLU)已在 TAEHV 内部完成。 + - 不创建 small / 不做移植 + - base_ckpt_path 参数保留但不使用(接口兼容) + + 返回:big (单个模型) + """ + if new_latent_channels is not None: + big = TAEHV(checkpoint_path=None, channels=new_channels, latent_channels=new_latent_channels).to(device).to(dtype).train() + else: + big = TAEHV(checkpoint_path=None, channels=new_channels).to(device).to(dtype).train() + + big.clean_mem() + return big diff --git a/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/__init__.py b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..96707b666371c39d4ba59a839d5ddfeafb5d1d43 --- /dev/null +++ b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/__init__.py @@ -0,0 +1 @@ +from .model_manager import * diff --git a/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/model_manager.py b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/model_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..9c1412cdf6e68aeaed7ee520930dd1464a13f4fc --- /dev/null +++ b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/model_manager.py @@ -0,0 +1,402 @@ +import os, torch, json, importlib +from typing import List + +from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs +from .utils import load_state_dict, init_weights_on_device, hash_state_dict_keys, split_state_dict_with_prefix + +def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device): + loaded_model_names, loaded_models = [], [] + for model_name, model_class in zip(model_names, model_classes): + #print(f" model_name: {model_name} model_class: {model_class.__name__}") + state_dict_converter = model_class.state_dict_converter() + if model_resource == "civitai": + state_dict_results = state_dict_converter.from_civitai(state_dict) + elif model_resource == "diffusers": + state_dict_results = state_dict_converter.from_diffusers(state_dict) + if isinstance(state_dict_results, tuple): + model_state_dict, extra_kwargs = state_dict_results + #print(f" This model is initialized with extra kwargs: {extra_kwargs}") + else: + model_state_dict, extra_kwargs = state_dict_results, {} + torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype + with init_weights_on_device(): + model = model_class(**extra_kwargs) + if hasattr(model, "eval"): + model = model.eval() + model.load_state_dict(model_state_dict, assign=True) + model = model.to(dtype=torch_dtype, device=device) + loaded_model_names.append(model_name) + loaded_models.append(model) + return loaded_model_names, loaded_models + + +def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device): + loaded_model_names, loaded_models = [], [] + for model_name, model_class in zip(model_names, model_classes): + if torch_dtype in [torch.float32, torch.float16, torch.bfloat16]: + model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval() + else: + model = model_class.from_pretrained(file_path).eval().to(dtype=torch_dtype) + if torch_dtype == torch.float16 and hasattr(model, "half"): + model = model.half() + try: + model = model.to(device=device) + except: + pass + loaded_model_names.append(model_name) + loaded_models.append(model) + return loaded_model_names, loaded_models + + +def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device): + #print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}") + base_state_dict = base_model.state_dict() + base_model.to("cpu") + del base_model + model = model_class(**extra_kwargs) + model.load_state_dict(base_state_dict, strict=False) + model.load_state_dict(state_dict, strict=False) + model.to(dtype=torch_dtype, device=device) + return model + + +def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device): + loaded_model_names, loaded_models = [], [] + for model_name, model_class in zip(model_names, model_classes): + while True: + for model_id in range(len(model_manager.model)): + base_model_name = model_manager.model_name[model_id] + if base_model_name == model_name: + base_model_path = model_manager.model_path[model_id] + base_model = model_manager.model[model_id] + print(f" Adding patch model to {base_model_name} ({base_model_path})") + patched_model = load_single_patch_model_from_single_file( + state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device) + loaded_model_names.append(base_model_name) + loaded_models.append(patched_model) + model_manager.model.pop(model_id) + model_manager.model_path.pop(model_id) + model_manager.model_name.pop(model_id) + break + else: + break + return loaded_model_names, loaded_models + + + +class ModelDetectorTemplate: + def __init__(self): + pass + + def match(self, file_path="", state_dict={}): + return False + + def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs): + return [], [] + + + +class ModelDetectorFromSingleFile: + def __init__(self, model_loader_configs=[]): + self.keys_hash_with_shape_dict = {} + self.keys_hash_dict = {} + for metadata in model_loader_configs: + self.add_model_metadata(*metadata) + + + def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource): + self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource) + if keys_hash is not None: + self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource) + + + def match(self, file_path="", state_dict={}): + if isinstance(file_path, str) and os.path.isdir(file_path): + return False + if len(state_dict) == 0: + state_dict = load_state_dict(file_path) + keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True) + if keys_hash_with_shape in self.keys_hash_with_shape_dict: + return True + keys_hash = hash_state_dict_keys(state_dict, with_shape=False) + if keys_hash in self.keys_hash_dict: + return True + return False + + + def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs): + if len(state_dict) == 0: + state_dict = load_state_dict(file_path) + + # Load models with strict matching + keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True) + if keys_hash_with_shape in self.keys_hash_with_shape_dict: + model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape] + loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device) + return loaded_model_names, loaded_models + + # Load models without strict matching + # (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture) + keys_hash = hash_state_dict_keys(state_dict, with_shape=False) + if keys_hash in self.keys_hash_dict: + model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash] + loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device) + return loaded_model_names, loaded_models + + return loaded_model_names, loaded_models + + + +class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile): + def __init__(self, model_loader_configs=[]): + super().__init__(model_loader_configs) + + + def match(self, file_path="", state_dict={}): + if isinstance(file_path, str) and os.path.isdir(file_path): + return False + if len(state_dict) == 0: + state_dict = load_state_dict(file_path) + splited_state_dict = split_state_dict_with_prefix(state_dict) + for sub_state_dict in splited_state_dict: + if super().match(file_path, sub_state_dict): + return True + return False + + + def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs): + # Split the state_dict and load from each component + splited_state_dict = split_state_dict_with_prefix(state_dict) + valid_state_dict = {} + for sub_state_dict in splited_state_dict: + if super().match(file_path, sub_state_dict): + valid_state_dict.update(sub_state_dict) + if super().match(file_path, valid_state_dict): + loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype) + else: + loaded_model_names, loaded_models = [], [] + for sub_state_dict in splited_state_dict: + if super().match(file_path, sub_state_dict): + loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype) + loaded_model_names += loaded_model_names_ + loaded_models += loaded_models_ + return loaded_model_names, loaded_models + + + +class ModelDetectorFromHuggingfaceFolder: + def __init__(self, model_loader_configs=[]): + self.architecture_dict = {} + for metadata in model_loader_configs: + self.add_model_metadata(*metadata) + + + def add_model_metadata(self, architecture, huggingface_lib, model_name, redirected_architecture): + self.architecture_dict[architecture] = (huggingface_lib, model_name, redirected_architecture) + + + def match(self, file_path="", state_dict={}): + if not isinstance(file_path, str) or os.path.isfile(file_path): + return False + file_list = os.listdir(file_path) + if "config.json" not in file_list: + return False + with open(os.path.join(file_path, "config.json"), "r") as f: + config = json.load(f) + if "architectures" not in config and "_class_name" not in config: + return False + return True + + + def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs): + with open(os.path.join(file_path, "config.json"), "r") as f: + config = json.load(f) + loaded_model_names, loaded_models = [], [] + architectures = config["architectures"] if "architectures" in config else [config["_class_name"]] + for architecture in architectures: + huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture] + if redirected_architecture is not None: + architecture = redirected_architecture + model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture) + loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device) + loaded_model_names += loaded_model_names_ + loaded_models += loaded_models_ + return loaded_model_names, loaded_models + + + +class ModelDetectorFromPatchedSingleFile: + def __init__(self, model_loader_configs=[]): + self.keys_hash_with_shape_dict = {} + for metadata in model_loader_configs: + self.add_model_metadata(*metadata) + + + def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs): + self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs) + + + def match(self, file_path="", state_dict={}): + if not isinstance(file_path, str) or os.path.isdir(file_path): + return False + if len(state_dict) == 0: + state_dict = load_state_dict(file_path) + keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True) + if keys_hash_with_shape in self.keys_hash_with_shape_dict: + return True + return False + + + def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs): + if len(state_dict) == 0: + state_dict = load_state_dict(file_path) + + # Load models with strict matching + loaded_model_names, loaded_models = [], [] + keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True) + if keys_hash_with_shape in self.keys_hash_with_shape_dict: + model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape] + loaded_model_names_, loaded_models_ = load_patch_model_from_single_file( + state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device) + loaded_model_names += loaded_model_names_ + loaded_models += loaded_models_ + return loaded_model_names, loaded_models + + + +class ModelManager: + def __init__( + self, + torch_dtype=torch.float16, + device="cuda", + file_path_list: List[str] = [], + ): + self.torch_dtype = torch_dtype + self.device = device + self.model = [] + self.model_path = [] + self.model_name = [] + self.model_detector = [ + ModelDetectorFromSingleFile(model_loader_configs), + ModelDetectorFromSplitedSingleFile(model_loader_configs), + ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs), + ModelDetectorFromPatchedSingleFile(patch_model_loader_configs), + ] + self.load_models(file_path_list) + + + def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None): + print(f"Loading models from file: {file_path}") + if len(state_dict) == 0: + state_dict = load_state_dict(file_path) + model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device) + for model_name, model in zip(model_names, models): + self.model.append(model) + self.model_path.append(file_path) + self.model_name.append(model_name) + #print(f" The following models are loaded: {model_names}.") + + + def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]): + print(f"Loading models from folder: {file_path}") + model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device) + for model_name, model in zip(model_names, models): + self.model.append(model) + self.model_path.append(file_path) + self.model_name.append(model_name) + #print(f" The following models are loaded: {model_names}.") + + + def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}): + print(f"Loading patch models from file: {file_path}") + model_names, models = load_patch_model_from_single_file( + state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device) + for model_name, model in zip(model_names, models): + self.model.append(model) + self.model_path.append(file_path) + self.model_name.append(model_name) + print(f" The following patched models are loaded: {model_names}.") + + + def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0): + if isinstance(file_path, list): + for file_path_ in file_path: + self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha) + else: + print(f"Loading LoRA models from file: {file_path}") + is_loaded = False + if len(state_dict) == 0: + state_dict = load_state_dict(file_path) + for model_name, model, model_path in zip(self.model_name, self.model, self.model_path): + for lora in get_lora_loaders(): + match_results = lora.match(model, state_dict) + if match_results is not None: + print(f" Adding LoRA to {model_name} ({model_path}).") + lora_prefix, model_resource = match_results + lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource) + is_loaded = True + break + if not is_loaded: + print(f" Cannot load LoRA: {file_path}") + + + def load_model(self, file_path, model_names=None, device=None, torch_dtype=None): + #print(f"Loading models from: {file_path}") + if device is None: device = self.device + if torch_dtype is None: torch_dtype = self.torch_dtype + if isinstance(file_path, list): + state_dict = {} + for path in file_path: + state_dict.update(load_state_dict(path)) + elif os.path.isfile(file_path): + state_dict = load_state_dict(file_path) + else: + state_dict = None + for model_detector in self.model_detector: + if model_detector.match(file_path, state_dict): + model_names, models = model_detector.load( + file_path, state_dict, + device=device, torch_dtype=torch_dtype, + allowed_model_names=model_names, model_manager=self + ) + for model_name, model in zip(model_names, models): + self.model.append(model) + self.model_path.append(file_path) + self.model_name.append(model_name) + #print(f" The following models are loaded: {model_names}.") + break + else: + print(f" We cannot detect the model type. No models are loaded.") + + + def load_models(self, file_path_list, model_names=None, device=None, torch_dtype=None): + for file_path in file_path_list: + self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype) + + + def fetch_model(self, model_name, file_path=None, require_model_path=False): + fetched_models = [] + fetched_model_paths = [] + for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name): + if file_path is not None and file_path != model_path: + continue + if model_name == model_name_: + fetched_models.append(model) + fetched_model_paths.append(model_path) + if len(fetched_models) == 0: + #print(f"No {model_name} models available.") + return None + if len(fetched_models) == 1: + print(f"Using {model_name} from {fetched_model_paths[0]}") + else: + print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}") + if require_model_path: + return fetched_models[0], fetched_model_paths[0] + else: + return fetched_models[0] + + + def to(self, device): + for model in self.model: + model.to(device) + diff --git a/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/sparse_sage/LICENSE.txt b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/sparse_sage/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..5ac25aa2a86b7bc0a7fe297b7091c248425084f0 --- /dev/null +++ b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/sparse_sage/LICENSE.txt @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2025 Jintao Zhang + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/sparse_sage/core.py b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/sparse_sage/core.py new file mode 100644 index 0000000000000000000000000000000000000000..62ba86395fe5a111d5f744e22da8e08dd0149e9f --- /dev/null +++ b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/sparse_sage/core.py @@ -0,0 +1,45 @@ +""" +https://github.com/jt-zhang/Sparse_SageAttention_API + +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from .quant_per_block import per_block_int8 +from .sparse_int8_attn import forward as sparse_sageattn_fwd +import torch + + +def sparse_sageattn(q, k, v, mask_id = None, is_causal=False, tensor_layout="HND"): + if mask_id is None: + mask_id = torch.ones((q.shape[0], q.shape[1], (q.shape[2] + 128 - 1)//128, (q.shape[3] + 64 - 1)//64), dtype=torch.int8, device=q.device) # TODO + + output_dtype = q.dtype + if output_dtype == torch.bfloat16 or output_dtype == torch.float32: + v = v.to(torch.float16) + + seq_dim = 1 if tensor_layout == "NHD" else 2 + km = k.mean(dim=seq_dim, keepdim=True) + # km = torch.zeros((k.size(0), k.size(1), 1, k.size(3)), dtype=torch.float16, device=k.device) # Placeholder for mean, not used in quantization + + q_int8, q_scale, k_int8, k_scale = per_block_int8(q, k, km=km, tensor_layout=tensor_layout) + + o = sparse_sageattn_fwd( + q_int8, k_int8, mask_id, v, q_scale, k_scale, + is_causal=is_causal, tensor_layout=tensor_layout, output_dtype=output_dtype + ) + return o + + +# flops = 4 * q.size(0) * q.size(1) * q.size(2)**2 * q.size(3) / (2 if is_causal else 1) diff --git a/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/sparse_sage/quant_per_block.py b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/sparse_sage/quant_per_block.py new file mode 100644 index 0000000000000000000000000000000000000000..1a8dab88447f250f2428a34d2cb119aab2604e4f --- /dev/null +++ b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/sparse_sage/quant_per_block.py @@ -0,0 +1,101 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import triton +import triton.language as tl + +@triton.jit +def quant_per_block_int8_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + sm_scale, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.arange(0, BLK) + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + x *= sm_scale + scale = tl.max(tl.abs(x)) / 127. + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +def per_block_int8(q, k, km=None, BLKQ=128, BLKK=64, sm_scale=None, tensor_layout="HND"): + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if km is not None: + k = k - km + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(1), q_int8.stride(2) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(1), k_int8.stride(2) + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(2), q_int8.stride(1) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(2), k_int8.stride(1) + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + q_scale = torch.empty((b, h_qo, (qo_len + BLKQ - 1) // BLKQ), device=q.device, dtype=torch.float32) + k_scale = torch.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + grid = ((qo_len + BLKQ - 1) // BLKQ, h_qo, b) + quant_per_block_int8_kernel[grid]( + q, q_int8, q_scale, qo_len, + stride_bz_q, stride_h_q, stride_seq_q, + stride_bz_qo, stride_h_qo, stride_seq_qo, + q_scale.stride(0), q_scale.stride(1), + sm_scale=(sm_scale * 1.44269504), + C=head_dim, BLK=BLKQ + ) + + grid = ((kv_len + BLKK - 1) // BLKK, h_kv, b) + quant_per_block_int8_kernel[grid]( + k, k_int8, k_scale, kv_len, + stride_bz_k, stride_h_k, stride_seq_k, + stride_bz_ko, stride_h_ko, stride_seq_ko, + k_scale.stride(0), k_scale.stride(1), + sm_scale=1.0, + C=head_dim, BLK=BLKK + ) + + return q_int8, q_scale, k_int8, k_scale \ No newline at end of file diff --git a/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/sparse_sage/sparse_int8_attn.py b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/sparse_sage/sparse_int8_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..0717119bb78de204000f125f49251a676fd1eca4 --- /dev/null +++ b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/sparse_sage/sparse_int8_attn.py @@ -0,0 +1,162 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch, math +import triton +import triton.language as tl +import torch.nn.functional as F + +@triton.jit +def _attn_fwd_inner(acc, l_i, old_m, q, q_scale, kv_len, + K_ptrs, K_bid_ptr, K_scale_ptr, V_ptrs, stride_kn, stride_vn, start_m, + BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, + STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, + ): + if STAGE == 1: + lo, hi = 0, start_m * BLOCK_M + elif STAGE == 2: + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + lo = tl.multiple_of(lo, BLOCK_M) + K_scale_ptr += lo // BLOCK_N + K_ptrs += stride_kn * lo + V_ptrs += stride_vn * lo + elif STAGE == 3: + lo, hi = 0, kv_len + for start_n in range(lo, hi, BLOCK_N): + kbid = tl.load(K_bid_ptr + start_n//BLOCK_N) + if kbid: + k_mask = offs_n[None, :] < (kv_len - start_n) + k = tl.load(K_ptrs, mask = k_mask) + k_scale = tl.load(K_scale_ptr) + qk = tl.dot(q, k).to(tl.float32) * q_scale * k_scale + if STAGE == 2: + mask = offs_m[:, None] >= (start_n + offs_n[None, :]) + qk = qk + tl.where(mask, 0, -1.0e6) + local_m = tl.max(qk, 1) + new_m = tl.maximum(old_m, local_m) + qk -= new_m[:, None] + else: + local_m = tl.max(qk, 1) + new_m = tl.maximum(old_m, local_m) + qk = qk - new_m[:, None] + + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + alpha = tl.math.exp2(old_m - new_m) + l_i = l_i * alpha + l_ij + acc = acc * alpha[:, None] + v = tl.load(V_ptrs, mask = offs_n[:, None] < (kv_len - start_n)) + p = p.to(tl.float16) + acc += tl.dot(p, v, out_dtype=tl.float16) + old_m = new_m + K_ptrs += BLOCK_N * stride_kn + K_scale_ptr += 1 + V_ptrs += BLOCK_N * stride_vn + return acc, l_i, old_m + +@triton.jit +def _attn_fwd(Q, K, K_blkid, V, Q_scale, K_scale, Out, + stride_qz, stride_qh, stride_qn, + stride_kz, stride_kh, stride_kn, + stride_vz, stride_vh, stride_vn, + stride_oz, stride_oh, stride_on, + stride_kbidq, stride_kbidk, + qo_len, kv_len, H:tl.constexpr, num_kv_groups:tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGE: tl.constexpr + ): + start_m = tl.program_id(0) + off_z = tl.program_id(2).to(tl.int64) + off_h = tl.program_id(1).to(tl.int64) + q_scale_offset = (off_z * H + off_h) * tl.cdiv(qo_len, BLOCK_M) + k_scale_offset = (off_z * (H // num_kv_groups) + off_h // num_kv_groups) * tl.cdiv(kv_len, BLOCK_N) + k_bid_offset = (off_z * (H // num_kv_groups) + off_h // num_kv_groups) * stride_kbidq + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, HEAD_DIM) + Q_ptrs = Q + (off_z * stride_qz + off_h * stride_qh) + offs_m[:, None] * stride_qn + offs_k[None, :] + Q_scale_ptr = Q_scale + q_scale_offset + start_m + K_ptrs = K + (off_z * stride_kz + (off_h // num_kv_groups) * stride_kh) + offs_n[None, :] * stride_kn + offs_k[:, None] + K_scale_ptr = K_scale + k_scale_offset + K_bid_ptr = K_blkid + k_bid_offset + start_m * stride_kbidk + V_ptrs = V + (off_z * stride_vz + (off_h // num_kv_groups) * stride_vh) + offs_n[:, None] * stride_vn + offs_k[None, :] + O_block_ptr = Out + (off_z * stride_oz + off_h * stride_oh) + offs_m[:, None] * stride_on + offs_k[None, :] + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + q = tl.load(Q_ptrs, mask = offs_m[:, None] < qo_len) + q_scale = tl.load(Q_scale_ptr) + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, q_scale, kv_len, K_ptrs, K_bid_ptr, K_scale_ptr, V_ptrs, stride_kn, stride_vn, + start_m, + BLOCK_M, HEAD_DIM, BLOCK_N, + 4 - STAGE, offs_m, offs_n + ) + if STAGE != 1: + acc, l_i, _ = _attn_fwd_inner(acc, l_i, m_i, q, q_scale, kv_len, K_ptrs, K_bid_ptr, K_scale_ptr, V_ptrs, stride_kn, stride_vn, + start_m, + BLOCK_M, HEAD_DIM, BLOCK_N, + 2, offs_m, offs_n + ) + acc = acc / l_i[:, None] + tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask = (offs_m[:, None] < qo_len)) + + +def forward(q, k, k_block_id, v, q_scale, k_scale, is_causal=False, tensor_layout="HND", output_dtype=torch.float16): + BLOCK_M = 128 + BLOCK_N = 64 + stage = 3 if is_causal else 1 + o = torch.empty(q.shape, dtype=output_dtype, device=q.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2) + stride_bz_v, stride_h_v, stride_seq_v = v.stride(0), v.stride(1), v.stride(2) + stride_bz_o, stride_h_o, stride_seq_o = o.stride(0), o.stride(1), o.stride(2) + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1) + stride_bz_v, stride_h_v, stride_seq_v = v.stride(0), v.stride(2), v.stride(1) + stride_bz_o, stride_h_o, stride_seq_o = o.stride(0), o.stride(2), o.stride(1) + else: + raise ValueError(f"tensor_layout {tensor_layout} not supported") + + if is_causal: + assert qo_len == kv_len, "qo_len and kv_len must be equal for causal attention" + + HEAD_DIM_K = head_dim + num_kv_groups = h_qo // h_kv + + grid = (triton.cdiv(qo_len, BLOCK_M), h_qo, b ) + _attn_fwd[grid]( + q, k, k_block_id, v, q_scale, k_scale, o, + stride_bz_q, stride_h_q, stride_seq_q, + stride_bz_k, stride_h_k, stride_seq_k, + stride_bz_v, stride_h_v, stride_seq_v, + stride_bz_o, stride_h_o, stride_seq_o, + k_block_id.stride(1), k_block_id.stride(2), + qo_len, kv_len, + h_qo, num_kv_groups, + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, HEAD_DIM=HEAD_DIM_K, + STAGE=stage, + num_warps=4 if head_dim == 64 else 8, + num_stages=4) + return o diff --git a/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/utils.py b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..94eaa66d1699965f062254f643bcf0bf71d8f8e3 --- /dev/null +++ b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/utils.py @@ -0,0 +1,462 @@ +import torch, os, gc +from safetensors import safe_open +from contextlib import contextmanager +from einops import rearrange, repeat +import torch.nn as nn +import torch.nn.functional as F +from tqdm import tqdm +import time +import hashlib + +CACHE_T = 2 + +@contextmanager +def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False): + + old_register_parameter = torch.nn.Module.register_parameter + if include_buffers: + old_register_buffer = torch.nn.Module.register_buffer + + def register_empty_parameter(module, name, param): + old_register_parameter(module, name, param) + if param is not None: + param_cls = type(module._parameters[name]) + kwargs = module._parameters[name].__dict__ + kwargs["requires_grad"] = param.requires_grad + module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) + + def register_empty_buffer(module, name, buffer, persistent=True): + old_register_buffer(module, name, buffer, persistent=persistent) + if buffer is not None: + module._buffers[name] = module._buffers[name].to(device) + + def patch_tensor_constructor(fn): + def wrapper(*args, **kwargs): + kwargs["device"] = device + return fn(*args, **kwargs) + + return wrapper + + if include_buffers: + tensor_constructors_to_patch = { + torch_function_name: getattr(torch, torch_function_name) + for torch_function_name in ["empty", "zeros", "ones", "full"] + } + else: + tensor_constructors_to_patch = {} + + try: + torch.nn.Module.register_parameter = register_empty_parameter + if include_buffers: + torch.nn.Module.register_buffer = register_empty_buffer + for torch_function_name in tensor_constructors_to_patch.keys(): + setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name))) + yield + finally: + torch.nn.Module.register_parameter = old_register_parameter + if include_buffers: + torch.nn.Module.register_buffer = old_register_buffer + for torch_function_name, old_torch_function in tensor_constructors_to_patch.items(): + setattr(torch, torch_function_name, old_torch_function) + +def load_state_dict_from_folder(file_path, torch_dtype=None): + state_dict = {} + for file_name in os.listdir(file_path): + if "." in file_name and file_name.split(".")[-1] in [ + "safetensors", "bin", "ckpt", "pth", "pt" + ]: + state_dict.update(load_state_dict(os.path.join(file_path, file_name), torch_dtype=torch_dtype)) + return state_dict + + +def load_state_dict(file_path, torch_dtype=None): + if file_path.endswith(".safetensors"): + return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype) + else: + return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype) + + +def load_state_dict_from_safetensors(file_path, torch_dtype=None): + state_dict = {} + with safe_open(file_path, framework="pt", device="cpu") as f: + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + if torch_dtype is not None: + state_dict[k] = state_dict[k].to(torch_dtype) + return state_dict + + +def load_state_dict_from_bin(file_path, torch_dtype=None): + state_dict = torch.load(file_path, map_location="cpu", weights_only=True) + if torch_dtype is not None: + for i in state_dict: + if isinstance(state_dict[i], torch.Tensor): + state_dict[i] = state_dict[i].to(torch_dtype) + return state_dict + + +def search_for_embeddings(state_dict): + embeddings = [] + for k in state_dict: + if isinstance(state_dict[k], torch.Tensor): + embeddings.append(state_dict[k]) + elif isinstance(state_dict[k], dict): + embeddings += search_for_embeddings(state_dict[k]) + return embeddings + + +def search_parameter(param, state_dict): + for name, param_ in state_dict.items(): + if param.numel() == param_.numel(): + if param.shape == param_.shape: + if torch.dist(param, param_) < 1e-3: + return name + else: + if torch.dist(param.flatten(), param_.flatten()) < 1e-3: + return name + return None + + +def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False): + matched_keys = set() + with torch.no_grad(): + for name in source_state_dict: + rename = search_parameter(source_state_dict[name], target_state_dict) + if rename is not None: + print(f'"{name}": "{rename}",') + matched_keys.add(rename) + elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0: + length = source_state_dict[name].shape[0] // 3 + rename = [] + for i in range(3): + rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict)) + if None not in rename: + print(f'"{name}": {rename},') + for rename_ in rename: + matched_keys.add(rename_) + for name in target_state_dict: + if name not in matched_keys: + print("Cannot find", name, target_state_dict[name].shape) + + +def search_for_files(folder, extensions): + files = [] + if os.path.isdir(folder): + for file in sorted(os.listdir(folder)): + files += search_for_files(os.path.join(folder, file), extensions) + elif os.path.isfile(folder): + for extension in extensions: + if folder.endswith(extension): + files.append(folder) + break + return files + + +def convert_state_dict_keys_to_single_str(state_dict, with_shape=True): + keys = [] + for key, value in state_dict.items(): + if isinstance(key, str): + if isinstance(value, torch.Tensor): + if with_shape: + shape = "_".join(map(str, list(value.shape))) + keys.append(key + ":" + shape) + keys.append(key) + elif isinstance(value, dict): + keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape)) + keys.sort() + keys_str = ",".join(keys) + return keys_str + + +def split_state_dict_with_prefix(state_dict): + keys = sorted([key for key in state_dict if isinstance(key, str)]) + prefix_dict = {} + for key in keys: + prefix = key if "." not in key else key.split(".")[0] + if prefix not in prefix_dict: + prefix_dict[prefix] = [] + prefix_dict[prefix].append(key) + state_dicts = [] + for prefix, keys in prefix_dict.items(): + sub_state_dict = {key: state_dict[key] for key in keys} + state_dicts.append(sub_state_dict) + return state_dicts + +def hash_state_dict_keys(state_dict, with_shape=True): + keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape) + keys_str = keys_str.encode(encoding="UTF-8") + return hashlib.md5(keys_str).hexdigest() + +def clean_vram(): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + if torch.backends.mps.is_available(): + torch.mps.empty_cache() + +def get_device_list(): + devs = ["auto"] + try: + if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_available") and torch.cuda.is_available(): + devs += [f"cuda:{i}" for i in range(torch.cuda.device_count())] + except Exception: + pass + try: + if hasattr(torch, "mps") and hasattr(torch.mps, "is_available") and torch.backends.mps.is_available(): + devs += [f"mps:{i}" for i in range(torch.mps.device_count())] + except Exception: + pass + return devs + +class RMS_norm(nn.Module): + + def __init__(self, dim, channel_first=True, images=True, bias=False): + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0. + + def forward(self, x): + return F.normalize( + x, dim=(1 if self.channel_first else + -1)) * self.scale * self.gamma + self.bias + +class CausalConv3d(nn.Conv3d): + """ + Causal 3d convolusion. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._padding = (self.padding[2], self.padding[2], self.padding[1], + self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + # print(cache_x.shape, x.shape) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + # print('cache!') + x = F.pad(x, padding, mode='replicate') # mode='replicate' + # print(x[0,0,:,0,0]) + + return super().forward(x) + +class PixelShuffle3d(nn.Module): + def __init__(self, ff, hh, ww): + super().__init__() + self.ff = ff + self.hh = hh + self.ww = ww + + def forward(self, x): + # x: (B, C, F, H, W) + return rearrange(x, + 'b c (f ff) (h hh) (w ww) -> b (c ff hh ww) f h w', + ff=self.ff, hh=self.hh, ww=self.ww) + +class Buffer_LQ4x_Proj(nn.Module): + + def __init__(self, in_dim, out_dim, layer_num=30): + super().__init__() + self.ff = 1 + self.hh = 16 + self.ww = 16 + self.hidden_dim1 = 2048 + self.hidden_dim2 = 3072 + self.layer_num = layer_num + + self.pixel_shuffle = PixelShuffle3d(self.ff, self.hh, self.ww) + + self.conv1 = CausalConv3d(in_dim*self.ff*self.hh*self.ww, self.hidden_dim1, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) # f -> f/2 h -> h w -> w + self.norm1 = RMS_norm(self.hidden_dim1, images=False) + self.act1 = nn.SiLU() + + self.conv2 = CausalConv3d(self.hidden_dim1, self.hidden_dim2, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) # f -> f/2 h -> h w -> w + self.norm2 = RMS_norm(self.hidden_dim2, images=False) + self.act2 = nn.SiLU() + + self.linear_layers = nn.ModuleList([nn.Linear(self.hidden_dim2, out_dim) for _ in range(layer_num)]) + + self.clip_idx = 0 + + def forward(self, video): + self.clear_cache() + # x: (B, C, F, H, W) + + t = video.shape[2] + iter_ = 1 + (t - 1) // 4 + first_frame = video[:, :, :1, :, :].repeat(1, 1, 3, 1, 1) + video = torch.cat([first_frame, video], dim=2) + # print(video.shape) + + out_x = [] + for i in range(iter_): + x = self.pixel_shuffle(video[:,:,i*4:(i+1)*4,:,:]) + cache1_x = x[:, :, -CACHE_T:, :, :].clone() + self.cache['conv1'] = cache1_x + x = self.conv1(x, self.cache['conv1']) + x = self.norm1(x) + x = self.act1(x) + cache2_x = x[:, :, -CACHE_T:, :, :].clone() + self.cache['conv2'] = cache2_x + if i == 0: + continue + x = self.conv2(x, self.cache['conv2']) + x = self.norm2(x) + x = self.act2(x) + out_x.append(x) + out_x = torch.cat(out_x, dim = 2) + # print(out_x.shape) + out_x = rearrange(out_x, 'b c f h w -> b (f h w) c') + outputs = [] + for i in range(self.layer_num): + outputs.append(self.linear_layers[i](out_x)) + return outputs + + def clear_cache(self): + self.cache = {} + self.cache['conv1'] = None + self.cache['conv2'] = None + self.clip_idx = 0 + + def stream_forward(self, video_clip): + if self.clip_idx == 0: + # self.clear_cache() + first_frame = video_clip[:, :, :1, :, :].repeat(1, 1, 3, 1, 1) + video_clip = torch.cat([first_frame, video_clip], dim=2) + x = self.pixel_shuffle(video_clip) + cache1_x = x[:, :, -CACHE_T:, :, :].clone() + self.cache['conv1'] = cache1_x + x = self.conv1(x, self.cache['conv1']) + x = self.norm1(x) + x = self.act1(x) + cache2_x = x[:, :, -CACHE_T:, :, :].clone() + self.cache['conv2'] = cache2_x + self.clip_idx += 1 + return None + else: + x = self.pixel_shuffle(video_clip) + cache1_x = x[:, :, -CACHE_T:, :, :].clone() + self.cache['conv1'] = cache1_x + x = self.conv1(x, self.cache['conv1']) + x = self.norm1(x) + x = self.act1(x) + cache2_x = x[:, :, -CACHE_T:, :, :].clone() + self.cache['conv2'] = cache2_x + x = self.conv2(x, self.cache['conv2']) + x = self.norm2(x) + x = self.act2(x) + out_x = rearrange(x, 'b c f h w -> b (f h w) c') + outputs = [] + for i in range(self.layer_num): + outputs.append(self.linear_layers[i](out_x)) + self.clip_idx += 1 + return outputs + +class Causal_LQ4x_Proj(nn.Module): + + def __init__(self, in_dim, out_dim, layer_num=30): + super().__init__() + self.ff = 1 + self.hh = 16 + self.ww = 16 + self.hidden_dim1 = 2048 + self.hidden_dim2 = 3072 + self.layer_num = layer_num + + self.pixel_shuffle = PixelShuffle3d(self.ff, self.hh, self.ww) + + self.conv1 = CausalConv3d(in_dim*self.ff*self.hh*self.ww, self.hidden_dim1, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) # f -> f/2 h -> h w -> w + self.norm1 = RMS_norm(self.hidden_dim1, images=False) + self.act1 = nn.SiLU() + + self.conv2 = CausalConv3d(self.hidden_dim1, self.hidden_dim2, (4, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)) # f -> f/2 h -> h w -> w + self.norm2 = RMS_norm(self.hidden_dim2, images=False) + self.act2 = nn.SiLU() + + self.linear_layers = nn.ModuleList([nn.Linear(self.hidden_dim2, out_dim) for _ in range(layer_num)]) + + self.clip_idx = 0 + + def forward(self, video): + self.clear_cache() + # x: (B, C, F, H, W) + + t = video.shape[2] + iter_ = 1 + (t - 1) // 4 + first_frame = video[:, :, :1, :, :].repeat(1, 1, 3, 1, 1) + video = torch.cat([first_frame, video], dim=2) + # print(video.shape) + + out_x = [] + for i in range(iter_): + x = self.pixel_shuffle(video[:,:,i*4:(i+1)*4,:,:]) + cache1_x = x[:, :, -CACHE_T:, :, :].clone() + x = self.conv1(x, self.cache['conv1']) + self.cache['conv1'] = cache1_x + x = self.norm1(x) + x = self.act1(x) + cache2_x = x[:, :, -CACHE_T:, :, :].clone() + if i == 0: + self.cache['conv2'] = cache2_x + continue + x = self.conv2(x, self.cache['conv2']) + self.cache['conv2'] = cache2_x + x = self.norm2(x) + x = self.act2(x) + out_x.append(x) + out_x = torch.cat(out_x, dim = 2) + out_x = rearrange(out_x, 'b c f h w -> b (f h w) c') + outputs = [] + for i in range(self.layer_num): + outputs.append(self.linear_layers[i](out_x)) + return outputs + + def clear_cache(self): + self.cache = {} + self.cache['conv1'] = None + self.cache['conv2'] = None + self.clip_idx = 0 + + def stream_forward(self, video_clip): + if self.clip_idx == 0: + # self.clear_cache() + first_frame = video_clip[:, :, :1, :, :].repeat(1, 1, 3, 1, 1) + video_clip = torch.cat([first_frame, video_clip], dim=2) + x = self.pixel_shuffle(video_clip) + cache1_x = x[:, :, -CACHE_T:, :, :].clone() + x = self.conv1(x, self.cache['conv1']) + self.cache['conv1'] = cache1_x + x = self.norm1(x) + x = self.act1(x) + cache2_x = x[:, :, -CACHE_T:, :, :].clone() + self.cache['conv2'] = cache2_x + self.clip_idx += 1 + return None + else: + x = self.pixel_shuffle(video_clip) + cache1_x = x[:, :, -CACHE_T:, :, :].clone() + x = self.conv1(x, self.cache['conv1']) + self.cache['conv1'] = cache1_x + x = self.norm1(x) + x = self.act1(x) + cache2_x = x[:, :, -CACHE_T:, :, :].clone() + x = self.conv2(x, self.cache['conv2']) + self.cache['conv2'] = cache2_x + x = self.norm2(x) + x = self.act2(x) + out_x = rearrange(x, 'b c f h w -> b (f h w) c') + outputs = [] + for i in range(self.layer_num): + outputs.append(self.linear_layers[i](out_x)) + self.clip_idx += 1 + return outputs \ No newline at end of file diff --git a/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/wan_video_dit.py b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/wan_video_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..328e7434ac430ba879c7188da9c09267d7f0c857 --- /dev/null +++ b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/wan_video_dit.py @@ -0,0 +1,864 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +import random +import os +import time +from typing import Tuple, Optional, List +from einops import rearrange +from .utils import hash_state_dict_keys + +try: + import flash_attn_interface + FLASH_ATTN_3_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_3_AVAILABLE = False + +try: + import flash_attn + FLASH_ATTN_2_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_2_AVAILABLE = False + +try: + from sageattention import sageattn + SAGE_ATTN_AVAILABLE = True +except ModuleNotFoundError: + SAGE_ATTN_AVAILABLE = False + +try: + from block_sparse_attn import block_sparse_attn_func + BLOCK_ATTN_AVAILABLE = True +except: + BLOCK_ATTN_AVAILABLE = False + +from .sparse_sage.core import sparse_sageattn +from PIL import Image +import numpy as np + +USE_BLOCK_ATTN = False + +# ---------------------------- +# Local / window masks +# ---------------------------- +@torch.no_grad() +def build_local_block_mask_shifted_vec(block_h: int, + block_w: int, + win_h: int = 6, + win_w: int = 6, + include_self: bool = True, + device=None) -> torch.Tensor: + device = device or torch.device("cpu") + H, W = block_h, block_w + r = torch.arange(H, device=device) + c = torch.arange(W, device=device) + YY, XX = torch.meshgrid(r, c, indexing="ij") + r_all = YY.reshape(-1) + c_all = XX.reshape(-1) + r_half = win_h // 2 + c_half = win_w // 2 + start_r = torch.clamp(r_all - r_half, 0, H - win_h) + end_r = start_r + win_h - 1 + start_c = torch.clamp(c_all - c_half, 0, W - win_w) + end_c = start_c + win_w - 1 + in_row = (r_all[None, :] >= start_r[:, None]) & (r_all[None, :] <= end_r[:, None]) + in_col = (c_all[None, :] >= start_c[:, None]) & (c_all[None, :] <= end_c[:, None]) + mask = in_row & in_col + if not include_self: + mask.fill_diagonal_(False) + return mask + +@torch.no_grad() +def build_local_block_mask_shifted_vec_normal_slide(block_h: int, + block_w: int, + win_h: int = 6, + win_w: int = 6, + include_self: bool = True, + device=None) -> torch.Tensor: + device = device or torch.device("cpu") + H, W = block_h, block_w + r = torch.arange(H, device=device) + c = torch.arange(W, device=device) + YY, XX = torch.meshgrid(r, c, indexing="ij") + r_all = YY.reshape(-1) + c_all = XX.reshape(-1) + r_half = win_h // 2 + c_half = win_w // 2 + start_r = r_all - r_half + end_r = start_r + win_h - 1 + start_c = c_all - c_half + end_c = start_c + win_w - 1 + in_row = (r_all[None, :] >= start_r[:, None]) & (r_all[None, :] <= end_r[:, None]) + in_col = (c_all[None, :] >= start_c[:, None]) & (c_all[None, :] <= end_c[:, None]) + mask = in_row & in_col + if not include_self: + mask.fill_diagonal_(False) + return mask + + +class WindowPartition3D: + """Partition / reverse-partition helpers for 5-D tensors (B,F,H,W,C).""" + @staticmethod + def partition(x: torch.Tensor, win: Tuple[int, int, int]): + B, F, H, W, C = x.shape + wf, wh, ww = win + assert F % wf == 0 and H % wh == 0 and W % ww == 0, "Dims must divide by window size." + x = x.view(B, F // wf, wf, H // wh, wh, W // ww, ww, C) + x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous() + return x.view(-1, wf * wh * ww, C) + + @staticmethod + def reverse(windows: torch.Tensor, win: Tuple[int, int, int], orig: Tuple[int, int, int]): + F, H, W = orig + wf, wh, ww = win + nf, nh, nw = F // wf, H // wh, W // ww + B = windows.size(0) // (nf * nh * nw) + x = windows.view(B, nf, nh, nw, wf, wh, ww, -1) + x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous() + return x.view(B, F, H, W, -1) + + +@torch.no_grad() +def generate_draft_block_mask(batch_size, nheads, seqlen, + q_w, k_w, topk=10, local_attn_mask=None): + assert batch_size == 1, "Only batch_size=1 supported for now" + assert local_attn_mask is not None, "local_attn_mask must be provided" + avgpool_q = torch.mean(q_w, dim=1) + avgpool_k = torch.mean(k_w, dim=1) + avgpool_q = rearrange(avgpool_q, 's (h d) -> s h d', h=nheads) + avgpool_k = rearrange(avgpool_k, 's (h d) -> s h d', h=nheads) + q_heads = avgpool_q.permute(1, 0, 2) + k_heads = avgpool_k.permute(1, 0, 2) + D = avgpool_q.shape[-1] + scores = torch.einsum("hld,hmd->hlm", q_heads, k_heads) / math.sqrt(D) + + repeat_head = scores.shape[0] + repeat_len = scores.shape[1] // local_attn_mask.shape[0] + repeat_num = scores.shape[2] // local_attn_mask.shape[1] + local_attn_mask = local_attn_mask.unsqueeze(1).unsqueeze(0).repeat(repeat_len, 1, repeat_num, 1) + local_attn_mask = rearrange(local_attn_mask, 'x a y b -> (x a) (y b)') + local_attn_mask = local_attn_mask.unsqueeze(0).repeat(repeat_head, 1, 1) + local_attn_mask = local_attn_mask.to(torch.float32) + local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == False, -float('inf')) + local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == True, 0) + scores = scores + local_attn_mask + + attn_map = torch.softmax(scores, dim=-1) + attn_map = rearrange(attn_map, 'h (it s1) s2 -> (h it) s1 s2', it=seqlen) + loop_num, s1, s2 = attn_map.shape + flat = attn_map.reshape(loop_num, -1) + n = flat.shape[1] + apply_topk = min(flat.shape[1]-1, topk) + thresholds = torch.topk(flat, k=apply_topk + 1, dim=1, largest=True).values[:, -1] + thresholds = thresholds.unsqueeze(1) + mask_new = (flat > thresholds).reshape(loop_num, s1, s2) + mask_new = rearrange(mask_new, '(h it) s1 s2 -> h (it s1) s2', it=seqlen) # keep shape note + # 修正:上行变量名统一 + # mask_new = rearrange(attn_map, 'h (it s1) s2 -> h (it s1) s2', it=seqlen) * 0 + mask_new + mask = mask_new.unsqueeze(0).repeat(batch_size, 1, 1, 1) + return mask + + +@torch.no_grad() +def generate_draft_block_mask_sage(batch_size, nheads, seqlen, + q_w, k_w, topk=10, local_attn_mask=None): + assert batch_size == 1, "Only batch_size=1 supported for now" + assert local_attn_mask is not None, "local_attn_mask must be provided" + + avgpool_q = torch.mean(q_w, dim=1) + avgpool_q = rearrange(avgpool_q, 's (h d) -> s h d', h=nheads) + q_heads = avgpool_q.permute(1, 0, 2) + D = avgpool_q.shape[-1] + + k_w_split = k_w.view(k_w.shape[0], 2, 64, k_w.shape[2]) + avgpool_k_split = torch.mean(k_w_split, dim=2) + avgpool_k_refined = rearrange(avgpool_k_split, 's two d -> (s two) d', two=2) # shape: (s*2, C) + avgpool_k_refined = rearrange(avgpool_k_refined, 's (h d) -> s h d', h=nheads) # shape: (s*2, h, d) + k_heads_doubled = avgpool_k_refined.permute(1, 0, 2) # shape: (h, s*2, d) + + k_heads_1, k_heads_2 = torch.chunk(k_heads_doubled, 2, dim=1) + scores_1 = torch.einsum("hld,hmd->hlm", q_heads, k_heads_1) / math.sqrt(D) + scores_2 = torch.einsum("hld,hmd->hlm", q_heads, k_heads_2) / math.sqrt(D) + scores = torch.cat([scores_1, scores_2], dim=-1) + + repeat_head = scores.shape[0] + repeat_len = scores.shape[1] // local_attn_mask.shape[0] + repeat_num = (scores.shape[2] // 2) // local_attn_mask.shape[1] + + local_attn_mask = local_attn_mask.unsqueeze(1).unsqueeze(0).repeat(repeat_len, 1, repeat_num, 1) + local_attn_mask = rearrange(local_attn_mask, 'x a y b -> (x a) (y b)') + local_attn_mask = local_attn_mask.repeat_interleave(2, dim=1) + local_attn_mask = local_attn_mask.unsqueeze(0).repeat(repeat_head, 1, 1) + + assert scores.shape == local_attn_mask.shape, \ + f"Scores shape {scores.shape} != Mask shape {local_attn_mask.shape}" + + local_attn_mask = local_attn_mask.to(torch.float32) + local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == False, -float('inf')) + local_attn_mask = local_attn_mask.masked_fill(local_attn_mask == True, 0) + scores = scores + local_attn_mask + + attn_map = torch.softmax(scores, dim=-1) + attn_map = rearrange(attn_map, 'h (it s1) s2 -> (h it) s1 s2', it=seqlen) + loop_num, s1, s2 = attn_map.shape + flat = attn_map.reshape(loop_num, -1) + apply_topk = min(flat.shape[1]-1, topk) + + if apply_topk <= 0: + mask_new = torch.zeros_like(flat, dtype=torch.bool).reshape(loop_num, s1, s2) + else: + thresholds = torch.topk(flat, k=apply_topk + 1, dim=1, largest=True).values[:, -1] + thresholds = thresholds.unsqueeze(1) + mask_new = (flat > thresholds).reshape(loop_num, s1, s2) + + mask_new = rearrange(mask_new, '(h it) s1 s2 -> h (it s1) s2', it=seqlen) + mask = mask_new.unsqueeze(0).repeat(batch_size, 1, 1, 1) + return mask + + +# ---------------------------- +# Attention kernels +# ---------------------------- +def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False, attention_mask=None, return_KV=False): + if attention_mask is not None: + seqlen = q.shape[1] + seqlen_kv = k.shape[1] + if USE_BLOCK_ATTN and BLOCK_ATTN_AVAILABLE: + q = rearrange(q, "b s (n d) -> (b s) n d", n=num_heads) + k = rearrange(k, "b s (n d) -> (b s) n d", n=num_heads) + v = rearrange(v, "b s (n d) -> (b s) n d", n=num_heads) + else: + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + cu_seqlens_q = torch.tensor([0, seqlen], device=q.device, dtype=torch.int32) + cu_seqlens_k = torch.tensor([0, seqlen_kv], device=q.device, dtype=torch.int32) + head_mask_type = torch.tensor([1]*num_heads, device=q.device, dtype=torch.int32) + streaming_info = None + base_blockmask = attention_mask + max_seqlen_q_ = seqlen + max_seqlen_k_ = seqlen_kv + p_dropout = 0.0 + if USE_BLOCK_ATTN and BLOCK_ATTN_AVAILABLE: + x = block_sparse_attn_func( + q, k, v, + cu_seqlens_q, cu_seqlens_k, + head_mask_type, + streaming_info, + base_blockmask, + max_seqlen_q_, max_seqlen_k_, + p_dropout, + deterministic=False, + softmax_scale=None, + is_causal=False, + exact_streaming=False, + return_attn_probs=False, + ).unsqueeze(0) + x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) + else: + x = sparse_sageattn( + q, k, v, + mask_id=base_blockmask.to(torch.int8), + is_causal=False, + tensor_layout="HND" + ) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + elif compatibility_mode: + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + x = F.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + elif FLASH_ATTN_3_AVAILABLE: + q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) + k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) + v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) + x = flash_attn_interface.flash_attn_func(q, k, v) + if isinstance(x, tuple): + x = x[0] + x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) + elif FLASH_ATTN_2_AVAILABLE: + q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) + k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) + v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) + x = flash_attn.flash_attn_func(q, k, v) + x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) + elif SAGE_ATTN_AVAILABLE: + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + x = sageattn(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + else: + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + x = F.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + return x + + +def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor): + return (x * (1 + scale) + shift) + + +def sinusoidal_embedding_1d(dim, position): + sinusoid = torch.outer(position.type(torch.float64), torch.pow( + 10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x.to(position.dtype) + + +def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0): + f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta) + h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) + w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) + return f_freqs_cis, h_freqs_cis, w_freqs_cis + + +def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2) + [: (dim // 2)].double() / dim)) + freqs = torch.outer(torch.arange(end, device=freqs.device), freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + return freqs_cis + + +def rope_apply(x, freqs, num_heads): + x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) + x_out = torch.view_as_complex(x.to(torch.float64).reshape( + x.shape[0], x.shape[1], x.shape[2], -1, 2)) + x_out = torch.view_as_real(x_out * freqs).flatten(2) + return x_out.to(x.dtype) + + +# ---------------------------- +# Norms & Blocks +# ---------------------------- +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + def forward(self, x): + dtype = x.dtype + return self.norm(x.float()).to(dtype) * self.weight + + +class AttentionModule(nn.Module): + def __init__(self, num_heads): + super().__init__() + self.num_heads = num_heads + + def forward(self, q, k, v, attention_mask=None): + x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads, attention_mask=attention_mask) + return x + + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + + self.attn = AttentionModule(self.num_heads) + self.local_attn_mask = None + + def forward(self, x, freqs, f=None, h=None, w=None, local_num=None, topk=None, + train_img=False, block_id=None, kv_len=None, is_full_block=False, + is_stream=False, pre_cache_k=None, pre_cache_v=None, local_range = 9): + B, L, D = x.shape + if is_stream and pre_cache_k is not None and pre_cache_v is not None: + assert f==2, "f must be 2" + if is_stream and (pre_cache_k is None or pre_cache_v is None): + assert f==6, " start f must be 6" + assert L == f * h * w, "Sequence length mismatch with provided (f,h,w)." + + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(x)) + v = self.v(x) + q = rope_apply(q, freqs, self.num_heads) + k = rope_apply(k, freqs, self.num_heads) + + win = (2, 8, 8) + q = q.view(B, f, h, w, D) + k = k.view(B, f, h, w, D) + v = v.view(B, f, h, w, D) + + q_w = WindowPartition3D.partition(q, win) + k_w = WindowPartition3D.partition(k, win) + v_w = WindowPartition3D.partition(v, win) + + seqlen = f//win[0] + one_len = k_w.shape[0] // B // seqlen + if pre_cache_k is not None and pre_cache_v is not None: + k_w = torch.cat([pre_cache_k, k_w], dim=0) + v_w = torch.cat([pre_cache_v, v_w], dim=0) + + block_n = q_w.shape[0] // B + block_s = q_w.shape[1] + block_n_kv = k_w.shape[0] // B + + reorder_q = rearrange(q_w, '(b block_n) (block_s) d -> b (block_n block_s) d', block_n=block_n, block_s=block_s) + reorder_k = rearrange(k_w, '(b block_n) (block_s) d -> b (block_n block_s) d', block_n=block_n_kv, block_s=block_s) + reorder_v = rearrange(v_w, '(b block_n) (block_s) d -> b (block_n block_s) d', block_n=block_n_kv, block_s=block_s) + + window_size = win[0]*h*w//128 + + if self.local_attn_mask is None or self.local_attn_mask_h!=h//8 or self.local_attn_mask_w!=w//8 or self.local_range!=local_range: + self.local_attn_mask = build_local_block_mask_shifted_vec_normal_slide(h//8, w//8, local_range, local_range, include_self=True, device=k_w.device) + self.local_attn_mask_h = h//8 + self.local_attn_mask_w = w//8 + self.local_range = local_range + if USE_BLOCK_ATTN and BLOCK_ATTN_AVAILABLE: + attention_mask = generate_draft_block_mask(B, self.num_heads, seqlen, q_w, k_w, topk=topk, local_attn_mask=self.local_attn_mask) + else: + attention_mask = generate_draft_block_mask_sage(B, self.num_heads, seqlen, q_w, k_w, topk=topk, local_attn_mask=self.local_attn_mask) + + x = self.attn(reorder_q, reorder_k, reorder_v, attention_mask) + + cur_block_n, cur_block_s, _ = k_w.shape + cache_num = cur_block_n // one_len + if cache_num > kv_len: + cache_k = k_w[one_len:, :, :] + cache_v = v_w[one_len:, :, :] + else: + cache_k = k_w + cache_v = v_w + + x = rearrange(x, 'b (block_n block_s) d -> (b block_n) (block_s) d', block_n=block_n, block_s=block_s) + x = WindowPartition3D.reverse(x, win, (f, h, w)) + x = x.view(B, f*h*w, D) + + if is_stream: + return self.o(x), cache_k, cache_v + return self.o(x) + + +class CrossAttention(nn.Module): + """ + 仅考虑文本 context;提供持久 KV 缓存。 + """ + def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + + self.attn = AttentionModule(self.num_heads) + + # 持久缓存 + self.cache_k = None + self.cache_v = None + + @torch.no_grad() + def init_cache(self, ctx: torch.Tensor): + """ctx: [B, S_ctx, dim] —— 经过 text_embedding 之后的上下文""" + self.cache_k = self.norm_k(self.k(ctx)) + self.cache_v = self.v(ctx) + + def clear_cache(self): + self.cache_k = None + self.cache_v = None + + def forward(self, x: torch.Tensor, y: torch.Tensor, is_stream: bool = False): + """ + y 即文本上下文(未做其他分支)。 + """ + q = self.norm_q(self.q(x)) + assert self.cache_k is not None and self.cache_v is not None + k = self.cache_k + v = self.cache_v + + x = self.attn(q, k, v) + return self.o(x) + + +class GateModule(nn.Module): + def __init__(self,): + super().__init__() + + def forward(self, x, gate, residual): + return x + gate * residual + + +class DiTBlock(nn.Module): + def __init__(self, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.ffn_dim = ffn_dim + + self.self_attn = SelfAttention(dim, num_heads, eps) + self.cross_attn = CrossAttention(dim, num_heads, eps) + + self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.norm3 = nn.LayerNorm(dim, eps=eps) + self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU( + approximate='tanh'), nn.Linear(ffn_dim, dim)) + self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + self.gate = GateModule() + + def forward(self, x, context, t_mod, freqs, f, h, w, local_num=None, topk=None, + train_img=False, block_id=None, kv_len=None, is_full_block=False, + is_stream=False, pre_cache_k=None, pre_cache_v=None, local_range = 9): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1) + input_x = modulate(self.norm1(x), shift_msa, scale_msa) + self_attn_output, self_attn_cache_k, self_attn_cache_v = self.self_attn( + input_x, freqs, f, h, w, local_num, topk, train_img, block_id, + kv_len=kv_len, is_full_block=is_full_block, is_stream=is_stream, + pre_cache_k=pre_cache_k, pre_cache_v=pre_cache_v, local_range = local_range) + + x = self.gate(x, gate_msa, self_attn_output) + x = x + self.cross_attn(self.norm3(x), context, is_stream=is_stream) + input_x = modulate(self.norm2(x), shift_mlp, scale_mlp) + x = self.gate(x, gate_mlp, self.ffn(input_x)) + if is_stream: + return x, self_attn_cache_k, self_attn_cache_v + return x + + +class MLP(torch.nn.Module): + def __init__(self, in_dim, out_dim, has_pos_emb=False): + super().__init__() + self.proj = torch.nn.Sequential( + nn.LayerNorm(in_dim), + nn.Linear(in_dim, in_dim), + nn.GELU(), + nn.Linear(in_dim, out_dim), + nn.LayerNorm(out_dim) + ) + self.has_pos_emb = has_pos_emb + if has_pos_emb: + self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280))) + + def forward(self, x): + if self.has_pos_emb: + x = x + self.emb_pos.to(dtype=x.dtype, device=x.device) + return self.proj(x) + + +class Head(nn.Module): + def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float): + super().__init__() + self.dim = dim + self.patch_size = patch_size + self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) + self.head = nn.Linear(dim, out_dim * math.prod(patch_size)) + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + + def forward(self, x, t_mod): + shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1) + x = (self.head(self.norm(x) * (1 + scale) + shift)) + return x + + +# ---------------------------- +# WanModel (no image branch) — init 时即产生 KV 缓存 +# ---------------------------- +class WanModel(torch.nn.Module): + def __init__( + self, + dim: int, + in_dim: int, + ffn_dim: int, + out_dim: int, + text_dim: int, + freq_dim: int, + eps: float, + patch_size: Tuple[int, int, int], + num_heads: int, + num_layers: int, + # init_context: torch.Tensor, # <<<< 必填:在 __init__ 里用它生成 cross-attn KV 缓存 + has_image_input: bool = False, + ): + super().__init__() + self.dim = dim + self.freq_dim = freq_dim + self.patch_size = patch_size + + # patch embed + self.patch_embedding = nn.Conv3d( + in_dim, dim, kernel_size=patch_size, stride=patch_size) + + # text / time embed + self.text_embedding = nn.Sequential( + nn.Linear(text_dim, dim), + nn.GELU(approximate='tanh'), + nn.Linear(dim, dim) + ) + self.time_embedding = nn.Sequential( + nn.Linear(freq_dim, dim), + nn.SiLU(), + nn.Linear(dim, dim) + ) + self.time_projection = nn.Sequential( + nn.SiLU(), nn.Linear(dim, dim * 6)) + + # blocks + self.blocks = nn.ModuleList([ + DiTBlock(dim, num_heads, ffn_dim, eps) + for _ in range(num_layers) + ]) + self.head = Head(dim, out_dim, patch_size, eps) + + head_dim = dim // num_heads + self.freqs = precompute_freqs_cis_3d(head_dim) + + self._cross_kv_initialized = False + + # 可选:手动清空 / 重新初始化 + def clear_cross_kv(self): + for blk in self.blocks: + blk.cross_attn.clear_cache() + self._cross_kv_initialized = False + + @torch.no_grad() + def reinit_cross_kv(self, new_context: torch.Tensor): + ctx_txt = self.text_embedding(new_context) + for blk in self.blocks: + blk.cross_attn.init_cache(ctx_txt) + self._cross_kv_initialized = True + + def patchify(self, x: torch.Tensor): + x = self.patch_embedding(x) + grid_size = x.shape[2:] + x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous() + return x, grid_size # x, grid_size: (f, h, w) + + def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor): + return rearrange( + x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)', + f=grid_size[0], h=grid_size[1], w=grid_size[2], + x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2] + ) + + def forward(self, + x: torch.Tensor, + timestep: torch.Tensor, + context: torch.Tensor, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + LQ_latents: Optional[List[torch.Tensor]] = None, + train_img: bool = False, + topk_ratio: Optional[float] = None, + kv_ratio: Optional[float] = None, + local_num: Optional[int] = None, + is_full_block: bool = False, + causal_idx: Optional[int] = None, + **kwargs, + ): + # time / text embeds + t = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, timestep)) + t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) + + # 这里仍会嵌入 text(CrossAttention 若已有缓存会忽略它) + # context = self.text_embedding(context) + + # 输入打补丁 + x, (f, h, w) = self.patchify(x) + B = x.shape[0] + + # window / masks 超参 + win = (2, 8, 8) + seqlen = f//win[0] + if local_num is None: + local_random = random.random() + if local_random < 0.3: + local_num = seqlen - 3 + elif local_random < 0.4: + local_num = seqlen - 4 + elif local_random < 0.5: + local_num = seqlen - 2 + else: + local_num = seqlen + + window_size = win[0]*h*w//128 + square_num = window_size*window_size + topk_ratio = 2.0 + topk = min(max(int(square_num*topk_ratio), 1), int(square_num*seqlen)-1) + + if kv_ratio is None: + kv_ratio = (random.uniform(0., 1.0)**2)*(local_num-2-2)+2 + kv_len = min(max(int(window_size*kv_ratio), 1), int(window_size*seqlen)-1) + + decay_ratio = random.uniform(0.7, 1.0) + + # RoPE 3D + freqs = torch.cat([ + self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + # blocks + for block_id, block in enumerate(self.blocks): + if LQ_latents is not None and block_id < len(LQ_latents): + x += LQ_latents[block_id] + + if self.training and use_gradient_checkpointing: + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, f, h, w, local_num, topk, + train_img, block_id, kv_len, is_full_block, False, + None, None, + use_reentrant=False, + ) + else: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, f, h, w, local_num, topk, + train_img, block_id, kv_len, is_full_block, False, + None, None, + use_reentrant=False, + ) + else: + x = block(x, context, t_mod, freqs, f, h, w, local_num, topk, + train_img, block_id, kv_len, is_full_block, False, + None, None) + + x = self.head(x, t) + x = self.unpatchify(x, (f, h, w)) + return x + + @staticmethod + def state_dict_converter(): + return WanModelStateDictConverter() + + +# ---------------------------- +# State dict converter(保持原映射;已忽略 has_image_input 使用) +# ---------------------------- +class WanModelStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + rename_dict = { + "blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight", + "blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight", + "blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias", + "blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight", + "blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias", + "blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight", + "blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias", + "blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight", + "blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias", + "blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight", + "blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight", + "blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight", + "blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias", + "blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight", + "blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias", + "blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight", + "blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias", + "blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight", + "blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias", + "blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight", + "blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias", + "blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight", + "blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias", + "blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight", + "blocks.0.norm2.bias": "blocks.0.norm3.bias", + "blocks.0.norm2.weight": "blocks.0.norm3.weight", + "blocks.0.scale_shift_table": "blocks.0.modulation", + "condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias", + "condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight", + "condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias", + "condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight", + "condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias", + "condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight", + "condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias", + "condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight", + "condition_embedder.time_proj.bias": "time_projection.1.bias", + "condition_embedder.time_proj.weight": "time_projection.1.weight", + "patch_embedding.bias": "patch_embedding.bias", + "patch_embedding.weight": "patch_embedding.weight", + "scale_shift_table": "head.modulation", + "proj_out.bias": "head.head.bias", + "proj_out.weight": "head.head.weight", + } + state_dict_ = {} + for name, param in state_dict.items(): + if name in rename_dict: + state_dict_[rename_dict[name]] = param + else: + name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:]) + if name_ in rename_dict: + name_ = rename_dict[name_] + name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:]) + state_dict_[name_] = param + if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b": + config = { + "model_type": "t2v", + "patch_size": (1, 2, 2), + "text_len": 512, + "in_dim": 16, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 40, + "num_layers": 40, + "window_size": (-1, -1), + "qk_norm": True, + "cross_attn_norm": True, + "eps": 1e-6, + } + else: + config = {} + return state_dict_, config + + def from_civitai(self, state_dict): + state_dict = {name: param for name, param in state_dict.items() if not name.startswith("vace")} + # 保留原有哈希匹配返回的 config;实现本身不使用 has_image_input 分支 + if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814": + config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 16,"dim": 1536,"ffn_dim": 8960,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 12,"num_layers": 30,"eps": 1e-6} + elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70": + config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 16,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6} + elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e": + config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 36,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6} + elif hash_state_dict_keys(state_dict) == "6d6ccde6845b95ad9114ab993d917893": + config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 36,"dim": 1536,"ffn_dim": 8960,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 12,"num_layers": 30,"eps": 1e-6} + elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677": + config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 48,"dim": 1536,"ffn_dim": 8960,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 12,"num_layers": 30,"eps": 1e-6} + elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c": + config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 48,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6} + elif hash_state_dict_keys(state_dict) == "3ef3b1f8e1dab83d5b71fd7b617f859f": + config = {"has_image_input": False,"patch_size": [1, 2, 2],"in_dim": 36,"dim": 5120,"ffn_dim": 13824,"freq_dim": 256,"text_dim": 4096,"out_dim": 16,"num_heads": 40,"num_layers": 40,"eps": 1e-6,"has_image_pos_emb": False} + else: + config = {} + return state_dict, config + \ No newline at end of file diff --git a/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/wan_video_vae.py b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/wan_video_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..123303963af5e0e426e2f316a6399b22fcdffac5 --- /dev/null +++ b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/models/wan_video_vae.py @@ -0,0 +1,847 @@ +from einops import rearrange, repeat + +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm import tqdm + +CACHE_T = 2 + + +def check_is_instance(model, module_class): + if isinstance(model, module_class): + return True + if hasattr(model, "module") and isinstance(model.module, module_class): + return True + return False + + +def block_causal_mask(x, block_size): + # params + b, n, s, _, device = *x.size(), x.device + assert s % block_size == 0 + num_blocks = s // block_size + + # build mask + mask = torch.zeros(b, n, s, s, dtype=torch.bool, device=device) + for i in range(num_blocks): + mask[:, :, + i * block_size:(i + 1) * block_size, :(i + 1) * block_size] = 1 + return mask + + +class CausalConv3d(nn.Conv3d): + """ + Causal 3d convolusion. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._padding = (self.padding[2], self.padding[2], self.padding[1], + self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + # print('cache_x.shape', cache_x.shape, 'x.shape', x.shape) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + + return super().forward(x) + + +class RMS_norm(nn.Module): + + def __init__(self, dim, channel_first=True, images=True, bias=False): + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0. + + def forward(self, x): + return F.normalize( + x, dim=(1 if self.channel_first else + -1)) * self.scale * self.gamma + self.bias + + +class Upsample(nn.Upsample): + + def forward(self, x): + """ + Fix bfloat16 support for nearest neighbor interpolation. + """ + return super().forward(x.float()).type_as(x) + + +class Resample(nn.Module): + + def __init__(self, dim, mode): + assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d', + 'downsample3d') + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == 'upsample2d': + self.resample = nn.Sequential( + Upsample(scale_factor=(2., 2.), mode='nearest-exact'), + nn.Conv2d(dim, dim // 2, 3, padding=1)) + elif mode == 'upsample3d': + self.resample = nn.Sequential( + Upsample(scale_factor=(2., 2.), mode='nearest-exact'), + nn.Conv2d(dim, dim // 2, 3, padding=1)) + self.time_conv = CausalConv3d(dim, + dim * 2, (3, 1, 1), + padding=(1, 0, 0)) + + elif mode == 'downsample2d': + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == 'downsample3d': + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = CausalConv3d(dim, + dim, (3, 1, 1), + stride=(2, 1, 1), + padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == 'upsample3d': + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = 'Rep' + feat_idx[0] += 1 + else: + + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[ + idx] is not None and feat_cache[idx] != 'Rep': + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + if cache_x.shape[2] < 2 and feat_cache[ + idx] is not None and feat_cache[idx] == 'Rep': + cache_x = torch.cat([ + torch.zeros_like(cache_x).to(cache_x.device), + cache_x + ], + dim=2) + if feat_cache[idx] == 'Rep': + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), + 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = rearrange(x, 'b c t h w -> (b t) c h w') + x = self.resample(x) + x = rearrange(x, '(b t) c h w -> b c t h w', t=t) + + if self.mode == 'downsample3d': + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv( + torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + def init_weight(self, conv): + conv_weight = conv.weight + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + one_matrix = torch.eye(c1, c2) + init_matrix = one_matrix + nn.init.zeros_(conv_weight) + conv_weight.data[:, :, 1, 0, 0] = init_matrix + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + def init_weight2(self, conv): + conv_weight = conv.weight.data + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + init_matrix = torch.eye(c1 // 2, c2) + conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix + conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + +class ResidualBlock(nn.Module): + + def __init__(self, in_dim, out_dim, dropout=0.0): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # layers + self.residual = nn.Sequential( + RMS_norm(in_dim, images=False), nn.SiLU(), + CausalConv3d(in_dim, out_dim, 3, padding=1), + RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout), + CausalConv3d(out_dim, out_dim, 3, padding=1)) + self.shortcut = CausalConv3d(in_dim, out_dim, 1) \ + if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + h = self.shortcut(x) + for layer in self.residual: + if check_is_instance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + h + + +class AttentionBlock(nn.Module): + """ + Causal self-attention with a single head. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = RMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x): + identity = x + b, c, t, h, w = x.size() + x = rearrange(x, 'b c t h w -> (b t) c h w') + x = self.norm(x) + # compute query, key, value + q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute( + 0, 1, 3, 2).contiguous().chunk(3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention( + q, + k, + v, + #attn_mask=block_causal_mask(q, block_size=h * w) + ) + x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) + + # output + x = self.proj(x) + x = rearrange(x, '(b t) c h w-> b c t h w', t=t) + return x + identity + + +class Encoder3d(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + downsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = 'downsample3d' if temperal_downsample[ + i] else 'downsample2d' + downsamples.append(Resample(out_dim, mode=mode)) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential(ResidualBlock(out_dim, out_dim, dropout), + AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout)) + + # output blocks + self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), + CausalConv3d(out_dim, z_dim, 3, padding=1)) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## downsamples + for layer in self.downsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + for layer in self.middle: + if check_is_instance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if check_is_instance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +class Decoder3d(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2**(len(dim_mult) - 2) + + # init block + self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout), + AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout)) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i == 1 or i == 2 or i == 3: + in_dim = in_dim // 2 + for _ in range(num_res_blocks + 1): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + upsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # upsample block + if i != len(dim_mult) - 1: + mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d' + upsamples.append(Resample(out_dim, mode=mode)) + scale *= 2.0 + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), + CausalConv3d(out_dim, 3, 3, padding=1)) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## middle + for layer in self.middle: + if check_is_instance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## upsamples + for layer in self.upsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if check_is_instance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +def count_conv3d(model): + count = 0 + for m in model.modules(): + if check_is_instance(m, CausalConv3d): + count += 1 + return count + + +class VideoVAE_(nn.Module): + + def __init__(self, + dim=96, + z_dim=16, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[False, True, True], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + # modules + self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, + attn_scales, self.temperal_downsample, dropout) + self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) + self.conv2 = CausalConv3d(z_dim, z_dim, 1) + self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, + attn_scales, self.temperal_upsample, dropout) + + def forward(self, x): + mu, log_var = self.encode(x) + z = self.reparameterize(mu, log_var) + x_recon = self.decode(z) + return x_recon, mu, log_var + + def encode(self, x, scale): + self.clear_cache() + ## cache + t = x.shape[2] + iter_ = 1 + (t - 1) // 4 + + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder(x[:, :, :1, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + out = torch.cat([out, out_], 2) + mu, log_var = self.conv1(out).chunk(2, dim=1) + if isinstance(scale[0], torch.Tensor): + scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale] + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( + 1, self.z_dim, 1, 1, 1) + else: + scale = scale.to(dtype=mu.dtype, device=mu.device) + mu = (mu - scale[0]) * scale[1] + return mu + + def decode(self, z, scale): + self.clear_cache() + # z: [b,c,t,h,w] + if isinstance(scale[0], torch.Tensor): + scale = [s.to(dtype=z.dtype, device=z.device) for s in scale] + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( + 1, self.z_dim, 1, 1, 1) + else: + scale = scale.to(dtype=z.dtype, device=z.device) + z = z / scale[1] + scale[0] + iter_ = z.shape[2] + x = self.conv2(z) + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out = self.decoder(x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + else: + out_ = self.decoder(x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) # may add tensor offload + return out + + + def stream_decode(self, z, scale): + # self.clear_cache() + # z: [b,c,t,h,w] + if isinstance(scale[0], torch.Tensor): + scale = [s.to(dtype=z.dtype, device=z.device) for s in scale] + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( + 1, self.z_dim, 1, 1, 1) + else: + scale = scale.to(dtype=z.dtype, device=z.device) + z = z / scale[1] + scale[0] + iter_ = z.shape[2] + x = self.conv2(z) + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out = self.decoder(x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + else: + out_ = self.decoder(x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) # may add tensor offload + return out + + def reparameterize(self, mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + return eps * std + mu + + def sample(self, imgs, deterministic=False): + mu, log_var = self.encode(imgs) + if deterministic: + return mu + std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) + return mu + std * torch.randn_like(std) + + def clear_cache(self): + self._conv_num = count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # print('self._feat_map', len(self._feat_map)) + # cache encode + if self.encoder is not None: + self._enc_conv_num = count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + # print('self._enc_feat_map', len(self._enc_feat_map)) + + +class WanVideoVAE(nn.Module): + + def __init__(self, z_dim=16, dim=96): + super().__init__() + + mean = [ + -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, + 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921 + ] + std = [ + 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, + 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160 + ] + self.mean = torch.tensor(mean) + self.std = torch.tensor(std) + self.scale = [self.mean, 1.0 / self.std] + + # init model + self.model = VideoVAE_(z_dim=z_dim, dim = dim).eval().requires_grad_(False) + self.upsampling_factor = 8 + + + def build_1d_mask(self, length, left_bound, right_bound, border_width): + x = torch.ones((length,)) + if not left_bound: + x[:border_width] = (torch.arange(border_width) + 1) / border_width + if not right_bound: + x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,)) + return x + + + def build_mask(self, data, is_bound, border_width): + _, _, _, H, W = data.shape + h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0]) + w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1]) + + h = repeat(h, "H -> H W", H=H, W=W) + w = repeat(w, "W -> H W", H=H, W=W) + + mask = torch.stack([h, w]).min(dim=0).values + mask = rearrange(mask, "H W -> 1 1 1 H W") + return mask + + + def tiled_decode(self, hidden_states, device, tile_size, tile_stride): + _, _, T, H, W = hidden_states.shape + size_h, size_w = tile_size + stride_h, stride_w = tile_stride + + # Split tasks + tasks = [] + for h in range(0, H, stride_h): + if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue + for w in range(0, W, stride_w): + if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue + h_, w_ = h + size_h, w + size_w + tasks.append((h, h_, w, w_)) + + data_device = "cpu" + computation_device = device + + out_T = T * 4 - 3 + weight = torch.zeros((1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device) + values = torch.zeros((1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device) + + for h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"): + hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(computation_device) + hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(data_device) + + mask = self.build_mask( + hidden_states_batch, + is_bound=(h==0, h_>=H, w==0, w_>=W), + border_width=((size_h - stride_h) * self.upsampling_factor, (size_w - stride_w) * self.upsampling_factor) + ).to(dtype=hidden_states.dtype, device=data_device) + + target_h = h * self.upsampling_factor + target_w = w * self.upsampling_factor + values[ + :, + :, + :, + target_h:target_h + hidden_states_batch.shape[3], + target_w:target_w + hidden_states_batch.shape[4], + ] += hidden_states_batch * mask + weight[ + :, + :, + :, + target_h: target_h + hidden_states_batch.shape[3], + target_w: target_w + hidden_states_batch.shape[4], + ] += mask + values = values / weight + values = values.clamp_(-1, 1) + return values + + + def tiled_encode(self, video, device, tile_size, tile_stride): + _, _, T, H, W = video.shape + size_h, size_w = tile_size + stride_h, stride_w = tile_stride + + # Split tasks + tasks = [] + for h in range(0, H, stride_h): + if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue + for w in range(0, W, stride_w): + if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue + h_, w_ = h + size_h, w + size_w + tasks.append((h, h_, w, w_)) + + data_device = "cpu" + computation_device = device + + out_T = (T + 3) // 4 + weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device) + values = torch.zeros((1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device) + + for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"): + hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device) + hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(data_device) + + mask = self.build_mask( + hidden_states_batch, + is_bound=(h==0, h_>=H, w==0, w_>=W), + border_width=((size_h - stride_h) // self.upsampling_factor, (size_w - stride_w) // self.upsampling_factor) + ).to(dtype=video.dtype, device=data_device) + + target_h = h // self.upsampling_factor + target_w = w // self.upsampling_factor + values[ + :, + :, + :, + target_h:target_h + hidden_states_batch.shape[3], + target_w:target_w + hidden_states_batch.shape[4], + ] += hidden_states_batch * mask + weight[ + :, + :, + :, + target_h: target_h + hidden_states_batch.shape[3], + target_w: target_w + hidden_states_batch.shape[4], + ] += mask + values = values / weight + return values + + + def single_encode(self, video, device): + video = video.to(device) + x = self.model.encode(video, self.scale) + return x + + + def single_decode(self, hidden_state, device): + hidden_state = hidden_state.to(device) + video = self.model.decode(hidden_state, self.scale) + return video.clamp_(-1, 1) + + + def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): + + videos = [video.to("cpu") for video in videos] + hidden_states = [] + for video in videos: + video = video.unsqueeze(0) + if tiled: + tile_size = (tile_size[0] * 8, tile_size[1] * 8) + tile_stride = (tile_stride[0] * 8, tile_stride[1] * 8) + hidden_state = self.tiled_encode(video, device, tile_size, tile_stride) + else: + hidden_state = self.single_encode(video, device) + hidden_state = hidden_state.squeeze(0) + hidden_states.append(hidden_state) + hidden_states = torch.stack(hidden_states) + return hidden_states + + + def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): + hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states] + videos = [] + for hidden_state in hidden_states: + hidden_state = hidden_state.unsqueeze(0) + if tiled: + video = self.tiled_decode(hidden_state, device, tile_size, tile_stride) + else: + video = self.single_decode(hidden_state, device) + video = video.squeeze(0) + videos.append(video) + videos = torch.stack(videos) + return videos + + def clear_cache(self): + self.model.clear_cache() + + def stream_decode(self, hidden_states, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): + hidden_states = [hidden_state for hidden_state in hidden_states] + assert len(hidden_states) == 1 + hidden_state = hidden_states[0] + video = self.model.stream_decode(hidden_state, self.scale) + return video + + + @staticmethod + def state_dict_converter(): + return WanVideoVAEStateDictConverter() + + +class WanVideoVAEStateDictConverter: + + def __init__(self): + pass + + def from_civitai(self, state_dict): + state_dict_ = {} + if 'model_state' in state_dict: + state_dict = state_dict['model_state'] + for name in state_dict: + state_dict_['model.' + name] = state_dict[name] + return state_dict_ diff --git a/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/pipelines/__init__.py b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3eb0b24808ceb14cdcb075f14401f978401cd6f7 --- /dev/null +++ b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/pipelines/__init__.py @@ -0,0 +1,3 @@ +from .flashvsr_full import FlashVSRFullPipeline +from .flashvsr_tiny import FlashVSRTinyPipeline +from .flashvsr_tiny_long import FlashVSRTinyLongPipeline \ No newline at end of file diff --git a/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/pipelines/base.py b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/pipelines/base.py new file mode 100644 index 0000000000000000000000000000000000000000..e9f6a4d8b2fc678be243ffcb2c599b11bdbd2f8b --- /dev/null +++ b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/pipelines/base.py @@ -0,0 +1,130 @@ +import torch +import gc +import numpy as np +from PIL import Image +from torchvision.transforms import GaussianBlur + +class BasePipeline(torch.nn.Module): + + def __init__(self, device="cuda", torch_dtype=torch.float16, height_division_factor=64, width_division_factor=64): + super().__init__() + self.device = device + self.torch_dtype = torch_dtype + self.height_division_factor = height_division_factor + self.width_division_factor = width_division_factor + self.cpu_offload = False + self.model_names = [] + + + def check_resize_height_width(self, height, width): + if height % self.height_division_factor != 0: + height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor + print(f"The height cannot be evenly divided by {self.height_division_factor}. We round it up to {height}.") + if width % self.width_division_factor != 0: + width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor + print(f"The width cannot be evenly divided by {self.width_division_factor}. We round it up to {width}.") + return height, width + + + def preprocess_image(self, image): + image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0) + return image + + + def preprocess_images(self, images): + return [self.preprocess_image(image) for image in images] + + + def vae_output_to_image(self, vae_output): + image = vae_output[0].cpu().float().permute(1, 2, 0).numpy() + image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) + return image + + + def vae_output_to_video(self, vae_output): + video = vae_output.cpu().permute(1, 2, 0).numpy() + video = [Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) for image in video] + return video + + + def merge_latents(self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0): + if len(latents) > 0: + blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma) + height, width = value.shape[-2:] + weight = torch.ones_like(value) + for latent, mask, scale in zip(latents, masks, scales): + mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0 + mask = mask.repeat(1, latent.shape[1], 1, 1).to(dtype=latent.dtype, device=latent.device) + mask = blur(mask) + value += latent * mask * scale + weight += mask * scale + value /= weight + return value + + + def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback, special_kwargs=None, special_local_kwargs_list=None): + if special_kwargs is None: + noise_pred_global = inference_callback(prompt_emb_global) + else: + noise_pred_global = inference_callback(prompt_emb_global, special_kwargs) + if special_local_kwargs_list is None: + noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals] + else: + noise_pred_locals = [inference_callback(prompt_emb_local, special_kwargs) for prompt_emb_local, special_kwargs in zip(prompt_emb_locals, special_local_kwargs_list)] + noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales) + return noise_pred + + + def extend_prompt(self, prompt, local_prompts, masks, mask_scales): + local_prompts = local_prompts or [] + masks = masks or [] + mask_scales = mask_scales or [] + extended_prompt_dict = self.prompter.extend_prompt(prompt) + prompt = extended_prompt_dict.get("prompt", prompt) + local_prompts += extended_prompt_dict.get("prompts", []) + masks += extended_prompt_dict.get("masks", []) + mask_scales += [100.0] * len(extended_prompt_dict.get("masks", [])) + return prompt, local_prompts, masks, mask_scales + + + def enable_cpu_offload(self): + self.cpu_offload = True + + + def load_models_to_device(self, loadmodel_names=[]): + # only load models to device if cpu_offload is enabled + if not self.cpu_offload: + return + # offload the unneeded models to cpu + for model_name in self.model_names: + if model_name not in loadmodel_names: + model = getattr(self, model_name) + if model is not None: + if hasattr(model, "vram_management_enabled") and model.vram_management_enabled: + for module in model.modules(): + if hasattr(module, "offload"): + module.offload() + else: + model.cpu() + # load the needed models to device + for model_name in loadmodel_names: + model = getattr(self, model_name) + if model is not None: + if hasattr(model, "vram_management_enabled") and model.vram_management_enabled: + for module in model.modules(): + if hasattr(module, "onload"): + module.onload() + else: + model.to(self.device) + # fresh the cuda cache + if torch.cuda.is_available(): + torch.cuda.empty_cache() + if torch.backends.mps.is_available(): + torch.mps.empty_cache() + + + def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16): + generator = None if seed is None else torch.Generator(device).manual_seed(seed) + noise = torch.randn(shape, generator=generator, device=device, dtype=dtype) + return noise + \ No newline at end of file diff --git a/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/pipelines/flashvsr_full.py b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/pipelines/flashvsr_full.py new file mode 100644 index 0000000000000000000000000000000000000000..57263d13a3ad6d53c964c414324e918b1e0efa80 --- /dev/null +++ b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/pipelines/flashvsr_full.py @@ -0,0 +1,618 @@ +import types +import os +import time +from typing import Optional, Tuple, Literal + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from einops import rearrange +from PIL import Image +from tqdm import tqdm +# import pyfiglet + +from ..models import ModelManager +from ..models.utils import clean_vram +from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d +from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample +from ..schedulers.flow_match import FlowMatchScheduler +from .base import BasePipeline + + +# ----------------------------- +# 基础工具:ADAIN 所需的统计量(保留以备需要;管线默认用 wavelet) +# ----------------------------- +def _calc_mean_std(feat: torch.Tensor, eps: float = 1e-5) -> Tuple[torch.Tensor, torch.Tensor]: + assert feat.dim() == 4, 'feat 必须是 (N, C, H, W)' + N, C = feat.shape[:2] + var = feat.view(N, C, -1).var(dim=2, unbiased=False) + eps + std = var.sqrt().view(N, C, 1, 1) + mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) + return mean, std + + +def _adain(content_feat: torch.Tensor, style_feat: torch.Tensor) -> torch.Tensor: + assert content_feat.shape[:2] == style_feat.shape[:2], "ADAIN: N、C 必须匹配" + size = content_feat.size() + style_mean, style_std = _calc_mean_std(style_feat) + content_mean, content_std = _calc_mean_std(content_feat) + normalized = (content_feat - content_mean.expand(size)) / content_std.expand(size) + return normalized * style_std.expand(size) + style_mean.expand(size) + + +# ----------------------------- +# 小波式模糊与分解/重构(ColorCorrector 用) +# ----------------------------- +def _make_gaussian3x3_kernel(dtype, device) -> torch.Tensor: + vals = [ + [0.0625, 0.125, 0.0625], + [0.125, 0.25, 0.125 ], + [0.0625, 0.125, 0.0625], + ] + return torch.tensor(vals, dtype=dtype, device=device) + + +def _wavelet_blur(x: torch.Tensor, radius: int) -> torch.Tensor: + assert x.dim() == 4, 'x 必须是 (N, C, H, W)' + N, C, H, W = x.shape + base = _make_gaussian3x3_kernel(x.dtype, x.device) + weight = base.view(1, 1, 3, 3).repeat(C, 1, 1, 1) + pad = radius + x_pad = F.pad(x, (pad, pad, pad, pad), mode='replicate') + out = F.conv2d(x_pad, weight, bias=None, stride=1, padding=0, dilation=radius, groups=C) + return out + + +def _wavelet_decompose(x: torch.Tensor, levels: int = 5) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 4, 'x 必须是 (N, C, H, W)' + high = torch.zeros_like(x) + low = x + for i in range(levels): + radius = 2 ** i + blurred = _wavelet_blur(low, radius) + high = high + (low - blurred) + low = blurred + return high, low + + +def _wavelet_reconstruct(content: torch.Tensor, style: torch.Tensor, levels: int = 5) -> torch.Tensor: + c_high, _ = _wavelet_decompose(content, levels=levels) + _, s_low = _wavelet_decompose(style, levels=levels) + return c_high + s_low + + +# ----------------------------- +# 无状态颜色矫正模块(视频友好,默认 wavelet) +# ----------------------------- +class TorchColorCorrectorWavelet(nn.Module): + def __init__(self, levels: int = 5): + super().__init__() + self.levels = levels + + @staticmethod + def _flatten_time(x: torch.Tensor) -> Tuple[torch.Tensor, int, int]: + assert x.dim() == 5, '输入必须是 (B, C, f, H, W)' + B, C, f, H, W = x.shape + y = x.permute(0, 2, 1, 3, 4).reshape(B * f, C, H, W) + return y, B, f + + @staticmethod + def _unflatten_time(y: torch.Tensor, B: int, f: int) -> torch.Tensor: + BF, C, H, W = y.shape + assert BF == B * f + return y.reshape(B, f, C, H, W).permute(0, 2, 1, 3, 4) + + def forward( + self, + hq_image: torch.Tensor, # (B, C, f, H, W) + lq_image: torch.Tensor, # (B, C, f, H, W) + clip_range: Tuple[float, float] = (-1.0, 1.0), + method: Literal['wavelet', 'adain'] = 'wavelet', + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + assert hq_image.shape == lq_image.shape, "HQ 与 LQ 的形状必须一致" + assert hq_image.dim() == 5 and hq_image.shape[1] == 3, "输入必须是 (B, 3, f, H, W)" + + B, C, f, H, W = hq_image.shape + if chunk_size is None or chunk_size >= f: + hq4, B, f = self._flatten_time(hq_image) + lq4, _, _ = self._flatten_time(lq_image) + if method == 'wavelet': + out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels) + elif method == 'adain': + out4 = _adain(hq4, lq4) + else: + raise ValueError(f"未知 method: {method}") + out4 = torch.clamp(out4, *clip_range) + out = self._unflatten_time(out4, B, f) + return out + + outs = [] + for start in range(0, f, chunk_size): + end = min(start + chunk_size, f) + hq_chunk = hq_image[:, :, start:end] + lq_chunk = lq_image[:, :, start:end] + hq4, B_, f_ = self._flatten_time(hq_chunk) + lq4, _, _ = self._flatten_time(lq_chunk) + if method == 'wavelet': + out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels) + elif method == 'adain': + out4 = _adain(hq4, lq4) + else: + raise ValueError(f"未知 method: {method}") + out4 = torch.clamp(out4, *clip_range) + out_chunk = self._unflatten_time(out4, B_, f_) + outs.append(out_chunk) + out = torch.cat(outs, dim=2) + return out + + +# ----------------------------- +# 简化版 Pipeline(仅 dit + vae) +# ----------------------------- +class FlashVSRFullPipeline(BasePipeline): + + def __init__(self, device="cuda", torch_dtype=torch.float16): + super().__init__(device=device, torch_dtype=torch_dtype) + self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True) + self.dit: WanModel = None + self.vae: WanVideoVAE = None + self.model_names = ['dit', 'vae'] + self.height_division_factor = 16 + self.width_division_factor = 16 + self.use_unified_sequence_parallel = False + self.prompt_emb_posi = None + self.ColorCorrector = TorchColorCorrectorWavelet(levels=5) + + print(r""" + ███████╗██╗ █████╗ ███████╗██╗ ██╗██╗ ██╗███████╗█████╗ + ██╔════╝██║ ██╔══██╗██╔════╝██║ ██║██║ ██║██╔════╝██╔══██╗ ██╗ + █████╗ ██║ ███████║███████╗███████║╚██╗ ██╔╝███████╗███████║ ██████╗ + ██╔══╝ ██║ ██╔══██║╚════██║██╔══██║ ╚████╔╝ ╚════██║██╔═██║ ██╔═╝ + ██║ ███████╗██║ ██║███████║██║ ██║ ╚██╔╝ ███████║██║ ██║ ╚═╝ + ╚═╝ ╚══════╝╚═╝ ╚═╝╚══════╝╚═╝ ╚═╝ ╚═╝ ╚══════╝╚═╝ ╚═╝ +""") + + def enable_vram_management(self, num_persistent_param_in_dit=None): + # 仅管理 dit / vae + dtype = next(iter(self.dit.parameters())).dtype + from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear + enable_vram_management( + self.dit, + module_map={ + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Conv3d: AutoWrappedModule, + torch.nn.LayerNorm: AutoWrappedModule, + RMSNorm: AutoWrappedModule, + }, + module_config=dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device=self.device, + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + max_num_param=num_persistent_param_in_dit, + overflow_module_config=dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + ) + dtype = next(iter(self.vae.parameters())).dtype + enable_vram_management( + self.vae, + module_map={ + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Conv2d: AutoWrappedModule, + RMS_norm: AutoWrappedModule, + CausalConv3d: AutoWrappedModule, + Upsample: AutoWrappedModule, + torch.nn.SiLU: AutoWrappedModule, + torch.nn.Dropout: AutoWrappedModule, + }, + module_config=dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device=self.device, + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + ) + self.enable_cpu_offload() + + def fetch_models(self, model_manager: ModelManager): + self.dit = model_manager.fetch_model("wan_video_dit") + self.vae = model_manager.fetch_model("wan_video_vae") + + @staticmethod + def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False): + if device is None: device = model_manager.device + if torch_dtype is None: torch_dtype = model_manager.torch_dtype + pipe = FlashVSRFullPipeline(device=device, torch_dtype=torch_dtype) + pipe.fetch_models(model_manager) + # 可选:统一序列并行入口(此处默认关闭) + pipe.use_unified_sequence_parallel = False + return pipe + + def denoising_model(self): + return self.dit + + # ------------------------- + # 新增:显式 KV 预初始化函数 + # ------------------------- + def init_cross_kv( + self, + context_tensor: Optional[torch.Tensor] = None, + prompt_path = None + ): + self.load_models_to_device(["dit"]) + """ + 使用固定 prompt 生成文本 context,并在 WanModel 中初始化所有 CrossAttention 的 KV 缓存。 + 必须在 __call__ 前显式调用一次。 + """ + #prompt_path = "../../examples/WanVSR/prompt_tensor/posi_prompt.pth" + if self.dit is None: + raise RuntimeError("请先通过 fetch_models / from_model_manager 初始化 self.dit") + + if context_tensor is None: + if prompt_path is None: + raise ValueError("init_cross_kv: 需要提供 prompt_path 或 context_tensor 其一") + ctx = torch.load(prompt_path, map_location=self.device) + else: + ctx = context_tensor + + ctx = ctx.to(dtype=self.torch_dtype, device=self.device) + + if self.prompt_emb_posi is None: + self.prompt_emb_posi = {} + self.prompt_emb_posi['context'] = ctx + self.prompt_emb_posi['stats'] = "load" + + if hasattr(self.dit, "reinit_cross_kv"): + self.dit.reinit_cross_kv(ctx) + else: + raise AttributeError("WanModel 缺少 reinit_cross_kv(ctx) 方法,请在模型实现中加入该能力。") + self.timestep = torch.tensor([1000.], device=self.device, dtype=self.torch_dtype) + self.t = self.dit.time_embedding(sinusoidal_embedding_1d(self.dit.freq_dim, self.timestep)) + self.t_mod = self.dit.time_projection(self.t).unflatten(1, (6, self.dit.dim)) + # Scheduler + self.scheduler.set_timesteps(1, denoising_strength=1.0, shift=5.0) + self.load_models_to_device([]) + + def prepare_unified_sequence_parallel(self): + return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel} + + def prepare_extra_input(self, latents=None): + return {} + + def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): + latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + return latents + + def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): + frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + return frames + + def offload_model(self, keep_vae=False): + self.dit.clear_cross_kv() + self.prompt_emb_posi['stats'] = "offload" + if hasattr(self.dit, "LQ_proj_in"): + self.dit.LQ_proj_in.to('cpu') + if keep_vae: + self.load_models_to_device(["vae"]) + else: + self.load_models_to_device([]) + + @torch.no_grad() + def __call__( + self, + prompt=None, + negative_prompt="", + denoising_strength=1.0, + seed=None, + rand_device="gpu", + height=480, + width=832, + num_frames=81, + cfg_scale=5.0, + num_inference_steps=50, + sigma_shift=5.0, + tiled=True, + tile_size=(60, 104), + tile_stride=(30, 52), + tea_cache_l1_thresh=None, + tea_cache_model_id="Wan2.1-T2V-1.3B", + progress_bar_cmd=tqdm, + progress_bar_st=None, + LQ_video=None, + is_full_block=False, + if_buffer=False, + topk_ratio=2.0, + kv_ratio=3.0, + local_range = 9, + color_fix = True, + unload_dit = False, + force_offload = False, + ): + # 只接受 cfg=1.0(与原代码一致) + assert cfg_scale == 1.0, "cfg_scale must be 1.0" + + # 要求:必须先 init_cross_kv() + if self.prompt_emb_posi is None or 'context' not in self.prompt_emb_posi: + raise RuntimeError( + "Cross-Attn KV 未初始化。请在调用 __call__ 前先执行:\n" + " pipe.init_cross_kv()\n" + "或传入自定义 context:\n" + " pipe.init_cross_kv(context_tensor=your_context_tensor)" + ) + + if num_frames % 4 != 1: + num_frames = (num_frames + 2) // 4 * 4 + 1 + print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.") + + # Tiler 参数 + tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} + + # 初始化噪声 + if if_buffer: + noise = self.generate_noise((1, 16, (num_frames - 1) // 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) + else: + noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) + # noise = noise.to(dtype=self.torch_dtype, device=self.device) + latents = noise + + process_total_num = (num_frames - 1) // 8 - 2 + is_stream = True + + if self.prompt_emb_posi['stats'] == "offload": + self.init_cross_kv(context_tensor=self.prompt_emb_posi['context']) + self.load_models_to_device(["dit", "vae"]) + self.dit.LQ_proj_in.to(self.device) + + # 清理可能存在的 LQ_proj_in cache + if hasattr(self.dit, "LQ_proj_in"): + self.dit.LQ_proj_in.clear_cache() + + latents_total = [] + self.vae.clear_cache() + + with torch.no_grad(): + for cur_process_idx in progress_bar_cmd(range(process_total_num)): + if cur_process_idx == 0: + pre_cache_k = [None] * len(self.dit.blocks) + pre_cache_v = [None] * len(self.dit.blocks) + LQ_latents = None + inner_loop_num = 7 + for inner_idx in range(inner_loop_num): + cur = self.denoising_model().LQ_proj_in.stream_forward( + LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :] + ) if LQ_video is not None else None + if cur is None: + continue + if LQ_latents is None: + LQ_latents = cur + else: + for layer_idx in range(len(LQ_latents)): + LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1) + cur_latents = latents[:, :, :6, :, :] + else: + LQ_latents = None + inner_loop_num = 2 + for inner_idx in range(inner_loop_num): + cur = self.denoising_model().LQ_proj_in.stream_forward( + LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :] + ) if LQ_video is not None else None + if cur is None: + continue + if LQ_latents is None: + LQ_latents = cur + else: + for layer_idx in range(len(LQ_latents)): + LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1) + cur_latents = latents[:, :, 4+cur_process_idx*2:6+cur_process_idx*2, :, :] + + # 推理(无 motion_controller / vace) + noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video( + self.dit, + x=cur_latents, + timestep=self.timestep, + context=None, + tea_cache=None, + use_unified_sequence_parallel=False, + LQ_latents=LQ_latents, + is_full_block=is_full_block, + is_stream=is_stream, + pre_cache_k=pre_cache_k, + pre_cache_v=pre_cache_v, + topk_ratio=topk_ratio, + kv_ratio=kv_ratio, + cur_process_idx=cur_process_idx, + t_mod=self.t_mod, + t=self.t, + local_range = local_range, + ) + + # 更新 latent + cur_latents = cur_latents - noise_pred_posi + latents_total.append(cur_latents) + + if hasattr(self.dit, "LQ_proj_in"): + self.dit.LQ_proj_in.clear_cache() + + if unload_dit and hasattr(self, 'dit') and not next(self.dit.parameters()).is_cpu: + print("[FlashVSR] Offloading DiT to the CPU to free up VRAM...") + self.offload_model(keep_vae=True) + + latents = torch.cat(latents_total, dim=2) + + # Decode + print("[FlashVSR] Starting VAE decoding...") + frames = self.decode_video(latents, **tiler_kwargs) + + self.vae.clear_cache() + if force_offload: + self.offload_model() + + # 颜色校正(wavelet) + try: + if color_fix: + frames = self.ColorCorrector( + frames.to(device=LQ_video.device), + LQ_video[:, :, :frames.shape[2], :, :], + clip_range=(-1, 1), + chunk_size=16, + method='adain' + ) + except: + pass + + return frames[0] + + +# ----------------------------- +# TeaCache(保留原逻辑;此处默认不启用) +# ----------------------------- +class TeaCache: + def __init__(self, num_inference_steps, rel_l1_thresh, model_id): + self.num_inference_steps = num_inference_steps + self.step = 0 + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + self.rel_l1_thresh = rel_l1_thresh + self.previous_residual = None + self.previous_hidden_states = None + + self.coefficients_dict = { + "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02], + "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01], + "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01], + "Wan2.1-I2V-14B-720P": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02], + } + if model_id not in self.coefficients_dict: + supported_model_ids = ", ".join([i for i in self.coefficients_dict]) + raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).") + self.coefficients = self.coefficients_dict[model_id] + + def check(self, dit: WanModel, x, t_mod): + modulated_inp = t_mod.clone() + if self.step == 0 or self.step == self.num_inference_steps - 1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + coefficients = self.coefficients + rescale_func = np.poly1d(coefficients) + self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) + should_calc = not (self.accumulated_rel_l1_distance < self.rel_l1_thresh) + if should_calc: + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = modulated_inp + self.step = (self.step + 1) % self.num_inference_steps + if should_calc: + self.previous_hidden_states = x.clone() + return not should_calc + + def store(self, hidden_states): + self.previous_residual = hidden_states - self.previous_hidden_states + self.previous_hidden_states = None + + def update(self, hidden_states): + hidden_states = hidden_states + self.previous_residual + return hidden_states + + +# ----------------------------- +# 简化版模型前向封装(无 vace / 无 motion_controller) +# ----------------------------- +def model_fn_wan_video( + dit: WanModel, + x: torch.Tensor, + timestep: torch.Tensor, + context: torch.Tensor, + tea_cache: Optional[TeaCache] = None, + use_unified_sequence_parallel: bool = False, + LQ_latents: Optional[torch.Tensor] = None, + is_full_block: bool = False, + is_stream: bool = False, + pre_cache_k: Optional[list[torch.Tensor]] = None, + pre_cache_v: Optional[list[torch.Tensor]] = None, + topk_ratio: float = 2.0, + kv_ratio: float = 3.0, + cur_process_idx: int = 0, + t_mod : torch.Tensor = None, + t : torch.Tensor = None, + local_range: int = 9, + **kwargs, +): + # patchify + x, (f, h, w) = dit.patchify(x) + + win = (2, 8, 8) + seqlen = f // win[0] + local_num = seqlen + window_size = win[0] * h * w // 128 + square_num = window_size * window_size + topk = int(square_num * topk_ratio) - 1 + kv_len = int(kv_ratio) + + # RoPE 位置(分段) + if cur_process_idx == 0: + freqs = torch.cat([ + dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + else: + freqs = torch.cat([ + dit.freqs[0][4 + cur_process_idx*2:4 + cur_process_idx*2 + f].view(f, 1, 1, -1).expand(f, h, w, -1), + dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + + # TeaCache(默认不启用) + tea_cache_update = tea_cache.check(dit, x, t_mod) if tea_cache is not None else False + + # 统一序列并行(此处默认关闭) + if use_unified_sequence_parallel: + import torch.distributed as dist + from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) + if dist.is_initialized() and dist.get_world_size() > 1: + x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + + # Block 堆叠 + if tea_cache_update: + x = tea_cache.update(x) + else: + for block_id, block in enumerate(dit.blocks): + if LQ_latents is not None and block_id < len(LQ_latents): + x = x + LQ_latents[block_id] + x, last_pre_cache_k, last_pre_cache_v = block( + x, context, t_mod, freqs, f, h, w, + local_num, topk, + block_id=block_id, + kv_len=kv_len, + is_full_block=is_full_block, + is_stream=is_stream, + pre_cache_k=pre_cache_k[block_id] if pre_cache_k is not None else None, + pre_cache_v=pre_cache_v[block_id] if pre_cache_v is not None else None, + local_range = local_range, + ) + if pre_cache_k is not None: pre_cache_k[block_id] = last_pre_cache_k + if pre_cache_v is not None: pre_cache_v[block_id] = last_pre_cache_v + + x = dit.head(x, t) + if use_unified_sequence_parallel: + import torch.distributed as dist + from xfuser.core.distributed import get_sp_group + if dist.is_initialized() and dist.get_world_size() > 1: + x = get_sp_group().all_gather(x, dim=1) + x = dit.unpatchify(x, (f, h, w)) + return x, pre_cache_k, pre_cache_v diff --git a/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/pipelines/flashvsr_tiny.py b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/pipelines/flashvsr_tiny.py new file mode 100644 index 0000000000000000000000000000000000000000..52fe12716aea308108b12993f071b8698278bc37 --- /dev/null +++ b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/pipelines/flashvsr_tiny.py @@ -0,0 +1,615 @@ +import types +import os +import time +from typing import Optional, Tuple, Literal + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from einops import rearrange +from PIL import Image +from tqdm import tqdm +# import pyfiglet + +from ..models import ModelManager +from ..models.utils import clean_vram +from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d +from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample +from ..schedulers.flow_match import FlowMatchScheduler +from .base import BasePipeline + + +# ----------------------------- +# 基础工具:ADAIN 所需的统计量(保留以备需要;管线默认用 wavelet) +# ----------------------------- +def _calc_mean_std(feat: torch.Tensor, eps: float = 1e-5) -> Tuple[torch.Tensor, torch.Tensor]: + assert feat.dim() == 4, 'feat 必须是 (N, C, H, W)' + N, C = feat.shape[:2] + var = feat.view(N, C, -1).var(dim=2, unbiased=False) + eps + std = var.sqrt().view(N, C, 1, 1) + mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) + return mean, std + + +def _adain(content_feat: torch.Tensor, style_feat: torch.Tensor) -> torch.Tensor: + assert content_feat.shape[:2] == style_feat.shape[:2], "ADAIN: N、C 必须匹配" + size = content_feat.size() + style_mean, style_std = _calc_mean_std(style_feat) + content_mean, content_std = _calc_mean_std(content_feat) + normalized = (content_feat - content_mean.expand(size)) / content_std.expand(size) + return normalized * style_std.expand(size) + style_mean.expand(size) + + +# ----------------------------- +# 小波式模糊与分解/重构(ColorCorrector 用) +# ----------------------------- +def _make_gaussian3x3_kernel(dtype, device) -> torch.Tensor: + vals = [ + [0.0625, 0.125, 0.0625], + [0.125, 0.25, 0.125 ], + [0.0625, 0.125, 0.0625], + ] + return torch.tensor(vals, dtype=dtype, device=device) + + +def _wavelet_blur(x: torch.Tensor, radius: int) -> torch.Tensor: + assert x.dim() == 4, 'x 必须是 (N, C, H, W)' + N, C, H, W = x.shape + base = _make_gaussian3x3_kernel(x.dtype, x.device) + weight = base.view(1, 1, 3, 3).repeat(C, 1, 1, 1) + pad = radius + x_pad = F.pad(x, (pad, pad, pad, pad), mode='replicate') + out = F.conv2d(x_pad, weight, bias=None, stride=1, padding=0, dilation=radius, groups=C) + return out + + +def _wavelet_decompose(x: torch.Tensor, levels: int = 5) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 4, 'x 必须是 (N, C, H, W)' + high = torch.zeros_like(x) + low = x + for i in range(levels): + radius = 2 ** i + blurred = _wavelet_blur(low, radius) + high = high + (low - blurred) + low = blurred + return high, low + + +def _wavelet_reconstruct(content: torch.Tensor, style: torch.Tensor, levels: int = 5) -> torch.Tensor: + c_high, _ = _wavelet_decompose(content, levels=levels) + _, s_low = _wavelet_decompose(style, levels=levels) + return c_high + s_low + + +# ----------------------------- +# 无状态颜色矫正模块(视频友好,默认 wavelet) +# ----------------------------- +class TorchColorCorrectorWavelet(nn.Module): + def __init__(self, levels: int = 5): + super().__init__() + self.levels = levels + + @staticmethod + def _flatten_time(x: torch.Tensor) -> Tuple[torch.Tensor, int, int]: + assert x.dim() == 5, '输入必须是 (B, C, f, H, W)' + B, C, f, H, W = x.shape + y = x.permute(0, 2, 1, 3, 4).reshape(B * f, C, H, W) + return y, B, f + + @staticmethod + def _unflatten_time(y: torch.Tensor, B: int, f: int) -> torch.Tensor: + BF, C, H, W = y.shape + assert BF == B * f + return y.reshape(B, f, C, H, W).permute(0, 2, 1, 3, 4) + + def forward( + self, + hq_image: torch.Tensor, # (B, C, f, H, W) + lq_image: torch.Tensor, # (B, C, f, H, W) + clip_range: Tuple[float, float] = (-1.0, 1.0), + method: Literal['wavelet', 'adain'] = 'wavelet', + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + assert hq_image.shape == lq_image.shape, "HQ 与 LQ 的形状必须一致" + assert hq_image.dim() == 5 and hq_image.shape[1] == 3, "输入必须是 (B, 3, f, H, W)" + + B, C, f, H, W = hq_image.shape + if chunk_size is None or chunk_size >= f: + hq4, B, f = self._flatten_time(hq_image) + lq4, _, _ = self._flatten_time(lq_image) + if method == 'wavelet': + out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels) + elif method == 'adain': + out4 = _adain(hq4, lq4) + else: + raise ValueError(f"未知 method: {method}") + out4 = torch.clamp(out4, *clip_range) + out = self._unflatten_time(out4, B, f) + return out + + outs = [] + for start in range(0, f, chunk_size): + end = min(start + chunk_size, f) + hq_chunk = hq_image[:, :, start:end] + lq_chunk = lq_image[:, :, start:end] + hq4, B_, f_ = self._flatten_time(hq_chunk) + lq4, _, _ = self._flatten_time(lq_chunk) + if method == 'wavelet': + out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels) + elif method == 'adain': + out4 = _adain(hq4, lq4) + else: + raise ValueError(f"未知 method: {method}") + out4 = torch.clamp(out4, *clip_range) + out_chunk = self._unflatten_time(out4, B_, f_) + outs.append(out_chunk) + out = torch.cat(outs, dim=2) + return out + + +# ----------------------------- +# 简化版 Pipeline(仅 dit + vae) +# ----------------------------- +class FlashVSRTinyPipeline(BasePipeline): + + def __init__(self, device="cuda", torch_dtype=torch.float16): + super().__init__(device=device, torch_dtype=torch_dtype) + self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True) + self.dit: WanModel = None + self.vae: WanVideoVAE = None + self.model_names = ['dit', 'vae'] + self.height_division_factor = 16 + self.width_division_factor = 16 + self.use_unified_sequence_parallel = False + self.prompt_emb_posi = None + self.ColorCorrector = TorchColorCorrectorWavelet(levels=5) + + print(r""" + ███████╗██╗ █████╗ ███████╗██╗ ██╗██╗ ██╗███████╗█████╗ + ██╔════╝██║ ██╔══██╗██╔════╝██║ ██║██║ ██║██╔════╝██╔══██╗ ██╗ + █████╗ ██║ ███████║███████╗███████║╚██╗ ██╔╝███████╗███████║ ██████╗ + ██╔══╝ ██║ ██╔══██║╚════██║██╔══██║ ╚████╔╝ ╚════██║██╔═██║ ██╔═╝ + ██║ ███████╗██║ ██║███████║██║ ██║ ╚██╔╝ ███████║██║ ██║ ╚═╝ + ╚═╝ ╚══════╝╚═╝ ╚═╝╚══════╝╚═╝ ╚═╝ ╚═╝ ╚══════╝╚═╝ ╚═╝ +""") + + def enable_vram_management(self, num_persistent_param_in_dit=None): + # 仅管理 dit / vae + dtype = next(iter(self.dit.parameters())).dtype + from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear + enable_vram_management( + self.dit, + module_map={ + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Conv3d: AutoWrappedModule, + torch.nn.LayerNorm: AutoWrappedModule, + RMSNorm: AutoWrappedModule, + }, + module_config=dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device=self.device, + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + max_num_param=num_persistent_param_in_dit, + overflow_module_config=dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + ) + self.enable_cpu_offload() + + def fetch_models(self, model_manager: ModelManager): + self.dit = model_manager.fetch_model("wan_video_dit") + self.vae = model_manager.fetch_model("wan_video_vae") + + @staticmethod + def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False): + if device is None: device = model_manager.device + if torch_dtype is None: torch_dtype = model_manager.torch_dtype + pipe = FlashVSRTinyPipeline(device=device, torch_dtype=torch_dtype) + pipe.fetch_models(model_manager) + # 可选:统一序列并行入口(此处默认关闭) + pipe.use_unified_sequence_parallel = False + return pipe + + def denoising_model(self): + return self.dit + + # ------------------------- + # 新增:显式 KV 预初始化函数 + # ------------------------- + def init_cross_kv( + self, + context_tensor: Optional[torch.Tensor] = None, + prompt_path = None, + ): + self.load_models_to_device(["dit"]) + """ + 使用固定 prompt 生成文本 context,并在 WanModel 中初始化所有 CrossAttention 的 KV 缓存。 + 必须在 __call__ 前显式调用一次。 + """ + #prompt_path = "../../examples/WanVSR/prompt_tensor/posi_prompt.pth" + + if self.dit is None: + raise RuntimeError("请先通过 fetch_models / from_model_manager 初始化 self.dit") + + if context_tensor is None: + if prompt_path is None: + raise ValueError("init_cross_kv: 需要提供 prompt_path 或 context_tensor 其一") + ctx = torch.load(prompt_path, map_location=self.device) + else: + ctx = context_tensor + + ctx = ctx.to(dtype=self.torch_dtype, device=self.device) + + if self.prompt_emb_posi is None: + self.prompt_emb_posi = {} + self.prompt_emb_posi['context'] = ctx + self.prompt_emb_posi['stats'] = "load" + + if hasattr(self.dit, "reinit_cross_kv"): + self.dit.reinit_cross_kv(ctx) + else: + raise AttributeError("WanModel 缺少 reinit_cross_kv(ctx) 方法,请在模型实现中加入该能力。") + self.timestep = torch.tensor([1000.], device=self.device, dtype=self.torch_dtype) + self.t = self.dit.time_embedding(sinusoidal_embedding_1d(self.dit.freq_dim, self.timestep)) + self.t_mod = self.dit.time_projection(self.t).unflatten(1, (6, self.dit.dim)) + # Scheduler + self.scheduler.set_timesteps(1, denoising_strength=1.0, shift=5.0) + self.load_models_to_device([]) + + def prepare_unified_sequence_parallel(self): + return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel} + + def prepare_extra_input(self, latents=None): + return {} + + def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): + latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + return latents + + def _decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): + frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + return frames + + def decode_video(self, latents, cond=None, **kwargs): + frames = self.TCDecoder.decode_video( + latents.transpose(1, 2), # TCDecoder 需要 (B, F, C, H, W) + parallel=False, + show_progress_bar=False, + cond=cond + ).transpose(1, 2).mul_(2).sub_(1) # 转回 (B, C, F, H, W) 格式,范围 -1 to 1 + + return frames + + def offload_model(self, keep_vae=False): + self.dit.clear_cross_kv() + self.prompt_emb_posi['stats'] = "offload" + self.load_models_to_device([]) + if hasattr(self.dit, "LQ_proj_in"): + self.dit.LQ_proj_in.to('cpu') + if not keep_vae: + self.TCDecoder.to('cpu') + + @torch.no_grad() + def __call__( + self, + prompt=None, + negative_prompt="", + denoising_strength=1.0, + seed=None, + rand_device="gpu", + height=480, + width=832, + num_frames=81, + cfg_scale=5.0, + num_inference_steps=50, + sigma_shift=5.0, + tiled=True, + tile_size=(60, 104), + tile_stride=(30, 52), + tea_cache_l1_thresh=None, + tea_cache_model_id="Wan2.1-T2V-1.3B", + progress_bar_cmd=tqdm, + progress_bar_st=None, + LQ_video=None, + is_full_block=False, + if_buffer=False, + topk_ratio=2.0, + kv_ratio=3.0, + local_range = 9, + color_fix = True, + unload_dit = False, + force_offload = False, + ): + # 只接受 cfg=1.0(与原代码一致) + assert cfg_scale == 1.0, "cfg_scale must be 1.0" + + # 要求:必须先 init_cross_kv() + if self.prompt_emb_posi is None or 'context' not in self.prompt_emb_posi: + raise RuntimeError( + "Cross-Attn KV 未初始化。请在调用 __call__ 前先执行:\n" + " pipe.init_cross_kv()\n" + "或传入自定义 context:\n" + " pipe.init_cross_kv(context_tensor=your_context_tensor)" + ) + + # 尺寸修正 + height, width = self.check_resize_height_width(height, width) + if num_frames % 4 != 1: + num_frames = (num_frames + 2) // 4 * 4 + 1 + print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.") + + # Tiler 参数 + tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} + + # 初始化噪声 + if if_buffer: + noise = self.generate_noise((1, 16, (num_frames - 1) // 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) + else: + noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) + # noise = noise.to(dtype=self.torch_dtype, device=self.device) + latents = noise + + process_total_num = (num_frames - 1) // 8 - 2 + is_stream = True + + if self.prompt_emb_posi['stats'] == "offload": + self.init_cross_kv(context_tensor=self.prompt_emb_posi['context']) + self.load_models_to_device(["dit"]) + self.dit.LQ_proj_in.to(self.device) + self.TCDecoder.to(self.device) + + # 清理可能存在的 LQ_proj_in cache + if hasattr(self.dit, "LQ_proj_in"): + self.dit.LQ_proj_in.clear_cache() + + latents_total = [] + self.TCDecoder.clean_mem() + LQ_pre_idx = 0 + LQ_cur_idx = 0 + + with torch.no_grad(): + for cur_process_idx in progress_bar_cmd(range(process_total_num)): + if cur_process_idx == 0: + pre_cache_k = [None] * len(self.dit.blocks) + pre_cache_v = [None] * len(self.dit.blocks) + LQ_latents = None + inner_loop_num = 7 + for inner_idx in range(inner_loop_num): + cur = self.denoising_model().LQ_proj_in.stream_forward( + LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :] + ) if LQ_video is not None else None + if cur is None: + continue + if LQ_latents is None: + LQ_latents = cur + else: + for layer_idx in range(len(LQ_latents)): + LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1) + LQ_cur_idx = (inner_loop_num-1)*4-3 + cur_latents = latents[:, :, :6, :, :] + else: + LQ_latents = None + inner_loop_num = 2 + for inner_idx in range(inner_loop_num): + cur = self.denoising_model().LQ_proj_in.stream_forward( + LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :] + ) if LQ_video is not None else None + if cur is None: + continue + if LQ_latents is None: + LQ_latents = cur + else: + for layer_idx in range(len(LQ_latents)): + LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1) + LQ_cur_idx = cur_process_idx*8+21+(inner_loop_num-2)*4 + cur_latents = latents[:, :, 4+cur_process_idx*2:6+cur_process_idx*2, :, :] + + # 推理(无 motion_controller / vace) + noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video( + self.dit, + x=cur_latents, + timestep=self.timestep, + context=None, + tea_cache=None, + use_unified_sequence_parallel=False, + LQ_latents=LQ_latents, + is_full_block=is_full_block, + is_stream=is_stream, + pre_cache_k=pre_cache_k, + pre_cache_v=pre_cache_v, + topk_ratio=topk_ratio, + kv_ratio=kv_ratio, + cur_process_idx=cur_process_idx, + t_mod=self.t_mod, + t=self.t, + local_range = local_range, + ) + + # 更新 latent + cur_latents = cur_latents - noise_pred_posi + latents_total.append(cur_latents) + LQ_pre_idx = LQ_cur_idx + + if hasattr(self.dit, "LQ_proj_in"): + self.dit.LQ_proj_in.clear_cache() + + if unload_dit and hasattr(self, 'dit') and not next(self.dit.parameters()).is_cpu: + print("[FlashVSR] Offloading DiT to the CPU to free up VRAM...") + self.offload_model(keep_vae=True) + + latents = torch.cat(latents_total, dim=2) + + # Decode + print("[FlashVSR] Starting VAE decoding...") + frames = self.TCDecoder.decode_video(latents.transpose(1, 2),parallel=False, show_progress_bar=False, cond=LQ_video[:,:,:LQ_cur_idx,:,:]).transpose(1, 2).mul_(2).sub_(1) + + self.TCDecoder.clean_mem() + if force_offload: + self.offload_model() + + # 颜色校正(wavelet) + try: + if color_fix: + frames = self.ColorCorrector( + frames.to(device=LQ_video.device), + LQ_video[:, :, :frames.shape[2], :, :], + clip_range=(-1, 1), + chunk_size=16, + method='adain' + ) + except: + pass + + return frames[0] + + +# ----------------------------- +# TeaCache(保留原逻辑;此处默认不启用) +# ----------------------------- +class TeaCache: + def __init__(self, num_inference_steps, rel_l1_thresh, model_id): + self.num_inference_steps = num_inference_steps + self.step = 0 + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + self.rel_l1_thresh = rel_l1_thresh + self.previous_residual = None + self.previous_hidden_states = None + + self.coefficients_dict = { + "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02], + "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01], + "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01], + "Wan2.1-I2V-14B-720P": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02], + } + if model_id not in self.coefficients_dict: + supported_model_ids = ", ".join([i for i in self.coefficients_dict]) + raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).") + self.coefficients = self.coefficients_dict[model_id] + + def check(self, dit: WanModel, x, t_mod): + modulated_inp = t_mod.clone() + if self.step == 0 or self.step == self.num_inference_steps - 1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + coefficients = self.coefficients + rescale_func = np.poly1d(coefficients) + self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) + should_calc = not (self.accumulated_rel_l1_distance < self.rel_l1_thresh) + if should_calc: + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = modulated_inp + self.step = (self.step + 1) % self.num_inference_steps + if should_calc: + self.previous_hidden_states = x.clone() + return not should_calc + + def store(self, hidden_states): + self.previous_residual = hidden_states - self.previous_hidden_states + self.previous_hidden_states = None + + def update(self, hidden_states): + hidden_states = hidden_states + self.previous_residual + return hidden_states + + +# ----------------------------- +# 简化版模型前向封装(无 vace / 无 motion_controller) +# ----------------------------- +def model_fn_wan_video( + dit: WanModel, + x: torch.Tensor, + timestep: torch.Tensor, + context: torch.Tensor, + tea_cache: Optional[TeaCache] = None, + use_unified_sequence_parallel: bool = False, + LQ_latents: Optional[torch.Tensor] = None, + is_full_block: bool = False, + is_stream: bool = False, + pre_cache_k: Optional[list[torch.Tensor]] = None, + pre_cache_v: Optional[list[torch.Tensor]] = None, + topk_ratio: float = 2.0, + kv_ratio: float = 3.0, + cur_process_idx: int = 0, + t_mod : torch.Tensor = None, + t : torch.Tensor = None, + local_range: int = 9, + **kwargs, +): + # patchify + x, (f, h, w) = dit.patchify(x) + + win = (2, 8, 8) + seqlen = f // win[0] + local_num = seqlen + window_size = win[0] * h * w // 128 + square_num = window_size * window_size + topk = int(square_num * topk_ratio) - 1 + kv_len = int(kv_ratio) + + # RoPE 位置(分段) + if cur_process_idx == 0: + freqs = torch.cat([ + dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + else: + freqs = torch.cat([ + dit.freqs[0][4 + cur_process_idx*2:4 + cur_process_idx*2 + f].view(f, 1, 1, -1).expand(f, h, w, -1), + dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + + # TeaCache(默认不启用) + tea_cache_update = tea_cache.check(dit, x, t_mod) if tea_cache is not None else False + + # 统一序列并行(此处默认关闭) + if use_unified_sequence_parallel: + import torch.distributed as dist + from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) + if dist.is_initialized() and dist.get_world_size() > 1: + x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + + # Block 堆叠 + if tea_cache_update: + x = tea_cache.update(x) + else: + for block_id, block in enumerate(dit.blocks): + if LQ_latents is not None and block_id < len(LQ_latents): + x = x + LQ_latents[block_id] + x, last_pre_cache_k, last_pre_cache_v = block( + x, context, t_mod, freqs, f, h, w, + local_num, topk, + block_id=block_id, + kv_len=kv_len, + is_full_block=is_full_block, + is_stream=is_stream, + pre_cache_k=pre_cache_k[block_id] if pre_cache_k is not None else None, + pre_cache_v=pre_cache_v[block_id] if pre_cache_v is not None else None, + local_range = local_range, + ) + if pre_cache_k is not None: pre_cache_k[block_id] = last_pre_cache_k + if pre_cache_v is not None: pre_cache_v[block_id] = last_pre_cache_v + + x = dit.head(x, t) + if use_unified_sequence_parallel: + import torch.distributed as dist + from xfuser.core.distributed import get_sp_group + if dist.is_initialized() and dist.get_world_size() > 1: + x = get_sp_group().all_gather(x, dim=1) + x = dit.unpatchify(x, (f, h, w)) + return x, pre_cache_k, pre_cache_v diff --git a/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/pipelines/flashvsr_tiny_long.py b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/pipelines/flashvsr_tiny_long.py new file mode 100644 index 0000000000000000000000000000000000000000..c927ef262550a556fd49e8d3e21e400d4233d0e6 --- /dev/null +++ b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/pipelines/flashvsr_tiny_long.py @@ -0,0 +1,620 @@ +import types +import os +import time +from typing import Optional, Tuple, Literal + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from einops import rearrange +from PIL import Image +from tqdm import tqdm +# import pyfiglet + +from ..models import ModelManager +from ..models.utils import clean_vram +from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d +from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample +from ..schedulers.flow_match import FlowMatchScheduler +from .base import BasePipeline + + +# ----------------------------- +# 基础工具:ADAIN 所需的统计量(保留以备需要;管线默认用 wavelet) +# ----------------------------- +def _calc_mean_std(feat: torch.Tensor, eps: float = 1e-5) -> Tuple[torch.Tensor, torch.Tensor]: + assert feat.dim() == 4, 'feat 必须是 (N, C, H, W)' + N, C = feat.shape[:2] + var = feat.view(N, C, -1).var(dim=2, unbiased=False) + eps + std = var.sqrt().view(N, C, 1, 1) + mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) + return mean, std + + +def _adain(content_feat: torch.Tensor, style_feat: torch.Tensor) -> torch.Tensor: + assert content_feat.shape[:2] == style_feat.shape[:2], "ADAIN: N、C 必须匹配" + size = content_feat.size() + style_mean, style_std = _calc_mean_std(style_feat) + content_mean, content_std = _calc_mean_std(content_feat) + normalized = (content_feat - content_mean.expand(size)) / content_std.expand(size) + return normalized * style_std.expand(size) + style_mean.expand(size) + + +# ----------------------------- +# 小波式模糊与分解/重构(ColorCorrector 用) +# ----------------------------- +def _make_gaussian3x3_kernel(dtype, device) -> torch.Tensor: + vals = [ + [0.0625, 0.125, 0.0625], + [0.125, 0.25, 0.125 ], + [0.0625, 0.125, 0.0625], + ] + return torch.tensor(vals, dtype=dtype, device=device) + + +def _wavelet_blur(x: torch.Tensor, radius: int) -> torch.Tensor: + assert x.dim() == 4, 'x 必须是 (N, C, H, W)' + N, C, H, W = x.shape + base = _make_gaussian3x3_kernel(x.dtype, x.device) + weight = base.view(1, 1, 3, 3).repeat(C, 1, 1, 1) + pad = radius + x_pad = F.pad(x, (pad, pad, pad, pad), mode='replicate') + out = F.conv2d(x_pad, weight, bias=None, stride=1, padding=0, dilation=radius, groups=C) + return out + + +def _wavelet_decompose(x: torch.Tensor, levels: int = 5) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 4, 'x 必须是 (N, C, H, W)' + high = torch.zeros_like(x) + low = x + for i in range(levels): + radius = 2 ** i + blurred = _wavelet_blur(low, radius) + high = high + (low - blurred) + low = blurred + return high, low + + +def _wavelet_reconstruct(content: torch.Tensor, style: torch.Tensor, levels: int = 5) -> torch.Tensor: + c_high, _ = _wavelet_decompose(content, levels=levels) + _, s_low = _wavelet_decompose(style, levels=levels) + return c_high + s_low + + +# ----------------------------- +# 无状态颜色矫正模块(视频友好,默认 wavelet) +# ----------------------------- +class TorchColorCorrectorWavelet(nn.Module): + def __init__(self, levels: int = 5): + super().__init__() + self.levels = levels + + @staticmethod + def _flatten_time(x: torch.Tensor) -> Tuple[torch.Tensor, int, int]: + assert x.dim() == 5, '输入必须是 (B, C, f, H, W)' + B, C, f, H, W = x.shape + y = x.permute(0, 2, 1, 3, 4).reshape(B * f, C, H, W) + return y, B, f + + @staticmethod + def _unflatten_time(y: torch.Tensor, B: int, f: int) -> torch.Tensor: + BF, C, H, W = y.shape + assert BF == B * f + return y.reshape(B, f, C, H, W).permute(0, 2, 1, 3, 4) + + def forward( + self, + hq_image: torch.Tensor, # (B, C, f, H, W) + lq_image: torch.Tensor, # (B, C, f, H, W) + clip_range: Tuple[float, float] = (-1.0, 1.0), + method: Literal['wavelet', 'adain'] = 'wavelet', + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + assert hq_image.shape == lq_image.shape, "HQ 与 LQ 的形状必须一致" + assert hq_image.dim() == 5 and hq_image.shape[1] == 3, "输入必须是 (B, 3, f, H, W)" + + B, C, f, H, W = hq_image.shape + if chunk_size is None or chunk_size >= f: + hq4, B, f = self._flatten_time(hq_image) + lq4, _, _ = self._flatten_time(lq_image) + if method == 'wavelet': + out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels) + elif method == 'adain': + out4 = _adain(hq4, lq4) + else: + raise ValueError(f"未知 method: {method}") + out4 = torch.clamp(out4, *clip_range) + out = self._unflatten_time(out4, B, f) + return out + + outs = [] + for start in range(0, f, chunk_size): + end = min(start + chunk_size, f) + hq_chunk = hq_image[:, :, start:end] + lq_chunk = lq_image[:, :, start:end] + hq4, B_, f_ = self._flatten_time(hq_chunk) + lq4, _, _ = self._flatten_time(lq_chunk) + if method == 'wavelet': + out4 = _wavelet_reconstruct(hq4, lq4, levels=self.levels) + elif method == 'adain': + out4 = _adain(hq4, lq4) + else: + raise ValueError(f"未知 method: {method}") + out4 = torch.clamp(out4, *clip_range) + out_chunk = self._unflatten_time(out4, B_, f_) + outs.append(out_chunk) + out = torch.cat(outs, dim=2) + return out + + +# ----------------------------- +# 简化版 Pipeline(仅 dit + vae) +# ----------------------------- +class FlashVSRTinyLongPipeline(BasePipeline): + + def __init__(self, device="cuda", torch_dtype=torch.float16): + super().__init__(device=device, torch_dtype=torch_dtype) + self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True) + self.dit: WanModel = None + self.vae: WanVideoVAE = None + self.model_names = ['dit', 'vae'] + self.height_division_factor = 16 + self.width_division_factor = 16 + self.use_unified_sequence_parallel = False + self.prompt_emb_posi = None + self.ColorCorrector = TorchColorCorrectorWavelet(levels=5) + + print(r""" + ███████╗██╗ █████╗ ███████╗██╗ ██╗██╗ ██╗███████╗█████╗ + ██╔════╝██║ ██╔══██╗██╔════╝██║ ██║██║ ██║██╔════╝██╔══██╗ ██╗ + █████╗ ██║ ███████║███████╗███████║╚██╗ ██╔╝███████╗███████║ ██████╗ + ██╔══╝ ██║ ██╔══██║╚════██║██╔══██║ ╚████╔╝ ╚════██║██╔═██║ ██╔═╝ + ██║ ███████╗██║ ██║███████║██║ ██║ ╚██╔╝ ███████║██║ ██║ ╚═╝ + ╚═╝ ╚══════╝╚═╝ ╚═╝╚══════╝╚═╝ ╚═╝ ╚═╝ ╚══════╝╚═╝ ╚═╝ +""") + + def enable_vram_management(self, num_persistent_param_in_dit=None): + # 仅管理 dit / vae + dtype = next(iter(self.dit.parameters())).dtype + from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear + enable_vram_management( + self.dit, + module_map={ + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Conv3d: AutoWrappedModule, + torch.nn.LayerNorm: AutoWrappedModule, + RMSNorm: AutoWrappedModule, + }, + module_config=dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device=self.device, + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + max_num_param=num_persistent_param_in_dit, + overflow_module_config=dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + ) + self.enable_cpu_offload() + + def fetch_models(self, model_manager: ModelManager): + self.dit = model_manager.fetch_model("wan_video_dit") + self.vae = model_manager.fetch_model("wan_video_vae") + + @staticmethod + def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False): + if device is None: device = model_manager.device + if torch_dtype is None: torch_dtype = model_manager.torch_dtype + pipe = FlashVSRTinyLongPipeline(device=device, torch_dtype=torch_dtype) + pipe.fetch_models(model_manager) + # 可选:统一序列并行入口(此处默认关闭) + pipe.use_unified_sequence_parallel = False + return pipe + + def denoising_model(self): + return self.dit + + # ------------------------- + # 新增:显式 KV 预初始化函数 + # ------------------------- + def init_cross_kv( + self, + context_tensor: Optional[torch.Tensor] = None, + prompt_path = None, + ): + self.load_models_to_device(["dit"]) + """ + 使用固定 prompt 生成文本 context,并在 WanModel 中初始化所有 CrossAttention 的 KV 缓存。 + 必须在 __call__ 前显式调用一次。 + """ + #prompt_path = "../../examples/WanVSR/prompt_tensor/posi_prompt.pth" + + if self.dit is None: + raise RuntimeError("请先通过 fetch_models / from_model_manager 初始化 self.dit") + + if context_tensor is None: + if prompt_path is None: + raise ValueError("init_cross_kv: 需要提供 prompt_path 或 context_tensor 其一") + ctx = torch.load(prompt_path, map_location=self.device) + else: + ctx = context_tensor + + ctx = ctx.to(dtype=self.torch_dtype, device=self.device) + + if self.prompt_emb_posi is None: + self.prompt_emb_posi = {} + self.prompt_emb_posi['context'] = ctx + self.prompt_emb_posi['stats'] = "load" + + if hasattr(self.dit, "reinit_cross_kv"): + self.dit.reinit_cross_kv(ctx) + else: + raise AttributeError("WanModel 缺少 reinit_cross_kv(ctx) 方法,请在模型实现中加入该能力。") + self.timestep = torch.tensor([1000.], device=self.device, dtype=self.torch_dtype) + self.t = self.dit.time_embedding(sinusoidal_embedding_1d(self.dit.freq_dim, self.timestep)) + self.t_mod = self.dit.time_projection(self.t).unflatten(1, (6, self.dit.dim)) + # Scheduler + self.scheduler.set_timesteps(1, denoising_strength=1.0, shift=5.0) + self.load_models_to_device([]) + + def prepare_unified_sequence_parallel(self): + return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel} + + def prepare_extra_input(self, latents=None): + return {} + + def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): + latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + return latents + + def _decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): + frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + return frames + + def decode_video(self, latents, cond=None, **kwargs): + frames = self.TCDecoder.decode_video( + latents.transpose(1, 2), # TCDecoder 需要 (B, F, C, H, W) + parallel=False, + show_progress_bar=False, + cond=cond + ).transpose(1, 2).mul_(2).sub_(1) # 转回 (B, C, F, H, W) 格式,范围 -1 to 1 + + return frames + + def offload_model(self, keep_vae=False): + self.dit.clear_cross_kv() + self.prompt_emb_posi['stats'] = "offload" + self.load_models_to_device([]) + if hasattr(self.dit, "LQ_proj_in"): + self.dit.LQ_proj_in.to('cpu') + if not keep_vae: + self.TCDecoder.to('cpu') + + @torch.no_grad() + def __call__( + self, + prompt=None, + negative_prompt="", + denoising_strength=1.0, + seed=None, + rand_device="gpu", + height=480, + width=832, + num_frames=81, + cfg_scale=5.0, + num_inference_steps=50, + sigma_shift=5.0, + tiled=True, + tile_size=(60, 104), + tile_stride=(30, 52), + tea_cache_l1_thresh=None, + tea_cache_model_id="Wan2.1-T2V-1.3B", + progress_bar_cmd=tqdm, + progress_bar_st=None, + LQ_video=None, + is_full_block=False, + if_buffer=False, + topk_ratio=2.0, + kv_ratio=3.0, + local_range = 9, + color_fix = True, + unload_dit = False, + force_offload = False, + ): + # 只接受 cfg=1.0(与原代码一致) + assert cfg_scale == 1.0, "cfg_scale must be 1.0" + + # 要求:必须先 init_cross_kv() + if self.prompt_emb_posi is None or 'context' not in self.prompt_emb_posi: + raise RuntimeError( + "Cross-Attn KV 未初始化。请在调用 __call__ 前先执行:\n" + " pipe.init_cross_kv()\n" + "或传入自定义 context:\n" + " pipe.init_cross_kv(context_tensor=your_context_tensor)" + ) + + # 尺寸修正 + height, width = self.check_resize_height_width(height, width) + if num_frames % 4 != 1: + num_frames = (num_frames + 2) // 4 * 4 + 1 + print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.") + + # Tiler 参数 + tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} + + # 初始化噪声 + if if_buffer: + noise = self.generate_noise((1, 16, (num_frames - 1) // 4, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) + else: + noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) + # noise = noise.to(dtype=self.torch_dtype, device=self.device) + latents = noise + + process_total_num = (num_frames - 1) // 8 - 2 + is_stream = True + + if self.prompt_emb_posi['stats'] == "offload": + self.init_cross_kv(context_tensor=self.prompt_emb_posi['context']) + self.load_models_to_device(["dit"]) + self.dit.LQ_proj_in.to(self.device) + self.TCDecoder.to(self.device) + + # 清理可能存在的 LQ_proj_in cache + if hasattr(self.dit, "LQ_proj_in"): + self.dit.LQ_proj_in.clear_cache() + + frames_total = [] + LQ_pre_idx = 0 + LQ_cur_idx = 0 + self.TCDecoder.clean_mem() + + with torch.no_grad(): + for cur_process_idx in progress_bar_cmd(range(process_total_num)): + if cur_process_idx == 0: + pre_cache_k = [None] * len(self.dit.blocks) + pre_cache_v = [None] * len(self.dit.blocks) + LQ_latents = None + inner_loop_num = 7 + for inner_idx in range(inner_loop_num): + cur = self.denoising_model().LQ_proj_in.stream_forward( + LQ_video[:, :, max(0, inner_idx*4-3):(inner_idx+1)*4-3, :, :].to(self.device) + ) if LQ_video is not None else None + if cur is None: + continue + if LQ_latents is None: + LQ_latents = cur + else: + for layer_idx in range(len(LQ_latents)): + LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1) + LQ_cur_idx = (inner_loop_num-1)*4-3 + cur_latents = latents[:, :, :6, :, :] + else: + LQ_latents = None + inner_loop_num = 2 + for inner_idx in range(inner_loop_num): + cur = self.denoising_model().LQ_proj_in.stream_forward( + LQ_video[:, :, cur_process_idx*8+17+inner_idx*4:cur_process_idx*8+21+inner_idx*4, :, :].to(self.device) + ) if LQ_video is not None else None + if cur is None: + continue + if LQ_latents is None: + LQ_latents = cur + else: + for layer_idx in range(len(LQ_latents)): + LQ_latents[layer_idx] = torch.cat([LQ_latents[layer_idx], cur[layer_idx]], dim=1) + LQ_cur_idx = cur_process_idx*8+21+(inner_loop_num-2)*4 + cur_latents = latents[:, :, 4+cur_process_idx*2:6+cur_process_idx*2, :, :] + + # 推理(无 motion_controller / vace) + noise_pred_posi, pre_cache_k, pre_cache_v = model_fn_wan_video( + self.dit, + x=cur_latents, + timestep=self.timestep, + context=None, + tea_cache=None, + use_unified_sequence_parallel=False, + LQ_latents=LQ_latents, + is_full_block=is_full_block, + is_stream=is_stream, + pre_cache_k=pre_cache_k, + pre_cache_v=pre_cache_v, + topk_ratio=topk_ratio, + kv_ratio=kv_ratio, + cur_process_idx=cur_process_idx, + t_mod=self.t_mod, + t=self.t, + local_range = local_range, + ) + + # 更新 latent + cur_latents = cur_latents - noise_pred_posi + + # Decode + cur_LQ_frame = LQ_video[:,:,LQ_pre_idx:LQ_cur_idx,:,:].to(self.device) + cur_frames = self.TCDecoder.decode_video( + cur_latents.transpose(1, 2), + parallel=False, + show_progress_bar=False, + cond=cur_LQ_frame).transpose(1, 2).mul_(2).sub_(1) + + # 颜色校正(wavelet) + try: + if color_fix: + cur_frames = self.ColorCorrector( + cur_frames.to(device=self.device), + cur_LQ_frame, + clip_range=(-1, 1), + chunk_size=None, + method='adain' + ) + except: + pass + + frames_total.append(cur_frames.to('cpu')) + LQ_pre_idx = LQ_cur_idx + + if unload_dit: + del noise_pred_posi, cur_frames, cur_latents, cur_LQ_frame + clean_vram() + + if hasattr(self.dit, "LQ_proj_in"): + self.dit.LQ_proj_in.clear_cache() + + self.TCDecoder.clean_mem() + if force_offload: + self.offload_model() + + frames = torch.cat(frames_total, dim=2) + + return frames[0] + + +# ----------------------------- +# TeaCache(保留原逻辑;此处默认不启用) +# ----------------------------- +class TeaCache: + def __init__(self, num_inference_steps, rel_l1_thresh, model_id): + self.num_inference_steps = num_inference_steps + self.step = 0 + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + self.rel_l1_thresh = rel_l1_thresh + self.previous_residual = None + self.previous_hidden_states = None + + self.coefficients_dict = { + "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02], + "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01], + "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01], + "Wan2.1-I2V-14B-720P": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02], + } + if model_id not in self.coefficients_dict: + supported_model_ids = ", ".join([i for i in self.coefficients_dict]) + raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).") + self.coefficients = self.coefficients_dict[model_id] + + def check(self, dit: WanModel, x, t_mod): + modulated_inp = t_mod.clone() + if self.step == 0 or self.step == self.num_inference_steps - 1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + coefficients = self.coefficients + rescale_func = np.poly1d(coefficients) + self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) + should_calc = not (self.accumulated_rel_l1_distance < self.rel_l1_thresh) + if should_calc: + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = modulated_inp + self.step = (self.step + 1) % self.num_inference_steps + if should_calc: + self.previous_hidden_states = x.clone() + return not should_calc + + def store(self, hidden_states): + self.previous_residual = hidden_states - self.previous_hidden_states + self.previous_hidden_states = None + + def update(self, hidden_states): + hidden_states = hidden_states + self.previous_residual + return hidden_states + + +# ----------------------------- +# 简化版模型前向封装(无 vace / 无 motion_controller) +# ----------------------------- +def model_fn_wan_video( + dit: WanModel, + x: torch.Tensor, + timestep: torch.Tensor, + context: torch.Tensor, + tea_cache: Optional[TeaCache] = None, + use_unified_sequence_parallel: bool = False, + LQ_latents: Optional[torch.Tensor] = None, + is_full_block: bool = False, + is_stream: bool = False, + pre_cache_k: Optional[list[torch.Tensor]] = None, + pre_cache_v: Optional[list[torch.Tensor]] = None, + topk_ratio: float = 2.0, + kv_ratio: float = 3.0, + cur_process_idx: int = 0, + t_mod : torch.Tensor = None, + t : torch.Tensor = None, + local_range: int = 9, + **kwargs, +): + # patchify + x, (f, h, w) = dit.patchify(x) + + win = (2, 8, 8) + seqlen = f // win[0] + local_num = seqlen + window_size = win[0] * h * w // 128 + square_num = window_size * window_size + topk = int(square_num * topk_ratio) - 1 + kv_len = int(kv_ratio) + + # RoPE 位置(分段) + if cur_process_idx == 0: + freqs = torch.cat([ + dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + else: + freqs = torch.cat([ + dit.freqs[0][4 + cur_process_idx*2:4 + cur_process_idx*2 + f].view(f, 1, 1, -1).expand(f, h, w, -1), + dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + + # TeaCache(默认不启用) + tea_cache_update = tea_cache.check(dit, x, t_mod) if tea_cache is not None else False + + # 统一序列并行(此处默认关闭) + if use_unified_sequence_parallel: + import torch.distributed as dist + from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) + if dist.is_initialized() and dist.get_world_size() > 1: + x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + + # Block 堆叠 + if tea_cache_update: + x = tea_cache.update(x) + else: + for block_id, block in enumerate(dit.blocks): + if LQ_latents is not None and block_id < len(LQ_latents): + x = x + LQ_latents[block_id] + x, last_pre_cache_k, last_pre_cache_v = block( + x, context, t_mod, freqs, f, h, w, + local_num, topk, + block_id=block_id, + kv_len=kv_len, + is_full_block=is_full_block, + is_stream=is_stream, + pre_cache_k=pre_cache_k[block_id] if pre_cache_k is not None else None, + pre_cache_v=pre_cache_v[block_id] if pre_cache_v is not None else None, + local_range = local_range, + ) + if pre_cache_k is not None: pre_cache_k[block_id] = last_pre_cache_k + if pre_cache_v is not None: pre_cache_v[block_id] = last_pre_cache_v + + x = dit.head(x, t) + if use_unified_sequence_parallel: + import torch.distributed as dist + from xfuser.core.distributed import get_sp_group + if dist.is_initialized() and dist.get_world_size() > 1: + x = get_sp_group().all_gather(x, dim=1) + x = dit.unpatchify(x, (f, h, w)) + return x, pre_cache_k, pre_cache_v diff --git a/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/schedulers/__init__.py b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/schedulers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..744c8dd4bd216c137583dd6f5091d2da731fadf8 --- /dev/null +++ b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/schedulers/__init__.py @@ -0,0 +1 @@ +from .flow_match import FlowMatchScheduler diff --git a/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/schedulers/flow_match.py b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/schedulers/flow_match.py new file mode 100644 index 0000000000000000000000000000000000000000..d6d02195aac2345e1938044d8ffd310dc6c4d3b9 --- /dev/null +++ b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/schedulers/flow_match.py @@ -0,0 +1,79 @@ +import torch + + + +class FlowMatchScheduler(): + + def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False, extra_one_step=False, reverse_sigmas=False): + self.num_train_timesteps = num_train_timesteps + self.shift = shift + self.sigma_max = sigma_max + self.sigma_min = sigma_min + self.inverse_timesteps = inverse_timesteps + self.extra_one_step = extra_one_step + self.reverse_sigmas = reverse_sigmas + self.set_timesteps(num_inference_steps) + + + def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None): + if shift is not None: + self.shift = shift + sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength + if self.extra_one_step: + self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1] + else: + self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps) + if self.inverse_timesteps: + self.sigmas = torch.flip(self.sigmas, dims=[0]) + self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas) + if self.reverse_sigmas: + self.sigmas = 1 - self.sigmas + self.timesteps = self.sigmas * self.num_train_timesteps + if training: + x = self.timesteps + y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2) + y_shifted = y - y.min() + bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum()) + self.linear_timesteps_weights = bsmntw_weighing + + + def step(self, model_output, timestep, sample, to_final=False, **kwargs): + if isinstance(timestep, torch.Tensor): + timestep = timestep.cpu() + timestep_id = torch.argmin((self.timesteps - timestep).abs()) + sigma = self.sigmas[timestep_id] + if to_final or timestep_id + 1 >= len(self.timesteps): + sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0 + else: + sigma_ = self.sigmas[timestep_id + 1] + prev_sample = sample + model_output * (sigma_ - sigma) + return prev_sample + + + def return_to_timestep(self, timestep, sample, sample_stablized): + if isinstance(timestep, torch.Tensor): + timestep = timestep.cpu() + timestep_id = torch.argmin((self.timesteps - timestep).abs()) + sigma = self.sigmas[timestep_id] + model_output = (sample - sample_stablized) / sigma + return model_output + + + def add_noise(self, original_samples, noise, timestep): + if isinstance(timestep, torch.Tensor): + timestep = timestep.cpu() + timestep_id = torch.argmin((self.timesteps - timestep).abs()) + sigma = self.sigmas[timestep_id] + sample = (1 - sigma) * original_samples + sigma * noise + return sample + + + def training_target(self, sample, noise, timestep): + target = noise - sample + return target + + + def training_weight(self, timestep): + timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs()) + weights = self.linear_timesteps_weights[timestep_id] + return weights diff --git a/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/vram_management/__init__.py b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/vram_management/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..69a388db1dea2d5699b716260dfa0902c27c0ab5 --- /dev/null +++ b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/vram_management/__init__.py @@ -0,0 +1 @@ +from .layers import * diff --git a/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/vram_management/layers.py b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/vram_management/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..a9df39ed224bf44a611af3ab984cb84a5d12c527 --- /dev/null +++ b/custom_nodes/ComfyUI-FlashVSR_Ultra_Fast/src/vram_management/layers.py @@ -0,0 +1,95 @@ +import torch, copy +from ..models.utils import init_weights_on_device + + +def cast_to(weight, dtype, device): + r = torch.empty_like(weight, dtype=dtype, device=device) + r.copy_(weight) + return r + + +class AutoWrappedModule(torch.nn.Module): + def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device): + super().__init__() + self.module = module.to(dtype=offload_dtype, device=offload_device) + self.offload_dtype = offload_dtype + self.offload_device = offload_device + self.onload_dtype = onload_dtype + self.onload_device = onload_device + self.computation_dtype = computation_dtype + self.computation_device = computation_device + self.state = 0 + + def offload(self): + if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): + self.module.to(dtype=self.offload_dtype, device=self.offload_device) + self.state = 0 + + def onload(self): + if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): + self.module.to(dtype=self.onload_dtype, device=self.onload_device) + self.state = 1 + + def forward(self, *args, **kwargs): + if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device: + module = self.module + else: + module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device) + return module(*args, **kwargs) + + +class AutoWrappedLinear(torch.nn.Linear): + def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device): + with init_weights_on_device(device=torch.device("meta")): + super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device) + self.weight = module.weight + self.bias = module.bias + self.offload_dtype = offload_dtype + self.offload_device = offload_device + self.onload_dtype = onload_dtype + self.onload_device = onload_device + self.computation_dtype = computation_dtype + self.computation_device = computation_device + self.state = 0 + + def offload(self): + if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): + self.to(dtype=self.offload_dtype, device=self.offload_device) + self.state = 0 + + def onload(self): + if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): + self.to(dtype=self.onload_dtype, device=self.onload_device) + self.state = 1 + + def forward(self, x, *args, **kwargs): + if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device: + weight, bias = self.weight, self.bias + else: + weight = cast_to(self.weight, self.computation_dtype, self.computation_device) + bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device) + return torch.nn.functional.linear(x, weight, bias) + + +def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0): + for name, module in model.named_children(): + for source_module, target_module in module_map.items(): + if isinstance(module, source_module): + num_param = sum(p.numel() for p in module.parameters()) + if max_num_param is not None and total_num_param + num_param > max_num_param: + module_config_ = overflow_module_config + else: + module_config_ = module_config + module_ = target_module(module, **module_config_) + setattr(model, name, module_) + total_num_param += num_param + break + else: + total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param) + return total_num_param + + +def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None): + enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0) + model.vram_management_enabled = True + diff --git a/custom_nodes/ComfyUI-LCS/.gitignore b/custom_nodes/ComfyUI-LCS/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..16bb595a9895bc9d4a1bd311eabf767b85ac2f9b --- /dev/null +++ b/custom_nodes/ComfyUI-LCS/.gitignore @@ -0,0 +1,2 @@ +.claude/ +__pycache__/ diff --git a/custom_nodes/ComfyUI-LCS/README.md b/custom_nodes/ComfyUI-LCS/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3b60676e72bf1d195553323109217277904f8429 --- /dev/null +++ b/custom_nodes/ComfyUI-LCS/README.md @@ -0,0 +1,344 @@ +# ComfyUI-LCS + +Training-free color control via the **Latent Color Subspace**, plus **sharpness control** via a discovered sharpness subspace. + +> **Note:** This is an unofficial community implementation. For the official code, see [ExplainableML/LCS](https://github.com/ExplainableML/LCS). + +Based on ["The Latent Color Subspace"](https://arxiv.org/abs/2603.12261v1) (ICML 2026): color in diffusion model latent patch spaces lives in a **3D subspace** (PCA captures 100% color variance), while the remaining 61 dimensions encode structure and detail orthogonally. + +This plugin steers colors directly in the 3D LCS during diffusion sampling — no training, no LoRA, no post-processing. + +> [中文版 README](README_zh.md) + +## LCS vs Traditional Post-Processing + +LCS operates **during** diffusion sampling, not after — this is the key difference from traditional color grading (Photoshop, filters, etc.). + +| | Traditional Post-Processing | LCS | +|---|---|---| +| **When** | After VAE decode, in pixel space | During sampling, in latent space | +| **Mechanism** | Color filter on the final image | Modifies 3D color subspace mid-generation | +| **Model awareness** | None — structure already locked | Model adapts to color shifts in subsequent steps | +| **Result** | Colors can look "painted on" | Colors look naturally intended by the model | + +For example: to get a warm orange sunset, post-processing tints everything orange (muddying shadows and skin tones), while LCS nudges the color subspace early in sampling so clouds, lighting, and reflections are *coherently* warm. + +The core insight: color and structure are **orthogonal** in the latent patch space — you can steer one without disturbing the other. + +## Tested Models + +| Model | Status | +|-------|--------| +| FLUX | Tested | +| FLUX2.klein | Tested | +| z-image | Tested | +| z-image-turbo | Tested | +| Wan (qwen-image) | Tested | +| LTX2.3 | Tested | + + +LCS calibrates per-VAE, so it should work with any model using a compatible VAE. Feel free to report results with other models. + +## Features + +- **Color Steering** — Push colors toward any target color +- **Batch Multi-Color** — Different colors per batch item +- **Tone Adjustment** — Contrast, brightness, saturation, temperature with one-click presets +- **Color Anchor** — Zero-config color drift correction: self-anchor, reference-based, or spatial smoothing with auto mode +- **Sharpness Control** — Sharpen or blur during generation via a discovered sharpness subspace (PC1 explains ~97% variance) +- **Localized Control** — Optional mask for region-specific changes +- **Latent Color Preview** — Visualize color structure without VAE decoding +- **Step Observer** — Per-step color previews for debugging + +## Installation + +```bash +cd ComfyUI/custom_nodes +git clone https://github.com/facok/ComfyUI-LCS.git +``` + +Dependencies (usually already present in ComfyUI): + +```bash +pip install einops safetensors +``` + +## Quick Start + +### Basic Color Control + +``` +LCS Load Data → LCS Color Intervene → KSampler + ↑ + (pick a color) +``` + +1. **LCS Load Data** — connect your VAE (auto-calibrates on first run) +2. **LCS Color Intervene** — connect MODEL and LCS_DATA, pick a target color +3. Connect the output MODEL to KSampler + +### Tone Adjustment + +``` +LCS Load Data → LCS Tone Adjust → KSampler +``` + +1. **LCS Load Data** → **LCS Tone Adjust** +2. Select a preset (e.g., "Cinematic") or adjust sliders manually + +![3d3c82eb0e89ed1608e40ac7a8cc3408](https://github.com/user-attachments/assets/62868e2d-0275-4801-a9bd-606bfea3ce2f) +![42541357](https://github.com/user-attachments/assets/fe22f09e-98ac-4281-ae40-f58232c7700f) + +### Sharpness Control + +``` +LCS Load Data ──→ LCS Sharpness Calibrate → LCS Sharpness Intervene → KSampler + ↑ lcs_data +``` + +1. **LCS Sharpness Calibrate** — connect VAE (auto-calibrates and caches). Optionally connect `lcs_data` from LCS Load Data to ensure sharpness edits don't affect color. +2. **LCS Sharpness Intervene** — connect MODEL and SHARPNESS_DATA, set strength + - Positive strength → sharper + - Negative strength → blurrier + - 0 → no change +![89814728](https://github.com/user-attachments/assets/62f036e9-0bea-4cc0-9220-af4c2fb8fa76) +### Multi-Color Batch + +``` +LCS Load Data → LCS Color Batch → KSampler + ↓ + batch_size → EmptyLatentImage +``` + +Enter comma-separated hex colors (e.g., `#FF0000,#00FF00,#0000FF`). Each color applies to one batch item. + +### Color Anchor (Zero-Config Drift Correction) + +``` +LCS Load Data → LCS Color Anchor → KSampler +``` + +1. **LCS Load Data** → **LCS Color Anchor** — connect MODEL and LCS_DATA +2. Set mode to **auto** (default) and leave intensity at default +3. Connect the output MODEL to KSampler + +That's it. In `auto` mode, the node automatically selects the correction strategy based on which optional inputs are connected: + +| Connected Inputs | Resolved Mode | Behavior | +|---|---|---| +| Nothing | self_anchor | Learns the image's color patterns early on, then prevents sudden color shifts | +| reference_image + vae | reference | Keeps generated colors close to your reference image | +| mask (no reference) | smooth | Smooths out color seams (great for inpainting) | + +Intensity is also derived automatically from measured drift — no manual tuning needed. + +> **When to use manual mode:** If you want full control, set mode to `smooth`, `reference`, or `self_anchor` explicitly and adjust the `intensity` slider (0–1). Auto mode is designed for zero-config "just works" usage. + +## Nodes + +### Calibration + +| Node | Description | +|------|-------------| +| **LCS Load Data** | Auto-calibrate and cache LCS color data per-VAE. Fingerprints VAE weights for automatic cache management. | +| **LCS Sharpness Calibrate** | Discover sharpness subspace via PCA on blur stimuli. Optionally connect `lcs_data` for color-orthogonal sharpness. | + +Calibration runs once per VAE and caches automatically. Subsequent runs load instantly. + +### Intervention + +| Node | Description | +|------|-------------| +| **LCS Color Intervene** | Steer colors toward a target. Supports Type I (LCS shift), Type II (HSL shift), or interpolated mode. | +| **LCS Color Batch** | Different target colors per batch item. Outputs `batch_size` for EmptyLatentImage. | +| **LCS Tone Adjust** | Contrast, brightness, saturation, temperature. Preset dropdown with real-time slider sync. | +| **LCS Color Anchor** | Correct color drift during sampling. Auto mode infers strategy and intensity from connected inputs. | +| **LCS Sharpness Intervene** | Control sharpness during generation. Positive = sharper, negative = blurrier. | + +### Observation + +| Node | Description | +|------|-------------| +| **LCS Preview Colors** | Decode latent colors to RGB preview without VAE decoding. | +| **LCS Step Observer** | Save per-step color preview PNGs to ComfyUI temp directory. | + +## Intervention Modes + +| Mode | Description | Best For | +|------|-------------|----------| +| **interpolated** (default) | Blends Type I and Type II using sigma | General use | +| **type_i** | Direct translation in 3D LCS space | Strong global color shifts | +| **type_ii** | Per-patch HSL interpolation via bicone geometry | Precise local color control | + +## Key Parameters + +### Color Intervention +- **strength** (0.0–2.0): Intervention intensity. 1.0 = full, 0.0 = none. +- **start_step / end_step**: Step range for intervention. Paper optimal: steps 8–10 of 50. +- **mask**: Optional. Downsampled to patch grid for localized control. + +### Sharpness Intervention +- **strength** (-5.0–5.0): Positive = sharper, negative = blurrier, 0 = no change. +- **start_step / end_step**: Step range (default 5–15). +- **mask**: Optional. Localized sharpness control. + +> **Tip for distilled models**: Step-distilled models (e.g., z-image-turbo) use far fewer steps, so intervention should start earlier — even from step 0. + +### Color Anchor + +Sometimes diffusion models produce unexpected color shifts during sampling — a blue sky suddenly turns purple, or inpainting leaves visible color seams. The Color Anchor node fixes these problems by monitoring and correcting colors as the image is being generated. + +**Modes:** + +| Mode | What it does | When to use | +|------|-------------|----------| +| **auto** (default) | Looks at what you connected and picks the best strategy for you | Just want it to work, no config needed | +| **self_anchor** | Watches how colors evolve in early steps, then prevents sudden color jumps in later steps | General color stability, no reference needed | +| **reference** | Keeps the generated image's colors close to a reference image you provide | "Make it look like this photo's color palette" | +| **smooth** | Smooths out abrupt color boundaries between regions | Fixing visible seams after inpainting | + +**How auto mode picks for you:** + +1. **Which strategy?** Based on what you plugged in: + - Connected a reference image + VAE → uses `reference` + - Connected a mask (but no reference) → uses `smooth` + - Connected nothing extra → uses `self_anchor` +2. **How strong?** The node measures how much color drift is actually happening, then sets the correction strength accordingly. Big drift → stronger fix. Small drift → gentle touch. The range is 0.15–0.6, so it never over-corrects or does nothing. + +**What happens during sampling:** + +The node runs at every sampling step but doesn't always intervene. It automatically figures out which steps are safe to correct: + +1. **Early steps** (image is mostly noise) — Too early to fix colors without creating artifacts. Skipped. In self_anchor mode, the node uses these steps to *learn* the image's color patterns. +2. **Middle steps** (image is taking shape) — The sweet spot. The node applies corrections here, ramping smoothly in and out to avoid sudden changes. +3. **Late steps** (fine details) — Corrections would disturb fine detail. Skipped. + +Only colors are modified — structure, texture, and detail are never touched. + +**Parameters:** + +- **mode**: `auto`, `smooth`, `reference`, or `self_anchor` +- **intensity** (0.0–1.0): How strong the correction is. In `auto` mode this is determined automatically. Set to 0 to disable the node entirely. +- **vae** (optional): Needed for `reference` mode to encode the reference image +- **reference_image** (optional): The image whose colors you want to match +- **mask** (optional): Only correct colors inside the masked area + +## Tone Presets + +Select a preset — sliders update in real-time. Tweak after selecting for fine-tuning. Select **Custom** to set values manually. + +| Preset | Contrast | Brightness | Saturation | Temperature | +|--------|----------|------------|------------|-------------| +| Base | 1.0 | 0.0 | 1.0 | 0.0 | +| Cinematic | 1.20 | -0.05 | 0.90 | 0.05 | +| HDR | 1.40 | 0.0 | 1.20 | 0.0 | +| Vivid | 1.10 | 0.0 | 1.50 | 0.0 | +| Dramatic | 1.50 | -0.10 | 0.85 | 0.0 | +| Low Key | 1.30 | -0.20 | 0.80 | 0.0 | +| High Key | 0.80 | 0.20 | 0.90 | 0.0 | +| Warm | 1.05 | 0.03 | 1.10 | 0.30 | +| Cool | 1.05 | 0.0 | 1.05 | -0.30 | +| Desaturated | 1.0 | 0.0 | 0.40 | 0.0 | + +## How It Works + +### Color (LCS) + +1. **Project** — Convert denoised prediction to 64D patch space, project onto 3D LCS basis +2. **Decompose** — Separate 3D color coordinates from the 61D structural residual +3. **Normalize** — Transform to reference timestep (t=50) using learned alpha/beta statistics +4. **Manipulate** — Shift colors, adjust tone, or apply other transformations in 3D LCS +5. **Reconstruct** — Denormalize, add back the preserved 61D residual, convert to latent space + +The 61D residual (structure, texture, detail) is never modified — only the 3D color subspace is touched. + +### Sharpness + +Sharpness lives in a separate subspace orthogonal to color: + +1. **Calibrate** — Generate grayscale noise images at multiple blur levels, VAE-encode, PCA on color-removed patch vectors. PC1 captures ~97% of sharpness variance. +2. **Intervene** — Add `strength * pc1_direction` to each patch. Since pc1_direction is orthogonal to color (calibrated with LCS removal) and DC-free (per-vector zero-mean before PCA), this modifies only spatial frequency content without affecting color or brightness. + +### Color Anchor + +The Color Anchor stabilizes colors without pushing them toward a specific target — it prevents drift from what the model is already generating: + +1. **Decide when to act** — The node checks each sampling step: is the image still mostly noise (too early), taking shape (good time to correct), or nearly finished (too late)? It only corrects during the safe middle window. +2. **Learn the color pattern** (self_anchor) — During early noisy steps, the node watches how colors relate to their neighbors and builds a running average of these relationships. This is more reliable than tracking absolute colors, which shift naturally as the image forms. +3. **Measure drift** — On the first correction step, the node measures how much the colors have actually drifted (varies by mode: step-to-step jumps, distance from reference, or spatial roughness). This sets the correction strength in auto mode. +4. **Apply gentle corrections** — Corrections ramp smoothly in and out (no sudden jumps). Each mode corrects differently: self_anchor fixes patches that deviate from learned patterns, reference pulls toward the reference image's colors, smooth blurs out sharp color boundaries. +5. **Preserve everything else** — As with all LCS operations, only the 3D color coordinates change. Structure, texture, and detail are untouched. + +## File Structure + +``` +ComfyUI-LCS/ +├── __init__.py # Entry point (V3 + V2 compat) +├── requirements.txt +├── core/ +│ ├── adaptive.py # Adaptive scheduling (phases, envelopes, drift estimation) +│ ├── bilateral.py # Bilateral filter for LCS color smoothing +│ ├── calibration.py # PCA calibration pipeline (color) +│ ├── color_space.py # Bicone LCS ↔ HSL mapping +│ ├── defaults.py # Alpha/beta tables from paper +│ ├── lcs_data.py # LCSData dataclass +│ ├── patchify.py # Patch ↔ latent conversion +│ ├── relationships.py # Local color relationship analysis & anomaly detection +│ ├── sampling.py # Shared constants & step utilities +│ ├── sharpness.py # Sharpness subspace calibration +│ └── timestep.py # Sigma/timestep utilities +├── nodes/ +│ ├── anchor.py # LCSColorAnchor (adaptive color drift correction) +│ ├── calibrate.py # LCSLoadData (auto-calibrate + cache) +│ ├── intervene.py # LCSColorIntervene, LCSColorBatch, LCSToneAdjust +│ ├── observe.py # LCSPreviewColors, LCSStepObserver +│ └── sharpen.py # LCSSharpnessCalibrate, LCSSharpnessIntervene +├── data/ # Cached calibration files +└── web/js/ + └── tone_preset.js # Frontend preset sync +``` + +## Changelog + +### 2026-03-21 +- **Color Anchor: auto mode** — New `auto` mode that infers correction strategy (self_anchor / reference / smooth) from connected inputs and derives intensity from measured drift. Zero-config usage. +- **Color Anchor: adaptive scheduling** — Phase assignment (observe/correct/skip) and strength envelope are derived from the sigma schedule at runtime. + +### 2026-03-20 +- **Sharpness Control** — New sharpness subspace discovered via PCA on blur stimuli. `LCS Sharpness Calibrate` + `LCS Sharpness Intervene` nodes. PC1 explains ~97% variance, orthogonal to color. +- **Color-orthogonal sharpness** — Optional `lcs_data` input removes color component during sharpness calibration, preventing color shift. + +### 2026-03-19 +- **Video VAE support (Wan)** — Handle 5D video latents in patchify/unpatchify. Per-image VAE encoding fallback for video VAEs. +- **LTXV compatibility** — Pad odd spatial dims in patchify, handle 3D tensors, skip gracefully for incompatible formats. +- **FLUX2 support** — Auto-detect 128-channel latents in unpatchify. +- **Universal latent format** — Use model's `latent_format` for space conversion instead of hardcoded FLUX constants. + +### 2026-03-18 +- **Tone Adjust** — `LCS Tone Adjust` node with contrast, brightness, saturation, temperature sliders. 10 presets with frontend real-time sync. +- **Color temperature** — Warm/cool shift along LCS blue-yellow axis. +- **Bicone HSL geometry** — Correct Type II intervention via bicone LCS-to-HSL mapping. + +### 2026-03-17 +- **Initial release** — Color steering (Type I + Type II + interpolated), batch multi-color, localized mask control, latent color preview, step observer. Per-VAE auto-calibration with caching. + +## Citation + +Official repository: [ExplainableML/LCS](https://github.com/ExplainableML/LCS) + +```bibtex +@article{pach2026latentcolorsubspace, + title={The Latent Color Subspace: Emergent Order in High-Dimensional Chaos}, + author={Mateusz Pach and Jessica Bader and Quentin Bouniot and Serge Belongie and Zeynep Akata}, + journal={arxiv}, + year={2026} +} +``` + +## Acknowledgments + +Thanks to Mateusz Pach, Jessica Bader, Quentin Bouniot, Serge Belongie, and Zeynep Akata for their research making training-free color control possible. + +## License + +MIT diff --git a/custom_nodes/ComfyUI-LCS/README_zh.md b/custom_nodes/ComfyUI-LCS/README_zh.md new file mode 100644 index 0000000000000000000000000000000000000000..74404355d02edb11207fe03fad45b85f80fcfc22 --- /dev/null +++ b/custom_nodes/ComfyUI-LCS/README_zh.md @@ -0,0 +1,343 @@ +# ComfyUI-LCS + +基于**潜在颜色子空间**(Latent Color Subspace)的免训练颜色控制,以及基于发现的**锐度子空间**的锐度控制。 + +> **注意:** 本项目为非官方社区实现。官方代码见 [ExplainableML/LCS](https://github.com/ExplainableML/LCS)。 + +基于论文 ["The Latent Color Subspace"](https://arxiv.org/abs/2603.12261v1)(ICML 2026):扩散模型潜在 patch 空间中的颜色完全存在于一个 **3 维子空间**(PCA 捕获 100% 颜色方差),剩余 61 维编码结构与细节,与颜色正交。 + +本插件在扩散采样过程中直接操作 3D LCS 控制颜色——无需训练、无需 LoRA、无需后处理。 + +> [English README](README.md) + +## LCS 与传统后处理调色的区别 + +LCS 在扩散采样**过程中**操作,而非生成之后——这是与传统调色(Photoshop、滤镜等)的根本区别。 + +| | 传统后处理 | LCS | +|---|---|---| +| **时机** | VAE 解码后,像素空间 | 采样过程中,潜在空间 | +| **机制** | 对成品图像施加颜色滤镜 | 在生成中途修改 3D 颜色子空间 | +| **模型感知** | 无——结构已定型 | 模型在后续步骤中自适应颜色偏移 | +| **效果** | 颜色容易显得"涂上去的" | 颜色与内容自然融合 | + +例:想要暖橙色日落,后处理会给全图叠橙色(阴影和肤色变脏),而 LCS 在采样早期推动颜色子空间,模型生成的云层、光照、反射与暖色调**内在一致**。 + +核心发现:颜色与结构在潜在 patch 空间中**正交**——可以单独控制颜色而不干扰结构。 + +## 已测试模型 + +| 模型 | 状态 | +|------|------| +| FLUX | 已测试 | +| FLUX2.klein | 已测试 | +| z-image | 已测试 | +| z-image-turbo | 已测试 | +| Wan (qwen-image) | 已测试 | +| LTX2.3 | 已测试 | + +LCS 按 VAE 校准,理论上适用于任何使用兼容 VAE 架构的模型。欢迎反馈其他模型的测试结果。 + +## 功能 + +- **颜色引导** — 将颜色推向任意目标色 +- **批量多色** — 为批次中每张图像指定不同颜色 +- **色调调整** — 对比度、亮度、饱和度、色温,支持一键预设 +- **颜色锚定** — 零配置颜色漂移校正:自锚定、参考图锚定、空间平滑,支持全自动模式 +- **锐度控制** — 在生成过程中增强或减弱锐度,基于发现的锐度子空间(PC1 解释 ~97% 方差) +- **局部控制** — 可选遮罩,实现区域性变化 +- **潜在颜色预览** — 无需 VAE 解码即可可视化颜色结构 +- **步骤观察器** — 保存每步颜色预览,用于调试 + +## 安装 + +```bash +cd ComfyUI/custom_nodes +git clone https://github.com/facok/ComfyUI-LCS.git +``` + +依赖(通常 ComfyUI 已自带): + +```bash +pip install einops safetensors +``` + +## 快速开始 + +### 基本颜色控制 + +``` +LCS Load Data → LCS Color Intervene → KSampler + ↑ + (选择颜色) +``` + +1. **LCS Load Data** — 连接 VAE(首次运行自动校准) +2. **LCS Color Intervene** — 连接 MODEL 和 LCS_DATA,选择目标颜色 +3. 将输出 MODEL 连接到 KSampler + +### 色调调整 + +``` +LCS Load Data → LCS Tone Adjust → KSampler +``` + +1. **LCS Load Data** → **LCS Tone Adjust** +2. 选择预设(如 "Cinematic")或手动调整滑条 + +![3d3c82eb0e89ed1608e40ac7a8cc3408](https://github.com/user-attachments/assets/62868e2d-0275-4801-a9bd-606bfea3ce2f) +![42541357](https://github.com/user-attachments/assets/fe22f09e-98ac-4281-ae40-f58232c7700f) +### 锐度控制 + +``` +LCS Load Data ──→ LCS Sharpness Calibrate → LCS Sharpness Intervene → KSampler + ↑ lcs_data +``` + +1. **LCS Sharpness Calibrate** — 连接 VAE(首次运行自动校准并缓存)。可选连接 `lcs_data`(来自 LCS Load Data),确保锐度编辑不影响颜色。 +2. **LCS Sharpness Intervene** — 连接 MODEL 和 SHARPNESS_DATA,设置强度 + - 正值 → 更锐利 + - 负值 → 更模糊 + - 0 → 无变化 +![89814728](https://github.com/user-attachments/assets/62f036e9-0bea-4cc0-9220-af4c2fb8fa76) + +### 批量多色生成 + +``` +LCS Load Data → LCS Color Batch → KSampler + ↓ + batch_size → EmptyLatentImage +``` + +输入逗号分隔的十六进制颜色(如 `#FF0000,#00FF00,#0000FF`),每个颜色对应一个批次项。 + +### 颜色锚定(零配置漂移校正) + +``` +LCS Load Data → LCS Color Anchor → KSampler +``` + +1. **LCS Load Data** → **LCS Color Anchor** — 连接 MODEL 和 LCS_DATA +2. 模式设为 **auto**(默认),intensity 保持默认值 +3. 将输出 MODEL 连接到 KSampler + +完成。在 `auto` 模式下,节点根据连接的可选输入自动选择校正策略: + +| 已连接输入 | 解析模式 | 行为 | +|---|---|---| +| 无 | self_anchor | 在早期学习图像的颜色规律,然后防止突然的颜色偏移 | +| reference_image + vae | reference | 让生成的颜色贴近你的参考图 | +| mask(无参考图) | smooth | 平滑颜色接缝(很适合修复/补绘) | + +intensity 也会根据实测漂移自动推导——无需手动调参。 + +> **手动模式:** 如果需要完全控制,可以将模式设为 `smooth`、`reference` 或 `self_anchor`,并手动调节 `intensity` 滑条(0–1)。auto 模式适合零配置「开箱即用」场景。 + +## 节点一览 + +### 校准 + +| 节点 | 说明 | +|------|------| +| **LCS Load Data** | 自动校准并按 VAE 缓存 LCS 颜色数据。通过 VAE 权重指纹自动管理缓存。 | +| **LCS Sharpness Calibrate** | 通过模糊刺激 PCA 发现锐度子空间。可选连接 `lcs_data` 使锐度正交于颜色。 | + +每个 VAE 只需校准一次,结果自动缓存,后续运行瞬时加载。 + +### 干预 + +| 节点 | 说明 | +|------|------| +| **LCS Color Intervene** | 将颜色引导至目标色。支持 Type I(LCS 平移)、Type II(HSL 偏移)或插值模式。 | +| **LCS Color Batch** | 每个批次项施加不同目标颜色。输出 `batch_size` 可连接 EmptyLatentImage。 | +| **LCS Tone Adjust** | 对比度、亮度、饱和度、色温调整。预设下拉菜单,滑条实时同步。 | +| **LCS Color Anchor** | 采样过程中校正颜色漂移。auto 模式根据连接输入自动推断策略和强度。 | +| **LCS Sharpness Intervene** | 在生成过程中控制锐度。正值 = 更锐利,负值 = 更模糊。 | + +### 观察 + +| 节点 | 说明 | +|------|------| +| **LCS Preview Colors** | 将潜在颜色解码为 RGB 预览图,无需 VAE 解码。 | +| **LCS Step Observer** | 将每步颜色预览 PNG 保存至 ComfyUI 临时目录。 | + +## 干预模式 + +| 模式 | 说明 | 适用场景 | +|------|------|----------| +| **interpolated**(默认) | 以 sigma 为权重混合 Type I 和 Type II | 通用场景 | +| **type_i** | 3D LCS 空间中的直接平移 | 强烈的全局颜色偏移 | +| **type_ii** | 通过双锥几何进行逐 patch HSL 插值 | 精确的局部颜色控制 | + +## 关键参数 + +### 颜色干预 +- **strength**(0.0–2.0):干预强度。1.0 = 完整干预,0.0 = 无干预。 +- **start_step / end_step**:干预步骤范围。论文最优:50 步中的第 8–10 步。 +- **mask**:可选。下采样至 patch 网格分辨率,用于局部控制。 + +### 锐度干预 +- **strength**(-5.0–5.0):正值 = 更锐利,负值 = 更模糊,0 = 无变化。 +- **start_step / end_step**:干预步骤范围(默认 5–15)。 +- **mask**:可选。用于局部锐度控制。 + +> **步数蒸馏模型提示**:对于步数蒸馏模型(如 z-image-turbo),总步数很少,干预应从更早的步骤开始——甚至可以从第 0 步就开始干预。 + +### 颜色锚定 + +扩散模型在采样过程中有时会出现意想不到的颜色偏移——蓝天突然变紫,或者修复/补绘后留下明显的颜色接缝。颜色锚定节点在图像生成过程中监控和修正这些问题。 + +**模式:** + +| 模式 | 功能 | 适用场景 | +|------|------|----------| +| **auto**(默认) | 根据你连接的输入自动选最合适的策略 | 不想调参,开箱即用 | +| **self_anchor** | 在早期步骤观察颜色变化规律,在后续步骤防止突然的颜色跳变 | 通用颜色稳定,不需要参考图 | +| **reference** | 让生成图像的颜色贴近你提供的参考图 | 「我想要这张照片的配色风格」 | +| **smooth** | 平滑区域之间的突兀颜色边界 | 修复/补绘后消除接缝 | + +**auto 模式如何自动选择:** + +1. **用哪种策略?** 看你连了什么: + - 连了参考图 + VAE → 用 `reference` + - 连了遮罩(没有参考图)→ 用 `smooth` + - 什么额外输入都没连 → 用 `self_anchor` +2. **修正多强?** 节点会测量实际的颜色漂移幅度,据此自动设置校正强度。漂移大 → 修正更强;漂移小 → 轻轻一碰。范围是 0.15–0.6,既不会矫枉过正,也不会毫无作用。 + +**采样过程中发生了什么:** + +节点在每个采样步都会运行,但不会每步都干预。它自动判断哪些步骤适合校正: + +1. **早期步骤**(图像基本是噪声)— 太早修正颜色会产生伪影,跳过。在 self_anchor 模式下,节点利用这些步骤*学习*图像的颜色规律。 +2. **中间步骤**(图像逐渐成形)— 最佳校正时机。节点在这里施加校正,平滑地渐入渐出,避免突变。 +3. **后期步骤**(精细细节)— 校正会干扰细节,跳过。 + +只修改颜色——结构、纹理、细节始终不受影响。 + +**参数:** + +- **mode**:`auto`、`smooth`、`reference` 或 `self_anchor` +- **intensity**(0.0–1.0):校正强度。auto 模式下自动决定。设为 0 可完全禁用此节点。 +- **vae**(可选):reference 模式需要用它来编码参考图 +- **reference_image**(可选):你想匹配其颜色的参考图 +- **mask**(可选):只在遮罩区域内校正颜色 + +## 色调预设 + +选择预设后滑条实时更新。可在预设基础上微调。选择 **Custom** 可完全手动设置。 + +| 预设 | 对比度 | 亮度 | 饱和度 | 色温 | +|------|--------|------|--------|------| +| Base | 1.0 | 0.0 | 1.0 | 0.0 | +| Cinematic | 1.20 | -0.05 | 0.90 | 0.05 | +| HDR | 1.40 | 0.0 | 1.20 | 0.0 | +| Vivid | 1.10 | 0.0 | 1.50 | 0.0 | +| Dramatic | 1.50 | -0.10 | 0.85 | 0.0 | +| Low Key | 1.30 | -0.20 | 0.80 | 0.0 | +| High Key | 0.80 | 0.20 | 0.90 | 0.0 | +| Warm | 1.05 | 0.03 | 1.10 | 0.30 | +| Cool | 1.05 | 0.0 | 1.05 | -0.30 | +| Desaturated | 1.0 | 0.0 | 0.40 | 0.0 | + +## 工作原理 + +### 颜色(LCS) + +1. **投影** — 将去噪预测转换到 64D patch 空间,投影到 3D LCS 基底 +2. **分解** — 将 3D 颜色坐标与 61D 结构残差分离 +3. **归一化** — 使用学习的 alpha/beta 统计量变换至参考时间步(t=50) +4. **操作** — 在 3D LCS 中偏移颜色、调整色调或进行其他变换 +5. **重建** — 反归一化,加回保留的 61D 残差,转换回潜在空间 + +61D 残差(结构、纹理、细节)始终不被修改——只有 3D 颜色子空间会被改变。 + +### 锐度 + +锐度存在于与颜色正交的独立子空间中: + +1. **校准** — 生成灰度噪声图像,应用多级高斯模糊,VAE 编码后对去除颜色分量的 patch 向量做 PCA。PC1 捕获 ~97% 的锐度方差。 +2. **干预** — 在每个 patch 上沿 `strength * pc1_direction` 方向添加偏移。由于 pc1_direction 与颜色正交(校准时已移除 LCS 分量)且无直流分量(PCA 前做了逐向量零均值化),因此只改变空间频率内容,不影响颜色或亮度。 + +### 颜色锚定 + +颜色锚定的作用是稳定颜色,而不是把颜色推向某个特定目标——它防止模型已经在生成的颜色发生偏移: + +1. **判断何时介入** — 节点检查每个采样步:图像还是一片噪声(太早)、正在成形(适合校正)、还是快完成了(太晚)?只在安全的中间窗口进行校正。 +2. **学习颜色规律**(self_anchor)— 在早期噪声较大的步骤中,节点观察每个区域的颜色与邻居之间的关系,建立一个动态平均值。比起追踪绝对颜色值,这种「相对关系」更可靠,因为绝对颜色在图像成形过程中本来就会自然变化。 +3. **测量漂移** — 在第一个校正步,节点测量颜色实际漂移了多少(根据模式不同:步间跳变幅度、与参考图的差距、或空间粗糙程度)。这决定了 auto 模式下的校正强度。 +4. **温和地修正** — 校正平滑地渐入渐出(不会突变)。每种模式的修正方式不同:self_anchor 修复偏离已学规律的区域,reference 拉近与参考图的颜色,smooth 模糊掉尖锐的颜色边界。 +5. **保留其他一切** — 与所有 LCS 操作一样,只修改 3D 颜色坐标,结构、纹理、细节完全不受影响。 + +## 文件结构 + +``` +ComfyUI-LCS/ +├── __init__.py # 入口(V3 + V2 兼容) +├── requirements.txt +├── core/ +│ ├── adaptive.py # 自适应调度(阶段、包络、漂移估计) +│ ├── bilateral.py # LCS 颜色平滑的双边滤波 +│ ├── calibration.py # PCA 校准流程(颜色) +│ ├── color_space.py # 双锥 LCS ↔ HSL 映射 +│ ├── defaults.py # 论文中的 Alpha/beta 表 +│ ├── lcs_data.py # LCSData 数据类 +│ ├── patchify.py # Patch ↔ 潜在空间转换 +│ ├── relationships.py # 局部颜色关系分析与异常检测 +│ ├── sampling.py # 共享常量和步骤工具 +│ ├── sharpness.py # 锐度子空间校准 +│ └── timestep.py # Sigma/时间步工具 +├── nodes/ +│ ├── anchor.py # LCSColorAnchor(自适应颜色漂移校正) +│ ├── calibrate.py # LCSLoadData(自动校准 + 缓存) +│ ├── intervene.py # LCSColorIntervene, LCSColorBatch, LCSToneAdjust +│ ├── observe.py # LCSPreviewColors, LCSStepObserver +│ └── sharpen.py # LCSSharpnessCalibrate, LCSSharpnessIntervene +├── data/ # 缓存的校准文件 +└── web/js/ + └── tone_preset.js # 前端预设同步 +``` + +## 更新日志 + +### 2026-03-21 +- **颜色锚定:auto 模式** — 新增 `auto` 模式,根据连接的输入自动推断校正策略(self_anchor / reference / smooth),并根据实测漂移推导强度。零配置使用。 +- **颜色锚定:自适应调度** — 阶段分配(observe/correct/skip)和强度包络在运行时从 sigma 调度表推导。 + +### 2026-03-20 +- **锐度控制** — 通过模糊刺激 PCA 发现锐度子空间。新增 `LCS Sharpness Calibrate` + `LCS Sharpness Intervene` 节点。PC1 解释 ~97% 方差,与颜色正交。 +- **颜色正交锐度** — 可选连接 `lcs_data`,在锐度校准时移除颜色分量,防止颜色偏移。 + +### 2026-03-19 +- **视频 VAE 支持(Wan)** — 在 patchify/unpatchify 中处理 5D 视频潜在表示。视频 VAE 自动回退到逐帧编码。 +- **LTXV 兼容** — patchify 中填充奇数空间维度,处理 3D 张量,不兼容格式时优雅跳过。 +- **FLUX2 支持** — unpatchify 自动检测 128 通道潜在表示。 +- **通用潜在格式** — 使用模型的 `latent_format` 进行空间转换,不再硬编码 FLUX 常量。 + +### 2026-03-18 +- **色调调整** — `LCS Tone Adjust` 节点,支持对比度、亮度、饱和度、色温滑条。10 个预设,前端实时同步。 +- **色温控制** — 沿 LCS 蓝-黄轴的暖/冷偏移。 +- **双锥 HSL 几何** — 通过双锥 LCS-to-HSL 映射实现正确的 Type II 干预。 + +### 2026-03-17 +- **首次发布** — 颜色引导(Type I + Type II + 插值模式)、批量多色、局部遮罩控制、潜在颜色预览、步骤观察器。按 VAE 自动校准并缓存。 + +## 引用 + +官方仓库:[ExplainableML/LCS](https://github.com/ExplainableML/LCS) + +```bibtex +@article{pach2026latentcolorsubspace, + title={The Latent Color Subspace: Emergent Order in High-Dimensional Chaos}, + author={Mateusz Pach and Jessica Bader and Quentin Bouniot and Serge Belongie and Zeynep Akata}, + journal={arxiv}, + year={2026} +} +``` + +## 致谢 + +感谢 Mateusz Pach、Jessica Bader、Quentin Bouniot、Serge Belongie 和 Zeynep Akata,他们的研究使免训练颜色控制成为可能。 + +## 许可证 + +MIT diff --git a/custom_nodes/ComfyUI-LCS/__init__.py b/custom_nodes/ComfyUI-LCS/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..70ba7471ce85cc5cf323fe2e0872df1f146d1418 --- /dev/null +++ b/custom_nodes/ComfyUI-LCS/__init__.py @@ -0,0 +1,53 @@ +"""ComfyUI-LCS: The Latent Color Subspace — training-free color control for FLUX. + +Paper: "The Latent Color Subspace" (arXiv:2603.12261v1, ICML 2026) +""" + +# Register as ComfyUI_LCS so other plugins can `from ComfyUI_LCS.core.xxx import ...` +import sys as _sys +_sys.modules.setdefault("ComfyUI_LCS", _sys.modules[__name__]) + +# V3 ComfyExtension entry point +from comfy_api.latest import ComfyExtension, io +from .nodes.calibrate import LCSLoadData +from .nodes.intervene import LCSColorIntervene, LCSColorBatch, LCSToneAdjust +from .nodes.observe import LCSPreviewColors, LCSStepObserver +from .nodes.sharpen import LCSSharpnessCalibrate, LCSSharpnessIntervene +from .nodes.anchor import LCSColorAnchor + + +class LCSExtension(ComfyExtension): + """V3 ComfyExtension providing all LCS nodes to ComfyUI.""" + + async def get_node_list(self) -> list[type[io.ComfyNode]]: + """Return all LCS node classes.""" + return [ + LCSLoadData, + LCSColorIntervene, + LCSColorBatch, + LCSToneAdjust, + LCSPreviewColors, + LCSStepObserver, + LCSSharpnessCalibrate, + LCSSharpnessIntervene, + LCSColorAnchor, + ] + + +async def comfy_entrypoint() -> LCSExtension: + """V3 async entry point called by ComfyUI on startup.""" + return LCSExtension() + + +# V2 backward compatibility +from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS + +WEB_DIRECTORY = "./web" + +__all__ = [ + "NODE_CLASS_MAPPINGS", + "NODE_DISPLAY_NAME_MAPPINGS", + "WEB_DIRECTORY", + "LCSExtension", + "comfy_entrypoint", +] diff --git a/custom_nodes/ComfyUI-LCS/core/__init__.py b/custom_nodes/ComfyUI-LCS/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fa25cd18e816e0f67bbe7bd3a764367e4b964e3d --- /dev/null +++ b/custom_nodes/ComfyUI-LCS/core/__init__.py @@ -0,0 +1,10 @@ +from .lcs_data import LCSData +from .patchify import patchify, unpatchify +from .timestep import sigma_to_paper_t, get_alpha_beta, normalize_to_t50, denormalize_from_t50 +from .color_space import decode_lcs_to_hsl, encode_hsl_to_lcs, hex_to_hsl, hsl_to_rgb + + +def calibrate(*args, **kwargs): + """Lazy wrapper for core.calibration.calibrate (avoids importing comfy.utils at module level).""" + from .calibration import calibrate as _calibrate + return _calibrate(*args, **kwargs) diff --git a/custom_nodes/ComfyUI-LCS/core/adaptive.py b/custom_nodes/ComfyUI-LCS/core/adaptive.py new file mode 100644 index 0000000000000000000000000000000000000000..280da4209e86cb630edb28934cfe4a4508cc397b --- /dev/null +++ b/custom_nodes/ComfyUI-LCS/core/adaptive.py @@ -0,0 +1,109 @@ +"""Schedule-aware adaptive logic for LCS color anchoring. + +Derives intervention windows, strength envelopes, and phase assignments +from the sigma schedule's amplification factor (beta_50 / beta_t), replacing +all manually-tuned step/strength parameters with data-driven decisions. +""" + +import math +import torch +from .defaults import get_beta_table + + +def compute_amplification(sigma_val, device=None): + """Compute amplification factor A = max_k(beta_50[k] / beta_t(sigma)[k]). + + The amplification factor measures how much the normalization step inflates + noise relative to signal. High A means corrections are dangerous (amplified + noise dominates), low A means corrections are safe. + + sigma_val: float in [0, 1] (FLUX sigma, 1=noise, 0=clean) + Returns: float amplification factor + """ + beta_table = get_beta_table() # [51, 3] + beta_50 = beta_table[50] # [3] + + # Convert sigma to paper timestep + t = 50.0 * (1.0 - max(0.0, min(1.0, sigma_val))) + t = max(0.0, min(50.0, t)) + t_low = int(t) + t_high = min(t_low + 1, 50) + frac = t - t_low + + beta_t = (1.0 - frac) * beta_table[t_low] + frac * beta_table[t_high] + + # Per-component ratio, take max + beta_t_safe = beta_t.clamp(min=1e-8) + ratios = beta_50 / beta_t_safe # [3] + return ratios.max().item() + + +def compute_step_phases(sigmas, mode): + """Assign a phase to each sampling step based on amplification factor. + + Physics-derived constants (not empirical): + A_MAX = 10.0 — above: normalization amplifies noise >10x → skip + A_WARMUP = 5.0 — self_anchor only: observe phase for EMA buildup + SIGMA_MIN = 0.15 — below: final detail refinement → skip + + sigmas: 1D tensor of sigma values for each step (length N+1, last is 0) + mode: "smooth", "reference", or "self_anchor" + + Returns: list of N strings, each "skip" / "observe" / "correct" + """ + A_MAX = 10.0 + A_WARMUP = 5.0 + SIGMA_MIN = 0.15 + + n_steps = len(sigmas) - 1 # last sigma is terminal (0) + phases = [] + + for i in range(n_steps): + sigma_val = float(sigmas[i]) + + # Final refinement — skip + if sigma_val < SIGMA_MIN: + phases.append("skip") + continue + + amp = compute_amplification(sigma_val) + + # Too noisy — skip + if amp > A_MAX: + phases.append("skip") + continue + + # Self-anchor warmup zone + if mode == "self_anchor" and amp > A_WARMUP: + phases.append("observe") + continue + + phases.append("correct") + + return phases + + +def estimate_intensity(drift_signal): + """Map drift magnitude to intensity in [0.15, 0.6].""" + DRIFT_SCALE = 0.2 + INTENSITY_MIN = 0.15 + INTENSITY_MAX = 0.6 + return max(INTENSITY_MIN, min(INTENSITY_MAX, drift_signal / DRIFT_SCALE)) + + +def compute_strength_envelope(n_correction_steps): + """Sinusoidal bell envelope over correction steps. + + sin(pi * i / (n-1)) for i in 0..n-1 + Prevents abrupt on/off at phase boundaries. + Single step returns [1.0]. + + Returns: 1D tensor of length n_correction_steps + """ + if n_correction_steps <= 0: + return torch.zeros(0) + if n_correction_steps == 1: + return torch.ones(1) + n = n_correction_steps + indices = torch.arange(n, dtype=torch.float32) + return torch.sin(math.pi * indices / (n - 1)) diff --git a/custom_nodes/ComfyUI-LCS/core/bilateral.py b/custom_nodes/ComfyUI-LCS/core/bilateral.py new file mode 100644 index 0000000000000000000000000000000000000000..27c68154cd033fa386dee742215a8a5c6e00493a --- /dev/null +++ b/custom_nodes/ComfyUI-LCS/core/bilateral.py @@ -0,0 +1,79 @@ +"""Bilateral filter in LCS space for smooth color anchoring.""" + +import math + +import torch +import torch.nn.functional as F + + +def estimate_bilateral_params(c, h_len, w_len): + """Estimate bilateral filter parameters from local color statistics. + + Computes per-channel spatial std of c across the grid, takes the median + to derive sigma_color. sigma_spatial is fixed at 1.5 (5x5 kernel is small). + + c: [B, L, 3] LCS coordinates + Returns: (sigma_spatial, sigma_color) floats + """ + B = c.shape[0] + grid = c.reshape(B, h_len, w_len, 3) # [B, H, W, 3] + # Per-channel std across spatial dims → [B, 3] + channel_std = grid.reshape(B, -1, 3).std(dim=1) # [B, 3] + # Median across batch and channels + median_std = float(channel_std.median()) + sigma_color = max(0.05, min(3.0, 0.75 * median_std)) + sigma_spatial = 1.5 + return sigma_spatial, sigma_color + + +def bilateral_filter_lcs(c, h_len, w_len, sigma_spatial, sigma_color, kernel_radius=2): + """Bilateral filter on [B, L, 3] LCS coordinates arranged on h_len x w_len grid. + + Uses spatial distance + LCS color distance as joint weights. + kernel_radius=2 -> 5x5 neighborhood (25 lookups per patch). + Returns [B, L, 3] filtered coordinates. + """ + B = c.shape[0] + # Reshape to spatial grid + grid = c.reshape(B, h_len, w_len, 3) # [B, H, W, 3] + + # Pad by kernel_radius (replicate) — pad last two spatial dims + # F.pad on [B, H, W, 3]: need to pad dims -3 and -2 (H and W) + # Permute to [B, 3, H, W] for F.pad, then back + grid_chw = grid.permute(0, 3, 1, 2) # [B, 3, H, W] + r = kernel_radius + padded = F.pad(grid_chw, (r, r, r, r), mode="replicate") # [B, 3, H+2r, W+2r] + + # Precompute spatial Gaussian weights for each offset in kernel + inv_2ss = -0.5 / (sigma_spatial * sigma_spatial) + inv_2sc = -0.5 / (sigma_color * sigma_color) + + # Accumulate weighted sum + weight_sum = torch.zeros(B, 1, h_len, w_len, device=c.device, dtype=c.dtype) + value_sum = torch.zeros(B, 3, h_len, w_len, device=c.device, dtype=c.dtype) + + for dy in range(-r, r + 1): + for dx in range(-r, r + 1): + # Spatial weight (constant per offset) + spatial_dist_sq = float(dy * dy + dx * dx) + w_spatial = math.exp(spatial_dist_sq * inv_2ss) + + # Extract neighbor values from padded grid + y_start = r + dy + x_start = r + dx + neighbor = padded[:, :, y_start:y_start + h_len, x_start:x_start + w_len] # [B, 3, H, W] + + # Color distance weight (per-pixel) + diff = neighbor - grid_chw # [B, 3, H, W] + color_dist_sq = (diff * diff).sum(dim=1, keepdim=True) # [B, 1, H, W] + w_color = torch.exp(color_dist_sq * inv_2sc) # [B, 1, H, W] + + w = w_spatial * w_color + weight_sum.add_(w) + value_sum.add_(w * neighbor) + + # Normalize + result = value_sum / weight_sum.clamp(min=1e-8) # [B, 3, H, W] + + # Back to [B, L, 3] + return result.permute(0, 2, 3, 1).reshape(B, -1, 3) diff --git a/custom_nodes/ComfyUI-LCS/core/calibration.py b/custom_nodes/ComfyUI-LCS/core/calibration.py new file mode 100644 index 0000000000000000000000000000000000000000..a9fbc80c0f17fcc5282582577c35ccb8f610a4f2 --- /dev/null +++ b/custom_nodes/ComfyUI-LCS/core/calibration.py @@ -0,0 +1,214 @@ +"""PCA calibration from FLUX VAE: compute LCS basis, mean, and anchor positions.""" + +import hashlib +import math +import torch +import comfy.utils +from .patchify import patchify +from .lcs_data import LCSData +from .color_space import _chromatic_plane_basis + + +def vae_fingerprint(vae) -> str: + """8-char hex fingerprint from VAE decoder weights. + + Used to cache calibration data per-VAE so different VAE models + get separate calibration files automatically. + """ + sd = vae.get_sd() + # Use first decoder weight tensor as fingerprint source + for key in sorted(sd.keys()): + if "decoder" in key and "weight" in key: + w = sd[key] + return hashlib.sha256(w.cpu().float().numpy().tobytes()).hexdigest()[:8] + # Fallback: hash first weight found + first_key = sorted(sd.keys())[0] + w = sd[first_key] + return hashlib.sha256(w.cpu().float().numpy().tobytes()).hexdigest()[:8] + + +# 8 anchor colors: R, B, G, M, C, Y, Black, White +ANCHOR_COLORS_RGB = [ + (1.0, 0.0, 0.0), # Red + (0.0, 0.0, 1.0), # Blue + (0.0, 1.0, 0.0), # Green + (1.0, 0.0, 1.0), # Magenta + (0.0, 1.0, 1.0), # Cyan + (1.0, 1.0, 0.0), # Yellow + (0.0, 0.0, 0.0), # Black + (1.0, 1.0, 1.0), # White +] + + +def calibrate(vae, num_colors=512, image_size=512, batch_size=8): + """Compute LCS data (PCA basis, mean, anchors) from FLUX VAE. + + 1. Sample num_colors solid-color images uniformly from HSV + 2. VAE encode each → latent + 3. Patchify → average patches per image → vector in R^64 + 4. PCA on all vectors → basis B [64,3], mean μ [64] + 5. Encode 8 anchor colors → compute LCS coords + hue angles + + Returns: LCSData + """ + device = comfy.model_management.intermediate_device() + + print(f"\n[LCS Calibration] Starting calibration for {num_colors} colors...") + print(f"[LCS Calibration] Image size: {image_size}x{image_size}, Batch size: {batch_size}") + + # Step 1: Sample colors uniformly from HSV (full saturation, full value for diversity) + colors = [] + for i in range(num_colors): + # Uniform sampling in HSV + h = (i * 137.508) % 360.0 / 360.0 # Golden angle for uniform coverage + s = 0.3 + 0.7 * ((i * 73) % 100) / 100.0 # Vary saturation 0.3-1.0 + v = 0.3 + 0.7 * ((i * 47) % 100) / 100.0 # Vary value 0.3-1.0 + # HSV to RGB + r, g, b = _hsv_to_rgb(h, s, v) + colors.append((r, g, b)) + + # Step 2+3: Encode and average patches + vectors = [] + pbar = comfy.utils.ProgressBar(num_colors) + + num_batches = (num_colors + batch_size - 1) // batch_size + print(f"[LCS Calibration] Encoding {num_colors} color images in {num_batches} batches...") + + for batch_start in range(0, num_colors, batch_size): + batch_end = min(batch_start + batch_size, num_colors) + batch_colors = colors[batch_start:batch_end] + actual_batch = len(batch_colors) + + # Create solid color images [B, H, W, 3] (BHWC format for ComfyUI VAE) + imgs = torch.zeros(actual_batch, image_size, image_size, 3, dtype=torch.float32, device="cpu") + for j, (r, g, b) in enumerate(batch_colors): + imgs[j, :, :, 0] = r + imgs[j, :, :, 1] = g + imgs[j, :, :, 2] = b + + # VAE encode — try batch first, fall back to per-image for video VAEs + latent = vae.encode(imgs[:, :, :, :3]) + + # Squeeze video VAE temporal dim — calibration uses still images + if latent.ndim == 5: + latent = latent[:, :, 0, :, :] + + # Patchify → [B', L, D] + patches, _, _, _ = patchify(latent) + + # Average across patches → [B', D] + avg = patches.mean(dim=1).cpu() + + if avg.shape[0] == actual_batch: + # Normal VAE: batch encode worked + vectors.extend(avg.unbind(0)) + else: + # Video VAE or unexpected batch collapse — encode one by one + for k in range(actual_batch): + single = imgs[k:k+1, :, :, :3] + lat = vae.encode(single) + if lat.ndim == 5: + lat = lat[:, :, 0, :, :] + p, _, _, _ = patchify(lat) + vectors.append(p.mean(dim=1).cpu().squeeze(0)) + + pbar.update(actual_batch) + + # Stack all vectors: [N, 64] + X = torch.stack(vectors, dim=0).float() + print(f"[LCS Calibration] Collected {X.shape[0]} patch vectors of dimension {X.shape[1]}") + + # Step 4: PCA + print("[LCS Calibration] Computing PCA...") + mean = X.mean(dim=0) # [64] + X_centered = X - mean + # SVD for PCA + U, S, Vh = torch.linalg.svd(X_centered, full_matrices=False) + # Top 3 components: B = V[:, :3] (columns are principal directions) + basis = Vh[:3].T # [64, 3] (Vh rows are right singular vectors) + + # Variance explained + total_var = (S ** 2).sum() + explained = (S[:3] ** 2) / total_var + print(f"[LCS Calibration] Top 3 components explain {explained.sum():.1%} variance") + print(f"[LCS Calibration] PC1: {explained[0]:.1%}, PC2: {explained[1]:.1%}, PC3: {explained[2]:.1%}") + + # Step 5: Encode 8 anchor colors → LCS coords + print("[LCS Calibration] Encoding 8 anchor colors...") + anchor_lcs_list = [] + for i, (r, g, b) in enumerate(ANCHOR_COLORS_RGB): + img = torch.zeros(1, image_size, image_size, 3, dtype=torch.float32, device="cpu") + img[0, :, :, 0] = r + img[0, :, :, 1] = g + img[0, :, :, 2] = b + latent = vae.encode(img[:, :, :, :3]) + if latent.ndim == 5: + latent = latent[:, :, 0, :, :] + patches, _, _, _ = patchify(latent) + avg = patches.mean(dim=1).cpu().squeeze(0) # [64] + # Project to LCS + lcs_coord = (avg - mean) @ basis # [3] + anchor_lcs_list.append(lcs_coord) + + anchor_lcs = torch.stack(anchor_lcs_list, dim=0) # [8, 3] + + # Compute hue angles for 6 chromatic anchors + anchor_angles = _compute_anchor_angles(anchor_lcs, basis, mean) + + print(f"[LCS Calibration] Complete! Basis shape: {basis.shape}") + print(f"[LCS Calibration] Anchor LCS coords:\n{anchor_lcs}") + + return LCSData( + basis=basis, + mean=mean, + anchor_lcs=anchor_lcs, + anchor_angles=anchor_angles, + ) + + +def _compute_anchor_angles(anchor_lcs, basis, mean): + """Compute hue angles of the 6 chromatic anchors in the chromatic plane. + + The chromatic plane is perpendicular to the achromatic axis (black→white). + Returns [6] tensor of angles in radians. + """ + black = anchor_lcs[6] # [3] + white = anchor_lcs[7] # [3] + chromatic = anchor_lcs[:6] # [6, 3] + + # Achromatic axis + a = white - black + a_unit, e1, e2 = _chromatic_plane_basis(a) + + # Project each chromatic anchor onto the plane and compute angle + angles = [] + for i in range(6): + c = chromatic[i] + # Project onto achromatic axis + c_proj = black + ((c - black) * a).sum() / ((a * a).sum() + 1e-10) * a + # Chromatic residual + chroma = c - c_proj + x = (chroma * e1).sum() + y = (chroma * e2).sum() + angle = torch.atan2(y, x) % (2 * math.pi) + angles.append(angle) + + return torch.stack(angles) # [6] + + +def _hsv_to_rgb(h, s, v): + """Convert HSV to RGB (scalars in [0,1]).""" + if s < 1e-10: + return v, v, v + h6 = h * 6.0 + i = int(h6) % 6 + f = h6 - int(h6) + p = v * (1.0 - s) + q = v * (1.0 - s * f) + t = v * (1.0 - s * (1.0 - f)) + if i == 0: return v, t, p + if i == 1: return q, v, p + if i == 2: return p, v, t + if i == 3: return p, q, v + if i == 4: return t, p, v + return v, p, q diff --git a/custom_nodes/ComfyUI-LCS/core/color_space.py b/custom_nodes/ComfyUI-LCS/core/color_space.py new file mode 100644 index 0000000000000000000000000000000000000000..909eda20da68e3ef8a158436ffd83d27cdb5ead4 --- /dev/null +++ b/custom_nodes/ComfyUI-LCS/core/color_space.py @@ -0,0 +1,380 @@ +"""Bicone LCS ↔ HSL mapping using 8 anchor colors. + +Anchors are indexed as: [Red, Blue, Green, Magenta, Cyan, Yellow, Black, White] +Indices: 0=R, 1=B, 2=G, 3=M, 4=C, 5=Y, 6=Black, 7=White +""" + +import math +import torch + +# Standard HSL hue for each anchor: R=0, B=4/6, G=2/6, M=5/6, C=3/6, Y=1/6 +_ANCHOR_HUES = (0.0, 4.0/6.0, 2.0/6.0, 5.0/6.0, 3.0/6.0, 1.0/6.0) + + +def _bicone_factor(l, clamp_min=None): + """Compute bicone scaling factor: 1 - |2L - 1|. + + At l=0.5 (equator), factor=1 (full radius). + At l=0 or l=1 (poles), factor=0 (zero radius). + + Args: + l: Lightness tensor [...] + clamp_min: Optional minimum value for numerical stability + + Returns: + Bicone factor tensor [...] + """ + factor = 1.0 - (2.0 * l - 1.0).abs() + if clamp_min is not None: + factor = factor.clamp(min=clamp_min) + return factor + + +def _wrap_hue_diff(diff): + """Wrap hue differences to the shortest path on the unit circle [-0.5, 0.5].""" + return diff - (diff > 0.5).float() + (diff < -0.5).float() + + +def _hue_lerp(h1, h2, t): + """Lerp hues on the circle [0,1], taking the shortest path.""" + return (h1 + t * _wrap_hue_diff(h2 - h1)) % 1.0 + + +def _chromatic_plane_basis(a): + """Build orthonormal basis (a_unit, e1, e2) for the chromatic plane perpendicular to a.""" + a_unit = a / (a.norm() + 1e-10) + arb = torch.zeros(3, device=a.device, dtype=a.dtype) + arb[0] = 1.0 + if a_unit[0].abs() > 0.9: + arb[0] = 0.0 + arb[1] = 1.0 + e1 = arb - (arb * a_unit).sum() * a_unit + e1 = e1 / (e1.norm() + 1e-10) + e2 = torch.linalg.cross(a_unit, e1) + return a_unit, e1, e2 + + +def hex_to_hsl(hex_str): + """Convert "#RRGGBB" to (h, s, l) where h∈[0,1], s∈[0,1], l∈[0,1].""" + hex_str = hex_str.lstrip("#") + r = int(hex_str[0:2], 16) / 255.0 + g = int(hex_str[2:4], 16) / 255.0 + b = int(hex_str[4:6], 16) / 255.0 + return rgb_to_hsl(r, g, b) + + +def rgb_to_hsl(r, g, b): + """Convert RGB [0,1] to HSL [0,1].""" + cmax = max(r, g, b) + cmin = min(r, g, b) + delta = cmax - cmin + l = (cmax + cmin) / 2.0 + + if delta < 1e-10: + return 0.0, 0.0, l + + s = delta / (1.0 - abs(2.0 * l - 1.0)) if abs(2.0 * l - 1.0) < 1.0 else 0.0 + + if cmax == r: + h = ((g - b) / delta) % 6.0 + elif cmax == g: + h = (b - r) / delta + 2.0 + else: + h = (r - g) / delta + 4.0 + h = h / 6.0 + if h < 0: + h += 1.0 + + return h, max(0.0, min(1.0, s)), max(0.0, min(1.0, l)) + + +def hsl_to_rgb(h, s, l): + """Convert HSL [0,1] to RGB [0,1]. Works with scalars or tensors.""" + if isinstance(h, torch.Tensor): + return _hsl_to_rgb_tensor(h, s, l) + + c = (1.0 - abs(2.0 * l - 1.0)) * s + x = c * (1.0 - abs((h * 6.0) % 2.0 - 1.0)) + m = l - c / 2.0 + + h6 = h * 6.0 + if h6 < 1: + r, g, b = c, x, 0 + elif h6 < 2: + r, g, b = x, c, 0 + elif h6 < 3: + r, g, b = 0, c, x + elif h6 < 4: + r, g, b = 0, x, c + elif h6 < 5: + r, g, b = x, 0, c + else: + r, g, b = c, 0, x + + return r + m, g + m, b + m + + +def _hsl_to_rgb_tensor(h, s, l): + """Vectorized HSL→RGB for tensors.""" + c = _bicone_factor(l) * s + h6 = h * 6.0 + x = c * (1.0 - ((h6 % 2.0) - 1.0).abs()) + m = l - c / 2.0 + + r = torch.zeros_like(h) + g = torch.zeros_like(h) + b = torch.zeros_like(h) + + mask0 = h6 < 1 + mask1 = (h6 >= 1) & (h6 < 2) + mask2 = (h6 >= 2) & (h6 < 3) + mask3 = (h6 >= 3) & (h6 < 4) + mask4 = (h6 >= 4) & (h6 < 5) + mask5 = h6 >= 5 + + r[mask0] = c[mask0]; g[mask0] = x[mask0] + r[mask1] = x[mask1]; g[mask1] = c[mask1] + g[mask2] = c[mask2]; b[mask2] = x[mask2] + g[mask3] = x[mask3]; b[mask3] = c[mask3] + r[mask4] = x[mask4]; b[mask4] = c[mask4] + r[mask5] = c[mask5]; b[mask5] = x[mask5] + + return (r + m).clamp(0, 1), (g + m).clamp(0, 1), (b + m).clamp(0, 1) + + +def decode_lcs_to_hsl(c, anchor_lcs, anchor_angles): + """Decode LCS coordinates to HSL using bicone geometry. + + c: [..., 3] LCS coordinates (normalized to t=50) + anchor_lcs: [8, 3] anchor positions [R,B,G,M,C,Y,Black,White] + anchor_angles: [6] hue angles of chromatic anchors in radians + + Returns: (h, s, l) each [...] in [0,1] + """ + black = anchor_lcs[6] # [3] + white = anchor_lcs[7] # [3] + chromatic = anchor_lcs[:6] # [6, 3] + + # Achromatic axis + a = white - black # [3] + a_norm_sq = (a * a).sum() + 1e-10 + + # Lightness: project onto achromatic axis + diff = c - black # [..., 3] + l = (diff * a).sum(dim=-1) / a_norm_sq # [...] + l = l.clamp(0.0, 1.0) + + # Point on achromatic axis + c_L = black + l.unsqueeze(-1) * a # [..., 3] + + # Chromatic residual + chroma_vec = c - c_L # [..., 3] + chroma_dist = chroma_vec.norm(dim=-1) + 1e-10 # [...] + + # Compute hue angle in chromatic plane + a_unit, e1, e2 = _chromatic_plane_basis(a) + + # Project chromatic vector to 2D + x_coord = (chroma_vec * e1).sum(dim=-1) # [...] + y_coord = (chroma_vec * e2).sum(dim=-1) # [...] + angle = torch.atan2(y_coord, x_coord) # [...] radians + angle = angle % (2 * math.pi) + + # Map angle to hue [0,1] using sorted anchor angles + # anchor_angles are the angles of [R,B,G,M,C,Y] in the same coordinate system + # Standard HSL hue: R=0, Y=1/6, G=2/6, C=3/6, B=4/6, M=5/6 + # But anchors may not be in that order in angle-space, so we interpolate + sorted_angles, sort_idx = anchor_angles.sort() + anchor_hues = torch.tensor(_ANCHOR_HUES, device=c.device, dtype=c.dtype) + sorted_hues = anchor_hues[sort_idx] + + # Piecewise linear interpolation around the circle + h = _angle_to_hue(angle, sorted_angles, sorted_hues) + + # Saturation: distance to achromatic axis normalized by max distance + # Max distance at this hue and lightness + bicone_factor = _bicone_factor(l, clamp_min=1e-10) + + # Find the chroma boundary at this hue (perpendicular to achromatic axis) + chroma_boundary = _hue_to_chroma_vector(h, chromatic, anchor_angles, a_unit, e1, e2, black, a) + max_radius = chroma_boundary.norm(dim=-1) + 1e-10 + s = chroma_dist / (max_radius * bicone_factor) + s = s.clamp(0.0, 1.0) + + return h, s, l + + +def encode_hsl_to_lcs(h, s, l, anchor_lcs, anchor_angles): + """Encode HSL to LCS coordinates using bicone geometry. + + h, s, l: [...] in [0,1] + anchor_lcs: [8, 3] + anchor_angles: [6] radians + + Returns: c [..., 3] LCS coordinates + """ + black = anchor_lcs[6] # [3] + white = anchor_lcs[7] # [3] + chromatic = anchor_lcs[:6] # [6, 3] + + a = white - black + a_unit, e1, e2 = _chromatic_plane_basis(a) + + # Lightness point on achromatic axis + c_L = black + l.unsqueeze(-1) * a # [..., 3] + + # Chroma direction vector (equatorial radius at this hue) + chroma_dir = _hue_to_chroma_vector(h, chromatic, anchor_angles, a_unit, e1, e2, black, a) + + # Combine: c = c_L + s * (1 - |2l-1|) * chroma_dir + bicone_factor = _bicone_factor(l) + c = c_L + (s * bicone_factor).unsqueeze(-1) * chroma_dir + + return c + + +def _angle_to_hue(angle, sorted_angles, sorted_hues): + """Map an angle [...] to hue [0,1] via piecewise linear interpolation on anchor angles.""" + n = len(sorted_angles) + h = torch.zeros_like(angle) + + for i in range(n): + j = (i + 1) % n + a_start = sorted_angles[i] + a_end = sorted_angles[j] + h_start = sorted_hues[i] + h_end = sorted_hues[j] + + # Handle wraparound + if a_end < a_start: + a_end = a_end + 2 * math.pi + span = a_end - a_start + if span < 1e-10: + continue + + # Check which angles fall in this segment + if a_end > 2 * math.pi: + # Wraparound segment + mask = (angle >= a_start) | (angle < (a_end - 2 * math.pi)) + angle_shifted = torch.where(angle < a_start, angle + 2 * math.pi, angle) + else: + mask = (angle >= a_start) & (angle < a_end) + angle_shifted = angle + + frac = ((angle_shifted - a_start) / span).clamp(0, 1) + + # Interpolate hue (handling hue wraparound) + h_diff = h_end - h_start + if abs(h_diff) > 0.5: + if h_diff > 0: + h_diff -= 1.0 + else: + h_diff += 1.0 + interp = h_start + frac * h_diff + interp = interp % 1.0 + + h = torch.where(mask, interp, h) + + return h + + +def _hue_to_chroma_vector(h, chromatic, anchor_angles, a_unit, e1, e2, black, a): + """Map hue values [...] to EQUATORIAL chroma direction vectors. + + Returns vectors in 3D LCS space that lie in the chromatic plane (perpendicular to a_unit) + with magnitude equal to the equatorial chroma radius at that hue (i.e., the radius at l=0.5). + + The equatorial radius is computed by normalizing each anchor's chroma radius by its + bicone factor (1 - |2L - 1|), where L is the anchor's lightness. This ensures proper + round-trip encoding/decoding across the bicone. + + chromatic: [6, 3] anchor LCS positions + anchor_angles: [6] calibrated angles of chromatic anchors (radians) + a_unit: [3] unit vector along achromatic axis + e1, e2: [3] orthonormal basis for chromatic plane + black: [3] black anchor position + a: [3] full achromatic axis vector (white - black) + """ + # Compute each anchor's lightness (scalar projection onto achromatic axis) + a_sq = (a * a).sum() + 1e-10 + anchor_diff = chromatic - black # [6, 3] + anchor_l = (anchor_diff * a).sum(dim=-1) / a_sq # [6] lightness values + + # Project anchors onto chromatic plane to get chroma vectors + anchor_on_axis = black + anchor_l.unsqueeze(-1) * a # [6, 3] + anchor_chroma = chromatic - anchor_on_axis # [6, 3] chroma vectors + anchor_r = anchor_chroma.norm(dim=-1) # [6] radii at anchor lightness + + # Normalize to equatorial radii (radius at l=0.5 where bicone_factor=1) + bicone_factors = _bicone_factor(anchor_l, clamp_min=1e-6) # [6] + equatorial_r = anchor_r / bicone_factors # [6] equatorial radii + + anchor_hues = torch.tensor(_ANCHOR_HUES, device=chromatic.device, dtype=chromatic.dtype) + + # Sort by ANGLE (same as _angle_to_hue) to match segment structure + sorted_angles, sort_idx = anchor_angles.sort() + sorted_hues = anchor_hues[sort_idx] + sorted_radii = equatorial_r[sort_idx] # [6] equatorial radii + + # Iterate segments in angle order (same as _angle_to_hue) + n = 6 + result = torch.empty(h.shape + (3,), device=chromatic.device, dtype=chromatic.dtype) + + for i in range(n): + j = (i + 1) % n + h_start = sorted_hues[i] + h_end = sorted_hues[j] + + # Hue span with wraparound (same logic as _angle_to_hue) + h_diff = h_end - h_start + if abs(h_diff) > 0.5: + if h_diff > 0: + h_diff -= 1.0 + else: + h_diff += 1.0 + + if abs(h_diff) < 1e-10: + continue + + # Determine hue range for this segment + h_end_unwrapped = h_start + h_diff + + # Build mask for which input hues fall in this segment + if h_diff > 0: + if h_end_unwrapped > 1.0: + mask = (h >= h_start) | (h < (h_end_unwrapped - 1.0)) + h_shifted = torch.where(h < h_start, h + 1.0, h) + else: + mask = (h >= h_start) & (h < h_end_unwrapped) + h_shifted = h + else: + # Hue decreases + if h_end_unwrapped < 0.0: + mask = (h <= h_start) | (h > (h_end_unwrapped + 1.0)) + h_shifted = torch.where(h > h_start, h - 1.0, h) + else: + mask = (h <= h_start) & (h > h_end_unwrapped) + h_shifted = h + + frac = ((h_shifted - h_start) / h_diff).clamp(0, 1) + + # Interpolate radius + interp_r = sorted_radii[i] + frac * (sorted_radii[j] - sorted_radii[i]) + + # Interpolate angle + a_start = sorted_angles[i] + a_end = sorted_angles[j] + a_span = a_end - a_start + if a_span < 0: + a_span += 2 * math.pi + interp_angle = (a_start + frac * a_span) % (2 * math.pi) + + # Reconstruct 3D chroma vector + interp_vec = interp_r.unsqueeze(-1) * ( + torch.cos(interp_angle).unsqueeze(-1) * e1 + + torch.sin(interp_angle).unsqueeze(-1) * e2 + ) + + result = torch.where(mask.unsqueeze(-1), interp_vec, result) + + return result diff --git a/custom_nodes/ComfyUI-LCS/core/defaults.py b/custom_nodes/ComfyUI-LCS/core/defaults.py new file mode 100644 index 0000000000000000000000000000000000000000..6fa34747a74df994a3ee0f1f04b94e47c9a2c5a9 --- /dev/null +++ b/custom_nodes/ComfyUI-LCS/core/defaults.py @@ -0,0 +1,65 @@ +"""Hardcoded alpha_t and beta_t tables from paper Appendix F (51 entries, t=0..50).""" + +import torch + +# Shift alpha_t: 3D vectors for each timestep t=0..50 +ALPHA_T = [ + [2.3413, -2.3586, 0.4266], [2.3574, -2.3833, 0.4644], [2.3638, -2.3904, 0.4883], + [2.3734, -2.3951, 0.5122], [2.3831, -2.3993, 0.5384], [2.3925, -2.4026, 0.5647], + [2.4023, -2.4047, 0.5919], [2.4124, -2.4060, 0.6198], [2.4226, -2.4064, 0.6484], + [2.4330, -2.4060, 0.6772], [2.4437, -2.4051, 0.7065], [2.4546, -2.4035, 0.7367], + [2.4659, -2.4011, 0.7668], [2.4775, -2.3981, 0.7974], [2.4897, -2.4009, 0.8312], + [2.5021, -2.4036, 0.8656], [2.5148, -2.4065, 0.9008], [2.5277, -2.4093, 0.9364], + [2.5408, -2.4123, 0.9727], [2.5542, -2.4154, 1.0099], [2.5680, -2.4186, 1.0481], + [2.5820, -2.4218, 1.0868], [2.5963, -2.4252, 1.1263], [2.6110, -2.4288, 1.1672], + [2.6261, -2.4324, 1.2090], [2.6416, -2.4363, 1.2520], [2.6575, -2.4403, 1.2957], + [2.6738, -2.4444, 1.3406], [2.6904, -2.4485, 1.3865], [2.7074, -2.4529, 1.4336], + [2.7250, -2.4574, 1.4818], [2.7432, -2.4621, 1.5314], [2.7618, -2.4669, 1.5823], + [2.7810, -2.4720, 1.6344], [2.8006, -2.4771, 1.6878], [2.8209, -2.4826, 1.7430], + [2.8418, -2.4883, 1.7995], [2.8631, -2.4944, 1.8578], [2.8853, -2.5005, 1.9179], + [2.9080, -2.5066, 1.9793], [2.9313, -2.5132, 2.0426], [2.9555, -2.5199, 2.1082], + [2.9804, -2.5268, 2.1756], [3.0060, -2.5338, 2.2450], [3.0328, -2.5411, 2.3172], + [3.0603, -2.5486, 2.3914], [3.0889, -2.5561, 2.4682], [3.1189, -2.5640, 2.5482], + [3.1497, -2.5725, 2.6302], [3.1824, -2.5796, 2.7175], [3.2152, -2.5889, 2.8050], +] + +# Scale beta_t: 3D vectors for each timestep t=0..50 +BETA_T = [ + [0.0163, 0.0172, 0.0295], [0.0905, 0.0716, 0.0999], [0.1345, 0.1123, 0.1544], + [0.1826, 0.1491, 0.2065], [0.2360, 0.1899, 0.2630], [0.2904, 0.2316, 0.3202], + [0.3471, 0.2749, 0.3793], [0.4050, 0.3191, 0.4394], [0.4640, 0.3641, 0.5003], + [0.5231, 0.4091, 0.5611], [0.5834, 0.4547, 0.6228], [0.6456, 0.5016, 0.6861], + [0.7077, 0.5481, 0.7488], [0.7713, 0.5958, 0.8127], [0.8410, 0.6496, 0.8866], + [0.9119, 0.7044, 0.9616], [0.9845, 0.7605, 1.0386], [1.0578, 0.8172, 1.1163], + [1.1325, 0.8750, 1.1957], [1.2094, 0.9344, 1.2771], [1.2880, 0.9953, 1.3606], + [1.3680, 1.0571, 1.4453], [1.4498, 1.1205, 1.5321], [1.5341, 1.1858, 1.6216], + [1.6206, 1.2526, 1.7131], [1.7094, 1.3214, 1.8072], [1.7998, 1.3913, 1.9030], + [1.8927, 1.4633, 2.0014], [1.9879, 1.5370, 2.1022], [2.0854, 1.6126, 2.2056], + [2.1853, 1.6900, 2.3114], [2.2881, 1.7696, 2.4202], [2.3939, 1.8515, 2.5321], + [2.5021, 1.9354, 2.6467], [2.6133, 2.0215, 2.7642], [2.7280, 2.1106, 2.8857], + [2.8455, 2.2017, 3.0101], [2.9668, 2.2957, 3.1386], [3.0921, 2.3929, 3.2712], + [3.2204, 2.4922, 3.4067], [3.3523, 2.5946, 3.5464], [3.4888, 2.7006, 3.6911], + [3.6292, 2.8097, 3.8398], [3.7741, 2.9222, 3.9931], [3.9247, 3.0394, 4.1527], + [4.0793, 3.1597, 4.3168], [4.2393, 3.2843, 4.4866], [4.4053, 3.4142, 4.6636], + [4.5760, 3.5480, 4.8461], [4.7541, 3.6886, 5.0383], [4.9407, 3.8364, 5.2390], +] + +# Pre-convert to tensors (lazily cached on first access) +_alpha_tensor = None +_beta_tensor = None + + +def get_alpha_table(): + """Return α_t table as tensor [51, 3], cached after first call.""" + global _alpha_tensor + if _alpha_tensor is None: + _alpha_tensor = torch.tensor(ALPHA_T, dtype=torch.float32) # [51, 3] + return _alpha_tensor + + +def get_beta_table(): + """Return β_t table as tensor [51, 3], cached after first call.""" + global _beta_tensor + if _beta_tensor is None: + _beta_tensor = torch.tensor(BETA_T, dtype=torch.float32) # [51, 3] + return _beta_tensor diff --git a/custom_nodes/ComfyUI-LCS/core/diagnostics.py b/custom_nodes/ComfyUI-LCS/core/diagnostics.py new file mode 100644 index 0000000000000000000000000000000000000000..6e7f75687677c28f4c59092f1e76890e54f61d82 --- /dev/null +++ b/custom_nodes/ComfyUI-LCS/core/diagnostics.py @@ -0,0 +1,246 @@ +"""Diagnostic tests for LCS intervention pipeline. + +This module provides tests and diagnostics to identify conditions that +cause image blurriness or quality degradation during LCS intervention. +""" + +import torch +import math +from .color_space import decode_lcs_to_hsl, encode_hsl_to_lcs, _hue_lerp +from .timestep import get_alpha_beta, get_alpha_beta_t50, normalize_to_t50, denormalize_from_t50 + +# Test constants +_T50_REFERENCE_COORD = [0.5, 0.3, 0.1] # Typical LCS magnitude at t=50 +_TEST_STRENGTHS = [0.0, 0.25, 0.5, 0.75, 1.0, 1.5, 2.0] # Range from none to overshoot +_VARIATION_SCALE = 0.5 # Scale for test patch variation +_NOISE_SCALE = 2.0 # Simulated diffusion noise magnitude +_PROBLEMATIC_AMPLIFICATION_THRESHOLD = 50 # >50x noise amplification is problematic + + +def test_round_trip_consistency(anchor_lcs, anchor_angles): + """Test that encode(decode(x)) ≈ x for typical LCS coordinates. + + This verifies the bicone geometry math is correct. + """ + chromatic = anchor_lcs[:6] + black, white = anchor_lcs[6], anchor_lcs[7] + + # Test round-trip on anchor positions + errors = [] + test_cases = list(chromatic) # All 6 chromatic anchors + + # Add some mid-tones and random points + for _ in range(5): + # Generate random LCS point + h = torch.rand(1).item() + s = torch.rand(1).item() + l = torch.rand(1).item() + c = encode_hsl_to_lcs( + torch.tensor(h), torch.tensor(s), torch.tensor(l), + anchor_lcs, anchor_angles + ) + test_cases.append(c) + + for c in test_cases: + h, s, l = decode_lcs_to_hsl(c, anchor_lcs, anchor_angles) + c_round = encode_hsl_to_lcs(h, s, l, anchor_lcs, anchor_angles) + error = (c - c_round).norm().item() + errors.append(error) + + max_error = max(errors) + avg_error = sum(errors) / len(errors) + return { + "max_round_trip_error": max_error, + "avg_round_trip_error": avg_error, + "passed": max_error < 1e-4, + "errors": errors, + } + + +def test_normalization_stability(): + """Test that normalize/denormalize round-trip is stable across all timesteps. + + Identifies timesteps where numerical instability could cause issues. + """ + # Sample LCS coordinates at t=50 (clean image reference) + c_t50 = torch.tensor(_T50_REFERENCE_COORD, dtype=torch.float32) + alpha_50, beta_50 = get_alpha_beta_t50() + + results = [] + for t in range(51): + sigma = 1.0 - t / 50.0 # sigma = 1 - t/50 + alpha_t, beta_t = get_alpha_beta(sigma) + + # Normalize then denormalize + c_norm = normalize_to_t50(c_t50, alpha_t, beta_t, alpha_50, beta_50) + c_back = denormalize_from_t50(c_norm, alpha_t, beta_t, alpha_50, beta_50) + + error = (c_t50 - c_back).norm().item() + + # Check amplification factor + amplification = (beta_50 / beta_t).max().item() + + results.append({ + "t": t, + "sigma": sigma, + "beta_t_min": beta_t.min().item(), + "amplification": amplification, + "round_trip_error": error, + }) + + return results + + +def test_type_ii_uniformity(anchor_lcs, anchor_angles): + """Test if Type II intervention at high strength produces uniform outputs. + + This is a key diagnostic for the blurriness issue - if all patches + converge to the same HSL values, the image loses detail. + """ + # Create diverse patch set (simulate image with color variation) + patches = torch.randn(100, 3) * _VARIATION_SCALE + torch.tensor([0.3, 0.2, 0.1]) + + # Target color (e.g., saturated red) + t_h, t_s, t_l = 0.0, 1.0, 0.5 + + # Decode all patches ONCE (constant across strengths) + h_cur, s_cur, l_cur = decode_lcs_to_hsl(patches, anchor_lcs, anchor_angles) + + # Target HSL tensors + h_new = torch.full_like(h_cur, t_h) + s_new = torch.full_like(s_cur, t_s) + l_new = torch.full_like(l_cur, t_l) + + # Compute input variance once (patches never changes) + input_var = patches.var(dim=0).mean().item() + + # Test different strengths + for strength in _TEST_STRENGTHS: + # Hue lerp using shared helper + h_interp = _hue_lerp(h_cur, h_new, strength) + s_interp = (s_cur + strength * (s_new - s_cur)).clamp(0, 1) + l_interp = (l_cur + strength * (l_new - l_cur)).clamp(0, 1) + + # Re-encode + new_patches = encode_hsl_to_lcs(h_interp, s_interp, l_interp, anchor_lcs, anchor_angles) + + # Measure variance loss + output_var = new_patches.var(dim=0).mean().item() + var_ratio = output_var / (input_var + 1e-10) + + # Check how many unique HSL values we end up with + h_unique = len(torch.unique(h_interp.round(decimals=3))) + s_unique = len(torch.unique(s_interp.round(decimals=3))) + l_unique = len(torch.unique(l_interp.round(decimals=3))) + + print(f"strength={strength:.2f}: var_ratio={var_ratio:.3f}, " + f"unique_h={h_unique}, unique_s={s_unique}, unique_l={l_unique}") + + +def test_early_timestep_amplification(): + """Test numerical behavior at very early timesteps (high sigma). + + At t≈0 (sigma≈1), beta_t is very small, causing large amplification + in normalize_to_t50. This could amplify noise and corrupt the signal. + """ + # Typical LCS coordinate magnitude at t=50 + c_ref = torch.tensor(_T50_REFERENCE_COORD, dtype=torch.float32) + alpha_50, beta_50 = get_alpha_beta_t50() # Constant across all sigmas + + for sigma in [1.0, 0.99, 0.95, 0.90, 0.85, 0.80, 0.50, 0.0]: + alpha_t, beta_t = get_alpha_beta(sigma) + + # Simulate a noisy observation at timestep t + # In diffusion, the observation is alpha_t * clean + beta_t * noise + # At high sigma, noise dominates + noise = torch.randn(3) * _NOISE_SCALE + c_observed = alpha_t + beta_t * c_ref + beta_t * noise + + # Normalize to t=50 + c_norm = normalize_to_t50(c_observed, alpha_t, beta_t, alpha_50, beta_50) + + # Measure deviation from reference + deviation = (c_norm - c_ref).norm().item() + amplification = (beta_50 / beta_t).max().item() + + print(f"sigma={sigma:.2f}: beta_t={beta_t.numpy()}, " + f"amplification={amplification:.1f}x, deviation={deviation:.3f}") + + +def analyze_blurriness_causes(lcs_data_path=None): + """Comprehensive analysis of all potential blurriness causes.""" + print("=" * 60) + print("LCS INTERVENTION BLURRINESS ANALYSIS") + print("=" * 60) + + # Load actual calibration data + if lcs_data_path is None: + from pathlib import Path + data_dir = Path(__file__).parent.parent / "data" + safetensors_files = list(data_dir.glob("lcs_*.safetensors")) + if safetensors_files: + lcs_data_path = safetensors_files[0] + else: + print("ERROR: No calibration data found. Run LCSLoadData with calibrate=True first.") + return + + from safetensors.torch import load_file + data = load_file(lcs_data_path) + anchor_lcs = data["anchor_lcs"] + anchor_angles = data["anchor_angles"] + + print(f"\nLoaded calibration data from: {lcs_data_path}") + print(f"anchor_lcs shape: {anchor_lcs.shape}") + print(f"anchor_angles shape: {anchor_angles.shape}") + + print("\n1. ROUND-TRIP CONSISTENCY TEST") + print("-" * 40) + result = test_round_trip_consistency(anchor_lcs, anchor_angles) + print(f"Max error: {result['max_round_trip_error']:.2e}") + print(f"Avg error: {result['avg_round_trip_error']:.2e}") + print(f"Status: {'PASS' if result['passed'] else 'FAIL'}") + + print("\n2. NORMALIZATION STABILITY TEST") + print("-" * 40) + norm_results = test_normalization_stability() + problematic = [r for r in norm_results if r['amplification'] > _PROBLEMATIC_AMPLIFICATION_THRESHOLD] + print(f"Timesteps with >{_PROBLEMATIC_AMPLIFICATION_THRESHOLD}x amplification: {len(problematic)}") + for r in problematic[:5]: + print(f" t={r['t']:2d} (sigma={r['sigma']:.2f}): amp={r['amplification']:.1f}x") + + print("\n3. TYPE II UNIFORMITY TEST") + print("-" * 40) + test_type_ii_uniformity(anchor_lcs, anchor_angles) + + print("\n4. EARLY TIMESTEP AMPLIFICATION TEST") + print("-" * 40) + test_early_timestep_amplification() + + print("\n" + "=" * 60) + print("CONCLUSIONS") + print("=" * 60) + print(""" +Potential blurriness causes identified: + +1. TYPE II AT HIGH STRENGTH: At strength=1.0, all patches get the same + target HSL, destroying spatial color variation. This is the PRIMARY + cause of blur in type_ii mode. + +2. EARLY TIMESTEP AMPLIFICATION: At sigma>0.95 (t<2.5), beta_t is ~0.02, + causing ~250x amplification of noise. Intervening too early (step 0-2) + will corrupt the signal. + +3. OVERSHOOTING: strength>1.0 overshoots the target, potentially pushing + values outside the valid color gamut. This can cause clipping and + artifacts. + +RECOMMENDATIONS: +- For type_ii mode, use strength<0.8 to preserve some original variation +- Avoid intervening before step 5 (sigma<0.90) +- For interpolated mode, the gamma=sigma blending naturally limits damage + at early steps +""") + + +if __name__ == "__main__": + analyze_blurriness_causes() diff --git a/custom_nodes/ComfyUI-LCS/core/lcs_data.py b/custom_nodes/ComfyUI-LCS/core/lcs_data.py new file mode 100644 index 0000000000000000000000000000000000000000..649bcced12e921df38a6bb725aaf2fe3086a6028 --- /dev/null +++ b/custom_nodes/ComfyUI-LCS/core/lcs_data.py @@ -0,0 +1,28 @@ +from dataclasses import dataclass +import torch + + +@dataclass +class LCSData: + """Calibration data for the Latent Color Subspace. + + Produced by PCA on FLUX VAE-encoded solid-color images. Flows between + all LCS nodes as the shared LCS_DATA custom type. + """ + + basis: torch.Tensor # [64, 3] PCA basis B (orthonormal columns) + mean: torch.Tensor # [64] PCA mean mu + anchor_lcs: torch.Tensor # [8, 3] LCS coords of 8 anchor colors [R,B,G,M,C,Y,Black,White] + anchor_angles: torch.Tensor # [6] hue angles (radians) of the 6 chromatic anchors + + def to(self, device, dtype=None): + """Move all tensors to device/dtype.""" + kw = {"device": device} + if dtype is not None: + kw["dtype"] = dtype + return LCSData( + basis=self.basis.to(**kw), + mean=self.mean.to(**kw), + anchor_lcs=self.anchor_lcs.to(**kw), + anchor_angles=self.anchor_angles.to(**kw), + ) diff --git a/custom_nodes/ComfyUI-LCS/core/patchify.py b/custom_nodes/ComfyUI-LCS/core/patchify.py new file mode 100644 index 0000000000000000000000000000000000000000..bc3cbd5e8cf04eb474d4d910d093e9558f7bb2a0 --- /dev/null +++ b/custom_nodes/ComfyUI-LCS/core/patchify.py @@ -0,0 +1,93 @@ +"""Patchify/unpatchify for latent tensors (patch_size=2, auto-detect channels). + +Handles 3D, 4D, and 5D inputs. Pads odd spatial dims to even before patchifying. +""" + +from einops import rearrange +import torch.nn.functional as F + + +def patchify(x): + """Convert latent [C, H, W], [B, C, H, W], or [B, C, T, H, W] → patch sequence [B, L, C*4]. + + Handles three input formats: + - 3D [C, H, W]: adds batch dim, extra_shape="unbatched" + - 4D [B, C, H, W]: standard path, extra_shape=None + - 5D [B, C, T, H, W]: video VAE, merges T into batch, extra_shape=(B, C, T) + + Pads odd H/W to even before patchifying. The pad amounts are stored + in the returned extra_shape for unpatchify to crop back. + + L = (H_padded/2) * (W_padded/2), d = C * 2 * 2. + """ + extra_shape = None + pad_h = 0 + pad_w = 0 + + if x.ndim == 3: + extra_shape = "unbatched" + x = x.unsqueeze(0) + elif x.ndim == 5: + B_orig, C, T, H, W = x.shape + extra_shape = (B_orig, C, T) + x = x.permute(0, 2, 1, 3, 4).reshape(B_orig * T, C, H, W) + + B, C, H, W = x.shape + if H < 1 or W < 1: + return None, None, None, None + + # Pad odd dimensions to even (replicate last row/col) + if H % 2 != 0: + pad_h = 1 + if W % 2 != 0: + pad_w = 1 + if pad_h or pad_w: + x = F.pad(x, (0, pad_w, 0, pad_h), mode="replicate") + + H_p, W_p = x.shape[2], x.shape[3] + h_len = H_p // 2 + w_len = W_p // 2 + patches = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + + # Bundle pad info with extra_shape + if pad_h or pad_w: + extra_shape = {"orig_extra": extra_shape, "pad_h": pad_h, "pad_w": pad_w} + + return patches, h_len, w_len, extra_shape + + +def unpatchify(patches, h_len, w_len, extra_shape=None): + """Convert patch sequence [B, L, C*4] → latent, restoring original shape. + + Auto-detects channel count from patch dimension: C = D / 4. + Handles padding removal and 3D/5D restoration based on extra_shape. + """ + D = patches.shape[-1] + C = D // 4 # patch_size=2×2=4 + x = rearrange(patches, "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=h_len, w=w_len, c=C, ph=2, pw=2) + + # Unwrap pad info if present + pad_h = 0 + pad_w = 0 + orig_extra = extra_shape + if isinstance(extra_shape, dict): + pad_h = extra_shape["pad_h"] + pad_w = extra_shape["pad_w"] + orig_extra = extra_shape["orig_extra"] + + # Remove padding + if pad_h: + x = x[:, :, :-pad_h, :] + if pad_w: + x = x[:, :, :, :-pad_w] + + # Restore original format + if orig_extra == "unbatched": + x = x.squeeze(0) + elif orig_extra is not None: + B_orig, C_orig, T = orig_extra + H, W = x.shape[2], x.shape[3] + x = x.reshape(B_orig, T, C_orig, H, W).permute(0, 2, 1, 3, 4) + + return x diff --git a/custom_nodes/ComfyUI-LCS/core/relationships.py b/custom_nodes/ComfyUI-LCS/core/relationships.py new file mode 100644 index 0000000000000000000000000000000000000000..470659572a742e5ee38f9b8cb5949d045826c4bd --- /dev/null +++ b/custom_nodes/ComfyUI-LCS/core/relationships.py @@ -0,0 +1,117 @@ +"""Local color relationship analysis for drift detection and correction.""" + +import torch +import torch.nn.functional as F + + +def compute_local_relationships(c, h_len, w_len, kernel_radius=2): + """Compute per-patch relationship vector from 5x5 neighborhood. + + For each patch, cosine similarity with each of up to 24 neighbors. + Returns [B, L, N_neighbors] relationship vectors where N_neighbors = (2*r+1)^2 - 1. + """ + B = c.shape[0] + r = kernel_radius + k_size = 2 * r + 1 + n_neighbors = k_size * k_size - 1 # 24 for r=2 + + # Reshape to spatial grid + grid = c.reshape(B, h_len, w_len, 3) # [B, H, W, 3] + + # Permute to [B, 3, H, W] for padding + grid_chw = grid.permute(0, 3, 1, 2) # [B, 3, H, W] + padded = F.pad(grid_chw, (r, r, r, r), mode="replicate") # [B, 3, H+2r, W+2r] + + # Center values — normalize for cosine similarity + center_norm = grid_chw / grid_chw.norm(dim=1, keepdim=True).clamp(min=1e-8) + + # Pre-normalize padded tensor once (avoids per-neighbor normalization in loop) + padded_norm = padded / padded.norm(dim=1, keepdim=True).clamp(min=1e-8) + + # Collect cosine similarities with each neighbor + similarities = [] + for dy in range(-r, r + 1): + for dx in range(-r, r + 1): + if dy == 0 and dx == 0: + continue + y_start = r + dy + x_start = r + dx + neighbor_norm = padded_norm[:, :, y_start:y_start + h_len, x_start:x_start + w_len] + # Cosine similarity per pixel + sim = (center_norm * neighbor_norm).sum(dim=1) # [B, H, W] + similarities.append(sim) + + # Stack to [B, H, W, N_neighbors] -> [B, L, N_neighbors] + rel = torch.stack(similarities, dim=-1) # [B, H, W, N_neighbors] + return rel.reshape(B, -1, n_neighbors) + + +def detect_anomalies_adaptive(r_current, r_reference): + """Compare current vs reference relationships with adaptive threshold. + + Uses per-batch robust outlier detection: threshold = median + 3.0 * 1.4826 * MAD. + Returns anomaly_magnitude [B, L, 1] in [0, 1]. + """ + # Mean absolute difference across neighbor relationships + diff = (r_current - r_reference).abs().mean(dim=-1) # [B, L] + + # Per-batch robust statistics + median = diff.median(dim=-1, keepdim=True).values # [B, 1] + mad = (diff - median).abs().median(dim=-1, keepdim=True).values # [B, 1] + threshold = median + 3.0 * 1.4826 * mad # [B, 1] + + # Soft ramp above threshold, normalized to [0, 1] + anomaly = (diff - threshold).clamp(min=0.0) # [B, L] + # Normalize per-batch: max anomaly → 1.0 + amax = anomaly.amax(dim=-1, keepdim=True).clamp(min=1e-8) # [B, 1] + anomaly = anomaly / amax + + return anomaly.unsqueeze(-1) # [B, L, 1] + + +def infer_color_from_neighbors(c, anomaly_mag, h_len, w_len, kernel_radius=2): + """For anomalous patches, infer correct color from non-anomalous neighbors. + + Uses inverse-anomaly weighting: patches with low anomaly contribute more. + Returns [B, L, 3] corrected colors (blended: anomalous patches get + neighbor-inferred values, non-anomalous patches keep their original). + """ + B = c.shape[0] + r = kernel_radius + + # Reshape to spatial grid + grid = c.reshape(B, h_len, w_len, 3) + anom_grid = anomaly_mag.reshape(B, h_len, w_len, 1) + + # Pad both grid and anomaly + grid_chw = grid.permute(0, 3, 1, 2) # [B, 3, H, W] + anom_chw = anom_grid.permute(0, 3, 1, 2) # [B, 1, H, W] + padded_c = F.pad(grid_chw, (r, r, r, r), mode="replicate") + padded_a = F.pad(anom_chw, (r, r, r, r), mode="replicate") + + # Weight neighbors by how non-anomalous they are + weight_sum = torch.zeros(B, 1, h_len, w_len, device=c.device, dtype=c.dtype) + value_sum = torch.zeros(B, 3, h_len, w_len, device=c.device, dtype=c.dtype) + + for dy in range(-r, r + 1): + for dx in range(-r, r + 1): + if dy == 0 and dx == 0: + continue + y_start = r + dy + x_start = r + dx + neighbor_c = padded_c[:, :, y_start:y_start + h_len, x_start:x_start + w_len] + neighbor_a = padded_a[:, :, y_start:y_start + h_len, x_start:x_start + w_len] + + # Weight: 1 - anomaly (non-anomalous neighbors get high weight) + w = (1.0 - neighbor_a).clamp(min=0.01) # [B, 1, H, W] + weight_sum = weight_sum + w + value_sum = value_sum + w * neighbor_c + + # Inferred color from neighbors + inferred = value_sum / weight_sum.clamp(min=1e-8) # [B, 3, H, W] + inferred = inferred.permute(0, 2, 3, 1).reshape(B, -1, 3) # [B, L, 3] + + # Blend: anomalous patches use inferred, non-anomalous keep original + # anomaly_mag is [B, L, 1], range [0, ~1] + blend = anomaly_mag.clamp(0, 1) + return c * (1.0 - blend) + inferred * blend diff --git a/custom_nodes/ComfyUI-LCS/core/sampling.py b/custom_nodes/ComfyUI-LCS/core/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..925687fb8451ce659a2769a22c299b977bc38969 --- /dev/null +++ b/custom_nodes/ComfyUI-LCS/core/sampling.py @@ -0,0 +1,105 @@ +"""Shared sampling utilities for LCS intervention hooks.""" + +import comfy.utils +import torch +import torch.nn.functional as F + + +def find_step_index(sigma, sigmas): + """Find the step index for a given sigma value in the sigma schedule. + + Uses torch.isclose for robust matching across dtype differences (e.g. bfloat16 + sigma vs float32 sample_sigmas), with argmin fallback for edge cases. + """ + sigma_val = sigma.flatten()[0].float() + sigmas_f = sigmas.float() + matched = torch.isclose(sigmas_f, sigma_val, rtol=1e-3, atol=1e-5).nonzero() + if len(matched) > 0: + return matched[0].item() + return (sigmas_f - sigma_val).abs().argmin().item() + + +def denoised_to_raw(denoised, model): + """Convert denoised tensor from process_in space to raw VAE space. + + Uses the model's latent_format.process_out (inverse of process_in). + Works for any model: FLUX (scale+shift), LTXV (identity), SD (scale), etc. + """ + return model.latent_format.process_out(denoised) + + +def raw_to_denoised(raw, model): + """Convert raw VAE space tensor back to process_in space. + + Uses the model's latent_format.process_in. + """ + return model.latent_format.process_in(raw) + + +def unpack_video_if_needed(denoised, args): + """Unpack LTXAV-style packed latents if detected. + + LTXAV packs video [B,128,F,H,W] + audio [B,ch,T,freq] into [B,1,flat]. + Returns (tensor_to_process, pack_info) where pack_info is None for + non-packed formats or a dict for repacking. + """ + # Detect packed format: shape [B, 1, flat] with very large last dim + if denoised.ndim == 3 and denoised.shape[1] == 1: + # Try to find latent_shapes from cond data + cond = args.get("cond") + latent_shapes = _extract_latent_shapes(cond) + if latent_shapes is not None and len(latent_shapes) > 1: + tensors = comfy.utils.unpack_latents(denoised, latent_shapes) + # tensors[0] = video [B, 128, F, H, W], tensors[1] = audio [B, ch, T, freq] + return tensors[0], {"other_tensors": tensors[1:]} + return denoised, None + + +def repack_video_if_needed(modified, pack_info): + """Repack video tensor back into LTXAV packed format if it was unpacked. + + modified: the video tensor after intervention [B, 128, F, H, W] + pack_info: from unpack_video_if_needed + """ + if pack_info is None: + return modified + all_tensors = [modified] + pack_info["other_tensors"] + packed, _ = comfy.utils.pack_latents(all_tensors) + return packed + + +def downsample_mask(mask, h_len, w_len, device, dtype): + """Downsample a mask to patch grid and flatten to [1, L, 1].""" + mask_dev = mask.to(device=device, dtype=dtype) + if mask_dev.ndim == 3: + mask_dev = mask_dev[:1] + if mask_dev.ndim == 2: + mask_4d = mask_dev.unsqueeze(0).unsqueeze(0) # [1, 1, H, W] + elif mask_dev.ndim == 3: + mask_4d = mask_dev.unsqueeze(1) # [B, 1, H, W] + else: + mask_4d = mask_dev + mask_resized = F.interpolate( + mask_4d, size=(h_len, w_len), mode="bilinear", align_corners=False + ) + return mask_resized.reshape(1, -1, 1) # [1, L, 1] + + +def _extract_latent_shapes(cond): + """Try to extract latent_shapes from conditioning data. + + After convert_cond, cond is a list of dicts with 'model_conds' containing + CONDConstant-wrapped values like 'latent_shapes'. + """ + if cond is None: + return None + for c in cond: + if isinstance(c, dict): + model_conds = c.get('model_conds', {}) + if 'latent_shapes' in model_conds: + ls = model_conds['latent_shapes'] + # CONDConstant wraps the value in .cond + if hasattr(ls, 'cond'): + return ls.cond + return ls + return None diff --git a/custom_nodes/ComfyUI-LCS/core/sharpness.py b/custom_nodes/ComfyUI-LCS/core/sharpness.py new file mode 100644 index 0000000000000000000000000000000000000000..47f74fb9d5d35b9791db6ab1cad69f84a5b34158 --- /dev/null +++ b/custom_nodes/ComfyUI-LCS/core/sharpness.py @@ -0,0 +1,213 @@ +"""Sharpness subspace calibration via sinusoidal grating stimuli. + +Replaces the previous Gaussian blur approach with narrowband frequency +gratings, which achieve higher linearity (R²=0.94 vs 0.88) because each +stimulus contains a single spatial frequency — a purer probe of the VAE's +frequency encoding axis. + +The two methods discover the same 1D subspace (|cos|=0.986, 9.7° apart), +but grating stimuli yield a cleaner PC1 direction. +""" + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple +import warnings + +import torch +import comfy.utils + +from .patchify import patchify +from .lcs_data import LCSData + + +@dataclass +class SharpnessData: + """Calibration data for the sharpness subspace. + + Produced by PCA on FLUX VAE-encoded sinusoidal gratings at varying + spatial frequencies. PC1 captures ~94% of variance with R²=0.94 + linearity vs log₂(frequency). + """ + + basis: torch.Tensor # [64, K] PCA basis (columns), K typically 1-2 + mean: torch.Tensor # [64] PCA mean (in color-removed space if lcs_data was used) + sign: float # +1 or -1: ensures positive strength = sharper + lcs_basis: Optional[torch.Tensor] = None # [64, 3] LCS basis used during calibration (for re-orthogonalization) + + def to(self, device, dtype=None): + """Move all tensors to device/dtype.""" + kw = {"device": device} + if dtype is not None: + kw["dtype"] = dtype + return SharpnessData( + basis=self.basis.to(**kw), + mean=self.mean.to(**kw), + sign=self.sign, + lcs_basis=self.lcs_basis.to(**kw) if self.lcs_basis is not None else None, + ) + + +def _generate_grating_batch( + indices: List[int], + angles: torch.Tensor, + phases: torch.Tensor, + frequencies: Tuple[float, ...], + coord_x: torch.Tensor, + coord_y: torch.Tensor, +) -> torch.Tensor: + """Generate a batch of sinusoidal grating stimuli by flat index. + + Each flat index maps to (orientation, frequency) via divmod. + Returns [len(indices), 3, H, W] tensor. + """ + num_freqs = len(frequencies) + batch = [] + for idx in indices: + ori = idx // num_freqs + freq = frequencies[idx % num_freqs] + angle = angles[ori].item() + phase = phases[ori].item() + cos_a, sin_a = math.cos(angle), math.sin(angle) + coord = coord_x * cos_a + coord_y * sin_a + grating = 0.5 + 0.3 * torch.sin(2 * math.pi * freq * coord + phase) + batch.append(grating.unsqueeze(0).expand(3, -1, -1)) + return torch.stack(batch, dim=0) + + +def calibrate_sharpness(vae, num_samples: int = 64, image_size: int = 512, + frequencies: Tuple[float, ...] = (1, 2, 4, 8, 16, 32, 64), + batch_size: int = 8, + lcs_data: LCSData = None, + # Legacy parameter — accepted but ignored + blur_levels: Optional[Tuple[float, ...]] = None, + ) -> SharpnessData: + """Compute sharpness subspace data (PCA basis, mean, sign) from FLUX VAE. + + Generates sinusoidal gratings at varying spatial frequencies (one pure + frequency per stimulus), VAE-encodes them, and runs PCA to find the + sharpness/frequency direction in 64D patch space. + + Args: + vae: ComfyUI VAE object + num_samples: Number of orientations (each combined with all frequencies) + image_size: Size of generated images + frequencies: Spatial frequencies in cycles/image + batch_size: Batch size for VAE encoding + lcs_data: Optional LCS data for removing color component during calibration. + When provided, the sharpness PC1 will be orthogonal to the color subspace, + preventing color shifts during intervention. + + Returns: SharpnessData + """ + if blur_levels is not None: + warnings.warn( + "blur_levels is deprecated and ignored; calibration now uses sinusoidal gratings", + DeprecationWarning, stacklevel=2, + ) + + n_freqs = len(frequencies) + total_images = num_samples * n_freqs + + print(f"\n[LCS Sharpness Calibration] Starting: {num_samples} orientations × {n_freqs} frequencies = {total_images} stimuli") + print(f"[LCS Sharpness Calibration] Frequencies: {list(frequencies)} cycles/image") + + # Pre-compute shared state for grating generation + gen = torch.Generator().manual_seed(42) + angles = torch.rand(num_samples, generator=gen) * math.pi # [0, π) + phases = torch.rand(num_samples, generator=gen) * 2 * math.pi # [0, 2π) + y_coords = torch.linspace(-0.5, 0.5, image_size).unsqueeze(1) + x_coords = torch.linspace(-0.5, 0.5, image_size).unsqueeze(0) + coord_y = y_coords.expand(image_size, image_size) + coord_x = x_coords.expand(image_size, image_size) + + # Build frequency labels for all stimuli (flat index → frequency) + freq_labels = [frequencies[idx % n_freqs] for idx in range(total_images)] + freq_labels_t = torch.tensor(freq_labels, dtype=torch.float32) + log_freq = torch.log2(freq_labels_t.clamp(min=0.5)) + + # Generate stimuli lazily per batch and VAE encode + vectors = [] + pbar = comfy.utils.ProgressBar(total_images) + + for batch_start in range(0, total_images, batch_size): + batch_end = min(batch_start + batch_size, total_images) + indices = list(range(batch_start, batch_end)) + batch = _generate_grating_batch(indices, angles, phases, frequencies, coord_x, coord_y) + actual_batch = batch.shape[0] + + # Convert BCHW → BHWC for ComfyUI VAE + imgs_bhwc = batch.permute(0, 2, 3, 1).contiguous().cpu() + + # VAE encode — try batch first, fall back to per-image for video VAEs + latent = vae.encode(imgs_bhwc) + patches, _, _, _ = patchify(latent) + avg = patches.mean(dim=1).cpu() + + if avg.shape[0] == actual_batch: + vectors.extend(avg.unbind(0)) + else: + # Video VAE: batch not fully supported, encode one by one + vectors.extend(avg.unbind(0)) + for k in range(1, actual_batch): + single = imgs_bhwc[k:k+1] + lat = vae.encode(single) + p, _, _, _ = patchify(lat) + vectors.append(p.mean(dim=1).cpu().squeeze(0)) + + pbar.update(actual_batch) + + # Stack all vectors: [N, 64] + X = torch.stack(vectors, dim=0).float() + print(f"[LCS Sharpness Calibration] Collected {X.shape[0]} vectors of dimension {X.shape[1]}") + + # Remove LCS color component FIRST, in the raw space where LCS was calibrated. + # This must happen before per-vector DC removal, because the LCS basis has + # significant DC components (PC1 ≈ brightness). Doing DC removal first would + # shift vectors into a different space where B^T(x - mu) is incorrect. + if lcs_data is not None: + print("[LCS Sharpness Calibration] Removing LCS color component...") + lcs_mean = lcs_data.mean.to(X.device, X.dtype) + lcs_basis = lcs_data.basis.to(X.device, X.dtype) + # Project out color: X' = X - B B^T (X - mu) + centered = X - lcs_mean + lcs_coords = centered @ lcs_basis # [N, 3] + X = X - lcs_coords @ lcs_basis.T + print("[LCS Sharpness Calibration] Color component removed") + + # Remove per-vector DC AFTER color removal. + # VAE encoding shifts the latent mean depending on stimulus content. + # Per-vector zero-mean forces PCA to find patterns in the relative channel + # structure, not in the absolute level. + X = X - X.mean(dim=1, keepdim=True) + + # Step 3: PCA + print("[LCS Sharpness Calibration] Computing PCA...") + mean = X.mean(dim=0) # [64] + X_centered = X - mean + U, S, Vh = torch.linalg.svd(X_centered, full_matrices=False) + # Top 2 components + basis = Vh[:2].T # [64, 2] + + # Variance explained + total_var = (S ** 2).sum() + explained = (S[:2] ** 2) / total_var + print(f"[LCS Sharpness Calibration] PC1: {explained[0]:.1%}, PC2: {explained[1]:.1%} ({(explained[0]+explained[1]):.1%} total)") + + # Step 4: Determine sign convention + # Project all vectors onto PC1 + pc1_scores = X_centered @ basis[:, 0] # [N] + + # Correlate PC1 score with log₂(frequency) + # Higher frequency = sharper → if positive correlation, sign = +1 + correlation = torch.corrcoef(torch.stack([pc1_scores, log_freq]))[0, 1] + sign = 1.0 if correlation > 0 else -1.0 + print(f"[LCS Sharpness Calibration] PC1-frequency correlation: {correlation:.3f} → sign = {sign:+.0f}") + print(f"[LCS Sharpness Calibration] Complete! Basis shape: {basis.shape}") + + return SharpnessData( + basis=basis, + mean=mean, + sign=sign, + lcs_basis=lcs_data.basis.clone() if lcs_data is not None else None, + ) diff --git a/custom_nodes/ComfyUI-LCS/core/timestep.py b/custom_nodes/ComfyUI-LCS/core/timestep.py new file mode 100644 index 0000000000000000000000000000000000000000..96948007e1c18608c01cffd7461b17d83cdd09f2 --- /dev/null +++ b/custom_nodes/ComfyUI-LCS/core/timestep.py @@ -0,0 +1,75 @@ +"""Sigma ↔ paper timestep conversion and α_t/β_t interpolation.""" + +import torch +from .defaults import get_alpha_table, get_beta_table + + +def sigma_to_paper_t(sigma): + """Convert FLUX sigma ∈ [0,1] to paper timestep t ∈ [0,50]. + + sigma=1 → noise → t=0, sigma=0 → clean → t=50. + """ + if isinstance(sigma, torch.Tensor): + return 50.0 * (1.0 - sigma.clamp(0.0, 1.0)) + return 50.0 * (1.0 - max(0.0, min(1.0, sigma))) + + +def get_alpha_beta(sigma, device=None): + """Get interpolated α_t and β_t [3] vectors for a given sigma. + + Returns (alpha_t, beta_t) as tensors on the specified device. + """ + t = sigma_to_paper_t(sigma) + if isinstance(t, torch.Tensor): + t = t.item() + + alpha_table = get_alpha_table() # [51, 3] + beta_table = get_beta_table() # [51, 3] + + t = max(0.0, min(50.0, t)) + t_low = int(t) + t_high = min(t_low + 1, 50) + frac = t - t_low + + alpha = (1.0 - frac) * alpha_table[t_low] + frac * alpha_table[t_high] + beta = (1.0 - frac) * beta_table[t_low] + frac * beta_table[t_high] + + if device is not None: + alpha = alpha.to(device) + beta = beta.to(device) + return alpha, beta + + +def get_alpha_beta_t50(device=None): + """Get α_50 and β_50 (reference timestep t=50, clean image).""" + alpha_table = get_alpha_table() + beta_table = get_beta_table() + alpha_50 = alpha_table[50] + beta_50 = beta_table[50] + if device is not None: + alpha_50 = alpha_50.to(device) + beta_50 = beta_50.to(device) + return alpha_50, beta_50 + + +def normalize_to_t50(c, alpha_t, beta_t, alpha_50, beta_50): + """Normalize LCS coords from timestep t to reference t=50. + + ĉ = (c - α_t) / β_t * β_50 + α_50 + c: [..., 3], alpha_t/beta_t/alpha_50/beta_50: [3] + """ + beta_t_safe = beta_t.clone() + beta_t_safe = torch.where(beta_t_safe.abs() < 1e-6, + torch.full_like(beta_t_safe, 1e-6), beta_t_safe) + return (c - alpha_t) / beta_t_safe * beta_50 + alpha_50 + + +def denormalize_from_t50(c_hat, alpha_t, beta_t, alpha_50, beta_50): + """Denormalize LCS coords from reference t=50 back to timestep t. + + c = (ĉ - α_50) / β_50 * β_t + α_t + """ + beta_50_safe = beta_50.clone() + beta_50_safe = torch.where(beta_50_safe.abs() < 1e-6, + torch.full_like(beta_50_safe, 1e-6), beta_50_safe) + return (c_hat - alpha_50) / beta_50_safe * beta_t + alpha_t diff --git a/custom_nodes/ComfyUI-LCS/data/.gitkeep b/custom_nodes/ComfyUI-LCS/data/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/custom_nodes/ComfyUI-LCS/data/lcs_bc7957e0.safetensors b/custom_nodes/ComfyUI-LCS/data/lcs_bc7957e0.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..847f0bca11ebb2097ec1cc08a7cee7331c0d8531 Binary files /dev/null and b/custom_nodes/ComfyUI-LCS/data/lcs_bc7957e0.safetensors differ diff --git a/custom_nodes/ComfyUI-LCS/nodes/__init__.py b/custom_nodes/ComfyUI-LCS/nodes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3dd37ab92246e9f7d0f60fd29aa6ec92f8fc4a12 --- /dev/null +++ b/custom_nodes/ComfyUI-LCS/nodes/__init__.py @@ -0,0 +1,31 @@ +"""V2 backward compatibility exports for all LCS nodes.""" + +from .calibrate import LCSLoadData +from .intervene import LCSColorIntervene, LCSColorBatch, LCSToneAdjust +from .observe import LCSPreviewColors, LCSStepObserver +from .sharpen import LCSSharpnessCalibrate, LCSSharpnessIntervene +from .anchor import LCSColorAnchor + +NODE_CLASS_MAPPINGS = { + "LCSLoadData": LCSLoadData, + "LCSColorIntervene": LCSColorIntervene, + "LCSColorBatch": LCSColorBatch, + "LCSToneAdjust": LCSToneAdjust, + "LCSPreviewColors": LCSPreviewColors, + "LCSStepObserver": LCSStepObserver, + "LCSSharpnessCalibrate": LCSSharpnessCalibrate, + "LCSSharpnessIntervene": LCSSharpnessIntervene, + "LCSColorAnchor": LCSColorAnchor, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "LCSLoadData": "LCS Load Data", + "LCSColorIntervene": "LCS Color Intervene", + "LCSColorBatch": "LCS Color Batch", + "LCSToneAdjust": "LCS Tone Adjust", + "LCSPreviewColors": "LCS Preview Colors", + "LCSStepObserver": "LCS Step Observer", + "LCSSharpnessCalibrate": "LCS Sharpness Calibrate", + "LCSSharpnessIntervene": "LCS Sharpness Intervene", + "LCSColorAnchor": "LCS Color Anchor", +} diff --git a/custom_nodes/ComfyUI-LCS/nodes/anchor.py b/custom_nodes/ComfyUI-LCS/nodes/anchor.py new file mode 100644 index 0000000000000000000000000000000000000000..d347ff720773e69a1f6647ec0cbb5601c5eb8e46 --- /dev/null +++ b/custom_nodes/ComfyUI-LCS/nodes/anchor.py @@ -0,0 +1,335 @@ +"""Color anchor node: correct color drift during sampling. + +Adaptive version — all scheduling and filtering parameters are derived +from runtime signals (sigma schedule, local color statistics, robust +outlier detection). User controls: mode + intensity. +""" + +import torch +import torch.nn.functional as F +from comfy_api.latest import io + +from ..core.adaptive import compute_step_phases, compute_strength_envelope, estimate_intensity +from ..core.bilateral import bilateral_filter_lcs, estimate_bilateral_params +from ..core.relationships import ( + compute_local_relationships, + detect_anomalies_adaptive, + infer_color_from_neighbors, +) +from ..core.patchify import patchify, unpatchify +from ..core.sampling import ( + find_step_index, + denoised_to_raw, + raw_to_denoised, + unpack_video_if_needed, + repack_video_if_needed, + downsample_mask, +) +from ..core.timestep import get_alpha_beta, get_alpha_beta_t50, normalize_to_t50, denormalize_from_t50 + +LCS_DATA = io.Custom("LCS_DATA") + + +def _encode_reference_to_lcs(reference_image, vae, lcs_data): + """VAE-encode reference image to LCS coordinates. + + reference_image: [B, H, W, 3] (BHWC ComfyUI format) + Returns (c_ref [1, L, 3], h_len, w_len) in t=50 space. + """ + latent = vae.encode(reference_image[:1, :, :, :3]) + patches, h_len, w_len, _ = patchify(latent) + if patches is None: + return None, 0, 0 + + device = patches.device + dtype = patches.dtype + ld = lcs_data.to(device, dtype) + c_ref = (patches - ld.mean) @ ld.basis # [1, L, 3] + return c_ref, h_len, w_len + + +def _resize_color_field(c, src_h, src_w, dst_h, dst_w): + """Bilinear resize of [B, L, 3] color field for resolution mismatch.""" + if src_h == dst_h and src_w == dst_w: + return c + B = c.shape[0] + grid = c.reshape(B, src_h, src_w, 3).permute(0, 3, 1, 2) + resized = F.interpolate(grid, size=(dst_h, dst_w), mode="bilinear", align_corners=False) + return resized.permute(0, 2, 3, 1).reshape(B, -1, 3) + + +def _build_adaptive_anchor_fn(lcs_data, mode, intensity, mask, + c_ref=None, ref_h=0, ref_w=0, r_ref=None, + auto_intensity=False): + """Build unified post_cfg_function for all anchor modes. + + Phase assignment and strength scheduling are derived from the sigma + schedule on the first hook call. All filter/threshold parameters are + estimated from the data at each step. + + Closure state auto-resets per graph execution (new closure = new dict). + """ + state = { + "phases": None, + "envelope": None, + "correction_index": 0, + "r_ema": None, + "prev_c_mean": None, + "drift_sum": 0.0, + "drift_count": 0, + "auto_intensity_val": None, + } + + def post_cfg_fn(args): + denoised = args["denoised"] + sigma = args["sigma"] + model = args["model"] + + # --- Lazy init: compute phases and envelope from sigma schedule --- + if state["phases"] is None: + sigmas = args["model_options"]["transformer_options"]["sample_sigmas"] + state["phases"] = compute_step_phases(sigmas, mode) + n_correct = sum(1 for p in state["phases"] if p == "correct") + state["envelope"] = compute_strength_envelope(n_correct) + state["correction_index"] = 0 + + # Find current step index + sigmas = args["model_options"]["transformer_options"]["sample_sigmas"] + step_index = find_step_index(sigma, sigmas) + + # Look up phase (guard against out-of-range) + if step_index >= len(state["phases"]): + return denoised + phase = state["phases"][step_index] + + # Skip phase — return unchanged + if phase == "skip": + return denoised + + # --- Common pipeline: unpack → raw → patchify → project → normalize --- + working, pack_info = unpack_video_if_needed(denoised, args) + + sigma_val = float(sigma.flatten()[0]) + device = working.device + dtype = working.dtype + + ld = lcs_data.to(device, dtype) + B_mat = ld.basis + mu = ld.mean + + raw = denoised_to_raw(working, model) + patches, h_len, w_len, extra_shape = patchify(raw) + if patches is None: + return denoised + + projection = (patches - mu) @ B_mat # [B, L, 3] + reconstruction = projection @ B_mat.T + mu + residual = patches - reconstruction + + alpha_t, beta_t = get_alpha_beta(sigma_val, device=device) + alpha_t, beta_t = alpha_t.to(dtype), beta_t.to(dtype) + alpha_50, beta_50 = get_alpha_beta_t50(device=device) + alpha_50, beta_50 = alpha_50.to(dtype), beta_50.to(dtype) + + c_norm = normalize_to_t50(projection, alpha_t, beta_t, alpha_50, beta_50) + + # --- Observe phase (self_anchor warmup): update EMA, return unchanged --- + if phase == "observe": + r_current = compute_local_relationships(c_norm, h_len, w_len) + decay = 0.8 + if state["r_ema"] is None: + state["r_ema"] = r_current.detach().clone() + else: + state["r_ema"] = decay * state["r_ema"] + (1 - decay) * r_current.detach() + + # Collect step-to-step drift for auto_intensity (self_anchor) + c_mean_now = c_norm.detach().mean(dim=1, keepdim=True) + if auto_intensity and state["prev_c_mean"] is not None: + drift = (c_mean_now - state["prev_c_mean"]).abs().mean().item() + state["drift_sum"] += drift + state["drift_count"] += 1 + state["prev_c_mean"] = c_mean_now + return denoised + + # --- Correct phase --- + # Compute mode-specific target first (reused for auto-intensity drift measurement) + if mode == "smooth": + sigma_s, sigma_c = estimate_bilateral_params(c_norm, h_len, w_len) + c_filtered = bilateral_filter_lcs(c_norm, h_len, w_len, sigma_s, sigma_c) + elif mode == "reference": + c_ref_dev = c_ref.to(device=device, dtype=dtype) + r_ref_dev = r_ref.to(device=device, dtype=dtype) + if ref_h != h_len or ref_w != w_len: + c_ref_resized = _resize_color_field(c_ref_dev, ref_h, ref_w, h_len, w_len) + r_ref_resized = compute_local_relationships(c_ref_resized, h_len, w_len) + else: + c_ref_resized = c_ref_dev + r_ref_resized = r_ref_dev + + # Auto-intensity: measure drift on first correction step, cache for rest + effective_intensity = intensity + if auto_intensity: + if state["auto_intensity_val"] is None: + if mode == "self_anchor" and state["drift_count"] > 0: + drift_signal = state["drift_sum"] / state["drift_count"] + elif mode == "self_anchor": + # No observe phase (img2img) — measure from seeded baseline + if state["prev_c_mean"] is not None: + drift_signal = (c_norm.detach().mean(dim=1, keepdim=True) + - state["prev_c_mean"]).abs().mean().item() + else: + drift_signal = 0.05 # conservative default + elif mode == "reference": + drift_signal = (c_norm - c_ref_resized).abs().mean().item() + elif mode == "smooth": + drift_signal = (c_filtered - c_norm).abs().mean().item() + state["auto_intensity_val"] = estimate_intensity(drift_signal) + effective_intensity = state["auto_intensity_val"] + + # Compute step strength from envelope + ci = state["correction_index"] + envelope = state["envelope"] + if ci < len(envelope): + step_strength = effective_intensity * float(envelope[ci]) + else: + step_strength = effective_intensity + state["correction_index"] = ci + 1 + + # Self-anchor convergence damping + if mode == "self_anchor" and state["prev_c_mean"] is not None: + c_mean_now = c_norm.detach().mean(dim=1, keepdim=True) + delta = (c_mean_now - state["prev_c_mean"]).abs().mean().item() + step_strength *= min(delta / 0.1, 1.0) + + # Mode-specific correction (reuses targets computed above) + if mode == "smooth": + new_c_norm = c_norm + step_strength * (c_filtered - c_norm) + + elif mode == "reference": + B_size = c_norm.shape[0] + c_ref_exp = c_ref_resized.expand(B_size, -1, -1) + r_ref_exp = r_ref_resized.expand(B_size, -1, -1) + + r_current = compute_local_relationships(c_norm, h_len, w_len) + anomaly_mag = detect_anomalies_adaptive(r_current, r_ref_exp) + + correction = c_ref_exp - c_norm + new_c_norm = c_norm + step_strength * anomaly_mag * correction + + else: # self_anchor + r_current = compute_local_relationships(c_norm, h_len, w_len) + + if state["r_ema"] is None: + # Seed EMA — first step, no correction yet (anomalies will be zero) + state["r_ema"] = r_current.detach().clone() + + anomaly_mag = detect_anomalies_adaptive(r_current, state["r_ema"]) + c_corrected = infer_color_from_neighbors( + c_norm, anomaly_mag, h_len, w_len + ) + new_c_norm = c_norm + step_strength * (c_corrected - c_norm) + + # Update EMA (slow decay during correction) + decay = 0.95 + state["r_ema"] = decay * state["r_ema"] + (1 - decay) * r_current.detach() + state["prev_c_mean"] = c_norm.detach().mean(dim=1, keepdim=True) + + # --- Apply mask --- + if mask is not None: + mask_flat = downsample_mask(mask, h_len, w_len, device, dtype) + if mask_flat.shape[1] != new_c_norm.shape[1]: + mask_flat = mask_flat[:, :new_c_norm.shape[1], :] + new_c_norm = c_norm + mask_flat * (new_c_norm - c_norm) + + # --- Denormalize → reconstruct → unpatchify → repack --- + new_projection = denormalize_from_t50(new_c_norm, alpha_t, beta_t, alpha_50, beta_50) + patches_new = new_projection @ B_mat.T + mu + residual + raw_new = unpatchify(patches_new, h_len, w_len, extra_shape) + modified = raw_to_denoised(raw_new, model).to(dtype) + return repack_video_if_needed(modified, pack_info) + + return post_cfg_fn + + +class LCSColorAnchor(io.ComfyNode): + """Correct color drift during sampling by anchoring local color relationships. + + Four modes: + - auto: Infer mode from connected inputs and intensity from drift signals + - smooth: Bilateral filter smooths color discontinuities (inpainting boundaries) + - reference: Anchor to a reference image's color relationships + - self_anchor: Build internal color model during warmup, then correct drift + + All scheduling and filter parameters are derived adaptively from the sigma + schedule and image content. In auto mode, intensity is also derived automatically. + """ + + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LCSColorAnchor", + display_name="LCS Color Anchor", + category="LCS/intervention", + description="Correct color drift during sampling by anchoring local color relationships", + inputs=[ + io.Model.Input("model"), + LCS_DATA.Input("lcs_data", tooltip="Calibration data from LCSLoadData"), + io.Combo.Input("mode", options=["auto", "smooth", "reference", "self_anchor"], + default="auto", + tooltip="auto: infer mode and intensity from connected inputs; smooth: bilateral filter; reference: anchor to image; self_anchor: auto-detect drift"), + io.Float.Input("intensity", default=0.5, min=0.0, max=1.0, step=0.05, + tooltip="Correction intensity (0 = none, 1 = full)"), + io.Vae.Input("vae", optional=True, + tooltip="Required for reference mode (VAE-encodes reference image)"), + io.Image.Input("reference_image", optional=True, + tooltip="Reference image for reference mode"), + io.Mask.Input("mask", optional=True, + tooltip="Optional mask for localized correction"), + ], + outputs=[ + io.Model.Output(display_name="model"), + ], + ) + + @classmethod + def execute(cls, model, lcs_data, mode, intensity, + vae=None, reference_image=None, mask=None) -> io.NodeOutput: + """Clone model, attach adaptive color anchor hook.""" + m = model.clone() + + # Resolve auto mode based on connected inputs + auto_intensity = False + if mode == "auto": + auto_intensity = True + if reference_image is not None and vae is not None: + mode = "reference" + elif mask is not None: + mode = "smooth" + else: + mode = "self_anchor" + + if not auto_intensity and intensity < 1e-6: + return io.NodeOutput(m) + + c_ref = None + ref_h = 0 + ref_w = 0 + r_ref = None + + if mode == "reference": + if vae is None or reference_image is None: + print("[LCS Color Anchor] Reference mode requires vae and reference_image — skipping.") + return io.NodeOutput(m) + c_ref, ref_h, ref_w = _encode_reference_to_lcs(reference_image, vae, lcs_data) + if c_ref is None: + print("[LCS Color Anchor] Failed to encode reference image — skipping.") + return io.NodeOutput(m) + r_ref = compute_local_relationships(c_ref, ref_h, ref_w) + + hook = _build_adaptive_anchor_fn( + lcs_data, mode, intensity, mask, + c_ref=c_ref, ref_h=ref_h, ref_w=ref_w, r_ref=r_ref, + auto_intensity=auto_intensity, + ) + m.set_model_sampler_post_cfg_function(hook) + return io.NodeOutput(m) diff --git a/custom_nodes/ComfyUI-LCS/nodes/calibrate.py b/custom_nodes/ComfyUI-LCS/nodes/calibrate.py new file mode 100644 index 0000000000000000000000000000000000000000..530210142b9b1a9356fa8ac39283aa2754bdc3e2 --- /dev/null +++ b/custom_nodes/ComfyUI-LCS/nodes/calibrate.py @@ -0,0 +1,71 @@ +"""Calibration node: LCSLoadData with automatic per-VAE caching.""" + +import os +import torch +from comfy_api.latest import io +from safetensors.torch import save_file, load_file + +from ..core.calibration import calibrate, vae_fingerprint +from ..core.lcs_data import LCSData + +LCS_DATA = io.Custom("LCS_DATA") +DATA_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data") + + +def _save_lcs(lcs_data: LCSData, path: str): + """Save LCSData to safetensors file.""" + os.makedirs(os.path.dirname(path), exist_ok=True) + save_file({ + "basis": lcs_data.basis.contiguous(), + "mean": lcs_data.mean.contiguous(), + "anchor_lcs": lcs_data.anchor_lcs.contiguous(), + "anchor_angles": lcs_data.anchor_angles.contiguous(), + }, path) + + +def _load_lcs(path: str) -> LCSData: + """Load LCSData from safetensors file.""" + data = load_file(path) + return LCSData( + basis=data["basis"], + mean=data["mean"], + anchor_lcs=data["anchor_lcs"], + anchor_angles=data["anchor_angles"], + ) + + +class LCSLoadData(io.ComfyNode): + """Load or auto-compute LCS calibration data for a VAE. + + Computes a fingerprint of the VAE weights and checks for a cached + calibration file. On cache miss, runs PCA calibration automatically + and saves the result for future reuse. + """ + + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LCSLoadData", + display_name="LCS Load Data", + category="LCS/calibration", + description="Auto-calibrate and cache LCS data per-VAE", + inputs=[ + io.Vae.Input("vae", tooltip="VAE model (calibration is cached per-VAE)"), + ], + outputs=[ + LCS_DATA.Output(display_name="lcs_data"), + ], + ) + + @classmethod + def execute(cls, vae) -> io.NodeOutput: + fp = vae_fingerprint(vae) + cache_path = os.path.join(DATA_DIR, f"lcs_{fp}.safetensors") + + if os.path.exists(cache_path): + lcs_data = _load_lcs(cache_path) + else: + lcs_data = calibrate(vae, num_colors=512, image_size=512) + _save_lcs(lcs_data, cache_path) + + return io.NodeOutput(lcs_data) diff --git a/custom_nodes/ComfyUI-LCS/nodes/intervene.py b/custom_nodes/ComfyUI-LCS/nodes/intervene.py new file mode 100644 index 0000000000000000000000000000000000000000..4d0bfffac6bbaacd7382face1430281e1561df7d --- /dev/null +++ b/custom_nodes/ComfyUI-LCS/nodes/intervene.py @@ -0,0 +1,490 @@ +"""Intervention nodes: LCSColorIntervene, LCSColorBatch, and LCSToneAdjust.""" + +import torch +from comfy_api.latest import io + +from ..core.lcs_data import LCSData +from ..core.patchify import patchify, unpatchify +from ..core.sampling import find_step_index, denoised_to_raw, raw_to_denoised, unpack_video_if_needed, repack_video_if_needed, downsample_mask +from ..core.timestep import get_alpha_beta, get_alpha_beta_t50, normalize_to_t50, denormalize_from_t50 +from ..core.color_space import hex_to_hsl, encode_hsl_to_lcs, decode_lcs_to_hsl, _hue_lerp + +LCS_DATA = io.Custom("LCS_DATA") + +# Backward-compat alias for any external code that imported _find_step_index +_find_step_index = find_step_index + + +def _build_post_cfg_fn(lcs_data, target_colors_hsl, strength, mode, start_step, end_step, mask): + """Build the post_cfg_function closure for color intervention. + + target_colors_hsl: list of (h, s, l) tuples, one per batch item (or one for all). + """ + def post_cfg_fn(args): + """Post-CFG hook: project to LCS, apply color intervention, reconstruct.""" + denoised = args["denoised"] + sigma = args["sigma"] + model = args["model"] + + # Determine current step index + sigmas = args["model_options"]["transformer_options"]["sample_sigmas"] + step_index = _find_step_index(sigma, sigmas) + + # Check if in intervention range + if step_index < start_step or step_index > end_step: + return denoised + + # Unpack LTXAV packed format if needed + working, pack_info = unpack_video_if_needed(denoised, args) + + sigma_val = float(sigma.flatten()[0]) + device = working.device + dtype = working.dtype + + # Move LCS data to device + ld = lcs_data.to(device, dtype) + B_mat = ld.basis + mu = ld.mean + anchor_lcs = ld.anchor_lcs + anchor_angles = ld.anchor_angles + + # Convert from process_in to raw VAE space + raw = denoised_to_raw(working, model) + + # Patchify + patches, h_len, w_len, extra_shape = patchify(raw) + if patches is None: + return denoised # Incompatible latent format + + # Project to LCS + projection = (patches - mu) @ B_mat # [B, L, 3] + + # Compute residual (61D orthogonal complement) + reconstruction = projection @ B_mat.T + mu # [B, L, 64] + residual = patches - reconstruction # [B, L, 64] + + # Get timestep statistics + alpha_t, beta_t = get_alpha_beta(sigma_val, device=device) + alpha_t, beta_t = alpha_t.to(dtype), beta_t.to(dtype) + alpha_50, beta_50 = get_alpha_beta_t50(device=device) + alpha_50, beta_50 = alpha_50.to(dtype), beta_50.to(dtype) + + # Normalize to t=50 + c_norm = normalize_to_t50(projection, alpha_t, beta_t, alpha_50, beta_50) # [B, L, 3] + + # Apply intervention per batch item + B_size = c_norm.shape[0] + new_c_norm = c_norm.clone() + + for b in range(B_size): + color_idx = b if b < len(target_colors_hsl) else 0 + t_h, t_s, t_l = target_colors_hsl[color_idx] + + # Encode target color to LCS at t=50 + t_h_t = torch.tensor(t_h, device=device, dtype=dtype) + t_s_t = torch.tensor(t_s, device=device, dtype=dtype) + t_l_t = torch.tensor(t_l, device=device, dtype=dtype) + target_lcs = encode_hsl_to_lcs(t_h_t, t_s_t, t_l_t, anchor_lcs, anchor_angles) # [3] + + c_b = c_norm[b] # [L, 3] + + if mode == "type_i": + # Type I: direct LCS translation + shift = target_lcs - c_b.mean(dim=0) + new_c_norm[b] = c_b + strength * shift + + elif mode == "type_ii": + # Type II: decode → shift in HSL → re-encode + h_cur, s_cur, l_cur = decode_lcs_to_hsl(c_b, anchor_lcs, anchor_angles) + # Shift towards target HSL + h_new = t_h_t.expand_as(h_cur) + s_new = t_s_t.expand_as(s_cur) + l_new = t_l_t.expand_as(l_cur) + # Interpolate in HSL + h_interp = _hue_lerp(h_cur, h_new, strength) + s_interp = s_cur + strength * (s_new - s_cur) + l_interp = l_cur + strength * (l_new - l_cur) + new_c_norm[b] = encode_hsl_to_lcs(h_interp, s_interp.clamp(0, 1), + l_interp.clamp(0, 1), + anchor_lcs, anchor_angles) + + else: # interpolated (default) + # gamma_t = sigma (high sigma → Type I, low sigma → Type II) + gamma = sigma_val + + # Type I + shift = target_lcs - c_b.mean(dim=0) + c_type_i = c_b + strength * shift + + # Type II + h_cur, s_cur, l_cur = decode_lcs_to_hsl(c_b, anchor_lcs, anchor_angles) + h_new = t_h_t.expand_as(h_cur) + s_new = t_s_t.expand_as(s_cur) + l_new = t_l_t.expand_as(l_cur) + h_interp = _hue_lerp(h_cur, h_new, strength) + s_interp = s_cur + strength * (s_new - s_cur) + l_interp = l_cur + strength * (l_new - l_cur) + c_type_ii = encode_hsl_to_lcs(h_interp, s_interp.clamp(0, 1), + l_interp.clamp(0, 1), + anchor_lcs, anchor_angles) + + # Interpolate: gamma * Type_I + (1-gamma) * Type_II + new_c_norm[b] = gamma * c_type_i + (1.0 - gamma) * c_type_ii + + # Apply mask if provided + if mask is not None: + mask_flat = downsample_mask(mask, h_len, w_len, device, dtype) + if mask_flat.shape[1] != new_c_norm.shape[1]: + mask_flat = mask_flat[:, :new_c_norm.shape[1], :] + # Blend: masked areas get intervention, unmasked keep original + new_c_norm = c_norm + mask_flat * (new_c_norm - c_norm) + + # Denormalize back to timestep t + new_projection = denormalize_from_t50(new_c_norm, alpha_t, beta_t, alpha_50, beta_50) + + # Reconstruct patches + patches_new = new_projection @ B_mat.T + mu + residual + + # Unpatchify + raw_new = unpatchify(patches_new, h_len, w_len, extra_shape) + + # Convert back to process_in space + modified = raw_to_denoised(raw_new, model).to(dtype) + + # Repack if LTXAV + return repack_video_if_needed(modified, pack_info) + + return post_cfg_fn + + +class LCSColorIntervene(io.ComfyNode): + """Steer colors during FLUX generation via the Latent Color Subspace. + + Installs a post-CFG hook that projects the denoised prediction into the + 3D LCS, shifts it toward the target color (Type I, Type II, or interpolated), + preserves the 61D residual, and writes the modified prediction back. + Active only during [start_step, end_step]. + """ + + @classmethod + def define_schema(cls) -> io.Schema: + """Define inputs (MODEL, LCS_DATA, color, strength, mode, steps, mask) and MODEL output.""" + return io.Schema( + node_id="LCSColorIntervene", + display_name="LCS Color Intervene", + category="LCS/intervention", + description="Steer colors during FLUX generation via Latent Color Subspace", + inputs=[ + io.Model.Input("model"), + LCS_DATA.Input("lcs_data", tooltip="Calibration data from LCSLoadData"), + io.Color.Input("color", default="#FF0000", tooltip="Target color"), + io.Float.Input("strength", default=1.0, min=0.0, max=2.0, step=0.05, + tooltip="Intervention strength (1.0 = full, 0.0 = none)"), + io.Combo.Input("mode", options=["interpolated", "type_i", "type_ii"], + default="interpolated", + tooltip="Interpolated blends Type I (LCS shift) and Type II (HSL shift)"), + io.Int.Input("start_step", default=8, min=0, max=50, + tooltip="First step to apply intervention (paper optimal: 8)"), + io.Int.Input("end_step", default=10, min=0, max=50, + tooltip="Last step to apply intervention (paper optimal: 10)"), + io.Mask.Input("mask", optional=True, + tooltip="Optional mask for localized color control"), + ], + outputs=[ + io.Model.Output(display_name="model"), + ], + ) + + @classmethod + def execute(cls, model, lcs_data, color, strength, mode, start_step, end_step, + mask=None) -> io.NodeOutput: + """Clone model, attach LCS color intervention hook. Returns patched MODEL.""" + m = model.clone() + h, s, l = hex_to_hsl(color) + hook = _build_post_cfg_fn(lcs_data, [(h, s, l)], strength, mode, start_step, end_step, mask) + m.set_model_sampler_post_cfg_function(hook) + return io.NodeOutput(m) + + +class LCSColorBatch(io.ComfyNode): + """Apply different target colors to each batch item for multi-color generation. + + Parses comma-separated hex colors and installs a post-CFG hook that applies + a distinct color target per batch index. Also outputs batch_size (INT) for + connecting to EmptyLatentImage. + """ + + @classmethod + def define_schema(cls) -> io.Schema: + """Define inputs (MODEL, LCS_DATA, colors string, strength, mode, steps, mask) and (MODEL, INT) outputs.""" + return io.Schema( + node_id="LCSColorBatch", + display_name="LCS Color Batch", + category="LCS/intervention", + description="Apply different target colors to each batch item for multi-color generation", + inputs=[ + io.Model.Input("model"), + LCS_DATA.Input("lcs_data", tooltip="Calibration data from LCSLoadData"), + io.String.Input("colors", default="#FF0000,#00FF00,#0000FF", + tooltip="Comma-separated hex colors, one per batch item"), + io.Float.Input("strength", default=1.0, min=0.0, max=2.0, step=0.05), + io.Combo.Input("mode", options=["interpolated", "type_i", "type_ii"], + default="interpolated"), + io.Int.Input("start_step", default=8, min=0, max=50), + io.Int.Input("end_step", default=10, min=0, max=50), + io.Mask.Input("mask", optional=True), + ], + outputs=[ + io.Model.Output(display_name="model"), + io.Int.Output(display_name="batch_size"), + ], + ) + + @classmethod + def execute(cls, model, lcs_data, colors, strength, mode, start_step, end_step, + mask=None) -> io.NodeOutput: + """Clone model, attach per-batch color hooks. Returns (MODEL, batch_size INT).""" + m = model.clone() + + # Parse comma-separated hex colors + color_list = [c.strip() for c in colors.split(",") if c.strip()] + target_hsl = [hex_to_hsl(c) for c in color_list] + batch_size = len(target_hsl) + + hook = _build_post_cfg_fn(lcs_data, target_hsl, strength, mode, start_step, end_step, mask) + m.set_model_sampler_post_cfg_function(hook) + return io.NodeOutput(m, batch_size) + + +_EPS = 1e-6 + + +def _is_default_tone(contrast, brightness, saturation, color_temperature): + """Check if all tone parameters are at their default (no-op) values.""" + return (abs(contrast - 1.0) < _EPS and abs(brightness) < _EPS + and abs(saturation - 1.0) < _EPS and abs(color_temperature) < _EPS) + + +def _build_tone_fn(lcs_data, contrast, brightness, saturation, color_temperature, + start_step, end_step, mask): + """Build the post_cfg_function closure for tone adjustment (contrast/brightness/saturation/temperature). + + Operates directly in 3D LCS space by decomposing into lightness (projection + onto achromatic axis) and chroma (perpendicular residual). No HSL round-trip. + """ + # Precompute warm/cool direction vector (depends only on immutable calibration data) + # Scaled so that color_temperature=1.0 shifts by the mean anchor chroma magnitude, + # giving a visually significant warm↔cool shift. + _warm_dir = None + if abs(color_temperature) > _EPS: + anc = lcs_data.anchor_lcs # [8, 3] + black_anc, white_anc = anc[6], anc[7] + a_pre = white_anc - black_anc + a_sq_pre = (a_pre * a_pre).sum() + + def _anchor_chroma_pre(idx): + pt = anc[idx] + l_a = ((pt - black_anc) * a_pre).sum() / a_sq_pre + return pt - (black_anc + l_a * a_pre) + + chromas = [_anchor_chroma_pre(i) for i in range(6)] + warm_center = (chromas[0] + chromas[5]) / 2 # Red + Yellow + cool_center = (chromas[1] + chromas[4]) / 2 # Blue + Cyan + wd = warm_center - cool_center + wd_unit = wd / wd.norm() # unit vector + + # Scale: color_temperature=1.0 → shift by mean chroma norm + mean_chroma_norm = sum(c.norm() for c in chromas) / 6.0 + _warm_dir = wd_unit * mean_chroma_norm + + def post_cfg_fn(args): + """Post-CFG hook: project to LCS, adjust contrast/brightness/saturation, reconstruct.""" + denoised = args["denoised"] + sigma = args["sigma"] + model = args["model"] + + # Determine current step index + sigmas = args["model_options"]["transformer_options"]["sample_sigmas"] + step_index = _find_step_index(sigma, sigmas) + + if step_index < start_step or step_index > end_step: + return denoised + + # Unpack LTXAV packed format if needed + working, pack_info = unpack_video_if_needed(denoised, args) + + sigma_val = float(sigma.flatten()[0]) + device = working.device + dtype = working.dtype + + ld = lcs_data.to(device, dtype) + B_mat = ld.basis + mu = ld.mean + anchor_lcs = ld.anchor_lcs + + # Convert from process_in to raw VAE space + raw = denoised_to_raw(working, model) + + # Patchify + patches, h_len, w_len, extra_shape = patchify(raw) + if patches is None: + return denoised # Incompatible latent format + + # Project to LCS + projection = (patches - mu) @ B_mat + + # Compute residual (orthogonal complement) + reconstruction = projection @ B_mat.T + mu + residual = patches - reconstruction + + # Get timestep statistics + alpha_t, beta_t = get_alpha_beta(sigma_val, device=device) + alpha_t, beta_t = alpha_t.to(dtype), beta_t.to(dtype) + alpha_50, beta_50 = get_alpha_beta_t50(device=device) + alpha_50, beta_50 = alpha_50.to(dtype), beta_50.to(dtype) + + # Normalize to t=50 + c_norm = normalize_to_t50(projection, alpha_t, beta_t, alpha_50, beta_50) + + # Achromatic axis: black → white in LCS anchor space + black = anchor_lcs[6] # [3] + white = anchor_lcs[7] # [3] + a = white - black # [3] + a_sq = (a * a).sum() # ||a||² + + # Decompose into lightness + chroma + # l_scalar: scalar projection along achromatic axis, [B, L] + l_scalar = ((c_norm - black) * a).sum(dim=-1) / a_sq + # c_L: point on achromatic axis, [B, L, 3] + c_L = black + l_scalar.unsqueeze(-1) * a + # chroma: perpendicular component, [B, L, 3] + chroma = c_norm - c_L + + # Adjust lightness: contrast around per-image mean + brightness shift + # No clamp: LCS coords are naturally unbounded during denoising (same as + # Type I intervention), and clamping destroys highlight/shadow detail that + # the user wants to enhance. The no-op skip in execute() handles defaults. + l_mean = l_scalar.mean(dim=-1, keepdim=True) # [B, 1] + l_new = (l_scalar - l_mean) * contrast + l_mean + brightness + + # Adjust color temperature: shift chroma along warm↔cool axis (precomputed) + if _warm_dir is not None: + chroma = chroma + color_temperature * _warm_dir.to(device=device, dtype=dtype) + + # Adjust saturation + chroma_new = chroma * saturation + + # Reconstruct in normalized LCS space + new_c_norm = black + l_new.unsqueeze(-1) * a + chroma_new # [B, L, 3] + + # Apply mask if provided + if mask is not None: + mask_flat = downsample_mask(mask, h_len, w_len, device, dtype) + if mask_flat.shape[1] != new_c_norm.shape[1]: + mask_flat = mask_flat[:, :new_c_norm.shape[1], :] + new_c_norm = c_norm + mask_flat * (new_c_norm - c_norm) + + # Denormalize back to timestep t + new_projection = denormalize_from_t50(new_c_norm, alpha_t, beta_t, alpha_50, beta_50) + + # Reconstruct patches + patches_new = new_projection @ B_mat.T + mu + residual + + # Unpatchify + raw_new = unpatchify(patches_new, h_len, w_len, extra_shape) + + # Convert back to process_in space + modified = raw_to_denoised(raw_new, model).to(dtype) + + # Repack if LTXAV + return repack_video_if_needed(modified, pack_info) + + return post_cfg_fn + + +# Preset definitions. Frontend JS (web/js/tone_preset.js) syncs these into +# sliders on user interaction. The Python-side copy serves as fallback for +# headless / batch execution where frontend JS does not run. +TONE_PRESETS = { + "Custom": None, + "Base": {"contrast": 1.0, "brightness": 0.0, "saturation": 1.0, "color_temperature": 0.0}, + "Cinematic": {"contrast": 1.20, "brightness": -0.05, "saturation": 0.90, "color_temperature": 0.05}, + "HDR": {"contrast": 1.40, "brightness": 0.0, "saturation": 1.20, "color_temperature": 0.0}, + "Vivid": {"contrast": 1.10, "brightness": 0.0, "saturation": 1.50, "color_temperature": 0.0}, + "Dramatic": {"contrast": 1.50, "brightness": -0.10, "saturation": 0.85, "color_temperature": 0.0}, + "Low Key": {"contrast": 1.30, "brightness": -0.20, "saturation": 0.80, "color_temperature": 0.0}, + "High Key": {"contrast": 0.80, "brightness": 0.20, "saturation": 0.90, "color_temperature": 0.0}, + "Warm": {"contrast": 1.05, "brightness": 0.03, "saturation": 1.10, "color_temperature": 0.30}, + "Cool": {"contrast": 1.05, "brightness": 0.0, "saturation": 1.05, "color_temperature": -0.30}, + "Desaturated": {"contrast": 1.0, "brightness": 0.0, "saturation": 0.40, "color_temperature": 0.0}, +} + + +class LCSToneAdjust(io.ComfyNode): + """Adjust tone (contrast, brightness, saturation, color temperature) in the Latent Color Subspace. + + Decomposes each patch into lightness (projection onto black→white axis) + and chroma (perpendicular residual). Contrast scales lightness around its + mean, brightness shifts it, saturation scales the chroma magnitude, and + color temperature shifts chroma along the warm↔cool axis. + All math is done directly in 3D LCS — no HSL round-trip needed. + Select a preset for one-click tonal styles, or use Custom to set sliders manually. + """ + + @classmethod + def define_schema(cls) -> io.Schema: + """Define inputs and MODEL output for tone adjustment.""" + return io.Schema( + node_id="LCSToneAdjust", + display_name="LCS Tone Adjust", + category="LCS/intervention", + description="Adjust tone (contrast, brightness, saturation, color temperature) via Latent Color Subspace", + inputs=[ + io.Model.Input("model"), + LCS_DATA.Input("lcs_data", tooltip="Calibration data from LCSLoadData"), + io.Combo.Input("preset", options=list(TONE_PRESETS.keys()), default="Custom", + tooltip="Select a tonal preset or Custom to use the sliders below"), + io.Float.Input("contrast", default=1.0, min=0.0, max=3.0, step=0.05, + tooltip="Lightness contrast multiplier (>1 = more contrast, <1 = less, 1 = no change)"), + io.Float.Input("brightness", default=0.0, min=-1.0, max=1.0, step=0.05, + tooltip="Lightness shift (>0 = brighter, <0 = darker)"), + io.Float.Input("saturation", default=1.0, min=0.0, max=3.0, step=0.05, + tooltip="Saturation multiplier (>1 = more vivid, <1 = more muted, 0 = grayscale)"), + io.Float.Input("color_temperature", default=0.0, min=-1.0, max=1.0, step=0.05, + tooltip="Color temperature shift (>0 = warmer/amber, <0 = cooler/blue)"), + io.Int.Input("start_step", default=5, min=0, max=50, + tooltip="First step to apply adjustment"), + io.Int.Input("end_step", default=15, min=0, max=50, + tooltip="Last step to apply adjustment"), + io.Mask.Input("mask", optional=True, + tooltip="Optional mask for localized adjustment"), + ], + outputs=[ + io.Model.Output(display_name="model"), + ], + ) + + @classmethod + def execute(cls, model, lcs_data, preset, contrast, brightness, saturation, + color_temperature, start_step, end_step, mask=None) -> io.NodeOutput: + """Clone model, attach LCS tone adjustment hook. Returns patched MODEL. + + Frontend JS syncs preset values into sliders on user interaction. + For headless/batch execution (no frontend), if a preset is selected + but sliders are still at defaults, apply the preset values server-side. + """ + # Headless fallback: preset selected but sliders untouched → apply preset + p = TONE_PRESETS.get(preset) + if p is not None and _is_default_tone(contrast, brightness, saturation, color_temperature): + contrast = p["contrast"] + brightness = p["brightness"] + saturation = p["saturation"] + color_temperature = p["color_temperature"] + + m = model.clone() + # Skip hook entirely when all parameters are at default (true no-op) + if not _is_default_tone(contrast, brightness, saturation, color_temperature): + hook = _build_tone_fn(lcs_data, contrast, brightness, saturation, + color_temperature, start_step, end_step, mask) + m.set_model_sampler_post_cfg_function(hook) + return io.NodeOutput(m) diff --git a/custom_nodes/ComfyUI-LCS/nodes/observe.py b/custom_nodes/ComfyUI-LCS/nodes/observe.py new file mode 100644 index 0000000000000000000000000000000000000000..bc6f7ef71a89ed6f8226bd702f012476796c8667 --- /dev/null +++ b/custom_nodes/ComfyUI-LCS/nodes/observe.py @@ -0,0 +1,161 @@ +"""Observation nodes: LCSPreviewColors and LCSStepObserver.""" + +import os +import torch +import numpy as np +import torch.nn.functional as F +from PIL import Image as PILImage +from comfy_api.latest import io +import folder_paths + +from ..core.lcs_data import LCSData +from ..core.patchify import patchify +from ..core.timestep import get_alpha_beta, get_alpha_beta_t50, normalize_to_t50 +from ..core.color_space import decode_lcs_to_hsl, hsl_to_rgb +from ..core.sampling import denoised_to_raw, unpack_video_if_needed + +LCS_DATA = io.Custom("LCS_DATA") + +# FLUX VAE constants — fallback for LCSPreviewColors which has no model access +_FLUX_SCALE_FACTOR = 0.3611 +_FLUX_SHIFT_FACTOR = 0.1159 + + +def _latent_to_color_preview(samples, lcs_data, sigma, upscale=8, model=None): + """Convert latent tensor to LCS color preview image. + + samples: [B, C, H, W] or [B, C, T, H, W] in process_in space + model: if provided, uses model.latent_format for space conversion; + otherwise falls back to FLUX constants. + Returns: [B, H_up, W_up, 3] float32 in [0,1] + """ + device = samples.device + dtype = samples.dtype + ld = lcs_data.to(device, dtype) + + if model is not None: + raw = denoised_to_raw(samples, model) + else: + raw = samples / _FLUX_SCALE_FACTOR + _FLUX_SHIFT_FACTOR + patches, h_len, w_len, _ = patchify(raw) + if patches is None: + # Incompatible latent format — return black image + B = samples.shape[0] if samples.ndim >= 4 else 1 + return torch.zeros(B, upscale * 2, upscale * 2, 3) + projection = (patches - ld.mean) @ ld.basis + + alpha_t, beta_t = get_alpha_beta(sigma, device=device) + alpha_t, beta_t = alpha_t.to(dtype), beta_t.to(dtype) + alpha_50, beta_50 = get_alpha_beta_t50(device=device) + alpha_50, beta_50 = alpha_50.to(dtype), beta_50.to(dtype) + c_norm = normalize_to_t50(projection, alpha_t, beta_t, alpha_50, beta_50) + + B = c_norm.shape[0] + images = [] + for b in range(B): + c_b = c_norm[b] + h_vals, s_vals, l_vals = decode_lcs_to_hsl(c_b, ld.anchor_lcs, ld.anchor_angles) + r, g, b_ch = hsl_to_rgb(h_vals, s_vals, l_vals) + rgb = torch.stack([r, g, b_ch], dim=-1).reshape(h_len, w_len, 3) + if upscale > 1: + rgb = F.interpolate( + rgb.permute(2, 0, 1).unsqueeze(0), + scale_factor=upscale, mode="nearest" + ).squeeze(0).permute(1, 2, 0) + images.append(rgb) + + return torch.stack(images, dim=0).clamp(0, 1).cpu().float() + + +class LCSPreviewColors(io.ComfyNode): + """Visualize latent colors without VAE decoding — pure math color preview from LCS. + + Projects latent patches into the 3D LCS, normalizes to t=50, decodes to HSL, + converts to RGB, and upscales 8x to pixel resolution. Produces a [B, H, W, 3] IMAGE. + """ + + @classmethod + def define_schema(cls) -> io.Schema: + """Define inputs (LATENT, LCS_DATA, sigma) and IMAGE output.""" + return io.Schema( + node_id="LCSPreviewColors", + display_name="LCS Preview Colors", + category="LCS/observe", + description="Visualize latent colors without VAE decoding — pure math color preview from LCS", + inputs=[ + io.Latent.Input("latent", tooltip="Latent from KSampler or similar"), + LCS_DATA.Input("lcs_data"), + io.Float.Input("sigma", default=0.0, min=0.0, max=1.0, step=0.01, + tooltip="Sigma for normalization (0.0 = final/clean, use sigma from sampler)"), + ], + outputs=[ + io.Image.Output(display_name="preview"), + ], + ) + + @classmethod + def execute(cls, latent, lcs_data, sigma) -> io.NodeOutput: + """Decode latent to LCS color preview. Returns IMAGE [B, H, W, 3].""" + samples = latent["samples"] + result = _latent_to_color_preview(samples, lcs_data, sigma, upscale=8) + return io.NodeOutput(result) + + +class LCSStepObserver(io.ComfyNode): + """Patches model to save per-step LCS color previews to ComfyUI's temp directory. + + Installs a post-CFG hook that generates a color preview image for the first + batch item at each sampling step. Images are saved as lcs_step_NNN_sX.XXX.png. + Does not modify the denoised prediction. + """ + + @classmethod + def define_schema(cls) -> io.Schema: + """Define inputs (MODEL, LCS_DATA) and MODEL output.""" + return io.Schema( + node_id="LCSStepObserver", + display_name="LCS Step Observer", + category="LCS/observe", + description="Patches model to save per-step LCS color previews to temp directory", + inputs=[ + io.Model.Input("model"), + LCS_DATA.Input("lcs_data"), + ], + outputs=[ + io.Model.Output(display_name="model"), + ], + ) + + @classmethod + def execute(cls, model, lcs_data) -> io.NodeOutput: + """Clone model, attach step observer hook. Returns patched MODEL.""" + m = model.clone() + step_counter = [0] + + def observer_fn(args): + """Post-CFG hook: generate color preview and save to temp directory.""" + denoised = args["denoised"] + sigma = args["sigma"] + model = args["model"] + sigma_val = float(sigma.flatten()[0]) + + # Unpack LTXAV packed format if needed + working, _ = unpack_video_if_needed(denoised, args) + + # Generate color preview for first batch item + preview = _latent_to_color_preview( + working[:1], lcs_data, sigma_val, upscale=4, model=model + ) + + # Save to temp directory + temp_dir = folder_paths.get_temp_directory() + img_np = (preview[0].numpy() * 255).clip(0, 255).astype(np.uint8) + pil_img = PILImage.fromarray(img_np) + filename = f"lcs_step_{step_counter[0]:03d}_s{sigma_val:.3f}.png" + pil_img.save(os.path.join(temp_dir, filename)) + step_counter[0] += 1 + + return denoised # Don't modify + + m.set_model_sampler_post_cfg_function(observer_fn) + return io.NodeOutput(m) diff --git a/custom_nodes/ComfyUI-LCS/nodes/sharpen.py b/custom_nodes/ComfyUI-LCS/nodes/sharpen.py new file mode 100644 index 0000000000000000000000000000000000000000..8858476b896c1c1ef4957506a73528b7619c0ff0 --- /dev/null +++ b/custom_nodes/ComfyUI-LCS/nodes/sharpen.py @@ -0,0 +1,200 @@ +"""Sharpness nodes: LCSSharpnessCalibrate and LCSSharpnessIntervene.""" + +import os + +import torch +from comfy_api.latest import io +from safetensors.torch import save_file, load_file + +from ..core.sharpness import SharpnessData, calibrate_sharpness +from ..core.calibration import vae_fingerprint +from ..core.patchify import patchify, unpatchify +from ..core.sampling import find_step_index, denoised_to_raw, raw_to_denoised, unpack_video_if_needed, repack_video_if_needed, downsample_mask + +SHARPNESS_DATA = io.Custom("SHARPNESS_DATA") +LCS_DATA = io.Custom("LCS_DATA") +DATA_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data") + + +def _save_sharpness(data: SharpnessData, path: str): + """Save SharpnessData to safetensors file.""" + os.makedirs(os.path.dirname(path), exist_ok=True) + tensors = { + "basis": data.basis.contiguous(), + "mean": data.mean.contiguous(), + "sign": torch.tensor([data.sign]), + } + if data.lcs_basis is not None: + tensors["lcs_basis"] = data.lcs_basis.contiguous() + save_file(tensors, path) + + +def _load_sharpness(path: str) -> SharpnessData: + """Load SharpnessData from safetensors file.""" + d = load_file(path) + return SharpnessData( + basis=d["basis"], + mean=d["mean"], + sign=float(d["sign"].item()), + lcs_basis=d.get("lcs_basis"), + ) + + +class LCSSharpnessCalibrate(io.ComfyNode): + """Calibrate the sharpness subspace for a VAE. + + Generates sinusoidal grating stimuli at varying spatial frequencies, + VAE-encodes them, and runs PCA to find the sharpness direction in + 64D patch space. Result is cached per-VAE fingerprint. + + When lcs_data is provided, the color component is removed during calibration, + ensuring the sharpness PC1 is orthogonal to the color subspace. + """ + + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LCSSharpnessCalibrate", + display_name="LCS Sharpness Calibrate", + category="LCS/calibration", + description="Auto-calibrate and cache sharpness subspace data per-VAE using frequency gratings. Connect lcs_data to ensure sharpness edits don't affect color.", + inputs=[ + io.Vae.Input("vae", tooltip="VAE model (calibration is cached per-VAE)"), + LCS_DATA.Input("lcs_data", optional=True, tooltip="Optional: remove color component to prevent color shifts"), + ], + outputs=[ + SHARPNESS_DATA.Output(display_name="sharpness_data"), + ], + ) + + @classmethod + def execute(cls, vae, lcs_data=None) -> io.NodeOutput: + fp = vae_fingerprint(vae) + suffix = "_lcs" if lcs_data is not None else "" + cache_path = os.path.join(DATA_DIR, f"sharpness_{fp}_grating{suffix}.safetensors") + + if os.path.exists(cache_path): + data = _load_sharpness(cache_path) + else: + data = calibrate_sharpness(vae, lcs_data=lcs_data) + _save_sharpness(data, cache_path) + + return io.NodeOutput(data) + + +def _build_sharpness_fn(sharpness_data, strength, start_step, end_step, mask): + """Build the post_cfg_function closure for sharpness intervention. + + Simple and correct approach: patches_new = patches + edit_vec. + Adding a vector along one direction automatically preserves all other + dimensions (residual preservation by construction). No need for explicit + projection/residual/reconstruction. + + The sharpness basis is calibrated with LCS color removal (if lcs_data was + provided at calibration time), so pc1_dir is already orthogonal to color. + At intervention time, we just add delta along that direction. + """ + # Precompute constant edit vector once (not per-step). + # Remove DC component from pc1_dir to prevent brightness shift, + # then re-orthogonalize against LCS basis (if available) because + # the DC vector [1,1,...,1] has nonzero projection onto LCS color space. + pc1_dir = sharpness_data.basis[:, 0].clone() + pc1_dir = pc1_dir - pc1_dir.mean() + if sharpness_data.lcs_basis is not None: + B = sharpness_data.lcs_basis.to(pc1_dir.device, pc1_dir.dtype) + pc1_dir = pc1_dir - B @ (B.T @ pc1_dir) + edit_vec = (strength * sharpness_data.sign) * pc1_dir # [64], on CPU + + def post_cfg_fn(args): + denoised = args["denoised"] + sigma = args["sigma"] + model = args["model"] + + # Step gating + sigmas = args["model_options"]["transformer_options"]["sample_sigmas"] + step_index = find_step_index(sigma, sigmas) + + if step_index < start_step or step_index > end_step: + return denoised + + # Unpack LTXAV packed format if needed + working, pack_info = unpack_video_if_needed(denoised, args) + + device = working.device + dtype = working.dtype + + # Move edit vector to device/dtype (short-circuits if already there) + ev = edit_vec.to(device=device, dtype=dtype) + + # Convert from process_in to raw VAE space + raw = denoised_to_raw(working, model) + + # Patchify + patches, h_len, w_len, extra_shape = patchify(raw) + if patches is None: + return denoised # Incompatible latent format + + # Apply sharpness edit + if mask is not None: + mask_flat = downsample_mask(mask, h_len, w_len, device, dtype) + if mask_flat.shape[1] != patches.shape[1]: + mask_flat = mask_flat[:, :patches.shape[1], :] + patches_new = patches + mask_flat * ev + else: + patches_new = patches + ev + + # Unpatchify + raw_new = unpatchify(patches_new, h_len, w_len, extra_shape) + + # Convert back to process_in space + modified = raw_to_denoised(raw_new, model).to(dtype) + + # Repack if LTXAV + return repack_video_if_needed(modified, pack_info) + + return post_cfg_fn + + +class LCSSharpnessIntervene(io.ComfyNode): + """Control sharpness during FLUX generation via the sharpness subspace. + + Installs a post-CFG hook that adds a scaled shift along the sharpness + PC1 direction (calibrated from sinusoidal grating stimuli). When calibrated + with lcs_data, the sharpness direction is orthogonal to color, so color is + preserved by construction. + Positive strength = sharper, negative = blurrier. + """ + + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LCSSharpnessIntervene", + display_name="LCS Sharpness Intervene", + category="LCS/intervention", + description="Control sharpness during FLUX generation (positive = sharper, negative = blurrier)", + inputs=[ + io.Model.Input("model"), + SHARPNESS_DATA.Input("sharpness_data", tooltip="Calibration data from LCSSharpnessCalibrate"), + io.Float.Input("strength", default=0.0, min=-5.0, max=5.0, step=0.1, + tooltip="Sharpness strength (>0 = sharper, <0 = blurrier, 0 = no change)"), + io.Int.Input("start_step", default=5, min=0, max=50, + tooltip="First step to apply sharpness intervention"), + io.Int.Input("end_step", default=15, min=0, max=50, + tooltip="Last step to apply sharpness intervention"), + io.Mask.Input("mask", optional=True, + tooltip="Optional mask for localized sharpness control"), + ], + outputs=[ + io.Model.Output(display_name="model"), + ], + ) + + @classmethod + def execute(cls, model, sharpness_data, strength, start_step, end_step, + mask=None) -> io.NodeOutput: + m = model.clone() + # Skip hook when strength is zero (true no-op) + if abs(strength) > 1e-6: + hook = _build_sharpness_fn(sharpness_data, strength, start_step, end_step, mask) + m.set_model_sampler_post_cfg_function(hook) + return io.NodeOutput(m) diff --git a/custom_nodes/ComfyUI-LCS/requirements.txt b/custom_nodes/ComfyUI-LCS/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..146237c43ec6bf86e4198a15f6111156500b8230 --- /dev/null +++ b/custom_nodes/ComfyUI-LCS/requirements.txt @@ -0,0 +1,2 @@ +einops +safetensors diff --git a/custom_nodes/ComfyUI-LCS/web/js/tone_preset.js b/custom_nodes/ComfyUI-LCS/web/js/tone_preset.js new file mode 100644 index 0000000000000000000000000000000000000000..8f754318af0cc0a94d99c99bf990a0aecacc475a --- /dev/null +++ b/custom_nodes/ComfyUI-LCS/web/js/tone_preset.js @@ -0,0 +1,51 @@ +import { app } from "../../../scripts/app.js"; + +// Preset values — must match TONE_PRESETS in nodes/intervene.py. +// Frontend syncs these into sliders on user interaction; Python has +// a copy as fallback for headless/batch execution without frontend. +const TONE_PRESETS = { + "Base": { contrast: 1.0, brightness: 0.0, saturation: 1.0, color_temperature: 0.0 }, + "Cinematic": { contrast: 1.20, brightness: -0.05, saturation: 0.90, color_temperature: 0.05 }, + "HDR": { contrast: 1.40, brightness: 0.0, saturation: 1.20, color_temperature: 0.0 }, + "Vivid": { contrast: 1.10, brightness: 0.0, saturation: 1.50, color_temperature: 0.0 }, + "Dramatic": { contrast: 1.50, brightness: -0.10, saturation: 0.85, color_temperature: 0.0 }, + "Low Key": { contrast: 1.30, brightness: -0.20, saturation: 0.80, color_temperature: 0.0 }, + "High Key": { contrast: 0.80, brightness: 0.20, saturation: 0.90, color_temperature: 0.0 }, + "Warm": { contrast: 1.05, brightness: 0.03, saturation: 1.10, color_temperature: 0.30 }, + "Cool": { contrast: 1.05, brightness: 0.0, saturation: 1.05, color_temperature: -0.30 }, + "Desaturated": { contrast: 1.0, brightness: 0.0, saturation: 0.40, color_temperature: 0.0 }, +}; + +const SLIDER_NAMES = ["contrast", "brightness", "saturation", "color_temperature"]; + +function syncPreset(node, presetName) { + const preset = TONE_PRESETS[presetName]; + if (!preset) return; // "Custom" — leave sliders as-is + + for (const name of SLIDER_NAMES) { + const widget = node.widgets?.find(w => w.name === name); + if (widget && name in preset) { + widget.value = preset[name]; + } + } + node.graph?.setDirtyCanvas(true, true); +} + +app.registerExtension({ + name: "ComfyUI-LCS.TonePresetSync", + + nodeCreated(node) { + if (node.comfyClass !== "LCSToneAdjust") return; + + const presetWidget = node.widgets?.find(w => w.name === "preset"); + if (!presetWidget) return; + + // Only sync on explicit user interaction (dropdown change), not on + // workflow load or paste — those restore saved slider values directly. + const origCallback = presetWidget.callback; + presetWidget.callback = function (value, canvas, nodeRef, pos, event) { + if (origCallback) origCallback.call(this, value, canvas, nodeRef, pos, event); + syncPreset(node, value); + }; + }, +}); diff --git a/models/FlashVSR-v1.1/.gitattributes b/models/FlashVSR-v1.1/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..a6344aac8c09253b3b630fb776ae94478aa0275b --- /dev/null +++ b/models/FlashVSR-v1.1/.gitattributes @@ -0,0 +1,35 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/models/FlashVSR-v1.1/LQ_proj_in.ckpt b/models/FlashVSR-v1.1/LQ_proj_in.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..22366868eea9a8e63218a5f3c2266c80fabbdc04 --- /dev/null +++ b/models/FlashVSR-v1.1/LQ_proj_in.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d6d011cdaaba6a52645086caa08fa04124e746f6ca568140a24007591142bfd2 +size 575694948 diff --git a/models/FlashVSR-v1.1/README.md b/models/FlashVSR-v1.1/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0b592ae115fb8061ce99e8e701c9688a255fa03d --- /dev/null +++ b/models/FlashVSR-v1.1/README.md @@ -0,0 +1,221 @@ +--- +license: apache-2.0 +pipeline_tag: video-to-video +--- + +# ⚡ FlashVSR + +**Towards Real-Time Diffusion-Based Streaming Video Super-Resolution** + +**Authors:** Junhao Zhuang, Shi Guo, Xin Cai, Xiaohui Li, Yihao Liu, Chun Yuan, Tianfan Xue + +   +   +   +   +   + + +**Your star means a lot for us to develop this project!** :star: + + + +--- + +### 🌟 Abstract + +Diffusion models have recently advanced video restoration, but applying them to real-world video super-resolution (VSR) remains challenging due to high latency, prohibitive computation, and poor generalization to ultra-high resolutions. Our goal in this work is to make diffusion-based VSR practical by achieving **efficiency, scalability, and real-time performance**. To this end, we propose **FlashVSR**, the first diffusion-based one-step streaming framework towards real-time VSR. **FlashVSR runs at ∼17 FPS for 768 × 1408 videos on a single A100 GPU** by combining three complementary innovations: (i) a train-friendly three-stage distillation pipeline that enables streaming super-resolution, (ii) locality-constrained sparse attention that cuts redundant computation while bridging the train–test resolution gap, and (iii) a tiny conditional decoder that accelerates reconstruction without sacrificing quality. To support large-scale training, we also construct **VSR-120K**, a new dataset with 120k videos and 180k images. Extensive experiments show that FlashVSR scales reliably to ultra-high resolutions and achieves **state-of-the-art performance with up to ∼12× speedup** over prior one-step diffusion VSR models. + +--- + +### 📰 News + +- **Nov 2025 — 🎉 [FlashVSR v1.1](https://huggingface.co/JunhaoZhuang/FlashVSR-v1.1) released:** enhanced stability + fidelity +- **Oct 2025 — [FlashVSR v1](https://huggingface.co/JunhaoZhuang/FlashVSR) (initial release)**: Inference code and model weights are available now! 🎉 +- **Bug Fix (October 21, 2025):** Fixed `local_attention_mask` update logic to prevent artifacts when switching between different aspect ratios during continuous inference. +- **Coming Soon:** Dataset release (**VSR-120K**) for large-scale training. + +--- + +### 📢 Important Quality Note (ComfyUI & other third-party implementations) + +First of all, huge thanks to the community for the fast adoption, feedback, and contributions to FlashVSR! 🙌 +During community testing, we noticed that some third-party implementations of FlashVSR (e.g. early ComfyUI versions) do **not include our Locality-Constrained Sparse Attention (LCSA)** module and instead fall back to **dense attention**. This may lead to **noticeable quality degradation**, especially at higher resolutions. +Community discussion: https://github.com/kijai/ComfyUI-WanVideoWrapper/issues/1441 + +Below is a comparison example provided by a community member: + +| Fig.1 – LR Input Video | Fig.2 – 3rd-party (no LCSA) | Fig.3 – Official FlashVSR | +|------------------|-----------------------------------------------|--------------------------------------| +| | | | + +✅ The **official FlashVSR pipeline (this repository)**: +- **Better preserves fine structures and details** +- **Effectively avoids texture aliasing and visual artifacts** + +We are also working on a **version that does not rely on the Block-Sparse Attention library** while keeping **the same output quality**; this alternative may run slower than the optimized original implementation. + +Thanks again to the community for actively testing and helping improve FlashVSR together! 🚀 + +--- + +### 📋 TODO + +- ✅ Release inference code and model weights +- ⬜ Release dataset (VSR-120K) + +--- + +### 🚀 Getting Started + +Follow these steps to set up and run **FlashVSR** on your local machine: + +> ⚠️ **Note:** This project is primarily designed and optimized for **4× video super-resolution**. +> We **strongly recommend** using the **4× SR setting** to achieve better results and stability. ✅ + +#### 1️⃣ Clone the Repository + +```bash +git clone https://github.com/OpenImagingLab/FlashVSR +cd FlashVSR +```` + +#### 2️⃣ Set Up the Python Environment + +Create and activate the environment (**Python 3.11.13**): + +```bash +conda create -n flashvsr python=3.11.13 +conda activate flashvsr +``` + +Install project dependencies: + +```bash +pip install -e . +pip install -r requirements.txt +``` + +#### 3️⃣ Install Block-Sparse Attention (Required) + +FlashVSR relies on the **Block-Sparse Attention** backend to enable flexible and dynamic attention masking for efficient inference. + +> **⚠️ Note:** +> +> * The Block-Sparse Attention build process can be memory-intensive, especially when compiling in parallel with multiple `ninja` jobs. It is recommended to keep sufficient memory available during compilation to avoid OOM errors. Once the build is complete, runtime memory usage is stable and not an issue. +> * Based on our testing, the Block-Sparse Attention backend works correctly on **NVIDIA A100 and A800** (Ampere) with **ideal acceleration performance**, and it also runs correctly on **H200** (Hopper) but the acceleration is limited due to hardware scheduling differences and sparse kernel behavior. **Compatibility and performance on other GPUs (e.g., RTX 40/50 series or H800) are currently unknown**. For more details, please refer to the official documentation: https://github.com/mit-han-lab/Block-Sparse-Attention + + +```bash +# ✅ Recommended: clone and install in a separate clean folder (outside the FlashVSR repo) +git clone https://github.com/mit-han-lab/Block-Sparse-Attention +cd Block-Sparse-Attention +pip install packaging +pip install ninja +python setup.py install +``` + +#### 4️⃣ Download Model Weights from Hugging Face + +FlashVSR provides both **v1** and **v1.1** model weights on Hugging Face (via **Git LFS**). +Please install Git LFS first: + +```bash +# From the repo root +cd examples/WanVSR + +# Install Git LFS (once per machine) +git lfs install + +# Clone v1 (original) or v1.1 (recommended) +git lfs clone https://huggingface.co/JunhaoZhuang/FlashVSR # v1 +# or +git lfs clone https://huggingface.co/JunhaoZhuang/FlashVSR-v1.1 # v1.1 +``` + +After cloning, you should have one of the following folders: + +``` +./examples/WanVSR/FlashVSR/ # v1 +./examples/WanVSR/FlashVSR-v1.1/ # v1.1 +│ +├── LQ_proj_in.ckpt +├── TCDecoder.ckpt +├── Wan2.1_VAE.pth +├── diffusion_pytorch_model_streaming_dmd.safetensors +└── README.md +``` + +> Inference scripts automatically load weights from the corresponding folder. + +--- + +#### 5️⃣ Run Inference + +```bash +# From the repo root +cd examples/WanVSR + +# v1 (original) +python infer_flashvsr_full.py +# or +python infer_flashvsr_tiny.py +# or +python infer_flashvsr_tiny_long_video.py + +# v1.1 (recommended) +python infer_flashvsr_v1.1_full.py +# or +python infer_flashvsr_v1.1_tiny.py +# or +python infer_flashvsr_v1.1_tiny_long_video.py +``` + +--- + +### 🛠️ Method + +The overview of **FlashVSR**. This framework features: + +* **Three-Stage Distillation Pipeline** for streaming VSR training. +* **Locality-Constrained Sparse Attention** to cut redundant computation and bridge the train–test resolution gap. +* **Tiny Conditional Decoder** for efficient, high-quality reconstruction. +* **VSR-120K Dataset** consisting of **120k videos** and **180k images**, supports joint training on both images and videos. + + + +--- + +### 🤗 Feedback & Support + +We welcome feedback and issues. Thank you for trying **FlashVSR**! + +--- + +### 📄 Acknowledgments + +We gratefully acknowledge the following open-source projects: + +* **DiffSynth Studio** — [https://github.com/modelscope/DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) +* **Block-Sparse-Attention** — [https://github.com/mit-han-lab/Block-Sparse-Attention](https://github.com/mit-han-lab/Block-Sparse-Attention) +* **taehv** — [https://github.com/madebyollin/taehv](https://github.com/madebyollin/taehv) + +--- + +### 📞 Contact + +* **Junhao Zhuang** + Email: [zhuangjh23@mails.tsinghua.edu.cn](mailto:zhuangjh23@mails.tsinghua.edu.cn) + +--- + +### 📜 Citation + +```bibtex +@article{zhuang2025flashvsr, + title={FlashVSR: Towards Real-Time Diffusion-Based Streaming Video Super-Resolution}, + author={Zhuang, Junhao and Guo, Shi and Cai, Xin and Li, Xiaohui and Liu, Yihao and Yuan, Chun and Xue, Tianfan}, + journal={arXiv preprint arXiv:2510.12747}, + year={2025} +} +``` diff --git a/models/FlashVSR-v1.1/TCDecoder.ckpt b/models/FlashVSR-v1.1/TCDecoder.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..c9172bab1dddb2bfcebf49749292b990ef00209f --- /dev/null +++ b/models/FlashVSR-v1.1/TCDecoder.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e224bdcf2f52745cbf4d393ff5374c2ba09e90285d5d19062d2bf63b915b6161 +size 189018333 diff --git a/models/FlashVSR-v1.1/Wan2.1_VAE.pth b/models/FlashVSR-v1.1/Wan2.1_VAE.pth new file mode 100644 index 0000000000000000000000000000000000000000..5897fba405232a6b07a947d6188d19a8e050ccfb --- /dev/null +++ b/models/FlashVSR-v1.1/Wan2.1_VAE.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:38071ab59bd94681c686fa51d75a1968f64e470262043be31f7a094e442fd981 +size 507609880 diff --git a/models/FlashVSR-v1.1/config.json b/models/FlashVSR-v1.1/config.json new file mode 100644 index 0000000000000000000000000000000000000000..3163ad0144826c126dfd750454ac5d30757cfec8 --- /dev/null +++ b/models/FlashVSR-v1.1/config.json @@ -0,0 +1,3 @@ +{ + "model_type": "flashvsr" +} \ No newline at end of file diff --git a/models/FlashVSR-v1.1/diffusion_pytorch_model_streaming_dmd.safetensors b/models/FlashVSR-v1.1/diffusion_pytorch_model_streaming_dmd.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..cd6a37b2d2caa77887bdc8349e43af2a88d38b5f --- /dev/null +++ b/models/FlashVSR-v1.1/diffusion_pytorch_model_streaming_dmd.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bd28180edcf3446c028e32fc6b731a80bf7e4da2ab4caac3186b9499964d37be +size 5676070392 diff --git a/models/FlashVSR-v1.1/model_index.json b/models/FlashVSR-v1.1/model_index.json new file mode 100644 index 0000000000000000000000000000000000000000..27b27d4163fa0d50959936e28624795c515ade4a --- /dev/null +++ b/models/FlashVSR-v1.1/model_index.json @@ -0,0 +1,4 @@ +{ + "_class_name": "FlashVSRPipeline", + "_diffusers_version": "0.24.0" +} \ No newline at end of file