ssoxye commited on
Commit
313389b
·
1 Parent(s): 61345be

Track diffusers3 models/pipelines (fix .gitignore)

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +6 -0
  2. diffusers3/models/.ipynb_checkpoints/__init__-checkpoint.py +137 -0
  3. diffusers3/models/.ipynb_checkpoints/attention-checkpoint.py +1202 -0
  4. diffusers3/models/.ipynb_checkpoints/controlnet-checkpoint.py +870 -0
  5. diffusers3/models/README.md +3 -0
  6. diffusers3/models/__init__.py +137 -0
  7. diffusers3/models/__pycache__/__init__.cpython-310.pyc +0 -0
  8. diffusers3/models/__pycache__/__init__.cpython-38.pyc +0 -0
  9. diffusers3/models/__pycache__/activations.cpython-310.pyc +0 -0
  10. diffusers3/models/__pycache__/activations.cpython-38.pyc +0 -0
  11. diffusers3/models/__pycache__/attention.cpython-310.pyc +0 -0
  12. diffusers3/models/__pycache__/attention.cpython-38.pyc +0 -0
  13. diffusers3/models/__pycache__/attention_processor.cpython-310.pyc +0 -0
  14. diffusers3/models/__pycache__/attention_processor.cpython-38.pyc +0 -0
  15. diffusers3/models/__pycache__/controlnet.cpython-310.pyc +0 -0
  16. diffusers3/models/__pycache__/controlnet.cpython-38.pyc +0 -0
  17. diffusers3/models/__pycache__/downsampling.cpython-310.pyc +0 -0
  18. diffusers3/models/__pycache__/downsampling.cpython-38.pyc +0 -0
  19. diffusers3/models/__pycache__/embeddings.cpython-310.pyc +0 -0
  20. diffusers3/models/__pycache__/embeddings.cpython-38.pyc +0 -0
  21. diffusers3/models/__pycache__/lora.cpython-310.pyc +0 -0
  22. diffusers3/models/__pycache__/lora.cpython-38.pyc +0 -0
  23. diffusers3/models/__pycache__/model_loading_utils.cpython-310.pyc +0 -0
  24. diffusers3/models/__pycache__/model_loading_utils.cpython-38.pyc +0 -0
  25. diffusers3/models/__pycache__/modeling_outputs.cpython-310.pyc +0 -0
  26. diffusers3/models/__pycache__/modeling_outputs.cpython-38.pyc +0 -0
  27. diffusers3/models/__pycache__/modeling_utils.cpython-310.pyc +0 -0
  28. diffusers3/models/__pycache__/modeling_utils.cpython-38.pyc +0 -0
  29. diffusers3/models/__pycache__/normalization.cpython-310.pyc +0 -0
  30. diffusers3/models/__pycache__/normalization.cpython-38.pyc +0 -0
  31. diffusers3/models/__pycache__/resnet.cpython-310.pyc +0 -0
  32. diffusers3/models/__pycache__/resnet.cpython-38.pyc +0 -0
  33. diffusers3/models/__pycache__/upsampling.cpython-310.pyc +0 -0
  34. diffusers3/models/__pycache__/upsampling.cpython-38.pyc +0 -0
  35. diffusers3/models/activations.py +165 -0
  36. diffusers3/models/adapter.py +584 -0
  37. diffusers3/models/attention.py +1202 -0
  38. diffusers3/models/attention_flax.py +494 -0
  39. diffusers3/models/attention_processor.py +0 -0
  40. diffusers3/models/autoencoders/__init__.py +8 -0
  41. diffusers3/models/autoencoders/__pycache__/__init__.cpython-310.pyc +0 -0
  42. diffusers3/models/autoencoders/__pycache__/__init__.cpython-38.pyc +0 -0
  43. diffusers3/models/autoencoders/__pycache__/autoencoder_asym_kl.cpython-310.pyc +0 -0
  44. diffusers3/models/autoencoders/__pycache__/autoencoder_asym_kl.cpython-38.pyc +0 -0
  45. diffusers3/models/autoencoders/__pycache__/autoencoder_kl.cpython-310.pyc +0 -0
  46. diffusers3/models/autoencoders/__pycache__/autoencoder_kl.cpython-38.pyc +0 -0
  47. diffusers3/models/autoencoders/__pycache__/autoencoder_kl_cogvideox.cpython-310.pyc +0 -0
  48. diffusers3/models/autoencoders/__pycache__/autoencoder_kl_cogvideox.cpython-38.pyc +0 -0
  49. diffusers3/models/autoencoders/__pycache__/autoencoder_kl_temporal_decoder.cpython-310.pyc +0 -0
  50. diffusers3/models/autoencoders/__pycache__/autoencoder_kl_temporal_decoder.cpython-38.pyc +0 -0
.gitignore CHANGED
@@ -8,3 +8,9 @@ preprocess/mhp_extension/demo/*.jpg
8
  preprocess/demo/*.jpg
9
  preprocess/mhp_extension/demo/*.jpg
10
  **/.ipynb_checkpoints/
 
 
 
 
 
 
 
8
  preprocess/demo/*.jpg
9
  preprocess/mhp_extension/demo/*.jpg
10
  **/.ipynb_checkpoints/
