xiaoanyu123 commited on
Commit
0ac5f99
·
verified ·
1 Parent(s): ee2f895

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. pythonProject/.venv/Lib/site-packages/diffusers/models/controlnets/__pycache__/__init__.cpython-310.pyc +0 -0
  2. pythonProject/.venv/Lib/site-packages/diffusers/models/controlnets/__pycache__/multicontrolnet.cpython-310.pyc +0 -0
  3. pythonProject/.venv/Lib/site-packages/diffusers/models/controlnets/__pycache__/multicontrolnet_union.cpython-310.pyc +0 -0
  4. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/__init__.cpython-310.pyc +0 -0
  5. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/stable_audio_transformer.cpython-310.pyc +0 -0
  6. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/t5_film_transformer.cpython-310.pyc +0 -0
  7. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_2d.cpython-310.pyc +0 -0
  8. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_allegro.cpython-310.pyc +0 -0
  9. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_bria.cpython-310.pyc +0 -0
  10. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_chroma.cpython-310.pyc +0 -0
  11. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_cogview3plus.cpython-310.pyc +0 -0
  12. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_cogview4.cpython-310.pyc +0 -0
  13. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_cosmos.cpython-310.pyc +0 -0
  14. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_easyanimate.cpython-310.pyc +0 -0
  15. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_flux.cpython-310.pyc +0 -0
  16. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_hidream_image.cpython-310.pyc +0 -0
  17. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_hunyuan_video.cpython-310.pyc +0 -0
  18. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_hunyuan_video_framepack.cpython-310.pyc +0 -0
  19. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_ltx.cpython-310.pyc +0 -0
  20. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_lumina2.cpython-310.pyc +0 -0
  21. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_mochi.cpython-310.pyc +0 -0
  22. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_omnigen.cpython-310.pyc +0 -0
  23. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_qwenimage.cpython-310.pyc +0 -0
  24. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_sd3.cpython-310.pyc +0 -0
  25. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_skyreels_v2.cpython-310.pyc +0 -0
  26. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_temporal.cpython-310.pyc +0 -0
  27. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_wan.cpython-310.pyc +0 -0
  28. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_wan_vace.cpython-310.pyc +0 -0
  29. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/auraflow_transformer_2d.py +564 -0
  30. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/cogvideox_transformer_3d.py +531 -0
  31. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  32. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/dit_transformer_2d.py +226 -0
  33. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/dual_transformer_2d.py +156 -0
  34. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/hunyuan_transformer_2d.py +579 -0
  35. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/latte_transformer_3d.py +331 -0
  36. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/t5_film_transformer.py +436 -0
  37. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  38. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/transformer_ltx.py +568 -0
  39. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/transformer_lumina2.py +548 -0
  40. pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__init__.py +18 -0
  41. pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/__init__.cpython-310.pyc +0 -0
  42. pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/unet_1d.cpython-310.pyc +0 -0
  43. pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/unet_1d_blocks.cpython-310.pyc +0 -0
  44. pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/unet_2d.cpython-310.pyc +0 -0
  45. pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/unet_2d_blocks.cpython-310.pyc +0 -0
  46. pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/unet_2d_blocks_flax.cpython-310.pyc +0 -0
  47. pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/unet_2d_condition.cpython-310.pyc +0 -0
  48. pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/unet_2d_condition_flax.cpython-310.pyc +0 -0
  49. pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/unet_3d_blocks.cpython-310.pyc +0 -0
  50. pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/unet_3d_condition.cpython-310.pyc +0 -0
