Sek2810 commited on
Commit
901fb88
·
verified ·
1 Parent(s): edb3cd2

Delete models

Browse files
models/__init__.py DELETED
File without changes
models/attention.py DELETED
@@ -1,1245 +0,0 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from typing import Any, Dict, List, Optional, Tuple
15
-
16
- import torch
17
- import torch.nn.functional as F
18
- from torch import nn
19
-
20
- from diffusers.utils import deprecate, logging
21
- from diffusers.utils.torch_utils import maybe_allow_in_graph
22
- from diffusers.models.activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
23
- from diffusers.models.attention_processor import Attention, JointAttnProcessor2_0
24
- from diffusers.models.embeddings import SinusoidalPositionalEmbedding
25
- from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
26
-
27
-
28
- logger = logging.get_logger(__name__)
29
-
30
-
31
- def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
32
- # "feed_forward_chunk_size" can be used to save memory
33
- if hidden_states.shape[chunk_dim] % chunk_size != 0:
34
- raise ValueError(
35
- f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
36
- )
37
-
38
- num_chunks = hidden_states.shape[chunk_dim] // chunk_size
39
- ff_output = torch.cat(
40
- [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
41
- dim=chunk_dim,
42
- )
43
- return ff_output
44
-
45
-
46
- @maybe_allow_in_graph
47
- class GatedSelfAttentionDense(nn.Module):
48
- r"""
49
- A gated self-attention dense layer that combines visual features and object features.
50
-
51
- Parameters:
52
- query_dim (`int`): The number of channels in the query.
53
- context_dim (`int`): The number of channels in the context.
54
- n_heads (`int`): The number of heads to use for attention.
55
- d_head (`int`): The number of channels in each head.
56
- """
57
-
58
- def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
59
- super().__init__()
60
-
61
- # we need a linear projection since we need cat visual feature and obj feature
62
- self.linear = nn.Linear(context_dim, query_dim)
63
-
64
- self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
65
- self.ff = FeedForward(query_dim, activation_fn="geglu")
66
-
67
- self.norm1 = nn.LayerNorm(query_dim)
68
- self.norm2 = nn.LayerNorm(query_dim)
69
-
70
- self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
71
- self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
72
-
73
- self.enabled = True
74
-
75
- def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
76
- if not self.enabled:
77
- return x
78
-
79
- n_visual = x.shape[1]
80
- objs = self.linear(objs)
81
-
82
- x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
83
- x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
84
-
85
- return x
86
-
87
-
88
- @maybe_allow_in_graph
89
- class JointTransformerBlock(nn.Module):
90
- r"""
91
- A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
92
-
93
- Reference: https://arxiv.org/abs/2403.03206
94
-
95
- Parameters:
96
- dim (`int`): The number of channels in the input and output.
97
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
98
- attention_head_dim (`int`): The number of channels in each head.
99
- context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
100
- processing of `context` conditions.
101
- """
102
-
103
- def __init__(
104
- self,
105
- dim: int,
106
- num_attention_heads: int,
107
- attention_head_dim: int,
108
- context_pre_only: bool = False,
109
- qk_norm: Optional[str] = None,
110
- use_dual_attention: bool = False,
111
- ):
112
- super().__init__()
113
-
114
- self.use_dual_attention = use_dual_attention
115
- self.context_pre_only = context_pre_only
116
- context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
117
-
118
- if use_dual_attention:
119
- self.norm1 = SD35AdaLayerNormZeroX(dim)
120
- else:
121
- self.norm1 = AdaLayerNormZero(dim)
122
-
123
- if context_norm_type == "ada_norm_continous":
124
- self.norm1_context = AdaLayerNormContinuous(
125
- dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
126
- )
127
- elif context_norm_type == "ada_norm_zero":
128
- self.norm1_context = AdaLayerNormZero(dim)
129
- else:
130
- raise ValueError(
131
- f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
132
- )
133
-
134
- if hasattr(F, "scaled_dot_product_attention"):
135
- processor = JointAttnProcessor2_0()
136
- else:
137
- raise ValueError(
138
- "The current PyTorch version does not support the `scaled_dot_product_attention` function."
139
- )
140
-
141
- self.attn = Attention(
142
- query_dim=dim,
143
- cross_attention_dim=None,
144
- added_kv_proj_dim=dim,
145
- dim_head=attention_head_dim,
146
- heads=num_attention_heads,
147
- out_dim=dim,
148
- context_pre_only=context_pre_only,
149
- bias=True,
150
- processor=processor,
151
- qk_norm=qk_norm,
152
- eps=1e-6,
153
- )
154
-
155
- if use_dual_attention:
156
- self.attn2 = Attention(
157
- query_dim=dim,
158
- cross_attention_dim=None,
159
- dim_head=attention_head_dim,
160
- heads=num_attention_heads,
161
- out_dim=dim,
162
- bias=True,
163
- processor=processor,
164
- qk_norm=qk_norm,
165
- eps=1e-6,
166
- )
167
- else:
168
- self.attn2 = None
169
-
170
- self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
171
- self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
172
-
173
- if not context_pre_only:
174
- self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
175
- self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
176
- else:
177
- self.norm2_context = None
178
- self.ff_context = None
179
-
180
- # let chunk size default to None
181
- self._chunk_size = None
182
- self._chunk_dim = 0
183
-
184
- # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
185
- def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
186
- # Sets chunk feed-forward
187
- self._chunk_size = chunk_size
188
- self._chunk_dim = dim
189
-
190
- def forward(
191
- self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor,
192
- joint_attention_kwargs=None,
193
- ):
194
- if self.use_dual_attention:
195
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
196
- hidden_states, emb=temb
197
- )
198
- else:
199
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
200
-
201
- if self.context_pre_only:
202
- norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
203
- else:
204
- norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
205
- encoder_hidden_states, emb=temb
206
- )
207
-
208
- # Attention.
209
- attn_output, context_attn_output = self.attn(
210
- hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states,
211
- **({} if joint_attention_kwargs is None else joint_attention_kwargs),
212
- )
213
-
214
- # Process attention outputs for the `hidden_states`.
215
- attn_output = gate_msa.unsqueeze(1) * attn_output
216
- hidden_states = hidden_states + attn_output
217
-
218
- if self.use_dual_attention:
219
- attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **({} if joint_attention_kwargs is None else joint_attention_kwargs),)
220
- attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
221
- hidden_states = hidden_states + attn_output2
222
-
223
- norm_hidden_states = self.norm2(hidden_states)
224
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
225
- if self._chunk_size is not None:
226
- # "feed_forward_chunk_size" can be used to save memory
227
- ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
228
- else:
229
- ff_output = self.ff(norm_hidden_states)
230
- ff_output = gate_mlp.unsqueeze(1) * ff_output
231
-
232
- hidden_states = hidden_states + ff_output
233
-
234
- # Process attention outputs for the `encoder_hidden_states`.
235
- if self.context_pre_only:
236
- encoder_hidden_states = None
237
- else:
238
- context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
239
- encoder_hidden_states = encoder_hidden_states + context_attn_output
240
-
241
- norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
242
- norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
243
- if self._chunk_size is not None:
244
- # "feed_forward_chunk_size" can be used to save memory
245
- context_ff_output = _chunked_feed_forward(
246
- self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
247
- )
248
- else:
249
- context_ff_output = self.ff_context(norm_encoder_hidden_states)
250
- encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
251
-
252
- return encoder_hidden_states, hidden_states
253
-
254
-
255
- @maybe_allow_in_graph
256
- class BasicTransformerBlock(nn.Module):
257
- r"""
258
- A basic Transformer block.
259
-
260
- Parameters:
261
- dim (`int`): The number of channels in the input and output.
262
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
263
- attention_head_dim (`int`): The number of channels in each head.
264
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
265
- cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
266
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
267
- num_embeds_ada_norm (:
268
- obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
269
- attention_bias (:
270
- obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
271
- only_cross_attention (`bool`, *optional*):
272
- Whether to use only cross-attention layers. In this case two cross attention layers are used.
273
- double_self_attention (`bool`, *optional*):
274
- Whether to use two self-attention layers. In this case no cross attention layers are used.
275
- upcast_attention (`bool`, *optional*):
276
- Whether to upcast the attention computation to float32. This is useful for mixed precision training.
277
- norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
278
- Whether to use learnable elementwise affine parameters for normalization.
279
- norm_type (`str`, *optional*, defaults to `"layer_norm"`):
280
- The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
281
- final_dropout (`bool` *optional*, defaults to False):
282
- Whether to apply a final dropout after the last feed-forward layer.
283
- attention_type (`str`, *optional*, defaults to `"default"`):
284
- The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
285
- positional_embeddings (`str`, *optional*, defaults to `None`):
286
- The type of positional embeddings to apply to.
287
- num_positional_embeddings (`int`, *optional*, defaults to `None`):
288
- The maximum number of positional embeddings to apply.
289
- """
290
-
291
- def __init__(
292
- self,
293
- dim: int,
294
- num_attention_heads: int,
295
- attention_head_dim: int,
296
- dropout=0.0,
297
- cross_attention_dim: Optional[int] = None,
298
- activation_fn: str = "geglu",
299
- num_embeds_ada_norm: Optional[int] = None,
300
- attention_bias: bool = False,
301
- only_cross_attention: bool = False,
302
- double_self_attention: bool = False,
303
- upcast_attention: bool = False,
304
- norm_elementwise_affine: bool = True,
305
- norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
306
- norm_eps: float = 1e-5,
307
- final_dropout: bool = False,
308
- attention_type: str = "default",
309
- positional_embeddings: Optional[str] = None,
310
- num_positional_embeddings: Optional[int] = None,
311
- ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
312
- ada_norm_bias: Optional[int] = None,
313
- ff_inner_dim: Optional[int] = None,
314
- ff_bias: bool = True,
315
- attention_out_bias: bool = True,
316
- ):
317
- super().__init__()
318
- self.dim = dim
319
- self.num_attention_heads = num_attention_heads
320
- self.attention_head_dim = attention_head_dim
321
- self.dropout = dropout
322
- self.cross_attention_dim = cross_attention_dim
323
- self.activation_fn = activation_fn
324
- self.attention_bias = attention_bias
325
- self.double_self_attention = double_self_attention
326
- self.norm_elementwise_affine = norm_elementwise_affine
327
- self.positional_embeddings = positional_embeddings
328
- self.num_positional_embeddings = num_positional_embeddings
329
- self.only_cross_attention = only_cross_attention
330
-
331
- # We keep these boolean flags for backward-compatibility.
332
- self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
333
- self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
334
- self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
335
- self.use_layer_norm = norm_type == "layer_norm"
336
- self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
337
-
338
- if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
339
- raise ValueError(
340
- f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
341
- f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
342
- )
343
-
344
- self.norm_type = norm_type
345
- self.num_embeds_ada_norm = num_embeds_ada_norm
346
-
347
- if positional_embeddings and (num_positional_embeddings is None):
348
- raise ValueError(
349
- "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
350
- )
351
-
352
- if positional_embeddings == "sinusoidal":
353
- self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
354
- else:
355
- self.pos_embed = None
356
-
357
- # Define 3 blocks. Each block has its own normalization layer.
358
- # 1. Self-Attn
359
- if norm_type == "ada_norm":
360
- self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
361
- elif norm_type == "ada_norm_zero":
362
- self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
363
- elif norm_type == "ada_norm_continuous":
364
- self.norm1 = AdaLayerNormContinuous(
365
- dim,
366
- ada_norm_continous_conditioning_embedding_dim,
367
- norm_elementwise_affine,
368
- norm_eps,
369
- ada_norm_bias,
370
- "rms_norm",
371
- )
372
- else:
373
- self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
374
-
375
- self.attn1 = Attention(
376
- query_dim=dim,
377
- heads=num_attention_heads,
378
- dim_head=attention_head_dim,
379
- dropout=dropout,
380
- bias=attention_bias,
381
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
382
- upcast_attention=upcast_attention,
383
- out_bias=attention_out_bias,
384
- )
385
-
386
- # 2. Cross-Attn
387
- if cross_attention_dim is not None or double_self_attention:
388
- # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
389
- # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
390
- # the second cross attention block.
391
- if norm_type == "ada_norm":
392
- self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
393
- elif norm_type == "ada_norm_continuous":
394
- self.norm2 = AdaLayerNormContinuous(
395
- dim,
396
- ada_norm_continous_conditioning_embedding_dim,
397
- norm_elementwise_affine,
398
- norm_eps,
399
- ada_norm_bias,
400
- "rms_norm",
401
- )
402
- else:
403
- self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
404
-
405
- self.attn2 = Attention(
406
- query_dim=dim,
407
- cross_attention_dim=cross_attention_dim if not double_self_attention else None,
408
- heads=num_attention_heads,
409
- dim_head=attention_head_dim,
410
- dropout=dropout,
411
- bias=attention_bias,
412
- upcast_attention=upcast_attention,
413
- out_bias=attention_out_bias,
414
- ) # is self-attn if encoder_hidden_states is none
415
- else:
416
- if norm_type == "ada_norm_single": # For Latte
417
- self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
418
- else:
419
- self.norm2 = None
420
- self.attn2 = None
421
-
422
- # 3. Feed-forward
423
- if norm_type == "ada_norm_continuous":
424
- self.norm3 = AdaLayerNormContinuous(
425
- dim,
426
- ada_norm_continous_conditioning_embedding_dim,
427
- norm_elementwise_affine,
428
- norm_eps,
429
- ada_norm_bias,
430
- "layer_norm",
431
- )
432
-
433
- elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
434
- self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
435
- elif norm_type == "layer_norm_i2vgen":
436
- self.norm3 = None
437
-
438
- self.ff = FeedForward(
439
- dim,
440
- dropout=dropout,
441
- activation_fn=activation_fn,
442
- final_dropout=final_dropout,
443
- inner_dim=ff_inner_dim,
444
- bias=ff_bias,
445
- )
446
-
447
- # 4. Fuser
448
- if attention_type == "gated" or attention_type == "gated-text-image":
449
- self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
450
-
451
- # 5. Scale-shift for PixArt-Alpha.
452
- if norm_type == "ada_norm_single":
453
- self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
454
-
455
- # let chunk size default to None
456
- self._chunk_size = None
457
- self._chunk_dim = 0
458
-
459
- def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
460
- # Sets chunk feed-forward
461
- self._chunk_size = chunk_size
462
- self._chunk_dim = dim
463
-
464
- def forward(
465
- self,
466
- hidden_states: torch.Tensor,
467
- attention_mask: Optional[torch.Tensor] = None,
468
- encoder_hidden_states: Optional[torch.Tensor] = None,
469
- encoder_attention_mask: Optional[torch.Tensor] = None,
470
- timestep: Optional[torch.LongTensor] = None,
471
- cross_attention_kwargs: Dict[str, Any] = None,
472
- class_labels: Optional[torch.LongTensor] = None,
473
- added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
474
- ) -> torch.Tensor:
475
- if cross_attention_kwargs is not None:
476
- if cross_attention_kwargs.get("scale", None) is not None:
477
- logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
478
-
479
- # Notice that normalization is always applied before the real computation in the following blocks.
480
- # 0. Self-Attention
481
- batch_size = hidden_states.shape[0]
482
-
483
- if self.norm_type == "ada_norm":
484
- norm_hidden_states = self.norm1(hidden_states, timestep)
485
- elif self.norm_type == "ada_norm_zero":
486
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
487
- hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
488
- )
489
- elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
490
- norm_hidden_states = self.norm1(hidden_states)
491
- elif self.norm_type == "ada_norm_continuous":
492
- norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
493
- elif self.norm_type == "ada_norm_single":
494
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
495
- self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
496
- ).chunk(6, dim=1)
497
- norm_hidden_states = self.norm1(hidden_states)
498
- norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
499
- else:
500
- raise ValueError("Incorrect norm used")
501
-
502
- if self.pos_embed is not None:
503
- norm_hidden_states = self.pos_embed(norm_hidden_states)
504
-
505
- # 1. Prepare GLIGEN inputs
506
- cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
507
- gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
508
-
509
- attn_output = self.attn1(
510
- norm_hidden_states,
511
- encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
512
- attention_mask=attention_mask,
513
- **cross_attention_kwargs,
514
- )
515
-
516
- if self.norm_type == "ada_norm_zero":
517
- attn_output = gate_msa.unsqueeze(1) * attn_output
518
- elif self.norm_type == "ada_norm_single":
519
- attn_output = gate_msa * attn_output
520
-
521
- hidden_states = attn_output + hidden_states
522
- if hidden_states.ndim == 4:
523
- hidden_states = hidden_states.squeeze(1)
524
-
525
- # 1.2 GLIGEN Control
526
- if gligen_kwargs is not None:
527
- hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
528
-
529
- # 3. Cross-Attention
530
- if self.attn2 is not None:
531
- if self.norm_type == "ada_norm":
532
- norm_hidden_states = self.norm2(hidden_states, timestep)
533
- elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
534
- norm_hidden_states = self.norm2(hidden_states)
535
- elif self.norm_type == "ada_norm_single":
536
- # For PixArt norm2 isn't applied here:
537
- # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
538
- norm_hidden_states = hidden_states
539
- elif self.norm_type == "ada_norm_continuous":
540
- norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
541
- else:
542
- raise ValueError("Incorrect norm")
543
-
544
- if self.pos_embed is not None and self.norm_type != "ada_norm_single":
545
- norm_hidden_states = self.pos_embed(norm_hidden_states)
546
-
547
- attn_output = self.attn2(
548
- norm_hidden_states,
549
- encoder_hidden_states=encoder_hidden_states,
550
- attention_mask=encoder_attention_mask,
551
- **cross_attention_kwargs,
552
- )
553
- hidden_states = attn_output + hidden_states
554
-
555
- # 4. Feed-forward
556
- # i2vgen doesn't have this norm 🤷‍♂️
557
- if self.norm_type == "ada_norm_continuous":
558
- norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
559
- elif not self.norm_type == "ada_norm_single":
560
- norm_hidden_states = self.norm3(hidden_states)
561
-
562
- if self.norm_type == "ada_norm_zero":
563
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
564
-
565
- if self.norm_type == "ada_norm_single":
566
- norm_hidden_states = self.norm2(hidden_states)
567
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
568
-
569
- if self._chunk_size is not None:
570
- # "feed_forward_chunk_size" can be used to save memory
571
- ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
572
- else:
573
- ff_output = self.ff(norm_hidden_states)
574
-
575
- if self.norm_type == "ada_norm_zero":
576
- ff_output = gate_mlp.unsqueeze(1) * ff_output
577
- elif self.norm_type == "ada_norm_single":
578
- ff_output = gate_mlp * ff_output
579
-
580
- hidden_states = ff_output + hidden_states
581
- if hidden_states.ndim == 4:
582
- hidden_states = hidden_states.squeeze(1)
583
-
584
- return hidden_states
585
-
586
-
587
- class LuminaFeedForward(nn.Module):
588
- r"""
589
- A feed-forward layer.
590
-
591
- Parameters:
592
- hidden_size (`int`):
593
- The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
594
- hidden representations.
595
- intermediate_size (`int`): The intermediate dimension of the feedforward layer.
596
- multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
597
- of this value.
598
- ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
599
- dimension. Defaults to None.
600
- """
601
-
602
- def __init__(
603
- self,
604
- dim: int,
605
- inner_dim: int,
606
- multiple_of: Optional[int] = 256,
607
- ffn_dim_multiplier: Optional[float] = None,
608
- ):
609
- super().__init__()
610
- inner_dim = int(2 * inner_dim / 3)
611
- # custom hidden_size factor multiplier
612
- if ffn_dim_multiplier is not None:
613
- inner_dim = int(ffn_dim_multiplier * inner_dim)
614
- inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
615
-
616
- self.linear_1 = nn.Linear(
617
- dim,
618
- inner_dim,
619
- bias=False,
620
- )
621
- self.linear_2 = nn.Linear(
622
- inner_dim,
623
- dim,
624
- bias=False,
625
- )
626
- self.linear_3 = nn.Linear(
627
- dim,
628
- inner_dim,
629
- bias=False,
630
- )
631
- self.silu = FP32SiLU()
632
-
633
- def forward(self, x):
634
- return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x))
635
-
636
-
637
- @maybe_allow_in_graph
638
- class TemporalBasicTransformerBlock(nn.Module):
639
- r"""
640
- A basic Transformer block for video like data.
641
-
642
- Parameters:
643
- dim (`int`): The number of channels in the input and output.
644
- time_mix_inner_dim (`int`): The number of channels for temporal attention.
645
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
646
- attention_head_dim (`int`): The number of channels in each head.
647
- cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
648
- """
649
-
650
- def __init__(
651
- self,
652
- dim: int,
653
- time_mix_inner_dim: int,
654
- num_attention_heads: int,
655
- attention_head_dim: int,
656
- cross_attention_dim: Optional[int] = None,
657
- ):
658
- super().__init__()
659
- self.is_res = dim == time_mix_inner_dim
660
-
661
- self.norm_in = nn.LayerNorm(dim)
662
-
663
- # Define 3 blocks. Each block has its own normalization layer.
664
- # 1. Self-Attn
665
- self.ff_in = FeedForward(
666
- dim,
667
- dim_out=time_mix_inner_dim,
668
- activation_fn="geglu",
669
- )
670
-
671
- self.norm1 = nn.LayerNorm(time_mix_inner_dim)
672
- self.attn1 = Attention(
673
- query_dim=time_mix_inner_dim,
674
- heads=num_attention_heads,
675
- dim_head=attention_head_dim,
676
- cross_attention_dim=None,
677
- )
678
-
679
- # 2. Cross-Attn
680
- if cross_attention_dim is not None:
681
- # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
682
- # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
683
- # the second cross attention block.
684
- self.norm2 = nn.LayerNorm(time_mix_inner_dim)
685
- self.attn2 = Attention(
686
- query_dim=time_mix_inner_dim,
687
- cross_attention_dim=cross_attention_dim,
688
- heads=num_attention_heads,
689
- dim_head=attention_head_dim,
690
- ) # is self-attn if encoder_hidden_states is none
691
- else:
692
- self.norm2 = None
693
- self.attn2 = None
694
-
695
- # 3. Feed-forward
696
- self.norm3 = nn.LayerNorm(time_mix_inner_dim)
697
- self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
698
-
699
- # let chunk size default to None
700
- self._chunk_size = None
701
- self._chunk_dim = None
702
-
703
- def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
704
- # Sets chunk feed-forward
705
- self._chunk_size = chunk_size
706
- # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
707
- self._chunk_dim = 1
708
-
709
- def forward(
710
- self,
711
- hidden_states: torch.Tensor,
712
- num_frames: int,
713
- encoder_hidden_states: Optional[torch.Tensor] = None,
714
- ) -> torch.Tensor:
715
- # Notice that normalization is always applied before the real computation in the following blocks.
716
- # 0. Self-Attention
717
- batch_size = hidden_states.shape[0]
718
-
719
- batch_frames, seq_length, channels = hidden_states.shape
720
- batch_size = batch_frames // num_frames
721
-
722
- hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
723
- hidden_states = hidden_states.permute(0, 2, 1, 3)
724
- hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
725
-
726
- residual = hidden_states
727
- hidden_states = self.norm_in(hidden_states)
728
-
729
- if self._chunk_size is not None:
730
- hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
731
- else:
732
- hidden_states = self.ff_in(hidden_states)
733
-
734
- if self.is_res:
735
- hidden_states = hidden_states + residual
736
-
737
- norm_hidden_states = self.norm1(hidden_states)
738
- attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
739
- hidden_states = attn_output + hidden_states
740
-
741
- # 3. Cross-Attention
742
- if self.attn2 is not None:
743
- norm_hidden_states = self.norm2(hidden_states)
744
- attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
745
- hidden_states = attn_output + hidden_states
746
-
747
- # 4. Feed-forward
748
- norm_hidden_states = self.norm3(hidden_states)
749
-
750
- if self._chunk_size is not None:
751
- ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
752
- else:
753
- ff_output = self.ff(norm_hidden_states)
754
-
755
- if self.is_res:
756
- hidden_states = ff_output + hidden_states
757
- else:
758
- hidden_states = ff_output
759
-
760
- hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
761
- hidden_states = hidden_states.permute(0, 2, 1, 3)
762
- hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
763
-
764
- return hidden_states
765
-
766
-
767
- class SkipFFTransformerBlock(nn.Module):
768
- def __init__(
769
- self,
770
- dim: int,
771
- num_attention_heads: int,
772
- attention_head_dim: int,
773
- kv_input_dim: int,
774
- kv_input_dim_proj_use_bias: bool,
775
- dropout=0.0,
776
- cross_attention_dim: Optional[int] = None,
777
- attention_bias: bool = False,
778
- attention_out_bias: bool = True,
779
- ):
780
- super().__init__()
781
- if kv_input_dim != dim:
782
- self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
783
- else:
784
- self.kv_mapper = None
785
-
786
- self.norm1 = RMSNorm(dim, 1e-06)
787
-
788
- self.attn1 = Attention(
789
- query_dim=dim,
790
- heads=num_attention_heads,
791
- dim_head=attention_head_dim,
792
- dropout=dropout,
793
- bias=attention_bias,
794
- cross_attention_dim=cross_attention_dim,
795
- out_bias=attention_out_bias,
796
- )
797
-
798
- self.norm2 = RMSNorm(dim, 1e-06)
799
-
800
- self.attn2 = Attention(
801
- query_dim=dim,
802
- cross_attention_dim=cross_attention_dim,
803
- heads=num_attention_heads,
804
- dim_head=attention_head_dim,
805
- dropout=dropout,
806
- bias=attention_bias,
807
- out_bias=attention_out_bias,
808
- )
809
-
810
- def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
811
- cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
812
-
813
- if self.kv_mapper is not None:
814
- encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
815
-
816
- norm_hidden_states = self.norm1(hidden_states)
817
-
818
- attn_output = self.attn1(
819
- norm_hidden_states,
820
- encoder_hidden_states=encoder_hidden_states,
821
- **cross_attention_kwargs,
822
- )
823
-
824
- hidden_states = attn_output + hidden_states
825
-
826
- norm_hidden_states = self.norm2(hidden_states)
827
-
828
- attn_output = self.attn2(
829
- norm_hidden_states,
830
- encoder_hidden_states=encoder_hidden_states,
831
- **cross_attention_kwargs,
832
- )
833
-
834
- hidden_states = attn_output + hidden_states
835
-
836
- return hidden_states
837
-
838
-
839
- @maybe_allow_in_graph
840
- class FreeNoiseTransformerBlock(nn.Module):
841
- r"""
842
- A FreeNoise Transformer block.
843
-
844
- Parameters:
845
- dim (`int`):
846
- The number of channels in the input and output.
847
- num_attention_heads (`int`):
848
- The number of heads to use for multi-head attention.
849
- attention_head_dim (`int`):
850
- The number of channels in each head.
851
- dropout (`float`, *optional*, defaults to 0.0):
852
- The dropout probability to use.
853
- cross_attention_dim (`int`, *optional*):
854
- The size of the encoder_hidden_states vector for cross attention.
855
- activation_fn (`str`, *optional*, defaults to `"geglu"`):
856
- Activation function to be used in feed-forward.
857
- num_embeds_ada_norm (`int`, *optional*):
858
- The number of diffusion steps used during training. See `Transformer2DModel`.
859
- attention_bias (`bool`, defaults to `False`):
860
- Configure if the attentions should contain a bias parameter.
861
- only_cross_attention (`bool`, defaults to `False`):
862
- Whether to use only cross-attention layers. In this case two cross attention layers are used.
863
- double_self_attention (`bool`, defaults to `False`):
864
- Whether to use two self-attention layers. In this case no cross attention layers are used.
865
- upcast_attention (`bool`, defaults to `False`):
866
- Whether to upcast the attention computation to float32. This is useful for mixed precision training.
867
- norm_elementwise_affine (`bool`, defaults to `True`):
868
- Whether to use learnable elementwise affine parameters for normalization.
869
- norm_type (`str`, defaults to `"layer_norm"`):
870
- The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
871
- final_dropout (`bool` defaults to `False`):
872
- Whether to apply a final dropout after the last feed-forward layer.
873
- attention_type (`str`, defaults to `"default"`):
874
- The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
875
- positional_embeddings (`str`, *optional*):
876
- The type of positional embeddings to apply to.
877
- num_positional_embeddings (`int`, *optional*, defaults to `None`):
878
- The maximum number of positional embeddings to apply.
879
- ff_inner_dim (`int`, *optional*):
880
- Hidden dimension of feed-forward MLP.
881
- ff_bias (`bool`, defaults to `True`):
882
- Whether or not to use bias in feed-forward MLP.
883
- attention_out_bias (`bool`, defaults to `True`):
884
- Whether or not to use bias in attention output project layer.
885
- context_length (`int`, defaults to `16`):
886
- The maximum number of frames that the FreeNoise block processes at once.
887
- context_stride (`int`, defaults to `4`):
888
- The number of frames to be skipped before starting to process a new batch of `context_length` frames.
889
- weighting_scheme (`str`, defaults to `"pyramid"`):
890
- The weighting scheme to use for weighting averaging of processed latent frames. As described in the
891
- Equation 9. of the [FreeNoise](https://arxiv.org/abs/2310.15169) paper, "pyramid" is the default setting
892
- used.
893
- """
894
-
895
- def __init__(
896
- self,
897
- dim: int,
898
- num_attention_heads: int,
899
- attention_head_dim: int,
900
- dropout: float = 0.0,
901
- cross_attention_dim: Optional[int] = None,
902
- activation_fn: str = "geglu",
903
- num_embeds_ada_norm: Optional[int] = None,
904
- attention_bias: bool = False,
905
- only_cross_attention: bool = False,
906
- double_self_attention: bool = False,
907
- upcast_attention: bool = False,
908
- norm_elementwise_affine: bool = True,
909
- norm_type: str = "layer_norm",
910
- norm_eps: float = 1e-5,
911
- final_dropout: bool = False,
912
- positional_embeddings: Optional[str] = None,
913
- num_positional_embeddings: Optional[int] = None,
914
- ff_inner_dim: Optional[int] = None,
915
- ff_bias: bool = True,
916
- attention_out_bias: bool = True,
917
- context_length: int = 16,
918
- context_stride: int = 4,
919
- weighting_scheme: str = "pyramid",
920
- ):
921
- super().__init__()
922
- self.dim = dim
923
- self.num_attention_heads = num_attention_heads
924
- self.attention_head_dim = attention_head_dim
925
- self.dropout = dropout
926
- self.cross_attention_dim = cross_attention_dim
927
- self.activation_fn = activation_fn
928
- self.attention_bias = attention_bias
929
- self.double_self_attention = double_self_attention
930
- self.norm_elementwise_affine = norm_elementwise_affine
931
- self.positional_embeddings = positional_embeddings
932
- self.num_positional_embeddings = num_positional_embeddings
933
- self.only_cross_attention = only_cross_attention
934
-
935
- self.set_free_noise_properties(context_length, context_stride, weighting_scheme)
936
-
937
- # We keep these boolean flags for backward-compatibility.
938
- self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
939
- self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
940
- self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
941
- self.use_layer_norm = norm_type == "layer_norm"
942
- self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
943
-
944
- if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
945
- raise ValueError(
946
- f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
947
- f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
948
- )
949
-
950
- self.norm_type = norm_type
951
- self.num_embeds_ada_norm = num_embeds_ada_norm
952
-
953
- if positional_embeddings and (num_positional_embeddings is None):
954
- raise ValueError(
955
- "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
956
- )
957
-
958
- if positional_embeddings == "sinusoidal":
959
- self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
960
- else:
961
- self.pos_embed = None
962
-
963
- # Define 3 blocks. Each block has its own normalization layer.
964
- # 1. Self-Attn
965
- self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
966
-
967
- self.attn1 = Attention(
968
- query_dim=dim,
969
- heads=num_attention_heads,
970
- dim_head=attention_head_dim,
971
- dropout=dropout,
972
- bias=attention_bias,
973
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
974
- upcast_attention=upcast_attention,
975
- out_bias=attention_out_bias,
976
- )
977
-
978
- # 2. Cross-Attn
979
- if cross_attention_dim is not None or double_self_attention:
980
- self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
981
-
982
- self.attn2 = Attention(
983
- query_dim=dim,
984
- cross_attention_dim=cross_attention_dim if not double_self_attention else None,
985
- heads=num_attention_heads,
986
- dim_head=attention_head_dim,
987
- dropout=dropout,
988
- bias=attention_bias,
989
- upcast_attention=upcast_attention,
990
- out_bias=attention_out_bias,
991
- ) # is self-attn if encoder_hidden_states is none
992
-
993
- # 3. Feed-forward
994
- self.ff = FeedForward(
995
- dim,
996
- dropout=dropout,
997
- activation_fn=activation_fn,
998
- final_dropout=final_dropout,
999
- inner_dim=ff_inner_dim,
1000
- bias=ff_bias,
1001
- )
1002
-
1003
- self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
1004
-
1005
- # let chunk size default to None
1006
- self._chunk_size = None
1007
- self._chunk_dim = 0
1008
-
1009
- def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]:
1010
- frame_indices = []
1011
- for i in range(0, num_frames - self.context_length + 1, self.context_stride):
1012
- window_start = i
1013
- window_end = min(num_frames, i + self.context_length)
1014
- frame_indices.append((window_start, window_end))
1015
- return frame_indices
1016
-
1017
- def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
1018
- if weighting_scheme == "flat":
1019
- weights = [1.0] * num_frames
1020
-
1021
- elif weighting_scheme == "pyramid":
1022
- if num_frames % 2 == 0:
1023
- # num_frames = 4 => [1, 2, 2, 1]
1024
- mid = num_frames // 2
1025
- weights = list(range(1, mid + 1))
1026
- weights = weights + weights[::-1]
1027
- else:
1028
- # num_frames = 5 => [1, 2, 3, 2, 1]
1029
- mid = (num_frames + 1) // 2
1030
- weights = list(range(1, mid))
1031
- weights = weights + [mid] + weights[::-1]
1032
-
1033
- elif weighting_scheme == "delayed_reverse_sawtooth":
1034
- if num_frames % 2 == 0:
1035
- # num_frames = 4 => [0.01, 2, 2, 1]
1036
- mid = num_frames // 2
1037
- weights = [0.01] * (mid - 1) + [mid]
1038
- weights = weights + list(range(mid, 0, -1))
1039
- else:
1040
- # num_frames = 5 => [0.01, 0.01, 3, 2, 1]
1041
- mid = (num_frames + 1) // 2
1042
- weights = [0.01] * mid
1043
- weights = weights + list(range(mid, 0, -1))
1044
- else:
1045
- raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")
1046
-
1047
- return weights
1048
-
1049
- def set_free_noise_properties(
1050
- self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid"
1051
- ) -> None:
1052
- self.context_length = context_length
1053
- self.context_stride = context_stride
1054
- self.weighting_scheme = weighting_scheme
1055
-
1056
- def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0) -> None:
1057
- # Sets chunk feed-forward
1058
- self._chunk_size = chunk_size
1059
- self._chunk_dim = dim
1060
-
1061
- def forward(
1062
- self,
1063
- hidden_states: torch.Tensor,
1064
- attention_mask: Optional[torch.Tensor] = None,
1065
- encoder_hidden_states: Optional[torch.Tensor] = None,
1066
- encoder_attention_mask: Optional[torch.Tensor] = None,
1067
- cross_attention_kwargs: Dict[str, Any] = None,
1068
- *args,
1069
- **kwargs,
1070
- ) -> torch.Tensor:
1071
- if cross_attention_kwargs is not None:
1072
- if cross_attention_kwargs.get("scale", None) is not None:
1073
- logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
1074
-
1075
- cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
1076
-
1077
- # hidden_states: [B x H x W, F, C]
1078
- device = hidden_states.device
1079
- dtype = hidden_states.dtype
1080
-
1081
- num_frames = hidden_states.size(1)
1082
- frame_indices = self._get_frame_indices(num_frames)
1083
- frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme)
1084
- frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1)
1085
- is_last_frame_batch_complete = frame_indices[-1][1] == num_frames
1086
-
1087
- # Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length
1088
- # For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges:
1089
- # [(0, 16), (4, 20), (8, 24), (10, 26)]
1090
- if not is_last_frame_batch_complete:
1091
- if num_frames < self.context_length:
1092
- raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}")
1093
- last_frame_batch_length = num_frames - frame_indices[-1][1]
1094
- frame_indices.append((num_frames - self.context_length, num_frames))
1095
-
1096
- num_times_accumulated = torch.zeros((1, num_frames, 1), device=device)
1097
- accumulated_values = torch.zeros_like(hidden_states)
1098
-
1099
- for i, (frame_start, frame_end) in enumerate(frame_indices):
1100
- # The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle
1101
- # cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or
1102
- # essentially a non-multiple of `context_length`.
1103
- weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end])
1104
- weights *= frame_weights
1105
-
1106
- hidden_states_chunk = hidden_states[:, frame_start:frame_end]
1107
-
1108
- # Notice that normalization is always applied before the real computation in the following blocks.
1109
- # 1. Self-Attention
1110
- norm_hidden_states = self.norm1(hidden_states_chunk)
1111
-
1112
- if self.pos_embed is not None:
1113
- norm_hidden_states = self.pos_embed(norm_hidden_states)
1114
-
1115
- attn_output = self.attn1(
1116
- norm_hidden_states,
1117
- encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
1118
- attention_mask=attention_mask,
1119
- **cross_attention_kwargs,
1120
- )
1121
-
1122
- hidden_states_chunk = attn_output + hidden_states_chunk
1123
- if hidden_states_chunk.ndim == 4:
1124
- hidden_states_chunk = hidden_states_chunk.squeeze(1)
1125
-
1126
- # 2. Cross-Attention
1127
- if self.attn2 is not None:
1128
- norm_hidden_states = self.norm2(hidden_states_chunk)
1129
-
1130
- if self.pos_embed is not None and self.norm_type != "ada_norm_single":
1131
- norm_hidden_states = self.pos_embed(norm_hidden_states)
1132
-
1133
- attn_output = self.attn2(
1134
- norm_hidden_states,
1135
- encoder_hidden_states=encoder_hidden_states,
1136
- attention_mask=encoder_attention_mask,
1137
- **cross_attention_kwargs,
1138
- )
1139
- hidden_states_chunk = attn_output + hidden_states_chunk
1140
-
1141
- if i == len(frame_indices) - 1 and not is_last_frame_batch_complete:
1142
- accumulated_values[:, -last_frame_batch_length:] += (
1143
- hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:]
1144
- )
1145
- num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length]
1146
- else:
1147
- accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
1148
- num_times_accumulated[:, frame_start:frame_end] += weights
1149
-
1150
- # TODO(aryan): Maybe this could be done in a better way.
1151
- #
1152
- # Previously, this was:
1153
- # hidden_states = torch.where(
1154
- # num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
1155
- # )
1156
- #
1157
- # The reasoning for the change here is `torch.where` became a bottleneck at some point when golfing memory
1158
- # spikes. It is particularly noticeable when the number of frames is high. My understanding is that this comes
1159
- # from tensors being copied - which is why we resort to spliting and concatenating here. I've not particularly
1160
- # looked into this deeply because other memory optimizations led to more pronounced reductions.
1161
- hidden_states = torch.cat(
1162
- [
1163
- torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split)
1164
- for accumulated_split, num_times_split in zip(
1165
- accumulated_values.split(self.context_length, dim=1),
1166
- num_times_accumulated.split(self.context_length, dim=1),
1167
- )
1168
- ],
1169
- dim=1,
1170
- ).to(dtype)
1171
-
1172
- # 3. Feed-forward
1173
- norm_hidden_states = self.norm3(hidden_states)
1174
-
1175
- if self._chunk_size is not None:
1176
- ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
1177
- else:
1178
- ff_output = self.ff(norm_hidden_states)
1179
-
1180
- hidden_states = ff_output + hidden_states
1181
- if hidden_states.ndim == 4:
1182
- hidden_states = hidden_states.squeeze(1)
1183
-
1184
- return hidden_states
1185
-
1186
-
1187
- class FeedForward(nn.Module):
1188
- r"""
1189
- A feed-forward layer.
1190
-
1191
- Parameters:
1192
- dim (`int`): The number of channels in the input.
1193
- dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
1194
- mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
1195
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
1196
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
1197
- final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
1198
- bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
1199
- """
1200
-
1201
- def __init__(
1202
- self,
1203
- dim: int,
1204
- dim_out: Optional[int] = None,
1205
- mult: int = 4,
1206
- dropout: float = 0.0,
1207
- activation_fn: str = "geglu",
1208
- final_dropout: bool = False,
1209
- inner_dim=None,
1210
- bias: bool = True,
1211
- ):
1212
- super().__init__()
1213
- if inner_dim is None:
1214
- inner_dim = int(dim * mult)
1215
- dim_out = dim_out if dim_out is not None else dim
1216
-
1217
- if activation_fn == "gelu":
1218
- act_fn = GELU(dim, inner_dim, bias=bias)
1219
- if activation_fn == "gelu-approximate":
1220
- act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
1221
- elif activation_fn == "geglu":
1222
- act_fn = GEGLU(dim, inner_dim, bias=bias)
1223
- elif activation_fn == "geglu-approximate":
1224
- act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
1225
- elif activation_fn == "swiglu":
1226
- act_fn = SwiGLU(dim, inner_dim, bias=bias)
1227
-
1228
- self.net = nn.ModuleList([])
1229
- # project in
1230
- self.net.append(act_fn)
1231
- # project dropout
1232
- self.net.append(nn.Dropout(dropout))
1233
- # project out
1234
- self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
1235
- # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
1236
- if final_dropout:
1237
- self.net.append(nn.Dropout(dropout))
1238
-
1239
- def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
1240
- if len(args) > 0 or kwargs.get("scale", None) is not None:
1241
- deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1242
- deprecate("scale", "1.0.0", deprecation_message)
1243
- for module in self.net:
1244
- hidden_states = module(hidden_states)
1245
- return hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/resampler.py DELETED
@@ -1,304 +0,0 @@
1
- # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
- import math
3
-
4
- import torch
5
- import torch.nn as nn
6
-
7
- from diffusers.models.embeddings import Timesteps, TimestepEmbedding
8
-
9
- def get_timestep_embedding(
10
- timesteps: torch.Tensor,
11
- embedding_dim: int,
12
- flip_sin_to_cos: bool = False,
13
- downscale_freq_shift: float = 1,
14
- scale: float = 1,
15
- max_period: int = 10000,
16
- ):
17
- """
18
- This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
19
-
20
- :param timesteps: a 1-D Tensor of N indices, one per batch element.
21
- These may be fractional.
22
- :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
23
- embeddings. :return: an [N x dim] Tensor of positional embeddings.
24
- """
25
- assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
26
-
27
- half_dim = embedding_dim // 2
28
- exponent = -math.log(max_period) * torch.arange(
29
- start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
30
- )
31
- exponent = exponent / (half_dim - downscale_freq_shift)
32
-
33
- emb = torch.exp(exponent)
34
- emb = timesteps[:, None].float() * emb[None, :]
35
-
36
- # scale embeddings
37
- emb = scale * emb
38
-
39
- # concat sine and cosine embeddings
40
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
41
-
42
- # flip sine and cosine embeddings
43
- if flip_sin_to_cos:
44
- emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
45
-
46
- # zero pad
47
- if embedding_dim % 2 == 1:
48
- emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
49
- return emb
50
-
51
-
52
- # FFN
53
- def FeedForward(dim, mult=4):
54
- inner_dim = int(dim * mult)
55
- return nn.Sequential(
56
- nn.LayerNorm(dim),
57
- nn.Linear(dim, inner_dim, bias=False),
58
- nn.GELU(),
59
- nn.Linear(inner_dim, dim, bias=False),
60
- )
61
-
62
-
63
- def reshape_tensor(x, heads):
64
- bs, length, width = x.shape
65
- #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
66
- x = x.view(bs, length, heads, -1)
67
- # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
68
- x = x.transpose(1, 2)
69
- # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
70
- x = x.reshape(bs, heads, length, -1)
71
- return x
72
-
73
-
74
- class PerceiverAttention(nn.Module):
75
- def __init__(self, *, dim, dim_head=64, heads=8):
76
- super().__init__()
77
- self.scale = dim_head**-0.5
78
- self.dim_head = dim_head
79
- self.heads = heads
80
- inner_dim = dim_head * heads
81
-
82
- self.norm1 = nn.LayerNorm(dim)
83
- self.norm2 = nn.LayerNorm(dim)
84
-
85
- self.to_q = nn.Linear(dim, inner_dim, bias=False)
86
- self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
87
- self.to_out = nn.Linear(inner_dim, dim, bias=False)
88
-
89
-
90
- def forward(self, x, latents, shift=None, scale=None):
91
- """
92
- Args:
93
- x (torch.Tensor): image features
94
- shape (b, n1, D)
95
- latent (torch.Tensor): latent features
96
- shape (b, n2, D)
97
- """
98
- x = self.norm1(x)
99
- latents = self.norm2(latents)
100
-
101
- if shift is not None and scale is not None:
102
- latents = latents * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
103
-
104
- b, l, _ = latents.shape
105
-
106
- q = self.to_q(latents)
107
- kv_input = torch.cat((x, latents), dim=-2)
108
- k, v = self.to_kv(kv_input).chunk(2, dim=-1)
109
-
110
- q = reshape_tensor(q, self.heads)
111
- k = reshape_tensor(k, self.heads)
112
- v = reshape_tensor(v, self.heads)
113
-
114
- # attention
115
- scale = 1 / math.sqrt(math.sqrt(self.dim_head))
116
- weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
117
- weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
118
- out = weight @ v
119
-
120
- out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
121
-
122
- return self.to_out(out)
123
-
124
-
125
- class Resampler(nn.Module):
126
- def __init__(
127
- self,
128
- dim=1024,
129
- depth=8,
130
- dim_head=64,
131
- heads=16,
132
- num_queries=8,
133
- embedding_dim=768,
134
- output_dim=1024,
135
- ff_mult=4,
136
- *args,
137
- **kwargs,
138
- ):
139
- super().__init__()
140
-
141
- self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
142
-
143
- self.proj_in = nn.Linear(embedding_dim, dim)
144
-
145
- self.proj_out = nn.Linear(dim, output_dim)
146
- self.norm_out = nn.LayerNorm(output_dim)
147
-
148
- self.layers = nn.ModuleList([])
149
- for _ in range(depth):
150
- self.layers.append(
151
- nn.ModuleList(
152
- [
153
- PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
154
- FeedForward(dim=dim, mult=ff_mult),
155
- ]
156
- )
157
- )
158
-
159
- def forward(self, x):
160
-
161
- latents = self.latents.repeat(x.size(0), 1, 1)
162
-
163
- x = self.proj_in(x)
164
-
165
- for attn, ff in self.layers:
166
- latents = attn(x, latents) + latents
167
- latents = ff(latents) + latents
168
-
169
- latents = self.proj_out(latents)
170
- return self.norm_out(latents)
171
-
172
-
173
- class TimeResampler(nn.Module):
174
- def __init__(
175
- self,
176
- dim=1024,
177
- depth=8,
178
- dim_head=64,
179
- heads=16,
180
- num_queries=8,
181
- embedding_dim=768,
182
- output_dim=1024,
183
- ff_mult=4,
184
- timestep_in_dim=320,
185
- timestep_flip_sin_to_cos=True,
186
- timestep_freq_shift=0,
187
- ):
188
- super().__init__()
189
-
190
- self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
191
-
192
- self.proj_in = nn.Linear(embedding_dim, dim)
193
-
194
- self.proj_out = nn.Linear(dim, output_dim)
195
- self.norm_out = nn.LayerNorm(output_dim)
196
-
197
- self.layers = nn.ModuleList([])
198
- for _ in range(depth):
199
- self.layers.append(
200
- nn.ModuleList(
201
- [
202
- # msa
203
- PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
204
- # ff
205
- FeedForward(dim=dim, mult=ff_mult),
206
- # adaLN
207
- nn.Sequential(nn.SiLU(), nn.Linear(dim, 4 * dim, bias=True))
208
- ]
209
- )
210
- )
211
-
212
- # time
213
- self.time_proj = Timesteps(timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift)
214
- self.time_embedding = TimestepEmbedding(timestep_in_dim, dim, act_fn="silu")
215
-
216
- # adaLN
217
- # self.adaLN_modulation = nn.Sequential(
218
- # nn.SiLU(),
219
- # nn.Linear(timestep_out_dim, 6 * timestep_out_dim, bias=True)
220
- # )
221
-
222
-
223
- def forward(self, x, timestep, need_temb=False):
224
- timestep_emb = self.embedding_time(x, timestep) # bs, dim
225
-
226
- latents = self.latents.repeat(x.size(0), 1, 1)
227
-
228
- x = self.proj_in(x)
229
- x = x + timestep_emb[:, None]
230
-
231
- for attn, ff, adaLN_modulation in self.layers:
232
- shift_msa, scale_msa, shift_mlp, scale_mlp = adaLN_modulation(timestep_emb).chunk(4, dim=1)
233
- latents = attn(x, latents, shift_msa, scale_msa) + latents
234
-
235
- res = latents
236
- for idx_ff in range(len(ff)):
237
- layer_ff = ff[idx_ff]
238
- latents = layer_ff(latents)
239
- if idx_ff == 0 and isinstance(layer_ff, nn.LayerNorm): # adaLN
240
- latents = latents * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
241
- latents = latents + res
242
-
243
- # latents = ff(latents) + latents
244
-
245
- latents = self.proj_out(latents)
246
- latents = self.norm_out(latents)
247
-
248
- if need_temb:
249
- return latents, timestep_emb
250
- else:
251
- return latents
252
-
253
-
254
-
255
- def embedding_time(self, sample, timestep):
256
-
257
- # 1. time
258
- timesteps = timestep
259
- if not torch.is_tensor(timesteps):
260
- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
261
- # This would be a good case for the `match` statement (Python 3.10+)
262
- is_mps = sample.device.type == "mps"
263
- if isinstance(timestep, float):
264
- dtype = torch.float32 if is_mps else torch.float64
265
- else:
266
- dtype = torch.int32 if is_mps else torch.int64
267
- timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
268
- elif len(timesteps.shape) == 0:
269
- timesteps = timesteps[None].to(sample.device)
270
-
271
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
272
- timesteps = timesteps.expand(sample.shape[0])
273
-
274
- t_emb = self.time_proj(timesteps)
275
-
276
- # timesteps does not contain any weights and will always return f32 tensors
277
- # but time_embedding might actually be running in fp16. so we need to cast here.
278
- # there might be better ways to encapsulate this.
279
- t_emb = t_emb.to(dtype=sample.dtype)
280
-
281
- emb = self.time_embedding(t_emb, None)
282
- return emb
283
-
284
-
285
-
286
-
287
-
288
- if __name__ == '__main__':
289
- model = TimeResampler(
290
- dim=1280,
291
- depth=4,
292
- dim_head=64,
293
- heads=20,
294
- num_queries=16,
295
- embedding_dim=512,
296
- output_dim=2048,
297
- ff_mult=4,
298
- timestep_in_dim=320,
299
- timestep_flip_sin_to_cos=True,
300
- timestep_freq_shift=0,
301
- in_channel_extra_emb=2048,
302
- )
303
-
304
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/transformer_sd3.py DELETED
@@ -1,375 +0,0 @@
1
- # Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- from typing import Any, Dict, List, Optional, Tuple, Union
17
-
18
- import torch
19
- import torch.nn as nn
20
-
21
- from diffusers.configuration_utils import ConfigMixin, register_to_config
22
- from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
23
- from .attention import JointTransformerBlock
24
- from diffusers.models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
25
- from diffusers.models.modeling_utils import ModelMixin
26
- from diffusers.models.normalization import AdaLayerNormContinuous
27
- from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
28
- from diffusers.models.embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
29
- from diffusers.models.modeling_outputs import Transformer2DModelOutput
30
-
31
-
32
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33
-
34
-
35
- class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
36
- """
37
- The Transformer model introduced in Stable Diffusion 3.
38
-
39
- Reference: https://arxiv.org/abs/2403.03206
40
-
41
- Parameters:
42
- sample_size (`int`): The width of the latent images. This is fixed during training since
43
- it is used to learn a number of position embeddings.
44
- patch_size (`int`): Patch size to turn the input data into small patches.
45
- in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
46
- num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use.
47
- attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
48
- num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
49
- cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
50
- caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
51
- pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
52
- out_channels (`int`, defaults to 16): Number of output channels.
53
-
54
- """
55
-
56
- _supports_gradient_checkpointing = True
57
-
58
- @register_to_config
59
- def __init__(
60
- self,
61
- sample_size: int = 128,
62
- patch_size: int = 2,
63
- in_channels: int = 16,
64
- num_layers: int = 18,
65
- attention_head_dim: int = 64,
66
- num_attention_heads: int = 18,
67
- joint_attention_dim: int = 4096,
68
- caption_projection_dim: int = 1152,
69
- pooled_projection_dim: int = 2048,
70
- out_channels: int = 16,
71
- pos_embed_max_size: int = 96,
72
- dual_attention_layers: Tuple[
73
- int, ...
74
- ] = (), # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
75
- qk_norm: Optional[str] = None,
76
- ):
77
- super().__init__()
78
- default_out_channels = in_channels
79
- self.out_channels = out_channels if out_channels is not None else default_out_channels
80
- self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
81
-
82
- self.pos_embed = PatchEmbed(
83
- height=self.config.sample_size,
84
- width=self.config.sample_size,
85
- patch_size=self.config.patch_size,
86
- in_channels=self.config.in_channels,
87
- embed_dim=self.inner_dim,
88
- pos_embed_max_size=pos_embed_max_size, # hard-code for now.
89
- )
90
- self.time_text_embed = CombinedTimestepTextProjEmbeddings(
91
- embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
92
- )
93
- self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim)
94
-
95
- # `attention_head_dim` is doubled to account for the mixing.
96
- # It needs to crafted when we get the actual checkpoints.
97
- self.transformer_blocks = nn.ModuleList(
98
- [
99
- JointTransformerBlock(
100
- dim=self.inner_dim,
101
- num_attention_heads=self.config.num_attention_heads,
102
- attention_head_dim=self.config.attention_head_dim,
103
- context_pre_only=i == num_layers - 1,
104
- qk_norm=qk_norm,
105
- use_dual_attention=True if i in dual_attention_layers else False,
106
- )
107
- for i in range(self.config.num_layers)
108
- ]
109
- )
110
-
111
- self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
112
- self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
113
-
114
- self.gradient_checkpointing = False
115
-
116
- # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
117
- def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
118
- """
119
- Sets the attention processor to use [feed forward
120
- chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
121
-
122
- Parameters:
123
- chunk_size (`int`, *optional*):
124
- The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
125
- over each tensor of dim=`dim`.
126
- dim (`int`, *optional*, defaults to `0`):
127
- The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
128
- or dim=1 (sequence length).
129
- """
130
- if dim not in [0, 1]:
131
- raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
132
-
133
- # By default chunk size is 1
134
- chunk_size = chunk_size or 1
135
-
136
- def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
137
- if hasattr(module, "set_chunk_feed_forward"):
138
- module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
139
-
140
- for child in module.children():
141
- fn_recursive_feed_forward(child, chunk_size, dim)
142
-
143
- for module in self.children():
144
- fn_recursive_feed_forward(module, chunk_size, dim)
145
-
146
- # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
147
- def disable_forward_chunking(self):
148
- def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
149
- if hasattr(module, "set_chunk_feed_forward"):
150
- module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
151
-
152
- for child in module.children():
153
- fn_recursive_feed_forward(child, chunk_size, dim)
154
-
155
- for module in self.children():
156
- fn_recursive_feed_forward(module, None, 0)
157
-
158
- @property
159
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
160
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
161
- r"""
162
- Returns:
163
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
164
- indexed by its weight name.
165
- """
166
- # set recursively
167
- processors = {}
168
-
169
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
170
- if hasattr(module, "get_processor"):
171
- processors[f"{name}.processor"] = module.get_processor()
172
-
173
- for sub_name, child in module.named_children():
174
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
175
-
176
- return processors
177
-
178
- for name, module in self.named_children():
179
- fn_recursive_add_processors(name, module, processors)
180
-
181
- return processors
182
-
183
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
184
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
185
- r"""
186
- Sets the attention processor to use to compute attention.
187
-
188
- Parameters:
189
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
190
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
191
- for **all** `Attention` layers.
192
-
193
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
194
- processor. This is strongly recommended when setting trainable attention processors.
195
-
196
- """
197
- count = len(self.attn_processors.keys())
198
-
199
- if isinstance(processor, dict) and len(processor) != count:
200
- raise ValueError(
201
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
202
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
203
- )
204
-
205
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
206
- if hasattr(module, "set_processor"):
207
- if not isinstance(processor, dict):
208
- module.set_processor(processor)
209
- else:
210
- module.set_processor(processor.pop(f"{name}.processor"))
211
-
212
- for sub_name, child in module.named_children():
213
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
214
-
215
- for name, module in self.named_children():
216
- fn_recursive_attn_processor(name, module, processor)
217
-
218
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0
219
- def fuse_qkv_projections(self):
220
- """
221
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
222
- are fused. For cross-attention modules, key and value projection matrices are fused.
223
-
224
- <Tip warning={true}>
225
-
226
- This API is 🧪 experimental.
227
-
228
- </Tip>
229
- """
230
- self.original_attn_processors = None
231
-
232
- for _, attn_processor in self.attn_processors.items():
233
- if "Added" in str(attn_processor.__class__.__name__):
234
- raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
235
-
236
- self.original_attn_processors = self.attn_processors
237
-
238
- for module in self.modules():
239
- if isinstance(module, Attention):
240
- module.fuse_projections(fuse=True)
241
-
242
- self.set_attn_processor(FusedJointAttnProcessor2_0())
243
-
244
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
245
- def unfuse_qkv_projections(self):
246
- """Disables the fused QKV projection if enabled.
247
-
248
- <Tip warning={true}>
249
-
250
- This API is 🧪 experimental.
251
-
252
- </Tip>
253
-
254
- """
255
- if self.original_attn_processors is not None:
256
- self.set_attn_processor(self.original_attn_processors)
257
-
258
- def _set_gradient_checkpointing(self, module, value=False):
259
- if hasattr(module, "gradient_checkpointing"):
260
- module.gradient_checkpointing = value
261
-
262
- def forward(
263
- self,
264
- hidden_states: torch.FloatTensor,
265
- encoder_hidden_states: torch.FloatTensor = None,
266
- pooled_projections: torch.FloatTensor = None,
267
- timestep: torch.LongTensor = None,
268
- block_controlnet_hidden_states: List = None,
269
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
270
- return_dict: bool = True,
271
- ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
272
- """
273
- The [`SD3Transformer2DModel`] forward method.
274
-
275
- Args:
276
- hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
277
- Input `hidden_states`.
278
- encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
279
- Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
280
- pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
281
- from the embeddings of input conditions.
282
- timestep ( `torch.LongTensor`):
283
- Used to indicate denoising step.
284
- block_controlnet_hidden_states: (`list` of `torch.Tensor`):
285
- A list of tensors that if specified are added to the residuals of transformer blocks.
286
- joint_attention_kwargs (`dict`, *optional*):
287
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
288
- `self.processor` in
289
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
290
- return_dict (`bool`, *optional*, defaults to `True`):
291
- Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
292
- tuple.
293
-
294
- Returns:
295
- If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
296
- `tuple` where the first element is the sample tensor.
297
- """
298
- if joint_attention_kwargs is not None:
299
- joint_attention_kwargs = joint_attention_kwargs.copy()
300
- lora_scale = joint_attention_kwargs.pop("scale", 1.0)
301
- else:
302
- lora_scale = 1.0
303
-
304
- if USE_PEFT_BACKEND:
305
- # weight the lora layers by setting `lora_scale` for each PEFT layer
306
- scale_lora_layers(self, lora_scale)
307
- else:
308
- if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
309
- logger.warning(
310
- "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
311
- )
312
-
313
- height, width = hidden_states.shape[-2:]
314
-
315
- hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
316
- temb = self.time_text_embed(timestep, pooled_projections)
317
- encoder_hidden_states = self.context_embedder(encoder_hidden_states)
318
-
319
- for index_block, block in enumerate(self.transformer_blocks):
320
- if self.training and self.gradient_checkpointing:
321
-
322
- def create_custom_forward(module, return_dict=None):
323
- def custom_forward(*inputs):
324
- if return_dict is not None:
325
- return module(*inputs, return_dict=return_dict)
326
- else:
327
- return module(*inputs)
328
-
329
- return custom_forward
330
-
331
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
332
- encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
333
- create_custom_forward(block),
334
- hidden_states,
335
- encoder_hidden_states,
336
- temb,
337
- joint_attention_kwargs,
338
- **ckpt_kwargs,
339
- )
340
-
341
- else:
342
- encoder_hidden_states, hidden_states = block(
343
- hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb,
344
- joint_attention_kwargs=joint_attention_kwargs,
345
- )
346
-
347
- # controlnet residual
348
- if block_controlnet_hidden_states is not None and block.context_pre_only is False:
349
- interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states)
350
- hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control]
351
-
352
- hidden_states = self.norm_out(hidden_states, temb)
353
- hidden_states = self.proj_out(hidden_states)
354
-
355
- # unpatchify
356
- patch_size = self.config.patch_size
357
- height = height // patch_size
358
- width = width // patch_size
359
-
360
- hidden_states = hidden_states.reshape(
361
- shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
362
- )
363
- hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
364
- output = hidden_states.reshape(
365
- shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
366
- )
367
-
368
- if USE_PEFT_BACKEND:
369
- # remove `lora_scale` from each PEFT layer
370
- unscale_lora_layers(self, lora_scale)
371
-
372
- if not return_dict:
373
- return (output,)
374
-
375
- return Transformer2DModelOutput(sample=output)