Upload 111 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +29 -0
- comfyui-mvadapter/.github/workflows/publish.yml +25 -0
- comfyui-mvadapter/BACKUP_nodes.py +843 -0
- comfyui-mvadapter/LICENSE +201 -0
- comfyui-mvadapter/README.md +88 -0
- comfyui-mvadapter/__init__.py +45 -0
- comfyui-mvadapter/__pycache__/__init__.cpython-312.pyc +0 -0
- comfyui-mvadapter/__pycache__/nodes.cpython-312.pyc +0 -0
- comfyui-mvadapter/__pycache__/nodes_local_mv.cpython-312.pyc +0 -0
- comfyui-mvadapter/__pycache__/utils.cpython-312.pyc +0 -0
- comfyui-mvadapter/assets/CustomLoraModelLoader.png +0 -0
- comfyui-mvadapter/assets/comfyui_i2mv.png +3 -0
- comfyui-mvadapter/assets/comfyui_i2mv_lora.png +3 -0
- comfyui-mvadapter/assets/comfyui_i2mv_multiple_loras.jpg +3 -0
- comfyui-mvadapter/assets/comfyui_i2mv_view_selector.png +3 -0
- comfyui-mvadapter/assets/comfyui_ldm_vae.png +0 -0
- comfyui-mvadapter/assets/comfyui_model_makeup.png +0 -0
- comfyui-mvadapter/assets/comfyui_t2mv.png +3 -0
- comfyui-mvadapter/assets/comfyui_t2mv_controlnet.png +3 -0
- comfyui-mvadapter/assets/comfyui_t2mv_lora.png +3 -0
- comfyui-mvadapter/assets/comfyui_t2mv_multiple_loras.jpg +3 -0
- comfyui-mvadapter/assets/demo/scribbles/scribble_0.png +0 -0
- comfyui-mvadapter/assets/demo/scribbles/scribble_1.png +0 -0
- comfyui-mvadapter/assets/demo/scribbles/scribble_2.png +0 -0
- comfyui-mvadapter/assets/demo/scribbles/scribble_3.png +0 -0
- comfyui-mvadapter/assets/demo/scribbles/scribble_4.png +0 -0
- comfyui-mvadapter/assets/demo/scribbles/scribble_5.png +0 -0
- comfyui-mvadapter/cache/stable-diffusion-v1-inference.yaml +70 -0
- comfyui-mvadapter/mvadapter/__init__.py +0 -0
- comfyui-mvadapter/mvadapter/__pycache__/__init__.cpython-312.pyc +0 -0
- comfyui-mvadapter/mvadapter/loaders/__init__.py +1 -0
- comfyui-mvadapter/mvadapter/loaders/__pycache__/__init__.cpython-312.pyc +0 -0
- comfyui-mvadapter/mvadapter/loaders/__pycache__/custom_adapter.cpython-312.pyc +0 -0
- comfyui-mvadapter/mvadapter/loaders/custom_adapter.py +98 -0
- comfyui-mvadapter/mvadapter/models/__init__.py +0 -0
- comfyui-mvadapter/mvadapter/models/__pycache__/__init__.cpython-312.pyc +0 -0
- comfyui-mvadapter/mvadapter/models/__pycache__/attention_processor.cpython-312.pyc +0 -0
- comfyui-mvadapter/mvadapter/models/attention_processor.py +377 -0
- comfyui-mvadapter/mvadapter/pipelines/__pycache__/pipeline_mvadapter_i2mv_sd.cpython-312.pyc +0 -0
- comfyui-mvadapter/mvadapter/pipelines/__pycache__/pipeline_mvadapter_i2mv_sdxl.cpython-312.pyc +0 -0
- comfyui-mvadapter/mvadapter/pipelines/__pycache__/pipeline_mvadapter_t2mv_sd.cpython-312.pyc +0 -0
- comfyui-mvadapter/mvadapter/pipelines/__pycache__/pipeline_mvadapter_t2mv_sdxl.cpython-312.pyc +0 -0
- comfyui-mvadapter/mvadapter/pipelines/pipeline_mvadapter_i2mv_sdxl.py +903 -0
- comfyui-mvadapter/mvadapter/schedulers/ShiftSNRSchedulerKarras.py +120 -0
- comfyui-mvadapter/mvadapter/schedulers/__pycache__/ShiftSNRSchedulerKarras.cpython-312.pyc +0 -0
- comfyui-mvadapter/mvadapter/schedulers/__pycache__/scheduler_utils.cpython-312.pyc +0 -0
- comfyui-mvadapter/mvadapter/schedulers/__pycache__/scheduling_shift_snr.cpython-312.pyc +0 -0
- comfyui-mvadapter/mvadapter/schedulers/scheduler_utils.py +70 -0
- comfyui-mvadapter/mvadapter/schedulers/scheduling_shift_snr.py +140 -0
- comfyui-mvadapter/mvadapter/utils/__init__.py +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,32 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
comfyui-mvadapter/assets/comfyui_i2mv_lora.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
comfyui-mvadapter/assets/comfyui_i2mv_multiple_loras.jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
comfyui-mvadapter/assets/comfyui_i2mv_view_selector.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
comfyui-mvadapter/assets/comfyui_i2mv.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
comfyui-mvadapter/assets/comfyui_t2mv_controlnet.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
comfyui-mvadapter/assets/comfyui_t2mv_lora.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
comfyui-mvadapter/assets/comfyui_t2mv_multiple_loras.jpg filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
comfyui-mvadapter/assets/comfyui_t2mv.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
comfyui-salia/assets/images/boy0.png filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
comfyui-salia/assets/images/boy1.png filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
comfyui-salia/assets/images/boy2.png filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
comfyui-salia/assets/images/boy3.png filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
comfyui-salia/assets/images/boy4.png filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
comfyui-salia/assets/images/boy5.png filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
comfyui-salia/assets/images/girl0.png filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
comfyui-salia/assets/images/girl1.png filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
comfyui-salia/assets/images/girl2.png filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
comfyui-salia/assets/images/girl3.png filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
comfyui-salia/assets/images/girl4.png filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
comfyui-salia/assets/images/girl5.png filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
comfyui-salia/assets/images/hair_L_Bound_Braided.png filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
comfyui-salia/assets/images/hair_L_Bound.png filter=lfs diff=lfs merge=lfs -text
|
| 58 |
+
comfyui-salia/assets/images/hair_L_Loose.png filter=lfs diff=lfs merge=lfs -text
|
| 59 |
+
comfyui-salia/assets/images/hair_M_Bound_Braided.png filter=lfs diff=lfs merge=lfs -text
|
| 60 |
+
comfyui-salia/assets/images/hair_M_Bound.png filter=lfs diff=lfs merge=lfs -text
|
| 61 |
+
comfyui-salia/assets/images/hair_M_Loose.png filter=lfs diff=lfs merge=lfs -text
|
| 62 |
+
comfyui-salia/assets/images/hair_S_Bound_Braided.png filter=lfs diff=lfs merge=lfs -text
|
| 63 |
+
comfyui-salia/assets/images/hair_S_Bound.png filter=lfs diff=lfs merge=lfs -text
|
| 64 |
+
comfyui-salia/assets/images/hair_S_Loose.png filter=lfs diff=lfs merge=lfs -text
|
comfyui-mvadapter/.github/workflows/publish.yml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Publish to Comfy registry
|
| 2 |
+
on:
|
| 3 |
+
workflow_dispatch:
|
| 4 |
+
push:
|
| 5 |
+
branches:
|
| 6 |
+
- main
|
| 7 |
+
paths:
|
| 8 |
+
- "pyproject.toml"
|
| 9 |
+
|
| 10 |
+
permissions:
|
| 11 |
+
issues: write
|
| 12 |
+
|
| 13 |
+
jobs:
|
| 14 |
+
publish-node:
|
| 15 |
+
name: Publish Custom Node to registry
|
| 16 |
+
runs-on: ubuntu-latest
|
| 17 |
+
if: ${{ github.repository_owner == 'huanngzh' }}
|
| 18 |
+
steps:
|
| 19 |
+
- name: Check out code
|
| 20 |
+
uses: actions/checkout@v4
|
| 21 |
+
- name: Publish Custom Node
|
| 22 |
+
uses: Comfy-Org/publish-node-action@v1
|
| 23 |
+
with:
|
| 24 |
+
## Add your own personal access token to your Github Repository secrets and reference it here.
|
| 25 |
+
personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}
|
comfyui-mvadapter/BACKUP_nodes.py
ADDED
|
@@ -0,0 +1,843 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/Limitex/ComfyUI-Diffusers/blob/main/nodes.py
|
| 2 |
+
import copy
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
from safetensors.torch import load_file
|
| 6 |
+
from torchvision import transforms
|
| 7 |
+
from .utils import (
|
| 8 |
+
SCHEDULERS,
|
| 9 |
+
PIPELINES,
|
| 10 |
+
MVADAPTERS,
|
| 11 |
+
vae_pt_to_vae_diffuser,
|
| 12 |
+
convert_images_to_tensors,
|
| 13 |
+
convert_tensors_to_images,
|
| 14 |
+
prepare_camera_embed,
|
| 15 |
+
preprocess_image,
|
| 16 |
+
)
|
| 17 |
+
from comfy.model_management import get_torch_device
|
| 18 |
+
import folder_paths
|
| 19 |
+
|
| 20 |
+
from diffusers import StableDiffusionXLPipeline, AutoencoderKL, ControlNetModel
|
| 21 |
+
from transformers import AutoModelForImageSegmentation # <-- restored
|
| 22 |
+
|
| 23 |
+
# ADDED: import DPMSolverMultistepScheduler for DPM++ Karras
|
| 24 |
+
from diffusers import DPMSolverMultistepScheduler
|
| 25 |
+
|
| 26 |
+
from .mvadapter.pipelines.pipeline_mvadapter_t2mv_sdxl import MVAdapterT2MVSDXLPipeline
|
| 27 |
+
from .mvadapter.schedulers.scheduling_shift_snr import ShiftSNRScheduler
|
| 28 |
+
|
| 29 |
+
# ADDED: import your new Karras-enabled shift scheduler (file sits next to scheduling_shift_snr.py)
|
| 30 |
+
from .mvadapter.schedulers.ShiftSNRSchedulerKarras import ShiftSNRSchedulerKarras
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class DiffusersMVPipelineLoader:
|
| 35 |
+
def __init__(self):
|
| 36 |
+
self.hf_dir = folder_paths.get_folder_paths("diffusers")[0]
|
| 37 |
+
self.dtype = torch.float16
|
| 38 |
+
|
| 39 |
+
@classmethod
|
| 40 |
+
def INPUT_TYPES(s):
|
| 41 |
+
return {
|
| 42 |
+
"required": {
|
| 43 |
+
"ckpt_name": (
|
| 44 |
+
"STRING",
|
| 45 |
+
{"default": "stabilityai/stable-diffusion-xl-base-1.0"},
|
| 46 |
+
),
|
| 47 |
+
"pipeline_name": (
|
| 48 |
+
list(PIPELINES.keys()),
|
| 49 |
+
{"default": "MVAdapterT2MVSDXLPipeline"},
|
| 50 |
+
),
|
| 51 |
+
}
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
RETURN_TYPES = (
|
| 55 |
+
"PIPELINE",
|
| 56 |
+
"AUTOENCODER",
|
| 57 |
+
"SCHEDULER",
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
FUNCTION = "create_pipeline"
|
| 61 |
+
|
| 62 |
+
CATEGORY = "MV-Adapter"
|
| 63 |
+
|
| 64 |
+
def create_pipeline(self, ckpt_name, pipeline_name):
|
| 65 |
+
pipeline_class = PIPELINES[pipeline_name]
|
| 66 |
+
pipe = pipeline_class.from_pretrained(
|
| 67 |
+
pretrained_model_name_or_path=ckpt_name,
|
| 68 |
+
torch_dtype=self.dtype,
|
| 69 |
+
cache_dir=self.hf_dir,
|
| 70 |
+
)
|
| 71 |
+
return (pipe, pipe.vae, pipe.scheduler)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class LdmPipelineLoader:
|
| 75 |
+
def __init__(self):
|
| 76 |
+
self.hf_dir = folder_paths.get_folder_paths("diffusers")[0]
|
| 77 |
+
self.dtype = torch.float16
|
| 78 |
+
|
| 79 |
+
@classmethod
|
| 80 |
+
def INPUT_TYPES(s):
|
| 81 |
+
return {
|
| 82 |
+
"required": {
|
| 83 |
+
"ckpt_name": (folder_paths.get_filename_list("checkpoints"),),
|
| 84 |
+
"pipeline_name": (
|
| 85 |
+
list(PIPELINES.keys()),
|
| 86 |
+
{"default": "MVAdapterT2MVSDXLPipeline"},
|
| 87 |
+
),
|
| 88 |
+
}
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
RETURN_TYPES = (
|
| 92 |
+
"PIPELINE",
|
| 93 |
+
"AUTOENCODER",
|
| 94 |
+
"SCHEDULER",
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
FUNCTION = "create_pipeline"
|
| 98 |
+
|
| 99 |
+
CATEGORY = "MV-Adapter"
|
| 100 |
+
|
| 101 |
+
def create_pipeline(self, ckpt_name, pipeline_name):
|
| 102 |
+
pipeline_class = PIPELINES[pipeline_name]
|
| 103 |
+
|
| 104 |
+
pipe = pipeline_class.from_single_file(
|
| 105 |
+
pretrained_model_link_or_path=folder_paths.get_full_path(
|
| 106 |
+
"checkpoints", ckpt_name
|
| 107 |
+
),
|
| 108 |
+
torch_dtype=self.dtype,
|
| 109 |
+
cache_dir=self.hf_dir,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
return (pipe, pipe.vae, pipe.scheduler)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class DiffusersMVVaeLoader:
|
| 116 |
+
def __init__(self):
|
| 117 |
+
self.hf_dir = folder_paths.get_folder_paths("diffusers")[0]
|
| 118 |
+
self.dtype = torch.float16
|
| 119 |
+
|
| 120 |
+
@classmethod
|
| 121 |
+
def INPUT_TYPES(s):
|
| 122 |
+
return {
|
| 123 |
+
"required": {
|
| 124 |
+
"vae_name": (
|
| 125 |
+
"STRING",
|
| 126 |
+
{"default": "madebyollin/sdxl-vae-fp16-fix"},
|
| 127 |
+
),
|
| 128 |
+
}
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
RETURN_TYPES = ("AUTOENCODER",)
|
| 132 |
+
|
| 133 |
+
FUNCTION = "create_pipeline"
|
| 134 |
+
|
| 135 |
+
CATEGORY = "MV-Adapter"
|
| 136 |
+
|
| 137 |
+
def create_pipeline(self, vae_name):
|
| 138 |
+
vae = AutoencoderKL.from_pretrained(
|
| 139 |
+
pretrained_model_name_or_path=vae_name,
|
| 140 |
+
torch_dtype=self.dtype,
|
| 141 |
+
cache_dir=self.hf_dir,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
return (vae,)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class LdmVaeLoader:
|
| 148 |
+
def __init__(self):
|
| 149 |
+
self.dtype = torch.float16
|
| 150 |
+
|
| 151 |
+
@classmethod
|
| 152 |
+
def INPUT_TYPES(s):
|
| 153 |
+
return {
|
| 154 |
+
"required": {
|
| 155 |
+
"vae_name": (folder_paths.get_filename_list("vae"),),
|
| 156 |
+
"upcast_fp32": ("BOOLEAN", {"default": True}),
|
| 157 |
+
},
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
RETURN_TYPES = ("AUTOENCODER",)
|
| 161 |
+
|
| 162 |
+
FUNCTION = "create_pipeline"
|
| 163 |
+
|
| 164 |
+
CATEGORY = "MV-Adapter"
|
| 165 |
+
|
| 166 |
+
def create_pipeline(self, vae_name, upcast_fp32):
|
| 167 |
+
vae = vae_pt_to_vae_diffuser(
|
| 168 |
+
folder_paths.get_full_path("vae", vae_name), force_upcast=upcast_fp32
|
| 169 |
+
).to(self.dtype)
|
| 170 |
+
|
| 171 |
+
return (vae,)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class DiffusersMVSchedulerLoader:
|
| 175 |
+
def __init__(self):
|
| 176 |
+
self.hf_dir = folder_paths.get_folder_paths("diffusers")[0]
|
| 177 |
+
self.dtype = torch.float16
|
| 178 |
+
|
| 179 |
+
@classmethod
|
| 180 |
+
def INPUT_TYPES(s):
|
| 181 |
+
return {
|
| 182 |
+
"required": {
|
| 183 |
+
"pipeline": ("PIPELINE",),
|
| 184 |
+
"scheduler_name": (list(SCHEDULERS.keys()),),
|
| 185 |
+
"shift_snr": ("BOOLEAN", {"default": True}),
|
| 186 |
+
"shift_mode": (
|
| 187 |
+
list(ShiftSNRScheduler.SHIFT_MODES),
|
| 188 |
+
{"default": "interpolated"},
|
| 189 |
+
),
|
| 190 |
+
"shift_scale": (
|
| 191 |
+
"FLOAT",
|
| 192 |
+
{"default": 8.0, "min": 0.0, "max": 50.0, "step": 1.0},
|
| 193 |
+
),
|
| 194 |
+
}
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
RETURN_TYPES = ("SCHEDULER",)
|
| 198 |
+
|
| 199 |
+
FUNCTION = "load_scheduler"
|
| 200 |
+
|
| 201 |
+
CATEGORY = "MV-Adapter"
|
| 202 |
+
|
| 203 |
+
def load_scheduler(
|
| 204 |
+
self, pipeline, scheduler_name, shift_snr, shift_mode, shift_scale
|
| 205 |
+
):
|
| 206 |
+
scheduler = SCHEDULERS[scheduler_name].from_config(
|
| 207 |
+
pipeline.scheduler.config, torch_dtype=self.dtype
|
| 208 |
+
)
|
| 209 |
+
if shift_snr:
|
| 210 |
+
scheduler = ShiftSNRScheduler.from_scheduler(
|
| 211 |
+
scheduler,
|
| 212 |
+
shift_mode=shift_mode,
|
| 213 |
+
shift_scale=shift_scale,
|
| 214 |
+
scheduler_class=scheduler.__class__,
|
| 215 |
+
)
|
| 216 |
+
return (scheduler,)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
# ADDED: Karras version — same inputs/outputs, but always returns a DPM++ (Karras) scheduler.
|
| 220 |
+
class DiffusersMVSchedulerLoaderKarras:
|
| 221 |
+
def __init__(self):
|
| 222 |
+
self.hf_dir = folder_paths.get_folder_paths("diffusers")[0]
|
| 223 |
+
self.dtype = torch.float16
|
| 224 |
+
|
| 225 |
+
@classmethod
|
| 226 |
+
def INPUT_TYPES(s):
|
| 227 |
+
return {
|
| 228 |
+
"required": {
|
| 229 |
+
"pipeline": ("PIPELINE",),
|
| 230 |
+
"scheduler_name": (list(SCHEDULERS.keys()),),
|
| 231 |
+
"shift_snr": ("BOOLEAN", {"default": True}),
|
| 232 |
+
"shift_mode": (
|
| 233 |
+
list(ShiftSNRSchedulerKarras.SHIFT_MODES),
|
| 234 |
+
{"default": "interpolated"},
|
| 235 |
+
),
|
| 236 |
+
"shift_scale": (
|
| 237 |
+
"FLOAT",
|
| 238 |
+
{"default": 8.0, "min": 0.0, "max": 50.0, "step": 1.0},
|
| 239 |
+
),
|
| 240 |
+
}
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
RETURN_TYPES = ("SCHEDULER",)
|
| 244 |
+
|
| 245 |
+
FUNCTION = "load_scheduler"
|
| 246 |
+
|
| 247 |
+
CATEGORY = "MV-Adapter"
|
| 248 |
+
|
| 249 |
+
def load_scheduler(
|
| 250 |
+
self, pipeline, scheduler_name, shift_snr, shift_mode, shift_scale
|
| 251 |
+
):
|
| 252 |
+
# Build a base scheduler from the pipeline config (kept for parity with original UI),
|
| 253 |
+
# then *replace* it with DPM++ (Karras). If SNR shift is requested, apply via your Karras class.
|
| 254 |
+
base_sched = SCHEDULERS[scheduler_name].from_config(
|
| 255 |
+
pipeline.scheduler.config, torch_dtype=self.dtype
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
# Always use DPM++ Karras:
|
| 259 |
+
if shift_snr:
|
| 260 |
+
# Apply your Karras-enabled Shift SNR on top, and force DPM++ class to guarantee Karras works.
|
| 261 |
+
scheduler = ShiftSNRSchedulerKarras.from_scheduler(
|
| 262 |
+
base_sched,
|
| 263 |
+
shift_mode=shift_mode,
|
| 264 |
+
shift_scale=shift_scale,
|
| 265 |
+
scheduler_class=DPMSolverMultistepScheduler,
|
| 266 |
+
)
|
| 267 |
+
else:
|
| 268 |
+
# No SNR shift requested: just return DPM++ with Karras sigmas
|
| 269 |
+
scheduler = DPMSolverMultistepScheduler.from_config(
|
| 270 |
+
pipeline.scheduler.config,
|
| 271 |
+
algorithm_type="dpmsolver++",
|
| 272 |
+
use_karras_sigmas=True,
|
| 273 |
+
torch_dtype=self.dtype,
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
return (scheduler,)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
class CustomLoraModelLoader:
|
| 280 |
+
def __init__(self):
|
| 281 |
+
self.loaded_lora = None
|
| 282 |
+
|
| 283 |
+
@classmethod
|
| 284 |
+
def INPUT_TYPES(s):
|
| 285 |
+
return {
|
| 286 |
+
"required": {
|
| 287 |
+
"pipeline": ("PIPELINE",),
|
| 288 |
+
"lora_name": (folder_paths.get_filename_list("loras"),),
|
| 289 |
+
"strength_model": (
|
| 290 |
+
"FLOAT",
|
| 291 |
+
{"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01},
|
| 292 |
+
),
|
| 293 |
+
"enable": (
|
| 294 |
+
"BOOLEAN",
|
| 295 |
+
{"default": True},
|
| 296 |
+
),
|
| 297 |
+
"last_lora_node": (
|
| 298 |
+
"BOOLEAN",
|
| 299 |
+
{"default": True},
|
| 300 |
+
),
|
| 301 |
+
}
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
RETURN_TYPES = ("PIPELINE",)
|
| 305 |
+
FUNCTION = "load_lora"
|
| 306 |
+
CATEGORY = "MV-Adapter"
|
| 307 |
+
|
| 308 |
+
def load_lora(self, pipeline, lora_name, strength_model, enable, last_lora_node):
|
| 309 |
+
if not hasattr(pipeline, "loaded_loras"):
|
| 310 |
+
pipeline.loaded_loras = []
|
| 311 |
+
|
| 312 |
+
lora_path = folder_paths.get_full_path("loras", lora_name)
|
| 313 |
+
lora_dir = os.path.dirname(lora_path)
|
| 314 |
+
lora_name = os.path.basename(lora_path)
|
| 315 |
+
lora = None
|
| 316 |
+
if enable:
|
| 317 |
+
if self.loaded_lora is not None:
|
| 318 |
+
if self.loaded_lora[0] == lora_path:
|
| 319 |
+
lora = self.loaded_lora[1]
|
| 320 |
+
else:
|
| 321 |
+
temp = self.loaded_lora
|
| 322 |
+
pipeline.delete_adapters(temp[1])
|
| 323 |
+
pipeline.loaded_loras = [(name, strength) for (name, strength) in pipeline.loaded_loras if name != temp[1]]
|
| 324 |
+
self.loaded_lora = None
|
| 325 |
+
|
| 326 |
+
if lora is None:
|
| 327 |
+
adapter_name = lora_name.rsplit(".", 1)[0]
|
| 328 |
+
pipeline.load_lora_weights(
|
| 329 |
+
lora_dir, weight_name=lora_name, adapter_name=adapter_name
|
| 330 |
+
)
|
| 331 |
+
pipeline.set_adapters(adapter_name, strength_model)
|
| 332 |
+
self.loaded_lora = (lora_path, adapter_name)
|
| 333 |
+
lora = adapter_name
|
| 334 |
+
|
| 335 |
+
pipeline.loaded_loras.append((adapter_name, strength_model))
|
| 336 |
+
else:
|
| 337 |
+
# Delete the loaded lora
|
| 338 |
+
if self.loaded_lora is not None:
|
| 339 |
+
temp = self.loaded_lora
|
| 340 |
+
pipeline.delete_adapters(temp[1])
|
| 341 |
+
pipeline.loaded_loras = [(name, strength) for (name, strength) in pipeline.loaded_loras if name != temp[1]]
|
| 342 |
+
self.loaded_lora = None
|
| 343 |
+
|
| 344 |
+
if last_lora_node:
|
| 345 |
+
adapter_names = [x[0] for x in pipeline.loaded_loras]
|
| 346 |
+
strengths = [x[1] for x in pipeline.loaded_loras]
|
| 347 |
+
pipeline.set_adapters(adapter_names, strengths)
|
| 348 |
+
|
| 349 |
+
print(adapter_names)
|
| 350 |
+
|
| 351 |
+
return (pipeline,)
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
class ControlNetModelLoader:
|
| 355 |
+
def __init__(self):
|
| 356 |
+
self.loaded_controlnet = None
|
| 357 |
+
self.dtype = torch.float16
|
| 358 |
+
self.torch_device = get_torch_device()
|
| 359 |
+
self.hf_dir = folder_paths.get_folder_paths("diffusers")[0]
|
| 360 |
+
|
| 361 |
+
@classmethod
|
| 362 |
+
def INPUT_TYPES(s):
|
| 363 |
+
return {
|
| 364 |
+
"required": {
|
| 365 |
+
"pipeline": ("PIPELINE",),
|
| 366 |
+
"controlnet_name": (
|
| 367 |
+
"STRING",
|
| 368 |
+
{"default": "xinsir/controlnet-scribble-sdxl-1.0"},
|
| 369 |
+
),
|
| 370 |
+
}
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
RETURN_TYPES = ("PIPELINE",)
|
| 374 |
+
FUNCTION = "load_controlnet"
|
| 375 |
+
CATEGORY = "MV-Adapter"
|
| 376 |
+
|
| 377 |
+
def load_controlnet(self, pipeline, controlnet_name):
|
| 378 |
+
controlnet = None
|
| 379 |
+
if self.loaded_controlnet is not None:
|
| 380 |
+
if self.loaded_controlnet == controlnet_name:
|
| 381 |
+
controlnet = self.loaded_controlnet
|
| 382 |
+
else:
|
| 383 |
+
del pipeline.controlnet
|
| 384 |
+
self.loaded_controlnet = None
|
| 385 |
+
|
| 386 |
+
if controlnet is None:
|
| 387 |
+
controlnet = ControlNetModel.from_pretrained(
|
| 388 |
+
controlnet_name, cache_dir=self.hf_dir, torch_dtype=self.dtype
|
| 389 |
+
)
|
| 390 |
+
pipeline.controlnet = controlnet
|
| 391 |
+
pipeline.controlnet.to(device=self.torch_device, dtype=self.dtype)
|
| 392 |
+
|
| 393 |
+
self.loaded_controlnet = controlnet_name
|
| 394 |
+
controlnet = controlnet_name
|
| 395 |
+
|
| 396 |
+
return (pipeline,)
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
class DiffusersMVModelMakeup:
|
| 400 |
+
def __init__(self):
|
| 401 |
+
self.hf_dir = folder_paths.get_folder_paths("diffusers")[0]
|
| 402 |
+
self.torch_device = get_torch_device()
|
| 403 |
+
self.dtype = torch.float16
|
| 404 |
+
|
| 405 |
+
@classmethod
|
| 406 |
+
def INPUT_TYPES(s):
|
| 407 |
+
return {
|
| 408 |
+
"required": {
|
| 409 |
+
"pipeline": ("PIPELINE",),
|
| 410 |
+
"scheduler": ("SCHEDULER",),
|
| 411 |
+
"autoencoder": ("AUTOENCODER",),
|
| 412 |
+
"load_mvadapter": ("BOOLEAN", {"default": True}),
|
| 413 |
+
"adapter_path": ("STRING", {"default": "huanngzh/mv-adapter"}),
|
| 414 |
+
"adapter_name": (
|
| 415 |
+
MVADAPTERS,
|
| 416 |
+
{"default": "mvadapter_t2mv_sdxl.safetensors"},
|
| 417 |
+
),
|
| 418 |
+
"num_views": ("INT", {"default": 6, "min": 1, "max": 12}),
|
| 419 |
+
},
|
| 420 |
+
"optional": {
|
| 421 |
+
"enable_vae_slicing": ("BOOLEAN", {"default": True}),
|
| 422 |
+
"enable_vae_tiling": ("BOOLEAN", {"default": False}),
|
| 423 |
+
},
|
| 424 |
+
}
|
| 425 |
+
|
| 426 |
+
RETURN_TYPES = ("PIPELINE",)
|
| 427 |
+
|
| 428 |
+
FUNCTION = "makeup_pipeline"
|
| 429 |
+
|
| 430 |
+
CATEGORY = "MV-Adapter"
|
| 431 |
+
|
| 432 |
+
def makeup_pipeline(
|
| 433 |
+
self,
|
| 434 |
+
pipeline,
|
| 435 |
+
scheduler,
|
| 436 |
+
autoencoder,
|
| 437 |
+
load_mvadapter,
|
| 438 |
+
adapter_path,
|
| 439 |
+
adapter_name,
|
| 440 |
+
num_views,
|
| 441 |
+
enable_vae_slicing=True,
|
| 442 |
+
enable_vae_tiling=False,
|
| 443 |
+
):
|
| 444 |
+
pipeline.vae = autoencoder
|
| 445 |
+
pipeline.scheduler = scheduler
|
| 446 |
+
|
| 447 |
+
if load_mvadapter:
|
| 448 |
+
pipeline.init_custom_adapter(num_views=num_views)
|
| 449 |
+
pipeline.load_custom_adapter(
|
| 450 |
+
adapter_path, weight_name=adapter_name, cache_dir=self.hf_dir
|
| 451 |
+
)
|
| 452 |
+
pipeline.cond_encoder.to(device=self.torch_device, dtype=self.dtype)
|
| 453 |
+
|
| 454 |
+
pipeline = pipeline.to(self.torch_device, self.dtype)
|
| 455 |
+
|
| 456 |
+
if enable_vae_slicing:
|
| 457 |
+
pipeline.enable_vae_slicing()
|
| 458 |
+
if enable_vae_tiling:
|
| 459 |
+
pipeline.enable_vae_tiling()
|
| 460 |
+
|
| 461 |
+
return (pipeline,)
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
class DiffusersSampler:
|
| 465 |
+
def __init__(self):
|
| 466 |
+
self.torch_device = get_torch_device()
|
| 467 |
+
|
| 468 |
+
@classmethod
|
| 469 |
+
def INPUT_TYPES(s):
|
| 470 |
+
return {
|
| 471 |
+
"required": {
|
| 472 |
+
"pipeline": ("PIPELINE",),
|
| 473 |
+
"prompt": (
|
| 474 |
+
"STRING",
|
| 475 |
+
{"multiline": True, "default": "a photo of a cat"},
|
| 476 |
+
),
|
| 477 |
+
"negative_prompt": (
|
| 478 |
+
"STRING",
|
| 479 |
+
{
|
| 480 |
+
"multiline": True,
|
| 481 |
+
"default": "watermark, ugly, deformed, noisy, blurry, low contrast",
|
| 482 |
+
},
|
| 483 |
+
),
|
| 484 |
+
"width": ("INT", {"default": 768, "min": 1, "max": 2048, "step": 1}),
|
| 485 |
+
"height": ("INT", {"default": 768, "min": 1, "max": 2048, "step": 1}),
|
| 486 |
+
"steps": ("INT", {"default": 50, "min": 1, "max": 2000}),
|
| 487 |
+
"cfg": (
|
| 488 |
+
"FLOAT",
|
| 489 |
+
{
|
| 490 |
+
"default": 7.0,
|
| 491 |
+
"min": 0.0,
|
| 492 |
+
"max": 100.0,
|
| 493 |
+
"step": 0.1,
|
| 494 |
+
"round": 0.01,
|
| 495 |
+
},
|
| 496 |
+
),
|
| 497 |
+
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF}),
|
| 498 |
+
}
|
| 499 |
+
}
|
| 500 |
+
|
| 501 |
+
RETURN_TYPES = ("IMAGE",)
|
| 502 |
+
|
| 503 |
+
FUNCTION = "sample"
|
| 504 |
+
|
| 505 |
+
CATEGORY = "MV-Adapter"
|
| 506 |
+
|
| 507 |
+
def sample(
|
| 508 |
+
self,
|
| 509 |
+
pipeline,
|
| 510 |
+
prompt,
|
| 511 |
+
negative_prompt,
|
| 512 |
+
height,
|
| 513 |
+
width,
|
| 514 |
+
steps,
|
| 515 |
+
cfg,
|
| 516 |
+
seed,
|
| 517 |
+
):
|
| 518 |
+
images = pipeline(
|
| 519 |
+
prompt=prompt,
|
| 520 |
+
height=height,
|
| 521 |
+
width=width,
|
| 522 |
+
num_inference_steps=steps,
|
| 523 |
+
guidance_scale=cfg,
|
| 524 |
+
negative_prompt=negative_prompt,
|
| 525 |
+
generator=torch.Generator(self.torch_device).manual_seed(seed),
|
| 526 |
+
).images
|
| 527 |
+
return (convert_images_to_tensors(images),)
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
class DiffusersMVSampler:
|
| 531 |
+
def __init__(self):
|
| 532 |
+
self.torch_device = get_torch_device()
|
| 533 |
+
|
| 534 |
+
@classmethod
|
| 535 |
+
def INPUT_TYPES(s):
|
| 536 |
+
return {
|
| 537 |
+
"required": {
|
| 538 |
+
"pipeline": ("PIPELINE",),
|
| 539 |
+
"num_views": ("INT", {"default": 6, "min": 1, "max": 12}),
|
| 540 |
+
"prompt": (
|
| 541 |
+
"STRING",
|
| 542 |
+
{"multiline": True, "default": "an astronaut riding a horse"},
|
| 543 |
+
),
|
| 544 |
+
"negative_prompt": (
|
| 545 |
+
"STRING",
|
| 546 |
+
{
|
| 547 |
+
"multiline": True,
|
| 548 |
+
"default": "watermark, ugly, deformed, noisy, blurry, low contrast",
|
| 549 |
+
},
|
| 550 |
+
),
|
| 551 |
+
"width": ("INT", {"default": 768, "min": 1, "max": 2048, "step": 1}),
|
| 552 |
+
"height": ("INT", {"default": 768, "min": 1, "max": 2048, "step": 1}),
|
| 553 |
+
"steps": ("INT", {"default": 50, "min": 1, "max": 2000}),
|
| 554 |
+
"cfg": (
|
| 555 |
+
"FLOAT",
|
| 556 |
+
{
|
| 557 |
+
"default": 7.0,
|
| 558 |
+
"min": 0.0,
|
| 559 |
+
"max": 100.0,
|
| 560 |
+
"step": 0.1,
|
| 561 |
+
"round": 0.01,
|
| 562 |
+
},
|
| 563 |
+
),
|
| 564 |
+
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF}),
|
| 565 |
+
},
|
| 566 |
+
"optional": {
|
| 567 |
+
"reference_image": ("IMAGE",),
|
| 568 |
+
"controlnet_image": ("IMAGE",),
|
| 569 |
+
"controlnet_conditioning_scale": ("FLOAT", {"default": 1.0}),
|
| 570 |
+
"azimuth_degrees": ("LIST", {"default": [0, 45, 90, 180, 270, 315]}),
|
| 571 |
+
},
|
| 572 |
+
}
|
| 573 |
+
|
| 574 |
+
RETURN_TYPES = ("IMAGE",)
|
| 575 |
+
|
| 576 |
+
FUNCTION = "sample"
|
| 577 |
+
|
| 578 |
+
CATEGORY = "MV-Adapter"
|
| 579 |
+
|
| 580 |
+
def sample(
|
| 581 |
+
self,
|
| 582 |
+
pipeline,
|
| 583 |
+
num_views,
|
| 584 |
+
prompt,
|
| 585 |
+
negative_prompt,
|
| 586 |
+
height,
|
| 587 |
+
width,
|
| 588 |
+
steps,
|
| 589 |
+
cfg,
|
| 590 |
+
seed,
|
| 591 |
+
reference_image=None,
|
| 592 |
+
controlnet_image=None,
|
| 593 |
+
controlnet_conditioning_scale=1.0,
|
| 594 |
+
azimuth_degrees=[0, 45, 90, 180, 270, 315],
|
| 595 |
+
):
|
| 596 |
+
num_views = len(azimuth_degrees)
|
| 597 |
+
control_images = prepare_camera_embed(
|
| 598 |
+
num_views, width, self.torch_device, azimuth_degrees
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
pipe_kwargs = {}
|
| 602 |
+
if reference_image is not None:
|
| 603 |
+
pipe_kwargs.update(
|
| 604 |
+
{
|
| 605 |
+
"reference_image": convert_tensors_to_images(reference_image)[0],
|
| 606 |
+
"reference_conditioning_scale": 1.0,
|
| 607 |
+
}
|
| 608 |
+
)
|
| 609 |
+
if controlnet_image is not None:
|
| 610 |
+
controlnet_image = convert_tensors_to_images(controlnet_image)
|
| 611 |
+
pipe_kwargs.update(
|
| 612 |
+
{
|
| 613 |
+
"controlnet_image": controlnet_image,
|
| 614 |
+
"controlnet_conditioning_scale": controlnet_conditioning_scale,
|
| 615 |
+
}
|
| 616 |
+
)
|
| 617 |
+
|
| 618 |
+
images = pipeline(
|
| 619 |
+
prompt=prompt,
|
| 620 |
+
height=height,
|
| 621 |
+
width=width,
|
| 622 |
+
num_inference_steps=steps,
|
| 623 |
+
guidance_scale=cfg,
|
| 624 |
+
num_images_per_prompt=num_views,
|
| 625 |
+
control_image=control_images,
|
| 626 |
+
control_conditioning_scale=1.0,
|
| 627 |
+
negative_prompt=negative_prompt,
|
| 628 |
+
generator=torch.Generator(self.torch_device).manual_seed(seed),
|
| 629 |
+
cross_attention_kwargs={"num_views": num_views},
|
| 630 |
+
**pipe_kwargs,
|
| 631 |
+
).images
|
| 632 |
+
return (convert_images_to_tensors(images),)
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
class BiRefNet:
|
| 636 |
+
def __init__(self):
|
| 637 |
+
self.hf_dir = folder_paths.get_folder_paths("diffusers")[0]
|
| 638 |
+
self.torch_device = get_torch_device()
|
| 639 |
+
self.dtype = torch.float32
|
| 640 |
+
|
| 641 |
+
RETURN_TYPES = ("FUNCTION",)
|
| 642 |
+
|
| 643 |
+
FUNCTION = "load_model_fn"
|
| 644 |
+
|
| 645 |
+
CATEGORY = "MV-Adapter"
|
| 646 |
+
|
| 647 |
+
@classmethod
|
| 648 |
+
def INPUT_TYPES(s):
|
| 649 |
+
return {
|
| 650 |
+
"required": {"ckpt_name": ("STRING", {"default": "briaai/RMBG-2.0"})}
|
| 651 |
+
}
|
| 652 |
+
|
| 653 |
+
def remove_bg(self, image, net, transform, device):
|
| 654 |
+
image_size = image.size
|
| 655 |
+
input_images = transform(image).unsqueeze(0).to(device)
|
| 656 |
+
with torch.no_grad():
|
| 657 |
+
preds = net(input_images)[-1].sigmoid().cpu()
|
| 658 |
+
pred = preds[0].squeeze()
|
| 659 |
+
pred_pil = transforms.ToPILImage()(pred)
|
| 660 |
+
mask = pred_pil.resize(image_size)
|
| 661 |
+
image.putalpha(mask)
|
| 662 |
+
return image
|
| 663 |
+
|
| 664 |
+
def load_model_fn(self, ckpt_name):
|
| 665 |
+
model = AutoModelForImageSegmentation.from_pretrained(
|
| 666 |
+
ckpt_name, trust_remote_code=True, cache_dir=self.hf_dir
|
| 667 |
+
).to(self.torch_device, self.dtype)
|
| 668 |
+
|
| 669 |
+
transform_image = transforms.Compose(
|
| 670 |
+
[
|
| 671 |
+
transforms.Resize((1024, 1024)),
|
| 672 |
+
transforms.ToTensor(),
|
| 673 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
| 674 |
+
]
|
| 675 |
+
)
|
| 676 |
+
|
| 677 |
+
remove_bg_fn = lambda x: self.remove_bg(
|
| 678 |
+
x, model, transform_image, self.torch_device
|
| 679 |
+
)
|
| 680 |
+
return (remove_bg_fn,)
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
class ImagePreprocessor:
|
| 684 |
+
def __init__(self):
|
| 685 |
+
self.torch_device = get_torch_device()
|
| 686 |
+
|
| 687 |
+
@classmethod
|
| 688 |
+
def INPUT_TYPES(s):
|
| 689 |
+
return {
|
| 690 |
+
"required": {
|
| 691 |
+
"remove_bg_fn": ("FUNCTION",),
|
| 692 |
+
"image": ("IMAGE",),
|
| 693 |
+
"height": ("INT", {"default": 768, "min": 1, "max": 2048, "step": 1}),
|
| 694 |
+
"width": ("INT", {"default": 768, "min": 1, "max": 2048, "step": 1}),
|
| 695 |
+
}
|
| 696 |
+
}
|
| 697 |
+
|
| 698 |
+
RETURN_TYPES = ("IMAGE",)
|
| 699 |
+
|
| 700 |
+
FUNCTION = "process"
|
| 701 |
+
|
| 702 |
+
def process(self, remove_bg_fn, image, height, width):
|
| 703 |
+
images = convert_tensors_to_images(image)
|
| 704 |
+
images = [
|
| 705 |
+
preprocess_image(remove_bg_fn(img.convert("RGB")), height, width)
|
| 706 |
+
for img in images
|
| 707 |
+
]
|
| 708 |
+
|
| 709 |
+
return (convert_images_to_tensors(images),)
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
class ControlImagePreprocessor:
|
| 713 |
+
def __init__(self):
|
| 714 |
+
self.torch_device = get_torch_device()
|
| 715 |
+
|
| 716 |
+
@classmethod
|
| 717 |
+
def INPUT_TYPES(s):
|
| 718 |
+
return {
|
| 719 |
+
"required": {
|
| 720 |
+
"front_view": ("IMAGE",),
|
| 721 |
+
"front_right_view": ("IMAGE",),
|
| 722 |
+
"right_view": ("IMAGE",),
|
| 723 |
+
"back_view": ("IMAGE",),
|
| 724 |
+
"left_view": ("IMAGE",),
|
| 725 |
+
"front_left_view": ("IMAGE",),
|
| 726 |
+
"width": ("INT", {"default": 768, "min": 1, "max": 2048, "step": 1}),
|
| 727 |
+
"height": ("INT", {"default": 768, "min": 1, "max": 2048, "step": 1}),
|
| 728 |
+
}
|
| 729 |
+
}
|
| 730 |
+
|
| 731 |
+
RETURN_TYPES = ("IMAGE",)
|
| 732 |
+
|
| 733 |
+
FUNCTION = "process"
|
| 734 |
+
|
| 735 |
+
def process(
|
| 736 |
+
self,
|
| 737 |
+
front_view,
|
| 738 |
+
front_right_view,
|
| 739 |
+
right_view,
|
| 740 |
+
back_view,
|
| 741 |
+
left_view,
|
| 742 |
+
front_left_view,
|
| 743 |
+
width,
|
| 744 |
+
height,
|
| 745 |
+
):
|
| 746 |
+
images = torch.cat(
|
| 747 |
+
[
|
| 748 |
+
front_view,
|
| 749 |
+
front_right_view,
|
| 750 |
+
right_view,
|
| 751 |
+
back_view,
|
| 752 |
+
left_view,
|
| 753 |
+
front_left_view,
|
| 754 |
+
],
|
| 755 |
+
dim=0,
|
| 756 |
+
)
|
| 757 |
+
images = convert_tensors_to_images(images)
|
| 758 |
+
images = [img.resize((width, height)).convert("RGB") for img in images]
|
| 759 |
+
return (convert_images_to_tensors(images),)
|
| 760 |
+
|
| 761 |
+
|
| 762 |
+
class ViewSelector:
|
| 763 |
+
def __init__(self):
|
| 764 |
+
pass
|
| 765 |
+
|
| 766 |
+
@classmethod
|
| 767 |
+
def INPUT_TYPES(s):
|
| 768 |
+
return {
|
| 769 |
+
"required": {
|
| 770 |
+
"front_view": ("BOOLEAN", {"default": True}),
|
| 771 |
+
"front_right_view": ("BOOLEAN", {"default": True}),
|
| 772 |
+
"right_view": ("BOOLEAN", {"default": True}),
|
| 773 |
+
"back_view": ("BOOLEAN", {"default": True}),
|
| 774 |
+
"left_view": ("BOOLEAN", {"default": True}),
|
| 775 |
+
"front_left_view": ("BOOLEAN", {"default": True}),
|
| 776 |
+
}
|
| 777 |
+
}
|
| 778 |
+
|
| 779 |
+
RETURN_TYPES = ("LIST",)
|
| 780 |
+
FUNCTION = "process"
|
| 781 |
+
CATEGORY = "MV-Adapter"
|
| 782 |
+
|
| 783 |
+
def process(
|
| 784 |
+
self,
|
| 785 |
+
front_view,
|
| 786 |
+
front_right_view,
|
| 787 |
+
right_view,
|
| 788 |
+
back_view,
|
| 789 |
+
left_view,
|
| 790 |
+
front_left_view,
|
| 791 |
+
):
|
| 792 |
+
azimuth_deg = []
|
| 793 |
+
if front_view:
|
| 794 |
+
azimuth_deg.append(0)
|
| 795 |
+
if front_right_view:
|
| 796 |
+
azimuth_deg.append(45)
|
| 797 |
+
if right_view:
|
| 798 |
+
azimuth_deg.append(90)
|
| 799 |
+
if back_view:
|
| 800 |
+
azimuth_deg.append(180)
|
| 801 |
+
if left_view:
|
| 802 |
+
azimuth_deg.append(270)
|
| 803 |
+
if front_left_view:
|
| 804 |
+
azimuth_deg.append(315)
|
| 805 |
+
|
| 806 |
+
return (azimuth_deg,)
|
| 807 |
+
|
| 808 |
+
|
| 809 |
+
NODE_CLASS_MAPPINGS = {
|
| 810 |
+
"LdmPipelineLoader": LdmPipelineLoader,
|
| 811 |
+
"LdmVaeLoader": LdmVaeLoader,
|
| 812 |
+
"DiffusersMVPipelineLoader": DiffusersMVPipelineLoader,
|
| 813 |
+
"DiffusersMVVaeLoader": DiffusersMVVaeLoader,
|
| 814 |
+
"DiffusersMVSchedulerLoader": DiffusersMVSchedulerLoader,
|
| 815 |
+
# ADDED: Karras version
|
| 816 |
+
"DiffusersMVSchedulerLoaderKarras": DiffusersMVSchedulerLoaderKarras,
|
| 817 |
+
"DiffusersMVModelMakeup": DiffusersMVModelMakeup,
|
| 818 |
+
"CustomLoraModelLoader": CustomLoraModelLoader,
|
| 819 |
+
"DiffusersMVSampler": DiffusersMVSampler,
|
| 820 |
+
"BiRefNet": BiRefNet,
|
| 821 |
+
"ImagePreprocessor": ImagePreprocessor,
|
| 822 |
+
"ControlNetModelLoader": ControlNetModelLoader,
|
| 823 |
+
"ControlImagePreprocessor": ControlImagePreprocessor,
|
| 824 |
+
"ViewSelector": ViewSelector,
|
| 825 |
+
}
|
| 826 |
+
|
| 827 |
+
NODE_DISPLAY_NAME_MAPPINGS = {
|
| 828 |
+
"LdmPipelineLoader": "LDM Pipeline Loader",
|
| 829 |
+
"LdmVaeLoader": "LDM Vae Loader",
|
| 830 |
+
"DiffusersMVPipelineLoader": "Diffusers MV Pipeline Loader",
|
| 831 |
+
"DiffusersMVVaeLoader": "Diffusers MV Vae Loader",
|
| 832 |
+
"DiffusersMVSchedulerLoader": "Diffusers MV Scheduler Loader",
|
| 833 |
+
# ADDED: Karras version
|
| 834 |
+
"DiffusersMVSchedulerLoaderKarras": "Diffusers MV Scheduler Loader (Karras)",
|
| 835 |
+
"DiffusersMVModelMakeup": "Diffusers MV Model Makeup",
|
| 836 |
+
"CustomLoraModelLoader": "Custom Lora Model Loader",
|
| 837 |
+
"DiffusersMVSampler": "Diffusers MV Sampler",
|
| 838 |
+
"BiRefNet": "BiRefNet",
|
| 839 |
+
"ImagePreprocessor": "Image Preprocessor",
|
| 840 |
+
"ControlNetModelLoader": "ControlNet Model Loader",
|
| 841 |
+
"ControlImagePreprocessor": "Control Image Preprocessor",
|
| 842 |
+
"ViewSelector": "View Selector",
|
| 843 |
+
}
|
comfyui-mvadapter/LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
comfyui-mvadapter/README.md
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ComfyUI-MVAdapter
|
| 2 |
+
|
| 3 |
+
This extension integrates [MV-Adapter](https://github.com/huanngzh/MV-Adapter) into ComfyUI, allowing users to generate multi-view consistent images from text prompts or single images directly within the ComfyUI interface.
|
| 4 |
+
|
| 5 |
+
## 🔥 Feature Updates
|
| 6 |
+
|
| 7 |
+
* [2025-06-26] Support multiple loras for multi-view synthesis [See [here](https://github.com/huanngzh/ComfyUI-MVAdapter/pull/96)]
|
| 8 |
+
* [2025-01-15] Support selection of generated perspectives, such as generating only 2 views (front&back) [See [here](#view-selection)]
|
| 9 |
+
* [2024-12-25] Support integration with ControlNet, for applications like scribble to multi-view images [See [here](#with-controlnet)]
|
| 10 |
+
* [2024-12-09] Support integration with SDXL LoRA [See [here](#with-lora)]
|
| 11 |
+
* [2024-12-02] Generate multi-view consistent images from text prompts or a single image
|
| 12 |
+
|
| 13 |
+
## Installation
|
| 14 |
+
|
| 15 |
+
### From Source
|
| 16 |
+
|
| 17 |
+
* Clone or download this repository into your `ComfyUI/custom_nodes/` directory.
|
| 18 |
+
* Install the required dependencies by running `pip install -r requirements.txt`.
|
| 19 |
+
|
| 20 |
+
## Notes
|
| 21 |
+
|
| 22 |
+
### Workflows
|
| 23 |
+
|
| 24 |
+
We provide the example workflows in `workflows` directory.
|
| 25 |
+
|
| 26 |
+
Note that our code depends on diffusers, and will automatically download the model weights from huggingface to the hf cache path at the first time. The `ckpt_name` in the node corresponds to the model name in huggingface, such as `stabilityai/stable-diffusion-xl-base-1.0`.
|
| 27 |
+
|
| 28 |
+
We also provide the nodes `Ldm**Loader` to support loading text-to-image models in `ldm` format. Please see the workflow files with the suffix `_ldm.json`.
|
| 29 |
+
|
| 30 |
+
### GPU Memory
|
| 31 |
+
|
| 32 |
+
If your GPU resources are limited, we recommend using the following configuration:
|
| 33 |
+
|
| 34 |
+
* Use [madebyollin/sdxl-vae-fp16-fix](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix) as VAE. If using ldm-format pipeline, remember to set `upcast_fp32` to `False`.
|
| 35 |
+
|
| 36 |
+

|
| 37 |
+
|
| 38 |
+
* Set `enable_vae_slicing` in the Diffusers Model Makeup node to `True`.
|
| 39 |
+
|
| 40 |
+

|
| 41 |
+
|
| 42 |
+
However, since SDXL is used as the base model, it still requires about 13G to 14G GPU memory.
|
| 43 |
+
|
| 44 |
+
## Usage
|
| 45 |
+
|
| 46 |
+
### Text to Multi-view Images
|
| 47 |
+
|
| 48 |
+
#### With SDXL or other base models
|
| 49 |
+
|
| 50 |
+

|
| 51 |
+
|
| 52 |
+
* `workflows/t2mv_sdxl_diffusers.json` for loading diffusers-format models
|
| 53 |
+
* `workflows/t2mv_sdxl_ldm.json` for loading ldm-format models
|
| 54 |
+
|
| 55 |
+
#### With LoRA
|
| 56 |
+
|
| 57 |
+

|
| 58 |
+
|
| 59 |
+
`workflows/t2mv_sdxl_ldm_lora.json` for loading ldm-format models with LoRA for text-to-multi-view generation
|
| 60 |
+
|
| 61 |
+
#### With ControlNet
|
| 62 |
+
|
| 63 |
+

|
| 64 |
+
|
| 65 |
+
`workflows/t2mv_sdxl_ldm_controlnet.json` for loading diffusers-format controlnets for text-scribble-to-multi-view generation
|
| 66 |
+
|
| 67 |
+
### Image to Multi-view Images
|
| 68 |
+
|
| 69 |
+
#### With SDXL or other base models
|
| 70 |
+
|
| 71 |
+

|
| 72 |
+
|
| 73 |
+
* `workflows/i2mv_sdxl_diffusers.json` for loading diffusers-format models
|
| 74 |
+
* `workflows/i2mv_sdxl_ldm.json` for loading ldm-format models
|
| 75 |
+
|
| 76 |
+
#### With LoRA
|
| 77 |
+
|
| 78 |
+

|
| 79 |
+
|
| 80 |
+
`workflows/i2mv_sdxl_ldm_lora.json` for loading ldm-format models with LoRA for image-to-multi-view generation
|
| 81 |
+
|
| 82 |
+
#### View Selection
|
| 83 |
+
|
| 84 |
+

|
| 85 |
+
|
| 86 |
+
`workflows/i2mv_sdxl_ldm_view_selector.json` for loading ldm-format models and selecting specific views to generate
|
| 87 |
+
|
| 88 |
+
The key is to replace the `adapter_name` in `Diffusers Model Makeup` with `mvadapter_i2mv_sdxl_beta.safetensors`, and add a `View Selector` node to choose which views you want to generate. After a rough test, the beta model is better at generating 2 views (front&back), 3 views (front&right&back), 4 views (front&right&back&left). Note that the attribute `num_views` is not used and can be ignored.
|
comfyui-mvadapter/__init__.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# __init__.py for comfyui-mvadapter
|
| 2 |
+
# Register BOTH node sets: the original nodes.py and nodes_local_mv.py
|
| 3 |
+
|
| 4 |
+
import traceback
|
| 5 |
+
|
| 6 |
+
# Load the original nodes (if present)
|
| 7 |
+
try:
|
| 8 |
+
from .nodes import (
|
| 9 |
+
NODE_CLASS_MAPPINGS as CORE_NODE_CLASS_MAPPINGS,
|
| 10 |
+
NODE_DISPLAY_NAME_MAPPINGS as CORE_NODE_DISPLAY_NAME_MAPPINGS,
|
| 11 |
+
)
|
| 12 |
+
except Exception as e:
|
| 13 |
+
print("[comfyui-mvadapter] WARN: Failed to import .nodes")
|
| 14 |
+
traceback.print_exc()
|
| 15 |
+
CORE_NODE_CLASS_MAPPINGS = {}
|
| 16 |
+
CORE_NODE_DISPLAY_NAME_MAPPINGS = {}
|
| 17 |
+
|
| 18 |
+
# Load the local-only nodes (if present)
|
| 19 |
+
try:
|
| 20 |
+
from .nodes_local_mv import (
|
| 21 |
+
NODE_CLASS_MAPPINGS as LOCAL_NODE_CLASS_MAPPINGS,
|
| 22 |
+
NODE_DISPLAY_NAME_MAPPINGS as LOCAL_NODE_DISPLAY_NAME_MAPPINGS,
|
| 23 |
+
)
|
| 24 |
+
except Exception as e:
|
| 25 |
+
print("[comfyui-mvadapter] WARN: Failed to import .nodes_local_mv")
|
| 26 |
+
traceback.print_exc()
|
| 27 |
+
LOCAL_NODE_CLASS_MAPPINGS = {}
|
| 28 |
+
LOCAL_NODE_DISPLAY_NAME_MAPPINGS = {}
|
| 29 |
+
|
| 30 |
+
# Merge into the symbols ComfyUI looks for
|
| 31 |
+
NODE_CLASS_MAPPINGS = {}
|
| 32 |
+
NODE_CLASS_MAPPINGS.update(CORE_NODE_CLASS_MAPPINGS)
|
| 33 |
+
NODE_CLASS_MAPPINGS.update(LOCAL_NODE_CLASS_MAPPINGS)
|
| 34 |
+
|
| 35 |
+
NODE_DISPLAY_NAME_MAPPINGS = {}
|
| 36 |
+
NODE_DISPLAY_NAME_MAPPINGS.update(CORE_NODE_DISPLAY_NAME_MAPPINGS)
|
| 37 |
+
NODE_DISPLAY_NAME_MAPPINGS.update(LOCAL_NODE_DISPLAY_NAME_MAPPINGS)
|
| 38 |
+
|
| 39 |
+
# Optional: quick summary to help debug load order
|
| 40 |
+
print(
|
| 41 |
+
"[comfyui-mvadapter] Registered nodes:",
|
| 42 |
+
", ".join(sorted(NODE_CLASS_MAPPINGS.keys())) or "(none)",
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
|
comfyui-mvadapter/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (1.4 kB). View file
|
|
|
comfyui-mvadapter/__pycache__/nodes.cpython-312.pyc
ADDED
|
Binary file (8.32 kB). View file
|
|
|
comfyui-mvadapter/__pycache__/nodes_local_mv.cpython-312.pyc
ADDED
|
Binary file (10.7 kB). View file
|
|
|
comfyui-mvadapter/__pycache__/utils.cpython-312.pyc
ADDED
|
Binary file (13.6 kB). View file
|
|
|
comfyui-mvadapter/assets/CustomLoraModelLoader.png
ADDED
|
comfyui-mvadapter/assets/comfyui_i2mv.png
ADDED
|
Git LFS Details
|
comfyui-mvadapter/assets/comfyui_i2mv_lora.png
ADDED
|
Git LFS Details
|
comfyui-mvadapter/assets/comfyui_i2mv_multiple_loras.jpg
ADDED
|
Git LFS Details
|
comfyui-mvadapter/assets/comfyui_i2mv_view_selector.png
ADDED
|
Git LFS Details
|
comfyui-mvadapter/assets/comfyui_ldm_vae.png
ADDED
|
comfyui-mvadapter/assets/comfyui_model_makeup.png
ADDED
|
comfyui-mvadapter/assets/comfyui_t2mv.png
ADDED
|
Git LFS Details
|
comfyui-mvadapter/assets/comfyui_t2mv_controlnet.png
ADDED
|
Git LFS Details
|
comfyui-mvadapter/assets/comfyui_t2mv_lora.png
ADDED
|
Git LFS Details
|
comfyui-mvadapter/assets/comfyui_t2mv_multiple_loras.jpg
ADDED
|
Git LFS Details
|
comfyui-mvadapter/assets/demo/scribbles/scribble_0.png
ADDED
|
comfyui-mvadapter/assets/demo/scribbles/scribble_1.png
ADDED
|
comfyui-mvadapter/assets/demo/scribbles/scribble_2.png
ADDED
|
comfyui-mvadapter/assets/demo/scribbles/scribble_3.png
ADDED
|
comfyui-mvadapter/assets/demo/scribbles/scribble_4.png
ADDED
|
comfyui-mvadapter/assets/demo/scribbles/scribble_5.png
ADDED
|
comfyui-mvadapter/cache/stable-diffusion-v1-inference.yaml
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
base_learning_rate: 1.0e-04
|
| 3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
| 4 |
+
params:
|
| 5 |
+
linear_start: 0.00085
|
| 6 |
+
linear_end: 0.0120
|
| 7 |
+
num_timesteps_cond: 1
|
| 8 |
+
log_every_t: 200
|
| 9 |
+
timesteps: 1000
|
| 10 |
+
first_stage_key: "jpg"
|
| 11 |
+
cond_stage_key: "txt"
|
| 12 |
+
image_size: 64
|
| 13 |
+
channels: 4
|
| 14 |
+
cond_stage_trainable: false # Note: different from the one we trained before
|
| 15 |
+
conditioning_key: crossattn
|
| 16 |
+
monitor: val/loss_simple_ema
|
| 17 |
+
scale_factor: 0.18215
|
| 18 |
+
use_ema: False
|
| 19 |
+
|
| 20 |
+
scheduler_config: # 10000 warmup steps
|
| 21 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
| 22 |
+
params:
|
| 23 |
+
warm_up_steps: [ 10000 ]
|
| 24 |
+
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
| 25 |
+
f_start: [ 1.e-6 ]
|
| 26 |
+
f_max: [ 1. ]
|
| 27 |
+
f_min: [ 1. ]
|
| 28 |
+
|
| 29 |
+
unet_config:
|
| 30 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
| 31 |
+
params:
|
| 32 |
+
image_size: 32 # unused
|
| 33 |
+
in_channels: 4
|
| 34 |
+
out_channels: 4
|
| 35 |
+
model_channels: 320
|
| 36 |
+
attention_resolutions: [ 4, 2, 1 ]
|
| 37 |
+
num_res_blocks: 2
|
| 38 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
| 39 |
+
num_heads: 8
|
| 40 |
+
use_spatial_transformer: True
|
| 41 |
+
transformer_depth: 1
|
| 42 |
+
context_dim: 768
|
| 43 |
+
use_checkpoint: True
|
| 44 |
+
legacy: False
|
| 45 |
+
|
| 46 |
+
first_stage_config:
|
| 47 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
| 48 |
+
params:
|
| 49 |
+
embed_dim: 4
|
| 50 |
+
monitor: val/rec_loss
|
| 51 |
+
ddconfig:
|
| 52 |
+
double_z: true
|
| 53 |
+
z_channels: 4
|
| 54 |
+
resolution: 256
|
| 55 |
+
in_channels: 3
|
| 56 |
+
out_ch: 3
|
| 57 |
+
ch: 128
|
| 58 |
+
ch_mult:
|
| 59 |
+
- 1
|
| 60 |
+
- 2
|
| 61 |
+
- 4
|
| 62 |
+
- 4
|
| 63 |
+
num_res_blocks: 2
|
| 64 |
+
attn_resolutions: []
|
| 65 |
+
dropout: 0.0
|
| 66 |
+
lossconfig:
|
| 67 |
+
target: torch.nn.Identity
|
| 68 |
+
|
| 69 |
+
cond_stage_config:
|
| 70 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
comfyui-mvadapter/mvadapter/__init__.py
ADDED
|
File without changes
|
comfyui-mvadapter/mvadapter/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (161 Bytes). View file
|
|
|
comfyui-mvadapter/mvadapter/loaders/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .custom_adapter import CustomAdapterMixin
|
comfyui-mvadapter/mvadapter/loaders/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (229 Bytes). View file
|
|
|
comfyui-mvadapter/mvadapter/loaders/__pycache__/custom_adapter.cpython-312.pyc
ADDED
|
Binary file (4.44 kB). View file
|
|
|
comfyui-mvadapter/mvadapter/loaders/custom_adapter.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Dict, Optional, Union
|
| 3 |
+
|
| 4 |
+
import safetensors
|
| 5 |
+
import torch
|
| 6 |
+
from diffusers.utils import _get_model_file, logging
|
| 7 |
+
from safetensors import safe_open
|
| 8 |
+
|
| 9 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class CustomAdapterMixin:
|
| 13 |
+
def init_custom_adapter(self, *args, **kwargs):
|
| 14 |
+
self._init_custom_adapter(*args, **kwargs)
|
| 15 |
+
|
| 16 |
+
def _init_custom_adapter(self, *args, **kwargs):
|
| 17 |
+
raise NotImplementedError
|
| 18 |
+
|
| 19 |
+
def load_custom_adapter(
|
| 20 |
+
self,
|
| 21 |
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
| 22 |
+
weight_name: str,
|
| 23 |
+
subfolder: Optional[str] = None,
|
| 24 |
+
**kwargs,
|
| 25 |
+
):
|
| 26 |
+
# Load the main state dict first.
|
| 27 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
| 28 |
+
force_download = kwargs.pop("force_download", False)
|
| 29 |
+
proxies = kwargs.pop("proxies", None)
|
| 30 |
+
local_files_only = kwargs.pop("local_files_only", None)
|
| 31 |
+
token = kwargs.pop("token", None)
|
| 32 |
+
revision = kwargs.pop("revision", None)
|
| 33 |
+
|
| 34 |
+
user_agent = {
|
| 35 |
+
"file_type": "attn_procs_weights",
|
| 36 |
+
"framework": "pytorch",
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
| 40 |
+
model_file = _get_model_file(
|
| 41 |
+
pretrained_model_name_or_path_or_dict,
|
| 42 |
+
weights_name=weight_name,
|
| 43 |
+
subfolder=subfolder,
|
| 44 |
+
cache_dir=cache_dir,
|
| 45 |
+
force_download=force_download,
|
| 46 |
+
proxies=proxies,
|
| 47 |
+
local_files_only=local_files_only,
|
| 48 |
+
token=token,
|
| 49 |
+
revision=revision,
|
| 50 |
+
user_agent=user_agent,
|
| 51 |
+
)
|
| 52 |
+
if weight_name.endswith(".safetensors"):
|
| 53 |
+
state_dict = {}
|
| 54 |
+
with safe_open(model_file, framework="pt", device="cpu") as f:
|
| 55 |
+
for key in f.keys():
|
| 56 |
+
state_dict[key] = f.get_tensor(key)
|
| 57 |
+
else:
|
| 58 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
| 59 |
+
else:
|
| 60 |
+
state_dict = pretrained_model_name_or_path_or_dict
|
| 61 |
+
|
| 62 |
+
self._load_custom_adapter(state_dict)
|
| 63 |
+
|
| 64 |
+
def _load_custom_adapter(self, state_dict):
|
| 65 |
+
raise NotImplementedError
|
| 66 |
+
|
| 67 |
+
def save_custom_adapter(
|
| 68 |
+
self,
|
| 69 |
+
save_directory: Union[str, os.PathLike],
|
| 70 |
+
weight_name: str,
|
| 71 |
+
safe_serialization: bool = False,
|
| 72 |
+
**kwargs,
|
| 73 |
+
):
|
| 74 |
+
if os.path.isfile(save_directory):
|
| 75 |
+
logger.error(
|
| 76 |
+
f"Provided path ({save_directory}) should be a directory, not a file"
|
| 77 |
+
)
|
| 78 |
+
return
|
| 79 |
+
|
| 80 |
+
if safe_serialization:
|
| 81 |
+
|
| 82 |
+
def save_function(weights, filename):
|
| 83 |
+
return safetensors.torch.save_file(
|
| 84 |
+
weights, filename, metadata={"format": "pt"}
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
else:
|
| 88 |
+
save_function = torch.save
|
| 89 |
+
|
| 90 |
+
# Save the model
|
| 91 |
+
state_dict = self._save_custom_adapter(**kwargs)
|
| 92 |
+
save_function(state_dict, os.path.join(save_directory, weight_name))
|
| 93 |
+
logger.info(
|
| 94 |
+
f"Custom adapter weights saved in {os.path.join(save_directory, weight_name)}"
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
def _save_custom_adapter(self):
|
| 98 |
+
raise NotImplementedError
|
comfyui-mvadapter/mvadapter/models/__init__.py
ADDED
|
File without changes
|
comfyui-mvadapter/mvadapter/models/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (168 Bytes). View file
|
|
|
comfyui-mvadapter/mvadapter/models/__pycache__/attention_processor.cpython-312.pyc
ADDED
|
Binary file (13.8 kB). View file
|
|
|
comfyui-mvadapter/mvadapter/models/attention_processor.py
ADDED
|
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Callable, List, Optional, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from diffusers.models.attention_processor import Attention
|
| 7 |
+
from diffusers.models.unets import UNet2DConditionModel
|
| 8 |
+
from diffusers.utils import deprecate, logging
|
| 9 |
+
from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
from torch import nn
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def default_set_attn_proc_func(
|
| 15 |
+
name: str,
|
| 16 |
+
hidden_size: int,
|
| 17 |
+
cross_attention_dim: Optional[int],
|
| 18 |
+
ori_attn_proc: object,
|
| 19 |
+
) -> object:
|
| 20 |
+
return ori_attn_proc
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def set_unet_2d_condition_attn_processor(
|
| 24 |
+
unet: UNet2DConditionModel,
|
| 25 |
+
set_self_attn_proc_func: Callable = default_set_attn_proc_func,
|
| 26 |
+
set_cross_attn_proc_func: Callable = default_set_attn_proc_func,
|
| 27 |
+
set_custom_attn_proc_func: Callable = default_set_attn_proc_func,
|
| 28 |
+
set_self_attn_module_names: Optional[List[str]] = None,
|
| 29 |
+
set_cross_attn_module_names: Optional[List[str]] = None,
|
| 30 |
+
set_custom_attn_module_names: Optional[List[str]] = None,
|
| 31 |
+
) -> None:
|
| 32 |
+
do_set_processor = lambda name, module_names: (
|
| 33 |
+
any([name.startswith(module_name) for module_name in module_names])
|
| 34 |
+
if module_names is not None
|
| 35 |
+
else True
|
| 36 |
+
) # prefix match
|
| 37 |
+
|
| 38 |
+
attn_procs = {}
|
| 39 |
+
for name, attn_processor in unet.attn_processors.items():
|
| 40 |
+
# set attn_processor by default, if module_names is None
|
| 41 |
+
set_self_attn_processor = do_set_processor(name, set_self_attn_module_names)
|
| 42 |
+
set_cross_attn_processor = do_set_processor(name, set_cross_attn_module_names)
|
| 43 |
+
set_custom_attn_processor = do_set_processor(name, set_custom_attn_module_names)
|
| 44 |
+
|
| 45 |
+
if name.startswith("mid_block"):
|
| 46 |
+
hidden_size = unet.config.block_out_channels[-1]
|
| 47 |
+
elif name.startswith("up_blocks"):
|
| 48 |
+
block_id = int(name[len("up_blocks.")])
|
| 49 |
+
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
| 50 |
+
elif name.startswith("down_blocks"):
|
| 51 |
+
block_id = int(name[len("down_blocks.")])
|
| 52 |
+
hidden_size = unet.config.block_out_channels[block_id]
|
| 53 |
+
|
| 54 |
+
is_custom = "attn_mid_blocks" in name or "attn_post_blocks" in name
|
| 55 |
+
if is_custom:
|
| 56 |
+
attn_procs[name] = (
|
| 57 |
+
set_custom_attn_proc_func(name, hidden_size, None, attn_processor)
|
| 58 |
+
if set_custom_attn_processor
|
| 59 |
+
else attn_processor
|
| 60 |
+
)
|
| 61 |
+
else:
|
| 62 |
+
cross_attention_dim = (
|
| 63 |
+
None
|
| 64 |
+
if name.endswith("attn1.processor")
|
| 65 |
+
else unet.config.cross_attention_dim
|
| 66 |
+
)
|
| 67 |
+
if cross_attention_dim is None or "motion_modules" in name:
|
| 68 |
+
# self attention
|
| 69 |
+
attn_procs[name] = (
|
| 70 |
+
set_self_attn_proc_func(
|
| 71 |
+
name, hidden_size, cross_attention_dim, attn_processor
|
| 72 |
+
)
|
| 73 |
+
if set_self_attn_processor
|
| 74 |
+
else attn_processor
|
| 75 |
+
)
|
| 76 |
+
else:
|
| 77 |
+
# cross attention
|
| 78 |
+
attn_procs[name] = (
|
| 79 |
+
set_cross_attn_proc_func(
|
| 80 |
+
name, hidden_size, cross_attention_dim, attn_processor
|
| 81 |
+
)
|
| 82 |
+
if set_cross_attn_processor
|
| 83 |
+
else attn_processor
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
unet.set_attn_processor(attn_procs)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class DecoupledMVRowSelfAttnProcessor2_0(torch.nn.Module):
|
| 90 |
+
r"""
|
| 91 |
+
Attention processor for Decoupled Row-wise Self-Attention and Image Cross-Attention for PyTorch 2.0.
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
def __init__(
|
| 95 |
+
self,
|
| 96 |
+
query_dim: int,
|
| 97 |
+
inner_dim: int,
|
| 98 |
+
num_views: int = 1,
|
| 99 |
+
name: Optional[str] = None,
|
| 100 |
+
use_mv: bool = True,
|
| 101 |
+
use_ref: bool = False,
|
| 102 |
+
):
|
| 103 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 104 |
+
raise ImportError(
|
| 105 |
+
"DecoupledMVRowSelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
super().__init__()
|
| 109 |
+
|
| 110 |
+
self.num_views = num_views
|
| 111 |
+
self.name = name # NOTE: need for image cross-attention
|
| 112 |
+
self.use_mv = use_mv
|
| 113 |
+
self.use_ref = use_ref
|
| 114 |
+
|
| 115 |
+
if self.use_mv:
|
| 116 |
+
self.to_q_mv = nn.Linear(
|
| 117 |
+
in_features=query_dim, out_features=inner_dim, bias=False
|
| 118 |
+
)
|
| 119 |
+
self.to_k_mv = nn.Linear(
|
| 120 |
+
in_features=query_dim, out_features=inner_dim, bias=False
|
| 121 |
+
)
|
| 122 |
+
self.to_v_mv = nn.Linear(
|
| 123 |
+
in_features=query_dim, out_features=inner_dim, bias=False
|
| 124 |
+
)
|
| 125 |
+
self.to_out_mv = nn.ModuleList(
|
| 126 |
+
[
|
| 127 |
+
nn.Linear(in_features=inner_dim, out_features=query_dim, bias=True),
|
| 128 |
+
nn.Dropout(0.0),
|
| 129 |
+
]
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
if self.use_ref:
|
| 133 |
+
self.to_q_ref = nn.Linear(
|
| 134 |
+
in_features=query_dim, out_features=inner_dim, bias=False
|
| 135 |
+
)
|
| 136 |
+
self.to_k_ref = nn.Linear(
|
| 137 |
+
in_features=query_dim, out_features=inner_dim, bias=False
|
| 138 |
+
)
|
| 139 |
+
self.to_v_ref = nn.Linear(
|
| 140 |
+
in_features=query_dim, out_features=inner_dim, bias=False
|
| 141 |
+
)
|
| 142 |
+
self.to_out_ref = nn.ModuleList(
|
| 143 |
+
[
|
| 144 |
+
nn.Linear(in_features=inner_dim, out_features=query_dim, bias=True),
|
| 145 |
+
nn.Dropout(0.0),
|
| 146 |
+
]
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
def __call__(
|
| 150 |
+
self,
|
| 151 |
+
attn: Attention,
|
| 152 |
+
hidden_states: torch.FloatTensor,
|
| 153 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 154 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 155 |
+
temb: Optional[torch.FloatTensor] = None,
|
| 156 |
+
mv_scale: float = 1.0,
|
| 157 |
+
ref_hidden_states: Optional[torch.FloatTensor] = None,
|
| 158 |
+
ref_scale: float = 1.0,
|
| 159 |
+
cache_hidden_states: Optional[List[torch.FloatTensor]] = None,
|
| 160 |
+
use_mv: bool = True,
|
| 161 |
+
use_ref: bool = True,
|
| 162 |
+
num_views: Optional[int] = None,
|
| 163 |
+
*args,
|
| 164 |
+
**kwargs,
|
| 165 |
+
) -> torch.FloatTensor:
|
| 166 |
+
"""
|
| 167 |
+
New args:
|
| 168 |
+
mv_scale (float): scale for multi-view self-attention.
|
| 169 |
+
ref_hidden_states (torch.FloatTensor): reference encoder hidden states for image cross-attention.
|
| 170 |
+
ref_scale (float): scale for image cross-attention.
|
| 171 |
+
cache_hidden_states (List[torch.FloatTensor]): cache hidden states from reference unet.
|
| 172 |
+
|
| 173 |
+
"""
|
| 174 |
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
| 175 |
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
| 176 |
+
deprecate("scale", "1.0.0", deprecation_message)
|
| 177 |
+
|
| 178 |
+
if num_views is not None:
|
| 179 |
+
self.num_views = num_views
|
| 180 |
+
|
| 181 |
+
# NEW: cache hidden states for reference unet
|
| 182 |
+
if cache_hidden_states is not None:
|
| 183 |
+
cache_hidden_states[self.name] = hidden_states.clone()
|
| 184 |
+
|
| 185 |
+
# NEW: whether to use multi-view attention and image cross-attention
|
| 186 |
+
use_mv = self.use_mv and use_mv
|
| 187 |
+
use_ref = self.use_ref and use_ref
|
| 188 |
+
|
| 189 |
+
residual = hidden_states
|
| 190 |
+
if attn.spatial_norm is not None:
|
| 191 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
| 192 |
+
|
| 193 |
+
input_ndim = hidden_states.ndim
|
| 194 |
+
|
| 195 |
+
if input_ndim == 4:
|
| 196 |
+
batch_size, channel, height, width = hidden_states.shape
|
| 197 |
+
hidden_states = hidden_states.view(
|
| 198 |
+
batch_size, channel, height * width
|
| 199 |
+
).transpose(1, 2)
|
| 200 |
+
|
| 201 |
+
batch_size, sequence_length, _ = (
|
| 202 |
+
hidden_states.shape
|
| 203 |
+
if encoder_hidden_states is None
|
| 204 |
+
else encoder_hidden_states.shape
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
if attention_mask is not None:
|
| 208 |
+
attention_mask = attn.prepare_attention_mask(
|
| 209 |
+
attention_mask, sequence_length, batch_size
|
| 210 |
+
)
|
| 211 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
| 212 |
+
# (batch, heads, source_length, target_length)
|
| 213 |
+
attention_mask = attention_mask.view(
|
| 214 |
+
batch_size, attn.heads, -1, attention_mask.shape[-1]
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
if attn.group_norm is not None:
|
| 218 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
| 219 |
+
1, 2
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
query = attn.to_q(hidden_states)
|
| 223 |
+
|
| 224 |
+
# NEW: for decoupled multi-view attention
|
| 225 |
+
if use_mv:
|
| 226 |
+
query_mv = self.to_q_mv(hidden_states)
|
| 227 |
+
|
| 228 |
+
# NEW: for decoupled reference cross attention
|
| 229 |
+
if use_ref:
|
| 230 |
+
query_ref = self.to_q_ref(hidden_states)
|
| 231 |
+
|
| 232 |
+
if encoder_hidden_states is None:
|
| 233 |
+
encoder_hidden_states = hidden_states
|
| 234 |
+
elif attn.norm_cross:
|
| 235 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
| 236 |
+
encoder_hidden_states
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
key = attn.to_k(encoder_hidden_states)
|
| 240 |
+
value = attn.to_v(encoder_hidden_states)
|
| 241 |
+
|
| 242 |
+
inner_dim = key.shape[-1]
|
| 243 |
+
head_dim = inner_dim // attn.heads
|
| 244 |
+
|
| 245 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 246 |
+
|
| 247 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 248 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 249 |
+
|
| 250 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
| 251 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
| 252 |
+
hidden_states = F.scaled_dot_product_attention(
|
| 253 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(
|
| 257 |
+
batch_size, -1, attn.heads * head_dim
|
| 258 |
+
)
|
| 259 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 260 |
+
|
| 261 |
+
####### Decoupled multi-view self-attention ########
|
| 262 |
+
if use_mv:
|
| 263 |
+
key_mv = self.to_k_mv(encoder_hidden_states)
|
| 264 |
+
value_mv = self.to_v_mv(encoder_hidden_states)
|
| 265 |
+
|
| 266 |
+
query_mv = query_mv.view(batch_size, -1, attn.heads, head_dim)
|
| 267 |
+
key_mv = key_mv.view(batch_size, -1, attn.heads, head_dim)
|
| 268 |
+
value_mv = value_mv.view(batch_size, -1, attn.heads, head_dim)
|
| 269 |
+
|
| 270 |
+
height = width = math.isqrt(sequence_length)
|
| 271 |
+
|
| 272 |
+
# row self-attention
|
| 273 |
+
query_mv = rearrange(
|
| 274 |
+
query_mv,
|
| 275 |
+
"(b nv) (ih iw) h c -> (b nv ih) iw h c",
|
| 276 |
+
nv=self.num_views,
|
| 277 |
+
ih=height,
|
| 278 |
+
iw=width,
|
| 279 |
+
).transpose(1, 2)
|
| 280 |
+
key_mv = rearrange(
|
| 281 |
+
key_mv,
|
| 282 |
+
"(b nv) (ih iw) h c -> b ih (nv iw) h c",
|
| 283 |
+
nv=self.num_views,
|
| 284 |
+
ih=height,
|
| 285 |
+
iw=width,
|
| 286 |
+
)
|
| 287 |
+
key_mv = (
|
| 288 |
+
key_mv.repeat_interleave(self.num_views, dim=0)
|
| 289 |
+
.view(batch_size * height, -1, attn.heads, head_dim)
|
| 290 |
+
.transpose(1, 2)
|
| 291 |
+
)
|
| 292 |
+
value_mv = rearrange(
|
| 293 |
+
value_mv,
|
| 294 |
+
"(b nv) (ih iw) h c -> b ih (nv iw) h c",
|
| 295 |
+
nv=self.num_views,
|
| 296 |
+
ih=height,
|
| 297 |
+
iw=width,
|
| 298 |
+
)
|
| 299 |
+
value_mv = (
|
| 300 |
+
value_mv.repeat_interleave(self.num_views, dim=0)
|
| 301 |
+
.view(batch_size * height, -1, attn.heads, head_dim)
|
| 302 |
+
.transpose(1, 2)
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
hidden_states_mv = F.scaled_dot_product_attention(
|
| 306 |
+
query_mv,
|
| 307 |
+
key_mv,
|
| 308 |
+
value_mv,
|
| 309 |
+
dropout_p=0.0,
|
| 310 |
+
is_causal=False,
|
| 311 |
+
)
|
| 312 |
+
hidden_states_mv = rearrange(
|
| 313 |
+
hidden_states_mv,
|
| 314 |
+
"(b nv ih) h iw c -> (b nv) (ih iw) (h c)",
|
| 315 |
+
nv=self.num_views,
|
| 316 |
+
ih=height,
|
| 317 |
+
)
|
| 318 |
+
hidden_states_mv = hidden_states_mv.to(query.dtype)
|
| 319 |
+
|
| 320 |
+
# linear proj
|
| 321 |
+
hidden_states_mv = self.to_out_mv[0](hidden_states_mv)
|
| 322 |
+
# dropout
|
| 323 |
+
hidden_states_mv = self.to_out_mv[1](hidden_states_mv)
|
| 324 |
+
|
| 325 |
+
if use_ref:
|
| 326 |
+
reference_hidden_states = ref_hidden_states[self.name]
|
| 327 |
+
|
| 328 |
+
key_ref = self.to_k_ref(reference_hidden_states)
|
| 329 |
+
value_ref = self.to_v_ref(reference_hidden_states)
|
| 330 |
+
|
| 331 |
+
query_ref = query_ref.view(batch_size, -1, attn.heads, head_dim).transpose(
|
| 332 |
+
1, 2
|
| 333 |
+
)
|
| 334 |
+
key_ref = key_ref.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 335 |
+
value_ref = value_ref.view(batch_size, -1, attn.heads, head_dim).transpose(
|
| 336 |
+
1, 2
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
hidden_states_ref = F.scaled_dot_product_attention(
|
| 340 |
+
query_ref, key_ref, value_ref, dropout_p=0.0, is_causal=False
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
hidden_states_ref = hidden_states_ref.transpose(1, 2).reshape(
|
| 344 |
+
batch_size, -1, attn.heads * head_dim
|
| 345 |
+
)
|
| 346 |
+
hidden_states_ref = hidden_states_ref.to(query.dtype)
|
| 347 |
+
|
| 348 |
+
# linear proj
|
| 349 |
+
hidden_states_ref = self.to_out_ref[0](hidden_states_ref)
|
| 350 |
+
# dropout
|
| 351 |
+
hidden_states_ref = self.to_out_ref[1](hidden_states_ref)
|
| 352 |
+
|
| 353 |
+
# linear proj
|
| 354 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 355 |
+
# dropout
|
| 356 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 357 |
+
|
| 358 |
+
if use_mv:
|
| 359 |
+
hidden_states = hidden_states + hidden_states_mv * mv_scale
|
| 360 |
+
|
| 361 |
+
if use_ref:
|
| 362 |
+
hidden_states = hidden_states + hidden_states_ref * ref_scale
|
| 363 |
+
|
| 364 |
+
if input_ndim == 4:
|
| 365 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
| 366 |
+
batch_size, channel, height, width
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
if attn.residual_connection:
|
| 370 |
+
hidden_states = hidden_states + residual
|
| 371 |
+
|
| 372 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
| 373 |
+
|
| 374 |
+
return hidden_states
|
| 375 |
+
|
| 376 |
+
def set_num_views(self, num_views: int) -> None:
|
| 377 |
+
self.num_views = num_views
|
comfyui-mvadapter/mvadapter/pipelines/__pycache__/pipeline_mvadapter_i2mv_sd.cpython-312.pyc
ADDED
|
Binary file (30 kB). View file
|
|
|
comfyui-mvadapter/mvadapter/pipelines/__pycache__/pipeline_mvadapter_i2mv_sdxl.cpython-312.pyc
ADDED
|
Binary file (32.6 kB). View file
|
|
|
comfyui-mvadapter/mvadapter/pipelines/__pycache__/pipeline_mvadapter_t2mv_sd.cpython-312.pyc
ADDED
|
Binary file (24.9 kB). View file
|
|
|
comfyui-mvadapter/mvadapter/pipelines/__pycache__/pipeline_mvadapter_t2mv_sdxl.cpython-312.pyc
ADDED
|
Binary file (34.8 kB). View file
|
|
|
comfyui-mvadapter/mvadapter/pipelines/pipeline_mvadapter_i2mv_sdxl.py
ADDED
|
@@ -0,0 +1,903 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import inspect
|
| 16 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import PIL
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
| 23 |
+
from diffusers.models import (
|
| 24 |
+
AutoencoderKL,
|
| 25 |
+
ImageProjection,
|
| 26 |
+
T2IAdapter,
|
| 27 |
+
UNet2DConditionModel,
|
| 28 |
+
)
|
| 29 |
+
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import (
|
| 30 |
+
StableDiffusionXLPipelineOutput,
|
| 31 |
+
)
|
| 32 |
+
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import (
|
| 33 |
+
StableDiffusionXLPipeline,
|
| 34 |
+
rescale_noise_cfg,
|
| 35 |
+
retrieve_timesteps,
|
| 36 |
+
)
|
| 37 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
| 38 |
+
from diffusers.utils import deprecate, logging
|
| 39 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 40 |
+
from einops import rearrange
|
| 41 |
+
from transformers import (
|
| 42 |
+
CLIPImageProcessor,
|
| 43 |
+
CLIPTextModel,
|
| 44 |
+
CLIPTextModelWithProjection,
|
| 45 |
+
CLIPTokenizer,
|
| 46 |
+
CLIPVisionModelWithProjection,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
from ..loaders import CustomAdapterMixin
|
| 50 |
+
from ..models.attention_processor import (
|
| 51 |
+
DecoupledMVRowSelfAttnProcessor2_0,
|
| 52 |
+
set_unet_2d_condition_attn_processor,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def retrieve_latents(
|
| 59 |
+
encoder_output: torch.Tensor,
|
| 60 |
+
generator: Optional[torch.Generator] = None,
|
| 61 |
+
sample_mode: str = "sample",
|
| 62 |
+
):
|
| 63 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
| 64 |
+
return encoder_output.latent_dist.sample(generator)
|
| 65 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
| 66 |
+
return encoder_output.latent_dist.mode()
|
| 67 |
+
elif hasattr(encoder_output, "latents"):
|
| 68 |
+
return encoder_output.latents
|
| 69 |
+
else:
|
| 70 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class MVAdapterI2MVSDXLPipeline(StableDiffusionXLPipeline, CustomAdapterMixin):
|
| 74 |
+
def __init__(
|
| 75 |
+
self,
|
| 76 |
+
vae: AutoencoderKL,
|
| 77 |
+
text_encoder: CLIPTextModel,
|
| 78 |
+
text_encoder_2: CLIPTextModelWithProjection,
|
| 79 |
+
tokenizer: CLIPTokenizer,
|
| 80 |
+
tokenizer_2: CLIPTokenizer,
|
| 81 |
+
unet: UNet2DConditionModel,
|
| 82 |
+
scheduler: KarrasDiffusionSchedulers,
|
| 83 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
| 84 |
+
feature_extractor: CLIPImageProcessor = None,
|
| 85 |
+
force_zeros_for_empty_prompt: bool = True,
|
| 86 |
+
add_watermarker: Optional[bool] = None,
|
| 87 |
+
):
|
| 88 |
+
super().__init__(
|
| 89 |
+
vae=vae,
|
| 90 |
+
text_encoder=text_encoder,
|
| 91 |
+
text_encoder_2=text_encoder_2,
|
| 92 |
+
tokenizer=tokenizer,
|
| 93 |
+
tokenizer_2=tokenizer_2,
|
| 94 |
+
unet=unet,
|
| 95 |
+
scheduler=scheduler,
|
| 96 |
+
image_encoder=image_encoder,
|
| 97 |
+
feature_extractor=feature_extractor,
|
| 98 |
+
force_zeros_for_empty_prompt=force_zeros_for_empty_prompt,
|
| 99 |
+
add_watermarker=add_watermarker,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
self.control_image_processor = VaeImageProcessor(
|
| 103 |
+
vae_scale_factor=self.vae_scale_factor,
|
| 104 |
+
do_convert_rgb=True,
|
| 105 |
+
do_normalize=False,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.prepare_latents
|
| 109 |
+
def prepare_image_latents(
|
| 110 |
+
self,
|
| 111 |
+
image,
|
| 112 |
+
timestep,
|
| 113 |
+
batch_size,
|
| 114 |
+
num_images_per_prompt,
|
| 115 |
+
dtype,
|
| 116 |
+
device,
|
| 117 |
+
generator=None,
|
| 118 |
+
add_noise=True,
|
| 119 |
+
):
|
| 120 |
+
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
|
| 121 |
+
raise ValueError(
|
| 122 |
+
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
latents_mean = latents_std = None
|
| 126 |
+
if (
|
| 127 |
+
hasattr(self.vae.config, "latents_mean")
|
| 128 |
+
and self.vae.config.latents_mean is not None
|
| 129 |
+
):
|
| 130 |
+
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
|
| 131 |
+
if (
|
| 132 |
+
hasattr(self.vae.config, "latents_std")
|
| 133 |
+
and self.vae.config.latents_std is not None
|
| 134 |
+
):
|
| 135 |
+
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
|
| 136 |
+
|
| 137 |
+
# Offload text encoder if `enable_model_cpu_offload` was enabled
|
| 138 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
| 139 |
+
self.text_encoder_2.to("cpu")
|
| 140 |
+
torch.cuda.empty_cache()
|
| 141 |
+
|
| 142 |
+
image = image.to(device=device, dtype=dtype)
|
| 143 |
+
|
| 144 |
+
batch_size = batch_size * num_images_per_prompt
|
| 145 |
+
|
| 146 |
+
if image.shape[1] == 4:
|
| 147 |
+
init_latents = image
|
| 148 |
+
|
| 149 |
+
else:
|
| 150 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
| 151 |
+
if self.vae.config.force_upcast:
|
| 152 |
+
image = image.float()
|
| 153 |
+
self.vae.to(dtype=torch.float32)
|
| 154 |
+
|
| 155 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 156 |
+
raise ValueError(
|
| 157 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 158 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
elif isinstance(generator, list):
|
| 162 |
+
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
|
| 163 |
+
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
|
| 164 |
+
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
|
| 165 |
+
raise ValueError(
|
| 166 |
+
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
init_latents = [
|
| 170 |
+
retrieve_latents(
|
| 171 |
+
self.vae.encode(image[i : i + 1]), generator=generator[i]
|
| 172 |
+
)
|
| 173 |
+
for i in range(batch_size)
|
| 174 |
+
]
|
| 175 |
+
init_latents = torch.cat(init_latents, dim=0)
|
| 176 |
+
else:
|
| 177 |
+
init_latents = retrieve_latents(
|
| 178 |
+
self.vae.encode(image), generator=generator
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
if self.vae.config.force_upcast:
|
| 182 |
+
self.vae.to(dtype)
|
| 183 |
+
|
| 184 |
+
init_latents = init_latents.to(dtype)
|
| 185 |
+
if latents_mean is not None and latents_std is not None:
|
| 186 |
+
latents_mean = latents_mean.to(device=device, dtype=dtype)
|
| 187 |
+
latents_std = latents_std.to(device=device, dtype=dtype)
|
| 188 |
+
init_latents = (
|
| 189 |
+
(init_latents - latents_mean)
|
| 190 |
+
* self.vae.config.scaling_factor
|
| 191 |
+
/ latents_std
|
| 192 |
+
)
|
| 193 |
+
else:
|
| 194 |
+
init_latents = self.vae.config.scaling_factor * init_latents
|
| 195 |
+
|
| 196 |
+
if (
|
| 197 |
+
batch_size > init_latents.shape[0]
|
| 198 |
+
and batch_size % init_latents.shape[0] == 0
|
| 199 |
+
):
|
| 200 |
+
# expand init_latents for batch_size
|
| 201 |
+
additional_image_per_prompt = batch_size // init_latents.shape[0]
|
| 202 |
+
init_latents = torch.cat(
|
| 203 |
+
[init_latents] * additional_image_per_prompt, dim=0
|
| 204 |
+
)
|
| 205 |
+
elif (
|
| 206 |
+
batch_size > init_latents.shape[0]
|
| 207 |
+
and batch_size % init_latents.shape[0] != 0
|
| 208 |
+
):
|
| 209 |
+
raise ValueError(
|
| 210 |
+
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
| 211 |
+
)
|
| 212 |
+
else:
|
| 213 |
+
init_latents = torch.cat([init_latents], dim=0)
|
| 214 |
+
|
| 215 |
+
if add_noise:
|
| 216 |
+
shape = init_latents.shape
|
| 217 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 218 |
+
# get latents
|
| 219 |
+
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
| 220 |
+
|
| 221 |
+
latents = init_latents
|
| 222 |
+
|
| 223 |
+
return latents
|
| 224 |
+
|
| 225 |
+
def prepare_control_image(
|
| 226 |
+
self,
|
| 227 |
+
image,
|
| 228 |
+
width,
|
| 229 |
+
height,
|
| 230 |
+
batch_size,
|
| 231 |
+
num_images_per_prompt,
|
| 232 |
+
device,
|
| 233 |
+
dtype,
|
| 234 |
+
do_classifier_free_guidance=False,
|
| 235 |
+
num_empty_images=0, # for concat in batch like ImageDream
|
| 236 |
+
):
|
| 237 |
+
"""
|
| 238 |
+
Accepts either:
|
| 239 |
+
- regular RGB-like images -> preprocess via VaeImageProcessor, or
|
| 240 |
+
- native 6-channel Plücker tensors (B,6,H,W) or (6,H,W) -> pass through without normalization
|
| 241 |
+
"""
|
| 242 |
+
assert hasattr(
|
| 243 |
+
self, "control_image_processor"
|
| 244 |
+
), "control_image_processor is not initialized"
|
| 245 |
+
|
| 246 |
+
# Fast path: native 6-channel tensor
|
| 247 |
+
if isinstance(image, torch.Tensor):
|
| 248 |
+
if image.dim() == 3 and image.shape[0] == 6:
|
| 249 |
+
image = image.unsqueeze(0) # (1,6,H,W)
|
| 250 |
+
if image.dim() == 4 and image.shape[1] == 6:
|
| 251 |
+
ctrl = image.to(device=device, dtype=torch.float32)
|
| 252 |
+
if num_empty_images > 0:
|
| 253 |
+
ctrl = torch.cat([ctrl, torch.zeros_like(ctrl[:num_empty_images])], dim=0)
|
| 254 |
+
|
| 255 |
+
image_batch_size = ctrl.shape[0]
|
| 256 |
+
repeat_by = batch_size if image_batch_size == 1 else num_images_per_prompt # always 1 per control
|
| 257 |
+
ctrl = ctrl.repeat_interleave(repeat_by, dim=0)
|
| 258 |
+
ctrl = ctrl.to(device=device, dtype=dtype)
|
| 259 |
+
|
| 260 |
+
if do_classifier_free_guidance:
|
| 261 |
+
ctrl = torch.cat([ctrl] * 2)
|
| 262 |
+
|
| 263 |
+
return ctrl
|
| 264 |
+
|
| 265 |
+
# Fallback: treat as regular image(s)
|
| 266 |
+
image = self.control_image_processor.preprocess(
|
| 267 |
+
image, height=height, width=width
|
| 268 |
+
).to(dtype=torch.float32)
|
| 269 |
+
|
| 270 |
+
if num_empty_images > 0:
|
| 271 |
+
image = torch.cat(
|
| 272 |
+
[image, torch.zeros_like(image[:num_empty_images])], dim=0
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
image_batch_size = image.shape[0]
|
| 276 |
+
|
| 277 |
+
if image_batch_size == 1:
|
| 278 |
+
repeat_by = batch_size
|
| 279 |
+
else:
|
| 280 |
+
# image batch size is the same as prompt batch size
|
| 281 |
+
repeat_by = num_images_per_prompt # always 1 for control image
|
| 282 |
+
|
| 283 |
+
image = image.repeat_interleave(repeat_by, dim=0)
|
| 284 |
+
|
| 285 |
+
image = image.to(device=device, dtype=dtype)
|
| 286 |
+
|
| 287 |
+
if do_classifier_free_guidance:
|
| 288 |
+
image = torch.cat([image] * 2)
|
| 289 |
+
|
| 290 |
+
return image
|
| 291 |
+
|
| 292 |
+
@torch.no_grad()
|
| 293 |
+
def __call__(
|
| 294 |
+
self,
|
| 295 |
+
prompt: Union[str, List[str]] = None,
|
| 296 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 297 |
+
# --- Task 1: NEW (reference-only prompt used only for the ref cache pass) ---
|
| 298 |
+
reference_prompt: Optional[Union[str, List[str]]] = None,
|
| 299 |
+
reference_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 300 |
+
# -----------------------------------------------------------------------------
|
| 301 |
+
height: Optional[int] = None,
|
| 302 |
+
width: Optional[int] = None,
|
| 303 |
+
num_inference_steps: int = 50,
|
| 304 |
+
timesteps: List[int] = None,
|
| 305 |
+
denoising_end: Optional[float] = None,
|
| 306 |
+
guidance_scale: float = 5.0,
|
| 307 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 308 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 309 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 310 |
+
eta: float = 0.0,
|
| 311 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 312 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 313 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 314 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 315 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 316 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 317 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 318 |
+
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
|
| 319 |
+
output_type: Optional[str] = "pil",
|
| 320 |
+
return_dict: bool = True,
|
| 321 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 322 |
+
guidance_rescale: float = 0.0,
|
| 323 |
+
original_size: Optional[Tuple[int, int]] = None,
|
| 324 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
| 325 |
+
target_size: Optional[Tuple[int, int]] = None,
|
| 326 |
+
negative_original_size: Optional[Tuple[int, int]] = None,
|
| 327 |
+
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
| 328 |
+
negative_target_size: Optional[Tuple[int, int]] = None,
|
| 329 |
+
clip_skip: Optional[int] = None,
|
| 330 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 331 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 332 |
+
# NEW
|
| 333 |
+
mv_scale: float = 1.0,
|
| 334 |
+
# Camera or geometry condition
|
| 335 |
+
control_image: Optional[PipelineImageInput] = None,
|
| 336 |
+
control_conditioning_scale: Optional[float] = 1.0,
|
| 337 |
+
control_conditioning_factor: float = 1.0,
|
| 338 |
+
# Image condition
|
| 339 |
+
reference_image: Optional[PipelineImageInput] = None,
|
| 340 |
+
reference_conditioning_scale: Optional[float] = 1.0,
|
| 341 |
+
**kwargs,
|
| 342 |
+
):
|
| 343 |
+
r"""
|
| 344 |
+
Function invoked when calling the pipeline for generation.
|
| 345 |
+
|
| 346 |
+
Args:
|
| 347 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 348 |
+
The main prompt(s) for generation.
|
| 349 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
| 350 |
+
Prompt(s) for the second text encoder. Falls back to `prompt` if None.
|
| 351 |
+
reference_prompt (`str` or `List[str]`, *optional*):
|
| 352 |
+
Prompt used **only** during the one-shot reference UNet pass that caches identity features
|
| 353 |
+
from `reference_image`. If None or empty, falls back to the positive branch of the main prompt
|
| 354 |
+
(original behavior).
|
| 355 |
+
reference_prompt_2 (`str` or `List[str]`, *optional*):
|
| 356 |
+
Second-encoder counterpart for `reference_prompt`.
|
| 357 |
+
... (other arguments unchanged) ...
|
| 358 |
+
"""
|
| 359 |
+
|
| 360 |
+
callback = kwargs.pop("callback", None)
|
| 361 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
| 362 |
+
|
| 363 |
+
if callback is not None:
|
| 364 |
+
deprecate(
|
| 365 |
+
"callback",
|
| 366 |
+
"1.0.0",
|
| 367 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
| 368 |
+
)
|
| 369 |
+
if callback_steps is not None:
|
| 370 |
+
deprecate(
|
| 371 |
+
"callback_steps",
|
| 372 |
+
"1.0.0",
|
| 373 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
# 0. Default height and width to unet
|
| 377 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 378 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 379 |
+
|
| 380 |
+
original_size = original_size or (height, width)
|
| 381 |
+
target_size = target_size or (height, width)
|
| 382 |
+
|
| 383 |
+
# 1. Check inputs. Raise error if not correct
|
| 384 |
+
self.check_inputs(
|
| 385 |
+
prompt,
|
| 386 |
+
prompt_2,
|
| 387 |
+
height,
|
| 388 |
+
width,
|
| 389 |
+
callback_steps,
|
| 390 |
+
negative_prompt,
|
| 391 |
+
negative_prompt_2,
|
| 392 |
+
prompt_embeds,
|
| 393 |
+
negative_prompt_embeds,
|
| 394 |
+
pooled_prompt_embeds,
|
| 395 |
+
negative_pooled_prompt_embeds,
|
| 396 |
+
ip_adapter_image,
|
| 397 |
+
ip_adapter_image_embeds,
|
| 398 |
+
callback_on_step_end_tensor_inputs,
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
self._guidance_scale = guidance_scale
|
| 402 |
+
self._guidance_rescale = guidance_rescale
|
| 403 |
+
self._clip_skip = clip_skip
|
| 404 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
| 405 |
+
self._denoising_end = denoising_end
|
| 406 |
+
self._interrupt = False
|
| 407 |
+
|
| 408 |
+
# 2. Define call parameters
|
| 409 |
+
if prompt is not None and isinstance(prompt, str):
|
| 410 |
+
batch_size = 1
|
| 411 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 412 |
+
batch_size = len(prompt)
|
| 413 |
+
else:
|
| 414 |
+
batch_size = prompt_embeds.shape[0]
|
| 415 |
+
|
| 416 |
+
device = self._execution_device
|
| 417 |
+
|
| 418 |
+
# 3. Encode input prompt
|
| 419 |
+
lora_scale = (
|
| 420 |
+
self.cross_attention_kwargs.get("scale", None)
|
| 421 |
+
if self.cross_attention_kwargs is not None
|
| 422 |
+
else None
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
(
|
| 426 |
+
prompt_embeds,
|
| 427 |
+
negative_prompt_embeds,
|
| 428 |
+
pooled_prompt_embeds,
|
| 429 |
+
negative_pooled_prompt_embeds,
|
| 430 |
+
) = self.encode_prompt(
|
| 431 |
+
prompt=prompt,
|
| 432 |
+
prompt_2=prompt_2,
|
| 433 |
+
device=device,
|
| 434 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 435 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 436 |
+
negative_prompt=negative_prompt,
|
| 437 |
+
negative_prompt_2=negative_prompt_2,
|
| 438 |
+
prompt_embeds=prompt_embeds,
|
| 439 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 440 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 441 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 442 |
+
lora_scale=lora_scale,
|
| 443 |
+
clip_skip=self.clip_skip,
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
# 4. Prepare timesteps
|
| 447 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 448 |
+
self.scheduler, num_inference_steps, device, timesteps
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
# 5. Prepare latent variables
|
| 452 |
+
num_channels_latents = self.unet.config.in_channels
|
| 453 |
+
latents = self.prepare_latents(
|
| 454 |
+
batch_size * num_images_per_prompt,
|
| 455 |
+
num_channels_latents,
|
| 456 |
+
height,
|
| 457 |
+
width,
|
| 458 |
+
prompt_embeds.dtype,
|
| 459 |
+
device,
|
| 460 |
+
generator,
|
| 461 |
+
latents,
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 465 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 466 |
+
|
| 467 |
+
# 7. Prepare added time ids & embeddings
|
| 468 |
+
add_text_embeds = pooled_prompt_embeds
|
| 469 |
+
if self.text_encoder_2 is None:
|
| 470 |
+
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
| 471 |
+
else:
|
| 472 |
+
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
| 473 |
+
|
| 474 |
+
add_time_ids = self._get_add_time_ids(
|
| 475 |
+
original_size,
|
| 476 |
+
crops_coords_top_left,
|
| 477 |
+
target_size,
|
| 478 |
+
dtype=prompt_embeds.dtype,
|
| 479 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
| 480 |
+
)
|
| 481 |
+
if negative_original_size is not None and negative_target_size is not None:
|
| 482 |
+
negative_add_time_ids = self._get_add_time_ids(
|
| 483 |
+
negative_original_size,
|
| 484 |
+
negative_crops_coords_top_left,
|
| 485 |
+
negative_target_size,
|
| 486 |
+
dtype=prompt_embeds.dtype,
|
| 487 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
| 488 |
+
)
|
| 489 |
+
else:
|
| 490 |
+
negative_add_time_ids = add_time_ids
|
| 491 |
+
|
| 492 |
+
if self.do_classifier_free_guidance:
|
| 493 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 494 |
+
add_text_embeds = torch.cat(
|
| 495 |
+
[negative_pooled_prompt_embeds, add_text_embeds], dim=0
|
| 496 |
+
)
|
| 497 |
+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
| 498 |
+
|
| 499 |
+
prompt_embeds = prompt_embeds.to(device)
|
| 500 |
+
add_text_embeds = add_text_embeds.to(device)
|
| 501 |
+
add_time_ids = add_time_ids.to(device).repeat(
|
| 502 |
+
batch_size * num_images_per_prompt, 1
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
| 506 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
| 507 |
+
ip_adapter_image,
|
| 508 |
+
ip_adapter_image_embeds,
|
| 509 |
+
device,
|
| 510 |
+
batch_size * num_images_per_prompt,
|
| 511 |
+
self.do_classifier_free_guidance,
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
# Preprocess reference image (required)
|
| 515 |
+
reference_image = self.image_processor.preprocess(reference_image)
|
| 516 |
+
reference_latents = self.prepare_image_latents(
|
| 517 |
+
reference_image,
|
| 518 |
+
timesteps[:1].repeat(batch_size * num_images_per_prompt), # no use
|
| 519 |
+
batch_size,
|
| 520 |
+
1,
|
| 521 |
+
prompt_embeds.dtype,
|
| 522 |
+
device,
|
| 523 |
+
generator,
|
| 524 |
+
add_noise=False,
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
with torch.no_grad():
|
| 528 |
+
ref_timesteps = torch.zeros_like(timesteps[0])
|
| 529 |
+
ref_hidden_states = {}
|
| 530 |
+
|
| 531 |
+
# reference-only prompt support (Task 1)
|
| 532 |
+
def _first_or_none(x):
|
| 533 |
+
if x is None:
|
| 534 |
+
return None
|
| 535 |
+
if isinstance(x, list) and len(x) > 0:
|
| 536 |
+
return x[0]
|
| 537 |
+
return x
|
| 538 |
+
|
| 539 |
+
rp = _first_or_none(reference_prompt)
|
| 540 |
+
rp2 = _first_or_none(reference_prompt_2)
|
| 541 |
+
have_ref_prompt = (rp is not None and str(rp).strip() != "") or (
|
| 542 |
+
rp2 is not None and str(rp2).strip() != ""
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
if have_ref_prompt:
|
| 546 |
+
ref_prompt_embeds, _, ref_pooled_prompt_embeds, _ = self.encode_prompt(
|
| 547 |
+
prompt=rp or prompt,
|
| 548 |
+
prompt_2=rp2 or prompt_2,
|
| 549 |
+
device=device,
|
| 550 |
+
num_images_per_prompt=1,
|
| 551 |
+
do_classifier_free_guidance=False,
|
| 552 |
+
prompt_embeds=None,
|
| 553 |
+
negative_prompt_embeds=None,
|
| 554 |
+
pooled_prompt_embeds=None,
|
| 555 |
+
negative_pooled_prompt_embeds=None,
|
| 556 |
+
lora_scale=(
|
| 557 |
+
self.cross_attention_kwargs.get("scale", None)
|
| 558 |
+
if self.cross_attention_kwargs is not None
|
| 559 |
+
else None
|
| 560 |
+
),
|
| 561 |
+
clip_skip=self.clip_skip,
|
| 562 |
+
)
|
| 563 |
+
else:
|
| 564 |
+
if self.do_classifier_free_guidance:
|
| 565 |
+
ref_prompt_embeds = prompt_embeds[-1:].clone()
|
| 566 |
+
ref_pooled_prompt_embeds = add_text_embeds[-1:].clone()
|
| 567 |
+
else:
|
| 568 |
+
ref_prompt_embeds = prompt_embeds[:1].clone()
|
| 569 |
+
ref_pooled_prompt_embeds = add_text_embeds[:1].clone()
|
| 570 |
+
|
| 571 |
+
self.unet(
|
| 572 |
+
reference_latents,
|
| 573 |
+
ref_timesteps,
|
| 574 |
+
encoder_hidden_states=ref_prompt_embeds,
|
| 575 |
+
added_cond_kwargs={
|
| 576 |
+
"text_embeds": ref_pooled_prompt_embeds,
|
| 577 |
+
"time_ids": add_time_ids[-1:],
|
| 578 |
+
},
|
| 579 |
+
cross_attention_kwargs={
|
| 580 |
+
"cache_hidden_states": ref_hidden_states,
|
| 581 |
+
"use_mv": False,
|
| 582 |
+
"use_ref": False,
|
| 583 |
+
},
|
| 584 |
+
return_dict=False,
|
| 585 |
+
)
|
| 586 |
+
ref_hidden_states = {
|
| 587 |
+
k: v.repeat_interleave(num_images_per_prompt, dim=0)
|
| 588 |
+
for k, v in ref_hidden_states.items()
|
| 589 |
+
}
|
| 590 |
+
if self.do_classifier_free_guidance:
|
| 591 |
+
ref_hidden_states = {
|
| 592 |
+
k: torch.cat([torch.zeros_like(v), v], dim=0)
|
| 593 |
+
for k, v in ref_hidden_states.items()
|
| 594 |
+
}
|
| 595 |
+
|
| 596 |
+
cross_attention_kwargs = {
|
| 597 |
+
"mv_scale": mv_scale,
|
| 598 |
+
"ref_hidden_states": ref_hidden_states,
|
| 599 |
+
"ref_scale": reference_conditioning_scale,
|
| 600 |
+
**(self.cross_attention_kwargs or {}),
|
| 601 |
+
}
|
| 602 |
+
|
| 603 |
+
# ------------- control image (Task 2 supports 6ch pass-through) -------------
|
| 604 |
+
control_image_feature = self.prepare_control_image(
|
| 605 |
+
image=control_image,
|
| 606 |
+
width=width,
|
| 607 |
+
height=height,
|
| 608 |
+
batch_size=batch_size * num_images_per_prompt,
|
| 609 |
+
num_images_per_prompt=1, # NOTE: always 1 for control images
|
| 610 |
+
device=device,
|
| 611 |
+
dtype=latents.dtype,
|
| 612 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 613 |
+
).to(device=device, dtype=latents.dtype)
|
| 614 |
+
|
| 615 |
+
adapter_state = self.cond_encoder(control_image_feature)
|
| 616 |
+
for i, state in enumerate(adapter_state):
|
| 617 |
+
adapter_state[i] = state * control_conditioning_scale
|
| 618 |
+
# ---------------------------------------------------------------------------
|
| 619 |
+
|
| 620 |
+
# 8. Denoising loop
|
| 621 |
+
num_warmup_steps = max(
|
| 622 |
+
len(timesteps) - num_inference_steps * self.scheduler.order, 0
|
| 623 |
+
)
|
| 624 |
+
|
| 625 |
+
# 8.1 Apply denoising_end
|
| 626 |
+
if (
|
| 627 |
+
self.denoising_end is not None
|
| 628 |
+
and isinstance(self.denoising_end, float)
|
| 629 |
+
and self.denoising_end > 0
|
| 630 |
+
and self.denoising_end < 1
|
| 631 |
+
):
|
| 632 |
+
discrete_timestep_cutoff = int(
|
| 633 |
+
round(
|
| 634 |
+
self.scheduler.config.num_train_timesteps
|
| 635 |
+
- (self.denoising_end * self.scheduler.config.num_train_timesteps)
|
| 636 |
+
)
|
| 637 |
+
)
|
| 638 |
+
num_inference_steps = len(
|
| 639 |
+
list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))
|
| 640 |
+
)
|
| 641 |
+
timesteps = timesteps[:num_inference_steps]
|
| 642 |
+
|
| 643 |
+
# 9. Optionally get Guidance Scale Embedding
|
| 644 |
+
timestep_cond = None
|
| 645 |
+
if self.unet.config.time_cond_proj_dim is not None:
|
| 646 |
+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
|
| 647 |
+
batch_size * num_images_per_prompt
|
| 648 |
+
)
|
| 649 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
| 650 |
+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
| 651 |
+
).to(device=device, dtype=latents.dtype)
|
| 652 |
+
|
| 653 |
+
self._num_timesteps = len(timesteps)
|
| 654 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 655 |
+
for i, t in enumerate(timesteps):
|
| 656 |
+
if self.interrupt:
|
| 657 |
+
continue
|
| 658 |
+
|
| 659 |
+
# expand the latents if we are doing classifier free guidance
|
| 660 |
+
latent_model_input = (
|
| 661 |
+
torch.cat([latents] * 2)
|
| 662 |
+
if self.do_classifier_free_guidance
|
| 663 |
+
else latents
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
latent_model_input = self.scheduler.scale_model_input(
|
| 667 |
+
latent_model_input, t
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
added_cond_kwargs = {
|
| 671 |
+
"text_embeds": add_text_embeds,
|
| 672 |
+
"time_ids": add_time_ids,
|
| 673 |
+
}
|
| 674 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
| 675 |
+
added_cond_kwargs["image_embeds"] = image_embeds
|
| 676 |
+
|
| 677 |
+
if i < int(num_inference_steps * control_conditioning_factor):
|
| 678 |
+
down_intrablock_additional_residuals = [
|
| 679 |
+
state.clone() for state in adapter_state
|
| 680 |
+
]
|
| 681 |
+
else:
|
| 682 |
+
down_intrablock_additional_residuals = None
|
| 683 |
+
|
| 684 |
+
# predict the noise residual
|
| 685 |
+
noise_pred = self.unet(
|
| 686 |
+
latent_model_input,
|
| 687 |
+
t,
|
| 688 |
+
encoder_hidden_states=prompt_embeds,
|
| 689 |
+
timestep_cond=timestep_cond,
|
| 690 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 691 |
+
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
| 692 |
+
added_cond_kwargs=added_cond_kwargs,
|
| 693 |
+
return_dict=False,
|
| 694 |
+
)[0]
|
| 695 |
+
|
| 696 |
+
# perform guidance
|
| 697 |
+
if self.do_classifier_free_guidance:
|
| 698 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 699 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (
|
| 700 |
+
noise_pred_text - noise_pred_uncond
|
| 701 |
+
)
|
| 702 |
+
|
| 703 |
+
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
| 704 |
+
noise_pred = rescale_noise_cfg(
|
| 705 |
+
noise_pred,
|
| 706 |
+
noise_pred_text,
|
| 707 |
+
guidance_rescale=self.guidance_rescale,
|
| 708 |
+
)
|
| 709 |
+
|
| 710 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 711 |
+
latents_dtype = latents.dtype
|
| 712 |
+
latents = self.scheduler.step(
|
| 713 |
+
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
|
| 714 |
+
)[0]
|
| 715 |
+
if latents.dtype != latents_dtype:
|
| 716 |
+
if torch.backends.mps.is_available():
|
| 717 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 718 |
+
latents = latents.to(latents_dtype)
|
| 719 |
+
|
| 720 |
+
if callback_on_step_end is not None:
|
| 721 |
+
callback_kwargs = {}
|
| 722 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 723 |
+
callback_kwargs[k] = locals()[k]
|
| 724 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 725 |
+
|
| 726 |
+
latents = callback_outputs.pop("latents", latents)
|
| 727 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 728 |
+
negative_prompt_embeds = callback_outputs.pop(
|
| 729 |
+
"negative_prompt_embeds", negative_prompt_embeds
|
| 730 |
+
)
|
| 731 |
+
add_text_embeds = callback_outputs.pop(
|
| 732 |
+
"add_text_embeds", add_text_embeds
|
| 733 |
+
)
|
| 734 |
+
negative_pooled_prompt_embeds = callback_outputs.pop(
|
| 735 |
+
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
| 736 |
+
)
|
| 737 |
+
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
|
| 738 |
+
negative_add_time_ids = callback_outputs.pop(
|
| 739 |
+
"negative_add_time_ids", negative_add_time_ids
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
# call the callback, if provided
|
| 743 |
+
if i == len(timesteps) - 1 or (
|
| 744 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
| 745 |
+
):
|
| 746 |
+
progress_bar.update()
|
| 747 |
+
if callback is not None and i % callback_steps == 0:
|
| 748 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 749 |
+
callback(step_idx, t, latents)
|
| 750 |
+
|
| 751 |
+
if not output_type == "latent":
|
| 752 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
| 753 |
+
needs_upcasting = (
|
| 754 |
+
self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
| 755 |
+
)
|
| 756 |
+
|
| 757 |
+
if needs_upcasting:
|
| 758 |
+
self.upcast_vae()
|
| 759 |
+
latents = latents.to(
|
| 760 |
+
next(iter(self.vae.post_quant_conv.parameters())).dtype
|
| 761 |
+
)
|
| 762 |
+
elif latents.dtype != self.vae.dtype:
|
| 763 |
+
if torch.backends.mps.is_available():
|
| 764 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 765 |
+
self.vae = self.vae.to(latents.dtype)
|
| 766 |
+
|
| 767 |
+
# unscale/denormalize the latents
|
| 768 |
+
# denormalize with the mean and std if available and not None
|
| 769 |
+
has_latents_mean = (
|
| 770 |
+
hasattr(self.vae.config, "latents_mean")
|
| 771 |
+
and self.vae.config.latents_mean is not None
|
| 772 |
+
)
|
| 773 |
+
has_latents_std = (
|
| 774 |
+
hasattr(self.vae.config, "latents_std")
|
| 775 |
+
and self.vae.config.latents_std is not None
|
| 776 |
+
)
|
| 777 |
+
if has_latents_mean and has_latents_std:
|
| 778 |
+
latents_mean = (
|
| 779 |
+
torch.tensor(self.vae.config.latents_mean)
|
| 780 |
+
.view(1, 4, 1, 1)
|
| 781 |
+
.to(latents.device, latents.dtype)
|
| 782 |
+
)
|
| 783 |
+
latents_std = (
|
| 784 |
+
torch.tensor(self.vae.config.latents_std)
|
| 785 |
+
.view(1, 4, 1, 1)
|
| 786 |
+
.to(latents.device, latents.dtype)
|
| 787 |
+
)
|
| 788 |
+
latents = (
|
| 789 |
+
latents * latents_std / self.vae.config.scaling_factor
|
| 790 |
+
+ latents_mean
|
| 791 |
+
)
|
| 792 |
+
else:
|
| 793 |
+
latents = latents / self.vae.config.scaling_factor
|
| 794 |
+
|
| 795 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 796 |
+
|
| 797 |
+
# cast back to fp16 if needed
|
| 798 |
+
if needs_upcasting:
|
| 799 |
+
self.vae.to(dtype=torch.float16)
|
| 800 |
+
else:
|
| 801 |
+
image = latents
|
| 802 |
+
|
| 803 |
+
if not output_type == "latent":
|
| 804 |
+
# apply watermark if available
|
| 805 |
+
if self.watermark is not None:
|
| 806 |
+
image = self.watermark.apply_watermark(image)
|
| 807 |
+
|
| 808 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 809 |
+
|
| 810 |
+
# Offload all models
|
| 811 |
+
self.maybe_free_model_hooks()
|
| 812 |
+
|
| 813 |
+
if not return_dict:
|
| 814 |
+
return (image,)
|
| 815 |
+
|
| 816 |
+
return StableDiffusionXLPipelineOutput(images=image)
|
| 817 |
+
|
| 818 |
+
### NEW: adapters ###
|
| 819 |
+
def _init_custom_adapter(
|
| 820 |
+
self,
|
| 821 |
+
# Multi-view adapter
|
| 822 |
+
num_views: int = 1,
|
| 823 |
+
self_attn_processor: Any = DecoupledMVRowSelfAttnProcessor2_0,
|
| 824 |
+
# Condition encoder
|
| 825 |
+
cond_in_channels: int = 6,
|
| 826 |
+
# For training
|
| 827 |
+
copy_attn_weights: bool = True,
|
| 828 |
+
zero_init_module_keys: List[str] = [],
|
| 829 |
+
):
|
| 830 |
+
# Condition encoder
|
| 831 |
+
self.cond_encoder = T2IAdapter(
|
| 832 |
+
in_channels=cond_in_channels,
|
| 833 |
+
channels=(320, 640, 1280, 1280),
|
| 834 |
+
num_res_blocks=2,
|
| 835 |
+
downscale_factor=16,
|
| 836 |
+
adapter_type="full_adapter_xl",
|
| 837 |
+
)
|
| 838 |
+
|
| 839 |
+
# set custom attn processor for multi-view attention and image cross-attention
|
| 840 |
+
self.unet: UNet2DConditionModel
|
| 841 |
+
set_unet_2d_condition_attn_processor(
|
| 842 |
+
self.unet,
|
| 843 |
+
set_self_attn_proc_func=lambda name, hs, cad, ap: self_attn_processor(
|
| 844 |
+
query_dim=hs,
|
| 845 |
+
inner_dim=hs,
|
| 846 |
+
num_views=num_views,
|
| 847 |
+
name=name,
|
| 848 |
+
use_mv=True,
|
| 849 |
+
use_ref=True,
|
| 850 |
+
),
|
| 851 |
+
set_cross_attn_proc_func=lambda name, hs, cad, ap: self_attn_processor(
|
| 852 |
+
query_dim=hs,
|
| 853 |
+
inner_dim=hs,
|
| 854 |
+
num_views=num_views,
|
| 855 |
+
name=name,
|
| 856 |
+
use_mv=False,
|
| 857 |
+
use_ref=False,
|
| 858 |
+
),
|
| 859 |
+
)
|
| 860 |
+
|
| 861 |
+
# copy decoupled attention weights from original unet
|
| 862 |
+
if copy_attn_weights:
|
| 863 |
+
state_dict = self.unet.state_dict()
|
| 864 |
+
for key in state_dict.keys():
|
| 865 |
+
if "_mv" in key:
|
| 866 |
+
compatible_key = key.replace("_mv", "").replace("processor.", "")
|
| 867 |
+
elif "_ref" in key:
|
| 868 |
+
compatible_key = key.replace("_ref", "").replace("processor.", "")
|
| 869 |
+
else:
|
| 870 |
+
compatible_key = key
|
| 871 |
+
|
| 872 |
+
is_zero_init_key = any([k in key for k in zero_init_module_keys])
|
| 873 |
+
if is_zero_init_key:
|
| 874 |
+
state_dict[key] = torch.zeros_like(state_dict[compatible_key])
|
| 875 |
+
else:
|
| 876 |
+
state_dict[key] = state_dict[compatible_key].clone()
|
| 877 |
+
self.unet.load_state_dict(state_dict)
|
| 878 |
+
|
| 879 |
+
def _load_custom_adapter(self, state_dict):
|
| 880 |
+
self.unet.load_state_dict(state_dict, strict=False)
|
| 881 |
+
self.cond_encoder.load_state_dict(state_dict, strict=False)
|
| 882 |
+
|
| 883 |
+
def _save_custom_adapter(
|
| 884 |
+
self,
|
| 885 |
+
include_keys: Optional[List[str]] = None,
|
| 886 |
+
exclude_keys: Optional[List[str]] = None,
|
| 887 |
+
):
|
| 888 |
+
def include_fn(k):
|
| 889 |
+
is_included = False
|
| 890 |
+
|
| 891 |
+
if include_keys is not None:
|
| 892 |
+
is_included = is_included or any([key in k for key in include_keys])
|
| 893 |
+
if exclude_keys is not None:
|
| 894 |
+
is_included = is_included and not any(
|
| 895 |
+
[key in k for key in exclude_keys]
|
| 896 |
+
)
|
| 897 |
+
|
| 898 |
+
return is_included
|
| 899 |
+
|
| 900 |
+
state_dict = {k: v for k, v in self.unet.state_dict().items() if include_fn(k)}
|
| 901 |
+
state_dict.update(self.cond_encoder.state_dict())
|
| 902 |
+
|
| 903 |
+
return state_dict
|
comfyui-mvadapter/mvadapter/schedulers/ShiftSNRSchedulerKarras.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from .scheduler_utils import SNR_to_betas, compute_snr
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ShiftSNRSchedulerKarras:
|
| 9 |
+
"""
|
| 10 |
+
Wraps a Diffusers scheduler to apply SNR shifting to its noise schedule and
|
| 11 |
+
rebuilds a DPMSolverMultistepScheduler that supports Karras sigmas.
|
| 12 |
+
|
| 13 |
+
Usage:
|
| 14 |
+
new_sched = ShiftSNRSchedulerKarras.from_scheduler(
|
| 15 |
+
noise_scheduler=base_sched,
|
| 16 |
+
shift_mode="interpolated", # or "default"
|
| 17 |
+
shift_scale=8.0,
|
| 18 |
+
scheduler_class=DPMSolverMultistepScheduler, # usually this
|
| 19 |
+
)
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
# Supported modes for how the SNR shift is applied
|
| 23 |
+
SHIFT_MODES = ["default", "interpolated"]
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
noise_scheduler: Any,
|
| 28 |
+
timesteps: Any,
|
| 29 |
+
shift_scale: float,
|
| 30 |
+
scheduler_class: Any,
|
| 31 |
+
):
|
| 32 |
+
# original scheduler (used only as a reference/config source)
|
| 33 |
+
self.noise_scheduler = noise_scheduler
|
| 34 |
+
# tensor of timesteps to compute SNR/betas on
|
| 35 |
+
self.timesteps = timesteps
|
| 36 |
+
# scale by which to divide the SNR (e.g., 8.0)
|
| 37 |
+
self.shift_scale = shift_scale
|
| 38 |
+
# the scheduler class to construct for output (e.g., DPMSolverMultistepScheduler)
|
| 39 |
+
self.scheduler_class = scheduler_class
|
| 40 |
+
|
| 41 |
+
def _get_shift_scheduler(self):
|
| 42 |
+
"""
|
| 43 |
+
Apply a uniform SNR shift: snr' = snr / shift_scale
|
| 44 |
+
Then convert to betas and rebuild the scheduler with Karras enabled.
|
| 45 |
+
"""
|
| 46 |
+
snr = compute_snr(self.timesteps, self.noise_scheduler)
|
| 47 |
+
shifted_betas = SNR_to_betas(snr / self.shift_scale)
|
| 48 |
+
|
| 49 |
+
return self.scheduler_class.from_config(
|
| 50 |
+
self.noise_scheduler.config,
|
| 51 |
+
trained_betas=shifted_betas.numpy(),
|
| 52 |
+
# Enable Karras sigmas in the rebuilt scheduler
|
| 53 |
+
algorithm_type="dpmsolver++",
|
| 54 |
+
use_karras_sigmas=True,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
def _get_interpolated_shift_scheduler(self):
|
| 58 |
+
"""
|
| 59 |
+
Interpolate SNR in log-space between the original and the shifted SNR
|
| 60 |
+
as timesteps progress. This tends to preserve early behavior and
|
| 61 |
+
gradually apply the shift later in the schedule.
|
| 62 |
+
"""
|
| 63 |
+
snr = compute_snr(self.timesteps, self.noise_scheduler)
|
| 64 |
+
shifted_snr = snr / self.shift_scale
|
| 65 |
+
|
| 66 |
+
# Interpolate in log-space from original -> shifted across timesteps
|
| 67 |
+
weighting = self.timesteps.float() / (
|
| 68 |
+
self.noise_scheduler.config.num_train_timesteps - 1
|
| 69 |
+
)
|
| 70 |
+
interpolated_snr = torch.exp(
|
| 71 |
+
torch.log(snr) * (1 - weighting) + torch.log(shifted_snr) * weighting
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
shifted_betas = SNR_to_betas(interpolated_snr)
|
| 75 |
+
|
| 76 |
+
return self.scheduler_class.from_config(
|
| 77 |
+
self.noise_scheduler.config,
|
| 78 |
+
trained_betas=shifted_betas.numpy(),
|
| 79 |
+
# Enable Karras sigmas in the rebuilt scheduler
|
| 80 |
+
algorithm_type="dpmsolver++",
|
| 81 |
+
use_karras_sigmas=True,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
@classmethod
|
| 85 |
+
def from_scheduler(
|
| 86 |
+
cls,
|
| 87 |
+
noise_scheduler: Any,
|
| 88 |
+
shift_mode: str = "default",
|
| 89 |
+
timesteps: Any = None,
|
| 90 |
+
shift_scale: float = 1.0,
|
| 91 |
+
scheduler_class: Any = None,
|
| 92 |
+
):
|
| 93 |
+
"""
|
| 94 |
+
Factory that returns a NEW scheduler instance with the shifted betas applied.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
noise_scheduler: the original Diffusers scheduler (used for config & base betas)
|
| 98 |
+
shift_mode: "default" or "interpolated"
|
| 99 |
+
timesteps: tensor of timesteps to evaluate SNR on; if None, uses full training range
|
| 100 |
+
shift_scale: divide SNR by this value (e.g., 8.0)
|
| 101 |
+
scheduler_class: class to construct for the output scheduler (defaults to original class)
|
| 102 |
+
"""
|
| 103 |
+
if timesteps is None:
|
| 104 |
+
timesteps = torch.arange(0, noise_scheduler.config.num_train_timesteps)
|
| 105 |
+
if scheduler_class is None:
|
| 106 |
+
scheduler_class = noise_scheduler.__class__
|
| 107 |
+
|
| 108 |
+
wrapper = cls(
|
| 109 |
+
noise_scheduler=noise_scheduler,
|
| 110 |
+
timesteps=timesteps,
|
| 111 |
+
shift_scale=shift_scale,
|
| 112 |
+
scheduler_class=scheduler_class,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
if shift_mode == "default":
|
| 116 |
+
return wrapper._get_shift_scheduler()
|
| 117 |
+
elif shift_mode == "interpolated":
|
| 118 |
+
return wrapper._get_interpolated_shift_scheduler()
|
| 119 |
+
else:
|
| 120 |
+
raise ValueError(f"Unknown shift_mode: {shift_mode}")
|
comfyui-mvadapter/mvadapter/schedulers/__pycache__/ShiftSNRSchedulerKarras.cpython-312.pyc
ADDED
|
Binary file (4.86 kB). View file
|
|
|
comfyui-mvadapter/mvadapter/schedulers/__pycache__/scheduler_utils.cpython-312.pyc
ADDED
|
Binary file (3.78 kB). View file
|
|
|
comfyui-mvadapter/mvadapter/schedulers/__pycache__/scheduling_shift_snr.cpython-312.pyc
ADDED
|
Binary file (5.9 kB). View file
|
|
|
comfyui-mvadapter/mvadapter/schedulers/scheduler_utils.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def get_sigmas(noise_scheduler, timesteps, n_dim=4, dtype=torch.float32, device=None):
|
| 5 |
+
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
|
| 6 |
+
schedule_timesteps = noise_scheduler.timesteps.to(device)
|
| 7 |
+
timesteps = timesteps.to(device)
|
| 8 |
+
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
| 9 |
+
sigma = sigmas[step_indices].flatten()
|
| 10 |
+
while len(sigma.shape) < n_dim:
|
| 11 |
+
sigma = sigma.unsqueeze(-1)
|
| 12 |
+
return sigma
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def SNR_to_betas(snr):
|
| 16 |
+
"""
|
| 17 |
+
Converts SNR to betas
|
| 18 |
+
"""
|
| 19 |
+
# alphas_cumprod = pass
|
| 20 |
+
# snr = (alpha / ) ** 2
|
| 21 |
+
# alpha_t^2 / (1 - alpha_t^2) = snr
|
| 22 |
+
alpha_t = (snr / (1 + snr)) ** 0.5
|
| 23 |
+
alphas_cumprod = alpha_t**2
|
| 24 |
+
alphas = alphas_cumprod / torch.cat(
|
| 25 |
+
[torch.ones(1, device=snr.device), alphas_cumprod[:-1]]
|
| 26 |
+
)
|
| 27 |
+
betas = 1 - alphas
|
| 28 |
+
return betas
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def compute_snr(timesteps, noise_scheduler):
|
| 32 |
+
"""
|
| 33 |
+
Computes SNR as per Min-SNR-Diffusion-Training/guided_diffusion/gaussian_diffusion.py at 521b624bd70c67cee4bdf49225915f5
|
| 34 |
+
"""
|
| 35 |
+
alphas_cumprod = noise_scheduler.alphas_cumprod
|
| 36 |
+
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
| 37 |
+
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
| 38 |
+
|
| 39 |
+
# Expand the tensors.
|
| 40 |
+
# Adapted from Min-SNR-Diffusion-Training/guided_diffusion/gaussian_diffusion.py at 521b624bd70c67cee4bdf49225915f5
|
| 41 |
+
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
|
| 42 |
+
timesteps
|
| 43 |
+
].float()
|
| 44 |
+
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
| 45 |
+
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
| 46 |
+
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
| 47 |
+
|
| 48 |
+
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
|
| 49 |
+
device=timesteps.device
|
| 50 |
+
)[timesteps].float()
|
| 51 |
+
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
| 52 |
+
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
| 53 |
+
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
| 54 |
+
|
| 55 |
+
# Compute SNR.
|
| 56 |
+
snr = (alpha / sigma) ** 2
|
| 57 |
+
return snr
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def compute_alpha(timesteps, noise_scheduler):
|
| 61 |
+
alphas_cumprod = noise_scheduler.alphas_cumprod
|
| 62 |
+
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
| 63 |
+
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
|
| 64 |
+
timesteps
|
| 65 |
+
].float()
|
| 66 |
+
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
| 67 |
+
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
| 68 |
+
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
| 69 |
+
|
| 70 |
+
return alpha
|
comfyui-mvadapter/mvadapter/schedulers/scheduling_shift_snr.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from .scheduler_utils import SNR_to_betas, compute_snr
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ShiftSNRScheduler:
|
| 9 |
+
SHIFT_MODES = ["default", "interpolated"]
|
| 10 |
+
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
noise_scheduler: Any,
|
| 14 |
+
timesteps: Any,
|
| 15 |
+
shift_scale: float,
|
| 16 |
+
scheduler_class: Any,
|
| 17 |
+
):
|
| 18 |
+
self.noise_scheduler = noise_scheduler
|
| 19 |
+
self.timesteps = timesteps
|
| 20 |
+
self.shift_scale = shift_scale
|
| 21 |
+
self.scheduler_class = scheduler_class
|
| 22 |
+
|
| 23 |
+
def _get_shift_scheduler(self):
|
| 24 |
+
"""
|
| 25 |
+
Prepare scheduler for shifted betas.
|
| 26 |
+
|
| 27 |
+
:return: A scheduler object configured with shifted betas
|
| 28 |
+
"""
|
| 29 |
+
snr = compute_snr(self.timesteps, self.noise_scheduler)
|
| 30 |
+
shifted_betas = SNR_to_betas(snr / self.shift_scale)
|
| 31 |
+
|
| 32 |
+
return self.scheduler_class.from_config(
|
| 33 |
+
self.noise_scheduler.config, trained_betas=shifted_betas.numpy()
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
def _get_interpolated_shift_scheduler(self):
|
| 37 |
+
"""
|
| 38 |
+
Prepare scheduler for shifted betas and interpolate with the original betas in log space.
|
| 39 |
+
|
| 40 |
+
:return: A scheduler object configured with interpolated shifted betas
|
| 41 |
+
"""
|
| 42 |
+
snr = compute_snr(self.timesteps, self.noise_scheduler)
|
| 43 |
+
shifted_snr = snr / self.shift_scale
|
| 44 |
+
|
| 45 |
+
weighting = self.timesteps.float() / (
|
| 46 |
+
self.noise_scheduler.config.num_train_timesteps - 1
|
| 47 |
+
)
|
| 48 |
+
interpolated_snr = torch.exp(
|
| 49 |
+
torch.log(snr) * (1 - weighting) + torch.log(shifted_snr) * weighting
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
shifted_betas = SNR_to_betas(interpolated_snr)
|
| 53 |
+
|
| 54 |
+
return self.scheduler_class.from_config(
|
| 55 |
+
self.noise_scheduler.config, trained_betas=shifted_betas.numpy()
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
@classmethod
|
| 59 |
+
def from_scheduler(
|
| 60 |
+
cls,
|
| 61 |
+
noise_scheduler: Any,
|
| 62 |
+
shift_mode: str = "default",
|
| 63 |
+
timesteps: Any = None,
|
| 64 |
+
shift_scale: float = 1.0,
|
| 65 |
+
scheduler_class: Any = None,
|
| 66 |
+
):
|
| 67 |
+
# Check input
|
| 68 |
+
if timesteps is None:
|
| 69 |
+
timesteps = torch.arange(0, noise_scheduler.config.num_train_timesteps)
|
| 70 |
+
if scheduler_class is None:
|
| 71 |
+
scheduler_class = noise_scheduler.__class__
|
| 72 |
+
|
| 73 |
+
# Create scheduler
|
| 74 |
+
shift_scheduler = cls(
|
| 75 |
+
noise_scheduler=noise_scheduler,
|
| 76 |
+
timesteps=timesteps,
|
| 77 |
+
shift_scale=shift_scale,
|
| 78 |
+
scheduler_class=scheduler_class,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
if shift_mode == "default":
|
| 82 |
+
return shift_scheduler._get_shift_scheduler()
|
| 83 |
+
elif shift_mode == "interpolated":
|
| 84 |
+
return shift_scheduler._get_interpolated_shift_scheduler()
|
| 85 |
+
else:
|
| 86 |
+
raise ValueError(f"Unknown shift_mode: {shift_mode}")
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
if __name__ == "__main__":
|
| 90 |
+
"""
|
| 91 |
+
Compare the alpha values for different noise schedulers.
|
| 92 |
+
"""
|
| 93 |
+
import matplotlib.pyplot as plt
|
| 94 |
+
from diffusers import DDPMScheduler
|
| 95 |
+
|
| 96 |
+
from .scheduler_utils import compute_alpha
|
| 97 |
+
|
| 98 |
+
# Base
|
| 99 |
+
timesteps = torch.arange(0, 1000)
|
| 100 |
+
noise_scheduler_base = DDPMScheduler.from_pretrained(
|
| 101 |
+
"runwayml/stable-diffusion-v1-5", subfolder="scheduler"
|
| 102 |
+
)
|
| 103 |
+
alpha = compute_alpha(timesteps, noise_scheduler_base)
|
| 104 |
+
plt.plot(timesteps.numpy(), alpha.numpy(), label="Base")
|
| 105 |
+
|
| 106 |
+
# Kolors
|
| 107 |
+
num_train_timesteps_ = 1100
|
| 108 |
+
timesteps_ = torch.arange(0, num_train_timesteps_)
|
| 109 |
+
noise_kwargs = {"beta_end": 0.014, "num_train_timesteps": num_train_timesteps_}
|
| 110 |
+
noise_scheduler_kolors = DDPMScheduler.from_config(
|
| 111 |
+
noise_scheduler_base.config, **noise_kwargs
|
| 112 |
+
)
|
| 113 |
+
alpha = compute_alpha(timesteps_, noise_scheduler_kolors)
|
| 114 |
+
plt.plot(timesteps_.numpy(), alpha.numpy(), label="Kolors")
|
| 115 |
+
|
| 116 |
+
# Shift betas
|
| 117 |
+
shift_scale = 8.0
|
| 118 |
+
noise_scheduler_shift = ShiftSNRScheduler.from_scheduler(
|
| 119 |
+
noise_scheduler_base, shift_mode="default", shift_scale=shift_scale
|
| 120 |
+
)
|
| 121 |
+
alpha = compute_alpha(timesteps, noise_scheduler_shift)
|
| 122 |
+
plt.plot(timesteps.numpy(), alpha.numpy(), label="Shift Noise (scale 8.0)")
|
| 123 |
+
|
| 124 |
+
# Shift betas (interpolated)
|
| 125 |
+
noise_scheduler_inter = ShiftSNRScheduler.from_scheduler(
|
| 126 |
+
noise_scheduler_base, shift_mode="interpolated", shift_scale=shift_scale
|
| 127 |
+
)
|
| 128 |
+
alpha = compute_alpha(timesteps, noise_scheduler_inter)
|
| 129 |
+
plt.plot(timesteps.numpy(), alpha.numpy(), label="Interpolated (scale 8.0)")
|
| 130 |
+
|
| 131 |
+
# ZeroSNR
|
| 132 |
+
noise_scheduler = DDPMScheduler.from_config(
|
| 133 |
+
noise_scheduler_base.config, rescale_betas_zero_snr=True
|
| 134 |
+
)
|
| 135 |
+
alpha = compute_alpha(timesteps, noise_scheduler)
|
| 136 |
+
plt.plot(timesteps.numpy(), alpha.numpy(), label="ZeroSNR")
|
| 137 |
+
|
| 138 |
+
plt.legend()
|
| 139 |
+
plt.grid()
|
| 140 |
+
plt.savefig("check_alpha.png")
|
comfyui-mvadapter/mvadapter/utils/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .camera import get_camera, get_orthogonal_camera
|
| 2 |
+
from .geometry import get_plucker_embeds_from_cameras_ortho
|
| 3 |
+
from .saving import make_image_grid, tensor_to_image
|