xiaoanyu123 commited on
Commit
bf8cf37
·
verified ·
1 Parent(s): ddf41bf

Add files using upload-large-folder tool

Browse files
Files changed (20) hide show
  1. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__init__.py +39 -0
  2. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/auraflow_transformer_2d.cpython-310.pyc +0 -0
  3. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/cogvideox_transformer_3d.cpython-310.pyc +0 -0
  4. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/consisid_transformer_3d.cpython-310.pyc +0 -0
  5. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/dit_transformer_2d.cpython-310.pyc +0 -0
  6. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/dual_transformer_2d.cpython-310.pyc +0 -0
  7. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/hunyuan_transformer_2d.cpython-310.pyc +0 -0
  8. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/latte_transformer_3d.cpython-310.pyc +0 -0
  9. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/lumina_nextdit2d.cpython-310.pyc +0 -0
  10. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/pixart_transformer_2d.cpython-310.pyc +0 -0
  11. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/prior_transformer.cpython-310.pyc +0 -0
  12. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/sana_transformer.cpython-310.pyc +0 -0
  13. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/transformer_mochi.py +488 -0
  14. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/transformer_omnigen.py +469 -0
  15. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/transformer_qwenimage.py +655 -0
  16. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/transformer_sd3.py +431 -0
  17. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/transformer_skyreels_v2.py +781 -0
  18. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/transformer_temporal.py +375 -0
  19. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/transformer_wan.py +698 -0
  20. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/transformer_wan_vace.py +389 -0
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__init__.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ...utils import is_torch_available
2
+
3
+
4
+ if is_torch_available():
5
+ from .auraflow_transformer_2d import AuraFlowTransformer2DModel
6
+ from .cogvideox_transformer_3d import CogVideoXTransformer3DModel
7
+ from .consisid_transformer_3d import ConsisIDTransformer3DModel
8
+ from .dit_transformer_2d import DiTTransformer2DModel
9
+ from .dual_transformer_2d import DualTransformer2DModel
10
+ from .hunyuan_transformer_2d import HunyuanDiT2DModel
11
+ from .latte_transformer_3d import LatteTransformer3DModel
12
+ from .lumina_nextdit2d import LuminaNextDiT2DModel
13
+ from .pixart_transformer_2d import PixArtTransformer2DModel
14
+ from .prior_transformer import PriorTransformer
15
+ from .sana_transformer import SanaTransformer2DModel
16
+ from .stable_audio_transformer import StableAudioDiTModel
17
+ from .t5_film_transformer import T5FilmDecoder
18
+ from .transformer_2d import Transformer2DModel
19
+ from .transformer_allegro import AllegroTransformer3DModel
20
+ from .transformer_bria import BriaTransformer2DModel
21
+ from .transformer_chroma import ChromaTransformer2DModel
22
+ from .transformer_cogview3plus import CogView3PlusTransformer2DModel
23
+ from .transformer_cogview4 import CogView4Transformer2DModel
24
+ from .transformer_cosmos import CosmosTransformer3DModel
25
+ from .transformer_easyanimate import EasyAnimateTransformer3DModel
26
+ from .transformer_flux import FluxTransformer2DModel
27
+ from .transformer_hidream_image import HiDreamImageTransformer2DModel
28
+ from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
29
+ from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel
30
+ from .transformer_ltx import LTXVideoTransformer3DModel
31
+ from .transformer_lumina2 import Lumina2Transformer2DModel
32
+ from .transformer_mochi import MochiTransformer3DModel
33
+ from .transformer_omnigen import OmniGenTransformer2DModel
34
+ from .transformer_qwenimage import QwenImageTransformer2DModel
35
+ from .transformer_sd3 import SD3Transformer2DModel
36
+ from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel
37
+ from .transformer_temporal import TransformerTemporalModel
38
+ from .transformer_wan import WanTransformer3DModel
39
+ from .transformer_wan_vace import WanVACETransformer3DModel
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/auraflow_transformer_2d.cpython-310.pyc ADDED
Binary file (16.8 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/cogvideox_transformer_3d.cpython-310.pyc ADDED
Binary file (16.6 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/consisid_transformer_3d.cpython-310.pyc ADDED
Binary file (25.8 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/dit_transformer_2d.cpython-310.pyc ADDED
Binary file (8.2 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/dual_transformer_2d.cpython-310.pyc ADDED
Binary file (6.15 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/hunyuan_transformer_2d.cpython-310.pyc ADDED
Binary file (18.6 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/latte_transformer_3d.cpython-310.pyc ADDED
Binary file (7.22 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/lumina_nextdit2d.cpython-310.pyc ADDED
Binary file (10.9 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/pixart_transformer_2d.cpython-310.pyc ADDED
Binary file (15.4 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/prior_transformer.cpython-310.pyc ADDED
Binary file (13.3 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/__pycache__/sana_transformer.cpython-310.pyc ADDED
Binary file (16.9 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/transformer_mochi.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The Genmo team and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any, Dict, Optional, Tuple
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from ...configuration_utils import ConfigMixin, register_to_config
22
+ from ...loaders import PeftAdapterMixin
23
+ from ...loaders.single_file_model import FromOriginalModelMixin
24
+ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
25
+ from ...utils.torch_utils import maybe_allow_in_graph
26
+ from ..attention import FeedForward
27
+ from ..attention_processor import MochiAttention, MochiAttnProcessor2_0
28
+ from ..cache_utils import CacheMixin
29
+ from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
30
+ from ..modeling_outputs import Transformer2DModelOutput
31
+ from ..modeling_utils import ModelMixin
32
+ from ..normalization import AdaLayerNormContinuous, RMSNorm
33
+
34
+
35
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
+
37
+
38
+ class MochiModulatedRMSNorm(nn.Module):
39
+ def __init__(self, eps: float):
40
+ super().__init__()
41
+
42
+ self.eps = eps
43
+ self.norm = RMSNorm(0, eps, False)
44
+
45
+ def forward(self, hidden_states, scale=None):
46
+ hidden_states_dtype = hidden_states.dtype
47
+ hidden_states = hidden_states.to(torch.float32)
48
+
49
+ hidden_states = self.norm(hidden_states)
50
+
51
+ if scale is not None:
52
+ hidden_states = hidden_states * scale
53
+
54
+ hidden_states = hidden_states.to(hidden_states_dtype)
55
+
56
+ return hidden_states
57
+
58
+
59
+ class MochiLayerNormContinuous(nn.Module):
60
+ def __init__(
61
+ self,
62
+ embedding_dim: int,
63
+ conditioning_embedding_dim: int,
64
+ eps=1e-5,
65
+ bias=True,
66
+ ):
67
+ super().__init__()
68
+
69
+ # AdaLN
70
+ self.silu = nn.SiLU()
71
+ self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
72
+ self.norm = MochiModulatedRMSNorm(eps=eps)
73
+
74
+ def forward(
75
+ self,
76
+ x: torch.Tensor,
77
+ conditioning_embedding: torch.Tensor,
78
+ ) -> torch.Tensor:
79
+ input_dtype = x.dtype
80
+
81
+ # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
82
+ scale = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
83
+ x = self.norm(x, (1 + scale.unsqueeze(1).to(torch.float32)))
84
+
85
+ return x.to(input_dtype)
86
+
87
+
88
+ class MochiRMSNormZero(nn.Module):
89
+ r"""
90
+ Adaptive RMS Norm used in Mochi.
91
+
92
+ Parameters:
93
+ embedding_dim (`int`): The size of each embedding vector.
94
+ """
95
+
96
+ def __init__(
97
+ self, embedding_dim: int, hidden_dim: int, eps: float = 1e-5, elementwise_affine: bool = False
98
+ ) -> None:
99
+ super().__init__()
100
+
101
+ self.silu = nn.SiLU()
102
+ self.linear = nn.Linear(embedding_dim, hidden_dim)
103
+ self.norm = RMSNorm(0, eps, False)
104
+
105
+ def forward(
106
+ self, hidden_states: torch.Tensor, emb: torch.Tensor
107
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
108
+ hidden_states_dtype = hidden_states.dtype
109
+
110
+ emb = self.linear(self.silu(emb))
111
+ scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
112
+ hidden_states = self.norm(hidden_states.to(torch.float32)) * (1 + scale_msa[:, None].to(torch.float32))
113
+ hidden_states = hidden_states.to(hidden_states_dtype)
114
+
115
+ return hidden_states, gate_msa, scale_mlp, gate_mlp
116
+
117
+
118
+ @maybe_allow_in_graph
119
+ class MochiTransformerBlock(nn.Module):
120
+ r"""
121
+ Transformer block used in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
122
+
123
+ Args:
124
+ dim (`int`):
125
+ The number of channels in the input and output.
126
+ num_attention_heads (`int`):
127
+ The number of heads to use for multi-head attention.
128
+ attention_head_dim (`int`):
129
+ The number of channels in each head.
130
+ qk_norm (`str`, defaults to `"rms_norm"`):
131
+ The normalization layer to use.
132
+ activation_fn (`str`, defaults to `"swiglu"`):
133
+ Activation function to use in feed-forward.
134
+ context_pre_only (`bool`, defaults to `False`):
135
+ Whether or not to process context-related conditions with additional layers.
136
+ eps (`float`, defaults to `1e-6`):
137
+ Epsilon value for normalization layers.
138
+ """
139
+
140
+ def __init__(
141
+ self,
142
+ dim: int,
143
+ num_attention_heads: int,
144
+ attention_head_dim: int,
145
+ pooled_projection_dim: int,
146
+ qk_norm: str = "rms_norm",
147
+ activation_fn: str = "swiglu",
148
+ context_pre_only: bool = False,
149
+ eps: float = 1e-6,
150
+ ) -> None:
151
+ super().__init__()
152
+
153
+ self.context_pre_only = context_pre_only
154
+ self.ff_inner_dim = (4 * dim * 2) // 3
155
+ self.ff_context_inner_dim = (4 * pooled_projection_dim * 2) // 3
156
+
157
+ self.norm1 = MochiRMSNormZero(dim, 4 * dim, eps=eps, elementwise_affine=False)
158
+
159
+ if not context_pre_only:
160
+ self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim, eps=eps, elementwise_affine=False)
161
+ else:
162
+ self.norm1_context = MochiLayerNormContinuous(
163
+ embedding_dim=pooled_projection_dim,
164
+ conditioning_embedding_dim=dim,
165
+ eps=eps,
166
+ )
167
+
168
+ self.attn1 = MochiAttention(
169
+ query_dim=dim,
170
+ heads=num_attention_heads,
171
+ dim_head=attention_head_dim,
172
+ bias=False,
173
+ added_kv_proj_dim=pooled_projection_dim,
174
+ added_proj_bias=False,
175
+ out_dim=dim,
176
+ out_context_dim=pooled_projection_dim,
177
+ context_pre_only=context_pre_only,
178
+ processor=MochiAttnProcessor2_0(),
179
+ eps=1e-5,
180
+ )
181
+
182
+ # TODO(aryan): norm_context layers are not needed when `context_pre_only` is True
183
+ self.norm2 = MochiModulatedRMSNorm(eps=eps)
184
+ self.norm2_context = MochiModulatedRMSNorm(eps=eps) if not self.context_pre_only else None
185
+
186
+ self.norm3 = MochiModulatedRMSNorm(eps)
187
+ self.norm3_context = MochiModulatedRMSNorm(eps=eps) if not self.context_pre_only else None
188
+
189
+ self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False)
190
+ self.ff_context = None
191
+ if not context_pre_only:
192
+ self.ff_context = FeedForward(
193
+ pooled_projection_dim,
194
+ inner_dim=self.ff_context_inner_dim,
195
+ activation_fn=activation_fn,
196
+ bias=False,
197
+ )
198
+
199
+ self.norm4 = MochiModulatedRMSNorm(eps=eps)
200
+ self.norm4_context = MochiModulatedRMSNorm(eps=eps)
201
+
202
+ def forward(
203
+ self,
204
+ hidden_states: torch.Tensor,
205
+ encoder_hidden_states: torch.Tensor,
206
+ temb: torch.Tensor,
207
+ encoder_attention_mask: torch.Tensor,
208
+ image_rotary_emb: Optional[torch.Tensor] = None,
209
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
210
+ norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
211
+
212
+ if not self.context_pre_only:
213
+ norm_encoder_hidden_states, enc_gate_msa, enc_scale_mlp, enc_gate_mlp = self.norm1_context(
214
+ encoder_hidden_states, temb
215
+ )
216
+ else:
217
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
218
+
219
+ attn_hidden_states, context_attn_hidden_states = self.attn1(
220
+ hidden_states=norm_hidden_states,
221
+ encoder_hidden_states=norm_encoder_hidden_states,
222
+ image_rotary_emb=image_rotary_emb,
223
+ attention_mask=encoder_attention_mask,
224
+ )
225
+
226
+ hidden_states = hidden_states + self.norm2(attn_hidden_states, torch.tanh(gate_msa).unsqueeze(1))
227
+ norm_hidden_states = self.norm3(hidden_states, (1 + scale_mlp.unsqueeze(1).to(torch.float32)))
228
+ ff_output = self.ff(norm_hidden_states)
229
+ hidden_states = hidden_states + self.norm4(ff_output, torch.tanh(gate_mlp).unsqueeze(1))
230
+
231
+ if not self.context_pre_only:
232
+ encoder_hidden_states = encoder_hidden_states + self.norm2_context(
233
+ context_attn_hidden_states, torch.tanh(enc_gate_msa).unsqueeze(1)
234
+ )
235
+ norm_encoder_hidden_states = self.norm3_context(
236
+ encoder_hidden_states, (1 + enc_scale_mlp.unsqueeze(1).to(torch.float32))
237
+ )
238
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
239
+ encoder_hidden_states = encoder_hidden_states + self.norm4_context(
240
+ context_ff_output, torch.tanh(enc_gate_mlp).unsqueeze(1)
241
+ )
242
+
243
+ return hidden_states, encoder_hidden_states
244
+
245
+
246
+ class MochiRoPE(nn.Module):
247
+ r"""
248
+ RoPE implementation used in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
249
+
250
+ Args:
251
+ base_height (`int`, defaults to `192`):
252
+ Base height used to compute interpolation scale for rotary positional embeddings.
253
+ base_width (`int`, defaults to `192`):
254
+ Base width used to compute interpolation scale for rotary positional embeddings.
255
+ """
256
+
257
+ def __init__(self, base_height: int = 192, base_width: int = 192) -> None:
258
+ super().__init__()
259
+
260
+ self.target_area = base_height * base_width
261
+
262
+ def _centers(self, start, stop, num, device, dtype) -> torch.Tensor:
263
+ edges = torch.linspace(start, stop, num + 1, device=device, dtype=dtype)
264
+ return (edges[:-1] + edges[1:]) / 2
265
+
266
+ def _get_positions(
267
+ self,
268
+ num_frames: int,
269
+ height: int,
270
+ width: int,
271
+ device: Optional[torch.device] = None,
272
+ dtype: Optional[torch.dtype] = None,
273
+ ) -> torch.Tensor:
274
+ scale = (self.target_area / (height * width)) ** 0.5
275
+
276
+ t = torch.arange(num_frames, device=device, dtype=dtype)
277
+ h = self._centers(-height * scale / 2, height * scale / 2, height, device, dtype)
278
+ w = self._centers(-width * scale / 2, width * scale / 2, width, device, dtype)
279
+
280
+ grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij")
281
+
282
+ positions = torch.stack([grid_t, grid_h, grid_w], dim=-1).view(-1, 3)
283
+ return positions
284
+
285
+ def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
286
+ with torch.autocast(freqs.device.type, torch.float32):
287
+ # Always run ROPE freqs computation in FP32
288
+ freqs = torch.einsum("nd,dhf->nhf", pos.to(torch.float32), freqs.to(torch.float32))
289
+
290
+ freqs_cos = torch.cos(freqs)
291
+ freqs_sin = torch.sin(freqs)
292
+ return freqs_cos, freqs_sin
293
+
294
+ def forward(
295
+ self,
296
+ pos_frequencies: torch.Tensor,
297
+ num_frames: int,
298
+ height: int,
299
+ width: int,
300
+ device: Optional[torch.device] = None,
301
+ dtype: Optional[torch.dtype] = None,
302
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
303
+ pos = self._get_positions(num_frames, height, width, device, dtype)
304
+ rope_cos, rope_sin = self._create_rope(pos_frequencies, pos)
305
+ return rope_cos, rope_sin
306
+
307
+
308
+ @maybe_allow_in_graph
309
+ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
310
+ r"""
311
+ A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
312
+
313
+ Args:
314
+ patch_size (`int`, defaults to `2`):
315
+ The size of the patches to use in the patch embedding layer.
316
+ num_attention_heads (`int`, defaults to `24`):
317
+ The number of heads to use for multi-head attention.
318
+ attention_head_dim (`int`, defaults to `128`):
319
+ The number of channels in each head.
320
+ num_layers (`int`, defaults to `48`):
321
+ The number of layers of Transformer blocks to use.
322
+ in_channels (`int`, defaults to `12`):
323
+ The number of channels in the input.
324
+ out_channels (`int`, *optional*, defaults to `None`):
325
+ The number of channels in the output.
326
+ qk_norm (`str`, defaults to `"rms_norm"`):
327
+ The normalization layer to use.
328
+ text_embed_dim (`int`, defaults to `4096`):
329
+ Input dimension of text embeddings from the text encoder.
330
+ time_embed_dim (`int`, defaults to `256`):
331
+ Output dimension of timestep embeddings.
332
+ activation_fn (`str`, defaults to `"swiglu"`):
333
+ Activation function to use in feed-forward.
334
+ max_sequence_length (`int`, defaults to `256`):
335
+ The maximum sequence length of text embeddings supported.
336
+ """
337
+
338
+ _supports_gradient_checkpointing = True
339
+ _no_split_modules = ["MochiTransformerBlock"]
340
+ _skip_layerwise_casting_patterns = ["patch_embed", "norm"]
341
+
342
+ @register_to_config
343
+ def __init__(
344
+ self,
345
+ patch_size: int = 2,
346
+ num_attention_heads: int = 24,
347
+ attention_head_dim: int = 128,
348
+ num_layers: int = 48,
349
+ pooled_projection_dim: int = 1536,
350
+ in_channels: int = 12,
351
+ out_channels: Optional[int] = None,
352
+ qk_norm: str = "rms_norm",
353
+ text_embed_dim: int = 4096,
354
+ time_embed_dim: int = 256,
355
+ activation_fn: str = "swiglu",
356
+ max_sequence_length: int = 256,
357
+ ) -> None:
358
+ super().__init__()
359
+
360
+ inner_dim = num_attention_heads * attention_head_dim
361
+ out_channels = out_channels or in_channels
362
+
363
+ self.patch_embed = PatchEmbed(
364
+ patch_size=patch_size,
365
+ in_channels=in_channels,
366
+ embed_dim=inner_dim,
367
+ pos_embed_type=None,
368
+ )
369
+
370
+ self.time_embed = MochiCombinedTimestepCaptionEmbedding(
371
+ embedding_dim=inner_dim,
372
+ pooled_projection_dim=pooled_projection_dim,
373
+ text_embed_dim=text_embed_dim,
374
+ time_embed_dim=time_embed_dim,
375
+ num_attention_heads=8,
376
+ )
377
+
378
+ self.pos_frequencies = nn.Parameter(torch.full((3, num_attention_heads, attention_head_dim // 2), 0.0))
379
+ self.rope = MochiRoPE()
380
+
381
+ self.transformer_blocks = nn.ModuleList(
382
+ [
383
+ MochiTransformerBlock(
384
+ dim=inner_dim,
385
+ num_attention_heads=num_attention_heads,
386
+ attention_head_dim=attention_head_dim,
387
+ pooled_projection_dim=pooled_projection_dim,
388
+ qk_norm=qk_norm,
389
+ activation_fn=activation_fn,
390
+ context_pre_only=i == num_layers - 1,
391
+ )
392
+ for i in range(num_layers)
393
+ ]
394
+ )
395
+
396
+ self.norm_out = AdaLayerNormContinuous(
397
+ inner_dim,
398
+ inner_dim,
399
+ elementwise_affine=False,
400
+ eps=1e-6,
401
+ norm_type="layer_norm",
402
+ )
403
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
404
+
405
+ self.gradient_checkpointing = False
406
+
407
+ def forward(
408
+ self,
409
+ hidden_states: torch.Tensor,
410
+ encoder_hidden_states: torch.Tensor,
411
+ timestep: torch.LongTensor,
412
+ encoder_attention_mask: torch.Tensor,
413
+ attention_kwargs: Optional[Dict[str, Any]] = None,
414
+ return_dict: bool = True,
415
+ ) -> torch.Tensor:
416
+ if attention_kwargs is not None:
417
+ attention_kwargs = attention_kwargs.copy()
418
+ lora_scale = attention_kwargs.pop("scale", 1.0)
419
+ else:
420
+ lora_scale = 1.0
421
+
422
+ if USE_PEFT_BACKEND:
423
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
424
+ scale_lora_layers(self, lora_scale)
425
+ else:
426
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
427
+ logger.warning(
428
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
429
+ )
430
+
431
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
432
+ p = self.config.patch_size
433
+
434
+ post_patch_height = height // p
435
+ post_patch_width = width // p
436
+
437
+ temb, encoder_hidden_states = self.time_embed(
438
+ timestep,
439
+ encoder_hidden_states,
440
+ encoder_attention_mask,
441
+ hidden_dtype=hidden_states.dtype,
442
+ )
443
+
444
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
445
+ hidden_states = self.patch_embed(hidden_states)
446
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2)
447
+
448
+ image_rotary_emb = self.rope(
449
+ self.pos_frequencies,
450
+ num_frames,
451
+ post_patch_height,
452
+ post_patch_width,
453
+ device=hidden_states.device,
454
+ dtype=torch.float32,
455
+ )
456
+
457
+ for i, block in enumerate(self.transformer_blocks):
458
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
459
+ hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
460
+ block,
461
+ hidden_states,
462
+ encoder_hidden_states,
463
+ temb,
464
+ encoder_attention_mask,
465
+ image_rotary_emb,
466
+ )
467
+ else:
468
+ hidden_states, encoder_hidden_states = block(
469
+ hidden_states=hidden_states,
470
+ encoder_hidden_states=encoder_hidden_states,
471
+ temb=temb,
472
+ encoder_attention_mask=encoder_attention_mask,
473
+ image_rotary_emb=image_rotary_emb,
474
+ )
475
+ hidden_states = self.norm_out(hidden_states, temb)
476
+ hidden_states = self.proj_out(hidden_states)
477
+
478
+ hidden_states = hidden_states.reshape(batch_size, num_frames, post_patch_height, post_patch_width, p, p, -1)
479
+ hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5)
480
+ output = hidden_states.reshape(batch_size, -1, num_frames, height, width)
481
+
482
+ if USE_PEFT_BACKEND:
483
+ # remove `lora_scale` from each PEFT layer
484
+ unscale_lora_layers(self, lora_scale)
485
+
486
+ if not return_dict:
487
+ return (output,)
488
+ return Transformer2DModelOutput(sample=output)
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/transformer_omnigen.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 OmniGen team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from ...configuration_utils import ConfigMixin, register_to_config
23
+ from ...utils import logging
24
+ from ..attention_processor import Attention
25
+ from ..embeddings import TimestepEmbedding, Timesteps, get_2d_sincos_pos_embed
26
+ from ..modeling_outputs import Transformer2DModelOutput
27
+ from ..modeling_utils import ModelMixin
28
+ from ..normalization import AdaLayerNorm, RMSNorm
29
+
30
+
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+
34
+ class OmniGenFeedForward(nn.Module):
35
+ def __init__(self, hidden_size: int, intermediate_size: int):
36
+ super().__init__()
37
+
38
+ self.gate_up_proj = nn.Linear(hidden_size, 2 * intermediate_size, bias=False)
39
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
40
+ self.activation_fn = nn.SiLU()
41
+
42
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
43
+ up_states = self.gate_up_proj(hidden_states)
44
+ gate, up_states = up_states.chunk(2, dim=-1)
45
+ up_states = up_states * self.activation_fn(gate)
46
+ return self.down_proj(up_states)
47
+
48
+
49
+ class OmniGenPatchEmbed(nn.Module):
50
+ def __init__(
51
+ self,
52
+ patch_size: int = 2,
53
+ in_channels: int = 4,
54
+ embed_dim: int = 768,
55
+ bias: bool = True,
56
+ interpolation_scale: float = 1,
57
+ pos_embed_max_size: int = 192,
58
+ base_size: int = 64,
59
+ ):
60
+ super().__init__()
61
+
62
+ self.output_image_proj = nn.Conv2d(
63
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
64
+ )
65
+ self.input_image_proj = nn.Conv2d(
66
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
67
+ )
68
+
69
+ self.patch_size = patch_size
70
+ self.interpolation_scale = interpolation_scale
71
+ self.pos_embed_max_size = pos_embed_max_size
72
+
73
+ pos_embed = get_2d_sincos_pos_embed(
74
+ embed_dim,
75
+ self.pos_embed_max_size,
76
+ base_size=base_size,
77
+ interpolation_scale=self.interpolation_scale,
78
+ output_type="pt",
79
+ )
80
+ self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=True)
81
+
82
+ def _cropped_pos_embed(self, height, width):
83
+ """Crops positional embeddings for SD3 compatibility."""
84
+ if self.pos_embed_max_size is None:
85
+ raise ValueError("`pos_embed_max_size` must be set for cropping.")
86
+
87
+ height = height // self.patch_size
88
+ width = width // self.patch_size
89
+ if height > self.pos_embed_max_size:
90
+ raise ValueError(
91
+ f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
92
+ )
93
+ if width > self.pos_embed_max_size:
94
+ raise ValueError(
95
+ f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
96
+ )
97
+
98
+ top = (self.pos_embed_max_size - height) // 2
99
+ left = (self.pos_embed_max_size - width) // 2
100
+ spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
101
+ spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
102
+ spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
103
+ return spatial_pos_embed
104
+
105
+ def _patch_embeddings(self, hidden_states: torch.Tensor, is_input_image: bool) -> torch.Tensor:
106
+ if is_input_image:
107
+ hidden_states = self.input_image_proj(hidden_states)
108
+ else:
109
+ hidden_states = self.output_image_proj(hidden_states)
110
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
111
+ return hidden_states
112
+
113
+ def forward(
114
+ self, hidden_states: torch.Tensor, is_input_image: bool, padding_latent: torch.Tensor = None
115
+ ) -> torch.Tensor:
116
+ if isinstance(hidden_states, list):
117
+ if padding_latent is None:
118
+ padding_latent = [None] * len(hidden_states)
119
+ patched_latents = []
120
+ for sub_latent, padding in zip(hidden_states, padding_latent):
121
+ height, width = sub_latent.shape[-2:]
122
+ sub_latent = self._patch_embeddings(sub_latent, is_input_image)
123
+ pos_embed = self._cropped_pos_embed(height, width)
124
+ sub_latent = sub_latent + pos_embed
125
+ if padding is not None:
126
+ sub_latent = torch.cat([sub_latent, padding.to(sub_latent.device)], dim=-2)
127
+ patched_latents.append(sub_latent)
128
+ else:
129
+ height, width = hidden_states.shape[-2:]
130
+ pos_embed = self._cropped_pos_embed(height, width)
131
+ hidden_states = self._patch_embeddings(hidden_states, is_input_image)
132
+ patched_latents = hidden_states + pos_embed
133
+
134
+ return patched_latents
135
+
136
+
137
+ class OmniGenSuScaledRotaryEmbedding(nn.Module):
138
+ def __init__(
139
+ self, dim, max_position_embeddings=131072, original_max_position_embeddings=4096, base=10000, rope_scaling=None
140
+ ):
141
+ super().__init__()
142
+
143
+ self.dim = dim
144
+ self.max_position_embeddings = max_position_embeddings
145
+ self.base = base
146
+
147
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
148
+ self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
149
+
150
+ self.short_factor = rope_scaling["short_factor"]
151
+ self.long_factor = rope_scaling["long_factor"]
152
+ self.original_max_position_embeddings = original_max_position_embeddings
153
+
154
+ def forward(self, hidden_states, position_ids):
155
+ seq_len = torch.max(position_ids) + 1
156
+ if seq_len > self.original_max_position_embeddings:
157
+ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=hidden_states.device)
158
+ else:
159
+ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=hidden_states.device)
160
+
161
+ inv_freq_shape = (
162
+ torch.arange(0, self.dim, 2, dtype=torch.int64, device=hidden_states.device).float() / self.dim
163
+ )
164
+ self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
165
+
166
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
167
+ position_ids_expanded = position_ids[:, None, :].float()
168
+
169
+ # Force float32 since bfloat16 loses precision on long contexts
170
+ # See https://github.com/huggingface/transformers/pull/29285
171
+ device_type = hidden_states.device.type
172
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
173
+ with torch.autocast(device_type=device_type, enabled=False):
174
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
175
+ emb = torch.cat((freqs, freqs), dim=-1)[0]
176
+
177
+ scale = self.max_position_embeddings / self.original_max_position_embeddings
178
+ if scale <= 1.0:
179
+ scaling_factor = 1.0
180
+ else:
181
+ scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
182
+
183
+ cos = emb.cos() * scaling_factor
184
+ sin = emb.sin() * scaling_factor
185
+ return cos, sin
186
+
187
+
188
+ class OmniGenAttnProcessor2_0:
189
+ r"""
190
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
191
+ used in the OmniGen model.
192
+ """
193
+
194
+ def __init__(self):
195
+ if not hasattr(F, "scaled_dot_product_attention"):
196
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
197
+
198
+ def __call__(
199
+ self,
200
+ attn: Attention,
201
+ hidden_states: torch.Tensor,
202
+ encoder_hidden_states: torch.Tensor,
203
+ attention_mask: Optional[torch.Tensor] = None,
204
+ image_rotary_emb: Optional[torch.Tensor] = None,
205
+ ) -> torch.Tensor:
206
+ batch_size, sequence_length, _ = hidden_states.shape
207
+
208
+ # Get Query-Key-Value Pair
209
+ query = attn.to_q(hidden_states)
210
+ key = attn.to_k(encoder_hidden_states)
211
+ value = attn.to_v(encoder_hidden_states)
212
+
213
+ bsz, q_len, query_dim = query.size()
214
+ inner_dim = key.shape[-1]
215
+ head_dim = query_dim // attn.heads
216
+
217
+ # Get key-value heads
218
+ kv_heads = inner_dim // head_dim
219
+
220
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
221
+ key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
222
+ value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
223
+
224
+ # Apply RoPE if needed
225
+ if image_rotary_emb is not None:
226
+ from ..embeddings import apply_rotary_emb
227
+
228
+ query = apply_rotary_emb(query, image_rotary_emb, use_real_unbind_dim=-2)
229
+ key = apply_rotary_emb(key, image_rotary_emb, use_real_unbind_dim=-2)
230
+
231
+ hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask)
232
+ hidden_states = hidden_states.transpose(1, 2).type_as(query)
233
+ hidden_states = hidden_states.reshape(bsz, q_len, attn.out_dim)
234
+ hidden_states = attn.to_out[0](hidden_states)
235
+ return hidden_states
236
+
237
+
238
+ class OmniGenBlock(nn.Module):
239
+ def __init__(
240
+ self,
241
+ hidden_size: int,
242
+ num_attention_heads: int,
243
+ num_key_value_heads: int,
244
+ intermediate_size: int,
245
+ rms_norm_eps: float,
246
+ ) -> None:
247
+ super().__init__()
248
+
249
+ self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
250
+ self.self_attn = Attention(
251
+ query_dim=hidden_size,
252
+ cross_attention_dim=hidden_size,
253
+ dim_head=hidden_size // num_attention_heads,
254
+ heads=num_attention_heads,
255
+ kv_heads=num_key_value_heads,
256
+ bias=False,
257
+ out_dim=hidden_size,
258
+ out_bias=False,
259
+ processor=OmniGenAttnProcessor2_0(),
260
+ )
261
+ self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
262
+ self.mlp = OmniGenFeedForward(hidden_size, intermediate_size)
263
+
264
+ def forward(
265
+ self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor
266
+ ) -> torch.Tensor:
267
+ # 1. Attention
268
+ norm_hidden_states = self.input_layernorm(hidden_states)
269
+ attn_output = self.self_attn(
270
+ hidden_states=norm_hidden_states,
271
+ encoder_hidden_states=norm_hidden_states,
272
+ attention_mask=attention_mask,
273
+ image_rotary_emb=image_rotary_emb,
274
+ )
275
+ hidden_states = hidden_states + attn_output
276
+
277
+ # 2. Feed Forward
278
+ norm_hidden_states = self.post_attention_layernorm(hidden_states)
279
+ ff_output = self.mlp(norm_hidden_states)
280
+ hidden_states = hidden_states + ff_output
281
+ return hidden_states
282
+
283
+
284
+ class OmniGenTransformer2DModel(ModelMixin, ConfigMixin):
285
+ """
286
+ The Transformer model introduced in OmniGen (https://huggingface.co/papers/2409.11340).
287
+
288
+ Parameters:
289
+ in_channels (`int`, defaults to `4`):
290
+ The number of channels in the input.
291
+ patch_size (`int`, defaults to `2`):
292
+ The size of the spatial patches to use in the patch embedding layer.
293
+ hidden_size (`int`, defaults to `3072`):
294
+ The dimensionality of the hidden layers in the model.
295
+ rms_norm_eps (`float`, defaults to `1e-5`):
296
+ Eps for RMSNorm layer.
297
+ num_attention_heads (`int`, defaults to `32`):
298
+ The number of heads to use for multi-head attention.
299
+ num_key_value_heads (`int`, defaults to `32`):
300
+ The number of heads to use for keys and values in multi-head attention.
301
+ intermediate_size (`int`, defaults to `8192`):
302
+ Dimension of the hidden layer in FeedForward layers.
303
+ num_layers (`int`, default to `32`):
304
+ The number of layers of transformer blocks to use.
305
+ pad_token_id (`int`, default to `32000`):
306
+ The id of the padding token.
307
+ vocab_size (`int`, default to `32064`):
308
+ The size of the vocabulary of the embedding vocabulary.
309
+ rope_base (`int`, default to `10000`):
310
+ The default theta value to use when creating RoPE.
311
+ rope_scaling (`Dict`, optional):
312
+ The scaling factors for the RoPE. Must contain `short_factor` and `long_factor`.
313
+ pos_embed_max_size (`int`, default to `192`):
314
+ The maximum size of the positional embeddings.
315
+ time_step_dim (`int`, default to `256`):
316
+ Output dimension of timestep embeddings.
317
+ flip_sin_to_cos (`bool`, default to `True`):
318
+ Whether to flip the sin and cos in the positional embeddings when preparing timestep embeddings.
319
+ downscale_freq_shift (`int`, default to `0`):
320
+ The frequency shift to use when downscaling the timestep embeddings.
321
+ timestep_activation_fn (`str`, default to `silu`):
322
+ The activation function to use for the timestep embeddings.
323
+ """
324
+
325
+ _supports_gradient_checkpointing = True
326
+ _no_split_modules = ["OmniGenBlock"]
327
+ _skip_layerwise_casting_patterns = ["patch_embedding", "embed_tokens", "norm"]
328
+
329
+ @register_to_config
330
+ def __init__(
331
+ self,
332
+ in_channels: int = 4,
333
+ patch_size: int = 2,
334
+ hidden_size: int = 3072,
335
+ rms_norm_eps: float = 1e-5,
336
+ num_attention_heads: int = 32,
337
+ num_key_value_heads: int = 32,
338
+ intermediate_size: int = 8192,
339
+ num_layers: int = 32,
340
+ pad_token_id: int = 32000,
341
+ vocab_size: int = 32064,
342
+ max_position_embeddings: int = 131072,
343
+ original_max_position_embeddings: int = 4096,
344
+ rope_base: int = 10000,
345
+ rope_scaling: Dict = None,
346
+ pos_embed_max_size: int = 192,
347
+ time_step_dim: int = 256,
348
+ flip_sin_to_cos: bool = True,
349
+ downscale_freq_shift: int = 0,
350
+ timestep_activation_fn: str = "silu",
351
+ ):
352
+ super().__init__()
353
+ self.in_channels = in_channels
354
+ self.out_channels = in_channels
355
+
356
+ self.patch_embedding = OmniGenPatchEmbed(
357
+ patch_size=patch_size,
358
+ in_channels=in_channels,
359
+ embed_dim=hidden_size,
360
+ pos_embed_max_size=pos_embed_max_size,
361
+ )
362
+
363
+ self.time_proj = Timesteps(time_step_dim, flip_sin_to_cos, downscale_freq_shift)
364
+ self.time_token = TimestepEmbedding(time_step_dim, hidden_size, timestep_activation_fn)
365
+ self.t_embedder = TimestepEmbedding(time_step_dim, hidden_size, timestep_activation_fn)
366
+
367
+ self.embed_tokens = nn.Embedding(vocab_size, hidden_size, pad_token_id)
368
+ self.rope = OmniGenSuScaledRotaryEmbedding(
369
+ hidden_size // num_attention_heads,
370
+ max_position_embeddings=max_position_embeddings,
371
+ original_max_position_embeddings=original_max_position_embeddings,
372
+ base=rope_base,
373
+ rope_scaling=rope_scaling,
374
+ )
375
+
376
+ self.layers = nn.ModuleList(
377
+ [
378
+ OmniGenBlock(hidden_size, num_attention_heads, num_key_value_heads, intermediate_size, rms_norm_eps)
379
+ for _ in range(num_layers)
380
+ ]
381
+ )
382
+
383
+ self.norm = RMSNorm(hidden_size, eps=rms_norm_eps)
384
+ self.norm_out = AdaLayerNorm(hidden_size, norm_elementwise_affine=False, norm_eps=1e-6, chunk_dim=1)
385
+ self.proj_out = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True)
386
+
387
+ self.gradient_checkpointing = False
388
+
389
+ def _get_multimodal_embeddings(
390
+ self, input_ids: torch.Tensor, input_img_latents: List[torch.Tensor], input_image_sizes: Dict
391
+ ) -> Optional[torch.Tensor]:
392
+ if input_ids is None:
393
+ return None
394
+
395
+ input_img_latents = [x.to(self.dtype) for x in input_img_latents]
396
+ condition_tokens = self.embed_tokens(input_ids)
397
+ input_img_inx = 0
398
+ input_image_tokens = self.patch_embedding(input_img_latents, is_input_image=True)
399
+ for b_inx in input_image_sizes.keys():
400
+ for start_inx, end_inx in input_image_sizes[b_inx]:
401
+ # replace the placeholder in text tokens with the image embedding.
402
+ condition_tokens[b_inx, start_inx:end_inx] = input_image_tokens[input_img_inx].to(
403
+ condition_tokens.dtype
404
+ )
405
+ input_img_inx += 1
406
+ return condition_tokens
407
+
408
+ def forward(
409
+ self,
410
+ hidden_states: torch.Tensor,
411
+ timestep: Union[int, float, torch.FloatTensor],
412
+ input_ids: torch.Tensor,
413
+ input_img_latents: List[torch.Tensor],
414
+ input_image_sizes: Dict[int, List[int]],
415
+ attention_mask: torch.Tensor,
416
+ position_ids: torch.Tensor,
417
+ return_dict: bool = True,
418
+ ) -> Union[Transformer2DModelOutput, Tuple[torch.Tensor]]:
419
+ batch_size, num_channels, height, width = hidden_states.shape
420
+ p = self.config.patch_size
421
+ post_patch_height, post_patch_width = height // p, width // p
422
+
423
+ # 1. Patch & Timestep & Conditional Embedding
424
+ hidden_states = self.patch_embedding(hidden_states, is_input_image=False)
425
+ num_tokens_for_output_image = hidden_states.size(1)
426
+
427
+ timestep_proj = self.time_proj(timestep).type_as(hidden_states)
428
+ time_token = self.time_token(timestep_proj).unsqueeze(1)
429
+ temb = self.t_embedder(timestep_proj)
430
+
431
+ condition_tokens = self._get_multimodal_embeddings(input_ids, input_img_latents, input_image_sizes)
432
+ if condition_tokens is not None:
433
+ hidden_states = torch.cat([condition_tokens, time_token, hidden_states], dim=1)
434
+ else:
435
+ hidden_states = torch.cat([time_token, hidden_states], dim=1)
436
+
437
+ seq_length = hidden_states.size(1)
438
+ position_ids = position_ids.view(-1, seq_length).long()
439
+
440
+ # 2. Attention mask preprocessing
441
+ if attention_mask is not None and attention_mask.dim() == 3:
442
+ dtype = hidden_states.dtype
443
+ min_dtype = torch.finfo(dtype).min
444
+ attention_mask = (1 - attention_mask) * min_dtype
445
+ attention_mask = attention_mask.unsqueeze(1).type_as(hidden_states)
446
+
447
+ # 3. Rotary position embedding
448
+ image_rotary_emb = self.rope(hidden_states, position_ids)
449
+
450
+ # 4. Transformer blocks
451
+ for block in self.layers:
452
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
453
+ hidden_states = self._gradient_checkpointing_func(
454
+ block, hidden_states, attention_mask, image_rotary_emb
455
+ )
456
+ else:
457
+ hidden_states = block(hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb)
458
+
459
+ # 5. Output norm & projection
460
+ hidden_states = self.norm(hidden_states)
461
+ hidden_states = hidden_states[:, -num_tokens_for_output_image:]
462
+ hidden_states = self.norm_out(hidden_states, temb=temb)
463
+ hidden_states = self.proj_out(hidden_states)
464
+ hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, p, p, -1)
465
+ output = hidden_states.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3)
466
+
467
+ if not return_dict:
468
+ return (output,)
469
+ return Transformer2DModelOutput(sample=output)
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/transformer_qwenimage.py ADDED
@@ -0,0 +1,655 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Qwen-Image Team, The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import functools
16
+ import math
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+
24
+ from ...configuration_utils import ConfigMixin, register_to_config
25
+ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
26
+ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
27
+ from ...utils.torch_utils import maybe_allow_in_graph
28
+ from ..attention import AttentionMixin, FeedForward
29
+ from ..attention_dispatch import dispatch_attention_fn
30
+ from ..attention_processor import Attention
31
+ from ..cache_utils import CacheMixin
32
+ from ..embeddings import TimestepEmbedding, Timesteps
33
+ from ..modeling_outputs import Transformer2DModelOutput
34
+ from ..modeling_utils import ModelMixin
35
+ from ..normalization import AdaLayerNormContinuous, RMSNorm
36
+
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+
41
+ def get_timestep_embedding(
42
+ timesteps: torch.Tensor,
43
+ embedding_dim: int,
44
+ flip_sin_to_cos: bool = False,
45
+ downscale_freq_shift: float = 1,
46
+ scale: float = 1,
47
+ max_period: int = 10000,
48
+ ) -> torch.Tensor:
49
+ """
50
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
51
+
52
+ Args
53
+ timesteps (torch.Tensor):
54
+ a 1-D Tensor of N indices, one per batch element. These may be fractional.
55
+ embedding_dim (int):
56
+ the dimension of the output.
57
+ flip_sin_to_cos (bool):
58
+ Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
59
+ downscale_freq_shift (float):
60
+ Controls the delta between frequencies between dimensions
61
+ scale (float):
62
+ Scaling factor applied to the embeddings.
63
+ max_period (int):
64
+ Controls the maximum frequency of the embeddings
65
+ Returns
66
+ torch.Tensor: an [N x dim] Tensor of positional embeddings.
67
+ """
68
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
69
+
70
+ half_dim = embedding_dim // 2
71
+ exponent = -math.log(max_period) * torch.arange(
72
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
73
+ )
74
+ exponent = exponent / (half_dim - downscale_freq_shift)
75
+
76
+ emb = torch.exp(exponent).to(timesteps.dtype)
77
+ emb = timesteps[:, None].float() * emb[None, :]
78
+
79
+ # scale embeddings
80
+ emb = scale * emb
81
+
82
+ # concat sine and cosine embeddings
83
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
84
+
85
+ # flip sine and cosine embeddings
86
+ if flip_sin_to_cos:
87
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
88
+
89
+ # zero pad
90
+ if embedding_dim % 2 == 1:
91
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
92
+ return emb
93
+
94
+
95
+ def apply_rotary_emb_qwen(
96
+ x: torch.Tensor,
97
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
98
+ use_real: bool = True,
99
+ use_real_unbind_dim: int = -1,
100
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
101
+ """
102
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
103
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
104
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
105
+ tensors contain rotary embeddings and are returned as real tensors.
106
+
107
+ Args:
108
+ x (`torch.Tensor`):
109
+ Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply
110
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
111
+
112
+ Returns:
113
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
114
+ """
115
+ if use_real:
116
+ cos, sin = freqs_cis # [S, D]
117
+ cos = cos[None, None]
118
+ sin = sin[None, None]
119
+ cos, sin = cos.to(x.device), sin.to(x.device)
120
+
121
+ if use_real_unbind_dim == -1:
122
+ # Used for flux, cogvideox, hunyuan-dit
123
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
124
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
125
+ elif use_real_unbind_dim == -2:
126
+ # Used for Stable Audio, OmniGen, CogView4 and Cosmos
127
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
128
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
129
+ else:
130
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
131
+
132
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
133
+
134
+ return out
135
+ else:
136
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
137
+ freqs_cis = freqs_cis.unsqueeze(1)
138
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
139
+
140
+ return x_out.type_as(x)
141
+
142
+
143
+ class QwenTimestepProjEmbeddings(nn.Module):
144
+ def __init__(self, embedding_dim):
145
+ super().__init__()
146
+
147
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
148
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
149
+
150
+ def forward(self, timestep, hidden_states):
151
+ timesteps_proj = self.time_proj(timestep)
152
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D)
153
+
154
+ conditioning = timesteps_emb
155
+
156
+ return conditioning
157
+
158
+
159
+ class QwenEmbedRope(nn.Module):
160
+ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
161
+ super().__init__()
162
+ self.theta = theta
163
+ self.axes_dim = axes_dim
164
+ pos_index = torch.arange(4096)
165
+ neg_index = torch.arange(4096).flip(0) * -1 - 1
166
+ self.pos_freqs = torch.cat(
167
+ [
168
+ self.rope_params(pos_index, self.axes_dim[0], self.theta),
169
+ self.rope_params(pos_index, self.axes_dim[1], self.theta),
170
+ self.rope_params(pos_index, self.axes_dim[2], self.theta),
171
+ ],
172
+ dim=1,
173
+ )
174
+ self.neg_freqs = torch.cat(
175
+ [
176
+ self.rope_params(neg_index, self.axes_dim[0], self.theta),
177
+ self.rope_params(neg_index, self.axes_dim[1], self.theta),
178
+ self.rope_params(neg_index, self.axes_dim[2], self.theta),
179
+ ],
180
+ dim=1,
181
+ )
182
+ self.rope_cache = {}
183
+
184
+ # DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
185
+ self.scale_rope = scale_rope
186
+
187
+ def rope_params(self, index, dim, theta=10000):
188
+ """
189
+ Args:
190
+ index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
191
+ """
192
+ assert dim % 2 == 0
193
+ freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))
194
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
195
+ return freqs
196
+
197
+ def forward(self, video_fhw, txt_seq_lens, device):
198
+ """
199
+ Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
200
+ txt_length: [bs] a list of 1 integers representing the length of the text
201
+ """
202
+ if self.pos_freqs.device != device:
203
+ self.pos_freqs = self.pos_freqs.to(device)
204
+ self.neg_freqs = self.neg_freqs.to(device)
205
+
206
+ if isinstance(video_fhw, list):
207
+ video_fhw = video_fhw[0]
208
+ if not isinstance(video_fhw, list):
209
+ video_fhw = [video_fhw]
210
+
211
+ vid_freqs = []
212
+ max_vid_index = 0
213
+ for idx, fhw in enumerate(video_fhw):
214
+ frame, height, width = fhw
215
+ rope_key = f"{idx}_{height}_{width}"
216
+
217
+ if not torch.compiler.is_compiling():
218
+ if rope_key not in self.rope_cache:
219
+ self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx)
220
+ video_freq = self.rope_cache[rope_key]
221
+ else:
222
+ video_freq = self._compute_video_freqs(frame, height, width, idx)
223
+ video_freq = video_freq.to(device)
224
+ vid_freqs.append(video_freq)
225
+
226
+ if self.scale_rope:
227
+ max_vid_index = max(height // 2, width // 2, max_vid_index)
228
+ else:
229
+ max_vid_index = max(height, width, max_vid_index)
230
+
231
+ max_len = max(txt_seq_lens)
232
+ txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
233
+ vid_freqs = torch.cat(vid_freqs, dim=0)
234
+
235
+ return vid_freqs, txt_freqs
236
+
237
+ @functools.lru_cache(maxsize=None)
238
+ def _compute_video_freqs(self, frame, height, width, idx=0):
239
+ seq_lens = frame * height * width
240
+ freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
241
+ freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
242
+
243
+ freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
244
+ if self.scale_rope:
245
+ freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
246
+ freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
247
+ freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
248
+ freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
249
+ else:
250
+ freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
251
+ freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
252
+
253
+ freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
254
+ return freqs.clone().contiguous()
255
+
256
+
257
+ class QwenDoubleStreamAttnProcessor2_0:
258
+ """
259
+ Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor
260
+ implements joint attention computation where text and image streams are processed together.
261
+ """
262
+
263
+ _attention_backend = None
264
+
265
+ def __init__(self):
266
+ if not hasattr(F, "scaled_dot_product_attention"):
267
+ raise ImportError(
268
+ "QwenDoubleStreamAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
269
+ )
270
+
271
+ def __call__(
272
+ self,
273
+ attn: Attention,
274
+ hidden_states: torch.FloatTensor, # Image stream
275
+ encoder_hidden_states: torch.FloatTensor = None, # Text stream
276
+ encoder_hidden_states_mask: torch.FloatTensor = None,
277
+ attention_mask: Optional[torch.FloatTensor] = None,
278
+ image_rotary_emb: Optional[torch.Tensor] = None,
279
+ ) -> torch.FloatTensor:
280
+ if encoder_hidden_states is None:
281
+ raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)")
282
+
283
+ seq_txt = encoder_hidden_states.shape[1]
284
+
285
+ # Compute QKV for image stream (sample projections)
286
+ img_query = attn.to_q(hidden_states)
287
+ img_key = attn.to_k(hidden_states)
288
+ img_value = attn.to_v(hidden_states)
289
+
290
+ # Compute QKV for text stream (context projections)
291
+ txt_query = attn.add_q_proj(encoder_hidden_states)
292
+ txt_key = attn.add_k_proj(encoder_hidden_states)
293
+ txt_value = attn.add_v_proj(encoder_hidden_states)
294
+
295
+ # Reshape for multi-head attention
296
+ img_query = img_query.unflatten(-1, (attn.heads, -1))
297
+ img_key = img_key.unflatten(-1, (attn.heads, -1))
298
+ img_value = img_value.unflatten(-1, (attn.heads, -1))
299
+
300
+ txt_query = txt_query.unflatten(-1, (attn.heads, -1))
301
+ txt_key = txt_key.unflatten(-1, (attn.heads, -1))
302
+ txt_value = txt_value.unflatten(-1, (attn.heads, -1))
303
+
304
+ # Apply QK normalization
305
+ if attn.norm_q is not None:
306
+ img_query = attn.norm_q(img_query)
307
+ if attn.norm_k is not None:
308
+ img_key = attn.norm_k(img_key)
309
+ if attn.norm_added_q is not None:
310
+ txt_query = attn.norm_added_q(txt_query)
311
+ if attn.norm_added_k is not None:
312
+ txt_key = attn.norm_added_k(txt_key)
313
+
314
+ # Apply RoPE
315
+ if image_rotary_emb is not None:
316
+ img_freqs, txt_freqs = image_rotary_emb
317
+ img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False)
318
+ img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False)
319
+ txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False)
320
+ txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False)
321
+
322
+ # Concatenate for joint attention
323
+ # Order: [text, image]
324
+ joint_query = torch.cat([txt_query, img_query], dim=1)
325
+ joint_key = torch.cat([txt_key, img_key], dim=1)
326
+ joint_value = torch.cat([txt_value, img_value], dim=1)
327
+
328
+ # Compute joint attention
329
+ joint_hidden_states = dispatch_attention_fn(
330
+ joint_query,
331
+ joint_key,
332
+ joint_value,
333
+ attn_mask=attention_mask,
334
+ dropout_p=0.0,
335
+ is_causal=False,
336
+ backend=self._attention_backend,
337
+ )
338
+
339
+ # Reshape back
340
+ joint_hidden_states = joint_hidden_states.flatten(2, 3)
341
+ joint_hidden_states = joint_hidden_states.to(joint_query.dtype)
342
+
343
+ # Split attention outputs back
344
+ txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part
345
+ img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
346
+
347
+ # Apply output projections
348
+ img_attn_output = attn.to_out[0](img_attn_output)
349
+ if len(attn.to_out) > 1:
350
+ img_attn_output = attn.to_out[1](img_attn_output) # dropout
351
+
352
+ txt_attn_output = attn.to_add_out(txt_attn_output)
353
+
354
+ return img_attn_output, txt_attn_output
355
+
356
+
357
+ @maybe_allow_in_graph
358
+ class QwenImageTransformerBlock(nn.Module):
359
+ def __init__(
360
+ self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
361
+ ):
362
+ super().__init__()
363
+
364
+ self.dim = dim
365
+ self.num_attention_heads = num_attention_heads
366
+ self.attention_head_dim = attention_head_dim
367
+
368
+ # Image processing modules
369
+ self.img_mod = nn.Sequential(
370
+ nn.SiLU(),
371
+ nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2
372
+ )
373
+ self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
374
+ self.attn = Attention(
375
+ query_dim=dim,
376
+ cross_attention_dim=None, # Enable cross attention for joint computation
377
+ added_kv_proj_dim=dim, # Enable added KV projections for text stream
378
+ dim_head=attention_head_dim,
379
+ heads=num_attention_heads,
380
+ out_dim=dim,
381
+ context_pre_only=False,
382
+ bias=True,
383
+ processor=QwenDoubleStreamAttnProcessor2_0(),
384
+ qk_norm=qk_norm,
385
+ eps=eps,
386
+ )
387
+ self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
388
+ self.img_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
389
+
390
+ # Text processing modules
391
+ self.txt_mod = nn.Sequential(
392
+ nn.SiLU(),
393
+ nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2
394
+ )
395
+ self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
396
+ # Text doesn't need separate attention - it's handled by img_attn joint computation
397
+ self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
398
+ self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
399
+
400
+ def _modulate(self, x, mod_params):
401
+ """Apply modulation to input tensor"""
402
+ shift, scale, gate = mod_params.chunk(3, dim=-1)
403
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
404
+
405
+ def forward(
406
+ self,
407
+ hidden_states: torch.Tensor,
408
+ encoder_hidden_states: torch.Tensor,
409
+ encoder_hidden_states_mask: torch.Tensor,
410
+ temb: torch.Tensor,
411
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
412
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
413
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
414
+ # Get modulation parameters for both streams
415
+ img_mod_params = self.img_mod(temb) # [B, 6*dim]
416
+ txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
417
+
418
+ # Split modulation parameters for norm1 and norm2
419
+ img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
420
+ txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
421
+
422
+ # Process image stream - norm1 + modulation
423
+ img_normed = self.img_norm1(hidden_states)
424
+ img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
425
+
426
+ # Process text stream - norm1 + modulation
427
+ txt_normed = self.txt_norm1(encoder_hidden_states)
428
+ txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
429
+
430
+ # Use QwenAttnProcessor2_0 for joint attention computation
431
+ # This directly implements the DoubleStreamLayerMegatron logic:
432
+ # 1. Computes QKV for both streams
433
+ # 2. Applies QK normalization and RoPE
434
+ # 3. Concatenates and runs joint attention
435
+ # 4. Splits results back to separate streams
436
+ joint_attention_kwargs = joint_attention_kwargs or {}
437
+ attn_output = self.attn(
438
+ hidden_states=img_modulated, # Image stream (will be processed as "sample")
439
+ encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context")
440
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
441
+ image_rotary_emb=image_rotary_emb,
442
+ **joint_attention_kwargs,
443
+ )
444
+
445
+ # QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided
446
+ img_attn_output, txt_attn_output = attn_output
447
+
448
+ # Apply attention gates and add residual (like in Megatron)
449
+ hidden_states = hidden_states + img_gate1 * img_attn_output
450
+ encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
451
+
452
+ # Process image stream - norm2 + MLP
453
+ img_normed2 = self.img_norm2(hidden_states)
454
+ img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
455
+ img_mlp_output = self.img_mlp(img_modulated2)
456
+ hidden_states = hidden_states + img_gate2 * img_mlp_output
457
+
458
+ # Process text stream - norm2 + MLP
459
+ txt_normed2 = self.txt_norm2(encoder_hidden_states)
460
+ txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
461
+ txt_mlp_output = self.txt_mlp(txt_modulated2)
462
+ encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
463
+
464
+ # Clip to prevent overflow for fp16
465
+ if encoder_hidden_states.dtype == torch.float16:
466
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
467
+ if hidden_states.dtype == torch.float16:
468
+ hidden_states = hidden_states.clip(-65504, 65504)
469
+
470
+ return encoder_hidden_states, hidden_states
471
+
472
+
473
+ class QwenImageTransformer2DModel(
474
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
475
+ ):
476
+ """
477
+ The Transformer model introduced in Qwen.
478
+
479
+ Args:
480
+ patch_size (`int`, defaults to `2`):
481
+ Patch size to turn the input data into small patches.
482
+ in_channels (`int`, defaults to `64`):
483
+ The number of channels in the input.
484
+ out_channels (`int`, *optional*, defaults to `None`):
485
+ The number of channels in the output. If not specified, it defaults to `in_channels`.
486
+ num_layers (`int`, defaults to `60`):
487
+ The number of layers of dual stream DiT blocks to use.
488
+ attention_head_dim (`int`, defaults to `128`):
489
+ The number of dimensions to use for each attention head.
490
+ num_attention_heads (`int`, defaults to `24`):
491
+ The number of attention heads to use.
492
+ joint_attention_dim (`int`, defaults to `3584`):
493
+ The number of dimensions to use for the joint attention (embedding/channel dimension of
494
+ `encoder_hidden_states`).
495
+ guidance_embeds (`bool`, defaults to `False`):
496
+ Whether to use guidance embeddings for guidance-distilled variant of the model.
497
+ axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
498
+ The dimensions to use for the rotary positional embeddings.
499
+ """
500
+
501
+ _supports_gradient_checkpointing = True
502
+ _no_split_modules = ["QwenImageTransformerBlock"]
503
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
504
+ _repeated_blocks = ["QwenImageTransformerBlock"]
505
+
506
+ @register_to_config
507
+ def __init__(
508
+ self,
509
+ patch_size: int = 2,
510
+ in_channels: int = 64,
511
+ out_channels: Optional[int] = 16,
512
+ num_layers: int = 60,
513
+ attention_head_dim: int = 128,
514
+ num_attention_heads: int = 24,
515
+ joint_attention_dim: int = 3584,
516
+ guidance_embeds: bool = False, # TODO: this should probably be removed
517
+ axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
518
+ ):
519
+ super().__init__()
520
+ self.out_channels = out_channels or in_channels
521
+ self.inner_dim = num_attention_heads * attention_head_dim
522
+
523
+ self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
524
+
525
+ self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim)
526
+
527
+ self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)
528
+
529
+ self.img_in = nn.Linear(in_channels, self.inner_dim)
530
+ self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim)
531
+
532
+ self.transformer_blocks = nn.ModuleList(
533
+ [
534
+ QwenImageTransformerBlock(
535
+ dim=self.inner_dim,
536
+ num_attention_heads=num_attention_heads,
537
+ attention_head_dim=attention_head_dim,
538
+ )
539
+ for _ in range(num_layers)
540
+ ]
541
+ )
542
+
543
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
544
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
545
+
546
+ self.gradient_checkpointing = False
547
+
548
+ def forward(
549
+ self,
550
+ hidden_states: torch.Tensor,
551
+ encoder_hidden_states: torch.Tensor = None,
552
+ encoder_hidden_states_mask: torch.Tensor = None,
553
+ timestep: torch.LongTensor = None,
554
+ img_shapes: Optional[List[Tuple[int, int, int]]] = None,
555
+ txt_seq_lens: Optional[List[int]] = None,
556
+ guidance: torch.Tensor = None, # TODO: this should probably be removed
557
+ attention_kwargs: Optional[Dict[str, Any]] = None,
558
+ controlnet_block_samples=None,
559
+ return_dict: bool = True,
560
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
561
+ """
562
+ The [`QwenTransformer2DModel`] forward method.
563
+
564
+ Args:
565
+ hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
566
+ Input `hidden_states`.
567
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
568
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
569
+ encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`):
570
+ Mask of the input conditions.
571
+ timestep ( `torch.LongTensor`):
572
+ Used to indicate denoising step.
573
+ attention_kwargs (`dict`, *optional*):
574
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
575
+ `self.processor` in
576
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
577
+ return_dict (`bool`, *optional*, defaults to `True`):
578
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
579
+ tuple.
580
+
581
+ Returns:
582
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
583
+ `tuple` where the first element is the sample tensor.
584
+ """
585
+ if attention_kwargs is not None:
586
+ attention_kwargs = attention_kwargs.copy()
587
+ lora_scale = attention_kwargs.pop("scale", 1.0)
588
+ else:
589
+ lora_scale = 1.0
590
+
591
+ if USE_PEFT_BACKEND:
592
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
593
+ scale_lora_layers(self, lora_scale)
594
+ else:
595
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
596
+ logger.warning(
597
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
598
+ )
599
+
600
+ hidden_states = self.img_in(hidden_states)
601
+
602
+ timestep = timestep.to(hidden_states.dtype)
603
+ encoder_hidden_states = self.txt_norm(encoder_hidden_states)
604
+ encoder_hidden_states = self.txt_in(encoder_hidden_states)
605
+
606
+ if guidance is not None:
607
+ guidance = guidance.to(hidden_states.dtype) * 1000
608
+
609
+ temb = (
610
+ self.time_text_embed(timestep, hidden_states)
611
+ if guidance is None
612
+ else self.time_text_embed(timestep, guidance, hidden_states)
613
+ )
614
+
615
+ image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
616
+
617
+ for index_block, block in enumerate(self.transformer_blocks):
618
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
619
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
620
+ block,
621
+ hidden_states,
622
+ encoder_hidden_states,
623
+ encoder_hidden_states_mask,
624
+ temb,
625
+ image_rotary_emb,
626
+ )
627
+
628
+ else:
629
+ encoder_hidden_states, hidden_states = block(
630
+ hidden_states=hidden_states,
631
+ encoder_hidden_states=encoder_hidden_states,
632
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
633
+ temb=temb,
634
+ image_rotary_emb=image_rotary_emb,
635
+ joint_attention_kwargs=attention_kwargs,
636
+ )
637
+
638
+ # controlnet residual
639
+ if controlnet_block_samples is not None:
640
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
641
+ interval_control = int(np.ceil(interval_control))
642
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
643
+
644
+ # Use only the image part (hidden_states) from the dual-stream blocks
645
+ hidden_states = self.norm_out(hidden_states, temb)
646
+ output = self.proj_out(hidden_states)
647
+
648
+ if USE_PEFT_BACKEND:
649
+ # remove `lora_scale` from each PEFT layer
650
+ unscale_lora_layers(self, lora_scale)
651
+
652
+ if not return_dict:
653
+ return (output,)
654
+
655
+ return Transformer2DModelOutput(sample=output)
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/transformer_sd3.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, List, Optional, Tuple, Union
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+ from ...configuration_utils import ConfigMixin, register_to_config
20
+ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
21
+ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
22
+ from ...utils.torch_utils import maybe_allow_in_graph
23
+ from ..attention import FeedForward, JointTransformerBlock
24
+ from ..attention_processor import (
25
+ Attention,
26
+ AttentionProcessor,
27
+ FusedJointAttnProcessor2_0,
28
+ JointAttnProcessor2_0,
29
+ )
30
+ from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
31
+ from ..modeling_outputs import Transformer2DModelOutput
32
+ from ..modeling_utils import ModelMixin
33
+ from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero
34
+
35
+
36
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
37
+
38
+
39
+ @maybe_allow_in_graph
40
+ class SD3SingleTransformerBlock(nn.Module):
41
+ def __init__(
42
+ self,
43
+ dim: int,
44
+ num_attention_heads: int,
45
+ attention_head_dim: int,
46
+ ):
47
+ super().__init__()
48
+
49
+ self.norm1 = AdaLayerNormZero(dim)
50
+ self.attn = Attention(
51
+ query_dim=dim,
52
+ dim_head=attention_head_dim,
53
+ heads=num_attention_heads,
54
+ out_dim=dim,
55
+ bias=True,
56
+ processor=JointAttnProcessor2_0(),
57
+ eps=1e-6,
58
+ )
59
+
60
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
61
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
62
+
63
+ def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor):
64
+ # 1. Attention
65
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
66
+ attn_output = self.attn(hidden_states=norm_hidden_states, encoder_hidden_states=None)
67
+ attn_output = gate_msa.unsqueeze(1) * attn_output
68
+ hidden_states = hidden_states + attn_output
69
+
70
+ # 2. Feed Forward
71
+ norm_hidden_states = self.norm2(hidden_states)
72
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
73
+ ff_output = self.ff(norm_hidden_states)
74
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
75
+ hidden_states = hidden_states + ff_output
76
+
77
+ return hidden_states
78
+
79
+
80
+ class SD3Transformer2DModel(
81
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin
82
+ ):
83
+ """
84
+ The Transformer model introduced in [Stable Diffusion 3](https://huggingface.co/papers/2403.03206).
85
+
86
+ Parameters:
87
+ sample_size (`int`, defaults to `128`):
88
+ The width/height of the latents. This is fixed during training since it is used to learn a number of
89
+ position embeddings.
90
+ patch_size (`int`, defaults to `2`):
91
+ Patch size to turn the input data into small patches.
92
+ in_channels (`int`, defaults to `16`):
93
+ The number of latent channels in the input.
94
+ num_layers (`int`, defaults to `18`):
95
+ The number of layers of transformer blocks to use.
96
+ attention_head_dim (`int`, defaults to `64`):
97
+ The number of channels in each head.
98
+ num_attention_heads (`int`, defaults to `18`):
99
+ The number of heads to use for multi-head attention.
100
+ joint_attention_dim (`int`, defaults to `4096`):
101
+ The embedding dimension to use for joint text-image attention.
102
+ caption_projection_dim (`int`, defaults to `1152`):
103
+ The embedding dimension of caption embeddings.
104
+ pooled_projection_dim (`int`, defaults to `2048`):
105
+ The embedding dimension of pooled text projections.
106
+ out_channels (`int`, defaults to `16`):
107
+ The number of latent channels in the output.
108
+ pos_embed_max_size (`int`, defaults to `96`):
109
+ The maximum latent height/width of positional embeddings.
110
+ dual_attention_layers (`Tuple[int, ...]`, defaults to `()`):
111
+ The number of dual-stream transformer blocks to use.
112
+ qk_norm (`str`, *optional*, defaults to `None`):
113
+ The normalization to use for query and key in the attention layer. If `None`, no normalization is used.
114
+ """
115
+
116
+ _supports_gradient_checkpointing = True
117
+ _no_split_modules = ["JointTransformerBlock"]
118
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
119
+
120
+ @register_to_config
121
+ def __init__(
122
+ self,
123
+ sample_size: int = 128,
124
+ patch_size: int = 2,
125
+ in_channels: int = 16,
126
+ num_layers: int = 18,
127
+ attention_head_dim: int = 64,
128
+ num_attention_heads: int = 18,
129
+ joint_attention_dim: int = 4096,
130
+ caption_projection_dim: int = 1152,
131
+ pooled_projection_dim: int = 2048,
132
+ out_channels: int = 16,
133
+ pos_embed_max_size: int = 96,
134
+ dual_attention_layers: Tuple[
135
+ int, ...
136
+ ] = (), # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
137
+ qk_norm: Optional[str] = None,
138
+ ):
139
+ super().__init__()
140
+ self.out_channels = out_channels if out_channels is not None else in_channels
141
+ self.inner_dim = num_attention_heads * attention_head_dim
142
+
143
+ self.pos_embed = PatchEmbed(
144
+ height=sample_size,
145
+ width=sample_size,
146
+ patch_size=patch_size,
147
+ in_channels=in_channels,
148
+ embed_dim=self.inner_dim,
149
+ pos_embed_max_size=pos_embed_max_size, # hard-code for now.
150
+ )
151
+ self.time_text_embed = CombinedTimestepTextProjEmbeddings(
152
+ embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
153
+ )
154
+ self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)
155
+
156
+ self.transformer_blocks = nn.ModuleList(
157
+ [
158
+ JointTransformerBlock(
159
+ dim=self.inner_dim,
160
+ num_attention_heads=num_attention_heads,
161
+ attention_head_dim=attention_head_dim,
162
+ context_pre_only=i == num_layers - 1,
163
+ qk_norm=qk_norm,
164
+ use_dual_attention=True if i in dual_attention_layers else False,
165
+ )
166
+ for i in range(num_layers)
167
+ ]
168
+ )
169
+
170
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
171
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
172
+
173
+ self.gradient_checkpointing = False
174
+
175
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
176
+ def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
177
+ """
178
+ Sets the attention processor to use [feed forward
179
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
180
+
181
+ Parameters:
182
+ chunk_size (`int`, *optional*):
183
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
184
+ over each tensor of dim=`dim`.
185
+ dim (`int`, *optional*, defaults to `0`):
186
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
187
+ or dim=1 (sequence length).
188
+ """
189
+ if dim not in [0, 1]:
190
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
191
+
192
+ # By default chunk size is 1
193
+ chunk_size = chunk_size or 1
194
+
195
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
196
+ if hasattr(module, "set_chunk_feed_forward"):
197
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
198
+
199
+ for child in module.children():
200
+ fn_recursive_feed_forward(child, chunk_size, dim)
201
+
202
+ for module in self.children():
203
+ fn_recursive_feed_forward(module, chunk_size, dim)
204
+
205
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
206
+ def disable_forward_chunking(self):
207
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
208
+ if hasattr(module, "set_chunk_feed_forward"):
209
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
210
+
211
+ for child in module.children():
212
+ fn_recursive_feed_forward(child, chunk_size, dim)
213
+
214
+ for module in self.children():
215
+ fn_recursive_feed_forward(module, None, 0)
216
+
217
+ @property
218
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
219
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
220
+ r"""
221
+ Returns:
222
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
223
+ indexed by its weight name.
224
+ """
225
+ # set recursively
226
+ processors = {}
227
+
228
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
229
+ if hasattr(module, "get_processor"):
230
+ processors[f"{name}.processor"] = module.get_processor()
231
+
232
+ for sub_name, child in module.named_children():
233
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
234
+
235
+ return processors
236
+
237
+ for name, module in self.named_children():
238
+ fn_recursive_add_processors(name, module, processors)
239
+
240
+ return processors
241
+
242
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
243
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
244
+ r"""
245
+ Sets the attention processor to use to compute attention.
246
+
247
+ Parameters:
248
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
249
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
250
+ for **all** `Attention` layers.
251
+
252
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
253
+ processor. This is strongly recommended when setting trainable attention processors.
254
+
255
+ """
256
+ count = len(self.attn_processors.keys())
257
+
258
+ if isinstance(processor, dict) and len(processor) != count:
259
+ raise ValueError(
260
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
261
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
262
+ )
263
+
264
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
265
+ if hasattr(module, "set_processor"):
266
+ if not isinstance(processor, dict):
267
+ module.set_processor(processor)
268
+ else:
269
+ module.set_processor(processor.pop(f"{name}.processor"))
270
+
271
+ for sub_name, child in module.named_children():
272
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
273
+
274
+ for name, module in self.named_children():
275
+ fn_recursive_attn_processor(name, module, processor)
276
+
277
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0
278
+ def fuse_qkv_projections(self):
279
+ """
280
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
281
+ are fused. For cross-attention modules, key and value projection matrices are fused.
282
+
283
+ <Tip warning={true}>
284
+
285
+ This API is 🧪 experimental.
286
+
287
+ </Tip>
288
+ """
289
+ self.original_attn_processors = None
290
+
291
+ for _, attn_processor in self.attn_processors.items():
292
+ if "Added" in str(attn_processor.__class__.__name__):
293
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
294
+
295
+ self.original_attn_processors = self.attn_processors
296
+
297
+ for module in self.modules():
298
+ if isinstance(module, Attention):
299
+ module.fuse_projections(fuse=True)
300
+
301
+ self.set_attn_processor(FusedJointAttnProcessor2_0())
302
+
303
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
304
+ def unfuse_qkv_projections(self):
305
+ """Disables the fused QKV projection if enabled.
306
+
307
+ <Tip warning={true}>
308
+
309
+ This API is 🧪 experimental.
310
+
311
+ </Tip>
312
+
313
+ """
314
+ if self.original_attn_processors is not None:
315
+ self.set_attn_processor(self.original_attn_processors)
316
+
317
+ def forward(
318
+ self,
319
+ hidden_states: torch.Tensor,
320
+ encoder_hidden_states: torch.Tensor = None,
321
+ pooled_projections: torch.Tensor = None,
322
+ timestep: torch.LongTensor = None,
323
+ block_controlnet_hidden_states: List = None,
324
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
325
+ return_dict: bool = True,
326
+ skip_layers: Optional[List[int]] = None,
327
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
328
+ """
329
+ The [`SD3Transformer2DModel`] forward method.
330
+
331
+ Args:
332
+ hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`):
333
+ Input `hidden_states`.
334
+ encoder_hidden_states (`torch.Tensor` of shape `(batch size, sequence_len, embed_dims)`):
335
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
336
+ pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`):
337
+ Embeddings projected from the embeddings of input conditions.
338
+ timestep (`torch.LongTensor`):
339
+ Used to indicate denoising step.
340
+ block_controlnet_hidden_states (`list` of `torch.Tensor`):
341
+ A list of tensors that if specified are added to the residuals of transformer blocks.
342
+ joint_attention_kwargs (`dict`, *optional*):
343
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
344
+ `self.processor` in
345
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
346
+ return_dict (`bool`, *optional*, defaults to `True`):
347
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
348
+ tuple.
349
+ skip_layers (`list` of `int`, *optional*):
350
+ A list of layer indices to skip during the forward pass.
351
+
352
+ Returns:
353
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
354
+ `tuple` where the first element is the sample tensor.
355
+ """
356
+ if joint_attention_kwargs is not None:
357
+ joint_attention_kwargs = joint_attention_kwargs.copy()
358
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
359
+ else:
360
+ lora_scale = 1.0
361
+
362
+ if USE_PEFT_BACKEND:
363
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
364
+ scale_lora_layers(self, lora_scale)
365
+ else:
366
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
367
+ logger.warning(
368
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
369
+ )
370
+
371
+ height, width = hidden_states.shape[-2:]
372
+
373
+ hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
374
+ temb = self.time_text_embed(timestep, pooled_projections)
375
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
376
+
377
+ if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
378
+ ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
379
+ ip_hidden_states, ip_temb = self.image_proj(ip_adapter_image_embeds, timestep)
380
+
381
+ joint_attention_kwargs.update(ip_hidden_states=ip_hidden_states, temb=ip_temb)
382
+
383
+ for index_block, block in enumerate(self.transformer_blocks):
384
+ # Skip specified layers
385
+ is_skip = True if skip_layers is not None and index_block in skip_layers else False
386
+
387
+ if torch.is_grad_enabled() and self.gradient_checkpointing and not is_skip:
388
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
389
+ block,
390
+ hidden_states,
391
+ encoder_hidden_states,
392
+ temb,
393
+ joint_attention_kwargs,
394
+ )
395
+ elif not is_skip:
396
+ encoder_hidden_states, hidden_states = block(
397
+ hidden_states=hidden_states,
398
+ encoder_hidden_states=encoder_hidden_states,
399
+ temb=temb,
400
+ joint_attention_kwargs=joint_attention_kwargs,
401
+ )
402
+
403
+ # controlnet residual
404
+ if block_controlnet_hidden_states is not None and block.context_pre_only is False:
405
+ interval_control = len(self.transformer_blocks) / len(block_controlnet_hidden_states)
406
+ hidden_states = hidden_states + block_controlnet_hidden_states[int(index_block / interval_control)]
407
+
408
+ hidden_states = self.norm_out(hidden_states, temb)
409
+ hidden_states = self.proj_out(hidden_states)
410
+
411
+ # unpatchify
412
+ patch_size = self.config.patch_size
413
+ height = height // patch_size
414
+ width = width // patch_size
415
+
416
+ hidden_states = hidden_states.reshape(
417
+ shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
418
+ )
419
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
420
+ output = hidden_states.reshape(
421
+ shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
422
+ )
423
+
424
+ if USE_PEFT_BACKEND:
425
+ # remove `lora_scale` from each PEFT layer
426
+ unscale_lora_layers(self, lora_scale)
427
+
428
+ if not return_dict:
429
+ return (output,)
430
+
431
+ return Transformer2DModelOutput(sample=output)
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/transformer_skyreels_v2.py ADDED
@@ -0,0 +1,781 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The SkyReels Team, The Wan Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import Any, Dict, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from ...configuration_utils import ConfigMixin, register_to_config
23
+ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
24
+ from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
25
+ from ...utils.torch_utils import maybe_allow_in_graph
26
+ from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
27
+ from ..attention_dispatch import dispatch_attention_fn
28
+ from ..cache_utils import CacheMixin
29
+ from ..embeddings import (
30
+ PixArtAlphaTextProjection,
31
+ TimestepEmbedding,
32
+ get_1d_rotary_pos_embed,
33
+ get_1d_sincos_pos_embed_from_grid,
34
+ )
35
+ from ..modeling_outputs import Transformer2DModelOutput
36
+ from ..modeling_utils import ModelMixin, get_parameter_dtype
37
+ from ..normalization import FP32LayerNorm
38
+
39
+
40
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
+
42
+
43
+ def _get_qkv_projections(
44
+ attn: "SkyReelsV2Attention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor
45
+ ):
46
+ # encoder_hidden_states is only passed for cross-attention
47
+ if encoder_hidden_states is None:
48
+ encoder_hidden_states = hidden_states
49
+
50
+ if attn.fused_projections:
51
+ if attn.cross_attention_dim_head is None:
52
+ # In self-attention layers, we can fuse the entire QKV projection into a single linear
53
+ query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
54
+ else:
55
+ # In cross-attention layers, we can only fuse the KV projections into a single linear
56
+ query = attn.to_q(hidden_states)
57
+ key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1)
58
+ else:
59
+ query = attn.to_q(hidden_states)
60
+ key = attn.to_k(encoder_hidden_states)
61
+ value = attn.to_v(encoder_hidden_states)
62
+ return query, key, value
63
+
64
+
65
+ def _get_added_kv_projections(attn: "SkyReelsV2Attention", encoder_hidden_states_img: torch.Tensor):
66
+ if attn.fused_projections:
67
+ key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1)
68
+ else:
69
+ key_img = attn.add_k_proj(encoder_hidden_states_img)
70
+ value_img = attn.add_v_proj(encoder_hidden_states_img)
71
+ return key_img, value_img
72
+
73
+
74
+ class SkyReelsV2AttnProcessor:
75
+ _attention_backend = None
76
+
77
+ def __init__(self):
78
+ if not hasattr(F, "scaled_dot_product_attention"):
79
+ raise ImportError(
80
+ "SkyReelsV2AttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
81
+ )
82
+
83
+ def __call__(
84
+ self,
85
+ attn: "SkyReelsV2Attention",
86
+ hidden_states: torch.Tensor,
87
+ encoder_hidden_states: Optional[torch.Tensor] = None,
88
+ attention_mask: Optional[torch.Tensor] = None,
89
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
90
+ ) -> torch.Tensor:
91
+ encoder_hidden_states_img = None
92
+ if attn.add_k_proj is not None:
93
+ # 512 is the context length of the text encoder, hardcoded for now
94
+ image_context_length = encoder_hidden_states.shape[1] - 512
95
+ encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
96
+ encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
97
+
98
+ query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
99
+
100
+ query = attn.norm_q(query)
101
+ key = attn.norm_k(key)
102
+
103
+ query = query.unflatten(2, (attn.heads, -1))
104
+ key = key.unflatten(2, (attn.heads, -1))
105
+ value = value.unflatten(2, (attn.heads, -1))
106
+
107
+ if rotary_emb is not None:
108
+
109
+ def apply_rotary_emb(
110
+ hidden_states: torch.Tensor,
111
+ freqs_cos: torch.Tensor,
112
+ freqs_sin: torch.Tensor,
113
+ ):
114
+ x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
115
+ cos = freqs_cos[..., 0::2]
116
+ sin = freqs_sin[..., 1::2]
117
+ out = torch.empty_like(hidden_states)
118
+ out[..., 0::2] = x1 * cos - x2 * sin
119
+ out[..., 1::2] = x1 * sin + x2 * cos
120
+ return out.type_as(hidden_states)
121
+
122
+ query = apply_rotary_emb(query, *rotary_emb)
123
+ key = apply_rotary_emb(key, *rotary_emb)
124
+
125
+ # I2V task
126
+ hidden_states_img = None
127
+ if encoder_hidden_states_img is not None:
128
+ key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img)
129
+ key_img = attn.norm_added_k(key_img)
130
+
131
+ key_img = key_img.unflatten(2, (attn.heads, -1))
132
+ value_img = value_img.unflatten(2, (attn.heads, -1))
133
+
134
+ hidden_states_img = dispatch_attention_fn(
135
+ query,
136
+ key_img,
137
+ value_img,
138
+ attn_mask=None,
139
+ dropout_p=0.0,
140
+ is_causal=False,
141
+ backend=self._attention_backend,
142
+ )
143
+ hidden_states_img = hidden_states_img.flatten(2, 3)
144
+ hidden_states_img = hidden_states_img.type_as(query)
145
+
146
+ hidden_states = dispatch_attention_fn(
147
+ query,
148
+ key,
149
+ value,
150
+ attn_mask=attention_mask,
151
+ dropout_p=0.0,
152
+ is_causal=False,
153
+ backend=self._attention_backend,
154
+ )
155
+
156
+ hidden_states = hidden_states.flatten(2, 3)
157
+ hidden_states = hidden_states.type_as(query)
158
+
159
+ if hidden_states_img is not None:
160
+ hidden_states = hidden_states + hidden_states_img
161
+
162
+ hidden_states = attn.to_out[0](hidden_states)
163
+ hidden_states = attn.to_out[1](hidden_states)
164
+ return hidden_states
165
+
166
+
167
+ class SkyReelsV2AttnProcessor2_0:
168
+ def __new__(cls, *args, **kwargs):
169
+ deprecation_message = (
170
+ "The SkyReelsV2AttnProcessor2_0 class is deprecated and will be removed in a future version. "
171
+ "Please use SkyReelsV2AttnProcessor instead. "
172
+ )
173
+ deprecate("SkyReelsV2AttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False)
174
+ return SkyReelsV2AttnProcessor(*args, **kwargs)
175
+
176
+
177
+ class SkyReelsV2Attention(torch.nn.Module, AttentionModuleMixin):
178
+ _default_processor_cls = SkyReelsV2AttnProcessor
179
+ _available_processors = [SkyReelsV2AttnProcessor]
180
+
181
+ def __init__(
182
+ self,
183
+ dim: int,
184
+ heads: int = 8,
185
+ dim_head: int = 64,
186
+ eps: float = 1e-5,
187
+ dropout: float = 0.0,
188
+ added_kv_proj_dim: Optional[int] = None,
189
+ cross_attention_dim_head: Optional[int] = None,
190
+ processor=None,
191
+ is_cross_attention=None,
192
+ ):
193
+ super().__init__()
194
+
195
+ self.inner_dim = dim_head * heads
196
+ self.heads = heads
197
+ self.added_kv_proj_dim = added_kv_proj_dim
198
+ self.cross_attention_dim_head = cross_attention_dim_head
199
+ self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
200
+
201
+ self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
202
+ self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
203
+ self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
204
+ self.to_out = torch.nn.ModuleList(
205
+ [
206
+ torch.nn.Linear(self.inner_dim, dim, bias=True),
207
+ torch.nn.Dropout(dropout),
208
+ ]
209
+ )
210
+ self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
211
+ self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
212
+
213
+ self.add_k_proj = self.add_v_proj = None
214
+ if added_kv_proj_dim is not None:
215
+ self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
216
+ self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
217
+ self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
218
+
219
+ self.is_cross_attention = cross_attention_dim_head is not None
220
+
221
+ self.set_processor(processor)
222
+
223
+ def fuse_projections(self):
224
+ if getattr(self, "fused_projections", False):
225
+ return
226
+
227
+ if self.cross_attention_dim_head is None:
228
+ concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
229
+ concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
230
+ out_features, in_features = concatenated_weights.shape
231
+ with torch.device("meta"):
232
+ self.to_qkv = nn.Linear(in_features, out_features, bias=True)
233
+ self.to_qkv.load_state_dict(
234
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
235
+ )
236
+ else:
237
+ concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
238
+ concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
239
+ out_features, in_features = concatenated_weights.shape
240
+ with torch.device("meta"):
241
+ self.to_kv = nn.Linear(in_features, out_features, bias=True)
242
+ self.to_kv.load_state_dict(
243
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
244
+ )
245
+
246
+ if self.added_kv_proj_dim is not None:
247
+ concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data])
248
+ concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data])
249
+ out_features, in_features = concatenated_weights.shape
250
+ with torch.device("meta"):
251
+ self.to_added_kv = nn.Linear(in_features, out_features, bias=True)
252
+ self.to_added_kv.load_state_dict(
253
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
254
+ )
255
+
256
+ self.fused_projections = True
257
+
258
+ @torch.no_grad()
259
+ def unfuse_projections(self):
260
+ if not getattr(self, "fused_projections", False):
261
+ return
262
+
263
+ if hasattr(self, "to_qkv"):
264
+ delattr(self, "to_qkv")
265
+ if hasattr(self, "to_kv"):
266
+ delattr(self, "to_kv")
267
+ if hasattr(self, "to_added_kv"):
268
+ delattr(self, "to_added_kv")
269
+
270
+ self.fused_projections = False
271
+
272
+ def forward(
273
+ self,
274
+ hidden_states: torch.Tensor,
275
+ encoder_hidden_states: Optional[torch.Tensor] = None,
276
+ attention_mask: Optional[torch.Tensor] = None,
277
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
278
+ **kwargs,
279
+ ) -> torch.Tensor:
280
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs)
281
+
282
+
283
+ class SkyReelsV2ImageEmbedding(torch.nn.Module):
284
+ def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
285
+ super().__init__()
286
+
287
+ self.norm1 = FP32LayerNorm(in_features)
288
+ self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
289
+ self.norm2 = FP32LayerNorm(out_features)
290
+ if pos_embed_seq_len is not None:
291
+ self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
292
+ else:
293
+ self.pos_embed = None
294
+
295
+ def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
296
+ if self.pos_embed is not None:
297
+ batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
298
+ encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
299
+ encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
300
+
301
+ hidden_states = self.norm1(encoder_hidden_states_image)
302
+ hidden_states = self.ff(hidden_states)
303
+ hidden_states = self.norm2(hidden_states)
304
+ return hidden_states
305
+
306
+
307
+ class SkyReelsV2Timesteps(nn.Module):
308
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, output_type: str = "pt"):
309
+ super().__init__()
310
+ self.num_channels = num_channels
311
+ self.output_type = output_type
312
+ self.flip_sin_to_cos = flip_sin_to_cos
313
+
314
+ def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
315
+ original_shape = timesteps.shape
316
+ t_emb = get_1d_sincos_pos_embed_from_grid(
317
+ self.num_channels,
318
+ timesteps,
319
+ output_type=self.output_type,
320
+ flip_sin_to_cos=self.flip_sin_to_cos,
321
+ )
322
+ # Reshape back to maintain batch structure
323
+ if len(original_shape) > 1:
324
+ t_emb = t_emb.reshape(*original_shape, self.num_channels)
325
+ return t_emb
326
+
327
+
328
+ class SkyReelsV2TimeTextImageEmbedding(nn.Module):
329
+ def __init__(
330
+ self,
331
+ dim: int,
332
+ time_freq_dim: int,
333
+ time_proj_dim: int,
334
+ text_embed_dim: int,
335
+ image_embed_dim: Optional[int] = None,
336
+ pos_embed_seq_len: Optional[int] = None,
337
+ ):
338
+ super().__init__()
339
+
340
+ self.timesteps_proj = SkyReelsV2Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True)
341
+ self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
342
+ self.act_fn = nn.SiLU()
343
+ self.time_proj = nn.Linear(dim, time_proj_dim)
344
+ self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
345
+
346
+ self.image_embedder = None
347
+ if image_embed_dim is not None:
348
+ self.image_embedder = SkyReelsV2ImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len)
349
+
350
+ def forward(
351
+ self,
352
+ timestep: torch.Tensor,
353
+ encoder_hidden_states: torch.Tensor,
354
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
355
+ ):
356
+ timestep = self.timesteps_proj(timestep)
357
+
358
+ time_embedder_dtype = get_parameter_dtype(self.time_embedder)
359
+ if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
360
+ timestep = timestep.to(time_embedder_dtype)
361
+ temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
362
+ timestep_proj = self.time_proj(self.act_fn(temb))
363
+
364
+ encoder_hidden_states = self.text_embedder(encoder_hidden_states)
365
+ if encoder_hidden_states_image is not None:
366
+ encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
367
+
368
+ return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
369
+
370
+
371
+ class SkyReelsV2RotaryPosEmbed(nn.Module):
372
+ def __init__(
373
+ self,
374
+ attention_head_dim: int,
375
+ patch_size: Tuple[int, int, int],
376
+ max_seq_len: int,
377
+ theta: float = 10000.0,
378
+ ):
379
+ super().__init__()
380
+
381
+ self.attention_head_dim = attention_head_dim
382
+ self.patch_size = patch_size
383
+ self.max_seq_len = max_seq_len
384
+
385
+ h_dim = w_dim = 2 * (attention_head_dim // 6)
386
+ t_dim = attention_head_dim - h_dim - w_dim
387
+ freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
388
+
389
+ freqs_cos = []
390
+ freqs_sin = []
391
+
392
+ for dim in [t_dim, h_dim, w_dim]:
393
+ freq_cos, freq_sin = get_1d_rotary_pos_embed(
394
+ dim,
395
+ max_seq_len,
396
+ theta,
397
+ use_real=True,
398
+ repeat_interleave_real=True,
399
+ freqs_dtype=freqs_dtype,
400
+ )
401
+ freqs_cos.append(freq_cos)
402
+ freqs_sin.append(freq_sin)
403
+
404
+ self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
405
+ self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
406
+
407
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
408
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
409
+ p_t, p_h, p_w = self.patch_size
410
+ ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
411
+
412
+ split_sizes = [
413
+ self.attention_head_dim - 2 * (self.attention_head_dim // 3),
414
+ self.attention_head_dim // 3,
415
+ self.attention_head_dim // 3,
416
+ ]
417
+
418
+ freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
419
+ freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
420
+
421
+ freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
422
+ freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
423
+ freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
424
+
425
+ freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
426
+ freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
427
+ freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
428
+
429
+ freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
430
+ freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
431
+
432
+ return freqs_cos, freqs_sin
433
+
434
+
435
+ @maybe_allow_in_graph
436
+ class SkyReelsV2TransformerBlock(nn.Module):
437
+ def __init__(
438
+ self,
439
+ dim: int,
440
+ ffn_dim: int,
441
+ num_heads: int,
442
+ qk_norm: str = "rms_norm_across_heads",
443
+ cross_attn_norm: bool = False,
444
+ eps: float = 1e-6,
445
+ added_kv_proj_dim: Optional[int] = None,
446
+ ):
447
+ super().__init__()
448
+
449
+ # 1. Self-attention
450
+ self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
451
+ self.attn1 = SkyReelsV2Attention(
452
+ dim=dim,
453
+ heads=num_heads,
454
+ dim_head=dim // num_heads,
455
+ eps=eps,
456
+ cross_attention_dim_head=None,
457
+ processor=SkyReelsV2AttnProcessor(),
458
+ )
459
+
460
+ # 2. Cross-attention
461
+ self.attn2 = SkyReelsV2Attention(
462
+ dim=dim,
463
+ heads=num_heads,
464
+ dim_head=dim // num_heads,
465
+ eps=eps,
466
+ added_kv_proj_dim=added_kv_proj_dim,
467
+ cross_attention_dim_head=dim // num_heads,
468
+ processor=SkyReelsV2AttnProcessor(),
469
+ )
470
+ self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
471
+
472
+ # 3. Feed-forward
473
+ self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
474
+ self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
475
+
476
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
477
+
478
+ def forward(
479
+ self,
480
+ hidden_states: torch.Tensor,
481
+ encoder_hidden_states: torch.Tensor,
482
+ temb: torch.Tensor,
483
+ rotary_emb: torch.Tensor,
484
+ attention_mask: torch.Tensor,
485
+ ) -> torch.Tensor:
486
+ if temb.dim() == 3:
487
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
488
+ self.scale_shift_table + temb.float()
489
+ ).chunk(6, dim=1)
490
+ elif temb.dim() == 4:
491
+ # For 4D temb in Diffusion Forcing framework, we assume the shape is (b, 6, f * pp_h * pp_w, inner_dim)
492
+ e = (self.scale_shift_table.unsqueeze(2) + temb.float()).chunk(6, dim=1)
493
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in e]
494
+
495
+ # 1. Self-attention
496
+ norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
497
+ attn_output = self.attn1(norm_hidden_states, None, attention_mask, rotary_emb)
498
+ hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
499
+
500
+ # 2. Cross-attention
501
+ norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
502
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None)
503
+ hidden_states = hidden_states + attn_output
504
+
505
+ # 3. Feed-forward
506
+ norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
507
+ hidden_states
508
+ )
509
+ ff_output = self.ffn(norm_hidden_states)
510
+ hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
511
+
512
+ return hidden_states
513
+
514
+
515
+ class SkyReelsV2Transformer3DModel(
516
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
517
+ ):
518
+ r"""
519
+ A Transformer model for video-like data used in the Wan-based SkyReels-V2 model.
520
+
521
+ Args:
522
+ patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
523
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
524
+ num_attention_heads (`int`, defaults to `16`):
525
+ Fixed length for text embeddings.
526
+ attention_head_dim (`int`, defaults to `128`):
527
+ The number of channels in each head.
528
+ in_channels (`int`, defaults to `16`):
529
+ The number of channels in the input.
530
+ out_channels (`int`, defaults to `16`):
531
+ The number of channels in the output.
532
+ text_dim (`int`, defaults to `4096`):
533
+ Input dimension for text embeddings.
534
+ freq_dim (`int`, defaults to `256`):
535
+ Dimension for sinusoidal time embeddings.
536
+ ffn_dim (`int`, defaults to `8192`):
537
+ Intermediate dimension in feed-forward network.
538
+ num_layers (`int`, defaults to `32`):
539
+ The number of layers of transformer blocks to use.
540
+ window_size (`Tuple[int]`, defaults to `(-1, -1)`):
541
+ Window size for local attention (-1 indicates global attention).
542
+ cross_attn_norm (`bool`, defaults to `True`):
543
+ Enable cross-attention normalization.
544
+ qk_norm (`str`, *optional*, defaults to `"rms_norm_across_heads"`):
545
+ Enable query/key normalization.
546
+ eps (`float`, defaults to `1e-6`):
547
+ Epsilon value for normalization layers.
548
+ inject_sample_info (`bool`, defaults to `False`):
549
+ Whether to inject sample information into the model.
550
+ image_dim (`int`, *optional*):
551
+ The dimension of the image embeddings.
552
+ added_kv_proj_dim (`int`, *optional*):
553
+ The dimension of the added key/value projection.
554
+ rope_max_seq_len (`int`, defaults to `1024`):
555
+ The maximum sequence length for the rotary embeddings.
556
+ pos_embed_seq_len (`int`, *optional*):
557
+ The sequence length for the positional embeddings.
558
+ """
559
+
560
+ _supports_gradient_checkpointing = True
561
+ _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
562
+ _no_split_modules = ["SkyReelsV2TransformerBlock"]
563
+ _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
564
+ _keys_to_ignore_on_load_unexpected = ["norm_added_q"]
565
+ _repeated_blocks = ["SkyReelsV2TransformerBlock"]
566
+
567
+ @register_to_config
568
+ def __init__(
569
+ self,
570
+ patch_size: Tuple[int] = (1, 2, 2),
571
+ num_attention_heads: int = 16,
572
+ attention_head_dim: int = 128,
573
+ in_channels: int = 16,
574
+ out_channels: int = 16,
575
+ text_dim: int = 4096,
576
+ freq_dim: int = 256,
577
+ ffn_dim: int = 8192,
578
+ num_layers: int = 32,
579
+ cross_attn_norm: bool = True,
580
+ qk_norm: Optional[str] = "rms_norm_across_heads",
581
+ eps: float = 1e-6,
582
+ image_dim: Optional[int] = None,
583
+ added_kv_proj_dim: Optional[int] = None,
584
+ rope_max_seq_len: int = 1024,
585
+ pos_embed_seq_len: Optional[int] = None,
586
+ inject_sample_info: bool = False,
587
+ num_frame_per_block: int = 1,
588
+ ) -> None:
589
+ super().__init__()
590
+
591
+ inner_dim = num_attention_heads * attention_head_dim
592
+ out_channels = out_channels or in_channels
593
+
594
+ # 1. Patch & position embedding
595
+ self.rope = SkyReelsV2RotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
596
+ self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
597
+
598
+ # 2. Condition embeddings
599
+ # image_embedding_dim=1280 for I2V model
600
+ self.condition_embedder = SkyReelsV2TimeTextImageEmbedding(
601
+ dim=inner_dim,
602
+ time_freq_dim=freq_dim,
603
+ time_proj_dim=inner_dim * 6,
604
+ text_embed_dim=text_dim,
605
+ image_embed_dim=image_dim,
606
+ pos_embed_seq_len=pos_embed_seq_len,
607
+ )
608
+
609
+ # 3. Transformer blocks
610
+ self.blocks = nn.ModuleList(
611
+ [
612
+ SkyReelsV2TransformerBlock(
613
+ inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
614
+ )
615
+ for _ in range(num_layers)
616
+ ]
617
+ )
618
+
619
+ # 4. Output norm & projection
620
+ self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
621
+ self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
622
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
623
+
624
+ if inject_sample_info:
625
+ self.fps_embedding = nn.Embedding(2, inner_dim)
626
+ self.fps_projection = FeedForward(inner_dim, inner_dim * 6, mult=1, activation_fn="linear-silu")
627
+
628
+ self.gradient_checkpointing = False
629
+
630
+ def forward(
631
+ self,
632
+ hidden_states: torch.Tensor,
633
+ timestep: torch.LongTensor,
634
+ encoder_hidden_states: torch.Tensor,
635
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
636
+ enable_diffusion_forcing: bool = False,
637
+ fps: Optional[torch.Tensor] = None,
638
+ return_dict: bool = True,
639
+ attention_kwargs: Optional[Dict[str, Any]] = None,
640
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
641
+ if attention_kwargs is not None:
642
+ attention_kwargs = attention_kwargs.copy()
643
+ lora_scale = attention_kwargs.pop("scale", 1.0)
644
+ else:
645
+ lora_scale = 1.0
646
+
647
+ if USE_PEFT_BACKEND:
648
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
649
+ scale_lora_layers(self, lora_scale)
650
+ else:
651
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
652
+ logger.warning(
653
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
654
+ )
655
+
656
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
657
+ p_t, p_h, p_w = self.config.patch_size
658
+ post_patch_num_frames = num_frames // p_t
659
+ post_patch_height = height // p_h
660
+ post_patch_width = width // p_w
661
+
662
+ rotary_emb = self.rope(hidden_states)
663
+
664
+ hidden_states = self.patch_embedding(hidden_states)
665
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
666
+
667
+ causal_mask = None
668
+ if self.config.num_frame_per_block > 1:
669
+ block_num = post_patch_num_frames // self.config.num_frame_per_block
670
+ range_tensor = torch.arange(block_num, device=hidden_states.device).repeat_interleave(
671
+ self.config.num_frame_per_block
672
+ )
673
+ causal_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f
674
+ causal_mask = causal_mask.view(post_patch_num_frames, 1, 1, post_patch_num_frames, 1, 1)
675
+ causal_mask = causal_mask.repeat(
676
+ 1, post_patch_height, post_patch_width, 1, post_patch_height, post_patch_width
677
+ )
678
+ causal_mask = causal_mask.reshape(
679
+ post_patch_num_frames * post_patch_height * post_patch_width,
680
+ post_patch_num_frames * post_patch_height * post_patch_width,
681
+ )
682
+ causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
683
+
684
+ temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
685
+ timestep, encoder_hidden_states, encoder_hidden_states_image
686
+ )
687
+
688
+ timestep_proj = timestep_proj.unflatten(-1, (6, -1))
689
+
690
+ if encoder_hidden_states_image is not None:
691
+ encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
692
+
693
+ if self.config.inject_sample_info:
694
+ fps = torch.tensor(fps, dtype=torch.long, device=hidden_states.device)
695
+
696
+ fps_emb = self.fps_embedding(fps)
697
+ if enable_diffusion_forcing:
698
+ timestep_proj = timestep_proj + self.fps_projection(fps_emb).unflatten(1, (6, -1)).repeat(
699
+ timestep.shape[1], 1, 1
700
+ )
701
+ else:
702
+ timestep_proj = timestep_proj + self.fps_projection(fps_emb).unflatten(1, (6, -1))
703
+
704
+ if enable_diffusion_forcing:
705
+ b, f = timestep.shape
706
+ temb = temb.view(b, f, 1, 1, -1)
707
+ timestep_proj = timestep_proj.view(b, f, 1, 1, 6, -1) # (b, f, 1, 1, 6, inner_dim)
708
+ temb = temb.repeat(1, 1, post_patch_height, post_patch_width, 1).flatten(1, 3)
709
+ timestep_proj = timestep_proj.repeat(1, 1, post_patch_height, post_patch_width, 1, 1).flatten(
710
+ 1, 3
711
+ ) # (b, f, pp_h, pp_w, 6, inner_dim) -> (b, f * pp_h * pp_w, 6, inner_dim)
712
+ timestep_proj = timestep_proj.transpose(1, 2).contiguous() # (b, 6, f * pp_h * pp_w, inner_dim)
713
+
714
+ # 4. Transformer blocks
715
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
716
+ for block in self.blocks:
717
+ hidden_states = self._gradient_checkpointing_func(
718
+ block,
719
+ hidden_states,
720
+ encoder_hidden_states,
721
+ timestep_proj,
722
+ rotary_emb,
723
+ causal_mask,
724
+ )
725
+ else:
726
+ for block in self.blocks:
727
+ hidden_states = block(
728
+ hidden_states,
729
+ encoder_hidden_states,
730
+ timestep_proj,
731
+ rotary_emb,
732
+ causal_mask,
733
+ )
734
+
735
+ if temb.dim() == 2:
736
+ # If temb is 2D, we assume it has time 1-D time embedding values for each batch.
737
+ # For models:
738
+ # - Skywork/SkyReels-V2-T2V-14B-540P-Diffusers
739
+ # - Skywork/SkyReels-V2-T2V-14B-720P-Diffusers
740
+ # - Skywork/SkyReels-V2-I2V-1.3B-540P-Diffusers
741
+ # - Skywork/SkyReels-V2-I2V-14B-540P-Diffusers
742
+ # - Skywork/SkyReels-V2-I2V-14B-720P-Diffusers
743
+ shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
744
+ elif temb.dim() == 3:
745
+ # If temb is 3D, we assume it has 2-D time embedding values for each batch.
746
+ # Each time embedding tensor includes values for each latent frame; thus Diffusion Forcing.
747
+ # For models:
748
+ # - Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers
749
+ # - Skywork/SkyReels-V2-DF-14B-540P-Diffusers
750
+ # - Skywork/SkyReels-V2-DF-14B-720P-Diffusers
751
+ shift, scale = (self.scale_shift_table.unsqueeze(2) + temb.unsqueeze(1)).chunk(2, dim=1)
752
+ shift, scale = shift.squeeze(1), scale.squeeze(1)
753
+
754
+ # Move the shift and scale tensors to the same device as hidden_states.
755
+ # When using multi-GPU inference via accelerate these will be on the
756
+ # first device rather than the last device, which hidden_states ends up
757
+ # on.
758
+ shift = shift.to(hidden_states.device)
759
+ scale = scale.to(hidden_states.device)
760
+
761
+ hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
762
+
763
+ hidden_states = self.proj_out(hidden_states)
764
+
765
+ hidden_states = hidden_states.reshape(
766
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
767
+ )
768
+ hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
769
+ output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
770
+
771
+ if USE_PEFT_BACKEND:
772
+ # remove `lora_scale` from each PEFT layer
773
+ unscale_lora_layers(self, lora_scale)
774
+
775
+ if not return_dict:
776
+ return (output,)
777
+
778
+ return Transformer2DModelOutput(sample=output)
779
+
780
+ def _set_ar_attention(self, causal_block_size: int):
781
+ self.register_to_config(num_frame_per_block=causal_block_size)
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/transformer_temporal.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+ from torch import nn
19
+
20
+ from ...configuration_utils import ConfigMixin, register_to_config
21
+ from ...utils import BaseOutput
22
+ from ..attention import BasicTransformerBlock, TemporalBasicTransformerBlock
23
+ from ..embeddings import TimestepEmbedding, Timesteps
24
+ from ..modeling_utils import ModelMixin
25
+ from ..resnet import AlphaBlender
26
+
27
+
28
+ @dataclass
29
+ class TransformerTemporalModelOutput(BaseOutput):
30
+ """
31
+ The output of [`TransformerTemporalModel`].
32
+
33
+ Args:
34
+ sample (`torch.Tensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
35
+ The hidden states output conditioned on `encoder_hidden_states` input.
36
+ """
37
+
38
+ sample: torch.Tensor
39
+
40
+
41
+ class TransformerTemporalModel(ModelMixin, ConfigMixin):
42
+ """
43
+ A Transformer model for video-like data.
44
+
45
+ Parameters:
46
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
47
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
48
+ in_channels (`int`, *optional*):
49
+ The number of channels in the input and output (specify if the input is **continuous**).
50
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
51
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
52
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
53
+ attention_bias (`bool`, *optional*):
54
+ Configure if the `TransformerBlock` attention should contain a bias parameter.
55
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
56
+ This is fixed during training since it is used to learn a number of position embeddings.
57
+ activation_fn (`str`, *optional*, defaults to `"geglu"`):
58
+ Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
59
+ activation functions.
60
+ norm_elementwise_affine (`bool`, *optional*):
61
+ Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
62
+ double_self_attention (`bool`, *optional*):
63
+ Configure if each `TransformerBlock` should contain two self-attention layers.
64
+ positional_embeddings: (`str`, *optional*):
65
+ The type of positional embeddings to apply to the sequence input before passing use.
66
+ num_positional_embeddings: (`int`, *optional*):
67
+ The maximum length of the sequence over which to apply positional embeddings.
68
+ """
69
+
70
+ _skip_layerwise_casting_patterns = ["norm"]
71
+
72
+ @register_to_config
73
+ def __init__(
74
+ self,
75
+ num_attention_heads: int = 16,
76
+ attention_head_dim: int = 88,
77
+ in_channels: Optional[int] = None,
78
+ out_channels: Optional[int] = None,
79
+ num_layers: int = 1,
80
+ dropout: float = 0.0,
81
+ norm_num_groups: int = 32,
82
+ cross_attention_dim: Optional[int] = None,
83
+ attention_bias: bool = False,
84
+ sample_size: Optional[int] = None,
85
+ activation_fn: str = "geglu",
86
+ norm_elementwise_affine: bool = True,
87
+ double_self_attention: bool = True,
88
+ positional_embeddings: Optional[str] = None,
89
+ num_positional_embeddings: Optional[int] = None,
90
+ ):
91
+ super().__init__()
92
+ self.num_attention_heads = num_attention_heads
93
+ self.attention_head_dim = attention_head_dim
94
+ inner_dim = num_attention_heads * attention_head_dim
95
+
96
+ self.in_channels = in_channels
97
+
98
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
99
+ self.proj_in = nn.Linear(in_channels, inner_dim)
100
+
101
+ # 3. Define transformers blocks
102
+ self.transformer_blocks = nn.ModuleList(
103
+ [
104
+ BasicTransformerBlock(
105
+ inner_dim,
106
+ num_attention_heads,
107
+ attention_head_dim,
108
+ dropout=dropout,
109
+ cross_attention_dim=cross_attention_dim,
110
+ activation_fn=activation_fn,
111
+ attention_bias=attention_bias,
112
+ double_self_attention=double_self_attention,
113
+ norm_elementwise_affine=norm_elementwise_affine,
114
+ positional_embeddings=positional_embeddings,
115
+ num_positional_embeddings=num_positional_embeddings,
116
+ )
117
+ for d in range(num_layers)
118
+ ]
119
+ )
120
+
121
+ self.proj_out = nn.Linear(inner_dim, in_channels)
122
+
123
+ def forward(
124
+ self,
125
+ hidden_states: torch.Tensor,
126
+ encoder_hidden_states: Optional[torch.LongTensor] = None,
127
+ timestep: Optional[torch.LongTensor] = None,
128
+ class_labels: torch.LongTensor = None,
129
+ num_frames: int = 1,
130
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
131
+ return_dict: bool = True,
132
+ ) -> TransformerTemporalModelOutput:
133
+ """
134
+ The [`TransformerTemporal`] forward method.
135
+
136
+ Args:
137
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous):
138
+ Input hidden_states.
139
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
140
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
141
+ self-attention.
142
+ timestep ( `torch.LongTensor`, *optional*):
143
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
144
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
145
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
146
+ `AdaLayerZeroNorm`.
147
+ num_frames (`int`, *optional*, defaults to 1):
148
+ The number of frames to be processed per batch. This is used to reshape the hidden states.
149
+ cross_attention_kwargs (`dict`, *optional*):
150
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
151
+ `self.processor` in
152
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
153
+ return_dict (`bool`, *optional*, defaults to `True`):
154
+ Whether or not to return a [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`]
155
+ instead of a plain tuple.
156
+
157
+ Returns:
158
+ [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
159
+ If `return_dict` is True, an
160
+ [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] is returned, otherwise a
161
+ `tuple` where the first element is the sample tensor.
162
+ """
163
+ # 1. Input
164
+ batch_frames, channel, height, width = hidden_states.shape
165
+ batch_size = batch_frames // num_frames
166
+
167
+ residual = hidden_states
168
+
169
+ hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
170
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
171
+
172
+ hidden_states = self.norm(hidden_states)
173
+ hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
174
+
175
+ hidden_states = self.proj_in(hidden_states)
176
+
177
+ # 2. Blocks
178
+ for block in self.transformer_blocks:
179
+ hidden_states = block(
180
+ hidden_states,
181
+ encoder_hidden_states=encoder_hidden_states,
182
+ timestep=timestep,
183
+ cross_attention_kwargs=cross_attention_kwargs,
184
+ class_labels=class_labels,
185
+ )
186
+
187
+ # 3. Output
188
+ hidden_states = self.proj_out(hidden_states)
189
+ hidden_states = (
190
+ hidden_states[None, None, :]
191
+ .reshape(batch_size, height, width, num_frames, channel)
192
+ .permute(0, 3, 4, 1, 2)
193
+ .contiguous()
194
+ )
195
+ hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
196
+
197
+ output = hidden_states + residual
198
+
199
+ if not return_dict:
200
+ return (output,)
201
+
202
+ return TransformerTemporalModelOutput(sample=output)
203
+
204
+
205
+ class TransformerSpatioTemporalModel(nn.Module):
206
+ """
207
+ A Transformer model for video-like data.
208
+
209
+ Parameters:
210
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
211
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
212
+ in_channels (`int`, *optional*):
213
+ The number of channels in the input and output (specify if the input is **continuous**).
214
+ out_channels (`int`, *optional*):
215
+ The number of channels in the output (specify if the input is **continuous**).
216
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
217
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
218
+ """
219
+
220
+ def __init__(
221
+ self,
222
+ num_attention_heads: int = 16,
223
+ attention_head_dim: int = 88,
224
+ in_channels: int = 320,
225
+ out_channels: Optional[int] = None,
226
+ num_layers: int = 1,
227
+ cross_attention_dim: Optional[int] = None,
228
+ ):
229
+ super().__init__()
230
+ self.num_attention_heads = num_attention_heads
231
+ self.attention_head_dim = attention_head_dim
232
+
233
+ inner_dim = num_attention_heads * attention_head_dim
234
+ self.inner_dim = inner_dim
235
+
236
+ # 2. Define input layers
237
+ self.in_channels = in_channels
238
+ self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6)
239
+ self.proj_in = nn.Linear(in_channels, inner_dim)
240
+
241
+ # 3. Define transformers blocks
242
+ self.transformer_blocks = nn.ModuleList(
243
+ [
244
+ BasicTransformerBlock(
245
+ inner_dim,
246
+ num_attention_heads,
247
+ attention_head_dim,
248
+ cross_attention_dim=cross_attention_dim,
249
+ )
250
+ for d in range(num_layers)
251
+ ]
252
+ )
253
+
254
+ time_mix_inner_dim = inner_dim
255
+ self.temporal_transformer_blocks = nn.ModuleList(
256
+ [
257
+ TemporalBasicTransformerBlock(
258
+ inner_dim,
259
+ time_mix_inner_dim,
260
+ num_attention_heads,
261
+ attention_head_dim,
262
+ cross_attention_dim=cross_attention_dim,
263
+ )
264
+ for _ in range(num_layers)
265
+ ]
266
+ )
267
+
268
+ time_embed_dim = in_channels * 4
269
+ self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels)
270
+ self.time_proj = Timesteps(in_channels, True, 0)
271
+ self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images")
272
+
273
+ # 4. Define output layers
274
+ self.out_channels = in_channels if out_channels is None else out_channels
275
+ # TODO: should use out_channels for continuous projections
276
+ self.proj_out = nn.Linear(inner_dim, in_channels)
277
+
278
+ self.gradient_checkpointing = False
279
+
280
+ def forward(
281
+ self,
282
+ hidden_states: torch.Tensor,
283
+ encoder_hidden_states: Optional[torch.Tensor] = None,
284
+ image_only_indicator: Optional[torch.Tensor] = None,
285
+ return_dict: bool = True,
286
+ ):
287
+ """
288
+ Args:
289
+ hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`):
290
+ Input hidden_states.
291
+ num_frames (`int`):
292
+ The number of frames to be processed per batch. This is used to reshape the hidden states.
293
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
294
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
295
+ self-attention.
296
+ image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*):
297
+ A tensor indicating whether the input contains only images. 1 indicates that the input contains only
298
+ images, 0 indicates that the input contains video frames.
299
+ return_dict (`bool`, *optional*, defaults to `True`):
300
+ Whether or not to return a [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`]
301
+ instead of a plain tuple.
302
+
303
+ Returns:
304
+ [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
305
+ If `return_dict` is True, an
306
+ [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] is returned, otherwise a
307
+ `tuple` where the first element is the sample tensor.
308
+ """
309
+ # 1. Input
310
+ batch_frames, _, height, width = hidden_states.shape
311
+ num_frames = image_only_indicator.shape[-1]
312
+ batch_size = batch_frames // num_frames
313
+
314
+ time_context = encoder_hidden_states
315
+ time_context_first_timestep = time_context[None, :].reshape(
316
+ batch_size, num_frames, -1, time_context.shape[-1]
317
+ )[:, 0]
318
+ time_context = time_context_first_timestep[:, None].broadcast_to(
319
+ batch_size, height * width, time_context.shape[-2], time_context.shape[-1]
320
+ )
321
+ time_context = time_context.reshape(batch_size * height * width, -1, time_context.shape[-1])
322
+
323
+ residual = hidden_states
324
+
325
+ hidden_states = self.norm(hidden_states)
326
+ inner_dim = hidden_states.shape[1]
327
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim)
328
+ hidden_states = self.proj_in(hidden_states)
329
+
330
+ num_frames_emb = torch.arange(num_frames, device=hidden_states.device)
331
+ num_frames_emb = num_frames_emb.repeat(batch_size, 1)
332
+ num_frames_emb = num_frames_emb.reshape(-1)
333
+ t_emb = self.time_proj(num_frames_emb)
334
+
335
+ # `Timesteps` does not contain any weights and will always return f32 tensors
336
+ # but time_embedding might actually be running in fp16. so we need to cast here.
337
+ # there might be better ways to encapsulate this.
338
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
339
+
340
+ emb = self.time_pos_embed(t_emb)
341
+ emb = emb[:, None, :]
342
+
343
+ # 2. Blocks
344
+ for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
345
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
346
+ hidden_states = self._gradient_checkpointing_func(
347
+ block, hidden_states, None, encoder_hidden_states, None
348
+ )
349
+ else:
350
+ hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
351
+
352
+ hidden_states_mix = hidden_states
353
+ hidden_states_mix = hidden_states_mix + emb
354
+
355
+ hidden_states_mix = temporal_block(
356
+ hidden_states_mix,
357
+ num_frames=num_frames,
358
+ encoder_hidden_states=time_context,
359
+ )
360
+ hidden_states = self.time_mixer(
361
+ x_spatial=hidden_states,
362
+ x_temporal=hidden_states_mix,
363
+ image_only_indicator=image_only_indicator,
364
+ )
365
+
366
+ # 3. Output
367
+ hidden_states = self.proj_out(hidden_states)
368
+ hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
369
+
370
+ output = hidden_states + residual
371
+
372
+ if not return_dict:
373
+ return (output,)
374
+
375
+ return TransformerTemporalModelOutput(sample=output)
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/transformer_wan.py ADDED
@@ -0,0 +1,698 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import Any, Dict, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from ...configuration_utils import ConfigMixin, register_to_config
23
+ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
24
+ from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
25
+ from ...utils.torch_utils import maybe_allow_in_graph
26
+ from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
27
+ from ..attention_dispatch import dispatch_attention_fn
28
+ from ..cache_utils import CacheMixin
29
+ from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
30
+ from ..modeling_outputs import Transformer2DModelOutput
31
+ from ..modeling_utils import ModelMixin
32
+ from ..normalization import FP32LayerNorm
33
+
34
+
35
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
+
37
+
38
+ def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor):
39
+ # encoder_hidden_states is only passed for cross-attention
40
+ if encoder_hidden_states is None:
41
+ encoder_hidden_states = hidden_states
42
+
43
+ if attn.fused_projections:
44
+ if attn.cross_attention_dim_head is None:
45
+ # In self-attention layers, we can fuse the entire QKV projection into a single linear
46
+ query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
47
+ else:
48
+ # In cross-attention layers, we can only fuse the KV projections into a single linear
49
+ query = attn.to_q(hidden_states)
50
+ key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1)
51
+ else:
52
+ query = attn.to_q(hidden_states)
53
+ key = attn.to_k(encoder_hidden_states)
54
+ value = attn.to_v(encoder_hidden_states)
55
+ return query, key, value
56
+
57
+
58
+ def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: torch.Tensor):
59
+ if attn.fused_projections:
60
+ key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1)
61
+ else:
62
+ key_img = attn.add_k_proj(encoder_hidden_states_img)
63
+ value_img = attn.add_v_proj(encoder_hidden_states_img)
64
+ return key_img, value_img
65
+
66
+
67
+ class WanAttnProcessor:
68
+ _attention_backend = None
69
+
70
+ def __init__(self):
71
+ if not hasattr(F, "scaled_dot_product_attention"):
72
+ raise ImportError(
73
+ "WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
74
+ )
75
+
76
+ def __call__(
77
+ self,
78
+ attn: "WanAttention",
79
+ hidden_states: torch.Tensor,
80
+ encoder_hidden_states: Optional[torch.Tensor] = None,
81
+ attention_mask: Optional[torch.Tensor] = None,
82
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
83
+ ) -> torch.Tensor:
84
+ encoder_hidden_states_img = None
85
+ if attn.add_k_proj is not None:
86
+ # 512 is the context length of the text encoder, hardcoded for now
87
+ image_context_length = encoder_hidden_states.shape[1] - 512
88
+ encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
89
+ encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
90
+
91
+ query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
92
+
93
+ query = attn.norm_q(query)
94
+ key = attn.norm_k(key)
95
+
96
+ query = query.unflatten(2, (attn.heads, -1))
97
+ key = key.unflatten(2, (attn.heads, -1))
98
+ value = value.unflatten(2, (attn.heads, -1))
99
+
100
+ if rotary_emb is not None:
101
+
102
+ def apply_rotary_emb(
103
+ hidden_states: torch.Tensor,
104
+ freqs_cos: torch.Tensor,
105
+ freqs_sin: torch.Tensor,
106
+ ):
107
+ x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
108
+ cos = freqs_cos[..., 0::2]
109
+ sin = freqs_sin[..., 1::2]
110
+ out = torch.empty_like(hidden_states)
111
+ out[..., 0::2] = x1 * cos - x2 * sin
112
+ out[..., 1::2] = x1 * sin + x2 * cos
113
+ return out.type_as(hidden_states)
114
+
115
+ query = apply_rotary_emb(query, *rotary_emb)
116
+ key = apply_rotary_emb(key, *rotary_emb)
117
+
118
+ # I2V task
119
+ hidden_states_img = None
120
+ if encoder_hidden_states_img is not None:
121
+ key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img)
122
+ key_img = attn.norm_added_k(key_img)
123
+
124
+ key_img = key_img.unflatten(2, (attn.heads, -1))
125
+ value_img = value_img.unflatten(2, (attn.heads, -1))
126
+
127
+ hidden_states_img = dispatch_attention_fn(
128
+ query,
129
+ key_img,
130
+ value_img,
131
+ attn_mask=None,
132
+ dropout_p=0.0,
133
+ is_causal=False,
134
+ backend=self._attention_backend,
135
+ )
136
+ hidden_states_img = hidden_states_img.flatten(2, 3)
137
+ hidden_states_img = hidden_states_img.type_as(query)
138
+
139
+ hidden_states = dispatch_attention_fn(
140
+ query,
141
+ key,
142
+ value,
143
+ attn_mask=attention_mask,
144
+ dropout_p=0.0,
145
+ is_causal=False,
146
+ backend=self._attention_backend,
147
+ )
148
+ hidden_states = hidden_states.flatten(2, 3)
149
+ hidden_states = hidden_states.type_as(query)
150
+
151
+ if hidden_states_img is not None:
152
+ hidden_states = hidden_states + hidden_states_img
153
+
154
+ hidden_states = attn.to_out[0](hidden_states)
155
+ hidden_states = attn.to_out[1](hidden_states)
156
+ return hidden_states
157
+
158
+
159
+ class WanAttnProcessor2_0:
160
+ def __new__(cls, *args, **kwargs):
161
+ deprecation_message = (
162
+ "The WanAttnProcessor2_0 class is deprecated and will be removed in a future version. "
163
+ "Please use WanAttnProcessor instead. "
164
+ )
165
+ deprecate("WanAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False)
166
+ return WanAttnProcessor(*args, **kwargs)
167
+
168
+
169
+ class WanAttention(torch.nn.Module, AttentionModuleMixin):
170
+ _default_processor_cls = WanAttnProcessor
171
+ _available_processors = [WanAttnProcessor]
172
+
173
+ def __init__(
174
+ self,
175
+ dim: int,
176
+ heads: int = 8,
177
+ dim_head: int = 64,
178
+ eps: float = 1e-5,
179
+ dropout: float = 0.0,
180
+ added_kv_proj_dim: Optional[int] = None,
181
+ cross_attention_dim_head: Optional[int] = None,
182
+ processor=None,
183
+ is_cross_attention=None,
184
+ ):
185
+ super().__init__()
186
+
187
+ self.inner_dim = dim_head * heads
188
+ self.heads = heads
189
+ self.added_kv_proj_dim = added_kv_proj_dim
190
+ self.cross_attention_dim_head = cross_attention_dim_head
191
+ self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
192
+
193
+ self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
194
+ self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
195
+ self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
196
+ self.to_out = torch.nn.ModuleList(
197
+ [
198
+ torch.nn.Linear(self.inner_dim, dim, bias=True),
199
+ torch.nn.Dropout(dropout),
200
+ ]
201
+ )
202
+ self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
203
+ self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
204
+
205
+ self.add_k_proj = self.add_v_proj = None
206
+ if added_kv_proj_dim is not None:
207
+ self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
208
+ self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
209
+ self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
210
+
211
+ self.is_cross_attention = cross_attention_dim_head is not None
212
+
213
+ self.set_processor(processor)
214
+
215
+ def fuse_projections(self):
216
+ if getattr(self, "fused_projections", False):
217
+ return
218
+
219
+ if self.cross_attention_dim_head is None:
220
+ concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
221
+ concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
222
+ out_features, in_features = concatenated_weights.shape
223
+ with torch.device("meta"):
224
+ self.to_qkv = nn.Linear(in_features, out_features, bias=True)
225
+ self.to_qkv.load_state_dict(
226
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
227
+ )
228
+ else:
229
+ concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
230
+ concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
231
+ out_features, in_features = concatenated_weights.shape
232
+ with torch.device("meta"):
233
+ self.to_kv = nn.Linear(in_features, out_features, bias=True)
234
+ self.to_kv.load_state_dict(
235
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
236
+ )
237
+
238
+ if self.added_kv_proj_dim is not None:
239
+ concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data])
240
+ concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data])
241
+ out_features, in_features = concatenated_weights.shape
242
+ with torch.device("meta"):
243
+ self.to_added_kv = nn.Linear(in_features, out_features, bias=True)
244
+ self.to_added_kv.load_state_dict(
245
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
246
+ )
247
+
248
+ self.fused_projections = True
249
+
250
+ @torch.no_grad()
251
+ def unfuse_projections(self):
252
+ if not getattr(self, "fused_projections", False):
253
+ return
254
+
255
+ if hasattr(self, "to_qkv"):
256
+ delattr(self, "to_qkv")
257
+ if hasattr(self, "to_kv"):
258
+ delattr(self, "to_kv")
259
+ if hasattr(self, "to_added_kv"):
260
+ delattr(self, "to_added_kv")
261
+
262
+ self.fused_projections = False
263
+
264
+ def forward(
265
+ self,
266
+ hidden_states: torch.Tensor,
267
+ encoder_hidden_states: Optional[torch.Tensor] = None,
268
+ attention_mask: Optional[torch.Tensor] = None,
269
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
270
+ **kwargs,
271
+ ) -> torch.Tensor:
272
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs)
273
+
274
+
275
+ class WanImageEmbedding(torch.nn.Module):
276
+ def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
277
+ super().__init__()
278
+
279
+ self.norm1 = FP32LayerNorm(in_features)
280
+ self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
281
+ self.norm2 = FP32LayerNorm(out_features)
282
+ if pos_embed_seq_len is not None:
283
+ self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
284
+ else:
285
+ self.pos_embed = None
286
+
287
+ def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
288
+ if self.pos_embed is not None:
289
+ batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
290
+ encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
291
+ encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
292
+
293
+ hidden_states = self.norm1(encoder_hidden_states_image)
294
+ hidden_states = self.ff(hidden_states)
295
+ hidden_states = self.norm2(hidden_states)
296
+ return hidden_states
297
+
298
+
299
+ class WanTimeTextImageEmbedding(nn.Module):
300
+ def __init__(
301
+ self,
302
+ dim: int,
303
+ time_freq_dim: int,
304
+ time_proj_dim: int,
305
+ text_embed_dim: int,
306
+ image_embed_dim: Optional[int] = None,
307
+ pos_embed_seq_len: Optional[int] = None,
308
+ ):
309
+ super().__init__()
310
+
311
+ self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
312
+ self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
313
+ self.act_fn = nn.SiLU()
314
+ self.time_proj = nn.Linear(dim, time_proj_dim)
315
+ self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
316
+
317
+ self.image_embedder = None
318
+ if image_embed_dim is not None:
319
+ self.image_embedder = WanImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len)
320
+
321
+ def forward(
322
+ self,
323
+ timestep: torch.Tensor,
324
+ encoder_hidden_states: torch.Tensor,
325
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
326
+ timestep_seq_len: Optional[int] = None,
327
+ ):
328
+ timestep = self.timesteps_proj(timestep)
329
+ if timestep_seq_len is not None:
330
+ timestep = timestep.unflatten(0, (-1, timestep_seq_len))
331
+
332
+ time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
333
+ if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
334
+ timestep = timestep.to(time_embedder_dtype)
335
+ temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
336
+ timestep_proj = self.time_proj(self.act_fn(temb))
337
+
338
+ encoder_hidden_states = self.text_embedder(encoder_hidden_states)
339
+ if encoder_hidden_states_image is not None:
340
+ encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
341
+
342
+ return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
343
+
344
+
345
+ class WanRotaryPosEmbed(nn.Module):
346
+ def __init__(
347
+ self,
348
+ attention_head_dim: int,
349
+ patch_size: Tuple[int, int, int],
350
+ max_seq_len: int,
351
+ theta: float = 10000.0,
352
+ ):
353
+ super().__init__()
354
+
355
+ self.attention_head_dim = attention_head_dim
356
+ self.patch_size = patch_size
357
+ self.max_seq_len = max_seq_len
358
+
359
+ h_dim = w_dim = 2 * (attention_head_dim // 6)
360
+ t_dim = attention_head_dim - h_dim - w_dim
361
+ freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
362
+
363
+ freqs_cos = []
364
+ freqs_sin = []
365
+
366
+ for dim in [t_dim, h_dim, w_dim]:
367
+ freq_cos, freq_sin = get_1d_rotary_pos_embed(
368
+ dim,
369
+ max_seq_len,
370
+ theta,
371
+ use_real=True,
372
+ repeat_interleave_real=True,
373
+ freqs_dtype=freqs_dtype,
374
+ )
375
+ freqs_cos.append(freq_cos)
376
+ freqs_sin.append(freq_sin)
377
+
378
+ self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
379
+ self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
380
+
381
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
382
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
383
+ p_t, p_h, p_w = self.patch_size
384
+ ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
385
+
386
+ split_sizes = [
387
+ self.attention_head_dim - 2 * (self.attention_head_dim // 3),
388
+ self.attention_head_dim // 3,
389
+ self.attention_head_dim // 3,
390
+ ]
391
+
392
+ freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
393
+ freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
394
+
395
+ freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
396
+ freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
397
+ freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
398
+
399
+ freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
400
+ freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
401
+ freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
402
+
403
+ freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
404
+ freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
405
+
406
+ return freqs_cos, freqs_sin
407
+
408
+
409
+ @maybe_allow_in_graph
410
+ class WanTransformerBlock(nn.Module):
411
+ def __init__(
412
+ self,
413
+ dim: int,
414
+ ffn_dim: int,
415
+ num_heads: int,
416
+ qk_norm: str = "rms_norm_across_heads",
417
+ cross_attn_norm: bool = False,
418
+ eps: float = 1e-6,
419
+ added_kv_proj_dim: Optional[int] = None,
420
+ ):
421
+ super().__init__()
422
+
423
+ # 1. Self-attention
424
+ self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
425
+ self.attn1 = WanAttention(
426
+ dim=dim,
427
+ heads=num_heads,
428
+ dim_head=dim // num_heads,
429
+ eps=eps,
430
+ cross_attention_dim_head=None,
431
+ processor=WanAttnProcessor(),
432
+ )
433
+
434
+ # 2. Cross-attention
435
+ self.attn2 = WanAttention(
436
+ dim=dim,
437
+ heads=num_heads,
438
+ dim_head=dim // num_heads,
439
+ eps=eps,
440
+ added_kv_proj_dim=added_kv_proj_dim,
441
+ cross_attention_dim_head=dim // num_heads,
442
+ processor=WanAttnProcessor(),
443
+ )
444
+ self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
445
+
446
+ # 3. Feed-forward
447
+ self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
448
+ self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
449
+
450
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
451
+
452
+ def forward(
453
+ self,
454
+ hidden_states: torch.Tensor,
455
+ encoder_hidden_states: torch.Tensor,
456
+ temb: torch.Tensor,
457
+ rotary_emb: torch.Tensor,
458
+ ) -> torch.Tensor:
459
+ if temb.ndim == 4:
460
+ # temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)
461
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
462
+ self.scale_shift_table.unsqueeze(0) + temb.float()
463
+ ).chunk(6, dim=2)
464
+ # batch_size, seq_len, 1, inner_dim
465
+ shift_msa = shift_msa.squeeze(2)
466
+ scale_msa = scale_msa.squeeze(2)
467
+ gate_msa = gate_msa.squeeze(2)
468
+ c_shift_msa = c_shift_msa.squeeze(2)
469
+ c_scale_msa = c_scale_msa.squeeze(2)
470
+ c_gate_msa = c_gate_msa.squeeze(2)
471
+ else:
472
+ # temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B)
473
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
474
+ self.scale_shift_table + temb.float()
475
+ ).chunk(6, dim=1)
476
+
477
+ # 1. Self-attention
478
+ norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
479
+ attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb)
480
+ hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
481
+
482
+ # 2. Cross-attention
483
+ norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
484
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None)
485
+ hidden_states = hidden_states + attn_output
486
+
487
+ # 3. Feed-forward
488
+ norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
489
+ hidden_states
490
+ )
491
+ ff_output = self.ffn(norm_hidden_states)
492
+ hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
493
+
494
+ return hidden_states
495
+
496
+
497
+ class WanTransformer3DModel(
498
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
499
+ ):
500
+ r"""
501
+ A Transformer model for video-like data used in the Wan model.
502
+
503
+ Args:
504
+ patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
505
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
506
+ num_attention_heads (`int`, defaults to `40`):
507
+ Fixed length for text embeddings.
508
+ attention_head_dim (`int`, defaults to `128`):
509
+ The number of channels in each head.
510
+ in_channels (`int`, defaults to `16`):
511
+ The number of channels in the input.
512
+ out_channels (`int`, defaults to `16`):
513
+ The number of channels in the output.
514
+ text_dim (`int`, defaults to `512`):
515
+ Input dimension for text embeddings.
516
+ freq_dim (`int`, defaults to `256`):
517
+ Dimension for sinusoidal time embeddings.
518
+ ffn_dim (`int`, defaults to `13824`):
519
+ Intermediate dimension in feed-forward network.
520
+ num_layers (`int`, defaults to `40`):
521
+ The number of layers of transformer blocks to use.
522
+ window_size (`Tuple[int]`, defaults to `(-1, -1)`):
523
+ Window size for local attention (-1 indicates global attention).
524
+ cross_attn_norm (`bool`, defaults to `True`):
525
+ Enable cross-attention normalization.
526
+ qk_norm (`bool`, defaults to `True`):
527
+ Enable query/key normalization.
528
+ eps (`float`, defaults to `1e-6`):
529
+ Epsilon value for normalization layers.
530
+ add_img_emb (`bool`, defaults to `False`):
531
+ Whether to use img_emb.
532
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
533
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
534
+ """
535
+
536
+ _supports_gradient_checkpointing = True
537
+ _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
538
+ _no_split_modules = ["WanTransformerBlock"]
539
+ _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
540
+ _keys_to_ignore_on_load_unexpected = ["norm_added_q"]
541
+ _repeated_blocks = ["WanTransformerBlock"]
542
+
543
+ @register_to_config
544
+ def __init__(
545
+ self,
546
+ patch_size: Tuple[int] = (1, 2, 2),
547
+ num_attention_heads: int = 40,
548
+ attention_head_dim: int = 128,
549
+ in_channels: int = 16,
550
+ out_channels: int = 16,
551
+ text_dim: int = 4096,
552
+ freq_dim: int = 256,
553
+ ffn_dim: int = 13824,
554
+ num_layers: int = 40,
555
+ cross_attn_norm: bool = True,
556
+ qk_norm: Optional[str] = "rms_norm_across_heads",
557
+ eps: float = 1e-6,
558
+ image_dim: Optional[int] = None,
559
+ added_kv_proj_dim: Optional[int] = None,
560
+ rope_max_seq_len: int = 1024,
561
+ pos_embed_seq_len: Optional[int] = None,
562
+ ) -> None:
563
+ super().__init__()
564
+
565
+ inner_dim = num_attention_heads * attention_head_dim
566
+ out_channels = out_channels or in_channels
567
+
568
+ # 1. Patch & position embedding
569
+ self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
570
+ self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
571
+
572
+ # 2. Condition embeddings
573
+ # image_embedding_dim=1280 for I2V model
574
+ self.condition_embedder = WanTimeTextImageEmbedding(
575
+ dim=inner_dim,
576
+ time_freq_dim=freq_dim,
577
+ time_proj_dim=inner_dim * 6,
578
+ text_embed_dim=text_dim,
579
+ image_embed_dim=image_dim,
580
+ pos_embed_seq_len=pos_embed_seq_len,
581
+ )
582
+
583
+ # 3. Transformer blocks
584
+ self.blocks = nn.ModuleList(
585
+ [
586
+ WanTransformerBlock(
587
+ inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
588
+ )
589
+ for _ in range(num_layers)
590
+ ]
591
+ )
592
+
593
+ # 4. Output norm & projection
594
+ self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
595
+ self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
596
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
597
+
598
+ self.gradient_checkpointing = False
599
+
600
+ def forward(
601
+ self,
602
+ hidden_states: torch.Tensor,
603
+ timestep: torch.LongTensor,
604
+ encoder_hidden_states: torch.Tensor,
605
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
606
+ return_dict: bool = True,
607
+ attention_kwargs: Optional[Dict[str, Any]] = None,
608
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
609
+ if attention_kwargs is not None:
610
+ attention_kwargs = attention_kwargs.copy()
611
+ lora_scale = attention_kwargs.pop("scale", 1.0)
612
+ else:
613
+ lora_scale = 1.0
614
+
615
+ if USE_PEFT_BACKEND:
616
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
617
+ scale_lora_layers(self, lora_scale)
618
+ else:
619
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
620
+ logger.warning(
621
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
622
+ )
623
+
624
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
625
+ p_t, p_h, p_w = self.config.patch_size
626
+ post_patch_num_frames = num_frames // p_t
627
+ post_patch_height = height // p_h
628
+ post_patch_width = width // p_w
629
+
630
+ rotary_emb = self.rope(hidden_states)
631
+
632
+ hidden_states = self.patch_embedding(hidden_states)
633
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
634
+
635
+ # timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v)
636
+ if timestep.ndim == 2:
637
+ ts_seq_len = timestep.shape[1]
638
+ timestep = timestep.flatten() # batch_size * seq_len
639
+ else:
640
+ ts_seq_len = None
641
+
642
+ temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
643
+ timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len
644
+ )
645
+ if ts_seq_len is not None:
646
+ # batch_size, seq_len, 6, inner_dim
647
+ timestep_proj = timestep_proj.unflatten(2, (6, -1))
648
+ else:
649
+ # batch_size, 6, inner_dim
650
+ timestep_proj = timestep_proj.unflatten(1, (6, -1))
651
+
652
+ if encoder_hidden_states_image is not None:
653
+ encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
654
+
655
+ # 4. Transformer blocks
656
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
657
+ for block in self.blocks:
658
+ hidden_states = self._gradient_checkpointing_func(
659
+ block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
660
+ )
661
+ else:
662
+ for block in self.blocks:
663
+ hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
664
+
665
+ # 5. Output norm, projection & unpatchify
666
+ if temb.ndim == 3:
667
+ # batch_size, seq_len, inner_dim (wan 2.2 ti2v)
668
+ shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2)
669
+ shift = shift.squeeze(2)
670
+ scale = scale.squeeze(2)
671
+ else:
672
+ # batch_size, inner_dim
673
+ shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
674
+
675
+ # Move the shift and scale tensors to the same device as hidden_states.
676
+ # When using multi-GPU inference via accelerate these will be on the
677
+ # first device rather than the last device, which hidden_states ends up
678
+ # on.
679
+ shift = shift.to(hidden_states.device)
680
+ scale = scale.to(hidden_states.device)
681
+
682
+ hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
683
+ hidden_states = self.proj_out(hidden_states)
684
+
685
+ hidden_states = hidden_states.reshape(
686
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
687
+ )
688
+ hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
689
+ output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
690
+
691
+ if USE_PEFT_BACKEND:
692
+ # remove `lora_scale` from each PEFT layer
693
+ unscale_lora_layers(self, lora_scale)
694
+
695
+ if not return_dict:
696
+ return (output,)
697
+
698
+ return Transformer2DModelOutput(sample=output)
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/transformer_wan_vace.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from ...configuration_utils import ConfigMixin, register_to_config
22
+ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
23
+ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
24
+ from ..attention import AttentionMixin, FeedForward
25
+ from ..cache_utils import CacheMixin
26
+ from ..modeling_outputs import Transformer2DModelOutput
27
+ from ..modeling_utils import ModelMixin
28
+ from ..normalization import FP32LayerNorm
29
+ from .transformer_wan import (
30
+ WanAttention,
31
+ WanAttnProcessor,
32
+ WanRotaryPosEmbed,
33
+ WanTimeTextImageEmbedding,
34
+ WanTransformerBlock,
35
+ )
36
+
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+
41
+ class WanVACETransformerBlock(nn.Module):
42
+ def __init__(
43
+ self,
44
+ dim: int,
45
+ ffn_dim: int,
46
+ num_heads: int,
47
+ qk_norm: str = "rms_norm_across_heads",
48
+ cross_attn_norm: bool = False,
49
+ eps: float = 1e-6,
50
+ added_kv_proj_dim: Optional[int] = None,
51
+ apply_input_projection: bool = False,
52
+ apply_output_projection: bool = False,
53
+ ):
54
+ super().__init__()
55
+
56
+ # 1. Input projection
57
+ self.proj_in = None
58
+ if apply_input_projection:
59
+ self.proj_in = nn.Linear(dim, dim)
60
+
61
+ # 2. Self-attention
62
+ self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
63
+ self.attn1 = WanAttention(
64
+ dim=dim,
65
+ heads=num_heads,
66
+ dim_head=dim // num_heads,
67
+ eps=eps,
68
+ processor=WanAttnProcessor(),
69
+ )
70
+
71
+ # 3. Cross-attention
72
+ self.attn2 = WanAttention(
73
+ dim=dim,
74
+ heads=num_heads,
75
+ dim_head=dim // num_heads,
76
+ eps=eps,
77
+ added_kv_proj_dim=added_kv_proj_dim,
78
+ processor=WanAttnProcessor(),
79
+ )
80
+ self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
81
+
82
+ # 4. Feed-forward
83
+ self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
84
+ self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
85
+
86
+ # 5. Output projection
87
+ self.proj_out = None
88
+ if apply_output_projection:
89
+ self.proj_out = nn.Linear(dim, dim)
90
+
91
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
92
+
93
+ def forward(
94
+ self,
95
+ hidden_states: torch.Tensor,
96
+ encoder_hidden_states: torch.Tensor,
97
+ control_hidden_states: torch.Tensor,
98
+ temb: torch.Tensor,
99
+ rotary_emb: torch.Tensor,
100
+ ) -> torch.Tensor:
101
+ if self.proj_in is not None:
102
+ control_hidden_states = self.proj_in(control_hidden_states)
103
+ control_hidden_states = control_hidden_states + hidden_states
104
+
105
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
106
+ self.scale_shift_table + temb.float()
107
+ ).chunk(6, dim=1)
108
+
109
+ # 1. Self-attention
110
+ norm_hidden_states = (self.norm1(control_hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(
111
+ control_hidden_states
112
+ )
113
+ attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb)
114
+ control_hidden_states = (control_hidden_states.float() + attn_output * gate_msa).type_as(control_hidden_states)
115
+
116
+ # 2. Cross-attention
117
+ norm_hidden_states = self.norm2(control_hidden_states.float()).type_as(control_hidden_states)
118
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None)
119
+ control_hidden_states = control_hidden_states + attn_output
120
+
121
+ # 3. Feed-forward
122
+ norm_hidden_states = (self.norm3(control_hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
123
+ control_hidden_states
124
+ )
125
+ ff_output = self.ffn(norm_hidden_states)
126
+ control_hidden_states = (control_hidden_states.float() + ff_output.float() * c_gate_msa).type_as(
127
+ control_hidden_states
128
+ )
129
+
130
+ conditioning_states = None
131
+ if self.proj_out is not None:
132
+ conditioning_states = self.proj_out(control_hidden_states)
133
+
134
+ return conditioning_states, control_hidden_states
135
+
136
+
137
+ class WanVACETransformer3DModel(
138
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
139
+ ):
140
+ r"""
141
+ A Transformer model for video-like data used in the Wan model.
142
+
143
+ Args:
144
+ patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
145
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
146
+ num_attention_heads (`int`, defaults to `40`):
147
+ Fixed length for text embeddings.
148
+ attention_head_dim (`int`, defaults to `128`):
149
+ The number of channels in each head.
150
+ in_channels (`int`, defaults to `16`):
151
+ The number of channels in the input.
152
+ out_channels (`int`, defaults to `16`):
153
+ The number of channels in the output.
154
+ text_dim (`int`, defaults to `512`):
155
+ Input dimension for text embeddings.
156
+ freq_dim (`int`, defaults to `256`):
157
+ Dimension for sinusoidal time embeddings.
158
+ ffn_dim (`int`, defaults to `13824`):
159
+ Intermediate dimension in feed-forward network.
160
+ num_layers (`int`, defaults to `40`):
161
+ The number of layers of transformer blocks to use.
162
+ window_size (`Tuple[int]`, defaults to `(-1, -1)`):
163
+ Window size for local attention (-1 indicates global attention).
164
+ cross_attn_norm (`bool`, defaults to `True`):
165
+ Enable cross-attention normalization.
166
+ qk_norm (`bool`, defaults to `True`):
167
+ Enable query/key normalization.
168
+ eps (`float`, defaults to `1e-6`):
169
+ Epsilon value for normalization layers.
170
+ add_img_emb (`bool`, defaults to `False`):
171
+ Whether to use img_emb.
172
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
173
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
174
+ """
175
+
176
+ _supports_gradient_checkpointing = True
177
+ _skip_layerwise_casting_patterns = ["patch_embedding", "vace_patch_embedding", "condition_embedder", "norm"]
178
+ _no_split_modules = ["WanTransformerBlock", "WanVACETransformerBlock"]
179
+ _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
180
+ _keys_to_ignore_on_load_unexpected = ["norm_added_q"]
181
+
182
+ @register_to_config
183
+ def __init__(
184
+ self,
185
+ patch_size: Tuple[int] = (1, 2, 2),
186
+ num_attention_heads: int = 40,
187
+ attention_head_dim: int = 128,
188
+ in_channels: int = 16,
189
+ out_channels: int = 16,
190
+ text_dim: int = 4096,
191
+ freq_dim: int = 256,
192
+ ffn_dim: int = 13824,
193
+ num_layers: int = 40,
194
+ cross_attn_norm: bool = True,
195
+ qk_norm: Optional[str] = "rms_norm_across_heads",
196
+ eps: float = 1e-6,
197
+ image_dim: Optional[int] = None,
198
+ added_kv_proj_dim: Optional[int] = None,
199
+ rope_max_seq_len: int = 1024,
200
+ pos_embed_seq_len: Optional[int] = None,
201
+ vace_layers: List[int] = [0, 5, 10, 15, 20, 25, 30, 35],
202
+ vace_in_channels: int = 96,
203
+ ) -> None:
204
+ super().__init__()
205
+
206
+ inner_dim = num_attention_heads * attention_head_dim
207
+ out_channels = out_channels or in_channels
208
+
209
+ if max(vace_layers) >= num_layers:
210
+ raise ValueError(f"VACE layers {vace_layers} exceed the number of transformer layers {num_layers}.")
211
+ if 0 not in vace_layers:
212
+ raise ValueError("VACE layers must include layer 0.")
213
+
214
+ # 1. Patch & position embedding
215
+ self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
216
+ self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
217
+ self.vace_patch_embedding = nn.Conv3d(vace_in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
218
+
219
+ # 2. Condition embeddings
220
+ # image_embedding_dim=1280 for I2V model
221
+ self.condition_embedder = WanTimeTextImageEmbedding(
222
+ dim=inner_dim,
223
+ time_freq_dim=freq_dim,
224
+ time_proj_dim=inner_dim * 6,
225
+ text_embed_dim=text_dim,
226
+ image_embed_dim=image_dim,
227
+ pos_embed_seq_len=pos_embed_seq_len,
228
+ )
229
+
230
+ # 3. Transformer blocks
231
+ self.blocks = nn.ModuleList(
232
+ [
233
+ WanTransformerBlock(
234
+ inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
235
+ )
236
+ for _ in range(num_layers)
237
+ ]
238
+ )
239
+
240
+ self.vace_blocks = nn.ModuleList(
241
+ [
242
+ WanVACETransformerBlock(
243
+ inner_dim,
244
+ ffn_dim,
245
+ num_attention_heads,
246
+ qk_norm,
247
+ cross_attn_norm,
248
+ eps,
249
+ added_kv_proj_dim,
250
+ apply_input_projection=i == 0, # Layer 0 always has input projection and is in vace_layers
251
+ apply_output_projection=True,
252
+ )
253
+ for i in range(len(vace_layers))
254
+ ]
255
+ )
256
+
257
+ # 4. Output norm & projection
258
+ self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
259
+ self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
260
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
261
+
262
+ self.gradient_checkpointing = False
263
+
264
+ def forward(
265
+ self,
266
+ hidden_states: torch.Tensor,
267
+ timestep: torch.LongTensor,
268
+ encoder_hidden_states: torch.Tensor,
269
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
270
+ control_hidden_states: torch.Tensor = None,
271
+ control_hidden_states_scale: torch.Tensor = None,
272
+ return_dict: bool = True,
273
+ attention_kwargs: Optional[Dict[str, Any]] = None,
274
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
275
+ if attention_kwargs is not None:
276
+ attention_kwargs = attention_kwargs.copy()
277
+ lora_scale = attention_kwargs.pop("scale", 1.0)
278
+ else:
279
+ lora_scale = 1.0
280
+
281
+ if USE_PEFT_BACKEND:
282
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
283
+ scale_lora_layers(self, lora_scale)
284
+ else:
285
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
286
+ logger.warning(
287
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
288
+ )
289
+
290
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
291
+ p_t, p_h, p_w = self.config.patch_size
292
+ post_patch_num_frames = num_frames // p_t
293
+ post_patch_height = height // p_h
294
+ post_patch_width = width // p_w
295
+
296
+ if control_hidden_states_scale is None:
297
+ control_hidden_states_scale = control_hidden_states.new_ones(len(self.config.vace_layers))
298
+ control_hidden_states_scale = torch.unbind(control_hidden_states_scale)
299
+ if len(control_hidden_states_scale) != len(self.config.vace_layers):
300
+ raise ValueError(
301
+ f"Length of `control_hidden_states_scale` {len(control_hidden_states_scale)} should be "
302
+ f"equal to {len(self.config.vace_layers)}."
303
+ )
304
+
305
+ # 1. Rotary position embedding
306
+ rotary_emb = self.rope(hidden_states)
307
+
308
+ # 2. Patch embedding
309
+ hidden_states = self.patch_embedding(hidden_states)
310
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
311
+
312
+ control_hidden_states = self.vace_patch_embedding(control_hidden_states)
313
+ control_hidden_states = control_hidden_states.flatten(2).transpose(1, 2)
314
+ control_hidden_states_padding = control_hidden_states.new_zeros(
315
+ batch_size, hidden_states.size(1) - control_hidden_states.size(1), control_hidden_states.size(2)
316
+ )
317
+ control_hidden_states = torch.cat([control_hidden_states, control_hidden_states_padding], dim=1)
318
+
319
+ # 3. Time embedding
320
+ temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
321
+ timestep, encoder_hidden_states, encoder_hidden_states_image
322
+ )
323
+ timestep_proj = timestep_proj.unflatten(1, (6, -1))
324
+
325
+ # 4. Image embedding
326
+ if encoder_hidden_states_image is not None:
327
+ encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
328
+
329
+ # 5. Transformer blocks
330
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
331
+ # Prepare VACE hints
332
+ control_hidden_states_list = []
333
+ for i, block in enumerate(self.vace_blocks):
334
+ conditioning_states, control_hidden_states = self._gradient_checkpointing_func(
335
+ block, hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb
336
+ )
337
+ control_hidden_states_list.append((conditioning_states, control_hidden_states_scale[i]))
338
+ control_hidden_states_list = control_hidden_states_list[::-1]
339
+
340
+ for i, block in enumerate(self.blocks):
341
+ hidden_states = self._gradient_checkpointing_func(
342
+ block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
343
+ )
344
+ if i in self.config.vace_layers:
345
+ control_hint, scale = control_hidden_states_list.pop()
346
+ hidden_states = hidden_states + control_hint * scale
347
+ else:
348
+ # Prepare VACE hints
349
+ control_hidden_states_list = []
350
+ for i, block in enumerate(self.vace_blocks):
351
+ conditioning_states, control_hidden_states = block(
352
+ hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb
353
+ )
354
+ control_hidden_states_list.append((conditioning_states, control_hidden_states_scale[i]))
355
+ control_hidden_states_list = control_hidden_states_list[::-1]
356
+
357
+ for i, block in enumerate(self.blocks):
358
+ hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
359
+ if i in self.config.vace_layers:
360
+ control_hint, scale = control_hidden_states_list.pop()
361
+ hidden_states = hidden_states + control_hint * scale
362
+
363
+ # 6. Output norm, projection & unpatchify
364
+ shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
365
+
366
+ # Move the shift and scale tensors to the same device as hidden_states.
367
+ # When using multi-GPU inference via accelerate these will be on the
368
+ # first device rather than the last device, which hidden_states ends up
369
+ # on.
370
+ shift = shift.to(hidden_states.device)
371
+ scale = scale.to(hidden_states.device)
372
+
373
+ hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
374
+ hidden_states = self.proj_out(hidden_states)
375
+
376
+ hidden_states = hidden_states.reshape(
377
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
378
+ )
379
+ hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
380
+ output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
381
+
382
+ if USE_PEFT_BACKEND:
383
+ # remove `lora_scale` from each PEFT layer
384
+ unscale_lora_layers(self, lora_scale)
385
+
386
+ if not return_dict:
387
+ return (output,)
388
+
389
+ return Transformer2DModelOutput(sample=output)