Yesianrohn commited on
Commit
65266c7
·
verified ·
1 Parent(s): 5978a74

Upload 700 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. diffusers/__init__.py +944 -0
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/__init__.py +27 -0
  4. diffusers/commands/diffusers_cli.py +43 -0
  5. diffusers/commands/env.py +180 -0
  6. diffusers/commands/fp16_safetensors.py +132 -0
  7. diffusers/configuration_utils.py +720 -0
  8. diffusers/dependency_versions_check.py +34 -0
  9. diffusers/dependency_versions_table.py +46 -0
  10. diffusers/experimental/README.md +5 -0
  11. diffusers/experimental/__init__.py +1 -0
  12. diffusers/experimental/rl/__init__.py +1 -0
  13. diffusers/experimental/rl/value_guided_sampling.py +153 -0
  14. diffusers/image_processor.py +1103 -0
  15. diffusers/loaders/__init__.py +102 -0
  16. diffusers/loaders/__pycache__/__init__.cpython-310.pyc +0 -0
  17. diffusers/loaders/__pycache__/__init__.cpython-38.pyc +0 -0
  18. diffusers/loaders/__pycache__/lora_base.cpython-310.pyc +0 -0
  19. diffusers/loaders/__pycache__/lora_base.cpython-38.pyc +0 -0
  20. diffusers/loaders/__pycache__/lora_conversion_utils.cpython-310.pyc +0 -0
  21. diffusers/loaders/__pycache__/lora_conversion_utils.cpython-38.pyc +0 -0
  22. diffusers/loaders/__pycache__/lora_pipeline.cpython-310.pyc +0 -0
  23. diffusers/loaders/__pycache__/lora_pipeline.cpython-38.pyc +0 -0
  24. diffusers/loaders/__pycache__/peft.cpython-310.pyc +0 -0
  25. diffusers/loaders/__pycache__/peft.cpython-38.pyc +0 -0
  26. diffusers/loaders/__pycache__/single_file_model.cpython-310.pyc +0 -0
  27. diffusers/loaders/__pycache__/single_file_model.cpython-38.pyc +0 -0
  28. diffusers/loaders/__pycache__/single_file_utils.cpython-310.pyc +0 -0
  29. diffusers/loaders/__pycache__/single_file_utils.cpython-38.pyc +0 -0
  30. diffusers/loaders/__pycache__/unet.cpython-310.pyc +0 -0
  31. diffusers/loaders/__pycache__/unet.cpython-38.pyc +0 -0
  32. diffusers/loaders/__pycache__/unet_loader_utils.cpython-310.pyc +0 -0
  33. diffusers/loaders/__pycache__/unet_loader_utils.cpython-38.pyc +0 -0
  34. diffusers/loaders/__pycache__/utils.cpython-310.pyc +0 -0
  35. diffusers/loaders/__pycache__/utils.cpython-38.pyc +0 -0
  36. diffusers/loaders/ip_adapter.py +348 -0
  37. diffusers/loaders/lora_base.py +759 -0
  38. diffusers/loaders/lora_conversion_utils.py +660 -0
  39. diffusers/loaders/lora_pipeline.py +0 -0
  40. diffusers/loaders/peft.py +396 -0
  41. diffusers/loaders/single_file.py +550 -0
  42. diffusers/loaders/single_file_model.py +318 -0
  43. diffusers/loaders/single_file_utils.py +2100 -0
  44. diffusers/loaders/textual_inversion.py +578 -0
  45. diffusers/loaders/unet.py +921 -0
  46. diffusers/loaders/unet_loader_utils.py +163 -0
  47. diffusers/loaders/utils.py +59 -0
  48. diffusers/models/README.md +3 -0
  49. diffusers/models/__init__.py +137 -0
  50. diffusers/models/__pycache__/__init__.cpython-310.pyc +0 -0
diffusers/__init__.py ADDED
@@ -0,0 +1,944 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "0.31.0.dev0"
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ from .utils import (
6
+ DIFFUSERS_SLOW_IMPORT,
7
+ OptionalDependencyNotAvailable,
8
+ _LazyModule,
9
+ is_flax_available,
10
+ is_k_diffusion_available,
11
+ is_librosa_available,
12
+ is_note_seq_available,
13
+ is_onnx_available,
14
+ is_scipy_available,
15
+ is_sentencepiece_available,
16
+ is_torch_available,
17
+ is_torchsde_available,
18
+ is_transformers_available,
19
+ )
20
+
21
+
22
+ # Lazy Import based on
23
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/__init__.py
24
+
25
+ # When adding a new object to this init, please add it to `_import_structure`. The `_import_structure` is a dictionary submodule to list of object names,
26
+ # and is used to defer the actual importing for when the objects are requested.
27
+ # This way `import diffusers` provides the names in the namespace without actually importing anything (and especially none of the backends).
28
+
29
+ _import_structure = {
30
+ "configuration_utils": ["ConfigMixin"],
31
+ "loaders": ["FromOriginalModelMixin"],
32
+ "models": [],
33
+ "pipelines": [],
34
+ "schedulers": [],
35
+ "utils": [
36
+ "OptionalDependencyNotAvailable",
37
+ "is_flax_available",
38
+ "is_inflect_available",
39
+ "is_invisible_watermark_available",
40
+ "is_k_diffusion_available",
41
+ "is_k_diffusion_version",
42
+ "is_librosa_available",
43
+ "is_note_seq_available",
44
+ "is_onnx_available",
45
+ "is_scipy_available",
46
+ "is_torch_available",
47
+ "is_torchsde_available",
48
+ "is_transformers_available",
49
+ "is_transformers_version",
50
+ "is_unidecode_available",
51
+ "logging",
52
+ ],
53
+ }
54
+
55
+ try:
56
+ if not is_onnx_available():
57
+ raise OptionalDependencyNotAvailable()
58
+ except OptionalDependencyNotAvailable:
59
+ from .utils import dummy_onnx_objects # noqa F403
60
+
61
+ _import_structure["utils.dummy_onnx_objects"] = [
62
+ name for name in dir(dummy_onnx_objects) if not name.startswith("_")
63
+ ]
64
+
65
+ else:
66
+ _import_structure["pipelines"].extend(["OnnxRuntimeModel"])
67
+
68
+ try:
69
+ if not is_torch_available():
70
+ raise OptionalDependencyNotAvailable()
71
+ except OptionalDependencyNotAvailable:
72
+ from .utils import dummy_pt_objects # noqa F403
73
+
74
+ _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
75
+
76
+ else:
77
+ _import_structure["models"].extend(
78
+ [
79
+ "AsymmetricAutoencoderKL",
80
+ "AuraFlowTransformer2DModel",
81
+ "AutoencoderKL",
82
+ "AutoencoderKLCogVideoX",
83
+ "AutoencoderKLTemporalDecoder",
84
+ "AutoencoderOobleck",
85
+ "AutoencoderTiny",
86
+ "CogVideoXTransformer3DModel",
87
+ "ConsistencyDecoderVAE",
88
+ "ControlNetModel",
89
+ "ControlNetXSAdapter",
90
+ "DiTTransformer2DModel",
91
+ "FluxControlNetModel",
92
+ "FluxMultiControlNetModel",
93
+ "FluxTransformer2DModel",
94
+ "HunyuanDiT2DControlNetModel",
95
+ "HunyuanDiT2DModel",
96
+ "HunyuanDiT2DMultiControlNetModel",
97
+ "I2VGenXLUNet",
98
+ "Kandinsky3UNet",
99
+ "LatteTransformer3DModel",
100
+ "LuminaNextDiT2DModel",
101
+ "ModelMixin",
102
+ "MotionAdapter",
103
+ "MultiAdapter",
104
+ "PixArtTransformer2DModel",
105
+ "PriorTransformer",
106
+ "SD3ControlNetModel",
107
+ "SD3MultiControlNetModel",
108
+ "SD3Transformer2DModel",
109
+ "SparseControlNetModel",
110
+ "StableAudioDiTModel",
111
+ "StableCascadeUNet",
112
+ "T2IAdapter",
113
+ "T5FilmDecoder",
114
+ "Transformer2DModel",
115
+ "UNet1DModel",
116
+ "UNet2DConditionModel",
117
+ "UNet2DModel",
118
+ "UNet3DConditionModel",
119
+ "UNetControlNetXSModel",
120
+ "UNetMotionModel",
121
+ "UNetSpatioTemporalConditionModel",
122
+ "UVit2DModel",
123
+ "VQModel",
124
+ ]
125
+ )
126
+
127
+ _import_structure["optimization"] = [
128
+ "get_constant_schedule",
129
+ "get_constant_schedule_with_warmup",
130
+ "get_cosine_schedule_with_warmup",
131
+ "get_cosine_with_hard_restarts_schedule_with_warmup",
132
+ "get_linear_schedule_with_warmup",
133
+ "get_polynomial_decay_schedule_with_warmup",
134
+ "get_scheduler",
135
+ ]
136
+ _import_structure["pipelines"].extend(
137
+ [
138
+ "AudioPipelineOutput",
139
+ "AutoPipelineForImage2Image",
140
+ "AutoPipelineForInpainting",
141
+ "AutoPipelineForText2Image",
142
+ "ConsistencyModelPipeline",
143
+ "DanceDiffusionPipeline",
144
+ "DDIMPipeline",
145
+ "DDPMPipeline",
146
+ "DiffusionPipeline",
147
+ "DiTPipeline",
148
+ "ImagePipelineOutput",
149
+ "KarrasVePipeline",
150
+ "LDMPipeline",
151
+ "LDMSuperResolutionPipeline",
152
+ "PNDMPipeline",
153
+ "RePaintPipeline",
154
+ "ScoreSdeVePipeline",
155
+ "StableDiffusionMixin",
156
+ ]
157
+ )
158
+ _import_structure["schedulers"].extend(
159
+ [
160
+ "AmusedScheduler",
161
+ "CMStochasticIterativeScheduler",
162
+ "CogVideoXDDIMScheduler",
163
+ "CogVideoXDPMScheduler",
164
+ "DDIMInverseScheduler",
165
+ "DDIMParallelScheduler",
166
+ "DDIMScheduler",
167
+ "DDPMParallelScheduler",
168
+ "DDPMScheduler",
169
+ "DDPMWuerstchenScheduler",
170
+ "DEISMultistepScheduler",
171
+ "DPMSolverMultistepInverseScheduler",
172
+ "DPMSolverMultistepScheduler",
173
+ "DPMSolverSinglestepScheduler",
174
+ "EDMDPMSolverMultistepScheduler",
175
+ "EDMEulerScheduler",
176
+ "EulerAncestralDiscreteScheduler",
177
+ "EulerDiscreteScheduler",
178
+ "FlowMatchEulerDiscreteScheduler",
179
+ "FlowMatchHeunDiscreteScheduler",
180
+ "HeunDiscreteScheduler",
181
+ "IPNDMScheduler",
182
+ "KarrasVeScheduler",
183
+ "KDPM2AncestralDiscreteScheduler",
184
+ "KDPM2DiscreteScheduler",
185
+ "LCMScheduler",
186
+ "PNDMScheduler",
187
+ "RePaintScheduler",
188
+ "SASolverScheduler",
189
+ "SchedulerMixin",
190
+ "ScoreSdeVeScheduler",
191
+ "TCDScheduler",
192
+ "UnCLIPScheduler",
193
+ "UniPCMultistepScheduler",
194
+ "VQDiffusionScheduler",
195
+ ]
196
+ )
197
+ _import_structure["training_utils"] = ["EMAModel"]
198
+
199
+ try:
200
+ if not (is_torch_available() and is_scipy_available()):
201
+ raise OptionalDependencyNotAvailable()
202
+ except OptionalDependencyNotAvailable:
203
+ from .utils import dummy_torch_and_scipy_objects # noqa F403
204
+
205
+ _import_structure["utils.dummy_torch_and_scipy_objects"] = [
206
+ name for name in dir(dummy_torch_and_scipy_objects) if not name.startswith("_")
207
+ ]
208
+
209
+ else:
210
+ _import_structure["schedulers"].extend(["LMSDiscreteScheduler"])
211
+
212
+ try:
213
+ if not (is_torch_available() and is_torchsde_available()):
214
+ raise OptionalDependencyNotAvailable()
215
+ except OptionalDependencyNotAvailable:
216
+ from .utils import dummy_torch_and_torchsde_objects # noqa F403
217
+
218
+ _import_structure["utils.dummy_torch_and_torchsde_objects"] = [
219
+ name for name in dir(dummy_torch_and_torchsde_objects) if not name.startswith("_")
220
+ ]
221
+
222
+ else:
223
+ _import_structure["schedulers"].extend(["CosineDPMSolverMultistepScheduler", "DPMSolverSDEScheduler"])
224
+
225
+ try:
226
+ if not (is_torch_available() and is_transformers_available()):
227
+ raise OptionalDependencyNotAvailable()
228
+ except OptionalDependencyNotAvailable:
229
+ from .utils import dummy_torch_and_transformers_objects # noqa F403
230
+
231
+ _import_structure["utils.dummy_torch_and_transformers_objects"] = [
232
+ name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_")
233
+ ]
234
+
235
+ else:
236
+ _import_structure["pipelines"].extend(
237
+ [
238
+ "AltDiffusionImg2ImgPipeline",
239
+ "AltDiffusionPipeline",
240
+ "AmusedImg2ImgPipeline",
241
+ "AmusedInpaintPipeline",
242
+ "AmusedPipeline",
243
+ "AnimateDiffControlNetPipeline",
244
+ "AnimateDiffPAGPipeline",
245
+ "AnimateDiffPipeline",
246
+ "AnimateDiffSDXLPipeline",
247
+ "AnimateDiffSparseControlNetPipeline",
248
+ "AnimateDiffVideoToVideoControlNetPipeline",
249
+ "AnimateDiffVideoToVideoPipeline",
250
+ "AudioLDM2Pipeline",
251
+ "AudioLDM2ProjectionModel",
252
+ "AudioLDM2UNet2DConditionModel",
253
+ "AudioLDMPipeline",
254
+ "AuraFlowPipeline",
255
+ "BlipDiffusionControlNetPipeline",
256
+ "BlipDiffusionPipeline",
257
+ "CLIPImageProjection",
258
+ "CogVideoXImageToVideoPipeline",
259
+ "CogVideoXPipeline",
260
+ "CogVideoXVideoToVideoPipeline",
261
+ "CycleDiffusionPipeline",
262
+ "FluxControlNetImg2ImgPipeline",
263
+ "FluxControlNetInpaintPipeline",
264
+ "FluxControlNetPipeline",
265
+ "FluxImg2ImgPipeline",
266
+ "FluxInpaintPipeline",
267
+ "FluxPipeline",
268
+ "HunyuanDiTControlNetPipeline",
269
+ "HunyuanDiTPAGPipeline",
270
+ "HunyuanDiTPipeline",
271
+ "I2VGenXLPipeline",
272
+ "IFImg2ImgPipeline",
273
+ "IFImg2ImgSuperResolutionPipeline",
274
+ "IFInpaintingPipeline",
275
+ "IFInpaintingSuperResolutionPipeline",
276
+ "IFPipeline",
277
+ "IFSuperResolutionPipeline",
278
+ "ImageTextPipelineOutput",
279
+ "Kandinsky3Img2ImgPipeline",
280
+ "Kandinsky3Pipeline",
281
+ "KandinskyCombinedPipeline",
282
+ "KandinskyImg2ImgCombinedPipeline",
283
+ "KandinskyImg2ImgPipeline",
284
+ "KandinskyInpaintCombinedPipeline",
285
+ "KandinskyInpaintPipeline",
286
+ "KandinskyPipeline",
287
+ "KandinskyPriorPipeline",
288
+ "KandinskyV22CombinedPipeline",
289
+ "KandinskyV22ControlnetImg2ImgPipeline",
290
+ "KandinskyV22ControlnetPipeline",
291
+ "KandinskyV22Img2ImgCombinedPipeline",
292
+ "KandinskyV22Img2ImgPipeline",
293
+ "KandinskyV22InpaintCombinedPipeline",
294
+ "KandinskyV22InpaintPipeline",
295
+ "KandinskyV22Pipeline",
296
+ "KandinskyV22PriorEmb2EmbPipeline",
297
+ "KandinskyV22PriorPipeline",
298
+ "LatentConsistencyModelImg2ImgPipeline",
299
+ "LatentConsistencyModelPipeline",
300
+ "LattePipeline",
301
+ "LDMTextToImagePipeline",
302
+ "LEditsPPPipelineStableDiffusion",
303
+ "LEditsPPPipelineStableDiffusionXL",
304
+ "LuminaText2ImgPipeline",
305
+ "MarigoldDepthPipeline",
306
+ "MarigoldNormalsPipeline",
307
+ "MusicLDMPipeline",
308
+ "PaintByExamplePipeline",
309
+ "PIAPipeline",
310
+ "PixArtAlphaPipeline",
311
+ "PixArtSigmaPAGPipeline",
312
+ "PixArtSigmaPipeline",
313
+ "SemanticStableDiffusionPipeline",
314
+ "ShapEImg2ImgPipeline",
315
+ "ShapEPipeline",
316
+ "StableAudioPipeline",
317
+ "StableAudioProjectionModel",
318
+ "StableCascadeCombinedPipeline",
319
+ "StableCascadeDecoderPipeline",
320
+ "StableCascadePriorPipeline",
321
+ "StableDiffusion3ControlNetInpaintingPipeline",
322
+ "StableDiffusion3ControlNetPipeline",
323
+ "StableDiffusion3Img2ImgPipeline",
324
+ "StableDiffusion3InpaintPipeline",
325
+ "StableDiffusion3PAGPipeline",
326
+ "StableDiffusion3Pipeline",
327
+ "StableDiffusionAdapterPipeline",
328
+ "StableDiffusionAttendAndExcitePipeline",
329
+ "StableDiffusionControlNetImg2ImgPipeline",
330
+ "StableDiffusionControlNetInpaintPipeline",
331
+ "StableDiffusionControlNetPAGInpaintPipeline",
332
+ "StableDiffusionControlNetPAGPipeline",
333
+ "StableDiffusionControlNetPipeline",
334
+ "StableDiffusionControlNetXSPipeline",
335
+ "StableDiffusionDepth2ImgPipeline",
336
+ "StableDiffusionDiffEditPipeline",
337
+ "StableDiffusionGLIGENPipeline",
338
+ "StableDiffusionGLIGENTextImagePipeline",
339
+ "StableDiffusionImageVariationPipeline",
340
+ "StableDiffusionImg2ImgPipeline",
341
+ "StableDiffusionInpaintPipeline",
342
+ "StableDiffusionInpaintPipelineLegacy",
343
+ "StableDiffusionInstructPix2PixPipeline",
344
+ "StableDiffusionLatentUpscalePipeline",
345
+ "StableDiffusionLDM3DPipeline",
346
+ "StableDiffusionModelEditingPipeline",
347
+ "StableDiffusionPAGImg2ImgPipeline",
348
+ "StableDiffusionPAGPipeline",
349
+ "StableDiffusionPanoramaPipeline",
350
+ "StableDiffusionParadigmsPipeline",
351
+ "StableDiffusionPipeline",
352
+ "StableDiffusionPipelineSafe",
353
+ "StableDiffusionPix2PixZeroPipeline",
354
+ "StableDiffusionSAGPipeline",
355
+ "StableDiffusionUpscalePipeline",
356
+ "StableDiffusionXLAdapterPipeline",
357
+ "StableDiffusionXLControlNetImg2ImgPipeline",
358
+ "StableDiffusionXLControlNetInpaintPipeline",
359
+ "StableDiffusionXLControlNetPAGImg2ImgPipeline",
360
+ "StableDiffusionXLControlNetPAGPipeline",
361
+ "StableDiffusionXLControlNetPipeline",
362
+ "StableDiffusionXLControlNetXSPipeline",
363
+ "StableDiffusionXLImg2ImgPipeline",
364
+ "StableDiffusionXLInpaintPipeline",
365
+ "StableDiffusionXLInstructPix2PixPipeline",
366
+ "StableDiffusionXLPAGImg2ImgPipeline",
367
+ "StableDiffusionXLPAGInpaintPipeline",
368
+ "StableDiffusionXLPAGPipeline",
369
+ "StableDiffusionXLPipeline",
370
+ "StableUnCLIPImg2ImgPipeline",
371
+ "StableUnCLIPPipeline",
372
+ "StableVideoDiffusionPipeline",
373
+ "TextToVideoSDPipeline",
374
+ "TextToVideoZeroPipeline",
375
+ "TextToVideoZeroSDXLPipeline",
376
+ "UnCLIPImageVariationPipeline",
377
+ "UnCLIPPipeline",
378
+ "UniDiffuserModel",
379
+ "UniDiffuserPipeline",
380
+ "UniDiffuserTextDecoder",
381
+ "VersatileDiffusionDualGuidedPipeline",
382
+ "VersatileDiffusionImageVariationPipeline",
383
+ "VersatileDiffusionPipeline",
384
+ "VersatileDiffusionTextToImagePipeline",
385
+ "VideoToVideoSDPipeline",
386
+ "VQDiffusionPipeline",
387
+ "WuerstchenCombinedPipeline",
388
+ "WuerstchenDecoderPipeline",
389
+ "WuerstchenPriorPipeline",
390
+ ]
391
+ )
392
+
393
+ try:
394
+ if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
395
+ raise OptionalDependencyNotAvailable()
396
+ except OptionalDependencyNotAvailable:
397
+ from .utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403
398
+
399
+ _import_structure["utils.dummy_torch_and_transformers_and_k_diffusion_objects"] = [
400
+ name for name in dir(dummy_torch_and_transformers_and_k_diffusion_objects) if not name.startswith("_")
401
+ ]
402
+
403
+ else:
404
+ _import_structure["pipelines"].extend(["StableDiffusionKDiffusionPipeline", "StableDiffusionXLKDiffusionPipeline"])
405
+
406
+ try:
407
+ if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
408
+ raise OptionalDependencyNotAvailable()
409
+ except OptionalDependencyNotAvailable:
410
+ from .utils import dummy_torch_and_transformers_and_sentencepiece_objects # noqa F403
411
+
412
+ _import_structure["utils.dummy_torch_and_transformers_and_sentencepiece_objects"] = [
413
+ name for name in dir(dummy_torch_and_transformers_and_sentencepiece_objects) if not name.startswith("_")
414
+ ]
415
+
416
+ else:
417
+ _import_structure["pipelines"].extend(["KolorsImg2ImgPipeline", "KolorsPAGPipeline", "KolorsPipeline"])
418
+
419
+ try:
420
+ if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
421
+ raise OptionalDependencyNotAvailable()
422
+ except OptionalDependencyNotAvailable:
423
+ from .utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403
424
+
425
+ _import_structure["utils.dummy_torch_and_transformers_and_onnx_objects"] = [
426
+ name for name in dir(dummy_torch_and_transformers_and_onnx_objects) if not name.startswith("_")
427
+ ]
428
+
429
+ else:
430
+ _import_structure["pipelines"].extend(
431
+ [
432
+ "OnnxStableDiffusionImg2ImgPipeline",
433
+ "OnnxStableDiffusionInpaintPipeline",
434
+ "OnnxStableDiffusionInpaintPipelineLegacy",
435
+ "OnnxStableDiffusionPipeline",
436
+ "OnnxStableDiffusionUpscalePipeline",
437
+ "StableDiffusionOnnxPipeline",
438
+ ]
439
+ )
440
+
441
+ try:
442
+ if not (is_torch_available() and is_librosa_available()):
443
+ raise OptionalDependencyNotAvailable()
444
+ except OptionalDependencyNotAvailable:
445
+ from .utils import dummy_torch_and_librosa_objects # noqa F403
446
+
447
+ _import_structure["utils.dummy_torch_and_librosa_objects"] = [
448
+ name for name in dir(dummy_torch_and_librosa_objects) if not name.startswith("_")
449
+ ]
450
+
451
+ else:
452
+ _import_structure["pipelines"].extend(["AudioDiffusionPipeline", "Mel"])
453
+
454
+ try:
455
+ if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
456
+ raise OptionalDependencyNotAvailable()
457
+ except OptionalDependencyNotAvailable:
458
+ from .utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403
459
+
460
+ _import_structure["utils.dummy_transformers_and_torch_and_note_seq_objects"] = [
461
+ name for name in dir(dummy_transformers_and_torch_and_note_seq_objects) if not name.startswith("_")
462
+ ]
463
+
464
+
465
+ else:
466
+ _import_structure["pipelines"].extend(["SpectrogramDiffusionPipeline"])
467
+
468
+ try:
469
+ if not is_flax_available():
470
+ raise OptionalDependencyNotAvailable()
471
+ except OptionalDependencyNotAvailable:
472
+ from .utils import dummy_flax_objects # noqa F403
473
+
474
+ _import_structure["utils.dummy_flax_objects"] = [
475
+ name for name in dir(dummy_flax_objects) if not name.startswith("_")
476
+ ]
477
+
478
+
479
+ else:
480
+ _import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"]
481
+ _import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"]
482
+ _import_structure["models.unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
483
+ _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
484
+ _import_structure["pipelines"].extend(["FlaxDiffusionPipeline"])
485
+ _import_structure["schedulers"].extend(
486
+ [
487
+ "FlaxDDIMScheduler",
488
+ "FlaxDDPMScheduler",
489
+ "FlaxDPMSolverMultistepScheduler",
490
+ "FlaxEulerDiscreteScheduler",
491
+ "FlaxKarrasVeScheduler",
492
+ "FlaxLMSDiscreteScheduler",
493
+ "FlaxPNDMScheduler",
494
+ "FlaxSchedulerMixin",
495
+ "FlaxScoreSdeVeScheduler",
496
+ ]
497
+ )
498
+
499
+
500
+ try:
501
+ if not (is_flax_available() and is_transformers_available()):
502
+ raise OptionalDependencyNotAvailable()
503
+ except OptionalDependencyNotAvailable:
504
+ from .utils import dummy_flax_and_transformers_objects # noqa F403
505
+
506
+ _import_structure["utils.dummy_flax_and_transformers_objects"] = [
507
+ name for name in dir(dummy_flax_and_transformers_objects) if not name.startswith("_")
508
+ ]
509
+
510
+
511
+ else:
512
+ _import_structure["pipelines"].extend(
513
+ [
514
+ "FlaxStableDiffusionControlNetPipeline",
515
+ "FlaxStableDiffusionImg2ImgPipeline",
516
+ "FlaxStableDiffusionInpaintPipeline",
517
+ "FlaxStableDiffusionPipeline",
518
+ "FlaxStableDiffusionXLPipeline",
519
+ ]
520
+ )
521
+
522
+ try:
523
+ if not (is_note_seq_available()):
524
+ raise OptionalDependencyNotAvailable()
525
+ except OptionalDependencyNotAvailable:
526
+ from .utils import dummy_note_seq_objects # noqa F403
527
+
528
+ _import_structure["utils.dummy_note_seq_objects"] = [
529
+ name for name in dir(dummy_note_seq_objects) if not name.startswith("_")
530
+ ]
531
+
532
+
533
+ else:
534
+ _import_structure["pipelines"].extend(["MidiProcessor"])
535
+
536
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
537
+ from .configuration_utils import ConfigMixin
538
+
539
+ try:
540
+ if not is_onnx_available():
541
+ raise OptionalDependencyNotAvailable()
542
+ except OptionalDependencyNotAvailable:
543
+ from .utils.dummy_onnx_objects import * # noqa F403
544
+ else:
545
+ from .pipelines import OnnxRuntimeModel
546
+
547
+ try:
548
+ if not is_torch_available():
549
+ raise OptionalDependencyNotAvailable()
550
+ except OptionalDependencyNotAvailable:
551
+ from .utils.dummy_pt_objects import * # noqa F403
552
+ else:
553
+ from .models import (
554
+ AsymmetricAutoencoderKL,
555
+ AuraFlowTransformer2DModel,
556
+ AutoencoderKL,
557
+ AutoencoderKLCogVideoX,
558
+ AutoencoderKLTemporalDecoder,
559
+ AutoencoderOobleck,
560
+ AutoencoderTiny,
561
+ CogVideoXTransformer3DModel,
562
+ ConsistencyDecoderVAE,
563
+ ControlNetModel,
564
+ ControlNetXSAdapter,
565
+ DiTTransformer2DModel,
566
+ FluxControlNetModel,
567
+ FluxMultiControlNetModel,
568
+ FluxTransformer2DModel,
569
+ HunyuanDiT2DControlNetModel,
570
+ HunyuanDiT2DModel,
571
+ HunyuanDiT2DMultiControlNetModel,
572
+ I2VGenXLUNet,
573
+ Kandinsky3UNet,
574
+ LatteTransformer3DModel,
575
+ LuminaNextDiT2DModel,
576
+ ModelMixin,
577
+ MotionAdapter,
578
+ MultiAdapter,
579
+ PixArtTransformer2DModel,
580
+ PriorTransformer,
581
+ SD3ControlNetModel,
582
+ SD3MultiControlNetModel,
583
+ SD3Transformer2DModel,
584
+ SparseControlNetModel,
585
+ StableAudioDiTModel,
586
+ T2IAdapter,
587
+ T5FilmDecoder,
588
+ Transformer2DModel,
589
+ UNet1DModel,
590
+ UNet2DConditionModel,
591
+ UNet2DModel,
592
+ UNet3DConditionModel,
593
+ UNetControlNetXSModel,
594
+ UNetMotionModel,
595
+ UNetSpatioTemporalConditionModel,
596
+ UVit2DModel,
597
+ VQModel,
598
+ )
599
+ from .optimization import (
600
+ get_constant_schedule,
601
+ get_constant_schedule_with_warmup,
602
+ get_cosine_schedule_with_warmup,
603
+ get_cosine_with_hard_restarts_schedule_with_warmup,
604
+ get_linear_schedule_with_warmup,
605
+ get_polynomial_decay_schedule_with_warmup,
606
+ get_scheduler,
607
+ )
608
+ from .pipelines import (
609
+ AudioPipelineOutput,
610
+ AutoPipelineForImage2Image,
611
+ AutoPipelineForInpainting,
612
+ AutoPipelineForText2Image,
613
+ BlipDiffusionControlNetPipeline,
614
+ BlipDiffusionPipeline,
615
+ CLIPImageProjection,
616
+ ConsistencyModelPipeline,
617
+ DanceDiffusionPipeline,
618
+ DDIMPipeline,
619
+ DDPMPipeline,
620
+ DiffusionPipeline,
621
+ DiTPipeline,
622
+ ImagePipelineOutput,
623
+ KarrasVePipeline,
624
+ LDMPipeline,
625
+ LDMSuperResolutionPipeline,
626
+ PNDMPipeline,
627
+ RePaintPipeline,
628
+ ScoreSdeVePipeline,
629
+ StableDiffusionMixin,
630
+ )
631
+ from .schedulers import (
632
+ AmusedScheduler,
633
+ CMStochasticIterativeScheduler,
634
+ CogVideoXDDIMScheduler,
635
+ CogVideoXDPMScheduler,
636
+ DDIMInverseScheduler,
637
+ DDIMParallelScheduler,
638
+ DDIMScheduler,
639
+ DDPMParallelScheduler,
640
+ DDPMScheduler,
641
+ DDPMWuerstchenScheduler,
642
+ DEISMultistepScheduler,
643
+ DPMSolverMultistepInverseScheduler,
644
+ DPMSolverMultistepScheduler,
645
+ DPMSolverSinglestepScheduler,
646
+ EDMDPMSolverMultistepScheduler,
647
+ EDMEulerScheduler,
648
+ EulerAncestralDiscreteScheduler,
649
+ EulerDiscreteScheduler,
650
+ FlowMatchEulerDiscreteScheduler,
651
+ FlowMatchHeunDiscreteScheduler,
652
+ HeunDiscreteScheduler,
653
+ IPNDMScheduler,
654
+ KarrasVeScheduler,
655
+ KDPM2AncestralDiscreteScheduler,
656
+ KDPM2DiscreteScheduler,
657
+ LCMScheduler,
658
+ PNDMScheduler,
659
+ RePaintScheduler,
660
+ SASolverScheduler,
661
+ SchedulerMixin,
662
+ ScoreSdeVeScheduler,
663
+ TCDScheduler,
664
+ UnCLIPScheduler,
665
+ UniPCMultistepScheduler,
666
+ VQDiffusionScheduler,
667
+ )
668
+ from .training_utils import EMAModel
669
+
670
+ try:
671
+ if not (is_torch_available() and is_scipy_available()):
672
+ raise OptionalDependencyNotAvailable()
673
+ except OptionalDependencyNotAvailable:
674
+ from .utils.dummy_torch_and_scipy_objects import * # noqa F403
675
+ else:
676
+ from .schedulers import LMSDiscreteScheduler
677
+
678
+ try:
679
+ if not (is_torch_available() and is_torchsde_available()):
680
+ raise OptionalDependencyNotAvailable()
681
+ except OptionalDependencyNotAvailable:
682
+ from .utils.dummy_torch_and_torchsde_objects import * # noqa F403
683
+ else:
684
+ from .schedulers import CosineDPMSolverMultistepScheduler, DPMSolverSDEScheduler
685
+
686
+ try:
687
+ if not (is_torch_available() and is_transformers_available()):
688
+ raise OptionalDependencyNotAvailable()
689
+ except OptionalDependencyNotAvailable:
690
+ from .utils.dummy_torch_and_transformers_objects import * # noqa F403
691
+ else:
692
+ from .pipelines import (
693
+ AltDiffusionImg2ImgPipeline,
694
+ AltDiffusionPipeline,
695
+ AmusedImg2ImgPipeline,
696
+ AmusedInpaintPipeline,
697
+ AmusedPipeline,
698
+ AnimateDiffControlNetPipeline,
699
+ AnimateDiffPAGPipeline,
700
+ AnimateDiffPipeline,
701
+ AnimateDiffSDXLPipeline,
702
+ AnimateDiffSparseControlNetPipeline,
703
+ AnimateDiffVideoToVideoControlNetPipeline,
704
+ AnimateDiffVideoToVideoPipeline,
705
+ AudioLDM2Pipeline,
706
+ AudioLDM2ProjectionModel,
707
+ AudioLDM2UNet2DConditionModel,
708
+ AudioLDMPipeline,
709
+ AuraFlowPipeline,
710
+ CLIPImageProjection,
711
+ CogVideoXImageToVideoPipeline,
712
+ CogVideoXPipeline,
713
+ CogVideoXVideoToVideoPipeline,
714
+ CycleDiffusionPipeline,
715
+ FluxControlNetImg2ImgPipeline,
716
+ FluxControlNetInpaintPipeline,
717
+ FluxControlNetPipeline,
718
+ FluxImg2ImgPipeline,
719
+ FluxInpaintPipeline,
720
+ FluxPipeline,
721
+ HunyuanDiTControlNetPipeline,
722
+ HunyuanDiTPAGPipeline,
723
+ HunyuanDiTPipeline,
724
+ I2VGenXLPipeline,
725
+ IFImg2ImgPipeline,
726
+ IFImg2ImgSuperResolutionPipeline,
727
+ IFInpaintingPipeline,
728
+ IFInpaintingSuperResolutionPipeline,
729
+ IFPipeline,
730
+ IFSuperResolutionPipeline,
731
+ ImageTextPipelineOutput,
732
+ Kandinsky3Img2ImgPipeline,
733
+ Kandinsky3Pipeline,
734
+ KandinskyCombinedPipeline,
735
+ KandinskyImg2ImgCombinedPipeline,
736
+ KandinskyImg2ImgPipeline,
737
+ KandinskyInpaintCombinedPipeline,
738
+ KandinskyInpaintPipeline,
739
+ KandinskyPipeline,
740
+ KandinskyPriorPipeline,
741
+ KandinskyV22CombinedPipeline,
742
+ KandinskyV22ControlnetImg2ImgPipeline,
743
+ KandinskyV22ControlnetPipeline,
744
+ KandinskyV22Img2ImgCombinedPipeline,
745
+ KandinskyV22Img2ImgPipeline,
746
+ KandinskyV22InpaintCombinedPipeline,
747
+ KandinskyV22InpaintPipeline,
748
+ KandinskyV22Pipeline,
749
+ KandinskyV22PriorEmb2EmbPipeline,
750
+ KandinskyV22PriorPipeline,
751
+ LatentConsistencyModelImg2ImgPipeline,
752
+ LatentConsistencyModelPipeline,
753
+ LattePipeline,
754
+ LDMTextToImagePipeline,
755
+ LEditsPPPipelineStableDiffusion,
756
+ LEditsPPPipelineStableDiffusionXL,
757
+ LuminaText2ImgPipeline,
758
+ MarigoldDepthPipeline,
759
+ MarigoldNormalsPipeline,
760
+ MusicLDMPipeline,
761
+ PaintByExamplePipeline,
762
+ PIAPipeline,
763
+ PixArtAlphaPipeline,
764
+ PixArtSigmaPAGPipeline,
765
+ PixArtSigmaPipeline,
766
+ SemanticStableDiffusionPipeline,
767
+ ShapEImg2ImgPipeline,
768
+ ShapEPipeline,
769
+ StableAudioPipeline,
770
+ StableAudioProjectionModel,
771
+ StableCascadeCombinedPipeline,
772
+ StableCascadeDecoderPipeline,
773
+ StableCascadePriorPipeline,
774
+ StableDiffusion3ControlNetPipeline,
775
+ StableDiffusion3Img2ImgPipeline,
776
+ StableDiffusion3InpaintPipeline,
777
+ StableDiffusion3PAGPipeline,
778
+ StableDiffusion3Pipeline,
779
+ StableDiffusionAdapterPipeline,
780
+ StableDiffusionAttendAndExcitePipeline,
781
+ StableDiffusionControlNetImg2ImgPipeline,
782
+ StableDiffusionControlNetInpaintPipeline,
783
+ StableDiffusionControlNetPAGInpaintPipeline,
784
+ StableDiffusionControlNetPAGPipeline,
785
+ StableDiffusionControlNetPipeline,
786
+ StableDiffusionControlNetXSPipeline,
787
+ StableDiffusionDepth2ImgPipeline,
788
+ StableDiffusionDiffEditPipeline,
789
+ StableDiffusionGLIGENPipeline,
790
+ StableDiffusionGLIGENTextImagePipeline,
791
+ StableDiffusionImageVariationPipeline,
792
+ StableDiffusionImg2ImgPipeline,
793
+ StableDiffusionInpaintPipeline,
794
+ StableDiffusionInpaintPipelineLegacy,
795
+ StableDiffusionInstructPix2PixPipeline,
796
+ StableDiffusionLatentUpscalePipeline,
797
+ StableDiffusionLDM3DPipeline,
798
+ StableDiffusionModelEditingPipeline,
799
+ StableDiffusionPAGImg2ImgPipeline,
800
+ StableDiffusionPAGPipeline,
801
+ StableDiffusionPanoramaPipeline,
802
+ StableDiffusionParadigmsPipeline,
803
+ StableDiffusionPipeline,
804
+ StableDiffusionPipelineSafe,
805
+ StableDiffusionPix2PixZeroPipeline,
806
+ StableDiffusionSAGPipeline,
807
+ StableDiffusionUpscalePipeline,
808
+ StableDiffusionXLAdapterPipeline,
809
+ StableDiffusionXLControlNetImg2ImgPipeline,
810
+ StableDiffusionXLControlNetInpaintPipeline,
811
+ StableDiffusionXLControlNetPAGImg2ImgPipeline,
812
+ StableDiffusionXLControlNetPAGPipeline,
813
+ StableDiffusionXLControlNetPipeline,
814
+ StableDiffusionXLControlNetXSPipeline,
815
+ StableDiffusionXLImg2ImgPipeline,
816
+ StableDiffusionXLInpaintPipeline,
817
+ StableDiffusionXLInstructPix2PixPipeline,
818
+ StableDiffusionXLPAGImg2ImgPipeline,
819
+ StableDiffusionXLPAGInpaintPipeline,
820
+ StableDiffusionXLPAGPipeline,
821
+ StableDiffusionXLPipeline,
822
+ StableUnCLIPImg2ImgPipeline,
823
+ StableUnCLIPPipeline,
824
+ StableVideoDiffusionPipeline,
825
+ TextToVideoSDPipeline,
826
+ TextToVideoZeroPipeline,
827
+ TextToVideoZeroSDXLPipeline,
828
+ UnCLIPImageVariationPipeline,
829
+ UnCLIPPipeline,
830
+ UniDiffuserModel,
831
+ UniDiffuserPipeline,
832
+ UniDiffuserTextDecoder,
833
+ VersatileDiffusionDualGuidedPipeline,
834
+ VersatileDiffusionImageVariationPipeline,
835
+ VersatileDiffusionPipeline,
836
+ VersatileDiffusionTextToImagePipeline,
837
+ VideoToVideoSDPipeline,
838
+ VQDiffusionPipeline,
839
+ WuerstchenCombinedPipeline,
840
+ WuerstchenDecoderPipeline,
841
+ WuerstchenPriorPipeline,
842
+ )
843
+
844
+ try:
845
+ if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
846
+ raise OptionalDependencyNotAvailable()
847
+ except OptionalDependencyNotAvailable:
848
+ from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
849
+ else:
850
+ from .pipelines import StableDiffusionKDiffusionPipeline, StableDiffusionXLKDiffusionPipeline
851
+
852
+ try:
853
+ if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
854
+ raise OptionalDependencyNotAvailable()
855
+ except OptionalDependencyNotAvailable:
856
+ from .utils.dummy_torch_and_transformers_and_sentencepiece_objects import * # noqa F403
857
+ else:
858
+ from .pipelines import KolorsImg2ImgPipeline, KolorsPAGPipeline, KolorsPipeline
859
+ try:
860
+ if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
861
+ raise OptionalDependencyNotAvailable()
862
+ except OptionalDependencyNotAvailable:
863
+ from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
864
+ else:
865
+ from .pipelines import (
866
+ OnnxStableDiffusionImg2ImgPipeline,
867
+ OnnxStableDiffusionInpaintPipeline,
868
+ OnnxStableDiffusionInpaintPipelineLegacy,
869
+ OnnxStableDiffusionPipeline,
870
+ OnnxStableDiffusionUpscalePipeline,
871
+ StableDiffusionOnnxPipeline,
872
+ )
873
+
874
+ try:
875
+ if not (is_torch_available() and is_librosa_available()):
876
+ raise OptionalDependencyNotAvailable()
877
+ except OptionalDependencyNotAvailable:
878
+ from .utils.dummy_torch_and_librosa_objects import * # noqa F403
879
+ else:
880
+ from .pipelines import AudioDiffusionPipeline, Mel
881
+
882
+ try:
883
+ if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
884
+ raise OptionalDependencyNotAvailable()
885
+ except OptionalDependencyNotAvailable:
886
+ from .utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
887
+ else:
888
+ from .pipelines import SpectrogramDiffusionPipeline
889
+
890
+ try:
891
+ if not is_flax_available():
892
+ raise OptionalDependencyNotAvailable()
893
+ except OptionalDependencyNotAvailable:
894
+ from .utils.dummy_flax_objects import * # noqa F403
895
+ else:
896
+ from .models.controlnet_flax import FlaxControlNetModel
897
+ from .models.modeling_flax_utils import FlaxModelMixin
898
+ from .models.unets.unet_2d_condition_flax import FlaxUNet2DConditionModel
899
+ from .models.vae_flax import FlaxAutoencoderKL
900
+ from .pipelines import FlaxDiffusionPipeline
901
+ from .schedulers import (
902
+ FlaxDDIMScheduler,
903
+ FlaxDDPMScheduler,
904
+ FlaxDPMSolverMultistepScheduler,
905
+ FlaxEulerDiscreteScheduler,
906
+ FlaxKarrasVeScheduler,
907
+ FlaxLMSDiscreteScheduler,
908
+ FlaxPNDMScheduler,
909
+ FlaxSchedulerMixin,
910
+ FlaxScoreSdeVeScheduler,
911
+ )
912
+
913
+ try:
914
+ if not (is_flax_available() and is_transformers_available()):
915
+ raise OptionalDependencyNotAvailable()
916
+ except OptionalDependencyNotAvailable:
917
+ from .utils.dummy_flax_and_transformers_objects import * # noqa F403
918
+ else:
919
+ from .pipelines import (
920
+ FlaxStableDiffusionControlNetPipeline,
921
+ FlaxStableDiffusionImg2ImgPipeline,
922
+ FlaxStableDiffusionInpaintPipeline,
923
+ FlaxStableDiffusionPipeline,
924
+ FlaxStableDiffusionXLPipeline,
925
+ )
926
+
927
+ try:
928
+ if not (is_note_seq_available()):
929
+ raise OptionalDependencyNotAvailable()
930
+ except OptionalDependencyNotAvailable:
931
+ from .utils.dummy_note_seq_objects import * # noqa F403
932
+ else:
933
+ from .pipelines import MidiProcessor
934
+
935
+ else:
936
+ import sys
937
+
938
+ sys.modules[__name__] = _LazyModule(
939
+ __name__,
940
+ globals()["__file__"],
941
+ _import_structure,
942
+ module_spec=__spec__,
943
+ extra_objects={"__version__": __version__},
944
+ )
diffusers/callbacks.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+
3
+ from .configuration_utils import ConfigMixin, register_to_config
4
+ from .utils import CONFIG_NAME
5
+
6
+
7
+ class PipelineCallback(ConfigMixin):
8
+ """
9
+ Base class for all the official callbacks used in a pipeline. This class provides a structure for implementing
10
+ custom callbacks and ensures that all callbacks have a consistent interface.
11
+
12
+ Please implement the following:
13
+ `tensor_inputs`: This should return a list of tensor inputs specific to your callback. You will only be able to
14
+ include
15
+ variables listed in the `._callback_tensor_inputs` attribute of your pipeline class.
16
+ `callback_fn`: This method defines the core functionality of your callback.
17
+ """
18
+
19
+ config_name = CONFIG_NAME
20
+
21
+ @register_to_config
22
+ def __init__(self, cutoff_step_ratio=1.0, cutoff_step_index=None):
23
+ super().__init__()
24
+
25
+ if (cutoff_step_ratio is None and cutoff_step_index is None) or (
26
+ cutoff_step_ratio is not None and cutoff_step_index is not None
27
+ ):
28
+ raise ValueError("Either cutoff_step_ratio or cutoff_step_index should be provided, not both or none.")
29
+
30
+ if cutoff_step_ratio is not None and (
31
+ not isinstance(cutoff_step_ratio, float) or not (0.0 <= cutoff_step_ratio <= 1.0)
32
+ ):
33
+ raise ValueError("cutoff_step_ratio must be a float between 0.0 and 1.0.")
34
+
35
+ @property
36
+ def tensor_inputs(self) -> List[str]:
37
+ raise NotImplementedError(f"You need to set the attribute `tensor_inputs` for {self.__class__}")
38
+
39
+ def callback_fn(self, pipeline, step_index, timesteps, callback_kwargs) -> Dict[str, Any]:
40
+ raise NotImplementedError(f"You need to implement the method `callback_fn` for {self.__class__}")
41
+
42
+ def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
43
+ return self.callback_fn(pipeline, step_index, timestep, callback_kwargs)
44
+
45
+
46
+ class MultiPipelineCallbacks:
47
+ """
48
+ This class is designed to handle multiple pipeline callbacks. It accepts a list of PipelineCallback objects and
49
+ provides a unified interface for calling all of them.
50
+ """
51
+
52
+ def __init__(self, callbacks: List[PipelineCallback]):
53
+ self.callbacks = callbacks
54
+
55
+ @property
56
+ def tensor_inputs(self) -> List[str]:
57
+ return [input for callback in self.callbacks for input in callback.tensor_inputs]
58
+
59
+ def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
60
+ """
61
+ Calls all the callbacks in order with the given arguments and returns the final callback_kwargs.
62
+ """
63
+ for callback in self.callbacks:
64
+ callback_kwargs = callback(pipeline, step_index, timestep, callback_kwargs)
65
+
66
+ return callback_kwargs
67
+
68
+
69
+ class SDCFGCutoffCallback(PipelineCallback):
70
+ """
71
+ Callback function for Stable Diffusion Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
72
+ `cutoff_step_index`), this callback will disable the CFG.
73
+
74
+ Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
75
+ """
76
+
77
+ tensor_inputs = ["prompt_embeds"]
78
+
79
+ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
80
+ cutoff_step_ratio = self.config.cutoff_step_ratio
81
+ cutoff_step_index = self.config.cutoff_step_index
82
+
83
+ # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
84
+ cutoff_step = (
85
+ cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
86
+ )
87
+
88
+ if step_index == cutoff_step:
89
+ prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
90
+ prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
91
+
92
+ pipeline._guidance_scale = 0.0
93
+
94
+ callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
95
+ return callback_kwargs
96
+
97
+
98
+ class SDXLCFGCutoffCallback(PipelineCallback):
99
+ """
100
+ Callback function for Stable Diffusion XL Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
101
+ `cutoff_step_index`), this callback will disable the CFG.
102
+
103
+ Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
104
+ """
105
+
106
+ tensor_inputs = ["prompt_embeds", "add_text_embeds", "add_time_ids"]
107
+
108
+ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
109
+ cutoff_step_ratio = self.config.cutoff_step_ratio
110
+ cutoff_step_index = self.config.cutoff_step_index
111
+
112
+ # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
113
+ cutoff_step = (
114
+ cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
115
+ )
116
+
117
+ if step_index == cutoff_step:
118
+ prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
119
+ prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
120
+
121
+ add_text_embeds = callback_kwargs[self.tensor_inputs[1]]
122
+ add_text_embeds = add_text_embeds[-1:] # "-1" denotes the embeddings for conditional pooled text tokens
123
+
124
+ add_time_ids = callback_kwargs[self.tensor_inputs[2]]
125
+ add_time_ids = add_time_ids[-1:] # "-1" denotes the embeddings for conditional added time vector
126
+
127
+ pipeline._guidance_scale = 0.0
128
+
129
+ callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
130
+ callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
131
+ callback_kwargs[self.tensor_inputs[2]] = add_time_ids
132
+ return callback_kwargs
133
+
134
+
135
+ class IPAdapterScaleCutoffCallback(PipelineCallback):
136
+ """
137
+ Callback function for any pipeline that inherits `IPAdapterMixin`. After certain number of steps (set by
138
+ `cutoff_step_ratio` or `cutoff_step_index`), this callback will set the IP Adapter scale to `0.0`.
139
+
140
+ Note: This callback mutates the IP Adapter attention processors by setting the scale to 0.0 after the cutoff step.
141
+ """
142
+
143
+ tensor_inputs = []
144
+
145
+ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
146
+ cutoff_step_ratio = self.config.cutoff_step_ratio
147
+ cutoff_step_index = self.config.cutoff_step_index
148
+
149
+ # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
150
+ cutoff_step = (
151
+ cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
152
+ )
153
+
154
+ if step_index == cutoff_step:
155
+ pipeline.set_ip_adapter_scale(0.0)
156
+ return callback_kwargs
diffusers/commands/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from argparse import ArgumentParser
17
+
18
+
19
+ class BaseDiffusersCLICommand(ABC):
20
+ @staticmethod
21
+ @abstractmethod
22
+ def register_subcommand(parser: ArgumentParser):
23
+ raise NotImplementedError()
24
+
25
+ @abstractmethod
26
+ def run(self):
27
+ raise NotImplementedError()
diffusers/commands/diffusers_cli.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from argparse import ArgumentParser
17
+
18
+ from .env import EnvironmentCommand
19
+ from .fp16_safetensors import FP16SafetensorsCommand
20
+
21
+
22
+ def main():
23
+ parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli <command> [<args>]")
24
+ commands_parser = parser.add_subparsers(help="diffusers-cli command helpers")
25
+
26
+ # Register commands
27
+ EnvironmentCommand.register_subcommand(commands_parser)
28
+ FP16SafetensorsCommand.register_subcommand(commands_parser)
29
+
30
+ # Let's go
31
+ args = parser.parse_args()
32
+
33
+ if not hasattr(args, "func"):
34
+ parser.print_help()
35
+ exit(1)
36
+
37
+ # Run
38
+ service = args.func(args)
39
+ service.run()
40
+
41
+
42
+ if __name__ == "__main__":
43
+ main()
diffusers/commands/env.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import platform
16
+ import subprocess
17
+ from argparse import ArgumentParser
18
+
19
+ import huggingface_hub
20
+
21
+ from .. import __version__ as version
22
+ from ..utils import (
23
+ is_accelerate_available,
24
+ is_bitsandbytes_available,
25
+ is_flax_available,
26
+ is_google_colab,
27
+ is_peft_available,
28
+ is_safetensors_available,
29
+ is_torch_available,
30
+ is_transformers_available,
31
+ is_xformers_available,
32
+ )
33
+ from . import BaseDiffusersCLICommand
34
+
35
+
36
+ def info_command_factory(_):
37
+ return EnvironmentCommand()
38
+
39
+
40
+ class EnvironmentCommand(BaseDiffusersCLICommand):
41
+ @staticmethod
42
+ def register_subcommand(parser: ArgumentParser) -> None:
43
+ download_parser = parser.add_parser("env")
44
+ download_parser.set_defaults(func=info_command_factory)
45
+
46
+ def run(self) -> dict:
47
+ hub_version = huggingface_hub.__version__
48
+
49
+ safetensors_version = "not installed"
50
+ if is_safetensors_available():
51
+ import safetensors
52
+
53
+ safetensors_version = safetensors.__version__
54
+
55
+ pt_version = "not installed"
56
+ pt_cuda_available = "NA"
57
+ if is_torch_available():
58
+ import torch
59
+
60
+ pt_version = torch.__version__
61
+ pt_cuda_available = torch.cuda.is_available()
62
+
63
+ flax_version = "not installed"
64
+ jax_version = "not installed"
65
+ jaxlib_version = "not installed"
66
+ jax_backend = "NA"
67
+ if is_flax_available():
68
+ import flax
69
+ import jax
70
+ import jaxlib
71
+
72
+ flax_version = flax.__version__
73
+ jax_version = jax.__version__
74
+ jaxlib_version = jaxlib.__version__
75
+ jax_backend = jax.lib.xla_bridge.get_backend().platform
76
+
77
+ transformers_version = "not installed"
78
+ if is_transformers_available():
79
+ import transformers
80
+
81
+ transformers_version = transformers.__version__
82
+
83
+ accelerate_version = "not installed"
84
+ if is_accelerate_available():
85
+ import accelerate
86
+
87
+ accelerate_version = accelerate.__version__
88
+
89
+ peft_version = "not installed"
90
+ if is_peft_available():
91
+ import peft
92
+
93
+ peft_version = peft.__version__
94
+
95
+ bitsandbytes_version = "not installed"
96
+ if is_bitsandbytes_available():
97
+ import bitsandbytes
98
+
99
+ bitsandbytes_version = bitsandbytes.__version__
100
+
101
+ xformers_version = "not installed"
102
+ if is_xformers_available():
103
+ import xformers
104
+
105
+ xformers_version = xformers.__version__
106
+
107
+ platform_info = platform.platform()
108
+
109
+ is_google_colab_str = "Yes" if is_google_colab() else "No"
110
+
111
+ accelerator = "NA"
112
+ if platform.system() in {"Linux", "Windows"}:
113
+ try:
114
+ sp = subprocess.Popen(
115
+ ["nvidia-smi", "--query-gpu=gpu_name,memory.total", "--format=csv,noheader"],
116
+ stdout=subprocess.PIPE,
117
+ stderr=subprocess.PIPE,
118
+ )
119
+ out_str, _ = sp.communicate()
120
+ out_str = out_str.decode("utf-8")
121
+
122
+ if len(out_str) > 0:
123
+ accelerator = out_str.strip()
124
+ except FileNotFoundError:
125
+ pass
126
+ elif platform.system() == "Darwin": # Mac OS
127
+ try:
128
+ sp = subprocess.Popen(
129
+ ["system_profiler", "SPDisplaysDataType"],
130
+ stdout=subprocess.PIPE,
131
+ stderr=subprocess.PIPE,
132
+ )
133
+ out_str, _ = sp.communicate()
134
+ out_str = out_str.decode("utf-8")
135
+
136
+ start = out_str.find("Chipset Model:")
137
+ if start != -1:
138
+ start += len("Chipset Model:")
139
+ end = out_str.find("\n", start)
140
+ accelerator = out_str[start:end].strip()
141
+
142
+ start = out_str.find("VRAM (Total):")
143
+ if start != -1:
144
+ start += len("VRAM (Total):")
145
+ end = out_str.find("\n", start)
146
+ accelerator += " VRAM: " + out_str[start:end].strip()
147
+ except FileNotFoundError:
148
+ pass
149
+ else:
150
+ print("It seems you are running an unusual OS. Could you fill in the accelerator manually?")
151
+
152
+ info = {
153
+ "🤗 Diffusers version": version,
154
+ "Platform": platform_info,
155
+ "Running on Google Colab?": is_google_colab_str,
156
+ "Python version": platform.python_version(),
157
+ "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
158
+ "Flax version (CPU?/GPU?/TPU?)": f"{flax_version} ({jax_backend})",
159
+ "Jax version": jax_version,
160
+ "JaxLib version": jaxlib_version,
161
+ "Huggingface_hub version": hub_version,
162
+ "Transformers version": transformers_version,
163
+ "Accelerate version": accelerate_version,
164
+ "PEFT version": peft_version,
165
+ "Bitsandbytes version": bitsandbytes_version,
166
+ "Safetensors version": safetensors_version,
167
+ "xFormers version": xformers_version,
168
+ "Accelerator": accelerator,
169
+ "Using GPU in script?": "<fill in>",
170
+ "Using distributed or parallel set-up in script?": "<fill in>",
171
+ }
172
+
173
+ print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
174
+ print(self.format_dict(info))
175
+
176
+ return info
177
+
178
+ @staticmethod
179
+ def format_dict(d: dict) -> str:
180
+ return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
diffusers/commands/fp16_safetensors.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Usage example:
17
+ diffusers-cli fp16_safetensors --ckpt_id=openai/shap-e --fp16 --use_safetensors
18
+ """
19
+
20
+ import glob
21
+ import json
22
+ import warnings
23
+ from argparse import ArgumentParser, Namespace
24
+ from importlib import import_module
25
+
26
+ import huggingface_hub
27
+ import torch
28
+ from huggingface_hub import hf_hub_download
29
+ from packaging import version
30
+
31
+ from ..utils import logging
32
+ from . import BaseDiffusersCLICommand
33
+
34
+
35
+ def conversion_command_factory(args: Namespace):
36
+ if args.use_auth_token:
37
+ warnings.warn(
38
+ "The `--use_auth_token` flag is deprecated and will be removed in a future version. Authentication is now"
39
+ " handled automatically if user is logged in."
40
+ )
41
+ return FP16SafetensorsCommand(args.ckpt_id, args.fp16, args.use_safetensors)
42
+
43
+
44
+ class FP16SafetensorsCommand(BaseDiffusersCLICommand):
45
+ @staticmethod
46
+ def register_subcommand(parser: ArgumentParser):
47
+ conversion_parser = parser.add_parser("fp16_safetensors")
48
+ conversion_parser.add_argument(
49
+ "--ckpt_id",
50
+ type=str,
51
+ help="Repo id of the checkpoints on which to run the conversion. Example: 'openai/shap-e'.",
52
+ )
53
+ conversion_parser.add_argument(
54
+ "--fp16", action="store_true", help="If serializing the variables in FP16 precision."
55
+ )
56
+ conversion_parser.add_argument(
57
+ "--use_safetensors", action="store_true", help="If serializing in the safetensors format."
58
+ )
59
+ conversion_parser.add_argument(
60
+ "--use_auth_token",
61
+ action="store_true",
62
+ help="When working with checkpoints having private visibility. When used `huggingface-cli login` needs to be run beforehand.",
63
+ )
64
+ conversion_parser.set_defaults(func=conversion_command_factory)
65
+
66
+ def __init__(self, ckpt_id: str, fp16: bool, use_safetensors: bool):
67
+ self.logger = logging.get_logger("diffusers-cli/fp16_safetensors")
68
+ self.ckpt_id = ckpt_id
69
+ self.local_ckpt_dir = f"/tmp/{ckpt_id}"
70
+ self.fp16 = fp16
71
+
72
+ self.use_safetensors = use_safetensors
73
+
74
+ if not self.use_safetensors and not self.fp16:
75
+ raise NotImplementedError(
76
+ "When `use_safetensors` and `fp16` both are False, then this command is of no use."
77
+ )
78
+
79
+ def run(self):
80
+ if version.parse(huggingface_hub.__version__) < version.parse("0.9.0"):
81
+ raise ImportError(
82
+ "The huggingface_hub version must be >= 0.9.0 to use this command. Please update your huggingface_hub"
83
+ " installation."
84
+ )
85
+ else:
86
+ from huggingface_hub import create_commit
87
+ from huggingface_hub._commit_api import CommitOperationAdd
88
+
89
+ model_index = hf_hub_download(repo_id=self.ckpt_id, filename="model_index.json")
90
+ with open(model_index, "r") as f:
91
+ pipeline_class_name = json.load(f)["_class_name"]
92
+ pipeline_class = getattr(import_module("diffusers"), pipeline_class_name)
93
+ self.logger.info(f"Pipeline class imported: {pipeline_class_name}.")
94
+
95
+ # Load the appropriate pipeline. We could have use `DiffusionPipeline`
96
+ # here, but just to avoid any rough edge cases.
97
+ pipeline = pipeline_class.from_pretrained(
98
+ self.ckpt_id, torch_dtype=torch.float16 if self.fp16 else torch.float32
99
+ )
100
+ pipeline.save_pretrained(
101
+ self.local_ckpt_dir,
102
+ safe_serialization=True if self.use_safetensors else False,
103
+ variant="fp16" if self.fp16 else None,
104
+ )
105
+ self.logger.info(f"Pipeline locally saved to {self.local_ckpt_dir}.")
106
+
107
+ # Fetch all the paths.
108
+ if self.fp16:
109
+ modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.fp16.*")
110
+ elif self.use_safetensors:
111
+ modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.safetensors")
112
+
113
+ # Prepare for the PR.
114
+ commit_message = f"Serialize variables with FP16: {self.fp16} and safetensors: {self.use_safetensors}."
115
+ operations = []
116
+ for path in modified_paths:
117
+ operations.append(CommitOperationAdd(path_in_repo="/".join(path.split("/")[4:]), path_or_fileobj=path))
118
+
119
+ # Open the PR.
120
+ commit_description = (
121
+ "Variables converted by the [`diffusers`' `fp16_safetensors`"
122
+ " CLI](https://github.com/huggingface/diffusers/blob/main/src/diffusers/commands/fp16_safetensors.py)."
123
+ )
124
+ hub_pr_url = create_commit(
125
+ repo_id=self.ckpt_id,
126
+ operations=operations,
127
+ commit_message=commit_message,
128
+ commit_description=commit_description,
129
+ repo_type="model",
130
+ create_pr=True,
131
+ ).pr_url
132
+ self.logger.info(f"PR created here: {hub_pr_url}.")
diffusers/configuration_utils.py ADDED
@@ -0,0 +1,720 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ConfigMixin base class and utilities."""
17
+
18
+ import dataclasses
19
+ import functools
20
+ import importlib
21
+ import inspect
22
+ import json
23
+ import os
24
+ import re
25
+ from collections import OrderedDict
26
+ from pathlib import Path
27
+ from typing import Any, Dict, Tuple, Union
28
+
29
+ import numpy as np
30
+ from huggingface_hub import create_repo, hf_hub_download
31
+ from huggingface_hub.utils import (
32
+ EntryNotFoundError,
33
+ RepositoryNotFoundError,
34
+ RevisionNotFoundError,
35
+ validate_hf_hub_args,
36
+ )
37
+ from requests import HTTPError
38
+
39
+ from . import __version__
40
+ from .utils import (
41
+ HUGGINGFACE_CO_RESOLVE_ENDPOINT,
42
+ DummyObject,
43
+ deprecate,
44
+ extract_commit_hash,
45
+ http_user_agent,
46
+ logging,
47
+ )
48
+
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+ _re_configuration_file = re.compile(r"config\.(.*)\.json")
53
+
54
+
55
+ class FrozenDict(OrderedDict):
56
+ def __init__(self, *args, **kwargs):
57
+ super().__init__(*args, **kwargs)
58
+
59
+ for key, value in self.items():
60
+ setattr(self, key, value)
61
+
62
+ self.__frozen = True
63
+
64
+ def __delitem__(self, *args, **kwargs):
65
+ raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
66
+
67
+ def setdefault(self, *args, **kwargs):
68
+ raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
69
+
70
+ def pop(self, *args, **kwargs):
71
+ raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
72
+
73
+ def update(self, *args, **kwargs):
74
+ raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
75
+
76
+ def __setattr__(self, name, value):
77
+ if hasattr(self, "__frozen") and self.__frozen:
78
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
79
+ super().__setattr__(name, value)
80
+
81
+ def __setitem__(self, name, value):
82
+ if hasattr(self, "__frozen") and self.__frozen:
83
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
84
+ super().__setitem__(name, value)
85
+
86
+
87
+ class ConfigMixin:
88
+ r"""
89
+ Base class for all configuration classes. All configuration parameters are stored under `self.config`. Also
90
+ provides the [`~ConfigMixin.from_config`] and [`~ConfigMixin.save_config`] methods for loading, downloading, and
91
+ saving classes that inherit from [`ConfigMixin`].
92
+
93
+ Class attributes:
94
+ - **config_name** (`str`) -- A filename under which the config should stored when calling
95
+ [`~ConfigMixin.save_config`] (should be overridden by parent class).
96
+ - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
97
+ overridden by subclass).
98
+ - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
99
+ - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the `init` function
100
+ should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
101
+ subclass).
102
+ """
103
+
104
+ config_name = None
105
+ ignore_for_config = []
106
+ has_compatibles = False
107
+
108
+ _deprecated_kwargs = []
109
+
110
+ def register_to_config(self, **kwargs):
111
+ if self.config_name is None:
112
+ raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
113
+ # Special case for `kwargs` used in deprecation warning added to schedulers
114
+ # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
115
+ # or solve in a more general way.
116
+ kwargs.pop("kwargs", None)
117
+
118
+ if not hasattr(self, "_internal_dict"):
119
+ internal_dict = kwargs
120
+ else:
121
+ previous_dict = dict(self._internal_dict)
122
+ internal_dict = {**self._internal_dict, **kwargs}
123
+ logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
124
+
125
+ self._internal_dict = FrozenDict(internal_dict)
126
+
127
+ def __getattr__(self, name: str) -> Any:
128
+ """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
129
+ config attributes directly. See https://github.com/huggingface/diffusers/pull/3129
130
+
131
+ This function is mostly copied from PyTorch's __getattr__ overwrite:
132
+ https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
133
+ """
134
+
135
+ is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
136
+ is_attribute = name in self.__dict__
137
+
138
+ if is_in_config and not is_attribute:
139
+ deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'."
140
+ deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
141
+ return self._internal_dict[name]
142
+
143
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
144
+
145
+ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
146
+ """
147
+ Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the
148
+ [`~ConfigMixin.from_config`] class method.
149
+
150
+ Args:
151
+ save_directory (`str` or `os.PathLike`):
152
+ Directory where the configuration JSON file is saved (will be created if it does not exist).
153
+ push_to_hub (`bool`, *optional*, defaults to `False`):
154
+ Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
155
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
156
+ namespace).
157
+ kwargs (`Dict[str, Any]`, *optional*):
158
+ Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
159
+ """
160
+ if os.path.isfile(save_directory):
161
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
162
+
163
+ os.makedirs(save_directory, exist_ok=True)
164
+
165
+ # If we save using the predefined names, we can load using `from_config`
166
+ output_config_file = os.path.join(save_directory, self.config_name)
167
+
168
+ self.to_json_file(output_config_file)
169
+ logger.info(f"Configuration saved in {output_config_file}")
170
+
171
+ if push_to_hub:
172
+ commit_message = kwargs.pop("commit_message", None)
173
+ private = kwargs.pop("private", False)
174
+ create_pr = kwargs.pop("create_pr", False)
175
+ token = kwargs.pop("token", None)
176
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
177
+ repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
178
+
179
+ self._upload_folder(
180
+ save_directory,
181
+ repo_id,
182
+ token=token,
183
+ commit_message=commit_message,
184
+ create_pr=create_pr,
185
+ )
186
+
187
+ @classmethod
188
+ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
189
+ r"""
190
+ Instantiate a Python class from a config dictionary.
191
+
192
+ Parameters:
193
+ config (`Dict[str, Any]`):
194
+ A config dictionary from which the Python class is instantiated. Make sure to only load configuration
195
+ files of compatible classes.
196
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
197
+ Whether kwargs that are not consumed by the Python class should be returned or not.
198
+ kwargs (remaining dictionary of keyword arguments, *optional*):
199
+ Can be used to update the configuration object (after it is loaded) and initiate the Python class.
200
+ `**kwargs` are passed directly to the underlying scheduler/model's `__init__` method and eventually
201
+ overwrite the same named arguments in `config`.
202
+
203
+ Returns:
204
+ [`ModelMixin`] or [`SchedulerMixin`]:
205
+ A model or scheduler object instantiated from a config dictionary.
206
+
207
+ Examples:
208
+
209
+ ```python
210
+ >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
211
+
212
+ >>> # Download scheduler from huggingface.co and cache.
213
+ >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
214
+
215
+ >>> # Instantiate DDIM scheduler class with same config as DDPM
216
+ >>> scheduler = DDIMScheduler.from_config(scheduler.config)
217
+
218
+ >>> # Instantiate PNDM scheduler class with same config as DDPM
219
+ >>> scheduler = PNDMScheduler.from_config(scheduler.config)
220
+ ```
221
+ """
222
+ # <===== TO BE REMOVED WITH DEPRECATION
223
+ # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
224
+ if "pretrained_model_name_or_path" in kwargs:
225
+ config = kwargs.pop("pretrained_model_name_or_path")
226
+
227
+ if config is None:
228
+ raise ValueError("Please make sure to provide a config as the first positional argument.")
229
+ # ======>
230
+
231
+ if not isinstance(config, dict):
232
+ deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
233
+ if "Scheduler" in cls.__name__:
234
+ deprecation_message += (
235
+ f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
236
+ " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
237
+ " be removed in v1.0.0."
238
+ )
239
+ elif "Model" in cls.__name__:
240
+ deprecation_message += (
241
+ f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
242
+ f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
243
+ " instead. This functionality will be removed in v1.0.0."
244
+ )
245
+ deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
246
+ config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
247
+
248
+ init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
249
+
250
+ # Allow dtype to be specified on initialization
251
+ if "dtype" in unused_kwargs:
252
+ init_dict["dtype"] = unused_kwargs.pop("dtype")
253
+
254
+ # add possible deprecated kwargs
255
+ for deprecated_kwarg in cls._deprecated_kwargs:
256
+ if deprecated_kwarg in unused_kwargs:
257
+ init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
258
+
259
+ # Return model and optionally state and/or unused_kwargs
260
+ model = cls(**init_dict)
261
+
262
+ # make sure to also save config parameters that might be used for compatible classes
263
+ # update _class_name
264
+ if "_class_name" in hidden_dict:
265
+ hidden_dict["_class_name"] = cls.__name__
266
+
267
+ model.register_to_config(**hidden_dict)
268
+
269
+ # add hidden kwargs of compatible classes to unused_kwargs
270
+ unused_kwargs = {**unused_kwargs, **hidden_dict}
271
+
272
+ if return_unused_kwargs:
273
+ return (model, unused_kwargs)
274
+ else:
275
+ return model
276
+
277
+ @classmethod
278
+ def get_config_dict(cls, *args, **kwargs):
279
+ deprecation_message = (
280
+ f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
281
+ " removed in version v1.0.0"
282
+ )
283
+ deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
284
+ return cls.load_config(*args, **kwargs)
285
+
286
+ @classmethod
287
+ @validate_hf_hub_args
288
+ def load_config(
289
+ cls,
290
+ pretrained_model_name_or_path: Union[str, os.PathLike],
291
+ return_unused_kwargs=False,
292
+ return_commit_hash=False,
293
+ **kwargs,
294
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
295
+ r"""
296
+ Load a model or scheduler configuration.
297
+
298
+ Parameters:
299
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
300
+ Can be either:
301
+
302
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
303
+ the Hub.
304
+ - A path to a *directory* (for example `./my_model_directory`) containing model weights saved with
305
+ [`~ConfigMixin.save_config`].
306
+
307
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
308
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
309
+ is not used.
310
+ force_download (`bool`, *optional*, defaults to `False`):
311
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
312
+ cached versions if they exist.
313
+ proxies (`Dict[str, str]`, *optional*):
314
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
315
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
316
+ output_loading_info(`bool`, *optional*, defaults to `False`):
317
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
318
+ local_files_only (`bool`, *optional*, defaults to `False`):
319
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
320
+ won't be downloaded from the Hub.
321
+ token (`str` or *bool*, *optional*):
322
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
323
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
324
+ revision (`str`, *optional*, defaults to `"main"`):
325
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
326
+ allowed by Git.
327
+ subfolder (`str`, *optional*, defaults to `""`):
328
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
329
+ return_unused_kwargs (`bool`, *optional*, defaults to `False):
330
+ Whether unused keyword arguments of the config are returned.
331
+ return_commit_hash (`bool`, *optional*, defaults to `False):
332
+ Whether the `commit_hash` of the loaded configuration are returned.
333
+
334
+ Returns:
335
+ `dict`:
336
+ A dictionary of all the parameters stored in a JSON configuration file.
337
+
338
+ """
339
+ cache_dir = kwargs.pop("cache_dir", None)
340
+ local_dir = kwargs.pop("local_dir", None)
341
+ local_dir_use_symlinks = kwargs.pop("local_dir_use_symlinks", "auto")
342
+ force_download = kwargs.pop("force_download", False)
343
+ proxies = kwargs.pop("proxies", None)
344
+ token = kwargs.pop("token", None)
345
+ local_files_only = kwargs.pop("local_files_only", False)
346
+ revision = kwargs.pop("revision", None)
347
+ _ = kwargs.pop("mirror", None)
348
+ subfolder = kwargs.pop("subfolder", None)
349
+ user_agent = kwargs.pop("user_agent", {})
350
+
351
+ user_agent = {**user_agent, "file_type": "config"}
352
+ user_agent = http_user_agent(user_agent)
353
+
354
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
355
+
356
+ if cls.config_name is None:
357
+ raise ValueError(
358
+ "`self.config_name` is not defined. Note that one should not load a config from "
359
+ "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
360
+ )
361
+
362
+ if os.path.isfile(pretrained_model_name_or_path):
363
+ config_file = pretrained_model_name_or_path
364
+ elif os.path.isdir(pretrained_model_name_or_path):
365
+ if subfolder is not None and os.path.isfile(
366
+ os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
367
+ ):
368
+ config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
369
+ elif os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
370
+ # Load from a PyTorch checkpoint
371
+ config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
372
+ else:
373
+ raise EnvironmentError(
374
+ f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
375
+ )
376
+ else:
377
+ try:
378
+ # Load from URL or cache if already cached
379
+ config_file = hf_hub_download(
380
+ pretrained_model_name_or_path,
381
+ filename=cls.config_name,
382
+ cache_dir=cache_dir,
383
+ force_download=force_download,
384
+ proxies=proxies,
385
+ local_files_only=local_files_only,
386
+ token=token,
387
+ user_agent=user_agent,
388
+ subfolder=subfolder,
389
+ revision=revision,
390
+ local_dir=local_dir,
391
+ local_dir_use_symlinks=local_dir_use_symlinks,
392
+ )
393
+ except RepositoryNotFoundError:
394
+ raise EnvironmentError(
395
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
396
+ " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
397
+ " token having permission to this repo with `token` or log in with `huggingface-cli login`."
398
+ )
399
+ except RevisionNotFoundError:
400
+ raise EnvironmentError(
401
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
402
+ " this model name. Check the model page at"
403
+ f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
404
+ )
405
+ except EntryNotFoundError:
406
+ raise EnvironmentError(
407
+ f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
408
+ )
409
+ except HTTPError as err:
410
+ raise EnvironmentError(
411
+ "There was a specific connection error when trying to load"
412
+ f" {pretrained_model_name_or_path}:\n{err}"
413
+ )
414
+ except ValueError:
415
+ raise EnvironmentError(
416
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
417
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
418
+ f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
419
+ " run the library in offline mode at"
420
+ " 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
421
+ )
422
+ except EnvironmentError:
423
+ raise EnvironmentError(
424
+ f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
425
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
426
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
427
+ f"containing a {cls.config_name} file"
428
+ )
429
+
430
+ try:
431
+ # Load config dict
432
+ config_dict = cls._dict_from_json_file(config_file)
433
+
434
+ commit_hash = extract_commit_hash(config_file)
435
+ except (json.JSONDecodeError, UnicodeDecodeError):
436
+ raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
437
+
438
+ if not (return_unused_kwargs or return_commit_hash):
439
+ return config_dict
440
+
441
+ outputs = (config_dict,)
442
+
443
+ if return_unused_kwargs:
444
+ outputs += (kwargs,)
445
+
446
+ if return_commit_hash:
447
+ outputs += (commit_hash,)
448
+
449
+ return outputs
450
+
451
+ @staticmethod
452
+ def _get_init_keys(input_class):
453
+ return set(dict(inspect.signature(input_class.__init__).parameters).keys())
454
+
455
+ @classmethod
456
+ def extract_init_dict(cls, config_dict, **kwargs):
457
+ # Skip keys that were not present in the original config, so default __init__ values were used
458
+ used_defaults = config_dict.get("_use_default_values", [])
459
+ config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"}
460
+
461
+ # 0. Copy origin config dict
462
+ original_dict = dict(config_dict.items())
463
+
464
+ # 1. Retrieve expected config attributes from __init__ signature
465
+ expected_keys = cls._get_init_keys(cls)
466
+ expected_keys.remove("self")
467
+ # remove general kwargs if present in dict
468
+ if "kwargs" in expected_keys:
469
+ expected_keys.remove("kwargs")
470
+ # remove flax internal keys
471
+ if hasattr(cls, "_flax_internal_args"):
472
+ for arg in cls._flax_internal_args:
473
+ expected_keys.remove(arg)
474
+
475
+ # 2. Remove attributes that cannot be expected from expected config attributes
476
+ # remove keys to be ignored
477
+ if len(cls.ignore_for_config) > 0:
478
+ expected_keys = expected_keys - set(cls.ignore_for_config)
479
+
480
+ # load diffusers library to import compatible and original scheduler
481
+ diffusers_library = importlib.import_module(__name__.split(".")[0])
482
+
483
+ if cls.has_compatibles:
484
+ compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)]
485
+ else:
486
+ compatible_classes = []
487
+
488
+ expected_keys_comp_cls = set()
489
+ for c in compatible_classes:
490
+ expected_keys_c = cls._get_init_keys(c)
491
+ expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
492
+ expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
493
+ config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
494
+
495
+ # remove attributes from orig class that cannot be expected
496
+ orig_cls_name = config_dict.pop("_class_name", cls.__name__)
497
+ if (
498
+ isinstance(orig_cls_name, str)
499
+ and orig_cls_name != cls.__name__
500
+ and hasattr(diffusers_library, orig_cls_name)
501
+ ):
502
+ orig_cls = getattr(diffusers_library, orig_cls_name)
503
+ unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
504
+ config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
505
+ elif not isinstance(orig_cls_name, str) and not isinstance(orig_cls_name, (list, tuple)):
506
+ raise ValueError(
507
+ "Make sure that the `_class_name` is of type string or list of string (for custom pipelines)."
508
+ )
509
+
510
+ # remove private attributes
511
+ config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
512
+
513
+ # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
514
+ init_dict = {}
515
+ for key in expected_keys:
516
+ # if config param is passed to kwarg and is present in config dict
517
+ # it should overwrite existing config dict key
518
+ if key in kwargs and key in config_dict:
519
+ config_dict[key] = kwargs.pop(key)
520
+
521
+ if key in kwargs:
522
+ # overwrite key
523
+ init_dict[key] = kwargs.pop(key)
524
+ elif key in config_dict:
525
+ # use value from config dict
526
+ init_dict[key] = config_dict.pop(key)
527
+
528
+ # 4. Give nice warning if unexpected values have been passed
529
+ if len(config_dict) > 0:
530
+ logger.warning(
531
+ f"The config attributes {config_dict} were passed to {cls.__name__}, "
532
+ "but are not expected and will be ignored. Please verify your "
533
+ f"{cls.config_name} configuration file."
534
+ )
535
+
536
+ # 5. Give nice info if config attributes are initialized to default because they have not been passed
537
+ passed_keys = set(init_dict.keys())
538
+ if len(expected_keys - passed_keys) > 0:
539
+ logger.info(
540
+ f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
541
+ )
542
+
543
+ # 6. Define unused keyword arguments
544
+ unused_kwargs = {**config_dict, **kwargs}
545
+
546
+ # 7. Define "hidden" config parameters that were saved for compatible classes
547
+ hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
548
+
549
+ return init_dict, unused_kwargs, hidden_config_dict
550
+
551
+ @classmethod
552
+ def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
553
+ with open(json_file, "r", encoding="utf-8") as reader:
554
+ text = reader.read()
555
+ return json.loads(text)
556
+
557
+ def __repr__(self):
558
+ return f"{self.__class__.__name__} {self.to_json_string()}"
559
+
560
+ @property
561
+ def config(self) -> Dict[str, Any]:
562
+ """
563
+ Returns the config of the class as a frozen dictionary
564
+
565
+ Returns:
566
+ `Dict[str, Any]`: Config of the class.
567
+ """
568
+ return self._internal_dict
569
+
570
+ def to_json_string(self) -> str:
571
+ """
572
+ Serializes the configuration instance to a JSON string.
573
+
574
+ Returns:
575
+ `str`:
576
+ String containing all the attributes that make up the configuration instance in JSON format.
577
+ """
578
+ config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
579
+ config_dict["_class_name"] = self.__class__.__name__
580
+ config_dict["_diffusers_version"] = __version__
581
+
582
+ def to_json_saveable(value):
583
+ if isinstance(value, np.ndarray):
584
+ value = value.tolist()
585
+ elif isinstance(value, Path):
586
+ value = value.as_posix()
587
+ return value
588
+
589
+ config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
590
+ # Don't save "_ignore_files" or "_use_default_values"
591
+ config_dict.pop("_ignore_files", None)
592
+ config_dict.pop("_use_default_values", None)
593
+
594
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
595
+
596
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
597
+ """
598
+ Save the configuration instance's parameters to a JSON file.
599
+
600
+ Args:
601
+ json_file_path (`str` or `os.PathLike`):
602
+ Path to the JSON file to save a configuration instance's parameters.
603
+ """
604
+ with open(json_file_path, "w", encoding="utf-8") as writer:
605
+ writer.write(self.to_json_string())
606
+
607
+
608
+ def register_to_config(init):
609
+ r"""
610
+ Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
611
+ automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
612
+ shouldn't be registered in the config, use the `ignore_for_config` class variable
613
+
614
+ Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
615
+ """
616
+
617
+ @functools.wraps(init)
618
+ def inner_init(self, *args, **kwargs):
619
+ # Ignore private kwargs in the init.
620
+ init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
621
+ config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
622
+ if not isinstance(self, ConfigMixin):
623
+ raise RuntimeError(
624
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
625
+ "not inherit from `ConfigMixin`."
626
+ )
627
+
628
+ ignore = getattr(self, "ignore_for_config", [])
629
+ # Get positional arguments aligned with kwargs
630
+ new_kwargs = {}
631
+ signature = inspect.signature(init)
632
+ parameters = {
633
+ name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
634
+ }
635
+ for arg, name in zip(args, parameters.keys()):
636
+ new_kwargs[name] = arg
637
+
638
+ # Then add all kwargs
639
+ new_kwargs.update(
640
+ {
641
+ k: init_kwargs.get(k, default)
642
+ for k, default in parameters.items()
643
+ if k not in ignore and k not in new_kwargs
644
+ }
645
+ )
646
+
647
+ # Take note of the parameters that were not present in the loaded config
648
+ if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
649
+ new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
650
+
651
+ new_kwargs = {**config_init_kwargs, **new_kwargs}
652
+ getattr(self, "register_to_config")(**new_kwargs)
653
+ init(self, *args, **init_kwargs)
654
+
655
+ return inner_init
656
+
657
+
658
+ def flax_register_to_config(cls):
659
+ original_init = cls.__init__
660
+
661
+ @functools.wraps(original_init)
662
+ def init(self, *args, **kwargs):
663
+ if not isinstance(self, ConfigMixin):
664
+ raise RuntimeError(
665
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
666
+ "not inherit from `ConfigMixin`."
667
+ )
668
+
669
+ # Ignore private kwargs in the init. Retrieve all passed attributes
670
+ init_kwargs = dict(kwargs.items())
671
+
672
+ # Retrieve default values
673
+ fields = dataclasses.fields(self)
674
+ default_kwargs = {}
675
+ for field in fields:
676
+ # ignore flax specific attributes
677
+ if field.name in self._flax_internal_args:
678
+ continue
679
+ if type(field.default) == dataclasses._MISSING_TYPE:
680
+ default_kwargs[field.name] = None
681
+ else:
682
+ default_kwargs[field.name] = getattr(self, field.name)
683
+
684
+ # Make sure init_kwargs override default kwargs
685
+ new_kwargs = {**default_kwargs, **init_kwargs}
686
+ # dtype should be part of `init_kwargs`, but not `new_kwargs`
687
+ if "dtype" in new_kwargs:
688
+ new_kwargs.pop("dtype")
689
+
690
+ # Get positional arguments aligned with kwargs
691
+ for i, arg in enumerate(args):
692
+ name = fields[i].name
693
+ new_kwargs[name] = arg
694
+
695
+ # Take note of the parameters that were not present in the loaded config
696
+ if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
697
+ new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
698
+
699
+ getattr(self, "register_to_config")(**new_kwargs)
700
+ original_init(self, *args, **kwargs)
701
+
702
+ cls.__init__ = init
703
+ return cls
704
+
705
+
706
+ class LegacyConfigMixin(ConfigMixin):
707
+ r"""
708
+ A subclass of `ConfigMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more
709
+ pipeline-specific classes (like `DiTTransformer2DModel`).
710
+ """
711
+
712
+ @classmethod
713
+ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
714
+ # To prevent dependency import problem.
715
+ from .models.model_loading_utils import _fetch_remapped_cls_from_config
716
+
717
+ # resolve remapping
718
+ remapped_class = _fetch_remapped_cls_from_config(config, cls)
719
+
720
+ return remapped_class.from_config(config, return_unused_kwargs, **kwargs)
diffusers/dependency_versions_check.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .dependency_versions_table import deps
16
+ from .utils.versions import require_version, require_version_core
17
+
18
+
19
+ # define which module versions we always want to check at run time
20
+ # (usually the ones defined in `install_requires` in setup.py)
21
+ #
22
+ # order specific notes:
23
+ # - tqdm must be checked before tokenizers
24
+
25
+ pkgs_to_check_at_runtime = "python requests filelock numpy".split()
26
+ for pkg in pkgs_to_check_at_runtime:
27
+ if pkg in deps:
28
+ require_version_core(deps[pkg])
29
+ else:
30
+ raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
31
+
32
+
33
+ def dep_version_check(pkg, hint=None):
34
+ require_version(deps[pkg], hint)
diffusers/dependency_versions_table.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # THIS FILE HAS BEEN AUTOGENERATED. To update:
2
+ # 1. modify the `_deps` dict in setup.py
3
+ # 2. run `make deps_table_update`
4
+ deps = {
5
+ "Pillow": "Pillow",
6
+ "accelerate": "accelerate>=0.31.0",
7
+ "compel": "compel==0.1.8",
8
+ "datasets": "datasets",
9
+ "filelock": "filelock",
10
+ "flax": "flax>=0.4.1",
11
+ "hf-doc-builder": "hf-doc-builder>=0.3.0",
12
+ "huggingface-hub": "huggingface-hub>=0.23.2",
13
+ "requests-mock": "requests-mock==1.10.0",
14
+ "importlib_metadata": "importlib_metadata",
15
+ "invisible-watermark": "invisible-watermark>=0.2.0",
16
+ "isort": "isort>=5.5.4",
17
+ "jax": "jax>=0.4.1",
18
+ "jaxlib": "jaxlib>=0.4.1",
19
+ "Jinja2": "Jinja2",
20
+ "k-diffusion": "k-diffusion>=0.0.12",
21
+ "torchsde": "torchsde",
22
+ "note_seq": "note_seq",
23
+ "librosa": "librosa",
24
+ "numpy": "numpy",
25
+ "parameterized": "parameterized",
26
+ "peft": "peft>=0.6.0",
27
+ "protobuf": "protobuf>=3.20.3,<4",
28
+ "pytest": "pytest",
29
+ "pytest-timeout": "pytest-timeout",
30
+ "pytest-xdist": "pytest-xdist",
31
+ "python": "python>=3.8.0",
32
+ "ruff": "ruff==0.1.5",
33
+ "safetensors": "safetensors>=0.3.1",
34
+ "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
35
+ "GitPython": "GitPython<3.1.19",
36
+ "scipy": "scipy",
37
+ "onnx": "onnx",
38
+ "regex": "regex!=2019.12.17",
39
+ "requests": "requests",
40
+ "tensorboard": "tensorboard",
41
+ "torch": "torch>=1.4",
42
+ "torchvision": "torchvision",
43
+ "transformers": "transformers>=4.41.2",
44
+ "urllib3": "urllib3<=2.0.0",
45
+ "black": "black",
46
+ }
diffusers/experimental/README.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # 🧨 Diffusers Experimental
2
+
3
+ We are adding experimental code to support novel applications and usages of the Diffusers library.
4
+ Currently, the following experiments are supported:
5
+ * Reinforcement learning via an implementation of the [Diffuser](https://arxiv.org/abs/2205.09991) model.
diffusers/experimental/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .rl import ValueGuidedRLPipeline
diffusers/experimental/rl/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .value_guided_sampling import ValueGuidedRLPipeline
diffusers/experimental/rl/value_guided_sampling.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+ import torch
17
+ import tqdm
18
+
19
+ from ...models.unets.unet_1d import UNet1DModel
20
+ from ...pipelines import DiffusionPipeline
21
+ from ...utils.dummy_pt_objects import DDPMScheduler
22
+ from ...utils.torch_utils import randn_tensor
23
+
24
+
25
+ class ValueGuidedRLPipeline(DiffusionPipeline):
26
+ r"""
27
+ Pipeline for value-guided sampling from a diffusion model trained to predict sequences of states.
28
+
29
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
30
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
31
+
32
+ Parameters:
33
+ value_function ([`UNet1DModel`]):
34
+ A specialized UNet for fine-tuning trajectories base on reward.
35
+ unet ([`UNet1DModel`]):
36
+ UNet architecture to denoise the encoded trajectories.
37
+ scheduler ([`SchedulerMixin`]):
38
+ A scheduler to be used in combination with `unet` to denoise the encoded trajectories. Default for this
39
+ application is [`DDPMScheduler`].
40
+ env ():
41
+ An environment following the OpenAI gym API to act in. For now only Hopper has pretrained models.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ value_function: UNet1DModel,
47
+ unet: UNet1DModel,
48
+ scheduler: DDPMScheduler,
49
+ env,
50
+ ):
51
+ super().__init__()
52
+
53
+ self.register_modules(value_function=value_function, unet=unet, scheduler=scheduler, env=env)
54
+
55
+ self.data = env.get_dataset()
56
+ self.means = {}
57
+ for key in self.data.keys():
58
+ try:
59
+ self.means[key] = self.data[key].mean()
60
+ except: # noqa: E722
61
+ pass
62
+ self.stds = {}
63
+ for key in self.data.keys():
64
+ try:
65
+ self.stds[key] = self.data[key].std()
66
+ except: # noqa: E722
67
+ pass
68
+ self.state_dim = env.observation_space.shape[0]
69
+ self.action_dim = env.action_space.shape[0]
70
+
71
+ def normalize(self, x_in, key):
72
+ return (x_in - self.means[key]) / self.stds[key]
73
+
74
+ def de_normalize(self, x_in, key):
75
+ return x_in * self.stds[key] + self.means[key]
76
+
77
+ def to_torch(self, x_in):
78
+ if isinstance(x_in, dict):
79
+ return {k: self.to_torch(v) for k, v in x_in.items()}
80
+ elif torch.is_tensor(x_in):
81
+ return x_in.to(self.unet.device)
82
+ return torch.tensor(x_in, device=self.unet.device)
83
+
84
+ def reset_x0(self, x_in, cond, act_dim):
85
+ for key, val in cond.items():
86
+ x_in[:, key, act_dim:] = val.clone()
87
+ return x_in
88
+
89
+ def run_diffusion(self, x, conditions, n_guide_steps, scale):
90
+ batch_size = x.shape[0]
91
+ y = None
92
+ for i in tqdm.tqdm(self.scheduler.timesteps):
93
+ # create batch of timesteps to pass into model
94
+ timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long)
95
+ for _ in range(n_guide_steps):
96
+ with torch.enable_grad():
97
+ x.requires_grad_()
98
+
99
+ # permute to match dimension for pre-trained models
100
+ y = self.value_function(x.permute(0, 2, 1), timesteps).sample
101
+ grad = torch.autograd.grad([y.sum()], [x])[0]
102
+
103
+ posterior_variance = self.scheduler._get_variance(i)
104
+ model_std = torch.exp(0.5 * posterior_variance)
105
+ grad = model_std * grad
106
+
107
+ grad[timesteps < 2] = 0
108
+ x = x.detach()
109
+ x = x + scale * grad
110
+ x = self.reset_x0(x, conditions, self.action_dim)
111
+
112
+ prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
113
+
114
+ # TODO: verify deprecation of this kwarg
115
+ x = self.scheduler.step(prev_x, i, x)["prev_sample"]
116
+
117
+ # apply conditions to the trajectory (set the initial state)
118
+ x = self.reset_x0(x, conditions, self.action_dim)
119
+ x = self.to_torch(x)
120
+ return x, y
121
+
122
+ def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1):
123
+ # normalize the observations and create batch dimension
124
+ obs = self.normalize(obs, "observations")
125
+ obs = obs[None].repeat(batch_size, axis=0)
126
+
127
+ conditions = {0: self.to_torch(obs)}
128
+ shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
129
+
130
+ # generate initial noise and apply our conditions (to make the trajectories start at current state)
131
+ x1 = randn_tensor(shape, device=self.unet.device)
132
+ x = self.reset_x0(x1, conditions, self.action_dim)
133
+ x = self.to_torch(x)
134
+
135
+ # run the diffusion process
136
+ x, y = self.run_diffusion(x, conditions, n_guide_steps, scale)
137
+
138
+ # sort output trajectories by value
139
+ sorted_idx = y.argsort(0, descending=True).squeeze()
140
+ sorted_values = x[sorted_idx]
141
+ actions = sorted_values[:, :, : self.action_dim]
142
+ actions = actions.detach().cpu().numpy()
143
+ denorm_actions = self.de_normalize(actions, key="actions")
144
+
145
+ # select the action with the highest value
146
+ if y is not None:
147
+ selected_index = 0
148
+ else:
149
+ # if we didn't run value guiding, select a random action
150
+ selected_index = np.random.randint(0, batch_size)
151
+
152
+ denorm_actions = denorm_actions[selected_index, 0]
153
+ return denorm_actions
diffusers/image_processor.py ADDED
@@ -0,0 +1,1103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ import warnings
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import PIL.Image
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from PIL import Image, ImageFilter, ImageOps
24
+
25
+ from .configuration_utils import ConfigMixin, register_to_config
26
+ from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
27
+
28
+
29
+ PipelineImageInput = Union[
30
+ PIL.Image.Image,
31
+ np.ndarray,
32
+ torch.Tensor,
33
+ List[PIL.Image.Image],
34
+ List[np.ndarray],
35
+ List[torch.Tensor],
36
+ ]
37
+
38
+ PipelineDepthInput = PipelineImageInput
39
+
40
+
41
+ def is_valid_image(image):
42
+ return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3)
43
+
44
+
45
+ def is_valid_image_imagelist(images):
46
+ # check if the image input is one of the supported formats for image and image list:
47
+ # it can be either one of below 3
48
+ # (1) a 4d pytorch tensor or numpy array,
49
+ # (2) a valid image: PIL.Image.Image, 2-d np.ndarray or torch.Tensor (grayscale image), 3-d np.ndarray or torch.Tensor
50
+ # (3) a list of valid image
51
+ if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4:
52
+ return True
53
+ elif is_valid_image(images):
54
+ return True
55
+ elif isinstance(images, list):
56
+ return all(is_valid_image(image) for image in images)
57
+ return False
58
+
59
+
60
+ class VaeImageProcessor(ConfigMixin):
61
+ """
62
+ Image processor for VAE.
63
+
64
+ Args:
65
+ do_resize (`bool`, *optional*, defaults to `True`):
66
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
67
+ `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
68
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
69
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
70
+ resample (`str`, *optional*, defaults to `lanczos`):
71
+ Resampling filter to use when resizing the image.
72
+ do_normalize (`bool`, *optional*, defaults to `True`):
73
+ Whether to normalize the image to [-1,1].
74
+ do_binarize (`bool`, *optional*, defaults to `False`):
75
+ Whether to binarize the image to 0/1.
76
+ do_convert_rgb (`bool`, *optional*, defaults to be `False`):
77
+ Whether to convert the images to RGB format.
78
+ do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
79
+ Whether to convert the images to grayscale format.
80
+ """
81
+
82
+ config_name = CONFIG_NAME
83
+
84
+ @register_to_config
85
+ def __init__(
86
+ self,
87
+ do_resize: bool = True,
88
+ vae_scale_factor: int = 8,
89
+ vae_latent_channels: int = 4,
90
+ resample: str = "lanczos",
91
+ do_normalize: bool = True,
92
+ do_binarize: bool = False,
93
+ do_convert_rgb: bool = False,
94
+ do_convert_grayscale: bool = False,
95
+ ):
96
+ super().__init__()
97
+ if do_convert_rgb and do_convert_grayscale:
98
+ raise ValueError(
99
+ "`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`,"
100
+ " if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.",
101
+ " if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`",
102
+ )
103
+
104
+ @staticmethod
105
+ def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
106
+ """
107
+ Convert a numpy image or a batch of images to a PIL image.
108
+ """
109
+ if images.ndim == 3:
110
+ images = images[None, ...]
111
+ images = (images * 255).round().astype("uint8")
112
+ if images.shape[-1] == 1:
113
+ # special case for grayscale (single channel) images
114
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
115
+ else:
116
+ pil_images = [Image.fromarray(image) for image in images]
117
+
118
+ return pil_images
119
+
120
+ @staticmethod
121
+ def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
122
+ """
123
+ Convert a PIL image or a list of PIL images to NumPy arrays.
124
+ """
125
+ if not isinstance(images, list):
126
+ images = [images]
127
+ images = [np.array(image).astype(np.float32) / 255.0 for image in images]
128
+ images = np.stack(images, axis=0)
129
+
130
+ return images
131
+
132
+ @staticmethod
133
+ def numpy_to_pt(images: np.ndarray) -> torch.Tensor:
134
+ """
135
+ Convert a NumPy image to a PyTorch tensor.
136
+ """
137
+ if images.ndim == 3:
138
+ images = images[..., None]
139
+
140
+ images = torch.from_numpy(images.transpose(0, 3, 1, 2))
141
+ return images
142
+
143
+ @staticmethod
144
+ def pt_to_numpy(images: torch.Tensor) -> np.ndarray:
145
+ """
146
+ Convert a PyTorch tensor to a NumPy image.
147
+ """
148
+ images = images.cpu().permute(0, 2, 3, 1).float().numpy()
149
+ return images
150
+
151
+ @staticmethod
152
+ def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
153
+ """
154
+ Normalize an image array to [-1,1].
155
+ """
156
+ return 2.0 * images - 1.0
157
+
158
+ @staticmethod
159
+ def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
160
+ """
161
+ Denormalize an image array to [0,1].
162
+ """
163
+ return (images / 2 + 0.5).clamp(0, 1)
164
+
165
+ @staticmethod
166
+ def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
167
+ """
168
+ Converts a PIL image to RGB format.
169
+ """
170
+ image = image.convert("RGB")
171
+
172
+ return image
173
+
174
+ @staticmethod
175
+ def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image:
176
+ """
177
+ Converts a PIL image to grayscale format.
178
+ """
179
+ image = image.convert("L")
180
+
181
+ return image
182
+
183
+ @staticmethod
184
+ def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image:
185
+ """
186
+ Applies Gaussian blur to an image.
187
+ """
188
+ image = image.filter(ImageFilter.GaussianBlur(blur_factor))
189
+
190
+ return image
191
+
192
+ @staticmethod
193
+ def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0):
194
+ """
195
+ Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect
196
+ ratio of the original image; for example, if user drew mask in a 128x32 region, and the dimensions for
197
+ processing are 512x512, the region will be expanded to 128x128.
198
+
199
+ Args:
200
+ mask_image (PIL.Image.Image): Mask image.
201
+ width (int): Width of the image to be processed.
202
+ height (int): Height of the image to be processed.
203
+ pad (int, optional): Padding to be added to the crop region. Defaults to 0.
204
+
205
+ Returns:
206
+ tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and
207
+ matches the original aspect ratio.
208
+ """
209
+
210
+ mask_image = mask_image.convert("L")
211
+ mask = np.array(mask_image)
212
+
213
+ # 1. find a rectangular region that contains all masked ares in an image
214
+ h, w = mask.shape
215
+ crop_left = 0
216
+ for i in range(w):
217
+ if not (mask[:, i] == 0).all():
218
+ break
219
+ crop_left += 1
220
+
221
+ crop_right = 0
222
+ for i in reversed(range(w)):
223
+ if not (mask[:, i] == 0).all():
224
+ break
225
+ crop_right += 1
226
+
227
+ crop_top = 0
228
+ for i in range(h):
229
+ if not (mask[i] == 0).all():
230
+ break
231
+ crop_top += 1
232
+
233
+ crop_bottom = 0
234
+ for i in reversed(range(h)):
235
+ if not (mask[i] == 0).all():
236
+ break
237
+ crop_bottom += 1
238
+
239
+ # 2. add padding to the crop region
240
+ x1, y1, x2, y2 = (
241
+ int(max(crop_left - pad, 0)),
242
+ int(max(crop_top - pad, 0)),
243
+ int(min(w - crop_right + pad, w)),
244
+ int(min(h - crop_bottom + pad, h)),
245
+ )
246
+
247
+ # 3. expands crop region to match the aspect ratio of the image to be processed
248
+ ratio_crop_region = (x2 - x1) / (y2 - y1)
249
+ ratio_processing = width / height
250
+
251
+ if ratio_crop_region > ratio_processing:
252
+ desired_height = (x2 - x1) / ratio_processing
253
+ desired_height_diff = int(desired_height - (y2 - y1))
254
+ y1 -= desired_height_diff // 2
255
+ y2 += desired_height_diff - desired_height_diff // 2
256
+ if y2 >= mask_image.height:
257
+ diff = y2 - mask_image.height
258
+ y2 -= diff
259
+ y1 -= diff
260
+ if y1 < 0:
261
+ y2 -= y1
262
+ y1 -= y1
263
+ if y2 >= mask_image.height:
264
+ y2 = mask_image.height
265
+ else:
266
+ desired_width = (y2 - y1) * ratio_processing
267
+ desired_width_diff = int(desired_width - (x2 - x1))
268
+ x1 -= desired_width_diff // 2
269
+ x2 += desired_width_diff - desired_width_diff // 2
270
+ if x2 >= mask_image.width:
271
+ diff = x2 - mask_image.width
272
+ x2 -= diff
273
+ x1 -= diff
274
+ if x1 < 0:
275
+ x2 -= x1
276
+ x1 -= x1
277
+ if x2 >= mask_image.width:
278
+ x2 = mask_image.width
279
+
280
+ return x1, y1, x2, y2
281
+
282
+ def _resize_and_fill(
283
+ self,
284
+ image: PIL.Image.Image,
285
+ width: int,
286
+ height: int,
287
+ ) -> PIL.Image.Image:
288
+ """
289
+ Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
290
+ the image within the dimensions, filling empty with data from image.
291
+
292
+ Args:
293
+ image: The image to resize.
294
+ width: The width to resize the image to.
295
+ height: The height to resize the image to.
296
+ """
297
+
298
+ ratio = width / height
299
+ src_ratio = image.width / image.height
300
+
301
+ src_w = width if ratio < src_ratio else image.width * height // image.height
302
+ src_h = height if ratio >= src_ratio else image.height * width // image.width
303
+
304
+ resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
305
+ res = Image.new("RGB", (width, height))
306
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
307
+
308
+ if ratio < src_ratio:
309
+ fill_height = height // 2 - src_h // 2
310
+ if fill_height > 0:
311
+ res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
312
+ res.paste(
313
+ resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)),
314
+ box=(0, fill_height + src_h),
315
+ )
316
+ elif ratio > src_ratio:
317
+ fill_width = width // 2 - src_w // 2
318
+ if fill_width > 0:
319
+ res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
320
+ res.paste(
321
+ resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)),
322
+ box=(fill_width + src_w, 0),
323
+ )
324
+
325
+ return res
326
+
327
+ def _resize_and_crop(
328
+ self,
329
+ image: PIL.Image.Image,
330
+ width: int,
331
+ height: int,
332
+ ) -> PIL.Image.Image:
333
+ """
334
+ Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
335
+ the image within the dimensions, cropping the excess.
336
+
337
+ Args:
338
+ image: The image to resize.
339
+ width: The width to resize the image to.
340
+ height: The height to resize the image to.
341
+ """
342
+ ratio = width / height
343
+ src_ratio = image.width / image.height
344
+
345
+ src_w = width if ratio > src_ratio else image.width * height // image.height
346
+ src_h = height if ratio <= src_ratio else image.height * width // image.width
347
+
348
+ resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
349
+ res = Image.new("RGB", (width, height))
350
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
351
+ return res
352
+
353
+ def resize(
354
+ self,
355
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
356
+ height: int,
357
+ width: int,
358
+ resize_mode: str = "default", # "default", "fill", "crop"
359
+ ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
360
+ """
361
+ Resize image.
362
+
363
+ Args:
364
+ image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
365
+ The image input, can be a PIL image, numpy array or pytorch tensor.
366
+ height (`int`):
367
+ The height to resize to.
368
+ width (`int`):
369
+ The width to resize to.
370
+ resize_mode (`str`, *optional*, defaults to `default`):
371
+ The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit
372
+ within the specified width and height, and it may not maintaining the original aspect ratio. If `fill`,
373
+ will resize the image to fit within the specified width and height, maintaining the aspect ratio, and
374
+ then center the image within the dimensions, filling empty with data from image. If `crop`, will resize
375
+ the image to fit within the specified width and height, maintaining the aspect ratio, and then center
376
+ the image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
377
+ supported for PIL image input.
378
+
379
+ Returns:
380
+ `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
381
+ The resized image.
382
+ """
383
+ if resize_mode != "default" and not isinstance(image, PIL.Image.Image):
384
+ raise ValueError(f"Only PIL image input is supported for resize_mode {resize_mode}")
385
+ if isinstance(image, PIL.Image.Image):
386
+ if resize_mode == "default":
387
+ image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
388
+ elif resize_mode == "fill":
389
+ image = self._resize_and_fill(image, width, height)
390
+ elif resize_mode == "crop":
391
+ image = self._resize_and_crop(image, width, height)
392
+ else:
393
+ raise ValueError(f"resize_mode {resize_mode} is not supported")
394
+
395
+ elif isinstance(image, torch.Tensor):
396
+ image = torch.nn.functional.interpolate(
397
+ image,
398
+ size=(height, width),
399
+ )
400
+ elif isinstance(image, np.ndarray):
401
+ image = self.numpy_to_pt(image)
402
+ image = torch.nn.functional.interpolate(
403
+ image,
404
+ size=(height, width),
405
+ )
406
+ image = self.pt_to_numpy(image)
407
+ return image
408
+
409
+ def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
410
+ """
411
+ Create a mask.
412
+
413
+ Args:
414
+ image (`PIL.Image.Image`):
415
+ The image input, should be a PIL image.
416
+
417
+ Returns:
418
+ `PIL.Image.Image`:
419
+ The binarized image. Values less than 0.5 are set to 0, values greater than 0.5 are set to 1.
420
+ """
421
+ image[image < 0.5] = 0
422
+ image[image >= 0.5] = 1
423
+
424
+ return image
425
+
426
+ def get_default_height_width(
427
+ self,
428
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
429
+ height: Optional[int] = None,
430
+ width: Optional[int] = None,
431
+ ) -> Tuple[int, int]:
432
+ """
433
+ This function return the height and width that are downscaled to the next integer multiple of
434
+ `vae_scale_factor`.
435
+
436
+ Args:
437
+ image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
438
+ The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have
439
+ shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should
440
+ have shape `[batch, channel, height, width]`.
441
+ height (`int`, *optional*, defaults to `None`):
442
+ The height in preprocessed image. If `None`, will use the height of `image` input.
443
+ width (`int`, *optional*`, defaults to `None`):
444
+ The width in preprocessed. If `None`, will use the width of the `image` input.
445
+ """
446
+
447
+ if height is None:
448
+ if isinstance(image, PIL.Image.Image):
449
+ height = image.height
450
+ elif isinstance(image, torch.Tensor):
451
+ height = image.shape[2]
452
+ else:
453
+ height = image.shape[1]
454
+
455
+ if width is None:
456
+ if isinstance(image, PIL.Image.Image):
457
+ width = image.width
458
+ elif isinstance(image, torch.Tensor):
459
+ width = image.shape[3]
460
+ else:
461
+ width = image.shape[2]
462
+
463
+ width, height = (
464
+ x - x % self.config.vae_scale_factor for x in (width, height)
465
+ ) # resize to integer multiple of vae_scale_factor
466
+
467
+ return height, width
468
+
469
+ def preprocess(
470
+ self,
471
+ image: PipelineImageInput,
472
+ height: Optional[int] = None,
473
+ width: Optional[int] = None,
474
+ resize_mode: str = "default", # "default", "fill", "crop"
475
+ crops_coords: Optional[Tuple[int, int, int, int]] = None,
476
+ ) -> torch.Tensor:
477
+ """
478
+ Preprocess the image input.
479
+
480
+ Args:
481
+ image (`pipeline_image_input`):
482
+ The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of
483
+ supported formats.
484
+ height (`int`, *optional*, defaults to `None`):
485
+ The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
486
+ height.
487
+ width (`int`, *optional*`, defaults to `None`):
488
+ The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
489
+ resize_mode (`str`, *optional*, defaults to `default`):
490
+ The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
491
+ the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will
492
+ resize the image to fit within the specified width and height, maintaining the aspect ratio, and then
493
+ center the image within the dimensions, filling empty with data from image. If `crop`, will resize the
494
+ image to fit within the specified width and height, maintaining the aspect ratio, and then center the
495
+ image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
496
+ supported for PIL image input.
497
+ crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
498
+ The crop coordinates for each image in the batch. If `None`, will not crop the image.
499
+ """
500
+ supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
501
+
502
+ # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
503
+ if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3:
504
+ if isinstance(image, torch.Tensor):
505
+ # if image is a pytorch tensor could have 2 possible shapes:
506
+ # 1. batch x height x width: we should insert the channel dimension at position 1
507
+ # 2. channel x height x width: we should insert batch dimension at position 0,
508
+ # however, since both channel and batch dimension has same size 1, it is same to insert at position 1
509
+ # for simplicity, we insert a dimension of size 1 at position 1 for both cases
510
+ image = image.unsqueeze(1)
511
+ else:
512
+ # if it is a numpy array, it could have 2 possible shapes:
513
+ # 1. batch x height x width: insert channel dimension on last position
514
+ # 2. height x width x channel: insert batch dimension on first position
515
+ if image.shape[-1] == 1:
516
+ image = np.expand_dims(image, axis=0)
517
+ else:
518
+ image = np.expand_dims(image, axis=-1)
519
+
520
+ if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4:
521
+ warnings.warn(
522
+ "Passing `image` as a list of 4d np.ndarray is deprecated."
523
+ "Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray",
524
+ FutureWarning,
525
+ )
526
+ image = np.concatenate(image, axis=0)
527
+ if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
528
+ warnings.warn(
529
+ "Passing `image` as a list of 4d torch.Tensor is deprecated."
530
+ "Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor",
531
+ FutureWarning,
532
+ )
533
+ image = torch.cat(image, axis=0)
534
+
535
+ if not is_valid_image_imagelist(image):
536
+ raise ValueError(
537
+ f"Input is in incorrect format. Currently, we only support {', '.join(str(x) for x in supported_formats)}"
538
+ )
539
+ if not isinstance(image, list):
540
+ image = [image]
541
+
542
+ if isinstance(image[0], PIL.Image.Image):
543
+ if crops_coords is not None:
544
+ image = [i.crop(crops_coords) for i in image]
545
+ if self.config.do_resize:
546
+ height, width = self.get_default_height_width(image[0], height, width)
547
+ image = [self.resize(i, height, width, resize_mode=resize_mode) for i in image]
548
+ if self.config.do_convert_rgb:
549
+ image = [self.convert_to_rgb(i) for i in image]
550
+ elif self.config.do_convert_grayscale:
551
+ image = [self.convert_to_grayscale(i) for i in image]
552
+ image = self.pil_to_numpy(image) # to np
553
+ image = self.numpy_to_pt(image) # to pt
554
+
555
+ elif isinstance(image[0], np.ndarray):
556
+ image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
557
+
558
+ image = self.numpy_to_pt(image)
559
+
560
+ height, width = self.get_default_height_width(image, height, width)
561
+ if self.config.do_resize:
562
+ image = self.resize(image, height, width)
563
+
564
+ elif isinstance(image[0], torch.Tensor):
565
+ image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
566
+
567
+ if self.config.do_convert_grayscale and image.ndim == 3:
568
+ image = image.unsqueeze(1)
569
+
570
+ channel = image.shape[1]
571
+ # don't need any preprocess if the image is latents
572
+ if channel == self.config.vae_latent_channels:
573
+ return image
574
+
575
+ height, width = self.get_default_height_width(image, height, width)
576
+ if self.config.do_resize:
577
+ image = self.resize(image, height, width)
578
+
579
+ # expected range [0,1], normalize to [-1,1]
580
+ do_normalize = self.config.do_normalize
581
+ if do_normalize and image.min() < 0:
582
+ warnings.warn(
583
+ "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
584
+ f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
585
+ FutureWarning,
586
+ )
587
+ do_normalize = False
588
+ if do_normalize:
589
+ image = self.normalize(image)
590
+
591
+ if self.config.do_binarize:
592
+ image = self.binarize(image)
593
+
594
+ return image
595
+
596
+ def postprocess(
597
+ self,
598
+ image: torch.Tensor,
599
+ output_type: str = "pil",
600
+ do_denormalize: Optional[List[bool]] = None,
601
+ ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
602
+ """
603
+ Postprocess the image output from tensor to `output_type`.
604
+
605
+ Args:
606
+ image (`torch.Tensor`):
607
+ The image input, should be a pytorch tensor with shape `B x C x H x W`.
608
+ output_type (`str`, *optional*, defaults to `pil`):
609
+ The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
610
+ do_denormalize (`List[bool]`, *optional*, defaults to `None`):
611
+ Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
612
+ `VaeImageProcessor` config.
613
+
614
+ Returns:
615
+ `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
616
+ The postprocessed image.
617
+ """
618
+ if not isinstance(image, torch.Tensor):
619
+ raise ValueError(
620
+ f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
621
+ )
622
+ if output_type not in ["latent", "pt", "np", "pil"]:
623
+ deprecation_message = (
624
+ f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
625
+ "`pil`, `np`, `pt`, `latent`"
626
+ )
627
+ deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
628
+ output_type = "np"
629
+
630
+ if output_type == "latent":
631
+ return image
632
+
633
+ if do_denormalize is None:
634
+ do_denormalize = [self.config.do_normalize] * image.shape[0]
635
+
636
+ image = torch.stack(
637
+ [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
638
+ )
639
+
640
+ if output_type == "pt":
641
+ return image
642
+
643
+ image = self.pt_to_numpy(image)
644
+
645
+ if output_type == "np":
646
+ return image
647
+
648
+ if output_type == "pil":
649
+ return self.numpy_to_pil(image)
650
+
651
+ def apply_overlay(
652
+ self,
653
+ mask: PIL.Image.Image,
654
+ init_image: PIL.Image.Image,
655
+ image: PIL.Image.Image,
656
+ crop_coords: Optional[Tuple[int, int, int, int]] = None,
657
+ ) -> PIL.Image.Image:
658
+ """
659
+ overlay the inpaint output to the original image
660
+ """
661
+
662
+ width, height = image.width, image.height
663
+
664
+ init_image = self.resize(init_image, width=width, height=height)
665
+ mask = self.resize(mask, width=width, height=height)
666
+
667
+ init_image_masked = PIL.Image.new("RGBa", (width, height))
668
+ init_image_masked.paste(init_image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert("L")))
669
+ init_image_masked = init_image_masked.convert("RGBA")
670
+
671
+ if crop_coords is not None:
672
+ x, y, x2, y2 = crop_coords
673
+ w = x2 - x
674
+ h = y2 - y
675
+ base_image = PIL.Image.new("RGBA", (width, height))
676
+ image = self.resize(image, height=h, width=w, resize_mode="crop")
677
+ base_image.paste(image, (x, y))
678
+ image = base_image.convert("RGB")
679
+
680
+ image = image.convert("RGBA")
681
+ image.alpha_composite(init_image_masked)
682
+ image = image.convert("RGB")
683
+
684
+ return image
685
+
686
+
687
+ class VaeImageProcessorLDM3D(VaeImageProcessor):
688
+ """
689
+ Image processor for VAE LDM3D.
690
+
691
+ Args:
692
+ do_resize (`bool`, *optional*, defaults to `True`):
693
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
694
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
695
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
696
+ resample (`str`, *optional*, defaults to `lanczos`):
697
+ Resampling filter to use when resizing the image.
698
+ do_normalize (`bool`, *optional*, defaults to `True`):
699
+ Whether to normalize the image to [-1,1].
700
+ """
701
+
702
+ config_name = CONFIG_NAME
703
+
704
+ @register_to_config
705
+ def __init__(
706
+ self,
707
+ do_resize: bool = True,
708
+ vae_scale_factor: int = 8,
709
+ resample: str = "lanczos",
710
+ do_normalize: bool = True,
711
+ ):
712
+ super().__init__()
713
+
714
+ @staticmethod
715
+ def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
716
+ """
717
+ Convert a NumPy image or a batch of images to a PIL image.
718
+ """
719
+ if images.ndim == 3:
720
+ images = images[None, ...]
721
+ images = (images * 255).round().astype("uint8")
722
+ if images.shape[-1] == 1:
723
+ # special case for grayscale (single channel) images
724
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
725
+ else:
726
+ pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
727
+
728
+ return pil_images
729
+
730
+ @staticmethod
731
+ def depth_pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
732
+ """
733
+ Convert a PIL image or a list of PIL images to NumPy arrays.
734
+ """
735
+ if not isinstance(images, list):
736
+ images = [images]
737
+
738
+ images = [np.array(image).astype(np.float32) / (2**16 - 1) for image in images]
739
+ images = np.stack(images, axis=0)
740
+ return images
741
+
742
+ @staticmethod
743
+ def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
744
+ """
745
+ Args:
746
+ image: RGB-like depth image
747
+
748
+ Returns: depth map
749
+
750
+ """
751
+ return image[:, :, 1] * 2**8 + image[:, :, 2]
752
+
753
+ def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
754
+ """
755
+ Convert a NumPy depth image or a batch of images to a PIL image.
756
+ """
757
+ if images.ndim == 3:
758
+ images = images[None, ...]
759
+ images_depth = images[:, :, :, 3:]
760
+ if images.shape[-1] == 6:
761
+ images_depth = (images_depth * 255).round().astype("uint8")
762
+ pil_images = [
763
+ Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16") for image_depth in images_depth
764
+ ]
765
+ elif images.shape[-1] == 4:
766
+ images_depth = (images_depth * 65535.0).astype(np.uint16)
767
+ pil_images = [Image.fromarray(image_depth, mode="I;16") for image_depth in images_depth]
768
+ else:
769
+ raise Exception("Not supported")
770
+
771
+ return pil_images
772
+
773
+ def postprocess(
774
+ self,
775
+ image: torch.Tensor,
776
+ output_type: str = "pil",
777
+ do_denormalize: Optional[List[bool]] = None,
778
+ ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
779
+ """
780
+ Postprocess the image output from tensor to `output_type`.
781
+
782
+ Args:
783
+ image (`torch.Tensor`):
784
+ The image input, should be a pytorch tensor with shape `B x C x H x W`.
785
+ output_type (`str`, *optional*, defaults to `pil`):
786
+ The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
787
+ do_denormalize (`List[bool]`, *optional*, defaults to `None`):
788
+ Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
789
+ `VaeImageProcessor` config.
790
+
791
+ Returns:
792
+ `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
793
+ The postprocessed image.
794
+ """
795
+ if not isinstance(image, torch.Tensor):
796
+ raise ValueError(
797
+ f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
798
+ )
799
+ if output_type not in ["latent", "pt", "np", "pil"]:
800
+ deprecation_message = (
801
+ f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
802
+ "`pil`, `np`, `pt`, `latent`"
803
+ )
804
+ deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
805
+ output_type = "np"
806
+
807
+ if do_denormalize is None:
808
+ do_denormalize = [self.config.do_normalize] * image.shape[0]
809
+
810
+ image = torch.stack(
811
+ [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
812
+ )
813
+
814
+ image = self.pt_to_numpy(image)
815
+
816
+ if output_type == "np":
817
+ if image.shape[-1] == 6:
818
+ image_depth = np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0)
819
+ else:
820
+ image_depth = image[:, :, :, 3:]
821
+ return image[:, :, :, :3], image_depth
822
+
823
+ if output_type == "pil":
824
+ return self.numpy_to_pil(image), self.numpy_to_depth(image)
825
+ else:
826
+ raise Exception(f"This type {output_type} is not supported")
827
+
828
+ def preprocess(
829
+ self,
830
+ rgb: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
831
+ depth: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
832
+ height: Optional[int] = None,
833
+ width: Optional[int] = None,
834
+ target_res: Optional[int] = None,
835
+ ) -> torch.Tensor:
836
+ """
837
+ Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
838
+ """
839
+ supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
840
+
841
+ # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
842
+ if self.config.do_convert_grayscale and isinstance(rgb, (torch.Tensor, np.ndarray)) and rgb.ndim == 3:
843
+ raise Exception("This is not yet supported")
844
+
845
+ if isinstance(rgb, supported_formats):
846
+ rgb = [rgb]
847
+ depth = [depth]
848
+ elif not (isinstance(rgb, list) and all(isinstance(i, supported_formats) for i in rgb)):
849
+ raise ValueError(
850
+ f"Input is in incorrect format: {[type(i) for i in rgb]}. Currently, we only support {', '.join(supported_formats)}"
851
+ )
852
+
853
+ if isinstance(rgb[0], PIL.Image.Image):
854
+ if self.config.do_convert_rgb:
855
+ raise Exception("This is not yet supported")
856
+ # rgb = [self.convert_to_rgb(i) for i in rgb]
857
+ # depth = [self.convert_to_depth(i) for i in depth] #TODO define convert_to_depth
858
+ if self.config.do_resize or target_res:
859
+ height, width = self.get_default_height_width(rgb[0], height, width) if not target_res else target_res
860
+ rgb = [self.resize(i, height, width) for i in rgb]
861
+ depth = [self.resize(i, height, width) for i in depth]
862
+ rgb = self.pil_to_numpy(rgb) # to np
863
+ rgb = self.numpy_to_pt(rgb) # to pt
864
+
865
+ depth = self.depth_pil_to_numpy(depth) # to np
866
+ depth = self.numpy_to_pt(depth) # to pt
867
+
868
+ elif isinstance(rgb[0], np.ndarray):
869
+ rgb = np.concatenate(rgb, axis=0) if rgb[0].ndim == 4 else np.stack(rgb, axis=0)
870
+ rgb = self.numpy_to_pt(rgb)
871
+ height, width = self.get_default_height_width(rgb, height, width)
872
+ if self.config.do_resize:
873
+ rgb = self.resize(rgb, height, width)
874
+
875
+ depth = np.concatenate(depth, axis=0) if rgb[0].ndim == 4 else np.stack(depth, axis=0)
876
+ depth = self.numpy_to_pt(depth)
877
+ height, width = self.get_default_height_width(depth, height, width)
878
+ if self.config.do_resize:
879
+ depth = self.resize(depth, height, width)
880
+
881
+ elif isinstance(rgb[0], torch.Tensor):
882
+ raise Exception("This is not yet supported")
883
+ # rgb = torch.cat(rgb, axis=0) if rgb[0].ndim == 4 else torch.stack(rgb, axis=0)
884
+
885
+ # if self.config.do_convert_grayscale and rgb.ndim == 3:
886
+ # rgb = rgb.unsqueeze(1)
887
+
888
+ # channel = rgb.shape[1]
889
+
890
+ # height, width = self.get_default_height_width(rgb, height, width)
891
+ # if self.config.do_resize:
892
+ # rgb = self.resize(rgb, height, width)
893
+
894
+ # depth = torch.cat(depth, axis=0) if depth[0].ndim == 4 else torch.stack(depth, axis=0)
895
+
896
+ # if self.config.do_convert_grayscale and depth.ndim == 3:
897
+ # depth = depth.unsqueeze(1)
898
+
899
+ # channel = depth.shape[1]
900
+ # # don't need any preprocess if the image is latents
901
+ # if depth == 4:
902
+ # return rgb, depth
903
+
904
+ # height, width = self.get_default_height_width(depth, height, width)
905
+ # if self.config.do_resize:
906
+ # depth = self.resize(depth, height, width)
907
+ # expected range [0,1], normalize to [-1,1]
908
+ do_normalize = self.config.do_normalize
909
+ if rgb.min() < 0 and do_normalize:
910
+ warnings.warn(
911
+ "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
912
+ f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{rgb.min()},{rgb.max()}]",
913
+ FutureWarning,
914
+ )
915
+ do_normalize = False
916
+
917
+ if do_normalize:
918
+ rgb = self.normalize(rgb)
919
+ depth = self.normalize(depth)
920
+
921
+ if self.config.do_binarize:
922
+ rgb = self.binarize(rgb)
923
+ depth = self.binarize(depth)
924
+
925
+ return rgb, depth
926
+
927
+
928
+ class IPAdapterMaskProcessor(VaeImageProcessor):
929
+ """
930
+ Image processor for IP Adapter image masks.
931
+
932
+ Args:
933
+ do_resize (`bool`, *optional*, defaults to `True`):
934
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
935
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
936
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
937
+ resample (`str`, *optional*, defaults to `lanczos`):
938
+ Resampling filter to use when resizing the image.
939
+ do_normalize (`bool`, *optional*, defaults to `False`):
940
+ Whether to normalize the image to [-1,1].
941
+ do_binarize (`bool`, *optional*, defaults to `True`):
942
+ Whether to binarize the image to 0/1.
943
+ do_convert_grayscale (`bool`, *optional*, defaults to be `True`):
944
+ Whether to convert the images to grayscale format.
945
+
946
+ """
947
+
948
+ config_name = CONFIG_NAME
949
+
950
+ @register_to_config
951
+ def __init__(
952
+ self,
953
+ do_resize: bool = True,
954
+ vae_scale_factor: int = 8,
955
+ resample: str = "lanczos",
956
+ do_normalize: bool = False,
957
+ do_binarize: bool = True,
958
+ do_convert_grayscale: bool = True,
959
+ ):
960
+ super().__init__(
961
+ do_resize=do_resize,
962
+ vae_scale_factor=vae_scale_factor,
963
+ resample=resample,
964
+ do_normalize=do_normalize,
965
+ do_binarize=do_binarize,
966
+ do_convert_grayscale=do_convert_grayscale,
967
+ )
968
+
969
+ @staticmethod
970
+ def downsample(mask: torch.Tensor, batch_size: int, num_queries: int, value_embed_dim: int):
971
+ """
972
+ Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention. If the
973
+ aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued.
974
+
975
+ Args:
976
+ mask (`torch.Tensor`):
977
+ The input mask tensor generated with `IPAdapterMaskProcessor.preprocess()`.
978
+ batch_size (`int`):
979
+ The batch size.
980
+ num_queries (`int`):
981
+ The number of queries.
982
+ value_embed_dim (`int`):
983
+ The dimensionality of the value embeddings.
984
+
985
+ Returns:
986
+ `torch.Tensor`:
987
+ The downsampled mask tensor.
988
+
989
+ """
990
+ o_h = mask.shape[1]
991
+ o_w = mask.shape[2]
992
+ ratio = o_w / o_h
993
+ mask_h = int(math.sqrt(num_queries / ratio))
994
+ mask_h = int(mask_h) + int((num_queries % int(mask_h)) != 0)
995
+ mask_w = num_queries // mask_h
996
+
997
+ mask_downsample = F.interpolate(mask.unsqueeze(0), size=(mask_h, mask_w), mode="bicubic").squeeze(0)
998
+
999
+ # Repeat batch_size times
1000
+ if mask_downsample.shape[0] < batch_size:
1001
+ mask_downsample = mask_downsample.repeat(batch_size, 1, 1)
1002
+
1003
+ mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1)
1004
+
1005
+ downsampled_area = mask_h * mask_w
1006
+ # If the output image and the mask do not have the same aspect ratio, tensor shapes will not match
1007
+ # Pad tensor if downsampled_mask.shape[1] is smaller than num_queries
1008
+ if downsampled_area < num_queries:
1009
+ warnings.warn(
1010
+ "The aspect ratio of the mask does not match the aspect ratio of the output image. "
1011
+ "Please update your masks or adjust the output size for optimal performance.",
1012
+ UserWarning,
1013
+ )
1014
+ mask_downsample = F.pad(mask_downsample, (0, num_queries - mask_downsample.shape[1]), value=0.0)
1015
+ # Discard last embeddings if downsampled_mask.shape[1] is bigger than num_queries
1016
+ if downsampled_area > num_queries:
1017
+ warnings.warn(
1018
+ "The aspect ratio of the mask does not match the aspect ratio of the output image. "
1019
+ "Please update your masks or adjust the output size for optimal performance.",
1020
+ UserWarning,
1021
+ )
1022
+ mask_downsample = mask_downsample[:, :num_queries]
1023
+
1024
+ # Repeat last dimension to match SDPA output shape
1025
+ mask_downsample = mask_downsample.view(mask_downsample.shape[0], mask_downsample.shape[1], 1).repeat(
1026
+ 1, 1, value_embed_dim
1027
+ )
1028
+
1029
+ return mask_downsample
1030
+
1031
+
1032
+ class PixArtImageProcessor(VaeImageProcessor):
1033
+ """
1034
+ Image processor for PixArt image resize and crop.
1035
+
1036
+ Args:
1037
+ do_resize (`bool`, *optional*, defaults to `True`):
1038
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
1039
+ `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
1040
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
1041
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
1042
+ resample (`str`, *optional*, defaults to `lanczos`):
1043
+ Resampling filter to use when resizing the image.
1044
+ do_normalize (`bool`, *optional*, defaults to `True`):
1045
+ Whether to normalize the image to [-1,1].
1046
+ do_binarize (`bool`, *optional*, defaults to `False`):
1047
+ Whether to binarize the image to 0/1.
1048
+ do_convert_rgb (`bool`, *optional*, defaults to be `False`):
1049
+ Whether to convert the images to RGB format.
1050
+ do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
1051
+ Whether to convert the images to grayscale format.
1052
+ """
1053
+
1054
+ @register_to_config
1055
+ def __init__(
1056
+ self,
1057
+ do_resize: bool = True,
1058
+ vae_scale_factor: int = 8,
1059
+ resample: str = "lanczos",
1060
+ do_normalize: bool = True,
1061
+ do_binarize: bool = False,
1062
+ do_convert_grayscale: bool = False,
1063
+ ):
1064
+ super().__init__(
1065
+ do_resize=do_resize,
1066
+ vae_scale_factor=vae_scale_factor,
1067
+ resample=resample,
1068
+ do_normalize=do_normalize,
1069
+ do_binarize=do_binarize,
1070
+ do_convert_grayscale=do_convert_grayscale,
1071
+ )
1072
+
1073
+ @staticmethod
1074
+ def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
1075
+ """Returns binned height and width."""
1076
+ ar = float(height / width)
1077
+ closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
1078
+ default_hw = ratios[closest_ratio]
1079
+ return int(default_hw[0]), int(default_hw[1])
1080
+
1081
+ @staticmethod
1082
+ def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor:
1083
+ orig_height, orig_width = samples.shape[2], samples.shape[3]
1084
+
1085
+ # Check if resizing is needed
1086
+ if orig_height != new_height or orig_width != new_width:
1087
+ ratio = max(new_height / orig_height, new_width / orig_width)
1088
+ resized_width = int(orig_width * ratio)
1089
+ resized_height = int(orig_height * ratio)
1090
+
1091
+ # Resize
1092
+ samples = F.interpolate(
1093
+ samples, size=(resized_height, resized_width), mode="bilinear", align_corners=False
1094
+ )
1095
+
1096
+ # Center Crop
1097
+ start_x = (resized_width - new_width) // 2
1098
+ end_x = start_x + new_width
1099
+ start_y = (resized_height - new_height) // 2
1100
+ end_y = start_y + new_height
1101
+ samples = samples[:, :, start_y:end_y, start_x:end_x]
1102
+
1103
+ return samples
diffusers/loaders/__init__.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, deprecate
4
+ from ..utils.import_utils import is_peft_available, is_torch_available, is_transformers_available
5
+
6
+
7
+ def text_encoder_lora_state_dict(text_encoder):
8
+ deprecate(
9
+ "text_encoder_load_state_dict in `models`",
10
+ "0.27.0",
11
+ "`text_encoder_lora_state_dict` is deprecated and will be removed in 0.27.0. Make sure to retrieve the weights using `get_peft_model`. See https://huggingface.co/docs/peft/v0.6.2/en/quicktour#peftmodel for more information.",
12
+ )
13
+ state_dict = {}
14
+
15
+ for name, module in text_encoder_attn_modules(text_encoder):
16
+ for k, v in module.q_proj.lora_linear_layer.state_dict().items():
17
+ state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
18
+
19
+ for k, v in module.k_proj.lora_linear_layer.state_dict().items():
20
+ state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
21
+
22
+ for k, v in module.v_proj.lora_linear_layer.state_dict().items():
23
+ state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
24
+
25
+ for k, v in module.out_proj.lora_linear_layer.state_dict().items():
26
+ state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
27
+
28
+ return state_dict
29
+
30
+
31
+ if is_transformers_available():
32
+
33
+ def text_encoder_attn_modules(text_encoder):
34
+ deprecate(
35
+ "text_encoder_attn_modules in `models`",
36
+ "0.27.0",
37
+ "`text_encoder_lora_state_dict` is deprecated and will be removed in 0.27.0. Make sure to retrieve the weights using `get_peft_model`. See https://huggingface.co/docs/peft/v0.6.2/en/quicktour#peftmodel for more information.",
38
+ )
39
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection
40
+
41
+ attn_modules = []
42
+
43
+ if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
44
+ for i, layer in enumerate(text_encoder.text_model.encoder.layers):
45
+ name = f"text_model.encoder.layers.{i}.self_attn"
46
+ mod = layer.self_attn
47
+ attn_modules.append((name, mod))
48
+ else:
49
+ raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}")
50
+
51
+ return attn_modules
52
+
53
+
54
+ _import_structure = {}
55
+
56
+ if is_torch_available():
57
+ _import_structure["single_file_model"] = ["FromOriginalModelMixin"]
58
+
59
+ _import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
60
+ _import_structure["utils"] = ["AttnProcsLayers"]
61
+ if is_transformers_available():
62
+ _import_structure["single_file"] = ["FromSingleFileMixin"]
63
+ _import_structure["lora_pipeline"] = [
64
+ "AmusedLoraLoaderMixin",
65
+ "StableDiffusionLoraLoaderMixin",
66
+ "SD3LoraLoaderMixin",
67
+ "StableDiffusionXLLoraLoaderMixin",
68
+ "LoraLoaderMixin",
69
+ "FluxLoraLoaderMixin",
70
+ "CogVideoXLoraLoaderMixin",
71
+ ]
72
+ _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
73
+ _import_structure["ip_adapter"] = ["IPAdapterMixin"]
74
+
75
+ _import_structure["peft"] = ["PeftAdapterMixin"]
76
+
77
+
78
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
79
+ if is_torch_available():
80
+ from .single_file_model import FromOriginalModelMixin
81
+ from .unet import UNet2DConditionLoadersMixin
82
+ from .utils import AttnProcsLayers
83
+
84
+ if is_transformers_available():
85
+ from .ip_adapter import IPAdapterMixin
86
+ from .lora_pipeline import (
87
+ AmusedLoraLoaderMixin,
88
+ CogVideoXLoraLoaderMixin,
89
+ FluxLoraLoaderMixin,
90
+ LoraLoaderMixin,
91
+ SD3LoraLoaderMixin,
92
+ StableDiffusionLoraLoaderMixin,
93
+ StableDiffusionXLLoraLoaderMixin,
94
+ )
95
+ from .single_file import FromSingleFileMixin
96
+ from .textual_inversion import TextualInversionLoaderMixin
97
+
98
+ from .peft import PeftAdapterMixin
99
+ else:
100
+ import sys
101
+
102
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diffusers/loaders/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (2.96 kB). View file
 
diffusers/loaders/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (2.99 kB). View file
 
diffusers/loaders/__pycache__/lora_base.cpython-310.pyc ADDED
Binary file (22.8 kB). View file
 
diffusers/loaders/__pycache__/lora_base.cpython-38.pyc ADDED
Binary file (23 kB). View file
 
diffusers/loaders/__pycache__/lora_conversion_utils.cpython-310.pyc ADDED
Binary file (16 kB). View file
 
diffusers/loaders/__pycache__/lora_conversion_utils.cpython-38.pyc ADDED
Binary file (16.4 kB). View file
 
diffusers/loaders/__pycache__/lora_pipeline.cpython-310.pyc ADDED
Binary file (63.6 kB). View file
 
diffusers/loaders/__pycache__/lora_pipeline.cpython-38.pyc ADDED
Binary file (69.7 kB). View file
 
diffusers/loaders/__pycache__/peft.cpython-310.pyc ADDED
Binary file (12.9 kB). View file
 
diffusers/loaders/__pycache__/peft.cpython-38.pyc ADDED
Binary file (13 kB). View file
 
diffusers/loaders/__pycache__/single_file_model.cpython-310.pyc ADDED
Binary file (9.83 kB). View file
 
diffusers/loaders/__pycache__/single_file_model.cpython-38.pyc ADDED
Binary file (9.84 kB). View file
 
diffusers/loaders/__pycache__/single_file_utils.cpython-310.pyc ADDED
Binary file (53.4 kB). View file
 
diffusers/loaders/__pycache__/single_file_utils.cpython-38.pyc ADDED
Binary file (54.8 kB). View file
 
diffusers/loaders/__pycache__/unet.cpython-310.pyc ADDED
Binary file (25.7 kB). View file
 
diffusers/loaders/__pycache__/unet.cpython-38.pyc ADDED
Binary file (25.6 kB). View file
 
diffusers/loaders/__pycache__/unet_loader_utils.cpython-310.pyc ADDED
Binary file (4.95 kB). View file
 
diffusers/loaders/__pycache__/unet_loader_utils.cpython-38.pyc ADDED
Binary file (5 kB). View file
 
diffusers/loaders/__pycache__/utils.cpython-310.pyc ADDED
Binary file (2.09 kB). View file
 
diffusers/loaders/__pycache__/utils.cpython-38.pyc ADDED
Binary file (2.1 kB). View file
 
diffusers/loaders/ip_adapter.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from pathlib import Path
16
+ from typing import Dict, List, Optional, Union
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from huggingface_hub.utils import validate_hf_hub_args
21
+ from safetensors import safe_open
22
+
23
+ from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
24
+ from ..utils import (
25
+ USE_PEFT_BACKEND,
26
+ _get_model_file,
27
+ is_accelerate_available,
28
+ is_torch_version,
29
+ is_transformers_available,
30
+ logging,
31
+ )
32
+ from .unet_loader_utils import _maybe_expand_lora_scales
33
+
34
+
35
+ if is_transformers_available():
36
+ from transformers import (
37
+ CLIPImageProcessor,
38
+ CLIPVisionModelWithProjection,
39
+ )
40
+
41
+ from ..models.attention_processor import (
42
+ AttnProcessor,
43
+ AttnProcessor2_0,
44
+ IPAdapterAttnProcessor,
45
+ IPAdapterAttnProcessor2_0,
46
+ )
47
+
48
+ logger = logging.get_logger(__name__)
49
+
50
+
51
+ class IPAdapterMixin:
52
+ """Mixin for handling IP Adapters."""
53
+
54
+ @validate_hf_hub_args
55
+ def load_ip_adapter(
56
+ self,
57
+ pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
58
+ subfolder: Union[str, List[str]],
59
+ weight_name: Union[str, List[str]],
60
+ image_encoder_folder: Optional[str] = "image_encoder",
61
+ **kwargs,
62
+ ):
63
+ """
64
+ Parameters:
65
+ pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
66
+ Can be either:
67
+
68
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
69
+ the Hub.
70
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
71
+ with [`ModelMixin.save_pretrained`].
72
+ - A [torch state
73
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
74
+ subfolder (`str` or `List[str]`):
75
+ The subfolder location of a model file within a larger model repository on the Hub or locally. If a
76
+ list is passed, it should have the same length as `weight_name`.
77
+ weight_name (`str` or `List[str]`):
78
+ The name of the weight file to load. If a list is passed, it should have the same length as
79
+ `weight_name`.
80
+ image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
81
+ The subfolder location of the image encoder within a larger model repository on the Hub or locally.
82
+ Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
83
+ `subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g.
84
+ `image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than
85
+ `subfolder`, you should pass the path to the folder that contains image encoder weights, for example,
86
+ `image_encoder_folder="different_subfolder/image_encoder"`.
87
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
88
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
89
+ is not used.
90
+ force_download (`bool`, *optional*, defaults to `False`):
91
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
92
+ cached versions if they exist.
93
+
94
+ proxies (`Dict[str, str]`, *optional*):
95
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
96
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
97
+ local_files_only (`bool`, *optional*, defaults to `False`):
98
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
99
+ won't be downloaded from the Hub.
100
+ token (`str` or *bool*, *optional*):
101
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
102
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
103
+ revision (`str`, *optional*, defaults to `"main"`):
104
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
105
+ allowed by Git.
106
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
107
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
108
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
109
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
110
+ argument to `True` will raise an error.
111
+ """
112
+
113
+ # handle the list inputs for multiple IP Adapters
114
+ if not isinstance(weight_name, list):
115
+ weight_name = [weight_name]
116
+
117
+ if not isinstance(pretrained_model_name_or_path_or_dict, list):
118
+ pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
119
+ if len(pretrained_model_name_or_path_or_dict) == 1:
120
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
121
+
122
+ if not isinstance(subfolder, list):
123
+ subfolder = [subfolder]
124
+ if len(subfolder) == 1:
125
+ subfolder = subfolder * len(weight_name)
126
+
127
+ if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
128
+ raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
129
+
130
+ if len(weight_name) != len(subfolder):
131
+ raise ValueError("`weight_name` and `subfolder` must have the same length.")
132
+
133
+ # Load the main state dict first.
134
+ cache_dir = kwargs.pop("cache_dir", None)
135
+ force_download = kwargs.pop("force_download", False)
136
+ proxies = kwargs.pop("proxies", None)
137
+ local_files_only = kwargs.pop("local_files_only", None)
138
+ token = kwargs.pop("token", None)
139
+ revision = kwargs.pop("revision", None)
140
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
141
+
142
+ if low_cpu_mem_usage and not is_accelerate_available():
143
+ low_cpu_mem_usage = False
144
+ logger.warning(
145
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
146
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
147
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
148
+ " install accelerate\n```\n."
149
+ )
150
+
151
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
152
+ raise NotImplementedError(
153
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
154
+ " `low_cpu_mem_usage=False`."
155
+ )
156
+
157
+ user_agent = {
158
+ "file_type": "attn_procs_weights",
159
+ "framework": "pytorch",
160
+ }
161
+ state_dicts = []
162
+ for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
163
+ pretrained_model_name_or_path_or_dict, weight_name, subfolder
164
+ ):
165
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
166
+ model_file = _get_model_file(
167
+ pretrained_model_name_or_path_or_dict,
168
+ weights_name=weight_name,
169
+ cache_dir=cache_dir,
170
+ force_download=force_download,
171
+ proxies=proxies,
172
+ local_files_only=local_files_only,
173
+ token=token,
174
+ revision=revision,
175
+ subfolder=subfolder,
176
+ user_agent=user_agent,
177
+ )
178
+ if weight_name.endswith(".safetensors"):
179
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
180
+ with safe_open(model_file, framework="pt", device="cpu") as f:
181
+ for key in f.keys():
182
+ if key.startswith("image_proj."):
183
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
184
+ elif key.startswith("ip_adapter."):
185
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
186
+ else:
187
+ state_dict = load_state_dict(model_file)
188
+ else:
189
+ state_dict = pretrained_model_name_or_path_or_dict
190
+
191
+ keys = list(state_dict.keys())
192
+ if keys != ["image_proj", "ip_adapter"]:
193
+ raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
194
+
195
+ state_dicts.append(state_dict)
196
+
197
+ # load CLIP image encoder here if it has not been registered to the pipeline yet
198
+ if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
199
+ if image_encoder_folder is not None:
200
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
201
+ logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
202
+ if image_encoder_folder.count("/") == 0:
203
+ image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()
204
+ else:
205
+ image_encoder_subfolder = Path(image_encoder_folder).as_posix()
206
+
207
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
208
+ pretrained_model_name_or_path_or_dict,
209
+ subfolder=image_encoder_subfolder,
210
+ low_cpu_mem_usage=low_cpu_mem_usage,
211
+ cache_dir=cache_dir,
212
+ local_files_only=local_files_only,
213
+ ).to(self.device, dtype=self.dtype)
214
+ self.register_modules(image_encoder=image_encoder)
215
+ else:
216
+ raise ValueError(
217
+ "`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
218
+ )
219
+ else:
220
+ logger.warning(
221
+ "image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
222
+ "Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
223
+ )
224
+
225
+ # create feature extractor if it has not been registered to the pipeline yet
226
+ if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
227
+ # FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224
228
+ default_clip_size = 224
229
+ clip_image_size = (
230
+ self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size
231
+ )
232
+ feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
233
+ self.register_modules(feature_extractor=feature_extractor)
234
+
235
+ # load ip-adapter into unet
236
+ unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
237
+ unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
238
+
239
+ extra_loras = unet._load_ip_adapter_loras(state_dicts)
240
+ if extra_loras != {}:
241
+ if not USE_PEFT_BACKEND:
242
+ logger.warning("PEFT backend is required to load these weights.")
243
+ else:
244
+ # apply the IP Adapter Face ID LoRA weights
245
+ peft_config = getattr(unet, "peft_config", {})
246
+ for k, lora in extra_loras.items():
247
+ if f"faceid_{k}" not in peft_config:
248
+ self.load_lora_weights(lora, adapter_name=f"faceid_{k}")
249
+ self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0])
250
+
251
+ def set_ip_adapter_scale(self, scale):
252
+ """
253
+ Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
254
+ granular control over each IP-Adapter behavior. A config can be a float or a dictionary.
255
+
256
+ Example:
257
+
258
+ ```py
259
+ # To use original IP-Adapter
260
+ scale = 1.0
261
+ pipeline.set_ip_adapter_scale(scale)
262
+
263
+ # To use style block only
264
+ scale = {
265
+ "up": {"block_0": [0.0, 1.0, 0.0]},
266
+ }
267
+ pipeline.set_ip_adapter_scale(scale)
268
+
269
+ # To use style+layout blocks
270
+ scale = {
271
+ "down": {"block_2": [0.0, 1.0]},
272
+ "up": {"block_0": [0.0, 1.0, 0.0]},
273
+ }
274
+ pipeline.set_ip_adapter_scale(scale)
275
+
276
+ # To use style and layout from 2 reference images
277
+ scales = [{"down": {"block_2": [0.0, 1.0]}}, {"up": {"block_0": [0.0, 1.0, 0.0]}}]
278
+ pipeline.set_ip_adapter_scale(scales)
279
+ ```
280
+ """
281
+ unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
282
+ if not isinstance(scale, list):
283
+ scale = [scale]
284
+ scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)
285
+
286
+ for attn_name, attn_processor in unet.attn_processors.items():
287
+ if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
288
+ if len(scale_configs) != len(attn_processor.scale):
289
+ raise ValueError(
290
+ f"Cannot assign {len(scale_configs)} scale_configs to "
291
+ f"{len(attn_processor.scale)} IP-Adapter."
292
+ )
293
+ elif len(scale_configs) == 1:
294
+ scale_configs = scale_configs * len(attn_processor.scale)
295
+ for i, scale_config in enumerate(scale_configs):
296
+ if isinstance(scale_config, dict):
297
+ for k, s in scale_config.items():
298
+ if attn_name.startswith(k):
299
+ attn_processor.scale[i] = s
300
+ else:
301
+ attn_processor.scale[i] = scale_config
302
+
303
+ def unload_ip_adapter(self):
304
+ """
305
+ Unloads the IP Adapter weights
306
+
307
+ Examples:
308
+
309
+ ```python
310
+ >>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
311
+ >>> pipeline.unload_ip_adapter()
312
+ >>> ...
313
+ ```
314
+ """
315
+ # remove CLIP image encoder
316
+ if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
317
+ self.image_encoder = None
318
+ self.register_to_config(image_encoder=[None, None])
319
+
320
+ # remove feature extractor only when safety_checker is None as safety_checker uses
321
+ # the feature_extractor later
322
+ if not hasattr(self, "safety_checker"):
323
+ if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
324
+ self.feature_extractor = None
325
+ self.register_to_config(feature_extractor=[None, None])
326
+
327
+ # remove hidden encoder
328
+ self.unet.encoder_hid_proj = None
329
+ self.unet.config.encoder_hid_dim_type = None
330
+
331
+ # Kolors: restore `encoder_hid_proj` with `text_encoder_hid_proj`
332
+ if hasattr(self.unet, "text_encoder_hid_proj") and self.unet.text_encoder_hid_proj is not None:
333
+ self.unet.encoder_hid_proj = self.unet.text_encoder_hid_proj
334
+ self.unet.text_encoder_hid_proj = None
335
+ self.unet.config.encoder_hid_dim_type = "text_proj"
336
+
337
+ # restore original Unet attention processors layers
338
+ attn_procs = {}
339
+ for name, value in self.unet.attn_processors.items():
340
+ attn_processor_class = (
341
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
342
+ )
343
+ attn_procs[name] = (
344
+ attn_processor_class
345
+ if isinstance(value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0))
346
+ else value.__class__()
347
+ )
348
+ self.unet.set_attn_processor(attn_procs)
diffusers/loaders/lora_base.py ADDED
@@ -0,0 +1,759 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import copy
16
+ import inspect
17
+ import os
18
+ from pathlib import Path
19
+ from typing import Callable, Dict, List, Optional, Union
20
+
21
+ import safetensors
22
+ import torch
23
+ import torch.nn as nn
24
+ from huggingface_hub import model_info
25
+ from huggingface_hub.constants import HF_HUB_OFFLINE
26
+
27
+ from ..models.modeling_utils import ModelMixin, load_state_dict
28
+ from ..utils import (
29
+ USE_PEFT_BACKEND,
30
+ _get_model_file,
31
+ delete_adapter_layers,
32
+ deprecate,
33
+ is_accelerate_available,
34
+ is_peft_available,
35
+ is_transformers_available,
36
+ logging,
37
+ recurse_remove_peft_layers,
38
+ set_adapter_layers,
39
+ set_weights_and_activate_adapters,
40
+ )
41
+
42
+
43
+ if is_transformers_available():
44
+ from transformers import PreTrainedModel
45
+
46
+ if is_peft_available():
47
+ from peft.tuners.tuners_utils import BaseTunerLayer
48
+
49
+ if is_accelerate_available():
50
+ from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
51
+
52
+ logger = logging.get_logger(__name__)
53
+
54
+
55
+ def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
56
+ """
57
+ Fuses LoRAs for the text encoder.
58
+
59
+ Args:
60
+ text_encoder (`torch.nn.Module`):
61
+ The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
62
+ attribute.
63
+ lora_scale (`float`, defaults to 1.0):
64
+ Controls how much to influence the outputs with the LoRA parameters.
65
+ safe_fusing (`bool`, defaults to `False`):
66
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
67
+ adapter_names (`List[str]` or `str`):
68
+ The names of the adapters to use.
69
+ """
70
+ merge_kwargs = {"safe_merge": safe_fusing}
71
+
72
+ for module in text_encoder.modules():
73
+ if isinstance(module, BaseTunerLayer):
74
+ if lora_scale != 1.0:
75
+ module.scale_layer(lora_scale)
76
+
77
+ # For BC with previous PEFT versions, we need to check the signature
78
+ # of the `merge` method to see if it supports the `adapter_names` argument.
79
+ supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
80
+ if "adapter_names" in supported_merge_kwargs:
81
+ merge_kwargs["adapter_names"] = adapter_names
82
+ elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
83
+ raise ValueError(
84
+ "The `adapter_names` argument is not supported with your PEFT version. "
85
+ "Please upgrade to the latest version of PEFT. `pip install -U peft`"
86
+ )
87
+
88
+ module.merge(**merge_kwargs)
89
+
90
+
91
+ def unfuse_text_encoder_lora(text_encoder):
92
+ """
93
+ Unfuses LoRAs for the text encoder.
94
+
95
+ Args:
96
+ text_encoder (`torch.nn.Module`):
97
+ The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
98
+ attribute.
99
+ """
100
+ for module in text_encoder.modules():
101
+ if isinstance(module, BaseTunerLayer):
102
+ module.unmerge()
103
+
104
+
105
+ def set_adapters_for_text_encoder(
106
+ adapter_names: Union[List[str], str],
107
+ text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
108
+ text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None,
109
+ ):
110
+ """
111
+ Sets the adapter layers for the text encoder.
112
+
113
+ Args:
114
+ adapter_names (`List[str]` or `str`):
115
+ The names of the adapters to use.
116
+ text_encoder (`torch.nn.Module`, *optional*):
117
+ The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
118
+ attribute.
119
+ text_encoder_weights (`List[float]`, *optional*):
120
+ The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters.
121
+ """
122
+ if text_encoder is None:
123
+ raise ValueError(
124
+ "The pipeline does not have a default `pipe.text_encoder` class. Please make sure to pass a `text_encoder` instead."
125
+ )
126
+
127
+ def process_weights(adapter_names, weights):
128
+ # Expand weights into a list, one entry per adapter
129
+ # e.g. for 2 adapters: 7 -> [7,7] ; [3, None] -> [3, None]
130
+ if not isinstance(weights, list):
131
+ weights = [weights] * len(adapter_names)
132
+
133
+ if len(adapter_names) != len(weights):
134
+ raise ValueError(
135
+ f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
136
+ )
137
+
138
+ # Set None values to default of 1.0
139
+ # e.g. [7,7] -> [7,7] ; [3, None] -> [3,1]
140
+ weights = [w if w is not None else 1.0 for w in weights]
141
+
142
+ return weights
143
+
144
+ adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
145
+ text_encoder_weights = process_weights(adapter_names, text_encoder_weights)
146
+ set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights)
147
+
148
+
149
+ def disable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None):
150
+ """
151
+ Disables the LoRA layers for the text encoder.
152
+
153
+ Args:
154
+ text_encoder (`torch.nn.Module`, *optional*):
155
+ The text encoder module to disable the LoRA layers for. If `None`, it will try to get the `text_encoder`
156
+ attribute.
157
+ """
158
+ if text_encoder is None:
159
+ raise ValueError("Text Encoder not found.")
160
+ set_adapter_layers(text_encoder, enabled=False)
161
+
162
+
163
+ def enable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None):
164
+ """
165
+ Enables the LoRA layers for the text encoder.
166
+
167
+ Args:
168
+ text_encoder (`torch.nn.Module`, *optional*):
169
+ The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder`
170
+ attribute.
171
+ """
172
+ if text_encoder is None:
173
+ raise ValueError("Text Encoder not found.")
174
+ set_adapter_layers(text_encoder, enabled=True)
175
+
176
+
177
+ def _remove_text_encoder_monkey_patch(text_encoder):
178
+ recurse_remove_peft_layers(text_encoder)
179
+ if getattr(text_encoder, "peft_config", None) is not None:
180
+ del text_encoder.peft_config
181
+ text_encoder._hf_peft_config_loaded = None
182
+
183
+
184
+ class LoraBaseMixin:
185
+ """Utility class for handling LoRAs."""
186
+
187
+ _lora_loadable_modules = []
188
+ num_fused_loras = 0
189
+
190
+ def load_lora_weights(self, **kwargs):
191
+ raise NotImplementedError("`load_lora_weights()` is not implemented.")
192
+
193
+ @classmethod
194
+ def save_lora_weights(cls, **kwargs):
195
+ raise NotImplementedError("`save_lora_weights()` not implemented.")
196
+
197
+ @classmethod
198
+ def lora_state_dict(cls, **kwargs):
199
+ raise NotImplementedError("`lora_state_dict()` is not implemented.")
200
+
201
+ @classmethod
202
+ def _optionally_disable_offloading(cls, _pipeline):
203
+ """
204
+ Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
205
+
206
+ Args:
207
+ _pipeline (`DiffusionPipeline`):
208
+ The pipeline to disable offloading for.
209
+
210
+ Returns:
211
+ tuple:
212
+ A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
213
+ """
214
+ is_model_cpu_offload = False
215
+ is_sequential_cpu_offload = False
216
+
217
+ if _pipeline is not None and _pipeline.hf_device_map is None:
218
+ for _, component in _pipeline.components.items():
219
+ if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
220
+ if not is_model_cpu_offload:
221
+ is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
222
+ if not is_sequential_cpu_offload:
223
+ is_sequential_cpu_offload = (
224
+ isinstance(component._hf_hook, AlignDevicesHook)
225
+ or hasattr(component._hf_hook, "hooks")
226
+ and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
227
+ )
228
+
229
+ logger.info(
230
+ "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
231
+ )
232
+ remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
233
+
234
+ return (is_model_cpu_offload, is_sequential_cpu_offload)
235
+
236
+ @classmethod
237
+ def _fetch_state_dict(
238
+ cls,
239
+ pretrained_model_name_or_path_or_dict,
240
+ weight_name,
241
+ use_safetensors,
242
+ local_files_only,
243
+ cache_dir,
244
+ force_download,
245
+ proxies,
246
+ token,
247
+ revision,
248
+ subfolder,
249
+ user_agent,
250
+ allow_pickle,
251
+ ):
252
+ from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
253
+
254
+ model_file = None
255
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
256
+ # Let's first try to load .safetensors weights
257
+ if (use_safetensors and weight_name is None) or (
258
+ weight_name is not None and weight_name.endswith(".safetensors")
259
+ ):
260
+ try:
261
+ # Here we're relaxing the loading check to enable more Inference API
262
+ # friendliness where sometimes, it's not at all possible to automatically
263
+ # determine `weight_name`.
264
+ if weight_name is None:
265
+ weight_name = cls._best_guess_weight_name(
266
+ pretrained_model_name_or_path_or_dict,
267
+ file_extension=".safetensors",
268
+ local_files_only=local_files_only,
269
+ )
270
+ model_file = _get_model_file(
271
+ pretrained_model_name_or_path_or_dict,
272
+ weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
273
+ cache_dir=cache_dir,
274
+ force_download=force_download,
275
+ proxies=proxies,
276
+ local_files_only=local_files_only,
277
+ token=token,
278
+ revision=revision,
279
+ subfolder=subfolder,
280
+ user_agent=user_agent,
281
+ )
282
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
283
+ except (IOError, safetensors.SafetensorError) as e:
284
+ if not allow_pickle:
285
+ raise e
286
+ # try loading non-safetensors weights
287
+ model_file = None
288
+ pass
289
+
290
+ if model_file is None:
291
+ if weight_name is None:
292
+ weight_name = cls._best_guess_weight_name(
293
+ pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
294
+ )
295
+ model_file = _get_model_file(
296
+ pretrained_model_name_or_path_or_dict,
297
+ weights_name=weight_name or LORA_WEIGHT_NAME,
298
+ cache_dir=cache_dir,
299
+ force_download=force_download,
300
+ proxies=proxies,
301
+ local_files_only=local_files_only,
302
+ token=token,
303
+ revision=revision,
304
+ subfolder=subfolder,
305
+ user_agent=user_agent,
306
+ )
307
+ state_dict = load_state_dict(model_file)
308
+ else:
309
+ state_dict = pretrained_model_name_or_path_or_dict
310
+
311
+ return state_dict
312
+
313
+ @classmethod
314
+ def _best_guess_weight_name(
315
+ cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
316
+ ):
317
+ from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
318
+
319
+ if local_files_only or HF_HUB_OFFLINE:
320
+ raise ValueError("When using the offline mode, you must specify a `weight_name`.")
321
+
322
+ targeted_files = []
323
+
324
+ if os.path.isfile(pretrained_model_name_or_path_or_dict):
325
+ return
326
+ elif os.path.isdir(pretrained_model_name_or_path_or_dict):
327
+ targeted_files = [
328
+ f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)
329
+ ]
330
+ else:
331
+ files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
332
+ targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
333
+ if len(targeted_files) == 0:
334
+ return
335
+
336
+ # "scheduler" does not correspond to a LoRA checkpoint.
337
+ # "optimizer" does not correspond to a LoRA checkpoint
338
+ # only top-level checkpoints are considered and not the other ones, hence "checkpoint".
339
+ unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
340
+ targeted_files = list(
341
+ filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
342
+ )
343
+
344
+ if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
345
+ targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
346
+ elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
347
+ targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
348
+
349
+ if len(targeted_files) > 1:
350
+ raise ValueError(
351
+ f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
352
+ )
353
+ weight_name = targeted_files[0]
354
+ return weight_name
355
+
356
+ def unload_lora_weights(self):
357
+ """
358
+ Unloads the LoRA parameters.
359
+
360
+ Examples:
361
+
362
+ ```python
363
+ >>> # Assuming `pipeline` is already loaded with the LoRA parameters.
364
+ >>> pipeline.unload_lora_weights()
365
+ >>> ...
366
+ ```
367
+ """
368
+ if not USE_PEFT_BACKEND:
369
+ raise ValueError("PEFT backend is required for this method.")
370
+
371
+ for component in self._lora_loadable_modules:
372
+ model = getattr(self, component, None)
373
+ if model is not None:
374
+ if issubclass(model.__class__, ModelMixin):
375
+ model.unload_lora()
376
+ elif issubclass(model.__class__, PreTrainedModel):
377
+ _remove_text_encoder_monkey_patch(model)
378
+
379
+ def fuse_lora(
380
+ self,
381
+ components: List[str] = [],
382
+ lora_scale: float = 1.0,
383
+ safe_fusing: bool = False,
384
+ adapter_names: Optional[List[str]] = None,
385
+ **kwargs,
386
+ ):
387
+ r"""
388
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
389
+
390
+ <Tip warning={true}>
391
+
392
+ This is an experimental API.
393
+
394
+ </Tip>
395
+
396
+ Args:
397
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
398
+ lora_scale (`float`, defaults to 1.0):
399
+ Controls how much to influence the outputs with the LoRA parameters.
400
+ safe_fusing (`bool`, defaults to `False`):
401
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
402
+ adapter_names (`List[str]`, *optional*):
403
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
404
+
405
+ Example:
406
+
407
+ ```py
408
+ from diffusers import DiffusionPipeline
409
+ import torch
410
+
411
+ pipeline = DiffusionPipeline.from_pretrained(
412
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
413
+ ).to("cuda")
414
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
415
+ pipeline.fuse_lora(lora_scale=0.7)
416
+ ```
417
+ """
418
+ if "fuse_unet" in kwargs:
419
+ depr_message = "Passing `fuse_unet` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_unet` will be removed in a future version."
420
+ deprecate(
421
+ "fuse_unet",
422
+ "1.0.0",
423
+ depr_message,
424
+ )
425
+ if "fuse_transformer" in kwargs:
426
+ depr_message = "Passing `fuse_transformer` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_transformer` will be removed in a future version."
427
+ deprecate(
428
+ "fuse_transformer",
429
+ "1.0.0",
430
+ depr_message,
431
+ )
432
+ if "fuse_text_encoder" in kwargs:
433
+ depr_message = "Passing `fuse_text_encoder` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_text_encoder` will be removed in a future version."
434
+ deprecate(
435
+ "fuse_text_encoder",
436
+ "1.0.0",
437
+ depr_message,
438
+ )
439
+
440
+ if len(components) == 0:
441
+ raise ValueError("`components` cannot be an empty list.")
442
+
443
+ for fuse_component in components:
444
+ if fuse_component not in self._lora_loadable_modules:
445
+ raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
446
+
447
+ model = getattr(self, fuse_component, None)
448
+ if model is not None:
449
+ # check if diffusers model
450
+ if issubclass(model.__class__, ModelMixin):
451
+ model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
452
+ # handle transformers models.
453
+ if issubclass(model.__class__, PreTrainedModel):
454
+ fuse_text_encoder_lora(
455
+ model, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
456
+ )
457
+
458
+ self.num_fused_loras += 1
459
+
460
+ def unfuse_lora(self, components: List[str] = [], **kwargs):
461
+ r"""
462
+ Reverses the effect of
463
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
464
+
465
+ <Tip warning={true}>
466
+
467
+ This is an experimental API.
468
+
469
+ </Tip>
470
+
471
+ Args:
472
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
473
+ unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
474
+ unfuse_text_encoder (`bool`, defaults to `True`):
475
+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
476
+ LoRA parameters then it won't have any effect.
477
+ """
478
+ if "unfuse_unet" in kwargs:
479
+ depr_message = "Passing `unfuse_unet` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_unet` will be removed in a future version."
480
+ deprecate(
481
+ "unfuse_unet",
482
+ "1.0.0",
483
+ depr_message,
484
+ )
485
+ if "unfuse_transformer" in kwargs:
486
+ depr_message = "Passing `unfuse_transformer` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_transformer` will be removed in a future version."
487
+ deprecate(
488
+ "unfuse_transformer",
489
+ "1.0.0",
490
+ depr_message,
491
+ )
492
+ if "unfuse_text_encoder" in kwargs:
493
+ depr_message = "Passing `unfuse_text_encoder` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_text_encoder` will be removed in a future version."
494
+ deprecate(
495
+ "unfuse_text_encoder",
496
+ "1.0.0",
497
+ depr_message,
498
+ )
499
+
500
+ if len(components) == 0:
501
+ raise ValueError("`components` cannot be an empty list.")
502
+
503
+ for fuse_component in components:
504
+ if fuse_component not in self._lora_loadable_modules:
505
+ raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
506
+
507
+ model = getattr(self, fuse_component, None)
508
+ if model is not None:
509
+ if issubclass(model.__class__, (ModelMixin, PreTrainedModel)):
510
+ for module in model.modules():
511
+ if isinstance(module, BaseTunerLayer):
512
+ module.unmerge()
513
+
514
+ self.num_fused_loras -= 1
515
+
516
+ def set_adapters(
517
+ self,
518
+ adapter_names: Union[List[str], str],
519
+ adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
520
+ ):
521
+ adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
522
+
523
+ adapter_weights = copy.deepcopy(adapter_weights)
524
+
525
+ # Expand weights into a list, one entry per adapter
526
+ if not isinstance(adapter_weights, list):
527
+ adapter_weights = [adapter_weights] * len(adapter_names)
528
+
529
+ if len(adapter_names) != len(adapter_weights):
530
+ raise ValueError(
531
+ f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}"
532
+ )
533
+
534
+ list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
535
+ # eg ["adapter1", "adapter2"]
536
+ all_adapters = {adapter for adapters in list_adapters.values() for adapter in adapters}
537
+ missing_adapters = set(adapter_names) - all_adapters
538
+ if len(missing_adapters) > 0:
539
+ raise ValueError(
540
+ f"Adapter name(s) {missing_adapters} not in the list of present adapters: {all_adapters}."
541
+ )
542
+
543
+ # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}
544
+ invert_list_adapters = {
545
+ adapter: [part for part, adapters in list_adapters.items() if adapter in adapters]
546
+ for adapter in all_adapters
547
+ }
548
+
549
+ # Decompose weights into weights for denoiser and text encoders.
550
+ _component_adapter_weights = {}
551
+ for component in self._lora_loadable_modules:
552
+ model = getattr(self, component)
553
+
554
+ for adapter_name, weights in zip(adapter_names, adapter_weights):
555
+ if isinstance(weights, dict):
556
+ component_adapter_weights = weights.pop(component, None)
557
+
558
+ if component_adapter_weights is not None and not hasattr(self, component):
559
+ logger.warning(
560
+ f"Lora weight dict contains {component} weights but will be ignored because pipeline does not have {component}."
561
+ )
562
+
563
+ if component_adapter_weights is not None and component not in invert_list_adapters[adapter_name]:
564
+ logger.warning(
565
+ (
566
+ f"Lora weight dict for adapter '{adapter_name}' contains {component},"
567
+ f"but this will be ignored because {adapter_name} does not contain weights for {component}."
568
+ f"Valid parts for {adapter_name} are: {invert_list_adapters[adapter_name]}."
569
+ )
570
+ )
571
+
572
+ else:
573
+ component_adapter_weights = weights
574
+
575
+ _component_adapter_weights.setdefault(component, [])
576
+ _component_adapter_weights[component].append(component_adapter_weights)
577
+
578
+ if issubclass(model.__class__, ModelMixin):
579
+ model.set_adapters(adapter_names, _component_adapter_weights[component])
580
+ elif issubclass(model.__class__, PreTrainedModel):
581
+ set_adapters_for_text_encoder(adapter_names, model, _component_adapter_weights[component])
582
+
583
+ def disable_lora(self):
584
+ if not USE_PEFT_BACKEND:
585
+ raise ValueError("PEFT backend is required for this method.")
586
+
587
+ for component in self._lora_loadable_modules:
588
+ model = getattr(self, component, None)
589
+ if model is not None:
590
+ if issubclass(model.__class__, ModelMixin):
591
+ model.disable_lora()
592
+ elif issubclass(model.__class__, PreTrainedModel):
593
+ disable_lora_for_text_encoder(model)
594
+
595
+ def enable_lora(self):
596
+ if not USE_PEFT_BACKEND:
597
+ raise ValueError("PEFT backend is required for this method.")
598
+
599
+ for component in self._lora_loadable_modules:
600
+ model = getattr(self, component, None)
601
+ if model is not None:
602
+ if issubclass(model.__class__, ModelMixin):
603
+ model.enable_lora()
604
+ elif issubclass(model.__class__, PreTrainedModel):
605
+ enable_lora_for_text_encoder(model)
606
+
607
+ def delete_adapters(self, adapter_names: Union[List[str], str]):
608
+ """
609
+ Args:
610
+ Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s).
611
+ adapter_names (`Union[List[str], str]`):
612
+ The names of the adapter to delete. Can be a single string or a list of strings
613
+ """
614
+ if not USE_PEFT_BACKEND:
615
+ raise ValueError("PEFT backend is required for this method.")
616
+
617
+ if isinstance(adapter_names, str):
618
+ adapter_names = [adapter_names]
619
+
620
+ for component in self._lora_loadable_modules:
621
+ model = getattr(self, component, None)
622
+ if model is not None:
623
+ if issubclass(model.__class__, ModelMixin):
624
+ model.delete_adapters(adapter_names)
625
+ elif issubclass(model.__class__, PreTrainedModel):
626
+ for adapter_name in adapter_names:
627
+ delete_adapter_layers(model, adapter_name)
628
+
629
+ def get_active_adapters(self) -> List[str]:
630
+ """
631
+ Gets the list of the current active adapters.
632
+
633
+ Example:
634
+
635
+ ```python
636
+ from diffusers import DiffusionPipeline
637
+
638
+ pipeline = DiffusionPipeline.from_pretrained(
639
+ "stabilityai/stable-diffusion-xl-base-1.0",
640
+ ).to("cuda")
641
+ pipeline.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
642
+ pipeline.get_active_adapters()
643
+ ```
644
+ """
645
+ if not USE_PEFT_BACKEND:
646
+ raise ValueError(
647
+ "PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
648
+ )
649
+
650
+ active_adapters = []
651
+
652
+ for component in self._lora_loadable_modules:
653
+ model = getattr(self, component, None)
654
+ if model is not None and issubclass(model.__class__, ModelMixin):
655
+ for module in model.modules():
656
+ if isinstance(module, BaseTunerLayer):
657
+ active_adapters = module.active_adapters
658
+ break
659
+
660
+ return active_adapters
661
+
662
+ def get_list_adapters(self) -> Dict[str, List[str]]:
663
+ """
664
+ Gets the current list of all available adapters in the pipeline.
665
+ """
666
+ if not USE_PEFT_BACKEND:
667
+ raise ValueError(
668
+ "PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
669
+ )
670
+
671
+ set_adapters = {}
672
+
673
+ for component in self._lora_loadable_modules:
674
+ model = getattr(self, component, None)
675
+ if (
676
+ model is not None
677
+ and issubclass(model.__class__, (ModelMixin, PreTrainedModel))
678
+ and hasattr(model, "peft_config")
679
+ ):
680
+ set_adapters[component] = list(model.peft_config.keys())
681
+
682
+ return set_adapters
683
+
684
+ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None:
685
+ """
686
+ Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
687
+ you want to load multiple adapters and free some GPU memory.
688
+
689
+ Args:
690
+ adapter_names (`List[str]`):
691
+ List of adapters to send device to.
692
+ device (`Union[torch.device, str, int]`):
693
+ Device to send the adapters to. Can be either a torch device, a str or an integer.
694
+ """
695
+ if not USE_PEFT_BACKEND:
696
+ raise ValueError("PEFT backend is required for this method.")
697
+
698
+ for component in self._lora_loadable_modules:
699
+ model = getattr(self, component, None)
700
+ if model is not None:
701
+ for module in model.modules():
702
+ if isinstance(module, BaseTunerLayer):
703
+ for adapter_name in adapter_names:
704
+ module.lora_A[adapter_name].to(device)
705
+ module.lora_B[adapter_name].to(device)
706
+ # this is a param, not a module, so device placement is not in-place -> re-assign
707
+ if hasattr(module, "lora_magnitude_vector") and module.lora_magnitude_vector is not None:
708
+ if adapter_name in module.lora_magnitude_vector:
709
+ module.lora_magnitude_vector[adapter_name] = module.lora_magnitude_vector[
710
+ adapter_name
711
+ ].to(device)
712
+
713
+ @staticmethod
714
+ def pack_weights(layers, prefix):
715
+ layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
716
+ layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
717
+ return layers_state_dict
718
+
719
+ @staticmethod
720
+ def write_lora_layers(
721
+ state_dict: Dict[str, torch.Tensor],
722
+ save_directory: str,
723
+ is_main_process: bool,
724
+ weight_name: str,
725
+ save_function: Callable,
726
+ safe_serialization: bool,
727
+ ):
728
+ from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
729
+
730
+ if os.path.isfile(save_directory):
731
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
732
+ return
733
+
734
+ if save_function is None:
735
+ if safe_serialization:
736
+
737
+ def save_function(weights, filename):
738
+ return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
739
+
740
+ else:
741
+ save_function = torch.save
742
+
743
+ os.makedirs(save_directory, exist_ok=True)
744
+
745
+ if weight_name is None:
746
+ if safe_serialization:
747
+ weight_name = LORA_WEIGHT_NAME_SAFE
748
+ else:
749
+ weight_name = LORA_WEIGHT_NAME
750
+
751
+ save_path = Path(save_directory, weight_name).as_posix()
752
+ save_function(state_dict, save_path)
753
+ logger.info(f"Model weights saved in {save_path}")
754
+
755
+ @property
756
+ def lora_scale(self) -> float:
757
+ # property function that returns the lora scale which can be set at run time by the pipeline.
758
+ # if _lora_scale has not been set, return 1
759
+ return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
diffusers/loaders/lora_conversion_utils.py ADDED
@@ -0,0 +1,660 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+
17
+ import torch
18
+
19
+ from ..utils import is_peft_version, logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", block_slice_pos=5):
26
+ # 1. get all state_dict_keys
27
+ all_keys = list(state_dict.keys())
28
+ sgm_patterns = ["input_blocks", "middle_block", "output_blocks"]
29
+
30
+ # 2. check if needs remapping, if not return original dict
31
+ is_in_sgm_format = False
32
+ for key in all_keys:
33
+ if any(p in key for p in sgm_patterns):
34
+ is_in_sgm_format = True
35
+ break
36
+
37
+ if not is_in_sgm_format:
38
+ return state_dict
39
+
40
+ # 3. Else remap from SGM patterns
41
+ new_state_dict = {}
42
+ inner_block_map = ["resnets", "attentions", "upsamplers"]
43
+
44
+ # Retrieves # of down, mid and up blocks
45
+ input_block_ids, middle_block_ids, output_block_ids = set(), set(), set()
46
+
47
+ for layer in all_keys:
48
+ if "text" in layer:
49
+ new_state_dict[layer] = state_dict.pop(layer)
50
+ else:
51
+ layer_id = int(layer.split(delimiter)[:block_slice_pos][-1])
52
+ if sgm_patterns[0] in layer:
53
+ input_block_ids.add(layer_id)
54
+ elif sgm_patterns[1] in layer:
55
+ middle_block_ids.add(layer_id)
56
+ elif sgm_patterns[2] in layer:
57
+ output_block_ids.add(layer_id)
58
+ else:
59
+ raise ValueError(f"Checkpoint not supported because layer {layer} not supported.")
60
+
61
+ input_blocks = {
62
+ layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key]
63
+ for layer_id in input_block_ids
64
+ }
65
+ middle_blocks = {
66
+ layer_id: [key for key in state_dict if f"middle_block{delimiter}{layer_id}" in key]
67
+ for layer_id in middle_block_ids
68
+ }
69
+ output_blocks = {
70
+ layer_id: [key for key in state_dict if f"output_blocks{delimiter}{layer_id}" in key]
71
+ for layer_id in output_block_ids
72
+ }
73
+
74
+ # Rename keys accordingly
75
+ for i in input_block_ids:
76
+ block_id = (i - 1) // (unet_config.layers_per_block + 1)
77
+ layer_in_block_id = (i - 1) % (unet_config.layers_per_block + 1)
78
+
79
+ for key in input_blocks[i]:
80
+ inner_block_id = int(key.split(delimiter)[block_slice_pos])
81
+ inner_block_key = inner_block_map[inner_block_id] if "op" not in key else "downsamplers"
82
+ inner_layers_in_block = str(layer_in_block_id) if "op" not in key else "0"
83
+ new_key = delimiter.join(
84
+ key.split(delimiter)[: block_slice_pos - 1]
85
+ + [str(block_id), inner_block_key, inner_layers_in_block]
86
+ + key.split(delimiter)[block_slice_pos + 1 :]
87
+ )
88
+ new_state_dict[new_key] = state_dict.pop(key)
89
+
90
+ for i in middle_block_ids:
91
+ key_part = None
92
+ if i == 0:
93
+ key_part = [inner_block_map[0], "0"]
94
+ elif i == 1:
95
+ key_part = [inner_block_map[1], "0"]
96
+ elif i == 2:
97
+ key_part = [inner_block_map[0], "1"]
98
+ else:
99
+ raise ValueError(f"Invalid middle block id {i}.")
100
+
101
+ for key in middle_blocks[i]:
102
+ new_key = delimiter.join(
103
+ key.split(delimiter)[: block_slice_pos - 1] + key_part + key.split(delimiter)[block_slice_pos:]
104
+ )
105
+ new_state_dict[new_key] = state_dict.pop(key)
106
+
107
+ for i in output_block_ids:
108
+ block_id = i // (unet_config.layers_per_block + 1)
109
+ layer_in_block_id = i % (unet_config.layers_per_block + 1)
110
+
111
+ for key in output_blocks[i]:
112
+ inner_block_id = int(key.split(delimiter)[block_slice_pos])
113
+ inner_block_key = inner_block_map[inner_block_id]
114
+ inner_layers_in_block = str(layer_in_block_id) if inner_block_id < 2 else "0"
115
+ new_key = delimiter.join(
116
+ key.split(delimiter)[: block_slice_pos - 1]
117
+ + [str(block_id), inner_block_key, inner_layers_in_block]
118
+ + key.split(delimiter)[block_slice_pos + 1 :]
119
+ )
120
+ new_state_dict[new_key] = state_dict.pop(key)
121
+
122
+ if len(state_dict) > 0:
123
+ raise ValueError("At this point all state dict entries have to be converted.")
124
+
125
+ return new_state_dict
126
+
127
+
128
+ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"):
129
+ """
130
+ Converts a non-Diffusers LoRA state dict to a Diffusers compatible state dict.
131
+
132
+ Args:
133
+ state_dict (`dict`): The state dict to convert.
134
+ unet_name (`str`, optional): The name of the U-Net module in the Diffusers model. Defaults to "unet".
135
+ text_encoder_name (`str`, optional): The name of the text encoder module in the Diffusers model. Defaults to
136
+ "text_encoder".
137
+
138
+ Returns:
139
+ `tuple`: A tuple containing the converted state dict and a dictionary of alphas.
140
+ """
141
+ unet_state_dict = {}
142
+ te_state_dict = {}
143
+ te2_state_dict = {}
144
+ network_alphas = {}
145
+
146
+ # Check for DoRA-enabled LoRAs.
147
+ dora_present_in_unet = any("dora_scale" in k and "lora_unet_" in k for k in state_dict)
148
+ dora_present_in_te = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict)
149
+ dora_present_in_te2 = any("dora_scale" in k and "lora_te2_" in k for k in state_dict)
150
+ if dora_present_in_unet or dora_present_in_te or dora_present_in_te2:
151
+ if is_peft_version("<", "0.9.0"):
152
+ raise ValueError(
153
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
154
+ )
155
+
156
+ # Iterate over all LoRA weights.
157
+ all_lora_keys = list(state_dict.keys())
158
+ for key in all_lora_keys:
159
+ if not key.endswith("lora_down.weight"):
160
+ continue
161
+
162
+ # Extract LoRA name.
163
+ lora_name = key.split(".")[0]
164
+
165
+ # Find corresponding up weight and alpha.
166
+ lora_name_up = lora_name + ".lora_up.weight"
167
+ lora_name_alpha = lora_name + ".alpha"
168
+
169
+ # Handle U-Net LoRAs.
170
+ if lora_name.startswith("lora_unet_"):
171
+ diffusers_name = _convert_unet_lora_key(key)
172
+
173
+ # Store down and up weights.
174
+ unet_state_dict[diffusers_name] = state_dict.pop(key)
175
+ unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
176
+
177
+ # Store DoRA scale if present.
178
+ if dora_present_in_unet:
179
+ dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
180
+ unet_state_dict[
181
+ diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
182
+ ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
183
+
184
+ # Handle text encoder LoRAs.
185
+ elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
186
+ diffusers_name = _convert_text_encoder_lora_key(key, lora_name)
187
+
188
+ # Store down and up weights for te or te2.
189
+ if lora_name.startswith(("lora_te_", "lora_te1_")):
190
+ te_state_dict[diffusers_name] = state_dict.pop(key)
191
+ te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
192
+ else:
193
+ te2_state_dict[diffusers_name] = state_dict.pop(key)
194
+ te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
195
+
196
+ # Store DoRA scale if present.
197
+ if dora_present_in_te or dora_present_in_te2:
198
+ dora_scale_key_to_replace_te = (
199
+ "_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
200
+ )
201
+ if lora_name.startswith(("lora_te_", "lora_te1_")):
202
+ te_state_dict[
203
+ diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
204
+ ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
205
+ elif lora_name.startswith("lora_te2_"):
206
+ te2_state_dict[
207
+ diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
208
+ ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
209
+
210
+ # Store alpha if present.
211
+ if lora_name_alpha in state_dict:
212
+ alpha = state_dict.pop(lora_name_alpha).item()
213
+ network_alphas.update(_get_alpha_name(lora_name_alpha, diffusers_name, alpha))
214
+
215
+ # Check if any keys remain.
216
+ if len(state_dict) > 0:
217
+ raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}")
218
+
219
+ logger.info("Non-diffusers checkpoint detected.")
220
+
221
+ # Construct final state dict.
222
+ unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
223
+ te_state_dict = {f"{text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()}
224
+ te2_state_dict = (
225
+ {f"text_encoder_2.{module_name}": params for module_name, params in te2_state_dict.items()}
226
+ if len(te2_state_dict) > 0
227
+ else None
228
+ )
229
+ if te2_state_dict is not None:
230
+ te_state_dict.update(te2_state_dict)
231
+
232
+ new_state_dict = {**unet_state_dict, **te_state_dict}
233
+ return new_state_dict, network_alphas
234
+
235
+
236
+ def _convert_unet_lora_key(key):
237
+ """
238
+ Converts a U-Net LoRA key to a Diffusers compatible key.
239
+ """
240
+ diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
241
+
242
+ # Replace common U-Net naming patterns.
243
+ diffusers_name = diffusers_name.replace("input.blocks", "down_blocks")
244
+ diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
245
+ diffusers_name = diffusers_name.replace("middle.block", "mid_block")
246
+ diffusers_name = diffusers_name.replace("mid.block", "mid_block")
247
+ diffusers_name = diffusers_name.replace("output.blocks", "up_blocks")
248
+ diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
249
+ diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
250
+ diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
251
+ diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
252
+ diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
253
+ diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
254
+ diffusers_name = diffusers_name.replace("proj.in", "proj_in")
255
+ diffusers_name = diffusers_name.replace("proj.out", "proj_out")
256
+ diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
257
+
258
+ # SDXL specific conversions.
259
+ if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name:
260
+ pattern = r"\.\d+(?=\D*$)"
261
+ diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
262
+ if ".in." in diffusers_name:
263
+ diffusers_name = diffusers_name.replace("in.layers.2", "conv1")
264
+ if ".out." in diffusers_name:
265
+ diffusers_name = diffusers_name.replace("out.layers.3", "conv2")
266
+ if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name:
267
+ diffusers_name = diffusers_name.replace("op", "conv")
268
+ if "skip" in diffusers_name:
269
+ diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
270
+
271
+ # LyCORIS specific conversions.
272
+ if "time.emb.proj" in diffusers_name:
273
+ diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
274
+ if "conv.shortcut" in diffusers_name:
275
+ diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")
276
+
277
+ # General conversions.
278
+ if "transformer_blocks" in diffusers_name:
279
+ if "attn1" in diffusers_name or "attn2" in diffusers_name:
280
+ diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
281
+ diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
282
+ elif "ff" in diffusers_name:
283
+ pass
284
+ elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
285
+ pass
286
+ else:
287
+ pass
288
+
289
+ return diffusers_name
290
+
291
+
292
+ def _convert_text_encoder_lora_key(key, lora_name):
293
+ """
294
+ Converts a text encoder LoRA key to a Diffusers compatible key.
295
+ """
296
+ if lora_name.startswith(("lora_te_", "lora_te1_")):
297
+ key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_"
298
+ else:
299
+ key_to_replace = "lora_te2_"
300
+
301
+ diffusers_name = key.replace(key_to_replace, "").replace("_", ".")
302
+ diffusers_name = diffusers_name.replace("text.model", "text_model")
303
+ diffusers_name = diffusers_name.replace("self.attn", "self_attn")
304
+ diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
305
+ diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
306
+ diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
307
+ diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
308
+ diffusers_name = diffusers_name.replace("text.projection", "text_projection")
309
+
310
+ if "self_attn" in diffusers_name or "text_projection" in diffusers_name:
311
+ pass
312
+ elif "mlp" in diffusers_name:
313
+ # Be aware that this is the new diffusers convention and the rest of the code might
314
+ # not utilize it yet.
315
+ diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
316
+ return diffusers_name
317
+
318
+
319
+ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
320
+ """
321
+ Gets the correct alpha name for the Diffusers model.
322
+ """
323
+ if lora_name_alpha.startswith("lora_unet_"):
324
+ prefix = "unet."
325
+ elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")):
326
+ prefix = "text_encoder."
327
+ else:
328
+ prefix = "text_encoder_2."
329
+ new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
330
+ return {new_name: alpha}
331
+
332
+
333
+ # The utilities under `_convert_kohya_flux_lora_to_diffusers()`
334
+ # are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
335
+ # All credits go to `kohya-ss`.
336
+ def _convert_kohya_flux_lora_to_diffusers(state_dict):
337
+ def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
338
+ if sds_key + ".lora_down.weight" not in sds_sd:
339
+ return
340
+ down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
341
+
342
+ # scale weight by alpha and dim
343
+ rank = down_weight.shape[0]
344
+ alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar
345
+ scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
346
+
347
+ # calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
348
+ scale_down = scale
349
+ scale_up = 1.0
350
+ while scale_down * 2 < scale_up:
351
+ scale_down *= 2
352
+ scale_up /= 2
353
+
354
+ ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down
355
+ ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up
356
+
357
+ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
358
+ if sds_key + ".lora_down.weight" not in sds_sd:
359
+ return
360
+ down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
361
+ up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
362
+ sd_lora_rank = down_weight.shape[0]
363
+
364
+ # scale weight by alpha and dim
365
+ alpha = sds_sd.pop(sds_key + ".alpha")
366
+ scale = alpha / sd_lora_rank
367
+
368
+ # calculate scale_down and scale_up
369
+ scale_down = scale
370
+ scale_up = 1.0
371
+ while scale_down * 2 < scale_up:
372
+ scale_down *= 2
373
+ scale_up /= 2
374
+
375
+ down_weight = down_weight * scale_down
376
+ up_weight = up_weight * scale_up
377
+
378
+ # calculate dims if not provided
379
+ num_splits = len(ait_keys)
380
+ if dims is None:
381
+ dims = [up_weight.shape[0] // num_splits] * num_splits
382
+ else:
383
+ assert sum(dims) == up_weight.shape[0]
384
+
385
+ # check upweight is sparse or not
386
+ is_sparse = False
387
+ if sd_lora_rank % num_splits == 0:
388
+ ait_rank = sd_lora_rank // num_splits
389
+ is_sparse = True
390
+ i = 0
391
+ for j in range(len(dims)):
392
+ for k in range(len(dims)):
393
+ if j == k:
394
+ continue
395
+ is_sparse = is_sparse and torch.all(
396
+ up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0
397
+ )
398
+ i += dims[j]
399
+ if is_sparse:
400
+ logger.info(f"weight is sparse: {sds_key}")
401
+
402
+ # make ai-toolkit weight
403
+ ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
404
+ ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
405
+ if not is_sparse:
406
+ # down_weight is copied to each split
407
+ ait_sd.update({k: down_weight for k in ait_down_keys})
408
+
409
+ # up_weight is split to each split
410
+ ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
411
+ else:
412
+ # down_weight is chunked to each split
413
+ ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416
414
+
415
+ # up_weight is sparse: only non-zero values are copied to each split
416
+ i = 0
417
+ for j in range(len(dims)):
418
+ ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous()
419
+ i += dims[j]
420
+
421
+ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
422
+ ait_sd = {}
423
+ for i in range(19):
424
+ _convert_to_ai_toolkit(
425
+ sds_sd,
426
+ ait_sd,
427
+ f"lora_unet_double_blocks_{i}_img_attn_proj",
428
+ f"transformer.transformer_blocks.{i}.attn.to_out.0",
429
+ )
430
+ _convert_to_ai_toolkit_cat(
431
+ sds_sd,
432
+ ait_sd,
433
+ f"lora_unet_double_blocks_{i}_img_attn_qkv",
434
+ [
435
+ f"transformer.transformer_blocks.{i}.attn.to_q",
436
+ f"transformer.transformer_blocks.{i}.attn.to_k",
437
+ f"transformer.transformer_blocks.{i}.attn.to_v",
438
+ ],
439
+ )
440
+ _convert_to_ai_toolkit(
441
+ sds_sd,
442
+ ait_sd,
443
+ f"lora_unet_double_blocks_{i}_img_mlp_0",
444
+ f"transformer.transformer_blocks.{i}.ff.net.0.proj",
445
+ )
446
+ _convert_to_ai_toolkit(
447
+ sds_sd,
448
+ ait_sd,
449
+ f"lora_unet_double_blocks_{i}_img_mlp_2",
450
+ f"transformer.transformer_blocks.{i}.ff.net.2",
451
+ )
452
+ _convert_to_ai_toolkit(
453
+ sds_sd,
454
+ ait_sd,
455
+ f"lora_unet_double_blocks_{i}_img_mod_lin",
456
+ f"transformer.transformer_blocks.{i}.norm1.linear",
457
+ )
458
+ _convert_to_ai_toolkit(
459
+ sds_sd,
460
+ ait_sd,
461
+ f"lora_unet_double_blocks_{i}_txt_attn_proj",
462
+ f"transformer.transformer_blocks.{i}.attn.to_add_out",
463
+ )
464
+ _convert_to_ai_toolkit_cat(
465
+ sds_sd,
466
+ ait_sd,
467
+ f"lora_unet_double_blocks_{i}_txt_attn_qkv",
468
+ [
469
+ f"transformer.transformer_blocks.{i}.attn.add_q_proj",
470
+ f"transformer.transformer_blocks.{i}.attn.add_k_proj",
471
+ f"transformer.transformer_blocks.{i}.attn.add_v_proj",
472
+ ],
473
+ )
474
+ _convert_to_ai_toolkit(
475
+ sds_sd,
476
+ ait_sd,
477
+ f"lora_unet_double_blocks_{i}_txt_mlp_0",
478
+ f"transformer.transformer_blocks.{i}.ff_context.net.0.proj",
479
+ )
480
+ _convert_to_ai_toolkit(
481
+ sds_sd,
482
+ ait_sd,
483
+ f"lora_unet_double_blocks_{i}_txt_mlp_2",
484
+ f"transformer.transformer_blocks.{i}.ff_context.net.2",
485
+ )
486
+ _convert_to_ai_toolkit(
487
+ sds_sd,
488
+ ait_sd,
489
+ f"lora_unet_double_blocks_{i}_txt_mod_lin",
490
+ f"transformer.transformer_blocks.{i}.norm1_context.linear",
491
+ )
492
+
493
+ for i in range(38):
494
+ _convert_to_ai_toolkit_cat(
495
+ sds_sd,
496
+ ait_sd,
497
+ f"lora_unet_single_blocks_{i}_linear1",
498
+ [
499
+ f"transformer.single_transformer_blocks.{i}.attn.to_q",
500
+ f"transformer.single_transformer_blocks.{i}.attn.to_k",
501
+ f"transformer.single_transformer_blocks.{i}.attn.to_v",
502
+ f"transformer.single_transformer_blocks.{i}.proj_mlp",
503
+ ],
504
+ dims=[3072, 3072, 3072, 12288],
505
+ )
506
+ _convert_to_ai_toolkit(
507
+ sds_sd,
508
+ ait_sd,
509
+ f"lora_unet_single_blocks_{i}_linear2",
510
+ f"transformer.single_transformer_blocks.{i}.proj_out",
511
+ )
512
+ _convert_to_ai_toolkit(
513
+ sds_sd,
514
+ ait_sd,
515
+ f"lora_unet_single_blocks_{i}_modulation_lin",
516
+ f"transformer.single_transformer_blocks.{i}.norm.linear",
517
+ )
518
+
519
+ remaining_keys = list(sds_sd.keys())
520
+ te_state_dict = {}
521
+ if remaining_keys:
522
+ if not all(k.startswith("lora_te1") for k in remaining_keys):
523
+ raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
524
+ for key in remaining_keys:
525
+ if not key.endswith("lora_down.weight"):
526
+ continue
527
+
528
+ lora_name = key.split(".")[0]
529
+ lora_name_up = f"{lora_name}.lora_up.weight"
530
+ lora_name_alpha = f"{lora_name}.alpha"
531
+ diffusers_name = _convert_text_encoder_lora_key(key, lora_name)
532
+
533
+ if lora_name.startswith(("lora_te_", "lora_te1_")):
534
+ down_weight = sds_sd.pop(key)
535
+ sd_lora_rank = down_weight.shape[0]
536
+ te_state_dict[diffusers_name] = down_weight
537
+ te_state_dict[diffusers_name.replace(".down.", ".up.")] = sds_sd.pop(lora_name_up)
538
+
539
+ if lora_name_alpha in sds_sd:
540
+ alpha = sds_sd.pop(lora_name_alpha).item()
541
+ scale = alpha / sd_lora_rank
542
+
543
+ scale_down = scale
544
+ scale_up = 1.0
545
+ while scale_down * 2 < scale_up:
546
+ scale_down *= 2
547
+ scale_up /= 2
548
+
549
+ te_state_dict[diffusers_name] *= scale_down
550
+ te_state_dict[diffusers_name.replace(".down.", ".up.")] *= scale_up
551
+
552
+ if len(sds_sd) > 0:
553
+ logger.warning(f"Unsupported keys for ai-toolkit: {sds_sd.keys()}")
554
+
555
+ if te_state_dict:
556
+ te_state_dict = {f"text_encoder.{module_name}": params for module_name, params in te_state_dict.items()}
557
+
558
+ new_state_dict = {**ait_sd, **te_state_dict}
559
+ return new_state_dict
560
+
561
+ return _convert_sd_scripts_to_ai_toolkit(state_dict)
562
+
563
+
564
+ # Adapted from https://gist.github.com/Leommm-byte/6b331a1e9bd53271210b26543a7065d6
565
+ # Some utilities were reused from
566
+ # https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
567
+ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
568
+ new_state_dict = {}
569
+ orig_keys = list(old_state_dict.keys())
570
+
571
+ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
572
+ down_weight = sds_sd.pop(sds_key)
573
+ up_weight = sds_sd.pop(sds_key.replace(".down.weight", ".up.weight"))
574
+
575
+ # calculate dims if not provided
576
+ num_splits = len(ait_keys)
577
+ if dims is None:
578
+ dims = [up_weight.shape[0] // num_splits] * num_splits
579
+ else:
580
+ assert sum(dims) == up_weight.shape[0]
581
+
582
+ # make ai-toolkit weight
583
+ ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
584
+ ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
585
+
586
+ # down_weight is copied to each split
587
+ ait_sd.update({k: down_weight for k in ait_down_keys})
588
+
589
+ # up_weight is split to each split
590
+ ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
591
+
592
+ for old_key in orig_keys:
593
+ # Handle double_blocks
594
+ if old_key.startswith(("diffusion_model.double_blocks", "double_blocks")):
595
+ block_num = re.search(r"double_blocks\.(\d+)", old_key).group(1)
596
+ new_key = f"transformer.transformer_blocks.{block_num}"
597
+
598
+ if "processor.proj_lora1" in old_key:
599
+ new_key += ".attn.to_out.0"
600
+ elif "processor.proj_lora2" in old_key:
601
+ new_key += ".attn.to_add_out"
602
+ # Handle text latents.
603
+ elif "processor.qkv_lora2" in old_key and "up" not in old_key:
604
+ handle_qkv(
605
+ old_state_dict,
606
+ new_state_dict,
607
+ old_key,
608
+ [
609
+ f"transformer.transformer_blocks.{block_num}.attn.add_q_proj",
610
+ f"transformer.transformer_blocks.{block_num}.attn.add_k_proj",
611
+ f"transformer.transformer_blocks.{block_num}.attn.add_v_proj",
612
+ ],
613
+ )
614
+ # continue
615
+ # Handle image latents.
616
+ elif "processor.qkv_lora1" in old_key and "up" not in old_key:
617
+ handle_qkv(
618
+ old_state_dict,
619
+ new_state_dict,
620
+ old_key,
621
+ [
622
+ f"transformer.transformer_blocks.{block_num}.attn.to_q",
623
+ f"transformer.transformer_blocks.{block_num}.attn.to_k",
624
+ f"transformer.transformer_blocks.{block_num}.attn.to_v",
625
+ ],
626
+ )
627
+ # continue
628
+
629
+ if "down" in old_key:
630
+ new_key += ".lora_A.weight"
631
+ elif "up" in old_key:
632
+ new_key += ".lora_B.weight"
633
+
634
+ # Handle single_blocks
635
+ elif old_key.startswith(("diffusion_model.single_blocks", "single_blocks")):
636
+ block_num = re.search(r"single_blocks\.(\d+)", old_key).group(1)
637
+ new_key = f"transformer.single_transformer_blocks.{block_num}"
638
+
639
+ if "proj_lora1" in old_key or "proj_lora2" in old_key:
640
+ new_key += ".proj_out"
641
+ elif "qkv_lora1" in old_key or "qkv_lora2" in old_key:
642
+ new_key += ".norm.linear"
643
+
644
+ if "down" in old_key:
645
+ new_key += ".lora_A.weight"
646
+ elif "up" in old_key:
647
+ new_key += ".lora_B.weight"
648
+
649
+ else:
650
+ # Handle other potential key patterns here
651
+ new_key = old_key
652
+
653
+ # Since we already handle qkv above.
654
+ if "qkv" not in old_key:
655
+ new_state_dict[new_key] = old_state_dict.pop(old_key)
656
+
657
+ if len(old_state_dict) > 0:
658
+ raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.")
659
+
660
+ return new_state_dict
diffusers/loaders/lora_pipeline.py ADDED
The diff for this file is too large to render. See raw diff
 
diffusers/loaders/peft.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import inspect
16
+ from functools import partial
17
+ from typing import Dict, List, Optional, Union
18
+
19
+ from ..utils import (
20
+ MIN_PEFT_VERSION,
21
+ USE_PEFT_BACKEND,
22
+ check_peft_version,
23
+ delete_adapter_layers,
24
+ is_peft_available,
25
+ set_adapter_layers,
26
+ set_weights_and_activate_adapters,
27
+ )
28
+ from .unet_loader_utils import _maybe_expand_lora_scales
29
+
30
+
31
+ _SET_ADAPTER_SCALE_FN_MAPPING = {
32
+ "UNet2DConditionModel": _maybe_expand_lora_scales,
33
+ "UNetMotionModel": _maybe_expand_lora_scales,
34
+ "SD3Transformer2DModel": lambda model_cls, weights: weights,
35
+ "FluxTransformer2DModel": lambda model_cls, weights: weights,
36
+ "CogVideoXTransformer3DModel": lambda model_cls, weights: weights,
37
+ }
38
+
39
+
40
+ class PeftAdapterMixin:
41
+ """
42
+ A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
43
+ more details about adapters and injecting them in a base model, check out the PEFT
44
+ [documentation](https://huggingface.co/docs/peft/index).
45
+
46
+ Install the latest version of PEFT, and use this mixin to:
47
+
48
+ - Attach new adapters in the model.
49
+ - Attach multiple adapters and iteratively activate/deactivate them.
50
+ - Activate/deactivate all adapters from the model.
51
+ - Get a list of the active adapters.
52
+ """
53
+
54
+ _hf_peft_config_loaded = False
55
+
56
+ def set_adapters(
57
+ self,
58
+ adapter_names: Union[List[str], str],
59
+ weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None,
60
+ ):
61
+ """
62
+ Set the currently active adapters for use in the UNet.
63
+
64
+ Args:
65
+ adapter_names (`List[str]` or `str`):
66
+ The names of the adapters to use.
67
+ adapter_weights (`Union[List[float], float]`, *optional*):
68
+ The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
69
+ adapters.
70
+
71
+ Example:
72
+
73
+ ```py
74
+ from diffusers import AutoPipelineForText2Image
75
+ import torch
76
+
77
+ pipeline = AutoPipelineForText2Image.from_pretrained(
78
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
79
+ ).to("cuda")
80
+ pipeline.load_lora_weights(
81
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
82
+ )
83
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
84
+ pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
85
+ ```
86
+ """
87
+ if not USE_PEFT_BACKEND:
88
+ raise ValueError("PEFT backend is required for `set_adapters()`.")
89
+
90
+ adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
91
+
92
+ # Expand weights into a list, one entry per adapter
93
+ # examples for e.g. 2 adapters: [{...}, 7] -> [7,7] ; None -> [None, None]
94
+ if not isinstance(weights, list):
95
+ weights = [weights] * len(adapter_names)
96
+
97
+ if len(adapter_names) != len(weights):
98
+ raise ValueError(
99
+ f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}."
100
+ )
101
+
102
+ # Set None values to default of 1.0
103
+ # e.g. [{...}, 7] -> [{...}, 7] ; [None, None] -> [1.0, 1.0]
104
+ weights = [w if w is not None else 1.0 for w in weights]
105
+
106
+ # e.g. [{...}, 7] -> [{expanded dict...}, 7]
107
+ scale_expansion_fn = _SET_ADAPTER_SCALE_FN_MAPPING[self.__class__.__name__]
108
+ weights = scale_expansion_fn(self, weights)
109
+
110
+ set_weights_and_activate_adapters(self, adapter_names, weights)
111
+
112
+ def add_adapter(self, adapter_config, adapter_name: str = "default") -> None:
113
+ r"""
114
+ Adds a new adapter to the current model for training. If no adapter name is passed, a default name is assigned
115
+ to the adapter to follow the convention of the PEFT library.
116
+
117
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them in the PEFT
118
+ [documentation](https://huggingface.co/docs/peft).
119
+
120
+ Args:
121
+ adapter_config (`[~peft.PeftConfig]`):
122
+ The configuration of the adapter to add; supported adapters are non-prefix tuning and adaption prompt
123
+ methods.
124
+ adapter_name (`str`, *optional*, defaults to `"default"`):
125
+ The name of the adapter to add. If no name is passed, a default name is assigned to the adapter.
126
+ """
127
+ check_peft_version(min_version=MIN_PEFT_VERSION)
128
+
129
+ if not is_peft_available():
130
+ raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.")
131
+
132
+ from peft import PeftConfig, inject_adapter_in_model
133
+
134
+ if not self._hf_peft_config_loaded:
135
+ self._hf_peft_config_loaded = True
136
+ elif adapter_name in self.peft_config:
137
+ raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")
138
+
139
+ if not isinstance(adapter_config, PeftConfig):
140
+ raise ValueError(
141
+ f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead."
142
+ )
143
+
144
+ # Unlike transformers, here we don't need to retrieve the name_or_path of the unet as the loading logic is
145
+ # handled by the `load_lora_layers` or `StableDiffusionLoraLoaderMixin`. Therefore we set it to `None` here.
146
+ adapter_config.base_model_name_or_path = None
147
+ inject_adapter_in_model(adapter_config, self, adapter_name)
148
+ self.set_adapter(adapter_name)
149
+
150
+ def set_adapter(self, adapter_name: Union[str, List[str]]) -> None:
151
+ """
152
+ Sets a specific adapter by forcing the model to only use that adapter and disables the other adapters.
153
+
154
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
155
+ [documentation](https://huggingface.co/docs/peft).
156
+
157
+ Args:
158
+ adapter_name (Union[str, List[str]])):
159
+ The list of adapters to set or the adapter name in the case of a single adapter.
160
+ """
161
+ check_peft_version(min_version=MIN_PEFT_VERSION)
162
+
163
+ if not self._hf_peft_config_loaded:
164
+ raise ValueError("No adapter loaded. Please load an adapter first.")
165
+
166
+ if isinstance(adapter_name, str):
167
+ adapter_name = [adapter_name]
168
+
169
+ missing = set(adapter_name) - set(self.peft_config)
170
+ if len(missing) > 0:
171
+ raise ValueError(
172
+ f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)."
173
+ f" current loaded adapters are: {list(self.peft_config.keys())}"
174
+ )
175
+
176
+ from peft.tuners.tuners_utils import BaseTunerLayer
177
+
178
+ _adapters_has_been_set = False
179
+
180
+ for _, module in self.named_modules():
181
+ if isinstance(module, BaseTunerLayer):
182
+ if hasattr(module, "set_adapter"):
183
+ module.set_adapter(adapter_name)
184
+ # Previous versions of PEFT does not support multi-adapter inference
185
+ elif not hasattr(module, "set_adapter") and len(adapter_name) != 1:
186
+ raise ValueError(
187
+ "You are trying to set multiple adapters and you have a PEFT version that does not support multi-adapter inference. Please upgrade to the latest version of PEFT."
188
+ " `pip install -U peft` or `pip install -U git+https://github.com/huggingface/peft.git`"
189
+ )
190
+ else:
191
+ module.active_adapter = adapter_name
192
+ _adapters_has_been_set = True
193
+
194
+ if not _adapters_has_been_set:
195
+ raise ValueError(
196
+ "Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters."
197
+ )
198
+
199
+ def disable_adapters(self) -> None:
200
+ r"""
201
+ Disable all adapters attached to the model and fallback to inference with the base model only.
202
+
203
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
204
+ [documentation](https://huggingface.co/docs/peft).
205
+ """
206
+ check_peft_version(min_version=MIN_PEFT_VERSION)
207
+
208
+ if not self._hf_peft_config_loaded:
209
+ raise ValueError("No adapter loaded. Please load an adapter first.")
210
+
211
+ from peft.tuners.tuners_utils import BaseTunerLayer
212
+
213
+ for _, module in self.named_modules():
214
+ if isinstance(module, BaseTunerLayer):
215
+ if hasattr(module, "enable_adapters"):
216
+ module.enable_adapters(enabled=False)
217
+ else:
218
+ # support for older PEFT versions
219
+ module.disable_adapters = True
220
+
221
+ def enable_adapters(self) -> None:
222
+ """
223
+ Enable adapters that are attached to the model. The model uses `self.active_adapters()` to retrieve the list of
224
+ adapters to enable.
225
+
226
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
227
+ [documentation](https://huggingface.co/docs/peft).
228
+ """
229
+ check_peft_version(min_version=MIN_PEFT_VERSION)
230
+
231
+ if not self._hf_peft_config_loaded:
232
+ raise ValueError("No adapter loaded. Please load an adapter first.")
233
+
234
+ from peft.tuners.tuners_utils import BaseTunerLayer
235
+
236
+ for _, module in self.named_modules():
237
+ if isinstance(module, BaseTunerLayer):
238
+ if hasattr(module, "enable_adapters"):
239
+ module.enable_adapters(enabled=True)
240
+ else:
241
+ # support for older PEFT versions
242
+ module.disable_adapters = False
243
+
244
+ def active_adapters(self) -> List[str]:
245
+ """
246
+ Gets the current list of active adapters of the model.
247
+
248
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
249
+ [documentation](https://huggingface.co/docs/peft).
250
+ """
251
+ check_peft_version(min_version=MIN_PEFT_VERSION)
252
+
253
+ if not is_peft_available():
254
+ raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.")
255
+
256
+ if not self._hf_peft_config_loaded:
257
+ raise ValueError("No adapter loaded. Please load an adapter first.")
258
+
259
+ from peft.tuners.tuners_utils import BaseTunerLayer
260
+
261
+ for _, module in self.named_modules():
262
+ if isinstance(module, BaseTunerLayer):
263
+ return module.active_adapter
264
+
265
+ def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None):
266
+ if not USE_PEFT_BACKEND:
267
+ raise ValueError("PEFT backend is required for `fuse_lora()`.")
268
+
269
+ self.lora_scale = lora_scale
270
+ self._safe_fusing = safe_fusing
271
+ self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names))
272
+
273
+ def _fuse_lora_apply(self, module, adapter_names=None):
274
+ from peft.tuners.tuners_utils import BaseTunerLayer
275
+
276
+ merge_kwargs = {"safe_merge": self._safe_fusing}
277
+
278
+ if isinstance(module, BaseTunerLayer):
279
+ if self.lora_scale != 1.0:
280
+ module.scale_layer(self.lora_scale)
281
+
282
+ # For BC with prevous PEFT versions, we need to check the signature
283
+ # of the `merge` method to see if it supports the `adapter_names` argument.
284
+ supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
285
+ if "adapter_names" in supported_merge_kwargs:
286
+ merge_kwargs["adapter_names"] = adapter_names
287
+ elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
288
+ raise ValueError(
289
+ "The `adapter_names` argument is not supported with your PEFT version. Please upgrade"
290
+ " to the latest version of PEFT. `pip install -U peft`"
291
+ )
292
+
293
+ module.merge(**merge_kwargs)
294
+
295
+ def unfuse_lora(self):
296
+ if not USE_PEFT_BACKEND:
297
+ raise ValueError("PEFT backend is required for `unfuse_lora()`.")
298
+ self.apply(self._unfuse_lora_apply)
299
+
300
+ def _unfuse_lora_apply(self, module):
301
+ from peft.tuners.tuners_utils import BaseTunerLayer
302
+
303
+ if isinstance(module, BaseTunerLayer):
304
+ module.unmerge()
305
+
306
+ def unload_lora(self):
307
+ if not USE_PEFT_BACKEND:
308
+ raise ValueError("PEFT backend is required for `unload_lora()`.")
309
+
310
+ from ..utils import recurse_remove_peft_layers
311
+
312
+ recurse_remove_peft_layers(self)
313
+ if hasattr(self, "peft_config"):
314
+ del self.peft_config
315
+
316
+ def disable_lora(self):
317
+ """
318
+ Disables the active LoRA layers of the underlying model.
319
+
320
+ Example:
321
+
322
+ ```py
323
+ from diffusers import AutoPipelineForText2Image
324
+ import torch
325
+
326
+ pipeline = AutoPipelineForText2Image.from_pretrained(
327
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
328
+ ).to("cuda")
329
+ pipeline.load_lora_weights(
330
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
331
+ )
332
+ pipeline.disable_lora()
333
+ ```
334
+ """
335
+ if not USE_PEFT_BACKEND:
336
+ raise ValueError("PEFT backend is required for this method.")
337
+ set_adapter_layers(self, enabled=False)
338
+
339
+ def enable_lora(self):
340
+ """
341
+ Enables the active LoRA layers of the underlying model.
342
+
343
+ Example:
344
+
345
+ ```py
346
+ from diffusers import AutoPipelineForText2Image
347
+ import torch
348
+
349
+ pipeline = AutoPipelineForText2Image.from_pretrained(
350
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
351
+ ).to("cuda")
352
+ pipeline.load_lora_weights(
353
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
354
+ )
355
+ pipeline.enable_lora()
356
+ ```
357
+ """
358
+ if not USE_PEFT_BACKEND:
359
+ raise ValueError("PEFT backend is required for this method.")
360
+ set_adapter_layers(self, enabled=True)
361
+
362
+ def delete_adapters(self, adapter_names: Union[List[str], str]):
363
+ """
364
+ Delete an adapter's LoRA layers from the underlying model.
365
+
366
+ Args:
367
+ adapter_names (`Union[List[str], str]`):
368
+ The names (single string or list of strings) of the adapter to delete.
369
+
370
+ Example:
371
+
372
+ ```py
373
+ from diffusers import AutoPipelineForText2Image
374
+ import torch
375
+
376
+ pipeline = AutoPipelineForText2Image.from_pretrained(
377
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
378
+ ).to("cuda")
379
+ pipeline.load_lora_weights(
380
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic"
381
+ )
382
+ pipeline.delete_adapters("cinematic")
383
+ ```
384
+ """
385
+ if not USE_PEFT_BACKEND:
386
+ raise ValueError("PEFT backend is required for this method.")
387
+
388
+ if isinstance(adapter_names, str):
389
+ adapter_names = [adapter_names]
390
+
391
+ for adapter_name in adapter_names:
392
+ delete_adapter_layers(self, adapter_name)
393
+
394
+ # Pop also the corresponding adapter from the config
395
+ if hasattr(self, "peft_config"):
396
+ self.peft_config.pop(adapter_name, None)
diffusers/loaders/single_file.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import importlib
15
+ import inspect
16
+ import os
17
+
18
+ import torch
19
+ from huggingface_hub import snapshot_download
20
+ from huggingface_hub.utils import LocalEntryNotFoundError, validate_hf_hub_args
21
+ from packaging import version
22
+
23
+ from ..utils import deprecate, is_transformers_available, logging
24
+ from .single_file_utils import (
25
+ SingleFileComponentError,
26
+ _is_legacy_scheduler_kwargs,
27
+ _is_model_weights_in_cached_folder,
28
+ _legacy_load_clip_tokenizer,
29
+ _legacy_load_safety_checker,
30
+ _legacy_load_scheduler,
31
+ create_diffusers_clip_model_from_ldm,
32
+ create_diffusers_t5_model_from_checkpoint,
33
+ fetch_diffusers_config,
34
+ fetch_original_config,
35
+ is_clip_model_in_single_file,
36
+ is_t5_in_single_file,
37
+ load_single_file_checkpoint,
38
+ )
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+ # Legacy behaviour. `from_single_file` does not load the safety checker unless explicitly provided
44
+ SINGLE_FILE_OPTIONAL_COMPONENTS = ["safety_checker"]
45
+
46
+ if is_transformers_available():
47
+ import transformers
48
+ from transformers import PreTrainedModel, PreTrainedTokenizer
49
+
50
+
51
+ def load_single_file_sub_model(
52
+ library_name,
53
+ class_name,
54
+ name,
55
+ checkpoint,
56
+ pipelines,
57
+ is_pipeline_module,
58
+ cached_model_config_path,
59
+ original_config=None,
60
+ local_files_only=False,
61
+ torch_dtype=None,
62
+ is_legacy_loading=False,
63
+ **kwargs,
64
+ ):
65
+ if is_pipeline_module:
66
+ pipeline_module = getattr(pipelines, library_name)
67
+ class_obj = getattr(pipeline_module, class_name)
68
+ else:
69
+ # else we just import it from the library.
70
+ library = importlib.import_module(library_name)
71
+ class_obj = getattr(library, class_name)
72
+
73
+ if is_transformers_available():
74
+ transformers_version = version.parse(version.parse(transformers.__version__).base_version)
75
+ else:
76
+ transformers_version = "N/A"
77
+
78
+ is_transformers_model = (
79
+ is_transformers_available()
80
+ and issubclass(class_obj, PreTrainedModel)
81
+ and transformers_version >= version.parse("4.20.0")
82
+ )
83
+ is_tokenizer = (
84
+ is_transformers_available()
85
+ and issubclass(class_obj, PreTrainedTokenizer)
86
+ and transformers_version >= version.parse("4.20.0")
87
+ )
88
+
89
+ diffusers_module = importlib.import_module(__name__.split(".")[0])
90
+ is_diffusers_single_file_model = issubclass(class_obj, diffusers_module.FromOriginalModelMixin)
91
+ is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
92
+ is_diffusers_scheduler = issubclass(class_obj, diffusers_module.SchedulerMixin)
93
+
94
+ if is_diffusers_single_file_model:
95
+ load_method = getattr(class_obj, "from_single_file")
96
+
97
+ # We cannot provide two different config options to the `from_single_file` method
98
+ # Here we have to ignore loading the config from `cached_model_config_path` if `original_config` is provided
99
+ if original_config:
100
+ cached_model_config_path = None
101
+
102
+ loaded_sub_model = load_method(
103
+ pretrained_model_link_or_path_or_dict=checkpoint,
104
+ original_config=original_config,
105
+ config=cached_model_config_path,
106
+ subfolder=name,
107
+ torch_dtype=torch_dtype,
108
+ local_files_only=local_files_only,
109
+ **kwargs,
110
+ )
111
+
112
+ elif is_transformers_model and is_clip_model_in_single_file(class_obj, checkpoint):
113
+ loaded_sub_model = create_diffusers_clip_model_from_ldm(
114
+ class_obj,
115
+ checkpoint=checkpoint,
116
+ config=cached_model_config_path,
117
+ subfolder=name,
118
+ torch_dtype=torch_dtype,
119
+ local_files_only=local_files_only,
120
+ is_legacy_loading=is_legacy_loading,
121
+ )
122
+
123
+ elif is_transformers_model and is_t5_in_single_file(checkpoint):
124
+ loaded_sub_model = create_diffusers_t5_model_from_checkpoint(
125
+ class_obj,
126
+ checkpoint=checkpoint,
127
+ config=cached_model_config_path,
128
+ subfolder=name,
129
+ torch_dtype=torch_dtype,
130
+ local_files_only=local_files_only,
131
+ )
132
+
133
+ elif is_tokenizer and is_legacy_loading:
134
+ loaded_sub_model = _legacy_load_clip_tokenizer(
135
+ class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only
136
+ )
137
+
138
+ elif is_diffusers_scheduler and (is_legacy_loading or _is_legacy_scheduler_kwargs(kwargs)):
139
+ loaded_sub_model = _legacy_load_scheduler(
140
+ class_obj, checkpoint=checkpoint, component_name=name, original_config=original_config, **kwargs
141
+ )
142
+
143
+ else:
144
+ if not hasattr(class_obj, "from_pretrained"):
145
+ raise ValueError(
146
+ (
147
+ f"The component {class_obj.__name__} cannot be loaded as it does not seem to have"
148
+ " a supported loading method."
149
+ )
150
+ )
151
+
152
+ loading_kwargs = {}
153
+ loading_kwargs.update(
154
+ {
155
+ "pretrained_model_name_or_path": cached_model_config_path,
156
+ "subfolder": name,
157
+ "local_files_only": local_files_only,
158
+ }
159
+ )
160
+
161
+ # Schedulers and Tokenizers don't make use of torch_dtype
162
+ # Skip passing it to those objects
163
+ if issubclass(class_obj, torch.nn.Module):
164
+ loading_kwargs.update({"torch_dtype": torch_dtype})
165
+
166
+ if is_diffusers_model or is_transformers_model:
167
+ if not _is_model_weights_in_cached_folder(cached_model_config_path, name):
168
+ raise SingleFileComponentError(
169
+ f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint."
170
+ )
171
+
172
+ load_method = getattr(class_obj, "from_pretrained")
173
+ loaded_sub_model = load_method(**loading_kwargs)
174
+
175
+ return loaded_sub_model
176
+
177
+
178
+ def _map_component_types_to_config_dict(component_types):
179
+ diffusers_module = importlib.import_module(__name__.split(".")[0])
180
+ config_dict = {}
181
+ component_types.pop("self", None)
182
+
183
+ if is_transformers_available():
184
+ transformers_version = version.parse(version.parse(transformers.__version__).base_version)
185
+ else:
186
+ transformers_version = "N/A"
187
+
188
+ for component_name, component_value in component_types.items():
189
+ is_diffusers_model = issubclass(component_value[0], diffusers_module.ModelMixin)
190
+ is_scheduler_enum = component_value[0].__name__ == "KarrasDiffusionSchedulers"
191
+ is_scheduler = issubclass(component_value[0], diffusers_module.SchedulerMixin)
192
+
193
+ is_transformers_model = (
194
+ is_transformers_available()
195
+ and issubclass(component_value[0], PreTrainedModel)
196
+ and transformers_version >= version.parse("4.20.0")
197
+ )
198
+ is_transformers_tokenizer = (
199
+ is_transformers_available()
200
+ and issubclass(component_value[0], PreTrainedTokenizer)
201
+ and transformers_version >= version.parse("4.20.0")
202
+ )
203
+
204
+ if is_diffusers_model and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS:
205
+ config_dict[component_name] = ["diffusers", component_value[0].__name__]
206
+
207
+ elif is_scheduler_enum or is_scheduler:
208
+ if is_scheduler_enum:
209
+ # Since we cannot fetch a scheduler config from the hub, we default to DDIMScheduler
210
+ # if the type hint is a KarrassDiffusionSchedulers enum
211
+ config_dict[component_name] = ["diffusers", "DDIMScheduler"]
212
+
213
+ elif is_scheduler:
214
+ config_dict[component_name] = ["diffusers", component_value[0].__name__]
215
+
216
+ elif (
217
+ is_transformers_model or is_transformers_tokenizer
218
+ ) and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS:
219
+ config_dict[component_name] = ["transformers", component_value[0].__name__]
220
+
221
+ else:
222
+ config_dict[component_name] = [None, None]
223
+
224
+ return config_dict
225
+
226
+
227
+ def _infer_pipeline_config_dict(pipeline_class):
228
+ parameters = inspect.signature(pipeline_class.__init__).parameters
229
+ required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
230
+ component_types = pipeline_class._get_signature_types()
231
+
232
+ # Ignore parameters that are not required for the pipeline
233
+ component_types = {k: v for k, v in component_types.items() if k in required_parameters}
234
+ config_dict = _map_component_types_to_config_dict(component_types)
235
+
236
+ return config_dict
237
+
238
+
239
+ def _download_diffusers_model_config_from_hub(
240
+ pretrained_model_name_or_path,
241
+ cache_dir,
242
+ revision,
243
+ proxies,
244
+ force_download=None,
245
+ local_files_only=None,
246
+ token=None,
247
+ ):
248
+ allow_patterns = ["**/*.json", "*.json", "*.txt", "**/*.txt", "**/*.model"]
249
+ cached_model_path = snapshot_download(
250
+ pretrained_model_name_or_path,
251
+ cache_dir=cache_dir,
252
+ revision=revision,
253
+ proxies=proxies,
254
+ force_download=force_download,
255
+ local_files_only=local_files_only,
256
+ token=token,
257
+ allow_patterns=allow_patterns,
258
+ )
259
+
260
+ return cached_model_path
261
+
262
+
263
+ class FromSingleFileMixin:
264
+ """
265
+ Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`].
266
+ """
267
+
268
+ @classmethod
269
+ @validate_hf_hub_args
270
+ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
271
+ r"""
272
+ Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors`
273
+ format. The pipeline is set in evaluation mode (`model.eval()`) by default.
274
+
275
+ Parameters:
276
+ pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
277
+ Can be either:
278
+ - A link to the `.ckpt` file (for example
279
+ `"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
280
+ - A path to a *file* containing all pipeline weights.
281
+ torch_dtype (`str` or `torch.dtype`, *optional*):
282
+ Override the default `torch.dtype` and load the model with another dtype.
283
+ force_download (`bool`, *optional*, defaults to `False`):
284
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
285
+ cached versions if they exist.
286
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
287
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
288
+ is not used.
289
+
290
+ proxies (`Dict[str, str]`, *optional*):
291
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
292
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
293
+ local_files_only (`bool`, *optional*, defaults to `False`):
294
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
295
+ won't be downloaded from the Hub.
296
+ token (`str` or *bool*, *optional*):
297
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
298
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
299
+ revision (`str`, *optional*, defaults to `"main"`):
300
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
301
+ allowed by Git.
302
+ original_config_file (`str`, *optional*):
303
+ The path to the original config file that was used to train the model. If not provided, the config file
304
+ will be inferred from the checkpoint file.
305
+ config (`str`, *optional*):
306
+ Can be either:
307
+ - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
308
+ hosted on the Hub.
309
+ - A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline
310
+ component configs in Diffusers format.
311
+ kwargs (remaining dictionary of keyword arguments, *optional*):
312
+ Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
313
+ class). The overwritten components are passed directly to the pipelines `__init__` method. See example
314
+ below for more information.
315
+
316
+ Examples:
317
+
318
+ ```py
319
+ >>> from diffusers import StableDiffusionPipeline
320
+
321
+ >>> # Download pipeline from huggingface.co and cache.
322
+ >>> pipeline = StableDiffusionPipeline.from_single_file(
323
+ ... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
324
+ ... )
325
+
326
+ >>> # Download pipeline from local file
327
+ >>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt
328
+ >>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly.ckpt")
329
+
330
+ >>> # Enable float16 and move to GPU
331
+ >>> pipeline = StableDiffusionPipeline.from_single_file(
332
+ ... "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
333
+ ... torch_dtype=torch.float16,
334
+ ... )
335
+ >>> pipeline.to("cuda")
336
+ ```
337
+
338
+ """
339
+ original_config_file = kwargs.pop("original_config_file", None)
340
+ config = kwargs.pop("config", None)
341
+ original_config = kwargs.pop("original_config", None)
342
+
343
+ if original_config_file is not None:
344
+ deprecation_message = (
345
+ "`original_config_file` argument is deprecated and will be removed in future versions."
346
+ "please use the `original_config` argument instead."
347
+ )
348
+ deprecate("original_config_file", "1.0.0", deprecation_message)
349
+ original_config = original_config_file
350
+
351
+ force_download = kwargs.pop("force_download", False)
352
+ proxies = kwargs.pop("proxies", None)
353
+ token = kwargs.pop("token", None)
354
+ cache_dir = kwargs.pop("cache_dir", None)
355
+ local_files_only = kwargs.pop("local_files_only", False)
356
+ revision = kwargs.pop("revision", None)
357
+ torch_dtype = kwargs.pop("torch_dtype", None)
358
+
359
+ is_legacy_loading = False
360
+
361
+ # We shouldn't allow configuring individual models components through a Pipeline creation method
362
+ # These model kwargs should be deprecated
363
+ scaling_factor = kwargs.get("scaling_factor", None)
364
+ if scaling_factor is not None:
365
+ deprecation_message = (
366
+ "Passing the `scaling_factor` argument to `from_single_file is deprecated "
367
+ "and will be ignored in future versions."
368
+ )
369
+ deprecate("scaling_factor", "1.0.0", deprecation_message)
370
+
371
+ if original_config is not None:
372
+ original_config = fetch_original_config(original_config, local_files_only=local_files_only)
373
+
374
+ from ..pipelines.pipeline_utils import _get_pipeline_class
375
+
376
+ pipeline_class = _get_pipeline_class(cls, config=None)
377
+
378
+ checkpoint = load_single_file_checkpoint(
379
+ pretrained_model_link_or_path,
380
+ force_download=force_download,
381
+ proxies=proxies,
382
+ token=token,
383
+ cache_dir=cache_dir,
384
+ local_files_only=local_files_only,
385
+ revision=revision,
386
+ )
387
+
388
+ if config is None:
389
+ config = fetch_diffusers_config(checkpoint)
390
+ default_pretrained_model_config_name = config["pretrained_model_name_or_path"]
391
+ else:
392
+ default_pretrained_model_config_name = config
393
+
394
+ if not os.path.isdir(default_pretrained_model_config_name):
395
+ # Provided config is a repo_id
396
+ if default_pretrained_model_config_name.count("/") > 1:
397
+ raise ValueError(
398
+ f'The provided config "{config}"'
399
+ " is neither a valid local path nor a valid repo id. Please check the parameter."
400
+ )
401
+ try:
402
+ # Attempt to download the config files for the pipeline
403
+ cached_model_config_path = _download_diffusers_model_config_from_hub(
404
+ default_pretrained_model_config_name,
405
+ cache_dir=cache_dir,
406
+ revision=revision,
407
+ proxies=proxies,
408
+ force_download=force_download,
409
+ local_files_only=local_files_only,
410
+ token=token,
411
+ )
412
+ config_dict = pipeline_class.load_config(cached_model_config_path)
413
+
414
+ except LocalEntryNotFoundError:
415
+ # `local_files_only=True` but a local diffusers format model config is not available in the cache
416
+ # If `original_config` is not provided, we need override `local_files_only` to False
417
+ # to fetch the config files from the hub so that we have a way
418
+ # to configure the pipeline components.
419
+
420
+ if original_config is None:
421
+ logger.warning(
422
+ "`local_files_only` is True but no local configs were found for this checkpoint.\n"
423
+ "Attempting to download the necessary config files for this pipeline.\n"
424
+ )
425
+ cached_model_config_path = _download_diffusers_model_config_from_hub(
426
+ default_pretrained_model_config_name,
427
+ cache_dir=cache_dir,
428
+ revision=revision,
429
+ proxies=proxies,
430
+ force_download=force_download,
431
+ local_files_only=False,
432
+ token=token,
433
+ )
434
+ config_dict = pipeline_class.load_config(cached_model_config_path)
435
+
436
+ else:
437
+ # For backwards compatibility
438
+ # If `original_config` is provided, then we need to assume we are using legacy loading for pipeline components
439
+ logger.warning(
440
+ "Detected legacy `from_single_file` loading behavior. Attempting to create the pipeline based on inferred components.\n"
441
+ "This may lead to errors if the model components are not correctly inferred. \n"
442
+ "To avoid this warning, please explicity pass the `config` argument to `from_single_file` with a path to a local diffusers model repo \n"
443
+ "e.g. `from_single_file(<my model checkpoint path>, config=<path to local diffusers model repo>) \n"
444
+ "or run `from_single_file` with `local_files_only=False` first to update the local cache directory with "
445
+ "the necessary config files.\n"
446
+ )
447
+ is_legacy_loading = True
448
+ cached_model_config_path = None
449
+
450
+ config_dict = _infer_pipeline_config_dict(pipeline_class)
451
+ config_dict["_class_name"] = pipeline_class.__name__
452
+
453
+ else:
454
+ # Provided config is a path to a local directory attempt to load directly.
455
+ cached_model_config_path = default_pretrained_model_config_name
456
+ config_dict = pipeline_class.load_config(cached_model_config_path)
457
+
458
+ # pop out "_ignore_files" as it is only needed for download
459
+ config_dict.pop("_ignore_files", None)
460
+
461
+ expected_modules, optional_kwargs = pipeline_class._get_signature_keys(cls)
462
+ passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
463
+ passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
464
+
465
+ init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
466
+ init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
467
+ init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
468
+
469
+ from diffusers import pipelines
470
+
471
+ # remove `null` components
472
+ def load_module(name, value):
473
+ if value[0] is None:
474
+ return False
475
+ if name in passed_class_obj and passed_class_obj[name] is None:
476
+ return False
477
+ if name in SINGLE_FILE_OPTIONAL_COMPONENTS:
478
+ return False
479
+
480
+ return True
481
+
482
+ init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
483
+
484
+ for name, (library_name, class_name) in logging.tqdm(
485
+ sorted(init_dict.items()), desc="Loading pipeline components..."
486
+ ):
487
+ loaded_sub_model = None
488
+ is_pipeline_module = hasattr(pipelines, library_name)
489
+
490
+ if name in passed_class_obj:
491
+ loaded_sub_model = passed_class_obj[name]
492
+
493
+ else:
494
+ try:
495
+ loaded_sub_model = load_single_file_sub_model(
496
+ library_name=library_name,
497
+ class_name=class_name,
498
+ name=name,
499
+ checkpoint=checkpoint,
500
+ is_pipeline_module=is_pipeline_module,
501
+ cached_model_config_path=cached_model_config_path,
502
+ pipelines=pipelines,
503
+ torch_dtype=torch_dtype,
504
+ original_config=original_config,
505
+ local_files_only=local_files_only,
506
+ is_legacy_loading=is_legacy_loading,
507
+ **kwargs,
508
+ )
509
+ except SingleFileComponentError as e:
510
+ raise SingleFileComponentError(
511
+ (
512
+ f"{e.message}\n"
513
+ f"Please load the component before passing it in as an argument to `from_single_file`.\n"
514
+ f"\n"
515
+ f"{name} = {class_name}.from_pretrained('...')\n"
516
+ f"pipe = {pipeline_class.__name__}.from_single_file(<checkpoint path>, {name}={name})\n"
517
+ f"\n"
518
+ )
519
+ )
520
+
521
+ init_kwargs[name] = loaded_sub_model
522
+
523
+ missing_modules = set(expected_modules) - set(init_kwargs.keys())
524
+ passed_modules = list(passed_class_obj.keys())
525
+ optional_modules = pipeline_class._optional_components
526
+
527
+ if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
528
+ for module in missing_modules:
529
+ init_kwargs[module] = passed_class_obj.get(module, None)
530
+ elif len(missing_modules) > 0:
531
+ passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
532
+ raise ValueError(
533
+ f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
534
+ )
535
+
536
+ # deprecated kwargs
537
+ load_safety_checker = kwargs.pop("load_safety_checker", None)
538
+ if load_safety_checker is not None:
539
+ deprecation_message = (
540
+ "Please pass instances of `StableDiffusionSafetyChecker` and `AutoImageProcessor`"
541
+ "using the `safety_checker` and `feature_extractor` arguments in `from_single_file`"
542
+ )
543
+ deprecate("load_safety_checker", "1.0.0", deprecation_message)
544
+
545
+ safety_checker_components = _legacy_load_safety_checker(local_files_only, torch_dtype)
546
+ init_kwargs.update(safety_checker_components)
547
+
548
+ pipe = pipeline_class(**init_kwargs)
549
+
550
+ return pipe
diffusers/loaders/single_file_model.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import importlib
15
+ import inspect
16
+ import re
17
+ from contextlib import nullcontext
18
+ from typing import Optional
19
+
20
+ from huggingface_hub.utils import validate_hf_hub_args
21
+
22
+ from ..utils import deprecate, is_accelerate_available, logging
23
+ from .single_file_utils import (
24
+ SingleFileComponentError,
25
+ convert_animatediff_checkpoint_to_diffusers,
26
+ convert_controlnet_checkpoint,
27
+ convert_flux_transformer_checkpoint_to_diffusers,
28
+ convert_ldm_unet_checkpoint,
29
+ convert_ldm_vae_checkpoint,
30
+ convert_sd3_transformer_checkpoint_to_diffusers,
31
+ convert_stable_cascade_unet_single_file_to_diffusers,
32
+ create_controlnet_diffusers_config_from_ldm,
33
+ create_unet_diffusers_config_from_ldm,
34
+ create_vae_diffusers_config_from_ldm,
35
+ fetch_diffusers_config,
36
+ fetch_original_config,
37
+ load_single_file_checkpoint,
38
+ )
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+
44
+ if is_accelerate_available():
45
+ from accelerate import init_empty_weights
46
+
47
+ from ..models.modeling_utils import load_model_dict_into_meta
48
+
49
+
50
+ SINGLE_FILE_LOADABLE_CLASSES = {
51
+ "StableCascadeUNet": {
52
+ "checkpoint_mapping_fn": convert_stable_cascade_unet_single_file_to_diffusers,
53
+ },
54
+ "UNet2DConditionModel": {
55
+ "checkpoint_mapping_fn": convert_ldm_unet_checkpoint,
56
+ "config_mapping_fn": create_unet_diffusers_config_from_ldm,
57
+ "default_subfolder": "unet",
58
+ "legacy_kwargs": {
59
+ "num_in_channels": "in_channels", # Legacy kwargs supported by `from_single_file` mapped to new args
60
+ },
61
+ },
62
+ "AutoencoderKL": {
63
+ "checkpoint_mapping_fn": convert_ldm_vae_checkpoint,
64
+ "config_mapping_fn": create_vae_diffusers_config_from_ldm,
65
+ "default_subfolder": "vae",
66
+ },
67
+ "ControlNetModel": {
68
+ "checkpoint_mapping_fn": convert_controlnet_checkpoint,
69
+ "config_mapping_fn": create_controlnet_diffusers_config_from_ldm,
70
+ },
71
+ "SD3Transformer2DModel": {
72
+ "checkpoint_mapping_fn": convert_sd3_transformer_checkpoint_to_diffusers,
73
+ "default_subfolder": "transformer",
74
+ },
75
+ "MotionAdapter": {
76
+ "checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
77
+ },
78
+ "SparseControlNetModel": {
79
+ "checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
80
+ },
81
+ "FluxTransformer2DModel": {
82
+ "checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
83
+ "default_subfolder": "transformer",
84
+ },
85
+ }
86
+
87
+
88
+ def _get_single_file_loadable_mapping_class(cls):
89
+ diffusers_module = importlib.import_module(__name__.split(".")[0])
90
+ for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
91
+ loadable_class = getattr(diffusers_module, loadable_class_str)
92
+
93
+ if issubclass(cls, loadable_class):
94
+ return loadable_class_str
95
+
96
+ return None
97
+
98
+
99
+ def _get_mapping_function_kwargs(mapping_fn, **kwargs):
100
+ parameters = inspect.signature(mapping_fn).parameters
101
+
102
+ mapping_kwargs = {}
103
+ for parameter in parameters:
104
+ if parameter in kwargs:
105
+ mapping_kwargs[parameter] = kwargs[parameter]
106
+
107
+ return mapping_kwargs
108
+
109
+
110
+ class FromOriginalModelMixin:
111
+ """
112
+ Load pretrained weights saved in the `.ckpt` or `.safetensors` format into a model.
113
+ """
114
+
115
+ @classmethod
116
+ @validate_hf_hub_args
117
+ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs):
118
+ r"""
119
+ Instantiate a model from pretrained weights saved in the original `.ckpt` or `.safetensors` format. The model
120
+ is set in evaluation mode (`model.eval()`) by default.
121
+
122
+ Parameters:
123
+ pretrained_model_link_or_path_or_dict (`str`, *optional*):
124
+ Can be either:
125
+ - A link to the `.safetensors` or `.ckpt` file (for example
126
+ `"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.safetensors"`) on the Hub.
127
+ - A path to a local *file* containing the weights of the component model.
128
+ - A state dict containing the component model weights.
129
+ config (`str`, *optional*):
130
+ - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline hosted
131
+ on the Hub.
132
+ - A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline component
133
+ configs in Diffusers format.
134
+ subfolder (`str`, *optional*, defaults to `""`):
135
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
136
+ original_config (`str`, *optional*):
137
+ Dict or path to a yaml file containing the configuration for the model in its original format.
138
+ If a dict is provided, it will be used to initialize the model configuration.
139
+ torch_dtype (`str` or `torch.dtype`, *optional*):
140
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
141
+ dtype is automatically derived from the model's weights.
142
+ force_download (`bool`, *optional*, defaults to `False`):
143
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
144
+ cached versions if they exist.
145
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
146
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
147
+ is not used.
148
+
149
+ proxies (`Dict[str, str]`, *optional*):
150
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
151
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
152
+ local_files_only (`bool`, *optional*, defaults to `False`):
153
+ Whether to only load local model weights and configuration files or not. If set to True, the model
154
+ won't be downloaded from the Hub.
155
+ token (`str` or *bool*, *optional*):
156
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
157
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
158
+ revision (`str`, *optional*, defaults to `"main"`):
159
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
160
+ allowed by Git.
161
+ kwargs (remaining dictionary of keyword arguments, *optional*):
162
+ Can be used to overwrite load and saveable variables (for example the pipeline components of the
163
+ specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
164
+ method. See example below for more information.
165
+
166
+ ```py
167
+ >>> from diffusers import StableCascadeUNet
168
+
169
+ >>> ckpt_path = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_lite.safetensors"
170
+ >>> model = StableCascadeUNet.from_single_file(ckpt_path)
171
+ ```
172
+ """
173
+
174
+ mapping_class_name = _get_single_file_loadable_mapping_class(cls)
175
+ # if class_name not in SINGLE_FILE_LOADABLE_CLASSES:
176
+ if mapping_class_name is None:
177
+ raise ValueError(
178
+ f"FromOriginalModelMixin is currently only compatible with {', '.join(SINGLE_FILE_LOADABLE_CLASSES.keys())}"
179
+ )
180
+
181
+ pretrained_model_link_or_path = kwargs.get("pretrained_model_link_or_path", None)
182
+ if pretrained_model_link_or_path is not None:
183
+ deprecation_message = (
184
+ "Please use `pretrained_model_link_or_path_or_dict` argument instead for model classes"
185
+ )
186
+ deprecate("pretrained_model_link_or_path", "1.0.0", deprecation_message)
187
+ pretrained_model_link_or_path_or_dict = pretrained_model_link_or_path
188
+
189
+ config = kwargs.pop("config", None)
190
+ original_config = kwargs.pop("original_config", None)
191
+
192
+ if config is not None and original_config is not None:
193
+ raise ValueError(
194
+ "`from_single_file` cannot accept both `config` and `original_config` arguments. Please provide only one of these arguments"
195
+ )
196
+
197
+ force_download = kwargs.pop("force_download", False)
198
+ proxies = kwargs.pop("proxies", None)
199
+ token = kwargs.pop("token", None)
200
+ cache_dir = kwargs.pop("cache_dir", None)
201
+ local_files_only = kwargs.pop("local_files_only", None)
202
+ subfolder = kwargs.pop("subfolder", None)
203
+ revision = kwargs.pop("revision", None)
204
+ torch_dtype = kwargs.pop("torch_dtype", None)
205
+
206
+ if isinstance(pretrained_model_link_or_path_or_dict, dict):
207
+ checkpoint = pretrained_model_link_or_path_or_dict
208
+ else:
209
+ checkpoint = load_single_file_checkpoint(
210
+ pretrained_model_link_or_path_or_dict,
211
+ force_download=force_download,
212
+ proxies=proxies,
213
+ token=token,
214
+ cache_dir=cache_dir,
215
+ local_files_only=local_files_only,
216
+ revision=revision,
217
+ )
218
+
219
+ mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name]
220
+
221
+ checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"]
222
+ if original_config:
223
+ if "config_mapping_fn" in mapping_functions:
224
+ config_mapping_fn = mapping_functions["config_mapping_fn"]
225
+ else:
226
+ config_mapping_fn = None
227
+
228
+ if config_mapping_fn is None:
229
+ raise ValueError(
230
+ (
231
+ f"`original_config` has been provided for {mapping_class_name} but no mapping function"
232
+ "was found to convert the original config to a Diffusers config in"
233
+ "`diffusers.loaders.single_file_utils`"
234
+ )
235
+ )
236
+
237
+ if isinstance(original_config, str):
238
+ # If original_config is a URL or filepath fetch the original_config dict
239
+ original_config = fetch_original_config(original_config, local_files_only=local_files_only)
240
+
241
+ config_mapping_kwargs = _get_mapping_function_kwargs(config_mapping_fn, **kwargs)
242
+ diffusers_model_config = config_mapping_fn(
243
+ original_config=original_config, checkpoint=checkpoint, **config_mapping_kwargs
244
+ )
245
+ else:
246
+ if config:
247
+ if isinstance(config, str):
248
+ default_pretrained_model_config_name = config
249
+ else:
250
+ raise ValueError(
251
+ (
252
+ "Invalid `config` argument. Please provide a string representing a repo id"
253
+ "or path to a local Diffusers model repo."
254
+ )
255
+ )
256
+
257
+ else:
258
+ config = fetch_diffusers_config(checkpoint)
259
+ default_pretrained_model_config_name = config["pretrained_model_name_or_path"]
260
+
261
+ if "default_subfolder" in mapping_functions:
262
+ subfolder = mapping_functions["default_subfolder"]
263
+
264
+ subfolder = subfolder or config.pop(
265
+ "subfolder", None
266
+ ) # some configs contain a subfolder key, e.g. StableCascadeUNet
267
+
268
+ diffusers_model_config = cls.load_config(
269
+ pretrained_model_name_or_path=default_pretrained_model_config_name,
270
+ subfolder=subfolder,
271
+ local_files_only=local_files_only,
272
+ )
273
+ expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)
274
+
275
+ # Map legacy kwargs to new kwargs
276
+ if "legacy_kwargs" in mapping_functions:
277
+ legacy_kwargs = mapping_functions["legacy_kwargs"]
278
+ for legacy_key, new_key in legacy_kwargs.items():
279
+ if legacy_key in kwargs:
280
+ kwargs[new_key] = kwargs.pop(legacy_key)
281
+
282
+ model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
283
+ diffusers_model_config.update(model_kwargs)
284
+
285
+ checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
286
+ diffusers_format_checkpoint = checkpoint_mapping_fn(
287
+ config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
288
+ )
289
+ if not diffusers_format_checkpoint:
290
+ raise SingleFileComponentError(
291
+ f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
292
+ )
293
+
294
+ ctx = init_empty_weights if is_accelerate_available() else nullcontext
295
+ with ctx():
296
+ model = cls.from_config(diffusers_model_config)
297
+
298
+ if is_accelerate_available():
299
+ unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
300
+
301
+ else:
302
+ _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
303
+
304
+ if model._keys_to_ignore_on_load_unexpected is not None:
305
+ for pat in model._keys_to_ignore_on_load_unexpected:
306
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
307
+
308
+ if len(unexpected_keys) > 0:
309
+ logger.warning(
310
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
311
+ )
312
+
313
+ if torch_dtype is not None:
314
+ model.to(torch_dtype)
315
+
316
+ model.eval()
317
+
318
+ return model
diffusers/loaders/single_file_utils.py ADDED
@@ -0,0 +1,2100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Conversion script for the Stable Diffusion checkpoints."""
16
+
17
+ import copy
18
+ import os
19
+ import re
20
+ from contextlib import nullcontext
21
+ from io import BytesIO
22
+ from urllib.parse import urlparse
23
+
24
+ import requests
25
+ import torch
26
+ import yaml
27
+
28
+ from ..models.modeling_utils import load_state_dict
29
+ from ..schedulers import (
30
+ DDIMScheduler,
31
+ DPMSolverMultistepScheduler,
32
+ EDMDPMSolverMultistepScheduler,
33
+ EulerAncestralDiscreteScheduler,
34
+ EulerDiscreteScheduler,
35
+ HeunDiscreteScheduler,
36
+ LMSDiscreteScheduler,
37
+ PNDMScheduler,
38
+ )
39
+ from ..utils import (
40
+ SAFETENSORS_WEIGHTS_NAME,
41
+ WEIGHTS_NAME,
42
+ deprecate,
43
+ is_accelerate_available,
44
+ is_transformers_available,
45
+ logging,
46
+ )
47
+ from ..utils.hub_utils import _get_model_file
48
+
49
+
50
+ if is_transformers_available():
51
+ from transformers import AutoImageProcessor
52
+
53
+ if is_accelerate_available():
54
+ from accelerate import init_empty_weights
55
+
56
+ from ..models.modeling_utils import load_model_dict_into_meta
57
+
58
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
59
+
60
+ CHECKPOINT_KEY_NAMES = {
61
+ "v2": "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
62
+ "xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias",
63
+ "xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias",
64
+ "upscale": "model.diffusion_model.input_blocks.10.0.skip_connection.bias",
65
+ "controlnet": "control_model.time_embed.0.weight",
66
+ "playground-v2-5": "edm_mean",
67
+ "inpainting": "model.diffusion_model.input_blocks.0.0.weight",
68
+ "clip": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
69
+ "clip_sdxl": "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight",
70
+ "clip_sd3": "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight",
71
+ "open_clip": "cond_stage_model.model.token_embedding.weight",
72
+ "open_clip_sdxl": "conditioner.embedders.1.model.positional_embedding",
73
+ "open_clip_sdxl_refiner": "conditioner.embedders.0.model.text_projection",
74
+ "open_clip_sd3": "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight",
75
+ "stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
76
+ "stable_cascade_stage_c": "clip_txt_mapper.weight",
77
+ "sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
78
+ "animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe",
79
+ "animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
80
+ "animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
81
+ "animatediff_scribble": "controlnet_cond_embedding.conv_in.weight",
82
+ "animatediff_rgb": "controlnet_cond_embedding.weight",
83
+ "flux": [
84
+ "double_blocks.0.img_attn.norm.key_norm.scale",
85
+ "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
86
+ ],
87
+ }
88
+
89
+ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
90
+ "xl_base": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-base-1.0"},
91
+ "xl_refiner": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-refiner-1.0"},
92
+ "xl_inpaint": {"pretrained_model_name_or_path": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"},
93
+ "playground-v2-5": {"pretrained_model_name_or_path": "playgroundai/playground-v2.5-1024px-aesthetic"},
94
+ "upscale": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-x4-upscaler"},
95
+ "inpainting": {"pretrained_model_name_or_path": "stable-diffusion-v1-5/stable-diffusion-inpainting"},
96
+ "inpainting_v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-inpainting"},
97
+ "controlnet": {"pretrained_model_name_or_path": "lllyasviel/control_v11p_sd15_canny"},
98
+ "v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-1"},
99
+ "v1": {"pretrained_model_name_or_path": "stable-diffusion-v1-5/stable-diffusion-v1-5"},
100
+ "stable_cascade_stage_b": {"pretrained_model_name_or_path": "stabilityai/stable-cascade", "subfolder": "decoder"},
101
+ "stable_cascade_stage_b_lite": {
102
+ "pretrained_model_name_or_path": "stabilityai/stable-cascade",
103
+ "subfolder": "decoder_lite",
104
+ },
105
+ "stable_cascade_stage_c": {
106
+ "pretrained_model_name_or_path": "stabilityai/stable-cascade-prior",
107
+ "subfolder": "prior",
108
+ },
109
+ "stable_cascade_stage_c_lite": {
110
+ "pretrained_model_name_or_path": "stabilityai/stable-cascade-prior",
111
+ "subfolder": "prior_lite",
112
+ },
113
+ "sd3": {
114
+ "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3-medium-diffusers",
115
+ },
116
+ "animatediff_v1": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5"},
117
+ "animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"},
118
+ "animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"},
119
+ "animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"},
120
+ "animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"},
121
+ "animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
122
+ "flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
123
+ "flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
124
+ }
125
+
126
+ # Use to configure model sample size when original config is provided
127
+ DIFFUSERS_TO_LDM_DEFAULT_IMAGE_SIZE_MAP = {
128
+ "xl_base": 1024,
129
+ "xl_refiner": 1024,
130
+ "xl_inpaint": 1024,
131
+ "playground-v2-5": 1024,
132
+ "upscale": 512,
133
+ "inpainting": 512,
134
+ "inpainting_v2": 512,
135
+ "controlnet": 512,
136
+ "v2": 768,
137
+ "v1": 512,
138
+ }
139
+
140
+
141
+ DIFFUSERS_TO_LDM_MAPPING = {
142
+ "unet": {
143
+ "layers": {
144
+ "time_embedding.linear_1.weight": "time_embed.0.weight",
145
+ "time_embedding.linear_1.bias": "time_embed.0.bias",
146
+ "time_embedding.linear_2.weight": "time_embed.2.weight",
147
+ "time_embedding.linear_2.bias": "time_embed.2.bias",
148
+ "conv_in.weight": "input_blocks.0.0.weight",
149
+ "conv_in.bias": "input_blocks.0.0.bias",
150
+ "conv_norm_out.weight": "out.0.weight",
151
+ "conv_norm_out.bias": "out.0.bias",
152
+ "conv_out.weight": "out.2.weight",
153
+ "conv_out.bias": "out.2.bias",
154
+ },
155
+ "class_embed_type": {
156
+ "class_embedding.linear_1.weight": "label_emb.0.0.weight",
157
+ "class_embedding.linear_1.bias": "label_emb.0.0.bias",
158
+ "class_embedding.linear_2.weight": "label_emb.0.2.weight",
159
+ "class_embedding.linear_2.bias": "label_emb.0.2.bias",
160
+ },
161
+ "addition_embed_type": {
162
+ "add_embedding.linear_1.weight": "label_emb.0.0.weight",
163
+ "add_embedding.linear_1.bias": "label_emb.0.0.bias",
164
+ "add_embedding.linear_2.weight": "label_emb.0.2.weight",
165
+ "add_embedding.linear_2.bias": "label_emb.0.2.bias",
166
+ },
167
+ },
168
+ "controlnet": {
169
+ "layers": {
170
+ "time_embedding.linear_1.weight": "time_embed.0.weight",
171
+ "time_embedding.linear_1.bias": "time_embed.0.bias",
172
+ "time_embedding.linear_2.weight": "time_embed.2.weight",
173
+ "time_embedding.linear_2.bias": "time_embed.2.bias",
174
+ "conv_in.weight": "input_blocks.0.0.weight",
175
+ "conv_in.bias": "input_blocks.0.0.bias",
176
+ "controlnet_cond_embedding.conv_in.weight": "input_hint_block.0.weight",
177
+ "controlnet_cond_embedding.conv_in.bias": "input_hint_block.0.bias",
178
+ "controlnet_cond_embedding.conv_out.weight": "input_hint_block.14.weight",
179
+ "controlnet_cond_embedding.conv_out.bias": "input_hint_block.14.bias",
180
+ },
181
+ "class_embed_type": {
182
+ "class_embedding.linear_1.weight": "label_emb.0.0.weight",
183
+ "class_embedding.linear_1.bias": "label_emb.0.0.bias",
184
+ "class_embedding.linear_2.weight": "label_emb.0.2.weight",
185
+ "class_embedding.linear_2.bias": "label_emb.0.2.bias",
186
+ },
187
+ "addition_embed_type": {
188
+ "add_embedding.linear_1.weight": "label_emb.0.0.weight",
189
+ "add_embedding.linear_1.bias": "label_emb.0.0.bias",
190
+ "add_embedding.linear_2.weight": "label_emb.0.2.weight",
191
+ "add_embedding.linear_2.bias": "label_emb.0.2.bias",
192
+ },
193
+ },
194
+ "vae": {
195
+ "encoder.conv_in.weight": "encoder.conv_in.weight",
196
+ "encoder.conv_in.bias": "encoder.conv_in.bias",
197
+ "encoder.conv_out.weight": "encoder.conv_out.weight",
198
+ "encoder.conv_out.bias": "encoder.conv_out.bias",
199
+ "encoder.conv_norm_out.weight": "encoder.norm_out.weight",
200
+ "encoder.conv_norm_out.bias": "encoder.norm_out.bias",
201
+ "decoder.conv_in.weight": "decoder.conv_in.weight",
202
+ "decoder.conv_in.bias": "decoder.conv_in.bias",
203
+ "decoder.conv_out.weight": "decoder.conv_out.weight",
204
+ "decoder.conv_out.bias": "decoder.conv_out.bias",
205
+ "decoder.conv_norm_out.weight": "decoder.norm_out.weight",
206
+ "decoder.conv_norm_out.bias": "decoder.norm_out.bias",
207
+ "quant_conv.weight": "quant_conv.weight",
208
+ "quant_conv.bias": "quant_conv.bias",
209
+ "post_quant_conv.weight": "post_quant_conv.weight",
210
+ "post_quant_conv.bias": "post_quant_conv.bias",
211
+ },
212
+ "openclip": {
213
+ "layers": {
214
+ "text_model.embeddings.position_embedding.weight": "positional_embedding",
215
+ "text_model.embeddings.token_embedding.weight": "token_embedding.weight",
216
+ "text_model.final_layer_norm.weight": "ln_final.weight",
217
+ "text_model.final_layer_norm.bias": "ln_final.bias",
218
+ "text_projection.weight": "text_projection",
219
+ },
220
+ "transformer": {
221
+ "text_model.encoder.layers.": "resblocks.",
222
+ "layer_norm1": "ln_1",
223
+ "layer_norm2": "ln_2",
224
+ ".fc1.": ".c_fc.",
225
+ ".fc2.": ".c_proj.",
226
+ ".self_attn": ".attn",
227
+ "transformer.text_model.final_layer_norm.": "ln_final.",
228
+ "transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
229
+ "transformer.text_model.embeddings.position_embedding.weight": "positional_embedding",
230
+ },
231
+ },
232
+ }
233
+
234
+ SD_2_TEXT_ENCODER_KEYS_TO_IGNORE = [
235
+ "cond_stage_model.model.transformer.resblocks.23.attn.in_proj_bias",
236
+ "cond_stage_model.model.transformer.resblocks.23.attn.in_proj_weight",
237
+ "cond_stage_model.model.transformer.resblocks.23.attn.out_proj.bias",
238
+ "cond_stage_model.model.transformer.resblocks.23.attn.out_proj.weight",
239
+ "cond_stage_model.model.transformer.resblocks.23.ln_1.bias",
240
+ "cond_stage_model.model.transformer.resblocks.23.ln_1.weight",
241
+ "cond_stage_model.model.transformer.resblocks.23.ln_2.bias",
242
+ "cond_stage_model.model.transformer.resblocks.23.ln_2.weight",
243
+ "cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.bias",
244
+ "cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.weight",
245
+ "cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.bias",
246
+ "cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.weight",
247
+ "cond_stage_model.model.text_projection",
248
+ ]
249
+
250
+ # To support legacy scheduler_type argument
251
+ SCHEDULER_DEFAULT_CONFIG = {
252
+ "beta_schedule": "scaled_linear",
253
+ "beta_start": 0.00085,
254
+ "beta_end": 0.012,
255
+ "interpolation_type": "linear",
256
+ "num_train_timesteps": 1000,
257
+ "prediction_type": "epsilon",
258
+ "sample_max_value": 1.0,
259
+ "set_alpha_to_one": False,
260
+ "skip_prk_steps": True,
261
+ "steps_offset": 1,
262
+ "timestep_spacing": "leading",
263
+ }
264
+
265
+ LDM_VAE_KEYS = ["first_stage_model.", "vae."]
266
+ LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
267
+ PLAYGROUND_VAE_SCALING_FACTOR = 0.5
268
+ LDM_UNET_KEY = "model.diffusion_model."
269
+ LDM_CONTROLNET_KEY = "control_model."
270
+ LDM_CLIP_PREFIX_TO_REMOVE = [
271
+ "cond_stage_model.transformer.",
272
+ "conditioner.embedders.0.transformer.",
273
+ ]
274
+ LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
275
+ SCHEDULER_LEGACY_KWARGS = ["prediction_type", "scheduler_type"]
276
+
277
+ VALID_URL_PREFIXES = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]
278
+
279
+
280
+ class SingleFileComponentError(Exception):
281
+ def __init__(self, message=None):
282
+ self.message = message
283
+ super().__init__(self.message)
284
+
285
+
286
+ def is_valid_url(url):
287
+ result = urlparse(url)
288
+ if result.scheme and result.netloc:
289
+ return True
290
+
291
+ return False
292
+
293
+
294
+ def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
295
+ if not is_valid_url(pretrained_model_name_or_path):
296
+ raise ValueError("Invalid `pretrained_model_name_or_path` provided. Please set it to a valid URL.")
297
+
298
+ pattern = r"([^/]+)/([^/]+)/(?:blob/main/)?(.+)"
299
+ weights_name = None
300
+ repo_id = (None,)
301
+ for prefix in VALID_URL_PREFIXES:
302
+ pretrained_model_name_or_path = pretrained_model_name_or_path.replace(prefix, "")
303
+ match = re.match(pattern, pretrained_model_name_or_path)
304
+ if not match:
305
+ logger.warning("Unable to identify the repo_id and weights_name from the provided URL.")
306
+ return repo_id, weights_name
307
+
308
+ repo_id = f"{match.group(1)}/{match.group(2)}"
309
+ weights_name = match.group(3)
310
+
311
+ return repo_id, weights_name
312
+
313
+
314
+ def _is_model_weights_in_cached_folder(cached_folder, name):
315
+ pretrained_model_name_or_path = os.path.join(cached_folder, name)
316
+ weights_exist = False
317
+
318
+ for weights_name in [WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME]:
319
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
320
+ weights_exist = True
321
+
322
+ return weights_exist
323
+
324
+
325
+ def _is_legacy_scheduler_kwargs(kwargs):
326
+ return any(k in SCHEDULER_LEGACY_KWARGS for k in kwargs.keys())
327
+
328
+
329
+ def load_single_file_checkpoint(
330
+ pretrained_model_link_or_path,
331
+ force_download=False,
332
+ proxies=None,
333
+ token=None,
334
+ cache_dir=None,
335
+ local_files_only=None,
336
+ revision=None,
337
+ ):
338
+ if os.path.isfile(pretrained_model_link_or_path):
339
+ pretrained_model_link_or_path = pretrained_model_link_or_path
340
+
341
+ else:
342
+ repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path)
343
+ pretrained_model_link_or_path = _get_model_file(
344
+ repo_id,
345
+ weights_name=weights_name,
346
+ force_download=force_download,
347
+ cache_dir=cache_dir,
348
+ proxies=proxies,
349
+ local_files_only=local_files_only,
350
+ token=token,
351
+ revision=revision,
352
+ )
353
+
354
+ checkpoint = load_state_dict(pretrained_model_link_or_path)
355
+
356
+ # some checkpoints contain the model state dict under a "state_dict" key
357
+ while "state_dict" in checkpoint:
358
+ checkpoint = checkpoint["state_dict"]
359
+
360
+ return checkpoint
361
+
362
+
363
+ def fetch_original_config(original_config_file, local_files_only=False):
364
+ if os.path.isfile(original_config_file):
365
+ with open(original_config_file, "r") as fp:
366
+ original_config_file = fp.read()
367
+
368
+ elif is_valid_url(original_config_file):
369
+ if local_files_only:
370
+ raise ValueError(
371
+ "`local_files_only` is set to True, but a URL was provided as `original_config_file`. "
372
+ "Please provide a valid local file path."
373
+ )
374
+
375
+ original_config_file = BytesIO(requests.get(original_config_file).content)
376
+
377
+ else:
378
+ raise ValueError("Invalid `original_config_file` provided. Please set it to a valid file path or URL.")
379
+
380
+ original_config = yaml.safe_load(original_config_file)
381
+
382
+ return original_config
383
+
384
+
385
+ def is_clip_model(checkpoint):
386
+ if CHECKPOINT_KEY_NAMES["clip"] in checkpoint:
387
+ return True
388
+
389
+ return False
390
+
391
+
392
+ def is_clip_sdxl_model(checkpoint):
393
+ if CHECKPOINT_KEY_NAMES["clip_sdxl"] in checkpoint:
394
+ return True
395
+
396
+ return False
397
+
398
+
399
+ def is_clip_sd3_model(checkpoint):
400
+ if CHECKPOINT_KEY_NAMES["clip_sd3"] in checkpoint:
401
+ return True
402
+
403
+ return False
404
+
405
+
406
+ def is_open_clip_model(checkpoint):
407
+ if CHECKPOINT_KEY_NAMES["open_clip"] in checkpoint:
408
+ return True
409
+
410
+ return False
411
+
412
+
413
+ def is_open_clip_sdxl_model(checkpoint):
414
+ if CHECKPOINT_KEY_NAMES["open_clip_sdxl"] in checkpoint:
415
+ return True
416
+
417
+ return False
418
+
419
+
420
+ def is_open_clip_sd3_model(checkpoint):
421
+ if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint:
422
+ return True
423
+
424
+ return False
425
+
426
+
427
+ def is_open_clip_sdxl_refiner_model(checkpoint):
428
+ if CHECKPOINT_KEY_NAMES["open_clip_sdxl_refiner"] in checkpoint:
429
+ return True
430
+
431
+ return False
432
+
433
+
434
+ def is_clip_model_in_single_file(class_obj, checkpoint):
435
+ is_clip_in_checkpoint = any(
436
+ [
437
+ is_clip_model(checkpoint),
438
+ is_clip_sd3_model(checkpoint),
439
+ is_open_clip_model(checkpoint),
440
+ is_open_clip_sdxl_model(checkpoint),
441
+ is_open_clip_sdxl_refiner_model(checkpoint),
442
+ is_open_clip_sd3_model(checkpoint),
443
+ ]
444
+ )
445
+ if (
446
+ class_obj.__name__ == "CLIPTextModel" or class_obj.__name__ == "CLIPTextModelWithProjection"
447
+ ) and is_clip_in_checkpoint:
448
+ return True
449
+
450
+ return False
451
+
452
+
453
+ def infer_diffusers_model_type(checkpoint):
454
+ if (
455
+ CHECKPOINT_KEY_NAMES["inpainting"] in checkpoint
456
+ and checkpoint[CHECKPOINT_KEY_NAMES["inpainting"]].shape[1] == 9
457
+ ):
458
+ if CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024:
459
+ model_type = "inpainting_v2"
460
+ elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint:
461
+ model_type = "xl_inpaint"
462
+ else:
463
+ model_type = "inpainting"
464
+
465
+ elif CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024:
466
+ model_type = "v2"
467
+
468
+ elif CHECKPOINT_KEY_NAMES["playground-v2-5"] in checkpoint:
469
+ model_type = "playground-v2-5"
470
+
471
+ elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint:
472
+ model_type = "xl_base"
473
+
474
+ elif CHECKPOINT_KEY_NAMES["xl_refiner"] in checkpoint:
475
+ model_type = "xl_refiner"
476
+
477
+ elif CHECKPOINT_KEY_NAMES["upscale"] in checkpoint:
478
+ model_type = "upscale"
479
+
480
+ elif CHECKPOINT_KEY_NAMES["controlnet"] in checkpoint:
481
+ model_type = "controlnet"
482
+
483
+ elif (
484
+ CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint
485
+ and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"]].shape[0] == 1536
486
+ ):
487
+ model_type = "stable_cascade_stage_c_lite"
488
+
489
+ elif (
490
+ CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint
491
+ and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"]].shape[0] == 2048
492
+ ):
493
+ model_type = "stable_cascade_stage_c"
494
+
495
+ elif (
496
+ CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"] in checkpoint
497
+ and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"]].shape[-1] == 576
498
+ ):
499
+ model_type = "stable_cascade_stage_b_lite"
500
+
501
+ elif (
502
+ CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"] in checkpoint
503
+ and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"]].shape[-1] == 640
504
+ ):
505
+ model_type = "stable_cascade_stage_b"
506
+
507
+ elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint:
508
+ model_type = "sd3"
509
+
510
+ elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint:
511
+ if CHECKPOINT_KEY_NAMES["animatediff_scribble"] in checkpoint:
512
+ model_type = "animatediff_scribble"
513
+
514
+ elif CHECKPOINT_KEY_NAMES["animatediff_rgb"] in checkpoint:
515
+ model_type = "animatediff_rgb"
516
+
517
+ elif CHECKPOINT_KEY_NAMES["animatediff_v2"] in checkpoint:
518
+ model_type = "animatediff_v2"
519
+
520
+ elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff_sdxl_beta"]].shape[-1] == 320:
521
+ model_type = "animatediff_sdxl_beta"
522
+
523
+ elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff"]].shape[1] == 24:
524
+ model_type = "animatediff_v1"
525
+
526
+ else:
527
+ model_type = "animatediff_v3"
528
+
529
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]):
530
+ if any(
531
+ g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
532
+ ):
533
+ model_type = "flux-dev"
534
+ else:
535
+ model_type = "flux-schnell"
536
+ else:
537
+ model_type = "v1"
538
+
539
+ return model_type
540
+
541
+
542
+ def fetch_diffusers_config(checkpoint):
543
+ model_type = infer_diffusers_model_type(checkpoint)
544
+ model_path = DIFFUSERS_DEFAULT_PIPELINE_PATHS[model_type]
545
+ model_path = copy.deepcopy(model_path)
546
+
547
+ return model_path
548
+
549
+
550
+ def set_image_size(checkpoint, image_size=None):
551
+ if image_size:
552
+ return image_size
553
+
554
+ model_type = infer_diffusers_model_type(checkpoint)
555
+ image_size = DIFFUSERS_TO_LDM_DEFAULT_IMAGE_SIZE_MAP[model_type]
556
+
557
+ return image_size
558
+
559
+
560
+ # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
561
+ def conv_attn_to_linear(checkpoint):
562
+ keys = list(checkpoint.keys())
563
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
564
+ for key in keys:
565
+ if ".".join(key.split(".")[-2:]) in attn_keys:
566
+ if checkpoint[key].ndim > 2:
567
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
568
+ elif "proj_attn.weight" in key:
569
+ if checkpoint[key].ndim > 2:
570
+ checkpoint[key] = checkpoint[key][:, :, 0]
571
+
572
+
573
+ def create_unet_diffusers_config_from_ldm(
574
+ original_config, checkpoint, image_size=None, upcast_attention=None, num_in_channels=None
575
+ ):
576
+ """
577
+ Creates a config for the diffusers based on the config of the LDM model.
578
+ """
579
+ if image_size is not None:
580
+ deprecation_message = (
581
+ "Configuring UNet2DConditionModel with the `image_size` argument to `from_single_file`"
582
+ "is deprecated and will be ignored in future versions."
583
+ )
584
+ deprecate("image_size", "1.0.0", deprecation_message)
585
+
586
+ image_size = set_image_size(checkpoint, image_size=image_size)
587
+
588
+ if (
589
+ "unet_config" in original_config["model"]["params"]
590
+ and original_config["model"]["params"]["unet_config"] is not None
591
+ ):
592
+ unet_params = original_config["model"]["params"]["unet_config"]["params"]
593
+ else:
594
+ unet_params = original_config["model"]["params"]["network_config"]["params"]
595
+
596
+ if num_in_channels is not None:
597
+ deprecation_message = (
598
+ "Configuring UNet2DConditionModel with the `num_in_channels` argument to `from_single_file`"
599
+ "is deprecated and will be ignored in future versions."
600
+ )
601
+ deprecate("image_size", "1.0.0", deprecation_message)
602
+ in_channels = num_in_channels
603
+ else:
604
+ in_channels = unet_params["in_channels"]
605
+
606
+ vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
607
+ block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
608
+
609
+ down_block_types = []
610
+ resolution = 1
611
+ for i in range(len(block_out_channels)):
612
+ block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
613
+ down_block_types.append(block_type)
614
+ if i != len(block_out_channels) - 1:
615
+ resolution *= 2
616
+
617
+ up_block_types = []
618
+ for i in range(len(block_out_channels)):
619
+ block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
620
+ up_block_types.append(block_type)
621
+ resolution //= 2
622
+
623
+ if unet_params["transformer_depth"] is not None:
624
+ transformer_layers_per_block = (
625
+ unet_params["transformer_depth"]
626
+ if isinstance(unet_params["transformer_depth"], int)
627
+ else list(unet_params["transformer_depth"])
628
+ )
629
+ else:
630
+ transformer_layers_per_block = 1
631
+
632
+ vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)
633
+
634
+ head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None
635
+ use_linear_projection = (
636
+ unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False
637
+ )
638
+ if use_linear_projection:
639
+ # stable diffusion 2-base-512 and 2-768
640
+ if head_dim is None:
641
+ head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"]
642
+ head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])]
643
+
644
+ class_embed_type = None
645
+ addition_embed_type = None
646
+ addition_time_embed_dim = None
647
+ projection_class_embeddings_input_dim = None
648
+ context_dim = None
649
+
650
+ if unet_params["context_dim"] is not None:
651
+ context_dim = (
652
+ unet_params["context_dim"]
653
+ if isinstance(unet_params["context_dim"], int)
654
+ else unet_params["context_dim"][0]
655
+ )
656
+
657
+ if "num_classes" in unet_params:
658
+ if unet_params["num_classes"] == "sequential":
659
+ if context_dim in [2048, 1280]:
660
+ # SDXL
661
+ addition_embed_type = "text_time"
662
+ addition_time_embed_dim = 256
663
+ else:
664
+ class_embed_type = "projection"
665
+ assert "adm_in_channels" in unet_params
666
+ projection_class_embeddings_input_dim = unet_params["adm_in_channels"]
667
+
668
+ config = {
669
+ "sample_size": image_size // vae_scale_factor,
670
+ "in_channels": in_channels,
671
+ "down_block_types": down_block_types,
672
+ "block_out_channels": block_out_channels,
673
+ "layers_per_block": unet_params["num_res_blocks"],
674
+ "cross_attention_dim": context_dim,
675
+ "attention_head_dim": head_dim,
676
+ "use_linear_projection": use_linear_projection,
677
+ "class_embed_type": class_embed_type,
678
+ "addition_embed_type": addition_embed_type,
679
+ "addition_time_embed_dim": addition_time_embed_dim,
680
+ "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
681
+ "transformer_layers_per_block": transformer_layers_per_block,
682
+ }
683
+
684
+ if upcast_attention is not None:
685
+ deprecation_message = (
686
+ "Configuring UNet2DConditionModel with the `upcast_attention` argument to `from_single_file`"
687
+ "is deprecated and will be ignored in future versions."
688
+ )
689
+ deprecate("image_size", "1.0.0", deprecation_message)
690
+ config["upcast_attention"] = upcast_attention
691
+
692
+ if "disable_self_attentions" in unet_params:
693
+ config["only_cross_attention"] = unet_params["disable_self_attentions"]
694
+
695
+ if "num_classes" in unet_params and isinstance(unet_params["num_classes"], int):
696
+ config["num_class_embeds"] = unet_params["num_classes"]
697
+
698
+ config["out_channels"] = unet_params["out_channels"]
699
+ config["up_block_types"] = up_block_types
700
+
701
+ return config
702
+
703
+
704
+ def create_controlnet_diffusers_config_from_ldm(original_config, checkpoint, image_size=None, **kwargs):
705
+ if image_size is not None:
706
+ deprecation_message = (
707
+ "Configuring ControlNetModel with the `image_size` argument"
708
+ "is deprecated and will be ignored in future versions."
709
+ )
710
+ deprecate("image_size", "1.0.0", deprecation_message)
711
+
712
+ image_size = set_image_size(checkpoint, image_size=image_size)
713
+
714
+ unet_params = original_config["model"]["params"]["control_stage_config"]["params"]
715
+ diffusers_unet_config = create_unet_diffusers_config_from_ldm(original_config, image_size=image_size)
716
+
717
+ controlnet_config = {
718
+ "conditioning_channels": unet_params["hint_channels"],
719
+ "in_channels": diffusers_unet_config["in_channels"],
720
+ "down_block_types": diffusers_unet_config["down_block_types"],
721
+ "block_out_channels": diffusers_unet_config["block_out_channels"],
722
+ "layers_per_block": diffusers_unet_config["layers_per_block"],
723
+ "cross_attention_dim": diffusers_unet_config["cross_attention_dim"],
724
+ "attention_head_dim": diffusers_unet_config["attention_head_dim"],
725
+ "use_linear_projection": diffusers_unet_config["use_linear_projection"],
726
+ "class_embed_type": diffusers_unet_config["class_embed_type"],
727
+ "addition_embed_type": diffusers_unet_config["addition_embed_type"],
728
+ "addition_time_embed_dim": diffusers_unet_config["addition_time_embed_dim"],
729
+ "projection_class_embeddings_input_dim": diffusers_unet_config["projection_class_embeddings_input_dim"],
730
+ "transformer_layers_per_block": diffusers_unet_config["transformer_layers_per_block"],
731
+ }
732
+
733
+ return controlnet_config
734
+
735
+
736
+ def create_vae_diffusers_config_from_ldm(original_config, checkpoint, image_size=None, scaling_factor=None):
737
+ """
738
+ Creates a config for the diffusers based on the config of the LDM model.
739
+ """
740
+ if image_size is not None:
741
+ deprecation_message = (
742
+ "Configuring AutoencoderKL with the `image_size` argument"
743
+ "is deprecated and will be ignored in future versions."
744
+ )
745
+ deprecate("image_size", "1.0.0", deprecation_message)
746
+
747
+ image_size = set_image_size(checkpoint, image_size=image_size)
748
+
749
+ if "edm_mean" in checkpoint and "edm_std" in checkpoint:
750
+ latents_mean = checkpoint["edm_mean"]
751
+ latents_std = checkpoint["edm_std"]
752
+ else:
753
+ latents_mean = None
754
+ latents_std = None
755
+
756
+ vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
757
+ if (scaling_factor is None) and (latents_mean is not None) and (latents_std is not None):
758
+ scaling_factor = PLAYGROUND_VAE_SCALING_FACTOR
759
+
760
+ elif (scaling_factor is None) and ("scale_factor" in original_config["model"]["params"]):
761
+ scaling_factor = original_config["model"]["params"]["scale_factor"]
762
+
763
+ elif scaling_factor is None:
764
+ scaling_factor = LDM_VAE_DEFAULT_SCALING_FACTOR
765
+
766
+ block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
767
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
768
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
769
+
770
+ config = {
771
+ "sample_size": image_size,
772
+ "in_channels": vae_params["in_channels"],
773
+ "out_channels": vae_params["out_ch"],
774
+ "down_block_types": down_block_types,
775
+ "up_block_types": up_block_types,
776
+ "block_out_channels": block_out_channels,
777
+ "latent_channels": vae_params["z_channels"],
778
+ "layers_per_block": vae_params["num_res_blocks"],
779
+ "scaling_factor": scaling_factor,
780
+ }
781
+ if latents_mean is not None and latents_std is not None:
782
+ config.update({"latents_mean": latents_mean, "latents_std": latents_std})
783
+
784
+ return config
785
+
786
+
787
+ def update_unet_resnet_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, mapping=None):
788
+ for ldm_key in ldm_keys:
789
+ diffusers_key = (
790
+ ldm_key.replace("in_layers.0", "norm1")
791
+ .replace("in_layers.2", "conv1")
792
+ .replace("out_layers.0", "norm2")
793
+ .replace("out_layers.3", "conv2")
794
+ .replace("emb_layers.1", "time_emb_proj")
795
+ .replace("skip_connection", "conv_shortcut")
796
+ )
797
+ if mapping:
798
+ diffusers_key = diffusers_key.replace(mapping["old"], mapping["new"])
799
+ new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
800
+
801
+
802
+ def update_unet_attention_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, mapping):
803
+ for ldm_key in ldm_keys:
804
+ diffusers_key = ldm_key.replace(mapping["old"], mapping["new"])
805
+ new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
806
+
807
+
808
+ def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
809
+ for ldm_key in keys:
810
+ diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]).replace("nin_shortcut", "conv_shortcut")
811
+ new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
812
+
813
+
814
+ def update_vae_attentions_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
815
+ for ldm_key in keys:
816
+ diffusers_key = (
817
+ ldm_key.replace(mapping["old"], mapping["new"])
818
+ .replace("norm.weight", "group_norm.weight")
819
+ .replace("norm.bias", "group_norm.bias")
820
+ .replace("q.weight", "to_q.weight")
821
+ .replace("q.bias", "to_q.bias")
822
+ .replace("k.weight", "to_k.weight")
823
+ .replace("k.bias", "to_k.bias")
824
+ .replace("v.weight", "to_v.weight")
825
+ .replace("v.bias", "to_v.bias")
826
+ .replace("proj_out.weight", "to_out.0.weight")
827
+ .replace("proj_out.bias", "to_out.0.bias")
828
+ )
829
+ new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
830
+
831
+ # proj_attn.weight has to be converted from conv 1D to linear
832
+ shape = new_checkpoint[diffusers_key].shape
833
+
834
+ if len(shape) == 3:
835
+ new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0]
836
+ elif len(shape) == 4:
837
+ new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0, 0]
838
+
839
+
840
+ def convert_stable_cascade_unet_single_file_to_diffusers(checkpoint, **kwargs):
841
+ is_stage_c = "clip_txt_mapper.weight" in checkpoint
842
+
843
+ if is_stage_c:
844
+ state_dict = {}
845
+ for key in checkpoint.keys():
846
+ if key.endswith("in_proj_weight"):
847
+ weights = checkpoint[key].chunk(3, 0)
848
+ state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
849
+ state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
850
+ state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
851
+ elif key.endswith("in_proj_bias"):
852
+ weights = checkpoint[key].chunk(3, 0)
853
+ state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
854
+ state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
855
+ state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
856
+ elif key.endswith("out_proj.weight"):
857
+ weights = checkpoint[key]
858
+ state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
859
+ elif key.endswith("out_proj.bias"):
860
+ weights = checkpoint[key]
861
+ state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
862
+ else:
863
+ state_dict[key] = checkpoint[key]
864
+ else:
865
+ state_dict = {}
866
+ for key in checkpoint.keys():
867
+ if key.endswith("in_proj_weight"):
868
+ weights = checkpoint[key].chunk(3, 0)
869
+ state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
870
+ state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
871
+ state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
872
+ elif key.endswith("in_proj_bias"):
873
+ weights = checkpoint[key].chunk(3, 0)
874
+ state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
875
+ state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
876
+ state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
877
+ elif key.endswith("out_proj.weight"):
878
+ weights = checkpoint[key]
879
+ state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
880
+ elif key.endswith("out_proj.bias"):
881
+ weights = checkpoint[key]
882
+ state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
883
+ # rename clip_mapper to clip_txt_pooled_mapper
884
+ elif key.endswith("clip_mapper.weight"):
885
+ weights = checkpoint[key]
886
+ state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights
887
+ elif key.endswith("clip_mapper.bias"):
888
+ weights = checkpoint[key]
889
+ state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights
890
+ else:
891
+ state_dict[key] = checkpoint[key]
892
+
893
+ return state_dict
894
+
895
+
896
+ def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False, **kwargs):
897
+ """
898
+ Takes a state dict and a config, and returns a converted checkpoint.
899
+ """
900
+ # extract state_dict for UNet
901
+ unet_state_dict = {}
902
+ keys = list(checkpoint.keys())
903
+ unet_key = LDM_UNET_KEY
904
+
905
+ # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
906
+ if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
907
+ logger.warning("Checkpoint has both EMA and non-EMA weights.")
908
+ logger.warning(
909
+ "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
910
+ " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
911
+ )
912
+ for key in keys:
913
+ if key.startswith("model.diffusion_model"):
914
+ flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
915
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.get(flat_ema_key)
916
+ else:
917
+ if sum(k.startswith("model_ema") for k in keys) > 100:
918
+ logger.warning(
919
+ "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
920
+ " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
921
+ )
922
+ for key in keys:
923
+ if key.startswith(unet_key):
924
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.get(key)
925
+
926
+ new_checkpoint = {}
927
+ ldm_unet_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["layers"]
928
+ for diffusers_key, ldm_key in ldm_unet_keys.items():
929
+ if ldm_key not in unet_state_dict:
930
+ continue
931
+ new_checkpoint[diffusers_key] = unet_state_dict[ldm_key]
932
+
933
+ if ("class_embed_type" in config) and (config["class_embed_type"] in ["timestep", "projection"]):
934
+ class_embed_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["class_embed_type"]
935
+ for diffusers_key, ldm_key in class_embed_keys.items():
936
+ new_checkpoint[diffusers_key] = unet_state_dict[ldm_key]
937
+
938
+ if ("addition_embed_type" in config) and (config["addition_embed_type"] == "text_time"):
939
+ addition_embed_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["addition_embed_type"]
940
+ for diffusers_key, ldm_key in addition_embed_keys.items():
941
+ new_checkpoint[diffusers_key] = unet_state_dict[ldm_key]
942
+
943
+ # Relevant to StableDiffusionUpscalePipeline
944
+ if "num_class_embeds" in config:
945
+ if (config["num_class_embeds"] is not None) and ("label_emb.weight" in unet_state_dict):
946
+ new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"]
947
+
948
+ # Retrieves the keys for the input blocks only
949
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
950
+ input_blocks = {
951
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
952
+ for layer_id in range(num_input_blocks)
953
+ }
954
+
955
+ # Retrieves the keys for the middle blocks only
956
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
957
+ middle_blocks = {
958
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
959
+ for layer_id in range(num_middle_blocks)
960
+ }
961
+
962
+ # Retrieves the keys for the output blocks only
963
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
964
+ output_blocks = {
965
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
966
+ for layer_id in range(num_output_blocks)
967
+ }
968
+
969
+ # Down blocks
970
+ for i in range(1, num_input_blocks):
971
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
972
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
973
+
974
+ resnets = [
975
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
976
+ ]
977
+ update_unet_resnet_ldm_to_diffusers(
978
+ resnets,
979
+ new_checkpoint,
980
+ unet_state_dict,
981
+ {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"},
982
+ )
983
+
984
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
985
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.get(
986
+ f"input_blocks.{i}.0.op.weight"
987
+ )
988
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.get(
989
+ f"input_blocks.{i}.0.op.bias"
990
+ )
991
+
992
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
993
+ if attentions:
994
+ update_unet_attention_ldm_to_diffusers(
995
+ attentions,
996
+ new_checkpoint,
997
+ unet_state_dict,
998
+ {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"},
999
+ )
1000
+
1001
+ # Mid blocks
1002
+ for key in middle_blocks.keys():
1003
+ diffusers_key = max(key - 1, 0)
1004
+ if key % 2 == 0:
1005
+ update_unet_resnet_ldm_to_diffusers(
1006
+ middle_blocks[key],
1007
+ new_checkpoint,
1008
+ unet_state_dict,
1009
+ mapping={"old": f"middle_block.{key}", "new": f"mid_block.resnets.{diffusers_key}"},
1010
+ )
1011
+ else:
1012
+ update_unet_attention_ldm_to_diffusers(
1013
+ middle_blocks[key],
1014
+ new_checkpoint,
1015
+ unet_state_dict,
1016
+ mapping={"old": f"middle_block.{key}", "new": f"mid_block.attentions.{diffusers_key}"},
1017
+ )
1018
+
1019
+ # Up Blocks
1020
+ for i in range(num_output_blocks):
1021
+ block_id = i // (config["layers_per_block"] + 1)
1022
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
1023
+
1024
+ resnets = [
1025
+ key for key in output_blocks[i] if f"output_blocks.{i}.0" in key and f"output_blocks.{i}.0.op" not in key
1026
+ ]
1027
+ update_unet_resnet_ldm_to_diffusers(
1028
+ resnets,
1029
+ new_checkpoint,
1030
+ unet_state_dict,
1031
+ {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"},
1032
+ )
1033
+
1034
+ attentions = [
1035
+ key for key in output_blocks[i] if f"output_blocks.{i}.1" in key and f"output_blocks.{i}.1.conv" not in key
1036
+ ]
1037
+ if attentions:
1038
+ update_unet_attention_ldm_to_diffusers(
1039
+ attentions,
1040
+ new_checkpoint,
1041
+ unet_state_dict,
1042
+ {"old": f"output_blocks.{i}.1", "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}"},
1043
+ )
1044
+
1045
+ if f"output_blocks.{i}.1.conv.weight" in unet_state_dict:
1046
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
1047
+ f"output_blocks.{i}.1.conv.weight"
1048
+ ]
1049
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
1050
+ f"output_blocks.{i}.1.conv.bias"
1051
+ ]
1052
+ if f"output_blocks.{i}.2.conv.weight" in unet_state_dict:
1053
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
1054
+ f"output_blocks.{i}.2.conv.weight"
1055
+ ]
1056
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
1057
+ f"output_blocks.{i}.2.conv.bias"
1058
+ ]
1059
+
1060
+ return new_checkpoint
1061
+
1062
+
1063
+ def convert_controlnet_checkpoint(
1064
+ checkpoint,
1065
+ config,
1066
+ **kwargs,
1067
+ ):
1068
+ # Some controlnet ckpt files are distributed independently from the rest of the
1069
+ # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
1070
+ if "time_embed.0.weight" in checkpoint:
1071
+ controlnet_state_dict = checkpoint
1072
+
1073
+ else:
1074
+ controlnet_state_dict = {}
1075
+ keys = list(checkpoint.keys())
1076
+ controlnet_key = LDM_CONTROLNET_KEY
1077
+ for key in keys:
1078
+ if key.startswith(controlnet_key):
1079
+ controlnet_state_dict[key.replace(controlnet_key, "")] = checkpoint.get(key)
1080
+
1081
+ new_checkpoint = {}
1082
+ ldm_controlnet_keys = DIFFUSERS_TO_LDM_MAPPING["controlnet"]["layers"]
1083
+ for diffusers_key, ldm_key in ldm_controlnet_keys.items():
1084
+ if ldm_key not in controlnet_state_dict:
1085
+ continue
1086
+ new_checkpoint[diffusers_key] = controlnet_state_dict[ldm_key]
1087
+
1088
+ # Retrieves the keys for the input blocks only
1089
+ num_input_blocks = len(
1090
+ {".".join(layer.split(".")[:2]) for layer in controlnet_state_dict if "input_blocks" in layer}
1091
+ )
1092
+ input_blocks = {
1093
+ layer_id: [key for key in controlnet_state_dict if f"input_blocks.{layer_id}" in key]
1094
+ for layer_id in range(num_input_blocks)
1095
+ }
1096
+
1097
+ # Down blocks
1098
+ for i in range(1, num_input_blocks):
1099
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
1100
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
1101
+
1102
+ resnets = [
1103
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
1104
+ ]
1105
+ update_unet_resnet_ldm_to_diffusers(
1106
+ resnets,
1107
+ new_checkpoint,
1108
+ controlnet_state_dict,
1109
+ {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"},
1110
+ )
1111
+
1112
+ if f"input_blocks.{i}.0.op.weight" in controlnet_state_dict:
1113
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = controlnet_state_dict.get(
1114
+ f"input_blocks.{i}.0.op.weight"
1115
+ )
1116
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = controlnet_state_dict.get(
1117
+ f"input_blocks.{i}.0.op.bias"
1118
+ )
1119
+
1120
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
1121
+ if attentions:
1122
+ update_unet_attention_ldm_to_diffusers(
1123
+ attentions,
1124
+ new_checkpoint,
1125
+ controlnet_state_dict,
1126
+ {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"},
1127
+ )
1128
+
1129
+ # controlnet down blocks
1130
+ for i in range(num_input_blocks):
1131
+ new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = controlnet_state_dict.get(f"zero_convs.{i}.0.weight")
1132
+ new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = controlnet_state_dict.get(f"zero_convs.{i}.0.bias")
1133
+
1134
+ # Retrieves the keys for the middle blocks only
1135
+ num_middle_blocks = len(
1136
+ {".".join(layer.split(".")[:2]) for layer in controlnet_state_dict if "middle_block" in layer}
1137
+ )
1138
+ middle_blocks = {
1139
+ layer_id: [key for key in controlnet_state_dict if f"middle_block.{layer_id}" in key]
1140
+ for layer_id in range(num_middle_blocks)
1141
+ }
1142
+
1143
+ # Mid blocks
1144
+ for key in middle_blocks.keys():
1145
+ diffusers_key = max(key - 1, 0)
1146
+ if key % 2 == 0:
1147
+ update_unet_resnet_ldm_to_diffusers(
1148
+ middle_blocks[key],
1149
+ new_checkpoint,
1150
+ controlnet_state_dict,
1151
+ mapping={"old": f"middle_block.{key}", "new": f"mid_block.resnets.{diffusers_key}"},
1152
+ )
1153
+ else:
1154
+ update_unet_attention_ldm_to_diffusers(
1155
+ middle_blocks[key],
1156
+ new_checkpoint,
1157
+ controlnet_state_dict,
1158
+ mapping={"old": f"middle_block.{key}", "new": f"mid_block.attentions.{diffusers_key}"},
1159
+ )
1160
+
1161
+ # mid block
1162
+ new_checkpoint["controlnet_mid_block.weight"] = controlnet_state_dict.get("middle_block_out.0.weight")
1163
+ new_checkpoint["controlnet_mid_block.bias"] = controlnet_state_dict.get("middle_block_out.0.bias")
1164
+
1165
+ # controlnet cond embedding blocks
1166
+ cond_embedding_blocks = {
1167
+ ".".join(layer.split(".")[:2])
1168
+ for layer in controlnet_state_dict
1169
+ if "input_hint_block" in layer and ("input_hint_block.0" not in layer) and ("input_hint_block.14" not in layer)
1170
+ }
1171
+ num_cond_embedding_blocks = len(cond_embedding_blocks)
1172
+
1173
+ for idx in range(1, num_cond_embedding_blocks + 1):
1174
+ diffusers_idx = idx - 1
1175
+ cond_block_id = 2 * idx
1176
+
1177
+ new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] = controlnet_state_dict.get(
1178
+ f"input_hint_block.{cond_block_id}.weight"
1179
+ )
1180
+ new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = controlnet_state_dict.get(
1181
+ f"input_hint_block.{cond_block_id}.bias"
1182
+ )
1183
+
1184
+ return new_checkpoint
1185
+
1186
+
1187
+ def convert_ldm_vae_checkpoint(checkpoint, config):
1188
+ # extract state dict for VAE
1189
+ # remove the LDM_VAE_KEY prefix from the ldm checkpoint keys so that it is easier to map them to diffusers keys
1190
+ vae_state_dict = {}
1191
+ keys = list(checkpoint.keys())
1192
+ vae_key = ""
1193
+ for ldm_vae_key in LDM_VAE_KEYS:
1194
+ if any(k.startswith(ldm_vae_key) for k in keys):
1195
+ vae_key = ldm_vae_key
1196
+
1197
+ for key in keys:
1198
+ if key.startswith(vae_key):
1199
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
1200
+
1201
+ new_checkpoint = {}
1202
+ vae_diffusers_ldm_map = DIFFUSERS_TO_LDM_MAPPING["vae"]
1203
+ for diffusers_key, ldm_key in vae_diffusers_ldm_map.items():
1204
+ if ldm_key not in vae_state_dict:
1205
+ continue
1206
+ new_checkpoint[diffusers_key] = vae_state_dict[ldm_key]
1207
+
1208
+ # Retrieves the keys for the encoder down blocks only
1209
+ num_down_blocks = len(config["down_block_types"])
1210
+ down_blocks = {
1211
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
1212
+ }
1213
+
1214
+ for i in range(num_down_blocks):
1215
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
1216
+ update_vae_resnet_ldm_to_diffusers(
1217
+ resnets,
1218
+ new_checkpoint,
1219
+ vae_state_dict,
1220
+ mapping={"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"},
1221
+ )
1222
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
1223
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.get(
1224
+ f"encoder.down.{i}.downsample.conv.weight"
1225
+ )
1226
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.get(
1227
+ f"encoder.down.{i}.downsample.conv.bias"
1228
+ )
1229
+
1230
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
1231
+ num_mid_res_blocks = 2
1232
+ for i in range(1, num_mid_res_blocks + 1):
1233
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
1234
+ update_vae_resnet_ldm_to_diffusers(
1235
+ resnets,
1236
+ new_checkpoint,
1237
+ vae_state_dict,
1238
+ mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"},
1239
+ )
1240
+
1241
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
1242
+ update_vae_attentions_ldm_to_diffusers(
1243
+ mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"}
1244
+ )
1245
+
1246
+ # Retrieves the keys for the decoder up blocks only
1247
+ num_up_blocks = len(config["up_block_types"])
1248
+ up_blocks = {
1249
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
1250
+ }
1251
+
1252
+ for i in range(num_up_blocks):
1253
+ block_id = num_up_blocks - 1 - i
1254
+ resnets = [
1255
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
1256
+ ]
1257
+ update_vae_resnet_ldm_to_diffusers(
1258
+ resnets,
1259
+ new_checkpoint,
1260
+ vae_state_dict,
1261
+ mapping={"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"},
1262
+ )
1263
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
1264
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
1265
+ f"decoder.up.{block_id}.upsample.conv.weight"
1266
+ ]
1267
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
1268
+ f"decoder.up.{block_id}.upsample.conv.bias"
1269
+ ]
1270
+
1271
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
1272
+ num_mid_res_blocks = 2
1273
+ for i in range(1, num_mid_res_blocks + 1):
1274
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
1275
+ update_vae_resnet_ldm_to_diffusers(
1276
+ resnets,
1277
+ new_checkpoint,
1278
+ vae_state_dict,
1279
+ mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"},
1280
+ )
1281
+
1282
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
1283
+ update_vae_attentions_ldm_to_diffusers(
1284
+ mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"}
1285
+ )
1286
+ conv_attn_to_linear(new_checkpoint)
1287
+
1288
+ return new_checkpoint
1289
+
1290
+
1291
+ def convert_ldm_clip_checkpoint(checkpoint, remove_prefix=None):
1292
+ keys = list(checkpoint.keys())
1293
+ text_model_dict = {}
1294
+
1295
+ remove_prefixes = []
1296
+ remove_prefixes.extend(LDM_CLIP_PREFIX_TO_REMOVE)
1297
+ if remove_prefix:
1298
+ remove_prefixes.append(remove_prefix)
1299
+
1300
+ for key in keys:
1301
+ for prefix in remove_prefixes:
1302
+ if key.startswith(prefix):
1303
+ diffusers_key = key.replace(prefix, "")
1304
+ text_model_dict[diffusers_key] = checkpoint.get(key)
1305
+
1306
+ return text_model_dict
1307
+
1308
+
1309
+ def convert_open_clip_checkpoint(
1310
+ text_model,
1311
+ checkpoint,
1312
+ prefix="cond_stage_model.model.",
1313
+ ):
1314
+ text_model_dict = {}
1315
+ text_proj_key = prefix + "text_projection"
1316
+
1317
+ if text_proj_key in checkpoint:
1318
+ text_proj_dim = int(checkpoint[text_proj_key].shape[0])
1319
+ elif hasattr(text_model.config, "projection_dim"):
1320
+ text_proj_dim = text_model.config.projection_dim
1321
+ else:
1322
+ text_proj_dim = LDM_OPEN_CLIP_TEXT_PROJECTION_DIM
1323
+
1324
+ keys = list(checkpoint.keys())
1325
+ keys_to_ignore = SD_2_TEXT_ENCODER_KEYS_TO_IGNORE
1326
+
1327
+ openclip_diffusers_ldm_map = DIFFUSERS_TO_LDM_MAPPING["openclip"]["layers"]
1328
+ for diffusers_key, ldm_key in openclip_diffusers_ldm_map.items():
1329
+ ldm_key = prefix + ldm_key
1330
+ if ldm_key not in checkpoint:
1331
+ continue
1332
+ if ldm_key in keys_to_ignore:
1333
+ continue
1334
+ if ldm_key.endswith("text_projection"):
1335
+ text_model_dict[diffusers_key] = checkpoint[ldm_key].T.contiguous()
1336
+ else:
1337
+ text_model_dict[diffusers_key] = checkpoint[ldm_key]
1338
+
1339
+ for key in keys:
1340
+ if key in keys_to_ignore:
1341
+ continue
1342
+
1343
+ if not key.startswith(prefix + "transformer."):
1344
+ continue
1345
+
1346
+ diffusers_key = key.replace(prefix + "transformer.", "")
1347
+ transformer_diffusers_to_ldm_map = DIFFUSERS_TO_LDM_MAPPING["openclip"]["transformer"]
1348
+ for new_key, old_key in transformer_diffusers_to_ldm_map.items():
1349
+ diffusers_key = (
1350
+ diffusers_key.replace(old_key, new_key).replace(".in_proj_weight", "").replace(".in_proj_bias", "")
1351
+ )
1352
+
1353
+ if key.endswith(".in_proj_weight"):
1354
+ weight_value = checkpoint.get(key)
1355
+
1356
+ text_model_dict[diffusers_key + ".q_proj.weight"] = weight_value[:text_proj_dim, :].clone().detach()
1357
+ text_model_dict[diffusers_key + ".k_proj.weight"] = (
1358
+ weight_value[text_proj_dim : text_proj_dim * 2, :].clone().detach()
1359
+ )
1360
+ text_model_dict[diffusers_key + ".v_proj.weight"] = weight_value[text_proj_dim * 2 :, :].clone().detach()
1361
+
1362
+ elif key.endswith(".in_proj_bias"):
1363
+ weight_value = checkpoint.get(key)
1364
+ text_model_dict[diffusers_key + ".q_proj.bias"] = weight_value[:text_proj_dim].clone().detach()
1365
+ text_model_dict[diffusers_key + ".k_proj.bias"] = (
1366
+ weight_value[text_proj_dim : text_proj_dim * 2].clone().detach()
1367
+ )
1368
+ text_model_dict[diffusers_key + ".v_proj.bias"] = weight_value[text_proj_dim * 2 :].clone().detach()
1369
+ else:
1370
+ text_model_dict[diffusers_key] = checkpoint.get(key)
1371
+
1372
+ return text_model_dict
1373
+
1374
+
1375
+ def create_diffusers_clip_model_from_ldm(
1376
+ cls,
1377
+ checkpoint,
1378
+ subfolder="",
1379
+ config=None,
1380
+ torch_dtype=None,
1381
+ local_files_only=None,
1382
+ is_legacy_loading=False,
1383
+ ):
1384
+ if config:
1385
+ config = {"pretrained_model_name_or_path": config}
1386
+ else:
1387
+ config = fetch_diffusers_config(checkpoint)
1388
+
1389
+ # For backwards compatibility
1390
+ # Older versions of `from_single_file` expected CLIP configs to be placed in their original transformers model repo
1391
+ # in the cache_dir, rather than in a subfolder of the Diffusers model
1392
+ if is_legacy_loading:
1393
+ logger.warning(
1394
+ (
1395
+ "Detected legacy CLIP loading behavior. Please run `from_single_file` with `local_files_only=False once to update "
1396
+ "the local cache directory with the necessary CLIP model config files. "
1397
+ "Attempting to load CLIP model from legacy cache directory."
1398
+ )
1399
+ )
1400
+
1401
+ if is_clip_model(checkpoint) or is_clip_sdxl_model(checkpoint):
1402
+ clip_config = "openai/clip-vit-large-patch14"
1403
+ config["pretrained_model_name_or_path"] = clip_config
1404
+ subfolder = ""
1405
+
1406
+ elif is_open_clip_model(checkpoint):
1407
+ clip_config = "stabilityai/stable-diffusion-2"
1408
+ config["pretrained_model_name_or_path"] = clip_config
1409
+ subfolder = "text_encoder"
1410
+
1411
+ else:
1412
+ clip_config = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
1413
+ config["pretrained_model_name_or_path"] = clip_config
1414
+ subfolder = ""
1415
+
1416
+ model_config = cls.config_class.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only)
1417
+ ctx = init_empty_weights if is_accelerate_available() else nullcontext
1418
+ with ctx():
1419
+ model = cls(model_config)
1420
+
1421
+ position_embedding_dim = model.text_model.embeddings.position_embedding.weight.shape[-1]
1422
+
1423
+ if is_clip_model(checkpoint):
1424
+ diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint)
1425
+
1426
+ elif (
1427
+ is_clip_sdxl_model(checkpoint)
1428
+ and checkpoint[CHECKPOINT_KEY_NAMES["clip_sdxl"]].shape[-1] == position_embedding_dim
1429
+ ):
1430
+ diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint)
1431
+
1432
+ elif (
1433
+ is_clip_sd3_model(checkpoint)
1434
+ and checkpoint[CHECKPOINT_KEY_NAMES["clip_sd3"]].shape[-1] == position_embedding_dim
1435
+ ):
1436
+ diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_l.transformer.")
1437
+ diffusers_format_checkpoint["text_projection.weight"] = torch.eye(position_embedding_dim)
1438
+
1439
+ elif is_open_clip_model(checkpoint):
1440
+ prefix = "cond_stage_model.model."
1441
+ diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
1442
+
1443
+ elif (
1444
+ is_open_clip_sdxl_model(checkpoint)
1445
+ and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sdxl"]].shape[-1] == position_embedding_dim
1446
+ ):
1447
+ prefix = "conditioner.embedders.1.model."
1448
+ diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
1449
+
1450
+ elif is_open_clip_sdxl_refiner_model(checkpoint):
1451
+ prefix = "conditioner.embedders.0.model."
1452
+ diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
1453
+
1454
+ elif (
1455
+ is_open_clip_sd3_model(checkpoint)
1456
+ and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sd3"]].shape[-1] == position_embedding_dim
1457
+ ):
1458
+ diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_g.transformer.")
1459
+
1460
+ else:
1461
+ raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
1462
+
1463
+ if is_accelerate_available():
1464
+ unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
1465
+ else:
1466
+ _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
1467
+
1468
+ if model._keys_to_ignore_on_load_unexpected is not None:
1469
+ for pat in model._keys_to_ignore_on_load_unexpected:
1470
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1471
+
1472
+ if len(unexpected_keys) > 0:
1473
+ logger.warning(
1474
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
1475
+ )
1476
+
1477
+ if torch_dtype is not None:
1478
+ model.to(torch_dtype)
1479
+
1480
+ model.eval()
1481
+
1482
+ return model
1483
+
1484
+
1485
+ def _legacy_load_scheduler(
1486
+ cls,
1487
+ checkpoint,
1488
+ component_name,
1489
+ original_config=None,
1490
+ **kwargs,
1491
+ ):
1492
+ scheduler_type = kwargs.get("scheduler_type", None)
1493
+ prediction_type = kwargs.get("prediction_type", None)
1494
+
1495
+ if scheduler_type is not None:
1496
+ deprecation_message = (
1497
+ "Please pass an instance of a Scheduler object directly to the `scheduler` argument in `from_single_file`\n\n"
1498
+ "Example:\n\n"
1499
+ "from diffusers import StableDiffusionPipeline, DDIMScheduler\n\n"
1500
+ "scheduler = DDIMScheduler()\n"
1501
+ "pipe = StableDiffusionPipeline.from_single_file(<checkpoint path>, scheduler=scheduler)\n"
1502
+ )
1503
+ deprecate("scheduler_type", "1.0.0", deprecation_message)
1504
+
1505
+ if prediction_type is not None:
1506
+ deprecation_message = (
1507
+ "Please configure an instance of a Scheduler with the appropriate `prediction_type` and "
1508
+ "pass the object directly to the `scheduler` argument in `from_single_file`.\n\n"
1509
+ "Example:\n\n"
1510
+ "from diffusers import StableDiffusionPipeline, DDIMScheduler\n\n"
1511
+ 'scheduler = DDIMScheduler(prediction_type="v_prediction")\n'
1512
+ "pipe = StableDiffusionPipeline.from_single_file(<checkpoint path>, scheduler=scheduler)\n"
1513
+ )
1514
+ deprecate("prediction_type", "1.0.0", deprecation_message)
1515
+
1516
+ scheduler_config = SCHEDULER_DEFAULT_CONFIG
1517
+ model_type = infer_diffusers_model_type(checkpoint=checkpoint)
1518
+
1519
+ global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
1520
+
1521
+ if original_config:
1522
+ num_train_timesteps = getattr(original_config["model"]["params"], "timesteps", 1000)
1523
+ else:
1524
+ num_train_timesteps = 1000
1525
+
1526
+ scheduler_config["num_train_timesteps"] = num_train_timesteps
1527
+
1528
+ if model_type == "v2":
1529
+ if prediction_type is None:
1530
+ # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` # as it relies on a brittle global step parameter here
1531
+ prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
1532
+
1533
+ else:
1534
+ prediction_type = prediction_type or "epsilon"
1535
+
1536
+ scheduler_config["prediction_type"] = prediction_type
1537
+
1538
+ if model_type in ["xl_base", "xl_refiner"]:
1539
+ scheduler_type = "euler"
1540
+ elif model_type == "playground":
1541
+ scheduler_type = "edm_dpm_solver_multistep"
1542
+ else:
1543
+ if original_config:
1544
+ beta_start = original_config["model"]["params"].get("linear_start")
1545
+ beta_end = original_config["model"]["params"].get("linear_end")
1546
+
1547
+ else:
1548
+ beta_start = 0.02
1549
+ beta_end = 0.085
1550
+
1551
+ scheduler_config["beta_start"] = beta_start
1552
+ scheduler_config["beta_end"] = beta_end
1553
+ scheduler_config["beta_schedule"] = "scaled_linear"
1554
+ scheduler_config["clip_sample"] = False
1555
+ scheduler_config["set_alpha_to_one"] = False
1556
+
1557
+ # to deal with an edge case StableDiffusionUpscale pipeline has two schedulers
1558
+ if component_name == "low_res_scheduler":
1559
+ return cls.from_config(
1560
+ {
1561
+ "beta_end": 0.02,
1562
+ "beta_schedule": "scaled_linear",
1563
+ "beta_start": 0.0001,
1564
+ "clip_sample": True,
1565
+ "num_train_timesteps": 1000,
1566
+ "prediction_type": "epsilon",
1567
+ "trained_betas": None,
1568
+ "variance_type": "fixed_small",
1569
+ }
1570
+ )
1571
+
1572
+ if scheduler_type is None:
1573
+ return cls.from_config(scheduler_config)
1574
+
1575
+ elif scheduler_type == "pndm":
1576
+ scheduler_config["skip_prk_steps"] = True
1577
+ scheduler = PNDMScheduler.from_config(scheduler_config)
1578
+
1579
+ elif scheduler_type == "lms":
1580
+ scheduler = LMSDiscreteScheduler.from_config(scheduler_config)
1581
+
1582
+ elif scheduler_type == "heun":
1583
+ scheduler = HeunDiscreteScheduler.from_config(scheduler_config)
1584
+
1585
+ elif scheduler_type == "euler":
1586
+ scheduler = EulerDiscreteScheduler.from_config(scheduler_config)
1587
+
1588
+ elif scheduler_type == "euler-ancestral":
1589
+ scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler_config)
1590
+
1591
+ elif scheduler_type == "dpm":
1592
+ scheduler = DPMSolverMultistepScheduler.from_config(scheduler_config)
1593
+
1594
+ elif scheduler_type == "ddim":
1595
+ scheduler = DDIMScheduler.from_config(scheduler_config)
1596
+
1597
+ elif scheduler_type == "edm_dpm_solver_multistep":
1598
+ scheduler_config = {
1599
+ "algorithm_type": "dpmsolver++",
1600
+ "dynamic_thresholding_ratio": 0.995,
1601
+ "euler_at_final": False,
1602
+ "final_sigmas_type": "zero",
1603
+ "lower_order_final": True,
1604
+ "num_train_timesteps": 1000,
1605
+ "prediction_type": "epsilon",
1606
+ "rho": 7.0,
1607
+ "sample_max_value": 1.0,
1608
+ "sigma_data": 0.5,
1609
+ "sigma_max": 80.0,
1610
+ "sigma_min": 0.002,
1611
+ "solver_order": 2,
1612
+ "solver_type": "midpoint",
1613
+ "thresholding": False,
1614
+ }
1615
+ scheduler = EDMDPMSolverMultistepScheduler(**scheduler_config)
1616
+
1617
+ else:
1618
+ raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
1619
+
1620
+ return scheduler
1621
+
1622
+
1623
+ def _legacy_load_clip_tokenizer(cls, checkpoint, config=None, local_files_only=False):
1624
+ if config:
1625
+ config = {"pretrained_model_name_or_path": config}
1626
+ else:
1627
+ config = fetch_diffusers_config(checkpoint)
1628
+
1629
+ if is_clip_model(checkpoint) or is_clip_sdxl_model(checkpoint):
1630
+ clip_config = "openai/clip-vit-large-patch14"
1631
+ config["pretrained_model_name_or_path"] = clip_config
1632
+ subfolder = ""
1633
+
1634
+ elif is_open_clip_model(checkpoint):
1635
+ clip_config = "stabilityai/stable-diffusion-2"
1636
+ config["pretrained_model_name_or_path"] = clip_config
1637
+ subfolder = "tokenizer"
1638
+
1639
+ else:
1640
+ clip_config = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
1641
+ config["pretrained_model_name_or_path"] = clip_config
1642
+ subfolder = ""
1643
+
1644
+ tokenizer = cls.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only)
1645
+
1646
+ return tokenizer
1647
+
1648
+
1649
+ def _legacy_load_safety_checker(local_files_only, torch_dtype):
1650
+ # Support for loading safety checker components using the deprecated
1651
+ # `load_safety_checker` argument.
1652
+
1653
+ from ..pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
1654
+
1655
+ feature_extractor = AutoImageProcessor.from_pretrained(
1656
+ "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only, torch_dtype=torch_dtype
1657
+ )
1658
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(
1659
+ "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only, torch_dtype=torch_dtype
1660
+ )
1661
+
1662
+ return {"safety_checker": safety_checker, "feature_extractor": feature_extractor}
1663
+
1664
+
1665
+ # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
1666
+ # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
1667
+ def swap_scale_shift(weight, dim):
1668
+ shift, scale = weight.chunk(2, dim=0)
1669
+ new_weight = torch.cat([scale, shift], dim=0)
1670
+ return new_weight
1671
+
1672
+
1673
+ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
1674
+ converted_state_dict = {}
1675
+ keys = list(checkpoint.keys())
1676
+ for k in keys:
1677
+ if "model.diffusion_model." in k:
1678
+ checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
1679
+
1680
+ num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "joint_blocks" in k))[-1] + 1 # noqa: C401
1681
+ caption_projection_dim = 1536
1682
+
1683
+ # Positional and patch embeddings.
1684
+ converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("pos_embed")
1685
+ converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
1686
+ converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
1687
+
1688
+ # Timestep embeddings.
1689
+ converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop(
1690
+ "t_embedder.mlp.0.weight"
1691
+ )
1692
+ converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
1693
+ converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop(
1694
+ "t_embedder.mlp.2.weight"
1695
+ )
1696
+ converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
1697
+
1698
+ # Context projections.
1699
+ converted_state_dict["context_embedder.weight"] = checkpoint.pop("context_embedder.weight")
1700
+ converted_state_dict["context_embedder.bias"] = checkpoint.pop("context_embedder.bias")
1701
+
1702
+ # Pooled context projection.
1703
+ converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("y_embedder.mlp.0.weight")
1704
+ converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("y_embedder.mlp.0.bias")
1705
+ converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop("y_embedder.mlp.2.weight")
1706
+ converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("y_embedder.mlp.2.bias")
1707
+
1708
+ # Transformer blocks 🎸.
1709
+ for i in range(num_layers):
1710
+ # Q, K, V
1711
+ sample_q, sample_k, sample_v = torch.chunk(
1712
+ checkpoint.pop(f"joint_blocks.{i}.x_block.attn.qkv.weight"), 3, dim=0
1713
+ )
1714
+ context_q, context_k, context_v = torch.chunk(
1715
+ checkpoint.pop(f"joint_blocks.{i}.context_block.attn.qkv.weight"), 3, dim=0
1716
+ )
1717
+ sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
1718
+ checkpoint.pop(f"joint_blocks.{i}.x_block.attn.qkv.bias"), 3, dim=0
1719
+ )
1720
+ context_q_bias, context_k_bias, context_v_bias = torch.chunk(
1721
+ checkpoint.pop(f"joint_blocks.{i}.context_block.attn.qkv.bias"), 3, dim=0
1722
+ )
1723
+
1724
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_q.weight"] = torch.cat([sample_q])
1725
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_q.bias"] = torch.cat([sample_q_bias])
1726
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_k.weight"] = torch.cat([sample_k])
1727
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_k.bias"] = torch.cat([sample_k_bias])
1728
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_v.weight"] = torch.cat([sample_v])
1729
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_v.bias"] = torch.cat([sample_v_bias])
1730
+
1731
+ converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.weight"] = torch.cat([context_q])
1732
+ converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.bias"] = torch.cat([context_q_bias])
1733
+ converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.weight"] = torch.cat([context_k])
1734
+ converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.bias"] = torch.cat([context_k_bias])
1735
+ converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v])
1736
+ converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias])
1737
+
1738
+ # output projections.
1739
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = checkpoint.pop(
1740
+ f"joint_blocks.{i}.x_block.attn.proj.weight"
1741
+ )
1742
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.bias"] = checkpoint.pop(
1743
+ f"joint_blocks.{i}.x_block.attn.proj.bias"
1744
+ )
1745
+ if not (i == num_layers - 1):
1746
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.weight"] = checkpoint.pop(
1747
+ f"joint_blocks.{i}.context_block.attn.proj.weight"
1748
+ )
1749
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.bias"] = checkpoint.pop(
1750
+ f"joint_blocks.{i}.context_block.attn.proj.bias"
1751
+ )
1752
+
1753
+ # norms.
1754
+ converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop(
1755
+ f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight"
1756
+ )
1757
+ converted_state_dict[f"transformer_blocks.{i}.norm1.linear.bias"] = checkpoint.pop(
1758
+ f"joint_blocks.{i}.x_block.adaLN_modulation.1.bias"
1759
+ )
1760
+ if not (i == num_layers - 1):
1761
+ converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = checkpoint.pop(
1762
+ f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight"
1763
+ )
1764
+ converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = checkpoint.pop(
1765
+ f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias"
1766
+ )
1767
+ else:
1768
+ converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = swap_scale_shift(
1769
+ checkpoint.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight"),
1770
+ dim=caption_projection_dim,
1771
+ )
1772
+ converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = swap_scale_shift(
1773
+ checkpoint.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias"),
1774
+ dim=caption_projection_dim,
1775
+ )
1776
+
1777
+ # ffs.
1778
+ converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.weight"] = checkpoint.pop(
1779
+ f"joint_blocks.{i}.x_block.mlp.fc1.weight"
1780
+ )
1781
+ converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.bias"] = checkpoint.pop(
1782
+ f"joint_blocks.{i}.x_block.mlp.fc1.bias"
1783
+ )
1784
+ converted_state_dict[f"transformer_blocks.{i}.ff.net.2.weight"] = checkpoint.pop(
1785
+ f"joint_blocks.{i}.x_block.mlp.fc2.weight"
1786
+ )
1787
+ converted_state_dict[f"transformer_blocks.{i}.ff.net.2.bias"] = checkpoint.pop(
1788
+ f"joint_blocks.{i}.x_block.mlp.fc2.bias"
1789
+ )
1790
+ if not (i == num_layers - 1):
1791
+ converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.weight"] = checkpoint.pop(
1792
+ f"joint_blocks.{i}.context_block.mlp.fc1.weight"
1793
+ )
1794
+ converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.bias"] = checkpoint.pop(
1795
+ f"joint_blocks.{i}.context_block.mlp.fc1.bias"
1796
+ )
1797
+ converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.weight"] = checkpoint.pop(
1798
+ f"joint_blocks.{i}.context_block.mlp.fc2.weight"
1799
+ )
1800
+ converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.bias"] = checkpoint.pop(
1801
+ f"joint_blocks.{i}.context_block.mlp.fc2.bias"
1802
+ )
1803
+
1804
+ # Final blocks.
1805
+ converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
1806
+ converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
1807
+ converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
1808
+ checkpoint.pop("final_layer.adaLN_modulation.1.weight"), dim=caption_projection_dim
1809
+ )
1810
+ converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
1811
+ checkpoint.pop("final_layer.adaLN_modulation.1.bias"), dim=caption_projection_dim
1812
+ )
1813
+
1814
+ return converted_state_dict
1815
+
1816
+
1817
+ def is_t5_in_single_file(checkpoint):
1818
+ if "text_encoders.t5xxl.transformer.shared.weight" in checkpoint:
1819
+ return True
1820
+
1821
+ return False
1822
+
1823
+
1824
+ def convert_sd3_t5_checkpoint_to_diffusers(checkpoint):
1825
+ keys = list(checkpoint.keys())
1826
+ text_model_dict = {}
1827
+
1828
+ remove_prefixes = ["text_encoders.t5xxl.transformer."]
1829
+
1830
+ for key in keys:
1831
+ for prefix in remove_prefixes:
1832
+ if key.startswith(prefix):
1833
+ diffusers_key = key.replace(prefix, "")
1834
+ text_model_dict[diffusers_key] = checkpoint.get(key)
1835
+
1836
+ return text_model_dict
1837
+
1838
+
1839
+ def create_diffusers_t5_model_from_checkpoint(
1840
+ cls,
1841
+ checkpoint,
1842
+ subfolder="",
1843
+ config=None,
1844
+ torch_dtype=None,
1845
+ local_files_only=None,
1846
+ ):
1847
+ if config:
1848
+ config = {"pretrained_model_name_or_path": config}
1849
+ else:
1850
+ config = fetch_diffusers_config(checkpoint)
1851
+
1852
+ model_config = cls.config_class.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only)
1853
+ ctx = init_empty_weights if is_accelerate_available() else nullcontext
1854
+ with ctx():
1855
+ model = cls(model_config)
1856
+
1857
+ diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint)
1858
+
1859
+ if is_accelerate_available():
1860
+ unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
1861
+ if model._keys_to_ignore_on_load_unexpected is not None:
1862
+ for pat in model._keys_to_ignore_on_load_unexpected:
1863
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1864
+
1865
+ if len(unexpected_keys) > 0:
1866
+ logger.warning(
1867
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
1868
+ )
1869
+
1870
+ else:
1871
+ model.load_state_dict(diffusers_format_checkpoint)
1872
+
1873
+ use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (torch_dtype == torch.float16)
1874
+ if use_keep_in_fp32_modules:
1875
+ keep_in_fp32_modules = model._keep_in_fp32_modules
1876
+ else:
1877
+ keep_in_fp32_modules = []
1878
+
1879
+ if keep_in_fp32_modules is not None:
1880
+ for name, param in model.named_parameters():
1881
+ if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules):
1882
+ # param = param.to(torch.float32) does not work here as only in the local scope.
1883
+ param.data = param.data.to(torch.float32)
1884
+
1885
+ return model
1886
+
1887
+
1888
+ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
1889
+ converted_state_dict = {}
1890
+ for k, v in checkpoint.items():
1891
+ if "pos_encoder" in k:
1892
+ continue
1893
+
1894
+ else:
1895
+ converted_state_dict[
1896
+ k.replace(".norms.0", ".norm1")
1897
+ .replace(".norms.1", ".norm2")
1898
+ .replace(".ff_norm", ".norm3")
1899
+ .replace(".attention_blocks.0", ".attn1")
1900
+ .replace(".attention_blocks.1", ".attn2")
1901
+ .replace(".temporal_transformer", "")
1902
+ ] = v
1903
+
1904
+ return converted_state_dict
1905
+
1906
+
1907
+ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
1908
+ converted_state_dict = {}
1909
+ keys = list(checkpoint.keys())
1910
+ for k in keys:
1911
+ if "model.diffusion_model." in k:
1912
+ checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
1913
+
1914
+ num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
1915
+ num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
1916
+ mlp_ratio = 4.0
1917
+ inner_dim = 3072
1918
+
1919
+ # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
1920
+ # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
1921
+ def swap_scale_shift(weight):
1922
+ shift, scale = weight.chunk(2, dim=0)
1923
+ new_weight = torch.cat([scale, shift], dim=0)
1924
+ return new_weight
1925
+
1926
+ ## time_text_embed.timestep_embedder <- time_in
1927
+ converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop(
1928
+ "time_in.in_layer.weight"
1929
+ )
1930
+ converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("time_in.in_layer.bias")
1931
+ converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop(
1932
+ "time_in.out_layer.weight"
1933
+ )
1934
+ converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("time_in.out_layer.bias")
1935
+
1936
+ ## time_text_embed.text_embedder <- vector_in
1937
+ converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("vector_in.in_layer.weight")
1938
+ converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("vector_in.in_layer.bias")
1939
+ converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop(
1940
+ "vector_in.out_layer.weight"
1941
+ )
1942
+ converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("vector_in.out_layer.bias")
1943
+
1944
+ # guidance
1945
+ has_guidance = any("guidance" in k for k in checkpoint)
1946
+ if has_guidance:
1947
+ converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = checkpoint.pop(
1948
+ "guidance_in.in_layer.weight"
1949
+ )
1950
+ converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = checkpoint.pop(
1951
+ "guidance_in.in_layer.bias"
1952
+ )
1953
+ converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = checkpoint.pop(
1954
+ "guidance_in.out_layer.weight"
1955
+ )
1956
+ converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = checkpoint.pop(
1957
+ "guidance_in.out_layer.bias"
1958
+ )
1959
+
1960
+ # context_embedder
1961
+ converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight")
1962
+ converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias")
1963
+
1964
+ # x_embedder
1965
+ converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight")
1966
+ converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias")
1967
+
1968
+ # double transformer blocks
1969
+ for i in range(num_layers):
1970
+ block_prefix = f"transformer_blocks.{i}."
1971
+ # norms.
1972
+ ## norm1
1973
+ converted_state_dict[f"{block_prefix}norm1.linear.weight"] = checkpoint.pop(
1974
+ f"double_blocks.{i}.img_mod.lin.weight"
1975
+ )
1976
+ converted_state_dict[f"{block_prefix}norm1.linear.bias"] = checkpoint.pop(
1977
+ f"double_blocks.{i}.img_mod.lin.bias"
1978
+ )
1979
+ ## norm1_context
1980
+ converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = checkpoint.pop(
1981
+ f"double_blocks.{i}.txt_mod.lin.weight"
1982
+ )
1983
+ converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = checkpoint.pop(
1984
+ f"double_blocks.{i}.txt_mod.lin.bias"
1985
+ )
1986
+ # Q, K, V
1987
+ sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0)
1988
+ context_q, context_k, context_v = torch.chunk(
1989
+ checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
1990
+ )
1991
+ sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
1992
+ checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
1993
+ )
1994
+ context_q_bias, context_k_bias, context_v_bias = torch.chunk(
1995
+ checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
1996
+ )
1997
+ converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
1998
+ converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
1999
+ converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
2000
+ converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
2001
+ converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
2002
+ converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
2003
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
2004
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
2005
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
2006
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
2007
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
2008
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
2009
+ # qk_norm
2010
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
2011
+ f"double_blocks.{i}.img_attn.norm.query_norm.scale"
2012
+ )
2013
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
2014
+ f"double_blocks.{i}.img_attn.norm.key_norm.scale"
2015
+ )
2016
+ converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop(
2017
+ f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
2018
+ )
2019
+ converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop(
2020
+ f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
2021
+ )
2022
+ # ff img_mlp
2023
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop(
2024
+ f"double_blocks.{i}.img_mlp.0.weight"
2025
+ )
2026
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias")
2027
+ converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight")
2028
+ converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias")
2029
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop(
2030
+ f"double_blocks.{i}.txt_mlp.0.weight"
2031
+ )
2032
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop(
2033
+ f"double_blocks.{i}.txt_mlp.0.bias"
2034
+ )
2035
+ converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop(
2036
+ f"double_blocks.{i}.txt_mlp.2.weight"
2037
+ )
2038
+ converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop(
2039
+ f"double_blocks.{i}.txt_mlp.2.bias"
2040
+ )
2041
+ # output projections.
2042
+ converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop(
2043
+ f"double_blocks.{i}.img_attn.proj.weight"
2044
+ )
2045
+ converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop(
2046
+ f"double_blocks.{i}.img_attn.proj.bias"
2047
+ )
2048
+ converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop(
2049
+ f"double_blocks.{i}.txt_attn.proj.weight"
2050
+ )
2051
+ converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop(
2052
+ f"double_blocks.{i}.txt_attn.proj.bias"
2053
+ )
2054
+
2055
+ # single transfomer blocks
2056
+ for i in range(num_single_layers):
2057
+ block_prefix = f"single_transformer_blocks.{i}."
2058
+ # norm.linear <- single_blocks.0.modulation.lin
2059
+ converted_state_dict[f"{block_prefix}norm.linear.weight"] = checkpoint.pop(
2060
+ f"single_blocks.{i}.modulation.lin.weight"
2061
+ )
2062
+ converted_state_dict[f"{block_prefix}norm.linear.bias"] = checkpoint.pop(
2063
+ f"single_blocks.{i}.modulation.lin.bias"
2064
+ )
2065
+ # Q, K, V, mlp
2066
+ mlp_hidden_dim = int(inner_dim * mlp_ratio)
2067
+ split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
2068
+ q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
2069
+ q_bias, k_bias, v_bias, mlp_bias = torch.split(
2070
+ checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
2071
+ )
2072
+ converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
2073
+ converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
2074
+ converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
2075
+ converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
2076
+ converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
2077
+ converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
2078
+ converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
2079
+ converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
2080
+ # qk norm
2081
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
2082
+ f"single_blocks.{i}.norm.query_norm.scale"
2083
+ )
2084
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
2085
+ f"single_blocks.{i}.norm.key_norm.scale"
2086
+ )
2087
+ # output projections.
2088
+ converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight")
2089
+ converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias")
2090
+
2091
+ converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
2092
+ converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
2093
+ converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
2094
+ checkpoint.pop("final_layer.adaLN_modulation.1.weight")
2095
+ )
2096
+ converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
2097
+ checkpoint.pop("final_layer.adaLN_modulation.1.bias")
2098
+ )
2099
+
2100
+ return converted_state_dict
diffusers/loaders/textual_inversion.py ADDED
@@ -0,0 +1,578 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Dict, List, Optional, Union
15
+
16
+ import safetensors
17
+ import torch
18
+ from huggingface_hub.utils import validate_hf_hub_args
19
+ from torch import nn
20
+
21
+ from ..models.modeling_utils import load_state_dict
22
+ from ..utils import _get_model_file, is_accelerate_available, is_transformers_available, logging
23
+
24
+
25
+ if is_transformers_available():
26
+ from transformers import PreTrainedModel, PreTrainedTokenizer
27
+
28
+ if is_accelerate_available():
29
+ from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+ TEXT_INVERSION_NAME = "learned_embeds.bin"
34
+ TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
35
+
36
+
37
+ @validate_hf_hub_args
38
+ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs):
39
+ cache_dir = kwargs.pop("cache_dir", None)
40
+ force_download = kwargs.pop("force_download", False)
41
+ proxies = kwargs.pop("proxies", None)
42
+ local_files_only = kwargs.pop("local_files_only", None)
43
+ token = kwargs.pop("token", None)
44
+ revision = kwargs.pop("revision", None)
45
+ subfolder = kwargs.pop("subfolder", None)
46
+ weight_name = kwargs.pop("weight_name", None)
47
+ use_safetensors = kwargs.pop("use_safetensors", None)
48
+
49
+ allow_pickle = False
50
+ if use_safetensors is None:
51
+ use_safetensors = True
52
+ allow_pickle = True
53
+
54
+ user_agent = {
55
+ "file_type": "text_inversion",
56
+ "framework": "pytorch",
57
+ }
58
+ state_dicts = []
59
+ for pretrained_model_name_or_path in pretrained_model_name_or_paths:
60
+ if not isinstance(pretrained_model_name_or_path, (dict, torch.Tensor)):
61
+ # 3.1. Load textual inversion file
62
+ model_file = None
63
+
64
+ # Let's first try to load .safetensors weights
65
+ if (use_safetensors and weight_name is None) or (
66
+ weight_name is not None and weight_name.endswith(".safetensors")
67
+ ):
68
+ try:
69
+ model_file = _get_model_file(
70
+ pretrained_model_name_or_path,
71
+ weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
72
+ cache_dir=cache_dir,
73
+ force_download=force_download,
74
+ proxies=proxies,
75
+ local_files_only=local_files_only,
76
+ token=token,
77
+ revision=revision,
78
+ subfolder=subfolder,
79
+ user_agent=user_agent,
80
+ )
81
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
82
+ except Exception as e:
83
+ if not allow_pickle:
84
+ raise e
85
+
86
+ model_file = None
87
+
88
+ if model_file is None:
89
+ model_file = _get_model_file(
90
+ pretrained_model_name_or_path,
91
+ weights_name=weight_name or TEXT_INVERSION_NAME,
92
+ cache_dir=cache_dir,
93
+ force_download=force_download,
94
+ proxies=proxies,
95
+ local_files_only=local_files_only,
96
+ token=token,
97
+ revision=revision,
98
+ subfolder=subfolder,
99
+ user_agent=user_agent,
100
+ )
101
+ state_dict = load_state_dict(model_file)
102
+ else:
103
+ state_dict = pretrained_model_name_or_path
104
+
105
+ state_dicts.append(state_dict)
106
+
107
+ return state_dicts
108
+
109
+
110
+ class TextualInversionLoaderMixin:
111
+ r"""
112
+ Load Textual Inversion tokens and embeddings to the tokenizer and text encoder.
113
+ """
114
+
115
+ def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): # noqa: F821
116
+ r"""
117
+ Processes prompts that include a special token corresponding to a multi-vector textual inversion embedding to
118
+ be replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
119
+ inversion token or if the textual inversion token is a single vector, the input prompt is returned.
120
+
121
+ Parameters:
122
+ prompt (`str` or list of `str`):
123
+ The prompt or prompts to guide the image generation.
124
+ tokenizer (`PreTrainedTokenizer`):
125
+ The tokenizer responsible for encoding the prompt into input tokens.
126
+
127
+ Returns:
128
+ `str` or list of `str`: The converted prompt
129
+ """
130
+ if not isinstance(prompt, List):
131
+ prompts = [prompt]
132
+ else:
133
+ prompts = prompt
134
+
135
+ prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts]
136
+
137
+ if not isinstance(prompt, List):
138
+ return prompts[0]
139
+
140
+ return prompts
141
+
142
+ def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"): # noqa: F821
143
+ r"""
144
+ Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
145
+ to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
146
+ is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
147
+ inversion token or a textual inversion token that is a single vector, the input prompt is simply returned.
148
+
149
+ Parameters:
150
+ prompt (`str`):
151
+ The prompt to guide the image generation.
152
+ tokenizer (`PreTrainedTokenizer`):
153
+ The tokenizer responsible for encoding the prompt into input tokens.
154
+
155
+ Returns:
156
+ `str`: The converted prompt
157
+ """
158
+ tokens = tokenizer.tokenize(prompt)
159
+ unique_tokens = set(tokens)
160
+ for token in unique_tokens:
161
+ if token in tokenizer.added_tokens_encoder:
162
+ replacement = token
163
+ i = 1
164
+ while f"{token}_{i}" in tokenizer.added_tokens_encoder:
165
+ replacement += f" {token}_{i}"
166
+ i += 1
167
+
168
+ prompt = prompt.replace(token, replacement)
169
+
170
+ return prompt
171
+
172
+ def _check_text_inv_inputs(self, tokenizer, text_encoder, pretrained_model_name_or_paths, tokens):
173
+ if tokenizer is None:
174
+ raise ValueError(
175
+ f"{self.__class__.__name__} requires `self.tokenizer` or passing a `tokenizer` of type `PreTrainedTokenizer` for calling"
176
+ f" `{self.load_textual_inversion.__name__}`"
177
+ )
178
+
179
+ if text_encoder is None:
180
+ raise ValueError(
181
+ f"{self.__class__.__name__} requires `self.text_encoder` or passing a `text_encoder` of type `PreTrainedModel` for calling"
182
+ f" `{self.load_textual_inversion.__name__}`"
183
+ )
184
+
185
+ if len(pretrained_model_name_or_paths) > 1 and len(pretrained_model_name_or_paths) != len(tokens):
186
+ raise ValueError(
187
+ f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)} "
188
+ f"Make sure both lists have the same length."
189
+ )
190
+
191
+ valid_tokens = [t for t in tokens if t is not None]
192
+ if len(set(valid_tokens)) < len(valid_tokens):
193
+ raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}")
194
+
195
+ @staticmethod
196
+ def _retrieve_tokens_and_embeddings(tokens, state_dicts, tokenizer):
197
+ all_tokens = []
198
+ all_embeddings = []
199
+ for state_dict, token in zip(state_dicts, tokens):
200
+ if isinstance(state_dict, torch.Tensor):
201
+ if token is None:
202
+ raise ValueError(
203
+ "You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
204
+ )
205
+ loaded_token = token
206
+ embedding = state_dict
207
+ elif len(state_dict) == 1:
208
+ # diffusers
209
+ loaded_token, embedding = next(iter(state_dict.items()))
210
+ elif "string_to_param" in state_dict:
211
+ # A1111
212
+ loaded_token = state_dict["name"]
213
+ embedding = state_dict["string_to_param"]["*"]
214
+ else:
215
+ raise ValueError(
216
+ f"Loaded state dictionary is incorrect: {state_dict}. \n\n"
217
+ "Please verify that the loaded state dictionary of the textual embedding either only has a single key or includes the `string_to_param`"
218
+ " input key."
219
+ )
220
+
221
+ if token is not None and loaded_token != token:
222
+ logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
223
+ else:
224
+ token = loaded_token
225
+
226
+ if token in tokenizer.get_vocab():
227
+ raise ValueError(
228
+ f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
229
+ )
230
+
231
+ all_tokens.append(token)
232
+ all_embeddings.append(embedding)
233
+
234
+ return all_tokens, all_embeddings
235
+
236
+ @staticmethod
237
+ def _extend_tokens_and_embeddings(tokens, embeddings, tokenizer):
238
+ all_tokens = []
239
+ all_embeddings = []
240
+
241
+ for embedding, token in zip(embeddings, tokens):
242
+ if f"{token}_1" in tokenizer.get_vocab():
243
+ multi_vector_tokens = [token]
244
+ i = 1
245
+ while f"{token}_{i}" in tokenizer.added_tokens_encoder:
246
+ multi_vector_tokens.append(f"{token}_{i}")
247
+ i += 1
248
+
249
+ raise ValueError(
250
+ f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
251
+ )
252
+
253
+ is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1
254
+ if is_multi_vector:
255
+ all_tokens += [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])]
256
+ all_embeddings += [e for e in embedding] # noqa: C416
257
+ else:
258
+ all_tokens += [token]
259
+ all_embeddings += [embedding[0]] if len(embedding.shape) > 1 else [embedding]
260
+
261
+ return all_tokens, all_embeddings
262
+
263
+ @validate_hf_hub_args
264
+ def load_textual_inversion(
265
+ self,
266
+ pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
267
+ token: Optional[Union[str, List[str]]] = None,
268
+ tokenizer: Optional["PreTrainedTokenizer"] = None, # noqa: F821
269
+ text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
270
+ **kwargs,
271
+ ):
272
+ r"""
273
+ Load Textual Inversion embeddings into the text encoder of [`StableDiffusionPipeline`] (both 🤗 Diffusers and
274
+ Automatic1111 formats are supported).
275
+
276
+ Parameters:
277
+ pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`):
278
+ Can be either one of the following or a list of them:
279
+
280
+ - A string, the *model id* (for example `sd-concepts-library/low-poly-hd-logos-icons`) of a
281
+ pretrained model hosted on the Hub.
282
+ - A path to a *directory* (for example `./my_text_inversion_directory/`) containing the textual
283
+ inversion weights.
284
+ - A path to a *file* (for example `./my_text_inversions.pt`) containing textual inversion weights.
285
+ - A [torch state
286
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
287
+
288
+ token (`str` or `List[str]`, *optional*):
289
+ Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a
290
+ list, then `token` must also be a list of equal length.
291
+ text_encoder ([`~transformers.CLIPTextModel`], *optional*):
292
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
293
+ If not specified, function will take self.tokenizer.
294
+ tokenizer ([`~transformers.CLIPTokenizer`], *optional*):
295
+ A `CLIPTokenizer` to tokenize text. If not specified, function will take self.tokenizer.
296
+ weight_name (`str`, *optional*):
297
+ Name of a custom weight file. This should be used when:
298
+
299
+ - The saved textual inversion file is in 🤗 Diffusers format, but was saved under a specific weight
300
+ name such as `text_inv.bin`.
301
+ - The saved textual inversion file is in the Automatic1111 format.
302
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
303
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
304
+ is not used.
305
+ force_download (`bool`, *optional*, defaults to `False`):
306
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
307
+ cached versions if they exist.
308
+
309
+ proxies (`Dict[str, str]`, *optional*):
310
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
311
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
312
+ local_files_only (`bool`, *optional*, defaults to `False`):
313
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
314
+ won't be downloaded from the Hub.
315
+ token (`str` or *bool*, *optional*):
316
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
317
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
318
+ revision (`str`, *optional*, defaults to `"main"`):
319
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
320
+ allowed by Git.
321
+ subfolder (`str`, *optional*, defaults to `""`):
322
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
323
+ mirror (`str`, *optional*):
324
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
325
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
326
+ information.
327
+
328
+ Example:
329
+
330
+ To load a Textual Inversion embedding vector in 🤗 Diffusers format:
331
+
332
+ ```py
333
+ from diffusers import StableDiffusionPipeline
334
+ import torch
335
+
336
+ model_id = "runwayml/stable-diffusion-v1-5"
337
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
338
+
339
+ pipe.load_textual_inversion("sd-concepts-library/cat-toy")
340
+
341
+ prompt = "A <cat-toy> backpack"
342
+
343
+ image = pipe(prompt, num_inference_steps=50).images[0]
344
+ image.save("cat-backpack.png")
345
+ ```
346
+
347
+ To load a Textual Inversion embedding vector in Automatic1111 format, make sure to download the vector first
348
+ (for example from [civitAI](https://civitai.com/models/3036?modelVersionId=9857)) and then load the vector
349
+ locally:
350
+
351
+ ```py
352
+ from diffusers import StableDiffusionPipeline
353
+ import torch
354
+
355
+ model_id = "runwayml/stable-diffusion-v1-5"
356
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
357
+
358
+ pipe.load_textual_inversion("./charturnerv2.pt", token="charturnerv2")
359
+
360
+ prompt = "charturnerv2, multiple views of the same character in the same outfit, a character turnaround of a woman wearing a black jacket and red shirt, best quality, intricate details."
361
+
362
+ image = pipe(prompt, num_inference_steps=50).images[0]
363
+ image.save("character.png")
364
+ ```
365
+
366
+ """
367
+ # 1. Set correct tokenizer and text encoder
368
+ tokenizer = tokenizer or getattr(self, "tokenizer", None)
369
+ text_encoder = text_encoder or getattr(self, "text_encoder", None)
370
+
371
+ # 2. Normalize inputs
372
+ pretrained_model_name_or_paths = (
373
+ [pretrained_model_name_or_path]
374
+ if not isinstance(pretrained_model_name_or_path, list)
375
+ else pretrained_model_name_or_path
376
+ )
377
+ tokens = [token] if not isinstance(token, list) else token
378
+ if tokens[0] is None:
379
+ tokens = tokens * len(pretrained_model_name_or_paths)
380
+
381
+ # 3. Check inputs
382
+ self._check_text_inv_inputs(tokenizer, text_encoder, pretrained_model_name_or_paths, tokens)
383
+
384
+ # 4. Load state dicts of textual embeddings
385
+ state_dicts = load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
386
+
387
+ # 4.1 Handle the special case when state_dict is a tensor that contains n embeddings for n tokens
388
+ if len(tokens) > 1 and len(state_dicts) == 1:
389
+ if isinstance(state_dicts[0], torch.Tensor):
390
+ state_dicts = list(state_dicts[0])
391
+ if len(tokens) != len(state_dicts):
392
+ raise ValueError(
393
+ f"You have passed a state_dict contains {len(state_dicts)} embeddings, and list of tokens of length {len(tokens)} "
394
+ f"Make sure both have the same length."
395
+ )
396
+
397
+ # 4. Retrieve tokens and embeddings
398
+ tokens, embeddings = self._retrieve_tokens_and_embeddings(tokens, state_dicts, tokenizer)
399
+
400
+ # 5. Extend tokens and embeddings for multi vector
401
+ tokens, embeddings = self._extend_tokens_and_embeddings(tokens, embeddings, tokenizer)
402
+
403
+ # 6. Make sure all embeddings have the correct size
404
+ expected_emb_dim = text_encoder.get_input_embeddings().weight.shape[-1]
405
+ if any(expected_emb_dim != emb.shape[-1] for emb in embeddings):
406
+ raise ValueError(
407
+ "Loaded embeddings are of incorrect shape. Expected each textual inversion embedding "
408
+ "to be of shape {input_embeddings.shape[-1]}, but are {embeddings.shape[-1]} "
409
+ )
410
+
411
+ # 7. Now we can be sure that loading the embedding matrix works
412
+ # < Unsafe code:
413
+
414
+ # 7.1 Offload all hooks in case the pipeline was cpu offloaded before make sure, we offload and onload again
415
+ is_model_cpu_offload = False
416
+ is_sequential_cpu_offload = False
417
+ if self.hf_device_map is None:
418
+ for _, component in self.components.items():
419
+ if isinstance(component, nn.Module):
420
+ if hasattr(component, "_hf_hook"):
421
+ is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
422
+ is_sequential_cpu_offload = (
423
+ isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
424
+ or hasattr(component._hf_hook, "hooks")
425
+ and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
426
+ )
427
+ logger.info(
428
+ "Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again."
429
+ )
430
+ remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
431
+
432
+ # 7.2 save expected device and dtype
433
+ device = text_encoder.device
434
+ dtype = text_encoder.dtype
435
+
436
+ # 7.3 Increase token embedding matrix
437
+ text_encoder.resize_token_embeddings(len(tokenizer) + len(tokens))
438
+ input_embeddings = text_encoder.get_input_embeddings().weight
439
+
440
+ # 7.4 Load token and embedding
441
+ for token, embedding in zip(tokens, embeddings):
442
+ # add tokens and get ids
443
+ tokenizer.add_tokens(token)
444
+ token_id = tokenizer.convert_tokens_to_ids(token)
445
+ input_embeddings.data[token_id] = embedding
446
+ logger.info(f"Loaded textual inversion embedding for {token}.")
447
+
448
+ input_embeddings.to(dtype=dtype, device=device)
449
+
450
+ # 7.5 Offload the model again
451
+ if is_model_cpu_offload:
452
+ self.enable_model_cpu_offload()
453
+ elif is_sequential_cpu_offload:
454
+ self.enable_sequential_cpu_offload()
455
+
456
+ # / Unsafe Code >
457
+
458
+ def unload_textual_inversion(
459
+ self,
460
+ tokens: Optional[Union[str, List[str]]] = None,
461
+ tokenizer: Optional["PreTrainedTokenizer"] = None,
462
+ text_encoder: Optional["PreTrainedModel"] = None,
463
+ ):
464
+ r"""
465
+ Unload Textual Inversion embeddings from the text encoder of [`StableDiffusionPipeline`]
466
+
467
+ Example:
468
+ ```py
469
+ from diffusers import AutoPipelineForText2Image
470
+ import torch
471
+
472
+ pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5")
473
+
474
+ # Example 1
475
+ pipeline.load_textual_inversion("sd-concepts-library/gta5-artwork")
476
+ pipeline.load_textual_inversion("sd-concepts-library/moeb-style")
477
+
478
+ # Remove all token embeddings
479
+ pipeline.unload_textual_inversion()
480
+
481
+ # Example 2
482
+ pipeline.load_textual_inversion("sd-concepts-library/moeb-style")
483
+ pipeline.load_textual_inversion("sd-concepts-library/gta5-artwork")
484
+
485
+ # Remove just one token
486
+ pipeline.unload_textual_inversion("<moe-bius>")
487
+
488
+ # Example 3: unload from SDXL
489
+ pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
490
+ embedding_path = hf_hub_download(
491
+ repo_id="linoyts/web_y2k", filename="web_y2k_emb.safetensors", repo_type="model"
492
+ )
493
+
494
+ # load embeddings to the text encoders
495
+ state_dict = load_file(embedding_path)
496
+
497
+ # load embeddings of text_encoder 1 (CLIP ViT-L/14)
498
+ pipeline.load_textual_inversion(
499
+ state_dict["clip_l"],
500
+ token=["<s0>", "<s1>"],
501
+ text_encoder=pipeline.text_encoder,
502
+ tokenizer=pipeline.tokenizer,
503
+ )
504
+ # load embeddings of text_encoder 2 (CLIP ViT-G/14)
505
+ pipeline.load_textual_inversion(
506
+ state_dict["clip_g"],
507
+ token=["<s0>", "<s1>"],
508
+ text_encoder=pipeline.text_encoder_2,
509
+ tokenizer=pipeline.tokenizer_2,
510
+ )
511
+
512
+ # Unload explicitly from both text encoders abd tokenizers
513
+ pipeline.unload_textual_inversion(
514
+ tokens=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer
515
+ )
516
+ pipeline.unload_textual_inversion(
517
+ tokens=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2
518
+ )
519
+ ```
520
+ """
521
+
522
+ tokenizer = tokenizer or getattr(self, "tokenizer", None)
523
+ text_encoder = text_encoder or getattr(self, "text_encoder", None)
524
+
525
+ # Get textual inversion tokens and ids
526
+ token_ids = []
527
+ last_special_token_id = None
528
+
529
+ if tokens:
530
+ if isinstance(tokens, str):
531
+ tokens = [tokens]
532
+ for added_token_id, added_token in tokenizer.added_tokens_decoder.items():
533
+ if not added_token.special:
534
+ if added_token.content in tokens:
535
+ token_ids.append(added_token_id)
536
+ else:
537
+ last_special_token_id = added_token_id
538
+ if len(token_ids) == 0:
539
+ raise ValueError("No tokens to remove found")
540
+ else:
541
+ tokens = []
542
+ for added_token_id, added_token in tokenizer.added_tokens_decoder.items():
543
+ if not added_token.special:
544
+ token_ids.append(added_token_id)
545
+ tokens.append(added_token.content)
546
+ else:
547
+ last_special_token_id = added_token_id
548
+
549
+ # Delete from tokenizer
550
+ for token_id, token_to_remove in zip(token_ids, tokens):
551
+ del tokenizer._added_tokens_decoder[token_id]
552
+ del tokenizer._added_tokens_encoder[token_to_remove]
553
+
554
+ # Make all token ids sequential in tokenizer
555
+ key_id = 1
556
+ for token_id in tokenizer.added_tokens_decoder:
557
+ if token_id > last_special_token_id and token_id > last_special_token_id + key_id:
558
+ token = tokenizer._added_tokens_decoder[token_id]
559
+ tokenizer._added_tokens_decoder[last_special_token_id + key_id] = token
560
+ del tokenizer._added_tokens_decoder[token_id]
561
+ tokenizer._added_tokens_encoder[token.content] = last_special_token_id + key_id
562
+ key_id += 1
563
+ tokenizer._update_trie()
564
+
565
+ # Delete from text encoder
566
+ text_embedding_dim = text_encoder.get_input_embeddings().embedding_dim
567
+ temp_text_embedding_weights = text_encoder.get_input_embeddings().weight
568
+ text_embedding_weights = temp_text_embedding_weights[: last_special_token_id + 1]
569
+ to_append = []
570
+ for i in range(last_special_token_id + 1, temp_text_embedding_weights.shape[0]):
571
+ if i not in token_ids:
572
+ to_append.append(temp_text_embedding_weights[i].unsqueeze(0))
573
+ if len(to_append) > 0:
574
+ to_append = torch.cat(to_append, dim=0)
575
+ text_embedding_weights = torch.cat([text_embedding_weights, to_append], dim=0)
576
+ text_embeddings_filtered = nn.Embedding(text_embedding_weights.shape[0], text_embedding_dim)
577
+ text_embeddings_filtered.weight.data = text_embedding_weights
578
+ text_encoder.set_input_embeddings(text_embeddings_filtered)
diffusers/loaders/unet.py ADDED
@@ -0,0 +1,921 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ from collections import defaultdict
16
+ from contextlib import nullcontext
17
+ from pathlib import Path
18
+ from typing import Callable, Dict, Union
19
+
20
+ import safetensors
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from huggingface_hub.utils import validate_hf_hub_args
24
+ from torch import nn
25
+
26
+ from ..models.embeddings import (
27
+ ImageProjection,
28
+ IPAdapterFaceIDImageProjection,
29
+ IPAdapterFaceIDPlusImageProjection,
30
+ IPAdapterFullImageProjection,
31
+ IPAdapterPlusImageProjection,
32
+ MultiIPAdapterImageProjection,
33
+ )
34
+ from ..models.modeling_utils import load_model_dict_into_meta, load_state_dict
35
+ from ..utils import (
36
+ USE_PEFT_BACKEND,
37
+ _get_model_file,
38
+ convert_unet_state_dict_to_peft,
39
+ get_adapter_name,
40
+ get_peft_kwargs,
41
+ is_accelerate_available,
42
+ is_peft_version,
43
+ is_torch_version,
44
+ logging,
45
+ )
46
+ from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
47
+ from .utils import AttnProcsLayers
48
+
49
+
50
+ if is_accelerate_available():
51
+ from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
52
+
53
+ logger = logging.get_logger(__name__)
54
+
55
+
56
+ CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
57
+ CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
58
+
59
+
60
+ class UNet2DConditionLoadersMixin:
61
+ """
62
+ Load LoRA layers into a [`UNet2DCondtionModel`].
63
+ """
64
+
65
+ text_encoder_name = TEXT_ENCODER_NAME
66
+ unet_name = UNET_NAME
67
+
68
+ @validate_hf_hub_args
69
+ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
70
+ r"""
71
+ Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
72
+ defined in
73
+ [`attention_processor.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py)
74
+ and be a `torch.nn.Module` class. Currently supported: LoRA, Custom Diffusion. For LoRA, one must install
75
+ `peft`: `pip install -U peft`.
76
+
77
+ Parameters:
78
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
79
+ Can be either:
80
+
81
+ - A string, the model id (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
82
+ the Hub.
83
+ - A path to a directory (for example `./my_model_directory`) containing the model weights saved
84
+ with [`ModelMixin.save_pretrained`].
85
+ - A [torch state
86
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
87
+
88
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
89
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
90
+ is not used.
91
+ force_download (`bool`, *optional*, defaults to `False`):
92
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
93
+ cached versions if they exist.
94
+
95
+ proxies (`Dict[str, str]`, *optional*):
96
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
97
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
98
+ local_files_only (`bool`, *optional*, defaults to `False`):
99
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
100
+ won't be downloaded from the Hub.
101
+ token (`str` or *bool*, *optional*):
102
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
103
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
104
+ revision (`str`, *optional*, defaults to `"main"`):
105
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
106
+ allowed by Git.
107
+ subfolder (`str`, *optional*, defaults to `""`):
108
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
109
+ network_alphas (`Dict[str, float]`):
110
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
111
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
112
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
113
+ adapter_name (`str`, *optional*, defaults to None):
114
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
115
+ `default_{i}` where i is the total number of adapters being loaded.
116
+ weight_name (`str`, *optional*, defaults to None):
117
+ Name of the serialized state dict file.
118
+ low_cpu_mem_usage (`bool`, *optional*):
119
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
120
+ weights.
121
+
122
+ Example:
123
+
124
+ ```py
125
+ from diffusers import AutoPipelineForText2Image
126
+ import torch
127
+
128
+ pipeline = AutoPipelineForText2Image.from_pretrained(
129
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
130
+ ).to("cuda")
131
+ pipeline.unet.load_attn_procs(
132
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
133
+ )
134
+ ```
135
+ """
136
+ cache_dir = kwargs.pop("cache_dir", None)
137
+ force_download = kwargs.pop("force_download", False)
138
+ proxies = kwargs.pop("proxies", None)
139
+ local_files_only = kwargs.pop("local_files_only", None)
140
+ token = kwargs.pop("token", None)
141
+ revision = kwargs.pop("revision", None)
142
+ subfolder = kwargs.pop("subfolder", None)
143
+ weight_name = kwargs.pop("weight_name", None)
144
+ use_safetensors = kwargs.pop("use_safetensors", None)
145
+ adapter_name = kwargs.pop("adapter_name", None)
146
+ _pipeline = kwargs.pop("_pipeline", None)
147
+ network_alphas = kwargs.pop("network_alphas", None)
148
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
149
+ allow_pickle = False
150
+
151
+ if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
152
+ raise ValueError(
153
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
154
+ )
155
+
156
+ if use_safetensors is None:
157
+ use_safetensors = True
158
+ allow_pickle = True
159
+
160
+ user_agent = {
161
+ "file_type": "attn_procs_weights",
162
+ "framework": "pytorch",
163
+ }
164
+
165
+ model_file = None
166
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
167
+ # Let's first try to load .safetensors weights
168
+ if (use_safetensors and weight_name is None) or (
169
+ weight_name is not None and weight_name.endswith(".safetensors")
170
+ ):
171
+ try:
172
+ model_file = _get_model_file(
173
+ pretrained_model_name_or_path_or_dict,
174
+ weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
175
+ cache_dir=cache_dir,
176
+ force_download=force_download,
177
+ proxies=proxies,
178
+ local_files_only=local_files_only,
179
+ token=token,
180
+ revision=revision,
181
+ subfolder=subfolder,
182
+ user_agent=user_agent,
183
+ )
184
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
185
+ except IOError as e:
186
+ if not allow_pickle:
187
+ raise e
188
+ # try loading non-safetensors weights
189
+ pass
190
+ if model_file is None:
191
+ model_file = _get_model_file(
192
+ pretrained_model_name_or_path_or_dict,
193
+ weights_name=weight_name or LORA_WEIGHT_NAME,
194
+ cache_dir=cache_dir,
195
+ force_download=force_download,
196
+ proxies=proxies,
197
+ local_files_only=local_files_only,
198
+ token=token,
199
+ revision=revision,
200
+ subfolder=subfolder,
201
+ user_agent=user_agent,
202
+ )
203
+ state_dict = load_state_dict(model_file)
204
+ else:
205
+ state_dict = pretrained_model_name_or_path_or_dict
206
+
207
+ is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
208
+ is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys())
209
+ is_model_cpu_offload = False
210
+ is_sequential_cpu_offload = False
211
+
212
+ if is_custom_diffusion:
213
+ attn_processors = self._process_custom_diffusion(state_dict=state_dict)
214
+ elif is_lora:
215
+ is_model_cpu_offload, is_sequential_cpu_offload = self._process_lora(
216
+ state_dict=state_dict,
217
+ unet_identifier_key=self.unet_name,
218
+ network_alphas=network_alphas,
219
+ adapter_name=adapter_name,
220
+ _pipeline=_pipeline,
221
+ low_cpu_mem_usage=low_cpu_mem_usage,
222
+ )
223
+ else:
224
+ raise ValueError(
225
+ f"{model_file} does not seem to be in the correct format expected by Custom Diffusion training."
226
+ )
227
+
228
+ # <Unsafe code
229
+ # We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
230
+ # Now we remove any existing hooks to `_pipeline`.
231
+
232
+ # For LoRA, the UNet is already offloaded at this stage as it is handled inside `_process_lora`.
233
+ if is_custom_diffusion and _pipeline is not None:
234
+ is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline=_pipeline)
235
+
236
+ # only custom diffusion needs to set attn processors
237
+ self.set_attn_processor(attn_processors)
238
+ self.to(dtype=self.dtype, device=self.device)
239
+
240
+ # Offload back.
241
+ if is_model_cpu_offload:
242
+ _pipeline.enable_model_cpu_offload()
243
+ elif is_sequential_cpu_offload:
244
+ _pipeline.enable_sequential_cpu_offload()
245
+ # Unsafe code />
246
+
247
+ def _process_custom_diffusion(self, state_dict):
248
+ from ..models.attention_processor import CustomDiffusionAttnProcessor
249
+
250
+ attn_processors = {}
251
+ custom_diffusion_grouped_dict = defaultdict(dict)
252
+ for key, value in state_dict.items():
253
+ if len(value) == 0:
254
+ custom_diffusion_grouped_dict[key] = {}
255
+ else:
256
+ if "to_out" in key:
257
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
258
+ else:
259
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
260
+ custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value
261
+
262
+ for key, value_dict in custom_diffusion_grouped_dict.items():
263
+ if len(value_dict) == 0:
264
+ attn_processors[key] = CustomDiffusionAttnProcessor(
265
+ train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None
266
+ )
267
+ else:
268
+ cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1]
269
+ hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0]
270
+ train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False
271
+ attn_processors[key] = CustomDiffusionAttnProcessor(
272
+ train_kv=True,
273
+ train_q_out=train_q_out,
274
+ hidden_size=hidden_size,
275
+ cross_attention_dim=cross_attention_dim,
276
+ )
277
+ attn_processors[key].load_state_dict(value_dict)
278
+
279
+ return attn_processors
280
+
281
+ def _process_lora(
282
+ self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline, low_cpu_mem_usage
283
+ ):
284
+ # This method does the following things:
285
+ # 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy
286
+ # format. For legacy format no filtering is applied.
287
+ # 2. Converts the `state_dict` to the `peft` compatible format.
288
+ # 3. Creates a `LoraConfig` and then injects the converted `state_dict` into the UNet per the
289
+ # `LoraConfig` specs.
290
+ # 4. It also reports if the underlying `_pipeline` has any kind of offloading inside of it.
291
+ if not USE_PEFT_BACKEND:
292
+ raise ValueError("PEFT backend is required for this method.")
293
+
294
+ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
295
+
296
+ keys = list(state_dict.keys())
297
+
298
+ unet_keys = [k for k in keys if k.startswith(unet_identifier_key)]
299
+ unet_state_dict = {
300
+ k.replace(f"{unet_identifier_key}.", ""): v for k, v in state_dict.items() if k in unet_keys
301
+ }
302
+
303
+ if network_alphas is not None:
304
+ alpha_keys = [k for k in network_alphas.keys() if k.startswith(unet_identifier_key)]
305
+ network_alphas = {
306
+ k.replace(f"{unet_identifier_key}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
307
+ }
308
+
309
+ is_model_cpu_offload = False
310
+ is_sequential_cpu_offload = False
311
+ state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict
312
+
313
+ if len(state_dict_to_be_used) > 0:
314
+ if adapter_name in getattr(self, "peft_config", {}):
315
+ raise ValueError(
316
+ f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name."
317
+ )
318
+
319
+ state_dict = convert_unet_state_dict_to_peft(state_dict_to_be_used)
320
+
321
+ if network_alphas is not None:
322
+ # The alphas state dict have the same structure as Unet, thus we convert it to peft format using
323
+ # `convert_unet_state_dict_to_peft` method.
324
+ network_alphas = convert_unet_state_dict_to_peft(network_alphas)
325
+
326
+ rank = {}
327
+ for key, val in state_dict.items():
328
+ if "lora_B" in key:
329
+ rank[key] = val.shape[1]
330
+
331
+ lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
332
+ if "use_dora" in lora_config_kwargs:
333
+ if lora_config_kwargs["use_dora"]:
334
+ if is_peft_version("<", "0.9.0"):
335
+ raise ValueError(
336
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
337
+ )
338
+ else:
339
+ if is_peft_version("<", "0.9.0"):
340
+ lora_config_kwargs.pop("use_dora")
341
+ lora_config = LoraConfig(**lora_config_kwargs)
342
+
343
+ # adapter_name
344
+ if adapter_name is None:
345
+ adapter_name = get_adapter_name(self)
346
+
347
+ # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
348
+ # otherwise loading LoRA weights will lead to an error
349
+ is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
350
+ peft_kwargs = {}
351
+ if is_peft_version(">=", "0.13.1"):
352
+ peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
353
+
354
+ inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
355
+ incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
356
+
357
+ if incompatible_keys is not None:
358
+ # check only for unexpected keys
359
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
360
+ if unexpected_keys:
361
+ logger.warning(
362
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
363
+ f" {unexpected_keys}. "
364
+ )
365
+
366
+ return is_model_cpu_offload, is_sequential_cpu_offload
367
+
368
+ @classmethod
369
+ # Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
370
+ def _optionally_disable_offloading(cls, _pipeline):
371
+ """
372
+ Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
373
+
374
+ Args:
375
+ _pipeline (`DiffusionPipeline`):
376
+ The pipeline to disable offloading for.
377
+
378
+ Returns:
379
+ tuple:
380
+ A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
381
+ """
382
+ is_model_cpu_offload = False
383
+ is_sequential_cpu_offload = False
384
+
385
+ if _pipeline is not None and _pipeline.hf_device_map is None:
386
+ for _, component in _pipeline.components.items():
387
+ if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
388
+ if not is_model_cpu_offload:
389
+ is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
390
+ if not is_sequential_cpu_offload:
391
+ is_sequential_cpu_offload = (
392
+ isinstance(component._hf_hook, AlignDevicesHook)
393
+ or hasattr(component._hf_hook, "hooks")
394
+ and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
395
+ )
396
+
397
+ logger.info(
398
+ "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
399
+ )
400
+ remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
401
+
402
+ return (is_model_cpu_offload, is_sequential_cpu_offload)
403
+
404
+ def save_attn_procs(
405
+ self,
406
+ save_directory: Union[str, os.PathLike],
407
+ is_main_process: bool = True,
408
+ weight_name: str = None,
409
+ save_function: Callable = None,
410
+ safe_serialization: bool = True,
411
+ **kwargs,
412
+ ):
413
+ r"""
414
+ Save attention processor layers to a directory so that it can be reloaded with the
415
+ [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`] method.
416
+
417
+ Arguments:
418
+ save_directory (`str` or `os.PathLike`):
419
+ Directory to save an attention processor to (will be created if it doesn't exist).
420
+ is_main_process (`bool`, *optional*, defaults to `True`):
421
+ Whether the process calling this is the main process or not. Useful during distributed training and you
422
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
423
+ process to avoid race conditions.
424
+ save_function (`Callable`):
425
+ The function to use to save the state dictionary. Useful during distributed training when you need to
426
+ replace `torch.save` with another method. Can be configured with the environment variable
427
+ `DIFFUSERS_SAVE_MODE`.
428
+ safe_serialization (`bool`, *optional*, defaults to `True`):
429
+ Whether to save the model using `safetensors` or with `pickle`.
430
+
431
+ Example:
432
+
433
+ ```py
434
+ import torch
435
+ from diffusers import DiffusionPipeline
436
+
437
+ pipeline = DiffusionPipeline.from_pretrained(
438
+ "CompVis/stable-diffusion-v1-4",
439
+ torch_dtype=torch.float16,
440
+ ).to("cuda")
441
+ pipeline.unet.load_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin")
442
+ pipeline.unet.save_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin")
443
+ ```
444
+ """
445
+ from ..models.attention_processor import (
446
+ CustomDiffusionAttnProcessor,
447
+ CustomDiffusionAttnProcessor2_0,
448
+ CustomDiffusionXFormersAttnProcessor,
449
+ )
450
+
451
+ if os.path.isfile(save_directory):
452
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
453
+ return
454
+
455
+ is_custom_diffusion = any(
456
+ isinstance(
457
+ x,
458
+ (CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor),
459
+ )
460
+ for (_, x) in self.attn_processors.items()
461
+ )
462
+ if is_custom_diffusion:
463
+ state_dict = self._get_custom_diffusion_state_dict()
464
+ if save_function is None and safe_serialization:
465
+ # safetensors does not support saving dicts with non-tensor values
466
+ empty_state_dict = {k: v for k, v in state_dict.items() if not isinstance(v, torch.Tensor)}
467
+ if len(empty_state_dict) > 0:
468
+ logger.warning(
469
+ f"Safetensors does not support saving dicts with non-tensor values. "
470
+ f"The following keys will be ignored: {empty_state_dict.keys()}"
471
+ )
472
+ state_dict = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)}
473
+ else:
474
+ if not USE_PEFT_BACKEND:
475
+ raise ValueError("PEFT backend is required for saving LoRAs using the `save_attn_procs()` method.")
476
+
477
+ from peft.utils import get_peft_model_state_dict
478
+
479
+ state_dict = get_peft_model_state_dict(self)
480
+
481
+ if save_function is None:
482
+ if safe_serialization:
483
+
484
+ def save_function(weights, filename):
485
+ return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
486
+
487
+ else:
488
+ save_function = torch.save
489
+
490
+ os.makedirs(save_directory, exist_ok=True)
491
+
492
+ if weight_name is None:
493
+ if safe_serialization:
494
+ weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE
495
+ else:
496
+ weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME
497
+
498
+ # Save the model
499
+ save_path = Path(save_directory, weight_name).as_posix()
500
+ save_function(state_dict, save_path)
501
+ logger.info(f"Model weights saved in {save_path}")
502
+
503
+ def _get_custom_diffusion_state_dict(self):
504
+ from ..models.attention_processor import (
505
+ CustomDiffusionAttnProcessor,
506
+ CustomDiffusionAttnProcessor2_0,
507
+ CustomDiffusionXFormersAttnProcessor,
508
+ )
509
+
510
+ model_to_save = AttnProcsLayers(
511
+ {
512
+ y: x
513
+ for (y, x) in self.attn_processors.items()
514
+ if isinstance(
515
+ x,
516
+ (
517
+ CustomDiffusionAttnProcessor,
518
+ CustomDiffusionAttnProcessor2_0,
519
+ CustomDiffusionXFormersAttnProcessor,
520
+ ),
521
+ )
522
+ }
523
+ )
524
+ state_dict = model_to_save.state_dict()
525
+ for name, attn in self.attn_processors.items():
526
+ if len(attn.state_dict()) == 0:
527
+ state_dict[name] = {}
528
+
529
+ return state_dict
530
+
531
+ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False):
532
+ if low_cpu_mem_usage:
533
+ if is_accelerate_available():
534
+ from accelerate import init_empty_weights
535
+
536
+ else:
537
+ low_cpu_mem_usage = False
538
+ logger.warning(
539
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
540
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
541
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
542
+ " install accelerate\n```\n."
543
+ )
544
+
545
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
546
+ raise NotImplementedError(
547
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
548
+ " `low_cpu_mem_usage=False`."
549
+ )
550
+
551
+ updated_state_dict = {}
552
+ image_projection = None
553
+ init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
554
+
555
+ if "proj.weight" in state_dict:
556
+ # IP-Adapter
557
+ num_image_text_embeds = 4
558
+ clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
559
+ cross_attention_dim = state_dict["proj.weight"].shape[0] // 4
560
+
561
+ with init_context():
562
+ image_projection = ImageProjection(
563
+ cross_attention_dim=cross_attention_dim,
564
+ image_embed_dim=clip_embeddings_dim,
565
+ num_image_text_embeds=num_image_text_embeds,
566
+ )
567
+
568
+ for key, value in state_dict.items():
569
+ diffusers_name = key.replace("proj", "image_embeds")
570
+ updated_state_dict[diffusers_name] = value
571
+
572
+ elif "proj.3.weight" in state_dict:
573
+ # IP-Adapter Full
574
+ clip_embeddings_dim = state_dict["proj.0.weight"].shape[0]
575
+ cross_attention_dim = state_dict["proj.3.weight"].shape[0]
576
+
577
+ with init_context():
578
+ image_projection = IPAdapterFullImageProjection(
579
+ cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim
580
+ )
581
+
582
+ for key, value in state_dict.items():
583
+ diffusers_name = key.replace("proj.0", "ff.net.0.proj")
584
+ diffusers_name = diffusers_name.replace("proj.2", "ff.net.2")
585
+ diffusers_name = diffusers_name.replace("proj.3", "norm")
586
+ updated_state_dict[diffusers_name] = value
587
+
588
+ elif "perceiver_resampler.proj_in.weight" in state_dict:
589
+ # IP-Adapter Face ID Plus
590
+ id_embeddings_dim = state_dict["proj.0.weight"].shape[1]
591
+ embed_dims = state_dict["perceiver_resampler.proj_in.weight"].shape[0]
592
+ hidden_dims = state_dict["perceiver_resampler.proj_in.weight"].shape[1]
593
+ output_dims = state_dict["perceiver_resampler.proj_out.weight"].shape[0]
594
+ heads = state_dict["perceiver_resampler.layers.0.0.to_q.weight"].shape[0] // 64
595
+
596
+ with init_context():
597
+ image_projection = IPAdapterFaceIDPlusImageProjection(
598
+ embed_dims=embed_dims,
599
+ output_dims=output_dims,
600
+ hidden_dims=hidden_dims,
601
+ heads=heads,
602
+ id_embeddings_dim=id_embeddings_dim,
603
+ )
604
+
605
+ for key, value in state_dict.items():
606
+ diffusers_name = key.replace("perceiver_resampler.", "")
607
+ diffusers_name = diffusers_name.replace("0.to", "attn.to")
608
+ diffusers_name = diffusers_name.replace("0.1.0.", "0.ff.0.")
609
+ diffusers_name = diffusers_name.replace("0.1.1.weight", "0.ff.1.net.0.proj.weight")
610
+ diffusers_name = diffusers_name.replace("0.1.3.weight", "0.ff.1.net.2.weight")
611
+ diffusers_name = diffusers_name.replace("1.1.0.", "1.ff.0.")
612
+ diffusers_name = diffusers_name.replace("1.1.1.weight", "1.ff.1.net.0.proj.weight")
613
+ diffusers_name = diffusers_name.replace("1.1.3.weight", "1.ff.1.net.2.weight")
614
+ diffusers_name = diffusers_name.replace("2.1.0.", "2.ff.0.")
615
+ diffusers_name = diffusers_name.replace("2.1.1.weight", "2.ff.1.net.0.proj.weight")
616
+ diffusers_name = diffusers_name.replace("2.1.3.weight", "2.ff.1.net.2.weight")
617
+ diffusers_name = diffusers_name.replace("3.1.0.", "3.ff.0.")
618
+ diffusers_name = diffusers_name.replace("3.1.1.weight", "3.ff.1.net.0.proj.weight")
619
+ diffusers_name = diffusers_name.replace("3.1.3.weight", "3.ff.1.net.2.weight")
620
+ diffusers_name = diffusers_name.replace("layers.0.0", "layers.0.ln0")
621
+ diffusers_name = diffusers_name.replace("layers.0.1", "layers.0.ln1")
622
+ diffusers_name = diffusers_name.replace("layers.1.0", "layers.1.ln0")
623
+ diffusers_name = diffusers_name.replace("layers.1.1", "layers.1.ln1")
624
+ diffusers_name = diffusers_name.replace("layers.2.0", "layers.2.ln0")
625
+ diffusers_name = diffusers_name.replace("layers.2.1", "layers.2.ln1")
626
+ diffusers_name = diffusers_name.replace("layers.3.0", "layers.3.ln0")
627
+ diffusers_name = diffusers_name.replace("layers.3.1", "layers.3.ln1")
628
+
629
+ if "norm1" in diffusers_name:
630
+ updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value
631
+ elif "norm2" in diffusers_name:
632
+ updated_state_dict[diffusers_name.replace("0.norm2", "1")] = value
633
+ elif "to_kv" in diffusers_name:
634
+ v_chunk = value.chunk(2, dim=0)
635
+ updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0]
636
+ updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1]
637
+ elif "to_out" in diffusers_name:
638
+ updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value
639
+ elif "proj.0.weight" == diffusers_name:
640
+ updated_state_dict["proj.net.0.proj.weight"] = value
641
+ elif "proj.0.bias" == diffusers_name:
642
+ updated_state_dict["proj.net.0.proj.bias"] = value
643
+ elif "proj.2.weight" == diffusers_name:
644
+ updated_state_dict["proj.net.2.weight"] = value
645
+ elif "proj.2.bias" == diffusers_name:
646
+ updated_state_dict["proj.net.2.bias"] = value
647
+ else:
648
+ updated_state_dict[diffusers_name] = value
649
+
650
+ elif "norm.weight" in state_dict:
651
+ # IP-Adapter Face ID
652
+ id_embeddings_dim_in = state_dict["proj.0.weight"].shape[1]
653
+ id_embeddings_dim_out = state_dict["proj.0.weight"].shape[0]
654
+ multiplier = id_embeddings_dim_out // id_embeddings_dim_in
655
+ norm_layer = "norm.weight"
656
+ cross_attention_dim = state_dict[norm_layer].shape[0]
657
+ num_tokens = state_dict["proj.2.weight"].shape[0] // cross_attention_dim
658
+
659
+ with init_context():
660
+ image_projection = IPAdapterFaceIDImageProjection(
661
+ cross_attention_dim=cross_attention_dim,
662
+ image_embed_dim=id_embeddings_dim_in,
663
+ mult=multiplier,
664
+ num_tokens=num_tokens,
665
+ )
666
+
667
+ for key, value in state_dict.items():
668
+ diffusers_name = key.replace("proj.0", "ff.net.0.proj")
669
+ diffusers_name = diffusers_name.replace("proj.2", "ff.net.2")
670
+ updated_state_dict[diffusers_name] = value
671
+
672
+ else:
673
+ # IP-Adapter Plus
674
+ num_image_text_embeds = state_dict["latents"].shape[1]
675
+ embed_dims = state_dict["proj_in.weight"].shape[1]
676
+ output_dims = state_dict["proj_out.weight"].shape[0]
677
+ hidden_dims = state_dict["latents"].shape[2]
678
+ attn_key_present = any("attn" in k for k in state_dict)
679
+ heads = (
680
+ state_dict["layers.0.attn.to_q.weight"].shape[0] // 64
681
+ if attn_key_present
682
+ else state_dict["layers.0.0.to_q.weight"].shape[0] // 64
683
+ )
684
+
685
+ with init_context():
686
+ image_projection = IPAdapterPlusImageProjection(
687
+ embed_dims=embed_dims,
688
+ output_dims=output_dims,
689
+ hidden_dims=hidden_dims,
690
+ heads=heads,
691
+ num_queries=num_image_text_embeds,
692
+ )
693
+
694
+ for key, value in state_dict.items():
695
+ diffusers_name = key.replace("0.to", "2.to")
696
+
697
+ diffusers_name = diffusers_name.replace("0.0.norm1", "0.ln0")
698
+ diffusers_name = diffusers_name.replace("0.0.norm2", "0.ln1")
699
+ diffusers_name = diffusers_name.replace("1.0.norm1", "1.ln0")
700
+ diffusers_name = diffusers_name.replace("1.0.norm2", "1.ln1")
701
+ diffusers_name = diffusers_name.replace("2.0.norm1", "2.ln0")
702
+ diffusers_name = diffusers_name.replace("2.0.norm2", "2.ln1")
703
+ diffusers_name = diffusers_name.replace("3.0.norm1", "3.ln0")
704
+ diffusers_name = diffusers_name.replace("3.0.norm2", "3.ln1")
705
+
706
+ if "to_kv" in diffusers_name:
707
+ parts = diffusers_name.split(".")
708
+ parts[2] = "attn"
709
+ diffusers_name = ".".join(parts)
710
+ v_chunk = value.chunk(2, dim=0)
711
+ updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0]
712
+ updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1]
713
+ elif "to_q" in diffusers_name:
714
+ parts = diffusers_name.split(".")
715
+ parts[2] = "attn"
716
+ diffusers_name = ".".join(parts)
717
+ updated_state_dict[diffusers_name] = value
718
+ elif "to_out" in diffusers_name:
719
+ parts = diffusers_name.split(".")
720
+ parts[2] = "attn"
721
+ diffusers_name = ".".join(parts)
722
+ updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value
723
+ else:
724
+ diffusers_name = diffusers_name.replace("0.1.0", "0.ff.0")
725
+ diffusers_name = diffusers_name.replace("0.1.1", "0.ff.1.net.0.proj")
726
+ diffusers_name = diffusers_name.replace("0.1.3", "0.ff.1.net.2")
727
+
728
+ diffusers_name = diffusers_name.replace("1.1.0", "1.ff.0")
729
+ diffusers_name = diffusers_name.replace("1.1.1", "1.ff.1.net.0.proj")
730
+ diffusers_name = diffusers_name.replace("1.1.3", "1.ff.1.net.2")
731
+
732
+ diffusers_name = diffusers_name.replace("2.1.0", "2.ff.0")
733
+ diffusers_name = diffusers_name.replace("2.1.1", "2.ff.1.net.0.proj")
734
+ diffusers_name = diffusers_name.replace("2.1.3", "2.ff.1.net.2")
735
+
736
+ diffusers_name = diffusers_name.replace("3.1.0", "3.ff.0")
737
+ diffusers_name = diffusers_name.replace("3.1.1", "3.ff.1.net.0.proj")
738
+ diffusers_name = diffusers_name.replace("3.1.3", "3.ff.1.net.2")
739
+ updated_state_dict[diffusers_name] = value
740
+
741
+ if not low_cpu_mem_usage:
742
+ image_projection.load_state_dict(updated_state_dict, strict=True)
743
+ else:
744
+ load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
745
+
746
+ return image_projection
747
+
748
+ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
749
+ from ..models.attention_processor import (
750
+ IPAdapterAttnProcessor,
751
+ IPAdapterAttnProcessor2_0,
752
+ )
753
+
754
+ if low_cpu_mem_usage:
755
+ if is_accelerate_available():
756
+ from accelerate import init_empty_weights
757
+
758
+ else:
759
+ low_cpu_mem_usage = False
760
+ logger.warning(
761
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
762
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
763
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
764
+ " install accelerate\n```\n."
765
+ )
766
+
767
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
768
+ raise NotImplementedError(
769
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
770
+ " `low_cpu_mem_usage=False`."
771
+ )
772
+
773
+ # set ip-adapter cross-attention processors & load state_dict
774
+ attn_procs = {}
775
+ key_id = 1
776
+ init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
777
+ for name in self.attn_processors.keys():
778
+ cross_attention_dim = None if name.endswith("attn1.processor") else self.config.cross_attention_dim
779
+ if name.startswith("mid_block"):
780
+ hidden_size = self.config.block_out_channels[-1]
781
+ elif name.startswith("up_blocks"):
782
+ block_id = int(name[len("up_blocks.")])
783
+ hidden_size = list(reversed(self.config.block_out_channels))[block_id]
784
+ elif name.startswith("down_blocks"):
785
+ block_id = int(name[len("down_blocks.")])
786
+ hidden_size = self.config.block_out_channels[block_id]
787
+
788
+ if cross_attention_dim is None or "motion_modules" in name:
789
+ attn_processor_class = self.attn_processors[name].__class__
790
+ attn_procs[name] = attn_processor_class()
791
+
792
+ else:
793
+ attn_processor_class = (
794
+ IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
795
+ )
796
+ num_image_text_embeds = []
797
+ for state_dict in state_dicts:
798
+ if "proj.weight" in state_dict["image_proj"]:
799
+ # IP-Adapter
800
+ num_image_text_embeds += [4]
801
+ elif "proj.3.weight" in state_dict["image_proj"]:
802
+ # IP-Adapter Full Face
803
+ num_image_text_embeds += [257] # 256 CLIP tokens + 1 CLS token
804
+ elif "perceiver_resampler.proj_in.weight" in state_dict["image_proj"]:
805
+ # IP-Adapter Face ID Plus
806
+ num_image_text_embeds += [4]
807
+ elif "norm.weight" in state_dict["image_proj"]:
808
+ # IP-Adapter Face ID
809
+ num_image_text_embeds += [4]
810
+ else:
811
+ # IP-Adapter Plus
812
+ num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]]
813
+
814
+ with init_context():
815
+ attn_procs[name] = attn_processor_class(
816
+ hidden_size=hidden_size,
817
+ cross_attention_dim=cross_attention_dim,
818
+ scale=1.0,
819
+ num_tokens=num_image_text_embeds,
820
+ )
821
+
822
+ value_dict = {}
823
+ for i, state_dict in enumerate(state_dicts):
824
+ value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
825
+ value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
826
+
827
+ if not low_cpu_mem_usage:
828
+ attn_procs[name].load_state_dict(value_dict)
829
+ else:
830
+ device = next(iter(value_dict.values())).device
831
+ dtype = next(iter(value_dict.values())).dtype
832
+ load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
833
+
834
+ key_id += 2
835
+
836
+ return attn_procs
837
+
838
+ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
839
+ if not isinstance(state_dicts, list):
840
+ state_dicts = [state_dicts]
841
+
842
+ # Kolors Unet already has a `encoder_hid_proj`
843
+ if (
844
+ self.encoder_hid_proj is not None
845
+ and self.config.encoder_hid_dim_type == "text_proj"
846
+ and not hasattr(self, "text_encoder_hid_proj")
847
+ ):
848
+ self.text_encoder_hid_proj = self.encoder_hid_proj
849
+
850
+ # Set encoder_hid_proj after loading ip_adapter weights,
851
+ # because `IPAdapterPlusImageProjection` also has `attn_processors`.
852
+ self.encoder_hid_proj = None
853
+
854
+ attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
855
+ self.set_attn_processor(attn_procs)
856
+
857
+ # convert IP-Adapter Image Projection layers to diffusers
858
+ image_projection_layers = []
859
+ for state_dict in state_dicts:
860
+ image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(
861
+ state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage
862
+ )
863
+ image_projection_layers.append(image_projection_layer)
864
+
865
+ self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
866
+ self.config.encoder_hid_dim_type = "ip_image_proj"
867
+
868
+ self.to(dtype=self.dtype, device=self.device)
869
+
870
+ def _load_ip_adapter_loras(self, state_dicts):
871
+ lora_dicts = {}
872
+ for key_id, name in enumerate(self.attn_processors.keys()):
873
+ for i, state_dict in enumerate(state_dicts):
874
+ if f"{key_id}.to_k_lora.down.weight" in state_dict["ip_adapter"]:
875
+ if i not in lora_dicts:
876
+ lora_dicts[i] = {}
877
+ lora_dicts[i].update(
878
+ {
879
+ f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][
880
+ f"{key_id}.to_k_lora.down.weight"
881
+ ]
882
+ }
883
+ )
884
+ lora_dicts[i].update(
885
+ {
886
+ f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][
887
+ f"{key_id}.to_q_lora.down.weight"
888
+ ]
889
+ }
890
+ )
891
+ lora_dicts[i].update(
892
+ {
893
+ f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][
894
+ f"{key_id}.to_v_lora.down.weight"
895
+ ]
896
+ }
897
+ )
898
+ lora_dicts[i].update(
899
+ {
900
+ f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][
901
+ f"{key_id}.to_out_lora.down.weight"
902
+ ]
903
+ }
904
+ )
905
+ lora_dicts[i].update(
906
+ {f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]}
907
+ )
908
+ lora_dicts[i].update(
909
+ {f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]}
910
+ )
911
+ lora_dicts[i].update(
912
+ {f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]}
913
+ )
914
+ lora_dicts[i].update(
915
+ {
916
+ f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][
917
+ f"{key_id}.to_out_lora.up.weight"
918
+ ]
919
+ }
920
+ )
921
+ return lora_dicts
diffusers/loaders/unet_loader_utils.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import copy
15
+ from typing import TYPE_CHECKING, Dict, List, Union
16
+
17
+ from ..utils import logging
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ # import here to avoid circular imports
22
+ from ..models import UNet2DConditionModel
23
+
24
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
25
+
26
+
27
+ def _translate_into_actual_layer_name(name):
28
+ """Translate user-friendly name (e.g. 'mid') into actual layer name (e.g. 'mid_block.attentions.0')"""
29
+ if name == "mid":
30
+ return "mid_block.attentions.0"
31
+
32
+ updown, block, attn = name.split(".")
33
+
34
+ updown = updown.replace("down", "down_blocks").replace("up", "up_blocks")
35
+ block = block.replace("block_", "")
36
+ attn = "attentions." + attn
37
+
38
+ return ".".join((updown, block, attn))
39
+
40
+
41
+ def _maybe_expand_lora_scales(
42
+ unet: "UNet2DConditionModel", weight_scales: List[Union[float, Dict]], default_scale=1.0
43
+ ):
44
+ blocks_with_transformer = {
45
+ "down": [i for i, block in enumerate(unet.down_blocks) if hasattr(block, "attentions")],
46
+ "up": [i for i, block in enumerate(unet.up_blocks) if hasattr(block, "attentions")],
47
+ }
48
+ transformer_per_block = {"down": unet.config.layers_per_block, "up": unet.config.layers_per_block + 1}
49
+
50
+ expanded_weight_scales = [
51
+ _maybe_expand_lora_scales_for_one_adapter(
52
+ weight_for_adapter,
53
+ blocks_with_transformer,
54
+ transformer_per_block,
55
+ unet.state_dict(),
56
+ default_scale=default_scale,
57
+ )
58
+ for weight_for_adapter in weight_scales
59
+ ]
60
+
61
+ return expanded_weight_scales
62
+
63
+
64
+ def _maybe_expand_lora_scales_for_one_adapter(
65
+ scales: Union[float, Dict],
66
+ blocks_with_transformer: Dict[str, int],
67
+ transformer_per_block: Dict[str, int],
68
+ state_dict: None,
69
+ default_scale: float = 1.0,
70
+ ):
71
+ """
72
+ Expands the inputs into a more granular dictionary. See the example below for more details.
73
+
74
+ Parameters:
75
+ scales (`Union[float, Dict]`):
76
+ Scales dict to expand.
77
+ blocks_with_transformer (`Dict[str, int]`):
78
+ Dict with keys 'up' and 'down', showing which blocks have transformer layers
79
+ transformer_per_block (`Dict[str, int]`):
80
+ Dict with keys 'up' and 'down', showing how many transformer layers each block has
81
+
82
+ E.g. turns
83
+ ```python
84
+ scales = {"down": 2, "mid": 3, "up": {"block_0": 4, "block_1": [5, 6, 7]}}
85
+ blocks_with_transformer = {"down": [1, 2], "up": [0, 1]}
86
+ transformer_per_block = {"down": 2, "up": 3}
87
+ ```
88
+ into
89
+ ```python
90
+ {
91
+ "down.block_1.0": 2,
92
+ "down.block_1.1": 2,
93
+ "down.block_2.0": 2,
94
+ "down.block_2.1": 2,
95
+ "mid": 3,
96
+ "up.block_0.0": 4,
97
+ "up.block_0.1": 4,
98
+ "up.block_0.2": 4,
99
+ "up.block_1.0": 5,
100
+ "up.block_1.1": 6,
101
+ "up.block_1.2": 7,
102
+ }
103
+ ```
104
+ """
105
+ if sorted(blocks_with_transformer.keys()) != ["down", "up"]:
106
+ raise ValueError("blocks_with_transformer needs to be a dict with keys `'down' and `'up'`")
107
+
108
+ if sorted(transformer_per_block.keys()) != ["down", "up"]:
109
+ raise ValueError("transformer_per_block needs to be a dict with keys `'down' and `'up'`")
110
+
111
+ if not isinstance(scales, dict):
112
+ # don't expand if scales is a single number
113
+ return scales
114
+
115
+ scales = copy.deepcopy(scales)
116
+
117
+ if "mid" not in scales:
118
+ scales["mid"] = default_scale
119
+ elif isinstance(scales["mid"], list):
120
+ if len(scales["mid"]) == 1:
121
+ scales["mid"] = scales["mid"][0]
122
+ else:
123
+ raise ValueError(f"Expected 1 scales for mid, got {len(scales['mid'])}.")
124
+
125
+ for updown in ["up", "down"]:
126
+ if updown not in scales:
127
+ scales[updown] = default_scale
128
+
129
+ # eg {"down": 1} to {"down": {"block_1": 1, "block_2": 1}}}
130
+ if not isinstance(scales[updown], dict):
131
+ scales[updown] = {f"block_{i}": copy.deepcopy(scales[updown]) for i in blocks_with_transformer[updown]}
132
+
133
+ # eg {"down": {"block_1": 1}} to {"down": {"block_1": [1, 1]}}
134
+ for i in blocks_with_transformer[updown]:
135
+ block = f"block_{i}"
136
+ # set not assigned blocks to default scale
137
+ if block not in scales[updown]:
138
+ scales[updown][block] = default_scale
139
+ if not isinstance(scales[updown][block], list):
140
+ scales[updown][block] = [scales[updown][block] for _ in range(transformer_per_block[updown])]
141
+ elif len(scales[updown][block]) == 1:
142
+ # a list specifying scale to each masked IP input
143
+ scales[updown][block] = scales[updown][block] * transformer_per_block[updown]
144
+ elif len(scales[updown][block]) != transformer_per_block[updown]:
145
+ raise ValueError(
146
+ f"Expected {transformer_per_block[updown]} scales for {updown}.{block}, got {len(scales[updown][block])}."
147
+ )
148
+
149
+ # eg {"down": "block_1": [1, 1]}} to {"down.block_1.0": 1, "down.block_1.1": 1}
150
+ for i in blocks_with_transformer[updown]:
151
+ block = f"block_{i}"
152
+ for tf_idx, value in enumerate(scales[updown][block]):
153
+ scales[f"{updown}.{block}.{tf_idx}"] = value
154
+
155
+ del scales[updown]
156
+
157
+ for layer in scales.keys():
158
+ if not any(_translate_into_actual_layer_name(layer) in module for module in state_dict.keys()):
159
+ raise ValueError(
160
+ f"Can't set lora scale for layer {layer}. It either doesn't exist in this unet or it has no attentions."
161
+ )
162
+
163
+ return {_translate_into_actual_layer_name(name): weight for name, weight in scales.items()}
diffusers/loaders/utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict
16
+
17
+ import torch
18
+
19
+
20
+ class AttnProcsLayers(torch.nn.Module):
21
+ def __init__(self, state_dict: Dict[str, torch.Tensor]):
22
+ super().__init__()
23
+ self.layers = torch.nn.ModuleList(state_dict.values())
24
+ self.mapping = dict(enumerate(state_dict.keys()))
25
+ self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
26
+
27
+ # .processor for unet, .self_attn for text encoder
28
+ self.split_keys = [".processor", ".self_attn"]
29
+
30
+ # we add a hook to state_dict() and load_state_dict() so that the
31
+ # naming fits with `unet.attn_processors`
32
+ def map_to(module, state_dict, *args, **kwargs):
33
+ new_state_dict = {}
34
+ for key, value in state_dict.items():
35
+ num = int(key.split(".")[1]) # 0 is always "layers"
36
+ new_key = key.replace(f"layers.{num}", module.mapping[num])
37
+ new_state_dict[new_key] = value
38
+
39
+ return new_state_dict
40
+
41
+ def remap_key(key, state_dict):
42
+ for k in self.split_keys:
43
+ if k in key:
44
+ return key.split(k)[0] + k
45
+
46
+ raise ValueError(
47
+ f"There seems to be a problem with the state_dict: {set(state_dict.keys())}. {key} has to have one of {self.split_keys}."
48
+ )
49
+
50
+ def map_from(module, state_dict, *args, **kwargs):
51
+ all_keys = list(state_dict.keys())
52
+ for key in all_keys:
53
+ replace_key = remap_key(key, state_dict)
54
+ new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}")
55
+ state_dict[new_key] = state_dict[key]
56
+ del state_dict[key]
57
+
58
+ self._register_state_dict_hook(map_to)
59
+ self._register_load_state_dict_pre_hook(map_from, with_module=True)
diffusers/models/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Models
2
+
3
+ For more detail on the models, please refer to the [docs](https://huggingface.co/docs/diffusers/api/models/overview).
diffusers/models/__init__.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ from ..utils import (
18
+ DIFFUSERS_SLOW_IMPORT,
19
+ _LazyModule,
20
+ is_flax_available,
21
+ is_torch_available,
22
+ )
23
+
24
+
25
+ _import_structure = {}
26
+
27
+ if is_torch_available():
28
+ _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
29
+ _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
30
+ _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
31
+ _import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
32
+ _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
33
+ _import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
34
+ _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
35
+ _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
36
+ _import_structure["autoencoders.vq_model"] = ["VQModel"]
37
+ _import_structure["controlnet"] = ["ControlNetModel"]
38
+ _import_structure["controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
39
+ _import_structure["controlnet_hunyuan"] = ["HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel"]
40
+ _import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
41
+ _import_structure["controlnet_sparsectrl"] = ["SparseControlNetModel"]
42
+ _import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
43
+ _import_structure["embeddings"] = ["ImageProjection"]
44
+ _import_structure["modeling_utils"] = ["ModelMixin"]
45
+ _import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"]
46
+ _import_structure["transformers.cogvideox_transformer_3d"] = ["CogVideoXTransformer3DModel"]
47
+ _import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"]
48
+ _import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"]
49
+ _import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"]
50
+ _import_structure["transformers.latte_transformer_3d"] = ["LatteTransformer3DModel"]
51
+ _import_structure["transformers.lumina_nextdit2d"] = ["LuminaNextDiT2DModel"]
52
+ _import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"]
53
+ _import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
54
+ _import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
55
+ _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
56
+ _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
57
+ _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
58
+ _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
59
+ _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
60
+ _import_structure["unets.unet_1d"] = ["UNet1DModel"]
61
+ _import_structure["unets.unet_2d"] = ["UNet2DModel"]
62
+ _import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
63
+ _import_structure["unets.unet_3d_condition"] = ["UNet3DConditionModel"]
64
+ _import_structure["unets.unet_i2vgen_xl"] = ["I2VGenXLUNet"]
65
+ _import_structure["unets.unet_kandinsky3"] = ["Kandinsky3UNet"]
66
+ _import_structure["unets.unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
67
+ _import_structure["unets.unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
68
+ _import_structure["unets.unet_stable_cascade"] = ["StableCascadeUNet"]
69
+ _import_structure["unets.uvit_2d"] = ["UVit2DModel"]
70
+
71
+ if is_flax_available():
72
+ _import_structure["controlnet_flax"] = ["FlaxControlNetModel"]
73
+ _import_structure["unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
74
+ _import_structure["vae_flax"] = ["FlaxAutoencoderKL"]
75
+
76
+
77
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
78
+ if is_torch_available():
79
+ from .adapter import MultiAdapter, T2IAdapter
80
+ from .autoencoders import (
81
+ AsymmetricAutoencoderKL,
82
+ AutoencoderKL,
83
+ AutoencoderKLCogVideoX,
84
+ AutoencoderKLTemporalDecoder,
85
+ AutoencoderOobleck,
86
+ AutoencoderTiny,
87
+ ConsistencyDecoderVAE,
88
+ VQModel,
89
+ )
90
+ from .controlnet import ControlNetModel
91
+ from .controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
92
+ from .controlnet_hunyuan import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel
93
+ from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
94
+ from .controlnet_sparsectrl import SparseControlNetModel
95
+ from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel
96
+ from .embeddings import ImageProjection
97
+ from .modeling_utils import ModelMixin
98
+ from .transformers import (
99
+ AuraFlowTransformer2DModel,
100
+ CogVideoXTransformer3DModel,
101
+ DiTTransformer2DModel,
102
+ DualTransformer2DModel,
103
+ FluxTransformer2DModel,
104
+ HunyuanDiT2DModel,
105
+ LatteTransformer3DModel,
106
+ LuminaNextDiT2DModel,
107
+ PixArtTransformer2DModel,
108
+ PriorTransformer,
109
+ SD3Transformer2DModel,
110
+ StableAudioDiTModel,
111
+ T5FilmDecoder,
112
+ Transformer2DModel,
113
+ TransformerTemporalModel,
114
+ )
115
+ from .unets import (
116
+ I2VGenXLUNet,
117
+ Kandinsky3UNet,
118
+ MotionAdapter,
119
+ StableCascadeUNet,
120
+ UNet1DModel,
121
+ UNet2DConditionModel,
122
+ UNet2DModel,
123
+ UNet3DConditionModel,
124
+ UNetMotionModel,
125
+ UNetSpatioTemporalConditionModel,
126
+ UVit2DModel,
127
+ )
128
+
129
+ if is_flax_available():
130
+ from .controlnet_flax import FlaxControlNetModel
131
+ from .unets import FlaxUNet2DConditionModel
132
+ from .vae_flax import FlaxAutoencoderKL
133
+
134
+ else:
135
+ import sys
136
+
137
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diffusers/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (4.35 kB). View file