Fabrice-TIERCELIN commited on
Commit
f2c1d7e
·
verified ·
1 Parent(s): 88027fa

Upload 4 files

Browse files
ltx_video/models/transformers/attention.py ADDED
@@ -0,0 +1,1265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from importlib import import_module
3
+ from typing import Any, Dict, Optional, Tuple
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
8
+ from diffusers.models.attention import _chunked_feed_forward
9
+ from diffusers.models.attention_processor import (
10
+ LoRAAttnAddedKVProcessor,
11
+ LoRAAttnProcessor,
12
+ LoRAAttnProcessor2_0,
13
+ LoRAXFormersAttnProcessor,
14
+ SpatialNorm,
15
+ )
16
+ from diffusers.models.lora import LoRACompatibleLinear
17
+ from diffusers.models.normalization import RMSNorm
18
+ from diffusers.utils import deprecate, logging
19
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
20
+ from einops import rearrange
21
+ from torch import nn
22
+
23
+ from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
24
+
25
+ try:
26
+ from torch_xla.experimental.custom_kernel import flash_attention
27
+ except ImportError:
28
+ # workaround for automatic tests. Currently this function is manually patched
29
+ # to the torch_xla lib on setup of container
30
+ pass
31
+
32
+ # code adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
33
+
34
+ logger = logging.get_logger(__name__)
35
+
36
+
37
+ @maybe_allow_in_graph
38
+ class BasicTransformerBlock(nn.Module):
39
+ r"""
40
+ A basic Transformer block.
41
+
42
+ Parameters:
43
+ dim (`int`): The number of channels in the input and output.
44
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
45
+ attention_head_dim (`int`): The number of channels in each head.
46
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
47
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
48
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
49
+ num_embeds_ada_norm (:
50
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
51
+ attention_bias (:
52
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
53
+ only_cross_attention (`bool`, *optional*):
54
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
55
+ double_self_attention (`bool`, *optional*):
56
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
57
+ upcast_attention (`bool`, *optional*):
58
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
59
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
60
+ Whether to use learnable elementwise affine parameters for normalization.
61
+ qk_norm (`str`, *optional*, defaults to None):
62
+ Set to 'layer_norm' or `rms_norm` to perform query and key normalization.
63
+ adaptive_norm (`str`, *optional*, defaults to `"single_scale_shift"`):
64
+ The type of adaptive norm to use. Can be `"single_scale_shift"`, `"single_scale"` or "none".
65
+ standardization_norm (`str`, *optional*, defaults to `"layer_norm"`):
66
+ The type of pre-normalization to use. Can be `"layer_norm"` or `"rms_norm"`.
67
+ final_dropout (`bool` *optional*, defaults to False):
68
+ Whether to apply a final dropout after the last feed-forward layer.
69
+ attention_type (`str`, *optional*, defaults to `"default"`):
70
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
71
+ positional_embeddings (`str`, *optional*, defaults to `None`):
72
+ The type of positional embeddings to apply to.
73
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
74
+ The maximum number of positional embeddings to apply.
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ dim: int,
80
+ num_attention_heads: int,
81
+ attention_head_dim: int,
82
+ dropout=0.0,
83
+ cross_attention_dim: Optional[int] = None,
84
+ activation_fn: str = "geglu",
85
+ num_embeds_ada_norm: Optional[int] = None, # pylint: disable=unused-argument
86
+ attention_bias: bool = False,
87
+ only_cross_attention: bool = False,
88
+ double_self_attention: bool = False,
89
+ upcast_attention: bool = False,
90
+ norm_elementwise_affine: bool = True,
91
+ adaptive_norm: str = "single_scale_shift", # 'single_scale_shift', 'single_scale' or 'none'
92
+ standardization_norm: str = "layer_norm", # 'layer_norm' or 'rms_norm'
93
+ norm_eps: float = 1e-5,
94
+ qk_norm: Optional[str] = None,
95
+ final_dropout: bool = False,
96
+ attention_type: str = "default", # pylint: disable=unused-argument
97
+ ff_inner_dim: Optional[int] = None,
98
+ ff_bias: bool = True,
99
+ attention_out_bias: bool = True,
100
+ use_tpu_flash_attention: bool = False,
101
+ use_rope: bool = False,
102
+ ):
103
+ super().__init__()
104
+ self.only_cross_attention = only_cross_attention
105
+ self.use_tpu_flash_attention = use_tpu_flash_attention
106
+ self.adaptive_norm = adaptive_norm
107
+
108
+ assert standardization_norm in ["layer_norm", "rms_norm"]
109
+ assert adaptive_norm in ["single_scale_shift", "single_scale", "none"]
110
+
111
+ make_norm_layer = (
112
+ nn.LayerNorm if standardization_norm == "layer_norm" else RMSNorm
113
+ )
114
+
115
+ # Define 3 blocks. Each block has its own normalization layer.
116
+ # 1. Self-Attn
117
+ self.norm1 = make_norm_layer(
118
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
119
+ )
120
+
121
+ self.attn1 = Attention(
122
+ query_dim=dim,
123
+ heads=num_attention_heads,
124
+ dim_head=attention_head_dim,
125
+ dropout=dropout,
126
+ bias=attention_bias,
127
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
128
+ upcast_attention=upcast_attention,
129
+ out_bias=attention_out_bias,
130
+ use_tpu_flash_attention=use_tpu_flash_attention,
131
+ qk_norm=qk_norm,
132
+ use_rope=use_rope,
133
+ )
134
+
135
+ # 2. Cross-Attn
136
+ if cross_attention_dim is not None or double_self_attention:
137
+ self.attn2 = Attention(
138
+ query_dim=dim,
139
+ cross_attention_dim=(
140
+ cross_attention_dim if not double_self_attention else None
141
+ ),
142
+ heads=num_attention_heads,
143
+ dim_head=attention_head_dim,
144
+ dropout=dropout,
145
+ bias=attention_bias,
146
+ upcast_attention=upcast_attention,
147
+ out_bias=attention_out_bias,
148
+ use_tpu_flash_attention=use_tpu_flash_attention,
149
+ qk_norm=qk_norm,
150
+ use_rope=use_rope,
151
+ ) # is self-attn if encoder_hidden_states is none
152
+
153
+ if adaptive_norm == "none":
154
+ self.attn2_norm = make_norm_layer(
155
+ dim, norm_eps, norm_elementwise_affine
156
+ )
157
+ else:
158
+ self.attn2 = None
159
+ self.attn2_norm = None
160
+
161
+ self.norm2 = make_norm_layer(dim, norm_eps, norm_elementwise_affine)
162
+
163
+ # 3. Feed-forward
164
+ self.ff = FeedForward(
165
+ dim,
166
+ dropout=dropout,
167
+ activation_fn=activation_fn,
168
+ final_dropout=final_dropout,
169
+ inner_dim=ff_inner_dim,
170
+ bias=ff_bias,
171
+ )
172
+
173
+ # 5. Scale-shift for PixArt-Alpha.
174
+ if adaptive_norm != "none":
175
+ num_ada_params = 4 if adaptive_norm == "single_scale" else 6
176
+ self.scale_shift_table = nn.Parameter(
177
+ torch.randn(num_ada_params, dim) / dim**0.5
178
+ )
179
+
180
+ # let chunk size default to None
181
+ self._chunk_size = None
182
+ self._chunk_dim = 0
183
+
184
+ def set_use_tpu_flash_attention(self):
185
+ r"""
186
+ Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU
187
+ attention kernel.
188
+ """
189
+ self.use_tpu_flash_attention = True
190
+ self.attn1.set_use_tpu_flash_attention()
191
+ self.attn2.set_use_tpu_flash_attention()
192
+
193
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
194
+ # Sets chunk feed-forward
195
+ self._chunk_size = chunk_size
196
+ self._chunk_dim = dim
197
+
198
+ def forward(
199
+ self,
200
+ hidden_states: torch.FloatTensor,
201
+ freqs_cis: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None,
202
+ attention_mask: Optional[torch.FloatTensor] = None,
203
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
204
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
205
+ timestep: Optional[torch.LongTensor] = None,
206
+ cross_attention_kwargs: Dict[str, Any] = None,
207
+ class_labels: Optional[torch.LongTensor] = None,
208
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
209
+ skip_layer_mask: Optional[torch.Tensor] = None,
210
+ skip_layer_strategy: Optional[SkipLayerStrategy] = None,
211
+ ) -> torch.FloatTensor:
212
+ if cross_attention_kwargs is not None:
213
+ if cross_attention_kwargs.get("scale", None) is not None:
214
+ logger.warning(
215
+ "Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored."
216
+ )
217
+
218
+ # Notice that normalization is always applied before the real computation in the following blocks.
219
+ # 0. Self-Attention
220
+ batch_size = hidden_states.shape[0]
221
+
222
+ original_hidden_states = hidden_states
223
+
224
+ norm_hidden_states = self.norm1(hidden_states)
225
+
226
+ # Apply ada_norm_single
227
+ if self.adaptive_norm in ["single_scale_shift", "single_scale"]:
228
+ assert timestep.ndim == 3 # [batch, 1 or num_tokens, embedding_dim]
229
+ num_ada_params = self.scale_shift_table.shape[0]
230
+ ada_values = self.scale_shift_table[None, None] + timestep.reshape(
231
+ batch_size, timestep.shape[1], num_ada_params, -1
232
+ )
233
+ if self.adaptive_norm == "single_scale_shift":
234
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
235
+ ada_values.unbind(dim=2)
236
+ )
237
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
238
+ else:
239
+ scale_msa, gate_msa, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
240
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa)
241
+ elif self.adaptive_norm == "none":
242
+ scale_msa, gate_msa, scale_mlp, gate_mlp = None, None, None, None
243
+ else:
244
+ raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}")
245
+
246
+ norm_hidden_states = norm_hidden_states.squeeze(
247
+ 1
248
+ ) # TODO: Check if this is needed
249
+
250
+ # 1. Prepare GLIGEN inputs
251
+ cross_attention_kwargs = (
252
+ cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
253
+ )
254
+
255
+ attn_output = self.attn1(
256
+ norm_hidden_states,
257
+ freqs_cis=freqs_cis,
258
+ encoder_hidden_states=(
259
+ encoder_hidden_states if self.only_cross_attention else None
260
+ ),
261
+ attention_mask=attention_mask,
262
+ skip_layer_mask=skip_layer_mask,
263
+ skip_layer_strategy=skip_layer_strategy,
264
+ **cross_attention_kwargs,
265
+ )
266
+ if gate_msa is not None:
267
+ attn_output = gate_msa * attn_output
268
+
269
+ hidden_states = attn_output + hidden_states
270
+ if hidden_states.ndim == 4:
271
+ hidden_states = hidden_states.squeeze(1)
272
+
273
+ # 3. Cross-Attention
274
+ if self.attn2 is not None:
275
+ if self.adaptive_norm == "none":
276
+ attn_input = self.attn2_norm(hidden_states)
277
+ else:
278
+ attn_input = hidden_states
279
+ attn_output = self.attn2(
280
+ attn_input,
281
+ freqs_cis=freqs_cis,
282
+ encoder_hidden_states=encoder_hidden_states,
283
+ attention_mask=encoder_attention_mask,
284
+ **cross_attention_kwargs,
285
+ )
286
+ hidden_states = attn_output + hidden_states
287
+
288
+ # 4. Feed-forward
289
+ norm_hidden_states = self.norm2(hidden_states)
290
+ if self.adaptive_norm == "single_scale_shift":
291
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
292
+ elif self.adaptive_norm == "single_scale":
293
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp)
294
+ elif self.adaptive_norm == "none":
295
+ pass
296
+ else:
297
+ raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}")
298
+
299
+ if self._chunk_size is not None:
300
+ # "feed_forward_chunk_size" can be used to save memory
301
+ ff_output = _chunked_feed_forward(
302
+ self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size
303
+ )
304
+ else:
305
+ ff_output = self.ff(norm_hidden_states)
306
+ if gate_mlp is not None:
307
+ ff_output = gate_mlp * ff_output
308
+
309
+ hidden_states = ff_output + hidden_states
310
+ if hidden_states.ndim == 4:
311
+ hidden_states = hidden_states.squeeze(1)
312
+
313
+ if (
314
+ skip_layer_mask is not None
315
+ and skip_layer_strategy == SkipLayerStrategy.TransformerBlock
316
+ ):
317
+ skip_layer_mask = skip_layer_mask.view(-1, 1, 1)
318
+ hidden_states = hidden_states * skip_layer_mask + original_hidden_states * (
319
+ 1.0 - skip_layer_mask
320
+ )
321
+
322
+ return hidden_states
323
+
324
+
325
+ @maybe_allow_in_graph
326
+ class Attention(nn.Module):
327
+ r"""
328
+ A cross attention layer.
329
+
330
+ Parameters:
331
+ query_dim (`int`):
332
+ The number of channels in the query.
333
+ cross_attention_dim (`int`, *optional*):
334
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
335
+ heads (`int`, *optional*, defaults to 8):
336
+ The number of heads to use for multi-head attention.
337
+ dim_head (`int`, *optional*, defaults to 64):
338
+ The number of channels in each head.
339
+ dropout (`float`, *optional*, defaults to 0.0):
340
+ The dropout probability to use.
341
+ bias (`bool`, *optional*, defaults to False):
342
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
343
+ upcast_attention (`bool`, *optional*, defaults to False):
344
+ Set to `True` to upcast the attention computation to `float32`.
345
+ upcast_softmax (`bool`, *optional*, defaults to False):
346
+ Set to `True` to upcast the softmax computation to `float32`.
347
+ cross_attention_norm (`str`, *optional*, defaults to `None`):
348
+ The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
349
+ cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
350
+ The number of groups to use for the group norm in the cross attention.
351
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
352
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
353
+ norm_num_groups (`int`, *optional*, defaults to `None`):
354
+ The number of groups to use for the group norm in the attention.
355
+ spatial_norm_dim (`int`, *optional*, defaults to `None`):
356
+ The number of channels to use for the spatial normalization.
357
+ out_bias (`bool`, *optional*, defaults to `True`):
358
+ Set to `True` to use a bias in the output linear layer.
359
+ scale_qk (`bool`, *optional*, defaults to `True`):
360
+ Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
361
+ qk_norm (`str`, *optional*, defaults to None):
362
+ Set to 'layer_norm' or `rms_norm` to perform query and key normalization.
363
+ only_cross_attention (`bool`, *optional*, defaults to `False`):
364
+ Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
365
+ `added_kv_proj_dim` is not `None`.
366
+ eps (`float`, *optional*, defaults to 1e-5):
367
+ An additional value added to the denominator in group normalization that is used for numerical stability.
368
+ rescale_output_factor (`float`, *optional*, defaults to 1.0):
369
+ A factor to rescale the output by dividing it with this value.
370
+ residual_connection (`bool`, *optional*, defaults to `False`):
371
+ Set to `True` to add the residual connection to the output.
372
+ _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
373
+ Set to `True` if the attention block is loaded from a deprecated state dict.
374
+ processor (`AttnProcessor`, *optional*, defaults to `None`):
375
+ The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
376
+ `AttnProcessor` otherwise.
377
+ """
378
+
379
+ def __init__(
380
+ self,
381
+ query_dim: int,
382
+ cross_attention_dim: Optional[int] = None,
383
+ heads: int = 8,
384
+ dim_head: int = 64,
385
+ dropout: float = 0.0,
386
+ bias: bool = False,
387
+ upcast_attention: bool = False,
388
+ upcast_softmax: bool = False,
389
+ cross_attention_norm: Optional[str] = None,
390
+ cross_attention_norm_num_groups: int = 32,
391
+ added_kv_proj_dim: Optional[int] = None,
392
+ norm_num_groups: Optional[int] = None,
393
+ spatial_norm_dim: Optional[int] = None,
394
+ out_bias: bool = True,
395
+ scale_qk: bool = True,
396
+ qk_norm: Optional[str] = None,
397
+ only_cross_attention: bool = False,
398
+ eps: float = 1e-5,
399
+ rescale_output_factor: float = 1.0,
400
+ residual_connection: bool = False,
401
+ _from_deprecated_attn_block: bool = False,
402
+ processor: Optional["AttnProcessor"] = None,
403
+ out_dim: int = None,
404
+ use_tpu_flash_attention: bool = False,
405
+ use_rope: bool = False,
406
+ ):
407
+ super().__init__()
408
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
409
+ self.query_dim = query_dim
410
+ self.use_bias = bias
411
+ self.is_cross_attention = cross_attention_dim is not None
412
+ self.cross_attention_dim = (
413
+ cross_attention_dim if cross_attention_dim is not None else query_dim
414
+ )
415
+ self.upcast_attention = upcast_attention
416
+ self.upcast_softmax = upcast_softmax
417
+ self.rescale_output_factor = rescale_output_factor
418
+ self.residual_connection = residual_connection
419
+ self.dropout = dropout
420
+ self.fused_projections = False
421
+ self.out_dim = out_dim if out_dim is not None else query_dim
422
+ self.use_tpu_flash_attention = use_tpu_flash_attention
423
+ self.use_rope = use_rope
424
+
425
+ # we make use of this private variable to know whether this class is loaded
426
+ # with an deprecated state dict so that we can convert it on the fly
427
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
428
+
429
+ self.scale_qk = scale_qk
430
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
431
+
432
+ if qk_norm is None:
433
+ self.q_norm = nn.Identity()
434
+ self.k_norm = nn.Identity()
435
+ elif qk_norm == "rms_norm":
436
+ self.q_norm = RMSNorm(dim_head * heads, eps=1e-5)
437
+ self.k_norm = RMSNorm(dim_head * heads, eps=1e-5)
438
+ elif qk_norm == "layer_norm":
439
+ self.q_norm = nn.LayerNorm(dim_head * heads, eps=1e-5)
440
+ self.k_norm = nn.LayerNorm(dim_head * heads, eps=1e-5)
441
+ else:
442
+ raise ValueError(f"Unsupported qk_norm method: {qk_norm}")
443
+
444
+ self.heads = out_dim // dim_head if out_dim is not None else heads
445
+ # for slice_size > 0 the attention score computation
446
+ # is split across the batch axis to save memory
447
+ # You can set slice_size with `set_attention_slice`
448
+ self.sliceable_head_dim = heads
449
+
450
+ self.added_kv_proj_dim = added_kv_proj_dim
451
+ self.only_cross_attention = only_cross_attention
452
+
453
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
454
+ raise ValueError(
455
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
456
+ )
457
+
458
+ if norm_num_groups is not None:
459
+ self.group_norm = nn.GroupNorm(
460
+ num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True
461
+ )
462
+ else:
463
+ self.group_norm = None
464
+
465
+ if spatial_norm_dim is not None:
466
+ self.spatial_norm = SpatialNorm(
467
+ f_channels=query_dim, zq_channels=spatial_norm_dim
468
+ )
469
+ else:
470
+ self.spatial_norm = None
471
+
472
+ if cross_attention_norm is None:
473
+ self.norm_cross = None
474
+ elif cross_attention_norm == "layer_norm":
475
+ self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
476
+ elif cross_attention_norm == "group_norm":
477
+ if self.added_kv_proj_dim is not None:
478
+ # The given `encoder_hidden_states` are initially of shape
479
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
480
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
481
+ # before the projection, so we need to use `added_kv_proj_dim` as
482
+ # the number of channels for the group norm.
483
+ norm_cross_num_channels = added_kv_proj_dim
484
+ else:
485
+ norm_cross_num_channels = self.cross_attention_dim
486
+
487
+ self.norm_cross = nn.GroupNorm(
488
+ num_channels=norm_cross_num_channels,
489
+ num_groups=cross_attention_norm_num_groups,
490
+ eps=1e-5,
491
+ affine=True,
492
+ )
493
+ else:
494
+ raise ValueError(
495
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
496
+ )
497
+
498
+ linear_cls = nn.Linear
499
+
500
+ self.linear_cls = linear_cls
501
+ self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
502
+
503
+ if not self.only_cross_attention:
504
+ # only relevant for the `AddedKVProcessor` classes
505
+ self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
506
+ self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
507
+ else:
508
+ self.to_k = None
509
+ self.to_v = None
510
+
511
+ if self.added_kv_proj_dim is not None:
512
+ self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
513
+ self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
514
+
515
+ self.to_out = nn.ModuleList([])
516
+ self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
517
+ self.to_out.append(nn.Dropout(dropout))
518
+
519
+ # set attention processor
520
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
521
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
522
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
523
+ if processor is None:
524
+ processor = AttnProcessor2_0()
525
+ self.set_processor(processor)
526
+
527
+ def set_use_tpu_flash_attention(self):
528
+ r"""
529
+ Function sets the flag in this object. The flag will enforce the usage of TPU attention kernel.
530
+ """
531
+ self.use_tpu_flash_attention = True
532
+
533
+ def set_processor(self, processor: "AttnProcessor") -> None:
534
+ r"""
535
+ Set the attention processor to use.
536
+
537
+ Args:
538
+ processor (`AttnProcessor`):
539
+ The attention processor to use.
540
+ """
541
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
542
+ # pop `processor` from `self._modules`
543
+ if (
544
+ hasattr(self, "processor")
545
+ and isinstance(self.processor, torch.nn.Module)
546
+ and not isinstance(processor, torch.nn.Module)
547
+ ):
548
+ logger.info(
549
+ f"You are removing possibly trained weights of {self.processor} with {processor}"
550
+ )
551
+ self._modules.pop("processor")
552
+
553
+ self.processor = processor
554
+
555
+ def get_processor(
556
+ self, return_deprecated_lora: bool = False
557
+ ) -> "AttentionProcessor": # noqa: F821
558
+ r"""
559
+ Get the attention processor in use.
560
+
561
+ Args:
562
+ return_deprecated_lora (`bool`, *optional*, defaults to `False`):
563
+ Set to `True` to return the deprecated LoRA attention processor.
564
+
565
+ Returns:
566
+ "AttentionProcessor": The attention processor in use.
567
+ """
568
+ if not return_deprecated_lora:
569
+ return self.processor
570
+
571
+ # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible
572
+ # serialization format for LoRA Attention Processors. It should be deleted once the integration
573
+ # with PEFT is completed.
574
+ is_lora_activated = {
575
+ name: module.lora_layer is not None
576
+ for name, module in self.named_modules()
577
+ if hasattr(module, "lora_layer")
578
+ }
579
+
580
+ # 1. if no layer has a LoRA activated we can return the processor as usual
581
+ if not any(is_lora_activated.values()):
582
+ return self.processor
583
+
584
+ # If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
585
+ is_lora_activated.pop("add_k_proj", None)
586
+ is_lora_activated.pop("add_v_proj", None)
587
+ # 2. else it is not posssible that only some layers have LoRA activated
588
+ if not all(is_lora_activated.values()):
589
+ raise ValueError(
590
+ f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
591
+ )
592
+
593
+ # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
594
+ non_lora_processor_cls_name = self.processor.__class__.__name__
595
+ lora_processor_cls = getattr(
596
+ import_module(__name__), "LoRA" + non_lora_processor_cls_name
597
+ )
598
+
599
+ hidden_size = self.inner_dim
600
+
601
+ # now create a LoRA attention processor from the LoRA layers
602
+ if lora_processor_cls in [
603
+ LoRAAttnProcessor,
604
+ LoRAAttnProcessor2_0,
605
+ LoRAXFormersAttnProcessor,
606
+ ]:
607
+ kwargs = {
608
+ "cross_attention_dim": self.cross_attention_dim,
609
+ "rank": self.to_q.lora_layer.rank,
610
+ "network_alpha": self.to_q.lora_layer.network_alpha,
611
+ "q_rank": self.to_q.lora_layer.rank,
612
+ "q_hidden_size": self.to_q.lora_layer.out_features,
613
+ "k_rank": self.to_k.lora_layer.rank,
614
+ "k_hidden_size": self.to_k.lora_layer.out_features,
615
+ "v_rank": self.to_v.lora_layer.rank,
616
+ "v_hidden_size": self.to_v.lora_layer.out_features,
617
+ "out_rank": self.to_out[0].lora_layer.rank,
618
+ "out_hidden_size": self.to_out[0].lora_layer.out_features,
619
+ }
620
+
621
+ if hasattr(self.processor, "attention_op"):
622
+ kwargs["attention_op"] = self.processor.attention_op
623
+
624
+ lora_processor = lora_processor_cls(hidden_size, **kwargs)
625
+ lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
626
+ lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
627
+ lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
628
+ lora_processor.to_out_lora.load_state_dict(
629
+ self.to_out[0].lora_layer.state_dict()
630
+ )
631
+ elif lora_processor_cls == LoRAAttnAddedKVProcessor:
632
+ lora_processor = lora_processor_cls(
633
+ hidden_size,
634
+ cross_attention_dim=self.add_k_proj.weight.shape[0],
635
+ rank=self.to_q.lora_layer.rank,
636
+ network_alpha=self.to_q.lora_layer.network_alpha,
637
+ )
638
+ lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
639
+ lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
640
+ lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
641
+ lora_processor.to_out_lora.load_state_dict(
642
+ self.to_out[0].lora_layer.state_dict()
643
+ )
644
+
645
+ # only save if used
646
+ if self.add_k_proj.lora_layer is not None:
647
+ lora_processor.add_k_proj_lora.load_state_dict(
648
+ self.add_k_proj.lora_layer.state_dict()
649
+ )
650
+ lora_processor.add_v_proj_lora.load_state_dict(
651
+ self.add_v_proj.lora_layer.state_dict()
652
+ )
653
+ else:
654
+ lora_processor.add_k_proj_lora = None
655
+ lora_processor.add_v_proj_lora = None
656
+ else:
657
+ raise ValueError(f"{lora_processor_cls} does not exist.")
658
+
659
+ return lora_processor
660
+
661
+ def forward(
662
+ self,
663
+ hidden_states: torch.FloatTensor,
664
+ freqs_cis: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None,
665
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
666
+ attention_mask: Optional[torch.FloatTensor] = None,
667
+ skip_layer_mask: Optional[torch.Tensor] = None,
668
+ skip_layer_strategy: Optional[SkipLayerStrategy] = None,
669
+ **cross_attention_kwargs,
670
+ ) -> torch.Tensor:
671
+ r"""
672
+ The forward method of the `Attention` class.
673
+
674
+ Args:
675
+ hidden_states (`torch.Tensor`):
676
+ The hidden states of the query.
677
+ encoder_hidden_states (`torch.Tensor`, *optional*):
678
+ The hidden states of the encoder.
679
+ attention_mask (`torch.Tensor`, *optional*):
680
+ The attention mask to use. If `None`, no mask is applied.
681
+ skip_layer_mask (`torch.Tensor`, *optional*):
682
+ The skip layer mask to use. If `None`, no mask is applied.
683
+ skip_layer_strategy (`SkipLayerStrategy`, *optional*, defaults to `None`):
684
+ Controls which layers to skip for spatiotemporal guidance.
685
+ **cross_attention_kwargs:
686
+ Additional keyword arguments to pass along to the cross attention.
687
+
688
+ Returns:
689
+ `torch.Tensor`: The output of the attention layer.
690
+ """
691
+ # The `Attention` class can call different attention processors / attention functions
692
+ # here we simply pass along all tensors to the selected processor class
693
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
694
+
695
+ attn_parameters = set(
696
+ inspect.signature(self.processor.__call__).parameters.keys()
697
+ )
698
+ unused_kwargs = [
699
+ k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters
700
+ ]
701
+ if len(unused_kwargs) > 0:
702
+ logger.warning(
703
+ f"cross_attention_kwargs {unused_kwargs} are not expected by"
704
+ f" {self.processor.__class__.__name__} and will be ignored."
705
+ )
706
+ cross_attention_kwargs = {
707
+ k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters
708
+ }
709
+
710
+ return self.processor(
711
+ self,
712
+ hidden_states,
713
+ freqs_cis=freqs_cis,
714
+ encoder_hidden_states=encoder_hidden_states,
715
+ attention_mask=attention_mask,
716
+ skip_layer_mask=skip_layer_mask,
717
+ skip_layer_strategy=skip_layer_strategy,
718
+ **cross_attention_kwargs,
719
+ )
720
+
721
+ def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
722
+ r"""
723
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
724
+ is the number of heads initialized while constructing the `Attention` class.
725
+
726
+ Args:
727
+ tensor (`torch.Tensor`): The tensor to reshape.
728
+
729
+ Returns:
730
+ `torch.Tensor`: The reshaped tensor.
731
+ """
732
+ head_size = self.heads
733
+ batch_size, seq_len, dim = tensor.shape
734
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
735
+ tensor = tensor.permute(0, 2, 1, 3).reshape(
736
+ batch_size // head_size, seq_len, dim * head_size
737
+ )
738
+ return tensor
739
+
740
+ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
741
+ r"""
742
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
743
+ the number of heads initialized while constructing the `Attention` class.
744
+
745
+ Args:
746
+ tensor (`torch.Tensor`): The tensor to reshape.
747
+ out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
748
+ reshaped to `[batch_size * heads, seq_len, dim // heads]`.
749
+
750
+ Returns:
751
+ `torch.Tensor`: The reshaped tensor.
752
+ """
753
+
754
+ head_size = self.heads
755
+ if tensor.ndim == 3:
756
+ batch_size, seq_len, dim = tensor.shape
757
+ extra_dim = 1
758
+ else:
759
+ batch_size, extra_dim, seq_len, dim = tensor.shape
760
+ tensor = tensor.reshape(
761
+ batch_size, seq_len * extra_dim, head_size, dim // head_size
762
+ )
763
+ tensor = tensor.permute(0, 2, 1, 3)
764
+
765
+ if out_dim == 3:
766
+ tensor = tensor.reshape(
767
+ batch_size * head_size, seq_len * extra_dim, dim // head_size
768
+ )
769
+
770
+ return tensor
771
+
772
+ def get_attention_scores(
773
+ self,
774
+ query: torch.Tensor,
775
+ key: torch.Tensor,
776
+ attention_mask: torch.Tensor = None,
777
+ ) -> torch.Tensor:
778
+ r"""
779
+ Compute the attention scores.
780
+
781
+ Args:
782
+ query (`torch.Tensor`): The query tensor.
783
+ key (`torch.Tensor`): The key tensor.
784
+ attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
785
+
786
+ Returns:
787
+ `torch.Tensor`: The attention probabilities/scores.
788
+ """
789
+ dtype = query.dtype
790
+ if self.upcast_attention:
791
+ query = query.float()
792
+ key = key.float()
793
+
794
+ if attention_mask is None:
795
+ baddbmm_input = torch.empty(
796
+ query.shape[0],
797
+ query.shape[1],
798
+ key.shape[1],
799
+ dtype=query.dtype,
800
+ device=query.device,
801
+ )
802
+ beta = 0
803
+ else:
804
+ baddbmm_input = attention_mask
805
+ beta = 1
806
+
807
+ attention_scores = torch.baddbmm(
808
+ baddbmm_input,
809
+ query,
810
+ key.transpose(-1, -2),
811
+ beta=beta,
812
+ alpha=self.scale,
813
+ )
814
+ del baddbmm_input
815
+
816
+ if self.upcast_softmax:
817
+ attention_scores = attention_scores.float()
818
+
819
+ attention_probs = attention_scores.softmax(dim=-1)
820
+ del attention_scores
821
+
822
+ attention_probs = attention_probs.to(dtype)
823
+
824
+ return attention_probs
825
+
826
+ def prepare_attention_mask(
827
+ self,
828
+ attention_mask: torch.Tensor,
829
+ target_length: int,
830
+ batch_size: int,
831
+ out_dim: int = 3,
832
+ ) -> torch.Tensor:
833
+ r"""
834
+ Prepare the attention mask for the attention computation.
835
+
836
+ Args:
837
+ attention_mask (`torch.Tensor`):
838
+ The attention mask to prepare.
839
+ target_length (`int`):
840
+ The target length of the attention mask. This is the length of the attention mask after padding.
841
+ batch_size (`int`):
842
+ The batch size, which is used to repeat the attention mask.
843
+ out_dim (`int`, *optional*, defaults to `3`):
844
+ The output dimension of the attention mask. Can be either `3` or `4`.
845
+
846
+ Returns:
847
+ `torch.Tensor`: The prepared attention mask.
848
+ """
849
+ head_size = self.heads
850
+ if attention_mask is None:
851
+ return attention_mask
852
+
853
+ current_length: int = attention_mask.shape[-1]
854
+ if current_length != target_length:
855
+ if attention_mask.device.type == "mps":
856
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
857
+ # Instead, we can manually construct the padding tensor.
858
+ padding_shape = (
859
+ attention_mask.shape[0],
860
+ attention_mask.shape[1],
861
+ target_length,
862
+ )
863
+ padding = torch.zeros(
864
+ padding_shape,
865
+ dtype=attention_mask.dtype,
866
+ device=attention_mask.device,
867
+ )
868
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
869
+ else:
870
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
871
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
872
+ # remaining_length: int = target_length - current_length
873
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
874
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
875
+
876
+ if out_dim == 3:
877
+ if attention_mask.shape[0] < batch_size * head_size:
878
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
879
+ elif out_dim == 4:
880
+ attention_mask = attention_mask.unsqueeze(1)
881
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
882
+
883
+ return attention_mask
884
+
885
+ def norm_encoder_hidden_states(
886
+ self, encoder_hidden_states: torch.Tensor
887
+ ) -> torch.Tensor:
888
+ r"""
889
+ Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
890
+ `Attention` class.
891
+
892
+ Args:
893
+ encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
894
+
895
+ Returns:
896
+ `torch.Tensor`: The normalized encoder hidden states.
897
+ """
898
+ assert (
899
+ self.norm_cross is not None
900
+ ), "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
901
+
902
+ if isinstance(self.norm_cross, nn.LayerNorm):
903
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
904
+ elif isinstance(self.norm_cross, nn.GroupNorm):
905
+ # Group norm norms along the channels dimension and expects
906
+ # input to be in the shape of (N, C, *). In this case, we want
907
+ # to norm along the hidden dimension, so we need to move
908
+ # (batch_size, sequence_length, hidden_size) ->
909
+ # (batch_size, hidden_size, sequence_length)
910
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
911
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
912
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
913
+ else:
914
+ assert False
915
+
916
+ return encoder_hidden_states
917
+
918
+ @staticmethod
919
+ def apply_rotary_emb(
920
+ input_tensor: torch.Tensor,
921
+ freqs_cis: Tuple[torch.FloatTensor, torch.FloatTensor],
922
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
923
+ cos_freqs = freqs_cis[0]
924
+ sin_freqs = freqs_cis[1]
925
+
926
+ t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
927
+ t1, t2 = t_dup.unbind(dim=-1)
928
+ t_dup = torch.stack((-t2, t1), dim=-1)
929
+ input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
930
+
931
+ out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
932
+
933
+ return out
934
+
935
+
936
+ class AttnProcessor2_0:
937
+ r"""
938
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
939
+ """
940
+
941
+ def __init__(self):
942
+ pass
943
+
944
+ def __call__(
945
+ self,
946
+ attn: Attention,
947
+ hidden_states: torch.FloatTensor,
948
+ freqs_cis: Tuple[torch.FloatTensor, torch.FloatTensor],
949
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
950
+ attention_mask: Optional[torch.FloatTensor] = None,
951
+ temb: Optional[torch.FloatTensor] = None,
952
+ skip_layer_mask: Optional[torch.FloatTensor] = None,
953
+ skip_layer_strategy: Optional[SkipLayerStrategy] = None,
954
+ *args,
955
+ **kwargs,
956
+ ) -> torch.FloatTensor:
957
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
958
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
959
+ deprecate("scale", "1.0.0", deprecation_message)
960
+
961
+ residual = hidden_states
962
+ if attn.spatial_norm is not None:
963
+ hidden_states = attn.spatial_norm(hidden_states, temb)
964
+
965
+ input_ndim = hidden_states.ndim
966
+
967
+ if input_ndim == 4:
968
+ batch_size, channel, height, width = hidden_states.shape
969
+ hidden_states = hidden_states.view(
970
+ batch_size, channel, height * width
971
+ ).transpose(1, 2)
972
+
973
+ batch_size, sequence_length, _ = (
974
+ hidden_states.shape
975
+ if encoder_hidden_states is None
976
+ else encoder_hidden_states.shape
977
+ )
978
+
979
+ if skip_layer_mask is not None:
980
+ skip_layer_mask = skip_layer_mask.reshape(batch_size, 1, 1)
981
+
982
+ if (attention_mask is not None) and (not attn.use_tpu_flash_attention):
983
+ attention_mask = attn.prepare_attention_mask(
984
+ attention_mask, sequence_length, batch_size
985
+ )
986
+ # scaled_dot_product_attention expects attention_mask shape to be
987
+ # (batch, heads, source_length, target_length)
988
+ attention_mask = attention_mask.view(
989
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
990
+ )
991
+
992
+ if attn.group_norm is not None:
993
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
994
+ 1, 2
995
+ )
996
+
997
+ query = attn.to_q(hidden_states)
998
+ query = attn.q_norm(query)
999
+
1000
+ if encoder_hidden_states is not None:
1001
+ if attn.norm_cross:
1002
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
1003
+ encoder_hidden_states
1004
+ )
1005
+ key = attn.to_k(encoder_hidden_states)
1006
+ key = attn.k_norm(key)
1007
+ else: # if no context provided do self-attention
1008
+ encoder_hidden_states = hidden_states
1009
+ key = attn.to_k(hidden_states)
1010
+ key = attn.k_norm(key)
1011
+ if attn.use_rope:
1012
+ key = attn.apply_rotary_emb(key, freqs_cis)
1013
+ query = attn.apply_rotary_emb(query, freqs_cis)
1014
+
1015
+ value = attn.to_v(encoder_hidden_states)
1016
+ value_for_stg = value
1017
+
1018
+ inner_dim = key.shape[-1]
1019
+ head_dim = inner_dim // attn.heads
1020
+
1021
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1022
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1023
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1024
+
1025
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1026
+
1027
+ if attn.use_tpu_flash_attention: # use tpu attention offload 'flash attention'
1028
+ q_segment_indexes = None
1029
+ if (
1030
+ attention_mask is not None
1031
+ ): # if mask is required need to tune both segmenIds fields
1032
+ # attention_mask = torch.squeeze(attention_mask).to(torch.float32)
1033
+ attention_mask = attention_mask.to(torch.float32)
1034
+ q_segment_indexes = torch.ones(
1035
+ batch_size, query.shape[2], device=query.device, dtype=torch.float32
1036
+ )
1037
+ assert (
1038
+ attention_mask.shape[1] == key.shape[2]
1039
+ ), f"ERROR: KEY SHAPE must be same as attention mask [{key.shape[2]}, {attention_mask.shape[1]}]"
1040
+
1041
+ assert (
1042
+ query.shape[2] % 128 == 0
1043
+ ), f"ERROR: QUERY SHAPE must be divisible by 128 (TPU limitation) [{query.shape[2]}]"
1044
+ assert (
1045
+ key.shape[2] % 128 == 0
1046
+ ), f"ERROR: KEY SHAPE must be divisible by 128 (TPU limitation) [{key.shape[2]}]"
1047
+
1048
+ # run the TPU kernel implemented in jax with pallas
1049
+ hidden_states_a = flash_attention(
1050
+ q=query,
1051
+ k=key,
1052
+ v=value,
1053
+ q_segment_ids=q_segment_indexes,
1054
+ kv_segment_ids=attention_mask,
1055
+ sm_scale=attn.scale,
1056
+ )
1057
+ else:
1058
+ hidden_states_a = F.scaled_dot_product_attention(
1059
+ query,
1060
+ key,
1061
+ value,
1062
+ attn_mask=attention_mask,
1063
+ dropout_p=0.0,
1064
+ is_causal=False,
1065
+ )
1066
+
1067
+ hidden_states_a = hidden_states_a.transpose(1, 2).reshape(
1068
+ batch_size, -1, attn.heads * head_dim
1069
+ )
1070
+ hidden_states_a = hidden_states_a.to(query.dtype)
1071
+
1072
+ if (
1073
+ skip_layer_mask is not None
1074
+ and skip_layer_strategy == SkipLayerStrategy.AttentionSkip
1075
+ ):
1076
+ hidden_states = hidden_states_a * skip_layer_mask + hidden_states * (
1077
+ 1.0 - skip_layer_mask
1078
+ )
1079
+ elif (
1080
+ skip_layer_mask is not None
1081
+ and skip_layer_strategy == SkipLayerStrategy.AttentionValues
1082
+ ):
1083
+ hidden_states = hidden_states_a * skip_layer_mask + value_for_stg * (
1084
+ 1.0 - skip_layer_mask
1085
+ )
1086
+ else:
1087
+ hidden_states = hidden_states_a
1088
+
1089
+ # linear proj
1090
+ hidden_states = attn.to_out[0](hidden_states)
1091
+ # dropout
1092
+ hidden_states = attn.to_out[1](hidden_states)
1093
+
1094
+ if input_ndim == 4:
1095
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
1096
+ batch_size, channel, height, width
1097
+ )
1098
+ if (
1099
+ skip_layer_mask is not None
1100
+ and skip_layer_strategy == SkipLayerStrategy.Residual
1101
+ ):
1102
+ skip_layer_mask = skip_layer_mask.reshape(batch_size, 1, 1, 1)
1103
+
1104
+ if attn.residual_connection:
1105
+ if (
1106
+ skip_layer_mask is not None
1107
+ and skip_layer_strategy == SkipLayerStrategy.Residual
1108
+ ):
1109
+ hidden_states = hidden_states + residual * skip_layer_mask
1110
+ else:
1111
+ hidden_states = hidden_states + residual
1112
+
1113
+ hidden_states = hidden_states / attn.rescale_output_factor
1114
+
1115
+ return hidden_states
1116
+
1117
+
1118
+ class AttnProcessor:
1119
+ r"""
1120
+ Default processor for performing attention-related computations.
1121
+ """
1122
+
1123
+ def __call__(
1124
+ self,
1125
+ attn: Attention,
1126
+ hidden_states: torch.FloatTensor,
1127
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1128
+ attention_mask: Optional[torch.FloatTensor] = None,
1129
+ temb: Optional[torch.FloatTensor] = None,
1130
+ *args,
1131
+ **kwargs,
1132
+ ) -> torch.Tensor:
1133
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
1134
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1135
+ deprecate("scale", "1.0.0", deprecation_message)
1136
+
1137
+ residual = hidden_states
1138
+
1139
+ if attn.spatial_norm is not None:
1140
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1141
+
1142
+ input_ndim = hidden_states.ndim
1143
+
1144
+ if input_ndim == 4:
1145
+ batch_size, channel, height, width = hidden_states.shape
1146
+ hidden_states = hidden_states.view(
1147
+ batch_size, channel, height * width
1148
+ ).transpose(1, 2)
1149
+
1150
+ batch_size, sequence_length, _ = (
1151
+ hidden_states.shape
1152
+ if encoder_hidden_states is None
1153
+ else encoder_hidden_states.shape
1154
+ )
1155
+ attention_mask = attn.prepare_attention_mask(
1156
+ attention_mask, sequence_length, batch_size
1157
+ )
1158
+
1159
+ if attn.group_norm is not None:
1160
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
1161
+ 1, 2
1162
+ )
1163
+
1164
+ query = attn.to_q(hidden_states)
1165
+
1166
+ if encoder_hidden_states is None:
1167
+ encoder_hidden_states = hidden_states
1168
+ elif attn.norm_cross:
1169
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
1170
+ encoder_hidden_states
1171
+ )
1172
+
1173
+ key = attn.to_k(encoder_hidden_states)
1174
+ value = attn.to_v(encoder_hidden_states)
1175
+
1176
+ query = attn.head_to_batch_dim(query)
1177
+ key = attn.head_to_batch_dim(key)
1178
+ value = attn.head_to_batch_dim(value)
1179
+
1180
+ query = attn.q_norm(query)
1181
+ key = attn.k_norm(key)
1182
+
1183
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
1184
+ hidden_states = torch.bmm(attention_probs, value)
1185
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1186
+
1187
+ # linear proj
1188
+ hidden_states = attn.to_out[0](hidden_states)
1189
+ # dropout
1190
+ hidden_states = attn.to_out[1](hidden_states)
1191
+
1192
+ if input_ndim == 4:
1193
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
1194
+ batch_size, channel, height, width
1195
+ )
1196
+
1197
+ if attn.residual_connection:
1198
+ hidden_states = hidden_states + residual
1199
+
1200
+ hidden_states = hidden_states / attn.rescale_output_factor
1201
+
1202
+ return hidden_states
1203
+
1204
+
1205
+ class FeedForward(nn.Module):
1206
+ r"""
1207
+ A feed-forward layer.
1208
+
1209
+ Parameters:
1210
+ dim (`int`): The number of channels in the input.
1211
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
1212
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
1213
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
1214
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
1215
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
1216
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
1217
+ """
1218
+
1219
+ def __init__(
1220
+ self,
1221
+ dim: int,
1222
+ dim_out: Optional[int] = None,
1223
+ mult: int = 4,
1224
+ dropout: float = 0.0,
1225
+ activation_fn: str = "geglu",
1226
+ final_dropout: bool = False,
1227
+ inner_dim=None,
1228
+ bias: bool = True,
1229
+ ):
1230
+ super().__init__()
1231
+ if inner_dim is None:
1232
+ inner_dim = int(dim * mult)
1233
+ dim_out = dim_out if dim_out is not None else dim
1234
+ linear_cls = nn.Linear
1235
+
1236
+ if activation_fn == "gelu":
1237
+ act_fn = GELU(dim, inner_dim, bias=bias)
1238
+ elif activation_fn == "gelu-approximate":
1239
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
1240
+ elif activation_fn == "geglu":
1241
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
1242
+ elif activation_fn == "geglu-approximate":
1243
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
1244
+ else:
1245
+ raise ValueError(f"Unsupported activation function: {activation_fn}")
1246
+
1247
+ self.net = nn.ModuleList([])
1248
+ # project in
1249
+ self.net.append(act_fn)
1250
+ # project dropout
1251
+ self.net.append(nn.Dropout(dropout))
1252
+ # project out
1253
+ self.net.append(linear_cls(inner_dim, dim_out, bias=bias))
1254
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
1255
+ if final_dropout:
1256
+ self.net.append(nn.Dropout(dropout))
1257
+
1258
+ def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
1259
+ compatible_cls = (GEGLU, LoRACompatibleLinear)
1260
+ for module in self.net:
1261
+ if isinstance(module, compatible_cls):
1262
+ hidden_states = module(hidden_states, scale)
1263
+ else:
1264
+ hidden_states = module(hidden_states)
1265
+ return hidden_states
ltx_video/models/transformers/embeddings.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py
2
+ import math
3
+
4
+ import numpy as np
5
+ import torch
6
+ from einops import rearrange
7
+ from torch import nn
8
+
9
+
10
+ def get_timestep_embedding(
11
+ timesteps: torch.Tensor,
12
+ embedding_dim: int,
13
+ flip_sin_to_cos: bool = False,
14
+ downscale_freq_shift: float = 1,
15
+ scale: float = 1,
16
+ max_period: int = 10000,
17
+ ):
18
+ """
19
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
20
+
21
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
22
+ These may be fractional.
23
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
24
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
25
+ """
26
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
27
+
28
+ half_dim = embedding_dim // 2
29
+ exponent = -math.log(max_period) * torch.arange(
30
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
31
+ )
32
+ exponent = exponent / (half_dim - downscale_freq_shift)
33
+
34
+ emb = torch.exp(exponent)
35
+ emb = timesteps[:, None].float() * emb[None, :]
36
+
37
+ # scale embeddings
38
+ emb = scale * emb
39
+
40
+ # concat sine and cosine embeddings
41
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
42
+
43
+ # flip sine and cosine embeddings
44
+ if flip_sin_to_cos:
45
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
46
+
47
+ # zero pad
48
+ if embedding_dim % 2 == 1:
49
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
50
+ return emb
51
+
52
+
53
+ def get_3d_sincos_pos_embed(embed_dim, grid, w, h, f):
54
+ """
55
+ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
56
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
57
+ """
58
+ grid = rearrange(grid, "c (f h w) -> c f h w", h=h, w=w)
59
+ grid = rearrange(grid, "c f h w -> c h w f", h=h, w=w)
60
+ grid = grid.reshape([3, 1, w, h, f])
61
+ pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid)
62
+ pos_embed = pos_embed.transpose(1, 0, 2, 3)
63
+ return rearrange(pos_embed, "h w f c -> (f h w) c")
64
+
65
+
66
+ def get_3d_sincos_pos_embed_from_grid(embed_dim, grid):
67
+ if embed_dim % 3 != 0:
68
+ raise ValueError("embed_dim must be divisible by 3")
69
+
70
+ # use half of dimensions to encode grid_h
71
+ emb_f = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[0]) # (H*W*T, D/3)
72
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[1]) # (H*W*T, D/3)
73
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[2]) # (H*W*T, D/3)
74
+
75
+ emb = np.concatenate([emb_h, emb_w, emb_f], axis=-1) # (H*W*T, D)
76
+ return emb
77
+
78
+
79
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
80
+ """
81
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
82
+ """
83
+ if embed_dim % 2 != 0:
84
+ raise ValueError("embed_dim must be divisible by 2")
85
+
86
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
87
+ omega /= embed_dim / 2.0
88
+ omega = 1.0 / 10000**omega # (D/2,)
89
+
90
+ pos_shape = pos.shape
91
+
92
+ pos = pos.reshape(-1)
93
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
94
+ out = out.reshape([*pos_shape, -1])[0]
95
+
96
+ emb_sin = np.sin(out) # (M, D/2)
97
+ emb_cos = np.cos(out) # (M, D/2)
98
+
99
+ emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (M, D)
100
+ return emb
101
+
102
+
103
+ class SinusoidalPositionalEmbedding(nn.Module):
104
+ """Apply positional information to a sequence of embeddings.
105
+
106
+ Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to
107
+ them
108
+
109
+ Args:
110
+ embed_dim: (int): Dimension of the positional embedding.
111
+ max_seq_length: Maximum sequence length to apply positional embeddings
112
+
113
+ """
114
+
115
+ def __init__(self, embed_dim: int, max_seq_length: int = 32):
116
+ super().__init__()
117
+ position = torch.arange(max_seq_length).unsqueeze(1)
118
+ div_term = torch.exp(
119
+ torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)
120
+ )
121
+ pe = torch.zeros(1, max_seq_length, embed_dim)
122
+ pe[0, :, 0::2] = torch.sin(position * div_term)
123
+ pe[0, :, 1::2] = torch.cos(position * div_term)
124
+ self.register_buffer("pe", pe)
125
+
126
+ def forward(self, x):
127
+ _, seq_length, _ = x.shape
128
+ x = x + self.pe[:, :seq_length]
129
+ return x
ltx_video/models/transformers/symmetric_patchifier.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ from diffusers.configuration_utils import ConfigMixin
6
+ from einops import rearrange
7
+ from torch import Tensor
8
+
9
+
10
+ class Patchifier(ConfigMixin, ABC):
11
+ def __init__(self, patch_size: int):
12
+ super().__init__()
13
+ self._patch_size = (1, patch_size, patch_size)
14
+
15
+ @abstractmethod
16
+ def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]:
17
+ raise NotImplementedError("Patchify method not implemented")
18
+
19
+ @abstractmethod
20
+ def unpatchify(
21
+ self,
22
+ latents: Tensor,
23
+ output_height: int,
24
+ output_width: int,
25
+ out_channels: int,
26
+ ) -> Tuple[Tensor, Tensor]:
27
+ pass
28
+
29
+ @property
30
+ def patch_size(self):
31
+ return self._patch_size
32
+
33
+ def get_latent_coords(
34
+ self, latent_num_frames, latent_height, latent_width, batch_size, device
35
+ ):
36
+ """
37
+ Return a tensor of shape [batch_size, 3, num_patches] containing the
38
+ top-left corner latent coordinates of each latent patch.
39
+ The tensor is repeated for each batch element.
40
+ """
41
+ latent_sample_coords = torch.meshgrid(
42
+ torch.arange(0, latent_num_frames, self._patch_size[0], device=device),
43
+ torch.arange(0, latent_height, self._patch_size[1], device=device),
44
+ torch.arange(0, latent_width, self._patch_size[2], device=device),
45
+ )
46
+ latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
47
+ latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
48
+ latent_coords = rearrange(
49
+ latent_coords, "b c f h w -> b c (f h w)", b=batch_size
50
+ )
51
+ return latent_coords
52
+
53
+
54
+ class SymmetricPatchifier(Patchifier):
55
+ def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]:
56
+ b, _, f, h, w = latents.shape
57
+ latent_coords = self.get_latent_coords(f, h, w, b, latents.device)
58
+ latents = rearrange(
59
+ latents,
60
+ "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
61
+ p1=self._patch_size[0],
62
+ p2=self._patch_size[1],
63
+ p3=self._patch_size[2],
64
+ )
65
+ return latents, latent_coords
66
+
67
+ def unpatchify(
68
+ self,
69
+ latents: Tensor,
70
+ output_height: int,
71
+ output_width: int,
72
+ out_channels: int,
73
+ ) -> Tuple[Tensor, Tensor]:
74
+ output_height = output_height // self._patch_size[1]
75
+ output_width = output_width // self._patch_size[2]
76
+ latents = rearrange(
77
+ latents,
78
+ "b (f h w) (c p q) -> b c f (h p) (w q)",
79
+ h=output_height,
80
+ w=output_width,
81
+ p=self._patch_size[1],
82
+ q=self._patch_size[2],
83
+ )
84
+ return latents
ltx_video/models/transformers/transformer3d.py ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/models/transformers/transformer_2d.py
2
+ import math
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, List, Optional, Union
5
+ import os
6
+ import json
7
+ import glob
8
+ from pathlib import Path
9
+
10
+ import torch
11
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ from diffusers.models.embeddings import PixArtAlphaTextProjection
13
+ from diffusers.models.modeling_utils import ModelMixin
14
+ from diffusers.models.normalization import AdaLayerNormSingle
15
+ from diffusers.utils import BaseOutput, is_torch_version
16
+ from diffusers.utils import logging
17
+ from torch import nn
18
+ from safetensors import safe_open
19
+
20
+
21
+ from ltx_video.models.transformers.attention import BasicTransformerBlock
22
+ from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
23
+
24
+ from ltx_video.utils.diffusers_config_mapping import (
25
+ diffusers_and_ours_config_mapping,
26
+ make_hashable_key,
27
+ TRANSFORMER_KEYS_RENAME_DICT,
28
+ )
29
+
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+
34
+ @dataclass
35
+ class Transformer3DModelOutput(BaseOutput):
36
+ """
37
+ The output of [`Transformer2DModel`].
38
+
39
+ Args:
40
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
41
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
42
+ distributions for the unnoised latent pixels.
43
+ """
44
+
45
+ sample: torch.FloatTensor
46
+
47
+
48
+ class Transformer3DModel(ModelMixin, ConfigMixin):
49
+ _supports_gradient_checkpointing = True
50
+
51
+ @register_to_config
52
+ def __init__(
53
+ self,
54
+ num_attention_heads: int = 16,
55
+ attention_head_dim: int = 88,
56
+ in_channels: Optional[int] = None,
57
+ out_channels: Optional[int] = None,
58
+ num_layers: int = 1,
59
+ dropout: float = 0.0,
60
+ norm_num_groups: int = 32,
61
+ cross_attention_dim: Optional[int] = None,
62
+ attention_bias: bool = False,
63
+ num_vector_embeds: Optional[int] = None,
64
+ activation_fn: str = "geglu",
65
+ num_embeds_ada_norm: Optional[int] = None,
66
+ use_linear_projection: bool = False,
67
+ only_cross_attention: bool = False,
68
+ double_self_attention: bool = False,
69
+ upcast_attention: bool = False,
70
+ adaptive_norm: str = "single_scale_shift", # 'single_scale_shift' or 'single_scale'
71
+ standardization_norm: str = "layer_norm", # 'layer_norm' or 'rms_norm'
72
+ norm_elementwise_affine: bool = True,
73
+ norm_eps: float = 1e-5,
74
+ attention_type: str = "default",
75
+ caption_channels: int = None,
76
+ use_tpu_flash_attention: bool = False, # if True uses the TPU attention offload ('flash attention')
77
+ qk_norm: Optional[str] = None,
78
+ positional_embedding_type: str = "rope",
79
+ positional_embedding_theta: Optional[float] = None,
80
+ positional_embedding_max_pos: Optional[List[int]] = None,
81
+ timestep_scale_multiplier: Optional[float] = None,
82
+ causal_temporal_positioning: bool = False, # For backward compatibility, will be deprecated
83
+ ):
84
+ super().__init__()
85
+ self.use_tpu_flash_attention = (
86
+ use_tpu_flash_attention # FIXME: push config down to the attention modules
87
+ )
88
+ self.use_linear_projection = use_linear_projection
89
+ self.num_attention_heads = num_attention_heads
90
+ self.attention_head_dim = attention_head_dim
91
+ inner_dim = num_attention_heads * attention_head_dim
92
+ self.inner_dim = inner_dim
93
+ self.patchify_proj = nn.Linear(in_channels, inner_dim, bias=True)
94
+ self.positional_embedding_type = positional_embedding_type
95
+ self.positional_embedding_theta = positional_embedding_theta
96
+ self.positional_embedding_max_pos = positional_embedding_max_pos
97
+ self.use_rope = self.positional_embedding_type == "rope"
98
+ self.timestep_scale_multiplier = timestep_scale_multiplier
99
+
100
+ if self.positional_embedding_type == "absolute":
101
+ raise ValueError("Absolute positional embedding is no longer supported")
102
+ elif self.positional_embedding_type == "rope":
103
+ if positional_embedding_theta is None:
104
+ raise ValueError(
105
+ "If `positional_embedding_type` type is rope, `positional_embedding_theta` must also be defined"
106
+ )
107
+ if positional_embedding_max_pos is None:
108
+ raise ValueError(
109
+ "If `positional_embedding_type` type is rope, `positional_embedding_max_pos` must also be defined"
110
+ )
111
+
112
+ # 3. Define transformers blocks
113
+ self.transformer_blocks = nn.ModuleList(
114
+ [
115
+ BasicTransformerBlock(
116
+ inner_dim,
117
+ num_attention_heads,
118
+ attention_head_dim,
119
+ dropout=dropout,
120
+ cross_attention_dim=cross_attention_dim,
121
+ activation_fn=activation_fn,
122
+ num_embeds_ada_norm=num_embeds_ada_norm,
123
+ attention_bias=attention_bias,
124
+ only_cross_attention=only_cross_attention,
125
+ double_self_attention=double_self_attention,
126
+ upcast_attention=upcast_attention,
127
+ adaptive_norm=adaptive_norm,
128
+ standardization_norm=standardization_norm,
129
+ norm_elementwise_affine=norm_elementwise_affine,
130
+ norm_eps=norm_eps,
131
+ attention_type=attention_type,
132
+ use_tpu_flash_attention=use_tpu_flash_attention,
133
+ qk_norm=qk_norm,
134
+ use_rope=self.use_rope,
135
+ )
136
+ for d in range(num_layers)
137
+ ]
138
+ )
139
+
140
+ # 4. Define output layers
141
+ self.out_channels = in_channels if out_channels is None else out_channels
142
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
143
+ self.scale_shift_table = nn.Parameter(
144
+ torch.randn(2, inner_dim) / inner_dim**0.5
145
+ )
146
+ self.proj_out = nn.Linear(inner_dim, self.out_channels)
147
+
148
+ self.adaln_single = AdaLayerNormSingle(
149
+ inner_dim, use_additional_conditions=False
150
+ )
151
+ if adaptive_norm == "single_scale":
152
+ self.adaln_single.linear = nn.Linear(inner_dim, 4 * inner_dim, bias=True)
153
+
154
+ self.caption_projection = None
155
+ if caption_channels is not None:
156
+ self.caption_projection = PixArtAlphaTextProjection(
157
+ in_features=caption_channels, hidden_size=inner_dim
158
+ )
159
+
160
+ self.gradient_checkpointing = False
161
+
162
+ def set_use_tpu_flash_attention(self):
163
+ r"""
164
+ Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU
165
+ attention kernel.
166
+ """
167
+ logger.info("ENABLE TPU FLASH ATTENTION -> TRUE")
168
+ self.use_tpu_flash_attention = True
169
+ # push config down to the attention modules
170
+ for block in self.transformer_blocks:
171
+ block.set_use_tpu_flash_attention()
172
+
173
+ def create_skip_layer_mask(
174
+ self,
175
+ batch_size: int,
176
+ num_conds: int,
177
+ ptb_index: int,
178
+ skip_block_list: Optional[List[int]] = None,
179
+ ):
180
+ if skip_block_list is None or len(skip_block_list) == 0:
181
+ return None
182
+ num_layers = len(self.transformer_blocks)
183
+ mask = torch.ones(
184
+ (num_layers, batch_size * num_conds), device=self.device, dtype=self.dtype
185
+ )
186
+ for block_idx in skip_block_list:
187
+ mask[block_idx, ptb_index::num_conds] = 0
188
+ return mask
189
+
190
+ def _set_gradient_checkpointing(self, module, value=False):
191
+ if hasattr(module, "gradient_checkpointing"):
192
+ module.gradient_checkpointing = value
193
+
194
+ def get_fractional_positions(self, indices_grid):
195
+ fractional_positions = torch.stack(
196
+ [
197
+ indices_grid[:, i] / self.positional_embedding_max_pos[i]
198
+ for i in range(3)
199
+ ],
200
+ dim=-1,
201
+ )
202
+ return fractional_positions
203
+
204
+ def precompute_freqs_cis(self, indices_grid, spacing="exp"):
205
+ dtype = torch.float32 # We need full precision in the freqs_cis computation.
206
+ dim = self.inner_dim
207
+ theta = self.positional_embedding_theta
208
+
209
+ fractional_positions = self.get_fractional_positions(indices_grid)
210
+
211
+ start = 1
212
+ end = theta
213
+ device = fractional_positions.device
214
+ if spacing == "exp":
215
+ indices = theta ** (
216
+ torch.linspace(
217
+ math.log(start, theta),
218
+ math.log(end, theta),
219
+ dim // 6,
220
+ device=device,
221
+ dtype=dtype,
222
+ )
223
+ )
224
+ indices = indices.to(dtype=dtype)
225
+ elif spacing == "exp_2":
226
+ indices = 1.0 / theta ** (torch.arange(0, dim, 6, device=device) / dim)
227
+ indices = indices.to(dtype=dtype)
228
+ elif spacing == "linear":
229
+ indices = torch.linspace(start, end, dim // 6, device=device, dtype=dtype)
230
+ elif spacing == "sqrt":
231
+ indices = torch.linspace(
232
+ start**2, end**2, dim // 6, device=device, dtype=dtype
233
+ ).sqrt()
234
+
235
+ indices = indices * math.pi / 2
236
+
237
+ if spacing == "exp_2":
238
+ freqs = (
239
+ (indices * fractional_positions.unsqueeze(-1))
240
+ .transpose(-1, -2)
241
+ .flatten(2)
242
+ )
243
+ else:
244
+ freqs = (
245
+ (indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
246
+ .transpose(-1, -2)
247
+ .flatten(2)
248
+ )
249
+
250
+ cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
251
+ sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
252
+ if dim % 6 != 0:
253
+ cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
254
+ sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
255
+ cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
256
+ sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
257
+ return cos_freq.to(self.dtype), sin_freq.to(self.dtype)
258
+
259
+ def load_state_dict(
260
+ self,
261
+ state_dict: Dict,
262
+ *args,
263
+ **kwargs,
264
+ ):
265
+ if any([key.startswith("model.diffusion_model.") for key in state_dict.keys()]):
266
+ state_dict = {
267
+ key.replace("model.diffusion_model.", ""): value
268
+ for key, value in state_dict.items()
269
+ if key.startswith("model.diffusion_model.")
270
+ }
271
+ super().load_state_dict(state_dict, **kwargs)
272
+
273
+ @classmethod
274
+ def from_pretrained(
275
+ cls,
276
+ pretrained_model_path: Optional[Union[str, os.PathLike]],
277
+ *args,
278
+ **kwargs,
279
+ ):
280
+ pretrained_model_path = Path(pretrained_model_path)
281
+ if pretrained_model_path.is_dir():
282
+ config_path = pretrained_model_path / "transformer" / "config.json"
283
+ with open(config_path, "r") as f:
284
+ config = make_hashable_key(json.load(f))
285
+
286
+ assert config in diffusers_and_ours_config_mapping, (
287
+ "Provided diffusers checkpoint config for transformer is not suppported. "
288
+ "We only support diffusers configs found in Lightricks/LTX-Video."
289
+ )
290
+
291
+ config = diffusers_and_ours_config_mapping[config]
292
+ state_dict = {}
293
+ ckpt_paths = (
294
+ pretrained_model_path
295
+ / "transformer"
296
+ / "diffusion_pytorch_model*.safetensors"
297
+ )
298
+ dict_list = glob.glob(str(ckpt_paths))
299
+ for dict_path in dict_list:
300
+ part_dict = {}
301
+ with safe_open(dict_path, framework="pt", device="cpu") as f:
302
+ for k in f.keys():
303
+ part_dict[k] = f.get_tensor(k)
304
+ state_dict.update(part_dict)
305
+
306
+ for key in list(state_dict.keys()):
307
+ new_key = key
308
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
309
+ new_key = new_key.replace(replace_key, rename_key)
310
+ state_dict[new_key] = state_dict.pop(key)
311
+
312
+ with torch.device("meta"):
313
+ transformer = cls.from_config(config)
314
+ transformer.load_state_dict(state_dict, assign=True, strict=True)
315
+ elif pretrained_model_path.is_file() and str(pretrained_model_path).endswith(
316
+ ".safetensors"
317
+ ):
318
+ comfy_single_file_state_dict = {}
319
+ with safe_open(pretrained_model_path, framework="pt", device="cpu") as f:
320
+ metadata = f.metadata()
321
+ for k in f.keys():
322
+ comfy_single_file_state_dict[k] = f.get_tensor(k)
323
+ configs = json.loads(metadata["config"])
324
+ transformer_config = configs["transformer"]
325
+ with torch.device("meta"):
326
+ transformer = Transformer3DModel.from_config(transformer_config)
327
+ transformer.load_state_dict(comfy_single_file_state_dict, assign=True)
328
+ return transformer
329
+
330
+ def forward(
331
+ self,
332
+ hidden_states: torch.Tensor,
333
+ indices_grid: torch.Tensor,
334
+ encoder_hidden_states: Optional[torch.Tensor] = None,
335
+ timestep: Optional[torch.LongTensor] = None,
336
+ class_labels: Optional[torch.LongTensor] = None,
337
+ cross_attention_kwargs: Dict[str, Any] = None,
338
+ attention_mask: Optional[torch.Tensor] = None,
339
+ encoder_attention_mask: Optional[torch.Tensor] = None,
340
+ skip_layer_mask: Optional[torch.Tensor] = None,
341
+ skip_layer_strategy: Optional[SkipLayerStrategy] = None,
342
+ return_dict: bool = True,
343
+ ):
344
+ """
345
+ The [`Transformer2DModel`] forward method.
346
+
347
+ Args:
348
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
349
+ Input `hidden_states`.
350
+ indices_grid (`torch.LongTensor` of shape `(batch size, 3, num latent pixels)`):
351
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
352
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
353
+ self-attention.
354
+ timestep ( `torch.LongTensor`, *optional*):
355
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
356
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
357
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
358
+ `AdaLayerZeroNorm`.
359
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
360
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
361
+ `self.processor` in
362
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
363
+ attention_mask ( `torch.Tensor`, *optional*):
364
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
365
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
366
+ negative values to the attention scores corresponding to "discard" tokens.
367
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
368
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
369
+
370
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
371
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
372
+
373
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
374
+ above. This bias will be added to the cross-attention scores.
375
+ skip_layer_mask ( `torch.Tensor`, *optional*):
376
+ A mask of shape `(num_layers, batch)` that indicates which layers to skip. `0` at position
377
+ `layer, batch_idx` indicates that the layer should be skipped for the corresponding batch index.
378
+ skip_layer_strategy ( `SkipLayerStrategy`, *optional*, defaults to `None`):
379
+ Controls which layers are skipped when calculating a perturbed latent for spatiotemporal guidance.
380
+ return_dict (`bool`, *optional*, defaults to `True`):
381
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
382
+ tuple.
383
+
384
+ Returns:
385
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
386
+ `tuple` where the first element is the sample tensor.
387
+ """
388
+ # for tpu attention offload 2d token masks are used. No need to transform.
389
+ if not self.use_tpu_flash_attention:
390
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
391
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
392
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
393
+ # expects mask of shape:
394
+ # [batch, key_tokens]
395
+ # adds singleton query_tokens dimension:
396
+ # [batch, 1, key_tokens]
397
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
398
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
399
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
400
+ if attention_mask is not None and attention_mask.ndim == 2:
401
+ # assume that mask is expressed as:
402
+ # (1 = keep, 0 = discard)
403
+ # convert mask into a bias that can be added to attention scores:
404
+ # (keep = +0, discard = -10000.0)
405
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
406
+ attention_mask = attention_mask.unsqueeze(1)
407
+
408
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
409
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
410
+ encoder_attention_mask = (
411
+ 1 - encoder_attention_mask.to(hidden_states.dtype)
412
+ ) * -10000.0
413
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
414
+
415
+ # 1. Input
416
+ hidden_states = self.patchify_proj(hidden_states)
417
+
418
+ if self.timestep_scale_multiplier:
419
+ timestep = self.timestep_scale_multiplier * timestep
420
+
421
+ freqs_cis = self.precompute_freqs_cis(indices_grid)
422
+
423
+ batch_size = hidden_states.shape[0]
424
+ timestep, embedded_timestep = self.adaln_single(
425
+ timestep.flatten(),
426
+ {"resolution": None, "aspect_ratio": None},
427
+ batch_size=batch_size,
428
+ hidden_dtype=hidden_states.dtype,
429
+ )
430
+ # Second dimension is 1 or number of tokens (if timestep_per_token)
431
+ timestep = timestep.view(batch_size, -1, timestep.shape[-1])
432
+ embedded_timestep = embedded_timestep.view(
433
+ batch_size, -1, embedded_timestep.shape[-1]
434
+ )
435
+
436
+ # 2. Blocks
437
+ if self.caption_projection is not None:
438
+ batch_size = hidden_states.shape[0]
439
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
440
+ encoder_hidden_states = encoder_hidden_states.view(
441
+ batch_size, -1, hidden_states.shape[-1]
442
+ )
443
+
444
+ for block_idx, block in enumerate(self.transformer_blocks):
445
+ if self.training and self.gradient_checkpointing:
446
+
447
+ def create_custom_forward(module, return_dict=None):
448
+ def custom_forward(*inputs):
449
+ if return_dict is not None:
450
+ return module(*inputs, return_dict=return_dict)
451
+ else:
452
+ return module(*inputs)
453
+
454
+ return custom_forward
455
+
456
+ ckpt_kwargs: Dict[str, Any] = (
457
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
458
+ )
459
+ hidden_states = torch.utils.checkpoint.checkpoint(
460
+ create_custom_forward(block),
461
+ hidden_states,
462
+ freqs_cis,
463
+ attention_mask,
464
+ encoder_hidden_states,
465
+ encoder_attention_mask,
466
+ timestep,
467
+ cross_attention_kwargs,
468
+ class_labels,
469
+ (
470
+ skip_layer_mask[block_idx]
471
+ if skip_layer_mask is not None
472
+ else None
473
+ ),
474
+ skip_layer_strategy,
475
+ **ckpt_kwargs,
476
+ )
477
+ else:
478
+ hidden_states = block(
479
+ hidden_states,
480
+ freqs_cis=freqs_cis,
481
+ attention_mask=attention_mask,
482
+ encoder_hidden_states=encoder_hidden_states,
483
+ encoder_attention_mask=encoder_attention_mask,
484
+ timestep=timestep,
485
+ cross_attention_kwargs=cross_attention_kwargs,
486
+ class_labels=class_labels,
487
+ skip_layer_mask=(
488
+ skip_layer_mask[block_idx]
489
+ if skip_layer_mask is not None
490
+ else None
491
+ ),
492
+ skip_layer_strategy=skip_layer_strategy,
493
+ )
494
+
495
+ # 3. Output
496
+ scale_shift_values = (
497
+ self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
498
+ )
499
+ shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
500
+ hidden_states = self.norm_out(hidden_states)
501
+ # Modulation
502
+ hidden_states = hidden_states * (1 + scale) + shift
503
+ hidden_states = self.proj_out(hidden_states)
504
+ if not return_dict:
505
+ return (hidden_states,)
506
+
507
+ return Transformer3DModelOutput(sample=hidden_states)