pythonProject/.venv/Lib/site-packages/diffusers/models/controlnets/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.44 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/controlnets/__pycache__/multicontrolnet.cpython-310.pyc ADDED
Binary file (8.55 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/controlnets/__pycache__/multicontrolnet_union.cpython-310.pyc ADDED
Binary file (8.86 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (2.67 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/stable_audio_transformer.cpython-310.pyc ADDED
Binary file (14.1 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/t5_film_transformer.cpython-310.pyc ADDED
Binary file (13.7 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_2d.cpython-310.pyc ADDED
Binary file (16.6 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_allegro.cpython-310.pyc ADDED
Binary file (8.29 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_bria.cpython-310.pyc ADDED
Binary file (21.3 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_chroma.cpython-310.pyc ADDED
Binary file (19.2 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_cogview3plus.cpython-310.pyc ADDED
Binary file (12.2 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_cogview4.cpython-310.pyc ADDED
Binary file (21.5 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_cosmos.cpython-310.pyc ADDED
Binary file (18.3 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_easyanimate.cpython-310.pyc ADDED
Binary file (14.7 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_flux.cpython-310.pyc ADDED
Binary file (21 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_hidream_image.cpython-310.pyc ADDED
Binary file (25.5 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_hunyuan_video.cpython-310.pyc ADDED
Binary file (30.1 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_hunyuan_video_framepack.cpython-310.pyc ADDED
Binary file (11.9 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_ltx.cpython-310.pyc ADDED
Binary file (16.1 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_lumina2.cpython-310.pyc ADDED
Binary file (15.1 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_mochi.cpython-310.pyc ADDED
Binary file (14 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_omnigen.cpython-310.pyc ADDED
Binary file (14.8 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_qwenimage.cpython-310.pyc ADDED
Binary file (19.4 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_sd3.cpython-310.pyc ADDED
Binary file (14.9 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_skyreels_v2.cpython-310.pyc ADDED
Binary file (21.3 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_temporal.cpython-310.pyc ADDED
Binary file (12.9 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_wan.cpython-310.pyc ADDED
Binary file (18.9 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/transformer_wan_vace.cpython-310.pyc ADDED
Binary file (10 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/auraflow_transformer_2d.py ADDED
@@ -0,0 +1,564 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 AuraFlow Authors, The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Any, Dict, Optional, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from ...configuration_utils import ConfigMixin, register_to_config
23
+ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
24
+ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
25
+ from ...utils.torch_utils import maybe_allow_in_graph
26
+ from ..attention_processor import (
27
+ Attention,
28
+ AttentionProcessor,
29
+ AuraFlowAttnProcessor2_0,
30
+ FusedAuraFlowAttnProcessor2_0,
31
+ )
32
+ from ..embeddings import TimestepEmbedding, Timesteps
33
+ from ..modeling_outputs import Transformer2DModelOutput
34
+ from ..modeling_utils import ModelMixin
35
+ from ..normalization import AdaLayerNormZero, FP32LayerNorm
36
+
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+
41
+ # Taken from the original aura flow inference code.
42
+ def find_multiple(n: int, k: int) -> int:
43
+ if n % k == 0:
44
+ return n
45
+ return n + k - (n % k)
46
+
47
+
48
+ # Aura Flow patch embed doesn't use convs for projections.
49
+ # Additionally, it uses learned positional embeddings.
50
+ class AuraFlowPatchEmbed(nn.Module):
51
+ def __init__(
52
+ self,
53
+ height=224,
54
+ width=224,
55
+ patch_size=16,
56
+ in_channels=3,
57
+ embed_dim=768,
58
+ pos_embed_max_size=None,
59
+ ):
60
+ super().__init__()
61
+
62
+ self.num_patches = (height // patch_size) * (width // patch_size)
63
+ self.pos_embed_max_size = pos_embed_max_size
64
+
65
+ self.proj = nn.Linear(patch_size * patch_size * in_channels, embed_dim)
66
+ self.pos_embed = nn.Parameter(torch.randn(1, pos_embed_max_size, embed_dim) * 0.1)
67
+
68
+ self.patch_size = patch_size
69
+ self.height, self.width = height // patch_size, width // patch_size
70
+ self.base_size = height // patch_size
71
+
72
+ def pe_selection_index_based_on_dim(self, h, w):
73
+ # select subset of positional embedding based on H, W, where H, W is size of latent
74
+ # PE will be viewed as 2d-grid, and H/p x W/p of the PE will be selected
75
+ # because original input are in flattened format, we have to flatten this 2d grid as well.
76
+ h_p, w_p = h // self.patch_size, w // self.patch_size
77
+ h_max, w_max = int(self.pos_embed_max_size**0.5), int(self.pos_embed_max_size**0.5)
78
+
79
+ # Calculate the top-left corner indices for the centered patch grid
80
+ starth = h_max // 2 - h_p // 2
81
+ startw = w_max // 2 - w_p // 2
82
+
83
+ # Generate the row and column indices for the desired patch grid
84
+ rows = torch.arange(starth, starth + h_p, device=self.pos_embed.device)
85
+ cols = torch.arange(startw, startw + w_p, device=self.pos_embed.device)
86
+
87
+ # Create a 2D grid of indices
88
+ row_indices, col_indices = torch.meshgrid(rows, cols, indexing="ij")
89
+
90
+ # Convert the 2D grid indices to flattened 1D indices
91
+ selected_indices = (row_indices * w_max + col_indices).flatten()
92
+
93
+ return selected_indices
94
+
95
+ def forward(self, latent):
96
+ batch_size, num_channels, height, width = latent.size()
97
+ latent = latent.view(
98
+ batch_size,
99
+ num_channels,
100
+ height // self.patch_size,
101
+ self.patch_size,
102
+ width // self.patch_size,
103
+ self.patch_size,
104
+ )
105
+ latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
106
+ latent = self.proj(latent)
107
+ pe_index = self.pe_selection_index_based_on_dim(height, width)
108
+ return latent + self.pos_embed[:, pe_index]
109
+
110
+
111
+ # Taken from the original Aura flow inference code.
112
+ # Our feedforward only has GELU but Aura uses SiLU.
113
+ class AuraFlowFeedForward(nn.Module):
114
+ def __init__(self, dim, hidden_dim=None) -> None:
115
+ super().__init__()
116
+ if hidden_dim is None:
117
+ hidden_dim = 4 * dim
118
+
119
+ final_hidden_dim = int(2 * hidden_dim / 3)
120
+ final_hidden_dim = find_multiple(final_hidden_dim, 256)
121
+
122
+ self.linear_1 = nn.Linear(dim, final_hidden_dim, bias=False)
123
+ self.linear_2 = nn.Linear(dim, final_hidden_dim, bias=False)
124
+ self.out_projection = nn.Linear(final_hidden_dim, dim, bias=False)
125
+
126
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
127
+ x = F.silu(self.linear_1(x)) * self.linear_2(x)
128
+ x = self.out_projection(x)
129
+ return x
130
+
131
+
132
+ class AuraFlowPreFinalBlock(nn.Module):
133
+ def __init__(self, embedding_dim: int, conditioning_embedding_dim: int):
134
+ super().__init__()
135
+
136
+ self.silu = nn.SiLU()
137
+ self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=False)
138
+
139
+ def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
140
+ emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
141
+ scale, shift = torch.chunk(emb, 2, dim=1)
142
+ x = x * (1 + scale)[:, None, :] + shift[:, None, :]
143
+ return x
144
+
145
+
146
+ @maybe_allow_in_graph
147
+ class AuraFlowSingleTransformerBlock(nn.Module):
148
+ """Similar to `AuraFlowJointTransformerBlock` with a single DiT instead of an MMDiT."""
149
+
150
+ def __init__(self, dim, num_attention_heads, attention_head_dim):
151
+ super().__init__()
152
+
153
+ self.norm1 = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm")
154
+
155
+ processor = AuraFlowAttnProcessor2_0()
156
+ self.attn = Attention(
157
+ query_dim=dim,
158
+ cross_attention_dim=None,
159
+ dim_head=attention_head_dim,
160
+ heads=num_attention_heads,
161
+ qk_norm="fp32_layer_norm",
162
+ out_dim=dim,
163
+ bias=False,
164
+ out_bias=False,
165
+ processor=processor,
166
+ )
167
+
168
+ self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
169
+ self.ff = AuraFlowFeedForward(dim, dim * 4)
170
+
171
+ def forward(
172
+ self,
173
+ hidden_states: torch.FloatTensor,
174
+ temb: torch.FloatTensor,
175
+ attention_kwargs: Optional[Dict[str, Any]] = None,
176
+ ):
177
+ residual = hidden_states
178
+ attention_kwargs = attention_kwargs or {}
179
+
180
+ # Norm + Projection.
181
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
182
+
183
+ # Attention.
184
+ attn_output = self.attn(hidden_states=norm_hidden_states, **attention_kwargs)
185
+
186
+ # Process attention outputs for the `hidden_states`.
187
+ hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output)
188
+ hidden_states = hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
189
+ ff_output = self.ff(hidden_states)
190
+ hidden_states = gate_mlp.unsqueeze(1) * ff_output
191
+ hidden_states = residual + hidden_states
192
+
193
+ return hidden_states
194
+
195
+
196
+ @maybe_allow_in_graph
197
+ class AuraFlowJointTransformerBlock(nn.Module):
198
+ r"""
199
+ Transformer block for Aura Flow. Similar to SD3 MMDiT. Differences (non-exhaustive):
200
+
201
+ * QK Norm in the attention blocks
202
+ * No bias in the attention blocks
203
+ * Most LayerNorms are in FP32
204
+
205
+ Parameters:
206
+ dim (`int`): The number of channels in the input and output.
207
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
208
+ attention_head_dim (`int`): The number of channels in each head.
209
+ is_last (`bool`): Boolean to determine if this is the last block in the model.
210
+ """
211
+
212
+ def __init__(self, dim, num_attention_heads, attention_head_dim):
213
+ super().__init__()
214
+
215
+ self.norm1 = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm")
216
+ self.norm1_context = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm")
217
+
218
+ processor = AuraFlowAttnProcessor2_0()
219
+ self.attn = Attention(
220
+ query_dim=dim,
221
+ cross_attention_dim=None,
222
+ added_kv_proj_dim=dim,
223
+ added_proj_bias=False,
224
+ dim_head=attention_head_dim,
225
+ heads=num_attention_heads,
226
+ qk_norm="fp32_layer_norm",
227
+ out_dim=dim,
228
+ bias=False,
229
+ out_bias=False,
230
+ processor=processor,
231
+ context_pre_only=False,
232
+ )
233
+
234
+ self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
235
+ self.ff = AuraFlowFeedForward(dim, dim * 4)
236
+ self.norm2_context = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
237
+ self.ff_context = AuraFlowFeedForward(dim, dim * 4)
238
+
239
+ def forward(
240
+ self,
241
+ hidden_states: torch.FloatTensor,
242
+ encoder_hidden_states: torch.FloatTensor,
243
+ temb: torch.FloatTensor,
244
+ attention_kwargs: Optional[Dict[str, Any]] = None,
245
+ ):
246
+ residual = hidden_states
247
+ residual_context = encoder_hidden_states
248
+ attention_kwargs = attention_kwargs or {}
249
+
250
+ # Norm + Projection.
251
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
252
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
253
+ encoder_hidden_states, emb=temb
254
+ )
255
+
256
+ # Attention.
257
+ attn_output, context_attn_output = self.attn(
258
+ hidden_states=norm_hidden_states,
259
+ encoder_hidden_states=norm_encoder_hidden_states,
260
+ **attention_kwargs,
261
+ )
262
+
263
+ # Process attention outputs for the `hidden_states`.
264
+ hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output)
265
+ hidden_states = hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
266
+ hidden_states = gate_mlp.unsqueeze(1) * self.ff(hidden_states)
267
+ hidden_states = residual + hidden_states
268
+
269
+ # Process attention outputs for the `encoder_hidden_states`.
270
+ encoder_hidden_states = self.norm2_context(residual_context + c_gate_msa.unsqueeze(1) * context_attn_output)
271
+ encoder_hidden_states = encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
272
+ encoder_hidden_states = c_gate_mlp.unsqueeze(1) * self.ff_context(encoder_hidden_states)
273
+ encoder_hidden_states = residual_context + encoder_hidden_states
274
+
275
+ return encoder_hidden_states, hidden_states
276
+
277
+
278
+ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
279
+ r"""
280
+ A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/).
281
+
282
+ Parameters:
283
+ sample_size (`int`): The width of the latent images. This is fixed during training since
284
+ it is used to learn a number of position embeddings.
285
+ patch_size (`int`): Patch size to turn the input data into small patches.
286
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input.
287
+ num_mmdit_layers (`int`, *optional*, defaults to 4): The number of layers of MMDiT Transformer blocks to use.
288
+ num_single_dit_layers (`int`, *optional*, defaults to 32):
289
+ The number of layers of Transformer blocks to use. These blocks use concatenated image and text
290
+ representations.
291
+ attention_head_dim (`int`, *optional*, defaults to 256): The number of channels in each head.
292
+ num_attention_heads (`int`, *optional*, defaults to 12): The number of heads to use for multi-head attention.
293
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
294
+ caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
295
+ out_channels (`int`, defaults to 4): Number of output channels.
296
+ pos_embed_max_size (`int`, defaults to 1024): Maximum positions to embed from the image latents.
297
+ """
298
+
299
+ _no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"]
300
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
301
+ _supports_gradient_checkpointing = True
302
+
303
+ @register_to_config
304
+ def __init__(
305
+ self,
306
+ sample_size: int = 64,
307
+ patch_size: int = 2,
308
+ in_channels: int = 4,
309
+ num_mmdit_layers: int = 4,
310
+ num_single_dit_layers: int = 32,
311
+ attention_head_dim: int = 256,
312
+ num_attention_heads: int = 12,
313
+ joint_attention_dim: int = 2048,
314
+ caption_projection_dim: int = 3072,
315
+ out_channels: int = 4,
316
+ pos_embed_max_size: int = 1024,
317
+ ):
318
+ super().__init__()
319
+ default_out_channels = in_channels
320
+ self.out_channels = out_channels if out_channels is not None else default_out_channels
321
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
322
+
323
+ self.pos_embed = AuraFlowPatchEmbed(
324
+ height=self.config.sample_size,
325
+ width=self.config.sample_size,
326
+ patch_size=self.config.patch_size,
327
+ in_channels=self.config.in_channels,
328
+ embed_dim=self.inner_dim,
329
+ pos_embed_max_size=pos_embed_max_size,
330
+ )
331
+
332
+ self.context_embedder = nn.Linear(
333
+ self.config.joint_attention_dim, self.config.caption_projection_dim, bias=False
334
+ )
335
+ self.time_step_embed = Timesteps(num_channels=256, downscale_freq_shift=0, scale=1000, flip_sin_to_cos=True)
336
+ self.time_step_proj = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim)
337
+
338
+ self.joint_transformer_blocks = nn.ModuleList(
339
+ [
340
+ AuraFlowJointTransformerBlock(
341
+ dim=self.inner_dim,
342
+ num_attention_heads=self.config.num_attention_heads,
343
+ attention_head_dim=self.config.attention_head_dim,
344
+ )
345
+ for i in range(self.config.num_mmdit_layers)
346
+ ]
347
+ )
348
+ self.single_transformer_blocks = nn.ModuleList(
349
+ [
350
+ AuraFlowSingleTransformerBlock(
351
+ dim=self.inner_dim,
352
+ num_attention_heads=self.config.num_attention_heads,
353
+ attention_head_dim=self.config.attention_head_dim,
354
+ )
355
+ for _ in range(self.config.num_single_dit_layers)
356
+ ]
357
+ )
358
+
359
+ self.norm_out = AuraFlowPreFinalBlock(self.inner_dim, self.inner_dim)
360
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False)
361
+
362
+ # https://huggingface.co/papers/2309.16588
363
+ # prevents artifacts in the attention maps
364
+ self.register_tokens = nn.Parameter(torch.randn(1, 8, self.inner_dim) * 0.02)
365
+
366
+ self.gradient_checkpointing = False
367
+
368
+ @property
369
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
370
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
371
+ r"""
372
+ Returns:
373
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
374
+ indexed by its weight name.
375
+ """
376
+ # set recursively
377
+ processors = {}
378
+
379
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
380
+ if hasattr(module, "get_processor"):
381
+ processors[f"{name}.processor"] = module.get_processor()
382
+
383
+ for sub_name, child in module.named_children():
384
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
385
+
386
+ return processors
387
+
388
+ for name, module in self.named_children():
389
+ fn_recursive_add_processors(name, module, processors)
390
+
391
+ return processors
392
+
393
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
394
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
395
+ r"""
396
+ Sets the attention processor to use to compute attention.
397
+
398
+ Parameters:
399
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
400
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
401
+ for **all** `Attention` layers.
402
+
403
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
404
+ processor. This is strongly recommended when setting trainable attention processors.
405
+
406
+ """
407
+ count = len(self.attn_processors.keys())
408
+
409
+ if isinstance(processor, dict) and len(processor) != count:
410
+ raise ValueError(
411
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
412
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
413
+ )
414
+
415
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
416
+ if hasattr(module, "set_processor"):
417
+ if not isinstance(processor, dict):
418
+ module.set_processor(processor)
419
+ else:
420
+ module.set_processor(processor.pop(f"{name}.processor"))
421
+
422
+ for sub_name, child in module.named_children():
423
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
424
+
425
+ for name, module in self.named_children():
426
+ fn_recursive_attn_processor(name, module, processor)
427
+
428
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedAuraFlowAttnProcessor2_0
429
+ def fuse_qkv_projections(self):
430
+ """
431
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
432
+ are fused. For cross-attention modules, key and value projection matrices are fused.
433
+
434
+ <Tip warning={true}>
435
+
436
+ This API is 🧪 experimental.
437
+
438
+ </Tip>
439
+ """
440
+ self.original_attn_processors = None
441
+
442
+ for _, attn_processor in self.attn_processors.items():
443
+ if "Added" in str(attn_processor.__class__.__name__):
444
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
445
+
446
+ self.original_attn_processors = self.attn_processors
447
+
448
+ for module in self.modules():
449
+ if isinstance(module, Attention):
450
+ module.fuse_projections(fuse=True)
451
+
452
+ self.set_attn_processor(FusedAuraFlowAttnProcessor2_0())
453
+
454
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
455
+ def unfuse_qkv_projections(self):
456
+ """Disables the fused QKV projection if enabled.
457
+
458
+ <Tip warning={true}>
459
+
460
+ This API is 🧪 experimental.
461
+
462
+ </Tip>
463
+
464
+ """
465
+ if self.original_attn_processors is not None:
466
+ self.set_attn_processor(self.original_attn_processors)
467
+
468
+ def forward(
469
+ self,
470
+ hidden_states: torch.FloatTensor,
471
+ encoder_hidden_states: torch.FloatTensor = None,
472
+ timestep: torch.LongTensor = None,
473
+ attention_kwargs: Optional[Dict[str, Any]] = None,
474
+ return_dict: bool = True,
475
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
476
+ if attention_kwargs is not None:
477
+ attention_kwargs = attention_kwargs.copy()
478
+ lora_scale = attention_kwargs.pop("scale", 1.0)
479
+ else:
480
+ lora_scale = 1.0
481
+
482
+ if USE_PEFT_BACKEND:
483
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
484
+ scale_lora_layers(self, lora_scale)
485
+ else:
486
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
487
+ logger.warning(
488
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
489
+ )
490
+
491
+ height, width = hidden_states.shape[-2:]
492
+
493
+ # Apply patch embedding, timestep embedding, and project the caption embeddings.
494
+ hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
495
+ temb = self.time_step_embed(timestep).to(dtype=next(self.parameters()).dtype)
496
+ temb = self.time_step_proj(temb)
497
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
498
+ encoder_hidden_states = torch.cat(
499
+ [self.register_tokens.repeat(encoder_hidden_states.size(0), 1, 1), encoder_hidden_states], dim=1
500
+ )
501
+
502
+ # MMDiT blocks.
503
+ for index_block, block in enumerate(self.joint_transformer_blocks):
504
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
505
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
506
+ block,
507
+ hidden_states,
508
+ encoder_hidden_states,
509
+ temb,
510
+ )
511
+
512
+ else:
513
+ encoder_hidden_states, hidden_states = block(
514
+ hidden_states=hidden_states,
515
+ encoder_hidden_states=encoder_hidden_states,
516
+ temb=temb,
517
+ attention_kwargs=attention_kwargs,
518
+ )
519
+
520
+ # Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text)
521
+ if len(self.single_transformer_blocks) > 0:
522
+ encoder_seq_len = encoder_hidden_states.size(1)
523
+ combined_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
524
+
525
+ for index_block, block in enumerate(self.single_transformer_blocks):
526
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
527
+ combined_hidden_states = self._gradient_checkpointing_func(
528
+ block,
529
+ combined_hidden_states,
530
+ temb,
531
+ )
532
+
533
+ else:
534
+ combined_hidden_states = block(
535
+ hidden_states=combined_hidden_states, temb=temb, attention_kwargs=attention_kwargs
536
+ )
537
+
538
+ hidden_states = combined_hidden_states[:, encoder_seq_len:]
539
+
540
+ hidden_states = self.norm_out(hidden_states, temb)
541
+ hidden_states = self.proj_out(hidden_states)
542
+
543
+ # unpatchify
544
+ patch_size = self.config.patch_size
545
+ out_channels = self.config.out_channels
546
+ height = height // patch_size
547
+ width = width // patch_size
548
+
549
+ hidden_states = hidden_states.reshape(
550
+ shape=(hidden_states.shape[0], height, width, patch_size, patch_size, out_channels)
551
+ )
552
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
553
+ output = hidden_states.reshape(
554
+ shape=(hidden_states.shape[0], out_channels, height * patch_size, width * patch_size)
555
+ )
556
+
557
+ if USE_PEFT_BACKEND:
558
+ # remove `lora_scale` from each PEFT layer
559
+ unscale_lora_layers(self, lora_scale)
560
+
561
+ if not return_dict:
562
+ return (output,)
563
+
564
+ return Transformer2DModelOutput(sample=output)
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/cogvideox_transformer_3d.py ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any, Dict, Optional, Tuple, Union
17
+
18
+ import torch
19
+ from torch import nn
20
+
21
+ from ...configuration_utils import ConfigMixin, register_to_config
22
+ from ...loaders import PeftAdapterMixin
23
+ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
24
+ from ...utils.torch_utils import maybe_allow_in_graph
25
+ from ..attention import Attention, FeedForward
26
+ from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
27
+ from ..cache_utils import CacheMixin
28
+ from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
29
+ from ..modeling_outputs import Transformer2DModelOutput
30
+ from ..modeling_utils import ModelMixin
31
+ from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
32
+
33
+
34
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
+
36
+
37
+ @maybe_allow_in_graph
38
+ class CogVideoXBlock(nn.Module):
39
+ r"""
40
+ Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
41
+
42
+ Parameters:
43
+ dim (`int`):
44
+ The number of channels in the input and output.
45
+ num_attention_heads (`int`):
46
+ The number of heads to use for multi-head attention.
47
+ attention_head_dim (`int`):
48
+ The number of channels in each head.
49
+ time_embed_dim (`int`):
50
+ The number of channels in timestep embedding.
51
+ dropout (`float`, defaults to `0.0`):
52
+ The dropout probability to use.
53
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
54
+ Activation function to be used in feed-forward.
55
+ attention_bias (`bool`, defaults to `False`):
56
+ Whether or not to use bias in attention projection layers.
57
+ qk_norm (`bool`, defaults to `True`):
58
+ Whether or not to use normalization after query and key projections in Attention.
59
+ norm_elementwise_affine (`bool`, defaults to `True`):
60
+ Whether to use learnable elementwise affine parameters for normalization.
61
+ norm_eps (`float`, defaults to `1e-5`):
62
+ Epsilon value for normalization layers.
63
+ final_dropout (`bool` defaults to `False`):
64
+ Whether to apply a final dropout after the last feed-forward layer.
65
+ ff_inner_dim (`int`, *optional*, defaults to `None`):
66
+ Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
67
+ ff_bias (`bool`, defaults to `True`):
68
+ Whether or not to use bias in Feed-forward layer.
69
+ attention_out_bias (`bool`, defaults to `True`):
70
+ Whether or not to use bias in Attention output projection layer.
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ dim: int,
76
+ num_attention_heads: int,
77
+ attention_head_dim: int,
78
+ time_embed_dim: int,
79
+ dropout: float = 0.0,
80
+ activation_fn: str = "gelu-approximate",
81
+ attention_bias: bool = False,
82
+ qk_norm: bool = True,
83
+ norm_elementwise_affine: bool = True,
84
+ norm_eps: float = 1e-5,
85
+ final_dropout: bool = True,
86
+ ff_inner_dim: Optional[int] = None,
87
+ ff_bias: bool = True,
88
+ attention_out_bias: bool = True,
89
+ ):
90
+ super().__init__()
91
+
92
+ # 1. Self Attention
93
+ self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
94
+
95
+ self.attn1 = Attention(
96
+ query_dim=dim,
97
+ dim_head=attention_head_dim,
98
+ heads=num_attention_heads,
99
+ qk_norm="layer_norm" if qk_norm else None,
100
+ eps=1e-6,
101
+ bias=attention_bias,
102
+ out_bias=attention_out_bias,
103
+ processor=CogVideoXAttnProcessor2_0(),
104
+ )
105
+
106
+ # 2. Feed Forward
107
+ self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
108
+
109
+ self.ff = FeedForward(
110
+ dim,
111
+ dropout=dropout,
112
+ activation_fn=activation_fn,
113
+ final_dropout=final_dropout,
114
+ inner_dim=ff_inner_dim,
115
+ bias=ff_bias,
116
+ )
117
+
118
+ def forward(
119
+ self,
120
+ hidden_states: torch.Tensor,
121
+ encoder_hidden_states: torch.Tensor,
122
+ temb: torch.Tensor,
123
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
124
+ attention_kwargs: Optional[Dict[str, Any]] = None,
125
+ ) -> torch.Tensor:
126
+ text_seq_length = encoder_hidden_states.size(1)
127
+ attention_kwargs = attention_kwargs or {}
128
+
129
+ # norm & modulate
130
+ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
131
+ hidden_states, encoder_hidden_states, temb
132
+ )
133
+
134
+ # attention
135
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
136
+ hidden_states=norm_hidden_states,
137
+ encoder_hidden_states=norm_encoder_hidden_states,
138
+ image_rotary_emb=image_rotary_emb,
139
+ **attention_kwargs,
140
+ )
141
+
142
+ hidden_states = hidden_states + gate_msa * attn_hidden_states
143
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
144
+
145
+ # norm & modulate
146
+ norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
147
+ hidden_states, encoder_hidden_states, temb
148
+ )
149
+
150
+ # feed-forward
151
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
152
+ ff_output = self.ff(norm_hidden_states)
153
+
154
+ hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
155
+ encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
156
+
157
+ return hidden_states, encoder_hidden_states
158
+
159
+
160
+ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
161
+ """
162
+ A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
163
+
164
+ Parameters:
165
+ num_attention_heads (`int`, defaults to `30`):
166
+ The number of heads to use for multi-head attention.
167
+ attention_head_dim (`int`, defaults to `64`):
168
+ The number of channels in each head.
169
+ in_channels (`int`, defaults to `16`):
170
+ The number of channels in the input.
171
+ out_channels (`int`, *optional*, defaults to `16`):
172
+ The number of channels in the output.
173
+ flip_sin_to_cos (`bool`, defaults to `True`):
174
+ Whether to flip the sin to cos in the time embedding.
175
+ time_embed_dim (`int`, defaults to `512`):
176
+ Output dimension of timestep embeddings.
177
+ ofs_embed_dim (`int`, defaults to `512`):
178
+ Output dimension of "ofs" embeddings used in CogVideoX-5b-I2B in version 1.5
179
+ text_embed_dim (`int`, defaults to `4096`):
180
+ Input dimension of text embeddings from the text encoder.
181
+ num_layers (`int`, defaults to `30`):
182
+ The number of layers of Transformer blocks to use.
183
+ dropout (`float`, defaults to `0.0`):
184
+ The dropout probability to use.
185
+ attention_bias (`bool`, defaults to `True`):
186
+ Whether to use bias in the attention projection layers.
187
+ sample_width (`int`, defaults to `90`):
188
+ The width of the input latents.
189
+ sample_height (`int`, defaults to `60`):
190
+ The height of the input latents.
191
+ sample_frames (`int`, defaults to `49`):
192
+ The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
193
+ instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
194
+ but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
195
+ K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
196
+ patch_size (`int`, defaults to `2`):
197
+ The size of the patches to use in the patch embedding layer.
198
+ temporal_compression_ratio (`int`, defaults to `4`):
199
+ The compression ratio across the temporal dimension. See documentation for `sample_frames`.
200
+ max_text_seq_length (`int`, defaults to `226`):
201
+ The maximum sequence length of the input text embeddings.
202
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
203
+ Activation function to use in feed-forward.
204
+ timestep_activation_fn (`str`, defaults to `"silu"`):
205
+ Activation function to use when generating the timestep embeddings.
206
+ norm_elementwise_affine (`bool`, defaults to `True`):
207
+ Whether to use elementwise affine in normalization layers.
208
+ norm_eps (`float`, defaults to `1e-5`):
209
+ The epsilon value to use in normalization layers.
210
+ spatial_interpolation_scale (`float`, defaults to `1.875`):
211
+ Scaling factor to apply in 3D positional embeddings across spatial dimensions.
212
+ temporal_interpolation_scale (`float`, defaults to `1.0`):
213
+ Scaling factor to apply in 3D positional embeddings across temporal dimensions.
214
+ """
215
+
216
+ _skip_layerwise_casting_patterns = ["patch_embed", "norm"]
217
+ _supports_gradient_checkpointing = True
218
+ _no_split_modules = ["CogVideoXBlock", "CogVideoXPatchEmbed"]
219
+
220
+ @register_to_config
221
+ def __init__(
222
+ self,
223
+ num_attention_heads: int = 30,
224
+ attention_head_dim: int = 64,
225
+ in_channels: int = 16,
226
+ out_channels: Optional[int] = 16,
227
+ flip_sin_to_cos: bool = True,
228
+ freq_shift: int = 0,
229
+ time_embed_dim: int = 512,
230
+ ofs_embed_dim: Optional[int] = None,
231
+ text_embed_dim: int = 4096,
232
+ num_layers: int = 30,
233
+ dropout: float = 0.0,
234
+ attention_bias: bool = True,
235
+ sample_width: int = 90,
236
+ sample_height: int = 60,
237
+ sample_frames: int = 49,
238
+ patch_size: int = 2,
239
+ patch_size_t: Optional[int] = None,
240
+ temporal_compression_ratio: int = 4,
241
+ max_text_seq_length: int = 226,
242
+ activation_fn: str = "gelu-approximate",
243
+ timestep_activation_fn: str = "silu",
244
+ norm_elementwise_affine: bool = True,
245
+ norm_eps: float = 1e-5,
246
+ spatial_interpolation_scale: float = 1.875,
247
+ temporal_interpolation_scale: float = 1.0,
248
+ use_rotary_positional_embeddings: bool = False,
249
+ use_learned_positional_embeddings: bool = False,
250
+ patch_bias: bool = True,
251
+ ):
252
+ super().__init__()
253
+ inner_dim = num_attention_heads * attention_head_dim
254
+
255
+ if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
256
+ raise ValueError(
257
+ "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
258
+ "embeddings. If you're using a custom model and/or believe this should be supported, please open an "
259
+ "issue at https://github.com/huggingface/diffusers/issues."
260
+ )
261
+
262
+ # 1. Patch embedding
263
+ self.patch_embed = CogVideoXPatchEmbed(
264
+ patch_size=patch_size,
265
+ patch_size_t=patch_size_t,
266
+ in_channels=in_channels,
267
+ embed_dim=inner_dim,
268
+ text_embed_dim=text_embed_dim,
269
+ bias=patch_bias,
270
+ sample_width=sample_width,
271
+ sample_height=sample_height,
272
+ sample_frames=sample_frames,
273
+ temporal_compression_ratio=temporal_compression_ratio,
274
+ max_text_seq_length=max_text_seq_length,
275
+ spatial_interpolation_scale=spatial_interpolation_scale,
276
+ temporal_interpolation_scale=temporal_interpolation_scale,
277
+ use_positional_embeddings=not use_rotary_positional_embeddings,
278
+ use_learned_positional_embeddings=use_learned_positional_embeddings,
279
+ )
280
+ self.embedding_dropout = nn.Dropout(dropout)
281
+
282
+ # 2. Time embeddings and ofs embedding(Only CogVideoX1.5-5B I2V have)
283
+
284
+ self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
285
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
286
+
287
+ self.ofs_proj = None
288
+ self.ofs_embedding = None
289
+ if ofs_embed_dim:
290
+ self.ofs_proj = Timesteps(ofs_embed_dim, flip_sin_to_cos, freq_shift)
291
+ self.ofs_embedding = TimestepEmbedding(
292
+ ofs_embed_dim, ofs_embed_dim, timestep_activation_fn
293
+ ) # same as time embeddings, for ofs
294
+
295
+ # 3. Define spatio-temporal transformers blocks
296
+ self.transformer_blocks = nn.ModuleList(
297
+ [
298
+ CogVideoXBlock(
299
+ dim=inner_dim,
300
+ num_attention_heads=num_attention_heads,
301
+ attention_head_dim=attention_head_dim,
302
+ time_embed_dim=time_embed_dim,
303
+ dropout=dropout,
304
+ activation_fn=activation_fn,
305
+ attention_bias=attention_bias,
306
+ norm_elementwise_affine=norm_elementwise_affine,
307
+ norm_eps=norm_eps,
308
+ )
309
+ for _ in range(num_layers)
310
+ ]
311
+ )
312
+ self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
313
+
314
+ # 4. Output blocks
315
+ self.norm_out = AdaLayerNorm(
316
+ embedding_dim=time_embed_dim,
317
+ output_dim=2 * inner_dim,
318
+ norm_elementwise_affine=norm_elementwise_affine,
319
+ norm_eps=norm_eps,
320
+ chunk_dim=1,
321
+ )
322
+
323
+ if patch_size_t is None:
324
+ # For CogVideox 1.0
325
+ output_dim = patch_size * patch_size * out_channels
326
+ else:
327
+ # For CogVideoX 1.5
328
+ output_dim = patch_size * patch_size * patch_size_t * out_channels
329
+
330
+ self.proj_out = nn.Linear(inner_dim, output_dim)
331
+
332
+ self.gradient_checkpointing = False
333
+
334
+ @property
335
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
336
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
337
+ r"""
338
+ Returns:
339
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
340
+ indexed by its weight name.
341
+ """
342
+ # set recursively
343
+ processors = {}
344
+
345
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
346
+ if hasattr(module, "get_processor"):
347
+ processors[f"{name}.processor"] = module.get_processor()
348
+
349
+ for sub_name, child in module.named_children():
350
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
351
+
352
+ return processors
353
+
354
+ for name, module in self.named_children():
355
+ fn_recursive_add_processors(name, module, processors)
356
+
357
+ return processors
358
+
359
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
360
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
361
+ r"""
362
+ Sets the attention processor to use to compute attention.
363
+
364
+ Parameters:
365
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
366
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
367
+ for **all** `Attention` layers.
368
+
369
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
370
+ processor. This is strongly recommended when setting trainable attention processors.
371
+
372
+ """
373
+ count = len(self.attn_processors.keys())
374
+
375
+ if isinstance(processor, dict) and len(processor) != count:
376
+ raise ValueError(
377
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
378
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
379
+ )
380
+
381
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
382
+ if hasattr(module, "set_processor"):
383
+ if not isinstance(processor, dict):
384
+ module.set_processor(processor)
385
+ else:
386
+ module.set_processor(processor.pop(f"{name}.processor"))
387
+
388
+ for sub_name, child in module.named_children():
389
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
390
+
391
+ for name, module in self.named_children():
392
+ fn_recursive_attn_processor(name, module, processor)
393
+
394
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
395
+ def fuse_qkv_projections(self):
396
+ """
397
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
398
+ are fused. For cross-attention modules, key and value projection matrices are fused.
399
+
400
+ <Tip warning={true}>
401
+
402
+ This API is 🧪 experimental.
403
+
404
+ </Tip>
405
+ """
406
+ self.original_attn_processors = None
407
+
408
+ for _, attn_processor in self.attn_processors.items():
409
+ if "Added" in str(attn_processor.__class__.__name__):
410
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
411
+
412
+ self.original_attn_processors = self.attn_processors
413
+
414
+ for module in self.modules():
415
+ if isinstance(module, Attention):
416
+ module.fuse_projections(fuse=True)
417
+
418
+ self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
419
+
420
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
421
+ def unfuse_qkv_projections(self):
422
+ """Disables the fused QKV projection if enabled.
423
+
424
+ <Tip warning={true}>
425
+
426
+ This API is 🧪 experimental.
427
+
428
+ </Tip>
429
+
430
+ """
431
+ if self.original_attn_processors is not None:
432
+ self.set_attn_processor(self.original_attn_processors)
433
+
434
+ def forward(
435
+ self,
436
+ hidden_states: torch.Tensor,
437
+ encoder_hidden_states: torch.Tensor,
438
+ timestep: Union[int, float, torch.LongTensor],
439
+ timestep_cond: Optional[torch.Tensor] = None,
440
+ ofs: Optional[Union[int, float, torch.LongTensor]] = None,
441
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
442
+ attention_kwargs: Optional[Dict[str, Any]] = None,
443
+ return_dict: bool = True,
444
+ ):
445
+ if attention_kwargs is not None:
446
+ attention_kwargs = attention_kwargs.copy()
447
+ lora_scale = attention_kwargs.pop("scale", 1.0)
448
+ else:
449
+ lora_scale = 1.0
450
+
451
+ if USE_PEFT_BACKEND:
452
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
453
+ scale_lora_layers(self, lora_scale)
454
+ else:
455
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
456
+ logger.warning(
457
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
458
+ )
459
+
460
+ batch_size, num_frames, channels, height, width = hidden_states.shape
461
+
462
+ # 1. Time embedding
463
+ timesteps = timestep
464
+ t_emb = self.time_proj(timesteps)
465
+
466
+ # timesteps does not contain any weights and will always return f32 tensors
467
+ # but time_embedding might actually be running in fp16. so we need to cast here.
468
+ # there might be better ways to encapsulate this.
469
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
470
+ emb = self.time_embedding(t_emb, timestep_cond)
471
+
472
+ if self.ofs_embedding is not None:
473
+ ofs_emb = self.ofs_proj(ofs)
474
+ ofs_emb = ofs_emb.to(dtype=hidden_states.dtype)
475
+ ofs_emb = self.ofs_embedding(ofs_emb)
476
+ emb = emb + ofs_emb
477
+
478
+ # 2. Patch embedding
479
+ hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
480
+ hidden_states = self.embedding_dropout(hidden_states)
481
+
482
+ text_seq_length = encoder_hidden_states.shape[1]
483
+ encoder_hidden_states = hidden_states[:, :text_seq_length]
484
+ hidden_states = hidden_states[:, text_seq_length:]
485
+
486
+ # 3. Transformer blocks
487
+ for i, block in enumerate(self.transformer_blocks):
488
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
489
+ hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
490
+ block,
491
+ hidden_states,
492
+ encoder_hidden_states,
493
+ emb,
494
+ image_rotary_emb,
495
+ attention_kwargs,
496
+ )
497
+ else:
498
+ hidden_states, encoder_hidden_states = block(
499
+ hidden_states=hidden_states,
500
+ encoder_hidden_states=encoder_hidden_states,
501
+ temb=emb,
502
+ image_rotary_emb=image_rotary_emb,
503
+ attention_kwargs=attention_kwargs,
504
+ )
505
+
506
+ hidden_states = self.norm_final(hidden_states)
507
+
508
+ # 4. Final block
509
+ hidden_states = self.norm_out(hidden_states, temb=emb)
510
+ hidden_states = self.proj_out(hidden_states)
511
+
512
+ # 5. Unpatchify
513
+ p = self.config.patch_size
514
+ p_t = self.config.patch_size_t
515
+
516
+ if p_t is None:
517
+ output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
518
+ output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
519
+ else:
520
+ output = hidden_states.reshape(
521
+ batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
522
+ )
523
+ output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
524
+
525
+ if USE_PEFT_BACKEND:
526
+ # remove `lora_scale` from each PEFT layer
527
+ unscale_lora_layers(self, lora_scale)
528
+
529
+ if not return_dict:
530
+ return (output,)
531
+ return Transformer2DModelOutput(sample=output)
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/consisid_transformer_3d.py ADDED
@@ -0,0 +1,789 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ConsisID Authors and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ from torch import nn
20
+
21
+ from ...configuration_utils import ConfigMixin, register_to_config
22
+ from ...loaders import PeftAdapterMixin
23
+ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
24
+ from ...utils.torch_utils import maybe_allow_in_graph
25
+ from ..attention import Attention, FeedForward
26
+ from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0
27
+ from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
28
+ from ..modeling_outputs import Transformer2DModelOutput
29
+ from ..modeling_utils import ModelMixin
30
+ from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
31
+
32
+
33
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
34
+
35
+
36
+ class PerceiverAttention(nn.Module):
37
+ def __init__(self, dim: int, dim_head: int = 64, heads: int = 8, kv_dim: Optional[int] = None):
38
+ super().__init__()
39
+
40
+ self.scale = dim_head**-0.5
41
+ self.dim_head = dim_head
42
+ self.heads = heads
43
+ inner_dim = dim_head * heads
44
+
45
+ self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim)
46
+ self.norm2 = nn.LayerNorm(dim)
47
+
48
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
49
+ self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False)
50
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
51
+
52
+ def forward(self, image_embeds: torch.Tensor, latents: torch.Tensor) -> torch.Tensor:
53
+ # Apply normalization
54
+ image_embeds = self.norm1(image_embeds)
55
+ latents = self.norm2(latents)
56
+
57
+ batch_size, seq_len, _ = latents.shape # Get batch size and sequence length
58
+
59
+ # Compute query, key, and value matrices
60
+ query = self.to_q(latents)
61
+ kv_input = torch.cat((image_embeds, latents), dim=-2)
62
+ key, value = self.to_kv(kv_input).chunk(2, dim=-1)
63
+
64
+ # Reshape the tensors for multi-head attention
65
+ query = query.reshape(query.size(0), -1, self.heads, self.dim_head).transpose(1, 2)
66
+ key = key.reshape(key.size(0), -1, self.heads, self.dim_head).transpose(1, 2)
67
+ value = value.reshape(value.size(0), -1, self.heads, self.dim_head).transpose(1, 2)
68
+
69
+ # attention
70
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
71
+ weight = (query * scale) @ (key * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
72
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
73
+ output = weight @ value
74
+
75
+ # Reshape and return the final output
76
+ output = output.permute(0, 2, 1, 3).reshape(batch_size, seq_len, -1)
77
+
78
+ return self.to_out(output)
79
+
80
+
81
+ class LocalFacialExtractor(nn.Module):
82
+ def __init__(
83
+ self,
84
+ id_dim: int = 1280,
85
+ vit_dim: int = 1024,
86
+ depth: int = 10,
87
+ dim_head: int = 64,
88
+ heads: int = 16,
89
+ num_id_token: int = 5,
90
+ num_queries: int = 32,
91
+ output_dim: int = 2048,
92
+ ff_mult: int = 4,
93
+ num_scale: int = 5,
94
+ ):
95
+ super().__init__()
96
+
97
+ # Storing identity token and query information
98
+ self.num_id_token = num_id_token
99
+ self.vit_dim = vit_dim
100
+ self.num_queries = num_queries
101
+ assert depth % num_scale == 0
102
+ self.depth = depth // num_scale
103
+ self.num_scale = num_scale
104
+ scale = vit_dim**-0.5
105
+
106
+ # Learnable latent query embeddings
107
+ self.latents = nn.Parameter(torch.randn(1, num_queries, vit_dim) * scale)
108
+ # Projection layer to map the latent output to the desired dimension
109
+ self.proj_out = nn.Parameter(scale * torch.randn(vit_dim, output_dim))
110
+
111
+ # Attention and ConsisIDFeedForward layer stack
112
+ self.layers = nn.ModuleList([])
113
+ for _ in range(depth):
114
+ self.layers.append(
115
+ nn.ModuleList(
116
+ [
117
+ PerceiverAttention(dim=vit_dim, dim_head=dim_head, heads=heads), # Perceiver Attention layer
118
+ nn.Sequential(
119
+ nn.LayerNorm(vit_dim),
120
+ nn.Linear(vit_dim, vit_dim * ff_mult, bias=False),
121
+ nn.GELU(),
122
+ nn.Linear(vit_dim * ff_mult, vit_dim, bias=False),
123
+ ), # ConsisIDFeedForward layer
124
+ ]
125
+ )
126
+ )
127
+
128
+ # Mappings for each of the 5 different ViT features
129
+ for i in range(num_scale):
130
+ setattr(
131
+ self,
132
+ f"mapping_{i}",
133
+ nn.Sequential(
134
+ nn.Linear(vit_dim, vit_dim),
135
+ nn.LayerNorm(vit_dim),
136
+ nn.LeakyReLU(),
137
+ nn.Linear(vit_dim, vit_dim),
138
+ nn.LayerNorm(vit_dim),
139
+ nn.LeakyReLU(),
140
+ nn.Linear(vit_dim, vit_dim),
141
+ ),
142
+ )
143
+
144
+ # Mapping for identity embedding vectors
145
+ self.id_embedding_mapping = nn.Sequential(
146
+ nn.Linear(id_dim, vit_dim),
147
+ nn.LayerNorm(vit_dim),
148
+ nn.LeakyReLU(),
149
+ nn.Linear(vit_dim, vit_dim),
150
+ nn.LayerNorm(vit_dim),
151
+ nn.LeakyReLU(),
152
+ nn.Linear(vit_dim, vit_dim * num_id_token),
153
+ )
154
+
155
+ def forward(self, id_embeds: torch.Tensor, vit_hidden_states: List[torch.Tensor]) -> torch.Tensor:
156
+ # Repeat latent queries for the batch size
157
+ latents = self.latents.repeat(id_embeds.size(0), 1, 1)
158
+
159
+ # Map the identity embedding to tokens
160
+ id_embeds = self.id_embedding_mapping(id_embeds)
161
+ id_embeds = id_embeds.reshape(-1, self.num_id_token, self.vit_dim)
162
+
163
+ # Concatenate identity tokens with the latent queries
164
+ latents = torch.cat((latents, id_embeds), dim=1)
165
+
166
+ # Process each of the num_scale visual feature inputs
167
+ for i in range(self.num_scale):
168
+ vit_feature = getattr(self, f"mapping_{i}")(vit_hidden_states[i])
169
+ ctx_feature = torch.cat((id_embeds, vit_feature), dim=1)
170
+
171
+ # Pass through the PerceiverAttention and ConsisIDFeedForward layers
172
+ for attn, ff in self.layers[i * self.depth : (i + 1) * self.depth]:
173
+ latents = attn(ctx_feature, latents) + latents
174
+ latents = ff(latents) + latents
175
+
176
+ # Retain only the query latents
177
+ latents = latents[:, : self.num_queries]
178
+ # Project the latents to the output dimension
179
+ latents = latents @ self.proj_out
180
+ return latents
181
+
182
+
183
+ class PerceiverCrossAttention(nn.Module):
184
+ def __init__(self, dim: int = 3072, dim_head: int = 128, heads: int = 16, kv_dim: int = 2048):
185
+ super().__init__()
186
+
187
+ self.scale = dim_head**-0.5
188
+ self.dim_head = dim_head
189
+ self.heads = heads
190
+ inner_dim = dim_head * heads
191
+
192
+ # Layer normalization to stabilize training
193
+ self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim)
194
+ self.norm2 = nn.LayerNorm(dim)
195
+
196
+ # Linear transformations to produce queries, keys, and values
197
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
198
+ self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False)
199
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
200
+
201
+ def forward(self, image_embeds: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor:
202
+ # Apply layer normalization to the input image and latent features
203
+ image_embeds = self.norm1(image_embeds)
204
+ hidden_states = self.norm2(hidden_states)
205
+
206
+ batch_size, seq_len, _ = hidden_states.shape
207
+
208
+ # Compute queries, keys, and values
209
+ query = self.to_q(hidden_states)
210
+ key, value = self.to_kv(image_embeds).chunk(2, dim=-1)
211
+
212
+ # Reshape tensors to split into attention heads
213
+ query = query.reshape(query.size(0), -1, self.heads, self.dim_head).transpose(1, 2)
214
+ key = key.reshape(key.size(0), -1, self.heads, self.dim_head).transpose(1, 2)
215
+ value = value.reshape(value.size(0), -1, self.heads, self.dim_head).transpose(1, 2)
216
+
217
+ # Compute attention weights
218
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
219
+ weight = (query * scale) @ (key * scale).transpose(-2, -1) # More stable scaling than post-division
220
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
221
+
222
+ # Compute the output via weighted combination of values
223
+ out = weight @ value
224
+
225
+ # Reshape and permute to prepare for final linear transformation
226
+ out = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, -1)
227
+
228
+ return self.to_out(out)
229
+
230
+
231
+ @maybe_allow_in_graph
232
+ class ConsisIDBlock(nn.Module):
233
+ r"""
234
+ Transformer block used in [ConsisID](https://github.com/PKU-YuanGroup/ConsisID) model.
235
+
236
+ Parameters:
237
+ dim (`int`):
238
+ The number of channels in the input and output.
239
+ num_attention_heads (`int`):
240
+ The number of heads to use for multi-head attention.
241
+ attention_head_dim (`int`):
242
+ The number of channels in each head.
243
+ time_embed_dim (`int`):
244
+ The number of channels in timestep embedding.
245
+ dropout (`float`, defaults to `0.0`):
246
+ The dropout probability to use.
247
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
248
+ Activation function to be used in feed-forward.
249
+ attention_bias (`bool`, defaults to `False`):
250
+ Whether or not to use bias in attention projection layers.
251
+ qk_norm (`bool`, defaults to `True`):
252
+ Whether or not to use normalization after query and key projections in Attention.
253
+ norm_elementwise_affine (`bool`, defaults to `True`):
254
+ Whether to use learnable elementwise affine parameters for normalization.
255
+ norm_eps (`float`, defaults to `1e-5`):
256
+ Epsilon value for normalization layers.
257
+ final_dropout (`bool` defaults to `False`):
258
+ Whether to apply a final dropout after the last feed-forward layer.
259
+ ff_inner_dim (`int`, *optional*, defaults to `None`):
260
+ Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
261
+ ff_bias (`bool`, defaults to `True`):
262
+ Whether or not to use bias in Feed-forward layer.
263
+ attention_out_bias (`bool`, defaults to `True`):
264
+ Whether or not to use bias in Attention output projection layer.
265
+ """
266
+
267
+ def __init__(
268
+ self,
269
+ dim: int,
270
+ num_attention_heads: int,
271
+ attention_head_dim: int,
272
+ time_embed_dim: int,
273
+ dropout: float = 0.0,
274
+ activation_fn: str = "gelu-approximate",
275
+ attention_bias: bool = False,
276
+ qk_norm: bool = True,
277
+ norm_elementwise_affine: bool = True,
278
+ norm_eps: float = 1e-5,
279
+ final_dropout: bool = True,
280
+ ff_inner_dim: Optional[int] = None,
281
+ ff_bias: bool = True,
282
+ attention_out_bias: bool = True,
283
+ ):
284
+ super().__init__()
285
+
286
+ # 1. Self Attention
287
+ self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
288
+
289
+ self.attn1 = Attention(
290
+ query_dim=dim,
291
+ dim_head=attention_head_dim,
292
+ heads=num_attention_heads,
293
+ qk_norm="layer_norm" if qk_norm else None,
294
+ eps=1e-6,
295
+ bias=attention_bias,
296
+ out_bias=attention_out_bias,
297
+ processor=CogVideoXAttnProcessor2_0(),
298
+ )
299
+
300
+ # 2. Feed Forward
301
+ self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
302
+
303
+ self.ff = FeedForward(
304
+ dim,
305
+ dropout=dropout,
306
+ activation_fn=activation_fn,
307
+ final_dropout=final_dropout,
308
+ inner_dim=ff_inner_dim,
309
+ bias=ff_bias,
310
+ )
311
+
312
+ def forward(
313
+ self,
314
+ hidden_states: torch.Tensor,
315
+ encoder_hidden_states: torch.Tensor,
316
+ temb: torch.Tensor,
317
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
318
+ ) -> torch.Tensor:
319
+ text_seq_length = encoder_hidden_states.size(1)
320
+
321
+ # norm & modulate
322
+ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
323
+ hidden_states, encoder_hidden_states, temb
324
+ )
325
+
326
+ # attention
327
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
328
+ hidden_states=norm_hidden_states,
329
+ encoder_hidden_states=norm_encoder_hidden_states,
330
+ image_rotary_emb=image_rotary_emb,
331
+ )
332
+
333
+ hidden_states = hidden_states + gate_msa * attn_hidden_states
334
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
335
+
336
+ # norm & modulate
337
+ norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
338
+ hidden_states, encoder_hidden_states, temb
339
+ )
340
+
341
+ # feed-forward
342
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
343
+ ff_output = self.ff(norm_hidden_states)
344
+
345
+ hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
346
+ encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
347
+
348
+ return hidden_states, encoder_hidden_states
349
+
350
+
351
+ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
352
+ """
353
+ A Transformer model for video-like data in [ConsisID](https://github.com/PKU-YuanGroup/ConsisID).
354
+
355
+ Parameters:
356
+ num_attention_heads (`int`, defaults to `30`):
357
+ The number of heads to use for multi-head attention.
358
+ attention_head_dim (`int`, defaults to `64`):
359
+ The number of channels in each head.
360
+ in_channels (`int`, defaults to `16`):
361
+ The number of channels in the input.
362
+ out_channels (`int`, *optional*, defaults to `16`):
363
+ The number of channels in the output.
364
+ flip_sin_to_cos (`bool`, defaults to `True`):
365
+ Whether to flip the sin to cos in the time embedding.
366
+ time_embed_dim (`int`, defaults to `512`):
367
+ Output dimension of timestep embeddings.
368
+ text_embed_dim (`int`, defaults to `4096`):
369
+ Input dimension of text embeddings from the text encoder.
370
+ num_layers (`int`, defaults to `30`):
371
+ The number of layers of Transformer blocks to use.
372
+ dropout (`float`, defaults to `0.0`):
373
+ The dropout probability to use.
374
+ attention_bias (`bool`, defaults to `True`):
375
+ Whether to use bias in the attention projection layers.
376
+ sample_width (`int`, defaults to `90`):
377
+ The width of the input latents.
378
+ sample_height (`int`, defaults to `60`):
379
+ The height of the input latents.
380
+ sample_frames (`int`, defaults to `49`):
381
+ The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
382
+ instead of 13 because ConsisID processed 13 latent frames at once in its default and recommended settings,
383
+ but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
384
+ K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
385
+ patch_size (`int`, defaults to `2`):
386
+ The size of the patches to use in the patch embedding layer.
387
+ temporal_compression_ratio (`int`, defaults to `4`):
388
+ The compression ratio across the temporal dimension. See documentation for `sample_frames`.
389
+ max_text_seq_length (`int`, defaults to `226`):
390
+ The maximum sequence length of the input text embeddings.
391
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
392
+ Activation function to use in feed-forward.
393
+ timestep_activation_fn (`str`, defaults to `"silu"`):
394
+ Activation function to use when generating the timestep embeddings.
395
+ norm_elementwise_affine (`bool`, defaults to `True`):
396
+ Whether to use elementwise affine in normalization layers.
397
+ norm_eps (`float`, defaults to `1e-5`):
398
+ The epsilon value to use in normalization layers.
399
+ spatial_interpolation_scale (`float`, defaults to `1.875`):
400
+ Scaling factor to apply in 3D positional embeddings across spatial dimensions.
401
+ temporal_interpolation_scale (`float`, defaults to `1.0`):
402
+ Scaling factor to apply in 3D positional embeddings across temporal dimensions.
403
+ is_train_face (`bool`, defaults to `False`):
404
+ Whether to use enable the identity-preserving module during the training process. When set to `True`, the
405
+ model will focus on identity-preserving tasks.
406
+ is_kps (`bool`, defaults to `False`):
407
+ Whether to enable keypoint for global facial extractor. If `True`, keypoints will be in the model.
408
+ cross_attn_interval (`int`, defaults to `2`):
409
+ The interval between cross-attention layers in the Transformer architecture. A larger value may reduce the
410
+ frequency of cross-attention computations, which can help reduce computational overhead.
411
+ cross_attn_dim_head (`int`, optional, defaults to `128`):
412
+ The dimensionality of each attention head in the cross-attention layers of the Transformer architecture. A
413
+ larger value increases the capacity to attend to more complex patterns, but also increases memory and
414
+ computation costs.
415
+ cross_attn_num_heads (`int`, optional, defaults to `16`):
416
+ The number of attention heads in the cross-attention layers. More heads allow for more parallel attention
417
+ mechanisms, capturing diverse relationships between different components of the input, but can also
418
+ increase computational requirements.
419
+ LFE_id_dim (`int`, optional, defaults to `1280`):
420
+ The dimensionality of the identity vector used in the Local Facial Extractor (LFE). This vector represents
421
+ the identity features of a face, which are important for tasks like face recognition and identity
422
+ preservation across different frames.
423
+ LFE_vit_dim (`int`, optional, defaults to `1024`):
424
+ The dimension of the vision transformer (ViT) output used in the Local Facial Extractor (LFE). This value
425
+ dictates the size of the transformer-generated feature vectors that will be processed for facial feature
426
+ extraction.
427
+ LFE_depth (`int`, optional, defaults to `10`):
428
+ The number of layers in the Local Facial Extractor (LFE). Increasing the depth allows the model to capture
429
+ more complex representations of facial features, but also increases the computational load.
430
+ LFE_dim_head (`int`, optional, defaults to `64`):
431
+ The dimensionality of each attention head in the Local Facial Extractor (LFE). This parameter affects how
432
+ finely the model can process and focus on different parts of the facial features during the extraction
433
+ process.
434
+ LFE_num_heads (`int`, optional, defaults to `16`):
435
+ The number of attention heads in the Local Facial Extractor (LFE). More heads can improve the model's
436
+ ability to capture diverse facial features, but at the cost of increased computational complexity.
437
+ LFE_num_id_token (`int`, optional, defaults to `5`):
438
+ The number of identity tokens used in the Local Facial Extractor (LFE). This defines how many
439
+ identity-related tokens the model will process to ensure face identity preservation during feature
440
+ extraction.
441
+ LFE_num_querie (`int`, optional, defaults to `32`):
442
+ The number of query tokens used in the Local Facial Extractor (LFE). These tokens are used to capture
443
+ high-frequency face-related information that aids in accurate facial feature extraction.
444
+ LFE_output_dim (`int`, optional, defaults to `2048`):
445
+ The output dimension of the Local Facial Extractor (LFE). This dimension determines the size of the feature
446
+ vectors produced by the LFE module, which will be used for subsequent tasks such as face recognition or
447
+ tracking.
448
+ LFE_ff_mult (`int`, optional, defaults to `4`):
449
+ The multiplication factor applied to the feed-forward network's hidden layer size in the Local Facial
450
+ Extractor (LFE). A higher value increases the model's capacity to learn more complex facial feature
451
+ transformations, but also increases the computation and memory requirements.
452
+ LFE_num_scale (`int`, optional, defaults to `5`):
453
+ The number of different scales visual feature. A higher value increases the model's capacity to learn more
454
+ complex facial feature transformations, but also increases the computation and memory requirements.
455
+ local_face_scale (`float`, defaults to `1.0`):
456
+ A scaling factor used to adjust the importance of local facial features in the model. This can influence
457
+ how strongly the model focuses on high frequency face-related content.
458
+ """
459
+
460
+ _supports_gradient_checkpointing = True
461
+
462
+ @register_to_config
463
+ def __init__(
464
+ self,
465
+ num_attention_heads: int = 30,
466
+ attention_head_dim: int = 64,
467
+ in_channels: int = 16,
468
+ out_channels: Optional[int] = 16,
469
+ flip_sin_to_cos: bool = True,
470
+ freq_shift: int = 0,
471
+ time_embed_dim: int = 512,
472
+ text_embed_dim: int = 4096,
473
+ num_layers: int = 30,
474
+ dropout: float = 0.0,
475
+ attention_bias: bool = True,
476
+ sample_width: int = 90,
477
+ sample_height: int = 60,
478
+ sample_frames: int = 49,
479
+ patch_size: int = 2,
480
+ temporal_compression_ratio: int = 4,
481
+ max_text_seq_length: int = 226,
482
+ activation_fn: str = "gelu-approximate",
483
+ timestep_activation_fn: str = "silu",
484
+ norm_elementwise_affine: bool = True,
485
+ norm_eps: float = 1e-5,
486
+ spatial_interpolation_scale: float = 1.875,
487
+ temporal_interpolation_scale: float = 1.0,
488
+ use_rotary_positional_embeddings: bool = False,
489
+ use_learned_positional_embeddings: bool = False,
490
+ is_train_face: bool = False,
491
+ is_kps: bool = False,
492
+ cross_attn_interval: int = 2,
493
+ cross_attn_dim_head: int = 128,
494
+ cross_attn_num_heads: int = 16,
495
+ LFE_id_dim: int = 1280,
496
+ LFE_vit_dim: int = 1024,
497
+ LFE_depth: int = 10,
498
+ LFE_dim_head: int = 64,
499
+ LFE_num_heads: int = 16,
500
+ LFE_num_id_token: int = 5,
501
+ LFE_num_querie: int = 32,
502
+ LFE_output_dim: int = 2048,
503
+ LFE_ff_mult: int = 4,
504
+ LFE_num_scale: int = 5,
505
+ local_face_scale: float = 1.0,
506
+ ):
507
+ super().__init__()
508
+ inner_dim = num_attention_heads * attention_head_dim
509
+
510
+ if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
511
+ raise ValueError(
512
+ "There are no ConsisID checkpoints available with disable rotary embeddings and learned positional "
513
+ "embeddings. If you're using a custom model and/or believe this should be supported, please open an "
514
+ "issue at https://github.com/huggingface/diffusers/issues."
515
+ )
516
+
517
+ # 1. Patch embedding
518
+ self.patch_embed = CogVideoXPatchEmbed(
519
+ patch_size=patch_size,
520
+ in_channels=in_channels,
521
+ embed_dim=inner_dim,
522
+ text_embed_dim=text_embed_dim,
523
+ bias=True,
524
+ sample_width=sample_width,
525
+ sample_height=sample_height,
526
+ sample_frames=sample_frames,
527
+ temporal_compression_ratio=temporal_compression_ratio,
528
+ max_text_seq_length=max_text_seq_length,
529
+ spatial_interpolation_scale=spatial_interpolation_scale,
530
+ temporal_interpolation_scale=temporal_interpolation_scale,
531
+ use_positional_embeddings=not use_rotary_positional_embeddings,
532
+ use_learned_positional_embeddings=use_learned_positional_embeddings,
533
+ )
534
+ self.embedding_dropout = nn.Dropout(dropout)
535
+
536
+ # 2. Time embeddings
537
+ self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
538
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
539
+
540
+ # 3. Define spatio-temporal transformers blocks
541
+ self.transformer_blocks = nn.ModuleList(
542
+ [
543
+ ConsisIDBlock(
544
+ dim=inner_dim,
545
+ num_attention_heads=num_attention_heads,
546
+ attention_head_dim=attention_head_dim,
547
+ time_embed_dim=time_embed_dim,
548
+ dropout=dropout,
549
+ activation_fn=activation_fn,
550
+ attention_bias=attention_bias,
551
+ norm_elementwise_affine=norm_elementwise_affine,
552
+ norm_eps=norm_eps,
553
+ )
554
+ for _ in range(num_layers)
555
+ ]
556
+ )
557
+ self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
558
+
559
+ # 4. Output blocks
560
+ self.norm_out = AdaLayerNorm(
561
+ embedding_dim=time_embed_dim,
562
+ output_dim=2 * inner_dim,
563
+ norm_elementwise_affine=norm_elementwise_affine,
564
+ norm_eps=norm_eps,
565
+ chunk_dim=1,
566
+ )
567
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
568
+
569
+ self.is_train_face = is_train_face
570
+ self.is_kps = is_kps
571
+
572
+ # 5. Define identity-preserving config
573
+ if is_train_face:
574
+ # LFE configs
575
+ self.LFE_id_dim = LFE_id_dim
576
+ self.LFE_vit_dim = LFE_vit_dim
577
+ self.LFE_depth = LFE_depth
578
+ self.LFE_dim_head = LFE_dim_head
579
+ self.LFE_num_heads = LFE_num_heads
580
+ self.LFE_num_id_token = LFE_num_id_token
581
+ self.LFE_num_querie = LFE_num_querie
582
+ self.LFE_output_dim = LFE_output_dim
583
+ self.LFE_ff_mult = LFE_ff_mult
584
+ self.LFE_num_scale = LFE_num_scale
585
+ # cross configs
586
+ self.inner_dim = inner_dim
587
+ self.cross_attn_interval = cross_attn_interval
588
+ self.num_cross_attn = num_layers // cross_attn_interval
589
+ self.cross_attn_dim_head = cross_attn_dim_head
590
+ self.cross_attn_num_heads = cross_attn_num_heads
591
+ self.cross_attn_kv_dim = int(self.inner_dim / 3 * 2)
592
+ self.local_face_scale = local_face_scale
593
+ # face modules
594
+ self._init_face_inputs()
595
+
596
+ self.gradient_checkpointing = False
597
+
598
+ def _init_face_inputs(self):
599
+ self.local_facial_extractor = LocalFacialExtractor(
600
+ id_dim=self.LFE_id_dim,
601
+ vit_dim=self.LFE_vit_dim,
602
+ depth=self.LFE_depth,
603
+ dim_head=self.LFE_dim_head,
604
+ heads=self.LFE_num_heads,
605
+ num_id_token=self.LFE_num_id_token,
606
+ num_queries=self.LFE_num_querie,
607
+ output_dim=self.LFE_output_dim,
608
+ ff_mult=self.LFE_ff_mult,
609
+ num_scale=self.LFE_num_scale,
610
+ )
611
+ self.perceiver_cross_attention = nn.ModuleList(
612
+ [
613
+ PerceiverCrossAttention(
614
+ dim=self.inner_dim,
615
+ dim_head=self.cross_attn_dim_head,
616
+ heads=self.cross_attn_num_heads,
617
+ kv_dim=self.cross_attn_kv_dim,
618
+ )
619
+ for _ in range(self.num_cross_attn)
620
+ ]
621
+ )
622
+
623
+ @property
624
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
625
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
626
+ r"""
627
+ Returns:
628
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
629
+ indexed by its weight name.
630
+ """
631
+ # set recursively
632
+ processors = {}
633
+
634
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
635
+ if hasattr(module, "get_processor"):
636
+ processors[f"{name}.processor"] = module.get_processor()
637
+
638
+ for sub_name, child in module.named_children():
639
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
640
+
641
+ return processors
642
+
643
+ for name, module in self.named_children():
644
+ fn_recursive_add_processors(name, module, processors)
645
+
646
+ return processors
647
+
648
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
649
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
650
+ r"""
651
+ Sets the attention processor to use to compute attention.
652
+
653
+ Parameters:
654
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
655
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
656
+ for **all** `Attention` layers.
657
+
658
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
659
+ processor. This is strongly recommended when setting trainable attention processors.
660
+
661
+ """
662
+ count = len(self.attn_processors.keys())
663
+
664
+ if isinstance(processor, dict) and len(processor) != count:
665
+ raise ValueError(
666
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
667
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
668
+ )
669
+
670
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
671
+ if hasattr(module, "set_processor"):
672
+ if not isinstance(processor, dict):
673
+ module.set_processor(processor)
674
+ else:
675
+ module.set_processor(processor.pop(f"{name}.processor"))
676
+
677
+ for sub_name, child in module.named_children():
678
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
679
+
680
+ for name, module in self.named_children():
681
+ fn_recursive_attn_processor(name, module, processor)
682
+
683
+ def forward(
684
+ self,
685
+ hidden_states: torch.Tensor,
686
+ encoder_hidden_states: torch.Tensor,
687
+ timestep: Union[int, float, torch.LongTensor],
688
+ timestep_cond: Optional[torch.Tensor] = None,
689
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
690
+ attention_kwargs: Optional[Dict[str, Any]] = None,
691
+ id_cond: Optional[torch.Tensor] = None,
692
+ id_vit_hidden: Optional[torch.Tensor] = None,
693
+ return_dict: bool = True,
694
+ ):
695
+ if attention_kwargs is not None:
696
+ attention_kwargs = attention_kwargs.copy()
697
+ lora_scale = attention_kwargs.pop("scale", 1.0)
698
+ else:
699
+ lora_scale = 1.0
700
+
701
+ if USE_PEFT_BACKEND:
702
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
703
+ scale_lora_layers(self, lora_scale)
704
+ else:
705
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
706
+ logger.warning(
707
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
708
+ )
709
+
710
+ # fuse clip and insightface
711
+ valid_face_emb = None
712
+ if self.is_train_face:
713
+ id_cond = id_cond.to(device=hidden_states.device, dtype=hidden_states.dtype)
714
+ id_vit_hidden = [
715
+ tensor.to(device=hidden_states.device, dtype=hidden_states.dtype) for tensor in id_vit_hidden
716
+ ]
717
+ valid_face_emb = self.local_facial_extractor(
718
+ id_cond, id_vit_hidden
719
+ ) # torch.Size([1, 1280]), list[5](torch.Size([1, 577, 1024])) -> torch.Size([1, 32, 2048])
720
+
721
+ batch_size, num_frames, channels, height, width = hidden_states.shape
722
+
723
+ # 1. Time embedding
724
+ timesteps = timestep
725
+ t_emb = self.time_proj(timesteps)
726
+
727
+ # timesteps does not contain any weights and will always return f32 tensors
728
+ # but time_embedding might actually be running in fp16. so we need to cast here.
729
+ # there might be better ways to encapsulate this.
730
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
731
+ emb = self.time_embedding(t_emb, timestep_cond)
732
+
733
+ # 2. Patch embedding
734
+ # torch.Size([1, 226, 4096]) torch.Size([1, 13, 32, 60, 90])
735
+ hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) # torch.Size([1, 17776, 3072])
736
+ hidden_states = self.embedding_dropout(hidden_states) # torch.Size([1, 17776, 3072])
737
+
738
+ text_seq_length = encoder_hidden_states.shape[1]
739
+ encoder_hidden_states = hidden_states[:, :text_seq_length] # torch.Size([1, 226, 3072])
740
+ hidden_states = hidden_states[:, text_seq_length:] # torch.Size([1, 17550, 3072])
741
+
742
+ # 3. Transformer blocks
743
+ ca_idx = 0
744
+ for i, block in enumerate(self.transformer_blocks):
745
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
746
+ hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
747
+ block,
748
+ hidden_states,
749
+ encoder_hidden_states,
750
+ emb,
751
+ image_rotary_emb,
752
+ )
753
+ else:
754
+ hidden_states, encoder_hidden_states = block(
755
+ hidden_states=hidden_states,
756
+ encoder_hidden_states=encoder_hidden_states,
757
+ temb=emb,
758
+ image_rotary_emb=image_rotary_emb,
759
+ )
760
+
761
+ if self.is_train_face:
762
+ if i % self.cross_attn_interval == 0 and valid_face_emb is not None:
763
+ hidden_states = hidden_states + self.local_face_scale * self.perceiver_cross_attention[ca_idx](
764
+ valid_face_emb, hidden_states
765
+ ) # torch.Size([2, 32, 2048]) torch.Size([2, 17550, 3072])
766
+ ca_idx += 1
767
+
768
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
769
+ hidden_states = self.norm_final(hidden_states)
770
+ hidden_states = hidden_states[:, text_seq_length:]
771
+
772
+ # 4. Final block
773
+ hidden_states = self.norm_out(hidden_states, temb=emb)
774
+ hidden_states = self.proj_out(hidden_states)
775
+
776
+ # 5. Unpatchify
777
+ # Note: we use `-1` instead of `channels`:
778
+ # - It is okay to `channels` use for ConsisID (number of input channels is equal to output channels)
779
+ p = self.config.patch_size
780
+ output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
781
+ output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
782
+
783
+ if USE_PEFT_BACKEND:
784
+ # remove `lora_scale` from each PEFT layer
785
+ unscale_lora_layers(self, lora_scale)
786
+
787
+ if not return_dict:
788
+ return (output,)
789
+ return Transformer2DModelOutput(sample=output)
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/dit_transformer_2d.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from ...configuration_utils import ConfigMixin, register_to_config
21
+ from ...utils import logging
22
+ from ..attention import BasicTransformerBlock
23
+ from ..embeddings import PatchEmbed
24
+ from ..modeling_outputs import Transformer2DModelOutput
25
+ from ..modeling_utils import ModelMixin
26
+
27
+
28
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29
+
30
+
31
+ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
32
+ r"""
33
+ A 2D Transformer model as introduced in DiT (https://huggingface.co/papers/2212.09748).
34
+
35
+ Parameters:
36
+ num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention.
37
+ attention_head_dim (int, optional, defaults to 72): The number of channels in each head.
38
+ in_channels (int, defaults to 4): The number of channels in the input.
39
+ out_channels (int, optional):
40
+ The number of channels in the output. Specify this parameter if the output channel number differs from the
41
+ input.
42
+ num_layers (int, optional, defaults to 28): The number of layers of Transformer blocks to use.
43
+ dropout (float, optional, defaults to 0.0): The dropout probability to use within the Transformer blocks.
44
+ norm_num_groups (int, optional, defaults to 32):
45
+ Number of groups for group normalization within Transformer blocks.
46
+ attention_bias (bool, optional, defaults to True):
47
+ Configure if the Transformer blocks' attention should contain a bias parameter.
48
+ sample_size (int, defaults to 32):
49
+ The width of the latent images. This parameter is fixed during training.
50
+ patch_size (int, defaults to 2):
51
+ Size of the patches the model processes, relevant for architectures working on non-sequential data.
52
+ activation_fn (str, optional, defaults to "gelu-approximate"):
53
+ Activation function to use in feed-forward networks within Transformer blocks.
54
+ num_embeds_ada_norm (int, optional, defaults to 1000):
55
+ Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during
56
+ inference.
57
+ upcast_attention (bool, optional, defaults to False):
58
+ If true, upcasts the attention mechanism dimensions for potentially improved performance.
59
+ norm_type (str, optional, defaults to "ada_norm_zero"):
60
+ Specifies the type of normalization used, can be 'ada_norm_zero'.
61
+ norm_elementwise_affine (bool, optional, defaults to False):
62
+ If true, enables element-wise affine parameters in the normalization layers.
63
+ norm_eps (float, optional, defaults to 1e-5):
64
+ A small constant added to the denominator in normalization layers to prevent division by zero.
65
+ """
66
+
67
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
68
+ _supports_gradient_checkpointing = True
69
+ _supports_group_offloading = False
70
+
71
+ @register_to_config
72
+ def __init__(
73
+ self,
74
+ num_attention_heads: int = 16,
75
+ attention_head_dim: int = 72,
76
+ in_channels: int = 4,
77
+ out_channels: Optional[int] = None,
78
+ num_layers: int = 28,
79
+ dropout: float = 0.0,
80
+ norm_num_groups: int = 32,
81
+ attention_bias: bool = True,
82
+ sample_size: int = 32,
83
+ patch_size: int = 2,
84
+ activation_fn: str = "gelu-approximate",
85
+ num_embeds_ada_norm: Optional[int] = 1000,
86
+ upcast_attention: bool = False,
87
+ norm_type: str = "ada_norm_zero",
88
+ norm_elementwise_affine: bool = False,
89
+ norm_eps: float = 1e-5,
90
+ ):
91
+ super().__init__()
92
+
93
+ # Validate inputs.
94
+ if norm_type != "ada_norm_zero":
95
+ raise NotImplementedError(
96
+ f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'."
97
+ )
98
+ elif norm_type == "ada_norm_zero" and num_embeds_ada_norm is None:
99
+ raise ValueError(
100
+ f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
101
+ )
102
+
103
+ # Set some common variables used across the board.
104
+ self.attention_head_dim = attention_head_dim
105
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
106
+ self.out_channels = in_channels if out_channels is None else out_channels
107
+ self.gradient_checkpointing = False
108
+
109
+ # 2. Initialize the position embedding and transformer blocks.
110
+ self.height = self.config.sample_size
111
+ self.width = self.config.sample_size
112
+
113
+ self.patch_size = self.config.patch_size
114
+ self.pos_embed = PatchEmbed(
115
+ height=self.config.sample_size,
116
+ width=self.config.sample_size,
117
+ patch_size=self.config.patch_size,
118
+ in_channels=self.config.in_channels,
119
+ embed_dim=self.inner_dim,
120
+ )
121
+
122
+ self.transformer_blocks = nn.ModuleList(
123
+ [
124
+ BasicTransformerBlock(
125
+ self.inner_dim,
126
+ self.config.num_attention_heads,
127
+ self.config.attention_head_dim,
128
+ dropout=self.config.dropout,
129
+ activation_fn=self.config.activation_fn,
130
+ num_embeds_ada_norm=self.config.num_embeds_ada_norm,
131
+ attention_bias=self.config.attention_bias,
132
+ upcast_attention=self.config.upcast_attention,
133
+ norm_type=norm_type,
134
+ norm_elementwise_affine=self.config.norm_elementwise_affine,
135
+ norm_eps=self.config.norm_eps,
136
+ )
137
+ for _ in range(self.config.num_layers)
138
+ ]
139
+ )
140
+
141
+ # 3. Output blocks.
142
+ self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
143
+ self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
144
+ self.proj_out_2 = nn.Linear(
145
+ self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels
146
+ )
147
+
148
+ def forward(
149
+ self,
150
+ hidden_states: torch.Tensor,
151
+ timestep: Optional[torch.LongTensor] = None,
152
+ class_labels: Optional[torch.LongTensor] = None,
153
+ cross_attention_kwargs: Dict[str, Any] = None,
154
+ return_dict: bool = True,
155
+ ):
156
+ """
157
+ The [`DiTTransformer2DModel`] forward method.
158
+
159
+ Args:
160
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
161
+ Input `hidden_states`.
162
+ timestep ( `torch.LongTensor`, *optional*):
163
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
164
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
165
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
166
+ `AdaLayerZeroNorm`.
167
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
168
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
169
+ `self.processor` in
170
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
171
+ return_dict (`bool`, *optional*, defaults to `True`):
172
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
173
+ tuple.
174
+
175
+ Returns:
176
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
177
+ `tuple` where the first element is the sample tensor.
178
+ """
179
+ # 1. Input
180
+ height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
181
+ hidden_states = self.pos_embed(hidden_states)
182
+
183
+ # 2. Blocks
184
+ for block in self.transformer_blocks:
185
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
186
+ hidden_states = self._gradient_checkpointing_func(
187
+ block,
188
+ hidden_states,
189
+ None,
190
+ None,
191
+ None,
192
+ timestep,
193
+ cross_attention_kwargs,
194
+ class_labels,
195
+ )
196
+ else:
197
+ hidden_states = block(
198
+ hidden_states,
199
+ attention_mask=None,
200
+ encoder_hidden_states=None,
201
+ encoder_attention_mask=None,
202
+ timestep=timestep,
203
+ cross_attention_kwargs=cross_attention_kwargs,
204
+ class_labels=class_labels,
205
+ )
206
+
207
+ # 3. Output
208
+ conditioning = self.transformer_blocks[0].norm1.emb(timestep, class_labels, hidden_dtype=hidden_states.dtype)
209
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
210
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
211
+ hidden_states = self.proj_out_2(hidden_states)
212
+
213
+ # unpatchify
214
+ height = width = int(hidden_states.shape[1] ** 0.5)
215
+ hidden_states = hidden_states.reshape(
216
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
217
+ )
218
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
219
+ output = hidden_states.reshape(
220
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
221
+ )
222
+
223
+ if not return_dict:
224
+ return (output,)
225
+
226
+ return Transformer2DModelOutput(sample=output)
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/dual_transformer_2d.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Optional
15
+
16
+ from torch import nn
17
+
18
+ from ..modeling_outputs import Transformer2DModelOutput
19
+ from .transformer_2d import Transformer2DModel
20
+
21
+
22
+ class DualTransformer2DModel(nn.Module):
23
+ """
24
+ Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
25
+
26
+ Parameters:
27
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
28
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
29
+ in_channels (`int`, *optional*):
30
+ Pass if the input is continuous. The number of channels in the input and output.
31
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
32
+ dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
33
+ cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
34
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
35
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
36
+ `ImagePositionalEmbeddings`.
37
+ num_vector_embeds (`int`, *optional*):
38
+ Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
39
+ Includes the class for the masked latent pixel.
40
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
41
+ num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
42
+ The number of diffusion steps used during training. Note that this is fixed at training time as it is used
43
+ to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
44
+ up to but not more than steps than `num_embeds_ada_norm`.
45
+ attention_bias (`bool`, *optional*):
46
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ num_attention_heads: int = 16,
52
+ attention_head_dim: int = 88,
53
+ in_channels: Optional[int] = None,
54
+ num_layers: int = 1,
55
+ dropout: float = 0.0,
56
+ norm_num_groups: int = 32,
57
+ cross_attention_dim: Optional[int] = None,
58
+ attention_bias: bool = False,
59
+ sample_size: Optional[int] = None,
60
+ num_vector_embeds: Optional[int] = None,
61
+ activation_fn: str = "geglu",
62
+ num_embeds_ada_norm: Optional[int] = None,
63
+ ):
64
+ super().__init__()
65
+ self.transformers = nn.ModuleList(
66
+ [
67
+ Transformer2DModel(
68
+ num_attention_heads=num_attention_heads,
69
+ attention_head_dim=attention_head_dim,
70
+ in_channels=in_channels,
71
+ num_layers=num_layers,
72
+ dropout=dropout,
73
+ norm_num_groups=norm_num_groups,
74
+ cross_attention_dim=cross_attention_dim,
75
+ attention_bias=attention_bias,
76
+ sample_size=sample_size,
77
+ num_vector_embeds=num_vector_embeds,
78
+ activation_fn=activation_fn,
79
+ num_embeds_ada_norm=num_embeds_ada_norm,
80
+ )
81
+ for _ in range(2)
82
+ ]
83
+ )
84
+
85
+ # Variables that can be set by a pipeline:
86
+
87
+ # The ratio of transformer1 to transformer2's output states to be combined during inference
88
+ self.mix_ratio = 0.5
89
+
90
+ # The shape of `encoder_hidden_states` is expected to be
91
+ # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
92
+ self.condition_lengths = [77, 257]
93
+
94
+ # Which transformer to use to encode which condition.
95
+ # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
96
+ self.transformer_index_for_condition = [1, 0]
97
+
98
+ def forward(
99
+ self,
100
+ hidden_states,
101
+ encoder_hidden_states,
102
+ timestep=None,
103
+ attention_mask=None,
104
+ cross_attention_kwargs=None,
105
+ return_dict: bool = True,
106
+ ):
107
+ """
108
+ Args:
109
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
110
+ When continuous, `torch.Tensor` of shape `(batch size, channel, height, width)`): Input hidden_states.
111
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
112
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
113
+ self-attention.
114
+ timestep ( `torch.long`, *optional*):
115
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
116
+ attention_mask (`torch.Tensor`, *optional*):
117
+ Optional attention mask to be applied in Attention.
118
+ cross_attention_kwargs (`dict`, *optional*):
119
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
120
+ `self.processor` in
121
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
122
+ return_dict (`bool`, *optional*, defaults to `True`):
123
+ Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
124
+ tuple.
125
+
126
+ Returns:
127
+ [`~models.transformers.transformer_2d.Transformer2DModelOutput`] or `tuple`:
128
+ [`~models.transformers.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a
129
+ `tuple`. When returning a tuple, the first element is the sample tensor.
130
+ """
131
+ input_states = hidden_states
132
+
133
+ encoded_states = []
134
+ tokens_start = 0
135
+ # attention_mask is not used yet
136
+ for i in range(2):
137
+ # for each of the two transformers, pass the corresponding condition tokens
138
+ condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
139
+ transformer_index = self.transformer_index_for_condition[i]
140
+ encoded_state = self.transformers[transformer_index](
141
+ input_states,
142
+ encoder_hidden_states=condition_state,
143
+ timestep=timestep,
144
+ cross_attention_kwargs=cross_attention_kwargs,
145
+ return_dict=False,
146
+ )[0]
147
+ encoded_states.append(encoded_state - input_states)
148
+ tokens_start += self.condition_lengths[i]
149
+
150
+ output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
151
+ output_states = output_states + input_states
152
+
153
+ if not return_dict:
154
+ return (output_states,)
155
+
156
+ return Transformer2DModelOutput(sample=output_states)
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/hunyuan_transformer_2d.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 HunyuanDiT Authors, Qixun Wang and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Dict, Optional, Union
15
+
16
+ import torch
17
+ from torch import nn
18
+
19
+ from ...configuration_utils import ConfigMixin, register_to_config
20
+ from ...utils import logging
21
+ from ...utils.torch_utils import maybe_allow_in_graph
22
+ from ..attention import FeedForward
23
+ from ..attention_processor import Attention, AttentionProcessor, FusedHunyuanAttnProcessor2_0, HunyuanAttnProcessor2_0
24
+ from ..embeddings import (
25
+ HunyuanCombinedTimestepTextSizeStyleEmbedding,
26
+ PatchEmbed,
27
+ PixArtAlphaTextProjection,
28
+ )
29
+ from ..modeling_outputs import Transformer2DModelOutput
30
+ from ..modeling_utils import ModelMixin
31
+ from ..normalization import AdaLayerNormContinuous, FP32LayerNorm
32
+
33
+
34
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
+
36
+
37
+ class AdaLayerNormShift(nn.Module):
38
+ r"""
39
+ Norm layer modified to incorporate timestep embeddings.
40
+
41
+ Parameters:
42
+ embedding_dim (`int`): The size of each embedding vector.
43
+ num_embeddings (`int`): The size of the embeddings dictionary.
44
+ """
45
+
46
+ def __init__(self, embedding_dim: int, elementwise_affine=True, eps=1e-6):
47
+ super().__init__()
48
+ self.silu = nn.SiLU()
49
+ self.linear = nn.Linear(embedding_dim, embedding_dim)
50
+ self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps)
51
+
52
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
53
+ shift = self.linear(self.silu(emb.to(torch.float32)).to(emb.dtype))
54
+ x = self.norm(x) + shift.unsqueeze(dim=1)
55
+ return x
56
+
57
+
58
+ @maybe_allow_in_graph
59
+ class HunyuanDiTBlock(nn.Module):
60
+ r"""
61
+ Transformer block used in Hunyuan-DiT model (https://github.com/Tencent/HunyuanDiT). Allow skip connection and
62
+ QKNorm
63
+
64
+ Parameters:
65
+ dim (`int`):
66
+ The number of channels in the input and output.
67
+ num_attention_heads (`int`):
68
+ The number of headsto use for multi-head attention.
69
+ cross_attention_dim (`int`,*optional*):
70
+ The size of the encoder_hidden_states vector for cross attention.
71
+ dropout(`float`, *optional*, defaults to 0.0):
72
+ The dropout probability to use.
73
+ activation_fn (`str`,*optional*, defaults to `"geglu"`):
74
+ Activation function to be used in feed-forward. .
75
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
76
+ Whether to use learnable elementwise affine parameters for normalization.
77
+ norm_eps (`float`, *optional*, defaults to 1e-6):
78
+ A small constant added to the denominator in normalization layers to prevent division by zero.
79
+ final_dropout (`bool` *optional*, defaults to False):
80
+ Whether to apply a final dropout after the last feed-forward layer.
81
+ ff_inner_dim (`int`, *optional*):
82
+ The size of the hidden layer in the feed-forward block. Defaults to `None`.
83
+ ff_bias (`bool`, *optional*, defaults to `True`):
84
+ Whether to use bias in the feed-forward block.
85
+ skip (`bool`, *optional*, defaults to `False`):
86
+ Whether to use skip connection. Defaults to `False` for down-blocks and mid-blocks.
87
+ qk_norm (`bool`, *optional*, defaults to `True`):
88
+ Whether to use normalization in QK calculation. Defaults to `True`.
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ dim: int,
94
+ num_attention_heads: int,
95
+ cross_attention_dim: int = 1024,
96
+ dropout=0.0,
97
+ activation_fn: str = "geglu",
98
+ norm_elementwise_affine: bool = True,
99
+ norm_eps: float = 1e-6,
100
+ final_dropout: bool = False,
101
+ ff_inner_dim: Optional[int] = None,
102
+ ff_bias: bool = True,
103
+ skip: bool = False,
104
+ qk_norm: bool = True,
105
+ ):
106
+ super().__init__()
107
+
108
+ # Define 3 blocks. Each block has its own normalization layer.
109
+ # NOTE: when new version comes, check norm2 and norm 3
110
+ # 1. Self-Attn
111
+ self.norm1 = AdaLayerNormShift(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
112
+
113
+ self.attn1 = Attention(
114
+ query_dim=dim,
115
+ cross_attention_dim=None,
116
+ dim_head=dim // num_attention_heads,
117
+ heads=num_attention_heads,
118
+ qk_norm="layer_norm" if qk_norm else None,
119
+ eps=1e-6,
120
+ bias=True,
121
+ processor=HunyuanAttnProcessor2_0(),
122
+ )
123
+
124
+ # 2. Cross-Attn
125
+ self.norm2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
126
+
127
+ self.attn2 = Attention(
128
+ query_dim=dim,
129
+ cross_attention_dim=cross_attention_dim,
130
+ dim_head=dim // num_attention_heads,
131
+ heads=num_attention_heads,
132
+ qk_norm="layer_norm" if qk_norm else None,
133
+ eps=1e-6,
134
+ bias=True,
135
+ processor=HunyuanAttnProcessor2_0(),
136
+ )
137
+ # 3. Feed-forward
138
+ self.norm3 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
139
+
140
+ self.ff = FeedForward(
141
+ dim,
142
+ dropout=dropout, ### 0.0
143
+ activation_fn=activation_fn, ### approx GeLU
144
+ final_dropout=final_dropout, ### 0.0
145
+ inner_dim=ff_inner_dim, ### int(dim * mlp_ratio)
146
+ bias=ff_bias,
147
+ )
148
+
149
+ # 4. Skip Connection
150
+ if skip:
151
+ self.skip_norm = FP32LayerNorm(2 * dim, norm_eps, elementwise_affine=True)
152
+ self.skip_linear = nn.Linear(2 * dim, dim)
153
+ else:
154
+ self.skip_linear = None
155
+
156
+ # let chunk size default to None
157
+ self._chunk_size = None
158
+ self._chunk_dim = 0
159
+
160
+ # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
161
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
162
+ # Sets chunk feed-forward
163
+ self._chunk_size = chunk_size
164
+ self._chunk_dim = dim
165
+
166
+ def forward(
167
+ self,
168
+ hidden_states: torch.Tensor,
169
+ encoder_hidden_states: Optional[torch.Tensor] = None,
170
+ temb: Optional[torch.Tensor] = None,
171
+ image_rotary_emb=None,
172
+ skip=None,
173
+ ) -> torch.Tensor:
174
+ # Notice that normalization is always applied before the real computation in the following blocks.
175
+ # 0. Long Skip Connection
176
+ if self.skip_linear is not None:
177
+ cat = torch.cat([hidden_states, skip], dim=-1)
178
+ cat = self.skip_norm(cat)
179
+ hidden_states = self.skip_linear(cat)
180
+
181
+ # 1. Self-Attention
182
+ norm_hidden_states = self.norm1(hidden_states, temb) ### checked: self.norm1 is correct
183
+ attn_output = self.attn1(
184
+ norm_hidden_states,
185
+ image_rotary_emb=image_rotary_emb,
186
+ )
187
+ hidden_states = hidden_states + attn_output
188
+
189
+ # 2. Cross-Attention
190
+ hidden_states = hidden_states + self.attn2(
191
+ self.norm2(hidden_states),
192
+ encoder_hidden_states=encoder_hidden_states,
193
+ image_rotary_emb=image_rotary_emb,
194
+ )
195
+
196
+ # FFN Layer ### TODO: switch norm2 and norm3 in the state dict
197
+ mlp_inputs = self.norm3(hidden_states)
198
+ hidden_states = hidden_states + self.ff(mlp_inputs)
199
+
200
+ return hidden_states
201
+
202
+
203
+ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
204
+ """
205
+ HunYuanDiT: Diffusion model with a Transformer backbone.
206
+
207
+ Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
208
+
209
+ Parameters:
210
+ num_attention_heads (`int`, *optional*, defaults to 16):
211
+ The number of heads to use for multi-head attention.
212
+ attention_head_dim (`int`, *optional*, defaults to 88):
213
+ The number of channels in each head.
214
+ in_channels (`int`, *optional*):
215
+ The number of channels in the input and output (specify if the input is **continuous**).
216
+ patch_size (`int`, *optional*):
217
+ The size of the patch to use for the input.
218
+ activation_fn (`str`, *optional*, defaults to `"geglu"`):
219
+ Activation function to use in feed-forward.
220
+ sample_size (`int`, *optional*):
221
+ The width of the latent images. This is fixed during training since it is used to learn a number of
222
+ position embeddings.
223
+ dropout (`float`, *optional*, defaults to 0.0):
224
+ The dropout probability to use.
225
+ cross_attention_dim (`int`, *optional*):
226
+ The number of dimension in the clip text embedding.
227
+ hidden_size (`int`, *optional*):
228
+ The size of hidden layer in the conditioning embedding layers.
229
+ num_layers (`int`, *optional*, defaults to 1):
230
+ The number of layers of Transformer blocks to use.
231
+ mlp_ratio (`float`, *optional*, defaults to 4.0):
232
+ The ratio of the hidden layer size to the input size.
233
+ learn_sigma (`bool`, *optional*, defaults to `True`):
234
+ Whether to predict variance.
235
+ cross_attention_dim_t5 (`int`, *optional*):
236
+ The number dimensions in t5 text embedding.
237
+ pooled_projection_dim (`int`, *optional*):
238
+ The size of the pooled projection.
239
+ text_len (`int`, *optional*):
240
+ The length of the clip text embedding.
241
+ text_len_t5 (`int`, *optional*):
242
+ The length of the T5 text embedding.
243
+ use_style_cond_and_image_meta_size (`bool`, *optional*):
244
+ Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2
245
+ """
246
+
247
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm", "pooler"]
248
+ _supports_group_offloading = False
249
+
250
+ @register_to_config
251
+ def __init__(
252
+ self,
253
+ num_attention_heads: int = 16,
254
+ attention_head_dim: int = 88,
255
+ in_channels: Optional[int] = None,
256
+ patch_size: Optional[int] = None,
257
+ activation_fn: str = "gelu-approximate",
258
+ sample_size=32,
259
+ hidden_size=1152,
260
+ num_layers: int = 28,
261
+ mlp_ratio: float = 4.0,
262
+ learn_sigma: bool = True,
263
+ cross_attention_dim: int = 1024,
264
+ norm_type: str = "layer_norm",
265
+ cross_attention_dim_t5: int = 2048,
266
+ pooled_projection_dim: int = 1024,
267
+ text_len: int = 77,
268
+ text_len_t5: int = 256,
269
+ use_style_cond_and_image_meta_size: bool = True,
270
+ ):
271
+ super().__init__()
272
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
273
+ self.num_heads = num_attention_heads
274
+ self.inner_dim = num_attention_heads * attention_head_dim
275
+
276
+ self.text_embedder = PixArtAlphaTextProjection(
277
+ in_features=cross_attention_dim_t5,
278
+ hidden_size=cross_attention_dim_t5 * 4,
279
+ out_features=cross_attention_dim,
280
+ act_fn="silu_fp32",
281
+ )
282
+
283
+ self.text_embedding_padding = nn.Parameter(torch.randn(text_len + text_len_t5, cross_attention_dim))
284
+
285
+ self.pos_embed = PatchEmbed(
286
+ height=sample_size,
287
+ width=sample_size,
288
+ in_channels=in_channels,
289
+ embed_dim=hidden_size,
290
+ patch_size=patch_size,
291
+ pos_embed_type=None,
292
+ )
293
+
294
+ self.time_extra_emb = HunyuanCombinedTimestepTextSizeStyleEmbedding(
295
+ hidden_size,
296
+ pooled_projection_dim=pooled_projection_dim,
297
+ seq_len=text_len_t5,
298
+ cross_attention_dim=cross_attention_dim_t5,
299
+ use_style_cond_and_image_meta_size=use_style_cond_and_image_meta_size,
300
+ )
301
+
302
+ # HunyuanDiT Blocks
303
+ self.blocks = nn.ModuleList(
304
+ [
305
+ HunyuanDiTBlock(
306
+ dim=self.inner_dim,
307
+ num_attention_heads=self.config.num_attention_heads,
308
+ activation_fn=activation_fn,
309
+ ff_inner_dim=int(self.inner_dim * mlp_ratio),
310
+ cross_attention_dim=cross_attention_dim,
311
+ qk_norm=True, # See https://huggingface.co/papers/2302.05442 for details.
312
+ skip=layer > num_layers // 2,
313
+ )
314
+ for layer in range(num_layers)
315
+ ]
316
+ )
317
+
318
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
319
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
320
+
321
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedHunyuanAttnProcessor2_0
322
+ def fuse_qkv_projections(self):
323
+ """
324
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
325
+ are fused. For cross-attention modules, key and value projection matrices are fused.
326
+
327
+ <Tip warning={true}>
328
+
329
+ This API is 🧪 experimental.
330
+
331
+ </Tip>
332
+ """
333
+ self.original_attn_processors = None
334
+
335
+ for _, attn_processor in self.attn_processors.items():
336
+ if "Added" in str(attn_processor.__class__.__name__):
337
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
338
+
339
+ self.original_attn_processors = self.attn_processors
340
+
341
+ for module in self.modules():
342
+ if isinstance(module, Attention):
343
+ module.fuse_projections(fuse=True)
344
+
345
+ self.set_attn_processor(FusedHunyuanAttnProcessor2_0())
346
+
347
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
348
+ def unfuse_qkv_projections(self):
349
+ """Disables the fused QKV projection if enabled.
350
+
351
+ <Tip warning={true}>
352
+
353
+ This API is 🧪 experimental.
354
+
355
+ </Tip>
356
+
357
+ """
358
+ if self.original_attn_processors is not None:
359
+ self.set_attn_processor(self.original_attn_processors)
360
+
361
+ @property
362
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
363
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
364
+ r"""
365
+ Returns:
366
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
367
+ indexed by its weight name.
368
+ """
369
+ # set recursively
370
+ processors = {}
371
+
372
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
373
+ if hasattr(module, "get_processor"):
374
+ processors[f"{name}.processor"] = module.get_processor()
375
+
376
+ for sub_name, child in module.named_children():
377
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
378
+
379
+ return processors
380
+
381
+ for name, module in self.named_children():
382
+ fn_recursive_add_processors(name, module, processors)
383
+
384
+ return processors
385
+
386
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
387
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
388
+ r"""
389
+ Sets the attention processor to use to compute attention.
390
+
391
+ Parameters:
392
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
393
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
394
+ for **all** `Attention` layers.
395
+
396
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
397
+ processor. This is strongly recommended when setting trainable attention processors.
398
+
399
+ """
400
+ count = len(self.attn_processors.keys())
401
+
402
+ if isinstance(processor, dict) and len(processor) != count:
403
+ raise ValueError(
404
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
405
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
406
+ )
407
+
408
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
409
+ if hasattr(module, "set_processor"):
410
+ if not isinstance(processor, dict):
411
+ module.set_processor(processor)
412
+ else:
413
+ module.set_processor(processor.pop(f"{name}.processor"))
414
+
415
+ for sub_name, child in module.named_children():
416
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
417
+
418
+ for name, module in self.named_children():
419
+ fn_recursive_attn_processor(name, module, processor)
420
+
421
+ def set_default_attn_processor(self):
422
+ """
423
+ Disables custom attention processors and sets the default attention implementation.
424
+ """
425
+ self.set_attn_processor(HunyuanAttnProcessor2_0())
426
+
427
+ def forward(
428
+ self,
429
+ hidden_states,
430
+ timestep,
431
+ encoder_hidden_states=None,
432
+ text_embedding_mask=None,
433
+ encoder_hidden_states_t5=None,
434
+ text_embedding_mask_t5=None,
435
+ image_meta_size=None,
436
+ style=None,
437
+ image_rotary_emb=None,
438
+ controlnet_block_samples=None,
439
+ return_dict=True,
440
+ ):
441
+ """
442
+ The [`HunyuanDiT2DModel`] forward method.
443
+
444
+ Args:
445
+ hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`):
446
+ The input tensor.
447
+ timestep ( `torch.LongTensor`, *optional*):
448
+ Used to indicate denoising step.
449
+ encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
450
+ Conditional embeddings for cross attention layer. This is the output of `BertModel`.
451
+ text_embedding_mask: torch.Tensor
452
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
453
+ of `BertModel`.
454
+ encoder_hidden_states_t5 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
455
+ Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder.
456
+ text_embedding_mask_t5: torch.Tensor
457
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
458
+ of T5 Text Encoder.
459
+ image_meta_size (torch.Tensor):
460
+ Conditional embedding indicate the image sizes
461
+ style: torch.Tensor:
462
+ Conditional embedding indicate the style
463
+ image_rotary_emb (`torch.Tensor`):
464
+ The image rotary embeddings to apply on query and key tensors during attention calculation.
465
+ return_dict: bool
466
+ Whether to return a dictionary.
467
+ """
468
+
469
+ height, width = hidden_states.shape[-2:]
470
+
471
+ hidden_states = self.pos_embed(hidden_states)
472
+
473
+ temb = self.time_extra_emb(
474
+ timestep, encoder_hidden_states_t5, image_meta_size, style, hidden_dtype=timestep.dtype
475
+ ) # [B, D]
476
+
477
+ # text projection
478
+ batch_size, sequence_length, _ = encoder_hidden_states_t5.shape
479
+ encoder_hidden_states_t5 = self.text_embedder(
480
+ encoder_hidden_states_t5.view(-1, encoder_hidden_states_t5.shape[-1])
481
+ )
482
+ encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, sequence_length, -1)
483
+
484
+ encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=1)
485
+ text_embedding_mask = torch.cat([text_embedding_mask, text_embedding_mask_t5], dim=-1)
486
+ text_embedding_mask = text_embedding_mask.unsqueeze(2).bool()
487
+
488
+ encoder_hidden_states = torch.where(text_embedding_mask, encoder_hidden_states, self.text_embedding_padding)
489
+
490
+ skips = []
491
+ for layer, block in enumerate(self.blocks):
492
+ if layer > self.config.num_layers // 2:
493
+ if controlnet_block_samples is not None:
494
+ skip = skips.pop() + controlnet_block_samples.pop()
495
+ else:
496
+ skip = skips.pop()
497
+ hidden_states = block(
498
+ hidden_states,
499
+ temb=temb,
500
+ encoder_hidden_states=encoder_hidden_states,
501
+ image_rotary_emb=image_rotary_emb,
502
+ skip=skip,
503
+ ) # (N, L, D)
504
+ else:
505
+ hidden_states = block(
506
+ hidden_states,
507
+ temb=temb,
508
+ encoder_hidden_states=encoder_hidden_states,
509
+ image_rotary_emb=image_rotary_emb,
510
+ ) # (N, L, D)
511
+
512
+ if layer < (self.config.num_layers // 2 - 1):
513
+ skips.append(hidden_states)
514
+
515
+ if controlnet_block_samples is not None and len(controlnet_block_samples) != 0:
516
+ raise ValueError("The number of controls is not equal to the number of skip connections.")
517
+
518
+ # final layer
519
+ hidden_states = self.norm_out(hidden_states, temb.to(torch.float32))
520
+ hidden_states = self.proj_out(hidden_states)
521
+ # (N, L, patch_size ** 2 * out_channels)
522
+
523
+ # unpatchify: (N, out_channels, H, W)
524
+ patch_size = self.pos_embed.patch_size
525
+ height = height // patch_size
526
+ width = width // patch_size
527
+
528
+ hidden_states = hidden_states.reshape(
529
+ shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
530
+ )
531
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
532
+ output = hidden_states.reshape(
533
+ shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
534
+ )
535
+ if not return_dict:
536
+ return (output,)
537
+ return Transformer2DModelOutput(sample=output)
538
+
539
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
540
+ def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
541
+ """
542
+ Sets the attention processor to use [feed forward
543
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
544
+
545
+ Parameters:
546
+ chunk_size (`int`, *optional*):
547
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
548
+ over each tensor of dim=`dim`.
549
+ dim (`int`, *optional*, defaults to `0`):
550
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
551
+ or dim=1 (sequence length).
552
+ """
553
+ if dim not in [0, 1]:
554
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
555
+
556
+ # By default chunk size is 1
557
+ chunk_size = chunk_size or 1
558
+
559
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
560
+ if hasattr(module, "set_chunk_feed_forward"):
561
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
562
+
563
+ for child in module.children():
564
+ fn_recursive_feed_forward(child, chunk_size, dim)
565
+
566
+ for module in self.children():
567
+ fn_recursive_feed_forward(module, chunk_size, dim)
568
+
569
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
570
+ def disable_forward_chunking(self):
571
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
572
+ if hasattr(module, "set_chunk_feed_forward"):
573
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
574
+
575
+ for child in module.children():
576
+ fn_recursive_feed_forward(child, chunk_size, dim)
577
+
578
+ for module in self.children():
579
+ fn_recursive_feed_forward(module, None, 0)
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/latte_transformer_3d.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 the Latte Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import torch
18
+ from torch import nn
19
+
20
+ from ...configuration_utils import ConfigMixin, register_to_config
21
+ from ..attention import BasicTransformerBlock
22
+ from ..cache_utils import CacheMixin
23
+ from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
24
+ from ..modeling_outputs import Transformer2DModelOutput
25
+ from ..modeling_utils import ModelMixin
26
+ from ..normalization import AdaLayerNormSingle
27
+
28
+
29
+ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
30
+ _supports_gradient_checkpointing = True
31
+
32
+ """
33
+ A 3D Transformer model for video-like data, paper: https://huggingface.co/papers/2401.03048, official code:
34
+ https://github.com/Vchitect/Latte
35
+
36
+ Parameters:
37
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
38
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
39
+ in_channels (`int`, *optional*):
40
+ The number of channels in the input.
41
+ out_channels (`int`, *optional*):
42
+ The number of channels in the output.
43
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
44
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
45
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
46
+ attention_bias (`bool`, *optional*):
47
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
48
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
49
+ This is fixed during training since it is used to learn a number of position embeddings.
50
+ patch_size (`int`, *optional*):
51
+ The size of the patches to use in the patch embedding layer.
52
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
53
+ num_embeds_ada_norm ( `int`, *optional*):
54
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
55
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
56
+ added to the hidden states. During inference, you can denoise for up to but not more steps than
57
+ `num_embeds_ada_norm`.
58
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
59
+ The type of normalization to use. Options are `"layer_norm"` or `"ada_layer_norm"`.
60
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
61
+ Whether or not to use elementwise affine in normalization layers.
62
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use in normalization layers.
63
+ caption_channels (`int`, *optional*):
64
+ The number of channels in the caption embeddings.
65
+ video_length (`int`, *optional*):
66
+ The number of frames in the video-like data.
67
+ """
68
+
69
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
70
+
71
+ @register_to_config
72
+ def __init__(
73
+ self,
74
+ num_attention_heads: int = 16,
75
+ attention_head_dim: int = 88,
76
+ in_channels: Optional[int] = None,
77
+ out_channels: Optional[int] = None,
78
+ num_layers: int = 1,
79
+ dropout: float = 0.0,
80
+ cross_attention_dim: Optional[int] = None,
81
+ attention_bias: bool = False,
82
+ sample_size: int = 64,
83
+ patch_size: Optional[int] = None,
84
+ activation_fn: str = "geglu",
85
+ num_embeds_ada_norm: Optional[int] = None,
86
+ norm_type: str = "layer_norm",
87
+ norm_elementwise_affine: bool = True,
88
+ norm_eps: float = 1e-5,
89
+ caption_channels: int = None,
90
+ video_length: int = 16,
91
+ ):
92
+ super().__init__()
93
+ inner_dim = num_attention_heads * attention_head_dim
94
+
95
+ # 1. Define input layers
96
+ self.height = sample_size
97
+ self.width = sample_size
98
+
99
+ interpolation_scale = self.config.sample_size // 64
100
+ interpolation_scale = max(interpolation_scale, 1)
101
+ self.pos_embed = PatchEmbed(
102
+ height=sample_size,
103
+ width=sample_size,
104
+ patch_size=patch_size,
105
+ in_channels=in_channels,
106
+ embed_dim=inner_dim,
107
+ interpolation_scale=interpolation_scale,
108
+ )
109
+
110
+ # 2. Define spatial transformers blocks
111
+ self.transformer_blocks = nn.ModuleList(
112
+ [
113
+ BasicTransformerBlock(
114
+ inner_dim,
115
+ num_attention_heads,
116
+ attention_head_dim,
117
+ dropout=dropout,
118
+ cross_attention_dim=cross_attention_dim,
119
+ activation_fn=activation_fn,
120
+ num_embeds_ada_norm=num_embeds_ada_norm,
121
+ attention_bias=attention_bias,
122
+ norm_type=norm_type,
123
+ norm_elementwise_affine=norm_elementwise_affine,
124
+ norm_eps=norm_eps,
125
+ )
126
+ for d in range(num_layers)
127
+ ]
128
+ )
129
+
130
+ # 3. Define temporal transformers blocks
131
+ self.temporal_transformer_blocks = nn.ModuleList(
132
+ [
133
+ BasicTransformerBlock(
134
+ inner_dim,
135
+ num_attention_heads,
136
+ attention_head_dim,
137
+ dropout=dropout,
138
+ cross_attention_dim=None,
139
+ activation_fn=activation_fn,
140
+ num_embeds_ada_norm=num_embeds_ada_norm,
141
+ attention_bias=attention_bias,
142
+ norm_type=norm_type,
143
+ norm_elementwise_affine=norm_elementwise_affine,
144
+ norm_eps=norm_eps,
145
+ )
146
+ for d in range(num_layers)
147
+ ]
148
+ )
149
+
150
+ # 4. Define output layers
151
+ self.out_channels = in_channels if out_channels is None else out_channels
152
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
153
+ self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
154
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
155
+
156
+ # 5. Latte other blocks.
157
+ self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=False)
158
+ self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
159
+
160
+ # define temporal positional embedding
161
+ temp_pos_embed = get_1d_sincos_pos_embed_from_grid(
162
+ inner_dim, torch.arange(0, video_length).unsqueeze(1), output_type="pt"
163
+ ) # 1152 hidden size
164
+ self.register_buffer("temp_pos_embed", temp_pos_embed.float().unsqueeze(0), persistent=False)
165
+
166
+ self.gradient_checkpointing = False
167
+
168
+ def forward(
169
+ self,
170
+ hidden_states: torch.Tensor,
171
+ timestep: Optional[torch.LongTensor] = None,
172
+ encoder_hidden_states: Optional[torch.Tensor] = None,
173
+ encoder_attention_mask: Optional[torch.Tensor] = None,
174
+ enable_temporal_attentions: bool = True,
175
+ return_dict: bool = True,
176
+ ):
177
+ """
178
+ The [`LatteTransformer3DModel`] forward method.
179
+
180
+ Args:
181
+ hidden_states shape `(batch size, channel, num_frame, height, width)`:
182
+ Input `hidden_states`.
183
+ timestep ( `torch.LongTensor`, *optional*):
184
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
185
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
186
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
187
+ self-attention.
188
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
189
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
190
+
191
+ * Mask `(batcheight, sequence_length)` True = keep, False = discard.
192
+ * Bias `(batcheight, 1, sequence_length)` 0 = keep, -10000 = discard.
193
+
194
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
195
+ above. This bias will be added to the cross-attention scores.
196
+ enable_temporal_attentions:
197
+ (`bool`, *optional*, defaults to `True`): Whether to enable temporal attentions.
198
+ return_dict (`bool`, *optional*, defaults to `True`):
199
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
200
+ tuple.
201
+
202
+ Returns:
203
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
204
+ `tuple` where the first element is the sample tensor.
205
+ """
206
+
207
+ # Reshape hidden states
208
+ batch_size, channels, num_frame, height, width = hidden_states.shape
209
+ # batch_size channels num_frame height width -> (batch_size * num_frame) channels height width
210
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(-1, channels, height, width)
211
+
212
+ # Input
213
+ height, width = (
214
+ hidden_states.shape[-2] // self.config.patch_size,
215
+ hidden_states.shape[-1] // self.config.patch_size,
216
+ )
217
+ num_patches = height * width
218
+
219
+ hidden_states = self.pos_embed(hidden_states) # already add positional embeddings
220
+
221
+ added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
222
+ timestep, embedded_timestep = self.adaln_single(
223
+ timestep, added_cond_kwargs=added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
224
+ )
225
+
226
+ # Prepare text embeddings for spatial block
227
+ # batch_size num_tokens hidden_size -> (batch_size * num_frame) num_tokens hidden_size
228
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states) # 3 120 1152
229
+ encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(
230
+ num_frame, dim=0, output_size=encoder_hidden_states.shape[0] * num_frame
231
+ ).view(-1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1])
232
+
233
+ # Prepare timesteps for spatial and temporal block
234
+ timestep_spatial = timestep.repeat_interleave(
235
+ num_frame, dim=0, output_size=timestep.shape[0] * num_frame
236
+ ).view(-1, timestep.shape[-1])
237
+ timestep_temp = timestep.repeat_interleave(
238
+ num_patches, dim=0, output_size=timestep.shape[0] * num_patches
239
+ ).view(-1, timestep.shape[-1])
240
+
241
+ # Spatial and temporal transformer blocks
242
+ for i, (spatial_block, temp_block) in enumerate(
243
+ zip(self.transformer_blocks, self.temporal_transformer_blocks)
244
+ ):
245
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
246
+ hidden_states = self._gradient_checkpointing_func(
247
+ spatial_block,
248
+ hidden_states,
249
+ None, # attention_mask
250
+ encoder_hidden_states_spatial,
251
+ encoder_attention_mask,
252
+ timestep_spatial,
253
+ None, # cross_attention_kwargs
254
+ None, # class_labels
255
+ )
256
+ else:
257
+ hidden_states = spatial_block(
258
+ hidden_states,
259
+ None, # attention_mask
260
+ encoder_hidden_states_spatial,
261
+ encoder_attention_mask,
262
+ timestep_spatial,
263
+ None, # cross_attention_kwargs
264
+ None, # class_labels
265
+ )
266
+
267
+ if enable_temporal_attentions:
268
+ # (batch_size * num_frame) num_tokens hidden_size -> (batch_size * num_tokens) num_frame hidden_size
269
+ hidden_states = hidden_states.reshape(
270
+ batch_size, -1, hidden_states.shape[-2], hidden_states.shape[-1]
271
+ ).permute(0, 2, 1, 3)
272
+ hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])
273
+
274
+ if i == 0 and num_frame > 1:
275
+ hidden_states = hidden_states + self.temp_pos_embed.to(hidden_states.dtype)
276
+
277
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
278
+ hidden_states = self._gradient_checkpointing_func(
279
+ temp_block,
280
+ hidden_states,
281
+ None, # attention_mask
282
+ None, # encoder_hidden_states
283
+ None, # encoder_attention_mask
284
+ timestep_temp,
285
+ None, # cross_attention_kwargs
286
+ None, # class_labels
287
+ )
288
+ else:
289
+ hidden_states = temp_block(
290
+ hidden_states,
291
+ None, # attention_mask
292
+ None, # encoder_hidden_states
293
+ None, # encoder_attention_mask
294
+ timestep_temp,
295
+ None, # cross_attention_kwargs
296
+ None, # class_labels
297
+ )
298
+
299
+ # (batch_size * num_tokens) num_frame hidden_size -> (batch_size * num_frame) num_tokens hidden_size
300
+ hidden_states = hidden_states.reshape(
301
+ batch_size, -1, hidden_states.shape[-2], hidden_states.shape[-1]
302
+ ).permute(0, 2, 1, 3)
303
+ hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])
304
+
305
+ embedded_timestep = embedded_timestep.repeat_interleave(
306
+ num_frame, dim=0, output_size=embedded_timestep.shape[0] * num_frame
307
+ ).view(-1, embedded_timestep.shape[-1])
308
+ shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
309
+ hidden_states = self.norm_out(hidden_states)
310
+ # Modulation
311
+ hidden_states = hidden_states * (1 + scale) + shift
312
+ hidden_states = self.proj_out(hidden_states)
313
+
314
+ # unpatchify
315
+ if self.adaln_single is None:
316
+ height = width = int(hidden_states.shape[1] ** 0.5)
317
+ hidden_states = hidden_states.reshape(
318
+ shape=(-1, height, width, self.config.patch_size, self.config.patch_size, self.out_channels)
319
+ )
320
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
321
+ output = hidden_states.reshape(
322
+ shape=(-1, self.out_channels, height * self.config.patch_size, width * self.config.patch_size)
323
+ )
324
+ output = output.reshape(batch_size, -1, output.shape[-3], output.shape[-2], output.shape[-1]).permute(
325
+ 0, 2, 1, 3, 4
326
+ )
327
+
328
+ if not return_dict:
329
+ return (output,)
330
+
331
+ return Transformer2DModelOutput(sample=output)
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/t5_film_transformer.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ from typing import Optional, Tuple
16
+
17
+ import torch
18
+ from torch import nn
19
+
20
+ from ...configuration_utils import ConfigMixin, register_to_config
21
+ from ..attention_processor import Attention
22
+ from ..embeddings import get_timestep_embedding
23
+ from ..modeling_utils import ModelMixin
24
+
25
+
26
+ class T5FilmDecoder(ModelMixin, ConfigMixin):
27
+ r"""
28
+ T5 style decoder with FiLM conditioning.
29
+
30
+ Args:
31
+ input_dims (`int`, *optional*, defaults to `128`):
32
+ The number of input dimensions.
33
+ targets_length (`int`, *optional*, defaults to `256`):
34
+ The length of the targets.
35
+ d_model (`int`, *optional*, defaults to `768`):
36
+ Size of the input hidden states.
37
+ num_layers (`int`, *optional*, defaults to `12`):
38
+ The number of `DecoderLayer`'s to use.
39
+ num_heads (`int`, *optional*, defaults to `12`):
40
+ The number of attention heads to use.
41
+ d_kv (`int`, *optional*, defaults to `64`):
42
+ Size of the key-value projection vectors.
43
+ d_ff (`int`, *optional*, defaults to `2048`):
44
+ The number of dimensions in the intermediate feed-forward layer of `DecoderLayer`'s.
45
+ dropout_rate (`float`, *optional*, defaults to `0.1`):
46
+ Dropout probability.
47
+ """
48
+
49
+ @register_to_config
50
+ def __init__(
51
+ self,
52
+ input_dims: int = 128,
53
+ targets_length: int = 256,
54
+ max_decoder_noise_time: float = 2000.0,
55
+ d_model: int = 768,
56
+ num_layers: int = 12,
57
+ num_heads: int = 12,
58
+ d_kv: int = 64,
59
+ d_ff: int = 2048,
60
+ dropout_rate: float = 0.1,
61
+ ):
62
+ super().__init__()
63
+
64
+ self.conditioning_emb = nn.Sequential(
65
+ nn.Linear(d_model, d_model * 4, bias=False),
66
+ nn.SiLU(),
67
+ nn.Linear(d_model * 4, d_model * 4, bias=False),
68
+ nn.SiLU(),
69
+ )
70
+
71
+ self.position_encoding = nn.Embedding(targets_length, d_model)
72
+ self.position_encoding.weight.requires_grad = False
73
+
74
+ self.continuous_inputs_projection = nn.Linear(input_dims, d_model, bias=False)
75
+
76
+ self.dropout = nn.Dropout(p=dropout_rate)
77
+
78
+ self.decoders = nn.ModuleList()
79
+ for lyr_num in range(num_layers):
80
+ # FiLM conditional T5 decoder
81
+ lyr = DecoderLayer(d_model=d_model, d_kv=d_kv, num_heads=num_heads, d_ff=d_ff, dropout_rate=dropout_rate)
82
+ self.decoders.append(lyr)
83
+
84
+ self.decoder_norm = T5LayerNorm(d_model)
85
+
86
+ self.post_dropout = nn.Dropout(p=dropout_rate)
87
+ self.spec_out = nn.Linear(d_model, input_dims, bias=False)
88
+
89
+ def encoder_decoder_mask(self, query_input: torch.Tensor, key_input: torch.Tensor) -> torch.Tensor:
90
+ mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2))
91
+ return mask.unsqueeze(-3)
92
+
93
+ def forward(self, encodings_and_masks, decoder_input_tokens, decoder_noise_time):
94
+ batch, _, _ = decoder_input_tokens.shape
95
+ assert decoder_noise_time.shape == (batch,)
96
+
97
+ # decoder_noise_time is in [0, 1), so rescale to expected timing range.
98
+ time_steps = get_timestep_embedding(
99
+ decoder_noise_time * self.config.max_decoder_noise_time,
100
+ embedding_dim=self.config.d_model,
101
+ max_period=self.config.max_decoder_noise_time,
102
+ ).to(dtype=self.dtype)
103
+
104
+ conditioning_emb = self.conditioning_emb(time_steps).unsqueeze(1)
105
+
106
+ assert conditioning_emb.shape == (batch, 1, self.config.d_model * 4)
107
+
108
+ seq_length = decoder_input_tokens.shape[1]
109
+
110
+ # If we want to use relative positions for audio context, we can just offset
111
+ # this sequence by the length of encodings_and_masks.
112
+ decoder_positions = torch.broadcast_to(
113
+ torch.arange(seq_length, device=decoder_input_tokens.device),
114
+ (batch, seq_length),
115
+ )
116
+
117
+ position_encodings = self.position_encoding(decoder_positions)
118
+
119
+ inputs = self.continuous_inputs_projection(decoder_input_tokens)
120
+ inputs += position_encodings
121
+ y = self.dropout(inputs)
122
+
123
+ # decoder: No padding present.
124
+ decoder_mask = torch.ones(
125
+ decoder_input_tokens.shape[:2], device=decoder_input_tokens.device, dtype=inputs.dtype
126
+ )
127
+
128
+ # Translate encoding masks to encoder-decoder masks.
129
+ encodings_and_encdec_masks = [(x, self.encoder_decoder_mask(decoder_mask, y)) for x, y in encodings_and_masks]
130
+
131
+ # cross attend style: concat encodings
132
+ encoded = torch.cat([x[0] for x in encodings_and_encdec_masks], dim=1)
133
+ encoder_decoder_mask = torch.cat([x[1] for x in encodings_and_encdec_masks], dim=-1)
134
+
135
+ for lyr in self.decoders:
136
+ y = lyr(
137
+ y,
138
+ conditioning_emb=conditioning_emb,
139
+ encoder_hidden_states=encoded,
140
+ encoder_attention_mask=encoder_decoder_mask,
141
+ )[0]
142
+
143
+ y = self.decoder_norm(y)
144
+ y = self.post_dropout(y)
145
+
146
+ spec_out = self.spec_out(y)
147
+ return spec_out
148
+
149
+
150
+ class DecoderLayer(nn.Module):
151
+ r"""
152
+ T5 decoder layer.
153
+
154
+ Args:
155
+ d_model (`int`):
156
+ Size of the input hidden states.
157
+ d_kv (`int`):
158
+ Size of the key-value projection vectors.
159
+ num_heads (`int`):
160
+ Number of attention heads.
161
+ d_ff (`int`):
162
+ Size of the intermediate feed-forward layer.
163
+ dropout_rate (`float`):
164
+ Dropout probability.
165
+ layer_norm_epsilon (`float`, *optional*, defaults to `1e-6`):
166
+ A small value used for numerical stability to avoid dividing by zero.
167
+ """
168
+
169
+ def __init__(
170
+ self, d_model: int, d_kv: int, num_heads: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float = 1e-6
171
+ ):
172
+ super().__init__()
173
+ self.layer = nn.ModuleList()
174
+
175
+ # cond self attention: layer 0
176
+ self.layer.append(
177
+ T5LayerSelfAttentionCond(d_model=d_model, d_kv=d_kv, num_heads=num_heads, dropout_rate=dropout_rate)
178
+ )
179
+
180
+ # cross attention: layer 1
181
+ self.layer.append(
182
+ T5LayerCrossAttention(
183
+ d_model=d_model,
184
+ d_kv=d_kv,
185
+ num_heads=num_heads,
186
+ dropout_rate=dropout_rate,
187
+ layer_norm_epsilon=layer_norm_epsilon,
188
+ )
189
+ )
190
+
191
+ # Film Cond MLP + dropout: last layer
192
+ self.layer.append(
193
+ T5LayerFFCond(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon)
194
+ )
195
+
196
+ def forward(
197
+ self,
198
+ hidden_states: torch.Tensor,
199
+ conditioning_emb: Optional[torch.Tensor] = None,
200
+ attention_mask: Optional[torch.Tensor] = None,
201
+ encoder_hidden_states: Optional[torch.Tensor] = None,
202
+ encoder_attention_mask: Optional[torch.Tensor] = None,
203
+ encoder_decoder_position_bias=None,
204
+ ) -> Tuple[torch.Tensor]:
205
+ hidden_states = self.layer[0](
206
+ hidden_states,
207
+ conditioning_emb=conditioning_emb,
208
+ attention_mask=attention_mask,
209
+ )
210
+
211
+ if encoder_hidden_states is not None:
212
+ encoder_extended_attention_mask = torch.where(encoder_attention_mask > 0, 0, -1e10).to(
213
+ encoder_hidden_states.dtype
214
+ )
215
+
216
+ hidden_states = self.layer[1](
217
+ hidden_states,
218
+ key_value_states=encoder_hidden_states,
219
+ attention_mask=encoder_extended_attention_mask,
220
+ )
221
+
222
+ # Apply Film Conditional Feed Forward layer
223
+ hidden_states = self.layer[-1](hidden_states, conditioning_emb)
224
+
225
+ return (hidden_states,)
226
+
227
+
228
+ class T5LayerSelfAttentionCond(nn.Module):
229
+ r"""
230
+ T5 style self-attention layer with conditioning.
231
+
232
+ Args:
233
+ d_model (`int`):
234
+ Size of the input hidden states.
235
+ d_kv (`int`):
236
+ Size of the key-value projection vectors.
237
+ num_heads (`int`):
238
+ Number of attention heads.
239
+ dropout_rate (`float`):
240
+ Dropout probability.
241
+ """
242
+
243
+ def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float):
244
+ super().__init__()
245
+ self.layer_norm = T5LayerNorm(d_model)
246
+ self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
247
+ self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
248
+ self.dropout = nn.Dropout(dropout_rate)
249
+
250
+ def forward(
251
+ self,
252
+ hidden_states: torch.Tensor,
253
+ conditioning_emb: Optional[torch.Tensor] = None,
254
+ attention_mask: Optional[torch.Tensor] = None,
255
+ ) -> torch.Tensor:
256
+ # pre_self_attention_layer_norm
257
+ normed_hidden_states = self.layer_norm(hidden_states)
258
+
259
+ if conditioning_emb is not None:
260
+ normed_hidden_states = self.FiLMLayer(normed_hidden_states, conditioning_emb)
261
+
262
+ # Self-attention block
263
+ attention_output = self.attention(normed_hidden_states)
264
+
265
+ hidden_states = hidden_states + self.dropout(attention_output)
266
+
267
+ return hidden_states
268
+
269
+
270
+ class T5LayerCrossAttention(nn.Module):
271
+ r"""
272
+ T5 style cross-attention layer.
273
+
274
+ Args:
275
+ d_model (`int`):
276
+ Size of the input hidden states.
277
+ d_kv (`int`):
278
+ Size of the key-value projection vectors.
279
+ num_heads (`int`):
280
+ Number of attention heads.
281
+ dropout_rate (`float`):
282
+ Dropout probability.
283
+ layer_norm_epsilon (`float`):
284
+ A small value used for numerical stability to avoid dividing by zero.
285
+ """
286
+
287
+ def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float, layer_norm_epsilon: float):
288
+ super().__init__()
289
+ self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
290
+ self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
291
+ self.dropout = nn.Dropout(dropout_rate)
292
+
293
+ def forward(
294
+ self,
295
+ hidden_states: torch.Tensor,
296
+ key_value_states: Optional[torch.Tensor] = None,
297
+ attention_mask: Optional[torch.Tensor] = None,
298
+ ) -> torch.Tensor:
299
+ normed_hidden_states = self.layer_norm(hidden_states)
300
+ attention_output = self.attention(
301
+ normed_hidden_states,
302
+ encoder_hidden_states=key_value_states,
303
+ attention_mask=attention_mask.squeeze(1),
304
+ )
305
+ layer_output = hidden_states + self.dropout(attention_output)
306
+ return layer_output
307
+
308
+
309
+ class T5LayerFFCond(nn.Module):
310
+ r"""
311
+ T5 style feed-forward conditional layer.
312
+
313
+ Args:
314
+ d_model (`int`):
315
+ Size of the input hidden states.
316
+ d_ff (`int`):
317
+ Size of the intermediate feed-forward layer.
318
+ dropout_rate (`float`):
319
+ Dropout probability.
320
+ layer_norm_epsilon (`float`):
321
+ A small value used for numerical stability to avoid dividing by zero.
322
+ """
323
+
324
+ def __init__(self, d_model: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float):
325
+ super().__init__()
326
+ self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate)
327
+ self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
328
+ self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
329
+ self.dropout = nn.Dropout(dropout_rate)
330
+
331
+ def forward(self, hidden_states: torch.Tensor, conditioning_emb: Optional[torch.Tensor] = None) -> torch.Tensor:
332
+ forwarded_states = self.layer_norm(hidden_states)
333
+ if conditioning_emb is not None:
334
+ forwarded_states = self.film(forwarded_states, conditioning_emb)
335
+
336
+ forwarded_states = self.DenseReluDense(forwarded_states)
337
+ hidden_states = hidden_states + self.dropout(forwarded_states)
338
+ return hidden_states
339
+
340
+
341
+ class T5DenseGatedActDense(nn.Module):
342
+ r"""
343
+ T5 style feed-forward layer with gated activations and dropout.
344
+
345
+ Args:
346
+ d_model (`int`):
347
+ Size of the input hidden states.
348
+ d_ff (`int`):
349
+ Size of the intermediate feed-forward layer.
350
+ dropout_rate (`float`):
351
+ Dropout probability.
352
+ """
353
+
354
+ def __init__(self, d_model: int, d_ff: int, dropout_rate: float):
355
+ super().__init__()
356
+ self.wi_0 = nn.Linear(d_model, d_ff, bias=False)
357
+ self.wi_1 = nn.Linear(d_model, d_ff, bias=False)
358
+ self.wo = nn.Linear(d_ff, d_model, bias=False)
359
+ self.dropout = nn.Dropout(dropout_rate)
360
+ self.act = NewGELUActivation()
361
+
362
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
363
+ hidden_gelu = self.act(self.wi_0(hidden_states))
364
+ hidden_linear = self.wi_1(hidden_states)
365
+ hidden_states = hidden_gelu * hidden_linear
366
+ hidden_states = self.dropout(hidden_states)
367
+
368
+ hidden_states = self.wo(hidden_states)
369
+ return hidden_states
370
+
371
+
372
+ class T5LayerNorm(nn.Module):
373
+ r"""
374
+ T5 style layer normalization module.
375
+
376
+ Args:
377
+ hidden_size (`int`):
378
+ Size of the input hidden states.
379
+ eps (`float`, `optional`, defaults to `1e-6`):
380
+ A small value used for numerical stability to avoid dividing by zero.
381
+ """
382
+
383
+ def __init__(self, hidden_size: int, eps: float = 1e-6):
384
+ """
385
+ Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
386
+ """
387
+ super().__init__()
388
+ self.weight = nn.Parameter(torch.ones(hidden_size))
389
+ self.variance_epsilon = eps
390
+
391
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
392
+ # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
393
+ # Square Layer Normalization https://huggingface.co/papers/1910.07467 thus variance is calculated
394
+ # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
395
+ # half-precision inputs is done in fp32
396
+
397
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
398
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
399
+
400
+ # convert into half-precision if necessary
401
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
402
+ hidden_states = hidden_states.to(self.weight.dtype)
403
+
404
+ return self.weight * hidden_states
405
+
406
+
407
+ class NewGELUActivation(nn.Module):
408
+ """
409
+ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
410
+ the Gaussian Error Linear Units paper: https://huggingface.co/papers/1606.08415
411
+ """
412
+
413
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
414
+ return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
415
+
416
+
417
+ class T5FiLMLayer(nn.Module):
418
+ """
419
+ T5 style FiLM Layer.
420
+
421
+ Args:
422
+ in_features (`int`):
423
+ Number of input features.
424
+ out_features (`int`):
425
+ Number of output features.
426
+ """
427
+
428
+ def __init__(self, in_features: int, out_features: int):
429
+ super().__init__()
430
+ self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False)
431
+
432
+ def forward(self, x: torch.Tensor, conditioning_emb: torch.Tensor) -> torch.Tensor:
433
+ emb = self.scale_bias(conditioning_emb)
434
+ scale, shift = torch.chunk(emb, 2, -1)
435
+ x = x * (1 + scale) + shift
436
+ return x
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/transformer_hunyuan_video_framepack.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The Framepack Team, The Hunyuan Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, List, Optional, Tuple
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+
21
+ from ...configuration_utils import ConfigMixin, register_to_config
22
+ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
23
+ from ...utils import USE_PEFT_BACKEND, get_logger, scale_lora_layers, unscale_lora_layers
24
+ from ..cache_utils import CacheMixin
25
+ from ..embeddings import get_1d_rotary_pos_embed
26
+ from ..modeling_outputs import Transformer2DModelOutput
27
+ from ..modeling_utils import ModelMixin
28
+ from ..normalization import AdaLayerNormContinuous
29
+ from .transformer_hunyuan_video import (
30
+ HunyuanVideoConditionEmbedding,
31
+ HunyuanVideoPatchEmbed,
32
+ HunyuanVideoSingleTransformerBlock,
33
+ HunyuanVideoTokenRefiner,
34
+ HunyuanVideoTransformerBlock,
35
+ )
36
+
37
+
38
+ logger = get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+
41
+ class HunyuanVideoFramepackRotaryPosEmbed(nn.Module):
42
+ def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None:
43
+ super().__init__()
44
+
45
+ self.patch_size = patch_size
46
+ self.patch_size_t = patch_size_t
47
+ self.rope_dim = rope_dim
48
+ self.theta = theta
49
+
50
+ def forward(self, frame_indices: torch.Tensor, height: int, width: int, device: torch.device):
51
+ height = height // self.patch_size
52
+ width = width // self.patch_size
53
+ grid = torch.meshgrid(
54
+ frame_indices.to(device=device, dtype=torch.float32),
55
+ torch.arange(0, height, device=device, dtype=torch.float32),
56
+ torch.arange(0, width, device=device, dtype=torch.float32),
57
+ indexing="ij",
58
+ ) # 3 * [W, H, T]
59
+ grid = torch.stack(grid, dim=0) # [3, W, H, T]
60
+
61
+ freqs = []
62
+ for i in range(3):
63
+ freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True)
64
+ freqs.append(freq)
65
+
66
+ freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
67
+ freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
68
+
69
+ return freqs_cos, freqs_sin
70
+
71
+
72
+ class FramepackClipVisionProjection(nn.Module):
73
+ def __init__(self, in_channels: int, out_channels: int):
74
+ super().__init__()
75
+ self.up = nn.Linear(in_channels, out_channels * 3)
76
+ self.down = nn.Linear(out_channels * 3, out_channels)
77
+
78
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
79
+ hidden_states = self.up(hidden_states)
80
+ hidden_states = F.silu(hidden_states)
81
+ hidden_states = self.down(hidden_states)
82
+ return hidden_states
83
+
84
+
85
+ class HunyuanVideoHistoryPatchEmbed(nn.Module):
86
+ def __init__(self, in_channels: int, inner_dim: int):
87
+ super().__init__()
88
+ self.proj = nn.Conv3d(in_channels, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
89
+ self.proj_2x = nn.Conv3d(in_channels, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
90
+ self.proj_4x = nn.Conv3d(in_channels, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
91
+
92
+ def forward(
93
+ self,
94
+ latents_clean: Optional[torch.Tensor] = None,
95
+ latents_clean_2x: Optional[torch.Tensor] = None,
96
+ latents_clean_4x: Optional[torch.Tensor] = None,
97
+ ):
98
+ if latents_clean is not None:
99
+ latents_clean = self.proj(latents_clean)
100
+ latents_clean = latents_clean.flatten(2).transpose(1, 2)
101
+ if latents_clean_2x is not None:
102
+ latents_clean_2x = _pad_for_3d_conv(latents_clean_2x, (2, 4, 4))
103
+ latents_clean_2x = self.proj_2x(latents_clean_2x)
104
+ latents_clean_2x = latents_clean_2x.flatten(2).transpose(1, 2)
105
+ if latents_clean_4x is not None:
106
+ latents_clean_4x = _pad_for_3d_conv(latents_clean_4x, (4, 8, 8))
107
+ latents_clean_4x = self.proj_4x(latents_clean_4x)
108
+ latents_clean_4x = latents_clean_4x.flatten(2).transpose(1, 2)
109
+ return latents_clean, latents_clean_2x, latents_clean_4x
110
+
111
+
112
+ class HunyuanVideoFramepackTransformer3DModel(
113
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin
114
+ ):
115
+ _supports_gradient_checkpointing = True
116
+ _skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
117
+ _no_split_modules = [
118
+ "HunyuanVideoTransformerBlock",
119
+ "HunyuanVideoSingleTransformerBlock",
120
+ "HunyuanVideoHistoryPatchEmbed",
121
+ "HunyuanVideoTokenRefiner",
122
+ ]
123
+
124
+ @register_to_config
125
+ def __init__(
126
+ self,
127
+ in_channels: int = 16,
128
+ out_channels: int = 16,
129
+ num_attention_heads: int = 24,
130
+ attention_head_dim: int = 128,
131
+ num_layers: int = 20,
132
+ num_single_layers: int = 40,
133
+ num_refiner_layers: int = 2,
134
+ mlp_ratio: float = 4.0,
135
+ patch_size: int = 2,
136
+ patch_size_t: int = 1,
137
+ qk_norm: str = "rms_norm",
138
+ guidance_embeds: bool = True,
139
+ text_embed_dim: int = 4096,
140
+ pooled_projection_dim: int = 768,
141
+ rope_theta: float = 256.0,
142
+ rope_axes_dim: Tuple[int] = (16, 56, 56),
143
+ image_condition_type: Optional[str] = None,
144
+ has_image_proj: int = False,
145
+ image_proj_dim: int = 1152,
146
+ has_clean_x_embedder: int = False,
147
+ ) -> None:
148
+ super().__init__()
149
+
150
+ inner_dim = num_attention_heads * attention_head_dim
151
+ out_channels = out_channels or in_channels
152
+
153
+ # 1. Latent and condition embedders
154
+ self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
155
+
156
+ # Framepack history projection embedder
157
+ self.clean_x_embedder = None
158
+ if has_clean_x_embedder:
159
+ self.clean_x_embedder = HunyuanVideoHistoryPatchEmbed(in_channels, inner_dim)
160
+
161
+ self.context_embedder = HunyuanVideoTokenRefiner(
162
+ text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
163
+ )
164
+
165
+ # Framepack image-conditioning embedder
166
+ self.image_projection = FramepackClipVisionProjection(image_proj_dim, inner_dim) if has_image_proj else None
167
+
168
+ self.time_text_embed = HunyuanVideoConditionEmbedding(
169
+ inner_dim, pooled_projection_dim, guidance_embeds, image_condition_type
170
+ )
171
+
172
+ # 2. RoPE
173
+ self.rope = HunyuanVideoFramepackRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
174
+
175
+ # 3. Dual stream transformer blocks
176
+ self.transformer_blocks = nn.ModuleList(
177
+ [
178
+ HunyuanVideoTransformerBlock(
179
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
180
+ )
181
+ for _ in range(num_layers)
182
+ ]
183
+ )
184
+
185
+ # 4. Single stream transformer blocks
186
+ self.single_transformer_blocks = nn.ModuleList(
187
+ [
188
+ HunyuanVideoSingleTransformerBlock(
189
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
190
+ )
191
+ for _ in range(num_single_layers)
192
+ ]
193
+ )
194
+
195
+ # 5. Output projection
196
+ self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
197
+ self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
198
+
199
+ self.gradient_checkpointing = False
200
+
201
+ def forward(
202
+ self,
203
+ hidden_states: torch.Tensor,
204
+ timestep: torch.LongTensor,
205
+ encoder_hidden_states: torch.Tensor,
206
+ encoder_attention_mask: torch.Tensor,
207
+ pooled_projections: torch.Tensor,
208
+ image_embeds: torch.Tensor,
209
+ indices_latents: torch.Tensor,
210
+ guidance: Optional[torch.Tensor] = None,
211
+ latents_clean: Optional[torch.Tensor] = None,
212
+ indices_latents_clean: Optional[torch.Tensor] = None,
213
+ latents_history_2x: Optional[torch.Tensor] = None,
214
+ indices_latents_history_2x: Optional[torch.Tensor] = None,
215
+ latents_history_4x: Optional[torch.Tensor] = None,
216
+ indices_latents_history_4x: Optional[torch.Tensor] = None,
217
+ attention_kwargs: Optional[Dict[str, Any]] = None,
218
+ return_dict: bool = True,
219
+ ):
220
+ if attention_kwargs is not None:
221
+ attention_kwargs = attention_kwargs.copy()
222
+ lora_scale = attention_kwargs.pop("scale", 1.0)
223
+ else:
224
+ lora_scale = 1.0
225
+
226
+ if USE_PEFT_BACKEND:
227
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
228
+ scale_lora_layers(self, lora_scale)
229
+ else:
230
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
231
+ logger.warning(
232
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
233
+ )
234
+
235
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
236
+ p, p_t = self.config.patch_size, self.config.patch_size_t
237
+ post_patch_num_frames = num_frames // p_t
238
+ post_patch_height = height // p
239
+ post_patch_width = width // p
240
+ original_context_length = post_patch_num_frames * post_patch_height * post_patch_width
241
+
242
+ if indices_latents is None:
243
+ indices_latents = torch.arange(0, num_frames).unsqueeze(0).expand(batch_size, -1)
244
+
245
+ hidden_states = self.x_embedder(hidden_states)
246
+ image_rotary_emb = self.rope(
247
+ frame_indices=indices_latents, height=height, width=width, device=hidden_states.device
248
+ )
249
+
250
+ latents_clean, latents_history_2x, latents_history_4x = self.clean_x_embedder(
251
+ latents_clean, latents_history_2x, latents_history_4x
252
+ )
253
+
254
+ if latents_clean is not None and indices_latents_clean is not None:
255
+ image_rotary_emb_clean = self.rope(
256
+ frame_indices=indices_latents_clean, height=height, width=width, device=hidden_states.device
257
+ )
258
+ if latents_history_2x is not None and indices_latents_history_2x is not None:
259
+ image_rotary_emb_history_2x = self.rope(
260
+ frame_indices=indices_latents_history_2x, height=height, width=width, device=hidden_states.device
261
+ )
262
+ if latents_history_4x is not None and indices_latents_history_4x is not None:
263
+ image_rotary_emb_history_4x = self.rope(
264
+ frame_indices=indices_latents_history_4x, height=height, width=width, device=hidden_states.device
265
+ )
266
+
267
+ hidden_states, image_rotary_emb = self._pack_history_states(
268
+ hidden_states,
269
+ latents_clean,
270
+ latents_history_2x,
271
+ latents_history_4x,
272
+ image_rotary_emb,
273
+ image_rotary_emb_clean,
274
+ image_rotary_emb_history_2x,
275
+ image_rotary_emb_history_4x,
276
+ post_patch_height,
277
+ post_patch_width,
278
+ )
279
+
280
+ temb, _ = self.time_text_embed(timestep, pooled_projections, guidance)
281
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
282
+
283
+ encoder_hidden_states_image = self.image_projection(image_embeds)
284
+ attention_mask_image = encoder_attention_mask.new_ones((batch_size, encoder_hidden_states_image.shape[1]))
285
+
286
+ # must cat before (not after) encoder_hidden_states, due to attn masking
287
+ encoder_hidden_states = torch.cat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
288
+ encoder_attention_mask = torch.cat([attention_mask_image, encoder_attention_mask], dim=1)
289
+
290
+ latent_sequence_length = hidden_states.shape[1]
291
+ condition_sequence_length = encoder_hidden_states.shape[1]
292
+ sequence_length = latent_sequence_length + condition_sequence_length
293
+ attention_mask = torch.zeros(
294
+ batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool
295
+ ) # [B, N]
296
+ effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,]
297
+ effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
298
+
299
+ if batch_size == 1:
300
+ encoder_hidden_states = encoder_hidden_states[:, : effective_condition_sequence_length[0]]
301
+ attention_mask = None
302
+ else:
303
+ for i in range(batch_size):
304
+ attention_mask[i, : effective_sequence_length[i]] = True
305
+ # [B, 1, 1, N], for broadcasting across attention heads
306
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
307
+
308
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
309
+ for block in self.transformer_blocks:
310
+ hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
311
+ block, hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
312
+ )
313
+
314
+ for block in self.single_transformer_blocks:
315
+ hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
316
+ block, hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
317
+ )
318
+
319
+ else:
320
+ for block in self.transformer_blocks:
321
+ hidden_states, encoder_hidden_states = block(
322
+ hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
323
+ )
324
+
325
+ for block in self.single_transformer_blocks:
326
+ hidden_states, encoder_hidden_states = block(
327
+ hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
328
+ )
329
+
330
+ hidden_states = hidden_states[:, -original_context_length:]
331
+ hidden_states = self.norm_out(hidden_states, temb)
332
+ hidden_states = self.proj_out(hidden_states)
333
+
334
+ hidden_states = hidden_states.reshape(
335
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p
336
+ )
337
+ hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
338
+ hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
339
+
340
+ if USE_PEFT_BACKEND:
341
+ # remove `lora_scale` from each PEFT layer
342
+ unscale_lora_layers(self, lora_scale)
343
+
344
+ if not return_dict:
345
+ return (hidden_states,)
346
+ return Transformer2DModelOutput(sample=hidden_states)
347
+
348
+ def _pack_history_states(
349
+ self,
350
+ hidden_states: torch.Tensor,
351
+ latents_clean: Optional[torch.Tensor] = None,
352
+ latents_history_2x: Optional[torch.Tensor] = None,
353
+ latents_history_4x: Optional[torch.Tensor] = None,
354
+ image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] = None,
355
+ image_rotary_emb_clean: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
356
+ image_rotary_emb_history_2x: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
357
+ image_rotary_emb_history_4x: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
358
+ height: int = None,
359
+ width: int = None,
360
+ ):
361
+ image_rotary_emb = list(image_rotary_emb) # convert tuple to list for in-place modification
362
+
363
+ if latents_clean is not None and image_rotary_emb_clean is not None:
364
+ hidden_states = torch.cat([latents_clean, hidden_states], dim=1)
365
+ image_rotary_emb[0] = torch.cat([image_rotary_emb_clean[0], image_rotary_emb[0]], dim=0)
366
+ image_rotary_emb[1] = torch.cat([image_rotary_emb_clean[1], image_rotary_emb[1]], dim=0)
367
+
368
+ if latents_history_2x is not None and image_rotary_emb_history_2x is not None:
369
+ hidden_states = torch.cat([latents_history_2x, hidden_states], dim=1)
370
+ image_rotary_emb_history_2x = self._pad_rotary_emb(image_rotary_emb_history_2x, height, width, (2, 2, 2))
371
+ image_rotary_emb[0] = torch.cat([image_rotary_emb_history_2x[0], image_rotary_emb[0]], dim=0)
372
+ image_rotary_emb[1] = torch.cat([image_rotary_emb_history_2x[1], image_rotary_emb[1]], dim=0)
373
+
374
+ if latents_history_4x is not None and image_rotary_emb_history_4x is not None:
375
+ hidden_states = torch.cat([latents_history_4x, hidden_states], dim=1)
376
+ image_rotary_emb_history_4x = self._pad_rotary_emb(image_rotary_emb_history_4x, height, width, (4, 4, 4))
377
+ image_rotary_emb[0] = torch.cat([image_rotary_emb_history_4x[0], image_rotary_emb[0]], dim=0)
378
+ image_rotary_emb[1] = torch.cat([image_rotary_emb_history_4x[1], image_rotary_emb[1]], dim=0)
379
+
380
+ return hidden_states, tuple(image_rotary_emb)
381
+
382
+ def _pad_rotary_emb(
383
+ self,
384
+ image_rotary_emb: Tuple[torch.Tensor],
385
+ height: int,
386
+ width: int,
387
+ kernel_size: Tuple[int, int, int],
388
+ ):
389
+ # freqs_cos, freqs_sin have shape [W * H * T, D / 2], where D is attention head dim
390
+ freqs_cos, freqs_sin = image_rotary_emb
391
+ freqs_cos = freqs_cos.unsqueeze(0).permute(0, 2, 1).unflatten(2, (-1, height, width))
392
+ freqs_sin = freqs_sin.unsqueeze(0).permute(0, 2, 1).unflatten(2, (-1, height, width))
393
+ freqs_cos = _pad_for_3d_conv(freqs_cos, kernel_size)
394
+ freqs_sin = _pad_for_3d_conv(freqs_sin, kernel_size)
395
+ freqs_cos = _center_down_sample_3d(freqs_cos, kernel_size)
396
+ freqs_sin = _center_down_sample_3d(freqs_sin, kernel_size)
397
+ freqs_cos = freqs_cos.flatten(2).permute(0, 2, 1).squeeze(0)
398
+ freqs_sin = freqs_sin.flatten(2).permute(0, 2, 1).squeeze(0)
399
+ return freqs_cos, freqs_sin
400
+
401
+
402
+ def _pad_for_3d_conv(x, kernel_size):
403
+ if isinstance(x, (tuple, list)):
404
+ return tuple(_pad_for_3d_conv(i, kernel_size) for i in x)
405
+ b, c, t, h, w = x.shape
406
+ pt, ph, pw = kernel_size
407
+ pad_t = (pt - (t % pt)) % pt
408
+ pad_h = (ph - (h % ph)) % ph
409
+ pad_w = (pw - (w % pw)) % pw
410
+ return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode="replicate")
411
+
412
+
413
+ def _center_down_sample_3d(x, kernel_size):
414
+ if isinstance(x, (tuple, list)):
415
+ return tuple(_center_down_sample_3d(i, kernel_size) for i in x)
416
+ return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size)
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/transformer_ltx.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The Lightricks team and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ import math
18
+ from typing import Any, Dict, Optional, Tuple, Union
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+
23
+ from ...configuration_utils import ConfigMixin, register_to_config
24
+ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
25
+ from ...utils import USE_PEFT_BACKEND, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
26
+ from ...utils.torch_utils import maybe_allow_in_graph
27
+ from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
28
+ from ..attention_dispatch import dispatch_attention_fn
29
+ from ..cache_utils import CacheMixin
30
+ from ..embeddings import PixArtAlphaTextProjection
31
+ from ..modeling_outputs import Transformer2DModelOutput
32
+ from ..modeling_utils import ModelMixin
33
+ from ..normalization import AdaLayerNormSingle, RMSNorm
34
+
35
+
36
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
37
+
38
+
39
+ class LTXVideoAttentionProcessor2_0:
40
+ def __new__(cls, *args, **kwargs):
41
+ deprecation_message = "`LTXVideoAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `LTXVideoAttnProcessor`"
42
+ deprecate("LTXVideoAttentionProcessor2_0", "1.0.0", deprecation_message)
43
+
44
+ return LTXVideoAttnProcessor(*args, **kwargs)
45
+
46
+
47
+ class LTXVideoAttnProcessor:
48
+ r"""
49
+ Processor for implementing attention (SDPA is used by default if you're using PyTorch 2.0). This is used in the LTX
50
+ model. It applies a normalization layer and rotary embedding on the query and key vector.
51
+ """
52
+
53
+ _attention_backend = None
54
+
55
+ def __init__(self):
56
+ if is_torch_version("<", "2.0"):
57
+ raise ValueError(
58
+ "LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation."
59
+ )
60
+
61
+ def __call__(
62
+ self,
63
+ attn: "LTXAttention",
64
+ hidden_states: torch.Tensor,
65
+ encoder_hidden_states: Optional[torch.Tensor] = None,
66
+ attention_mask: Optional[torch.Tensor] = None,
67
+ image_rotary_emb: Optional[torch.Tensor] = None,
68
+ ) -> torch.Tensor:
69
+ batch_size, sequence_length, _ = (
70
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
71
+ )
72
+
73
+ if attention_mask is not None:
74
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
75
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
76
+
77
+ if encoder_hidden_states is None:
78
+ encoder_hidden_states = hidden_states
79
+
80
+ query = attn.to_q(hidden_states)
81
+ key = attn.to_k(encoder_hidden_states)
82
+ value = attn.to_v(encoder_hidden_states)
83
+
84
+ query = attn.norm_q(query)
85
+ key = attn.norm_k(key)
86
+
87
+ if image_rotary_emb is not None:
88
+ query = apply_rotary_emb(query, image_rotary_emb)
89
+ key = apply_rotary_emb(key, image_rotary_emb)
90
+
91
+ query = query.unflatten(2, (attn.heads, -1))
92
+ key = key.unflatten(2, (attn.heads, -1))
93
+ value = value.unflatten(2, (attn.heads, -1))
94
+
95
+ hidden_states = dispatch_attention_fn(
96
+ query,
97
+ key,
98
+ value,
99
+ attn_mask=attention_mask,
100
+ dropout_p=0.0,
101
+ is_causal=False,
102
+ backend=self._attention_backend,
103
+ )
104
+ hidden_states = hidden_states.flatten(2, 3)
105
+ hidden_states = hidden_states.to(query.dtype)
106
+
107
+ hidden_states = attn.to_out[0](hidden_states)
108
+ hidden_states = attn.to_out[1](hidden_states)
109
+ return hidden_states
110
+
111
+
112
+ class LTXAttention(torch.nn.Module, AttentionModuleMixin):
113
+ _default_processor_cls = LTXVideoAttnProcessor
114
+ _available_processors = [LTXVideoAttnProcessor]
115
+
116
+ def __init__(
117
+ self,
118
+ query_dim: int,
119
+ heads: int = 8,
120
+ kv_heads: int = 8,
121
+ dim_head: int = 64,
122
+ dropout: float = 0.0,
123
+ bias: bool = True,
124
+ cross_attention_dim: Optional[int] = None,
125
+ out_bias: bool = True,
126
+ qk_norm: str = "rms_norm_across_heads",
127
+ processor=None,
128
+ ):
129
+ super().__init__()
130
+ if qk_norm != "rms_norm_across_heads":
131
+ raise NotImplementedError("Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`.")
132
+
133
+ self.head_dim = dim_head
134
+ self.inner_dim = dim_head * heads
135
+ self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
136
+ self.query_dim = query_dim
137
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
138
+ self.use_bias = bias
139
+ self.dropout = dropout
140
+ self.out_dim = query_dim
141
+ self.heads = heads
142
+
143
+ norm_eps = 1e-5
144
+ norm_elementwise_affine = True
145
+ self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
146
+ self.norm_k = torch.nn.RMSNorm(dim_head * kv_heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
147
+ self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
148
+ self.to_k = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
149
+ self.to_v = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
150
+ self.to_out = torch.nn.ModuleList([])
151
+ self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
152
+ self.to_out.append(torch.nn.Dropout(dropout))
153
+
154
+ if processor is None:
155
+ processor = self._default_processor_cls()
156
+ self.set_processor(processor)
157
+
158
+ def forward(
159
+ self,
160
+ hidden_states: torch.Tensor,
161
+ encoder_hidden_states: Optional[torch.Tensor] = None,
162
+ attention_mask: Optional[torch.Tensor] = None,
163
+ image_rotary_emb: Optional[torch.Tensor] = None,
164
+ **kwargs,
165
+ ) -> torch.Tensor:
166
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
167
+ unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
168
+ if len(unused_kwargs) > 0:
169
+ logger.warning(
170
+ f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
171
+ )
172
+ kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
173
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
174
+
175
+
176
+ class LTXVideoRotaryPosEmbed(nn.Module):
177
+ def __init__(
178
+ self,
179
+ dim: int,
180
+ base_num_frames: int = 20,
181
+ base_height: int = 2048,
182
+ base_width: int = 2048,
183
+ patch_size: int = 1,
184
+ patch_size_t: int = 1,
185
+ theta: float = 10000.0,
186
+ ) -> None:
187
+ super().__init__()
188
+
189
+ self.dim = dim
190
+ self.base_num_frames = base_num_frames
191
+ self.base_height = base_height
192
+ self.base_width = base_width
193
+ self.patch_size = patch_size
194
+ self.patch_size_t = patch_size_t
195
+ self.theta = theta
196
+
197
+ def _prepare_video_coords(
198
+ self,
199
+ batch_size: int,
200
+ num_frames: int,
201
+ height: int,
202
+ width: int,
203
+ rope_interpolation_scale: Tuple[torch.Tensor, float, float],
204
+ device: torch.device,
205
+ ) -> torch.Tensor:
206
+ # Always compute rope in fp32
207
+ grid_h = torch.arange(height, dtype=torch.float32, device=device)
208
+ grid_w = torch.arange(width, dtype=torch.float32, device=device)
209
+ grid_f = torch.arange(num_frames, dtype=torch.float32, device=device)
210
+ grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij")
211
+ grid = torch.stack(grid, dim=0)
212
+ grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
213
+
214
+ if rope_interpolation_scale is not None:
215
+ grid[:, 0:1] = grid[:, 0:1] * rope_interpolation_scale[0] * self.patch_size_t / self.base_num_frames
216
+ grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1] * self.patch_size / self.base_height
217
+ grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2] * self.patch_size / self.base_width
218
+
219
+ grid = grid.flatten(2, 4).transpose(1, 2)
220
+
221
+ return grid
222
+
223
+ def forward(
224
+ self,
225
+ hidden_states: torch.Tensor,
226
+ num_frames: Optional[int] = None,
227
+ height: Optional[int] = None,
228
+ width: Optional[int] = None,
229
+ rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None,
230
+ video_coords: Optional[torch.Tensor] = None,
231
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
232
+ batch_size = hidden_states.size(0)
233
+
234
+ if video_coords is None:
235
+ grid = self._prepare_video_coords(
236
+ batch_size,
237
+ num_frames,
238
+ height,
239
+ width,
240
+ rope_interpolation_scale=rope_interpolation_scale,
241
+ device=hidden_states.device,
242
+ )
243
+ else:
244
+ grid = torch.stack(
245
+ [
246
+ video_coords[:, 0] / self.base_num_frames,
247
+ video_coords[:, 1] / self.base_height,
248
+ video_coords[:, 2] / self.base_width,
249
+ ],
250
+ dim=-1,
251
+ )
252
+
253
+ start = 1.0
254
+ end = self.theta
255
+ freqs = self.theta ** torch.linspace(
256
+ math.log(start, self.theta),
257
+ math.log(end, self.theta),
258
+ self.dim // 6,
259
+ device=hidden_states.device,
260
+ dtype=torch.float32,
261
+ )
262
+ freqs = freqs * math.pi / 2.0
263
+ freqs = freqs * (grid.unsqueeze(-1) * 2 - 1)
264
+ freqs = freqs.transpose(-1, -2).flatten(2)
265
+
266
+ cos_freqs = freqs.cos().repeat_interleave(2, dim=-1)
267
+ sin_freqs = freqs.sin().repeat_interleave(2, dim=-1)
268
+
269
+ if self.dim % 6 != 0:
270
+ cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % 6])
271
+ sin_padding = torch.zeros_like(cos_freqs[:, :, : self.dim % 6])
272
+ cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1)
273
+ sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1)
274
+
275
+ return cos_freqs, sin_freqs
276
+
277
+
278
+ @maybe_allow_in_graph
279
+ class LTXVideoTransformerBlock(nn.Module):
280
+ r"""
281
+ Transformer block used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
282
+
283
+ Args:
284
+ dim (`int`):
285
+ The number of channels in the input and output.
286
+ num_attention_heads (`int`):
287
+ The number of heads to use for multi-head attention.
288
+ attention_head_dim (`int`):
289
+ The number of channels in each head.
290
+ qk_norm (`str`, defaults to `"rms_norm"`):
291
+ The normalization layer to use.
292
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
293
+ Activation function to use in feed-forward.
294
+ eps (`float`, defaults to `1e-6`):
295
+ Epsilon value for normalization layers.
296
+ """
297
+
298
+ def __init__(
299
+ self,
300
+ dim: int,
301
+ num_attention_heads: int,
302
+ attention_head_dim: int,
303
+ cross_attention_dim: int,
304
+ qk_norm: str = "rms_norm_across_heads",
305
+ activation_fn: str = "gelu-approximate",
306
+ attention_bias: bool = True,
307
+ attention_out_bias: bool = True,
308
+ eps: float = 1e-6,
309
+ elementwise_affine: bool = False,
310
+ ):
311
+ super().__init__()
312
+
313
+ self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
314
+ self.attn1 = LTXAttention(
315
+ query_dim=dim,
316
+ heads=num_attention_heads,
317
+ kv_heads=num_attention_heads,
318
+ dim_head=attention_head_dim,
319
+ bias=attention_bias,
320
+ cross_attention_dim=None,
321
+ out_bias=attention_out_bias,
322
+ qk_norm=qk_norm,
323
+ )
324
+
325
+ self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
326
+ self.attn2 = LTXAttention(
327
+ query_dim=dim,
328
+ cross_attention_dim=cross_attention_dim,
329
+ heads=num_attention_heads,
330
+ kv_heads=num_attention_heads,
331
+ dim_head=attention_head_dim,
332
+ bias=attention_bias,
333
+ out_bias=attention_out_bias,
334
+ qk_norm=qk_norm,
335
+ )
336
+
337
+ self.ff = FeedForward(dim, activation_fn=activation_fn)
338
+
339
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
340
+
341
+ def forward(
342
+ self,
343
+ hidden_states: torch.Tensor,
344
+ encoder_hidden_states: torch.Tensor,
345
+ temb: torch.Tensor,
346
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
347
+ encoder_attention_mask: Optional[torch.Tensor] = None,
348
+ ) -> torch.Tensor:
349
+ batch_size = hidden_states.size(0)
350
+ norm_hidden_states = self.norm1(hidden_states)
351
+
352
+ num_ada_params = self.scale_shift_table.shape[0]
353
+ ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
354
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
355
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
356
+
357
+ attn_hidden_states = self.attn1(
358
+ hidden_states=norm_hidden_states,
359
+ encoder_hidden_states=None,
360
+ image_rotary_emb=image_rotary_emb,
361
+ )
362
+ hidden_states = hidden_states + attn_hidden_states * gate_msa
363
+
364
+ attn_hidden_states = self.attn2(
365
+ hidden_states,
366
+ encoder_hidden_states=encoder_hidden_states,
367
+ image_rotary_emb=None,
368
+ attention_mask=encoder_attention_mask,
369
+ )
370
+ hidden_states = hidden_states + attn_hidden_states
371
+ norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp) + shift_mlp
372
+
373
+ ff_output = self.ff(norm_hidden_states)
374
+ hidden_states = hidden_states + ff_output * gate_mlp
375
+
376
+ return hidden_states
377
+
378
+
379
+ @maybe_allow_in_graph
380
+ class LTXVideoTransformer3DModel(
381
+ ModelMixin, ConfigMixin, AttentionMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin
382
+ ):
383
+ r"""
384
+ A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
385
+
386
+ Args:
387
+ in_channels (`int`, defaults to `128`):
388
+ The number of channels in the input.
389
+ out_channels (`int`, defaults to `128`):
390
+ The number of channels in the output.
391
+ patch_size (`int`, defaults to `1`):
392
+ The size of the spatial patches to use in the patch embedding layer.
393
+ patch_size_t (`int`, defaults to `1`):
394
+ The size of the tmeporal patches to use in the patch embedding layer.
395
+ num_attention_heads (`int`, defaults to `32`):
396
+ The number of heads to use for multi-head attention.
397
+ attention_head_dim (`int`, defaults to `64`):
398
+ The number of channels in each head.
399
+ cross_attention_dim (`int`, defaults to `2048 `):
400
+ The number of channels for cross attention heads.
401
+ num_layers (`int`, defaults to `28`):
402
+ The number of layers of Transformer blocks to use.
403
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
404
+ Activation function to use in feed-forward.
405
+ qk_norm (`str`, defaults to `"rms_norm_across_heads"`):
406
+ The normalization layer to use.
407
+ """
408
+
409
+ _supports_gradient_checkpointing = True
410
+ _skip_layerwise_casting_patterns = ["norm"]
411
+ _repeated_blocks = ["LTXVideoTransformerBlock"]
412
+
413
+ @register_to_config
414
+ def __init__(
415
+ self,
416
+ in_channels: int = 128,
417
+ out_channels: int = 128,
418
+ patch_size: int = 1,
419
+ patch_size_t: int = 1,
420
+ num_attention_heads: int = 32,
421
+ attention_head_dim: int = 64,
422
+ cross_attention_dim: int = 2048,
423
+ num_layers: int = 28,
424
+ activation_fn: str = "gelu-approximate",
425
+ qk_norm: str = "rms_norm_across_heads",
426
+ norm_elementwise_affine: bool = False,
427
+ norm_eps: float = 1e-6,
428
+ caption_channels: int = 4096,
429
+ attention_bias: bool = True,
430
+ attention_out_bias: bool = True,
431
+ ) -> None:
432
+ super().__init__()
433
+
434
+ out_channels = out_channels or in_channels
435
+ inner_dim = num_attention_heads * attention_head_dim
436
+
437
+ self.proj_in = nn.Linear(in_channels, inner_dim)
438
+
439
+ self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
440
+ self.time_embed = AdaLayerNormSingle(inner_dim, use_additional_conditions=False)
441
+
442
+ self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
443
+
444
+ self.rope = LTXVideoRotaryPosEmbed(
445
+ dim=inner_dim,
446
+ base_num_frames=20,
447
+ base_height=2048,
448
+ base_width=2048,
449
+ patch_size=patch_size,
450
+ patch_size_t=patch_size_t,
451
+ theta=10000.0,
452
+ )
453
+
454
+ self.transformer_blocks = nn.ModuleList(
455
+ [
456
+ LTXVideoTransformerBlock(
457
+ dim=inner_dim,
458
+ num_attention_heads=num_attention_heads,
459
+ attention_head_dim=attention_head_dim,
460
+ cross_attention_dim=cross_attention_dim,
461
+ qk_norm=qk_norm,
462
+ activation_fn=activation_fn,
463
+ attention_bias=attention_bias,
464
+ attention_out_bias=attention_out_bias,
465
+ eps=norm_eps,
466
+ elementwise_affine=norm_elementwise_affine,
467
+ )
468
+ for _ in range(num_layers)
469
+ ]
470
+ )
471
+
472
+ self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False)
473
+ self.proj_out = nn.Linear(inner_dim, out_channels)
474
+
475
+ self.gradient_checkpointing = False
476
+
477
+ def forward(
478
+ self,
479
+ hidden_states: torch.Tensor,
480
+ encoder_hidden_states: torch.Tensor,
481
+ timestep: torch.LongTensor,
482
+ encoder_attention_mask: torch.Tensor,
483
+ num_frames: Optional[int] = None,
484
+ height: Optional[int] = None,
485
+ width: Optional[int] = None,
486
+ rope_interpolation_scale: Optional[Union[Tuple[float, float, float], torch.Tensor]] = None,
487
+ video_coords: Optional[torch.Tensor] = None,
488
+ attention_kwargs: Optional[Dict[str, Any]] = None,
489
+ return_dict: bool = True,
490
+ ) -> torch.Tensor:
491
+ if attention_kwargs is not None:
492
+ attention_kwargs = attention_kwargs.copy()
493
+ lora_scale = attention_kwargs.pop("scale", 1.0)
494
+ else:
495
+ lora_scale = 1.0
496
+
497
+ if USE_PEFT_BACKEND:
498
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
499
+ scale_lora_layers(self, lora_scale)
500
+ else:
501
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
502
+ logger.warning(
503
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
504
+ )
505
+
506
+ image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale, video_coords)
507
+
508
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
509
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
510
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
511
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
512
+
513
+ batch_size = hidden_states.size(0)
514
+ hidden_states = self.proj_in(hidden_states)
515
+
516
+ temb, embedded_timestep = self.time_embed(
517
+ timestep.flatten(),
518
+ batch_size=batch_size,
519
+ hidden_dtype=hidden_states.dtype,
520
+ )
521
+
522
+ temb = temb.view(batch_size, -1, temb.size(-1))
523
+ embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
524
+
525
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
526
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))
527
+
528
+ for block in self.transformer_blocks:
529
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
530
+ hidden_states = self._gradient_checkpointing_func(
531
+ block,
532
+ hidden_states,
533
+ encoder_hidden_states,
534
+ temb,
535
+ image_rotary_emb,
536
+ encoder_attention_mask,
537
+ )
538
+ else:
539
+ hidden_states = block(
540
+ hidden_states=hidden_states,
541
+ encoder_hidden_states=encoder_hidden_states,
542
+ temb=temb,
543
+ image_rotary_emb=image_rotary_emb,
544
+ encoder_attention_mask=encoder_attention_mask,
545
+ )
546
+
547
+ scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
548
+ shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
549
+
550
+ hidden_states = self.norm_out(hidden_states)
551
+ hidden_states = hidden_states * (1 + scale) + shift
552
+ output = self.proj_out(hidden_states)
553
+
554
+ if USE_PEFT_BACKEND:
555
+ # remove `lora_scale` from each PEFT layer
556
+ unscale_lora_layers(self, lora_scale)
557
+
558
+ if not return_dict:
559
+ return (output,)
560
+ return Transformer2DModelOutput(sample=output)
561
+
562
+
563
+ def apply_rotary_emb(x, freqs):
564
+ cos, sin = freqs
565
+ x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, C // 2]
566
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2)
567
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
568
+ return out
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/transformer_lumina2.py ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from ...configuration_utils import ConfigMixin, register_to_config
23
+ from ...loaders import PeftAdapterMixin
24
+ from ...loaders.single_file_model import FromOriginalModelMixin
25
+ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
26
+ from ..attention import LuminaFeedForward
27
+ from ..attention_processor import Attention
28
+ from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed
29
+ from ..modeling_outputs import Transformer2DModelOutput
30
+ from ..modeling_utils import ModelMixin
31
+ from ..normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm
32
+
33
+
34
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
+
36
+
37
+ class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
38
+ def __init__(
39
+ self,
40
+ hidden_size: int = 4096,
41
+ cap_feat_dim: int = 2048,
42
+ frequency_embedding_size: int = 256,
43
+ norm_eps: float = 1e-5,
44
+ ) -> None:
45
+ super().__init__()
46
+
47
+ self.time_proj = Timesteps(
48
+ num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0
49
+ )
50
+
51
+ self.timestep_embedder = TimestepEmbedding(
52
+ in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024)
53
+ )
54
+
55
+ self.caption_embedder = nn.Sequential(
56
+ RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, hidden_size, bias=True)
57
+ )
58
+
59
+ def forward(
60
+ self, hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor
61
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
62
+ timestep_proj = self.time_proj(timestep).type_as(hidden_states)
63
+ time_embed = self.timestep_embedder(timestep_proj)
64
+ caption_embed = self.caption_embedder(encoder_hidden_states)
65
+ return time_embed, caption_embed
66
+
67
+
68
+ class Lumina2AttnProcessor2_0:
69
+ r"""
70
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
71
+ used in the Lumina2Transformer2DModel model. It applies normalization and RoPE on query and key vectors.
72
+ """
73
+
74
+ def __init__(self):
75
+ if not hasattr(F, "scaled_dot_product_attention"):
76
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
77
+
78
+ def __call__(
79
+ self,
80
+ attn: Attention,
81
+ hidden_states: torch.Tensor,
82
+ encoder_hidden_states: torch.Tensor,
83
+ attention_mask: Optional[torch.Tensor] = None,
84
+ image_rotary_emb: Optional[torch.Tensor] = None,
85
+ base_sequence_length: Optional[int] = None,
86
+ ) -> torch.Tensor:
87
+ batch_size, sequence_length, _ = hidden_states.shape
88
+
89
+ # Get Query-Key-Value Pair
90
+ query = attn.to_q(hidden_states)
91
+ key = attn.to_k(encoder_hidden_states)
92
+ value = attn.to_v(encoder_hidden_states)
93
+
94
+ query_dim = query.shape[-1]
95
+ inner_dim = key.shape[-1]
96
+ head_dim = query_dim // attn.heads
97
+ dtype = query.dtype
98
+
99
+ # Get key-value heads
100
+ kv_heads = inner_dim // head_dim
101
+
102
+ query = query.view(batch_size, -1, attn.heads, head_dim)
103
+ key = key.view(batch_size, -1, kv_heads, head_dim)
104
+ value = value.view(batch_size, -1, kv_heads, head_dim)
105
+
106
+ # Apply Query-Key Norm if needed
107
+ if attn.norm_q is not None:
108
+ query = attn.norm_q(query)
109
+ if attn.norm_k is not None:
110
+ key = attn.norm_k(key)
111
+
112
+ # Apply RoPE if needed
113
+ if image_rotary_emb is not None:
114
+ query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
115
+ key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
116
+
117
+ query, key = query.to(dtype), key.to(dtype)
118
+
119
+ # Apply proportional attention if true
120
+ if base_sequence_length is not None:
121
+ softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
122
+ else:
123
+ softmax_scale = attn.scale
124
+
125
+ # perform Grouped-qurey Attention (GQA)
126
+ n_rep = attn.heads // kv_heads
127
+ if n_rep >= 1:
128
+ key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
129
+ value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
130
+
131
+ # scaled_dot_product_attention expects attention_mask shape to be
132
+ # (batch, heads, source_length, target_length)
133
+ if attention_mask is not None:
134
+ attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
135
+
136
+ query = query.transpose(1, 2)
137
+ key = key.transpose(1, 2)
138
+ value = value.transpose(1, 2)
139
+
140
+ hidden_states = F.scaled_dot_product_attention(
141
+ query, key, value, attn_mask=attention_mask, scale=softmax_scale
142
+ )
143
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
144
+ hidden_states = hidden_states.type_as(query)
145
+
146
+ # linear proj
147
+ hidden_states = attn.to_out[0](hidden_states)
148
+ hidden_states = attn.to_out[1](hidden_states)
149
+ return hidden_states
150
+
151
+
152
+ class Lumina2TransformerBlock(nn.Module):
153
+ def __init__(
154
+ self,
155
+ dim: int,
156
+ num_attention_heads: int,
157
+ num_kv_heads: int,
158
+ multiple_of: int,
159
+ ffn_dim_multiplier: float,
160
+ norm_eps: float,
161
+ modulation: bool = True,
162
+ ) -> None:
163
+ super().__init__()
164
+ self.head_dim = dim // num_attention_heads
165
+ self.modulation = modulation
166
+
167
+ self.attn = Attention(
168
+ query_dim=dim,
169
+ cross_attention_dim=None,
170
+ dim_head=dim // num_attention_heads,
171
+ qk_norm="rms_norm",
172
+ heads=num_attention_heads,
173
+ kv_heads=num_kv_heads,
174
+ eps=1e-5,
175
+ bias=False,
176
+ out_bias=False,
177
+ processor=Lumina2AttnProcessor2_0(),
178
+ )
179
+
180
+ self.feed_forward = LuminaFeedForward(
181
+ dim=dim,
182
+ inner_dim=4 * dim,
183
+ multiple_of=multiple_of,
184
+ ffn_dim_multiplier=ffn_dim_multiplier,
185
+ )
186
+
187
+ if modulation:
188
+ self.norm1 = LuminaRMSNormZero(
189
+ embedding_dim=dim,
190
+ norm_eps=norm_eps,
191
+ norm_elementwise_affine=True,
192
+ )
193
+ else:
194
+ self.norm1 = RMSNorm(dim, eps=norm_eps)
195
+ self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
196
+
197
+ self.norm2 = RMSNorm(dim, eps=norm_eps)
198
+ self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
199
+
200
+ def forward(
201
+ self,
202
+ hidden_states: torch.Tensor,
203
+ attention_mask: torch.Tensor,
204
+ image_rotary_emb: torch.Tensor,
205
+ temb: Optional[torch.Tensor] = None,
206
+ ) -> torch.Tensor:
207
+ if self.modulation:
208
+ norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
209
+ attn_output = self.attn(
210
+ hidden_states=norm_hidden_states,
211
+ encoder_hidden_states=norm_hidden_states,
212
+ attention_mask=attention_mask,
213
+ image_rotary_emb=image_rotary_emb,
214
+ )
215
+ hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
216
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
217
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
218
+ else:
219
+ norm_hidden_states = self.norm1(hidden_states)
220
+ attn_output = self.attn(
221
+ hidden_states=norm_hidden_states,
222
+ encoder_hidden_states=norm_hidden_states,
223
+ attention_mask=attention_mask,
224
+ image_rotary_emb=image_rotary_emb,
225
+ )
226
+ hidden_states = hidden_states + self.norm2(attn_output)
227
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
228
+ hidden_states = hidden_states + self.ffn_norm2(mlp_output)
229
+
230
+ return hidden_states
231
+
232
+
233
+ class Lumina2RotaryPosEmbed(nn.Module):
234
+ def __init__(self, theta: int, axes_dim: List[int], axes_lens: List[int] = (300, 512, 512), patch_size: int = 2):
235
+ super().__init__()
236
+ self.theta = theta
237
+ self.axes_dim = axes_dim
238
+ self.axes_lens = axes_lens
239
+ self.patch_size = patch_size
240
+
241
+ self.freqs_cis = self._precompute_freqs_cis(axes_dim, axes_lens, theta)
242
+
243
+ def _precompute_freqs_cis(self, axes_dim: List[int], axes_lens: List[int], theta: int) -> List[torch.Tensor]:
244
+ freqs_cis = []
245
+ freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
246
+ for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
247
+ emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=freqs_dtype)
248
+ freqs_cis.append(emb)
249
+ return freqs_cis
250
+
251
+ def _get_freqs_cis(self, ids: torch.Tensor) -> torch.Tensor:
252
+ device = ids.device
253
+ if ids.device.type == "mps":
254
+ ids = ids.to("cpu")
255
+
256
+ result = []
257
+ for i in range(len(self.axes_dim)):
258
+ freqs = self.freqs_cis[i].to(ids.device)
259
+ index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
260
+ result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
261
+ return torch.cat(result, dim=-1).to(device)
262
+
263
+ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor):
264
+ batch_size, channels, height, width = hidden_states.shape
265
+ p = self.patch_size
266
+ post_patch_height, post_patch_width = height // p, width // p
267
+ image_seq_len = post_patch_height * post_patch_width
268
+ device = hidden_states.device
269
+
270
+ encoder_seq_len = attention_mask.shape[1]
271
+ l_effective_cap_len = attention_mask.sum(dim=1).tolist()
272
+ seq_lengths = [cap_seq_len + image_seq_len for cap_seq_len in l_effective_cap_len]
273
+ max_seq_len = max(seq_lengths)
274
+
275
+ # Create position IDs
276
+ position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
277
+
278
+ for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
279
+ # add caption position ids
280
+ position_ids[i, :cap_seq_len, 0] = torch.arange(cap_seq_len, dtype=torch.int32, device=device)
281
+ position_ids[i, cap_seq_len:seq_len, 0] = cap_seq_len
282
+
283
+ # add image position ids
284
+ row_ids = (
285
+ torch.arange(post_patch_height, dtype=torch.int32, device=device)
286
+ .view(-1, 1)
287
+ .repeat(1, post_patch_width)
288
+ .flatten()
289
+ )
290
+ col_ids = (
291
+ torch.arange(post_patch_width, dtype=torch.int32, device=device)
292
+ .view(1, -1)
293
+ .repeat(post_patch_height, 1)
294
+ .flatten()
295
+ )
296
+ position_ids[i, cap_seq_len:seq_len, 1] = row_ids
297
+ position_ids[i, cap_seq_len:seq_len, 2] = col_ids
298
+
299
+ # Get combined rotary embeddings
300
+ freqs_cis = self._get_freqs_cis(position_ids)
301
+
302
+ # create separate rotary embeddings for captions and images
303
+ cap_freqs_cis = torch.zeros(
304
+ batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
305
+ )
306
+ img_freqs_cis = torch.zeros(
307
+ batch_size, image_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
308
+ )
309
+
310
+ for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
311
+ cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
312
+ img_freqs_cis[i, :image_seq_len] = freqs_cis[i, cap_seq_len:seq_len]
313
+
314
+ # image patch embeddings
315
+ hidden_states = (
316
+ hidden_states.view(batch_size, channels, post_patch_height, p, post_patch_width, p)
317
+ .permute(0, 2, 4, 3, 5, 1)
318
+ .flatten(3)
319
+ .flatten(1, 2)
320
+ )
321
+
322
+ return hidden_states, cap_freqs_cis, img_freqs_cis, freqs_cis, l_effective_cap_len, seq_lengths
323
+
324
+
325
+ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
326
+ r"""
327
+ Lumina2NextDiT: Diffusion model with a Transformer backbone.
328
+
329
+ Parameters:
330
+ sample_size (`int`): The width of the latent images. This is fixed during training since
331
+ it is used to learn a number of position embeddings.
332
+ patch_size (`int`, *optional*, (`int`, *optional*, defaults to 2):
333
+ The size of each patch in the image. This parameter defines the resolution of patches fed into the model.
334
+ in_channels (`int`, *optional*, defaults to 4):
335
+ The number of input channels for the model. Typically, this matches the number of channels in the input
336
+ images.
337
+ hidden_size (`int`, *optional*, defaults to 4096):
338
+ The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
339
+ hidden representations.
340
+ num_layers (`int`, *optional*, default to 32):
341
+ The number of layers in the model. This defines the depth of the neural network.
342
+ num_attention_heads (`int`, *optional*, defaults to 32):
343
+ The number of attention heads in each attention layer. This parameter specifies how many separate attention
344
+ mechanisms are used.
345
+ num_kv_heads (`int`, *optional*, defaults to 8):
346
+ The number of key-value heads in the attention mechanism, if different from the number of attention heads.
347
+ If None, it defaults to num_attention_heads.
348
+ multiple_of (`int`, *optional*, defaults to 256):
349
+ A factor that the hidden size should be a multiple of. This can help optimize certain hardware
350
+ configurations.
351
+ ffn_dim_multiplier (`float`, *optional*):
352
+ A multiplier for the dimensionality of the feed-forward network. If None, it uses a default value based on
353
+ the model configuration.
354
+ norm_eps (`float`, *optional*, defaults to 1e-5):
355
+ A small value added to the denominator for numerical stability in normalization layers.
356
+ scaling_factor (`float`, *optional*, defaults to 1.0):
357
+ A scaling factor applied to certain parameters or layers in the model. This can be used for adjusting the
358
+ overall scale of the model's operations.
359
+ """
360
+
361
+ _supports_gradient_checkpointing = True
362
+ _no_split_modules = ["Lumina2TransformerBlock"]
363
+ _skip_layerwise_casting_patterns = ["x_embedder", "norm"]
364
+
365
+ @register_to_config
366
+ def __init__(
367
+ self,
368
+ sample_size: int = 128,
369
+ patch_size: int = 2,
370
+ in_channels: int = 16,
371
+ out_channels: Optional[int] = None,
372
+ hidden_size: int = 2304,
373
+ num_layers: int = 26,
374
+ num_refiner_layers: int = 2,
375
+ num_attention_heads: int = 24,
376
+ num_kv_heads: int = 8,
377
+ multiple_of: int = 256,
378
+ ffn_dim_multiplier: Optional[float] = None,
379
+ norm_eps: float = 1e-5,
380
+ scaling_factor: float = 1.0,
381
+ axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
382
+ axes_lens: Tuple[int, int, int] = (300, 512, 512),
383
+ cap_feat_dim: int = 1024,
384
+ ) -> None:
385
+ super().__init__()
386
+ self.out_channels = out_channels or in_channels
387
+
388
+ # 1. Positional, patch & conditional embeddings
389
+ self.rope_embedder = Lumina2RotaryPosEmbed(
390
+ theta=10000, axes_dim=axes_dim_rope, axes_lens=axes_lens, patch_size=patch_size
391
+ )
392
+
393
+ self.x_embedder = nn.Linear(in_features=patch_size * patch_size * in_channels, out_features=hidden_size)
394
+
395
+ self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
396
+ hidden_size=hidden_size, cap_feat_dim=cap_feat_dim, norm_eps=norm_eps
397
+ )
398
+
399
+ # 2. Noise and context refinement blocks
400
+ self.noise_refiner = nn.ModuleList(
401
+ [
402
+ Lumina2TransformerBlock(
403
+ hidden_size,
404
+ num_attention_heads,
405
+ num_kv_heads,
406
+ multiple_of,
407
+ ffn_dim_multiplier,
408
+ norm_eps,
409
+ modulation=True,
410
+ )
411
+ for _ in range(num_refiner_layers)
412
+ ]
413
+ )
414
+
415
+ self.context_refiner = nn.ModuleList(
416
+ [
417
+ Lumina2TransformerBlock(
418
+ hidden_size,
419
+ num_attention_heads,
420
+ num_kv_heads,
421
+ multiple_of,
422
+ ffn_dim_multiplier,
423
+ norm_eps,
424
+ modulation=False,
425
+ )
426
+ for _ in range(num_refiner_layers)
427
+ ]
428
+ )
429
+
430
+ # 3. Transformer blocks
431
+ self.layers = nn.ModuleList(
432
+ [
433
+ Lumina2TransformerBlock(
434
+ hidden_size,
435
+ num_attention_heads,
436
+ num_kv_heads,
437
+ multiple_of,
438
+ ffn_dim_multiplier,
439
+ norm_eps,
440
+ modulation=True,
441
+ )
442
+ for _ in range(num_layers)
443
+ ]
444
+ )
445
+
446
+ # 4. Output norm & projection
447
+ self.norm_out = LuminaLayerNormContinuous(
448
+ embedding_dim=hidden_size,
449
+ conditioning_embedding_dim=min(hidden_size, 1024),
450
+ elementwise_affine=False,
451
+ eps=1e-6,
452
+ bias=True,
453
+ out_dim=patch_size * patch_size * self.out_channels,
454
+ )
455
+
456
+ self.gradient_checkpointing = False
457
+
458
+ def forward(
459
+ self,
460
+ hidden_states: torch.Tensor,
461
+ timestep: torch.Tensor,
462
+ encoder_hidden_states: torch.Tensor,
463
+ encoder_attention_mask: torch.Tensor,
464
+ attention_kwargs: Optional[Dict[str, Any]] = None,
465
+ return_dict: bool = True,
466
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
467
+ if attention_kwargs is not None:
468
+ attention_kwargs = attention_kwargs.copy()
469
+ lora_scale = attention_kwargs.pop("scale", 1.0)
470
+ else:
471
+ lora_scale = 1.0
472
+
473
+ if USE_PEFT_BACKEND:
474
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
475
+ scale_lora_layers(self, lora_scale)
476
+ else:
477
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
478
+ logger.warning(
479
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
480
+ )
481
+
482
+ # 1. Condition, positional & patch embedding
483
+ batch_size, _, height, width = hidden_states.shape
484
+
485
+ temb, encoder_hidden_states = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states)
486
+
487
+ (
488
+ hidden_states,
489
+ context_rotary_emb,
490
+ noise_rotary_emb,
491
+ rotary_emb,
492
+ encoder_seq_lengths,
493
+ seq_lengths,
494
+ ) = self.rope_embedder(hidden_states, encoder_attention_mask)
495
+
496
+ hidden_states = self.x_embedder(hidden_states)
497
+
498
+ # 2. Context & noise refinement
499
+ for layer in self.context_refiner:
500
+ encoder_hidden_states = layer(encoder_hidden_states, encoder_attention_mask, context_rotary_emb)
501
+
502
+ for layer in self.noise_refiner:
503
+ hidden_states = layer(hidden_states, None, noise_rotary_emb, temb)
504
+
505
+ # 3. Joint Transformer blocks
506
+ max_seq_len = max(seq_lengths)
507
+ use_mask = len(set(seq_lengths)) > 1
508
+
509
+ attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
510
+ joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
511
+ for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
512
+ attention_mask[i, :seq_len] = True
513
+ joint_hidden_states[i, :encoder_seq_len] = encoder_hidden_states[i, :encoder_seq_len]
514
+ joint_hidden_states[i, encoder_seq_len:seq_len] = hidden_states[i]
515
+
516
+ hidden_states = joint_hidden_states
517
+
518
+ for layer in self.layers:
519
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
520
+ hidden_states = self._gradient_checkpointing_func(
521
+ layer, hidden_states, attention_mask if use_mask else None, rotary_emb, temb
522
+ )
523
+ else:
524
+ hidden_states = layer(hidden_states, attention_mask if use_mask else None, rotary_emb, temb)
525
+
526
+ # 4. Output norm & projection
527
+ hidden_states = self.norm_out(hidden_states, temb)
528
+
529
+ # 5. Unpatchify
530
+ p = self.config.patch_size
531
+ output = []
532
+ for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
533
+ output.append(
534
+ hidden_states[i][encoder_seq_len:seq_len]
535
+ .view(height // p, width // p, p, p, self.out_channels)
536
+ .permute(4, 0, 2, 1, 3)
537
+ .flatten(3, 4)
538
+ .flatten(1, 2)
539
+ )
540
+ output = torch.stack(output, dim=0)
541
+
542
+ if USE_PEFT_BACKEND:
543
+ # remove `lora_scale` from each PEFT layer
544
+ unscale_lora_layers(self, lora_scale)
545
+
546
+ if not return_dict:
547
+ return (output,)
548
+ return Transformer2DModelOutput(sample=output)
pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ...utils import is_flax_available, is_torch_available
2
+
3
+
4
+ if is_torch_available():
5
+ from .unet_1d import UNet1DModel
6
+ from .unet_2d import UNet2DModel
7
+ from .unet_2d_condition import UNet2DConditionModel
8
+ from .unet_3d_condition import UNet3DConditionModel
9
+ from .unet_i2vgen_xl import I2VGenXLUNet
10
+ from .unet_kandinsky3 import Kandinsky3UNet
11
+ from .unet_motion_model import MotionAdapter, UNetMotionModel
12
+ from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
13
+ from .unet_stable_cascade import StableCascadeUNet
14
+ from .uvit_2d import UVit2DModel
15
+
16
+
17
+ if is_flax_available():
18
+ from .unet_2d_condition_flax import FlaxUNet2DConditionModel
pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (940 Bytes). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/unet_1d.cpython-310.pyc ADDED
Binary file (7.94 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/unet_1d_blocks.cpython-310.pyc ADDED
Binary file (18.7 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/unet_2d.cpython-310.pyc ADDED
Binary file (11.6 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/unet_2d_blocks.cpython-310.pyc ADDED
Binary file (60.9 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/unet_2d_blocks_flax.cpython-310.pyc ADDED
Binary file (12.2 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/unet_2d_condition.cpython-310.pyc ADDED
Binary file (40.6 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/unet_2d_condition_flax.cpython-310.pyc ADDED
Binary file (14.1 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/unet_3d_blocks.cpython-310.pyc ADDED
Binary file (26.4 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/unets/__pycache__/unet_3d_condition.cpython-310.pyc ADDED
Binary file (24 kB). View file