saliacoel commited on
Commit
f93d68a
·
verified ·
1 Parent(s): 5f4a806

Upload 111 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +29 -0
  2. comfyui-mvadapter/.github/workflows/publish.yml +25 -0
  3. comfyui-mvadapter/BACKUP_nodes.py +843 -0
  4. comfyui-mvadapter/LICENSE +201 -0
  5. comfyui-mvadapter/README.md +88 -0
  6. comfyui-mvadapter/__init__.py +45 -0
  7. comfyui-mvadapter/__pycache__/__init__.cpython-312.pyc +0 -0
  8. comfyui-mvadapter/__pycache__/nodes.cpython-312.pyc +0 -0
  9. comfyui-mvadapter/__pycache__/nodes_local_mv.cpython-312.pyc +0 -0
  10. comfyui-mvadapter/__pycache__/utils.cpython-312.pyc +0 -0
  11. comfyui-mvadapter/assets/CustomLoraModelLoader.png +0 -0
  12. comfyui-mvadapter/assets/comfyui_i2mv.png +3 -0
  13. comfyui-mvadapter/assets/comfyui_i2mv_lora.png +3 -0
  14. comfyui-mvadapter/assets/comfyui_i2mv_multiple_loras.jpg +3 -0
  15. comfyui-mvadapter/assets/comfyui_i2mv_view_selector.png +3 -0
  16. comfyui-mvadapter/assets/comfyui_ldm_vae.png +0 -0
  17. comfyui-mvadapter/assets/comfyui_model_makeup.png +0 -0
  18. comfyui-mvadapter/assets/comfyui_t2mv.png +3 -0
  19. comfyui-mvadapter/assets/comfyui_t2mv_controlnet.png +3 -0
  20. comfyui-mvadapter/assets/comfyui_t2mv_lora.png +3 -0
  21. comfyui-mvadapter/assets/comfyui_t2mv_multiple_loras.jpg +3 -0
  22. comfyui-mvadapter/assets/demo/scribbles/scribble_0.png +0 -0
  23. comfyui-mvadapter/assets/demo/scribbles/scribble_1.png +0 -0
  24. comfyui-mvadapter/assets/demo/scribbles/scribble_2.png +0 -0
  25. comfyui-mvadapter/assets/demo/scribbles/scribble_3.png +0 -0
  26. comfyui-mvadapter/assets/demo/scribbles/scribble_4.png +0 -0
  27. comfyui-mvadapter/assets/demo/scribbles/scribble_5.png +0 -0
  28. comfyui-mvadapter/cache/stable-diffusion-v1-inference.yaml +70 -0
  29. comfyui-mvadapter/mvadapter/__init__.py +0 -0
  30. comfyui-mvadapter/mvadapter/__pycache__/__init__.cpython-312.pyc +0 -0
  31. comfyui-mvadapter/mvadapter/loaders/__init__.py +1 -0
  32. comfyui-mvadapter/mvadapter/loaders/__pycache__/__init__.cpython-312.pyc +0 -0
  33. comfyui-mvadapter/mvadapter/loaders/__pycache__/custom_adapter.cpython-312.pyc +0 -0
  34. comfyui-mvadapter/mvadapter/loaders/custom_adapter.py +98 -0
  35. comfyui-mvadapter/mvadapter/models/__init__.py +0 -0
  36. comfyui-mvadapter/mvadapter/models/__pycache__/__init__.cpython-312.pyc +0 -0
  37. comfyui-mvadapter/mvadapter/models/__pycache__/attention_processor.cpython-312.pyc +0 -0
  38. comfyui-mvadapter/mvadapter/models/attention_processor.py +377 -0
  39. comfyui-mvadapter/mvadapter/pipelines/__pycache__/pipeline_mvadapter_i2mv_sd.cpython-312.pyc +0 -0
  40. comfyui-mvadapter/mvadapter/pipelines/__pycache__/pipeline_mvadapter_i2mv_sdxl.cpython-312.pyc +0 -0
  41. comfyui-mvadapter/mvadapter/pipelines/__pycache__/pipeline_mvadapter_t2mv_sd.cpython-312.pyc +0 -0
  42. comfyui-mvadapter/mvadapter/pipelines/__pycache__/pipeline_mvadapter_t2mv_sdxl.cpython-312.pyc +0 -0
  43. comfyui-mvadapter/mvadapter/pipelines/pipeline_mvadapter_i2mv_sdxl.py +903 -0
  44. comfyui-mvadapter/mvadapter/schedulers/ShiftSNRSchedulerKarras.py +120 -0
  45. comfyui-mvadapter/mvadapter/schedulers/__pycache__/ShiftSNRSchedulerKarras.cpython-312.pyc +0 -0
  46. comfyui-mvadapter/mvadapter/schedulers/__pycache__/scheduler_utils.cpython-312.pyc +0 -0
  47. comfyui-mvadapter/mvadapter/schedulers/__pycache__/scheduling_shift_snr.cpython-312.pyc +0 -0
  48. comfyui-mvadapter/mvadapter/schedulers/scheduler_utils.py +70 -0
  49. comfyui-mvadapter/mvadapter/schedulers/scheduling_shift_snr.py +140 -0
  50. comfyui-mvadapter/mvadapter/utils/__init__.py +3 -0
