+
+> [!TIP]
+> * Set `tile_overlap` to 0 and `denoise` to 1 to see the tile seams and then adjust the options to your needs.
+> * Increase `tile_batch_size` to increase speed (if your machine can handle it).
+> * Use the [colorfix node](https://github.com/gameltb/Comfyui-StableSR) if your colors look off.
+
+### Options
+
+| Name | Description |
+|-------------------|--------------------------------------------------------------|
+| `method` | Tiling [strategy](https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111/blob/fbb24736c9bc374c7f098f82b575fcd14a73936a/scripts/tilediffusion.py#L39-L46). |
+| `tile_width` | Tile's width |
+| `tile_height` | Tile's height |
+| `tile_overlap` | Tile's overlap |
+| `tile_batch_size` | The number of tiles to process in a batch |
+
+### How can I specify the tiles' arrangement?
+
+If you have the [Math Expression](https://github.com/pythongosssss/ComfyUI-Custom-Scripts#math-expression) node (or something similar), you can use that to pass in the latent that's passed in your KSampler and divide the `tile_height`/`tile_width` by the number of rows/columns you want.
+
+`C` = number of columns you want
+`R` = number of rows you want
+
+`pixel width of input image or latent // C` = `tile_width`
+`pixel height of input image or latent // R` = `tile_height`
+
+
+
+### SpotDiffusion
+
+[Paper](https://arxiv.org/abs/2407.15507)
+
+A tiling algorithm that attempts to eliminate seams by randomly shifting the denoise window per timestep. It is mainly used for fast inferences by setting `tile_overlap` to 0; otherwise, it's better to stick with the other tiling strategies as they produce better outputs.
+
+This additional feature is experimental, in testing, and subject to change.
+
+## Tiled VAE
+
+
+
+
+
+
+
+The recommended tile sizes are given upon the creation of the node based on the available VRAM.
+
+> [!NOTE]
+> Enabling `fast` for the decoder may produce images with slightly higher contrast and brightness.
+
+### Options
+
+| Name | Description |
+|-------------|----------------------------------------------------------------------------------------------------------------------------------------------|
+| `tile_size` |
The image is split into tiles, which are then padded with 11/32 pixels' in the decoder/encoder.
|
+| `fast` |
When Fast Mode is disabled:
The original VAE forward is decomposed into a task queue and a task worker, which starts to process each tile.
When GroupNorm is needed, it suspends, stores current GroupNorm mean and var, send everything to RAM, and turns to the next tile.
After all GroupNorm means and vars are summarized, it applies group norm to tiles and continues.
A zigzag execution order is used to reduce unnecessary data transfer.
When Fast Mode is enabled:
The original input is downsampled and passed to a separate task queue.
Its group norm parameters are recorded and used by all tiles' task queues.
Each tile is separately processed without any RAM-VRAM data transfer.
After all tiles are processed, tiles are written to a result buffer and returned.
|
+| `color_fix` |
Only estimate GroupNorm before downsampling, i.e., run in a semi-fast mode.
Only for the encoder. Can restore colors if tiles are too small.
|
+
+
+
+## Workflows
+
+The following images can be loaded in ComfyUI.
+
+
+
+ `;
+ }
+ } catch(e) {
+ setError(e.message);
+ } finally {
+ closeButton.disabled = false;
+ }
+}
+
+app.registerExtension({
+ name: "Comfy.CPackExtension",
+
+ async setup() {
+ const styleTag = document.createElement("style");
+ styleTag.innerHTML = style;
+ document.head.appendChild(styleTag);
+ const menu = document.querySelector(".comfy-menu");
+ const separator = document.createElement("hr");
+
+ separator.style.margin = "20px 0";
+ separator.style.width = "100%";
+ menu.append(separator);
+
+ const packButton = document.createElement("button");
+ packButton.textContent = "Package";
+ packButton.onclick = packageAction;
+ menu.append(packButton);
+
+ const unpackButton = document.createElement("button");
+ unpackButton.textContent = "Unpack";
+ unpackButton.onclick = unpackAction;
+ menu.append(unpackButton);
+
+ const serveButton = document.createElement("button");
+ serveButton.textContent = "Serve";
+ serveButton.onclick = serveAction;
+ menu.append(serveButton);
+
+
+ const buildButton = document.createElement("button");
+ buildButton.textContent = "Deploy";
+ buildButton.onclick = deployAction;
+ menu.append(buildButton);
+
+
+ try {
+ // new style Manager buttons
+
+ // unload models button into new style Manager button
+ let cmGroup1 = new (await import("../../scripts/ui/components/buttonGroup.js")).ComfyButtonGroup(
+ new(await import("../../scripts/ui/components/button.js")).ComfyButton({
+ icon: "package-variant-closed",
+ action: packageAction,
+ tooltip: "Comfy-Pack",
+ content: "Package",
+ classList: "comfyui-button comfyui-menu-mobile-collapse primary"
+ }).element,
+ new(await import("../../scripts/ui/components/button.js")).ComfyButton({
+ icon: "package-variant",
+ action: unpackAction,
+ tooltip: "Comfy-Pack",
+ content: "Unpack",
+ classList: "comfyui-button comfyui-menu-mobile-collapse"
+ }).element,
+ );
+
+ app.menu?.settingsGroup.element.before(cmGroup1.element);
+
+ let cmGroup2 = new (await import("../../scripts/ui/components/buttonGroup.js")).ComfyButtonGroup(
+ new(await import("../../scripts/ui/components/button.js")).ComfyButton({
+ icon: "api",
+ action: serveAction,
+ tooltip: "Comfy-Pack",
+ content: "Serve",
+ classList: "comfyui-button comfyui-menu-mobile-collapse primary"
+ }).element,
+ new(await import("../../scripts/ui/components/button.js")).ComfyButton({
+ icon: "cloud-upload",
+ action: deployAction,
+ tooltip: "Comfy-Pack",
+ content: "Deploy",
+ classList: "comfyui-button comfyui-menu-mobile-collapse"
+ }).element,
+ );
+
+ app.menu?.settingsGroup.element.before(cmGroup2.element);
+ }
+ catch(exception) {
+ console.log('ComfyUI is outdated. New style menu based features are disabled.');
+ }
+ }
+});
diff --git a/custom_nodes/comfyui-advanced-controlnet/.github/workflows/publish.yml b/custom_nodes/comfyui-advanced-controlnet/.github/workflows/publish.yml
new file mode 100644
index 0000000000000000000000000000000000000000..27bcdb11695ea19286da34cd5f36cf75155a4dbb
--- /dev/null
+++ b/custom_nodes/comfyui-advanced-controlnet/.github/workflows/publish.yml
@@ -0,0 +1,24 @@
+name: Publish to Comfy registry
+on:
+ workflow_dispatch:
+ push:
+ branches:
+ - main
+ paths:
+ - "pyproject.toml"
+
+permissions:
+ issues: write
+
+jobs:
+ publish-node:
+ name: Publish Custom Node to registry
+ runs-on: ubuntu-latest
+ if: ${{ github.repository_owner == 'Kosinkadink' }}
+ steps:
+ - name: Check out code
+ uses: actions/checkout@v4
+ - name: Publish Custom Node
+ uses: Comfy-Org/publish-node-action@v1
+ with:
+ personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} ## Add your own personal access token to your Github Repository secrets and reference it here.
diff --git a/custom_nodes/comfyui-advanced-controlnet/.gitignore b/custom_nodes/comfyui-advanced-controlnet/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..68bc17f9ff2104a9d7b6777058bb4c343ca72609
--- /dev/null
+++ b/custom_nodes/comfyui-advanced-controlnet/.gitignore
@@ -0,0 +1,160 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+#.idea/
diff --git a/custom_nodes/comfyui-advanced-controlnet/.tracking b/custom_nodes/comfyui-advanced-controlnet/.tracking
new file mode 100644
index 0000000000000000000000000000000000000000..a95b70198bcdf720e5e3b8c7724840c9feb46ab0
--- /dev/null
+++ b/custom_nodes/comfyui-advanced-controlnet/.tracking
@@ -0,0 +1,30 @@
+.github/workflows/publish.yml
+.gitignore
+LICENSE
+README.md
+__init__.py
+adv_control/control.py
+adv_control/control_ctrlora.py
+adv_control/control_lllite.py
+adv_control/control_plusplus.py
+adv_control/control_reference.py
+adv_control/control_sparsectrl.py
+adv_control/control_svd.py
+adv_control/dinklink.py
+adv_control/documentation.py
+adv_control/logger.py
+adv_control/nodes.py
+adv_control/nodes_ctrlora.py
+adv_control/nodes_deprecated.py
+adv_control/nodes_keyframes.py
+adv_control/nodes_loosecontrol.py
+adv_control/nodes_main.py
+adv_control/nodes_plusplus.py
+adv_control/nodes_reference.py
+adv_control/nodes_sparsectrl.py
+adv_control/nodes_weight.py
+adv_control/sampling.py
+adv_control/utils.py
+pyproject.toml
+web/js/autosize.js
+web/js/documentation.js
\ No newline at end of file
diff --git a/custom_nodes/comfyui-advanced-controlnet/LICENSE b/custom_nodes/comfyui-advanced-controlnet/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..f288702d2fa16d3cdf0035b15a9fcbc552cd88e7
--- /dev/null
+++ b/custom_nodes/comfyui-advanced-controlnet/LICENSE
@@ -0,0 +1,674 @@
+ GNU GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
+
+ Copyright (C) 2007 Free Software Foundation, Inc.
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+ Preamble
+
+ The GNU General Public License is a free, copyleft license for
+software and other kinds of works.
+
+ The licenses for most software and other practical works are designed
+to take away your freedom to share and change the works. By contrast,
+the GNU General Public License is intended to guarantee your freedom to
+share and change all versions of a program--to make sure it remains free
+software for all its users. We, the Free Software Foundation, use the
+GNU General Public License for most of our software; it applies also to
+any other work released this way by its authors. You can apply it to
+your programs, too.
+
+ When we speak of free software, we are referring to freedom, not
+price. Our General Public Licenses are designed to make sure that you
+have the freedom to distribute copies of free software (and charge for
+them if you wish), that you receive source code or can get it if you
+want it, that you can change the software or use pieces of it in new
+free programs, and that you know you can do these things.
+
+ To protect your rights, we need to prevent others from denying you
+these rights or asking you to surrender the rights. Therefore, you have
+certain responsibilities if you distribute copies of the software, or if
+you modify it: responsibilities to respect the freedom of others.
+
+ For example, if you distribute copies of such a program, whether
+gratis or for a fee, you must pass on to the recipients the same
+freedoms that you received. You must make sure that they, too, receive
+or can get the source code. And you must show them these terms so they
+know their rights.
+
+ Developers that use the GNU GPL protect your rights with two steps:
+(1) assert copyright on the software, and (2) offer you this License
+giving you legal permission to copy, distribute and/or modify it.
+
+ For the developers' and authors' protection, the GPL clearly explains
+that there is no warranty for this free software. For both users' and
+authors' sake, the GPL requires that modified versions be marked as
+changed, so that their problems will not be attributed erroneously to
+authors of previous versions.
+
+ Some devices are designed to deny users access to install or run
+modified versions of the software inside them, although the manufacturer
+can do so. This is fundamentally incompatible with the aim of
+protecting users' freedom to change the software. The systematic
+pattern of such abuse occurs in the area of products for individuals to
+use, which is precisely where it is most unacceptable. Therefore, we
+have designed this version of the GPL to prohibit the practice for those
+products. If such problems arise substantially in other domains, we
+stand ready to extend this provision to those domains in future versions
+of the GPL, as needed to protect the freedom of users.
+
+ Finally, every program is threatened constantly by software patents.
+States should not allow patents to restrict development and use of
+software on general-purpose computers, but in those that do, we wish to
+avoid the special danger that patents applied to a free program could
+make it effectively proprietary. To prevent this, the GPL assures that
+patents cannot be used to render the program non-free.
+
+ The precise terms and conditions for copying, distribution and
+modification follow.
+
+ TERMS AND CONDITIONS
+
+ 0. Definitions.
+
+ "This License" refers to version 3 of the GNU General Public License.
+
+ "Copyright" also means copyright-like laws that apply to other kinds of
+works, such as semiconductor masks.
+
+ "The Program" refers to any copyrightable work licensed under this
+License. Each licensee is addressed as "you". "Licensees" and
+"recipients" may be individuals or organizations.
+
+ To "modify" a work means to copy from or adapt all or part of the work
+in a fashion requiring copyright permission, other than the making of an
+exact copy. The resulting work is called a "modified version" of the
+earlier work or a work "based on" the earlier work.
+
+ A "covered work" means either the unmodified Program or a work based
+on the Program.
+
+ To "propagate" a work means to do anything with it that, without
+permission, would make you directly or secondarily liable for
+infringement under applicable copyright law, except executing it on a
+computer or modifying a private copy. Propagation includes copying,
+distribution (with or without modification), making available to the
+public, and in some countries other activities as well.
+
+ To "convey" a work means any kind of propagation that enables other
+parties to make or receive copies. Mere interaction with a user through
+a computer network, with no transfer of a copy, is not conveying.
+
+ An interactive user interface displays "Appropriate Legal Notices"
+to the extent that it includes a convenient and prominently visible
+feature that (1) displays an appropriate copyright notice, and (2)
+tells the user that there is no warranty for the work (except to the
+extent that warranties are provided), that licensees may convey the
+work under this License, and how to view a copy of this License. If
+the interface presents a list of user commands or options, such as a
+menu, a prominent item in the list meets this criterion.
+
+ 1. Source Code.
+
+ The "source code" for a work means the preferred form of the work
+for making modifications to it. "Object code" means any non-source
+form of a work.
+
+ A "Standard Interface" means an interface that either is an official
+standard defined by a recognized standards body, or, in the case of
+interfaces specified for a particular programming language, one that
+is widely used among developers working in that language.
+
+ The "System Libraries" of an executable work include anything, other
+than the work as a whole, that (a) is included in the normal form of
+packaging a Major Component, but which is not part of that Major
+Component, and (b) serves only to enable use of the work with that
+Major Component, or to implement a Standard Interface for which an
+implementation is available to the public in source code form. A
+"Major Component", in this context, means a major essential component
+(kernel, window system, and so on) of the specific operating system
+(if any) on which the executable work runs, or a compiler used to
+produce the work, or an object code interpreter used to run it.
+
+ The "Corresponding Source" for a work in object code form means all
+the source code needed to generate, install, and (for an executable
+work) run the object code and to modify the work, including scripts to
+control those activities. However, it does not include the work's
+System Libraries, or general-purpose tools or generally available free
+programs which are used unmodified in performing those activities but
+which are not part of the work. For example, Corresponding Source
+includes interface definition files associated with source files for
+the work, and the source code for shared libraries and dynamically
+linked subprograms that the work is specifically designed to require,
+such as by intimate data communication or control flow between those
+subprograms and other parts of the work.
+
+ The Corresponding Source need not include anything that users
+can regenerate automatically from other parts of the Corresponding
+Source.
+
+ The Corresponding Source for a work in source code form is that
+same work.
+
+ 2. Basic Permissions.
+
+ All rights granted under this License are granted for the term of
+copyright on the Program, and are irrevocable provided the stated
+conditions are met. This License explicitly affirms your unlimited
+permission to run the unmodified Program. The output from running a
+covered work is covered by this License only if the output, given its
+content, constitutes a covered work. This License acknowledges your
+rights of fair use or other equivalent, as provided by copyright law.
+
+ You may make, run and propagate covered works that you do not
+convey, without conditions so long as your license otherwise remains
+in force. You may convey covered works to others for the sole purpose
+of having them make modifications exclusively for you, or provide you
+with facilities for running those works, provided that you comply with
+the terms of this License in conveying all material for which you do
+not control copyright. Those thus making or running the covered works
+for you must do so exclusively on your behalf, under your direction
+and control, on terms that prohibit them from making any copies of
+your copyrighted material outside their relationship with you.
+
+ Conveying under any other circumstances is permitted solely under
+the conditions stated below. Sublicensing is not allowed; section 10
+makes it unnecessary.
+
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
+
+ No covered work shall be deemed part of an effective technological
+measure under any applicable law fulfilling obligations under article
+11 of the WIPO copyright treaty adopted on 20 December 1996, or
+similar laws prohibiting or restricting circumvention of such
+measures.
+
+ When you convey a covered work, you waive any legal power to forbid
+circumvention of technological measures to the extent such circumvention
+is effected by exercising rights under this License with respect to
+the covered work, and you disclaim any intention to limit operation or
+modification of the work as a means of enforcing, against the work's
+users, your or third parties' legal rights to forbid circumvention of
+technological measures.
+
+ 4. Conveying Verbatim Copies.
+
+ You may convey verbatim copies of the Program's source code as you
+receive it, in any medium, provided that you conspicuously and
+appropriately publish on each copy an appropriate copyright notice;
+keep intact all notices stating that this License and any
+non-permissive terms added in accord with section 7 apply to the code;
+keep intact all notices of the absence of any warranty; and give all
+recipients a copy of this License along with the Program.
+
+ You may charge any price or no price for each copy that you convey,
+and you may offer support or warranty protection for a fee.
+
+ 5. Conveying Modified Source Versions.
+
+ You may convey a work based on the Program, or the modifications to
+produce it from the Program, in the form of source code under the
+terms of section 4, provided that you also meet all of these conditions:
+
+ a) The work must carry prominent notices stating that you modified
+ it, and giving a relevant date.
+
+ b) The work must carry prominent notices stating that it is
+ released under this License and any conditions added under section
+ 7. This requirement modifies the requirement in section 4 to
+ "keep intact all notices".
+
+ c) You must license the entire work, as a whole, under this
+ License to anyone who comes into possession of a copy. This
+ License will therefore apply, along with any applicable section 7
+ additional terms, to the whole of the work, and all its parts,
+ regardless of how they are packaged. This License gives no
+ permission to license the work in any other way, but it does not
+ invalidate such permission if you have separately received it.
+
+ d) If the work has interactive user interfaces, each must display
+ Appropriate Legal Notices; however, if the Program has interactive
+ interfaces that do not display Appropriate Legal Notices, your
+ work need not make them do so.
+
+ A compilation of a covered work with other separate and independent
+works, which are not by their nature extensions of the covered work,
+and which are not combined with it such as to form a larger program,
+in or on a volume of a storage or distribution medium, is called an
+"aggregate" if the compilation and its resulting copyright are not
+used to limit the access or legal rights of the compilation's users
+beyond what the individual works permit. Inclusion of a covered work
+in an aggregate does not cause this License to apply to the other
+parts of the aggregate.
+
+ 6. Conveying Non-Source Forms.
+
+ You may convey a covered work in object code form under the terms
+of sections 4 and 5, provided that you also convey the
+machine-readable Corresponding Source under the terms of this License,
+in one of these ways:
+
+ a) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by the
+ Corresponding Source fixed on a durable physical medium
+ customarily used for software interchange.
+
+ b) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by a
+ written offer, valid for at least three years and valid for as
+ long as you offer spare parts or customer support for that product
+ model, to give anyone who possesses the object code either (1) a
+ copy of the Corresponding Source for all the software in the
+ product that is covered by this License, on a durable physical
+ medium customarily used for software interchange, for a price no
+ more than your reasonable cost of physically performing this
+ conveying of source, or (2) access to copy the
+ Corresponding Source from a network server at no charge.
+
+ c) Convey individual copies of the object code with a copy of the
+ written offer to provide the Corresponding Source. This
+ alternative is allowed only occasionally and noncommercially, and
+ only if you received the object code with such an offer, in accord
+ with subsection 6b.
+
+ d) Convey the object code by offering access from a designated
+ place (gratis or for a charge), and offer equivalent access to the
+ Corresponding Source in the same way through the same place at no
+ further charge. You need not require recipients to copy the
+ Corresponding Source along with the object code. If the place to
+ copy the object code is a network server, the Corresponding Source
+ may be on a different server (operated by you or a third party)
+ that supports equivalent copying facilities, provided you maintain
+ clear directions next to the object code saying where to find the
+ Corresponding Source. Regardless of what server hosts the
+ Corresponding Source, you remain obligated to ensure that it is
+ available for as long as needed to satisfy these requirements.
+
+ e) Convey the object code using peer-to-peer transmission, provided
+ you inform other peers where the object code and Corresponding
+ Source of the work are being offered to the general public at no
+ charge under subsection 6d.
+
+ A separable portion of the object code, whose source code is excluded
+from the Corresponding Source as a System Library, need not be
+included in conveying the object code work.
+
+ A "User Product" is either (1) a "consumer product", which means any
+tangible personal property which is normally used for personal, family,
+or household purposes, or (2) anything designed or sold for incorporation
+into a dwelling. In determining whether a product is a consumer product,
+doubtful cases shall be resolved in favor of coverage. For a particular
+product received by a particular user, "normally used" refers to a
+typical or common use of that class of product, regardless of the status
+of the particular user or of the way in which the particular user
+actually uses, or expects or is expected to use, the product. A product
+is a consumer product regardless of whether the product has substantial
+commercial, industrial or non-consumer uses, unless such uses represent
+the only significant mode of use of the product.
+
+ "Installation Information" for a User Product means any methods,
+procedures, authorization keys, or other information required to install
+and execute modified versions of a covered work in that User Product from
+a modified version of its Corresponding Source. The information must
+suffice to ensure that the continued functioning of the modified object
+code is in no case prevented or interfered with solely because
+modification has been made.
+
+ If you convey an object code work under this section in, or with, or
+specifically for use in, a User Product, and the conveying occurs as
+part of a transaction in which the right of possession and use of the
+User Product is transferred to the recipient in perpetuity or for a
+fixed term (regardless of how the transaction is characterized), the
+Corresponding Source conveyed under this section must be accompanied
+by the Installation Information. But this requirement does not apply
+if neither you nor any third party retains the ability to install
+modified object code on the User Product (for example, the work has
+been installed in ROM).
+
+ The requirement to provide Installation Information does not include a
+requirement to continue to provide support service, warranty, or updates
+for a work that has been modified or installed by the recipient, or for
+the User Product in which it has been modified or installed. Access to a
+network may be denied when the modification itself materially and
+adversely affects the operation of the network or violates the rules and
+protocols for communication across the network.
+
+ Corresponding Source conveyed, and Installation Information provided,
+in accord with this section must be in a format that is publicly
+documented (and with an implementation available to the public in
+source code form), and must require no special password or key for
+unpacking, reading or copying.
+
+ 7. Additional Terms.
+
+ "Additional permissions" are terms that supplement the terms of this
+License by making exceptions from one or more of its conditions.
+Additional permissions that are applicable to the entire Program shall
+be treated as though they were included in this License, to the extent
+that they are valid under applicable law. If additional permissions
+apply only to part of the Program, that part may be used separately
+under those permissions, but the entire Program remains governed by
+this License without regard to the additional permissions.
+
+ When you convey a copy of a covered work, you may at your option
+remove any additional permissions from that copy, or from any part of
+it. (Additional permissions may be written to require their own
+removal in certain cases when you modify the work.) You may place
+additional permissions on material, added by you to a covered work,
+for which you have or can give appropriate copyright permission.
+
+ Notwithstanding any other provision of this License, for material you
+add to a covered work, you may (if authorized by the copyright holders of
+that material) supplement the terms of this License with terms:
+
+ a) Disclaiming warranty or limiting liability differently from the
+ terms of sections 15 and 16 of this License; or
+
+ b) Requiring preservation of specified reasonable legal notices or
+ author attributions in that material or in the Appropriate Legal
+ Notices displayed by works containing it; or
+
+ c) Prohibiting misrepresentation of the origin of that material, or
+ requiring that modified versions of such material be marked in
+ reasonable ways as different from the original version; or
+
+ d) Limiting the use for publicity purposes of names of licensors or
+ authors of the material; or
+
+ e) Declining to grant rights under trademark law for use of some
+ trade names, trademarks, or service marks; or
+
+ f) Requiring indemnification of licensors and authors of that
+ material by anyone who conveys the material (or modified versions of
+ it) with contractual assumptions of liability to the recipient, for
+ any liability that these contractual assumptions directly impose on
+ those licensors and authors.
+
+ All other non-permissive additional terms are considered "further
+restrictions" within the meaning of section 10. If the Program as you
+received it, or any part of it, contains a notice stating that it is
+governed by this License along with a term that is a further
+restriction, you may remove that term. If a license document contains
+a further restriction but permits relicensing or conveying under this
+License, you may add to a covered work material governed by the terms
+of that license document, provided that the further restriction does
+not survive such relicensing or conveying.
+
+ If you add terms to a covered work in accord with this section, you
+must place, in the relevant source files, a statement of the
+additional terms that apply to those files, or a notice indicating
+where to find the applicable terms.
+
+ Additional terms, permissive or non-permissive, may be stated in the
+form of a separately written license, or stated as exceptions;
+the above requirements apply either way.
+
+ 8. Termination.
+
+ You may not propagate or modify a covered work except as expressly
+provided under this License. Any attempt otherwise to propagate or
+modify it is void, and will automatically terminate your rights under
+this License (including any patent licenses granted under the third
+paragraph of section 11).
+
+ However, if you cease all violation of this License, then your
+license from a particular copyright holder is reinstated (a)
+provisionally, unless and until the copyright holder explicitly and
+finally terminates your license, and (b) permanently, if the copyright
+holder fails to notify you of the violation by some reasonable means
+prior to 60 days after the cessation.
+
+ Moreover, your license from a particular copyright holder is
+reinstated permanently if the copyright holder notifies you of the
+violation by some reasonable means, this is the first time you have
+received notice of violation of this License (for any work) from that
+copyright holder, and you cure the violation prior to 30 days after
+your receipt of the notice.
+
+ Termination of your rights under this section does not terminate the
+licenses of parties who have received copies or rights from you under
+this License. If your rights have been terminated and not permanently
+reinstated, you do not qualify to receive new licenses for the same
+material under section 10.
+
+ 9. Acceptance Not Required for Having Copies.
+
+ You are not required to accept this License in order to receive or
+run a copy of the Program. Ancillary propagation of a covered work
+occurring solely as a consequence of using peer-to-peer transmission
+to receive a copy likewise does not require acceptance. However,
+nothing other than this License grants you permission to propagate or
+modify any covered work. These actions infringe copyright if you do
+not accept this License. Therefore, by modifying or propagating a
+covered work, you indicate your acceptance of this License to do so.
+
+ 10. Automatic Licensing of Downstream Recipients.
+
+ Each time you convey a covered work, the recipient automatically
+receives a license from the original licensors, to run, modify and
+propagate that work, subject to this License. You are not responsible
+for enforcing compliance by third parties with this License.
+
+ An "entity transaction" is a transaction transferring control of an
+organization, or substantially all assets of one, or subdividing an
+organization, or merging organizations. If propagation of a covered
+work results from an entity transaction, each party to that
+transaction who receives a copy of the work also receives whatever
+licenses to the work the party's predecessor in interest had or could
+give under the previous paragraph, plus a right to possession of the
+Corresponding Source of the work from the predecessor in interest, if
+the predecessor has it or can get it with reasonable efforts.
+
+ You may not impose any further restrictions on the exercise of the
+rights granted or affirmed under this License. For example, you may
+not impose a license fee, royalty, or other charge for exercise of
+rights granted under this License, and you may not initiate litigation
+(including a cross-claim or counterclaim in a lawsuit) alleging that
+any patent claim is infringed by making, using, selling, offering for
+sale, or importing the Program or any portion of it.
+
+ 11. Patents.
+
+ A "contributor" is a copyright holder who authorizes use under this
+License of the Program or a work on which the Program is based. The
+work thus licensed is called the contributor's "contributor version".
+
+ A contributor's "essential patent claims" are all patent claims
+owned or controlled by the contributor, whether already acquired or
+hereafter acquired, that would be infringed by some manner, permitted
+by this License, of making, using, or selling its contributor version,
+but do not include claims that would be infringed only as a
+consequence of further modification of the contributor version. For
+purposes of this definition, "control" includes the right to grant
+patent sublicenses in a manner consistent with the requirements of
+this License.
+
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
+patent license under the contributor's essential patent claims, to
+make, use, sell, offer for sale, import and otherwise run, modify and
+propagate the contents of its contributor version.
+
+ In the following three paragraphs, a "patent license" is any express
+agreement or commitment, however denominated, not to enforce a patent
+(such as an express permission to practice a patent or covenant not to
+sue for patent infringement). To "grant" such a patent license to a
+party means to make such an agreement or commitment not to enforce a
+patent against the party.
+
+ If you convey a covered work, knowingly relying on a patent license,
+and the Corresponding Source of the work is not available for anyone
+to copy, free of charge and under the terms of this License, through a
+publicly available network server or other readily accessible means,
+then you must either (1) cause the Corresponding Source to be so
+available, or (2) arrange to deprive yourself of the benefit of the
+patent license for this particular work, or (3) arrange, in a manner
+consistent with the requirements of this License, to extend the patent
+license to downstream recipients. "Knowingly relying" means you have
+actual knowledge that, but for the patent license, your conveying the
+covered work in a country, or your recipient's use of the covered work
+in a country, would infringe one or more identifiable patents in that
+country that you have reason to believe are valid.
+
+ If, pursuant to or in connection with a single transaction or
+arrangement, you convey, or propagate by procuring conveyance of, a
+covered work, and grant a patent license to some of the parties
+receiving the covered work authorizing them to use, propagate, modify
+or convey a specific copy of the covered work, then the patent license
+you grant is automatically extended to all recipients of the covered
+work and works based on it.
+
+ A patent license is "discriminatory" if it does not include within
+the scope of its coverage, prohibits the exercise of, or is
+conditioned on the non-exercise of one or more of the rights that are
+specifically granted under this License. You may not convey a covered
+work if you are a party to an arrangement with a third party that is
+in the business of distributing software, under which you make payment
+to the third party based on the extent of your activity of conveying
+the work, and under which the third party grants, to any of the
+parties who would receive the covered work from you, a discriminatory
+patent license (a) in connection with copies of the covered work
+conveyed by you (or copies made from those copies), or (b) primarily
+for and in connection with specific products or compilations that
+contain the covered work, unless you entered into that arrangement,
+or that patent license was granted, prior to 28 March 2007.
+
+ Nothing in this License shall be construed as excluding or limiting
+any implied license or other defenses to infringement that may
+otherwise be available to you under applicable patent law.
+
+ 12. No Surrender of Others' Freedom.
+
+ If conditions are imposed on you (whether by court order, agreement or
+otherwise) that contradict the conditions of this License, they do not
+excuse you from the conditions of this License. If you cannot convey a
+covered work so as to satisfy simultaneously your obligations under this
+License and any other pertinent obligations, then as a consequence you may
+not convey it at all. For example, if you agree to terms that obligate you
+to collect a royalty for further conveying from those to whom you convey
+the Program, the only way you could satisfy both those terms and this
+License would be to refrain entirely from conveying the Program.
+
+ 13. Use with the GNU Affero General Public License.
+
+ Notwithstanding any other provision of this License, you have
+permission to link or combine any covered work with a work licensed
+under version 3 of the GNU Affero General Public License into a single
+combined work, and to convey the resulting work. The terms of this
+License will continue to apply to the part which is the covered work,
+but the special requirements of the GNU Affero General Public License,
+section 13, concerning interaction through a network will apply to the
+combination as such.
+
+ 14. Revised Versions of this License.
+
+ The Free Software Foundation may publish revised and/or new versions of
+the GNU General Public License from time to time. Such new versions will
+be similar in spirit to the present version, but may differ in detail to
+address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Program specifies that a certain numbered version of the GNU General
+Public License "or any later version" applies to it, you have the
+option of following the terms and conditions either of that numbered
+version or of any later version published by the Free Software
+Foundation. If the Program does not specify a version number of the
+GNU General Public License, you may choose any version ever published
+by the Free Software Foundation.
+
+ If the Program specifies that a proxy can decide which future
+versions of the GNU General Public License can be used, that proxy's
+public statement of acceptance of a version permanently authorizes you
+to choose that version for the Program.
+
+ Later license versions may give you additional or different
+permissions. However, no additional obligations are imposed on any
+author or copyright holder as a result of your choosing to follow a
+later version.
+
+ 15. Disclaimer of Warranty.
+
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
+APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
+HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
+OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
+THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
+IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
+ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
+
+ 16. Limitation of Liability.
+
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
+WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
+THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
+GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
+USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
+DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
+PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
+EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
+SUCH DAMAGES.
+
+ 17. Interpretation of Sections 15 and 16.
+
+ If the disclaimer of warranty and limitation of liability provided
+above cannot be given local legal effect according to their terms,
+reviewing courts shall apply local law that most closely approximates
+an absolute waiver of all civil liability in connection with the
+Program, unless a warranty or assumption of liability accompanies a
+copy of the Program in return for a fee.
+
+ END OF TERMS AND CONDITIONS
+
+ How to Apply These Terms to Your New Programs
+
+ If you develop a new program, and you want it to be of the greatest
+possible use to the public, the best way to achieve this is to make it
+free software which everyone can redistribute and change under these terms.
+
+ To do so, attach the following notices to the program. It is safest
+to attach them to the start of each source file to most effectively
+state the exclusion of warranty; and each file should have at least
+the "copyright" line and a pointer to where the full notice is found.
+
+
+ Copyright (C)
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+
+Also add information on how to contact you by electronic and paper mail.
+
+ If the program does terminal interaction, make it output a short
+notice like this when it starts in an interactive mode:
+
+ Copyright (C)
+ This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
+ This is free software, and you are welcome to redistribute it
+ under certain conditions; type `show c' for details.
+
+The hypothetical commands `show w' and `show c' should show the appropriate
+parts of the General Public License. Of course, your program's commands
+might be different; for a GUI interface, you would use an "about box".
+
+ You should also get your employer (if you work as a programmer) or school,
+if any, to sign a "copyright disclaimer" for the program, if necessary.
+For more information on this, and how to apply and follow the GNU GPL, see
+.
+
+ The GNU General Public License does not permit incorporating your program
+into proprietary programs. If your program is a subroutine library, you
+may consider it more useful to permit linking proprietary applications with
+the library. If this is what you want to do, use the GNU Lesser General
+Public License instead of this License. But first, please read
+.
diff --git a/custom_nodes/comfyui-advanced-controlnet/README.md b/custom_nodes/comfyui-advanced-controlnet/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..b44d9126cb8bacb32f72bdd13051069883af68e8
--- /dev/null
+++ b/custom_nodes/comfyui-advanced-controlnet/README.md
@@ -0,0 +1,205 @@
+# ComfyUI-Advanced-ControlNet
+Nodes for scheduling ControlNet strength across timesteps and batched latents, as well as applying custom weights and attention masks. The ControlNet nodes here fully support sliding context sampling, like the one used in the [ComfyUI-AnimateDiff-Evolved](https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved) nodes. Currently supports ControlNets, T2IAdapters, ControlLoRAs, ControlLLLite, SparseCtrls, SVD-ControlNets, and Reference.
+
+Custom weights allow replication of the "My prompt is more important" feature of Auto1111's sd-webui ControlNet extension via Soft Weights, and the "ControlNet is more important" feature can be granularly controlled by changing the uncond_multiplier on the same Soft Weights.
+
+ControlNet preprocessors are available through [comfyui_controlnet_aux](https://github.com/Fannovel16/comfyui_controlnet_aux) nodes.
+
+## Features
+- Timestep and latent strength scheduling
+- Attention masks
+- Replicate ***"My prompt is more important"*** feature from sd-webui-controlnet extension via ***Soft Weights***, and allow softness to be tweaked via ***base_multiplier***
+- Replicate ***"ControlNet is more important"*** feature from sd-webui-controlnet extension via ***uncond_multiplier*** on ***Soft Weights***
+ - uncond_multiplier=0.0 gives identical results of auto1111's feature, but values between 0.0 and 1.0 can be used without issue to granularly control the setting.
+- ControlNet, T2IAdapter, and ControlLoRA support for sliding context windows
+- ControlLLLite support
+- ControlNet++ support
+- CtrLoRA support
+ - Relevant models linked on [CtrLoRA github page](https://github.com/xyfJASON/ctrlora)
+- SparseCtrl support
+- SVD-ControlNet support
+ - Stable Video Diffusion ControlNets trained by **CiaraRowles**: [Depth](https://huggingface.co/CiaraRowles/temporal-controlnet-depth-svd-v1/tree/main/controlnet), [Lineart](https://huggingface.co/CiaraRowles/temporal-controlnet-lineart-svd-v1/tree/main/controlnet)
+- Reference support
+ - Supports ```reference_attn```, ```reference_adain```, and ```refrence_adain+attn``` modes. ```style_fidelity``` and ```ref_weight``` are equivalent to style_fidelity and control_weight in Auto1111, respectively, and strength of the Apply ControlNet is the balance between ref-influenced result and no-ref result. There is also a Reference ControlNet (Finetune) node that allows adjust the style_fidelity, weight, and strength of attn and adain separately.
+
+## Table of Contents:
+- [Scheduling Explanation](#scheduling-explanation)
+- [Nodes](#nodes)
+- [Usage](#usage) (will fill this out soon)
+
+
+# Scheduling Explanation
+
+The two core concepts for scheduling are ***Timestep Keyframes*** and ***Latent Keyframes***.
+
+***Timestep Keyframes*** hold the values that guide the settings for a controlnet, and begin to take effect based on their start_percent, which corresponds to the percentage of the sampling process. They can contain masks for the strengths of each latent, control_net_weights, and latent_keyframes (specific strengths for each latent), all optional.
+
+***Latent Keyframes*** determine the strength of the controlnet for specific latents - all they contain is the batch_index of the latent, and the strength the controlnet should apply for that latent. As a concept, latent keyframes achieve the same affect as a uniform mask with the chosen strength value.
+
+
+
+# Nodes
+
+The ControlNet nodes provided here are the ***Apply Advanced ControlNet*** and ***Load Advanced ControlNet Model*** (or diff) nodes. The vanilla ControlNet nodes are also compatible, and can be used almost interchangeably - the only difference is that **at least one of these nodes must be used** for Advanced versions of ControlNets to be used (important for sliding context sampling, like with AnimateDiff-Evolved).
+
+Key:
+- 🟩 - required inputs
+- 🟨 - optional inputs
+- 🟦 - start as widgets, can be converted to inputs
+- 🟥 - optional input/output, but not recommended to use unless needed
+- 🟪 - output
+
+## Apply Advanced ControlNet
+
+
+Same functionality as the vanilla Apply Advanced ControlNet (Advanced) node, except with Advanced ControlNet features added to it. Automatically converts any ControlNet from ControlNet loaders into Advanced versions.
+
+### Inputs
+- 🟩***positive***: conditioning (positive).
+- 🟩***negative***: conditioning (negative).
+- 🟩***control_net***: loaded controlnet; will be converted to Advanced version automatically by this node, if it's a supported type.
+- 🟩***image***: images to guide controlnets - if the loaded controlnet requires it, they must preprocessed images. If one image provided, will be used for all latents. If more images provided, will use each image separately for each latent. If not enough images to meet latent count, will repeat the images from the beginning to match vanilla ControlNet functionality.
+- 🟨***mask_optional***: attention masks to apply to controlnets; basically, decides what part of the image the controlnet to apply to (and the relative strength, if the mask is not binary). Same as image input, if you provide more than one mask, each can apply to a different latent.
+- 🟨***timestep_kf***: timestep keyframes to guide controlnet effect throughout sampling steps.
+- 🟨***latent_kf_override***: override for latent keyframes, useful if no other features from timestep keyframes is needed. *NOTE: this latent keyframe will be applied to ALL timesteps, regardless if there are other latent keyframes attached to connected timestep keyframes.*
+- 🟨***weights_override***: override for weights, useful if no other features from timestep keyframes is needed. *NOTE: this weight will be applied to ALL timesteps, regardless if there are other weights attached to connected timestep keyframes.*
+- 🟦***strength***: strength of controlnet; 1.0 is full strength, 0.0 is no effect at all.
+- 🟦***start_percent***: sampling step percentage at which controlnet should start to be applied - no matter what start_percent is set on timestep keyframes, they won't take effect until this start_percent is reached.
+- 🟦***stop_percent***: sampling step percentage at which controlnet should stop being applied - no matter what start_percent is set on timestep keyframes, they won't take effect once this end_percent is reached.
+
+### Outputs
+- 🟪***positive***: conditioning (positive) with applied controlnets
+- 🟪***negative***: conditioning (negative) with applied controlnets
+
+## Load Advanced ControlNet Model
+
+
+Loads a ControlNet model and converts it into an Advanced version that supports all the features in this repo. When used with **Apply Advanced ControlNet** node, there is no reason to use the timestep_keyframe input on this node - use timestep_kf on the Apply node instead.
+
+### Inputs
+- 🟥***timestep_keyframe***: optional and likely unnecessary input to have ControlNet use selected timestep_keyframes - should not be used unless you need to. Useful if this node is not attached to **Apply Advanced ControlNet** node, but still want to use Timestep Keyframe, or to use TK_SHORTCUT outputs from ControlWeights in the same scenario. Will be overriden by the timestep_kf input on **Apply Advanced ControlNet** node, if one is provided there.
+- 🟨***model***: model to plug into the diff version of the node. Some controlnets are designed for receive the model; if you don't know what this does, you probably don't want tot use the diff version of the node.
+
+### Outputs
+- 🟪***CONTROL_NET***: loaded Advanced ControlNet
+
+## Timestep Keyframe
+
+
+Scheduling node across timesteps (sampling steps) based on the set start_percent. Chaining Timestep Keyframes allows ControlNet scheduling across sampling steps (percentage-wise), through a timestep keyframe schedule.
+
+### Inputs
+- 🟨***prev_timestep_kf***: used to chain Timestep Keyframes together to create a schedule. The order does not matter - the Timestep Keyframes sort themselves automatically by their start_percent. *Any Timestep Keyframe contained in the prev_timestep_keyframe that contains the same start_percent as the Timestep Keyframe will be overwritten.*
+- 🟨***cn_weights***: weights to apply to controlnet while this Timestep Keyframe is in effect. Must be compatible with the loaded controlnet, or will throw an error explaining what weight types are compatible. If inherit_missing is True, if no control_net_weight is passed in, will attempt to reuse the last-used weights in the timestep keyframe schedule. *If Apply Advanced ControlNet node has a weight_override, the weight_override will be used during sampling instead of control_net_weight.*
+- 🟨***latent_keyframe***: latent keyframes to apply to controlnet while this Timestep Keyframe is in effect. If inherit_missing is True, if no latent_keyframe is passed in, will attempt to reuse the last-used weights in the timestep keyframe schedule. *If Apply Advanced ControlNet node has a latent_kf_override, the latent_lf_override will be used during sampling instead of latent_keyframe.*
+- 🟨***mask_optional***: attention masks to apply to controlnets; basically, decides what part of the image the controlnet to apply to (and the relative strength, if the mask is not binary). Same as mask_optional on the Apply Advanced ControlNet node, can apply either one maks to all latents, or individual masks for each latent. If inherit_missing is True, if no mask_optional is passed in, will attempt to reuse the last-used mask_optional in the timestep keyframe schedule. It is NOT overriden by mask_optional on the Apply Advanced ControlNet node; will be used together.
+- 🟦***start_percent***: sampling step percentage at which this Timestep Keyframe qualifies to be used. Acts as the 'key' for the Timestep Keyframe in the timestep keyframe schedule.
+- 🟦***strength***: strength of the controlnet; multiplies the controlnet by this value, basically, applied alongside the strength on the Apply ControlNet node. If set to 0.0 will not have any effect during the duration of this Timestep Keyframe's effect, and will increase sampling speed by not doing any work.
+- 🟦***null_latent_kf_strength***: strength to assign to latents that are unaccounted for in the passed in latent_keyframes. Has no effect if no latent_keyframes are passed in, or no batch_indeces are unaccounted in the latent_keyframes for during sampling.
+- 🟦***inherit_missing***: determines if should reuse values from previous Timestep Keyframes for optional values (control_net_weights, latent_keyframe, and mask_option) that are not included on this TimestepKeyframe. To inherit only specific inputs, use default inputs.
+- 🟦***guarantee_steps***: when 1 or greater, even if a Timestep Keyframe's start_percent ahead of this one in the schedule is closer to current sampling percentage, this Timestep Keyframe will still be used for the specified amount of steps before moving on to the next selected Timestep Keyframe in the following step. Whether the Timestep Keyframe is used or not, its inputs will still be accounted for inherit_missing purposes.
+
+### Outputs
+- 🟪***TIMESTEP_KF***: the created Timestep Keyframe, that can either be linked to another or into a Timestep Keyframe input.
+
+## Timestep Keyframe Interpolation
+
+
+Allows to create Timestep Keyframe with interpolated strength values in a given percent range. (The first generated keyframe will have guarantee_steps=1, rest that follow will have guarantee_steps=0).
+
+### Inputs
+- 🟨***prev_timestep_kf***: used to chain Timestep Keyframes together to create a schedule. The order does not matter - the Timestep Keyframes sort themselves automatically by their start_percent. *Any Timestep Keyframe contained in the prev_timestep_keyframe that contains the same start_percent as the Timestep Keyframe will be overwritten.*
+- 🟨***cn_weights***: weights to apply to controlnet while this Timestep Keyframe is in effect. Must be compatible with the loaded controlnet, or will throw an error explaining what weight types are compatible. If inherit_missing is True, if no control_net_weight is passed in, will attempt to reuse the last-used weights in the timestep keyframe schedule. *If Apply Advanced ControlNet node has a weight_override, the weight_override will be used during sampling instead of control_net_weight.*
+- 🟨***latent_keyframe***: latent keyframes to apply to controlnet while this Timestep Keyframe is in effect. If inherit_missing is True, if no latent_keyframe is passed in, will attempt to reuse the last-used weights in the timestep keyframe schedule. *If Apply Advanced ControlNet node has a latent_kf_override, the latent_lf_override will be used during sampling instead of latent_keyframe.*
+- 🟨***mask_optional***: attention masks to apply to controlnets; basically, decides what part of the image the controlnet to apply to (and the relative strength, if the mask is not binary). Same as mask_optional on the Apply Advanced ControlNet node, can apply either one maks to all latents, or individual masks for each latent. If inherit_missing is True, if no mask_optional is passed in, will attempt to reuse the last-used mask_optional in the timestep keyframe schedule. It is NOT overriden by mask_optional on the Apply Advanced ControlNet node; will be used together.
+- 🟦***start_percent***: sampling step percentage at which the first generated Timestep Keyframe qualifies to be used.
+- 🟦***end_percent***: sampling step percentage at which the last generated Timestep Keyframe qualifies to be used.
+- 🟦***strength_start***: strength of the Timestep Keyframe at start of range.
+- 🟦***strength_end***: strength of the Timestep Keyframe at end of range.
+- 🟦***interpolation***: the method of interpolation.
+- 🟦***intervals***: the amount of keyframes to generate in total - the first will have its start_percent equal to start_percent, the last will have its start_percent equal to end_percent.
+- 🟦***null_latent_kf_strength***: strength to assign to latents that are unaccounted for in the passed in latent_keyframes. Has no effect if no latent_keyframes are passed in, or no batch_indeces are unaccounted in the latent_keyframes for during sampling.
+- 🟦***inherit_missing***: determines if should reuse values from previous Timestep Keyframes for optional values (control_net_weights, latent_keyframe, and mask_option) that are not included on this TimestepKeyframe. To inherit only specific inputs, use default inputs.
+- 🟦***print_keyframes***: if True, will print the Timestep Keyframes generated by this node for debugging purposes.
+
+### Outputs
+- 🟪***TIMESTEP_KF***: the created Timestep Keyframe, that can either be linked to another or into a Timestep Keyframe input.
+
+## Timestep Keyframe From List
+
+
+Allows to create Timestep Keyframe via a list of floats, such as with Batch Value Schedule from [ComfyUI_FizzNodes](https://github.com/FizzleDorf/ComfyUI_FizzNodes) nodes. (The first generated keyframe will have guarantee_steps=1, rest that follow will have guarantee_steps=0).
+
+### Inputs
+- 🟨***prev_timestep_kf***: used to chain Timestep Keyframes together to create a schedule. The order does not matter - the Timestep Keyframes sort themselves automatically by their start_percent. *Any Timestep Keyframe contained in the prev_timestep_keyframe that contains the same start_percent as the Timestep Keyframe will be overwritten.*
+- 🟨***cn_weights***: weights to apply to controlnet while this Timestep Keyframe is in effect. Must be compatible with the loaded controlnet, or will throw an error explaining what weight types are compatible. If inherit_missing is True, if no control_net_weight is passed in, will attempt to reuse the last-used weights in the timestep keyframe schedule. *If Apply Advanced ControlNet node has a weight_override, the weight_override will be used during sampling instead of control_net_weight.*
+- 🟨***latent_keyframe***: latent keyframes to apply to controlnet while this Timestep Keyframe is in effect. If inherit_missing is True, if no latent_keyframe is passed in, will attempt to reuse the last-used weights in the timestep keyframe schedule. *If Apply Advanced ControlNet node has a latent_kf_override, the latent_lf_override will be used during sampling instead of latent_keyframe.*
+- 🟨***mask_optional***: attention masks to apply to controlnets; basically, decides what part of the image the controlnet to apply to (and the relative strength, if the mask is not binary). Same as mask_optional on the Apply Advanced ControlNet node, can apply either one maks to all latents, or individual masks for each latent. If inherit_missing is True, if no mask_optional is passed in, will attempt to reuse the last-used mask_optional in the timestep keyframe schedule. It is NOT overriden by mask_optional on the Apply Advanced ControlNet node; will be used together.
+- 🟩***float_strengths***: a list of floats, that will correspond to the strength of each Timestep Keyframe; first will be assigned to start_percent, last will be assigned to end_percent, and the rest spread linearly between.
+- 🟦***start_percent***: sampling step percentage at which the first generated Timestep Keyframe qualifies to be used.
+- 🟦***end_percent***: sampling step percentage at which the last generated Timestep Keyframe qualifies to be used.
+- 🟦***null_latent_kf_strength***: strength to assign to latents that are unaccounted for in the passed in latent_keyframes. Has no effect if no latent_keyframes are passed in, or no batch_indeces are unaccounted in the latent_keyframes for during sampling.
+- 🟦***inherit_missing***: determines if should reuse values from previous Timestep Keyframes for optional values (control_net_weights, latent_keyframe, and mask_option) that are not included on this TimestepKeyframe. To inherit only specific inputs, use default inputs.
+- 🟦***print_keyframes***: if True, will print the Timestep Keyframes generated by this node for debugging purposes.
+
+### Outputs
+- 🟪***TIMESTEP_KF***: the created Timestep Keyframe, that can either be linked to another or into a Timestep Keyframe input.
+
+## Latent Keyframe
+
+
+A singular Latent Keyframe, selects the strength for a specific batch_index. If batch_index is not present during sampling, will simply have no effect. Can be chained with any other Latent Keyframe-type node to create a latent keyframe schedule.
+
+### Inputs
+- 🟨***prev_latent_kf***: used to chain Latent Keyframes together to create a schedule. *If a Latent Keyframe contained in prev_latent_keyframes have the same batch_index as this Latent Keyframe, they will take priority over this node's value.*
+- 🟦***batch_index***: index of latent in batch to apply controlnet strength to. Acts as the 'key' for the Latent Keyframe in the latent keyframe schedule.
+- 🟦***strength***: strength of controlnet to apply to the corresponding latent.
+
+### Outputs
+- 🟪***LATENT_KF***: the created Latent Keyframe, that can either be linked to another or into a Latent Keyframe input.
+
+## Latent Keyframe Group
+
+
+Allows to create Latent Keyframes via individual indeces or python-style ranges.
+
+### Inputs
+- 🟨***prev_latent_kf***: used to chain Latent Keyframes together to create a schedule. *If any Latent Keyframes contained in prev_latent_keyframes have the same batch_index as a this Latent Keyframe, they will take priority over this node's version.*
+- 🟨***latent_optional***: the latents expected to be passed in for sampling; only required if you wish to use negative indeces (will be automatically converted to real values).
+- 🟦***index_strengths***: string list of indeces or python-style ranges of indeces to assign strengths to. If latent_optional is passed in, can contain negative indeces or ranges that contain negative numbers, python-style. The different indeces must be comma separated. Individual latents can be specified by ```batch_index=strength```, like ```0=0.9```. Ranges can be specified by ```start_index_inclusive:end_index_exclusive=strength```, like ```0:8=strength```. Negative indeces are possible when latents_optional has an input, with a string such as ```0,-4=0.25```.
+- 🟦***print_keyframes***: if True, will print the Latent Keyframes generated by this node for debugging purposes.
+
+### Outputs
+- 🟪***LATENT_KF***: the created Latent Keyframe, that can either be linked to another or into a Latent Keyframe input.
+
+## Latent Keyframe Interpolation
+
+
+Allows to create Latent Keyframes with interpolated values in a range.
+
+### Inputs
+- 🟨***prev_latent_kf***: used to chain Latent Keyframes together to create a schedule. *If any Latent Keyframes contained in prev_latent_keyframes have the same batch_index as a this Latent Keyframe, they will take priority over this node's version.*
+- 🟦***batch_index_from***: starting batch_index of range, included.
+- 🟦***batch_index_to***: end batch_index of range, excluded (python-style range).
+- 🟦***strength_from***: starting strength of interpolation.
+- 🟦***strength_to***: end strength of interpolation.
+- 🟦***interpolation***: the method of interpolation.
+- 🟦***print_keyframes***: if True, will print the Latent Keyframes generated by this node for debugging purposes.
+
+### Outputs
+- 🟪***LATENT_KF***: the created Latent Keyframe, that can either be linked to another or into a Latent Keyframe input.
+
+## Latent Keyframe From List
+
+
+Allows to create Latent Keyframes via a list of floats, such as with Batch Value Schedule from [ComfyUI_FizzNodes](https://github.com/FizzleDorf/ComfyUI_FizzNodes) nodes.
+
+### Inputs
+- 🟨***prev_latent_kf***: used to chain Latent Keyframes together to create a schedule. *If any Latent Keyframes contained in prev_latent_keyframes have the same batch_index as a this Latent Keyframe, they will take priority over this node's version.*
+- 🟩***float_strengths***: a list of floats, that will correspond to the strength of each Latent Keyframe; the batch_index is the index of each float value in the list.
+- 🟦***print_keyframes***: if True, will print the Latent Keyframes generated by this node for debugging purposes.
+
+### Outputs
+- 🟪***LATENT_KF***: the created Latent Keyframe, that can either be linked to another or into a Latent Keyframe input.
+
+# There are more nodes to document and show usage - will add this soon! TODO
diff --git a/custom_nodes/comfyui-advanced-controlnet/__init__.py b/custom_nodes/comfyui-advanced-controlnet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb650edf1138b066323163d27898b982ebdb12ce
--- /dev/null
+++ b/custom_nodes/comfyui-advanced-controlnet/__init__.py
@@ -0,0 +1,11 @@
+from .adv_control.nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
+from .adv_control import documentation
+from .adv_control.dinklink import init_dinklink
+from .adv_control.sampling import prepare_dinklink_acn_wrapper
+
+WEB_DIRECTORY = "./web"
+__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS', "WEB_DIRECTORY"]
+documentation.format_descriptions(NODE_CLASS_MAPPINGS)
+
+init_dinklink()
+prepare_dinklink_acn_wrapper()
diff --git a/custom_nodes/comfyui-advanced-controlnet/__pycache__/__init__.cpython-312.pyc b/custom_nodes/comfyui-advanced-controlnet/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..324be2dec9a874b639743fceb85f60572f7424c2
Binary files /dev/null and b/custom_nodes/comfyui-advanced-controlnet/__pycache__/__init__.cpython-312.pyc differ
diff --git a/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/control.cpython-312.pyc b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/control.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f8abc40eaf9f5a48a8cee13d4e469da0803aaadf
Binary files /dev/null and b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/control.cpython-312.pyc differ
diff --git a/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/control_ctrlora.cpython-312.pyc b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/control_ctrlora.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7ee20ada9c98e4f46905e494b9f1943237675168
Binary files /dev/null and b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/control_ctrlora.cpython-312.pyc differ
diff --git a/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/control_lllite.cpython-312.pyc b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/control_lllite.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..86cd2b9ebf5302bd96a01b9a2cdd52b67b3052b7
Binary files /dev/null and b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/control_lllite.cpython-312.pyc differ
diff --git a/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/control_plusplus.cpython-312.pyc b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/control_plusplus.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..90aaee70794e5dcdd1389ba73517b9255bf48684
Binary files /dev/null and b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/control_plusplus.cpython-312.pyc differ
diff --git a/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/control_reference.cpython-312.pyc b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/control_reference.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ef836265f18d955653550c54c3fa0968ca09bd2b
Binary files /dev/null and b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/control_reference.cpython-312.pyc differ
diff --git a/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/control_sparsectrl.cpython-312.pyc b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/control_sparsectrl.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..94a40c20a2c32e6070c663af5f503c6fd8ce860c
Binary files /dev/null and b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/control_sparsectrl.cpython-312.pyc differ
diff --git a/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/control_svd.cpython-312.pyc b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/control_svd.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1eea8ec6aca93a098f69ef67f9fec612bcae390b
Binary files /dev/null and b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/control_svd.cpython-312.pyc differ
diff --git a/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/dinklink.cpython-312.pyc b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/dinklink.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..877d780910335690c8654cb5e732cd9065df1aed
Binary files /dev/null and b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/dinklink.cpython-312.pyc differ
diff --git a/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/documentation.cpython-312.pyc b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/documentation.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6b4ab37f34a06f6a0558ae0febaf21c73c47cfd7
Binary files /dev/null and b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/documentation.cpython-312.pyc differ
diff --git a/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/logger.cpython-312.pyc b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/logger.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bf40d3c767d488799cb498a8863f9f8a54bd795b
Binary files /dev/null and b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/logger.cpython-312.pyc differ
diff --git a/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/nodes.cpython-312.pyc b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/nodes.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..abd645adaaf42a425f7809d26ca47fd899039b77
Binary files /dev/null and b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/nodes.cpython-312.pyc differ
diff --git a/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/nodes_ctrlora.cpython-312.pyc b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/nodes_ctrlora.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..85e1809452dda5bdf10d307488be7bb1c15a8137
Binary files /dev/null and b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/nodes_ctrlora.cpython-312.pyc differ
diff --git a/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/nodes_deprecated.cpython-312.pyc b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/nodes_deprecated.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..191fde872b0d157d99176e4cbcf654a17892413f
Binary files /dev/null and b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/nodes_deprecated.cpython-312.pyc differ
diff --git a/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/nodes_keyframes.cpython-312.pyc b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/nodes_keyframes.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ca6c0e9387c4b05e1fdc15e8b4bd28956703783f
Binary files /dev/null and b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/nodes_keyframes.cpython-312.pyc differ
diff --git a/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/nodes_loosecontrol.cpython-312.pyc b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/nodes_loosecontrol.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2a41eb2e66d775bab7e1beff97e06be32c615ead
Binary files /dev/null and b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/nodes_loosecontrol.cpython-312.pyc differ
diff --git a/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/nodes_main.cpython-312.pyc b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/nodes_main.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c3dea10df75db2931ca9cfed572b694a3b3e5b87
Binary files /dev/null and b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/nodes_main.cpython-312.pyc differ
diff --git a/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/nodes_plusplus.cpython-312.pyc b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/nodes_plusplus.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..167aa189d9592218380d6c7b9b93dc83901081e7
Binary files /dev/null and b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/nodes_plusplus.cpython-312.pyc differ
diff --git a/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/nodes_reference.cpython-312.pyc b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/nodes_reference.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2c4af55eeb1294762ec929ddcf038a1e206401d4
Binary files /dev/null and b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/nodes_reference.cpython-312.pyc differ
diff --git a/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/nodes_sparsectrl.cpython-312.pyc b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/nodes_sparsectrl.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b93e1f5ca672b0f3d892d9625bc228cab9c37b98
Binary files /dev/null and b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/nodes_sparsectrl.cpython-312.pyc differ
diff --git a/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/nodes_weight.cpython-312.pyc b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/nodes_weight.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aa350148e950e03c56ff04a52488cfc6b3d56058
Binary files /dev/null and b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/nodes_weight.cpython-312.pyc differ
diff --git a/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/sampling.cpython-312.pyc b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/sampling.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8de3d46856f7cc58663f50b659bdcbb87490eabb
Binary files /dev/null and b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/sampling.cpython-312.pyc differ
diff --git a/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/utils.cpython-312.pyc b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eefa33d57d923d18f3df3316d369f2eca5e59ce3
Binary files /dev/null and b/custom_nodes/comfyui-advanced-controlnet/adv_control/__pycache__/utils.cpython-312.pyc differ
diff --git a/custom_nodes/comfyui-advanced-controlnet/adv_control/control.py b/custom_nodes/comfyui-advanced-controlnet/adv_control/control.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb71b0f212a541b77f524b99fa0efa6a50b38c32
--- /dev/null
+++ b/custom_nodes/comfyui-advanced-controlnet/adv_control/control.py
@@ -0,0 +1,983 @@
+from typing import Callable, Union
+from torch import Tensor
+import torch
+import os
+
+import comfy.model_base
+import comfy.ops
+import comfy.utils
+import comfy.model_management
+import comfy.model_detection
+import comfy.controlnet as comfy_cn
+from comfy.controlnet import ControlBase, ControlNet, ControlNetSD35, ControlLora, T2IAdapter, StrengthType
+from comfy.model_patcher import ModelPatcher
+
+from .control_sparsectrl import SparseControlNet, SparseSettings, SparseConst, InterfaceAnimateDiffModel, create_sparse_modelpatcher, load_sparsectrl_motionmodel
+from .control_lllite import LLLiteModule, LLLitePatch, load_controllllite
+from .control_svd import svd_unet_config_from_diffusers_unet, SVDControlNet, svd_unet_to_diffusers
+from .utils import (AdvancedControlBase, TimestepKeyframeGroup, LatentKeyframeGroup, AbstractPreprocWrapper, ControlWeightType, ControlWeights, WeightTypeException, Extras,
+ manual_cast_clean_groupnorm, disable_weight_init_clean_groupnorm, WrapperConsts, prepare_mask_batch, get_properly_arranged_t2i_weights, load_torch_file_with_dict_factory,
+ broadcast_image_to_extend, extend_to_batch_size, ORIG_PREVIOUS_CONTROLNET, CONTROL_INIT_BY_ACN)
+from .logger import logger
+
+
+class ControlNetAdvanced(ControlNet, AdvancedControlBase):
+ def __init__(self, control_model, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None,
+ extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False, preprocess_image=lambda a: a):
+ super().__init__(control_model=control_model, global_average_pooling=global_average_pooling, compression_ratio=compression_ratio, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype,
+ extra_conds=extra_conds, strength_type=strength_type, concat_mask=concat_mask, preprocess_image=preprocess_image)
+ AdvancedControlBase.__init__(self, super(type(self), self), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controlnet())
+ self.is_flux = False
+ self.x_noisy_shape = None
+
+ def get_universal_weights(self) -> ControlWeights:
+ def cn_weights_func(idx: int, control: dict[str, list[Tensor]], key: str):
+ if key == "middle":
+ return 1.0 * self.weights.extras.get(Extras.MIDDLE_MULT, 1.0)
+ c_len = len(control[key])
+ raw_weights = [(self.weights.base_multiplier ** float((c_len) - i)) for i in range(c_len+1)]
+ raw_weights = raw_weights[:-1]
+ if key == "input":
+ raw_weights.reverse()
+ return raw_weights[idx]
+ return self.weights.copy_with_new_weights(new_weight_func=cn_weights_func)
+
+ def get_control_advanced(self, x_noisy, t, cond, batched_number, transformer_options):
+ # perform special version of get_control that supports sliding context and masks
+ return self.sliding_get_control(x_noisy, t, cond, batched_number, transformer_options)
+
+ def sliding_get_control(self, x_noisy: Tensor, t, cond, batched_number, transformer_options):
+ control_prev = None
+ if self.previous_controlnet is not None:
+ control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options)
+
+ if self.timestep_range is not None:
+ if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
+ if control_prev is not None:
+ return control_prev
+ else:
+ return None
+
+ dtype = self.control_model.dtype
+ if self.manual_cast_dtype is not None:
+ dtype = self.manual_cast_dtype
+
+ # make cond_hint appropriate dimensions
+ # TODO: change this to not require cond_hint upscaling every step when self.sub_idxs are present
+ if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] * self.real_compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.real_compression_ratio != self.cond_hint.shape[3]:
+ if self.cond_hint is not None:
+ del self.cond_hint
+ self.cond_hint = None
+ self.real_compression_ratio = self.compression_ratio
+ compression_ratio = self.compression_ratio
+ if self.vae is not None and self.mult_by_ratio_when_vae:
+ compression_ratio *= self.vae.downscale_ratio
+ # if self.cond_hint_original length greater or equal to real latent count, subdivide it before scaling
+ if self.sub_idxs is not None:
+ actual_cond_hint_orig = self.cond_hint_original
+ if self.cond_hint_original.size(0) < self.full_latent_length:
+ actual_cond_hint_orig = extend_to_batch_size(tensor=actual_cond_hint_orig, batch_size=self.full_latent_length)
+ self.cond_hint = comfy.utils.common_upscale(actual_cond_hint_orig[self.sub_idxs], x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
+ else:
+ self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
+ self.cond_hint = self.preprocess_image(self.cond_hint)
+ if self.vae is not None:
+ loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
+ self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
+ comfy.model_management.load_models_gpu(loaded_models)
+ if not self.mult_by_ratio_when_vae:
+ self.real_compression_ratio = 1
+ if self.latent_format is not None:
+ self.cond_hint = self.latent_format.process_in(self.cond_hint)
+ if len(self.extra_concat_orig) > 0:
+ to_concat = []
+ for c in self.extra_concat_orig:
+ c = c.to(self.cond_hint.device)
+ c = comfy.utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center")
+ to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
+ self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
+
+ self.cond_hint = self.cond_hint.to(device=x_noisy.device, dtype=dtype)
+ if x_noisy.shape[0] != self.cond_hint.shape[0]:
+ self.cond_hint = broadcast_image_to_extend(self.cond_hint, x_noisy.shape[0], batched_number)
+
+ # prepare mask_cond_hint
+ self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number, dtype=dtype)
+
+ context = cond.get('crossattn_controlnet', cond['c_crossattn'])
+ extra = self.extra_args.copy()
+ for c in self.extra_conds:
+ temp = cond.get(c, None)
+ if temp is not None:
+ extra[c] = comfy.model_base.convert_tensor(temp, dtype, x_noisy.device)
+
+ timestep = self.model_sampling_current.timestep(t)
+ x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
+ self.x_noisy_shape = x_noisy.shape
+
+ control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=comfy.model_management.cast_to_device(context, x_noisy.device, dtype), **extra)
+ return self.control_merge(control, control_prev, output_dtype=None)
+
+ def pre_run_advanced(self, *args, **kwargs):
+ self.is_flux = "Flux" in str(type(self.control_model).__name__)
+ return super().pre_run_advanced(*args, **kwargs)
+
+ def apply_advanced_strengths_and_masks(self, x: Tensor, batched_number: int, flux_shape=None):
+ if self.is_flux:
+ flux_shape = self.x_noisy_shape
+ return super().apply_advanced_strengths_and_masks(x, batched_number, flux_shape)
+
+ def copy(self, subtype=None):
+ if subtype is None:
+ subtype = ControlNetAdvanced
+ c = subtype(self.control_model, self.timestep_keyframes, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
+ c.control_model = self.control_model
+ c.control_model_wrapped = self.control_model_wrapped
+ self.copy_to(c)
+ self.copy_to_advanced(c)
+ return c
+
+ def cleanup_advanced(self):
+ self.x_noisy_shape = None
+ return super().cleanup_advanced()
+
+ @staticmethod
+ def from_vanilla(v: ControlNet, timestep_keyframe: TimestepKeyframeGroup=None, subtype=None) -> 'ControlNetAdvanced':
+ if subtype is None:
+ subtype = ControlNetAdvanced
+ to_return = subtype(control_model=v.control_model, timestep_keyframes=timestep_keyframe,
+ global_average_pooling=v.global_average_pooling, compression_ratio=v.compression_ratio, latent_format=v.latent_format, load_device=v.load_device,
+ manual_cast_dtype=v.manual_cast_dtype, extra_conds=v.extra_conds, strength_type=v.strength_type, concat_mask=v.concat_mask, preprocess_image=v.preprocess_image)
+ v.copy_to(to_return)
+ to_return.control_model_wrapped = v.control_model_wrapped.clone() # needed to avoid breaking memory management system (parent tracking)
+ return to_return
+
+
+class ControlNetSD35Advanced(ControlNetSD35, ControlNetAdvanced):
+ def __init__(self, *args, **kwargs):
+ ControlNetAdvanced.__init__(self, *args, **kwargs)
+
+ def copy(self):
+ return ControlNetAdvanced.copy(self, subtype=ControlNetSD35Advanced)
+
+ @staticmethod
+ def from_vanilla(v: ControlNetSD35, timestep_keyframe=None):
+ return ControlNetAdvanced.from_vanilla(v, timestep_keyframe, subtype=ControlNetSD35Advanced)
+
+
+class T2IAdapterAdvanced(T2IAdapter, AdvancedControlBase):
+ def __init__(self, t2i_model, timestep_keyframes: TimestepKeyframeGroup, channels_in, compression_ratio=8, upscale_algorithm="nearest_exact", device=None):
+ super().__init__(t2i_model=t2i_model, channels_in=channels_in, compression_ratio=compression_ratio, upscale_algorithm=upscale_algorithm, device=device)
+ AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.t2iadapter())
+
+ def control_merge_inject(self, control: dict[str, list[Tensor]], control_prev, output_dtype):
+ # match batch_size
+ # TODO: make this more efficient by modifying the cached self.control_input val instead of doing this every step
+ for key in control:
+ control_current = control[key]
+ for i in range(len(control_current)):
+ x = control_current[i]
+ if x is not None and x.size(0) == 1 and x.size(0) != self.batch_size:
+ control_current[i] = x.repeat(self.batch_size, 1, 1, 1)[:self.batch_size]
+ return AdvancedControlBase.control_merge_inject(self, control, control_prev, output_dtype)
+
+ def get_universal_weights(self) -> ControlWeights:
+ def t2i_weights_func(idx: int, control: dict[str, list[Tensor]], key: str):
+ if key == "middle":
+ return 1.0 * self.weights.extras.get(Extras.MIDDLE_MULT, 1.0)
+ c_len = 8 #len(control[key])
+ raw_weights = [(self.weights.base_multiplier ** float((c_len-1) - i)) for i in range(c_len)]
+ raw_weights = [raw_weights[-c_len], raw_weights[-3], raw_weights[-2], raw_weights[-1]]
+ raw_weights = get_properly_arranged_t2i_weights(raw_weights)
+ if key == "input":
+ raw_weights.reverse()
+ return raw_weights[idx]
+ return self.weights.copy_with_new_weights(new_weight_func=t2i_weights_func)
+
+ def get_calc_pow(self, idx: int, control: dict[str, list[Tensor]], key: str) -> int:
+ if key == "middle":
+ return 0
+ # match how T2IAdapterAdvanced deals with universal weights
+ c_len = 8 #len(control[key])
+ indeces = [(c_len-1) - i for i in range(c_len)]
+ indeces = [indeces[-c_len], indeces[-3], indeces[-2], indeces[-1]]
+ indeces = get_properly_arranged_t2i_weights(indeces)
+ if key == "input":
+ indeces.reverse() # need to reverse to match recent ComfyUI changes
+ return indeces[idx]
+
+ def get_control_advanced(self, x_noisy, t, cond, batched_number, transformer_options):
+ try:
+ # if sub indexes present, replace original hint with subsection
+ if self.sub_idxs is not None:
+ # cond hints
+ full_cond_hint_original = self.cond_hint_original
+ actual_cond_hint_orig = full_cond_hint_original
+ del self.cond_hint
+ self.cond_hint = None
+ if full_cond_hint_original.size(0) < self.full_latent_length:
+ actual_cond_hint_orig = extend_to_batch_size(tensor=full_cond_hint_original, batch_size=full_cond_hint_original.size(0))
+ self.cond_hint_original = actual_cond_hint_orig[self.sub_idxs]
+ # mask hints
+ self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number)
+ return super().get_control(x_noisy, t, cond, batched_number, transformer_options)
+ finally:
+ if self.sub_idxs is not None:
+ # replace original cond hint
+ self.cond_hint_original = full_cond_hint_original
+ del full_cond_hint_original
+
+ def copy(self):
+ c = T2IAdapterAdvanced(self.t2i_model, self.timestep_keyframes, self.channels_in, self.compression_ratio, self.upscale_algorithm)
+ self.copy_to(c)
+ self.copy_to_advanced(c)
+ return c
+
+ def cleanup(self):
+ super().cleanup()
+ self.cleanup_advanced()
+
+ @staticmethod
+ def from_vanilla(v: T2IAdapter, timestep_keyframe: TimestepKeyframeGroup=None) -> 'T2IAdapterAdvanced':
+ to_return = T2IAdapterAdvanced(t2i_model=v.t2i_model, timestep_keyframes=timestep_keyframe, channels_in=v.channels_in,
+ compression_ratio=v.compression_ratio, upscale_algorithm=v.upscale_algorithm, device=v.device)
+ v.copy_to(to_return)
+ return to_return
+
+
+class ControlLoraAdvanced(ControlLora, AdvancedControlBase):
+ def __init__(self, control_weights, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False):
+ super().__init__(control_weights=control_weights, global_average_pooling=global_average_pooling)
+ AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controllora())
+ # use some functions from ControlNetAdvanced
+ self.get_control_advanced = ControlNetAdvanced.get_control_advanced.__get__(self, type(self))
+ self.sliding_get_control = ControlNetAdvanced.sliding_get_control.__get__(self, type(self))
+
+ def get_universal_weights(self) -> ControlWeights:
+ raw_weights = [(self.weights.base_multiplier ** float(9 - i)) for i in range(10)]
+ return self.weights.copy_with_new_weights(raw_weights)
+
+ def copy(self):
+ c = ControlLoraAdvanced(self.control_weights, self.timestep_keyframes, global_average_pooling=self.global_average_pooling)
+ self.copy_to(c)
+ self.copy_to_advanced(c)
+ return c
+
+ def cleanup(self):
+ super().cleanup()
+ self.cleanup_advanced()
+
+ @staticmethod
+ def from_vanilla(v: ControlLora, timestep_keyframe: TimestepKeyframeGroup=None) -> 'ControlLoraAdvanced':
+ to_return = ControlLoraAdvanced(control_weights=v.control_weights, timestep_keyframes=timestep_keyframe,
+ global_average_pooling=v.global_average_pooling)
+ v.copy_to(to_return)
+ return to_return
+
+
+class SVDControlNetAdvanced(ControlNetAdvanced):
+ def __init__(self, control_model: SVDControlNet, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False, load_device=None, manual_cast_dtype=None):
+ super().__init__(control_model=control_model, timestep_keyframes=timestep_keyframes, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
+
+ def set_cond_hint_inject(self, *args, **kwargs):
+ to_return = super().set_cond_hint_inject(*args, **kwargs)
+ # cond hint for SVD-ControlNet needs to be scaled between (-1, 1) instead of (0, 1)
+ self.cond_hint_original = self.cond_hint_original * 2.0 - 1.0
+ return to_return
+
+ def get_control_advanced(self, x_noisy, t, cond, batched_number, transformer_options):
+ control_prev = None
+ if self.previous_controlnet is not None:
+ control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options)
+
+ if self.timestep_range is not None:
+ if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
+ if control_prev is not None:
+ return control_prev
+ else:
+ return None
+
+ dtype = self.control_model.dtype
+ if self.manual_cast_dtype is not None:
+ dtype = self.manual_cast_dtype
+
+ output_dtype = x_noisy.dtype
+ # make cond_hint appropriate dimensions
+ # TODO: change this to not require cond_hint upscaling every step when self.sub_idxs are present
+ if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
+ if self.cond_hint is not None:
+ del self.cond_hint
+ self.cond_hint = None
+ # if self.cond_hint_original length greater or equal to real latent count, subdivide it before scaling
+ if self.sub_idxs is not None:
+ actual_cond_hint_orig = self.cond_hint_original
+ if self.cond_hint_original.size(0) < self.full_latent_length:
+ actual_cond_hint_orig = extend_to_batch_size(tensor=actual_cond_hint_orig, batch_size=self.full_latent_length)
+ self.cond_hint = comfy.utils.common_upscale(actual_cond_hint_orig[self.sub_idxs], x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(x_noisy.device)
+ else:
+ self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(x_noisy.device)
+ if x_noisy.shape[0] != self.cond_hint.shape[0]:
+ self.cond_hint = broadcast_image_to_extend(self.cond_hint, x_noisy.shape[0], batched_number)
+
+ # prepare mask_cond_hint
+ self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number, dtype=dtype)
+
+ context = cond.get('crossattn_controlnet', cond['c_crossattn'])
+ # uses 'y' in new ComfyUI update
+ y = cond.get('y', None)
+ if y is not None:
+ y = comfy.model_base.convert_tensor(y, dtype, x_noisy.device)
+ timestep = self.model_sampling_current.timestep(t)
+ x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
+ # concat c_concat if exists (should exist for SVD), doubling channels to 8
+ if cond.get('c_concat', None) is not None:
+ x_noisy = torch.cat([x_noisy] + [cond['c_concat']], dim=1)
+
+ control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=comfy.model_management.cast_to_device(context, x_noisy.device, dtype), y=y, cond=cond)
+ return self.control_merge(control, control_prev, output_dtype)
+
+ def copy(self):
+ c = SVDControlNetAdvanced(self.control_model, self.timestep_keyframes, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
+ self.copy_to(c)
+ self.copy_to_advanced(c)
+ return c
+
+
+class SparseCtrlAdvanced(ControlNetAdvanced):
+ def __init__(self, control_model: SparseControlNet, motion_model: InterfaceAnimateDiffModel,
+ timestep_keyframes: TimestepKeyframeGroup, sparse_settings: SparseSettings=None, global_average_pooling=False, load_device=None, manual_cast_dtype=None):
+ super().__init__(control_model=None, timestep_keyframes=timestep_keyframes, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
+ self.control_model = control_model
+ if control_model is not None:
+ self.control_model_wrapped: ModelPatcher = create_sparse_modelpatcher(self.control_model, motion_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
+ self.prepare_conditioning_info()
+ self.add_compatible_weight(ControlWeightType.SPARSECTRL)
+ self.postpone_condhint_latents_check = True
+ self.sparse_settings = sparse_settings if sparse_settings is not None else SparseSettings.default()
+ self.model_latent_format = None # latent format for active SD model, NOT controlnet
+ self.preprocessed = False
+
+ def prepare_conditioning_info(self):
+ if self.control_model.use_simplified_conditioning_embedding:
+ # TODO: allow vae_optional to be used instead of preprocessor
+ #self.require_vae = True
+ self.allow_condhint_latents = True
+
+ @property
+ def motion_model(self) -> InterfaceAnimateDiffModel:
+ motion_models = self.control_model_wrapped.get_additional_models_with_key(WrapperConsts.ACN)
+ if len(motion_models) == 0:
+ return None
+ return motion_models[0].model
+
+ def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int, transformer_options):
+ # normal ControlNet stuff
+ control_prev = None
+ if self.previous_controlnet is not None:
+ control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options)
+
+ if self.timestep_range is not None:
+ if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
+ if control_prev is not None:
+ return control_prev
+ else:
+ return None
+
+ dtype = self.control_model.dtype
+ if self.manual_cast_dtype is not None:
+ dtype = self.manual_cast_dtype
+ output_dtype = x_noisy.dtype
+ # set actual input length on motion model
+ actual_length = x_noisy.size(0)//batched_number
+ full_length = actual_length if self.sub_idxs is None else self.full_latent_length
+ if self.motion_model is not None:
+ self.motion_model.set_video_length(video_length=actual_length, full_length=full_length)
+ # prepare cond_hint, if needed
+ dim_mult = 1 if self.control_model.use_simplified_conditioning_embedding else 8
+ if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2]*dim_mult != self.cond_hint.shape[2] or x_noisy.shape[3]*dim_mult != self.cond_hint.shape[3]:
+ # clear out cond_hint and conditioning_mask
+ if self.cond_hint is not None:
+ del self.cond_hint
+ self.cond_hint = None
+ # first, figure out which cond idxs are relevant, and where they fit in
+ cond_idxs, hint_order = self.sparse_settings.sparse_method.get_indexes(hint_length=self.cond_hint_original.size(0), full_length=full_length,
+ sub_idxs=self.sub_idxs if self.sparse_settings.is_context_aware() else None)
+ range_idxs = list(range(full_length)) if self.sub_idxs is None else self.sub_idxs
+ hint_idxs = [] # idxs in cond_idxs
+ local_idxs = [] # idx to put in final cond_hint
+ for i,cond_idx in enumerate(cond_idxs):
+ if cond_idx in range_idxs:
+ hint_idxs.append(i)
+ local_idxs.append(range_idxs.index(cond_idx))
+ # log_string = f"cond_idxs: {cond_idxs}, local_idxs: {local_idxs}, hint_idxs: {hint_idxs}, hint_order: {hint_order}"
+ # if self.sub_idxs is not None:
+ # log_string += f" sub_idxs: {self.sub_idxs[0]}-{self.sub_idxs[-1]}"
+ # logger.warn(log_string)
+ # determine cond/uncond indexes that will get masked
+ self.local_sparse_idxs = []
+ self.local_sparse_idxs_inverse = list(range(x_noisy.size(0)))
+ for batch_idx in range(batched_number):
+ for i in local_idxs:
+ actual_i = i+(batch_idx*actual_length)
+ self.local_sparse_idxs.append(actual_i)
+ if actual_i in self.local_sparse_idxs_inverse:
+ self.local_sparse_idxs_inverse.remove(actual_i)
+ # sub_cond_hint now contains the hints relevant to current x_noisy
+ if hint_order is None:
+ sub_cond_hint = self.cond_hint_original[hint_idxs].to(dtype).to(x_noisy.device)
+ else:
+ sub_cond_hint = self.cond_hint_original[hint_order][hint_idxs].to(dtype).to(x_noisy.device)
+ # scale cond_hints to match noisy input
+ if self.control_model.use_simplified_conditioning_embedding:
+ # RGB SparseCtrl; the inputs are latents - use bilinear to avoid blocky artifacts
+ sub_cond_hint = self.model_latent_format.process_in(sub_cond_hint) # multiplies by model scale factor
+ sub_cond_hint = comfy.utils.common_upscale(sub_cond_hint, x_noisy.shape[3], x_noisy.shape[2], "nearest-exact", "center").to(dtype).to(x_noisy.device)
+ else:
+ # other SparseCtrl; inputs are typical images
+ sub_cond_hint = comfy.utils.common_upscale(sub_cond_hint, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(x_noisy.device)
+ # prepare cond_hint (b, c, h ,w)
+ cond_shape = list(sub_cond_hint.shape)
+ cond_shape[0] = len(range_idxs)
+ self.cond_hint = torch.zeros(cond_shape).to(dtype).to(x_noisy.device)
+ self.cond_hint[local_idxs] = sub_cond_hint[:]
+ # prepare cond_mask (b, 1, h, w)
+ cond_shape[1] = 1
+ cond_mask = torch.zeros(cond_shape).to(dtype).to(x_noisy.device)
+ cond_mask[local_idxs] = self.sparse_settings.sparse_mask_mult * self.weights.extras.get(SparseConst.MASK_MULT, 1.0)
+ # combine cond_hint and cond_mask into (b, c+1, h, w)
+ if not self.sparse_settings.merged:
+ self.cond_hint = torch.cat([self.cond_hint, cond_mask], dim=1)
+ del sub_cond_hint
+ del cond_mask
+ # make cond_hint match x_noisy batch
+ if x_noisy.shape[0] != self.cond_hint.shape[0]:
+ self.cond_hint = broadcast_image_to_extend(self.cond_hint, x_noisy.shape[0], batched_number)
+
+ # prepare mask_cond_hint
+ self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number, dtype=dtype)
+
+ context = cond['c_crossattn']
+ y = cond.get('y', None)
+ if y is not None:
+ y = comfy.model_base.convert_tensor(y, dtype, x_noisy.device)
+ timestep = self.model_sampling_current.timestep(t)
+ x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
+
+ control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=comfy.model_management.cast_to_device(context, x_noisy.device, dtype), y=y)
+ return self.control_merge(control, control_prev, output_dtype)
+
+ def apply_advanced_strengths_and_masks(self, x: Tensor, batched_number: int, *args, **kwargs):
+ # apply mults to indexes with and without a direct condhint
+ x[self.local_sparse_idxs] *= self.sparse_settings.sparse_hint_mult * self.weights.extras.get(SparseConst.HINT_MULT, 1.0)
+ x[self.local_sparse_idxs_inverse] *= self.sparse_settings.sparse_nonhint_mult * self.weights.extras.get(SparseConst.NONHINT_MULT, 1.0)
+ return super().apply_advanced_strengths_and_masks(x, batched_number, *args, **kwargs)
+
+ def pre_run_advanced(self, model, percent_to_timestep_function):
+ super().pre_run_advanced(model, percent_to_timestep_function)
+ if isinstance(self.cond_hint_original, AbstractPreprocWrapper):
+ if not self.control_model.use_simplified_conditioning_embedding:
+ raise ValueError("Any model besides RGB SparseCtrl should NOT have its images go through the RGB SparseCtrl preprocessor.")
+ self.cond_hint_original = self.cond_hint_original.condhint
+ self.model_latent_format = model.latent_format # LatentFormat object, used to process_in latent cond hint
+ if self.motion_model is not None:
+ self.motion_model.cleanup()
+ self.motion_model.set_effect(self.sparse_settings.motion_strength)
+ self.motion_model.set_scale(self.sparse_settings.motion_scale)
+
+ def cleanup_advanced(self):
+ super().cleanup_advanced()
+ if self.model_latent_format is not None:
+ del self.model_latent_format
+ self.model_latent_format = None
+ self.local_sparse_idxs = None
+ self.local_sparse_idxs_inverse = None
+ if self.motion_model is not None:
+ self.motion_model.cleanup()
+
+ def copy(self):
+ c = SparseCtrlAdvanced(None, None, self.timestep_keyframes, self.sparse_settings, self.global_average_pooling, self.load_device, self.manual_cast_dtype)
+ c.control_model = self.control_model
+ c.control_model_wrapped = self.control_model_wrapped
+ self.prepare_conditioning_info()
+ self.copy_to(c)
+ self.copy_to_advanced(c)
+ return c
+
+ def get_models(self):
+ to_return = super().get_models()
+ to_return.extend(self.control_model_wrapped.get_additional_models())
+ return to_return
+
+
+def load_controlnet(ckpt_path, timestep_keyframe: TimestepKeyframeGroup=None, model=None):
+ controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
+ # from pathlib import Path
+ # log_name = ckpt_path.split('\\')[-1]
+ # with open(Path(__file__).parent.parent.parent / rf"keys_{log_name}.txt", "w") as afile:
+ # for key, value in controlnet_data.items():
+ # afile.write(f"{key}:\t{value.shape}\n")
+ control = None
+ # check if a non-vanilla ControlNet
+ controlnet_type = ControlWeightType.DEFAULT
+ has_controlnet_key = False
+ has_motion_modules_key = False
+ has_temporal_res_block_key = False
+ for key in controlnet_data:
+ # LLLite check
+ if "lllite" in key:
+ controlnet_type = ControlWeightType.CONTROLLLLITE
+ break
+ # SparseCtrl check
+ elif "motion_modules" in key:
+ has_motion_modules_key = True
+ elif "controlnet" in key:
+ has_controlnet_key = True
+ # SVD-ControlNet check
+ elif "temporal_res_block" in key:
+ has_temporal_res_block_key = True
+ # ControlNet++ check
+ elif "task_embedding" in key:
+ pass
+ # CtrLoRA check
+ elif "lora_layer" in key:
+ controlnet_type = ControlWeightType.CTRLORA
+ break
+
+ if has_controlnet_key and has_motion_modules_key:
+ controlnet_type = ControlWeightType.SPARSECTRL
+ elif has_controlnet_key and has_temporal_res_block_key:
+ controlnet_type = ControlWeightType.SVD_CONTROLNET
+
+ if controlnet_type != ControlWeightType.DEFAULT:
+ if controlnet_type == ControlWeightType.CONTROLLLLITE:
+ control = load_controllllite(ckpt_path, controlnet_data=controlnet_data, timestep_keyframe=timestep_keyframe)
+ elif controlnet_type == ControlWeightType.SPARSECTRL:
+ control = load_sparsectrl(ckpt_path, controlnet_data=controlnet_data, timestep_keyframe=timestep_keyframe, model=model)
+ elif controlnet_type == ControlWeightType.SVD_CONTROLNET:
+ control = load_svdcontrolnet(ckpt_path, controlnet_data=controlnet_data, timestep_keyframe=timestep_keyframe)
+ elif controlnet_type == ControlWeightType.CTRLORA:
+ raise Exception("This is a CtrLoRA; use the Load CtrLoRA Model node.")
+ # otherwise, load vanilla ControlNet
+ else:
+ try:
+ # hacky way of getting load_torch_file in load_controlnet to use already-present controlnet_data and not redo loading
+ orig_load_torch_file = comfy.utils.load_torch_file
+ comfy.utils.load_torch_file = load_torch_file_with_dict_factory(controlnet_data, orig_load_torch_file)
+ control = comfy_cn.load_controlnet(ckpt_path, model=model)
+ finally:
+ comfy.utils.load_torch_file = orig_load_torch_file
+ if control is None:
+ raise Exception(f"Something went wrong when loading '{ckpt_path}'; ControlNet is None.")
+ return convert_to_advanced(control, timestep_keyframe=timestep_keyframe)
+
+
+def convert_to_advanced(control, timestep_keyframe: TimestepKeyframeGroup=None):
+ # if already advanced, leave it be
+ if is_advanced_controlnet(control):
+ return control
+ # if exactly ControlNet returned, transform it into ControlNetAdvanced
+ if type(control) == ControlNet:
+ control = ControlNetAdvanced.from_vanilla(v=control, timestep_keyframe=timestep_keyframe)
+ if is_sd3_advanced_controlnet(control):
+ control.require_vae = True
+ return control
+ # if exactly ControlNetSD35 returned, transform into ControlNetSD35Advanced
+ elif type(control) == ControlNetSD35:
+ control = ControlNetSD35Advanced.from_vanilla(v=control, timestep_keyframe=timestep_keyframe)
+ if is_sd3_advanced_controlnet(control):
+ control.require_vae = True
+ return control
+ # if exactly ControlLora returned, transform it into ControlLoraAdvanced
+ elif type(control) == ControlLora:
+ return ControlLoraAdvanced.from_vanilla(v=control, timestep_keyframe=timestep_keyframe)
+ # if T2IAdapter returned, transform it into T2IAdapterAdvanced
+ elif isinstance(control, T2IAdapter):
+ return T2IAdapterAdvanced.from_vanilla(v=control, timestep_keyframe=timestep_keyframe)
+ # otherwise, leave it be - might be something I am not supporting yet
+ return control
+
+
+def convert_all_to_advanced(conds: dict[str, list[dict[str]]]) -> tuple[bool, list]:
+ cache = {}
+ modified = False
+ new_conds = {}
+ for cond_type in conds:
+ converted_cond: list[dict[str]] = None
+ cond = conds[cond_type]
+ if cond is not None:
+ for actual_cond in cond:
+ need_to_convert = False
+ if "control" in actual_cond:
+ if not are_all_advanced_controlnet(actual_cond["control"]):
+ need_to_convert = True
+ break
+ if not need_to_convert:
+ converted_cond = cond
+ else:
+ converted_cond = []
+ for actual_cond in cond:
+ if not isinstance(actual_cond, dict):
+ converted_cond.append(actual_cond)
+ continue
+ if "control" not in actual_cond:
+ converted_cond.append(actual_cond)
+ elif are_all_advanced_controlnet(actual_cond["control"]):
+ converted_cond.append(actual_cond)
+ else:
+ actual_cond = actual_cond.copy()
+ actual_cond["control"] = _convert_all_control_to_advanced(actual_cond["control"], cache)
+ converted_cond.append(actual_cond)
+ modified = True
+ new_conds[cond_type] = converted_cond
+ return modified, new_conds
+
+
+def _convert_all_control_to_advanced(input_object: ControlBase, cache: dict):
+ output_object = input_object
+ # iteratively convert to advanced, if needed
+ next_cn = None
+ curr_cn = input_object
+ iter = 0
+ while curr_cn is not None:
+ if not is_advanced_controlnet(curr_cn):
+ # if already in cache, then conversion was done before, so just link it and exit
+ if curr_cn in cache:
+ new_cn = cache[curr_cn]
+ if next_cn is not None:
+ setattr(next_cn, ORIG_PREVIOUS_CONTROLNET, next_cn.previous_controlnet)
+ next_cn.previous_controlnet = new_cn
+ if iter == 0: # if was top-level controlnet, that's the new output
+ output_object = new_cn
+ break
+ try:
+ # convert to advanced, and assign previous_controlnet (convert doesn't transfer it)
+ new_cn = convert_to_advanced(curr_cn)
+ except Exception as e:
+ raise Exception("Failed to automatically convert a ControlNet to Advanced to support sliding window context.", e)
+ new_cn.previous_controlnet = curr_cn.previous_controlnet
+ if iter == 0: # if was top-level controlnet, that's the new output
+ output_object = new_cn
+ # if next_cn is present, then it needs to be pointed to new_cn
+ if next_cn is not None:
+ setattr(next_cn, ORIG_PREVIOUS_CONTROLNET, next_cn.previous_controlnet)
+ next_cn.previous_controlnet = new_cn
+ # add to cache
+ cache[curr_cn] = new_cn
+ curr_cn = new_cn
+ next_cn = curr_cn
+ curr_cn = curr_cn.previous_controlnet
+ iter += 1
+ return output_object
+
+
+def restore_all_controlnet_conns(conds: dict[str, list[dict[str]]]):
+ # if a cn has an _orig_previous_controlnet property, restore it and delete
+ for cond_type in conds:
+ cond = conds[cond_type]
+ if cond is not None:
+ for actual_cond in cond:
+ if "control" in actual_cond:
+ # if ACN is the one to have initialized it, delete it
+ # TODO: maybe check if someone else did a similar hack, and carefully pluck out our stuff?
+ if CONTROL_INIT_BY_ACN in actual_cond:
+ actual_cond.pop("control")
+ actual_cond.pop(CONTROL_INIT_BY_ACN)
+ else:
+ _restore_all_controlnet_conns(actual_cond["control"])
+
+
+
+def _restore_all_controlnet_conns(input_object: ControlBase):
+ # restore original previous_controlnet if needed
+ curr_cn = input_object
+ while curr_cn is not None:
+ if hasattr(curr_cn, ORIG_PREVIOUS_CONTROLNET):
+ curr_cn.previous_controlnet = getattr(curr_cn, ORIG_PREVIOUS_CONTROLNET)
+ delattr(curr_cn, ORIG_PREVIOUS_CONTROLNET)
+ curr_cn = curr_cn.previous_controlnet
+
+
+def are_all_advanced_controlnet(input_object: ControlBase):
+ # iteratively check if linked controlnets objects are all advanced
+ curr_cn = input_object
+ while curr_cn is not None:
+ if not is_advanced_controlnet(curr_cn):
+ return False
+ curr_cn = curr_cn.previous_controlnet
+ return True
+
+
+def is_advanced_controlnet(input_object):
+ return hasattr(input_object, "sub_idxs")
+
+
+def is_sd3_advanced_controlnet(input_object: ControlNetAdvanced):
+ return type(input_object) in [ControlNetAdvanced, ControlNetSD35Advanced] and input_object.latent_format is not None
+
+
+def load_sparsectrl(ckpt_path: str, controlnet_data: dict[str, Tensor]=None, timestep_keyframe: TimestepKeyframeGroup=None, sparse_settings=SparseSettings.default(), model=None) -> SparseCtrlAdvanced:
+ if controlnet_data is None:
+ controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
+ # first, separate out motion part from normal controlnet part and attempt to load that portion
+ motion_data = {}
+ for key in list(controlnet_data.keys()):
+ if "temporal" in key:
+ motion_data[key] = controlnet_data.pop(key)
+ if len(motion_data) == 0:
+ raise ValueError(f"No motion-related keys in '{ckpt_path}'; not a valid SparseCtrl model!")
+
+ # now, load as if it was a normal controlnet - mostly copied from comfy load_controlnet function
+ controlnet_config: dict[str] = None
+ is_diffusers = False
+ use_simplified_conditioning_embedding = False
+ if "controlnet_cond_embedding.conv_in.weight" in controlnet_data:
+ is_diffusers = True
+ if "controlnet_cond_embedding.weight" in controlnet_data:
+ is_diffusers = True
+ use_simplified_conditioning_embedding = True
+ if is_diffusers: #diffusers format
+ unet_dtype = comfy.model_management.unet_dtype()
+ controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype)
+ diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
+ diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
+ diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
+
+ count = 0
+ loop = True
+ while loop:
+ suffix = [".weight", ".bias"]
+ for s in suffix:
+ k_in = "controlnet_down_blocks.{}{}".format(count, s)
+ k_out = "zero_convs.{}.0{}".format(count, s)
+ if k_in not in controlnet_data:
+ loop = False
+ break
+ diffusers_keys[k_in] = k_out
+ count += 1
+ # normal conditioning embedding
+ if not use_simplified_conditioning_embedding:
+ count = 0
+ loop = True
+ while loop:
+ suffix = [".weight", ".bias"]
+ for s in suffix:
+ if count == 0:
+ k_in = "controlnet_cond_embedding.conv_in{}".format(s)
+ else:
+ k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s)
+ k_out = "input_hint_block.{}{}".format(count * 2, s)
+ if k_in not in controlnet_data:
+ k_in = "controlnet_cond_embedding.conv_out{}".format(s)
+ loop = False
+ diffusers_keys[k_in] = k_out
+ count += 1
+ # simplified conditioning embedding
+ else:
+ count = 0
+ suffix = [".weight", ".bias"]
+ for s in suffix:
+ k_in = "controlnet_cond_embedding{}".format(s)
+ k_out = "input_hint_block.{}{}".format(count, s)
+ diffusers_keys[k_in] = k_out
+
+ new_sd = {}
+ for k in diffusers_keys:
+ if k in controlnet_data:
+ new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
+
+ leftover_keys = controlnet_data.keys()
+ if len(leftover_keys) > 0:
+ logger.info("leftover keys:", leftover_keys)
+ controlnet_data = new_sd
+
+ pth_key = 'control_model.zero_convs.0.0.weight'
+ pth = False
+ key = 'zero_convs.0.0.weight'
+ if pth_key in controlnet_data:
+ pth = True
+ key = pth_key
+ prefix = "control_model."
+ elif key in controlnet_data:
+ prefix = ""
+ else:
+ raise ValueError("The provided model is not a valid SparseCtrl model! [ErrorCode: HORSERADISH]")
+
+ if controlnet_config is None:
+ unet_dtype = comfy.model_management.unet_dtype()
+ controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
+ load_device = comfy.model_management.get_torch_device()
+ manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
+ if manual_cast_dtype is not None:
+ controlnet_config["operations"] = manual_cast_clean_groupnorm
+ else:
+ controlnet_config["operations"] = disable_weight_init_clean_groupnorm
+ controlnet_config.pop("out_channels")
+ # get proper hint channels
+ if use_simplified_conditioning_embedding:
+ controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
+ controlnet_config["use_simplified_conditioning_embedding"] = use_simplified_conditioning_embedding
+ else:
+ controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
+ controlnet_config["use_simplified_conditioning_embedding"] = use_simplified_conditioning_embedding
+ control_model = SparseControlNet(**controlnet_config)
+
+ if pth:
+ if 'difference' in controlnet_data:
+ if model is not None:
+ comfy.model_management.load_models_gpu([model])
+ model_sd = model.model_state_dict()
+ for x in controlnet_data:
+ c_m = "control_model."
+ if x.startswith(c_m):
+ sd_key = "diffusion_model.{}".format(x[len(c_m):])
+ if sd_key in model_sd:
+ cd = controlnet_data[x]
+ cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
+ else:
+ logger.warning("WARNING: Loaded a diff SparseCtrl without a model. It will very likely not work.")
+
+ class WeightsLoader(torch.nn.Module):
+ pass
+ w = WeightsLoader()
+ w.control_model = control_model
+ missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
+ else:
+ missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
+ if len(missing) > 0 or len(unexpected) > 0:
+ logger.info(f"SparseCtrl ControlNet: {missing}, {unexpected}")
+
+ global_average_pooling = False
+ filename = os.path.splitext(ckpt_path)[0]
+ if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
+ global_average_pooling = True
+
+ # actually load motion portion of model now
+ motion_model = load_sparsectrl_motionmodel(ckpt_path=ckpt_path, motion_data=motion_data, ops=controlnet_config.get("operations", None)).to(comfy.model_management.unet_dtype())
+ # both motion portion and controlnet portions are loaded; ignore motion_model if shouldn't use motion portion
+ if not sparse_settings.use_motion:
+ motion_model = None
+
+ control = SparseCtrlAdvanced(control_model, motion_model, timestep_keyframes=timestep_keyframe, sparse_settings=sparse_settings, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
+ return control
+
+
+def load_svdcontrolnet(ckpt_path: str, controlnet_data: dict[str, Tensor]=None, timestep_keyframe: TimestepKeyframeGroup=None, model=None):
+ if controlnet_data is None:
+ controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
+
+ controlnet_config = None
+ if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
+ unet_dtype = comfy.model_management.unet_dtype()
+ controlnet_config = svd_unet_config_from_diffusers_unet(controlnet_data, unet_dtype)
+ diffusers_keys = svd_unet_to_diffusers(controlnet_config)
+ diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
+ diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
+
+ count = 0
+ loop = True
+ while loop:
+ suffix = [".weight", ".bias"]
+ for s in suffix:
+ k_in = "controlnet_down_blocks.{}{}".format(count, s)
+ k_out = "zero_convs.{}.0{}".format(count, s)
+ if k_in not in controlnet_data:
+ loop = False
+ break
+ diffusers_keys[k_in] = k_out
+ count += 1
+
+ count = 0
+ loop = True
+ while loop:
+ suffix = [".weight", ".bias"]
+ for s in suffix:
+ if count == 0:
+ k_in = "controlnet_cond_embedding.conv_in{}".format(s)
+ else:
+ k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s)
+ k_out = "input_hint_block.{}{}".format(count * 2, s)
+ if k_in not in controlnet_data:
+ k_in = "controlnet_cond_embedding.conv_out{}".format(s)
+ loop = False
+ diffusers_keys[k_in] = k_out
+ count += 1
+
+ new_sd = {}
+ for k in diffusers_keys:
+ if k in controlnet_data:
+ new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
+
+ leftover_keys = controlnet_data.keys()
+ if len(leftover_keys) > 0:
+ spatial_leftover_keys = []
+ temporal_leftover_keys = []
+ other_leftover_keys = []
+ for key in leftover_keys:
+ if "spatial" in key:
+ spatial_leftover_keys.append(key)
+ elif "temporal" in key:
+ temporal_leftover_keys.append(key)
+ else:
+ other_leftover_keys.append(key)
+ logger.warn(f"spatial_leftover_keys ({len(spatial_leftover_keys)}): {spatial_leftover_keys}")
+ logger.warn(f"temporal_leftover_keys ({len(temporal_leftover_keys)}): {temporal_leftover_keys}")
+ logger.warn(f"other_leftover_keys ({len(other_leftover_keys)}): {other_leftover_keys}")
+ #print("leftover keys:", leftover_keys)
+ controlnet_data = new_sd
+
+ pth_key = 'control_model.zero_convs.0.0.weight'
+ pth = False
+ key = 'zero_convs.0.0.weight'
+ if pth_key in controlnet_data:
+ pth = True
+ key = pth_key
+ prefix = "control_model."
+ elif key in controlnet_data:
+ prefix = ""
+ else:
+ raise ValueError("The provided model is not a valid SVD-ControlNet model! [ErrorCode: MUSTARD]")
+
+ if controlnet_config is None:
+ unet_dtype = comfy.model_management.unet_dtype()
+ controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
+ load_device = comfy.model_management.get_torch_device()
+ manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
+ if manual_cast_dtype is not None:
+ controlnet_config["operations"] = comfy.ops.manual_cast
+ controlnet_config.pop("out_channels")
+ controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
+ control_model = SVDControlNet(**controlnet_config)
+
+ if pth:
+ if 'difference' in controlnet_data:
+ if model is not None:
+ comfy.model_management.load_models_gpu([model])
+ model_sd = model.model_state_dict()
+ for x in controlnet_data:
+ c_m = "control_model."
+ if x.startswith(c_m):
+ sd_key = "diffusion_model.{}".format(x[len(c_m):])
+ if sd_key in model_sd:
+ cd = controlnet_data[x]
+ cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
+ else:
+ print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
+
+ class WeightsLoader(torch.nn.Module):
+ pass
+ w = WeightsLoader()
+ w.control_model = control_model
+ missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
+ else:
+ missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
+ if len(missing) > 0 or len(unexpected) > 0:
+ logger.info(f"SVD-ControlNet: {missing}, {unexpected}")
+
+ global_average_pooling = False
+ filename = os.path.splitext(ckpt_path)[0]
+ if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
+ global_average_pooling = True
+
+ control = SVDControlNetAdvanced(control_model, timestep_keyframes=timestep_keyframe, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
+ return control
+
diff --git a/custom_nodes/comfyui-advanced-controlnet/adv_control/control_ctrlora.py b/custom_nodes/comfyui-advanced-controlnet/adv_control/control_ctrlora.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6c6e7fc1f5db5f898687128edb8efb1dc2c8380
--- /dev/null
+++ b/custom_nodes/comfyui-advanced-controlnet/adv_control/control_ctrlora.py
@@ -0,0 +1,231 @@
+# Core code adapted from CtrLoRA github repo:
+# https://github.com/xyfJASON/ctrlora
+import torch
+from torch import Tensor
+
+from comfy.cldm.cldm import ControlNet as ControlNetCLDM
+import comfy.model_detection
+import comfy.model_management
+import comfy.ops
+import comfy.utils
+
+from comfy.ldm.modules.diffusionmodules.util import (
+ zero_module,
+ timestep_embedding,
+)
+
+from .control import ControlNetAdvanced
+from .utils import TimestepKeyframeGroup
+from .logger import logger
+
+
+class ControlNetCtrLoRA(ControlNetCLDM):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ # delete input hint block
+ del self.input_hint_block
+
+ def forward(self, x: Tensor, hint: Tensor, timesteps, context, y=None, **kwargs):
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
+ emb = self.time_embed(t_emb)
+
+ out_output = []
+ out_middle = []
+
+ if self.num_classes is not None:
+ assert y.shape[0] == x.shape[0]
+ emb = emb + self.label_emb(y)
+
+ h = hint.to(dtype=x.dtype)
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
+ h = module(h, emb, context)
+ out_output.append(zero_conv(h, emb, context))
+
+ h = self.middle_block(h, emb, context)
+ out_middle.append(self.middle_block_out(h, emb, context))
+
+ return {"middle": out_middle, "output": out_output}
+
+
+class CtrLoRAAdvanced(ControlNetAdvanced):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.preprocess_image = lambda a: (a + 1) / 2.0
+ self.require_vae = True
+ self.mult_by_ratio_when_vae = False
+
+ def pre_run_advanced(self, model, percent_to_timestep_function):
+ super().pre_run_advanced(model, percent_to_timestep_function)
+ self.latent_format = model.latent_format # LatentFormat object, used to process_in latent cond hint
+
+ def cleanup_advanced(self):
+ super().cleanup_advanced()
+ if self.latent_format is not None:
+ del self.latent_format
+ self.latent_format = None
+
+ def copy(self):
+ c = CtrLoRAAdvanced(self.control_model, self.timestep_keyframes, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
+ c.control_model = self.control_model
+ c.control_model_wrapped = self.control_model_wrapped
+ self.copy_to(c)
+ self.copy_to_advanced(c)
+ return c
+
+
+def load_ctrlora(base_path: str, lora_path: str,
+ base_data: dict[str, Tensor]=None, lora_data: dict[str, Tensor]=None,
+ timestep_keyframe: TimestepKeyframeGroup=None, model=None, model_options={}):
+ if base_data is None:
+ base_data = comfy.utils.load_torch_file(base_path, safe_load=True)
+ controlnet_data = base_data
+
+ # first, check that base_data contains keys with lora_layer
+ contains_lora_layers = False
+ for key in base_data:
+ if "lora_layer" in key:
+ contains_lora_layers = True
+ if not contains_lora_layers:
+ raise Exception(f"File '{base_path}' is not a valid CtrLoRA base model; does not contain any lora_layer keys.")
+
+ controlnet_config = None
+ supported_inference_dtypes = None
+
+ pth_key = 'control_model.zero_convs.0.0.weight'
+ pth = False
+ key = 'zero_convs.0.0.weight'
+ if pth_key in controlnet_data:
+ pth = True
+ key = pth_key
+ prefix = "control_model."
+ elif key in controlnet_data:
+ prefix = ""
+ else:
+ raise Exception("")
+ net = load_t2i_adapter(controlnet_data, model_options=model_options)
+ if net is None:
+ logging.error("error could not detect control model type.")
+ return net
+
+ if controlnet_config is None:
+ model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
+ supported_inference_dtypes = list(model_config.supported_inference_dtypes)
+ controlnet_config = model_config.unet_config
+
+ unet_dtype = model_options.get("dtype", None)
+ if unet_dtype is None:
+ weight_dtype = comfy.utils.weight_dtype(controlnet_data)
+
+ if supported_inference_dtypes is None:
+ supported_inference_dtypes = [comfy.model_management.unet_dtype()]
+
+ if weight_dtype is not None:
+ supported_inference_dtypes.append(weight_dtype)
+
+ unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
+
+ load_device = comfy.model_management.get_torch_device()
+
+ manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
+ operations = model_options.get("custom_operations", None)
+ if operations is None:
+ operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype)
+
+ controlnet_config["operations"] = operations
+ controlnet_config["dtype"] = unet_dtype
+ controlnet_config["device"] = comfy.model_management.unet_offload_device()
+ controlnet_config.pop("out_channels")
+ controlnet_config["hint_channels"] = 3
+ #controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
+ control_model = ControlNetCtrLoRA(**controlnet_config)
+
+ if pth:
+ if 'difference' in controlnet_data:
+ if model is not None:
+ comfy.model_management.load_models_gpu([model])
+ model_sd = model.model_state_dict()
+ for x in controlnet_data:
+ c_m = "control_model."
+ if x.startswith(c_m):
+ sd_key = "diffusion_model.{}".format(x[len(c_m):])
+ if sd_key in model_sd:
+ cd = controlnet_data[x]
+ cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
+ else:
+ logger.warning("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
+
+ class WeightsLoader(torch.nn.Module):
+ pass
+ w = WeightsLoader()
+ w.control_model = control_model
+ missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
+ else:
+ missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
+
+ if len(missing) > 0:
+ logger.warning("missing controlnet keys: {}".format(missing))
+
+ if len(unexpected) > 0:
+ logger.debug("unexpected controlnet keys: {}".format(unexpected))
+
+ global_average_pooling = model_options.get("global_average_pooling", False)
+ control = CtrLoRAAdvanced(control_model, timestep_keyframe, global_average_pooling=global_average_pooling,
+ load_device=load_device, manual_cast_dtype=manual_cast_dtype)
+ # load lora data onto the controlnet
+ if lora_path is not None:
+ load_lora_data(control, lora_path)
+
+ return control
+
+
+def load_lora_data(control: CtrLoRAAdvanced, lora_path: str, loaded_data: dict[str, Tensor]=None, lora_strength=1.0):
+ if loaded_data is None:
+ loaded_data = comfy.utils.load_torch_file(lora_path, safe_load=True)
+ # check that lora_data contains keys with lora_layer
+ contains_lora_layers = False
+ for key in loaded_data:
+ if "lora_layer" in key:
+ contains_lora_layers = True
+ if not contains_lora_layers:
+ raise Exception(f"File '{lora_path}' is not a valid CtrLoRA lora model; does not contain any lora_layer keys.")
+
+ # now that we know we have a ctrlora file, separate keys into 'set' and 'lora' keys
+ data_set: dict[str, Tensor] = {}
+ data_lora: dict[str, Tensor] = {}
+
+ for key in list(loaded_data.keys()):
+ if 'lora_layer' in key:
+ data_lora[key] = loaded_data.pop(key)
+ else:
+ data_set[key] = loaded_data.pop(key)
+ # no keys should be left over
+ if len(loaded_data) > 0:
+ logger.warning("Not all keys from CtrlLoRA lora model's loaded data were parsed!")
+
+ # turn set/lora data into corresponding patches;
+ patches = {}
+ # set will replace the values
+ for key, value in data_set.items():
+ # prase model key from key;
+ # remove "control_model."
+ model_key = key.replace("control_model.", "")
+ patches[model_key] = ("set", (value,))
+ # lora will do mm of up and down tensors
+ for down_key in data_lora:
+ # only process lora down keys; we will process both up+down at the same time
+ if ".up." in down_key:
+ continue
+ # get up version of down key
+ up_key = down_key.replace(".down.", ".up.")
+ # get key that will match up with model key;
+ # remove "lora_layer.down." and "control_model."
+ model_key = down_key.replace("lora_layer.down.", "").replace("control_model.", "")
+
+ weight_down = data_lora[down_key]
+ weight_up = data_lora[up_key]
+ # currently, ComfyUI expects 6 elements in 'lora' type, but for future-proofing add a bunch more with None
+ patches[model_key] = ("lora", (weight_up, weight_down, None, None, None, None,
+ None, None, None, None, None, None, None, None))
+
+ # now that patches are made, add them to model
+ control.control_model_wrapped.add_patches(patches, strength_patch=lora_strength)
diff --git a/custom_nodes/comfyui-advanced-controlnet/adv_control/control_lllite.py b/custom_nodes/comfyui-advanced-controlnet/adv_control/control_lllite.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5820a8591eecade0123828b774f398aa9065096
--- /dev/null
+++ b/custom_nodes/comfyui-advanced-controlnet/adv_control/control_lllite.py
@@ -0,0 +1,427 @@
+# adapted from https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI
+# basically, all the LLLite core code is from there, which I then combined with
+# Advanced-ControlNet features and QoL
+import math
+from typing import Union
+from torch import Tensor
+import torch
+import os
+
+import comfy.utils
+import comfy.ops
+import comfy.model_management
+from comfy.model_patcher import ModelPatcher
+from comfy.controlnet import ControlBase
+
+from .logger import logger
+from .utils import (AdvancedControlBase, TimestepKeyframeGroup, ControlWeights, broadcast_image_to_extend, extend_to_batch_size,
+ prepare_mask_batch)
+
+
+# based on set_model_patch code in comfy/model_patcher.py
+def set_model_patch(transformer_options, patch, name):
+ to = transformer_options
+ # check if patch was already added
+ if "patches" in to:
+ current_patches = to["patches"].get(name, [])
+ if patch in current_patches:
+ return
+ if "patches" not in to:
+ to["patches"] = {}
+ to["patches"][name] = to["patches"].get(name, []) + [patch]
+
+def set_model_attn1_patch(transformer_options, patch):
+ set_model_patch(transformer_options, patch, "attn1_patch")
+
+def set_model_attn2_patch(transformer_options, patch):
+ set_model_patch(transformer_options, patch, "attn2_patch")
+
+
+def extra_options_to_module_prefix(extra_options):
+ # extra_options = {'transformer_index': 2, 'block_index': 8, 'original_shape': [2, 4, 128, 128], 'block': ('input', 7), 'n_heads': 20, 'dim_head': 64}
+
+ # block is: [('input', 4), ('input', 5), ('input', 7), ('input', 8), ('middle', 0),
+ # ('output', 0), ('output', 1), ('output', 2), ('output', 3), ('output', 4), ('output', 5)]
+ # transformer_index is: [0, 1, 2, 3, 4, 5, 6, 7, 8], for each block
+ # block_index is: 0-1 or 0-9, depends on the block
+ # input 7 and 8, middle has 10 blocks
+
+ # make module name from extra_options
+ block = extra_options["block"]
+ block_index = extra_options["block_index"]
+ if block[0] == "input":
+ module_pfx = f"lllite_unet_input_blocks_{block[1]}_1_transformer_blocks_{block_index}"
+ elif block[0] == "middle":
+ module_pfx = f"lllite_unet_middle_block_1_transformer_blocks_{block_index}"
+ elif block[0] == "output":
+ module_pfx = f"lllite_unet_output_blocks_{block[1]}_1_transformer_blocks_{block_index}"
+ else:
+ raise Exception(f"ControlLLLite: invalid block name '{block[0]}'. Expected 'input', 'middle', or 'output'.")
+ return module_pfx
+
+
+class LLLitePatch:
+ ATTN1 = "attn1"
+ ATTN2 = "attn2"
+ def __init__(self, modules: dict[str, 'LLLiteModule'], patch_type: str, control: Union[AdvancedControlBase, ControlBase]=None):
+ self.modules = modules
+ self.control = control
+ self.patch_type = patch_type
+ #logger.error(f"create LLLitePatch: {id(self)},{control}")
+
+ def __call__(self, q, k, v, extra_options):
+ #logger.error(f"in __call__: {id(self)}")
+ # determine if have anything to run
+ if self.control.timestep_range is not None:
+ # it turns out comparing single-value tensors to floats is extremely slow
+ # a: Tensor = extra_options["sigmas"][0]
+ if self.control.t > self.control.timestep_range[0] or self.control.t < self.control.timestep_range[1]:
+ return q, k, v
+
+ module_pfx = extra_options_to_module_prefix(extra_options)
+
+ is_attn1 = q.shape[-1] == k.shape[-1] # self attention
+ if is_attn1:
+ module_pfx = module_pfx + "_attn1"
+ else:
+ module_pfx = module_pfx + "_attn2"
+
+ module_pfx_to_q = module_pfx + "_to_q"
+ module_pfx_to_k = module_pfx + "_to_k"
+ module_pfx_to_v = module_pfx + "_to_v"
+
+ if module_pfx_to_q in self.modules:
+ q = q + self.modules[module_pfx_to_q](q, self.control)
+ if module_pfx_to_k in self.modules:
+ k = k + self.modules[module_pfx_to_k](k, self.control)
+ if module_pfx_to_v in self.modules:
+ v = v + self.modules[module_pfx_to_v](v, self.control)
+
+ return q, k, v
+
+ def to(self, device):
+ #logger.info(f"to... has control? {self.control}")
+ for d in self.modules.keys():
+ self.modules[d] = self.modules[d].to(device)
+ return self
+
+ def set_control(self, control: Union[AdvancedControlBase, ControlBase]) -> 'LLLitePatch':
+ self.control = control
+ return self
+ #logger.error(f"set control for LLLitePatch: {id(self)}, cn: {id(control)}")
+
+ def clone_with_control(self, control: AdvancedControlBase):
+ #logger.error(f"clone-set control for LLLitePatch: {id(self)},{id(control)}")
+ return LLLitePatch(self.modules, self.patch_type, control)
+
+ def cleanup(self):
+ for module in self.modules.values():
+ module.cleanup()
+
+
+# TODO: use comfy.ops to support fp8 properly
+class LLLiteModule(torch.nn.Module):
+ def __init__(
+ self,
+ name: str,
+ is_conv2d: bool,
+ in_dim: int,
+ depth: int,
+ cond_emb_dim: int,
+ mlp_dim: int,
+ ):
+ super().__init__()
+ self.name = name
+ self.is_conv2d = is_conv2d
+ self.is_first = False
+
+ modules = []
+ modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size*2
+ if depth == 1:
+ modules.append(torch.nn.ReLU(inplace=True))
+ modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
+ elif depth == 2:
+ modules.append(torch.nn.ReLU(inplace=True))
+ modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0))
+ elif depth == 3:
+ # kernel size 8 is too large, so set it to 4
+ modules.append(torch.nn.ReLU(inplace=True))
+ modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0))
+ modules.append(torch.nn.ReLU(inplace=True))
+ modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
+
+ self.conditioning1 = torch.nn.Sequential(*modules)
+
+ if self.is_conv2d:
+ self.down = torch.nn.Sequential(
+ torch.nn.Conv2d(in_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
+ torch.nn.ReLU(inplace=True),
+ )
+ self.mid = torch.nn.Sequential(
+ torch.nn.Conv2d(mlp_dim + cond_emb_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
+ torch.nn.ReLU(inplace=True),
+ )
+ self.up = torch.nn.Sequential(
+ torch.nn.Conv2d(mlp_dim, in_dim, kernel_size=1, stride=1, padding=0),
+ )
+ else:
+ self.down = torch.nn.Sequential(
+ torch.nn.Linear(in_dim, mlp_dim),
+ torch.nn.ReLU(inplace=True),
+ )
+ self.mid = torch.nn.Sequential(
+ torch.nn.Linear(mlp_dim + cond_emb_dim, mlp_dim),
+ torch.nn.ReLU(inplace=True),
+ )
+ self.up = torch.nn.Sequential(
+ torch.nn.Linear(mlp_dim, in_dim),
+ )
+
+ self.depth = depth
+ self.cond_emb = None
+ self.cx_shape = None
+ self.prev_batch = 0
+ self.prev_sub_idxs = None
+
+ def cleanup(self):
+ del self.cond_emb
+ self.cond_emb = None
+ self.cx_shape = None
+ self.prev_batch = 0
+ self.prev_sub_idxs = None
+
+ def forward(self, x: Tensor, control: Union[AdvancedControlBase, ControlBase]):
+ mask = None
+ mask_tk = None
+ #logger.info(x.shape)
+ if self.cond_emb is None or control.sub_idxs != self.prev_sub_idxs or x.shape[0] != self.prev_batch:
+ # print(f"cond_emb is None, {self.name}")
+ cond_hint = control.cond_hint.to(x.device, dtype=x.dtype)
+ if control.latent_dims_div2 is not None and x.shape[-1] != 1280:
+ cond_hint = comfy.utils.common_upscale(cond_hint, control.latent_dims_div2[0] * 8, control.latent_dims_div2[1] * 8, 'nearest-exact', "center").to(x.device, dtype=x.dtype)
+ elif control.latent_dims_div4 is not None and x.shape[-1] == 1280:
+ cond_hint = comfy.utils.common_upscale(cond_hint, control.latent_dims_div4[0] * 8, control.latent_dims_div4[1] * 8, 'nearest-exact', "center").to(x.device, dtype=x.dtype)
+ cx = self.conditioning1(cond_hint)
+ self.cx_shape = cx.shape
+ if not self.is_conv2d:
+ # reshape / b,c,h,w -> b,h*w,c
+ n, c, h, w = cx.shape
+ cx = cx.view(n, c, h * w).permute(0, 2, 1)
+ self.cond_emb = cx
+ # save prev values
+ self.prev_batch = x.shape[0]
+ self.prev_sub_idxs = control.sub_idxs
+
+ cx: torch.Tensor = self.cond_emb
+ # print(f"forward {self.name}, {cx.shape}, {x.shape}")
+
+ # TODO: make masks work for conv2d (could not find any ControlLLLites at this time that use them)
+ # create masks
+ if not self.is_conv2d:
+ n, c, h, w = self.cx_shape
+ if control.mask_cond_hint is not None:
+ mask = prepare_mask_batch(control.mask_cond_hint, (1, 1, h, w)).to(cx.dtype)
+ mask = mask.view(mask.shape[0], 1, h * w).permute(0, 2, 1)
+ if control.tk_mask_cond_hint is not None:
+ mask_tk = prepare_mask_batch(control.mask_cond_hint, (1, 1, h, w)).to(cx.dtype)
+ mask_tk = mask_tk.view(mask_tk.shape[0], 1, h * w).permute(0, 2, 1)
+
+ # x in uncond/cond doubles batch size
+ if x.shape[0] != cx.shape[0]:
+ if self.is_conv2d:
+ cx = cx.repeat(x.shape[0] // cx.shape[0], 1, 1, 1)
+ else:
+ # print("x.shape[0] != cx.shape[0]", x.shape[0], cx.shape[0])
+ cx = cx.repeat(x.shape[0] // cx.shape[0], 1, 1)
+ if mask is not None:
+ mask = mask.repeat(x.shape[0] // mask.shape[0], 1, 1)
+ if mask_tk is not None:
+ mask_tk = mask_tk.repeat(x.shape[0] // mask_tk.shape[0], 1, 1)
+
+ if mask is None:
+ mask = 1.0
+ elif mask_tk is not None:
+ mask = mask * mask_tk
+
+ #logger.info(f"cs: {cx.shape}, x: {x.shape}, is_conv2d: {self.is_conv2d}")
+ cx = torch.cat([cx, self.down(x)], dim=1 if self.is_conv2d else 2)
+ cx = self.mid(cx)
+ cx = self.up(cx)
+ if control.latent_keyframes is not None:
+ cx = cx * control.calc_latent_keyframe_mults(x=cx, batched_number=control.batched_number)
+ if control.weights is not None and control.weights.has_uncond_multiplier:
+ cond_or_uncond = control.batched_number.cond_or_uncond
+ actual_length = cx.size(0) // control.batched_number
+ for idx, cond_type in enumerate(cond_or_uncond):
+ # if uncond, set to weight's uncond_multiplier
+ if cond_type == 1:
+ cx[actual_length*idx:actual_length*(idx+1)] *= control.weights.uncond_multiplier
+ return cx * mask * control.strength * control._current_timestep_keyframe.strength
+
+
+class ControlLLLiteModules(torch.nn.Module):
+ def __init__(self, patch_attn1: LLLitePatch, patch_attn2: LLLitePatch):
+ super().__init__()
+ self.patch_attn1_modules = torch.nn.Sequential(*list(patch_attn1.modules.values()))
+ self.patch_attn2_modules = torch.nn.Sequential(*list(patch_attn2.modules.values()))
+
+
+class ControlLLLiteAdvanced(ControlBase, AdvancedControlBase):
+ # This ControlNet is more of an attention patch than a traditional controlnet
+ def __init__(self, patch_attn1: LLLitePatch, patch_attn2: LLLitePatch, timestep_keyframes: TimestepKeyframeGroup, device, ops: comfy.ops.disable_weight_init):
+ super().__init__()
+ AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controllllite())
+ self.device = device
+ self.ops = ops
+ self.patch_attn1 = patch_attn1.clone_with_control(self)
+ self.patch_attn2 = patch_attn2.clone_with_control(self)
+ self.control_model = ControlLLLiteModules(self.patch_attn1, self.patch_attn2)
+ self.control_model_wrapped = ModelPatcher(self.control_model, load_device=device, offload_device=comfy.model_management.unet_offload_device())
+ self.latent_dims_div2 = None
+ self.latent_dims_div4 = None
+
+ def set_cond_hint_inject(self, *args, **kwargs):
+ to_return = super().set_cond_hint_inject(*args, **kwargs)
+ # cond hint for LLLite needs to be scaled between (-1, 1) instead of (0, 1)
+ self.cond_hint_original = self.cond_hint_original * 2.0 - 1.0
+ return to_return
+
+ def pre_run_advanced(self, *args, **kwargs):
+ AdvancedControlBase.pre_run_advanced(self, *args, **kwargs)
+ #logger.error(f"in cn: {id(self.patch_attn1)},{id(self.patch_attn2)}")
+ self.patch_attn1.set_control(self)
+ self.patch_attn2.set_control(self)
+ #logger.warn(f"in pre_run_advanced: {id(self)}")
+
+ def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int, transformer_options: dict):
+ # normal ControlNet stuff
+ control_prev = None
+ if self.previous_controlnet is not None:
+ control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options)
+
+ if self.timestep_range is not None:
+ if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
+ return control_prev
+
+ dtype = x_noisy.dtype
+ # prepare cond_hint
+ if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
+ if self.cond_hint is not None:
+ del self.cond_hint
+ self.cond_hint = None
+ # if self.cond_hint_original length greater or equal to real latent count, subdivide it before scaling
+ if self.sub_idxs is not None:
+ actual_cond_hint_orig = self.cond_hint_original
+ if self.cond_hint_original.size(0) < self.full_latent_length:
+ actual_cond_hint_orig = extend_to_batch_size(tensor=actual_cond_hint_orig, batch_size=self.full_latent_length)
+ self.cond_hint = comfy.utils.common_upscale(actual_cond_hint_orig[self.sub_idxs], x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(x_noisy.device)
+ else:
+ self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(x_noisy.device)
+ if x_noisy.shape[0] != self.cond_hint.shape[0]:
+ self.cond_hint = broadcast_image_to_extend(self.cond_hint, x_noisy.shape[0], batched_number)
+ # some special logic here compared to other controlnets:
+ # * The cond_emb in attn patches will divide latent dims by 2 or 4, integer
+ # * Due to this loss, the cond_emb will become smaller than x input if latent dims are not divisble by 2 or 4
+ divisible_by_2_h = x_noisy.shape[2]%2==0
+ divisible_by_2_w = x_noisy.shape[3]%2==0
+ if not (divisible_by_2_h and divisible_by_2_w):
+ #logger.warn(f"{x_noisy.shape} not divisible by 2!")
+ new_h = (x_noisy.shape[2]//2)*2
+ new_w = (x_noisy.shape[3]//2)*2
+ if not divisible_by_2_h:
+ new_h += 2
+ if not divisible_by_2_w:
+ new_w += 2
+ self.latent_dims_div2 = (new_h, new_w)
+ divisible_by_4_h = x_noisy.shape[2]%4==0
+ divisible_by_4_w = x_noisy.shape[3]%4==0
+ if not (divisible_by_4_h and divisible_by_4_w):
+ #logger.warn(f"{x_noisy.shape} not divisible by 4!")
+ new_h = (x_noisy.shape[2]//4)*4
+ new_w = (x_noisy.shape[3]//4)*4
+ if not divisible_by_4_h:
+ new_h += 4
+ if not divisible_by_4_w:
+ new_w += 4
+ self.latent_dims_div4 = (new_h, new_w)
+ # prepare mask
+ self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number)
+ # done preparing; model patches will take care of everything now
+ set_model_attn1_patch(transformer_options, self.patch_attn1.set_control(self))
+ set_model_attn2_patch(transformer_options, self.patch_attn2.set_control(self))
+ # return normal controlnet stuff
+ return control_prev
+
+ def get_models(self):
+ to_return: list = super().get_models()
+ to_return.append(self.control_model_wrapped)
+ return to_return
+
+ def cleanup_advanced(self):
+ super().cleanup_advanced()
+ self.patch_attn1.cleanup()
+ self.patch_attn2.cleanup()
+ self.latent_dims_div2 = None
+ self.latent_dims_div4 = None
+
+ def copy(self):
+ c = ControlLLLiteAdvanced(self.patch_attn1, self.patch_attn2, self.timestep_keyframes, self.device, self.ops)
+ self.copy_to(c)
+ self.copy_to_advanced(c)
+ return c
+
+
+def load_controllllite(ckpt_path: str, controlnet_data: dict[str, Tensor]=None, timestep_keyframe: TimestepKeyframeGroup=None):
+ if controlnet_data is None:
+ controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
+ # adapted from https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI
+ # first, split weights for each module
+ module_weights = {}
+ for key, value in controlnet_data.items():
+ fragments = key.split(".")
+ module_name = fragments[0]
+ weight_name = ".".join(fragments[1:])
+
+ if module_name not in module_weights:
+ module_weights[module_name] = {}
+ module_weights[module_name][weight_name] = value
+
+ unet_dtype = comfy.model_management.unet_dtype()
+ load_device = comfy.model_management.get_torch_device()
+ manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
+ ops = comfy.ops.disable_weight_init
+ if manual_cast_dtype is not None:
+ ops = comfy.ops.manual_cast
+
+ # next, load each module
+ modules = {}
+ for module_name, weights in module_weights.items():
+ # kohya planned to do something about how these should be chosen, so I'm not touching this
+ # since I am not familiar with the logic for this
+ if "conditioning1.4.weight" in weights:
+ depth = 3
+ elif weights["conditioning1.2.weight"].shape[-1] == 4:
+ depth = 2
+ else:
+ depth = 1
+
+ module = LLLiteModule(
+ name=module_name,
+ is_conv2d=weights["down.0.weight"].ndim == 4,
+ in_dim=weights["down.0.weight"].shape[1],
+ depth=depth,
+ cond_emb_dim=weights["conditioning1.0.weight"].shape[0] * 2,
+ mlp_dim=weights["down.0.weight"].shape[0],
+ )
+ # load weights into module
+ module.load_state_dict(weights)
+ modules[module_name] = module.to(dtype=unet_dtype)
+ if len(modules) == 1:
+ module.is_first = True
+
+ #logger.info(f"loaded {ckpt_path} successfully, {len(modules)} modules")
+
+ patch_attn1 = LLLitePatch(modules=modules, patch_type=LLLitePatch.ATTN1)
+ patch_attn2 = LLLitePatch(modules=modules, patch_type=LLLitePatch.ATTN2)
+ control = ControlLLLiteAdvanced(patch_attn1=patch_attn1, patch_attn2=patch_attn2, timestep_keyframes=timestep_keyframe, device=load_device, ops=ops)
+ return control
diff --git a/custom_nodes/comfyui-advanced-controlnet/adv_control/control_plusplus.py b/custom_nodes/comfyui-advanced-controlnet/adv_control/control_plusplus.py
new file mode 100644
index 0000000000000000000000000000000000000000..25cb705734afe9acebf229e63d3f3103daf10948
--- /dev/null
+++ b/custom_nodes/comfyui-advanced-controlnet/adv_control/control_plusplus.py
@@ -0,0 +1,486 @@
+# Code ported and modified from the diffusers ControlNetPlus repo by Qi Xin:
+# https://github.com/xinsir6/ControlNetPlus/blob/main/models/controlnet_union.py
+from typing import Union
+
+import os
+import torch
+import torch as th
+import torch.nn as nn
+from torch import Tensor
+from collections import OrderedDict
+
+
+from comfy.ldm.modules.diffusionmodules.util import (zero_module, timestep_embedding)
+
+from comfy.cldm.cldm import ControlNet as ControlNetCLDM
+import comfy.cldm.cldm
+from comfy.controlnet import ControlNet
+#from comfy.t2i_adapter.adapter import ResidualAttentionBlock
+from comfy.ldm.modules.attention import optimized_attention
+import comfy.ops
+import comfy.model_base
+import comfy.model_management
+import comfy.model_detection
+import comfy.utils
+
+from .utils import (AdvancedControlBase, ControlWeights, ControlWeightType, TimestepKeyframeGroup, AbstractPreprocWrapper, Extras,
+ extend_to_batch_size, broadcast_image_to_extend)
+from .logger import logger
+
+
+class PlusPlusType:
+ OPENPOSE = "openpose"
+ DEPTH = "depth"
+ THICKLINE = "hed/pidi/scribble/ted"
+ THINLINE = "canny/lineart/mlsd"
+ NORMAL = "normal"
+ SEGMENT = "segment"
+ TILE = "tile"
+ REPAINT = "inpaint/outpaint"
+ NONE = "none"
+ _LIST_WITH_NONE = [OPENPOSE, DEPTH, THICKLINE, THINLINE, NORMAL, SEGMENT, TILE, REPAINT, NONE]
+ _LIST = [OPENPOSE, DEPTH, THICKLINE, THINLINE, NORMAL, SEGMENT, TILE, REPAINT]
+ _DICT = {OPENPOSE: 0, DEPTH: 1, THICKLINE: 2, THINLINE: 3, NORMAL: 4, SEGMENT: 5, TILE: 6, REPAINT: 7, NONE: -1}
+
+ @classmethod
+ def to_idx(cls, control_type: str):
+ try:
+ return cls._DICT[control_type]
+ except KeyError:
+ raise Exception(f"Unknown control type '{control_type}'.")
+
+
+class PlusPlusInput:
+ def __init__(self, image: Tensor, control_type: str, strength: float):
+ self.image = image
+ self.control_type = control_type
+ self.strength = strength
+
+ def clone(self):
+ return PlusPlusInput(self.image, self.control_type, self.strength)
+
+
+class PlusPlusInputGroup:
+ def __init__(self):
+ self.controls: dict[str, PlusPlusInput] = {}
+
+ def add(self, pp_input: PlusPlusInput):
+ if pp_input.control_type in self.controls:
+ raise Exception(f"Control type '{pp_input.control_type}' is already present; ControlNet++ does not allow more than 1 of each type.")
+ self.controls[pp_input.control_type] = pp_input
+
+ def clone(self) -> 'PlusPlusInputGroup':
+ cloned = PlusPlusInputGroup()
+ for key, value in self.controls.items():
+ cloned.controls[key] = value.clone()
+ return cloned
+
+
+class PlusPlusImageWrapper(AbstractPreprocWrapper):
+ error_msg = error_msg = "Invalid use of ControlNet++ Image Wrapper. The output of ControlNet++ Image Wrapper is NOT a usual image, but an object holding the images and extra info - you must connect the output directly to an Apply Advanced ControlNet node. It cannot be used for anything else that accepts IMAGE input."
+ def __init__(self, condhint: PlusPlusInputGroup):
+ super().__init__(condhint)
+ # just an IDE type hint
+ self.condhint: PlusPlusInputGroup
+
+ def movedim(self, source: int, destination: int):
+ condhint = self.condhint.clone()
+ for pp_input in condhint.controls.values():
+ pp_input.image = pp_input.image.movedim(source, destination)
+ return PlusPlusImageWrapper(condhint)
+
+# parts taken from comfy/cldm/cldm.py
+class OptimizedAttention(nn.Module):
+ def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.heads = nhead
+ self.c = c
+
+ self.in_proj = operations.Linear(c, c * 3, bias=True, dtype=dtype, device=device)
+ self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
+
+ def forward(self, x):
+ x = self.in_proj(x)
+ q, k, v = x.split(self.c, dim=2)
+ out = optimized_attention(q, k, v, self.heads)
+ return self.out_proj(out)
+
+class QuickGELU(nn.Module):
+ def forward(self, x: torch.Tensor):
+ return x * torch.sigmoid(1.702 * x)
+
+class ResBlockUnionControlnet(nn.Module):
+ def __init__(self, dim, nhead, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.attn = OptimizedAttention(dim, nhead, dtype=dtype, device=device, operations=operations)
+ self.ln_1 = operations.LayerNorm(dim, dtype=dtype, device=device)
+ self.mlp = nn.Sequential(
+ OrderedDict([("c_fc", operations.Linear(dim, dim * 4, dtype=dtype, device=device)), ("gelu", QuickGELU()),
+ ("c_proj", operations.Linear(dim * 4, dim, dtype=dtype, device=device))]))
+ self.ln_2 = operations.LayerNorm(dim, dtype=dtype, device=device)
+
+ def attention(self, x: torch.Tensor):
+ return self.attn(x)
+
+ def forward(self, x: torch.Tensor):
+ x = x + self.attention(self.ln_1(x))
+ x = x + self.mlp(self.ln_2(x))
+ return x
+
+
+class ControlAddEmbeddingAdv(nn.Module):
+ def __init__(self, in_dim, out_dim, num_control_type, dtype=None, device=None, operations: comfy.ops.disable_weight_init=None):
+ super().__init__()
+ self.num_control_type = num_control_type
+ self.in_dim = in_dim
+ self.linear_1 = operations.Linear(in_dim * num_control_type, out_dim, dtype=dtype, device=device)
+ self.linear_2 = operations.Linear(out_dim, out_dim, dtype=dtype, device=device)
+
+ def forward(self, control_type, dtype, device):
+ if control_type is None:
+ control_type = torch.zeros((self.num_control_type,), device=device)
+ c_type = timestep_embedding(control_type.flatten(), self.in_dim, repeat_only=False).to(dtype).reshape((-1, self.num_control_type * self.in_dim))
+ return self.linear_2(torch.nn.functional.silu(self.linear_1(c_type)))
+
+
+class ControlNetPlusPlus(ControlNetCLDM):
+ def __init__(self, *args,**kwargs):
+ super().__init__(*args, **kwargs)
+
+ operations: comfy.ops.disable_weight_init = kwargs.get("operations", comfy.ops.disable_weight_init)
+ device = kwargs.get("device", None)
+
+ time_embed_dim = self.model_channels * 4
+ control_add_embed_dim = 256
+
+ self.control_add_embedding = ControlAddEmbeddingAdv(control_add_embed_dim, time_embed_dim, self.num_control_type, dtype=self.dtype, device=device, operations=operations)
+
+ def union_controlnet_merge(self, hint: list[Tensor], control_type, emb, context):
+ # Equivalent to: https://github.com/xinsir6/ControlNetPlus/tree/main
+ indexes = torch.nonzero(control_type[0])
+ inputs = []
+ condition_list = []
+
+ for idx in range(indexes.shape[0]):
+ controlnet_cond = self.input_hint_block(hint[indexes[idx][0]], emb, context)
+ feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
+ if idx < indexes.shape[0]:
+ feat_seq += self.task_embedding[indexes[idx][0]].to(dtype=feat_seq.dtype, device=feat_seq.device)
+
+ inputs.append(feat_seq.unsqueeze(1))
+ condition_list.append(controlnet_cond)
+
+ x = torch.cat(inputs, dim=1)
+ x = self.transformer_layes(x)
+
+ controlnet_cond_fuser = None
+ for idx in range(indexes.shape[0]):
+ alpha = self.spatial_ch_projs(x[:, idx])
+ alpha = alpha.unsqueeze(-1).unsqueeze(-1)
+ o = condition_list[idx] + alpha
+ if controlnet_cond_fuser is None:
+ controlnet_cond_fuser = o
+ else:
+ controlnet_cond_fuser += o
+ return controlnet_cond_fuser
+
+ def forward(self, x: Tensor, hint: list[Tensor], timesteps, context, y: Tensor=None, **kwargs):
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
+ emb = self.time_embed(t_emb)
+
+ guided_hint = None
+ if self.control_add_embedding is not None:
+ control_type = kwargs.get("control_type", None)
+
+ emb += self.control_add_embedding(control_type, emb.dtype, emb.device)
+ if control_type is not None:
+ guided_hint = self.union_controlnet_merge(hint, control_type, emb, context)
+
+ if guided_hint is None:
+ guided_hint = self.input_hint_block(hint[0], emb, context)
+
+ out_output = []
+ out_middle = []
+
+ hs = []
+ if self.num_classes is not None:
+ assert y.shape[0] == x.shape[0]
+ emb = emb + self.label_emb(y)
+
+ h = x
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
+ if guided_hint is not None:
+ h = module(h, emb, context)
+ h += guided_hint
+ guided_hint = None
+ else:
+ h = module(h, emb, context)
+ out_output.append(zero_conv(h, emb, context))
+
+ h = self.middle_block(h, emb, context)
+ out_middle.append(self.middle_block_out(h, emb, context))
+
+ return {"middle": out_middle, "output": out_output}
+
+
+class ControlNetPlusPlusAdvanced(ControlNet, AdvancedControlBase):
+ def __init__(self, control_model: ControlNetPlusPlus, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False, load_device=None, manual_cast_dtype=None):
+ super().__init__(control_model=control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
+ AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controlnet())
+ self.add_compatible_weight(ControlWeightType.CONTROLNETPLUSPLUS)
+ # for IDE type hint purposes
+ self.control_model: ControlNetPlusPlus
+ self.cond_hint_original: Union[PlusPlusImageWrapper, PlusPlusInputGroup]
+ self.cond_hint: list[Union[Tensor, None]]
+ self.cond_hint_shape: Tensor = None
+ self.cond_hint_types: Tensor = None
+ # in case it is using the single loader
+ self.single_control_type: str = None
+
+ def get_universal_weights(self) -> ControlWeights:
+ def cn_weights_func(idx: int, control: dict[str, list[Tensor]], key: str):
+ if key == "middle":
+ return 1.0 * self.weights.extras.get(Extras.MIDDLE_MULT, 1.0)
+ c_len = len(control[key])
+ raw_weights = [(self.weights.base_multiplier ** float((c_len) - i)) for i in range(c_len+1)]
+ raw_weights = raw_weights[:-1]
+ if key == "input":
+ raw_weights.reverse()
+ return raw_weights[idx]
+ return self.weights.copy_with_new_weights(new_weight_func=cn_weights_func)
+
+ def verify_control_type(self, model_name: str, pp_group: PlusPlusInputGroup=None):
+ if pp_group is not None:
+ for pp_input in pp_group.controls.values():
+ if PlusPlusType.to_idx(pp_input.control_type) >= self.control_model.num_control_type:
+ raise Exception(f"ControlNet++ model '{model_name}' does not support control_type '{pp_input.control_type}'.")
+ if self.single_control_type is not None:
+ if PlusPlusType.to_idx(self.single_control_type) >= self.control_model.num_control_type:
+ raise Exception(f"ControlNet++ model '{model_name}' does not support control_type '{self.single_control_type}'.")
+
+ def set_cond_hint_inject(self, *args, **kwargs):
+ to_return = super().set_cond_hint_inject(*args, **kwargs)
+ # if not single_control_type, expect PlusPlusImageWrapper
+ if self.single_control_type is None:
+ # check that cond_hint is wrapped, and unwrap it
+ if type(self.cond_hint_original) != PlusPlusImageWrapper:
+ raise Exception("ControlNet++ (Multi) expects image input from the Load ControlNet++ Model node, NOT from anything else. Images are provided to that node via ControlNet++ Input nodes.")
+ self.cond_hint_original = self.cond_hint_original.condhint.clone()
+ # otherwise, expect single image input (AKA, usual controlnet input)
+ else:
+ # check that cond_hint is not a PlusPlusImageWrapper
+ if type(self.cond_hint_original) == PlusPlusImageWrapper:
+ raise Exception("ControlNet++ (Single) expects usual image input, NOT the image input from a Load ControlNet++ Model (Multi) node.")
+ pp_group = PlusPlusInputGroup()
+ pp_input = PlusPlusInput(self.cond_hint_original, self.single_control_type, 1.0)
+ pp_group.add(pp_input)
+ self.cond_hint_original = pp_group
+ return to_return
+
+ def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number, transformer_options):
+ control_prev = None
+ if self.previous_controlnet is not None:
+ control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options)
+
+ if self.timestep_range is not None:
+ if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
+ if control_prev is not None:
+ return control_prev
+ else:
+ return None
+
+ dtype = self.control_model.dtype
+ if self.manual_cast_dtype is not None:
+ dtype = self.manual_cast_dtype
+
+ output_dtype = x_noisy.dtype
+
+ # make all cond_hints appropriate dimensions
+ # TODO: change this to not require cond_hint upscaling every step when self.sub_idxs is present
+ if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint_shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint_shape[3]:
+ if self.cond_hint is not None:
+ del self.cond_hint
+ self.cond_hint = [None] * self.control_model.num_control_type
+ self.cond_hint_types = torch.tensor([0.0] * self.control_model.num_control_type)
+ self.cond_hint_shape = None
+ compression_ratio = self.compression_ratio
+ # unlike normal controlnet, need to handle each input image tensor (for each type)
+ for pp_type, pp_input in self.cond_hint_original.controls.items():
+ pp_idx = PlusPlusType.to_idx(pp_type)
+ # if negative, means no type should be selected (single only)
+ if pp_idx < 0:
+ pp_idx = 0
+ else:
+ self.cond_hint_types[pp_idx] = pp_input.strength
+ # if self.cond_hint_original lengths greater or equal to latent count, subdivide
+ if self.sub_idxs is not None:
+ actual_cond_hint_orig = pp_input.image
+ if pp_input.image.size(0) < self.full_latent_length:
+ actual_cond_hint_orig = extend_to_batch_size(tensor=actual_cond_hint_orig, batch_size=self.full_latent_length)
+ self.cond_hint[pp_idx] = comfy.utils.common_upscale(actual_cond_hint_orig[self.sub_idxs], x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, 'nearest-exact', "center")
+ else:
+ self.cond_hint[pp_idx] = comfy.utils.common_upscale(pp_input.image, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, 'nearest-exact', "center")
+ self.cond_hint[pp_idx] = self.cond_hint[pp_idx].to(device=x_noisy.device, dtype=dtype)
+ self.cond_hint_shape = self.cond_hint[pp_idx].shape
+ # prepare cond_hint_controls to match batchsize
+ if self.cond_hint_types.count_nonzero() == 0:
+ self.cond_hint_types = None
+ else:
+ self.cond_hint_types = self.cond_hint_types.unsqueeze(0).to(device=x_noisy.device, dtype=dtype).repeat(x_noisy.shape[0], 1)
+ for i in range(len(self.cond_hint)):
+ if self.cond_hint[i] is not None:
+ if x_noisy.shape[0] != self.cond_hint[i].shape[0]:
+ self.cond_hint[i] = broadcast_image_to_extend(self.cond_hint[i], x_noisy.shape[0], batched_number)
+ if self.cond_hint_types is not None and x_noisy.shape[0] != self.cond_hint_types.shape[0]:
+ self.cond_hint_types = broadcast_image_to_extend(self.cond_hint_types, x_noisy.shape[0], batched_number, False)
+
+ # prepare mask_cond_hint
+ self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number, dtype=dtype)
+
+ context = cond.get('crossattn_controlnet', cond['c_crossattn'])
+ y = cond.get('y', None)
+ if y is not None:
+ y = comfy.model_base.convert_tensor(y, dtype, x_noisy.device)
+ timestep = self.model_sampling_current.timestep(t)
+ x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
+
+ control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=comfy.model_management.cast_to_device(context, x_noisy.device, dtype), y=y, control_type=self.cond_hint_types)
+ return self.control_merge(control, control_prev, output_dtype)
+
+ def copy(self):
+ c = ControlNetPlusPlusAdvanced(self.control_model, self.timestep_keyframes, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
+ self.copy_to(c)
+ self.copy_to_advanced(c)
+ c.single_control_type = self.single_control_type
+ return c
+
+
+def load_controlnetplusplus(ckpt_path: str, timestep_keyframe: TimestepKeyframeGroup=None, model=None):
+ controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
+ # check that actually is ControlNet++ model
+ if "task_embedding" not in controlnet_data:
+ raise Exception(f"'{ckpt_path}' is not a valid ControlNet++ model.")
+
+ controlnet_config = None
+ supported_inference_dtypes = None
+
+ if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
+ controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data)
+ diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
+ diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
+ diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
+
+ count = 0
+ loop = True
+ while loop:
+ suffix = [".weight", ".bias"]
+ for s in suffix:
+ k_in = "controlnet_down_blocks.{}{}".format(count, s)
+ k_out = "zero_convs.{}.0{}".format(count, s)
+ if k_in not in controlnet_data:
+ loop = False
+ break
+ diffusers_keys[k_in] = k_out
+ count += 1
+
+ count = 0
+ loop = True
+ while loop:
+ suffix = [".weight", ".bias"]
+ for s in suffix:
+ if count == 0:
+ k_in = "controlnet_cond_embedding.conv_in{}".format(s)
+ else:
+ k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s)
+ k_out = "input_hint_block.{}{}".format(count * 2, s)
+ if k_in not in controlnet_data:
+ k_in = "controlnet_cond_embedding.conv_out{}".format(s)
+ loop = False
+ diffusers_keys[k_in] = k_out
+ count += 1
+
+ new_sd = {}
+ for k in diffusers_keys:
+ if k in controlnet_data:
+ new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
+
+ if "control_add_embedding.linear_1.bias" in controlnet_data: #Union Controlnet
+ controlnet_config["union_controlnet_num_control_type"] = controlnet_data["task_embedding"].shape[0]
+ for k in list(controlnet_data.keys()):
+ new_k = k.replace('.attn.in_proj_', '.attn.in_proj.')
+ new_sd[new_k] = controlnet_data.pop(k)
+
+ leftover_keys = controlnet_data.keys()
+ if len(leftover_keys) > 0:
+ logger.warning("leftover ControlNet++ keys: {}".format(leftover_keys))
+ controlnet_data = new_sd
+ elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
+ raise Exception("Unexpected SD3 diffusers format for ControlNet++ model. Something is very wrong.")
+
+ pth_key = 'control_model.zero_convs.0.0.weight'
+ pth = False
+ key = 'zero_convs.0.0.weight'
+ if pth_key in controlnet_data:
+ pth = True
+ key = pth_key
+ prefix = "control_model."
+ elif key in controlnet_data:
+ prefix = ""
+ else:
+ raise Exception("Unexpected T2IAdapter format for ControlNet++ model. Something is very wrong.")
+
+ if controlnet_config is None:
+ model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
+ supported_inference_dtypes = model_config.supported_inference_dtypes
+ controlnet_config = model_config.unet_config
+
+ load_device = comfy.model_management.get_torch_device()
+ if supported_inference_dtypes is None:
+ unet_dtype = comfy.model_management.unet_dtype()
+ else:
+ unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
+
+ manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
+ if manual_cast_dtype is not None:
+ controlnet_config["operations"] = comfy.ops.manual_cast
+ controlnet_config["dtype"] = unet_dtype
+ controlnet_config.pop("out_channels")
+ controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
+ control_model = ControlNetPlusPlus(**controlnet_config)
+
+ if pth:
+ if 'difference' in controlnet_data:
+ if model is not None:
+ comfy.model_management.load_models_gpu([model])
+ model_sd = model.model_state_dict()
+ for x in controlnet_data:
+ c_m = "control_model."
+ if x.startswith(c_m):
+ sd_key = "diffusion_model.{}".format(x[len(c_m):])
+ if sd_key in model_sd:
+ cd = controlnet_data[x]
+ cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
+ else:
+ logger.warning("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
+
+ class WeightsLoader(torch.nn.Module):
+ pass
+ w = WeightsLoader()
+ w.control_model = control_model
+ missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
+ else:
+ missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
+
+ if len(missing) > 0:
+ logger.warning("missing ControlNet++ keys: {}".format(missing))
+
+ if len(unexpected) > 0:
+ logger.debug("unexpected ControlNet++ keys: {}".format(unexpected))
+
+ global_average_pooling = False
+ filename = os.path.splitext(ckpt_path)[0]
+ if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
+ global_average_pooling = True
+
+ control = ControlNetPlusPlusAdvanced(control_model, timestep_keyframes=timestep_keyframe, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
+ return control
diff --git a/custom_nodes/comfyui-advanced-controlnet/adv_control/control_reference.py b/custom_nodes/comfyui-advanced-controlnet/adv_control/control_reference.py
new file mode 100644
index 0000000000000000000000000000000000000000..cba83401aaf0dbe2e5d4ad13344a0ae0d37903bb
--- /dev/null
+++ b/custom_nodes/comfyui-advanced-controlnet/adv_control/control_reference.py
@@ -0,0 +1,1206 @@
+from typing import Callable, Union
+
+from uuid import UUID
+import math
+import torch
+from torch import Tensor
+
+import comfy.model_management
+import comfy.patcher_extension
+import comfy.sample
+import comfy.hooks
+import comfy.model_patcher
+import comfy.utils
+from comfy.controlnet import ControlBase
+from comfy.model_patcher import ModelPatcher
+from comfy.ldm.modules.attention import BasicTransformerBlock
+from comfy.ldm.modules.diffusionmodules import openaimodel
+
+from .logger import logger
+from .utils import (AdvancedControlBase, ControlWeights, TimestepKeyframeGroup, TimestepKeyframe, AbstractPreprocWrapper,
+ broadcast_image_to_extend, ORIG_PREVIOUS_CONTROLNET, CONTROL_INIT_BY_ACN)
+
+
+REF_READ_ATTN_CONTROL_LIST = "ref_read_attn_control_list"
+REF_WRITE_ATTN_CONTROL_LIST = "ref_write_attn_control_list"
+REF_READ_ADAIN_CONTROL_LIST = "ref_read_adain_control_list"
+REF_WRITE_ADAIN_CONTROL_LIST = "ref_write_adain_control_list"
+
+REF_ATTN_CONTROL_LIST = "ref_attn_control_list"
+REF_ADAIN_CONTROL_LIST = "ref_adain_control_list"
+REF_CONTROL_LIST_ALL = "ref_control_list_all"
+REF_CONTROL_INFO = "ref_control_info"
+REF_ATTN_MACHINE_STATE = "ref_attn_machine_state"
+REF_ADAIN_MACHINE_STATE = "ref_adain_machine_state"
+REF_COND_IDXS = "ref_cond_idxs"
+REF_UNCOND_IDXS = "ref_uncond_idxs"
+
+CONTEXTREF_OPTIONS_CLASS = "contextref_options_class"
+CONTEXTREF_CLEAN_FUNC = "contextref_clean_func"
+CONTEXTREF_CONTROL_LIST_ALL = "contextref_control_list_all"
+CONTEXTREF_MACHINE_STATE = "contextref_machine_state"
+CONTEXTREF_TEMP_COND_IDX = "contextref_temp_cond_idx"
+
+HIGHEST_VERSION_SUPPORT = 1
+RETURNED_CONTEXTREF_VERSION = 1
+
+
+class RefConst:
+ OPTS = "refcn_opts"
+ CREF_MODE = "contextref_mode"
+ REFCN_PRESENT_IN_CONDS = "refcn_present_in_conds"
+
+
+class MachineState:
+ WRITE = "write"
+ READ = "read"
+ READ_WRITE = "read_write"
+ STYLEALIGN = "stylealign"
+ OFF = "off"
+
+def is_read(state: str):
+ return state in [MachineState.READ, MachineState.READ_WRITE]
+
+def is_write(state: str):
+ return state in [MachineState.WRITE, MachineState.READ_WRITE]
+
+
+class ReferenceType:
+ ATTN = "reference_attn"
+ ADAIN = "reference_adain"
+ ATTN_ADAIN = "reference_attn+adain"
+ STYLE_ALIGN = "StyleAlign"
+
+ _LIST = [ATTN, ADAIN, ATTN_ADAIN]
+ _LIST_ATTN = [ATTN, ATTN_ADAIN]
+ _LIST_ADAIN = [ADAIN, ATTN_ADAIN]
+
+ @classmethod
+ def is_attn(cls, ref_type: str):
+ return ref_type in cls._LIST_ATTN
+
+ @classmethod
+ def is_adain(cls, ref_type: str):
+ return ref_type in cls._LIST_ADAIN
+
+
+class ReferenceOptions:
+ def __init__(self, reference_type: str,
+ attn_style_fidelity: float, adain_style_fidelity: float,
+ attn_ref_weight: float, adain_ref_weight: float,
+ attn_strength: float=1.0, adain_strength: float=1.0,
+ ref_with_other_cns: bool=False):
+ self.reference_type = reference_type
+ # attn
+ self.original_attn_style_fidelity = attn_style_fidelity
+ self.attn_style_fidelity = attn_style_fidelity
+ self.attn_ref_weight = attn_ref_weight
+ self.attn_strength = attn_strength
+ # adain
+ self.original_adain_style_fidelity = adain_style_fidelity
+ self.adain_style_fidelity = adain_style_fidelity
+ self.adain_ref_weight = adain_ref_weight
+ self.adain_strength = adain_strength
+ # other
+ self.ref_with_other_cns = ref_with_other_cns
+
+ def clone(self):
+ return ReferenceOptions(reference_type=self.reference_type,
+ attn_style_fidelity=self.original_attn_style_fidelity, adain_style_fidelity=self.original_adain_style_fidelity,
+ attn_ref_weight=self.attn_ref_weight, adain_ref_weight=self.adain_ref_weight,
+ attn_strength=self.attn_strength, adain_strength=self.adain_strength,
+ ref_with_other_cns=self.ref_with_other_cns)
+
+ @staticmethod
+ def create_combo(reference_type: str, style_fidelity: float, ref_weight: float, ref_with_other_cns: bool=False):
+ return ReferenceOptions(reference_type=reference_type,
+ attn_style_fidelity=style_fidelity, adain_style_fidelity=style_fidelity,
+ attn_ref_weight=ref_weight, adain_ref_weight=ref_weight,
+ ref_with_other_cns=ref_with_other_cns)
+
+ @staticmethod
+ def create_from_kwargs(attn_style_fidelity=0.0, adain_style_fidelity=0.0,
+ attn_ref_weight=0.0, adain_ref_weight=0.0,
+ attn_strength=0.0, adain_strength=0.0, **kwargs):
+ has_attn = attn_strength > 0.0
+ has_adain = adain_strength > 0.0
+ if has_attn and has_adain:
+ reference_type = ReferenceType.ATTN_ADAIN
+ elif has_adain:
+ reference_type = ReferenceType.ADAIN
+ else:
+ reference_type = ReferenceType.ATTN
+ return ReferenceOptions(reference_type=reference_type,
+ attn_style_fidelity=float(attn_style_fidelity), adain_style_fidelity=float(adain_style_fidelity),
+ attn_ref_weight=float(attn_ref_weight), adain_ref_weight=float(adain_ref_weight),
+ attn_strength=float(attn_strength), adain_strength=float(adain_strength))
+
+
+class ReferencePreprocWrapper(AbstractPreprocWrapper):
+ error_msg = error_msg = "Invalid use of Reference Preprocess output. The output of Reference preprocessor is NOT a usual image, but a latent pretending to be an image - you must connect the output directly to an Apply Advanced ControlNet node. It cannot be used for anything else that accepts IMAGE input."
+ def __init__(self, condhint: Tensor):
+ super().__init__(condhint)
+
+
+class ReferenceAdvanced(ControlBase, AdvancedControlBase):
+ CHANNEL_TO_MULT = {320: 1, 640: 2, 1280: 4}
+
+ def __init__(self, ref_opts: ReferenceOptions, timestep_keyframes: TimestepKeyframeGroup, extra_hooks: comfy.hooks.HookGroup=None):
+ super().__init__()
+ AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controllllite(), allow_condhint_latents=True)
+ # TODO: allow vae_optional to be used instead of preprocessor
+ #require_vae=True
+ self._ref_opts = ref_opts
+ self.order = 0
+ self.model_latent_format = None
+ self.model_sampling_current = None
+ self.should_apply_attn_effective_strength = False
+ self.should_apply_adain_effective_strength = False
+ self.should_apply_effective_masks = False
+ self.latent_shape = None
+ # wrapper hooks
+ self.extra_hooks = extra_hooks.clone() if extra_hooks else self.import_and_create_wrapper_hooks()
+ # ContextRef stuff
+ self.is_context_ref = False
+ self.contextref_cond_idx = -1 # NOTE: does nothing ever since conds got uuids associated with them; can remove
+ self.contextref_version = RETURNED_CONTEXTREF_VERSION
+
+ @property
+ def ref_opts(self):
+ if self._current_timestep_keyframe is not None and self._current_timestep_keyframe.has_control_weights():
+ return self._current_timestep_keyframe.control_weights.extras.get(RefConst.OPTS, self._ref_opts)
+ return self._ref_opts
+
+ def import_and_create_wrapper_hooks(self):
+ from .sampling import create_wrapper_hooks
+ return create_wrapper_hooks()
+
+ def any_attn_strength_to_apply(self):
+ return self.should_apply_attn_effective_strength or self.should_apply_effective_masks
+
+ def any_adain_strength_to_apply(self):
+ return self.should_apply_adain_effective_strength or self.should_apply_effective_masks
+
+ def get_effective_strength(self):
+ effective_strength = self.strength
+ if self._current_timestep_keyframe is not None:
+ effective_strength = effective_strength * self._current_timestep_keyframe.strength
+ return effective_strength
+
+ def get_effective_attn_mask_or_float(self, x: Tensor, channels: int, is_mid: bool):
+ if not self.should_apply_effective_masks:
+ return self.get_effective_strength() * self.ref_opts.attn_strength
+ if is_mid:
+ div = 8
+ else:
+ div = self.CHANNEL_TO_MULT[channels]
+ real_mask = torch.ones([self.latent_shape[0], 1, self.latent_shape[2]//div, self.latent_shape[3]//div]).to(dtype=x.dtype, device=x.device) * self.strength * self.ref_opts.attn_strength
+ self.apply_advanced_strengths_and_masks(x=real_mask, batched_number=self.batched_number)
+ # mask is now shape [b, 1, h ,w]; need to turn into [b, h*w, 1]
+ b, c, h, w = real_mask.shape
+ real_mask = real_mask.permute(0, 2, 3, 1).reshape(b, h*w, c)
+ return real_mask
+
+ def get_effective_adain_mask_or_float(self, x: Tensor):
+ if not self.should_apply_effective_masks:
+ return self.get_effective_strength() * self.ref_opts.adain_strength
+ b, c, h, w = x.shape
+ real_mask = torch.ones([b, 1, h, w]).to(dtype=x.dtype, device=x.device) * self.strength * self.ref_opts.adain_strength
+ self.apply_advanced_strengths_and_masks(x=real_mask, batched_number=self.batched_number)
+ return real_mask
+
+ def get_contextref_mode_replace(self):
+ # used by ADE to get mode_replace for current keyframe
+ if self._current_timestep_keyframe.has_control_weights():
+ return self._current_timestep_keyframe.control_weights.extras.get(RefConst.CREF_MODE, None)
+ return None
+
+ def should_run(self):
+ running = super().should_run()
+ if not running:
+ return running
+ attn_run = False
+ adain_run = False
+ if ReferenceType.is_attn(self.ref_opts.reference_type):
+ # attn will run as long as neither weight or strength is zero
+ attn_run = not (math.isclose(self.ref_opts.attn_ref_weight, 0.0) or math.isclose(self.ref_opts.attn_strength, 0.0))
+ if ReferenceType.is_adain(self.ref_opts.reference_type):
+ # adain will run as long as neither weight or strength is zero
+ adain_run = not (math.isclose(self.ref_opts.adain_ref_weight, 0.0) or math.isclose(self.ref_opts.adain_strength, 0.0))
+ return attn_run or adain_run
+
+ def pre_run_advanced(self, model, percent_to_timestep_function):
+ AdvancedControlBase.pre_run_advanced(self, model, percent_to_timestep_function)
+ if isinstance(self.cond_hint_original, AbstractPreprocWrapper):
+ self.cond_hint_original = self.cond_hint_original.condhint
+ self.model_latent_format = model.latent_format # LatentFormat object, used to process_in latent cond_hint
+ self.model_sampling_current = model.model_sampling
+ # SDXL is more sensitive to style_fidelity according to sd-webui-controlnet comments;
+ # prepare all ref_opts accordingly
+ all_ref_opts = [self._ref_opts]
+ for kf in self.timestep_keyframes.keyframes:
+ if kf.has_control_weights() and RefConst.OPTS in kf.control_weights.extras:
+ all_ref_opts.append(kf.control_weights.extras[RefConst.OPTS])
+ for ropts in all_ref_opts:
+ if type(model).__name__ == "SDXL":
+ ropts.attn_style_fidelity = ropts.original_attn_style_fidelity ** 3.0
+ ropts.adain_style_fidelity = ropts.original_adain_style_fidelity ** 3.0
+ else:
+ ropts.attn_style_fidelity = ropts.original_attn_style_fidelity
+ ropts.adain_style_fidelity = ropts.original_adain_style_fidelity
+
+ def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int, transformer_options):
+ # normal ControlNet stuff
+ control_prev = None
+ if self.previous_controlnet is not None:
+ control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options)
+
+ if self.timestep_range is not None:
+ if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
+ return control_prev
+
+ dtype = x_noisy.dtype
+ # cond_hint_original only matters for RefCN, NOT ContextRef
+ if self.cond_hint_original is not None:
+ # prepare cond_hint - it is a latent, NOT an image
+ #if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] != self.cond_hint.shape[2] or x_noisy.shape[3] != self.cond_hint.shape[3]:
+ if self.cond_hint is not None:
+ del self.cond_hint
+ self.cond_hint = None
+ # if self.cond_hint_original length greater or equal to real latent count, subdivide it before scaling
+ if self.sub_idxs is not None and self.cond_hint_original.size(0) >= self.full_latent_length:
+ self.cond_hint = comfy.utils.common_upscale(
+ self.cond_hint_original[self.sub_idxs],
+ x_noisy.shape[3], x_noisy.shape[2], 'nearest-exact', "center").to(dtype).to(x_noisy.device)
+ else:
+ self.cond_hint = comfy.utils.common_upscale(
+ self.cond_hint_original,
+ x_noisy.shape[3], x_noisy.shape[2], 'nearest-exact', "center").to(dtype).to(x_noisy.device)
+ if x_noisy.shape[0] != self.cond_hint.shape[0]:
+ self.cond_hint = broadcast_image_to_extend(self.cond_hint, x_noisy.shape[0], batched_number, except_one=False)
+ # noise cond_hint based on sigma (current step)
+ self.cond_hint = self.model_latent_format.process_in(self.cond_hint)
+ self.cond_hint = ref_noise_latents(self.cond_hint, sigma=t, noise=None)
+ timestep = self.model_sampling_current.timestep(t)
+ self.should_apply_attn_effective_strength = not (math.isclose(self.strength, 1.0) and math.isclose(self._current_timestep_keyframe.strength, 1.0) and math.isclose(self.ref_opts.attn_strength, 1.0))
+ self.should_apply_adain_effective_strength = not (math.isclose(self.strength, 1.0) and math.isclose(self._current_timestep_keyframe.strength, 1.0) and math.isclose(self.ref_opts.adain_strength, 1.0))
+ # prepare mask - use direct_attn, so the mask dims will match source latents (and be smaller)
+ self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number, direct_attn=True)
+ self.should_apply_effective_masks = self.latent_keyframes is not None or self.mask_cond_hint is not None or self.tk_mask_cond_hint is not None
+ self.latent_shape = list(x_noisy.shape)
+ # done preparing; model patches will take care of everything now.
+ transformer_options[RefConst.REFCN_PRESENT_IN_CONDS] = True
+ # return normal controlnet stuff
+ return control_prev
+
+ def cleanup_advanced(self):
+ super().cleanup_advanced()
+ del self.model_latent_format
+ self.model_latent_format = None
+ del self.model_sampling_current
+ self.model_sampling_current = None
+ self.should_apply_attn_effective_strength = False
+ self.should_apply_adain_effective_strength = False
+ self.should_apply_effective_masks = False
+
+ def copy(self):
+ c = ReferenceAdvanced(self.ref_opts, self.timestep_keyframes, self.extra_hooks)
+ c.order = self.order
+ c.is_context_ref = self.is_context_ref
+ self.copy_to(c)
+ self.copy_to_advanced(c)
+ return c
+
+ # avoid deepcopy shenanigans by making deepcopy not do anything to the reference
+ # TODO: do the bookkeeping to do this in a proper way for all Adv-ControlNets
+ def __deepcopy__(self, memo):
+ return self
+
+
+def handle_context_ref_setup(contextref_obj, transformer_options: dict, conds: dict[str, list[dict, str]]):
+ transformer_options[CONTEXTREF_MACHINE_STATE] = MachineState.OFF
+ # verify version is compatible
+ if contextref_obj.version > HIGHEST_VERSION_SUPPORT:
+ raise Exception(f"AnimateDiff-Evolved's ContextRef v{contextref_obj.version} is not supported in currently-installed Advanced-ControlNet (only supports ContextRef up to v{HIGHEST_VERSION_SUPPORT}); " +
+ f"update your Advanced-ControlNet nodes for ContextRef to work.")
+ # init ReferenceOptions
+ cref_opt_dict = contextref_obj.tune.create_dict() # ContextRefTune obj from ADE
+ opts = ReferenceOptions.create_from_kwargs(**cref_opt_dict)
+ # init TimestepKeyframes
+ cref_tks_list = contextref_obj.keyframe.create_list_of_dicts() # ContextRefKeyframeGroup obj from ADE
+ timestep_keyframes = _create_tks_from_dict_list(cref_tks_list)
+ # create ReferenceAdvanced
+ cref = ReferenceAdvanced(ref_opts=opts, timestep_keyframes=timestep_keyframes)
+ cref.strength = contextref_obj.strength # ContextRef obj from ADE
+ cref.set_cond_hint_mask(contextref_obj.mask)
+ cref.order = 99
+ cref.is_context_ref = True
+ context_ref_list = [cref]
+ transformer_options[CONTEXTREF_CONTROL_LIST_ALL] = context_ref_list
+ transformer_options[CONTEXTREF_OPTIONS_CLASS] = ReferenceOptions
+ _add_context_ref_to_conds(conds, cref)
+ return context_ref_list
+
+
+def _create_tks_from_dict_list(dlist: list[dict[str]]) -> TimestepKeyframeGroup:
+ tks = TimestepKeyframeGroup()
+ if dlist is None or len(dlist) == 0:
+ return tks
+ for d in dlist:
+ # scheduling
+ start_percent = d["start_percent"]
+ guarantee_steps = d["guarantee_steps"]
+ inherit_missing = d["inherit_missing"]
+ # values
+ strength = d["strength"]
+ mask = d["mask"]
+ tune = d["tune"]
+ mode = d["mode"]
+ weights = None
+ extras = {}
+ if tune is not None:
+ cref_opt_dict = tune.create_dict() # ContextRefTune obj from ADE
+ opts = ReferenceOptions.create_from_kwargs(**cref_opt_dict)
+ extras[RefConst.OPTS] = opts
+ if mode is not None:
+ extras[RefConst.CREF_MODE] = mode
+ weights = ControlWeights.default(extras=extras)
+ # create keyframe
+ tk = TimestepKeyframe(start_percent=start_percent, guarantee_steps=guarantee_steps, inherit_missing=inherit_missing,
+ strength=strength, mask_hint_orig=mask, control_weights=weights)
+ tks.add(tk)
+ return tks
+
+
+def _add_context_ref_to_conds(conds: dict[list[dict[str]]], context_ref: ReferenceAdvanced):
+ def _add_context_ref_to_existing_control(control: ControlBase, context_ref: ReferenceAdvanced):
+ curr_cn = control
+ while curr_cn is not None:
+ if type(curr_cn) == ReferenceAdvanced and curr_cn.is_context_ref:
+ break
+ if curr_cn.previous_controlnet is not None:
+ curr_cn = curr_cn.previous_controlnet
+ continue
+ orig_previous_controlnet = curr_cn.previous_controlnet
+ # NOTE: code is already in place to restore any ORIG_PREVIOUS_CONTROLNET props
+ setattr(curr_cn, ORIG_PREVIOUS_CONTROLNET, orig_previous_controlnet)
+ curr_cn.previous_controlnet = context_ref
+ curr_cn = orig_previous_controlnet
+
+ def _add_context_ref(actual_cond: dict[str], context_ref: ReferenceAdvanced):
+ # if controls already present on cond, add it to the last previous_controlnet
+ if "control" in actual_cond:
+ return _add_context_ref_to_existing_control(actual_cond["control"], context_ref)
+ # otherwise, need to add it to begin with, and should mark that it should be cleaned after
+ actual_cond["control"] = context_ref
+ actual_cond[CONTROL_INIT_BY_ACN] = True
+
+ # either add context_ref to end of existing cnet chain, or init 'control' key on actual cond
+ for cond_type in conds:
+ cond = conds[cond_type]
+ if cond is not None:
+ for actual_cond in cond:
+ _add_context_ref(actual_cond, context_ref)
+
+
+def ref_noise_latents(latents: Tensor, sigma: Tensor, noise: Tensor=None):
+ sigma = sigma.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
+ alpha_cumprod = 1 / ((sigma * sigma) + 1)
+ sqrt_alpha_prod = alpha_cumprod ** 0.5
+ sqrt_one_minus_alpha_prod = (1. - alpha_cumprod) ** 0.5
+ if noise is None:
+ # generator = torch.Generator(device="cuda")
+ # generator.manual_seed(0)
+ # noise = torch.empty_like(latents).normal_(generator=generator)
+ # generator = torch.Generator()
+ # generator.manual_seed(0)
+ # noise = torch.randn(latents.size(), generator=generator).to(latents.device)
+ noise = torch.randn_like(latents).to(latents.device)
+ return sqrt_alpha_prod * latents + sqrt_one_minus_alpha_prod * noise
+
+
+def simple_noise_latents(latents: Tensor, sigma: float, noise: Tensor=None):
+ if noise is None:
+ noise = torch.rand_like(latents)
+ return latents + noise * sigma
+
+
+class BankStylesBasicTransformerBlock:
+ def __init__(self):
+ # ref
+ self.bank = []
+ self.style_cfgs = []
+ self.cn_idx: list[int] = []
+ # contextref - list of lists as each cond/uncond stored separately
+ self.c_bank: dict[UUID, list[Tensor]] = {}
+ self.c_style_cfgs: dict[UUID, list[float]] = {}
+ self.c_cn_idx: dict[UUID, list[int]] = {}
+
+ def set_c_bank_for_uuids(self, x: Tensor, uuids: list[UUID]):
+ per_uuid = len(x) // len(uuids)
+ for uuid, i in zip(uuids, list(range(0, len(x), per_uuid))):
+ self.c_bank.setdefault(uuid, []).append(x[i:i+per_uuid])
+
+ def _get_c_bank_for_uuids(self, uuids: list[UUID]):
+ per_i: list[list[Tensor]] = []
+ for uuid in uuids:
+ for i, bank in enumerate(self.c_bank[uuid]):
+ if i >= len(per_i):
+ per_i.append([])
+ per_i[i].append(bank)
+ real_banks = []
+ for bank in per_i:
+ if len(bank) == 1:
+ combined = bank[0]
+ else:
+ combined = torch.cat(bank, dim=0)
+ real_banks.append(combined)
+ return real_banks
+
+ def get_bank(self, uuids: list[UUID], ignore_contextref, cdevice=None):
+ if ignore_contextref:
+ return self.bank
+ real_c_bank_list = self._get_c_bank_for_uuids(uuids)
+ if cdevice != None:
+ real_c_bank_list = real_c_bank_list.copy()
+ for i in range(len(real_c_bank_list)):
+ real_c_bank_list[i] = real_c_bank_list[i].to(cdevice)
+ return self.bank + real_c_bank_list
+
+
+ def set_c_style_cfgs_for_uuids(self, style_cfg: float, uuids: list[UUID]):
+ for uuid in uuids:
+ self.c_style_cfgs.setdefault(uuid, []).append(style_cfg)
+
+ def get_avg_style_fidelity(self, uuids: list[UUID], ignore_contextref):
+ if ignore_contextref:
+ return sum(self.style_cfgs) / float(len(self.style_cfgs))
+ combined = self.style_cfgs + self._get_c_style_cfgs_for_uuids(uuids)
+ return sum(combined) / float(len(combined))
+
+ def _get_c_style_cfgs_for_uuids(self, uuids: list[UUID]):
+ # c_style_cfgs will be the same for all provided uuids
+ return self.c_style_cfgs[uuids[0]]
+
+
+ def set_c_cn_idx_for_uuids(self, cn_idx: int, uuids: list[UUID]):
+ for uuid in uuids:
+ self.c_cn_idx.setdefault(uuid, []).append(cn_idx)
+
+ def get_cn_idxs(self, uuids: list[UUID], ignore_contxtref):
+ if ignore_contxtref:
+ return self.cn_idx
+ return self.cn_idx + self._get_c_cn_idxs_for_uuids(uuids)
+
+ def _get_c_cn_idxs_for_uuids(self, uuids: list[UUID]):
+ # c_cn_idxs will be the same for all provided uuids
+ return self.c_cn_idx.get(uuids[0], [])
+
+
+ def init_cref_for_uuids(self, uuids: list[UUID]):
+ for uuid in uuids:
+ self.c_bank.setdefault(uuid, [])
+ self.c_style_cfgs.setdefault(uuid, [])
+ self.c_cn_idx.setdefault(uuid, [])
+
+ def clear_cref_for_uuids(self, uuids: list[UUID]):
+ for uuid in uuids:
+ self.c_bank[uuid] = []
+ self.c_style_cfgs[uuid] = []
+ self.c_cn_idx[uuid] = []
+
+ def clean_ref(self):
+ del self.bank
+ del self.style_cfgs
+ del self.cn_idx
+ self.bank = []
+ self.style_cfgs = []
+ self.cn_idx = []
+
+ def clean_contextref(self):
+ del self.c_bank
+ del self.c_style_cfgs
+ del self.c_cn_idx
+ self.c_bank = {}
+ self.c_style_cfgs = {}
+ self.c_cn_idx = {}
+
+ def clean_all(self):
+ self.clean_ref()
+ self.clean_contextref()
+
+
+class BankStylesTimestepEmbedSequential:
+ def __init__(self):
+ # ref
+ self.var_bank = []
+ self.mean_bank = []
+ self.style_cfgs = []
+ self.cn_idx: list[int] = []
+ # cref
+ self.c_var_bank: dict[UUID, list[Tensor]] = {}
+ self.c_mean_bank: dict[UUID, list[Tensor]] = {}
+ self.c_style_cfgs: dict[UUID, list[float]] = {}
+ self.c_cn_idx: dict[UUID, list[int]] = {}
+
+ def set_c_var_bank_for_uuids(self, var: Tensor, uuids: list[UUID]):
+ for uuid in uuids:
+ self.c_var_bank.setdefault(uuid, []).append(var)
+
+ def get_var_bank(self, uuids: list[UUID], ignore_contextref):
+ if ignore_contextref:
+ return self.var_bank
+ return self.var_bank + self._get_c_var_bank_for_uuids(uuids)
+
+ def _get_c_var_bank_for_uuids(self, uuids: list[UUID]):
+ return self.c_var_bank.get(uuids[0], [])
+
+
+ def set_c_mean_bank_for_uuids(self, mean: Tensor, uuids: list[UUID]):
+ for uuid in uuids:
+ self.c_mean_bank.setdefault(uuid, []).append(mean)
+
+ def get_mean_bank(self, uuids: list[UUID], ignore_contextref):
+ if ignore_contextref:
+ return self.mean_bank
+ return self.mean_bank + self._get_c_mean_bank_for_uuids(uuids)
+
+ def _get_c_mean_bank_for_uuids(self, uuids: list[UUID]):
+ return self.c_mean_bank.get(uuids[0], [])
+
+
+ def set_c_style_cfgs_for_uuids(self, style_cfg: float, uuids: list[UUID]):
+ for uuid in uuids:
+ self.c_style_cfgs.setdefault(uuid, []).append(style_cfg)
+
+ def get_style_cfgs(self, uuids: list[UUID], ignore_contextref):
+ if ignore_contextref:
+ return self.style_cfgs
+ return self.style_cfgs + self._get_c_style_cfgs_for_uuids(uuids)
+
+ def _get_c_style_cfgs_for_uuids(self, uuids: list[UUID]):
+ return self.c_style_cfgs.get(uuids[0], [])
+
+
+ def set_c_cn_idx_for_uuids(self, cn_idx: int, uuids: list[UUID]):
+ for uuid in uuids:
+ self.c_cn_idx.setdefault(uuid, []).append(cn_idx)
+
+ def get_cn_idxs(self, uuids: list[UUID], ignore_contextref):
+ if ignore_contextref:
+ return self.cn_idx
+ return self.cn_idx + self._get_c_cn_idxs_for_uuids(uuids)
+
+ def _get_c_cn_idxs_for_uuids(self, uuids: list[UUID]):
+ return self.c_cn_idx.get(uuids[0], [])
+
+
+ def init_cref_for_uuids(self, uuids: list[UUID]):
+ for uuid in uuids:
+ self.c_var_bank.setdefault(uuid, [])
+ self.c_mean_bank.setdefault(uuid, [])
+ self.c_style_cfgs.setdefault(uuid, [])
+ self.c_cn_idx.setdefault(uuid, [])
+
+ def clear_cref_for_uuids(self, uuids: list[UUID]):
+ for uuid in uuids:
+ self.c_var_bank[uuid] = []
+ self.c_mean_bank[uuid] = []
+ self.c_style_cfgs[uuid] = []
+ self.c_cn_idx[uuid] = []
+
+ def clean_ref(self):
+ del self.mean_bank
+ del self.var_bank
+ del self.style_cfgs
+ del self.cn_idx
+ self.mean_bank = []
+ self.var_bank = []
+ self.style_cfgs = []
+ self.cn_idx = []
+
+ def clean_contextref(self):
+ del self.c_var_bank
+ del self.c_mean_bank
+ del self.c_style_cfgs
+ del self.c_cn_idx
+ self.c_var_bank = {}
+ self.c_mean_bank = {}
+ self.c_style_cfgs = {}
+ self.c_cn_idx = {}
+
+ def clean_all(self):
+ self.clean_ref()
+ self.clean_contextref()
+
+
+class InjectionBasicTransformerBlockHolder:
+ def __init__(self, block: BasicTransformerBlock, idx=None):
+ if hasattr(block, "_forward"): # backward compatibility
+ self.original_forward = block._forward
+ else:
+ self.original_forward = block.forward
+ self.idx = idx
+ self.attn_weight = 1.0
+ self.is_middle = False
+ self.bank_styles = BankStylesBasicTransformerBlock()
+
+ def restore(self, block: BasicTransformerBlock):
+ if hasattr(block, "_forward"): # backward compatibility
+ block._forward = self.original_forward
+ else:
+ block.forward = self.original_forward
+
+ def clean_ref(self):
+ self.bank_styles.clean_ref()
+
+ def clean_contextref(self):
+ self.bank_styles.clean_contextref()
+
+ def clean_all(self):
+ self.bank_styles.clean_all()
+
+
+class InjectionTimestepEmbedSequentialHolder:
+ def __init__(self, block: openaimodel.TimestepEmbedSequential, idx=None, is_middle=False, is_input=False, is_output=False):
+ self.original_forward = block.forward
+ self.idx = idx
+ self.gn_weight = 1.0
+ self.is_middle = is_middle
+ self.is_input = is_input
+ self.is_output = is_output
+ self.bank_styles = BankStylesTimestepEmbedSequential()
+
+ def restore(self, block: openaimodel.TimestepEmbedSequential):
+ block.forward = self.original_forward
+
+ def clean_ref(self):
+ self.bank_styles.clean_ref()
+
+ def clean_contextref(self):
+ self.bank_styles.clean_contextref()
+
+ def clean_all(self):
+ self.bank_styles.clean_all()
+
+
+class ReferenceInjections:
+ def __init__(self, attn_modules: list['RefBasicTransformerBlock']=None, gn_modules: list['RefTimestepEmbedSequential']=None):
+ self.attn_modules = attn_modules if attn_modules else []
+ self.gn_modules = gn_modules if gn_modules else []
+
+ def clean_ref_module_mem(self):
+ for attn_module in self.attn_modules:
+ try:
+ attn_module.injection_holder.clean_ref()
+ except Exception:
+ pass
+ for gn_module in self.gn_modules:
+ try:
+ gn_module.injection_holder.clean_ref()
+ except Exception:
+ pass
+
+ def clean_contextref_module_mem(self):
+ for attn_module in self.attn_modules:
+ try:
+ attn_module.injection_holder.clean_contextref()
+ except Exception:
+ pass
+ for gn_module in self.gn_modules:
+ try:
+ gn_module.injection_holder.clean_contextref()
+ except Exception:
+ pass
+
+ def clean_all_module_mem(self):
+ for attn_module in self.attn_modules:
+ try:
+ attn_module.injection_holder.clean_all()
+ except Exception:
+ pass
+ for gn_module in self.gn_modules:
+ try:
+ gn_module.injection_holder.clean_all()
+ except Exception:
+ pass
+
+ def cleanup(self):
+ self.clean_all_module_mem()
+ del self.attn_modules
+ self.attn_modules = []
+ del self.gn_modules
+ self.gn_modules = []
+
+
+def handle_reference_injection(model_options: dict, reference_injections: ReferenceInjections):
+ # register wrapper functions on transformer_options
+ comfy.patcher_extension.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL,
+ "ACN_refcn_diffusion_model",
+ refcn_diffusion_model_wrapper_factory(reference_injections),
+ model_options, is_model_options=True)
+
+
+def refcn_diffusion_model_wrapper_factory(reference_injections: ReferenceInjections):
+ def refcn_diffusion_model_wrapper(executor, x, *args, **kwargs):
+ # get control and transformer_options from args
+ real_args = list(args)
+ real_kwargs = list(kwargs.keys())
+ # args values (x is treated separately, so all args are actually shifted by -1):
+ # -1: x
+ # 0: timesteps
+ # 1: context
+ # 2: y
+ # 3: control
+ # 4: transformer_options
+ control = args[3]
+ transformer_options = args[4]
+ # NOTE: adds support for both ReferenceCN and ContextRef, so need to track them separately
+ # get ReferenceAdvanced objects
+ ref_controlnets: list[ReferenceAdvanced] = transformer_options.get(REF_CONTROL_LIST_ALL, [])
+ context_controlnets: list[ReferenceAdvanced] = transformer_options.get(CONTEXTREF_CONTROL_LIST_ALL, [])
+ # clean contextref stuff if OFF
+ if len(context_controlnets) > 0 and transformer_options[CONTEXTREF_MACHINE_STATE] == MachineState.OFF:
+ reference_injections.clean_contextref_module_mem()
+ context_controlnets = []
+ # discard any controlnets that should not run
+ refcn_present_in_conds = transformer_options.get(RefConst.REFCN_PRESENT_IN_CONDS, False)
+ if refcn_present_in_conds:
+ ref_controlnets = [z for z in ref_controlnets if z.should_run()]
+ else:
+ ref_controlnets = []
+ context_controlnets = [z for z in context_controlnets if z.should_run()]
+ # if nothing related to reference controlnets, do nothing special
+ if len(ref_controlnets) == 0 and len(context_controlnets) == 0:
+ return executor(x, *args, **kwargs)
+ try:
+ # assign cond and uncond idxs
+ batched_number = len(transformer_options["cond_or_uncond"])
+ per_batch = x.shape[0] // batched_number
+ indiv_conds = []
+ for cond_type in transformer_options["cond_or_uncond"]:
+ indiv_conds.extend([cond_type] * per_batch)
+ transformer_options[REF_UNCOND_IDXS] = [i for i, z in enumerate(indiv_conds) if z == 1]
+ transformer_options[REF_COND_IDXS] = [i for i, z in enumerate(indiv_conds) if z == 0]
+ # check which controlnets do which thing
+ attn_controlnets = []
+ adain_controlnets = []
+ for control in ref_controlnets:
+ if ReferenceType.is_attn(control.ref_opts.reference_type):
+ attn_controlnets.append(control)
+ if ReferenceType.is_adain(control.ref_opts.reference_type):
+ adain_controlnets.append(control)
+ context_attn_controlnets = []
+ context_adain_controlnets = []
+ # for ease of access, store current contextref_cond_idx value
+ if len(context_controlnets) == 0:
+ transformer_options[CONTEXTREF_TEMP_COND_IDX] = -1
+ else:
+ transformer_options[CONTEXTREF_TEMP_COND_IDX] = context_controlnets[0].contextref_cond_idx
+ # logger.info(f"{transformer_options[CONTEXTREF_MACHINE_STATE]}: {transformer_options[CONTEXTREF_TEMP_COND_IDX]}")
+
+ for control in context_controlnets:
+ if ReferenceType.is_attn(control.ref_opts.reference_type):
+ context_attn_controlnets.append(control)
+ if ReferenceType.is_adain(control.ref_opts.reference_type):
+ context_adain_controlnets.append(control)
+ if len(adain_controlnets) > 0 or len(context_adain_controlnets) > 0:
+ # ComfyUI uses forward_timestep_embed with the TimestepEmbedSequential passed into it
+ orig_forward_timestep_embed = openaimodel.forward_timestep_embed
+ openaimodel.forward_timestep_embed = forward_timestep_embed_ref_inject_factory(orig_forward_timestep_embed)
+
+ # if RefCN to be used, handle running diffusion with ref cond hints
+ if len(ref_controlnets) > 0:
+ for control in ref_controlnets:
+ read_attn_list = []
+ write_attn_list = []
+ read_adain_list = []
+ write_adain_list = []
+
+ if ReferenceType.is_attn(control.ref_opts.reference_type):
+ write_attn_list.append(control)
+ if ReferenceType.is_adain(control.ref_opts.reference_type):
+ write_adain_list.append(control)
+ # apply lists
+ transformer_options[REF_READ_ATTN_CONTROL_LIST] = read_attn_list
+ transformer_options[REF_WRITE_ATTN_CONTROL_LIST] = write_attn_list
+ transformer_options[REF_READ_ADAIN_CONTROL_LIST] = read_adain_list
+ transformer_options[REF_WRITE_ADAIN_CONTROL_LIST] = write_adain_list
+
+ orig_args = args
+ # disable other controlnets for this run, if specified
+ if not control.ref_opts.ref_with_other_cns:
+ args = list(args)
+ args[3] = None
+ args = tuple(args)
+ executor(control.cond_hint.to(dtype=x.dtype).to(device=x.device), *args, **kwargs)
+ args = orig_args
+ # prepare running diffusion for real now
+ read_attn_list = []
+ write_attn_list = []
+ read_adain_list = []
+ write_adain_list = []
+
+ # add RefCNs to read lists
+ read_attn_list.extend(attn_controlnets)
+ read_adain_list.extend(adain_controlnets)
+
+ # do contextref stuff, if needed
+ if len(context_controlnets) > 0:
+ # clean contextref stuff if first WRITE
+ # if context_controlnets[0].contextref_cond_idx == 0 and is_write(transformer_options[CONTEXTREF_MACHINE_STATE]):
+ # reference_injections.clean_contextref_module_mem()
+ ### add ContextRef to appropriate lists
+ # attn
+ if is_read(transformer_options[CONTEXTREF_MACHINE_STATE]):
+ read_attn_list.extend(context_attn_controlnets)
+ if is_write(transformer_options[CONTEXTREF_MACHINE_STATE]):
+ write_attn_list.extend(context_attn_controlnets)
+ # adain
+ if is_read(transformer_options[CONTEXTREF_MACHINE_STATE]):
+ read_adain_list.extend(context_adain_controlnets)
+ if is_write(transformer_options[CONTEXTREF_MACHINE_STATE]):
+ write_adain_list.extend(context_adain_controlnets)
+ # apply lists, containing both RefCN and ContextRef
+ transformer_options[REF_READ_ATTN_CONTROL_LIST] = read_attn_list
+ transformer_options[REF_WRITE_ATTN_CONTROL_LIST] = write_attn_list
+ transformer_options[REF_READ_ADAIN_CONTROL_LIST] = read_adain_list
+ transformer_options[REF_WRITE_ADAIN_CONTROL_LIST] = write_adain_list
+ # run diffusion for real
+ try:
+ return executor(x, *args, **kwargs)
+ finally:
+ # increment current cond idx
+ if len(context_controlnets) > 0:
+ for cn in context_controlnets:
+ cn.contextref_cond_idx += 1
+ finally:
+ # make sure ref banks are cleared no matter what happens - otherwise, RIP VRAM
+ reference_injections.clean_ref_module_mem()
+ if len(adain_controlnets) > 0 or len(context_adain_controlnets) > 0:
+ openaimodel.forward_timestep_embed = orig_forward_timestep_embed
+ return refcn_diffusion_model_wrapper
+
+
+# dummy class just to help IDE keep track of injected variables
+class RefBasicTransformerBlock(BasicTransformerBlock):
+ injection_holder: InjectionBasicTransformerBlockHolder = None
+
+def _forward_inject_BasicTransformerBlock(self: RefBasicTransformerBlock, x: Tensor, context: Tensor=None, transformer_options: dict[str]={}):
+ extra_options = {}
+ block = transformer_options.get("block", None)
+ block_index = transformer_options.get("block_index", 0)
+ transformer_patches = {}
+ transformer_patches_replace = {}
+
+ for k in transformer_options:
+ if k == "patches":
+ transformer_patches = transformer_options[k]
+ elif k == "patches_replace":
+ transformer_patches_replace = transformer_options[k]
+ else:
+ extra_options[k] = transformer_options[k]
+
+ extra_options["n_heads"] = self.n_heads
+ extra_options["dim_head"] = self.d_head
+
+ if self.ff_in:
+ x_skip = x
+ x = self.ff_in(self.norm_in(x))
+ if self.is_res:
+ x += x_skip
+
+ n: Tensor = self.norm1(x)
+ if self.disable_self_attn:
+ context_attn1 = context
+ else:
+ context_attn1 = None
+ value_attn1 = None
+
+ # Reference CN stuff
+ uc_idx_mask = transformer_options.get(REF_UNCOND_IDXS, [])
+ uuids = transformer_options["uuids"]
+ cref_mode = transformer_options.get(CONTEXTREF_MACHINE_STATE, MachineState.OFF)
+ #c_idx_mask = transformer_options.get(REF_COND_IDXS, [])
+ # WRITE mode may have only 1 ReferenceAdvanced for RefCN at a time, other modes will have all ReferenceAdvanced
+ ref_write_cns: list[ReferenceAdvanced] = transformer_options.get(REF_WRITE_ATTN_CONTROL_LIST, [])
+ ref_read_cns: list[ReferenceAdvanced] = transformer_options.get(REF_READ_ATTN_CONTROL_LIST, [])
+ ignore_contextref_read = cref_mode in [MachineState.OFF, MachineState.WRITE]
+ #logger.info(f"cref: {cref_cond_idx}, cmode: {cref_mode}, ignored: {ignore_contextref_read}")
+
+ cached_n = None
+ cref_write_cns: list[ReferenceAdvanced] = []
+ # check if any WRITE cns are applicable; Reference CN WRITEs immediately, ContextREF WRITEs after READ completed
+ # if any refs to WRITE, save n and style_fidelity
+ for refcn in ref_write_cns:
+ if refcn.ref_opts.attn_ref_weight > self.injection_holder.attn_weight:
+ if cached_n is None:
+ cached_n = n.detach().clone()
+ # for ContextRef, make sure relevant lists are long enough to cond_idx
+ # store RefCN and ContextRef stuff separately
+ if refcn.is_context_ref:
+ cref_write_cns.append(refcn)
+ self.injection_holder.bank_styles.init_cref_for_uuids(uuids)
+ else: # Reference CN WRITE
+ self.injection_holder.bank_styles.bank.append(cached_n)
+ self.injection_holder.bank_styles.style_cfgs.append(refcn.ref_opts.attn_style_fidelity)
+ self.injection_holder.bank_styles.cn_idx.append(refcn.order)
+ if len(cref_write_cns) == 0:
+ del cached_n
+
+ if "attn1_patch" in transformer_patches:
+ patch = transformer_patches["attn1_patch"]
+ if context_attn1 is None:
+ context_attn1 = n
+ value_attn1 = context_attn1
+ for p in patch:
+ n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options)
+
+ if block is not None:
+ transformer_block = (block[0], block[1], block_index)
+ else:
+ transformer_block = None
+ attn1_replace_patch = transformer_patches_replace.get("attn1", {})
+ block_attn1 = transformer_block
+ if block_attn1 not in attn1_replace_patch:
+ block_attn1 = block
+
+ if block_attn1 in attn1_replace_patch:
+ if context_attn1 is None:
+ context_attn1 = n
+ value_attn1 = n
+ n = self.attn1.to_q(n)
+ # Reference CN READ - use attn1_replace_patch appropriately
+ if len(ref_read_cns) > 0 and len(self.injection_holder.bank_styles.get_cn_idxs(uuids, ignore_contextref_read)) > 0:
+ bank_styles = self.injection_holder.bank_styles
+ style_fidelity = bank_styles.get_avg_style_fidelity(uuids, ignore_contextref_read)
+ real_bank = bank_styles.get_bank(uuids, ignore_contextref_read, cdevice=n.device).copy()
+ real_cn_idxs = bank_styles.get_cn_idxs(uuids, ignore_contextref_read)
+ cn_idx = 0
+ for idx, order in enumerate(real_cn_idxs):
+ # make sure matching ref cn is selected
+ for i in range(cn_idx, len(ref_read_cns)):
+ if ref_read_cns[i].order == order:
+ cn_idx = i
+ break
+ assert order == ref_read_cns[cn_idx].order
+ if ref_read_cns[cn_idx].any_attn_strength_to_apply():
+ effective_strength = ref_read_cns[cn_idx].get_effective_attn_mask_or_float(x=n, channels=n.shape[2], is_mid=self.injection_holder.is_middle)
+ real_bank[idx] = real_bank[idx] * effective_strength + context_attn1 * (1-effective_strength)
+ n_uc = self.attn1.to_out(attn1_replace_patch[block_attn1](
+ n,
+ self.attn1.to_k(torch.cat([context_attn1] + real_bank, dim=1)),
+ self.attn1.to_v(torch.cat([value_attn1] + real_bank, dim=1)),
+ extra_options))
+ n_c = n_uc.clone()
+ if len(uc_idx_mask) > 0 and not math.isclose(style_fidelity, 0.0):
+ n_c[uc_idx_mask] = self.attn1.to_out(attn1_replace_patch[block_attn1](
+ n[uc_idx_mask],
+ self.attn1.to_k(context_attn1[uc_idx_mask]),
+ self.attn1.to_v(value_attn1[uc_idx_mask]),
+ extra_options))
+ n = style_fidelity * n_c + (1.0-style_fidelity) * n_uc
+ bank_styles.clean_ref()
+ else:
+ context_attn1 = self.attn1.to_k(context_attn1)
+ value_attn1 = self.attn1.to_v(value_attn1)
+ n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options)
+ n = self.attn1.to_out(n)
+ else:
+ # Reference CN READ - no attn1_replace_patch
+ if len(ref_read_cns) > 0 and len(self.injection_holder.bank_styles.get_cn_idxs(uuids, ignore_contextref_read)) > 0:
+ if context_attn1 is None:
+ context_attn1 = n
+ bank_styles = self.injection_holder.bank_styles
+ style_fidelity = bank_styles.get_avg_style_fidelity(uuids, ignore_contextref_read)
+ real_bank = bank_styles.get_bank(uuids, ignore_contextref_read, cdevice=n.device).copy()
+ real_cn_idxs = bank_styles.get_cn_idxs(uuids, ignore_contextref_read)
+ cn_idx = 0
+ for idx, order in enumerate(real_cn_idxs):
+ # make sure matching ref cn is selected
+ for i in range(cn_idx, len(ref_read_cns)):
+ if ref_read_cns[i].order == order:
+ cn_idx = i
+ break
+ assert order == ref_read_cns[cn_idx].order
+ if ref_read_cns[cn_idx].any_attn_strength_to_apply():
+ effective_strength = ref_read_cns[cn_idx].get_effective_attn_mask_or_float(x=n, channels=n.shape[2], is_mid=self.injection_holder.is_middle)
+ real_bank[idx] = real_bank[idx] * effective_strength + context_attn1 * (1-effective_strength)
+ n_uc: Tensor = self.attn1(
+ n,
+ context=torch.cat([context_attn1] + real_bank, dim=1),
+ value=torch.cat([value_attn1] + real_bank, dim=1) if value_attn1 is not None else value_attn1)
+ n_c = n_uc.clone()
+ if len(uc_idx_mask) > 0 and not math.isclose(style_fidelity, 0.0):
+ n_c[uc_idx_mask] = self.attn1(
+ n[uc_idx_mask],
+ context=context_attn1[uc_idx_mask],
+ value=value_attn1[uc_idx_mask] if value_attn1 is not None else value_attn1)
+ n = style_fidelity * n_c + (1.0-style_fidelity) * n_uc
+ bank_styles.clean_ref()
+ else:
+ n = self.attn1(n, context=context_attn1, value=value_attn1)
+
+ # ContextRef CN WRITE
+ if len(cref_write_cns) > 0:
+ # clear so that ContextRef CNs can properly 'replace' previous value at relevant uuids
+ self.injection_holder.bank_styles.clear_cref_for_uuids(uuids)
+ for refcn in cref_write_cns:
+ # add a whole list to match expected type when combining
+ self.injection_holder.bank_styles.set_c_bank_for_uuids(cached_n.to(comfy.model_management.unet_offload_device()), uuids)
+ self.injection_holder.bank_styles.set_c_style_cfgs_for_uuids(refcn.ref_opts.attn_style_fidelity, uuids)
+ self.injection_holder.bank_styles.set_c_cn_idx_for_uuids(refcn.order, uuids)
+ del cached_n
+
+ if "attn1_output_patch" in transformer_patches:
+ patch = transformer_patches["attn1_output_patch"]
+ for p in patch:
+ n = p(n, extra_options)
+
+ x += n
+ if "middle_patch" in transformer_patches:
+ patch = transformer_patches["middle_patch"]
+ for p in patch:
+ x = p(x, extra_options)
+
+ if self.attn2 is not None:
+ n = self.norm2(x)
+ if self.switch_temporal_ca_to_sa:
+ context_attn2 = n
+ else:
+ context_attn2 = context
+ value_attn2 = None
+ if "attn2_patch" in transformer_patches:
+ patch = transformer_patches["attn2_patch"]
+ value_attn2 = context_attn2
+ for p in patch:
+ n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
+
+ attn2_replace_patch = transformer_patches_replace.get("attn2", {})
+ block_attn2 = transformer_block
+ if block_attn2 not in attn2_replace_patch:
+ block_attn2 = block
+
+ if block_attn2 in attn2_replace_patch:
+ if value_attn2 is None:
+ value_attn2 = context_attn2
+ n = self.attn2.to_q(n)
+ context_attn2 = self.attn2.to_k(context_attn2)
+ value_attn2 = self.attn2.to_v(value_attn2)
+ n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
+ n = self.attn2.to_out(n)
+ else:
+ n = self.attn2(n, context=context_attn2, value=value_attn2)
+
+ if "attn2_output_patch" in transformer_patches:
+ patch = transformer_patches["attn2_output_patch"]
+ for p in patch:
+ n = p(n, extra_options)
+
+ x += n
+ if self.is_res:
+ x_skip = x
+ x = self.ff(self.norm3(x))
+ if self.is_res:
+ x += x_skip
+
+ return x
+
+
+class RefTimestepEmbedSequential(openaimodel.TimestepEmbedSequential):
+ injection_holder: InjectionTimestepEmbedSequentialHolder = None
+
+def forward_timestep_embed_ref_inject_factory(orig_timestep_embed_inject_factory: Callable):
+ def forward_timestep_embed_ref_inject(*args, **kwargs):
+ ts: RefTimestepEmbedSequential = args[0]
+ if not hasattr(ts, "injection_holder"):
+ return orig_timestep_embed_inject_factory(*args, **kwargs)
+ eps = 1e-6
+ x: Tensor = orig_timestep_embed_inject_factory(*args, **kwargs)
+ y: Tensor = None
+ transformer_options: dict[str] = args[4]
+ # Reference CN stuff
+ uc_idx_mask = transformer_options.get(REF_UNCOND_IDXS, [])
+ uuids = transformer_options["uuids"]
+ cref_mode = transformer_options.get(CONTEXTREF_MACHINE_STATE, MachineState.OFF)
+ #c_idx_mask = transformer_options.get(REF_COND_IDXS, [])
+ # WRITE mode will only have one ReferenceAdvanced, other modes will have all ReferenceAdvanced
+ ref_write_cns: list[ReferenceAdvanced] = transformer_options.get(REF_WRITE_ADAIN_CONTROL_LIST, [])
+ ref_read_cns: list[ReferenceAdvanced] = transformer_options.get(REF_READ_ADAIN_CONTROL_LIST, [])
+ ignore_contextref_read = cref_mode in [MachineState.OFF, MachineState.WRITE]
+
+ cached_var = None
+ cached_mean = None
+ cref_write_cns: list[ReferenceAdvanced] = []
+ # if any refs to WRITE, save var, mean, and style_cfg
+ for refcn in ref_write_cns:
+ if refcn.ref_opts.adain_ref_weight > ts.injection_holder.gn_weight:
+ if cached_var is None:
+ cached_var, cached_mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
+ if refcn.is_context_ref:
+ cref_write_cns.append(refcn)
+ ts.injection_holder.bank_styles.init_cref_for_uuids(uuids)
+ else:
+ ts.injection_holder.bank_styles.var_bank.append(cached_var)
+ ts.injection_holder.bank_styles.mean_bank.append(cached_mean)
+ ts.injection_holder.bank_styles.style_cfgs.append(refcn.ref_opts.adain_style_fidelity)
+ ts.injection_holder.bank_styles.cn_idx.append(refcn.order)
+ if len(cref_write_cns) == 0:
+ del cached_var
+ del cached_mean
+
+ # if any refs to READ, do math with saved var, mean, and style_cfg
+ if len(ref_read_cns) > 0:
+ if len(ts.injection_holder.bank_styles.get_cn_idxs(uuids, ignore_contextref_read)) > 0:
+ bank_styles = ts.injection_holder.bank_styles
+ var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ y_uc = torch.zeros_like(x)
+ cn_idx = 0
+ real_style_cfgs = bank_styles.get_style_cfgs(uuids, ignore_contextref_read)
+ real_var_bank = bank_styles.get_var_bank(uuids, ignore_contextref_read)
+ real_mean_bank = bank_styles.get_mean_bank(uuids, ignore_contextref_read)
+ real_cn_idxs = bank_styles.get_cn_idxs(uuids, ignore_contextref_read)
+ for idx, order in enumerate(real_cn_idxs):
+ # make sure matching ref cn is selected
+ for i in range(cn_idx, len(ref_read_cns)):
+ if ref_read_cns[i].order == order:
+ cn_idx = i
+ break
+ assert order == ref_read_cns[cn_idx].order
+ style_fidelity = real_style_cfgs[idx]
+ var_acc = real_var_bank[idx]
+ mean_acc = real_mean_bank[idx]
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ sub_y_uc = (((x - mean) / std) * std_acc) + mean_acc
+ if ref_read_cns[cn_idx].any_adain_strength_to_apply():
+ effective_strength = ref_read_cns[cn_idx].get_effective_adain_mask_or_float(x=x)
+ sub_y_uc = sub_y_uc * effective_strength + x * (1-effective_strength)
+ y_uc += sub_y_uc
+ # get average, if more than one
+ if len(real_cn_idxs) > 1:
+ y_uc /= len(real_cn_idxs)
+ y_c = y_uc.clone()
+ if len(uc_idx_mask) > 0 and not math.isclose(style_fidelity, 0.0):
+ y_c[uc_idx_mask] = x.to(y_c.dtype)[uc_idx_mask]
+ y = style_fidelity * y_c + (1.0 - style_fidelity) * y_uc
+ ts.injection_holder.bank_styles.clean_ref()
+
+ # ContextRef CN WRITE
+ if len(cref_write_cns) > 0:
+ # clear so that ContextRef CNs can properly 'replace' previous value at cond_idx
+ ts.injection_holder.bank_styles.clear_cref_for_uuids(uuids)
+ for refcn in cref_write_cns:
+ # add a whole list to match expected type when combining
+ ts.injection_holder.bank_styles.set_c_var_bank_for_uuids(cached_var, uuids)
+ ts.injection_holder.bank_styles.set_c_mean_bank_for_uuids(cached_mean, uuids)
+ ts.injection_holder.bank_styles.set_c_style_cfgs_for_uuids(refcn.ref_opts.adain_style_fidelity, uuids)
+ ts.injection_holder.bank_styles.set_c_cn_idx_for_uuids(refcn.order, uuids)
+ del cached_var
+ del cached_mean
+
+ if y is None:
+ y = x
+ return y.to(x.dtype)
+
+ return forward_timestep_embed_ref_inject
diff --git a/custom_nodes/comfyui-advanced-controlnet/adv_control/control_sparsectrl.py b/custom_nodes/comfyui-advanced-controlnet/adv_control/control_sparsectrl.py
new file mode 100644
index 0000000000000000000000000000000000000000..a102baa4a8df0db1aa4c1483f9dd09d6468cf8b6
--- /dev/null
+++ b/custom_nodes/comfyui-advanced-controlnet/adv_control/control_sparsectrl.py
@@ -0,0 +1,319 @@
+#taken from: https://github.com/lllyasviel/ControlNet
+#and modified
+#and then taken from comfy/cldm/cldm.py and modified again
+
+from abc import ABC, abstractmethod
+import numpy as np
+import torch
+from torch import Tensor
+
+from comfy.ldm.modules.diffusionmodules.util import (
+ zero_module,
+ timestep_embedding,
+)
+
+from comfy.cldm.cldm import ControlNet as ControlNetCLDM
+from comfy.ldm.modules.diffusionmodules.openaimodel import TimestepEmbedSequential
+from comfy.model_patcher import ModelPatcher
+from comfy.patcher_extension import PatcherInjection
+
+from .dinklink import (InterfaceAnimateDiffInfo, InterfaceAnimateDiffModel,
+ get_CreateMotionModelPatcher, get_AnimateDiffModel, get_AnimateDiffInfo)
+from .logger import logger
+from .utils import (BIGMAX, AbstractPreprocWrapper, disable_weight_init_clean_groupnorm, WrapperConsts)
+
+
+class SparseMotionModelPatcher(ModelPatcher):
+ '''Class only used for IDE type hints.'''
+ def __init__(self, *args, **kwargs):
+ self.model = InterfaceAnimateDiffModel
+
+
+class SparseConst:
+ HINT_MULT = "sparse_hint_mult"
+ NONHINT_MULT = "sparse_nonhint_mult"
+ MASK_MULT = "sparse_mask_mult"
+
+
+class SparseControlNet(ControlNetCLDM):
+ def __init__(self, *args,**kwargs):
+ super().__init__(*args, **kwargs)
+ hint_channels = kwargs.get("hint_channels")
+ operations: disable_weight_init_clean_groupnorm = kwargs.get("operations", disable_weight_init_clean_groupnorm)
+ device = kwargs.get("device", None)
+ self.use_simplified_conditioning_embedding = kwargs.get("use_simplified_conditioning_embedding", False)
+ if self.use_simplified_conditioning_embedding:
+ self.input_hint_block = TimestepEmbedSequential(
+ zero_module(operations.conv_nd(self.dims, hint_channels, self.model_channels, 3, padding=1, dtype=self.dtype, device=device)),
+ )
+
+ def forward(self, x: Tensor, hint: Tensor, timesteps, context, y=None, **kwargs):
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
+ emb = self.time_embed(t_emb)
+
+ # SparseCtrl sets noisy input to zeros
+ x = torch.zeros_like(x)
+ guided_hint = self.input_hint_block(hint, emb, context)
+
+ out_output = []
+ out_middle = []
+
+ hs = []
+ if self.num_classes is not None:
+ assert y.shape[0] == x.shape[0]
+ emb = emb + self.label_emb(y)
+
+ h = x
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
+ if guided_hint is not None:
+ h = module(h, emb, context)
+ h += guided_hint
+ guided_hint = None
+ else:
+ h = module(h, emb, context)
+ out_output.append(zero_conv(h, emb, context))
+
+ h = self.middle_block(h, emb, context)
+ out_middle.append(self.middle_block_out(h, emb, context))
+
+ return {"middle": out_middle, "output": out_output}
+
+
+def load_sparsectrl_motionmodel(ckpt_path: str, motion_data: dict[str, Tensor], ops=None) -> InterfaceAnimateDiffModel:
+ mm_info: InterfaceAnimateDiffInfo = get_AnimateDiffInfo()("SD1.5", "AnimateDiff", "v3", ckpt_path)
+ init_kwargs = {
+ "ops": ops,
+ "get_unet_func": _get_unet_func,
+ }
+ motion_model: InterfaceAnimateDiffModel = get_AnimateDiffModel()(mm_state_dict=motion_data, mm_info=mm_info, init_kwargs=init_kwargs)
+ missing, unexpected = motion_model.load_state_dict(motion_data)
+ if len(missing) > 0 or len(unexpected) > 0:
+ logger.info(f"SparseCtrl MotionModel: {missing}, {unexpected}")
+ return motion_model
+
+
+def create_sparse_modelpatcher(model, motion_model, load_device, offload_device):
+ patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device)
+ if motion_model is not None:
+ _motionpatcher = _create_sparse_motionmodelpatcher(motion_model, load_device, offload_device)
+ patcher.set_additional_models(WrapperConsts.ACN, [_motionpatcher])
+ patcher.set_injections(WrapperConsts.ACN,
+ [PatcherInjection(inject=_inject_motion_models, eject=_eject_motion_models)])
+ return patcher
+
+def _create_sparse_motionmodelpatcher(motion_model, load_device, offload_device) -> SparseMotionModelPatcher:
+ return get_CreateMotionModelPatcher()(motion_model, load_device, offload_device)
+
+
+def _inject_motion_models(patcher: ModelPatcher):
+ motion_models: list[SparseMotionModelPatcher] = patcher.get_additional_models_with_key(WrapperConsts.ACN)
+ for mm in motion_models:
+ mm.model.inject(patcher)
+
+def _eject_motion_models(patcher: ModelPatcher):
+ motion_models: list[SparseMotionModelPatcher] = patcher.get_additional_models_with_key(WrapperConsts.ACN)
+ for mm in motion_models:
+ mm.model.eject(patcher)
+
+def _get_unet_func(wrapper, model: ModelPatcher):
+ return model.model
+
+
+class PreprocSparseRGBWrapper(AbstractPreprocWrapper):
+ error_msg = error_msg = "Invalid use of RGB SparseCtrl output. The output of RGB SparseCtrl preprocessor is NOT a usual image, but a latent pretending to be an image - you must connect the output directly to an Apply ControlNet node (advanced or otherwise). It cannot be used for anything else that accepts IMAGE input."
+ def __init__(self, condhint: Tensor):
+ super().__init__(condhint)
+
+
+class SparseContextAware:
+ NEAREST_HINT = "nearest_hint"
+ OFF = "off"
+
+ LIST = [NEAREST_HINT, OFF]
+
+
+class SparseSettings:
+ def __init__(self, sparse_method: 'SparseMethod', use_motion: bool=True, motion_strength=1.0, motion_scale=1.0, merged=False,
+ sparse_mask_mult=1.0, sparse_hint_mult=1.0, sparse_nonhint_mult=1.0, context_aware=SparseContextAware.NEAREST_HINT):
+ # account for Steerable-Motion workflow incompatibility;
+ # doing this to for my own peace of mind (not an issue with my code)
+ if type(sparse_method) == str:
+ logger.warn("Outdated Steerable-Motion workflow detected; attempting to auto-convert indexes input. If you experience an error here, consult Steerable-Motion github, NOT Advanced-ControlNet.")
+ sparse_method = SparseIndexMethod(get_idx_list_from_str(sparse_method))
+ self.sparse_method = sparse_method
+ self.use_motion = use_motion
+ self.motion_strength = motion_strength
+ self.motion_scale = motion_scale
+ self.merged = merged
+ self.sparse_mask_mult = float(sparse_mask_mult)
+ self.sparse_hint_mult = float(sparse_hint_mult)
+ self.sparse_nonhint_mult = float(sparse_nonhint_mult)
+ self.context_aware = context_aware
+
+ def is_context_aware(self):
+ return self.context_aware != SparseContextAware.OFF
+
+ @classmethod
+ def default(cls):
+ return SparseSettings(sparse_method=SparseSpreadMethod(), use_motion=True)
+
+
+class SparseMethod(ABC):
+ SPREAD = "spread"
+ INDEX = "index"
+ def __init__(self, method: str):
+ self.method = method
+
+ @abstractmethod
+ def _get_indexes(self, hint_length: int, full_length: int) -> list[int]:
+ pass
+
+ def get_indexes(self, hint_length: int, full_length: int, sub_idxs: list[int]=None) -> tuple[list[int], list[int]]:
+ returned_idxs = self._get_indexes(hint_length, full_length)
+ if sub_idxs is None:
+ return returned_idxs, None
+ # need to map full indexes to condhint indexes
+ index_mapping = {}
+ for i, value in enumerate(returned_idxs):
+ index_mapping[value] = i
+ def get_mapped_idxs(idxs: list[int]):
+ return [index_mapping[idx] for idx in idxs]
+ # check if returned_idxs fit within subidxs
+ fitting_idxs = []
+ for sub_idx in sub_idxs:
+ if sub_idx in returned_idxs:
+ fitting_idxs.append(sub_idx)
+ # if have any fitting_idxs, deal with it
+ if len(fitting_idxs) > 0:
+ return fitting_idxs, get_mapped_idxs(fitting_idxs)
+
+ # since no returned_idxs fit in sub_idxs, need to get the next-closest hint images based on strategy
+ def get_closest_idx(target_idx: int, idxs: list[int]):
+ min_idx = -1
+ min_dist = BIGMAX
+ for idx in idxs:
+ new_dist = abs(idx-target_idx)
+ if new_dist < min_dist:
+ min_idx = idx
+ min_dist = new_dist
+ if min_dist == 1:
+ return min_idx, min_dist
+ return min_idx, min_dist
+ start_closest_idx, start_dist = get_closest_idx(sub_idxs[0], returned_idxs)
+ end_closest_idx, end_dist = get_closest_idx(sub_idxs[-1], returned_idxs)
+ # if only one cond hint exists, do special behavior
+ if hint_length == 1:
+ # if same distance from start and end,
+ if start_dist == end_dist:
+ # find center index of sub_idxs
+ center_idx = sub_idxs[np.linspace(0, len(sub_idxs)-1, 3, endpoint=True, dtype=int)[1]]
+ return [center_idx], get_mapped_idxs([start_closest_idx])
+ # otherwise, return closest
+ if start_dist < end_dist:
+ return [sub_idxs[0]], get_mapped_idxs([start_closest_idx])
+ return [sub_idxs[-1]], get_mapped_idxs([end_closest_idx])
+ # otherwise, select up to two closest images, or just 1, whichever one applies best
+ # if same distance from start and end, return two images to use
+ if start_dist == end_dist:
+ return [sub_idxs[0], sub_idxs[-1]], get_mapped_idxs([start_closest_idx, end_closest_idx])
+ # else, use just one
+ if start_dist < end_dist:
+ return [sub_idxs[0]], get_mapped_idxs([start_closest_idx])
+ return [sub_idxs[-1]], get_mapped_idxs([end_closest_idx])
+
+
+class SparseSpreadMethod(SparseMethod):
+ UNIFORM = "uniform"
+ STARTING = "starting"
+ ENDING = "ending"
+ CENTER = "center"
+
+ LIST = [UNIFORM, STARTING, ENDING, CENTER]
+
+ def __init__(self, spread=UNIFORM):
+ super().__init__(self.SPREAD)
+ self.spread = spread
+
+ def _get_indexes(self, hint_length: int, full_length: int) -> list[int]:
+ # if hint_length >= full_length, limit hints to full_length
+ if hint_length >= full_length:
+ return list(range(full_length))
+ # handle special case of 1 hint image
+ if hint_length == 1:
+ if self.spread in [self.UNIFORM, self.STARTING]:
+ return [0]
+ elif self.spread == self.ENDING:
+ return [full_length-1]
+ elif self.spread == self.CENTER:
+ # return second (of three) values as the center
+ return [np.linspace(0, full_length-1, 3, endpoint=True, dtype=int)[1]]
+ else:
+ raise ValueError(f"Unrecognized spread: {self.spread}")
+ # otherwise, handle other cases
+ if self.spread == self.UNIFORM:
+ return list(np.linspace(0, full_length-1, hint_length, endpoint=True, dtype=int))
+ elif self.spread == self.STARTING:
+ # make split 1 larger, remove last element
+ return list(np.linspace(0, full_length-1, hint_length+1, endpoint=True, dtype=int))[:-1]
+ elif self.spread == self.ENDING:
+ # make split 1 larger, remove first element
+ return list(np.linspace(0, full_length-1, hint_length+1, endpoint=True, dtype=int))[1:]
+ elif self.spread == self.CENTER:
+ # if hint length is not 3 greater than full length, do STARTING behavior
+ if full_length-hint_length < 3:
+ return list(np.linspace(0, full_length-1, hint_length+1, endpoint=True, dtype=int))[:-1]
+ # otherwise, get linspace of 2 greater than needed, then cut off first and last
+ return list(np.linspace(0, full_length-1, hint_length+2, endpoint=True, dtype=int))[1:-1]
+ return ValueError(f"Unrecognized spread: {self.spread}")
+
+
+class SparseIndexMethod(SparseMethod):
+ def __init__(self, idxs: list[int]):
+ super().__init__(self.INDEX)
+ self.idxs = idxs
+
+ def _get_indexes(self, hint_length: int, full_length: int) -> list[int]:
+ orig_hint_length = hint_length
+ if hint_length > full_length:
+ hint_length = full_length
+ # if idxs is less than hint_length, throw error
+ if len(self.idxs) < hint_length:
+ err_msg = f"There are not enough indexes ({len(self.idxs)}) provided to fit the usable {hint_length} input images."
+ if orig_hint_length != hint_length:
+ err_msg = f"{err_msg} (original input images: {orig_hint_length})"
+ raise ValueError(err_msg)
+ # cap idxs to hint_length
+ idxs = self.idxs[:hint_length]
+ new_idxs = []
+ real_idxs = set()
+ for idx in idxs:
+ if idx < 0:
+ real_idx = full_length+idx
+ if real_idx in real_idxs:
+ raise ValueError(f"Index '{idx}' maps to '{real_idx}' and is duplicate - indexes in Sparse Index Method must be unique.")
+ else:
+ real_idx = idx
+ if real_idx in real_idxs:
+ raise ValueError(f"Index '{idx}' is duplicate (or a negative index is equivalent) - indexes in Sparse Index Method must be unique.")
+ real_idxs.add(real_idx)
+ new_idxs.append(real_idx)
+ return new_idxs
+
+
+def get_idx_list_from_str(indexes: str) -> list[int]:
+ idxs = []
+ unique_idxs = set()
+ # get indeces from string
+ str_idxs = [x.strip() for x in indexes.strip().split(",")]
+ for str_idx in str_idxs:
+ try:
+ idx = int(str_idx)
+ if idx in unique_idxs:
+ raise ValueError(f"'{idx}' is duplicated; indexes must be unique.")
+ idxs.append(idx)
+ unique_idxs.add(idx)
+ except ValueError:
+ raise ValueError(f"'{str_idx}' is not a valid integer index.")
+ if len(idxs) == 0:
+ raise ValueError(f"No indexes were listed in Sparse Index Method.")
+ return idxs
diff --git a/custom_nodes/comfyui-advanced-controlnet/adv_control/control_svd.py b/custom_nodes/comfyui-advanced-controlnet/adv_control/control_svd.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c094fb1a23036eafcdc77cb02eea399e7cb9409
--- /dev/null
+++ b/custom_nodes/comfyui-advanced-controlnet/adv_control/control_svd.py
@@ -0,0 +1,518 @@
+import torch
+import torch.nn as nn
+from torch import Tensor
+
+import comfy.model_detection
+from comfy.utils import UNET_MAP_BASIC, UNET_MAP_RESNET, UNET_MAP_ATTENTIONS, TRANSFORMER_BLOCKS
+
+import torch
+
+
+from comfy.ldm.modules.diffusionmodules.util import (
+ zero_module,
+ timestep_embedding,
+)
+
+from comfy.ldm.modules.attention import SpatialVideoTransformer
+from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, VideoResBlock, Downsample
+from comfy.ldm.util import exists
+import comfy.ops
+
+
+class SVDControlNet(nn.Module):
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ hint_channels,
+ num_res_blocks,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ dtype=torch.float32,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=None, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
+ disable_self_attentions=None,
+ num_attention_blocks=None,
+ disable_middle_self_attn=False,
+ use_linear_in_transformer=False,
+ adm_in_channels=None,
+ transformer_depth_middle=None,
+ transformer_depth_output=None,
+ use_spatial_context=False,
+ extra_ff_mix_layer=False,
+ merge_strategy="fixed",
+ merge_factor=0.5,
+ video_kernel_size=3,
+ device=None,
+ operations=comfy.ops.disable_weight_init,
+ **kwargs,
+ ):
+ super().__init__()
+ assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
+ if use_spatial_transformer:
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
+
+ if context_dim is not None:
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
+ # from omegaconf.listconfig import ListConfig
+ # if type(context_dim) == ListConfig:
+ # context_dim = list(context_dim)
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+
+ if num_head_channels == -1:
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+
+ self.dims = dims
+ self.image_size = image_size
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+
+ if isinstance(num_res_blocks, int):
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+ else:
+ if len(num_res_blocks) != len(channel_mult):
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
+ "as a list/tuple (per-level) with the same length as channel_mult")
+ self.num_res_blocks = num_res_blocks
+
+ if disable_self_attentions is not None:
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
+ assert len(disable_self_attentions) == len(channel_mult)
+ if num_attention_blocks is not None:
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
+
+ transformer_depth = transformer_depth[:]
+
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.dtype = dtype
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.predict_codebook_ids = n_embed is not None
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
+ nn.SiLU(),
+ operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
+ )
+
+ if self.num_classes is not None:
+ if isinstance(self.num_classes, int):
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+ elif self.num_classes == "continuous":
+ print("setting up linear c_adm embedding layer")
+ self.label_emb = nn.Linear(1, time_embed_dim)
+ elif self.num_classes == "sequential":
+ assert adm_in_channels is not None
+ self.label_emb = nn.Sequential(
+ nn.Sequential(
+ operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
+ nn.SiLU(),
+ operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
+ )
+ )
+ else:
+ raise ValueError()
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
+ )
+ ]
+ )
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)])
+
+ self.input_hint_block = TimestepEmbedSequential(
+ operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
+ nn.SiLU(),
+ operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
+ nn.SiLU(),
+ operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
+ nn.SiLU(),
+ operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
+ nn.SiLU(),
+ operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
+ nn.SiLU(),
+ operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
+ nn.SiLU(),
+ operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
+ nn.SiLU(),
+ operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
+ )
+
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for nr in range(self.num_res_blocks[level]):
+ layers = [
+ VideoResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ dtype=self.dtype,
+ device=device,
+ operations=operations,
+ video_kernel_size=video_kernel_size,
+ merge_strategy=merge_strategy, merge_factor=merge_factor,
+ )
+ ]
+ ch = mult * model_channels
+ num_transformers = transformer_depth.pop(0)
+ if num_transformers > 0:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ if exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
+ layers.append(
+ SpatialVideoTransformer(
+ ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
+ checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations,
+ use_spatial_context=use_spatial_context, ff_in=extra_ff_mix_layer,
+ merge_strategy=merge_strategy, merge_factor=merge_factor,
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ VideoResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ dtype=self.dtype,
+ device=device,
+ operations=operations,
+ video_kernel_size=video_kernel_size,
+ merge_strategy=merge_strategy, merge_factor=merge_factor,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
+ ds *= 2
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ mid_block = [
+ VideoResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ dtype=self.dtype,
+ device=device,
+ operations=operations,
+ video_kernel_size=video_kernel_size,
+ merge_strategy=merge_strategy, merge_factor=merge_factor,
+ )]
+ if transformer_depth_middle >= 0:
+ mid_block += [SpatialVideoTransformer( # always uses a self-attn
+ ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
+ checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations,
+ use_spatial_context=use_spatial_context, ff_in=extra_ff_mix_layer,
+ merge_strategy=merge_strategy, merge_factor=merge_factor,
+ ),
+ VideoResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ dtype=self.dtype,
+ device=device,
+ operations=operations,
+ video_kernel_size=video_kernel_size,
+ merge_strategy=merge_strategy, merge_factor=merge_factor,
+ )]
+ self.middle_block = TimestepEmbedSequential(*mid_block)
+ self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
+ self._feature_size += ch
+
+ def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
+ return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
+
+ def forward(self, x, hint, timesteps, context, y=None, **kwargs):
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
+ emb = self.time_embed(t_emb)
+
+ cond = kwargs["cond"]
+ num_video_frames = cond["num_video_frames"]
+ image_only_indicator = cond.get("image_only_indicator", None)
+ time_context = cond.get("time_context", None)
+ del cond
+
+ guided_hint = self.input_hint_block(hint, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
+
+ out_output = []
+ out_middle = []
+
+ hs = []
+ if self.num_classes is not None:
+ assert y.shape[0] == x.shape[0]
+ emb = emb + self.label_emb(y)
+
+ h = x
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
+ if guided_hint is not None:
+ h = module(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
+ h += guided_hint
+ guided_hint = None
+ else:
+ h = module(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
+ out_output.append(zero_conv(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator))
+
+ h = self.middle_block(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
+ out_middle.append(self.middle_block_out(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator))
+
+ return {"middle": out_middle, "output": out_output}
+
+
+TEMPORAL_TRANSFORMER_BLOCKS = {
+ "norm_in.weight",
+ "norm_in.bias",
+ "ff_in.net.0.proj.weight",
+ "ff_in.net.0.proj.bias",
+ "ff_in.net.2.weight",
+ "ff_in.net.2.bias",
+}
+TEMPORAL_TRANSFORMER_BLOCKS.update(TRANSFORMER_BLOCKS)
+
+
+TEMPORAL_UNET_MAP_ATTENTIONS = {
+ "time_mixer.mix_factor",
+}
+TEMPORAL_UNET_MAP_ATTENTIONS.update(UNET_MAP_ATTENTIONS)
+
+
+TEMPORAL_TRANSFORMER_MAP = {
+ "time_pos_embed.0.weight": "time_pos_embed.linear_1.weight",
+ "time_pos_embed.0.bias": "time_pos_embed.linear_1.bias",
+ "time_pos_embed.2.weight": "time_pos_embed.linear_2.weight",
+ "time_pos_embed.2.bias": "time_pos_embed.linear_2.bias",
+}
+
+
+TEMPORAL_RESNET = {
+ "time_mixer.mix_factor",
+}
+
+
+def svd_unet_config_from_diffusers_unet(state_dict: dict[str, Tensor], dtype):
+ match = {}
+ transformer_depth = []
+
+ attn_res = 1
+ down_blocks = comfy.model_detection.count_blocks(state_dict, "down_blocks.{}")
+ for i in range(down_blocks):
+ attn_blocks = comfy.model_detection.count_blocks(state_dict, "down_blocks.{}.attentions.".format(i) + '{}')
+ for ab in range(attn_blocks):
+ transformer_count = comfy.model_detection.count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}')
+ transformer_depth.append(transformer_count)
+ if transformer_count > 0:
+ match["context_dim"] = state_dict["down_blocks.{}.attentions.{}.transformer_blocks.0.attn2.to_k.weight".format(i, ab)].shape[1]
+
+ attn_res *= 2
+ if attn_blocks == 0:
+ transformer_depth.append(0)
+ transformer_depth.append(0)
+
+ match["transformer_depth"] = transformer_depth
+
+ match["model_channels"] = state_dict["conv_in.weight"].shape[0]
+ match["in_channels"] = state_dict["conv_in.weight"].shape[1]
+ match["adm_in_channels"] = None
+ if "class_embedding.linear_1.weight" in state_dict:
+ match["adm_in_channels"] = state_dict["class_embedding.linear_1.weight"].shape[1]
+ elif "add_embedding.linear_1.weight" in state_dict:
+ match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1]
+
+ # based on unet_config of SVD
+ SVD = {
+ 'use_checkpoint': False,
+ 'image_size': 32,
+ 'use_spatial_transformer': True,
+ 'legacy': False,
+ 'num_classes': 'sequential',
+ 'adm_in_channels': 768,
+ 'dtype': dtype,
+ 'in_channels': 8,
+ 'out_channels': 4,
+ 'model_channels': 320,
+ 'num_res_blocks': [2, 2, 2, 2],
+ 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
+ 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
+ 'channel_mult': [1, 2, 4, 4],
+ 'transformer_depth_middle': 1,
+ 'use_linear_in_transformer': True,
+ 'context_dim': 1024,
+ 'extra_ff_mix_layer': True,
+ 'use_spatial_context': True,
+ 'merge_strategy': 'learned_with_images',
+ 'merge_factor': 0.0,
+ 'video_kernel_size': [3, 1, 1],
+ 'use_temporal_attention': True,
+ 'use_temporal_resblock': True,
+ 'num_heads': -1,
+ 'num_head_channels': 64,
+ }
+
+ supported_models = [SVD]
+
+ for unet_config in supported_models:
+ matches = True
+ for k in match:
+ if match[k] != unet_config[k]:
+ matches = False
+ break
+ if matches:
+ return comfy.model_detection.convert_config(unet_config)
+ return None
+
+
+def svd_unet_to_diffusers(unet_config):
+ num_res_blocks = unet_config["num_res_blocks"]
+ channel_mult = unet_config["channel_mult"]
+ transformer_depth = unet_config["transformer_depth"][:]
+ transformer_depth_output = unet_config["transformer_depth_output"][:]
+ num_blocks = len(channel_mult)
+
+ transformers_mid = unet_config.get("transformer_depth_middle", None)
+
+ diffusers_unet_map = {}
+ for x in range(num_blocks):
+ n = 1 + (num_res_blocks[x] + 1) * x
+ for i in range(num_res_blocks[x]):
+ for b in TEMPORAL_RESNET:
+ diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, b)] = "input_blocks.{}.0.{}".format(n, b)
+ for b in UNET_MAP_RESNET:
+ diffusers_unet_map["down_blocks.{}.resnets.{}.spatial_res_block.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b)
+ diffusers_unet_map["down_blocks.{}.resnets.{}.temporal_res_block.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.time_stack.{}".format(n, b)
+ #diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b)
+ num_transformers = transformer_depth.pop(0)
+ if num_transformers > 0:
+ for b in TEMPORAL_UNET_MAP_ATTENTIONS:
+ diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, b)] = "input_blocks.{}.1.{}".format(n, b)
+ for b in TEMPORAL_TRANSFORMER_MAP:
+ diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, TEMPORAL_TRANSFORMER_MAP[b])] = "input_blocks.{}.1.{}".format(n, b)
+ for t in range(num_transformers):
+ for b in TRANSFORMER_BLOCKS:
+ diffusers_unet_map["down_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
+ for b in TEMPORAL_TRANSFORMER_BLOCKS:
+ diffusers_unet_map["down_blocks.{}.attentions.{}.temporal_transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.time_stack.{}.{}".format(n, t, b)
+ n += 1
+ for k in ["weight", "bias"]:
+ diffusers_unet_map["down_blocks.{}.downsamplers.0.conv.{}".format(x, k)] = "input_blocks.{}.0.op.{}".format(n, k)
+
+ i = 0
+ for b in TEMPORAL_UNET_MAP_ATTENTIONS:
+ diffusers_unet_map["mid_block.attentions.{}.{}".format(i, b)] = "middle_block.1.{}".format(b)
+ for b in TEMPORAL_TRANSFORMER_MAP:
+ diffusers_unet_map["mid_block.attentions.{}.{}".format(i, TEMPORAL_TRANSFORMER_MAP[b])] = "middle_block.1.{}".format(b)
+ for t in range(transformers_mid):
+ for b in TRANSFORMER_BLOCKS:
+ diffusers_unet_map["mid_block.attentions.{}.transformer_blocks.{}.{}".format(i, t, b)] = "middle_block.1.transformer_blocks.{}.{}".format(t, b)
+ for b in TEMPORAL_TRANSFORMER_BLOCKS:
+ diffusers_unet_map["mid_block.attentions.{}.temporal_transformer_blocks.{}.{}".format(i, t, b)] = "middle_block.1.time_stack.{}.{}".format(t, b)
+
+ for i, n in enumerate([0, 2]):
+ for b in TEMPORAL_RESNET:
+ diffusers_unet_map["mid_block.resnets.{}.{}".format(i, b)] = "middle_block.{}.{}".format(n, b)
+ for b in UNET_MAP_RESNET:
+ diffusers_unet_map["mid_block.resnets.{}.spatial_res_block.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b)
+ diffusers_unet_map["mid_block.resnets.{}.temporal_res_block.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.time_stack.{}".format(n, b)
+ #diffusers_unet_map["mid_block.resnets.{}.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b)
+
+ num_res_blocks = list(reversed(num_res_blocks))
+ for x in range(num_blocks):
+ n = (num_res_blocks[x] + 1) * x
+ l = num_res_blocks[x] + 1
+ for i in range(l):
+ c = 0
+ for b in UNET_MAP_RESNET:
+ diffusers_unet_map["up_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "output_blocks.{}.0.{}".format(n, b)
+ c += 1
+ num_transformers = transformer_depth_output.pop()
+ if num_transformers > 0:
+ c += 1
+ for b in UNET_MAP_ATTENTIONS:
+ diffusers_unet_map["up_blocks.{}.attentions.{}.{}".format(x, i, b)] = "output_blocks.{}.1.{}".format(n, b)
+ for t in range(num_transformers):
+ for b in TRANSFORMER_BLOCKS:
+ diffusers_unet_map["up_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "output_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
+ if i == l - 1:
+ for k in ["weight", "bias"]:
+ diffusers_unet_map["up_blocks.{}.upsamplers.0.conv.{}".format(x, k)] = "output_blocks.{}.{}.conv.{}".format(n, c, k)
+ n += 1
+
+ for k in UNET_MAP_BASIC:
+ diffusers_unet_map[k[1]] = k[0]
+
+ return diffusers_unet_map
diff --git a/custom_nodes/comfyui-advanced-controlnet/adv_control/dinklink.py b/custom_nodes/comfyui-advanced-controlnet/adv_control/dinklink.py
new file mode 100644
index 0000000000000000000000000000000000000000..c100971ec594fbbd4a59249d3e090462e895b580
--- /dev/null
+++ b/custom_nodes/comfyui-advanced-controlnet/adv_control/dinklink.py
@@ -0,0 +1,112 @@
+####################################################################################################
+# DinkLink is my method of sharing classes/functions between my nodes.
+#
+# My DinkLink-compatible nodes will inject comfy.hooks with a __DINKLINK attr
+# that stores a dictionary, where any of my node packs can store their stuff.
+#
+# It is not intended to be accessed by node packs that I don't develop, so things may change
+# at any time.
+#
+# DinkLink also serves as a proof-of-concept for a future ComfyUI implementation of
+# purposely exposing node pack classes/functions with other node packs.
+####################################################################################################
+from __future__ import annotations
+from typing import Union
+from torch import Tensor, nn
+
+from comfy.model_patcher import ModelPatcher
+import comfy.hooks
+
+DINKLINK = "__DINKLINK"
+
+
+def init_dinklink():
+ create_dinklink()
+ prepare_dinklink()
+
+def create_dinklink():
+ if not hasattr(comfy.hooks, DINKLINK):
+ setattr(comfy.hooks, DINKLINK, {})
+
+def get_dinklink() -> dict[str, dict[str]]:
+ create_dinklink()
+ return getattr(comfy.hooks, DINKLINK)
+
+
+class DinkLinkConst:
+ VERSION = "version"
+ # ADE
+ ADE = "ADE"
+ ADE_ANIMATEDIFFMODEL = "AnimateDiffModel"
+ ADE_ANIMATEDIFFINFO = "AnimateDiffInfo"
+ ADE_CREATE_MOTIONMODELPATCHER = "create_MotionModelPatcher"
+
+def prepare_dinklink():
+ pass
+
+
+class InterfaceAnimateDiffInfo:
+ '''Class only used for IDE type hints; interface of ADE's AnimateDiffInfo'''
+ def __init__(self, sd_type: str, mm_format: str, mm_version: str, mm_name: str):
+ self.sd_type = sd_type
+ self.mm_format = mm_format
+ self.mm_version = mm_version
+ self.mm_name = mm_name
+
+
+class InterfaceAnimateDiffModel(nn.Module):
+ '''Class only used for IDE type hints; interface of ADE's AnimateDiffModel'''
+ def __init__(self, mm_state_dict: dict[str, Tensor], mm_info: InterfaceAnimateDiffInfo, init_kwargs: dict[str]={}):
+ pass
+
+ def set_video_length(self, video_length: int, full_length: int) -> None:
+ raise NotImplemented()
+
+ def set_scale(self, scale: Union[float, Tensor, None], per_block_list: Union[list, None]=None) -> None:
+ raise NotImplemented()
+
+ def set_effect(self, multival: Union[float, Tensor, None], per_block_list: Union[list, None]=None) -> None:
+ raise NotImplemented()
+
+ def cleanup(self):
+ raise NotImplemented()
+
+ def inject(self, model: ModelPatcher):
+ pass
+
+ def eject(self, model: ModelPatcher):
+ pass
+
+
+def get_CreateMotionModelPatcher(throw_exception=True):
+ d = get_dinklink()
+ try:
+ link_ade = d[DinkLinkConst.ADE]
+ return link_ade[DinkLinkConst.ADE_CREATE_MOTIONMODELPATCHER]
+ except KeyError:
+ if throw_exception:
+ raise Exception("Could not get create_MotionModelPatcher function. AnimateDiff-Evolved nodes need to be installed to use SparseCtrl; " + \
+ "they are either not installed or are of an insufficient version.")
+ return None
+
+def get_AnimateDiffModel(throw_exception=True):
+ d = get_dinklink()
+ try:
+ link_ade = d[DinkLinkConst.ADE]
+ return link_ade[DinkLinkConst.ADE_ANIMATEDIFFMODEL]
+ except KeyError:
+ if throw_exception:
+ raise Exception("Could not get AnimateDiffModel class. AnimateDiff-Evolved nodes need to be installed to use SparseCtrl; " + \
+ "they are either not installed or are of an insufficient version.")
+ return None
+
+def get_AnimateDiffInfo(throw_exception=True) -> InterfaceAnimateDiffInfo:
+ d = get_dinklink()
+ try:
+ link_ade = d[DinkLinkConst.ADE]
+ return link_ade[DinkLinkConst.ADE_ANIMATEDIFFINFO]
+ except KeyError:
+ if throw_exception:
+ raise Exception("Could not get AnimateDiffInfo class - AnimateDiff-Evolved nodes need to be installed to use SparseCtrl; " + \
+ "they are either not installed or are of an insufficient version.")
+ return None
diff --git a/custom_nodes/comfyui-advanced-controlnet/adv_control/documentation.py b/custom_nodes/comfyui-advanced-controlnet/adv_control/documentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6ccd48d160b83f1671230baa00448c8410927a2
--- /dev/null
+++ b/custom_nodes/comfyui-advanced-controlnet/adv_control/documentation.py
@@ -0,0 +1,47 @@
+from .logger import logger
+
+def image(src):
+ return f''
+def video(src):
+ return f'