11
+
12
+ # allow python package path (was unintentionally ignored by "models/")
13
+ !diffusers3/models/
14
+ !diffusers3/models/**
15
+ !diffusers3/pipelines/
16
+ !diffusers3/pipelines/**
diffusers3/models/.ipynb_checkpoints/__init__-checkpoint.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ from ..utils import (
18
+ DIFFUSERS_SLOW_IMPORT,
19
+ _LazyModule,
20
+ is_flax_available,
21
+ is_torch_available,
22
+ )
23
+
24
+
25
+ _import_structure = {}
26
+
27
+ if is_torch_available():
28
+ _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
29
+ _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
30
+ _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
31
+ _import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
32
+ _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
33
+ _import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
34
+ _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
35
+ _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
36
+ _import_structure["autoencoders.vq_model"] = ["VQModel"]
37
+ _import_structure["controlnet"] = ["ControlNetModel"]
38
+ _import_structure["controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
39
+ _import_structure["controlnet_hunyuan"] = ["HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel"]
40
+ _import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
41
+ _import_structure["controlnet_sparsectrl"] = ["SparseControlNetModel"]
42
+ _import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
43
+ _import_structure["embeddings"] = ["ImageProjection"]
44
+ _import_structure["modeling_utils"] = ["ModelMixin"]
45
+ _import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"]
46
+ _import_structure["transformers.cogvideox_transformer_3d"] = ["CogVideoXTransformer3DModel"]
47
+ _import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"]
48
+ _import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"]
49
+ _import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"]
50
+ _import_structure["transformers.latte_transformer_3d"] = ["LatteTransformer3DModel"]
51
+ _import_structure["transformers.lumina_nextdit2d"] = ["LuminaNextDiT2DModel"]
52
+ _import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"]
53
+ _import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
54
+ _import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
55
+ _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
56
+ _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
57
+ _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
58
+ _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
59
+ _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
60
+ _import_structure["unets.unet_1d"] = ["UNet1DModel"]
61
+ _import_structure["unets.unet_2d"] = ["UNet2DModel"]
62
+ _import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
63
+ _import_structure["unets.unet_3d_condition"] = ["UNet3DConditionModel"]
64
+ _import_structure["unets.unet_i2vgen_xl"] = ["I2VGenXLUNet"]
65
+ _import_structure["unets.unet_kandinsky3"] = ["Kandinsky3UNet"]
66
+ _import_structure["unets.unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
67
+ _import_structure["unets.unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
68
+ _import_structure["unets.unet_stable_cascade"] = ["StableCascadeUNet"]
69
+ _import_structure["unets.uvit_2d"] = ["UVit2DModel"]
70
+
71
+ if is_flax_available():
72
+ _import_structure["controlnet_flax"] = ["FlaxControlNetModel"]
73
+ _import_structure["unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
74
+ _import_structure["vae_flax"] = ["FlaxAutoencoderKL"]
75
+
76
+
77
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
78
+ if is_torch_available():
79
+ from .adapter import MultiAdapter, T2IAdapter
80
+ from .autoencoders import (
81
+ AsymmetricAutoencoderKL,
82
+ AutoencoderKL,
83
+ AutoencoderKLCogVideoX,
84
+ AutoencoderKLTemporalDecoder,
85
+ AutoencoderOobleck,
86
+ AutoencoderTiny,
87
+ ConsistencyDecoderVAE,
88
+ VQModel,
89
+ )
90
+ from .controlnet import ControlNetModel
91
+ from .controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
92
+ from .controlnet_hunyuan import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel
93
+ from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
94
+ from .controlnet_sparsectrl import SparseControlNetModel
95
+ from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel
96
+ from .embeddings import ImageProjection
97
+ from .modeling_utils import ModelMixin
98
+ from .transformers import (
99
+ AuraFlowTransformer2DModel,
100
+ CogVideoXTransformer3DModel,
101
+ DiTTransformer2DModel,
102
+ DualTransformer2DModel,
103
+ FluxTransformer2DModel,
104
+ HunyuanDiT2DModel,
105
+ LatteTransformer3DModel,
106
+ LuminaNextDiT2DModel,
107
+ PixArtTransformer2DModel,
108
+ PriorTransformer,
109
+ SD3Transformer2DModel,
110
+ StableAudioDiTModel,
111
+ T5FilmDecoder,
112
+ Transformer2DModel,
113
+ TransformerTemporalModel,
114
+ )
115
+ from .unets import (
116
+ I2VGenXLUNet,
117
+ Kandinsky3UNet,
118
+ MotionAdapter,
119
+ StableCascadeUNet,
120
+ UNet1DModel,
121
+ UNet2DConditionModel,
122
+ UNet2DModel,
123
+ UNet3DConditionModel,
124
+ UNetMotionModel,
125
+ UNetSpatioTemporalConditionModel,
126
+ UVit2DModel,
127
+ )
128
+
129
+ if is_flax_available():
130
+ from .controlnet_flax import FlaxControlNetModel
131
+ from .unets import FlaxUNet2DConditionModel
132
+ from .vae_flax import FlaxAutoencoderKL
133
+
134
+ else:
135
+ import sys
136
+
137
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diffusers3/models/.ipynb_checkpoints/attention-checkpoint.py ADDED
@@ -0,0 +1,1202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, List, Optional, Tuple
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from ..utils import deprecate, logging
21
+ from ..utils.torch_utils import maybe_allow_in_graph
22
+ from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
23
+ from .attention_processor import Attention, JointAttnProcessor2_0
24
+ from .embeddings import SinusoidalPositionalEmbedding
25
+ from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
26
+
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
32
+ # "feed_forward_chunk_size" can be used to save memory
33
+ if hidden_states.shape[chunk_dim] % chunk_size != 0:
34
+ raise ValueError(
35
+ f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
36
+ )
37
+
38
+ num_chunks = hidden_states.shape[chunk_dim] // chunk_size
39
+ ff_output = torch.cat(
40
+ [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
41
+ dim=chunk_dim,
42
+ )
43
+ return ff_output
44
+
45
+
46
+ @maybe_allow_in_graph
47
+ class GatedSelfAttentionDense(nn.Module):
48
+ r"""
49
+ A gated self-attention dense layer that combines visual features and object features.
50
+
51
+ Parameters:
52
+ query_dim (`int`): The number of channels in the query.
53
+ context_dim (`int`): The number of channels in the context.
54
+ n_heads (`int`): The number of heads to use for attention.
55
+ d_head (`int`): The number of channels in each head.
56
+ """
57
+
58
+ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
59
+ super().__init__()
60
+
61
+ # we need a linear projection since we need cat visual feature and obj feature
62
+ self.linear = nn.Linear(context_dim, query_dim)
63
+
64
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
65
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
66
+
67
+ self.norm1 = nn.LayerNorm(query_dim)
68
+ self.norm2 = nn.LayerNorm(query_dim)
69
+
70
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
71
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
72
+
73
+ self.enabled = True
74
+
75
+ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
76
+ if not self.enabled:
77
+ return x
78
+
79
+ n_visual = x.shape[1]
80
+ objs = self.linear(objs)
81
+
82
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
83
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
84
+
85
+ return x
86
+
87
+
88
+ @maybe_allow_in_graph
89
+ class JointTransformerBlock(nn.Module):
90
+ r"""
91
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
92
+
93
+ Reference: https://arxiv.org/abs/2403.03206
94
+
95
+ Parameters:
96
+ dim (`int`): The number of channels in the input and output.
97
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
98
+ attention_head_dim (`int`): The number of channels in each head.
99
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
100
+ processing of `context` conditions.
101
+ """
102
+
103
+ def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False):
104
+ super().__init__()
105
+
106
+ self.context_pre_only = context_pre_only
107
+ context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
108
+
109
+ self.norm1 = AdaLayerNormZero(dim)
110
+
111
+ if context_norm_type == "ada_norm_continous":
112
+ self.norm1_context = AdaLayerNormContinuous(
113
+ dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
114
+ )
115
+ elif context_norm_type == "ada_norm_zero":
116
+ self.norm1_context = AdaLayerNormZero(dim)
117
+ else:
118
+ raise ValueError(
119
+ f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
120
+ )
121
+ if hasattr(F, "scaled_dot_product_attention"):
122
+ processor = JointAttnProcessor2_0()
123
+ else:
124
+ raise ValueError(
125
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
126
+ )
127
+ self.attn = Attention(
128
+ query_dim=dim,
129
+ cross_attention_dim=None,
130
+ added_kv_proj_dim=dim,
131
+ dim_head=attention_head_dim,
132
+ heads=num_attention_heads,
133
+ out_dim=dim,
134
+ context_pre_only=context_pre_only,
135
+ bias=True,
136
+ processor=processor,
137
+ )
138
+
139
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
140
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
141
+
142
+ if not context_pre_only:
143
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
144
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
145
+ else:
146
+ self.norm2_context = None
147
+ self.ff_context = None
148
+
149
+ # let chunk size default to None
150
+ self._chunk_size = None
151
+ self._chunk_dim = 0
152
+
153
+ # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
154
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
155
+ # Sets chunk feed-forward
156
+ self._chunk_size = chunk_size
157
+ self._chunk_dim = dim
158
+
159
+ def forward(
160
+ self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
161
+ ):
162
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
163
+
164
+ if self.context_pre_only:
165
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
166
+ else:
167
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
168
+ encoder_hidden_states, emb=temb
169
+ )
170
+
171
+ # Attention.
172
+ attn_output, context_attn_output = self.attn(
173
+ hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
174
+ )
175
+
176
+ # Process attention outputs for the `hidden_states`.
177
+ attn_output = gate_msa.unsqueeze(1) * attn_output
178
+ hidden_states = hidden_states + attn_output
179
+
180
+ norm_hidden_states = self.norm2(hidden_states)
181
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
182
+ if self._chunk_size is not None:
183
+ # "feed_forward_chunk_size" can be used to save memory
184
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
185
+ else:
186
+ ff_output = self.ff(norm_hidden_states)
187
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
188
+
189
+ hidden_states = hidden_states + ff_output
190
+
191
+ # Process attention outputs for the `encoder_hidden_states`.
192
+ if self.context_pre_only:
193
+ encoder_hidden_states = None
194
+ else:
195
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
196
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
197
+
198
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
199
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
200
+ if self._chunk_size is not None:
201
+ # "feed_forward_chunk_size" can be used to save memory
202
+ context_ff_output = _chunked_feed_forward(
203
+ self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
204
+ )
205
+ else:
206
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
207
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
208
+
209
+ return encoder_hidden_states, hidden_states
210
+
211
+
212
+ @maybe_allow_in_graph
213
+ class BasicTransformerBlock(nn.Module):
214
+ r"""
215
+ A basic Transformer block.
216
+
217
+ Parameters:
218
+ dim (`int`): The number of channels in the input and output.
219
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
220
+ attention_head_dim (`int`): The number of channels in each head.
221
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
222
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
223
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
224
+ num_embeds_ada_norm (:
225
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
226
+ attention_bias (:
227
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
228
+ only_cross_attention (`bool`, *optional*):
229
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
230
+ double_self_attention (`bool`, *optional*):
231
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
232
+ upcast_attention (`bool`, *optional*):
233
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
234
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
235
+ Whether to use learnable elementwise affine parameters for normalization.
236
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
237
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
238
+ final_dropout (`bool` *optional*, defaults to False):
239
+ Whether to apply a final dropout after the last feed-forward layer.
240
+ attention_type (`str`, *optional*, defaults to `"default"`):
241
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
242
+ positional_embeddings (`str`, *optional*, defaults to `None`):
243
+ The type of positional embeddings to apply to.
244
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
245
+ The maximum number of positional embeddings to apply.
246
+ """
247
+
248
+ def __init__(
249
+ self,
250
+ dim: int,
251
+ num_attention_heads: int,
252
+ attention_head_dim: int,
253
+ dropout=0.0,
254
+ cross_attention_dim: Optional[int] = None,
255
+ activation_fn: str = "geglu",
256
+ num_embeds_ada_norm: Optional[int] = None,
257
+ attention_bias: bool = False,
258
+ only_cross_attention: bool = False,
259
+ double_self_attention: bool = False,
260
+ upcast_attention: bool = False,
261
+ norm_elementwise_affine: bool = True,
262
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
263
+ norm_eps: float = 1e-5,
264
+ final_dropout: bool = False,
265
+ attention_type: str = "default",
266
+ positional_embeddings: Optional[str] = None,
267
+ num_positional_embeddings: Optional[int] = None,
268
+ ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
269
+ ada_norm_bias: Optional[int] = None,
270
+ ff_inner_dim: Optional[int] = None,
271
+ ff_bias: bool = True,
272
+ attention_out_bias: bool = True,
273
+ ):
274
+ super().__init__()
275
+ self.dim = dim
276
+ self.num_attention_heads = num_attention_heads
277
+ self.attention_head_dim = attention_head_dim
278
+ self.dropout = dropout
279
+ self.cross_attention_dim = cross_attention_dim
280
+ self.activation_fn = activation_fn
281
+ self.attention_bias = attention_bias
282
+ self.double_self_attention = double_self_attention
283
+ self.norm_elementwise_affine = norm_elementwise_affine
284
+ self.positional_embeddings = positional_embeddings
285
+ self.num_positional_embeddings = num_positional_embeddings
286
+ self.only_cross_attention = only_cross_attention
287
+
288
+ # We keep these boolean flags for backward-compatibility.
289
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
290
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
291
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
292
+ self.use_layer_norm = norm_type == "layer_norm"
293
+ self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
294
+
295
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
296
+ raise ValueError(
297
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
298
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
299
+ )
300
+
301
+ self.norm_type = norm_type
302
+ self.num_embeds_ada_norm = num_embeds_ada_norm
303
+
304
+ if positional_embeddings and (num_positional_embeddings is None):
305
+ raise ValueError(
306
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
307
+ )
308
+
309
+ if positional_embeddings == "sinusoidal":
310
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
311
+ else:
312
+ self.pos_embed = None
313
+
314
+ # Define 3 blocks. Each block has its own normalization layer.
315
+ # 1. Self-Attn
316
+ if norm_type == "ada_norm":
317
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
318
+ elif norm_type == "ada_norm_zero":
319
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
320
+ elif norm_type == "ada_norm_continuous":
321
+ self.norm1 = AdaLayerNormContinuous(
322
+ dim,
323
+ ada_norm_continous_conditioning_embedding_dim,
324
+ norm_elementwise_affine,
325
+ norm_eps,
326
+ ada_norm_bias,
327
+ "rms_norm",
328
+ )
329
+ else:
330
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
331
+
332
+ self.attn1 = Attention(
333
+ query_dim=dim,
334
+ heads=num_attention_heads,
335
+ dim_head=attention_head_dim,
336
+ dropout=dropout,
337
+ bias=attention_bias,
338
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
339
+ upcast_attention=upcast_attention,
340
+ out_bias=attention_out_bias,
341
+ )
342
+
343
+ # 2. Cross-Attn
344
+ if cross_attention_dim is not None or double_self_attention:
345
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
346
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
347
+ # the second cross attention block.
348
+ if norm_type == "ada_norm":
349
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
350
+ elif norm_type == "ada_norm_continuous":
351
+ self.norm2 = AdaLayerNormContinuous(
352
+ dim,
353
+ ada_norm_continous_conditioning_embedding_dim,
354
+ norm_elementwise_affine,
355
+ norm_eps,
356
+ ada_norm_bias,
357
+ "rms_norm",
358
+ )
359
+ else:
360
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
361
+
362
+ self.attn2 = Attention(
363
+ query_dim=dim,
364
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
365
+ heads=num_attention_heads,
366
+ dim_head=attention_head_dim,
367
+ dropout=dropout,
368
+ bias=attention_bias,
369
+ upcast_attention=upcast_attention,
370
+ out_bias=attention_out_bias,
371
+ ) # is self-attn if encoder_hidden_states is none
372
+ else:
373
+ if norm_type == "ada_norm_single": # For Latte
374
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
375
+ else:
376
+ self.norm2 = None
377
+ self.attn2 = None
378
+
379
+ # 3. Feed-forward
380
+ if norm_type == "ada_norm_continuous":
381
+ self.norm3 = AdaLayerNormContinuous(
382
+ dim,
383
+ ada_norm_continous_conditioning_embedding_dim,
384
+ norm_elementwise_affine,
385
+ norm_eps,
386
+ ada_norm_bias,
387
+ "layer_norm",
388
+ )
389
+
390
+ elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
391
+ self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
392
+ elif norm_type == "layer_norm_i2vgen":
393
+ self.norm3 = None
394
+
395
+ self.ff = FeedForward(
396
+ dim,
397
+ dropout=dropout,
398
+ activation_fn=activation_fn,
399
+ final_dropout=final_dropout,
400
+ inner_dim=ff_inner_dim,
401
+ bias=ff_bias,
402
+ )
403
+
404
+ # 4. Fuser
405
+ if attention_type == "gated" or attention_type == "gated-text-image":
406
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
407
+
408
+ # 5. Scale-shift for PixArt-Alpha.
409
+ if norm_type == "ada_norm_single":
410
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
411
+
412
+ # let chunk size default to None
413
+ self._chunk_size = None
414
+ self._chunk_dim = 0
415
+
416
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
417
+ # Sets chunk feed-forward
418
+ self._chunk_size = chunk_size
419
+ self._chunk_dim = dim
420
+
421
+ def forward(
422
+ self,
423
+ hidden_states: torch.Tensor,
424
+ attention_mask: Optional[torch.Tensor] = None,
425
+ encoder_hidden_states: Optional[torch.Tensor] = None,
426
+ encoder_attention_mask: Optional[torch.Tensor] = None,
427
+ timestep: Optional[torch.LongTensor] = None,
428
+ cross_attention_kwargs: Dict[str, Any] = None,
429
+ class_labels: Optional[torch.LongTensor] = None,
430
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
431
+ ) -> torch.Tensor:
432
+ if cross_attention_kwargs is not None:
433
+ if cross_attention_kwargs.get("scale", None) is not None:
434
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
435
+
436
+ # Notice that normalization is always applied before the real computation in the following blocks.
437
+ # 0. Self-Attention
438
+ batch_size = hidden_states.shape[0]
439
+
440
+ if self.norm_type == "ada_norm":
441
+ norm_hidden_states = self.norm1(hidden_states, timestep)
442
+ elif self.norm_type == "ada_norm_zero":
443
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
444
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
445
+ )
446
+ elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
447
+ norm_hidden_states = self.norm1(hidden_states)
448
+ elif self.norm_type == "ada_norm_continuous":
449
+ norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
450
+ elif self.norm_type == "ada_norm_single":
451
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
452
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
453
+ ).chunk(6, dim=1)
454
+ norm_hidden_states = self.norm1(hidden_states)
455
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
456
+ else:
457
+ raise ValueError("Incorrect norm used")
458
+
459
+ if self.pos_embed is not None:
460
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
461
+
462
+ # 1. Prepare GLIGEN inputs
463
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
464
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
465
+
466
+ attn_output = self.attn1(
467
+ norm_hidden_states,
468
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
469
+ attention_mask=attention_mask,
470
+ **cross_attention_kwargs,
471
+ )
472
+
473
+ if self.norm_type == "ada_norm_zero":
474
+ attn_output = gate_msa.unsqueeze(1) * attn_output
475
+ elif self.norm_type == "ada_norm_single":
476
+ attn_output = gate_msa * attn_output
477
+
478
+ hidden_states = attn_output + hidden_states
479
+ if hidden_states.ndim == 4:
480
+ hidden_states = hidden_states.squeeze(1)
481
+
482
+ # 1.2 GLIGEN Control
483
+ if gligen_kwargs is not None:
484
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
485
+
486
+ # 3. Cross-Attention
487
+ if self.attn2 is not None:
488
+ if self.norm_type == "ada_norm":
489
+ norm_hidden_states = self.norm2(hidden_states, timestep)
490
+ elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
491
+ norm_hidden_states = self.norm2(hidden_states)
492
+ elif self.norm_type == "ada_norm_single":
493
+ # For PixArt norm2 isn't applied here:
494
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
495
+ norm_hidden_states = hidden_states
496
+ elif self.norm_type == "ada_norm_continuous":
497
+ norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
498
+ else:
499
+ raise ValueError("Incorrect norm")
500
+
501
+ if self.pos_embed is not None and self.norm_type != "ada_norm_single":
502
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
503
+
504
+ attn_output = self.attn2(
505
+ norm_hidden_states,
506
+ encoder_hidden_states=encoder_hidden_states,
507
+ attention_mask=encoder_attention_mask,
508
+ **cross_attention_kwargs,
509
+ )
510
+ hidden_states = attn_output + hidden_states
511
+
512
+ # 4. Feed-forward
513
+ # i2vgen doesn't have this norm 🤷‍♂️
514
+ if self.norm_type == "ada_norm_continuous":
515
+ norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
516
+ elif not self.norm_type == "ada_norm_single":
517
+ norm_hidden_states = self.norm3(hidden_states)
518
+
519
+ if self.norm_type == "ada_norm_zero":
520
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
521
+
522
+ if self.norm_type == "ada_norm_single":
523
+ norm_hidden_states = self.norm2(hidden_states)
524
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
525
+
526
+ if self._chunk_size is not None:
527
+ # "feed_forward_chunk_size" can be used to save memory
528
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
529
+ else:
530
+ ff_output = self.ff(norm_hidden_states)
531
+
532
+ if self.norm_type == "ada_norm_zero":
533
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
534
+ elif self.norm_type == "ada_norm_single":
535
+ ff_output = gate_mlp * ff_output
536
+
537
+ hidden_states = ff_output + hidden_states
538
+ if hidden_states.ndim == 4:
539
+ hidden_states = hidden_states.squeeze(1)
540
+
541
+ return hidden_states
542
+
543
+
544
+ class LuminaFeedForward(nn.Module):
545
+ r"""
546
+ A feed-forward layer.
547
+
548
+ Parameters:
549
+ hidden_size (`int`):
550
+ The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
551
+ hidden representations.
552
+ intermediate_size (`int`): The intermediate dimension of the feedforward layer.
553
+ multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
554
+ of this value.
555
+ ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
556
+ dimension. Defaults to None.
557
+ """
558
+
559
+ def __init__(
560
+ self,
561
+ dim: int,
562
+ inner_dim: int,
563
+ multiple_of: Optional[int] = 256,
564
+ ffn_dim_multiplier: Optional[float] = None,
565
+ ):
566
+ super().__init__()
567
+ inner_dim = int(2 * inner_dim / 3)
568
+ # custom hidden_size factor multiplier
569
+ if ffn_dim_multiplier is not None:
570
+ inner_dim = int(ffn_dim_multiplier * inner_dim)
571
+ inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
572
+
573
+ self.linear_1 = nn.Linear(
574
+ dim,
575
+ inner_dim,
576
+ bias=False,
577
+ )
578
+ self.linear_2 = nn.Linear(
579
+ inner_dim,
580
+ dim,
581
+ bias=False,
582
+ )
583
+ self.linear_3 = nn.Linear(
584
+ dim,
585
+ inner_dim,
586
+ bias=False,
587
+ )
588
+ self.silu = FP32SiLU()
589
+
590
+ def forward(self, x):
591
+ return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x))
592
+
593
+
594
+ @maybe_allow_in_graph
595
+ class TemporalBasicTransformerBlock(nn.Module):
596
+ r"""
597
+ A basic Transformer block for video like data.
598
+
599
+ Parameters:
600
+ dim (`int`): The number of channels in the input and output.
601
+ time_mix_inner_dim (`int`): The number of channels for temporal attention.
602
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
603
+ attention_head_dim (`int`): The number of channels in each head.
604
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
605
+ """
606
+
607
+ def __init__(
608
+ self,
609
+ dim: int,
610
+ time_mix_inner_dim: int,
611
+ num_attention_heads: int,
612
+ attention_head_dim: int,
613
+ cross_attention_dim: Optional[int] = None,
614
+ ):
615
+ super().__init__()
616
+ self.is_res = dim == time_mix_inner_dim
617
+
618
+ self.norm_in = nn.LayerNorm(dim)
619
+
620
+ # Define 3 blocks. Each block has its own normalization layer.
621
+ # 1. Self-Attn
622
+ self.ff_in = FeedForward(
623
+ dim,
624
+ dim_out=time_mix_inner_dim,
625
+ activation_fn="geglu",
626
+ )
627
+
628
+ self.norm1 = nn.LayerNorm(time_mix_inner_dim)
629
+ self.attn1 = Attention(
630
+ query_dim=time_mix_inner_dim,
631
+ heads=num_attention_heads,
632
+ dim_head=attention_head_dim,
633
+ cross_attention_dim=None,
634
+ )
635
+
636
+ # 2. Cross-Attn
637
+ if cross_attention_dim is not None:
638
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
639
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
640
+ # the second cross attention block.
641
+ self.norm2 = nn.LayerNorm(time_mix_inner_dim)
642
+ self.attn2 = Attention(
643
+ query_dim=time_mix_inner_dim,
644
+ cross_attention_dim=cross_attention_dim,
645
+ heads=num_attention_heads,
646
+ dim_head=attention_head_dim,
647
+ ) # is self-attn if encoder_hidden_states is none
648
+ else:
649
+ self.norm2 = None
650
+ self.attn2 = None
651
+
652
+ # 3. Feed-forward
653
+ self.norm3 = nn.LayerNorm(time_mix_inner_dim)
654
+ self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
655
+
656
+ # let chunk size default to None
657
+ self._chunk_size = None
658
+ self._chunk_dim = None
659
+
660
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
661
+ # Sets chunk feed-forward
662
+ self._chunk_size = chunk_size
663
+ # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
664
+ self._chunk_dim = 1
665
+
666
+ def forward(
667
+ self,
668
+ hidden_states: torch.Tensor,
669
+ num_frames: int,
670
+ encoder_hidden_states: Optional[torch.Tensor] = None,
671
+ ) -> torch.Tensor:
672
+ # Notice that normalization is always applied before the real computation in the following blocks.
673
+ # 0. Self-Attention
674
+ batch_size = hidden_states.shape[0]
675
+
676
+ batch_frames, seq_length, channels = hidden_states.shape
677
+ batch_size = batch_frames // num_frames
678
+
679
+ hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
680
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
681
+ hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
682
+
683
+ residual = hidden_states
684
+ hidden_states = self.norm_in(hidden_states)
685
+
686
+ if self._chunk_size is not None:
687
+ hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
688
+ else:
689
+ hidden_states = self.ff_in(hidden_states)
690
+
691
+ if self.is_res:
692
+ hidden_states = hidden_states + residual
693
+
694
+ norm_hidden_states = self.norm1(hidden_states)
695
+ attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
696
+ hidden_states = attn_output + hidden_states
697
+
698
+ # 3. Cross-Attention
699
+ if self.attn2 is not None:
700
+ norm_hidden_states = self.norm2(hidden_states)
701
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
702
+ hidden_states = attn_output + hidden_states
703
+
704
+ # 4. Feed-forward
705
+ norm_hidden_states = self.norm3(hidden_states)
706
+
707
+ if self._chunk_size is not None:
708
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
709
+ else:
710
+ ff_output = self.ff(norm_hidden_states)
711
+
712
+ if self.is_res:
713
+ hidden_states = ff_output + hidden_states
714
+ else:
715
+ hidden_states = ff_output
716
+
717
+ hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
718
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
719
+ hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
720
+
721
+ return hidden_states
722
+
723
+
724
+ class SkipFFTransformerBlock(nn.Module):
725
+ def __init__(
726
+ self,
727
+ dim: int,
728
+ num_attention_heads: int,
729
+ attention_head_dim: int,
730
+ kv_input_dim: int,
731
+ kv_input_dim_proj_use_bias: bool,
732
+ dropout=0.0,
733
+ cross_attention_dim: Optional[int] = None,
734
+ attention_bias: bool = False,
735
+ attention_out_bias: bool = True,
736
+ ):
737
+ super().__init__()
738
+ if kv_input_dim != dim:
739
+ self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
740
+ else:
741
+ self.kv_mapper = None
742
+
743
+ self.norm1 = RMSNorm(dim, 1e-06)
744
+
745
+ self.attn1 = Attention(
746
+ query_dim=dim,
747
+ heads=num_attention_heads,
748
+ dim_head=attention_head_dim,
749
+ dropout=dropout,
750
+ bias=attention_bias,
751
+ cross_attention_dim=cross_attention_dim,
752
+ out_bias=attention_out_bias,
753
+ )
754
+
755
+ self.norm2 = RMSNorm(dim, 1e-06)
756
+
757
+ self.attn2 = Attention(
758
+ query_dim=dim,
759
+ cross_attention_dim=cross_attention_dim,
760
+ heads=num_attention_heads,
761
+ dim_head=attention_head_dim,
762
+ dropout=dropout,
763
+ bias=attention_bias,
764
+ out_bias=attention_out_bias,
765
+ )
766
+
767
+ def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
768
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
769
+
770
+ if self.kv_mapper is not None:
771
+ encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
772
+
773
+ norm_hidden_states = self.norm1(hidden_states)
774
+
775
+ attn_output = self.attn1(
776
+ norm_hidden_states,
777
+ encoder_hidden_states=encoder_hidden_states,
778
+ **cross_attention_kwargs,
779
+ )
780
+
781
+ hidden_states = attn_output + hidden_states
782
+
783
+ norm_hidden_states = self.norm2(hidden_states)
784
+
785
+ attn_output = self.attn2(
786
+ norm_hidden_states,
787
+ encoder_hidden_states=encoder_hidden_states,
788
+ **cross_attention_kwargs,
789
+ )
790
+
791
+ hidden_states = attn_output + hidden_states
792
+
793
+ return hidden_states
794
+
795
+
796
+ @maybe_allow_in_graph
797
+ class FreeNoiseTransformerBlock(nn.Module):
798
+ r"""
799
+ A FreeNoise Transformer block.
800
+
801
+ Parameters:
802
+ dim (`int`):
803
+ The number of channels in the input and output.
804
+ num_attention_heads (`int`):
805
+ The number of heads to use for multi-head attention.
806
+ attention_head_dim (`int`):
807
+ The number of channels in each head.
808
+ dropout (`float`, *optional*, defaults to 0.0):
809
+ The dropout probability to use.
810
+ cross_attention_dim (`int`, *optional*):
811
+ The size of the encoder_hidden_states vector for cross attention.
812
+ activation_fn (`str`, *optional*, defaults to `"geglu"`):
813
+ Activation function to be used in feed-forward.
814
+ num_embeds_ada_norm (`int`, *optional*):
815
+ The number of diffusion steps used during training. See `Transformer2DModel`.
816
+ attention_bias (`bool`, defaults to `False`):
817
+ Configure if the attentions should contain a bias parameter.
818
+ only_cross_attention (`bool`, defaults to `False`):
819
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
820
+ double_self_attention (`bool`, defaults to `False`):
821
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
822
+ upcast_attention (`bool`, defaults to `False`):
823
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
824
+ norm_elementwise_affine (`bool`, defaults to `True`):
825
+ Whether to use learnable elementwise affine parameters for normalization.
826
+ norm_type (`str`, defaults to `"layer_norm"`):
827
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
828
+ final_dropout (`bool` defaults to `False`):
829
+ Whether to apply a final dropout after the last feed-forward layer.
830
+ attention_type (`str`, defaults to `"default"`):
831
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
832
+ positional_embeddings (`str`, *optional*):
833
+ The type of positional embeddings to apply to.
834
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
835
+ The maximum number of positional embeddings to apply.
836
+ ff_inner_dim (`int`, *optional*):
837
+ Hidden dimension of feed-forward MLP.
838
+ ff_bias (`bool`, defaults to `True`):
839
+ Whether or not to use bias in feed-forward MLP.
840
+ attention_out_bias (`bool`, defaults to `True`):
841
+ Whether or not to use bias in attention output project layer.
842
+ context_length (`int`, defaults to `16`):
843
+ The maximum number of frames that the FreeNoise block processes at once.
844
+ context_stride (`int`, defaults to `4`):
845
+ The number of frames to be skipped before starting to process a new batch of `context_length` frames.
846
+ weighting_scheme (`str`, defaults to `"pyramid"`):
847
+ The weighting scheme to use for weighting averaging of processed latent frames. As described in the
848
+ Equation 9. of the [FreeNoise](https://arxiv.org/abs/2310.15169) paper, "pyramid" is the default setting
849
+ used.
850
+ """
851
+
852
+ def __init__(
853
+ self,
854
+ dim: int,
855
+ num_attention_heads: int,
856
+ attention_head_dim: int,
857
+ dropout: float = 0.0,
858
+ cross_attention_dim: Optional[int] = None,
859
+ activation_fn: str = "geglu",
860
+ num_embeds_ada_norm: Optional[int] = None,
861
+ attention_bias: bool = False,
862
+ only_cross_attention: bool = False,
863
+ double_self_attention: bool = False,
864
+ upcast_attention: bool = False,
865
+ norm_elementwise_affine: bool = True,
866
+ norm_type: str = "layer_norm",
867
+ norm_eps: float = 1e-5,
868
+ final_dropout: bool = False,
869
+ positional_embeddings: Optional[str] = None,
870
+ num_positional_embeddings: Optional[int] = None,
871
+ ff_inner_dim: Optional[int] = None,
872
+ ff_bias: bool = True,
873
+ attention_out_bias: bool = True,
874
+ context_length: int = 16,
875
+ context_stride: int = 4,
876
+ weighting_scheme: str = "pyramid",
877
+ ):
878
+ super().__init__()
879
+ self.dim = dim
880
+ self.num_attention_heads = num_attention_heads
881
+ self.attention_head_dim = attention_head_dim
882
+ self.dropout = dropout
883
+ self.cross_attention_dim = cross_attention_dim
884
+ self.activation_fn = activation_fn
885
+ self.attention_bias = attention_bias
886
+ self.double_self_attention = double_self_attention
887
+ self.norm_elementwise_affine = norm_elementwise_affine
888
+ self.positional_embeddings = positional_embeddings
889
+ self.num_positional_embeddings = num_positional_embeddings
890
+ self.only_cross_attention = only_cross_attention
891
+
892
+ self.set_free_noise_properties(context_length, context_stride, weighting_scheme)
893
+
894
+ # We keep these boolean flags for backward-compatibility.
895
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
896
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
897
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
898
+ self.use_layer_norm = norm_type == "layer_norm"
899
+ self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
900
+
901
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
902
+ raise ValueError(
903
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
904
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
905
+ )
906
+
907
+ self.norm_type = norm_type
908
+ self.num_embeds_ada_norm = num_embeds_ada_norm
909
+
910
+ if positional_embeddings and (num_positional_embeddings is None):
911
+ raise ValueError(
912
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
913
+ )
914
+
915
+ if positional_embeddings == "sinusoidal":
916
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
917
+ else:
918
+ self.pos_embed = None
919
+
920
+ # Define 3 blocks. Each block has its own normalization layer.
921
+ # 1. Self-Attn
922
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
923
+
924
+ self.attn1 = Attention(
925
+ query_dim=dim,
926
+ heads=num_attention_heads,
927
+ dim_head=attention_head_dim,
928
+ dropout=dropout,
929
+ bias=attention_bias,
930
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
931
+ upcast_attention=upcast_attention,
932
+ out_bias=attention_out_bias,
933
+ )
934
+
935
+ # 2. Cross-Attn
936
+ if cross_attention_dim is not None or double_self_attention:
937
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
938
+
939
+ self.attn2 = Attention(
940
+ query_dim=dim,
941
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
942
+ heads=num_attention_heads,
943
+ dim_head=attention_head_dim,
944
+ dropout=dropout,
945
+ bias=attention_bias,
946
+ upcast_attention=upcast_attention,
947
+ out_bias=attention_out_bias,
948
+ ) # is self-attn if encoder_hidden_states is none
949
+
950
+ # 3. Feed-forward
951
+ self.ff = FeedForward(
952
+ dim,
953
+ dropout=dropout,
954
+ activation_fn=activation_fn,
955
+ final_dropout=final_dropout,
956
+ inner_dim=ff_inner_dim,
957
+ bias=ff_bias,
958
+ )
959
+
960
+ self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
961
+
962
+ # let chunk size default to None
963
+ self._chunk_size = None
964
+ self._chunk_dim = 0
965
+
966
+ def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]:
967
+ frame_indices = []
968
+ for i in range(0, num_frames - self.context_length + 1, self.context_stride):
969
+ window_start = i
970
+ window_end = min(num_frames, i + self.context_length)
971
+ frame_indices.append((window_start, window_end))
972
+ return frame_indices
973
+
974
+ def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
975
+ if weighting_scheme == "flat":
976
+ weights = [1.0] * num_frames
977
+
978
+ elif weighting_scheme == "pyramid":
979
+ if num_frames % 2 == 0:
980
+ # num_frames = 4 => [1, 2, 2, 1]
981
+ mid = num_frames // 2
982
+ weights = list(range(1, mid + 1))
983
+ weights = weights + weights[::-1]
984
+ else:
985
+ # num_frames = 5 => [1, 2, 3, 2, 1]
986
+ mid = (num_frames + 1) // 2
987
+ weights = list(range(1, mid))
988
+ weights = weights + [mid] + weights[::-1]
989
+
990
+ elif weighting_scheme == "delayed_reverse_sawtooth":
991
+ if num_frames % 2 == 0:
992
+ # num_frames = 4 => [0.01, 2, 2, 1]
993
+ mid = num_frames // 2
994
+ weights = [0.01] * (mid - 1) + [mid]
995
+ weights = weights + list(range(mid, 0, -1))
996
+ else:
997
+ # num_frames = 5 => [0.01, 0.01, 3, 2, 1]
998
+ mid = (num_frames + 1) // 2
999
+ weights = [0.01] * mid
1000
+ weights = weights + list(range(mid, 0, -1))
1001
+ else:
1002
+ raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")
1003
+
1004
+ return weights
1005
+
1006
+ def set_free_noise_properties(
1007
+ self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid"
1008
+ ) -> None:
1009
+ self.context_length = context_length
1010
+ self.context_stride = context_stride
1011
+ self.weighting_scheme = weighting_scheme
1012
+
1013
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0) -> None:
1014
+ # Sets chunk feed-forward
1015
+ self._chunk_size = chunk_size
1016
+ self._chunk_dim = dim
1017
+
1018
+ def forward(
1019
+ self,
1020
+ hidden_states: torch.Tensor,
1021
+ attention_mask: Optional[torch.Tensor] = None,
1022
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1023
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1024
+ cross_attention_kwargs: Dict[str, Any] = None,
1025
+ *args,
1026
+ **kwargs,
1027
+ ) -> torch.Tensor:
1028
+ if cross_attention_kwargs is not None:
1029
+ if cross_attention_kwargs.get("scale", None) is not None:
1030
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
1031
+
1032
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
1033
+
1034
+ # hidden_states: [B x H x W, F, C]
1035
+ device = hidden_states.device
1036
+ dtype = hidden_states.dtype
1037
+
1038
+ num_frames = hidden_states.size(1)
1039
+ frame_indices = self._get_frame_indices(num_frames)
1040
+ frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme)
1041
+ frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1)
1042
+ is_last_frame_batch_complete = frame_indices[-1][1] == num_frames
1043
+
1044
+ # Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length
1045
+ # For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges:
1046
+ # [(0, 16), (4, 20), (8, 24), (10, 26)]
1047
+ if not is_last_frame_batch_complete:
1048
+ if num_frames < self.context_length:
1049
+ raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}")
1050
+ last_frame_batch_length = num_frames - frame_indices[-1][1]
1051
+ frame_indices.append((num_frames - self.context_length, num_frames))
1052
+
1053
+ num_times_accumulated = torch.zeros((1, num_frames, 1), device=device)
1054
+ accumulated_values = torch.zeros_like(hidden_states)
1055
+
1056
+ for i, (frame_start, frame_end) in enumerate(frame_indices):
1057
+ # The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle
1058
+ # cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or
1059
+ # essentially a non-multiple of `context_length`.
1060
+ weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end])
1061
+ weights *= frame_weights
1062
+
1063
+ hidden_states_chunk = hidden_states[:, frame_start:frame_end]
1064
+
1065
+ # Notice that normalization is always applied before the real computation in the following blocks.
1066
+ # 1. Self-Attention
1067
+ norm_hidden_states = self.norm1(hidden_states_chunk)
1068
+
1069
+ if self.pos_embed is not None:
1070
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
1071
+
1072
+ attn_output = self.attn1(
1073
+ norm_hidden_states,
1074
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
1075
+ attention_mask=attention_mask,
1076
+ **cross_attention_kwargs,
1077
+ )
1078
+
1079
+ hidden_states_chunk = attn_output + hidden_states_chunk
1080
+ if hidden_states_chunk.ndim == 4:
1081
+ hidden_states_chunk = hidden_states_chunk.squeeze(1)
1082
+
1083
+ # 2. Cross-Attention
1084
+ if self.attn2 is not None:
1085
+ norm_hidden_states = self.norm2(hidden_states_chunk)
1086
+
1087
+ if self.pos_embed is not None and self.norm_type != "ada_norm_single":
1088
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
1089
+
1090
+ attn_output = self.attn2(
1091
+ norm_hidden_states,
1092
+ encoder_hidden_states=encoder_hidden_states,
1093
+ attention_mask=encoder_attention_mask,
1094
+ **cross_attention_kwargs,
1095
+ )
1096
+ hidden_states_chunk = attn_output + hidden_states_chunk
1097
+
1098
+ if i == len(frame_indices) - 1 and not is_last_frame_batch_complete:
1099
+ accumulated_values[:, -last_frame_batch_length:] += (
1100
+ hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:]
1101
+ )
1102
+ num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length]
1103
+ else:
1104
+ accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
1105
+ num_times_accumulated[:, frame_start:frame_end] += weights
1106
+
1107
+ # TODO(aryan): Maybe this could be done in a better way.
1108
+ #
1109
+ # Previously, this was:
1110
+ # hidden_states = torch.where(
1111
+ # num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
1112
+ # )
1113
+ #
1114
+ # The reasoning for the change here is `torch.where` became a bottleneck at some point when golfing memory
1115
+ # spikes. It is particularly noticeable when the number of frames is high. My understanding is that this comes
1116
+ # from tensors being copied - which is why we resort to spliting and concatenating here. I've not particularly
1117
+ # looked into this deeply because other memory optimizations led to more pronounced reductions.
1118
+ hidden_states = torch.cat(
1119
+ [
1120
+ torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split)
1121
+ for accumulated_split, num_times_split in zip(
1122
+ accumulated_values.split(self.context_length, dim=1),
1123
+ num_times_accumulated.split(self.context_length, dim=1),
1124
+ )
1125
+ ],
1126
+ dim=1,
1127
+ ).to(dtype)
1128
+
1129
+ # 3. Feed-forward
1130
+ norm_hidden_states = self.norm3(hidden_states)
1131
+
1132
+ if self._chunk_size is not None:
1133
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
1134
+ else:
1135
+ ff_output = self.ff(norm_hidden_states)
1136
+
1137
+ hidden_states = ff_output + hidden_states
1138
+ if hidden_states.ndim == 4:
1139
+ hidden_states = hidden_states.squeeze(1)
1140
+
1141
+ return hidden_states
1142
+
1143
+
1144
+ class FeedForward(nn.Module):
1145
+ r"""
1146
+ A feed-forward layer.
1147
+
1148
+ Parameters:
1149
+ dim (`int`): The number of channels in the input.
1150
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
1151
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
1152
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
1153
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
1154
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
1155
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
1156
+ """
1157
+
1158
+ def __init__(
1159
+ self,
1160
+ dim: int,
1161
+ dim_out: Optional[int] = None,
1162
+ mult: int = 4,
1163
+ dropout: float = 0.0,
1164
+ activation_fn: str = "geglu",
1165
+ final_dropout: bool = False,
1166
+ inner_dim=None,
1167
+ bias: bool = True,
1168
+ ):
1169
+ super().__init__()
1170
+ if inner_dim is None:
1171
+ inner_dim = int(dim * mult)
1172
+ dim_out = dim_out if dim_out is not None else dim
1173
+
1174
+ if activation_fn == "gelu":
1175
+ act_fn = GELU(dim, inner_dim, bias=bias)
1176
+ if activation_fn == "gelu-approximate":
1177
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
1178
+ elif activation_fn == "geglu":
1179
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
1180
+ elif activation_fn == "geglu-approximate":
1181
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
1182
+ elif activation_fn == "swiglu":
1183
+ act_fn = SwiGLU(dim, inner_dim, bias=bias)
1184
+
1185
+ self.net = nn.ModuleList([])
1186
+ # project in
1187
+ self.net.append(act_fn)
1188
+ # project dropout
1189
+ self.net.append(nn.Dropout(dropout))
1190
+ # project out
1191
+ self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
1192
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
1193
+ if final_dropout:
1194
+ self.net.append(nn.Dropout(dropout))
1195
+
1196
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
1197
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
1198
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1199
+ deprecate("scale", "1.0.0", deprecation_message)
1200
+ for module in self.net:
1201
+ hidden_states = module(hidden_states)
1202
+ return hidden_states
diffusers3/models/.ipynb_checkpoints/controlnet-checkpoint.py ADDED
@@ -0,0 +1,870 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ from torch import nn
19
+ from torch.nn import functional as F
20
+
21
+ from ..configuration_utils import ConfigMixin, register_to_config
22
+ from ..loaders.single_file_model import FromOriginalModelMixin
23
+ from ..utils import BaseOutput, logging
24
+ from .attention_processor import (
25
+ ADDED_KV_ATTENTION_PROCESSORS,
26
+ CROSS_ATTENTION_PROCESSORS,
27
+ AttentionProcessor,
28
+ AttnAddedKVProcessor,
29
+ AttnProcessor,
30
+ )
31
+ from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
32
+ from .modeling_utils import ModelMixin
33
+ from .unets.unet_2d_blocks import (
34
+ CrossAttnDownBlock2D,
35
+ DownBlock2D,
36
+ UNetMidBlock2D,
37
+ UNetMidBlock2DCrossAttn,
38
+ get_down_block,
39
+ )
40
+ from .unets.unet_2d_condition import UNet2DConditionModel
41
+
42
+
43
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
44
+
45
+
46
+ @dataclass
47
+ class ControlNetOutput(BaseOutput):
48
+ """
49
+ The output of [`ControlNetModel`].
50
+
51
+ Args:
52
+ down_block_res_samples (`tuple[torch.Tensor]`):
53
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
54
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
55
+ used to condition the original UNet's downsampling activations.
56
+ mid_down_block_re_sample (`torch.Tensor`):
57
+ The activation of the middle block (the lowest sample resolution). Each tensor should be of shape
58
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
59
+ Output can be used to condition the original UNet's middle block activation.
60
+ """
61
+
62
+ down_block_res_samples: Tuple[torch.Tensor]
63
+ mid_block_res_sample: torch.Tensor
64
+
65
+
66
+ class ControlNetConditioningEmbedding(nn.Module):
67
+ """
68
+ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
69
+ [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
70
+ training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
71
+ convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
72
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
73
+ model) to encode image-space conditions ... into feature maps ..."
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ conditioning_embedding_channels: int,
79
+ conditioning_channels: int = 3,
80
+ block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
81
+ ):
82
+ super().__init__()
83
+
84
+ self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
85
+
86
+ self.blocks = nn.ModuleList([])
87
+
88
+ for i in range(len(block_out_channels) - 1):
89
+ channel_in = block_out_channels[i]
90
+ channel_out = block_out_channels[i + 1]
91
+ self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
92
+ self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
93
+
94
+ self.conv_out = zero_module(
95
+ nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
96
+ )
97
+
98
+ def forward(self, conditioning):
99
+ embedding = self.conv_in(conditioning)
100
+ embedding = F.silu(embedding)
101
+
102
+ for block in self.blocks:
103
+ embedding = block(embedding)
104
+ embedding = F.silu(embedding)
105
+
106
+ embedding = self.conv_out(embedding)
107
+
108
+ return embedding
109
+
110
+
111
+ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
112
+ """
113
+ A ControlNet model.
114
+
115
+ Args:
116
+ in_channels (`int`, defaults to 4):
117
+ The number of channels in the input sample.
118
+ flip_sin_to_cos (`bool`, defaults to `True`):
119
+ Whether to flip the sin to cos in the time embedding.
120
+ freq_shift (`int`, defaults to 0):
121
+ The frequency shift to apply to the time embedding.
122
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
123
+ The tuple of downsample blocks to use.
124
+ only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
125
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
126
+ The tuple of output channels for each block.
127
+ layers_per_block (`int`, defaults to 2):
128
+ The number of layers per block.
129
+ downsample_padding (`int`, defaults to 1):
130
+ The padding to use for the downsampling convolution.
131
+ mid_block_scale_factor (`float`, defaults to 1):
132
+ The scale factor to use for the mid block.
133
+ act_fn (`str`, defaults to "silu"):
134
+ The activation function to use.
135
+ norm_num_groups (`int`, *optional*, defaults to 32):
136
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
137
+ in post-processing.
138
+ norm_eps (`float`, defaults to 1e-5):
139
+ The epsilon to use for the normalization.
140
+ cross_attention_dim (`int`, defaults to 1280):
141
+ The dimension of the cross attention features.
142
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
143
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
144
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
145
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
146
+ encoder_hid_dim (`int`, *optional*, defaults to None):
147
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
148
+ dimension to `cross_attention_dim`.
149
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
150
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
151
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
152
+ attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
153
+ The dimension of the attention heads.
154
+ use_linear_projection (`bool`, defaults to `False`):
155
+ class_embed_type (`str`, *optional*, defaults to `None`):
156
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
157
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
158
+ addition_embed_type (`str`, *optional*, defaults to `None`):
159
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
160
+ "text". "text" will use the `TextTimeEmbedding` layer.
161
+ num_class_embeds (`int`, *optional*, defaults to 0):
162
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
163
+ class conditioning with `class_embed_type` equal to `None`.
164
+ upcast_attention (`bool`, defaults to `False`):
165
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
166
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
167
+ projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
168
+ The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
169
+ `class_embed_type="projection"`.
170
+ controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
171
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
172
+ conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
173
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
174
+ global_pool_conditions (`bool`, defaults to `False`):
175
+ TODO(Patrick) - unused parameter.
176
+ addition_embed_type_num_heads (`int`, defaults to 64):
177
+ The number of heads to use for the `TextTimeEmbedding` layer.
178
+ """
179
+
180
+ _supports_gradient_checkpointing = True
181
+
182
+ @register_to_config
183
+ def __init__(
184
+ self,
185
+ in_channels: int = 4,
186
+ conditioning_channels: int = 3,
187
+ flip_sin_to_cos: bool = True,
188
+ freq_shift: int = 0,
189
+ down_block_types: Tuple[str, ...] = (
190
+ "CrossAttnDownBlock2D",
191
+ "CrossAttnDownBlock2D",
192
+ "CrossAttnDownBlock2D",
193
+ "DownBlock2D",
194
+ ),
195
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
196
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
197
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
198
+ layers_per_block: int = 2,
199
+ downsample_padding: int = 1,
200
+ mid_block_scale_factor: float = 1,
201
+ act_fn: str = "silu",
202
+ norm_num_groups: Optional[int] = 32,
203
+ norm_eps: float = 1e-5,
204
+ cross_attention_dim: int = 1280,
205
+ transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
206
+ encoder_hid_dim: Optional[int] = None,
207
+ encoder_hid_dim_type: Optional[str] = None,
208
+ attention_head_dim: Union[int, Tuple[int, ...]] = 8,
209
+ num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
210
+ use_linear_projection: bool = False,
211
+ class_embed_type: Optional[str] = None,
212
+ addition_embed_type: Optional[str] = None,
213
+ addition_time_embed_dim: Optional[int] = None,
214
+ num_class_embeds: Optional[int] = None,
215
+ upcast_attention: bool = False,
216
+ resnet_time_scale_shift: str = "default",
217
+ projection_class_embeddings_input_dim: Optional[int] = None,
218
+ controlnet_conditioning_channel_order: str = "rgb",
219
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
220
+ global_pool_conditions: bool = False,
221
+ addition_embed_type_num_heads: int = 64,
222
+ ):
223
+ super().__init__()
224
+
225
+ # If `num_attention_heads` is not defined (which is the case for most models)
226
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
227
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
228
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
229
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
230
+ # which is why we correct for the naming here.
231
+ num_attention_heads = num_attention_heads or attention_head_dim
232
+
233
+ # Check inputs
234
+ if len(block_out_channels) != len(down_block_types):
235
+ raise ValueError(
236
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
237
+ )
238
+
239
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
240
+ raise ValueError(
241
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
242
+ )
243
+
244
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
245
+ raise ValueError(
246
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
247
+ )
248
+
249
+ if isinstance(transformer_layers_per_block, int):
250
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
251
+
252
+ # input
253
+ conv_in_kernel = 3
254
+ conv_in_padding = (conv_in_kernel - 1) // 2
255
+ self.conv_in = nn.Conv2d(
256
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
257
+ )
258
+
259
+ # time
260
+ time_embed_dim = block_out_channels[0] * 4
261
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
262
+ timestep_input_dim = block_out_channels[0]
263
+ self.time_embedding = TimestepEmbedding(
264
+ timestep_input_dim,
265
+ time_embed_dim,
266
+ act_fn=act_fn,
267
+ )
268
+
269
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
270
+ encoder_hid_dim_type = "text_proj"
271
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
272
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
273
+
274
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
275
+ raise ValueError(
276
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
277
+ )
278
+
279
+ if encoder_hid_dim_type == "text_proj":
280
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
281
+ elif encoder_hid_dim_type == "text_image_proj":
282
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
283
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
284
+ # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
285
+ self.encoder_hid_proj = TextImageProjection(
286
+ text_embed_dim=encoder_hid_dim,
287
+ image_embed_dim=cross_attention_dim,
288
+ cross_attention_dim=cross_attention_dim,
289
+ )
290
+
291
+ elif encoder_hid_dim_type is not None:
292
+ raise ValueError(
293
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
294
+ )
295
+ else:
296
+ self.encoder_hid_proj = None
297
+
298
+ # class embedding
299
+ if class_embed_type is None and num_class_embeds is not None:
300
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
301
+ elif class_embed_type == "timestep":
302
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
303
+ elif class_embed_type == "identity":
304
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
305
+ elif class_embed_type == "projection":
306
+ if projection_class_embeddings_input_dim is None:
307
+ raise ValueError(
308
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
309
+ )
310
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
311
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
312
+ # 2. it projects from an arbitrary input dimension.
313
+ #
314
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
315
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
316
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
317
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
318
+ else:
319
+ self.class_embedding = None
320
+
321
+ if addition_embed_type == "text":
322
+ if encoder_hid_dim is not None:
323
+ text_time_embedding_from_dim = encoder_hid_dim
324
+ else:
325
+ text_time_embedding_from_dim = cross_attention_dim
326
+
327
+ self.add_embedding = TextTimeEmbedding(
328
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
329
+ )
330
+ elif addition_embed_type == "text_image":
331
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
332
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
333
+ # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
334
+ self.add_embedding = TextImageTimeEmbedding(
335
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
336
+ )
337
+ elif addition_embed_type == "text_time":
338
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
339
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
340
+
341
+ elif addition_embed_type is not None:
342
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
343
+
344
+ # control net conditioning embedding
345
+ self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
346
+ conditioning_embedding_channels=block_out_channels[0],
347
+ block_out_channels=conditioning_embedding_out_channels,
348
+ conditioning_channels=conditioning_channels,
349
+ )
350
+
351
+ self.down_blocks = nn.ModuleList([])
352
+ self.controlnet_down_blocks = nn.ModuleList([])
353
+
354
+ if isinstance(only_cross_attention, bool):
355
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
356
+
357
+ if isinstance(attention_head_dim, int):
358
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
359
+
360
+ if isinstance(num_attention_heads, int):
361
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
362
+
363
+ # down
364
+ output_channel = block_out_channels[0]
365
+
366
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
367
+ controlnet_block = zero_module(controlnet_block)
368
+ self.controlnet_down_blocks.append(controlnet_block)
369
+
370
+ for i, down_block_type in enumerate(down_block_types):
371
+ input_channel = output_channel
372
+ output_channel = block_out_channels[i]
373
+ is_final_block = i == len(block_out_channels) - 1
374
+
375
+ down_block = get_down_block(
376
+ down_block_type,
377
+ num_layers=layers_per_block,
378
+ transformer_layers_per_block=transformer_layers_per_block[i],
379
+ in_channels=input_channel,
380
+ out_channels=output_channel,
381
+ temb_channels=time_embed_dim,
382
+ add_downsample=not is_final_block,
383
+ resnet_eps=norm_eps,
384
+ resnet_act_fn=act_fn,
385
+ resnet_groups=norm_num_groups,
386
+ cross_attention_dim=cross_attention_dim,
387
+ num_attention_heads=num_attention_heads[i],
388
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
389
+ downsample_padding=downsample_padding,
390
+ use_linear_projection=use_linear_projection,
391
+ only_cross_attention=only_cross_attention[i],
392
+ upcast_attention=upcast_attention,
393
+ resnet_time_scale_shift=resnet_time_scale_shift,
394
+ )
395
+ self.down_blocks.append(down_block)
396
+
397
+ for _ in range(layers_per_block):
398
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
399
+ controlnet_block = zero_module(controlnet_block)
400
+ self.controlnet_down_blocks.append(controlnet_block)
401
+
402
+ if not is_final_block:
403
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
404
+ controlnet_block = zero_module(controlnet_block)
405
+ self.controlnet_down_blocks.append(controlnet_block)
406
+
407
+ # mid
408
+ mid_block_channel = block_out_channels[-1]
409
+
410
+ controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
411
+ controlnet_block = zero_module(controlnet_block)
412
+ self.controlnet_mid_block = controlnet_block
413
+
414
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
415
+ self.mid_block = UNetMidBlock2DCrossAttn(
416
+ transformer_layers_per_block=transformer_layers_per_block[-1],
417
+ in_channels=mid_block_channel,
418
+ temb_channels=time_embed_dim,
419
+ resnet_eps=norm_eps,
420
+ resnet_act_fn=act_fn,
421
+ output_scale_factor=mid_block_scale_factor,
422
+ resnet_time_scale_shift=resnet_time_scale_shift,
423
+ cross_attention_dim=cross_attention_dim,
424
+ num_attention_heads=num_attention_heads[-1],
425
+ resnet_groups=norm_num_groups,
426
+ use_linear_projection=use_linear_projection,
427
+ upcast_attention=upcast_attention,
428
+ )
429
+ elif mid_block_type == "UNetMidBlock2D":
430
+ self.mid_block = UNetMidBlock2D(
431
+ in_channels=block_out_channels[-1],
432
+ temb_channels=time_embed_dim,
433
+ num_layers=0,
434
+ resnet_eps=norm_eps,
435
+ resnet_act_fn=act_fn,
436
+ output_scale_factor=mid_block_scale_factor,
437
+ resnet_groups=norm_num_groups,
438
+ resnet_time_scale_shift=resnet_time_scale_shift,
439
+ add_attention=False,
440
+ )
441
+ else:
442
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
443
+
444
+ @classmethod
445
+ def from_unet(
446
+ cls,
447
+ unet: UNet2DConditionModel,
448
+ controlnet_conditioning_channel_order: str = "rgb",
449
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
450
+ load_weights_from_unet: bool = True,
451
+ conditioning_channels: int = 3,
452
+ ):
453
+ r"""
454
+ Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
455
+
456
+ Parameters:
457
+ unet (`UNet2DConditionModel`):
458
+ The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
459
+ where applicable.
460
+ """
461
+ transformer_layers_per_block = (
462
+ unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
463
+ )
464
+ encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
465
+ encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
466
+ addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
467
+ addition_time_embed_dim = (
468
+ unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
469
+ )
470
+
471
+ controlnet = cls(
472
+ encoder_hid_dim=encoder_hid_dim,
473
+ encoder_hid_dim_type=encoder_hid_dim_type,
474
+ addition_embed_type=addition_embed_type,
475
+ addition_time_embed_dim=addition_time_embed_dim,
476
+ transformer_layers_per_block=transformer_layers_per_block,
477
+ in_channels=unet.config.in_channels,
478
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
479
+ freq_shift=unet.config.freq_shift,
480
+ down_block_types=unet.config.down_block_types,
481
+ only_cross_attention=unet.config.only_cross_attention,
482
+ block_out_channels=unet.config.block_out_channels,
483
+ layers_per_block=unet.config.layers_per_block,
484
+ downsample_padding=unet.config.downsample_padding,
485
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
486
+ act_fn=unet.config.act_fn,
487
+ norm_num_groups=unet.config.norm_num_groups,
488
+ norm_eps=unet.config.norm_eps,
489
+ cross_attention_dim=unet.config.cross_attention_dim,
490
+ attention_head_dim=unet.config.attention_head_dim,
491
+ num_attention_heads=unet.config.num_attention_heads,
492
+ use_linear_projection=unet.config.use_linear_projection,
493
+ class_embed_type=unet.config.class_embed_type,
494
+ num_class_embeds=unet.config.num_class_embeds,
495
+ upcast_attention=unet.config.upcast_attention,
496
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
497
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
498
+ mid_block_type=unet.config.mid_block_type,
499
+ controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
500
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
501
+ conditioning_channels=conditioning_channels,
502
+ )
503
+
504
+ if load_weights_from_unet:
505
+ controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
506
+ controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
507
+ controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
508
+
509
+ if controlnet.class_embedding:
510
+ controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
511
+
512
+ if hasattr(controlnet, "add_embedding"):
513
+ controlnet.add_embedding.load_state_dict(unet.add_embedding.state_dict())
514
+
515
+ controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
516
+ controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
517
+
518
+ return controlnet
519
+
520
+ @property
521
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
522
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
523
+ r"""
524
+ Returns:
525
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
526
+ indexed by its weight name.
527
+ """
528
+ # set recursively
529
+ processors = {}
530
+
531
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
532
+ if hasattr(module, "get_processor"):
533
+ processors[f"{name}.processor"] = module.get_processor()
534
+
535
+ for sub_name, child in module.named_children():
536
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
537
+
538
+ return processors
539
+
540
+ for name, module in self.named_children():
541
+ fn_recursive_add_processors(name, module, processors)
542
+
543
+ return processors
544
+
545
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
546
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
547
+ r"""
548
+ Sets the attention processor to use to compute attention.
549
+
550
+ Parameters:
551
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
552
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
553
+ for **all** `Attention` layers.
554
+
555
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
556
+ processor. This is strongly recommended when setting trainable attention processors.
557
+
558
+ """
559
+ count = len(self.attn_processors.keys())
560
+
561
+ if isinstance(processor, dict) and len(processor) != count:
562
+ raise ValueError(
563
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
564
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
565
+ )
566
+
567
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
568
+ if hasattr(module, "set_processor"):
569
+ if not isinstance(processor, dict):
570
+ module.set_processor(processor)
571
+ else:
572
+ module.set_processor(processor.pop(f"{name}.processor"))
573
+
574
+ for sub_name, child in module.named_children():
575
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
576
+
577
+ for name, module in self.named_children():
578
+ fn_recursive_attn_processor(name, module, processor)
579
+
580
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
581
+ def set_default_attn_processor(self):
582
+ """
583
+ Disables custom attention processors and sets the default attention implementation.
584
+ """
585
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
586
+ processor = AttnAddedKVProcessor()
587
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
588
+ processor = AttnProcessor()
589
+ else:
590
+ raise ValueError(
591
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
592
+ )
593
+
594
+ self.set_attn_processor(processor)
595
+
596
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
597
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
598
+ r"""
599
+ Enable sliced attention computation.
600
+
601
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
602
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
603
+
604
+ Args:
605
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
606
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
607
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
608
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
609
+ must be a multiple of `slice_size`.
610
+ """
611
+ sliceable_head_dims = []
612
+
613
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
614
+ if hasattr(module, "set_attention_slice"):
615
+ sliceable_head_dims.append(module.sliceable_head_dim)
616
+
617
+ for child in module.children():
618
+ fn_recursive_retrieve_sliceable_dims(child)
619
+
620
+ # retrieve number of attention layers
621
+ for module in self.children():
622
+ fn_recursive_retrieve_sliceable_dims(module)
623
+
624
+ num_sliceable_layers = len(sliceable_head_dims)
625
+
626
+ if slice_size == "auto":
627
+ # half the attention head size is usually a good trade-off between
628
+ # speed and memory
629
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
630
+ elif slice_size == "max":
631
+ # make smallest slice possible
632
+ slice_size = num_sliceable_layers * [1]
633
+
634
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
635
+
636
+ if len(slice_size) != len(sliceable_head_dims):
637
+ raise ValueError(
638
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
639
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
640
+ )
641
+
642
+ for i in range(len(slice_size)):
643
+ size = slice_size[i]
644
+ dim = sliceable_head_dims[i]
645
+ if size is not None and size > dim:
646
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
647
+
648
+ # Recursively walk through all the children.
649
+ # Any children which exposes the set_attention_slice method
650
+ # gets the message
651
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
652
+ if hasattr(module, "set_attention_slice"):
653
+ module.set_attention_slice(slice_size.pop())
654
+
655
+ for child in module.children():
656
+ fn_recursive_set_attention_slice(child, slice_size)
657
+
658
+ reversed_slice_size = list(reversed(slice_size))
659
+ for module in self.children():
660
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
661
+
662
+ def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
663
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
664
+ module.gradient_checkpointing = value
665
+
666
+ def forward(
667
+ self,
668
+ sample: torch.Tensor,
669
+ timestep: Union[torch.Tensor, float, int],
670
+ encoder_hidden_states: torch.Tensor,
671
+ controlnet_cond: torch.Tensor,
672
+ conditioning_scale: float = 1.0,
673
+ class_labels: Optional[torch.Tensor] = None,
674
+ timestep_cond: Optional[torch.Tensor] = None,
675
+ attention_mask: Optional[torch.Tensor] = None,
676
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
677
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
678
+ guess_mode: bool = False,
679
+ return_dict: bool = True,
680
+ ) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
681
+ """
682
+ The [`ControlNetModel`] forward method.
683
+
684
+ Args:
685
+ sample (`torch.Tensor`):
686
+ The noisy input tensor.
687
+ timestep (`Union[torch.Tensor, float, int]`):
688
+ The number of timesteps to denoise an input.
689
+ encoder_hidden_states (`torch.Tensor`):
690
+ The encoder hidden states.
691
+ controlnet_cond (`torch.Tensor`):
692
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
693
+ conditioning_scale (`float`, defaults to `1.0`):
694
+ The scale factor for ControlNet outputs.
695
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
696
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
697
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
698
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
699
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
700
+ embeddings.
701
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
702
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
703
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
704
+ negative values to the attention scores corresponding to "discard" tokens.
705
+ added_cond_kwargs (`dict`):
706
+ Additional conditions for the Stable Diffusion XL UNet.
707
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
708
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
709
+ guess_mode (`bool`, defaults to `False`):
710
+ In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
711
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
712
+ return_dict (`bool`, defaults to `True`):
713
+ Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
714
+
715
+ Returns:
716
+ [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
717
+ If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
718
+ returned where the first element is the sample tensor.
719
+ """
720
+ # check channel order
721
+ channel_order = self.config.controlnet_conditioning_channel_order
722
+
723
+ if channel_order == "rgb":
724
+ # in rgb order by default
725
+ ...
726
+ elif channel_order == "bgr":
727
+ controlnet_cond = torch.flip(controlnet_cond, dims=[1])
728
+ else:
729
+ raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
730
+
731
+ # prepare attention_mask
732
+ if attention_mask is not None:
733
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
734
+ attention_mask = attention_mask.unsqueeze(1)
735
+
736
+ # 1. time
737
+ timesteps = timestep
738
+ if not torch.is_tensor(timesteps):
739
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
740
+ # This would be a good case for the `match` statement (Python 3.10+)
741
+ is_mps = sample.device.type == "mps"
742
+ if isinstance(timestep, float):
743
+ dtype = torch.float32 if is_mps else torch.float64
744
+ else:
745
+ dtype = torch.int32 if is_mps else torch.int64
746
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
747
+ elif len(timesteps.shape) == 0:
748
+ timesteps = timesteps[None].to(sample.device)
749
+
750
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
751
+ timesteps = timesteps.expand(sample.shape[0])
752
+
753
+ t_emb = self.time_proj(timesteps)
754
+
755
+ # timesteps does not contain any weights and will always return f32 tensors
756
+ # but time_embedding might actually be running in fp16. so we need to cast here.
757
+ # there might be better ways to encapsulate this.
758
+ t_emb = t_emb.to(dtype=sample.dtype)
759
+
760
+ emb = self.time_embedding(t_emb, timestep_cond)
761
+ aug_emb = None
762
+
763
+ if self.class_embedding is not None:
764
+ if class_labels is None:
765
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
766
+
767
+ if self.config.class_embed_type == "timestep":
768
+ class_labels = self.time_proj(class_labels)
769
+
770
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
771
+ emb = emb + class_emb
772
+
773
+ if self.config.addition_embed_type is not None:
774
+ if self.config.addition_embed_type == "text":
775
+ aug_emb = self.add_embedding(encoder_hidden_states)
776
+
777
+ elif self.config.addition_embed_type == "text_time":
778
+ if "text_embeds" not in added_cond_kwargs:
779
+ raise ValueError(
780
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
781
+ )
782
+ text_embeds = added_cond_kwargs.get("text_embeds")
783
+ if "time_ids" not in added_cond_kwargs:
784
+ raise ValueError(
785
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
786
+ )
787
+ time_ids = added_cond_kwargs.get("time_ids")
788
+ time_embeds = self.add_time_proj(time_ids.flatten())
789
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
790
+
791
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
792
+ add_embeds = add_embeds.to(emb.dtype)
793
+ aug_emb = self.add_embedding(add_embeds)
794
+
795
+ emb = emb + aug_emb if aug_emb is not None else emb
796
+
797
+ # 2. pre-process
798
+ sample = self.conv_in(sample)
799
+
800
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
801
+ sample = sample + controlnet_cond
802
+
803
+ # 3. down
804
+ down_block_res_samples = (sample,)
805
+ for downsample_block in self.down_blocks:
806
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
807
+ sample, res_samples = downsample_block(
808
+ hidden_states=sample,
809
+ temb=emb,
810
+ encoder_hidden_states=encoder_hidden_states,
811
+ attention_mask=attention_mask,
812
+ cross_attention_kwargs=cross_attention_kwargs,
813
+ )
814
+ else:
815
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
816
+
817
+ down_block_res_samples += res_samples
818
+
819
+ # 4. mid
820
+ if self.mid_block is not None:
821
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
822
+ sample = self.mid_block(
823
+ sample,
824
+ emb,
825
+ encoder_hidden_states=encoder_hidden_states,
826
+ attention_mask=attention_mask,
827
+ cross_attention_kwargs=cross_attention_kwargs,
828
+ )
829
+ else:
830
+ sample = self.mid_block(sample, emb)
831
+
832
+ # 5. Control net blocks
833
+ controlnet_down_block_res_samples = ()
834
+
835
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
836
+ down_block_res_sample = controlnet_block(down_block_res_sample)
837
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
838
+
839
+ down_block_res_samples = controlnet_down_block_res_samples
840
+
841
+ mid_block_res_sample = self.controlnet_mid_block(sample)
842
+
843
+ # 6. scaling
844
+ if guess_mode and not self.config.global_pool_conditions:
845
+ scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
846
+ scales = scales * conditioning_scale
847
+ down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
848
+ mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
849
+ else:
850
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
851
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
852
+
853
+ if self.config.global_pool_conditions:
854
+ down_block_res_samples = [
855
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
856
+ ]
857
+ mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
858
+
859
+ if not return_dict:
860
+ return (down_block_res_samples, mid_block_res_sample)
861
+
862
+ return ControlNetOutput(
863
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
864
+ )
865
+
866
+
867
+ def zero_module(module):
868
+ for p in module.parameters():
869
+ nn.init.zeros_(p)
870
+ return module
diffusers3/models/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Models
2
+
3
+ For more detail on the models, please refer to the [docs](https://huggingface.co/docs/diffusers/api/models/overview).
diffusers3/models/__init__.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ from ..utils import (
18
+ DIFFUSERS_SLOW_IMPORT,
19
+ _LazyModule,
20
+ is_flax_available,
21
+ is_torch_available,
22
+ )
23
+
24
+
25
+ _import_structure = {}
26
+
27
+ if is_torch_available():
28
+ _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
29
+ _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
30
+ _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
31
+ _import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
32
+ _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
33
+ _import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
34
+ _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
35
+ _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
36
+ _import_structure["autoencoders.vq_model"] = ["VQModel"]
37
+ _import_structure["controlnet"] = ["ControlNetModel"]
38
+ _import_structure["controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
39
+ _import_structure["controlnet_hunyuan"] = ["HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel"]
40
+ _import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
41
+ _import_structure["controlnet_sparsectrl"] = ["SparseControlNetModel"]
42
+ _import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
43
+ _import_structure["embeddings"] = ["ImageProjection"]
44
+ _import_structure["modeling_utils"] = ["ModelMixin"]
45
+ _import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"]
46
+ _import_structure["transformers.cogvideox_transformer_3d"] = ["CogVideoXTransformer3DModel"]
47
+ _import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"]
48
+ _import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"]
49
+ _import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"]
50
+ _import_structure["transformers.latte_transformer_3d"] = ["LatteTransformer3DModel"]
51
+ _import_structure["transformers.lumina_nextdit2d"] = ["LuminaNextDiT2DModel"]
52
+ _import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"]
53
+ _import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
54
+ _import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
55
+ _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
56
+ _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
57
+ _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
58
+ _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
59
+ _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
60
+ _import_structure["unets.unet_1d"] = ["UNet1DModel"]
61
+ _import_structure["unets.unet_2d"] = ["UNet2DModel"]
62
+ _import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
63
+ _import_structure["unets.unet_3d_condition"] = ["UNet3DConditionModel"]
64
+ _import_structure["unets.unet_i2vgen_xl"] = ["I2VGenXLUNet"]
65
+ _import_structure["unets.unet_kandinsky3"] = ["Kandinsky3UNet"]
66
+ _import_structure["unets.unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
67
+ _import_structure["unets.unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
68
+ _import_structure["unets.unet_stable_cascade"] = ["StableCascadeUNet"]
69
+ _import_structure["unets.uvit_2d"] = ["UVit2DModel"]
70
+
71
+ if is_flax_available():
72
+ _import_structure["controlnet_flax"] = ["FlaxControlNetModel"]
73
+ _import_structure["unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
74
+ _import_structure["vae_flax"] = ["FlaxAutoencoderKL"]
75
+
76
+
77
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
78
+ if is_torch_available():
79
+ from .adapter import MultiAdapter, T2IAdapter
80
+ from .autoencoders import (
81
+ AsymmetricAutoencoderKL,
82
+ AutoencoderKL,
83
+ AutoencoderKLCogVideoX,
84
+ AutoencoderKLTemporalDecoder,
85
+ AutoencoderOobleck,
86
+ AutoencoderTiny,
87
+ ConsistencyDecoderVAE,
88
+ VQModel,
89
+ )
90
+ from .controlnet import ControlNetModel
91
+ from .controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
92
+ from .controlnet_hunyuan import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel
93
+ from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
94
+ from .controlnet_sparsectrl import SparseControlNetModel
95
+ from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel
96
+ from .embeddings import ImageProjection
97
+ from .modeling_utils import ModelMixin
98
+ from .transformers import (
99
+ AuraFlowTransformer2DModel,
100
+ CogVideoXTransformer3DModel,
101
+ DiTTransformer2DModel,
102
+ DualTransformer2DModel,
103
+ FluxTransformer2DModel,
104
+ HunyuanDiT2DModel,
105
+ LatteTransformer3DModel,
106
+ LuminaNextDiT2DModel,
107
+ PixArtTransformer2DModel,
108
+ PriorTransformer,
109
+ SD3Transformer2DModel,
110
+ StableAudioDiTModel,
111
+ T5FilmDecoder,
112
+ Transformer2DModel,
113
+ TransformerTemporalModel,
114
+ )
115
+ from .unets import (
116
+ I2VGenXLUNet,
117
+ Kandinsky3UNet,
118
+ MotionAdapter,
119
+ StableCascadeUNet,
120
+ UNet1DModel,
121
+ UNet2DConditionModel,
122
+ UNet2DModel,
123
+ UNet3DConditionModel,
124
+ UNetMotionModel,
125
+ UNetSpatioTemporalConditionModel,
126
+ UVit2DModel,
127
+ )
128
+
129
+ if is_flax_available():
130
+ from .controlnet_flax import FlaxControlNetModel
131
+ from .unets import FlaxUNet2DConditionModel
132
+ from .vae_flax import FlaxAutoencoderKL
133
+
134
+ else:
135
+ import sys
136
+
137
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diffusers3/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (4.34 kB). View file
 
diffusers3/models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (4.34 kB). View file
 
diffusers3/models/__pycache__/activations.cpython-310.pyc ADDED
Binary file (6.42 kB). View file
 
diffusers3/models/__pycache__/activations.cpython-38.pyc ADDED
Binary file (6.38 kB). View file
 
diffusers3/models/__pycache__/attention.cpython-310.pyc ADDED
Binary file (29.6 kB). View file
 
diffusers3/models/__pycache__/attention.cpython-38.pyc ADDED
Binary file (29.3 kB). View file
 
diffusers3/models/__pycache__/attention_processor.cpython-310.pyc ADDED
Binary file (88.4 kB). View file
 
diffusers3/models/__pycache__/attention_processor.cpython-38.pyc ADDED
Binary file (89.7 kB). View file
 
diffusers3/models/__pycache__/controlnet.cpython-310.pyc ADDED
Binary file (27.9 kB). View file
 
diffusers3/models/__pycache__/controlnet.cpython-38.pyc ADDED
Binary file (27.4 kB). View file
 
diffusers3/models/__pycache__/downsampling.cpython-310.pyc ADDED
Binary file (12.5 kB). View file
 
diffusers3/models/__pycache__/downsampling.cpython-38.pyc ADDED
Binary file (12.4 kB). View file
 
diffusers3/models/__pycache__/embeddings.cpython-310.pyc ADDED
Binary file (52.8 kB). View file
 
diffusers3/models/__pycache__/embeddings.cpython-38.pyc ADDED
Binary file (52.5 kB). View file
 
diffusers3/models/__pycache__/lora.cpython-310.pyc ADDED
Binary file (13.2 kB). View file
 
diffusers3/models/__pycache__/lora.cpython-38.pyc ADDED
Binary file (13.3 kB). View file
 
diffusers3/models/__pycache__/model_loading_utils.cpython-310.pyc ADDED
Binary file (5.68 kB). View file
 
diffusers3/models/__pycache__/model_loading_utils.cpython-38.pyc ADDED
Binary file (5.66 kB). View file
 
diffusers3/models/__pycache__/modeling_outputs.cpython-310.pyc ADDED
Binary file (1.41 kB). View file
 
diffusers3/models/__pycache__/modeling_outputs.cpython-38.pyc ADDED
Binary file (1.44 kB). View file
 
diffusers3/models/__pycache__/modeling_utils.cpython-310.pyc ADDED
Binary file (39.5 kB). View file
 
diffusers3/models/__pycache__/modeling_utils.cpython-38.pyc ADDED
Binary file (39.6 kB). View file
 
diffusers3/models/__pycache__/normalization.cpython-310.pyc ADDED
Binary file (14.3 kB). View file
 
diffusers3/models/__pycache__/normalization.cpython-38.pyc ADDED
Binary file (14 kB). View file
 
diffusers3/models/__pycache__/resnet.cpython-310.pyc ADDED
Binary file (23 kB). View file
 
diffusers3/models/__pycache__/resnet.cpython-38.pyc ADDED
Binary file (22.7 kB). View file
 
diffusers3/models/__pycache__/upsampling.cpython-310.pyc ADDED
Binary file (14.5 kB). View file
 
diffusers3/models/__pycache__/upsampling.cpython-38.pyc ADDED
Binary file (14.3 kB). View file
 
diffusers3/models/activations.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 HuggingFace Inc.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from ..utils import deprecate
21
+ from ..utils.import_utils import is_torch_npu_available
22
+
23
+
24
+ if is_torch_npu_available():
25
+ import torch_npu
26
+
27
+ ACTIVATION_FUNCTIONS = {
28
+ "swish": nn.SiLU(),
29
+ "silu": nn.SiLU(),
30
+ "mish": nn.Mish(),
31
+ "gelu": nn.GELU(),
32
+ "relu": nn.ReLU(),
33
+ }
34
+
35
+
36
+ def get_activation(act_fn: str) -> nn.Module:
37
+ """Helper function to get activation function from string.
38
+
39
+ Args:
40
+ act_fn (str): Name of activation function.
41
+
42
+ Returns:
43
+ nn.Module: Activation function.
44
+ """
45
+
46
+ act_fn = act_fn.lower()
47
+ if act_fn in ACTIVATION_FUNCTIONS:
48
+ return ACTIVATION_FUNCTIONS[act_fn]
49
+ else:
50
+ raise ValueError(f"Unsupported activation function: {act_fn}")
51
+
52
+
53
+ class FP32SiLU(nn.Module):
54
+ r"""
55
+ SiLU activation function with input upcasted to torch.float32.
56
+ """
57
+
58
+ def __init__(self):
59
+ super().__init__()
60
+
61
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
62
+ return F.silu(inputs.float(), inplace=False).to(inputs.dtype)
63
+
64
+
65
+ class GELU(nn.Module):
66
+ r"""
67
+ GELU activation function with tanh approximation support with `approximate="tanh"`.
68
+
69
+ Parameters:
70
+ dim_in (`int`): The number of channels in the input.
71
+ dim_out (`int`): The number of channels in the output.
72
+ approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
73
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
74
+ """
75
+
76
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
77
+ super().__init__()
78
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
79
+ self.approximate = approximate
80
+
81
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
82
+ if gate.device.type != "mps":
83
+ return F.gelu(gate, approximate=self.approximate)
84
+ # mps: gelu is not implemented for float16
85
+ return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
86
+
87
+ def forward(self, hidden_states):
88
+ hidden_states = self.proj(hidden_states)
89
+ hidden_states = self.gelu(hidden_states)
90
+ return hidden_states
91
+
92
+
93
+ class GEGLU(nn.Module):
94
+ r"""
95
+ A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function.
96
+
97
+ Parameters:
98
+ dim_in (`int`): The number of channels in the input.
99
+ dim_out (`int`): The number of channels in the output.
100
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
101
+ """
102
+
103
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
104
+ super().__init__()
105
+ self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
106
+
107
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
108
+ if gate.device.type != "mps":
109
+ return F.gelu(gate)
110
+ # mps: gelu is not implemented for float16
111
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
112
+
113
+ def forward(self, hidden_states, *args, **kwargs):
114
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
115
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
116
+ deprecate("scale", "1.0.0", deprecation_message)
117
+ hidden_states = self.proj(hidden_states)
118
+ if is_torch_npu_available():
119
+ # using torch_npu.npu_geglu can run faster and save memory on NPU.
120
+ return torch_npu.npu_geglu(hidden_states, dim=-1, approximate=1)[0]
121
+ else:
122
+ hidden_states, gate = hidden_states.chunk(2, dim=-1)
123
+ return hidden_states * self.gelu(gate)
124
+
125
+
126
+ class SwiGLU(nn.Module):
127
+ r"""
128
+ A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. It's similar to `GEGLU`
129
+ but uses SiLU / Swish instead of GeLU.
130
+
131
+ Parameters:
132
+ dim_in (`int`): The number of channels in the input.
133
+ dim_out (`int`): The number of channels in the output.
134
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
135
+ """
136
+
137
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
138
+ super().__init__()
139
+ self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
140
+ self.activation = nn.SiLU()
141
+
142
+ def forward(self, hidden_states):
143
+ hidden_states = self.proj(hidden_states)
144
+ hidden_states, gate = hidden_states.chunk(2, dim=-1)
145
+ return hidden_states * self.activation(gate)
146
+
147
+
148
+ class ApproximateGELU(nn.Module):
149
+ r"""
150
+ The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this
151
+ [paper](https://arxiv.org/abs/1606.08415).
152
+
153
+ Parameters:
154
+ dim_in (`int`): The number of channels in the input.
155
+ dim_out (`int`): The number of channels in the output.
156
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
157
+ """
158
+
159
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
160
+ super().__init__()
161
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
162
+
163
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
164
+ x = self.proj(x)
165
+ return x * torch.sigmoid(1.702 * x)
diffusers3/models/adapter.py ADDED
@@ -0,0 +1,584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ from typing import Callable, List, Optional, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from ..configuration_utils import ConfigMixin, register_to_config
21
+ from ..utils import logging
22
+ from .modeling_utils import ModelMixin
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ class MultiAdapter(ModelMixin):
29
+ r"""
30
+ MultiAdapter is a wrapper model that contains multiple adapter models and merges their outputs according to
31
+ user-assigned weighting.
32
+
33
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
34
+ implements for all the model (such as downloading or saving, etc.)
35
+
36
+ Parameters:
37
+ adapters (`List[T2IAdapter]`, *optional*, defaults to None):
38
+ A list of `T2IAdapter` model instances.
39
+ """
40
+
41
+ def __init__(self, adapters: List["T2IAdapter"]):
42
+ super(MultiAdapter, self).__init__()
43
+
44
+ self.num_adapter = len(adapters)
45
+ self.adapters = nn.ModuleList(adapters)
46
+
47
+ if len(adapters) == 0:
48
+ raise ValueError("Expecting at least one adapter")
49
+
50
+ if len(adapters) == 1:
51
+ raise ValueError("For a single adapter, please use the `T2IAdapter` class instead of `MultiAdapter`")
52
+
53
+ # The outputs from each adapter are added together with a weight.
54
+ # This means that the change in dimensions from downsampling must
55
+ # be the same for all adapters. Inductively, it also means the
56
+ # downscale_factor and total_downscale_factor must be the same for all
57
+ # adapters.
58
+ first_adapter_total_downscale_factor = adapters[0].total_downscale_factor
59
+ first_adapter_downscale_factor = adapters[0].downscale_factor
60
+ for idx in range(1, len(adapters)):
61
+ if (
62
+ adapters[idx].total_downscale_factor != first_adapter_total_downscale_factor
63
+ or adapters[idx].downscale_factor != first_adapter_downscale_factor
64
+ ):
65
+ raise ValueError(
66
+ f"Expecting all adapters to have the same downscaling behavior, but got:\n"
67
+ f"adapters[0].total_downscale_factor={first_adapter_total_downscale_factor}\n"
68
+ f"adapters[0].downscale_factor={first_adapter_downscale_factor}\n"
69
+ f"adapter[`{idx}`].total_downscale_factor={adapters[idx].total_downscale_factor}\n"
70
+ f"adapter[`{idx}`].downscale_factor={adapters[idx].downscale_factor}"
71
+ )
72
+
73
+ self.total_downscale_factor = first_adapter_total_downscale_factor
74
+ self.downscale_factor = first_adapter_downscale_factor
75
+
76
+ def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = None) -> List[torch.Tensor]:
77
+ r"""
78
+ Args:
79
+ xs (`torch.Tensor`):
80
+ (batch, channel, height, width) input images for multiple adapter models concated along dimension 1,
81
+ `channel` should equal to `num_adapter` * "number of channel of image".
82
+ adapter_weights (`List[float]`, *optional*, defaults to None):
83
+ List of floats representing the weight which will be multiply to each adapter's output before adding
84
+ them together.
85
+ """
86
+ if adapter_weights is None:
87
+ adapter_weights = torch.tensor([1 / self.num_adapter] * self.num_adapter)
88
+ else:
89
+ adapter_weights = torch.tensor(adapter_weights)
90
+
91
+ accume_state = None
92
+ for x, w, adapter in zip(xs, adapter_weights, self.adapters):
93
+ features = adapter(x)
94
+ if accume_state is None:
95
+ accume_state = features
96
+ for i in range(len(accume_state)):
97
+ accume_state[i] = w * accume_state[i]
98
+ else:
99
+ for i in range(len(features)):
100
+ accume_state[i] += w * features[i]
101
+ return accume_state
102
+
103
+ def save_pretrained(
104
+ self,
105
+ save_directory: Union[str, os.PathLike],
106
+ is_main_process: bool = True,
107
+ save_function: Callable = None,
108
+ safe_serialization: bool = True,
109
+ variant: Optional[str] = None,
110
+ ):
111
+ """
112
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
113
+ `[`~models.adapter.MultiAdapter.from_pretrained`]` class method.
114
+
115
+ Arguments:
116
+ save_directory (`str` or `os.PathLike`):
117
+ Directory to which to save. Will be created if it doesn't exist.
118
+ is_main_process (`bool`, *optional*, defaults to `True`):
119
+ Whether the process calling this is the main process or not. Useful when in distributed training like
120
+ TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
121
+ the main process to avoid race conditions.
122
+ save_function (`Callable`):
123
+ The function to use to save the state dictionary. Useful on distributed training like TPUs when one
124
+ need to replace `torch.save` by another method. Can be configured with the environment variable
125
+ `DIFFUSERS_SAVE_MODE`.
126
+ safe_serialization (`bool`, *optional*, defaults to `True`):
127
+ Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
128
+ variant (`str`, *optional*):
129
+ If specified, weights are saved in the format pytorch_model.<variant>.bin.
130
+ """
131
+ idx = 0
132
+ model_path_to_save = save_directory
133
+ for adapter in self.adapters:
134
+ adapter.save_pretrained(
135
+ model_path_to_save,
136
+ is_main_process=is_main_process,
137
+ save_function=save_function,
138
+ safe_serialization=safe_serialization,
139
+ variant=variant,
140
+ )
141
+
142
+ idx += 1
143
+ model_path_to_save = model_path_to_save + f"_{idx}"
144
+
145
+ @classmethod
146
+ def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
147
+ r"""
148
+ Instantiate a pretrained MultiAdapter model from multiple pre-trained adapter models.
149
+
150
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
151
+ the model, you should first set it back in training mode with `model.train()`.
152
+
153
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
154
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
155
+ task.
156
+
157
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
158
+ weights are discarded.
159
+
160
+ Parameters:
161
+ pretrained_model_path (`os.PathLike`):
162
+ A path to a *directory* containing model weights saved using
163
+ [`~diffusers.models.adapter.MultiAdapter.save_pretrained`], e.g., `./my_model_directory/adapter`.
164
+ torch_dtype (`str` or `torch.dtype`, *optional*):
165
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
166
+ will be automatically derived from the model's weights.
167
+ output_loading_info(`bool`, *optional*, defaults to `False`):
168
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
169
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
170
+ A map that specifies where each submodule should go. It doesn't need to be refined to each
171
+ parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
172
+ same device.
173
+
174
+ To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
175
+ more information about each option see [designing a device
176
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
177
+ max_memory (`Dict`, *optional*):
178
+ A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
179
+ GPU and the available CPU RAM if unset.
180
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
181
+ Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
182
+ also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
183
+ model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
184
+ setting this argument to `True` will raise an error.
185
+ variant (`str`, *optional*):
186
+ If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
187
+ ignored when using `from_flax`.
188
+ use_safetensors (`bool`, *optional*, defaults to `None`):
189
+ If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
190
+ `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
191
+ `safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
192
+ """
193
+ idx = 0
194
+ adapters = []
195
+
196
+ # load adapter and append to list until no adapter directory exists anymore
197
+ # first adapter has to be saved under `./mydirectory/adapter` to be compliant with `DiffusionPipeline.from_pretrained`
198
+ # second, third, ... adapters have to be saved under `./mydirectory/adapter_1`, `./mydirectory/adapter_2`, ...
199
+ model_path_to_load = pretrained_model_path
200
+ while os.path.isdir(model_path_to_load):
201
+ adapter = T2IAdapter.from_pretrained(model_path_to_load, **kwargs)
202
+ adapters.append(adapter)
203
+
204
+ idx += 1
205
+ model_path_to_load = pretrained_model_path + f"_{idx}"
206
+
207
+ logger.info(f"{len(adapters)} adapters loaded from {pretrained_model_path}.")
208
+
209
+ if len(adapters) == 0:
210
+ raise ValueError(
211
+ f"No T2IAdapters found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}."
212
+ )
213
+
214
+ return cls(adapters)
215
+
216
+
217
+ class T2IAdapter(ModelMixin, ConfigMixin):
218
+ r"""
219
+ A simple ResNet-like model that accepts images containing control signals such as keyposes and depth. The model
220
+ generates multiple feature maps that are used as additional conditioning in [`UNet2DConditionModel`]. The model's
221
+ architecture follows the original implementation of
222
+ [Adapter](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L97)
223
+ and
224
+ [AdapterLight](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L235).
225
+
226
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
227
+ implements for all the model (such as downloading or saving, etc.)
228
+
229
+ Parameters:
230
+ in_channels (`int`, *optional*, defaults to 3):
231
+ Number of channels of Aapter's input(*control image*). Set this parameter to 1 if you're using gray scale
232
+ image as *control image*.
233
+ channels (`List[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
234
+ The number of channel of each downsample block's output hidden state. The `len(block_out_channels)` will
235
+ also determine the number of downsample blocks in the Adapter.
236
+ num_res_blocks (`int`, *optional*, defaults to 2):
237
+ Number of ResNet blocks in each downsample block.
238
+ downscale_factor (`int`, *optional*, defaults to 8):
239
+ A factor that determines the total downscale factor of the Adapter.
240
+ adapter_type (`str`, *optional*, defaults to `full_adapter`):
241
+ The type of Adapter to use. Choose either `full_adapter` or `full_adapter_xl` or `light_adapter`.
242
+ """
243
+
244
+ @register_to_config
245
+ def __init__(
246
+ self,
247
+ in_channels: int = 3,
248
+ channels: List[int] = [320, 640, 1280, 1280],
249
+ num_res_blocks: int = 2,
250
+ downscale_factor: int = 8,
251
+ adapter_type: str = "full_adapter",
252
+ ):
253
+ super().__init__()
254
+
255
+ if adapter_type == "full_adapter":
256
+ self.adapter = FullAdapter(in_channels, channels, num_res_blocks, downscale_factor)
257
+ elif adapter_type == "full_adapter_xl":
258
+ self.adapter = FullAdapterXL(in_channels, channels, num_res_blocks, downscale_factor)
259
+ elif adapter_type == "light_adapter":
260
+ self.adapter = LightAdapter(in_channels, channels, num_res_blocks, downscale_factor)
261
+ else:
262
+ raise ValueError(
263
+ f"Unsupported adapter_type: '{adapter_type}'. Choose either 'full_adapter' or "
264
+ "'full_adapter_xl' or 'light_adapter'."
265
+ )
266
+
267
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
268
+ r"""
269
+ This function processes the input tensor `x` through the adapter model and returns a list of feature tensors,
270
+ each representing information extracted at a different scale from the input. The length of the list is
271
+ determined by the number of downsample blocks in the Adapter, as specified by the `channels` and
272
+ `num_res_blocks` parameters during initialization.
273
+ """
274
+ return self.adapter(x)
275
+
276
+ @property
277
+ def total_downscale_factor(self):
278
+ return self.adapter.total_downscale_factor
279
+
280
+ @property
281
+ def downscale_factor(self):
282
+ """The downscale factor applied in the T2I-Adapter's initial pixel unshuffle operation. If an input image's dimensions are
283
+ not evenly divisible by the downscale_factor then an exception will be raised.
284
+ """
285
+ return self.adapter.unshuffle.downscale_factor
286
+
287
+
288
+ # full adapter
289
+
290
+
291
+ class FullAdapter(nn.Module):
292
+ r"""
293
+ See [`T2IAdapter`] for more information.
294
+ """
295
+
296
+ def __init__(
297
+ self,
298
+ in_channels: int = 3,
299
+ channels: List[int] = [320, 640, 1280, 1280],
300
+ num_res_blocks: int = 2,
301
+ downscale_factor: int = 8,
302
+ ):
303
+ super().__init__()
304
+
305
+ in_channels = in_channels * downscale_factor**2
306
+
307
+ self.unshuffle = nn.PixelUnshuffle(downscale_factor)
308
+ self.conv_in = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1)
309
+
310
+ self.body = nn.ModuleList(
311
+ [
312
+ AdapterBlock(channels[0], channels[0], num_res_blocks),
313
+ *[
314
+ AdapterBlock(channels[i - 1], channels[i], num_res_blocks, down=True)
315
+ for i in range(1, len(channels))
316
+ ],
317
+ ]
318
+ )
319
+
320
+ self.total_downscale_factor = downscale_factor * 2 ** (len(channels) - 1)
321
+
322
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
323
+ r"""
324
+ This method processes the input tensor `x` through the FullAdapter model and performs operations including
325
+ pixel unshuffling, convolution, and a stack of AdapterBlocks. It returns a list of feature tensors, each
326
+ capturing information at a different stage of processing within the FullAdapter model. The number of feature
327
+ tensors in the list is determined by the number of downsample blocks specified during initialization.
328
+ """
329
+ x = self.unshuffle(x)
330
+ x = self.conv_in(x)
331
+
332
+ features = []
333
+
334
+ for block in self.body:
335
+ x = block(x)
336
+ features.append(x)
337
+
338
+ return features
339
+
340
+
341
+ class FullAdapterXL(nn.Module):
342
+ r"""
343
+ See [`T2IAdapter`] for more information.
344
+ """
345
+
346
+ def __init__(
347
+ self,
348
+ in_channels: int = 3,
349
+ channels: List[int] = [320, 640, 1280, 1280],
350
+ num_res_blocks: int = 2,
351
+ downscale_factor: int = 16,
352
+ ):
353
+ super().__init__()
354
+
355
+ in_channels = in_channels * downscale_factor**2
356
+
357
+ self.unshuffle = nn.PixelUnshuffle(downscale_factor)
358
+ self.conv_in = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1)
359
+
360
+ self.body = []
361
+ # blocks to extract XL features with dimensions of [320, 64, 64], [640, 64, 64], [1280, 32, 32], [1280, 32, 32]
362
+ for i in range(len(channels)):
363
+ if i == 1:
364
+ self.body.append(AdapterBlock(channels[i - 1], channels[i], num_res_blocks))
365
+ elif i == 2:
366
+ self.body.append(AdapterBlock(channels[i - 1], channels[i], num_res_blocks, down=True))
367
+ else:
368
+ self.body.append(AdapterBlock(channels[i], channels[i], num_res_blocks))
369
+
370
+ self.body = nn.ModuleList(self.body)
371
+ # XL has only one downsampling AdapterBlock.
372
+ self.total_downscale_factor = downscale_factor * 2
373
+
374
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
375
+ r"""
376
+ This method takes the tensor x as input and processes it through FullAdapterXL model. It consists of operations
377
+ including unshuffling pixels, applying convolution layer and appending each block into list of feature tensors.
378
+ """
379
+ x = self.unshuffle(x)
380
+ x = self.conv_in(x)
381
+
382
+ features = []
383
+
384
+ for block in self.body:
385
+ x = block(x)
386
+ features.append(x)
387
+
388
+ return features
389
+
390
+
391
+ class AdapterBlock(nn.Module):
392
+ r"""
393
+ An AdapterBlock is a helper model that contains multiple ResNet-like blocks. It is used in the `FullAdapter` and
394
+ `FullAdapterXL` models.
395
+
396
+ Parameters:
397
+ in_channels (`int`):
398
+ Number of channels of AdapterBlock's input.
399
+ out_channels (`int`):
400
+ Number of channels of AdapterBlock's output.
401
+ num_res_blocks (`int`):
402
+ Number of ResNet blocks in the AdapterBlock.
403
+ down (`bool`, *optional*, defaults to `False`):
404
+ Whether to perform downsampling on AdapterBlock's input.
405
+ """
406
+
407
+ def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False):
408
+ super().__init__()
409
+
410
+ self.downsample = None
411
+ if down:
412
+ self.downsample = nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True)
413
+
414
+ self.in_conv = None
415
+ if in_channels != out_channels:
416
+ self.in_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
417
+
418
+ self.resnets = nn.Sequential(
419
+ *[AdapterResnetBlock(out_channels) for _ in range(num_res_blocks)],
420
+ )
421
+
422
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
423
+ r"""
424
+ This method takes tensor x as input and performs operations downsampling and convolutional layers if the
425
+ self.downsample and self.in_conv properties of AdapterBlock model are specified. Then it applies a series of
426
+ residual blocks to the input tensor.
427
+ """
428
+ if self.downsample is not None:
429
+ x = self.downsample(x)
430
+
431
+ if self.in_conv is not None:
432
+ x = self.in_conv(x)
433
+
434
+ x = self.resnets(x)
435
+
436
+ return x
437
+
438
+
439
+ class AdapterResnetBlock(nn.Module):
440
+ r"""
441
+ An `AdapterResnetBlock` is a helper model that implements a ResNet-like block.
442
+
443
+ Parameters:
444
+ channels (`int`):
445
+ Number of channels of AdapterResnetBlock's input and output.
446
+ """
447
+
448
+ def __init__(self, channels: int):
449
+ super().__init__()
450
+ self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
451
+ self.act = nn.ReLU()
452
+ self.block2 = nn.Conv2d(channels, channels, kernel_size=1)
453
+
454
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
455
+ r"""
456
+ This method takes input tensor x and applies a convolutional layer, ReLU activation, and another convolutional
457
+ layer on the input tensor. It returns addition with the input tensor.
458
+ """
459
+
460
+ h = self.act(self.block1(x))
461
+ h = self.block2(h)
462
+
463
+ return h + x
464
+
465
+
466
+ # light adapter
467
+
468
+
469
+ class LightAdapter(nn.Module):
470
+ r"""
471
+ See [`T2IAdapter`] for more information.
472
+ """
473
+
474
+ def __init__(
475
+ self,
476
+ in_channels: int = 3,
477
+ channels: List[int] = [320, 640, 1280],
478
+ num_res_blocks: int = 4,
479
+ downscale_factor: int = 8,
480
+ ):
481
+ super().__init__()
482
+
483
+ in_channels = in_channels * downscale_factor**2
484
+
485
+ self.unshuffle = nn.PixelUnshuffle(downscale_factor)
486
+
487
+ self.body = nn.ModuleList(
488
+ [
489
+ LightAdapterBlock(in_channels, channels[0], num_res_blocks),
490
+ *[
491
+ LightAdapterBlock(channels[i], channels[i + 1], num_res_blocks, down=True)
492
+ for i in range(len(channels) - 1)
493
+ ],
494
+ LightAdapterBlock(channels[-1], channels[-1], num_res_blocks, down=True),
495
+ ]
496
+ )
497
+
498
+ self.total_downscale_factor = downscale_factor * (2 ** len(channels))
499
+
500
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
501
+ r"""
502
+ This method takes the input tensor x and performs downscaling and appends it in list of feature tensors. Each
503
+ feature tensor corresponds to a different level of processing within the LightAdapter.
504
+ """
505
+ x = self.unshuffle(x)
506
+
507
+ features = []
508
+
509
+ for block in self.body:
510
+ x = block(x)
511
+ features.append(x)
512
+
513
+ return features
514
+
515
+
516
+ class LightAdapterBlock(nn.Module):
517
+ r"""
518
+ A `LightAdapterBlock` is a helper model that contains multiple `LightAdapterResnetBlocks`. It is used in the
519
+ `LightAdapter` model.
520
+
521
+ Parameters:
522
+ in_channels (`int`):
523
+ Number of channels of LightAdapterBlock's input.
524
+ out_channels (`int`):
525
+ Number of channels of LightAdapterBlock's output.
526
+ num_res_blocks (`int`):
527
+ Number of LightAdapterResnetBlocks in the LightAdapterBlock.
528
+ down (`bool`, *optional*, defaults to `False`):
529
+ Whether to perform downsampling on LightAdapterBlock's input.
530
+ """
531
+
532
+ def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False):
533
+ super().__init__()
534
+ mid_channels = out_channels // 4
535
+
536
+ self.downsample = None
537
+ if down:
538
+ self.downsample = nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True)
539
+
540
+ self.in_conv = nn.Conv2d(in_channels, mid_channels, kernel_size=1)
541
+ self.resnets = nn.Sequential(*[LightAdapterResnetBlock(mid_channels) for _ in range(num_res_blocks)])
542
+ self.out_conv = nn.Conv2d(mid_channels, out_channels, kernel_size=1)
543
+
544
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
545
+ r"""
546
+ This method takes tensor x as input and performs downsampling if required. Then it applies in convolution
547
+ layer, a sequence of residual blocks, and out convolutional layer.
548
+ """
549
+ if self.downsample is not None:
550
+ x = self.downsample(x)
551
+
552
+ x = self.in_conv(x)
553
+ x = self.resnets(x)
554
+ x = self.out_conv(x)
555
+
556
+ return x
557
+
558
+
559
+ class LightAdapterResnetBlock(nn.Module):
560
+ """
561
+ A `LightAdapterResnetBlock` is a helper model that implements a ResNet-like block with a slightly different
562
+ architecture than `AdapterResnetBlock`.
563
+
564
+ Parameters:
565
+ channels (`int`):
566
+ Number of channels of LightAdapterResnetBlock's input and output.
567
+ """
568
+
569
+ def __init__(self, channels: int):
570
+ super().__init__()
571
+ self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
572
+ self.act = nn.ReLU()
573
+ self.block2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
574
+
575
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
576
+ r"""
577
+ This function takes input tensor x and processes it through one convolutional layer, ReLU activation, and
578
+ another convolutional layer and adds it to input tensor.
579
+ """
580
+
581
+ h = self.act(self.block1(x))
582
+ h = self.block2(h)
583
+
584
+ return h + x
diffusers3/models/attention.py ADDED
@@ -0,0 +1,1202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, List, Optional, Tuple
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from ..utils import deprecate, logging
21
+ from ..utils.torch_utils import maybe_allow_in_graph
22
+ from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
23
+ from .attention_processor import Attention, JointAttnProcessor2_0
24
+ from .embeddings import SinusoidalPositionalEmbedding
25
+ from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
26
+
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
32
+ # "feed_forward_chunk_size" can be used to save memory
33
+ if hidden_states.shape[chunk_dim] % chunk_size != 0:
34
+ raise ValueError(
35
+ f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
36
+ )
37
+
38
+ num_chunks = hidden_states.shape[chunk_dim] // chunk_size
39
+ ff_output = torch.cat(
40
+ [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
41
+ dim=chunk_dim,
42
+ )
43
+ return ff_output
44
+
45
+
46
+ @maybe_allow_in_graph
47
+ class GatedSelfAttentionDense(nn.Module):
48
+ r"""
49
+ A gated self-attention dense layer that combines visual features and object features.
50
+
51
+ Parameters:
52
+ query_dim (`int`): The number of channels in the query.
53
+ context_dim (`int`): The number of channels in the context.
54
+ n_heads (`int`): The number of heads to use for attention.
55
+ d_head (`int`): The number of channels in each head.
56
+ """
57
+
58
+ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
59
+ super().__init__()
60
+
61
+ # we need a linear projection since we need cat visual feature and obj feature
62
+ self.linear = nn.Linear(context_dim, query_dim)
63
+
64
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
65
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
66
+
67
+ self.norm1 = nn.LayerNorm(query_dim)
68
+ self.norm2 = nn.LayerNorm(query_dim)
69
+
70
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
71
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
72
+
73
+ self.enabled = True
74
+
75
+ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
76
+ if not self.enabled:
77
+ return x
78
+
79
+ n_visual = x.shape[1]
80
+ objs = self.linear(objs)
81
+
82
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
83
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
84
+
85
+ return x
86
+
87
+
88
+ @maybe_allow_in_graph
89
+ class JointTransformerBlock(nn.Module):
90
+ r"""
91
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
92
+
93
+ Reference: https://arxiv.org/abs/2403.03206
94
+
95
+ Parameters:
96
+ dim (`int`): The number of channels in the input and output.
97
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
98
+ attention_head_dim (`int`): The number of channels in each head.
99
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
100
+ processing of `context` conditions.
101
+ """
102
+
103
+ def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False):
104
+ super().__init__()
105
+
106
+ self.context_pre_only = context_pre_only
107
+ context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
108
+
109
+ self.norm1 = AdaLayerNormZero(dim)
110
+
111
+ if context_norm_type == "ada_norm_continous":
112
+ self.norm1_context = AdaLayerNormContinuous(
113
+ dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
114
+ )
115
+ elif context_norm_type == "ada_norm_zero":
116
+ self.norm1_context = AdaLayerNormZero(dim)
117
+ else:
118
+ raise ValueError(
119
+ f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
120
+ )
121
+ if hasattr(F, "scaled_dot_product_attention"):
122
+ processor = JointAttnProcessor2_0()
123
+ else:
124
+ raise ValueError(
125
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
126
+ )
127
+ self.attn = Attention(
128
+ query_dim=dim,
129
+ cross_attention_dim=None,
130
+ added_kv_proj_dim=dim,
131
+ dim_head=attention_head_dim,
132
+ heads=num_attention_heads,
133
+ out_dim=dim,
134
+ context_pre_only=context_pre_only,
135
+ bias=True,
136
+ processor=processor,
137
+ )
138
+
139
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
140
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
141
+
142
+ if not context_pre_only:
143
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
144
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
145
+ else:
146
+ self.norm2_context = None
147
+ self.ff_context = None
148
+
149
+ # let chunk size default to None
150
+ self._chunk_size = None
151
+ self._chunk_dim = 0
152
+
153
+ # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
154
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
155
+ # Sets chunk feed-forward
156
+ self._chunk_size = chunk_size
157
+ self._chunk_dim = dim
158
+
159
+ def forward(
160
+ self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
161
+ ):
162
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
163
+
164
+ if self.context_pre_only:
165
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
166
+ else:
167
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
168
+ encoder_hidden_states, emb=temb
169
+ )
170
+
171
+ # Attention.
172
+ attn_output, context_attn_output = self.attn(
173
+ hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
174
+ )
175
+
176
+ # Process attention outputs for the `hidden_states`.
177
+ attn_output = gate_msa.unsqueeze(1) * attn_output
178
+ hidden_states = hidden_states + attn_output
179
+
180
+ norm_hidden_states = self.norm2(hidden_states)
181
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
182
+ if self._chunk_size is not None:
183
+ # "feed_forward_chunk_size" can be used to save memory
184
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
185
+ else:
186
+ ff_output = self.ff(norm_hidden_states)
187
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
188
+
189
+ hidden_states = hidden_states + ff_output
190
+
191
+ # Process attention outputs for the `encoder_hidden_states`.
192
+ if self.context_pre_only:
193
+ encoder_hidden_states = None
194
+ else:
195
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
196
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
197
+
198
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
199
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
200
+ if self._chunk_size is not None:
201
+ # "feed_forward_chunk_size" can be used to save memory
202
+ context_ff_output = _chunked_feed_forward(
203
+ self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
204
+ )
205
+ else:
206
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
207
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
208
+
209
+ return encoder_hidden_states, hidden_states
210
+
211
+
212
+ @maybe_allow_in_graph
213
+ class BasicTransformerBlock(nn.Module):
214
+ r"""
215
+ A basic Transformer block.
216
+
217
+ Parameters:
218
+ dim (`int`): The number of channels in the input and output.
219
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
220
+ attention_head_dim (`int`): The number of channels in each head.
221
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
222
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
223
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
224
+ num_embeds_ada_norm (:
225
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
226
+ attention_bias (:
227
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
228
+ only_cross_attention (`bool`, *optional*):
229
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
230
+ double_self_attention (`bool`, *optional*):
231
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
232
+ upcast_attention (`bool`, *optional*):
233
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
234
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
235
+ Whether to use learnable elementwise affine parameters for normalization.
236
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
237
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
238
+ final_dropout (`bool` *optional*, defaults to False):
239
+ Whether to apply a final dropout after the last feed-forward layer.
240
+ attention_type (`str`, *optional*, defaults to `"default"`):
241
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
242
+ positional_embeddings (`str`, *optional*, defaults to `None`):
243
+ The type of positional embeddings to apply to.
244
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
245
+ The maximum number of positional embeddings to apply.
246
+ """
247
+
248
+ def __init__(
249
+ self,
250
+ dim: int,
251
+ num_attention_heads: int,
252
+ attention_head_dim: int,
253
+ dropout=0.0,
254
+ cross_attention_dim: Optional[int] = None,
255
+ activation_fn: str = "geglu",
256
+ num_embeds_ada_norm: Optional[int] = None,
257
+ attention_bias: bool = False,
258
+ only_cross_attention: bool = False,
259
+ double_self_attention: bool = False,
260
+ upcast_attention: bool = False,
261
+ norm_elementwise_affine: bool = True,
262
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
263
+ norm_eps: float = 1e-5,
264
+ final_dropout: bool = False,
265
+ attention_type: str = "default",
266
+ positional_embeddings: Optional[str] = None,
267
+ num_positional_embeddings: Optional[int] = None,
268
+ ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
269
+ ada_norm_bias: Optional[int] = None,
270
+ ff_inner_dim: Optional[int] = None,
271
+ ff_bias: bool = True,
272
+ attention_out_bias: bool = True,
273
+ ):
274
+ super().__init__()
275
+ self.dim = dim
276
+ self.num_attention_heads = num_attention_heads
277
+ self.attention_head_dim = attention_head_dim
278
+ self.dropout = dropout
279
+ self.cross_attention_dim = cross_attention_dim
280
+ self.activation_fn = activation_fn
281
+ self.attention_bias = attention_bias
282
+ self.double_self_attention = double_self_attention
283
+ self.norm_elementwise_affine = norm_elementwise_affine
284
+ self.positional_embeddings = positional_embeddings
285
+ self.num_positional_embeddings = num_positional_embeddings
286
+ self.only_cross_attention = only_cross_attention
287
+
288
+ # We keep these boolean flags for backward-compatibility.
289
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
290
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
291
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
292
+ self.use_layer_norm = norm_type == "layer_norm"
293
+ self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
294
+
295
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
296
+ raise ValueError(
297
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
298
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
299
+ )
300
+
301
+ self.norm_type = norm_type
302
+ self.num_embeds_ada_norm = num_embeds_ada_norm
303
+
304
+ if positional_embeddings and (num_positional_embeddings is None):
305
+ raise ValueError(
306
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
307
+ )
308
+
309
+ if positional_embeddings == "sinusoidal":
310
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
311
+ else:
312
+ self.pos_embed = None
313
+
314
+ # Define 3 blocks. Each block has its own normalization layer.
315
+ # 1. Self-Attn
316
+ if norm_type == "ada_norm":
317
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
318
+ elif norm_type == "ada_norm_zero":
319
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
320
+ elif norm_type == "ada_norm_continuous":
321
+ self.norm1 = AdaLayerNormContinuous(
322
+ dim,
323
+ ada_norm_continous_conditioning_embedding_dim,
324
+ norm_elementwise_affine,
325
+ norm_eps,
326
+ ada_norm_bias,
327
+ "rms_norm",
328
+ )
329
+ else:
330
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
331
+
332
+ self.attn1 = Attention(
333
+ query_dim=dim,
334
+ heads=num_attention_heads,
335
+ dim_head=attention_head_dim,
336
+ dropout=dropout,
337
+ bias=attention_bias,
338
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
339
+ upcast_attention=upcast_attention,
340
+ out_bias=attention_out_bias,
341
+ )
342
+
343
+ # 2. Cross-Attn
344
+ if cross_attention_dim is not None or double_self_attention:
345
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
346
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
347
+ # the second cross attention block.
348
+ if norm_type == "ada_norm":
349
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
350
+ elif norm_type == "ada_norm_continuous":
351
+ self.norm2 = AdaLayerNormContinuous(
352
+ dim,
353
+ ada_norm_continous_conditioning_embedding_dim,
354
+ norm_elementwise_affine,
355
+ norm_eps,
356
+ ada_norm_bias,
357
+ "rms_norm",
358
+ )
359
+ else:
360
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
361
+
362
+ self.attn2 = Attention(
363
+ query_dim=dim,
364
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
365
+ heads=num_attention_heads,
366
+ dim_head=attention_head_dim,
367
+ dropout=dropout,
368
+ bias=attention_bias,
369
+ upcast_attention=upcast_attention,
370
+ out_bias=attention_out_bias,
371
+ ) # is self-attn if encoder_hidden_states is none
372
+ else:
373
+ if norm_type == "ada_norm_single": # For Latte
374
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
375
+ else:
376
+ self.norm2 = None
377
+ self.attn2 = None
378
+
379
+ # 3. Feed-forward
380
+ if norm_type == "ada_norm_continuous":
381
+ self.norm3 = AdaLayerNormContinuous(
382
+ dim,
383
+ ada_norm_continous_conditioning_embedding_dim,
384
+ norm_elementwise_affine,
385
+ norm_eps,
386
+ ada_norm_bias,
387
+ "layer_norm",
388
+ )
389
+
390
+ elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
391
+ self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
392
+ elif norm_type == "layer_norm_i2vgen":
393
+ self.norm3 = None
394
+
395
+ self.ff = FeedForward(
396
+ dim,
397
+ dropout=dropout,
398
+ activation_fn=activation_fn,
399
+ final_dropout=final_dropout,
400
+ inner_dim=ff_inner_dim,
401
+ bias=ff_bias,
402
+ )
403
+
404
+ # 4. Fuser
405
+ if attention_type == "gated" or attention_type == "gated-text-image":
406
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
407
+
408
+ # 5. Scale-shift for PixArt-Alpha.
409
+ if norm_type == "ada_norm_single":
410
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
411
+
412
+ # let chunk size default to None
413
+ self._chunk_size = None
414
+ self._chunk_dim = 0
415
+
416
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
417
+ # Sets chunk feed-forward
418
+ self._chunk_size = chunk_size
419
+ self._chunk_dim = dim
420
+
421
+ def forward(
422
+ self,
423
+ hidden_states: torch.Tensor,
424
+ attention_mask: Optional[torch.Tensor] = None,
425
+ encoder_hidden_states: Optional[torch.Tensor] = None,
426
+ encoder_attention_mask: Optional[torch.Tensor] = None,
427
+ timestep: Optional[torch.LongTensor] = None,
428
+ cross_attention_kwargs: Dict[str, Any] = None,
429
+ class_labels: Optional[torch.LongTensor] = None,
430
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
431
+ ) -> torch.Tensor:
432
+ if cross_attention_kwargs is not None:
433
+ if cross_attention_kwargs.get("scale", None) is not None:
434
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
435
+
436
+ # Notice that normalization is always applied before the real computation in the following blocks.
437
+ # 0. Self-Attention
438
+ batch_size = hidden_states.shape[0]
439
+
440
+ if self.norm_type == "ada_norm":
441
+ norm_hidden_states = self.norm1(hidden_states, timestep)
442
+ elif self.norm_type == "ada_norm_zero":
443
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
444
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
445
+ )
446
+ elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
447
+ norm_hidden_states = self.norm1(hidden_states)
448
+ elif self.norm_type == "ada_norm_continuous":
449
+ norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
450
+ elif self.norm_type == "ada_norm_single":
451
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
452
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
453
+ ).chunk(6, dim=1)
454
+ norm_hidden_states = self.norm1(hidden_states)
455
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
456
+ else:
457
+ raise ValueError("Incorrect norm used")
458
+
459
+ if self.pos_embed is not None:
460
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
461
+
462
+ # 1. Prepare GLIGEN inputs
463
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
464
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
465
+
466
+ attn_output = self.attn1(
467
+ norm_hidden_states,
468
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
469
+ attention_mask=attention_mask,
470
+ **cross_attention_kwargs,
471
+ )
472
+
473
+ if self.norm_type == "ada_norm_zero":
474
+ attn_output = gate_msa.unsqueeze(1) * attn_output
475
+ elif self.norm_type == "ada_norm_single":
476
+ attn_output = gate_msa * attn_output
477
+
478
+ hidden_states = attn_output + hidden_states
479
+ if hidden_states.ndim == 4:
480
+ hidden_states = hidden_states.squeeze(1)
481
+
482
+ # 1.2 GLIGEN Control
483
+ if gligen_kwargs is not None:
484
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
485
+
486
+ # 3. Cross-Attention
487
+ if self.attn2 is not None:
488
+ if self.norm_type == "ada_norm":
489
+ norm_hidden_states = self.norm2(hidden_states, timestep)
490
+ elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
491
+ norm_hidden_states = self.norm2(hidden_states)
492
+ elif self.norm_type == "ada_norm_single":
493
+ # For PixArt norm2 isn't applied here:
494
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
495
+ norm_hidden_states = hidden_states
496
+ elif self.norm_type == "ada_norm_continuous":
497
+ norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
498
+ else:
499
+ raise ValueError("Incorrect norm")
500
+
501
+ if self.pos_embed is not None and self.norm_type != "ada_norm_single":
502
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
503
+
504
+ attn_output = self.attn2(
505
+ norm_hidden_states,
506
+ encoder_hidden_states=encoder_hidden_states,
507
+ attention_mask=encoder_attention_mask,
508
+ **cross_attention_kwargs,
509
+ )
510
+ hidden_states = attn_output + hidden_states
511
+
512
+ # 4. Feed-forward
513
+ # i2vgen doesn't have this norm 🤷‍♂️
514
+ if self.norm_type == "ada_norm_continuous":
515
+ norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
516
+ elif not self.norm_type == "ada_norm_single":
517
+ norm_hidden_states = self.norm3(hidden_states)
518
+
519
+ if self.norm_type == "ada_norm_zero":
520
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
521
+
522
+ if self.norm_type == "ada_norm_single":
523
+ norm_hidden_states = self.norm2(hidden_states)
524
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
525
+
526
+ if self._chunk_size is not None:
527
+ # "feed_forward_chunk_size" can be used to save memory
528
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
529
+ else:
530
+ ff_output = self.ff(norm_hidden_states)
531
+
532
+ if self.norm_type == "ada_norm_zero":
533
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
534
+ elif self.norm_type == "ada_norm_single":
535
+ ff_output = gate_mlp * ff_output
536
+
537
+ hidden_states = ff_output + hidden_states
538
+ if hidden_states.ndim == 4:
539
+ hidden_states = hidden_states.squeeze(1)
540
+
541
+ return hidden_states
542
+
543
+
544
+ class LuminaFeedForward(nn.Module):
545
+ r"""
546
+ A feed-forward layer.
547
+
548
+ Parameters:
549
+ hidden_size (`int`):
550
+ The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
551
+ hidden representations.
552
+ intermediate_size (`int`): The intermediate dimension of the feedforward layer.
553
+ multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
554
+ of this value.
555
+ ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
556
+ dimension. Defaults to None.
557
+ """
558
+
559
+ def __init__(
560
+ self,
561
+ dim: int,
562
+ inner_dim: int,
563
+ multiple_of: Optional[int] = 256,
564
+ ffn_dim_multiplier: Optional[float] = None,
565
+ ):
566
+ super().__init__()
567
+ inner_dim = int(2 * inner_dim / 3)
568
+ # custom hidden_size factor multiplier
569
+ if ffn_dim_multiplier is not None:
570
+ inner_dim = int(ffn_dim_multiplier * inner_dim)
571
+ inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
572
+
573
+ self.linear_1 = nn.Linear(
574
+ dim,
575
+ inner_dim,
576
+ bias=False,
577
+ )
578
+ self.linear_2 = nn.Linear(
579
+ inner_dim,
580
+ dim,
581
+ bias=False,
582
+ )
583
+ self.linear_3 = nn.Linear(
584
+ dim,
585
+ inner_dim,
586
+ bias=False,
587
+ )
588
+ self.silu = FP32SiLU()
589
+
590
+ def forward(self, x):
591
+ return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x))
592
+
593
+
594
+ @maybe_allow_in_graph
595
+ class TemporalBasicTransformerBlock(nn.Module):
596
+ r"""
597
+ A basic Transformer block for video like data.
598
+
599
+ Parameters:
600
+ dim (`int`): The number of channels in the input and output.
601
+ time_mix_inner_dim (`int`): The number of channels for temporal attention.
602
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
603
+ attention_head_dim (`int`): The number of channels in each head.
604
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
605
+ """
606
+
607
+ def __init__(
608
+ self,
609
+ dim: int,
610
+ time_mix_inner_dim: int,
611
+ num_attention_heads: int,
612
+ attention_head_dim: int,
613
+ cross_attention_dim: Optional[int] = None,
614
+ ):
615
+ super().__init__()
616
+ self.is_res = dim == time_mix_inner_dim
617
+
618
+ self.norm_in = nn.LayerNorm(dim)
619
+
620
+ # Define 3 blocks. Each block has its own normalization layer.
621
+ # 1. Self-Attn
622
+ self.ff_in = FeedForward(
623
+ dim,
624
+ dim_out=time_mix_inner_dim,
625
+ activation_fn="geglu",
626
+ )
627
+
628
+ self.norm1 = nn.LayerNorm(time_mix_inner_dim)
629
+ self.attn1 = Attention(
630
+ query_dim=time_mix_inner_dim,
631
+ heads=num_attention_heads,
632
+ dim_head=attention_head_dim,
633
+ cross_attention_dim=None,
634
+ )
635
+
636
+ # 2. Cross-Attn
637
+ if cross_attention_dim is not None:
638
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
639
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
640
+ # the second cross attention block.
641
+ self.norm2 = nn.LayerNorm(time_mix_inner_dim)
642
+ self.attn2 = Attention(
643
+ query_dim=time_mix_inner_dim,
644
+ cross_attention_dim=cross_attention_dim,
645
+ heads=num_attention_heads,
646
+ dim_head=attention_head_dim,
647
+ ) # is self-attn if encoder_hidden_states is none
648
+ else:
649
+ self.norm2 = None
650
+ self.attn2 = None
651
+
652
+ # 3. Feed-forward
653
+ self.norm3 = nn.LayerNorm(time_mix_inner_dim)
654
+ self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
655
+
656
+ # let chunk size default to None
657
+ self._chunk_size = None
658
+ self._chunk_dim = None
659
+
660
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
661
+ # Sets chunk feed-forward
662
+ self._chunk_size = chunk_size
663
+ # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
664
+ self._chunk_dim = 1
665
+
666
+ def forward(
667
+ self,
668
+ hidden_states: torch.Tensor,
669
+ num_frames: int,
670
+ encoder_hidden_states: Optional[torch.Tensor] = None,
671
+ ) -> torch.Tensor:
672
+ # Notice that normalization is always applied before the real computation in the following blocks.
673
+ # 0. Self-Attention
674
+ batch_size = hidden_states.shape[0]
675
+
676
+ batch_frames, seq_length, channels = hidden_states.shape
677
+ batch_size = batch_frames // num_frames
678
+
679
+ hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
680
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
681
+ hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
682
+
683
+ residual = hidden_states
684
+ hidden_states = self.norm_in(hidden_states)
685
+
686
+ if self._chunk_size is not None:
687
+ hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
688
+ else:
689
+ hidden_states = self.ff_in(hidden_states)
690
+
691
+ if self.is_res:
692
+ hidden_states = hidden_states + residual
693
+
694
+ norm_hidden_states = self.norm1(hidden_states)
695
+ attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
696
+ hidden_states = attn_output + hidden_states
697
+
698
+ # 3. Cross-Attention
699
+ if self.attn2 is not None:
700
+ norm_hidden_states = self.norm2(hidden_states)
701
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
702
+ hidden_states = attn_output + hidden_states
703
+
704
+ # 4. Feed-forward
705
+ norm_hidden_states = self.norm3(hidden_states)
706
+
707
+ if self._chunk_size is not None:
708
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
709
+ else:
710
+ ff_output = self.ff(norm_hidden_states)
711
+
712
+ if self.is_res:
713
+ hidden_states = ff_output + hidden_states
714
+ else:
715
+ hidden_states = ff_output
716
+
717
+ hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
718
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
719
+ hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
720
+
721
+ return hidden_states
722
+
723
+
724
+ class SkipFFTransformerBlock(nn.Module):
725
+ def __init__(
726
+ self,
727
+ dim: int,
728
+ num_attention_heads: int,
729
+ attention_head_dim: int,
730
+ kv_input_dim: int,
731
+ kv_input_dim_proj_use_bias: bool,
732
+ dropout=0.0,
733
+ cross_attention_dim: Optional[int] = None,
734
+ attention_bias: bool = False,
735
+ attention_out_bias: bool = True,
736
+ ):
737
+ super().__init__()
738
+ if kv_input_dim != dim:
739
+ self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
740
+ else:
741
+ self.kv_mapper = None
742
+
743
+ self.norm1 = RMSNorm(dim, 1e-06)
744
+
745
+ self.attn1 = Attention(
746
+ query_dim=dim,
747
+ heads=num_attention_heads,
748
+ dim_head=attention_head_dim,
749
+ dropout=dropout,
750
+ bias=attention_bias,
751
+ cross_attention_dim=cross_attention_dim,
752
+ out_bias=attention_out_bias,
753
+ )
754
+
755
+ self.norm2 = RMSNorm(dim, 1e-06)
756
+
757
+ self.attn2 = Attention(
758
+ query_dim=dim,
759
+ cross_attention_dim=cross_attention_dim,
760
+ heads=num_attention_heads,
761
+ dim_head=attention_head_dim,
762
+ dropout=dropout,
763
+ bias=attention_bias,
764
+ out_bias=attention_out_bias,
765
+ )
766
+
767
+ def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
768
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
769
+
770
+ if self.kv_mapper is not None:
771
+ encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
772
+
773
+ norm_hidden_states = self.norm1(hidden_states)
774
+
775
+ attn_output = self.attn1(
776
+ norm_hidden_states,
777
+ encoder_hidden_states=encoder_hidden_states,
778
+ **cross_attention_kwargs,
779
+ )
780
+
781
+ hidden_states = attn_output + hidden_states
782
+
783
+ norm_hidden_states = self.norm2(hidden_states)
784
+
785
+ attn_output = self.attn2(
786
+ norm_hidden_states,
787
+ encoder_hidden_states=encoder_hidden_states,
788
+ **cross_attention_kwargs,
789
+ )
790
+
791
+ hidden_states = attn_output + hidden_states
792
+
793
+ return hidden_states
794
+
795
+
796
+ @maybe_allow_in_graph
797
+ class FreeNoiseTransformerBlock(nn.Module):
798
+ r"""
799
+ A FreeNoise Transformer block.
800
+
801
+ Parameters:
802
+ dim (`int`):
803
+ The number of channels in the input and output.
804
+ num_attention_heads (`int`):
805
+ The number of heads to use for multi-head attention.
806
+ attention_head_dim (`int`):
807
+ The number of channels in each head.
808
+ dropout (`float`, *optional*, defaults to 0.0):
809
+ The dropout probability to use.
810
+ cross_attention_dim (`int`, *optional*):
811
+ The size of the encoder_hidden_states vector for cross attention.
812
+ activation_fn (`str`, *optional*, defaults to `"geglu"`):
813
+ Activation function to be used in feed-forward.
814
+ num_embeds_ada_norm (`int`, *optional*):
815
+ The number of diffusion steps used during training. See `Transformer2DModel`.
816
+ attention_bias (`bool`, defaults to `False`):
817
+ Configure if the attentions should contain a bias parameter.
818
+ only_cross_attention (`bool`, defaults to `False`):
819
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
820
+ double_self_attention (`bool`, defaults to `False`):
821
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
822
+ upcast_attention (`bool`, defaults to `False`):
823
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
824
+ norm_elementwise_affine (`bool`, defaults to `True`):
825
+ Whether to use learnable elementwise affine parameters for normalization.
826
+ norm_type (`str`, defaults to `"layer_norm"`):
827
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
828
+ final_dropout (`bool` defaults to `False`):
829
+ Whether to apply a final dropout after the last feed-forward layer.
830
+ attention_type (`str`, defaults to `"default"`):
831
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
832
+ positional_embeddings (`str`, *optional*):
833
+ The type of positional embeddings to apply to.
834
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
835
+ The maximum number of positional embeddings to apply.
836
+ ff_inner_dim (`int`, *optional*):
837
+ Hidden dimension of feed-forward MLP.
838
+ ff_bias (`bool`, defaults to `True`):
839
+ Whether or not to use bias in feed-forward MLP.
840
+ attention_out_bias (`bool`, defaults to `True`):
841
+ Whether or not to use bias in attention output project layer.
842
+ context_length (`int`, defaults to `16`):
843
+ The maximum number of frames that the FreeNoise block processes at once.
844
+ context_stride (`int`, defaults to `4`):
845
+ The number of frames to be skipped before starting to process a new batch of `context_length` frames.
846
+ weighting_scheme (`str`, defaults to `"pyramid"`):
847
+ The weighting scheme to use for weighting averaging of processed latent frames. As described in the
848
+ Equation 9. of the [FreeNoise](https://arxiv.org/abs/2310.15169) paper, "pyramid" is the default setting
849
+ used.
850
+ """
851
+
852
+ def __init__(
853
+ self,
854
+ dim: int,
855
+ num_attention_heads: int,
856
+ attention_head_dim: int,
857
+ dropout: float = 0.0,
858
+ cross_attention_dim: Optional[int] = None,
859
+ activation_fn: str = "geglu",
860
+ num_embeds_ada_norm: Optional[int] = None,
861
+ attention_bias: bool = False,
862
+ only_cross_attention: bool = False,
863
+ double_self_attention: bool = False,
864
+ upcast_attention: bool = False,
865
+ norm_elementwise_affine: bool = True,
866
+ norm_type: str = "layer_norm",
867
+ norm_eps: float = 1e-5,
868
+ final_dropout: bool = False,
869
+ positional_embeddings: Optional[str] = None,
870
+ num_positional_embeddings: Optional[int] = None,
871
+ ff_inner_dim: Optional[int] = None,
872
+ ff_bias: bool = True,
873
+ attention_out_bias: bool = True,
874
+ context_length: int = 16,
875
+ context_stride: int = 4,
876
+ weighting_scheme: str = "pyramid",
877
+ ):
878
+ super().__init__()
879
+ self.dim = dim
880
+ self.num_attention_heads = num_attention_heads
881
+ self.attention_head_dim = attention_head_dim
882
+ self.dropout = dropout
883
+ self.cross_attention_dim = cross_attention_dim
884
+ self.activation_fn = activation_fn
885
+ self.attention_bias = attention_bias
886
+ self.double_self_attention = double_self_attention
887
+ self.norm_elementwise_affine = norm_elementwise_affine
888
+ self.positional_embeddings = positional_embeddings
889
+ self.num_positional_embeddings = num_positional_embeddings
890
+ self.only_cross_attention = only_cross_attention
891
+
892
+ self.set_free_noise_properties(context_length, context_stride, weighting_scheme)
893
+
894
+ # We keep these boolean flags for backward-compatibility.
895
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
896
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
897
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
898
+ self.use_layer_norm = norm_type == "layer_norm"
899
+ self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
900
+
901
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
902
+ raise ValueError(
903
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
904
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
905
+ )
906
+
907
+ self.norm_type = norm_type
908
+ self.num_embeds_ada_norm = num_embeds_ada_norm
909
+
910
+ if positional_embeddings and (num_positional_embeddings is None):
911
+ raise ValueError(
912
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
913
+ )
914
+
915
+ if positional_embeddings == "sinusoidal":
916
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
917
+ else:
918
+ self.pos_embed = None
919
+
920
+ # Define 3 blocks. Each block has its own normalization layer.
921
+ # 1. Self-Attn
922
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
923
+
924
+ self.attn1 = Attention(
925
+ query_dim=dim,
926
+ heads=num_attention_heads,
927
+ dim_head=attention_head_dim,
928
+ dropout=dropout,
929
+ bias=attention_bias,
930
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
931
+ upcast_attention=upcast_attention,
932
+ out_bias=attention_out_bias,
933
+ )
934
+
935
+ # 2. Cross-Attn
936
+ if cross_attention_dim is not None or double_self_attention:
937
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
938
+
939
+ self.attn2 = Attention(
940
+ query_dim=dim,
941
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
942
+ heads=num_attention_heads,
943
+ dim_head=attention_head_dim,
944
+ dropout=dropout,
945
+ bias=attention_bias,
946
+ upcast_attention=upcast_attention,
947
+ out_bias=attention_out_bias,
948
+ ) # is self-attn if encoder_hidden_states is none
949
+
950
+ # 3. Feed-forward
951
+ self.ff = FeedForward(
952
+ dim,
953
+ dropout=dropout,
954
+ activation_fn=activation_fn,
955
+ final_dropout=final_dropout,
956
+ inner_dim=ff_inner_dim,
957
+ bias=ff_bias,
958
+ )
959
+
960
+ self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
961
+
962
+ # let chunk size default to None
963
+ self._chunk_size = None
964
+ self._chunk_dim = 0
965
+
966
+ def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]:
967
+ frame_indices = []
968
+ for i in range(0, num_frames - self.context_length + 1, self.context_stride):
969
+ window_start = i
970
+ window_end = min(num_frames, i + self.context_length)
971
+ frame_indices.append((window_start, window_end))
972
+ return frame_indices
973
+
974
+ def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
975
+ if weighting_scheme == "flat":
976
+ weights = [1.0] * num_frames
977
+
978
+ elif weighting_scheme == "pyramid":
979
+ if num_frames % 2 == 0:
980
+ # num_frames = 4 => [1, 2, 2, 1]
981
+ mid = num_frames // 2
982
+ weights = list(range(1, mid + 1))
983
+ weights = weights + weights[::-1]
984
+ else:
985
+ # num_frames = 5 => [1, 2, 3, 2, 1]
986
+ mid = (num_frames + 1) // 2
987
+ weights = list(range(1, mid))
988
+ weights = weights + [mid] + weights[::-1]
989
+
990
+ elif weighting_scheme == "delayed_reverse_sawtooth":
991
+ if num_frames % 2 == 0:
992
+ # num_frames = 4 => [0.01, 2, 2, 1]
993
+ mid = num_frames // 2
994
+ weights = [0.01] * (mid - 1) + [mid]
995
+ weights = weights + list(range(mid, 0, -1))
996
+ else:
997
+ # num_frames = 5 => [0.01, 0.01, 3, 2, 1]
998
+ mid = (num_frames + 1) // 2
999
+ weights = [0.01] * mid
1000
+ weights = weights + list(range(mid, 0, -1))
1001
+ else:
1002
+ raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")
1003
+
1004
+ return weights
1005
+
1006
+ def set_free_noise_properties(
1007
+ self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid"
1008
+ ) -> None:
1009
+ self.context_length = context_length
1010
+ self.context_stride = context_stride
1011
+ self.weighting_scheme = weighting_scheme
1012
+
1013
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0) -> None:
1014
+ # Sets chunk feed-forward
1015
+ self._chunk_size = chunk_size
1016
+ self._chunk_dim = dim
1017
+
1018
+ def forward(
1019
+ self,
1020
+ hidden_states: torch.Tensor,
1021
+ attention_mask: Optional[torch.Tensor] = None,
1022
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1023
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1024
+ cross_attention_kwargs: Dict[str, Any] = None,
1025
+ *args,
1026
+ **kwargs,
1027
+ ) -> torch.Tensor:
1028
+ if cross_attention_kwargs is not None:
1029
+ if cross_attention_kwargs.get("scale", None) is not None:
1030
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
1031
+
1032
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
1033
+
1034
+ # hidden_states: [B x H x W, F, C]
1035
+ device = hidden_states.device
1036
+ dtype = hidden_states.dtype
1037
+
1038
+ num_frames = hidden_states.size(1)
1039
+ frame_indices = self._get_frame_indices(num_frames)
1040
+ frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme)
1041
+ frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1)
1042
+ is_last_frame_batch_complete = frame_indices[-1][1] == num_frames
1043
+
1044
+ # Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length
1045
+ # For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges:
1046
+ # [(0, 16), (4, 20), (8, 24), (10, 26)]
1047
+ if not is_last_frame_batch_complete:
1048
+ if num_frames < self.context_length:
1049
+ raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}")
1050
+ last_frame_batch_length = num_frames - frame_indices[-1][1]
1051
+ frame_indices.append((num_frames - self.context_length, num_frames))
1052
+
1053
+ num_times_accumulated = torch.zeros((1, num_frames, 1), device=device)
1054
+ accumulated_values = torch.zeros_like(hidden_states)
1055
+
1056
+ for i, (frame_start, frame_end) in enumerate(frame_indices):
1057
+ # The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle
1058
+ # cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or
1059
+ # essentially a non-multiple of `context_length`.
1060
+ weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end])
1061
+ weights *= frame_weights
1062
+
1063
+ hidden_states_chunk = hidden_states[:, frame_start:frame_end]
1064
+
1065
+ # Notice that normalization is always applied before the real computation in the following blocks.
1066
+ # 1. Self-Attention
1067
+ norm_hidden_states = self.norm1(hidden_states_chunk)
1068
+
1069
+ if self.pos_embed is not None:
1070
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
1071
+
1072
+ attn_output = self.attn1(
1073
+ norm_hidden_states,
1074
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
1075
+ attention_mask=attention_mask,
1076
+ **cross_attention_kwargs,
1077
+ )
1078
+
1079
+ hidden_states_chunk = attn_output + hidden_states_chunk
1080
+ if hidden_states_chunk.ndim == 4:
1081
+ hidden_states_chunk = hidden_states_chunk.squeeze(1)
1082
+
1083
+ # 2. Cross-Attention
1084
+ if self.attn2 is not None:
1085
+ norm_hidden_states = self.norm2(hidden_states_chunk)
1086
+
1087
+ if self.pos_embed is not None and self.norm_type != "ada_norm_single":
1088
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
1089
+
1090
+ attn_output = self.attn2(
1091
+ norm_hidden_states,
1092
+ encoder_hidden_states=encoder_hidden_states,
1093
+ attention_mask=encoder_attention_mask,
1094
+ **cross_attention_kwargs,
1095
+ )
1096
+ hidden_states_chunk = attn_output + hidden_states_chunk
1097
+
1098
+ if i == len(frame_indices) - 1 and not is_last_frame_batch_complete:
1099
+ accumulated_values[:, -last_frame_batch_length:] += (
1100
+ hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:]
1101
+ )
1102
+ num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length]
1103
+ else:
1104
+ accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
1105
+ num_times_accumulated[:, frame_start:frame_end] += weights
1106
+
1107
+ # TODO(aryan): Maybe this could be done in a better way.
1108
+ #
1109
+ # Previously, this was:
1110
+ # hidden_states = torch.where(
1111
+ # num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
1112
+ # )
1113
+ #
1114
+ # The reasoning for the change here is `torch.where` became a bottleneck at some point when golfing memory
1115
+ # spikes. It is particularly noticeable when the number of frames is high. My understanding is that this comes
1116
+ # from tensors being copied - which is why we resort to spliting and concatenating here. I've not particularly
1117
+ # looked into this deeply because other memory optimizations led to more pronounced reductions.
1118
+ hidden_states = torch.cat(
1119
+ [
1120
+ torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split)
1121
+ for accumulated_split, num_times_split in zip(
1122
+ accumulated_values.split(self.context_length, dim=1),
1123
+ num_times_accumulated.split(self.context_length, dim=1),
1124
+ )
1125
+ ],
1126
+ dim=1,
1127
+ ).to(dtype)
1128
+
1129
+ # 3. Feed-forward
1130
+ norm_hidden_states = self.norm3(hidden_states)
1131
+
1132
+ if self._chunk_size is not None:
1133
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
1134
+ else:
1135
+ ff_output = self.ff(norm_hidden_states)
1136
+
1137
+ hidden_states = ff_output + hidden_states
1138
+ if hidden_states.ndim == 4:
1139
+ hidden_states = hidden_states.squeeze(1)
1140
+
1141
+ return hidden_states
1142
+
1143
+
1144
+ class FeedForward(nn.Module):
1145
+ r"""
1146
+ A feed-forward layer.
1147
+
1148
+ Parameters:
1149
+ dim (`int`): The number of channels in the input.
1150
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
1151
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
1152
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
1153
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
1154
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
1155
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
1156
+ """
1157
+
1158
+ def __init__(
1159
+ self,
1160
+ dim: int,
1161
+ dim_out: Optional[int] = None,
1162
+ mult: int = 4,
1163
+ dropout: float = 0.0,
1164
+ activation_fn: str = "geglu",
1165
+ final_dropout: bool = False,
1166
+ inner_dim=None,
1167
+ bias: bool = True,
1168
+ ):
1169
+ super().__init__()
1170
+ if inner_dim is None:
1171
+ inner_dim = int(dim * mult)
1172
+ dim_out = dim_out if dim_out is not None else dim
1173
+
1174
+ if activation_fn == "gelu":
1175
+ act_fn = GELU(dim, inner_dim, bias=bias)
1176
+ if activation_fn == "gelu-approximate":
1177
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
1178
+ elif activation_fn == "geglu":
1179
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
1180
+ elif activation_fn == "geglu-approximate":
1181
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
1182
+ elif activation_fn == "swiglu":
1183
+ act_fn = SwiGLU(dim, inner_dim, bias=bias)
1184
+
1185
+ self.net = nn.ModuleList([])
1186
+ # project in
1187
+ self.net.append(act_fn)
1188
+ # project dropout
1189
+ self.net.append(nn.Dropout(dropout))
1190
+ # project out
1191
+ self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
1192
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
1193
+ if final_dropout:
1194
+ self.net.append(nn.Dropout(dropout))
1195
+
1196
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
1197
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
1198
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1199
+ deprecate("scale", "1.0.0", deprecation_message)
1200
+ for module in self.net:
1201
+ hidden_states = module(hidden_states)
1202
+ return hidden_states
diffusers3/models/attention_flax.py ADDED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import functools
16
+ import math
17
+
18
+ import flax.linen as nn
19
+ import jax
20
+ import jax.numpy as jnp
21
+
22
+
23
+ def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
24
+ """Multi-head dot product attention with a limited number of queries."""
25
+ num_kv, num_heads, k_features = key.shape[-3:]
26
+ v_features = value.shape[-1]
27
+ key_chunk_size = min(key_chunk_size, num_kv)
28
+ query = query / jnp.sqrt(k_features)
29
+
30
+ @functools.partial(jax.checkpoint, prevent_cse=False)
31
+ def summarize_chunk(query, key, value):
32
+ attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision)
33
+
34
+ max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
35
+ max_score = jax.lax.stop_gradient(max_score)
36
+ exp_weights = jnp.exp(attn_weights - max_score)
37
+
38
+ exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision)
39
+ max_score = jnp.einsum("...qhk->...qh", max_score)
40
+
41
+ return (exp_values, exp_weights.sum(axis=-1), max_score)
42
+
43
+ def chunk_scanner(chunk_idx):
44
+ # julienne key array
45
+ key_chunk = jax.lax.dynamic_slice(
46
+ operand=key,
47
+ start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], # [...,k,h,d]
48
+ slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features], # [...,k,h,d]
49
+ )
50
+
51
+ # julienne value array
52
+ value_chunk = jax.lax.dynamic_slice(
53
+ operand=value,
54
+ start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], # [...,v,h,d]
55
+ slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features], # [...,v,h,d]
56
+ )
57
+
58
+ return summarize_chunk(query, key_chunk, value_chunk)
59
+
60
+ chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size))
61
+
62
+ global_max = jnp.max(chunk_max, axis=0, keepdims=True)
63
+ max_diffs = jnp.exp(chunk_max - global_max)
64
+
65
+ chunk_values *= jnp.expand_dims(max_diffs, axis=-1)
66
+ chunk_weights *= max_diffs
67
+
68
+ all_values = chunk_values.sum(axis=0)
69
+ all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0)
70
+
71
+ return all_values / all_weights
72
+
73
+
74
+ def jax_memory_efficient_attention(
75
+ query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096
76
+ ):
77
+ r"""
78
+ Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2
79
+ https://github.com/AminRezaei0x443/memory-efficient-attention
80
+
81
+ Args:
82
+ query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head)
83
+ key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head)
84
+ value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head)
85
+ precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`):
86
+ numerical precision for computation
87
+ query_chunk_size (`int`, *optional*, defaults to 1024):
88
+ chunk size to divide query array value must divide query_length equally without remainder
89
+ key_chunk_size (`int`, *optional*, defaults to 4096):
90
+ chunk size to divide key and value array value must divide key_value_length equally without remainder
91
+
92
+ Returns:
93
+ (`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head)
94
+ """
95
+ num_q, num_heads, q_features = query.shape[-3:]
96
+
97
+ def chunk_scanner(chunk_idx, _):
98
+ # julienne query array
99
+ query_chunk = jax.lax.dynamic_slice(
100
+ operand=query,
101
+ start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], # [...,q,h,d]
102
+ slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features], # [...,q,h,d]
103
+ )
104
+
105
+ return (
106
+ chunk_idx + query_chunk_size, # unused ignore it
107
+ _query_chunk_attention(
108
+ query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size
109
+ ),
110
+ )
111
+
112
+ _, res = jax.lax.scan(
113
+ f=chunk_scanner,
114
+ init=0,
115
+ xs=None,
116
+ length=math.ceil(num_q / query_chunk_size), # start counter # stop counter
117
+ )
118
+
119
+ return jnp.concatenate(res, axis=-3) # fuse the chunked result back
120
+
121
+
122
+ class FlaxAttention(nn.Module):
123
+ r"""
124
+ A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
125
+
126
+ Parameters:
127
+ query_dim (:obj:`int`):
128
+ Input hidden states dimension
129
+ heads (:obj:`int`, *optional*, defaults to 8):
130
+ Number of heads
131
+ dim_head (:obj:`int`, *optional*, defaults to 64):
132
+ Hidden states dimension inside each head
133
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
134
+ Dropout rate
135
+ use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
136
+ enable memory efficient attention https://arxiv.org/abs/2112.05682
137
+ split_head_dim (`bool`, *optional*, defaults to `False`):
138
+ Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
139
+ enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
140
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
141
+ Parameters `dtype`
142
+
143
+ """
144
+
145
+ query_dim: int
146
+ heads: int = 8
147
+ dim_head: int = 64
148
+ dropout: float = 0.0
149
+ use_memory_efficient_attention: bool = False
150
+ split_head_dim: bool = False
151
+ dtype: jnp.dtype = jnp.float32
152
+
153
+ def setup(self):
154
+ inner_dim = self.dim_head * self.heads
155
+ self.scale = self.dim_head**-0.5
156
+
157
+ # Weights were exported with old names {to_q, to_k, to_v, to_out}
158
+ self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q")
159
+ self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k")
160
+ self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")
161
+
162
+ self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")
163
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
164
+
165
+ def reshape_heads_to_batch_dim(self, tensor):
166
+ batch_size, seq_len, dim = tensor.shape
167
+ head_size = self.heads
168
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
169
+ tensor = jnp.transpose(tensor, (0, 2, 1, 3))
170
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
171
+ return tensor
172
+
173
+ def reshape_batch_dim_to_heads(self, tensor):
174
+ batch_size, seq_len, dim = tensor.shape
175
+ head_size = self.heads
176
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
177
+ tensor = jnp.transpose(tensor, (0, 2, 1, 3))
178
+ tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
179
+ return tensor
180
+
181
+ def __call__(self, hidden_states, context=None, deterministic=True):
182
+ context = hidden_states if context is None else context
183
+
184
+ query_proj = self.query(hidden_states)
185
+ key_proj = self.key(context)
186
+ value_proj = self.value(context)
187
+
188
+ if self.split_head_dim:
189
+ b = hidden_states.shape[0]
190
+ query_states = jnp.reshape(query_proj, (b, -1, self.heads, self.dim_head))
191
+ key_states = jnp.reshape(key_proj, (b, -1, self.heads, self.dim_head))
192
+ value_states = jnp.reshape(value_proj, (b, -1, self.heads, self.dim_head))
193
+ else:
194
+ query_states = self.reshape_heads_to_batch_dim(query_proj)
195
+ key_states = self.reshape_heads_to_batch_dim(key_proj)
196
+ value_states = self.reshape_heads_to_batch_dim(value_proj)
197
+
198
+ if self.use_memory_efficient_attention:
199
+ query_states = query_states.transpose(1, 0, 2)
200
+ key_states = key_states.transpose(1, 0, 2)
201
+ value_states = value_states.transpose(1, 0, 2)
202
+
203
+ # this if statement create a chunk size for each layer of the unet
204
+ # the chunk size is equal to the query_length dimension of the deepest layer of the unet
205
+
206
+ flatten_latent_dim = query_states.shape[-3]
207
+ if flatten_latent_dim % 64 == 0:
208
+ query_chunk_size = int(flatten_latent_dim / 64)
209
+ elif flatten_latent_dim % 16 == 0:
210
+ query_chunk_size = int(flatten_latent_dim / 16)
211
+ elif flatten_latent_dim % 4 == 0:
212
+ query_chunk_size = int(flatten_latent_dim / 4)
213
+ else:
214
+ query_chunk_size = int(flatten_latent_dim)
215
+
216
+ hidden_states = jax_memory_efficient_attention(
217
+ query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4
218
+ )
219
+
220
+ hidden_states = hidden_states.transpose(1, 0, 2)
221
+ else:
222
+ # compute attentions
223
+ if self.split_head_dim:
224
+ attention_scores = jnp.einsum("b t n h, b f n h -> b n f t", key_states, query_states)
225
+ else:
226
+ attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
227
+
228
+ attention_scores = attention_scores * self.scale
229
+ attention_probs = nn.softmax(attention_scores, axis=-1 if self.split_head_dim else 2)
230
+
231
+ # attend to values
232
+ if self.split_head_dim:
233
+ hidden_states = jnp.einsum("b n f t, b t n h -> b f n h", attention_probs, value_states)
234
+ b = hidden_states.shape[0]
235
+ hidden_states = jnp.reshape(hidden_states, (b, -1, self.heads * self.dim_head))
236
+ else:
237
+ hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
238
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
239
+
240
+ hidden_states = self.proj_attn(hidden_states)
241
+ return self.dropout_layer(hidden_states, deterministic=deterministic)
242
+
243
+
244
+ class FlaxBasicTransformerBlock(nn.Module):
245
+ r"""
246
+ A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in:
247
+ https://arxiv.org/abs/1706.03762
248
+
249
+
250
+ Parameters:
251
+ dim (:obj:`int`):
252
+ Inner hidden states dimension
253
+ n_heads (:obj:`int`):
254
+ Number of heads
255
+ d_head (:obj:`int`):
256
+ Hidden states dimension inside each head
257
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
258
+ Dropout rate
259
+ only_cross_attention (`bool`, defaults to `False`):
260
+ Whether to only apply cross attention.
261
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
262
+ Parameters `dtype`
263
+ use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
264
+ enable memory efficient attention https://arxiv.org/abs/2112.05682
265
+ split_head_dim (`bool`, *optional*, defaults to `False`):
266
+ Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
267
+ enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
268
+ """
269
+
270
+ dim: int
271
+ n_heads: int
272
+ d_head: int
273
+ dropout: float = 0.0
274
+ only_cross_attention: bool = False
275
+ dtype: jnp.dtype = jnp.float32
276
+ use_memory_efficient_attention: bool = False
277
+ split_head_dim: bool = False
278
+
279
+ def setup(self):
280
+ # self attention (or cross_attention if only_cross_attention is True)
281
+ self.attn1 = FlaxAttention(
282
+ self.dim,
283
+ self.n_heads,
284
+ self.d_head,
285
+ self.dropout,
286
+ self.use_memory_efficient_attention,
287
+ self.split_head_dim,
288
+ dtype=self.dtype,
289
+ )
290
+ # cross attention
291
+ self.attn2 = FlaxAttention(
292
+ self.dim,
293
+ self.n_heads,
294
+ self.d_head,
295
+ self.dropout,
296
+ self.use_memory_efficient_attention,
297
+ self.split_head_dim,
298
+ dtype=self.dtype,
299
+ )
300
+ self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
301
+ self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
302
+ self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
303
+ self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
304
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
305
+
306
+ def __call__(self, hidden_states, context, deterministic=True):
307
+ # self attention
308
+ residual = hidden_states
309
+ if self.only_cross_attention:
310
+ hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic)
311
+ else:
312
+ hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
313
+ hidden_states = hidden_states + residual
314
+
315
+ # cross attention
316
+ residual = hidden_states
317
+ hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic)
318
+ hidden_states = hidden_states + residual
319
+
320
+ # feed forward
321
+ residual = hidden_states
322
+ hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic)
323
+ hidden_states = hidden_states + residual
324
+
325
+ return self.dropout_layer(hidden_states, deterministic=deterministic)
326
+
327
+
328
+ class FlaxTransformer2DModel(nn.Module):
329
+ r"""
330
+ A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in:
331
+ https://arxiv.org/pdf/1506.02025.pdf
332
+
333
+
334
+ Parameters:
335
+ in_channels (:obj:`int`):
336
+ Input number of channels
337
+ n_heads (:obj:`int`):
338
+ Number of heads
339
+ d_head (:obj:`int`):
340
+ Hidden states dimension inside each head
341
+ depth (:obj:`int`, *optional*, defaults to 1):
342
+ Number of transformers block
343
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
344
+ Dropout rate
345
+ use_linear_projection (`bool`, defaults to `False`): tbd
346
+ only_cross_attention (`bool`, defaults to `False`): tbd
347
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
348
+ Parameters `dtype`
349
+ use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
350
+ enable memory efficient attention https://arxiv.org/abs/2112.05682
351
+ split_head_dim (`bool`, *optional*, defaults to `False`):
352
+ Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
353
+ enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
354
+ """
355
+
356
+ in_channels: int
357
+ n_heads: int
358
+ d_head: int
359
+ depth: int = 1
360
+ dropout: float = 0.0
361
+ use_linear_projection: bool = False
362
+ only_cross_attention: bool = False
363
+ dtype: jnp.dtype = jnp.float32
364
+ use_memory_efficient_attention: bool = False
365
+ split_head_dim: bool = False
366
+
367
+ def setup(self):
368
+ self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
369
+
370
+ inner_dim = self.n_heads * self.d_head
371
+ if self.use_linear_projection:
372
+ self.proj_in = nn.Dense(inner_dim, dtype=self.dtype)
373
+ else:
374
+ self.proj_in = nn.Conv(
375
+ inner_dim,
376
+ kernel_size=(1, 1),
377
+ strides=(1, 1),
378
+ padding="VALID",
379
+ dtype=self.dtype,
380
+ )
381
+
382
+ self.transformer_blocks = [
383
+ FlaxBasicTransformerBlock(
384
+ inner_dim,
385
+ self.n_heads,
386
+ self.d_head,
387
+ dropout=self.dropout,
388
+ only_cross_attention=self.only_cross_attention,
389
+ dtype=self.dtype,
390
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
391
+ split_head_dim=self.split_head_dim,
392
+ )
393
+ for _ in range(self.depth)
394
+ ]
395
+
396
+ if self.use_linear_projection:
397
+ self.proj_out = nn.Dense(inner_dim, dtype=self.dtype)
398
+ else:
399
+ self.proj_out = nn.Conv(
400
+ inner_dim,
401
+ kernel_size=(1, 1),
402
+ strides=(1, 1),
403
+ padding="VALID",
404
+ dtype=self.dtype,
405
+ )
406
+
407
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
408
+
409
+ def __call__(self, hidden_states, context, deterministic=True):
410
+ batch, height, width, channels = hidden_states.shape
411
+ residual = hidden_states
412
+ hidden_states = self.norm(hidden_states)
413
+ if self.use_linear_projection:
414
+ hidden_states = hidden_states.reshape(batch, height * width, channels)
415
+ hidden_states = self.proj_in(hidden_states)
416
+ else:
417
+ hidden_states = self.proj_in(hidden_states)
418
+ hidden_states = hidden_states.reshape(batch, height * width, channels)
419
+
420
+ for transformer_block in self.transformer_blocks:
421
+ hidden_states = transformer_block(hidden_states, context, deterministic=deterministic)
422
+
423
+ if self.use_linear_projection:
424
+ hidden_states = self.proj_out(hidden_states)
425
+ hidden_states = hidden_states.reshape(batch, height, width, channels)
426
+ else:
427
+ hidden_states = hidden_states.reshape(batch, height, width, channels)
428
+ hidden_states = self.proj_out(hidden_states)
429
+
430
+ hidden_states = hidden_states + residual
431
+ return self.dropout_layer(hidden_states, deterministic=deterministic)
432
+
433
+
434
+ class FlaxFeedForward(nn.Module):
435
+ r"""
436
+ Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's
437
+ [`FeedForward`] class, with the following simplifications:
438
+ - The activation function is currently hardcoded to a gated linear unit from:
439
+ https://arxiv.org/abs/2002.05202
440
+ - `dim_out` is equal to `dim`.
441
+ - The number of hidden dimensions is hardcoded to `dim * 4` in [`FlaxGELU`].
442
+
443
+ Parameters:
444
+ dim (:obj:`int`):
445
+ Inner hidden states dimension
446
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
447
+ Dropout rate
448
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
449
+ Parameters `dtype`
450
+ """
451
+
452
+ dim: int
453
+ dropout: float = 0.0
454
+ dtype: jnp.dtype = jnp.float32
455
+
456
+ def setup(self):
457
+ # The second linear layer needs to be called
458
+ # net_2 for now to match the index of the Sequential layer
459
+ self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)
460
+ self.net_2 = nn.Dense(self.dim, dtype=self.dtype)
461
+
462
+ def __call__(self, hidden_states, deterministic=True):
463
+ hidden_states = self.net_0(hidden_states, deterministic=deterministic)
464
+ hidden_states = self.net_2(hidden_states)
465
+ return hidden_states
466
+
467
+
468
+ class FlaxGEGLU(nn.Module):
469
+ r"""
470
+ Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
471
+ https://arxiv.org/abs/2002.05202.
472
+
473
+ Parameters:
474
+ dim (:obj:`int`):
475
+ Input hidden states dimension
476
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
477
+ Dropout rate
478
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
479
+ Parameters `dtype`
480
+ """
481
+
482
+ dim: int
483
+ dropout: float = 0.0
484
+ dtype: jnp.dtype = jnp.float32
485
+
486
+ def setup(self):
487
+ inner_dim = self.dim * 4
488
+ self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
489
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
490
+
491
+ def __call__(self, hidden_states, deterministic=True):
492
+ hidden_states = self.proj(hidden_states)
493
+ hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
494
+ return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)
diffusers3/models/attention_processor.py ADDED
The diff for this file is too large to render. See raw diff
 
diffusers3/models/autoencoders/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .autoencoder_asym_kl import AsymmetricAutoencoderKL
2
+ from .autoencoder_kl import AutoencoderKL
3
+ from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
4
+ from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
5
+ from .autoencoder_oobleck import AutoencoderOobleck
6
+ from .autoencoder_tiny import AutoencoderTiny
7
+ from .consistency_decoder_vae import ConsistencyDecoderVAE
8
+ from .vq_model import VQModel
diffusers3/models/autoencoders/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (658 Bytes). View file
 
diffusers3/models/autoencoders/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (663 Bytes). View file
 
diffusers3/models/autoencoders/__pycache__/autoencoder_asym_kl.cpython-310.pyc ADDED
Binary file (6.57 kB). View file
 
diffusers3/models/autoencoders/__pycache__/autoencoder_asym_kl.cpython-38.pyc ADDED
Binary file (6.48 kB). View file
 
diffusers3/models/autoencoders/__pycache__/autoencoder_kl.cpython-310.pyc ADDED
Binary file (18.7 kB). View file
 
diffusers3/models/autoencoders/__pycache__/autoencoder_kl.cpython-38.pyc ADDED
Binary file (18.6 kB). View file
 
diffusers3/models/autoencoders/__pycache__/autoencoder_kl_cogvideox.cpython-310.pyc ADDED
Binary file (40.6 kB). View file
 
diffusers3/models/autoencoders/__pycache__/autoencoder_kl_cogvideox.cpython-38.pyc ADDED
Binary file (40.2 kB). View file
 
diffusers3/models/autoencoders/__pycache__/autoencoder_kl_temporal_decoder.cpython-310.pyc ADDED
Binary file (13.3 kB). View file
 
diffusers3/models/autoencoders/__pycache__/autoencoder_kl_temporal_decoder.cpython-38.pyc ADDED
Binary file (13.2 kB). View file