.gitattributes CHANGED
@@ -33,3 +33,32 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ comfyui-mvadapter/assets/comfyui_i2mv_lora.png filter=lfs diff=lfs merge=lfs -text
37
+ comfyui-mvadapter/assets/comfyui_i2mv_multiple_loras.jpg filter=lfs diff=lfs merge=lfs -text
38
+ comfyui-mvadapter/assets/comfyui_i2mv_view_selector.png filter=lfs diff=lfs merge=lfs -text
39
+ comfyui-mvadapter/assets/comfyui_i2mv.png filter=lfs diff=lfs merge=lfs -text
40
+ comfyui-mvadapter/assets/comfyui_t2mv_controlnet.png filter=lfs diff=lfs merge=lfs -text
41
+ comfyui-mvadapter/assets/comfyui_t2mv_lora.png filter=lfs diff=lfs merge=lfs -text
42
+ comfyui-mvadapter/assets/comfyui_t2mv_multiple_loras.jpg filter=lfs diff=lfs merge=lfs -text
43
+ comfyui-mvadapter/assets/comfyui_t2mv.png filter=lfs diff=lfs merge=lfs -text
44
+ comfyui-salia/assets/images/boy0.png filter=lfs diff=lfs merge=lfs -text
45
+ comfyui-salia/assets/images/boy1.png filter=lfs diff=lfs merge=lfs -text
46
+ comfyui-salia/assets/images/boy2.png filter=lfs diff=lfs merge=lfs -text
47
+ comfyui-salia/assets/images/boy3.png filter=lfs diff=lfs merge=lfs -text
48
+ comfyui-salia/assets/images/boy4.png filter=lfs diff=lfs merge=lfs -text
49
+ comfyui-salia/assets/images/boy5.png filter=lfs diff=lfs merge=lfs -text
50
+ comfyui-salia/assets/images/girl0.png filter=lfs diff=lfs merge=lfs -text
51
+ comfyui-salia/assets/images/girl1.png filter=lfs diff=lfs merge=lfs -text
52
+ comfyui-salia/assets/images/girl2.png filter=lfs diff=lfs merge=lfs -text
53
+ comfyui-salia/assets/images/girl3.png filter=lfs diff=lfs merge=lfs -text
54
+ comfyui-salia/assets/images/girl4.png filter=lfs diff=lfs merge=lfs -text
55
+ comfyui-salia/assets/images/girl5.png filter=lfs diff=lfs merge=lfs -text
56
+ comfyui-salia/assets/images/hair_L_Bound_Braided.png filter=lfs diff=lfs merge=lfs -text
57
+ comfyui-salia/assets/images/hair_L_Bound.png filter=lfs diff=lfs merge=lfs -text
58
+ comfyui-salia/assets/images/hair_L_Loose.png filter=lfs diff=lfs merge=lfs -text
59
+ comfyui-salia/assets/images/hair_M_Bound_Braided.png filter=lfs diff=lfs merge=lfs -text
60
+ comfyui-salia/assets/images/hair_M_Bound.png filter=lfs diff=lfs merge=lfs -text
61
+ comfyui-salia/assets/images/hair_M_Loose.png filter=lfs diff=lfs merge=lfs -text
62
+ comfyui-salia/assets/images/hair_S_Bound_Braided.png filter=lfs diff=lfs merge=lfs -text
63
+ comfyui-salia/assets/images/hair_S_Bound.png filter=lfs diff=lfs merge=lfs -text
64
+ comfyui-salia/assets/images/hair_S_Loose.png filter=lfs diff=lfs merge=lfs -text
comfyui-mvadapter/.github/workflows/publish.yml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Publish to Comfy registry
2
+ on:
3
+ workflow_dispatch:
4
+ push:
5
+ branches:
6
+ - main
7
+ paths:
8
+ - "pyproject.toml"
9
+
10
+ permissions:
11
+ issues: write
12
+
13
+ jobs:
14
+ publish-node:
15
+ name: Publish Custom Node to registry
16
+ runs-on: ubuntu-latest
17
+ if: ${{ github.repository_owner == 'huanngzh' }}
18
+ steps:
19
+ - name: Check out code
20
+ uses: actions/checkout@v4
21
+ - name: Publish Custom Node
22
+ uses: Comfy-Org/publish-node-action@v1
23
+ with:
24
+ ## Add your own personal access token to your Github Repository secrets and reference it here.
25
+ personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}
comfyui-mvadapter/BACKUP_nodes.py ADDED
@@ -0,0 +1,843 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/Limitex/ComfyUI-Diffusers/blob/main/nodes.py
2
+ import copy
3
+ import os
4
+ import torch
5
+ from safetensors.torch import load_file
6
+ from torchvision import transforms
7
+ from .utils import (
8
+ SCHEDULERS,
9
+ PIPELINES,
10
+ MVADAPTERS,
11
+ vae_pt_to_vae_diffuser,
12
+ convert_images_to_tensors,
13
+ convert_tensors_to_images,
14
+ prepare_camera_embed,
15
+ preprocess_image,
16
+ )
17
+ from comfy.model_management import get_torch_device
18
+ import folder_paths
19
+
20
+ from diffusers import StableDiffusionXLPipeline, AutoencoderKL, ControlNetModel
21
+ from transformers import AutoModelForImageSegmentation # <-- restored
22
+
23
+ # ADDED: import DPMSolverMultistepScheduler for DPM++ Karras
24
+ from diffusers import DPMSolverMultistepScheduler
25
+
26
+ from .mvadapter.pipelines.pipeline_mvadapter_t2mv_sdxl import MVAdapterT2MVSDXLPipeline
27
+ from .mvadapter.schedulers.scheduling_shift_snr import ShiftSNRScheduler
28
+
29
+ # ADDED: import your new Karras-enabled shift scheduler (file sits next to scheduling_shift_snr.py)
30
+ from .mvadapter.schedulers.ShiftSNRSchedulerKarras import ShiftSNRSchedulerKarras
31
+
32
+
33
+
34
+ class DiffusersMVPipelineLoader:
35
+ def __init__(self):
36
+ self.hf_dir = folder_paths.get_folder_paths("diffusers")[0]
37
+ self.dtype = torch.float16
38
+
39
+ @classmethod
40
+ def INPUT_TYPES(s):
41
+ return {
42
+ "required": {
43
+ "ckpt_name": (
44
+ "STRING",
45
+ {"default": "stabilityai/stable-diffusion-xl-base-1.0"},
46
+ ),
47
+ "pipeline_name": (
48
+ list(PIPELINES.keys()),
49
+ {"default": "MVAdapterT2MVSDXLPipeline"},
50
+ ),
51
+ }
52
+ }
53
+
54
+ RETURN_TYPES = (
55
+ "PIPELINE",
56
+ "AUTOENCODER",
57
+ "SCHEDULER",
58
+ )
59
+
60
+ FUNCTION = "create_pipeline"
61
+
62
+ CATEGORY = "MV-Adapter"
63
+
64
+ def create_pipeline(self, ckpt_name, pipeline_name):
65
+ pipeline_class = PIPELINES[pipeline_name]
66
+ pipe = pipeline_class.from_pretrained(
67
+ pretrained_model_name_or_path=ckpt_name,
68
+ torch_dtype=self.dtype,
69
+ cache_dir=self.hf_dir,
70
+ )
71
+ return (pipe, pipe.vae, pipe.scheduler)
72
+
73
+
74
+ class LdmPipelineLoader:
75
+ def __init__(self):
76
+ self.hf_dir = folder_paths.get_folder_paths("diffusers")[0]
77
+ self.dtype = torch.float16
78
+
79
+ @classmethod
80
+ def INPUT_TYPES(s):
81
+ return {
82
+ "required": {
83
+ "ckpt_name": (folder_paths.get_filename_list("checkpoints"),),
84
+ "pipeline_name": (
85
+ list(PIPELINES.keys()),
86
+ {"default": "MVAdapterT2MVSDXLPipeline"},
87
+ ),
88
+ }
89
+ }
90
+
91
+ RETURN_TYPES = (
92
+ "PIPELINE",
93
+ "AUTOENCODER",
94
+ "SCHEDULER",
95
+ )
96
+
97
+ FUNCTION = "create_pipeline"
98
+
99
+ CATEGORY = "MV-Adapter"
100
+
101
+ def create_pipeline(self, ckpt_name, pipeline_name):
102
+ pipeline_class = PIPELINES[pipeline_name]
103
+
104
+ pipe = pipeline_class.from_single_file(
105
+ pretrained_model_link_or_path=folder_paths.get_full_path(
106
+ "checkpoints", ckpt_name
107
+ ),
108
+ torch_dtype=self.dtype,
109
+ cache_dir=self.hf_dir,
110
+ )
111
+
112
+ return (pipe, pipe.vae, pipe.scheduler)
113
+
114
+
115
+ class DiffusersMVVaeLoader:
116
+ def __init__(self):
117
+ self.hf_dir = folder_paths.get_folder_paths("diffusers")[0]
118
+ self.dtype = torch.float16
119
+
120
+ @classmethod
121
+ def INPUT_TYPES(s):
122
+ return {
123
+ "required": {
124
+ "vae_name": (
125
+ "STRING",
126
+ {"default": "madebyollin/sdxl-vae-fp16-fix"},
127
+ ),
128
+ }
129
+ }
130
+
131
+ RETURN_TYPES = ("AUTOENCODER",)
132
+
133
+ FUNCTION = "create_pipeline"
134
+
135
+ CATEGORY = "MV-Adapter"
136
+
137
+ def create_pipeline(self, vae_name):
138
+ vae = AutoencoderKL.from_pretrained(
139
+ pretrained_model_name_or_path=vae_name,
140
+ torch_dtype=self.dtype,
141
+ cache_dir=self.hf_dir,
142
+ )
143
+
144
+ return (vae,)
145
+
146
+
147
+ class LdmVaeLoader:
148
+ def __init__(self):
149
+ self.dtype = torch.float16
150
+
151
+ @classmethod
152
+ def INPUT_TYPES(s):
153
+ return {
154
+ "required": {
155
+ "vae_name": (folder_paths.get_filename_list("vae"),),
156
+ "upcast_fp32": ("BOOLEAN", {"default": True}),
157
+ },
158
+ }
159
+
160
+ RETURN_TYPES = ("AUTOENCODER",)
161
+
162
+ FUNCTION = "create_pipeline"
163
+
164
+ CATEGORY = "MV-Adapter"
165
+
166
+ def create_pipeline(self, vae_name, upcast_fp32):
167
+ vae = vae_pt_to_vae_diffuser(
168
+ folder_paths.get_full_path("vae", vae_name), force_upcast=upcast_fp32
169
+ ).to(self.dtype)
170
+
171
+ return (vae,)
172
+
173
+
174
+ class DiffusersMVSchedulerLoader:
175
+ def __init__(self):
176
+ self.hf_dir = folder_paths.get_folder_paths("diffusers")[0]
177
+ self.dtype = torch.float16
178
+
179
+ @classmethod
180
+ def INPUT_TYPES(s):
181
+ return {
182
+ "required": {
183
+ "pipeline": ("PIPELINE",),
184
+ "scheduler_name": (list(SCHEDULERS.keys()),),
185
+ "shift_snr": ("BOOLEAN", {"default": True}),
186
+ "shift_mode": (
187
+ list(ShiftSNRScheduler.SHIFT_MODES),
188
+ {"default": "interpolated"},
189
+ ),
190
+ "shift_scale": (
191
+ "FLOAT",
192
+ {"default": 8.0, "min": 0.0, "max": 50.0, "step": 1.0},
193
+ ),
194
+ }
195
+ }
196
+
197
+ RETURN_TYPES = ("SCHEDULER",)
198
+
199
+ FUNCTION = "load_scheduler"
200
+
201
+ CATEGORY = "MV-Adapter"
202
+
203
+ def load_scheduler(
204
+ self, pipeline, scheduler_name, shift_snr, shift_mode, shift_scale
205
+ ):
206
+ scheduler = SCHEDULERS[scheduler_name].from_config(
207
+ pipeline.scheduler.config, torch_dtype=self.dtype
208
+ )
209
+ if shift_snr:
210
+ scheduler = ShiftSNRScheduler.from_scheduler(
211
+ scheduler,
212
+ shift_mode=shift_mode,
213
+ shift_scale=shift_scale,
214
+ scheduler_class=scheduler.__class__,
215
+ )
216
+ return (scheduler,)
217
+
218
+
219
+ # ADDED: Karras version — same inputs/outputs, but always returns a DPM++ (Karras) scheduler.
220
+ class DiffusersMVSchedulerLoaderKarras:
221
+ def __init__(self):
222
+ self.hf_dir = folder_paths.get_folder_paths("diffusers")[0]
223
+ self.dtype = torch.float16
224
+
225
+ @classmethod
226
+ def INPUT_TYPES(s):
227
+ return {
228
+ "required": {
229
+ "pipeline": ("PIPELINE",),
230
+ "scheduler_name": (list(SCHEDULERS.keys()),),
231
+ "shift_snr": ("BOOLEAN", {"default": True}),
232
+ "shift_mode": (
233
+ list(ShiftSNRSchedulerKarras.SHIFT_MODES),
234
+ {"default": "interpolated"},
235
+ ),
236
+ "shift_scale": (
237
+ "FLOAT",
238
+ {"default": 8.0, "min": 0.0, "max": 50.0, "step": 1.0},
239
+ ),
240
+ }
241
+ }
242
+
243
+ RETURN_TYPES = ("SCHEDULER",)
244
+
245
+ FUNCTION = "load_scheduler"
246
+
247
+ CATEGORY = "MV-Adapter"
248
+
249
+ def load_scheduler(
250
+ self, pipeline, scheduler_name, shift_snr, shift_mode, shift_scale
251
+ ):
252
+ # Build a base scheduler from the pipeline config (kept for parity with original UI),
253
+ # then *replace* it with DPM++ (Karras). If SNR shift is requested, apply via your Karras class.
254
+ base_sched = SCHEDULERS[scheduler_name].from_config(
255
+ pipeline.scheduler.config, torch_dtype=self.dtype
256
+ )
257
+
258
+ # Always use DPM++ Karras:
259
+ if shift_snr:
260
+ # Apply your Karras-enabled Shift SNR on top, and force DPM++ class to guarantee Karras works.
261
+ scheduler = ShiftSNRSchedulerKarras.from_scheduler(
262
+ base_sched,
263
+ shift_mode=shift_mode,
264
+ shift_scale=shift_scale,
265
+ scheduler_class=DPMSolverMultistepScheduler,
266
+ )
267
+ else:
268
+ # No SNR shift requested: just return DPM++ with Karras sigmas
269
+ scheduler = DPMSolverMultistepScheduler.from_config(
270
+ pipeline.scheduler.config,
271
+ algorithm_type="dpmsolver++",
272
+ use_karras_sigmas=True,
273
+ torch_dtype=self.dtype,
274
+ )
275
+
276
+ return (scheduler,)
277
+
278
+
279
+ class CustomLoraModelLoader:
280
+ def __init__(self):
281
+ self.loaded_lora = None
282
+
283
+ @classmethod
284
+ def INPUT_TYPES(s):
285
+ return {
286
+ "required": {
287
+ "pipeline": ("PIPELINE",),
288
+ "lora_name": (folder_paths.get_filename_list("loras"),),
289
+ "strength_model": (
290
+ "FLOAT",
291
+ {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01},
292
+ ),
293
+ "enable": (
294
+ "BOOLEAN",
295
+ {"default": True},
296
+ ),
297
+ "last_lora_node": (
298
+ "BOOLEAN",
299
+ {"default": True},
300
+ ),
301
+ }
302
+ }
303
+
304
+ RETURN_TYPES = ("PIPELINE",)
305
+ FUNCTION = "load_lora"
306
+ CATEGORY = "MV-Adapter"
307
+
308
+ def load_lora(self, pipeline, lora_name, strength_model, enable, last_lora_node):
309
+ if not hasattr(pipeline, "loaded_loras"):
310
+ pipeline.loaded_loras = []
311
+
312
+ lora_path = folder_paths.get_full_path("loras", lora_name)
313
+ lora_dir = os.path.dirname(lora_path)
314
+ lora_name = os.path.basename(lora_path)
315
+ lora = None
316
+ if enable:
317
+ if self.loaded_lora is not None:
318
+ if self.loaded_lora[0] == lora_path:
319
+ lora = self.loaded_lora[1]
320
+ else:
321
+ temp = self.loaded_lora
322
+ pipeline.delete_adapters(temp[1])
323
+ pipeline.loaded_loras = [(name, strength) for (name, strength) in pipeline.loaded_loras if name != temp[1]]
324
+ self.loaded_lora = None
325
+
326
+ if lora is None:
327
+ adapter_name = lora_name.rsplit(".", 1)[0]
328
+ pipeline.load_lora_weights(
329
+ lora_dir, weight_name=lora_name, adapter_name=adapter_name
330
+ )
331
+ pipeline.set_adapters(adapter_name, strength_model)
332
+ self.loaded_lora = (lora_path, adapter_name)
333
+ lora = adapter_name
334
+
335
+ pipeline.loaded_loras.append((adapter_name, strength_model))
336
+ else:
337
+ # Delete the loaded lora
338
+ if self.loaded_lora is not None:
339
+ temp = self.loaded_lora
340
+ pipeline.delete_adapters(temp[1])
341
+ pipeline.loaded_loras = [(name, strength) for (name, strength) in pipeline.loaded_loras if name != temp[1]]
342
+ self.loaded_lora = None
343
+
344
+ if last_lora_node:
345
+ adapter_names = [x[0] for x in pipeline.loaded_loras]
346
+ strengths = [x[1] for x in pipeline.loaded_loras]
347
+ pipeline.set_adapters(adapter_names, strengths)
348
+
349
+ print(adapter_names)
350
+
351
+ return (pipeline,)
352
+
353
+
354
+ class ControlNetModelLoader:
355
+ def __init__(self):
356
+ self.loaded_controlnet = None
357
+ self.dtype = torch.float16
358
+ self.torch_device = get_torch_device()
359
+ self.hf_dir = folder_paths.get_folder_paths("diffusers")[0]
360
+
361
+ @classmethod
362
+ def INPUT_TYPES(s):
363
+ return {
364
+ "required": {
365
+ "pipeline": ("PIPELINE",),
366
+ "controlnet_name": (
367
+ "STRING",
368
+ {"default": "xinsir/controlnet-scribble-sdxl-1.0"},
369
+ ),
370
+ }
371
+ }
372
+
373
+ RETURN_TYPES = ("PIPELINE",)
374
+ FUNCTION = "load_controlnet"
375
+ CATEGORY = "MV-Adapter"
376
+
377
+ def load_controlnet(self, pipeline, controlnet_name):
378
+ controlnet = None
379
+ if self.loaded_controlnet is not None:
380
+ if self.loaded_controlnet == controlnet_name:
381
+ controlnet = self.loaded_controlnet
382
+ else:
383
+ del pipeline.controlnet
384
+ self.loaded_controlnet = None
385
+
386
+ if controlnet is None:
387
+ controlnet = ControlNetModel.from_pretrained(
388
+ controlnet_name, cache_dir=self.hf_dir, torch_dtype=self.dtype
389
+ )
390
+ pipeline.controlnet = controlnet
391
+ pipeline.controlnet.to(device=self.torch_device, dtype=self.dtype)
392
+
393
+ self.loaded_controlnet = controlnet_name
394
+ controlnet = controlnet_name
395
+
396
+ return (pipeline,)
397
+
398
+
399
+ class DiffusersMVModelMakeup:
400
+ def __init__(self):
401
+ self.hf_dir = folder_paths.get_folder_paths("diffusers")[0]
402
+ self.torch_device = get_torch_device()
403
+ self.dtype = torch.float16
404
+
405
+ @classmethod
406
+ def INPUT_TYPES(s):
407
+ return {
408
+ "required": {
409
+ "pipeline": ("PIPELINE",),
410
+ "scheduler": ("SCHEDULER",),
411
+ "autoencoder": ("AUTOENCODER",),
412
+ "load_mvadapter": ("BOOLEAN", {"default": True}),
413
+ "adapter_path": ("STRING", {"default": "huanngzh/mv-adapter"}),
414
+ "adapter_name": (
415
+ MVADAPTERS,
416
+ {"default": "mvadapter_t2mv_sdxl.safetensors"},
417
+ ),
418
+ "num_views": ("INT", {"default": 6, "min": 1, "max": 12}),
419
+ },
420
+ "optional": {
421
+ "enable_vae_slicing": ("BOOLEAN", {"default": True}),
422
+ "enable_vae_tiling": ("BOOLEAN", {"default": False}),
423
+ },
424
+ }
425
+
426
+ RETURN_TYPES = ("PIPELINE",)
427
+
428
+ FUNCTION = "makeup_pipeline"
429
+
430
+ CATEGORY = "MV-Adapter"
431
+
432
+ def makeup_pipeline(
433
+ self,
434
+ pipeline,
435
+ scheduler,
436
+ autoencoder,
437
+ load_mvadapter,
438
+ adapter_path,
439
+ adapter_name,
440
+ num_views,
441
+ enable_vae_slicing=True,
442
+ enable_vae_tiling=False,
443
+ ):
444
+ pipeline.vae = autoencoder
445
+ pipeline.scheduler = scheduler
446
+
447
+ if load_mvadapter:
448
+ pipeline.init_custom_adapter(num_views=num_views)
449
+ pipeline.load_custom_adapter(
450
+ adapter_path, weight_name=adapter_name, cache_dir=self.hf_dir
451
+ )
452
+ pipeline.cond_encoder.to(device=self.torch_device, dtype=self.dtype)
453
+
454
+ pipeline = pipeline.to(self.torch_device, self.dtype)
455
+
456
+ if enable_vae_slicing:
457
+ pipeline.enable_vae_slicing()
458
+ if enable_vae_tiling:
459
+ pipeline.enable_vae_tiling()
460
+
461
+ return (pipeline,)
462
+
463
+
464
+ class DiffusersSampler:
465
+ def __init__(self):
466
+ self.torch_device = get_torch_device()
467
+
468
+ @classmethod
469
+ def INPUT_TYPES(s):
470
+ return {
471
+ "required": {
472
+ "pipeline": ("PIPELINE",),
473
+ "prompt": (
474
+ "STRING",
475
+ {"multiline": True, "default": "a photo of a cat"},
476
+ ),
477
+ "negative_prompt": (
478
+ "STRING",
479
+ {
480
+ "multiline": True,
481
+ "default": "watermark, ugly, deformed, noisy, blurry, low contrast",
482
+ },
483
+ ),
484
+ "width": ("INT", {"default": 768, "min": 1, "max": 2048, "step": 1}),
485
+ "height": ("INT", {"default": 768, "min": 1, "max": 2048, "step": 1}),
486
+ "steps": ("INT", {"default": 50, "min": 1, "max": 2000}),
487
+ "cfg": (
488
+ "FLOAT",
489
+ {
490
+ "default": 7.0,
491
+ "min": 0.0,
492
+ "max": 100.0,
493
+ "step": 0.1,
494
+ "round": 0.01,
495
+ },
496
+ ),
497
+ "seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF}),
498
+ }
499
+ }
500
+
501
+ RETURN_TYPES = ("IMAGE",)
502
+
503
+ FUNCTION = "sample"
504
+
505
+ CATEGORY = "MV-Adapter"
506
+
507
+ def sample(
508
+ self,
509
+ pipeline,
510
+ prompt,
511
+ negative_prompt,
512
+ height,
513
+ width,
514
+ steps,
515
+ cfg,
516
+ seed,
517
+ ):
518
+ images = pipeline(
519
+ prompt=prompt,
520
+ height=height,
521
+ width=width,
522
+ num_inference_steps=steps,
523
+ guidance_scale=cfg,
524
+ negative_prompt=negative_prompt,
525
+ generator=torch.Generator(self.torch_device).manual_seed(seed),
526
+ ).images
527
+ return (convert_images_to_tensors(images),)
528
+
529
+
530
+ class DiffusersMVSampler:
531
+ def __init__(self):
532
+ self.torch_device = get_torch_device()
533
+
534
+ @classmethod
535
+ def INPUT_TYPES(s):
536
+ return {
537
+ "required": {
538
+ "pipeline": ("PIPELINE",),
539
+ "num_views": ("INT", {"default": 6, "min": 1, "max": 12}),
540
+ "prompt": (
541
+ "STRING",
542
+ {"multiline": True, "default": "an astronaut riding a horse"},
543
+ ),
544
+ "negative_prompt": (
545
+ "STRING",
546
+ {
547
+ "multiline": True,
548
+ "default": "watermark, ugly, deformed, noisy, blurry, low contrast",
549
+ },
550
+ ),
551
+ "width": ("INT", {"default": 768, "min": 1, "max": 2048, "step": 1}),
552
+ "height": ("INT", {"default": 768, "min": 1, "max": 2048, "step": 1}),
553
+ "steps": ("INT", {"default": 50, "min": 1, "max": 2000}),
554
+ "cfg": (
555
+ "FLOAT",
556
+ {
557
+ "default": 7.0,
558
+ "min": 0.0,
559
+ "max": 100.0,
560
+ "step": 0.1,
561
+ "round": 0.01,
562
+ },
563
+ ),
564
+ "seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF}),
565
+ },
566
+ "optional": {
567
+ "reference_image": ("IMAGE",),
568
+ "controlnet_image": ("IMAGE",),
569
+ "controlnet_conditioning_scale": ("FLOAT", {"default": 1.0}),
570
+ "azimuth_degrees": ("LIST", {"default": [0, 45, 90, 180, 270, 315]}),
571
+ },
572
+ }
573
+
574
+ RETURN_TYPES = ("IMAGE",)
575
+
576
+ FUNCTION = "sample"
577
+
578
+ CATEGORY = "MV-Adapter"
579
+
580
+ def sample(
581
+ self,
582
+ pipeline,
583
+ num_views,
584
+ prompt,
585
+ negative_prompt,
586
+ height,
587
+ width,
588
+ steps,
589
+ cfg,
590
+ seed,
591
+ reference_image=None,
592
+ controlnet_image=None,
593
+ controlnet_conditioning_scale=1.0,
594
+ azimuth_degrees=[0, 45, 90, 180, 270, 315],
595
+ ):
596
+ num_views = len(azimuth_degrees)
597
+ control_images = prepare_camera_embed(
598
+ num_views, width, self.torch_device, azimuth_degrees
599
+ )
600
+
601
+ pipe_kwargs = {}
602
+ if reference_image is not None:
603
+ pipe_kwargs.update(
604
+ {
605
+ "reference_image": convert_tensors_to_images(reference_image)[0],
606
+ "reference_conditioning_scale": 1.0,
607
+ }
608
+ )
609
+ if controlnet_image is not None:
610
+ controlnet_image = convert_tensors_to_images(controlnet_image)
611
+ pipe_kwargs.update(
612
+ {
613
+ "controlnet_image": controlnet_image,
614
+ "controlnet_conditioning_scale": controlnet_conditioning_scale,
615
+ }
616
+ )
617
+
618
+ images = pipeline(
619
+ prompt=prompt,
620
+ height=height,
621
+ width=width,
622
+ num_inference_steps=steps,
623
+ guidance_scale=cfg,
624
+ num_images_per_prompt=num_views,
625
+ control_image=control_images,
626
+ control_conditioning_scale=1.0,
627
+ negative_prompt=negative_prompt,
628
+ generator=torch.Generator(self.torch_device).manual_seed(seed),
629
+ cross_attention_kwargs={"num_views": num_views},
630
+ **pipe_kwargs,
631
+ ).images
632
+ return (convert_images_to_tensors(images),)
633
+
634
+
635
+ class BiRefNet:
636
+ def __init__(self):
637
+ self.hf_dir = folder_paths.get_folder_paths("diffusers")[0]
638
+ self.torch_device = get_torch_device()
639
+ self.dtype = torch.float32
640
+
641
+ RETURN_TYPES = ("FUNCTION",)
642
+
643
+ FUNCTION = "load_model_fn"
644
+
645
+ CATEGORY = "MV-Adapter"
646
+
647
+ @classmethod
648
+ def INPUT_TYPES(s):
649
+ return {
650
+ "required": {"ckpt_name": ("STRING", {"default": "briaai/RMBG-2.0"})}
651
+ }
652
+
653
+ def remove_bg(self, image, net, transform, device):
654
+ image_size = image.size
655
+ input_images = transform(image).unsqueeze(0).to(device)
656
+ with torch.no_grad():
657
+ preds = net(input_images)[-1].sigmoid().cpu()
658
+ pred = preds[0].squeeze()
659
+ pred_pil = transforms.ToPILImage()(pred)
660
+ mask = pred_pil.resize(image_size)
661
+ image.putalpha(mask)
662
+ return image
663
+
664
+ def load_model_fn(self, ckpt_name):
665
+ model = AutoModelForImageSegmentation.from_pretrained(
666
+ ckpt_name, trust_remote_code=True, cache_dir=self.hf_dir
667
+ ).to(self.torch_device, self.dtype)
668
+
669
+ transform_image = transforms.Compose(
670
+ [
671
+ transforms.Resize((1024, 1024)),
672
+ transforms.ToTensor(),
673
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
674
+ ]
675
+ )
676
+
677
+ remove_bg_fn = lambda x: self.remove_bg(
678
+ x, model, transform_image, self.torch_device
679
+ )
680
+ return (remove_bg_fn,)
681
+
682
+
683
+ class ImagePreprocessor:
684
+ def __init__(self):
685
+ self.torch_device = get_torch_device()
686
+
687
+ @classmethod
688
+ def INPUT_TYPES(s):
689
+ return {
690
+ "required": {
691
+ "remove_bg_fn": ("FUNCTION",),
692
+ "image": ("IMAGE",),
693
+ "height": ("INT", {"default": 768, "min": 1, "max": 2048, "step": 1}),
694
+ "width": ("INT", {"default": 768, "min": 1, "max": 2048, "step": 1}),
695
+ }
696
+ }
697
+
698
+ RETURN_TYPES = ("IMAGE",)
699
+
700
+ FUNCTION = "process"
701
+
702
+ def process(self, remove_bg_fn, image, height, width):
703
+ images = convert_tensors_to_images(image)
704
+ images = [
705
+ preprocess_image(remove_bg_fn(img.convert("RGB")), height, width)
706
+ for img in images
707
+ ]
708
+
709
+ return (convert_images_to_tensors(images),)
710
+
711
+
712
+ class ControlImagePreprocessor:
713
+ def __init__(self):
714
+ self.torch_device = get_torch_device()
715
+
716
+ @classmethod
717
+ def INPUT_TYPES(s):
718
+ return {
719
+ "required": {
720
+ "front_view": ("IMAGE",),
721
+ "front_right_view": ("IMAGE",),
722
+ "right_view": ("IMAGE",),
723
+ "back_view": ("IMAGE",),
724
+ "left_view": ("IMAGE",),
725
+ "front_left_view": ("IMAGE",),
726
+ "width": ("INT", {"default": 768, "min": 1, "max": 2048, "step": 1}),
727
+ "height": ("INT", {"default": 768, "min": 1, "max": 2048, "step": 1}),
728
+ }
729
+ }
730
+
731
+ RETURN_TYPES = ("IMAGE",)
732
+
733
+ FUNCTION = "process"
734
+
735
+ def process(
736
+ self,
737
+ front_view,
738
+ front_right_view,
739
+ right_view,
740
+ back_view,
741
+ left_view,
742
+ front_left_view,
743
+ width,
744
+ height,
745
+ ):
746
+ images = torch.cat(
747
+ [
748
+ front_view,
749
+ front_right_view,
750
+ right_view,
751
+ back_view,
752
+ left_view,
753
+ front_left_view,
754
+ ],
755
+ dim=0,
756
+ )
757
+ images = convert_tensors_to_images(images)
758
+ images = [img.resize((width, height)).convert("RGB") for img in images]
759
+ return (convert_images_to_tensors(images),)
760
+
761
+
762
+ class ViewSelector:
763
+ def __init__(self):
764
+ pass
765
+
766
+ @classmethod
767
+ def INPUT_TYPES(s):
768
+ return {
769
+ "required": {
770
+ "front_view": ("BOOLEAN", {"default": True}),
771
+ "front_right_view": ("BOOLEAN", {"default": True}),
772
+ "right_view": ("BOOLEAN", {"default": True}),
773
+ "back_view": ("BOOLEAN", {"default": True}),
774
+ "left_view": ("BOOLEAN", {"default": True}),
775
+ "front_left_view": ("BOOLEAN", {"default": True}),
776
+ }
777
+ }
778
+
779
+ RETURN_TYPES = ("LIST",)
780
+ FUNCTION = "process"
781
+ CATEGORY = "MV-Adapter"
782
+
783
+ def process(
784
+ self,
785
+ front_view,
786
+ front_right_view,
787
+ right_view,
788
+ back_view,
789
+ left_view,
790
+ front_left_view,
791
+ ):
792
+ azimuth_deg = []
793
+ if front_view:
794
+ azimuth_deg.append(0)
795
+ if front_right_view:
796
+ azimuth_deg.append(45)
797
+ if right_view:
798
+ azimuth_deg.append(90)
799
+ if back_view:
800
+ azimuth_deg.append(180)
801
+ if left_view:
802
+ azimuth_deg.append(270)
803
+ if front_left_view:
804
+ azimuth_deg.append(315)
805
+
806
+ return (azimuth_deg,)
807
+
808
+
809
+ NODE_CLASS_MAPPINGS = {
810
+ "LdmPipelineLoader": LdmPipelineLoader,
811
+ "LdmVaeLoader": LdmVaeLoader,
812
+ "DiffusersMVPipelineLoader": DiffusersMVPipelineLoader,
813
+ "DiffusersMVVaeLoader": DiffusersMVVaeLoader,
814
+ "DiffusersMVSchedulerLoader": DiffusersMVSchedulerLoader,
815
+ # ADDED: Karras version
816
+ "DiffusersMVSchedulerLoaderKarras": DiffusersMVSchedulerLoaderKarras,
817
+ "DiffusersMVModelMakeup": DiffusersMVModelMakeup,
818
+ "CustomLoraModelLoader": CustomLoraModelLoader,
819
+ "DiffusersMVSampler": DiffusersMVSampler,
820
+ "BiRefNet": BiRefNet,
821
+ "ImagePreprocessor": ImagePreprocessor,
822
+ "ControlNetModelLoader": ControlNetModelLoader,
823
+ "ControlImagePreprocessor": ControlImagePreprocessor,
824
+ "ViewSelector": ViewSelector,
825
+ }
826
+
827
+ NODE_DISPLAY_NAME_MAPPINGS = {
828
+ "LdmPipelineLoader": "LDM Pipeline Loader",
829
+ "LdmVaeLoader": "LDM Vae Loader",
830
+ "DiffusersMVPipelineLoader": "Diffusers MV Pipeline Loader",
831
+ "DiffusersMVVaeLoader": "Diffusers MV Vae Loader",
832
+ "DiffusersMVSchedulerLoader": "Diffusers MV Scheduler Loader",
833
+ # ADDED: Karras version
834
+ "DiffusersMVSchedulerLoaderKarras": "Diffusers MV Scheduler Loader (Karras)",
835
+ "DiffusersMVModelMakeup": "Diffusers MV Model Makeup",
836
+ "CustomLoraModelLoader": "Custom Lora Model Loader",
837
+ "DiffusersMVSampler": "Diffusers MV Sampler",
838
+ "BiRefNet": "BiRefNet",
839
+ "ImagePreprocessor": "Image Preprocessor",
840
+ "ControlNetModelLoader": "ControlNet Model Loader",
841
+ "ControlImagePreprocessor": "Control Image Preprocessor",
842
+ "ViewSelector": "View Selector",
843
+ }
comfyui-mvadapter/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
comfyui-mvadapter/README.md ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ComfyUI-MVAdapter
2
+
3
+ This extension integrates [MV-Adapter](https://github.com/huanngzh/MV-Adapter) into ComfyUI, allowing users to generate multi-view consistent images from text prompts or single images directly within the ComfyUI interface.
4
+
5
+ ## 🔥 Feature Updates
6
+
7
+ * [2025-06-26] Support multiple loras for multi-view synthesis [See [here](https://github.com/huanngzh/ComfyUI-MVAdapter/pull/96)]
8
+ * [2025-01-15] Support selection of generated perspectives, such as generating only 2 views (front&back) [See [here](#view-selection)]
9
+ * [2024-12-25] Support integration with ControlNet, for applications like scribble to multi-view images [See [here](#with-controlnet)]
10
+ * [2024-12-09] Support integration with SDXL LoRA [See [here](#with-lora)]
11
+ * [2024-12-02] Generate multi-view consistent images from text prompts or a single image
12
+
13
+ ## Installation
14
+
15
+ ### From Source
16
+
17
+ * Clone or download this repository into your `ComfyUI/custom_nodes/` directory.
18
+ * Install the required dependencies by running `pip install -r requirements.txt`.
19
+
20
+ ## Notes
21
+
22
+ ### Workflows
23
+
24
+ We provide the example workflows in `workflows` directory.
25
+
26
+ Note that our code depends on diffusers, and will automatically download the model weights from huggingface to the hf cache path at the first time. The `ckpt_name` in the node corresponds to the model name in huggingface, such as `stabilityai/stable-diffusion-xl-base-1.0`.
27
+
28
+ We also provide the nodes `Ldm**Loader` to support loading text-to-image models in `ldm` format. Please see the workflow files with the suffix `_ldm.json`.
29
+
30
+ ### GPU Memory
31
+
32
+ If your GPU resources are limited, we recommend using the following configuration:
33
+
34
+ * Use [madebyollin/sdxl-vae-fp16-fix](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix) as VAE. If using ldm-format pipeline, remember to set `upcast_fp32` to `False`.
35
+
36
+ ![upcast_fp32_to_false](assets/comfyui_ldm_vae.png)
37
+
38
+ * Set `enable_vae_slicing` in the Diffusers Model Makeup node to `True`.
39
+
40
+ ![enable_vae_slicing](assets/comfyui_model_makeup.png)
41
+
42
+ However, since SDXL is used as the base model, it still requires about 13G to 14G GPU memory.
43
+
44
+ ## Usage
45
+
46
+ ### Text to Multi-view Images
47
+
48
+ #### With SDXL or other base models
49
+
50
+ ![comfyui_t2mv](assets/comfyui_t2mv.png)
51
+
52
+ * `workflows/t2mv_sdxl_diffusers.json` for loading diffusers-format models
53
+ * `workflows/t2mv_sdxl_ldm.json` for loading ldm-format models
54
+
55
+ #### With LoRA
56
+
57
+ ![comfyui_t2mv_lora](assets/comfyui_t2mv_lora.png)
58
+
59
+ `workflows/t2mv_sdxl_ldm_lora.json` for loading ldm-format models with LoRA for text-to-multi-view generation
60
+
61
+ #### With ControlNet
62
+
63
+ ![comfyui_t2mv_controlnet](assets/comfyui_t2mv_controlnet.png)
64
+
65
+ `workflows/t2mv_sdxl_ldm_controlnet.json` for loading diffusers-format controlnets for text-scribble-to-multi-view generation
66
+
67
+ ### Image to Multi-view Images
68
+
69
+ #### With SDXL or other base models
70
+
71
+ ![comfyui_i2mv](assets/comfyui_i2mv.png)
72
+
73
+ * `workflows/i2mv_sdxl_diffusers.json` for loading diffusers-format models
74
+ * `workflows/i2mv_sdxl_ldm.json` for loading ldm-format models
75
+
76
+ #### With LoRA
77
+
78
+ ![comfyui_i2mv_lora](assets/comfyui_i2mv_lora.png)
79
+
80
+ `workflows/i2mv_sdxl_ldm_lora.json` for loading ldm-format models with LoRA for image-to-multi-view generation
81
+
82
+ #### View Selection
83
+
84
+ ![comfyui_i2mv_pair_views](assets/comfyui_i2mv_view_selector.png)
85
+
86
+ `workflows/i2mv_sdxl_ldm_view_selector.json` for loading ldm-format models and selecting specific views to generate
87
+
88
+ The key is to replace the `adapter_name` in `Diffusers Model Makeup` with `mvadapter_i2mv_sdxl_beta.safetensors`, and add a `View Selector` node to choose which views you want to generate. After a rough test, the beta model is better at generating 2 views (front&back), 3 views (front&right&back), 4 views (front&right&back&left). Note that the attribute `num_views` is not used and can be ignored.
comfyui-mvadapter/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # __init__.py for comfyui-mvadapter
2
+ # Register BOTH node sets: the original nodes.py and nodes_local_mv.py
3
+
4
+ import traceback
5
+
6
+ # Load the original nodes (if present)
7
+ try:
8
+ from .nodes import (
9
+ NODE_CLASS_MAPPINGS as CORE_NODE_CLASS_MAPPINGS,
10
+ NODE_DISPLAY_NAME_MAPPINGS as CORE_NODE_DISPLAY_NAME_MAPPINGS,
11
+ )
12
+ except Exception as e:
13
+ print("[comfyui-mvadapter] WARN: Failed to import .nodes")
14
+ traceback.print_exc()
15
+ CORE_NODE_CLASS_MAPPINGS = {}
16
+ CORE_NODE_DISPLAY_NAME_MAPPINGS = {}
17
+
18
+ # Load the local-only nodes (if present)
19
+ try:
20
+ from .nodes_local_mv import (
21
+ NODE_CLASS_MAPPINGS as LOCAL_NODE_CLASS_MAPPINGS,
22
+ NODE_DISPLAY_NAME_MAPPINGS as LOCAL_NODE_DISPLAY_NAME_MAPPINGS,
23
+ )
24
+ except Exception as e:
25
+ print("[comfyui-mvadapter] WARN: Failed to import .nodes_local_mv")
26
+ traceback.print_exc()
27
+ LOCAL_NODE_CLASS_MAPPINGS = {}
28
+ LOCAL_NODE_DISPLAY_NAME_MAPPINGS = {}
29
+
30
+ # Merge into the symbols ComfyUI looks for
31
+ NODE_CLASS_MAPPINGS = {}
32
+ NODE_CLASS_MAPPINGS.update(CORE_NODE_CLASS_MAPPINGS)
33
+ NODE_CLASS_MAPPINGS.update(LOCAL_NODE_CLASS_MAPPINGS)
34
+
35
+ NODE_DISPLAY_NAME_MAPPINGS = {}
36
+ NODE_DISPLAY_NAME_MAPPINGS.update(CORE_NODE_DISPLAY_NAME_MAPPINGS)
37
+ NODE_DISPLAY_NAME_MAPPINGS.update(LOCAL_NODE_DISPLAY_NAME_MAPPINGS)
38
+
39
+ # Optional: quick summary to help debug load order
40
+ print(
41
+ "[comfyui-mvadapter] Registered nodes:",
42
+ ", ".join(sorted(NODE_CLASS_MAPPINGS.keys())) or "(none)",
43
+ )
44
+
45
+ __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
comfyui-mvadapter/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.4 kB). View file
 
comfyui-mvadapter/__pycache__/nodes.cpython-312.pyc ADDED
Binary file (8.32 kB). View file
 
comfyui-mvadapter/__pycache__/nodes_local_mv.cpython-312.pyc ADDED
Binary file (10.7 kB). View file
 
comfyui-mvadapter/__pycache__/utils.cpython-312.pyc ADDED
Binary file (13.6 kB). View file
 
comfyui-mvadapter/assets/CustomLoraModelLoader.png ADDED
comfyui-mvadapter/assets/comfyui_i2mv.png ADDED

Git LFS Details

  • SHA256: 9c364ee7e709ced6c9fe32111ed8ef0f6b893410b7165d87fa12dc7ec6c61953
  • Pointer size: 131 Bytes
  • Size of remote file: 432 kB
comfyui-mvadapter/assets/comfyui_i2mv_lora.png ADDED

Git LFS Details

  • SHA256: 9d037b0b3f026f308e6dacf9261483a8e9e069507ab09cf86ad22fc5fcf2aa49
  • Pointer size: 131 Bytes
  • Size of remote file: 853 kB
comfyui-mvadapter/assets/comfyui_i2mv_multiple_loras.jpg ADDED

Git LFS Details

  • SHA256: 65c901ec52c76dd2e3ee49e121b52a4589ce9e9f9e67edccf297b5028470768b
  • Pointer size: 131 Bytes
  • Size of remote file: 471 kB
comfyui-mvadapter/assets/comfyui_i2mv_view_selector.png ADDED

Git LFS Details

  • SHA256: 6a48cde4ec2a44b1a9a29d4b9e1aaaf5a9ae287ef2d5ad4fe5da23e876c76c74
  • Pointer size: 131 Bytes
  • Size of remote file: 401 kB
comfyui-mvadapter/assets/comfyui_ldm_vae.png ADDED
comfyui-mvadapter/assets/comfyui_model_makeup.png ADDED
comfyui-mvadapter/assets/comfyui_t2mv.png ADDED

Git LFS Details

  • SHA256: 61f807f5665dbe404be09ab27214ae3e545160c6f99005f7d309e31af15ed41f
  • Pointer size: 131 Bytes
  • Size of remote file: 311 kB
comfyui-mvadapter/assets/comfyui_t2mv_controlnet.png ADDED

Git LFS Details

  • SHA256: b1b1923de261e12963fc5dbdc929e3f4f832aae34cb198beab14748c24758aee
  • Pointer size: 131 Bytes
  • Size of remote file: 426 kB
comfyui-mvadapter/assets/comfyui_t2mv_lora.png ADDED

Git LFS Details

  • SHA256: 62293e0d4897848f7b2117d5b18036c9ed82b01eaa7b9e39e55ed33f53ee0ec3
  • Pointer size: 132 Bytes
  • Size of remote file: 1.05 MB
comfyui-mvadapter/assets/comfyui_t2mv_multiple_loras.jpg ADDED

Git LFS Details

  • SHA256: 7436db15d4fb65113bc544eeda0ad9be9ee03cb589959847763eaf85fe93f65e
  • Pointer size: 131 Bytes
  • Size of remote file: 492 kB
comfyui-mvadapter/assets/demo/scribbles/scribble_0.png ADDED
comfyui-mvadapter/assets/demo/scribbles/scribble_1.png ADDED
comfyui-mvadapter/assets/demo/scribbles/scribble_2.png ADDED
comfyui-mvadapter/assets/demo/scribbles/scribble_3.png ADDED
comfyui-mvadapter/assets/demo/scribbles/scribble_4.png ADDED
comfyui-mvadapter/assets/demo/scribbles/scribble_5.png ADDED
comfyui-mvadapter/cache/stable-diffusion-v1-inference.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-04
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: "jpg"
11
+ cond_stage_key: "txt"
12
+ image_size: 64
13
+ channels: 4
14
+ cond_stage_trainable: false # Note: different from the one we trained before
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+ use_ema: False
19
+
20
+ scheduler_config: # 10000 warmup steps
21
+ target: ldm.lr_scheduler.LambdaLinearScheduler
22
+ params:
23
+ warm_up_steps: [ 10000 ]
24
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25
+ f_start: [ 1.e-6 ]
26
+ f_max: [ 1. ]
27
+ f_min: [ 1. ]
28
+
29
+ unet_config:
30
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31
+ params:
32
+ image_size: 32 # unused
33
+ in_channels: 4
34
+ out_channels: 4
35
+ model_channels: 320
36
+ attention_resolutions: [ 4, 2, 1 ]
37
+ num_res_blocks: 2
38
+ channel_mult: [ 1, 2, 4, 4 ]
39
+ num_heads: 8
40
+ use_spatial_transformer: True
41
+ transformer_depth: 1
42
+ context_dim: 768
43
+ use_checkpoint: True
44
+ legacy: False
45
+
46
+ first_stage_config:
47
+ target: ldm.models.autoencoder.AutoencoderKL
48
+ params:
49
+ embed_dim: 4
50
+ monitor: val/rec_loss
51
+ ddconfig:
52
+ double_z: true
53
+ z_channels: 4
54
+ resolution: 256
55
+ in_channels: 3
56
+ out_ch: 3
57
+ ch: 128
58
+ ch_mult:
59
+ - 1
60
+ - 2
61
+ - 4
62
+ - 4
63
+ num_res_blocks: 2
64
+ attn_resolutions: []
65
+ dropout: 0.0
66
+ lossconfig:
67
+ target: torch.nn.Identity
68
+
69
+ cond_stage_config:
70
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
comfyui-mvadapter/mvadapter/__init__.py ADDED
File without changes
comfyui-mvadapter/mvadapter/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (161 Bytes). View file
 
comfyui-mvadapter/mvadapter/loaders/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .custom_adapter import CustomAdapterMixin
comfyui-mvadapter/mvadapter/loaders/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (229 Bytes). View file
 
comfyui-mvadapter/mvadapter/loaders/__pycache__/custom_adapter.cpython-312.pyc ADDED
Binary file (4.44 kB). View file
 
comfyui-mvadapter/mvadapter/loaders/custom_adapter.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, Optional, Union
3
+
4
+ import safetensors
5
+ import torch
6
+ from diffusers.utils import _get_model_file, logging
7
+ from safetensors import safe_open
8
+
9
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
10
+
11
+
12
+ class CustomAdapterMixin:
13
+ def init_custom_adapter(self, *args, **kwargs):
14
+ self._init_custom_adapter(*args, **kwargs)
15
+
16
+ def _init_custom_adapter(self, *args, **kwargs):
17
+ raise NotImplementedError
18
+
19
+ def load_custom_adapter(
20
+ self,
21
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
22
+ weight_name: str,
23
+ subfolder: Optional[str] = None,
24
+ **kwargs,
25
+ ):
26
+ # Load the main state dict first.
27
+ cache_dir = kwargs.pop("cache_dir", None)
28
+ force_download = kwargs.pop("force_download", False)
29
+ proxies = kwargs.pop("proxies", None)
30
+ local_files_only = kwargs.pop("local_files_only", None)
31
+ token = kwargs.pop("token", None)
32
+ revision = kwargs.pop("revision", None)
33
+
34
+ user_agent = {
35
+ "file_type": "attn_procs_weights",
36
+ "framework": "pytorch",
37
+ }
38
+
39
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
40
+ model_file = _get_model_file(
41
+ pretrained_model_name_or_path_or_dict,
42
+ weights_name=weight_name,
43
+ subfolder=subfolder,
44
+ cache_dir=cache_dir,
45
+ force_download=force_download,
46
+ proxies=proxies,
47
+ local_files_only=local_files_only,
48
+ token=token,
49
+ revision=revision,
50
+ user_agent=user_agent,
51
+ )
52
+ if weight_name.endswith(".safetensors"):
53
+ state_dict = {}
54
+ with safe_open(model_file, framework="pt", device="cpu") as f:
55
+ for key in f.keys():
56
+ state_dict[key] = f.get_tensor(key)
57
+ else:
58
+ state_dict = torch.load(model_file, map_location="cpu")
59
+ else:
60
+ state_dict = pretrained_model_name_or_path_or_dict
61
+
62
+ self._load_custom_adapter(state_dict)
63
+
64
+ def _load_custom_adapter(self, state_dict):
65
+ raise NotImplementedError
66
+
67
+ def save_custom_adapter(
68
+ self,
69
+ save_directory: Union[str, os.PathLike],
70
+ weight_name: str,
71
+ safe_serialization: bool = False,
72
+ **kwargs,
73
+ ):
74
+ if os.path.isfile(save_directory):
75
+ logger.error(
76
+ f"Provided path ({save_directory}) should be a directory, not a file"
77
+ )
78
+ return
79
+
80
+ if safe_serialization:
81
+
82
+ def save_function(weights, filename):
83
+ return safetensors.torch.save_file(
84
+ weights, filename, metadata={"format": "pt"}
85
+ )
86
+
87
+ else:
88
+ save_function = torch.save
89
+
90
+ # Save the model
91
+ state_dict = self._save_custom_adapter(**kwargs)
92
+ save_function(state_dict, os.path.join(save_directory, weight_name))
93
+ logger.info(
94
+ f"Custom adapter weights saved in {os.path.join(save_directory, weight_name)}"
95
+ )
96
+
97
+ def _save_custom_adapter(self):
98
+ raise NotImplementedError
comfyui-mvadapter/mvadapter/models/__init__.py ADDED
File without changes
comfyui-mvadapter/mvadapter/models/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (168 Bytes). View file
 
comfyui-mvadapter/mvadapter/models/__pycache__/attention_processor.cpython-312.pyc ADDED
Binary file (13.8 kB). View file
 
comfyui-mvadapter/mvadapter/models/attention_processor.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Callable, List, Optional, Union
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from diffusers.models.attention_processor import Attention
7
+ from diffusers.models.unets import UNet2DConditionModel
8
+ from diffusers.utils import deprecate, logging
9
+ from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
10
+ from einops import rearrange
11
+ from torch import nn
12
+
13
+
14
+ def default_set_attn_proc_func(
15
+ name: str,
16
+ hidden_size: int,
17
+ cross_attention_dim: Optional[int],
18
+ ori_attn_proc: object,
19
+ ) -> object:
20
+ return ori_attn_proc
21
+
22
+
23
+ def set_unet_2d_condition_attn_processor(
24
+ unet: UNet2DConditionModel,
25
+ set_self_attn_proc_func: Callable = default_set_attn_proc_func,
26
+ set_cross_attn_proc_func: Callable = default_set_attn_proc_func,
27
+ set_custom_attn_proc_func: Callable = default_set_attn_proc_func,
28
+ set_self_attn_module_names: Optional[List[str]] = None,
29
+ set_cross_attn_module_names: Optional[List[str]] = None,
30
+ set_custom_attn_module_names: Optional[List[str]] = None,
31
+ ) -> None:
32
+ do_set_processor = lambda name, module_names: (
33
+ any([name.startswith(module_name) for module_name in module_names])
34
+ if module_names is not None
35
+ else True
36
+ ) # prefix match
37
+
38
+ attn_procs = {}
39
+ for name, attn_processor in unet.attn_processors.items():
40
+ # set attn_processor by default, if module_names is None
41
+ set_self_attn_processor = do_set_processor(name, set_self_attn_module_names)
42
+ set_cross_attn_processor = do_set_processor(name, set_cross_attn_module_names)
43
+ set_custom_attn_processor = do_set_processor(name, set_custom_attn_module_names)
44
+
45
+ if name.startswith("mid_block"):
46
+ hidden_size = unet.config.block_out_channels[-1]
47
+ elif name.startswith("up_blocks"):
48
+ block_id = int(name[len("up_blocks.")])
49
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
50
+ elif name.startswith("down_blocks"):
51
+ block_id = int(name[len("down_blocks.")])
52
+ hidden_size = unet.config.block_out_channels[block_id]
53
+
54
+ is_custom = "attn_mid_blocks" in name or "attn_post_blocks" in name
55
+ if is_custom:
56
+ attn_procs[name] = (
57
+ set_custom_attn_proc_func(name, hidden_size, None, attn_processor)
58
+ if set_custom_attn_processor
59
+ else attn_processor
60
+ )
61
+ else:
62
+ cross_attention_dim = (
63
+ None
64
+ if name.endswith("attn1.processor")
65
+ else unet.config.cross_attention_dim
66
+ )
67
+ if cross_attention_dim is None or "motion_modules" in name:
68
+ # self attention
69
+ attn_procs[name] = (
70
+ set_self_attn_proc_func(
71
+ name, hidden_size, cross_attention_dim, attn_processor
72
+ )
73
+ if set_self_attn_processor
74
+ else attn_processor
75
+ )
76
+ else:
77
+ # cross attention
78
+ attn_procs[name] = (
79
+ set_cross_attn_proc_func(
80
+ name, hidden_size, cross_attention_dim, attn_processor
81
+ )
82
+ if set_cross_attn_processor
83
+ else attn_processor
84
+ )
85
+
86
+ unet.set_attn_processor(attn_procs)
87
+
88
+
89
+ class DecoupledMVRowSelfAttnProcessor2_0(torch.nn.Module):
90
+ r"""
91
+ Attention processor for Decoupled Row-wise Self-Attention and Image Cross-Attention for PyTorch 2.0.
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ query_dim: int,
97
+ inner_dim: int,
98
+ num_views: int = 1,
99
+ name: Optional[str] = None,
100
+ use_mv: bool = True,
101
+ use_ref: bool = False,
102
+ ):
103
+ if not hasattr(F, "scaled_dot_product_attention"):
104
+ raise ImportError(
105
+ "DecoupledMVRowSelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
106
+ )
107
+
108
+ super().__init__()
109
+
110
+ self.num_views = num_views
111
+ self.name = name # NOTE: need for image cross-attention
112
+ self.use_mv = use_mv
113
+ self.use_ref = use_ref
114
+
115
+ if self.use_mv:
116
+ self.to_q_mv = nn.Linear(
117
+ in_features=query_dim, out_features=inner_dim, bias=False
118
+ )
119
+ self.to_k_mv = nn.Linear(
120
+ in_features=query_dim, out_features=inner_dim, bias=False
121
+ )
122
+ self.to_v_mv = nn.Linear(
123
+ in_features=query_dim, out_features=inner_dim, bias=False
124
+ )
125
+ self.to_out_mv = nn.ModuleList(
126
+ [
127
+ nn.Linear(in_features=inner_dim, out_features=query_dim, bias=True),
128
+ nn.Dropout(0.0),
129
+ ]
130
+ )
131
+
132
+ if self.use_ref:
133
+ self.to_q_ref = nn.Linear(
134
+ in_features=query_dim, out_features=inner_dim, bias=False
135
+ )
136
+ self.to_k_ref = nn.Linear(
137
+ in_features=query_dim, out_features=inner_dim, bias=False
138
+ )
139
+ self.to_v_ref = nn.Linear(
140
+ in_features=query_dim, out_features=inner_dim, bias=False
141
+ )
142
+ self.to_out_ref = nn.ModuleList(
143
+ [
144
+ nn.Linear(in_features=inner_dim, out_features=query_dim, bias=True),
145
+ nn.Dropout(0.0),
146
+ ]
147
+ )
148
+
149
+ def __call__(
150
+ self,
151
+ attn: Attention,
152
+ hidden_states: torch.FloatTensor,
153
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
154
+ attention_mask: Optional[torch.FloatTensor] = None,
155
+ temb: Optional[torch.FloatTensor] = None,
156
+ mv_scale: float = 1.0,
157
+ ref_hidden_states: Optional[torch.FloatTensor] = None,
158
+ ref_scale: float = 1.0,
159
+ cache_hidden_states: Optional[List[torch.FloatTensor]] = None,
160
+ use_mv: bool = True,
161
+ use_ref: bool = True,
162
+ num_views: Optional[int] = None,
163
+ *args,
164
+ **kwargs,
165
+ ) -> torch.FloatTensor:
166
+ """
167
+ New args:
168
+ mv_scale (float): scale for multi-view self-attention.
169
+ ref_hidden_states (torch.FloatTensor): reference encoder hidden states for image cross-attention.
170
+ ref_scale (float): scale for image cross-attention.
171
+ cache_hidden_states (List[torch.FloatTensor]): cache hidden states from reference unet.
172
+
173
+ """
174
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
175
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
176
+ deprecate("scale", "1.0.0", deprecation_message)
177
+
178
+ if num_views is not None:
179
+ self.num_views = num_views
180
+
181
+ # NEW: cache hidden states for reference unet
182
+ if cache_hidden_states is not None:
183
+ cache_hidden_states[self.name] = hidden_states.clone()
184
+
185
+ # NEW: whether to use multi-view attention and image cross-attention
186
+ use_mv = self.use_mv and use_mv
187
+ use_ref = self.use_ref and use_ref
188
+
189
+ residual = hidden_states
190
+ if attn.spatial_norm is not None:
191
+ hidden_states = attn.spatial_norm(hidden_states, temb)
192
+
193
+ input_ndim = hidden_states.ndim
194
+
195
+ if input_ndim == 4:
196
+ batch_size, channel, height, width = hidden_states.shape
197
+ hidden_states = hidden_states.view(
198
+ batch_size, channel, height * width
199
+ ).transpose(1, 2)
200
+
201
+ batch_size, sequence_length, _ = (
202
+ hidden_states.shape
203
+ if encoder_hidden_states is None
204
+ else encoder_hidden_states.shape
205
+ )
206
+
207
+ if attention_mask is not None:
208
+ attention_mask = attn.prepare_attention_mask(
209
+ attention_mask, sequence_length, batch_size
210
+ )
211
+ # scaled_dot_product_attention expects attention_mask shape to be
212
+ # (batch, heads, source_length, target_length)
213
+ attention_mask = attention_mask.view(
214
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
215
+ )
216
+
217
+ if attn.group_norm is not None:
218
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
219
+ 1, 2
220
+ )
221
+
222
+ query = attn.to_q(hidden_states)
223
+
224
+ # NEW: for decoupled multi-view attention
225
+ if use_mv:
226
+ query_mv = self.to_q_mv(hidden_states)
227
+
228
+ # NEW: for decoupled reference cross attention
229
+ if use_ref:
230
+ query_ref = self.to_q_ref(hidden_states)
231
+
232
+ if encoder_hidden_states is None:
233
+ encoder_hidden_states = hidden_states
234
+ elif attn.norm_cross:
235
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
236
+ encoder_hidden_states
237
+ )
238
+
239
+ key = attn.to_k(encoder_hidden_states)
240
+ value = attn.to_v(encoder_hidden_states)
241
+
242
+ inner_dim = key.shape[-1]
243
+ head_dim = inner_dim // attn.heads
244
+
245
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
246
+
247
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
248
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
249
+
250
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
251
+ # TODO: add support for attn.scale when we move to Torch 2.1
252
+ hidden_states = F.scaled_dot_product_attention(
253
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
254
+ )
255
+
256
+ hidden_states = hidden_states.transpose(1, 2).reshape(
257
+ batch_size, -1, attn.heads * head_dim
258
+ )
259
+ hidden_states = hidden_states.to(query.dtype)
260
+
261
+ ####### Decoupled multi-view self-attention ########
262
+ if use_mv:
263
+ key_mv = self.to_k_mv(encoder_hidden_states)
264
+ value_mv = self.to_v_mv(encoder_hidden_states)
265
+
266
+ query_mv = query_mv.view(batch_size, -1, attn.heads, head_dim)
267
+ key_mv = key_mv.view(batch_size, -1, attn.heads, head_dim)
268
+ value_mv = value_mv.view(batch_size, -1, attn.heads, head_dim)
269
+
270
+ height = width = math.isqrt(sequence_length)
271
+
272
+ # row self-attention
273
+ query_mv = rearrange(
274
+ query_mv,
275
+ "(b nv) (ih iw) h c -> (b nv ih) iw h c",
276
+ nv=self.num_views,
277
+ ih=height,
278
+ iw=width,
279
+ ).transpose(1, 2)
280
+ key_mv = rearrange(
281
+ key_mv,
282
+ "(b nv) (ih iw) h c -> b ih (nv iw) h c",
283
+ nv=self.num_views,
284
+ ih=height,
285
+ iw=width,
286
+ )
287
+ key_mv = (
288
+ key_mv.repeat_interleave(self.num_views, dim=0)
289
+ .view(batch_size * height, -1, attn.heads, head_dim)
290
+ .transpose(1, 2)
291
+ )
292
+ value_mv = rearrange(
293
+ value_mv,
294
+ "(b nv) (ih iw) h c -> b ih (nv iw) h c",
295
+ nv=self.num_views,
296
+ ih=height,
297
+ iw=width,
298
+ )
299
+ value_mv = (
300
+ value_mv.repeat_interleave(self.num_views, dim=0)
301
+ .view(batch_size * height, -1, attn.heads, head_dim)
302
+ .transpose(1, 2)
303
+ )
304
+
305
+ hidden_states_mv = F.scaled_dot_product_attention(
306
+ query_mv,
307
+ key_mv,
308
+ value_mv,
309
+ dropout_p=0.0,
310
+ is_causal=False,
311
+ )
312
+ hidden_states_mv = rearrange(
313
+ hidden_states_mv,
314
+ "(b nv ih) h iw c -> (b nv) (ih iw) (h c)",
315
+ nv=self.num_views,
316
+ ih=height,
317
+ )
318
+ hidden_states_mv = hidden_states_mv.to(query.dtype)
319
+
320
+ # linear proj
321
+ hidden_states_mv = self.to_out_mv[0](hidden_states_mv)
322
+ # dropout
323
+ hidden_states_mv = self.to_out_mv[1](hidden_states_mv)
324
+
325
+ if use_ref:
326
+ reference_hidden_states = ref_hidden_states[self.name]
327
+
328
+ key_ref = self.to_k_ref(reference_hidden_states)
329
+ value_ref = self.to_v_ref(reference_hidden_states)
330
+
331
+ query_ref = query_ref.view(batch_size, -1, attn.heads, head_dim).transpose(
332
+ 1, 2
333
+ )
334
+ key_ref = key_ref.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
335
+ value_ref = value_ref.view(batch_size, -1, attn.heads, head_dim).transpose(
336
+ 1, 2
337
+ )
338
+
339
+ hidden_states_ref = F.scaled_dot_product_attention(
340
+ query_ref, key_ref, value_ref, dropout_p=0.0, is_causal=False
341
+ )
342
+
343
+ hidden_states_ref = hidden_states_ref.transpose(1, 2).reshape(
344
+ batch_size, -1, attn.heads * head_dim
345
+ )
346
+ hidden_states_ref = hidden_states_ref.to(query.dtype)
347
+
348
+ # linear proj
349
+ hidden_states_ref = self.to_out_ref[0](hidden_states_ref)
350
+ # dropout
351
+ hidden_states_ref = self.to_out_ref[1](hidden_states_ref)
352
+
353
+ # linear proj
354
+ hidden_states = attn.to_out[0](hidden_states)
355
+ # dropout
356
+ hidden_states = attn.to_out[1](hidden_states)
357
+
358
+ if use_mv:
359
+ hidden_states = hidden_states + hidden_states_mv * mv_scale
360
+
361
+ if use_ref:
362
+ hidden_states = hidden_states + hidden_states_ref * ref_scale
363
+
364
+ if input_ndim == 4:
365
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
366
+ batch_size, channel, height, width
367
+ )
368
+
369
+ if attn.residual_connection:
370
+ hidden_states = hidden_states + residual
371
+
372
+ hidden_states = hidden_states / attn.rescale_output_factor
373
+
374
+ return hidden_states
375
+
376
+ def set_num_views(self, num_views: int) -> None:
377
+ self.num_views = num_views
comfyui-mvadapter/mvadapter/pipelines/__pycache__/pipeline_mvadapter_i2mv_sd.cpython-312.pyc ADDED
Binary file (30 kB). View file
 
comfyui-mvadapter/mvadapter/pipelines/__pycache__/pipeline_mvadapter_i2mv_sdxl.cpython-312.pyc ADDED
Binary file (32.6 kB). View file
 
comfyui-mvadapter/mvadapter/pipelines/__pycache__/pipeline_mvadapter_t2mv_sd.cpython-312.pyc ADDED
Binary file (24.9 kB). View file
 
comfyui-mvadapter/mvadapter/pipelines/__pycache__/pipeline_mvadapter_t2mv_sdxl.cpython-312.pyc ADDED
Binary file (34.8 kB). View file
 
comfyui-mvadapter/mvadapter/pipelines/pipeline_mvadapter_i2mv_sdxl.py ADDED
@@ -0,0 +1,903 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import PIL
20
+ import torch
21
+ import torch.nn as nn
22
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
23
+ from diffusers.models import (
24
+ AutoencoderKL,
25
+ ImageProjection,
26
+ T2IAdapter,
27
+ UNet2DConditionModel,
28
+ )
29
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_output import (
30
+ StableDiffusionXLPipelineOutput,
31
+ )
32
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import (
33
+ StableDiffusionXLPipeline,
34
+ rescale_noise_cfg,
35
+ retrieve_timesteps,
36
+ )
37
+ from diffusers.schedulers import KarrasDiffusionSchedulers
38
+ from diffusers.utils import deprecate, logging
39
+ from diffusers.utils.torch_utils import randn_tensor
40
+ from einops import rearrange
41
+ from transformers import (
42
+ CLIPImageProcessor,
43
+ CLIPTextModel,
44
+ CLIPTextModelWithProjection,
45
+ CLIPTokenizer,
46
+ CLIPVisionModelWithProjection,
47
+ )
48
+
49
+ from ..loaders import CustomAdapterMixin
50
+ from ..models.attention_processor import (
51
+ DecoupledMVRowSelfAttnProcessor2_0,
52
+ set_unet_2d_condition_attn_processor,
53
+ )
54
+
55
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
56
+
57
+
58
+ def retrieve_latents(
59
+ encoder_output: torch.Tensor,
60
+ generator: Optional[torch.Generator] = None,
61
+ sample_mode: str = "sample",
62
+ ):
63
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
64
+ return encoder_output.latent_dist.sample(generator)
65
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
66
+ return encoder_output.latent_dist.mode()
67
+ elif hasattr(encoder_output, "latents"):
68
+ return encoder_output.latents
69
+ else:
70
+ raise AttributeError("Could not access latents of provided encoder_output")
71
+
72
+
73
+ class MVAdapterI2MVSDXLPipeline(StableDiffusionXLPipeline, CustomAdapterMixin):
74
+ def __init__(
75
+ self,
76
+ vae: AutoencoderKL,
77
+ text_encoder: CLIPTextModel,
78
+ text_encoder_2: CLIPTextModelWithProjection,
79
+ tokenizer: CLIPTokenizer,
80
+ tokenizer_2: CLIPTokenizer,
81
+ unet: UNet2DConditionModel,
82
+ scheduler: KarrasDiffusionSchedulers,
83
+ image_encoder: CLIPVisionModelWithProjection = None,
84
+ feature_extractor: CLIPImageProcessor = None,
85
+ force_zeros_for_empty_prompt: bool = True,
86
+ add_watermarker: Optional[bool] = None,
87
+ ):
88
+ super().__init__(
89
+ vae=vae,
90
+ text_encoder=text_encoder,
91
+ text_encoder_2=text_encoder_2,
92
+ tokenizer=tokenizer,
93
+ tokenizer_2=tokenizer_2,
94
+ unet=unet,
95
+ scheduler=scheduler,
96
+ image_encoder=image_encoder,
97
+ feature_extractor=feature_extractor,
98
+ force_zeros_for_empty_prompt=force_zeros_for_empty_prompt,
99
+ add_watermarker=add_watermarker,
100
+ )
101
+
102
+ self.control_image_processor = VaeImageProcessor(
103
+ vae_scale_factor=self.vae_scale_factor,
104
+ do_convert_rgb=True,
105
+ do_normalize=False,
106
+ )
107
+
108
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.prepare_latents
109
+ def prepare_image_latents(
110
+ self,
111
+ image,
112
+ timestep,
113
+ batch_size,
114
+ num_images_per_prompt,
115
+ dtype,
116
+ device,
117
+ generator=None,
118
+ add_noise=True,
119
+ ):
120
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
121
+ raise ValueError(
122
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
123
+ )
124
+
125
+ latents_mean = latents_std = None
126
+ if (
127
+ hasattr(self.vae.config, "latents_mean")
128
+ and self.vae.config.latents_mean is not None
129
+ ):
130
+ latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
131
+ if (
132
+ hasattr(self.vae.config, "latents_std")
133
+ and self.vae.config.latents_std is not None
134
+ ):
135
+ latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
136
+
137
+ # Offload text encoder if `enable_model_cpu_offload` was enabled
138
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
139
+ self.text_encoder_2.to("cpu")
140
+ torch.cuda.empty_cache()
141
+
142
+ image = image.to(device=device, dtype=dtype)
143
+
144
+ batch_size = batch_size * num_images_per_prompt
145
+
146
+ if image.shape[1] == 4:
147
+ init_latents = image
148
+
149
+ else:
150
+ # make sure the VAE is in float32 mode, as it overflows in float16
151
+ if self.vae.config.force_upcast:
152
+ image = image.float()
153
+ self.vae.to(dtype=torch.float32)
154
+
155
+ if isinstance(generator, list) and len(generator) != batch_size:
156
+ raise ValueError(
157
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
158
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
159
+ )
160
+
161
+ elif isinstance(generator, list):
162
+ if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
163
+ image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
164
+ elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
165
+ raise ValueError(
166
+ f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
167
+ )
168
+
169
+ init_latents = [
170
+ retrieve_latents(
171
+ self.vae.encode(image[i : i + 1]), generator=generator[i]
172
+ )
173
+ for i in range(batch_size)
174
+ ]
175
+ init_latents = torch.cat(init_latents, dim=0)
176
+ else:
177
+ init_latents = retrieve_latents(
178
+ self.vae.encode(image), generator=generator
179
+ )
180
+
181
+ if self.vae.config.force_upcast:
182
+ self.vae.to(dtype)
183
+
184
+ init_latents = init_latents.to(dtype)
185
+ if latents_mean is not None and latents_std is not None:
186
+ latents_mean = latents_mean.to(device=device, dtype=dtype)
187
+ latents_std = latents_std.to(device=device, dtype=dtype)
188
+ init_latents = (
189
+ (init_latents - latents_mean)
190
+ * self.vae.config.scaling_factor
191
+ / latents_std
192
+ )
193
+ else:
194
+ init_latents = self.vae.config.scaling_factor * init_latents
195
+
196
+ if (
197
+ batch_size > init_latents.shape[0]
198
+ and batch_size % init_latents.shape[0] == 0
199
+ ):
200
+ # expand init_latents for batch_size
201
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
202
+ init_latents = torch.cat(
203
+ [init_latents] * additional_image_per_prompt, dim=0
204
+ )
205
+ elif (
206
+ batch_size > init_latents.shape[0]
207
+ and batch_size % init_latents.shape[0] != 0
208
+ ):
209
+ raise ValueError(
210
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
211
+ )
212
+ else:
213
+ init_latents = torch.cat([init_latents], dim=0)
214
+
215
+ if add_noise:
216
+ shape = init_latents.shape
217
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
218
+ # get latents
219
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
220
+
221
+ latents = init_latents
222
+
223
+ return latents
224
+
225
+ def prepare_control_image(
226
+ self,
227
+ image,
228
+ width,
229
+ height,
230
+ batch_size,
231
+ num_images_per_prompt,
232
+ device,
233
+ dtype,
234
+ do_classifier_free_guidance=False,
235
+ num_empty_images=0, # for concat in batch like ImageDream
236
+ ):
237
+ """
238
+ Accepts either:
239
+ - regular RGB-like images -> preprocess via VaeImageProcessor, or
240
+ - native 6-channel Plücker tensors (B,6,H,W) or (6,H,W) -> pass through without normalization
241
+ """
242
+ assert hasattr(
243
+ self, "control_image_processor"
244
+ ), "control_image_processor is not initialized"
245
+
246
+ # Fast path: native 6-channel tensor
247
+ if isinstance(image, torch.Tensor):
248
+ if image.dim() == 3 and image.shape[0] == 6:
249
+ image = image.unsqueeze(0) # (1,6,H,W)
250
+ if image.dim() == 4 and image.shape[1] == 6:
251
+ ctrl = image.to(device=device, dtype=torch.float32)
252
+ if num_empty_images > 0:
253
+ ctrl = torch.cat([ctrl, torch.zeros_like(ctrl[:num_empty_images])], dim=0)
254
+
255
+ image_batch_size = ctrl.shape[0]
256
+ repeat_by = batch_size if image_batch_size == 1 else num_images_per_prompt # always 1 per control
257
+ ctrl = ctrl.repeat_interleave(repeat_by, dim=0)
258
+ ctrl = ctrl.to(device=device, dtype=dtype)
259
+
260
+ if do_classifier_free_guidance:
261
+ ctrl = torch.cat([ctrl] * 2)
262
+
263
+ return ctrl
264
+
265
+ # Fallback: treat as regular image(s)
266
+ image = self.control_image_processor.preprocess(
267
+ image, height=height, width=width
268
+ ).to(dtype=torch.float32)
269
+
270
+ if num_empty_images > 0:
271
+ image = torch.cat(
272
+ [image, torch.zeros_like(image[:num_empty_images])], dim=0
273
+ )
274
+
275
+ image_batch_size = image.shape[0]
276
+
277
+ if image_batch_size == 1:
278
+ repeat_by = batch_size
279
+ else:
280
+ # image batch size is the same as prompt batch size
281
+ repeat_by = num_images_per_prompt # always 1 for control image
282
+
283
+ image = image.repeat_interleave(repeat_by, dim=0)
284
+
285
+ image = image.to(device=device, dtype=dtype)
286
+
287
+ if do_classifier_free_guidance:
288
+ image = torch.cat([image] * 2)
289
+
290
+ return image
291
+
292
+ @torch.no_grad()
293
+ def __call__(
294
+ self,
295
+ prompt: Union[str, List[str]] = None,
296
+ prompt_2: Optional[Union[str, List[str]]] = None,
297
+ # --- Task 1: NEW (reference-only prompt used only for the ref cache pass) ---
298
+ reference_prompt: Optional[Union[str, List[str]]] = None,
299
+ reference_prompt_2: Optional[Union[str, List[str]]] = None,
300
+ # -----------------------------------------------------------------------------
301
+ height: Optional[int] = None,
302
+ width: Optional[int] = None,
303
+ num_inference_steps: int = 50,
304
+ timesteps: List[int] = None,
305
+ denoising_end: Optional[float] = None,
306
+ guidance_scale: float = 5.0,
307
+ negative_prompt: Optional[Union[str, List[str]]] = None,
308
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
309
+ num_images_per_prompt: Optional[int] = 1,
310
+ eta: float = 0.0,
311
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
312
+ latents: Optional[torch.FloatTensor] = None,
313
+ prompt_embeds: Optional[torch.FloatTensor] = None,
314
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
315
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
316
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
317
+ ip_adapter_image: Optional[PipelineImageInput] = None,
318
+ ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
319
+ output_type: Optional[str] = "pil",
320
+ return_dict: bool = True,
321
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
322
+ guidance_rescale: float = 0.0,
323
+ original_size: Optional[Tuple[int, int]] = None,
324
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
325
+ target_size: Optional[Tuple[int, int]] = None,
326
+ negative_original_size: Optional[Tuple[int, int]] = None,
327
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
328
+ negative_target_size: Optional[Tuple[int, int]] = None,
329
+ clip_skip: Optional[int] = None,
330
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
331
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
332
+ # NEW
333
+ mv_scale: float = 1.0,
334
+ # Camera or geometry condition
335
+ control_image: Optional[PipelineImageInput] = None,
336
+ control_conditioning_scale: Optional[float] = 1.0,
337
+ control_conditioning_factor: float = 1.0,
338
+ # Image condition
339
+ reference_image: Optional[PipelineImageInput] = None,
340
+ reference_conditioning_scale: Optional[float] = 1.0,
341
+ **kwargs,
342
+ ):
343
+ r"""
344
+ Function invoked when calling the pipeline for generation.
345
+
346
+ Args:
347
+ prompt (`str` or `List[str]`, *optional*):
348
+ The main prompt(s) for generation.
349
+ prompt_2 (`str` or `List[str]`, *optional*):
350
+ Prompt(s) for the second text encoder. Falls back to `prompt` if None.
351
+ reference_prompt (`str` or `List[str]`, *optional*):
352
+ Prompt used **only** during the one-shot reference UNet pass that caches identity features
353
+ from `reference_image`. If None or empty, falls back to the positive branch of the main prompt
354
+ (original behavior).
355
+ reference_prompt_2 (`str` or `List[str]`, *optional*):
356
+ Second-encoder counterpart for `reference_prompt`.
357
+ ... (other arguments unchanged) ...
358
+ """
359
+
360
+ callback = kwargs.pop("callback", None)
361
+ callback_steps = kwargs.pop("callback_steps", None)
362
+
363
+ if callback is not None:
364
+ deprecate(
365
+ "callback",
366
+ "1.0.0",
367
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
368
+ )
369
+ if callback_steps is not None:
370
+ deprecate(
371
+ "callback_steps",
372
+ "1.0.0",
373
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
374
+ )
375
+
376
+ # 0. Default height and width to unet
377
+ height = height or self.default_sample_size * self.vae_scale_factor
378
+ width = width or self.default_sample_size * self.vae_scale_factor
379
+
380
+ original_size = original_size or (height, width)
381
+ target_size = target_size or (height, width)
382
+
383
+ # 1. Check inputs. Raise error if not correct
384
+ self.check_inputs(
385
+ prompt,
386
+ prompt_2,
387
+ height,
388
+ width,
389
+ callback_steps,
390
+ negative_prompt,
391
+ negative_prompt_2,
392
+ prompt_embeds,
393
+ negative_prompt_embeds,
394
+ pooled_prompt_embeds,
395
+ negative_pooled_prompt_embeds,
396
+ ip_adapter_image,
397
+ ip_adapter_image_embeds,
398
+ callback_on_step_end_tensor_inputs,
399
+ )
400
+
401
+ self._guidance_scale = guidance_scale
402
+ self._guidance_rescale = guidance_rescale
403
+ self._clip_skip = clip_skip
404
+ self._cross_attention_kwargs = cross_attention_kwargs
405
+ self._denoising_end = denoising_end
406
+ self._interrupt = False
407
+
408
+ # 2. Define call parameters
409
+ if prompt is not None and isinstance(prompt, str):
410
+ batch_size = 1
411
+ elif prompt is not None and isinstance(prompt, list):
412
+ batch_size = len(prompt)
413
+ else:
414
+ batch_size = prompt_embeds.shape[0]
415
+
416
+ device = self._execution_device
417
+
418
+ # 3. Encode input prompt
419
+ lora_scale = (
420
+ self.cross_attention_kwargs.get("scale", None)
421
+ if self.cross_attention_kwargs is not None
422
+ else None
423
+ )
424
+
425
+ (
426
+ prompt_embeds,
427
+ negative_prompt_embeds,
428
+ pooled_prompt_embeds,
429
+ negative_pooled_prompt_embeds,
430
+ ) = self.encode_prompt(
431
+ prompt=prompt,
432
+ prompt_2=prompt_2,
433
+ device=device,
434
+ num_images_per_prompt=num_images_per_prompt,
435
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
436
+ negative_prompt=negative_prompt,
437
+ negative_prompt_2=negative_prompt_2,
438
+ prompt_embeds=prompt_embeds,
439
+ negative_prompt_embeds=negative_prompt_embeds,
440
+ pooled_prompt_embeds=pooled_prompt_embeds,
441
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
442
+ lora_scale=lora_scale,
443
+ clip_skip=self.clip_skip,
444
+ )
445
+
446
+ # 4. Prepare timesteps
447
+ timesteps, num_inference_steps = retrieve_timesteps(
448
+ self.scheduler, num_inference_steps, device, timesteps
449
+ )
450
+
451
+ # 5. Prepare latent variables
452
+ num_channels_latents = self.unet.config.in_channels
453
+ latents = self.prepare_latents(
454
+ batch_size * num_images_per_prompt,
455
+ num_channels_latents,
456
+ height,
457
+ width,
458
+ prompt_embeds.dtype,
459
+ device,
460
+ generator,
461
+ latents,
462
+ )
463
+
464
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
465
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
466
+
467
+ # 7. Prepare added time ids & embeddings
468
+ add_text_embeds = pooled_prompt_embeds
469
+ if self.text_encoder_2 is None:
470
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
471
+ else:
472
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
473
+
474
+ add_time_ids = self._get_add_time_ids(
475
+ original_size,
476
+ crops_coords_top_left,
477
+ target_size,
478
+ dtype=prompt_embeds.dtype,
479
+ text_encoder_projection_dim=text_encoder_projection_dim,
480
+ )
481
+ if negative_original_size is not None and negative_target_size is not None:
482
+ negative_add_time_ids = self._get_add_time_ids(
483
+ negative_original_size,
484
+ negative_crops_coords_top_left,
485
+ negative_target_size,
486
+ dtype=prompt_embeds.dtype,
487
+ text_encoder_projection_dim=text_encoder_projection_dim,
488
+ )
489
+ else:
490
+ negative_add_time_ids = add_time_ids
491
+
492
+ if self.do_classifier_free_guidance:
493
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
494
+ add_text_embeds = torch.cat(
495
+ [negative_pooled_prompt_embeds, add_text_embeds], dim=0
496
+ )
497
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
498
+
499
+ prompt_embeds = prompt_embeds.to(device)
500
+ add_text_embeds = add_text_embeds.to(device)
501
+ add_time_ids = add_time_ids.to(device).repeat(
502
+ batch_size * num_images_per_prompt, 1
503
+ )
504
+
505
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
506
+ image_embeds = self.prepare_ip_adapter_image_embeds(
507
+ ip_adapter_image,
508
+ ip_adapter_image_embeds,
509
+ device,
510
+ batch_size * num_images_per_prompt,
511
+ self.do_classifier_free_guidance,
512
+ )
513
+
514
+ # Preprocess reference image (required)
515
+ reference_image = self.image_processor.preprocess(reference_image)
516
+ reference_latents = self.prepare_image_latents(
517
+ reference_image,
518
+ timesteps[:1].repeat(batch_size * num_images_per_prompt), # no use
519
+ batch_size,
520
+ 1,
521
+ prompt_embeds.dtype,
522
+ device,
523
+ generator,
524
+ add_noise=False,
525
+ )
526
+
527
+ with torch.no_grad():
528
+ ref_timesteps = torch.zeros_like(timesteps[0])
529
+ ref_hidden_states = {}
530
+
531
+ # reference-only prompt support (Task 1)
532
+ def _first_or_none(x):
533
+ if x is None:
534
+ return None
535
+ if isinstance(x, list) and len(x) > 0:
536
+ return x[0]
537
+ return x
538
+
539
+ rp = _first_or_none(reference_prompt)
540
+ rp2 = _first_or_none(reference_prompt_2)
541
+ have_ref_prompt = (rp is not None and str(rp).strip() != "") or (
542
+ rp2 is not None and str(rp2).strip() != ""
543
+ )
544
+
545
+ if have_ref_prompt:
546
+ ref_prompt_embeds, _, ref_pooled_prompt_embeds, _ = self.encode_prompt(
547
+ prompt=rp or prompt,
548
+ prompt_2=rp2 or prompt_2,
549
+ device=device,
550
+ num_images_per_prompt=1,
551
+ do_classifier_free_guidance=False,
552
+ prompt_embeds=None,
553
+ negative_prompt_embeds=None,
554
+ pooled_prompt_embeds=None,
555
+ negative_pooled_prompt_embeds=None,
556
+ lora_scale=(
557
+ self.cross_attention_kwargs.get("scale", None)
558
+ if self.cross_attention_kwargs is not None
559
+ else None
560
+ ),
561
+ clip_skip=self.clip_skip,
562
+ )
563
+ else:
564
+ if self.do_classifier_free_guidance:
565
+ ref_prompt_embeds = prompt_embeds[-1:].clone()
566
+ ref_pooled_prompt_embeds = add_text_embeds[-1:].clone()
567
+ else:
568
+ ref_prompt_embeds = prompt_embeds[:1].clone()
569
+ ref_pooled_prompt_embeds = add_text_embeds[:1].clone()
570
+
571
+ self.unet(
572
+ reference_latents,
573
+ ref_timesteps,
574
+ encoder_hidden_states=ref_prompt_embeds,
575
+ added_cond_kwargs={
576
+ "text_embeds": ref_pooled_prompt_embeds,
577
+ "time_ids": add_time_ids[-1:],
578
+ },
579
+ cross_attention_kwargs={
580
+ "cache_hidden_states": ref_hidden_states,
581
+ "use_mv": False,
582
+ "use_ref": False,
583
+ },
584
+ return_dict=False,
585
+ )
586
+ ref_hidden_states = {
587
+ k: v.repeat_interleave(num_images_per_prompt, dim=0)
588
+ for k, v in ref_hidden_states.items()
589
+ }
590
+ if self.do_classifier_free_guidance:
591
+ ref_hidden_states = {
592
+ k: torch.cat([torch.zeros_like(v), v], dim=0)
593
+ for k, v in ref_hidden_states.items()
594
+ }
595
+
596
+ cross_attention_kwargs = {
597
+ "mv_scale": mv_scale,
598
+ "ref_hidden_states": ref_hidden_states,
599
+ "ref_scale": reference_conditioning_scale,
600
+ **(self.cross_attention_kwargs or {}),
601
+ }
602
+
603
+ # ------------- control image (Task 2 supports 6ch pass-through) -------------
604
+ control_image_feature = self.prepare_control_image(
605
+ image=control_image,
606
+ width=width,
607
+ height=height,
608
+ batch_size=batch_size * num_images_per_prompt,
609
+ num_images_per_prompt=1, # NOTE: always 1 for control images
610
+ device=device,
611
+ dtype=latents.dtype,
612
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
613
+ ).to(device=device, dtype=latents.dtype)
614
+
615
+ adapter_state = self.cond_encoder(control_image_feature)
616
+ for i, state in enumerate(adapter_state):
617
+ adapter_state[i] = state * control_conditioning_scale
618
+ # ---------------------------------------------------------------------------
619
+
620
+ # 8. Denoising loop
621
+ num_warmup_steps = max(
622
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
623
+ )
624
+
625
+ # 8.1 Apply denoising_end
626
+ if (
627
+ self.denoising_end is not None
628
+ and isinstance(self.denoising_end, float)
629
+ and self.denoising_end > 0
630
+ and self.denoising_end < 1
631
+ ):
632
+ discrete_timestep_cutoff = int(
633
+ round(
634
+ self.scheduler.config.num_train_timesteps
635
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
636
+ )
637
+ )
638
+ num_inference_steps = len(
639
+ list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))
640
+ )
641
+ timesteps = timesteps[:num_inference_steps]
642
+
643
+ # 9. Optionally get Guidance Scale Embedding
644
+ timestep_cond = None
645
+ if self.unet.config.time_cond_proj_dim is not None:
646
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
647
+ batch_size * num_images_per_prompt
648
+ )
649
+ timestep_cond = self.get_guidance_scale_embedding(
650
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
651
+ ).to(device=device, dtype=latents.dtype)
652
+
653
+ self._num_timesteps = len(timesteps)
654
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
655
+ for i, t in enumerate(timesteps):
656
+ if self.interrupt:
657
+ continue
658
+
659
+ # expand the latents if we are doing classifier free guidance
660
+ latent_model_input = (
661
+ torch.cat([latents] * 2)
662
+ if self.do_classifier_free_guidance
663
+ else latents
664
+ )
665
+
666
+ latent_model_input = self.scheduler.scale_model_input(
667
+ latent_model_input, t
668
+ )
669
+
670
+ added_cond_kwargs = {
671
+ "text_embeds": add_text_embeds,
672
+ "time_ids": add_time_ids,
673
+ }
674
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
675
+ added_cond_kwargs["image_embeds"] = image_embeds
676
+
677
+ if i < int(num_inference_steps * control_conditioning_factor):
678
+ down_intrablock_additional_residuals = [
679
+ state.clone() for state in adapter_state
680
+ ]
681
+ else:
682
+ down_intrablock_additional_residuals = None
683
+
684
+ # predict the noise residual
685
+ noise_pred = self.unet(
686
+ latent_model_input,
687
+ t,
688
+ encoder_hidden_states=prompt_embeds,
689
+ timestep_cond=timestep_cond,
690
+ cross_attention_kwargs=cross_attention_kwargs,
691
+ down_intrablock_additional_residuals=down_intrablock_additional_residuals,
692
+ added_cond_kwargs=added_cond_kwargs,
693
+ return_dict=False,
694
+ )[0]
695
+
696
+ # perform guidance
697
+ if self.do_classifier_free_guidance:
698
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
699
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
700
+ noise_pred_text - noise_pred_uncond
701
+ )
702
+
703
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
704
+ noise_pred = rescale_noise_cfg(
705
+ noise_pred,
706
+ noise_pred_text,
707
+ guidance_rescale=self.guidance_rescale,
708
+ )
709
+
710
+ # compute the previous noisy sample x_t -> x_t-1
711
+ latents_dtype = latents.dtype
712
+ latents = self.scheduler.step(
713
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
714
+ )[0]
715
+ if latents.dtype != latents_dtype:
716
+ if torch.backends.mps.is_available():
717
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
718
+ latents = latents.to(latents_dtype)
719
+
720
+ if callback_on_step_end is not None:
721
+ callback_kwargs = {}
722
+ for k in callback_on_step_end_tensor_inputs:
723
+ callback_kwargs[k] = locals()[k]
724
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
725
+
726
+ latents = callback_outputs.pop("latents", latents)
727
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
728
+ negative_prompt_embeds = callback_outputs.pop(
729
+ "negative_prompt_embeds", negative_prompt_embeds
730
+ )
731
+ add_text_embeds = callback_outputs.pop(
732
+ "add_text_embeds", add_text_embeds
733
+ )
734
+ negative_pooled_prompt_embeds = callback_outputs.pop(
735
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
736
+ )
737
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
738
+ negative_add_time_ids = callback_outputs.pop(
739
+ "negative_add_time_ids", negative_add_time_ids
740
+ )
741
+
742
+ # call the callback, if provided
743
+ if i == len(timesteps) - 1 or (
744
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
745
+ ):
746
+ progress_bar.update()
747
+ if callback is not None and i % callback_steps == 0:
748
+ step_idx = i // getattr(self.scheduler, "order", 1)
749
+ callback(step_idx, t, latents)
750
+
751
+ if not output_type == "latent":
752
+ # make sure the VAE is in float32 mode, as it overflows in float16
753
+ needs_upcasting = (
754
+ self.vae.dtype == torch.float16 and self.vae.config.force_upcast
755
+ )
756
+
757
+ if needs_upcasting:
758
+ self.upcast_vae()
759
+ latents = latents.to(
760
+ next(iter(self.vae.post_quant_conv.parameters())).dtype
761
+ )
762
+ elif latents.dtype != self.vae.dtype:
763
+ if torch.backends.mps.is_available():
764
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
765
+ self.vae = self.vae.to(latents.dtype)
766
+
767
+ # unscale/denormalize the latents
768
+ # denormalize with the mean and std if available and not None
769
+ has_latents_mean = (
770
+ hasattr(self.vae.config, "latents_mean")
771
+ and self.vae.config.latents_mean is not None
772
+ )
773
+ has_latents_std = (
774
+ hasattr(self.vae.config, "latents_std")
775
+ and self.vae.config.latents_std is not None
776
+ )
777
+ if has_latents_mean and has_latents_std:
778
+ latents_mean = (
779
+ torch.tensor(self.vae.config.latents_mean)
780
+ .view(1, 4, 1, 1)
781
+ .to(latents.device, latents.dtype)
782
+ )
783
+ latents_std = (
784
+ torch.tensor(self.vae.config.latents_std)
785
+ .view(1, 4, 1, 1)
786
+ .to(latents.device, latents.dtype)
787
+ )
788
+ latents = (
789
+ latents * latents_std / self.vae.config.scaling_factor
790
+ + latents_mean
791
+ )
792
+ else:
793
+ latents = latents / self.vae.config.scaling_factor
794
+
795
+ image = self.vae.decode(latents, return_dict=False)[0]
796
+
797
+ # cast back to fp16 if needed
798
+ if needs_upcasting:
799
+ self.vae.to(dtype=torch.float16)
800
+ else:
801
+ image = latents
802
+
803
+ if not output_type == "latent":
804
+ # apply watermark if available
805
+ if self.watermark is not None:
806
+ image = self.watermark.apply_watermark(image)
807
+
808
+ image = self.image_processor.postprocess(image, output_type=output_type)
809
+
810
+ # Offload all models
811
+ self.maybe_free_model_hooks()
812
+
813
+ if not return_dict:
814
+ return (image,)
815
+
816
+ return StableDiffusionXLPipelineOutput(images=image)
817
+
818
+ ### NEW: adapters ###
819
+ def _init_custom_adapter(
820
+ self,
821
+ # Multi-view adapter
822
+ num_views: int = 1,
823
+ self_attn_processor: Any = DecoupledMVRowSelfAttnProcessor2_0,
824
+ # Condition encoder
825
+ cond_in_channels: int = 6,
826
+ # For training
827
+ copy_attn_weights: bool = True,
828
+ zero_init_module_keys: List[str] = [],
829
+ ):
830
+ # Condition encoder
831
+ self.cond_encoder = T2IAdapter(
832
+ in_channels=cond_in_channels,
833
+ channels=(320, 640, 1280, 1280),
834
+ num_res_blocks=2,
835
+ downscale_factor=16,
836
+ adapter_type="full_adapter_xl",
837
+ )
838
+
839
+ # set custom attn processor for multi-view attention and image cross-attention
840
+ self.unet: UNet2DConditionModel
841
+ set_unet_2d_condition_attn_processor(
842
+ self.unet,
843
+ set_self_attn_proc_func=lambda name, hs, cad, ap: self_attn_processor(
844
+ query_dim=hs,
845
+ inner_dim=hs,
846
+ num_views=num_views,
847
+ name=name,
848
+ use_mv=True,
849
+ use_ref=True,
850
+ ),
851
+ set_cross_attn_proc_func=lambda name, hs, cad, ap: self_attn_processor(
852
+ query_dim=hs,
853
+ inner_dim=hs,
854
+ num_views=num_views,
855
+ name=name,
856
+ use_mv=False,
857
+ use_ref=False,
858
+ ),
859
+ )
860
+
861
+ # copy decoupled attention weights from original unet
862
+ if copy_attn_weights:
863
+ state_dict = self.unet.state_dict()
864
+ for key in state_dict.keys():
865
+ if "_mv" in key:
866
+ compatible_key = key.replace("_mv", "").replace("processor.", "")
867
+ elif "_ref" in key:
868
+ compatible_key = key.replace("_ref", "").replace("processor.", "")
869
+ else:
870
+ compatible_key = key
871
+
872
+ is_zero_init_key = any([k in key for k in zero_init_module_keys])
873
+ if is_zero_init_key:
874
+ state_dict[key] = torch.zeros_like(state_dict[compatible_key])
875
+ else:
876
+ state_dict[key] = state_dict[compatible_key].clone()
877
+ self.unet.load_state_dict(state_dict)
878
+
879
+ def _load_custom_adapter(self, state_dict):
880
+ self.unet.load_state_dict(state_dict, strict=False)
881
+ self.cond_encoder.load_state_dict(state_dict, strict=False)
882
+
883
+ def _save_custom_adapter(
884
+ self,
885
+ include_keys: Optional[List[str]] = None,
886
+ exclude_keys: Optional[List[str]] = None,
887
+ ):
888
+ def include_fn(k):
889
+ is_included = False
890
+
891
+ if include_keys is not None:
892
+ is_included = is_included or any([key in k for key in include_keys])
893
+ if exclude_keys is not None:
894
+ is_included = is_included and not any(
895
+ [key in k for key in exclude_keys]
896
+ )
897
+
898
+ return is_included
899
+
900
+ state_dict = {k: v for k, v in self.unet.state_dict().items() if include_fn(k)}
901
+ state_dict.update(self.cond_encoder.state_dict())
902
+
903
+ return state_dict
comfyui-mvadapter/mvadapter/schedulers/ShiftSNRSchedulerKarras.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ import torch
4
+
5
+ from .scheduler_utils import SNR_to_betas, compute_snr
6
+
7
+
8
+ class ShiftSNRSchedulerKarras:
9
+ """
10
+ Wraps a Diffusers scheduler to apply SNR shifting to its noise schedule and
11
+ rebuilds a DPMSolverMultistepScheduler that supports Karras sigmas.
12
+
13
+ Usage:
14
+ new_sched = ShiftSNRSchedulerKarras.from_scheduler(
15
+ noise_scheduler=base_sched,
16
+ shift_mode="interpolated", # or "default"
17
+ shift_scale=8.0,
18
+ scheduler_class=DPMSolverMultistepScheduler, # usually this
19
+ )
20
+ """
21
+
22
+ # Supported modes for how the SNR shift is applied
23
+ SHIFT_MODES = ["default", "interpolated"]
24
+
25
+ def __init__(
26
+ self,
27
+ noise_scheduler: Any,
28
+ timesteps: Any,
29
+ shift_scale: float,
30
+ scheduler_class: Any,
31
+ ):
32
+ # original scheduler (used only as a reference/config source)
33
+ self.noise_scheduler = noise_scheduler
34
+ # tensor of timesteps to compute SNR/betas on
35
+ self.timesteps = timesteps
36
+ # scale by which to divide the SNR (e.g., 8.0)
37
+ self.shift_scale = shift_scale
38
+ # the scheduler class to construct for output (e.g., DPMSolverMultistepScheduler)
39
+ self.scheduler_class = scheduler_class
40
+
41
+ def _get_shift_scheduler(self):
42
+ """
43
+ Apply a uniform SNR shift: snr' = snr / shift_scale
44
+ Then convert to betas and rebuild the scheduler with Karras enabled.
45
+ """
46
+ snr = compute_snr(self.timesteps, self.noise_scheduler)
47
+ shifted_betas = SNR_to_betas(snr / self.shift_scale)
48
+
49
+ return self.scheduler_class.from_config(
50
+ self.noise_scheduler.config,
51
+ trained_betas=shifted_betas.numpy(),
52
+ # Enable Karras sigmas in the rebuilt scheduler
53
+ algorithm_type="dpmsolver++",
54
+ use_karras_sigmas=True,
55
+ )
56
+
57
+ def _get_interpolated_shift_scheduler(self):
58
+ """
59
+ Interpolate SNR in log-space between the original and the shifted SNR
60
+ as timesteps progress. This tends to preserve early behavior and
61
+ gradually apply the shift later in the schedule.
62
+ """
63
+ snr = compute_snr(self.timesteps, self.noise_scheduler)
64
+ shifted_snr = snr / self.shift_scale
65
+
66
+ # Interpolate in log-space from original -> shifted across timesteps
67
+ weighting = self.timesteps.float() / (
68
+ self.noise_scheduler.config.num_train_timesteps - 1
69
+ )
70
+ interpolated_snr = torch.exp(
71
+ torch.log(snr) * (1 - weighting) + torch.log(shifted_snr) * weighting
72
+ )
73
+
74
+ shifted_betas = SNR_to_betas(interpolated_snr)
75
+
76
+ return self.scheduler_class.from_config(
77
+ self.noise_scheduler.config,
78
+ trained_betas=shifted_betas.numpy(),
79
+ # Enable Karras sigmas in the rebuilt scheduler
80
+ algorithm_type="dpmsolver++",
81
+ use_karras_sigmas=True,
82
+ )
83
+
84
+ @classmethod
85
+ def from_scheduler(
86
+ cls,
87
+ noise_scheduler: Any,
88
+ shift_mode: str = "default",
89
+ timesteps: Any = None,
90
+ shift_scale: float = 1.0,
91
+ scheduler_class: Any = None,
92
+ ):
93
+ """
94
+ Factory that returns a NEW scheduler instance with the shifted betas applied.
95
+
96
+ Args:
97
+ noise_scheduler: the original Diffusers scheduler (used for config & base betas)
98
+ shift_mode: "default" or "interpolated"
99
+ timesteps: tensor of timesteps to evaluate SNR on; if None, uses full training range
100
+ shift_scale: divide SNR by this value (e.g., 8.0)
101
+ scheduler_class: class to construct for the output scheduler (defaults to original class)
102
+ """
103
+ if timesteps is None:
104
+ timesteps = torch.arange(0, noise_scheduler.config.num_train_timesteps)
105
+ if scheduler_class is None:
106
+ scheduler_class = noise_scheduler.__class__
107
+
108
+ wrapper = cls(
109
+ noise_scheduler=noise_scheduler,
110
+ timesteps=timesteps,
111
+ shift_scale=shift_scale,
112
+ scheduler_class=scheduler_class,
113
+ )
114
+
115
+ if shift_mode == "default":
116
+ return wrapper._get_shift_scheduler()
117
+ elif shift_mode == "interpolated":
118
+ return wrapper._get_interpolated_shift_scheduler()
119
+ else:
120
+ raise ValueError(f"Unknown shift_mode: {shift_mode}")
comfyui-mvadapter/mvadapter/schedulers/__pycache__/ShiftSNRSchedulerKarras.cpython-312.pyc ADDED
Binary file (4.86 kB). View file
 
comfyui-mvadapter/mvadapter/schedulers/__pycache__/scheduler_utils.cpython-312.pyc ADDED
Binary file (3.78 kB). View file
 
comfyui-mvadapter/mvadapter/schedulers/__pycache__/scheduling_shift_snr.cpython-312.pyc ADDED
Binary file (5.9 kB). View file
 
comfyui-mvadapter/mvadapter/schedulers/scheduler_utils.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def get_sigmas(noise_scheduler, timesteps, n_dim=4, dtype=torch.float32, device=None):
5
+ sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
6
+ schedule_timesteps = noise_scheduler.timesteps.to(device)
7
+ timesteps = timesteps.to(device)
8
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
9
+ sigma = sigmas[step_indices].flatten()
10
+ while len(sigma.shape) < n_dim:
11
+ sigma = sigma.unsqueeze(-1)
12
+ return sigma
13
+
14
+
15
+ def SNR_to_betas(snr):
16
+ """
17
+ Converts SNR to betas
18
+ """
19
+ # alphas_cumprod = pass
20
+ # snr = (alpha / ) ** 2
21
+ # alpha_t^2 / (1 - alpha_t^2) = snr
22
+ alpha_t = (snr / (1 + snr)) ** 0.5
23
+ alphas_cumprod = alpha_t**2
24
+ alphas = alphas_cumprod / torch.cat(
25
+ [torch.ones(1, device=snr.device), alphas_cumprod[:-1]]
26
+ )
27
+ betas = 1 - alphas
28
+ return betas
29
+
30
+
31
+ def compute_snr(timesteps, noise_scheduler):
32
+ """
33
+ Computes SNR as per Min-SNR-Diffusion-Training/guided_diffusion/gaussian_diffusion.py at 521b624bd70c67cee4bdf49225915f5
34
+ """
35
+ alphas_cumprod = noise_scheduler.alphas_cumprod
36
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
37
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
38
+
39
+ # Expand the tensors.
40
+ # Adapted from Min-SNR-Diffusion-Training/guided_diffusion/gaussian_diffusion.py at 521b624bd70c67cee4bdf49225915f5
41
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
42
+ timesteps
43
+ ].float()
44
+ while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
45
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
46
+ alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
47
+
48
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
49
+ device=timesteps.device
50
+ )[timesteps].float()
51
+ while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
52
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
53
+ sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
54
+
55
+ # Compute SNR.
56
+ snr = (alpha / sigma) ** 2
57
+ return snr
58
+
59
+
60
+ def compute_alpha(timesteps, noise_scheduler):
61
+ alphas_cumprod = noise_scheduler.alphas_cumprod
62
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
63
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
64
+ timesteps
65
+ ].float()
66
+ while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
67
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
68
+ alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
69
+
70
+ return alpha
comfyui-mvadapter/mvadapter/schedulers/scheduling_shift_snr.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ import torch
4
+
5
+ from .scheduler_utils import SNR_to_betas, compute_snr
6
+
7
+
8
+ class ShiftSNRScheduler:
9
+ SHIFT_MODES = ["default", "interpolated"]
10
+
11
+ def __init__(
12
+ self,
13
+ noise_scheduler: Any,
14
+ timesteps: Any,
15
+ shift_scale: float,
16
+ scheduler_class: Any,
17
+ ):
18
+ self.noise_scheduler = noise_scheduler
19
+ self.timesteps = timesteps
20
+ self.shift_scale = shift_scale
21
+ self.scheduler_class = scheduler_class
22
+
23
+ def _get_shift_scheduler(self):
24
+ """
25
+ Prepare scheduler for shifted betas.
26
+
27
+ :return: A scheduler object configured with shifted betas
28
+ """
29
+ snr = compute_snr(self.timesteps, self.noise_scheduler)
30
+ shifted_betas = SNR_to_betas(snr / self.shift_scale)
31
+
32
+ return self.scheduler_class.from_config(
33
+ self.noise_scheduler.config, trained_betas=shifted_betas.numpy()
34
+ )
35
+
36
+ def _get_interpolated_shift_scheduler(self):
37
+ """
38
+ Prepare scheduler for shifted betas and interpolate with the original betas in log space.
39
+
40
+ :return: A scheduler object configured with interpolated shifted betas
41
+ """
42
+ snr = compute_snr(self.timesteps, self.noise_scheduler)
43
+ shifted_snr = snr / self.shift_scale
44
+
45
+ weighting = self.timesteps.float() / (
46
+ self.noise_scheduler.config.num_train_timesteps - 1
47
+ )
48
+ interpolated_snr = torch.exp(
49
+ torch.log(snr) * (1 - weighting) + torch.log(shifted_snr) * weighting
50
+ )
51
+
52
+ shifted_betas = SNR_to_betas(interpolated_snr)
53
+
54
+ return self.scheduler_class.from_config(
55
+ self.noise_scheduler.config, trained_betas=shifted_betas.numpy()
56
+ )
57
+
58
+ @classmethod
59
+ def from_scheduler(
60
+ cls,
61
+ noise_scheduler: Any,
62
+ shift_mode: str = "default",
63
+ timesteps: Any = None,
64
+ shift_scale: float = 1.0,
65
+ scheduler_class: Any = None,
66
+ ):
67
+ # Check input
68
+ if timesteps is None:
69
+ timesteps = torch.arange(0, noise_scheduler.config.num_train_timesteps)
70
+ if scheduler_class is None:
71
+ scheduler_class = noise_scheduler.__class__
72
+
73
+ # Create scheduler
74
+ shift_scheduler = cls(
75
+ noise_scheduler=noise_scheduler,
76
+ timesteps=timesteps,
77
+ shift_scale=shift_scale,
78
+ scheduler_class=scheduler_class,
79
+ )
80
+
81
+ if shift_mode == "default":
82
+ return shift_scheduler._get_shift_scheduler()
83
+ elif shift_mode == "interpolated":
84
+ return shift_scheduler._get_interpolated_shift_scheduler()
85
+ else:
86
+ raise ValueError(f"Unknown shift_mode: {shift_mode}")
87
+
88
+
89
+ if __name__ == "__main__":
90
+ """
91
+ Compare the alpha values for different noise schedulers.
92
+ """
93
+ import matplotlib.pyplot as plt
94
+ from diffusers import DDPMScheduler
95
+
96
+ from .scheduler_utils import compute_alpha
97
+
98
+ # Base
99
+ timesteps = torch.arange(0, 1000)
100
+ noise_scheduler_base = DDPMScheduler.from_pretrained(
101
+ "runwayml/stable-diffusion-v1-5", subfolder="scheduler"
102
+ )
103
+ alpha = compute_alpha(timesteps, noise_scheduler_base)
104
+ plt.plot(timesteps.numpy(), alpha.numpy(), label="Base")
105
+
106
+ # Kolors
107
+ num_train_timesteps_ = 1100
108
+ timesteps_ = torch.arange(0, num_train_timesteps_)
109
+ noise_kwargs = {"beta_end": 0.014, "num_train_timesteps": num_train_timesteps_}
110
+ noise_scheduler_kolors = DDPMScheduler.from_config(
111
+ noise_scheduler_base.config, **noise_kwargs
112
+ )
113
+ alpha = compute_alpha(timesteps_, noise_scheduler_kolors)
114
+ plt.plot(timesteps_.numpy(), alpha.numpy(), label="Kolors")
115
+
116
+ # Shift betas
117
+ shift_scale = 8.0
118
+ noise_scheduler_shift = ShiftSNRScheduler.from_scheduler(
119
+ noise_scheduler_base, shift_mode="default", shift_scale=shift_scale
120
+ )
121
+ alpha = compute_alpha(timesteps, noise_scheduler_shift)
122
+ plt.plot(timesteps.numpy(), alpha.numpy(), label="Shift Noise (scale 8.0)")
123
+
124
+ # Shift betas (interpolated)
125
+ noise_scheduler_inter = ShiftSNRScheduler.from_scheduler(
126
+ noise_scheduler_base, shift_mode="interpolated", shift_scale=shift_scale
127
+ )
128
+ alpha = compute_alpha(timesteps, noise_scheduler_inter)
129
+ plt.plot(timesteps.numpy(), alpha.numpy(), label="Interpolated (scale 8.0)")
130
+
131
+ # ZeroSNR
132
+ noise_scheduler = DDPMScheduler.from_config(
133
+ noise_scheduler_base.config, rescale_betas_zero_snr=True
134
+ )
135
+ alpha = compute_alpha(timesteps, noise_scheduler)
136
+ plt.plot(timesteps.numpy(), alpha.numpy(), label="ZeroSNR")
137
+
138
+ plt.legend()
139
+ plt.grid()
140
+ plt.savefig("check_alpha.png")
comfyui-mvadapter/mvadapter/utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .camera import get_camera, get_orthogonal_camera
2
+ from .geometry import get_plucker_embeds_from_cameras_ortho
3
+ from .saving import make_image_grid, tensor_to_image